diff --git a/modcc/expression.cpp b/modcc/expression.cpp index a2eb8daaf16e7741fc112bb26bfd998f8a9b1409..74f50ef142e0c151fe1e41e6a6d381471f835e83 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -244,6 +244,10 @@ std::string ArgumentExpression::to_string() const { return blue("arg") + " " + yellow(name_); } +expression_ptr ArgumentExpression::clone() const { + return make_expression<ArgumentExpression>(location_, token_); +} + void ArgumentExpression::semantic(scope_ptr scp) { error_ = false; scope_ = scp; @@ -686,6 +690,8 @@ void NetReceiveExpression::semantic(scope_type::symbol_map &global_symbols) { // create the scope for this procedure scope_ = std::make_shared<scope_type>(global_symbols); + scope_->in_api_context(true); + error_ = false; // add the argumemts to the list of local variables @@ -733,6 +739,8 @@ void PostEventExpression::semantic(scope_type::symbol_map &global_symbols) { // create the scope for this procedure scope_ = std::make_shared<scope_type>(global_symbols); + scope_->in_api_context(true); + error_ = false; // add the argumemts to the list of local variables diff --git a/modcc/expression.hpp b/modcc/expression.hpp index b208ac4a8ba73c3af0213f72f8c04d5f52b6c932..f12b429dd4fde2180d5059b98ac6933b028594f8 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -439,7 +439,7 @@ public: const std::string& spelling() const { return token_.spelling; } - + expression_ptr clone() const override; ~ArgumentExpression() {} void accept(Visitor *v) override; private: diff --git a/modcc/module.cpp b/modcc/module.cpp index 007397650afff58b074da1d5a0b223073f631706..dc2c5b9de93cfb2c2afacbc6f038a75731dfe6af 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -229,10 +229,13 @@ bool Module::semantic() { symbols_.find(name)->second->location()); return std::make_pair(nullptr, source); } - + std::vector<expression_ptr> args; + for (auto& a: source->args()) { + args.push_back(a->clone()); + } symbols_[name] = make_symbol<APIMethod>( loc, name, - std::vector<expression_ptr>(), // no arguments + std::move(args), make_expression<BlockExpression> (loc, expr_list_type(), false) ); @@ -507,6 +510,7 @@ bool Module::semantic() { if (has_symbol("net_receive", symbolKind::procedure)) { auto net_rec_api = make_empty_api_method("net_rec_api", "net_receive"); + net_rec_api.first->body(net_rec_api.second->body()->clone()); if (net_rec_api.second) { for (auto &s: net_rec_api.second->body()->statements()) { if (s->is_assignment()) { @@ -536,6 +540,10 @@ bool Module::semantic() { linear_ = linear; post_events_ = has_symbol("post_event", symbolKind::procedure); + if (post_events_) { + auto post_events_api = make_empty_api_method("post_event_api", "post_event"); + post_events_api.first->body(post_events_api.second->body()->clone()); + } // Are we writing an ionic reversal potential? If so, change the moduleKind to // `revpot` and assert that the mechanism is 'pure': it has no state variables; diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index 27ec365d426dd62b24f9b3dd473dbaa6dbf996b2..b2a3c5f1519cbd06181c87c49de913118cd327df 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -42,7 +42,7 @@ void emit_procedure_proto(std::ostream&, ProcedureExpression*, const std::string void emit_simd_procedure_proto(std::ostream&, ProcedureExpression*, const std::string&, const std::string& qualified = ""); void emit_masked_simd_procedure_proto(std::ostream&, ProcedureExpression*, const std::string&, const std::string& qualified = ""); -void emit_api_body(std::ostream&, APIMethod*); +void emit_api_body(std::ostream&, APIMethod*, bool cv_loop = true); void emit_simd_api_body(std::ostream&, APIMethod*, const std::vector<VariableExpression*>& scalars); void emit_simd_index_initialize(std::ostream& out, const std::list<index_prop>& indices, simd_expr_constraint constraint); @@ -114,12 +114,12 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { auto ppack_name = make_cpu_ppack_name(name); auto ns_components = namespace_components(opt.cpp_namespace); - NetReceiveExpression* net_receive = find_net_receive(module_); - PostEventExpression* post_event = find_post_event(module_); - APIMethod* init_api = find_api_method(module_, "init"); - APIMethod* state_api = find_api_method(module_, "advance_state"); - APIMethod* current_api = find_api_method(module_, "compute_currents"); - APIMethod* write_ions_api = find_api_method(module_, "write_ions"); + APIMethod* net_receive_api = find_api_method(module_, "net_rec_api"); + APIMethod* post_event_api = find_api_method(module_, "post_event_api"); + APIMethod* init_api = find_api_method(module_, "init"); + APIMethod* state_api = find_api_method(module_, "advance_state"); + APIMethod* current_api = find_api_method(module_, "compute_currents"); + APIMethod* write_ions_api = find_api_method(module_, "write_ions"); bool with_simd = opt.simd.abi!=simd_spec::none; @@ -301,11 +301,12 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { emit_body(write_ions_api); out << popindent << "}\n\n"; - if (net_receive) { - const std::string weight_arg = net_receive->args().empty() ? "weight" : net_receive->args().front()->is_argument()->name(); + if (net_receive_api) { + const std::string weight_arg = net_receive_api->args().empty() ? "weight" : net_receive_api->args().front()->is_argument()->name(); out << - "void net_receive(" << ppack_name << "* pp, int i_, ::arb::fvm_value_type " << weight_arg << ") {\n" << indent << - cprint(net_receive->body()) << popindent << + "void net_receive(" << ppack_name << "* pp, int i_, ::arb::fvm_value_type " << weight_arg << ") {\n" << indent; + emit_api_body(out, net_receive_api, false); + out << popindent << "}\n\n" "void apply_events(" << ppack_name << "* pp, ::arb::fvm_size_type mechanism_id, ::arb::multicore::deliverable_event_stream::state events) {\n" << indent << "auto ncell = events.n_streams();\n" @@ -320,8 +321,8 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "\n"; } - if(post_event) { - const std::string time_arg = post_event->args().empty() ? "time" : post_event->args().front()->is_argument()->name(); + if(post_event_api) { + const std::string time_arg = post_event_api->args().empty() ? "time" : post_event_api->args().front()->is_argument()->name(); out << "void post_event(" << ppack_name << "* pp) {\n" << indent << "int n_ = pp->width_;\n" @@ -331,8 +332,9 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "auto offset_ = pp->n_detectors_ * cid_;\n" "for (::arb::fvm_index_type c = 0; c < pp->n_detectors_; c++) {\n" << indent << "auto " << time_arg << " = pp->time_since_spike_[offset_ + c];\n" - "if (" << time_arg << " >= 0) {\n" << indent << - cprint(post_event->body()) << popindent << + "if (" << time_arg << " >= 0) {\n" << indent; + emit_api_body(out, post_event_api, false); + out << popindent << "}\n" << popindent << "}\n" << popindent << "}\n" << popindent << @@ -378,10 +380,10 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "void compute_currents() override { " << namespace_name << "::compute_currents(&pp_); }\n" "void write_ions() override{ " << namespace_name << "::write_ions(&pp_); }\n"; - net_receive && + net_receive_api && out << "void apply_events(deliverable_event_stream::state events) override { " << namespace_name << "::apply_events(&pp_, mechanism_id_, events); }\n"; - post_event && + post_event_api && out << "void post_event() override { " << namespace_name << "::post_event(&pp_); };\n"; with_simd && @@ -563,6 +565,24 @@ namespace { }; } +std::list<index_prop> gather_indexed_vars(const std::vector<LocalVariable*>& indexed_vars, const std::string& index) { + std::list<index_prop> indices; + for (auto& sym: indexed_vars) { + auto d = decode_indexed_variable(sym->external_variable()); + if (!d.scalar()) { + index_prop node_idx = {d.node_index_var, index, true}; + auto it = std::find(indices.begin(), indices.end(), node_idx); + if (it == indices.end()) indices.push_front(node_idx); + if (!d.cell_index_var.empty()) { + index_prop cell_idx = {d.cell_index_var, node_index_i_name(d), false}; + auto it = std::find(indices.begin(), indices.end(), cell_idx); + if (it == indices.end()) indices.push_back(cell_idx); + } + } + } + return indices; +}; + void emit_state_read(std::ostream& out, LocalVariable* local) { ENTER(out); out << "::arb::fvm_value_type " << cprint(local) << " = "; @@ -607,33 +627,19 @@ void emit_state_update(std::ostream& out, Symbol* from, IndexedVariable* externa EXIT(out); } -void emit_api_body(std::ostream& out, APIMethod* method) { +void emit_api_body(std::ostream& out, APIMethod* method, bool cv_loop) { ENTER(out); auto body = method->body(); auto indexed_vars = indexed_locals(method->scope()); - std::list<index_prop> indices; - for (auto& sym: indexed_vars) { - auto d = decode_indexed_variable(sym->external_variable()); - if (!d.scalar()) { - index_prop node_idx = {d.node_index_var, "i_", true}; - auto it = std::find(indices.begin(), indices.end(), node_idx); - if (it == indices.end()) indices.push_front(node_idx); - if (!d.cell_index_var.empty()) { - index_prop cell_idx = {d.cell_index_var, node_index_i_name(d), false}; - auto it = std::find(indices.begin(), indices.end(), cell_idx); - if (it == indices.end()) indices.push_back(cell_idx); - } - } - } - + std::list<index_prop> indices = gather_indexed_vars(indexed_vars, "i_"); if (!body->statements().empty()) { - out << + cv_loop && out << "int n_ = pp->width_;\n" "for (int i_ = 0; i_ < n_; ++i_) {\n" << indent; for (auto index: indices) { - out << "auto " << source_index_i_name(index) << " = " << source_var(index) << "[" << index.index_name << " ];\n"; + out << "auto " << source_index_i_name(index) << " = " << source_var(index) << "[" << index.index_name << "];\n"; } for (auto& sym: indexed_vars) { @@ -644,7 +650,7 @@ void emit_api_body(std::ostream& out, APIMethod* method) { for (auto& sym: indexed_vars) { emit_state_update(out, sym, sym->external_variable()); } - out << popindent << "}\n"; + cv_loop && out << popindent << "}\n"; } EXIT(out); } @@ -1008,9 +1014,6 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, const std::vector< auto indexed_vars = indexed_locals(method->scope()); bool requires_weight = false; - std::vector<LocalVariable*> scalar_indexed_vars; - std::list<index_prop> indices; - ENTER(out); for (auto& s: body->is_block()->statements()) { @@ -1026,21 +1029,10 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, const std::vector< } } } - + std::list<index_prop> indices = gather_indexed_vars(indexed_vars, "index_"); + std::vector<LocalVariable*> scalar_indexed_vars; for (auto& sym: indexed_vars) { - auto info = decode_indexed_variable(sym->external_variable()); - if (!info.scalar()) { - index_prop node_idx = {info.node_index_var, "index_", true}; - auto it = std::find(indices.begin(), indices.end(), node_idx); - if (it == indices.end()) indices.push_front(node_idx); - - if (!info.cell_index_var.empty()) { - index_prop cell_idx = {info.cell_index_var, node_index_i_name(info), false}; - it = std::find(indices.begin(), indices.end(), cell_idx); - if (it == indices.end()) indices.push_back(cell_idx); - } - } - else { + if (decode_indexed_variable(sym->external_variable()).scalar()) { scalar_indexed_vars.push_back(sym); } } diff --git a/modcc/printer/gpuprinter.cpp b/modcc/printer/gpuprinter.cpp index 61ad097500a94c8d5b0c1841b771a8a2379ec753..3b5953b091f0be879ef0b9ea2973ae9e40a93669 100644 --- a/modcc/printer/gpuprinter.cpp +++ b/modcc/printer/gpuprinter.cpp @@ -15,7 +15,7 @@ using io::popindent; using io::quote; void emit_common_defs(std::ostream&, const Module& module_); -void emit_api_body_cu(std::ostream& out, APIMethod* method, bool is_point_proc); +void emit_api_body_cu(std::ostream& out, APIMethod* method, bool is_point_proc, bool cv_loop = true); void emit_procedure_body_cu(std::ostream& out, ProcedureExpression* proc); void emit_state_read_cu(std::ostream& out, LocalVariable* local); void emit_state_update_cu(std::ostream& out, Symbol* from, @@ -223,8 +223,8 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt auto ns_components = namespace_components(opt.cpp_namespace); const bool is_point_proc = module_.kind() == moduleKind::point; - NetReceiveExpression* net_receive = find_net_receive(module_); - PostEventExpression* post_event = find_post_event(module_); + APIMethod* net_receive_api = find_api_method(module_, "net_rec_api"); + APIMethod* post_event_api = find_api_method(module_, "post_event_api"); APIMethod* init_api = find_api_method(module_, "init"); APIMethod* state_api = find_api_method(module_, "advance_state"); APIMethod* current_api = find_api_method(module_, "compute_currents"); @@ -286,7 +286,8 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt if (!e->body()->statements().empty()) { out << "__global__\n" << "void " << e->name() << "(" << ppack_name << " params_) {\n" << indent - << "int n_ = params_.width_;\n"; + << "int n_ = params_.width_;\n" + << "int tid_ = threadIdx.x + blockDim.x*blockIdx.x;\n"; emit_api_body_cu(out, e, is_point_proc); out << popindent << "}\n\n"; } @@ -298,8 +299,8 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt emit_api_kernel(write_ions_api); // event delivery - if (net_receive) { - const std::string weight_arg = net_receive->args().empty() ? "weight" : net_receive->args().front()->is_argument()->name(); + if (net_receive_api) { + const std::string weight_arg = net_receive_api->args().empty() ? "weight" : net_receive_api->args().front()->is_argument()->name(); out << "__global__\n" << "void apply_events(int mech_id_, " << ppack_name << " params_, " << "deliverable_event_stream_state events) {\n" << indent @@ -312,17 +313,17 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt << "for (auto p = begin; p<end; ++p) {\n" << indent << "if (p->mech_id==mech_id_) {\n" << indent << "auto tid_ = p->mech_index;\n" - << "auto " << weight_arg << " = p->weight;\n" - << cuprint(net_receive->body()) - << popindent << "}\n" + << "auto " << weight_arg << " = p->weight;\n"; + emit_api_body_cu(out, net_receive_api, is_point_proc, false); + out << popindent << "}\n" << popindent << "}\n" << popindent << "}\n" << popindent << "}\n"; } // event delivery - if (post_event) { - const std::string time_arg = post_event->args().empty() ? "time" : post_event->args().front()->is_argument()->name(); + if (post_event_api) { + const std::string time_arg = post_event_api->args().empty() ? "time" : post_event_api->args().front()->is_argument()->name(); out << "__global__\n" << "void post_event(" << ppack_name << " params_) {\n" << indent << "int n_ = params_.width_;\n" @@ -333,9 +334,9 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt << "auto offset_ = params_.n_detectors_ * cid_;\n" << "for (unsigned c = 0; c < params_.n_detectors_; c++) {\n" << indent << "auto " << time_arg << " = params_.time_since_spike_[offset_ + c];\n" - << "if (" << time_arg << " >= 0) {\n" << indent - << cuprint(post_event->body()) - << popindent << "}\n" + << "if (" << time_arg << " >= 0) {\n" << indent; + emit_api_body_cu(out, post_event_api, is_point_proc, false); + out << popindent << "}\n" << popindent << "}\n" << popindent << "}\n" << popindent << "}\n"; @@ -363,7 +364,7 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt emit_api_wrapper(state_api); emit_api_wrapper(write_ions_api); - net_receive && out + net_receive_api && out << "void " << class_name << "_apply_events_(" << "int mech_id, " << ppack_name << "& p, deliverable_event_stream_state events) {\n" << indent @@ -373,7 +374,7 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt << "apply_events<<<grid_dim, block_dim>>>(mech_id, p, events);\n" << popindent << "}\n\n"; - post_event && out + post_event_api && out << "void " << class_name << "_post_event_(" << ppack_name << "& p) {\n" << indent << "auto n = p.width_;\n" @@ -417,7 +418,7 @@ static std::string index_i_name(const std::string& index_var) { return index_var+"i_"; } -void emit_api_body_cu(std::ostream& out, APIMethod* e, bool is_point_proc) { +void emit_api_body_cu(std::ostream& out, APIMethod* e, bool is_point_proc, bool cv_loop) { auto body = e->body(); auto indexed_vars = indexed_locals(e->scope()); @@ -445,7 +446,6 @@ void emit_api_body_cu(std::ostream& out, APIMethod* e, bool is_point_proc) { } if (!body->statements().empty()) { - out << "int tid_ = threadIdx.x + blockDim.x*blockIdx.x;\n"; if (is_point_proc) { // The run length information is only required if this method will // update an indexed variable, like current or conductance. @@ -457,7 +457,7 @@ void emit_api_body_cu(std::ostream& out, APIMethod* e, bool is_point_proc) { } } - out << "if (tid_<n_) {\n" << indent; + cv_loop && out << "if (tid_<n_) {\n" << indent; for (auto& index: indices) { out << "auto " << index_i_name(index.source_var) @@ -473,7 +473,7 @@ void emit_api_body_cu(std::ostream& out, APIMethod* e, bool is_point_proc) { for (auto& sym: indexed_vars) { emit_state_update_cu(out, sym, sym->external_variable(), is_point_proc); } - out << popindent << "}\n"; + cv_loop && out << popindent << "}\n"; } }