From 45ac39ae77c6db9777b5b21f42768eca1baf53c4 Mon Sep 17 00:00:00 2001
From: Nora Abi Akar <nora.abiakar@gmail.com>
Date: Wed, 17 Nov 2021 06:58:04 +0100
Subject: [PATCH] Generate correct simd code for reading `peer_index` and
 `v_peer` (#1735)

Use `constraint_category` only when reading/writing data sources indexed by `node_index`. Every other indexed data source assumes that the underlying index belongs to the category `index_constraint::none`.
Fixes #1734.
---
 modcc/printer/cprinter.cpp    | 102 ++++++++++++++++++++++------------
 modcc/printer/gpuprinter.cpp  |  24 +++++---
 modcc/printer/printerutil.cpp |  23 +++++++-
 modcc/printer/printerutil.hpp |  21 ++++++-
 4 files changed, 122 insertions(+), 48 deletions(-)

diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp
index 23c4793f..0b9f0ff6 100644
--- a/modcc/printer/cprinter.cpp
+++ b/modcc/printer/cprinter.cpp
@@ -44,7 +44,7 @@ static std::string scaled(double coeff) {
 struct index_prop {
     std::string source_var; // array holding the indices
     std::string index_name; // index into the array
-    bool        node_index; // node index (cv) or cell index
+    index_kind  kind; // node index (cv) or cell index or other
     bool operator==(const index_prop& other) const {
         return (source_var == other.source_var) && (index_name == other.index_name);
     }
@@ -538,8 +538,8 @@ namespace {
         deref(indexed_variable_info d): d(d) {}
 
         friend std::ostream& operator<<(std::ostream& o, const deref& wrap) {
-            auto index_var = wrap.d.cell_index_var.empty() ? wrap.d.node_index_var : wrap.d.cell_index_var;
-            auto i_name    = index_i_name(index_var);
+            auto index_var = wrap.d.outer_index_var();
+            auto i_name = index_i_name(index_var);
             index_var = pp_var_pfx + index_var;
             return o << data_via_ppack(wrap.d) << '[' << (wrap.d.scalar() ? "0": i_name) << ']';
         }
@@ -551,13 +551,19 @@ std::list<index_prop> gather_indexed_vars(const std::vector<LocalVariable*>& ind
     for (auto& sym: indexed_vars) {
         auto d = decode_indexed_variable(sym->external_variable());
         if (!d.scalar()) {
-            index_prop node_idx = {d.node_index_var, index, true};
-            auto it = std::find(indices.begin(), indices.end(), node_idx);
-            if (it == indices.end()) indices.push_front(node_idx);
-            if (!d.cell_index_var.empty()) {
-                index_prop cell_idx = {d.cell_index_var, node_index_i_name(d), false};
-                auto it = std::find(indices.begin(), indices.end(), cell_idx);
-                if (it == indices.end()) indices.push_back(cell_idx);
+            auto nested = !d.inner_index_var().empty();
+            auto outer_index_var = d.outer_index_var();
+            auto inner_index_var = nested? d.inner_index_var()+"i_": index;
+            index_prop index_var = {outer_index_var, inner_index_var, d.index_var_kind};
+            auto it = std::find(indices.begin(), indices.end(), index_var);
+            if (it == indices.end()) {
+                // If an inner index is required, push the outer index_var to the end of the list
+                if (nested) {
+                    indices.push_back(index_var);
+                }
+                else {
+                    indices.push_front(index_var);
+                }
             }
         }
     }
@@ -780,8 +786,9 @@ void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_con
                 << "[0]);\n";
         }
         else {
-            if (d.cell_index_var.empty()) {
-                switch (constraint) {
+            switch (d.index_var_kind) {
+                case index_kind::node: {
+                    switch (constraint) {
                     case simd_expr_constraint::contiguous:
                         out << ";\n"
                             << "assign(" << local->name() << ", indirect(" << data_via_ppack(d)
@@ -795,12 +802,15 @@ void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_con
                         out << ";\n"
                             << "assign(" << local->name() << ", indirect(" << data_via_ppack(d)
                             << ", " << node_index_i_name(d) << ", simd_width_, constraint_category_));\n";
+                    }
+                    break;
+                }
+                default: {
+                    out << ";\n"
+                        << "assign(" << local->name() << ", indirect(" << data_via_ppack(d)
+                        << ", " << index_i_name(d.outer_index_var()) << ", simd_width_, index_constraint::none));\n";
+                    break;
                 }
-            }
-            else {
-                out << ";\n"
-                    << "assign(" << local->name() << ", indirect(" << data_via_ppack(d)
-                    << ", " << index_i_name(d.cell_index_var) << ", simd_width_, index_constraint::none));\n";
             }
         }
 
@@ -827,8 +837,9 @@ void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* ex
     ENTER(out);
 
     if (d.accumulate) {
-        if (d.cell_index_var.empty()) {
-            switch (constraint) {
+        switch (d.index_var_kind) {
+            case index_kind::node: {
+                switch (constraint) {
                 case simd_expr_constraint::contiguous:
                 {
                     std::string tempvar = "t_" + external->name();
@@ -861,19 +872,24 @@ void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* ex
                         out << " += S::mul(w_, " << from->name() << ");\n";
                     }
                 }
+                }
+                break;
             }
-        } else {
-            out << "indirect(" << data_via_ppack(d) << ", " << index_i_name(d.cell_index_var) << ", simd_width_, index_constraint::none)";
-            if (coeff != 1) {
-                out << " += S::mul(w_, S::mul(simd_cast<simd_value>(" << as_c_double(coeff) << "), " << from->name() << "));\n";
-            } else {
-                out << " += S::mul(w_, " << from->name() << ");\n";
+            default: {
+                out << "indirect(" << data_via_ppack(d) << ", " << index_i_name(d.outer_index_var()) << ", simd_width_, index_constraint::none)";
+                if (coeff != 1) {
+                    out << " += S::mul(w_, S::mul(simd_cast<simd_value>(" << as_c_double(coeff) << "), " << from->name() << "));\n";
+                } else {
+                    out << " += S::mul(w_, " << from->name() << ");\n";
+                }
+                break;
             }
         }
     }
     else {
-        if (d.cell_index_var.empty()) {
-            switch (constraint) {
+        switch (d.index_var_kind) {
+            case index_kind::node: {
+                switch (constraint) {
                 case simd_expr_constraint::contiguous:
                     out << "indirect(" << data_via_ppack(d) << " + " << node_index_i_name(d) << ", simd_width_) = ";
                     break;
@@ -882,13 +898,14 @@ void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* ex
                     break;
                 default:
                     out << "indirect(" << data_via_ppack(d) << ", " << node_index_i_name(d) << ", simd_width_, constraint_category_) = ";
+                }
+                break;
+            }
+            default: {
+                out << "indirect(" << data_via_ppack(d)
+                    << ", " << index_i_name(d.outer_index_var()) << ", simd_width_, index_constraint::none) = ";
+                break;
             }
-        } else {
-            out << "indirect(" << data_via_ppack(d)
-
-
-
-                << ", " << index_i_name(d.cell_index_var) << ", simd_width_, index_constraint::none) = ";
         }
 
         if (coeff != 1) {
@@ -905,8 +922,9 @@ void emit_simd_index_initialize(std::ostream& out, const std::list<index_prop>&
                                 simd_expr_constraint constraint) {
     ENTER(out);
     for (auto& index: indices) {
-        if (index.node_index) {
-            switch (constraint) {
+        switch (index.kind) {
+            case index_kind::node: {
+                switch (constraint) {
                 case simd_expr_constraint::contiguous:
                 case simd_expr_constraint::constant:
                     out << "auto " << source_index_i_name(index) << " = " << source_var(index) << "[" << index.index_name << "];\n";
@@ -915,9 +933,12 @@ void emit_simd_index_initialize(std::ostream& out, const std::list<index_prop>&
                     out << "auto " << source_index_i_name(index) << " = simd_cast<simd_index>(indirect(&" << source_var(index)
                         << "[0] + " << index.index_name << ", simd_width_));\n";
                     break;
+                }
+                break;
             }
-        } else {
-            switch (constraint) {
+            case index_kind::cell: {
+                // Treat like reading a state variable.
+                switch (constraint) {
                 case simd_expr_constraint::contiguous:
                     out << "auto " << source_index_i_name(index) << " = simd_cast<simd_index>(indirect(" << source_var(index)
                         << " + " << index.index_name << ", simd_width_));\n";
@@ -930,6 +951,13 @@ void emit_simd_index_initialize(std::ostream& out, const std::list<index_prop>&
                     out << "auto " << source_index_i_name(index) << " = simd_cast<simd_index>(indirect(" << source_var(index)
                         << ", " << index.index_name << ", simd_width_, constraint_category_));\n";
                     break;
+                }
+                break;
+            }
+            default: {
+                out << "auto " << source_index_i_name(index) << " = simd_cast<simd_index>(indirect(&" << source_var(index)
+                    << "[0] + " << index.index_name << ", simd_width_));\n";
+                break;
             }
         }
     }
diff --git a/modcc/printer/gpuprinter.cpp b/modcc/printer/gpuprinter.cpp
index 8c6bfc0c..3b431592 100644
--- a/modcc/printer/gpuprinter.cpp
+++ b/modcc/printer/gpuprinter.cpp
@@ -364,13 +364,19 @@ void emit_api_body_cu(std::ostream& out, APIMethod* e, bool is_point_proc, bool
     for (auto& sym: indexed_vars) {
         auto d = decode_indexed_variable(sym->external_variable());
         if (!d.scalar()) {
-            index_prop node_idx = {d.node_index_var, "tid_"};
-            auto it = std::find(indices.begin(), indices.end(), node_idx);
-            if (it == indices.end()) indices.push_front(node_idx);
-            if (!d.cell_index_var.empty()) {
-                index_prop cell_idx = {d.cell_index_var, index_i_name(d.node_index_var)};
-                auto it = std::find(indices.begin(), indices.end(), cell_idx);
-                if (it == indices.end()) indices.push_back(cell_idx);
+            auto nested = !d.inner_index_var().empty();
+            auto outer_index_var = d.outer_index_var();
+            auto inner_index_var = nested? index_i_name(d.inner_index_var()): "tid_";
+            index_prop index_var = {outer_index_var, inner_index_var};
+            auto it = std::find(indices.begin(), indices.end(), index_var);
+            if (it == indices.end()) {
+                // If an inner index is required, push the outer index_var to the end of the list
+                if (nested) {
+                    indices.push_back(index_var);
+                }
+                else {
+                    indices.push_front(index_var);
+                }
             }
         }
     }
@@ -419,7 +425,7 @@ namespace {
 
         deref(indexed_variable_info v): v(v) {}
         friend std::ostream& operator<<(std::ostream& o, const deref& wrap) {
-            auto index_var = wrap.v.cell_index_var.empty() ? wrap.v.node_index_var : wrap.v.cell_index_var;
+            auto index_var = wrap.v.outer_index_var();
             return o << pp_var_pfx << wrap.v.data_var << '['
                      << (wrap.v.scalar()? "0": index_i_name(index_var)) << ']';
         }
@@ -459,7 +465,7 @@ void emit_state_update_cu(std::ostream& out, Symbol* from,
 
         out << pp_var_pfx << "weight[tid_]*" << from->name() << ',';
 
-        auto index_var = d.cell_index_var.empty() ? d.node_index_var : d.cell_index_var;
+        auto index_var = d.outer_index_var();
         out << pp_var_pfx << d.data_var << ", " << index_i_name(index_var) << ", lane_mask_);\n";
     }
     else if (d.accumulate) {
diff --git a/modcc/printer/printerutil.cpp b/modcc/printer/printerutil.cpp
index 47d146a1..7e7b7c31 100644
--- a/modcc/printer/printerutil.cpp
+++ b/modcc/printer/printerutil.cpp
@@ -119,9 +119,26 @@ PostEventExpression* find_post_event(const Module& m) {
     return it==m.symbols().end()? nullptr: it->second->is_post_event();
 }
 
+bool indexed_variable_info::scalar() const { return index_var_kind==index_kind::none; }
+
+std::string indexed_variable_info::inner_index_var() const {
+    if (index_var_kind == index_kind::cell) return node_index_var;
+    return {};
+}
+
+std::string indexed_variable_info::outer_index_var() const {
+    switch(index_var_kind) {
+        case index_kind::node: return node_index_var;
+        case index_kind::cell: return cell_index_var;
+        case index_kind::other: return other_index_var;
+        default: return {};
+    }
+}
+
 indexed_variable_info decode_indexed_variable(IndexedVariable* sym) {
     indexed_variable_info v;
     v.node_index_var = "node_index";
+    v.index_var_kind = index_kind::node;
     v.scale = 1;
     v.accumulate = true;
     v.readonly = true;
@@ -139,7 +156,9 @@ indexed_variable_info decode_indexed_variable(IndexedVariable* sym) {
         break;
     case sourceKind::peer_voltage:
         v.data_var="vec_v";
-        v.node_index_var = "peer_index";
+        v.other_index_var = "peer_index";
+        v.node_index_var = "";
+        v.index_var_kind = index_kind::other;
         v.readonly = true;
         break;
     case sourceKind::current_density:
@@ -169,6 +188,7 @@ indexed_variable_info decode_indexed_variable(IndexedVariable* sym) {
     case sourceKind::time:
         v.data_var = "vec_t";
         v.cell_index_var = "vec_di";
+        v.index_var_kind = index_kind::cell;
         v.readonly = true;
         break;
     case sourceKind::ion_current_density:
@@ -197,6 +217,7 @@ indexed_variable_info decode_indexed_variable(IndexedVariable* sym) {
     case sourceKind::ion_valence:
         v.data_var = ion_pfx+".ionic_charge";
         v.node_index_var = ""; // scalar global
+        v.index_var_kind = index_kind::none;
         v.readonly = true;
         break;
     case sourceKind::temperature:
diff --git a/modcc/printer/printerutil.hpp b/modcc/printer/printerutil.hpp
index cf872106..77daace5 100644
--- a/modcc/printer/printerutil.hpp
+++ b/modcc/printer/printerutil.hpp
@@ -124,10 +124,27 @@ NetReceiveExpression* find_net_receive(const Module& m);
 
 PostEventExpression* find_post_event(const Module& m);
 
+// For generating vectorized code for reading and writing data sources.
+// node: The data source uses the CV index which is categorized into
+//       one of four constraints to optimize memory accesses.
+// cell: The data source uses the cell index, which is in turn indexed
+//       according to the CV index.
+// other: The data source is indexed according to some other index.
+//        Vector optimizations should be skipped.
+// none: The data source is scalar.
+enum class index_kind {
+    node,
+    cell,
+    other,
+    none
+};
+
 struct indexed_variable_info {
     std::string data_var;
     std::string node_index_var;
     std::string cell_index_var;
+    std::string other_index_var;
+    index_kind  index_var_kind;
 
     bool accumulate = true; // true => add with weight_ factor on assignment
     bool readonly = false;  // true => can never be assigned to by a mechanism
@@ -135,7 +152,9 @@ struct indexed_variable_info {
     // Scale is the conversion factor from the data variable
     // to the NMODL value.
     double scale = 1;
-    bool scalar() const { return node_index_var.empty(); }
+    bool scalar() const;
+    std::string inner_index_var() const;
+    std::string outer_index_var() const;
 };
 
 indexed_variable_info decode_indexed_variable(IndexedVariable* sym);
-- 
GitLab