diff --git a/.gitmodules b/.gitmodules index b7b2a25e92a25ca1d8d777e4a91ca653067b09f1..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "modparser"] - path = external/modparser - url = git@github.com:eth-cscs/modparser.git diff --git a/CMakeLists.txt b/CMakeLists.txt index e6ade6037d32ab09542b74efd7e9e3971bc264d3..9361984c69ee3ef97eed17800c5b427b94c6304f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ cmake_minimum_required(VERSION 2.8) project(cell_algorithms) enable_language(CXX) -# save incoming CXX flags for forwarding to modparser external project +# save incoming CXX flags for forwarding to modcc external project set(SAVED_CXX_FLAGS "${CMAKE_CXX_FLAGS}") # compilation flags @@ -79,28 +79,15 @@ set(USE_OPTIMIZED_KERNELS OFF CACHE BOOL "generate optimized code that vectorize # Only build modcc if it has not already been installed. # This is useful if cross compiling for KNL, when it is not desirable to compile # modcc with the same flags that are used for the KNL target. -find_program(MODCC_BIN modcc) -set(modcc "${MODCC_BIN}") set(use_external_modcc OFF BOOL) - -# the modcc executable was not found, so build our own copy +find_program(MODCC_BIN modcc) if(MODCC_BIN STREQUAL "MODCC_BIN-NOTFOUND") - include(ExternalProject) - externalproject_add(modparser - PREFIX ${CMAKE_BINARY_DIR}/external - CMAKE_ARGS "-DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/external" - "-DCMAKE_CXX_FLAGS=${SAVED_CXX_FLAGS}" - "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" - BINARY_DIR "${CMAKE_BINARY_DIR}/external/modparser" - STAMP_DIR "${CMAKE_BINARY_DIR}/external/" - TMP_DIR "${CMAKE_BINARY_DIR}/external/tmp" - SOURCE_DIR "${CMAKE_SOURCE_DIR}/external/modparser" - ) - # Set up environment to use the version of modcc that is compiled - # as the ExternalProject above. - set(use_external_modcc ON) - set(modcc "${CMAKE_BINARY_DIR}/external/bin/modcc") + set(modcc "${CMAKE_BINARY_DIR}/modcc/modcc") +else() + set(modcc "${MODCC_BIN}") + set(use_external_modcc ON BOOL) endif() +message("==== modcc: " ${modcc}) include_directories(${CMAKE_SOURCE_DIR}/tclap/include) include_directories(${CMAKE_SOURCE_DIR}/vector) @@ -112,6 +99,8 @@ if( "${WITH_TBB}" STREQUAL "ON" ) include_directories(${TBB_INCLUDE_DIRS}) endif() +# TODO : only compile modcc if it is not provided externally +add_subdirectory(modcc) add_subdirectory(mechanisms) add_subdirectory(src) add_subdirectory(tests) diff --git a/external/modparser b/external/modparser deleted file mode 160000 index a3aa02b2888a23aa9dbd2244ef3709dd01fc6c31..0000000000000000000000000000000000000000 --- a/external/modparser +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a3aa02b2888a23aa9dbd2244ef3709dd01fc6c31 diff --git a/modcc/CMakeLists.txt b/modcc/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..508202cb0c157e867449aa0142efee551c538ff7 --- /dev/null +++ b/modcc/CMakeLists.txt @@ -0,0 +1,6 @@ +# generated .a and .so go into /lib +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) + +add_subdirectory(src) + diff --git a/modcc/src/CMakeLists.txt b/modcc/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..aa98b8617c6581eee3fe386eaaf168b964c56d96 --- /dev/null +++ b/modcc/src/CMakeLists.txt @@ -0,0 +1,27 @@ +set(MODCC_SOURCES + token.cpp + lexer.cpp + expression.cpp + parser.cpp + textbuffer.cpp + cprinter.cpp + functionexpander.cpp + functioninliner.cpp + cudaprinter.cpp + expressionclassifier.cpp + constantfolder.cpp + errorvisitor.cpp + module.cpp +) + +add_library(compiler ${MODCC_SOURCES}) + +add_executable(modcc modcc.cpp) + +target_link_libraries(modcc LINK_PUBLIC compiler) + +set_target_properties(modcc + PROPERTIES + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/modcc" +) + diff --git a/modcc/src/blocks.hpp b/modcc/src/blocks.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9fee8ab18e72465d3c7f4f97faffca0e174e84b5 --- /dev/null +++ b/modcc/src/blocks.hpp @@ -0,0 +1,167 @@ +#pragma once + +#include <string> +#include <vector> + +#include "identifier.hpp" +#include "location.hpp" +#include "token.hpp" +#include "util.hpp" + +// describes a relationship with an ion channel +struct IonDep { + ionKind kind() const { + if(name=="k") return ionKind::K; + if(name=="na") return ionKind::Na; + if(name=="ca") return ionKind::Ca; + return ionKind::none; + } + std::string name; // name of ion channel + std::vector<Token> read; // name of channels parameters to write + std::vector<Token> write; // name of channels parameters to read +}; + +enum class moduleKind { + point, + density +}; + +// information stored in a NEURON {} block in mod file. +struct NeuronBlock { + bool threadsafe = false; + std::string name; + moduleKind kind; + std::vector<IonDep> ions; + std::vector<Token> ranges; + std::vector<Token> globals; + Token nonspecific_current; + bool has_nonspecific_current() const { + return nonspecific_current.spelling.size()>0; + } +}; + +// information stored in a NEURON {} block in mod file +struct StateBlock { + std::vector<std::string> state_variables; + auto begin() -> decltype(state_variables.begin()) { + return state_variables.begin(); + } + auto end() -> decltype(state_variables.end()) { + return state_variables.end(); + } +}; + +// information stored in a NEURON {} block in mod file +typedef std::vector<Token> unit_tokens; +struct UnitsBlock { + typedef std::pair<unit_tokens, unit_tokens> units_pair; + std::vector<units_pair> unit_aliases; +}; + +struct Id { + Token token; + std::string value; // store the value as a string, not a number : empty string == no value + unit_tokens units; + + Id(Token const& t, std::string const& v, unit_tokens const& u) + : token(t), value(v), units(u) + {} + + Id() {} + + bool has_value() const { + return value.size()>0; + } + + std::string const& name() const { + return token.spelling; + } +}; + +// information stored in a NEURON {} block in mod file +struct ParameterBlock { + std::vector<Id> parameters; + + auto begin() -> decltype(parameters.begin()) { + return parameters.begin(); + } + auto end() -> decltype(parameters.end()) { + return parameters.end(); + } +}; + +// information stored in a NEURON {} block in mod file +struct AssignedBlock { + std::vector<Id> parameters; + + auto begin() -> decltype(parameters.begin()) { + return parameters.begin(); + } + auto end() -> decltype(parameters.end()) { + return parameters.end(); + } +}; + +//////////////////////////////////////////////// +// helpers for pretty printing block information +//////////////////////////////////////////////// +inline std::ostream& operator<< (std::ostream& os, Id const& V) { + if(V.units.size()) + os << "(" << V.token << "," << V.value << "," << V.units << ")"; + else + os << "(" << V.token << "," << V.value << ",)"; + + return os; +} + +inline std::ostream& operator<< (std::ostream& os, UnitsBlock::units_pair const& p) { + return os << "(" << p.first << ", " << p.second << ")"; +} + +inline std::ostream& operator<< (std::ostream& os, IonDep const& I) { + return os << "(" << I.name << ": read " << I.read << " write " << I.write << ")"; +} + +inline std::ostream& operator<< (std::ostream& os, moduleKind const& k) { + return os << (k==moduleKind::density ? "density" : "point process"); +} + +inline std::ostream& operator<< (std::ostream& os, NeuronBlock const& N) { + os << blue("NeuronBlock") << std::endl; + os << " kind : " << N.kind << std::endl; + os << " name : " << N.name << std::endl; + os << " threadsafe : " << (N.threadsafe ? "yes" : "no") << std::endl; + os << " ranges : " << N.ranges << std::endl; + os << " globals : " << N.globals << std::endl; + os << " ions : " << N.ions << std::endl; + + return os; +} + +inline std::ostream& operator<< (std::ostream& os, StateBlock const& B) { + os << blue("StateBlock") << std::endl; + return os << " variables : " << B.state_variables << std::endl; + +} + +inline std::ostream& operator<< (std::ostream& os, UnitsBlock const& U) { + os << blue("UnitsBlock") << std::endl; + os << " aliases : " << U.unit_aliases << std::endl; + + return os; +} + +inline std::ostream& operator<< (std::ostream& os, ParameterBlock const& P) { + os << blue("ParameterBlock") << std::endl; + os << " parameters : " << P.parameters << std::endl; + + return os; +} + +inline std::ostream& operator<< (std::ostream& os, AssignedBlock const& A) { + os << blue("AssignedBlock") << std::endl; + os << " parameters : " << A.parameters << std::endl; + + return os; +} + diff --git a/modcc/src/constantfolder.cpp b/modcc/src/constantfolder.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3dfb609f7c06f16972633cc3c5f3d5edbfe7b2a2 --- /dev/null +++ b/modcc/src/constantfolder.cpp @@ -0,0 +1,176 @@ +#include <iostream> +#include <cmath> + +#include "constantfolder.hpp" + +/* + perform a walk of the AST + - pre-order : mark node as not a number + - in-order : convert all children that marked themselves as numbers into NumberExpressions + - post-order: mark the current node as a constant if all of its children + were converted to NumberExpressions + + all calculations and intermediate results use 80 bit floating point precision (long double) +*/ + +// default is to do nothing and return +void ConstantFolderVisitor::visit(Expression *e) { + is_number = false; +} + +// number expresssion +void ConstantFolderVisitor::visit(NumberExpression *e) { + // set constant number and return + is_number = true; + value = e->value(); +} + +/// unary expresssion +void ConstantFolderVisitor::visit(UnaryExpression *e) { + is_number = false; + e->expression()->accept(this); + if(is_number) { + if(!e->is_number()) { + e->replace_expression(make_expression<NumberExpression>(e->location(), value)); + } + switch(e->op()) { + case tok::minus : + value = -value; + return; + case tok::exp : + value = std::exp(value); + return; + case tok::cos : + value = std::cos(value); + return; + case tok::sin : + value = std::sin(value); + return; + case tok::log : + value = std::log(value); + return; + default : + throw compiler_exception( + "attempting constant folding on unsuported unary operator " + + yellow(token_string(e->op())), + e->location()); + } + } +} + +// binary expresssion +// handle all binary expressions with one routine, because the +// pre-order and in-order code is the same for all cases +void ConstantFolderVisitor::visit(BinaryExpression *e) { + bool lhs_is_number = false; + long double lhs_value = 0; + + // check the lhs + is_number = false; + e->lhs()->accept(this); + if(is_number) { + lhs_value = value; + lhs_is_number = true; + // replace lhs with a number node, if it is not already one + if(!e->lhs()->is_number()) { + e->replace_lhs( make_expression<NumberExpression>(e->location(), value) ); + } + } + //std::cout << "lhs : " << e->lhs()->to_string() << std::endl; + + // check the rhs + is_number = false; + e->rhs()->accept(this); + if(is_number) { + // replace rhs with a number node, if it is not already one + if(!e->rhs()->is_number()) { + //std::cout << "rhs : " << e->rhs()->to_string() << " -> "; + e->replace_rhs( make_expression<NumberExpression>(e->location(), value) ); + //std::cout << e->rhs()->to_string() << std::endl; + } + } + //std::cout << "rhs : " << e->rhs()->to_string() << std::endl; + + auto rhs_is_number = is_number; + is_number = rhs_is_number && lhs_is_number; + + // check to see if both lhs and rhs are numbers + // mark this node as a number if so + if(is_number) { + // be careful to get the order of operation right for + // non-computative operators + switch(e->op()) { + case tok::plus : + value = lhs_value + value; + return; + case tok::minus : + value = lhs_value - value; + return; + case tok::times : + value = lhs_value * value; + return; + case tok::divide : + value = lhs_value / value; + return; + case tok::pow : + value = std::pow(lhs_value, value); + return; + // don't fold comparison operators (we have no internal support + // for boolean values in nodes). leave for the back end compiler. + // not a big deal, because these are not counted when estimating + // flops with the FLOP visitor + case tok::lt : + case tok::lte : + case tok::gt : + case tok::gte : + case tok::equality : + is_number = false; + return; + default : + throw compiler_exception( + "attempting constant folding on unsuported binary operator " + + yellow(token_string(e->op())), + e->location()); + } + } +} + +void ConstantFolderVisitor::visit(CallExpression *e) { + is_number = false; + for(auto& a : e->args()) { + a->accept(this); + if(is_number) { + // replace rhs with a number node, if it is not already one + if(!a->is_number()) { + a.reset(new NumberExpression(a->location(), value)); + } + } + } +} + +void ConstantFolderVisitor::visit(BlockExpression *e) { + is_number = false; + for(auto &expression : e->statements()) { + expression->accept(this); + } +} + +void ConstantFolderVisitor::visit(FunctionExpression *e) { + is_number = false; + e->body()->accept(this); +} + +void ConstantFolderVisitor::visit(ProcedureExpression *e) { + is_number = false; + e->body()->accept(this); +} + +void ConstantFolderVisitor::visit(IfExpression *e) { + is_number = false; + e->condition()->accept(this); + e->true_branch()->accept(this); + if(e->false_branch()) { + e->false_branch()->accept(this); + } +} + diff --git a/modcc/src/constantfolder.hpp b/modcc/src/constantfolder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1fc5f507a4bbdd3d58622e7772944bd1d7b25b3a --- /dev/null +++ b/modcc/src/constantfolder.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "visitor.hpp" + +class ConstantFolderVisitor : public Visitor { +public: + ConstantFolderVisitor() {} + + void visit(Expression *e) override; + // reduce child + void visit(UnaryExpression *e) override; + // reduce left and right children + void visit(BinaryExpression *e) override; + // reduce expressions in arguments + void visit(NumberExpression *e) override; + + void visit(CallExpression *e) override; + void visit(ProcedureExpression *e) override; + void visit(FunctionExpression *e) override; + void visit(BlockExpression *e) override; + void visit(IfExpression *e) override; + + // store intermediate results as long double, i.e. 80-bit precision + long double value = 0.; + bool is_number = false; +}; + diff --git a/modcc/src/cprinter.cpp b/modcc/src/cprinter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..93f05a8d1c9fa7777043c8c12b8b55ea98b7ef1e --- /dev/null +++ b/modcc/src/cprinter.cpp @@ -0,0 +1,894 @@ +#include <algorithm> + +#include "cprinter.hpp" +#include "lexer.hpp" + +/****************************************************************************** + CPrinter driver +******************************************************************************/ + +CPrinter::CPrinter(Module &m, bool o) +: module_(&m), + optimize_(o) +{ + // make a list of vector types, both parameters and assigned + // and a list of all scalar types + std::vector<VariableExpression*> scalar_variables; + std::vector<VariableExpression*> array_variables; + for(auto& sym: m.symbols()) { + if(auto var = sym.second->is_variable()) { + if(var->is_range()) { + array_variables.push_back(var); + } + else { + scalar_variables.push_back(var); + } + } + } + + ////////////////////////////////////////////// + ////////////////////////////////////////////// + text_.add_line("#pragma once"); + text_.add_line(); + text_.add_line("#include <cmath>"); + text_.add_line("#include <limits>"); + text_.add_line(); + text_.add_line("#include <mechanism.hpp>"); + text_.add_line("#include <mechanism_interface.hpp>"); + text_.add_line("#include <algorithms.hpp>"); + text_.add_line(); + + ////////////////////////////////////////////// + ////////////////////////////////////////////// + std::string class_name = "mechanism_" + m.name(); + + text_.add_line("namespace nest{ namespace mc{ namespace mechanisms{ namespace " + m.name() + "{"); + text_.add_line(); + text_.add_line("template<typename T, typename I>"); + text_.add_line("class " + class_name + " : public mechanism<T, I> {"); + text_.add_line("public:"); + text_.increase_indentation(); + text_.add_line("using base = mechanism<T, I>;"); + text_.add_line("using value_type = typename base::value_type;"); + text_.add_line("using size_type = typename base::size_type;"); + text_.add_line("using vector_type = typename base::vector_type;"); + text_.add_line("using view_type = typename base::view_type;"); + text_.add_line("using index_type = typename base::index_type;"); + text_.add_line("using index_view = typename base::index_view;"); + text_.add_line("using const_index_view = typename base::const_index_view;"); + text_.add_line("using indexed_view_type= typename base::indexed_view_type;"); + text_.add_line("using ion_type = typename base::ion_type;"); + text_.add_line(); + + ////////////////////////////////////////////// + ////////////////////////////////////////////// + for(auto& ion: m.neuron_block().ions) { + auto tname = "Ion" + ion.name; + text_.add_line("struct " + tname + " {"); + text_.increase_indentation(); + for(auto& field : ion.read) { + text_.add_line("view_type " + field.spelling + ";"); + } + for(auto& field : ion.write) { + text_.add_line("view_type " + field.spelling + ";"); + } + text_.add_line("index_type index;"); + text_.add_line("std::size_t memory() const { return sizeof(size_type)*index.size(); }"); + text_.add_line("std::size_t size() const { return index.size(); }"); + text_.decrease_indentation(); + text_.add_line("};"); + text_.add_line(tname + " ion_" + ion.name + ";"); + } + text_.add_line(); + + ////////////////////////////////////////////// + ////////////////////////////////////////////// + int num_vars = array_variables.size(); + text_.add_line(class_name + "(view_type vec_v, view_type vec_i, const_index_view node_index)"); + text_.add_line(": base(vec_v, vec_i, node_index)"); + text_.add_line("{"); + text_.increase_indentation(); + text_.add_gutter() << "size_type num_fields = " << num_vars << ";"; + text_.end_line(); + + text_.add_line(); + text_.add_line("// calculate the padding required to maintain proper alignment of sub arrays"); + text_.add_line("auto alignment = data_.alignment();"); + text_.add_line("auto field_size_in_bytes = sizeof(value_type)*size();"); + text_.add_line("auto remainder = field_size_in_bytes % alignment;"); + text_.add_line("auto padding = remainder ? (alignment - remainder)/sizeof(value_type) : 0;"); + text_.add_line("auto field_size = size()+padding;"); + + text_.add_line(); + text_.add_line("// allocate memory"); + text_.add_line("data_ = vector_type(field_size * num_fields);"); + text_.add_line("data_(memory::all) = std::numeric_limits<value_type>::quiet_NaN();"); + + // assign the sub-arrays + // replace this : data_(1*n, 2*n); + // with this : data_(1*field_size, 1*field_size+n); + + text_.add_line(); + text_.add_line("// asign the sub-arrays"); + for(int i=0; i<num_vars; ++i) { + char namestr[128]; + sprintf(namestr, "%-15s", array_variables[i]->name().c_str()); + if(optimize_) { + text_.add_gutter() << namestr << " = data_.data() + " + << i << "*field_size;"; + } + else { + text_.add_gutter() << namestr << " = data_(" + << i << "*field_size, " << i+1 << "*size());"; + } + text_.end_line(); + } + + text_.add_line(); + text_.add_line("// set initial values for variables and parameters"); + for(auto const& var : array_variables) { + double val = var->value(); + // only non-NaN fields need to be initialized, because data_ + // is NaN by default + std::string pointer_name = var->name(); + if(!optimize_) pointer_name += ".data()"; + if(val == val) { + text_.add_gutter() << "std::fill(" << pointer_name << ", " + << pointer_name << "+size(), " + << val << ");"; + text_.end_line(); + } + } + + text_.add_line(); + //text_.add_line("INIT_PROFILE"); + text_.decrease_indentation(); + text_.add_line("}"); + + ////////////////////////////////////////////// + ////////////////////////////////////////////// + + text_.add_line(); + text_.add_line("using base::size;"); + text_.add_line(); + + text_.add_line("std::size_t memory() const override {"); + text_.increase_indentation(); + text_.add_line("auto s = std::size_t{0};"); + text_.add_line("s += data_.size()*sizeof(value_type);"); + for(auto& ion: m.neuron_block().ions) { + text_.add_line("s += ion_" + ion.name + ".memory();"); + } + text_.add_line("return s;"); + text_.decrease_indentation(); + text_.add_line("}"); + text_.add_line(); + + text_.add_line("void set_params(value_type t_, value_type dt_) override {"); + text_.increase_indentation(); + text_.add_line("t = t_;"); + text_.add_line("dt = dt_;"); + text_.decrease_indentation(); + text_.add_line("}"); + text_.add_line(); + + text_.add_line("std::string name() const override {"); + text_.increase_indentation(); + text_.add_line("return \"" + m.name() + "\";"); + text_.decrease_indentation(); + text_.add_line("}"); + text_.add_line(); + + std::string kind_str = m.kind() == moduleKind::density + ? "mechanismKind::density" + : "mechanismKind::point"; + text_.add_line("mechanismKind kind() const override {"); + text_.increase_indentation(); + text_.add_line("return " + kind_str + ";"); + text_.decrease_indentation(); + text_.add_line("}"); + text_.add_line(); + + // return true/false indicating if cell has dependency on k + auto const& ions = m.neuron_block().ions; + auto find_ion = [&ions] (ionKind k) { + return std::find_if( + ions.begin(), ions.end(), + [k](IonDep const& d) {return d.kind()==k;} + ); + }; + auto has_ion = [&ions, find_ion] (ionKind k) { + return find_ion(k) != ions.end(); + }; + + // bool uses_ion(ionKind k) const override + text_.add_line("bool uses_ion(ionKind k) const override {"); + text_.increase_indentation(); + text_.add_line("switch(k) {"); + text_.increase_indentation(); + text_.add_gutter() + << "case ionKind::na : return " + << (has_ion(ionKind::Na) ? "true" : "false") << ";"; + text_.end_line(); + text_.add_gutter() + << "case ionKind::ca : return " + << (has_ion(ionKind::Ca) ? "true" : "false") << ";"; + text_.end_line(); + text_.add_gutter() + << "case ionKind::k : return " + << (has_ion(ionKind::K) ? "true" : "false") << ";"; + text_.end_line(); + text_.decrease_indentation(); + text_.add_line("}"); + text_.add_line("return false;"); + text_.decrease_indentation(); + text_.add_line("}"); + text_.add_line(); + + /*************************************************************************** + * + * ion channels have the following fields : + * + * --------------------------------------------------- + * label Ca Na K name + * --------------------------------------------------- + * iX ica ina ik current + * eX eca ena ek reversal_potential + * Xi cai nai ki internal_concentration + * Xo cao nao ko external_concentration + * gX gca gna gk conductance + * --------------------------------------------------- + * + **************************************************************************/ + + // void set_ion(ionKind k, ion_type& i) override + // TODO: this is done manually, which isn't going to scale + auto has_variable = [] (IonDep const& ion, std::string const& name) { + if( std::find_if(ion.read.begin(), ion.read.end(), + [&name] (Token const& t) {return t.spelling==name;} + ) != ion.read.end() + ) return true; + if( std::find_if(ion.write.begin(), ion.write.end(), + [&name] (Token const& t) {return t.spelling==name;} + ) != ion.write.end() + ) return true; + return false; + }; + text_.add_line("void set_ion(ionKind k, ion_type& i) override {"); + text_.increase_indentation(); + text_.add_line("using nest::mc::algorithms::index_into;"); + if(has_ion(ionKind::Na)) { + auto ion = find_ion(ionKind::Na); + text_.add_line("if(k==ionKind::na) {"); + text_.increase_indentation(); + text_.add_line("ion_na.index = index_into(i.node_index(), node_index_);"); + if(has_variable(*ion, "ina")) text_.add_line("ion_na.ina = i.current();"); + if(has_variable(*ion, "ena")) text_.add_line("ion_na.ena = i.reversal_potential();"); + if(has_variable(*ion, "nai")) text_.add_line("ion_na.nai = i.internal_concentration();"); + if(has_variable(*ion, "nao")) text_.add_line("ion_na.nao = i.external_concentration();"); + text_.add_line("return;"); + text_.decrease_indentation(); + text_.add_line("}"); + } + if(has_ion(ionKind::Ca)) { + auto ion = find_ion(ionKind::Ca); + text_.add_line("if(k==ionKind::ca) {"); + text_.increase_indentation(); + text_.add_line("ion_ca.index = index_into(i.node_index(), node_index_);"); + if(has_variable(*ion, "ica")) text_.add_line("ion_ca.ica = i.current();"); + if(has_variable(*ion, "eca")) text_.add_line("ion_ca.eca = i.reversal_potential();"); + if(has_variable(*ion, "cai")) text_.add_line("ion_ca.cai = i.internal_concentration();"); + if(has_variable(*ion, "cao")) text_.add_line("ion_ca.cao = i.external_concentration();"); + text_.add_line("return;"); + text_.decrease_indentation(); + text_.add_line("}"); + } + if(has_ion(ionKind::K)) { + auto ion = find_ion(ionKind::K); + text_.add_line("if(k==ionKind::k) {"); + text_.increase_indentation(); + text_.add_line("ion_k.index = index_into(i.node_index(), node_index_);"); + if(has_variable(*ion, "ik")) text_.add_line("ion_k.ik = i.current();"); + if(has_variable(*ion, "ek")) text_.add_line("ion_k.ek = i.reversal_potential();"); + if(has_variable(*ion, "ki")) text_.add_line("ion_k.ki = i.internal_concentration();"); + if(has_variable(*ion, "ko")) text_.add_line("ion_k.ko = i.external_concentration();"); + text_.add_line("return;"); + text_.decrease_indentation(); + text_.add_line("}"); + } + text_.add_line("throw std::domain_error(nest::mc::util::pprintf(\"mechanism % does not support ion type\\n\", name()));"); + text_.decrease_indentation(); + text_.add_line("}"); + text_.add_line(); + + ////////////////////////////////////////////// + ////////////////////////////////////////////// + + auto proctest = [] (procedureKind k) { + return + k == procedureKind::normal + || k == procedureKind::api + || k == procedureKind::net_receive; + }; + for(auto &var : m.symbols()) { + auto isproc = var.second->kind()==symbolKind::procedure; + if(isproc ) + { + auto proc = var.second->is_procedure(); + if(proctest(proc->kind())) { + proc->accept(this); + } + } + } + + ////////////////////////////////////////////// + ////////////////////////////////////////////// + + text_.add_line("vector_type data_;"); + for(auto var: array_variables) { + if(optimize_) { + text_.add_line( + "__declspec(align(vector_type::alignment())) value_type *" + + var->name() + ";"); + } + else { + text_.add_line("view_type " + var->name() + ";"); + } + } + + for(auto var: scalar_variables) { + double val = var->value(); + // test the default value for NaN + // useful for error propogation from bad initial conditions + if(val==val) { + text_.add_gutter() << "value_type " << var->name() << " = " << val << ";"; + text_.end_line(); + } + else { + text_.add_line("value_type " + var->name() + " = 0;"); + } + } + + text_.add_line(); + text_.add_line("using base::vec_v_;"); + text_.add_line("using base::vec_i_;"); + text_.add_line("using base::vec_area_;"); + text_.add_line("using base::node_index_;"); + + text_.add_line(); + //text_.add_line("DATA_PROFILE"); + text_.decrease_indentation(); + text_.add_line("};"); + text_.add_line(); + + // print the helper type that provides the bridge from the mechanism to + // the calling code + text_.add_line("template<typename T, typename I>"); + text_.add_line("struct helper : public mechanism_helper<T, I> {"); + text_.increase_indentation(); + text_.add_line("using base = mechanism_helper<T, I>;"); + text_.add_line("using index_view = typename base::index_view;"); + text_.add_line("using view_type = typename base::view_type;"); + text_.add_line("using mechanism_ptr_type = typename base::mechanism_ptr_type;"); + text_.add_gutter() << "using mechanism_type = " << class_name << "<T, I>;"; + text_.add_line(); + text_.add_line(); + + text_.add_line("std::string"); + text_.add_line("name() const override"); + text_.add_line("{"); + text_.increase_indentation(); + text_.add_gutter() << "return \"" << m.name() << "\";"; + text_.add_line(); + text_.decrease_indentation(); + text_.add_line("}"); + text_.add_line(); + + text_.add_line("mechanism_ptr<T,I>"); + text_.add_line("new_mechanism(view_type vec_v, view_type vec_i, index_view node_index) const override"); + text_.add_line("{"); + text_.increase_indentation(); + text_.add_line("return nest::mc::mechanisms::make_mechanism<mechanism_type>(vec_v, vec_i, node_index);"); + text_.decrease_indentation(); + text_.add_line("}"); + text_.add_line(); + + text_.add_line("void"); + text_.add_line("set_parameters(mechanism_ptr_type&, parameter_list const&) const override"); + text_.add_line("{"); + text_.increase_indentation(); + // TODO : interface that writes parameter_list paramaters into the mechanism's storage + text_.decrease_indentation(); + text_.add_line("}"); + text_.add_line(); + + text_.decrease_indentation(); + text_.add_line("};"); + text_.add_line(); + + text_.add_line("}}}} // namespaces"); +} + + +/****************************************************************************** + CPrinter +******************************************************************************/ + +void CPrinter::visit(Expression *e) { + throw compiler_exception( + "CPrinter doesn't know how to print " + e->to_string(), + e->location()); +} + +void CPrinter::visit(LocalDeclaration *e) { +} + +void CPrinter::visit(Symbol *e) { + throw compiler_exception("I don't know how to print raw Symbol " + e->to_string(), + e->location()); +} + +void CPrinter::visit(LocalVariable *e) { + std::string const& name = e->name(); + text_ << name; + if(is_ghost_local(e)) { + text_ << "[j_]"; + } +} + +void CPrinter::visit(NumberExpression *e) { + text_ << " " << e->value(); +} + +void CPrinter::visit(IdentifierExpression *e) { + e->symbol()->accept(this); +} + +void CPrinter::visit(VariableExpression *e) { + text_ << e->name(); + if(e->is_range()) { + text_ << "[i_]"; + } +} + +void CPrinter::visit(IndexedVariable *e) { + text_ << e->index_name() << "[i_]"; +} + +void CPrinter::visit(UnaryExpression *e) { + auto b = (e->expression()->is_binary()!=nullptr); + switch(e->op()) { + case tok::minus : + // place a space in front of minus sign to avoid invalid + // expressions of the form : (v[i]--67) + if(b) text_ << " -("; + else text_ << " -"; + e->expression()->accept(this); + if(b) text_ << ")"; + return; + case tok::exp : + text_ << "exp("; + e->expression()->accept(this); + text_ << ")"; + return; + case tok::cos : + text_ << "cos("; + e->expression()->accept(this); + text_ << ")"; + return; + case tok::sin : + text_ << "sin("; + e->expression()->accept(this); + text_ << ")"; + return; + case tok::log : + text_ << "log("; + e->expression()->accept(this); + text_ << ")"; + return; + default : + throw compiler_exception( + "CPrinter unsupported unary operator " + yellow(token_string(e->op())), + e->location()); + } +} + +void CPrinter::visit(BlockExpression *e) { + // ------------- declare local variables ------------- // + // only if this is the outer block + if(!e->is_nested()) { + std::vector<std::string> names; + for(auto& symbol : e->scope()->locals()) { + auto sym = symbol.second.get(); + // input variables are declared earlier, before the + // block body is printed + if(is_stack_local(sym) && !is_input(sym)) { + names.push_back(sym->name()); + } + } + if(names.size()>0) { + //for(auto it=names.begin(); it!=names.end(); ++it) { + // text_.add_gutter() << "value_type " << *it; + // text_.end_line("{0};"); + //} + text_.add_gutter() << "value_type " << *(names.begin()); + for(auto it=names.begin()+1; it!=names.end(); ++it) { + text_ << ", " << *it; + } + text_.end_line(";"); + } + } + + // ------------- statements ------------- // + for(auto& stmt : e->statements()) { + if(stmt->is_local_declaration()) continue; + + // these all must be handled + text_.add_gutter(); + stmt->accept(this); + text_.end_line(";"); + } +} + +void CPrinter::visit(IfExpression *e) { + // for now we remove the brackets around the condition because + // the binary expression printer adds them, and we want to work + // around the -Wparentheses-equality warning + text_ << "if("; + e->condition()->accept(this); + text_ << ") {\n"; + increase_indentation(); + e->true_branch()->accept(this); + decrease_indentation(); + text_.add_gutter(); + text_ << "}"; +} + +void CPrinter::visit(ProcedureExpression *e) { + // ------------- print prototype ------------- // + text_.add_gutter() << "void " << e->name() << "(int i_"; + for(auto& arg : e->args()) { + text_ << ", value_type " << arg->is_argument()->name(); + } + if(e->kind() == procedureKind::net_receive) { + text_.end_line(") override {"); + } + else { + text_.end_line(") {"); + } + + if(!e->scope()) { // error: semantic analysis has not been performed + throw compiler_exception( + "CPrinter attempt to print Procedure " + e->name() + + " for which semantic analysis has not been performed", + e->location()); + } + + increase_indentation(); + + e->body()->accept(this); + + // ------------- close up ------------- // + decrease_indentation(); + text_.add_line("}"); + text_.add_line(); + return; +} + +void CPrinter::visit(APIMethod *e) { + // ------------- print prototype ------------- // + text_.add_gutter() << "void " << e->name() << "() override {"; + text_.end_line(); + + if(!e->scope()) { // error: semantic analysis has not been performed + throw compiler_exception( + "CPrinter attempt to print APIMethod " + e->name() + + " for which semantic analysis has not been performed", + e->location()); + } + + // only print the body if it has contents + if(e->is_api_method()->body()->statements().size()) { + increase_indentation(); + + // create local indexed views + for(auto &symbol : e->scope()->locals()) { + auto var = symbol.second->is_local_variable(); + if(var->is_indexed()) { + auto const& name = var->name(); + auto const& index_name = var->external_variable()->index_name(); + text_.add_gutter(); + if(var->is_read()) text_ << "const "; + text_ << "indexed_view_type " + index_name; + auto channel = var->external_variable()->ion_channel(); + if(channel==ionKind::none) { + text_ << "(" + index_name + "_, node_index_);\n"; + } + else { + auto iname = ion_store(channel); + text_ << "(" << iname << "." << name << ", " + << ion_store(channel) << ".index);\n"; + } + } + } + + // ------------- get loop dimensions ------------- // + text_.add_line("int n_ = node_index_.size();"); + + // hand off printing of loops to optimized or unoptimized backend + if(optimize_) { + print_APIMethod_optimized(e); + } + else { + print_APIMethod_unoptimized(e); + } + } + + // ------------- close up ------------- // + text_.add_line("}"); + text_.add_line(); +} + +void CPrinter::print_APIMethod_unoptimized(APIMethod* e) { + //text_.add_line("START_PROFILE"); + + // there can not be more than 1 instance of a density channel per grid point, + // so we can assert that aliasing will not occur. + if(optimize_) text_.add_line("#pragma ivdep"); + + text_.add_line("for(int i_=0; i_<n_; ++i_) {"); + text_.increase_indentation(); + + // loads from external indexed arrays + for(auto &symbol : e->scope()->locals()) { + auto var = symbol.second->is_local_variable(); + if(is_input(var)) { + auto ext = var->external_variable(); + text_.add_gutter() << "value_type "; + var->accept(this); + text_ << " = "; + ext->accept(this); + text_.end_line(";"); + } + } + + // print the body of the loop + e->body()->accept(this); + + // perform update of external variables (currents etc) + for(auto &symbol : e->scope()->locals()) { + auto var = symbol.second->is_local_variable(); + if(is_output(var)) { + auto ext = var->external_variable(); + text_.add_gutter(); + ext->accept(this); + text_ << (ext->op() == tok::plus ? " += " : " -= "); + var->accept(this); + text_.end_line(";"); + } + } + + text_.decrease_indentation(); + text_.add_line("}"); + + //text_.add_line("STOP_PROFILE"); + decrease_indentation(); + + return; +} + +void CPrinter::print_APIMethod_optimized(APIMethod* e) { + // ------------- get mechanism properties ------------- // + + // make a list of all the local variables that have to be + // written out to global memory via an index + auto is_aliased = [this] (Symbol* s) -> LocalVariable* { + if(is_output(s)) { + return s->is_local_variable(); + } + return nullptr; + }; + + std::vector<LocalVariable*> aliased_variables; + if(is_point_process()) { + for(auto &l : e->scope()->locals()) { + if(auto var = is_aliased(l.second.get())) { + aliased_variables.push_back(var); + } + } + } + aliased_output_ = aliased_variables.size()>0; + + // only proceed with optimized output if the ouputs are aliased + // because all optimizations are for using ghost buffers to avoid + // race conditions in vectorized code + if(!aliased_output_) { + print_APIMethod_unoptimized(e); + return; + } + + // ------------- block loop ------------- // + + text_.add_line("constexpr int BSIZE = 4;"); + text_.add_line("int NB = n_/BSIZE;"); + for(auto out: aliased_variables) { + text_.add_line( + "__declspec(align(vector_type::alignment())) value_type " + + out->name() + "[BSIZE];"); + } + //text_.add_line("START_PROFILE"); + + text_.add_line("for(int b_=0; b_<NB; ++b_) {"); + text_.increase_indentation(); + text_.add_line("int BSTART = BSIZE*b_;"); + text_.add_line("int i_ = BSTART;"); + + + // assert that memory accesses are not aliased because we will + // use ghost arrays to ensure that write-back of point processes does + // not lead to race conditions + text_.add_line("#pragma ivdep"); + text_.add_line("for(int j_=0; j_<BSIZE; ++j_, ++i_) {"); + text_.increase_indentation(); + + // loads from external indexed arrays + for(auto &symbol : e->scope()->locals()) { + auto var = symbol.second->is_local_variable(); + if(is_input(var)) { + auto ext = var->external_variable(); + text_.add_gutter() << "value_type "; + var->accept(this); + text_ << " = "; + ext->accept(this); + text_.end_line(";"); + } + } + + e->body()->accept(this); + + text_.decrease_indentation(); + text_.add_line("}"); // end inner compute loop + + text_.add_line("i_ = BSTART;"); + text_.add_line("for(int j_=0; j_<BSIZE; ++j_, ++i_) {"); + text_.increase_indentation(); + + for(auto out: aliased_variables) { + text_.add_gutter(); + auto ext = out->external_variable(); + ext->accept(this); + text_ << (ext->op() == tok::plus ? " += " : " -= "); + out->accept(this); + text_.end_line(";"); + } + + text_.decrease_indentation(); + text_.add_line("}"); // end inner write loop + text_.decrease_indentation(); + text_.add_line("}"); // end outer block loop + + // ------------- block tail loop ------------- // + + text_.add_line("int j_ = 0;"); + text_.add_line("#pragma ivdep"); + text_.add_line("for(int i_=NB*BSIZE; i_<n_; ++j_, ++i_) {"); + text_.increase_indentation(); + + for(auto &symbol : e->scope()->locals()) { + auto var = symbol.second->is_local_variable(); + if(is_input(var)) { + auto ext = var->external_variable(); + text_.add_gutter() << "value_type "; + var->accept(this); + text_ << " = "; + ext->accept(this); + text_.end_line(";"); + } + } + + e->body()->accept(this); + + text_.decrease_indentation(); + text_.add_line("}"); // end inner compute loop + text_.add_line("j_ = 0;"); + text_.add_line("for(int i_=NB*BSIZE; i_<n_; ++j_, ++i_) {"); + text_.increase_indentation(); + + for(auto out: aliased_variables) { + text_.add_gutter(); + auto ext = out->external_variable(); + ext->accept(this); + text_ << (ext->op() == tok::plus ? " += " : " -= "); + out->accept(this); + text_.end_line(";"); + } + + text_.decrease_indentation(); + text_.add_line("}"); // end block tail loop + + //text_.add_line("STOP_PROFILE"); + decrease_indentation(); + + aliased_output_ = false; + return; +} + +void CPrinter::visit(CallExpression *e) { + text_ << e->name() << "(i_"; + for(auto& arg: e->args()) { + text_ << ", "; + arg->accept(this); + } + text_ << ")"; +} + +void CPrinter::visit(AssignmentExpression *e) { + e->lhs()->accept(this); + text_ << " = "; + e->rhs()->accept(this); +} + +void CPrinter::visit(PowBinaryExpression *e) { + text_ << "std::pow("; + e->lhs()->accept(this); + text_ << ", "; + e->rhs()->accept(this); + text_ << ")"; +} + +void CPrinter::visit(BinaryExpression *e) { + auto pop = parent_op_; + // TODO unit tests for parenthesis and binops + bool use_brackets = + Lexer::binop_precedence(pop) > Lexer::binop_precedence(e->op()) + || (pop==tok::divide && e->op()==tok::times); + parent_op_ = e->op(); + + auto lhs = e->lhs(); + auto rhs = e->rhs(); + if(use_brackets) { + text_ << "("; + } + lhs->accept(this); + switch(e->op()) { + case tok::minus : + text_ << "-"; + break; + case tok::plus : + text_ << "+"; + break; + case tok::times : + text_ << "*"; + break; + case tok::divide : + text_ << "/"; + break; + case tok::lt : + text_ << "<"; + break; + case tok::lte : + text_ << "<="; + break; + case tok::gt : + text_ << ">"; + break; + case tok::gte : + text_ << ">="; + break; + case tok::equality : + text_ << "=="; + break; + default : + throw compiler_exception( + "CPrinter unsupported binary operator " + yellow(token_string(e->op())), + e->location()); + } + rhs->accept(this); + if(use_brackets) { + text_ << ")"; + } + + // reset parent precedence + parent_op_ = pop; +} + diff --git a/modcc/src/cprinter.hpp b/modcc/src/cprinter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e669c25f02e71da10e9f4622ad07651bbd558e36 --- /dev/null +++ b/modcc/src/cprinter.hpp @@ -0,0 +1,115 @@ +#pragma once + +#include <sstream> + +#include "module.hpp" +#include "textbuffer.hpp" +#include "visitor.hpp" + +class CPrinter : public Visitor { +public: + CPrinter() {} + CPrinter(Module &m, bool o=false); + + void visit(Expression *e) override; + void visit(UnaryExpression *e) override; + void visit(BinaryExpression *e) override; + void visit(AssignmentExpression *e) override; + void visit(PowBinaryExpression *e) override; + void visit(NumberExpression *e) override; + void visit(VariableExpression *e) override; + + void visit(Symbol *e) override; + void visit(LocalVariable *e) override; + void visit(IndexedVariable *e) override; + + void visit(IdentifierExpression *e) override; + void visit(CallExpression *e) override; + void visit(ProcedureExpression *e) override; + void visit(APIMethod *e) override; + void visit(LocalDeclaration *e) override; + void visit(BlockExpression *e) override; + void visit(IfExpression *e) override; + + std::string text() const { + return text_.str(); + } + + void set_gutter(int w) { + text_.set_gutter(w); + } + void increase_indentation(){ + text_.increase_indentation(); + } + void decrease_indentation(){ + text_.decrease_indentation(); + } +private: + + void print_APIMethod_optimized(APIMethod* e); + void print_APIMethod_unoptimized(APIMethod* e); + + Module *module_ = nullptr; + tok parent_op_ = tok::eq; + TextBuffer text_; + bool optimize_ = false; + bool aliased_output_ = false; + + bool is_input(Symbol *s) { + if(auto l = s->is_local_variable() ) { + if(l->is_local()) { + if(l->is_indexed() && l->is_read()) { + return true; + } + } + } + return false; + } + + bool is_output(Symbol *s) { + if(auto l = s->is_local_variable() ) { + if(l->is_local()) { + if(l->is_indexed() && l->is_write()) { + return true; + } + } + } + return false; + } + + bool is_arg_local(Symbol *s) { + if(auto l=s->is_local_variable()) { + if(l->is_arg()) { + return true; + } + } + return false; + } + + bool is_indexed_local(Symbol *s) { + if(auto l=s->is_local_variable()) { + if(l->is_indexed()) { + return true; + } + } + return false; + } + + bool is_ghost_local(Symbol *s) { + if(!is_point_process()) return false; + if(!optimize_) return false; + if(!aliased_output_) return false; + if(is_arg_local(s)) return false; + return is_output(s); + } + + bool is_stack_local(Symbol *s) { + if(is_arg_local(s)) return false; + return !is_ghost_local(s); + } + + bool is_point_process() { + return module_->kind() == moduleKind::point; + } +}; + diff --git a/modcc/src/cudaprinter.cpp b/modcc/src/cudaprinter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..24433d2be230ea0d083e4223698cecec501577fd --- /dev/null +++ b/modcc/src/cudaprinter.cpp @@ -0,0 +1,652 @@ +#include "cudaprinter.hpp" +#include "lexer.hpp" + +/****************************************************************************** +******************************************************************************/ + +CUDAPrinter::CUDAPrinter(Module &m, bool o) + : module_(&m) +{ + // make a list of vector types, both parameters and assigned + // and a list of all scalar types + std::vector<VariableExpression*> scalar_variables; + std::vector<VariableExpression*> array_variables; + for(auto& sym: m.symbols()) { + if(sym.second->kind()==symbolKind::variable) { + auto var = sym.second->is_variable(); + if(var->is_range()) { + array_variables.push_back(var); + } + else { + scalar_variables.push_back(var) ; + } + } + } + + ////////////////////////////////////////////// + // header files + ////////////////////////////////////////////// + text_.add_line("#pragma once"); + text_.add_line(); + text_.add_line("#include <cmath>"); + text_.add_line("#include <limits>"); + text_.add_line(); + text_.add_line("#include <indexedview.hpp>"); + text_.add_line("#include <mechanism.hpp>"); + text_.add_line("#include <target.hpp>"); + text_.add_line(); + + //////////////////////////////////////////////////////////// + // generate the parameter pack + //////////////////////////////////////////////////////////// + std::vector<std::string> param_pack; + text_.add_line("template <typename T, typename I>"); + text_.add_gutter() << "struct " << m.name() << "_ParamPack {"; + text_.end_line(); + text_.increase_indentation(); + text_.add_line("// array parameters"); + for(auto const &var: array_variables) { + text_.add_line("T* " + var->name() + ";"); + param_pack.push_back(var->name() + ".data()"); + } + text_.add_line("// scalar parameters"); + for(auto const &var: scalar_variables) { + text_.add_line("T " + var->name() + ";"); + param_pack.push_back(var->name()); + } + text_.add_line("// ion channel dependencies"); + for(auto& ion: m.neuron_block().ions) { + auto tname = "ion_" + ion.name; + for(auto& field : ion.read) { + text_.add_line("T* ion_" + field.spelling + ";"); + param_pack.push_back(tname + "." + field.spelling + ".data()"); + } + for(auto& field : ion.write) { + text_.add_line("T* ion_" + field.spelling + ";"); + param_pack.push_back(tname + "." + field.spelling + ".data()"); + } + text_.add_line("I* ion_" + ion.name + "_idx_;"); + param_pack.push_back(tname + ".index.data()"); + } + + text_.add_line("// matrix"); + text_.add_line("T* vec_rhs;"); + text_.add_line("T* vec_d;"); + text_.add_line("T* vec_v;"); + param_pack.push_back("matrix_.vec_rhs().data()"); + param_pack.push_back("matrix_.vec_d().data()"); + param_pack.push_back("matrix_.vec_v().data()"); + + text_.add_line("// node index information"); + text_.add_line("I* ni;"); + text_.add_line("unsigned long n;"); + text_.decrease_indentation(); + text_.add_line("};"); + text_.add_line(); + param_pack.push_back("node_indices_.data()"); + param_pack.push_back("node_indices_.size()"); + + + //////////////////////////////////////////////////////// + // write the CUDA kernels + //////////////////////////////////////////////////////// + text_.add_line("namespace impl {"); + text_.add_line("namespace " + m.name() + " {"); + text_.add_line(); + { + increase_indentation(); + // forward declarations of procedures + for(auto const &var : m.symbols()) { + if( var.second->kind()==symbolKind::procedure + && var.second->is_procedure()->kind() == procedureKind::normal) + { + print_procedure_prototype(var.second->is_procedure()); + text_.end_line(";"); + text_.add_line(); + } + } + + // print stubs that call API method kernels that are defined in the + // imp::name namespace + auto proctest = [] (procedureKind k) {return k == procedureKind::normal + || k == procedureKind::api; }; + for(auto const &var : m.symbols()) { + if( var.second->kind()==symbolKind::procedure + && proctest(var.second->is_procedure()->kind())) + { + var.second->accept(this); + } + } + decrease_indentation(); + } + text_.add_line("} // namespace " + m.name()); + text_.add_line("} // namespace impl"); + text_.add_line(); + + ////////////////////////////////////////////// + ////////////////////////////////////////////// + std::string class_name = "Mechanism_" + m.name(); + + text_ << "template<typename T, typename I>\n"; + text_ << "class " + class_name + " : public Mechanism<T, I, targetKind::gpu> {\n"; + text_ << "public:\n\n"; + text_ << " using base = Mechanism<T, I, targetKind::gpu>;\n"; + text_ << " using value_type = typename base::value_type;\n"; + text_ << " using size_type = typename base::size_type;\n"; + text_ << " using vector_type = typename base::vector_type;\n"; + text_ << " using view_type = typename base::view_type;\n"; + text_ << " using index_type = typename base::index_type;\n"; + text_ << " using index_view = typename index_type::view_type;\n"; + text_ << " using indexed_view= typename base::indexed_view;\n\n"; + text_ << " using matrix_type = typename base::matrix_type;\n\n"; + text_ << " using param_pack_type = " << m.name() << "_ParamPack<T,I>;\n\n"; + + ////////////////////////////////////////////// + ////////////////////////////////////////////// + for(auto& ion: m.neuron_block().ions) { + auto tname = "Ion" + ion.name; + text_ << " struct " + tname + " {\n"; + for(auto& field : ion.read) { + text_ << " view_type " + field.spelling + ";\n"; + } + for(auto& field : ion.write) { + text_ << " view_type " + field.spelling + ";\n"; + } + text_ << " index_type index;\n"; + text_ << " std::size_t memory() const { return sizeof(size_type)*index.size(); }\n"; + text_ << " std::size_t size() const { return index.size(); }\n"; + text_ << " };\n"; + text_ << " " + tname + " ion_" + ion.name + ";\n\n"; + } + + ////////////////////////////////////////////// + ////////////////////////////////////////////// + int num_vars = array_variables.size(); + text_ << " " + class_name + "(\n"; + text_ << " matrix_type* matrix,\n"; + text_ << " index_view node_indices)\n"; + text_ << " : base(matrix, node_indices)\n"; + text_ << " {\n"; + text_ << " size_type num_fields = " << num_vars << ";\n"; + text_ << " size_type n = size();\n"; + text_ << " data_ = vector_type(n * num_fields);\n"; + text_ << " data_(memory::all) = std::numeric_limits<value_type>::quiet_NaN();\n"; + for(int i=0; i<num_vars; ++i) { + char namestr[128]; + sprintf(namestr, "%-15s", array_variables[i]->name().c_str()); + text_ << " " << namestr << " = data_(" << i << "*n, " << i+1 << "*n);\n"; + } + for(auto const& var : array_variables) { + double val = var->value(); + // only non-NaN fields need to be initialized, because data_ + // is NaN by default + if(val == val) { + text_ << " " << var->name() << "(memory::all) = " << val << ";\n"; + } + } + + text_ << " INIT_PROFILE\n"; + text_ << " }\n\n"; + + ////////////////////////////////////////////// + ////////////////////////////////////////////// + + text_ << " using base::size;\n\n"; + + text_ << " std::size_t memory() const override {\n"; + text_ << " auto s = std::size_t{0};\n"; + text_ << " s += data_.size()*sizeof(value_type);\n"; + for(auto& ion: m.neuron_block().ions) { + text_ << " s += ion_" + ion.name + ".memory();\n"; + } + text_ << " return s;\n"; + text_ << " }\n\n"; + + text_ << " void set_params(value_type t_, value_type dt_) override {\n"; + text_ << " t = t_;\n"; + text_ << " dt = dt_;\n"; + text_ << " param_pack_ = param_pack_type{\n"; + //for(auto i=0; i<param_pack.size(); ++i) + for(auto &str: param_pack) { + text_ << " " << str << ",\n"; + } + text_ << " };\n"; + text_ << " }\n\n"; + + text_ << " std::string name() const override {\n"; + text_ << " return \"" << m.name() << "\";\n"; + text_ << " }\n\n"; + + std::string kind_str = m.kind() == moduleKind::density + ? "mechanismKind::density" + : "mechanismKind::point_process"; + text_ << " mechanismKind kind() const override {\n"; + text_ << " return " << kind_str << ";\n"; + text_ << " }\n\n"; + + + ////////////////////////////////////////////// + ////////////////////////////////////////////// + + auto proctest = [] (procedureKind k) {return k == procedureKind::api;}; + for(auto const &var : m.symbols()) { + if( var.second->kind()==symbolKind::procedure + && proctest(var.second->is_procedure()->kind())) + { + auto proc = var.second->is_api_method(); + auto name = proc->name(); + text_ << " void " << name << "() {\n"; + text_ << " auto n = size();\n"; + text_ << " auto thread_dim = 192;\n"; + text_ << " dim3 dim_block(thread_dim);\n"; + text_ << " dim3 dim_grid(n/dim_block.x + (n%dim_block.x ? 1 : 0) );\n\n"; + text_ << " START_PROFILE\n"; + text_ << " impl::" << m.name() << "::" << name << "<T,I>" + << "<<<dim_grid, dim_block>>>(param_pack_);\n"; + text_ << " STOP_PROFILE\n"; + text_ << " }\n"; + } + } + + ////////////////////////////////////////////// + ////////////////////////////////////////////// + + //text_ << "private:\n\n"; + text_ << " vector_type data_;\n\n"; + for(auto var: array_variables) { + text_ << " view_type " << var->name() << ";\n"; + } + for(auto var: scalar_variables) { + double val = var->value(); + // test the default value for NaN + // useful for error propogation from bad initial conditions + if(val==val) { + text_ << " value_type " << var->name() << " = " << val << ";\n"; + } + else { + // the cuda compiler has a bug that doesn't allow initialization of + // class members with std::numer_limites<>. So simply set to zero. + text_ << " value_type " << var->name() + << " = value_type{0};\n"; + } + } + + text_ << " using base::matrix_;\n"; + text_ << " using base::node_indices_;\n\n"; + text_ << " param_pack_type param_pack_;\n\n"; + text_ << " DATA_PROFILE\n"; + text_ << "};\n"; +} + +void CUDAPrinter::visit(Expression *e) { + throw compiler_exception( + "CUDAPrinter doesn't know how to print " + e->to_string(), + e->location()); +} + +void CUDAPrinter::visit(LocalDeclaration *e) { +} + +void CUDAPrinter::visit(NumberExpression *e) { + text_ << " " << e->value(); +} + +void CUDAPrinter::visit(IdentifierExpression *e) { + e->symbol()->accept(this); +} + +void CUDAPrinter::visit(Symbol *e) { + text_ << e->name(); +} + +void CUDAPrinter::visit(VariableExpression *e) { + text_ << "params_." << e->name(); + if(e->is_range()) { + text_ << "[" << index_string(e) << "]"; + } +} + +std::string CUDAPrinter::index_string(Symbol *s) { + if(s->is_variable()) { + return "tid_"; + } + else if(auto var = s->is_indexed_variable()) { + switch(var->ion_channel()) { + case ionKind::none : + return "gid_"; + case ionKind::Ca : + return "caid_"; + case ionKind::Na : + return "naid_"; + case ionKind::K : + return "kid_"; + // a nonspecific ion current should never be indexed: it is + // local to a mechanism + case ionKind::nonspecific: + break; + default : + throw compiler_exception( + "CUDAPrinter unknown ion type", + s->location()); + } + } + return ""; +} + +void CUDAPrinter::visit(IndexedVariable *e) { + text_ << "params_." << e->index_name() << "[" << index_string(e) << "]"; +} + +void CUDAPrinter::visit(LocalVariable *e) { + std::string const& name = e->name(); + text_ << name; +} + +void CUDAPrinter::visit(UnaryExpression *e) { + auto b = (e->expression()->is_binary()!=nullptr); + switch(e->op()) { + case tok::minus : + // place a space in front of minus sign to avoid invalid + // expressions of the form : (v[i]--67) + // use parenthesis if expression is a binop, otherwise + // -(v+2) becomes -v+2 + if(b) text_ << " -("; + else text_ << " -"; + e->expression()->accept(this); + if(b) text_ << ")"; + return; + case tok::exp : + text_ << "exp("; + e->expression()->accept(this); + text_ << ")"; + return; + case tok::cos : + text_ << "cos("; + e->expression()->accept(this); + text_ << ")"; + return; + case tok::sin : + text_ << "sin("; + e->expression()->accept(this); + text_ << ")"; + return; + case tok::log : + text_ << "log("; + e->expression()->accept(this); + text_ << ")"; + return; + default : + throw compiler_exception( + "CUDAPrinter unsupported unary operator " + yellow(token_string(e->op())), + e->location()); + } +} + +void CUDAPrinter::visit(BlockExpression *e) { + // ------------- declare local variables ------------- // + // only if this is the outer block + if(!e->is_nested()) { + for(auto& var : e->scope()->locals()) { + auto sym = var.second.get(); + // input variables are declared earlier, before the + // block body is printed + if(is_stack_local(sym) && !is_input(sym)) { + text_.add_line("value_type " + var.first + ";"); + } + } + } + + // ------------- statements ------------- // + for(auto& stmt : e->statements()) { + if(stmt->is_local_declaration()) continue; + // these all must be handled + text_.add_gutter(); + stmt->accept(this); + text_.end_line(";"); + } +} + +void CUDAPrinter::visit(IfExpression *e) { + // for now we remove the brackets around the condition because + // the binary expression printer adds them, and we want to work + // around the -Wparentheses-equality warning + text_ << "if("; + e->condition()->accept(this); + text_ << ") {\n"; + increase_indentation(); + e->true_branch()->accept(this); + decrease_indentation(); + text_.add_gutter(); + text_ << "}"; +} + +void CUDAPrinter::print_procedure_prototype(ProcedureExpression *e) { + text_.add_gutter() << "template <typename T, typename I>\n"; + text_.add_line("__device__"); + text_.add_gutter() << "void " << e->name() + << "(" << module_->name() << "_ParamPack<T,I> const& params_," + << "const int tid_"; + for(auto& arg : e->args()) { + text_ << ", T " << arg->is_argument()->name(); + } + text_ << ")"; +} + +void CUDAPrinter::visit(ProcedureExpression *e) { + // error: semantic analysis has not been performed + if(!e->scope()) { // error: semantic analysis has not been performed + throw compiler_exception( + "CUDAPrinter attempt to print Procedure " + e->name() + + " for which semantic analysis has not been performed", + e->location()); + } + + // ------------- print prototype ------------- // + print_procedure_prototype(e); + text_.end_line(" {"); + + // ------------- print body ------------- // + increase_indentation(); + + text_.add_line("using value_type = T;"); + text_.add_line("using index_type = I;"); + text_.add_line(); + + e->body()->accept(this); + + // ------------- close up ------------- // + decrease_indentation(); + text_.add_line("}"); + text_.add_line(); + return; +} + +void CUDAPrinter::visit(APIMethod *e) { + // ------------- print prototype ------------- // + text_.add_gutter() << "template <typename T, typename I>\n"; + text_.add_line( "__global__"); + text_.add_gutter() << "void " << e->name() + << "(" << module_->name() << "_ParamPack<T,I> params_) {"; + text_.add_line(); + + if(!e->scope()) { // error: semantic analysis has not been performed + throw compiler_exception( + "CUDAPrinter attempt to print APIMethod " + e->name() + + " for which semantic analysis has not been performed", + e->location()); + } + increase_indentation(); + + text_.add_line("using value_type = T;"); + text_.add_line("using index_type = I;"); + text_.add_line(); + + text_.add_line("auto tid_ = threadIdx.x + blockDim.x*blockIdx.x;"); + text_.add_line("auto const n_ = params_.n;"); + text_.add_line(); + text_.add_line("if(tid_<n_) {"); + increase_indentation(); + + text_.add_line("auto gid_ __attribute__((unused)) = params_.ni[tid_];"); + + print_APIMethod_body(e); + + decrease_indentation(); + text_.add_line("}"); + + decrease_indentation(); + text_.add_line("}\n"); +} + +void CUDAPrinter::print_APIMethod_body(APIMethod* e) { + // load indexes of ion channels + auto uses_k = false; + auto uses_na = false; + auto uses_ca = false; + for(auto &symbol : e->scope()->locals()) { + auto ch = symbol.second->is_local_variable()->ion_channel(); + if(!uses_k && (uses_k = (ch == ionKind::K)) ) { + text_.add_line("auto kid_ = params_.ion_k_idx_[tid_];"); + } + if(!uses_ca && (uses_ca = (ch == ionKind::Ca)) ) { + text_.add_line("auto caid_ = params_.ion_ca_idx_[tid_];"); + } + if(!uses_na && (uses_na = (ch == ionKind::Na)) ) { + text_.add_line("auto naid_ = params_.ion_na_idx_[tid_];"); + } + } + + // shadows for indexed arrays + for(auto &symbol : e->scope()->locals()) { + auto var = symbol.second->is_local_variable(); + if(is_input(var)) { + auto ext = var->external_variable(); + text_.add_gutter() << "value_type "; + var->accept(this); + text_ << " = "; + ext->accept(this); + text_.end_line("; // indexed load"); + } + else if (is_output(var)) { + text_.add_gutter() << "value_type " << var->name() << ";"; + text_.end_line(); + } + } + + text_.add_line(); + text_.add_line("// the kernel computation"); + + e->body()->accept(this); + + // insert stores here + // take care to use atomic operations for the updates for point processes, where + // more than one thread may try add/subtract to the same memory location + auto has_outputs = false; + for(auto &symbol : e->scope()->locals()) { + auto in = symbol.second->is_local_variable(); + auto out = in->external_variable(); + if(out==nullptr || !is_output(in)) continue; + if(!has_outputs) { + text_.add_line(); + text_.add_line("// stores to indexed global memory"); + has_outputs = true; + } + text_.add_gutter(); + if(!is_point_process()) { + out->accept(this); + text_ << (out->op()==tok::plus ? " += " : " -= "); + in->accept(this); + } + else { + text_ << (out->op()==tok::plus ? "atomicAdd" : "atomicSub") << "(&"; + out->accept(this); + text_ << ", "; + in->accept(this); + text_ << ")"; + } + text_.end_line(";"); + } + + return; +} + +void CUDAPrinter::visit(CallExpression *e) { + text_ << e->name() << "<T,I>(params_, tid_"; + for(auto& arg: e->args()) { + text_ << ", "; + arg->accept(this); + } + text_ << ")"; +} + +void CUDAPrinter::visit(AssignmentExpression *e) { + e->lhs()->accept(this); + text_ << " = "; + e->rhs()->accept(this); +} + +void CUDAPrinter::visit(PowBinaryExpression *e) { + text_ << "std::pow("; + e->lhs()->accept(this); + text_ << ", "; + e->rhs()->accept(this); + text_ << ")"; +} + +void CUDAPrinter::visit(BinaryExpression *e) { + auto pop = parent_op_; + // TODO unit tests for parenthesis and binops + bool use_brackets = + Lexer::binop_precedence(pop) > Lexer::binop_precedence(e->op()) + || (pop==tok::divide && e->op()==tok::times); + parent_op_ = e->op(); + + + auto lhs = e->lhs(); + auto rhs = e->rhs(); + if(use_brackets) { + text_ << "("; + } + lhs->accept(this); + switch(e->op()) { + case tok::minus : + text_ << "-"; + break; + case tok::plus : + text_ << "+"; + break; + case tok::times : + text_ << "*"; + break; + case tok::divide : + text_ << "/"; + break; + case tok::lt : + text_ << "<"; + break; + case tok::lte : + text_ << "<="; + break; + case tok::gt : + text_ << ">"; + break; + case tok::gte : + text_ << ">="; + break; + case tok::equality : + text_ << "=="; + break; + default : + throw compiler_exception( + "CUDAPrinter unsupported binary operator " + yellow(token_string(e->op())), + e->location()); + } + rhs->accept(this); + if(use_brackets) { + text_ << ")"; + } + + // reset parent precedence + parent_op_ = pop; +} + diff --git a/modcc/src/cudaprinter.hpp b/modcc/src/cudaprinter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0b2cb00480f90078b068e48474f941fb566b24dc --- /dev/null +++ b/modcc/src/cudaprinter.hpp @@ -0,0 +1,109 @@ +#pragma once + +#include <sstream> + +#include "module.hpp" +#include "textbuffer.hpp" +#include "visitor.hpp" + +class CUDAPrinter : public Visitor { +public: + CUDAPrinter() {} + CUDAPrinter(Module &m, bool o=false); + + void visit(Expression *e) override; + void visit(UnaryExpression *e) override; + void visit(BinaryExpression *e) override; + void visit(AssignmentExpression *e) override; + void visit(PowBinaryExpression *e) override; + void visit(NumberExpression *e) override; + void visit(VariableExpression *e) override; + + void visit(Symbol *e) override; + void visit(LocalVariable *e) override; + void visit(IndexedVariable *e) override; + + void visit(IdentifierExpression *e) override; + void visit(CallExpression *e) override; + void visit(ProcedureExpression *e) override; + void visit(APIMethod *e) override; + void visit(LocalDeclaration *e) override; + void visit(BlockExpression *e) override; + void visit(IfExpression *e) override; + + std::string text() const { + return text_.str(); + } + + void set_gutter(int w) { + text_.set_gutter(w); + } + void increase_indentation(){ + text_.increase_indentation(); + } + void decrease_indentation(){ + text_.decrease_indentation(); + } +private: + + bool is_input(Symbol *s) { + if(auto l = s->is_local_variable() ) { + if(l->is_local()) { + if(l->is_indexed() && l->is_read()) { + return true; + } + } + } + return false; + } + + bool is_output(Symbol *s) { + if(auto l = s->is_local_variable() ) { + if(l->is_local()) { + if(l->is_indexed() && l->is_write()) { + return true; + } + } + } + return false; + } + + bool is_indexed_local(Symbol *s) { + if(auto l=s->is_local_variable()) { + if(l->is_indexed()) { + return true; + } + } + return false; + } + + bool is_arg_local(Symbol *s) { + if(auto l=s->is_local_variable()) { + if(l->is_arg()) { + return true; + } + } + return false; + } + + bool is_stack_local(Symbol *s) { + if(is_arg_local(s)) return false; + if(is_input(s)) return false; + if(is_output(s)) return false; + return true; + } + + bool is_point_process() const { + return module_->kind() == moduleKind::point; + } + + void print_APIMethod_body(APIMethod* e); + void print_procedure_prototype(ProcedureExpression *e); + std::string index_string(Symbol *e); + + Module *module_ = nullptr; + tok parent_op_ = tok::eq; + TextBuffer text_; + //bool optimize_ = false; +}; + diff --git a/modcc/src/error.hpp b/modcc/src/error.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f113eff06de12c34c2f766e7e1b53e8db69e9101 --- /dev/null +++ b/modcc/src/error.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include "location.hpp" + +class compiler_exception : public std::exception { +public: + compiler_exception(std::string m, Location location) + : location_(location), + message_(std::move(m)) + {} + + virtual const char* what() const throw() { + return message_.c_str(); + } + + Location const& location() const { + return location_; + } + +private: + + Location location_; + std::string message_; +}; + diff --git a/modcc/src/errorvisitor.cpp b/modcc/src/errorvisitor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..225ec97d1f79972b32746f82606e5279f3112ff2 --- /dev/null +++ b/modcc/src/errorvisitor.cpp @@ -0,0 +1,78 @@ +#include "errorvisitor.hpp" + +/* + * we use a post order walk to print the erros in an expression after those + * in all of its children + */ + +void ErrorVisitor::visit(Expression *e) { + print_error(e); +} + +// traverse the statements in a procedure +void ErrorVisitor::visit(ProcedureExpression *e) { + for(auto& expression : e->args()) { + expression->accept(this); + } + + e->body()->accept(this); + print_error(e); +} + +// traverse the statements in a function +void ErrorVisitor::visit(FunctionExpression *e) { + for(auto& expression : e->args()) { + expression->accept(this); + } + + e->body()->accept(this); + print_error(e); +} + +// an if statement +void ErrorVisitor::visit(IfExpression *e) { + e->true_branch()->accept(this); + if(e->false_branch()) { + e->false_branch()->accept(this); + } + + print_error(e); +} + +void ErrorVisitor::visit(BlockExpression* e) { + for(auto& expression : e->statements()) { + expression->accept(this); + } + + print_error(e); +} + +void ErrorVisitor::visit(InitialBlock* e) { + for(auto& expression : e->statements()) { + expression->accept(this); + } + + print_error(e); +} + +// unary expresssion +void ErrorVisitor::visit(UnaryExpression *e) { + e->expression()->accept(this); + print_error(e); +} + +// binary expresssion +void ErrorVisitor::visit(BinaryExpression *e) { + e->lhs()->accept(this); + e->rhs()->accept(this); + print_error(e); +} + +// binary expresssion +void ErrorVisitor::visit(CallExpression *e) { + for(auto& expression: e->args()) { + expression->accept(this); + } + print_error(e); +} + diff --git a/modcc/src/errorvisitor.hpp b/modcc/src/errorvisitor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..39afad2be75f998ce855a31cf3bb27147c694332 --- /dev/null +++ b/modcc/src/errorvisitor.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include <iostream> +#include "visitor.hpp" +#include "expression.hpp" + +class ErrorVisitor : public Visitor { +public: + ErrorVisitor(std::string const& m) + : module_name_(m) + {} + + void visit(Expression *e) override; + void visit(ProcedureExpression *e) override; + void visit(FunctionExpression *e) override; + void visit(UnaryExpression *e) override; + void visit(BinaryExpression *e) override; + void visit(CallExpression *e) override; + + void visit(BlockExpression *e) override; + void visit(InitialBlock *e) override; + void visit(IfExpression *e) override; + + int num_errors() {return num_errors_;} + int num_warnings() {return num_warnings_;} +private: + template <typename ExpressionType> + void print_error(ExpressionType *e) { + if(e->has_error()) { + auto header = red("error: ") + + white(pprintf("% % ", module_name_, e->location())); + std::cout << header << "\n " + << e->error_message() + << std::endl; + num_errors_++; + } + if(e->has_warning()) { + auto header = purple("warning: ") + + white(pprintf("% % ", module_name_, e->location())); + std::cout << header << "\n " + << e->warning_message() + << std::endl; + num_warnings_++; + } + } + + std::string module_name_; + int num_errors_ = 0; + int num_warnings_ = 0; +}; + diff --git a/modcc/src/expression.cpp b/modcc/src/expression.cpp new file mode 100644 index 0000000000000000000000000000000000000000..96fc173d7576084f1bdb1547d7c8b4e78f94d9bb --- /dev/null +++ b/modcc/src/expression.cpp @@ -0,0 +1,891 @@ +#include <cstring> + +#include "expression.hpp" + +inline std::string to_string(symbolKind k) { + switch (k) { + case symbolKind::variable: + return std::string("variable"); + case symbolKind::indexed_variable: + return std::string("indexed variable"); + case symbolKind::local_variable: + return std::string("local"); + case symbolKind::procedure: + return std::string("procedure"); + case symbolKind::function: + return std::string("function"); + } + return ""; +} + + + +inline std::string to_string(procedureKind k) { + switch(k) { + case procedureKind::normal : + return "procedure"; + case procedureKind::api : + return "APIprocedure"; + case procedureKind::initial : + return "initial"; + case procedureKind::net_receive : + return "net_receive"; + case procedureKind::breakpoint : + return "breakpoint"; + case procedureKind::derivative : + return "derivative"; + default : + return "undefined"; + } +} + +/******************************************************************************* + Expression +*******************************************************************************/ + +void Expression::semantic(std::shared_ptr<scope_type>) { + error("semantic() has not been implemented for this expression"); +} + +expression_ptr Expression::clone() const { + throw compiler_exception( + "clone() has not been implemented for " + this->to_string(), + location_); +} + +/******************************************************************************* + Symbol +*******************************************************************************/ + +std::string Symbol::to_string() const { + return blue("Symbol") + " " + yellow(name_); +} + +/******************************************************************************* + LocalVariable +*******************************************************************************/ + +std::string LocalVariable::to_string() const { + std::string s = blue("Local Variable") + " " + yellow(name()); + if(is_indexed()) { + s += " ->(" + token_string(external_->op()) + ") " + yellow(external_->index_name()); + } + return s; +} + +/******************************************************************************* + IdentifierExpression +*******************************************************************************/ + +void IdentifierExpression::semantic(std::shared_ptr<scope_type> scp) { + scope_ = scp; + + auto s = scope_->find(spelling_); + + if(s==nullptr) { + error( pprintf("the variable '%' is undefined", + yellow(spelling_), location_)); + return; + } + if(s->kind() == symbolKind::procedure || s->kind() == symbolKind::function) { + error( pprintf("the symbol '%' is a function/procedure, not a variable", + yellow(spelling_))); + return; + } + // if the symbol is an indexed variable, this is the first time that the + // indexed variable is used in this procedure. In which case, we create + // a local variable which refers to the indexed variable, which will be + // found for any subsequent variable lookup inside the procedure + if(auto sym = s->is_indexed_variable()) { + auto var = new LocalVariable(location_, spelling_); + var->external_variable(sym); + s = scope_->add_local_symbol(spelling_, scope_type::symbol_ptr{var}); + } + + // save the symbol + symbol_ = s; +} + +expression_ptr IdentifierExpression::clone() const { + return make_expression<IdentifierExpression>(location_, spelling_); +} + +bool IdentifierExpression::is_lvalue() { + // check for global variable that is writeable + auto var = symbol_->is_variable(); + if(var) return var->is_writeable(); + + // else look for local symbol + if( symbol_->kind() == symbolKind::local_variable ) { + return true; + } + + return false; +} + +/******************************************************************************* + NumberExpression +********************************************************************************/ + +expression_ptr NumberExpression::clone() const { + return make_expression<NumberExpression>(location_, value_); +} + +/******************************************************************************* + LocalDeclaration +*******************************************************************************/ + +std::string LocalDeclaration::to_string() const { + std::string str = blue("local"); + for(auto v : vars_) { + str += " " + yellow(v.first); + } + return str; +} + +expression_ptr LocalDeclaration::clone() const { + auto local = new LocalDeclaration(location()); + for(auto &v : vars_) { + local->add_variable(v.second); + } + return expression_ptr{local}; +} + +bool LocalDeclaration::add_variable(Token tok) { + if(vars_.find(tok.spelling)!=vars_.end()) { + error( "the variable '" + yellow(tok.spelling) + "' is defined more than once"); + return false; + } + + vars_[tok.spelling] = tok; + return true; +} + +void LocalDeclaration::semantic(std::shared_ptr<scope_type> scp) { + scope_ = scp; + + // loop over the variables declared in this LOCAL statement + for(auto &v : vars_) { + auto &name = v.first; + auto s = scope_->find(name); + + // First check that the variable is undefined + // Note that we allow for local variables with the same name as + // class scope variables (globals), in which case the local variable + // name will be used for lookup + if( s==nullptr // symbol has not been defined yet + || s->kind()==symbolKind::variable // symbol is defined at global scope + || s->kind()==symbolKind::indexed_variable) + { + if(s && s->kind()==symbolKind::indexed_variable) { + warning(pprintf("The local variable '%' clashes with the indexed" + " variable defined at %, which will be ignored." + " Remove the local definition of this variable" + " if the previously defined variable was intended.", + yellow(name), s->location() )); + } else { + auto symbol = make_symbol<LocalVariable>(location_, name); + symbols_.push_back( scope_->add_local_symbol(name, std::move(symbol)) ); + } + } + else { + error(pprintf("the symbol '%' has already been defined at %", + yellow(name), s->location() )); + } + } +} + +/******************************************************************************* + ArgumentExpression +*******************************************************************************/ +std::string ArgumentExpression::to_string() const { + return blue("arg") + " " + yellow(name_); +} + +void ArgumentExpression::semantic(std::shared_ptr<scope_type> scp) { + scope_ = scp; + + auto s = scope_->find(name_); + + if(s==nullptr || s->kind()==symbolKind::variable || s->kind()==symbolKind::indexed_variable) { + auto symbol = make_symbol<LocalVariable>( location_, name_, localVariableKind::argument ); + scope_->add_local_symbol(name_, std::move(symbol)); + } + else { + error(pprintf("the symbol '%' has already been defined at %", + yellow(name_), s->location() )); + } +} + + +/******************************************************************************* + VariableExpression +*******************************************************************************/ + +std::string VariableExpression::to_string() const { + char n[17]; + snprintf(n, 17, "%-10s", name().c_str()); + std::string + s = blue("variable") + " " + yellow(n) + "(" + + colorize("write", is_writeable() ? stringColor::green : stringColor::red) + ", " + + colorize("read", is_readable() ? stringColor::green : stringColor::red) + ", " + + (is_range() ? "range" : "scalar") + ", " + + "ion" + colorize(::to_string(ion_channel()), + (ion_channel()==ionKind::none) ? stringColor::red : stringColor::green) + ", " + + "vis " + ::to_string(visibility()) + ", " + + "link " + ::to_string(linkage()) + ", " + + colorize("state", is_state() ? stringColor::green : stringColor::red) + ")"; + return s; +} + +/******************************************************************************* + IndexedVariable +*******************************************************************************/ + +std::string IndexedVariable::to_string() const { + auto ch = ::to_string(ion_channel()); + return + blue("indexed") + " " + yellow(name()) + "->" + yellow(index_name()) + "(" + + (is_write() ? " write-only" : " read-only") + + ", ion" + (ion_channel()==ionKind::none ? red(ch) : green(ch)) + ") "; +} + +/******************************************************************************* + CallExpression +*******************************************************************************/ + +std::string CallExpression::to_string() const { + std::string str = blue("call") + " " + yellow(spelling_) + " ("; + for(auto& arg : args_) + str += arg->to_string() + ", "; + str += ")"; + + return str; +} + +void CallExpression::semantic(std::shared_ptr<scope_type> scp) { + scope_ = scp; + + // look up to see if symbol is defined + // restrict search to global namespace + auto s = scope_->find_global(spelling_); + + // either undefined or refers to a variable + if(!s) { + error(pprintf("there is no function or procedure named '%' ", + yellow(spelling_))); + } + if(s->kind()==symbolKind::local_variable || s->kind()==symbolKind::variable) { + error(pprintf("the symbol '%' refers to a variable, but it is being" + " called like a function", yellow(spelling_) )); + } + + // save the symbol + symbol_ = s; + + // check that the number of passed arguments matches + if( !has_error() ) { // only analyze if the call was found + int expected_args; + if(auto f = function()) { + expected_args = f->args().size(); + } + else { + expected_args = procedure()->args().size(); + } + if(args_.size() != unsigned(expected_args)) { + error(pprintf("call has the wrong number of arguments: expected %" + ", received %", expected_args, args_.size())); + } + } + + // perform semantic analysis on the arguments + for(auto& a : args_) { + a->semantic(scp); + } +} + +expression_ptr CallExpression::clone() const { + // clone the arguments + std::vector<expression_ptr> cloned_args; + for(auto& a: args_) { + cloned_args.emplace_back(a->clone()); + } + + return make_expression<CallExpression>(location_, spelling_, std::move(cloned_args)); +} + +/******************************************************************************* + ProcedureExpression +*******************************************************************************/ + +std::string ProcedureExpression::to_string() const { + std::string str = blue("procedure") + " " + yellow(name()) + "\n"; + str += blue(" special") + " : " + ::to_string(kind_) + "\n"; + str += blue(" args") + " : "; + for(auto& arg : args_) + str += arg->to_string() + " "; + str += "\n "+blue("body")+" :"; + str += body_->to_string(); + + return str; +} + +void ProcedureExpression::semantic(scope_type::symbol_map &global_symbols) { + // assert that the symbol is already visible in the global_symbols + if(global_symbols.find(name()) == global_symbols.end()) { + throw compiler_exception( + "attempt to perform semantic analysis for procedure '" + + yellow(name()) + + "' which has not been added to global symbol table", + location_); + } + + // create the scope for this procedure + scope_ = std::make_shared<scope_type>(global_symbols); + + // add the argumemts to the list of local variables + for(auto& a : args_) { + a->semantic(scope_); + } + + // this loop could be used to then check the types of statements in the body + for(auto& e : *(body_->is_block())) { + if(e->is_initial_block()) + error("INITIAL block not allowed inside "+::to_string(kind_)+" definition"); + } + + // perform semantic analysis for each expression in the body + body_->semantic(scope_); + + // the symbol for this expression is itself + symbol_ = scope_->find_global(name()); +} + +/******************************************************************************* + APIMethod +*******************************************************************************/ + +std::string APIMethod::to_string() const { + auto namestr = [] (Symbol* e) -> std::string { + return yellow(e->name()); + return ""; + }; + std::string str = blue("API method") + " " + yellow(name()) + "\n"; + + str += blue(" locals") + " : "; + for(auto& var : scope_->locals()) { + str += namestr(var.second.get()); + str += ", "; + } + str += "\n"; + + str += " "+blue("body ")+" : "; + str += body_->to_string(); + + return str; +} + +/******************************************************************************* + InitialBlock +*******************************************************************************/ + +std::string InitialBlock::to_string() const { + std::string str = green("[[initial"); + for(auto& ex : statements_) { + str += "\n " + ex->to_string(); + } + str += green("\n ]]"); + return str; +} + +/******************************************************************************* + NetReceiveExpression +*******************************************************************************/ + +void NetReceiveExpression::semantic(scope_type::symbol_map &global_symbols) { + // assert that the symbol is already visible in the global_symbols + if(global_symbols.find(name()) == global_symbols.end()) { + throw compiler_exception( + "attempt to perform semantic analysis for procedure '" + + yellow(name()) + + "' which has not been added to global symbol table", + location_); + } + + // create the scope for this procedure + scope_ = std::make_shared<scope_type>(global_symbols); + + // add the argumemts to the list of local variables + for(auto& a : args_) { + a->semantic(scope_); + } + + // perform semantic analysis for each expression in the body + body_->semantic(scope_); + // this loop could be used to then check the types of statements in the body + for(auto& e : *(body_->is_block())) { + if(e->is_initial_block()) { + if(initial_block_) { + error("only one INITIAL block is permitted per NET_RECEIVE block"); + } + initial_block_ = e->is_initial_block(); + } + } + + // the symbol for this expression is itself + // this could lead to nasty self-referencing loops + symbol_ = scope_->find_global(name()); +} + +/******************************************************************************* + FunctionExpression +*******************************************************************************/ + +std::string FunctionExpression::to_string() const { + std::string str = blue("function") + " " + yellow(name()) + "\n"; + str += blue(" args") + " : "; + for(auto& arg : args_) + str += arg->to_string() + " "; + str += "\n "+blue("body")+" :"; + str += body_->to_string(); + + return str; +} + +void FunctionExpression::semantic(scope_type::symbol_map &global_symbols) { + // assert that the symbol is already visible in the global_symbols + if(global_symbols.find(name()) == global_symbols.end()) { + throw compiler_exception( + "attempt to perform semantic analysis for procedure '" + + yellow(name()) + + "' which has not been added to global symbol table", + location_); + } + + // create the scope for this procedure + scope_ = std::make_shared<scope_type>(global_symbols); + + // add the argumemts to the list of local variables + for(auto& a : args_) { + a->semantic(scope_); + } + + // Add a variable that has the same name as the function, + // which acts as a placeholder for the return value + // Make its location correspond to that of the first line of the function, + // for want of a better location + auto return_var = scope_type::symbol_ptr( + new Symbol(body_->location(), name(), symbolKind::local_variable) + ); + scope_->add_local_symbol(name(), std::move(return_var)); + + // perform semantic analysis for each expression in the body + body_->semantic(scope_); + // this loop could be used to then check the types of statements in the body + for(auto& e : *(body())) { + if(e->is_initial_block()) error("INITIAL block not allowed inside FUNCTION definition"); + } + + // check that the last expression in the body was an assignment to + // the return placeholder + bool last_expr_is_assign = false; + auto tail = body()->back()->is_assignment(); + if(tail) { + // we know that the tail is an assignment expression + auto lhs = tail->lhs()->is_identifier(); + // use nullptr check followed by lazy name lookup + if(lhs && lhs->name()==name()) { + last_expr_is_assign = true; + } + } + if(!last_expr_is_assign) { + warning("the last expression in function '" + + yellow(name()) + + "' does not set the return value"); + } + + // the symbol for this expression is itself + // this could lead to nasty self-referencing loops + symbol_ = scope_->find_global(name()); +} + +/******************************************************************************* + UnaryExpression +*******************************************************************************/ +void UnaryExpression::semantic(std::shared_ptr<scope_type> scp) { + scope_ = scp; + + expression_->semantic(scp); + + if(expression_->is_procedure_call()) { + error("a procedure call can't be part of an expression"); + } +} + +void UnaryExpression::replace_expression(expression_ptr&& other) { + std::swap(expression_, other); +} + +expression_ptr UnaryExpression::clone() const { + return unary_expression(location_, op_, expression_->clone()); +} + +/******************************************************************************* + BinaryExpression +*******************************************************************************/ +void BinaryExpression::semantic(std::shared_ptr<scope_type> scp) { + scope_ = scp; + lhs_->semantic(scp); + rhs_->semantic(scp); + + if(rhs_->is_procedure_call() || lhs_->is_procedure_call()) { + error("procedure calls can't be made in an expression"); + } +} + +expression_ptr BinaryExpression::clone() const { + return binary_expression(location_, op_, lhs_->clone(), rhs_->clone()); +} + +void BinaryExpression::replace_lhs(expression_ptr&& other) { + std::swap(lhs_, other); +} + +void BinaryExpression::replace_rhs(expression_ptr&& other) { + std::swap(rhs_, other); +} + +std::string BinaryExpression::to_string() const { + //return pprintf("(% % %)", blue(token_string(op_)), lhs_->to_string(), rhs_->to_string()); + return pprintf("(% % %)", lhs_->to_string(), blue(token_string(op_)), rhs_->to_string()); +} + +/******************************************************************************* + AssignmentExpression +*******************************************************************************/ + +void AssignmentExpression::semantic(std::shared_ptr<scope_type> scp) { + scope_ = scp; + lhs_->semantic(scp); + rhs_->semantic(scp); + + // only flag an lvalue error if there was no error in the lhs expression + // this ensures that we don't print redundant error messages when trying + // to write to an undeclared variable + if(!lhs_->has_error() && !lhs_->is_lvalue()) { + error("the left hand side of an assignment must be an lvalue"); + } + if(rhs_->is_procedure_call()) { + error("procedure calls can't be made in an expression"); + } +} + +/******************************************************************************* + SolveExpression +*******************************************************************************/ + +void SolveExpression::semantic(std::shared_ptr<scope_type> scp) { + scope_ = scp; + + auto e = scp->find(name()); + auto proc = e ? e->is_procedure() : nullptr; + + // this is optimistic: it simply looks for a procedure, + // it should also evaluate the procedure to see whether it contains the derivatives + // if an integration method has been specified (i.e. cnexp) + if(proc) { + procedure_ = proc; + } + else { + error( "'" + yellow(name_) + "' is not a valid procedure name" + " for computing the derivatives in a SOLVE statement"); + } +} + +expression_ptr SolveExpression::clone() const { + auto s = new SolveExpression(location_, name_, method_); + s->procedure(procedure_); + return expression_ptr{s}; +} + +/******************************************************************************* + ConductanceExpression +*******************************************************************************/ + +void ConductanceExpression::semantic(std::shared_ptr<scope_type> scp) { + scope_ = scp; + // For now do nothing with the CONDUCTANCE statement, because it is not needed + // to optimize conductance calculation. + // Semantic analysis would involve + // - check that name identifies a valid variable + // - check that the ion channel is an ion channel for which current + // is to be updated + /* + auto e = scp->find(name()); + auto var = e ? e->is_variable() : nullptr; + */ +} + +expression_ptr ConductanceExpression::clone() const { + auto s = new ConductanceExpression(location_, name_, ion_channel_); + //s->procedure(procedure_); + return expression_ptr{s}; +} + +/******************************************************************************* + BlockExpression +*******************************************************************************/ + +std::string BlockExpression::to_string() const { + std::string str; + for(auto& ex : statements_) { + str += "\n " + ex->to_string(); + } + return str; +} + +void BlockExpression::semantic(std::shared_ptr<scope_type> scp) { + scope_ = scp; + for(auto& e : statements_) { + e->semantic(scope_); + } +} + +expression_ptr BlockExpression::clone() const { + std::list<expression_ptr> statements; + for(auto& e: statements_) { + statements.emplace_back(e->clone()); + } + return make_expression<BlockExpression>(location_, std::move(statements), is_nested_); +} + +/******************************************************************************* + IfExpression +*******************************************************************************/ + +std::string IfExpression::to_string() const { + std::string s = blue("if") + " :"; + s += "\n " + white("condition") +" " + condition_->to_string(); + s += "\n " + white("true branch ") + true_branch_->to_string(); + if(false_branch_) { + s += "\n " + white("false branch "); + s += false_branch_->to_string(); + } + s += "\n"; + return s; +} + +void IfExpression::semantic(std::shared_ptr<scope_type> scp) { + scope_ = scp; + + condition_->semantic(scp); + + auto cond = condition_->is_conditional(); + if(!cond) { + error("not a valid conditional expression"); + } + + true_branch_->semantic(scp); + + if(false_branch_) { + false_branch_->semantic(scp); + } +} + +expression_ptr IfExpression::clone() const { + return make_expression<IfExpression>( + location_, + condition_->clone(), + true_branch_->clone(), + false_branch_? false_branch_->clone() : nullptr + ); +} + +#include "visitor.hpp" +/* + Visitor hooks +*/ +void Expression::accept(Visitor *v) { + v->visit(this); +} +void Symbol::accept(Visitor *v) { + v->visit(this); +} +void LocalVariable::accept(Visitor *v) { + v->visit(this); +} +void IdentifierExpression::accept(Visitor *v) { + v->visit(this); +} +void BlockExpression::accept(Visitor *v) { + v->visit(this); +} +void InitialBlock::accept(Visitor *v) { + v->visit(this); +} +void IfExpression::accept(Visitor *v) { + v->visit(this); +} +void SolveExpression::accept(Visitor *v) { + v->visit(this); +} +void ConductanceExpression::accept(Visitor *v) { + v->visit(this); +} +void DerivativeExpression::accept(Visitor *v) { + v->visit(this); +} +void VariableExpression::accept(Visitor *v) { + v->visit(this); +} +void IndexedVariable::accept(Visitor *v) { + v->visit(this); +} +void NumberExpression::accept(Visitor *v) { + v->visit(this); +} +void LocalDeclaration::accept(Visitor *v) { + v->visit(this); +} +void ArgumentExpression::accept(Visitor *v) { + v->visit(this); +} +void PrototypeExpression::accept(Visitor *v) { + v->visit(this); +} +void CallExpression::accept(Visitor *v) { + v->visit(this); +} +void ProcedureExpression::accept(Visitor *v) { + v->visit(this); +} +void NetReceiveExpression::accept(Visitor *v) { + v->visit(this); +} +void APIMethod::accept(Visitor *v) { + v->visit(this); +} +void FunctionExpression::accept(Visitor *v) { + v->visit(this); +} +void UnaryExpression::accept(Visitor *v) { + v->visit(this); +} +void NegUnaryExpression::accept(Visitor *v) { + v->visit(this); +} +void ExpUnaryExpression::accept(Visitor *v) { + v->visit(this); +} +void LogUnaryExpression::accept(Visitor *v) { + v->visit(this); +} +void CosUnaryExpression::accept(Visitor *v) { + v->visit(this); +} +void SinUnaryExpression::accept(Visitor *v) { + v->visit(this); +} +void BinaryExpression::accept(Visitor *v) { + v->visit(this); +} +void AssignmentExpression::accept(Visitor *v) { + v->visit(this); +} +void AddBinaryExpression::accept(Visitor *v) { + v->visit(this); +} +void SubBinaryExpression::accept(Visitor *v) { + v->visit(this); +} +void MulBinaryExpression::accept(Visitor *v) { + v->visit(this); +} +void DivBinaryExpression::accept(Visitor *v) { + v->visit(this); +} +void PowBinaryExpression::accept(Visitor *v) { + v->visit(this); +} +void ConditionalExpression::accept(Visitor *v) { + v->visit(this); +} + +expression_ptr unary_expression( Location loc, + tok op, + expression_ptr&& e + ) +{ + switch(op) { + case tok::minus : + return make_expression<NegUnaryExpression>(loc, std::move(e)); + case tok::exp : + return make_expression<ExpUnaryExpression>(loc, std::move(e)); + case tok::cos : + return make_expression<CosUnaryExpression>(loc, std::move(e)); + case tok::sin : + return make_expression<SinUnaryExpression>(loc, std::move(e)); + case tok::log : + return make_expression<LogUnaryExpression>(loc, std::move(e)); + default : + std::cerr << yellow(token_string(op)) + << " is not a valid unary operator" + << std::endl;; + return nullptr; + } + return nullptr; +} + +expression_ptr binary_expression( tok op, + expression_ptr&& lhs, + expression_ptr&& rhs + ) +{ + return binary_expression(Location(), op, std::move(lhs), std::move(rhs)); +} + +expression_ptr binary_expression(Location loc, + tok op, + expression_ptr&& lhs, + expression_ptr&& rhs + ) +{ + switch(op) { + case tok::eq : + return make_expression<AssignmentExpression>( + loc, std::move(lhs), std::move(rhs) + ); + case tok::plus : + return make_expression<AddBinaryExpression>( + loc, std::move(lhs), std::move(rhs) + ); + case tok::minus : + return make_expression<SubBinaryExpression>( + loc, std::move(lhs), std::move(rhs) + ); + case tok::times : + return make_expression<MulBinaryExpression>( + loc, std::move(lhs), std::move(rhs) + ); + case tok::divide : + return make_expression<DivBinaryExpression>( + loc, std::move(lhs), std::move(rhs) + ); + case tok::pow : + return make_expression<PowBinaryExpression>( + loc, std::move(lhs), std::move(rhs) + ); + case tok::lt : + case tok::lte : + case tok::gt : + case tok::gte : + case tok::equality : + return make_expression<ConditionalExpression>(loc, op, std::move(lhs), std::move(rhs)); + default : + std::cerr << yellow(token_string(op)) + << " is not a valid binary operator" + << std::endl; + return nullptr; + } + return nullptr; +} diff --git a/modcc/src/expression.hpp b/modcc/src/expression.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b37f18552ada9e6c887354a090cee941b6eab01a --- /dev/null +++ b/modcc/src/expression.hpp @@ -0,0 +1,1158 @@ +#pragma once + +#include <iostream> +#include <limits> +#include <list> +#include <memory> +#include <string> +#include <vector> + +#include "error.hpp" +#include "identifier.hpp" +#include "memop.hpp" +#include "scope.hpp" +#include "util.hpp" + +class Visitor; + +class Expression; +class IdentifierExpression; +class BlockExpression; +class InitialBlock; +class IfExpression; +class VariableExpression; +class IndexedVariable; +class NumberExpression; +class LocalDeclaration; +class ArgumentExpression; +class DerivativeExpression; +class PrototypeExpression; +class CallExpression; +class ProcedureExpression; +class NetReceiveExpression; +class APIMethod; +class FunctionExpression; +class UnaryExpression; +class NegUnaryExpression; +class ExpUnaryExpression; +class LogUnaryExpression; +class CosUnaryExpression; +class SinUnaryExpression; +class BinaryExpression; +class AssignmentExpression; +class AddBinaryExpression; +class SubBinaryExpression; +class MulBinaryExpression; +class DivBinaryExpression; +class PowBinaryExpression; +class ConditionalExpression; +class SolveExpression; +class ConductanceExpression; +class Symbol; +class LocalVariable; + +using expression_ptr = std::unique_ptr<Expression>; +using symbol_ptr = std::unique_ptr<Symbol>; + +template <typename T, typename... Args> +expression_ptr make_expression(Args&&... args) { + return expression_ptr(new T(std::forward<Args>(args)...)); +} + +template <typename T, typename... Args> +symbol_ptr make_symbol(Args&&... args) { + return symbol_ptr(new T(std::forward<Args>(args)...)); +} + +// helper functions for generating unary and binary expressions +expression_ptr unary_expression(Location, tok, expression_ptr&&); +expression_ptr unary_expression(tok, expression_ptr&&); +expression_ptr binary_expression(Location, tok, expression_ptr&&, expression_ptr&&); +expression_ptr binary_expression(tok, expression_ptr&&, expression_ptr&&); + +/// specifies special properties of a ProcedureExpression +enum class procedureKind { + normal, ///< PROCEDURE + api, ///< API PROCEDURE + initial, ///< INITIAL + net_receive, ///< NET_RECEIVE + breakpoint, ///< BREAKPOINT + derivative ///< DERIVATIVE +}; +std::string to_string(procedureKind k); + +/// classification of different symbol kinds +enum class symbolKind { + function, ///< function call + procedure, ///< procedure call + variable, ///< variable at module scope + indexed_variable, ///< a variable that is indexed + local_variable, ///< variable at local scope +}; +std::string to_string(symbolKind k); + +/// methods for time stepping state +enum class solverMethod { + cnexp, // the only method we have at the moment + none +}; + +static std::string to_string(solverMethod m) { + switch(m) { + case solverMethod::cnexp : return std::string("cnexp"); + case solverMethod::none : return std::string("none"); + } + return std::string("<error : undefined solverMethod>"); +} + +class Expression { +public: + using scope_type = Scope<Symbol>; + + explicit Expression(Location location) + : location_(location) + {} + + virtual ~Expression() {}; + + // This printer should be implemented with a visitor pattern + // expressions must provide a method for stringification + virtual std::string to_string() const = 0; + + Location const& location() const {return location_;}; + + std::shared_ptr<scope_type> scope() {return scope_;}; + + void error(std::string const& str) { + error_ = true; + error_string_ += str; + } + void warning(std::string const& str) { + warning_ = true; + warning_string_ += str; + } + bool has_error() { return error_; } + bool has_warning() { return warning_; } + std::string const& error_message() const { return error_string_; } + std::string const& warning_message() const { return warning_string_; } + + // perform semantic analysis + virtual void semantic(std::shared_ptr<scope_type>); + virtual void semantic(scope_type::symbol_map&) { + throw compiler_exception("unable to perform semantic analysis for " + this->to_string(), location_); + }; + + // clone an expression + virtual expression_ptr clone() const; + + // easy lookup of properties + virtual CallExpression* is_function_call() {return nullptr;} + virtual CallExpression* is_procedure_call() {return nullptr;} + virtual BlockExpression* is_block() {return nullptr;} + virtual IfExpression* is_if() {return nullptr;} + virtual LocalDeclaration* is_local_declaration() {return nullptr;} + virtual ArgumentExpression* is_argument() {return nullptr;} + virtual FunctionExpression* is_function() {return nullptr;} + virtual DerivativeExpression* is_derivative() {return nullptr;} + virtual PrototypeExpression* is_prototype() {return nullptr;} + virtual IdentifierExpression* is_identifier() {return nullptr;} + virtual NumberExpression* is_number() {return nullptr;} + virtual BinaryExpression* is_binary() {return nullptr;} + virtual UnaryExpression* is_unary() {return nullptr;} + virtual AssignmentExpression* is_assignment() {return nullptr;} + virtual ConditionalExpression* is_conditional() {return nullptr;} + virtual InitialBlock* is_initial_block() {return nullptr;} + virtual SolveExpression* is_solve_statement() {return nullptr;} + virtual Symbol* is_symbol() {return nullptr;} + virtual ConductanceExpression* is_conductance_statement() {return nullptr;} + + virtual bool is_lvalue() {return false;} + + // force all derived classes to implement visitor + // this might be a bad idea + virtual void accept(Visitor *v) = 0; + +protected: + // these are used to flag errors when performing semantic analysis + // we might want to extend these to an additional "contaminated" flag + // which marks whether an error was found in a subnode of a node. + bool error_=false; + bool warning_=false; + std::string error_string_; + std::string warning_string_; + + Location location_; + + std::shared_ptr<scope_type> scope_; +}; + +class Symbol : public Expression { +public : + Symbol(Location loc, std::string name, symbolKind kind) + : Expression(std::move(loc)), + name_(std::move(name)), + kind_(kind) + {} + + std::string const& name() const { + return name_; + } + + symbolKind kind() const { + return kind_; + } + + Symbol* is_symbol() override { + return this; + } + + std::string to_string() const override; + void accept(Visitor *v) override; + + virtual VariableExpression* is_variable() {return nullptr;} + virtual ProcedureExpression* is_procedure() {return nullptr;} + virtual NetReceiveExpression* is_net_receive() {return nullptr;} + virtual APIMethod* is_api_method() {return nullptr;} + virtual IndexedVariable* is_indexed_variable() {return nullptr;} + virtual LocalVariable* is_local_variable() {return nullptr;} + +private : + std::string name_; + + symbolKind kind_; +}; + +enum class localVariableKind { + local, argument +}; + +// an identifier +class IdentifierExpression : public Expression { +public: + IdentifierExpression(Location loc, std::string const& spelling) + : Expression(loc), spelling_(spelling) + {} + + IdentifierExpression(IdentifierExpression const& other) + : Expression(other.location()), spelling_(other.spelling()) + {} + + IdentifierExpression(IdentifierExpression const* other) + : Expression(other->location()), spelling_(other->spelling()) + {} + + std::string const& spelling() const { + return spelling_; + } + + std::string to_string() const override { + return yellow(pprintf("%", spelling_)); + } + + expression_ptr clone() const override; + + void semantic(std::shared_ptr<scope_type> scp) override; + + Symbol* symbol() { return symbol_; }; + + void accept(Visitor *v) override; + + IdentifierExpression* is_identifier() override {return this;} + + bool is_lvalue() override; + + ~IdentifierExpression() {} + + std::string const& name() const { + if(symbol_) return symbol_->name(); + throw compiler_exception( + " attempt to look up name of identifier for which no symbol_ yet defined", + location_); + } + +protected: + Symbol* symbol_ = nullptr; + + // there has to be some pointer to a table of identifiers + std::string spelling_; +}; + +// an identifier for a derivative +class DerivativeExpression : public IdentifierExpression { +public: + DerivativeExpression(Location loc, std::string const& name) + : IdentifierExpression(loc, name) + {} + + std::string to_string() const override { + return blue("diff") + "(" + yellow(spelling()) + ")"; + } + DerivativeExpression* is_derivative() override { return this; } + + ~DerivativeExpression() {} + + void accept(Visitor *v) override; +}; + +// a number +class NumberExpression : public Expression { +public: + NumberExpression(Location loc, std::string const& value) + : Expression(loc), value_(std::stod(value)) + {} + + NumberExpression(Location loc, long double value) + : Expression(loc), value_(value) + {} + + long double value() const {return value_;}; + + std::string to_string() const override { + return purple(pprintf("%", value_)); + } + + // do nothing for number semantic analysis + void semantic(std::shared_ptr<scope_type> scp) override {}; + expression_ptr clone() const override; + + NumberExpression* is_number() override {return this;} + + ~NumberExpression() {} + + void accept(Visitor *v) override; +private: + long double value_; +}; + +// declaration of a LOCAL variable +class LocalDeclaration : public Expression { +public: + LocalDeclaration(Location loc) + : Expression(loc) + {} + LocalDeclaration(Location loc, std::string const& name) + : Expression(loc) + { + Token tok(tok::identifier, name, loc); + add_variable(tok); + } + + std::string to_string() const override; + + bool add_variable(Token name); + LocalDeclaration* is_local_declaration() override {return this;} + void semantic(std::shared_ptr<scope_type> scp) override; + std::vector<Symbol*>& symbols() {return symbols_;} + std::map<std::string, Token>& variables() {return vars_;} + expression_ptr clone() const override; + ~LocalDeclaration() {} + void accept(Visitor *v) override; +private: + std::vector<Symbol*> symbols_; + // there has to be some pointer to a table of identifiers + std::map<std::string, Token> vars_; +}; + +// declaration of an argument +class ArgumentExpression : public Expression { +public: + ArgumentExpression(Location loc, Token const& tok) + : Expression(loc), + token_(tok), + name_(tok.spelling) + {} + + std::string to_string() const override; + + bool add_variable(Token name); + ArgumentExpression* is_argument() override {return this;} + void semantic(std::shared_ptr<scope_type> scp) override; + Token token() {return token_;} + std::string const& name() {return name_;} + void set_name(std::string const& n) { + name_ = n; + } + const std::string& spelling() const { + return token_.spelling; + } + + ~ArgumentExpression() {} + void accept(Visitor *v) override; +private: + Token token_; + std::string name_; +}; + +// variable definition +class VariableExpression : public Symbol { +public: + VariableExpression(Location loc, std::string name) + : Symbol(loc, std::move(name), symbolKind::variable) + {} + + std::string to_string() const override; + + void access(accessKind a) { + access_ = a; + } + void visibility(visibilityKind v) { + visibility_ = v; + } + void linkage(linkageKind l) { + linkage_ = l; + } + void range(rangeKind r) { + range_kind_ = r; + } + void ion_channel(ionKind i) { + ion_channel_ = i; + } + void state(bool s) { + is_state_ = s; + } + + accessKind access() const { + return access_; + } + visibilityKind visibility() const { + return visibility_; + } + linkageKind linkage() const { + return linkage_; + } + ionKind ion_channel() const { + return ion_channel_; + } + + bool is_ion() const {return ion_channel_ != ionKind::none;} + bool is_state() const {return is_state_;} + bool is_range() const {return range_kind_ == rangeKind::range;} + bool is_scalar() const {return !is_range();} + + bool is_readable() const {return access_==accessKind::read + || access_==accessKind::readwrite;} + + bool is_writeable() const {return access_==accessKind::write + || access_==accessKind::readwrite;} + + double value() const {return value_;} + void value(double v) {value_ = v;} + + void accept(Visitor *v) override; + VariableExpression* is_variable() override {return this;} + + ~VariableExpression() {} +protected: + + bool is_state_ = false; + accessKind access_ = accessKind::readwrite; + visibilityKind visibility_ = visibilityKind::local; + linkageKind linkage_ = linkageKind::external; + rangeKind range_kind_ = rangeKind::range; + ionKind ion_channel_ = ionKind::none; + double value_ = std::numeric_limits<double>::quiet_NaN(); +}; + +// an indexed variable +class IndexedVariable : public Symbol { +public: + IndexedVariable(Location loc, + std::string lookup_name, + std::string index_name, + accessKind acc, + tok o=tok::eq, + ionKind channel=ionKind::none) + : Symbol(loc, std::move(lookup_name), symbolKind::indexed_variable), + access_(acc), + ion_channel_(channel), + index_name_(index_name), + op_(o) + { + std::string msg; + // external symbols are either read or write only + if(access()==accessKind::readwrite) { + msg = pprintf("attempt to generate an index % with readwrite access", + yellow(lookup_name)); + goto compiler_error; + } + // read only variables must be assigned via equality + if(is_read() && op()!=tok::eq) { + msg = pprintf("read only indexes % must use assignment", + yellow(lookup_name)); + goto compiler_error; + } + // write only variables must be update via addition/subtraction + if(is_write() && (op()!=tok::plus && op()!=tok::minus)) { + msg = pprintf("write only index % must use addition or subtraction", + yellow(lookup_name)); + goto compiler_error; + } + + return; + +compiler_error: + throw(compiler_exception(msg, location_)); + } + + std::string to_string() const override; + + accessKind access() const { + return access_; + } + ionKind ion_channel() const { + return ion_channel_; + } + tok op() const { + return op_; + } + + std::string const& index_name() const { + return index_name_; + } + + bool is_ion() const {return ion_channel_ != ionKind::none;} + bool is_read() const {return access_ == accessKind::read; } + bool is_write() const {return access_ == accessKind::write; } + + void accept(Visitor *v) override; + IndexedVariable* is_indexed_variable() override {return this;} + + ~IndexedVariable() {} +protected: + accessKind access_; + ionKind ion_channel_; + std::string index_name_; + tok op_; +}; + +class LocalVariable : public Symbol { +public : + LocalVariable(Location loc, + std::string name, + localVariableKind kind=localVariableKind::local) + : Symbol(std::move(loc), std::move(name), symbolKind::local_variable), + kind_(kind) + {} + + LocalVariable* is_local_variable() override { + return this; + } + + localVariableKind kind() const { + return kind_; + } + + bool is_indexed() const { + return external_!=nullptr && ion_channel()!=ionKind::nonspecific; + } + + ionKind ion_channel() const { + if(external_) return external_->ion_channel(); + return ionKind::none; + } + + bool is_read() const { + if(is_indexed()) return external_->is_read(); + return true; + } + + bool is_write() const { + if(is_indexed()) return external_->is_write(); + return true; + } + + bool is_local() const { + return kind_==localVariableKind::local; + } + + bool is_arg() const { + return kind_==localVariableKind::argument; + } + + IndexedVariable* external_variable() { + return external_; + } + + void external_variable(IndexedVariable *i) { + external_ = i; + } + + std::string to_string() const override; + void accept(Visitor *v) override; + +private : + IndexedVariable *external_=nullptr; + localVariableKind kind_; +}; + + +// a SOLVE statement +class SolveExpression : public Expression { +public: + SolveExpression( + Location loc, + std::string name, + solverMethod method) + : Expression(loc), name_(std::move(name)), method_(method), procedure_(nullptr) + {} + + std::string to_string() const override { + return blue("solve") + "(" + yellow(name_) + ", " + + green(::to_string(method_)) + ")"; + } + + std::string const& name() const { + return name_; + } + + solverMethod method() const { + return method_; + } + + ProcedureExpression* procedure() const { + return procedure_; + } + + void procedure(ProcedureExpression *e) { + procedure_ = e; + } + + SolveExpression* is_solve_statement() override { + return this; + } + + expression_ptr clone() const override; + + void semantic(std::shared_ptr<scope_type> scp) override; + void accept(Visitor *v) override; + + ~SolveExpression() {} +private: + /// pointer to the variable symbol for the state variable to be solved for + std::string name_; + solverMethod method_; + + ProcedureExpression* procedure_; +}; + +// a CONDUCTANCE statement +class ConductanceExpression : public Expression { +public: + ConductanceExpression( + Location loc, + std::string name, + ionKind channel) + : Expression(loc), name_(std::move(name)), ion_channel_(channel) + {} + + std::string to_string() const override { + return blue("conductance") + "(" + yellow(name_) + ", " + + green(::to_string(ion_channel_)) + ")"; + } + + std::string const& name() const { + return name_; + } + + ionKind ion_channel() const { + return ion_channel_; + } + + ConductanceExpression* is_conductance_statement() override { + return this; + } + + expression_ptr clone() const override; + + void semantic(std::shared_ptr<scope_type> scp) override; + void accept(Visitor *v) override; + + ~ConductanceExpression() {} +private: + /// pointer to the variable symbol for the state variable to be solved for + std::string name_; + ionKind ion_channel_; +}; + +//////////////////////////////////////////////////////////////////////////////// +// recursive if statement +// requires a BlockExpression that is a simple wrapper around a std::list +// of Expressions surrounded by {} +//////////////////////////////////////////////////////////////////////////////// + +class BlockExpression : public Expression { +protected: + std::list<expression_ptr> statements_; + bool is_nested_ = false; + +public: + BlockExpression( + Location loc, + std::list<expression_ptr>&& statements, + bool is_nested) + : Expression(loc), + statements_(std::move(statements)), + is_nested_(is_nested) + {} + + BlockExpression* is_block() override { + return this; + } + + std::list<expression_ptr>& statements() { + return statements_; + } + + expression_ptr clone() const override; + + // provide iterators for easy iteration over statements + auto begin() -> decltype(statements_.begin()) { + return statements_.begin(); + } + auto end() -> decltype(statements_.end()) { + return statements_.end(); + } + auto back() -> decltype(statements_.back()) { + return statements_.back(); + } + auto front() -> decltype(statements_.front()) { + return statements_.front(); + } + bool is_nested() const { + return is_nested_; + } + + void semantic(std::shared_ptr<scope_type> scp) override; + void accept(Visitor* v) override; + + std::string to_string() const override; +}; + +class IfExpression : public Expression { +public: + IfExpression(Location loc, expression_ptr&& con, expression_ptr&& tb, expression_ptr&& fb) + : Expression(loc), condition_(std::move(con)), true_branch_(std::move(tb)), false_branch_(std::move(fb)) + {} + + IfExpression* is_if() override { + return this; + } + Expression* condition() { + return condition_.get(); + } + Expression* true_branch() { + return true_branch_.get(); + } + Expression* false_branch() { + return false_branch_.get(); + } + + expression_ptr clone() const override; + + std::string to_string() const override; + void semantic(std::shared_ptr<scope_type> scp) override; + + void accept(Visitor* v) override; +private: + expression_ptr condition_; + expression_ptr true_branch_; + expression_ptr false_branch_; +}; + +// a proceduce prototype +class PrototypeExpression : public Expression { +public: + PrototypeExpression( + Location loc, + std::string const& name, + std::vector<expression_ptr>&& args) + : Expression(loc), name_(name), args_(std::move(args)) + {} + + std::string const& name() const {return name_;} + + std::vector<expression_ptr>& args() {return args_;} + std::vector<expression_ptr>const& args() const {return args_;} + PrototypeExpression* is_prototype() override {return this;} + + // TODO: printing out the vector of unique pointers is an unsolved problem... + std::string to_string() const override { + return name_; //+ pprintf("(% args : %)", args_.size(), args_); + } + + ~PrototypeExpression() {} + + void accept(Visitor *v) override; +private: + std::string name_; + std::vector<expression_ptr> args_; +}; + +// marks a call site in the AST +// is used to mark both function and procedure calls +class CallExpression : public Expression { +public: + CallExpression(Location loc, std::string spelling, std::vector<expression_ptr>&& args) + : Expression(loc), spelling_(std::move(spelling)), args_(std::move(args)) + {} + + std::vector<expression_ptr>& args() { return args_; } + std::vector<expression_ptr> const& args() const { return args_; } + + std::string& name() { return spelling_; } + std::string const& name() const { return spelling_; } + + void semantic(std::shared_ptr<scope_type> scp) override; + expression_ptr clone() const override; + + std::string to_string() const override; + + void accept(Visitor *v) override; + + CallExpression* is_function_call() override { + return symbol_->kind() == symbolKind::function ? this : nullptr; + } + CallExpression* is_procedure_call() override { + return symbol_->kind() == symbolKind::procedure ? this : nullptr; + } + + FunctionExpression* function() { + return symbol_->kind() == symbolKind::function + ? symbol_->is_function() : nullptr; + } + + ProcedureExpression* procedure() { + return symbol_->kind() == symbolKind::procedure + ? symbol_->is_procedure() : nullptr; + } + +private: + Symbol* symbol_; + + std::string spelling_; + std::vector<expression_ptr> args_; +}; + +class ProcedureExpression : public Symbol { +public: + ProcedureExpression( Location loc, + std::string name, + std::vector<expression_ptr>&& args, + expression_ptr&& body, + procedureKind k=procedureKind::normal) + : Symbol(loc, std::move(name), symbolKind::procedure), args_(std::move(args)), kind_(k) + { + if(!body->is_block()) { + throw compiler_exception( + " attempt to initialize ProcedureExpression with non-block expression, i.e.\n" + + body->to_string(), + location_); + } + body_ = std::move(body); + } + + std::vector<expression_ptr>& args() { + return args_; + } + BlockExpression* body() { + return body_.get()->is_block(); + } + + void semantic(scope_type::symbol_map &scp) override; + ProcedureExpression* is_procedure() override {return this;} + std::string to_string() const override; + void accept(Visitor *v) override; + + /// can be used to determine whether the procedure has been lowered + /// from a special block, e.g. BREAKPOINT, INITIAL, NET_RECEIVE, etc + procedureKind kind() const {return kind_;} + +protected: + Symbol* symbol_; + + std::vector<expression_ptr> args_; + expression_ptr body_; + procedureKind kind_ = procedureKind::normal; +}; + +class APIMethod : public ProcedureExpression { +public: + using memop_type = MemOp<Symbol>; + + APIMethod( Location loc, + std::string name, + std::vector<expression_ptr>&& args, + expression_ptr&& body) + : ProcedureExpression(loc, std::move(name), std::move(args), std::move(body), procedureKind::api) + {} + + APIMethod* is_api_method() override {return this;} + void accept(Visitor *v) override; + + std::string to_string() const override; +}; + +/// stores the INITIAL block in a NET_RECEIVE block, if there is one +/// should not be used anywhere but NET_RECEIVE +class InitialBlock : public BlockExpression { +public: + InitialBlock( + Location loc, + std::list<expression_ptr>&& statements) + : BlockExpression(loc, std::move(statements), true) + {} + + std::string to_string() const override; + + // currently we use the semantic for a BlockExpression + // this could be overriden to check for statements + // specific to initialization of a net_receive block + //void semantic() override; + + void accept(Visitor *v) override; + InitialBlock* is_initial_block() override {return this;} +}; + +/// handle NetReceiveExpressions as a special case of ProcedureExpression +class NetReceiveExpression : public ProcedureExpression { +public: + NetReceiveExpression( Location loc, + std::string name, + std::vector<expression_ptr>&& args, + expression_ptr&& body) + : ProcedureExpression(loc, std::move(name), std::move(args), std::move(body), procedureKind::net_receive) + {} + + void semantic(scope_type::symbol_map &scp) override; + NetReceiveExpression* is_net_receive() override {return this;} + /// hard code the kind + procedureKind kind() {return procedureKind::net_receive;} + + void accept(Visitor *v) override; + InitialBlock* initial_block() {return initial_block_;} +protected: + InitialBlock* initial_block_ = nullptr; +}; + +class FunctionExpression : public Symbol { +public: + FunctionExpression( Location loc, + std::string name, + std::vector<expression_ptr>&& args, + expression_ptr&& body) + : Symbol(loc, std::move(name), symbolKind::function), + args_(std::move(args)) + { + if(!body->is_block()) { + throw compiler_exception( + " attempt to initialize FunctionExpression with non-block expression, i.e.\n" + + body->to_string(), + location_); + } + body_ = std::move(body); + } + + std::vector<expression_ptr>& args() { + return args_; + } + BlockExpression* body() { + return body_->is_block(); + } + + FunctionExpression* is_function() override {return this;} + void semantic(scope_type::symbol_map&) override; + std::string to_string() const override; + void accept(Visitor *v) override; + +private: + Symbol* symbol_; + + std::vector<expression_ptr> args_; + expression_ptr body_; +}; + +//////////////////////////////////////////////////////////// +// unary expressions +//////////////////////////////////////////////////////////// + +/// Unary expression +class UnaryExpression : public Expression { +protected: + expression_ptr expression_; + tok op_; +public: + UnaryExpression(Location loc, tok op, expression_ptr&& e) + : Expression(loc), + expression_(std::move(e)), + op_(op) + {} + + std::string to_string() const override { + return pprintf("(% %)", green(token_string(op_)), expression_->to_string()); + } + + expression_ptr clone() const override; + + tok op() const {return op_;} + UnaryExpression* is_unary() override {return this;}; + Expression* expression() {return expression_.get();} + const Expression* expression() const {return expression_.get();} + void semantic(std::shared_ptr<scope_type> scp) override; + void accept(Visitor *v) override; + void replace_expression(expression_ptr&& other); +}; + +/// negation unary expression, i.e. -x +class NegUnaryExpression : public UnaryExpression { +public: + NegUnaryExpression(Location loc, expression_ptr e) + : UnaryExpression(loc, tok::minus, std::move(e)) + {} + + void accept(Visitor *v) override; +}; + +/// exponential unary expression, i.e. e^x or exp(x) +class ExpUnaryExpression : public UnaryExpression { +public: + ExpUnaryExpression(Location loc, expression_ptr e) + : UnaryExpression(loc, tok::exp, std::move(e)) + {} + + void accept(Visitor *v) override; +}; + +// logarithm unary expression, i.e. log_10(x) +class LogUnaryExpression : public UnaryExpression { +public: + LogUnaryExpression(Location loc, expression_ptr e) + : UnaryExpression(loc, tok::log, std::move(e)) + {} + + void accept(Visitor *v) override; +}; + +// cosine unary expression, i.e. cos(x) +class CosUnaryExpression : public UnaryExpression { +public: + CosUnaryExpression(Location loc, expression_ptr e) + : UnaryExpression(loc, tok::cos, std::move(e)) + {} + + void accept(Visitor *v) override; +}; + +// sin unary expression, i.e. sin(x) +class SinUnaryExpression : public UnaryExpression { +public: + SinUnaryExpression(Location loc, expression_ptr e) + : UnaryExpression(loc, tok::sin, std::move(e)) + {} + + void accept(Visitor *v) override; +}; + +//////////////////////////////////////////////////////////// +// binary expressions + +//////////////////////////////////////////////////////////// + +/// binary expression base class +/// never used directly in the AST, instead the specializations that derive from +/// it are inserted into the AST. +class BinaryExpression : public Expression { +protected: + expression_ptr lhs_; + expression_ptr rhs_; + tok op_; +public: + BinaryExpression(Location loc, tok op, expression_ptr&& lhs, expression_ptr&& rhs) + : Expression(loc), + lhs_(std::move(lhs)), + rhs_(std::move(rhs)), + op_(op) + {} + + tok op() const {return op_;} + Expression* lhs() {return lhs_.get();} + Expression* rhs() {return rhs_.get();} + const Expression* lhs() const {return lhs_.get();} + const Expression* rhs() const {return rhs_.get();} + BinaryExpression* is_binary() override {return this;} + void semantic(std::shared_ptr<scope_type> scp) override; + expression_ptr clone() const override; + void replace_rhs(expression_ptr&& other); + void replace_lhs(expression_ptr&& other); + std::string to_string() const override; + void accept(Visitor *v) override; +}; + +class AssignmentExpression : public BinaryExpression { +public: + AssignmentExpression(Location loc, expression_ptr&& lhs, expression_ptr&& rhs) + : BinaryExpression(loc, tok::eq, std::move(lhs), std::move(rhs)) + {} + + AssignmentExpression* is_assignment() override {return this;} + + void semantic(std::shared_ptr<scope_type> scp) override; + + void accept(Visitor *v) override; +}; + +class AddBinaryExpression : public BinaryExpression { +public: + AddBinaryExpression(Location loc, expression_ptr&& lhs, expression_ptr&& rhs) + : BinaryExpression(loc, tok::plus, std::move(lhs), std::move(rhs)) + {} + + void accept(Visitor *v) override; +}; + +class SubBinaryExpression : public BinaryExpression { +public: + SubBinaryExpression(Location loc, expression_ptr&& lhs, expression_ptr&& rhs) + : BinaryExpression(loc, tok::minus, std::move(lhs), std::move(rhs)) + {} + + void accept(Visitor *v) override; +}; + +class MulBinaryExpression : public BinaryExpression { +public: + MulBinaryExpression(Location loc, expression_ptr&& lhs, expression_ptr&& rhs) + : BinaryExpression(loc, tok::times, std::move(lhs), std::move(rhs)) + {} + + void accept(Visitor *v) override; +}; + +class DivBinaryExpression : public BinaryExpression { +public: + DivBinaryExpression(Location loc, expression_ptr&& lhs, expression_ptr&& rhs) + : BinaryExpression(loc, tok::divide, std::move(lhs), std::move(rhs)) + {} + + void accept(Visitor *v) override; +}; + +class PowBinaryExpression : public BinaryExpression { +public: + PowBinaryExpression(Location loc, expression_ptr&& lhs, expression_ptr&& rhs) + : BinaryExpression(loc, tok::pow, std::move(lhs), std::move(rhs)) + {} + + void accept(Visitor *v) override; +}; + +class ConditionalExpression : public BinaryExpression { +public: + ConditionalExpression(Location loc, tok op, expression_ptr&& lhs, expression_ptr&& rhs) + : BinaryExpression(loc, op, std::move(lhs), std::move(rhs)) + {} + + ConditionalExpression* is_conditional() override {return this;} + + void accept(Visitor *v) override; +}; + diff --git a/modcc/src/expressionclassifier.cpp b/modcc/src/expressionclassifier.cpp new file mode 100644 index 0000000000000000000000000000000000000000..83532a4710fb906173e2301c7210d011cb396c63 --- /dev/null +++ b/modcc/src/expressionclassifier.cpp @@ -0,0 +1,323 @@ +#include <iostream> +#include <cmath> + +#include "error.hpp" +#include "expressionclassifier.hpp" +#include "util.hpp" + +// this turns out to be quite easy, however quite fiddly to do right. + +// default is to do nothing and return +void ExpressionClassifierVisitor::visit(Expression *e) { + throw compiler_exception(" attempting to apply linear analysis on " + e->to_string(), e->location()); +} + +// number expresssion +void ExpressionClassifierVisitor::visit(NumberExpression *e) { + // save the coefficient as the number + coefficient_ = e->clone(); +} + +// identifier expresssion +void ExpressionClassifierVisitor::visit(IdentifierExpression *e) { + // check if symbol of identifier matches the identifier + if(symbol_ == e->symbol()) { + found_symbol_ = true; + coefficient_.reset(new NumberExpression(Location(), "1")); + } + else { + coefficient_ = e->clone(); + } +} + +/// unary expresssion +void ExpressionClassifierVisitor::visit(UnaryExpression *e) { + e->expression()->accept(this); + if(found_symbol_) { + switch(e->op()) { + // plus or minus don't change linearity + case tok::minus : + coefficient_ = unary_expression(Location(), + e->op(), + std::move(coefficient_)); + return; + case tok::plus : + return; + // one of these applied to the symbol certainly isn't linear + case tok::exp : + case tok::cos : + case tok::sin : + case tok::log : + is_linear_ = false; + return; + default : + throw compiler_exception( + "attempting to apply linear analysis on unsuported UnaryExpression " + + yellow(token_string(e->op())), e->location()); + } + } + else { + coefficient_ = e->clone(); + } +} + +// binary expresssion +// handle all binary expressions with one routine, because the +// pre-order and in-order code is the same for all cases +void ExpressionClassifierVisitor::visit(BinaryExpression *e) { + bool lhs_contains_symbol = false; + bool rhs_contains_symbol = false; + expression_ptr lhs_coefficient; + expression_ptr rhs_coefficient; + expression_ptr lhs_constant; + expression_ptr rhs_constant; + + // check the lhs + reset(); + e->lhs()->accept(this); + lhs_contains_symbol = found_symbol_; + lhs_coefficient = std::move(coefficient_); + lhs_constant = std::move(constant_); + if(!is_linear_) return; // early return if nonlinear + + // check the rhs + reset(); + e->rhs()->accept(this); + rhs_contains_symbol = found_symbol_; + rhs_coefficient = std::move(coefficient_); + rhs_constant = std::move(constant_); + if(!is_linear_) return; // early return if nonlinear + + // mark symbol as found if in either lhs or rhs + found_symbol_ = rhs_contains_symbol || lhs_contains_symbol; + + if( found_symbol_ ) { + // if both lhs and rhs contain symbol check that the binary operator + // preserves linearity + // note that we don't have to test for linearity, because we abort early + // if either lhs or rhs are nonlinear + if( rhs_contains_symbol && lhs_contains_symbol ) { + // be careful to get the order of operation right for + // non-computative operators + switch(e->op()) { + // addition and subtraction are valid, nothing else is + case tok::plus : + case tok::minus : + coefficient_ = + binary_expression(Location(), + e->op(), + std::move(lhs_coefficient), + std::move(rhs_coefficient)); + return; + // multiplying two expressions that depend on symbol is nonlinear + case tok::times : + case tok::pow : + case tok::divide : + default : + is_linear_ = false; + return; + } + } + // special cases : + // operator | invalid symbol location + // ------------------------------------- + // pow | lhs OR rhs + // comparisons | lhs OR rhs + // division | rhs + //////////////////////////////////////////////////////////////////////// + // only RHS contains the symbol + //////////////////////////////////////////////////////////////////////// + else if(rhs_contains_symbol) { + switch(e->op()) { + case tok::times : + // determine the linear coefficient + if( rhs_coefficient->is_number() && + rhs_coefficient->is_number()->value()==1) { + coefficient_ = lhs_coefficient->clone(); + } + else { + coefficient_ = + binary_expression(Location(), + tok::times, + lhs_coefficient->clone(), + rhs_coefficient->clone()); + } + // determine the constant + if(rhs_constant) { + constant_ = + binary_expression(Location(), + tok::times, + std::move(lhs_coefficient), + std::move(rhs_constant)); + } else { + constant_ = nullptr; + } + return; + case tok::plus : + // constant term + if(lhs_constant && rhs_constant) { + constant_ = + binary_expression(Location(), + tok::plus, + std::move(lhs_constant), + std::move(rhs_constant)); + } + else if(rhs_constant) { + constant_ = binary_expression(Location(), + tok::plus, + std::move(lhs_coefficient), + std::move(rhs_constant)); + } + else { + constant_ = std::move(lhs_coefficient); + } + // coefficient + coefficient_ = std::move(rhs_coefficient); + return; + case tok::minus : + // constant term + if(lhs_constant && rhs_constant) { + constant_ = binary_expression(Location(), + tok::minus, + std::move(lhs_constant), + std::move(rhs_constant)); + } + else if(rhs_constant) { + constant_ = binary_expression(Location(), + tok::minus, + std::move(lhs_coefficient), + std::move(rhs_constant)); + } + else { + constant_ = std::move(lhs_coefficient); + } + // coefficient + coefficient_ = unary_expression(Location(), + e->op(), + std::move(rhs_coefficient)); + return; + case tok::pow : + case tok::divide : + case tok::lt : + case tok::lte : + case tok::gt : + case tok::gte : + case tok::equality : + is_linear_ = false; + return; + default: + return; + } + } + //////////////////////////////////////////////////////////////////////// + // only LHS contains the symbol + //////////////////////////////////////////////////////////////////////// + else if(lhs_contains_symbol) { + switch(e->op()) { + case tok::times : + // check if the lhs is == 1 + if( lhs_coefficient->is_number() && + lhs_coefficient->is_number()->value()==1) { + coefficient_ = rhs_coefficient->clone(); + } + else { + coefficient_ = + binary_expression(Location(), + tok::times, + std::move(lhs_coefficient), + std::move(rhs_coefficient)); + } + // constant term + if(lhs_constant) { + constant_ = binary_expression(Location(), + tok::times, + std::move(lhs_constant), + std::move(rhs_coefficient)); + } else { + constant_ = nullptr; + } + return; + case tok::plus : + coefficient_ = std::move(lhs_coefficient); + // constant term + if(lhs_constant && rhs_constant) { + constant_ = binary_expression(Location(), + tok::plus, + std::move(lhs_constant), + std::move(rhs_constant)); + } + else if(lhs_constant) { + constant_ = binary_expression(Location(), + tok::plus, + std::move(lhs_constant), + std::move(rhs_coefficient)); + } + else { + constant_ = std::move(rhs_coefficient); + } + return; + case tok::minus : + coefficient_ = std::move(lhs_coefficient); + // constant term + if(lhs_constant && rhs_constant) { + constant_ = binary_expression(Location(), + tok::minus, + std::move(lhs_constant), + std::move(rhs_constant)); + } + else if(lhs_constant) { + constant_ = binary_expression(Location(), + tok::minus, + std::move(lhs_constant), + std::move(rhs_coefficient)); + } + else { + constant_ = unary_expression(Location(), + tok::minus, + std::move(rhs_coefficient)); + } + return; + case tok::divide: + coefficient_ = binary_expression(Location(), + tok::divide, + std::move(lhs_coefficient), + rhs_coefficient->clone()); + if(lhs_constant) { + constant_ = binary_expression(Location(), + tok::divide, + std::move(lhs_constant), + std::move(rhs_coefficient)); + } + return; + case tok::pow : + case tok::lt : + case tok::lte : + case tok::gt : + case tok::gte : + case tok::equality : + is_linear_ = false; + return; + default: + return; + } + } + } + // neither lhs or rhs contains symbol + // continue building the coefficient + else { + coefficient_ = e->clone(); + } +} + +void ExpressionClassifierVisitor::visit(CallExpression *e) { + for(auto& a : e->args()) { + a->accept(this); + // we assume that the parameter passed into a function + // won't be linear + if(found_symbol_) { + is_linear_ = false; + return; + } + } +} + diff --git a/modcc/src/expressionclassifier.hpp b/modcc/src/expressionclassifier.hpp new file mode 100644 index 0000000000000000000000000000000000000000..505bdd2b30fb4d018001e144a7dcf117e92c76db --- /dev/null +++ b/modcc/src/expressionclassifier.hpp @@ -0,0 +1,121 @@ +#pragma once + +#include <mutex> + +#include "constantfolder.hpp" +#include "scope.hpp" +#include "visitor.hpp" + +enum class expressionClassification { + constant, + linear, + nonlinear +}; + +class ExpressionClassifierVisitor : public Visitor { +public: + ExpressionClassifierVisitor(Symbol *s) + : symbol_(s) + { + const_folder_ = new ConstantFolderVisitor(); + } + + void reset(Symbol* s) { + reset(); + symbol_ = s; + } + + void reset() { + is_linear_ = true; + found_symbol_ = false; + configured_ = false; + coefficient_ = nullptr; + constant_ = nullptr; + } + + void visit(Expression *e) override; + void visit(UnaryExpression *e) override; + void visit(BinaryExpression *e) override; + void visit(NumberExpression *e) override; + void visit(IdentifierExpression *e) override; + void visit(CallExpression *e) override; + + expressionClassification classify() const { + if(!found_symbol_) { + return expressionClassification::constant; + } + if(is_linear_) { + return expressionClassification::linear; + } + return expressionClassification::nonlinear; + } + + Expression *linear_coefficient() { + set(); + return coefficient_.get(); + } + + Expression *constant_term() { + set(); + return constant_.get(); + } + + ~ExpressionClassifierVisitor() { + delete const_folder_; + } + +private: + + void set() const { + // a mutex is required because two threads might attempt to update + // the cached constant_/coefficient_ values, which would violate the + // condition that set() is const + std::lock_guard<std::mutex> g(mutex_); + + // update the constant_ and coefficient_ terms if they have not already + // been set + if(!configured_) { + if(classify() == expressionClassification::linear) { + // if constat_ was never set, it must be zero + if(!constant_) { + constant_ = + make_expression<NumberExpression>(Location(), 0.); + } + // perform constant folding on the coefficient term + coefficient_->accept(const_folder_); + if(const_folder_->is_number) { + // if the folding resulted in a constant, reset coefficient + // to be a NumberExpression + coefficient_.reset(new NumberExpression( + Location(), + const_folder_->value) + ); + } + } + else if(classify() == expressionClassification::constant) { + coefficient_.reset(new NumberExpression( + Location(), + 0.) + ); + } + else { // nonlinear expression + coefficient_ = nullptr; + constant_ = nullptr; + } + configured_ = true; + } + } + + // assume linear until otherwise proven + bool is_linear_ = true; + bool found_symbol_ = false; + mutable bool configured_ = false; + mutable expression_ptr coefficient_; + mutable expression_ptr constant_; + Symbol* symbol_; + ConstantFolderVisitor* const_folder_; + + mutable std::mutex mutex_; + +}; + diff --git a/modcc/src/functionexpander.cpp b/modcc/src/functionexpander.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c2e727fdb33ecb6bc59c73eddf36ba4353263c22 --- /dev/null +++ b/modcc/src/functionexpander.cpp @@ -0,0 +1,165 @@ +#include <iostream> + +#include "error.hpp" +#include "functionexpander.hpp" +#include "util.hpp" + +/////////////////////////////////////////////////////////////////////////////// +// function call site lowering +/////////////////////////////////////////////////////////////////////////////// + +call_list_type lower_function_calls(Expression* e) +{ + auto v = make_unique<FunctionCallLowerer>(e->scope()); + + if(auto a=e->is_assignment()) { +#ifdef LOGGING + std::cout << "lower_function_calls inspect expression " << e->to_string() << "\n"; +#endif + // recursively inspect and replace function calls with identifiers + a->rhs()->accept(v.get()); + + } + + // return the list of statements that assign function call return values + // to identifiers, e.g. + // LOCAL ll1_ + // ll1_ = mInf(v) + return v->move_calls(); +} + +void FunctionCallLowerer::visit(Expression *e) { + throw compiler_exception( + "function lowering for expressions of the type " + e->to_string() + + " has not been defined", e->location() + ); +} + +void FunctionCallLowerer::visit(CallExpression *e) { + for(auto& arg : e->args()) { + if(auto func = arg->is_function_call()) { + func->accept(this); +#ifdef LOGGING + std::cout << " lowering : " << func->to_string() << "\n"; +#endif + expand_call( + func, [&arg](expression_ptr&& p){arg = std::move(p);} + ); + arg->semantic(scope_); + } + else { + arg->accept(this); + } + } +} + +void FunctionCallLowerer::visit(UnaryExpression *e) { + if(auto func = e->expression()->is_function_call()) { + func->accept(this); +#ifdef LOGGING + std::cout << " lowering : " << func->to_string() << "\n"; +#endif + expand_call(func, [&e](expression_ptr&& p){e->replace_expression(std::move(p));}); + e->semantic(scope_); + } + else { + e->expression()->accept(this); + } +} + +void FunctionCallLowerer::visit(BinaryExpression *e) { + if(auto func = e->lhs()->is_function_call()) { + func->accept(this); +#ifdef LOGGING + std::cout << " lowering : " << func->to_string() << "\n"; +#endif + expand_call(func, [&e](expression_ptr&& p){e->replace_lhs(std::move(p));}); + e->semantic(scope_); + } + else { + e->lhs()->accept(this); + } + + if(auto func = e->rhs()->is_function_call()) { + func->accept(this); +#ifdef LOGGING + std::cout << " lowering : " << func->to_string() << "\n"; +#endif + expand_call(func, [&e](expression_ptr&& p){e->replace_rhs(std::move(p));}); + e->semantic(scope_); + } + else { + e->rhs()->accept(this); + } +} + +/////////////////////////////////////////////////////////////////////////////// +// function argument lowering +/////////////////////////////////////////////////////////////////////////////// +Symbol* make_unique_local(std::shared_ptr<Scope<Symbol>> scope) { + std::string name; + auto i = 0; + do { + name = pprintf("ll%_", i); + ++i; + } while(scope->find(name)); + + return + scope->add_local_symbol( + name, + make_symbol<LocalVariable>( + Location(), name, localVariableKind::local + ) + ); +} + +call_list_type +lower_function_arguments(std::vector<expression_ptr>& args) +{ + call_list_type new_statements; + for(auto it=args.begin(); it!=args.end(); ++it) { + // get reference to the unique_ptr with the expression + auto& e = *it; +#ifdef LOGGING + std::cout << "inspecting argument @ " << e->location() << " : " << e->to_string() << std::endl; +#endif + + if(e->is_number() || e->is_identifier()) { + // do nothing, because identifiers and literals are in the correct form + // for lowering + continue; + } + + // use the source location of the original statement + auto loc = e->location(); + + // make an identifier for the new symbol which will store the result of + // the function call + auto id = make_expression<IdentifierExpression> + (loc, make_unique_local(e->scope())->name()); + id->semantic(e->scope()); + + // generate a LOCAL declaration for the variable + new_statements.push_front( + make_expression<LocalDeclaration>(loc, id->is_identifier()->spelling()) + ); + + // make a binary expression which assigns the argument to the variable + auto ass = binary_expression(loc, tok::eq, id->clone(), e->clone()); + ass->semantic(e->scope()); +#ifdef LOGGING + std::cout << " lowering to " << ass->to_string() << "\n"; +#endif + new_statements.push_back(std::move(ass)); + + // replace the function call in the original expression with the local + // variable which holds the pre-computed value + std::swap(e, id); + } +#ifdef LOGGING + std::cout << "\n"; +#endif + + return new_statements; +} + diff --git a/modcc/src/functionexpander.hpp b/modcc/src/functionexpander.hpp new file mode 100644 index 0000000000000000000000000000000000000000..38c799de24506732e1399ed64c2cb585a85defca --- /dev/null +++ b/modcc/src/functionexpander.hpp @@ -0,0 +1,130 @@ +#pragma once + +#include <sstream> + +#include "scope.hpp" +#include "visitor.hpp" + +// storage for a list of expressions +using call_list_type = std::list<expression_ptr>; + +// prototype for lowering function calls +call_list_type lower_function_calls(Expression* e); + +/////////////////////////////////////////////////////////////////////////////// +// visitor that takes function call sites and lowers them to inline assignments +// +// e.g. if called on the following statement +// +// a = 3 + foo(x, y) +// +// the calls_ member will be +// +// LOCAL ll0_ +// ll0_ = foo(x,y) +// +// and the original statment is modified to be +// +// a = 3 + ll0_ +// +// If the calls_ data is spliced directly before the original statement +// the function call will have been fully lowered +/////////////////////////////////////////////////////////////////////////////// +class FunctionCallLowerer : public Visitor { + +public: + using scope_type = Scope<Symbol>; + + FunctionCallLowerer(std::shared_ptr<scope_type> s) + : scope_(s) + {} + + void visit(CallExpression *e) override; + void visit(Expression *e) override; + void visit(UnaryExpression *e) override; + void visit(BinaryExpression *e) override; + void visit(NumberExpression *e) override {}; + void visit(IdentifierExpression *e) override {}; + + call_list_type& calls() { + return calls_; + } + + call_list_type move_calls() { + return std::move(calls_); + } + + ~FunctionCallLowerer() {} + +private: + Symbol* make_unique_local() { + std::string name; + auto i = 0; + do { + name = pprintf("ll%_", i); + ++i; + } while(scope_->find(name)); + + auto sym = + scope_->add_local_symbol( + name, + make_symbol<LocalVariable>( + Location(), name, localVariableKind::local + ) + ); + + return sym; + } + + template< typename F> + void expand_call(CallExpression* func, F replacer) { + // use the source location of the original statement + auto loc = func->location(); + + // make an identifier for the new symbol which will store the result of + // the function call + auto id = make_expression<IdentifierExpression> + (loc, make_unique_local()->name()); + id->semantic(scope_); + // generate a LOCAL declaration for the variable + calls_.push_front( + make_expression<LocalDeclaration>(loc, id->is_identifier()->spelling()) + ); + calls_.front()->semantic(scope_); + + // make a binary expression which assigns the function to the variable + auto ass = binary_expression(loc, tok::eq, id->clone(), func->clone()); + ass->semantic(scope_); + calls_.push_back(std::move(ass)); + + // replace the function call in the original expression with the local + // variable which holds the pre-computed value + replacer(std::move(id)); + } + + call_list_type calls_; + std::shared_ptr<scope_type> scope_; +}; + +/////////////////////////////////////////////////////////////////////////////// +// visitor that takes function arguments that are not literals of identifiers +// and lowers them to inline assignments +// +// e.g. if called on the following statement +// +// a = foo(2+x, y) +// +// the calls_ member will be +// +// LOCAL ll0_ +// ll0_ = 2+x +// +// and the original statment is modified to be +// +// a = foo(ll0_, y) +// +// If the calls_ data is spliced directly before the original statement +// the function arguments will have been fully lowered +/////////////////////////////////////////////////////////////////////////////// +call_list_type lower_function_arguments(std::vector<expression_ptr>& args); + diff --git a/modcc/src/functioninliner.cpp b/modcc/src/functioninliner.cpp new file mode 100644 index 0000000000000000000000000000000000000000..514594c1d32def708165c881c3f6b4d1a6800a6b --- /dev/null +++ b/modcc/src/functioninliner.cpp @@ -0,0 +1,167 @@ +#include <iostream> + +#include "error.hpp" +#include "functioninliner.hpp" +#include "util.hpp" +#include "errorvisitor.hpp" + +expression_ptr inline_function_call(Expression* e) +{ + if(auto f=e->is_function_call()) { + auto func = f->function(); +#ifdef LOGGING + std::cout << "inline_function_call for statement " << f->to_string() + << " with body" << func->body()->to_string() << "\n"; +#endif + auto& body = func->body()->statements(); + if(body.size() != 1) { + throw compiler_exception( + "can only inline functions with one statement", func->location() + ); + } + // assume that the function body is correctly formed, with the last + // statement being an assignment expression + auto last = body.front()->is_assignment(); + auto new_e = last->rhs()->clone(); + + auto& fargs = func->args(); // argument names for the function + auto& cargs = f->args(); // arguments at the call site + for(auto i=0u; i<fargs.size(); ++i) { + if(auto id = cargs[i]->is_identifier()) { +#ifdef LOGGING + std::cout << "inline_function_call symbol replacement " + << id->to_string() << " -> " << fargs[i]->to_string() + << " in the expression " << new_e->to_string() << "\n"; +#endif + auto v = + make_unique<VariableReplacer>( + fargs[i]->is_argument()->spelling(), + id->spelling() + ); + new_e->accept(v.get()); + } + else if(auto value = cargs[i]->is_number()) { +#ifdef LOGGING + std::cout << "inline_function_call symbol replacement " + << value->to_string() << " -> " << fargs[i]->to_string() + << " in the expression " << new_e->to_string() << "\n"; +#endif + auto v = + make_unique<ValueInliner>( + fargs[i]->is_argument()->spelling(), + value->value() + ); + new_e->accept(v.get()); + } + else { + throw compiler_exception( + "can't inline functions with expressions as arguments", + e->location() + ); + } + } + new_e->semantic(e->scope()); + + auto v = make_unique<ErrorVisitor>(""); + new_e->accept(v.get()); +#ifdef LOGGING + std::cout << "inline_function_call result " << new_e->to_string() << "\n\n"; +#endif + if(v->num_errors()) { + throw compiler_exception("something went wrong with inlined function call ", + e->location()); + } + + return new_e; + } + + return {}; +} + +/////////////////////////////////////////////////////////////////////////////// +// variable replacer +/////////////////////////////////////////////////////////////////////////////// + +void VariableReplacer::visit(Expression *e) { + throw compiler_exception( + "I don't know how to variable inlining for this statement : " + + e->to_string(), e->location()); +} + +void VariableReplacer::visit(UnaryExpression *e) { + auto exp = e->expression()->is_identifier(); + if(exp && exp->spelling()==source_) { + e->replace_expression( + make_expression<IdentifierExpression>(exp->location(), target_) + ); + } + else if(!exp) { + e->expression()->accept(this); + } +} + +void VariableReplacer::visit(BinaryExpression *e) { + auto lhs = e->lhs()->is_identifier(); + if(lhs && lhs->spelling()==source_) { + e->replace_lhs( + make_expression<IdentifierExpression>(lhs->location(), target_) + ); + } + else if(!lhs){ // only inspect subexpressions that are not themselves identifiers + e->lhs()->accept(this); + } + + auto rhs = e->rhs()->is_identifier(); + if(rhs && rhs->spelling()==source_) { + e->replace_rhs( + make_expression<IdentifierExpression>(rhs->location(), target_) + ); + } + else if(!rhs){ // only inspect subexpressions that are not themselves identifiers + e->rhs()->accept(this); + } +} + +/////////////////////////////////////////////////////////////////////////////// +// value inliner +/////////////////////////////////////////////////////////////////////////////// + +void ValueInliner::visit(Expression *e) { + throw compiler_exception( + "I don't know how to value inlining for this statement : " + + e->to_string(), e->location()); +} + +void ValueInliner::visit(UnaryExpression *e) { + auto exp = e->expression()->is_identifier(); + if(exp && exp->spelling()==source_) { + e->replace_expression( + make_expression<NumberExpression>(exp->location(), value_) + ); + } + else if(!exp){ + e->expression()->accept(this); + } +} + +void ValueInliner::visit(BinaryExpression *e) { + auto lhs = e->lhs()->is_identifier(); + if(lhs && lhs->spelling()==source_) { + e->replace_lhs( + make_expression<NumberExpression>(lhs->location(), value_) + ); + } + else if(!lhs) { + e->lhs()->accept(this); + } + + auto rhs = e->rhs()->is_identifier(); + if(rhs && rhs->spelling()==source_) { + e->replace_rhs( + make_expression<NumberExpression>(rhs->location(), value_) + ); + } + else if(!rhs){ + e->rhs()->accept(this); + } +} diff --git a/modcc/src/functioninliner.hpp b/modcc/src/functioninliner.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f100ddbc12914d0536025fbd643637cf065c4af4 --- /dev/null +++ b/modcc/src/functioninliner.hpp @@ -0,0 +1,52 @@ +#pragma once + +#include <sstream> + +#include "scope.hpp" +#include "visitor.hpp" + +expression_ptr inline_function_call(Expression* e); + +class VariableReplacer : public Visitor { + +public: + + VariableReplacer(std::string const& source, std::string const& target) + : source_(source), + target_(target) + {} + + void visit(Expression *e) override; + void visit(UnaryExpression *e) override; + void visit(BinaryExpression *e) override; + void visit(NumberExpression *e) override {}; + + ~VariableReplacer() {} + +private: + + std::string source_; + std::string target_; +}; + +class ValueInliner : public Visitor { + +public: + + ValueInliner(std::string const& source, long double value) + : source_(source), + value_(value) + {} + + void visit(Expression *e) override; + void visit(UnaryExpression *e) override; + void visit(BinaryExpression *e) override; + void visit(NumberExpression *e) override {}; + + ~ValueInliner() {} + +private: + + std::string source_; + long double value_; +}; diff --git a/modcc/src/identifier.hpp b/modcc/src/identifier.hpp new file mode 100644 index 0000000000000000000000000000000000000000..60a09436acb63dcf2e2eb39cd49a25bb026aeddb --- /dev/null +++ b/modcc/src/identifier.hpp @@ -0,0 +1,118 @@ +#pragma once + +#include <string> + +/// indicate how a variable is accessed +/// access is (read, written, or both) +/// the distinction between write only and read only is required because +/// if an external variable is to be written/updated, then it does not have +/// to be loaded before applying a kernel. +enum class accessKind { + read, + write, + readwrite +}; + +/// describes the scope of a variable +enum class visibilityKind { + local, + global +}; + +/// describes the scope of a variable +enum class rangeKind { + range, + scalar +}; + +/// the whether the variable value is defined inside or outside of the module. +enum class linkageKind { + local, + external +}; + +/// ion channel that the variable belongs to +enum class ionKind { + none, ///< not an ion variable + nonspecific, ///< nonspecific current + Ca, ///< calcium ion + Na, ///< sodium ion + K ///< potassium ion +}; + +inline std::string yesno(bool val) { + return std::string(val ? "yes" : "no"); +}; + +//////////////////////////////////////////// +// to_string functions convert types +// to strings for printing diagnostics +//////////////////////////////////////////// +inline std::string to_string(ionKind i) { + switch(i) { + case ionKind::none : return std::string("none"); + case ionKind::Ca : return std::string("calcium"); + case ionKind::Na : return std::string("sodium"); + case ionKind::K : return std::string("potassium"); + case ionKind::nonspecific : return std::string("nonspecific"); + } + return std::string("<error : undefined ionKind>"); +} + +inline std::string to_string(visibilityKind v) { + switch(v) { + case visibilityKind::local : return std::string("local"); + case visibilityKind::global: return std::string("global"); + } + return std::string("<error : undefined visibilityKind>"); +} + +inline std::string to_string(linkageKind v) { + switch(v) { + case linkageKind::local : return std::string("local"); + case linkageKind::external: return std::string("external"); + } + return std::string("<error : undefined visibilityKind>"); +} + +// ostream writers +inline std::ostream& operator<< (std::ostream& os, ionKind i) { + return os << to_string(i); +} + +inline std::ostream& operator<< (std::ostream& os, visibilityKind v) { + return os << to_string(v); +} + +inline std::ostream& operator<< (std::ostream& os, linkageKind l) { + return os << to_string(l); +} + +inline ionKind ion_kind_from_name(std::string field) { + if(field.substr(0,4) == "ion_") { + field = field.substr(4); + } + if(field=="ica" || field=="eca" || field=="cai" || field=="cao") { + return ionKind::Ca; + } + if(field=="ik" || field=="ek" || field=="ki" || field=="ko") { + return ionKind::K; + } + if(field=="ina" || field=="ena" || field=="nai" || field=="nao") { + return ionKind::Na; + } + return ionKind::none; +} + +inline std::string ion_store(ionKind k) { + switch(k) { + case ionKind::Ca: + return "ion_ca"; + case ionKind::Na: + return "ion_na"; + case ionKind::K: + return "ion_k"; + default: + return ""; + } +} diff --git a/modcc/src/lexer.cpp b/modcc/src/lexer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b3562dae9571165ce7ec18bf5fa4fb9c6bed83d7 --- /dev/null +++ b/modcc/src/lexer.cpp @@ -0,0 +1,373 @@ +#include <cstdio> + +#include <iostream> +#include <string> + +#include "lexer.hpp" +#include "util.hpp" + +// helpers for identifying character types +inline bool in_range(char c, char first, char last) { + return c>=first && c<=last; +} +inline bool is_numeric(char c) { + return in_range(c, '0', '9'); +} +inline bool is_alpha(char c) { + return (in_range(c, 'a', 'z') || in_range(c, 'A', 'Z') ); +} +inline bool is_alphanumeric(char c) { + return (is_numeric(c) || is_alpha(c) ); +} +inline bool is_whitespace(char c) { + return (c==' ' || c=='\t' || c=='\v' || c=='\f'); +} +inline bool is_eof(char c) { + return (c==0 || c==EOF); +} +inline bool is_operator(char c) { + return (c=='+' || c=='-' || c=='*' || c=='/' || c=='^' || c=='\''); +} + +//********************* +// Lexer +//********************* + +Token Lexer::parse() { + Token t; + + // the while loop strips white space/new lines in front of the next token + while(1) { + location_.column = current_-line_+1; + t.location = location_; + + switch(*current_) { + // end of file + case 0 : // end of string + case EOF : // end of file + t.spelling = "eof"; + t.type = tok::eof; + return t; + + // white space + case ' ' : + case '\t' : + case '\v' : + case '\f' : + current_++; + continue; // skip to next character + + // new line + case '\n' : + current_++; + line_ = current_; + location_.line++; + continue; // skip to next line + + // new line + case '\r' : + current_++; + if(*current_ != '\n') { + error_string_ = pprintf("bad line ending: \\n must follow \\r"); + status_ = lexerStatus::error; + t.type = tok::reserved; + return t; + } + current_++; + line_ = current_; + location_.line++; + continue; // skip to next line + + // comment (everything after : on a line is a comment) + case ':' : + // strip characters until either end of file or end of line + while( !is_eof(*current_) && *current_ != '\n') { + ++current_; + } + continue; + + // number + case '0': case '1' : case '2' : case '3' : case '4': + case '5': case '6' : case '7' : case '8' : case '9': + case '.': + t.spelling = number(); + + // test for error when reading number + t.type = (status_==lexerStatus::error) ? tok::reserved : tok::number; + return t; + + // identifier or keyword + case 'a': case 'b': case 'c': case 'd': case 'e': case 'f': case 'g': + case 'h': case 'i': case 'j': case 'k': case 'l': case 'm': case 'n': + case 'o': case 'p': case 'q': case 'r': case 's': case 't': case 'u': + case 'v': case 'w': case 'x': case 'y': case 'z': + case 'A': case 'B': case 'C': case 'D': case 'E': case 'F': case 'G': + case 'H': case 'I': case 'J': case 'K': case 'L': case 'M': case 'N': + case 'O': case 'P': case 'Q': case 'R': case 'S': case 'T': case 'U': + case 'V': case 'W': case 'X': case 'Y': case 'Z': + case '_': + // get std::string of the identifier + t.spelling = identifier(); + t.type + = status_==lexerStatus::error + ? tok::reserved + : get_identifier_type(t.spelling); + return t; + case '(': + t.type = tok::lparen; + t.spelling += character(); + return t; + case ')': + t.type = tok::rparen; + t.spelling += character(); + return t; + case '{': + t.type = tok::lbrace; + t.spelling += character(); + return t; + case '}': + t.type = tok::rbrace; + t.spelling += character(); + return t; + case '=': { + t.spelling += character(); + if(*current_=='=') { + t.spelling += character(); + t.type=tok::equality; + } + else { + t.type = tok::eq; + } + return t; + } + case '!': { + t.spelling += character(); + if(*current_=='=') { + t.spelling += character(); + t.type=tok::ne; + } + else { + t.type = tok::lnot; + } + return t; + } + case '+': + t.type = tok::plus; + t.spelling += character(); + return t; + case '-': + t.type = tok::minus; + t.spelling += character(); + return t; + case '/': + t.type = tok::divide; + t.spelling += character(); + return t; + case '*': + t.type = tok::times; + t.spelling += character(); + return t; + case '^': + t.type = tok::pow; + t.spelling += character(); + return t; + // comparison binary operators + case '<': { + t.spelling += character(); + if(*current_=='=') { + t.spelling += character(); + t.type = tok::lte; + } + else { + t.type = tok::lt; + } + return t; + } + case '>': { + t.spelling += character(); + if(*current_=='=') { + t.spelling += character(); + t.type = tok::gte; + } + else { + t.type = tok::gt; + } + return t; + } + case '\'': + t.type = tok::prime; + t.spelling += character(); + return t; + case ',': + t.type = tok::comma; + t.spelling += character(); + return t; + default: + error_string_ = + pprintf( "unexpected character '%' at %", + *current_, location_); + status_ = lexerStatus::error; + t.spelling += character(); + t.type = tok::reserved; + return t; + } + } + + // return the token + return t; +} + +Token Lexer::peek() { + // save the current position + const char *oldpos = current_; + const char *oldlin = line_; + Location oldloc = location_; + + Token t = parse(); // read the next token + + // reset position + current_ = oldpos; + location_ = oldloc; + line_ = oldlin; + + return t; +} + +// scan floating point number from stream +std::string Lexer::number() { + std::string str; + char c = *current_; + + // start counting the number of points in the number + auto num_point = (c=='.' ? 1 : 0); + auto uses_scientific_notation = 0; + bool incorrectly_formed_mantisa = false; + + str += c; + current_++; + while(1) { + c = *current_; + if(is_numeric(c)) { + str += c; + current_++; + } + else if(c=='.') { + num_point++; + str += c; + current_++; + if(uses_scientific_notation) { + incorrectly_formed_mantisa = true; + } + } + else if(c=='e' || c=='E') { + uses_scientific_notation++; + str += c; + current_++; + } + else { + break; + } + } + + // check that the mantisa is an integer + if(incorrectly_formed_mantisa) { + error_string_ = pprintf("the exponent/mantissa must be an integer '%'", yellow(str)); + status_ = lexerStatus::error; + } + // check that there is at most one decimal point + // i.e. disallow values like 2.2324.323 + if(num_point>1) { + error_string_ = pprintf("too many .'s when reading the number '%'", yellow(str)); + status_ = lexerStatus::error; + } + // check that e or E is not used more than once in the number + if(uses_scientific_notation>1) { + error_string_ = pprintf("can't parse the number '%'", yellow(str)); + status_ = lexerStatus::error; + } + + return str; +} + +// scan identifier from stream +// examples of valid names: +// _1 _a ndfs var9 num_a This_ +// examples of invalid names: +// _ __ 9val 9_ +std::string Lexer::identifier() { + std::string name; + char c = *current_; + + // assert that current position is at the start of a number + // note that the first character can't be numeric + if( !(is_alpha(c) || c=='_') ) { + throw compiler_exception( + "Lexer attempting to read number when none is available", + location_); + } + + name += c; + current_++; + while(1) { + c = *current_; + + if(is_alphanumeric(c) || c=='_') { + name += c; + current_++; + } + else { + break; + } + } + + return name; +} + +// scan a single character from the buffer +char Lexer::character() { + return *current_++; +} + +std::map<tok, int> Lexer::binop_prec_; + +void Lexer::binop_prec_init() { + if(binop_prec_.size()>0) + return; + + // I have taken the operator precedence from C++ + binop_prec_[tok::eq] = 2; + binop_prec_[tok::equality] = 4; + binop_prec_[tok::ne] = 4; + binop_prec_[tok::lt] = 5; + binop_prec_[tok::lte] = 5; + binop_prec_[tok::gt] = 5; + binop_prec_[tok::gte] = 5; + binop_prec_[tok::plus] = 10; + binop_prec_[tok::minus] = 10; + binop_prec_[tok::times] = 20; + binop_prec_[tok::divide] = 20; + binop_prec_[tok::pow] = 30; +} + +int Lexer::binop_precedence(tok tok) { + auto r = binop_prec_.find(tok); + if(r==binop_prec_.end()) + return -1; + return r->second; +} + +associativityKind Lexer::operator_associativity(tok token) { + if(token==tok::pow) { + return associativityKind::right; + } + return associativityKind::left; +} + +// pre : identifier is a valid identifier ([_a-zA-Z][_a-zA-Z0-9]*) +// post : if(identifier is a keyword) return tok::<keyword> +// else return tok::identifier +tok Lexer::get_identifier_type(std::string const& identifier) { + auto pos = keyword_map.find(identifier); + return pos==keyword_map.end() ? tok::identifier : pos->second; +} + diff --git a/modcc/src/lexer.hpp b/modcc/src/lexer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a998287fabcf2edddc7c96b43675a1509ff41dbb --- /dev/null +++ b/modcc/src/lexer.hpp @@ -0,0 +1,117 @@ +#pragma once + +// inspiration was taken from the Digital Mars D compiler +// github.com/D-Programming-Language/dmd + +#include <map> +#include <string> +#include <unordered_map> +#include <vector> + +#include "location.hpp" +#include "error.hpp" +#include "token.hpp" + +// status of the lexer +enum class lexerStatus { + error, // lexer has encounterd a problem + happy // lexer is in a good place +}; + +// associativity of an operator +enum class associativityKind { + left, + right +}; + +bool is_keyword(Token const& t); + + +// class that implements the lexer +// takes a range of characters as input parameters +class Lexer { +public: + Lexer(const char * begin, const char* end) + : begin_(begin), + end_(end), + current_(begin), + line_(begin), + location_() + { + if(begin_>end_) { + throw std::out_of_range("Lexer(begin, end) : begin>end"); + } + + initialize_token_maps(); + binop_prec_init(); + } + + Lexer(std::vector<char> const& v) + : Lexer(v.data(), v.data()+v.size()) + {} + + Lexer(std::string const& s) + : buffer_(s.data(), s.data()+s.size()+1) + { + begin_ = buffer_.data(); + end_ = buffer_.data() + buffer_.size(); + current_ = begin_; + line_ = begin_; + + initialize_token_maps(); + binop_prec_init(); + } + + // get the next token + Token parse(); + + void get_token() { + token_ = parse(); + } + + // return the next token in the stream without advancing the current position + Token peek(); + + // scan a number from the stream + std::string number(); + + // scan an identifier string from the stream + std::string identifier(); + + // scan a character from the stream + char character(); + + Location location() {return location_;} + + // binary operator precedence + static std::map<tok, int> binop_prec_; + + lexerStatus status() {return status_;} + + const std::string& error_message() {return error_string_;}; + + static int binop_precedence(tok tok); + static associativityKind operator_associativity(tok token); +protected: + // buffer used for short-lived parsers + std::vector<char> buffer_; + + // generate lookup tables (hash maps) for keywords + void keywords_init(); + void token_strings_init(); + void binop_prec_init(); + + // helper for determining if an identifier string matches a keyword + tok get_identifier_type(std::string const& identifier); + + const char *begin_, *end_;// pointer to start and 1 past the end of the buffer + const char *current_; // pointer to current character + const char *line_; // pointer to start of current line + Location location_; // current location (line,column) in buffer + + lexerStatus status_ = lexerStatus::happy; + std::string error_string_; + + Token token_; +}; + diff --git a/modcc/src/location.hpp b/modcc/src/location.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f78f66d4782d3044cc4fc4104e5604abca319ac9 --- /dev/null +++ b/modcc/src/location.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include <ostream> + +struct Location { + int line; + int column; + + Location() : Location(1, 1) + {} + + Location(int ln, int col) + : line(ln), + column(col) + {} +}; + +inline std::ostream& operator<< (std::ostream& os, Location const& L) { + return os << "(line " << L.line << ",col " << L.column << ")"; +} + + diff --git a/modcc/src/mechanism.hpp b/modcc/src/mechanism.hpp new file mode 100644 index 0000000000000000000000000000000000000000..309bd95daf2464be92777a039fa761f1a5ac1fff --- /dev/null +++ b/modcc/src/mechanism.hpp @@ -0,0 +1,43 @@ +#pragma once + +#include "lib/vector/include/Vector.h" + +/* + Abstract base class for all mechanisms + + This works well for the standard interface that is exported by all mechanisms, + i.e. nrn_jacobian(), nrn_current(), etc. The overhead of using virtual dispatch + for such functions is negligable compared to the cost of the operations themselves. + However, the friction between compile time and run time dispatch has to be considered + carefully: + - we want to dispatch on template parameters like vector width, target + hardware etc. Maybe these could be template parameters to the base + class? + - how to expose mechanism functionality, e.g. for computing a mechanism- + specific quantity for visualization? +*/ + +class Mechanism { +public: + // typedefs for storage + using value_type = double; + using vector_type = memory::HostVector<value_type>; + using view_type = memory::HostView<value_type>; + + Mechanism(std::string const& name) + : name_(name) + {} + + //virtual void state() = 0; + //virtual void jacobi() = 0; + virtual void current() = 0; + virtual void init() = 0; + + std::string const& name() { + return name_; + } + +protected: + std::string name_; +}; + diff --git a/modcc/src/memop.hpp b/modcc/src/memop.hpp new file mode 100644 index 0000000000000000000000000000000000000000..acfe084a77b1224427aaf427980646da70d698b7 --- /dev/null +++ b/modcc/src/memop.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include "util.hpp" +#include "lexer.hpp" + +/// Defines a memory operation that is to performed by an APIMethod. +/// Kernels can read/write global state via an index, e.g. +/// - loading voltage v from VEC_V in matrix before computation +/// - loading a variable associated with an ionic variable +/// - accumulating an update to VEC_RHS/VEC_D after computation +/// - adding contribution to an ionic current +/// How these operations are handled will vary significantly from +/// one backend implementation to another, so inserting expressions +/// directly into the APIMethod body to perform them is not appropriate. +/// Instead, each API method stores two lists +/// - a list of load/input transactions to perform before kernel +/// - a list of store/output transactions to perform after kernel +/// The lists are of MemOps, which describe the local and external variables +template <typename Symbol> +struct MemOp { + using symbol_type = Symbol; + tok op; + Symbol *local; + Symbol *external; + + MemOp(tok o, Symbol *loc, Symbol *ext) + : op(o), local(loc), external(ext) + { + const tok valid_ops[] = {tok::plus, tok::minus, tok::eq}; + if(!is_in(op, valid_ops)) { + throw compiler_exception( + "invalid operation for creating a MemOp : " + + loc->to_string() + yellow(token_string(op)) + ext->to_string(), + loc->location()); + } + } +}; + diff --git a/modcc/src/modcc.cpp b/modcc/src/modcc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9161bbfd2f63d6945739fceb2825a7f3ede1dfef --- /dev/null +++ b/modcc/src/modcc.cpp @@ -0,0 +1,231 @@ +#include <chrono> +#include <iostream> +#include <fstream> + +#include <tclap/CmdLine.h> + +#include "cprinter.hpp" +#include "cudaprinter.hpp" +#include "lexer.hpp" +#include "module.hpp" +#include "parser.hpp" +#include "perfvisitor.hpp" +#include "util.hpp" + +//#define VERBOSE + +enum class targetKind {cpu, gpu}; + +struct Options { + std::string filename; + std::string outputname; + bool has_output = false; + bool verbose = true; + bool optimize = false; + bool analysis = false; + targetKind target = targetKind::cpu; + + void print() { + std::cout << cyan("." + std::string(60, '-') + ".") << std::endl; + std::cout << cyan("| file ") << filename << std::string(61-11-filename.size(),' ') << cyan("|") << std::endl; + std::string outname = (outputname.size() ? outputname : "stdout"); + std::cout << cyan("| output ") << outname << std::string(61-11-outname.size(),' ') << cyan("|") << std::endl; + std::cout << cyan("| verbose ") << (verbose ? "yes" : "no ") << std::string(61-11-3,' ') << cyan("|") << std::endl; + std::cout << cyan("| optimize ") << (optimize ? "yes" : "no ") << std::string(61-11-3,' ') << cyan("|") << std::endl; + std::cout << cyan("| target ") << (target==targetKind::cpu? "cpu" : "gpu") << std::string(61-11-3,' ') << cyan("|") << std::endl; + std::cout << cyan("| analysis ") << (analysis ? "yes" : "no ") << std::string(61-11-3,' ') << cyan("|") << std::endl; + std::cout << cyan("." + std::string(60, '-') + ".") << std::endl; + } +}; + +int main(int argc, char **argv) { + + Options options; + + // parse command line arguments + try { + TCLAP::CmdLine cmd("welcome to mod2c", ' ', "0.1"); + + // input file name (to load multiple files we have to use UnlabeledMultiArg + TCLAP::UnlabeledValueArg<std::string> + fin_arg("input_file", "the name of the .mod file to compile", true, "", "filename"); + // output filename + TCLAP::ValueArg<std::string> + fout_arg("o","output","name of output file", false,"","filname"); + // output filename + TCLAP::ValueArg<std::string> + target_arg("t","target","backend target={cpu,gpu}", true,"cpu","cpu/gpu"); + // verbose mode + TCLAP::SwitchArg verbose_arg("V","verbose","toggle verbose mode", cmd, false); + // analysis mode + TCLAP::SwitchArg analysis_arg("A","analyse","toggle analysis mode", cmd, false); + // optimization mode + TCLAP::SwitchArg opt_arg("O","optimize","turn optimizations on", cmd, false); + + cmd.add(fin_arg); + cmd.add(fout_arg); + cmd.add(target_arg); + + cmd.parse(argc, argv); + + options.outputname = fout_arg.getValue(); + options.has_output = options.outputname.size()>0; + options.filename = fin_arg.getValue(); + options.verbose = verbose_arg.getValue(); + options.optimize = opt_arg.getValue(); + options.analysis = analysis_arg.getValue(); + auto targstr = target_arg.getValue(); + if(targstr == "cpu") { + options.target = targetKind::cpu; + } + else if(targstr == "gpu") { + options.target = targetKind::gpu; + } + else { + std::cerr << red("error") << " target must be one in {cpu, gpu}" << std::endl; + return 1; + } + } + // catch any exceptions in command line handling + catch(TCLAP::ArgException &e) { + std::cerr << "error: " << e.error() + << " for arg " << e.argId() + << std::endl; + } + + try { + // load the module from file passed as first argument + Module m(options.filename.c_str()); + + // check that the module is not empty + if(m.buffer().size()==0) { + std::cout << red("error: ") << white(argv[1]) + << " invalid or empty file" << std::endl; + return 1; + } + + if(options.verbose) { + options.print(); + } + + //////////////////////////////////////////////////////////// + // parsing + //////////////////////////////////////////////////////////// + if(options.verbose) std::cout << green("[") + "parsing" + green("]") << std::endl; + + // initialize the parser + Parser p(m, false); + + // parse + p.parse(); + if(p.status() == lexerStatus::error) return 1; + + //////////////////////////////////////////////////////////// + // semantic analysis + //////////////////////////////////////////////////////////// + if(options.verbose) std::cout << green("[") + "semantic analysis" + green("]") << std::endl; + m.semantic(); + + if( m.has_error() || m.has_warning() ) { + std::cout << m.error_string() << std::endl; + } + if(m.status() == lexerStatus::error) { + return 1; + } + + //////////////////////////////////////////////////////////// + // optimize + //////////////////////////////////////////////////////////// + if(options.optimize) { + if(options.verbose) std::cout << green("[") + "optimize" + green("]") << std::endl; + m.optimize(); + if(m.status() == lexerStatus::error) { + return 1; + } + } + + //////////////////////////////////////////////////////////// + // generate output + //////////////////////////////////////////////////////////// + if(options.verbose) { + std::cout << green("[") + "code generation" + << green("]") << std::endl; + } + + std::string text; + switch(options.target) { + case targetKind::cpu : + text = CPrinter(m, options.optimize).text(); + break; + case targetKind::gpu : + text = CUDAPrinter(m, options.optimize).text(); + break; + default : + std::cerr << red("error") << ": unknown printer" << std::endl; + exit(1); + } + + if(options.has_output) { + std::ofstream fout(options.outputname); + fout << text; + fout.close(); + } + else { + std::cout << cyan("--------------------------------------") << std::endl; + std::cout << text; + std::cout << cyan("--------------------------------------") << std::endl; + } + + std::cout << yellow("successfully compiled ") << white(options.filename) << " -> " << white(options.outputname) << std::endl; + + //////////////////////////////////////////////////////////// + // print module information + //////////////////////////////////////////////////////////// + if(options.analysis) { + std::cout << green("performance analysis") << std::endl; + for(auto &symbol : m.symbols()) { + if(auto method = symbol.second->is_api_method()) { + std::cout << white("-------------------------") << std::endl; + std::cout << yellow("method " + method->name()) << std::endl; + std::cout << white("-------------------------") << std::endl; + + auto flops = make_unique<FlopVisitor>(); + method->accept(flops.get()); + std::cout << white("FLOPS") << std::endl; + std::cout << flops->print() << std::endl; + + std::cout << white("MEMOPS") << std::endl; + auto memops = make_unique<MemOpVisitor>(); + method->accept(memops.get()); + std::cout << memops->print() << std::endl;; + } + } + } + } + + catch(compiler_exception e) { + std::cerr << red("internal compiler error: ") + << white("this means a bug in the compiler," + " please report to modcc developers") + << std::endl + << e.what() << " @ " << e.location() << std::endl; + exit(1); + } + catch(std::exception e) { + std::cerr << red("internal compiler error: ") + << white("this means a bug in the compiler," + " please report to modcc developers") + << std::endl + << e.what() << std::endl; + exit(1); + } + catch(...) { + std::cerr << red("internal compiler error: ") + << white("this means a bug in the compiler," + " please report to modcc developers") + << std::endl; + exit(1); + } + + return 0; +} diff --git a/modcc/src/module.cpp b/modcc/src/module.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d00d90d27c190c0fb74165fe5644dd8d0c15ea7a --- /dev/null +++ b/modcc/src/module.cpp @@ -0,0 +1,814 @@ +#include <algorithm> +#include <cassert> +#include <fstream> +#include <iostream> +#include <set> + +#include "errorvisitor.hpp" +#include "expressionclassifier.hpp" +#include "functionexpander.hpp" +#include "functioninliner.hpp" +#include "module.hpp" +#include "parser.hpp" + +Module::Module(std::string const& fname) +: fname_(fname) +{ + // open the file at the end + std::ifstream fid; + fid.open(fname.c_str(), std::ios::binary | std::ios::ate); + if(!fid.is_open()) { // return if no file opened + return; + } + + // determine size of file + std::size_t size = fid.tellg(); + fid.seekg(0, std::ios::beg); + + // allocate space for storage and read + buffer_.resize(size+1); + fid.read(buffer_.data(), size); + buffer_[size] = 0; // append \0 to terminate string +} + +Module::Module(std::vector<char> const& buffer) +{ + buffer_ = buffer; + + // add \0 to end of buffer if not already present + if(buffer_[buffer_.size()-1] != 0) + buffer_.push_back(0); +} + +std::vector<Module::symbol_ptr>& +Module::procedures() { + return procedures_; +} + +std::vector<Module::symbol_ptr>const& +Module::procedures() const { + return procedures_; +} + +std::vector<Module::symbol_ptr>& +Module::functions() { + return functions_; +} + +std::vector<Module::symbol_ptr>const& +Module::functions() const { + return functions_; +} + +Module::symbol_map& +Module::symbols() { + return symbols_; +} + +Module::symbol_map const& +Module::symbols() const { + return symbols_; +} + +void Module::error(std::string const& msg, Location loc) { + std::string location_info = pprintf("%:% ", file_name(), loc); + if(error_string_.size()) {// append to current string + error_string_ += "\n"; + } + error_string_ += red("error ") + white(location_info) + msg; + status_ = lexerStatus::error; +} + +void Module::warning(std::string const& msg, Location loc) { + std::string location_info = pprintf("%:% ", file_name(), loc); + if(error_string_.size()) {// append to current string + error_string_ += "\n"; + } + error_string_ += purple("warning ") + white(location_info) + msg; + has_warning_ = true; +} + +bool Module::semantic() { + //////////////////////////////////////////////////////////////////////////// + // create the symbol table + // there are three types of symbol to look up + // 1. variables + // 2. function calls + // 3. procedure calls + // the symbol table is generated, then we can traverse the AST and verify + // that all symbols are correctly used + //////////////////////////////////////////////////////////////////////////// + + // first add variables defined in the NEURON, ASSIGNED and PARAMETER + // blocks these symbols have "global" scope, i.e. they are visible to all + // functions and procedurs in the mechanism + add_variables_to_symbols(); + + // Helper which iterates over a vector of Symbols, moving them into the + // symbol table. + // Returns false if a symbol name clases with the name of a symbol that + // is already in the symbol table. + auto move_symbols = [this] (std::vector<symbol_ptr>& symbol_list) { + for(auto& symbol: symbol_list) { + bool is_found = (symbols_.find(symbol->name()) != symbols_.end()); + if(is_found) { + error( + pprintf("'%' clashes with previously defined symbol", + symbol->name()), + symbol->location() + ); + return false; + } + // move symbol to table + symbols_[symbol->name()] = std::move(symbol); + } + return true; + }; + + // move functions and procedures to the symbol table + if(!move_symbols(functions_)) return false; + if(!move_symbols(procedures_)) return false; + + //////////////////////////////////////////////////////////////////////////// + // now iterate over the functions and procedures and perform semantic + // analysis on each. This includes + // - variable, function and procedure lookup + // - generate local variable table for each function/procedure + // - inlining function calls + //////////////////////////////////////////////////////////////////////////// +#ifdef LOGGING + std::cout << white("===================================\n"); + std::cout << cyan(" Function Inlining\n"); + std::cout << white("===================================\n"); +#endif + int errors = 0; + for(auto& e : symbols_) { + auto& s = e.second; + + if( s->kind() == symbolKind::function + || s->kind() == symbolKind::procedure) + { +#ifdef LOGGING + std::cout << "\nfunction inlining for " << s->location() << "\n" << s->to_string() << "\n"; + std::cout << green("\n-call site lowering-\n\n"); +#endif + // first perform semantic analysis + s->semantic(symbols_); + + // then use an error visitor to print out all the semantic errors + auto v = make_unique<ErrorVisitor>(file_name()); + s->accept(v.get()); + errors += v->num_errors(); + + // inline function calls + // this requires that the symbol table has already been built + if(v->num_errors()==0) { + auto &b = s->kind()==symbolKind::function ? + s->is_function()->body()->statements() : + s->is_procedure()->body()->statements(); + + // lower function call sites so that all function calls are of + // the form : variable = call(<args>) + // e.g. + // a = 2 + foo(2+x, y, 1) + // becomes + // ll0_ = foo(2+x, y, 1) + // a = 2 + ll0_ + for(auto e=b.begin(); e!=b.end(); ++e) { + b.splice(e, lower_function_calls((*e).get())); + } +#ifdef LOGGING + std::cout << "body after call site lowering\n"; + for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n"; + std::cout << green("\n-argument lowering-\n\n"); +#endif + + // lower function arguments that are not identifiers or literals + // e.g. + // ll0_ = foo(2+x, y, 1) + // a = 2 + ll0_ + // becomes + // ll1_ = 2+x + // ll0_ = foo(ll1_, y, 1) + // a = 2 + ll0_ + for(auto e=b.begin(); e!=b.end(); ++e) { + if(auto be = (*e)->is_binary()) { + // only apply to assignment expressions where rhs is a + // function call because the function call lowering step + // above ensures that all function calls are of this form + if(auto rhs = be->rhs()->is_function_call()) { + b.splice(e, lower_function_arguments(rhs->args())); + } + } + } + +#ifdef LOGGING + std::cout << "body after argument lowering\n"; + for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n"; + std::cout << green("\n-inlining-\n\n"); +#endif + + // Do the inlining, which currently only works for functions + // that have a single statement in their body + // e.g. if the function foo in the examples above is defined as follows + // + // function foo(a, b, c) { + // foo = a*(b + c) + // } + // + // the full inlined example is + // ll1_ = 2+x + // ll0_ = ll1_*(y + 1) + // a = 2 + ll0_ + for(auto e=b.begin(); e!=b.end(); ++e) { + if(auto ass = (*e)->is_assignment()) { + if(ass->rhs()->is_function_call()) { + ass->replace_rhs(inline_function_call(ass->rhs())); + } + } + } + +#ifdef LOGGING + std::cout << "body after inlining\n"; + for(auto& l : b) std::cout << " " << l->to_string() << " @ " << l->location() << "\n"; +#endif + } + } + } + + if(errors) { + std::cout << "\nthere were " << errors + << " errors in the semantic analysis" << std::endl; + status_ = lexerStatus::error; + return false; + } + + // All API methods are generated from statements in one of the special procedures + // defined in NMODL, e.g. the nrn_init() API call is based on the INITIAL block. + // When creating an API method, the first task is to look up the source procedure, + // i.e. the INITIAL block for nrn_init(). This lambda takes care of this repetative + // lookup work, with error checking. + auto make_empty_api_method = [this] + (std::string const& name, std::string const& source_name) + -> std::pair<APIMethod*, ProcedureExpression*> + { + if( !has_symbol(source_name, symbolKind::procedure) ) { + error(pprintf("unable to find symbol '%'", yellow(source_name)), + Location()); + return std::make_pair(nullptr, nullptr); + } + + auto source = symbols_[source_name]->is_procedure(); + auto loc = source->location(); + + if( symbols_.find(name)!=symbols_.end() ) { + error(pprintf("'%' clashes with reserved name, please rename it", + yellow(name)), + symbols_.find(name)->second->location()); + return std::make_pair(nullptr, source); + } + + symbols_[name] = make_symbol<APIMethod>( + loc, name, + std::vector<expression_ptr>(), // no arguments + make_expression<BlockExpression> + (loc, std::list<expression_ptr>(), false) + ); + + auto proc = symbols_[name]->is_api_method(); + return std::make_pair(proc, source); + }; + + //......................................................................... + // nrn_init : based on the INITIAL block (i.e. the 'initial' procedure + //......................................................................... + auto initial_api = make_empty_api_method("nrn_init", "initial"); + auto api_init = initial_api.first; + auto proc_init = initial_api.second; + + if(api_init) + { + auto& body = api_init->body()->statements(); + + for(auto& e : *proc_init->body()) { + body.emplace_back(e->clone()); + } + + api_init->semantic(symbols_); + } + else { + if(!proc_init) { + error("an INITIAL block is required", Location()); + } + return false; + } + + + // evaluate whether an expression has the form + // (b - x)/a + // where x is a state variable with name state_variable + // this test is used to detect ODEs with the signature + // dx/dt = (xinf - x)/xtau + // so that we can integrate them efficiently using the cnexp integrator + // + // this is messy, but ok for a once off. If this pattern is + // repeated, it will be worth finding a more sophisticated solution + auto is_gating = [] (Expression* e, std::string const& state_variable) { + IdentifierExpression* a = nullptr; + IdentifierExpression* b = nullptr; + BinaryExpression* other = nullptr; + if(auto binop = e->is_binary()) { + if(binop->op()==tok::divide) { + if((a = binop->rhs()->is_identifier())) { + other = binop->lhs()->is_binary(); + } + } + } + if(other) { + if(other->op()==tok::minus) { + if(auto rhs = other->rhs()->is_identifier()) { + if(rhs->name() == state_variable) { + if(auto lhs = other->lhs()->is_identifier()) { + if(lhs->name() != state_variable) { + b = lhs; + return std::make_pair(a, b); + } + } + } + } + } + } + a = b = nullptr; + return std::make_pair(a, b); + }; + + // Look in the symbol table for a procedure with the name "breakpoint". + // This symbol corresponds to the BREAKPOINT block in the .mod file + // There are two APIMethods generated from BREAKPOINT. + // The first is nrn_state, which is the first case handled below. + // The second is nrn_current, which is handled after this block + auto state_api = make_empty_api_method("nrn_state", "breakpoint"); + auto api_state = state_api.first; + auto breakpoint = state_api.second; + + if( breakpoint ) { + // helper for making identifiers on the fly + auto id = [] (std::string const& name, Location loc=Location()) { + return make_expression<IdentifierExpression>(loc, name); + }; + //.......................................................... + // nrn_state : The temporal integration of state variables + //.......................................................... + + // find the SOLVE statement + SolveExpression* solve_expression = nullptr; + for(auto& e: *(breakpoint->body())) { + solve_expression = e->is_solve_statement(); + if(solve_expression) break; + } + + // handle the case where there is no SOLVE in BREAKPOINT + if( solve_expression==nullptr ) { + warning( " there is no SOLVE statement, required to update the" + " state variables, in the BREAKPOINT block", + breakpoint->location()); + } + else { + // get the DERIVATIVE block + auto dblock = solve_expression->procedure(); + + // body refers to the currently empty body of the APIMethod that + // will hold the AST for the nrn_state function. + auto& body = api_state->body()->statements(); + + auto has_provided_integration_method = + solve_expression->method() == solverMethod::cnexp; + + // loop over the statements in the SOLVE block from the mod file + // put each statement into the new APIMethod, performing + // transformations if necessary. + for(auto& e : *(dblock->body())) { + if(auto ass = e->is_assignment()) { + auto lhs = ass->lhs(); + auto rhs = ass->rhs(); + if(auto deriv = lhs->is_derivative()) { + // Check that a METHOD was provided in the original SOLVE + // statment. We have to do this because it is possible + // to call SOLVE without a METHOD, in which case there should + // be no derivative expressions in the DERIVATIVE block. + if(!has_provided_integration_method) { + error("The DERIVATIVE block has a derivative expression" + " but no METHOD was specified in the SOLVE statement", + deriv->location()); + return false; + } + + auto sym = deriv->symbol(); + auto name = deriv->name(); + + auto gating_vars = is_gating(rhs, name); + if(gating_vars.first && gating_vars.second) { + auto const& inf = gating_vars.second->spelling(); + auto const& rate = gating_vars.first->spelling(); + auto e_string = name + "=" + inf + + "+(" + name + "-" + inf + ")*exp(-dt/" + + rate + ")"; + auto stmt_update = Parser(e_string).parse_line_expression(); + body.emplace_back(std::move(stmt_update)); + continue; + } + else { + // create visitor for linear analysis + auto v = make_unique<ExpressionClassifierVisitor>(sym); + rhs->accept(v.get()); + + // quit if ODE is not linear + if( v->classify() != expressionClassification::linear ) { + error("unable to integrate nonlinear state ODEs", + rhs->location()); + return false; + } + + // the linear differential equation is of the form + // s' = a*s + b + // integration by separation of variables gives the following + // update function to integrate s for one time step dt + // s = -b/a + (s+b/a)*exp(a*dt) + // we are going to build this update function by + // 1. generating statements that define a_=a and ba_=b/a + // 2. generating statements that update the solution + + // statement : a_ = a + auto stmt_a = + binary_expression(Location(), + tok::eq, + id("a_"), + v->linear_coefficient()->clone()); + + // expression : b/a + auto expr_ba = + binary_expression(Location(), + tok::divide, + v->constant_term()->clone(), + id("a_")); + // statement : ba_ = b/a + auto stmt_ba = binary_expression(Location(), tok::eq, id("ba_"), std::move(expr_ba)); + + // the update function + auto e_string = name + " = -ba_ + " + "(" + name + " + ba_)*exp(a_*dt)"; + auto stmt_update = Parser(e_string).parse_line_expression(); + + // add declaration of local variables + body.emplace_back(Parser("LOCAL a_").parse_local()); + body.emplace_back(Parser("LOCAL ba_").parse_local()); + // add integration statements + body.emplace_back(std::move(stmt_a)); + body.emplace_back(std::move(stmt_ba)); + body.emplace_back(std::move(stmt_update)); + continue; + } + } + else { + body.push_back(e->clone()); + continue; + } + } + body.push_back(e->clone()); + } + } + + // perform semantic analysis + api_state->semantic(symbols_); + + //.......................................................... + // nrn_current : update contributions to currents + //.......................................................... + std::list<expression_ptr> block; + + // helper which tests a statement to see if it updates an ion + // channel variable. + auto is_ion_update = [] (Expression* e) { + if(auto a = e->is_assignment()) { + // semantic analysis has been performed on the original expression + // which ensures that the lhs is an identifier and a variable + if(auto sym = a->lhs()->is_identifier()->symbol()) { + // assume that a scalar stack variable is being used for + // the indexed value: i.e. the value is not cached + if(auto var = sym->is_local_variable()) { + return var->ion_channel(); + } + } + } + return ionKind::none; + }; + + // add statements that initialize the reduction variables + bool has_current_update = false; + for(auto& e: *(breakpoint->body())) { + // ignore solve and conductance statements + if(e->is_solve_statement()) continue; + if(e->is_conductance_statement()) continue; + + // add the expression + block.emplace_back(e->clone()); + + // we are updating an ionic current + // so keep track of current and conductance accumulation + auto channel = is_ion_update(e.get()); + if(channel != ionKind::none) { + auto lhs = e->is_assignment()->lhs()->is_identifier(); + auto rhs = e->is_assignment()->rhs(); + + // analyze the expression for linear terms + //auto v = make_unique<ExpressionClassifierVisitor>(symbols_["v"].get()); + auto v_symbol = breakpoint->scope()->find("v"); + auto v = make_unique<ExpressionClassifierVisitor>(v_symbol); + rhs->accept(v.get()); + + if(v->classify()==expressionClassification::linear) { + // add current update + if(has_current_update) { + block.emplace_back(Parser("current_ = current_ + " + lhs->name()).parse_line_expression()); + } + else { + block.emplace_back(Parser("current_ = " + lhs->name()).parse_line_expression()); + } + } + else { + error("current update functions must be a linear" + " function of v : " + rhs->to_string(), e->location()); + return false; + } + has_current_update = true; + } + } + if(has_current_update && kind()==moduleKind::point) { + block.emplace_back(Parser("current_ = 100. * current_ / area_").parse_line_expression()); + } + + auto v = make_unique<ConstantFolderVisitor>(); + for(auto& e : block) { + e->accept(v.get()); + } + + symbols_["nrn_current"] = + make_symbol<APIMethod>( + breakpoint->location(), "nrn_current", + std::vector<expression_ptr>(), + make_expression<BlockExpression>(breakpoint->location(), + std::move(block), false) + ); + symbols_["nrn_current"]->semantic(symbols_); + } + else { + error("a BREAKPOINT block is required", Location()); + return false; + } + + return status() == lexerStatus::happy; +} + +/// populate the symbol table with class scope variables +void Module::add_variables_to_symbols() { + // add reserved symbols (not v, because for some reason it has to be added + // by the user) + auto create_variable = [this] (const char* name, rangeKind rng, accessKind acc) { + auto t = new VariableExpression(Location(), name); + t->state(false); + t->linkage(linkageKind::local); + t->ion_channel(ionKind::none); + t->range(rng); + t->access(acc); + t->visibility(visibilityKind::global); + symbols_[name] = symbol_ptr{t}; + }; + + create_variable("t", rangeKind::scalar, accessKind::read); + create_variable("dt", rangeKind::scalar, accessKind::read); + + // add indexed variables to the table + auto create_indexed_variable = [this] + (std::string const& name, std::string const& indexed_name, + tok op, accessKind acc, ionKind ch, Location loc) + { + if(symbols_.count(name)) { + throw compiler_exception( + "trying to insert a symbol that already exists", + loc); + } + symbols_[name] = + make_symbol<IndexedVariable>(loc, name, indexed_name, acc, op, ch); + }; + + create_indexed_variable("current_", "vec_i", tok::plus, + accessKind::write, ionKind::none, Location()); + create_indexed_variable("v", "vec_v", tok::eq, + accessKind::read, ionKind::none, Location()); + create_indexed_variable("area_", "vec_area", tok::eq, + accessKind::read, ionKind::none, Location()); + + // add state variables + for(auto const &var : state_block()) { + VariableExpression *id = new VariableExpression(Location(), var); + + id->state(true); // set state to true + // state variables are private + // what about if the state variables is an ion concentration? + id->linkage(linkageKind::local); + id->visibility(visibilityKind::local); + id->ion_channel(ionKind::none); // no ion channel + id->range(rangeKind::range); // always a range + id->access(accessKind::readwrite); + + symbols_[var] = symbol_ptr{id}; + } + + // add the parameters + for(auto const& var : parameter_block()) { + auto name = var.name(); + if(name == "v") { // global voltage values + // ignore voltage, which is added as an indexed variable by default + continue; + } + VariableExpression *id = new VariableExpression(Location(), name); + + id->state(false); // never a state variable + id->linkage(linkageKind::local); + // parameters are visible to Neuron + id->visibility(visibilityKind::global); + id->ion_channel(ionKind::none); + // scalar by default, may later be upgraded to range + id->range(rangeKind::scalar); + id->access(accessKind::read); + + // check for 'special' variables + if(name == "celcius") { // global celcius parameter + id->linkage(linkageKind::external); + } + + // set default value if one was specified + if(var.value.size()) { + id->value(std::stod(var.value)); + } + + symbols_[name] = symbol_ptr{id}; + } + + // add the assigned variables + for(auto const& var : assigned_block()) { + auto name = var.name(); + if(name == "v") { // global voltage values + // ignore voltage, which is added as an indexed variable by default + continue; + } + VariableExpression *id = new VariableExpression(var.token.location, name); + + id->state(false); // never a state variable + id->linkage(linkageKind::local); + // local visibility by default + id->visibility(visibilityKind::local); + id->ion_channel(ionKind::none); // can change later + // ranges because these are assigned to in loop + id->range(rangeKind::range); + id->access(accessKind::readwrite); + + symbols_[name] = symbol_ptr{id}; + } + + //////////////////////////////////////////////////// + // parse the NEURON block data, and use it to update + // the variables in symbols_ + //////////////////////////////////////////////////// + // first the ION channels + // add ion channel variables + auto update_ion_symbols = [this, create_indexed_variable] + (Token const& tkn, accessKind acc, ionKind channel) + { + auto const& var = tkn.spelling; + + // add the ion variable's indexed shadow + if(has_symbol(var)) { + auto sym = symbols_[var].get(); + + // has the user declared a range/parameter with the same name? + if(sym->kind()!=symbolKind::indexed_variable) { + warning( + pprintf("the symbol % clashes with the ion channel variable," + " and will be ignored", yellow(var)), + sym->location() + ); + // erase symbol + symbols_.erase(var); + } + } + + create_indexed_variable(var, "ion_"+var, + acc==accessKind::read ? tok::eq : tok::plus, + acc, channel, tkn.location); + }; + + // check for nonspecific current + if( neuron_block().has_nonspecific_current() ) { + auto const& i = neuron_block().nonspecific_current; + update_ion_symbols(i, accessKind::write, ionKind::nonspecific); + } + + + for(auto const& ion : neuron_block().ions) { + for(auto const& var : ion.read) { + update_ion_symbols(var, accessKind::read, ion.kind()); + } + for(auto const& var : ion.write) { + update_ion_symbols(var, accessKind::write, ion.kind()); + } + } + + // then GLOBAL variables + for(auto const& var : neuron_block().globals) { + if(!symbols_[var.spelling]) { + error( yellow(var.spelling) + + " is declared as GLOBAL, but has not been declared in the" + + " ASSIGNED block", + var.location); + return; + } + auto& sym = symbols_[var.spelling]; + if(auto id = sym->is_variable()) { + id->visibility(visibilityKind::global); + } + else if (!sym->is_indexed_variable()){ + throw compiler_exception( + "unable to find symbol " + yellow(var.spelling) + " in symbols", + Location()); + } + } + + // then RANGE variables + for(auto const& var : neuron_block().ranges) { + if(!symbols_[var.spelling]) { + error( yellow(var.spelling) + + " is declared as RANGE, but has not been declared in the" + + " ASSIGNED or PARAMETER block", + var.location); + return; + } + auto& sym = symbols_[var.spelling]; + if(auto id = sym->is_variable()) { + id->range(rangeKind::range); + } + else if (!sym->is_indexed_variable()){ + throw compiler_exception( + "unable to find symbol " + yellow(var.spelling) + " in symbols", + var.location); + } + } +} + +bool Module::optimize() { + // how to structure the optimizer + // loop over APIMethods + // - apply optimization to each in turn + auto folder = make_unique<ConstantFolderVisitor>(); + for(auto &symbol : symbols_) { + auto kind = symbol.second->kind(); + BlockExpression* body; + if(kind == symbolKind::procedure) { + // we are only interested in true procedurs and APIMethods + auto proc = symbol.second->is_procedure(); + auto pkind = proc->kind(); + if(pkind == procedureKind::normal || pkind == procedureKind::api ) + body = symbol.second->is_procedure()->body(); + else + continue; + } + // for now don't look at functions + //else if(kind == symbolKind::function) { + // body = symbol.second.expression->is_function()->body(); + //} + else { + continue; + } + + ///////////////////////////////////////////////////////////////////// + // loop over folding and propogation steps until there are no changes + ///////////////////////////////////////////////////////////////////// + + // perform constant folding + for(auto& line : *body) { + line->accept(folder.get()); + } + + // preform expression simplification + // i.e. removing zeros/refactoring reciprocals/etc + + // perform constant propogation + + ///////////////////////////////////////////////////////////////////// + // remove dead local variables + ///////////////////////////////////////////////////////////////////// + } + + return true; +} + diff --git a/modcc/src/module.hpp b/modcc/src/module.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f427d870ebc1cb770e8e424349d4d140892b3f45 --- /dev/null +++ b/modcc/src/module.hpp @@ -0,0 +1,128 @@ +#pragma once + +#include <string> +#include <vector> + +#include "blocks.hpp" +#include "expression.hpp" + +// wrapper around a .mod file +class Module { +public : + using scope_type = Expression::scope_type; + using symbol_map = scope_type::symbol_map; + using symbol_ptr = scope_type::symbol_ptr; + + Module(std::string const& fname); + Module(std::vector<char> const& buffer); + + std::vector<char> const& buffer() const { + return buffer_; + } + + std::string const& file_name() const {return fname_;} + std::string const& name() const {return neuron_block_.name;} + + void title(const std::string& t) {title_ = t;} + std::string const& title() const {return title_;} + + NeuronBlock & neuron_block() {return neuron_block_;} + NeuronBlock const& neuron_block() const {return neuron_block_;} + + StateBlock & state_block() {return state_block_;} + StateBlock const& state_block() const {return state_block_;} + + UnitsBlock & units_block() {return units_block_;} + UnitsBlock const& units_block() const {return units_block_;} + + ParameterBlock & parameter_block() {return parameter_block_;} + ParameterBlock const& parameter_block() const {return parameter_block_;} + + AssignedBlock & assigned_block() {return assigned_block_;} + AssignedBlock const& assigned_block() const {return assigned_block_;} + + void neuron_block(NeuronBlock const &n) {neuron_block_ = n;} + void state_block (StateBlock const &s) {state_block_ = s;} + void units_block (UnitsBlock const &u) {units_block_ = u;} + void parameter_block (ParameterBlock const &p) {parameter_block_ = p;} + void assigned_block (AssignedBlock const &a) {assigned_block_ = a;} + + // access to the AST + std::vector<symbol_ptr>& procedures(); + std::vector<symbol_ptr>const& procedures() const; + + std::vector<symbol_ptr>& functions(); + std::vector<symbol_ptr>const& functions() const; + + symbol_map & symbols(); + symbol_map const& symbols() const; + + // error handling + void error(std::string const& msg, Location loc); + std::string const& error_string() { + return error_string_; + } + + lexerStatus status() const { + return status_; + } + + // warnings + void warning(std::string const& msg, Location loc); + bool has_warning() const { + return has_warning_; + } + bool has_error() const { + return status()==lexerStatus::error; + } + + moduleKind kind() const { + return kind_; + } + void kind(moduleKind k) { + kind_ = k; + } + + // perform semantic analysis + void add_variables_to_symbols(); + bool semantic(); + bool optimize(); +private : + moduleKind kind_; + std::string title_; + std::string fname_; + std::vector<char> buffer_; // character buffer loaded from file + + bool generate_initial_api(); + bool generate_current_api(); + bool generate_state_api(); + + // error handling + std::string error_string_; + lexerStatus status_ = lexerStatus::happy; + bool has_warning_ = false; + + // AST storage + std::vector<symbol_ptr> procedures_; + std::vector<symbol_ptr> functions_; + + // hash table for lookup of variable and call names + symbol_map symbols_; + + /// tests if symbol is defined + bool has_symbol(const std::string& name) { + return symbols_.find(name) != symbols_.end(); + } + /// tests if symbol is defined + bool has_symbol(const std::string& name, symbolKind kind) { + auto s = symbols_.find(name); + return s == symbols_.end() ? false : s->second->kind() == kind; + } + + // blocks + NeuronBlock neuron_block_; + StateBlock state_block_; + UnitsBlock units_block_; + ParameterBlock parameter_block_; + AssignedBlock assigned_block_; +}; diff --git a/modcc/src/parser.cpp b/modcc/src/parser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..30dbb9ea3f72cbd98ae48b3a598275de3e59a3b8 --- /dev/null +++ b/modcc/src/parser.cpp @@ -0,0 +1,1309 @@ +#include <iostream> +#include <list> +#include <cstring> + +#include "constantfolder.hpp" +#include "parser.hpp" +#include "perfvisitor.hpp" +#include "token.hpp" +#include "util.hpp" + +// specialize on const char* for lazy evaluation of compile time strings +bool Parser::expect(tok tok, const char* str) { + if(tok==token_.type) { + return true; + } + + error( + strlen(str)>0 ? + str + : std::string("unexpected token ")+yellow(token_.spelling)); + + return false; +} + +bool Parser::expect(tok tok, std::string const& str) { + if(tok==token_.type) { + return true; + } + + error( + str.size()>0 ? + str + : std::string("unexpected token ")+yellow(token_.spelling)); + + return false; +} + +void Parser::error(std::string msg) { + std::string location_info = pprintf( + "%:% ", module_ ? module_->file_name() : "", token_.location); + if(status_==lexerStatus::error) { + // append to current string + error_string_ += "\n" + white(location_info) + "\n " +msg; + } + else { + error_string_ = white(location_info) + "\n " + msg; + status_ = lexerStatus::error; + } +} + +void Parser::error(std::string msg, Location loc) { + std::string location_info = pprintf( + "%:% ", module_ ? module_->file_name() : "", loc); + if(status_==lexerStatus::error) { + // append to current string + error_string_ += "\n" + green(location_info) + msg; + } + else { + error_string_ = green(location_info) + msg; + status_ = lexerStatus::error; + } +} + +Parser::Parser(Module& m, bool advance) +: Lexer(m.buffer()), + module_(&m) +{ + // prime the first token + get_token(); + + if(advance) { + parse(); + } +} + +Parser::Parser(std::string const& buf) +: Lexer(buf), + module_(nullptr) +{ + // prime the first token + get_token(); +} + +bool Parser::parse() { + // perform first pass to read the descriptive blocks and + // record the location of the verb blocks + while(token_.type!=tok::eof) { + switch(token_.type) { + case tok::title : + parse_title(); + break; + case tok::neuron : + parse_neuron_block(); + break; + case tok::state : + parse_state_block(); + break; + case tok::units : + parse_units_block(); + break; + case tok::parameter : + parse_parameter_block(); + break; + case tok::assigned : + parse_assigned_block(); + break; + // INITIAL, DERIVATIVE, PROCEDURE, NET_RECEIVE and BREAKPOINT blocks + // are all lowered to ProcedureExpression + case tok::net_receive: + case tok::breakpoint : + case tok::initial : + case tok::derivative : + case tok::procedure : + { + auto p = parse_procedure(); + if(!p) break; + module_->procedures().emplace_back(std::move(p)); + } + break; + case tok::function : + { + auto f = parse_function(); + if(!f) break; + module_->functions().emplace_back(std::move(f)); + } + break; + default : + error(pprintf("expected block type, found '%'", token_.spelling)); + break; + } + if(status() == lexerStatus::error) { + std::cerr << red("error: ") << error_string_ << std::endl; + return false; + } + } + + return true; +} + +// consume a comma separated list of identifiers +// NOTE: leaves the current location at begining of the last identifier in the list +// OK: empty list "" +// OK: list with one identifier "a" +// OK: list with mutiple identifier "a, b, c, d" +// BAD: list with keyword "a, b, else, d" +// list with trailing comma "a, b,\n" +// list with keyword "a, if, b" +std::vector<Token> Parser::comma_separated_identifiers() { + std::vector<Token> tokens; + int startline = location_.line; + // handle is an empty list at the end of a line + if(peek().location.line > startline) { + // this happens when scanning WRITE below: + // USEION k READ a, b WRITE + // leave to the caller to decide whether an empty list is an error + return tokens; + } + while(1) { + get_token(); + + // first check if a new line was encounterd + if(location_.line > startline) { + return tokens; + } + else if(token_.type == tok::identifier) { + tokens.push_back(token_); + } + else if(is_keyword(token_)) { + error(pprintf("found keyword '%', expected a variable name", token_.spelling)); + return tokens; + } + else if(token_.type == tok::number) { + error(pprintf("found number '%', expected a variable name", token_.spelling)); + return tokens; + } + else { + error(pprintf("found '%', expected a variable name", token_.spelling)); + return tokens; + } + + // look ahead to check for a comma. This approach ensures that the + // first token after the end of the list is not consumed + if( peek().type == tok::comma ) { + // load the comma + get_token(); + // assert that the list can't run off the end of a line + if(peek().location.line > startline) { + error("line can't end with a '"+yellow(",")+"'"); + return tokens; + } + } + else { + break; + } + } + get_token(); // prime the first token after the list + + return tokens; +} + +/* +NEURON { + THREADSAFE + SUFFIX KdShu2007 + USEION k WRITE ik READ xy + RANGE gkbar, ik, ek + GLOBAL minf, mtau, hinf, htau +} +*/ +void Parser::parse_neuron_block() { + NeuronBlock neuron_block; + + get_token(); + + // assert that the block starts with a curly brace + if(token_.type != tok::lbrace) { + error(pprintf("NEURON block must start with a curly brace {, found '%'", + token_.spelling)); + return; + } + + // initialize neuron block + neuron_block.threadsafe = false; + + // there are no use cases for curly brace in a NEURON block, so we don't + // have to count them we have to get the next token before entering the loop + // to handle the case of an empty block {} + get_token(); + while(token_.type!=tok::rbrace) { + switch(token_.type) { + case tok::threadsafe : + neuron_block.threadsafe = true; + get_token(); // consume THREADSAFE + break; + + case tok::suffix : + case tok::point_process : + neuron_block.kind = (token_.type==tok::suffix) ? moduleKind::density + : moduleKind::point; + + // set the modul kind + module_->kind(neuron_block.kind); + + get_token(); // consume SUFFIX / POINT_PROCESS + // assert that a valid name for the Neuron has been specified + if(token_.type != tok::identifier) { + error(pprintf("invalid name for SUFFIX, found '%'", token_.spelling)); + return; + } + neuron_block.name = token_.spelling; + + get_token(); // consume the name + break; + + // this will be a comma-separated list of identifiers + case tok::global : + // the ranges are a comma-seperated list of identifiers + { + auto identifiers = comma_separated_identifiers(); + // bail if there was an error reading the list + if(status_==lexerStatus::error) { + return; + } + for(auto const &id : identifiers) { + neuron_block.globals.push_back(id); + } + } + break; + + // this will be a comma-separated list of identifiers + case tok::range : + // the ranges are a comma-seperated list of identifiers + { + auto identifiers = comma_separated_identifiers(); + if(status_==lexerStatus::error) { // bail if there was an error reading the list + return; + } + for(auto const &id : identifiers) { + neuron_block.ranges.push_back(id); + } + } + break; + + case tok::useion : + { + IonDep ion; + // we have to parse the name of the ion first + get_token(); + // check this is an identifier token + if(token_.type != tok::identifier) { + error(pprintf("invalid name for an ion chanel '%'", + token_.spelling)); + return; + } + // check that the ion type is valid (insist on lower case?) + if(!(token_.spelling == "k" || token_.spelling == "ca" || token_.spelling == "na")) { + error(pprintf("invalid ion type % must be on eof 'k' 'ca' or 'na'", + yellow(token_.spelling))); + return; + } + ion.name = token_.spelling; + get_token(); // consume the ion name + + // this loop ensures that we don't gobble any tokens past + // the end of the USEION clause + while(token_.type == tok::read || token_.type == tok::write) { + auto& target = (token_.type == tok::read) ? ion.read + : ion.write; + std::vector<Token> identifiers + = comma_separated_identifiers(); + // bail if there was an error reading the list + if(status_==lexerStatus::error) { + return; + } + for(auto const &id : identifiers) { + target.push_back(id); + } + } + // add the ion dependency to the NEURON block + neuron_block.ions.push_back(std::move(ion)); + } + break; + + case tok::nonspecific_current : + // Assume that there is one non-specific current per mechanism. + // It would be easy to extend this to multiple currents, + // however there are no mechanisms in the CoreNeuron repository + // that do this + { + get_token(); // consume NONSPECIFIC_CURRENT + + auto tok = token_; + + // parse the current name and check for errors + auto id = parse_identifier(); + if(status_==lexerStatus::error) { + return; + } + + // store the token with nonspecific current's name and location + neuron_block.nonspecific_current = tok; + } + break; + + // the parser encountered an invalid symbol + default : + error(pprintf("there was an invalid statement '%' in NEURON block", + token_.spelling)); + return; + } + } + + // copy neuron block into module + module_->neuron_block(neuron_block); + + // now we have a curly brace, so prime the next token + get_token(); +} + +void Parser::parse_state_block() { + StateBlock state_block; + + get_token(); + + // assert that the block starts with a curly brace + if(token_.type != tok::lbrace) { + error(pprintf("STATE block must start with a curly brace {, found '%'", token_.spelling)); + return; + } + + // there are no use cases for curly brace in a STATE block, so we don't have to count them + // we have to get the next token before entering the loop to handle the case of + // an empty block {} + get_token(); + while(token_.type!=tok::rbrace) { + if(token_.type != tok::identifier) { + error(pprintf("'%' is not a valid name for a state variable", token_.spelling)); + return; + } + state_block.state_variables.push_back(token_.spelling); + get_token(); + } + + // add this state block information to the module + module_->state_block(state_block); + + // now we have a curly brace, so prime the next token + get_token(); +} + +// scan a unit block +void Parser::parse_units_block() { + UnitsBlock units_block; + + get_token(); + + // assert that the block starts with a curly brace + if(token_.type != tok::lbrace) { + error(pprintf("UNITS block must start with a curly brace {, found '%'", token_.spelling)); + return; + } + + // there are no use cases for curly brace in a UNITS block, so we don't have to count them + get_token(); + while(token_.type!=tok::rbrace) { + // get the alias + std::vector<Token> lhs = unit_description(); + if( status_!=lexerStatus::happy ) return; + + // consume the '=' sign + if( token_.type!=tok::eq ) { + error(pprintf("expected '=', found '%'", token_.spelling)); + return; + } + + get_token(); // next token + + // get the units + std::vector<Token> rhs = unit_description(); + if( status_!=lexerStatus::happy ) return; + + // store the unit definition + units_block.unit_aliases.push_back({lhs, rhs}); + } + + // add this state block information to the module + module_->units_block(units_block); + + // now we have a curly brace, so prime the next token + get_token(); +} + +////////////////////////////////////////////////////// +// the parameter block describes variables that are +// to be used as parameters. Some are given values, +// others are simply listed, and some have units +// assigned to them. Here we want to get a list of the +// parameter names, along with values if given. +// We also store the token that describes the units +////////////////////////////////////////////////////// +void Parser::parse_parameter_block() { + ParameterBlock block; + + get_token(); + + // assert that the block starts with a curly brace + if(token_.type != tok::lbrace) { + error(pprintf("PARAMETER block must start with a curly brace {, found '%'", token_.spelling)); + return; + } + + // there are no use cases for curly brace in a UNITS block, so we don't have to count them + get_token(); + while(token_.type!=tok::rbrace && token_.type!=tok::eof) { + int line = location_.line; + Id parm; + + // read the parameter name + if(token_.type != tok::identifier) { + goto parm_error; + } + parm.token = token_; // save full token + + get_token(); + + // look for equality + if(token_.type==tok::eq) { + get_token(); // consume '=' + if(token_.type==tok::minus) { + parm.value = "-"; + get_token(); + } + if(token_.type != tok::number) { + goto parm_error; + } + parm.value += token_.spelling; // store value as a string + get_token(); + } + + // get the parameters + if(line==location_.line && token_.type == tok::lparen) { + parm.units = unit_description(); + if(status_ == lexerStatus::error) { + goto parm_error; + } + } + + block.parameters.push_back(parm); + } + + // errer if EOF before closeing curly brace + if(token_.type==tok::eof) { + error("PARAMETER block must have closing '}'"); + goto parm_error; + } + + get_token(); // consume closing brace + + module_->parameter_block(block); + + return; +parm_error: + // only write error message if one hasn't already been logged by the lexer + if(status_==lexerStatus::happy) { + error(pprintf("PARAMETER block unexpected symbol '%'", token_.spelling)); + } + return; +} + +void Parser::parse_assigned_block() { + AssignedBlock block; + + get_token(); + + // assert that the block starts with a curly brace + if(token_.type != tok::lbrace) { + error(pprintf("ASSIGNED block must start with a curly brace {, found '%'", token_.spelling)); + return; + } + + // there are no use cases for curly brace in an ASSIGNED block, so we don't have to count them + get_token(); + while(token_.type!=tok::rbrace && token_.type!=tok::eof) { + int line = location_.line; + std::vector<Token> variables; // we can have more than one variable on a line + + // the first token must be ... + if(token_.type != tok::identifier) { + goto ass_error; + } + // read all of the identifiers until we run out of identifiers or reach a new line + while(token_.type == tok::identifier && line == location_.line) { + variables.push_back(token_); + get_token(); + } + + // there are some parameters at the end of the line + if(line==location_.line && token_.type == tok::lparen) { + auto u = unit_description(); + if(status_ == lexerStatus::error) { + goto ass_error; + } + for(auto const& t : variables) { + block.parameters.push_back(Id(t, "", u)); + } + } + else { + for(auto const& t : variables) { + block.parameters.push_back(Id(t, "", {})); + } + } + } + + // errer if EOF before closeing curly brace + if(token_.type==tok::eof) { + error("ASSIGNED block must have closing '}'"); + goto ass_error; + } + + get_token(); // consume closing brace + + module_->assigned_block(block); + + return; +ass_error: + // only write error message if one hasn't already been logged by the lexer + if(status_==lexerStatus::happy) { + error(pprintf("ASSIGNED block unexpected symbol '%'", token_.spelling)); + } + return; +} + +std::vector<Token> Parser::unit_description() { + static const tok legal_tokens[] = {tok::identifier, tok::divide, tok::number}; + int startline = location_.line; + std::vector<Token> tokens; + + // chec that we start with a left parenthesis + if(token_.type != tok::lparen) + goto unit_error; + get_token(); + + while(token_.type != tok::rparen) { + // check for illegal tokens or a new line + if( !is_in(token_.type,legal_tokens) || startline < location_.line ) + goto unit_error; + + // add this token to the set + tokens.push_back(token_); + get_token(); + } + // remove trailing right parenthesis ')' + get_token(); + + return tokens; + +unit_error: + error(pprintf("incorrect unit description '%'", tokens)); + return tokens; +} + +// Returns a prototype expression for a function or procedure call +// Takes an optional argument that allows the user to specify the +// name of the prototype, which is used for prototypes where the name +// is implcitly defined (e.g. INITIAL and BREAKPOINT blocks) +expression_ptr Parser::parse_prototype(std::string name=std::string()) { + Token identifier = token_; + + if(name.size()) { + // we assume that the current token_ is still pointing at + // the keyword, i.e. INITIAL or BREAKPOINT + identifier.type = tok::identifier; + identifier.spelling = name; + } + + // load the parenthesis + get_token(); + + // check for an argument list enclosed in parenthesis (...) + // return a prototype with an empty argument list if not found + if( token_.type != tok::lparen ) { + //return make_expression<PrototypeExpression>(identifier.location, identifier.spelling, {}); + return expression_ptr{new PrototypeExpression(identifier.location, identifier.spelling, {})}; + } + + get_token(); // consume '(' + std::vector<Token> arg_tokens; + while(token_.type != tok::rparen) { + // check identifier + if(token_.type != tok::identifier) { + error( "expected a valid identifier, found '" + + yellow(token_.spelling) + "'"); + return nullptr; + } + + arg_tokens.push_back(token_); + + get_token(); // consume the identifier + + // look for a comma + if(!(token_.type == tok::comma || token_.type==tok::rparen)) { + error( "expected a comma or closing parenthesis, found '" + + yellow(token_.spelling) + "'"); + return nullptr; + } + + if(token_.type == tok::comma) { + get_token(); // consume ',' + } + } + + if(token_.type != tok::rparen) { + error("procedure argument list must have closing parenthesis ')'"); + return nullptr; + } + get_token(); // consume closing parenthesis + + // pack the arguments into LocalDeclarations + std::vector<expression_ptr> arg_expressions; + for(auto const& t : arg_tokens) { + arg_expressions.emplace_back(make_expression<ArgumentExpression>(t.location, t)); + } + + return make_expression<PrototypeExpression> + (identifier.location, identifier.spelling, std::move(arg_expressions)); +} + +void Parser::parse_title() { + std::string title; + int this_line = location().line; + + Token tkn = peek(); + while( tkn.location.line==this_line + && tkn.type!=tok::eof + && status_==lexerStatus::happy) + { + get_token(); + title += token_.spelling; + tkn = peek(); + } + + // set the module title + module_->title(title); + + // load next token + get_token(); +} + +/// parse a procedure +/// can handle both PROCEDURE and INITIAL blocks +/// an initial block is stored as a procedure with name 'initial' and empty argument list +symbol_ptr Parser::parse_procedure() { + expression_ptr p; + procedureKind kind = procedureKind::normal; + + switch( token_.type ) { + case tok::derivative: + kind = procedureKind::derivative; + get_token(); // consume keyword token + if( !expect( tok::identifier ) ) return nullptr; + p = parse_prototype(); + break; + case tok::procedure: + kind = procedureKind::normal; + get_token(); // consume keyword token + if( !expect( tok::identifier ) ) return nullptr; + p = parse_prototype(); + break; + case tok::initial: + kind = procedureKind::initial; + p = parse_prototype("initial"); + break; + case tok::breakpoint: + kind = procedureKind::breakpoint; + p = parse_prototype("breakpoint"); + break; + case tok::net_receive: + kind = procedureKind::net_receive; + p = parse_prototype("net_receive"); + break; + default: + // it is a compiler error if trying to parse_procedure() without + // having DERIVATIVE, PROCEDURE, INITIAL or BREAKPOINT keyword + throw compiler_exception( + "attempt to parser_procedure() without {DERIVATIVE,PROCEDURE,INITIAL,BREAKPOINT}", + location_); + } + if(p==nullptr) return nullptr; + + // check for opening left brace { + if(!expect(tok::lbrace)) return nullptr; + + // parse the body of the function + expression_ptr body = parse_block(false); + if(body==nullptr) return nullptr; + + auto proto = p->is_prototype(); + if(kind != procedureKind::net_receive) { + return make_symbol<ProcedureExpression> + (proto->location(), proto->name(), std::move(proto->args()), std::move(body), kind); + } + else { + return make_symbol<NetReceiveExpression> + (proto->location(), proto->name(), std::move(proto->args()), std::move(body)); + } +} + +symbol_ptr Parser::parse_function() { + get_token(); // consume FUNCTION token + + // check that a valid identifier name was specified by the user + if( !expect( tok::identifier ) ) return nullptr; + + // parse the prototype + auto p = parse_prototype(); + if(p==nullptr) return nullptr; + + // check for opening left brace { + if(!expect(tok::lbrace)) return nullptr; + + // parse the body of the function + auto body = parse_block(false); + if(body==nullptr) return nullptr; + + PrototypeExpression *proto = p->is_prototype(); + return make_symbol<FunctionExpression> + (proto->location(), proto->name(), std::move(proto->args()), std::move(body)); +} + +// this is the first port of call when parsing a new line inside a verb block +// it tests to see whether the expression is: +// :: LOCAL identifier +// :: expression +expression_ptr Parser::parse_statement() { + switch(token_.type) { + case tok::if_stmt : + return parse_if(); + break; + case tok::conductance : + return parse_conductance(); + case tok::solve : + return parse_solve(); + case tok::local : + return parse_local(); + case tok::identifier : + return parse_line_expression(); + case tok::initial : + // only used for INITIAL block in NET_RECEIVE + return parse_initial(); + default: + error(pprintf("unexpected token type % '%'", token_string(token_.type), token_.spelling)); + return nullptr; + } + return nullptr; +} + +expression_ptr Parser::parse_identifier() { + // save name and location of the identifier + auto id = make_expression<IdentifierExpression>(token_.location, token_.spelling); + + // consume identifier + get_token(); + + // return variable identifier + return id; +} + +expression_ptr Parser::parse_call() { + // save name and location of the identifier + Token idtoken = token_; + + // consume identifier + get_token(); + + // check for a function call + // assert this is so + if(token_.type != tok::lparen) { + throw compiler_exception( + "should not be parsing parse_call without trailing '('", + location_); + } + + std::vector<expression_ptr> args; + + // parse a function call + get_token(); // consume '(' + + while(token_.type != tok::rparen) { + auto e = parse_expression(); + if(!e) return e; + + args.emplace_back(std::move(e)); + + // reached the end of the argument list + if(token_.type == tok::rparen) break; + + // insist on a comma between arguments + if( !expect(tok::comma, "call arguments must be separated by ','") ) + return expression_ptr(); + get_token(); // consume ',' + } + + // check that we have a closing parenthesis + if(!expect(tok::rparen, "function call missing closing ')'") ) { + return expression_ptr(); + } + get_token(); // consume ')' + + return make_expression<CallExpression>(idtoken.location, idtoken.spelling, std::move(args)); +} + +// parse a full line expression, i.e. one of +// :: procedure call e.g. rates(v+0.01) +// :: assignment expression e.g. x = y + 3 +// to parse a subexpression, see parse_expression() +// proceeds by first parsing the LHS (which may be a variable or function call) +// then attempts to parse the RHS if +// 1. the lhs is not a procedure call +// 2. the operator that follows is = +expression_ptr Parser::parse_line_expression() { + int line = location_.line; + expression_ptr lhs; + Token next = peek(); + if(next.type == tok::lparen) { + lhs = parse_call(); + // we have to ensure that a procedure call is alone on the line + // to avoid : + // :: assigning to it e.g. foo() = x + 6 + // :: stray symbols coming after e.g. foo() + x + // We assume that foo is a procedure call, if it is an eroneous + // function call this has to be caught in the second pass. + // or optimized away with a warning + if(!lhs) return lhs; + if(location_.line == line && token_.type != tok::eof) { + error(pprintf( + "expected a new line after call expression, found '%'", + yellow(token_.spelling))); + return expression_ptr(); + } + return lhs ; + } else if(next.type == tok::prime) { + lhs = make_expression<DerivativeExpression>(location_, token_.spelling); + // consume both name and derivative operator + get_token(); + get_token(); + // a derivative statement must be followed by '=' + if(token_.type!=tok::eq) { + error("a derivative declaration must have an assignment of the "\ + "form\n x' = expression\n where x is a state variable"); + return expression_ptr(); + } + } else { + lhs = parse_unaryop(); + } + + if(!lhs) { // error + return lhs; + } + + // we parse a binary expression if followed by an operator + if(token_.type == tok::eq) { + Token op = token_; // save the '=' operator with location + get_token(); // consume the '=' operator + return parse_binop(std::move(lhs), op); + } else if(line == location_.line && token_.type != tok::eof){ + error(pprintf("expected an assignment '%' or new line, found '%'", + yellow("="), + yellow(token_.spelling))); + return nullptr; + } + + return lhs; +} + +expression_ptr Parser::parse_expression() { + auto lhs = parse_unaryop(); + + if(lhs==nullptr) { // error + return nullptr; + } + + // we parse a binary expression if followed by an operator + if( binop_precedence(token_.type)>0 ) { + if(token_.type==tok::eq) { + error("assignment '"+yellow("=")+"' not allowed in sub-expression"); + return nullptr; + } + Token op = token_; // save the operator + get_token(); // consume the operator + return parse_binop(std::move(lhs), op); + } + + return lhs; +} + +/// Parse a unary expression. +/// If called when the current node in the AST is not a unary expression the call +/// will be forwarded to parse_primary. This mechanism makes it possible to parse +/// all nodes in the expression using parse_unary, which simplifies the call sites +/// with either a primary or unary node is to be parsed. +/// It also simplifies parsing nested unary functions, e.g. x + - - y +expression_ptr Parser::parse_unaryop() { + expression_ptr e; + Token op = token_; + switch(token_.type) { + case tok::plus : + // plus sign is simply ignored + get_token(); // consume '+' + return parse_unaryop(); + case tok::minus : + get_token(); // consume '-' + e = parse_unaryop(); // handle recursive unary + if(!e) return nullptr; + return unary_expression(token_.location, op.type, std::move(e)); + case tok::exp : + case tok::sin : + case tok::cos : + case tok::log : + get_token(); // consume operator (exp, sin, cos or log) + if(token_.type!=tok::lparen) { + error( "missing parenthesis after call to " + + yellow(op.spelling) ); + return nullptr; + } + e = parse_unaryop(); // handle recursive unary + if(!e) return nullptr; + return unary_expression(token_.location, op.type, std::move(e)); + default : + return parse_primary(); + } + return nullptr; +} + +/// parse a primary expression node +/// expects one of : +/// :: number +/// :: identifier +/// :: call +/// :: parenthesis expression (parsed recursively) +expression_ptr Parser::parse_primary() { + switch(token_.type) { + case tok::number: + return parse_number(); + case tok::identifier: + if( peek().type == tok::lparen ) { + return parse_call(); + } + return parse_identifier(); + case tok::lparen: + return parse_parenthesis_expression(); + default: // fall through to return nullptr at end of function + error( pprintf( "unexpected token '%' in expression", + yellow(token_.spelling) )); + } + + return nullptr; +} + +expression_ptr Parser::parse_parenthesis_expression() { + // never call unless at start of parenthesis + + if(token_.type!=tok::lparen) { + throw compiler_exception( + "attempt to parse a parenthesis_expression() without opening parenthesis", + location_); + } + + get_token(); // consume '(' + + auto e = parse_expression(); + + // check for closing parenthesis ')' + if( !e || !expect(tok::rparen) ) return nullptr; + + get_token(); // consume ')' + + return e; +} + +expression_ptr Parser::parse_number() { + auto e = make_expression<NumberExpression>(token_.location, token_.spelling); + + get_token(); // consume the number + + return e; +} + +expression_ptr Parser::parse_binop(expression_ptr&& lhs, Token op_left) { + // only way out of the loop below is by return: + // :: return with nullptr on error + // :: return when loop runs out of operators + // i.e. if(pp<0) + // :: return when recursion applied to remainder of expression + // i.e. if(p_op>p_left) + while(1) { + // get precedence of the left operator + auto p_left = binop_precedence(op_left.type); + + auto e = parse_unaryop(); + if(!e) return nullptr; + + auto op = token_; + auto p_op = binop_precedence(op.type); + if(operator_associativity(op.type)==associativityKind::right) { + p_op += 1; + } + + // if no binop, parsing of expression is finished with (op_left lhs e) + if(p_op < 0) { + return binary_expression(op_left.location, op_left.type, std::move(lhs), std::move(e)); + } + + get_token(); // consume op + if(p_op > p_left) { + auto rhs = parse_binop(std::move(e), op); + if(!rhs) return nullptr; + return binary_expression(op_left.location, op_left.type, std::move(lhs), std::move(rhs)); + } + + lhs = binary_expression(op_left.location, op_left.type, std::move(lhs), std::move(e)); + op_left = op; + } + throw compiler_exception( + "parse_binop() : fell out of recursive parse descent", + location_); + return nullptr; +} + +/// parse a local variable definition +/// a local variable definition is a line with the form +/// LOCAL x +/// where x is a valid identifier name +expression_ptr Parser::parse_local() { + Location loc = location_; + + get_token(); // consume LOCAL + + // create local expression stub + auto e = make_expression<LocalDeclaration>(loc); + if(!e) return e; + + // add symbols + while(1) { + if(!expect(tok::identifier)) return nullptr; + + // try adding variable name to list + if(!e->is_local_declaration()->add_variable(token_)) { + error(e->error_message()); + return nullptr; + } + get_token(); // consume identifier + + // look for comma that indicates continuation of the variable list + if(token_.type == tok::comma) { + get_token(); + } else { + break; + } + } + + return e; +} + +/// parse a SOLVE statement +/// a SOLVE statement specifies a procedure and a method +/// SOLVE procedure METHOD method +/// we also support SOLVE statements without a METHOD clause +/// for backward compatability with performance hacks that +/// are implemented in some key mod files (i.e. Prob* synapses) +expression_ptr Parser::parse_solve() { + int line = location_.line; + Location loc = location_; // solve location for expression + std::string name; + solverMethod method; + + get_token(); // consume the SOLVE keyword + + if(token_.type != tok::identifier) goto solve_statement_error; + + name = token_.spelling; // save name of procedure + get_token(); // consume the procedure identifier + + if(token_.type != tok::method) { // no method was provided + method = solverMethod::none; + } + else { + get_token(); // consume the METHOD keyword + if(token_.type != tok::cnexp) goto solve_statement_error; + method = solverMethod::cnexp; + + get_token(); // consume the method description + } + // check that the rest of the line was empty + if(line == location_.line) { + if(token_.type != tok::eof) goto solve_statement_error; + } + + return make_expression<SolveExpression>(loc, name, method); + +solve_statement_error: + error( "SOLVE statements must have the form\n" + " SOLVE x METHOD cnexp\n" + " or\n" + " SOLVE x\n" + "where 'x' is the name of a DERIVATIVE block", loc); + return nullptr; +} + +/// parse a CONDUCTANCE statement +/// a CONDUCTANCE statement specifies a variable and a channel +/// where the channel is optional +/// CONDUCTANCE name USEION channel +/// CONDUCTANCE name +expression_ptr Parser::parse_conductance() { + int line = location_.line; + Location loc = location_; // solve location for expression + std::string name; + ionKind channel; + + get_token(); // consume the CONDUCTANCE keyword + + if(token_.type != tok::identifier) goto conductance_statement_error; + + name = token_.spelling; // save name of variable + get_token(); // consume the variable identifier + + if(token_.type != tok::useion) { // no ion channel was provided + // we set nonspecific not none because ionKind::none marks + // any variable that is not associated with an ion channel + channel = ionKind::nonspecific; + } + else { + get_token(); // consume the USEION keyword + if(token_.type!=tok::identifier) goto conductance_statement_error; + + if (token_.spelling == "na") channel = ionKind::Na; + else if(token_.spelling == "ca") channel = ionKind::Ca; + else if(token_.spelling == "k") channel = ionKind::K; + else goto conductance_statement_error; + + get_token(); // consume the ion channel type + } + // check that the rest of the line was empty + if(line == location_.line) { + if(token_.type != tok::eof) goto conductance_statement_error; + } + + return make_expression<ConductanceExpression>(loc, name, channel); + +conductance_statement_error: + error( "CONDUCTANCE statements must have the form\n" + " CONDUCTANCE g USEION channel\n" + " or\n" + " CONDUCTANCE g\n" + "where 'g' is the name of a variable, and 'channel' is the type of ion channel", loc); + return nullptr; +} + +expression_ptr Parser::parse_if() { + Token if_token = token_; + get_token(); // consume 'if' + + if(!expect(tok::lparen)) return nullptr; + + // parse the conditional + auto cond = parse_parenthesis_expression(); + if(!cond) return nullptr; + + // parse the block of the true branch + auto true_branch = parse_block(true); + if(!true_branch) return nullptr; + + // parse the false branch if there is an else + expression_ptr false_branch; + if(token_.type == tok::else_stmt) { + get_token(); // consume else + + // handle 'else if {}' case recursively + if(token_.type == tok::if_stmt) { + false_branch = parse_if(); + } + // we have a closing 'else {}' + else if(token_.type == tok::lbrace) { + false_branch = parse_block(true); + } + else { + error("expect either '"+yellow("if")+"' or '"+yellow("{")+" after else"); + return nullptr; + } + } + + return make_expression<IfExpression>(if_token.location, std::move(cond), std::move(true_branch), std::move(false_branch)); +} + +// takes a flag indicating whether the block is at procedure/function body, +// or lower. Can be used to check for illegal statements inside a nested block, +// e.g. LOCAL declarations. +expression_ptr Parser::parse_block(bool is_nested) { + // blocks have to be enclosed in curly braces {} + expect(tok::lbrace); + + get_token(); // consume '{' + + // save the location of the first statement as the starting point for the block + Location block_location = token_.location; + + std::list<expression_ptr> body; + while(token_.type != tok::rbrace) { + auto e = parse_statement(); + if(!e) return e; + + if(is_nested) { + if(e->is_local_declaration()) { + error("LOCAL variable declarations are not allowed inside a nested scope"); + return nullptr; + } + } + + body.emplace_back(std::move(e)); + } + + if(token_.type != tok::rbrace) { + error("could not find closing '" + yellow("}") + + "' for else statement that started at " + + ::to_string(block_location)); + return nullptr; + } + get_token(); // consume closing '}' + + return make_expression<BlockExpression>(block_location, std::move(body), is_nested); +} + +expression_ptr Parser::parse_initial() { + // has to start with INITIAL: error in compiler implementaion otherwise + expect(tok::initial); + + // save the location of the first statement as the starting point for the block + Location block_location = token_.location; + + get_token(); // consume 'INITIAL' + + if(!expect(tok::lbrace)) return nullptr; + get_token(); // consume '{' + + std::list<expression_ptr> body; + while(token_.type != tok::rbrace) { + auto e = parse_statement(); + if(!e) return e; + + // disallow variable declarations in an INITIAL block + if(e->is_local_declaration()) { + error("LOCAL variable declarations are not allowed inside a nested scope"); + return nullptr; + } + + body.emplace_back(std::move(e)); + } + + if(token_.type != tok::rbrace) { + error("could not find closing '" + yellow("}") + + "' for else statement that started at " + + ::to_string(block_location)); + return nullptr; + } + get_token(); // consume closing '}' + + return make_expression<InitialBlock>(block_location, std::move(body)); +} + diff --git a/modcc/src/parser.hpp b/modcc/src/parser.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6768749df46e51181ade05d0d6acb1d4dddbe7c6 --- /dev/null +++ b/modcc/src/parser.hpp @@ -0,0 +1,72 @@ +#pragma once + +#include <memory> +#include <string> + +#include "expression.hpp" +#include "lexer.hpp" +#include "module.hpp" + +class Parser : public Lexer { +public: + + explicit Parser(Module& m, bool advance=true); + Parser(std::string const&); + bool parse(); + + expression_ptr parse_prototype(std::string); + expression_ptr parse_statement(); + expression_ptr parse_identifier(); + expression_ptr parse_number(); + expression_ptr parse_call(); + expression_ptr parse_expression(); + expression_ptr parse_primary(); + expression_ptr parse_parenthesis_expression(); + expression_ptr parse_line_expression(); + expression_ptr parse_binop(expression_ptr&&, Token); + expression_ptr parse_unaryop(); + expression_ptr parse_local(); + expression_ptr parse_solve(); + expression_ptr parse_conductance(); + expression_ptr parse_block(bool); + expression_ptr parse_initial(); + expression_ptr parse_if(); + + symbol_ptr parse_procedure(); + symbol_ptr parse_function(); + + std::string const& error_message() { + return error_string_; + } + +private: + Module *module_; + + // functions for parsing descriptive blocks + // these are called in the first pass, and do not + // construct any AST information + void parse_neuron_block(); + void parse_state_block(); + void parse_units_block(); + void parse_parameter_block(); + void parse_assigned_block(); + void parse_title(); + + std::vector<Token> comma_separated_identifiers(); + std::vector<Token> unit_description(); + + /// build the identifier list + void add_variables_to_symbols(); + + // helper function for logging errors + void error(std::string msg); + void error(std::string msg, Location loc); + + // disable default and copy assignment + Parser(); + Parser(Parser const &); + + bool expect(tok, const char *str=""); + bool expect(tok, std::string const& str); +}; + diff --git a/modcc/src/perfvisitor.hpp b/modcc/src/perfvisitor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..67ed51afd057cc81ce9c5f7642a645da3ced74a9 --- /dev/null +++ b/modcc/src/perfvisitor.hpp @@ -0,0 +1,262 @@ +#pragma once + +#include <cstdio> +#include <iomanip> +#include <set> + +#include "visitor.hpp" + +struct FlopAccumulator { + int add=0; + int neg=0; + int mul=0; + int div=0; + int exp=0; + int sin=0; + int cos=0; + int log=0; + int pow=0; + + void reset() { + add = neg = mul = div = exp = sin = cos = log = 0; + } +}; + +static std::ostream& operator << (std::ostream& os, FlopAccumulator const& f) { + char buffer[512]; + snprintf(buffer, + 512, + " add neg mul div exp sin cos log pow\n%6d%6d%6d%6d%6d%6d%6d%6d%6d", + f.add, f.neg, f.mul, f.div, f.exp, f.sin, f.cos, f.log, f.pow); + + os << buffer << std::endl << std::endl; + os << " add+mul+neg " << f.add + f.neg + f.mul << std::endl; + os << " div " << f.div << std::endl; + os << " exp " << f.exp; + + return os; +} + +class FlopVisitor : public Visitor { +public: + void visit(Expression *e) override {} + + // traverse the statements in an API method + void visit(APIMethod *e) override { + for(auto& expression : *(e->body())) { + expression->accept(this); + } + } + + // traverse the statements in a procedure + void visit(ProcedureExpression *e) override { + for(auto& expression : *(e->body())) { + expression->accept(this); + } + } + + // traverse the statements in a function + void visit(FunctionExpression *e) override { + for(auto& expression : *(e->body())) { + expression->accept(this); + } + } + + //////////////////////////////////////////////////// + // specializations for each type of unary expression + // leave UnaryExpression to throw, to catch + // any missed specializations + //////////////////////////////////////////////////// + void visit(NegUnaryExpression *e) override { + // this is a simplification + // we would have to perform analysis of parent nodes to ensure that + // the negation actually translates into an operation + // :: x - - x // not counted + // :: x * -exp(3) // should be counted + // :: x / -exp(3) // should be counted + // :: x / - -exp(3)// should be counted only once + flops.neg++; + } + void visit(ExpUnaryExpression *e) override { + e->expression()->accept(this); + flops.exp++; + } + void visit(LogUnaryExpression *e) override { + e->expression()->accept(this); + flops.log++; + } + void visit(CosUnaryExpression *e) override { + e->expression()->accept(this); + flops.cos++; + } + void visit(SinUnaryExpression *e) override { + e->expression()->accept(this); + flops.sin++; + } + + //////////////////////////////////////////////////// + // specializations for each type of binary expression + // leave UnaryExpression throw an exception, to catch + // any missed specializations + //////////////////////////////////////////////////// + void visit(BinaryExpression *e) override { + // there must be a specialization of the flops counter for every type + // of binary expression: if we get here there has been an attempt to + // visit a binary expression for which no visitor is implemented + throw compiler_exception( + "PerfVisitor unable to analyse binary expression " + e->to_string(), + e->location()); + } + void visit(AssignmentExpression *e) override { + e->rhs()->accept(this); + } + void visit(AddBinaryExpression *e) override { + e->lhs()->accept(this); + e->rhs()->accept(this); + flops.add++; + } + void visit(SubBinaryExpression *e) override { + e->lhs()->accept(this); + e->rhs()->accept(this); + flops.add++; + } + void visit(MulBinaryExpression *e) override { + e->lhs()->accept(this); + e->rhs()->accept(this); + flops.mul++; + } + void visit(DivBinaryExpression *e) override { + e->lhs()->accept(this); + e->rhs()->accept(this); + flops.div++; + } + void visit(PowBinaryExpression *e) override { + e->lhs()->accept(this); + e->rhs()->accept(this); + flops.pow++; + } + + FlopAccumulator flops; + + std::string print() const { + std::stringstream s; + + s << flops << std::endl; + + return s.str(); + } +}; + +class MemOpVisitor : public Visitor { +public: + void visit(Expression *e) override {} + + // traverse the statements in an API method + void visit(APIMethod *e) override { + for(auto& expression : *(e->body())) { + expression->accept(this); + } + + // create local indexed views + for(auto &symbol : e->scope()->locals()) { + auto var = symbol.second->is_local_variable(); + if(var->is_indexed()) { + if(var->is_read()) { + indexed_reads_.insert(var); + } + else { + indexed_writes_.insert(var); + } + } + } + } + + // traverse the statements in a procedure + void visit(ProcedureExpression *e) override { + for(auto& expression : *(e->body())) { + expression->accept(this); + } + } + + // traverse the statements in a function + void visit(FunctionExpression *e) override { + for(auto& expression : *(e->body())) { + expression->accept(this); + } + } + + void visit(UnaryExpression *e) override { + e->expression()->accept(this); + } + + void visit(BinaryExpression *e) override { + e->lhs()->accept(this); + e->rhs()->accept(this); + } + + void visit(AssignmentExpression *e) override { + // handle the write on the lhs as a special case + auto symbol = e->lhs()->is_identifier()->symbol(); + if(!symbol) { + throw compiler_exception( + " attempt to look up name of identifier for which no symbol_ yet defined", + e->lhs()->location()); + } + switch (symbol->kind()) { + case symbolKind::variable : + vector_writes_.insert(symbol); + break; + case symbolKind::indexed_variable : + indexed_writes_.insert(symbol); + default : + break; + } + + // let the visitor implementation handle the reads + e->rhs()->accept(this); + } + + void visit(IdentifierExpression* e) override { + auto symbol = e->symbol(); + if(!symbol) { + throw compiler_exception( + " attempt to look up name of identifier for which no symbol_ yet defined", + e->location()); + } + switch (symbol->kind()) { + case symbolKind::variable : + if(symbol->is_variable()->is_range()) { + vector_reads_.insert(symbol); + } + break; + case symbolKind::indexed_variable : + indexed_reads_.insert(symbol); + default : + break; + } + } + + std::string print() const { + std::stringstream s; + + auto ir = indexed_reads_.size(); + auto vr = vector_reads_.size(); + auto iw = indexed_writes_.size(); + auto vw = vector_writes_.size(); + + auto w = std::setw(8); + s << " " << w << "read" << w << "write" << w << "total" << std::endl; + s << "indexed " << w << ir << w << iw << w << ir + iw << std::endl; + s << "vector " << w << vr << w << vw << w << vr + vw << std::endl; + s << "total " << w << vr+ir << w << vw +iw << w << vr + vw + ir +iw << std::endl; + + return s.str(); + } + +private: + std::set<Symbol*> indexed_reads_; + std::set<Symbol*> vector_reads_; + std::set<Symbol*> indexed_writes_; + std::set<Symbol*> vector_writes_; +}; + diff --git a/modcc/src/scope.hpp b/modcc/src/scope.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6e81044f717f9e118f53370d1019a312c0fc8dc0 --- /dev/null +++ b/modcc/src/scope.hpp @@ -0,0 +1,127 @@ +#pragma once + +#include "util.hpp" + +#include <memory> +#include <string> +#include <unordered_map> + +// Scope is templated to avoid circular compilation issues. +// When performing semantic analysis of expressions via traversal of the AST +// each node in the AST has a reference to a Scope. This leads to circular +// dependencies, where Symbol nodes refer to Scopes which contain Symbols. +// Using a template means that we can defer Scope definition until after +// the Symbol type defined in expression.h has been defined. +template <typename Symbol> +class Scope { +public: + using symbol_type = Symbol; + using symbol_ptr = std::unique_ptr<Symbol>; + using symbol_map = std::unordered_map<std::string, symbol_ptr>; + + Scope(symbol_map& s); + ~Scope() {}; + symbol_type* add_local_symbol(std::string const& name, symbol_ptr s); + symbol_type* find(std::string const& name); + symbol_type* find_local(std::string const& name); + symbol_type* find_global(std::string const& name); + std::string to_string() const; + + symbol_map& locals(); + symbol_map* globals(); + +private: + symbol_map* global_symbols_=nullptr; + symbol_map local_symbols_; +}; + +template<typename Symbol> +Scope<Symbol>::Scope(symbol_map &s) + : global_symbols_(&s) +{} + +template<typename Symbol> +Symbol* +Scope<Symbol>::add_local_symbol( std::string const& name, + typename Scope<Symbol>::symbol_ptr s) +{ + // check to see if the symbol already exists + if( local_symbols_.find(name) != local_symbols_.end() ) { + return nullptr; + } + + // add symbol to list + local_symbols_[name] = std::move(s); + + return local_symbols_[name].get(); +} + +template<typename Symbol> +Symbol* +Scope<Symbol>::find(std::string const& name) { + auto local = find_local(name); + return local ? local : find_global(name); +} + +template<typename Symbol> +Symbol* +Scope<Symbol>::find_local(std::string const& name) { + // search in local symbols + auto local = local_symbols_.find(name); + + if(local != local_symbols_.end()) { + return local->second.get(); + } + + return nullptr; +} + +template<typename Symbol> +Symbol* +Scope<Symbol>::find_global(std::string const& name) { + // search in global symbols + if( global_symbols_ ) { + auto global = global_symbols_->find(name); + + if(global != global_symbols_->end()) { + return global->second.get(); + } + } + + return nullptr; +} + +template<typename Symbol> +std::string +Scope<Symbol>::to_string() const { + std::string s; + char buffer[16]; + + s += blue("Scope") + "\n"; + s += blue(" global :\n"); + for(auto& sym : *global_symbols_) { + snprintf(buffer, 16, "%-15s", sym.first.c_str()); + s += " " + yellow(buffer); + } + s += "\n"; + s += blue(" local :\n"); + for(auto& sym : local_symbols_) { + snprintf(buffer, 16, "%-15s", sym.first.c_str()); + s += " " + yellow(buffer); + } + + return s; +} + +template<typename Symbol> +typename Scope<Symbol>::symbol_map& +Scope<Symbol>::locals() { + return local_symbols_; +} + +template<typename Symbol> +typename Scope<Symbol>::symbol_map* +Scope<Symbol>::globals() { + return global_symbols_; +} + diff --git a/modcc/src/textbuffer.cpp b/modcc/src/textbuffer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5c8af1df5df85a1f72e241c0b3f461892a65edaa --- /dev/null +++ b/modcc/src/textbuffer.cpp @@ -0,0 +1,49 @@ +#include "textbuffer.hpp" + +/****************************************************************************** + TextBuffer +******************************************************************************/ +TextBuffer& TextBuffer::add_gutter() { + text_ << gutter_; + return *this; +} +void TextBuffer::add_line(std::string const& line) { + text_ << gutter_ << line << std::endl; +} +void TextBuffer::add_line() { + text_ << std::endl; +} +void TextBuffer::end_line(std::string const& line) { + text_ << line << std::endl; +} +void TextBuffer::end_line() { + text_ << std::endl; +} + +std::string TextBuffer::str() const { + return text_.str(); +} + +void TextBuffer::set_gutter(int width) { + indent_ = width; + gutter_ = std::string(indent_, ' '); +} + +void TextBuffer::increase_indentation() { + indent_ += indentation_width_; + if(indent_<0) { + indent_=0; + } + gutter_ = std::string(indent_, ' '); +} +void TextBuffer::decrease_indentation() { + indent_ -= indentation_width_; + if(indent_<0) { + indent_=0; + } + gutter_ = std::string(indent_, ' '); +} + +std::stringstream& TextBuffer::text() { + return text_; +} diff --git a/modcc/src/textbuffer.hpp b/modcc/src/textbuffer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cdb3fad3daacece05c39f88a7b5e13f0c95aa779 --- /dev/null +++ b/modcc/src/textbuffer.hpp @@ -0,0 +1,40 @@ +#pragma once + +#include <limits> +#include <sstream> +#include <string> + +class TextBuffer { +public: + TextBuffer() { + text_.precision(std::numeric_limits<double>::max_digits10); + } + TextBuffer& add_gutter(); + void add_line(std::string const& line); + void add_line(); + void end_line(std::string const& line); + void end_line(); + + std::string str() const; + + void set_gutter(int width); + + void increase_indentation(); + void decrease_indentation(); + std::stringstream &text(); + +private: + + int indent_ = 0; + const int indentation_width_=4; + std::string gutter_ = ""; + std::stringstream text_; +}; + +template <typename T> +TextBuffer& operator<< (TextBuffer& buffer, T const& v) { + buffer.text() << v; + + return buffer; +} + diff --git a/modcc/src/token.cpp b/modcc/src/token.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1216c91287a7be1df4a0311c65c8436bdb4fb1cb --- /dev/null +++ b/modcc/src/token.cpp @@ -0,0 +1,161 @@ +#include <mutex> + +#include "token.hpp" + +// lookup table used for checking if an identifier matches a keyword +std::unordered_map<std::string, tok> keyword_map; + +// for stringifying a token type +std::map<tok, std::string> token_map; + +std::mutex mutex; + +struct Keyword { + const char *name; + tok type; +}; + +struct TokenString { + const char *name; + tok token; +}; + +static Keyword keywords[] = { + {"TITLE", tok::title}, + {"NEURON", tok::neuron}, + {"UNITS", tok::units}, + {"PARAMETER", tok::parameter}, + {"ASSIGNED", tok::assigned}, + {"STATE", tok::state}, + {"BREAKPOINT", tok::breakpoint}, + {"DERIVATIVE", tok::derivative}, + {"PROCEDURE", tok::procedure}, + {"FUNCTION", tok::function}, + {"INITIAL", tok::initial}, + {"NET_RECEIVE", tok::net_receive}, + {"UNITSOFF", tok::unitsoff}, + {"UNITSON", tok::unitson}, + {"SUFFIX", tok::suffix}, + {"NONSPECIFIC_CURRENT", tok::nonspecific_current}, + {"USEION", tok::useion}, + {"READ", tok::read}, + {"WRITE", tok::write}, + {"RANGE", tok::range}, + {"LOCAL", tok::local}, + {"SOLVE", tok::solve}, + {"THREADSAFE", tok::threadsafe}, + {"GLOBAL", tok::global}, + {"POINT_PROCESS", tok::point_process}, + {"METHOD", tok::method}, + {"if", tok::if_stmt}, + {"else", tok::else_stmt}, + {"cnexp", tok::cnexp}, + {"exp", tok::exp}, + {"sin", tok::sin}, + {"cos", tok::cos}, + {"log", tok::log}, + {"CONDUCTANCE", tok::conductance}, + {nullptr, tok::reserved}, +}; + +static TokenString token_strings[] = { + {"=", tok::eq}, + {"+", tok::plus}, + {"-", tok::minus}, + {"*", tok::times}, + {"/", tok::divide}, + {"^", tok::pow}, + {"!", tok::lnot}, + {"<", tok::lt}, + {"<=", tok::lte}, + {">", tok::gt}, + {">=", tok::gte}, + {"==", tok::equality}, + {"!=", tok::ne}, + {",", tok::comma}, + {"'", tok::prime}, + {"{", tok::lbrace}, + {"}", tok::rbrace}, + {"(", tok::lparen}, + {")", tok::rparen}, + {"identifier", tok::identifier}, + {"number", tok::number}, + {"TITLE", tok::title}, + {"NEURON", tok::neuron}, + {"UNITS", tok::units}, + {"PARAMETER", tok::parameter}, + {"ASSIGNED", tok::assigned}, + {"STATE", tok::state}, + {"BREAKPOINT", tok::breakpoint}, + {"DERIVATIVE", tok::derivative}, + {"PROCEDURE", tok::procedure}, + {"FUNCTION", tok::function}, + {"INITIAL", tok::initial}, + {"NET_RECEIVE", tok::net_receive}, + {"UNITSOFF", tok::unitsoff}, + {"UNITSON", tok::unitson}, + {"SUFFIX", tok::suffix}, + {"NONSPECIFIC_CURRENT", tok::nonspecific_current}, + {"USEION", tok::useion}, + {"READ", tok::read}, + {"WRITE", tok::write}, + {"RANGE", tok::range}, + {"LOCAL", tok::local}, + {"SOLVE", tok::solve}, + {"THREADSAFE", tok::threadsafe}, + {"GLOBAL", tok::global}, + {"POINT_PROCESS", tok::point_process}, + {"METHOD", tok::method}, + {"if", tok::if_stmt}, + {"else", tok::else_stmt}, + {"eof", tok::eof}, + {"exp", tok::exp}, + {"log", tok::log}, + {"cos", tok::cos}, + {"sin", tok::sin}, + {"cnexp", tok::cnexp}, + {"CONDUCTANCE", tok::conductance}, + {"error", tok::reserved}, +}; + +/// set up lookup tables for converting between tokens and their +/// string representations +void initialize_token_maps() { + // ensure that tables are initialized only once + std::lock_guard<std::mutex> g(mutex); + + if(keyword_map.size()==0) { + ////////////////////// + /// keyword map + ////////////////////// + for(int i = 0; keywords[i].name!=nullptr; ++i) { + keyword_map.insert( {keywords[i].name, keywords[i].type} ); + } + + ////////////////////// + // token map + ////////////////////// + int i; + for(i = 0; token_strings[i].token!=tok::reserved; ++i) { + token_map.insert( {token_strings[i].token, token_strings[i].name} ); + } + // insert the last token: tok::reserved + token_map.insert( {token_strings[i].token, token_strings[i].name} ); + } +} + +std::string token_string(tok token) { + auto pos = token_map.find(token); + return pos==token_map.end() ? std::string("<unknown token>") : pos->second; +} + +bool is_keyword(Token const& t) { + for(Keyword *k=keywords; k->name!=nullptr; ++k) + if(t.type == k->type) + return true; + return false; +} + +std::ostream& operator<< (std::ostream& os, Token const& t) { + return os << "<<" << token_string(t.type) << ", " << t.spelling << ", " << t.location << ">>"; +} diff --git a/modcc/src/token.hpp b/modcc/src/token.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8e97e47f9fa8a46bed23913a4762a90324bf60e2 --- /dev/null +++ b/modcc/src/token.hpp @@ -0,0 +1,110 @@ +#pragma once + +#include <string> +#include <map> +#include <unordered_map> + +#include "location.hpp" + +enum class tok { + eof, // end of file + + ///////////////////////////// + // symbols + ///////////////////////////// + // = + - * / ^ + eq, plus, minus, times, divide, pow, + // comparison + lnot, // ! named logical not, to avoid clash with C++ not keyword + lt, // < + lte, // <= + gt, // > + gte, // >= + equality,// == + ne, // != + + // , ' + comma, prime, + + // { } + lbrace, rbrace, + // ( ) + lparen, rparen, + + // variable/function names + identifier, + + // numbers + number, + + ///////////////////////////// + // keywords + ///////////////////////////// + // block keywoards + title, + neuron, units, parameter, + assigned, state, breakpoint, + derivative, procedure, initial, function, + net_receive, + + // keywoards inside blocks + unitsoff, unitson, + suffix, nonspecific_current, useion, + read, write, + range, local, + solve, method, + threadsafe, global, + point_process, + + // unary operators + exp, sin, cos, log, + + // logical keywords + if_stmt, else_stmt, // add _stmt to avoid clash with c++ keywords + + // solver methods + cnexp, + + conductance, + + reserved, // placeholder for generating keyword lookup +}; + +// what is in a token? +// tok indicating type of token +// information about its location +struct Token { + // the spelling string contains the text of the token as it was written + // in the input file + // type = tok::number : spelling = "3.1415" (e.g.) + // type = tok::identifier : spelling = "foo_bar" (e.g.) + // type = tok::plus : spelling = "+" (always) + // type = tok::if : spelling = "if" (always) + std::string spelling; + tok type; + Location location; + + Token(tok tok, std::string const& sp, Location loc=Location(0,0)) + : spelling(sp), + type(tok), + location(loc) + {} + + Token() + : spelling(""), + type(tok::reserved), + location(Location()) + {}; +}; + +// lookup table used for checking if an identifier matches a keyword +extern std::unordered_map<std::string, tok> keyword_map; + +// for stringifying a token type +extern std::map<tok, std::string> token_map; + +void initialize_token_maps(); +std::string token_string(tok token); +bool is_keyword(Token const& t); +std::ostream& operator<< (std::ostream& os, Token const& t); + diff --git a/modcc/src/util.hpp b/modcc/src/util.hpp new file mode 100644 index 0000000000000000000000000000000000000000..00f323f5375637efa2dff0bf9519adf874a7a6ad --- /dev/null +++ b/modcc/src/util.hpp @@ -0,0 +1,136 @@ +#pragma once + +#include <exception> +#include <memory> +#include <sstream> +#include <vector> + +// is thing in list? +template <typename T, int N> +bool is_in(T thing, const T (&list)[N]) { + for(auto const& item : list) { + if(thing==item) { + return true; + } + } + return false; +} + +inline std::string pprintf(const char *s) { + std::string errstring; + while(*s) { + if(*s == '%' && s[1]!='%') { + // instead of throwing an exception, replace with ?? + //throw std::runtime_error("pprintf: the number of arguments did not match the format "); + errstring += "<?>"; + } + else { + errstring += *s; + } + ++s; + } + return errstring; +} + +// variadic printf for easy error messages +template <typename T, typename ... Args> +std::string pprintf(const char *s, T value, Args... args) { + std::string errstring; + while(*s) { + if(*s == '%' && s[1]!='%') { + std::stringstream str; + str << value; + errstring += str.str(); + errstring += pprintf(++s, args...); + return errstring; + } + else { + errstring += *s; + ++s; + } + } + return errstring; +} + +template <typename T> +std::string to_string(T val) { + std::stringstream str; + str << val; + return str.str(); +} + +//'\e[1;31m' # Red +//'\e[1;32m' # Green +//'\e[1;33m' # Yellow +//'\e[1;34m' # Blue +//'\e[1;35m' # Purple +//'\e[1;36m' # Cyan +//'\e[1;37m' # White +enum class stringColor {white, red, green, blue, yellow, purple, cyan}; + +#define COLOR_PRINTING +#ifdef COLOR_PRINTING +inline std::string colorize(std::string const& s, stringColor c) { + switch(c) { + case stringColor::white : + return "\033[1;37m" + s + "\033[0m"; + case stringColor::red : + return "\033[1;31m" + s + "\033[0m"; + case stringColor::green : + return "\033[1;32m" + s + "\033[0m"; + case stringColor::blue : + return "\033[1;34m" + s + "\033[0m"; + case stringColor::yellow: + return "\033[1;33m" + s + "\033[0m"; + case stringColor::purple: + return "\033[1;35m" + s + "\033[0m"; + case stringColor::cyan : + return "\033[1;36m" + s + "\033[0m"; + } + return s; +} +#else +inline std::string colorize(std::string const& s, stringColor c) { + return s; +} +#endif + +// helpers for inline printing +inline std::string red(std::string const& s) { + return colorize(s, stringColor::red); +} +inline std::string green(std::string const& s) { + return colorize(s, stringColor::green); +} +inline std::string yellow(std::string const& s) { + return colorize(s, stringColor::yellow); +} +inline std::string blue(std::string const& s) { + return colorize(s, stringColor::blue); +} +inline std::string purple(std::string const& s) { + return colorize(s, stringColor::purple); +} +inline std::string cyan(std::string const& s) { + return colorize(s, stringColor::cyan); +} +inline std::string white(std::string const& s) { + return colorize(s, stringColor::white); +} + +template <typename T> +std::ostream& operator<< (std::ostream& os, std::vector<T> const& V) { + os << "["; + for(auto it = V.begin(); it!=V.end(); ++it) { // ugly loop, pretty printing + os << *it << (it==V.end()-1 ? "" : " "); + } + return os << "]"; +} + +// just because we aren't using C++14, doesn't mean we shouldn't go +// without make_unique +template <typename T, typename... Args> +std::unique_ptr<T> make_unique(Args&&... args) { + return std::unique_ptr<T>(new T(std::forward<Args>(args) ...)); +} + diff --git a/modcc/src/visitor.hpp b/modcc/src/visitor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4a32e5fa8f0bde47ba071058be7f5d601c4d561d --- /dev/null +++ b/modcc/src/visitor.hpp @@ -0,0 +1,64 @@ +#pragma once + +#include "error.hpp" +#include "expression.hpp" +#include "util.hpp" + +/// visitor base class +/// The visitors for all AST nodes throw an exception +/// by default, with node types calling the default visitor for a parent +/// For example, all BinaryExpression types call the visitor defined for +/// BinaryExpression, so by overriding just the base implementation, all of the +/// children get that implementation for free, which might be useful for some +/// use cases. +/// +/// heavily inspired by the DMD D compiler : github.com/D-Programming-Language/dmd +class Visitor { +public: + virtual void visit(Expression *e) { + throw compiler_exception("unimplemented visitor", Location()); + } + virtual void visit(Symbol *e) { visit((Expression*) e); } + virtual void visit(LocalVariable *e) { visit((Expression*) e); } + virtual void visit(IdentifierExpression *e) { visit((Expression*) e); } + virtual void visit(NumberExpression *e) { visit((Expression*) e); } + virtual void visit(LocalDeclaration *e) { visit((Expression*) e); } + virtual void visit(ArgumentExpression *e) { visit((Expression*) e); } + virtual void visit(PrototypeExpression *e) { visit((Expression*) e); } + virtual void visit(CallExpression *e) { visit((Expression*) e); } + virtual void visit(VariableExpression *e) { visit((Expression*) e); } + virtual void visit(IndexedVariable *e) { visit((Expression*) e); } + virtual void visit(FunctionExpression *e) { visit((Expression*) e); } + virtual void visit(IfExpression *e) { visit((Expression*) e); } + virtual void visit(SolveExpression *e) { visit((Expression*) e); } + virtual void visit(DerivativeExpression *e) { visit((Expression*) e); } + + virtual void visit(ProcedureExpression *e) { visit((Expression*) e); } + virtual void visit(NetReceiveExpression *e) { visit((ProcedureExpression*) e); } + virtual void visit(APIMethod *e) { visit((Expression*) e); } + + virtual void visit(BlockExpression *e) { visit((Expression*) e); } + virtual void visit(InitialBlock *e) { visit((BlockExpression*) e); } + + virtual void visit(UnaryExpression *e) { + throw compiler_exception("unimplemented visitor (UnaryExpression)", Location()); + } + virtual void visit(NegUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(ExpUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(LogUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(CosUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(SinUnaryExpression *e) { visit((UnaryExpression*) e); } + + virtual void visit(BinaryExpression *e) { + throw compiler_exception("unimplemented visitor (BinaryExpression)", Location()); + } + virtual void visit(AssignmentExpression *e) { visit((BinaryExpression*) e); } + virtual void visit(AddBinaryExpression *e) { visit((BinaryExpression*) e); } + virtual void visit(SubBinaryExpression *e) { visit((BinaryExpression*) e); } + virtual void visit(MulBinaryExpression *e) { visit((BinaryExpression*) e); } + virtual void visit(DivBinaryExpression *e) { visit((BinaryExpression*) e); } + virtual void visit(PowBinaryExpression *e) { visit((BinaryExpression*) e); } + + virtual ~Visitor() {}; +}; +