From 37ed1ce4fa60390dd0c084a92b6a86ebcd8b9aa7 Mon Sep 17 00:00:00 2001
From: Nora Abi Akar <nora.abiakar@gmail.com>
Date: Tue, 23 Mar 2021 10:51:31 +0100
Subject: [PATCH] modcc: modify semantic analysis and error reporting  (#1450)

* Fix semantic passes and error visitor expression coverage in modcc.
---
 modcc/errorvisitor.cpp |  47 ++++++++++++++--
 modcc/errorvisitor.hpp |  23 +++++---
 modcc/expression.cpp   | 125 ++++++-----------------------------------
 modcc/expression.hpp   |   1 -
 4 files changed, 72 insertions(+), 124 deletions(-)

diff --git a/modcc/errorvisitor.cpp b/modcc/errorvisitor.cpp
index 225ec97d..eddf59b8 100644
--- a/modcc/errorvisitor.cpp
+++ b/modcc/errorvisitor.cpp
@@ -14,7 +14,6 @@ void ErrorVisitor::visit(ProcedureExpression *e) {
     for(auto& expression : e->args()) {
         expression->accept(this);
     }
-
     e->body()->accept(this);
     print_error(e);
 }
@@ -24,18 +23,17 @@ void ErrorVisitor::visit(FunctionExpression *e) {
     for(auto& expression : e->args()) {
         expression->accept(this);
     }
-
     e->body()->accept(this);
     print_error(e);
 }
 
 // an if statement
 void ErrorVisitor::visit(IfExpression *e) {
+    e->condition()->accept(this);
     e->true_branch()->accept(this);
     if(e->false_branch()) {
         e->false_branch()->accept(this);
     }
-
     print_error(e);
 }
 
@@ -43,7 +41,6 @@ void ErrorVisitor::visit(BlockExpression* e) {
     for(auto& expression : e->statements()) {
         expression->accept(this);
     }
-
     print_error(e);
 }
 
@@ -51,7 +48,6 @@ void ErrorVisitor::visit(InitialBlock* e) {
     for(auto& expression : e->statements()) {
         expression->accept(this);
     }
-
     print_error(e);
 }
 
@@ -68,7 +64,7 @@ void ErrorVisitor::visit(BinaryExpression *e) {
     print_error(e);
 }
 
-// binary expresssion
+// call expresssion
 void ErrorVisitor::visit(CallExpression *e) {
     for(auto& expression: e->args()) {
         expression->accept(this);
@@ -76,3 +72,42 @@ void ErrorVisitor::visit(CallExpression *e) {
     print_error(e);
 }
 
+// reaction expresssion
+void ErrorVisitor::visit(ReactionExpression *e) {
+    e->lhs()->accept(this);
+    e->rhs()->accept(this);
+    e->fwd_rate()->accept(this);
+    e->rev_rate()->accept(this);
+    print_error(e);
+}
+
+// stoich expresssion
+void ErrorVisitor::visit(StoichExpression *e) {
+    for (auto& expression: e->terms()) {
+        expression->accept(this);
+    }
+    print_error(e);
+}
+
+// stoich term expresssion
+void ErrorVisitor::visit(StoichTermExpression *e) {
+    e->ident()->accept(this);
+    e->coeff()->accept(this);
+    print_error(e);
+}
+
+// compartment expresssion
+void ErrorVisitor::visit(CompartmentExpression *e) {
+    e->scale_factor()->accept(this);
+    for (auto& expression: e->state_vars()) {
+        expression->accept(this);
+    }
+    print_error(e);
+}
+
+// pdiff expresssion
+void ErrorVisitor::visit(PDiffExpression *e) {
+    e->var()->accept(this);
+    e->arg()->accept(this);
+    print_error(e);
+}
\ No newline at end of file
diff --git a/modcc/errorvisitor.hpp b/modcc/errorvisitor.hpp
index 39afad2b..c9b63385 100644
--- a/modcc/errorvisitor.hpp
+++ b/modcc/errorvisitor.hpp
@@ -10,16 +10,21 @@ public:
         : module_name_(m)
     {}
 
-    void visit(Expression *e)           override;
-    void visit(ProcedureExpression *e)  override;
-    void visit(FunctionExpression *e)   override;
-    void visit(UnaryExpression *e)      override;
-    void visit(BinaryExpression *e)     override;
-    void visit(CallExpression *e)       override;
+    void visit(Expression *e)            override;
+    void visit(ProcedureExpression *e)   override;
+    void visit(FunctionExpression *e)    override;
+    void visit(UnaryExpression *e)       override;
+    void visit(BinaryExpression *e)      override;
+    void visit(CallExpression *e)        override;
+    void visit(ReactionExpression *e)    override;
+    void visit(StoichExpression *e)      override;
+    void visit(StoichTermExpression *e)  override;
+    void visit(CompartmentExpression *e) override;
+    void visit(PDiffExpression *e)       override;
 
-    void visit(BlockExpression *e)      override;
-    void visit(InitialBlock *e)         override;
-    void visit(IfExpression *e)         override;
+    void visit(BlockExpression *e)       override;
+    void visit(InitialBlock *e)          override;
+    void visit(IfExpression *e)          override;
 
     int num_errors()   {return num_errors_;}
     int num_warnings() {return num_warnings_;}
diff --git a/modcc/expression.cpp b/modcc/expression.cpp
index 74f50ef1..8cb81f71 100644
--- a/modcc/expression.cpp
+++ b/modcc/expression.cpp
@@ -148,6 +148,9 @@ void DerivativeExpression::semantic(scope_ptr scp) {
     error_ = false;
 
     IdentifierExpression::semantic(scp);
+    if (has_error()) {
+        return;
+    }
     auto v = symbol_->is_variable();
     if (!v || !v->is_state()) {
         error( pprintf("the variable '%' must be a state variable to be differentiated",
@@ -323,17 +326,8 @@ void ReactionExpression::semantic(scope_ptr scp) {
     fwd_rate()->semantic(scp);
     rev_rate()->semantic(scp);
 
-    std::string msg = lhs_->has_error() ? lhs_->error_message() :
-                      rhs_->has_error() ? rhs_->error_message() :
-                      fwd_rate_->has_error() ? fwd_rate_->error_message() :
-                      rev_rate_->has_error() ? rev_rate_->error_message() : "";
-
-    if (!msg.empty()) {
-        error(msg);
-        return;
-    }
-
-    if(fwd_rate_->is_procedure_call() || rev_rate_->is_procedure_call()) {
+    if((!fwd_rate_->has_error() && fwd_rate_->is_procedure_call()) ||
+       (!rev_rate_->has_error() && rev_rate_->is_procedure_call())) {
         error("procedure calls can't be made in an expression");
     }
 }
@@ -350,11 +344,7 @@ expression_ptr StoichTermExpression::clone() const {
 void StoichTermExpression::semantic(scope_ptr scp) {
     error_ = false;
     scope_ = scp;
-
     ident()->semantic(scp);
-    if(ident()->has_error()) {
-        error(ident()->error_message());
-    }
 }
 
 /*******************************************************************************
@@ -387,9 +377,6 @@ void StoichExpression::semantic(scope_ptr scp) {
 
     for(auto& e: terms()) {
         e->semantic(scp);
-        if(e->has_error()) {
-            error(e->error_message());
-        }
     }
 }
 
@@ -423,10 +410,9 @@ std::string CompartmentExpression::to_string() const {
 void CompartmentExpression::semantic(scope_ptr scp) {
     error_ = false;
     scope_ = scp;
-
     scale_factor()->semantic(scp);
-    if(scale_factor()->has_error()) {
-        error(scale_factor()->error_message());
+    for (auto& e: state_vars_) {
+        e->semantic(scp);
     }
 }
 
@@ -446,14 +432,7 @@ void LinearExpression::semantic(scope_ptr scp) {
     lhs_->semantic(scp);
     rhs_->semantic(scp);
 
-    std::string msg = lhs_->has_error() ? lhs_->error_message() :
-                      rhs_->has_error() ? rhs_->error_message() : "";
-
-    if (!msg.empty()) {
-        error(msg);
-        return;
-    }
-    if(rhs_->is_procedure_call()) {
+    if(!rhs_->has_error() && rhs_->is_procedure_call()) {
         error("procedure calls can't be made in an expression");
     }
 }
@@ -474,14 +453,7 @@ void ConserveExpression::semantic(scope_ptr scp) {
     lhs_->semantic(scp);
     rhs_->semantic(scp);
 
-    std::string msg = lhs_->has_error() ? lhs_->error_message() :
-                      rhs_->has_error() ? rhs_->error_message() : "";
-
-    if (!msg.empty()) {
-        error(msg);
-        return;
-    }
-    if(rhs_->is_procedure_call()) {
+    if(!rhs_->has_error() && rhs_->is_procedure_call()) {
         error("procedure calls can't be made in an expression");
     }
 }
@@ -539,9 +511,6 @@ void CallExpression::semantic(scope_ptr scp) {
     // perform semantic analysis on the arguments
     for(auto& a : args_) {
         a->semantic(scp);
-        if(a->has_error()) {
-            error(a->error_message());
-        }
     }
 }
 
@@ -587,9 +556,6 @@ void ProcedureExpression::semantic(scope_ptr scp) {
     // add the argumemts to the list of local variables
     for(auto& a : args_) {
         a->semantic(scope_);
-        if(a->has_error()) {
-            error(a->error_message());
-        }
     }
 
     // this loop could be used to then check the types of statements in the body
@@ -600,9 +566,6 @@ void ProcedureExpression::semantic(scope_ptr scp) {
 
     // perform semantic analysis for each expression in the body
     body_->semantic(scope_);
-    if(body_->has_error()) {
-        error(body_->error_message());
-    }
 
     // the symbol for this expression is itself
     symbol_ = scope_->find_global(name());
@@ -697,16 +660,10 @@ void NetReceiveExpression::semantic(scope_type::symbol_map &global_symbols) {
     // add the argumemts to the list of local variables
     for(auto& a : args_) {
         a->semantic(scope_);
-        if(a->has_error()) {
-            error(a->error_message());
-        }
     }
 
     // perform semantic analysis for each expression in the body
     body_->semantic(scope_);
-    if(body_->has_error()) {
-        error(body_->error_message());
-    }
 
     // this loop could be used to then check the types of statements in the body
     for(auto& e : *(body_->is_block())) {
@@ -746,16 +703,10 @@ void PostEventExpression::semantic(scope_type::symbol_map &global_symbols) {
     // add the argumemts to the list of local variables
     for(auto& a : args_) {
         a->semantic(scope_);
-        if(a->has_error()) {
-            error(a->error_message());
-        }
     }
 
     // perform semantic analysis for each expression in the body
     body_->semantic(scope_);
-    if(body_->has_error()) {
-        error(body_->error_message());
-    }
 
     symbol_ = scope_->find_global(name());
 }
@@ -792,9 +743,6 @@ void FunctionExpression::semantic(scope_type::symbol_map &global_symbols) {
     // add the argumemts to the list of local variables
     for(auto& a : args_) {
         a->semantic(scope_);
-        if(a->has_error()) {
-            error(a->error_message());
-        }
     }
 
     // Add a variable that has the same name as the function,
@@ -808,12 +756,12 @@ void FunctionExpression::semantic(scope_type::symbol_map &global_symbols) {
 
     // perform semantic analysis for each expression in the body
     body_->semantic(scope_);
-    if(body_->has_error()) {
-        error(body_->error_message());
-    }
+
     // this loop could be used to then check the types of statements in the body
     for(auto& e : *(body())) {
-        if(e->is_initial_block()) error("INITIAL block not allowed inside FUNCTION definition");
+        if(e->is_initial_block()) {
+            error("INITIAL block not allowed inside FUNCTION definition");
+        }
     }
 
     // the symbol for this expression is itself
@@ -829,11 +777,7 @@ void UnaryExpression::semantic(scope_ptr scp) {
     scope_ = scp;
 
     expression_->semantic(scp);
-    if(expression_->has_error()) {
-        error(expression_->error_message());
-        return;
-    }
-    if(expression_->is_procedure_call()) {
+    if(!expression_->has_error() && expression_->is_procedure_call()) {
         error("a procedure call can't be part of an expression");
     }
 }
@@ -856,14 +800,8 @@ void BinaryExpression::semantic(scope_ptr scp) {
     lhs_->semantic(scp);
     rhs_->semantic(scp);
 
-    std::string msg = lhs_->has_error() ? lhs_->error_message() :
-                      rhs_->has_error() ? rhs_->error_message() : "";
-
-    if (!msg.empty()) {
-        error(msg);
-        return;
-    }
-    if(rhs_->is_procedure_call() || lhs_->is_procedure_call()) {
+    if((!rhs_->has_error() && rhs_->is_procedure_call()) ||
+       (!lhs_->has_error() && lhs_->is_procedure_call())) {
         error("procedure calls can't be made in an expression");
     }
 }
@@ -896,14 +834,6 @@ void AssignmentExpression::semantic(scope_ptr scp) {
     lhs_->semantic(scp);
     rhs_->semantic(scp);
 
-    std::string msg = lhs_->has_error() ? lhs_->error_message() :
-                      rhs_->has_error() ? rhs_->error_message() : "";
-
-    if (!msg.empty()) {
-        error(msg);
-        return;
-    }
-
     // only flag an lvalue error if there was no error in the lhs expression
     // this ensures that we don't print redundant error messages when trying
     // to write to an undeclared variable
@@ -986,9 +916,6 @@ void BlockExpression::semantic(scope_ptr scp) {
     scope_ = scp;
     for(auto& e : statements_) {
         e->semantic(scope_);
-        if(e->has_error()) {
-            error(e->error_message());
-        }
     }
 }
 
@@ -1021,24 +948,12 @@ void IfExpression::semantic(scope_ptr scp) {
     scope_ = scp;
 
     condition_->semantic(scp);
-    if(condition_->has_error()) {
-        error(condition()->error_message());
-    }
-
-    if(!condition_->is_conditional()) {
+    if(!condition_->has_error() && !condition_->is_conditional()) {
         error("not a valid conditional expression");
     }
-
     true_branch_->semantic(scp);
-    if(true_branch_->has_error()) {
-        error(true_branch_->error_message());
-    }
-
     if(false_branch_) {
         false_branch_->semantic(scp);
-        if(false_branch_->has_error()) {
-            error(false_branch_->error_message());
-        }
     }
 }
 
@@ -1072,13 +987,7 @@ void PDiffExpression::semantic(scope_ptr scp) {
                       "an identifier, but instead %", yellow(var_->to_string())));
     }
     var_->semantic(scp);
-    if(var_->has_error()) {
-        error(var_->error_message());
-    }
     arg_->semantic(scp);
-    if(arg_->has_error()) {
-        error(arg_->error_message());
-    }
 }
 
 expression_ptr PDiffExpression::clone() const {
diff --git a/modcc/expression.hpp b/modcc/expression.hpp
index f12b429d..fc7c3dac 100644
--- a/modcc/expression.hpp
+++ b/modcc/expression.hpp
@@ -45,7 +45,6 @@ class Symbol;
 class ConductanceExpression;
 class PDiffExpression;
 class VariableExpression;
-class ProcedureExpression;
 class NetReceiveExpression;
 class PostEventExpression;
 class APIMethod;
-- 
GitLab