diff --git a/arbor/backends/gpu/mechanism.cpp b/arbor/backends/gpu/mechanism.cpp index a0304fe7975cefc2f656a1c994f8658b86b62871..b3d6af949aed4452c0c0b8b9fedf08ef63351af7 100644 --- a/arbor/backends/gpu/mechanism.cpp +++ b/arbor/backends/gpu/mechanism.cpp @@ -80,7 +80,7 @@ void mechanism::instantiate(unsigned id, // Assign non-owning views onto shared state: - mechanism_ppack_base* pp = ppack_ptr(); // From derived class instance. + mechanism_ppack* pp = ppack_ptr(); // From derived class instance. pp->width_ = width_; pp->n_detectors_ = shared.n_detector; @@ -213,7 +213,7 @@ fvm_value_type* mechanism::field_data(const std::string& field_var) { void multiply_in_place(fvm_value_type* s, const fvm_index_type* p, int n); void mechanism::initialize() { - mechanism_ppack_base* pp = ppack_ptr(); + mechanism_ppack* pp = ppack_ptr(); pp->vec_t_ = vec_t_ptr_->data(); init(); diff --git a/arbor/backends/gpu/mechanism.cu b/arbor/backends/gpu/mechanism.cu index dc9ab1679b24999001d51b020303d8e3701b4344..f0befe09e3aba23d2a40f5cbad4fd1d01e736f09 100644 --- a/arbor/backends/gpu/mechanism.cu +++ b/arbor/backends/gpu/mechanism.cu @@ -1,10 +1,5 @@ -#include <iostream> -#include <backends/event.hpp> -#include <backends/multi_event_stream_state.hpp> +#include <arbor/fvm_types.hpp> #include <backends/gpu/gpu_common.hpp> -#include <backends/gpu/math_cu.hpp> -#include <backends/gpu/mechanism_ppack_base.hpp> -#include <backends/gpu/reduce_by_key.hpp> namespace arb { namespace gpu { diff --git a/arbor/backends/gpu/mechanism.hpp b/arbor/backends/gpu/mechanism.hpp index e63d222a28070862175a611237e15e3304c4659b..42daf236bcc5205b1ec63350b6a70f2c1eda945f 100644 --- a/arbor/backends/gpu/mechanism.hpp +++ b/arbor/backends/gpu/mechanism.hpp @@ -13,7 +13,6 @@ #include "backends/gpu/fvm.hpp" #include "backends/gpu/gpu_store_types.hpp" -#include "backends/gpu/mechanism_ppack_base.hpp" namespace arb { namespace gpu { @@ -22,74 +21,10 @@ namespace gpu { class mechanism: public arb::concrete_mechanism<arb::gpu::backend> { public: -protected: - using array = arb::gpu::array; - using iarray = arb::gpu::iarray; - -public: - std::size_t size() const override { - return width_; - } - - std::size_t memory() const override { - std::size_t s = object_sizeof(); - - s += sizeof(value_type) * data_.size(); - s += sizeof(index_type) * indices_.size(); - return s; - } - void instantiate(fvm_size_type id, backend::shared_state& shared, const mechanism_overrides&, const mechanism_layout&) override; - - void deliver_events() override { - // Delegate to derived class, passing in event queue state. - apply_events(event_stream_ptr_->marked_events()); - } - void update_current() override { - mechanism_ppack_base* pp = ppack_ptr(); - pp->vec_t_ = vec_t_ptr_->data(); - compute_currents(); - } - void update_state() override { - mechanism_ppack_base* pp = ppack_ptr(); - pp->vec_t_ = vec_t_ptr_->data(); - advance_state(); - } - void update_ions() override { - mechanism_ppack_base* pp = ppack_ptr(); - pp->vec_t_ = vec_t_ptr_->data(); - write_ions(); - } - + void initialize() override; void set_parameter(const std::string& key, const std::vector<fvm_value_type>& values) override; - - // Peek into mechanism state variable; implements arb::gpu::backend::mechanism_field_data. - // Returns pointer to GPU memory corresponding to state variable data. fvm_value_type* field_data(const std::string& state_var) override; - - void initialize() override; - -protected: - size_type width_ = 0; // Instance width (number of CVs/sites) - size_type num_ions_ = 0; - - // Returns pointer to (derived) parameter-pack object that holds: - // * pointers to shared cell state `vec_ci_` et al., - // * pointer to mechanism weights `weight_`, - // * pointer to mechanism node indices `node_index_`, - // * mechanism global scalars and pointers to mechanism range parameters. - // * mechanism ion_state_view objects and pointers to mechanism ion indices. - - virtual mechanism_ppack_base* ppack_ptr() = 0; - - deliverable_event_stream* event_stream_ptr_; - const array* vec_t_ptr_; - - // Bulk storage for index vectors and state and parameter variables. - - iarray indices_; - array data_; - bool mult_in_place_; }; } // namespace gpu diff --git a/arbor/backends/multicore/mechanism.cpp b/arbor/backends/multicore/mechanism.cpp index bb6de164c7344dba2e929bafbce8ab593aa6deed..0a5ee29badfbc04b4474222db1d058e45fa1e8bf 100644 --- a/arbor/backends/multicore/mechanism.cpp +++ b/arbor/backends/multicore/mechanism.cpp @@ -62,23 +62,25 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me width_ = pos_data.cv.size(); // Assign non-owning views onto shared state: + auto pp = (arb::multicore::mechanism_ppack*) ppack_ptr(); - vec_ci_ = shared.cv_to_cell.data(); - vec_di_ = shared.cv_to_intdom.data(); - vec_dt_ = shared.dt_cv.data(); + pp->width_ = width_; + pp->vec_ci_ = shared.cv_to_cell.data(); + pp->vec_di_ = shared.cv_to_intdom.data(); + pp->vec_dt_ = shared.dt_cv.data(); - vec_v_ = shared.voltage.data(); - vec_i_ = shared.current_density.data(); - vec_g_ = shared.conductivity.data(); + pp->vec_v_ = shared.voltage.data(); + pp->vec_i_ = shared.current_density.data(); + pp->vec_g_ = shared.conductivity.data(); - temperature_degC_ = shared.temperature_degC.data(); - diam_um_ = shared.diam_um.data(); - time_since_spike_ = shared.time_since_spike.data(); + pp->temperature_degC_ = shared.temperature_degC.data(); + pp->diam_um_ = shared.diam_um.data(); + pp->time_since_spike_ = shared.time_since_spike.data(); - n_detectors_ = shared.n_detector; + pp->n_detectors_ = shared.n_detector; auto ion_state_tbl = ion_state_table(); - n_ion_ = ion_state_tbl.size(); + num_ions_ = ion_state_tbl.size(); for (auto i: ion_state_tbl) { auto ion_binding = value_by_key(overrides.ion_rebind, i.first).value_or(i.first); @@ -117,12 +119,11 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me // Take reference to corresponding derived (generated) mechanism value pointer member. fvm_value_type*& field_ptr = *(fields[i].second); field_ptr = data_.data()+(i+1)*width_padded_; - if (auto opt_value = value_by_key(field_default_table(), fields[i].first)) { std::fill(field_ptr, field_ptr+width_padded_, *opt_value); } } - weight_ = data_.data(); + pp->weight_ = data_.data(); // Allocate and copy local state: weight, node indices, ion indices. // The tail comprises those elements between width_ and width_padded_: @@ -131,26 +132,27 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me // * For indices in the padded tail of node_index_, set index to last valid CV index. // * For indices in the padded tail of ion index maps, set index to last valid ion index. - copy_extend(pos_data.weight, make_range(data_.data(), data_.data()+width_padded_), 0); - - if (mult_in_place_) { - multiplicity_ = iarray(width_padded_, pad); - copy_extend(pos_data.multiplicity, multiplicity_, 1); - } + util::copy_extend(pos_data.weight, make_range(data_.data(), data_.data()+width_padded_), 0); + // Make index bulk storage { auto table = ion_index_table(); // Allocate bulk storage - auto count = (table.size() + 1)*width_padded_; - indices_ = iarray(count, 0, pad); + auto count = table.size() + 1 + (mult_in_place_ ? 1 : 0); + indices_ = iarray(count*width_padded_, 0, pad); auto base_ptr = indices_.data(); + auto append_chunk = [&](const auto& input, auto& output, const auto& pad) { + copy_extend(input, make_range(base_ptr, base_ptr + width_padded_), pad); + output = base_ptr; + base_ptr += width_padded_; + }; + // Setup node indices - node_index_ = base_ptr; - base_ptr += width_padded_; - auto node_index = make_range(node_index_, node_index_ + width_padded_); - copy_extend(pos_data.cv, node_index, pos_data.cv.back()); - index_constraints_ = make_constraint_partition(node_index, width_, simd_width()); + append_chunk(pos_data.cv, pp->node_index_, pos_data.cv.back()); + + auto node_index = make_range(pp->node_index_, pp->node_index_ + width_padded_); + pp->index_constraints_ = make_constraint_partition(node_index, width_, simd_width()); // Create ion indices for (const auto& [ion_name, ion_index_ptr]: table) { @@ -160,19 +162,18 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me if (!oion) { throw arbor_internal_error("multicore/mechanism: mechanism holds ion with no corresponding shared state"); } - - // Set the table entry, step offset to next location - *ion_index_ptr = base_ptr; - base_ptr += width_padded_; - // Obtain index and move data - auto indices = util::index_into(node_index, oion->node_index_); - auto ion_index = make_range(*ion_index_ptr, *ion_index_ptr + width_padded_); - copy_extend(indices, ion_index, util::back(indices)); + auto indices = util::index_into(node_index, oion->node_index_); + append_chunk(indices, *ion_index_ptr, util::back(indices)); // Check SIMD constraints + auto ion_index = make_range(*ion_index_ptr, *ion_index_ptr + width_padded_); arb_assert(compatible_index_constraints(node_index, ion_index, simd_width())); } + + if (mult_in_place_) { + append_chunk(pos_data.multiplicity, pp->multiplicity_, 0); + } } } @@ -196,7 +197,8 @@ void mechanism::set_parameter(const std::string& key, const std::vector<fvm_valu } void mechanism::initialize() { - vec_t_ = vec_t_ptr_->data(); + auto pp_ptr = ppack_ptr(); + pp_ptr->vec_t_ = vec_t_ptr_->data(); init(); auto states = state_table(); @@ -204,7 +206,7 @@ void mechanism::initialize() { if (mult_in_place_) { for (auto& state: states) { for (std::size_t j = 0; j < width_; ++j) { - (*state.second)[j] *= multiplicity_[j]; + (*state.second)[j] *= pp_ptr->multiplicity_[j]; } } } diff --git a/arbor/backends/multicore/mechanism.hpp b/arbor/backends/multicore/mechanism.hpp index 465ab876f685fd3b7762560d80b5ca2cd76079fa..3a0803691bcf9ca0e54e48270d05d4d33da68242 100644 --- a/arbor/backends/multicore/mechanism.hpp +++ b/arbor/backends/multicore/mechanism.hpp @@ -10,6 +10,7 @@ #include <arbor/common_types.hpp> #include <arbor/fvm_types.hpp> #include <arbor/mechanism.hpp> +#include <arbor/mechanism_ppack.hpp> #include "backends/multicore/fvm.hpp" #include "backends/multicore/multicore_common.hpp" @@ -18,86 +19,22 @@ namespace arb { namespace multicore { -// Base class for all generated mechanisms for multicore back-end. +// Parameter pack extended for multicore. +struct mechanism_ppack: arb::mechanism_ppack { + constraint_partition index_constraints_; // Per-mechanism index and weight data, excepting ion indices. +}; +// Base class for all generated mechanisms for multicore back-end. class mechanism: public arb::concrete_mechanism<arb::multicore::backend> { -protected: - using array = arb::multicore::array; - using iarray = arb::multicore::iarray; - public: - std::size_t size() const override { - return width_; - } - - std::size_t memory() const override { - std::size_t s = object_sizeof(); - s += sizeof(data_[0]) * data_.size(); - s += sizeof(indices_[0]) * indices_.size(); - return s; - } - void instantiate(fvm_size_type id, backend::shared_state& shared, const mechanism_overrides&, const mechanism_layout&) override; void initialize() override; - - void deliver_events() override { - // Delegate to derived class, passing in event queue state. - apply_events(event_stream_ptr_->marked_events()); - } - void update_current() override { - vec_t_ = vec_t_ptr_->data(); - compute_currents(); - } - void update_state() override { - vec_t_ = vec_t_ptr_->data(); - advance_state(); - } - void update_ions() override { - vec_t_ = vec_t_ptr_->data(); - write_ions(); - } - void set_parameter(const std::string& key, const std::vector<fvm_value_type>& values) override; - - // Peek into mechanism state variable; implements arb::multicore::backend::mechanism_field_data. fvm_value_type* field_data(const std::string& state_var) override; protected: - fvm_size_type width_ = 0; // Instance width (number of CVs/sites) - fvm_size_type width_padded_ = 0; // Width rounded up to multiple of pad/alignment. - fvm_size_type n_ion_ = 0; - fvm_size_type n_detectors_ = 0; - - // Non-owning views onto shared cell state, excepting ion state. - - const fvm_index_type* vec_ci_; // CV to cell index - const fvm_index_type* vec_di_; // CV to indom index - const fvm_value_type* vec_t_; // Cell index to cell-local time. - const fvm_value_type* vec_dt_; // CV to integration time step. - const fvm_value_type* vec_v_; // CV to cell membrane voltage. - fvm_value_type* vec_i_; // CV to cell membrane current density. - fvm_value_type* vec_g_; // CV to cell membrane conductivity. - const fvm_value_type* temperature_degC_; // CV to temperature. - const fvm_value_type* diam_um_; // CV to diameter. - const fvm_value_type* time_since_spike_; // Vector containing time since last spike, indexed by cell index and n_detectors_ - - const array* vec_t_ptr_; - deliverable_event_stream* event_stream_ptr_; - - // Per-mechanism index and weight data, excepting ion indices. - - fvm_index_type* node_index_; - iarray multiplicity_; - bool mult_in_place_; - constraint_partition index_constraints_; - const fvm_value_type* weight_; // Points within data_ after instantiation. - - // Bulk storage for state and parameter variables. - - array data_; - iarray indices_; - virtual unsigned simd_width() const { return 1; } + fvm_size_type width_padded_ = 0; // Width rounded up to multiple of pad/alignment. }; } // namespace multicore diff --git a/arbor/include/arbor/mechanism.hpp b/arbor/include/arbor/mechanism.hpp index 1c18457e905faff259746b8f5c994105a576ad6f..92209c11e0a9f4112149a568af4b887266708a14 100644 --- a/arbor/include/arbor/mechanism.hpp +++ b/arbor/include/arbor/mechanism.hpp @@ -6,6 +6,7 @@ #include <arbor/fvm_types.hpp> #include <arbor/mechinfo.hpp> +#include <arbor/mechanism_ppack.hpp> namespace arb { @@ -118,9 +119,27 @@ public: // Instantiation: allocate per-instance state; set views/pointers to shared data. virtual void instantiate(unsigned id, typename backend::shared_state&, const mechanism_overrides&, const mechanism_layout&) = 0; + std::size_t size() const override { return width_; } + + std::size_t memory() const override { + std::size_t s = object_sizeof(); + s += sizeof(data_[0]) * data_.size(); + s += sizeof(indices_[0]) * indices_.size(); + return s; + } + + // Delegate to derived class. + virtual void deliver_events() override { apply_events(event_stream_ptr_->marked_events()); } + virtual void update_current() override { set_time_ptr(); compute_currents(); } + virtual void update_state() override { set_time_ptr(); advance_state(); } + virtual void update_ions() override { set_time_ptr(); write_ions(); } + protected: using deliverable_event_stream = typename backend::deliverable_event_stream; using iarray = typename backend::iarray; + using array = typename backend::array; + + void set_time_ptr() { ppack_ptr()->vec_t_ = vec_t_ptr_->data(); } // Generated mechanism field, global and ion table lookup types. // First component is name, second is pointer to corresponing member in @@ -148,13 +167,22 @@ protected: // Member tables: introspection into derived mechanism fields, views etc. // Default implementations correspond to no corresponding fields/globals/ions. - virtual mechanism_field_table field_table() { return {}; } + virtual mechanism_field_table field_table() { return {}; } virtual mechanism_field_default_table field_default_table() { return {}; } - virtual mechanism_global_table global_table() { return {}; } - virtual mechanism_state_table state_table() { return {}; } - virtual mechanism_ion_state_table ion_state_table() { return {}; } - virtual mechanism_ion_index_table ion_index_table() { return {}; } - + virtual mechanism_global_table global_table() { return {}; } + virtual mechanism_state_table state_table() { return {}; } + virtual mechanism_ion_state_table ion_state_table() { return {}; } + virtual mechanism_ion_index_table ion_index_table() { return {}; } + + // Returns pointer to (derived) parameter-pack object that holds: + // * pointers to shared cell state `vec_ci_` et al., + // * pointer to mechanism weights `weight_`, + // * pointer to mechanism node indices `node_index_`, + // * mechanism global scalars and pointers to mechanism range parameters. + // * mechanism ion_state_view objects and pointers to mechanism ion indices. + virtual mechanism_ppack* ppack_ptr() = 0; + + // to be overridden in mechanism implemetations virtual void advance_state() {}; virtual void compute_currents() {}; virtual void apply_events(typename deliverable_event_stream::state) {}; @@ -162,6 +190,20 @@ protected: virtual void init() {}; // Report raw size in bytes of mechanism object. virtual std::size_t object_sizeof() const = 0; + + // events to be processed + + // indirection for accessing time in mechanisms + const array* vec_t_ptr_; + + deliverable_event_stream* event_stream_ptr_; + size_type width_ = 0; // Instance width (number of CVs/sites) + size_type num_ions_ = 0; // Ion count + bool mult_in_place_; // perform multipliction in place? + + // Bulk storage for index vectors and state and parameter variables. + iarray indices_; + array data_; }; } // namespace arb diff --git a/arbor/backends/gpu/mechanism_ppack_base.hpp b/arbor/include/arbor/mechanism_ppack.hpp similarity index 72% rename from arbor/backends/gpu/mechanism_ppack_base.hpp rename to arbor/include/arbor/mechanism_ppack.hpp index 6f2456ca5b112207599f4bfe84e915157329c80b..bb510a5c085894fd2a1a467dc09317d0bc02b2a6 100644 --- a/arbor/backends/gpu/mechanism_ppack_base.hpp +++ b/arbor/include/arbor/mechanism_ppack.hpp @@ -1,19 +1,11 @@ #pragma once -// Base class for parameter packs for GPU generated kernels: -// will be included by .cu generated sources. - -#include <arbor/mechanism.hpp> #include <arbor/fvm_types.hpp> namespace arb { -namespace gpu { - -// Parameter pack base: -struct mechanism_ppack_base { +struct mechanism_ppack { fvm_index_type width_; fvm_index_type n_detectors_; - const fvm_index_type* vec_ci_; const fvm_index_type* vec_di_; const fvm_value_type* vec_t_; @@ -24,12 +16,8 @@ struct mechanism_ppack_base { const fvm_value_type* temperature_degC_; const fvm_value_type* diam_um_; const fvm_value_type* time_since_spike_; - const fvm_index_type* node_index_; const fvm_index_type* multiplicity_; - const fvm_value_type* weight_; }; - -} // namespace gpu } // namespace arb diff --git a/doc/fileformat/nmodl.rst b/doc/fileformat/nmodl.rst index bce1d3ff2c9e7bcdc04449cb9e6f518b0e7abddf..a6fd87019a23c66237edb05ec108d41c2db05276 100644 --- a/doc/fileformat/nmodl.rst +++ b/doc/fileformat/nmodl.rst @@ -51,8 +51,10 @@ Special variables * Arbor exposes some parameters from the simulation to the NMODL mechanisms. These include ``v``, ``diam``, ``celsius`` and ``t`` in addition to the previously mentioned ion parameters. -* Special variables should not be ``ASSIGNED`` or ``CONSTANT``, - they are ``PARAMETER``. +* The ``area`` is not currently exposed to NMODL. +* Special variables should not be ``ASSIGNED`` or ``CONSTANT``, they are + ``PARAMETER``. This is different from NEURON where a built-in variable is + declared ``ASSIGNED`` to make it accessible. * ``diam`` and ``celsius`` can be set from the simulation side. * ``v`` is a reserved variable name and can be written in NMODL. * If Special variables are used in a ``PROCEDURE`` or ``FUNCTION``, they need @@ -77,7 +79,7 @@ Unsupported features However, ``CONSERVE`` statements are supported. * ``TABLE`` is not supported, calculations are exact. * ``derivimplicit`` solving method is not supported, use ``cnexp`` instead. -* `verbatim` blocks are not supported. +* ``VERBATIM`` blocks are not supported. Arbor-specific features ----------------------- diff --git a/modcc/CMakeLists.txt b/modcc/CMakeLists.txt index 775746875bfade226d52ca5d723e336b47a0209c..67ee89c80f7d68d52114ab1d3f03efb1aa462b34 100644 --- a/modcc/CMakeLists.txt +++ b/modcc/CMakeLists.txt @@ -21,6 +21,7 @@ set(libmodcc_sources io/prefixbuf.cpp printer/cexpr_emit.cpp printer/cprinter.cpp + printer/marks.cpp printer/gpuprinter.cpp printer/infoprinter.cpp printer/printerutil.cpp diff --git a/modcc/modcc.cpp b/modcc/modcc.cpp index d1ec2c76a44abbdbfaf29e0ee4cef95a8e8ea27e..40c522a73279e7b1c3ff6a8ca8834d6597289330 100644 --- a/modcc/modcc.cpp +++ b/modcc/modcc.cpp @@ -141,6 +141,7 @@ const char* usage_str = "-P|--profile [Build with profiled kernels]\n" "-V|--verbose [Toggle verbose mode]\n" "-A|--analyse [Toggle analysis mode]\n" + "-T|--trace-codegen [Leave trace marks in generated source]\n" "<filename> [File to be compiled]\n"; int main(int argc, char **argv) { @@ -167,16 +168,17 @@ int main(int argc, char **argv) { to::option options[] = { { opt.modfile, to::mandatory}, - { opt.outprefix, "-o", "--output" }, - { to::set(opt.verbose), to::flag, "-V", "--verbose" }, - { to::set(opt.analysis), to::flag, "-A", "--analyse" }, - { opt.modulename, "-m", "--module" }, - { to::set(popt.profile), to::flag, "-P", "--profile" }, - { popt.cpp_namespace, "-N", "--namespace" }, - { to::action(enable_simd), to::flag, "-s", "--simd" }, - { popt.simd, "-S", "--simd-abi" }, + { opt.outprefix, "-o", "--output" }, + { to::set(opt.verbose), to::flag, "-V", "--verbose" }, + { to::set(opt.analysis), to::flag, "-A", "--analyse" }, + { opt.modulename, "-m", "--module" }, + { to::set(popt.profile), to::flag, "-P", "--profile" }, + { popt.cpp_namespace, "-N", "--namespace" }, + { to::action(enable_simd), to::flag, "-s", "--simd" }, + { popt.simd, "-S", "--simd-abi" }, + { to::set(popt.trace_codegen), to::flag, "-T", "--trace-codegen"}, { {to::action(add_target, to::keywords(targetKindMap))}, "-t", "--target" }, - { to::action(help), to::flag, to::exit, "-h", "--help" } + { to::action(help), to::flag, to::exit, "-h", "--help" } }; if (!to::run(options, argc, argv+1)) return 0; diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp index 30e06ef3a26a4e0abb85469910f2284e11815799..fdc02657ae3c8270a67561ae93ebd407ea7c3f62 100644 --- a/modcc/printer/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -221,6 +221,20 @@ void SimdExprEmitter::visit(UnaryExpression* e) { } } +std::string id_prefix(IdentifierExpression* id) { + if (id) { + if (auto symbol = id->symbol()->is_symbol()) { + if (auto var = symbol->is_variable()) { + if (!var->is_local_variable()) { + return "pp->"+id->name(); + } + } + } + } + return id->name(); +} + + void SimdExprEmitter::visit(BinaryExpression* e) { static std::unordered_map<tok, const char *> func_tbl = { {tok::minus, "S::sub"}, @@ -262,7 +276,7 @@ void SimdExprEmitter::visit(BinaryExpression* e) { "CExprEmitter: unsupported binary operator " + token_string(e->op()), e->location()); } - std::string rhs_name, lhs_name; + std::string rhs_name, lhs_name, rhs_pfxd, lhs_pfxd; auto rhs = e->rhs(); auto lhs = e->lhs(); @@ -270,11 +284,13 @@ void SimdExprEmitter::visit(BinaryExpression* e) { const char *op_spelling = binop_tbl.at(e->op()); const char *func_spelling = func_tbl.at(e->op()); - if (rhs->is_identifier()) { - rhs_name = rhs->is_identifier()->name(); + if (auto id = rhs->is_identifier()) { + rhs_name = id->name(); + rhs_pfxd = id_prefix(id); } - if (lhs->is_identifier()) { - lhs_name = lhs->is_identifier()->name(); + if (auto id = lhs->is_identifier()) { + lhs_name = id->name(); + lhs_pfxd = id_prefix(id); } if (scalars_.count(rhs_name) && scalars_.count(lhs_name)) { @@ -313,10 +329,10 @@ void SimdExprEmitter::visit(BinaryExpression* e) { } else if (scalars_.count(rhs_name) && !scalars_.count(lhs_name)) { out_ << func_spelling << '('; lhs->accept(this); - out_ << ", simd_cast<simd_value>(" << rhs_name ; + out_ << ", simd_cast<simd_value>(" << rhs_pfxd; out_ << "))"; } else if (!scalars_.count(rhs_name) && scalars_.count(lhs_name)) { - out_ << func_spelling << "(simd_cast<simd_value>(" << lhs_name << "), "; + out_ << func_spelling << "(simd_cast<simd_value>(" << lhs_pfxd << "), "; rhs->accept(this); out_ << ")"; } else { @@ -365,14 +381,17 @@ void SimdExprEmitter::visit(AssignmentExpression* e) { auto mask = processing_true_ ? current_mask_ : current_mask_bar_; Symbol* lhs = e->lhs()->is_identifier()->symbol(); + auto lhs_pfxd = id_prefix(e->lhs()->is_identifier()); + + if (lhs->is_variable() && lhs->is_variable()->is_range()) { if (!input_mask_.empty()) { mask = "S::logical_and(" + mask + ", " + input_mask_ + ")"; } if(is_indirect_) - out_ << "indirect(" << lhs->name() << "+index_, simd_width_) = "; + out_ << "indirect(" << lhs_pfxd << "+index_, simd_width_) = "; else - out_ << "indirect(" << lhs->name() << "+i_, simd_width_) = "; + out_ << "indirect(" << lhs_pfxd << "+i_, simd_width_) = "; out_ << "S::where(" << mask << ", "; diff --git a/modcc/printer/cexpr_emit.hpp b/modcc/printer/cexpr_emit.hpp index 9286a5c94a97ea0cebb846c36f8539a86e0cfece..f17cbf56ad8b443cce065951eaf1557eec85681a 100644 --- a/modcc/printer/cexpr_emit.hpp +++ b/modcc/printer/cexpr_emit.hpp @@ -5,6 +5,7 @@ #include "expression.hpp" #include "visitor.hpp" +#include "marks.hpp" // Common functionality for generating source from binary expressions // and conditional structures with C syntax. diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index 332cd70cf156d8d1c3812b81a9f2d86ceffa55f9..27ec365d426dd62b24f9b3dd473dbaa6dbf996b2 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -11,6 +11,7 @@ #include "printer/cprinter.hpp" #include "printer/printeropt.hpp" #include "printer/printerutil.hpp" +#include "printer/marks.hpp" using io::indent; using io::popindent; @@ -24,6 +25,10 @@ constexpr bool with_profiling() { #endif } +inline static std::string make_cpu_class_name(const std::string& module_name) { return std::string{"mechanism_cpu_"} + module_name; } + +inline static std::string make_cpu_ppack_name(const std::string& module_name) { return make_cpu_class_name(module_name) + std::string{"_pp_"}; } + struct index_prop { std::string source_var; // array holding the indices std::string index_name; // index into the array @@ -33,9 +38,9 @@ struct index_prop { } }; -void emit_procedure_proto(std::ostream&, ProcedureExpression*, const std::string& qualified = ""); -void emit_simd_procedure_proto(std::ostream&, ProcedureExpression*, const std::string& qualified = ""); -void emit_masked_simd_procedure_proto(std::ostream&, ProcedureExpression*, const std::string& qualified = ""); +void emit_procedure_proto(std::ostream&, ProcedureExpression*, const std::string&, const std::string& qualified = ""); +void emit_simd_procedure_proto(std::ostream&, ProcedureExpression*, const std::string&, const std::string& qualified = ""); +void emit_masked_simd_procedure_proto(std::ostream&, ProcedureExpression*, const std::string&, const std::string& qualified = ""); void emit_api_body(std::ostream&, APIMethod*); void emit_simd_api_body(std::ostream&, APIMethod*, const std::vector<VariableExpression*>& scalars); @@ -103,19 +108,23 @@ static std::string ion_state_index(std::string ion_name) { } std::string emit_cpp_source(const Module& module_, const printer_options& opt) { - std::string name = module_.module_name(); - std::string class_name = "mechanism_cpu_"+name; - auto ns_components = namespace_components(opt.cpp_namespace); + auto name = module_.module_name(); + auto class_name = make_cpu_class_name(name); + auto namespace_name = "kernel_" + class_name; + auto ppack_name = make_cpu_ppack_name(name); + auto ns_components = namespace_components(opt.cpp_namespace); NetReceiveExpression* net_receive = find_net_receive(module_); PostEventExpression* post_event = find_post_event(module_); - APIMethod* init_api = find_api_method(module_, "init"); - APIMethod* state_api = find_api_method(module_, "advance_state"); - APIMethod* current_api = find_api_method(module_, "compute_currents"); + APIMethod* init_api = find_api_method(module_, "init"); + APIMethod* state_api = find_api_method(module_, "advance_state"); + APIMethod* current_api = find_api_method(module_, "compute_currents"); APIMethod* write_ions_api = find_api_method(module_, "write_ions"); bool with_simd = opt.simd.abi!=simd_spec::none; + options_trace_codegen = opt.trace_codegen; + // init_api, state_api, current_api methods are mandatory: assert_has_scope(init_api, "init"); @@ -150,6 +159,7 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { io::pfxstringstream out; + ENTER(out); out << "#include <algorithm>\n" "#include <cmath>\n" @@ -232,6 +242,126 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "\n"; } + out << "struct " << ppack_name << ": public ::arb::multicore::mechanism_ppack {\n" << indent; + for (const auto& scalar: vars.scalars) { + out << "::arb::fvm_value_type " << scalar->name() << " = " << as_c_double(scalar->value()) << ";\n"; + } + for (const auto& array: vars.arrays) { + out << "::arb::fvm_value_type* " << array->name() << ";\n"; + } + for (const auto& dep: ion_deps) { + out << "::arb::ion_state_view " << ion_state_field(dep.name) << ";\n"; + out << "::arb::fvm_index_type* " << ion_state_index(dep.name) << ";\n"; + } + out << popindent << "};\n\n"; + + // Make implementations + auto emit_body = [&](APIMethod *p) { + if (with_simd) { + emit_simd_api_body(out, p, vars.scalars); + } else { + emit_api_body(out, p); + } + }; + + out << "namespace " << namespace_name << " {\n"; + + out << "// procedure prototypes\n"; + for (auto proc: normal_procedures(module_)) { + if (with_simd) { + emit_simd_procedure_proto(out, proc, ppack_name); + out << ";\n"; + emit_masked_simd_procedure_proto(out, proc, ppack_name); + out << ";\n"; + } else { + emit_procedure_proto(out, proc, ppack_name); + out << ";\n"; + } + } + out << "\n"; + + out << "// interface methods\n"; + out << "void init(" << ppack_name << "* pp) {\n" << indent; + emit_body(init_api); + out << popindent << "}\n\n"; + + out << "void advance_state(" << ppack_name << "* pp) {\n" << indent; + out << profiler_enter("advance_integrate_state"); + emit_body(state_api); + out << profiler_leave(); + out << popindent << "}\n\n"; + + out << "void compute_currents(" << ppack_name << "* pp) {\n" << indent; + out << profiler_enter("advance_integrate_current"); + emit_body(current_api); + out << profiler_leave(); + out << popindent << "}\n\n"; + + out << "void write_ions(" << ppack_name << "* pp) {\n" << indent; + emit_body(write_ions_api); + out << popindent << "}\n\n"; + + if (net_receive) { + const std::string weight_arg = net_receive->args().empty() ? "weight" : net_receive->args().front()->is_argument()->name(); + out << + "void net_receive(" << ppack_name << "* pp, int i_, ::arb::fvm_value_type " << weight_arg << ") {\n" << indent << + cprint(net_receive->body()) << popindent << + "}\n\n" + "void apply_events(" << ppack_name << "* pp, ::arb::fvm_size_type mechanism_id, ::arb::multicore::deliverable_event_stream::state events) {\n" << indent << + "auto ncell = events.n_streams();\n" + "for (::arb::fvm_size_type c = 0; c<ncell; ++c) {\n" << indent << + "auto begin = events.begin_marked(c);\n" + "auto end = events.end_marked(c);\n" + "for (auto p = begin; p<end; ++p) {\n" << indent << + "if (p->mech_id==mechanism_id) " << namespace_name << "::net_receive(pp, p->mech_index, p->weight);\n" << popindent << + "}\n" << popindent << + "}\n" << popindent << + "}\n" + "\n"; + } + + if(post_event) { + const std::string time_arg = post_event->args().empty() ? "time" : post_event->args().front()->is_argument()->name(); + out << + "void post_event(" << ppack_name << "* pp) {\n" << indent << + "int n_ = pp->width_;\n" + "for (int i_ = 0; i_ < n_; ++i_) {\n" << indent << + "auto node_index_i_ = pp->node_index_[i_];\n" + "auto cid_ = pp->vec_ci_[node_index_i_];\n" + "auto offset_ = pp->n_detectors_ * cid_;\n" + "for (::arb::fvm_index_type c = 0; c < pp->n_detectors_; c++) {\n" << indent << + "auto " << time_arg << " = pp->time_since_spike_[offset_ + c];\n" + "if (" << time_arg << " >= 0) {\n" << indent << + cprint(post_event->body()) << popindent << + "}\n" << popindent << + "}\n" << popindent << + "}\n" << popindent << + "}\n\n"; + } + + + out << "// Procedure definitions\n"; + for (auto proc: normal_procedures(module_)) { + if (with_simd) { + emit_simd_procedure_proto(out, proc, ppack_name); + auto simd_print = simdprint(proc->body(), vars.scalars); + out << " {\n" << indent << simd_print << popindent << "}\n\n"; + + emit_masked_simd_procedure_proto(out, proc, ppack_name); + auto masked_print = simdprint(proc->body(), vars.scalars); + masked_print.set_masked(); + out << " {\n" << indent << masked_print << popindent << "}\n\n"; + } else { + emit_procedure_proto(out, proc, ppack_name); + out << + " {\n" << indent << + cprint(proc->body()) << popindent << + "}\n\n"; + } + } + + out << popindent << "}\n\n"; // close kernel namespace + out << "class " << class_name << ": public base {\n" "public:\n" << indent << @@ -243,24 +373,25 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "::arb::mechanismKind kind() const override { return " << module_kind_str(module_) << "; }\n" "::arb::mechanism_ptr clone() const override { return ::arb::mechanism_ptr(new " << class_name << "()); }\n" "\n" - "void init() override;\n" - "void advance_state() override;\n" - "void compute_currents() override;\n" - "void write_ions() override;\n"; + "void init() override { " << namespace_name << "::init(&pp_); }\n" + "void advance_state() override { " << namespace_name << "::advance_state(&pp_); }\n" + "void compute_currents() override { " << namespace_name << "::compute_currents(&pp_); }\n" + "void write_ions() override{ " << namespace_name << "::write_ions(&pp_); }\n"; - net_receive && out << - "void apply_events(deliverable_event_stream::state events) override;\n" - "void net_receive(int i_, ::arb::fvm_value_type weight);\n"; + net_receive && + out << "void apply_events(deliverable_event_stream::state events) override { " << namespace_name << "::apply_events(&pp_, mechanism_id_, events); }\n"; - post_event && out << - "void post_event() override;\n"; + post_event && + out << "void post_event() override { " << namespace_name << "::post_event(&pp_); };\n"; - with_simd && out << "unsigned simd_width() const override { return simd_width_; }\n"; + with_simd && + out << "unsigned simd_width() const override { return simd_width_; }\n"; out << "\n" << popindent << "protected:\n" << indent << - "std::size_t object_sizeof() const override { return sizeof(*this); }\n"; + "std::size_t object_sizeof() const override { return sizeof(*this); }\n" << + "virtual ::arb::mechanism_ppack* ppack_ptr() override { return &pp_; }\n"; io::separator sep("\n", ",\n"); if (!vars.scalars.empty()) { @@ -270,7 +401,7 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { for (const auto& scalar: vars.scalars) { auto memb = scalar->name(); - out << sep << "{" << quote(memb) << ", &" << memb << "}"; + out << sep << "{" << quote(memb) << ", &pp_." << memb << "}"; } out << popindent << "\n};\n" << popindent << "}\n"; } @@ -283,7 +414,7 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { sep.reset(); for (const auto& array: vars.arrays) { auto memb = array->name(); - out << sep << "{" << quote(memb) << ", &" << memb << "}"; + out << sep << "{" << quote(memb) << ", &pp_." << memb << "}"; } out << popindent << "\n};" << popindent << "\n}\n"; @@ -309,7 +440,7 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { for (const auto& array: vars.arrays) { auto memb = array->name(); if(array->is_state()) { - out << sep << "{" << quote(memb) << ", &" << memb << "}"; + out << sep << "{" << quote(memb) << ", &pp_." << memb << "}"; } } out << popindent << "\n};" << popindent << "\n}\n"; @@ -323,43 +454,21 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { sep.reset(); for (const auto& dep: ion_deps) { - out << sep << "{\"" << dep.name << "\", &" << ion_state_field(dep.name) << "}"; + out << sep << "{\"" << dep.name << "\", &pp_." << ion_state_field(dep.name) << "}"; } out << popindent << "\n};" << popindent << "\n}\n"; sep.reset(); out << "mechanism_ion_index_table ion_index_table() override {\n" << indent << "return {" << indent; for (const auto& dep: ion_deps) { - out << sep << "{\"" << dep.name << "\", &" << ion_state_index(dep.name) << "}"; + out << sep << "{\"" << dep.name << "\", &pp_." << ion_state_index(dep.name) << "}"; } out << popindent << "\n};" << popindent << "\n}\n"; } out << popindent << "\n" "private:\n" << indent; - - for (const auto& scalar: vars.scalars) { - out << "::arb::fvm_value_type " << scalar->name() << " = " << as_c_double(scalar->value()) << ";\n"; - } - for (const auto& array: vars.arrays) { - out << "::arb::fvm_value_type* " << array->name() << ";\n"; - } - for (const auto& dep: ion_deps) { - out << "::arb::ion_state_view " << ion_state_field(dep.name) << ";\n"; - out << "::arb::fvm_index_type* " << ion_state_index(dep.name) << ";\n"; - } - - for (auto proc: normal_procedures(module_)) { - if (with_simd) { - emit_simd_procedure_proto(out, proc); - out << ";\n"; - emit_masked_simd_procedure_proto(out, proc); - out << ";\n"; - } else { - emit_procedure_proto(out, proc); - out << ";\n"; - } - } + out << ppack_name << " pp_;\n"; out << popindent << "};\n\n" @@ -368,97 +477,8 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "return ::arb::concrete_mech_ptr<backend>(new " << class_name << "());\n" << popindent << "}\n\n"; - // Interface methods: - - if (net_receive) { - const std::string weight_arg = net_receive->args().empty() ? "weight" : net_receive->args().front()->is_argument()->name(); - out << - "void " << class_name << "::apply_events(deliverable_event_stream::state events) {\n" << indent << - "auto ncell = events.n_streams();\n" - "for (size_type c = 0; c<ncell; ++c) {\n" << indent << - "auto begin = events.begin_marked(c);\n" - "auto end = events.end_marked(c);\n" - "for (auto p = begin; p<end; ++p) {\n" << indent << - "if (p->mech_id==mechanism_id_) net_receive(p->mech_index, p->weight);\n" << popindent << - "}\n" << popindent << - "}\n" << popindent << - "}\n" - "\n" - "void " << class_name << "::net_receive(int i_, ::arb::fvm_value_type " << weight_arg << ") {\n" << indent << - cprint(net_receive->body()) << popindent << - "}\n\n"; - } - - if(post_event) { - const std::string time_arg = post_event->args().empty() ? "time" : post_event->args().front()->is_argument()->name(); - out << - "void " << class_name << "::post_event() {\n" << indent << - "int n_ = width_;\n" - "for (int i_ = 0; i_ < n_; ++i_) {\n" << indent << - "auto node_index_i_ = node_index_[i_];\n" - "auto cid_ = vec_ci_[node_index_i_];\n" - "auto offset_ = n_detectors_ * cid_;\n" - "for (unsigned c = 0; c < n_detectors_; c++) {\n" << indent << - "auto " << time_arg << " = time_since_spike_[offset_ + c];\n" - "if (" << time_arg << " >= 0) {\n" << indent << - cprint(post_event->body()) << popindent << - "}\n" << popindent << - "}\n" << popindent << - "}\n" << popindent << - "}\n\n"; - } - - auto emit_body = [&](APIMethod *p) { - if (with_simd) { - emit_simd_api_body(out, p, vars.scalars); - } - else { - emit_api_body(out, p); - } - }; - - out << "void " << class_name << "::init() {\n" << indent; - emit_body(init_api); - out << popindent << "}\n\n"; - - out << "void " << class_name << "::advance_state() {\n" << indent; - out << profiler_enter("advance_integrate_state"); - emit_body(state_api); - out << profiler_leave(); - out << popindent << "}\n\n"; - - out << "void " << class_name << "::compute_currents() {\n" << indent; - out << profiler_enter("advance_integrate_current"); - emit_body(current_api); - out << profiler_leave(); - out << popindent << "}\n\n"; - - out << "void " << class_name << "::write_ions() {\n" << indent; - emit_body(write_ions_api); - out << popindent << "}\n\n"; - - // Mechanism procedures - - for (auto proc: normal_procedures(module_)) { - if (with_simd) { - emit_simd_procedure_proto(out, proc, class_name); - auto simd_print = simdprint(proc->body(), vars.scalars); - out << " {\n" << indent << simd_print << popindent << "}\n\n"; - - emit_masked_simd_procedure_proto(out, proc, class_name); - auto masked_print = simdprint(proc->body(), vars.scalars); - masked_print.set_masked(); - out << " {\n" << indent << masked_print << popindent << "}\n\n"; - } else { - emit_procedure_proto(out, proc, class_name); - out << - " {\n" << indent << - cprint(proc->body()) << popindent << - "}\n\n"; - } - } - out << namespace_declaration_close(ns_components); + EXIT(out); return out.str(); } @@ -473,11 +493,11 @@ void CPrinter::visit(LocalVariable* sym) { } void CPrinter::visit(VariableExpression *sym) { - out_ << sym->name() << (sym->is_range()? "[i_]": ""); + out_ << "pp->" << sym->name() << (sym->is_range()? "[i_]": ""); } void CPrinter::visit(CallExpression* e) { - out_ << e->name() << "(i_"; + out_ << e->name() << "(pp, i_"; for (auto& arg: e->args()) { out_ << ", "; arg->accept(this); @@ -486,6 +506,7 @@ void CPrinter::visit(CallExpression* e) { } void CPrinter::visit(BlockExpression* block) { + ENTERM(out_, "c:block"); // Only include local declarations in outer-most block. if (!block->is_nested()) { auto locals = pure_locals(block->scope()); @@ -505,14 +526,15 @@ void CPrinter::visit(BlockExpression* block) { out_ << (stmt->is_if()? "": ";\n"); } } + EXITM(out_, "c:block"); } static std::string index_i_name(const std::string& index_var) { return index_var+"i_"; } -void emit_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& qualified) { - out << "void " << qualified << (qualified.empty()? "": "::") << e->name() << "(int i_"; +void emit_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& ppack_name, const std::string& qualified) { + out << "void " << qualified << (qualified.empty()? "": "::") << e->name() << "(" << ppack_name << "* pp, int i_"; for (auto& arg: e->args()) { out << ", ::arb::fvm_value_type " << arg->is_argument()->name(); } @@ -520,6 +542,12 @@ void emit_procedure_proto(std::ostream& out, ProcedureExpression* e, const std:: } namespace { + // Access through ppack + std::string data_via_ppack(const indexed_variable_info& i) { return "pp->" + i.data_var; } + std::string node_index_i_name(const indexed_variable_info& i) { return i.node_index_var + "i_"; } + std::string source_index_i_name(const index_prop& i) { return i.source_var + "i_"; } + std::string source_var(const index_prop& i) { return "pp->" + i.source_var; } + // Convenience I/O wrapper for emitting indexed access to an external variable. struct deref { @@ -528,13 +556,15 @@ namespace { friend std::ostream& operator<<(std::ostream& o, const deref& wrap) { auto index_var = wrap.d.cell_index_var.empty() ? wrap.d.node_index_var : wrap.d.cell_index_var; - return o << wrap.d.data_var << '[' - << (wrap.d.scalar()? "0": index_i_name(index_var)) << ']'; + auto i_name = index_i_name(index_var); + index_var = "pp->" + index_var; + return o << data_via_ppack(wrap.d) << '[' << (wrap.d.scalar() ? "0": i_name) << ']'; } }; } void emit_state_read(std::ostream& out, LocalVariable* local) { + ENTER(out); out << "::arb::fvm_value_type " << cprint(local) << " = "; if (local->is_read()) { @@ -547,11 +577,12 @@ void emit_state_read(std::ostream& out, LocalVariable* local) { else { out << "0;\n"; } + EXIT(out); } void emit_state_update(std::ostream& out, Symbol* from, IndexedVariable* external) { if (!external->is_write()) return; - + ENTER(out); auto d = decode_indexed_variable(external); double coeff = 1./d.scale; @@ -564,7 +595,7 @@ void emit_state_update(std::ostream& out, Symbol* from, IndexedVariable* externa if (coeff != 1) { out << as_c_double(coeff) << '*'; } - out << "weight_[i_], " << from->name() << ", " << deref(d) << ");\n"; + out << "pp->weight_[i_], " << from->name() << ", " << deref(d) << ");\n"; } else { out << deref(d) << " = "; @@ -573,9 +604,11 @@ void emit_state_update(std::ostream& out, Symbol* from, IndexedVariable* externa } out << from->name() << ";\n"; } + EXIT(out); } void emit_api_body(std::ostream& out, APIMethod* method) { + ENTER(out); auto body = method->body(); auto indexed_vars = indexed_locals(method->scope()); @@ -587,7 +620,7 @@ void emit_api_body(std::ostream& out, APIMethod* method) { auto it = std::find(indices.begin(), indices.end(), node_idx); if (it == indices.end()) indices.push_front(node_idx); if (!d.cell_index_var.empty()) { - index_prop cell_idx = {d.cell_index_var, index_i_name(d.node_index_var), false}; + index_prop cell_idx = {d.cell_index_var, node_index_i_name(d), false}; auto it = std::find(indices.begin(), indices.end(), cell_idx); if (it == indices.end()) indices.push_back(cell_idx); } @@ -596,11 +629,11 @@ void emit_api_body(std::ostream& out, APIMethod* method) { if (!body->statements().empty()) { out << - "int n_ = width_;\n" + "int n_ = pp->width_;\n" "for (int i_ = 0; i_ < n_; ++i_) {\n" << indent; for (auto index: indices) { - out << "auto " << index_i_name(index.source_var) << " = " << index.source_var << "[" << index.index_name << "];\n"; + out << "auto " << source_index_i_name(index) << " = " << source_var(index) << "[" << index.index_name << " ];\n"; } for (auto& sym: indexed_vars) { @@ -613,29 +646,37 @@ void emit_api_body(std::ostream& out, APIMethod* method) { } out << popindent << "}\n"; } + EXIT(out); } // SIMD printing: void SimdPrinter::visit(IdentifierExpression *e) { + ENTERM(out_, "identifier"); e->symbol()->accept(this); + EXITM(out_, "identifier"); } void SimdPrinter::visit(LocalVariable* sym) { + ENTERM(out_, "local"); out_ << sym->name(); + EXITM(out_, "local"); } void SimdPrinter::visit(VariableExpression *sym) { + ENTERM(out_, "variable"); if (sym->is_range()) { auto index = is_indirect_? "index_": "i_"; - out_ << "simd_cast<simd_value>(indirect(" << sym->name() << "+" << index << ", simd_width_))"; + out_ << "simd_cast<simd_value>(indirect(pp->" << sym->name() << "+" << index << ", simd_width_))"; } else { - out_ << sym->name(); + out_ << "pp->" << sym->name(); } + EXITM(out_, "variable"); } void SimdPrinter::visit(AssignmentExpression* e) { + ENTERM(out_, "assign"); if (!e->lhs() || !e->lhs()->is_identifier() || !e->lhs()->is_identifier()->symbol()) { throw compiler_exception("Expect symbol on lhs of assignment: "+e->to_string()); } @@ -650,10 +691,11 @@ void SimdPrinter::visit(AssignmentExpression* e) { if (scalars_.count(e->lhs()->is_identifier()->name())) cast = false; if (lhs->is_variable() && lhs->is_variable()->is_range()) { + std::string pfx = lhs->is_local_variable() ? "" : "pp->"; if(is_indirect_) - out_ << "indirect(" << lhs->name() << "+index_, simd_width_) = "; + out_ << "indirect(" << pfx << lhs->name() << "+index_, simd_width_) = "; else - out_ << "indirect(" << lhs->name() << "+i_, simd_width_) = "; + out_ << "indirect(" << pfx << lhs->name() << "+i_, simd_width_) = "; if (!input_mask_.empty()) out_ << "S::where(" << input_mask_ << ", "; @@ -664,15 +706,15 @@ void SimdPrinter::visit(AssignmentExpression* e) { if (!input_mask_.empty()) out_ << ")"; - } - else { - out_ << "assign(" << lhs->name() << ", "; + } else { + std::string pfx = lhs->is_local_variable() ? "" : "pp->"; + out_ << "assign(" << pfx << lhs->name() << ", "; if (auto rhs = e->rhs()->is_identifier()) { if (auto sym = rhs->symbol()) { // We shouldn't call the rhs visitor in this case because it automatically casts indirect expressions if (sym->is_variable() && sym->is_variable()->is_range()) { auto index = is_indirect_ ? "index_" : "i_"; - out_ << "indirect(" << rhs->name() << "+" << index << ", simd_width_))"; + out_ << "indirect(pp->" << rhs->name() << "+" << index << ", simd_width_))"; return; } } @@ -680,22 +722,26 @@ void SimdPrinter::visit(AssignmentExpression* e) { e->rhs()->accept(this); out_ << ")"; } + EXITM(out_, "assign"); } void SimdPrinter::visit(CallExpression* e) { + ENTERM(out_, "call"); if(is_indirect_) - out_ << e->name() << "(index_"; + out_ << e->name() << "(pp, index_"; else - out_ << e->name() << "(i_"; + out_ << e->name() << "(pp, i_"; for (auto& arg: e->args()) { out_ << ", "; arg->accept(this); } out_ << ")"; + EXITM(out_, "call"); } void SimdPrinter::visit(BlockExpression* block) { // Only include local declarations in outer-most block. + ENTERM(out_, "block"); if (!block->is_nested()) { auto locals = pure_locals(block->scope()); if (!locals.empty()) { @@ -716,32 +762,38 @@ void SimdPrinter::visit(BlockExpression* block) { } } } + EXITM(out_, "block"); } -void emit_simd_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& qualified) { - out << "void " << qualified << (qualified.empty()? "": "::") << e->name() << "(::arb::fvm_index_type i_"; +void emit_simd_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& ppack_name, const std::string& qualified) { + ENTER(out); + out << "void " << qualified << (qualified.empty()? "": "::") << e->name() << "(" << ppack_name << "* pp, ::arb::fvm_index_type i_"; for (auto& arg: e->args()) { out << ", const simd_value& " << arg->is_argument()->name(); } out << ")"; + EXIT(out); } -void emit_masked_simd_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& qualified) { +void emit_masked_simd_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& ppack_name, const std::string& qualified) { + ENTER(out); out << "void " << qualified << (qualified.empty()? "": "::") << e->name() - << "(::arb::fvm_index_type i_, simd_mask mask_input_"; + << "(" << ppack_name << "* pp, ::arb::fvm_index_type i_, simd_mask mask_input_"; for (auto& arg: e->args()) { out << ", const simd_value& " << arg->is_argument()->name(); } out << ")"; + EXIT(out); } void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_constraint constraint) { + ENTER(out); out << "simd_value " << local->name(); if (local->is_read()) { auto d = decode_indexed_variable(local->external_variable()); if (d.scalar()) { - out << " = simd_cast<simd_value>(" << d.data_var + out << " = simd_cast<simd_value>(pp->" << d.data_var << "[0]);\n"; } else { @@ -749,22 +801,22 @@ void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_con switch (constraint) { case simd_expr_constraint::contiguous: out << ";\n" - << "assign(" << local->name() << ", indirect(" << d.data_var - << " + " << index_i_name(d.node_index_var) << ", simd_width_));\n"; + << "assign(" << local->name() << ", indirect(" << data_via_ppack(d) + << " + " << node_index_i_name(d) << ", simd_width_));\n"; break; case simd_expr_constraint::constant: - out << " = simd_cast<simd_value>(" << d.data_var - << "[" << index_i_name(d.node_index_var) << "]);\n"; + out << " = simd_cast<simd_value>(" << data_via_ppack(d) + << "[" << node_index_i_name(d) << "]);\n"; break; default: out << ";\n" - << "assign(" << local->name() << ", indirect(" << d.data_var - << ", " << index_i_name(d.node_index_var) << ", simd_width_, constraint_category_));\n"; + << "assign(" << local->name() << ", indirect(" << data_via_ppack(d) + << ", " << node_index_i_name(d) << ", simd_width_, constraint_category_));\n"; } } else { out << ";\n" - << "assign(" << local->name() << ", indirect(" << d.data_var + << "assign(" << local->name() << ", indirect(" << data_via_ppack(d) << ", " << index_i_name(d.cell_index_var) << ", simd_width_, index_constraint::none));\n"; } } @@ -776,18 +828,21 @@ void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_con else { out << " = simd_cast<simd_value>(0);\n"; } + EXIT(out); } void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* external, simd_expr_constraint constraint) { if (!external->is_write()) return; - auto d = decode_indexed_variable(external);; + auto d = decode_indexed_variable(external); double coeff = 1./d.scale; if (d.readonly) { throw compiler_exception("Cannot assign to read-only external state: "+external->to_string()); } + ENTER(out); + if (d.accumulate) { if (d.cell_index_var.empty()) { switch (constraint) { @@ -795,18 +850,18 @@ void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* ex { std::string tempvar = "t_" + external->name(); out << "simd_value " << tempvar << ";\n" - << "assign(" << tempvar << ", indirect(" << d.data_var << " + " << index_i_name(d.node_index_var) << ", simd_width_));\n"; + << "assign(" << tempvar << ", indirect(" << data_via_ppack(d) << " + " << node_index_i_name(d) << ", simd_width_));\n"; if (coeff != 1) { out << tempvar << " = S::fma(S::mul(w_, simd_cast<simd_value>(" << as_c_double(coeff) << "))," << from->name() << ", " << tempvar << ");\n"; } else { out << tempvar << " = S::fma(w_, " << from->name() << ", " << tempvar << ");\n"; } - out << "indirect(" << d.data_var << " + " << index_i_name(d.node_index_var) << ", simd_width_) = " << tempvar << ";\n"; + out << "indirect(" << data_via_ppack(d) << " + " << node_index_i_name(d) << ", simd_width_) = " << tempvar << ";\n"; break; } case simd_expr_constraint::constant: { - out << "indirect(" << d.data_var << ", simd_cast<simd_index>(" << index_i_name(d.node_index_var) << "), simd_width_, constraint_category_)"; + out << "indirect(" << data_via_ppack(d) << ", simd_cast<simd_index>(" << node_index_i_name(d) << "), simd_width_, constraint_category_)"; if (coeff != 1) { out << " += S::mul(w_, S::mul(simd_cast<simd_value>(" << as_c_double(coeff) << "), " << from->name() << "));\n"; } else { @@ -816,7 +871,7 @@ void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* ex } default : { - out << "indirect(" << d.data_var << ", " << index_i_name(d.node_index_var) << ", simd_width_, constraint_category_)"; + out << "indirect(" << data_via_ppack(d) << ", " << node_index_i_name(d) << ", simd_width_, constraint_category_)"; if (coeff != 1) { out << " += S::mul(w_, S::mul(simd_cast<simd_value>(" << as_c_double(coeff) << "), " << from->name() << "));\n"; } else { @@ -825,7 +880,7 @@ void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* ex } } } else { - out << "indirect(" << d.data_var << ", " << index_i_name(d.cell_index_var) << ", simd_width_, index_constraint::none)"; + out << "indirect(" << data_via_ppack(d) << ", " << index_i_name(d.cell_index_var) << ", simd_width_, index_constraint::none)"; if (coeff != 1) { out << " += S::mul(w_, S::mul(simd_cast<simd_value>(" << as_c_double(coeff) << "), " << from->name() << "));\n"; } else { @@ -837,16 +892,20 @@ void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* ex if (d.cell_index_var.empty()) { switch (constraint) { case simd_expr_constraint::contiguous: - out << "indirect(" << d.data_var << " + " << index_i_name(d.node_index_var) << ", simd_width_) = "; + out << "indirect(" << data_via_ppack(d) << " + " << node_index_i_name(d) << ", simd_width_) = "; break; case simd_expr_constraint::constant: - out << "indirect(" << d.data_var << ", simd_cast<simd_index>(" << index_i_name(d.node_index_var) << "), simd_width_, constraint_category_) = "; + out << "indirect(" << data_via_ppack(d) << ", simd_cast<simd_index>(" << node_index_i_name(d) << "), simd_width_, constraint_category_) = "; break; default: - out << "indirect(" << d.data_var << ", " << index_i_name(d.node_index_var) << ", simd_width_, constraint_category_) = "; + out << "indirect(" << data_via_ppack(d) << ", " << node_index_i_name(d) << ", simd_width_, constraint_category_) = "; } } else { - out << "indirect(" << d.data_var << ", " << index_i_name(d.cell_index_var) << ", simd_width_, index_constraint::none) = "; + out << "indirect(" << data_via_ppack(d) + + + + << ", " << index_i_name(d.cell_index_var) << ", simd_width_, index_constraint::none) = "; } if (coeff != 1) { @@ -855,39 +914,43 @@ void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* ex out << from->name() << ";\n"; } } + + EXIT(out); } void emit_simd_index_initialize(std::ostream& out, const std::list<index_prop>& indices, - simd_expr_constraint constraint) { + simd_expr_constraint constraint) { + ENTER(out); for (auto& index: indices) { if (index.node_index) { switch (constraint) { case simd_expr_constraint::contiguous: case simd_expr_constraint::constant: - out << "auto " << index_i_name(index.source_var) << " = " << index.source_var << "[" << index.index_name << "];\n"; + out << "auto " << source_index_i_name(index) << " = " << source_var(index) << "[" << index.index_name << "];\n"; break; default: - out << "auto " << index_i_name(index.source_var) << " = simd_cast<simd_index>(indirect(&" << index.source_var + out << "auto " << source_index_i_name(index) << " = simd_cast<simd_index>(indirect(&" << source_var(index) << "[0] + " << index.index_name << ", simd_width_));\n"; break; } } else { switch (constraint) { case simd_expr_constraint::contiguous: - out << "auto " << index_i_name(index.source_var) << " = simd_cast<simd_index>(indirect(" << index.source_var + out << "auto " << source_index_i_name(index) << " = simd_cast<simd_index>(indirect(" << source_var(index) << " + " << index.index_name << ", simd_width_));\n"; break; case simd_expr_constraint::constant: - out << "auto " << index_i_name(index.source_var) << " = simd_cast<simd_index>(" << index.source_var + out << "auto " << source_index_i_name(index) << " = simd_cast<simd_index>(" << source_var(index) << "[" << index.index_name << "]);\n"; break; default: - out << "auto " << index_i_name(index.source_var) << " = simd_cast<simd_index>(indirect(" << index.source_var + out << "auto " << source_index_i_name(index) << " = simd_cast<simd_index>(indirect(" << source_var(index) << ", " << index.index_name << ", simd_width_, constraint_category_));\n"; break; } } } + EXIT(out); } void emit_simd_body_for_loop( @@ -897,6 +960,7 @@ void emit_simd_body_for_loop( const std::vector<VariableExpression*>& scalars, const std::list<index_prop>& indices, const simd_expr_constraint& constraint) { + ENTER(out); emit_simd_index_initialize(out, indices, constraint); for (auto& sym: indexed_vars) { @@ -911,6 +975,7 @@ void emit_simd_body_for_loop( for (auto& sym: indexed_vars) { emit_simd_state_update(out, sym, sym->external_variable(), constraint); } + EXIT(out); } void emit_simd_for_loop_per_constraint(std::ostream& out, BlockExpression* body, @@ -920,21 +985,22 @@ void emit_simd_for_loop_per_constraint(std::ostream& out, BlockExpression* body, const std::list<index_prop>& indices, const simd_expr_constraint& constraint, std::string underlying_constraint_name) { - + ENTER(out); out << "constraint_category_ = index_constraint::"<< underlying_constraint_name << ";\n"; - out << "for (unsigned i_ = 0; i_ < index_constraints_." << underlying_constraint_name + out << "for (unsigned i_ = 0; i_ < pp->index_constraints_." << underlying_constraint_name << ".size(); i_++) {\n" << indent; - out << "::arb::fvm_index_type index_ = index_constraints_." << underlying_constraint_name << "[i_];\n"; + out << "::arb::fvm_index_type index_ = pp->index_constraints_." << underlying_constraint_name << "[i_];\n"; if (requires_weight) { out << "simd_value w_;\n" - << "assign(w_, indirect((weight_+index_), simd_width_));\n"; + << "assign(w_, indirect((pp->weight_+index_), simd_width_));\n"; } emit_simd_body_for_loop(out, body, indexed_vars, scalars, indices, constraint); out << popindent << "}\n"; + EXIT(out); } void emit_simd_api_body(std::ostream& out, APIMethod* method, const std::vector<VariableExpression*>& scalars) { @@ -945,6 +1011,8 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, const std::vector< std::vector<LocalVariable*> scalar_indexed_vars; std::list<index_prop> indices; + ENTER(out); + for (auto& s: body->is_block()->statements()) { if (s->is_assignment()) { for (auto& v: indexed_vars) { @@ -967,7 +1035,7 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, const std::vector< if (it == indices.end()) indices.push_front(node_idx); if (!info.cell_index_var.empty()) { - index_prop cell_idx = {info.cell_index_var, index_i_name(info.node_index_var), false}; + index_prop cell_idx = {info.cell_index_var, node_index_i_name(info), false}; it = std::find(indices.begin(), indices.end(), cell_idx); if (it == indices.end()) indices.push_back(cell_idx); } @@ -1013,11 +1081,11 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, const std::vector< } out << - "unsigned n_ = width_;\n\n" + "unsigned n_ = pp->width_;\n\n" "for (unsigned i_ = 0; i_ < n_; i_ += simd_width_) {\n" << indent << simdprint(body, scalars) << popindent << "}\n"; } } + EXIT(out); } - diff --git a/modcc/printer/gpuprinter.cpp b/modcc/printer/gpuprinter.cpp index 2892528209f12894d70c251fb6d33c4c94f2ca68..61ad097500a94c8d5b0c1841b771a8a2379ec753 100644 --- a/modcc/printer/gpuprinter.cpp +++ b/modcc/printer/gpuprinter.cpp @@ -68,9 +68,8 @@ std::string emit_gpu_cpp_source(const Module& module_, const printer_options& op "#include <" << arb_private_header_prefix() << "backends/event.hpp>\n" "#include <" << arb_private_header_prefix() << "backends/multi_event_stream_state.hpp>\n"; - out << - "#include <" << arb_private_header_prefix() << "backends/gpu/mechanism.hpp>\n" - "#include <" << arb_private_header_prefix() << "backends/gpu/mechanism_ppack_base.hpp>\n"; + out << "#include <" << arb_private_header_prefix() << "backends/gpu/mechanism.hpp>\n" + << "#include <arbor/mechanism_ppack.hpp>\n"; out << "\n" << namespace_declaration_open(ns_components) << "\n"; @@ -128,7 +127,7 @@ std::string emit_gpu_cpp_source(const Module& module_, const printer_options& op out << popindent << "protected:\n" << indent << "std::size_t object_sizeof() const override { return sizeof(*this); }\n" - "::arb::gpu::mechanism_ppack_base* ppack_ptr() override { return &pp_; }\n\n"; + "::arb::mechanism_ppack* ppack_ptr() override { return &pp_; }\n\n"; io::separator sep("\n", ",\n"); if (!vars.scalars.empty()) { @@ -243,7 +242,8 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt "#include <" << arb_private_header_prefix() << "backends/multi_event_stream_state.hpp>\n" "#include <" << arb_private_header_prefix() << "backends/gpu/gpu_common.hpp>\n" "#include <" << arb_private_header_prefix() << "backends/gpu/math_cu.hpp>\n" - "#include <" << arb_private_header_prefix() << "backends/gpu/mechanism_ppack_base.hpp>\n"; + "#include <arbor/mechanism.hpp>\n" << + "#include <arbor/mechanism_ppack.hpp>\n"; is_point_proc && out << "#include <" << arb_private_header_prefix() << "backends/gpu/reduce_by_key.hpp>\n"; @@ -397,7 +397,7 @@ void emit_common_defs(std::ostream& out, const Module& module_) { "using deliverable_event_stream_state =\n" " ::arb::multi_event_stream_state<::arb::deliverable_event_data>;\n\n"; - out << "struct " << ppack_name << ": ::arb::gpu::mechanism_ppack_base {\n" << indent; + out << "struct " << ppack_name << ": ::arb::mechanism_ppack {\n" << indent; for (const auto& scalar: vars.scalars) { out << "::arb::fvm_value_type " << scalar->name() << " = " << as_c_double(scalar->value()) << ";\n"; diff --git a/modcc/printer/marks.cpp b/modcc/printer/marks.cpp new file mode 100644 index 0000000000000000000000000000000000000000..597065261a880949e5a4501fd1b7f7d59f8f808d --- /dev/null +++ b/modcc/printer/marks.cpp @@ -0,0 +1,3 @@ +#include "marks.hpp" + +bool options_trace_codegen = false; diff --git a/modcc/printer/marks.hpp b/modcc/printer/marks.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5c3c537e290211874056c64a0e3d000bd4baab1d --- /dev/null +++ b/modcc/printer/marks.hpp @@ -0,0 +1,19 @@ +#pragma once +#include "printeropt.hpp" + +extern bool options_trace_codegen; + +#define ENTERM(stream, msg) do { \ + if (options_trace_codegen) { \ + (stream) << " /* " << __FUNCTION__ << ":" << (msg) << ":enter */ "; \ + } \ + } while(0) + +#define EXITM(stream, msg) do { \ + if (options_trace_codegen) { \ + (stream) << " /* " << __FUNCTION__ << ":" << (msg) << ":exit */ "; \ + } \ + } while(0) + +#define ENTER(stream) ENTERM(stream, "") +#define EXIT(stream) EXITM(stream, "") diff --git a/modcc/printer/printeropt.hpp b/modcc/printer/printeropt.hpp index 6c4b9a32b30c54f288ece59d5bd07f750af5f38c..72130dc3221ebf7d47e4a949aca50613611d25d7 100644 --- a/modcc/printer/printeropt.hpp +++ b/modcc/printer/printeropt.hpp @@ -6,6 +6,7 @@ #include <string> #include "simd.hpp" + struct printer_options { // C++ namespace for generated code. std::string cpp_namespace; @@ -17,4 +18,5 @@ struct printer_options { // Currently only supported for C printer. bool profile = false; + bool trace_codegen = false; }; diff --git a/test/unit-modcc/test_printers.cpp b/test/unit-modcc/test_printers.cpp index e88518751057db6e9576bf51367e08f48e50c680..b96438963f394784c8b6f36c51fb1c98831824f6 100644 --- a/test/unit-modcc/test_printers.cpp +++ b/test/unit-modcc/test_printers.cpp @@ -125,10 +125,10 @@ TEST(CPrinter, proc_body) { "}" , "::arb::fvm_value_type k;\n" - "minf[i_] = 1.0-1.0/(1.0+exp((v-k)/k));\n" - "hinf[i_] = 1.0/(1.0+exp((v-k)/k));\n" - "mtau[i_] = 0.5;\n" - "htau[i_] = 1500.0;\n" + "pp->minf[i_] = 1.0-1.0/(1.0+exp((v-k)/k));\n" + "pp->hinf[i_] = 1.0/(1.0+exp((v-k)/k));\n" + "pp->mtau[i_] = 0.5;\n" + "pp->htau[i_] = 1500.0;\n" } }; @@ -167,7 +167,7 @@ TEST(CPrinter, proc_body_const) { " mtau = 0.5 - t0 + t1\n" "}" , - "mtau[i_] = 0.5 - -0.5 + 1.2;\n" + "pp->mtau[i_] = 0.5 - -0.5 + 1.2;\n" } }; @@ -204,27 +204,27 @@ TEST(CPrinter, proc_body_inlined) { "r_6_ = 0.;\n" "r_7_ = 0.;\n" "r_8_ = 0.;\n" - "r_9_=s2[i_]*0.33333333333333331;\n" - "r_8_=s1[i_]+2.0;\n" - "if(s1[i_]==3.0){\n" + "r_9_=pp->s2[i_]*0.33333333333333331;\n" + "r_8_=pp->s1[i_]+2.0;\n" + "if(pp->s1[i_]==3.0){\n" " r_7_=2.0*r_8_;\n" "}\n" "else{\n" - " if(s1[i_]==4.0){\n" + " if(pp->s1[i_]==4.0){\n" " r_11_ = 0.;\n" " r_12_ = 0.;\n" - " r_12_=6.0+s1[i_];\n" + " r_12_=6.0+pp->s1[i_];\n" " r_11_=r_12_;\n" " r_7_=r_8_*r_11_;\n" " }\n" " else{\n" " r_10_=exp(r_8_);\n" - " r_7_=r_10_*s1[i_];\n" + " r_7_=r_10_*pp->s1[i_];\n" " }\n" "}\n" "r_13_=0.;\n" "r_14_=0.;\n" - "r_14_=r_9_/s2[i_];\n" + "r_14_=r_9_/pp->s2[i_];\n" "r_15_=log(r_14_);\n" "r_13_=42.0*r_15_;\n" "r_6_=r_9_*r_13_;\n" @@ -247,7 +247,7 @@ TEST(CPrinter, proc_body_inlined) { " t2=r_16_*ll0_;\n" " }\n" "}\n" - "s2[i_]=t2+4.0;\n"; + "pp->s2[i_]=t2+4.0;\n"; Module m(io::read_all(DATADIR "/mod_files/test6.mod"), "test6.mod"); Parser p(m, false); @@ -280,22 +280,22 @@ TEST(SimdPrinter, simd_if_else) { "simd_mask mask_0_ = S::cmp_gt(i, (double)2.0);\n" "S::where(mask_0_,u) = (double)7.0;\n" "S::where(S::logical_not(mask_0_),u) = (double)5.0;\n" - "indirect(s+i_, simd_width_) = S::where(S::logical_not(mask_0_),simd_cast<simd_value>((double)42.0));\n" - "indirect(s+i_, simd_width_) = u;" + "indirect(pp->s+i_, simd_width_) = S::where(S::logical_not(mask_0_),simd_cast<simd_value>((double)42.0));\n" + "indirect(pp->s+i_, simd_width_) = u;" , "simd_value u;\n" "simd_mask mask_1_ = S::cmp_gt(i, (double)2.0);\n" "S::where(mask_1_,u) = (double)7.0;\n" "S::where(S::logical_not(mask_1_),u) = (double)5.0;\n" - "indirect(s+i_, simd_width_) = S::where(S::logical_and(S::logical_not(mask_1_), mask_input_),simd_cast<simd_value>((double)42.0));\n" - "indirect(s+i_, simd_width_) = S::where(mask_input_, u);" + "indirect(pp->s+i_, simd_width_) = S::where(S::logical_and(S::logical_not(mask_1_), mask_input_),simd_cast<simd_value>((double)42.0));\n" + "indirect(pp->s+i_, simd_width_) = S::where(mask_input_, u);" , - "simd_mask mask_2_ = S::cmp_gt(simd_cast<simd_value>(indirect(g+i_, simd_width_)), (double)2.0);\n" - "simd_mask mask_3_ = S::cmp_gt(simd_cast<simd_value>(indirect(g+i_, simd_width_)), (double)3.0);\n" + "simd_mask mask_2_ = S::cmp_gt(simd_cast<simd_value>(indirect(pp->g+i_, simd_width_)), (double)2.0);\n" + "simd_mask mask_3_ = S::cmp_gt(simd_cast<simd_value>(indirect(pp->g+i_, simd_width_)), (double)3.0);\n" "S::where(S::logical_and(mask_2_,mask_3_),i) = (double)0.;\n" "S::where(S::logical_and(mask_2_,S::logical_not(mask_3_)),i) = (double)1.0;\n" - "simd_mask mask_4_ = S::cmp_lt(simd_cast<simd_value>(indirect(g+i_, simd_width_)), (double)1.0);\n" - "indirect(s+i_, simd_width_) = S::where(S::logical_and(S::logical_not(mask_2_),mask_4_),simd_cast<simd_value>((double)2.0));\n" + "simd_mask mask_4_ = S::cmp_lt(simd_cast<simd_value>(indirect(pp->g+i_, simd_width_)), (double)1.0);\n" + "indirect(pp->s+i_, simd_width_) = S::where(S::logical_and(S::logical_not(mask_2_),mask_4_),simd_cast<simd_value>((double)2.0));\n" "rates(i_, S::logical_and(S::logical_not(mask_2_),S::logical_not(mask_4_)), i);" }; diff --git a/test/unit/test_mechcat.cpp b/test/unit/test_mechcat.cpp index 60c9dac7662d1c63acf70c6069ef1e78b9e50f35..a1156b6452ef8158429c5abf0f56493f92faf489 100644 --- a/test/unit/test_mechcat.cpp +++ b/test/unit/test_mechcat.cpp @@ -57,7 +57,7 @@ mechanism_info fleeb_info = { template <typename B> struct common_impl: concrete_mechanism<B> { void instantiate(fvm_size_type id, typename B::shared_state& state, const mechanism_overrides& o, const mechanism_layout& l) override { - width_ = l.cv.size(); + this->width_ = l.cv.size(); // Write mechanism global values to shared state to test instatiation call and catalogue global // variable overrides. for (auto& kv: o.globals) { @@ -67,15 +67,13 @@ struct common_impl: concrete_mechanism<B> { for (auto& ion: mech_ions) { if (o.ion_rebind.count(ion)) { ion_bindings_[ion] = state.ions.at(o.ion_rebind.at(ion)); - } - else { + } else { ion_bindings_[ion] = state.ions.at(ion); } } } std::size_t memory() const override { return 10u; } - std::size_t size() const override { return width_; } void set_parameter(const std::string& key, const std::vector<fvm_value_type>& vs) override {} @@ -88,11 +86,12 @@ struct common_impl: concrete_mechanism<B> { void deliver_events() override {} void update_ions() override {} - std::size_t width_ = 0; - std::vector<std::string> mech_ions; std::unordered_map<std::string, std::string> ion_bindings_; + +protected: + mechanism_ppack* ppack_ptr() override { return nullptr; } }; template <typename B> @@ -105,10 +104,13 @@ struct foo_stream_state {}; struct foo_stream { using state = foo_stream_state; + state& marked_events() { return state_; } + state state_; }; struct foo_backend { using iarray = std::vector<fvm_index_type>; + using array = std::vector<fvm_value_type>; using deliverable_event_stream = foo_stream; struct shared_state { @@ -130,10 +132,13 @@ struct bar_stream_state {}; struct bar_stream { using state = bar_stream_state; + state& marked_events() { return state_; } + state state_; }; struct bar_backend { using iarray = std::vector<fvm_index_type>; + using array = std::vector<fvm_value_type>; using deliverable_event_stream = bar_stream; struct shared_state { std::unordered_map<std::string, fvm_value_type> overrides; diff --git a/test/unit/test_synapses.cpp b/test/unit/test_synapses.cpp index 4f3157234b8054c0041d54a8686ee358c5efa694..60aa02cb2d946df3b279baf7f40331ee04c5b2d6 100644 --- a/test/unit/test_synapses.cpp +++ b/test/unit/test_synapses.cpp @@ -6,6 +6,8 @@ #include <arbor/constants.hpp> #include <arbor/mechcat.hpp> +#include <arbor/mechanism.hpp> +#include <arbor/mechanism_ppack.hpp> #include <arbor/cable_cell.hpp> #include "backends/multicore/fvm.hpp" @@ -25,9 +27,7 @@ using value_type = backend::value_type; using size_type = backend::size_type; // Access to more mechanism protected data: - -ACCESS_BIND(const value_type* multicore::mechanism::*, vec_v_ptr, &multicore::mechanism::vec_v_) -ACCESS_BIND(value_type* multicore::mechanism::*, vec_i_ptr, &multicore::mechanism::vec_i_) +ACCESS_BIND(::arb::mechanism_ppack* (::arb::concrete_mechanism<backend>::*)(), pp_ptr, &::arb::concrete_mechanism<backend>::ppack_ptr); TEST(synapses, add_to_cell) { using namespace arb; @@ -130,17 +130,17 @@ TEST(synapses, syn_basic_state) { // Current and voltage views correctly hooked up? const value_type* v_ptr; - v_ptr = expsyn.get()->*vec_v_ptr; + v_ptr = (expsyn.get()->*pp_ptr)()->vec_v_; EXPECT_TRUE(all_equal_to(util::make_range(v_ptr, v_ptr+num_comp), -65.)); - v_ptr = exp2syn.get()->*vec_v_ptr; + v_ptr = (exp2syn.get()->*pp_ptr)()->vec_v_; EXPECT_TRUE(all_equal_to(util::make_range(v_ptr, v_ptr+num_comp), -65.)); const value_type* i_ptr; - i_ptr = expsyn.get()->*vec_i_ptr; + i_ptr = (expsyn.get()->*pp_ptr)()->vec_i_; EXPECT_TRUE(all_equal_to(util::make_range(i_ptr, i_ptr+num_comp), 1.)); - i_ptr = exp2syn.get()->*vec_i_ptr; + i_ptr = (exp2syn.get()->*pp_ptr)()->vec_i_; EXPECT_TRUE(all_equal_to(util::make_range(i_ptr, i_ptr+num_comp), 1.)); // Initialize state then check g, A, B have been set to zero.