diff --git a/modcc/astmanip.cpp b/modcc/astmanip.cpp index 898374ff332395e40eb7572f03429bffd9f37a86..4d58b09226bc9f75c4930d95f28ceeab73b8e5db 100644 --- a/modcc/astmanip.cpp +++ b/modcc/astmanip.cpp @@ -34,7 +34,7 @@ local_declaration make_unique_local_decl(scope_ptr scope, Location loc, std::str auto local = make_expression<LocalDeclaration>(loc, name); local->semantic(scope); - auto id = make_expression<LocalDeclaration>(loc, name); + auto id = make_expression<IdentifierExpression>(loc, name); id->semantic(scope); return { std::move(local), std::move(id), scope }; diff --git a/modcc/expression.cpp b/modcc/expression.cpp index 1e06c82ae657c076d4bf0066feb7561db06778e7..250af45d7f31877992edc19be2ec0fd2e8682a38 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -690,24 +690,6 @@ void FunctionExpression::semantic(scope_type::symbol_map &global_symbols) { if(e->is_initial_block()) error("INITIAL block not allowed inside FUNCTION definition"); } - // check that the last expression in the body was an assignment to - // the return placeholder - bool last_expr_is_assign = false; - auto tail = body()->back()->is_assignment(); - if(tail) { - // we know that the tail is an assignment expression - auto lhs = tail->lhs()->is_identifier(); - // use nullptr check followed by lazy name lookup - if(lhs && lhs->name()==name()) { - last_expr_is_assign = true; - } - } - if(!last_expr_is_assign) { - warning("the last expression in function '" - + yellow(name()) - + "' does not set the return value"); - } - // the symbol for this expression is itself // this could lead to nasty self-referencing loops symbol_ = scope_->find_global(name()); diff --git a/modcc/functioninliner.cpp b/modcc/functioninliner.cpp index 20c0c2ede1394a57f065c125192f6deab76a0bd2..f238a684529c53ae9fcbe2e1d0da86d17ac9f0a6 100644 --- a/modcc/functioninliner.cpp +++ b/modcc/functioninliner.cpp @@ -1,85 +1,163 @@ #include <iostream> +#include "astmanip.hpp" #include "error.hpp" #include "functioninliner.hpp" #include "errorvisitor.hpp" -expression_ptr inline_function_call(Expression* e) +expression_ptr inline_function_call(const expression_ptr& e) { - if(auto f=e->is_function_call()) { - auto func = f->function(); -#ifdef LOGGING - std::cout << "inline_function_call for statement " << f->to_string() - << " with body" << func->body()->to_string() << "\n"; -#endif - auto& body = func->body()->statements(); - if(body.size() != 1) { - throw compiler_exception( - "can only inline functions with one statement", func->location() - ); + auto assign_to_func = e->is_assignment(); + auto ret_identifier = assign_to_func->lhs()->is_identifier(); + + if(auto f = assign_to_func->rhs()->is_function_call()) { + auto body = f->function()->body()->clone(); + + for (auto&s: body->is_block()->statements()) { + s->semantic(e->scope()); } - if(body.front()->is_if()) { - throw compiler_exception( - "can not inline functions with if statements", func->location() - ); + FunctionInliner func_inliner(f->name(), ret_identifier, f->function()->args(), f->args(), e->scope()); + + body->accept(&func_inliner); + if (!func_inliner.return_val_set()) { + throw compiler_exception(pprintf("return variable of function % not set", f->name()), e->location()); } + return body; + } + return {}; +} +/////////////////////////////////////////////////////////////////////////////// +// function inliner +/////////////////////////////////////////////////////////////////////////////// - // assume that the function body is correctly formed, with the last - // statement being an assignment expression - auto last = body.front()->is_assignment(); - auto new_e = last->rhs()->clone(); - - auto& fargs = func->args(); // argument names for the function - auto& cargs = f->args(); // arguments at the call site - for(auto i=0u; i<fargs.size(); ++i) { - if(auto id = cargs[i]->is_identifier()) { -#ifdef LOGGING - std::cout << "inline_function_call symbol replacement " - << id->to_string() << " -> " << fargs[i]->to_string() - << " in the expression " << new_e->to_string() << "\n"; -#endif - VariableReplacer v( - fargs[i]->is_argument()->spelling(), - id->spelling() - ); - new_e->accept(&v); - } - else if(auto value = cargs[i]->is_number()) { -#ifdef LOGGING - std::cout << "inline_function_call symbol replacement " - << value->to_string() << " -> " << fargs[i]->to_string() - << " in the expression " << new_e->to_string() << "\n"; -#endif - ValueInliner v( - fargs[i]->is_argument()->spelling(), - value->value() - ); - new_e->accept(&v); - } - else { - throw compiler_exception( - "can't inline functions with expressions as arguments", - e->location() - ); +// Takes a Binary or Unary Expression and replaces its variables that match any +// function argument in fargs_ with the corresponding call argument in cargs_ +void FunctionInliner::replace_args(Expression* e) { + for(auto i=0u; i<fargs_.size(); ++i) { + if(auto id = cargs_[i]->is_identifier()) { + VariableReplacer v(fargs_[i], id->spelling()); + e->accept(&v); + } + else if(auto value = cargs_[i]->is_number()) { + ValueInliner v(fargs_[i], value->value()); + e->accept(&v); + } + else { + throw compiler_exception("can't inline functions with expressions as arguments", e->location()); + } + } + e->semantic(scope_); + + ErrorVisitor v(""); + e->accept(&v); + if(v.num_errors()) { + throw compiler_exception("something went wrong with inlined function call ", e->location()); + } +} + +void FunctionInliner::visit(Expression* e) { + throw compiler_exception( + "I don't know how to do function inlining for this statement : " + + e->to_string(), e->location()); +} + +void FunctionInliner::visit(LocalDeclaration* e) { + auto loc = e->location(); + + std::map<std::string, Token> new_vars; + for (auto& var: e->variables()) { + auto unique_decl = make_unique_local_decl(scope_, loc, "r_"); + auto unique_name = unique_decl.id->is_identifier()->spelling(); + + // Local variables must be renamed to avoid collisions with the calling function. + // They are considered part of the function arguments `fargs_` and the renamed + // variable is considered part of the call arguments `cargs_` + fargs_.push_back(var.first); + cargs_.push_back(unique_decl.id->clone()); + + auto e_tok = var.second; + e_tok.spelling = unique_name; + new_vars[unique_name] = e_tok; + } + e->variables().swap(new_vars); +} + +void FunctionInliner::visit(BlockExpression* e) { + for (auto& expr: e->statements()) { + expr->accept(this); + } +} + +void FunctionInliner::visit(UnaryExpression* e) { + replace_args(e); +} + +void FunctionInliner::visit(BinaryExpression* e) { + replace_args(e); +} + +void FunctionInliner::visit(AssignmentExpression* e) { + if (auto lhs = e->lhs()->is_identifier()) { + if (lhs->spelling() == func_name_) { + e->replace_lhs(lhs_->clone()); + return_set_ = true; + } else { + for (unsigned i = 0; i < fargs_.size(); i++) { + if (fargs_[i] == lhs->spelling()) { + e->replace_lhs(cargs_[i]->clone()); + break; + } } } - new_e->semantic(e->scope()); - - ErrorVisitor v(""); - new_e->accept(&v); -#ifdef LOGGING - std::cout << "inline_function_call result " << new_e->to_string() << "\n\n"; -#endif - if(v.num_errors()) { - throw compiler_exception("something went wrong with inlined function call ", - e->location()); + } + + if (auto rhs = e->rhs()->is_identifier()) { + for (unsigned i = 0; i < fargs_.size(); i++) { + if (fargs_[i] == rhs->spelling()) { + e->replace_rhs(cargs_[i]->clone()); + break; + } } + } + else { + e->rhs()->accept(this); + } +} + +void FunctionInliner::visit(IfExpression* e) { + bool if_ret; + bool save_ret = return_set_; + + return_set_ = false; - return new_e; + e->condition()->accept(this); + e->true_branch()->accept(this); + + if_ret = return_set_; + return_set_ = false; + + if (e->false_branch()) { + e->false_branch()->accept(this); } - return {}; + if_ret &= return_set_; + + return_set_ = save_ret? save_ret: if_ret; +} + +void FunctionInliner::visit(CallExpression* e) { + for (auto& a: e->is_function_call()->args()) { + for (unsigned i = 0; i < fargs_.size(); i++) { + if (auto id = a->is_identifier()) { + if (fargs_[i] == id->spelling()) { + a = cargs_[i]->clone(); + } + } else { + a->accept(this); + } + } + } } /////////////////////////////////////////////////////////////////////////////// diff --git a/modcc/functioninliner.hpp b/modcc/functioninliner.hpp index f100ddbc12914d0536025fbd643637cf065c4af4..c302e36f2be73b71f2e9b1fe29b8d27bcb78c54a 100644 --- a/modcc/functioninliner.hpp +++ b/modcc/functioninliner.hpp @@ -5,7 +5,53 @@ #include "scope.hpp" #include "visitor.hpp" -expression_ptr inline_function_call(Expression* e); +// Takes an assignment to a function call, returns an inlined +// version without modifying the original expression's contents +expression_ptr inline_function_call(const expression_ptr& e); + +class FunctionInliner : public Visitor { + +public: + + FunctionInliner(std::string func_name, + Expression* lhs, + const std::vector<expression_ptr>& fargs, + const std::vector<expression_ptr>& cargs, + const scope_ptr& scope) : + func_name_(func_name), lhs_(lhs->clone()), scope_(scope) { + for (auto& f: fargs) { + fargs_.push_back(f->is_argument()->spelling()); + } + for (auto& c: cargs) { + cargs_.push_back(c->clone()); + } + } + + void visit(Expression* e) override; + void visit(UnaryExpression* e) override; + void visit(BinaryExpression* e) override; + void visit(BlockExpression *e) override; + void visit(AssignmentExpression* e) override; + void visit(IfExpression* e) override; + void visit(LocalDeclaration* e) override; + void visit(CallExpression* e) override; + void visit(NumberExpression* e) override {}; + + bool return_val_set() {return return_set_;}; + + ~FunctionInliner() {} + +private: + std::string func_name_; + expression_ptr lhs_; + std::vector<std::string> fargs_; + std::vector<expression_ptr> cargs_; + scope_ptr scope_; + bool return_set_ = false; + + void replace_args(Expression* e); + +}; class VariableReplacer : public Visitor { diff --git a/modcc/module.cpp b/modcc/module.cpp index 2df8ec16b7840c2256a99b4a105b1ecc2a5728db..b982d2059c433ef1e5960ed60c975cd99fbdb438 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -699,6 +699,9 @@ void Module::add_variables_to_symbols() { } int Module::semantic_func_proc() { + bool keep_inlining = true; + int errors = 0; + //////////////////////////////////////////////////////////////////////////// // now iterate over the functions and procedures and perform semantic // analysis on each. This includes @@ -706,107 +709,126 @@ int Module::semantic_func_proc() { // - generate local variable table for each function/procedure // - inlining function calls //////////////////////////////////////////////////////////////////////////// -#ifdef LOGGING - std::cout << white("===================================\n"); - std::cout << cyan(" Function Inlining\n"); - std::cout << white("===================================\n"); -#endif - int errors = 0; - for(auto& e : symbols_) { - auto& s = e.second; - if( s->kind() == symbolKind::function - || s->kind() == symbolKind::procedure) - { -#ifdef LOGGING - std::cout << "\nfunction inlining for " << s->location() << "\n" - << s->to_string() << "\n" - << green("\n-call site lowering-\n\n"); -#endif - // first perform semantic analysis - s->semantic(symbols_); - - // then use an error visitor to print out all the semantic errors - ErrorVisitor v(source_name()); - s->accept(&v); - errors += v.num_errors(); - - // inline function calls - // this requires that the symbol table has already been built - if(v.num_errors()==0) { - auto &b = s->kind()==symbolKind::function ? - s->is_function()->body()->statements() : - s->is_procedure()->body()->statements(); - - // lower function call sites so that all function calls are of - // the form : variable = call(<args>) - // e.g. - // a = 2 + foo(2+x, y, 1) - // becomes - // ll0_ = foo(2+x, y, 1) - // a = 2 + ll0_ - for(auto e=b.begin(); e!=b.end(); ++e) { - b.splice(e, lower_function_calls((*e).get())); - } -#ifdef LOGGING - std::cout << "body after call site lowering\n"; - for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n"; - std::cout << green("\n-argument lowering-\n\n"); -#endif - - // lower function arguments that are not identifiers or literals - // e.g. - // ll0_ = foo(2+x, y, 1) - // a = 2 + ll0_ - // becomes - // ll1_ = 2+x - // ll0_ = foo(ll1_, y, 1) - // a = 2 + ll0_ - for(auto e=b.begin(); e!=b.end(); ++e) { - if(auto be = (*e)->is_binary()) { - // only apply to assignment expressions where rhs is a - // function call because the function call lowering step - // above ensures that all function calls are of this form - if(auto rhs = be->rhs()->is_function_call()) { - b.splice(e, lower_function_arguments(rhs->args())); + while (keep_inlining) { + #ifdef LOGGING + std::cout << white("===================================\n"); + std::cout << cyan(" Function Inlining\n"); + std::cout << white("===================================\n"); + #endif + keep_inlining = false; + + for (auto& e: symbols_) { + auto& s = e.second; + if(s->kind() == symbolKind::procedure || s->kind() == symbolKind::function) { + + #ifdef LOGGING + std::cout << "\nfunction inlining for " << s->location() << "\n" + << s->to_string() << "\n" + << green("\n-call site lowering-\n\n"); + #endif + + // perform semantic analysis + s->semantic(symbols_); + + // then use an error visitor to print out all the semantic errors + ErrorVisitor v(source_name()); + s->accept(&v); + errors += v.num_errors(); + + // inline function calls + // this requires that the symbol table has already been built + if (v.num_errors() == 0) { + auto &b = s->kind() == symbolKind::function ? + s->is_function()->body()->statements() : + s->is_procedure()->body()->statements(); + + // lower function call sites so that all function calls are of + // the form : variable = call(<args>) + // e.g. + // a = 2 + foo(2+x, y, 1) + // becomes + // ll0_ = foo(2+x, y, 1) + // a = 2 + ll0_ + for (auto e = b.begin(); e != b.end(); ++e) { + b.splice(e, lower_function_calls((*e).get())); + } + #ifdef LOGGING + std::cout << "body after call site lowering\n"; + for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n"; + std::cout << green("\n-argument lowering-\n\n"); + #endif + // lower function arguments that are not identifiers or literals + // e.g. + // ll0_ = foo(2+x, y, 1) + // a = 2 + ll0_ + // becomes + // ll1_ = 2+x + // ll0_ = foo(ll1_, y, 1) + // a = 2 + ll0_ + for (auto e = b.begin(); e != b.end(); ++e) { + if (auto be = (*e)->is_binary()) { + // only apply to assignment expressions where rhs is a + // function call because the function call lowering step + // above ensures that all function calls are of this form + if (auto rhs = be->rhs()->is_function_call()) { + b.splice(e, lower_function_arguments(rhs->args())); + } } } + + #ifdef LOGGING + std::cout << "body after argument lowering\n"; + for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n"; + std::cout << green("\n-inlining-\n\n"); + #endif } + } + } -#ifdef LOGGING - std::cout << "body after argument lowering\n"; - for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n"; - std::cout << green("\n-inlining-\n\n"); -#endif - - // Do the inlining, which currently only works for functions - // that have a single statement in their body - // e.g. if the function foo in the examples above is defined as follows - // - // function foo(a, b, c) { - // foo = a*(b + c) - // } - // - // the full inlined example is - // ll1_ = 2+x - // ll0_ = ll1_*(y + 1) - // a = 2 + ll0_ - for(auto e=b.begin(); e!=b.end(); ++e) { - if(auto ass = (*e)->is_assignment()) { - if(ass->rhs()->is_function_call()) { - ass->replace_rhs(inline_function_call(ass->rhs())); + for(auto& e : symbols_) { + auto& s = e.second; + + if(s->kind() == symbolKind::procedure) + { + if(errors==0) { + auto &b = s->kind()==symbolKind::function ? + s->is_function()->body()->statements() : + s->is_procedure()->body()->statements(); + + // Do the inlining: supports multiline functions and if/else statements + // e.g. if the function foo in the examples above is defined as follows + // + // function foo(a, b, c) { + // Local t = b + c + // foo = a*t + // } + // + // the full inlined example is + // ll1_ = 2+x + // r_0_ = y+1 + // ll0_ = ll1_*r_0_ + // a = 2 + ll0_ + + for (auto &e: b) { + if (auto ass = e->is_assignment()) { + if (ass->rhs()->is_function_call()) { + e = inline_function_call(e); + keep_inlining = true; + } } } - } -#ifdef LOGGING - std::cout << "body after inlining\n"; - for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n"; -#endif - // Finally, run a constant simplification pass. - if (auto proc = s->is_procedure()) { - proc->body(constant_simplify(proc->body())); - s->semantic(symbols_); + + #ifdef LOGGING + std::cout << "body after inlining\n"; + for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n"; + #endif + // Finally, run a constant simplification pass. + if (auto proc = s->is_procedure()) { + proc->body(constant_simplify(proc->body())); + s->semantic(symbols_); + } } } } diff --git a/modcc/symdiff.cpp b/modcc/symdiff.cpp index 75a1b76a994d3924898dfa4243feed03f9d80c4a..f8573b6a52fd10378b00b602fcb6b7424e766022 100644 --- a/modcc/symdiff.cpp +++ b/modcc/symdiff.cpp @@ -328,8 +328,11 @@ public: auto cond_expr = result(); e->true_branch()->accept(this); auto true_expr = result(); - e->false_branch()->accept(this); - auto false_expr = result(); + expression_ptr false_expr; + if (e->false_branch()) { + e->false_branch()->accept(this); + false_expr = result()->clone(); + } if (!is_number(cond_expr)) { result_ = make_expression<IfExpression>(loc, diff --git a/modcc/visitor.hpp b/modcc/visitor.hpp index feb9d3a64ff8deb91d7bfc33b187611d8243073d..e41123ade7627520f5d661d9a84e2dcf825259c8 100644 --- a/modcc/visitor.hpp +++ b/modcc/visitor.hpp @@ -130,11 +130,14 @@ public: true); statements_.clear(); - e->false_branch()->accept(this); - auto false_branch = make_expression<BlockExpression>( - e->false_branch()->location(), - std::move(statements_), - true); + expression_ptr false_branch; + if (e->false_branch()) { + e->false_branch()->accept(this); + false_branch = make_expression<BlockExpression>( + e->false_branch()->location(), + std::move(statements_), + true); + } statements_ = std::move(outer); statements_.push_back(make_expression<IfExpression>( diff --git a/test/unit-modcc/mod_files/test6.mod b/test/unit-modcc/mod_files/test6.mod new file mode 100644 index 0000000000000000000000000000000000000000..70b8431997f8fc9c8c3d9deb6a5d24ec022e1c50 --- /dev/null +++ b/test/unit-modcc/mod_files/test6.mod @@ -0,0 +1,65 @@ +NEURON { + SUFFIX test_inilining +} + +UNITS { + (mV) = (millivolt) + (S) = (siemens) +} + +PARAMETER { + alpha = 2 + beta = 4.5 + gamma = -15 + delta = -0.2 +} + +STATE { + s0 + s1 + s2 +} + +BREAKPOINT { + s1 = foo(alpha, beta) + rates(delta) + s0 = s1 * s2 +} + +FUNCTION foo(x, y) { + LOCAL temp + if (x == 3) { + foo = 2 * y + } else if (x == 4) { + foo = y + } else { + temp = exp(y) + foo = temp * x + } +} + +FUNCTION dip(q, p) { + LOCAL temp + temp = log(p) + dip = 2*q*temp +} + +FUNCTION jab(x, y) { + jab = x*dip(21, x/y) +} + +FUNCTION bar(x, y) { + LOCAL p + p = y/3 + bar = foo(x, x+2) * jab(p, y) +} + +PROCEDURE rates(x) +{ + LOCAL t0, t1, t2 + + t0 = bar(s1, s2) + t1 = exprelr(t0) + t2 = foo(t1 + 2, 5) + s2 = t2 + 4 +} \ No newline at end of file diff --git a/test/unit-modcc/test_printers.cpp b/test/unit-modcc/test_printers.cpp index 530f266e06fb7805670dc3d561825cdfa7096654..54a62db574eb076d0164207c0b206a6cc5166801 100644 --- a/test/unit-modcc/test_printers.cpp +++ b/test/unit-modcc/test_printers.cpp @@ -4,6 +4,7 @@ #include <sstream> #include "common.hpp" +#include "io/bulkio.hpp" #include "printer/cexpr_emit.hpp" #include "printer/cprinter.hpp" @@ -195,4 +196,63 @@ TEST(CPrinter, proc_body_const) { EXPECT_EQ(strip(tc.expected), strip(text)); } +} + +TEST(CPrinter, proc_body_inlined) { + const char* expected = + " r_0_ = s2[i_]/ 3;\n" + " r_3_ = s1[i_]+ 2;\n" + " if (s1[i_]== 3) {\n" + " r_2_ = 2*r_3_;\n" + " }\n" + " else if (s1[i_]== 4) {\n" + " r_2_ = r_3_;\n" + " }\n" + " else {\n" + " r_5_ = exp(r_3_);\n" + " r_2_ = r_5_*s1[i_];\n" + " }\n" + "\n" + " r_7_ = r_0_/s2[i_];\n" + " r_8_ = log(r_7_);\n" + " r_6_ = 42*r_8_;\n" + " r_1_ = r_0_*r_6_;\n" + " t0 = r_2_*r_1_;\n" + " t1 = exprelr(t0);\n" + " ll0_ = t1+ 2;\n" + " if (ll0_== 3) {\n" + " t2 = 10;\n" + " }\n" + " else if (ll0_== 4) {\n" + " t2 = 5;\n" + " }\n" + " else {\n" + " r_4_ = 148.4131591025766;\n" + " t2 = r_4_*ll0_;\n" + " }\n" + " s2[i_] = t2+ 4;"; + + Module m(io::read_all(DATADIR "/mod_files/test6.mod"), "test6.mod"); + Parser p(m, false); + p.parse(); + m.semantic(); + + auto& proc_rates = m.symbols().at("rates"); + + ASSERT_TRUE(proc_rates->is_symbol()); + + std::stringstream out; + auto v = std::make_unique<CPrinter>(out); + proc_rates->is_procedure()->body()->accept(v.get()); + std::string text = out.str(); + + verbose_print(proc_rates->is_procedure()->body()->to_string()); + verbose_print(" :--: ", text); + + // Remove the first statement that declares the locals + // Their print order is not fixed + auto proc_with_locals = strip(text); + proc_with_locals.erase(0, proc_with_locals.find(";") + 1); + + EXPECT_EQ(strip(expected), proc_with_locals); } \ No newline at end of file