diff --git a/miniapp/io.cpp b/miniapp/io.cpp index 344fd56135b232bef3563d6a3cc61b570868a360..c1d4f6fc0297b08e9311cdf5bef683e0de720b87 100644 --- a/miniapp/io.cpp +++ b/miniapp/io.cpp @@ -121,6 +121,8 @@ cl_options read_options(int argc, char** argv) { 100., // tfinal 0.025, // dt false, // all_to_all + false, // ring + 1, // group_size false, // probe_soma_only 0.0, // probe_ratio "trace_", // trace_prefix @@ -170,6 +172,11 @@ cl_options read_options(int argc, char** argv) { false, defopts.dt, "time", cmd); TCLAP::SwitchArg all_to_all_arg( "m","alltoall","all to all network", cmd, false); + TCLAP::SwitchArg ring_arg( + "r","ring","ring network", cmd, false); + TCLAP::ValueArg<uint32_t> group_size_arg( + "g", "group-size", "number of cells per cell group", + false, defopts.compartments_per_segment, "integer", cmd); TCLAP::ValueArg<double> probe_ratio_arg( "p", "probe-ratio", "proportion between 0 and 1 of cells to probe", false, defopts.probe_ratio, "proportion", cmd); @@ -206,6 +213,8 @@ cl_options read_options(int argc, char** argv) { update_option(options.dt, fopts, "dt"); update_option(options.tfinal, fopts, "tfinal"); update_option(options.all_to_all, fopts, "all_to_all"); + update_option(options.ring, fopts, "ring"); + update_option(options.group_size, fopts, "group_size"); update_option(options.probe_ratio, fopts, "probe_ratio"); update_option(options.probe_soma_only, fopts, "probe_soma_only"); update_option(options.trace_prefix, fopts, "trace_prefix"); @@ -239,12 +248,22 @@ cl_options read_options(int argc, char** argv) { update_option(options.tfinal, tfinal_arg); update_option(options.dt, dt_arg); update_option(options.all_to_all, all_to_all_arg); + update_option(options.ring, ring_arg); + update_option(options.group_size, group_size_arg); update_option(options.probe_ratio, probe_ratio_arg); update_option(options.probe_soma_only, probe_soma_only_arg); update_option(options.trace_prefix, trace_prefix_arg); update_option(options.trace_max_gid, trace_max_gid_arg); update_option(options.spike_file_output, spike_output_arg); + if (options.all_to_all && options.ring) { + throw usage_error("can specify at most one of --ring and --all-to-all"); + } + + if (options.group_size<1) { + throw usage_error("minimum of one cell per group"); + } + save_file = ofile_arg.getValue(); } catch (TCLAP::ArgException& e) { diff --git a/miniapp/io.hpp b/miniapp/io.hpp index c18e688c33ceff0c3966b5d66ae8eb1680bc48a1..1c5bba8163596dcf5bdcbed10113c901a1c816b0 100644 --- a/miniapp/io.hpp +++ b/miniapp/io.hpp @@ -21,6 +21,8 @@ struct cl_options { double tfinal; double dt; bool all_to_all; + bool ring; + uint32_t group_size; bool probe_soma_only; double probe_ratio; std::string trace_prefix; diff --git a/miniapp/miniapp.cpp b/miniapp/miniapp.cpp index f0983d0514f6eeddff2778aef45e4e0b78b5f9f1..9ad224efd7c1597ef848559b02921911cfe19b15 100644 --- a/miniapp/miniapp.cpp +++ b/miniapp/miniapp.cpp @@ -28,8 +28,8 @@ using namespace nest::mc; using global_policy = communication::global_policy; -//using lowered_cell = fvm::fvm_multicell<double, cell_local_size_type>; -using lowered_cell = fvm::fvm_cell<double, cell_local_size_type>; +using lowered_cell = fvm::fvm_multicell<double, cell_local_size_type>; +//using lowered_cell = fvm::fvm_cell<double, cell_local_size_type>; using model_type = model<lowered_cell>; using time_type = model_type::time_type; using sample_trace_type = sample_trace<time_type, model_type::value_type>; @@ -66,8 +66,16 @@ int main(int argc, char** argv) { auto recipe = make_recipe(options, pdist); auto cell_range = distribute_cells(recipe->num_cells()); + std::vector<cell_gid_type> group_divisions; + for (auto i = cell_range.first; i<cell_range.second; i+=options.group_size) { + group_divisions.push_back(i); + } + group_divisions.push_back(cell_range.second); + + EXPECTS(group_divisions.front() == cell_range.first); + EXPECTS(group_divisions.back() == cell_range.second); - model_type m(*recipe, cell_range.first, cell_range.second); + model_type m(*recipe, util::partition_view(group_divisions)); auto register_exporter = [] (const io::cl_options& options) { return @@ -183,6 +191,9 @@ std::unique_ptr<recipe> make_recipe(const io::cl_options& options, const probe_d if (options.all_to_all) { return make_basic_kgraph_recipe(options.cells, p, pdist); } + else if (options.ring) { + return make_basic_ring_recipe(options.cells, p, pdist); + } else { return make_basic_rgraph_recipe(options.cells, p, pdist); } diff --git a/scripts/tsplot b/scripts/tsplot index 81a5da78c06a24d71f727431ef237a18f2a13b0b..15977538a3ce5d977c7559e536d2e03af845e4e4 100755 --- a/scripts/tsplot +++ b/scripts/tsplot @@ -33,13 +33,19 @@ def parse_clargs(): P.add_argument('inputs', metavar='FILE', nargs='+', help='time series data in JSON format') P.add_argument('-t', '--trange', metavar='RANGE', dest='trange', - type=parse_range_spec, + type=parse_range_spec, help='restrict time axis to RANGE (see below)') P.add_argument('-g', '--group', metavar='KEY,...', dest='groupby', - type=lambda s: s.split(','), + type=lambda s: s.split(','), help='plot series with same KEYs on the same axes') P.add_argument('-o', '--output', metavar='FILE', dest='outfile', help='save plot to file FILE') + P.add_argument('--dpi', metavar='NUM', dest='dpi', + type=int, + help='set dpi for output image') + P.add_argument('--scale', metavar='NUM', dest='scale', + type=float, + help='scale size of output image by NUM') P.add_argument('-x', '--exclude', metavar='NUM', dest='exclude', type=float, help='remove extreme points outside NUM times the 0.9-interquantile range of the median') @@ -82,7 +88,7 @@ class TimeSeries: l_, lq, median, uq, u_ = np.percentile(yfinite, [0, 5.0, 50.0, 95.0, 100]) lb = median - iqr_factor*(uq-lq) ub = median + iqr_factor*(uq-lq) - + np_err_save = np.seterr(all='ignore') yex = np.ma.masked_where(np.isfinite(self.y)&(self.y<=ub)&(self.y>=lb), self.y) np.seterr(**np_err_save) @@ -104,7 +110,7 @@ class TimeSeries: def units(self): return self.meta.get('units',"") - + def trange(self): return (min(self.t), max(self.t)) @@ -155,7 +161,7 @@ def read_json_timeseries(source): ts_list.append(TimeSeries(times, jdata[key], **meta)) return ts_list - + def range_join(r, s): return (min(r[0], s[0]), max(r[1], s[1])) @@ -236,7 +242,7 @@ def make_palette(cm_name, n, cmin=0, cmax=1): M.cm.get_cmap(cm_name)) return [smap.to_rgba((2*i+1)/float(2*n)) for i in xrange(n)] -def plot_plots(plot_groups, save=None): +def plot_plots(plot_groups, save=None, dpi=None, scale=None): nplots = len(plot_groups) plot_groups = sorted(plot_groups, key=lambda g: g.group_label()) @@ -268,7 +274,7 @@ def plot_plots(plot_groups, save=None): lab = unit return lab - + uniq_units = list(set([s.units() for s in group.series])) uniq_units.sort() if len(uniq_units)>2: @@ -297,7 +303,7 @@ def plot_plots(plot_groups, save=None): cm, n in zip(['hot', 'winter'], [len(x) for x in series_by_unit])] lines = cycle(["-",(0,(3,1))]) - + first_plot = True for ui in xrange(len(uniq_units)): if not first_plot: @@ -310,7 +316,7 @@ def plot_plots(plot_groups, save=None): plot.get_yaxis().get_major_formatter().set_useOffset(False) plot.get_yaxis().set_major_locator(M.ticker.MaxNLocator(nbins=6)) - + plot.set_xlim(trange) colours = cycle(palette[ui]) @@ -336,10 +342,14 @@ def plot_plots(plot_groups, save=None): axis_ymin = min([ax.get_position().ymin for ax in figure.axes]) figure.text(0.5, axis_ymin - float(3)/figure.dpi, 'time', ha='center', va='center') if save: - figure.savefig(save) + if scale: + base = figure.get_size_inches() + figure.set_size_inches((base[0]*scale, base[1]*scale)) + + figure.savefig(save, dpi=dpi) else: P.show() - + args = parse_clargs() tss = [] for filename in args.inputs: @@ -360,4 +370,4 @@ plots = gather_ts_plots(tss, groupby) if not args.outfile: M.interactive(False) -plot_plots(plots, save=args.outfile) +plot_plots(plots, save=args.outfile, dpi=args.dpi, scale=args.scale) diff --git a/src/algorithms.hpp b/src/algorithms.hpp index 37eeaf370acd54ef326f70cc4ae17e46f1c6f2b0..f626fa98ddde880cf4fd71a1589dc7ba012764a4 100644 --- a/src/algorithms.hpp +++ b/src/algorithms.hpp @@ -51,6 +51,17 @@ C make_index(C const& c) return out; } +/// test for membership within half-open interval +template <typename X, typename I, typename J> +bool in_interval(const X& x, const I& lower, const J& upper) { + return x>=lower && x<upper; +} + +template <typename X, typename I, typename J> +bool in_interval(const X& x, const std::pair<I, J>& bounds) { + return x>=bounds.first && x<bounds.second; +} + /// works like std::is_sorted(), but with stronger condition that succesive /// elements must be greater than those before them template <typename C> diff --git a/src/cell_group.hpp b/src/cell_group.hpp index 8d8edcf56ccdfd1b46769fa49f63ffdcc2ac01a4..d8fd2f2ad18edf9e401658cf3c2372190e0226c7 100644 --- a/src/cell_group.hpp +++ b/src/cell_group.hpp @@ -2,13 +2,16 @@ #include <cstdint> #include <functional> +#include <iterator> #include <vector> +#include <algorithms.hpp> #include <cell.hpp> #include <common_types.hpp> #include <event_queue.hpp> #include <spike.hpp> #include <spike_source.hpp> +#include <util/partition.hpp> #include <util/range.hpp> #include <profiling/profiler.hpp> @@ -36,28 +39,37 @@ public: cell_group() = default; - cell_group(cell_gid_type gid, const cell& c) : - gid_base_{gid} + template <typename Cells> + cell_group(cell_gid_type first_gid, const Cells& cells): + gid_base_{first_gid} { - detector_handles_.resize(c.detectors().size()); - target_handles_.resize(c.synapses().size()); - probe_handles_.resize(c.probes().size()); + // Create lookup structure for probe and target ids. + build_handle_partitions(cells); + std::size_t n_probes = probe_handle_divisions_.back(); + std::size_t n_targets = target_handle_divisions_.back(); + std::size_t n_detectors = + algorithms::sum(util::transform_view(cells, [](const cell& c) { return c.detectors().size(); })); - cell_.initialize(util::singleton_view(c), detector_handles_, target_handles_, probe_handles_); + // Allocate space to store handles. + detector_handles_.resize(n_detectors); + target_handles_.resize(n_targets); + probe_handles_.resize(n_probes); - // Create spike detectors and associate them with globally unique source ids, - // as specified by cell gid and cell-local zero-based index. + cell_.initialize(cells, detector_handles_, target_handles_, probe_handles_); + // Create spike detectors and associate them with globally unique source ids. cell_gid_type source_gid = gid_base_; - cell_lid_type source_lid = 0u; - unsigned i = 0; - for (auto& d : c.detectors()) { - cell_member_type source_id{source_gid, source_lid++}; - - spike_sources_.push_back({ - source_id, spike_detector_type(cell_, detector_handles_[i], d.threshold, 0.f) - }); + for (const auto& cell: cells) { + cell_lid_type source_lid = 0u; + for (auto& d: cell.detectors()) { + cell_member_type source_id{source_gid, source_lid++}; + + spike_sources_.push_back({ + source_id, spike_detector_type(cell_, detector_handles_[i++], d.threshold, 0.f) + }); + } + ++source_gid; } } @@ -65,7 +77,7 @@ public: clear_spikes(); clear_events(); reset_samplers(); - //initialize_cells(); + cell_.reset(); for (auto& spike_source: spike_sources_) { spike_source.source.reset(cell_, 0.f); } @@ -112,13 +124,13 @@ public: // apply events if (next) { - auto handle = target_handles_[next->target.index]; + auto handle = get_target_handle(next->target); cell_.deliver_event(handle, next->weight); // apply events that are due within some epsilon of the current // time step. This should be a parameter. e.g. with for variable // order time stepping, use the minimum possible time step size. while(auto e = events_.pop_if_before(cell_.time()+dt/10.)) { - auto handle = target_handles_[e->target.index]; + auto handle = get_target_handle(next->target); cell_.deliver_event(handle, e->weight); } } @@ -151,9 +163,10 @@ public: } void add_sampler(cell_member_type probe_id, sampler_function s, time_type start_time = 0) { - EXPECTS(probe_id.gid==gid_base_); + auto handle = get_probe_handle(probe_id); + auto sampler_index = uint32_t(samplers_.size()); - samplers_.push_back({probe_handles_[probe_id.index], s}); + samplers_.push_back({handle, s}); sampler_start_times_.push_back(start_time); sample_events_.push({sampler_index, start_time}); } @@ -173,7 +186,7 @@ public: } value_type probe(cell_member_type probe_id) const { - return cell_.probe(probe_handles_[probe_id.index]); + return cell_.probe(get_probe_handle(probe_id)); } private: @@ -216,6 +229,50 @@ private: /// collection of samplers to be run against probes in this group std::vector<sampler_entry> samplers_; + + /// lookup table for probe ids -> local probe handle indices + std::vector<std::size_t> probe_handle_divisions_; + + /// lookup table for target ids -> local target handle indices + std::vector<std::size_t> target_handle_divisions_; + + /// build handle index lookup tables + template <typename Cells> + void build_handle_partitions(const Cells& cells) { + auto probe_counts = util::transform_view(cells, [](const cell& c) { return c.probes().size(); }); + auto target_counts = util::transform_view(cells, [](const cell& c) { return c.synapses().size(); }); + + make_partition(probe_handle_divisions_, probe_counts); + make_partition(target_handle_divisions_, target_counts); + } + + /// use handle partition to get index from id + template <typename Divisions> + std::size_t handle_partition_lookup(const Divisions& divisions_, cell_member_type id) const { + // NB: without any assertion checking, this would just be: + // return divisions_[id.gid-gid_base_]+id.index; + + EXPECTS(id.gid>=gid_base_); + + auto handle_partition = util::partition_view(divisions_); + EXPECTS(id.gid-gid_base_<handle_partition.size()); + + auto ival = handle_partition[id.gid-gid_base_]; + std::size_t i = ival.first + id.index; + EXPECTS(i<ival.second); + + return i; + } + + /// get probe handle from probe id + probe_handle get_probe_handle(cell_member_type probe_id) const { + return probe_handles_[handle_partition_lookup(probe_handle_divisions_, probe_id)]; + } + + /// get target handle from target id + target_handle get_target_handle(cell_member_type target_id) const { + return target_handles_[handle_partition_lookup(target_handle_divisions_, target_id)]; + } }; } // namespace mc diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp index 13d4a54303da74aefbcab0bc4d0ecc3636d7f472..6d1377f49213871fe5a8146217babb9866cdd6bc 100644 --- a/src/communication/communicator.hpp +++ b/src/communication/communicator.hpp @@ -13,6 +13,7 @@ #include <spike.hpp> #include <util/debug.hpp> #include <util/double_buffer.hpp> +#include <util/partition.hpp> namespace nest { namespace mc { @@ -41,19 +42,18 @@ public: using event_queue = std::vector<postsynaptic_spike_event<time_type>>; - communicator() = default; + using gid_partition_type = + util::partition_range<std::vector<cell_gid_type>::const_iterator>; - // TODO - // for now, still assuming one-to-one association cells <-> groups, - // so that 'group' gids as represented by their first cell gid are - // contiguous. - communicator(id_type cell_from, id_type cell_to): - cell_gid_from_(cell_from), cell_gid_to_(cell_to) + communicator() {} + + explicit communicator(gid_partition_type cell_gid_partition): + cell_gid_partition_(cell_gid_partition) {} cell_local_size_type num_groups_local() const { - return cell_gid_to_-cell_gid_from_; + return cell_gid_partition_.size(); } void add_connection(connection_type con) { @@ -63,7 +63,7 @@ public: /// returns true if the cell with gid is on the domain of the caller bool is_local_cell(id_type gid) const { - return gid>=cell_gid_from_ && gid<cell_gid_to_; + return algorithms::in_interval(gid, cell_gid_partition_.bounds()); } /// builds the optimized data structure @@ -135,9 +135,8 @@ public: private: std::size_t cell_group_index(cell_gid_type cell_gid) const { - // this will be more elaborate when there is more than one cell per cell group - EXPECTS(cell_gid>=cell_gid_from_ && cell_gid<cell_gid_to_); - return cell_gid-cell_gid_from_; + EXPECTS(is_local_cell(cell_gid)); + return cell_gid_partition_.find(cell_gid)-cell_gid_partition_.begin(); } std::vector<connection_type> connections_; @@ -145,8 +144,8 @@ private: communication_policy_type communication_policy_; uint64_t num_spikes_ = 0u; - id_type cell_gid_from_; - id_type cell_gid_to_; + + gid_partition_type cell_gid_partition_; }; } // namespace communication diff --git a/src/fvm_multicell.hpp b/src/fvm_multicell.hpp index 9100a8fb56789be83fb950f7ff666e8dc6c0ffa8..f885906f1ce52902eca89403d27cde9e25559924 100644 --- a/src/fvm_multicell.hpp +++ b/src/fvm_multicell.hpp @@ -373,20 +373,23 @@ void fvm_multicell<T, I>::initialize( EXPECTS(targets_count < targets_size); const auto& name = syn.mechanism.name(); - std::size_t index = 0; + std::size_t syn_mech_index = 0; if (syn_mech_indices.count(name)==0) { - index = syn_mech_map.size(); - syn_mech_indices[name] = index; + syn_mech_index = syn_mech_map.size(); + syn_mech_indices[name] = syn_mech_index; syn_mech_map.push_back(std::vector<size_type>{}); } else { - index = syn_mech_indices[name]; + syn_mech_index = syn_mech_indices[name]; } + auto& map_entry = syn_mech_map[syn_mech_index]; + size_type syn_comp = comps.first+find_compartment_index(syn.location, graph); - syn_mech_map[index].push_back(syn_comp); + size_type syn_index = map_entry.size(); + map_entry.push_back(syn_comp); - *target_hi++ = target_handle{index, syn_mech_map[index].size()}; + *target_hi++ = target_handle{syn_mech_index, syn_index}; ++targets_count; } diff --git a/src/model.hpp b/src/model.hpp index 5d0ce4fd8bcbe010cec3629e4b5d1adde38bcd96..2298e5af8365151654b2ca8a266f0dce2f82041b 100644 --- a/src/model.hpp +++ b/src/model.hpp @@ -15,6 +15,8 @@ #include <recipe.hpp> #include <thread_private_spike_store.hpp> #include <util/nop.hpp> +#include <util/partition.hpp> +#include <util/range.hpp> #include "trace_sampler.hpp" @@ -37,27 +39,36 @@ public: probe_spec probe; }; - model(const recipe& rec, cell_gid_type cell_from, cell_gid_type cell_to): - cell_from_(cell_from), - cell_to_(cell_to), - communicator_(cell_from, cell_to) + template <typename Iter> + model(const recipe& rec, const util::partition_range<Iter>& groups): + cell_group_divisions_(groups.divisions().begin(), groups.divisions().end()) { + // set up communicator based on partition + communicator_ = communicator_type{gid_partition()}; + // generate the cell groups in parallel, with one task per cell group - cell_groups_ = std::vector<cell_group_type>{cell_to_-cell_from_}; + cell_groups_ = std::vector<cell_group_type>{gid_partition().size()}; threading::parallel_vector<probe_record> probes; - threading::parallel_for::apply(cell_from_, cell_to_, + threading::parallel_for::apply(0, cell_groups_.size(), [&](cell_gid_type i) { PE("setup", "cells"); - auto cell = rec.get_cell(i); - auto idx = i-cell_from_; - cell_groups_[idx] = cell_group_type(i, cell); - - cell_lid_type j = 0; - for (const auto& probe: cell.probes()) { - cell_member_type probe_id{i,j++}; - probes.push_back({probe_id, probe}); + + auto gids = gid_partition()[i]; + std::vector<cell> cells{gids.second-gids.first}; + + for (auto gid: util::make_span(gids)) { + auto i = gid-gids.first; + cells[i] = rec.get_cell(gid); + + cell_lid_type j = 0; + for (const auto& probe: cells[i].probes()) { + cell_member_type probe_id{gid, j++}; + probes.push_back({probe_id, probe}); + } } + + cell_groups_[i] = cell_group_type(gids.first, cells); PL(2); }); @@ -65,7 +76,7 @@ public: probes_.assign(probes.begin(), probes.end()); // generate the network connections - for (cell_gid_type i=cell_from_; i<cell_to_; ++i) { + for (cell_gid_type i: util::make_span(gid_partition().bounds())) { for (const auto& cc: rec.connections_on(i)) { connection<time_type> conn{cc.source, cc.dest, cc.weight, cc.delay}; communicator_.add_connection(conn); @@ -187,11 +198,11 @@ public: } void attach_sampler(cell_member_type probe_id, sampler_function f, time_type tfrom = 0) { - // TODO: translate probe_id.gid to appropriate group, but for now 1-1. - if (probe_id.gid<cell_from_ || probe_id.gid>=cell_to_) { + if (!algorithms::in_interval(probe_id.gid, gid_partition().bounds())) { return; } - cell_groups_[probe_id.gid-cell_from_].add_sampler(probe_id, f, tfrom); + + cell_groups_[gid_partition().index(probe_id.gid)].add_sampler(probe_id, f, tfrom); } const std::vector<probe_record>& probes() const { return probes_; } @@ -199,12 +210,14 @@ public: std::size_t num_spikes() const { return communicator_.num_spikes(); } + std::size_t num_groups() const { return cell_groups_.size(); } + std::size_t num_cells() const { - // TODO: fix when the assumption that there is one cell per cell group is no longer valid - return num_groups(); + auto bounds = gid_partition().bounds(); + return bounds.second-bounds.first; } // register a callback that will perform a export of the global @@ -220,18 +233,22 @@ public: } private: - cell_gid_type cell_from_; - cell_gid_type cell_to_; + std::vector<cell_gid_type> cell_group_divisions_; + + auto gid_partition() const -> decltype(util::partition_view(cell_group_divisions_)) { + return util::partition_view(cell_group_divisions_); + } + time_type t_ = 0.; std::vector<cell_group_type> cell_groups_; communicator_type communicator_; std::vector<probe_record> probes_; using event_queue_type = typename communicator_type::event_queue; - util::double_buffer< std::vector<event_queue_type> > event_queues_; + util::double_buffer<std::vector<event_queue_type>> event_queues_; using local_spike_store_type = thread_private_spike_store<time_type>; - util::double_buffer< local_spike_store_type > local_spikes_; + util::double_buffer<local_spike_store_type> local_spikes_; spike_export_function global_export_callback_ = util::nop_function; spike_export_function local_export_callback_ = util::nop_function; diff --git a/src/util/debug.hpp b/src/util/debug.hpp index be2b3d51e4d7bf9b5d4c5fb9b7c78855c5e4b4c5..f189852dc1bd49685920af2c0effdabf5a9ac132 100644 --- a/src/util/debug.hpp +++ b/src/util/debug.hpp @@ -78,5 +78,6 @@ void debug_emit_trace(const char* file, int line, const char* varlist, const Arg (void)((condition) || \ nest::mc::util::global_failed_assertion_handler(#condition, __FILE__, __LINE__, DEBUG_FUNCTION_NAME)) #else - #define EXPECTS(condition) + #define EXPECTS(condition) \ + (void)(false && (condition)) #endif // def WITH_ASSERTIONS diff --git a/src/util/partition.hpp b/src/util/partition.hpp index b35f81de8fa1f40217e331dd96e09a11eba1c174..376b769ebd4c43a3aa62ee39764ca8b5a90efaf7 100644 --- a/src/util/partition.hpp +++ b/src/util/partition.hpp @@ -29,12 +29,17 @@ class partition_range: public range<partition_iterator<I>> { public: using typename base::iterator; using typename base::value_type; + using typename base::size_type; using base::left; using base::right; using base::front; using base::back; using base::empty; + static constexpr size_type npos = static_cast<size_type>(-1); + + partition_range() = default; + template <typename Seq> partition_range(const Seq& s): base{std::begin(s), upto(std::begin(s), std::end(s))} { EXPECTS(is_valid()); @@ -62,6 +67,11 @@ public: return iterator{std::prev(i)}; } + size_type index(const inner_value_type& x) const { + iterator i = find(x); + return i==right? npos: i-left; + } + // access to underlying divisions range<I> divisions() const { return {left.get(), std::next(right.get())}; diff --git a/src/util/partition_iterator.hpp b/src/util/partition_iterator.hpp index f13b82f25d43c604680ff4e96b3efd6936e98aea..dae1475ce91382801d6d827fd3f66f30da9c7a8d 100644 --- a/src/util/partition_iterator.hpp +++ b/src/util/partition_iterator.hpp @@ -37,6 +37,8 @@ public: using pointer = const value_type*; using reference = const value_type&; + partition_iterator() = default; + template < typename J, typename = enable_if_t<!std::is_same<decay_t<J>, partition_iterator>::value> diff --git a/src/util/range.hpp b/src/util/range.hpp index 6bcb4c2567d79873cbaffb19f9ca642bb4bddc61..89fcaddbfc8b0b6161f97bd5f3f2cb0cc2d20545 100644 --- a/src/util/range.hpp +++ b/src/util/range.hpp @@ -336,6 +336,30 @@ void append(Container &c, const Seq& seq) { c.insert(c.end(), seq.begin(), seq.end()); } +template <typename AssignableContainer, typename Seq> +AssignableContainer& assign(AssignableContainer& c, const Seq& seq) { + c.assign(seq.begin(), seq.end()); + return c; +} + +template <typename Seq> +range<typename sequence_traits<Seq>::iterator_type, typename sequence_traits<Seq>::sentinel_type> +range_view(Seq& seq) { + return make_range(std::begin(seq), std::end(seq)); +} + +template < + typename Seq, + typename Iter = typename sequence_traits<Seq>::iterator_type, + typename Size = typename sequence_traits<Seq>::size_type +> +enable_if_t<is_forward_iterator<Iter>::value, range<Iter>> +subrange_view(Seq& seq, Size bi, Size ei) { + Iter b = std::next(std::begin(seq), bi); + Iter e = std::next(b, ei-bi); + return make_range(b, e); +} + } // namespace util } // namespace mc } // namespace nest diff --git a/src/util/transform.hpp b/src/util/transform.hpp index d20b006102984e8dc40adb528cd5b4388aab21c5..a5ab0b348a61fb0f3a174058dc31fdb128776e24 100644 --- a/src/util/transform.hpp +++ b/src/util/transform.hpp @@ -37,6 +37,8 @@ public: using pointer = const value_type*; using reference = const value_type&; + transform_iterator() = default; + template <typename J, typename G> transform_iterator(J&& c, G&& g): inner_(std::forward<J>(c)), f_(std::forward<G>(g)) {} diff --git a/tests/unit/test_cell_group.cpp b/tests/unit/test_cell_group.cpp index 99b46e564d66edc03643fba738dcd134ef9ac90c..9b3f6e6a0b13f7338558889e40e47d7fe7ed99c8 100644 --- a/tests/unit/test_cell_group.cpp +++ b/tests/unit/test_cell_group.cpp @@ -1,8 +1,9 @@ #include "gtest.h" +#include <cell_group.hpp> #include <common_types.hpp> #include <fvm_cell.hpp> -#include <cell_group.hpp> +#include <util/range.hpp> #include "../test_common_cells.hpp" @@ -22,7 +23,7 @@ TEST(cell_group, test) using namespace nest::mc; using cell_group_type = cell_group<fvm::fvm_cell<double, cell_local_size_type>>; - auto group = cell_group_type{0, make_cell()}; + auto group = cell_group_type{0, util::singleton_view(make_cell())}; group.advance(50, 0.01); @@ -44,7 +45,7 @@ TEST(cell_group, sources) cell.add_detector({1, 0.3}, 2.3); cell_gid_type first_gid = 37u; - auto group = cell_group_type{first_gid, cell}; + auto group = cell_group_type{first_gid, util::singleton_view(cell)}; // expect group sources to be lexicographically sorted by source id // with gids in cell group's range and indices starting from zero diff --git a/tests/validation/validate_ball_and_stick.cpp b/tests/validation/validate_ball_and_stick.cpp index da34cfcff48d5bae6829edf81140b322c7a7c107..2a5ec45ccd575dc5f12ad79f565996d584b41a22 100644 --- a/tests/validation/validate_ball_and_stick.cpp +++ b/tests/validation/validate_ball_and_stick.cpp @@ -3,7 +3,8 @@ #include <common_types.hpp> #include <cell.hpp> -#include <fvm_cell.hpp> +//#include <fvm_cell.hpp> +#include <fvm_multicell.hpp> #include <util/range.hpp> #include "gtest.h" @@ -72,7 +73,7 @@ TEST(ball_and_stick, neuron_baseline) } }; - using fvm_cell = fvm::fvm_cell<double, cell_local_size_type>; + using fvm_cell = fvm::fvm_multicell<double, cell_local_size_type>; std::vector<fvm_cell::detector_handle> detectors(cell.detectors().size()); std::vector<fvm_cell::target_handle> targets(cell.synapses().size()); std::vector<fvm_cell::probe_handle> probes(cell.probes().size()); @@ -207,7 +208,7 @@ TEST(ball_and_3stick, neuron_baseline) } }; - using fvm_cell = fvm::fvm_cell<double, cell_local_size_type>; + using fvm_cell = fvm::fvm_multicell<double, cell_local_size_type>; std::vector<fvm_cell::detector_handle> detectors(cell.detectors().size()); std::vector<fvm_cell::target_handle> targets(cell.synapses().size()); std::vector<fvm_cell::probe_handle> probes(cell.probes().size()); diff --git a/tests/validation/validate_synapses.cpp b/tests/validation/validate_synapses.cpp index 9aa90abef6f6178a6285f9a1d5db2e9d1cb1d4f1..d1de3df06945b33466ccbbc105d412196c0a5f0b 100644 --- a/tests/validation/validate_synapses.cpp +++ b/tests/validation/validate_synapses.cpp @@ -8,6 +8,7 @@ #include <cell_group.hpp> #include <fvm_cell.hpp> #include <mechanism_interface.hpp> +#include <util/range.hpp> #include "gtest.h" #include "../test_util.hpp" @@ -108,7 +109,7 @@ void run_neuron_baseline(const char* syn_type, const char* data_file) std::vector<std::vector<double>> v(2); // make the lowered finite volume cell - cell_group<lowered_cell> group(0, cell); + cell_group<lowered_cell> group(0, util::singleton_view(cell)); // add the 3 spike events to the queue group.enqueue_events(synthetic_events);