From 1db7376752ff4912b0874cb3d616f1a556ae76dc Mon Sep 17 00:00:00 2001 From: Sam Yates <yates@cscs.ch> Date: Tue, 28 Mar 2017 14:13:21 +0200 Subject: [PATCH] Bug fix for issue #196 (#211) Fixes #196. Correct treatment of missing coefficients in `cnexp` solver. * Extend `EXPECT_EXPR_EQ` functionality with wrapper that works with `Expression *` and `expression_ptr` arguments. * Replace string comparison checks in `test_symdiff.cpp` with equivalents that use `EXPECT_EXPR_EQ`. * Check explicitly for missing coefficient in `cnexp` solver, which should be treated equivalently to zero. --- modcc/solvers.cpp | 26 +++++++++++++------------- tests/modcc/test.hpp | 11 ++++++++++- tests/modcc/test_symdiff.cpp | 34 +++++++++++++++++----------------- 3 files changed, 40 insertions(+), 31 deletions(-) diff --git a/modcc/solvers.cpp b/modcc/solvers.cpp index f3f4a7d6..e4a4d490 100644 --- a/modcc/solvers.cpp +++ b/modcc/solvers.cpp @@ -49,19 +49,7 @@ void CnexpSolverVisitor::visit(AssignmentExpression *e) { } Expression* coef = r.coef[s].get(); - if (r.is_homogeneous) { - // s' = a*s becomes s = s*exp(a*dt); use a_ as a local variable - // for the coefficient. - auto local_a_term = make_unique_local_assign(scope, coef, "a_"); - statements_.push_back(std::move(local_a_term.local_decl)); - statements_.push_back(std::move(local_a_term.assignment)); - auto a_ = local_a_term.id->is_identifier()->spelling(); - - std::string s_update = pprintf("% = %*exp(%*dt)", s, s, a_); - statements_.push_back(Parser{s_update}.parse_line_expression()); - return; - } - else if (is_zero(coef)) { + if (!coef || is_zero(coef)) { // s' = b becomes s = s + b*dt; use b_ as a local variable for // the constant term b. @@ -74,6 +62,18 @@ void CnexpSolverVisitor::visit(AssignmentExpression *e) { statements_.push_back(Parser{s_update}.parse_line_expression()); return; } + else if (r.is_homogeneous) { + // s' = a*s becomes s = s*exp(a*dt); use a_ as a local variable + // for the coefficient. + auto local_a_term = make_unique_local_assign(scope, coef, "a_"); + statements_.push_back(std::move(local_a_term.local_decl)); + statements_.push_back(std::move(local_a_term.assignment)); + auto a_ = local_a_term.id->is_identifier()->spelling(); + + std::string s_update = pprintf("% = %*exp(%*dt)", s, s, a_); + statements_.push_back(Parser{s_update}.parse_line_expression()); + return; + } else { // s' = a*s + b becomes s = -b/a + (s+b/a)*exp(a*dt); use // a_ as a local variable for the coefficient and ba_ for the diff --git a/tests/modcc/test.hpp b/tests/modcc/test.hpp index fe6e7666..d0059e3c 100644 --- a/tests/modcc/test.hpp +++ b/tests/modcc/test.hpp @@ -33,11 +33,20 @@ inline expression_ptr parse_procedure(std::string const& s) { // Strip ANSI control sequences from `to_string` output. std::string plain_text(Expression* expr); +// Generic get expression pointer from raw pointer or unique_ptr wrapper. +inline Expression* raw_expression(Expression* p) { return p; } +inline Expression* raw_expression(const expression_ptr &p) { return p.get(); } + // Compare two expressions via their representation. // Use with EXPECT_PRED_FORMAT2. ::testing::AssertionResult assert_expr_eq(const char *arg1, const char *arg2, Expression* expected, Expression* value); -#define EXPECT_EXPR_EQ(a,b) EXPECT_PRED_FORMAT2(assert_expr_eq, a, b) +template <typename E1, typename E2> +::testing::AssertionResult assert_expr_eq_wrap(const char *arg1, const char *arg2, E1&& expected, E2&& value) { + return assert_expr_eq(arg1, arg2, raw_expression(std::forward<E1>(expected)), raw_expression(std::forward<E2>(value))); +} + +#define EXPECT_EXPR_EQ(a,b) EXPECT_PRED_FORMAT2(assert_expr_eq_wrap, a, b) // Print arguments, but only if verbose flag set. // Use `to_string()` to print (smart) pointers to Expression or Scope objects. diff --git a/tests/modcc/test_symdiff.cpp b/tests/modcc/test_symdiff.cpp index 7f709825..49571d3f 100644 --- a/tests/modcc/test_symdiff.cpp +++ b/tests/modcc/test_symdiff.cpp @@ -122,7 +122,7 @@ TEST(constant_simplify, simplified_expr) { ASSERT_TRUE(before); ASSERT_TRUE(after); - EXPECT_EQ(after->to_string(), constant_simplify(before)->to_string()); + EXPECT_EXPR_EQ(after, constant_simplify(before)); } } @@ -156,7 +156,7 @@ TEST(constant_simplify, block_with_if) { ASSERT_TRUE(before); ASSERT_TRUE(after); - EXPECT_EQ(after->to_string(), constant_simplify(before)->to_string()); + EXPECT_EXPR_EQ(after, constant_simplify(before)); } TEST(symbolic_pdiff, expressions) { @@ -177,7 +177,7 @@ TEST(symbolic_pdiff, expressions) { ASSERT_TRUE(before); ASSERT_TRUE(after); - EXPECT_EQ(after->to_string(), symbolic_pdiff(before, "x")->to_string()); + EXPECT_EXPR_EQ(after, symbolic_pdiff(before, "x")); } } @@ -195,7 +195,7 @@ TEST(symbolic_pdiff, linear) { ASSERT_TRUE(before); ASSERT_TRUE(after); - EXPECT_EQ(after->to_string(), symbolic_pdiff(before, "x")->to_string()); + EXPECT_EXPR_EQ(after, symbolic_pdiff(before, "x")); } } @@ -214,7 +214,7 @@ TEST(symbolic_pdiff, nonlinear) { ASSERT_TRUE(before); ASSERT_TRUE(after); - EXPECT_EQ(after->to_string(), symbolic_pdiff(before, "x")->to_string()); + EXPECT_EXPR_EQ(after, symbolic_pdiff(before, "x")); } } @@ -242,7 +242,7 @@ TEST(substitute, expressions) { ASSERT_TRUE(after); auto result = substitute(before.get(), "x", yplusz.get()); - EXPECT_EQ(after->to_string(), result->to_string()); + EXPECT_EXPR_EQ(after, result); } } @@ -259,7 +259,7 @@ TEST(substitute, exprmap) { ASSERT_TRUE(after); auto result = substitute(before.get(), subs); - EXPECT_EQ(after->to_string(), result->to_string()); + EXPECT_EXPR_EQ(after, result); } TEST(linear_test, homogeneous) { @@ -269,14 +269,14 @@ TEST(linear_test, homogeneous) { EXPECT_TRUE(r.is_linear); EXPECT_TRUE(r.is_homogeneous); EXPECT_TRUE(r.monolinear()); - EXPECT_EQ(r.coef["x"]->to_string(), "3"_expr->to_string()); + EXPECT_EXPR_EQ(r.coef["x"], "3"_expr); r = linear_test("y-a*x+2*x"_expr, {"x", "y"}); EXPECT_TRUE(r.is_linear); EXPECT_TRUE(r.is_homogeneous); EXPECT_FALSE(r.monolinear()); - EXPECT_EQ(r.coef["x"]->to_string(), "-a+2"_expr->to_string()); - EXPECT_EQ(r.coef["y"]->to_string(), "1"_expr->to_string()); + EXPECT_EXPR_EQ(r.coef["x"], "-a+2"_expr); + EXPECT_EXPR_EQ(r.coef["y"], "1"_expr); } TEST(linear_test, inhomogeneous) { @@ -285,23 +285,23 @@ TEST(linear_test, inhomogeneous) { r = linear_test("sin(y)+3*x"_expr, {"x"}); EXPECT_TRUE(r.is_linear); EXPECT_FALSE(r.is_homogeneous); - EXPECT_EQ(r.coef["x"]->to_string(), "3"_expr->to_string()); - EXPECT_EQ(r.constant->to_string(), "sin(y)"_expr->to_string()); + EXPECT_EXPR_EQ(r.coef["x"], "3"_expr); + EXPECT_EXPR_EQ(r.constant, "sin(y)"_expr); r = linear_test("(x+y+1)*(a+b)"_expr, {"x", "y"}); EXPECT_TRUE(r.is_linear); EXPECT_FALSE(r.is_homogeneous); - EXPECT_EQ(r.coef["x"]->to_string(), "a+b"_expr->to_string()); - EXPECT_EQ(r.coef["y"]->to_string(), "a+b"_expr->to_string()); - EXPECT_EQ(r.constant->to_string(), "a+b"_expr->to_string()); + EXPECT_EXPR_EQ(r.coef["x"], "a+b"_expr); + EXPECT_EXPR_EQ(r.coef["y"], "a+b"_expr); + EXPECT_EXPR_EQ(r.constant, "a+b"_expr); // check 'gating' case still works! (Use plus instead of minus // though because of -1 vs (- 1) parsing makes the test harder.) r = linear_test("(a+x)/b"_expr, {"x"}); EXPECT_TRUE(r.is_linear); EXPECT_FALSE(r.is_homogeneous); - EXPECT_EQ(r.coef["x"]->to_string(), "1/b"_expr->to_string()); - EXPECT_EQ(r.constant->to_string(), "a/b"_expr->to_string()); + EXPECT_EXPR_EQ(r.coef["x"], "1/b"_expr); + EXPECT_EXPR_EQ(r.constant, "a/b"_expr); } TEST(linear_test, nonlinear) { -- GitLab