From 48f4795e0e4cb14afb5356fe74a4d841d8f09b3c Mon Sep 17 00:00:00 2001
From: Lennart Landsmeer <lennart@landsmeer.email>
Date: Tue, 11 Oct 2022 11:45:13 +0200
Subject: [PATCH] Heaviside step (#1989)

Adds the Heaviside step function to Arbor's iexpr

Co-authored-by: boeschf <48126478+boeschf@users.noreply.github.com>
---
 arbor/iexpr.cpp               | 23 +++++++++++++++++++++++
 arbor/include/arbor/iexpr.hpp |  3 +++
 arborio/label_parse.cpp       |  3 +++
 doc/concepts/labels.rst       |  4 ++++
 test/unit/test_iexpr.cpp      | 11 +++++++++++
 test/unit/test_s_expr.cpp     |  3 ++-
 6 files changed, 46 insertions(+), 1 deletion(-)

diff --git a/arbor/iexpr.cpp b/arbor/iexpr.cpp
index f2fe81ac..339e4e0d 100644
--- a/arbor/iexpr.cpp
+++ b/arbor/iexpr.cpp
@@ -281,6 +281,20 @@ struct exp: public iexpr_interface {
     iexpr_ptr value;
 };
 
+struct step: public iexpr_interface {
+    step(iexpr_ptr v): value(std::move(v)) {}
+
+    double eval(const mprovider& p, const mcable& c) const override {
+        double x = value->eval(p, c);
+        // x <  0:  0
+        // x == 0:  0.5
+        // x >  0:  1
+        return 0.5*((0. < x) - (x < 0.) + 1);
+    }
+
+    iexpr_ptr value;
+};
+
 struct log: public iexpr_interface {
     log(iexpr_ptr v): value(std::move(v)) {}
 
@@ -402,6 +416,8 @@ iexpr iexpr::div(iexpr left, iexpr right) {
 
 iexpr iexpr::exp(iexpr value) { return iexpr(iexpr_type::exp, std::make_tuple(std::move(value))); }
 
+iexpr iexpr::step(iexpr value) { return iexpr(iexpr_type::step, std::make_tuple(std::move(value))); }
+
 iexpr iexpr::log(iexpr value) { return iexpr(iexpr_type::log, std::make_tuple(std::move(value))); }
 
 iexpr iexpr::named(std::string name) {
@@ -488,6 +504,9 @@ iexpr_ptr thingify(const iexpr& expr, const mprovider& m) {
     case iexpr_type::exp:
         return iexpr_ptr(new iexpr_impl::exp(
             thingify(std::get<0>(std::any_cast<const std::tuple<iexpr>&>(expr.args())), m)));
+    case iexpr_type::step:
+        return iexpr_ptr(new iexpr_impl::step(
+            thingify(std::get<0>(std::any_cast<const std::tuple<iexpr>&>(expr.args())), m)));
     case iexpr_type::log:
         return iexpr_ptr(new iexpr_impl::log(
             thingify(std::get<0>(std::any_cast<const std::tuple<iexpr>&>(expr.args())), m)));
@@ -581,6 +600,10 @@ std::ostream& operator<<(std::ostream& o, const iexpr& e) {
         o << "exp " << std::get<0>(std::any_cast<const std::tuple<iexpr>&>(e.args()));
         break;
     }
+    case iexpr_type::step: {
+        o << "step " << std::get<0>(std::any_cast<const std::tuple<iexpr>&>(e.args()));
+        break;
+    }
     case iexpr_type::log: {
         o << "log " << std::get<0>(std::any_cast<const std::tuple<iexpr>&>(e.args()));
         break;
diff --git a/arbor/include/arbor/iexpr.hpp b/arbor/include/arbor/iexpr.hpp
index 6921b777..3164c53d 100644
--- a/arbor/include/arbor/iexpr.hpp
+++ b/arbor/include/arbor/iexpr.hpp
@@ -29,6 +29,7 @@ enum class iexpr_type {
     mul,
     div,
     exp,
+    step,
     log,
     named
 };
@@ -97,6 +98,8 @@ struct ARB_SYMBOL_VISIBLE iexpr {
 
     static iexpr exp(iexpr value);
 
+    static iexpr step(iexpr value);
+
     static iexpr log(iexpr value);
 
     static iexpr named(std::string name);
diff --git a/arborio/label_parse.cpp b/arborio/label_parse.cpp
index 6700295c..02df7dad 100644
--- a/arborio/label_parse.cpp
+++ b/arborio/label_parse.cpp
@@ -163,6 +163,9 @@ std::unordered_multimap<std::string, evaluator> eval_map {
     {"exp", make_call<arb::iexpr>(arb::iexpr::exp, "iexpr with 1 argument: (value:iexpr)")},
     {"exp", make_call<double>(arb::iexpr::exp, "iexpr with 1 argument: (value:double)")},
 
+    {"step", make_call<arb::iexpr>(arb::iexpr::step, "iexpr with 1 argument: (value:iexpr)")},
+    {"step", make_call<double>(arb::iexpr::step, "iexpr with 1 argument: (value:double)")},
+
     {"log", make_call<arb::iexpr>(arb::iexpr::log, "iexpr with 1 argument: (value:iexpr)")},
     {"log", make_call<double>(arb::iexpr::log, "iexpr with 1 argument: (value:double)")},
 
diff --git a/doc/concepts/labels.rst b/doc/concepts/labels.rst
index 70272753..2997d38a 100644
--- a/doc/concepts/labels.rst
+++ b/doc/concepts/labels.rst
@@ -762,6 +762,10 @@ Inhomogeneous Expressions
 
     The exponential function of the inhomogeneous expression or real ``value``.
 
+.. label:: (step value:(iexpr | real))
+
+    The Heaviside step function of the inhomogeneous expression or real ``value``, with `(step 0.0)` evaluating to 0.5.
+
 .. label:: (log value:(iexpr | real))
 
     The logarithm of the inhomogeneous expression or real ``value``.
diff --git a/test/unit/test_iexpr.cpp b/test/unit/test_iexpr.cpp
index 9043381a..f399fa78 100644
--- a/test/unit/test_iexpr.cpp
+++ b/test/unit/test_iexpr.cpp
@@ -398,6 +398,17 @@ TEST(iexpr, exp) {
     EXPECT_DOUBLE_EQ(e->eval(prov, {0, 0.0, 1.0}), std::exp(3.0 * 10.0));
 }
 
+TEST(iexpr, step) {
+    segment_tree tree;
+    tree.append(mnpos, {0, 0, 0, 10}, {0, 0, 10, 10}, 1);
+
+    arb::mprovider prov(arb::morphology(std::move(tree)));
+    auto root_dist = arb::iexpr::distance(1.0, arb::mlocation{0, 0.0});
+    auto e = thingify(arb::iexpr::step(root_dist-5.0), prov);
+    EXPECT_DOUBLE_EQ(e->eval(prov, {0, 0.0, 0.5}), 0.0f); /* step(2.5-5) == 0 */
+    EXPECT_DOUBLE_EQ(e->eval(prov, {0, 0.5, 1.0}), 1.0f); /* step(7.5-5) == 1 */
+}
+
 TEST(iexpr, log) {
     segment_tree tree;
     tree.append(mnpos, {0, 0, 0, 10}, {0, 0, 10, 10}, 1);
diff --git a/test/unit/test_s_expr.cpp b/test/unit/test_s_expr.cpp
index 9920ee86..85ed36d5 100644
--- a/test/unit/test_s_expr.cpp
+++ b/test/unit/test_s_expr.cpp
@@ -269,6 +269,7 @@ TEST(iexpr, round_tripping) {
         "(radius 2.1)",
         "(diameter 2.1)",
         "(exp (scalar 2.1))",
+        "(step (scalar 2.1))",
         "(log (scalar 2.1))",
         "(add (scalar 2.1) (radius 3.2))",
         "(sub (scalar 2.1) (radius 3.2))",
@@ -281,7 +282,7 @@ TEST(iexpr, round_tripping) {
     }
 
     // check double values for input instead of explicit scalar iexpr
-    auto mono_iexpr = {std::string("exp"), std::string("log")};
+    auto mono_iexpr = {std::string("exp"), std::string("step"), std::string("log")};
     auto duo_iexpr = {std::string("add"), std::string("sub"), std::string("mul"), std::string("div")};
     constexpr auto v1 = "1.2";
     constexpr auto v2 = "1.2";
-- 
GitLab