diff --git a/arbor/arbexcept.cpp b/arbor/arbexcept.cpp index a8b589b7b0c3490560f599b3df951f819e1c257a..8ea9e1dd141b3abf8b0658fd100708942a2735e4 100644 --- a/arbor/arbexcept.cpp +++ b/arbor/arbexcept.cpp @@ -26,6 +26,13 @@ bad_probe_id::bad_probe_id(cell_member_type probe_id): probe_id(probe_id) {} + +gj_kind_mismatch::gj_kind_mismatch(cell_gid_type gid_0, cell_gid_type gid_1): + arbor_exception(pprintf("Cells on gid {} and {} connected via gap junction have different cell kinds", gid_0, gid_1)), + gid_0(gid_0), + gid_1(gid_1) +{} + bad_event_time::bad_event_time(time_type event_time, time_type sim_time): arbor_exception(pprintf("event time {} precedes current simulation time {}", event_time, sim_time)), event_time(event_time), diff --git a/arbor/backends/gpu/gpu_store_types.hpp b/arbor/backends/gpu/gpu_store_types.hpp index ac215ed4370e486cec9cca82617f483d047fd081..708233a964634458b7a4da9d2942fa2ae721e48c 100644 --- a/arbor/backends/gpu/gpu_store_types.hpp +++ b/arbor/backends/gpu/gpu_store_types.hpp @@ -17,6 +17,7 @@ namespace gpu { using array = memory::device_vector<fvm_value_type>; using iarray = memory::device_vector<fvm_index_type>; +using gjarray = memory::device_vector<fvm_gap_junction>; using deliverable_event_stream = arb::gpu::multi_event_stream<deliverable_event>; using sample_event_stream = arb::gpu::multi_event_stream<sample_event>; diff --git a/arbor/backends/gpu/shared_state.cpp b/arbor/backends/gpu/shared_state.cpp index 5cd8972c45ab44913977612d2ae7710e846c1300..5760b02cf48db23046e9d94d798545301f0a0b35 100644 --- a/arbor/backends/gpu/shared_state.cpp +++ b/arbor/backends/gpu/shared_state.cpp @@ -35,10 +35,15 @@ void update_time_to_impl( std::size_t n, fvm_value_type* time_to, const fvm_value_type* time, fvm_value_type dt, fvm_value_type tmax); +void sync_time_to_impl(std::size_t n, fvm_value_type* time_to, const fvm_index_type* time_deps); + void set_dt_impl( fvm_size_type ncell, fvm_size_type ncomp, fvm_value_type* dt_cell, fvm_value_type* dt_comp, const fvm_value_type* time_to, const fvm_value_type* time, const fvm_index_type* cv_to_cell); +void add_gj_current_impl( + fvm_size_type n_gj, const fvm_gap_junction* gj, const fvm_value_type* v, fvm_value_type* i); + void take_samples_impl( const multi_event_stream_state<raw_probe_info>& s, const fvm_value_type* time, fvm_value_type* sample_time, fvm_value_type* sample_value); @@ -109,11 +114,16 @@ void ion_state::zero_current() { shared_state::shared_state( fvm_size_type n_cell, const std::vector<fvm_index_type>& cv_to_cell_vec, + const std::vector<fvm_index_type>& time_dep_vec, + const std::vector<fvm_gap_junction>& gj_vec, unsigned // alignment parameter ignored. ): n_cell(n_cell), n_cv(cv_to_cell_vec.size()), + n_gj(gj_vec.size()), cv_to_cell(make_const_view(cv_to_cell_vec)), + time_dep(make_const_view(time_dep_vec)), + gap_junctions(make_const_view(gj_vec)), time(n_cell), time_to(n_cell), dt_cell(n_cell), @@ -170,10 +180,18 @@ void shared_state::update_time_to(fvm_value_type dt_step, fvm_value_type tmax) { update_time_to_impl(n_cell, time_to.data(), time.data(), dt_step, tmax); } +void shared_state::sync_time_to() { + sync_time_to_impl(n_cell, time_to.data(), time_dep.data()); +} + void shared_state::set_dt() { set_dt_impl(n_cell, n_cv, dt_cell.data(), dt_cv.data(), time_to.data(), time.data(), cv_to_cell.data()); } +void shared_state::add_gj_current() { + add_gj_current_impl(n_gj, gap_junctions.data(), voltage.data(), current_density.data()); +} + std::pair<fvm_value_type, fvm_value_type> shared_state::time_bounds() const { return minmax_value_impl(n_cell, time.data()); } diff --git a/arbor/backends/gpu/shared_state.cu b/arbor/backends/gpu/shared_state.cu index 995eeb7678f0684d7b3bac23626ca5fb632ecc54..806418b99b05e01add18a5f04d96ca64740ee5ee 100644 --- a/arbor/backends/gpu/shared_state.cu +++ b/arbor/backends/gpu/shared_state.cu @@ -5,6 +5,7 @@ #include <backends/event.hpp> #include <backends/multi_event_stream_state.hpp> +#include "cuda_atomic.hpp" #include "cuda_common.hpp" namespace arb { @@ -33,6 +34,24 @@ void init_concentration_impl(unsigned n, T* Xi, T* Xo, const T* weight_Xi, const } } +template <typename T, typename I> +__global__ void sync_time_to_impl(unsigned n, T* time_to, const I* time_deps) { + unsigned i = threadIdx.x+blockIdx.x*blockDim.x; + if (i<n) { + if (time_deps[i] > 0) { + auto min_t = time_to[i]; + for (int j = 1; j < time_deps[i]; j++) { + if (time_to[i+j] < min_t) { + min_t = time_to[i+j]; + } + } + for (int j = 0; j < time_deps[i]; j++) { + time_to[i+j] = min_t; + } + } + } +} + template <typename T> __global__ void update_time_to_impl(unsigned n, T* time_to, const T* time, T dt, T tmax) { unsigned i = threadIdx.x+blockIdx.x*blockDim.x; @@ -42,6 +61,17 @@ __global__ void update_time_to_impl(unsigned n, T* time_to, const T* time, T dt, } } +template <typename T, typename I> +__global__ void add_gj_current_impl(unsigned n, const T* gj_info, const I* voltage, I* current_density) { + unsigned i = threadIdx.x+blockIdx.x*blockDim.x; + if (i<n) { + auto gj = gj_info[i]; + auto curr = gj.weight * (voltage[gj.loc.second] - voltage[gj.loc.first]); // nA + + cuda_atomic_sub(current_density + gj.loc.first, curr); + } +} + // Vector minus: x = y - z template <typename T> __global__ void vec_minus(unsigned n, T* x, const T* y, const T* z) { @@ -101,6 +131,15 @@ void init_concentration_impl( kernel::init_concentration_impl<<<nblock, block_dim>>>(n, Xi, Xo, weight_Xi, weight_Xo, c_int, c_ext); } +void sync_time_to_impl(std::size_t n, fvm_value_type* time_to, const fvm_index_type* time_deps) +{ + if (!n) return; + + constexpr int block_dim = 128; + int nblock = block_count(n, block_dim); + kernel::sync_time_to_impl<<<nblock, block_dim>>>(n, time_to, time_deps); +} + void update_time_to_impl( std::size_t n, fvm_value_type* time_to, const fvm_value_type* time, fvm_value_type dt, fvm_value_type tmax) @@ -126,6 +165,16 @@ void set_dt_impl( kernel::gather<<<nblock, block_dim>>>(ncomp, dt_comp, dt_cell, cv_to_cell); } +void add_gj_current_impl( + fvm_size_type n_gj, const fvm_gap_junction* gj_info, const fvm_value_type* voltage, fvm_value_type* current_density) +{ + if (!n_gj) return; + + constexpr int block_dim = 128; + int nblock = block_count(n_gj, block_dim); + kernel::add_gj_current_impl<<<nblock, block_dim>>>(n_gj, gj_info, voltage, current_density); +} + void take_samples_impl( const multi_event_stream_state<raw_probe_info>& s, const fvm_value_type* time, fvm_value_type* sample_time, fvm_value_type* sample_value) diff --git a/arbor/backends/gpu/shared_state.hpp b/arbor/backends/gpu/shared_state.hpp index 81174715d60c62e074da9db07cf4a507e2e07245..c66ac3f16508a6163ff4b6daab82e8b4fa1b3458 100644 --- a/arbor/backends/gpu/shared_state.hpp +++ b/arbor/backends/gpu/shared_state.hpp @@ -67,8 +67,11 @@ struct ion_state { struct shared_state { fvm_size_type n_cell = 0; // Number of distinct cells (integration domains). fvm_size_type n_cv = 0; // Total number of CVs. + fvm_size_type n_gj = 0; // Total number of GJs. iarray cv_to_cell; // Maps CV index to cell index. + iarray time_dep; // Provides information about supercells + gjarray gap_junctions; // Stores gap_junction info. array time; // Maps cell index to integration start time [ms]. array time_to; // Maps cell index to integration stop time [ms]. array dt_cell; // Maps cell index to (stop time) - (start time) [ms]. @@ -86,6 +89,8 @@ struct shared_state { shared_state( fvm_size_type n_cell, const std::vector<fvm_index_type>& cv_to_cell_vec, + const std::vector<fvm_index_type>& time_dep_vec, + const std::vector<fvm_gap_junction>& gj_vec, unsigned align ); @@ -104,9 +109,15 @@ struct shared_state { // Set time_to to earliest of time+dt_step and tmax. void update_time_to(fvm_value_type dt_step, fvm_value_type tmax); + // Synchrnize the time_to for supercells. + void sync_time_to(); + // Set the per-cell and per-compartment dt from time_to - time. void set_dt(); + // Update gap_junction state + void add_gj_current(); + // Return minimum and maximum time value [ms] across cells. std::pair<fvm_value_type, fvm_value_type> time_bounds() const; diff --git a/arbor/backends/multicore/multicore_common.hpp b/arbor/backends/multicore/multicore_common.hpp index b0b3ed4ec0b01aacfd8cb6113f0f3d60a57e25c1..4a3ee9bc018318a3d5ea990c66a3e1457ca69700 100644 --- a/arbor/backends/multicore/multicore_common.hpp +++ b/arbor/backends/multicore/multicore_common.hpp @@ -23,6 +23,7 @@ using padded_vector = std::vector<V, util::padded_allocator<V>>; using array = padded_vector<fvm_value_type>; using iarray = padded_vector<fvm_index_type>; +using gjarray = padded_vector<fvm_gap_junction>; using deliverable_event_stream = arb::multicore::multi_event_stream<deliverable_event>; using sample_event_stream = arb::multicore::multi_event_stream<sample_event>; diff --git a/arbor/backends/multicore/shared_state.cpp b/arbor/backends/multicore/shared_state.cpp index 38120f4a15fecb7725943d5a286b304669c954c9..3bb73441a964f0be89c46bda2d0e767d9194da93 100644 --- a/arbor/backends/multicore/shared_state.cpp +++ b/arbor/backends/multicore/shared_state.cpp @@ -117,13 +117,18 @@ void ion_state::zero_current() { shared_state::shared_state( fvm_size_type n_cell, const std::vector<fvm_index_type>& cv_to_cell_vec, + const std::vector<fvm_index_type>& time_dep_vec, + const std::vector<fvm_gap_junction>& gj_vec, unsigned align ): alignment(min_alignment(align)), alloc(alignment), n_cell(n_cell), n_cv(cv_to_cell_vec.size()), + n_gj(gj_vec.size()), cv_to_cell(math::round_up(n_cv, alignment), pad(alignment)), + time_dep(n_cell), + gap_junctions(math::round_up(n_gj, alignment), pad(alignment)), time(n_cell, pad(alignment)), time_to(n_cell, pad(alignment)), dt_cell(n_cell, pad(alignment)), @@ -133,11 +138,16 @@ shared_state::shared_state( temperature_degC(NAN), deliverable_events(n_cell) { - // For indices in the padded tail of cv_to_cell, set index to last valid cell index. + std::copy(time_dep_vec.begin(), time_dep_vec.end(), time_dep.begin()); + // For indices in the padded tail of cv_to_cell, set index to last valid cell index. if (n_cv>0) { std::copy(cv_to_cell_vec.begin(), cv_to_cell_vec.end(), cv_to_cell.begin()); - std::fill(cv_to_cell.begin()+n_cv, cv_to_cell.end(), cv_to_cell_vec.back()); + std::fill(cv_to_cell.begin() + n_cv, cv_to_cell.end(), cv_to_cell_vec.back()); + } + if (n_gj>0) { + std::copy(gj_vec.begin(), gj_vec.end(), gap_junctions.begin()); + std::fill(gap_junctions.begin()+n_gj, gap_junctions.end(), gj_vec.back()); } } @@ -182,6 +192,21 @@ void shared_state::ions_nernst_reversal_potential(fvm_value_type temperature_K) i.second.nernst(temperature_K); } } +void shared_state::sync_time_to() { + for (fvm_size_type i = 0; i<n_cell; i++) { + if (!time_dep[i]) continue; + + fvm_value_type min_t = time_to[i]; + for (int j = 1; j < time_dep[i]; j++) { + if (time_to[i+j] < min_t) { + min_t = time_to[i+j]; + } + } + for (int j = 0; j < time_dep[i]; j++) { + time_to[i+j] = min_t; + } + } +} void shared_state::update_time_to(fvm_value_type dt_step, fvm_value_type tmax) { for (fvm_size_type i = 0; i<n_cell; i+=simd_width) { @@ -208,6 +233,16 @@ void shared_state::set_dt() { } } +void shared_state::add_gj_current() { + for (unsigned i = 0; i < n_gj; i++) { + auto gj = gap_junctions[i]; + auto curr = gj.weight * + (voltage[gj.loc.second] - voltage[gj.loc.first]); // nA + + current_density[gj.loc.first] -= curr; + } +} + std::pair<fvm_value_type, fvm_value_type> shared_state::time_bounds() const { return util::minmax_value(time); } diff --git a/arbor/backends/multicore/shared_state.hpp b/arbor/backends/multicore/shared_state.hpp index 25247f1435d377eb8a01a107738ca49d656f0907..40be9eb297e6bafe3b7e72de31c66cb172a814a7 100644 --- a/arbor/backends/multicore/shared_state.hpp +++ b/arbor/backends/multicore/shared_state.hpp @@ -86,8 +86,11 @@ struct shared_state { fvm_size_type n_cell = 0; // Number of distinct cells (integration domains). fvm_size_type n_cv = 0; // Total number of CVs. + fvm_size_type n_gj = 0; // Total number of GJs. iarray cv_to_cell; // Maps CV index to cell index. + iarray time_dep; // Provides information about supercells + gjarray gap_junctions; // Stores gap_junction info. array time; // Maps cell index to integration start time [ms]. array time_to; // Maps cell index to integration stop time [ms]. array dt_cell; // Maps cell index to (stop time) - (start time) [ms]. @@ -105,6 +108,8 @@ struct shared_state { shared_state( fvm_size_type n_cell, const std::vector<fvm_index_type>& cv_to_cell_vec, + const std::vector<fvm_index_type>& time_dep_vec, + const std::vector<fvm_gap_junction>& gj_vec, unsigned align ); @@ -123,9 +128,15 @@ struct shared_state { // Set time_to to earliest of time+dt_step and tmax. void update_time_to(fvm_value_type dt_step, fvm_value_type tmax); + // Synchrnize the time_to for supercells. + void sync_time_to(); + // Set the per-cell and per-compartment dt from time_to - time. void set_dt(); + // Update gap_junction state + void add_gj_current(); + // Return minimum and maximum time value [ms] across cells. std::pair<fvm_value_type, fvm_value_type> time_bounds() const; diff --git a/arbor/communication/dry_run_context.cpp b/arbor/communication/dry_run_context.cpp index 15714475c8da66b5918573ec0c17a3fe978bf7d4..cfa6ddf0f50481b5082f8b9bdaa5a11f488a9f46 100644 --- a/arbor/communication/dry_run_context.cpp +++ b/arbor/communication/dry_run_context.cpp @@ -34,13 +34,40 @@ struct dry_run_context_impl { } std::vector<count_type> partition; - for(count_type i = 0; i <= num_ranks_; i++) { + for (count_type i = 0; i <= num_ranks_; i++) { partition.push_back(static_cast<count_type>(i*local_size)); } return gathered_vector<arb::spike>(std::move(gathered_spikes), std::move(partition)); } + gathered_vector<cell_gid_type> + gather_gids(const std::vector<cell_gid_type>& local_gids) const { + using count_type = typename gathered_vector<cell_gid_type>::count_type; + + count_type local_size = local_gids.size(); + + std::vector<cell_gid_type> gathered_gids; + gathered_gids.reserve(local_size*num_ranks_); + + for (count_type i = 0; i < num_ranks_; i++) { + gathered_gids.insert(gathered_gids.end(), local_gids.begin(), local_gids.end()); + } + + for (count_type i = 0; i < num_ranks_; i++) { + for (count_type j = i*local_size; j < (i+1)*local_size; j++){ + gathered_gids[j] += num_cells_per_tile_*i; + } + } + + std::vector<count_type> partition; + for (count_type i = 0; i <= num_ranks_; i++) { + partition.push_back(static_cast<count_type>(i*local_size)); + } + + return gathered_vector<cell_gid_type>(std::move(gathered_gids), std::move(partition)); + } + int id() const { return 0; } int size() const { return num_ranks_; } diff --git a/arbor/communication/mpi_context.cpp b/arbor/communication/mpi_context.cpp index be80dca50966bddde9683d4b05d0b8c4cd1d54c7..9d0bc30b64f6a22614db6c10662a4c2f7c63e219 100644 --- a/arbor/communication/mpi_context.cpp +++ b/arbor/communication/mpi_context.cpp @@ -33,6 +33,11 @@ struct mpi_context_impl { return mpi::gather_all_with_partition(local_spikes, comm_); } + gathered_vector<cell_gid_type> + gather_gids(const std::vector<cell_gid_type>& local_gids) const { + return mpi::gather_all_with_partition(local_gids, comm_); + } + std::string name() const { return "MPI"; } int id() const { return rank_; } int size() const { return size_; } diff --git a/arbor/distributed_context.hpp b/arbor/distributed_context.hpp index ead2688732ae367d9aaab2f6ad3fbcdb6c5f17f9..1037e3f2098be2af748e93b949c6a99641fe749f 100644 --- a/arbor/distributed_context.hpp +++ b/arbor/distributed_context.hpp @@ -42,6 +42,7 @@ namespace arb { class distributed_context { public: using spike_vector = std::vector<arb::spike>; + using gid_vector = std::vector<cell_gid_type>; // default constructor uses a local context: see below. distributed_context(); @@ -58,6 +59,10 @@ public: return impl_->gather_spikes(local_spikes); } + gathered_vector<cell_gid_type> gather_gids(const gid_vector& local_gids) const { + return impl_->gather_gids(local_gids); + } + int id() const { return impl_->id(); } @@ -84,6 +89,8 @@ private: struct interface { virtual gathered_vector<arb::spike> gather_spikes(const spike_vector& local_spikes) const = 0; + virtual gathered_vector<cell_gid_type> + gather_gids(const gid_vector& local_gids) const = 0; virtual int id() const = 0; virtual int size() const = 0; virtual void barrier() const = 0; @@ -104,6 +111,10 @@ private: gather_spikes(const spike_vector& local_spikes) const override { return wrapped.gather_spikes(local_spikes); } + virtual gathered_vector<cell_gid_type> + gather_gids(const gid_vector& local_gids) const override { + return wrapped.gather_gids(local_gids); + } int id() const override { return wrapped.id(); } @@ -138,6 +149,14 @@ struct local_context { {0u, static_cast<count_type>(local_spikes.size())} ); } + gathered_vector<cell_gid_type> + gather_gids(const std::vector<cell_gid_type>& local_gids) const { + using count_type = typename gathered_vector<cell_gid_type>::count_type; + return gathered_vector<cell_gid_type>( + std::vector<cell_gid_type>(local_gids), + {0u, static_cast<count_type>(local_gids.size())} + ); + } int id() const { return 0; } diff --git a/arbor/fvm_layout.cpp b/arbor/fvm_layout.cpp index 0f75ffe2628b122878778dfdeba2638a361c04d6..9d2d16075d7dcdcaf2b472a6d307f135fe7ed304 100644 --- a/arbor/fvm_layout.cpp +++ b/arbor/fvm_layout.cpp @@ -236,7 +236,6 @@ fvm_discretization fvm_discretize(const std::vector<mc_cell>& cells) { return D; } - // Build up mechanisms. // // Processing procedes in the following stages: diff --git a/arbor/fvm_layout.hpp b/arbor/fvm_layout.hpp index ceeb30081a70dd8aa19015de8af82cf2615e8983..4ab94980feed388102cd5f807ae02cbefdceccff 100644 --- a/arbor/fvm_layout.hpp +++ b/arbor/fvm_layout.hpp @@ -5,6 +5,7 @@ #include <arbor/mechanism.hpp> #include <arbor/mechinfo.hpp> #include <arbor/mechcat.hpp> +#include <arbor/recipe.hpp> #include "fvm_compartment.hpp" #include "util/span.hpp" diff --git a/arbor/fvm_lowered_cell.hpp b/arbor/fvm_lowered_cell.hpp index c710074cc18b4a096fa463993acc77b03118b901..e49c610b961b3fb9a9b5dc956e043991fa5fd137 100644 --- a/arbor/fvm_lowered_cell.hpp +++ b/arbor/fvm_lowered_cell.hpp @@ -28,6 +28,7 @@ struct fvm_lowered_cell { virtual void initialize( const std::vector<cell_gid_type>& gids, + const std::vector<int>& deps, const recipe& rec, std::vector<target_handle>& target_handles, probe_association_map<probe_handle>& probe_map) = 0; diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index d033f464f343d52bada8c2984d761311852faa71..31edf2bf49d78866a2568bda5ff726f1821f2307 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -50,6 +50,7 @@ public: void initialize( const std::vector<cell_gid_type>& gids, + const std::vector<int>& deps, const recipe& rec, std::vector<target_handle>& target_handles, probe_association_map<probe_handle>& probe_map) override; @@ -60,6 +61,12 @@ public: std::vector<deliverable_event> staged_events, std::vector<sample_event> staged_samples) override; + std::vector<fvm_gap_junction> fvm_gap_junctions( + const std::vector<mc_cell>& cells, + const std::vector<cell_gid_type>& gids, + const recipe& rec, + const fvm_discretization& D); + value_type time() const override { return tmin_; } //Exposed for testing purposes @@ -198,6 +205,9 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate( m->nrn_current(); } + // Add current contribution from gap_junctions + state_->add_gj_current(); + PE(advance_integrate_events); state_->deliverable_events.drop_marked_events(); @@ -205,6 +215,7 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate( state_->update_time_to(dt_max, tfinal); state_->deliverable_events.event_time_if_before(state_->time_to); + state_->sync_time_to(); state_->set_dt(); PL(); @@ -301,6 +312,7 @@ void fvm_lowered_cell_impl<B>::assert_voltage_bounded(fvm_value_type bound) { template <typename B> void fvm_lowered_cell_impl<B>::initialize( const std::vector<cell_gid_type>& gids, + const std::vector<int>& deps, const recipe& rec, std::vector<target_handle>& target_handles, probe_association_map<probe_handle>& probe_map) @@ -362,6 +374,10 @@ void fvm_lowered_cell_impl<B>::initialize( fvm_mechanism_data mech_data = fvm_build_mechanism_data(*catalogue, cells, D); + // Discritize and build gap junction info + + auto gj_vector = fvm_gap_junctions(cells, gids, rec, D); + // Create shared cell state. // (SIMD padding requires us to check each mechanism for alignment/padding constraints.) @@ -369,7 +385,7 @@ void fvm_lowered_cell_impl<B>::initialize( util::transform_view(keys(mech_data.mechanisms), [&](const std::string& name) { return mech_instance(name)->data_alignment(); })); - state_ = std::make_unique<shared_state>(ncell, D.cv_to_cell, data_alignment? data_alignment: 1u); + state_ = std::make_unique<shared_state>(ncell, D.cv_to_cell, deps, gj_vector, data_alignment? data_alignment: 1u); // Instantiate mechanisms and ions. @@ -468,4 +484,51 @@ void fvm_lowered_cell_impl<B>::initialize( reset(); } +// Get vector of gap_junctions +template <typename B> +std::vector<fvm_gap_junction> fvm_lowered_cell_impl<B>::fvm_gap_junctions( + const std::vector<mc_cell>& cells, + const std::vector<cell_gid_type>& gids, + const recipe& rec, const fvm_discretization& D) { + + std::vector<fvm_gap_junction> v; + + std::unordered_map<cell_gid_type, std::vector<unsigned>> gid_to_cvs; + for (auto cell_idx: util::make_span(0, D.ncell)) { + + if (rec.num_gap_junction_sites(gids[cell_idx])) { + gid_to_cvs[gids[cell_idx]].reserve(rec.num_gap_junction_sites(gids[cell_idx])); + + auto cell_gj = cells[cell_idx].gap_junction_sites(); + for (auto gj : cell_gj) { + auto cv = D.segment_location_cv(cell_idx, gj); + gid_to_cvs[gids[cell_idx]].push_back(cv); + } + } + } + + for (auto gid: gids) { + auto gj_list = rec.gap_junctions_on(gid); + for (auto g: gj_list) { + if (gid != g.local.gid && gid != g.peer.gid) { + throw arb::bad_cell_description(cell_kind::cable1d_neuron, gid); + } + cell_gid_type cv0, cv1; + try { + cv0 = gid_to_cvs[g.local.gid].at(g.local.index); + cv1 = gid_to_cvs[g.peer.gid].at(g.peer.index); + } + catch (std::out_of_range&) { + throw arb::bad_cell_description(cell_kind::cable1d_neuron, gid); + } + if (gid != g.local.gid) { + std::swap(cv0, cv1); + } + v.push_back(fvm_gap_junction(std::make_pair(cv0, cv1), g.ggap * 1e3 / D.cv_area[cv0])); + } + } + + return v; +} + } // namespace arb diff --git a/arbor/include/arbor/arbexcept.hpp b/arbor/include/arbor/arbexcept.hpp index d136362a0ecec42a6e4ce6176f06d80ed80bfdbd..cdeaa96d5400b755956b5b9b17c07e274ed00b8f 100644 --- a/arbor/include/arbor/arbexcept.hpp +++ b/arbor/include/arbor/arbexcept.hpp @@ -45,6 +45,11 @@ struct bad_probe_id: arbor_exception { cell_member_type probe_id; }; +struct gj_kind_mismatch: arbor_exception { + gj_kind_mismatch(cell_gid_type gid_0, cell_gid_type gid_1); + cell_gid_type gid_0, gid_1; +}; + // Simulation errors: struct bad_event_time: arbor_exception { diff --git a/arbor/include/arbor/domain_decomposition.hpp b/arbor/include/arbor/domain_decomposition.hpp index 619526c5e6b55fbec6f2d1aca04c66b15780d148..4b5d6e9d7624c8ab623cf84e5eab9ebb30abd3d1 100644 --- a/arbor/include/arbor/domain_decomposition.hpp +++ b/arbor/include/arbor/domain_decomposition.hpp @@ -23,9 +23,7 @@ struct group_description { group_description(cell_kind k, std::vector<cell_gid_type> g, backend_kind b): kind(k), gids(std::move(g)), backend(b) - { - arb_assert(std::is_sorted(gids.begin(), gids.end())); - } + {} }; /// Meta data that describes a domain decomposition. diff --git a/arbor/include/arbor/fvm_types.hpp b/arbor/include/arbor/fvm_types.hpp index 9a5bf56dffd875c8bbfa9604540d5f7487b3f059..a4f21386025fccadcb4c4f4ec21367f6eb1b80c0 100644 --- a/arbor/include/arbor/fvm_types.hpp +++ b/arbor/include/arbor/fvm_types.hpp @@ -10,4 +10,16 @@ using fvm_value_type = double; using fvm_size_type = cell_local_size_type; using fvm_index_type = int; +struct fvm_gap_junction { + using value_type = fvm_value_type; + using index_type = fvm_index_type; + + std::pair<index_type, index_type> loc; + value_type weight; + + fvm_gap_junction() {} + fvm_gap_junction(std::pair<index_type, index_type> l, value_type w): loc(l), weight(w) {} + +}; + } // namespace arb diff --git a/arbor/include/arbor/mc_cell.hpp b/arbor/include/arbor/mc_cell.hpp index 71e6579428ff0dd8c091f252bd81590e72cc18a8..8ac2adddec06133acfa01d68d5ce6e4e2fbb788b 100644 --- a/arbor/include/arbor/mc_cell.hpp +++ b/arbor/include/arbor/mc_cell.hpp @@ -100,6 +100,8 @@ public: using value_type = double; using point_type = point<value_type>; + using gap_junction_instance = segment_location; + struct synapse_instance { segment_location location; mechanism_desc mechanism; @@ -123,6 +125,7 @@ public: parents_(other.parents_), stimuli_(other.stimuli_), synapses_(other.synapses_), + gap_junction_sites_(other.gap_junction_sites_), spike_detectors_(other.spike_detectors_) { // unique_ptr's cannot be copy constructed, do a manual assignment @@ -209,6 +212,17 @@ public: return synapses_; } + ////////////////// + // gap-junction + ////////////////// + void add_gap_junction(segment_location location) + { + gap_junction_sites_.push_back(location); + } + const std::vector<gap_junction_instance>& gap_junction_sites() const { + return gap_junction_sites_; + } + ////////////////// // spike detectors ////////////////// @@ -251,6 +265,9 @@ private: // the synapses std::vector<synapse_instance> synapses_; + // the gap_junctions + std::vector<gap_junction_instance> gap_junction_sites_; + // the sensors std::vector<detector_instance> spike_detectors_; }; diff --git a/arbor/include/arbor/recipe.hpp b/arbor/include/arbor/recipe.hpp index 00147ff854cbecd8e824d91d4f13cd907abd992d..5521aa12e0294f0b45a8960b13b12115ad4b6de2 100644 --- a/arbor/include/arbor/recipe.hpp +++ b/arbor/include/arbor/recipe.hpp @@ -47,6 +47,15 @@ struct cell_connection { {} }; +struct gap_junction_connection { + cell_member_type local; + cell_member_type peer; + double ggap; + + gap_junction_connection(cell_member_type local, cell_member_type peer, double g): + local(local), peer(peer), ggap(g) {} +}; + class recipe { public: virtual cell_size_type num_cells() const = 0; @@ -58,13 +67,18 @@ public: virtual cell_size_type num_sources(cell_gid_type) const { return 0; } virtual cell_size_type num_targets(cell_gid_type) const { return 0; } virtual cell_size_type num_probes(cell_gid_type) const { return 0; } - + virtual cell_size_type num_gap_junction_sites(cell_gid_type gid) const { + return gap_junctions_on(gid).size(); + } virtual std::vector<event_generator> event_generators(cell_gid_type) const { return {}; } virtual std::vector<cell_connection> connections_on(cell_gid_type) const { return {}; } + virtual std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type) const { + return {}; + } virtual probe_info get_probe(cell_member_type probe_id) const { throw bad_probe_id(probe_id); diff --git a/arbor/mc_cell_group.cpp b/arbor/mc_cell_group.cpp index f06808099a25213dfbdd48691e31059cc67a5529..a2cc9f69f1e8a554f39165a6566d14aa2b2d4918 100644 --- a/arbor/mc_cell_group.cpp +++ b/arbor/mc_cell_group.cpp @@ -1,5 +1,5 @@ #include <functional> -#include <unordered_map> +#include <unordered_set> #include <vector> #include <arbor/assert.hpp> @@ -25,8 +25,9 @@ namespace arb { mc_cell_group::mc_cell_group(const std::vector<cell_gid_type>& gids, const recipe& rec, fvm_lowered_cell_ptr lowered): - gids_(gids), lowered_(std::move(lowered)) + lowered_(std::move(lowered)) { + generate_deps_gids(rec, gids); // Default to no binning of events set_binning_policy(binning_kind::none, 0); @@ -46,7 +47,7 @@ mc_cell_group::mc_cell_group(const std::vector<cell_gid_type>& gids, const recip target_handles_.reserve(n_targets); // Construct cell implementation, retrieving handles and maps. - lowered_->initialize(gids_, rec, target_handles_, probe_map_); + lowered_->initialize(gids_, deps_, rec, target_handles_, probe_map_); // Create a list of the global identifiers for the spike sources for (auto source_gid: gids_) { @@ -57,6 +58,59 @@ mc_cell_group::mc_cell_group(const std::vector<cell_gid_type>& gids, const recip spike_sources_.shrink_to_fit(); } +// Fills gids_ and deps_: gids_ are sorted such that members of the same supercell are consecutive +void mc_cell_group::generate_deps_gids(const recipe& rec, std::vector<cell_gid_type> gids) { + + std::unordered_map<cell_gid_type, cell_size_type> gid_to_loc; + for (auto i: util::count_along(gids)) { + gid_to_loc[gids[i]] = i; + } + + deps_.reserve(gids_.size()); + std::unordered_set<cell_gid_type> visited; + std::queue<cell_gid_type> scq; + + for (auto gid: gids) { + if (visited.count(gid)) continue; + visited.insert(gid); + + cell_size_type sc_size = 0; + scq.push(gid); + while (!scq.empty()) { + auto g = scq.front(); + scq.pop(); + + gids_.push_back(g); + ++sc_size; + + for (auto gj: rec.gap_junctions_on(g)) { + cell_gid_type peer = + gj.local.gid==g? gj.peer.gid: + gj.peer.gid==g? gj.local.gid: + throw bad_cell_description(cell_kind::cable1d_neuron, g); + + if (!gid_to_loc.count(peer)) { + // actually an error in the domain decomposition... + throw bad_cell_description(cell_kind::cable1d_neuron, g); + } + + if (!visited.count(peer)) { + visited.insert(peer); + scq.push(peer); + } + } + } + + deps_.push_back(sc_size>1? sc_size: 0); + deps_.insert(deps_.end(), sc_size-1, 0); + } + + perm_gids_.reserve(gids_.size()); + for (auto gid: gids_) { + perm_gids_.push_back(gid_to_loc[gid]); + } +} + void mc_cell_group::reset() { spikes_.clear(); @@ -85,7 +139,7 @@ void mc_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange& e // skip event binning if empty lanes are passed if (event_lanes.size()) { for (auto lid: util::count_along(gids_)) { - auto& lane = event_lanes[lid]; + auto& lane = event_lanes[perm_gids_[lid]]; for (auto e: lane) { if (e.time>=ep.tfinal) break; e.time = binners_[lid].bin(e.time, tstart); diff --git a/arbor/mc_cell_group.hpp b/arbor/mc_cell_group.hpp index 47653506aa642842a127ccf8e142684dfbb79b47..4efd99d053eaf98b2d6a6658731ba97736663a7e 100644 --- a/arbor/mc_cell_group.hpp +++ b/arbor/mc_cell_group.hpp @@ -57,10 +57,27 @@ public: void remove_all_samplers() override; + void generate_deps_gids(const recipe& rec, std::vector<cell_gid_type>); + + std::vector<cell_gid_type> get_gids() { + return gids_; + } + + std::vector<int> get_dependencies() { + return deps_; + } + private: // List of the gids of the cells in the group. std::vector<cell_gid_type> gids_; + // Permutation of the gids of the cells in the group. + // perm_gids_[i] is the index of gids_[i] in the original gids vector passed to the ctor + std::vector<cell_size_type> perm_gids_; + + // List of the dependencies of the cells in the group. + std::vector<int> deps_; + // Hash table for converting gid to local index std::unordered_map<cell_gid_type, cell_gid_type> gid_index_map_; diff --git a/arbor/partition_load_balance.cpp b/arbor/partition_load_balance.cpp index 29b7017424a3338ef8853755b72d32f2b267ac18..c393f51cce863d9d3f7afc2dfb28efc8e8d7c7a7 100644 --- a/arbor/partition_load_balance.cpp +++ b/arbor/partition_load_balance.cpp @@ -1,3 +1,5 @@ +#include <unordered_set> + #include <arbor/domain_decomposition.hpp> #include <arbor/load_balance.hpp> #include <arbor/recipe.hpp> @@ -21,16 +23,28 @@ domain_decomposition partition_load_balance( const bool gpu_avail = ctx->gpu->has_gpu(); struct partition_gid_domain { - partition_gid_domain(std::vector<cell_gid_type> divs): - gid_divisions(std::move(divs)) + partition_gid_domain(gathered_vector<cell_gid_type> divs, unsigned domains): + gids_by_rank(std::move(divs)), num_domains(domains) {} int operator()(cell_gid_type gid) const { - auto gid_part = util::partition_view(gid_divisions); - return gid_part.index(gid); + using namespace util; + auto rank_part = partition_view(gids_by_rank.partition()); + for (auto i: count_along(rank_part)) { + if (binary_search_index(subrange_view(gids_by_rank.values(), rank_part[i]), gid)) { + return i; + } + } + return -1; } - const std::vector<cell_gid_type> gid_divisions; + const gathered_vector<cell_gid_type> gids_by_rank; + unsigned num_domains; + }; + + struct cell_identifier { + cell_gid_type id; + bool is_super_cell; }; using util::make_span; @@ -53,11 +67,81 @@ domain_decomposition partition_load_balance( // Local load balance - std::unordered_map<cell_kind, std::vector<cell_gid_type>> kind_lists; + std::vector<std::vector<cell_gid_type>> super_cells; //cells connected by gj + std::vector<cell_gid_type> reg_cells; //independent cells + + // Map to track visited cells (cells that already belong to a group) + std::unordered_set<cell_gid_type> visited; + + // Connected components algorithm using BFS + std::queue<cell_gid_type> q; for (auto gid: make_span(gid_part[domain_id])) { - kind_lists[rec.get_cell_kind(gid)].push_back(gid); + if (!rec.gap_junctions_on(gid).empty()) { + // If cell hasn't been visited yet, must belong to new super_cell + // Perform BFS starting from that cell + if (!visited.count(gid)) { + visited.insert(gid); + std::vector<cell_gid_type> cg; + q.push(gid); + while (!q.empty()) { + auto element = q.front(); + q.pop(); + cg.push_back(element); + // Adjacency list + auto conns = rec.gap_junctions_on(element); + for (auto c: conns) { + if (element != c.local.gid && element != c.peer.gid) { + throw bad_cell_description(cell_kind::cable1d_neuron, element); + } + cell_member_type other = c.local.gid == element ? c.peer : c.local; + + if (!visited.count(other.gid)) { + visited.insert(other.gid); + q.push(other.gid); + } + } + } + super_cells.push_back(cg); + } + } + else { + // If cell has no gap_junctions, put in separate group of independent cells + reg_cells.push_back(gid); + } + } + + // Sort super_cell groups and only keep those where the first element in the group belongs to domain + super_cells.erase(std::remove_if(super_cells.begin(), super_cells.end(), + [gid_part, domain_id](std::vector<cell_gid_type>& cg) + { + std::sort(cg.begin(), cg.end()); + return cg.front() < gid_part[domain_id].first; + }), super_cells.end()); + + // Collect local gids that belong to this rank, and sort gids into kind lists + // kind_lists maps a cell_kind to a vector of either: + // 1. gids of regular cells (in reg_cells) + // 2. indices of supercells (in super_cells) + + std::vector<cell_gid_type> local_gids; + std::unordered_map<cell_kind, std::vector<cell_identifier>> kind_lists; + for (auto gid: reg_cells) { + local_gids.push_back(gid); + kind_lists[rec.get_cell_kind(gid)].push_back({gid, false}); + } + + for (unsigned i = 0; i < super_cells.size(); i++) { + auto kind = rec.get_cell_kind(super_cells[i].front()); + for (auto gid: super_cells[i]) { + if (rec.get_cell_kind(gid) != kind) { + throw gj_kind_mismatch(gid, super_cells[i].front()); + } + local_gids.push_back(gid); + } + kind_lists[kind].push_back({i, true}); } + // Create a flat vector of the cell kinds present on this node, // partitioned such that kinds for which GPU implementation are // listed before the others. This is a very primitive attempt at @@ -95,8 +179,19 @@ domain_decomposition partition_load_balance( } std::vector<cell_gid_type> group_elements; - for (auto gid: kind_lists[k]) { - group_elements.push_back(gid); + // group_elements are sorted such that the gids of all members of a super_cell are consecutive. + for (auto cell: kind_lists[k]) { + if (cell.is_super_cell == false) { + group_elements.push_back(cell.id); + } else { + if (group_elements.size() + super_cells[cell.id].size() > group_size && !group_elements.empty()) { + groups.push_back({k, std::move(group_elements), backend}); + group_elements.clear(); + } + for (auto gid: super_cells[cell.id]) { + group_elements.push_back(gid); + } + } if (group_elements.size()>=group_size) { groups.push_back({k, std::move(group_elements), backend}); group_elements.clear(); @@ -107,9 +202,14 @@ domain_decomposition partition_load_balance( } } - // calculate the number of local cells - auto rng = gid_part[domain_id]; - cell_size_type num_local_cells = rng.second - rng.first; + cell_size_type num_local_cells = local_gids.size(); + + // Exchange gid list with all other nodes + + util::sort(local_gids); + + // global all-to-all to gather a local copy of the global gid list on each node. + auto global_gids = ctx->distributed->gather_gids(local_gids); domain_decomposition d; d.num_domains = num_domains; @@ -117,7 +217,7 @@ domain_decomposition partition_load_balance( d.num_local_cells = num_local_cells; d.num_global_cells = num_global_cells; d.groups = std::move(groups); - d.gid_domain = partition_gid_domain(std::move(gid_divisions)); + d.gid_domain = partition_gid_domain(std::move(global_gids), num_domains); return d; } diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 9cfab47efaa9771b7eb1b013df93df54d27f3c01..990a5671730dbc18452ff783c16851e1a3c5ff47 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -8,3 +8,4 @@ add_subdirectory(generators) add_subdirectory(brunel) add_subdirectory(bench) add_subdirectory(ring) +add_subdirectory(gap_junctions) diff --git a/example/gap_junctions/CMakeLists.txt b/example/gap_junctions/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..60a68d61e9ddebfe0b359865e6b10565796f8293 --- /dev/null +++ b/example/gap_junctions/CMakeLists.txt @@ -0,0 +1,4 @@ +add_executable(gap_junctions EXCLUDE_FROM_ALL gap_junctions.cpp parameters.hpp) +add_dependencies(examples gap_junctions) + +target_link_libraries(gap_junctions PRIVATE arbor arborenv arbor-sup ext-json) diff --git a/example/gap_junctions/gap_junctions.cpp b/example/gap_junctions/gap_junctions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..94a697f7e7cea8d26a68e9e833c4486c94f1bad2 --- /dev/null +++ b/example/gap_junctions/gap_junctions.cpp @@ -0,0 +1,338 @@ +/* + * A miniapp that demonstrates how to make a model with gap junctions + * + */ + +#include <fstream> +#include <iomanip> +#include <iostream> + +#include <nlohmann/json.hpp> + +#include <arbor/assert_macro.hpp> +#include <arbor/common_types.hpp> +#include <arbor/context.hpp> +#include <arbor/load_balance.hpp> +#include <arbor/mc_cell.hpp> +#include <arbor/profile/meter_manager.hpp> +#include <arbor/profile/profiler.hpp> +#include <arbor/simple_sampler.hpp> +#include <arbor/simulation.hpp> +#include <arbor/recipe.hpp> +#include <arbor/version.hpp> + +#include <sup/ioutil.hpp> +#include <sup/json_meter.hpp> + +#include "parameters.hpp" + +#ifdef ARB_MPI_ENABLED +#include <mpi.h> +#include <arborenv/with_mpi.hpp> +#endif + +using arb::cell_gid_type; +using arb::cell_lid_type; +using arb::cell_size_type; +using arb::cell_member_type; +using arb::cell_kind; +using arb::time_type; +using arb::cell_probe_address; + +// Writes voltage trace as a json file. +void write_trace_json(const std::vector<arb::trace_data<double>>& trace, unsigned rank); + +// Generate a cell. +arb::mc_cell gj_cell(double delay, double duration); + +class gj_recipe: public arb::recipe { +public: + gj_recipe(gap_params params): params_(params) {} + + cell_size_type num_cells() const override { + return params_.cells_per_ring * params_.num_rings; + } + + arb::util::unique_any get_cell_description(cell_gid_type gid) const override { + return gj_cell(params_.delay*gid, params_.duration); + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable1d_neuron; + } + + // Each cell has one spike detector (at the soma). + cell_size_type num_sources(cell_gid_type gid) const override { + return 1; + } + + // The cell has one target synapse, which will be connected to cell gid-1. + cell_size_type num_targets(cell_gid_type gid) const override { + return 0; + } + + // There is one probe (for measuring voltage at the soma) on the cell. + cell_size_type num_probes(cell_gid_type gid) const override { + return 1; + } + + arb::probe_info get_probe(cell_member_type id) const override { + // Get the appropriate kind for measuring voltage. + cell_probe_address::probe_kind kind = cell_probe_address::membrane_voltage; + // Measure at the soma. + arb::segment_location loc(0, 1); + + return arb::probe_info{id, kind, cell_probe_address{loc, kind}}; + } + + arb::util::any get_global_properties(cell_kind k) const override { + arb::mc_cell_global_properties a; + a.temperature_K = 308.15; + return a; + } + + std::vector<arb::gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override{ + std::vector<arb::gap_junction_connection> conns; + + cell_gid_type ring_start_gid = (gid/params_.cells_per_ring) * params_.cells_per_ring; + cell_gid_type next_cell = (gid + 1) % params_.cells_per_ring + ring_start_gid; + cell_gid_type prev_cell = (gid - 1 + params_.cells_per_ring) % params_.cells_per_ring + ring_start_gid; + + // Our soma is connected to the next cell's dendrite + // Our dendrite is connected to the prev cell's soma + conns.push_back(arb::gap_junction_connection({next_cell, 1}, {gid, 0}, 0.0037)); // 1 is the id of the dendrite junction + conns.push_back(arb::gap_junction_connection({prev_cell, 0}, {gid, 1}, 0.0037)); // 0 is the id of the soma junction + + return conns; + } + +private: + gap_params params_; +}; + +struct cell_stats { + using size_type = unsigned; + size_type ncells = 0; + size_type nsegs = 0; + size_type ncomp = 0; + + cell_stats(arb::recipe& r) { +#ifdef ARB_MPI_ENABLED + int nranks, rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &nranks); + ncells = r.num_cells(); + size_type cells_per_rank = ncells/nranks; + size_type b = rank*cells_per_rank; + size_type e = (rank==nranks-1)? ncells: (rank+1)*cells_per_rank; + size_type nsegs_tmp = 0; + size_type ncomp_tmp = 0; + for (size_type i=b; i<e; ++i) { + auto c = arb::util::any_cast<arb::mc_cell>(r.get_cell_description(i)); + nsegs_tmp += c.num_segments(); + ncomp_tmp += c.num_compartments(); + } + MPI_Allreduce(&nsegs_tmp, &nsegs, 1, MPI_UNSIGNED, MPI_SUM, MPI_COMM_WORLD); + MPI_Allreduce(&ncomp_tmp, &ncomp, 1, MPI_UNSIGNED, MPI_SUM, MPI_COMM_WORLD); +#else + ncells = r.num_cells(); + for (size_type i=0; i<ncells; ++i) { + auto c = arb::util::any_cast<arb::mc_cell>(r.get_cell_description(i)); + nsegs += c.num_segments(); + ncomp += c.num_compartments(); + } +#endif + } + + friend std::ostream& operator<<(std::ostream& o, const cell_stats& s) { + return o << "cell stats: " + << s.ncells << " cells; " + << s.nsegs << " segments; " + << s.ncomp << " compartments."; + } +}; + + +int main(int argc, char** argv) { + try { + bool root = true; + +#ifdef ARB_MPI_ENABLED + arbenv::with_mpi guard(argc, argv, false); + auto context = arb::make_context(arb::proc_allocation(), MPI_COMM_WORLD); + { + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + root = rank==0; + } +#else + auto context = arb::make_context(); +#endif + +#ifdef ARB_PROFILE_ENABLED + arb::profile::profiler_initialize(context); +#endif + + std::cout << sup::mask_stream(root); + + // Print a banner with information about hardware configuration + std::cout << "gpu: " << (has_gpu(context)? "yes": "no") << "\n"; + std::cout << "threads: " << num_threads(context) << "\n"; + std::cout << "mpi: " << (has_mpi(context)? "yes": "no") << "\n"; + std::cout << "ranks: " << num_ranks(context) << "\n" << std::endl; + + auto params = read_options(argc, argv); + + arb::profile::meter_manager meters; + meters.start(context); + + // Create an instance of our recipe. + gj_recipe recipe(params); + + cell_stats stats(recipe); + std::cout << stats << "\n"; + + auto decomp = arb::partition_load_balance(recipe, context); + + // Construct the model. + arb::simulation sim(recipe, decomp, context); + + // Set up the probe that will measure voltage in the cell. + + auto sched = arb::regular_schedule(0.025); + // This is where the voltage samples will be stored as (time, value) pairs + std::vector<arb::trace_data<double>> voltage(decomp.num_local_cells); + + // Now attach the sampler at probe_id, with sampling schedule sched, writing to voltage + unsigned j=0; + for (auto g : decomp.groups) { + for (auto i : g.gids) { + auto t = recipe.get_probe({i, 0}); + sim.add_sampler(arb::one_probe(t.id), sched, arb::make_simple_sampler(voltage[j++])); + } + } + + // Set up recording of spikes to a vector on the root process. + std::vector<arb::spike> recorded_spikes; + if (root) { + sim.set_global_spike_callback( + [&recorded_spikes](const std::vector<arb::spike>& spikes) { + recorded_spikes.insert(recorded_spikes.end(), spikes.begin(), spikes.end()); + }); + } + + meters.checkpoint("model-init", context); + + std::cout << "running simulation" << std::endl; + // Run the simulation for 100 ms, with time steps of 0.025 ms. + sim.run(params.duration, 0.025); + + meters.checkpoint("model-run", context); + + auto ns = sim.num_spikes(); + + // Write spikes to file + if (root) { + std::cout << "\n" << ns << " spikes generated at rate of " + << params.duration/ns << " ms between spikes\n"; + std::ofstream fid("spikes.gdf"); + if (!fid.good()) { + std::cerr << "Warning: unable to open file spikes.gdf for spike output\n"; + } + else { + char linebuf[45]; + for (auto spike: recorded_spikes) { + auto n = std::snprintf( + linebuf, sizeof(linebuf), "%u %.4f\n", + unsigned{spike.source.gid}, float(spike.time)); + fid.write(linebuf, n); + } + } + } + + // Write the samples to a json file. + if (params.print_all) { + write_trace_json(voltage, arb::rank(context)); + } + + auto report = arb::profile::make_meter_report(meters, context); + std::cout << report; + } + catch (std::exception& e) { + std::cerr << "exception caught in gap junction miniapp:\n" << e.what() << "\n"; + return 1; + } + + return 0; +} + +void write_trace_json(const std::vector<arb::trace_data<double>>& trace, unsigned rank) { + for (unsigned i = 0; i < trace.size(); i++) { + std::string path = "./voltages_" + std::to_string(rank) + + "_" + std::to_string(i) + ".json"; + + nlohmann::json json; + json["name"] = "gj demo: cell " + std::to_string(i); + json["units"] = "mV"; + json["cell"] = std::to_string(i); + json["probe"] = "0"; + + auto &jt = json["data"]["time"]; + auto &jy = json["data"]["voltage"]; + + for (const auto &sample: trace[i]) { + jt.push_back(sample.t); + jy.push_back(sample.v); + } + + std::ofstream file(path); + file << std::setw(1) << json << "\n"; + } +} + +arb::mc_cell gj_cell(double delay, double duration) { + arb::mc_cell cell; + + arb::mechanism_desc nax("nax"); + arb::mechanism_desc kdrmt("kdrmt"); + arb::mechanism_desc kamt("kamt"); + arb::mechanism_desc pas("pas"); + + auto set_reg_params = [&]() { + nax["gbar"] = 0.04; + nax["sh"] = 10; + kdrmt["gbar"] = 0.0001; + kamt["gbar"] = 0.004; + pas["g"] = 1.0/12000.0; + pas["e"] = -65; + }; + + auto setup_seg = [&](auto seg) { + seg->rL = 100; + seg->cm = 0.018; + seg->add_mechanism(nax); + seg->add_mechanism(kdrmt); + seg->add_mechanism(kamt); + seg->add_mechanism(pas); + }; + + auto soma = cell.add_soma(22.360679775/2.0); + set_reg_params(); + setup_seg(soma); + + auto dend = cell.add_cable(0, arb::section_kind::dendrite, 3.0/2.0, 3.0/2.0, 300); //cable 1 + dend->set_compartments(1); + set_reg_params(); + setup_seg(dend); + + cell.add_detector({0,0}, 10); + + cell.add_gap_junction({0, 1}); //ggap in μS + cell.add_gap_junction({1, 1}); //ggap in μS + + arb::i_clamp stim(delay, duration, 0.2); + cell.add_stimulus({1, 0.95}, stim); + + return cell; +} + diff --git a/example/gap_junctions/parameters.hpp b/example/gap_junctions/parameters.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2c2e7487241b04cf9e211aa04c6b8a6d56919a3d --- /dev/null +++ b/example/gap_junctions/parameters.hpp @@ -0,0 +1,61 @@ +#include <iostream> + +#include <array> +#include <cmath> +#include <fstream> +#include <random> + +#include <arbor/mc_cell.hpp> + +#include <sup/json_params.hpp> + +struct gap_params { + gap_params() = default; + + std::string name = "default"; + unsigned cells_per_ring = 100; + unsigned num_rings = 10; + double duration = 100; + double delay = 0.5; + bool print_all = false; +}; + +gap_params read_options(int argc, char** argv) { + using sup::param_from_json; + + gap_params params; + if (argc<2) { + std::cout << "Using default parameters.\n"; + return params; + } + if (argc>2) { + throw std::runtime_error("More than command line one option not permitted."); + } + + std::string fname = argv[1]; + std::cout << "Loading parameters from file: " << fname << "\n"; + std::ifstream f(fname); + + if (!f.good()) { + throw std::runtime_error("Unable to open input parameter file: "+fname); + } + + nlohmann::json json; + json << f; + + param_from_json(params.name, "name", json); + param_from_json(params.cells_per_ring, "num-cells", json); + param_from_json(params.num_rings, "num-rings", json); + param_from_json(params.duration, "duration", json); + param_from_json(params.delay, "delay", json); + param_from_json(params.print_all, "print-all", json); + + if (!json.empty()) { + for (auto it=json.begin(); it!=json.end(); ++it) { + std::cout << " Warning: unused input parameter: \"" << it.key() << "\"\n"; + } + std::cout << "\n"; + } + + return params; +} diff --git a/example/gap_junctions/readme.md b/example/gap_junctions/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..1d32a1928c4cef5270458f5990889aa8c10ef9bc --- /dev/null +++ b/example/gap_junctions/readme.md @@ -0,0 +1,3 @@ +# Gap Junctions Example + +A miniapp that demonstrates how to describe how to build a network with gap junctions. diff --git a/mechanisms/mod/nax.mod b/mechanisms/mod/nax.mod index 931fe82fb08d7f4bee717e9d49290f5c2299bc87..911b835597a9c9e1e7e91186b532ca96483eb44f 100644 --- a/mechanisms/mod/nax.mod +++ b/mechanisms/mod/nax.mod @@ -90,5 +90,5 @@ PROCEDURE trates(vm,sh2,celsius) { FUNCTION trap0(v,th,a,q) { : trap0 = a * (v - th) / (1 - exp(-(v - th)/q)) - trap0 = -a*q*exprelr((v-th)/q) + trap0 = a*q*exprelr(-(v-th)/q) } diff --git a/test/simple_recipes.hpp b/test/simple_recipes.hpp index 92c392082191631e6ce7e98a4ad02c1efa4b288e..bcf528540890a947cb2888057d0628914dad4e5c 100644 --- a/test/simple_recipes.hpp +++ b/test/simple_recipes.hpp @@ -122,6 +122,5 @@ protected: std::vector<mc_cell> cells_; }; - } // namespace arb diff --git a/test/ubench/mech_vec.cpp b/test/ubench/mech_vec.cpp index 8c0388b856b9f864f2297d5339c28793a31aa202..0a6572f613e1ea1d8b96c76095c360f0c1cc32b2 100644 --- a/test/ubench/mech_vec.cpp +++ b/test/ubench/mech_vec.cpp @@ -210,7 +210,7 @@ void expsyn_1_branch_current(benchmark::State& state) { probe_association_map<probe_handle> probe_handles; fvm_cell cell((execution_context())); - cell.initialize(gids, rec_expsyn_1_branch, target_handles, probe_handles); + cell.initialize(gids, {0}, rec_expsyn_1_branch, target_handles, probe_handles); auto& m = find_mechanism("expsyn", cell); @@ -229,7 +229,7 @@ void expsyn_1_branch_state(benchmark::State& state) { probe_association_map<probe_handle> probe_handles; fvm_cell cell((execution_context())); - cell.initialize(gids, rec_expsyn_1_branch, target_handles, probe_handles); + cell.initialize(gids, {0}, rec_expsyn_1_branch, target_handles, probe_handles); auto& m = find_mechanism("expsyn", cell); @@ -247,7 +247,7 @@ void pas_1_branch_current(benchmark::State& state) { probe_association_map<probe_handle> probe_handles; fvm_cell cell((execution_context())); - cell.initialize(gids, rec_pas_1_branch, target_handles, probe_handles); + cell.initialize(gids, {0}, rec_pas_1_branch, target_handles, probe_handles); auto& m = find_mechanism("pas", cell); @@ -265,7 +265,7 @@ void pas_3_branches_current(benchmark::State& state) { probe_association_map<probe_handle> probe_handles; fvm_cell cell((execution_context())); - cell.initialize(gids, rec_pas_3_branches, target_handles, probe_handles); + cell.initialize(gids, {0}, rec_pas_3_branches, target_handles, probe_handles); auto& m = find_mechanism("pas", cell); @@ -283,7 +283,7 @@ void hh_1_branch_state(benchmark::State& state) { probe_association_map<probe_handle> probe_handles; fvm_cell cell((execution_context())); - cell.initialize(gids, rec_hh_1_branch, target_handles, probe_handles); + cell.initialize(gids, {0}, rec_hh_1_branch, target_handles, probe_handles); auto& m = find_mechanism("hh", cell); @@ -301,7 +301,7 @@ void hh_1_branch_current(benchmark::State& state) { probe_association_map<probe_handle> probe_handles; fvm_cell cell((execution_context())); - cell.initialize(gids, rec_hh_1_branch, target_handles, probe_handles); + cell.initialize(gids, {0}, rec_hh_1_branch, target_handles, probe_handles); auto& m = find_mechanism("hh", cell); @@ -319,7 +319,7 @@ void hh_3_branches_state(benchmark::State& state) { probe_association_map<probe_handle> probe_handles; fvm_cell cell((execution_context())); - cell.initialize(gids, rec_hh_3_branches, target_handles, probe_handles); + cell.initialize(gids, {0}, rec_hh_3_branches, target_handles, probe_handles); auto& m = find_mechanism("hh", cell); @@ -337,7 +337,7 @@ void hh_3_branches_current(benchmark::State& state) { probe_association_map<probe_handle> probe_handles; fvm_cell cell((execution_context())); - cell.initialize(gids, rec_hh_3_branches, target_handles, probe_handles); + cell.initialize(gids, {0}, rec_hh_3_branches, target_handles, probe_handles); auto& m = find_mechanism("hh", cell); diff --git a/test/unit-distributed/test_communicator.cpp b/test/unit-distributed/test_communicator.cpp index e40a13acee3b8906b9440a60663832973a7c9429..4eda7ccbba3a0378561f66801e9d9a532788528e 100644 --- a/test/unit-distributed/test_communicator.cpp +++ b/test/unit-distributed/test_communicator.cpp @@ -140,6 +140,45 @@ TEST(communicator, gather_spikes_variant) { } } +// Test low level gids_gather function when the number of gids per domain +// are not equal. +TEST(communicator, gather_gids_variant) { + const auto num_domains = g_context->distributed->size(); + const auto rank = g_context->distributed->id(); + + constexpr int scale = 10; + const auto n_local_gids = scale*rank; + auto sumn = [](unsigned n) {return scale*n*(n+1)/2;}; + + std::vector<cell_gid_type > local_gids; + const auto local_start_id = sumn(rank-1); + for (auto i=0; i<n_local_gids; ++i) { + local_gids.push_back(local_start_id + i); + } + + // Perform exchange + const auto global_gids = g_context->distributed->gather_gids(local_gids); + + // Test that partition information is correct + const auto& part =global_gids.partition(); + EXPECT_EQ(unsigned(num_domains+1), part.size()); + EXPECT_EQ(0, (int)part[0]); + for (auto i=1u; i<part.size(); ++i) { + EXPECT_EQ(sumn(i-1), part[i]); + } + + // Test that gids were correctly exchanged + for (auto domain=0; domain<num_domains; ++domain) { + auto source = sumn(domain-1); + const auto first_gid = global_gids.values().begin() + sumn(domain-1); + const auto last_gid = global_gids.values().begin() + sumn(domain); + const auto gids = util::make_range(first_gid, last_gid); + for (auto s: gids) { + EXPECT_EQ(s, source++); + } + } +} + namespace { // Population of cable and rss cells with ring connection topology. // Even gid are rss, and odd gid are cable cells. diff --git a/test/unit-distributed/test_domain_decomposition.cpp b/test/unit-distributed/test_domain_decomposition.cpp index 43f3127f7772a680af3285333dfad03e10e198b7..27e69004ddd0520720e008eb6b2e85c51fb3a471 100644 --- a/test/unit-distributed/test_domain_decomposition.cpp +++ b/test/unit-distributed/test_domain_decomposition.cpp @@ -67,6 +67,81 @@ namespace { private: cell_size_type size_; }; + + class gj_symmetric: public recipe { + public: + gj_symmetric(unsigned num_ranks): ncopies_(num_ranks){} + + cell_size_type num_cells() const override { + return size_*ncopies_; + } + + arb::util::unique_any get_cell_description(cell_gid_type) const override { + return {}; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable1d_neuron; + } + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { + unsigned shift = (gid/size_)*size_; + switch (gid % size_) { + case 1 : return { gap_junction_connection({7 + shift, 0}, {gid, 0}, 0.1)}; + case 2 : return { + gap_junction_connection({6 + shift, 0}, {gid, 0}, 0.1), + gap_junction_connection({9 + shift, 0}, {gid, 0}, 0.1) + }; + case 6 : return { + gap_junction_connection({2 + shift, 0}, {gid, 0}, 0.1), + gap_junction_connection({7 + shift, 0}, {gid, 0}, 0.1) + }; + case 7 : return { + gap_junction_connection({6 + shift, 0}, {gid, 0}, 0.1), + gap_junction_connection({1 + shift, 0}, {gid, 0}, 0.1) + }; + case 9 : return { gap_junction_connection({2 + shift, 0}, {gid, 0}, 0.1)}; + default : return {}; + } + } + + private: + cell_size_type size_ = 10; + unsigned ncopies_; + }; + + class gj_non_symmetric: public recipe { + public: + gj_non_symmetric(unsigned num_ranks): groups_(num_ranks), size_(num_ranks){} + + cell_size_type num_cells() const override { + return size_*groups_; + } + + arb::util::unique_any get_cell_description(cell_gid_type) const override { + return {}; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable1d_neuron; + } + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { + unsigned group = gid/groups_; + unsigned id = gid%size_; + if (id == group && group != (groups_ - 1)) { + return {gap_junction_connection({gid + size_, 0}, {gid, 0}, 0.1)}; + } + else if (id == group - 1) { + return {gap_junction_connection({gid - size_, 0}, {gid, 0}, 0.1)}; + } + else { + return {}; + } + } + + private: + unsigned groups_; + cell_size_type size_; + }; } TEST(domain_decomposition, homogeneous_population_mc) { @@ -75,7 +150,7 @@ TEST(domain_decomposition, homogeneous_population_mc) { // This assumption will not hold in the future, requiring and update to // the test. proc_allocation resources{1, -1}; -#ifdef ARB_TEST_MPI +#ifdef TEST_MPI auto ctx = make_context(resources, MPI_COMM_WORLD); #else auto ctx = make_context(resources); @@ -124,7 +199,7 @@ TEST(domain_decomposition, homogeneous_population_gpu) { proc_allocation resources; resources.num_threads = 1; -#ifdef ARB_TEST_MPI +#ifdef TEST_MPI auto ctx = make_context(resources, MPI_COMM_WORLD); #else auto ctx = make_context(resources); @@ -169,7 +244,7 @@ TEST(domain_decomposition, heterogeneous_population) { // This assumption will not hold in the future, requiring and update to // the test. proc_allocation resources{1, -1}; -#ifdef ARB_TEST_MPI +#ifdef TEST_MPI auto ctx = make_context(resources, MPI_COMM_WORLD); #else auto ctx = make_context(resources); @@ -217,3 +292,126 @@ TEST(domain_decomposition, heterogeneous_population) { } } +TEST(domain_decomposition, symmetric_groups) +{ + proc_allocation resources{1, -1}; + int nranks = 1; + int rank = 0; +#ifdef TEST_MPI + auto ctx = make_context(resources, MPI_COMM_WORLD); + MPI_Comm_size(MPI_COMM_WORLD, &nranks); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); +#else + auto ctx = make_context(resources); +#endif + auto R = gj_symmetric(nranks); + const auto D0 = partition_load_balance(R, ctx); + EXPECT_EQ(6u, D0.groups.size()); + + unsigned shift = rank*R.num_cells()/nranks; + std::vector<std::vector<cell_gid_type>> expected_groups0 = + { {0 + shift}, + {3 + shift}, + {4 + shift}, + {5 + shift}, + {8 + shift}, + {1 + shift, 2 + shift, 6 + shift, 7 + shift, 9 + shift} + }; + + for (unsigned i = 0; i < 6; i++){ + EXPECT_EQ(expected_groups0[i], D0.groups[i].gids); + } + + unsigned cells_per_rank = R.num_cells()/nranks; + for (unsigned i = 0; i < R.num_cells(); i++) { + EXPECT_EQ(i/cells_per_rank, (unsigned)D0.gid_domain(i)); + } + + // Test different group_hints + partition_hint_map hints; + hints[cell_kind::cable1d_neuron].cpu_group_size = R.num_cells(); + hints[cell_kind::cable1d_neuron].prefer_gpu = false; + + const auto D1 = partition_load_balance(R, ctx, hints); + EXPECT_EQ(1u, D1.groups.size()); + + std::vector<cell_gid_type> expected_groups1 = + { 0 + shift, 3 + shift, 4 + shift, 5 + shift, 8 + shift, + 1 + shift, 2 + shift, 6 + shift, 7 + shift, 9 + shift }; + + EXPECT_EQ(expected_groups1, D1.groups[0].gids); + + for (unsigned i = 0; i < R.num_cells(); i++) { + EXPECT_EQ(i/cells_per_rank, (unsigned)D1.gid_domain(i)); + } + + hints[cell_kind::cable1d_neuron].cpu_group_size = cells_per_rank/2; + hints[cell_kind::cable1d_neuron].prefer_gpu = false; + + const auto D2 = partition_load_balance(R, ctx, hints); + EXPECT_EQ(2u, D2.groups.size()); + + std::vector<std::vector<cell_gid_type>> expected_groups2 = + { { 0 + shift, 3 + shift, 4 + shift, 5 + shift, 8 + shift }, + { 1 + shift, 2 + shift, 6 + shift, 7 + shift, 9 + shift } }; + + for (unsigned i = 0; i < 2u; i++) { + EXPECT_EQ(expected_groups2[i], D2.groups[i].gids); + } + for (unsigned i = 0; i < R.num_cells(); i++) { + EXPECT_EQ(i/cells_per_rank, (unsigned)D2.gid_domain(i)); + } + +} + +TEST(domain_decomposition, non_symmetric_groups) +{ + proc_allocation resources{1, -1}; + int nranks = 1; + int rank = 0; +#ifdef TEST_MPI + auto ctx = make_context(resources, MPI_COMM_WORLD); + MPI_Comm_size(MPI_COMM_WORLD, &nranks); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); +#else + auto ctx = make_context(resources); +#endif + /*if (nranks == 1) { + return; + }*/ + + auto R = gj_non_symmetric(nranks); + const auto D = partition_load_balance(R, ctx); + + unsigned cells_per_rank = nranks; + // check groups + unsigned i = 0; + for (unsigned gid = rank*cells_per_rank; gid < (rank + 1)*cells_per_rank; gid++) { + if (gid % nranks == (unsigned)rank - 1) { + continue; + } + else if (gid % nranks == (unsigned)rank && rank != nranks - 1) { + std::vector<cell_gid_type> cg = {gid, gid+cells_per_rank}; + EXPECT_EQ(cg, D.groups[D.groups.size()-1].gids); + } + else { + std::vector<cell_gid_type> cg = {gid}; + EXPECT_EQ(cg, D.groups[i++].gids); + } + } + // check gid_domains + for (unsigned gid = 0; gid < R.num_cells(); gid++) { + auto group = gid/cells_per_rank; + auto idx = gid % cells_per_rank; + unsigned ngroups = nranks; + if (idx == group - 1) { + EXPECT_EQ(group - 1, (unsigned)D.gid_domain(gid)); + } + else if (idx == group && group != ngroups - 1) { + EXPECT_EQ(group, (unsigned)D.gid_domain(gid)); + } + else { + EXPECT_EQ(group, (unsigned)D.gid_domain(gid)); + } + } +} diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 977372b9f2d3f734d4ac7ff0157e4db82dd1c80d..1148c0023337f4cd014b68929594ae731d12245d 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -103,6 +103,7 @@ set(unit_sources test_spike_store.cpp test_stats.cpp test_strprintf.cpp + test_supercell.cpp test_swcio.cpp test_synapses.cpp test_thread.cpp diff --git a/test/unit/test_domain_decomposition.cpp b/test/unit/test_domain_decomposition.cpp index a35cef75753a0d221b65dd389ed8761e2095a749..d19313b1a9cf7f84a924500593824e7f9868c26a 100644 --- a/test/unit/test_domain_decomposition.cpp +++ b/test/unit/test_domain_decomposition.cpp @@ -54,6 +54,62 @@ namespace { private: cell_size_type size_; }; + + class gap_recipe: public recipe { + public: + gap_recipe() {} + + cell_size_type num_cells() const override { + return size_; + } + + arb::util::unique_any get_cell_description(cell_gid_type) const override { + mc_cell c; + c.add_soma(20); + c.add_gap_junction({0,1}); + return {std::move(c)}; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable1d_neuron; + } + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { + switch (gid) { + case 0 : return {gap_junction_connection({13, 0}, {0, 0}, 0.1)}; + case 2 : return { + gap_junction_connection({7, 0}, {2, 0}, 0.1), + gap_junction_connection({11, 0}, {2, 0}, 0.1) + }; + case 3 : return { + gap_junction_connection({4, 0}, {3, 0}, 0.1), + gap_junction_connection({8, 0}, {3, 0}, 0.1) + }; + case 4 : return { + gap_junction_connection({3, 0}, {4, 0}, 0.1), + gap_junction_connection({8, 0}, {4, 0}, 0.1), + gap_junction_connection({9, 0}, {4, 0}, 0.1) + }; + case 7 : return { + gap_junction_connection({2, 0}, {7, 0}, 0.1), + gap_junction_connection({11, 0}, {7, 0}, 0.1) + };; + case 8 : return { + gap_junction_connection({3, 0}, {8, 0}, 0.1), + gap_junction_connection({4, 0}, {8, 0}, 0.1) + };; + case 9 : return {gap_junction_connection({4, 0}, {9, 0}, 0.1)}; + case 11 : return { + gap_junction_connection({2, 0}, {11, 0}, 0.1), + gap_junction_connection({7, 0}, {11, 0}, 0.1) + }; + case 13 : return {gap_junction_connection({0, 0}, {13, 0}, 0.1)}; + default : return {}; + } + } + + private: + cell_size_type size_ = 15; + }; } // test assumes one domain @@ -251,3 +307,49 @@ TEST(domain_decomposition, hints) { EXPECT_EQ(expected_c1d_groups, c1d_groups); EXPECT_EQ(expected_ss_groups, ss_groups); } + +TEST(domain_decomposition, compulsory_groups) +{ + proc_allocation resources; + resources.num_threads = 1; + resources.gpu_id = -1; // disable GPU if available + auto ctx = make_context(resources); + + auto R = gap_recipe(); + const auto D0 = partition_load_balance(R, ctx); + EXPECT_EQ(9u, D0.groups.size()); + + std::vector<std::vector<cell_gid_type>> expected_groups0 = + { {1}, {5}, {6}, {10}, {12}, {14}, {0, 13}, {2, 7, 11}, {3, 4, 8, 9} }; + + for (unsigned i = 0; i < 9u; i++) { + EXPECT_EQ(expected_groups0[i], D0.groups[i].gids); + } + + // Test different group_hints + partition_hint_map hints; + hints[cell_kind::cable1d_neuron].cpu_group_size = 3; + hints[cell_kind::cable1d_neuron].prefer_gpu = false; + + const auto D1 = partition_load_balance(R, ctx, hints); + EXPECT_EQ(5u, D1.groups.size()); + + std::vector<std::vector<cell_gid_type>> expected_groups1 = + { {1, 5, 6}, {10, 12, 14}, {0, 13}, {2, 7, 11}, {3, 4, 8, 9} }; + + for (unsigned i = 0; i < 5u; i++) { + EXPECT_EQ(expected_groups1[i], D1.groups[i].gids); + } + + hints[cell_kind::cable1d_neuron].cpu_group_size = 20; + hints[cell_kind::cable1d_neuron].prefer_gpu = false; + + const auto D2 = partition_load_balance(R, ctx, hints); + EXPECT_EQ(1u, D2.groups.size()); + + std::vector<cell_gid_type> expected_groups2 = + { 1, 5, 6, 10, 12, 14, 0, 13, 2, 7, 11, 3, 4, 8, 9 }; + + EXPECT_EQ(expected_groups2, D2.groups[0].gids); + +} diff --git a/test/unit/test_dry_run_context.cpp b/test/unit/test_dry_run_context.cpp index b4c9a6d034b004fe8fd4aa0f88ba6a3805012532..d0c5c91e34200202ac7b1570e7215eb774a07d1a 100644 --- a/test/unit/test_dry_run_context.cpp +++ b/test/unit/test_dry_run_context.cpp @@ -97,3 +97,23 @@ TEST(dry_run_context, gather_spikes) EXPECT_EQ(part[3], spikes.size()*3); EXPECT_EQ(part[4], spikes.size()*4); } + +TEST(dry_run_context, gather_gids) +{ + distributed_context_handle ctx = arb::make_dry_run_context(4, 4); + using gvec = std::vector<arb::cell_gid_type>; + + gvec gids = {0, 1, 2, 3}; + gvec gathered_gids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + auto s = ctx->gather_gids(gids); + auto& part = s.partition(); + + EXPECT_EQ(s.values(), gathered_gids); + EXPECT_EQ(part.size(), 5u); + EXPECT_EQ(part[0], 0u); + EXPECT_EQ(part[1], gids.size()); + EXPECT_EQ(part[2], gids.size()*2); + EXPECT_EQ(part[3], gids.size()*3); + EXPECT_EQ(part[4], gids.size()*4); +} diff --git a/test/unit/test_event_delivery.cpp b/test/unit/test_event_delivery.cpp index d521f8bf644cadf247ac3641843bc303a97efbbc..905d8cb428a6ea48b8b5c5bf66ceb6c43f4f01f1 100644 --- a/test/unit/test_event_delivery.cpp +++ b/test/unit/test_event_delivery.cpp @@ -6,11 +6,6 @@ // * Inject events one per cell in a given order, and confirm generated spikes // are in the same order. -// Note: this test anticipates the gap junction PR; the #if guards below -// will be removed when that PR is merged. - -#define NO_GJ_YET - #include <arbor/common_types.hpp> #include <arbor/domain_decomposition.hpp> #include <arbor/simulation.hpp> @@ -36,9 +31,7 @@ struct test_recipe: public n_mc_cell_recipe { c.add_soma(10.)->add_mechanism("pas"); c.add_synapse({0, 0.5}, "expsyn"); c.add_detector({0, 0.5}, -64); -#ifndef NO_GJ_YET c.add_gap_junction({0, 0.5}); -#endif return c; } @@ -100,15 +93,13 @@ TEST(mc_event_delivery, two_interleaved_groups) { EXPECT_EQ((gid_vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), spike_gids); } -#ifndef NO_GJ_YET - typedef std::vector<std::pair<unsigned, unsigned>> cell_gj_pairs; struct test_recipe_gj: public test_recipe { explicit test_recipe_gj(int n, cell_gj_pairs gj_pairs): test_recipe(n), gj_pairs_(std::move(gj_pairs)) {} - cell_size_type num_gap_junctions(cell_gid_type) const override { return 1; } + cell_size_type num_gap_junction_sites(cell_gid_type) const override { return 1; } std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type i) const override { std::vector<gap_junction_connection> gjs; @@ -126,5 +117,3 @@ TEST(mc_event_delivery, gj_reordered) { gid_vector spike_gids = run_test_sim(R, {{0, 1, 2, 3, 4}}); EXPECT_EQ((gid_vector{0, 1, 2, 3, 4}), spike_gids); } - -#endif // ndef NO_GJ_YET diff --git a/test/unit/test_fvm_lowered.cpp b/test/unit/test_fvm_lowered.cpp index 80e73927bb5483d3a551a784488eaf0c546eab72..84e2e2264e0fd248f83a03e745b7ac9b2a1d5e89 100644 --- a/test/unit/test_fvm_lowered.cpp +++ b/test/unit/test_fvm_lowered.cpp @@ -95,7 +95,7 @@ TEST(fvm_lowered, matrix_init) probe_association_map<probe_handle> probe_map; fvm_cell fvcell(context); - fvcell.initialize({0}, cable1d_recipe(cell), targets, probe_map); + fvcell.initialize({0}, {0}, cable1d_recipe(cell), targets, probe_map); auto& J = fvcell.*private_matrix_ptr; EXPECT_EQ(J.size(), 11u); @@ -140,7 +140,7 @@ TEST(fvm_lowered, target_handles) { probe_association_map<probe_handle> probe_map; fvm_cell fvcell(context); - fvcell.initialize({0, 1}, cable1d_recipe(cells), targets, probe_map); + fvcell.initialize({0, 1}, {0, 0}, cable1d_recipe(cells), targets, probe_map); mechanism* expsyn = find_mechanism(fvcell, "expsyn"); ASSERT_TRUE(expsyn); @@ -204,7 +204,7 @@ TEST(fvm_lowered, stimulus) { probe_association_map<probe_handle> probe_map; fvm_cell fvcell(context); - fvcell.initialize({0}, cable1d_recipe(cells), targets, probe_map); + fvcell.initialize({0}, {0}, cable1d_recipe(cells), targets, probe_map); mechanism* stim = find_mechanism(fvcell, "_builtin_stimulus"); ASSERT_TRUE(stim); @@ -298,7 +298,7 @@ TEST(fvm_lowered, derived_mechs) { execution_context context; fvm_cell fvcell(context); - fvcell.initialize({0, 1, 2}, rec, targets, probe_map); + fvcell.initialize({0, 1, 2}, {0, 0, 0}, rec, targets, probe_map); // Both mechanisms will have the same internal name, "test_kin1". @@ -404,7 +404,7 @@ TEST(fvm_lowered, weighted_write_ion) { probe_association_map<probe_handle> probe_map; fvm_cell fvcell(context); - fvcell.initialize({0}, cable1d_recipe(c), targets, probe_map); + fvcell.initialize({0}, {0}, cable1d_recipe(c), targets, probe_map); auto& state = *(fvcell.*private_state_ptr).get(); auto& ion = state.ion_data.at(ionKind::ca); @@ -446,3 +446,247 @@ TEST(fvm_lowered, weighted_write_ion) { EXPECT_EQ(expected_iconc, ion_iconc); } +TEST(fvm_lowered, gj_coords_simple) { + using pair = std::pair<int, int>; + + class gap_recipe: public recipe { + public: + gap_recipe() {} + + cell_size_type num_cells() const override { return n_; } + cell_kind get_cell_kind(cell_gid_type) const override { return cell_kind::cable1d_neuron; } + util::unique_any get_cell_description(cell_gid_type gid) const override { + return {}; + } + std::vector<arb::gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override{ + std::vector<gap_junction_connection> conns; + conns.push_back(gap_junction_connection({(gid+1)%2, 0}, {gid, 0}, 0.5)); + return conns; + } + + protected: + cell_size_type n_ = 2; + }; + + execution_context context; + fvm_cell fvcell(context); + + gap_recipe rec; + std::vector<mc_cell> cells; + mc_cell c, d; + c.add_soma(2.1); + c.add_cable(0, section_kind::dendrite, 0.3, 0.2, 10); + c.segment(1)->set_compartments(5); + c.add_gap_junction({1, 0.8}); + cells.push_back(std::move(c)); + + d.add_soma(2.4); + d.add_cable(0, section_kind::dendrite, 0.3, 0.2, 10); + d.segment(1)->set_compartments(2); + d.add_gap_junction({1, 1}); + cells.push_back(std::move(d)); + + fvm_discretization D = fvm_discretize(cells); + + std::vector<cell_gid_type> gids = {0, 1}; + auto GJ = fvcell.fvm_gap_junctions(cells, gids, rec, D); + + auto weight = [&](fvm_value_type g, fvm_index_type i){ + return g * 1e3 / D.cv_area[i]; + }; + + EXPECT_EQ(pair({4,8}), GJ[0].loc); + EXPECT_EQ(weight(0.5, 4), GJ[0].weight); + + EXPECT_EQ(pair({8,4}), GJ[1].loc); + EXPECT_EQ(weight(0.5, 8), GJ[1].weight); +} + +TEST(fvm_lowered, gj_coords_complex) { + using pair = std::pair<int, int>; + + class gap_recipe: public recipe { + public: + gap_recipe() {} + + cell_size_type num_cells() const override { return n_; } + cell_kind get_cell_kind(cell_gid_type) const override { return cell_kind::cable1d_neuron; } + util::unique_any get_cell_description(cell_gid_type gid) const override { + return {}; + } + std::vector<arb::gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override{ + std::vector<gap_junction_connection> conns; + switch (gid) { + case 0 : return { + gap_junction_connection({2, 0}, {0, 1}, 0.01), + gap_junction_connection({1, 0}, {0, 0}, 0.03), + gap_junction_connection({1, 1}, {0, 0}, 0.04) + }; + case 1 : return { + gap_junction_connection({0, 0}, {1, 0}, 0.03), + gap_junction_connection({0, 0}, {1, 1}, 0.04), + gap_junction_connection({2, 1}, {1, 2}, 0.02), + gap_junction_connection({2, 2}, {1, 3}, 0.01) + }; + case 2 : return { + gap_junction_connection({0, 1}, {2, 0}, 0.01), + gap_junction_connection({1, 2}, {2, 1}, 0.02), + gap_junction_connection({1, 3}, {2, 2}, 0.01) + }; + default : return {}; + } + return conns; + } + + protected: + cell_size_type n_ = 3; + }; + + execution_context context; + fvm_cell fvcell(context); + + gap_recipe rec; + mc_cell c0, c1, c2; + std::vector<mc_cell> cells; + + // Make 3 cells + c0.add_soma(2.1); + c0.add_cable(0, section_kind::dendrite, 0.3, 0.2, 8); + c0.segment(1)->set_compartments(4); + + c1.add_soma(1.4); + c1.add_cable(0, section_kind::dendrite, 0.3, 0.5, 12); + c1.segment(1)->set_compartments(6); + c1.add_cable(1, section_kind::dendrite, 0.3, 0.2, 9); + c1.segment(2)->set_compartments(3); + c1.add_cable(1, section_kind::dendrite, 0.2, 0.2, 15); + c1.segment(3)->set_compartments(5); + + c2.add_soma(2.9); + c2.add_cable(0, section_kind::dendrite, 0.3, 0.5, 4); + c2.segment(1)->set_compartments(2); + c2.add_cable(1, section_kind::dendrite, 0.4, 0.2, 6); + c2.segment(2)->set_compartments(2); + c2.add_cable(1, section_kind::dendrite, 0.1, 0.2, 8); + c2.segment(3)->set_compartments(2); + c2.add_cable(2, section_kind::dendrite, 0.2, 0.2, 4); + c2.segment(4)->set_compartments(2); + c2.add_cable(2, section_kind::dendrite, 0.2, 0.2, 4); + c2.segment(5)->set_compartments(2); + + // Add 5 gap junctions + c0.add_gap_junction({1, 1}); + c0.add_gap_junction({1, 0.5}); + + c1.add_gap_junction({2, 1}); + c1.add_gap_junction({1, 1}); + c1.add_gap_junction({1, 0.45}); + c1.add_gap_junction({1, 0.1}); + + c2.add_gap_junction({1, 0.5}); + c2.add_gap_junction({4, 1}); + c2.add_gap_junction({2, 1}); + + cells.push_back(std::move(c0)); + cells.push_back(std::move(c1)); + cells.push_back(std::move(c2)); + + fvm_discretization D = fvm_discretize(cells); + std::vector<cell_gid_type> gids = {0, 1, 2}; + + auto GJ = fvcell.fvm_gap_junctions(cells, gids, rec, D); + EXPECT_EQ(10u, GJ.size()); + + auto weight = [&](fvm_value_type g, fvm_index_type i){ + return g * 1e3 / D.cv_area[i]; + }; + + std::vector<pair> expected_loc = {{4, 14}, {4,11}, {2,21}, {14, 4}, {11,4} ,{8,28}, {6, 24}, {21,2}, {28,8}, {24, 6}}; + std::vector<double> expected_weight = { + weight(0.03, 4), weight(0.04, 4), weight(0.01, 2), weight(0.03, 14), weight(0.04, 11), + weight(0.02, 8), weight(0.01, 6), weight(0.01, 21), weight(0.02, 28), weight(0.01, 24) + }; + + for (unsigned i = 0; i < GJ.size(); i++) { + bool found = false; + for (unsigned j = 0; j < expected_loc.size(); j++) { + if (expected_loc[j].first == GJ[i].loc.first && expected_loc[j].second == GJ[i].loc.second) { + found = true; + EXPECT_EQ(expected_weight[j], GJ[i].weight); + break; + } + } + EXPECT_TRUE(found); + } + std::cout << std::endl; +} + +TEST(fvm_lowered, cell_group_gj) { + using pair = std::pair<int, int>; + + class gap_recipe: public recipe { + public: + gap_recipe() {} + + cell_size_type num_cells() const override { return n_; } + cell_kind get_cell_kind(cell_gid_type) const override { return cell_kind::cable1d_neuron; } + util::unique_any get_cell_description(cell_gid_type gid) const override { + return {}; + } + std::vector<arb::gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override{ + std::vector<gap_junction_connection> conns; + if (gid % 2 == 0) { + // connect 5 of the first 10 cells in a ring; connect 5 of the second 10 cells in a ring + auto next_cell = gid == 8 ? 0 : (gid == 18 ? 10 : gid + 2); + auto prev_cell = gid == 0 ? 8 : (gid == 10 ? 18 : gid - 2); + conns.push_back(gap_junction_connection({next_cell, 0}, {gid, 0}, 0.03)); + conns.push_back(gap_junction_connection({prev_cell, 0}, {gid, 0}, 0.03)); + } + return conns; + } + + protected: + cell_size_type n_ = 20; + }; + execution_context context; + fvm_cell fvcell(context); + + gap_recipe rec; + std::vector<mc_cell> cell_group0; + std::vector<mc_cell> cell_group1; + + // Make 20 cells + for (unsigned i = 0; i < 20; i++) { + mc_cell c; + c.add_soma(2.1); + if (i % 2 == 0) { + c.add_gap_junction({0, 1}); + if (i < 10) { + cell_group0.push_back(std::move(c)); + } + else { + cell_group1.push_back(std::move(c)); + } + } + + } + + std::vector<cell_gid_type> gids_cg0 = { 0, 2, 4, 6, 8}; + std::vector<cell_gid_type> gids_cg1 = {10,12,14,16,18}; + + fvm_discretization D0 = fvm_discretize(cell_group0); + fvm_discretization D1 = fvm_discretize(cell_group1); + + auto GJ0 = fvcell.fvm_gap_junctions(cell_group0, gids_cg0, rec, D0); + auto GJ1 = fvcell.fvm_gap_junctions(cell_group1, gids_cg1, rec, D1); + + EXPECT_EQ(10u, GJ0.size()); + EXPECT_EQ(10u, GJ1.size()); + + std::vector<pair> expected_loc = {{0, 1}, {0, 4}, {1, 2}, {1, 0}, {2, 3} ,{2, 1}, {3, 4}, {3, 2}, {4, 0}, {4, 3}}; + + for (unsigned i = 0; i < GJ0.size(); i++) { + EXPECT_EQ(expected_loc[i], GJ0[i].loc); + EXPECT_EQ(expected_loc[i], GJ1[i].loc); + } +} diff --git a/test/unit/test_local_context.cpp b/test/unit/test_local_context.cpp index 0f42e206f22eceefb672039097c32a5e278aa50e..3bda37fc1ae03c544f0a015f5cddeb67150fce10 100644 --- a/test/unit/test_local_context.cpp +++ b/test/unit/test_local_context.cpp @@ -77,3 +77,19 @@ TEST(local_context, gather_spikes) EXPECT_EQ(part[0], 0u); EXPECT_EQ(part[1], spikes.size()); } + +TEST(local_context, gather_gids) +{ + arb::local_context ctx; + using gvec = std::vector<arb::cell_gid_type>; + + gvec gids = {0, 1, 2, 3, 4}; + + auto s = ctx.gather_gids(gids); + + auto& part = s.partition(); + EXPECT_EQ(s.values(), gids); + EXPECT_EQ(part.size(), 2u); + EXPECT_EQ(part[0], 0u); + EXPECT_EQ(part[1], gids.size()); +} diff --git a/test/unit/test_mc_cell_group.cpp b/test/unit/test_mc_cell_group.cpp index 66326c222c19320af8f16378335fb8c33cfbbe66..70077f320385110dc913ac1b338822b401d5517d 100644 --- a/test/unit/test_mc_cell_group.cpp +++ b/test/unit/test_mc_cell_group.cpp @@ -28,6 +28,128 @@ namespace { return c; } + + class gap_recipe_0: public recipe { + public: + gap_recipe_0() {} + + cell_size_type num_cells() const override { + return size_; + } + + arb::util::unique_any get_cell_description(cell_gid_type gid) const override { + mc_cell c; + c.add_soma(20); + c.add_gap_junction({0, 1}); + return {std::move(c)}; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable1d_neuron; + } + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { + switch (gid) { + case 0 : + return {gap_junction_connection({5, 0}, {0, 0}, 0.1)}; + case 2 : + return { + gap_junction_connection({3, 0}, {2, 0}, 0.1), + }; + case 3 : + return { + gap_junction_connection({7, 0}, {3, 0}, 0.1), + gap_junction_connection({3, 0}, {2, 0}, 0.1) + }; + case 5 : + return {gap_junction_connection({5, 0}, {0, 0}, 0.1)}; + case 7 : + return { + gap_junction_connection({3, 0}, {7, 0}, 0.1), + }; + default : + return {}; + } + } + + private: + cell_size_type size_ = 12; + }; + + class gap_recipe_1: public recipe { + public: + gap_recipe_1() {} + + cell_size_type num_cells() const override { + return size_; + } + + arb::util::unique_any get_cell_description(cell_gid_type) const override { + mc_cell c; + c.add_soma(20); + return {std::move(c)}; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable1d_neuron; + } + + private: + cell_size_type size_ = 12; + }; + + class gap_recipe_2: public recipe { + public: + gap_recipe_2() {} + + cell_size_type num_cells() const override { + return size_; + } + + arb::util::unique_any get_cell_description(cell_gid_type) const override { + mc_cell c; + c.add_soma(20); + c.add_gap_junction({0,1}); + return {std::move(c)}; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable1d_neuron; + } + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { + switch (gid) { + case 0 : + return { + gap_junction_connection({2, 0}, {0, 0}, 0.1), + gap_junction_connection({3, 0}, {0, 0}, 0.1), + gap_junction_connection({5, 0}, {0, 0}, 0.1) + }; + case 2 : + return { + gap_junction_connection({0, 0}, {2, 0}, 0.1), + gap_junction_connection({3, 0}, {2, 0}, 0.1), + gap_junction_connection({5, 0}, {2, 0}, 0.1) + }; + case 3 : + return { + gap_junction_connection({0, 0}, {3, 0}, 0.1), + gap_junction_connection({2, 0}, {3, 0}, 0.1), + gap_junction_connection({5, 0}, {3, 0}, 0.1) + }; + case 5 : + return { + gap_junction_connection({2, 0}, {5, 0}, 0.1), + gap_junction_connection({3, 0}, {5, 0}, 0.1), + gap_junction_connection({0, 0}, {5, 0}, 0.1) + }; + default : + return {}; + } + } + + private: + cell_size_type size_ = 12; + }; + } ACCESS_BIND( @@ -84,3 +206,31 @@ TEST(mc_cell_group, sources) { } } } + +TEST(mc_cell_group, generated_gids_deps_) { + { + std::vector<cell_gid_type> gids = {11u, 5u, 2u, 3u, 0u, 8u, 7u}; + mc_cell_group group{gids, gap_recipe_0(), lowered_cell()}; + + std::vector<cell_gid_type> expected_gids = {11u, 5u, 0u, 2u, 3u, 7u, 8u}; + std::vector<int> expected_deps = {0, 2, 0, 3, 0, 0, 0}; + EXPECT_EQ(expected_gids, group.get_gids()); + EXPECT_EQ(expected_deps, group.get_dependencies()); + } + { + std::vector<cell_gid_type> gids = {11u, 5u, 2u, 3u, 0u, 8u, 7u}; + mc_cell_group group{gids, gap_recipe_1(), lowered_cell()}; + + std::vector<int> expected_deps = {0, 0, 0, 0, 0, 0, 0}; + EXPECT_EQ(gids, group.get_gids()); + EXPECT_EQ(expected_deps, group.get_dependencies()); + } + { + std::vector<cell_gid_type> gids = {5u, 2u, 3u, 0u}; + mc_cell_group group{gids, gap_recipe_2(), lowered_cell()}; + + std::vector<int> expected_deps = {4, 0, 0, 0}; + EXPECT_EQ(gids, group.get_gids()); + EXPECT_EQ(expected_deps, group.get_dependencies()); + } +} diff --git a/test/unit/test_mech_temperature.cpp b/test/unit/test_mech_temperature.cpp index 8d152d9990fe131ff4fc91113721d163cb973911..81a11394babbc6a42f3cde0c79af1354ad9911fa 100644 --- a/test/unit/test_mech_temperature.cpp +++ b/test/unit/test_mech_temperature.cpp @@ -24,9 +24,11 @@ void run_celsius_test() { fvm_size_type ncv = 3; std::vector<fvm_index_type> cv_to_cell(ncv, 0); + std::vector<fvm_gap_junction> gj = {}; + std::vector<int> deps = {0}; auto celsius_test = cat.instance<backend>("celsius_test"); auto shared_state = std::make_unique<typename backend::shared_state>( - ncell, cv_to_cell, celsius_test->data_alignment()); + ncell, cv_to_cell, deps, gj, celsius_test->data_alignment()); mechanism::layout layout; layout.weight.assign(ncv, 1.); diff --git a/test/unit/test_probe.cpp b/test/unit/test_probe.cpp index 42326f35fb269ceaab7914aba786c46468ea76e8..2e52784a75973654d9dcaf27241aec19d0376b70 100644 --- a/test/unit/test_probe.cpp +++ b/test/unit/test_probe.cpp @@ -40,7 +40,7 @@ TEST(probe, fvm_lowered_cell) { probe_association_map<probe_handle> probe_map; fvm_cell lcell(context); - lcell.initialize({0}, rec, targets, probe_map); + lcell.initialize({0}, {0}, rec, targets, probe_map); EXPECT_EQ(3u, rec.num_probes(0)); EXPECT_EQ(3u, probe_map.size()); diff --git a/test/unit/test_supercell.cpp b/test/unit/test_supercell.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f4bebb2cd2769a18349758daae61599224e79324 --- /dev/null +++ b/test/unit/test_supercell.cpp @@ -0,0 +1,71 @@ +#include "../gtest.h" + +#include <cmath> +#include <tuple> +#include <vector> + +#include <arbor/constants.hpp> +#include <arbor/mechcat.hpp> +#include <arbor/util/optional.hpp> +#include <arbor/mc_cell.hpp> + +#include "backends/multicore/fvm.hpp" +#include "backends/multicore/mechanism.hpp" +#include "util/maputil.hpp" +#include "util/range.hpp" + +#include "common.hpp" +#include "mech_private_field_access.hpp" + +using namespace arb; + +using backend = ::arb::multicore::backend; +using shared_state = backend::shared_state; +using value_type = backend::value_type; +using size_type = backend::size_type; + +// Access to more mechanism protected data: + +ACCESS_BIND(const value_type* multicore::mechanism::*, vec_v_ptr, &multicore::mechanism::vec_v_) +ACCESS_BIND(value_type* multicore::mechanism::*, vec_i_ptr, &multicore::mechanism::vec_i_) + +TEST(supercell, sync_time_to) { + using value_type = multicore::backend::value_type; + using index_type = multicore::backend::index_type; + + int num_cell = 10; + + std::vector<fvm_gap_junction> gj = {}; + std::vector<index_type> deps = {4, 0, 0, 0, 3, 0, 0, 2, 0, 0}; + + shared_state state(num_cell, std::vector<index_type>(num_cell, 0), deps, gj, 1u); + + state.time_to = {0.3, 0.1, 0.2, 0.4, 0.5, 0.6, 0.1, 0.1, 0.6, 0.9}; + state.sync_time_to(); + + std::vector<value_type> expected = {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9}; + + for (unsigned i = 0; i < state.time_to.size(); i++) { + EXPECT_EQ(expected[i], state.time_to[i]); + } + + state.time_dep = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + state.time_to = {0.3, 0.1, 0.2, 0.4, 0.5, 0.6, 0.1, 0.1, 0.6, 0.9}; + expected = {0.3, 0.1, 0.2, 0.4, 0.5, 0.6, 0.1, 0.1, 0.6, 0.9}; + state.sync_time_to(); + + for (unsigned i = 0; i < state.time_to.size(); i++) { + EXPECT_EQ(expected[i], state.time_to[i]); + } + + state.time_dep = {10, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + state.time_to = {0.3, 0.1, 0.2, 0.4, 0.5, 0.6, 0.1, 0.1, 0.6, 0.9}; + expected = {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1}; + state.sync_time_to(); + + for (unsigned i = 0; i < state.time_to.size(); i++) { + EXPECT_EQ(expected[i], state.time_to[i]); + } + +} + diff --git a/test/unit/test_synapses.cpp b/test/unit/test_synapses.cpp index 1428455d82203481e8376d225a99cde1fec7e080..5210a7d46ada3f052ba12d1cc6268ef84df70647 100644 --- a/test/unit/test_synapses.cpp +++ b/test/unit/test_synapses.cpp @@ -85,8 +85,9 @@ TEST(synapses, syn_basic_state) { auto exp2syn = unique_cast<multicore::mechanism>(global_default_catalogue().instance<backend>("exp2syn")); ASSERT_TRUE(exp2syn); + std::vector<fvm_gap_junction> gj = {}; auto align = std::max(expsyn->data_alignment(), exp2syn->data_alignment()); - shared_state state(num_cell, std::vector<index_type>(num_comp, 0), align); + shared_state state(num_cell, std::vector<index_type>(num_comp, 0), std::vector<index_type>(num_cell, 0), gj, align); state.reset(-65., constant::hh_squid_temp); fill(state.current_density, 1.0);