Skip to content
Snippets Groups Projects
Commit 5b254146 authored by Ben Cumming's avatar Ben Cumming Committed by Sam Yates
Browse files

Padé approximation of exp in 'cnexp' integration (#268)

Fixes #265.

In the `modcc`-generated mechanism code, the `cnexp` solver method uses an expensive call to `exp` to integrate dependent variables over one time step. This commit replaces the exponential with a second-order Padé approximation.

  * Modify `modcc` to insert `exp_pade_11` and `exp_pade_22` functions into every module, which define Padé approximations of second and fourth order respectively (m=n=1 and m=n=2).
  * Have `cnexp` solver use `exp_pade_11` instead of the built in `exp` unary operator.

The validation tests pass for both the 2nd and 4th order approximations; the second order approximation will suffices.
parent 07dd8f35
No related branches found
No related tags found
No related merge requests found
......@@ -209,119 +209,21 @@ bool Module::semantic() {
return true;
};
// Add built in function that approximate exp use pade polynomials
functions_.push_back(
Parser{"FUNCTION exp_pade_11(z) { exp_pade_11=(1+0.5*z)/(1-0.5*z) }"}.parse_function());
functions_.push_back(
Parser{
"FUNCTION exp_pade_22(z)"
"{ exp_pade_22=(1+0.5*z+0.08333333333333333*z*z)/(1-0.5*z+0.08333333333333333*z*z) }"
}.parse_function());
// move functions and procedures to the symbol table
if(!move_symbols(functions_)) return false;
if(!move_symbols(procedures_)) return false;
////////////////////////////////////////////////////////////////////////////
// now iterate over the functions and procedures and perform semantic
// analysis on each. This includes
// - variable, function and procedure lookup
// - generate local variable table for each function/procedure
// - inlining function calls
////////////////////////////////////////////////////////////////////////////
#ifdef LOGGING
std::cout << white("===================================\n");
std::cout << cyan(" Function Inlining\n");
std::cout << white("===================================\n");
#endif
int errors = 0;
for(auto& e : symbols_) {
auto& s = e.second;
if( s->kind() == symbolKind::function
|| s->kind() == symbolKind::procedure)
{
#ifdef LOGGING
std::cout << "\nfunction inlining for " << s->location() << "\n"
<< s->to_string() << "\n"
<< green("\n-call site lowering-\n\n");
#endif
// first perform semantic analysis
s->semantic(symbols_);
// then use an error visitor to print out all the semantic errors
ErrorVisitor v(file_name());
s->accept(&v);
errors += v.num_errors();
// inline function calls
// this requires that the symbol table has already been built
if(v.num_errors()==0) {
auto &b = s->kind()==symbolKind::function ?
s->is_function()->body()->statements() :
s->is_procedure()->body()->statements();
// lower function call sites so that all function calls are of
// the form : variable = call(<args>)
// e.g.
// a = 2 + foo(2+x, y, 1)
// becomes
// ll0_ = foo(2+x, y, 1)
// a = 2 + ll0_
for(auto e=b.begin(); e!=b.end(); ++e) {
b.splice(e, lower_function_calls((*e).get()));
}
#ifdef LOGGING
std::cout << "body after call site lowering\n";
for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n";
std::cout << green("\n-argument lowering-\n\n");
#endif
// lower function arguments that are not identifiers or literals
// e.g.
// ll0_ = foo(2+x, y, 1)
// a = 2 + ll0_
// becomes
// ll1_ = 2+x
// ll0_ = foo(ll1_, y, 1)
// a = 2 + ll0_
for(auto e=b.begin(); e!=b.end(); ++e) {
if(auto be = (*e)->is_binary()) {
// only apply to assignment expressions where rhs is a
// function call because the function call lowering step
// above ensures that all function calls are of this form
if(auto rhs = be->rhs()->is_function_call()) {
b.splice(e, lower_function_arguments(rhs->args()));
}
}
}
#ifdef LOGGING
std::cout << "body after argument lowering\n";
for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n";
std::cout << green("\n-inlining-\n\n");
#endif
// Do the inlining, which currently only works for functions
// that have a single statement in their body
// e.g. if the function foo in the examples above is defined as follows
//
// function foo(a, b, c) {
// foo = a*(b + c)
// }
//
// the full inlined example is
// ll1_ = 2+x
// ll0_ = ll1_*(y + 1)
// a = 2 + ll0_
for(auto e=b.begin(); e!=b.end(); ++e) {
if(auto ass = (*e)->is_assignment()) {
if(ass->rhs()->is_function_call()) {
ass->replace_rhs(inline_function_call(ass->rhs()));
}
}
}
#ifdef LOGGING
std::cout << "body after inlining\n";
for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n";
#endif
}
}
}
if(errors) {
// perform semantic analysis and inlining on function and procedure bodies
if(auto errors = semantic_func_proc()) {
error("There were "+std::to_string(errors)+" errors in the semantic analysis");
return false;
}
......@@ -506,6 +408,10 @@ bool Module::semantic() {
return false;
}
// Perform semantic analysis and inlining on function and procedure bodies
// in order to inline calls inside the newly crated API methods.
semantic_func_proc();
return !has_error();
}
......@@ -754,3 +660,114 @@ bool Module::optimize() {
return true;
}
int Module::semantic_func_proc() {
////////////////////////////////////////////////////////////////////////////
// now iterate over the functions and procedures and perform semantic
// analysis on each. This includes
// - variable, function and procedure lookup
// - generate local variable table for each function/procedure
// - inlining function calls
////////////////////////////////////////////////////////////////////////////
#ifdef LOGGING
std::cout << white("===================================\n");
std::cout << cyan(" Function Inlining\n");
std::cout << white("===================================\n");
#endif
int errors = 0;
for(auto& e : symbols_) {
auto& s = e.second;
if( s->kind() == symbolKind::function
|| s->kind() == symbolKind::procedure)
{
#ifdef LOGGING
std::cout << "\nfunction inlining for " << s->location() << "\n"
<< s->to_string() << "\n"
<< green("\n-call site lowering-\n\n");
#endif
// first perform semantic analysis
s->semantic(symbols_);
// then use an error visitor to print out all the semantic errors
ErrorVisitor v(file_name());
s->accept(&v);
errors += v.num_errors();
// inline function calls
// this requires that the symbol table has already been built
if(v.num_errors()==0) {
auto &b = s->kind()==symbolKind::function ?
s->is_function()->body()->statements() :
s->is_procedure()->body()->statements();
// lower function call sites so that all function calls are of
// the form : variable = call(<args>)
// e.g.
// a = 2 + foo(2+x, y, 1)
// becomes
// ll0_ = foo(2+x, y, 1)
// a = 2 + ll0_
for(auto e=b.begin(); e!=b.end(); ++e) {
b.splice(e, lower_function_calls((*e).get()));
}
#ifdef LOGGING
std::cout << "body after call site lowering\n";
for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n";
std::cout << green("\n-argument lowering-\n\n");
#endif
// lower function arguments that are not identifiers or literals
// e.g.
// ll0_ = foo(2+x, y, 1)
// a = 2 + ll0_
// becomes
// ll1_ = 2+x
// ll0_ = foo(ll1_, y, 1)
// a = 2 + ll0_
for(auto e=b.begin(); e!=b.end(); ++e) {
if(auto be = (*e)->is_binary()) {
// only apply to assignment expressions where rhs is a
// function call because the function call lowering step
// above ensures that all function calls are of this form
if(auto rhs = be->rhs()->is_function_call()) {
b.splice(e, lower_function_arguments(rhs->args()));
}
}
}
#ifdef LOGGING
std::cout << "body after argument lowering\n";
for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n";
std::cout << green("\n-inlining-\n\n");
#endif
// Do the inlining, which currently only works for functions
// that have a single statement in their body
// e.g. if the function foo in the examples above is defined as follows
//
// function foo(a, b, c) {
// foo = a*(b + c)
// }
//
// the full inlined example is
// ll1_ = 2+x
// ll0_ = ll1_*(y + 1)
// a = 2 + ll0_
for(auto e=b.begin(); e!=b.end(); ++e) {
if(auto ass = (*e)->is_assignment()) {
if(ass->rhs()->is_function_call()) {
ass->replace_rhs(inline_function_call(ass->rhs()));
}
}
}
#ifdef LOGGING
std::cout << "body after inlining\n";
for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n";
#endif
}
}
}
return errors;
}
......@@ -109,6 +109,10 @@ private:
return s == symbols_.end() ? false : s->second->kind() == kind;
}
// Perform semantic analysis on functions and procedures.
// Returns the number of errors that were encountered.
int semantic_func_proc();
// blocks
NeuronBlock neuron_block_;
StateBlock state_block_;
......
......@@ -70,7 +70,7 @@ void CnexpSolverVisitor::visit(AssignmentExpression *e) {
statements_.push_back(std::move(local_a_term.assignment));
auto a_ = local_a_term.id->is_identifier()->spelling();
std::string s_update = pprintf("% = %*exp(%*dt)", s, s, a_);
std::string s_update = pprintf("% = %*exp_pade_11(%*dt)", s, s, a_);
statements_.push_back(Parser{s_update}.parse_line_expression());
return;
}
......@@ -111,7 +111,7 @@ void CnexpSolverVisitor::visit(AssignmentExpression *e) {
statements_.push_back(std::move(local_b_term.local_decl));
statements_.push_back(std::move(local_b_term.assignment));
std::string s_update = pprintf("% = %+(%-%)*exp(-dt/%)", s, b_, s, b_, a_);
std::string s_update = pprintf("% = %+(%-%)*exp_pade_11(-dt/%)", s, b_, s, b_, a_);
statements_.push_back(Parser{s_update}.parse_line_expression());
return;
}
......@@ -130,7 +130,7 @@ not_gating:
statements_.push_back(std::move(local_ba_term.local_decl));
statements_.push_back(std::move(local_ba_term.assignment));
std::string s_update = pprintf("% = -%+(%+%)*exp(%*dt)", s, ba_, s, ba_, a_);
std::string s_update = pprintf("% = -%+(%+%)*exp_pade_11(%*dt)", s, ba_, s, ba_, a_);
statements_.push_back(Parser{s_update}.parse_line_expression());
return;
}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment