diff --git a/modcc/module.cpp b/modcc/module.cpp index 2b1590d667c3872aa4dfd142ecd0fea816baba20..31a7b1261c6bbba6c6b888e0bda3e5701aa64bd1 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -209,119 +209,21 @@ bool Module::semantic() { return true; }; + // Add built in function that approximate exp use pade polynomials + functions_.push_back( + Parser{"FUNCTION exp_pade_11(z) { exp_pade_11=(1+0.5*z)/(1-0.5*z) }"}.parse_function()); + functions_.push_back( + Parser{ + "FUNCTION exp_pade_22(z)" + "{ exp_pade_22=(1+0.5*z+0.08333333333333333*z*z)/(1-0.5*z+0.08333333333333333*z*z) }" + }.parse_function()); + // move functions and procedures to the symbol table if(!move_symbols(functions_)) return false; if(!move_symbols(procedures_)) return false; - //////////////////////////////////////////////////////////////////////////// - // now iterate over the functions and procedures and perform semantic - // analysis on each. This includes - // - variable, function and procedure lookup - // - generate local variable table for each function/procedure - // - inlining function calls - //////////////////////////////////////////////////////////////////////////// -#ifdef LOGGING - std::cout << white("===================================\n"); - std::cout << cyan(" Function Inlining\n"); - std::cout << white("===================================\n"); -#endif - int errors = 0; - for(auto& e : symbols_) { - auto& s = e.second; - - if( s->kind() == symbolKind::function - || s->kind() == symbolKind::procedure) - { -#ifdef LOGGING - std::cout << "\nfunction inlining for " << s->location() << "\n" - << s->to_string() << "\n" - << green("\n-call site lowering-\n\n"); -#endif - // first perform semantic analysis - s->semantic(symbols_); - - // then use an error visitor to print out all the semantic errors - ErrorVisitor v(file_name()); - s->accept(&v); - errors += v.num_errors(); - - // inline function calls - // this requires that the symbol table has already been built - if(v.num_errors()==0) { - auto &b = s->kind()==symbolKind::function ? - s->is_function()->body()->statements() : - s->is_procedure()->body()->statements(); - - // lower function call sites so that all function calls are of - // the form : variable = call(<args>) - // e.g. - // a = 2 + foo(2+x, y, 1) - // becomes - // ll0_ = foo(2+x, y, 1) - // a = 2 + ll0_ - for(auto e=b.begin(); e!=b.end(); ++e) { - b.splice(e, lower_function_calls((*e).get())); - } -#ifdef LOGGING - std::cout << "body after call site lowering\n"; - for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n"; - std::cout << green("\n-argument lowering-\n\n"); -#endif - - // lower function arguments that are not identifiers or literals - // e.g. - // ll0_ = foo(2+x, y, 1) - // a = 2 + ll0_ - // becomes - // ll1_ = 2+x - // ll0_ = foo(ll1_, y, 1) - // a = 2 + ll0_ - for(auto e=b.begin(); e!=b.end(); ++e) { - if(auto be = (*e)->is_binary()) { - // only apply to assignment expressions where rhs is a - // function call because the function call lowering step - // above ensures that all function calls are of this form - if(auto rhs = be->rhs()->is_function_call()) { - b.splice(e, lower_function_arguments(rhs->args())); - } - } - } - -#ifdef LOGGING - std::cout << "body after argument lowering\n"; - for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n"; - std::cout << green("\n-inlining-\n\n"); -#endif - - // Do the inlining, which currently only works for functions - // that have a single statement in their body - // e.g. if the function foo in the examples above is defined as follows - // - // function foo(a, b, c) { - // foo = a*(b + c) - // } - // - // the full inlined example is - // ll1_ = 2+x - // ll0_ = ll1_*(y + 1) - // a = 2 + ll0_ - for(auto e=b.begin(); e!=b.end(); ++e) { - if(auto ass = (*e)->is_assignment()) { - if(ass->rhs()->is_function_call()) { - ass->replace_rhs(inline_function_call(ass->rhs())); - } - } - } - -#ifdef LOGGING - std::cout << "body after inlining\n"; - for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n"; -#endif - } - } - } - - if(errors) { + // perform semantic analysis and inlining on function and procedure bodies + if(auto errors = semantic_func_proc()) { error("There were "+std::to_string(errors)+" errors in the semantic analysis"); return false; } @@ -506,6 +408,10 @@ bool Module::semantic() { return false; } + // Perform semantic analysis and inlining on function and procedure bodies + // in order to inline calls inside the newly crated API methods. + semantic_func_proc(); + return !has_error(); } @@ -754,3 +660,114 @@ bool Module::optimize() { return true; } + +int Module::semantic_func_proc() { + //////////////////////////////////////////////////////////////////////////// + // now iterate over the functions and procedures and perform semantic + // analysis on each. This includes + // - variable, function and procedure lookup + // - generate local variable table for each function/procedure + // - inlining function calls + //////////////////////////////////////////////////////////////////////////// +#ifdef LOGGING + std::cout << white("===================================\n"); + std::cout << cyan(" Function Inlining\n"); + std::cout << white("===================================\n"); +#endif + int errors = 0; + for(auto& e : symbols_) { + auto& s = e.second; + + if( s->kind() == symbolKind::function + || s->kind() == symbolKind::procedure) + { +#ifdef LOGGING + std::cout << "\nfunction inlining for " << s->location() << "\n" + << s->to_string() << "\n" + << green("\n-call site lowering-\n\n"); +#endif + // first perform semantic analysis + s->semantic(symbols_); + + // then use an error visitor to print out all the semantic errors + ErrorVisitor v(file_name()); + s->accept(&v); + errors += v.num_errors(); + + // inline function calls + // this requires that the symbol table has already been built + if(v.num_errors()==0) { + auto &b = s->kind()==symbolKind::function ? + s->is_function()->body()->statements() : + s->is_procedure()->body()->statements(); + + // lower function call sites so that all function calls are of + // the form : variable = call(<args>) + // e.g. + // a = 2 + foo(2+x, y, 1) + // becomes + // ll0_ = foo(2+x, y, 1) + // a = 2 + ll0_ + for(auto e=b.begin(); e!=b.end(); ++e) { + b.splice(e, lower_function_calls((*e).get())); + } +#ifdef LOGGING + std::cout << "body after call site lowering\n"; + for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n"; + std::cout << green("\n-argument lowering-\n\n"); +#endif + + // lower function arguments that are not identifiers or literals + // e.g. + // ll0_ = foo(2+x, y, 1) + // a = 2 + ll0_ + // becomes + // ll1_ = 2+x + // ll0_ = foo(ll1_, y, 1) + // a = 2 + ll0_ + for(auto e=b.begin(); e!=b.end(); ++e) { + if(auto be = (*e)->is_binary()) { + // only apply to assignment expressions where rhs is a + // function call because the function call lowering step + // above ensures that all function calls are of this form + if(auto rhs = be->rhs()->is_function_call()) { + b.splice(e, lower_function_arguments(rhs->args())); + } + } + } + +#ifdef LOGGING + std::cout << "body after argument lowering\n"; + for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n"; + std::cout << green("\n-inlining-\n\n"); +#endif + + // Do the inlining, which currently only works for functions + // that have a single statement in their body + // e.g. if the function foo in the examples above is defined as follows + // + // function foo(a, b, c) { + // foo = a*(b + c) + // } + // + // the full inlined example is + // ll1_ = 2+x + // ll0_ = ll1_*(y + 1) + // a = 2 + ll0_ + for(auto e=b.begin(); e!=b.end(); ++e) { + if(auto ass = (*e)->is_assignment()) { + if(ass->rhs()->is_function_call()) { + ass->replace_rhs(inline_function_call(ass->rhs())); + } + } + } + +#ifdef LOGGING + std::cout << "body after inlining\n"; + for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n"; +#endif + } + } + } + return errors; +} diff --git a/modcc/module.hpp b/modcc/module.hpp index 41abdeb5256ed77b499ea6e0ed1aae038f2b38fd..18c362ac0b31cb1882e26ad8ac56f35d7fdc308e 100644 --- a/modcc/module.hpp +++ b/modcc/module.hpp @@ -109,6 +109,10 @@ private: return s == symbols_.end() ? false : s->second->kind() == kind; } + // Perform semantic analysis on functions and procedures. + // Returns the number of errors that were encountered. + int semantic_func_proc(); + // blocks NeuronBlock neuron_block_; StateBlock state_block_; diff --git a/modcc/solvers.cpp b/modcc/solvers.cpp index e4a4d49051d942ba941058a7b9dfe9b41904b2a2..25e4df5b7f45ecb99ab8ea51db0881b555f56113 100644 --- a/modcc/solvers.cpp +++ b/modcc/solvers.cpp @@ -70,7 +70,7 @@ void CnexpSolverVisitor::visit(AssignmentExpression *e) { statements_.push_back(std::move(local_a_term.assignment)); auto a_ = local_a_term.id->is_identifier()->spelling(); - std::string s_update = pprintf("% = %*exp(%*dt)", s, s, a_); + std::string s_update = pprintf("% = %*exp_pade_11(%*dt)", s, s, a_); statements_.push_back(Parser{s_update}.parse_line_expression()); return; } @@ -111,7 +111,7 @@ void CnexpSolverVisitor::visit(AssignmentExpression *e) { statements_.push_back(std::move(local_b_term.local_decl)); statements_.push_back(std::move(local_b_term.assignment)); - std::string s_update = pprintf("% = %+(%-%)*exp(-dt/%)", s, b_, s, b_, a_); + std::string s_update = pprintf("% = %+(%-%)*exp_pade_11(-dt/%)", s, b_, s, b_, a_); statements_.push_back(Parser{s_update}.parse_line_expression()); return; } @@ -130,7 +130,7 @@ not_gating: statements_.push_back(std::move(local_ba_term.local_decl)); statements_.push_back(std::move(local_ba_term.assignment)); - std::string s_update = pprintf("% = -%+(%+%)*exp(%*dt)", s, ba_, s, ba_, a_); + std::string s_update = pprintf("% = -%+(%+%)*exp_pade_11(%*dt)", s, ba_, s, ba_, a_); statements_.push_back(Parser{s_update}.parse_line_expression()); return; }