From 8494607e49c68400eba161cf3524cab0b9116a67 Mon Sep 17 00:00:00 2001
From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com>
Date: Mon, 5 Sep 2022 20:04:41 +0200
Subject: [PATCH] Guard against errors in linearity test. (#1964)

- Print better modcc errors
- Catch errors in symbolic diff in linearity test
- Abort upon such errors and advise for a different solver
---
 modcc/error.hpp                            |  3 ---
 modcc/expression.hpp                       | 15 ++++++++++++
 modcc/modcc.cpp                            |  5 ++--
 modcc/module.cpp                           |  4 ++--
 modcc/solvers.cpp                          |  7 +++++-
 modcc/symdiff.cpp                          | 21 ++++++++++++----
 modcc/symdiff.hpp                          |  2 +-
 test/unit-modcc/mod_files/bug-1893-bad.mod | 28 ++++++++++++++++++++++
 test/unit-modcc/mod_files/bug-1893.mod     | 28 ++++++++++++++++++++++
 test/unit-modcc/test_module.cpp            | 21 ++++++++++++++++
 10 files changed, 121 insertions(+), 13 deletions(-)
 create mode 100644 test/unit-modcc/mod_files/bug-1893-bad.mod
 create mode 100644 test/unit-modcc/mod_files/bug-1893.mod

diff --git a/modcc/error.hpp b/modcc/error.hpp
index e66fc454..919d1248 100644
--- a/modcc/error.hpp
+++ b/modcc/error.hpp
@@ -71,6 +71,3 @@ public:
 private:
     error_entry error_info_;
 };
-
-
-
diff --git a/modcc/expression.hpp b/modcc/expression.hpp
index da885630..d369e7a9 100644
--- a/modcc/expression.hpp
+++ b/modcc/expression.hpp
@@ -18,6 +18,7 @@
 class Visitor;
 
 class ARB_LIBMODCC_API Expression;
+struct ARB_LIBMODCC_API ErrorExpression;
 class ARB_LIBMODCC_API CallExpression;
 class ARB_LIBMODCC_API BlockExpression;
 class ARB_LIBMODCC_API IfExpression;
@@ -209,6 +210,20 @@ protected:
     scope_ptr scope_;
 };
 
+struct ARB_LIBMODCC_API ErrorExpression : public Expression {
+    explicit ErrorExpression(Location location): Expression(location)
+    {}
+
+    std::string to_string() const override {
+        return "Error" + error_string_;
+    }
+
+    void accept(Visitor *) override {
+        throw compiler_exception{"Attempted to visit error expression.", location()};
+    }
+};
+
+
 class ARB_LIBMODCC_API Symbol : public Expression {
 public :
     Symbol(Location loc, std::string name, symbolKind kind)
diff --git a/modcc/modcc.cpp b/modcc/modcc.cpp
index 6e9b51ba..3eea843b 100644
--- a/modcc/modcc.cpp
+++ b/modcc/modcc.cpp
@@ -25,12 +25,12 @@ using std::cerr;
 // Options and option parsing:
 
 int report_error(const std::string& message) {
-    cerr << red("error: ") << message << "\n";
+    cerr << red("error trace:\n") << message << "\n";
     return 1;
 }
 
 int report_ice(const std::string& message) {
-    cerr << red("internal compiler error: ") << message << "\n"
+    cerr << red("internal compiler error:\n") << message << "\n"
          << "\nPlease report this error to the modcc developers.\n";
     return 1;
 }
@@ -223,6 +223,7 @@ int main(int argc, char **argv) {
         emit_header("semantic analysis");
         m.semantic();
         if (m.has_warning()) {
+            cerr << yellow("Warnings:\n");
             cerr << m.warning_string() << "\n";
         }
         if (m.has_error()) {
diff --git a/modcc/module.cpp b/modcc/module.cpp
index 6d6a56a0..7dafa6ed 100644
--- a/modcc/module.cpp
+++ b/modcc/module.cpp
@@ -119,7 +119,7 @@ std::string Module::error_string() const {
     std::string str;
     for (const error_entry& entry: errors()) {
         if (!str.empty()) str += '\n';
-        str += red("error   ");
+        str += red("  * ");
         str += white(pprintf("%:% ", source_name(), entry.location));
         str += entry.message;
     }
@@ -130,7 +130,7 @@ std::string Module::warning_string() const {
     std::string str;
     for (auto& entry: warnings()) {
         if (!str.empty()) str += '\n';
-        str += purple("warning   ");
+        str += purple("  * ");
         str += white(pprintf("%:% ", source_name(), entry.location));
         str += entry.message;
     }
diff --git a/modcc/solvers.cpp b/modcc/solvers.cpp
index 1d31e55d..23d89a41 100644
--- a/modcc/solvers.cpp
+++ b/modcc/solvers.cpp
@@ -41,7 +41,13 @@ void CnexpSolverVisitor::visit(AssignmentExpression *e) {
     }
 
     auto s = deriv->name();
+
     linear_test_result r = linear_test(rhs, dvars_);
+    if (r.has_error()) {
+        append_errors(r.errors());
+        error({"CNExp: Could not determine linearity, maybe use a different solver?", loc});
+        return;
+    }
 
     if (!r.monolinear(s)) {
         error({"System not diagonal linear for cnexp", loc});
@@ -52,7 +58,6 @@ void CnexpSolverVisitor::visit(AssignmentExpression *e) {
     if (!coef || is_zero(coef)) {
         // s' = b becomes s = s + b*dt; use b_ as a local variable for
         // the constant term b.
-
         auto local_b_term = make_unique_local_assign(scope, r.constant.get(), "b_");
         statements_.push_back(std::move(local_b_term.local_decl));
         statements_.push_back(std::move(local_b_term.assignment));
diff --git a/modcc/symdiff.cpp b/modcc/symdiff.cpp
index 4946a1b8..1b4028eb 100644
--- a/modcc/symdiff.cpp
+++ b/modcc/symdiff.cpp
@@ -562,7 +562,17 @@ ARB_LIBMODCC_API expression_ptr symbolic_pdiff(Expression* e, const std::string&
     SymPDiffVisitor pdiff_visitor(id);
     e->accept(&pdiff_visitor);
 
-    if (pdiff_visitor.has_error()) return nullptr;
+    if (pdiff_visitor.has_error()) {
+        std::string errors, sep = "";
+
+        for (const auto& error: pdiff_visitor.errors()) {
+            errors += sep + error.message;
+            sep = "\n";
+        }
+        auto res = std::make_unique<ErrorExpression>(e->location());
+        res->error(errors);
+        return res;
+    }
 
     return constant_simplify(pdiff_visitor.result());
 }
@@ -666,17 +676,20 @@ ARB_LIBMODCC_API linear_test_result linear_test(Expression* e, const std::vector
     result.constant = e->clone();
     for (const auto& id: vars) {
         auto coef = symbolic_pdiff(e, id);
-        if (!coef) {
-            return linear_test_result{};
+        if (coef->has_error()) {
+            auto res = linear_test_result{};
+            res.error({coef->error_message(), loc});
+            return res;
         }
+        if (!coef) return linear_test_result{};
         if (!is_zero(coef)) result.coef[id] = std::move(coef);
-
         result.constant = substitute(result.constant, id, zero());
     }
 
     ConstantSimplifyVisitor csimp_visitor;
     result.constant->accept(&csimp_visitor);
     result.constant = csimp_visitor.result();
+    if (result.constant.get() == nullptr) throw compiler_exception{"Linear test: simplification of the constant term failed.", loc};
 
     // linearity test: take second order derivatives, test against zero.
     result.is_linear = true;
diff --git a/modcc/symdiff.hpp b/modcc/symdiff.hpp
index cb58ffae..b4387489 100644
--- a/modcc/symdiff.hpp
+++ b/modcc/symdiff.hpp
@@ -81,7 +81,7 @@ inline expression_ptr substitute(const expression_ptr& e, const substitute_map&
 
 // Linearity testing
 
-struct linear_test_result {
+struct linear_test_result: public error_stack {
     bool is_linear = false;
     bool is_homogeneous = false;
     expression_ptr constant;
diff --git a/test/unit-modcc/mod_files/bug-1893-bad.mod b/test/unit-modcc/mod_files/bug-1893-bad.mod
new file mode 100644
index 00000000..fde7f171
--- /dev/null
+++ b/test/unit-modcc/mod_files/bug-1893-bad.mod
@@ -0,0 +1,28 @@
+NEURON {
+    POINT_PROCESS bug_1893
+}
+
+INITIAL {
+    c = 0
+    rho = 0
+    theta_p = 0
+}
+
+STATE {
+    c
+    rho
+    theta_p
+}
+
+PARAMETER {
+    tau_c = 150 (ms)
+}
+
+BREAKPOINT {
+    SOLVE state METHOD cnexp
+}
+
+DERIVATIVE state {
+    c' = -c/tau_c
+    rho' = (c - theta_p) > 0
+}
diff --git a/test/unit-modcc/mod_files/bug-1893.mod b/test/unit-modcc/mod_files/bug-1893.mod
new file mode 100644
index 00000000..14cfd5e5
--- /dev/null
+++ b/test/unit-modcc/mod_files/bug-1893.mod
@@ -0,0 +1,28 @@
+NEURON {
+    POINT_PROCESS bug_1893
+}
+
+INITIAL {
+    c = 0
+    rho = 0
+    theta_p = 0
+}
+
+STATE {
+    c
+    rho
+    theta_p
+}
+
+PARAMETER {
+    tau_c = 150 (ms)
+}
+
+BREAKPOINT {
+    SOLVE state METHOD sparse
+}
+
+DERIVATIVE state {
+    c' = -c/tau_c
+    rho' = (c - theta_p) > 0
+}
diff --git a/test/unit-modcc/test_module.cpp b/test/unit-modcc/test_module.cpp
index 34ff5b4d..9e9a1db0 100644
--- a/test/unit-modcc/test_module.cpp
+++ b/test/unit-modcc/test_module.cpp
@@ -122,3 +122,24 @@ TEST(Module, read_write_ion) {
     EXPECT_TRUE(p.parse());
     EXPECT_TRUE(m.semantic());
 }
+
+// Regression test in #1893 we found that the solver segfaults when handed a
+// naked comparison statement.
+TEST(Module, solver_bug_1893) {
+    {
+        Module m(io::read_all(DATADIR "/mod_files/bug-1893.mod"), "bug-1893.mod");
+        EXPECT_NE(m.buffer().size(), 0);
+
+        Parser p(m, false);
+        EXPECT_TRUE(p.parse());
+        EXPECT_TRUE(m.semantic());
+    }
+    {
+        Module m(io::read_all(DATADIR "/mod_files/bug-1893-bad.mod"), "bug-1893.mod");
+        EXPECT_NE(m.buffer().size(), 0);
+
+        Parser p(m, false);
+        EXPECT_TRUE(p.parse());
+        EXPECT_FALSE(m.semantic());
+    }
+}
-- 
GitLab