diff --git a/arbor/backends/gpu/fvm.hpp b/arbor/backends/gpu/fvm.hpp index 5cddad532eeb97a5500fa54fbc9f03fdb53bcb5e..a3467f22db9c25bedfd8e37f0942d47b048299b9 100644 --- a/arbor/backends/gpu/fvm.hpp +++ b/arbor/backends/gpu/fvm.hpp @@ -50,7 +50,7 @@ struct backend { using ion_state = arb::gpu::ion_state; static threshold_watcher voltage_watcher( - const shared_state& state, + shared_state& state, const std::vector<index_type>& cv, const std::vector<value_type>& thresholds, const execution_context& context) @@ -58,6 +58,7 @@ struct backend { return threshold_watcher( state.cv_to_intdom.data(), state.voltage.data(), + state.src_to_spike.data(), &state.time, &state.time_to, cv, diff --git a/arbor/backends/gpu/mechanism.cpp b/arbor/backends/gpu/mechanism.cpp index e577be3f37cac83d49614fd6738706358160d711..b96a01891c2883f864c6c8881f4f45ec04611c1a 100644 --- a/arbor/backends/gpu/mechanism.cpp +++ b/arbor/backends/gpu/mechanism.cpp @@ -83,8 +83,10 @@ void mechanism::instantiate(unsigned id, mechanism_ppack_base* pp = ppack_ptr(); // From derived class instance. pp->width_ = width_; + pp->n_detectors_ = shared.n_detector; - pp->vec_ci_ = shared.cv_to_intdom.data(); + pp->vec_ci_ = shared.cv_to_cell.data(); + pp->vec_di_ = shared.cv_to_intdom.data(); pp->vec_dt_ = shared.dt_cv.data(); pp->vec_v_ = shared.voltage.data(); @@ -93,6 +95,7 @@ void mechanism::instantiate(unsigned id, pp->temperature_degC_ = shared.temperature_degC.data(); pp->diam_um_ = shared.diam_um.data(); + pp->time_since_spike_ = shared.time_since_spike.data(); auto ion_state_tbl = ion_state_table(); num_ions_ = ion_state_tbl.size(); diff --git a/arbor/backends/gpu/mechanism_ppack_base.hpp b/arbor/backends/gpu/mechanism_ppack_base.hpp index 0727a82957a40adc04d35c47111a5a9a4a4949ef..0a221aa1dac3fb6052e5833bc7ae58e428004330 100644 --- a/arbor/backends/gpu/mechanism_ppack_base.hpp +++ b/arbor/backends/gpu/mechanism_ppack_base.hpp @@ -29,8 +29,10 @@ struct mechanism_ppack_base { using ion_state_view = ::arb::gpu::ion_state_view; index_type width_; + index_type n_detectors_; const index_type* vec_ci_; + const index_type* vec_di_; const value_type* vec_t_; const value_type* vec_t_to_; const value_type* vec_dt_; @@ -39,6 +41,7 @@ struct mechanism_ppack_base { value_type* vec_g_; const value_type* temperature_degC_; const value_type* diam_um_; + const value_type* time_since_spike_; const index_type* node_index_; const index_type* multiplicity_; diff --git a/arbor/backends/gpu/shared_state.cpp b/arbor/backends/gpu/shared_state.cpp index 93489f1ae1fd7d8eeb6599ad6e164dd482d387a5..ee5648457a74e335610e5087f9da18854df335e8 100644 --- a/arbor/backends/gpu/shared_state.cpp +++ b/arbor/backends/gpu/shared_state.cpp @@ -90,17 +90,23 @@ void ion_state::reset() { shared_state::shared_state( fvm_size_type n_intdom, + fvm_size_type n_cell, + fvm_size_type n_detector, const std::vector<fvm_index_type>& cv_to_intdom_vec, + const std::vector<fvm_index_type>& cv_to_cell_vec, const std::vector<fvm_gap_junction>& gj_vec, const std::vector<fvm_value_type>& init_membrane_potential, const std::vector<fvm_value_type>& temperature_K, const std::vector<fvm_value_type>& diam, + const std::vector<fvm_index_type>& src_to_spike, unsigned // alignment parameter ignored. -): + ): n_intdom(n_intdom), + n_detector(n_detector), n_cv(cv_to_intdom_vec.size()), n_gj(gj_vec.size()), cv_to_intdom(make_const_view(cv_to_intdom_vec)), + cv_to_cell(make_const_view(cv_to_cell_vec)), gap_junctions(make_const_view(gj_vec)), time(n_intdom), time_to(n_intdom), @@ -112,8 +118,11 @@ shared_state::shared_state( init_voltage(make_const_view(init_membrane_potential)), temperature_degC(make_const_view(temperature_K)), diam_um(make_const_view(diam)), + time_since_spike(n_cell*n_detector), + src_to_spike(make_const_view(src_to_spike)), deliverable_events(n_intdom) { + memory::fill(time_since_spike, -1.0); add_scalar(temperature_degC.size(), temperature_degC.data(), -273.15); } @@ -133,6 +142,7 @@ void shared_state::reset() { memory::fill(conductivity, 0); memory::fill(time, 0); memory::fill(time_to, 0); + memory::fill(time_since_spike, -1.0); for (auto& i: ion_data) { i.second.reset(); diff --git a/arbor/backends/gpu/shared_state.hpp b/arbor/backends/gpu/shared_state.hpp index c375aa325f2872c7a426ba8cde2fa373a986665a..8d0254b3e95feeb9734a1bac6b3261d95d95d926 100644 --- a/arbor/backends/gpu/shared_state.hpp +++ b/arbor/backends/gpu/shared_state.hpp @@ -61,11 +61,13 @@ struct ion_state { }; struct shared_state { - fvm_size_type n_intdom = 0; // Number of distinct integration domains. - fvm_size_type n_cv = 0; // Total number of CVs. - fvm_size_type n_gj = 0; // Total number of GJs. + fvm_size_type n_intdom = 0; // Number of distinct integration domains. + fvm_size_type n_detector = 0; // Max number of detectors on all cells. + fvm_size_type n_cv = 0; // Total number of CVs. + fvm_size_type n_gj = 0; // Total number of GJs. iarray cv_to_intdom; // Maps CV index to intdom index. + iarray cv_to_cell; // Maps CV index to cell index. gjarray gap_junctions; // Stores gap_junction info. array time; // Maps intdom index to integration start time [ms]. array time_to; // Maps intdom index to integration stop time [ms]. @@ -79,6 +81,9 @@ struct shared_state { array temperature_degC; // Maps CV to local temperature (read only) [°C]. array diam_um; // Maps CV to local diameter (read only) [µm]. + array time_since_spike; // Stores time since last spike on any detector, organized by cell. + iarray src_to_spike; // Maps spike source index to spike index + std::unordered_map<std::string, ion_state> ion_data; deliverable_event_stream deliverable_events; @@ -87,11 +92,15 @@ struct shared_state { shared_state( fvm_size_type n_intdom, + fvm_size_type n_cell, + fvm_size_type n_detector, const std::vector<fvm_index_type>& cv_to_intdom_vec, + const std::vector<fvm_index_type>& cv_to_cell_vec, const std::vector<fvm_gap_junction>& gj_vec, const std::vector<fvm_value_type>& init_membrane_potential, const std::vector<fvm_value_type>& temperature_K, const std::vector<fvm_value_type>& diam, + const std::vector<fvm_index_type>& src_to_spike, unsigned align ); diff --git a/arbor/backends/gpu/stimulus.cu b/arbor/backends/gpu/stimulus.cu index 1589d6e4f340df91451f5caaa5b9af837c45b3b8..a207b376614d2874127c4349aad7dc63a1b597b5 100644 --- a/arbor/backends/gpu/stimulus.cu +++ b/arbor/backends/gpu/stimulus.cu @@ -12,7 +12,7 @@ namespace kernel { void stimulus_current_impl(int n, stimulus_pp pp) { auto i = threadIdx.x + blockDim.x*blockIdx.x; if (i<n) { - auto t = pp.vec_t_[pp.vec_ci_[i]]; + auto t = pp.vec_t_[pp.vec_di_[i]]; if (t>=pp.delay[i] && t<pp.delay[i]+pp.duration[i]) { // use subtraction because the electrode currents are specified // in terms of current into the compartment diff --git a/arbor/backends/gpu/threshold_watcher.cu b/arbor/backends/gpu/threshold_watcher.cu index 28180608ca466369e7dc171983b3216df0ae035e..61706f9ee4c5ea215a110a34f0e90e998a30b73f 100644 --- a/arbor/backends/gpu/threshold_watcher.cu +++ b/arbor/backends/gpu/threshold_watcher.cu @@ -33,12 +33,15 @@ void test_thresholds_impl( const fvm_index_type* __restrict__ const cv_to_intdom, const fvm_value_type* __restrict__ const t_after, const fvm_value_type* __restrict__ const t_before, + const fvm_index_type* __restrict__ const src_to_spike, + fvm_value_type* __restrict__ const time_since_spike, stack_storage<threshold_crossing>& stack, fvm_index_type* __restrict__ const is_crossed, fvm_value_type* __restrict__ const prev_values, const fvm_index_type* __restrict__ const cv_index, const fvm_value_type* __restrict__ const values, - const fvm_value_type* __restrict__ const thresholds) + const fvm_value_type* __restrict__ const thresholds, + bool record_time_since_spike) { int i = threadIdx.x + blockIdx.x*blockDim.x; @@ -48,17 +51,27 @@ void test_thresholds_impl( if (i<size) { // Test for threshold crossing const auto cv = cv_index[i]; - const auto cell = cv_to_intdom[cv]; + const auto intdom = cv_to_intdom[cv]; const auto v_prev = prev_values[i]; const auto v = values[cv]; const auto thresh = thresholds[i]; + fvm_index_type spike_idx = 0; + // Reset all spike times to -1.0 indicating no spike has been recorded on the detector + if (record_time_since_spike) { + spike_idx = src_to_spike[i]; + time_since_spike[spike_idx] = -1.0; + } if (!is_crossed[i]) { if (v>=thresh) { // The threshold has been passed, so estimate the time using // linear interpolation auto pos = (thresh - v_prev)/(v - v_prev); - crossing_time = lerp(t_before[cell], t_after[cell], pos); + crossing_time = lerp(t_before[intdom], t_after[intdom], pos); + + if (record_time_since_spike) { + time_since_spike[spike_idx] = t_after[intdom] - crossing_time; + } is_crossed[i] = 1; crossed = true; @@ -95,15 +108,17 @@ extern void reset_crossed_impl( void test_thresholds_impl( int size, const fvm_index_type* cv_to_intdom, const fvm_value_type* t_after, const fvm_value_type* t_before, - stack_storage<threshold_crossing>& stack, + const fvm_index_type* src_to_spike, fvm_value_type* time_since_spike, stack_storage<threshold_crossing>& stack, fvm_index_type* is_crossed, fvm_value_type* prev_values, - const fvm_index_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds) + const fvm_index_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds, + bool record_time_since_spike) { if (size>0) { constexpr int block_dim = 128; const int grid_dim = impl::block_count(size, block_dim); kernel::test_thresholds_impl<<<grid_dim, block_dim>>>( - size, cv_to_intdom, t_after, t_before, stack, is_crossed, prev_values, cv_index, values, thresholds); + size, cv_to_intdom, t_after, t_before, src_to_spike, time_since_spike, + stack, is_crossed, prev_values, cv_index, values, thresholds, record_time_since_spike); } } diff --git a/arbor/backends/gpu/threshold_watcher.hpp b/arbor/backends/gpu/threshold_watcher.hpp index bf533073b4a3b2ecdc072dbfdfaf2bec2cc75619..6d371faa64b2923401eb90bfb7864e533bcea19d 100644 --- a/arbor/backends/gpu/threshold_watcher.hpp +++ b/arbor/backends/gpu/threshold_watcher.hpp @@ -22,9 +22,10 @@ namespace gpu { void test_thresholds_impl( int size, const fvm_index_type* cv_to_intdom, const fvm_value_type* t_after, const fvm_value_type* t_before, - stack_storage<threshold_crossing>& stack, + const fvm_index_type* src_to_spike, fvm_value_type* time_since_spike, stack_storage<threshold_crossing>& stack, fvm_index_type* is_crossed, fvm_value_type* prev_values, - const fvm_index_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds); + const fvm_index_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds, + bool record); void reset_crossed_impl( int size, @@ -45,6 +46,7 @@ public: threshold_watcher( const fvm_index_type* cv_to_intdom, const fvm_value_type* values, + const fvm_index_type* src_to_spike, const array* t_before, const array* t_after, const std::vector<fvm_index_type>& cv_index, @@ -53,6 +55,7 @@ public: ): cv_to_intdom_(cv_to_intdom), values_(values), + src_to_spike_(src_to_spike), t_before_ptr_(t_before), t_after_ptr_(t_after), cv_index_(memory::make_const_view(cv_index)), @@ -104,17 +107,16 @@ public: /// Crossing events are recorded for each threshold that has been /// crossed since current time t, and the last time the test was /// performed. - void test() { - const fvm_value_type* t_before = t_before_ptr_->data(); - const fvm_value_type* t_after = t_after_ptr_->data(); - + void test(array* time_since_spike) { if (size()>0) { test_thresholds_impl( (int)size(), - cv_to_intdom_, t_after, t_before, + cv_to_intdom_, t_after_ptr_->data(), t_before_ptr_->data(), + src_to_spike_, time_since_spike->data(), stack_.storage(), is_crossed_.data(), v_prev_.data(), - cv_index_.data(), values_, thresholds_.data()); + cv_index_.data(), values_, thresholds_.data(), + !time_since_spike->empty()); // Check that the number of spikes has not exceeded capacity. arb_assert(!stack_.overflow()); @@ -132,6 +134,7 @@ private: /// and pointers to the time arrays const fvm_index_type* cv_to_intdom_ = nullptr; const fvm_value_type* values_ = nullptr; + const fvm_index_type* src_to_spike_ = nullptr; const array* t_before_ptr_ = nullptr; const array* t_after_ptr_ = nullptr; diff --git a/arbor/backends/multicore/fvm.hpp b/arbor/backends/multicore/fvm.hpp index e81b7d7dcf984733b128fd2703d925128dc85115..1920f801b323a282ae8a11ed1daa214d7e0be170 100644 --- a/arbor/backends/multicore/fvm.hpp +++ b/arbor/backends/multicore/fvm.hpp @@ -48,7 +48,7 @@ struct backend { using ion_state = arb::multicore::ion_state; static threshold_watcher voltage_watcher( - const shared_state& state, + shared_state& state, const std::vector<index_type>& cv, const std::vector<value_type>& thresholds, const execution_context& context) @@ -56,6 +56,7 @@ struct backend { return threshold_watcher( state.cv_to_intdom.data(), state.voltage.data(), + state.src_to_spike.data(), &state.time, &state.time_to, cv, diff --git a/arbor/backends/multicore/mechanism.cpp b/arbor/backends/multicore/mechanism.cpp index 285c27dd189e525e42c2ef1a39674d51624349cb..3e7f6c2f5a5b76e8ce44ab6030fb113c9571f708 100644 --- a/arbor/backends/multicore/mechanism.cpp +++ b/arbor/backends/multicore/mechanism.cpp @@ -83,7 +83,8 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me // Assign non-owning views onto shared state: - vec_ci_ = shared.cv_to_intdom.data(); + vec_ci_ = shared.cv_to_cell.data(); + vec_di_ = shared.cv_to_intdom.data(); vec_dt_ = shared.dt_cv.data(); vec_v_ = shared.voltage.data(); @@ -92,6 +93,9 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me temperature_degC_ = shared.temperature_degC.data(); diam_um_ = shared.diam_um.data(); + time_since_spike_ = shared.time_since_spike.data(); + + n_detectors_ = shared.n_detector; auto ion_state_tbl = ion_state_table(); n_ion_ = ion_state_tbl.size(); diff --git a/arbor/backends/multicore/mechanism.hpp b/arbor/backends/multicore/mechanism.hpp index 140c5d78ebe43a703815db9ee47c8bbdc7f55926..afd638c34e5137ecc911eda54273cc6e3f5f5fe9 100644 --- a/arbor/backends/multicore/mechanism.hpp +++ b/arbor/backends/multicore/mechanism.hpp @@ -84,10 +84,12 @@ protected: size_type width_ = 0; // Instance width (number of CVs/sites) size_type width_padded_ = 0; // Width rounded up to multiple of pad/alignment. size_type n_ion_ = 0; + size_type n_detectors_ = 0; // Non-owning views onto shared cell state, excepting ion state. - const index_type* vec_ci_; // CV to cell index. + const index_type* vec_ci_; // CV to cell index + const index_type* vec_di_; // CV to indom index const value_type* vec_t_; // Cell index to cell-local time. const value_type* vec_t_to_; // Cell index to cell-local integration step time end. const value_type* vec_dt_; // CV to integration time step. @@ -96,6 +98,7 @@ protected: value_type* vec_g_; // CV to cell membrane conductivity. const value_type* temperature_degC_; // CV to temperature. const value_type* diam_um_; // CV to diameter. + const value_type* time_since_spike_; // Vector containing time since last spike, indexed by cell index and n_detectors_ const array* vec_t_ptr_; const array* vec_t_to_ptr_; diff --git a/arbor/backends/multicore/shared_state.cpp b/arbor/backends/multicore/shared_state.cpp index dfdc70b1e3593d593879b5a51737645f15383ae8..02f4e9f880559a49bd66317249519d8f01e757e9 100644 --- a/arbor/backends/multicore/shared_state.cpp +++ b/arbor/backends/multicore/shared_state.cpp @@ -91,19 +91,25 @@ void ion_state::reset() { shared_state::shared_state( fvm_size_type n_intdom, + fvm_size_type n_cell, + fvm_size_type n_detector, const std::vector<fvm_index_type>& cv_to_intdom_vec, + const std::vector<fvm_index_type>& cv_to_cell_vec, const std::vector<fvm_gap_junction>& gj_vec, const std::vector<fvm_value_type>& init_membrane_potential, const std::vector<fvm_value_type>& temperature_K, const std::vector<fvm_value_type>& diam, + const std::vector<fvm_index_type>& src_to_spike, unsigned align ): alignment(min_alignment(align)), alloc(alignment), n_intdom(n_intdom), + n_detector(n_detector), n_cv(cv_to_intdom_vec.size()), n_gj(gj_vec.size()), cv_to_intdom(math::round_up(n_cv, alignment), pad(alignment)), + cv_to_cell(math::round_up(cv_to_cell_vec.size(), alignment), pad(alignment)), gap_junctions(math::round_up(n_gj, alignment), pad(alignment)), time(n_intdom, pad(alignment)), time_to(n_intdom, pad(alignment)), @@ -115,6 +121,8 @@ shared_state::shared_state( init_voltage(init_membrane_potential.begin(), init_membrane_potential.end(), pad(alignment)), temperature_degC(n_cv, pad(alignment)), diam_um(diam.begin(), diam.end(), pad(alignment)), + time_since_spike(n_cell*n_detector, pad(alignment)), + src_to_spike(src_to_spike.begin(), src_to_spike.end(), pad(alignment)), deliverable_events(n_intdom) { // For indices in the padded tail of cv_to_intdom, set index to last valid intdom index. @@ -122,11 +130,16 @@ shared_state::shared_state( std::copy(cv_to_intdom_vec.begin(), cv_to_intdom_vec.end(), cv_to_intdom.begin()); std::fill(cv_to_intdom.begin() + n_cv, cv_to_intdom.end(), cv_to_intdom_vec.back()); } + if (cv_to_cell_vec.size()) { + std::copy(cv_to_cell_vec.begin(), cv_to_cell_vec.end(), cv_to_cell.begin()); + std::fill(cv_to_cell.begin() + n_cv, cv_to_cell.end(), cv_to_cell_vec.back()); + } if (n_gj>0) { std::copy(gj_vec.begin(), gj_vec.end(), gap_junctions.begin()); std::fill(gap_junctions.begin()+n_gj, gap_junctions.end(), gj_vec.back()); } + util::fill(time_since_spike, -1.0); for (unsigned i = 0; i<n_cv; ++i) { temperature_degC[i] = temperature_K[i] - 273.15; } @@ -148,6 +161,7 @@ void shared_state::reset() { util::fill(conductivity, 0); util::fill(time, 0); util::fill(time_to, 0); + util::fill(time_since_spike, -1.0); for (auto& i: ion_data) { i.second.reset(); diff --git a/arbor/backends/multicore/shared_state.hpp b/arbor/backends/multicore/shared_state.hpp index df146a5cb7930a533977c1dd2d797bee2636ae24..c36908fd32335e166318e7416c979ac3f654848d 100644 --- a/arbor/backends/multicore/shared_state.hpp +++ b/arbor/backends/multicore/shared_state.hpp @@ -79,10 +79,12 @@ struct shared_state { util::padded_allocator<> alloc; // Allocator with corresponging alignment/padding. fvm_size_type n_intdom = 0; // Number of integration domains. + fvm_size_type n_detector = 0; // Max number of detectors on all cells. fvm_size_type n_cv = 0; // Total number of CVs. fvm_size_type n_gj = 0; // Total number of GJs. iarray cv_to_intdom; // Maps CV index to integration domain index. + iarray cv_to_cell; // Maps CV index to the first spike gjarray gap_junctions; // Stores gap_junction info. array time; // Maps intdom index to integration start time [ms]. array time_to; // Maps intdom index to integration stop time [ms]. @@ -96,6 +98,9 @@ struct shared_state { array temperature_degC; // Maps CV to local temperature (read only) [°C]. array diam_um; // Maps CV to local diameter (read only) [µm]. + array time_since_spike; // Stores time since last spike on any detector, organized by cell. + iarray src_to_spike; // Maps spike source index to spike index + std::unordered_map<std::string, ion_state> ion_data; deliverable_event_stream deliverable_events; @@ -104,11 +109,15 @@ struct shared_state { shared_state( fvm_size_type n_intdom, + fvm_size_type n_cell, + fvm_size_type n_detector, const std::vector<fvm_index_type>& cv_to_intdom_vec, + const std::vector<fvm_index_type>& cv_to_cell_vec, const std::vector<fvm_gap_junction>& gj_vec, const std::vector<fvm_value_type>& init_membrane_potential, const std::vector<fvm_value_type>& temperature_K, const std::vector<fvm_value_type>& diam, + const std::vector<fvm_index_type>& src_to_spike, unsigned align ); diff --git a/arbor/backends/multicore/stimulus.cpp b/arbor/backends/multicore/stimulus.cpp index 11f6bdf3a792d354c347623c59b254a58b1a0510..d17d2505c3413220b54f3838edde5ab1f0ea0f6b 100644 --- a/arbor/backends/multicore/stimulus.cpp +++ b/arbor/backends/multicore/stimulus.cpp @@ -24,7 +24,7 @@ public: size_type n = size(); for (size_type i=0; i<n; ++i) { auto cv = node_index_[i]; - auto t = vec_t_[vec_ci_[cv]]; + auto t = vec_t_[vec_di_[cv]]; if (t>=delay[i] && t<delay[i]+duration[i]) { // Amplitudes are given as a current into a compartment, so subtract. diff --git a/arbor/backends/multicore/threshold_watcher.hpp b/arbor/backends/multicore/threshold_watcher.hpp index 7074dda5353430dbf9a8ce0cc9846c5829221c53..9b2d5e533184711c86901cf84f35387e6b6567ee 100644 --- a/arbor/backends/multicore/threshold_watcher.hpp +++ b/arbor/backends/multicore/threshold_watcher.hpp @@ -20,6 +20,7 @@ public: threshold_watcher( const fvm_index_type* cv_to_intdom, const fvm_value_type* values, + const fvm_index_type* src_to_spike, const array* t_before, const array* t_after, const std::vector<fvm_index_type>& cv_index, @@ -28,6 +29,7 @@ public: ): cv_to_intdom_(cv_to_intdom), values_(values), + src_to_spike_(src_to_spike), t_before_ptr_(t_before), t_after_ptr_(t_after), n_cv_(cv_index.size()), @@ -63,24 +65,35 @@ public: /// Tests each target for changed threshold state /// Crossing events are recorded for each threshold that /// is crossed since the last call to test - void test() { + void test(array* time_since_spike) { + // Reset all spike times to -1.0 indicating no spike has been recorded on the detector const fvm_value_type* t_before = t_before_ptr_->data(); const fvm_value_type* t_after = t_after_ptr_->data(); for (fvm_size_type i = 0; i<n_cv_; ++i) { auto cv = cv_index_[i]; - auto cell = cv_to_intdom_[cv]; + auto intdom = cv_to_intdom_[cv]; auto v_prev = v_prev_[i]; auto v = values_[cv]; auto thresh = thresholds_[i]; + fvm_index_type spike_idx = 0; + + if (!time_since_spike->empty()) { + spike_idx = src_to_spike_[i]; + (*time_since_spike)[spike_idx] = -1.0; + } if (!is_crossed_[i]) { if (v>=thresh) { // The threshold has been passed, so estimate the time using // linear interpolation. auto pos = (thresh - v_prev)/(v - v_prev); - auto crossing_time = math::lerp(t_before[cell], t_after[cell], pos); + auto crossing_time = math::lerp(t_before[intdom], t_after[intdom], pos); crossings_.push_back({i, crossing_time}); + if (!time_since_spike->empty()) { + (*time_since_spike)[spike_idx] = t_after[intdom] - crossing_time; + } + is_crossed_[i] = true; } } @@ -109,6 +122,7 @@ private: /// and pointers to the time arrays const fvm_index_type* cv_to_intdom_ = nullptr; const fvm_value_type* values_ = nullptr; + const fvm_index_type* src_to_spike_ = nullptr; const array* t_before_ptr_ = nullptr; const array* t_after_ptr_ = nullptr; diff --git a/arbor/fvm_layout.cpp b/arbor/fvm_layout.cpp index 2125fb1e20872ada3125780613dfc3ea0febcaa3..e97a6e4d251b9ab667028d261f688be8450d17a5 100644 --- a/arbor/fvm_layout.cpp +++ b/arbor/fvm_layout.cpp @@ -751,6 +751,8 @@ fvm_mechanism_data& append(fvm_mechanism_data& left, const fvm_mechanism_data& r } left.n_target += right.n_target; + left.post_events |= right.post_events; + append_divs(left.target_divs, right.target_divs); arb_assert(left.n_target==left.target_divs.back()); @@ -945,10 +947,13 @@ fvm_mechanism_data fvm_build_mechanism_data(const cable_cell_global_properties& std::vector<synapse_instance> inst_list; std::vector<size_type> cv_order; + bool post_events = false; + for (const auto& entry: cell.synapses()) { const std::string& name = entry.first; mechanism_info info = catalogue[name]; + post_events |= info.post_events; std::size_t n_param = info.parameters.size(); std::size_t n_inst = entry.second.size(); @@ -1065,6 +1070,7 @@ fvm_mechanism_data fvm_build_mechanism_data(const cable_cell_global_properties& M.n_target += config.target.size(); M.mechanisms[name] = std::move(config); } + M.post_events = post_events; // Stimuli: diff --git a/arbor/fvm_layout.hpp b/arbor/fvm_layout.hpp index 69557bc005414b57d1532908c2beb77296bd4913..cdc7c4dae5c8f84dd0d085b5e9f101a559876653 100644 --- a/arbor/fvm_layout.hpp +++ b/arbor/fvm_layout.hpp @@ -280,6 +280,9 @@ struct fvm_mechanism_data { // Partitions target numbers by cell. std::vector<std::size_t> target_divs; + + // Contains mechanisms with post_event + bool post_events = false; }; fvm_mechanism_data fvm_build_mechanism_data(const cable_cell_global_properties& gprop, const std::vector<cable_cell>& cells, const fvm_cv_discretization& D, const arb::execution_context& ctx={}); diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index 5dd419c39f2c18a54a9dcbe934628db73f4047ad..d5126cf0cb10888fd9d198befa53ece58c386339 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -107,7 +107,10 @@ private: std::vector<mechanism_ptr> revpot_mechanisms_; // Non-physical voltage check threshold, 0 => no check. - value_type check_voltage_mV = 0; + value_type check_voltage_mV_ = 0; + + // Flag indicating that at least one of the mechanisms implements the post_events procedure + bool post_events_; // Host-side views/copies and local state. decltype(backend::host_view(sample_time_)) sample_time_host_; @@ -231,7 +234,6 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate( m->update_current(); } - // Deliver events and accumulate mechanism current contributions. PE(advance_integrate_events); @@ -291,15 +293,24 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate( // Update time and test for spike threshold crossings. PE(advance_integrate_threshold); - threshold_watcher_.test(); - std::swap(state_->time_to, state_->time); + threshold_watcher_.test(&state_->time_since_spike); + PL(); + + PE(advance_integrate_post) + if (post_events_) { + for (auto& m: mechanisms_) { + m->post_event(); + } + } PL(); + std::swap(state_->time_to, state_->time); + // Check for non-physical solutions: - if (check_voltage_mV>0) { + if (check_voltage_mV_>0) { PE(advance_integrate_physicalcheck); - assert_voltage_bounded(check_voltage_mV); + assert_voltage_bounded(check_voltage_mV_); PL(); } @@ -393,6 +404,26 @@ void fvm_lowered_cell_impl<Backend>::initialize( // (Throws cable_cell_error on failure.) check_global_properties(global_props); + // Sanity check recipe; find max num_sources and + // create a list of the global identifiers for the spike sources + + std::vector<fvm_size_type> nsources; + for (auto cell_idx: make_span(ncell)) { + cell_gid_type gid = gids[cell_idx]; + auto& cell = cells[cell_idx]; + + auto num_sources = rec.num_sources(gid); + nsources.push_back(num_sources); + + if (num_sources != cell.detectors().size()) { + throw arb::bad_source_description(gid, num_sources, cell.detectors().size()); + } + auto cell_targets = util::sum_by(cell.synapses(), [](auto& syn) { return syn.second.size(); }); + if (rec.num_targets(gid) > cell_targets) { + throw arb::bad_target_description(gid, rec.num_targets(gid), cell_targets); + } + } + const mechanism_catalogue* catalogue = global_props.catalogue; // Mechanism instantiator helper. @@ -403,9 +434,9 @@ void fvm_lowered_cell_impl<Backend>::initialize( // Check for physically reasonable membrane volages? - check_voltage_mV = global_props.membrane_voltage_limit_mV; + check_voltage_mV_ = global_props.membrane_voltage_limit_mV; - auto num_intdoms = fvm_intdom(rec, gids, cell_to_intdom); + auto nintdom = fvm_intdom(rec, gids, cell_to_intdom); // Discretize cells, build matrix. @@ -418,7 +449,7 @@ void fvm_lowered_cell_impl<Backend>::initialize( arb_assert(D.n_cell() == ncell); matrix_ = matrix<backend>(D.geometry.cv_parent, D.geometry.cell_cv_divs, D.cv_capacitance, D.face_conductance, D.cv_area, cell_to_intdom); - sample_events_ = sample_event_stream(num_intdoms); + sample_events_ = sample_event_stream(nintdom); // Discretize mechanism data. @@ -428,6 +459,21 @@ void fvm_lowered_cell_impl<Backend>::initialize( auto gj_vector = fvm_gap_junctions(cells, gids, rec, D); + // Fill src_to_spike and cv_to_cell vectors only if mechanisms with post_events implemented are present. + post_events_ = mech_data.post_events; + auto max_detector = post_events_ ? util::max_value(nsources) : 0; + std::vector<fvm_index_type> src_to_spike, cv_to_cell; + + if (post_events_) { + for (auto cell_idx: make_span(ncell)) { + for (auto lid: make_span(nsources[cell_idx])) { + src_to_spike.push_back(cell_idx * max_detector + lid); + } + } + src_to_spike.shrink_to_fit(); + cv_to_cell = D.geometry.cv_to_cell; + } + // Create shared cell state. // (SIMD padding requires us to check each mechanism for alignment/padding constraints.) @@ -436,7 +482,8 @@ void fvm_lowered_cell_impl<Backend>::initialize( [&](const std::string& name) { return mech_instance(name).mech->data_alignment(); })); state_ = std::make_unique<shared_state>( - num_intdoms, cv_to_intdom, gj_vector, D.init_membrane_potential, D.temperature_K, D.diam_um, + nintdom, ncell, max_detector, cv_to_intdom, std::move(cv_to_cell), gj_vector, + D.init_membrane_potential, D.temperature_K, D.diam_um, std::move(src_to_spike), data_alignment? data_alignment: 1u); // Instantiate mechanisms and ions. @@ -534,16 +581,6 @@ void fvm_lowered_cell_impl<Backend>::initialize( for (auto cell_idx: make_span(ncell)) { cell_gid_type gid = gids[cell_idx]; - // Sanity check recipe - auto& cell = cells[cell_idx]; - if (rec.num_sources(gid) != cell.detectors().size()) { - throw arb::bad_source_description(gid, rec.num_sources(gid), cell.detectors().size());; - } - auto cell_targets = util::sum_by(cell.synapses(), [](auto& syn) {return syn.second.size();}); - if (rec.num_targets(gid) > cell_targets) { - throw arb::bad_target_description(gid, rec.num_targets(gid), cell_targets); - } - // Collect detectors, probe handles. for (auto entry: cells[cell_idx].detectors()) { detector_cv.push_back(D.geometry.location_cv(cell_idx, entry.loc, cv_prefer::cv_empty)); diff --git a/arbor/include/arbor/mechanism.hpp b/arbor/include/arbor/mechanism.hpp index 5e87a441fcb6986955858d6fdac54763cd4587bf..c1361f3021bcab12b5c55bcaeb322ccb80982e94 100644 --- a/arbor/include/arbor/mechanism.hpp +++ b/arbor/include/arbor/mechanism.hpp @@ -54,6 +54,7 @@ public: virtual void update_state() {}; virtual void update_current() {}; virtual void deliver_events() {}; + virtual void post_event() {}; virtual void update_ions() {}; virtual ~mechanism() = default; diff --git a/arbor/include/arbor/mechinfo.hpp b/arbor/include/arbor/mechinfo.hpp index f3cc748a20b6e5c45e82a7242be95d4e1c21a62d..45785e31c4f5b48d911b64ba812c4c3ceacc933c 100644 --- a/arbor/include/arbor/mechinfo.hpp +++ b/arbor/include/arbor/mechinfo.hpp @@ -69,6 +69,8 @@ struct mechanism_info { mechanism_fingerprint fingerprint; bool linear = false; + + bool post_events = false; }; } // namespace arb diff --git a/doc/internals/nmodl.rst b/doc/internals/nmodl.rst index 84474d31efcf4a64bc3b61a00b2c8f8603f3f791..4757a8484b67366c875467ab600db98636223a69 100644 --- a/doc/internals/nmodl.rst +++ b/doc/internals/nmodl.rst @@ -31,7 +31,7 @@ Ions must be added explicitly in Arbor along with their default properties and valence (this can be done in the recipe or on a single cell model). Simply specifying them in NMODL will not work. -* The parameters and variabnles of each ion referenced in a ``USEION`` statement +* The parameters and variables of each ion referenced in a ``USEION`` statement are available automatically to the mechanism. The exposed variables are: internal concentration ``Xi``, external concentration ``Xo``, reversal potential ``eX`` and current ``iX``. It is an error to also mark these as @@ -42,18 +42,19 @@ Ions * If ``Xi``, ``Xo``, ``eX``, ``iX`` are used in a ``PROCEDURE`` or ``FUNCTION``, they need to be passed as arguments. * If ``Xi`` or ``Xo`` (internal and external concentrations) are written in the - NMODL mechanism they need to be specified as ``STATE`` variables. + NMODL mechanism they need to be declared as ``STATE`` variables and their initial + values have to be set in the mechanism. Special variables ----------------- * Arbor exposes some parameters from the simulation to the NMODL mechanisms. - These include ``v``, ``diam``, ``celsius`` in addition to the previously + These include ``v``, ``diam``, ``celsius`` and ``t`` in addition to the previously mentioned ion parameters. * Special variables should not be ``ASSIGNED`` or ``CONSTANT``, they are ``PARAMETER``. * ``diam`` and ``celsius`` can be set from the simulation side. -* ``v`` is a reserved varible name and can be written in NMODL. +* ``v`` is a reserved variable name and can be written in NMODL. * If Special variables are used in a ``PROCEDURE`` or ``FUNCTION``, they need to be passed as arguments. * ``dt`` is not exposed to NMODL mechanisms. @@ -76,4 +77,24 @@ Unsupported features However, ``CONSERVE`` statements are supported. * ``TABLE`` is not supported, calculations are exact. * ``derivimplicit`` solving method is not supported, use ``cnexp`` instead. +* `verbatim` blocks are not supported. +Arbor-specific features +----------------------- + +* Arbor's NMODL dialect supports the most widely used features of NEURON. It also + has some features unavailable in NEURON such as the ``POST_EVENT`` procedure block. + This procedure has a single argument representing the time since the last spike on + the cell. In the event of multiple detectors on the cell, and multiple spikes on the + detectors within the same integration period, the times of each of these spikes will + be processed by the ``POST_EVENT`` block. Spikes are processed only once and then + cleared. + + Example of a ``POST_EVENT`` procedure, where ``g`` is a ``STATE`` parameter representing + the conductance: + + .. code:: + + POST_EVENT(t) { + g = g + (0.1*t) + } \ No newline at end of file diff --git a/modcc/expression.cpp b/modcc/expression.cpp index 51073654c49163b500a386279cf6f5a8bb5e87d1..a2eb8daaf16e7741fc112bb26bfd998f8a9b1409 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -31,6 +31,8 @@ inline std::string to_string(procedureKind k) { return "initial"; case procedureKind::net_receive : return "net_receive"; + case procedureKind::post_event : + return "post_event"; case procedureKind::breakpoint : return "breakpoint"; case procedureKind::derivative : @@ -715,6 +717,41 @@ void NetReceiveExpression::semantic(scope_type::symbol_map &global_symbols) { symbol_ = scope_->find_global(name()); } +/******************************************************************************* + PostEventExpression +*******************************************************************************/ + +void PostEventExpression::semantic(scope_type::symbol_map &global_symbols) { + // assert that the symbol is already visible in the global_symbols + if(global_symbols.find(name()) == global_symbols.end()) { + throw compiler_exception( + "attempt to perform semantic analysis for procedure '" + + yellow(name()) + + "' which has not been added to global symbol table", + location_); + } + + // create the scope for this procedure + scope_ = std::make_shared<scope_type>(global_symbols); + error_ = false; + + // add the argumemts to the list of local variables + for(auto& a : args_) { + a->semantic(scope_); + if(a->has_error()) { + error(a->error_message()); + } + } + + // perform semantic analysis for each expression in the body + body_->semantic(scope_); + if(body_->has_error()) { + error(body_->error_message()); + } + + symbol_ = scope_->find_global(name()); +} + /******************************************************************************* FunctionExpression *******************************************************************************/ @@ -1105,6 +1142,9 @@ void ProcedureExpression::accept(Visitor *v) { void NetReceiveExpression::accept(Visitor *v) { v->visit(this); } +void PostEventExpression::accept(Visitor *v) { + v->visit(this); +} void APIMethod::accept(Visitor *v) { v->visit(this); } diff --git a/modcc/expression.hpp b/modcc/expression.hpp index 50c5c1b1c34aac5ed63f0d937e94de93e59fc255..b208ac4a8ba73c3af0213f72f8c04d5f52b6c932 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -47,6 +47,7 @@ class PDiffExpression; class VariableExpression; class ProcedureExpression; class NetReceiveExpression; +class PostEventExpression; class APIMethod; class IndexedVariable; class LocalVariable; @@ -79,6 +80,7 @@ enum class procedureKind { api, ///< API PROCEDURE initial, ///< INITIAL net_receive, ///< NET_RECEIVE + post_event, ///< POST_EVENT breakpoint, ///< BREAKPOINT kinetic, ///< KINETIC derivative, ///< DERIVATIVE @@ -234,6 +236,7 @@ public : virtual VariableExpression* is_variable() {return nullptr;} virtual ProcedureExpression* is_procedure() {return nullptr;} virtual NetReceiveExpression* is_net_receive() {return nullptr;} + virtual PostEventExpression* is_post_event() {return nullptr;} virtual APIMethod* is_api_method() {return nullptr;} virtual IndexedVariable* is_indexed_variable() {return nullptr;} virtual LocalVariable* is_local_variable() {return nullptr;} @@ -1120,6 +1123,24 @@ protected: InitialBlock* initial_block_ = nullptr; }; +/// handle PostEventExpression as a special case of ProcedureExpression +class PostEventExpression : public ProcedureExpression { +public: + PostEventExpression( Location loc, + std::string name, + std::vector<expression_ptr>&& args, + expression_ptr&& body) + : ProcedureExpression(loc, std::move(name), std::move(args), std::move(body), procedureKind::post_event) + {} + + void semantic(scope_type::symbol_map &scp) override; + PostEventExpression* is_post_event() override {return this;} + /// hard code the kind + procedureKind kind() {return procedureKind::post_event;} + + void accept(Visitor *v) override; +}; + class FunctionExpression : public Symbol { public: FunctionExpression( Location loc, diff --git a/modcc/module.cpp b/modcc/module.cpp index a9b86b6db9efe8d7320740cc16eda0952f250190..d099ef2322f4f77d0b7353743540e3c625a56f35 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -537,6 +537,8 @@ bool Module::semantic() { } linear_ = linear; + post_events_ = has_symbol("post_event", symbolKind::procedure); + // Are we writing an ionic reversal potential? If so, change the moduleKind to // `revpot` and assert that the mechanism is 'pure': it has no state variables; // it contributes to no currents, ionic or otherwise; it isn't a point mechanism; diff --git a/modcc/module.hpp b/modcc/module.hpp index 09e29d407594c1f3ef2201ca364083f065e185c7..f29ddd06d99408d11fcd1b7405d20e394e6fbeee 100644 --- a/modcc/module.hpp +++ b/modcc/module.hpp @@ -104,6 +104,7 @@ public: }; bool is_linear() const { return linear_; } + bool has_post_events() const { return post_events_; } private: moduleKind kind_; @@ -118,6 +119,7 @@ private: ParameterBlock parameter_block_; AssignedBlock assigned_block_; bool linear_; + bool post_events_; // AST storage. std::vector<symbol_ptr> callables_; diff --git a/modcc/parser.cpp b/modcc/parser.cpp index d52f9416899074dfaa937b9148fa154d1888c882..86ae2accd4984c311ebdc5b58d8c5a75c52a4711 100644 --- a/modcc/parser.cpp +++ b/modcc/parser.cpp @@ -108,6 +108,7 @@ bool Parser::parse() { case tok::net_receive: case tok::breakpoint: case tok::initial: + case tok::post_event: case tok::kinetic: case tok::linear: case tok::derivative: @@ -932,6 +933,10 @@ symbol_ptr Parser::parse_procedure() { kind = procedureKind::net_receive; p = parse_prototype("net_receive"); break; + case tok::post_event: + kind = procedureKind::post_event; + p = parse_prototype("post_event"); + break; default: // it is a compiler error if trying to parse_procedure() without // having DERIVATIVE, KINETIC, PROCEDURE, INITIAL or BREAKPOINT keyword @@ -949,12 +954,13 @@ symbol_ptr Parser::parse_procedure() { if (body == nullptr) return nullptr; auto proto = p->is_prototype(); - if (kind != procedureKind::net_receive) { - return make_symbol<ProcedureExpression>(proto->location(), proto->name(), std::move(proto->args()), std::move(body), kind); + if(kind == procedureKind::net_receive) { + return make_symbol<NetReceiveExpression> (proto->location(), proto->name(), std::move(proto->args()), std::move(body)); } - else { - return make_symbol<NetReceiveExpression>(proto->location(), proto->name(), std::move(proto->args()), std::move(body)); + if(kind == procedureKind::post_event) { + return make_symbol<PostEventExpression> (proto->location(), proto->name(), std::move(proto->args()), std::move(body)); } + return make_symbol<ProcedureExpression> (proto->location(), proto->name(), std::move(proto->args()), std::move(body), kind); } symbol_ptr Parser::parse_function() { diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index 0e333686f31edc7a194d262cd7d12ade200da1b1..23ecf2f06203b26f58a4dfa9356ca458b9ccbd36 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -108,6 +108,7 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { auto ns_components = namespace_components(opt.cpp_namespace); NetReceiveExpression* net_receive = find_net_receive(module_); + PostEventExpression* post_event = find_post_event(module_); APIMethod* init_api = find_api_method(module_, "nrn_init"); APIMethod* state_api = find_api_method(module_, "nrn_state"); APIMethod* current_api = find_api_method(module_, "nrn_current"); @@ -254,6 +255,9 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "void deliver_events(deliverable_event_stream::state events) override;\n" "void net_receive(int i_, value_type weight);\n"; + post_event && out << + "void post_event() override;\n"; + with_simd && out << "unsigned simd_width() const override { return simd_width_; }\n"; out << @@ -388,6 +392,25 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "}\n\n"; } + if(post_event) { + const std::string time_arg = post_event->args().empty() ? "time" : post_event->args().front()->is_argument()->name(); + out << + "void " << class_name << "::post_event() {\n" << indent << + "int n_ = width_;\n" + "for (int i_ = 0; i_ < n_; ++i_) {\n" << indent << + "auto node_index_i_ = node_index_[i_];\n" + "auto cid_ = vec_ci_[node_index_i_];\n" + "auto offset_ = n_detectors_ * cid_;\n" + "for (unsigned c = 0; c < n_detectors_; c++) {\n" << indent << + "auto " << time_arg << " = time_since_spike_[offset_ + c];\n" + "if (" << time_arg << " >= 0) {\n" << indent << + cprint(post_event->body()) << popindent << + "}\n" << popindent << + "}\n" << popindent << + "}\n" << popindent << + "}\n\n"; + } + auto emit_body = [&](APIMethod *p) { if (with_simd) { emit_simd_api_body(out, p, vars.scalars); diff --git a/modcc/printer/gpuprinter.cpp b/modcc/printer/gpuprinter.cpp index b0f67d1427a607586ab26fc7c4e3947f28e20016..31148051a7eed63a3a7b097cea096c2b729fb218 100644 --- a/modcc/printer/gpuprinter.cpp +++ b/modcc/printer/gpuprinter.cpp @@ -55,6 +55,7 @@ std::string emit_gpu_cpp_source(const Module& module_, const printer_options& op auto ns_components = namespace_components(opt.cpp_namespace); NetReceiveExpression* net_receive = find_net_receive(module_); + PostEventExpression* post_event = find_post_event(module_); auto vars = local_module_variables(module_); auto ion_deps = module_.ion_deps(); @@ -85,6 +86,10 @@ std::string emit_gpu_cpp_source(const Module& module_, const printer_options& op "void " << class_name << "_deliver_events_(int mech_id, " << ppack_name << "&, deliverable_event_stream_state events);\n"; + post_event && out << + "void " << class_name << "_post_event_(" << ppack_name << "&);\n"; + + out << "\n" "class " << class_name << ": public ::arb::gpu::mechanism {\n" @@ -115,6 +120,11 @@ std::string emit_gpu_cpp_source(const Module& module_, const printer_options& op class_name << "_deliver_events_(mechanism_id_, pp_, events);\n" << popindent << "}\n\n"; + post_event && out << + "void post_event() override {\n" << indent << + class_name << "_post_event_(pp_);\n" << popindent << + "}\n\n"; + out << popindent << "protected:\n" << indent << "std::size_t object_sizeof() const override { return sizeof(*this); }\n" @@ -215,6 +225,7 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt const bool is_point_proc = module_.kind() == moduleKind::point; NetReceiveExpression* net_receive = find_net_receive(module_); + PostEventExpression* post_event = find_post_event(module_); APIMethod* init_api = find_api_method(module_, "nrn_init"); APIMethod* state_api = find_api_method(module_, "nrn_state"); APIMethod* current_api = find_api_method(module_, "nrn_current"); @@ -314,6 +325,27 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt << popindent << "}\n"; } + // event delivery + if (post_event) { + const std::string time_arg = post_event->args().empty() ? "time" : post_event->args().front()->is_argument()->name(); + out << "__global__\n" + << "void post_event(" << ppack_name << " params_) {\n" << indent + << "int n_ = params_.width_;\n" + << "auto tid_ = threadIdx.x + blockDim.x*blockIdx.x;\n" + << "if (tid_<n_) {\n" << indent + << "auto node_index_i_ = params_.node_index_[tid_];\n" + << "auto cid_ = params_.vec_ci_[node_index_i_];\n" + << "auto offset_ = params_.n_detectors_ * cid_;\n" + << "for (unsigned c = 0; c < params_.n_detectors_; c++) {\n" << indent + << "auto " << time_arg << " = params_.time_since_spike_[offset_ + c];\n" + << "if (" << time_arg << " >= 0) {\n" << indent + << cuprint(post_event->body()) + << popindent << "}\n" + << popindent << "}\n" + << popindent << "}\n" + << popindent << "}\n"; + } + out << "} // namspace\n\n"; // close anonymous namespace // Write wrappers. @@ -346,6 +378,15 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt << "deliver_events<<<grid_dim, block_dim>>>(mech_id, p, events);\n" << popindent << "}\n\n"; + post_event && out + << "void " << class_name << "_post_event_(" + << ppack_name << "& p) {\n" << indent + << "auto n = p.width_;\n" + << "unsigned block_dim = 128;\n" + << "unsigned grid_dim = ::arb::gpu::impl::block_count(n, block_dim);\n" + << "post_event<<<grid_dim, block_dim>>>(p);\n" + << popindent << "}\n\n"; + out << namespace_declaration_close(ns_components); return out.str(); } diff --git a/modcc/printer/infoprinter.cpp b/modcc/printer/infoprinter.cpp index a5dcf6040426b70395de4e1c9fd4c8c1574754c1..7b4ea8928ce5d23dbdc2206ab4b92d3249546085 100644 --- a/modcc/printer/infoprinter.cpp +++ b/modcc/printer/infoprinter.cpp @@ -131,7 +131,9 @@ std::string build_info_header(const Module& m, const printer_options& opt) { "// fingerprint\n" << quote(fingerprint) << ",\n" "// linear, homogeneous mechanism\n" - << m.is_linear() << "\n" + << m.is_linear() << ",\n" + "// post_events enabled mechanism\n" + << m.has_post_events() << "\n" << popindent << "};\n" "\n" "return info;\n" diff --git a/modcc/printer/printerutil.cpp b/modcc/printer/printerutil.cpp index 787a8ebebeeb5be44f0223949f76ce5e3cd4db8f..5309e528f1c5f6c70f5c52c0c43cd1fba745e56a 100644 --- a/modcc/printer/printerutil.cpp +++ b/modcc/printer/printerutil.cpp @@ -45,7 +45,8 @@ std::vector<ProcedureExpression*> normal_procedures(const Module& m) { for (auto& sym: m.symbols()) { auto proc = sym.second->is_procedure(); - if (proc && proc->kind()==procedureKind::normal && !proc->is_api_method() && !proc->is_net_receive()) { + if (proc && proc->kind()==procedureKind::normal && !proc->is_api_method() + && !proc->is_net_receive() && !proc->is_post_event()) { procs.push_back(proc); } } @@ -113,6 +114,11 @@ NetReceiveExpression* find_net_receive(const Module& m) { return it==m.symbols().end()? nullptr: it->second->is_net_receive(); } +PostEventExpression* find_post_event(const Module& m) { + auto it = m.symbols().find("post_event"); + return it==m.symbols().end()? nullptr: it->second->is_post_event(); +} + indexed_variable_info decode_indexed_variable(IndexedVariable* sym) { indexed_variable_info v; v.node_index_var = "node_index_"; @@ -157,7 +163,7 @@ indexed_variable_info decode_indexed_variable(IndexedVariable* sym) { break; case sourceKind::time: v.data_var = "vec_t_"; - v.cell_index_var = "vec_ci_"; + v.cell_index_var = "vec_di_"; v.readonly = true; break; case sourceKind::ion_current_density: diff --git a/modcc/printer/printerutil.hpp b/modcc/printer/printerutil.hpp index 4fbae5286e5af4b150b39ec5ecb50000c030a568..902290a1c79fbb7b72180aea0d1c166440628270 100644 --- a/modcc/printer/printerutil.hpp +++ b/modcc/printer/printerutil.hpp @@ -114,6 +114,8 @@ APIMethod* find_api_method(const Module& m, const char* which); NetReceiveExpression* find_net_receive(const Module& m); +PostEventExpression* find_post_event(const Module& m); + struct indexed_variable_info { std::string data_var; std::string node_index_var; diff --git a/modcc/token.cpp b/modcc/token.cpp index 30ed8def41408a16c6f37022d4d1a7b021a0cd75..e2226c7e4f45ddf7b6f31c6ab776e8b7ae139543 100644 --- a/modcc/token.cpp +++ b/modcc/token.cpp @@ -36,6 +36,7 @@ static Keyword keywords[] = { {"FUNCTION", tok::function}, {"INITIAL", tok::initial}, {"NET_RECEIVE", tok::net_receive}, + {"POST_EVENT", tok::post_event}, {"UNITSOFF", tok::unitsoff}, {"UNITSON", tok::unitson}, {"SUFFIX", tok::suffix}, @@ -117,6 +118,7 @@ static TokenString token_strings[] = { {"FUNCTION", tok::function}, {"INITIAL", tok::initial}, {"NET_RECEIVE", tok::net_receive}, + {"POST_EVENT", tok::post_event}, {"UNITSOFF", tok::unitsoff}, {"UNITSON", tok::unitson}, {"SUFFIX", tok::suffix}, diff --git a/modcc/token.hpp b/modcc/token.hpp index 21648dcdb84a0094be95eeb02c0b18b1e5c927ce..e3a7a12ca96b3afd04cdf776c6bd32ed8ddb316b 100644 --- a/modcc/token.hpp +++ b/modcc/token.hpp @@ -54,7 +54,7 @@ enum class tok { neuron, units, parameter, constant, assigned, state, breakpoint, derivative, kinetic, procedure, initial, function, linear, - net_receive, + net_receive, post_event, // keywoards inside blocks unitsoff, unitson, diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 040752618a9af914f1131b527d42d18a85d060a5..dfda9efc50ffd91fee5daa2626f41dec0b25a53f 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -4,10 +4,14 @@ set(test_mechanisms ca_linear celsius_test diam_test + fixed_ica_current + linear_ca_conc + non_linear param_as_state - test_linear_state - test_linear_init - test_linear_init_shuffle + point_ica_current + post_events_syn + read_cai_init + read_eX test0_kin_diff test0_kin_conserve test0_kin_compartment @@ -15,25 +19,22 @@ set(test_mechanisms test1_kin_diff test1_kin_conserve test1_kin_compartment + test1_kin_steadystate test2_kin_diff test3_kin_diff test4_kin_compartment - test1_kin_steadystate - fixed_ica_current - point_ica_current - non_linear - linear_ca_conc - test_cl_valence - test_ca_read_valence - read_eX - write_Xi_Xo - write_multiple_eX - write_eX - read_cai_init - write_cai_breakpoint test_ca + test_ca_read_valence + test_cl_valence + test_linear_state + test_linear_init + test_linear_init_shuffle test_kin1 test_kinlva + write_cai_breakpoint + write_eX + write_multiple_eX + write_Xi_Xo ) include(${PROJECT_SOURCE_DIR}/mechanisms/BuildModules.cmake) diff --git a/test/unit/mod/post_events_syn.mod b/test/unit/mod/post_events_syn.mod new file mode 100644 index 0000000000000000000000000000000000000000..a0292b12ec6e3eab32b91233f24069ad044c6bd5 --- /dev/null +++ b/test/unit/mod/post_events_syn.mod @@ -0,0 +1,35 @@ +NEURON { + POINT_PROCESS post_events_syn + RANGE tau, e + NONSPECIFIC_CURRENT i +} + +PARAMETER { + tau = 2.0 (ms) : the default for Neuron is 0.1 + e = 0 (mV) +} + +STATE { + g +} + +INITIAL { + g=0 +} + +BREAKPOINT { + SOLVE state METHOD cnexp + i = g*(v - e) +} + +DERIVATIVE state { + g' = -g/tau +} + +NET_RECEIVE(weight) { + g = g + weight +} + +POST_EVENT(t) { + g = g + (0.1*t) +} diff --git a/test/unit/test_fvm_lowered.cpp b/test/unit/test_fvm_lowered.cpp index e95018ac0884a24fd91ac209fee1f47f96b41930..82b6bba4576243cef1ed183dbcd17f7670aa965c 100644 --- a/test/unit/test_fvm_lowered.cpp +++ b/test/unit/test_fvm_lowered.cpp @@ -604,6 +604,7 @@ TEST(fvm_lowered, ionic_concentrations) { std::vector<fvm_value_type> diam(ncv, 1.); std::vector<fvm_value_type> vinit(ncv, -65); std::vector<fvm_gap_junction> gj = {}; + std::vector<fvm_index_type> src_to_spike = {}; fvm_ion_config ion_config; mechanism_layout layout; @@ -627,7 +628,7 @@ TEST(fvm_lowered, ionic_concentrations) { auto& write_cai_mech = write_cai.mech; auto shared_state = std::make_unique<typename backend::shared_state>( - ncell, cv_to_intdom, gj, vinit, temp, diam, read_cai_mech->data_alignment()); + ncell, ncell, 0, cv_to_intdom, cv_to_intdom, gj, vinit, temp, diam, src_to_spike, read_cai_mech->data_alignment()); shared_state->add_ion("ca", 2, ion_config); read_cai_mech->instantiate(0, *shared_state, overrides, layout); @@ -1207,3 +1208,124 @@ TEST(fvm_lowered, integration_domains) { } } +TEST(fvm_lowered, post_events_shared_state) { + arb::proc_allocation resources; + if (auto nt = arbenv::get_env_num_threads()) { + resources.num_threads = nt; + } else { + resources.num_threads = arbenv::thread_concurrency(); + } + arb::execution_context context(resources); + + class detector_recipe: public arb::recipe { + public: + detector_recipe(unsigned ncv, std::vector<unsigned> detectors_per_cell, std::string synapse): + ncell_(detectors_per_cell.size()), + ncv_(ncv), + detectors_per_cell_(detectors_per_cell), + synapse_(synapse), + cat_(make_unit_test_catalogue()) { + const auto default_cat = arb::global_default_catalogue(); + cat_.import(default_cat, ""); + } + + cell_size_type num_cells() const override { + return ncell_; + } + + arb::util::unique_any get_cell_description(cell_gid_type gid) const override { + arb::segment_tree tree; + tree.append(arb::mnpos, {0, 0, 0.0, 1.0}, {0, 0, 200, 1.0}, 1); + + arb::decor decor; + decor.set_default(arb::cv_policy_fixed_per_branch(ncv_)); + + auto ndetectors = detectors_per_cell_[gid]; + auto offset = 1.0 / ndetectors; + for (unsigned i = 0; i < ndetectors; ++i) { + decor.place(arb::mlocation{0, offset * i}, arb::threshold_detector{10}); + } + decor.place(arb::mlocation{0, 0.5}, synapse_); + + return arb::cable_cell(arb::morphology(tree), {}, decor);; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable; + } + + // Each cell has one spike detector (at the soma). + cell_size_type num_sources(cell_gid_type gid) const override { + return detectors_per_cell_[gid]; + } + + // The cell has one target synapse, which will be connected to cell gid-1. + cell_size_type num_targets(cell_gid_type gid) const override { + return 1; + } + + std::any get_global_properties(arb::cell_kind) const override { + arb::cable_cell_global_properties gprop; + gprop.default_parameters = arb::neuron_parameter_defaults; + gprop.catalogue = &cat_; + return gprop; + } + + private: + unsigned ncell_; + unsigned ncv_; + std::vector<unsigned> detectors_per_cell_; + arb::mechanism_desc synapse_; + mechanism_catalogue cat_; + }; + + std::vector<target_handle> targets; + probe_association_map probe_map; + + std::vector<unsigned> gids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + const unsigned ncell = gids.size(); + const unsigned cv_per_cell = 10; + + std::vector<std::vector<unsigned>> detectors_per_cell_vec = { + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 6, 2, 1, 3, 2, 1, 2, 1, 4}, + }; + + for (const auto& detectors_per_cell: detectors_per_cell_vec) { + detector_recipe rec(cv_per_cell, detectors_per_cell, "post_events_syn"); + std::vector<fvm_index_type> cell_to_intdom; + + fvm_cell fvcell(context); + fvcell.initialize(gids, rec, cell_to_intdom, targets, probe_map); + + auto& S = fvcell.*private_state_ptr; + + auto expected_detectors = util::max_value(detectors_per_cell); + + EXPECT_EQ(expected_detectors, S->n_detector); + EXPECT_EQ(util::sum(detectors_per_cell), S->src_to_spike.size()); + EXPECT_EQ(expected_detectors * ncell, S->time_since_spike.size()); + + unsigned detector_id = 0; + for (unsigned c = 0; c < detectors_per_cell.size(); ++c) { + for (unsigned d = 0; d < detectors_per_cell[c]; ++d) { + EXPECT_EQ((int) (c * expected_detectors + d), S->src_to_spike[detector_id++]); + } + } + } + for (const auto& detectors_per_cell: detectors_per_cell_vec) { + detector_recipe rec(cv_per_cell, detectors_per_cell, "expsyn"); + std::vector<fvm_index_type> cell_to_intdom; + + fvm_cell fvcell(context); + fvcell.initialize(gids, rec, cell_to_intdom, targets, probe_map); + + auto& S = fvcell.*private_state_ptr; + + EXPECT_EQ(0u, S->n_detector); + EXPECT_EQ(0u, S->src_to_spike.size()); + EXPECT_EQ(0u, S->time_since_spike.size()); + } + +} diff --git a/test/unit/test_kinetic_linear.cpp b/test/unit/test_kinetic_linear.cpp index feb1f8f696c1d4dd8bdb8c3601c1f74f2bf9dd60..b45731a265deea17ceebe74d960a492e140b253c 100644 --- a/test/unit/test_kinetic_linear.cpp +++ b/test/unit/test_kinetic_linear.cpp @@ -45,9 +45,10 @@ void run_test(std::string mech_name, std::vector<fvm_value_type> temp(ncv, 300.); std::vector<fvm_value_type> diam(ncv, 1.); std::vector<fvm_value_type> vinit(ncv, -65); + std::vector<fvm_index_type> src_to_spike = {}; auto shared_state = std::make_unique<typename backend::shared_state>( - ncell, cv_to_intdom, gj, vinit, temp, diam, test->data_alignment()); + ncell, ncell, 0, cv_to_intdom, cv_to_intdom, gj, vinit, temp, diam, src_to_spike, test->data_alignment()); mechanism_layout layout; mechanism_overrides overrides; diff --git a/test/unit/test_mech_temp_diam.cpp b/test/unit/test_mech_temp_diam.cpp index a40f8cfc2fe696c53fabd3d99c921bcacaf38c98..ba051960eab6c256b366f39e698748de25d2f2e4 100644 --- a/test/unit/test_mech_temp_diam.cpp +++ b/test/unit/test_mech_temp_diam.cpp @@ -34,9 +34,10 @@ void run_celsius_test() { std::vector<fvm_value_type> temp(ncv, temperature_K); std::vector<fvm_value_type> diam(ncv, 1.); std::vector<fvm_value_type> vinit(ncv, -65); + std::vector<fvm_index_type> src_to_spike = {}; auto shared_state = std::make_unique<typename backend::shared_state>( - ncell, cv_to_intdom, gj, vinit, temp, diam, celsius_test->data_alignment()); + ncell, ncell, 0, cv_to_intdom, cv_to_intdom, gj, vinit, temp, diam, src_to_spike, celsius_test->data_alignment()); mechanism_layout layout; mechanism_overrides overrides; @@ -81,6 +82,7 @@ void run_diam_test() { std::vector<fvm_value_type> temp(ncv, 300.); std::vector<fvm_value_type> vinit(ncv, -65); std::vector<fvm_value_type> diam(ncv); + std::vector<fvm_index_type> src_to_spike = {}; mechanism_layout layout; mechanism_overrides overrides; @@ -93,7 +95,7 @@ void run_diam_test() { } auto shared_state = std::make_unique<typename backend::shared_state>( - ncell, cv_to_intdom, gj, vinit, temp, diam, celsius_test->data_alignment()); + ncell, ncell, 0, cv_to_intdom, cv_to_intdom, gj, vinit, temp, diam, src_to_spike, celsius_test->data_alignment()); celsius_test->instantiate(0, *shared_state, overrides, layout); diff --git a/test/unit/test_spikes.cpp b/test/unit/test_spikes.cpp index dd1f8d1e1298c4a8bc9c998f5b569cfe4f68f1b1..d207f586c7adff0180f456ea70126cb83ad57f9e 100644 --- a/test/unit/test_spikes.cpp +++ b/test/unit/test_spikes.cpp @@ -42,6 +42,10 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { const std::vector<index_type> index{0, 5, 7}; const std::vector<value_type> thresh{1., 2., 3.}; + std::vector<int> src_to_spike_vec = {0, 1, 5}; + std::vector<fvm_value_type> time_since_spike_vec(10); + memory::fill(time_since_spike_vec, -1.0); + // all values are initially 0, except for values[5] which we set // to exceed the threshold of 2. for the second watch array values(n, 0); @@ -57,11 +61,18 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { array time_before(2, 0.); array time_after(2, 0.); + iarray src_to_spike(src_to_spike_vec.size()); + memory::copy(src_to_spike_vec, src_to_spike); + + array time_since_spike(10, -1.0); + std::vector<unsigned> empty_slots = {2, 3, 4, 6, 7, 8, 9}; + // list for storing expected crossings for validation at the end list expected; // create the watch - backend::threshold_watcher watch(cell_index.data(), values.data(), &time_before, &time_after, index, thresh, context); + backend::threshold_watcher watch(cell_index.data(), values.data(), src_to_spike.data(), + &time_before, &time_after, index, thresh, context); // initially the first and third watch should not be spiking // the second is spiking @@ -72,18 +83,23 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { // test again at t=1, with unchanged values // - nothing should change memory::fill(time_after, 1.); - watch.test(); + watch.test(&time_since_spike); EXPECT_FALSE(watch.is_crossed(0)); EXPECT_TRUE(watch.is_crossed(1)); EXPECT_FALSE(watch.is_crossed(2)); EXPECT_EQ(watch.crossings().size(), 0u); + memory::copy(time_since_spike, time_since_spike_vec); + for (auto t: time_since_spike_vec) { + EXPECT_EQ(-1.0, t); + } + // test at t=2, with all values set to zero // - 2nd watch should now stop spiking memory::fill(values, 0.); memory::copy(time_after, time_before); memory::fill(time_after, 2.); - watch.test(); + watch.test(&time_since_spike); EXPECT_FALSE(watch.is_crossed(0)); EXPECT_FALSE(watch.is_crossed(1)); EXPECT_FALSE(watch.is_crossed(2)); @@ -95,7 +111,7 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { memory::copy(time_after, time_before); time_after[0] = 2.5; time_after[1] = 3.0; - watch.test(); + watch.test(&time_since_spike); EXPECT_TRUE(watch.is_crossed(0)); EXPECT_TRUE(watch.is_crossed(1)); EXPECT_TRUE(watch.is_crossed(2)); @@ -106,29 +122,50 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { expected.push_back({1u, 2.250f}); // 2. + (2.5-2)*(2./4.) expected.push_back({2u, 2.750f}); // 2. + (3.0-2)*(3./4.) + memory::copy(time_since_spike, time_since_spike_vec); + EXPECT_EQ(0.375, time_since_spike_vec[src_to_spike[0]]); // 2.5 - 2.125 + EXPECT_EQ(0.250, time_since_spike_vec[src_to_spike[1]]); // 2.5 - 2.250 + EXPECT_EQ(0.250, time_since_spike_vec[src_to_spike[2]]); // 3.0 - 2.750 + for (auto i: empty_slots) { + EXPECT_EQ(-1.0, time_since_spike_vec[i]); + } + // test at t=4, with all values set to 0. // - all watches should stop spiking memory::fill(values, 0.); memory::copy(time_after, time_before); memory::fill(time_after, 4.); - watch.test(); + watch.test(&time_since_spike); EXPECT_FALSE(watch.is_crossed(0)); EXPECT_FALSE(watch.is_crossed(1)); EXPECT_FALSE(watch.is_crossed(2)); EXPECT_EQ(watch.crossings().size(), 3u); + memory::copy(time_since_spike, time_since_spike_vec); + for (auto t: time_since_spike_vec) { + EXPECT_EQ(-1.0, t); + } + // test at t=5, with value on 3rd watch set to 6 // - watch 3 should be spiking values[index[2]] = 6.; memory::copy(time_after, time_before); memory::fill(time_after, 5.); - watch.test(); + watch.test(&time_since_spike); EXPECT_FALSE(watch.is_crossed(0)); EXPECT_FALSE(watch.is_crossed(1)); EXPECT_TRUE(watch.is_crossed(2)); EXPECT_EQ(watch.crossings().size(), 4u); expected.push_back({2u, 4.5f}); + memory::copy(time_since_spike, time_since_spike_vec); + EXPECT_EQ(-1.0, time_since_spike_vec[src_to_spike[0]]); + EXPECT_EQ(-1.0, time_since_spike_vec[src_to_spike[1]]); + EXPECT_EQ(0.50, time_since_spike_vec[src_to_spike[2]]); + for (auto i: empty_slots) { + EXPECT_EQ(-1.0, time_since_spike_vec[i]); + } + // // test that all generated spikes matched the expected values // diff --git a/test/unit/test_synapses.cpp b/test/unit/test_synapses.cpp index daab5fe6bb26ad6fb7e407fa580bf6b7e60a6488..4f3157234b8054c0041d54a8686ee358c5efa694 100644 --- a/test/unit/test_synapses.cpp +++ b/test/unit/test_synapses.cpp @@ -92,11 +92,15 @@ TEST(synapses, syn_basic_state) { auto align = std::max(expsyn->data_alignment(), exp2syn->data_alignment()); shared_state state(num_intdom, + num_intdom, + 0, + std::vector<index_type>(num_comp, 0), std::vector<index_type>(num_comp, 0), {}, std::vector<value_type>(num_comp, -65), std::vector<value_type>(num_comp, temp_K), std::vector<value_type>(num_comp, 1.), + std::vector<index_type>(0), align); state.reset(); diff --git a/test/unit/unit_test_catalogue.cpp b/test/unit/unit_test_catalogue.cpp index 1d3c5112acc171b360e282708fa9d623af2d782d..f137d4e9dc5bad5edcf9577766522f3d64753b38 100644 --- a/test/unit/unit_test_catalogue.cpp +++ b/test/unit/unit_test_catalogue.cpp @@ -10,8 +10,9 @@ #include "mechanisms/ca_linear.hpp" #include "mechanisms/celsius_test.hpp" #include "mechanisms/diam_test.hpp" -#include "mechanisms/param_as_state.hpp" #include "mechanisms/non_linear.hpp" +#include "mechanisms/param_as_state.hpp" +#include "mechanisms/post_events_syn.hpp" #include "mechanisms/test0_kin_diff.hpp" #include "mechanisms/test_linear_state.hpp" #include "mechanisms/test_linear_init.hpp" @@ -66,6 +67,7 @@ mechanism_catalogue make_unit_test_catalogue(const mechanism_catalogue& from) { ADD_MECH(cat, celsius_test) ADD_MECH(cat, diam_test) ADD_MECH(cat, param_as_state) + ADD_MECH(cat, post_events_syn) ADD_MECH(cat, test_linear_state) ADD_MECH(cat, test_linear_init) ADD_MECH(cat, test_linear_init_shuffle)