Skip to content
Snippets Groups Projects
  • Sam Yates's avatar
    Add KINETIC block rewriter (issue #63) (#95) · 4e229b01
    Sam Yates authored
    Adds a new KineticRewriter visitor that transforms (after semantic analysis) a parsed KINETIC procedure into an equivalent DERIVATIVE procedure. The visitor takes a ProcedureExpression and composes the equivalent procedure, available via the as_procedure() method on the visitor object.
    
    Move common functinality for 'local' variable insertion during transformation phase to new files astmanip.?pp.
    Add Expression method for directly setting scope.
    Use scope_ptr type alias widely.
    Implement correct clone() behaviour for DerivativeExpression
    Implement KineticRewriter transforming visitor class.
    Add equivalence test for KineticRewriter: the test incorporates a simple ad-hoc algebraic expression simplifier.
    Add unit test to Parser.parse_binop to exercise bug #94
    4e229b01
functionexpander.cpp 4.09 KiB
#include <iostream>

#include "astmanip.hpp"
#include "error.hpp"
#include "functionexpander.hpp"
#include "modccutil.hpp"

expression_ptr insert_unique_local_assignment(call_list_type& stmts, Expression* e) {
    auto exprs = make_unique_local_assign(e->scope(), e);
    stmts.push_front(std::move(exprs.local_decl));
    stmts.push_back(std::move(exprs.assignment));
    return std::move(exprs.id);
}

///////////////////////////////////////////////////////////////////////////////
//  function call site lowering
///////////////////////////////////////////////////////////////////////////////

call_list_type lower_function_calls(Expression* e)
{
    auto v = make_unique<FunctionCallLowerer>(e->scope());

    if(auto a=e->is_assignment()) {
#ifdef LOGGING
        std::cout << "lower_function_calls inspect expression " << e->to_string() << "\n";
#endif
        // recursively inspect and replace function calls with identifiers
        a->rhs()->accept(v.get());

    }

    // return the list of statements that assign function call return values
    // to identifiers, e.g.
    //      LOCAL ll1_
    //      ll1_ = mInf(v)
    return v->move_calls();
}

void FunctionCallLowerer::visit(Expression *e) {
    throw compiler_exception(
        "function lowering for expressions of the type " + e->to_string()
        + " has not been defined", e->location()
    );
}

void FunctionCallLowerer::visit(CallExpression *e) {
    for(auto& arg : e->args()) {
        if(auto func = arg->is_function_call()) {
            func->accept(this);
#ifdef LOGGING
            std::cout << "  lowering : " << func->to_string() << "\n";
#endif
            expand_call(
                func, [&arg](expression_ptr&& p){arg = std::move(p);}
            );
            arg->semantic(scope_);
        }
        else {
            arg->accept(this);
        }
    }
}

void FunctionCallLowerer::visit(UnaryExpression *e) {
    if(auto func = e->expression()->is_function_call()) {
        func->accept(this);
#ifdef LOGGING
        std::cout << "  lowering : " << func->to_string() << "\n";
#endif
        expand_call(func, [&e](expression_ptr&& p){e->replace_expression(std::move(p));});
        e->semantic(scope_);
    }
    else {
        e->expression()->accept(this);
    }
}

void FunctionCallLowerer::visit(BinaryExpression *e) {
    if(auto func = e->lhs()->is_function_call()) {
        func->accept(this);
#ifdef LOGGING
        std::cout << "  lowering : " << func->to_string() << "\n";
#endif
        expand_call(func, [&e](expression_ptr&& p){e->replace_lhs(std::move(p));});
        e->semantic(scope_);
    }
    else {
        e->lhs()->accept(this);
    }

    if(auto func = e->rhs()->is_function_call()) {
        func->accept(this);
#ifdef LOGGING
        std::cout << "  lowering : " << func->to_string() << "\n";
#endif
        expand_call(func, [&e](expression_ptr&& p){e->replace_rhs(std::move(p));});
        e->semantic(scope_);
    }
    else {
        e->rhs()->accept(this);
    }
}

///////////////////////////////////////////////////////////////////////////////
//  function argument lowering
///////////////////////////////////////////////////////////////////////////////

call_list_type
lower_function_arguments(std::vector<expression_ptr>& args)
{
    call_list_type new_statements;
    for(auto it=args.begin(); it!=args.end(); ++it) {
        // get reference to the unique_ptr with the expression
        auto& e = *it;
#ifdef LOGGING
        std::cout << "inspecting argument @ " << e->location() << " : " << e->to_string() << std::endl;
#endif

        if(e->is_number() || e->is_identifier()) {
            // do nothing, because identifiers and literals are in the correct form
            // for lowering
            continue;
        }

        auto id = insert_unique_local_assignment(new_statements, e.get());
#ifdef LOGGING
        std::cout << "  lowering to " << new_statements.back()->to_string() << "\n";
#endif
        // replace the function call in the original expression with the local
        // variable which holds the pre-computed value
        std::swap(e, id);
    }
#ifdef LOGGING
    std::cout << "\n";
#endif

    return new_statements;
}