Skip to content
Snippets Groups Projects
Unverified Commit 57394539 authored by Nora Abi Akar's avatar Nora Abi Akar Committed by GitHub
Browse files

Expose `time` to the mechanisms (#1113)

The `vec_t_` time vector is already available to the mechanisms, but not exposed. It is indexed by `cell_index` instead of directly by the `node_index` (`vec_t[vec_ci_[node_index_[i]]]`). This kind of indexing was previously unavailble. It has not been added to the printers. 
This PR also includes some cleanup in the vectorized code printer. 

Address #1109
parent a4fe0d07
No related branches found
No related tags found
No related merge requests found
......@@ -46,6 +46,7 @@ enum class sourceKind {
conductivity,
conductance,
dt,
time,
ion_current,
ion_current_density,
ion_revpot,
......
......@@ -598,8 +598,8 @@ void Module::add_variables_to_symbols() {
continue;
}
// Special case: 'celsius' is an external indexed-variable with a special
// data source. Retrieval of value is handled especially by printers.
// Special cases: 'celsius', 'diam' and 't' are external indexed-variables with special
// data sources. Retrieval of their values is handled especially by printers.
if (id.name() == "celsius") {
create_indexed_variable("celsius", sourceKind::temperature, accessKind::read, "", Location());
......@@ -607,6 +607,9 @@ void Module::add_variables_to_symbols() {
else if (id.name() == "diam") {
create_indexed_variable("diam", sourceKind::diameter, accessKind::read, "", Location());
}
else if (id.name() == "t") {
create_indexed_variable("t", sourceKind::time, accessKind::read, "", Location());
}
else {
// Parameters are scalar by default, but may later be changed to range.
auto& sym = create_variable(id.token,
......@@ -619,7 +622,7 @@ void Module::add_variables_to_symbols() {
}
}
// Remove `celsius` and `diam` from the parameter block, as they are not true parameters anymore.
// Remove `celsius`, `diam` and `t` from the parameter block, as they are not true parameters anymore.
parameter_block_.parameters.erase(
std::remove_if(parameter_block_.begin(), parameter_block_.end(),
[](const Id& id) { return id.name() == "celsius"; }),
......@@ -632,6 +635,12 @@ void Module::add_variables_to_symbols() {
parameter_block_.end()
);
parameter_block_.parameters.erase(
std::remove_if(parameter_block_.begin(), parameter_block_.end(),
[](const Id& id) { return id.name() == "t"; }),
parameter_block_.end()
);
// Add 'assigned' variables, ignoring built-in voltage variable "v".
for (const Id& id: assigned_block_) {
if (id.name() == "v") {
......
This diff is collapsed.
#include <cmath>
#include <iostream>
#include <string>
#include <unordered_set>
#include <set>
#include "gpuprinter.hpp"
#include "expression.hpp"
......@@ -384,11 +384,26 @@ void emit_api_body_cu(std::ostream& out, APIMethod* e, bool is_point_proc) {
auto body = e->body();
auto indexed_vars = indexed_locals(e->scope());
std::unordered_set<std::string> indices;
struct index_prop {
std::string source_var; // array holding the indices
std::string index_name; // index into the array
bool operator==(const index_prop& other) const {
return (source_var == other.source_var) && (index_name==other.index_name);
}
};
std::list<index_prop> indices;
for (auto& sym: indexed_vars) {
auto d = decode_indexed_variable(sym->external_variable());
if (!d.scalar()) {
indices.insert(d.index_var);
index_prop node_idx = {d.node_index_var, "tid_"};
auto it = std::find(indices.begin(), indices.end(), node_idx);
if (it == indices.end()) indices.push_front(node_idx);
if (!d.cell_index_var.empty()) {
index_prop cell_idx = {d.cell_index_var, index_i_name(d.node_index_var)};
auto it = std::find(indices.begin(), indices.end(), cell_idx);
if (it == indices.end()) indices.push_back(cell_idx);
}
}
}
......@@ -408,8 +423,8 @@ void emit_api_body_cu(std::ostream& out, APIMethod* e, bool is_point_proc) {
out << "if (tid_<n_) {\n" << indent;
for (auto& index: indices) {
out << "auto " << index_i_name(index)
<< " = params_." << index << "[tid_];\n";
out << "auto " << index_i_name(index.source_var)
<< " = params_." << index.source_var << "[" << index.index_name << "];\n";
}
for (auto& sym: indexed_vars) {
......@@ -437,8 +452,9 @@ namespace {
deref(indexed_variable_info v): v(v) {}
friend std::ostream& operator<<(std::ostream& o, const deref& wrap) {
auto index_var = wrap.v.cell_index_var.empty() ? wrap.v.node_index_var : wrap.v.cell_index_var;
return o << "params_." << wrap.v.data_var << '['
<< (wrap.v.scalar()? "0": index_i_name(wrap.v.index_var)) << ']';
<< (wrap.v.scalar()? "0": index_i_name(index_var)) << ']';
}
};
}
......@@ -476,7 +492,9 @@ void emit_state_update_cu(std::ostream& out, Symbol* from,
if (coeff != 1) out << as_c_double(coeff) << '*';
out << "params_.weight_[tid_]*" << from->name() << ',';
out << "params_." << d.data_var << ", " << index_i_name(d.index_var) << ", lane_mask_);\n";
auto index_var = d.cell_index_var.empty() ? d.node_index_var : d.cell_index_var;
out << "params_." << d.data_var << ", " << index_i_name(index_var) << ", lane_mask_);\n";
}
else if (d.accumulate) {
out << deref(d) << " = fma(";
......
......@@ -115,7 +115,7 @@ NetReceiveExpression* find_net_receive(const Module& m) {
indexed_variable_info decode_indexed_variable(IndexedVariable* sym) {
indexed_variable_info v;
v.index_var = "node_index_";
v.node_index_var = "node_index_";
v.scale = 1;
v.accumulate = true;
v.readonly = true;
......@@ -123,7 +123,7 @@ indexed_variable_info decode_indexed_variable(IndexedVariable* sym) {
std::string ion_pfx;
if (sym->is_ion()) {
ion_pfx = "ion_"+sym->ion_channel()+"_";
v.index_var = ion_pfx+"index_";
v.node_index_var = ion_pfx+"index_";
}
switch (sym->data_source()) {
......@@ -155,6 +155,11 @@ indexed_variable_info decode_indexed_variable(IndexedVariable* sym) {
v.data_var = "vec_dt_";
v.readonly = true;
break;
case sourceKind::time:
v.data_var = "vec_t_";
v.cell_index_var = "vec_ci_";
v.readonly = true;
break;
case sourceKind::ion_current_density:
v.data_var = ion_pfx+".current_density";
v.scale = 0.1;
......@@ -180,7 +185,7 @@ indexed_variable_info decode_indexed_variable(IndexedVariable* sym) {
break;
case sourceKind::ion_valence:
v.data_var = ion_pfx+".ionic_charge";
v.index_var = ""; // scalar global
v.node_index_var = ""; // scalar global
v.readonly = true;
break;
case sourceKind::temperature:
......
......@@ -116,7 +116,8 @@ NetReceiveExpression* find_net_receive(const Module& m);
struct indexed_variable_info {
std::string data_var;
std::string index_var;
std::string node_index_var;
std::string cell_index_var;
bool accumulate = true; // true => add with weight_ factor on assignment
bool readonly = false; // true => can never be assigned to by a mechanism
......@@ -124,7 +125,7 @@ struct indexed_variable_info {
// Scale is the conversion factor from the data variable
// to the NMODL value.
double scale = 1;
bool scalar() const { return index_var.empty(); }
bool scalar() const { return node_index_var.empty(); }
};
indexed_variable_info decode_indexed_variable(IndexedVariable* sym);
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment