Skip to content
Snippets Groups Projects
Commit b07875e1 authored by Sam Yates's avatar Sam Yates Committed by Benjamin Cumming
Browse files

Generic external variables for CUDA printer. (#490)

Replace hard-coded index variable names in modcc cuda printer with ones derived from the external variable.

Uses `decode_indexed_variable`.
parent f8f169db
No related branches found
No related tags found
No related merge requests found
...@@ -412,11 +412,11 @@ void emit_api_body(std::ostream& out, APIMethod* method) { ...@@ -412,11 +412,11 @@ void emit_api_body(std::ostream& out, APIMethod* method) {
// SIMD printing: // SIMD printing:
std::string index_i_name(const std::string& index_var) { static std::string index_i_name(const std::string& index_var) {
return index_var+"i_"; return index_var+"i_";
} }
std::string index_constraint_name(const std::string& index_var) { static std::string index_constraint_name(const std::string& index_var) {
return index_var+"constraint_"; return index_var+"constraint_";
} }
......
...@@ -362,28 +362,27 @@ void emit_common_defs(std::ostream& out, const Module& module_) { ...@@ -362,28 +362,27 @@ void emit_common_defs(std::ostream& out, const Module& module_) {
out << popindent << "};\n\n"; out << popindent << "};\n\n";
} }
static std::string index_i_name(const std::string& index_var) {
return index_var+"i_";
}
void emit_api_body_cu(std::ostream& out, APIMethod* e, bool is_point_proc) { void emit_api_body_cu(std::ostream& out, APIMethod* e, bool is_point_proc) {
auto body = e->body(); auto body = e->body();
auto indexed_vars = indexed_locals(e->scope()); auto indexed_vars = indexed_locals(e->scope());
std::unordered_set<std::string> indices;
for (auto& sym: indexed_vars) {
indices.insert(decode_indexed_variable(sym->external_variable()).index_var);
}
if (!body->statements().empty()) { if (!body->statements().empty()) {
out << "int tid_ = threadIdx.x + blockDim.x*blockIdx.x;\n"; out << "int tid_ = threadIdx.x + blockDim.x*blockIdx.x;\n";
out << "if (tid_<n_) {\n" << indent; out << "if (tid_<n_) {\n" << indent;
out << "auto gid_ __attribute__((unused)) = params_.node_index_[tid_];\n";
out << "auto cid_ __attribute__((unused)) = params_.vec_ci_[gid_];\n";
auto uses_ion = [&] (ionKind k) {
for(auto &symbol : e->scope()->locals()) {
if (k == symbol.second->is_local_variable()->ion_channel()) {
return true;
}
}
return false;
};
uses_ion(ionKind::K) && out << "auto kid_ = params_.ion_k_index_[tid_];\n"; for (auto& index: indices) {
uses_ion(ionKind::Ca) && out << "auto caid_ = params_.ion_ca_index_[tid_];\n"; out << "auto " << index_i_name(index)
uses_ion(ionKind::Na) && out << "auto naid_ = params_.ion_na_index_[tid_];\n"; << " = params_." << index << "[tid_];\n";
}
for (auto& sym: indexed_vars) { for (auto& sym: indexed_vars) {
emit_state_read_cu(out, sym); emit_state_read_cu(out, sym);
...@@ -419,33 +418,19 @@ void emit_state_update_cu(std::ostream& out, Symbol* from, ...@@ -419,33 +418,19 @@ void emit_state_update_cu(std::ostream& out, Symbol* from,
if (!external->is_write()) return; if (!external->is_write()) return;
const bool is_minus = external->op()==tok::minus; const bool is_minus = external->op()==tok::minus;
auto d = decode_indexed_variable(external);
if (is_point_proc) { if (is_point_proc) {
out << "arb::gpu::reduce_by_key("; out << "arb::gpu::reduce_by_key(";
is_minus && out << "-"; is_minus && out << "-";
out << from->name() out << from->name()
<< ", params_." << decode_indexed_variable(external).data_var << ", gid_);\n"; << ", params_." << d.data_var << ", " << index_i_name(d.index_var) << ");\n";
} }
else { else {
out << cuprint(external) << (is_minus? " -= ": " += ") << from->name() << ";\n"; out << cuprint(external) << (is_minus? " -= ": " += ") << from->name() << ";\n";
} }
} }
// Gives the name of the local variable used to index into a field.
const char* index_id(Symbol *s) {
if(s->is_variable()) return "tid_";
else if(auto var = s->is_indexed_variable()) {
switch(var->ion_channel()) {
case ionKind::none: return "gid_";
case ionKind::Ca: return "caid_";
case ionKind::Na: return "naid_";
case ionKind::K: return "kid_";
default :
throw compiler_exception("CudaPrinter unknown ion type", s->location());
}
}
return "";
}
// CUDA Printer visitors // CUDA Printer visitors
void CudaPrinter::visit(VariableExpression *sym) { void CudaPrinter::visit(VariableExpression *sym) {
...@@ -454,7 +439,7 @@ void CudaPrinter::visit(VariableExpression *sym) { ...@@ -454,7 +439,7 @@ void CudaPrinter::visit(VariableExpression *sym) {
void CudaPrinter::visit(IndexedVariable *e) { void CudaPrinter::visit(IndexedVariable *e) {
auto d = decode_indexed_variable(e); auto d = decode_indexed_variable(e);
out_ << "params_." << d.data_var << "[" << index_id(e) << "]"; out_ << "params_." << d.data_var << "[" << index_i_name(d.index_var) << "]";
} }
void CudaPrinter::visit(CallExpression* e) { void CudaPrinter::visit(CallExpression* e) {
......
...@@ -251,7 +251,7 @@ public: ...@@ -251,7 +251,7 @@ public:
// fill memory // fill memory
void set(view_type &rng, value_type value) { void set(view_type &rng, value_type value) {
if (rng.size()) { if (rng.size()) {
gpu::fill<value_type>(rng.data(), value, rng.size()); arb::gpu::fill<value_type>(rng.data(), value, rng.size());
} }
} }
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment