diff --git a/arbor/backends/gpu/fvm.hpp b/arbor/backends/gpu/fvm.hpp index 0fa52b63dc0db048b2dd882c602652cfd1220cb3..5cddad532eeb97a5500fa54fbc9f03fdb53bcb5e 100644 --- a/arbor/backends/gpu/fvm.hpp +++ b/arbor/backends/gpu/fvm.hpp @@ -57,9 +57,9 @@ struct backend { { return threshold_watcher( state.cv_to_intdom.data(), - state.time.data(), - state.time_to.data(), state.voltage.data(), + &state.time, + &state.time_to, cv, thresholds, context); diff --git a/arbor/backends/gpu/mechanism.cpp b/arbor/backends/gpu/mechanism.cpp index fc175c65604cf840e8933338b2a30502d4afe85d..e577be3f37cac83d49614fd6738706358160d711 100644 --- a/arbor/backends/gpu/mechanism.cpp +++ b/arbor/backends/gpu/mechanism.cpp @@ -85,8 +85,6 @@ void mechanism::instantiate(unsigned id, pp->width_ = width_; pp->vec_ci_ = shared.cv_to_intdom.data(); - pp->vec_t_ = shared.time.data(); - pp->vec_t_to_ = shared.time_to.data(); pp->vec_dt_ = shared.dt_cv.data(); pp->vec_v_ = shared.voltage.data(); @@ -116,6 +114,8 @@ void mechanism::instantiate(unsigned id, } event_stream_ptr_ = &shared.deliverable_events; + vec_t_ptr_ = &shared.time; + vec_t_to_ptr_ = &shared.time_to; // If there are no sites (is this ever meaningful?) there is nothing more to do. if (width_==0) { @@ -208,8 +208,10 @@ fvm_value_type* mechanism::field_data(const std::string& field_var) { void multiply_in_place(fvm_value_type* s, const fvm_index_type* p, int n); void mechanism::initialize() { - nrn_init(); mechanism_ppack_base* pp = ppack_ptr(); + pp->vec_t_ = vec_t_ptr_->data(); + + nrn_init(); auto states = state_table(); if(mult_in_place_) { diff --git a/arbor/backends/gpu/mechanism.hpp b/arbor/backends/gpu/mechanism.hpp index 402507d690afe8af7638eb9035ef5e4d3464d749..8f04a2383d83e786b65ea2382652ef82172c8ce6 100644 --- a/arbor/backends/gpu/mechanism.hpp +++ b/arbor/backends/gpu/mechanism.hpp @@ -52,6 +52,21 @@ public: // Delegate to derived class, passing in event queue state. deliver_events(event_stream_ptr_->marked_events()); } + void update_current() override { + mechanism_ppack_base* pp = ppack_ptr(); + pp->vec_t_ = vec_t_ptr_->data(); + nrn_current(); + } + void update_state() override { + mechanism_ppack_base* pp = ppack_ptr(); + pp->vec_t_ = vec_t_ptr_->data(); + nrn_state(); + } + void update_ions() override { + mechanism_ppack_base* pp = ppack_ptr(); + pp->vec_t_ = vec_t_ptr_->data(); + write_ions(); + } void set_parameter(const std::string& key, const std::vector<fvm_value_type>& values) override; @@ -75,6 +90,8 @@ protected: virtual mechanism_ppack_base* ppack_ptr() = 0; deliverable_event_stream* event_stream_ptr_; + const array* vec_t_ptr_; + const array* vec_t_to_ptr_; // Bulk storage for index vectors and state and parameter variables. @@ -127,7 +144,10 @@ protected: // Event delivery, given event queue state: + virtual void nrn_state() {}; + virtual void nrn_current() {}; virtual void deliver_events(deliverable_event_stream::state) {}; + virtual void write_ions() {}; }; } // namespace gpu diff --git a/arbor/backends/gpu/threshold_watcher.hpp b/arbor/backends/gpu/threshold_watcher.hpp index 5c5fb246514b3142083e385a39ff1b6cbec28f1b..bf533073b4a3b2ecdc072dbfdfaf2bec2cc75619 100644 --- a/arbor/backends/gpu/threshold_watcher.hpp +++ b/arbor/backends/gpu/threshold_watcher.hpp @@ -44,17 +44,17 @@ public: threshold_watcher( const fvm_index_type* cv_to_intdom, - const fvm_value_type* t_before, - const fvm_value_type* t_after, const fvm_value_type* values, + const array* t_before, + const array* t_after, const std::vector<fvm_index_type>& cv_index, const std::vector<fvm_value_type>& thresholds, const execution_context& ctx ): cv_to_intdom_(cv_to_intdom), - t_before_(t_before), - t_after_(t_after), values_(values), + t_before_ptr_(t_before), + t_after_ptr_(t_after), cv_index_(memory::make_const_view(cv_index)), is_crossed_(cv_index.size()), thresholds_(memory::make_const_view(thresholds)), @@ -105,10 +105,13 @@ public: /// crossed since current time t, and the last time the test was /// performed. void test() { + const fvm_value_type* t_before = t_before_ptr_->data(); + const fvm_value_type* t_after = t_after_ptr_->data(); + if (size()>0) { test_thresholds_impl( (int)size(), - cv_to_intdom_, t_after_, t_before_, + cv_to_intdom_, t_after, t_before, stack_.storage(), is_crossed_.data(), v_prev_.data(), cv_index_.data(), values_, thresholds_.data()); @@ -124,12 +127,13 @@ public: } private: - /// Non-owning pointers to gpu-side cv-to-cell map, per-cell time data, - /// and the values for to test against thresholds. + /// Non-owning pointers to cv-to-intdom map, + /// the values for to test against thresholds, + /// and pointers to the time arrays const fvm_index_type* cv_to_intdom_ = nullptr; - const fvm_value_type* t_before_ = nullptr; - const fvm_value_type* t_after_ = nullptr; const fvm_value_type* values_ = nullptr; + const array* t_before_ptr_ = nullptr; + const array* t_after_ptr_ = nullptr; // Threshold watch state, with data on gpu: iarray cv_index_; // Compartment indexes of values to watch. diff --git a/arbor/backends/multicore/fvm.hpp b/arbor/backends/multicore/fvm.hpp index d7b233de9faeb8a1bfa1840ea0aa992775e3ecb9..e81b7d7dcf984733b128fd2703d925128dc85115 100644 --- a/arbor/backends/multicore/fvm.hpp +++ b/arbor/backends/multicore/fvm.hpp @@ -55,9 +55,9 @@ struct backend { { return threshold_watcher( state.cv_to_intdom.data(), - state.time.data(), - state.time_to.data(), state.voltage.data(), + &state.time, + &state.time_to, cv, thresholds, context); diff --git a/arbor/backends/multicore/mechanism.cpp b/arbor/backends/multicore/mechanism.cpp index 2aec088c825c1091fe1d918d498a79e0ba51d387..285c27dd189e525e42c2ef1a39674d51624349cb 100644 --- a/arbor/backends/multicore/mechanism.cpp +++ b/arbor/backends/multicore/mechanism.cpp @@ -84,8 +84,6 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me // Assign non-owning views onto shared state: vec_ci_ = shared.cv_to_intdom.data(); - vec_t_ = shared.time.data(); - vec_t_to_ = shared.time_to.data(); vec_dt_ = shared.dt_cv.data(); vec_v_ = shared.voltage.data(); @@ -113,6 +111,8 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me ion_view.ionic_charge = oion->charge.data(); } + vec_t_ptr_ = &shared.time; + vec_t_to_ptr_ = &shared.time_to; event_stream_ptr_ = &shared.deliverable_events; // If there are no sites (is this ever meaningful?) there is nothing more to do. @@ -198,6 +198,7 @@ void mechanism::set_parameter(const std::string& key, const std::vector<fvm_valu } void mechanism::initialize() { + vec_t_ = vec_t_ptr_->data(); nrn_init(); auto states = state_table(); diff --git a/arbor/backends/multicore/mechanism.hpp b/arbor/backends/multicore/mechanism.hpp index cf4a331cb9584992eac5245d08b8901c7847e69a..140c5d78ebe43a703815db9ee47c8bbdc7f55926 100644 --- a/arbor/backends/multicore/mechanism.hpp +++ b/arbor/backends/multicore/mechanism.hpp @@ -62,6 +62,18 @@ public: // Delegate to derived class, passing in event queue state. deliver_events(event_stream_ptr_->marked_events()); } + void update_current() override { + vec_t_ = vec_t_ptr_->data(); + nrn_current(); + } + void update_state() override { + vec_t_ = vec_t_ptr_->data(); + nrn_state(); + } + void update_ions() override { + vec_t_ = vec_t_ptr_->data(); + write_ions(); + } void set_parameter(const std::string& key, const std::vector<fvm_value_type>& values) override; @@ -84,6 +96,9 @@ protected: value_type* vec_g_; // CV to cell membrane conductivity. const value_type* temperature_degC_; // CV to temperature. const value_type* diam_um_; // CV to diameter. + + const array* vec_t_ptr_; + const array* vec_t_to_ptr_; deliverable_event_stream* event_stream_ptr_; // Per-mechanism index and weight data, excepting ion indices. @@ -147,7 +162,10 @@ protected: // Event delivery, given event queue state: + virtual void nrn_state() {}; + virtual void nrn_current() {}; virtual void deliver_events(deliverable_event_stream::state) {}; + virtual void write_ions() {}; }; } // namespace multicore diff --git a/arbor/backends/multicore/threshold_watcher.hpp b/arbor/backends/multicore/threshold_watcher.hpp index e356f486cc78db4837542f5e6e56c7ffce865545..7074dda5353430dbf9a8ce0cc9846c5829221c53 100644 --- a/arbor/backends/multicore/threshold_watcher.hpp +++ b/arbor/backends/multicore/threshold_watcher.hpp @@ -19,17 +19,17 @@ public: threshold_watcher( const fvm_index_type* cv_to_intdom, - const fvm_value_type* t_before, - const fvm_value_type* t_after, const fvm_value_type* values, + const array* t_before, + const array* t_after, const std::vector<fvm_index_type>& cv_index, const std::vector<fvm_value_type>& thresholds, const execution_context& context ): cv_to_intdom_(cv_to_intdom), - t_before_(t_before), - t_after_(t_after), values_(values), + t_before_ptr_(t_before), + t_after_ptr_(t_after), n_cv_(cv_index.size()), cv_index_(cv_index), is_crossed_(n_cv_), @@ -64,6 +64,8 @@ public: /// Crossing events are recorded for each threshold that /// is crossed since the last call to test void test() { + const fvm_value_type* t_before = t_before_ptr_->data(); + const fvm_value_type* t_after = t_after_ptr_->data(); for (fvm_size_type i = 0; i<n_cv_; ++i) { auto cv = cv_index_[i]; auto cell = cv_to_intdom_[cv]; @@ -76,7 +78,7 @@ public: // The threshold has been passed, so estimate the time using // linear interpolation. auto pos = (thresh - v_prev)/(v - v_prev); - auto crossing_time = math::lerp(t_before_[cell], t_after_[cell], pos); + auto crossing_time = math::lerp(t_before[cell], t_after[cell], pos); crossings_.push_back({i, crossing_time}); is_crossed_[i] = true; @@ -102,12 +104,13 @@ public: } private: - /// Non-owning pointers to cv-to-cell map, per-cell time data, - /// and the values for to test against thresholds. + /// Non-owning pointers to cv-to-intdom map, + /// the values for to test against thresholds, + /// and pointers to the time arrays const fvm_index_type* cv_to_intdom_ = nullptr; - const fvm_value_type* t_before_ = nullptr; - const fvm_value_type* t_after_ = nullptr; const fvm_value_type* values_ = nullptr; + const array* t_before_ptr_ = nullptr; + const array* t_after_ptr_ = nullptr; /// Threshold watcher state. fvm_size_type n_cv_ = 0; diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index b5a46f192981d026be20b2b2afdba4015009fda6..5dd419c39f2c18a54a9dcbe934628db73f4047ad 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -228,7 +228,7 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate( // Update any required reversal potentials based on ionic concs. for (auto& m: revpot_mechanisms_) { - m->nrn_current(); + m->update_current(); } @@ -243,7 +243,7 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate( PL(); for (auto& m: mechanisms_) { m->deliver_events(); - m->nrn_current(); + m->update_current(); } // Add current contribution from gap_junctions @@ -279,7 +279,7 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate( // Integrate mechanism state. for (auto& m: mechanisms_) { - m->nrn_state(); + m->update_state(); } // Update ion concentrations. @@ -330,7 +330,7 @@ template <typename Backend> void fvm_lowered_cell_impl<Backend>::update_ion_state() { state_->ions_init_concentration(); for (auto& m: mechanisms_) { - m->write_ions(); + m->update_ions(); } } diff --git a/arbor/include/arbor/mechanism.hpp b/arbor/include/arbor/mechanism.hpp index 37e497f865e7965c8591c7d94f6f188e697f9ffb..5e87a441fcb6986955858d6fdac54763cd4587bf 100644 --- a/arbor/include/arbor/mechanism.hpp +++ b/arbor/include/arbor/mechanism.hpp @@ -51,10 +51,10 @@ public: // Simulation interfaces: virtual void initialize() = 0; - virtual void nrn_state() = 0; - virtual void nrn_current() = 0; + virtual void update_state() {}; + virtual void update_current() {}; virtual void deliver_events() {}; - virtual void write_ions() = 0; + virtual void update_ions() {}; virtual ~mechanism() = default; diff --git a/test/ubench/mech_vec.cpp b/test/ubench/mech_vec.cpp index 6a531779b670503bacc436dcc8c7b8fb9b838a93..56d713e184df762ba5759f1d4e97432344dc77bb 100644 --- a/test/ubench/mech_vec.cpp +++ b/test/ubench/mech_vec.cpp @@ -279,7 +279,7 @@ void expsyn_1_branch_current(benchmark::State& state) { auto& m = find_mechanism("expsyn", cell); while (state.KeepRunning()) { - m->nrn_current(); + m->update_current(); } } @@ -299,7 +299,7 @@ void expsyn_1_branch_state(benchmark::State& state) { auto& m = find_mechanism("expsyn", cell); while (state.KeepRunning()) { - m->nrn_state(); + m->update_state(); } } @@ -318,7 +318,7 @@ void pas_1_branch_current(benchmark::State& state) { auto& m = find_mechanism("pas", cell); while (state.KeepRunning()) { - m->nrn_current(); + m->update_current(); } } @@ -337,7 +337,7 @@ void pas_3_branches_current(benchmark::State& state) { auto& m = find_mechanism("pas", cell); while (state.KeepRunning()) { - m->nrn_current(); + m->update_current(); } } @@ -356,7 +356,7 @@ void hh_1_branch_state(benchmark::State& state) { auto& m = find_mechanism("hh", cell); while (state.KeepRunning()) { - m->nrn_state(); + m->update_state(); } } @@ -375,7 +375,7 @@ void hh_1_branch_current(benchmark::State& state) { auto& m = find_mechanism("hh", cell); while (state.KeepRunning()) { - m->nrn_current(); + m->update_current(); } } @@ -394,7 +394,7 @@ void hh_3_branches_state(benchmark::State& state) { auto& m = find_mechanism("hh", cell); while (state.KeepRunning()) { - m->nrn_state(); + m->update_state(); } } @@ -413,7 +413,7 @@ void hh_3_branches_current(benchmark::State& state) { auto& m = find_mechanism("hh", cell); while (state.KeepRunning()) { - m->nrn_current(); + m->update_current(); } } diff --git a/test/unit/test_fvm_lowered.cpp b/test/unit/test_fvm_lowered.cpp index 4e7027558662f9885db22c585d89b965a1e8ae2b..455ed8257f6f8d47456f4459b0a0ccf3fbf8ff27 100644 --- a/test/unit/test_fvm_lowered.cpp +++ b/test/unit/test_fvm_lowered.cpp @@ -381,7 +381,7 @@ TEST(fvm_lowered, stimulus) { // Test that no current is injected at t=0. memory::fill(J, 0.); memory::fill(T, 0.); - stim->nrn_current(); + stim->update_current(); for (auto j: J) { EXPECT_EQ(j, 0.); @@ -390,19 +390,19 @@ TEST(fvm_lowered, stimulus) { // Test that 0.1 nA current is injected at soma at t=1. memory::fill(J, 0.); memory::fill(T, 1.); - stim->nrn_current(); + stim->update_current(); constexpr double unit_factor = 1e-3; // scale A/m²·µm² to nA EXPECT_DOUBLE_EQ(-0.1, J[soma_cv]*A[soma_cv]*unit_factor); // Test that 0.1 nA is again injected at t=1.5, for a total of 0.2 nA. memory::fill(T, 1.); - stim->nrn_current(); + stim->update_current(); EXPECT_DOUBLE_EQ(-0.2, J[soma_cv]*A[soma_cv]*unit_factor); // Test that at t=10, no more current is injected at soma, and that // that 0.3 nA is injected at dendrite tip. memory::fill(T, 10.); - stim->nrn_current(); + stim->update_current(); EXPECT_DOUBLE_EQ(-0.2, J[soma_cv]*A[soma_cv]*unit_factor); EXPECT_DOUBLE_EQ(-0.3, J[tip_cv]*A[tip_cv]*unit_factor); } @@ -640,13 +640,13 @@ TEST(fvm_lowered, ionic_concentrations) { EXPECT_EQ(expected_s_values, mechanism_field(read_cai_mech.get(), "s")); // expect 5.2 + 2.3 value in state 's' in read_cai_init after state update: - read_cai_mech->nrn_state(); - write_cai_mech->nrn_state(); + read_cai_mech->update_state(); + write_cai_mech->update_state(); - read_cai_mech->write_ions(); - write_cai_mech->write_ions(); + read_cai_mech->update_ions(); + write_cai_mech->update_ions(); - read_cai_mech->nrn_state(); + read_cai_mech->update_state(); expected_s_values.assign(ncv, 7.5e-4); EXPECT_EQ(expected_s_values, mechanism_field(read_cai_mech.get(), "s")); @@ -857,7 +857,7 @@ TEST(fvm_lowered, weighted_write_ion) { } ion.init_concentration(); - test_ca->write_ions(); + test_ca->update_ions(); std::vector<double> ion_iconc = util::assign_from(ion.Xi_); EXPECT_TRUE(testing::seq_almost_eq<double>(expected_iconc, ion_iconc)); } diff --git a/test/unit/test_kinetic_linear.cpp b/test/unit/test_kinetic_linear.cpp index 7ed540cc6f18616ff24215fe11d01031f0de5f96..feb1f8f696c1d4dd8bdb8c3601c1f74f2bf9dd60 100644 --- a/test/unit/test_kinetic_linear.cpp +++ b/test/unit/test_kinetic_linear.cpp @@ -78,7 +78,7 @@ void run_test(std::string mech_name, shared_state->update_time_to(dt, dt); shared_state->set_dt(); - test->nrn_state(); + test->update_state(); if (!t1_values.empty()) { for (unsigned i = 0; i < state_variables.size(); i++) { diff --git a/test/unit/test_mech_temp_diam.cpp b/test/unit/test_mech_temp_diam.cpp index fbda2292660c7f75b21f787e69a5b267ed0b8648..a40f8cfc2fe696c53fabd3d99c921bcacaf38c98 100644 --- a/test/unit/test_mech_temp_diam.cpp +++ b/test/unit/test_mech_temp_diam.cpp @@ -58,7 +58,7 @@ void run_celsius_test() { // expect temperature_C value in state 'c' after state update: - celsius_test->nrn_state(); + celsius_test->update_state(); expected_c_values.assign(ncv, temperature_C); EXPECT_EQ(expected_c_values, mechanism_field(celsius_test.get(), "c")); @@ -108,7 +108,7 @@ void run_diam_test() { // expect original diam values in state 'd' after state update: - celsius_test->nrn_state(); + celsius_test->update_state(); expected_d_values = diam; EXPECT_EQ(expected_d_values, mechanism_field(celsius_test.get(), "d")); diff --git a/test/unit/test_mechanisms.cpp b/test/unit/test_mechanisms.cpp index 1c991d59dc0fc001e59ba40037851776a2e7d27b..f281e6cb41834353274f41414db63f546c44dd18 100644 --- a/test/unit/test_mechanisms.cpp +++ b/test/unit/test_mechanisms.cpp @@ -102,8 +102,8 @@ void mech_update(T* mech, unsigned num_iters) { } for (auto i=0u; i<num_iters; ++i) { - mech->nrn_current(); - mech->nrn_state(); + mech->update_current(); + mech->update_state(); } } diff --git a/test/unit/test_mechcat.cpp b/test/unit/test_mechcat.cpp index ffd6b26510c9d79564609062db3bb7ae7c66cd60..1b3656765d1fe35bb812f3f309ecbd7938fb679b 100644 --- a/test/unit/test_mechcat.cpp +++ b/test/unit/test_mechcat.cpp @@ -75,10 +75,10 @@ struct common_impl: concrete_mechanism<B> { void set_parameter(const std::string& key, const std::vector<fvm_value_type>& vs) override {} void initialize() override {} - void nrn_state() override {} - void nrn_current() override {} + void update_state() override {} + void update_current() override {} void deliver_events() override {} - void write_ions() override {} + void update_ions() override {} std::size_t width_ = 0; diff --git a/test/unit/test_spikes.cpp b/test/unit/test_spikes.cpp index 7be54ece59ff9d82b6efa8ea46263c87b8dad846..f3b0fcbcb2dab2bdbc7a3e52d036f87f245ae573 100644 --- a/test/unit/test_spikes.cpp +++ b/test/unit/test_spikes.cpp @@ -1,11 +1,18 @@ #include "../gtest.h" +#include <arborenv/concurrency.hpp> +#include <arborenv/gpu_env.hpp> + +#include <arbor/load_balance.hpp> +#include <arbor/simulation.hpp> #include <arbor/spike.hpp> #include <backends/multicore/fvm.hpp> #include <memory/memory.hpp> #include <util/rangeutil.hpp> +#include <simple_recipes.hpp> + using namespace arb; // This source is included in `test_spikes_gpu.cpp`, which defines @@ -54,7 +61,7 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { list expected; // create the watch - backend::threshold_watcher watch(cell_index.data(), time_before.data(), time_after.data(), values.data(), index, thresh, context); + backend::threshold_watcher watch(cell_index.data(), values.data(), &time_before, &time_after, index, thresh, context); // initially the first and third watch should not be spiking // the second is spiking @@ -155,3 +162,46 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { EXPECT_FALSE(watch.is_crossed(2)); } +TEST(SPIKES_TEST_CLASS, threshold_watcher_interpolation) { + double dt = 0.025; + double duration = 1; + + arb::segment_tree tree; + tree.append(arb::mnpos, { -6.3, 0.0, 0.0, 6.3}, { 6.3, 0.0, 0.0, 6.3}, 1); + arb::morphology morpho(tree); + + arb::label_dict dict; + dict.set("mid", arb::ls::on_branches(0.5)); + + arb::proc_allocation resources; + resources.gpu_id = arbenv::default_gpu(); + auto context = arb::make_context(resources); + + std::vector<arb::spike> spikes; + + for (unsigned i = 0; i < 8; i++) { + arb::cable_cell cell(morpho, dict); + cell.default_parameters.discretization = arb::cv_policy_every_segment(); + cell.place("\"mid\"", arb::threshold_detector{10}); + cell.place("\"mid\"", arb::i_clamp(0.01+i*dt, duration, 0.5)); + cell.place("\"mid\"", arb::mechanism_desc("hh")); + + cable1d_recipe rec({cell}); + + auto decomp = arb::partition_load_balance(rec, context); + arb::simulation sim(rec, decomp, context); + + sim.set_global_spike_callback( + [&spikes](const std::vector<arb::spike>& recorded_spikes) { + spikes.insert(spikes.end(), recorded_spikes.begin(), recorded_spikes.end()); + }); + + sim.run(duration, dt); + ASSERT_EQ(1u, sim.num_spikes()); + } + + for (unsigned i = 1; i < spikes.size(); ++i) { + EXPECT_NEAR(dt, spikes[i].time - spikes[i-1].time, 1e-4); + } +} +