From b7ba25a18ec9a6636f8477a945f934a57febeecc Mon Sep 17 00:00:00 2001
From: Nora Abi Akar <nora.abiakar@gmail.com>
Date: Tue, 28 Sep 2021 11:34:54 +0200
Subject: [PATCH] Fix modcc simd generation (#1681)

Fixes 2 bugs where modcc expects a range variable but gets a non-range variable instead.

    PowBinaryExpressions were being handled separately from other BinaryExpressions and as a result, some crucial analysis was being skipped on the analysis of the lhs and rhs arguments. This PR incorporates the analysis and printing of PowBinaryExpression into BinaryExpression.
    Printing code for a masked AssignmentExpression was not handling the case of a non-range variable on the rhs of the assignment correctly. In that case, an explicit cast to a vector is needed, which is added in this PR.
---
 modcc/expression.hpp         |  3 +++
 modcc/printer/cexpr_emit.cpp | 49 ++++++++++++++++++++----------------
 modcc/printer/cexpr_emit.hpp |  5 ++--
 modcc/printer/cprinter.cpp   | 41 +++++++++++++++++-------------
 modcc/printer/cprinter.hpp   |  3 ++-
 5 files changed, 59 insertions(+), 42 deletions(-)

diff --git a/modcc/expression.hpp b/modcc/expression.hpp
index e1eef14e..280791a7 100644
--- a/modcc/expression.hpp
+++ b/modcc/expression.hpp
@@ -1441,6 +1441,9 @@ public:
     :   BinaryExpression(loc, tok::pow, std::move(lhs), std::move(rhs))
     {}
 
+    // pow is a prefix binop
+    bool is_infix() const override {return false;}
+
     void accept(Visitor *v) override;
 };
 
diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp
index d1366a46..e9c23e4d 100644
--- a/modcc/printer/cexpr_emit.cpp
+++ b/modcc/printer/cexpr_emit.cpp
@@ -91,10 +91,6 @@ void CExprEmitter::visit(AssignmentExpression* e) {
     e->rhs()->accept(this);
 }
 
-void CExprEmitter::visit(PowBinaryExpression* e) {
-    emit_as_call("pow", e->lhs(), e->rhs());
-}
-
 void CExprEmitter::visit(BinaryExpression* e) {
     static std::unordered_map<tok, const char*> binop_tbl = {
         {tok::minus,    "-"},
@@ -111,6 +107,7 @@ void CExprEmitter::visit(BinaryExpression* e) {
         {tok::ne,       "!="},
         {tok::min,      "min"},
         {tok::max,      "max"},
+        {tok::pow,      "pow"},
     };
 
     if (!binop_tbl.count(e->op())) {
@@ -178,14 +175,6 @@ void CExprEmitter::visit(IfExpression* e) {
 ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
 std::unordered_set<std::string> SimdExprEmitter::mask_names_;
 
-void SimdExprEmitter::visit(PowBinaryExpression* e) {
-    out_ << "S::pow(";
-    e->lhs()->accept(this);
-    out_ << ", ";
-    e->rhs()->accept(this);
-    out_ << ')';
-}
-
 void SimdExprEmitter::visit(NumberExpression* e) {
     out_ << " (double)" << as_c_double(e->value());
 } 
@@ -252,6 +241,7 @@ void SimdExprEmitter::visit(BinaryExpression* e) {
             {tok::ne,       "S::cmp_neq"},
             {tok::min,      "S::min"},
             {tok::max,      "S::max"},
+            {tok::pow,      "S::pow"},
     };
 
     static std::unordered_map<tok, const char *> binop_tbl = {
@@ -269,6 +259,7 @@ void SimdExprEmitter::visit(BinaryExpression* e) {
             {tok::ne,       "!="},
             {tok::min,      "min"},
             {tok::max,      "max"},
+            {tok::pow,      "pow"},
     };
 
 
@@ -385,25 +376,41 @@ void SimdExprEmitter::visit(AssignmentExpression* e) {
 
     auto lhs_pfxd = id_prefix(e->lhs()->is_identifier());
 
+    // lhs should not be an IndexedVariable, only a VariableExpression or LocalVariable.
+    // IndexedVariables are only assigned in API calls and are handled in a special way.
+    if (lhs->is_indexed_variable()) {
+        throw (compiler_exception("Should not be trying to assign an IndexedVariable " + lhs->to_string(), lhs->location()));
+    }
+    // If lhs is a VariableExpression, it must be a range variable. Non-range variables
+    // are scalars and are read-only.
     if (lhs->is_variable() && lhs->is_variable()->is_range()) {
+        // input_mask_ will only appear in PROCEDURE and FUNCTION calls which can only
+        // assign to VariableExpression and LocalVariable. But only VariableExpression
+        // vectors need to be assigned only at the elements where input_mask_ is true.
         if (!input_mask_.empty()) {
             mask = "S::logical_and(" + mask + ", " + input_mask_ + ")";
         }
-        if(is_indirect_)
-            out_ << "indirect(" << lhs_pfxd << "+index_, simd_width_) = ";
-        else
-            out_ << "indirect(" << lhs_pfxd << "+i_, simd_width_) = ";
 
-        out_ << "S::where(" << mask << ", ";
+        std::string index = is_indirect_ ? "index_" : "i_";
+        out_ << "indirect(" << lhs_pfxd << "+" << index << ", simd_width_) = S::where(" << mask << ", ";
+
+        // If the rhs is a scalar identifier or a number, it needs to be cast to a vector.
+        auto id = e->rhs()->is_identifier();
+        bool num = e->rhs()->is_number();
+        bool cast = num || (id && scalars_.count(id->name()));
 
-        bool cast = e->rhs()->is_number();
         if (cast) out_ << "simd_cast<simd_value>(";
         e->rhs()->accept(this);
+        if (cast) out_ << ")";
 
         out_ << ")";
-
-        if (cast) out_ << ")";
-    } else {
+    }
+    else if (lhs->is_variable() && !lhs->is_variable()->is_range()) {
+        throw (compiler_exception("Should not be trying to assign a non-range variable " + lhs->to_string(), lhs->location()));
+    }
+    // Otherwise, lhs must be a LocalVariable, we don't need to mask assignment according to the
+    // input_mask_.
+    else {
         out_ << "S::where(" << mask << ", ";
         e->lhs()->accept(this);
         out_ << ") = ";
diff --git a/modcc/printer/cexpr_emit.hpp b/modcc/printer/cexpr_emit.hpp
index 9a2c3dbc..4b278ec8 100644
--- a/modcc/printer/cexpr_emit.hpp
+++ b/modcc/printer/cexpr_emit.hpp
@@ -21,7 +21,6 @@ public:
     void visit(UnaryExpression *e) override;
     void visit(BinaryExpression *e) override;
     void visit(AssignmentExpression *e) override;
-    void visit(PowBinaryExpression *e) override;
     void visit(NumberExpression *e) override;
     void visit(IfExpression *e) override;
 
@@ -54,14 +53,14 @@ public:
     void visit(UnaryExpression *e) override;
     void visit(BinaryExpression *e) override;
     void visit(AssignmentExpression *e) override;
-    void visit(PowBinaryExpression *e) override;
     void visit(NumberExpression *e) override;
     void visit(IfExpression *e) override;
 
 protected:
     static std::unordered_set<std::string> mask_names_;
     bool processing_true_ = false;
-    bool is_indirect_ = false;
+    bool is_indirect_ = false; // For choosing between "index_" and "i_" as an index. Depends on whether
+                               // we are in a procedure or handling a simd constraint in an API call.
     std::string current_mask_, current_mask_bar_, input_mask_;
     std::unordered_set<std::string> scalars_;
     Visitor* fallback_;
diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp
index 6dd28c88..74d92372 100644
--- a/modcc/printer/cprinter.cpp
+++ b/modcc/printer/cprinter.cpp
@@ -89,7 +89,8 @@ std::string do_cprint(Expression* cp, int ind) {
 
 struct simdprint {
     Expression* expr_;
-    bool is_indirect_ = false;
+    bool is_indirect_ = false; // For choosing between "index_" and "i_" as an index. Depends on whether
+                               // we are in a procedure or handling a simd constraint in an API call.
     bool is_masked_ = false;
     std::unordered_set<std::string> scalars_;
 
@@ -657,38 +658,44 @@ void SimdPrinter::visit(AssignmentExpression* e) {
     }
 
     Symbol* lhs = e->lhs()->is_identifier()->symbol();
- 
-    bool cast = false;
-    if (auto id = e->rhs()->is_identifier()) {
-        if (scalars_.count(id->name())) cast = true;
-    }
-    if (e->rhs()->is_number()) cast = true;
-    if (scalars_.count(e->lhs()->is_identifier()->name()))  cast = false;
+    std::string pfx = lhs->is_local_variable() ? "" : pp_var_pfx;
+    std::string index = is_indirect_ ? "index_" : "i_";
 
+    // lhs should not be an IndexedVariable, only a VariableExpression or LocalVariable.
+    // IndexedVariables are only assigned in API calls and are handled in a special way.
+    if (lhs->is_indexed_variable()) {
+        throw (compiler_exception("Should not be trying to assign an IndexedVariable " + lhs->to_string(), lhs->location()));
+    }
+    // If lhs is a VariableExpression, it must be a range variable. Non-range variables
+    // are scalars and read-only.
     if (lhs->is_variable() && lhs->is_variable()->is_range()) {
-        std::string pfx = lhs->is_local_variable() ? "" : pp_var_pfx;
-        if(is_indirect_)
-            out_ << "indirect(" << pfx << lhs->name() << "+index_, simd_width_) = ";
-        else
-            out_ << "indirect(" << pfx << lhs->name() << "+i_, simd_width_) = ";
-
+        out_ << "indirect(" << pfx << lhs->name() << "+" << index << ", simd_width_) = ";
         if (!input_mask_.empty())
             out_ << "S::where(" << input_mask_ << ", ";
 
+        // If the rhs is a scalar identifier or a number, it needs to be cast to a vector.
+        auto id = e->rhs()->is_identifier();
+        auto num = e->rhs()->is_number();
+        bool cast = num || (id && scalars_.count(id->name()));
+
         if (cast) out_ << "simd_cast<simd_value>(";
         e->rhs()->accept(this);
         if (cast) out_ << ")";
 
         if (!input_mask_.empty())
             out_ << ")";
-    } else {
-        std::string pfx = lhs->is_local_variable() ? "" : pp_var_pfx;
+    }
+    else if (lhs->is_variable() && !lhs->is_variable()->is_range()) {
+        throw (compiler_exception("Should not be trying to assign a non-range variable " + lhs->to_string(), lhs->location()));
+    }
+    // Otherwise, lhs must be a LocalVariable, we don't need to mask assignment according to the
+    // input_mask_.
+    else {
         out_ << "assign(" << pfx << lhs->name() << ", ";
         if (auto rhs = e->rhs()->is_identifier()) {
             if (auto sym = rhs->symbol()) {
                 // We shouldn't call the rhs visitor in this case because it automatically casts indirect expressions
                 if (sym->is_variable() && sym->is_variable()->is_range()) {
-                    auto index = is_indirect_ ? "index_" : "i_";
                     out_ << "indirect(" << pp_var_pfx << rhs->name() << "+" << index << ", simd_width_))";
                     return;
                 }
diff --git a/modcc/printer/cprinter.hpp b/modcc/printer/cprinter.hpp
index ca450eb2..b07b7f1b 100644
--- a/modcc/printer/cprinter.hpp
+++ b/modcc/printer/cprinter.hpp
@@ -76,6 +76,7 @@ public:
 private:
     std::ostream& out_;
     std::string input_mask_;
-    bool is_indirect_ = false;
+    bool is_indirect_ = false; // For choosing between "index_" and "i_" as an index. Depends on whether
+                               // we are in a procedure or handling a simd constraint in an API call.
     std::unordered_set<std::string> scalars_;
 };
-- 
GitLab