diff --git a/arbor/iexpr.cpp b/arbor/iexpr.cpp index 339e4e0dcf01591c99a695f783f171812c7e5659..eaaba7cd07b61957c976601d7061f6269ed653a2 100644 --- a/arbor/iexpr.cpp +++ b/arbor/iexpr.cpp @@ -281,6 +281,32 @@ struct exp: public iexpr_interface { iexpr_ptr value; }; +struct step_right: public iexpr_interface { + step_right(iexpr_ptr v): value(std::move(v)) {} + + double eval(const mprovider& p, const mcable& c) const override { + double x = value->eval(p, c); + // x < 0: 0 + // x >= 0: 1 + return (x >= 0.); + } + + iexpr_ptr value; +}; + +struct step_left: public iexpr_interface { + step_left(iexpr_ptr v): value(std::move(v)) {} + + double eval(const mprovider& p, const mcable& c) const override { + double x = value->eval(p, c); + // x <= 0: 0 + // x > 0: 1 + return (x > 0.); + } + + iexpr_ptr value; +}; + struct step: public iexpr_interface { step(iexpr_ptr v): value(std::move(v)) {} @@ -416,6 +442,10 @@ iexpr iexpr::div(iexpr left, iexpr right) { iexpr iexpr::exp(iexpr value) { return iexpr(iexpr_type::exp, std::make_tuple(std::move(value))); } +iexpr iexpr::step_right(iexpr value) { return iexpr(iexpr_type::step_right, std::make_tuple(std::move(value))); } + +iexpr iexpr::step_left(iexpr value) { return iexpr(iexpr_type::step_left, std::make_tuple(std::move(value))); } + iexpr iexpr::step(iexpr value) { return iexpr(iexpr_type::step, std::make_tuple(std::move(value))); } iexpr iexpr::log(iexpr value) { return iexpr(iexpr_type::log, std::make_tuple(std::move(value))); } @@ -504,6 +534,12 @@ iexpr_ptr thingify(const iexpr& expr, const mprovider& m) { case iexpr_type::exp: return iexpr_ptr(new iexpr_impl::exp( thingify(std::get<0>(std::any_cast<const std::tuple<iexpr>&>(expr.args())), m))); + case iexpr_type::step_right: + return iexpr_ptr(new iexpr_impl::step_right( + thingify(std::get<0>(std::any_cast<const std::tuple<iexpr>&>(expr.args())), m))); + case iexpr_type::step_left: + return iexpr_ptr(new iexpr_impl::step_left( + thingify(std::get<0>(std::any_cast<const std::tuple<iexpr>&>(expr.args())), m))); case iexpr_type::step: return iexpr_ptr(new iexpr_impl::step( thingify(std::get<0>(std::any_cast<const std::tuple<iexpr>&>(expr.args())), m))); @@ -600,6 +636,14 @@ std::ostream& operator<<(std::ostream& o, const iexpr& e) { o << "exp " << std::get<0>(std::any_cast<const std::tuple<iexpr>&>(e.args())); break; } + case iexpr_type::step_right: { + o << "step_right " << std::get<0>(std::any_cast<const std::tuple<iexpr>&>(e.args())); + break; + } + case iexpr_type::step_left: { + o << "step_left " << std::get<0>(std::any_cast<const std::tuple<iexpr>&>(e.args())); + break; + } case iexpr_type::step: { o << "step " << std::get<0>(std::any_cast<const std::tuple<iexpr>&>(e.args())); break; diff --git a/arbor/include/arbor/iexpr.hpp b/arbor/include/arbor/iexpr.hpp index 3164c53d975e1c071f7d8ede4ef7124d74932428..124612e59ecaac5b039a487b823c53ac9f23c3ad 100644 --- a/arbor/include/arbor/iexpr.hpp +++ b/arbor/include/arbor/iexpr.hpp @@ -29,6 +29,8 @@ enum class iexpr_type { mul, div, exp, + step_right, + step_left, step, log, named @@ -98,6 +100,10 @@ struct ARB_SYMBOL_VISIBLE iexpr { static iexpr exp(iexpr value); + static iexpr step_right(iexpr value); + + static iexpr step_left(iexpr value); + static iexpr step(iexpr value); static iexpr log(iexpr value); diff --git a/arbor/include/arbor/simd/implbase.hpp b/arbor/include/arbor/simd/implbase.hpp index 4c3766b210e8861dc3b49d4979f81af87e38c722..968ffce0353a428809f9ff3f71b2fa60eb99c90a 100644 --- a/arbor/include/arbor/simd/implbase.hpp +++ b/arbor/include/arbor/simd/implbase.hpp @@ -22,6 +22,7 @@ // exp | lane-wise std::exp // log | lane-wise std::log // pow | lane-wise std::pow +// sqrt | lane-wise std::sqrt // expm1 | lane-wise std::expm1 // exprelr | expm1, div, add, cmp_eq, ifelse // @@ -468,6 +469,35 @@ struct implbase { return I::ifelse(I::cmp_gt(t, s), t, s); } + static vector_type step_right(const vector_type& s) { + vector_type zeros = I::broadcast(0); + vector_type ones = I::broadcast(1); + return I::ifelse(I::cmp_geq(s,zeros), ones, zeros); + } + + static vector_type step_left(const vector_type& s) { + vector_type zeros = I::broadcast(0); + vector_type ones = I::broadcast(1); + return I::ifelse(I::cmp_gt(s,zeros), ones, zeros); + } + + static vector_type step(const vector_type& s) { + vector_type zeros = I::broadcast(0); + vector_type halfs = I::broadcast(0.5); + return I::add( + I::sub( + I::ifelse(I::cmp_gt(s,zeros), halfs, zeros), + I::ifelse(I::cmp_gt(zeros,s), halfs, zeros)), + halfs); + } + + static vector_type signum(const vector_type& s) { + vector_type zeros = I::broadcast(0); + vector_type ones = I::broadcast(1); + return I::sub(I::ifelse(I::cmp_gt(s,zeros), ones, zeros), + I::ifelse(I::cmp_gt(zeros,s), ones, zeros)); + } + static vector_type sin(const vector_type& s) { store a, r; I::copy_to(s, a); @@ -533,6 +563,16 @@ struct implbase { } return I::copy_from(r); } + + static vector_type sqrt(const vector_type& s) { + store a, r; + I::copy_to(s, a); + + for (unsigned i = 0; i<width; ++i) { + r[i] = std::sqrt(a[i]); + } + return I::copy_from(r); + } }; } // namespace detail diff --git a/arbor/include/arbor/simd/simd.hpp b/arbor/include/arbor/simd/simd.hpp index d9acc07fdbf827bbee55fc5e982ae26123f2acbf..966c472f2e50e734c43e9ae3278a724c43168331 100644 --- a/arbor/include/arbor/simd/simd.hpp +++ b/arbor/include/arbor/simd/simd.hpp @@ -82,7 +82,7 @@ typename detail::simd_impl<Impl>::simd_mask name(const typename detail::simd_imp ARB_PP_FOREACH(ARB_BINARY_ARITHMETIC_, add, sub, mul, div, pow, max, min) ARB_PP_FOREACH(ARB_BINARY_COMPARISON_, cmp_eq, cmp_neq, cmp_leq, cmp_lt, cmp_geq, cmp_gt) -ARB_PP_FOREACH(ARB_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, exprelr) +ARB_PP_FOREACH(ARB_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, exprelr, sqrt, step_right, step_left, step, signum) #undef ARB_BINARY_ARITHMETIC_ #undef ARB_BINARY_COMPARISON__ @@ -677,7 +677,7 @@ namespace detail { ARB_PP_FOREACH(ARB_DECLARE_BINARY_ARITHMETIC_, add, sub, mul, div, pow, max, min, cmp_eq) ARB_PP_FOREACH(ARB_DECLARE_BINARY_COMPARISON_, cmp_eq, cmp_neq, cmp_lt, cmp_leq, cmp_gt, cmp_geq) - ARB_PP_FOREACH(ARB_DECLARE_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, exprelr) + ARB_PP_FOREACH(ARB_DECLARE_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, exprelr, sqrt, step_right, step_left, step, signum) #undef ARB_DECLARE_UNARY_ARITHMETIC_ #undef ARB_DECLARE_BINARY_ARITHMETIC_ diff --git a/arbor/include/arbor/simd/sve.hpp b/arbor/include/arbor/simd/sve.hpp index aa98d50b40850f068b1b6c3e19919c6373584e06..c1547333fcfd129b6802b1773e4560f114244fa2 100644 --- a/arbor/include/arbor/simd/sve.hpp +++ b/arbor/include/arbor/simd/sve.hpp @@ -576,6 +576,17 @@ struct sve_double { return copy_from(r); } + static svfloat64_t sqrt(const svfloat64_t& x) { + auto len = svlen_f64(x); + double a[len], r[len]; + copy_to(x, a); + + for (unsigned i = 0; i<len; ++i) { + r[i] = std::sqrt(a[i]); + } + return copy_from(r); + } + static svfloat64_t cos(const svfloat64_t& x) { auto len = svlen_f64(x); double a[len], r[len]; @@ -693,7 +704,7 @@ auto name(const typename detail::simd_traits<typename detail::sve_type_to_impl<T ARB_PP_FOREACH(ARB_SVE_BINARY_ARITHMETIC_, add, sub, mul, div, pow, max, min) ARB_PP_FOREACH(ARB_SVE_BINARY_ARITHMETIC_, cmp_eq, cmp_neq, cmp_leq, cmp_lt, cmp_geq, cmp_gt, logical_and, logical_or) -ARB_PP_FOREACH(ARB_SVE_UNARY_ARITHMETIC_, logical_not, neg, abs, exp, log, expm1, exprelr, cos, sin) +ARB_PP_FOREACH(ARB_SVE_UNARY_ARITHMETIC_, logical_not, neg, abs, exp, log, expm1, exprelr, cos, sin, sqrt) #undef ARB_SVE_UNARY_ARITHMETIC_ #undef ARB_SVE_BINARY_ARITHMETIC_ diff --git a/arborio/label_parse.cpp b/arborio/label_parse.cpp index 02df7dad8912f24bde490c9fe2e621a606e1e0af..22df96481cd9bdcd5f67b8cdb3355ed34e8fdcaf 100644 --- a/arborio/label_parse.cpp +++ b/arborio/label_parse.cpp @@ -163,6 +163,12 @@ std::unordered_multimap<std::string, evaluator> eval_map { {"exp", make_call<arb::iexpr>(arb::iexpr::exp, "iexpr with 1 argument: (value:iexpr)")}, {"exp", make_call<double>(arb::iexpr::exp, "iexpr with 1 argument: (value:double)")}, + {"step_right", make_call<arb::iexpr>(arb::iexpr::step_right, "iexpr with 1 argument: (value:iexpr)")}, + {"step_right", make_call<double>(arb::iexpr::step_right, "iexpr with 1 argument: (value:double)")}, + + {"step_left", make_call<arb::iexpr>(arb::iexpr::step_left, "iexpr with 1 argument: (value:iexpr)")}, + {"step_left", make_call<double>(arb::iexpr::step_left, "iexpr with 1 argument: (value:double)")}, + {"step", make_call<arb::iexpr>(arb::iexpr::step, "iexpr with 1 argument: (value:iexpr)")}, {"step", make_call<double>(arb::iexpr::step, "iexpr with 1 argument: (value:double)")}, diff --git a/doc/concepts/labels.rst b/doc/concepts/labels.rst index ea3151bacb81dfdeed00f661faaa83695a5d5e70..d0e6ec74cb2b5a378c80f8210d17b35bf5be45ee 100644 --- a/doc/concepts/labels.rst +++ b/doc/concepts/labels.rst @@ -762,6 +762,14 @@ Inhomogeneous Expressions The exponential function of the inhomogeneous expression or real ``value``. +.. label:: (step_right value:(iexpr | real)) + + The Heaviside step function of the inhomogeneous expression or real ``value``, with `(step 0.0)` evaluating to 1. + +.. label:: (step_left value:(iexpr | real)) + + The Heaviside step function of the inhomogeneous expression or real ``value``, with `(step 0.0)` evaluating to 0. + .. label:: (step value:(iexpr | real)) The Heaviside step function of the inhomogeneous expression or real ``value``, with `(step 0.0)` evaluating to 0.5. diff --git a/doc/dev/simd_api.rst b/doc/dev/simd_api.rst index a8236f039455618ba7bfe1a5ebe430033c2e2954..a8324fc64d0f5364026624aea051fe55b584b596 100644 --- a/doc/dev/simd_api.rst +++ b/doc/dev/simd_api.rst @@ -574,6 +574,26 @@ In the following: - *S* - Lane-wise raise *s* to the power of *t*. + * - ``sqrt(s)`` + - *S* + - Lane-wise square root of *s*. + + * - ``signum(s)`` + - *S* + - Lane-wise :math:`x \mapsto \begin{align*} +1 & ~~ \text{if} ~x \gt 0, \\ -1 & ~~ \text{if} ~x \lt 0, \\ 0 & ~~ \text{otherwise}. \end{align*}` + + * - ``step(s)`` + - *S* + - Lane-wise :math:`x \mapsto \begin{align*} 1 & ~~ \text{if} ~x \gt 0, \\ 0 & ~~ \text{if} ~x \lt 0, \\ 0.5 & ~~ \text{otherwise}. \end{align*}` + + * - ``step_right(s)`` + - *S* + - Lane-wise :math:`x \mapsto \begin{align*} 1 & ~~ \text{if} ~x \geq 0, \\ 0 & ~~ \text{otherwise}. \end{align*}` + + * - ``step_left(s)`` + - *S* + - Lane-wise :math:`x \mapsto \begin{align*} 1 & ~~ \text{if} ~x \gt 0, \\ 0 & ~~ \text{otherwise}. \end{align*}` + * - ``simd_cast<std::array<L, N>>(a)`` - ``std::array<L, N>`` - Lane-wise cast of values in *a* to scalar type *L* in ``std::array<L, N>``. diff --git a/doc/fileformat/nmodl.rst b/doc/fileformat/nmodl.rst index b975fc59d016356ac67a7db3af2a7d2373b5d1c7..a125cc82d8d393ce7e48993ecd5aae95224bd7b7 100644 --- a/doc/fileformat/nmodl.rst +++ b/doc/fileformat/nmodl.rst @@ -129,6 +129,8 @@ Unsupported features * ``LOCAL`` variables outside blocks are not supported. * ``INDEPENDENT`` variables are not supported. +.. _arbornmodl: + Arbor-specific features ----------------------- @@ -158,6 +160,21 @@ Arbor-specific features made available through the ``v_peer`` variable while the local membrane potential is available through ``v``, as usual. +* Arbor offers a number of additional unary math functions which may offer improved performance + compared to hand-rolled solutions (especially with the vectorized and GPU backends). + All of the following functions take a single argument `x` and return a + floating point value. + + ================== ===================================== ========= + Function name Description Semantics + ================== ===================================== ========= + sqrt(x) square root :math:`\sqrt{x}` + step_right(x) right-continuous heaviside step :math:`\begin{align*} 1 & ~~ \text{if} ~x \geq 0, \\ 0 & ~~ \text{otherwise}. \end{align*}` + step_left(x) left-continuous heaviside step :math:`\begin{align*} 1 & ~~ \text{if} ~x \gt 0, \\ 0 & ~~ \text{otherwise}. \end{align*}` + step(x) heaviside step with half value :math:`\begin{align*} 1 & ~~ \text{if} ~x \gt 0, \\ 0 & ~~ \text{if} ~x \lt 0, \\ 0.5 & ~~ \text{otherwise}. \end{align*}` + signum(x) sign of argument :math:`\begin{align*} +1 & ~~ \text{if} ~x \gt 0, \\ -1 & ~~ \text{if} ~x \lt 0, \\ 0 & ~~ \text{otherwise}. \end{align*}` + exprelr(x) guarded exponential :math:`x e^{1-x}` + ================== ===================================== ========= .. _format-sde: @@ -422,7 +439,9 @@ of this pattern from ``hh.mod`` in the Arbor sources Specialised Functions ~~~~~~~~~~~~~~~~~~~~~ -Another common pattern is the use of a guarded exponential of the form +Some extra cost can be saved by choosing Arbor-specific optimized math functions instead of +hand-rolled versions. Please consult the table in :ref:`this section <arbornmodl>`. +A common pattern is the use of a guarded exponential of the form .. code:: @@ -432,8 +451,7 @@ Another common pattern is the use of a guarded exponential of the form r = x } -This incurs some extra cost on most platforms. However, it can be written in -Arbor's NMODL dialect as +However, it can be written in Arbor's NMODL dialect as .. code:: diff --git a/mechanisms/stochastic/ou_input.mod b/mechanisms/stochastic/ou_input.mod index 2c63d75a7c018b71a3f51421d670c3add20d8ba2..bd6d2fe36ac6ba41465db54d445cc0819bb1fa1c 100644 --- a/mechanisms/stochastic/ou_input.mod +++ b/mechanisms/stochastic/ou_input.mod @@ -53,7 +53,7 @@ INITIAL { I_ou = 0 active = -1 alpha = 1.0/tau - beta = sigma * (2.0/tau)^0.5 + beta = sigma * sqrt(2.0/tau) } BREAKPOINT { @@ -62,7 +62,7 @@ BREAKPOINT { } DERIVATIVE state { - I_ou' = heaviside(active) * (alpha * (mu - I_ou) + beta * W) + I_ou' = step_right(active) * (alpha * (mu - I_ou) + beta * W) } NET_RECEIVE(weight) { @@ -75,12 +75,3 @@ NET_RECEIVE(weight) { active = -1 } } - -FUNCTION heaviside(x) { - if (x >= 0) { - heaviside = 1 - } - else { - heaviside = 0 - } -} diff --git a/modcc/expression.cpp b/modcc/expression.cpp index f091741030dc6974aa516af088eb19d16f716889..b7ba2aced605d06b651a4c665e6b3ca2039207f5 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -1109,6 +1109,21 @@ void CosUnaryExpression::accept(Visitor *v) { void SinUnaryExpression::accept(Visitor *v) { v->visit(this); } +void SqrtUnaryExpression::accept(Visitor *v) { + v->visit(this); +} +void StepRightUnaryExpression::accept(Visitor *v) { + v->visit(this); +} +void StepLeftUnaryExpression::accept(Visitor *v) { + v->visit(this); +} +void StepUnaryExpression::accept(Visitor *v) { + v->visit(this); +} +void SignumUnaryExpression::accept(Visitor *v) { + v->visit(this); +} void BinaryExpression::accept(Visitor *v) { v->visit(this); } @@ -1183,6 +1198,16 @@ ARB_LIBMODCC_API expression_ptr unary_expression( Location loc, return make_expression<ExprelrUnaryExpression>(loc, std::move(e)); case tok::safeinv : return make_expression<SafeInvUnaryExpression>(loc, std::move(e)); + case tok::sqrt : + return make_expression<SqrtUnaryExpression>(loc, std::move(e)); + case tok::step_right : + return make_expression<StepRightUnaryExpression>(loc, std::move(e)); + case tok::step_left : + return make_expression<StepLeftUnaryExpression>(loc, std::move(e)); + case tok::step : + return make_expression<StepUnaryExpression>(loc, std::move(e)); + case tok::signum : + return make_expression<SignumUnaryExpression>(loc, std::move(e)); default : std::cerr << yellow(token_string(op)) << " is not a valid unary operator" diff --git a/modcc/expression.hpp b/modcc/expression.hpp index f9ed282c89a60b28bf8fe83bdacc5ce79dc041f4..31188ab58493c78aeb7ce9db1fd582ee26f03ad0 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -1338,6 +1338,66 @@ public: void accept(Visitor *v) override; }; +// sqrt unuary expression, i.e. sqrt(x) +class ARB_LIBMODCC_API SqrtUnaryExpression : public UnaryExpression { +public: + SqrtUnaryExpression(Location loc, expression_ptr e) + : UnaryExpression(loc, tok::sqrt, std::move(e)) + {} + + void accept(Visitor *v) override; +}; + +// step_right unary expression, +// i.e. step_right(x) = 0, for x < 0 +// 1, otherwise +class ARB_LIBMODCC_API StepRightUnaryExpression : public UnaryExpression { +public: + StepRightUnaryExpression(Location loc, expression_ptr e) + : UnaryExpression(loc, tok::step_right, std::move(e)) + {} + + void accept(Visitor *v) override; +}; + +// step_left unary expression, +// i.e. step_left(x) = 0, for x <= 0 +// 1, otherwise +class ARB_LIBMODCC_API StepLeftUnaryExpression : public UnaryExpression { +public: + StepLeftUnaryExpression(Location loc, expression_ptr e) + : UnaryExpression(loc, tok::step_left, std::move(e)) + {} + + void accept(Visitor *v) override; +}; + +// step unary expression, +// i.e. step(x) = 0, for x < 0 +// 1, for x > 0 +// 0.5, otherwise +class ARB_LIBMODCC_API StepUnaryExpression : public UnaryExpression { +public: + StepUnaryExpression(Location loc, expression_ptr e) + : UnaryExpression(loc, tok::step, std::move(e)) + {} + + void accept(Visitor *v) override; +}; + +// signum unary expression, +// i.e. signum(x) = -1, for x < 0 +// 0, for x = 0 +// +1, otherwise +class ARB_LIBMODCC_API SignumUnaryExpression : public UnaryExpression { +public: + SignumUnaryExpression(Location loc, expression_ptr e) + : UnaryExpression(loc, tok::signum, std::move(e)) + {} + + void accept(Visitor *v) override; +}; + //////////////////////////////////////////////////////////// // binary expressions diff --git a/modcc/parser.cpp b/modcc/parser.cpp index e7279a92eae6491ede5ad28ada1d409ca9c16925..cc86179f5f63fc912c8d4ac3bdc09b28e70a6f98 100644 --- a/modcc/parser.cpp +++ b/modcc/parser.cpp @@ -1446,6 +1446,11 @@ expression_ptr Parser::parse_unaryop() { case tok::abs: case tok::safeinv: case tok::exprelr: + case tok::sqrt: + case tok::step_right: + case tok::step_left: + case tok::step: + case tok::signum: get_token(); // consume operator (exp, sin, cos or log) if (token_.type != tok::lparen) { error("missing parenthesis after call to " + yellow(op.spelling)); diff --git a/modcc/perfvisitor.hpp b/modcc/perfvisitor.hpp index 694a9284dfe977e02bd1d1cbb3b5935600223905..2971fa4a578f558558dd505d97b57b146439dfa6 100644 --- a/modcc/perfvisitor.hpp +++ b/modcc/perfvisitor.hpp @@ -16,9 +16,10 @@ struct FlopAccumulator { int cos=0; int log=0; int pow=0; + int sqrt=0; void reset() { - add = neg = mul = div = exp = sin = cos = log = 0; + add = neg = mul = div = exp = sin = cos = log = pow = sqrt = 0; } }; @@ -26,8 +27,8 @@ static std::ostream& operator << (std::ostream& os, FlopAccumulator const& f) { char buffer[512]; snprintf(buffer, 512, - " add neg mul div exp sin cos log pow\n%6d%6d%6d%6d%6d%6d%6d%6d%6d", - f.add, f.neg, f.mul, f.div, f.exp, f.sin, f.cos, f.log, f.pow); + " add neg mul div exp sin cos log pow sqrt\n%6d%6d%6d%6d%6d%6d%6d%6d%6d%6d", + f.add, f.neg, f.mul, f.div, f.exp, f.sin, f.cos, f.log, f.pow, f.sqrt); os << buffer << std::endl << std::endl; os << " add+mul+neg " << f.add + f.neg + f.mul << std::endl; @@ -97,6 +98,19 @@ public: e->expression()->accept(this); flops.sin++; } + void visit(SqrtUnaryExpression *e) override { + e->expression()->accept(this); + flops.sqrt++; + } + void visit(StepUnaryExpression *e) override { + e->expression()->accept(this); + flops.add+=2; + flops.mul++; + } + void visit(SignumUnaryExpression *e) override { + e->expression()->accept(this); + flops.add++; + } //////////////////////////////////////////////////// // specializations for each type of binary expression diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp index 82f3bfb14a6188328b8eee6e67b489afac8b32d6..894117b8be5ecd42407692913d7b8a994f14df6a 100644 --- a/modcc/printer/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -56,14 +56,19 @@ void CExprEmitter::visit(UnaryExpression* e) { // Place a space in front of minus sign to avoid invalid // expressions of the form: (v[i]--67) static std::unordered_map<tok, const char*> unaryop_tbl = { - {tok::minus, " -"}, - {tok::exp, "exp"}, - {tok::cos, "cos"}, - {tok::sin, "sin"}, - {tok::log, "log"}, - {tok::abs, "abs"}, - {tok::exprelr, "exprelr"}, - {tok::safeinv, "safeinv"} + {tok::minus, " -"}, + {tok::exp, "exp"}, + {tok::cos, "cos"}, + {tok::sin, "sin"}, + {tok::log, "log"}, + {tok::abs, "abs"}, + {tok::exprelr, "exprelr"}, + {tok::safeinv, "safeinv"}, + {tok::sqrt, "sqrt"}, + {tok::step_right, "step_right"}, + {tok::step_left, "step_left"}, + {tok::step, "step"}, + {tok::signum, "signum"} }; if (!unaryop_tbl.count(e->op())) { @@ -80,6 +85,30 @@ void CExprEmitter::visit(UnaryExpression* e) { out_ << op_spelling; inner->accept(this); } + else if (e->op()==tok::step_right) { + out_ << "((arb_value_type)(("; + inner->accept(this); + out_ << ")>=0.))"; + } + else if (e->op()==tok::step_left) { + out_ << "((arb_value_type)(("; + inner->accept(this); + out_ << ")>0.))"; + } + else if (e->op()==tok::step) { + out_ << "((arb_value_type)0.5*((0.<("; + inner->accept(this); + out_ << "))-(("; + inner->accept(this); + out_ << ")<0.)+1))"; + } + else if (e->op()==tok::signum) { + out_ << "((arb_value_type)((0.<("; + inner->accept(this); + out_ << "))-(("; + inner->accept(this); + out_ << ")<0.)))"; + } else { emit_as_call(op_spelling, inner); } @@ -181,14 +210,19 @@ void SimdExprEmitter::visit(NumberExpression* e) { void SimdExprEmitter::visit(UnaryExpression* e) { static std::unordered_map<tok, const char*> unaryop_tbl = { - {tok::minus, "S::neg"}, - {tok::exp, "S::exp"}, - {tok::cos, "S::cos"}, - {tok::sin, "S::sin"}, - {tok::log, "S::log"}, - {tok::abs, "S::abs"}, - {tok::exprelr, "S::exprelr"}, - {tok::safeinv, "safeinv"} + {tok::minus, "S::neg"}, + {tok::exp, "S::exp"}, + {tok::cos, "S::cos"}, + {tok::sin, "S::sin"}, + {tok::log, "S::log"}, + {tok::abs, "S::abs"}, + {tok::exprelr, "S::exprelr"}, + {tok::safeinv, "safeinv"}, + {tok::sqrt, "S::sqrt"}, + {tok::step_right, "S::step_right"}, + {tok::step_left, "S::step_left"}, + {tok::step, "S::step"}, + {tok::signum, "S::signum"} }; if (!unaryop_tbl.count(e->op())) { diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index 5d6e95352ea8550f8d48bf3a144542c7a9acdbaf..432c1187bf0967c1ff68ea868654a77007b7d0ca 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -173,6 +173,7 @@ ARB_LIBMODCC_API std::string emit_cpp_source(const Module& module_, const printe "using ::std::min;\n" "using ::std::pow;\n" "using ::std::sin;\n" + "using ::std::sqrt;\n" "\n"; if (with_simd) { diff --git a/modcc/symdiff.cpp b/modcc/symdiff.cpp index 35fdfc8dd70ccc1b22f89849775135f786319c8f..1f0ea6a819fd184cdc34298cb762fc942ef3100e 100644 --- a/modcc/symdiff.cpp +++ b/modcc/symdiff.cpp @@ -234,6 +234,43 @@ public: std::move(dlhs)))); } + void visit(SqrtUnaryExpression* e) override { + auto loc = e->location(); + e->expression()->accept(this); + // d(sqrt(f(x)))/dx = 0.5*(f(x))^(-0.5)*d(f(x))/dx + result_ = make_expression<MulBinaryExpression>(loc, + make_expression<MulBinaryExpression>(loc, + make_expression<NumberExpression>(loc, 0.5), + make_expression<PowBinaryExpression>(loc, + e->expression()->clone(), + make_expression<NumberExpression>(loc, -0.5))), + result()); + } + + void visit(StepRightUnaryExpression* e) override { + // ignore singularity + auto loc = e->location(); + result_ = make_expression<IntegerExpression>(loc, 0); + } + + void visit(StepLeftUnaryExpression* e) override { + // ignore singularity + auto loc = e->location(); + result_ = make_expression<IntegerExpression>(loc, 0); + } + + void visit(StepUnaryExpression* e) override { + // ignore singularity + auto loc = e->location(); + result_ = make_expression<IntegerExpression>(loc, 0); + } + + void visit(SignumUnaryExpression* e) override { + // ignore singularity + auto loc = e->location(); + result_ = make_expression<IntegerExpression>(loc, 0); + } + void visit(CallExpression* e) override { auto loc = e->location(); result_ = make_expression<PDiffExpression>(loc, @@ -371,6 +408,21 @@ public: case tok::log: as_number(loc, std::log(val)); return; + case tok::sqrt: + as_number(loc, std::sqrt(val)); + return; + case tok::step_right: + as_number(loc, (val >= 0.)); + return; + case tok::step_left: + as_number(loc, (val > 0.)); + return; + case tok::step: + as_number(loc, 0.5*((0. < val) - (val < 0.) + 1)); + return; + case tok::signum: + as_number(loc, (0. < val) - (val < 0.)); + return; default: ; // treat opaquely as below } } diff --git a/modcc/token.cpp b/modcc/token.cpp index 1c94385c9d7d6b99ecdb66f3537bb1ab26a18031..e8b70a5ae3e7c0afc18e78a9c8e998f03f323a7c 100644 --- a/modcc/token.cpp +++ b/modcc/token.cpp @@ -21,143 +21,153 @@ struct TokenString { }; static Keyword keywords[] = { - {"TITLE", tok::title}, - {"NEURON", tok::neuron}, - {"UNITS", tok::units}, - {"PARAMETER", tok::parameter}, - {"CONSTANT", tok::constant}, - {"ASSIGNED", tok::assigned}, - {"WHITE_NOISE", tok::white_noise}, - {"STATE", tok::state}, - {"BREAKPOINT", tok::breakpoint}, - {"DERIVATIVE", tok::derivative}, - {"KINETIC", tok::kinetic}, - {"LINEAR", tok::linear}, - {"PROCEDURE", tok::procedure}, - {"FUNCTION", tok::function}, - {"INITIAL", tok::initial}, - {"NET_RECEIVE", tok::net_receive}, - {"POST_EVENT", tok::post_event}, - {"UNITSOFF", tok::unitsoff}, - {"UNITSON", tok::unitson}, - {"SUFFIX", tok::suffix}, + {"TITLE", tok::title}, + {"NEURON", tok::neuron}, + {"UNITS", tok::units}, + {"PARAMETER", tok::parameter}, + {"CONSTANT", tok::constant}, + {"ASSIGNED", tok::assigned}, + {"WHITE_NOISE", tok::white_noise}, + {"STATE", tok::state}, + {"BREAKPOINT", tok::breakpoint}, + {"DERIVATIVE", tok::derivative}, + {"KINETIC", tok::kinetic}, + {"LINEAR", tok::linear}, + {"PROCEDURE", tok::procedure}, + {"FUNCTION", tok::function}, + {"INITIAL", tok::initial}, + {"NET_RECEIVE", tok::net_receive}, + {"POST_EVENT", tok::post_event}, + {"UNITSOFF", tok::unitsoff}, + {"UNITSON", tok::unitson}, + {"SUFFIX", tok::suffix}, {"NONSPECIFIC_CURRENT", tok::nonspecific_current}, - {"USEION", tok::useion}, - {"READ", tok::read}, - {"WRITE", tok::write}, - {"VALENCE", tok::valence}, - {"RANGE", tok::range}, - {"LOCAL", tok::local}, - {"CONSERVE", tok::conserve}, - {"SOLVE", tok::solve}, - {"THREADSAFE", tok::threadsafe}, - {"GLOBAL", tok::global}, - {"POINT_PROCESS", tok::point_process}, - {"JUNCTION_PROCESS", tok::junction_process}, - {"COMPARTMENT", tok::compartment}, - {"METHOD", tok::method}, - {"STEADYSTATE", tok::steadystate}, - {"FROM", tok::from}, - {"TO", tok::to}, - {"if", tok::if_stmt}, - {"IF", tok::if_stmt}, - {"else", tok::else_stmt}, - {"ELSE", tok::else_stmt}, - {"cnexp", tok::cnexp}, - {"sparse", tok::sparse}, - {"stochastic", tok::stochastic}, - {"min", tok::min}, - {"max", tok::max}, - {"exp", tok::exp}, - {"sin", tok::sin}, - {"cos", tok::cos}, - {"log", tok::log}, - {"fabs", tok::abs}, - {"exprelr", tok::exprelr}, - {"safeinv", tok::safeinv}, - {"CONDUCTANCE", tok::conductance}, - {"WATCH", tok::watch}, - {nullptr, tok::reserved}, + {"USEION", tok::useion}, + {"READ", tok::read}, + {"WRITE", tok::write}, + {"VALENCE", tok::valence}, + {"RANGE", tok::range}, + {"LOCAL", tok::local}, + {"CONSERVE", tok::conserve}, + {"SOLVE", tok::solve}, + {"THREADSAFE", tok::threadsafe}, + {"GLOBAL", tok::global}, + {"POINT_PROCESS", tok::point_process}, + {"JUNCTION_PROCESS", tok::junction_process}, + {"COMPARTMENT", tok::compartment}, + {"METHOD", tok::method}, + {"STEADYSTATE", tok::steadystate}, + {"FROM", tok::from}, + {"TO", tok::to}, + {"if", tok::if_stmt}, + {"IF", tok::if_stmt}, + {"else", tok::else_stmt}, + {"ELSE", tok::else_stmt}, + {"cnexp", tok::cnexp}, + {"sparse", tok::sparse}, + {"stochastic", tok::stochastic}, + {"min", tok::min}, + {"max", tok::max}, + {"exp", tok::exp}, + {"sin", tok::sin}, + {"cos", tok::cos}, + {"log", tok::log}, + {"fabs", tok::abs}, + {"exprelr", tok::exprelr}, + {"safeinv", tok::safeinv}, + {"CONDUCTANCE", tok::conductance}, + {"WATCH", tok::watch}, + {"sqrt", tok::sqrt}, + {"step_right", tok::step_right}, + {"step_left", tok::step_left}, + {"step", tok::step}, + {"signum", tok::signum}, + {nullptr, tok::reserved}, }; static TokenString token_strings[] = { - {"=", tok::eq}, - {"+", tok::plus}, - {"-", tok::minus}, - {"*", tok::times}, - {"/", tok::divide}, - {"^", tok::pow}, - {"!", tok::lnot}, - {"<", tok::lt}, - {"<=", tok::lte}, - {">", tok::gt}, - {">=", tok::gte}, - {"==", tok::equality}, - {"!=", tok::ne}, - {"&&", tok::land}, - {"||", tok::lor}, - {"<->", tok::arrow}, - {"~", tok::tilde}, - {",", tok::comma}, - {"'", tok::prime}, - {"{", tok::lbrace}, - {"}", tok::rbrace}, - {"(", tok::lparen}, - {")", tok::rparen}, - {"identifier", tok::identifier}, - {"real", tok::real}, - {"integer", tok::integer}, - {"TITLE", tok::title}, - {"NEURON", tok::neuron}, - {"UNITS", tok::units}, - {"PARAMETER", tok::parameter}, - {"CONSTANT", tok::constant}, - {"ASSIGNED", tok::assigned}, - {"WHITE_NOISE", tok::white_noise}, - {"STATE", tok::state}, - {"BREAKPOINT", tok::breakpoint}, - {"DERIVATIVE", tok::derivative}, - {"KINETIC", tok::kinetic}, - {"LINEAR", tok::linear}, - {"PROCEDURE", tok::procedure}, - {"FUNCTION", tok::function}, - {"INITIAL", tok::initial}, - {"NET_RECEIVE", tok::net_receive}, - {"POST_EVENT", tok::post_event}, - {"UNITSOFF", tok::unitsoff}, - {"UNITSON", tok::unitson}, - {"SUFFIX", tok::suffix}, + {"=", tok::eq}, + {"+", tok::plus}, + {"-", tok::minus}, + {"*", tok::times}, + {"/", tok::divide}, + {"^", tok::pow}, + {"!", tok::lnot}, + {"<", tok::lt}, + {"<=", tok::lte}, + {">", tok::gt}, + {">=", tok::gte}, + {"==", tok::equality}, + {"!=", tok::ne}, + {"&&", tok::land}, + {"||", tok::lor}, + {"<->", tok::arrow}, + {"~", tok::tilde}, + {",", tok::comma}, + {"'", tok::prime}, + {"{", tok::lbrace}, + {"}", tok::rbrace}, + {"(", tok::lparen}, + {")", tok::rparen}, + {"identifier", tok::identifier}, + {"real", tok::real}, + {"integer", tok::integer}, + {"TITLE", tok::title}, + {"NEURON", tok::neuron}, + {"UNITS", tok::units}, + {"PARAMETER", tok::parameter}, + {"CONSTANT", tok::constant}, + {"ASSIGNED", tok::assigned}, + {"WHITE_NOISE", tok::white_noise}, + {"STATE", tok::state}, + {"BREAKPOINT", tok::breakpoint}, + {"DERIVATIVE", tok::derivative}, + {"KINETIC", tok::kinetic}, + {"LINEAR", tok::linear}, + {"PROCEDURE", tok::procedure}, + {"FUNCTION", tok::function}, + {"INITIAL", tok::initial}, + {"NET_RECEIVE", tok::net_receive}, + {"POST_EVENT", tok::post_event}, + {"UNITSOFF", tok::unitsoff}, + {"UNITSON", tok::unitson}, + {"SUFFIX", tok::suffix}, {"NONSPECIFIC_CURRENT", tok::nonspecific_current}, - {"USEION", tok::useion}, - {"READ", tok::read}, - {"WRITE", tok::write}, - {"VALENCE", tok::valence}, - {"RANGE", tok::range}, - {"LOCAL", tok::local}, - {"SOLVE", tok::solve}, - {"THREADSAFE", tok::threadsafe}, - {"GLOBAL", tok::global}, - {"POINT_PROCESS", tok::point_process}, - {"JUNCTION_PROCESS", tok::junction_process}, - {"COMPARTMENT", tok::compartment}, - {"METHOD", tok::method}, - {"STEADYSTATE", tok::steadystate}, - {"if", tok::if_stmt}, - {"else", tok::else_stmt}, - {"eof", tok::eof}, - {"min", tok::min}, - {"max", tok::max}, - {"exp", tok::exp}, - {"log", tok::log}, - {"fabs", tok::abs}, - {"exprelr", tok::exprelr}, - {"safeinv", tok::safeinv}, - {"cos", tok::cos}, - {"sin", tok::sin}, - {"cnexp", tok::cnexp}, - {"CONDUCTANCE", tok::conductance}, - {"WATCH", tok::watch}, - {"error", tok::reserved}, + {"USEION", tok::useion}, + {"READ", tok::read}, + {"WRITE", tok::write}, + {"VALENCE", tok::valence}, + {"RANGE", tok::range}, + {"LOCAL", tok::local}, + {"SOLVE", tok::solve}, + {"THREADSAFE", tok::threadsafe}, + {"GLOBAL", tok::global}, + {"POINT_PROCESS", tok::point_process}, + {"JUNCTION_PROCESS", tok::junction_process}, + {"COMPARTMENT", tok::compartment}, + {"METHOD", tok::method}, + {"STEADYSTATE", tok::steadystate}, + {"if", tok::if_stmt}, + {"else", tok::else_stmt}, + {"eof", tok::eof}, + {"min", tok::min}, + {"max", tok::max}, + {"exp", tok::exp}, + {"log", tok::log}, + {"fabs", tok::abs}, + {"exprelr", tok::exprelr}, + {"safeinv", tok::safeinv}, + {"cos", tok::cos}, + {"sin", tok::sin}, + {"cnexp", tok::cnexp}, + {"CONDUCTANCE", tok::conductance}, + {"WATCH", tok::watch}, + {"sqrt", tok::sqrt}, + {"step_right", tok::step_right}, + {"step_left", tok::step_left}, + {"step", tok::step}, + {"signum", tok::signum}, + {"error", tok::reserved}, }; /// set up lookup tables for converting between tokens and their diff --git a/modcc/token.hpp b/modcc/token.hpp index 1e4fa762068f43dc3cf0226f307eb7fb0452b462..c3e6d3168539c4e65524b087ae1d3ba26f7af9b9 100644 --- a/modcc/token.hpp +++ b/modcc/token.hpp @@ -73,6 +73,11 @@ enum class tok { // unary operators exp, sin, cos, log, abs, safeinv, exprelr, // equivalent to x/(exp(x)-1) with exprelr(0)=1 + sqrt, + step_right, // right-continuous heaviside step function (H(0) = 1) + step_left, // left-continuous heaviside step function (H(0) = 0) + step, // heaviside step function (H(0) = 0.5) + signum, // sign function {-1, 0, +1} // logical keywords if_stmt, else_stmt, // add _stmt to avoid clash with c++ keywords diff --git a/modcc/visitor.hpp b/modcc/visitor.hpp index 72d8eb3d3aecdac4fc191dafd6f0d3b321f5e111..116d80b06dc3eb93b85cfd62899e8b318e01769e 100644 --- a/modcc/visitor.hpp +++ b/modcc/visitor.hpp @@ -15,51 +15,56 @@ class Visitor { public: virtual void visit(Expression *e) = 0; - virtual void visit(Symbol *e) { visit((Expression*) e); } - virtual void visit(LocalVariable *e) { visit((Expression*) e); } - virtual void visit(WhiteNoise *e) { visit((Expression*) e); } - virtual void visit(IdentifierExpression *e) { visit((Expression*) e); } - virtual void visit(NumberExpression *e) { visit((Expression*) e); } - virtual void visit(IntegerExpression *e) { visit((NumberExpression*) e); } - virtual void visit(LocalDeclaration *e) { visit((Expression*) e); } - virtual void visit(ArgumentExpression *e) { visit((Expression*) e); } - virtual void visit(PrototypeExpression *e) { visit((Expression*) e); } - virtual void visit(CallExpression *e) { visit((Expression*) e); } - virtual void visit(ReactionExpression *e) { visit((Expression*) e); } - virtual void visit(StoichTermExpression *e) { visit((Expression*) e); } - virtual void visit(StoichExpression *e) { visit((Expression*) e); } - virtual void visit(CompartmentExpression *e){ visit((Expression*) e); } - virtual void visit(VariableExpression *e) { visit((Expression*) e); } - virtual void visit(IndexedVariable *e) { visit((Expression*) e); } - virtual void visit(FunctionExpression *e) { visit((Expression*) e); } - virtual void visit(IfExpression *e) { visit((Expression*) e); } - virtual void visit(SolveExpression *e) { visit((Expression*) e); } - virtual void visit(DerivativeExpression *e) { visit((Expression*) e); } - virtual void visit(PDiffExpression *e) { visit((Expression*) e); } - virtual void visit(ProcedureExpression *e) { visit((Expression*) e); } - virtual void visit(NetReceiveExpression *e) { visit((ProcedureExpression*) e); } - virtual void visit(APIMethod *e) { visit((Expression*) e); } - virtual void visit(ConductanceExpression *e){ visit((Expression*) e); } - virtual void visit(BlockExpression *e) { visit((Expression*) e); } - virtual void visit(InitialBlock *e) { visit((BlockExpression*) e); } + virtual void visit(Symbol *e) { visit((Expression*) e); } + virtual void visit(LocalVariable *e) { visit((Expression*) e); } + virtual void visit(WhiteNoise *e) { visit((Expression*) e); } + virtual void visit(IdentifierExpression *e) { visit((Expression*) e); } + virtual void visit(NumberExpression *e) { visit((Expression*) e); } + virtual void visit(IntegerExpression *e) { visit((NumberExpression*) e); } + virtual void visit(LocalDeclaration *e) { visit((Expression*) e); } + virtual void visit(ArgumentExpression *e) { visit((Expression*) e); } + virtual void visit(PrototypeExpression *e) { visit((Expression*) e); } + virtual void visit(CallExpression *e) { visit((Expression*) e); } + virtual void visit(ReactionExpression *e) { visit((Expression*) e); } + virtual void visit(StoichTermExpression *e) { visit((Expression*) e); } + virtual void visit(StoichExpression *e) { visit((Expression*) e); } + virtual void visit(CompartmentExpression *e) { visit((Expression*) e); } + virtual void visit(VariableExpression *e) { visit((Expression*) e); } + virtual void visit(IndexedVariable *e) { visit((Expression*) e); } + virtual void visit(FunctionExpression *e) { visit((Expression*) e); } + virtual void visit(IfExpression *e) { visit((Expression*) e); } + virtual void visit(SolveExpression *e) { visit((Expression*) e); } + virtual void visit(DerivativeExpression *e) { visit((Expression*) e); } + virtual void visit(PDiffExpression *e) { visit((Expression*) e); } + virtual void visit(ProcedureExpression *e) { visit((Expression*) e); } + virtual void visit(NetReceiveExpression *e) { visit((ProcedureExpression*) e); } + virtual void visit(APIMethod *e) { visit((Expression*) e); } + virtual void visit(ConductanceExpression *e) { visit((Expression*) e); } + virtual void visit(BlockExpression *e) { visit((Expression*) e); } + virtual void visit(InitialBlock *e) { visit((BlockExpression*) e); } virtual void visit(UnaryExpression *e) = 0; - virtual void visit(NegUnaryExpression *e) { visit((UnaryExpression*) e); } - virtual void visit(ExpUnaryExpression *e) { visit((UnaryExpression*) e); } - virtual void visit(LogUnaryExpression *e) { visit((UnaryExpression*) e); } - virtual void visit(CosUnaryExpression *e) { visit((UnaryExpression*) e); } - virtual void visit(SinUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(NegUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(ExpUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(LogUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(CosUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(SinUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(SqrtUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(StepRightUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(StepLeftUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(StepUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(SignumUnaryExpression *e) { visit((UnaryExpression*) e); } virtual void visit(BinaryExpression *e) = 0; - virtual void visit(ConditionalExpression *e){ visit((BinaryExpression*) e); } - virtual void visit(AssignmentExpression *e) { visit((BinaryExpression*) e); } - virtual void visit(ConserveExpression *e) { visit((BinaryExpression*) e); } - virtual void visit(LinearExpression *e) { visit((BinaryExpression*) e); } - virtual void visit(AddBinaryExpression *e) { visit((BinaryExpression*) e); } - virtual void visit(SubBinaryExpression *e) { visit((BinaryExpression*) e); } - virtual void visit(MulBinaryExpression *e) { visit((BinaryExpression*) e); } - virtual void visit(DivBinaryExpression *e) { visit((BinaryExpression*) e); } - virtual void visit(PowBinaryExpression *e) { visit((BinaryExpression*) e); } + virtual void visit(ConditionalExpression *e) { visit((BinaryExpression*) e); } + virtual void visit(AssignmentExpression *e) { visit((BinaryExpression*) e); } + virtual void visit(ConserveExpression *e) { visit((BinaryExpression*) e); } + virtual void visit(LinearExpression *e) { visit((BinaryExpression*) e); } + virtual void visit(AddBinaryExpression *e) { visit((BinaryExpression*) e); } + virtual void visit(SubBinaryExpression *e) { visit((BinaryExpression*) e); } + virtual void visit(MulBinaryExpression *e) { visit((BinaryExpression*) e); } + virtual void visit(DivBinaryExpression *e) { visit((BinaryExpression*) e); } + virtual void visit(PowBinaryExpression *e) { visit((BinaryExpression*) e); } virtual ~Visitor() {}; }; diff --git a/test/unit-modcc/test_printers.cpp b/test/unit-modcc/test_printers.cpp index 89b1ddf95d570e54a6dd513c0cdb4512cdec8685..57da5d65d98636087540fa21b01ddc1014febc54 100644 --- a/test/unit-modcc/test_printers.cpp +++ b/test/unit-modcc/test_printers.cpp @@ -58,25 +58,30 @@ TEST(scalar_printer, constants) { TEST(scalar_printer, statement) { std::vector<testcase> testcases = { - {"y=x+3", "y=x+3.0"}, - {"y=y^z", "y=pow(y,z)"}, - {"y=exp((x/2) + 3)", "y=exp(x/2.0+3.0)"}, - {"z=a/b/c", "z=a/b/c"}, - {"z=a/(b/c)", "z=a/(b/c)"}, - {"z=(a*b)/c", "z=a*b/c"}, - {"z=a-(b+c)", "z=a-(b+c)"}, - {"z=(a>0)<(b>0)", "z=a>0.<(b>0.)"}, - {"z=a- -2", "z=a- -2.0"}, - {"z=fabs(x-z)", "z=abs(x-z)"}, - {"z=min(x,y)", "z=min(x,y)"}, - {"z=min(max(a,b),y)","z=min(max(a,b),y)"}, + {"y=x+3", "y=x+3.0"}, + {"y=y^z", "y=pow(y,z)"}, + {"y=exp((x/2) + 3)", "y=exp(x/2.0+3.0)"}, + {"z=a/b/c", "z=a/b/c"}, + {"z=a/(b/c)", "z=a/(b/c)"}, + {"z=(a*b)/c", "z=a*b/c"}, + {"z=a-(b+c)", "z=a-(b+c)"}, + {"z=(a>0)<(b>0)", "z=a>0.<(b>0.)"}, + {"z=a- -2", "z=a- -2.0"}, + {"z=fabs(x-z)", "z=abs(x-z)"}, + {"z=min(x,y)", "z=min(x,y)"}, + {"z=min(max(a,b),y)", "z=min(max(a,b),y)"}, + {"y=sqrt((x/2) + 3)", "y=sqrt(x/2.0+3.0)"}, + {"y=signum(c-theta)", "y=((arb_value_type)((0.<(c-theta))-((c-theta)<0.)))"}, + {"y=step_right(c-theta)", "y=((arb_value_type)((c-theta)>=0.))"}, + {"y=step_left(c-theta)", "y=((arb_value_type)((c-theta)>0.))"}, + {"y=step(c-theta)", "y=((arb_value_type)0.5*((0.<(c-theta))-((c-theta)<0.)+1))"}, }; // create a scope that contains the symbols used in the tests Scope<Symbol>::symbol_map globals; auto scope = std::make_shared<Scope<Symbol>>(globals); - for (auto var: {"x", "y", "z", "a", "b", "c"}) { + for (auto var: {"x", "y", "z", "a", "b", "c", "theta"}) { scope->add_local_symbol(var, make_symbol<LocalVariable>(Location(), var, localVariableKind::local)); } diff --git a/test/unit-modcc/test_visitors.cpp b/test/unit-modcc/test_visitors.cpp index 348e4a8d26f672c15d2078f856beed26b6c66b81..507706e8826aa16f053d95c38bde8a4d2b5424fe 100644 --- a/test/unit-modcc/test_visitors.cpp +++ b/test/unit-modcc/test_visitors.cpp @@ -72,6 +72,28 @@ TEST(FlopVisitor, basic) { e->accept(&visitor); EXPECT_EQ(visitor.flops.sin, 1); } + + { + FlopVisitor visitor; + auto e = parse_expression("sqrt(x)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.sqrt, 1); + } + + { + FlopVisitor visitor; + auto e = parse_expression("signum(x)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 1); + } + + { + FlopVisitor visitor; + auto e = parse_expression("step(x)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 2); + EXPECT_EQ(visitor.flops.mul, 1); + } } TEST(FlopVisitor, compound) { @@ -121,16 +143,18 @@ TEST(FlopVisitor, procedure) { " hinf=1/(1+exp((v-vhalfh)/kh))\n" " mtau = 0.6\n" " htau = 1500\n" +" rho = step_left(c-theta) + 1/sqrt(tau)*sigma\n" "}"; FlopVisitor visitor; auto e = parse_procedure(expression); e->accept(&visitor); - EXPECT_EQ(visitor.flops.add, 6); + EXPECT_EQ(visitor.flops.add, 7); EXPECT_EQ(visitor.flops.neg, 0); - EXPECT_EQ(visitor.flops.mul, 0); - EXPECT_EQ(visitor.flops.div, 5); + EXPECT_EQ(visitor.flops.mul, 1); + EXPECT_EQ(visitor.flops.div, 6); EXPECT_EQ(visitor.flops.exp, 2); EXPECT_EQ(visitor.flops.pow, 1); + EXPECT_EQ(visitor.flops.sqrt, 1); } TEST(FlopVisitor, function) { @@ -141,15 +165,17 @@ TEST(FlopVisitor, function) { " minf=1-1/(1+exp((v-vhalfm)/km))\n" " hinf=1/(1+exp((v-vhalfh)/kh))\n" " foo = minf + hinf\n" +" rho = signum(c-theta)/tau + 1/sqrt(tau)*sigma\n" "}"; FlopVisitor visitor; auto e = parse_function(expression); e->accept(&visitor); - EXPECT_EQ(visitor.flops.add, 7); + EXPECT_EQ(visitor.flops.add, 10); EXPECT_EQ(visitor.flops.neg, 1); - EXPECT_EQ(visitor.flops.mul, 0); - EXPECT_EQ(visitor.flops.div, 5); + EXPECT_EQ(visitor.flops.mul, 1); + EXPECT_EQ(visitor.flops.div, 7); EXPECT_EQ(visitor.flops.exp, 2); EXPECT_EQ(visitor.flops.pow, 1); + EXPECT_EQ(visitor.flops.sqrt, 1); } diff --git a/test/unit/test_iexpr.cpp b/test/unit/test_iexpr.cpp index 506d924c44ceb8be581eacb6defbdc40eea82dd1..e3ac3ce8fb3df41501b258d4ce9c958948cd9f6b 100644 --- a/test/unit/test_iexpr.cpp +++ b/test/unit/test_iexpr.cpp @@ -398,6 +398,30 @@ TEST(iexpr, exp) { EXPECT_DOUBLE_EQ(e->eval(prov, {0, 0.0, 1.0}), std::exp(3.0 * 10.0)); } +TEST(iexpr, step_right) { + segment_tree tree; + tree.append(mnpos, {0, 0, 0, 10}, {0, 0, 10, 10}, 1); + + arb::mprovider prov(arb::morphology(std::move(tree))); + auto root_dist = arb::iexpr::distance(1.0, arb::mlocation{0, 0.0}); + auto e = thingify(arb::iexpr::step_right(root_dist-5.0), prov); + EXPECT_DOUBLE_EQ(e->eval(prov, {0, 0.0, 0.5}), 0.0f); /* step(2.5-5) == 0 */ + EXPECT_DOUBLE_EQ(e->eval(prov, {0, 0.4, 0.6}), 1.0f); /* step(5.0-5) == 1 */ + EXPECT_DOUBLE_EQ(e->eval(prov, {0, 0.5, 1.0}), 1.0f); /* step(7.5-5) == 1 */ +} + +TEST(iexpr, step_left) { + segment_tree tree; + tree.append(mnpos, {0, 0, 0, 10}, {0, 0, 10, 10}, 1); + + arb::mprovider prov(arb::morphology(std::move(tree))); + auto root_dist = arb::iexpr::distance(1.0, arb::mlocation{0, 0.0}); + auto e = thingify(arb::iexpr::step_left(root_dist-5.0), prov); + EXPECT_DOUBLE_EQ(e->eval(prov, {0, 0.0, 0.5}), 0.0f); /* step(2.5-5) == 0 */ + EXPECT_DOUBLE_EQ(e->eval(prov, {0, 0.4, 0.6}), 0.0f); /* step(5.0-5) == 0 */ + EXPECT_DOUBLE_EQ(e->eval(prov, {0, 0.5, 1.0}), 1.0f); /* step(7.5-5) == 1 */ +} + TEST(iexpr, step) { segment_tree tree; tree.append(mnpos, {0, 0, 0, 10}, {0, 0, 10, 10}, 1); @@ -406,6 +430,7 @@ TEST(iexpr, step) { auto root_dist = arb::iexpr::distance(1.0, arb::mlocation{0, 0.0}); auto e = thingify(arb::iexpr::step(root_dist-5.0), prov); EXPECT_DOUBLE_EQ(e->eval(prov, {0, 0.0, 0.5}), 0.0f); /* step(2.5-5) == 0 */ + EXPECT_DOUBLE_EQ(e->eval(prov, {0, 0.4, 0.6}), 0.5f); /* step(5.0-5) == 0.5 */ EXPECT_DOUBLE_EQ(e->eval(prov, {0, 0.5, 1.0}), 1.0f); /* step(7.5-5) == 1 */ } diff --git a/test/unit/test_simd.cpp b/test/unit/test_simd.cpp index 7a3978fd041bf6677fc632487406171b2c174191..7f0e0a339aafc959d7c729c9269b51d8504349aa 100644 --- a/test/unit/test_simd.cpp +++ b/test/unit/test_simd.cpp @@ -637,6 +637,7 @@ TYPED_TEST_P(simd_fp_value, fp_maths) { for (unsigned i = 0; i<nrounds; ++i) { fp epsilon = std::numeric_limits<fp>::epsilon(); + fp max_value = std::numeric_limits<fp>::max(); int min_exponent = std::numeric_limits<fp>::min_exponent; int max_exponent = std::numeric_limits<fp>::max_exponent; @@ -733,6 +734,50 @@ TYPED_TEST_P(simd_fp_value, fp_maths) { for (unsigned i = 0; i<N; ++i) pow_u_v_int[i] = std::pow(u[i], v[i]); pow(simd(u), simd(v)).copy_to(r); EXPECT_TRUE(testing::seq_almost_eq<fp>(pow_u_v_int, r)); + + // Sqrt function: + fill_random(u, rng, 0., max_value); + fp sqrt_fp[N]; + for (unsigned i = 0; i<N; ++i) sqrt_fp[i] = std::sqrt(u[i]); + sqrt(simd(u)).copy_to(r); + EXPECT_TRUE(testing::seq_almost_eq<fp>(sqrt_fp, r)); + + // Indicator functions: + fill_random(u, rng, 0.01, 10.0); + fill_random(v, rng, -10.0, -0.1); + v[0] = 0.0; + v[1] = -0.0; + fp signum_fp[N]; + fp step_right_fp[N]; + fp step_left_fp[N]; + fp step_fp[N]; + for (unsigned i = 0; i<N; ++i) { + signum_fp[i] = -1; + step_right_fp[i] = 0; + step_left_fp[i] = 0; + step_fp[i] = 0; + } + signum_fp[0] = 0; + signum_fp[1] = 0; + step_right_fp[0] = 1; + step_right_fp[1] = 1; + step_fp[0] = 0.5; + step_fp[1] = 0.5; + for (unsigned i = 2; i<2+(N-2)/2; ++i) { + v[i] = u[i]; + signum_fp[i] = 1; + step_right_fp[i] = 1; + step_left_fp[i] = 1; + step_fp[i] = 1; + } + signum(simd(v)).copy_to(r); + EXPECT_TRUE(testing::seq_eq(signum_fp, r)); + step_right(simd(v)).copy_to(r); + EXPECT_TRUE(testing::seq_eq(step_right_fp, r)); + step_left(simd(v)).copy_to(r); + EXPECT_TRUE(testing::seq_eq(step_left_fp, r)); + step(simd(v)).copy_to(r); + EXPECT_TRUE(testing::seq_eq(step_fp, r)); } // The tests can cause floating point exceptions, which may set errno to nonzero