From 1d3ad8dbbfb1bdf6fd2149a29475fb726ebf90d1 Mon Sep 17 00:00:00 2001
From: Nora Abi Akar <nora.abiakar@gmail.com>
Date: Thu, 5 May 2022 08:56:49 +0200
Subject: [PATCH] Bug fix: Fix voltage vector size in threshold_watcher
 contstructor (#1820)

Fix a crash in GPU code when threshold detector is used for the first time and its internal state is allocated with a wrong size.
---
 arbor/backends/gpu/fvm.hpp                    |  6 +++---
 arbor/backends/gpu/threshold_watcher.cu       |  4 ++--
 arbor/backends/gpu/threshold_watcher.hpp      | 20 +++++++++++--------
 arbor/backends/multicore/fvm.hpp              |  6 +++---
 .../backends/multicore/threshold_watcher.hpp  | 17 +++++++++-------
 arbor/fvm_lowered_cell_impl.hpp               |  2 +-
 test/unit/test_spikes.cpp                     |  8 +++++---
 7 files changed, 36 insertions(+), 27 deletions(-)

diff --git a/arbor/backends/gpu/fvm.hpp b/arbor/backends/gpu/fvm.hpp
index ef025bbf..712b2fc7 100644
--- a/arbor/backends/gpu/fvm.hpp
+++ b/arbor/backends/gpu/fvm.hpp
@@ -53,17 +53,17 @@ struct backend {
 
     static threshold_watcher voltage_watcher(
         shared_state& state,
-        const std::vector<index_type>& cv,
+        const std::vector<index_type>& detector_cv,
         const std::vector<value_type>& thresholds,
         const execution_context& context)
     {
         return threshold_watcher(
             state.cv_to_intdom.data(),
-            state.voltage.data(),
             state.src_to_spike.data(),
             &state.time,
             &state.time_to,
-            cv,
+            state.voltage.size(),
+            detector_cv,
             thresholds,
             context);
     }
diff --git a/arbor/backends/gpu/threshold_watcher.cu b/arbor/backends/gpu/threshold_watcher.cu
index 6748a482..bf048243 100644
--- a/arbor/backends/gpu/threshold_watcher.cu
+++ b/arbor/backends/gpu/threshold_watcher.cu
@@ -46,7 +46,7 @@ void test_thresholds_impl(
         // Test for threshold crossing
         const auto cv     = cv_index[i];
         const auto intdom = cv_to_intdom[cv];
-        const auto v_prev = prev_values[i];
+        const auto v_prev = prev_values[cv];
         const auto v      = values[cv];
         const auto thresh = thresholds[i];
         fvm_index_type spike_idx = 0;
@@ -75,7 +75,7 @@ void test_thresholds_impl(
             is_crossed[i]=0;
         }
 
-        prev_values[i] = v;
+        prev_values[cv] = v;
     }
 
     if (crossed) {
diff --git a/arbor/backends/gpu/threshold_watcher.hpp b/arbor/backends/gpu/threshold_watcher.hpp
index 6d371faa..2c391281 100644
--- a/arbor/backends/gpu/threshold_watcher.hpp
+++ b/arbor/backends/gpu/threshold_watcher.hpp
@@ -1,6 +1,7 @@
 #pragma once
 
 #include <arbor/arbexcept.hpp>
+#include <arbor/assert.hpp>
 #include <arbor/common_types.hpp>
 #include <arbor/fvm_types.hpp>
 
@@ -45,29 +46,28 @@ public:
 
     threshold_watcher(
         const fvm_index_type* cv_to_intdom,
-        const fvm_value_type* values,
         const fvm_index_type* src_to_spike,
         const array* t_before,
         const array* t_after,
-        const std::vector<fvm_index_type>& cv_index,
+        const fvm_size_type num_cv,
+        const std::vector<fvm_index_type>& detector_cv_idx,
         const std::vector<fvm_value_type>& thresholds,
         const execution_context& ctx
     ):
         cv_to_intdom_(cv_to_intdom),
-        values_(values),
         src_to_spike_(src_to_spike),
         t_before_ptr_(t_before),
         t_after_ptr_(t_after),
-        cv_index_(memory::make_const_view(cv_index)),
-        is_crossed_(cv_index.size()),
+        cv_index_(memory::make_const_view(detector_cv_idx)),
+        is_crossed_(detector_cv_idx.size()),
         thresholds_(memory::make_const_view(thresholds)),
-        v_prev_(memory::const_host_view<fvm_value_type>(values, cv_index.size())),
+        v_prev_(num_cv),
         // TODO: allocates enough space for 10 spikes per watch.
         // A more robust approach might be needed to avoid overflows.
         stack_(10*size(), ctx.gpu)
     {
         crossings_.reserve(stack_.capacity());
-        reset();
+        // reset() needs to be called before this is ready for use
     }
 
     /// Remove all stored crossings that were detected in previous calls to test()
@@ -79,7 +79,9 @@ public:
     /// Reset state machine for each detector.
     /// Assume that the values in values_ have been set correctly before
     /// calling, because the values are used to determine the initial state
-    void reset() {
+    void reset(const array& values) {
+        values_ = values.data();
+        memory::copy(values, v_prev_);
         clear_crossings();
         if (size()>0) {
             reset_crossed_impl((int)size(), is_crossed_.data(), cv_index_.data(), values_, thresholds_.data());
@@ -108,6 +110,8 @@ public:
     /// crossed since current time t, and the last time the test was
     /// performed.
     void test(array* time_since_spike) {
+        arb_assert(values_);
+
         if (size()>0) {
             test_thresholds_impl(
                 (int)size(),
diff --git a/arbor/backends/multicore/fvm.hpp b/arbor/backends/multicore/fvm.hpp
index 6e21311d..a9064357 100644
--- a/arbor/backends/multicore/fvm.hpp
+++ b/arbor/backends/multicore/fvm.hpp
@@ -51,17 +51,17 @@ struct backend {
 
     static threshold_watcher voltage_watcher(
         shared_state& state,
-        const std::vector<index_type>& cv,
+        const std::vector<index_type>& detector_cv,
         const std::vector<value_type>& thresholds,
         const execution_context& context)
     {
         return threshold_watcher(
             state.cv_to_intdom.data(),
-            state.voltage.data(),
             state.src_to_spike.data(),
             &state.time,
             &state.time_to,
-            cv,
+            state.voltage.size(),
+            detector_cv,
             thresholds,
             context);
     }
diff --git a/arbor/backends/multicore/threshold_watcher.hpp b/arbor/backends/multicore/threshold_watcher.hpp
index 9b2d5e53..76a260d3 100644
--- a/arbor/backends/multicore/threshold_watcher.hpp
+++ b/arbor/backends/multicore/threshold_watcher.hpp
@@ -19,16 +19,15 @@ public:
 
     threshold_watcher(
         const fvm_index_type* cv_to_intdom,
-        const fvm_value_type* values,
         const fvm_index_type* src_to_spike,
         const array* t_before,
         const array* t_after,
+        const fvm_size_type num_cv,
         const std::vector<fvm_index_type>& cv_index,
         const std::vector<fvm_value_type>& thresholds,
         const execution_context& context
     ):
         cv_to_intdom_(cv_to_intdom),
-        values_(values),
         src_to_spike_(src_to_spike),
         t_before_ptr_(t_before),
         t_after_ptr_(t_after),
@@ -36,10 +35,10 @@ public:
         cv_index_(cv_index),
         is_crossed_(n_cv_),
         thresholds_(thresholds),
-        v_prev_(values_, values_+n_cv_)
+        v_prev_(num_cv)
     {
         arb_assert(n_cv_==thresholds.size());
-        reset();
+        // reset() needs to be called before this is ready for use
     }
 
     /// Remove all stored crossings that were detected in previous calls
@@ -51,7 +50,9 @@ public:
     /// Reset state machine for each detector.
     /// Assume that the values in values_ have been set correctly before
     /// calling, because the values are used to determine the initial state
-    void reset() {
+    void reset(const array& values) {
+        values_ = values.data();
+        std::copy(values.begin(), values.end(), v_prev_.begin());
         clear_crossings();
         for (fvm_size_type i = 0; i<n_cv_; ++i) {
             is_crossed_[i] = values_[cv_index_[i]]>=thresholds_[i];
@@ -66,13 +67,15 @@ public:
     /// Crossing events are recorded for each threshold that
     /// is crossed since the last call to test
     void test(array* time_since_spike) {
+        arb_assert(values_);
+
         // Reset all spike times to -1.0 indicating no spike has been recorded on the detector
         const fvm_value_type* t_before = t_before_ptr_->data();
         const fvm_value_type* t_after  = t_after_ptr_->data();
         for (fvm_size_type i = 0; i<n_cv_; ++i) {
             auto cv     = cv_index_[i];
             auto intdom = cv_to_intdom_[cv];
-            auto v_prev = v_prev_[i];
+            auto v_prev = v_prev_[cv];
             auto v      = values_[cv];
             auto thresh = thresholds_[i];
             fvm_index_type spike_idx = 0;
@@ -103,7 +106,7 @@ public:
                 }
             }
 
-            v_prev_[i] = v;
+            v_prev_[cv] = v;
         }
     }
 
diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp
index ddfe104e..fcd7c4ee 100644
--- a/arbor/fvm_lowered_cell_impl.hpp
+++ b/arbor/fvm_lowered_cell_impl.hpp
@@ -185,7 +185,7 @@ void fvm_lowered_cell_impl<Backend>::reset() {
 
     // NOTE: Threshold watcher reset must come after the voltage values are set,
     // as voltage is implicitly read by watcher to set initial state.
-    threshold_watcher_.reset();
+    threshold_watcher_.reset(state_->voltage);
 }
 
 template <typename Backend>
diff --git a/test/unit/test_spikes.cpp b/test/unit/test_spikes.cpp
index 882de042..12a53234 100644
--- a/test/unit/test_spikes.cpp
+++ b/test/unit/test_spikes.cpp
@@ -76,8 +76,10 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) {
     list expected;
 
     // create the watch
-    backend::threshold_watcher watch(cell_index.data(), values.data(), src_to_spike.data(),
-                                     &time_before, &time_after, index, thresh, context);
+    backend::threshold_watcher watch(cell_index.data(), src_to_spike.data(),
+                                     &time_before, &time_after, 
+                                     values.size(), index, thresh, context);
+    watch.reset(values);
 
     // initially the first and third watch should not be spiking
     //           the second is spiking
@@ -197,7 +199,7 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) {
     memory::fill(values, 0);
     values[index[0]] = 10.; // first watch should be intialized to spiking state
     memory::fill(time_before, 0.);
-    watch.reset();
+    watch.reset(values);
     EXPECT_EQ(watch.crossings().size(), 0u);
     EXPECT_TRUE(watch.is_crossed(0));
     EXPECT_FALSE(watch.is_crossed(1));
-- 
GitLab