From 8494607e49c68400eba161cf3524cab0b9116a67 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Mon, 5 Sep 2022 20:04:41 +0200 Subject: [PATCH] Guard against errors in linearity test. (#1964) - Print better modcc errors - Catch errors in symbolic diff in linearity test - Abort upon such errors and advise for a different solver --- modcc/error.hpp | 3 --- modcc/expression.hpp | 15 ++++++++++++ modcc/modcc.cpp | 5 ++-- modcc/module.cpp | 4 ++-- modcc/solvers.cpp | 7 +++++- modcc/symdiff.cpp | 21 ++++++++++++---- modcc/symdiff.hpp | 2 +- test/unit-modcc/mod_files/bug-1893-bad.mod | 28 ++++++++++++++++++++++ test/unit-modcc/mod_files/bug-1893.mod | 28 ++++++++++++++++++++++ test/unit-modcc/test_module.cpp | 21 ++++++++++++++++ 10 files changed, 121 insertions(+), 13 deletions(-) create mode 100644 test/unit-modcc/mod_files/bug-1893-bad.mod create mode 100644 test/unit-modcc/mod_files/bug-1893.mod diff --git a/modcc/error.hpp b/modcc/error.hpp index e66fc454..919d1248 100644 --- a/modcc/error.hpp +++ b/modcc/error.hpp @@ -71,6 +71,3 @@ public: private: error_entry error_info_; }; - - - diff --git a/modcc/expression.hpp b/modcc/expression.hpp index da885630..d369e7a9 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -18,6 +18,7 @@ class Visitor; class ARB_LIBMODCC_API Expression; +struct ARB_LIBMODCC_API ErrorExpression; class ARB_LIBMODCC_API CallExpression; class ARB_LIBMODCC_API BlockExpression; class ARB_LIBMODCC_API IfExpression; @@ -209,6 +210,20 @@ protected: scope_ptr scope_; }; +struct ARB_LIBMODCC_API ErrorExpression : public Expression { + explicit ErrorExpression(Location location): Expression(location) + {} + + std::string to_string() const override { + return "Error" + error_string_; + } + + void accept(Visitor *) override { + throw compiler_exception{"Attempted to visit error expression.", location()}; + } +}; + + class ARB_LIBMODCC_API Symbol : public Expression { public : Symbol(Location loc, std::string name, symbolKind kind) diff --git a/modcc/modcc.cpp b/modcc/modcc.cpp index 6e9b51ba..3eea843b 100644 --- a/modcc/modcc.cpp +++ b/modcc/modcc.cpp @@ -25,12 +25,12 @@ using std::cerr; // Options and option parsing: int report_error(const std::string& message) { - cerr << red("error: ") << message << "\n"; + cerr << red("error trace:\n") << message << "\n"; return 1; } int report_ice(const std::string& message) { - cerr << red("internal compiler error: ") << message << "\n" + cerr << red("internal compiler error:\n") << message << "\n" << "\nPlease report this error to the modcc developers.\n"; return 1; } @@ -223,6 +223,7 @@ int main(int argc, char **argv) { emit_header("semantic analysis"); m.semantic(); if (m.has_warning()) { + cerr << yellow("Warnings:\n"); cerr << m.warning_string() << "\n"; } if (m.has_error()) { diff --git a/modcc/module.cpp b/modcc/module.cpp index 6d6a56a0..7dafa6ed 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -119,7 +119,7 @@ std::string Module::error_string() const { std::string str; for (const error_entry& entry: errors()) { if (!str.empty()) str += '\n'; - str += red("error "); + str += red(" * "); str += white(pprintf("%:% ", source_name(), entry.location)); str += entry.message; } @@ -130,7 +130,7 @@ std::string Module::warning_string() const { std::string str; for (auto& entry: warnings()) { if (!str.empty()) str += '\n'; - str += purple("warning "); + str += purple(" * "); str += white(pprintf("%:% ", source_name(), entry.location)); str += entry.message; } diff --git a/modcc/solvers.cpp b/modcc/solvers.cpp index 1d31e55d..23d89a41 100644 --- a/modcc/solvers.cpp +++ b/modcc/solvers.cpp @@ -41,7 +41,13 @@ void CnexpSolverVisitor::visit(AssignmentExpression *e) { } auto s = deriv->name(); + linear_test_result r = linear_test(rhs, dvars_); + if (r.has_error()) { + append_errors(r.errors()); + error({"CNExp: Could not determine linearity, maybe use a different solver?", loc}); + return; + } if (!r.monolinear(s)) { error({"System not diagonal linear for cnexp", loc}); @@ -52,7 +58,6 @@ void CnexpSolverVisitor::visit(AssignmentExpression *e) { if (!coef || is_zero(coef)) { // s' = b becomes s = s + b*dt; use b_ as a local variable for // the constant term b. - auto local_b_term = make_unique_local_assign(scope, r.constant.get(), "b_"); statements_.push_back(std::move(local_b_term.local_decl)); statements_.push_back(std::move(local_b_term.assignment)); diff --git a/modcc/symdiff.cpp b/modcc/symdiff.cpp index 4946a1b8..1b4028eb 100644 --- a/modcc/symdiff.cpp +++ b/modcc/symdiff.cpp @@ -562,7 +562,17 @@ ARB_LIBMODCC_API expression_ptr symbolic_pdiff(Expression* e, const std::string& SymPDiffVisitor pdiff_visitor(id); e->accept(&pdiff_visitor); - if (pdiff_visitor.has_error()) return nullptr; + if (pdiff_visitor.has_error()) { + std::string errors, sep = ""; + + for (const auto& error: pdiff_visitor.errors()) { + errors += sep + error.message; + sep = "\n"; + } + auto res = std::make_unique<ErrorExpression>(e->location()); + res->error(errors); + return res; + } return constant_simplify(pdiff_visitor.result()); } @@ -666,17 +676,20 @@ ARB_LIBMODCC_API linear_test_result linear_test(Expression* e, const std::vector result.constant = e->clone(); for (const auto& id: vars) { auto coef = symbolic_pdiff(e, id); - if (!coef) { - return linear_test_result{}; + if (coef->has_error()) { + auto res = linear_test_result{}; + res.error({coef->error_message(), loc}); + return res; } + if (!coef) return linear_test_result{}; if (!is_zero(coef)) result.coef[id] = std::move(coef); - result.constant = substitute(result.constant, id, zero()); } ConstantSimplifyVisitor csimp_visitor; result.constant->accept(&csimp_visitor); result.constant = csimp_visitor.result(); + if (result.constant.get() == nullptr) throw compiler_exception{"Linear test: simplification of the constant term failed.", loc}; // linearity test: take second order derivatives, test against zero. result.is_linear = true; diff --git a/modcc/symdiff.hpp b/modcc/symdiff.hpp index cb58ffae..b4387489 100644 --- a/modcc/symdiff.hpp +++ b/modcc/symdiff.hpp @@ -81,7 +81,7 @@ inline expression_ptr substitute(const expression_ptr& e, const substitute_map& // Linearity testing -struct linear_test_result { +struct linear_test_result: public error_stack { bool is_linear = false; bool is_homogeneous = false; expression_ptr constant; diff --git a/test/unit-modcc/mod_files/bug-1893-bad.mod b/test/unit-modcc/mod_files/bug-1893-bad.mod new file mode 100644 index 00000000..fde7f171 --- /dev/null +++ b/test/unit-modcc/mod_files/bug-1893-bad.mod @@ -0,0 +1,28 @@ +NEURON { + POINT_PROCESS bug_1893 +} + +INITIAL { + c = 0 + rho = 0 + theta_p = 0 +} + +STATE { + c + rho + theta_p +} + +PARAMETER { + tau_c = 150 (ms) +} + +BREAKPOINT { + SOLVE state METHOD cnexp +} + +DERIVATIVE state { + c' = -c/tau_c + rho' = (c - theta_p) > 0 +} diff --git a/test/unit-modcc/mod_files/bug-1893.mod b/test/unit-modcc/mod_files/bug-1893.mod new file mode 100644 index 00000000..14cfd5e5 --- /dev/null +++ b/test/unit-modcc/mod_files/bug-1893.mod @@ -0,0 +1,28 @@ +NEURON { + POINT_PROCESS bug_1893 +} + +INITIAL { + c = 0 + rho = 0 + theta_p = 0 +} + +STATE { + c + rho + theta_p +} + +PARAMETER { + tau_c = 150 (ms) +} + +BREAKPOINT { + SOLVE state METHOD sparse +} + +DERIVATIVE state { + c' = -c/tau_c + rho' = (c - theta_p) > 0 +} diff --git a/test/unit-modcc/test_module.cpp b/test/unit-modcc/test_module.cpp index 34ff5b4d..9e9a1db0 100644 --- a/test/unit-modcc/test_module.cpp +++ b/test/unit-modcc/test_module.cpp @@ -122,3 +122,24 @@ TEST(Module, read_write_ion) { EXPECT_TRUE(p.parse()); EXPECT_TRUE(m.semantic()); } + +// Regression test in #1893 we found that the solver segfaults when handed a +// naked comparison statement. +TEST(Module, solver_bug_1893) { + { + Module m(io::read_all(DATADIR "/mod_files/bug-1893.mod"), "bug-1893.mod"); + EXPECT_NE(m.buffer().size(), 0); + + Parser p(m, false); + EXPECT_TRUE(p.parse()); + EXPECT_TRUE(m.semantic()); + } + { + Module m(io::read_all(DATADIR "/mod_files/bug-1893-bad.mod"), "bug-1893.mod"); + EXPECT_NE(m.buffer().size(), 0); + + Parser p(m, false); + EXPECT_TRUE(p.parse()); + EXPECT_FALSE(m.semantic()); + } +} -- GitLab