From 1db7376752ff4912b0874cb3d616f1a556ae76dc Mon Sep 17 00:00:00 2001
From: Sam Yates <yates@cscs.ch>
Date: Tue, 28 Mar 2017 14:13:21 +0200
Subject: [PATCH] Bug fix for issue #196 (#211)

Fixes #196.

Correct treatment of missing coefficients in `cnexp` solver.

* Extend `EXPECT_EXPR_EQ` functionality with wrapper that works with `Expression *` and `expression_ptr` arguments.
* Replace string comparison checks in `test_symdiff.cpp` with equivalents that use `EXPECT_EXPR_EQ`.
* Check explicitly for missing coefficient in `cnexp` solver, which should be treated equivalently to zero.
---
 modcc/solvers.cpp            | 26 +++++++++++++-------------
 tests/modcc/test.hpp         | 11 ++++++++++-
 tests/modcc/test_symdiff.cpp | 34 +++++++++++++++++-----------------
 3 files changed, 40 insertions(+), 31 deletions(-)

diff --git a/modcc/solvers.cpp b/modcc/solvers.cpp
index f3f4a7d6..e4a4d490 100644
--- a/modcc/solvers.cpp
+++ b/modcc/solvers.cpp
@@ -49,19 +49,7 @@ void CnexpSolverVisitor::visit(AssignmentExpression *e) {
     }
 
     Expression* coef = r.coef[s].get();
-    if (r.is_homogeneous) {
-        // s' = a*s becomes s = s*exp(a*dt); use a_ as a local variable
-        // for the coefficient.
-        auto local_a_term = make_unique_local_assign(scope, coef, "a_");
-        statements_.push_back(std::move(local_a_term.local_decl));
-        statements_.push_back(std::move(local_a_term.assignment));
-        auto a_ = local_a_term.id->is_identifier()->spelling();
-
-        std::string s_update = pprintf("% = %*exp(%*dt)", s, s, a_);
-        statements_.push_back(Parser{s_update}.parse_line_expression());
-        return;
-    }
-    else if (is_zero(coef)) {
+    if (!coef || is_zero(coef)) {
         // s' = b becomes s = s + b*dt; use b_ as a local variable for
         // the constant term b.
 
@@ -74,6 +62,18 @@ void CnexpSolverVisitor::visit(AssignmentExpression *e) {
         statements_.push_back(Parser{s_update}.parse_line_expression());
         return;
     }
+    else if (r.is_homogeneous) {
+        // s' = a*s becomes s = s*exp(a*dt); use a_ as a local variable
+        // for the coefficient.
+        auto local_a_term = make_unique_local_assign(scope, coef, "a_");
+        statements_.push_back(std::move(local_a_term.local_decl));
+        statements_.push_back(std::move(local_a_term.assignment));
+        auto a_ = local_a_term.id->is_identifier()->spelling();
+
+        std::string s_update = pprintf("% = %*exp(%*dt)", s, s, a_);
+        statements_.push_back(Parser{s_update}.parse_line_expression());
+        return;
+    }
     else {
         // s' = a*s + b becomes s = -b/a + (s+b/a)*exp(a*dt); use
         // a_ as a local variable for the coefficient and ba_ for the
diff --git a/tests/modcc/test.hpp b/tests/modcc/test.hpp
index fe6e7666..d0059e3c 100644
--- a/tests/modcc/test.hpp
+++ b/tests/modcc/test.hpp
@@ -33,11 +33,20 @@ inline expression_ptr parse_procedure(std::string const& s) {
 // Strip ANSI control sequences from `to_string` output.
 std::string plain_text(Expression* expr);
 
+// Generic get expression pointer from raw pointer or unique_ptr wrapper.
+inline Expression* raw_expression(Expression* p) { return p; }
+inline Expression* raw_expression(const expression_ptr &p) { return p.get(); }
+
 // Compare two expressions via their representation.
 // Use with EXPECT_PRED_FORMAT2.
 ::testing::AssertionResult assert_expr_eq(const char *arg1, const char *arg2, Expression* expected, Expression* value);
 
-#define EXPECT_EXPR_EQ(a,b) EXPECT_PRED_FORMAT2(assert_expr_eq, a, b)
+template <typename E1, typename E2>
+::testing::AssertionResult assert_expr_eq_wrap(const char *arg1, const char *arg2, E1&& expected, E2&& value) {
+    return assert_expr_eq(arg1, arg2, raw_expression(std::forward<E1>(expected)), raw_expression(std::forward<E2>(value)));
+}
+
+#define EXPECT_EXPR_EQ(a,b) EXPECT_PRED_FORMAT2(assert_expr_eq_wrap, a, b)
 
 // Print arguments, but only if verbose flag set.
 // Use `to_string()` to print (smart) pointers to Expression or Scope objects.
diff --git a/tests/modcc/test_symdiff.cpp b/tests/modcc/test_symdiff.cpp
index 7f709825..49571d3f 100644
--- a/tests/modcc/test_symdiff.cpp
+++ b/tests/modcc/test_symdiff.cpp
@@ -122,7 +122,7 @@ TEST(constant_simplify, simplified_expr) {
         ASSERT_TRUE(before);
         ASSERT_TRUE(after);
 
-        EXPECT_EQ(after->to_string(), constant_simplify(before)->to_string());
+        EXPECT_EXPR_EQ(after, constant_simplify(before));
     }
 }
 
@@ -156,7 +156,7 @@ TEST(constant_simplify, block_with_if) {
     ASSERT_TRUE(before);
     ASSERT_TRUE(after);
 
-    EXPECT_EQ(after->to_string(), constant_simplify(before)->to_string());
+    EXPECT_EXPR_EQ(after, constant_simplify(before));
 }
 
 TEST(symbolic_pdiff, expressions) {
@@ -177,7 +177,7 @@ TEST(symbolic_pdiff, expressions) {
         ASSERT_TRUE(before);
         ASSERT_TRUE(after);
 
-        EXPECT_EQ(after->to_string(), symbolic_pdiff(before, "x")->to_string());
+        EXPECT_EXPR_EQ(after, symbolic_pdiff(before, "x"));
     }
 }
 
@@ -195,7 +195,7 @@ TEST(symbolic_pdiff, linear) {
         ASSERT_TRUE(before);
         ASSERT_TRUE(after);
 
-        EXPECT_EQ(after->to_string(), symbolic_pdiff(before, "x")->to_string());
+        EXPECT_EXPR_EQ(after, symbolic_pdiff(before, "x"));
     }
 }
 
@@ -214,7 +214,7 @@ TEST(symbolic_pdiff, nonlinear) {
         ASSERT_TRUE(before);
         ASSERT_TRUE(after);
 
-        EXPECT_EQ(after->to_string(), symbolic_pdiff(before, "x")->to_string());
+        EXPECT_EXPR_EQ(after, symbolic_pdiff(before, "x"));
     }
 }
 
@@ -242,7 +242,7 @@ TEST(substitute, expressions) {
         ASSERT_TRUE(after);
 
         auto result = substitute(before.get(), "x", yplusz.get());
-        EXPECT_EQ(after->to_string(), result->to_string());
+        EXPECT_EXPR_EQ(after, result);
     }
 }
 
@@ -259,7 +259,7 @@ TEST(substitute, exprmap) {
     ASSERT_TRUE(after);
 
     auto result = substitute(before.get(), subs);
-    EXPECT_EQ(after->to_string(), result->to_string());
+    EXPECT_EXPR_EQ(after, result);
 }
 
 TEST(linear_test, homogeneous) {
@@ -269,14 +269,14 @@ TEST(linear_test, homogeneous) {
     EXPECT_TRUE(r.is_linear);
     EXPECT_TRUE(r.is_homogeneous);
     EXPECT_TRUE(r.monolinear());
-    EXPECT_EQ(r.coef["x"]->to_string(), "3"_expr->to_string());
+    EXPECT_EXPR_EQ(r.coef["x"], "3"_expr);
 
     r = linear_test("y-a*x+2*x"_expr, {"x", "y"});
     EXPECT_TRUE(r.is_linear);
     EXPECT_TRUE(r.is_homogeneous);
     EXPECT_FALSE(r.monolinear());
-    EXPECT_EQ(r.coef["x"]->to_string(), "-a+2"_expr->to_string());
-    EXPECT_EQ(r.coef["y"]->to_string(), "1"_expr->to_string());
+    EXPECT_EXPR_EQ(r.coef["x"], "-a+2"_expr);
+    EXPECT_EXPR_EQ(r.coef["y"], "1"_expr);
 }
 
 TEST(linear_test, inhomogeneous) {
@@ -285,23 +285,23 @@ TEST(linear_test, inhomogeneous) {
     r = linear_test("sin(y)+3*x"_expr, {"x"});
     EXPECT_TRUE(r.is_linear);
     EXPECT_FALSE(r.is_homogeneous);
-    EXPECT_EQ(r.coef["x"]->to_string(), "3"_expr->to_string());
-    EXPECT_EQ(r.constant->to_string(), "sin(y)"_expr->to_string());
+    EXPECT_EXPR_EQ(r.coef["x"], "3"_expr);
+    EXPECT_EXPR_EQ(r.constant, "sin(y)"_expr);
 
     r = linear_test("(x+y+1)*(a+b)"_expr, {"x", "y"});
     EXPECT_TRUE(r.is_linear);
     EXPECT_FALSE(r.is_homogeneous);
-    EXPECT_EQ(r.coef["x"]->to_string(), "a+b"_expr->to_string());
-    EXPECT_EQ(r.coef["y"]->to_string(), "a+b"_expr->to_string());
-    EXPECT_EQ(r.constant->to_string(), "a+b"_expr->to_string());
+    EXPECT_EXPR_EQ(r.coef["x"], "a+b"_expr);
+    EXPECT_EXPR_EQ(r.coef["y"], "a+b"_expr);
+    EXPECT_EXPR_EQ(r.constant, "a+b"_expr);
 
     // check 'gating' case still works! (Use plus instead of minus
     // though because of -1 vs (- 1) parsing makes the test harder.)
     r = linear_test("(a+x)/b"_expr, {"x"});
     EXPECT_TRUE(r.is_linear);
     EXPECT_FALSE(r.is_homogeneous);
-    EXPECT_EQ(r.coef["x"]->to_string(), "1/b"_expr->to_string());
-    EXPECT_EQ(r.constant->to_string(), "a/b"_expr->to_string());
+    EXPECT_EXPR_EQ(r.coef["x"], "1/b"_expr);
+    EXPECT_EXPR_EQ(r.constant, "a/b"_expr);
 }
 
 TEST(linear_test, nonlinear) {
-- 
GitLab