diff --git a/modcc/CMakeLists.txt b/modcc/CMakeLists.txt index 2398b12fe4b2f722f870afcf241500e9887d47fc..612488e8f2a0449d7584e5ab3b9df832949ccd6c 100644 --- a/modcc/CMakeLists.txt +++ b/modcc/CMakeLists.txt @@ -1,5 +1,6 @@ set(MODCC_SOURCES astmanip.cpp + cexpr_emit.cpp constantfolder.cpp cprinter.cpp cudaprinter.cpp diff --git a/modcc/cexpr_emit.cpp b/modcc/cexpr_emit.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1b1e8f49d0f5c5fcccc7756506e673f5c16098dd --- /dev/null +++ b/modcc/cexpr_emit.cpp @@ -0,0 +1,112 @@ +#include <ostream> +#include <unordered_map> + +#include "cexpr_emit.hpp" +#include "error.hpp" + +void CExprEmitter::emit_as_call(const char* sub, Expression* e) { + out_ << sub << '('; + e->accept(this); + out_ << ')'; +} + +void CExprEmitter::emit_as_call(const char* sub, Expression* e1, Expression* e2) { + out_ << sub << '('; + e1->accept(this); + out_ << ", "; + e2->accept(this); + out_ << ')'; +} + +void CExprEmitter::visit(NumberExpression* e) { + out_ << " " << e->value(); +} + +void CExprEmitter::visit(UnaryExpression* e) { + // Place a space in front of minus sign to avoid invalid + // expressions of the form: (v[i]--67) + static std::unordered_map<tok, const char*> unaryop_tbl = { + {tok::minus, " -"}, + {tok::exp, "exp"}, + {tok::cos, "cos"}, + {tok::sin, "sin"}, + {tok::log, "log"} + }; + + if (!unaryop_tbl.count(e->op())) { + throw compiler_exception( + "CExprEmitter: unsupported unary operator "+token_string(e->op()), e->location()); + } + + const char* op_spelling = unaryop_tbl.at(e->op()); + Expression* inner = e->expression(); + + // No need to use parenthesis for unary minus if inner expression is + // not binary. + if (e->op()==tok::minus && !inner->is_binary()) { + out_ << op_spelling; + inner->accept(this); + } + else { + emit_as_call(op_spelling, inner); + } +} + +void CExprEmitter::visit(AssignmentExpression* e) { + e->lhs()->accept(this); + out_ << " = "; + e->rhs()->accept(this); +} + +void CExprEmitter::visit(PowBinaryExpression* e) { + emit_as_call("std::pow", e->lhs(), e->rhs()); +} + +void CExprEmitter::visit(BinaryExpression *e) { + static std::unordered_map<tok, const char*> binop_tbl = { + {tok::minus, "-"}, + {tok::plus, "+"}, + {tok::times, "*"}, + {tok::divide, "/"}, + {tok::lt, "<"}, + {tok::lte, "<="}, + {tok::gt, ">"}, + {tok::gte, ">="}, + {tok::equality, "=="} + }; + + if (!binop_tbl.count(e->op())) { + throw compiler_exception( + "CExprEmitter: unsupported binary operator "+token_string(e->op()), e->location()); + } + + const char* op_spelling = binop_tbl.at(e->op()); + associativityKind assoc = Lexer::operator_associativity(e->op()); + int op_prec = Lexer::binop_precedence(e->op()); + + auto need_paren = [op_prec](Expression* subexpr, bool assoc_side) -> bool { + if (auto b = subexpr->is_binary()) { + int sub_prec = Lexer::binop_precedence(b->op()); + return sub_prec<op_prec || (!assoc_side && sub_prec==op_prec); + } + return false; + }; + + auto lhs = e->lhs(); + if (need_paren(lhs, assoc==associativityKind::left)) { + emit_as_call("", lhs); + } + else { + lhs->accept(this); + } + + out_ << op_spelling; + + auto rhs = e->rhs(); + if (need_paren(rhs, assoc==associativityKind::right)) { + emit_as_call("", rhs); + } + else { + rhs->accept(this); + } +} diff --git a/modcc/cexpr_emit.hpp b/modcc/cexpr_emit.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c9096a0a3a9794eff344134dda246032da104b62 --- /dev/null +++ b/modcc/cexpr_emit.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include "expression.hpp" +#include "visitor.hpp" + +// Common functionality for generating source from binary expressions +// as C expressions. + +class CExprEmitter: public Visitor { +public: + CExprEmitter(std::ostream& out, Visitor* fallback): + out_(out), fallback_(fallback) + {} + + virtual void visit(Expression* e) { e->accept(fallback_); } + + virtual void visit(UnaryExpression *e) override; + virtual void visit(BinaryExpression *e) override; + virtual void visit(AssignmentExpression *e) override; + virtual void visit(PowBinaryExpression *e) override; + virtual void visit(NumberExpression *e) override; + +protected: + std::ostream& out_; + Visitor* fallback_; + + void emit_as_call(const char* sub, Expression*); + void emit_as_call(const char* sub, Expression*, Expression*); +}; + +inline void cexpr_emit(Expression* e, std::ostream& out, Visitor* fallback) { + CExprEmitter renderer(out, fallback); + e->accept(&renderer); +} diff --git a/modcc/cprinter.cpp b/modcc/cprinter.cpp index ebfe5d9d799c6933b074161af5b7bcd624a9e092..4163403216e0cc57effcf50d90626e65cc4dbeba 100644 --- a/modcc/cprinter.cpp +++ b/modcc/cprinter.cpp @@ -2,6 +2,7 @@ #include <string> #include <unordered_set> +#include "cexpr_emit.hpp" #include "cprinter.hpp" #include "lexer.hpp" #include "options.hpp" @@ -544,7 +545,7 @@ void CPrinter::visit(LocalVariable *e) { } void CPrinter::visit(NumberExpression *e) { - text_ << " " << e->value(); + cexpr_emit(e, text_.text(), this); } void CPrinter::visit(IdentifierExpression *e) { @@ -567,41 +568,7 @@ void CPrinter::visit(CellIndexedVariable *e) { } void CPrinter::visit(UnaryExpression *e) { - auto b = (e->expression()->is_binary()!=nullptr); - switch(e->op()) { - case tok::minus : - // place a space in front of minus sign to avoid invalid - // expressions of the form : (v[i]--67) - if(b) text_ << " -("; - else text_ << " -"; - e->expression()->accept(this); - if(b) text_ << ")"; - return; - case tok::exp : - text_ << "exp("; - e->expression()->accept(this); - text_ << ")"; - return; - case tok::cos : - text_ << "cos("; - e->expression()->accept(this); - text_ << ")"; - return; - case tok::sin : - text_ << "sin("; - e->expression()->accept(this); - text_ << ")"; - return; - case tok::log : - text_ << "log("; - e->expression()->accept(this); - text_ << ")"; - return; - default : - throw compiler_exception( - "CPrinter unsupported unary operator " + yellow(token_string(e->op())), - e->location()); - } + cexpr_emit(e, text_.text(), this); } void CPrinter::visit(BlockExpression *e) { @@ -943,72 +910,7 @@ void CPrinter::visit(CallExpression *e) { text_ << ")"; } -void CPrinter::visit(AssignmentExpression *e) { - e->lhs()->accept(this); - text_ << " = "; - e->rhs()->accept(this); -} - -void CPrinter::visit(PowBinaryExpression *e) { - text_ << "std::pow("; - e->lhs()->accept(this); - text_ << ", "; - e->rhs()->accept(this); - text_ << ")"; -} - void CPrinter::visit(BinaryExpression *e) { - auto pop = parent_op_; - // TODO unit tests for parenthesis and binops - bool use_brackets = - Lexer::binop_precedence(pop) > Lexer::binop_precedence(e->op()) - || (pop==tok::divide && e->op()==tok::times); - parent_op_ = e->op(); - - auto lhs = e->lhs(); - auto rhs = e->rhs(); - if(use_brackets) { - text_ << "("; - } - lhs->accept(this); - switch(e->op()) { - case tok::minus : - text_ << "-"; - break; - case tok::plus : - text_ << "+"; - break; - case tok::times : - text_ << "*"; - break; - case tok::divide : - text_ << "/"; - break; - case tok::lt : - text_ << "<"; - break; - case tok::lte : - text_ << "<="; - break; - case tok::gt : - text_ << ">"; - break; - case tok::gte : - text_ << ">="; - break; - case tok::equality : - text_ << "=="; - break; - default : - throw compiler_exception( - "CPrinter unsupported binary operator " + yellow(token_string(e->op())), - e->location()); - } - rhs->accept(this); - if(use_brackets) { - text_ << ")"; - } - - // reset parent precedence - parent_op_ = pop; + cexpr_emit(e, text_.text(), this); } + diff --git a/modcc/cprinter.hpp b/modcc/cprinter.hpp index 0a254bc569fbc7d7209c0767d45612b96cf6fea7..8c8a5721757a770ec8fc749b619cb33f490c7d48 100644 --- a/modcc/cprinter.hpp +++ b/modcc/cprinter.hpp @@ -14,8 +14,6 @@ public: virtual void visit(Expression *e) override; virtual void visit(UnaryExpression *e) override; virtual void visit(BinaryExpression *e) override; - virtual void visit(AssignmentExpression *e) override; - virtual void visit(PowBinaryExpression *e) override; virtual void visit(NumberExpression *e) override; virtual void visit(VariableExpression *e) override; virtual void visit(Symbol *e) override; @@ -62,7 +60,6 @@ protected: void print_APIMethod_unoptimized(APIMethod* e); Module *module_ = nullptr; - tok parent_op_ = tok::eq; TextBuffer text_; bool optimize_ = false; bool aliased_output_ = false; @@ -121,7 +118,7 @@ protected: } bool is_point_process() { - return module_->kind() == moduleKind::point; + return module_ && module_->kind() == moduleKind::point; } std::vector<LocalVariable*> aliased_vars(APIMethod* e); diff --git a/modcc/cudaprinter.cpp b/modcc/cudaprinter.cpp index 461d69308a1ccd4234e94a569ffbcc67b6c4354f..cea5f9317cd0a6e3f4730e161c8bf2a08cc3d5a9 100644 --- a/modcc/cudaprinter.cpp +++ b/modcc/cudaprinter.cpp @@ -2,6 +2,7 @@ #include <string> #include <unordered_set> +#include "cexpr_emit.hpp" #include "cudaprinter.hpp" #include "lexer.hpp" #include "options.hpp" @@ -645,7 +646,7 @@ void CUDAPrinter::visit(LocalDeclaration *e) { } void CUDAPrinter::visit(NumberExpression *e) { - buffer() << " " << e->value(); + cexpr_emit(e, buffer().text(), this); } void CUDAPrinter::visit(IdentifierExpression *e) { @@ -708,43 +709,7 @@ void CUDAPrinter::visit(LocalVariable *e) { } void CUDAPrinter::visit(UnaryExpression *e) { - auto b = (e->expression()->is_binary()!=nullptr); - switch(e->op()) { - case tok::minus : - // place a space in front of minus sign to avoid invalid - // expressions of the form : (v[i]--67) - // use parenthesis if expression is a binop, otherwise - // -(v+2) becomes -v+2 - if(b) buffer() << " -("; - else buffer() << " -"; - e->expression()->accept(this); - if(b) buffer() << ")"; - return; - case tok::exp : - buffer() << "exp("; - e->expression()->accept(this); - buffer() << ")"; - return; - case tok::cos : - buffer() << "cos("; - e->expression()->accept(this); - buffer() << ")"; - return; - case tok::sin : - buffer() << "sin("; - e->expression()->accept(this); - buffer() << ")"; - return; - case tok::log : - buffer() << "log("; - e->expression()->accept(this); - buffer() << ")"; - return; - default : - throw compiler_exception( - "CUDAPrinter unsupported unary operator " + yellow(token_string(e->op())), - e->location()); - } + cexpr_emit(e, buffer().text(), this); } void CUDAPrinter::visit(BlockExpression *e) { @@ -1039,73 +1004,6 @@ void CUDAPrinter::visit(CallExpression *e) { buffer() << ")"; } -void CUDAPrinter::visit(AssignmentExpression *e) { - e->lhs()->accept(this); - buffer() << " = "; - e->rhs()->accept(this); -} - -void CUDAPrinter::visit(PowBinaryExpression *e) { - buffer() << "std::pow("; - e->lhs()->accept(this); - buffer() << ", "; - e->rhs()->accept(this); - buffer() << ")"; -} - void CUDAPrinter::visit(BinaryExpression *e) { - auto pop = parent_op_; - // TODO unit tests for parenthesis and binops - bool use_brackets = - Lexer::binop_precedence(pop) > Lexer::binop_precedence(e->op()) - || (pop==tok::divide && e->op()==tok::times); - parent_op_ = e->op(); - - - auto lhs = e->lhs(); - auto rhs = e->rhs(); - if(use_brackets) { - buffer() << "("; - } - lhs->accept(this); - switch(e->op()) { - case tok::minus : - buffer() << "-"; - break; - case tok::plus : - buffer() << "+"; - break; - case tok::times : - buffer() << "*"; - break; - case tok::divide : - buffer() << "/"; - break; - case tok::lt : - buffer() << "<"; - break; - case tok::lte : - buffer() << "<="; - break; - case tok::gt : - buffer() << ">"; - break; - case tok::gte : - buffer() << ">="; - break; - case tok::equality : - buffer() << "=="; - break; - default : - throw compiler_exception( - "CUDAPrinter unsupported binary operator " + yellow(token_string(e->op())), - e->location()); - } - rhs->accept(this); - if(use_brackets) { - buffer() << ")"; - } - - // reset parent precedence - parent_op_ = pop; + cexpr_emit(e, buffer().text(), this); } diff --git a/modcc/cudaprinter.hpp b/modcc/cudaprinter.hpp index 35750720f5de07b69aed648d5158ff0bfae7c50a..d016f23609866e32e7c6170d62c1b1fb89d626a5 100644 --- a/modcc/cudaprinter.hpp +++ b/modcc/cudaprinter.hpp @@ -14,8 +14,6 @@ public: void visit(Expression *e) override; void visit(UnaryExpression *e) override; void visit(BinaryExpression *e) override; - void visit(AssignmentExpression *e) override; - void visit(PowBinaryExpression *e) override; void visit(NumberExpression *e) override; void visit(VariableExpression *e) override; @@ -44,6 +42,11 @@ public: return interface_.str(); } + // public for testing purposes: + void set_buffer(TextBuffer& buf) { + current_buffer_ = &buf; + } + private: bool is_input(Symbol *s) { @@ -112,10 +115,6 @@ private: TextBuffer impl_interface_; TextBuffer* current_buffer_; - void set_buffer(TextBuffer& buf) { - current_buffer_ = &buf; - } - TextBuffer& buffer() { if (!current_buffer_) { throw std::runtime_error("CUDAPrinter buffer must be set via CUDAPrinter::set_buffer() before accessing via CUDAPrinter::buffer()."); diff --git a/modcc/token.hpp b/modcc/token.hpp index 95f7c461f368775f4bc1c67d6538f0c15e506503..fcafd2e1d47930859fc31f498488fbf3118409cc 100644 --- a/modcc/token.hpp +++ b/modcc/token.hpp @@ -77,6 +77,16 @@ enum class tok { reserved, // placeholder for generating keyword lookup }; +namespace std { + // note: necessary before C++14 (refer: lwg dr#2148). + template <> + struct hash<tok> { + std::size_t operator()(const tok& x) const { + return std::hash<int>()(static_cast<int>(x)); + } + }; +} + // what is in a token? // tok indicating type of token // information about its location diff --git a/tests/modcc/CMakeLists.txt b/tests/modcc/CMakeLists.txt index 681839de853f21a71c8534e1eb520c7f96d83d7c..0416812aee8bc8abebdd96d04bc2d03a3084fbb9 100644 --- a/tests/modcc/CMakeLists.txt +++ b/tests/modcc/CMakeLists.txt @@ -6,7 +6,7 @@ set(MODCC_TEST_SOURCES test_msparse.cpp test_optimization.cpp test_parser.cpp - #test_printers.cpp + test_printers.cpp test_removelocals.cpp test_symdiff.cpp test_symge.cpp diff --git a/tests/modcc/test_printers.cpp b/tests/modcc/test_printers.cpp index dad3a71911de27e243112678b82698d5b792ad51..855d9d69c5d6788c2b220bb41acbc1eee0e17d9b 100644 --- a/tests/modcc/test_printers.cpp +++ b/tests/modcc/test_printers.cpp @@ -1,55 +1,108 @@ +#include <regex> +#include <string> + #include "test.hpp" #include "cprinter.hpp" +#include "cudaprinter.hpp" +#include "expression.hpp" +#include "textbuffer.hpp" + +struct testcase { + const char* source; + const char* expected; +}; + +static std::string strip(std::string text) { + // Strip all spaces, except when between two minus symbols, where instead + // they should be replaced by a single space: + // + // 1. Replace whitespace with two spaces. + // 2. Replace '- -' with '- -'. + // 3. Replace ' ' with ''. + + static std::regex rx1("\\s+"); + static std::regex rx2("- -"); + static std::regex rx3(" "); + + text = std::regex_replace(text, rx1, " "); + text = std::regex_replace(text, rx2, "- -"); + text = std::regex_replace(text, rx3, ""); + + return text; +} -using scope_type = Scope<Symbol>; -using symbol_map = scope_type::symbol_map; -using symbol_ptr = Scope<Symbol>::symbol_ptr; - -TEST(CPrinter, statement) { - std::vector<const char*> expressions = - { -"y=x+3", -"y=y^z", -"y=exp(x/2 + 3)", +TEST(scalar_printer, statement) { + std::vector<testcase> testcases = { + {"y=x+3", "y=x+3"}, + {"y=y^z", "y=std::pow(y,z)"}, + {"y=exp((x/2) + 3)", "y=exp(x/2+3)"}, + {"z=a/b/c", "z=a/b/c"}, + {"z=a/(b/c)", "z=a/(b/c)"}, + {"z=(a*b)/c", "z=a*b/c"}, + {"z=a-(b+c)", "z=a-(b+c)"}, + {"z=(a>0)<(b>0)", "z=a>0<(b>0)"}, + {"z=a- -2", "z=a- -2"} }; // create a scope that contains the symbols used in the tests Scope<Symbol>::symbol_map globals; - globals["x"] = make_symbol<Symbol>(Location(), "x", symbolKind::local); - globals["y"] = make_symbol<Symbol>(Location(), "y", symbolKind::local); - globals["z"] = make_symbol<Symbol>(Location(), "z", symbolKind::local); - auto scope = std::make_shared<Scope<Symbol>>(globals); - for(auto const& expression : expressions) { - auto e = parse_line_expression(expression); + for (auto var: {"x", "y", "z", "a", "b", "c"}) { + scope->add_local_symbol(var, make_symbol<LocalVariable>(Location(), var, localVariableKind::local)); + } - // sanity check the compiler - EXPECT_NE(e, nullptr); - if( e==nullptr ) continue; + for (const auto& tc: testcases) { + auto e = parse_line_expression(tc.source); + ASSERT_TRUE(e); e->semantic(scope); - auto v = make_unique<CPrinter>(); - e->accept(v.get()); -#ifdef VERBOSE_TEST - std::cout << e->to_string() << std::endl; - << " :--: " << v->text() << std::endl; -#endif + { + SCOPED_TRACE("CPrinter"); + auto printer = make_unique<CPrinter>(); + e->accept(printer.get()); + std::string text = printer->text(); + + verbose_print(e->to_string(), " :--: ", text); + EXPECT_EQ(strip(tc.expected), strip(text)); + } + + { + SCOPED_TRACE("CUDAPrinter"); + TextBuffer buf; + auto printer = make_unique<CUDAPrinter>(); + printer->set_buffer(buf); + + e->accept(printer.get()); + std::string text = buf.str(); + + verbose_print(e->to_string(), " :--: ", text); + EXPECT_EQ(strip(tc.expected), strip(text)); + } } } TEST(CPrinter, proc) { - std::vector<const char*> expressions = - { -"PROCEDURE trates(v) {\n" -" LOCAL k\n" -" minf=1-1/(1+exp((v-k)/k))\n" -" hinf=1/(1+exp((v-k)/k))\n" -" mtau = 0.6\n" -" htau = 1500\n" -"}" + std::vector<testcase> testcases = { + { + "PROCEDURE trates(v) {\n" + " LOCAL k\n" + " minf=1-1/(1+exp((v-k)/k))\n" + " hinf=1/(1+exp((v-k)/k))\n" + " mtau = 0.6\n" + " htau = 1500\n" + "}" + , + "void trates(int i_, value_type v) {\n" + "value_type k;\n" + "minf[i_] = 1-1/(1+exp((v-k)/k));\n" + "hinf[i_] = 1/(1+exp((v-k)/k));\n" + "mtau[i_] = 0.6;\n" + "htau[i_] = 1500;\n" + "}" + } }; // create a scope that contains the symbols used in the tests @@ -60,23 +113,20 @@ TEST(CPrinter, proc) { globals["htau"] = make_symbol<VariableExpression>(Location(), "htau"); globals["v"] = make_symbol<VariableExpression>(Location(), "v"); - for(auto const& expression : expressions) { - auto e = symbol_ptr{parse_procedure(expression)->is_symbol()}; + for (const auto& tc: testcases) { + expression_ptr e = parse_procedure(tc.source); + ASSERT_TRUE(e->is_symbol()); - // sanity check the compiler - EXPECT_NE(e, nullptr); + auto procname = e->is_symbol()->name(); + auto& proc = (globals[procname] = symbol_ptr(e.release()->is_symbol())); - if( e==nullptr ) continue; - - globals["trates"] = std::move(e); - - e->semantic(globals); + proc->semantic(globals); auto v = make_unique<CPrinter>(); - e->accept(v.get()); + proc->accept(v.get()); + + verbose_print(proc->to_string()); + verbose_print(" :--: ", v->text()); -#ifdef VERBOSE_TEST - std::cout << e->to_string() << std::endl; - << " :--: " << v->text() << std::endl; -#endif + EXPECT_EQ(strip(tc.expected), strip(v->text())); } }