From fbd411844ec8fb495d5012ce7cc7332bf1829060 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Wed, 21 Dec 2022 13:12:59 +0100 Subject: [PATCH] Better handling of powers in modcc. (#2061) - $x^{-1} \Rightarrow 1/x\quad \forall x$ - $x^n \Rightarrow x\cdot \dots \cdot x \quad x\in N; |x| < 5$ - $x^n \Rightarrow 1/(x\cdot \dots \cdot x) \quad x\in N; |x| < 5; x < 0$ - $b^e \Rightarrow \exp(\log(b) e)\quad \forall b, e$ The last point introduces potential errors when `pow(b, e)` is allowed, but `log(b)` is undefined. These occur exactly when all of the following is true - $b < 0$ - $e\in N$ - $e$, $b$ not known at compile time (since we cover these cases before) --- arbor/include/arbor/simd/simd.hpp | 40 +++++++++++++------------- modcc/printer/cexpr_emit.cpp | 2 +- modcc/printer/cprinter.cpp | 8 ++++-- modcc/symdiff.cpp | 47 +++++++++++++++++++++++++------ test/unit-modcc/test_symdiff.cpp | 31 +++++++++++++++++++- 5 files changed, 96 insertions(+), 32 deletions(-) diff --git a/arbor/include/arbor/simd/simd.hpp b/arbor/include/arbor/simd/simd.hpp index 3c3ba9e7..48c2e0a0 100644 --- a/arbor/include/arbor/simd/simd.hpp +++ b/arbor/include/arbor/simd/simd.hpp @@ -46,24 +46,24 @@ typename detail::simd_impl<Impl>::scalar_type sum(const detail::simd_impl<Impl>& return a.sum(); }; -#define ARB_UNARY_ARITHMETIC_(name)\ -template <typename Impl>\ -detail::simd_impl<Impl> name(const detail::simd_impl<Impl>& a) {\ - return detail::simd_impl<Impl>::wrap(Impl::name(a.value_));\ +#define ARB_UNARY_ARITHMETIC_(name) \ +template <typename Impl> \ +detail::simd_impl<Impl> name(const detail::simd_impl<Impl>& a) { \ + return detail::simd_impl<Impl>::wrap(Impl::name(a.value_)); \ }; -#define ARB_BINARY_ARITHMETIC_(name)\ -template <typename Impl>\ -detail::simd_impl<Impl> name(const detail::simd_impl<Impl>& a, detail::simd_impl<Impl> b) {\ - return detail::simd_impl<Impl>::wrap(Impl::name(a.value_, b.value_));\ -};\ -template <typename Impl>\ -detail::simd_impl<Impl> name(const detail::simd_impl<Impl>& a, typename detail::simd_impl<Impl>::scalar_type b) {\ - return detail::simd_impl<Impl>::wrap(Impl::name(a.value_, Impl::broadcast(b)));\ -};\ -template <typename Impl>\ -detail::simd_impl<Impl> name(const typename detail::simd_impl<Impl>::scalar_type a, detail::simd_impl<Impl> b) {\ - return detail::simd_impl<Impl>::wrap(Impl::name(Impl::broadcast(a), b.value_));\ +#define ARB_BINARY_ARITHMETIC_(name) \ +template <typename Impl> \ +detail::simd_impl<Impl> name(const detail::simd_impl<Impl>& a, detail::simd_impl<Impl> b) { \ + return detail::simd_impl<Impl>::wrap(Impl::name(a.value_, b.value_)); \ +}; \ +template <typename Impl> \ +detail::simd_impl<Impl> name(const detail::simd_impl<Impl>& a, typename detail::simd_impl<Impl>::scalar_type b) { \ + return detail::simd_impl<Impl>::wrap(Impl::name(a.value_, Impl::broadcast(b))); \ +}; \ +template <typename Impl> \ +detail::simd_impl<Impl> name(const typename detail::simd_impl<Impl>::scalar_type a, detail::simd_impl<Impl> b) { \ + return detail::simd_impl<Impl>::wrap(Impl::name(Impl::broadcast(a), b.value_)); \ }; #define ARB_BINARY_COMPARISON_(name)\ @@ -661,9 +661,11 @@ namespace detail { // Maths functions are implemented as top-level functions; declare as friends for access to `wrap` - #define ARB_DECLARE_UNARY_ARITHMETIC_(name)\ - template <typename T>\ - friend simd_impl<T> arb::simd::name(const simd_impl<T>& a); + #define ARB_DECLARE_UNARY_ARITHMETIC_(name) \ + template <typename T> \ + friend simd_impl<T> arb::simd::name(const simd_impl<T>& a); \ + template <typename T> \ + friend simd_impl<T> arb::simd::name(const typename simd_impl<T>::scalar_type a); #define ARB_DECLARE_BINARY_ARITHMETIC_(name)\ template <typename T>\ diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp index 894117b8..4c685a7b 100644 --- a/modcc/printer/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -214,7 +214,7 @@ void SimdExprEmitter::visit(UnaryExpression* e) { {tok::exp, "S::exp"}, {tok::cos, "S::cos"}, {tok::sin, "S::sin"}, - {tok::log, "S::log"}, + {tok::log, "log"}, // We use an overload to upcast scalars here. {tok::abs, "S::abs"}, {tok::exprelr, "S::exprelr"}, {tok::safeinv, "safeinv"}, diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index 708ad24f..40cb5dea 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -168,7 +168,6 @@ ARB_LIBMODCC_API std::string emit_cpp_source(const Module& module_, const printe "using ::std::abs;\n" "using ::std::cos;\n" "using ::std::exp;\n" - "using ::std::log;\n" "using ::std::max;\n" "using ::std::min;\n" "using ::std::pow;\n" @@ -222,10 +221,15 @@ ARB_LIBMODCC_API std::string emit_cpp_source(const Module& module_, const printe " S::where(mask, x) = simd_cast<simd_value>(DBL_EPSILON);\n" " return S::div(ones, x);\n" "}\n" + "\n" + "inline simd_value log(const simd_value& v) { return S::log(v); }\n" + "inline simd_value log(arb_value_type v) { return S::log(S::simd_cast<simd_value>(v)); }\n" "\n"; } else { out << "static constexpr unsigned simd_width_ = 1;\n" - "static constexpr unsigned min_align_ = std::max(alignof(arb_value_type), alignof(arb_index_type));\n\n"; + "static constexpr unsigned min_align_ = std::max(alignof(arb_value_type), alignof(arb_index_type));\n" + "using ::std::log;\n" + "\n"; } // Make implementations diff --git a/modcc/symdiff.cpp b/modcc/symdiff.cpp index 1f0ea6a8..c6fc0dfc 100644 --- a/modcc/symdiff.cpp +++ b/modcc/symdiff.cpp @@ -300,7 +300,7 @@ private: }; ARB_LIBMODCC_API double expr_value(Expression* e) { - return e && e->is_number()? e->is_number()->value(): NAN; + return e && e->is_number()? e->is_number()->value(): std::nan(""); } class ConstantSimplifyVisitor: public Visitor { @@ -535,24 +535,53 @@ public: void visit(PowBinaryExpression* e) override { auto loc = e->location(); e->lhs()->accept(this); - expression_ptr lhs = result(); + auto lhs = result(); e->rhs()->accept(this); - expression_ptr rhs = result(); + auto rhs = result(); + auto lval = expr_value(lhs); + auto rval = expr_value(rhs); + + auto rint = std::nan(""); + auto rfrac = std::modf(rval, &rint); + + auto mk_f64 = [loc](double v) { return make_expression<NumberExpression>(loc, v); }; if (is_number(lhs) && is_number(rhs)) { - as_number(loc, std::pow(expr_value(lhs),expr_value(rhs))); + as_number(loc, std::pow(lval, rval)); } - else if (expr_value(lhs)==0) { + else if (lval == 0) { as_number(loc, 0); } - else if (expr_value(rhs)==0 || expr_value(lhs)==1) { + else if (lval == 1) { as_number(loc, 1); } - else if (expr_value(rhs)==1) { - result_ = std::move(lhs); + else if (rval == 0) { + as_number(loc, 1); + } + else if (rfrac == 0.0 && std::abs(rint) <= 5.0) { // NOTE somewhat arbitray cut-off; but in line with GCC AFAIR + result_ = lhs->clone(); + for (int ix = 1; ix < std::abs(rint); ++ix) { + result_ = make_expression<MulBinaryExpression>(loc, + lhs->clone(), + std::move(result_)); + } + if (rval < 0.0) { + result_ = make_expression<DivBinaryExpression>(loc, + mk_f64(1.0), + std::move(result_)); + } + } + else if (lval < 0.0) { + result_ = make_expression<PowBinaryExpression>(loc, + mk_f64(lval), + std::move(rhs)); } else { - result_ = make_expression<PowBinaryExpression>(loc, std::move(lhs), std::move(rhs)); + result_ = make_expression<ExpUnaryExpression>(loc, + make_expression<MulBinaryExpression>(loc, + make_expression<LogUnaryExpression>(loc, + std::move(lhs)), + std::move(rhs))); } } diff --git a/test/unit-modcc/test_symdiff.cpp b/test/unit-modcc/test_symdiff.cpp index 22651170..c9bf6ad4 100644 --- a/test/unit-modcc/test_symdiff.cpp +++ b/test/unit-modcc/test_symdiff.cpp @@ -102,6 +102,35 @@ TEST(constant_simplify, constants) { } } +TEST(constant_simplify, powers) { + // Expect simplification of 'before' expression matches parse of 'after'. + // Use output string representation of expression for easy comparison. + + struct { const char* before; const char* after; } tests[] = { + { "x^-1", "1/x" }, + { "x^1", "x" }, + { "x^2", "x*x" }, + { "x^-2", "1/(x*x)" }, + { "2^4", "16" }, + { "x^y", "exp(log(x)*y)" }, + { "(-6)^2", "36" }, + { "(-6)^2", "36" }, + // { "(-3)^3", "-27" }, // NOTE doesn't work due to some parser troubles. + // { "(-6)^x", "-6^x" }, // NOTE doesn't work due to some parser troubles. + }; + + for (const auto& item: tests) { + SCOPED_TRACE(std::string("expressions: ")+item.before+"; "+item.after); + std::cerr << item.before << " => " << item.after << '\n'; + auto before = Parser{item.before}.parse_expression(); + auto after = Parser{item.after}.parse_expression(); + ASSERT_TRUE(before); + ASSERT_TRUE(after); + EXPECT_EXPR_EQ(after, constant_simplify(before)); + } +} + + TEST(constant_simplify, simplified_expr) { // Expect simplification of 'before' expression matches parse of 'after'. // Use output string representation of expression for easy comparison. @@ -203,7 +232,7 @@ TEST(symbolic_pdiff, nonlinear) { { "sin(x)", "cos(x)" }, { "exp(2*x)", "2*exp(2*x)" }, { "x^2", "2*x" }, - { "a^x", "log(a)*a^x" } + { "a^x", "log(a)*exp(log(a) * x)" } }; for (const auto& item: tests) { -- GitLab