diff --git a/modcc/cudaprinter.cpp b/modcc/cudaprinter.cpp index d151c2d03687bd8747a88036c513ea668ac162a5..904bf7b32a162eae55f92ce210c2afe22dce7fc0 100644 --- a/modcc/cudaprinter.cpp +++ b/modcc/cudaprinter.cpp @@ -1,6 +1,5 @@ #include <algorithm> -#include "cprinter.hpp" // needed for printing net_receive method #include "cudaprinter.hpp" #include "lexer.hpp" #include "options.hpp" @@ -121,7 +120,7 @@ CUDAPrinter::CUDAPrinter(Module &m, bool o) for(auto const &var : m.symbols()) { if (var.second->kind()==symbolKind::procedure && is_in(var.second->is_procedure()->kind(), - {procedureKind::normal, procedureKind::api})) + {procedureKind::normal, procedureKind::api, procedureKind::net_receive})) { var.second->accept(this); } @@ -434,22 +433,15 @@ CUDAPrinter::CUDAPrinter(Module &m, bool o) else if( var.second->kind()==symbolKind::procedure && var.second->is_procedure()->kind()==procedureKind::net_receive) { + // Simplest net_receive implementation forwards a single update + // to a GPU kernel. auto proc = var.second->is_procedure(); auto name = proc->name(); text_.add_line("void " + name + "(int i_, value_type weight) {"); text_.increase_indentation(); - - // Print the body of the net_receive block. - // Use the same body as would be generated with the cprinter. - // This is not omptimal, because each read and write will require - // a copy between host and device memory, so we will need a - // GPU-specific implementation - auto cprinter = CPrinter(*module_); - cprinter.clear_text(); - cprinter.set_gutter(text_.get_gutter()); - proc->body()->accept(&cprinter); - text_ << cprinter.text(); - + text_.add_line( + "kernels::" + name + "<value_type, size_type>" + + "<<<1, 1>>>(param_pack_, i_, weight);"); text_.decrease_indentation(); text_.add_line("}"); text_.add_line(); @@ -669,22 +661,48 @@ void CUDAPrinter::visit(ProcedureExpression *e) { e->location()); } - // print prototype - print_procedure_prototype(e); - text_.end_line(" {"); + if(e->kind() != procedureKind::net_receive) { + // print prototype + print_procedure_prototype(e); + text_.end_line(" {"); - // print body - increase_indentation(); + // print body + increase_indentation(); - text_.add_line("using value_type = T;"); - text_.add_line(); + text_.add_line("using value_type = T;"); + text_.add_line(); - e->body()->accept(this); + e->body()->accept(this); - // close up - decrease_indentation(); - text_.add_line("}"); - text_.add_line(); + // close up + decrease_indentation(); + text_.add_line("}"); + text_.add_line(); + } + else { + // net_receive() kernel is a special case, not covered by APIMethod visit. + text_.add_gutter() << "template <typename T, typename I>\n"; + text_.add_line( "__global__"); + text_.add_gutter() << "void " << e->name() + << "(" << module_->name() << "_ParamPack<T,I> params_, " + << "I i_, T weight) {"; + text_.add_line(); + increase_indentation(); + + text_.add_line("using value_type = T;"); + text_.add_line("using iarray = I;"); + text_.add_line(); + + text_.add_line("if (threadIdx.x || blockIdx.x) return;"); + text_.add_line("auto tid_ = i_;"); + text_.add_line("auto gid_ __attribute__((unused)) = params_.ni[tid_];"); + + print_APIMethod_body(e); + + decrease_indentation(); + text_.add_line("}"); + text_.add_line(); + } } void CUDAPrinter::visit(APIMethod *e) { @@ -725,7 +743,7 @@ void CUDAPrinter::visit(APIMethod *e) { text_.add_line(); } -void CUDAPrinter::print_APIMethod_body(APIMethod* e) { +void CUDAPrinter::print_APIMethod_body(ProcedureExpression* e) { // load indexes of ion channels auto uses_k = false; auto uses_na = false; diff --git a/modcc/cudaprinter.hpp b/modcc/cudaprinter.hpp index 0b2cb00480f90078b068e48474f941fb566b24dc..0f4acbbbe1c8b0ed53b978aeea946a6dd222a3c7 100644 --- a/modcc/cudaprinter.hpp +++ b/modcc/cudaprinter.hpp @@ -97,7 +97,7 @@ private: return module_->kind() == moduleKind::point; } - void print_APIMethod_body(APIMethod* e); + void print_APIMethod_body(ProcedureExpression* e); void print_procedure_prototype(ProcedureExpression *e); std::string index_string(Symbol *e); diff --git a/src/memory/device_coordinator.hpp b/src/memory/device_coordinator.hpp index 92ede76b87f031e298d2f9ccb19865a1109beedc..05fa27b931be3cf7e50320b25c369e45724ab2aa 100644 --- a/src/memory/device_coordinator.hpp +++ b/src/memory/device_coordinator.hpp @@ -290,7 +290,9 @@ public: // fill memory void set(view_type &rng, value_type value) { - gpu::fill<value_type>(rng.data(), value, rng.size()); + if (rng.size()) { + gpu::fill<value_type>(rng.data(), value, rng.size()); + } } // generate reference objects for a raw pointer.