diff --git a/mechanisms/allen/Nap.mod b/mechanisms/allen/Nap.mod index 5eac9d53ced960b2ed5426877062a7a85e833ad1..c83c795d5d983c012dedf60469fc25e01aae8cac 100644 --- a/mechanisms/allen/Nap.mod +++ b/mechanisms/allen/Nap.mod @@ -44,9 +44,9 @@ DERIVATIVE states { h' = (hInf-h)/hTau } -INITIAL{ - rates(v) - h = hInf +INITIAL { + rates(v) + h = hInf } PROCEDURE rates(v){ diff --git a/modcc/CMakeLists.txt b/modcc/CMakeLists.txt index 102b397f321f2a049fbdc529328b2e9b17d1f05a..59134b29cb08913a777f6c0f1583ebdc778f6b3a 100644 --- a/modcc/CMakeLists.txt +++ b/modcc/CMakeLists.txt @@ -2,13 +2,13 @@ # unit tests for the driver also use this library. set(libmodcc_sources - astmanip.cpp blocks.cpp errorvisitor.cpp expression.cpp functionexpander.cpp functioninliner.cpp + procinliner.cpp lexer.cpp kineticrewriter.cpp linearrewriter.cpp diff --git a/modcc/functioninliner.cpp b/modcc/functioninliner.cpp index 6d727e7f29d5cc6f47508cad5a10b82d49fa6a6f..e0435b0c367d2d87576f6f989bd058b2023429a6 100644 --- a/modcc/functioninliner.cpp +++ b/modcc/functioninliner.cpp @@ -158,7 +158,7 @@ void FunctionInliner::visit(AssignmentExpression* e) { auto body = f->function()->body()->clone(); for (auto&s: body->is_block()->statements()) { - s->semantic(e->scope()); + s->semantic(scope_); } body->accept(this); diff --git a/modcc/module.cpp b/modcc/module.cpp index 45fc00d5f5d4f7aeb6d8cdf907094df44c70e615..6b76116267a4f47c19f8fc27b67b01aa08fc5796 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -10,6 +10,7 @@ #include "errorvisitor.hpp" #include "functionexpander.hpp" #include "functioninliner.hpp" +#include "procinliner.hpp" #include "kineticrewriter.hpp" #include "linearrewriter.hpp" #include "module.hpp" @@ -905,8 +906,14 @@ int Module::semantic_func_proc() { } auto inline_and_simplify = [&](auto&& caller) { - auto rewritten = inline_function_calls(caller->name(), caller->body()); - caller->body(std::move(rewritten)); + { + auto rewritten = inline_function_calls(caller->name(), caller->body()); + caller->body(std::move(rewritten)); + } + { + auto rewritten = inline_procedure_calls(caller->name(), caller->body()); + caller->body(std::move(rewritten)); + } caller->body(constant_simplify(caller->body())); }; diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp index 6e93c1d0c15b194ea7cac5436f4282dfa5f95149..82f3bfb14a6188328b8eee6e67b489afac8b32d6 100644 --- a/modcc/printer/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -349,9 +349,9 @@ void SimdExprEmitter::visit(BlockExpression* block) { void SimdExprEmitter::visit(CallExpression* e) { if(is_indirect_) - out_ << e->name() << "(index_"; + out_ << e->name() << "(pp, index_"; else - out_ << e->name() << "(i_"; + out_ << e->name() << "(pp, i_"; if (processing_true_ && !current_mask_.empty()) { out_ << ", " << current_mask_; diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index d86df4faee5a2085b94f4c1dab3e0ac4db8a5c3d..5d6e95352ea8550f8d48bf3a144542c7a9acdbaf 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -49,9 +49,6 @@ struct index_prop { } }; -void emit_procedure_proto(std::ostream&, ProcedureExpression*, const std::string&, const std::string& qualified = ""); -void emit_simd_procedure_proto(std::ostream&, ProcedureExpression*, const std::string&, const std::string& qualified = ""); -void emit_masked_simd_procedure_proto(std::ostream&, ProcedureExpression*, const std::string&, const std::string& qualified = ""); void emit_api_body(std::ostream&, APIMethod*, const ApiFlags& flags={}); void emit_simd_api_body(std::ostream&, APIMethod*, const std::vector<VariableExpression*>& scalars, const ApiFlags&); void emit_simd_index_initialize(std::ostream& out, const std::list<index_prop>& indices, simd_expr_constraint constraint); @@ -122,7 +119,6 @@ struct simdprint { ARB_LIBMODCC_API std::string emit_cpp_source(const Module& module_, const printer_options& opt) { auto name = module_.module_name(); auto namespace_name = "kernel_" + name; - auto ppack_name = "arb_mechanism_ppack"; auto ns_components = namespace_components(opt.cpp_namespace); APIMethod* net_receive_api = find_api_method(module_, "net_rec_api"); @@ -290,23 +286,10 @@ ARB_LIBMODCC_API std::string emit_cpp_source(const Module& module_, const printe out << fmt::format("[[maybe_unused]] auto* {}{} = pp->ion_states[{}].index;\\\n", pp_var_pfx, ion_index(ion), idx); idx++; } - out << "//End of IFACEBLOCK\n\n"; - - out << "// procedure prototypes\n"; - for (auto proc: normal_procedures(module_)) { - if (with_simd) { - emit_simd_procedure_proto(out, proc, ppack_name); - out << ";\n"; - emit_masked_simd_procedure_proto(out, proc, ppack_name); - out << ";\n"; - } else { - emit_procedure_proto(out, proc, ppack_name); - out << ";\n"; - } - } - out << "\n" - << "// interface methods\n"; - out << "static void init(arb_mechanism_ppack* pp) {\n" << indent; + out << "//End of IFACEBLOCK\n\n" + << "\n" + << "// interface methods\n" + << "static void init(arb_mechanism_ppack* pp) {\n" << indent; emit_body(init_api); if (init_api && init_api->body() && !init_api->body()->statements().empty()) { auto n = std::count_if(vars.arrays.begin(), vars.arrays.end(), @@ -374,36 +357,6 @@ ARB_LIBMODCC_API std::string emit_cpp_source(const Module& module_, const printe out << "static void post_event(arb_mechanism_ppack*) {}\n"; } - out << "\n// Procedure definitions\n"; - for (auto proc: normal_procedures(module_)) { - if (with_simd) { - emit_simd_procedure_proto(out, proc, ppack_name); - auto simd_print = simdprint(proc->body(), vars.scalars); - out << " {\n" - << indent - << "PPACK_IFACE_BLOCK;\n" - << simd_print - << popindent - << "}\n\n"; - - emit_masked_simd_procedure_proto(out, proc, ppack_name); - auto masked_print = simdprint(proc->body(), vars.scalars); - masked_print.set_masked(); - out << " {\n" - << indent - << "PPACK_IFACE_BLOCK;\n" - << masked_print - << popindent - << "}\n\n"; - } else { - emit_procedure_proto(out, proc, ppack_name); - out << " {\n" << indent - << "PPACK_IFACE_BLOCK;\n" - << cprint(proc->body()) - << popindent << "}\n"; - } - } - out << popindent << "#undef PPACK_IFACE_BLOCK\n" << "} // namespace kernel_" << name @@ -495,14 +448,6 @@ static std::string index_i_name(const std::string& index_var) { return index_var+"i_"; } -void emit_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& ppack_name, const std::string& qualified) { - out << "[[maybe_unused]] static void " << qualified << (qualified.empty()? "": "::") << e->name() << "(" << ppack_name << "* pp, int i_"; - for (auto& arg: e->args()) { - out << ", arb_value_type " << arg->is_argument()->name(); - } - out << ")"; -} - namespace { // Access through ppack std::string data_via_ppack(const indexed_variable_info& i) { return pp_var_pfx + i.data_var; } @@ -759,27 +704,6 @@ void SimdPrinter::visit(BlockExpression* block) { EXITM(out_, "block"); } -void emit_simd_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& ppack_name, const std::string& qualified) { - ENTER(out); - out << "[[maybe_unused]] static void " << qualified << (qualified.empty()? "": "::") << e->name() << "(arb_mechanism_ppack* pp, arb_index_type i_"; - for (auto& arg: e->args()) { - out << ", const simd_value& " << arg->is_argument()->name(); - } - out << ")"; - EXIT(out); -} - -void emit_masked_simd_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& ppack_name, const std::string& qualified) { - ENTER(out); - out << "[[maybe_unused]] static void " << qualified << (qualified.empty()? "": "::") << e->name() - << "(arb_mechanism_ppack* pp, arb_index_type i_, simd_mask mask_input_"; - for (auto& arg: e->args()) { - out << ", const simd_value& " << arg->is_argument()->name(); - } - out << ")"; - EXIT(out); -} - void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_constraint constraint) { ENTER(out); out << "simd_value " << local->name(); diff --git a/modcc/printer/gpuprinter.cpp b/modcc/printer/gpuprinter.cpp index 8285683daee8de226913f9f258b5535111fd5a3d..7b326fea93b61d110efd547c89c3e52d1ccd376d 100644 --- a/modcc/printer/gpuprinter.cpp +++ b/modcc/printer/gpuprinter.cpp @@ -29,7 +29,6 @@ static std::string scaled(double coeff) { void emit_api_body_cu(std::ostream& out, APIMethod* method, const ApiFlags&); -void emit_procedure_body_cu(std::ostream& out, ProcedureExpression* proc); void emit_state_read_cu(std::ostream& out, LocalVariable* local); void emit_state_update_cu(std::ostream& out, Symbol* from, IndexedVariable* external, const ApiFlags&); @@ -193,32 +192,6 @@ ARB_LIBMODCC_API std::string emit_gpu_cu_source(const Module& module_, const pri << "using ::arb::gpu::min;\n" << "using ::arb::gpu::max;\n\n"; - // Procedures as __device__ functions. - auto emit_procedure_proto = [&] (ProcedureExpression* e) { - out << fmt::format("__device__\n" - "void {}(arb_mechanism_ppack params_, int tid_", - e->name()); - for(auto& arg: e->args()) out << ", arb_value_type " << arg->is_argument()->name(); - out << ")"; - }; - - auto emit_procedure_kernel = [&] (ProcedureExpression* e) { - emit_procedure_proto(e); - out << " {\n" << indent - << "PPACK_IFACE_BLOCK;\n" - << cuprint(e->body()) - << popindent << "}\n\n"; - }; - - for (auto& p: module_normal_procedures(module_)) { - emit_procedure_proto(p); - out << ";\n\n"; - } - - for (auto& p: module_normal_procedures(module_)) { - emit_procedure_kernel(p); - } - // API methods as __global__ kernels. auto emit_api_kernel = [&] (APIMethod* e, bool additive=false) { // Only print the kernel if the method is not empty. @@ -450,10 +423,6 @@ void emit_api_body_cu(std::ostream& out, APIMethod* e, const ApiFlags& flags) { } } -void emit_procedure_body_cu(std::ostream& out, ProcedureExpression* e) { - out << cuprint(e->body()); -} - namespace { // Convenience I/O wrapper for emitting indexed access to an external variable. diff --git a/modcc/procinliner.cpp b/modcc/procinliner.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9b3f15269cd6024de26a08fce6e65e3ebece888d --- /dev/null +++ b/modcc/procinliner.cpp @@ -0,0 +1,208 @@ +#include "procinliner.hpp" + +#include <iostream> + +#include "astmanip.hpp" +#include "error.hpp" +#include "errorvisitor.hpp" +#include "symdiff.hpp" + +// Note: on renaming variables when inlining: +// Identifiers will refer to one of +// - LOCAL variables +// - argument, +// - global variable +// - PARAMETER +// - ASSIGNED +// +// All local variables are renamed and the mapping is stored in local_arg_map_. +// The mapping from arguments to local names is in call_args_map_. +// +// Local variable renaming of identifiers should be performed before call +// argument renaming. This means that if a local variable shadows an argument, +// the local variable takes precedence. + +void check_errors(Expression* e) { + ErrorVisitor v(""); + e->accept(&v); + if(v.num_errors()) throw compiler_exception("something went wrong with inlined procedure call ", e->location()); +} + +ARB_LIBMODCC_API expression_ptr inline_procedure_calls(std::string caller, BlockExpression* block) { + // The inliner will inline one procedure at a time. Once all procedures in a + // block have been inlined, the while loop will be broken + auto inline_block = block->clone(); + for(;;) { + inline_block->semantic(block->scope()); + + auto inliner = std::make_unique<ProcedureInliner>(caller); + inline_block->accept(inliner.get()); + inline_block = inliner->as_block(false); + if (inliner->state_ == ProcedureInliner::state::Done) break; + break; + } + return inline_block; +} + +// The inliner works on inlining one item at a time. +void ProcedureInliner::visit(Expression* e) { + if (state_ == state::Running) + throw compiler_exception("I don't know how to do procedure inlining for this statement: " + + e->to_string(), + e->location()); + statements_.push_back(e->clone()); +} + +// Only in procedures, always stays the same +void ProcedureInliner::visit(ConserveExpression *e) { + statements_.push_back(e->clone()); +} + +// Only in procedures, always stays the same +void ProcedureInliner::visit(CompartmentExpression *e) { + statements_.push_back(e->clone()); +} + +// Only in procedures, always stays the same +void ProcedureInliner::visit(LinearExpression *e) { + statements_.push_back(e->clone()); +} + +void ProcedureInliner::visit(LocalDeclaration* e) { + // If we are active, we need to rename variables + if (state_ == state::Running) { + std::map<std::string, Token> new_vars; + for (auto& var: e->variables()) { + auto unique_decl = make_unique_local_decl(scope_, e->location(), "r_"); + auto unique_name = unique_decl.id->is_identifier()->spelling(); + + // Local variables must be renamed to avoid collisions with the caller. + // The mappings are stored in local_arg_map + local_arg_map_.emplace(std::make_pair(var.first, std::move(unique_decl.id))); + + auto e_tok = var.second; + e_tok.spelling = unique_name; + new_vars[unique_name] = e_tok; + } + e->variables().swap(new_vars); + } + statements_.push_back(e->clone()); +} + +void ProcedureInliner::visit(UnaryExpression* e) { + if (state_ != state::Running) return; + + auto sub = substitute(e->expression(), local_arg_map_); + sub = substitute(sub, call_arg_map_); + e->replace_expression(std::move(sub)); + e->semantic(scope_); + check_errors(e); +} + +void ProcedureInliner::visit(BinaryExpression* e) { + if (state_ != state::Running) return; + e->replace_lhs(substitute(substitute(e->lhs(), local_arg_map_), call_arg_map_)); + e->replace_rhs(substitute(substitute(e->rhs(), local_arg_map_), call_arg_map_)); + e->semantic(scope_); + check_errors(e); +} + + +void ProcedureInliner::visit(AssignmentExpression* e) { + // If we're inlining a call, take care of variable renaming + if (state_ == state::Running) { + if (auto lhs = e->lhs()->is_identifier()) { + e->replace_lhs(substitute(e->lhs(), local_arg_map_)); + } + if (e->rhs()->is_identifier()) { + auto sub_rhs = substitute(e->rhs(), local_arg_map_); + sub_rhs = substitute(sub_rhs, call_arg_map_); + e->replace_rhs(std::move(sub_rhs)); + } + else { + e->rhs()->accept(this); + } + } + statements_.push_back(e->clone()); +} + +void ProcedureInliner::visit(IfExpression* e) { + expr_list_type outer; + std::swap(outer, statements_); + + e->condition()->accept(this); + e->true_branch()->accept(this); + auto true_branch = make_expression<BlockExpression>( + e->true_branch()->location(), + std::move(statements_), + true); + + statements_.clear(); + + expression_ptr false_branch; + if (e->false_branch()) { + e->false_branch()->accept(this); + false_branch = make_expression<BlockExpression>( + e->false_branch()->location(), + std::move(statements_), + true); + } + + statements_.clear(); + + statements_ = std::move(outer); + statements_.push_back(make_expression<IfExpression>( + e->location(), + e->condition()->clone(), + std::move(true_branch), + std::move(false_branch))); +} + +void ProcedureInliner::visit(CallExpression* e) { + if (state_ == state::Running) { + if (e->is_procedure_call()) { + auto nm = e->is_procedure_call()->name(); + if (nm == callee_ || nm == caller_) throw compiler_exception("recursive procedures not allowed", e->location()); + } + auto& args = e->is_procedure_call() + ? e->is_procedure_call()->args() + : e->is_function_call()->args(); + for (auto& a: args) { + if (a->is_identifier()) { + a = substitute(substitute(a, local_arg_map_), call_arg_map_); + } + else { + a->accept(this); + } + } + } + else if (state_ == state::Ready) { + // If we are ready to do some inlining, check if we can indeed inline this statement + if (auto call = e->is_procedure_call(); call != nullptr) { + // fetch the procedure, its body and its formal args + const auto& proc = call->procedure(); + const auto& body = proc->body()->clone(); + const auto& args = proc->args(); + + // store the args we are actually called with to do replacement + const auto& subs = call->args(); + for (unsigned i = 0; i < args.size(); ++i) { + call_arg_map_.emplace( + std::make_pair(args[i]->is_argument()->spelling(), + subs[i]->clone())); + } + + scope_ = e->scope(); + callee_ = proc->name(); + state_ = state::Running; + + for (auto& s: body->is_block()->statements()) s->semantic(scope_); + body->accept(this); + + state_ = state::Done; + } + } + else if (e->is_procedure_call()) { + statements_.push_back(e->clone()); + } +} diff --git a/modcc/procinliner.hpp b/modcc/procinliner.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dcc748635a4a425efaccff4fed7fbc323856d376 --- /dev/null +++ b/modcc/procinliner.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include <sstream> + +#include "scope.hpp" +#include "visitor.hpp" +#include <libmodcc/export.hpp> + +ARB_LIBMODCC_API expression_ptr inline_procedure_calls(std::string caller, BlockExpression* block); + +struct ARB_LIBMODCC_API ProcedureInliner: + public BlockRewriterBase +{ + enum class state { + Done, + Running, + Ready, + }; + + using BlockRewriterBase::visit; + ProcedureInliner(std::string caller) : BlockRewriterBase(), caller_(caller) {}; + ProcedureInliner(scope_ptr s): BlockRewriterBase(s) {} + + virtual void visit(Expression *e) override; + virtual void visit(CallExpression *e) override; + virtual void visit(ConserveExpression *e) override; + virtual void visit(CompartmentExpression *e) override; + virtual void visit(LinearExpression *e) override; + virtual void visit(AssignmentExpression* e) override; + virtual void visit(BinaryExpression* e) override; + virtual void visit(UnaryExpression* e) override; + virtual void visit(IfExpression* e) override; + virtual void visit(LocalDeclaration* e) override; + virtual void visit(NumberExpression* e) override {}; + virtual void visit(IdentifierExpression* e) override {}; + + std::string callee_, caller_; + expression_ptr lhs_; + std::map<std::string, expression_ptr> call_arg_map_; + std::map<std::string, expression_ptr> local_arg_map_; + scope_ptr scope_; + + state state_ = state::Ready; + +protected: + virtual void reset() override { + state_ = state::Ready; + callee_.clear(); + lhs_ = nullptr; + call_arg_map_.clear(); + local_arg_map_.clear(); + scope_.reset(); + BlockRewriterBase::reset(); + } +}; diff --git a/test/unit-modcc/test_printers.cpp b/test/unit-modcc/test_printers.cpp index a188edc0862d2c69ebcf1fd23cf8468a7b36ef71..89b1ddf95d570e54a6dd513c0cdb4512cdec8685 100644 --- a/test/unit-modcc/test_printers.cpp +++ b/test/unit-modcc/test_printers.cpp @@ -290,13 +290,19 @@ TEST(SimdPrinter, simd_if_else) { "indirect(_pp_var_s+i_, simd_width_) = S::where(S::logical_and(S::logical_not(mask_1_), mask_input_),simd_cast<simd_value>((double)42.0));\n" "indirect(_pp_var_s+i_, simd_width_) = S::where(mask_input_, u);" , + "simd_value r_0_;" "simd_mask mask_2_ = S::cmp_gt(simd_cast<simd_value>(indirect(_pp_var_g+i_, simd_width_)), (double)2.0);\n" "simd_mask mask_3_ = S::cmp_gt(simd_cast<simd_value>(indirect(_pp_var_g+i_, simd_width_)), (double)3.0);\n" "S::where(S::logical_and(mask_2_,mask_3_),i) = (double)0.;\n" "S::where(S::logical_and(mask_2_,S::logical_not(mask_3_)),i) = (double)1.0;\n" "simd_mask mask_4_ = S::cmp_lt(simd_cast<simd_value>(indirect(_pp_var_g+i_, simd_width_)), (double)1.0);\n" "indirect(_pp_var_s+i_, simd_width_) = S::where(S::logical_and(S::logical_not(mask_2_),mask_4_),simd_cast<simd_value>((double)2.0));\n" - "rates(i_, S::logical_and(S::logical_not(mask_2_),S::logical_not(mask_4_)), i);" + // This is the inlined call rates(pp, i_, S::logical_and(S::logical_not(mask_2_),S::logical_not(mask_4_)), i); + "simd_maskmask_5_=S::cmp_gt(i,(double)2.0);" + "S::where(S::logical_and(S::logical_and(S::logical_not(mask_2_),S::logical_not(mask_4_)),mask_5_),r_0_)=(double)7.0;" + "S::where(S::logical_and(S::logical_and(S::logical_not(mask_2_),S::logical_not(mask_4_)),S::logical_not(mask_5_)),r_0_)=(double)5.0;" + "indirect(_pp_var_s+i_,simd_width_)=S::where(S::logical_and(S::logical_and(S::logical_not(mask_2_),S::logical_not(mask_4_)),S::logical_not(mask_5_)),simd_cast<simd_value>((double)42.0));" + "indirect(_pp_var_s+i_,simd_width_)=S::where(S::logical_and(S::logical_not(mask_2_),S::logical_not(mask_4_)),r_0_);" }; Module m(io::read_all(DATADIR "/mod_files/test7.mod"), "test7.mod");