Skip to content
Snippets Groups Projects
Unverified Commit 1d3ad8db authored by Nora Abi Akar's avatar Nora Abi Akar Committed by GitHub
Browse files

Bug fix: Fix voltage vector size in threshold_watcher contstructor (#1820)

Fix a crash in GPU code when threshold detector is used for the first time and its internal state is allocated with a wrong size.
parent c5b1436f
No related branches found
No related tags found
No related merge requests found
......@@ -53,17 +53,17 @@ struct backend {
static threshold_watcher voltage_watcher(
shared_state& state,
const std::vector<index_type>& cv,
const std::vector<index_type>& detector_cv,
const std::vector<value_type>& thresholds,
const execution_context& context)
{
return threshold_watcher(
state.cv_to_intdom.data(),
state.voltage.data(),
state.src_to_spike.data(),
&state.time,
&state.time_to,
cv,
state.voltage.size(),
detector_cv,
thresholds,
context);
}
......
......@@ -46,7 +46,7 @@ void test_thresholds_impl(
// Test for threshold crossing
const auto cv = cv_index[i];
const auto intdom = cv_to_intdom[cv];
const auto v_prev = prev_values[i];
const auto v_prev = prev_values[cv];
const auto v = values[cv];
const auto thresh = thresholds[i];
fvm_index_type spike_idx = 0;
......@@ -75,7 +75,7 @@ void test_thresholds_impl(
is_crossed[i]=0;
}
prev_values[i] = v;
prev_values[cv] = v;
}
if (crossed) {
......
#pragma once
#include <arbor/arbexcept.hpp>
#include <arbor/assert.hpp>
#include <arbor/common_types.hpp>
#include <arbor/fvm_types.hpp>
......@@ -45,29 +46,28 @@ public:
threshold_watcher(
const fvm_index_type* cv_to_intdom,
const fvm_value_type* values,
const fvm_index_type* src_to_spike,
const array* t_before,
const array* t_after,
const std::vector<fvm_index_type>& cv_index,
const fvm_size_type num_cv,
const std::vector<fvm_index_type>& detector_cv_idx,
const std::vector<fvm_value_type>& thresholds,
const execution_context& ctx
):
cv_to_intdom_(cv_to_intdom),
values_(values),
src_to_spike_(src_to_spike),
t_before_ptr_(t_before),
t_after_ptr_(t_after),
cv_index_(memory::make_const_view(cv_index)),
is_crossed_(cv_index.size()),
cv_index_(memory::make_const_view(detector_cv_idx)),
is_crossed_(detector_cv_idx.size()),
thresholds_(memory::make_const_view(thresholds)),
v_prev_(memory::const_host_view<fvm_value_type>(values, cv_index.size())),
v_prev_(num_cv),
// TODO: allocates enough space for 10 spikes per watch.
// A more robust approach might be needed to avoid overflows.
stack_(10*size(), ctx.gpu)
{
crossings_.reserve(stack_.capacity());
reset();
// reset() needs to be called before this is ready for use
}
/// Remove all stored crossings that were detected in previous calls to test()
......@@ -79,7 +79,9 @@ public:
/// Reset state machine for each detector.
/// Assume that the values in values_ have been set correctly before
/// calling, because the values are used to determine the initial state
void reset() {
void reset(const array& values) {
values_ = values.data();
memory::copy(values, v_prev_);
clear_crossings();
if (size()>0) {
reset_crossed_impl((int)size(), is_crossed_.data(), cv_index_.data(), values_, thresholds_.data());
......@@ -108,6 +110,8 @@ public:
/// crossed since current time t, and the last time the test was
/// performed.
void test(array* time_since_spike) {
arb_assert(values_);
if (size()>0) {
test_thresholds_impl(
(int)size(),
......
......@@ -51,17 +51,17 @@ struct backend {
static threshold_watcher voltage_watcher(
shared_state& state,
const std::vector<index_type>& cv,
const std::vector<index_type>& detector_cv,
const std::vector<value_type>& thresholds,
const execution_context& context)
{
return threshold_watcher(
state.cv_to_intdom.data(),
state.voltage.data(),
state.src_to_spike.data(),
&state.time,
&state.time_to,
cv,
state.voltage.size(),
detector_cv,
thresholds,
context);
}
......
......@@ -19,16 +19,15 @@ public:
threshold_watcher(
const fvm_index_type* cv_to_intdom,
const fvm_value_type* values,
const fvm_index_type* src_to_spike,
const array* t_before,
const array* t_after,
const fvm_size_type num_cv,
const std::vector<fvm_index_type>& cv_index,
const std::vector<fvm_value_type>& thresholds,
const execution_context& context
):
cv_to_intdom_(cv_to_intdom),
values_(values),
src_to_spike_(src_to_spike),
t_before_ptr_(t_before),
t_after_ptr_(t_after),
......@@ -36,10 +35,10 @@ public:
cv_index_(cv_index),
is_crossed_(n_cv_),
thresholds_(thresholds),
v_prev_(values_, values_+n_cv_)
v_prev_(num_cv)
{
arb_assert(n_cv_==thresholds.size());
reset();
// reset() needs to be called before this is ready for use
}
/// Remove all stored crossings that were detected in previous calls
......@@ -51,7 +50,9 @@ public:
/// Reset state machine for each detector.
/// Assume that the values in values_ have been set correctly before
/// calling, because the values are used to determine the initial state
void reset() {
void reset(const array& values) {
values_ = values.data();
std::copy(values.begin(), values.end(), v_prev_.begin());
clear_crossings();
for (fvm_size_type i = 0; i<n_cv_; ++i) {
is_crossed_[i] = values_[cv_index_[i]]>=thresholds_[i];
......@@ -66,13 +67,15 @@ public:
/// Crossing events are recorded for each threshold that
/// is crossed since the last call to test
void test(array* time_since_spike) {
arb_assert(values_);
// Reset all spike times to -1.0 indicating no spike has been recorded on the detector
const fvm_value_type* t_before = t_before_ptr_->data();
const fvm_value_type* t_after = t_after_ptr_->data();
for (fvm_size_type i = 0; i<n_cv_; ++i) {
auto cv = cv_index_[i];
auto intdom = cv_to_intdom_[cv];
auto v_prev = v_prev_[i];
auto v_prev = v_prev_[cv];
auto v = values_[cv];
auto thresh = thresholds_[i];
fvm_index_type spike_idx = 0;
......@@ -103,7 +106,7 @@ public:
}
}
v_prev_[i] = v;
v_prev_[cv] = v;
}
}
......
......@@ -185,7 +185,7 @@ void fvm_lowered_cell_impl<Backend>::reset() {
// NOTE: Threshold watcher reset must come after the voltage values are set,
// as voltage is implicitly read by watcher to set initial state.
threshold_watcher_.reset();
threshold_watcher_.reset(state_->voltage);
}
template <typename Backend>
......
......@@ -76,8 +76,10 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) {
list expected;
// create the watch
backend::threshold_watcher watch(cell_index.data(), values.data(), src_to_spike.data(),
&time_before, &time_after, index, thresh, context);
backend::threshold_watcher watch(cell_index.data(), src_to_spike.data(),
&time_before, &time_after,
values.size(), index, thresh, context);
watch.reset(values);
// initially the first and third watch should not be spiking
// the second is spiking
......@@ -197,7 +199,7 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) {
memory::fill(values, 0);
values[index[0]] = 10.; // first watch should be intialized to spiking state
memory::fill(time_before, 0.);
watch.reset();
watch.reset(values);
EXPECT_EQ(watch.crossings().size(), 0u);
EXPECT_TRUE(watch.is_crossed(0));
EXPECT_FALSE(watch.is_crossed(1));
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment