From 2dabd87a34996d3f11c233f0648c50f715084c72 Mon Sep 17 00:00:00 2001 From: Lennart Landsmeer <lennart@landsmeer.email> Date: Thu, 5 Jan 2023 13:27:58 +0100 Subject: [PATCH] ANN activation functions for NMODL: `sigmoid(x)`, `relu(x)` and `tanh(x)` (#2052) I have been playing around a bit with getting Neural ODE's working (and in general ANN's as black-box homeostatic mechanisms) and found arbor lacking in ANN activation functions. Now not everybody might agree that this is actually needed in Arbor, so I don't expect this PR to be merged directly. But if people like the idea, here it is :) The PR contains 3 common unary functions used in ANN's as nonlinear activations: `sigmoid(x)`, `relu(x)` and `tanh(x)` Here is a very minimal example of a 'neural ode': ``` NEURON { SUFFIX node } PARAMETER { A1 = -0.662 A2 = 0.282 A3 = 0.957 A4 = -0.189 A5 = -1.794 A6 = 1.094 A7 = -1.133 A8 = 0.625 A9 = 0.074 b1 = -1.434 b2 = -0.358 b3 = -0.093 } STATE { x y z } INITIAL { x = 0 y = 1 z = 1 } DERIVATIVE dstate { x' = tanh(A1 * x + A2 * y + A3 * z + b1) y' = tanh(A4 * x + A5 * y + A6 * z + b2) z' = tanh(A7 * x + A8 * y + A9 * z + b3) } BREAKPOINT { SOLVE dstate METHOD sparse } ``` --- arbor/include/arbor/simd/implbase.hpp | 20 +++++++++++ arbor/include/arbor/simd/simd.hpp | 4 +-- doc/dev/simd_api.rst | 12 +++++++ doc/fileformat/nmodl.rst | 5 ++- modcc/expression.cpp | 15 ++++++++ modcc/expression.hpp | 27 +++++++++++++++ modcc/parser.cpp | 3 ++ modcc/perfvisitor.hpp | 22 ++++++++++-- modcc/printer/cexpr_emit.cpp | 19 +++++++++-- modcc/printer/cprinter.cpp | 1 + modcc/symdiff.cpp | 49 +++++++++++++++++++++++++++ modcc/token.cpp | 6 ++++ modcc/token.hpp | 5 +++ modcc/visitor.hpp | 3 ++ test/unit-modcc/test_symdiff.cpp | 6 +++- test/unit/test_simd.cpp | 17 ++++++++++ 16 files changed, 205 insertions(+), 9 deletions(-) diff --git a/arbor/include/arbor/simd/implbase.hpp b/arbor/include/arbor/simd/implbase.hpp index 968ffce0..fd62ad91 100644 --- a/arbor/include/arbor/simd/implbase.hpp +++ b/arbor/include/arbor/simd/implbase.hpp @@ -573,6 +573,26 @@ struct implbase { } return I::copy_from(r); } + + static vector_type relu(const vector_type& s) { + vector_type zeros = I::broadcast(0); + return I::ifelse(I::cmp_gt(s, zeros), s, zeros); + } + + static vector_type sigmoid(const vector_type& s) { + vector_type ones = I::broadcast(1); + return div(ones, add(ones, exp(neg(s)))); + } + + static vector_type tanh(const vector_type& s) { + store a, r; + I::copy_to(s, a); + + for (unsigned i = 0; i<width; ++i) { + r[i] = std::tanh(a[i]); + } + return I::copy_from(r); + } }; } // namespace detail diff --git a/arbor/include/arbor/simd/simd.hpp b/arbor/include/arbor/simd/simd.hpp index 48c2e0a0..934c5aee 100644 --- a/arbor/include/arbor/simd/simd.hpp +++ b/arbor/include/arbor/simd/simd.hpp @@ -82,7 +82,7 @@ typename detail::simd_impl<Impl>::simd_mask name(const typename detail::simd_imp ARB_PP_FOREACH(ARB_BINARY_ARITHMETIC_, add, sub, mul, div, pow, max, min) ARB_PP_FOREACH(ARB_BINARY_COMPARISON_, cmp_eq, cmp_neq, cmp_leq, cmp_lt, cmp_geq, cmp_gt) -ARB_PP_FOREACH(ARB_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, exprelr, sqrt, step_right, step_left, step, signum) +ARB_PP_FOREACH(ARB_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, exprelr, sqrt, step_right, step_left, step, signum, sigmoid, relu, tanh) #undef ARB_BINARY_ARITHMETIC_ #undef ARB_BINARY_COMPARISON__ @@ -685,7 +685,7 @@ namespace detail { ARB_PP_FOREACH(ARB_DECLARE_BINARY_ARITHMETIC_, add, sub, mul, div, pow, max, min, cmp_eq) ARB_PP_FOREACH(ARB_DECLARE_BINARY_COMPARISON_, cmp_eq, cmp_neq, cmp_lt, cmp_leq, cmp_gt, cmp_geq) - ARB_PP_FOREACH(ARB_DECLARE_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, exprelr, sqrt, step_right, step_left, step, signum) + ARB_PP_FOREACH(ARB_DECLARE_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, exprelr, sqrt, step_right, step_left, step, signum, sigmoid, relu, tanh) #undef ARB_DECLARE_UNARY_ARITHMETIC_ #undef ARB_DECLARE_BINARY_ARITHMETIC_ diff --git a/doc/dev/simd_api.rst b/doc/dev/simd_api.rst index a8324fc6..37b5b5cd 100644 --- a/doc/dev/simd_api.rst +++ b/doc/dev/simd_api.rst @@ -594,6 +594,18 @@ In the following: - *S* - Lane-wise :math:`x \mapsto \begin{align*} 1 & ~~ \text{if} ~x \gt 0, \\ 0 & ~~ \text{otherwise}. \end{align*}` + * - ``tanh(s)`` + - *S* + - Lane-wise :math:`x \mapsto tanh(x)` + + * - ``relu(s)`` + - *S* + - Lane-wise :math:`x \mapsto max(0, x)` + + * - ``sigmoid(s)`` + - *S* + - Lane-wise :math:`x \mapsto \frac{1}{1+e^{-x}}` + * - ``simd_cast<std::array<L, N>>(a)`` - ``std::array<L, N>`` - Lane-wise cast of values in *a* to scalar type *L* in ``std::array<L, N>``. diff --git a/doc/fileformat/nmodl.rst b/doc/fileformat/nmodl.rst index 5ec115cc..a6d3de37 100644 --- a/doc/fileformat/nmodl.rst +++ b/doc/fileformat/nmodl.rst @@ -173,8 +173,11 @@ Arbor-specific features step(x) heaviside step with half value :math:`\begin{align*} 1 & ~~ \text{if} ~x \gt 0, \\ 0 & ~~ \text{if} ~x \lt 0, \\ 0.5 & ~~ \text{otherwise}. \end{align*}` signum(x) sign of argument :math:`\begin{align*} +1 & ~~ \text{if} ~x \gt 0, \\ -1 & ~~ \text{if} ~x \lt 0, \\ 0 & ~~ \text{otherwise}. \end{align*}` exprelr(x) smooth continuation over :math:`x=0` of :math:`x/(1 - e^{-x})` + sigmoid(x) sigmoidal function :math:`\frac{1}{1+e^{-x}}` + relu(x) rectified linear function :math:`max(0, x)` + tanh(x) hyperbolic tangent :math:`tanh(x)` ================== ======================================== ========= - + Voltage Processes ----------------- diff --git a/modcc/expression.cpp b/modcc/expression.cpp index b7ba2ace..5178059c 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -1121,6 +1121,15 @@ void StepLeftUnaryExpression::accept(Visitor *v) { void StepUnaryExpression::accept(Visitor *v) { v->visit(this); } +void ReLuUnaryExpression::accept(Visitor *v) { + v->visit(this); +} +void TanHUnaryExpression::accept(Visitor *v) { + v->visit(this); +} +void SigmoidUnaryExpression::accept(Visitor *v) { + v->visit(this); +} void SignumUnaryExpression::accept(Visitor *v) { v->visit(this); } @@ -1208,6 +1217,12 @@ ARB_LIBMODCC_API expression_ptr unary_expression( Location loc, return make_expression<StepUnaryExpression>(loc, std::move(e)); case tok::signum : return make_expression<SignumUnaryExpression>(loc, std::move(e)); + case tok::sigmoid : + return make_expression<SigmoidUnaryExpression>(loc, std::move(e)); + case tok::relu : + return make_expression<ReLuUnaryExpression>(loc, std::move(e)); + case tok::tanh : + return make_expression<TanHUnaryExpression>(loc, std::move(e)); default : std::cerr << yellow(token_string(op)) << " is not a valid unary operator" diff --git a/modcc/expression.hpp b/modcc/expression.hpp index 31188ab5..4d3cd6af 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -1398,6 +1398,33 @@ public: void accept(Visitor *v) override; }; +class ARB_LIBMODCC_API TanHUnaryExpression : public UnaryExpression { +public: + TanHUnaryExpression(Location loc, expression_ptr e) + : UnaryExpression(loc, tok::tanh, std::move(e)) + {} + + void accept(Visitor *v) override; +}; + +class ARB_LIBMODCC_API SigmoidUnaryExpression : public UnaryExpression { +public: + SigmoidUnaryExpression(Location loc, expression_ptr e) + : UnaryExpression(loc, tok::sigmoid, std::move(e)) + {} + + void accept(Visitor *v) override; +}; + +class ARB_LIBMODCC_API ReLuUnaryExpression : public UnaryExpression { +public: + ReLuUnaryExpression(Location loc, expression_ptr e) + : UnaryExpression(loc, tok::relu, std::move(e)) + {} + + void accept(Visitor *v) override; +}; + //////////////////////////////////////////////////////////// // binary expressions diff --git a/modcc/parser.cpp b/modcc/parser.cpp index dfd0bf4d..067bb67c 100644 --- a/modcc/parser.cpp +++ b/modcc/parser.cpp @@ -1453,6 +1453,9 @@ expression_ptr Parser::parse_unaryop() { case tok::step_left: case tok::step: case tok::signum: + case tok::sigmoid: + case tok::tanh: + case tok::relu: get_token(); // consume operator (exp, sin, cos or log) if (token_.type != tok::lparen) { error("missing parenthesis after call to " + yellow(op.spelling)); diff --git a/modcc/perfvisitor.hpp b/modcc/perfvisitor.hpp index 2971fa4a..80c3088c 100644 --- a/modcc/perfvisitor.hpp +++ b/modcc/perfvisitor.hpp @@ -17,9 +17,10 @@ struct FlopAccumulator { int log=0; int pow=0; int sqrt=0; + int tanh=0; void reset() { - add = neg = mul = div = exp = sin = cos = log = pow = sqrt = 0; + add = neg = mul = div = exp = sin = cos = log = pow = sqrt = tanh = 0; } }; @@ -27,8 +28,8 @@ static std::ostream& operator << (std::ostream& os, FlopAccumulator const& f) { char buffer[512]; snprintf(buffer, 512, - " add neg mul div exp sin cos log pow sqrt\n%6d%6d%6d%6d%6d%6d%6d%6d%6d%6d", - f.add, f.neg, f.mul, f.div, f.exp, f.sin, f.cos, f.log, f.pow, f.sqrt); + " add neg mul div exp sin cos log pow sqrt tanh\n%6d%6d%6d%6d%6d%6d%6d%6d%6d%6d%6d", + f.add, f.neg, f.mul, f.div, f.exp, f.sin, f.cos, f.log, f.pow, f.sqrt, f.tanh); os << buffer << std::endl << std::endl; os << " add+mul+neg " << f.add + f.neg + f.mul << std::endl; @@ -111,6 +112,21 @@ public: e->expression()->accept(this); flops.add++; } + void visit(SigmoidUnaryExpression *e) override { + e->expression()->accept(this); + // 1 / ( 1 + exp(-x)) + flops.exp++; + flops.neg++; + flops.div++; + flops.add++; + } + void visit(ReLuUnaryExpression *e) override { + e->expression()->accept(this); + } + void visit(TanHUnaryExpression *e) override { + e->expression()->accept(this); + flops.tanh++; + } //////////////////////////////////////////////////// // specializations for each type of binary expression diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp index 4c685a7b..df081195 100644 --- a/modcc/printer/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -68,7 +68,10 @@ void CExprEmitter::visit(UnaryExpression* e) { {tok::step_right, "step_right"}, {tok::step_left, "step_left"}, {tok::step, "step"}, - {tok::signum, "signum"} + {tok::signum, "signum"}, + {tok::tanh, "tanh"}, + {tok::sigmoid, "sigmoid"}, + {tok::relu, "relu"} }; if (!unaryop_tbl.count(e->op())) { @@ -109,6 +112,15 @@ void CExprEmitter::visit(UnaryExpression* e) { inner->accept(this); out_ << ")<0.)))"; } + else if (e->op()==tok::relu) { + out_ << "max(0.0, ("; inner->accept(this); out_ << "))"; + } + else if (e->op()==tok::sigmoid) { + out_ << "1.0/(1.0 + exp(-("; inner->accept(this); out_ << ")))"; + } + else if (e->op()==tok::tanh) { + out_ << "tanh("; inner->accept(this); out_ << ")"; + } else { emit_as_call(op_spelling, inner); } @@ -222,7 +234,10 @@ void SimdExprEmitter::visit(UnaryExpression* e) { {tok::step_right, "S::step_right"}, {tok::step_left, "S::step_left"}, {tok::step, "S::step"}, - {tok::signum, "S::signum"} + {tok::signum, "S::signum"}, + {tok::sigmoid, "S::sigmoid"}, + {tok::relu, "S::relu"}, + {tok::tanh, "S::tanh"} }; if (!unaryop_tbl.count(e->op())) { diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index d95e97df..4290497c 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -173,6 +173,7 @@ ARB_LIBMODCC_API std::string emit_cpp_source(const Module& module_, const printe "using ::std::pow;\n" "using ::std::sin;\n" "using ::std::sqrt;\n" + "using ::std::tanh;\n" "\n"; if (with_simd) { diff --git a/modcc/symdiff.cpp b/modcc/symdiff.cpp index c6fc0dfc..6cc9ac0d 100644 --- a/modcc/symdiff.cpp +++ b/modcc/symdiff.cpp @@ -265,6 +265,46 @@ public: result_ = make_expression<IntegerExpression>(loc, 0); } + void visit(TanHUnaryExpression* e) override { + // (1 - tanh(f(x))^2) * df(x)/dx + auto loc = e->location(); + e->expression()->accept(this); + result_ = make_expression<MulBinaryExpression>(loc, + make_expression<SubBinaryExpression>(loc, + make_expression<NumberExpression>(loc, 1.0), + make_expression<PowBinaryExpression>(loc, + make_expression<TanHUnaryExpression>(loc, e->expression()->clone()), + make_expression<NumberExpression>(loc, 2.0)) + ), + result() + ); + } + + void visit(SigmoidUnaryExpression* e) override { + // s'(x) = s(x) * (1 - s(x)) * x' + auto loc = e->location(); + e->expression()->accept(this); + result_ = make_expression<MulBinaryExpression>(loc, + make_expression<MulBinaryExpression>(loc, + make_expression<SigmoidUnaryExpression>(loc, e->expression()->clone()), + make_expression<SubBinaryExpression>(loc, + make_expression<NumberExpression>(loc, 1.0), + make_expression<SigmoidUnaryExpression>(loc, e->expression()->clone()) + ) + ), + result() + ); + } + + void visit(ReLuUnaryExpression* e) override { + auto loc = e->location(); + e->expression()->accept(this); + result_ = make_expression<MulBinaryExpression>(loc, + make_expression<StepUnaryExpression>(loc, e->expression()->clone()), + result() + ); + } + void visit(SignumUnaryExpression* e) override { // ignore singularity auto loc = e->location(); @@ -423,6 +463,15 @@ public: case tok::signum: as_number(loc, (0. < val) - (val < 0.)); return; + case tok::tanh: + as_number(loc, std::tanh(val)); + return; + case tok::relu: + as_number(loc, std::max(0.0, val)); + return; + case tok::sigmoid: + as_number(loc, 1.0 / (1.0 + std::exp(-val))); + return; default: ; // treat opaquely as below } } diff --git a/modcc/token.cpp b/modcc/token.cpp index 34081c48..fbb13b21 100644 --- a/modcc/token.cpp +++ b/modcc/token.cpp @@ -82,6 +82,9 @@ static Keyword keywords[] = { {"step_right", tok::step_right}, {"step_left", tok::step_left}, {"step", tok::step}, + {"tanh", tok::tanh}, + {"sigmoid", tok::sigmoid}, + {"relu", tok::relu}, {"signum", tok::signum}, {nullptr, tok::reserved}, }; @@ -168,6 +171,9 @@ static TokenString token_strings[] = { {"step_right", tok::step_right}, {"step_left", tok::step_left}, {"step", tok::step}, + {"tanh", tok::tanh}, + {"sigmoid", tok::sigmoid}, + {"relu", tok::relu}, {"signum", tok::signum}, {"error", tok::reserved}, }; diff --git a/modcc/token.hpp b/modcc/token.hpp index 109d38a0..cf5ce533 100644 --- a/modcc/token.hpp +++ b/modcc/token.hpp @@ -79,6 +79,11 @@ enum class tok { step, // heaviside step function (H(0) = 0.5) signum, // sign function {-1, 0, +1} + // neural (the other neural) functions + tanh, + sigmoid, + relu, + // logical keywords if_stmt, else_stmt, // add _stmt to avoid clash with c++ keywords diff --git a/modcc/visitor.hpp b/modcc/visitor.hpp index 116d80b0..8a638547 100644 --- a/modcc/visitor.hpp +++ b/modcc/visitor.hpp @@ -53,6 +53,9 @@ public: virtual void visit(StepRightUnaryExpression *e) { visit((UnaryExpression*) e); } virtual void visit(StepLeftUnaryExpression *e) { visit((UnaryExpression*) e); } virtual void visit(StepUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(SigmoidUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(ReLuUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(TanHUnaryExpression *e) { visit((UnaryExpression*) e); } virtual void visit(SignumUnaryExpression *e) { visit((UnaryExpression*) e); } virtual void visit(BinaryExpression *e) = 0; diff --git a/test/unit-modcc/test_symdiff.cpp b/test/unit-modcc/test_symdiff.cpp index c9bf6ad4..60b2d31e 100644 --- a/test/unit-modcc/test_symdiff.cpp +++ b/test/unit-modcc/test_symdiff.cpp @@ -87,7 +87,11 @@ TEST(constant_simplify, constants) { { "log(exp(2))+cos(0)", 3. }, { "0/17-1", -1. }, { "2.5*(34/17-1.0e2)", -245. }, - { "-sin(0.523598775598298873077107230546583814)", -0.5 } + { "-sin(0.523598775598298873077107230546583814)", -0.5 }, + { "sigmoid(0.0)", 0.5 }, + { "relu(2.0)", 2.0 }, + { "relu(-2.0)", 0.0 }, + { "tanh(0.0)", 0.0 }, }; for (const auto& item: tests) { diff --git a/test/unit/test_simd.cpp b/test/unit/test_simd.cpp index 7f0e0a33..e0448510 100644 --- a/test/unit/test_simd.cpp +++ b/test/unit/test_simd.cpp @@ -742,6 +742,23 @@ TYPED_TEST_P(simd_fp_value, fp_maths) { sqrt(simd(u)).copy_to(r); EXPECT_TRUE(testing::seq_almost_eq<fp>(sqrt_fp, r)); + // Nonlinear 'AI' functions + fill_random(u, rng, 0., max_value); + fp sigmoid_fp[N]; + for (unsigned i = 0; i<N; ++i) sigmoid_fp[i] = 1 / (1 + std::exp(-u[i])); + sigmoid(simd(u)).copy_to(r); + EXPECT_TRUE(testing::seq_almost_eq<fp>(sigmoid_fp, r)); + fill_random(u, rng, 0., max_value); + fp tanh_fp[N]; + for (unsigned i = 0; i<N; ++i) tanh_fp[i] = std::tanh(u[i]); + tanh(simd(u)).copy_to(r); + EXPECT_TRUE(testing::seq_almost_eq<fp>(tanh_fp, r)); + fill_random(u, rng, 0., max_value); + fp relu_fp[N]; + for (unsigned i = 0; i<N; ++i) relu_fp[i] = u[i] > 0 ? u[i] : 0; + relu(simd(u)).copy_to(r); + EXPECT_TRUE(testing::seq_almost_eq<fp>(relu_fp, r)); + // Indicator functions: fill_random(u, rng, 0.01, 10.0); fill_random(v, rng, -10.0, -0.1); -- GitLab