From 3531f7eec5fa664a71b72d7f21bdff788906920d Mon Sep 17 00:00:00 2001 From: Nora Abi Akar <nora.abiakar@gmail.com> Date: Thu, 4 Feb 2021 12:01:29 +0100 Subject: [PATCH] Add `post events` functionality to support models with STDP synapses (#1255) * Added a `POST_EVENTS` procedure to nmodl, that takes an argument that represents the time since the last spike on the cell. In the event of multiple detectors on the cell, and multiple spikes on the detectors within the same integration period, all spikes will be processed by the synapse. Spikes are processed only once and then cleared. * Added 3 vectors to the shared state required to dispatch post-events: `cv_to_cell` map, `time_since_spike` holding max(num_detectors) slots per cell, and `src_to_spike` mapping spike sources (detectors) to slots in `time_since_spike`. * Renamed `vec_ci_` to `vec_di_` (to better reflect that it stands for **domain index**). Named the new `cv_to_cell` index as `vec_ci_` in the mechanisms. * Fixed existing unit tests and added new tests for the new post-events functionality. Fixes #1206 --- arbor/backends/gpu/fvm.hpp | 3 +- arbor/backends/gpu/mechanism.cpp | 5 +- arbor/backends/gpu/mechanism_ppack_base.hpp | 3 + arbor/backends/gpu/shared_state.cpp | 12 +- arbor/backends/gpu/shared_state.hpp | 15 ++- arbor/backends/gpu/stimulus.cu | 2 +- arbor/backends/gpu/threshold_watcher.cu | 27 +++- arbor/backends/gpu/threshold_watcher.hpp | 19 +-- arbor/backends/multicore/fvm.hpp | 3 +- arbor/backends/multicore/mechanism.cpp | 6 +- arbor/backends/multicore/mechanism.hpp | 5 +- arbor/backends/multicore/shared_state.cpp | 14 ++ arbor/backends/multicore/shared_state.hpp | 9 ++ arbor/backends/multicore/stimulus.cpp | 2 +- .../backends/multicore/threshold_watcher.hpp | 20 ++- arbor/fvm_layout.cpp | 6 + arbor/fvm_layout.hpp | 3 + arbor/fvm_lowered_cell_impl.hpp | 77 ++++++++--- arbor/include/arbor/mechanism.hpp | 1 + arbor/include/arbor/mechinfo.hpp | 2 + doc/internals/nmodl.rst | 29 +++- modcc/expression.cpp | 40 ++++++ modcc/expression.hpp | 21 +++ modcc/module.cpp | 2 + modcc/module.hpp | 2 + modcc/parser.cpp | 14 +- modcc/printer/cprinter.cpp | 23 ++++ modcc/printer/gpuprinter.cpp | 41 ++++++ modcc/printer/infoprinter.cpp | 4 +- modcc/printer/printerutil.cpp | 10 +- modcc/printer/printerutil.hpp | 2 + modcc/token.cpp | 2 + modcc/token.hpp | 2 +- test/unit/CMakeLists.txt | 33 ++--- test/unit/mod/post_events_syn.mod | 35 +++++ test/unit/test_fvm_lowered.cpp | 124 +++++++++++++++++- test/unit/test_kinetic_linear.cpp | 3 +- test/unit/test_mech_temp_diam.cpp | 6 +- test/unit/test_spikes.cpp | 49 ++++++- test/unit/test_synapses.cpp | 4 + test/unit/unit_test_catalogue.cpp | 4 +- 41 files changed, 597 insertions(+), 87 deletions(-) create mode 100644 test/unit/mod/post_events_syn.mod diff --git a/arbor/backends/gpu/fvm.hpp b/arbor/backends/gpu/fvm.hpp index 5cddad53..a3467f22 100644 --- a/arbor/backends/gpu/fvm.hpp +++ b/arbor/backends/gpu/fvm.hpp @@ -50,7 +50,7 @@ struct backend { using ion_state = arb::gpu::ion_state; static threshold_watcher voltage_watcher( - const shared_state& state, + shared_state& state, const std::vector<index_type>& cv, const std::vector<value_type>& thresholds, const execution_context& context) @@ -58,6 +58,7 @@ struct backend { return threshold_watcher( state.cv_to_intdom.data(), state.voltage.data(), + state.src_to_spike.data(), &state.time, &state.time_to, cv, diff --git a/arbor/backends/gpu/mechanism.cpp b/arbor/backends/gpu/mechanism.cpp index e577be3f..b96a0189 100644 --- a/arbor/backends/gpu/mechanism.cpp +++ b/arbor/backends/gpu/mechanism.cpp @@ -83,8 +83,10 @@ void mechanism::instantiate(unsigned id, mechanism_ppack_base* pp = ppack_ptr(); // From derived class instance. pp->width_ = width_; + pp->n_detectors_ = shared.n_detector; - pp->vec_ci_ = shared.cv_to_intdom.data(); + pp->vec_ci_ = shared.cv_to_cell.data(); + pp->vec_di_ = shared.cv_to_intdom.data(); pp->vec_dt_ = shared.dt_cv.data(); pp->vec_v_ = shared.voltage.data(); @@ -93,6 +95,7 @@ void mechanism::instantiate(unsigned id, pp->temperature_degC_ = shared.temperature_degC.data(); pp->diam_um_ = shared.diam_um.data(); + pp->time_since_spike_ = shared.time_since_spike.data(); auto ion_state_tbl = ion_state_table(); num_ions_ = ion_state_tbl.size(); diff --git a/arbor/backends/gpu/mechanism_ppack_base.hpp b/arbor/backends/gpu/mechanism_ppack_base.hpp index 0727a829..0a221aa1 100644 --- a/arbor/backends/gpu/mechanism_ppack_base.hpp +++ b/arbor/backends/gpu/mechanism_ppack_base.hpp @@ -29,8 +29,10 @@ struct mechanism_ppack_base { using ion_state_view = ::arb::gpu::ion_state_view; index_type width_; + index_type n_detectors_; const index_type* vec_ci_; + const index_type* vec_di_; const value_type* vec_t_; const value_type* vec_t_to_; const value_type* vec_dt_; @@ -39,6 +41,7 @@ struct mechanism_ppack_base { value_type* vec_g_; const value_type* temperature_degC_; const value_type* diam_um_; + const value_type* time_since_spike_; const index_type* node_index_; const index_type* multiplicity_; diff --git a/arbor/backends/gpu/shared_state.cpp b/arbor/backends/gpu/shared_state.cpp index 93489f1a..ee564845 100644 --- a/arbor/backends/gpu/shared_state.cpp +++ b/arbor/backends/gpu/shared_state.cpp @@ -90,17 +90,23 @@ void ion_state::reset() { shared_state::shared_state( fvm_size_type n_intdom, + fvm_size_type n_cell, + fvm_size_type n_detector, const std::vector<fvm_index_type>& cv_to_intdom_vec, + const std::vector<fvm_index_type>& cv_to_cell_vec, const std::vector<fvm_gap_junction>& gj_vec, const std::vector<fvm_value_type>& init_membrane_potential, const std::vector<fvm_value_type>& temperature_K, const std::vector<fvm_value_type>& diam, + const std::vector<fvm_index_type>& src_to_spike, unsigned // alignment parameter ignored. -): + ): n_intdom(n_intdom), + n_detector(n_detector), n_cv(cv_to_intdom_vec.size()), n_gj(gj_vec.size()), cv_to_intdom(make_const_view(cv_to_intdom_vec)), + cv_to_cell(make_const_view(cv_to_cell_vec)), gap_junctions(make_const_view(gj_vec)), time(n_intdom), time_to(n_intdom), @@ -112,8 +118,11 @@ shared_state::shared_state( init_voltage(make_const_view(init_membrane_potential)), temperature_degC(make_const_view(temperature_K)), diam_um(make_const_view(diam)), + time_since_spike(n_cell*n_detector), + src_to_spike(make_const_view(src_to_spike)), deliverable_events(n_intdom) { + memory::fill(time_since_spike, -1.0); add_scalar(temperature_degC.size(), temperature_degC.data(), -273.15); } @@ -133,6 +142,7 @@ void shared_state::reset() { memory::fill(conductivity, 0); memory::fill(time, 0); memory::fill(time_to, 0); + memory::fill(time_since_spike, -1.0); for (auto& i: ion_data) { i.second.reset(); diff --git a/arbor/backends/gpu/shared_state.hpp b/arbor/backends/gpu/shared_state.hpp index c375aa32..8d0254b3 100644 --- a/arbor/backends/gpu/shared_state.hpp +++ b/arbor/backends/gpu/shared_state.hpp @@ -61,11 +61,13 @@ struct ion_state { }; struct shared_state { - fvm_size_type n_intdom = 0; // Number of distinct integration domains. - fvm_size_type n_cv = 0; // Total number of CVs. - fvm_size_type n_gj = 0; // Total number of GJs. + fvm_size_type n_intdom = 0; // Number of distinct integration domains. + fvm_size_type n_detector = 0; // Max number of detectors on all cells. + fvm_size_type n_cv = 0; // Total number of CVs. + fvm_size_type n_gj = 0; // Total number of GJs. iarray cv_to_intdom; // Maps CV index to intdom index. + iarray cv_to_cell; // Maps CV index to cell index. gjarray gap_junctions; // Stores gap_junction info. array time; // Maps intdom index to integration start time [ms]. array time_to; // Maps intdom index to integration stop time [ms]. @@ -79,6 +81,9 @@ struct shared_state { array temperature_degC; // Maps CV to local temperature (read only) [°C]. array diam_um; // Maps CV to local diameter (read only) [µm]. + array time_since_spike; // Stores time since last spike on any detector, organized by cell. + iarray src_to_spike; // Maps spike source index to spike index + std::unordered_map<std::string, ion_state> ion_data; deliverable_event_stream deliverable_events; @@ -87,11 +92,15 @@ struct shared_state { shared_state( fvm_size_type n_intdom, + fvm_size_type n_cell, + fvm_size_type n_detector, const std::vector<fvm_index_type>& cv_to_intdom_vec, + const std::vector<fvm_index_type>& cv_to_cell_vec, const std::vector<fvm_gap_junction>& gj_vec, const std::vector<fvm_value_type>& init_membrane_potential, const std::vector<fvm_value_type>& temperature_K, const std::vector<fvm_value_type>& diam, + const std::vector<fvm_index_type>& src_to_spike, unsigned align ); diff --git a/arbor/backends/gpu/stimulus.cu b/arbor/backends/gpu/stimulus.cu index 1589d6e4..a207b376 100644 --- a/arbor/backends/gpu/stimulus.cu +++ b/arbor/backends/gpu/stimulus.cu @@ -12,7 +12,7 @@ namespace kernel { void stimulus_current_impl(int n, stimulus_pp pp) { auto i = threadIdx.x + blockDim.x*blockIdx.x; if (i<n) { - auto t = pp.vec_t_[pp.vec_ci_[i]]; + auto t = pp.vec_t_[pp.vec_di_[i]]; if (t>=pp.delay[i] && t<pp.delay[i]+pp.duration[i]) { // use subtraction because the electrode currents are specified // in terms of current into the compartment diff --git a/arbor/backends/gpu/threshold_watcher.cu b/arbor/backends/gpu/threshold_watcher.cu index 28180608..61706f9e 100644 --- a/arbor/backends/gpu/threshold_watcher.cu +++ b/arbor/backends/gpu/threshold_watcher.cu @@ -33,12 +33,15 @@ void test_thresholds_impl( const fvm_index_type* __restrict__ const cv_to_intdom, const fvm_value_type* __restrict__ const t_after, const fvm_value_type* __restrict__ const t_before, + const fvm_index_type* __restrict__ const src_to_spike, + fvm_value_type* __restrict__ const time_since_spike, stack_storage<threshold_crossing>& stack, fvm_index_type* __restrict__ const is_crossed, fvm_value_type* __restrict__ const prev_values, const fvm_index_type* __restrict__ const cv_index, const fvm_value_type* __restrict__ const values, - const fvm_value_type* __restrict__ const thresholds) + const fvm_value_type* __restrict__ const thresholds, + bool record_time_since_spike) { int i = threadIdx.x + blockIdx.x*blockDim.x; @@ -48,17 +51,27 @@ void test_thresholds_impl( if (i<size) { // Test for threshold crossing const auto cv = cv_index[i]; - const auto cell = cv_to_intdom[cv]; + const auto intdom = cv_to_intdom[cv]; const auto v_prev = prev_values[i]; const auto v = values[cv]; const auto thresh = thresholds[i]; + fvm_index_type spike_idx = 0; + // Reset all spike times to -1.0 indicating no spike has been recorded on the detector + if (record_time_since_spike) { + spike_idx = src_to_spike[i]; + time_since_spike[spike_idx] = -1.0; + } if (!is_crossed[i]) { if (v>=thresh) { // The threshold has been passed, so estimate the time using // linear interpolation auto pos = (thresh - v_prev)/(v - v_prev); - crossing_time = lerp(t_before[cell], t_after[cell], pos); + crossing_time = lerp(t_before[intdom], t_after[intdom], pos); + + if (record_time_since_spike) { + time_since_spike[spike_idx] = t_after[intdom] - crossing_time; + } is_crossed[i] = 1; crossed = true; @@ -95,15 +108,17 @@ extern void reset_crossed_impl( void test_thresholds_impl( int size, const fvm_index_type* cv_to_intdom, const fvm_value_type* t_after, const fvm_value_type* t_before, - stack_storage<threshold_crossing>& stack, + const fvm_index_type* src_to_spike, fvm_value_type* time_since_spike, stack_storage<threshold_crossing>& stack, fvm_index_type* is_crossed, fvm_value_type* prev_values, - const fvm_index_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds) + const fvm_index_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds, + bool record_time_since_spike) { if (size>0) { constexpr int block_dim = 128; const int grid_dim = impl::block_count(size, block_dim); kernel::test_thresholds_impl<<<grid_dim, block_dim>>>( - size, cv_to_intdom, t_after, t_before, stack, is_crossed, prev_values, cv_index, values, thresholds); + size, cv_to_intdom, t_after, t_before, src_to_spike, time_since_spike, + stack, is_crossed, prev_values, cv_index, values, thresholds, record_time_since_spike); } } diff --git a/arbor/backends/gpu/threshold_watcher.hpp b/arbor/backends/gpu/threshold_watcher.hpp index bf533073..6d371faa 100644 --- a/arbor/backends/gpu/threshold_watcher.hpp +++ b/arbor/backends/gpu/threshold_watcher.hpp @@ -22,9 +22,10 @@ namespace gpu { void test_thresholds_impl( int size, const fvm_index_type* cv_to_intdom, const fvm_value_type* t_after, const fvm_value_type* t_before, - stack_storage<threshold_crossing>& stack, + const fvm_index_type* src_to_spike, fvm_value_type* time_since_spike, stack_storage<threshold_crossing>& stack, fvm_index_type* is_crossed, fvm_value_type* prev_values, - const fvm_index_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds); + const fvm_index_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds, + bool record); void reset_crossed_impl( int size, @@ -45,6 +46,7 @@ public: threshold_watcher( const fvm_index_type* cv_to_intdom, const fvm_value_type* values, + const fvm_index_type* src_to_spike, const array* t_before, const array* t_after, const std::vector<fvm_index_type>& cv_index, @@ -53,6 +55,7 @@ public: ): cv_to_intdom_(cv_to_intdom), values_(values), + src_to_spike_(src_to_spike), t_before_ptr_(t_before), t_after_ptr_(t_after), cv_index_(memory::make_const_view(cv_index)), @@ -104,17 +107,16 @@ public: /// Crossing events are recorded for each threshold that has been /// crossed since current time t, and the last time the test was /// performed. - void test() { - const fvm_value_type* t_before = t_before_ptr_->data(); - const fvm_value_type* t_after = t_after_ptr_->data(); - + void test(array* time_since_spike) { if (size()>0) { test_thresholds_impl( (int)size(), - cv_to_intdom_, t_after, t_before, + cv_to_intdom_, t_after_ptr_->data(), t_before_ptr_->data(), + src_to_spike_, time_since_spike->data(), stack_.storage(), is_crossed_.data(), v_prev_.data(), - cv_index_.data(), values_, thresholds_.data()); + cv_index_.data(), values_, thresholds_.data(), + !time_since_spike->empty()); // Check that the number of spikes has not exceeded capacity. arb_assert(!stack_.overflow()); @@ -132,6 +134,7 @@ private: /// and pointers to the time arrays const fvm_index_type* cv_to_intdom_ = nullptr; const fvm_value_type* values_ = nullptr; + const fvm_index_type* src_to_spike_ = nullptr; const array* t_before_ptr_ = nullptr; const array* t_after_ptr_ = nullptr; diff --git a/arbor/backends/multicore/fvm.hpp b/arbor/backends/multicore/fvm.hpp index e81b7d7d..1920f801 100644 --- a/arbor/backends/multicore/fvm.hpp +++ b/arbor/backends/multicore/fvm.hpp @@ -48,7 +48,7 @@ struct backend { using ion_state = arb::multicore::ion_state; static threshold_watcher voltage_watcher( - const shared_state& state, + shared_state& state, const std::vector<index_type>& cv, const std::vector<value_type>& thresholds, const execution_context& context) @@ -56,6 +56,7 @@ struct backend { return threshold_watcher( state.cv_to_intdom.data(), state.voltage.data(), + state.src_to_spike.data(), &state.time, &state.time_to, cv, diff --git a/arbor/backends/multicore/mechanism.cpp b/arbor/backends/multicore/mechanism.cpp index 285c27dd..3e7f6c2f 100644 --- a/arbor/backends/multicore/mechanism.cpp +++ b/arbor/backends/multicore/mechanism.cpp @@ -83,7 +83,8 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me // Assign non-owning views onto shared state: - vec_ci_ = shared.cv_to_intdom.data(); + vec_ci_ = shared.cv_to_cell.data(); + vec_di_ = shared.cv_to_intdom.data(); vec_dt_ = shared.dt_cv.data(); vec_v_ = shared.voltage.data(); @@ -92,6 +93,9 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me temperature_degC_ = shared.temperature_degC.data(); diam_um_ = shared.diam_um.data(); + time_since_spike_ = shared.time_since_spike.data(); + + n_detectors_ = shared.n_detector; auto ion_state_tbl = ion_state_table(); n_ion_ = ion_state_tbl.size(); diff --git a/arbor/backends/multicore/mechanism.hpp b/arbor/backends/multicore/mechanism.hpp index 140c5d78..afd638c3 100644 --- a/arbor/backends/multicore/mechanism.hpp +++ b/arbor/backends/multicore/mechanism.hpp @@ -84,10 +84,12 @@ protected: size_type width_ = 0; // Instance width (number of CVs/sites) size_type width_padded_ = 0; // Width rounded up to multiple of pad/alignment. size_type n_ion_ = 0; + size_type n_detectors_ = 0; // Non-owning views onto shared cell state, excepting ion state. - const index_type* vec_ci_; // CV to cell index. + const index_type* vec_ci_; // CV to cell index + const index_type* vec_di_; // CV to indom index const value_type* vec_t_; // Cell index to cell-local time. const value_type* vec_t_to_; // Cell index to cell-local integration step time end. const value_type* vec_dt_; // CV to integration time step. @@ -96,6 +98,7 @@ protected: value_type* vec_g_; // CV to cell membrane conductivity. const value_type* temperature_degC_; // CV to temperature. const value_type* diam_um_; // CV to diameter. + const value_type* time_since_spike_; // Vector containing time since last spike, indexed by cell index and n_detectors_ const array* vec_t_ptr_; const array* vec_t_to_ptr_; diff --git a/arbor/backends/multicore/shared_state.cpp b/arbor/backends/multicore/shared_state.cpp index dfdc70b1..02f4e9f8 100644 --- a/arbor/backends/multicore/shared_state.cpp +++ b/arbor/backends/multicore/shared_state.cpp @@ -91,19 +91,25 @@ void ion_state::reset() { shared_state::shared_state( fvm_size_type n_intdom, + fvm_size_type n_cell, + fvm_size_type n_detector, const std::vector<fvm_index_type>& cv_to_intdom_vec, + const std::vector<fvm_index_type>& cv_to_cell_vec, const std::vector<fvm_gap_junction>& gj_vec, const std::vector<fvm_value_type>& init_membrane_potential, const std::vector<fvm_value_type>& temperature_K, const std::vector<fvm_value_type>& diam, + const std::vector<fvm_index_type>& src_to_spike, unsigned align ): alignment(min_alignment(align)), alloc(alignment), n_intdom(n_intdom), + n_detector(n_detector), n_cv(cv_to_intdom_vec.size()), n_gj(gj_vec.size()), cv_to_intdom(math::round_up(n_cv, alignment), pad(alignment)), + cv_to_cell(math::round_up(cv_to_cell_vec.size(), alignment), pad(alignment)), gap_junctions(math::round_up(n_gj, alignment), pad(alignment)), time(n_intdom, pad(alignment)), time_to(n_intdom, pad(alignment)), @@ -115,6 +121,8 @@ shared_state::shared_state( init_voltage(init_membrane_potential.begin(), init_membrane_potential.end(), pad(alignment)), temperature_degC(n_cv, pad(alignment)), diam_um(diam.begin(), diam.end(), pad(alignment)), + time_since_spike(n_cell*n_detector, pad(alignment)), + src_to_spike(src_to_spike.begin(), src_to_spike.end(), pad(alignment)), deliverable_events(n_intdom) { // For indices in the padded tail of cv_to_intdom, set index to last valid intdom index. @@ -122,11 +130,16 @@ shared_state::shared_state( std::copy(cv_to_intdom_vec.begin(), cv_to_intdom_vec.end(), cv_to_intdom.begin()); std::fill(cv_to_intdom.begin() + n_cv, cv_to_intdom.end(), cv_to_intdom_vec.back()); } + if (cv_to_cell_vec.size()) { + std::copy(cv_to_cell_vec.begin(), cv_to_cell_vec.end(), cv_to_cell.begin()); + std::fill(cv_to_cell.begin() + n_cv, cv_to_cell.end(), cv_to_cell_vec.back()); + } if (n_gj>0) { std::copy(gj_vec.begin(), gj_vec.end(), gap_junctions.begin()); std::fill(gap_junctions.begin()+n_gj, gap_junctions.end(), gj_vec.back()); } + util::fill(time_since_spike, -1.0); for (unsigned i = 0; i<n_cv; ++i) { temperature_degC[i] = temperature_K[i] - 273.15; } @@ -148,6 +161,7 @@ void shared_state::reset() { util::fill(conductivity, 0); util::fill(time, 0); util::fill(time_to, 0); + util::fill(time_since_spike, -1.0); for (auto& i: ion_data) { i.second.reset(); diff --git a/arbor/backends/multicore/shared_state.hpp b/arbor/backends/multicore/shared_state.hpp index df146a5c..c36908fd 100644 --- a/arbor/backends/multicore/shared_state.hpp +++ b/arbor/backends/multicore/shared_state.hpp @@ -79,10 +79,12 @@ struct shared_state { util::padded_allocator<> alloc; // Allocator with corresponging alignment/padding. fvm_size_type n_intdom = 0; // Number of integration domains. + fvm_size_type n_detector = 0; // Max number of detectors on all cells. fvm_size_type n_cv = 0; // Total number of CVs. fvm_size_type n_gj = 0; // Total number of GJs. iarray cv_to_intdom; // Maps CV index to integration domain index. + iarray cv_to_cell; // Maps CV index to the first spike gjarray gap_junctions; // Stores gap_junction info. array time; // Maps intdom index to integration start time [ms]. array time_to; // Maps intdom index to integration stop time [ms]. @@ -96,6 +98,9 @@ struct shared_state { array temperature_degC; // Maps CV to local temperature (read only) [°C]. array diam_um; // Maps CV to local diameter (read only) [µm]. + array time_since_spike; // Stores time since last spike on any detector, organized by cell. + iarray src_to_spike; // Maps spike source index to spike index + std::unordered_map<std::string, ion_state> ion_data; deliverable_event_stream deliverable_events; @@ -104,11 +109,15 @@ struct shared_state { shared_state( fvm_size_type n_intdom, + fvm_size_type n_cell, + fvm_size_type n_detector, const std::vector<fvm_index_type>& cv_to_intdom_vec, + const std::vector<fvm_index_type>& cv_to_cell_vec, const std::vector<fvm_gap_junction>& gj_vec, const std::vector<fvm_value_type>& init_membrane_potential, const std::vector<fvm_value_type>& temperature_K, const std::vector<fvm_value_type>& diam, + const std::vector<fvm_index_type>& src_to_spike, unsigned align ); diff --git a/arbor/backends/multicore/stimulus.cpp b/arbor/backends/multicore/stimulus.cpp index 11f6bdf3..d17d2505 100644 --- a/arbor/backends/multicore/stimulus.cpp +++ b/arbor/backends/multicore/stimulus.cpp @@ -24,7 +24,7 @@ public: size_type n = size(); for (size_type i=0; i<n; ++i) { auto cv = node_index_[i]; - auto t = vec_t_[vec_ci_[cv]]; + auto t = vec_t_[vec_di_[cv]]; if (t>=delay[i] && t<delay[i]+duration[i]) { // Amplitudes are given as a current into a compartment, so subtract. diff --git a/arbor/backends/multicore/threshold_watcher.hpp b/arbor/backends/multicore/threshold_watcher.hpp index 7074dda5..9b2d5e53 100644 --- a/arbor/backends/multicore/threshold_watcher.hpp +++ b/arbor/backends/multicore/threshold_watcher.hpp @@ -20,6 +20,7 @@ public: threshold_watcher( const fvm_index_type* cv_to_intdom, const fvm_value_type* values, + const fvm_index_type* src_to_spike, const array* t_before, const array* t_after, const std::vector<fvm_index_type>& cv_index, @@ -28,6 +29,7 @@ public: ): cv_to_intdom_(cv_to_intdom), values_(values), + src_to_spike_(src_to_spike), t_before_ptr_(t_before), t_after_ptr_(t_after), n_cv_(cv_index.size()), @@ -63,24 +65,35 @@ public: /// Tests each target for changed threshold state /// Crossing events are recorded for each threshold that /// is crossed since the last call to test - void test() { + void test(array* time_since_spike) { + // Reset all spike times to -1.0 indicating no spike has been recorded on the detector const fvm_value_type* t_before = t_before_ptr_->data(); const fvm_value_type* t_after = t_after_ptr_->data(); for (fvm_size_type i = 0; i<n_cv_; ++i) { auto cv = cv_index_[i]; - auto cell = cv_to_intdom_[cv]; + auto intdom = cv_to_intdom_[cv]; auto v_prev = v_prev_[i]; auto v = values_[cv]; auto thresh = thresholds_[i]; + fvm_index_type spike_idx = 0; + + if (!time_since_spike->empty()) { + spike_idx = src_to_spike_[i]; + (*time_since_spike)[spike_idx] = -1.0; + } if (!is_crossed_[i]) { if (v>=thresh) { // The threshold has been passed, so estimate the time using // linear interpolation. auto pos = (thresh - v_prev)/(v - v_prev); - auto crossing_time = math::lerp(t_before[cell], t_after[cell], pos); + auto crossing_time = math::lerp(t_before[intdom], t_after[intdom], pos); crossings_.push_back({i, crossing_time}); + if (!time_since_spike->empty()) { + (*time_since_spike)[spike_idx] = t_after[intdom] - crossing_time; + } + is_crossed_[i] = true; } } @@ -109,6 +122,7 @@ private: /// and pointers to the time arrays const fvm_index_type* cv_to_intdom_ = nullptr; const fvm_value_type* values_ = nullptr; + const fvm_index_type* src_to_spike_ = nullptr; const array* t_before_ptr_ = nullptr; const array* t_after_ptr_ = nullptr; diff --git a/arbor/fvm_layout.cpp b/arbor/fvm_layout.cpp index 2125fb1e..e97a6e4d 100644 --- a/arbor/fvm_layout.cpp +++ b/arbor/fvm_layout.cpp @@ -751,6 +751,8 @@ fvm_mechanism_data& append(fvm_mechanism_data& left, const fvm_mechanism_data& r } left.n_target += right.n_target; + left.post_events |= right.post_events; + append_divs(left.target_divs, right.target_divs); arb_assert(left.n_target==left.target_divs.back()); @@ -945,10 +947,13 @@ fvm_mechanism_data fvm_build_mechanism_data(const cable_cell_global_properties& std::vector<synapse_instance> inst_list; std::vector<size_type> cv_order; + bool post_events = false; + for (const auto& entry: cell.synapses()) { const std::string& name = entry.first; mechanism_info info = catalogue[name]; + post_events |= info.post_events; std::size_t n_param = info.parameters.size(); std::size_t n_inst = entry.second.size(); @@ -1065,6 +1070,7 @@ fvm_mechanism_data fvm_build_mechanism_data(const cable_cell_global_properties& M.n_target += config.target.size(); M.mechanisms[name] = std::move(config); } + M.post_events = post_events; // Stimuli: diff --git a/arbor/fvm_layout.hpp b/arbor/fvm_layout.hpp index 69557bc0..cdc7c4da 100644 --- a/arbor/fvm_layout.hpp +++ b/arbor/fvm_layout.hpp @@ -280,6 +280,9 @@ struct fvm_mechanism_data { // Partitions target numbers by cell. std::vector<std::size_t> target_divs; + + // Contains mechanisms with post_event + bool post_events = false; }; fvm_mechanism_data fvm_build_mechanism_data(const cable_cell_global_properties& gprop, const std::vector<cable_cell>& cells, const fvm_cv_discretization& D, const arb::execution_context& ctx={}); diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index 5dd419c3..d5126cf0 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -107,7 +107,10 @@ private: std::vector<mechanism_ptr> revpot_mechanisms_; // Non-physical voltage check threshold, 0 => no check. - value_type check_voltage_mV = 0; + value_type check_voltage_mV_ = 0; + + // Flag indicating that at least one of the mechanisms implements the post_events procedure + bool post_events_; // Host-side views/copies and local state. decltype(backend::host_view(sample_time_)) sample_time_host_; @@ -231,7 +234,6 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate( m->update_current(); } - // Deliver events and accumulate mechanism current contributions. PE(advance_integrate_events); @@ -291,15 +293,24 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate( // Update time and test for spike threshold crossings. PE(advance_integrate_threshold); - threshold_watcher_.test(); - std::swap(state_->time_to, state_->time); + threshold_watcher_.test(&state_->time_since_spike); + PL(); + + PE(advance_integrate_post) + if (post_events_) { + for (auto& m: mechanisms_) { + m->post_event(); + } + } PL(); + std::swap(state_->time_to, state_->time); + // Check for non-physical solutions: - if (check_voltage_mV>0) { + if (check_voltage_mV_>0) { PE(advance_integrate_physicalcheck); - assert_voltage_bounded(check_voltage_mV); + assert_voltage_bounded(check_voltage_mV_); PL(); } @@ -393,6 +404,26 @@ void fvm_lowered_cell_impl<Backend>::initialize( // (Throws cable_cell_error on failure.) check_global_properties(global_props); + // Sanity check recipe; find max num_sources and + // create a list of the global identifiers for the spike sources + + std::vector<fvm_size_type> nsources; + for (auto cell_idx: make_span(ncell)) { + cell_gid_type gid = gids[cell_idx]; + auto& cell = cells[cell_idx]; + + auto num_sources = rec.num_sources(gid); + nsources.push_back(num_sources); + + if (num_sources != cell.detectors().size()) { + throw arb::bad_source_description(gid, num_sources, cell.detectors().size()); + } + auto cell_targets = util::sum_by(cell.synapses(), [](auto& syn) { return syn.second.size(); }); + if (rec.num_targets(gid) > cell_targets) { + throw arb::bad_target_description(gid, rec.num_targets(gid), cell_targets); + } + } + const mechanism_catalogue* catalogue = global_props.catalogue; // Mechanism instantiator helper. @@ -403,9 +434,9 @@ void fvm_lowered_cell_impl<Backend>::initialize( // Check for physically reasonable membrane volages? - check_voltage_mV = global_props.membrane_voltage_limit_mV; + check_voltage_mV_ = global_props.membrane_voltage_limit_mV; - auto num_intdoms = fvm_intdom(rec, gids, cell_to_intdom); + auto nintdom = fvm_intdom(rec, gids, cell_to_intdom); // Discretize cells, build matrix. @@ -418,7 +449,7 @@ void fvm_lowered_cell_impl<Backend>::initialize( arb_assert(D.n_cell() == ncell); matrix_ = matrix<backend>(D.geometry.cv_parent, D.geometry.cell_cv_divs, D.cv_capacitance, D.face_conductance, D.cv_area, cell_to_intdom); - sample_events_ = sample_event_stream(num_intdoms); + sample_events_ = sample_event_stream(nintdom); // Discretize mechanism data. @@ -428,6 +459,21 @@ void fvm_lowered_cell_impl<Backend>::initialize( auto gj_vector = fvm_gap_junctions(cells, gids, rec, D); + // Fill src_to_spike and cv_to_cell vectors only if mechanisms with post_events implemented are present. + post_events_ = mech_data.post_events; + auto max_detector = post_events_ ? util::max_value(nsources) : 0; + std::vector<fvm_index_type> src_to_spike, cv_to_cell; + + if (post_events_) { + for (auto cell_idx: make_span(ncell)) { + for (auto lid: make_span(nsources[cell_idx])) { + src_to_spike.push_back(cell_idx * max_detector + lid); + } + } + src_to_spike.shrink_to_fit(); + cv_to_cell = D.geometry.cv_to_cell; + } + // Create shared cell state. // (SIMD padding requires us to check each mechanism for alignment/padding constraints.) @@ -436,7 +482,8 @@ void fvm_lowered_cell_impl<Backend>::initialize( [&](const std::string& name) { return mech_instance(name).mech->data_alignment(); })); state_ = std::make_unique<shared_state>( - num_intdoms, cv_to_intdom, gj_vector, D.init_membrane_potential, D.temperature_K, D.diam_um, + nintdom, ncell, max_detector, cv_to_intdom, std::move(cv_to_cell), gj_vector, + D.init_membrane_potential, D.temperature_K, D.diam_um, std::move(src_to_spike), data_alignment? data_alignment: 1u); // Instantiate mechanisms and ions. @@ -534,16 +581,6 @@ void fvm_lowered_cell_impl<Backend>::initialize( for (auto cell_idx: make_span(ncell)) { cell_gid_type gid = gids[cell_idx]; - // Sanity check recipe - auto& cell = cells[cell_idx]; - if (rec.num_sources(gid) != cell.detectors().size()) { - throw arb::bad_source_description(gid, rec.num_sources(gid), cell.detectors().size());; - } - auto cell_targets = util::sum_by(cell.synapses(), [](auto& syn) {return syn.second.size();}); - if (rec.num_targets(gid) > cell_targets) { - throw arb::bad_target_description(gid, rec.num_targets(gid), cell_targets); - } - // Collect detectors, probe handles. for (auto entry: cells[cell_idx].detectors()) { detector_cv.push_back(D.geometry.location_cv(cell_idx, entry.loc, cv_prefer::cv_empty)); diff --git a/arbor/include/arbor/mechanism.hpp b/arbor/include/arbor/mechanism.hpp index 5e87a441..c1361f30 100644 --- a/arbor/include/arbor/mechanism.hpp +++ b/arbor/include/arbor/mechanism.hpp @@ -54,6 +54,7 @@ public: virtual void update_state() {}; virtual void update_current() {}; virtual void deliver_events() {}; + virtual void post_event() {}; virtual void update_ions() {}; virtual ~mechanism() = default; diff --git a/arbor/include/arbor/mechinfo.hpp b/arbor/include/arbor/mechinfo.hpp index f3cc748a..45785e31 100644 --- a/arbor/include/arbor/mechinfo.hpp +++ b/arbor/include/arbor/mechinfo.hpp @@ -69,6 +69,8 @@ struct mechanism_info { mechanism_fingerprint fingerprint; bool linear = false; + + bool post_events = false; }; } // namespace arb diff --git a/doc/internals/nmodl.rst b/doc/internals/nmodl.rst index 84474d31..4757a848 100644 --- a/doc/internals/nmodl.rst +++ b/doc/internals/nmodl.rst @@ -31,7 +31,7 @@ Ions must be added explicitly in Arbor along with their default properties and valence (this can be done in the recipe or on a single cell model). Simply specifying them in NMODL will not work. -* The parameters and variabnles of each ion referenced in a ``USEION`` statement +* The parameters and variables of each ion referenced in a ``USEION`` statement are available automatically to the mechanism. The exposed variables are: internal concentration ``Xi``, external concentration ``Xo``, reversal potential ``eX`` and current ``iX``. It is an error to also mark these as @@ -42,18 +42,19 @@ Ions * If ``Xi``, ``Xo``, ``eX``, ``iX`` are used in a ``PROCEDURE`` or ``FUNCTION``, they need to be passed as arguments. * If ``Xi`` or ``Xo`` (internal and external concentrations) are written in the - NMODL mechanism they need to be specified as ``STATE`` variables. + NMODL mechanism they need to be declared as ``STATE`` variables and their initial + values have to be set in the mechanism. Special variables ----------------- * Arbor exposes some parameters from the simulation to the NMODL mechanisms. - These include ``v``, ``diam``, ``celsius`` in addition to the previously + These include ``v``, ``diam``, ``celsius`` and ``t`` in addition to the previously mentioned ion parameters. * Special variables should not be ``ASSIGNED`` or ``CONSTANT``, they are ``PARAMETER``. * ``diam`` and ``celsius`` can be set from the simulation side. -* ``v`` is a reserved varible name and can be written in NMODL. +* ``v`` is a reserved variable name and can be written in NMODL. * If Special variables are used in a ``PROCEDURE`` or ``FUNCTION``, they need to be passed as arguments. * ``dt`` is not exposed to NMODL mechanisms. @@ -76,4 +77,24 @@ Unsupported features However, ``CONSERVE`` statements are supported. * ``TABLE`` is not supported, calculations are exact. * ``derivimplicit`` solving method is not supported, use ``cnexp`` instead. +* `verbatim` blocks are not supported. +Arbor-specific features +----------------------- + +* Arbor's NMODL dialect supports the most widely used features of NEURON. It also + has some features unavailable in NEURON such as the ``POST_EVENT`` procedure block. + This procedure has a single argument representing the time since the last spike on + the cell. In the event of multiple detectors on the cell, and multiple spikes on the + detectors within the same integration period, the times of each of these spikes will + be processed by the ``POST_EVENT`` block. Spikes are processed only once and then + cleared. + + Example of a ``POST_EVENT`` procedure, where ``g`` is a ``STATE`` parameter representing + the conductance: + + .. code:: + + POST_EVENT(t) { + g = g + (0.1*t) + } \ No newline at end of file diff --git a/modcc/expression.cpp b/modcc/expression.cpp index 51073654..a2eb8daa 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -31,6 +31,8 @@ inline std::string to_string(procedureKind k) { return "initial"; case procedureKind::net_receive : return "net_receive"; + case procedureKind::post_event : + return "post_event"; case procedureKind::breakpoint : return "breakpoint"; case procedureKind::derivative : @@ -715,6 +717,41 @@ void NetReceiveExpression::semantic(scope_type::symbol_map &global_symbols) { symbol_ = scope_->find_global(name()); } +/******************************************************************************* + PostEventExpression +*******************************************************************************/ + +void PostEventExpression::semantic(scope_type::symbol_map &global_symbols) { + // assert that the symbol is already visible in the global_symbols + if(global_symbols.find(name()) == global_symbols.end()) { + throw compiler_exception( + "attempt to perform semantic analysis for procedure '" + + yellow(name()) + + "' which has not been added to global symbol table", + location_); + } + + // create the scope for this procedure + scope_ = std::make_shared<scope_type>(global_symbols); + error_ = false; + + // add the argumemts to the list of local variables + for(auto& a : args_) { + a->semantic(scope_); + if(a->has_error()) { + error(a->error_message()); + } + } + + // perform semantic analysis for each expression in the body + body_->semantic(scope_); + if(body_->has_error()) { + error(body_->error_message()); + } + + symbol_ = scope_->find_global(name()); +} + /******************************************************************************* FunctionExpression *******************************************************************************/ @@ -1105,6 +1142,9 @@ void ProcedureExpression::accept(Visitor *v) { void NetReceiveExpression::accept(Visitor *v) { v->visit(this); } +void PostEventExpression::accept(Visitor *v) { + v->visit(this); +} void APIMethod::accept(Visitor *v) { v->visit(this); } diff --git a/modcc/expression.hpp b/modcc/expression.hpp index 50c5c1b1..b208ac4a 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -47,6 +47,7 @@ class PDiffExpression; class VariableExpression; class ProcedureExpression; class NetReceiveExpression; +class PostEventExpression; class APIMethod; class IndexedVariable; class LocalVariable; @@ -79,6 +80,7 @@ enum class procedureKind { api, ///< API PROCEDURE initial, ///< INITIAL net_receive, ///< NET_RECEIVE + post_event, ///< POST_EVENT breakpoint, ///< BREAKPOINT kinetic, ///< KINETIC derivative, ///< DERIVATIVE @@ -234,6 +236,7 @@ public : virtual VariableExpression* is_variable() {return nullptr;} virtual ProcedureExpression* is_procedure() {return nullptr;} virtual NetReceiveExpression* is_net_receive() {return nullptr;} + virtual PostEventExpression* is_post_event() {return nullptr;} virtual APIMethod* is_api_method() {return nullptr;} virtual IndexedVariable* is_indexed_variable() {return nullptr;} virtual LocalVariable* is_local_variable() {return nullptr;} @@ -1120,6 +1123,24 @@ protected: InitialBlock* initial_block_ = nullptr; }; +/// handle PostEventExpression as a special case of ProcedureExpression +class PostEventExpression : public ProcedureExpression { +public: + PostEventExpression( Location loc, + std::string name, + std::vector<expression_ptr>&& args, + expression_ptr&& body) + : ProcedureExpression(loc, std::move(name), std::move(args), std::move(body), procedureKind::post_event) + {} + + void semantic(scope_type::symbol_map &scp) override; + PostEventExpression* is_post_event() override {return this;} + /// hard code the kind + procedureKind kind() {return procedureKind::post_event;} + + void accept(Visitor *v) override; +}; + class FunctionExpression : public Symbol { public: FunctionExpression( Location loc, diff --git a/modcc/module.cpp b/modcc/module.cpp index a9b86b6d..d099ef23 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -537,6 +537,8 @@ bool Module::semantic() { } linear_ = linear; + post_events_ = has_symbol("post_event", symbolKind::procedure); + // Are we writing an ionic reversal potential? If so, change the moduleKind to // `revpot` and assert that the mechanism is 'pure': it has no state variables; // it contributes to no currents, ionic or otherwise; it isn't a point mechanism; diff --git a/modcc/module.hpp b/modcc/module.hpp index 09e29d40..f29ddd06 100644 --- a/modcc/module.hpp +++ b/modcc/module.hpp @@ -104,6 +104,7 @@ public: }; bool is_linear() const { return linear_; } + bool has_post_events() const { return post_events_; } private: moduleKind kind_; @@ -118,6 +119,7 @@ private: ParameterBlock parameter_block_; AssignedBlock assigned_block_; bool linear_; + bool post_events_; // AST storage. std::vector<symbol_ptr> callables_; diff --git a/modcc/parser.cpp b/modcc/parser.cpp index d52f9416..86ae2acc 100644 --- a/modcc/parser.cpp +++ b/modcc/parser.cpp @@ -108,6 +108,7 @@ bool Parser::parse() { case tok::net_receive: case tok::breakpoint: case tok::initial: + case tok::post_event: case tok::kinetic: case tok::linear: case tok::derivative: @@ -932,6 +933,10 @@ symbol_ptr Parser::parse_procedure() { kind = procedureKind::net_receive; p = parse_prototype("net_receive"); break; + case tok::post_event: + kind = procedureKind::post_event; + p = parse_prototype("post_event"); + break; default: // it is a compiler error if trying to parse_procedure() without // having DERIVATIVE, KINETIC, PROCEDURE, INITIAL or BREAKPOINT keyword @@ -949,12 +954,13 @@ symbol_ptr Parser::parse_procedure() { if (body == nullptr) return nullptr; auto proto = p->is_prototype(); - if (kind != procedureKind::net_receive) { - return make_symbol<ProcedureExpression>(proto->location(), proto->name(), std::move(proto->args()), std::move(body), kind); + if(kind == procedureKind::net_receive) { + return make_symbol<NetReceiveExpression> (proto->location(), proto->name(), std::move(proto->args()), std::move(body)); } - else { - return make_symbol<NetReceiveExpression>(proto->location(), proto->name(), std::move(proto->args()), std::move(body)); + if(kind == procedureKind::post_event) { + return make_symbol<PostEventExpression> (proto->location(), proto->name(), std::move(proto->args()), std::move(body)); } + return make_symbol<ProcedureExpression> (proto->location(), proto->name(), std::move(proto->args()), std::move(body), kind); } symbol_ptr Parser::parse_function() { diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index 0e333686..23ecf2f0 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -108,6 +108,7 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { auto ns_components = namespace_components(opt.cpp_namespace); NetReceiveExpression* net_receive = find_net_receive(module_); + PostEventExpression* post_event = find_post_event(module_); APIMethod* init_api = find_api_method(module_, "nrn_init"); APIMethod* state_api = find_api_method(module_, "nrn_state"); APIMethod* current_api = find_api_method(module_, "nrn_current"); @@ -254,6 +255,9 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "void deliver_events(deliverable_event_stream::state events) override;\n" "void net_receive(int i_, value_type weight);\n"; + post_event && out << + "void post_event() override;\n"; + with_simd && out << "unsigned simd_width() const override { return simd_width_; }\n"; out << @@ -388,6 +392,25 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "}\n\n"; } + if(post_event) { + const std::string time_arg = post_event->args().empty() ? "time" : post_event->args().front()->is_argument()->name(); + out << + "void " << class_name << "::post_event() {\n" << indent << + "int n_ = width_;\n" + "for (int i_ = 0; i_ < n_; ++i_) {\n" << indent << + "auto node_index_i_ = node_index_[i_];\n" + "auto cid_ = vec_ci_[node_index_i_];\n" + "auto offset_ = n_detectors_ * cid_;\n" + "for (unsigned c = 0; c < n_detectors_; c++) {\n" << indent << + "auto " << time_arg << " = time_since_spike_[offset_ + c];\n" + "if (" << time_arg << " >= 0) {\n" << indent << + cprint(post_event->body()) << popindent << + "}\n" << popindent << + "}\n" << popindent << + "}\n" << popindent << + "}\n\n"; + } + auto emit_body = [&](APIMethod *p) { if (with_simd) { emit_simd_api_body(out, p, vars.scalars); diff --git a/modcc/printer/gpuprinter.cpp b/modcc/printer/gpuprinter.cpp index b0f67d14..31148051 100644 --- a/modcc/printer/gpuprinter.cpp +++ b/modcc/printer/gpuprinter.cpp @@ -55,6 +55,7 @@ std::string emit_gpu_cpp_source(const Module& module_, const printer_options& op auto ns_components = namespace_components(opt.cpp_namespace); NetReceiveExpression* net_receive = find_net_receive(module_); + PostEventExpression* post_event = find_post_event(module_); auto vars = local_module_variables(module_); auto ion_deps = module_.ion_deps(); @@ -85,6 +86,10 @@ std::string emit_gpu_cpp_source(const Module& module_, const printer_options& op "void " << class_name << "_deliver_events_(int mech_id, " << ppack_name << "&, deliverable_event_stream_state events);\n"; + post_event && out << + "void " << class_name << "_post_event_(" << ppack_name << "&);\n"; + + out << "\n" "class " << class_name << ": public ::arb::gpu::mechanism {\n" @@ -115,6 +120,11 @@ std::string emit_gpu_cpp_source(const Module& module_, const printer_options& op class_name << "_deliver_events_(mechanism_id_, pp_, events);\n" << popindent << "}\n\n"; + post_event && out << + "void post_event() override {\n" << indent << + class_name << "_post_event_(pp_);\n" << popindent << + "}\n\n"; + out << popindent << "protected:\n" << indent << "std::size_t object_sizeof() const override { return sizeof(*this); }\n" @@ -215,6 +225,7 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt const bool is_point_proc = module_.kind() == moduleKind::point; NetReceiveExpression* net_receive = find_net_receive(module_); + PostEventExpression* post_event = find_post_event(module_); APIMethod* init_api = find_api_method(module_, "nrn_init"); APIMethod* state_api = find_api_method(module_, "nrn_state"); APIMethod* current_api = find_api_method(module_, "nrn_current"); @@ -314,6 +325,27 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt << popindent << "}\n"; } + // event delivery + if (post_event) { + const std::string time_arg = post_event->args().empty() ? "time" : post_event->args().front()->is_argument()->name(); + out << "__global__\n" + << "void post_event(" << ppack_name << " params_) {\n" << indent + << "int n_ = params_.width_;\n" + << "auto tid_ = threadIdx.x + blockDim.x*blockIdx.x;\n" + << "if (tid_<n_) {\n" << indent + << "auto node_index_i_ = params_.node_index_[tid_];\n" + << "auto cid_ = params_.vec_ci_[node_index_i_];\n" + << "auto offset_ = params_.n_detectors_ * cid_;\n" + << "for (unsigned c = 0; c < params_.n_detectors_; c++) {\n" << indent + << "auto " << time_arg << " = params_.time_since_spike_[offset_ + c];\n" + << "if (" << time_arg << " >= 0) {\n" << indent + << cuprint(post_event->body()) + << popindent << "}\n" + << popindent << "}\n" + << popindent << "}\n" + << popindent << "}\n"; + } + out << "} // namspace\n\n"; // close anonymous namespace // Write wrappers. @@ -346,6 +378,15 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt << "deliver_events<<<grid_dim, block_dim>>>(mech_id, p, events);\n" << popindent << "}\n\n"; + post_event && out + << "void " << class_name << "_post_event_(" + << ppack_name << "& p) {\n" << indent + << "auto n = p.width_;\n" + << "unsigned block_dim = 128;\n" + << "unsigned grid_dim = ::arb::gpu::impl::block_count(n, block_dim);\n" + << "post_event<<<grid_dim, block_dim>>>(p);\n" + << popindent << "}\n\n"; + out << namespace_declaration_close(ns_components); return out.str(); } diff --git a/modcc/printer/infoprinter.cpp b/modcc/printer/infoprinter.cpp index a5dcf604..7b4ea892 100644 --- a/modcc/printer/infoprinter.cpp +++ b/modcc/printer/infoprinter.cpp @@ -131,7 +131,9 @@ std::string build_info_header(const Module& m, const printer_options& opt) { "// fingerprint\n" << quote(fingerprint) << ",\n" "// linear, homogeneous mechanism\n" - << m.is_linear() << "\n" + << m.is_linear() << ",\n" + "// post_events enabled mechanism\n" + << m.has_post_events() << "\n" << popindent << "};\n" "\n" "return info;\n" diff --git a/modcc/printer/printerutil.cpp b/modcc/printer/printerutil.cpp index 787a8ebe..5309e528 100644 --- a/modcc/printer/printerutil.cpp +++ b/modcc/printer/printerutil.cpp @@ -45,7 +45,8 @@ std::vector<ProcedureExpression*> normal_procedures(const Module& m) { for (auto& sym: m.symbols()) { auto proc = sym.second->is_procedure(); - if (proc && proc->kind()==procedureKind::normal && !proc->is_api_method() && !proc->is_net_receive()) { + if (proc && proc->kind()==procedureKind::normal && !proc->is_api_method() + && !proc->is_net_receive() && !proc->is_post_event()) { procs.push_back(proc); } } @@ -113,6 +114,11 @@ NetReceiveExpression* find_net_receive(const Module& m) { return it==m.symbols().end()? nullptr: it->second->is_net_receive(); } +PostEventExpression* find_post_event(const Module& m) { + auto it = m.symbols().find("post_event"); + return it==m.symbols().end()? nullptr: it->second->is_post_event(); +} + indexed_variable_info decode_indexed_variable(IndexedVariable* sym) { indexed_variable_info v; v.node_index_var = "node_index_"; @@ -157,7 +163,7 @@ indexed_variable_info decode_indexed_variable(IndexedVariable* sym) { break; case sourceKind::time: v.data_var = "vec_t_"; - v.cell_index_var = "vec_ci_"; + v.cell_index_var = "vec_di_"; v.readonly = true; break; case sourceKind::ion_current_density: diff --git a/modcc/printer/printerutil.hpp b/modcc/printer/printerutil.hpp index 4fbae528..902290a1 100644 --- a/modcc/printer/printerutil.hpp +++ b/modcc/printer/printerutil.hpp @@ -114,6 +114,8 @@ APIMethod* find_api_method(const Module& m, const char* which); NetReceiveExpression* find_net_receive(const Module& m); +PostEventExpression* find_post_event(const Module& m); + struct indexed_variable_info { std::string data_var; std::string node_index_var; diff --git a/modcc/token.cpp b/modcc/token.cpp index 30ed8def..e2226c7e 100644 --- a/modcc/token.cpp +++ b/modcc/token.cpp @@ -36,6 +36,7 @@ static Keyword keywords[] = { {"FUNCTION", tok::function}, {"INITIAL", tok::initial}, {"NET_RECEIVE", tok::net_receive}, + {"POST_EVENT", tok::post_event}, {"UNITSOFF", tok::unitsoff}, {"UNITSON", tok::unitson}, {"SUFFIX", tok::suffix}, @@ -117,6 +118,7 @@ static TokenString token_strings[] = { {"FUNCTION", tok::function}, {"INITIAL", tok::initial}, {"NET_RECEIVE", tok::net_receive}, + {"POST_EVENT", tok::post_event}, {"UNITSOFF", tok::unitsoff}, {"UNITSON", tok::unitson}, {"SUFFIX", tok::suffix}, diff --git a/modcc/token.hpp b/modcc/token.hpp index 21648dcd..e3a7a12c 100644 --- a/modcc/token.hpp +++ b/modcc/token.hpp @@ -54,7 +54,7 @@ enum class tok { neuron, units, parameter, constant, assigned, state, breakpoint, derivative, kinetic, procedure, initial, function, linear, - net_receive, + net_receive, post_event, // keywoards inside blocks unitsoff, unitson, diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 04075261..dfda9efc 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -4,10 +4,14 @@ set(test_mechanisms ca_linear celsius_test diam_test + fixed_ica_current + linear_ca_conc + non_linear param_as_state - test_linear_state - test_linear_init - test_linear_init_shuffle + point_ica_current + post_events_syn + read_cai_init + read_eX test0_kin_diff test0_kin_conserve test0_kin_compartment @@ -15,25 +19,22 @@ set(test_mechanisms test1_kin_diff test1_kin_conserve test1_kin_compartment + test1_kin_steadystate test2_kin_diff test3_kin_diff test4_kin_compartment - test1_kin_steadystate - fixed_ica_current - point_ica_current - non_linear - linear_ca_conc - test_cl_valence - test_ca_read_valence - read_eX - write_Xi_Xo - write_multiple_eX - write_eX - read_cai_init - write_cai_breakpoint test_ca + test_ca_read_valence + test_cl_valence + test_linear_state + test_linear_init + test_linear_init_shuffle test_kin1 test_kinlva + write_cai_breakpoint + write_eX + write_multiple_eX + write_Xi_Xo ) include(${PROJECT_SOURCE_DIR}/mechanisms/BuildModules.cmake) diff --git a/test/unit/mod/post_events_syn.mod b/test/unit/mod/post_events_syn.mod new file mode 100644 index 00000000..a0292b12 --- /dev/null +++ b/test/unit/mod/post_events_syn.mod @@ -0,0 +1,35 @@ +NEURON { + POINT_PROCESS post_events_syn + RANGE tau, e + NONSPECIFIC_CURRENT i +} + +PARAMETER { + tau = 2.0 (ms) : the default for Neuron is 0.1 + e = 0 (mV) +} + +STATE { + g +} + +INITIAL { + g=0 +} + +BREAKPOINT { + SOLVE state METHOD cnexp + i = g*(v - e) +} + +DERIVATIVE state { + g' = -g/tau +} + +NET_RECEIVE(weight) { + g = g + weight +} + +POST_EVENT(t) { + g = g + (0.1*t) +} diff --git a/test/unit/test_fvm_lowered.cpp b/test/unit/test_fvm_lowered.cpp index e95018ac..82b6bba4 100644 --- a/test/unit/test_fvm_lowered.cpp +++ b/test/unit/test_fvm_lowered.cpp @@ -604,6 +604,7 @@ TEST(fvm_lowered, ionic_concentrations) { std::vector<fvm_value_type> diam(ncv, 1.); std::vector<fvm_value_type> vinit(ncv, -65); std::vector<fvm_gap_junction> gj = {}; + std::vector<fvm_index_type> src_to_spike = {}; fvm_ion_config ion_config; mechanism_layout layout; @@ -627,7 +628,7 @@ TEST(fvm_lowered, ionic_concentrations) { auto& write_cai_mech = write_cai.mech; auto shared_state = std::make_unique<typename backend::shared_state>( - ncell, cv_to_intdom, gj, vinit, temp, diam, read_cai_mech->data_alignment()); + ncell, ncell, 0, cv_to_intdom, cv_to_intdom, gj, vinit, temp, diam, src_to_spike, read_cai_mech->data_alignment()); shared_state->add_ion("ca", 2, ion_config); read_cai_mech->instantiate(0, *shared_state, overrides, layout); @@ -1207,3 +1208,124 @@ TEST(fvm_lowered, integration_domains) { } } +TEST(fvm_lowered, post_events_shared_state) { + arb::proc_allocation resources; + if (auto nt = arbenv::get_env_num_threads()) { + resources.num_threads = nt; + } else { + resources.num_threads = arbenv::thread_concurrency(); + } + arb::execution_context context(resources); + + class detector_recipe: public arb::recipe { + public: + detector_recipe(unsigned ncv, std::vector<unsigned> detectors_per_cell, std::string synapse): + ncell_(detectors_per_cell.size()), + ncv_(ncv), + detectors_per_cell_(detectors_per_cell), + synapse_(synapse), + cat_(make_unit_test_catalogue()) { + const auto default_cat = arb::global_default_catalogue(); + cat_.import(default_cat, ""); + } + + cell_size_type num_cells() const override { + return ncell_; + } + + arb::util::unique_any get_cell_description(cell_gid_type gid) const override { + arb::segment_tree tree; + tree.append(arb::mnpos, {0, 0, 0.0, 1.0}, {0, 0, 200, 1.0}, 1); + + arb::decor decor; + decor.set_default(arb::cv_policy_fixed_per_branch(ncv_)); + + auto ndetectors = detectors_per_cell_[gid]; + auto offset = 1.0 / ndetectors; + for (unsigned i = 0; i < ndetectors; ++i) { + decor.place(arb::mlocation{0, offset * i}, arb::threshold_detector{10}); + } + decor.place(arb::mlocation{0, 0.5}, synapse_); + + return arb::cable_cell(arb::morphology(tree), {}, decor);; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable; + } + + // Each cell has one spike detector (at the soma). + cell_size_type num_sources(cell_gid_type gid) const override { + return detectors_per_cell_[gid]; + } + + // The cell has one target synapse, which will be connected to cell gid-1. + cell_size_type num_targets(cell_gid_type gid) const override { + return 1; + } + + std::any get_global_properties(arb::cell_kind) const override { + arb::cable_cell_global_properties gprop; + gprop.default_parameters = arb::neuron_parameter_defaults; + gprop.catalogue = &cat_; + return gprop; + } + + private: + unsigned ncell_; + unsigned ncv_; + std::vector<unsigned> detectors_per_cell_; + arb::mechanism_desc synapse_; + mechanism_catalogue cat_; + }; + + std::vector<target_handle> targets; + probe_association_map probe_map; + + std::vector<unsigned> gids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + const unsigned ncell = gids.size(); + const unsigned cv_per_cell = 10; + + std::vector<std::vector<unsigned>> detectors_per_cell_vec = { + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 6, 2, 1, 3, 2, 1, 2, 1, 4}, + }; + + for (const auto& detectors_per_cell: detectors_per_cell_vec) { + detector_recipe rec(cv_per_cell, detectors_per_cell, "post_events_syn"); + std::vector<fvm_index_type> cell_to_intdom; + + fvm_cell fvcell(context); + fvcell.initialize(gids, rec, cell_to_intdom, targets, probe_map); + + auto& S = fvcell.*private_state_ptr; + + auto expected_detectors = util::max_value(detectors_per_cell); + + EXPECT_EQ(expected_detectors, S->n_detector); + EXPECT_EQ(util::sum(detectors_per_cell), S->src_to_spike.size()); + EXPECT_EQ(expected_detectors * ncell, S->time_since_spike.size()); + + unsigned detector_id = 0; + for (unsigned c = 0; c < detectors_per_cell.size(); ++c) { + for (unsigned d = 0; d < detectors_per_cell[c]; ++d) { + EXPECT_EQ((int) (c * expected_detectors + d), S->src_to_spike[detector_id++]); + } + } + } + for (const auto& detectors_per_cell: detectors_per_cell_vec) { + detector_recipe rec(cv_per_cell, detectors_per_cell, "expsyn"); + std::vector<fvm_index_type> cell_to_intdom; + + fvm_cell fvcell(context); + fvcell.initialize(gids, rec, cell_to_intdom, targets, probe_map); + + auto& S = fvcell.*private_state_ptr; + + EXPECT_EQ(0u, S->n_detector); + EXPECT_EQ(0u, S->src_to_spike.size()); + EXPECT_EQ(0u, S->time_since_spike.size()); + } + +} diff --git a/test/unit/test_kinetic_linear.cpp b/test/unit/test_kinetic_linear.cpp index feb1f8f6..b45731a2 100644 --- a/test/unit/test_kinetic_linear.cpp +++ b/test/unit/test_kinetic_linear.cpp @@ -45,9 +45,10 @@ void run_test(std::string mech_name, std::vector<fvm_value_type> temp(ncv, 300.); std::vector<fvm_value_type> diam(ncv, 1.); std::vector<fvm_value_type> vinit(ncv, -65); + std::vector<fvm_index_type> src_to_spike = {}; auto shared_state = std::make_unique<typename backend::shared_state>( - ncell, cv_to_intdom, gj, vinit, temp, diam, test->data_alignment()); + ncell, ncell, 0, cv_to_intdom, cv_to_intdom, gj, vinit, temp, diam, src_to_spike, test->data_alignment()); mechanism_layout layout; mechanism_overrides overrides; diff --git a/test/unit/test_mech_temp_diam.cpp b/test/unit/test_mech_temp_diam.cpp index a40f8cfc..ba051960 100644 --- a/test/unit/test_mech_temp_diam.cpp +++ b/test/unit/test_mech_temp_diam.cpp @@ -34,9 +34,10 @@ void run_celsius_test() { std::vector<fvm_value_type> temp(ncv, temperature_K); std::vector<fvm_value_type> diam(ncv, 1.); std::vector<fvm_value_type> vinit(ncv, -65); + std::vector<fvm_index_type> src_to_spike = {}; auto shared_state = std::make_unique<typename backend::shared_state>( - ncell, cv_to_intdom, gj, vinit, temp, diam, celsius_test->data_alignment()); + ncell, ncell, 0, cv_to_intdom, cv_to_intdom, gj, vinit, temp, diam, src_to_spike, celsius_test->data_alignment()); mechanism_layout layout; mechanism_overrides overrides; @@ -81,6 +82,7 @@ void run_diam_test() { std::vector<fvm_value_type> temp(ncv, 300.); std::vector<fvm_value_type> vinit(ncv, -65); std::vector<fvm_value_type> diam(ncv); + std::vector<fvm_index_type> src_to_spike = {}; mechanism_layout layout; mechanism_overrides overrides; @@ -93,7 +95,7 @@ void run_diam_test() { } auto shared_state = std::make_unique<typename backend::shared_state>( - ncell, cv_to_intdom, gj, vinit, temp, diam, celsius_test->data_alignment()); + ncell, ncell, 0, cv_to_intdom, cv_to_intdom, gj, vinit, temp, diam, src_to_spike, celsius_test->data_alignment()); celsius_test->instantiate(0, *shared_state, overrides, layout); diff --git a/test/unit/test_spikes.cpp b/test/unit/test_spikes.cpp index dd1f8d1e..d207f586 100644 --- a/test/unit/test_spikes.cpp +++ b/test/unit/test_spikes.cpp @@ -42,6 +42,10 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { const std::vector<index_type> index{0, 5, 7}; const std::vector<value_type> thresh{1., 2., 3.}; + std::vector<int> src_to_spike_vec = {0, 1, 5}; + std::vector<fvm_value_type> time_since_spike_vec(10); + memory::fill(time_since_spike_vec, -1.0); + // all values are initially 0, except for values[5] which we set // to exceed the threshold of 2. for the second watch array values(n, 0); @@ -57,11 +61,18 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { array time_before(2, 0.); array time_after(2, 0.); + iarray src_to_spike(src_to_spike_vec.size()); + memory::copy(src_to_spike_vec, src_to_spike); + + array time_since_spike(10, -1.0); + std::vector<unsigned> empty_slots = {2, 3, 4, 6, 7, 8, 9}; + // list for storing expected crossings for validation at the end list expected; // create the watch - backend::threshold_watcher watch(cell_index.data(), values.data(), &time_before, &time_after, index, thresh, context); + backend::threshold_watcher watch(cell_index.data(), values.data(), src_to_spike.data(), + &time_before, &time_after, index, thresh, context); // initially the first and third watch should not be spiking // the second is spiking @@ -72,18 +83,23 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { // test again at t=1, with unchanged values // - nothing should change memory::fill(time_after, 1.); - watch.test(); + watch.test(&time_since_spike); EXPECT_FALSE(watch.is_crossed(0)); EXPECT_TRUE(watch.is_crossed(1)); EXPECT_FALSE(watch.is_crossed(2)); EXPECT_EQ(watch.crossings().size(), 0u); + memory::copy(time_since_spike, time_since_spike_vec); + for (auto t: time_since_spike_vec) { + EXPECT_EQ(-1.0, t); + } + // test at t=2, with all values set to zero // - 2nd watch should now stop spiking memory::fill(values, 0.); memory::copy(time_after, time_before); memory::fill(time_after, 2.); - watch.test(); + watch.test(&time_since_spike); EXPECT_FALSE(watch.is_crossed(0)); EXPECT_FALSE(watch.is_crossed(1)); EXPECT_FALSE(watch.is_crossed(2)); @@ -95,7 +111,7 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { memory::copy(time_after, time_before); time_after[0] = 2.5; time_after[1] = 3.0; - watch.test(); + watch.test(&time_since_spike); EXPECT_TRUE(watch.is_crossed(0)); EXPECT_TRUE(watch.is_crossed(1)); EXPECT_TRUE(watch.is_crossed(2)); @@ -106,29 +122,50 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { expected.push_back({1u, 2.250f}); // 2. + (2.5-2)*(2./4.) expected.push_back({2u, 2.750f}); // 2. + (3.0-2)*(3./4.) + memory::copy(time_since_spike, time_since_spike_vec); + EXPECT_EQ(0.375, time_since_spike_vec[src_to_spike[0]]); // 2.5 - 2.125 + EXPECT_EQ(0.250, time_since_spike_vec[src_to_spike[1]]); // 2.5 - 2.250 + EXPECT_EQ(0.250, time_since_spike_vec[src_to_spike[2]]); // 3.0 - 2.750 + for (auto i: empty_slots) { + EXPECT_EQ(-1.0, time_since_spike_vec[i]); + } + // test at t=4, with all values set to 0. // - all watches should stop spiking memory::fill(values, 0.); memory::copy(time_after, time_before); memory::fill(time_after, 4.); - watch.test(); + watch.test(&time_since_spike); EXPECT_FALSE(watch.is_crossed(0)); EXPECT_FALSE(watch.is_crossed(1)); EXPECT_FALSE(watch.is_crossed(2)); EXPECT_EQ(watch.crossings().size(), 3u); + memory::copy(time_since_spike, time_since_spike_vec); + for (auto t: time_since_spike_vec) { + EXPECT_EQ(-1.0, t); + } + // test at t=5, with value on 3rd watch set to 6 // - watch 3 should be spiking values[index[2]] = 6.; memory::copy(time_after, time_before); memory::fill(time_after, 5.); - watch.test(); + watch.test(&time_since_spike); EXPECT_FALSE(watch.is_crossed(0)); EXPECT_FALSE(watch.is_crossed(1)); EXPECT_TRUE(watch.is_crossed(2)); EXPECT_EQ(watch.crossings().size(), 4u); expected.push_back({2u, 4.5f}); + memory::copy(time_since_spike, time_since_spike_vec); + EXPECT_EQ(-1.0, time_since_spike_vec[src_to_spike[0]]); + EXPECT_EQ(-1.0, time_since_spike_vec[src_to_spike[1]]); + EXPECT_EQ(0.50, time_since_spike_vec[src_to_spike[2]]); + for (auto i: empty_slots) { + EXPECT_EQ(-1.0, time_since_spike_vec[i]); + } + // // test that all generated spikes matched the expected values // diff --git a/test/unit/test_synapses.cpp b/test/unit/test_synapses.cpp index daab5fe6..4f315723 100644 --- a/test/unit/test_synapses.cpp +++ b/test/unit/test_synapses.cpp @@ -92,11 +92,15 @@ TEST(synapses, syn_basic_state) { auto align = std::max(expsyn->data_alignment(), exp2syn->data_alignment()); shared_state state(num_intdom, + num_intdom, + 0, + std::vector<index_type>(num_comp, 0), std::vector<index_type>(num_comp, 0), {}, std::vector<value_type>(num_comp, -65), std::vector<value_type>(num_comp, temp_K), std::vector<value_type>(num_comp, 1.), + std::vector<index_type>(0), align); state.reset(); diff --git a/test/unit/unit_test_catalogue.cpp b/test/unit/unit_test_catalogue.cpp index 1d3c5112..f137d4e9 100644 --- a/test/unit/unit_test_catalogue.cpp +++ b/test/unit/unit_test_catalogue.cpp @@ -10,8 +10,9 @@ #include "mechanisms/ca_linear.hpp" #include "mechanisms/celsius_test.hpp" #include "mechanisms/diam_test.hpp" -#include "mechanisms/param_as_state.hpp" #include "mechanisms/non_linear.hpp" +#include "mechanisms/param_as_state.hpp" +#include "mechanisms/post_events_syn.hpp" #include "mechanisms/test0_kin_diff.hpp" #include "mechanisms/test_linear_state.hpp" #include "mechanisms/test_linear_init.hpp" @@ -66,6 +67,7 @@ mechanism_catalogue make_unit_test_catalogue(const mechanism_catalogue& from) { ADD_MECH(cat, celsius_test) ADD_MECH(cat, diam_test) ADD_MECH(cat, param_as_state) + ADD_MECH(cat, post_events_syn) ADD_MECH(cat, test_linear_state) ADD_MECH(cat, test_linear_init) ADD_MECH(cat, test_linear_init_shuffle) -- GitLab