From fbd411844ec8fb495d5012ce7cc7332bf1829060 Mon Sep 17 00:00:00 2001
From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com>
Date: Wed, 21 Dec 2022 13:12:59 +0100
Subject: [PATCH] Better handling of powers in modcc. (#2061)

- $x^{-1} \Rightarrow 1/x\quad \forall x$
- $x^n \Rightarrow x\cdot \dots \cdot x \quad x\in N; |x| < 5$
- $x^n \Rightarrow 1/(x\cdot \dots \cdot x) \quad x\in N; |x| < 5; x <
0$
- $b^e \Rightarrow \exp(\log(b) e)\quad \forall  b, e$

The last point introduces potential errors when `pow(b, e)` is allowed,
but `log(b)` is
undefined. These occur exactly when all of the following is true
- $b < 0$
- $e\in N$
- $e$, $b$ not known at compile time (since we cover these cases before)
---
 arbor/include/arbor/simd/simd.hpp | 40 +++++++++++++-------------
 modcc/printer/cexpr_emit.cpp      |  2 +-
 modcc/printer/cprinter.cpp        |  8 ++++--
 modcc/symdiff.cpp                 | 47 +++++++++++++++++++++++++------
 test/unit-modcc/test_symdiff.cpp  | 31 +++++++++++++++++++-
 5 files changed, 96 insertions(+), 32 deletions(-)

diff --git a/arbor/include/arbor/simd/simd.hpp b/arbor/include/arbor/simd/simd.hpp
index 3c3ba9e7..48c2e0a0 100644
--- a/arbor/include/arbor/simd/simd.hpp
+++ b/arbor/include/arbor/simd/simd.hpp
@@ -46,24 +46,24 @@ typename detail::simd_impl<Impl>::scalar_type sum(const detail::simd_impl<Impl>&
     return a.sum();
 };
 
-#define ARB_UNARY_ARITHMETIC_(name)\
-template <typename Impl>\
-detail::simd_impl<Impl> name(const detail::simd_impl<Impl>& a) {\
-    return detail::simd_impl<Impl>::wrap(Impl::name(a.value_));\
+#define ARB_UNARY_ARITHMETIC_(name)                                     \
+template <typename Impl>                                                \
+detail::simd_impl<Impl> name(const detail::simd_impl<Impl>& a) {        \
+    return detail::simd_impl<Impl>::wrap(Impl::name(a.value_));         \
 };
 
-#define ARB_BINARY_ARITHMETIC_(name)\
-template <typename Impl>\
-detail::simd_impl<Impl> name(const detail::simd_impl<Impl>& a, detail::simd_impl<Impl> b) {\
-    return detail::simd_impl<Impl>::wrap(Impl::name(a.value_, b.value_));\
-};\
-template <typename Impl>\
-detail::simd_impl<Impl> name(const detail::simd_impl<Impl>& a, typename detail::simd_impl<Impl>::scalar_type b) {\
-    return detail::simd_impl<Impl>::wrap(Impl::name(a.value_, Impl::broadcast(b)));\
-};\
-template <typename Impl>\
-detail::simd_impl<Impl> name(const typename detail::simd_impl<Impl>::scalar_type a, detail::simd_impl<Impl> b) {\
-    return detail::simd_impl<Impl>::wrap(Impl::name(Impl::broadcast(a), b.value_));\
+#define ARB_BINARY_ARITHMETIC_(name)                                                                              \
+template <typename Impl>                                                                                          \
+detail::simd_impl<Impl> name(const detail::simd_impl<Impl>& a, detail::simd_impl<Impl> b) {                       \
+    return detail::simd_impl<Impl>::wrap(Impl::name(a.value_, b.value_));                                         \
+};                                                                                                                \
+template <typename Impl>                                                                                          \
+detail::simd_impl<Impl> name(const detail::simd_impl<Impl>& a, typename detail::simd_impl<Impl>::scalar_type b) { \
+    return detail::simd_impl<Impl>::wrap(Impl::name(a.value_, Impl::broadcast(b)));                               \
+};                                                                                                                \
+template <typename Impl>                                                                                          \
+detail::simd_impl<Impl> name(const typename detail::simd_impl<Impl>::scalar_type a, detail::simd_impl<Impl> b) {  \
+    return detail::simd_impl<Impl>::wrap(Impl::name(Impl::broadcast(a), b.value_));                               \
 };
 
 #define ARB_BINARY_COMPARISON_(name)\
@@ -661,9 +661,11 @@ namespace detail {
 
         // Maths functions are implemented as top-level functions; declare as friends for access to `wrap`
 
-        #define ARB_DECLARE_UNARY_ARITHMETIC_(name)\
-        template <typename T>\
-        friend simd_impl<T> arb::simd::name(const simd_impl<T>& a);
+        #define ARB_DECLARE_UNARY_ARITHMETIC_(name)                 \
+        template <typename T>                                       \
+        friend simd_impl<T> arb::simd::name(const simd_impl<T>& a); \
+        template <typename T>                                       \
+        friend simd_impl<T> arb::simd::name(const typename simd_impl<T>::scalar_type a);
 
         #define ARB_DECLARE_BINARY_ARITHMETIC_(name)\
         template <typename T>\
diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp
index 894117b8..4c685a7b 100644
--- a/modcc/printer/cexpr_emit.cpp
+++ b/modcc/printer/cexpr_emit.cpp
@@ -214,7 +214,7 @@ void SimdExprEmitter::visit(UnaryExpression* e) {
         {tok::exp,        "S::exp"},
         {tok::cos,        "S::cos"},
         {tok::sin,        "S::sin"},
-        {tok::log,        "S::log"},
+        {tok::log,        "log"}, // We use an overload to upcast scalars here.
         {tok::abs,        "S::abs"},
         {tok::exprelr,    "S::exprelr"},
         {tok::safeinv,    "safeinv"},
diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp
index 708ad24f..40cb5dea 100644
--- a/modcc/printer/cprinter.cpp
+++ b/modcc/printer/cprinter.cpp
@@ -168,7 +168,6 @@ ARB_LIBMODCC_API std::string emit_cpp_source(const Module& module_, const printe
         "using ::std::abs;\n"
         "using ::std::cos;\n"
         "using ::std::exp;\n"
-        "using ::std::log;\n"
         "using ::std::max;\n"
         "using ::std::min;\n"
         "using ::std::pow;\n"
@@ -222,10 +221,15 @@ ARB_LIBMODCC_API std::string emit_cpp_source(const Module& module_, const printe
             "    S::where(mask, x) = simd_cast<simd_value>(DBL_EPSILON);\n"
             "    return S::div(ones, x);\n"
             "}\n"
+            "\n"
+            "inline simd_value log(const simd_value& v) { return S::log(v); }\n"
+            "inline simd_value log(arb_value_type v) { return S::log(S::simd_cast<simd_value>(v)); }\n"
             "\n";
     } else {
        out << "static constexpr unsigned simd_width_ = 1;\n"
-              "static constexpr unsigned min_align_ = std::max(alignof(arb_value_type), alignof(arb_index_type));\n\n";
+              "static constexpr unsigned min_align_ = std::max(alignof(arb_value_type), alignof(arb_index_type));\n"
+              "using ::std::log;\n"
+              "\n";
     }
 
     // Make implementations
diff --git a/modcc/symdiff.cpp b/modcc/symdiff.cpp
index 1f0ea6a8..c6fc0dfc 100644
--- a/modcc/symdiff.cpp
+++ b/modcc/symdiff.cpp
@@ -300,7 +300,7 @@ private:
 };
 
 ARB_LIBMODCC_API double expr_value(Expression* e) {
-    return e && e->is_number()? e->is_number()->value(): NAN;
+    return e && e->is_number()? e->is_number()->value(): std::nan("");
 }
 
 class ConstantSimplifyVisitor: public Visitor {
@@ -535,24 +535,53 @@ public:
     void visit(PowBinaryExpression* e) override {
         auto loc = e->location();
         e->lhs()->accept(this);
-        expression_ptr lhs = result();
+        auto lhs = result();
         e->rhs()->accept(this);
-        expression_ptr rhs = result();
+        auto rhs = result();
+        auto lval = expr_value(lhs);
+        auto rval = expr_value(rhs);
+
+        auto rint  = std::nan("");
+        auto rfrac = std::modf(rval, &rint);
+
+        auto mk_f64 = [loc](double v) { return make_expression<NumberExpression>(loc, v); };
 
         if (is_number(lhs) && is_number(rhs)) {
-            as_number(loc, std::pow(expr_value(lhs),expr_value(rhs)));
+            as_number(loc, std::pow(lval, rval));
         }
-        else if (expr_value(lhs)==0) {
+        else if (lval == 0) {
             as_number(loc, 0);
         }
-        else if (expr_value(rhs)==0 || expr_value(lhs)==1) {
+        else if (lval == 1) {
             as_number(loc, 1);
         }
-        else if (expr_value(rhs)==1) {
-            result_ = std::move(lhs);
+        else if (rval == 0) {
+            as_number(loc, 1);
+        }
+        else if (rfrac == 0.0 && std::abs(rint) <= 5.0) { // NOTE somewhat arbitray cut-off; but in line with GCC AFAIR
+            result_ = lhs->clone();
+            for (int ix = 1; ix < std::abs(rint); ++ix) {
+                result_ = make_expression<MulBinaryExpression>(loc,
+                                                               lhs->clone(),
+                                                               std::move(result_));
+            }
+            if (rval < 0.0) {
+                result_ = make_expression<DivBinaryExpression>(loc,
+                                                               mk_f64(1.0),
+                                                               std::move(result_));
+            }
+        }
+        else if (lval < 0.0) {
+            result_ = make_expression<PowBinaryExpression>(loc,
+                                                           mk_f64(lval),
+                                                           std::move(rhs));
         }
         else {
-            result_ = make_expression<PowBinaryExpression>(loc, std::move(lhs), std::move(rhs));
+            result_ = make_expression<ExpUnaryExpression>(loc,
+                                                          make_expression<MulBinaryExpression>(loc,
+                                                                                               make_expression<LogUnaryExpression>(loc,
+                                                                                                                                   std::move(lhs)),
+                                                                                               std::move(rhs)));
         }
     }
 
diff --git a/test/unit-modcc/test_symdiff.cpp b/test/unit-modcc/test_symdiff.cpp
index 22651170..c9bf6ad4 100644
--- a/test/unit-modcc/test_symdiff.cpp
+++ b/test/unit-modcc/test_symdiff.cpp
@@ -102,6 +102,35 @@ TEST(constant_simplify, constants) {
     }
 }
 
+TEST(constant_simplify, powers) {
+    // Expect simplification of 'before' expression matches parse of 'after'.
+    // Use output string representation of expression for easy comparison.
+
+    struct { const char* before; const char* after; } tests[] = {
+        { "x^-1",            "1/x" },
+        { "x^1",             "x" },
+        { "x^2",             "x*x" },
+        { "x^-2",            "1/(x*x)" },
+        { "2^4",             "16" },
+        { "x^y",             "exp(log(x)*y)" },
+        { "(-6)^2",          "36" },
+        { "(-6)^2",          "36" },
+        // { "(-3)^3",          "-27" },  // NOTE doesn't work due to some parser troubles.
+        // { "(-6)^x",          "-6^x" }, // NOTE doesn't work due to some parser troubles.
+    };
+
+    for (const auto& item: tests) {
+        SCOPED_TRACE(std::string("expressions: ")+item.before+"; "+item.after);
+        std::cerr << item.before << " => " << item.after << '\n';
+        auto before = Parser{item.before}.parse_expression();
+        auto after = Parser{item.after}.parse_expression();
+        ASSERT_TRUE(before);
+        ASSERT_TRUE(after);
+        EXPECT_EXPR_EQ(after, constant_simplify(before));
+    }
+}
+
+
 TEST(constant_simplify, simplified_expr) {
     // Expect simplification of 'before' expression matches parse of 'after'.
     // Use output string representation of expression for easy comparison.
@@ -203,7 +232,7 @@ TEST(symbolic_pdiff, nonlinear) {
         { "sin(x)",      "cos(x)" },
         { "exp(2*x)",    "2*exp(2*x)" },
         { "x^2",         "2*x" },
-        { "a^x",         "log(a)*a^x" }
+        { "a^x",         "log(a)*exp(log(a) * x)" }
     };
 
     for (const auto& item: tests) {
-- 
GitLab