From dfb3209453dcb3cf8ee04fbe3c6d35ee60f671b2 Mon Sep 17 00:00:00 2001
From: Sam Yates <yates@cscs.ch>
Date: Tue, 13 Dec 2016 20:03:25 +0100
Subject: [PATCH] Fix modcc precedence parsing bug (#127)

* Modify `parse_expression` to take a controlling (parent) precedence.
* `parse_expression` folds left over sequences of sub-expressions with decreasing operator precedence (accumulates in `lhs`).
* Use recursion rather than accumulator for left fold in `parse_binop` to simplify code logic.
* Extend parser unit test to cover more complicated, multi-level expression.
* Remove (now) redundant parenthesis from derivative check block in kinetic rewriter test.

Fixes #94
---
 modcc/parser.cpp                      | 82 +++++++++++++--------------
 modcc/parser.hpp                      |  1 +
 tests/modcc/test_kinetic_rewriter.cpp |  2 +-
 tests/modcc/test_parser.cpp           | 20 ++++---
 4 files changed, 55 insertions(+), 50 deletions(-)

diff --git a/modcc/parser.cpp b/modcc/parser.cpp
index bebef194..7bf11c37 100644
--- a/modcc/parser.cpp
+++ b/modcc/parser.cpp
@@ -1095,27 +1095,34 @@ expression_ptr Parser::parse_conserve_expression() {
     return make_expression<ConserveExpression>(here, std::move(lhs), std::move(rhs));
 }
 
-expression_ptr Parser::parse_expression() {
+expression_ptr Parser::parse_expression(int prec) {
     auto lhs = parse_unaryop();
+    if(lhs==nullptr) return nullptr;
 
-    if(lhs==nullptr) { // error
-        return nullptr;
-    }
-
-    // we parse a binary expression if followed by an operator
-    if( binop_precedence(token_.type)>0 ) {
+    // Combine all sub-expressions with precedence greater than prec.
+    for (;;) {
         if(token_.type==tok::eq) {
             error("assignment '"+yellow("=")+"' not allowed in sub-expression");
             return nullptr;
         }
-        Token op = token_;  // save the operator
-        get_token();        // consume the operator
-        return parse_binop(std::move(lhs), op);
+
+        auto op = token_;
+        auto p_op = binop_precedence(op.type);
+
+        if(p_op<=prec) return lhs;
+        get_token();
+
+        lhs = parse_binop(std::move(lhs), op);
+        if(!lhs) return nullptr;
     }
 
     return lhs;
 }
 
+expression_ptr Parser::parse_expression() {
+    return parse_expression(0);
+}
+
 /// Parse a unary expression.
 /// If called when the current node in the AST is not a unary expression the call
 /// will be forwarded to parse_primary. This mechanism makes it possible to parse
@@ -1215,44 +1222,35 @@ expression_ptr Parser::parse_integer() {
 }
 
 expression_ptr Parser::parse_binop(expression_ptr&& lhs, Token op_left) {
-    // only way out of the loop below is by return:
-    //      :: return with nullptr on error
-    //      :: return when loop runs out of operators
-    //          i.e. if(pp<0)
-    //      :: return when recursion applied to remainder of expression
-    //          i.e. if(p_op>p_left)
-    while(1) {
-        // get precedence of the left operator
-        auto p_left = binop_precedence(op_left.type);
+    auto p_op_left = binop_precedence(op_left.type);
+    auto rhs = parse_expression(p_op_left);
+    if(!rhs) return nullptr;
 
-        auto e = parse_unaryop();
-        if(!e) return nullptr;
+    auto op_right = token_;
+    auto p_op_right = binop_precedence(op_right.type);
+    bool right_assoc = operator_associativity(op_right.type)==associativityKind::right;
 
-        auto op = token_;
-        auto p_op = binop_precedence(op.type);
-        if(operator_associativity(op.type)==associativityKind::right) {
-            p_op += 1;
-        }
+    if(p_op_right>p_op_left) {
+        throw compiler_exception(
+            "parse_binop() : encountered operator of higher precedence",
+            location_);
+    }
 
-        //  if no binop, parsing of expression is finished with (op_left lhs e)
-        if(p_op < 0) {
-            return binary_expression(op_left.location, op_left.type, std::move(lhs), std::move(e));
-        }
+    if(p_op_right<p_op_left) {
+        return binary_expression(op_left.location, op_left.type, std::move(lhs), std::move(rhs));
+    }
 
-        get_token(); // consume op
-        if(p_op > p_left) {
-            auto rhs = parse_binop(std::move(e), op);
-            if(!rhs) return nullptr;
-            return binary_expression(op_left.location, op_left.type, std::move(lhs), std::move(rhs));
-        }
+    get_token(); // consume op_right
+    if(right_assoc) {
+        rhs = parse_binop(std::move(rhs), op_right);
+        if(!rhs) return nullptr;
 
-        lhs = binary_expression(op_left.location, op_left.type, std::move(lhs), std::move(e));
-        op_left = op;
+        return binary_expression(op_left.location, op_left.type, std::move(lhs), std::move(rhs));
+    }
+    else {
+        lhs = binary_expression(op_left.location, op_left.type, std::move(lhs), std::move(rhs));
+        return parse_binop(std::move(lhs), op_right);
     }
-    throw compiler_exception(
-        "parse_binop() : fell out of recursive parse descent",
-        location_);
-    return nullptr;
 }
 
 /// parse a local variable definition
diff --git a/modcc/parser.hpp b/modcc/parser.hpp
index dc673e9d..10a33e59 100644
--- a/modcc/parser.hpp
+++ b/modcc/parser.hpp
@@ -20,6 +20,7 @@ public:
     expression_ptr parse_integer();
     expression_ptr parse_real();
     expression_ptr parse_call();
+    expression_ptr parse_expression(int prec);
     expression_ptr parse_expression();
     expression_ptr parse_primary();
     expression_ptr parse_parenthesis_expression();
diff --git a/tests/modcc/test_kinetic_rewriter.cpp b/tests/modcc/test_kinetic_rewriter.cpp
index 53a32d91..4dc1017c 100644
--- a/tests/modcc/test_kinetic_rewriter.cpp
+++ b/tests/modcc/test_kinetic_rewriter.cpp
@@ -43,7 +43,7 @@ static const char* derivative_abc =
     "    a' = -3*a + b*v       \n"
     "    LOCAL rev2            \n"
     "    rev2 = c*b^3*sin(4)   \n"
-    "    b' = 3*a - (v*b) + 8*b - 2*rev2\n"
+    "    b' = 3*a - v*b + 8*b - 2*rev2\n"
     "    c' = 4*b - rev2       \n"
     "}                         \n";
 
diff --git a/tests/modcc/test_parser.cpp b/tests/modcc/test_parser.cpp
index 8cab1b8b..853b2136 100644
--- a/tests/modcc/test_parser.cpp
+++ b/tests/modcc/test_parser.cpp
@@ -513,13 +513,15 @@ long double eval(Expression *e) {
 // test parsing of expressions for correctness
 // by parsing rvalue expressions with numeric atoms, which can be evalutated using eval
 TEST(Parser, parse_binop) {
+    using std::pow;
+
     std::pair<const char*, double> tests[] = {
         // simple
         {"2+3", 2.+3.},
         {"2-3", 2.-3.},
         {"2*3", 2.*3.},
         {"2/3", 2./3.},
-        {"2^3", std::pow(2., 3.)},
+        {"2^3", pow(2., 3.)},
 
         // more complicated
         {"2+3*2", 2.+(3*2)},
@@ -530,12 +532,16 @@ TEST(Parser, parse_binop) {
         {"2 * 7 - 3 * 11 + 4 * 13", 2.*7.-3.*11.+4.*13.},
 
         // right associative
-        {"2^3^1.5", std::pow(2.,std::pow(3.,1.5))},
-        {"2^3^1.5^2", std::pow(2.,std::pow(3.,std::pow(1.5,2.)))},
-        {"2^2^3", std::pow(2.,std::pow(2.,3.))},
-        {"(2^2)^3", std::pow(std::pow(2.,2.),3.)},
-        {"3./2^7.", 3./std::pow(2.,7.)},
-        {"3^2*5.", std::pow(3.,2.)*5.},
+        {"2^3^1.5", pow(2.,pow(3.,1.5))},
+        {"2^3^1.5^2", pow(2.,pow(3.,pow(1.5,2.)))},
+        {"2^2^3", pow(2.,pow(2.,3.))},
+        {"(2^2)^3", pow(pow(2.,2.),3.)},
+        {"3./2^7.", 3./pow(2.,7.)},
+        {"3^2*5.", pow(3.,2.)*5.},
+
+        // multilevel
+        {"1-2*3^4*5^2^3-3^2^3/4/8-5",
+            1.-2*pow(3.,4.)*pow(5.,pow(2.,3.))-pow(3,pow(2.,3.))/4./8.-5}
     };
 
     for (const auto& test_case: tests) {
-- 
GitLab