From 1d3ad8dbbfb1bdf6fd2149a29475fb726ebf90d1 Mon Sep 17 00:00:00 2001 From: Nora Abi Akar <nora.abiakar@gmail.com> Date: Thu, 5 May 2022 08:56:49 +0200 Subject: [PATCH] Bug fix: Fix voltage vector size in threshold_watcher contstructor (#1820) Fix a crash in GPU code when threshold detector is used for the first time and its internal state is allocated with a wrong size. --- arbor/backends/gpu/fvm.hpp | 6 +++--- arbor/backends/gpu/threshold_watcher.cu | 4 ++-- arbor/backends/gpu/threshold_watcher.hpp | 20 +++++++++++-------- arbor/backends/multicore/fvm.hpp | 6 +++--- .../backends/multicore/threshold_watcher.hpp | 17 +++++++++------- arbor/fvm_lowered_cell_impl.hpp | 2 +- test/unit/test_spikes.cpp | 8 +++++--- 7 files changed, 36 insertions(+), 27 deletions(-) diff --git a/arbor/backends/gpu/fvm.hpp b/arbor/backends/gpu/fvm.hpp index ef025bbf..712b2fc7 100644 --- a/arbor/backends/gpu/fvm.hpp +++ b/arbor/backends/gpu/fvm.hpp @@ -53,17 +53,17 @@ struct backend { static threshold_watcher voltage_watcher( shared_state& state, - const std::vector<index_type>& cv, + const std::vector<index_type>& detector_cv, const std::vector<value_type>& thresholds, const execution_context& context) { return threshold_watcher( state.cv_to_intdom.data(), - state.voltage.data(), state.src_to_spike.data(), &state.time, &state.time_to, - cv, + state.voltage.size(), + detector_cv, thresholds, context); } diff --git a/arbor/backends/gpu/threshold_watcher.cu b/arbor/backends/gpu/threshold_watcher.cu index 6748a482..bf048243 100644 --- a/arbor/backends/gpu/threshold_watcher.cu +++ b/arbor/backends/gpu/threshold_watcher.cu @@ -46,7 +46,7 @@ void test_thresholds_impl( // Test for threshold crossing const auto cv = cv_index[i]; const auto intdom = cv_to_intdom[cv]; - const auto v_prev = prev_values[i]; + const auto v_prev = prev_values[cv]; const auto v = values[cv]; const auto thresh = thresholds[i]; fvm_index_type spike_idx = 0; @@ -75,7 +75,7 @@ void test_thresholds_impl( is_crossed[i]=0; } - prev_values[i] = v; + prev_values[cv] = v; } if (crossed) { diff --git a/arbor/backends/gpu/threshold_watcher.hpp b/arbor/backends/gpu/threshold_watcher.hpp index 6d371faa..2c391281 100644 --- a/arbor/backends/gpu/threshold_watcher.hpp +++ b/arbor/backends/gpu/threshold_watcher.hpp @@ -1,6 +1,7 @@ #pragma once #include <arbor/arbexcept.hpp> +#include <arbor/assert.hpp> #include <arbor/common_types.hpp> #include <arbor/fvm_types.hpp> @@ -45,29 +46,28 @@ public: threshold_watcher( const fvm_index_type* cv_to_intdom, - const fvm_value_type* values, const fvm_index_type* src_to_spike, const array* t_before, const array* t_after, - const std::vector<fvm_index_type>& cv_index, + const fvm_size_type num_cv, + const std::vector<fvm_index_type>& detector_cv_idx, const std::vector<fvm_value_type>& thresholds, const execution_context& ctx ): cv_to_intdom_(cv_to_intdom), - values_(values), src_to_spike_(src_to_spike), t_before_ptr_(t_before), t_after_ptr_(t_after), - cv_index_(memory::make_const_view(cv_index)), - is_crossed_(cv_index.size()), + cv_index_(memory::make_const_view(detector_cv_idx)), + is_crossed_(detector_cv_idx.size()), thresholds_(memory::make_const_view(thresholds)), - v_prev_(memory::const_host_view<fvm_value_type>(values, cv_index.size())), + v_prev_(num_cv), // TODO: allocates enough space for 10 spikes per watch. // A more robust approach might be needed to avoid overflows. stack_(10*size(), ctx.gpu) { crossings_.reserve(stack_.capacity()); - reset(); + // reset() needs to be called before this is ready for use } /// Remove all stored crossings that were detected in previous calls to test() @@ -79,7 +79,9 @@ public: /// Reset state machine for each detector. /// Assume that the values in values_ have been set correctly before /// calling, because the values are used to determine the initial state - void reset() { + void reset(const array& values) { + values_ = values.data(); + memory::copy(values, v_prev_); clear_crossings(); if (size()>0) { reset_crossed_impl((int)size(), is_crossed_.data(), cv_index_.data(), values_, thresholds_.data()); @@ -108,6 +110,8 @@ public: /// crossed since current time t, and the last time the test was /// performed. void test(array* time_since_spike) { + arb_assert(values_); + if (size()>0) { test_thresholds_impl( (int)size(), diff --git a/arbor/backends/multicore/fvm.hpp b/arbor/backends/multicore/fvm.hpp index 6e21311d..a9064357 100644 --- a/arbor/backends/multicore/fvm.hpp +++ b/arbor/backends/multicore/fvm.hpp @@ -51,17 +51,17 @@ struct backend { static threshold_watcher voltage_watcher( shared_state& state, - const std::vector<index_type>& cv, + const std::vector<index_type>& detector_cv, const std::vector<value_type>& thresholds, const execution_context& context) { return threshold_watcher( state.cv_to_intdom.data(), - state.voltage.data(), state.src_to_spike.data(), &state.time, &state.time_to, - cv, + state.voltage.size(), + detector_cv, thresholds, context); } diff --git a/arbor/backends/multicore/threshold_watcher.hpp b/arbor/backends/multicore/threshold_watcher.hpp index 9b2d5e53..76a260d3 100644 --- a/arbor/backends/multicore/threshold_watcher.hpp +++ b/arbor/backends/multicore/threshold_watcher.hpp @@ -19,16 +19,15 @@ public: threshold_watcher( const fvm_index_type* cv_to_intdom, - const fvm_value_type* values, const fvm_index_type* src_to_spike, const array* t_before, const array* t_after, + const fvm_size_type num_cv, const std::vector<fvm_index_type>& cv_index, const std::vector<fvm_value_type>& thresholds, const execution_context& context ): cv_to_intdom_(cv_to_intdom), - values_(values), src_to_spike_(src_to_spike), t_before_ptr_(t_before), t_after_ptr_(t_after), @@ -36,10 +35,10 @@ public: cv_index_(cv_index), is_crossed_(n_cv_), thresholds_(thresholds), - v_prev_(values_, values_+n_cv_) + v_prev_(num_cv) { arb_assert(n_cv_==thresholds.size()); - reset(); + // reset() needs to be called before this is ready for use } /// Remove all stored crossings that were detected in previous calls @@ -51,7 +50,9 @@ public: /// Reset state machine for each detector. /// Assume that the values in values_ have been set correctly before /// calling, because the values are used to determine the initial state - void reset() { + void reset(const array& values) { + values_ = values.data(); + std::copy(values.begin(), values.end(), v_prev_.begin()); clear_crossings(); for (fvm_size_type i = 0; i<n_cv_; ++i) { is_crossed_[i] = values_[cv_index_[i]]>=thresholds_[i]; @@ -66,13 +67,15 @@ public: /// Crossing events are recorded for each threshold that /// is crossed since the last call to test void test(array* time_since_spike) { + arb_assert(values_); + // Reset all spike times to -1.0 indicating no spike has been recorded on the detector const fvm_value_type* t_before = t_before_ptr_->data(); const fvm_value_type* t_after = t_after_ptr_->data(); for (fvm_size_type i = 0; i<n_cv_; ++i) { auto cv = cv_index_[i]; auto intdom = cv_to_intdom_[cv]; - auto v_prev = v_prev_[i]; + auto v_prev = v_prev_[cv]; auto v = values_[cv]; auto thresh = thresholds_[i]; fvm_index_type spike_idx = 0; @@ -103,7 +106,7 @@ public: } } - v_prev_[i] = v; + v_prev_[cv] = v; } } diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index ddfe104e..fcd7c4ee 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -185,7 +185,7 @@ void fvm_lowered_cell_impl<Backend>::reset() { // NOTE: Threshold watcher reset must come after the voltage values are set, // as voltage is implicitly read by watcher to set initial state. - threshold_watcher_.reset(); + threshold_watcher_.reset(state_->voltage); } template <typename Backend> diff --git a/test/unit/test_spikes.cpp b/test/unit/test_spikes.cpp index 882de042..12a53234 100644 --- a/test/unit/test_spikes.cpp +++ b/test/unit/test_spikes.cpp @@ -76,8 +76,10 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { list expected; // create the watch - backend::threshold_watcher watch(cell_index.data(), values.data(), src_to_spike.data(), - &time_before, &time_after, index, thresh, context); + backend::threshold_watcher watch(cell_index.data(), src_to_spike.data(), + &time_before, &time_after, + values.size(), index, thresh, context); + watch.reset(values); // initially the first and third watch should not be spiking // the second is spiking @@ -197,7 +199,7 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { memory::fill(values, 0); values[index[0]] = 10.; // first watch should be intialized to spiking state memory::fill(time_before, 0.); - watch.reset(); + watch.reset(values); EXPECT_EQ(watch.crossings().size(), 0u); EXPECT_TRUE(watch.is_crossed(0)); EXPECT_FALSE(watch.is_crossed(1)); -- GitLab