diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index 699485a7f046d383d713d9f2d9e60afa7b8be774..a75b9a47918a832e118b4c8aebd9dfba9571c5d1 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -412,11 +412,11 @@ void emit_api_body(std::ostream& out, APIMethod* method) { // SIMD printing: -std::string index_i_name(const std::string& index_var) { +static std::string index_i_name(const std::string& index_var) { return index_var+"i_"; } -std::string index_constraint_name(const std::string& index_var) { +static std::string index_constraint_name(const std::string& index_var) { return index_var+"constraint_"; } diff --git a/modcc/printer/cudaprinter.cpp b/modcc/printer/cudaprinter.cpp index 8051ab091e8cbd4f1c29fa7dd0df621eab609627..583a8501692b4d6a79d3dc657704e7581c706504 100644 --- a/modcc/printer/cudaprinter.cpp +++ b/modcc/printer/cudaprinter.cpp @@ -362,28 +362,27 @@ void emit_common_defs(std::ostream& out, const Module& module_) { out << popindent << "};\n\n"; } +static std::string index_i_name(const std::string& index_var) { + return index_var+"i_"; +} + void emit_api_body_cu(std::ostream& out, APIMethod* e, bool is_point_proc) { auto body = e->body(); auto indexed_vars = indexed_locals(e->scope()); + std::unordered_set<std::string> indices; + for (auto& sym: indexed_vars) { + indices.insert(decode_indexed_variable(sym->external_variable()).index_var); + } + if (!body->statements().empty()) { out << "int tid_ = threadIdx.x + blockDim.x*blockIdx.x;\n"; out << "if (tid_<n_) {\n" << indent; - out << "auto gid_ __attribute__((unused)) = params_.node_index_[tid_];\n"; - out << "auto cid_ __attribute__((unused)) = params_.vec_ci_[gid_];\n"; - - auto uses_ion = [&] (ionKind k) { - for(auto &symbol : e->scope()->locals()) { - if (k == symbol.second->is_local_variable()->ion_channel()) { - return true; - } - } - return false; - }; - uses_ion(ionKind::K) && out << "auto kid_ = params_.ion_k_index_[tid_];\n"; - uses_ion(ionKind::Ca) && out << "auto caid_ = params_.ion_ca_index_[tid_];\n"; - uses_ion(ionKind::Na) && out << "auto naid_ = params_.ion_na_index_[tid_];\n"; + for (auto& index: indices) { + out << "auto " << index_i_name(index) + << " = params_." << index << "[tid_];\n"; + } for (auto& sym: indexed_vars) { emit_state_read_cu(out, sym); @@ -419,33 +418,19 @@ void emit_state_update_cu(std::ostream& out, Symbol* from, if (!external->is_write()) return; const bool is_minus = external->op()==tok::minus; + auto d = decode_indexed_variable(external); + if (is_point_proc) { out << "arb::gpu::reduce_by_key("; is_minus && out << "-"; out << from->name() - << ", params_." << decode_indexed_variable(external).data_var << ", gid_);\n"; + << ", params_." << d.data_var << ", " << index_i_name(d.index_var) << ");\n"; } else { out << cuprint(external) << (is_minus? " -= ": " += ") << from->name() << ";\n"; } } -// Gives the name of the local variable used to index into a field. -const char* index_id(Symbol *s) { - if(s->is_variable()) return "tid_"; - else if(auto var = s->is_indexed_variable()) { - switch(var->ion_channel()) { - case ionKind::none: return "gid_"; - case ionKind::Ca: return "caid_"; - case ionKind::Na: return "naid_"; - case ionKind::K: return "kid_"; - default : - throw compiler_exception("CudaPrinter unknown ion type", s->location()); - } - } - return ""; -} - // CUDA Printer visitors void CudaPrinter::visit(VariableExpression *sym) { @@ -454,7 +439,7 @@ void CudaPrinter::visit(VariableExpression *sym) { void CudaPrinter::visit(IndexedVariable *e) { auto d = decode_indexed_variable(e); - out_ << "params_." << d.data_var << "[" << index_id(e) << "]"; + out_ << "params_." << d.data_var << "[" << index_i_name(d.index_var) << "]"; } void CudaPrinter::visit(CallExpression* e) { diff --git a/src/memory/device_coordinator.hpp b/src/memory/device_coordinator.hpp index ff45f28644055ed45fdc7d7e8f5e7f3e793a1220..bb8303f11ee157ca6b9bec0988b0f856179de868 100644 --- a/src/memory/device_coordinator.hpp +++ b/src/memory/device_coordinator.hpp @@ -251,7 +251,7 @@ public: // fill memory void set(view_type &rng, value_type value) { if (rng.size()) { - gpu::fill<value_type>(rng.data(), value, rng.size()); + arb::gpu::fill<value_type>(rng.data(), value, rng.size()); } }