diff --git a/modcc/expression.hpp b/modcc/expression.hpp index e1eef14e77ed7258e39327348e8922ffe74bcc2d..280791a7d51cb6f09ba64f533b5d6bade288d0b6 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -1441,6 +1441,9 @@ public: : BinaryExpression(loc, tok::pow, std::move(lhs), std::move(rhs)) {} + // pow is a prefix binop + bool is_infix() const override {return false;} + void accept(Visitor *v) override; }; diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp index d1366a46d87a38cca06ba1ffbc84855e3fd12f0d..e9c23e4d229221d853ac116ad947de6f8935cb2a 100644 --- a/modcc/printer/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -91,10 +91,6 @@ void CExprEmitter::visit(AssignmentExpression* e) { e->rhs()->accept(this); } -void CExprEmitter::visit(PowBinaryExpression* e) { - emit_as_call("pow", e->lhs(), e->rhs()); -} - void CExprEmitter::visit(BinaryExpression* e) { static std::unordered_map<tok, const char*> binop_tbl = { {tok::minus, "-"}, @@ -111,6 +107,7 @@ void CExprEmitter::visit(BinaryExpression* e) { {tok::ne, "!="}, {tok::min, "min"}, {tok::max, "max"}, + {tok::pow, "pow"}, }; if (!binop_tbl.count(e->op())) { @@ -178,14 +175,6 @@ void CExprEmitter::visit(IfExpression* e) { /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// std::unordered_set<std::string> SimdExprEmitter::mask_names_; -void SimdExprEmitter::visit(PowBinaryExpression* e) { - out_ << "S::pow("; - e->lhs()->accept(this); - out_ << ", "; - e->rhs()->accept(this); - out_ << ')'; -} - void SimdExprEmitter::visit(NumberExpression* e) { out_ << " (double)" << as_c_double(e->value()); } @@ -252,6 +241,7 @@ void SimdExprEmitter::visit(BinaryExpression* e) { {tok::ne, "S::cmp_neq"}, {tok::min, "S::min"}, {tok::max, "S::max"}, + {tok::pow, "S::pow"}, }; static std::unordered_map<tok, const char *> binop_tbl = { @@ -269,6 +259,7 @@ void SimdExprEmitter::visit(BinaryExpression* e) { {tok::ne, "!="}, {tok::min, "min"}, {tok::max, "max"}, + {tok::pow, "pow"}, }; @@ -385,25 +376,41 @@ void SimdExprEmitter::visit(AssignmentExpression* e) { auto lhs_pfxd = id_prefix(e->lhs()->is_identifier()); + // lhs should not be an IndexedVariable, only a VariableExpression or LocalVariable. + // IndexedVariables are only assigned in API calls and are handled in a special way. + if (lhs->is_indexed_variable()) { + throw (compiler_exception("Should not be trying to assign an IndexedVariable " + lhs->to_string(), lhs->location())); + } + // If lhs is a VariableExpression, it must be a range variable. Non-range variables + // are scalars and are read-only. if (lhs->is_variable() && lhs->is_variable()->is_range()) { + // input_mask_ will only appear in PROCEDURE and FUNCTION calls which can only + // assign to VariableExpression and LocalVariable. But only VariableExpression + // vectors need to be assigned only at the elements where input_mask_ is true. if (!input_mask_.empty()) { mask = "S::logical_and(" + mask + ", " + input_mask_ + ")"; } - if(is_indirect_) - out_ << "indirect(" << lhs_pfxd << "+index_, simd_width_) = "; - else - out_ << "indirect(" << lhs_pfxd << "+i_, simd_width_) = "; - out_ << "S::where(" << mask << ", "; + std::string index = is_indirect_ ? "index_" : "i_"; + out_ << "indirect(" << lhs_pfxd << "+" << index << ", simd_width_) = S::where(" << mask << ", "; + + // If the rhs is a scalar identifier or a number, it needs to be cast to a vector. + auto id = e->rhs()->is_identifier(); + bool num = e->rhs()->is_number(); + bool cast = num || (id && scalars_.count(id->name())); - bool cast = e->rhs()->is_number(); if (cast) out_ << "simd_cast<simd_value>("; e->rhs()->accept(this); + if (cast) out_ << ")"; out_ << ")"; - - if (cast) out_ << ")"; - } else { + } + else if (lhs->is_variable() && !lhs->is_variable()->is_range()) { + throw (compiler_exception("Should not be trying to assign a non-range variable " + lhs->to_string(), lhs->location())); + } + // Otherwise, lhs must be a LocalVariable, we don't need to mask assignment according to the + // input_mask_. + else { out_ << "S::where(" << mask << ", "; e->lhs()->accept(this); out_ << ") = "; diff --git a/modcc/printer/cexpr_emit.hpp b/modcc/printer/cexpr_emit.hpp index 9a2c3dbc6a56d5fe30e814049c26671ad54e747e..4b278ec85be11e6ad61369b6abbef9aaaf5c708d 100644 --- a/modcc/printer/cexpr_emit.hpp +++ b/modcc/printer/cexpr_emit.hpp @@ -21,7 +21,6 @@ public: void visit(UnaryExpression *e) override; void visit(BinaryExpression *e) override; void visit(AssignmentExpression *e) override; - void visit(PowBinaryExpression *e) override; void visit(NumberExpression *e) override; void visit(IfExpression *e) override; @@ -54,14 +53,14 @@ public: void visit(UnaryExpression *e) override; void visit(BinaryExpression *e) override; void visit(AssignmentExpression *e) override; - void visit(PowBinaryExpression *e) override; void visit(NumberExpression *e) override; void visit(IfExpression *e) override; protected: static std::unordered_set<std::string> mask_names_; bool processing_true_ = false; - bool is_indirect_ = false; + bool is_indirect_ = false; // For choosing between "index_" and "i_" as an index. Depends on whether + // we are in a procedure or handling a simd constraint in an API call. std::string current_mask_, current_mask_bar_, input_mask_; std::unordered_set<std::string> scalars_; Visitor* fallback_; diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index 6dd28c88375854b386e346d6ba0de0991356dadd..74d92372977f7cf95c9a6fff1e5b446c01df0fa1 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -89,7 +89,8 @@ std::string do_cprint(Expression* cp, int ind) { struct simdprint { Expression* expr_; - bool is_indirect_ = false; + bool is_indirect_ = false; // For choosing between "index_" and "i_" as an index. Depends on whether + // we are in a procedure or handling a simd constraint in an API call. bool is_masked_ = false; std::unordered_set<std::string> scalars_; @@ -657,38 +658,44 @@ void SimdPrinter::visit(AssignmentExpression* e) { } Symbol* lhs = e->lhs()->is_identifier()->symbol(); - - bool cast = false; - if (auto id = e->rhs()->is_identifier()) { - if (scalars_.count(id->name())) cast = true; - } - if (e->rhs()->is_number()) cast = true; - if (scalars_.count(e->lhs()->is_identifier()->name())) cast = false; + std::string pfx = lhs->is_local_variable() ? "" : pp_var_pfx; + std::string index = is_indirect_ ? "index_" : "i_"; + // lhs should not be an IndexedVariable, only a VariableExpression or LocalVariable. + // IndexedVariables are only assigned in API calls and are handled in a special way. + if (lhs->is_indexed_variable()) { + throw (compiler_exception("Should not be trying to assign an IndexedVariable " + lhs->to_string(), lhs->location())); + } + // If lhs is a VariableExpression, it must be a range variable. Non-range variables + // are scalars and read-only. if (lhs->is_variable() && lhs->is_variable()->is_range()) { - std::string pfx = lhs->is_local_variable() ? "" : pp_var_pfx; - if(is_indirect_) - out_ << "indirect(" << pfx << lhs->name() << "+index_, simd_width_) = "; - else - out_ << "indirect(" << pfx << lhs->name() << "+i_, simd_width_) = "; - + out_ << "indirect(" << pfx << lhs->name() << "+" << index << ", simd_width_) = "; if (!input_mask_.empty()) out_ << "S::where(" << input_mask_ << ", "; + // If the rhs is a scalar identifier or a number, it needs to be cast to a vector. + auto id = e->rhs()->is_identifier(); + auto num = e->rhs()->is_number(); + bool cast = num || (id && scalars_.count(id->name())); + if (cast) out_ << "simd_cast<simd_value>("; e->rhs()->accept(this); if (cast) out_ << ")"; if (!input_mask_.empty()) out_ << ")"; - } else { - std::string pfx = lhs->is_local_variable() ? "" : pp_var_pfx; + } + else if (lhs->is_variable() && !lhs->is_variable()->is_range()) { + throw (compiler_exception("Should not be trying to assign a non-range variable " + lhs->to_string(), lhs->location())); + } + // Otherwise, lhs must be a LocalVariable, we don't need to mask assignment according to the + // input_mask_. + else { out_ << "assign(" << pfx << lhs->name() << ", "; if (auto rhs = e->rhs()->is_identifier()) { if (auto sym = rhs->symbol()) { // We shouldn't call the rhs visitor in this case because it automatically casts indirect expressions if (sym->is_variable() && sym->is_variable()->is_range()) { - auto index = is_indirect_ ? "index_" : "i_"; out_ << "indirect(" << pp_var_pfx << rhs->name() << "+" << index << ", simd_width_))"; return; } diff --git a/modcc/printer/cprinter.hpp b/modcc/printer/cprinter.hpp index ca450eb284f53326c2f625974d50ea18bfcef258..b07b7f1b550abeda1b0553010c70ea7cbaa273c3 100644 --- a/modcc/printer/cprinter.hpp +++ b/modcc/printer/cprinter.hpp @@ -76,6 +76,7 @@ public: private: std::ostream& out_; std::string input_mask_; - bool is_indirect_ = false; + bool is_indirect_ = false; // For choosing between "index_" and "i_" as an index. Depends on whether + // we are in a procedure or handling a simd constraint in an API call. std::unordered_set<std::string> scalars_; };