diff --git a/arbor/include/arbor/simd/implbase.hpp b/arbor/include/arbor/simd/implbase.hpp index 968ffce0353a428809f9ff3f71b2fa60eb99c90a..fd62ad91bf79b9d6250381263f2799d43ddd5cfc 100644 --- a/arbor/include/arbor/simd/implbase.hpp +++ b/arbor/include/arbor/simd/implbase.hpp @@ -573,6 +573,26 @@ struct implbase { } return I::copy_from(r); } + + static vector_type relu(const vector_type& s) { + vector_type zeros = I::broadcast(0); + return I::ifelse(I::cmp_gt(s, zeros), s, zeros); + } + + static vector_type sigmoid(const vector_type& s) { + vector_type ones = I::broadcast(1); + return div(ones, add(ones, exp(neg(s)))); + } + + static vector_type tanh(const vector_type& s) { + store a, r; + I::copy_to(s, a); + + for (unsigned i = 0; i<width; ++i) { + r[i] = std::tanh(a[i]); + } + return I::copy_from(r); + } }; } // namespace detail diff --git a/arbor/include/arbor/simd/simd.hpp b/arbor/include/arbor/simd/simd.hpp index 48c2e0a0758d3d5052c1bd619f43faab607f6837..934c5aee43972f858149193d1628c650e0cc33c1 100644 --- a/arbor/include/arbor/simd/simd.hpp +++ b/arbor/include/arbor/simd/simd.hpp @@ -82,7 +82,7 @@ typename detail::simd_impl<Impl>::simd_mask name(const typename detail::simd_imp ARB_PP_FOREACH(ARB_BINARY_ARITHMETIC_, add, sub, mul, div, pow, max, min) ARB_PP_FOREACH(ARB_BINARY_COMPARISON_, cmp_eq, cmp_neq, cmp_leq, cmp_lt, cmp_geq, cmp_gt) -ARB_PP_FOREACH(ARB_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, exprelr, sqrt, step_right, step_left, step, signum) +ARB_PP_FOREACH(ARB_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, exprelr, sqrt, step_right, step_left, step, signum, sigmoid, relu, tanh) #undef ARB_BINARY_ARITHMETIC_ #undef ARB_BINARY_COMPARISON__ @@ -685,7 +685,7 @@ namespace detail { ARB_PP_FOREACH(ARB_DECLARE_BINARY_ARITHMETIC_, add, sub, mul, div, pow, max, min, cmp_eq) ARB_PP_FOREACH(ARB_DECLARE_BINARY_COMPARISON_, cmp_eq, cmp_neq, cmp_lt, cmp_leq, cmp_gt, cmp_geq) - ARB_PP_FOREACH(ARB_DECLARE_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, exprelr, sqrt, step_right, step_left, step, signum) + ARB_PP_FOREACH(ARB_DECLARE_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, exprelr, sqrt, step_right, step_left, step, signum, sigmoid, relu, tanh) #undef ARB_DECLARE_UNARY_ARITHMETIC_ #undef ARB_DECLARE_BINARY_ARITHMETIC_ diff --git a/doc/dev/simd_api.rst b/doc/dev/simd_api.rst index a8324fc64d0f5364026624aea051fe55b584b596..37b5b5cde08a16bf62d20ac48285acda20d4c40a 100644 --- a/doc/dev/simd_api.rst +++ b/doc/dev/simd_api.rst @@ -594,6 +594,18 @@ In the following: - *S* - Lane-wise :math:`x \mapsto \begin{align*} 1 & ~~ \text{if} ~x \gt 0, \\ 0 & ~~ \text{otherwise}. \end{align*}` + * - ``tanh(s)`` + - *S* + - Lane-wise :math:`x \mapsto tanh(x)` + + * - ``relu(s)`` + - *S* + - Lane-wise :math:`x \mapsto max(0, x)` + + * - ``sigmoid(s)`` + - *S* + - Lane-wise :math:`x \mapsto \frac{1}{1+e^{-x}}` + * - ``simd_cast<std::array<L, N>>(a)`` - ``std::array<L, N>`` - Lane-wise cast of values in *a* to scalar type *L* in ``std::array<L, N>``. diff --git a/doc/fileformat/nmodl.rst b/doc/fileformat/nmodl.rst index 5ec115cc26ce701b9291360c5f50832473afa46c..a6d3de37b07ea6bea942bdcf25ebd62f34e034dd 100644 --- a/doc/fileformat/nmodl.rst +++ b/doc/fileformat/nmodl.rst @@ -173,8 +173,11 @@ Arbor-specific features step(x) heaviside step with half value :math:`\begin{align*} 1 & ~~ \text{if} ~x \gt 0, \\ 0 & ~~ \text{if} ~x \lt 0, \\ 0.5 & ~~ \text{otherwise}. \end{align*}` signum(x) sign of argument :math:`\begin{align*} +1 & ~~ \text{if} ~x \gt 0, \\ -1 & ~~ \text{if} ~x \lt 0, \\ 0 & ~~ \text{otherwise}. \end{align*}` exprelr(x) smooth continuation over :math:`x=0` of :math:`x/(1 - e^{-x})` + sigmoid(x) sigmoidal function :math:`\frac{1}{1+e^{-x}}` + relu(x) rectified linear function :math:`max(0, x)` + tanh(x) hyperbolic tangent :math:`tanh(x)` ================== ======================================== ========= - + Voltage Processes ----------------- diff --git a/modcc/expression.cpp b/modcc/expression.cpp index b7ba2aced605d06b651a4c665e6b3ca2039207f5..5178059c4d155ce6b4a2ba7559706edf15f29496 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -1121,6 +1121,15 @@ void StepLeftUnaryExpression::accept(Visitor *v) { void StepUnaryExpression::accept(Visitor *v) { v->visit(this); } +void ReLuUnaryExpression::accept(Visitor *v) { + v->visit(this); +} +void TanHUnaryExpression::accept(Visitor *v) { + v->visit(this); +} +void SigmoidUnaryExpression::accept(Visitor *v) { + v->visit(this); +} void SignumUnaryExpression::accept(Visitor *v) { v->visit(this); } @@ -1208,6 +1217,12 @@ ARB_LIBMODCC_API expression_ptr unary_expression( Location loc, return make_expression<StepUnaryExpression>(loc, std::move(e)); case tok::signum : return make_expression<SignumUnaryExpression>(loc, std::move(e)); + case tok::sigmoid : + return make_expression<SigmoidUnaryExpression>(loc, std::move(e)); + case tok::relu : + return make_expression<ReLuUnaryExpression>(loc, std::move(e)); + case tok::tanh : + return make_expression<TanHUnaryExpression>(loc, std::move(e)); default : std::cerr << yellow(token_string(op)) << " is not a valid unary operator" diff --git a/modcc/expression.hpp b/modcc/expression.hpp index 31188ab58493c78aeb7ce9db1fd582ee26f03ad0..4d3cd6aff6b650f02df4e5b811cab3d86589d118 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -1398,6 +1398,33 @@ public: void accept(Visitor *v) override; }; +class ARB_LIBMODCC_API TanHUnaryExpression : public UnaryExpression { +public: + TanHUnaryExpression(Location loc, expression_ptr e) + : UnaryExpression(loc, tok::tanh, std::move(e)) + {} + + void accept(Visitor *v) override; +}; + +class ARB_LIBMODCC_API SigmoidUnaryExpression : public UnaryExpression { +public: + SigmoidUnaryExpression(Location loc, expression_ptr e) + : UnaryExpression(loc, tok::sigmoid, std::move(e)) + {} + + void accept(Visitor *v) override; +}; + +class ARB_LIBMODCC_API ReLuUnaryExpression : public UnaryExpression { +public: + ReLuUnaryExpression(Location loc, expression_ptr e) + : UnaryExpression(loc, tok::relu, std::move(e)) + {} + + void accept(Visitor *v) override; +}; + //////////////////////////////////////////////////////////// // binary expressions diff --git a/modcc/parser.cpp b/modcc/parser.cpp index dfd0bf4d9186fb845abdbc21229c620edbcccb3e..067bb67c15b419766e5a8ec226ad1673c9a8dbe1 100644 --- a/modcc/parser.cpp +++ b/modcc/parser.cpp @@ -1453,6 +1453,9 @@ expression_ptr Parser::parse_unaryop() { case tok::step_left: case tok::step: case tok::signum: + case tok::sigmoid: + case tok::tanh: + case tok::relu: get_token(); // consume operator (exp, sin, cos or log) if (token_.type != tok::lparen) { error("missing parenthesis after call to " + yellow(op.spelling)); diff --git a/modcc/perfvisitor.hpp b/modcc/perfvisitor.hpp index 2971fa4a578f558558dd505d97b57b146439dfa6..80c3088c3c0a203cb6136a1f0e925e1253963a61 100644 --- a/modcc/perfvisitor.hpp +++ b/modcc/perfvisitor.hpp @@ -17,9 +17,10 @@ struct FlopAccumulator { int log=0; int pow=0; int sqrt=0; + int tanh=0; void reset() { - add = neg = mul = div = exp = sin = cos = log = pow = sqrt = 0; + add = neg = mul = div = exp = sin = cos = log = pow = sqrt = tanh = 0; } }; @@ -27,8 +28,8 @@ static std::ostream& operator << (std::ostream& os, FlopAccumulator const& f) { char buffer[512]; snprintf(buffer, 512, - " add neg mul div exp sin cos log pow sqrt\n%6d%6d%6d%6d%6d%6d%6d%6d%6d%6d", - f.add, f.neg, f.mul, f.div, f.exp, f.sin, f.cos, f.log, f.pow, f.sqrt); + " add neg mul div exp sin cos log pow sqrt tanh\n%6d%6d%6d%6d%6d%6d%6d%6d%6d%6d%6d", + f.add, f.neg, f.mul, f.div, f.exp, f.sin, f.cos, f.log, f.pow, f.sqrt, f.tanh); os << buffer << std::endl << std::endl; os << " add+mul+neg " << f.add + f.neg + f.mul << std::endl; @@ -111,6 +112,21 @@ public: e->expression()->accept(this); flops.add++; } + void visit(SigmoidUnaryExpression *e) override { + e->expression()->accept(this); + // 1 / ( 1 + exp(-x)) + flops.exp++; + flops.neg++; + flops.div++; + flops.add++; + } + void visit(ReLuUnaryExpression *e) override { + e->expression()->accept(this); + } + void visit(TanHUnaryExpression *e) override { + e->expression()->accept(this); + flops.tanh++; + } //////////////////////////////////////////////////// // specializations for each type of binary expression diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp index 4c685a7b13666e484914728ef3482596f2558357..df081195a39bba21e5031dc52f5649a7d7dc2bed 100644 --- a/modcc/printer/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -68,7 +68,10 @@ void CExprEmitter::visit(UnaryExpression* e) { {tok::step_right, "step_right"}, {tok::step_left, "step_left"}, {tok::step, "step"}, - {tok::signum, "signum"} + {tok::signum, "signum"}, + {tok::tanh, "tanh"}, + {tok::sigmoid, "sigmoid"}, + {tok::relu, "relu"} }; if (!unaryop_tbl.count(e->op())) { @@ -109,6 +112,15 @@ void CExprEmitter::visit(UnaryExpression* e) { inner->accept(this); out_ << ")<0.)))"; } + else if (e->op()==tok::relu) { + out_ << "max(0.0, ("; inner->accept(this); out_ << "))"; + } + else if (e->op()==tok::sigmoid) { + out_ << "1.0/(1.0 + exp(-("; inner->accept(this); out_ << ")))"; + } + else if (e->op()==tok::tanh) { + out_ << "tanh("; inner->accept(this); out_ << ")"; + } else { emit_as_call(op_spelling, inner); } @@ -222,7 +234,10 @@ void SimdExprEmitter::visit(UnaryExpression* e) { {tok::step_right, "S::step_right"}, {tok::step_left, "S::step_left"}, {tok::step, "S::step"}, - {tok::signum, "S::signum"} + {tok::signum, "S::signum"}, + {tok::sigmoid, "S::sigmoid"}, + {tok::relu, "S::relu"}, + {tok::tanh, "S::tanh"} }; if (!unaryop_tbl.count(e->op())) { diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index d95e97df8c78a106e541d30199b0c23eb55dfc07..4290497c877ff5a052ffd8b72168cafc5b853367 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -173,6 +173,7 @@ ARB_LIBMODCC_API std::string emit_cpp_source(const Module& module_, const printe "using ::std::pow;\n" "using ::std::sin;\n" "using ::std::sqrt;\n" + "using ::std::tanh;\n" "\n"; if (with_simd) { diff --git a/modcc/symdiff.cpp b/modcc/symdiff.cpp index c6fc0dfce744e14885c79ea9bbbdd7d73db44a42..6cc9ac0d3d77389e263c118d162ef36112c1e3a8 100644 --- a/modcc/symdiff.cpp +++ b/modcc/symdiff.cpp @@ -265,6 +265,46 @@ public: result_ = make_expression<IntegerExpression>(loc, 0); } + void visit(TanHUnaryExpression* e) override { + // (1 - tanh(f(x))^2) * df(x)/dx + auto loc = e->location(); + e->expression()->accept(this); + result_ = make_expression<MulBinaryExpression>(loc, + make_expression<SubBinaryExpression>(loc, + make_expression<NumberExpression>(loc, 1.0), + make_expression<PowBinaryExpression>(loc, + make_expression<TanHUnaryExpression>(loc, e->expression()->clone()), + make_expression<NumberExpression>(loc, 2.0)) + ), + result() + ); + } + + void visit(SigmoidUnaryExpression* e) override { + // s'(x) = s(x) * (1 - s(x)) * x' + auto loc = e->location(); + e->expression()->accept(this); + result_ = make_expression<MulBinaryExpression>(loc, + make_expression<MulBinaryExpression>(loc, + make_expression<SigmoidUnaryExpression>(loc, e->expression()->clone()), + make_expression<SubBinaryExpression>(loc, + make_expression<NumberExpression>(loc, 1.0), + make_expression<SigmoidUnaryExpression>(loc, e->expression()->clone()) + ) + ), + result() + ); + } + + void visit(ReLuUnaryExpression* e) override { + auto loc = e->location(); + e->expression()->accept(this); + result_ = make_expression<MulBinaryExpression>(loc, + make_expression<StepUnaryExpression>(loc, e->expression()->clone()), + result() + ); + } + void visit(SignumUnaryExpression* e) override { // ignore singularity auto loc = e->location(); @@ -423,6 +463,15 @@ public: case tok::signum: as_number(loc, (0. < val) - (val < 0.)); return; + case tok::tanh: + as_number(loc, std::tanh(val)); + return; + case tok::relu: + as_number(loc, std::max(0.0, val)); + return; + case tok::sigmoid: + as_number(loc, 1.0 / (1.0 + std::exp(-val))); + return; default: ; // treat opaquely as below } } diff --git a/modcc/token.cpp b/modcc/token.cpp index 34081c48db414460ce4ec819270c1b9a1afa0d66..fbb13b219ed92212f8eca61b966d27f745dccc44 100644 --- a/modcc/token.cpp +++ b/modcc/token.cpp @@ -82,6 +82,9 @@ static Keyword keywords[] = { {"step_right", tok::step_right}, {"step_left", tok::step_left}, {"step", tok::step}, + {"tanh", tok::tanh}, + {"sigmoid", tok::sigmoid}, + {"relu", tok::relu}, {"signum", tok::signum}, {nullptr, tok::reserved}, }; @@ -168,6 +171,9 @@ static TokenString token_strings[] = { {"step_right", tok::step_right}, {"step_left", tok::step_left}, {"step", tok::step}, + {"tanh", tok::tanh}, + {"sigmoid", tok::sigmoid}, + {"relu", tok::relu}, {"signum", tok::signum}, {"error", tok::reserved}, }; diff --git a/modcc/token.hpp b/modcc/token.hpp index 109d38a0624fc826350adb5928b9cb56a94d6def..cf5ce533a0e9a0197a7a4ec8dfda9e044d97ef31 100644 --- a/modcc/token.hpp +++ b/modcc/token.hpp @@ -79,6 +79,11 @@ enum class tok { step, // heaviside step function (H(0) = 0.5) signum, // sign function {-1, 0, +1} + // neural (the other neural) functions + tanh, + sigmoid, + relu, + // logical keywords if_stmt, else_stmt, // add _stmt to avoid clash with c++ keywords diff --git a/modcc/visitor.hpp b/modcc/visitor.hpp index 116d80b06dc3eb93b85cfd62899e8b318e01769e..8a638547b90a43ccd4d5fc7dd9dd4092360e474b 100644 --- a/modcc/visitor.hpp +++ b/modcc/visitor.hpp @@ -53,6 +53,9 @@ public: virtual void visit(StepRightUnaryExpression *e) { visit((UnaryExpression*) e); } virtual void visit(StepLeftUnaryExpression *e) { visit((UnaryExpression*) e); } virtual void visit(StepUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(SigmoidUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(ReLuUnaryExpression *e) { visit((UnaryExpression*) e); } + virtual void visit(TanHUnaryExpression *e) { visit((UnaryExpression*) e); } virtual void visit(SignumUnaryExpression *e) { visit((UnaryExpression*) e); } virtual void visit(BinaryExpression *e) = 0; diff --git a/test/unit-modcc/test_symdiff.cpp b/test/unit-modcc/test_symdiff.cpp index c9bf6ad474636d5b384c968f6617938727a34f6d..60b2d31e7cfd228ea55fca95b4ea098e1053e1e0 100644 --- a/test/unit-modcc/test_symdiff.cpp +++ b/test/unit-modcc/test_symdiff.cpp @@ -87,7 +87,11 @@ TEST(constant_simplify, constants) { { "log(exp(2))+cos(0)", 3. }, { "0/17-1", -1. }, { "2.5*(34/17-1.0e2)", -245. }, - { "-sin(0.523598775598298873077107230546583814)", -0.5 } + { "-sin(0.523598775598298873077107230546583814)", -0.5 }, + { "sigmoid(0.0)", 0.5 }, + { "relu(2.0)", 2.0 }, + { "relu(-2.0)", 0.0 }, + { "tanh(0.0)", 0.0 }, }; for (const auto& item: tests) { diff --git a/test/unit/test_simd.cpp b/test/unit/test_simd.cpp index 7f0e0a339aafc959d7c729c9269b51d8504349aa..e04485105b205b053c5c70e3fb7e35c81c63cc2e 100644 --- a/test/unit/test_simd.cpp +++ b/test/unit/test_simd.cpp @@ -742,6 +742,23 @@ TYPED_TEST_P(simd_fp_value, fp_maths) { sqrt(simd(u)).copy_to(r); EXPECT_TRUE(testing::seq_almost_eq<fp>(sqrt_fp, r)); + // Nonlinear 'AI' functions + fill_random(u, rng, 0., max_value); + fp sigmoid_fp[N]; + for (unsigned i = 0; i<N; ++i) sigmoid_fp[i] = 1 / (1 + std::exp(-u[i])); + sigmoid(simd(u)).copy_to(r); + EXPECT_TRUE(testing::seq_almost_eq<fp>(sigmoid_fp, r)); + fill_random(u, rng, 0., max_value); + fp tanh_fp[N]; + for (unsigned i = 0; i<N; ++i) tanh_fp[i] = std::tanh(u[i]); + tanh(simd(u)).copy_to(r); + EXPECT_TRUE(testing::seq_almost_eq<fp>(tanh_fp, r)); + fill_random(u, rng, 0., max_value); + fp relu_fp[N]; + for (unsigned i = 0; i<N; ++i) relu_fp[i] = u[i] > 0 ? u[i] : 0; + relu(simd(u)).copy_to(r); + EXPECT_TRUE(testing::seq_almost_eq<fp>(relu_fp, r)); + // Indicator functions: fill_random(u, rng, 0.01, 10.0); fill_random(v, rng, -10.0, -0.1);