diff --git a/src/fvm_cell.hpp b/src/fvm_cell.hpp index 4a908bba49887948e37bfa7fe3409b19df90bbb3..8a76b0c51a9ee2d729b545cfae1b6c19225a9631 100644 --- a/src/fvm_cell.hpp +++ b/src/fvm_cell.hpp @@ -146,6 +146,18 @@ class fvm_cell { return events_; } + // returns the compartment index of a segment location + int compartment_index(segment_location loc) { + EXPECTS(loc.segment < segment_index_.size()); + + const auto seg = loc.segment; + + auto first = segment_index_[seg]; + auto n = segment_index_[seg+1] - first; + auto index = std::floor(n*loc.position); + return index<n ? first+index : first+n-1; + } + private: /// current time @@ -154,6 +166,9 @@ class fvm_cell { /// the linear system for implicit time stepping of cell state matrix_type matrix_; + /// index for fast lookup of compartment index ranges of segments + index_type segment_index_; + /// cv_areas_[i] is the surface area of CV i vector_type cv_areas_; @@ -217,7 +232,7 @@ fvm_cell<T, I>::fvm_cell(nest::mc::cell const& cell) matrix_ = matrix_type(graph.parent_index); auto parent_index = matrix_.p(); - auto const& segment_index = graph.segment_index; + segment_index_ = graph.segment_index; auto seg_idx = 0; for(auto const& s : cell.segments()) { @@ -250,7 +265,7 @@ fvm_cell<T, I>::fvm_cell(nest::mc::cell const& cell) auto c_m = cable->mechanism("membrane").get("c_m").value; auto r_L = cable->mechanism("membrane").get("r_L").value; for(auto c : cable->compartments()) { - auto i = segment_index[seg_idx] + c.index; + auto i = segment_index_[seg_idx] + c.index; auto j = parent_index[i]; auto radius_center = math::mean(c.radius); @@ -309,7 +324,7 @@ fvm_cell<T, I>::fvm_cell(nest::mc::cell const& cell) // calculate the number of compartments that contain the mechanism auto num_comp = 0u; for(auto seg : mech.second) { - num_comp += segment_index[seg+1] - segment_index[seg]; + num_comp += segment_index_[seg+1] - segment_index_[seg]; } // build a vector of the indexes of the compartments that contain @@ -317,11 +332,11 @@ fvm_cell<T, I>::fvm_cell(nest::mc::cell const& cell) index_type compartment_index(num_comp); auto pos = 0u; for(auto seg : mech.second) { - auto seg_size = segment_index[seg+1] - segment_index[seg]; + auto seg_size = segment_index_[seg+1] - segment_index_[seg]; std::iota( compartment_index.data() + pos, compartment_index.data() + pos + seg_size, - segment_index[seg] + segment_index_[seg] ); pos += seg_size; } diff --git a/src/lowered_cell.hpp b/src/lowered_cell.hpp new file mode 100644 index 0000000000000000000000000000000000000000..07137655f3a0f35c0a258865b694116b9353f590 --- /dev/null +++ b/src/lowered_cell.hpp @@ -0,0 +1,20 @@ +#pragma once + +namespace nest { +namespace mc { + +template <typename Cell> +class lowered_cell { + public : + + using cell_type = Cell; + using value_type = typename cell_type::value_type; + using size_type = typename cell_type::value_type; + + private : + + cell_type cell_; +}; + +} // namespace mc +} // namespace nest diff --git a/src/spike.hpp b/src/spike.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cd250ba55b2a61940fa86fec91103122f8749b5b --- /dev/null +++ b/src/spike.hpp @@ -0,0 +1,40 @@ +#pragma once + +#include <type_traits> +#include <ostream> + +namespace nest { +namespace mc { + +template < + typename I, + typename = typename std::enable_if<std::is_integral<I>::value> +> +struct spike { + using index_type = I; + index_type source = 0; + float time = -1.; + + spike() = default; + + spike(index_type s, float t) + : source(s), time(t) + {} +}; + +} // namespace mc +} // namespace nest + +/// custom stream operator for printing nest::mc::spike<> values +template <typename I> +std::ostream& operator <<(std::ostream& o, nest::mc::spike<I> s) { + return o << "spike[t " << s.time << ", src " << s.source << "]"; +} + +/// less than comparison operator for nest::mc::spike<> values +/// spikes are ordered by spike time, for use in sorting and queueing +template <typename I> +bool operator <(nest::mc::spike<I> lhs, nest::mc::spike<I> rhs) { + return lhs.time < rhs.time; +} + diff --git a/src/spike_source.hpp b/src/spike_source.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8a5199bf16eb80411c796bca42708aec5dda96c1 --- /dev/null +++ b/src/spike_source.hpp @@ -0,0 +1,114 @@ +#pragma once + +#include <cell.hpp> +#include <fvm_cell.hpp> +#include <util/optional.hpp> + +namespace nest { +namespace mc { + +// generic spike source +class spike_source { + public: + + virtual std::vector<float> test(float t) = 0; +}; + +// spike detector for a lowered cell +template <typename Cell> +class spike_detector : public spike_source +{ + public: + using cell_type = Cell; + + spike_detector( + cell_type const* cell, + segment_location loc, + double thresh, + float t_init + ) + : cell_(cell), + location_(loc), + threshold_(thresh), + previous_t_(t_init) + { + previous_v_ = cell->voltage(location_); + is_spiking_ = previous_v_ >= thresh ? true : false; + } + + std::vector<float> test(float t) override { + std::vector<float> spike_times; + auto v = cell_->voltage(location_); + if (!is_spiking_) { + if (v>=threshold_) { + // the threshold has been passed, so estimate the time using + // linear interpolation + auto pos = (threshold_ - previous_v_)/(v - previous_v_); + spike_times.push_back(previous_t_ + pos*(t - previous_t_)); + + is_spiking_ = true; + } + } + else { + if (v<threshold_) { + is_spiking_ = false; + } + } + + previous_v_ = v; + previous_t_ = t; + + return spike_times; + } + + bool is_spiking() const { + return is_spiking_; + } + + segment_location location() const { + return location_; + } + + private: + + // parameters/data + cell_type* cell_; + segment_location location_; + double threshold_; + + // state + float previous_t_; + float previous_v_; + bool is_spiking_; +}; + +// spike generator according to a Poisson process +class poisson_generator : public spike_source +{ + public: + + poisson_generator(float r) + : firing_rate_(r) + {} + + util::optional<float> test(float t) { + // generate a uniformly distrubuted random number x \in [0,1] + // if (x > r*dt) we have a spike in the interval + std::vector<float> spike_times; + if(rand() > firing_rate_*(t-previous_t_)) { + return t; + } + return util::optional<float>::nothing; + } + + private: + + // firing rate in spikes/ms + float firing_rate_; + float previous_t_; +}; + + +} // namespace mc +} // namespace nest +