From 2dc9520ed3ddc65e0e738083d29cdf14be9304e6 Mon Sep 17 00:00:00 2001 From: noraabiakar <nora.abiakar@gmail.com> Date: Wed, 6 Feb 2019 12:43:53 +0100 Subject: [PATCH] Explicit Gap Junctions (#661) Add support for gap junctions in mc_cells, modelled as a conductance between 2 cell CVs. Gap junctions act as additional current sources on the CVs, as opposed to participating in the implicit voltage integration step. Cells connected via gap junctions must be in the same cell group as determined by the provided domain decomposition. * Extend `mc_cell` to hold a list of gap junction locations. * Add `num_gap_junction_sites()` and `gap_junctions_on()` methods to the `recipe` interface. * Add `gather_gids()` collective operation to distributed context interface and implementations. * Extend `partition_load_balance()` functionality to ensure that cells connected by gap junctions are put in the same groups. * Permute cells within `mc_cell_group` so that cells connected by gap junctions are contiguous. * Add gap junction information to `multicore::shared_state` and `gpu::shared_state`, together with `add_gj_current()` method that computes GJ current contributions. * Add `time_dep` field to shared state structures that records which adjacent cells are GJ-connected for the purposes of determining integration time intervals. * Add `sync_time_to()` method to shared state structures to determine the common integration time step across GJ-connected cells. * Incorporate new shared state methods into the `fvm_lower_cell_impl::integrate()` integration loop. * Add new arbor exception `arb::gj_kind_mismatch`. * Add unit tests for new functionality. * Add new example code `gap_junctions` for GJ demonstration. --- arbor/arbexcept.cpp | 7 + arbor/backends/gpu/gpu_store_types.hpp | 1 + arbor/backends/gpu/shared_state.cpp | 18 + arbor/backends/gpu/shared_state.cu | 49 +++ arbor/backends/gpu/shared_state.hpp | 11 + arbor/backends/multicore/multicore_common.hpp | 1 + arbor/backends/multicore/shared_state.cpp | 39 +- arbor/backends/multicore/shared_state.hpp | 11 + arbor/communication/dry_run_context.cpp | 29 +- arbor/communication/mpi_context.cpp | 5 + arbor/distributed_context.hpp | 19 + arbor/fvm_layout.cpp | 1 - arbor/fvm_layout.hpp | 1 + arbor/fvm_lowered_cell.hpp | 1 + arbor/fvm_lowered_cell_impl.hpp | 65 +++- arbor/include/arbor/arbexcept.hpp | 5 + arbor/include/arbor/domain_decomposition.hpp | 4 +- arbor/include/arbor/fvm_types.hpp | 12 + arbor/include/arbor/mc_cell.hpp | 17 + arbor/include/arbor/recipe.hpp | 16 +- arbor/mc_cell_group.cpp | 62 +++- arbor/mc_cell_group.hpp | 17 + arbor/partition_load_balance.cpp | 126 ++++++- example/CMakeLists.txt | 1 + example/gap_junctions/CMakeLists.txt | 4 + example/gap_junctions/gap_junctions.cpp | 338 ++++++++++++++++++ example/gap_junctions/parameters.hpp | 61 ++++ example/gap_junctions/readme.md | 3 + mechanisms/mod/nax.mod | 2 +- test/simple_recipes.hpp | 1 - test/ubench/mech_vec.cpp | 16 +- test/unit-distributed/test_communicator.cpp | 39 ++ .../test_domain_decomposition.cpp | 204 ++++++++++- test/unit/CMakeLists.txt | 1 + test/unit/test_domain_decomposition.cpp | 102 ++++++ test/unit/test_dry_run_context.cpp | 20 ++ test/unit/test_event_delivery.cpp | 13 +- test/unit/test_fvm_lowered.cpp | 254 ++++++++++++- test/unit/test_local_context.cpp | 16 + test/unit/test_mc_cell_group.cpp | 150 ++++++++ test/unit/test_mech_temperature.cpp | 4 +- test/unit/test_probe.cpp | 2 +- test/unit/test_supercell.cpp | 71 ++++ test/unit/test_synapses.cpp | 3 +- 44 files changed, 1763 insertions(+), 59 deletions(-) create mode 100644 example/gap_junctions/CMakeLists.txt create mode 100644 example/gap_junctions/gap_junctions.cpp create mode 100644 example/gap_junctions/parameters.hpp create mode 100644 example/gap_junctions/readme.md create mode 100644 test/unit/test_supercell.cpp diff --git a/arbor/arbexcept.cpp b/arbor/arbexcept.cpp index a8b589b7..8ea9e1dd 100644 --- a/arbor/arbexcept.cpp +++ b/arbor/arbexcept.cpp @@ -26,6 +26,13 @@ bad_probe_id::bad_probe_id(cell_member_type probe_id): probe_id(probe_id) {} + +gj_kind_mismatch::gj_kind_mismatch(cell_gid_type gid_0, cell_gid_type gid_1): + arbor_exception(pprintf("Cells on gid {} and {} connected via gap junction have different cell kinds", gid_0, gid_1)), + gid_0(gid_0), + gid_1(gid_1) +{} + bad_event_time::bad_event_time(time_type event_time, time_type sim_time): arbor_exception(pprintf("event time {} precedes current simulation time {}", event_time, sim_time)), event_time(event_time), diff --git a/arbor/backends/gpu/gpu_store_types.hpp b/arbor/backends/gpu/gpu_store_types.hpp index ac215ed4..708233a9 100644 --- a/arbor/backends/gpu/gpu_store_types.hpp +++ b/arbor/backends/gpu/gpu_store_types.hpp @@ -17,6 +17,7 @@ namespace gpu { using array = memory::device_vector<fvm_value_type>; using iarray = memory::device_vector<fvm_index_type>; +using gjarray = memory::device_vector<fvm_gap_junction>; using deliverable_event_stream = arb::gpu::multi_event_stream<deliverable_event>; using sample_event_stream = arb::gpu::multi_event_stream<sample_event>; diff --git a/arbor/backends/gpu/shared_state.cpp b/arbor/backends/gpu/shared_state.cpp index 5cd8972c..5760b02c 100644 --- a/arbor/backends/gpu/shared_state.cpp +++ b/arbor/backends/gpu/shared_state.cpp @@ -35,10 +35,15 @@ void update_time_to_impl( std::size_t n, fvm_value_type* time_to, const fvm_value_type* time, fvm_value_type dt, fvm_value_type tmax); +void sync_time_to_impl(std::size_t n, fvm_value_type* time_to, const fvm_index_type* time_deps); + void set_dt_impl( fvm_size_type ncell, fvm_size_type ncomp, fvm_value_type* dt_cell, fvm_value_type* dt_comp, const fvm_value_type* time_to, const fvm_value_type* time, const fvm_index_type* cv_to_cell); +void add_gj_current_impl( + fvm_size_type n_gj, const fvm_gap_junction* gj, const fvm_value_type* v, fvm_value_type* i); + void take_samples_impl( const multi_event_stream_state<raw_probe_info>& s, const fvm_value_type* time, fvm_value_type* sample_time, fvm_value_type* sample_value); @@ -109,11 +114,16 @@ void ion_state::zero_current() { shared_state::shared_state( fvm_size_type n_cell, const std::vector<fvm_index_type>& cv_to_cell_vec, + const std::vector<fvm_index_type>& time_dep_vec, + const std::vector<fvm_gap_junction>& gj_vec, unsigned // alignment parameter ignored. ): n_cell(n_cell), n_cv(cv_to_cell_vec.size()), + n_gj(gj_vec.size()), cv_to_cell(make_const_view(cv_to_cell_vec)), + time_dep(make_const_view(time_dep_vec)), + gap_junctions(make_const_view(gj_vec)), time(n_cell), time_to(n_cell), dt_cell(n_cell), @@ -170,10 +180,18 @@ void shared_state::update_time_to(fvm_value_type dt_step, fvm_value_type tmax) { update_time_to_impl(n_cell, time_to.data(), time.data(), dt_step, tmax); } +void shared_state::sync_time_to() { + sync_time_to_impl(n_cell, time_to.data(), time_dep.data()); +} + void shared_state::set_dt() { set_dt_impl(n_cell, n_cv, dt_cell.data(), dt_cv.data(), time_to.data(), time.data(), cv_to_cell.data()); } +void shared_state::add_gj_current() { + add_gj_current_impl(n_gj, gap_junctions.data(), voltage.data(), current_density.data()); +} + std::pair<fvm_value_type, fvm_value_type> shared_state::time_bounds() const { return minmax_value_impl(n_cell, time.data()); } diff --git a/arbor/backends/gpu/shared_state.cu b/arbor/backends/gpu/shared_state.cu index 995eeb76..806418b9 100644 --- a/arbor/backends/gpu/shared_state.cu +++ b/arbor/backends/gpu/shared_state.cu @@ -5,6 +5,7 @@ #include <backends/event.hpp> #include <backends/multi_event_stream_state.hpp> +#include "cuda_atomic.hpp" #include "cuda_common.hpp" namespace arb { @@ -33,6 +34,24 @@ void init_concentration_impl(unsigned n, T* Xi, T* Xo, const T* weight_Xi, const } } +template <typename T, typename I> +__global__ void sync_time_to_impl(unsigned n, T* time_to, const I* time_deps) { + unsigned i = threadIdx.x+blockIdx.x*blockDim.x; + if (i<n) { + if (time_deps[i] > 0) { + auto min_t = time_to[i]; + for (int j = 1; j < time_deps[i]; j++) { + if (time_to[i+j] < min_t) { + min_t = time_to[i+j]; + } + } + for (int j = 0; j < time_deps[i]; j++) { + time_to[i+j] = min_t; + } + } + } +} + template <typename T> __global__ void update_time_to_impl(unsigned n, T* time_to, const T* time, T dt, T tmax) { unsigned i = threadIdx.x+blockIdx.x*blockDim.x; @@ -42,6 +61,17 @@ __global__ void update_time_to_impl(unsigned n, T* time_to, const T* time, T dt, } } +template <typename T, typename I> +__global__ void add_gj_current_impl(unsigned n, const T* gj_info, const I* voltage, I* current_density) { + unsigned i = threadIdx.x+blockIdx.x*blockDim.x; + if (i<n) { + auto gj = gj_info[i]; + auto curr = gj.weight * (voltage[gj.loc.second] - voltage[gj.loc.first]); // nA + + cuda_atomic_sub(current_density + gj.loc.first, curr); + } +} + // Vector minus: x = y - z template <typename T> __global__ void vec_minus(unsigned n, T* x, const T* y, const T* z) { @@ -101,6 +131,15 @@ void init_concentration_impl( kernel::init_concentration_impl<<<nblock, block_dim>>>(n, Xi, Xo, weight_Xi, weight_Xo, c_int, c_ext); } +void sync_time_to_impl(std::size_t n, fvm_value_type* time_to, const fvm_index_type* time_deps) +{ + if (!n) return; + + constexpr int block_dim = 128; + int nblock = block_count(n, block_dim); + kernel::sync_time_to_impl<<<nblock, block_dim>>>(n, time_to, time_deps); +} + void update_time_to_impl( std::size_t n, fvm_value_type* time_to, const fvm_value_type* time, fvm_value_type dt, fvm_value_type tmax) @@ -126,6 +165,16 @@ void set_dt_impl( kernel::gather<<<nblock, block_dim>>>(ncomp, dt_comp, dt_cell, cv_to_cell); } +void add_gj_current_impl( + fvm_size_type n_gj, const fvm_gap_junction* gj_info, const fvm_value_type* voltage, fvm_value_type* current_density) +{ + if (!n_gj) return; + + constexpr int block_dim = 128; + int nblock = block_count(n_gj, block_dim); + kernel::add_gj_current_impl<<<nblock, block_dim>>>(n_gj, gj_info, voltage, current_density); +} + void take_samples_impl( const multi_event_stream_state<raw_probe_info>& s, const fvm_value_type* time, fvm_value_type* sample_time, fvm_value_type* sample_value) diff --git a/arbor/backends/gpu/shared_state.hpp b/arbor/backends/gpu/shared_state.hpp index 81174715..c66ac3f1 100644 --- a/arbor/backends/gpu/shared_state.hpp +++ b/arbor/backends/gpu/shared_state.hpp @@ -67,8 +67,11 @@ struct ion_state { struct shared_state { fvm_size_type n_cell = 0; // Number of distinct cells (integration domains). fvm_size_type n_cv = 0; // Total number of CVs. + fvm_size_type n_gj = 0; // Total number of GJs. iarray cv_to_cell; // Maps CV index to cell index. + iarray time_dep; // Provides information about supercells + gjarray gap_junctions; // Stores gap_junction info. array time; // Maps cell index to integration start time [ms]. array time_to; // Maps cell index to integration stop time [ms]. array dt_cell; // Maps cell index to (stop time) - (start time) [ms]. @@ -86,6 +89,8 @@ struct shared_state { shared_state( fvm_size_type n_cell, const std::vector<fvm_index_type>& cv_to_cell_vec, + const std::vector<fvm_index_type>& time_dep_vec, + const std::vector<fvm_gap_junction>& gj_vec, unsigned align ); @@ -104,9 +109,15 @@ struct shared_state { // Set time_to to earliest of time+dt_step and tmax. void update_time_to(fvm_value_type dt_step, fvm_value_type tmax); + // Synchrnize the time_to for supercells. + void sync_time_to(); + // Set the per-cell and per-compartment dt from time_to - time. void set_dt(); + // Update gap_junction state + void add_gj_current(); + // Return minimum and maximum time value [ms] across cells. std::pair<fvm_value_type, fvm_value_type> time_bounds() const; diff --git a/arbor/backends/multicore/multicore_common.hpp b/arbor/backends/multicore/multicore_common.hpp index b0b3ed4e..4a3ee9bc 100644 --- a/arbor/backends/multicore/multicore_common.hpp +++ b/arbor/backends/multicore/multicore_common.hpp @@ -23,6 +23,7 @@ using padded_vector = std::vector<V, util::padded_allocator<V>>; using array = padded_vector<fvm_value_type>; using iarray = padded_vector<fvm_index_type>; +using gjarray = padded_vector<fvm_gap_junction>; using deliverable_event_stream = arb::multicore::multi_event_stream<deliverable_event>; using sample_event_stream = arb::multicore::multi_event_stream<sample_event>; diff --git a/arbor/backends/multicore/shared_state.cpp b/arbor/backends/multicore/shared_state.cpp index 38120f4a..3bb73441 100644 --- a/arbor/backends/multicore/shared_state.cpp +++ b/arbor/backends/multicore/shared_state.cpp @@ -117,13 +117,18 @@ void ion_state::zero_current() { shared_state::shared_state( fvm_size_type n_cell, const std::vector<fvm_index_type>& cv_to_cell_vec, + const std::vector<fvm_index_type>& time_dep_vec, + const std::vector<fvm_gap_junction>& gj_vec, unsigned align ): alignment(min_alignment(align)), alloc(alignment), n_cell(n_cell), n_cv(cv_to_cell_vec.size()), + n_gj(gj_vec.size()), cv_to_cell(math::round_up(n_cv, alignment), pad(alignment)), + time_dep(n_cell), + gap_junctions(math::round_up(n_gj, alignment), pad(alignment)), time(n_cell, pad(alignment)), time_to(n_cell, pad(alignment)), dt_cell(n_cell, pad(alignment)), @@ -133,11 +138,16 @@ shared_state::shared_state( temperature_degC(NAN), deliverable_events(n_cell) { - // For indices in the padded tail of cv_to_cell, set index to last valid cell index. + std::copy(time_dep_vec.begin(), time_dep_vec.end(), time_dep.begin()); + // For indices in the padded tail of cv_to_cell, set index to last valid cell index. if (n_cv>0) { std::copy(cv_to_cell_vec.begin(), cv_to_cell_vec.end(), cv_to_cell.begin()); - std::fill(cv_to_cell.begin()+n_cv, cv_to_cell.end(), cv_to_cell_vec.back()); + std::fill(cv_to_cell.begin() + n_cv, cv_to_cell.end(), cv_to_cell_vec.back()); + } + if (n_gj>0) { + std::copy(gj_vec.begin(), gj_vec.end(), gap_junctions.begin()); + std::fill(gap_junctions.begin()+n_gj, gap_junctions.end(), gj_vec.back()); } } @@ -182,6 +192,21 @@ void shared_state::ions_nernst_reversal_potential(fvm_value_type temperature_K) i.second.nernst(temperature_K); } } +void shared_state::sync_time_to() { + for (fvm_size_type i = 0; i<n_cell; i++) { + if (!time_dep[i]) continue; + + fvm_value_type min_t = time_to[i]; + for (int j = 1; j < time_dep[i]; j++) { + if (time_to[i+j] < min_t) { + min_t = time_to[i+j]; + } + } + for (int j = 0; j < time_dep[i]; j++) { + time_to[i+j] = min_t; + } + } +} void shared_state::update_time_to(fvm_value_type dt_step, fvm_value_type tmax) { for (fvm_size_type i = 0; i<n_cell; i+=simd_width) { @@ -208,6 +233,16 @@ void shared_state::set_dt() { } } +void shared_state::add_gj_current() { + for (unsigned i = 0; i < n_gj; i++) { + auto gj = gap_junctions[i]; + auto curr = gj.weight * + (voltage[gj.loc.second] - voltage[gj.loc.first]); // nA + + current_density[gj.loc.first] -= curr; + } +} + std::pair<fvm_value_type, fvm_value_type> shared_state::time_bounds() const { return util::minmax_value(time); } diff --git a/arbor/backends/multicore/shared_state.hpp b/arbor/backends/multicore/shared_state.hpp index 25247f14..40be9eb2 100644 --- a/arbor/backends/multicore/shared_state.hpp +++ b/arbor/backends/multicore/shared_state.hpp @@ -86,8 +86,11 @@ struct shared_state { fvm_size_type n_cell = 0; // Number of distinct cells (integration domains). fvm_size_type n_cv = 0; // Total number of CVs. + fvm_size_type n_gj = 0; // Total number of GJs. iarray cv_to_cell; // Maps CV index to cell index. + iarray time_dep; // Provides information about supercells + gjarray gap_junctions; // Stores gap_junction info. array time; // Maps cell index to integration start time [ms]. array time_to; // Maps cell index to integration stop time [ms]. array dt_cell; // Maps cell index to (stop time) - (start time) [ms]. @@ -105,6 +108,8 @@ struct shared_state { shared_state( fvm_size_type n_cell, const std::vector<fvm_index_type>& cv_to_cell_vec, + const std::vector<fvm_index_type>& time_dep_vec, + const std::vector<fvm_gap_junction>& gj_vec, unsigned align ); @@ -123,9 +128,15 @@ struct shared_state { // Set time_to to earliest of time+dt_step and tmax. void update_time_to(fvm_value_type dt_step, fvm_value_type tmax); + // Synchrnize the time_to for supercells. + void sync_time_to(); + // Set the per-cell and per-compartment dt from time_to - time. void set_dt(); + // Update gap_junction state + void add_gj_current(); + // Return minimum and maximum time value [ms] across cells. std::pair<fvm_value_type, fvm_value_type> time_bounds() const; diff --git a/arbor/communication/dry_run_context.cpp b/arbor/communication/dry_run_context.cpp index 15714475..cfa6ddf0 100644 --- a/arbor/communication/dry_run_context.cpp +++ b/arbor/communication/dry_run_context.cpp @@ -34,13 +34,40 @@ struct dry_run_context_impl { } std::vector<count_type> partition; - for(count_type i = 0; i <= num_ranks_; i++) { + for (count_type i = 0; i <= num_ranks_; i++) { partition.push_back(static_cast<count_type>(i*local_size)); } return gathered_vector<arb::spike>(std::move(gathered_spikes), std::move(partition)); } + gathered_vector<cell_gid_type> + gather_gids(const std::vector<cell_gid_type>& local_gids) const { + using count_type = typename gathered_vector<cell_gid_type>::count_type; + + count_type local_size = local_gids.size(); + + std::vector<cell_gid_type> gathered_gids; + gathered_gids.reserve(local_size*num_ranks_); + + for (count_type i = 0; i < num_ranks_; i++) { + gathered_gids.insert(gathered_gids.end(), local_gids.begin(), local_gids.end()); + } + + for (count_type i = 0; i < num_ranks_; i++) { + for (count_type j = i*local_size; j < (i+1)*local_size; j++){ + gathered_gids[j] += num_cells_per_tile_*i; + } + } + + std::vector<count_type> partition; + for (count_type i = 0; i <= num_ranks_; i++) { + partition.push_back(static_cast<count_type>(i*local_size)); + } + + return gathered_vector<cell_gid_type>(std::move(gathered_gids), std::move(partition)); + } + int id() const { return 0; } int size() const { return num_ranks_; } diff --git a/arbor/communication/mpi_context.cpp b/arbor/communication/mpi_context.cpp index be80dca5..9d0bc30b 100644 --- a/arbor/communication/mpi_context.cpp +++ b/arbor/communication/mpi_context.cpp @@ -33,6 +33,11 @@ struct mpi_context_impl { return mpi::gather_all_with_partition(local_spikes, comm_); } + gathered_vector<cell_gid_type> + gather_gids(const std::vector<cell_gid_type>& local_gids) const { + return mpi::gather_all_with_partition(local_gids, comm_); + } + std::string name() const { return "MPI"; } int id() const { return rank_; } int size() const { return size_; } diff --git a/arbor/distributed_context.hpp b/arbor/distributed_context.hpp index ead26887..1037e3f2 100644 --- a/arbor/distributed_context.hpp +++ b/arbor/distributed_context.hpp @@ -42,6 +42,7 @@ namespace arb { class distributed_context { public: using spike_vector = std::vector<arb::spike>; + using gid_vector = std::vector<cell_gid_type>; // default constructor uses a local context: see below. distributed_context(); @@ -58,6 +59,10 @@ public: return impl_->gather_spikes(local_spikes); } + gathered_vector<cell_gid_type> gather_gids(const gid_vector& local_gids) const { + return impl_->gather_gids(local_gids); + } + int id() const { return impl_->id(); } @@ -84,6 +89,8 @@ private: struct interface { virtual gathered_vector<arb::spike> gather_spikes(const spike_vector& local_spikes) const = 0; + virtual gathered_vector<cell_gid_type> + gather_gids(const gid_vector& local_gids) const = 0; virtual int id() const = 0; virtual int size() const = 0; virtual void barrier() const = 0; @@ -104,6 +111,10 @@ private: gather_spikes(const spike_vector& local_spikes) const override { return wrapped.gather_spikes(local_spikes); } + virtual gathered_vector<cell_gid_type> + gather_gids(const gid_vector& local_gids) const override { + return wrapped.gather_gids(local_gids); + } int id() const override { return wrapped.id(); } @@ -138,6 +149,14 @@ struct local_context { {0u, static_cast<count_type>(local_spikes.size())} ); } + gathered_vector<cell_gid_type> + gather_gids(const std::vector<cell_gid_type>& local_gids) const { + using count_type = typename gathered_vector<cell_gid_type>::count_type; + return gathered_vector<cell_gid_type>( + std::vector<cell_gid_type>(local_gids), + {0u, static_cast<count_type>(local_gids.size())} + ); + } int id() const { return 0; } diff --git a/arbor/fvm_layout.cpp b/arbor/fvm_layout.cpp index 0f75ffe2..9d2d1607 100644 --- a/arbor/fvm_layout.cpp +++ b/arbor/fvm_layout.cpp @@ -236,7 +236,6 @@ fvm_discretization fvm_discretize(const std::vector<mc_cell>& cells) { return D; } - // Build up mechanisms. // // Processing procedes in the following stages: diff --git a/arbor/fvm_layout.hpp b/arbor/fvm_layout.hpp index ceeb3008..4ab94980 100644 --- a/arbor/fvm_layout.hpp +++ b/arbor/fvm_layout.hpp @@ -5,6 +5,7 @@ #include <arbor/mechanism.hpp> #include <arbor/mechinfo.hpp> #include <arbor/mechcat.hpp> +#include <arbor/recipe.hpp> #include "fvm_compartment.hpp" #include "util/span.hpp" diff --git a/arbor/fvm_lowered_cell.hpp b/arbor/fvm_lowered_cell.hpp index c710074c..e49c610b 100644 --- a/arbor/fvm_lowered_cell.hpp +++ b/arbor/fvm_lowered_cell.hpp @@ -28,6 +28,7 @@ struct fvm_lowered_cell { virtual void initialize( const std::vector<cell_gid_type>& gids, + const std::vector<int>& deps, const recipe& rec, std::vector<target_handle>& target_handles, probe_association_map<probe_handle>& probe_map) = 0; diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index d033f464..31edf2bf 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -50,6 +50,7 @@ public: void initialize( const std::vector<cell_gid_type>& gids, + const std::vector<int>& deps, const recipe& rec, std::vector<target_handle>& target_handles, probe_association_map<probe_handle>& probe_map) override; @@ -60,6 +61,12 @@ public: std::vector<deliverable_event> staged_events, std::vector<sample_event> staged_samples) override; + std::vector<fvm_gap_junction> fvm_gap_junctions( + const std::vector<mc_cell>& cells, + const std::vector<cell_gid_type>& gids, + const recipe& rec, + const fvm_discretization& D); + value_type time() const override { return tmin_; } //Exposed for testing purposes @@ -198,6 +205,9 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate( m->nrn_current(); } + // Add current contribution from gap_junctions + state_->add_gj_current(); + PE(advance_integrate_events); state_->deliverable_events.drop_marked_events(); @@ -205,6 +215,7 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate( state_->update_time_to(dt_max, tfinal); state_->deliverable_events.event_time_if_before(state_->time_to); + state_->sync_time_to(); state_->set_dt(); PL(); @@ -301,6 +312,7 @@ void fvm_lowered_cell_impl<B>::assert_voltage_bounded(fvm_value_type bound) { template <typename B> void fvm_lowered_cell_impl<B>::initialize( const std::vector<cell_gid_type>& gids, + const std::vector<int>& deps, const recipe& rec, std::vector<target_handle>& target_handles, probe_association_map<probe_handle>& probe_map) @@ -362,6 +374,10 @@ void fvm_lowered_cell_impl<B>::initialize( fvm_mechanism_data mech_data = fvm_build_mechanism_data(*catalogue, cells, D); + // Discritize and build gap junction info + + auto gj_vector = fvm_gap_junctions(cells, gids, rec, D); + // Create shared cell state. // (SIMD padding requires us to check each mechanism for alignment/padding constraints.) @@ -369,7 +385,7 @@ void fvm_lowered_cell_impl<B>::initialize( util::transform_view(keys(mech_data.mechanisms), [&](const std::string& name) { return mech_instance(name)->data_alignment(); })); - state_ = std::make_unique<shared_state>(ncell, D.cv_to_cell, data_alignment? data_alignment: 1u); + state_ = std::make_unique<shared_state>(ncell, D.cv_to_cell, deps, gj_vector, data_alignment? data_alignment: 1u); // Instantiate mechanisms and ions. @@ -468,4 +484,51 @@ void fvm_lowered_cell_impl<B>::initialize( reset(); } +// Get vector of gap_junctions +template <typename B> +std::vector<fvm_gap_junction> fvm_lowered_cell_impl<B>::fvm_gap_junctions( + const std::vector<mc_cell>& cells, + const std::vector<cell_gid_type>& gids, + const recipe& rec, const fvm_discretization& D) { + + std::vector<fvm_gap_junction> v; + + std::unordered_map<cell_gid_type, std::vector<unsigned>> gid_to_cvs; + for (auto cell_idx: util::make_span(0, D.ncell)) { + + if (rec.num_gap_junction_sites(gids[cell_idx])) { + gid_to_cvs[gids[cell_idx]].reserve(rec.num_gap_junction_sites(gids[cell_idx])); + + auto cell_gj = cells[cell_idx].gap_junction_sites(); + for (auto gj : cell_gj) { + auto cv = D.segment_location_cv(cell_idx, gj); + gid_to_cvs[gids[cell_idx]].push_back(cv); + } + } + } + + for (auto gid: gids) { + auto gj_list = rec.gap_junctions_on(gid); + for (auto g: gj_list) { + if (gid != g.local.gid && gid != g.peer.gid) { + throw arb::bad_cell_description(cell_kind::cable1d_neuron, gid); + } + cell_gid_type cv0, cv1; + try { + cv0 = gid_to_cvs[g.local.gid].at(g.local.index); + cv1 = gid_to_cvs[g.peer.gid].at(g.peer.index); + } + catch (std::out_of_range&) { + throw arb::bad_cell_description(cell_kind::cable1d_neuron, gid); + } + if (gid != g.local.gid) { + std::swap(cv0, cv1); + } + v.push_back(fvm_gap_junction(std::make_pair(cv0, cv1), g.ggap * 1e3 / D.cv_area[cv0])); + } + } + + return v; +} + } // namespace arb diff --git a/arbor/include/arbor/arbexcept.hpp b/arbor/include/arbor/arbexcept.hpp index d136362a..cdeaa96d 100644 --- a/arbor/include/arbor/arbexcept.hpp +++ b/arbor/include/arbor/arbexcept.hpp @@ -45,6 +45,11 @@ struct bad_probe_id: arbor_exception { cell_member_type probe_id; }; +struct gj_kind_mismatch: arbor_exception { + gj_kind_mismatch(cell_gid_type gid_0, cell_gid_type gid_1); + cell_gid_type gid_0, gid_1; +}; + // Simulation errors: struct bad_event_time: arbor_exception { diff --git a/arbor/include/arbor/domain_decomposition.hpp b/arbor/include/arbor/domain_decomposition.hpp index 619526c5..4b5d6e9d 100644 --- a/arbor/include/arbor/domain_decomposition.hpp +++ b/arbor/include/arbor/domain_decomposition.hpp @@ -23,9 +23,7 @@ struct group_description { group_description(cell_kind k, std::vector<cell_gid_type> g, backend_kind b): kind(k), gids(std::move(g)), backend(b) - { - arb_assert(std::is_sorted(gids.begin(), gids.end())); - } + {} }; /// Meta data that describes a domain decomposition. diff --git a/arbor/include/arbor/fvm_types.hpp b/arbor/include/arbor/fvm_types.hpp index 9a5bf56d..a4f21386 100644 --- a/arbor/include/arbor/fvm_types.hpp +++ b/arbor/include/arbor/fvm_types.hpp @@ -10,4 +10,16 @@ using fvm_value_type = double; using fvm_size_type = cell_local_size_type; using fvm_index_type = int; +struct fvm_gap_junction { + using value_type = fvm_value_type; + using index_type = fvm_index_type; + + std::pair<index_type, index_type> loc; + value_type weight; + + fvm_gap_junction() {} + fvm_gap_junction(std::pair<index_type, index_type> l, value_type w): loc(l), weight(w) {} + +}; + } // namespace arb diff --git a/arbor/include/arbor/mc_cell.hpp b/arbor/include/arbor/mc_cell.hpp index 71e65794..8ac2addd 100644 --- a/arbor/include/arbor/mc_cell.hpp +++ b/arbor/include/arbor/mc_cell.hpp @@ -100,6 +100,8 @@ public: using value_type = double; using point_type = point<value_type>; + using gap_junction_instance = segment_location; + struct synapse_instance { segment_location location; mechanism_desc mechanism; @@ -123,6 +125,7 @@ public: parents_(other.parents_), stimuli_(other.stimuli_), synapses_(other.synapses_), + gap_junction_sites_(other.gap_junction_sites_), spike_detectors_(other.spike_detectors_) { // unique_ptr's cannot be copy constructed, do a manual assignment @@ -209,6 +212,17 @@ public: return synapses_; } + ////////////////// + // gap-junction + ////////////////// + void add_gap_junction(segment_location location) + { + gap_junction_sites_.push_back(location); + } + const std::vector<gap_junction_instance>& gap_junction_sites() const { + return gap_junction_sites_; + } + ////////////////// // spike detectors ////////////////// @@ -251,6 +265,9 @@ private: // the synapses std::vector<synapse_instance> synapses_; + // the gap_junctions + std::vector<gap_junction_instance> gap_junction_sites_; + // the sensors std::vector<detector_instance> spike_detectors_; }; diff --git a/arbor/include/arbor/recipe.hpp b/arbor/include/arbor/recipe.hpp index 00147ff8..5521aa12 100644 --- a/arbor/include/arbor/recipe.hpp +++ b/arbor/include/arbor/recipe.hpp @@ -47,6 +47,15 @@ struct cell_connection { {} }; +struct gap_junction_connection { + cell_member_type local; + cell_member_type peer; + double ggap; + + gap_junction_connection(cell_member_type local, cell_member_type peer, double g): + local(local), peer(peer), ggap(g) {} +}; + class recipe { public: virtual cell_size_type num_cells() const = 0; @@ -58,13 +67,18 @@ public: virtual cell_size_type num_sources(cell_gid_type) const { return 0; } virtual cell_size_type num_targets(cell_gid_type) const { return 0; } virtual cell_size_type num_probes(cell_gid_type) const { return 0; } - + virtual cell_size_type num_gap_junction_sites(cell_gid_type gid) const { + return gap_junctions_on(gid).size(); + } virtual std::vector<event_generator> event_generators(cell_gid_type) const { return {}; } virtual std::vector<cell_connection> connections_on(cell_gid_type) const { return {}; } + virtual std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type) const { + return {}; + } virtual probe_info get_probe(cell_member_type probe_id) const { throw bad_probe_id(probe_id); diff --git a/arbor/mc_cell_group.cpp b/arbor/mc_cell_group.cpp index f0680809..a2cc9f69 100644 --- a/arbor/mc_cell_group.cpp +++ b/arbor/mc_cell_group.cpp @@ -1,5 +1,5 @@ #include <functional> -#include <unordered_map> +#include <unordered_set> #include <vector> #include <arbor/assert.hpp> @@ -25,8 +25,9 @@ namespace arb { mc_cell_group::mc_cell_group(const std::vector<cell_gid_type>& gids, const recipe& rec, fvm_lowered_cell_ptr lowered): - gids_(gids), lowered_(std::move(lowered)) + lowered_(std::move(lowered)) { + generate_deps_gids(rec, gids); // Default to no binning of events set_binning_policy(binning_kind::none, 0); @@ -46,7 +47,7 @@ mc_cell_group::mc_cell_group(const std::vector<cell_gid_type>& gids, const recip target_handles_.reserve(n_targets); // Construct cell implementation, retrieving handles and maps. - lowered_->initialize(gids_, rec, target_handles_, probe_map_); + lowered_->initialize(gids_, deps_, rec, target_handles_, probe_map_); // Create a list of the global identifiers for the spike sources for (auto source_gid: gids_) { @@ -57,6 +58,59 @@ mc_cell_group::mc_cell_group(const std::vector<cell_gid_type>& gids, const recip spike_sources_.shrink_to_fit(); } +// Fills gids_ and deps_: gids_ are sorted such that members of the same supercell are consecutive +void mc_cell_group::generate_deps_gids(const recipe& rec, std::vector<cell_gid_type> gids) { + + std::unordered_map<cell_gid_type, cell_size_type> gid_to_loc; + for (auto i: util::count_along(gids)) { + gid_to_loc[gids[i]] = i; + } + + deps_.reserve(gids_.size()); + std::unordered_set<cell_gid_type> visited; + std::queue<cell_gid_type> scq; + + for (auto gid: gids) { + if (visited.count(gid)) continue; + visited.insert(gid); + + cell_size_type sc_size = 0; + scq.push(gid); + while (!scq.empty()) { + auto g = scq.front(); + scq.pop(); + + gids_.push_back(g); + ++sc_size; + + for (auto gj: rec.gap_junctions_on(g)) { + cell_gid_type peer = + gj.local.gid==g? gj.peer.gid: + gj.peer.gid==g? gj.local.gid: + throw bad_cell_description(cell_kind::cable1d_neuron, g); + + if (!gid_to_loc.count(peer)) { + // actually an error in the domain decomposition... + throw bad_cell_description(cell_kind::cable1d_neuron, g); + } + + if (!visited.count(peer)) { + visited.insert(peer); + scq.push(peer); + } + } + } + + deps_.push_back(sc_size>1? sc_size: 0); + deps_.insert(deps_.end(), sc_size-1, 0); + } + + perm_gids_.reserve(gids_.size()); + for (auto gid: gids_) { + perm_gids_.push_back(gid_to_loc[gid]); + } +} + void mc_cell_group::reset() { spikes_.clear(); @@ -85,7 +139,7 @@ void mc_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange& e // skip event binning if empty lanes are passed if (event_lanes.size()) { for (auto lid: util::count_along(gids_)) { - auto& lane = event_lanes[lid]; + auto& lane = event_lanes[perm_gids_[lid]]; for (auto e: lane) { if (e.time>=ep.tfinal) break; e.time = binners_[lid].bin(e.time, tstart); diff --git a/arbor/mc_cell_group.hpp b/arbor/mc_cell_group.hpp index 47653506..4efd99d0 100644 --- a/arbor/mc_cell_group.hpp +++ b/arbor/mc_cell_group.hpp @@ -57,10 +57,27 @@ public: void remove_all_samplers() override; + void generate_deps_gids(const recipe& rec, std::vector<cell_gid_type>); + + std::vector<cell_gid_type> get_gids() { + return gids_; + } + + std::vector<int> get_dependencies() { + return deps_; + } + private: // List of the gids of the cells in the group. std::vector<cell_gid_type> gids_; + // Permutation of the gids of the cells in the group. + // perm_gids_[i] is the index of gids_[i] in the original gids vector passed to the ctor + std::vector<cell_size_type> perm_gids_; + + // List of the dependencies of the cells in the group. + std::vector<int> deps_; + // Hash table for converting gid to local index std::unordered_map<cell_gid_type, cell_gid_type> gid_index_map_; diff --git a/arbor/partition_load_balance.cpp b/arbor/partition_load_balance.cpp index 29b70174..c393f51c 100644 --- a/arbor/partition_load_balance.cpp +++ b/arbor/partition_load_balance.cpp @@ -1,3 +1,5 @@ +#include <unordered_set> + #include <arbor/domain_decomposition.hpp> #include <arbor/load_balance.hpp> #include <arbor/recipe.hpp> @@ -21,16 +23,28 @@ domain_decomposition partition_load_balance( const bool gpu_avail = ctx->gpu->has_gpu(); struct partition_gid_domain { - partition_gid_domain(std::vector<cell_gid_type> divs): - gid_divisions(std::move(divs)) + partition_gid_domain(gathered_vector<cell_gid_type> divs, unsigned domains): + gids_by_rank(std::move(divs)), num_domains(domains) {} int operator()(cell_gid_type gid) const { - auto gid_part = util::partition_view(gid_divisions); - return gid_part.index(gid); + using namespace util; + auto rank_part = partition_view(gids_by_rank.partition()); + for (auto i: count_along(rank_part)) { + if (binary_search_index(subrange_view(gids_by_rank.values(), rank_part[i]), gid)) { + return i; + } + } + return -1; } - const std::vector<cell_gid_type> gid_divisions; + const gathered_vector<cell_gid_type> gids_by_rank; + unsigned num_domains; + }; + + struct cell_identifier { + cell_gid_type id; + bool is_super_cell; }; using util::make_span; @@ -53,11 +67,81 @@ domain_decomposition partition_load_balance( // Local load balance - std::unordered_map<cell_kind, std::vector<cell_gid_type>> kind_lists; + std::vector<std::vector<cell_gid_type>> super_cells; //cells connected by gj + std::vector<cell_gid_type> reg_cells; //independent cells + + // Map to track visited cells (cells that already belong to a group) + std::unordered_set<cell_gid_type> visited; + + // Connected components algorithm using BFS + std::queue<cell_gid_type> q; for (auto gid: make_span(gid_part[domain_id])) { - kind_lists[rec.get_cell_kind(gid)].push_back(gid); + if (!rec.gap_junctions_on(gid).empty()) { + // If cell hasn't been visited yet, must belong to new super_cell + // Perform BFS starting from that cell + if (!visited.count(gid)) { + visited.insert(gid); + std::vector<cell_gid_type> cg; + q.push(gid); + while (!q.empty()) { + auto element = q.front(); + q.pop(); + cg.push_back(element); + // Adjacency list + auto conns = rec.gap_junctions_on(element); + for (auto c: conns) { + if (element != c.local.gid && element != c.peer.gid) { + throw bad_cell_description(cell_kind::cable1d_neuron, element); + } + cell_member_type other = c.local.gid == element ? c.peer : c.local; + + if (!visited.count(other.gid)) { + visited.insert(other.gid); + q.push(other.gid); + } + } + } + super_cells.push_back(cg); + } + } + else { + // If cell has no gap_junctions, put in separate group of independent cells + reg_cells.push_back(gid); + } + } + + // Sort super_cell groups and only keep those where the first element in the group belongs to domain + super_cells.erase(std::remove_if(super_cells.begin(), super_cells.end(), + [gid_part, domain_id](std::vector<cell_gid_type>& cg) + { + std::sort(cg.begin(), cg.end()); + return cg.front() < gid_part[domain_id].first; + }), super_cells.end()); + + // Collect local gids that belong to this rank, and sort gids into kind lists + // kind_lists maps a cell_kind to a vector of either: + // 1. gids of regular cells (in reg_cells) + // 2. indices of supercells (in super_cells) + + std::vector<cell_gid_type> local_gids; + std::unordered_map<cell_kind, std::vector<cell_identifier>> kind_lists; + for (auto gid: reg_cells) { + local_gids.push_back(gid); + kind_lists[rec.get_cell_kind(gid)].push_back({gid, false}); + } + + for (unsigned i = 0; i < super_cells.size(); i++) { + auto kind = rec.get_cell_kind(super_cells[i].front()); + for (auto gid: super_cells[i]) { + if (rec.get_cell_kind(gid) != kind) { + throw gj_kind_mismatch(gid, super_cells[i].front()); + } + local_gids.push_back(gid); + } + kind_lists[kind].push_back({i, true}); } + // Create a flat vector of the cell kinds present on this node, // partitioned such that kinds for which GPU implementation are // listed before the others. This is a very primitive attempt at @@ -95,8 +179,19 @@ domain_decomposition partition_load_balance( } std::vector<cell_gid_type> group_elements; - for (auto gid: kind_lists[k]) { - group_elements.push_back(gid); + // group_elements are sorted such that the gids of all members of a super_cell are consecutive. + for (auto cell: kind_lists[k]) { + if (cell.is_super_cell == false) { + group_elements.push_back(cell.id); + } else { + if (group_elements.size() + super_cells[cell.id].size() > group_size && !group_elements.empty()) { + groups.push_back({k, std::move(group_elements), backend}); + group_elements.clear(); + } + for (auto gid: super_cells[cell.id]) { + group_elements.push_back(gid); + } + } if (group_elements.size()>=group_size) { groups.push_back({k, std::move(group_elements), backend}); group_elements.clear(); @@ -107,9 +202,14 @@ domain_decomposition partition_load_balance( } } - // calculate the number of local cells - auto rng = gid_part[domain_id]; - cell_size_type num_local_cells = rng.second - rng.first; + cell_size_type num_local_cells = local_gids.size(); + + // Exchange gid list with all other nodes + + util::sort(local_gids); + + // global all-to-all to gather a local copy of the global gid list on each node. + auto global_gids = ctx->distributed->gather_gids(local_gids); domain_decomposition d; d.num_domains = num_domains; @@ -117,7 +217,7 @@ domain_decomposition partition_load_balance( d.num_local_cells = num_local_cells; d.num_global_cells = num_global_cells; d.groups = std::move(groups); - d.gid_domain = partition_gid_domain(std::move(gid_divisions)); + d.gid_domain = partition_gid_domain(std::move(global_gids), num_domains); return d; } diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 9cfab47e..990a5671 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -8,3 +8,4 @@ add_subdirectory(generators) add_subdirectory(brunel) add_subdirectory(bench) add_subdirectory(ring) +add_subdirectory(gap_junctions) diff --git a/example/gap_junctions/CMakeLists.txt b/example/gap_junctions/CMakeLists.txt new file mode 100644 index 00000000..60a68d61 --- /dev/null +++ b/example/gap_junctions/CMakeLists.txt @@ -0,0 +1,4 @@ +add_executable(gap_junctions EXCLUDE_FROM_ALL gap_junctions.cpp parameters.hpp) +add_dependencies(examples gap_junctions) + +target_link_libraries(gap_junctions PRIVATE arbor arborenv arbor-sup ext-json) diff --git a/example/gap_junctions/gap_junctions.cpp b/example/gap_junctions/gap_junctions.cpp new file mode 100644 index 00000000..94a697f7 --- /dev/null +++ b/example/gap_junctions/gap_junctions.cpp @@ -0,0 +1,338 @@ +/* + * A miniapp that demonstrates how to make a model with gap junctions + * + */ + +#include <fstream> +#include <iomanip> +#include <iostream> + +#include <nlohmann/json.hpp> + +#include <arbor/assert_macro.hpp> +#include <arbor/common_types.hpp> +#include <arbor/context.hpp> +#include <arbor/load_balance.hpp> +#include <arbor/mc_cell.hpp> +#include <arbor/profile/meter_manager.hpp> +#include <arbor/profile/profiler.hpp> +#include <arbor/simple_sampler.hpp> +#include <arbor/simulation.hpp> +#include <arbor/recipe.hpp> +#include <arbor/version.hpp> + +#include <sup/ioutil.hpp> +#include <sup/json_meter.hpp> + +#include "parameters.hpp" + +#ifdef ARB_MPI_ENABLED +#include <mpi.h> +#include <arborenv/with_mpi.hpp> +#endif + +using arb::cell_gid_type; +using arb::cell_lid_type; +using arb::cell_size_type; +using arb::cell_member_type; +using arb::cell_kind; +using arb::time_type; +using arb::cell_probe_address; + +// Writes voltage trace as a json file. +void write_trace_json(const std::vector<arb::trace_data<double>>& trace, unsigned rank); + +// Generate a cell. +arb::mc_cell gj_cell(double delay, double duration); + +class gj_recipe: public arb::recipe { +public: + gj_recipe(gap_params params): params_(params) {} + + cell_size_type num_cells() const override { + return params_.cells_per_ring * params_.num_rings; + } + + arb::util::unique_any get_cell_description(cell_gid_type gid) const override { + return gj_cell(params_.delay*gid, params_.duration); + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable1d_neuron; + } + + // Each cell has one spike detector (at the soma). + cell_size_type num_sources(cell_gid_type gid) const override { + return 1; + } + + // The cell has one target synapse, which will be connected to cell gid-1. + cell_size_type num_targets(cell_gid_type gid) const override { + return 0; + } + + // There is one probe (for measuring voltage at the soma) on the cell. + cell_size_type num_probes(cell_gid_type gid) const override { + return 1; + } + + arb::probe_info get_probe(cell_member_type id) const override { + // Get the appropriate kind for measuring voltage. + cell_probe_address::probe_kind kind = cell_probe_address::membrane_voltage; + // Measure at the soma. + arb::segment_location loc(0, 1); + + return arb::probe_info{id, kind, cell_probe_address{loc, kind}}; + } + + arb::util::any get_global_properties(cell_kind k) const override { + arb::mc_cell_global_properties a; + a.temperature_K = 308.15; + return a; + } + + std::vector<arb::gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override{ + std::vector<arb::gap_junction_connection> conns; + + cell_gid_type ring_start_gid = (gid/params_.cells_per_ring) * params_.cells_per_ring; + cell_gid_type next_cell = (gid + 1) % params_.cells_per_ring + ring_start_gid; + cell_gid_type prev_cell = (gid - 1 + params_.cells_per_ring) % params_.cells_per_ring + ring_start_gid; + + // Our soma is connected to the next cell's dendrite + // Our dendrite is connected to the prev cell's soma + conns.push_back(arb::gap_junction_connection({next_cell, 1}, {gid, 0}, 0.0037)); // 1 is the id of the dendrite junction + conns.push_back(arb::gap_junction_connection({prev_cell, 0}, {gid, 1}, 0.0037)); // 0 is the id of the soma junction + + return conns; + } + +private: + gap_params params_; +}; + +struct cell_stats { + using size_type = unsigned; + size_type ncells = 0; + size_type nsegs = 0; + size_type ncomp = 0; + + cell_stats(arb::recipe& r) { +#ifdef ARB_MPI_ENABLED + int nranks, rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &nranks); + ncells = r.num_cells(); + size_type cells_per_rank = ncells/nranks; + size_type b = rank*cells_per_rank; + size_type e = (rank==nranks-1)? ncells: (rank+1)*cells_per_rank; + size_type nsegs_tmp = 0; + size_type ncomp_tmp = 0; + for (size_type i=b; i<e; ++i) { + auto c = arb::util::any_cast<arb::mc_cell>(r.get_cell_description(i)); + nsegs_tmp += c.num_segments(); + ncomp_tmp += c.num_compartments(); + } + MPI_Allreduce(&nsegs_tmp, &nsegs, 1, MPI_UNSIGNED, MPI_SUM, MPI_COMM_WORLD); + MPI_Allreduce(&ncomp_tmp, &ncomp, 1, MPI_UNSIGNED, MPI_SUM, MPI_COMM_WORLD); +#else + ncells = r.num_cells(); + for (size_type i=0; i<ncells; ++i) { + auto c = arb::util::any_cast<arb::mc_cell>(r.get_cell_description(i)); + nsegs += c.num_segments(); + ncomp += c.num_compartments(); + } +#endif + } + + friend std::ostream& operator<<(std::ostream& o, const cell_stats& s) { + return o << "cell stats: " + << s.ncells << " cells; " + << s.nsegs << " segments; " + << s.ncomp << " compartments."; + } +}; + + +int main(int argc, char** argv) { + try { + bool root = true; + +#ifdef ARB_MPI_ENABLED + arbenv::with_mpi guard(argc, argv, false); + auto context = arb::make_context(arb::proc_allocation(), MPI_COMM_WORLD); + { + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + root = rank==0; + } +#else + auto context = arb::make_context(); +#endif + +#ifdef ARB_PROFILE_ENABLED + arb::profile::profiler_initialize(context); +#endif + + std::cout << sup::mask_stream(root); + + // Print a banner with information about hardware configuration + std::cout << "gpu: " << (has_gpu(context)? "yes": "no") << "\n"; + std::cout << "threads: " << num_threads(context) << "\n"; + std::cout << "mpi: " << (has_mpi(context)? "yes": "no") << "\n"; + std::cout << "ranks: " << num_ranks(context) << "\n" << std::endl; + + auto params = read_options(argc, argv); + + arb::profile::meter_manager meters; + meters.start(context); + + // Create an instance of our recipe. + gj_recipe recipe(params); + + cell_stats stats(recipe); + std::cout << stats << "\n"; + + auto decomp = arb::partition_load_balance(recipe, context); + + // Construct the model. + arb::simulation sim(recipe, decomp, context); + + // Set up the probe that will measure voltage in the cell. + + auto sched = arb::regular_schedule(0.025); + // This is where the voltage samples will be stored as (time, value) pairs + std::vector<arb::trace_data<double>> voltage(decomp.num_local_cells); + + // Now attach the sampler at probe_id, with sampling schedule sched, writing to voltage + unsigned j=0; + for (auto g : decomp.groups) { + for (auto i : g.gids) { + auto t = recipe.get_probe({i, 0}); + sim.add_sampler(arb::one_probe(t.id), sched, arb::make_simple_sampler(voltage[j++])); + } + } + + // Set up recording of spikes to a vector on the root process. + std::vector<arb::spike> recorded_spikes; + if (root) { + sim.set_global_spike_callback( + [&recorded_spikes](const std::vector<arb::spike>& spikes) { + recorded_spikes.insert(recorded_spikes.end(), spikes.begin(), spikes.end()); + }); + } + + meters.checkpoint("model-init", context); + + std::cout << "running simulation" << std::endl; + // Run the simulation for 100 ms, with time steps of 0.025 ms. + sim.run(params.duration, 0.025); + + meters.checkpoint("model-run", context); + + auto ns = sim.num_spikes(); + + // Write spikes to file + if (root) { + std::cout << "\n" << ns << " spikes generated at rate of " + << params.duration/ns << " ms between spikes\n"; + std::ofstream fid("spikes.gdf"); + if (!fid.good()) { + std::cerr << "Warning: unable to open file spikes.gdf for spike output\n"; + } + else { + char linebuf[45]; + for (auto spike: recorded_spikes) { + auto n = std::snprintf( + linebuf, sizeof(linebuf), "%u %.4f\n", + unsigned{spike.source.gid}, float(spike.time)); + fid.write(linebuf, n); + } + } + } + + // Write the samples to a json file. + if (params.print_all) { + write_trace_json(voltage, arb::rank(context)); + } + + auto report = arb::profile::make_meter_report(meters, context); + std::cout << report; + } + catch (std::exception& e) { + std::cerr << "exception caught in gap junction miniapp:\n" << e.what() << "\n"; + return 1; + } + + return 0; +} + +void write_trace_json(const std::vector<arb::trace_data<double>>& trace, unsigned rank) { + for (unsigned i = 0; i < trace.size(); i++) { + std::string path = "./voltages_" + std::to_string(rank) + + "_" + std::to_string(i) + ".json"; + + nlohmann::json json; + json["name"] = "gj demo: cell " + std::to_string(i); + json["units"] = "mV"; + json["cell"] = std::to_string(i); + json["probe"] = "0"; + + auto &jt = json["data"]["time"]; + auto &jy = json["data"]["voltage"]; + + for (const auto &sample: trace[i]) { + jt.push_back(sample.t); + jy.push_back(sample.v); + } + + std::ofstream file(path); + file << std::setw(1) << json << "\n"; + } +} + +arb::mc_cell gj_cell(double delay, double duration) { + arb::mc_cell cell; + + arb::mechanism_desc nax("nax"); + arb::mechanism_desc kdrmt("kdrmt"); + arb::mechanism_desc kamt("kamt"); + arb::mechanism_desc pas("pas"); + + auto set_reg_params = [&]() { + nax["gbar"] = 0.04; + nax["sh"] = 10; + kdrmt["gbar"] = 0.0001; + kamt["gbar"] = 0.004; + pas["g"] = 1.0/12000.0; + pas["e"] = -65; + }; + + auto setup_seg = [&](auto seg) { + seg->rL = 100; + seg->cm = 0.018; + seg->add_mechanism(nax); + seg->add_mechanism(kdrmt); + seg->add_mechanism(kamt); + seg->add_mechanism(pas); + }; + + auto soma = cell.add_soma(22.360679775/2.0); + set_reg_params(); + setup_seg(soma); + + auto dend = cell.add_cable(0, arb::section_kind::dendrite, 3.0/2.0, 3.0/2.0, 300); //cable 1 + dend->set_compartments(1); + set_reg_params(); + setup_seg(dend); + + cell.add_detector({0,0}, 10); + + cell.add_gap_junction({0, 1}); //ggap in μS + cell.add_gap_junction({1, 1}); //ggap in μS + + arb::i_clamp stim(delay, duration, 0.2); + cell.add_stimulus({1, 0.95}, stim); + + return cell; +} + diff --git a/example/gap_junctions/parameters.hpp b/example/gap_junctions/parameters.hpp new file mode 100644 index 00000000..2c2e7487 --- /dev/null +++ b/example/gap_junctions/parameters.hpp @@ -0,0 +1,61 @@ +#include <iostream> + +#include <array> +#include <cmath> +#include <fstream> +#include <random> + +#include <arbor/mc_cell.hpp> + +#include <sup/json_params.hpp> + +struct gap_params { + gap_params() = default; + + std::string name = "default"; + unsigned cells_per_ring = 100; + unsigned num_rings = 10; + double duration = 100; + double delay = 0.5; + bool print_all = false; +}; + +gap_params read_options(int argc, char** argv) { + using sup::param_from_json; + + gap_params params; + if (argc<2) { + std::cout << "Using default parameters.\n"; + return params; + } + if (argc>2) { + throw std::runtime_error("More than command line one option not permitted."); + } + + std::string fname = argv[1]; + std::cout << "Loading parameters from file: " << fname << "\n"; + std::ifstream f(fname); + + if (!f.good()) { + throw std::runtime_error("Unable to open input parameter file: "+fname); + } + + nlohmann::json json; + json << f; + + param_from_json(params.name, "name", json); + param_from_json(params.cells_per_ring, "num-cells", json); + param_from_json(params.num_rings, "num-rings", json); + param_from_json(params.duration, "duration", json); + param_from_json(params.delay, "delay", json); + param_from_json(params.print_all, "print-all", json); + + if (!json.empty()) { + for (auto it=json.begin(); it!=json.end(); ++it) { + std::cout << " Warning: unused input parameter: \"" << it.key() << "\"\n"; + } + std::cout << "\n"; + } + + return params; +} diff --git a/example/gap_junctions/readme.md b/example/gap_junctions/readme.md new file mode 100644 index 00000000..1d32a192 --- /dev/null +++ b/example/gap_junctions/readme.md @@ -0,0 +1,3 @@ +# Gap Junctions Example + +A miniapp that demonstrates how to describe how to build a network with gap junctions. diff --git a/mechanisms/mod/nax.mod b/mechanisms/mod/nax.mod index 931fe82f..911b8355 100644 --- a/mechanisms/mod/nax.mod +++ b/mechanisms/mod/nax.mod @@ -90,5 +90,5 @@ PROCEDURE trates(vm,sh2,celsius) { FUNCTION trap0(v,th,a,q) { : trap0 = a * (v - th) / (1 - exp(-(v - th)/q)) - trap0 = -a*q*exprelr((v-th)/q) + trap0 = a*q*exprelr(-(v-th)/q) } diff --git a/test/simple_recipes.hpp b/test/simple_recipes.hpp index 92c39208..bcf52854 100644 --- a/test/simple_recipes.hpp +++ b/test/simple_recipes.hpp @@ -122,6 +122,5 @@ protected: std::vector<mc_cell> cells_; }; - } // namespace arb diff --git a/test/ubench/mech_vec.cpp b/test/ubench/mech_vec.cpp index 8c0388b8..0a6572f6 100644 --- a/test/ubench/mech_vec.cpp +++ b/test/ubench/mech_vec.cpp @@ -210,7 +210,7 @@ void expsyn_1_branch_current(benchmark::State& state) { probe_association_map<probe_handle> probe_handles; fvm_cell cell((execution_context())); - cell.initialize(gids, rec_expsyn_1_branch, target_handles, probe_handles); + cell.initialize(gids, {0}, rec_expsyn_1_branch, target_handles, probe_handles); auto& m = find_mechanism("expsyn", cell); @@ -229,7 +229,7 @@ void expsyn_1_branch_state(benchmark::State& state) { probe_association_map<probe_handle> probe_handles; fvm_cell cell((execution_context())); - cell.initialize(gids, rec_expsyn_1_branch, target_handles, probe_handles); + cell.initialize(gids, {0}, rec_expsyn_1_branch, target_handles, probe_handles); auto& m = find_mechanism("expsyn", cell); @@ -247,7 +247,7 @@ void pas_1_branch_current(benchmark::State& state) { probe_association_map<probe_handle> probe_handles; fvm_cell cell((execution_context())); - cell.initialize(gids, rec_pas_1_branch, target_handles, probe_handles); + cell.initialize(gids, {0}, rec_pas_1_branch, target_handles, probe_handles); auto& m = find_mechanism("pas", cell); @@ -265,7 +265,7 @@ void pas_3_branches_current(benchmark::State& state) { probe_association_map<probe_handle> probe_handles; fvm_cell cell((execution_context())); - cell.initialize(gids, rec_pas_3_branches, target_handles, probe_handles); + cell.initialize(gids, {0}, rec_pas_3_branches, target_handles, probe_handles); auto& m = find_mechanism("pas", cell); @@ -283,7 +283,7 @@ void hh_1_branch_state(benchmark::State& state) { probe_association_map<probe_handle> probe_handles; fvm_cell cell((execution_context())); - cell.initialize(gids, rec_hh_1_branch, target_handles, probe_handles); + cell.initialize(gids, {0}, rec_hh_1_branch, target_handles, probe_handles); auto& m = find_mechanism("hh", cell); @@ -301,7 +301,7 @@ void hh_1_branch_current(benchmark::State& state) { probe_association_map<probe_handle> probe_handles; fvm_cell cell((execution_context())); - cell.initialize(gids, rec_hh_1_branch, target_handles, probe_handles); + cell.initialize(gids, {0}, rec_hh_1_branch, target_handles, probe_handles); auto& m = find_mechanism("hh", cell); @@ -319,7 +319,7 @@ void hh_3_branches_state(benchmark::State& state) { probe_association_map<probe_handle> probe_handles; fvm_cell cell((execution_context())); - cell.initialize(gids, rec_hh_3_branches, target_handles, probe_handles); + cell.initialize(gids, {0}, rec_hh_3_branches, target_handles, probe_handles); auto& m = find_mechanism("hh", cell); @@ -337,7 +337,7 @@ void hh_3_branches_current(benchmark::State& state) { probe_association_map<probe_handle> probe_handles; fvm_cell cell((execution_context())); - cell.initialize(gids, rec_hh_3_branches, target_handles, probe_handles); + cell.initialize(gids, {0}, rec_hh_3_branches, target_handles, probe_handles); auto& m = find_mechanism("hh", cell); diff --git a/test/unit-distributed/test_communicator.cpp b/test/unit-distributed/test_communicator.cpp index e40a13ac..4eda7ccb 100644 --- a/test/unit-distributed/test_communicator.cpp +++ b/test/unit-distributed/test_communicator.cpp @@ -140,6 +140,45 @@ TEST(communicator, gather_spikes_variant) { } } +// Test low level gids_gather function when the number of gids per domain +// are not equal. +TEST(communicator, gather_gids_variant) { + const auto num_domains = g_context->distributed->size(); + const auto rank = g_context->distributed->id(); + + constexpr int scale = 10; + const auto n_local_gids = scale*rank; + auto sumn = [](unsigned n) {return scale*n*(n+1)/2;}; + + std::vector<cell_gid_type > local_gids; + const auto local_start_id = sumn(rank-1); + for (auto i=0; i<n_local_gids; ++i) { + local_gids.push_back(local_start_id + i); + } + + // Perform exchange + const auto global_gids = g_context->distributed->gather_gids(local_gids); + + // Test that partition information is correct + const auto& part =global_gids.partition(); + EXPECT_EQ(unsigned(num_domains+1), part.size()); + EXPECT_EQ(0, (int)part[0]); + for (auto i=1u; i<part.size(); ++i) { + EXPECT_EQ(sumn(i-1), part[i]); + } + + // Test that gids were correctly exchanged + for (auto domain=0; domain<num_domains; ++domain) { + auto source = sumn(domain-1); + const auto first_gid = global_gids.values().begin() + sumn(domain-1); + const auto last_gid = global_gids.values().begin() + sumn(domain); + const auto gids = util::make_range(first_gid, last_gid); + for (auto s: gids) { + EXPECT_EQ(s, source++); + } + } +} + namespace { // Population of cable and rss cells with ring connection topology. // Even gid are rss, and odd gid are cable cells. diff --git a/test/unit-distributed/test_domain_decomposition.cpp b/test/unit-distributed/test_domain_decomposition.cpp index 43f3127f..27e69004 100644 --- a/test/unit-distributed/test_domain_decomposition.cpp +++ b/test/unit-distributed/test_domain_decomposition.cpp @@ -67,6 +67,81 @@ namespace { private: cell_size_type size_; }; + + class gj_symmetric: public recipe { + public: + gj_symmetric(unsigned num_ranks): ncopies_(num_ranks){} + + cell_size_type num_cells() const override { + return size_*ncopies_; + } + + arb::util::unique_any get_cell_description(cell_gid_type) const override { + return {}; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable1d_neuron; + } + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { + unsigned shift = (gid/size_)*size_; + switch (gid % size_) { + case 1 : return { gap_junction_connection({7 + shift, 0}, {gid, 0}, 0.1)}; + case 2 : return { + gap_junction_connection({6 + shift, 0}, {gid, 0}, 0.1), + gap_junction_connection({9 + shift, 0}, {gid, 0}, 0.1) + }; + case 6 : return { + gap_junction_connection({2 + shift, 0}, {gid, 0}, 0.1), + gap_junction_connection({7 + shift, 0}, {gid, 0}, 0.1) + }; + case 7 : return { + gap_junction_connection({6 + shift, 0}, {gid, 0}, 0.1), + gap_junction_connection({1 + shift, 0}, {gid, 0}, 0.1) + }; + case 9 : return { gap_junction_connection({2 + shift, 0}, {gid, 0}, 0.1)}; + default : return {}; + } + } + + private: + cell_size_type size_ = 10; + unsigned ncopies_; + }; + + class gj_non_symmetric: public recipe { + public: + gj_non_symmetric(unsigned num_ranks): groups_(num_ranks), size_(num_ranks){} + + cell_size_type num_cells() const override { + return size_*groups_; + } + + arb::util::unique_any get_cell_description(cell_gid_type) const override { + return {}; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable1d_neuron; + } + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { + unsigned group = gid/groups_; + unsigned id = gid%size_; + if (id == group && group != (groups_ - 1)) { + return {gap_junction_connection({gid + size_, 0}, {gid, 0}, 0.1)}; + } + else if (id == group - 1) { + return {gap_junction_connection({gid - size_, 0}, {gid, 0}, 0.1)}; + } + else { + return {}; + } + } + + private: + unsigned groups_; + cell_size_type size_; + }; } TEST(domain_decomposition, homogeneous_population_mc) { @@ -75,7 +150,7 @@ TEST(domain_decomposition, homogeneous_population_mc) { // This assumption will not hold in the future, requiring and update to // the test. proc_allocation resources{1, -1}; -#ifdef ARB_TEST_MPI +#ifdef TEST_MPI auto ctx = make_context(resources, MPI_COMM_WORLD); #else auto ctx = make_context(resources); @@ -124,7 +199,7 @@ TEST(domain_decomposition, homogeneous_population_gpu) { proc_allocation resources; resources.num_threads = 1; -#ifdef ARB_TEST_MPI +#ifdef TEST_MPI auto ctx = make_context(resources, MPI_COMM_WORLD); #else auto ctx = make_context(resources); @@ -169,7 +244,7 @@ TEST(domain_decomposition, heterogeneous_population) { // This assumption will not hold in the future, requiring and update to // the test. proc_allocation resources{1, -1}; -#ifdef ARB_TEST_MPI +#ifdef TEST_MPI auto ctx = make_context(resources, MPI_COMM_WORLD); #else auto ctx = make_context(resources); @@ -217,3 +292,126 @@ TEST(domain_decomposition, heterogeneous_population) { } } +TEST(domain_decomposition, symmetric_groups) +{ + proc_allocation resources{1, -1}; + int nranks = 1; + int rank = 0; +#ifdef TEST_MPI + auto ctx = make_context(resources, MPI_COMM_WORLD); + MPI_Comm_size(MPI_COMM_WORLD, &nranks); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); +#else + auto ctx = make_context(resources); +#endif + auto R = gj_symmetric(nranks); + const auto D0 = partition_load_balance(R, ctx); + EXPECT_EQ(6u, D0.groups.size()); + + unsigned shift = rank*R.num_cells()/nranks; + std::vector<std::vector<cell_gid_type>> expected_groups0 = + { {0 + shift}, + {3 + shift}, + {4 + shift}, + {5 + shift}, + {8 + shift}, + {1 + shift, 2 + shift, 6 + shift, 7 + shift, 9 + shift} + }; + + for (unsigned i = 0; i < 6; i++){ + EXPECT_EQ(expected_groups0[i], D0.groups[i].gids); + } + + unsigned cells_per_rank = R.num_cells()/nranks; + for (unsigned i = 0; i < R.num_cells(); i++) { + EXPECT_EQ(i/cells_per_rank, (unsigned)D0.gid_domain(i)); + } + + // Test different group_hints + partition_hint_map hints; + hints[cell_kind::cable1d_neuron].cpu_group_size = R.num_cells(); + hints[cell_kind::cable1d_neuron].prefer_gpu = false; + + const auto D1 = partition_load_balance(R, ctx, hints); + EXPECT_EQ(1u, D1.groups.size()); + + std::vector<cell_gid_type> expected_groups1 = + { 0 + shift, 3 + shift, 4 + shift, 5 + shift, 8 + shift, + 1 + shift, 2 + shift, 6 + shift, 7 + shift, 9 + shift }; + + EXPECT_EQ(expected_groups1, D1.groups[0].gids); + + for (unsigned i = 0; i < R.num_cells(); i++) { + EXPECT_EQ(i/cells_per_rank, (unsigned)D1.gid_domain(i)); + } + + hints[cell_kind::cable1d_neuron].cpu_group_size = cells_per_rank/2; + hints[cell_kind::cable1d_neuron].prefer_gpu = false; + + const auto D2 = partition_load_balance(R, ctx, hints); + EXPECT_EQ(2u, D2.groups.size()); + + std::vector<std::vector<cell_gid_type>> expected_groups2 = + { { 0 + shift, 3 + shift, 4 + shift, 5 + shift, 8 + shift }, + { 1 + shift, 2 + shift, 6 + shift, 7 + shift, 9 + shift } }; + + for (unsigned i = 0; i < 2u; i++) { + EXPECT_EQ(expected_groups2[i], D2.groups[i].gids); + } + for (unsigned i = 0; i < R.num_cells(); i++) { + EXPECT_EQ(i/cells_per_rank, (unsigned)D2.gid_domain(i)); + } + +} + +TEST(domain_decomposition, non_symmetric_groups) +{ + proc_allocation resources{1, -1}; + int nranks = 1; + int rank = 0; +#ifdef TEST_MPI + auto ctx = make_context(resources, MPI_COMM_WORLD); + MPI_Comm_size(MPI_COMM_WORLD, &nranks); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); +#else + auto ctx = make_context(resources); +#endif + /*if (nranks == 1) { + return; + }*/ + + auto R = gj_non_symmetric(nranks); + const auto D = partition_load_balance(R, ctx); + + unsigned cells_per_rank = nranks; + // check groups + unsigned i = 0; + for (unsigned gid = rank*cells_per_rank; gid < (rank + 1)*cells_per_rank; gid++) { + if (gid % nranks == (unsigned)rank - 1) { + continue; + } + else if (gid % nranks == (unsigned)rank && rank != nranks - 1) { + std::vector<cell_gid_type> cg = {gid, gid+cells_per_rank}; + EXPECT_EQ(cg, D.groups[D.groups.size()-1].gids); + } + else { + std::vector<cell_gid_type> cg = {gid}; + EXPECT_EQ(cg, D.groups[i++].gids); + } + } + // check gid_domains + for (unsigned gid = 0; gid < R.num_cells(); gid++) { + auto group = gid/cells_per_rank; + auto idx = gid % cells_per_rank; + unsigned ngroups = nranks; + if (idx == group - 1) { + EXPECT_EQ(group - 1, (unsigned)D.gid_domain(gid)); + } + else if (idx == group && group != ngroups - 1) { + EXPECT_EQ(group, (unsigned)D.gid_domain(gid)); + } + else { + EXPECT_EQ(group, (unsigned)D.gid_domain(gid)); + } + } +} diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 977372b9..1148c002 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -103,6 +103,7 @@ set(unit_sources test_spike_store.cpp test_stats.cpp test_strprintf.cpp + test_supercell.cpp test_swcio.cpp test_synapses.cpp test_thread.cpp diff --git a/test/unit/test_domain_decomposition.cpp b/test/unit/test_domain_decomposition.cpp index a35cef75..d19313b1 100644 --- a/test/unit/test_domain_decomposition.cpp +++ b/test/unit/test_domain_decomposition.cpp @@ -54,6 +54,62 @@ namespace { private: cell_size_type size_; }; + + class gap_recipe: public recipe { + public: + gap_recipe() {} + + cell_size_type num_cells() const override { + return size_; + } + + arb::util::unique_any get_cell_description(cell_gid_type) const override { + mc_cell c; + c.add_soma(20); + c.add_gap_junction({0,1}); + return {std::move(c)}; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable1d_neuron; + } + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { + switch (gid) { + case 0 : return {gap_junction_connection({13, 0}, {0, 0}, 0.1)}; + case 2 : return { + gap_junction_connection({7, 0}, {2, 0}, 0.1), + gap_junction_connection({11, 0}, {2, 0}, 0.1) + }; + case 3 : return { + gap_junction_connection({4, 0}, {3, 0}, 0.1), + gap_junction_connection({8, 0}, {3, 0}, 0.1) + }; + case 4 : return { + gap_junction_connection({3, 0}, {4, 0}, 0.1), + gap_junction_connection({8, 0}, {4, 0}, 0.1), + gap_junction_connection({9, 0}, {4, 0}, 0.1) + }; + case 7 : return { + gap_junction_connection({2, 0}, {7, 0}, 0.1), + gap_junction_connection({11, 0}, {7, 0}, 0.1) + };; + case 8 : return { + gap_junction_connection({3, 0}, {8, 0}, 0.1), + gap_junction_connection({4, 0}, {8, 0}, 0.1) + };; + case 9 : return {gap_junction_connection({4, 0}, {9, 0}, 0.1)}; + case 11 : return { + gap_junction_connection({2, 0}, {11, 0}, 0.1), + gap_junction_connection({7, 0}, {11, 0}, 0.1) + }; + case 13 : return {gap_junction_connection({0, 0}, {13, 0}, 0.1)}; + default : return {}; + } + } + + private: + cell_size_type size_ = 15; + }; } // test assumes one domain @@ -251,3 +307,49 @@ TEST(domain_decomposition, hints) { EXPECT_EQ(expected_c1d_groups, c1d_groups); EXPECT_EQ(expected_ss_groups, ss_groups); } + +TEST(domain_decomposition, compulsory_groups) +{ + proc_allocation resources; + resources.num_threads = 1; + resources.gpu_id = -1; // disable GPU if available + auto ctx = make_context(resources); + + auto R = gap_recipe(); + const auto D0 = partition_load_balance(R, ctx); + EXPECT_EQ(9u, D0.groups.size()); + + std::vector<std::vector<cell_gid_type>> expected_groups0 = + { {1}, {5}, {6}, {10}, {12}, {14}, {0, 13}, {2, 7, 11}, {3, 4, 8, 9} }; + + for (unsigned i = 0; i < 9u; i++) { + EXPECT_EQ(expected_groups0[i], D0.groups[i].gids); + } + + // Test different group_hints + partition_hint_map hints; + hints[cell_kind::cable1d_neuron].cpu_group_size = 3; + hints[cell_kind::cable1d_neuron].prefer_gpu = false; + + const auto D1 = partition_load_balance(R, ctx, hints); + EXPECT_EQ(5u, D1.groups.size()); + + std::vector<std::vector<cell_gid_type>> expected_groups1 = + { {1, 5, 6}, {10, 12, 14}, {0, 13}, {2, 7, 11}, {3, 4, 8, 9} }; + + for (unsigned i = 0; i < 5u; i++) { + EXPECT_EQ(expected_groups1[i], D1.groups[i].gids); + } + + hints[cell_kind::cable1d_neuron].cpu_group_size = 20; + hints[cell_kind::cable1d_neuron].prefer_gpu = false; + + const auto D2 = partition_load_balance(R, ctx, hints); + EXPECT_EQ(1u, D2.groups.size()); + + std::vector<cell_gid_type> expected_groups2 = + { 1, 5, 6, 10, 12, 14, 0, 13, 2, 7, 11, 3, 4, 8, 9 }; + + EXPECT_EQ(expected_groups2, D2.groups[0].gids); + +} diff --git a/test/unit/test_dry_run_context.cpp b/test/unit/test_dry_run_context.cpp index b4c9a6d0..d0c5c91e 100644 --- a/test/unit/test_dry_run_context.cpp +++ b/test/unit/test_dry_run_context.cpp @@ -97,3 +97,23 @@ TEST(dry_run_context, gather_spikes) EXPECT_EQ(part[3], spikes.size()*3); EXPECT_EQ(part[4], spikes.size()*4); } + +TEST(dry_run_context, gather_gids) +{ + distributed_context_handle ctx = arb::make_dry_run_context(4, 4); + using gvec = std::vector<arb::cell_gid_type>; + + gvec gids = {0, 1, 2, 3}; + gvec gathered_gids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + auto s = ctx->gather_gids(gids); + auto& part = s.partition(); + + EXPECT_EQ(s.values(), gathered_gids); + EXPECT_EQ(part.size(), 5u); + EXPECT_EQ(part[0], 0u); + EXPECT_EQ(part[1], gids.size()); + EXPECT_EQ(part[2], gids.size()*2); + EXPECT_EQ(part[3], gids.size()*3); + EXPECT_EQ(part[4], gids.size()*4); +} diff --git a/test/unit/test_event_delivery.cpp b/test/unit/test_event_delivery.cpp index d521f8bf..905d8cb4 100644 --- a/test/unit/test_event_delivery.cpp +++ b/test/unit/test_event_delivery.cpp @@ -6,11 +6,6 @@ // * Inject events one per cell in a given order, and confirm generated spikes // are in the same order. -// Note: this test anticipates the gap junction PR; the #if guards below -// will be removed when that PR is merged. - -#define NO_GJ_YET - #include <arbor/common_types.hpp> #include <arbor/domain_decomposition.hpp> #include <arbor/simulation.hpp> @@ -36,9 +31,7 @@ struct test_recipe: public n_mc_cell_recipe { c.add_soma(10.)->add_mechanism("pas"); c.add_synapse({0, 0.5}, "expsyn"); c.add_detector({0, 0.5}, -64); -#ifndef NO_GJ_YET c.add_gap_junction({0, 0.5}); -#endif return c; } @@ -100,15 +93,13 @@ TEST(mc_event_delivery, two_interleaved_groups) { EXPECT_EQ((gid_vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), spike_gids); } -#ifndef NO_GJ_YET - typedef std::vector<std::pair<unsigned, unsigned>> cell_gj_pairs; struct test_recipe_gj: public test_recipe { explicit test_recipe_gj(int n, cell_gj_pairs gj_pairs): test_recipe(n), gj_pairs_(std::move(gj_pairs)) {} - cell_size_type num_gap_junctions(cell_gid_type) const override { return 1; } + cell_size_type num_gap_junction_sites(cell_gid_type) const override { return 1; } std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type i) const override { std::vector<gap_junction_connection> gjs; @@ -126,5 +117,3 @@ TEST(mc_event_delivery, gj_reordered) { gid_vector spike_gids = run_test_sim(R, {{0, 1, 2, 3, 4}}); EXPECT_EQ((gid_vector{0, 1, 2, 3, 4}), spike_gids); } - -#endif // ndef NO_GJ_YET diff --git a/test/unit/test_fvm_lowered.cpp b/test/unit/test_fvm_lowered.cpp index 80e73927..84e2e226 100644 --- a/test/unit/test_fvm_lowered.cpp +++ b/test/unit/test_fvm_lowered.cpp @@ -95,7 +95,7 @@ TEST(fvm_lowered, matrix_init) probe_association_map<probe_handle> probe_map; fvm_cell fvcell(context); - fvcell.initialize({0}, cable1d_recipe(cell), targets, probe_map); + fvcell.initialize({0}, {0}, cable1d_recipe(cell), targets, probe_map); auto& J = fvcell.*private_matrix_ptr; EXPECT_EQ(J.size(), 11u); @@ -140,7 +140,7 @@ TEST(fvm_lowered, target_handles) { probe_association_map<probe_handle> probe_map; fvm_cell fvcell(context); - fvcell.initialize({0, 1}, cable1d_recipe(cells), targets, probe_map); + fvcell.initialize({0, 1}, {0, 0}, cable1d_recipe(cells), targets, probe_map); mechanism* expsyn = find_mechanism(fvcell, "expsyn"); ASSERT_TRUE(expsyn); @@ -204,7 +204,7 @@ TEST(fvm_lowered, stimulus) { probe_association_map<probe_handle> probe_map; fvm_cell fvcell(context); - fvcell.initialize({0}, cable1d_recipe(cells), targets, probe_map); + fvcell.initialize({0}, {0}, cable1d_recipe(cells), targets, probe_map); mechanism* stim = find_mechanism(fvcell, "_builtin_stimulus"); ASSERT_TRUE(stim); @@ -298,7 +298,7 @@ TEST(fvm_lowered, derived_mechs) { execution_context context; fvm_cell fvcell(context); - fvcell.initialize({0, 1, 2}, rec, targets, probe_map); + fvcell.initialize({0, 1, 2}, {0, 0, 0}, rec, targets, probe_map); // Both mechanisms will have the same internal name, "test_kin1". @@ -404,7 +404,7 @@ TEST(fvm_lowered, weighted_write_ion) { probe_association_map<probe_handle> probe_map; fvm_cell fvcell(context); - fvcell.initialize({0}, cable1d_recipe(c), targets, probe_map); + fvcell.initialize({0}, {0}, cable1d_recipe(c), targets, probe_map); auto& state = *(fvcell.*private_state_ptr).get(); auto& ion = state.ion_data.at(ionKind::ca); @@ -446,3 +446,247 @@ TEST(fvm_lowered, weighted_write_ion) { EXPECT_EQ(expected_iconc, ion_iconc); } +TEST(fvm_lowered, gj_coords_simple) { + using pair = std::pair<int, int>; + + class gap_recipe: public recipe { + public: + gap_recipe() {} + + cell_size_type num_cells() const override { return n_; } + cell_kind get_cell_kind(cell_gid_type) const override { return cell_kind::cable1d_neuron; } + util::unique_any get_cell_description(cell_gid_type gid) const override { + return {}; + } + std::vector<arb::gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override{ + std::vector<gap_junction_connection> conns; + conns.push_back(gap_junction_connection({(gid+1)%2, 0}, {gid, 0}, 0.5)); + return conns; + } + + protected: + cell_size_type n_ = 2; + }; + + execution_context context; + fvm_cell fvcell(context); + + gap_recipe rec; + std::vector<mc_cell> cells; + mc_cell c, d; + c.add_soma(2.1); + c.add_cable(0, section_kind::dendrite, 0.3, 0.2, 10); + c.segment(1)->set_compartments(5); + c.add_gap_junction({1, 0.8}); + cells.push_back(std::move(c)); + + d.add_soma(2.4); + d.add_cable(0, section_kind::dendrite, 0.3, 0.2, 10); + d.segment(1)->set_compartments(2); + d.add_gap_junction({1, 1}); + cells.push_back(std::move(d)); + + fvm_discretization D = fvm_discretize(cells); + + std::vector<cell_gid_type> gids = {0, 1}; + auto GJ = fvcell.fvm_gap_junctions(cells, gids, rec, D); + + auto weight = [&](fvm_value_type g, fvm_index_type i){ + return g * 1e3 / D.cv_area[i]; + }; + + EXPECT_EQ(pair({4,8}), GJ[0].loc); + EXPECT_EQ(weight(0.5, 4), GJ[0].weight); + + EXPECT_EQ(pair({8,4}), GJ[1].loc); + EXPECT_EQ(weight(0.5, 8), GJ[1].weight); +} + +TEST(fvm_lowered, gj_coords_complex) { + using pair = std::pair<int, int>; + + class gap_recipe: public recipe { + public: + gap_recipe() {} + + cell_size_type num_cells() const override { return n_; } + cell_kind get_cell_kind(cell_gid_type) const override { return cell_kind::cable1d_neuron; } + util::unique_any get_cell_description(cell_gid_type gid) const override { + return {}; + } + std::vector<arb::gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override{ + std::vector<gap_junction_connection> conns; + switch (gid) { + case 0 : return { + gap_junction_connection({2, 0}, {0, 1}, 0.01), + gap_junction_connection({1, 0}, {0, 0}, 0.03), + gap_junction_connection({1, 1}, {0, 0}, 0.04) + }; + case 1 : return { + gap_junction_connection({0, 0}, {1, 0}, 0.03), + gap_junction_connection({0, 0}, {1, 1}, 0.04), + gap_junction_connection({2, 1}, {1, 2}, 0.02), + gap_junction_connection({2, 2}, {1, 3}, 0.01) + }; + case 2 : return { + gap_junction_connection({0, 1}, {2, 0}, 0.01), + gap_junction_connection({1, 2}, {2, 1}, 0.02), + gap_junction_connection({1, 3}, {2, 2}, 0.01) + }; + default : return {}; + } + return conns; + } + + protected: + cell_size_type n_ = 3; + }; + + execution_context context; + fvm_cell fvcell(context); + + gap_recipe rec; + mc_cell c0, c1, c2; + std::vector<mc_cell> cells; + + // Make 3 cells + c0.add_soma(2.1); + c0.add_cable(0, section_kind::dendrite, 0.3, 0.2, 8); + c0.segment(1)->set_compartments(4); + + c1.add_soma(1.4); + c1.add_cable(0, section_kind::dendrite, 0.3, 0.5, 12); + c1.segment(1)->set_compartments(6); + c1.add_cable(1, section_kind::dendrite, 0.3, 0.2, 9); + c1.segment(2)->set_compartments(3); + c1.add_cable(1, section_kind::dendrite, 0.2, 0.2, 15); + c1.segment(3)->set_compartments(5); + + c2.add_soma(2.9); + c2.add_cable(0, section_kind::dendrite, 0.3, 0.5, 4); + c2.segment(1)->set_compartments(2); + c2.add_cable(1, section_kind::dendrite, 0.4, 0.2, 6); + c2.segment(2)->set_compartments(2); + c2.add_cable(1, section_kind::dendrite, 0.1, 0.2, 8); + c2.segment(3)->set_compartments(2); + c2.add_cable(2, section_kind::dendrite, 0.2, 0.2, 4); + c2.segment(4)->set_compartments(2); + c2.add_cable(2, section_kind::dendrite, 0.2, 0.2, 4); + c2.segment(5)->set_compartments(2); + + // Add 5 gap junctions + c0.add_gap_junction({1, 1}); + c0.add_gap_junction({1, 0.5}); + + c1.add_gap_junction({2, 1}); + c1.add_gap_junction({1, 1}); + c1.add_gap_junction({1, 0.45}); + c1.add_gap_junction({1, 0.1}); + + c2.add_gap_junction({1, 0.5}); + c2.add_gap_junction({4, 1}); + c2.add_gap_junction({2, 1}); + + cells.push_back(std::move(c0)); + cells.push_back(std::move(c1)); + cells.push_back(std::move(c2)); + + fvm_discretization D = fvm_discretize(cells); + std::vector<cell_gid_type> gids = {0, 1, 2}; + + auto GJ = fvcell.fvm_gap_junctions(cells, gids, rec, D); + EXPECT_EQ(10u, GJ.size()); + + auto weight = [&](fvm_value_type g, fvm_index_type i){ + return g * 1e3 / D.cv_area[i]; + }; + + std::vector<pair> expected_loc = {{4, 14}, {4,11}, {2,21}, {14, 4}, {11,4} ,{8,28}, {6, 24}, {21,2}, {28,8}, {24, 6}}; + std::vector<double> expected_weight = { + weight(0.03, 4), weight(0.04, 4), weight(0.01, 2), weight(0.03, 14), weight(0.04, 11), + weight(0.02, 8), weight(0.01, 6), weight(0.01, 21), weight(0.02, 28), weight(0.01, 24) + }; + + for (unsigned i = 0; i < GJ.size(); i++) { + bool found = false; + for (unsigned j = 0; j < expected_loc.size(); j++) { + if (expected_loc[j].first == GJ[i].loc.first && expected_loc[j].second == GJ[i].loc.second) { + found = true; + EXPECT_EQ(expected_weight[j], GJ[i].weight); + break; + } + } + EXPECT_TRUE(found); + } + std::cout << std::endl; +} + +TEST(fvm_lowered, cell_group_gj) { + using pair = std::pair<int, int>; + + class gap_recipe: public recipe { + public: + gap_recipe() {} + + cell_size_type num_cells() const override { return n_; } + cell_kind get_cell_kind(cell_gid_type) const override { return cell_kind::cable1d_neuron; } + util::unique_any get_cell_description(cell_gid_type gid) const override { + return {}; + } + std::vector<arb::gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override{ + std::vector<gap_junction_connection> conns; + if (gid % 2 == 0) { + // connect 5 of the first 10 cells in a ring; connect 5 of the second 10 cells in a ring + auto next_cell = gid == 8 ? 0 : (gid == 18 ? 10 : gid + 2); + auto prev_cell = gid == 0 ? 8 : (gid == 10 ? 18 : gid - 2); + conns.push_back(gap_junction_connection({next_cell, 0}, {gid, 0}, 0.03)); + conns.push_back(gap_junction_connection({prev_cell, 0}, {gid, 0}, 0.03)); + } + return conns; + } + + protected: + cell_size_type n_ = 20; + }; + execution_context context; + fvm_cell fvcell(context); + + gap_recipe rec; + std::vector<mc_cell> cell_group0; + std::vector<mc_cell> cell_group1; + + // Make 20 cells + for (unsigned i = 0; i < 20; i++) { + mc_cell c; + c.add_soma(2.1); + if (i % 2 == 0) { + c.add_gap_junction({0, 1}); + if (i < 10) { + cell_group0.push_back(std::move(c)); + } + else { + cell_group1.push_back(std::move(c)); + } + } + + } + + std::vector<cell_gid_type> gids_cg0 = { 0, 2, 4, 6, 8}; + std::vector<cell_gid_type> gids_cg1 = {10,12,14,16,18}; + + fvm_discretization D0 = fvm_discretize(cell_group0); + fvm_discretization D1 = fvm_discretize(cell_group1); + + auto GJ0 = fvcell.fvm_gap_junctions(cell_group0, gids_cg0, rec, D0); + auto GJ1 = fvcell.fvm_gap_junctions(cell_group1, gids_cg1, rec, D1); + + EXPECT_EQ(10u, GJ0.size()); + EXPECT_EQ(10u, GJ1.size()); + + std::vector<pair> expected_loc = {{0, 1}, {0, 4}, {1, 2}, {1, 0}, {2, 3} ,{2, 1}, {3, 4}, {3, 2}, {4, 0}, {4, 3}}; + + for (unsigned i = 0; i < GJ0.size(); i++) { + EXPECT_EQ(expected_loc[i], GJ0[i].loc); + EXPECT_EQ(expected_loc[i], GJ1[i].loc); + } +} diff --git a/test/unit/test_local_context.cpp b/test/unit/test_local_context.cpp index 0f42e206..3bda37fc 100644 --- a/test/unit/test_local_context.cpp +++ b/test/unit/test_local_context.cpp @@ -77,3 +77,19 @@ TEST(local_context, gather_spikes) EXPECT_EQ(part[0], 0u); EXPECT_EQ(part[1], spikes.size()); } + +TEST(local_context, gather_gids) +{ + arb::local_context ctx; + using gvec = std::vector<arb::cell_gid_type>; + + gvec gids = {0, 1, 2, 3, 4}; + + auto s = ctx.gather_gids(gids); + + auto& part = s.partition(); + EXPECT_EQ(s.values(), gids); + EXPECT_EQ(part.size(), 2u); + EXPECT_EQ(part[0], 0u); + EXPECT_EQ(part[1], gids.size()); +} diff --git a/test/unit/test_mc_cell_group.cpp b/test/unit/test_mc_cell_group.cpp index 66326c22..70077f32 100644 --- a/test/unit/test_mc_cell_group.cpp +++ b/test/unit/test_mc_cell_group.cpp @@ -28,6 +28,128 @@ namespace { return c; } + + class gap_recipe_0: public recipe { + public: + gap_recipe_0() {} + + cell_size_type num_cells() const override { + return size_; + } + + arb::util::unique_any get_cell_description(cell_gid_type gid) const override { + mc_cell c; + c.add_soma(20); + c.add_gap_junction({0, 1}); + return {std::move(c)}; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable1d_neuron; + } + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { + switch (gid) { + case 0 : + return {gap_junction_connection({5, 0}, {0, 0}, 0.1)}; + case 2 : + return { + gap_junction_connection({3, 0}, {2, 0}, 0.1), + }; + case 3 : + return { + gap_junction_connection({7, 0}, {3, 0}, 0.1), + gap_junction_connection({3, 0}, {2, 0}, 0.1) + }; + case 5 : + return {gap_junction_connection({5, 0}, {0, 0}, 0.1)}; + case 7 : + return { + gap_junction_connection({3, 0}, {7, 0}, 0.1), + }; + default : + return {}; + } + } + + private: + cell_size_type size_ = 12; + }; + + class gap_recipe_1: public recipe { + public: + gap_recipe_1() {} + + cell_size_type num_cells() const override { + return size_; + } + + arb::util::unique_any get_cell_description(cell_gid_type) const override { + mc_cell c; + c.add_soma(20); + return {std::move(c)}; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable1d_neuron; + } + + private: + cell_size_type size_ = 12; + }; + + class gap_recipe_2: public recipe { + public: + gap_recipe_2() {} + + cell_size_type num_cells() const override { + return size_; + } + + arb::util::unique_any get_cell_description(cell_gid_type) const override { + mc_cell c; + c.add_soma(20); + c.add_gap_junction({0,1}); + return {std::move(c)}; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable1d_neuron; + } + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { + switch (gid) { + case 0 : + return { + gap_junction_connection({2, 0}, {0, 0}, 0.1), + gap_junction_connection({3, 0}, {0, 0}, 0.1), + gap_junction_connection({5, 0}, {0, 0}, 0.1) + }; + case 2 : + return { + gap_junction_connection({0, 0}, {2, 0}, 0.1), + gap_junction_connection({3, 0}, {2, 0}, 0.1), + gap_junction_connection({5, 0}, {2, 0}, 0.1) + }; + case 3 : + return { + gap_junction_connection({0, 0}, {3, 0}, 0.1), + gap_junction_connection({2, 0}, {3, 0}, 0.1), + gap_junction_connection({5, 0}, {3, 0}, 0.1) + }; + case 5 : + return { + gap_junction_connection({2, 0}, {5, 0}, 0.1), + gap_junction_connection({3, 0}, {5, 0}, 0.1), + gap_junction_connection({0, 0}, {5, 0}, 0.1) + }; + default : + return {}; + } + } + + private: + cell_size_type size_ = 12; + }; + } ACCESS_BIND( @@ -84,3 +206,31 @@ TEST(mc_cell_group, sources) { } } } + +TEST(mc_cell_group, generated_gids_deps_) { + { + std::vector<cell_gid_type> gids = {11u, 5u, 2u, 3u, 0u, 8u, 7u}; + mc_cell_group group{gids, gap_recipe_0(), lowered_cell()}; + + std::vector<cell_gid_type> expected_gids = {11u, 5u, 0u, 2u, 3u, 7u, 8u}; + std::vector<int> expected_deps = {0, 2, 0, 3, 0, 0, 0}; + EXPECT_EQ(expected_gids, group.get_gids()); + EXPECT_EQ(expected_deps, group.get_dependencies()); + } + { + std::vector<cell_gid_type> gids = {11u, 5u, 2u, 3u, 0u, 8u, 7u}; + mc_cell_group group{gids, gap_recipe_1(), lowered_cell()}; + + std::vector<int> expected_deps = {0, 0, 0, 0, 0, 0, 0}; + EXPECT_EQ(gids, group.get_gids()); + EXPECT_EQ(expected_deps, group.get_dependencies()); + } + { + std::vector<cell_gid_type> gids = {5u, 2u, 3u, 0u}; + mc_cell_group group{gids, gap_recipe_2(), lowered_cell()}; + + std::vector<int> expected_deps = {4, 0, 0, 0}; + EXPECT_EQ(gids, group.get_gids()); + EXPECT_EQ(expected_deps, group.get_dependencies()); + } +} diff --git a/test/unit/test_mech_temperature.cpp b/test/unit/test_mech_temperature.cpp index 8d152d99..81a11394 100644 --- a/test/unit/test_mech_temperature.cpp +++ b/test/unit/test_mech_temperature.cpp @@ -24,9 +24,11 @@ void run_celsius_test() { fvm_size_type ncv = 3; std::vector<fvm_index_type> cv_to_cell(ncv, 0); + std::vector<fvm_gap_junction> gj = {}; + std::vector<int> deps = {0}; auto celsius_test = cat.instance<backend>("celsius_test"); auto shared_state = std::make_unique<typename backend::shared_state>( - ncell, cv_to_cell, celsius_test->data_alignment()); + ncell, cv_to_cell, deps, gj, celsius_test->data_alignment()); mechanism::layout layout; layout.weight.assign(ncv, 1.); diff --git a/test/unit/test_probe.cpp b/test/unit/test_probe.cpp index 42326f35..2e52784a 100644 --- a/test/unit/test_probe.cpp +++ b/test/unit/test_probe.cpp @@ -40,7 +40,7 @@ TEST(probe, fvm_lowered_cell) { probe_association_map<probe_handle> probe_map; fvm_cell lcell(context); - lcell.initialize({0}, rec, targets, probe_map); + lcell.initialize({0}, {0}, rec, targets, probe_map); EXPECT_EQ(3u, rec.num_probes(0)); EXPECT_EQ(3u, probe_map.size()); diff --git a/test/unit/test_supercell.cpp b/test/unit/test_supercell.cpp new file mode 100644 index 00000000..f4bebb2c --- /dev/null +++ b/test/unit/test_supercell.cpp @@ -0,0 +1,71 @@ +#include "../gtest.h" + +#include <cmath> +#include <tuple> +#include <vector> + +#include <arbor/constants.hpp> +#include <arbor/mechcat.hpp> +#include <arbor/util/optional.hpp> +#include <arbor/mc_cell.hpp> + +#include "backends/multicore/fvm.hpp" +#include "backends/multicore/mechanism.hpp" +#include "util/maputil.hpp" +#include "util/range.hpp" + +#include "common.hpp" +#include "mech_private_field_access.hpp" + +using namespace arb; + +using backend = ::arb::multicore::backend; +using shared_state = backend::shared_state; +using value_type = backend::value_type; +using size_type = backend::size_type; + +// Access to more mechanism protected data: + +ACCESS_BIND(const value_type* multicore::mechanism::*, vec_v_ptr, &multicore::mechanism::vec_v_) +ACCESS_BIND(value_type* multicore::mechanism::*, vec_i_ptr, &multicore::mechanism::vec_i_) + +TEST(supercell, sync_time_to) { + using value_type = multicore::backend::value_type; + using index_type = multicore::backend::index_type; + + int num_cell = 10; + + std::vector<fvm_gap_junction> gj = {}; + std::vector<index_type> deps = {4, 0, 0, 0, 3, 0, 0, 2, 0, 0}; + + shared_state state(num_cell, std::vector<index_type>(num_cell, 0), deps, gj, 1u); + + state.time_to = {0.3, 0.1, 0.2, 0.4, 0.5, 0.6, 0.1, 0.1, 0.6, 0.9}; + state.sync_time_to(); + + std::vector<value_type> expected = {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9}; + + for (unsigned i = 0; i < state.time_to.size(); i++) { + EXPECT_EQ(expected[i], state.time_to[i]); + } + + state.time_dep = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + state.time_to = {0.3, 0.1, 0.2, 0.4, 0.5, 0.6, 0.1, 0.1, 0.6, 0.9}; + expected = {0.3, 0.1, 0.2, 0.4, 0.5, 0.6, 0.1, 0.1, 0.6, 0.9}; + state.sync_time_to(); + + for (unsigned i = 0; i < state.time_to.size(); i++) { + EXPECT_EQ(expected[i], state.time_to[i]); + } + + state.time_dep = {10, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + state.time_to = {0.3, 0.1, 0.2, 0.4, 0.5, 0.6, 0.1, 0.1, 0.6, 0.9}; + expected = {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1}; + state.sync_time_to(); + + for (unsigned i = 0; i < state.time_to.size(); i++) { + EXPECT_EQ(expected[i], state.time_to[i]); + } + +} + diff --git a/test/unit/test_synapses.cpp b/test/unit/test_synapses.cpp index 1428455d..5210a7d4 100644 --- a/test/unit/test_synapses.cpp +++ b/test/unit/test_synapses.cpp @@ -85,8 +85,9 @@ TEST(synapses, syn_basic_state) { auto exp2syn = unique_cast<multicore::mechanism>(global_default_catalogue().instance<backend>("exp2syn")); ASSERT_TRUE(exp2syn); + std::vector<fvm_gap_junction> gj = {}; auto align = std::max(expsyn->data_alignment(), exp2syn->data_alignment()); - shared_state state(num_cell, std::vector<index_type>(num_comp, 0), align); + shared_state state(num_cell, std::vector<index_type>(num_comp, 0), std::vector<index_type>(num_cell, 0), gj, align); state.reset(-65., constant::hh_squid_temp); fill(state.current_density, 1.0); -- GitLab