diff --git a/arbor/include/arbor/simd/simd.hpp b/arbor/include/arbor/simd/simd.hpp index 5edd63fe6489b256af1913aba4ec4b5a09a88812..3e0fb573387a6c2e8f291cc1779ba04949b23cd9 100644 --- a/arbor/include/arbor/simd/simd.hpp +++ b/arbor/include/arbor/simd/simd.hpp @@ -408,6 +408,28 @@ namespace detail { simd_impl& data_; }; + struct const_where_expression { + const_where_expression(const const_where_expression&) = default; + const_where_expression& operator=(const const_where_expression&) = delete; + + const_where_expression(const simd_mask& m, const simd_impl& s): + mask_(m), data_(s) {} + + void copy_to(scalar_type* p) const { + Impl::copy_to_masked(data_.value_, p, mask_.value_); + } + + template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> + void copy_to(indirect_expression<IndexImpl, scalar_type> pi) const { + Impl::scatter(tag<IndexImpl>{}, data_.value_, pi.p, pi.index, mask_.value_); + } + + private: + const simd_mask& mask_; + const simd_impl& data_; + }; + + // Maths functions are implemented as top-level functions; declare as friends // for access to `wrap` and to enjoy ADL, allowing implicit conversion from // scalar_type in binary operation arguments. @@ -639,11 +661,19 @@ using simd_mask = typename simd<Value, N>::simd_mask; template <typename Simd> using where_expression = typename Simd::where_expression; +template <typename Simd> +using const_where_expression = typename Simd::const_where_expression; + template <typename Simd> where_expression<Simd> where(const typename Simd::simd_mask& m, Simd& v) { return where_expression<Simd>(m, v); } +template <typename Simd> +const_where_expression<Simd> where(const typename Simd::simd_mask& m, const Simd& v) { + return const_where_expression<Simd>(m, v); +} + template <typename> struct is_simd: std::false_type {}; diff --git a/modcc/functioninliner.cpp b/modcc/functioninliner.cpp index 8a9e63113f48eb1878d1fabbbbf98733b8511e6c..5692da136ea63009efac8563eda224fbb6118a3c 100644 --- a/modcc/functioninliner.cpp +++ b/modcc/functioninliner.cpp @@ -90,7 +90,6 @@ void FunctionInliner::visit(UnaryExpression* e) { if (!inlining_in_progress_) { return; } - auto sub = substitute(e->expression(), local_arg_map_); sub = substitute(sub, call_arg_map_); e->replace_expression(std::move(sub)); diff --git a/modcc/functioninliner.hpp b/modcc/functioninliner.hpp index 34e37a5c038dea86f8e92a7356bc3a4002f32810..ba4bada981e5ad43def73fe3d699391ec271b9db 100644 --- a/modcc/functioninliner.hpp +++ b/modcc/functioninliner.hpp @@ -10,7 +10,6 @@ expression_ptr inline_function_calls(std::string calling_func, BlockExpression* class FunctionInliner : public BlockRewriterBase { public: using BlockRewriterBase::visit; - FunctionInliner(std::string calling_func) : BlockRewriterBase(), calling_func_(calling_func) {}; FunctionInliner(scope_ptr s): BlockRewriterBase(s) {} diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp index 0c88fb766e85c8714189198c2ada86b4c50e885b..f6edfe9ad6adbae8116c5192101b2a545577c44e 100644 --- a/modcc/printer/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -7,6 +7,7 @@ #include "error.hpp" #include "lexer.hpp" #include "io/ostream_wrappers.hpp" +#include "astmanip.hpp" #include "io/prefixbuf.hpp" std::ostream& operator<<(std::ostream& out, as_c_double wrap) { @@ -165,3 +166,102 @@ void CExprEmitter::visit(IfExpression* e) { } } } + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +std::unordered_set<std::string> SimdExprEmitter::mask_names_; + +void SimdExprEmitter::visit(BlockExpression* block) { + for (auto& stmt: block->statements()) { + if (!stmt->is_local_declaration()) { + stmt->accept(this); + if (!stmt->is_if() && !stmt->is_block()) { + out_ << ";\n"; + } + } + } +} + +void SimdExprEmitter::visit(CallExpression* e) { + if(is_indirect_) + out_ << e->name() << "(index_"; + else + out_ << e->name() << "(i_"; + + if (processing_true_ && !current_mask_.empty()) { + out_ << ", " << current_mask_; + } else if (!processing_true_ && !current_mask_bar_.empty()) { + out_ << ", " << current_mask_bar_; + } + for (auto& arg: e->args()) { + out_ << ", "; + arg->accept(this); + } + out_ << ")"; +} + +void SimdExprEmitter::visit(AssignmentExpression* e) { + if (!e->lhs() || !e->lhs()->is_identifier() || !e->lhs()->is_identifier()->symbol()) { + throw compiler_exception("Expect symbol on lhs of assignment: "+e->to_string()); + } + + auto mask = processing_true_ ? current_mask_ : current_mask_bar_; + Symbol* lhs = e->lhs()->is_identifier()->symbol(); + + if (lhs->is_variable() && lhs->is_variable()->is_range()) { + if (!input_mask_.empty()) { + mask = mask + " && " + input_mask_; + } + out_ << "S::where(" << mask << ", " << "simd_value("; + e->rhs()->accept(this); + out_ << "))"; + if(is_indirect_) + out_ << ".copy_to(" << lhs->name() << "+index_)"; + else + out_ << ".copy_to(" << lhs->name() << "+i_)"; + } else { + out_ << "S::where(" << mask << ", "; + e->lhs()->accept(this); + out_ << ") = "; + e->rhs()->accept(this); + } +} + +void SimdExprEmitter::visit(IfExpression* e) { + + // Save old masks + auto old_mask = current_mask_; + auto old_mask_bar = current_mask_bar_; + auto old_branch = processing_true_; + + // Create new mask name + auto new_mask = make_unique_var(e->scope(), "mask_"); + + // Set new masks + out_ << "simd_value::simd_mask " << new_mask << " = "; + e->condition()->accept(this); + out_ << ";\n"; + + if (!current_mask_.empty()) { + auto base_mask = processing_true_ ? current_mask_ : current_mask_bar_; + current_mask_bar_ = base_mask + " && !" + new_mask; + current_mask_ = base_mask + " && " + new_mask; + + } else { + current_mask_bar_ = "!" + new_mask; + current_mask_ = new_mask; + } + + processing_true_ = true; + e->true_branch()->accept(this); + + processing_true_ = false; + if (auto fb = e->false_branch()) { + fb->accept(this); + } + + // Reset old masks + current_mask_ = old_mask; + current_mask_bar_ = old_mask_bar; + processing_true_ = old_branch; + +} diff --git a/modcc/printer/cexpr_emit.hpp b/modcc/printer/cexpr_emit.hpp index 099e3ed5f397eb35a0afb963e37ca37b1eae8395..fbbfd448d017dfde85b69397c967f7b323e500eb 100644 --- a/modcc/printer/cexpr_emit.hpp +++ b/modcc/printer/cexpr_emit.hpp @@ -1,6 +1,7 @@ #pragma once #include <iosfwd> +#include <unordered_set> #include "expression.hpp" #include "visitor.hpp" @@ -36,6 +37,40 @@ inline void cexpr_emit(Expression* e, std::ostream& out, Visitor* fallback) { e->accept(&emitter); } +class SimdExprEmitter: public CExprEmitter { + using CExprEmitter::visit; +public: + SimdExprEmitter(std::ostream& out, bool is_indirect, std::string input_mask, Visitor* fallback): + CExprEmitter(out, fallback), is_indirect_(is_indirect), input_mask_(input_mask) {} + + void visit(BlockExpression *e) override; + void visit(CallExpression *e) override; + void visit(AssignmentExpression *e) override; + void visit(IfExpression *e) override; + +protected: + static std::unordered_set<std::string> mask_names_; + bool processing_true_; + bool is_indirect_; + std::string current_mask_, current_mask_bar_, input_mask_; + +private: + std::string make_unique_var(scope_ptr scope, std::string prefix) { + for (int i = 0;; ++i) { + std::string name = prefix + std::to_string(i) + "_"; + if (!scope->find(name) && !mask_names_.count(name)) { + mask_names_.insert(name); + return name; + } + } + }; +}; + +inline void simd_expr_emit(Expression* e, std::ostream& out, bool is_indirect, std::string input_mask, Visitor* fallback) { + SimdExprEmitter emitter(out, is_indirect, input_mask, fallback); + e->accept(&emitter); +} + // Helper for formatting of double-valued numeric constants. struct as_c_double { double value; diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index c42a16e80c5fec5804c66c5c4d61a7d445ec5700..9c12821b704af9b18e3b2cba208acbf8b3b21ffb 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -26,6 +26,7 @@ constexpr bool with_profiling() { void emit_procedure_proto(std::ostream&, ProcedureExpression*, const std::string& qualified = ""); void emit_simd_procedure_proto(std::ostream&, ProcedureExpression*, const std::string& qualified = ""); +void emit_masked_simd_procedure_proto(std::ostream&, ProcedureExpression*, const std::string& qualified = ""); void emit_api_body(std::ostream&, APIMethod*); void emit_simd_api_body(std::ostream&, APIMethod*, moduleKind); @@ -56,21 +57,24 @@ struct cprint { struct simdprint { Expression* expr_; - bool is_indirect_index_; - simd_expr_constraint constraint_; + bool is_indirect_ = false; + bool is_masked_ = false; - explicit simdprint(Expression* expr): expr_(expr), is_indirect_index_(false), - constraint_(simd_expr_constraint::other) {} - explicit simdprint(Expression* expr, bool is_indexed, simd_expr_constraint constraint): - expr_(expr), is_indirect_index_(is_indexed), constraint_(constraint) {} + explicit simdprint(Expression* expr): expr_(expr) {} void set_indirect_index() { - is_indirect_index_ = true; + is_indirect_ = true; + } + void set_masked() { + is_masked_ = true; } friend std::ostream& operator<<(std::ostream& out, const simdprint& w) { SimdPrinter printer(out); - printer.set_var_indexed_to(w.is_indirect_index_); + if(w.is_masked_) { + printer.set_input_mask("mask_input_"); + } + printer.set_var_indexed(w.is_indirect_); return w.expr_->accept(&printer), out; } }; @@ -312,11 +316,14 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { } for (auto proc: normal_procedures(module_)) { - emit_procedure_proto(out, proc); - out << ";\n"; if (with_simd) { emit_simd_procedure_proto(out, proc); out << ";\n"; + emit_masked_simd_procedure_proto(out, proc); + out << ";\n"; + } else { + emit_procedure_proto(out, proc); + out << ";\n"; } } @@ -377,17 +384,20 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { // Mechanism procedures for (auto proc: normal_procedures(module_)) { - emit_procedure_proto(out, proc, class_name); - out << - " {\n" << indent << - cprint(proc->body()) << popindent << - "}\n\n"; - if (with_simd) { emit_simd_procedure_proto(out, proc, class_name); + auto simd_print = simdprint(proc->body()); + out << " {\n" << indent << simd_print << popindent << "}\n\n"; + + emit_masked_simd_procedure_proto(out, proc, class_name); + auto masked_print = simdprint(proc->body()); + masked_print.set_masked(); + out << " {\n" << indent << masked_print << popindent << "}\n\n"; + } else { + emit_procedure_proto(out, proc, class_name); out << " {\n" << indent << - simdprint(proc->body()) << popindent << + cprint(proc->body()) << popindent << "}\n\n"; } } @@ -541,7 +551,7 @@ void SimdPrinter::visit(LocalVariable* sym) { void SimdPrinter::visit(VariableExpression *sym) { if (sym->is_range()) { - if(is_indirect_index_) + if(is_indirect_) out_ << "simd_value(" << sym->name() << "+index_)"; else out_ << "simd_value(" << sym->name() << "+i_)"; @@ -559,9 +569,17 @@ void SimdPrinter::visit(AssignmentExpression* e) { Symbol* lhs = e->lhs()->is_identifier()->symbol(); if (lhs->is_variable() && lhs->is_variable()->is_range()) { - out_ << "simd_value("; + if (!input_mask_.empty()) + out_ << "S::where(" << input_mask_ << ", simd_value("; + else + out_ << "simd_value("; + e->rhs()->accept(this); - if(is_indirect_index_) + + if (!input_mask_.empty()) + out_ << ")"; + + if(is_indirect_) out_ << ").copy_to(" << lhs->name() << "+index_)"; else out_ << ").copy_to(" << lhs->name() << "+i_)"; @@ -573,7 +591,7 @@ void SimdPrinter::visit(AssignmentExpression* e) { } void SimdPrinter::visit(CallExpression* e) { - if(is_indirect_index_) + if(is_indirect_) out_ << e->name() << "(index_"; else out_ << e->name() << "(i_"; @@ -599,12 +617,11 @@ void SimdPrinter::visit(BlockExpression* block) { } for (auto& stmt: block->statements()) { - if (stmt->is_if()) { - throw compiler_exception("Conditionals not yet supported in SIMD printer: "+stmt->to_string()); - } if (!stmt->is_local_declaration()) { stmt->accept(this); - out_ << ";\n"; + if (!stmt->is_if() && !stmt->is_block()) { + out_ << ";\n"; + } } } } @@ -617,6 +634,15 @@ void emit_simd_procedure_proto(std::ostream& out, ProcedureExpression* e, const out << ")"; } +void emit_masked_simd_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& qualified) { + out << "void " << qualified << (qualified.empty()? "": "::") << e->name() + << "(index_type i_, simd_value::simd_mask mask_input_"; + for (auto& arg: e->args()) { + out << ", const simd_value& " << arg->is_argument()->name(); + } + out << ")"; +} + void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_constraint constraint) { out << "simd_value " << local->name(); diff --git a/modcc/printer/cprinter.hpp b/modcc/printer/cprinter.hpp index b7612ca98f5ef8669bed42ee0f6de7aed417fa44..5b12957f0a90f84fa421fa3008945ddfbfe8d2fc 100644 --- a/modcc/printer/cprinter.hpp +++ b/modcc/printer/cprinter.hpp @@ -51,8 +51,11 @@ public: void visit(Expression* e) override { throw compiler_exception("SimdPrinter cannot translate expression "+e->to_string()); } - void set_var_indexed_to(bool is_indirect_index) { - is_indirect_index_ = is_indirect_index; + void set_var_indexed(bool is_indirect_index) { + is_indirect_ = is_indirect_index; + } + void set_input_mask(std::string input_mask) { + input_mask_ = input_mask; } void visit(BlockExpression*) override; @@ -65,8 +68,10 @@ public: void visit(NumberExpression* e) override { cexpr_emit(e, out_, this); } void visit(UnaryExpression* e) override { cexpr_emit(e, out_, this); } void visit(BinaryExpression* e) override { cexpr_emit(e, out_, this); } + void visit(IfExpression* e) override { simd_expr_emit(e, out_, is_indirect_, input_mask_, this); } private: std::ostream& out_; - bool is_indirect_index_; + std::string input_mask_; + bool is_indirect_ = false; }; diff --git a/test/unit-modcc/mod_files/test7.mod b/test/unit-modcc/mod_files/test7.mod new file mode 100644 index 0000000000000000000000000000000000000000..bc88f63e5cb4e20591e709a9079ce20e51e9d986 --- /dev/null +++ b/test/unit-modcc/mod_files/test7.mod @@ -0,0 +1,55 @@ +UNITS { + (mV) = (millivolt) + (S) = (siemens) +} +NEURON { + SUFFIX pas + NONSPECIFIC_CURRENT i + RANGE g, e +} + +INITIAL {} + +PARAMETER { + g = .001 (S/cm2) + e = -65 (mV) : we use -65 for the ball and stick model, instead of Neuron default of -70 +} + +STATE { + s +} + +ASSIGNED { + v (mV) +} + +BREAKPOINT { + foo(i) +} + +PROCEDURE foo(i) { + if (g > 2) { + if (g > 3) { + i = 0 + } else { + i = 1 + } + } else { + if (g < 1) { + s = 2 + } else { + rates(i) + } + } +} + +PROCEDURE rates(i) { +LOCAL u + if(i > 2) { + u = 7 + } else { + u = 5 + s = 42 + } + s = u +} \ No newline at end of file diff --git a/test/unit-modcc/test_printers.cpp b/test/unit-modcc/test_printers.cpp index de50a5f3423cff7f374111e14b951b498d06118c..8e1635dcdc8e487666c628d4e219048f4dbf4ecb 100644 --- a/test/unit-modcc/test_printers.cpp +++ b/test/unit-modcc/test_printers.cpp @@ -262,4 +262,62 @@ TEST(CPrinter, proc_body_inlined) { proc_with_locals.erase(0, proc_with_locals.find(";") + 1); EXPECT_EQ(strip(expected), proc_with_locals); +} + +TEST(SimdPrinter, simd_if_else) { + std::vector<const char*> expected_procs = { + "simd_value u;\n" + "simd_value::simd_mask mask_0_ = i > 2;\n" + "S::where(mask_0_,u) = 7;\n" + "S::where(!mask_0_,u) = 5;\n" + "S::where(!mask_0_,simd_value(42)).copy_to(s+i_);\n" + "simd_value(u).copy_to(s+i_);" + , + "simd_value u;\n" + "simd_value::simd_mask mask_1_ = i > 2;\n" + "S::where(mask_1_,u) = 7;\n" + "S::where(!mask_1_,u) = 5;\n" + "S::where(!mask_1_ && mask_input_,simd_value(42)).copy_to(s+i_);\n" + "S::where(mask_input_, simd_value(u)).copy_to(s+i_);" + , + "simd_value::simd_mask mask_2_ = simd_value(g+i_)>2;\n" + "simd_value::simd_mask mask_3_ = simd_value(g+i_)>3;\n" + "S::where(mask_2_&&mask_3_,i) = 0.;\n" + "S::where(mask_2_&&!mask_3_,i) = 1;\n" + "simd_value::simd_mask mask_4_ = simd_value(g+i_)<1;\n" + "S::where(!mask_2_&& mask_4_,simd_value(2)).copy_to(s+i_);\n" + "rates(i_, !mask_2_&&!mask_4_, i);" + }; + + Module m(io::read_all(DATADIR "/mod_files/test7.mod"), "test7.mod"); + Parser p(m, false); + p.parse(); + m.semantic(); + + struct proc { + std::string name; + bool masked; + }; + + std::vector<proc> procs = {{"rates", false}, {"rates", true}, {"foo", false}}; + for (unsigned i = 0; i < procs.size(); i++) { + auto p = procs[i]; + std::stringstream out; + auto &proc = m.symbols().at(p.name); + ASSERT_TRUE(proc->is_symbol()); + + auto v = std::make_unique<SimdPrinter>(out); + if (p.masked) { + v->set_input_mask("mask_input_"); + } + proc->is_procedure()->body()->accept(v.get()); + std::string text = out.str(); + + verbose_print(proc->is_procedure()->body()->to_string()); + verbose_print(" :--: ", text); + + auto proc_with_locals = strip(text); + EXPECT_EQ(strip(expected_procs[i]), proc_with_locals); + + } } \ No newline at end of file diff --git a/test/unit/test_simd.cpp b/test/unit/test_simd.cpp index 47c701c755a589eec2af762c1ba45a77bb77d0ac..c28503be16c36cd8debe1c6523e1aeeb3075fb13 100644 --- a/test/unit/test_simd.cpp +++ b/test/unit/test_simd.cpp @@ -185,10 +185,11 @@ TYPED_TEST_P(simd_value, copy_to_from_masked) { std::minstd_rand rng(1031); for (unsigned i = 0; i<nrounds; ++i) { - scalar buf1[N], buf2[N], buf3[N]; + scalar buf1[N], buf2[N], buf3[N], buf4[N]; fill_random(buf1, rng); fill_random(buf2, rng); fill_random(buf3, rng); + fill_random(buf4, rng); bool mbuf1[N], mbuf2[N]; fill_random(mbuf1, rng); @@ -211,6 +212,13 @@ TYPED_TEST_P(simd_value, copy_to_from_masked) { where(m2, s).copy_to(buf3); EXPECT_TRUE(testing::indexed_eq_n(N, expected, buf3)); + + for (unsigned i = 0; i<N; ++i) { + expected[i] = mbuf2[i]? buf1[i]: buf4[i]; + } + + where(m2, simd(buf1)).copy_to(buf4); + EXPECT_TRUE(testing::indexed_eq_n(N, expected, buf4)); } } @@ -1028,6 +1036,14 @@ TYPED_TEST_P(simd_indirect, masked_scatter) { where(m, s).copy_to(indirect(array, simd_index(offset))); EXPECT_TRUE(::testing::indexed_eq_n(buflen, test, array)); + + for (unsigned j = 0; j<buflen; ++j) { + array[j] = test[j]; + } + + where(m, simd(values)).copy_to(indirect(array, simd_index(offset))); + + EXPECT_TRUE(::testing::indexed_eq_n(buflen, test, array)); } }