diff --git a/arbor/iexpr.cpp b/arbor/iexpr.cpp index f2fe81acac734ed5b82f7e07ed57ed9cf89a1a5c..339e4e0dcf01591c99a695f783f171812c7e5659 100644 --- a/arbor/iexpr.cpp +++ b/arbor/iexpr.cpp @@ -281,6 +281,20 @@ struct exp: public iexpr_interface { iexpr_ptr value; }; +struct step: public iexpr_interface { + step(iexpr_ptr v): value(std::move(v)) {} + + double eval(const mprovider& p, const mcable& c) const override { + double x = value->eval(p, c); + // x < 0: 0 + // x == 0: 0.5 + // x > 0: 1 + return 0.5*((0. < x) - (x < 0.) + 1); + } + + iexpr_ptr value; +}; + struct log: public iexpr_interface { log(iexpr_ptr v): value(std::move(v)) {} @@ -402,6 +416,8 @@ iexpr iexpr::div(iexpr left, iexpr right) { iexpr iexpr::exp(iexpr value) { return iexpr(iexpr_type::exp, std::make_tuple(std::move(value))); } +iexpr iexpr::step(iexpr value) { return iexpr(iexpr_type::step, std::make_tuple(std::move(value))); } + iexpr iexpr::log(iexpr value) { return iexpr(iexpr_type::log, std::make_tuple(std::move(value))); } iexpr iexpr::named(std::string name) { @@ -488,6 +504,9 @@ iexpr_ptr thingify(const iexpr& expr, const mprovider& m) { case iexpr_type::exp: return iexpr_ptr(new iexpr_impl::exp( thingify(std::get<0>(std::any_cast<const std::tuple<iexpr>&>(expr.args())), m))); + case iexpr_type::step: + return iexpr_ptr(new iexpr_impl::step( + thingify(std::get<0>(std::any_cast<const std::tuple<iexpr>&>(expr.args())), m))); case iexpr_type::log: return iexpr_ptr(new iexpr_impl::log( thingify(std::get<0>(std::any_cast<const std::tuple<iexpr>&>(expr.args())), m))); @@ -581,6 +600,10 @@ std::ostream& operator<<(std::ostream& o, const iexpr& e) { o << "exp " << std::get<0>(std::any_cast<const std::tuple<iexpr>&>(e.args())); break; } + case iexpr_type::step: { + o << "step " << std::get<0>(std::any_cast<const std::tuple<iexpr>&>(e.args())); + break; + } case iexpr_type::log: { o << "log " << std::get<0>(std::any_cast<const std::tuple<iexpr>&>(e.args())); break; diff --git a/arbor/include/arbor/iexpr.hpp b/arbor/include/arbor/iexpr.hpp index 6921b777ed912fe6582da404830fba57f2bacc27..3164c53d975e1c071f7d8ede4ef7124d74932428 100644 --- a/arbor/include/arbor/iexpr.hpp +++ b/arbor/include/arbor/iexpr.hpp @@ -29,6 +29,7 @@ enum class iexpr_type { mul, div, exp, + step, log, named }; @@ -97,6 +98,8 @@ struct ARB_SYMBOL_VISIBLE iexpr { static iexpr exp(iexpr value); + static iexpr step(iexpr value); + static iexpr log(iexpr value); static iexpr named(std::string name); diff --git a/arborio/label_parse.cpp b/arborio/label_parse.cpp index 6700295c830e23ef9c15077e3824bbb7dbcd97e2..02df7dad8912f24bde490c9fe2e621a606e1e0af 100644 --- a/arborio/label_parse.cpp +++ b/arborio/label_parse.cpp @@ -163,6 +163,9 @@ std::unordered_multimap<std::string, evaluator> eval_map { {"exp", make_call<arb::iexpr>(arb::iexpr::exp, "iexpr with 1 argument: (value:iexpr)")}, {"exp", make_call<double>(arb::iexpr::exp, "iexpr with 1 argument: (value:double)")}, + {"step", make_call<arb::iexpr>(arb::iexpr::step, "iexpr with 1 argument: (value:iexpr)")}, + {"step", make_call<double>(arb::iexpr::step, "iexpr with 1 argument: (value:double)")}, + {"log", make_call<arb::iexpr>(arb::iexpr::log, "iexpr with 1 argument: (value:iexpr)")}, {"log", make_call<double>(arb::iexpr::log, "iexpr with 1 argument: (value:double)")}, diff --git a/doc/concepts/labels.rst b/doc/concepts/labels.rst index 7027275384f0ef5e2d53143da560c1ecae4e3fbd..2997d38a61daa3be9fe98f0246b83e444eb196eb 100644 --- a/doc/concepts/labels.rst +++ b/doc/concepts/labels.rst @@ -762,6 +762,10 @@ Inhomogeneous Expressions The exponential function of the inhomogeneous expression or real ``value``. +.. label:: (step value:(iexpr | real)) + + The Heaviside step function of the inhomogeneous expression or real ``value``, with `(step 0.0)` evaluating to 0.5. + .. label:: (log value:(iexpr | real)) The logarithm of the inhomogeneous expression or real ``value``. diff --git a/test/unit/test_iexpr.cpp b/test/unit/test_iexpr.cpp index 9043381a0ad7a94380ea70d581a28904ed34c52a..f399fa785a502206c2f93aec756ab190d8d1f4d7 100644 --- a/test/unit/test_iexpr.cpp +++ b/test/unit/test_iexpr.cpp @@ -398,6 +398,17 @@ TEST(iexpr, exp) { EXPECT_DOUBLE_EQ(e->eval(prov, {0, 0.0, 1.0}), std::exp(3.0 * 10.0)); } +TEST(iexpr, step) { + segment_tree tree; + tree.append(mnpos, {0, 0, 0, 10}, {0, 0, 10, 10}, 1); + + arb::mprovider prov(arb::morphology(std::move(tree))); + auto root_dist = arb::iexpr::distance(1.0, arb::mlocation{0, 0.0}); + auto e = thingify(arb::iexpr::step(root_dist-5.0), prov); + EXPECT_DOUBLE_EQ(e->eval(prov, {0, 0.0, 0.5}), 0.0f); /* step(2.5-5) == 0 */ + EXPECT_DOUBLE_EQ(e->eval(prov, {0, 0.5, 1.0}), 1.0f); /* step(7.5-5) == 1 */ +} + TEST(iexpr, log) { segment_tree tree; tree.append(mnpos, {0, 0, 0, 10}, {0, 0, 10, 10}, 1); diff --git a/test/unit/test_s_expr.cpp b/test/unit/test_s_expr.cpp index 9920ee86b100ad0ad60030f9f0d376d27dcb56c7..85ed36d59a60faaa8cdf69e6e1896a5666197dbb 100644 --- a/test/unit/test_s_expr.cpp +++ b/test/unit/test_s_expr.cpp @@ -269,6 +269,7 @@ TEST(iexpr, round_tripping) { "(radius 2.1)", "(diameter 2.1)", "(exp (scalar 2.1))", + "(step (scalar 2.1))", "(log (scalar 2.1))", "(add (scalar 2.1) (radius 3.2))", "(sub (scalar 2.1) (radius 3.2))", @@ -281,7 +282,7 @@ TEST(iexpr, round_tripping) { } // check double values for input instead of explicit scalar iexpr - auto mono_iexpr = {std::string("exp"), std::string("log")}; + auto mono_iexpr = {std::string("exp"), std::string("step"), std::string("log")}; auto duo_iexpr = {std::string("add"), std::string("sub"), std::string("mul"), std::string("div")}; constexpr auto v1 = "1.2"; constexpr auto v2 = "1.2";