diff --git a/modcc/linearrewriter.cpp b/modcc/linearrewriter.cpp index 3705a2ed74ee569ca8bc33a949749ad97d1d54d3..e3513694fc5cffda7847e3aa3bdb0a813542f4ff 100644 --- a/modcc/linearrewriter.cpp +++ b/modcc/linearrewriter.cpp @@ -43,9 +43,14 @@ void LinearRewriter::visit(LinearExpression* e) { for (const auto& state : state_vars) { // To factorize w.r.t state, differentiate the lhs and rhs auto ident = make_expression<IdentifierExpression>(loc, state); + auto lhs_pdiff = symbolic_pdiff(e->lhs(), state); + auto rhs_pdiff = symbolic_pdiff(e->rhs(), state); + if (!lhs_pdiff || !rhs_pdiff) { + error({"expression in LINEAR system is not linear", e->location()}); + return; + } auto coeff = constant_simplify(make_expression<SubBinaryExpression>(loc, - symbolic_pdiff(e->lhs(), state), - symbolic_pdiff(e->rhs(), state))); + lhs_pdiff->clone(), rhs_pdiff->clone())); if (expr_value(coeff) != 0) { auto local_coeff = make_unique_local_assign(scope, coeff, "l_"); diff --git a/modcc/module.cpp b/modcc/module.cpp index aaf9fcc113142701791d38796286ac56a9ad47a3..a9b86b6db9efe8d7320740cc16eda0952f250190 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -298,6 +298,11 @@ bool Module::semantic() { if (solve_proc->kind() == procedureKind::linear) { solver = std::make_unique<LinearSolverVisitor>(state_vars); auto rewrite_body = linear_rewrite(solve_proc->body(), state_vars); + if (!rewrite_body) { + error("An error occured while compiling the LINEAR block. " + "Check whether the statements are in fact linear."); + return false; + } rewrite_body->semantic(nrn_init_scope); rewrite_body->accept(solver.get()); @@ -509,6 +514,10 @@ bool Module::semantic() { if (s->is_assignment()) { for (const auto &id: state_vars) { auto coef = symbolic_pdiff(s->is_assignment()->rhs(), id); + if(!coef) { + linear = false; + continue; + } if(coef->is_number()) { if (!s->is_assignment()->lhs()->is_identifier()) { error(pprintf("Left hand side of assignment is not an identifier")); diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp index 4c947d6088179c5870c1aa9ba5941e214863e50e..30e06ef3a26a4e0abb85469910f2284e11815799 100644 --- a/modcc/printer/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -21,8 +21,15 @@ std::ostream& operator<<(std::ostream& out, as_c_double wrap) { case FP_ZERO: return out << (neg? "-0.": "0."); default: - return out << - (std::stringstream{} << io::classic << std::setprecision(17) << wrap.value).rdbuf(); + double val; + std::stringstream s; + // If wrap.value is an int print it as X.0, this is needed for std::max and std::min + if (std::modf(wrap.value, &val) == 0) { + s << io::classic << std::fixed << std::setprecision(1) << wrap.value; + } else { + s << io::classic << std::setprecision(17) << wrap.value; + } + return out << s.rdbuf(); } } diff --git a/modcc/symdiff.cpp b/modcc/symdiff.cpp index ef5e3222f937a7dbfdfa93e3c9738c5a93db7945..db32ac3aeb96ffeedfc4fbfaff1996e26229f813 100644 --- a/modcc/symdiff.cpp +++ b/modcc/symdiff.cpp @@ -562,6 +562,8 @@ expression_ptr symbolic_pdiff(Expression* e, const std::string& id) { SymPDiffVisitor pdiff_visitor(id); e->accept(&pdiff_visitor); + if (pdiff_visitor.has_error()) return nullptr; + return constant_simplify(pdiff_visitor.result()); } @@ -664,6 +666,9 @@ linear_test_result linear_test(Expression* e, const std::vector<std::string>& va result.constant = e->clone(); for (const auto& id: vars) { auto coef = symbolic_pdiff(e, id); + if (!coef) { + return linear_test_result{}; + } if (!is_zero(coef)) result.coef[id] = std::move(coef); result.constant = substitute(result.constant, id, zero()); @@ -681,8 +686,8 @@ linear_test_result linear_test(Expression* e, const std::vector<std::string>& va for (unsigned j = i; j<vars.size(); ++j) { auto v2 = vars[j]; - - if (!is_zero(symbolic_pdiff(result.coef[v1].get(), v2).get())) { + auto coef = symbolic_pdiff(result.coef[v1].get(), v2); + if (!coef || !is_zero(coef.get())) { result.is_linear = false; goto done; } diff --git a/test/unit-modcc/test_printers.cpp b/test/unit-modcc/test_printers.cpp index 83f323fb064b90e91d4abd64fc02d289178f0503..81bdc785d1cedb20f2b6a7bbe35051be1dcaa842 100644 --- a/test/unit-modcc/test_printers.cpp +++ b/test/unit-modcc/test_printers.cpp @@ -58,15 +58,15 @@ TEST(scalar_printer, constants) { TEST(scalar_printer, statement) { std::vector<testcase> testcases = { - {"y=x+3", "y=x+3"}, + {"y=x+3", "y=x+3.0"}, {"y=y^z", "y=pow(y,z)"}, - {"y=exp((x/2) + 3)", "y=exp(x/2+3)"}, + {"y=exp((x/2) + 3)", "y=exp(x/2.0+3.0)"}, {"z=a/b/c", "z=a/b/c"}, {"z=a/(b/c)", "z=a/(b/c)"}, {"z=(a*b)/c", "z=a*b/c"}, {"z=a-(b+c)", "z=a-(b+c)"}, {"z=(a>0)<(b>0)", "z=a>0.<(b>0.)"}, - {"z=a- -2", "z=a- -2"}, + {"z=a- -2", "z=a- -2.0"}, {"z=fabs(x-z)", "z=abs(x-z)"}, {"z=min(x,y)", "z=min(x,y)"}, {"z=min(max(a,b),y)","z=min(max(a,b),y)"}, @@ -125,10 +125,10 @@ TEST(CPrinter, proc_body) { "}" , "value_type k;\n" - "minf[i_] = 1-1/(1+exp((v-k)/k));\n" - "hinf[i_] = 1/(1+exp((v-k)/k));\n" + "minf[i_] = 1.0-1.0/(1.0+exp((v-k)/k));\n" + "hinf[i_] = 1.0/(1.0+exp((v-k)/k));\n" "mtau[i_] = 0.5;\n" - "htau[i_] = 1500;\n" + "htau[i_] = 1500.0;\n" } }; @@ -205,15 +205,15 @@ TEST(CPrinter, proc_body_inlined) { "r_7_ = 0.;\n" "r_8_ = 0.;\n" "r_9_=s2[i_]*0.33333333333333331;\n" - "r_8_=s1[i_]+2;\n" - "if(s1[i_]==3){\n" - " r_7_=2*r_8_;\n" + "r_8_=s1[i_]+2.0;\n" + "if(s1[i_]==3.0){\n" + " r_7_=2.0*r_8_;\n" "}\n" "else{\n" - " if(s1[i_]==4){\n" + " if(s1[i_]==4.0){\n" " r_11_ = 0.;\n" " r_12_ = 0.;\n" - " r_12_=6+s1[i_];\n" + " r_12_=6.0+s1[i_];\n" " r_11_=r_12_;\n" " r_7_=r_8_*r_11_;\n" " }\n" @@ -226,28 +226,28 @@ TEST(CPrinter, proc_body_inlined) { "r_14_=0.;\n" "r_14_=r_9_/s2[i_];\n" "r_15_=log(r_14_);\n" - "r_13_=42*r_15_;\n" + "r_13_=42.0*r_15_;\n" "r_6_=r_9_*r_13_;\n" "t0=r_7_*r_6_;\n" "t1=exprelr(t0);\n" - "ll0_=t1+2;\n" - "if(ll0_==3){\n" - " t2=10;\n" + "ll0_=t1+2.0;\n" + "if(ll0_==3.0){\n" + " t2=10.0;\n" "}\n" "else{\n" - " if(ll0_==4){\n" + " if(ll0_==4.0){\n" " r_17_=0.;\n" " r_18_=0.;\n" - " r_18_=6+ll0_;\n" + " r_18_=6.0+ll0_;\n" " r_17_=r_18_;\n" - " t2=5*r_17_;\n" + " t2=5.0*r_17_;\n" " }\n" " else{\n" " r_16_=148.4131591025766;\n" " t2=r_16_*ll0_;\n" " }\n" "}\n" - "s2[i_]=t2+4;\n"; + "s2[i_]=t2+4.0;\n"; Module m(io::read_all(DATADIR "/mod_files/test6.mod"), "test6.mod"); Parser p(m, false); @@ -277,25 +277,25 @@ TEST(CPrinter, proc_body_inlined) { TEST(SimdPrinter, simd_if_else) { std::vector<const char*> expected_procs = { "simd_value u;\n" - "simd_mask mask_0_ = S::cmp_gt(i, (double)2);\n" - "S::where(mask_0_,u) = (double)7;\n" - "S::where(S::logical_not(mask_0_),u) = (double)5;\n" - "indirect(s+i_, simd_width_) = S::where(S::logical_not(mask_0_),simd_cast<simd_value>((double)42));\n" + "simd_mask mask_0_ = S::cmp_gt(i, (double)2.0);\n" + "S::where(mask_0_,u) = (double)7.0;\n" + "S::where(S::logical_not(mask_0_),u) = (double)5.0;\n" + "indirect(s+i_, simd_width_) = S::where(S::logical_not(mask_0_),simd_cast<simd_value>((double)42.0));\n" "indirect(s+i_, simd_width_) = u;" , "simd_value u;\n" - "simd_mask mask_1_ = S::cmp_gt(i, (double)2);\n" - "S::where(mask_1_,u) = (double)7;\n" - "S::where(S::logical_not(mask_1_),u) = (double)5;\n" - "indirect(s+i_, simd_width_) = S::where(S::logical_and(S::logical_not(mask_1_), mask_input_),simd_cast<simd_value>((double)42));\n" + "simd_mask mask_1_ = S::cmp_gt(i, (double)2.0);\n" + "S::where(mask_1_,u) = (double)7.0;\n" + "S::where(S::logical_not(mask_1_),u) = (double)5.0;\n" + "indirect(s+i_, simd_width_) = S::where(S::logical_and(S::logical_not(mask_1_), mask_input_),simd_cast<simd_value>((double)42.0));\n" "indirect(s+i_, simd_width_) = S::where(mask_input_, u);" , - "simd_mask mask_2_ = S::cmp_gt(simd_cast<simd_value>(indirect(g+i_, simd_width_)), (double)2);\n" - "simd_mask mask_3_ = S::cmp_gt(simd_cast<simd_value>(indirect(g+i_, simd_width_)), (double)3);\n" + "simd_mask mask_2_ = S::cmp_gt(simd_cast<simd_value>(indirect(g+i_, simd_width_)), (double)2.0);\n" + "simd_mask mask_3_ = S::cmp_gt(simd_cast<simd_value>(indirect(g+i_, simd_width_)), (double)3.0);\n" "S::where(S::logical_and(mask_2_,mask_3_),i) = (double)0.;\n" - "S::where(S::logical_and(mask_2_,S::logical_not(mask_3_)),i) = (double)1;\n" - "simd_mask mask_4_ = S::cmp_lt(simd_cast<simd_value>(indirect(g+i_, simd_width_)), (double)1);\n" - "indirect(s+i_, simd_width_) = S::where(S::logical_and(S::logical_not(mask_2_),mask_4_),simd_cast<simd_value>((double)2));\n" + "S::where(S::logical_and(mask_2_,S::logical_not(mask_3_)),i) = (double)1.0;\n" + "simd_mask mask_4_ = S::cmp_lt(simd_cast<simd_value>(indirect(g+i_, simd_width_)), (double)1.0);\n" + "indirect(s+i_, simd_width_) = S::where(S::logical_and(S::logical_not(mask_2_),mask_4_),simd_cast<simd_value>((double)2.0));\n" "rates(i_, S::logical_and(S::logical_not(mask_2_),S::logical_not(mask_4_)), i);" }; diff --git a/test/unit-modcc/test_symdiff.cpp b/test/unit-modcc/test_symdiff.cpp index 410e77001edcf2a1fae852628abbf39b10a9d020..226511701486c3dd3f423a1010269c120f133e09 100644 --- a/test/unit-modcc/test_symdiff.cpp +++ b/test/unit-modcc/test_symdiff.cpp @@ -217,6 +217,19 @@ TEST(symbolic_pdiff, nonlinear) { } } +TEST(symbolic_pdiff, non_differentiable) { + struct { const char* exp; } tests[] = { + { "max(x)"}, + { "min(a)"} + }; + + for (const auto& item: tests) { + SCOPED_TRACE(std::string("expressions: ")+item.exp); + auto exp = Parser{item.exp}.parse_expression(); + ASSERT_FALSE(exp); + } +} + inline expression_ptr operator""_expr(const char* literal, std::size_t) { return Parser{literal}.parse_expression(); } @@ -313,6 +326,18 @@ TEST(linear_test, nonlinear) { EXPECT_FALSE(r.is_linear); } +TEST(linear_test, non_differentiable) { + linear_test_result r; + + r = linear_test("max(x, y)"_expr, {"x", "y"}); + EXPECT_FALSE(r.is_linear); + EXPECT_FALSE(r.is_homogeneous); + + r = linear_test("min(x, y)"_expr, {"x", "y"}); + EXPECT_FALSE(r.is_linear); + EXPECT_FALSE(r.is_homogeneous); +} + TEST(linear_test, diagonality) { auto xdot = "a*x"_expr; auto ydot = "-b*y/2"_expr; diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 3fa832387e893e792ef7304dd6de792ef73d9486..e709899bd2d15e3e7814d4d6943b85669dc3c0a5 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -21,6 +21,7 @@ set(test_mechanisms test1_kin_steadystate fixed_ica_current point_ica_current + non_linear linear_ca_conc test_cl_valence test_ca_read_valence diff --git a/test/unit/mod/non_linear.mod b/test/unit/mod/non_linear.mod new file mode 100644 index 0000000000000000000000000000000000000000..2a066c748a37ce2a7a1262a7a3fe7b87012fdddc --- /dev/null +++ b/test/unit/mod/non_linear.mod @@ -0,0 +1,32 @@ +NEURON { + POINT_PROCESS non_linear + RANGE tau, e + NONSPECIFIC_CURRENT i +} + +PARAMETER { + tau = 2.0 (ms) : the default for Neuron is 0.1 + e = 0 (mV) +} + +STATE { + g +} + +INITIAL { + g=0 +} + +BREAKPOINT { + SOLVE state METHOD cnexp + i = g*(v - e) +} + +DERIVATIVE state { + g' = -g/tau +} + +NET_RECEIVE(weight) { + g = max(0, min(g + weight, 10)) +} + diff --git a/test/unit/test_mechinfo.cpp b/test/unit/test_mechinfo.cpp index d7d20c8dbf1db84369201bfb3c78c538b76fe744..cd5b088530d75b8b73f0eb2f6e0b37366cfadc1c 100644 --- a/test/unit/test_mechinfo.cpp +++ b/test/unit/test_mechinfo.cpp @@ -5,6 +5,7 @@ #include <arbor/cable_cell.hpp> #include "../gtest.h" +#include "unit_test_catalogue.hpp" // TODO: This test is really checking part of the recipe description // for cable1d cells, so move it there. Make actual tests for mechinfo @@ -40,3 +41,16 @@ TEST(mechanism_desc, setting) { EXPECT_EQ(p["b"], m["b"]); EXPECT_EQ(p["d"], m["d"]); } + +TEST(mechanism_desc, linearity) { + { + mechanism_catalogue cat = arb::global_default_catalogue(); + EXPECT_TRUE(cat["expsyn"].linear); + EXPECT_TRUE(cat["exp2syn"].linear); + } + { + mechanism_catalogue cat = make_unit_test_catalogue(); + EXPECT_FALSE(cat["non_linear"].linear); + } + +} diff --git a/test/unit/unit_test_catalogue.cpp b/test/unit/unit_test_catalogue.cpp index 2c2239860edfb2daa2a4bdf4f86819b77d2feafd..1d3c5112acc171b360e282708fa9d623af2d782d 100644 --- a/test/unit/unit_test_catalogue.cpp +++ b/test/unit/unit_test_catalogue.cpp @@ -11,6 +11,7 @@ #include "mechanisms/celsius_test.hpp" #include "mechanisms/diam_test.hpp" #include "mechanisms/param_as_state.hpp" +#include "mechanisms/non_linear.hpp" #include "mechanisms/test0_kin_diff.hpp" #include "mechanisms/test_linear_state.hpp" #include "mechanisms/test_linear_init.hpp" @@ -80,6 +81,7 @@ mechanism_catalogue make_unit_test_catalogue(const mechanism_catalogue& from) { ADD_MECH(cat, test1_kin_compartment) ADD_MECH(cat, test4_kin_compartment) ADD_MECH(cat, fixed_ica_current) + ADD_MECH(cat, non_linear) ADD_MECH(cat, point_ica_current) ADD_MECH(cat, linear_ca_conc) ADD_MECH(cat, test_cl_valence)