From 2775b3f3f59caabc1a84ac75b857d48c286eecec Mon Sep 17 00:00:00 2001
From: Nora Abi Akar <nora.abiakar@gmail.com>
Date: Thu, 11 Feb 2021 14:15:12 +0100
Subject: [PATCH] Modcc: comments & seg fault (#1360)

* fix nmodl comment parsing
* proper handling of \r
* fix segmentation fault when api_state is null
* add unit tests
---
 modcc/lexer.cpp                |  24 +++-
 modcc/module.cpp               | 235 ++++++++++++++++-----------------
 test/unit-modcc/test_lexer.cpp |  67 +++++++---
 3 files changed, 183 insertions(+), 143 deletions(-)

diff --git a/modcc/lexer.cpp b/modcc/lexer.cpp
index 086a5a80..4ada753b 100644
--- a/modcc/lexer.cpp
+++ b/modcc/lexer.cpp
@@ -19,9 +19,6 @@ inline bool is_alpha(char c) {
 inline bool is_alphanumeric(char c) {
     return (is_numeric(c) || is_alpha(c) );
 }
-inline bool is_whitespace(char c) {
-    return (c==' ' || c=='\t' || c=='\v' || c=='\f' || c=='\n' || c=='\r');
-}
 inline bool is_eof(char c) {
     return (c==0);
 }
@@ -103,8 +100,25 @@ Token Lexer::parse() {
                 if (id == "UNITSON" || id == "UNITSOFF") continue;
                 if (id == "COMMENT") {
                     while (!is_eof(*current_)) {
-                        while (is_whitespace(*current_) || !is_alpha(*current_)) current_++;
-                        if (identifier() == "ENDCOMMENT") break;
+                        while ((*current_ != '\n') && (*current_ != '\r') && !is_alpha(*current_)) {
+                            current_++;
+                        }
+                        if (*current_ == '\n') {
+                            current_++;
+                            line_ = current_;
+                            location_.line++;
+                        }
+                        else if (*current_ == '\r') {
+                            current_++;
+                            if(*current_ != '\n') {
+                                error_string_ = pprintf("bad line ending: \\n must follow \\r");
+                                return t;
+                            }
+                            current_++;
+                            line_ = current_;
+                            location_.line++;
+                        }
+                        else if (identifier() == "ENDCOMMENT") break;
                     }
                     continue;
                 }
diff --git a/modcc/module.cpp b/modcc/module.cpp
index d099ef23..dee4f0ab 100644
--- a/modcc/module.cpp
+++ b/modcc/module.cpp
@@ -357,156 +357,155 @@ bool Module::semantic() {
     auto api_state  = state_api.first;
     auto breakpoint = state_api.second; // implies we are building the `nrn_state()` method.
 
+    if(!breakpoint) {
+        error("a BREAKPOINT block is required");
+        return false;
+    }
+
     api_state->semantic(symbols_);
     scope_ptr nrn_state_scope = api_state->scope();
 
-    if(breakpoint) {
-        // Grab SOLVE statements, put them in `nrn_state` after translation.
-        bool found_solve = false;
-        bool found_non_solve = false;
-        std::set<std::string> solved_ids;
-
-        for(auto& e: (breakpoint->body()->statements())) {
-            SolveExpression* solve_expression = e->is_solve_statement();
-            LocalDeclaration* local_expression = e->is_local_declaration();
-            if(local_expression) {
-                continue;
-            }
-            if(!solve_expression) {
-                found_non_solve = true;
-                continue;
-            }
-            if(found_non_solve) {
-                error("SOLVE statements must come first in BREAKPOINT block",
-                    e->location());
-                return false;
-            }
+    // Grab SOLVE statements, put them in `nrn_state` after translation.
+    bool found_solve = false;
+    bool found_non_solve = false;
+    std::set<std::string> solved_ids;
 
-            found_solve = true;
-            std::unique_ptr<SolverVisitorBase> solver;
+    for(auto& e: (breakpoint->body()->statements())) {
+        SolveExpression* solve_expression = e->is_solve_statement();
+        LocalDeclaration* local_expression = e->is_local_declaration();
+        if(local_expression) {
+            continue;
+        }
+        if(!solve_expression) {
+            found_non_solve = true;
+            continue;
+        }
+        if(found_non_solve) {
+            error("SOLVE statements must come first in BREAKPOINT block",
+                e->location());
+            return false;
+        }
 
-            switch(solve_expression->method()) {
-            case solverMethod::cnexp:
-                solver = std::make_unique<CnexpSolverVisitor>();
-                break;
-            case solverMethod::sparse: {
-                solver = std::make_unique<SparseSolverVisitor>(solve_expression->variant());
-                break;
-            }
-            case solverMethod::none:
-                solver = std::make_unique<DirectSolverVisitor>();
-                break;
-            }
+        found_solve = true;
+        std::unique_ptr<SolverVisitorBase> solver;
 
-            // If the derivative block is a kinetic block, perform kinetic
-            // rewrite first.
+        switch(solve_expression->method()) {
+        case solverMethod::cnexp:
+            solver = std::make_unique<CnexpSolverVisitor>();
+            break;
+        case solverMethod::sparse: {
+            solver = std::make_unique<SparseSolverVisitor>(solve_expression->variant());
+            break;
+        }
+        case solverMethod::none:
+            solver = std::make_unique<DirectSolverVisitor>();
+            break;
+        }
 
-            auto deriv = solve_expression->procedure();
+        // If the derivative block is a kinetic block, perform kinetic
+        // rewrite first.
 
-            if (deriv->kind()==procedureKind::kinetic) {
-                auto rewrite_body = kinetic_rewrite(deriv->body());
-                bool linear_kinetic = true;
+        auto deriv = solve_expression->procedure();
 
-                for (auto& s: rewrite_body->is_block()->statements()) {
-                    if(s->is_assignment() && !state_vars.empty()) {
-                        linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
-                        linear_kinetic &= r.is_linear;
-                    }
-                }
+        if (deriv->kind()==procedureKind::kinetic) {
+            auto rewrite_body = kinetic_rewrite(deriv->body());
+            bool linear_kinetic = true;
 
-                if (!linear_kinetic) {
-                    solver = std::make_unique<SparseNonlinearSolverVisitor>();
+            for (auto& s: rewrite_body->is_block()->statements()) {
+                if(s->is_assignment() && !state_vars.empty()) {
+                    linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
+                    linear_kinetic &= r.is_linear;
                 }
-
-                rewrite_body->semantic(nrn_state_scope);
-                rewrite_body->accept(solver.get());
             }
-            else if (deriv->kind()==procedureKind::linear) {
-                solver = std::make_unique<LinearSolverVisitor>(state_vars);
-                auto rewrite_body = linear_rewrite(deriv->body(), state_vars);
 
-                rewrite_body->semantic(nrn_state_scope);
-                rewrite_body->accept(solver.get());
+            if (!linear_kinetic) {
+                solver = std::make_unique<SparseNonlinearSolverVisitor>();
             }
-            else {
-                deriv->body()->accept(solver.get());
-                for (auto& s: deriv->body()->statements()) {
-                    if(s->is_assignment() && !state_vars.empty()) {
-                        linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
-                        linear &= r.is_linear;
-                        linear &= r.is_homogeneous;
-                    }
+
+            rewrite_body->semantic(nrn_state_scope);
+            rewrite_body->accept(solver.get());
+        }
+        else if (deriv->kind()==procedureKind::linear) {
+            solver = std::make_unique<LinearSolverVisitor>(state_vars);
+            auto rewrite_body = linear_rewrite(deriv->body(), state_vars);
+
+            rewrite_body->semantic(nrn_state_scope);
+            rewrite_body->accept(solver.get());
+        }
+        else {
+            deriv->body()->accept(solver.get());
+            for (auto& s: deriv->body()->statements()) {
+                if(s->is_assignment() && !state_vars.empty()) {
+                    linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
+                    linear &= r.is_linear;
+                    linear &= r.is_homogeneous;
                 }
             }
+        }
 
-            if (auto solve_block = solver->as_block(false)) {
-                // Check that we didn't solve an already solved variable.
-                for (const auto& id: solver->solved_identifiers()) {
-                    if (solved_ids.count(id)>0) {
-                        error("Variable "+id+" solved twice!", e->location());
-                        return false;
-                    }
-                    solved_ids.insert(id);
+        if (auto solve_block = solver->as_block(false)) {
+            // Check that we didn't solve an already solved variable.
+            for (const auto& id: solver->solved_identifiers()) {
+                if (solved_ids.count(id)>0) {
+                    error("Variable "+id+" solved twice!", e->location());
+                    return false;
                 }
+                solved_ids.insert(id);
+            }
 
-                // May have now redundant local variables; remove these first.
-                solve_block->semantic(nrn_state_scope);
-                solve_block = remove_unused_locals(solve_block->is_block());
+            // May have now redundant local variables; remove these first.
+            solve_block->semantic(nrn_state_scope);
+            solve_block = remove_unused_locals(solve_block->is_block());
 
-                // Copy body into nrn_state.
-                for (auto& stmt: solve_block->is_block()->statements()) {
-                    api_state->body()->statements().push_back(std::move(stmt));
-                }
-            }
-            else {
-                // Something went wrong: copy errors across.
-                append_errors(solver->errors());
-                return false;
+            // Copy body into nrn_state.
+            for (auto& stmt: solve_block->is_block()->statements()) {
+                api_state->body()->statements().push_back(std::move(stmt));
             }
         }
-
-        // handle the case where there is a SOLVE in BREAKPOINT (which is the typical case)
-        if (found_solve) {
-            // Redo semantic pass in order to elimate any removed local symbols.
-            api_state->semantic(symbols_);
+        else {
+            // Something went wrong: copy errors across.
+            append_errors(solver->errors());
+            return false;
         }
+    }
 
-        // Run remove locals pass again on the whole body in case `dt` was never used.
-        api_state->body(remove_unused_locals(api_state->body()));
+    // handle the case where there is a SOLVE in BREAKPOINT (which is the typical case)
+    if (found_solve) {
+        // Redo semantic pass in order to eliminate any removed local symbols.
         api_state->semantic(symbols_);
+    }
 
-        //..........................................................
-        // nrn_current : update contributions to currents
-        //..........................................................
-        NrnCurrentRewriter nrn_current_rewriter;
-        breakpoint->accept(&nrn_current_rewriter);
-
-        for (auto& s: breakpoint->body()->statements()) {
-            if(s->is_assignment() && !state_vars.empty()) {
-                linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
-                linear &= r.is_linear;
-                linear &= r.is_homogeneous;
-            }
-        }
+    // Run remove locals pass again on the whole body in case `dt` was never used.
+    api_state->body(remove_unused_locals(api_state->body()));
+    api_state->semantic(symbols_);
 
-        auto nrn_current_block = nrn_current_rewriter.as_block();
-        if (!nrn_current_block) {
-            append_errors(nrn_current_rewriter.errors());
-            return false;
-        }
+    //..........................................................
+    // nrn_current : update contributions to currents
+    //..........................................................
+    NrnCurrentRewriter nrn_current_rewriter;
+    breakpoint->accept(&nrn_current_rewriter);
 
-        symbols_["nrn_current"] =
-            make_symbol<APIMethod>(
-                    breakpoint->location(), "nrn_current",
-                    std::vector<expression_ptr>(),
-                    constant_simplify(nrn_current_block));
-        symbols_["nrn_current"]->semantic(symbols_);
+    for (auto& s: breakpoint->body()->statements()) {
+        if(s->is_assignment() && !state_vars.empty()) {
+            linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
+            linear &= r.is_linear;
+            linear &= r.is_homogeneous;
+        }
     }
-    else {
-        error("a BREAKPOINT block is required");
+
+    auto nrn_current_block = nrn_current_rewriter.as_block();
+    if (!nrn_current_block) {
+        append_errors(nrn_current_rewriter.errors());
         return false;
     }
 
+    symbols_["nrn_current"] =
+        make_symbol<APIMethod>(
+                breakpoint->location(), "nrn_current",
+                std::vector<expression_ptr>(),
+                constant_simplify(nrn_current_block));
+    symbols_["nrn_current"]->semantic(symbols_);
+
     if (has_symbol("net_receive", symbolKind::procedure)) {
         auto net_rec_api = make_empty_api_method("net_rec_api", "net_receive");
         if (net_rec_api.second) {
diff --git a/test/unit-modcc/test_lexer.cpp b/test/unit-modcc/test_lexer.cpp
index c6fb38b0..b59a8256 100644
--- a/test/unit-modcc/test_lexer.cpp
+++ b/test/unit-modcc/test_lexer.cpp
@@ -254,26 +254,53 @@ TEST(Lexer, braces) {
 
 // test comments
 TEST(Lexer, comments) {
-    char string[] = "foo:this is one line\n"
-                    "bar : another comment\n"
-                    "foobar ? another comment\n";
-    VerboseLexer lexer(string, string+sizeof(string));
-
-    auto t1 = lexer.parse();
-    EXPECT_EQ(t1.type, tok::identifier);
-
-    auto t2 = lexer.parse();
-    EXPECT_EQ(t2.type, tok::identifier);
-    EXPECT_EQ(t2.spelling, "bar");
-    EXPECT_EQ(t2.location.line, 2);
-
-    auto t3 = lexer.parse();
-    EXPECT_EQ(t3.type, tok::identifier);
-    EXPECT_EQ(t3.spelling, "foobar");
-    EXPECT_EQ(t3.location.line, 3);
-
-    auto t4 = lexer.parse();
-    EXPECT_EQ(t4.type, tok::eof);
+    {
+        char string[] = "foo:this is one line\n"
+                        "bar : another comment\n"
+                        "foobar ? another comment\n";
+        VerboseLexer lexer(string, string + sizeof(string));
+
+        auto t1 = lexer.parse();
+        EXPECT_EQ(t1.type, tok::identifier);
+
+        auto t2 = lexer.parse();
+        EXPECT_EQ(t2.type, tok::identifier);
+        EXPECT_EQ(t2.spelling, "bar");
+        EXPECT_EQ(t2.location.line, 2);
+
+        auto t3 = lexer.parse();
+        EXPECT_EQ(t3.type, tok::identifier);
+        EXPECT_EQ(t3.spelling, "foobar");
+        EXPECT_EQ(t3.location.line, 3);
+
+        auto t4 = lexer.parse();
+        EXPECT_EQ(t4.type, tok::eof);
+    }
+    {
+        char string[] = "COMMENT line 1\n"
+                        "comment line 2\n"
+                        "ENDCOMMENT \n"
+                        "foo\n"
+                        "COMMENT <some special comment.> ENDCOMMENT\n"
+                        "bar\n"
+                        "COMMENT\n"
+                        "some info here! ENDCOMMENT";
+        VerboseLexer lexer(string, string + sizeof(string));
+
+        auto t1 = lexer.parse();
+        EXPECT_EQ(t1.type, tok::identifier);
+        EXPECT_EQ(t1.spelling, "foo");
+        EXPECT_EQ(t1.location.line, 4);
+
+        auto t2 = lexer.parse();
+        EXPECT_EQ(t2.type, tok::identifier);
+        EXPECT_EQ(t2.spelling, "bar");
+        EXPECT_EQ(t2.location.line, 6);
+
+        auto t3 = lexer.parse();
+        EXPECT_EQ(t3.type, tok::eof);
+        EXPECT_EQ(t3.location.line, 8);
+    }
 }
 
 // test numbers
-- 
GitLab