diff --git a/CMakeLists.txt b/CMakeLists.txt index d181f424c8e58319da6d33471fd6020d33527df0..fb6800b491e2425afda9e1d52ca8a24c96212f68 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,6 +4,11 @@ cmake_minimum_required(VERSION 2.8) project(cell_algorithms) enable_language(CXX) +# Hide warnings about mixing old and new signatures for target_link_libraries. +# These can't be avoided, because the FindCUDA packed provided by CMake uses +# the old signature, while other packages use the new signature. +cmake_policy(SET CMP0023 OLD) + # save incoming CXX flags for forwarding to modcc external project set(SAVED_CXX_FLAGS "${CMAKE_CXX_FLAGS}") @@ -221,27 +226,29 @@ set(NMC_VALIDATION_DATA_DIR "${PROJECT_SOURCE_DIR}/validation/data" CACHE PATH "location of generated validation data") #---------------------------------------------------------- -# Whether to build validation data at all +# Whether to build validation data #---------------------------------------------------------- -option(NMC_BUILD_VALIDATION_DATA "generate validation data" ON) - -# Whether to attempt to use julia to build validation data -find_program(JULIA_BIN julia) -if(JULIA_BIN STREQUAL "JULIA_BIN-NOTFOUND") - message(STATUS "julia not found; will not automatically build validation data sets from julia scripts") - set(NMC_BUILD_JULIA_VALIDATION_DATA FALSE) -else() - set(NMC_BUILD_JULIA_VALIDATION_DATA TRUE) -endif() +# turn off by default +option(NMC_BUILD_VALIDATION_DATA "generate validation data" OFF) +if (NMC_BUILD_VALIDATION_DATA) + # Whether to attempt to use julia to build validation data + find_program(JULIA_BIN julia) + if(JULIA_BIN STREQUAL "JULIA_BIN-NOTFOUND") + message(STATUS "julia not found; will not automatically build validation data sets from julia scripts") + set(NMC_BUILD_JULIA_VALIDATION_DATA FALSE) + else() + set(NMC_BUILD_JULIA_VALIDATION_DATA TRUE) + endif() -# Whether to attempt to use nrniv to build validation data -# (if we find nrniv, do) -find_program(NRNIV_BIN nrniv) -if(NRNIV_BIN STREQUAL "NRNIV_BIN-NOTFOUND") - message(STATUS "nrniv not found; will not automatically build NEURON validation data sets") - set(NMC_BUILD_NRN_VALIDATION_DATA FALSE) -else() - set(NMC_BUILD_NRN_VALIDATION_DATA TRUE) + # Whether to attempt to use nrniv to build validation data + # (if we find nrniv, do) + find_program(NRNIV_BIN nrniv) + if(NRNIV_BIN STREQUAL "NRNIV_BIN-NOTFOUND") + message(STATUS "nrniv not found; will not automatically build NEURON validation data sets") + set(NMC_BUILD_NRN_VALIDATION_DATA FALSE) + else() + set(NMC_BUILD_NRN_VALIDATION_DATA TRUE) + endif() endif() #---------------------------------------------------------- @@ -276,4 +283,3 @@ add_subdirectory(src) add_subdirectory(tests) add_subdirectory(miniapp) add_subdirectory(lmorpho) - diff --git a/miniapp/miniapp.cpp b/miniapp/miniapp.cpp index 67697d2ff9450b41b60a726fa25446f76f518e41..8d520c9fee1f13e7381842cbf579380d7dee9748 100644 --- a/miniapp/miniapp.cpp +++ b/miniapp/miniapp.cpp @@ -89,8 +89,6 @@ int main(int argc, char** argv) { EXPECTS(group_divisions.front() == cell_range.first); EXPECTS(group_divisions.back() == cell_range.second); - model_type m(*recipe, util::partition_view(group_divisions)); - auto register_exporter = [] (const io::cl_options& options) { return util::make_unique<file_export_type>( @@ -98,24 +96,7 @@ int main(int argc, char** argv) { options.file_extension, options.over_write); }; - // File output depends on the input arguments - std::unique_ptr<file_export_type> file_exporter; - if (options.spike_file_output) { - if (options.single_file_per_rank) { - file_exporter = register_exporter(options); - m.set_local_spike_callback( - [&](const std::vector<spike_type>& spikes) { - file_exporter->output(spikes); - }); - } - else if(communication::global_policy::id()==0) { - file_exporter = register_exporter(options); - m.set_global_spike_callback( - [&](const std::vector<spike_type>& spikes) { - file_exporter->output(spikes); - }); - } - } + model_type m(*recipe, util::partition_view(group_divisions)); // inject some artificial spikes, 1 per 20 neurons. std::vector<cell_gid_type> local_sources; @@ -137,15 +118,36 @@ int main(int argc, char** argv) { m.attach_sampler(probe.id, make_trace_sampler(traces.back().get(), sample_dt)); } +#ifdef WITH_PROFILING // dummy run of the model for one step to ensure that profiling is consistent m.run(options.dt, options.dt); - - // reset the model + // reset and add the source spikes once again m.reset(); - // reset the source spikes for (auto source : local_sources) { m.add_artificial_spike({source, 0}); } +#endif + + // Initialize the spike exporting interface after the profiler dummy + // steps, to avoid having the initial seed spikes that are artificially + // injected at t=0 from being recorded and output twice. + std::unique_ptr<file_export_type> file_exporter; + if (options.spike_file_output) { + if (options.single_file_per_rank) { + file_exporter = register_exporter(options); + m.set_local_spike_callback( + [&](const std::vector<spike_type>& spikes) { + file_exporter->output(spikes); + }); + } + else if(communication::global_policy::id()==0) { + file_exporter = register_exporter(options); + m.set_global_spike_callback( + [&](const std::vector<spike_type>& spikes) { + file_exporter->output(spikes); + }); + } + } // run model m.run(options.tfinal, options.dt); diff --git a/src/backends/fvm_gpu.hpp b/src/backends/fvm_gpu.hpp index 6572a7c6b113ac22e9638178240a942631e7dcd5..5e3e7a14f128c4ca272b2aabaa83cb255efcafd6 100644 --- a/src/backends/fvm_gpu.hpp +++ b/src/backends/fvm_gpu.hpp @@ -6,9 +6,11 @@ #include <common_types.hpp> #include <mechanism.hpp> #include <memory/memory.hpp> +#include <memory/managed_ptr.hpp> #include <util/span.hpp> #include "stimulus_gpu.hpp" +#include "gpu_stack.hpp" namespace nest { namespace mc { @@ -51,6 +53,24 @@ __global__ void matrix_solve(matrix_solve_param_pack<T, I> params); template <typename T, typename I> __global__ void assemble_matrix(matrix_update_param_pack<T, I> params, T dt); +/// kernel used to test for threshold crossing test code. +/// params: +/// t : current time (ms) +/// t_prev : time of last test (ms) +/// size : number of values to test +/// is_crossed : crossing state at time t_prev (true or false) +/// prev_values : values at sample points (see index) sampled at t_prev +/// index : index with locations in values to test for crossing +/// values : values at t_prev +/// thresholds : threshold values to watch for crossings +template <typename T, typename I, typename Stack> +__global__ +void test_thresholds( + float t, float t_prev, int size, + Stack& stack, + I* is_crossed, T* prev_values, + const I* index, const T* values, const T* thresholds); + struct backend { /// define the real and index types using value_type = double; @@ -98,6 +118,7 @@ struct backend { { auto n = d.size(); host_array invariant_d_tmp(n, 0); + // make a copy of the conductance on the host host_array face_conductance_tmp = face_conductance; for(auto i: util::make_span(1u, n)) { @@ -174,6 +195,129 @@ struct backend { static bool has_mechanism(const std::string& name) { return mech_map_.count(name)>0; } + /// threshold crossing logic + /// used as part of spike detection back end + class threshold_watcher { + public: + /// stores a single crossing event + struct threshold_crossing { + size_type index; // index of variable + value_type time; // time of crossing + __host__ __device__ + friend bool operator== + (const threshold_crossing& lhs, const threshold_crossing& rhs) + { + return lhs.index==rhs.index && lhs.time==rhs.time; + } + }; + + using stack_type = gpu_stack<threshold_crossing>; + + threshold_watcher() = default; + + threshold_watcher( + const_view values, + const std::vector<size_type>& index, + const std::vector<value_type>& thresh, + value_type t=0): + values_(values), + index_(memory::make_const_view(index)), + thresholds_(memory::make_const_view(thresh)), + prev_values_(values), + is_crossed_(size()), + stack_(memory::make_managed_ptr<stack_type>(10*size())) + { + reset(t); + } + + /// Remove all stored crossings that were detected in previous calls + /// to test() + void clear_crossings() { + stack_->clear(); + } + + /// Reset state machine for each detector. + /// Assume that the values in values_ have been set correctly before + /// calling, because the values are used to determine the initial state + void reset(value_type t=0) { + clear_crossings(); + + // Make host-side copies of the information needed to calculate + // the initial crossed state + auto values = memory::on_host(values_); + auto thresholds = memory::on_host(thresholds_); + auto index = memory::on_host(index_); + + // calculate the initial crossed state in host memory + auto crossed = std::vector<size_type>(size()); + for (auto i: util::make_span(0u, size())) { + crossed[i] = values[index[i]] < thresholds[i] ? 0 : 1; + } + + // copy the initial crossed state to device memory + is_crossed_ = memory::on_gpu(crossed); + + // reset time of last test + t_prev_ = t; + } + + bool is_crossed(size_type i) const { + return is_crossed_[i]; + } + + const std::vector<threshold_crossing> crossings() const { + return std::vector<threshold_crossing>(stack_->begin(), stack_->end()); + } + + /// The time at which the last test was performed + value_type last_test_time() const { + return t_prev_; + } + + /// Tests each target for changed threshold state. + /// Crossing events are recorded for each threshold that has been + /// crossed since current time t, and the last time the test was + /// performed. + void test(value_type t) { + EXPECTS(t_prev_<t); + + constexpr int block_dim = 128; + const int grid_dim = (size()+block_dim-1)/block_dim; + test_thresholds<<<grid_dim, block_dim>>>( + t, t_prev_, size(), + *stack_, + is_crossed_.data(), prev_values_.data(), + index_.data(), values_.data(), thresholds_.data()); + + // Check that the number of spikes has not exceeded + // the capacity of the stack. + EXPECTS(stack_->size() <= stack_->capacity()); + + t_prev_ = t; + } + + /// the number of threashold values that are being monitored + std::size_t size() const { + return index_.size(); + } + + /// Data type used to store the crossings. + /// Provided to make type-generic calling code. + using crossing_list = std::vector<threshold_crossing>; + + private: + + const_view values_; // values to watch: on gpu + iarray index_; // indexes of values to watch: on gpu + + array thresholds_; // threshold for each watch: on gpu + value_type t_prev_; // time of previous sample: on host + array prev_values_; // values at previous sample time: on host + iarray is_crossed_; // bool flag for state of each watch: on gpu + + memory::managed_ptr<stack_type> stack_; + }; + private: using maker_type = mechanism (*)(view, view, array&&, iarray&&); @@ -187,7 +331,7 @@ private: }; /// GPU implementation of Hines Matrix solver. -/// Naiive implementation with one CUDA thread per matrix. +/// Naive implementation with one CUDA thread per matrix. template <typename T, typename I> __global__ void matrix_solve(matrix_solve_param_pack<T, I> params) { @@ -240,6 +384,48 @@ void assemble_matrix(matrix_update_param_pack<T, I> params, T dt) { } } +template <typename T, typename I, typename Stack> +__global__ +void test_thresholds( + float t, float t_prev, int size, + Stack& stack, + I* is_crossed, T* prev_values, + const I* index, const T* values, const T* thresholds) +{ + int i = threadIdx.x + blockIdx.x*blockDim.x; + + bool crossed = false; + float crossing_time; + + if (i<size) { + // Test for threshold crossing + const auto v_prev = prev_values[i]; + const auto v = values[index[i]]; + const auto thresh = thresholds[i]; + + if (!is_crossed[i]) { + if (v>=thresh) { + // The threshold has been passed, so estimate the time using + // linear interpolation + auto pos = (thresh - v_prev)/(v - v_prev); + crossing_time = t_prev + pos*(t - t_prev); + + is_crossed[i] = 1; + crossed = true; + } + } + else if (v<thresh) { + is_crossed[i]=0; + } + + prev_values[i] = v; + } + + if (crossed) { + stack.push_back({I(i), crossing_time}); + } +} + } // namespace multicore } // namespace mc } // namespace nest diff --git a/src/backends/fvm_multicore.hpp b/src/backends/fvm_multicore.hpp index a82229e7e2078e0d921d04c2d8f0e5091ba5eef4..8ac525614b05413e63121f50ce873b1d6496c7ba 100644 --- a/src/backends/fvm_multicore.hpp +++ b/src/backends/fvm_multicore.hpp @@ -5,6 +5,7 @@ #include <common_types.hpp> #include <mechanism.hpp> #include <memory/memory.hpp> +#include <memory/wrappers.hpp> #include <util/span.hpp> #include "stimulus_multicore.hpp" @@ -34,6 +35,7 @@ struct backend { using host_view = view; using host_iview = iview; + /// hines matrix solver static void hines_solve( view d, view u, view rhs, const_iview p, const_iview cell_index) @@ -142,6 +144,118 @@ struct backend { return "cpu"; } + /// threshold crossing logic + /// used as part of spike detection back end + class threshold_watcher { + public: + /// stores a single crossing event + struct threshold_crossing { + size_type index; // index of variable + value_type time; // time of crossing + friend bool operator== ( + const threshold_crossing& lhs, const threshold_crossing& rhs) + { + return lhs.index==rhs.index && lhs.time==rhs.time; + } + }; + + threshold_watcher() = default; + + threshold_watcher( + const_view vals, + const std::vector<size_type>& indxs, + const std::vector<value_type>& thresh, + value_type t=0): + values_(vals), + index_(memory::make_const_view(indxs)), + thresholds_(memory::make_const_view(thresh)), + v_prev_(vals) + { + is_crossed_ = iarray(size()); + reset(t); + } + + /// Remove all stored crossings that were detected in previous calls + /// to the test() member function. + void clear_crossings() { + crossings_.clear(); + } + + /// Reset state machine for each detector. + /// Assume that the values in values_ have been set correctly before + /// calling, because the values are used to determine the initial state + void reset(value_type t=0) { + clear_crossings(); + for (auto i=0u; i<size(); ++i) { + is_crossed_[i] = values_[index_[i]]>=thresholds_[i]; + } + t_prev_ = t; + } + + const std::vector<threshold_crossing>& crossings() const { + return crossings_; + } + + /// The time at which the last test was performed + value_type last_test_time() const { + return t_prev_; + } + + /// Tests each target for changed threshold state + /// Crossing events are recorded for each threshold that + /// is crossed since the last call to test + void test(value_type t) { + for (auto i=0u; i<size(); ++i) { + auto v_prev = v_prev_[i]; + auto v = values_[index_[i]]; + auto thresh = thresholds_[i]; + if (!is_crossed_[i]) { + if (v>=thresh) { + // the threshold has been passed, so estimate the time using + // linear interpolation + auto pos = (thresh - v_prev)/(v - v_prev); + auto crossing_time = t_prev_ + pos*(t - t_prev_); + crossings_.push_back({i, crossing_time}); + + is_crossed_[i] = true; + } + } + else { + if (v<thresh) { + is_crossed_[i] = false; + } + } + + v_prev_[i] = v; + } + t_prev_ = t; + } + + bool is_crossed(size_type i) const { + return is_crossed_[i]; + } + + /// the number of threashold values that are being monitored + std::size_t size() const { + return index_.size(); + } + + /// Data type used to store the crossings. + /// Provided to make type-generic calling code. + using crossing_list = std::vector<threshold_crossing>; + + private: + const_view values_; + iarray index_; + + array thresholds_; + value_type t_prev_; + array v_prev_; + crossing_list crossings_; + iarray is_crossed_; + }; + + private: using maker_type = mechanism (*)(view, view, array&&, iarray&&); diff --git a/src/backends/gpu_stack.hpp b/src/backends/gpu_stack.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7767b9e535d22aa05d23b7c7bd18690028f95ca0 --- /dev/null +++ b/src/backends/gpu_stack.hpp @@ -0,0 +1,115 @@ +#pragma once + +#include <memory/allocator.hpp> + +namespace nest { +namespace mc { +namespace gpu { + +// A simple stack data structure for the GPU. +// +// Provides a host side interface for +// * construction and destrutcion +// * reading the values stored on the stack +// * resetting the stack to an empty state +// * querying the size and capacity of the stack +// Provides a device side interface for +// * cooperative grid level push_back +// * querying the size and capacity of the stack +// +// It is designed to be initialized empty with a given capacity on the host, +// updated by device kernels, and periodically read and reset from the host side. +template <typename T> +class gpu_stack { + using value_type = T; + using allocator = memory::managed_allocator<value_type>; + + // The number of items of type value_type that can be stored in the stack + unsigned capacity_; + + // The number of items that have been stored + unsigned size_; + + // Memory containing the value buffer + // Stored in managed memory to facilitate host-side access of values + // pushed from kernels on the device. + value_type* data_; + +public: + + gpu_stack(unsigned capacity): + capacity_(capacity), size_(0u) + { + data_ = allocator().allocate(capacity_); + } + + ~gpu_stack() { + allocator().deallocate(data_, capacity_); + } + + // Append a new value to the stack. + // The value will only be appended if do_push is true. + __device__ + void push_back(const value_type& value) { + // Atomically increment the size_ counter. The atomicAdd returns + // the value of size_ before the increment, which is the location + // at which this thread can store value. + unsigned position = atomicAdd(&size_, 1u); + + // It is possible that size_>capacity_. In this case, only capacity_ + // entries are stored, and additional values are lost. The size_ + // will contain the total number of attempts to push, + if (position<capacity_) { + data_[position] = value; + } + } + + __host__ + void clear() { + size_ = 0; + } + + // The number of items that have been pushed back on the stack. + // size may exceed capacity, which indicates that the caller attempted + // to push back more values than there was space to store. + __host__ __device__ + unsigned size() const { + return size_; + } + + // The maximum number of items that can be stored in the stack. + __host__ __device__ + unsigned capacity() const { + return capacity_; + } + + value_type& operator[](unsigned i) { + EXPECTS(i<size_ && i<capacity_); + return data_[i]; + } + + value_type& operator[](unsigned i) const { + EXPECTS(i<size_ && i<capacity_); + return data_[i]; + } + + value_type* begin() { + return data_; + } + const value_type* begin() const { + return data_; + } + + value_type* end() { + // Take care of the case where size_>capacity_. + return data_ + (size_>capacity_? capacity_: size_); + } + const value_type* end() const { + // Take care of the case where size_>capacity_. + return data_ + (size_>capacity_? capacity_: size_); + } +}; + +} // namespace gpu +} // namespace mc +} // namespace nest diff --git a/src/cell_group.hpp b/src/cell_group.hpp index 606acf586579b2e043ca164d1b60c86e9f7bf78a..807238a440862c2265cc624f57fad5ebb3e71eeb 100644 --- a/src/cell_group.hpp +++ b/src/cell_group.hpp @@ -10,7 +10,6 @@ #include <common_types.hpp> #include <event_queue.hpp> #include <spike.hpp> -#include <spike_source.hpp> #include <util/debug.hpp> #include <util/partition.hpp> #include <util/range.hpp> @@ -27,17 +26,11 @@ public: using lowered_cell_type = LoweredCell; using value_type = typename lowered_cell_type::value_type; using size_type = typename lowered_cell_type::value_type; - using spike_detector_type = spike_detector<lowered_cell_type>; using source_id_type = cell_member_type; using time_type = float; using sampler_function = std::function<util::optional<time_type>(time_type, double)>; - struct spike_source_type { - source_id_type source_id; - spike_detector_type source; - }; - cell_group() = default; template <typename Cells> @@ -48,40 +41,31 @@ public: build_handle_partitions(cells); std::size_t n_probes = probe_handle_divisions_.back(); std::size_t n_targets = target_handle_divisions_.back(); - std::size_t n_detectors = - algorithms::sum(util::transform_view(cells, [](const cell& c) { return c.detectors().size(); })); + std::size_t n_detectors = algorithms::sum(util::transform_view( + cells, [](const cell& c) { return c.detectors().size(); })); // Allocate space to store handles. - detector_handles_.resize(n_detectors); target_handles_.resize(n_targets); probe_handles_.resize(n_probes); - cell_.initialize(cells, detector_handles_, target_handles_, probe_handles_); + cell_.initialize(cells, target_handles_, probe_handles_); - // Create spike detectors and associate them with globally unique source ids. - cell_gid_type source_gid = gid_base_; - unsigned i = 0; + // Create a list of the global identifiers for the spike sources + auto source_gid = cell_gid_type{gid_base_}; for (const auto& cell: cells) { - cell_lid_type source_lid = 0u; - for (auto& d: cell.detectors()) { - cell_member_type source_id{source_gid, source_lid++}; - - spike_sources_.push_back({ - source_id, spike_detector_type(cell_, detector_handles_[i++], d.threshold, 0.f) - }); + for (cell_lid_type lid=0u; lid<cell.detectors().size(); ++lid) { + spike_sources_.push_back(source_id_type{source_gid, lid}); } ++source_gid; } + EXPECTS(spike_sources_.size()==n_detectors); } void reset() { - clear_spikes(); + spikes_.clear(); clear_events(); reset_samplers(); cell_.reset(); - for (auto& spike_source: spike_sources_) { - spike_source.source.reset(cell_, 0.f); - } } time_type min_step(time_type dt) { @@ -127,15 +111,8 @@ public: << gid_base_ << " at t " << cell_.time() << " ms\n"; } - PE("events"); - // check for new spikes - for (auto& s : spike_sources_) { - if (auto spike = s.source.test(cell_, cell_.time())) { - spikes_.push_back({s.source_id, spike.get()}); - } - } - // apply events + PE("events"); if (next) { auto handle = get_target_handle(next->target); cell_.deliver_event(handle, next->weight); @@ -143,6 +120,18 @@ public: PL(); } + // Copy out spike voltage threshold crossings from the back end, then + // generate spikes with global spike source ids. The threshold crossings + // record the local spike source index, which must be converted to a + // global index for spike communication. + PE("events"); + for (auto c: cell_.get_spikes()) { + spikes_.push_back({spike_sources_[c.index], time_type(c.time)}); + } + // Now that the spikes have been generated, clear the old crossings + // to get ready to record spikes from the next integration period. + cell_.clear_spikes(); + PL(); } template <typename R> @@ -153,17 +142,19 @@ public: } const std::vector<spike<source_id_type, time_type>>& - spikes() const { return spikes_; } - - const std::vector<spike_source_type>& - spike_sources() const { - return spike_sources_; + spikes() const { + return spikes_; } void clear_spikes() { spikes_.clear(); } + const std::vector<source_id_type>& + spike_sources() const { + return spike_sources_; + } + void clear_events() { events_.clear(); } @@ -203,7 +194,7 @@ private: lowered_cell_type cell_; /// spike detectors attached to the cell - std::vector<spike_source_type> spike_sources_; + std::vector<source_id_type> spike_sources_; /// spikes that are generated std::vector<spike<source_id_type, time_type>> spikes_; @@ -219,9 +210,6 @@ private: iarray first_target_gid_; /// handles for accessing lowered cell - using detector_handle = typename lowered_cell_type::detector_handle; - std::vector<detector_handle> detector_handles_; - using target_handle = typename lowered_cell_type::target_handle; std::vector<target_handle> target_handles_; diff --git a/src/fvm_multicell.hpp b/src/fvm_multicell.hpp index 4dc1966fd6e92733b7392698e758281fe61cc397..a629da3eeb66e3ccdb7607e717e3ebdac43d1db4 100644 --- a/src/fvm_multicell.hpp +++ b/src/fvm_multicell.hpp @@ -62,7 +62,6 @@ public: using matrix_assembler = typename backend::matrix_assembler; - using detector_handle = size_type; using target_handle = std::pair<size_type, size_type>; using probe_handle = std::pair<const array fvm_multicell::*, size_type>; @@ -72,10 +71,9 @@ public: resting_potential_ = potential_mV; } - template <typename Cells, typename Detectors, typename Targets, typename Probes> + template <typename Cells, typename Targets, typename Probes> void initialize( const Cells& cells, // collection of nest::mc::cell descriptions - Detectors& detector_handles, // (write) where to store detector handles Targets& target_handles, // (write) where to store target handles Probes& probe_handles); // (write) where to store probe handles @@ -85,10 +83,6 @@ public: mechanisms_[h.first]->net_receive(h.second, weight); } - value_type detector_voltage(detector_handle h) const { - return voltage_[h]; // detector_handle is just the compartment index - } - value_type probe(probe_handle h) const { return (this->*h.first)[h.second]; } @@ -184,7 +178,32 @@ public: std::size_t num_probes() const { return probes_.size(); } + // + // Threshold crossing interface. + // Used by calling code to perform spike detection + // + + /// types defined by the back end for threshold detection + using threshold_watcher = typename backend::threshold_watcher; + using crossing_list = typename backend::threshold_watcher::crossing_list; + + /// Forward the list of threshold crossings from the back end. + /// The list is passed by value, because we don't want the calling code + /// to depend on references to internal state of the solver, and because + /// for some backends the results might have to be collated before returning. + crossing_list get_spikes() const { + return threshold_watcher_.crossings(); + } + + /// clear all spikes: aka threshold crossings. + void clear_spikes() { + threshold_watcher_.clear_crossings(); + } + private: + + threshold_watcher threshold_watcher_; + /// current time [ms] value_type t_ = value_type{0}; @@ -399,10 +418,9 @@ fvm_multicell<Backend>::compute_cv_area_capacitance( } template <typename Backend> -template <typename Cells, typename Detectors, typename Targets, typename Probes> +template <typename Cells, typename Targets, typename Probes> void fvm_multicell<Backend>::initialize( const Cells& cells, - Detectors& detector_handles, Targets& target_handles, Probes& probe_handles) { @@ -415,11 +433,9 @@ void fvm_multicell<Backend>::initialize( using util::transform_view; using util::subrange_view; - // count total detectors, targets and probes for validation of handle container sizes - std::size_t detectors_count = 0u; + // count total targets and probes for validation of handle container sizes std::size_t targets_count = 0u; std::size_t probes_count = 0u; - auto detectors_size = size(detector_handles); auto targets_size = size(target_handles); auto probes_size = size(probe_handles); @@ -445,7 +461,6 @@ void fvm_multicell<Backend>::initialize( // create each cell: auto target_hi = target_handles.begin(); - auto detector_hi = detector_handles.begin(); auto probe_hi = probe_handles.begin(); // Allocate scratch storage for calculating quantities used to build the @@ -456,6 +471,10 @@ void fvm_multicell<Backend>::initialize( std::vector<value_type> tmp_cv_areas(ncomp, 0.); std::vector<value_type> tmp_cv_capacitance(ncomp, 0.); + // used to build the information required to construct spike detectors + std::vector<size_type> spike_detector_index; + std::vector<value_type> thresholds; + // Iterate over the input cells and build the indexes etc that descrbe the // fused cell group. On completion: // - group_paranet_index contains the full parent index for the fused cells. @@ -549,13 +568,11 @@ void fvm_multicell<Backend>::initialize( mechanisms_.push_back(mechanism(stim)); } - // detector handles are just their corresponding compartment indices + // calculate spike detector handles are their corresponding compartment indices for (const auto& detector: c.detectors()) { - EXPECTS(detectors_count < detectors_size); - auto comp = comp_ival.first+find_cv_index(detector.location, graph); - *detector_hi++ = comp; - ++detectors_count; + spike_detector_index.push_back(comp); + thresholds.push_back(detector.threshold); } // record probe locations by index into corresponding state vector @@ -577,9 +594,11 @@ void fvm_multicell<Backend>::initialize( } } - // confirm user-supplied containers for detectors and probes were - // appropriately sized. - EXPECTS(detectors_size==detectors_count); + // set a back-end supplied watcher on the voltage vector + threshold_watcher_ = + threshold_watcher(voltage_, spike_detector_index, thresholds, 0); + + // confirm user-supplied container probes were appropriately sized. EXPECTS(probes_size==probes_count); // store the geometric information in target-specific containers @@ -750,6 +769,11 @@ void fvm_multicell<Backend>::reset() { m->set_params(t_, 0.025); m->nrn_init(); } + + // Reset state of the threshold watcher. + // NOTE: this has to come after the voltage_ values have been reinitialized, + // because these values are used by the watchers to set their initial state. + threshold_watcher_.reset(t_); } template <typename Backend> @@ -786,6 +810,9 @@ void fvm_multicell<Backend>::advance(double dt) { PL(); t_ += dt; + + // update spike detector thresholds + threshold_watcher_.test(t_); } } // namespace fvm diff --git a/src/memory/allocator.hpp b/src/memory/allocator.hpp index dcba92670494d80ae413bbced4d929f9d2f6ad79..3755e293f3c951b6510a6c3ed84d77d11769a75b 100644 --- a/src/memory/allocator.hpp +++ b/src/memory/allocator.hpp @@ -181,6 +181,37 @@ namespace impl { } }; + // bare bones implementation of standard compliant allocator for managed memory + template <size_type Alignment=1024> + struct managed_policy { + // Managed memory is aligned on 1024 byte boundaries. + // So the Alignment parameter must be a factor of 1024 + static_assert(1024%Alignment==0, "CUDA managed memory is always aligned on 1024 byte boundaries"); + + void* allocate_policy(std::size_t n) { + void* ptr; + auto status = cudaMallocManaged(&ptr, n); + if(status != cudaSuccess) { + LOG_ERROR("memory:: unable to allocate managed memory"); + ptr = nullptr; + } + return ptr; + } + + static constexpr size_type alignment() { + return Alignment; + } + + // managed memory can be used with standard memcpy + static constexpr bool is_malloc_compatible() { + return true; + } + + void free_policy(void* p) { + cudaFree(p); + } + }; + class device_policy { public: void *allocate_policy(size_type size) { @@ -251,8 +282,9 @@ public: } void deallocate(pointer p, size_type) { - if( p!=nullptr ) + if( p!=nullptr ) { free_policy(p); + } } size_type max_size() const { @@ -326,15 +358,20 @@ using hbw_allocator = allocator<T, impl::knl::hbw_policy<alignment>>; #endif #ifdef NMC_HAVE_CUDA -// for pinned allocation set the default alignment to correspond to the -// alignment of a page (4096 bytes), because pinned memory is allocated at page -// boundaries. -template <class T, size_t alignment=4096> +// For pinned and allocation set the default alignment to correspond to +// the alignment of 1024 bytes, because pinned memory is allocated at +// page boundaries. It is allocated at page boundaries (typically 4k), +// however in practice it will return pointers that are 1k aligned. +template <class T, size_t alignment=1024> using pinned_allocator = allocator<T, impl::cuda::pinned_policy<alignment>>; -// use 256 as default allignment, because that is the default for cudaMalloc +template <class T, size_t alignment=1024> +using managed_allocator = allocator<T, impl::cuda::managed_policy<alignment>>; + +// use 256 as default alignment, because that is the default for cudaMalloc template <class T, size_t alignment=256> using cuda_allocator = allocator<T, impl::cuda::device_policy>; + #endif } // namespace memory diff --git a/src/memory/gpu.hpp b/src/memory/gpu.hpp index c0b7cef8ada18f25e6fea0e3eb51d8ce10686808..8e950d567217459186b7ebf23953104d6366eb3a 100644 --- a/src/memory/gpu.hpp +++ b/src/memory/gpu.hpp @@ -24,7 +24,7 @@ void fill32(uint32_t* v, uint32_t value, std::size_t n); void fill64(uint64_t* v, uint64_t value, std::size_t n); // -// helpers for memory where at least on of the target or source is on the gpu +// helpers for memory where at least one of the target or source is on the gpu // template <typename T> void memcpy_d2h(const T* from, T* to, std::size_t size) { diff --git a/src/memory/managed_ptr.hpp b/src/memory/managed_ptr.hpp new file mode 100644 index 0000000000000000000000000000000000000000..944424bb6f0960bf3a15f62d5935729e4747832b --- /dev/null +++ b/src/memory/managed_ptr.hpp @@ -0,0 +1,118 @@ +#pragma once + +#include <cuda.h> + +#include <memory/allocator.hpp> + +namespace nest { +namespace mc { +namespace memory { + +// used to indicate that the type pointed to by the managed_ptr is to be +// constructed in the managed_ptr constructor +struct construct_in_place_tag {}; + +// Like std::unique_ptr, but for CUDA managed memory. +// Handles memory allocation and freeing, and the construction and destruction +// of the type being stored in the allocated memory. +// Implemented as a stand alone type instead of as a std::unique_ptr with +// custom desctructor so that __device__ annotation can be added to members +// like ::get, ::operator*, etc., which enables the use of the smart pointer +// in device side code. +// +// It is very strongly recommended that the helper make_managed_ptr be used +// instead of directly constructing the managed_ptr. +template <typename T> +class managed_ptr { + public: + + using element_type = T; + using pointer = element_type*; + using reference = element_type&; + + managed_ptr() = default; + + managed_ptr(const managed_ptr& other) = delete; + + // Allocate memory and construct in place using args. + // This is an extension over the std::unique_ptr interface, because the + // point of the wrapper is to hide the complexity of allocating managed + // memory and constructing a type in place. + template <typename... Args> + managed_ptr(construct_in_place_tag, Args&&... args) { + managed_allocator<element_type> allocator; + data_ = allocator.allocate(1u); + synchronize(); + data_ = new (data_) element_type(std::forward<Args>(args)...); + } + + managed_ptr(managed_ptr&& other) { + std::swap(other.data_, data_); + } + + // pointer to the managed object + __host__ __device__ + pointer get() const { + return data_; + } + + // return a reference to the managed object + __host__ __device__ + reference operator *() const { + return *data_; + } + + // return a reference to the managed object + __host__ __device__ + pointer operator->() const { + return get(); + } + + managed_ptr& operator=(managed_ptr&& other) { + swap(std::move(other)); + return *this; + } + + ~managed_ptr() { + if (is_allocated()) { + managed_allocator<element_type> allocator; + synchronize(); // required to ensure that memory is not in use on GPU + data_->~element_type(); + allocator.deallocate(data_, 1u); + } + } + + void swap(managed_ptr&& other) { + std::swap(other.data_, data_); + } + + __host__ __device__ + operator bool() const { + return is_allocated(); + } + + void synchronize() const { + cudaDeviceSynchronize(); + } + + private: + + __host__ __device__ + bool is_allocated() const { + return data_!=nullptr; + } + + pointer data_ = nullptr; +}; + +// The correct way to construct a type in managed memory. +// Equivalent to std::make_unique_ptr for std::unique_ptr +template <typename T, typename... Args> +managed_ptr<T> make_managed_ptr(Args&&... args) { + return managed_ptr<T>(construct_in_place_tag(), std::forward<Args>(args)...); +} + +} // namespace memory +} // namespace mc +} // namespace nest + diff --git a/src/spike_source.hpp b/src/spike_source.hpp deleted file mode 100644 index 52bfccc920298a641308ed29a956cfb861095430..0000000000000000000000000000000000000000 --- a/src/spike_source.hpp +++ /dev/null @@ -1,81 +0,0 @@ -#pragma once - -#include <cell.hpp> -#include <util/optional.hpp> - -namespace nest { -namespace mc { - -// spike detector for a lowered cell -template <typename Cell> -class spike_detector { -public: - using cell_type = Cell; - - spike_detector( - const cell_type& cell, - typename Cell::detector_handle h, - double thresh, - float t_init - ): - handle_(h), - threshold_(thresh) - { - reset(cell, t_init); - } - - util::optional<float> test(const cell_type& cell, float t) { - util::optional<float> result = util::nothing; - auto v = cell.detector_voltage(handle_); - - // these if statements could be simplified, but I keep them like - // this to clearly reflect the finite state machine - if (!is_spiking_) { - if (v>=threshold_) { - // the threshold has been passed, so estimate the time using - // linear interpolation - auto pos = (threshold_ - previous_v_)/(v - previous_v_); - result = previous_t_ + pos*(t - previous_t_); - - is_spiking_ = true; - } - } - else { - if (v<threshold_) { - is_spiking_ = false; - } - } - - previous_v_ = v; - previous_t_ = t; - - return result; - } - - bool is_spiking() const { return is_spiking_; } - - float t() const { return previous_t_; } - - float v() const { return previous_v_; } - - void reset(const cell_type& cell, float t_init) { - previous_t_ = t_init; - previous_v_ = cell.detector_voltage(handle_); - is_spiking_ = previous_v_ >= threshold_; - } - -private: - // parameters/data - typename cell_type::detector_handle handle_; - double threshold_; - - // state - float previous_t_; - float previous_v_; - bool is_spiking_; -}; - - -} // namespace mc -} // namespace nest - diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 9a72b1b9cbc1993bc8d31a2c50f013fc14ea10dd..4f774670128bca75a559fdcbdeb4baedf03117ab 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -18,7 +18,9 @@ build_modules( set(TEST_CUDA_SOURCES test_cell_group.cu + test_gpu_stack.cu test_matrix.cu + test_spikes.cu test_vector.cu # unit test driver diff --git a/tests/unit/test_cell_group.cpp b/tests/unit/test_cell_group.cpp index 0e740f7a890adb5d7a5e1373a678f7c25d2b7c17..f20945ad0f639708ef3fd7742acd6aead554d154 100644 --- a/tests/unit/test_cell_group.cpp +++ b/tests/unit/test_cell_group.cpp @@ -54,13 +54,13 @@ TEST(cell_group, sources) const auto& sources = group.spike_sources(); for (unsigned i = 0; i<sources.size(); ++i) { - auto id = sources[i].source_id; + auto id = sources[i]; if (i==0) { EXPECT_EQ(id.gid, first_gid); EXPECT_EQ(id.index, 0u); } else { - auto prev = sources[i-1].source_id; + auto prev = sources[i-1]; EXPECT_GT(id, prev); EXPECT_EQ(id.index, id.gid==prev.gid? prev.index+1: 0u); } diff --git a/tests/unit/test_fvm_multi.cpp b/tests/unit/test_fvm_multi.cpp index a4d7e90243c51a05fa11b532ee1a6f11a3004484..8c08b29470c1c81d4a3351a3b7d770dedd360572 100644 --- a/tests/unit/test_fvm_multi.cpp +++ b/tests/unit/test_fvm_multi.cpp @@ -20,11 +20,10 @@ TEST(fvm_multi, cable) nest::mc::cell cell=make_cell_ball_and_3stick(); std::vector<fvm_cell::target_handle> targets; - std::vector<fvm_cell::detector_handle> detectors; std::vector<fvm_cell::probe_handle> probes; fvm_cell fvcell; - fvcell.initialize(util::singleton_view(cell), detectors, targets, probes); + fvcell.initialize(util::singleton_view(cell), targets, probes); auto& J = fvcell.jacobian(); @@ -64,11 +63,10 @@ TEST(fvm_multi, init) cell.segment(1)->set_compartments(10); std::vector<fvm_cell::target_handle> targets; - std::vector<fvm_cell::detector_handle> detectors; std::vector<fvm_cell::probe_handle> probes; fvm_cell fvcell; - fvcell.initialize(util::singleton_view(cell), detectors, targets, probes); + fvcell.initialize(util::singleton_view(cell), targets, probes); // This is naughty: removing const from the matrix reference, but is needed // to test the build_matrix() method below (which is only accessable @@ -126,11 +124,10 @@ TEST(fvm_multi, multi_init) cells[1].add_detector({0, 0}, 3.3); std::vector<fvm_cell::target_handle> targets(4); - std::vector<fvm_cell::detector_handle> detectors(1); std::vector<fvm_cell::probe_handle> probes; fvm_cell fvcell; - fvcell.initialize(cells, detectors, targets, probes); + fvcell.initialize(cells, targets, probes); auto& J = fvcell.jacobian(); EXPECT_EQ(J.size(), 5u+13u); @@ -188,11 +185,10 @@ TEST(fvm_multi, stimulus) // as during the stimulus windows. std::vector<fvm_cell::target_handle> targets; - std::vector<fvm_cell::detector_handle> detectors; std::vector<fvm_cell::probe_handle> probes; fvm_cell fvcell; - fvcell.initialize(singleton_view(cell), detectors, targets, probes); + fvcell.initialize(singleton_view(cell), targets, probes); auto ref = fvcell.find_mechanism("stimulus"); ASSERT_TRUE(ref) << "no stimuli retrieved from lowered fvm cell: expected 2"; @@ -275,11 +271,10 @@ TEST(fvm_multi, mechanism_indexes) // generate the lowered fvm cell std::vector<fvm_cell::target_handle> targets; - std::vector<fvm_cell::detector_handle> detectors; std::vector<fvm_cell::probe_handle> probes; fvm_cell fvcell; - fvcell.initialize(util::singleton_view(c), detectors, targets, probes); + fvcell.initialize(util::singleton_view(c), targets, probes); // make vectors with the expected CV indexes for each mechanism std::vector<unsigned> hh_index = {0u, 4u, 5u, 6u, 7u, 8u}; diff --git a/tests/unit/test_gpu_stack.cu b/tests/unit/test_gpu_stack.cu new file mode 100644 index 0000000000000000000000000000000000000000..38087a9570ad975f8a3028966a56a0b47842ac61 --- /dev/null +++ b/tests/unit/test_gpu_stack.cu @@ -0,0 +1,79 @@ +#include "../gtest.h" + +#include <backends/gpu_stack.hpp> +#include <memory/managed_ptr.hpp> + +using namespace nest::mc; + +TEST(gpu_stack, construction) { + using T = int; + + gpu::gpu_stack<T> s(10); + + EXPECT_EQ(0u, s.size()); + EXPECT_EQ(10u, s.capacity()); +} + +// kernel and functors for testing push_back functionality +namespace kernels { + template <typename F> + __global__ + void push_back(gpu::gpu_stack<int>& s, F f) { + if (f(threadIdx.x)) { + s.push_back(threadIdx.x); + } + } + + struct all_ftor { + __host__ __device__ + bool operator() (int i) { + return true; + } + }; + + struct even_ftor { + __host__ __device__ + bool operator() (int i) { + return (i%2)==0; + } + }; + + struct odd_ftor { + __host__ __device__ + bool operator() (int i) { + return i%2; + } + }; +} + +TEST(gpu_stack, push_back) { + using T = int; + using stack = gpu::gpu_stack<T>; + + const unsigned n = 10; + EXPECT_TRUE(n%2 == 0); // require n is even for tests to work + auto s = memory::make_managed_ptr<stack>(n); + + kernels::push_back<<<1, n>>>(*s, kernels::all_ftor()); + cudaDeviceSynchronize(); + EXPECT_EQ(n, s->size()); + for (auto i=0; i<int(s->size()); ++i) { + EXPECT_EQ(i, (*s)[i]); + } + + s->clear(); + kernels::push_back<<<1, n>>>(*s, kernels::even_ftor()); + cudaDeviceSynchronize(); + EXPECT_EQ(n/2, s->size()); + for (auto i=0; i<int(s->size())/2; ++i) { + EXPECT_EQ(2*i, (*s)[i]); + } + + s->clear(); + kernels::push_back<<<1, n>>>(*s, kernels::odd_ftor()); + cudaDeviceSynchronize(); + EXPECT_EQ(n/2, s->size()); + for (auto i=0; i<int(s->size())/2; ++i) { + EXPECT_EQ(2*i+1, (*s)[i]); + } +} diff --git a/tests/unit/test_probe.cpp b/tests/unit/test_probe.cpp index 5c77a72cdaa88282a36c930d081f29ffd3115950..7ebbefc7adfd0c81de952479cc94d8d2f03d2e88 100644 --- a/tests/unit/test_probe.cpp +++ b/tests/unit/test_probe.cpp @@ -59,11 +59,10 @@ TEST(probe, fvm_multicell) using fvm_multicell = fvm::fvm_multicell<nest::mc::multicore::backend>; std::vector<fvm_multicell::target_handle> targets; - std::vector<fvm_multicell::detector_handle> detectors; std::vector<fvm_multicell::probe_handle> probes{3}; fvm_multicell lcell; - lcell.initialize(util::singleton_view(bs), detectors, targets, probes); + lcell.initialize(util::singleton_view(bs), targets, probes); // Know from implementation that probe_handle.second // is a compartment index: expect probe values and diff --git a/tests/unit/test_spikes.cpp b/tests/unit/test_spikes.cpp index 6f8a21dec5743f0c2aee1cb504b298be4c601f00..2941698cc5b47cc2739e19f3eaca7f911001f0f1 100644 --- a/tests/unit/test_spikes.cpp +++ b/tests/unit/test_spikes.cpp @@ -1,69 +1,122 @@ #include "../gtest.h" #include <spike.hpp> -#include <spike_source.hpp> - -struct cell_proxy { - using detector_handle = int; - double detector_voltage(detector_handle) const { - return v; +#include <backends/fvm_multicore.hpp> + +using namespace nest::mc; + +TEST(spikes, threshold_watcher) { + using backend = multicore::backend; + using size_type = backend::size_type; + using value_type = backend::value_type; + using array = backend::array; + using list = backend::threshold_watcher::crossing_list; + + // the test creates a watch on 3 values in the array values (which has 10 + // elements in total). + const auto n = 10; + + const std::vector<size_type> index{0, 5, 7}; + const std::vector<value_type> thresh{1., 2., 3.}; + + // all values are initially 0, except for values[5] which we set + // to exceed the threshold of 2. for the second watch + array values(n, 0); + values[5] = 3.; + + // list for storing expected crossings for validation at the end + list expected; + + // create the watch + backend::threshold_watcher watch(values, index, thresh, 0.f); + + // initially the first and third watch should not be spiking + // the second is spiking + EXPECT_FALSE(watch.is_crossed(0)); + EXPECT_TRUE(watch.is_crossed(1)); + EXPECT_FALSE(watch.is_crossed(2)); + + // test again at t=1, with unchanged values + // - nothing should change + watch.test(1.); + EXPECT_FALSE(watch.is_crossed(0)); + EXPECT_TRUE(watch.is_crossed(1)); + EXPECT_FALSE(watch.is_crossed(2)); + EXPECT_EQ(watch.crossings().size(), 0u); + + // test at t=2, with all values set to zero + // - 2nd watch should now stop spiking + memory::fill(values, 0.); + watch.test(2.); + EXPECT_FALSE(watch.is_crossed(0)); + EXPECT_FALSE(watch.is_crossed(1)); + EXPECT_FALSE(watch.is_crossed(2)); + EXPECT_EQ(watch.crossings().size(), 0u); + + // test at t=3, with all values set to 4. + // - all watches should now be spiking + memory::fill(values, 4.); + watch.test(3.); + EXPECT_TRUE(watch.is_crossed(0)); + EXPECT_TRUE(watch.is_crossed(1)); + EXPECT_TRUE(watch.is_crossed(2)); + EXPECT_EQ(watch.crossings().size(), 3u); + + // record the expected spikes + expected.push_back({0u, 2.25f}); + expected.push_back({1u, 2.50f}); + expected.push_back({2u, 2.75f}); + + // test at t=4, with all values set to 0. + // - all watches should stop spiking + memory::fill(values, 0.); + watch.test(4.); + EXPECT_FALSE(watch.is_crossed(0)); + EXPECT_FALSE(watch.is_crossed(1)); + EXPECT_FALSE(watch.is_crossed(2)); + EXPECT_EQ(watch.crossings().size(), 3u); + + // test at t=5, with value on 3rd watch set to 6 + // - watch 3 should be spiking + values[index[2]] = 6.; + watch.test(5.); + EXPECT_FALSE(watch.is_crossed(0)); + EXPECT_FALSE(watch.is_crossed(1)); + EXPECT_TRUE(watch.is_crossed(2)); + EXPECT_EQ(watch.crossings().size(), 4u); + expected.push_back({2u, 4.5f}); + + // + // test that all generated spikes matched the expected values + // + if (expected.size()!=watch.crossings().size()) { + FAIL() << "count of recorded crosssings did not match expected count"; } - - double v = -65.; -}; - -TEST(spikes, spike_detector) -{ - using namespace nest::mc; - using detector_type = spike_detector<cell_proxy>; - using detector_handle = cell_proxy::detector_handle; - - cell_proxy proxy; - float threshold = 10.f; - float t = 0.f; - float dt = 1.f; - detector_handle handle{}; - - auto detector = detector_type(proxy, handle, threshold, t); - - EXPECT_FALSE(detector.is_spiking()); - EXPECT_EQ(proxy.v, detector.v()); - EXPECT_EQ(t, detector.t()); - - { - t += dt; - proxy.v = 0; - auto spike = detector.test(proxy, t); - EXPECT_FALSE(spike); - - EXPECT_FALSE(detector.is_spiking()); - EXPECT_EQ(proxy.v, detector.v()); - EXPECT_EQ(t, detector.t()); + auto const& spikes = watch.crossings(); + for (auto i=0u; i<expected.size(); ++i) { + EXPECT_EQ(expected[i], spikes[i]); } - { - t += dt; - proxy.v = 20; - auto spike = detector.test(proxy, t); - - EXPECT_TRUE(spike); - EXPECT_EQ(spike.get(), 1.5); - - EXPECT_TRUE(detector.is_spiking()); - EXPECT_EQ(proxy.v, detector.v()); - EXPECT_EQ(t, detector.t()); - } - - { - t += dt; - proxy.v = 0; - auto spike = detector.test(proxy, t); - - EXPECT_FALSE(spike); - - EXPECT_FALSE(detector.is_spiking()); - EXPECT_EQ(proxy.v, detector.v()); - EXPECT_EQ(t, detector.t()); - } + // + // test that clearing works + // + watch.clear_crossings(); + EXPECT_EQ(watch.crossings().size(), 0u); + EXPECT_FALSE(watch.is_crossed(0)); + EXPECT_FALSE(watch.is_crossed(1)); + EXPECT_TRUE(watch.is_crossed(2)); + + // + // test that resetting works + // + EXPECT_EQ(watch.last_test_time(), 5); + memory::fill(values, 0); + values[index[0]] = 10.; // first watch should be intialized to spiking state + watch.reset(0); + EXPECT_EQ(watch.last_test_time(), 0); + EXPECT_EQ(watch.crossings().size(), 0u); + EXPECT_TRUE(watch.is_crossed(0)); + EXPECT_FALSE(watch.is_crossed(1)); + EXPECT_FALSE(watch.is_crossed(2)); } diff --git a/tests/unit/test_spikes.cu b/tests/unit/test_spikes.cu new file mode 100644 index 0000000000000000000000000000000000000000..01daf1562b7ffde7dfd4700f041ac63dacdae2d7 --- /dev/null +++ b/tests/unit/test_spikes.cu @@ -0,0 +1,122 @@ +#include "../gtest.h" + +#include <spike.hpp> +#include <backends/fvm_gpu.hpp> + +using namespace nest::mc; + +TEST(spikes, threshold_watcher) { + using backend = gpu::backend; + using size_type = backend::size_type; + using value_type = backend::value_type; + using array = backend::array; + using list = backend::threshold_watcher::crossing_list; + + // the test creates a watch on 3 values in the array values (which has 10 + // elements in total). + const auto n = 10; + + const std::vector<size_type> index{0, 5, 7}; + const std::vector<value_type> thresh{1., 2., 3.}; + + // all values are initially 0, except for values[5] which we set + // to exceed the threshold of 2. for the second watch + array values(n, 0); + values[5] = 3.; + + // list for storing expected crossings for validation at the end + list expected; + + // create the watch + backend::threshold_watcher watch(values, index, thresh, 0.f); + + // initially the first and third watch should not be spiking + // the second is spiking + EXPECT_FALSE(watch.is_crossed(0)); + EXPECT_TRUE(watch.is_crossed(1)); + EXPECT_FALSE(watch.is_crossed(2)); + + // test again at t=1, with unchanged values + // - nothing should change + watch.test(1.); + EXPECT_FALSE(watch.is_crossed(0)); + EXPECT_TRUE(watch.is_crossed(1)); + EXPECT_FALSE(watch.is_crossed(2)); + EXPECT_EQ(watch.crossings().size(), 0u); + + // test at t=2, with all values set to zero + // - 2nd watch should now stop spiking + memory::fill(values, 0.); + watch.test(2.); + EXPECT_FALSE(watch.is_crossed(0)); + EXPECT_FALSE(watch.is_crossed(1)); + EXPECT_FALSE(watch.is_crossed(2)); + EXPECT_EQ(watch.crossings().size(), 0u); + + // test at t=3, with all values set to 4. + // - all watches should now be spiking + memory::fill(values, 4.); + watch.test(3.); + EXPECT_TRUE(watch.is_crossed(0)); + EXPECT_TRUE(watch.is_crossed(1)); + EXPECT_TRUE(watch.is_crossed(2)); + EXPECT_EQ(watch.crossings().size(), 3u); + + // record the expected spikes + expected.push_back({0u, 2.25f}); + expected.push_back({1u, 2.50f}); + expected.push_back({2u, 2.75f}); + + // test at t=4, with all values set to 0. + // - all watches should stop spiking + memory::fill(values, 0.); + watch.test(4.); + EXPECT_FALSE(watch.is_crossed(0)); + EXPECT_FALSE(watch.is_crossed(1)); + EXPECT_FALSE(watch.is_crossed(2)); + EXPECT_EQ(watch.crossings().size(), 3u); + + // test at t=5, with value on 3rd watch set to 6 + // - watch 3 should be spiking + values[index[2]] = 6.; + watch.test(5.); + EXPECT_FALSE(watch.is_crossed(0)); + EXPECT_FALSE(watch.is_crossed(1)); + EXPECT_TRUE(watch.is_crossed(2)); + EXPECT_EQ(watch.crossings().size(), 4u); + expected.push_back({2u, 4.5f}); + + // + // test that all generated spikes matched the expected values + // + if (expected.size()!=watch.crossings().size()) { + FAIL() << "count of recorded crosssings did not match expected count"; + } + auto const& spikes = watch.crossings(); + for (auto i=0u; i<expected.size(); ++i) { + EXPECT_EQ(expected[i], spikes[i]); + } + + // + // test that clearing works + // + watch.clear_crossings(); + EXPECT_EQ(watch.crossings().size(), 0u); + EXPECT_FALSE(watch.is_crossed(0)); + EXPECT_FALSE(watch.is_crossed(1)); + EXPECT_TRUE(watch.is_crossed(2)); + + // + // test that resetting works + // + EXPECT_EQ(watch.last_test_time(), 5); + memory::fill(values, 0); + values[index[0]] = 10.; // first watch should be intialized to spiking state + watch.reset(0); + EXPECT_EQ(watch.last_test_time(), 0); + EXPECT_EQ(watch.crossings().size(), 0u); + EXPECT_TRUE(watch.is_crossed(0)); + EXPECT_FALSE(watch.is_crossed(1)); + EXPECT_FALSE(watch.is_crossed(2)); +} +