From 74616108b38b8deb65df810bba05907bbc32c9cb Mon Sep 17 00:00:00 2001
From: Nora Abi Akar <nora.abiakar@gmail.com>
Date: Mon, 29 Jul 2019 18:38:40 +0200
Subject: [PATCH] Add CONSTANT block support for modcc (#825)

Addresses #824.

* Add modcc support for `CONSTANT` blocks in NMODL, subject to the following constraints:
    - Any identifier declared in the `CONSTANT` block may only be used after that declaration (including within the `CONSTANT` block itself).
    - Units in the `CONSTANT` block are parsed but not processed.
    - Values assigned to identifiers inside the `CONSTANT` block may only be signed numeric literals, or an already declared constant, possibly preceded by unary minus.
---
 modcc/parser.cpp                  | 86 +++++++++++++++++++++++++++++--
 modcc/parser.hpp                  |  3 ++
 modcc/token.cpp                   |  2 +
 modcc/token.hpp                   |  2 +-
 test/unit-modcc/test_parser.cpp   | 51 ++++++++++++++++++
 test/unit-modcc/test_printers.cpp | 38 ++++++++++++++
 6 files changed, 178 insertions(+), 4 deletions(-)

diff --git a/modcc/parser.cpp b/modcc/parser.cpp
index aa02f3ec..0e9e064b 100644
--- a/modcc/parser.cpp
+++ b/modcc/parser.cpp
@@ -107,6 +107,9 @@ bool Parser::parse() {
             case tok::units :
                 parse_units_block();
                 break;
+            case tok::constant :
+                parse_constant_block();
+                break;
             case tok::parameter :
                 parse_parameter_block();
                 break;
@@ -558,6 +561,60 @@ parm_exit:
     return;
 }
 
+void Parser::parse_constant_block() {
+    get_token();
+
+    // assert that the block starts with a curly brace
+    if(token_.type != tok::lbrace) {
+        error(pprintf("CONSTANT block must start with a curly brace {, found '%'", token_.spelling));
+        return;
+    }
+
+    get_token();
+    while(token_.type!=tok::rbrace && token_.type!=tok::eof) {
+        int line = location_.line;
+        std::string name, value;
+
+        // read the constant name
+        if(token_.type != tok::identifier) {
+            error(pprintf("CONSTANT block unexpected symbol '%s'", token_.spelling));
+            return;
+        }
+        name = token_.spelling; // save full token
+
+        get_token();
+
+        // look for equality
+        if(token_.type==tok::eq) {
+            get_token(); // consume '='
+            value = value_literal();
+            if(status_ == lexerStatus::error) {
+                return;
+            }
+        }
+
+        // get the units
+        if(line==location_.line && token_.type == tok::lparen) {
+            unit_description();
+            if(status_ == lexerStatus::error) {
+                return;
+            }
+        }
+
+        constants_map_.insert({name, value});
+    }
+
+    // error if EOF before closing curly brace
+    if(token_.type==tok::eof) {
+        error("CONSTANT block must have closing '}'");
+        return;
+    }
+
+    get_token(); // consume closing brace
+
+    return;
+}
+
 void Parser::parse_assigned_block() {
     AssignedBlock block;
 
@@ -627,18 +684,31 @@ ass_exit:
 // Parse a value (integral or real) with possible preceding unary minus,
 // and return as a string.
 std::string Parser::value_literal() {
-    std::string value;
+    bool negate = false;
 
     if(token_.type==tok::minus) {
-        value = "-";
+        negate = true;
+        get_token();
+    }
+
+    if (constants_map_.find(token_.spelling) != constants_map_.end()) {
+        // Remove double negation
+        auto v = constants_map_.at(token_.spelling);
+        if (v.at(0) == '-' && negate) {
+            v.erase(0,1);
+            negate = false;
+        }
+        auto value = negate ? "-" + v : v;
         get_token();
+        return value;
     }
+
     if(token_.type != tok::integer && token_.type != tok::real) {
         error(pprintf("numeric constant not an integer or real number '%'", token_));
         return "";
     }
     else {
-        value += token_.spelling;
+        auto value = negate ? "-" + token_.spelling : token_.spelling;
         get_token();
         return value;
     }
@@ -936,6 +1006,16 @@ expression_ptr Parser::parse_statement() {
 }
 
 expression_ptr Parser::parse_identifier() {
+    if (constants_map_.find(token_.spelling) != constants_map_.end()) {
+        // save location and value of the identifier
+        auto id = make_expression<NumberExpression>(token_.location, constants_map_.at(token_.spelling));
+
+        // consume the number
+        get_token();
+
+        // return the value of the constant
+        return id;
+    }
     // save name and location of the identifier
     auto id = make_expression<IdentifierExpression>(token_.location, token_.spelling);
 
diff --git a/modcc/parser.hpp b/modcc/parser.hpp
index 8590938c..3c2a297a 100644
--- a/modcc/parser.hpp
+++ b/modcc/parser.hpp
@@ -53,9 +53,12 @@ public:
     void parse_state_block();
     void parse_units_block();
     void parse_parameter_block();
+    void parse_constant_block();
     void parse_assigned_block();
     void parse_title();
 
+    std::unordered_map<std::string, std::string> constants_map_;
+
 private:
     Module *module_;
 
diff --git a/modcc/token.cpp b/modcc/token.cpp
index ba89e2ce..169bbc51 100644
--- a/modcc/token.cpp
+++ b/modcc/token.cpp
@@ -25,6 +25,7 @@ static Keyword keywords[] = {
     {"NEURON",      tok::neuron},
     {"UNITS",       tok::units},
     {"PARAMETER",   tok::parameter},
+    {"CONSTANT",    tok::constant},
     {"ASSIGNED",    tok::assigned},
     {"STATE",       tok::state},
     {"BREAKPOINT",  tok::breakpoint},
@@ -95,6 +96,7 @@ static TokenString token_strings[] = {
     {"NEURON",      tok::neuron},
     {"UNITS",       tok::units},
     {"PARAMETER",   tok::parameter},
+    {"CONSTANT",    tok::constant},
     {"ASSIGNED",    tok::assigned},
     {"STATE",       tok::state},
     {"BREAKPOINT",  tok::breakpoint},
diff --git a/modcc/token.hpp b/modcc/token.hpp
index 149ee743..ae03b3d5 100644
--- a/modcc/token.hpp
+++ b/modcc/token.hpp
@@ -52,7 +52,7 @@ enum class tok {
     // block keywoards
     title,
     neuron, units, parameter,
-    assigned, state, breakpoint,
+    constant, assigned, state, breakpoint,
     derivative, kinetic, procedure, initial, function,
     net_receive,
 
diff --git a/test/unit-modcc/test_parser.cpp b/test/unit-modcc/test_parser.cpp
index 6551eb10..f251d202 100644
--- a/test/unit-modcc/test_parser.cpp
+++ b/test/unit-modcc/test_parser.cpp
@@ -108,6 +108,57 @@ TEST(Parser, procedure) {
     }
 }
 
+TEST(Parser, load_constant) {
+    char str[] = {
+            "CONSTANT {\n"
+            "  t0 = -1.2\n"
+            "  t1 = 0.5\n"
+            "  t2 = -t0\n"
+            "  t3 = -t1\n"
+            "}"
+    };
+
+    Parser p(str);
+    p.parse_constant_block();
+    EXPECT_TRUE(p.status()==lexerStatus::happy);
+
+    EXPECT_TRUE(p.constants_map_.find("t0") != p.constants_map_.end());
+    EXPECT_EQ("-1.2", p.constants_map_.at("t0"));
+
+    EXPECT_TRUE(p.constants_map_.find("t1") != p.constants_map_.end());
+    EXPECT_EQ("0.5", p.constants_map_.at("t1"));
+
+    EXPECT_TRUE(p.constants_map_.find("t2") != p.constants_map_.end());
+    EXPECT_EQ("1.2", p.constants_map_.at("t2"));
+
+    EXPECT_TRUE(p.constants_map_.find("t3") != p.constants_map_.end());
+    EXPECT_EQ("-0.5", p.constants_map_.at("t3"));
+}
+
+TEST(Parser, parameters_from_constant) {
+    const char str[] =
+            "PARAMETER {   \n"
+            "  tau = -t0   \n"
+            "  e = t1      \n"
+            "}";
+
+    expression_ptr null;
+    Module m(str, str+std::strlen(str), "");
+    Parser p(m, false);
+    p.constants_map_.insert({"t0","-0.5"});
+    p.constants_map_.insert({"t1","1.2"});
+    p.parse_parameter_block();
+
+    EXPECT_EQ(lexerStatus::happy, p.status());
+    verbose_print(null, p, str);
+
+    auto param_block = m.parameter_block();
+    EXPECT_EQ("tau", param_block.parameters[0].name());
+    EXPECT_EQ("0.5", param_block.parameters[0].value);
+    EXPECT_EQ("e", param_block.parameters[1].name());
+    EXPECT_EQ("1.2", param_block.parameters[1].value);
+}
+
 TEST(Parser, net_receive) {
     char str[] =
         "NET_RECEIVE (x, y) {   \n"
diff --git a/test/unit-modcc/test_printers.cpp b/test/unit-modcc/test_printers.cpp
index 8dbbedaf..530f266e 100644
--- a/test/unit-modcc/test_printers.cpp
+++ b/test/unit-modcc/test_printers.cpp
@@ -158,3 +158,41 @@ TEST(CPrinter, proc_body) {
         EXPECT_EQ(strip(tc.expected), strip(text));
     }
 }
+
+TEST(CPrinter, proc_body_const) {
+    std::vector<testcase> testcases = {
+            {
+                    "PROCEDURE trates(v) {\n"
+                    "    mtau = 0.5 - t0 + t1\n"
+                    "}"
+                    ,
+                    "mtau[i_] = 0.5 - -0.5 + 1.2;\n"
+            }
+    };
+
+    // create a scope that contains the symbols used in the tests
+    Scope<Symbol>::symbol_map globals;
+    globals["mtau"] = make_symbol<VariableExpression>(Location(), "mtau");
+
+    for (const auto& tc: testcases) {
+        Parser p(tc.source);
+        p.constants_map_.insert({"t0","-0.5"});
+        p.constants_map_.insert({"t1","1.2"});
+        expression_ptr e = p.parse_procedure();
+        ASSERT_TRUE(e->is_symbol());
+
+        auto procname = e->is_symbol()->name();
+        auto& proc = (globals[procname] = symbol_ptr(e.release()->is_symbol()));
+
+        proc->semantic(globals);
+        std::stringstream out;
+        auto v = std::make_unique<CPrinter>(out);
+        proc->is_procedure()->body()->accept(v.get());
+        std::string text = out.str();
+
+        verbose_print(proc->is_procedure()->body()->to_string());
+        verbose_print(" :--: ", text);
+
+        EXPECT_EQ(strip(tc.expected), strip(text));
+    }
+}
\ No newline at end of file
-- 
GitLab