Skip to content
Snippets Groups Projects
Unverified Commit 45ac39ae authored by Nora Abi Akar's avatar Nora Abi Akar Committed by GitHub
Browse files

Generate correct simd code for reading `peer_index` and `v_peer` (#1735)

Use `constraint_category` only when reading/writing data sources indexed by `node_index`. Every other indexed data source assumes that the underlying index belongs to the category `index_constraint::none`. 
Fixes #1734.
parent 7c381f04
No related branches found
No related tags found
No related merge requests found
......@@ -44,7 +44,7 @@ static std::string scaled(double coeff) {
struct index_prop {
std::string source_var; // array holding the indices
std::string index_name; // index into the array
bool node_index; // node index (cv) or cell index
index_kind kind; // node index (cv) or cell index or other
bool operator==(const index_prop& other) const {
return (source_var == other.source_var) && (index_name == other.index_name);
}
......@@ -538,8 +538,8 @@ namespace {
deref(indexed_variable_info d): d(d) {}
friend std::ostream& operator<<(std::ostream& o, const deref& wrap) {
auto index_var = wrap.d.cell_index_var.empty() ? wrap.d.node_index_var : wrap.d.cell_index_var;
auto i_name = index_i_name(index_var);
auto index_var = wrap.d.outer_index_var();
auto i_name = index_i_name(index_var);
index_var = pp_var_pfx + index_var;
return o << data_via_ppack(wrap.d) << '[' << (wrap.d.scalar() ? "0": i_name) << ']';
}
......@@ -551,13 +551,19 @@ std::list<index_prop> gather_indexed_vars(const std::vector<LocalVariable*>& ind
for (auto& sym: indexed_vars) {
auto d = decode_indexed_variable(sym->external_variable());
if (!d.scalar()) {
index_prop node_idx = {d.node_index_var, index, true};
auto it = std::find(indices.begin(), indices.end(), node_idx);
if (it == indices.end()) indices.push_front(node_idx);
if (!d.cell_index_var.empty()) {
index_prop cell_idx = {d.cell_index_var, node_index_i_name(d), false};
auto it = std::find(indices.begin(), indices.end(), cell_idx);
if (it == indices.end()) indices.push_back(cell_idx);
auto nested = !d.inner_index_var().empty();
auto outer_index_var = d.outer_index_var();
auto inner_index_var = nested? d.inner_index_var()+"i_": index;
index_prop index_var = {outer_index_var, inner_index_var, d.index_var_kind};
auto it = std::find(indices.begin(), indices.end(), index_var);
if (it == indices.end()) {
// If an inner index is required, push the outer index_var to the end of the list
if (nested) {
indices.push_back(index_var);
}
else {
indices.push_front(index_var);
}
}
}
}
......@@ -780,8 +786,9 @@ void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_con
<< "[0]);\n";
}
else {
if (d.cell_index_var.empty()) {
switch (constraint) {
switch (d.index_var_kind) {
case index_kind::node: {
switch (constraint) {
case simd_expr_constraint::contiguous:
out << ";\n"
<< "assign(" << local->name() << ", indirect(" << data_via_ppack(d)
......@@ -795,12 +802,15 @@ void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_con
out << ";\n"
<< "assign(" << local->name() << ", indirect(" << data_via_ppack(d)
<< ", " << node_index_i_name(d) << ", simd_width_, constraint_category_));\n";
}
break;
}
default: {
out << ";\n"
<< "assign(" << local->name() << ", indirect(" << data_via_ppack(d)
<< ", " << index_i_name(d.outer_index_var()) << ", simd_width_, index_constraint::none));\n";
break;
}
}
else {
out << ";\n"
<< "assign(" << local->name() << ", indirect(" << data_via_ppack(d)
<< ", " << index_i_name(d.cell_index_var) << ", simd_width_, index_constraint::none));\n";
}
}
......@@ -827,8 +837,9 @@ void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* ex
ENTER(out);
if (d.accumulate) {
if (d.cell_index_var.empty()) {
switch (constraint) {
switch (d.index_var_kind) {
case index_kind::node: {
switch (constraint) {
case simd_expr_constraint::contiguous:
{
std::string tempvar = "t_" + external->name();
......@@ -861,19 +872,24 @@ void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* ex
out << " += S::mul(w_, " << from->name() << ");\n";
}
}
}
break;
}
} else {
out << "indirect(" << data_via_ppack(d) << ", " << index_i_name(d.cell_index_var) << ", simd_width_, index_constraint::none)";
if (coeff != 1) {
out << " += S::mul(w_, S::mul(simd_cast<simd_value>(" << as_c_double(coeff) << "), " << from->name() << "));\n";
} else {
out << " += S::mul(w_, " << from->name() << ");\n";
default: {
out << "indirect(" << data_via_ppack(d) << ", " << index_i_name(d.outer_index_var()) << ", simd_width_, index_constraint::none)";
if (coeff != 1) {
out << " += S::mul(w_, S::mul(simd_cast<simd_value>(" << as_c_double(coeff) << "), " << from->name() << "));\n";
} else {
out << " += S::mul(w_, " << from->name() << ");\n";
}
break;
}
}
}
else {
if (d.cell_index_var.empty()) {
switch (constraint) {
switch (d.index_var_kind) {
case index_kind::node: {
switch (constraint) {
case simd_expr_constraint::contiguous:
out << "indirect(" << data_via_ppack(d) << " + " << node_index_i_name(d) << ", simd_width_) = ";
break;
......@@ -882,13 +898,14 @@ void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* ex
break;
default:
out << "indirect(" << data_via_ppack(d) << ", " << node_index_i_name(d) << ", simd_width_, constraint_category_) = ";
}
break;
}
default: {
out << "indirect(" << data_via_ppack(d)
<< ", " << index_i_name(d.outer_index_var()) << ", simd_width_, index_constraint::none) = ";
break;
}
} else {
out << "indirect(" << data_via_ppack(d)
<< ", " << index_i_name(d.cell_index_var) << ", simd_width_, index_constraint::none) = ";
}
if (coeff != 1) {
......@@ -905,8 +922,9 @@ void emit_simd_index_initialize(std::ostream& out, const std::list<index_prop>&
simd_expr_constraint constraint) {
ENTER(out);
for (auto& index: indices) {
if (index.node_index) {
switch (constraint) {
switch (index.kind) {
case index_kind::node: {
switch (constraint) {
case simd_expr_constraint::contiguous:
case simd_expr_constraint::constant:
out << "auto " << source_index_i_name(index) << " = " << source_var(index) << "[" << index.index_name << "];\n";
......@@ -915,9 +933,12 @@ void emit_simd_index_initialize(std::ostream& out, const std::list<index_prop>&
out << "auto " << source_index_i_name(index) << " = simd_cast<simd_index>(indirect(&" << source_var(index)
<< "[0] + " << index.index_name << ", simd_width_));\n";
break;
}
break;
}
} else {
switch (constraint) {
case index_kind::cell: {
// Treat like reading a state variable.
switch (constraint) {
case simd_expr_constraint::contiguous:
out << "auto " << source_index_i_name(index) << " = simd_cast<simd_index>(indirect(" << source_var(index)
<< " + " << index.index_name << ", simd_width_));\n";
......@@ -930,6 +951,13 @@ void emit_simd_index_initialize(std::ostream& out, const std::list<index_prop>&
out << "auto " << source_index_i_name(index) << " = simd_cast<simd_index>(indirect(" << source_var(index)
<< ", " << index.index_name << ", simd_width_, constraint_category_));\n";
break;
}
break;
}
default: {
out << "auto " << source_index_i_name(index) << " = simd_cast<simd_index>(indirect(&" << source_var(index)
<< "[0] + " << index.index_name << ", simd_width_));\n";
break;
}
}
}
......
......@@ -364,13 +364,19 @@ void emit_api_body_cu(std::ostream& out, APIMethod* e, bool is_point_proc, bool
for (auto& sym: indexed_vars) {
auto d = decode_indexed_variable(sym->external_variable());
if (!d.scalar()) {
index_prop node_idx = {d.node_index_var, "tid_"};
auto it = std::find(indices.begin(), indices.end(), node_idx);
if (it == indices.end()) indices.push_front(node_idx);
if (!d.cell_index_var.empty()) {
index_prop cell_idx = {d.cell_index_var, index_i_name(d.node_index_var)};
auto it = std::find(indices.begin(), indices.end(), cell_idx);
if (it == indices.end()) indices.push_back(cell_idx);
auto nested = !d.inner_index_var().empty();
auto outer_index_var = d.outer_index_var();
auto inner_index_var = nested? index_i_name(d.inner_index_var()): "tid_";
index_prop index_var = {outer_index_var, inner_index_var};
auto it = std::find(indices.begin(), indices.end(), index_var);
if (it == indices.end()) {
// If an inner index is required, push the outer index_var to the end of the list
if (nested) {
indices.push_back(index_var);
}
else {
indices.push_front(index_var);
}
}
}
}
......@@ -419,7 +425,7 @@ namespace {
deref(indexed_variable_info v): v(v) {}
friend std::ostream& operator<<(std::ostream& o, const deref& wrap) {
auto index_var = wrap.v.cell_index_var.empty() ? wrap.v.node_index_var : wrap.v.cell_index_var;
auto index_var = wrap.v.outer_index_var();
return o << pp_var_pfx << wrap.v.data_var << '['
<< (wrap.v.scalar()? "0": index_i_name(index_var)) << ']';
}
......@@ -459,7 +465,7 @@ void emit_state_update_cu(std::ostream& out, Symbol* from,
out << pp_var_pfx << "weight[tid_]*" << from->name() << ',';
auto index_var = d.cell_index_var.empty() ? d.node_index_var : d.cell_index_var;
auto index_var = d.outer_index_var();
out << pp_var_pfx << d.data_var << ", " << index_i_name(index_var) << ", lane_mask_);\n";
}
else if (d.accumulate) {
......
......@@ -119,9 +119,26 @@ PostEventExpression* find_post_event(const Module& m) {
return it==m.symbols().end()? nullptr: it->second->is_post_event();
}
bool indexed_variable_info::scalar() const { return index_var_kind==index_kind::none; }
std::string indexed_variable_info::inner_index_var() const {
if (index_var_kind == index_kind::cell) return node_index_var;
return {};
}
std::string indexed_variable_info::outer_index_var() const {
switch(index_var_kind) {
case index_kind::node: return node_index_var;
case index_kind::cell: return cell_index_var;
case index_kind::other: return other_index_var;
default: return {};
}
}
indexed_variable_info decode_indexed_variable(IndexedVariable* sym) {
indexed_variable_info v;
v.node_index_var = "node_index";
v.index_var_kind = index_kind::node;
v.scale = 1;
v.accumulate = true;
v.readonly = true;
......@@ -139,7 +156,9 @@ indexed_variable_info decode_indexed_variable(IndexedVariable* sym) {
break;
case sourceKind::peer_voltage:
v.data_var="vec_v";
v.node_index_var = "peer_index";
v.other_index_var = "peer_index";
v.node_index_var = "";
v.index_var_kind = index_kind::other;
v.readonly = true;
break;
case sourceKind::current_density:
......@@ -169,6 +188,7 @@ indexed_variable_info decode_indexed_variable(IndexedVariable* sym) {
case sourceKind::time:
v.data_var = "vec_t";
v.cell_index_var = "vec_di";
v.index_var_kind = index_kind::cell;
v.readonly = true;
break;
case sourceKind::ion_current_density:
......@@ -197,6 +217,7 @@ indexed_variable_info decode_indexed_variable(IndexedVariable* sym) {
case sourceKind::ion_valence:
v.data_var = ion_pfx+".ionic_charge";
v.node_index_var = ""; // scalar global
v.index_var_kind = index_kind::none;
v.readonly = true;
break;
case sourceKind::temperature:
......
......@@ -124,10 +124,27 @@ NetReceiveExpression* find_net_receive(const Module& m);
PostEventExpression* find_post_event(const Module& m);
// For generating vectorized code for reading and writing data sources.
// node: The data source uses the CV index which is categorized into
// one of four constraints to optimize memory accesses.
// cell: The data source uses the cell index, which is in turn indexed
// according to the CV index.
// other: The data source is indexed according to some other index.
// Vector optimizations should be skipped.
// none: The data source is scalar.
enum class index_kind {
node,
cell,
other,
none
};
struct indexed_variable_info {
std::string data_var;
std::string node_index_var;
std::string cell_index_var;
std::string other_index_var;
index_kind index_var_kind;
bool accumulate = true; // true => add with weight_ factor on assignment
bool readonly = false; // true => can never be assigned to by a mechanism
......@@ -135,7 +152,9 @@ struct indexed_variable_info {
// Scale is the conversion factor from the data variable
// to the NMODL value.
double scale = 1;
bool scalar() const { return node_index_var.empty(); }
bool scalar() const;
std::string inner_index_var() const;
std::string outer_index_var() const;
};
indexed_variable_info decode_indexed_variable(IndexedVariable* sym);
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment