From 4f49fe38ac1b0730ba874844fa61e86545926e57 Mon Sep 17 00:00:00 2001
From: Sam Yates <halfflat@gmail.com>
Date: Sat, 3 Sep 2016 01:56:14 +0200
Subject: [PATCH] Bugfixes; tie multicells to multi-cell cell group.

* Include new option -g in miniapp to specify cell group size.
* Miniapp -r option generates a ring network.
* Modify EXPECTS() so that when assertions are disabled we
  avoid unused variable warnings.
---
 miniapp/io.cpp                               |  19 ++++
 miniapp/io.hpp                               |   2 +
 miniapp/miniapp.cpp                          |  17 +++-
 scripts/tsplot                               |  34 ++++---
 src/algorithms.hpp                           |  11 ++
 src/cell_group.hpp                           | 101 +++++++++++++++----
 src/communication/communicator.hpp           |  27 +++--
 src/fvm_multicell.hpp                        |  15 +--
 src/model.hpp                                |  65 +++++++-----
 src/util/debug.hpp                           |   3 +-
 src/util/partition.hpp                       |  10 ++
 src/util/partition_iterator.hpp              |   2 +
 src/util/range.hpp                           |  24 +++++
 src/util/transform.hpp                       |   2 +
 tests/unit/test_cell_group.cpp               |   7 +-
 tests/validation/validate_ball_and_stick.cpp |   7 +-
 tests/validation/validate_synapses.cpp       |   3 +-
 17 files changed, 260 insertions(+), 89 deletions(-)

diff --git a/miniapp/io.cpp b/miniapp/io.cpp
index 344fd561..c1d4f6fc 100644
--- a/miniapp/io.cpp
+++ b/miniapp/io.cpp
@@ -121,6 +121,8 @@ cl_options read_options(int argc, char** argv) {
         100.,       // tfinal
         0.025,      // dt
         false,      // all_to_all
+        false,      // ring
+        1,          // group_size
         false,      // probe_soma_only
         0.0,        // probe_ratio
         "trace_",   // trace_prefix
@@ -170,6 +172,11 @@ cl_options read_options(int argc, char** argv) {
             false, defopts.dt, "time", cmd);
         TCLAP::SwitchArg all_to_all_arg(
             "m","alltoall","all to all network", cmd, false);
+        TCLAP::SwitchArg ring_arg(
+            "r","ring","ring network", cmd, false);
+        TCLAP::ValueArg<uint32_t> group_size_arg(
+            "g", "group-size", "number of cells per cell group",
+            false, defopts.compartments_per_segment, "integer", cmd);
         TCLAP::ValueArg<double> probe_ratio_arg(
             "p", "probe-ratio", "proportion between 0 and 1 of cells to probe",
             false, defopts.probe_ratio, "proportion", cmd);
@@ -206,6 +213,8 @@ cl_options read_options(int argc, char** argv) {
                     update_option(options.dt, fopts, "dt");
                     update_option(options.tfinal, fopts, "tfinal");
                     update_option(options.all_to_all, fopts, "all_to_all");
+                    update_option(options.ring, fopts, "ring");
+                    update_option(options.group_size, fopts, "group_size");
                     update_option(options.probe_ratio, fopts, "probe_ratio");
                     update_option(options.probe_soma_only, fopts, "probe_soma_only");
                     update_option(options.trace_prefix, fopts, "trace_prefix");
@@ -239,12 +248,22 @@ cl_options read_options(int argc, char** argv) {
         update_option(options.tfinal, tfinal_arg);
         update_option(options.dt, dt_arg);
         update_option(options.all_to_all, all_to_all_arg);
+        update_option(options.ring, ring_arg);
+        update_option(options.group_size, group_size_arg);
         update_option(options.probe_ratio, probe_ratio_arg);
         update_option(options.probe_soma_only, probe_soma_only_arg);
         update_option(options.trace_prefix, trace_prefix_arg);
         update_option(options.trace_max_gid, trace_max_gid_arg);
         update_option(options.spike_file_output, spike_output_arg);
 
+        if (options.all_to_all && options.ring) {
+            throw usage_error("can specify at most one of --ring and --all-to-all");
+        }
+
+        if (options.group_size<1) {
+            throw usage_error("minimum of one cell per group");
+        }
+
         save_file = ofile_arg.getValue();
     }
     catch (TCLAP::ArgException& e) {
diff --git a/miniapp/io.hpp b/miniapp/io.hpp
index c18e688c..1c5bba81 100644
--- a/miniapp/io.hpp
+++ b/miniapp/io.hpp
@@ -21,6 +21,8 @@ struct cl_options {
     double tfinal;
     double dt;
     bool all_to_all;
+    bool ring;
+    uint32_t group_size;
     bool probe_soma_only;
     double probe_ratio;
     std::string trace_prefix;
diff --git a/miniapp/miniapp.cpp b/miniapp/miniapp.cpp
index f0983d05..9ad224ef 100644
--- a/miniapp/miniapp.cpp
+++ b/miniapp/miniapp.cpp
@@ -28,8 +28,8 @@
 using namespace nest::mc;
 
 using global_policy = communication::global_policy;
-//using lowered_cell = fvm::fvm_multicell<double, cell_local_size_type>;
-using lowered_cell = fvm::fvm_cell<double, cell_local_size_type>;
+using lowered_cell = fvm::fvm_multicell<double, cell_local_size_type>;
+//using lowered_cell = fvm::fvm_cell<double, cell_local_size_type>;
 using model_type = model<lowered_cell>;
 using time_type = model_type::time_type;
 using sample_trace_type = sample_trace<time_type, model_type::value_type>;
@@ -66,8 +66,16 @@ int main(int argc, char** argv) {
         auto recipe = make_recipe(options, pdist);
         auto cell_range = distribute_cells(recipe->num_cells());
 
+        std::vector<cell_gid_type> group_divisions;
+        for (auto i = cell_range.first; i<cell_range.second; i+=options.group_size) {
+            group_divisions.push_back(i);
+        }
+        group_divisions.push_back(cell_range.second);
+
+        EXPECTS(group_divisions.front() == cell_range.first);
+        EXPECTS(group_divisions.back() == cell_range.second);
 
-        model_type m(*recipe, cell_range.first, cell_range.second);
+        model_type m(*recipe, util::partition_view(group_divisions));
 
         auto register_exporter = [] (const io::cl_options& options) {
             return
@@ -183,6 +191,9 @@ std::unique_ptr<recipe> make_recipe(const io::cl_options& options, const probe_d
     if (options.all_to_all) {
         return make_basic_kgraph_recipe(options.cells, p, pdist);
     }
+    else if (options.ring) {
+        return make_basic_ring_recipe(options.cells, p, pdist);
+    }
     else {
         return make_basic_rgraph_recipe(options.cells, p, pdist);
     }
diff --git a/scripts/tsplot b/scripts/tsplot
index 81a5da78..15977538 100755
--- a/scripts/tsplot
+++ b/scripts/tsplot
@@ -33,13 +33,19 @@ def parse_clargs():
     P.add_argument('inputs', metavar='FILE', nargs='+',
                    help='time series data in JSON format')
     P.add_argument('-t', '--trange', metavar='RANGE', dest='trange',
-                   type=parse_range_spec, 
+                   type=parse_range_spec,
                    help='restrict time axis to RANGE (see below)')
     P.add_argument('-g', '--group', metavar='KEY,...', dest='groupby',
-                   type=lambda s: s.split(','), 
+                   type=lambda s: s.split(','),
                    help='plot series with same KEYs on the same axes')
     P.add_argument('-o', '--output', metavar='FILE', dest='outfile',
                    help='save plot to file FILE')
+    P.add_argument('--dpi', metavar='NUM', dest='dpi',
+                   type=int,
+                   help='set dpi for output image')
+    P.add_argument('--scale', metavar='NUM', dest='scale',
+                   type=float,
+                   help='scale size of output image by NUM')
     P.add_argument('-x', '--exclude', metavar='NUM', dest='exclude',
                    type=float,
                    help='remove extreme points outside NUM times the 0.9-interquantile range of the median')
@@ -82,7 +88,7 @@ class TimeSeries:
         l_, lq, median, uq, u_ = np.percentile(yfinite, [0, 5.0, 50.0, 95.0, 100])
         lb = median - iqr_factor*(uq-lq)
         ub = median + iqr_factor*(uq-lq)
-        
+
         np_err_save = np.seterr(all='ignore')
         yex = np.ma.masked_where(np.isfinite(self.y)&(self.y<=ub)&(self.y>=lb), self.y)
         np.seterr(**np_err_save)
@@ -104,7 +110,7 @@ class TimeSeries:
 
     def units(self):
         return self.meta.get('units',"")
-        
+
     def trange(self):
         return (min(self.t), max(self.t))
 
@@ -155,7 +161,7 @@ def read_json_timeseries(source):
         ts_list.append(TimeSeries(times, jdata[key], **meta))
 
     return ts_list
-        
+
 def range_join(r, s):
     return (min(r[0], s[0]), max(r[1], s[1]))
 
@@ -236,7 +242,7 @@ def make_palette(cm_name, n, cmin=0, cmax=1):
                                M.cm.get_cmap(cm_name))
     return [smap.to_rgba((2*i+1)/float(2*n)) for i in xrange(n)]
 
-def plot_plots(plot_groups, save=None):
+def plot_plots(plot_groups, save=None, dpi=None, scale=None):
     nplots = len(plot_groups)
     plot_groups = sorted(plot_groups, key=lambda g: g.group_label())
 
@@ -268,7 +274,7 @@ def plot_plots(plot_groups, save=None):
                 lab = unit
 
             return lab
-            
+
         uniq_units = list(set([s.units() for s in group.series]))
         uniq_units.sort()
         if len(uniq_units)>2:
@@ -297,7 +303,7 @@ def plot_plots(plot_groups, save=None):
                 cm, n in zip(['hot', 'winter'],  [len(x) for x in series_by_unit])]
 
         lines = cycle(["-",(0,(3,1))])
-        
+
         first_plot = True
         for ui in xrange(len(uniq_units)):
             if not first_plot:
@@ -310,7 +316,7 @@ def plot_plots(plot_groups, save=None):
 
             plot.get_yaxis().get_major_formatter().set_useOffset(False)
             plot.get_yaxis().set_major_locator(M.ticker.MaxNLocator(nbins=6))
-            
+
             plot.set_xlim(trange)
 
             colours = cycle(palette[ui])
@@ -336,10 +342,14 @@ def plot_plots(plot_groups, save=None):
     axis_ymin = min([ax.get_position().ymin for ax in figure.axes])
     figure.text(0.5, axis_ymin - float(3)/figure.dpi, 'time', ha='center', va='center')
     if save:
-        figure.savefig(save)
+        if scale:
+            base = figure.get_size_inches()
+            figure.set_size_inches((base[0]*scale, base[1]*scale))
+
+        figure.savefig(save, dpi=dpi)
     else:
         P.show()
-        
+
 args = parse_clargs()
 tss = []
 for filename in args.inputs:
@@ -360,4 +370,4 @@ plots = gather_ts_plots(tss, groupby)
 if not args.outfile:
     M.interactive(False)
 
-plot_plots(plots, save=args.outfile)
+plot_plots(plots, save=args.outfile, dpi=args.dpi, scale=args.scale)
diff --git a/src/algorithms.hpp b/src/algorithms.hpp
index 37eeaf37..f626fa98 100644
--- a/src/algorithms.hpp
+++ b/src/algorithms.hpp
@@ -51,6 +51,17 @@ C make_index(C const& c)
     return out;
 }
 
+/// test for membership within half-open interval
+template <typename X, typename I, typename J>
+bool in_interval(const X& x, const I& lower, const J& upper) {
+    return x>=lower && x<upper;
+}
+
+template <typename X, typename I, typename J>
+bool in_interval(const X& x, const std::pair<I, J>& bounds) {
+    return x>=bounds.first && x<bounds.second;
+}
+
 /// works like std::is_sorted(), but with stronger condition that succesive
 /// elements must be greater than those before them
 template <typename C>
diff --git a/src/cell_group.hpp b/src/cell_group.hpp
index 8d8edcf5..d8fd2f2a 100644
--- a/src/cell_group.hpp
+++ b/src/cell_group.hpp
@@ -2,13 +2,16 @@
 
 #include <cstdint>
 #include <functional>
+#include <iterator>
 #include <vector>
 
+#include <algorithms.hpp>
 #include <cell.hpp>
 #include <common_types.hpp>
 #include <event_queue.hpp>
 #include <spike.hpp>
 #include <spike_source.hpp>
+#include <util/partition.hpp>
 #include <util/range.hpp>
 
 #include <profiling/profiler.hpp>
@@ -36,28 +39,37 @@ public:
 
     cell_group() = default;
 
-    cell_group(cell_gid_type gid, const cell& c) :
-        gid_base_{gid}
+    template <typename Cells>
+    cell_group(cell_gid_type first_gid, const Cells& cells):
+        gid_base_{first_gid}
     {
-        detector_handles_.resize(c.detectors().size());
-        target_handles_.resize(c.synapses().size());
-        probe_handles_.resize(c.probes().size());
+        // Create lookup structure for probe and target ids.
+        build_handle_partitions(cells);
+        std::size_t n_probes = probe_handle_divisions_.back();
+        std::size_t n_targets = target_handle_divisions_.back();
+        std::size_t n_detectors =
+            algorithms::sum(util::transform_view(cells, [](const cell& c) { return c.detectors().size(); }));
 
-        cell_.initialize(util::singleton_view(c), detector_handles_, target_handles_, probe_handles_);
+        // Allocate space to store handles.
+        detector_handles_.resize(n_detectors);
+        target_handles_.resize(n_targets);
+        probe_handles_.resize(n_probes);
 
-        // Create spike detectors and associate them with globally unique source ids,
-        // as specified by cell gid and cell-local zero-based index.
+        cell_.initialize(cells, detector_handles_, target_handles_, probe_handles_);
 
+        // Create spike detectors and associate them with globally unique source ids.
         cell_gid_type source_gid = gid_base_;
-        cell_lid_type source_lid = 0u;
-
         unsigned i = 0;
-        for (auto& d : c.detectors()) {
-            cell_member_type source_id{source_gid, source_lid++};
-
-            spike_sources_.push_back({
-                source_id, spike_detector_type(cell_, detector_handles_[i],  d.threshold, 0.f)
-            });
+        for (const auto& cell: cells) {
+            cell_lid_type source_lid = 0u;
+            for (auto& d: cell.detectors()) {
+                cell_member_type source_id{source_gid, source_lid++};
+
+                spike_sources_.push_back({
+                    source_id, spike_detector_type(cell_, detector_handles_[i++],  d.threshold, 0.f)
+                });
+            }
+            ++source_gid;
         }
     }
 
@@ -65,7 +77,7 @@ public:
         clear_spikes();
         clear_events();
         reset_samplers();
-        //initialize_cells();
+        cell_.reset();
         for (auto& spike_source: spike_sources_) {
             spike_source.source.reset(cell_, 0.f);
         }
@@ -112,13 +124,13 @@ public:
 
             // apply events
             if (next) {
-                auto handle = target_handles_[next->target.index];
+                auto handle = get_target_handle(next->target);
                 cell_.deliver_event(handle, next->weight);
                 // apply events that are due within some epsilon of the current
                 // time step. This should be a parameter. e.g. with for variable
                 // order time stepping, use the minimum possible time step size.
                 while(auto e = events_.pop_if_before(cell_.time()+dt/10.)) {
-                    auto handle = target_handles_[e->target.index];
+                    auto handle = get_target_handle(next->target);
                     cell_.deliver_event(handle, e->weight);
                 }
             }
@@ -151,9 +163,10 @@ public:
     }
 
     void add_sampler(cell_member_type probe_id, sampler_function s, time_type start_time = 0) {
-        EXPECTS(probe_id.gid==gid_base_);
+        auto handle = get_probe_handle(probe_id);
+
         auto sampler_index = uint32_t(samplers_.size());
-        samplers_.push_back({probe_handles_[probe_id.index], s});
+        samplers_.push_back({handle, s});
         sampler_start_times_.push_back(start_time);
         sample_events_.push({sampler_index, start_time});
     }
@@ -173,7 +186,7 @@ public:
     }
 
     value_type probe(cell_member_type probe_id) const {
-        return cell_.probe(probe_handles_[probe_id.index]);
+        return cell_.probe(get_probe_handle(probe_id));
     }
 
 private:
@@ -216,6 +229,50 @@ private:
 
     /// collection of samplers to be run against probes in this group
     std::vector<sampler_entry> samplers_;
+
+    /// lookup table for probe ids -> local probe handle indices
+    std::vector<std::size_t> probe_handle_divisions_;
+
+    /// lookup table for target ids -> local target handle indices
+    std::vector<std::size_t> target_handle_divisions_;
+
+    /// build handle index lookup tables
+    template <typename Cells>
+    void build_handle_partitions(const Cells& cells) {
+        auto probe_counts = util::transform_view(cells, [](const cell& c) { return c.probes().size(); });
+        auto target_counts = util::transform_view(cells, [](const cell& c) { return c.synapses().size(); });
+
+        make_partition(probe_handle_divisions_, probe_counts);
+        make_partition(target_handle_divisions_, target_counts);
+    }
+
+    /// use handle partition to get index from id
+    template <typename Divisions>
+    std::size_t handle_partition_lookup(const Divisions& divisions_, cell_member_type id) const {
+        // NB: without any assertion checking, this would just be:
+        // return divisions_[id.gid-gid_base_]+id.index;
+
+        EXPECTS(id.gid>=gid_base_);
+
+        auto handle_partition = util::partition_view(divisions_);
+        EXPECTS(id.gid-gid_base_<handle_partition.size());
+
+        auto ival = handle_partition[id.gid-gid_base_];
+        std::size_t i = ival.first + id.index;
+        EXPECTS(i<ival.second);
+
+        return i;
+    }
+
+    /// get probe handle from probe id
+    probe_handle get_probe_handle(cell_member_type probe_id) const {
+        return probe_handles_[handle_partition_lookup(probe_handle_divisions_, probe_id)];
+    }
+
+    /// get target handle from target id
+    target_handle get_target_handle(cell_member_type target_id) const {
+        return target_handles_[handle_partition_lookup(target_handle_divisions_, target_id)];
+    }
 };
 
 } // namespace mc
diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp
index 13d4a543..6d1377f4 100644
--- a/src/communication/communicator.hpp
+++ b/src/communication/communicator.hpp
@@ -13,6 +13,7 @@
 #include <spike.hpp>
 #include <util/debug.hpp>
 #include <util/double_buffer.hpp>
+#include <util/partition.hpp>
 
 namespace nest {
 namespace mc {
@@ -41,19 +42,18 @@ public:
     using event_queue =
         std::vector<postsynaptic_spike_event<time_type>>;
 
-    communicator() = default;
+    using gid_partition_type =
+        util::partition_range<std::vector<cell_gid_type>::const_iterator>;
 
-    // TODO
-    // for now, still assuming one-to-one association cells <-> groups,
-    // so that 'group' gids as represented by their first cell gid are
-    // contiguous.
-    communicator(id_type cell_from, id_type cell_to):
-        cell_gid_from_(cell_from), cell_gid_to_(cell_to)
+    communicator() {}
+
+    explicit communicator(gid_partition_type cell_gid_partition):
+        cell_gid_partition_(cell_gid_partition)
     {}
 
     cell_local_size_type num_groups_local() const
     {
-        return cell_gid_to_-cell_gid_from_;
+        return cell_gid_partition_.size();
     }
 
     void add_connection(connection_type con) {
@@ -63,7 +63,7 @@ public:
 
     /// returns true if the cell with gid is on the domain of the caller
     bool is_local_cell(id_type gid) const {
-        return gid>=cell_gid_from_ && gid<cell_gid_to_;
+        return algorithms::in_interval(gid, cell_gid_partition_.bounds());
     }
 
     /// builds the optimized data structure
@@ -135,9 +135,8 @@ public:
 
 private:
     std::size_t cell_group_index(cell_gid_type cell_gid) const {
-        // this will be more elaborate when there is more than one cell per cell group
-        EXPECTS(cell_gid>=cell_gid_from_ && cell_gid<cell_gid_to_);
-        return cell_gid-cell_gid_from_;
+        EXPECTS(is_local_cell(cell_gid));
+        return cell_gid_partition_.find(cell_gid)-cell_gid_partition_.begin();
     }
 
     std::vector<connection_type> connections_;
@@ -145,8 +144,8 @@ private:
     communication_policy_type communication_policy_;
 
     uint64_t num_spikes_ = 0u;
-    id_type cell_gid_from_;
-    id_type cell_gid_to_;
+
+    gid_partition_type cell_gid_partition_;
 };
 
 } // namespace communication
diff --git a/src/fvm_multicell.hpp b/src/fvm_multicell.hpp
index 9100a8fb..f885906f 100644
--- a/src/fvm_multicell.hpp
+++ b/src/fvm_multicell.hpp
@@ -373,20 +373,23 @@ void fvm_multicell<T, I>::initialize(
             EXPECTS(targets_count < targets_size);
 
             const auto& name = syn.mechanism.name();
-            std::size_t index = 0;
+            std::size_t syn_mech_index = 0;
             if (syn_mech_indices.count(name)==0) {
-                index = syn_mech_map.size();
-                syn_mech_indices[name] = index;
+                syn_mech_index = syn_mech_map.size();
+                syn_mech_indices[name] = syn_mech_index;
                 syn_mech_map.push_back(std::vector<size_type>{});
             }
             else {
-                index = syn_mech_indices[name];
+                syn_mech_index = syn_mech_indices[name];
             }
 
+            auto& map_entry = syn_mech_map[syn_mech_index];
+
             size_type syn_comp = comps.first+find_compartment_index(syn.location, graph);
-            syn_mech_map[index].push_back(syn_comp);
+            size_type syn_index = map_entry.size();
+            map_entry.push_back(syn_comp);
 
-            *target_hi++ = target_handle{index, syn_mech_map[index].size()};
+            *target_hi++ = target_handle{syn_mech_index, syn_index};
             ++targets_count;
         }
 
diff --git a/src/model.hpp b/src/model.hpp
index 5d0ce4fd..2298e5af 100644
--- a/src/model.hpp
+++ b/src/model.hpp
@@ -15,6 +15,8 @@
 #include <recipe.hpp>
 #include <thread_private_spike_store.hpp>
 #include <util/nop.hpp>
+#include <util/partition.hpp>
+#include <util/range.hpp>
 
 #include "trace_sampler.hpp"
 
@@ -37,27 +39,36 @@ public:
         probe_spec probe;
     };
 
-    model(const recipe& rec, cell_gid_type cell_from, cell_gid_type cell_to):
-        cell_from_(cell_from),
-        cell_to_(cell_to),
-        communicator_(cell_from, cell_to)
+    template <typename Iter>
+    model(const recipe& rec, const util::partition_range<Iter>& groups):
+        cell_group_divisions_(groups.divisions().begin(), groups.divisions().end())
     {
+        // set up communicator based on partition
+        communicator_ = communicator_type{gid_partition()};
+
         // generate the cell groups in parallel, with one task per cell group
-        cell_groups_ = std::vector<cell_group_type>{cell_to_-cell_from_};
+        cell_groups_ = std::vector<cell_group_type>{gid_partition().size()};
         threading::parallel_vector<probe_record> probes;
 
-        threading::parallel_for::apply(cell_from_, cell_to_,
+        threading::parallel_for::apply(0, cell_groups_.size(),
             [&](cell_gid_type i) {
                 PE("setup", "cells");
-                auto cell = rec.get_cell(i);
-                auto idx = i-cell_from_;
-                cell_groups_[idx] = cell_group_type(i, cell);
-
-                cell_lid_type j = 0;
-                for (const auto& probe: cell.probes()) {
-                    cell_member_type probe_id{i,j++};
-                    probes.push_back({probe_id, probe});
+
+                auto gids = gid_partition()[i];
+                std::vector<cell> cells{gids.second-gids.first};
+
+                for (auto gid: util::make_span(gids)) {
+                    auto i = gid-gids.first;
+                    cells[i] = rec.get_cell(gid);
+
+                    cell_lid_type j = 0;
+                    for (const auto& probe: cells[i].probes()) {
+                        cell_member_type probe_id{gid, j++};
+                        probes.push_back({probe_id, probe});
+                    }
                 }
+
+                cell_groups_[i] = cell_group_type(gids.first, cells);
                 PL(2);
             });
 
@@ -65,7 +76,7 @@ public:
         probes_.assign(probes.begin(), probes.end());
 
         // generate the network connections
-        for (cell_gid_type i=cell_from_; i<cell_to_; ++i) {
+        for (cell_gid_type i: util::make_span(gid_partition().bounds())) {
             for (const auto& cc: rec.connections_on(i)) {
                 connection<time_type> conn{cc.source, cc.dest, cc.weight, cc.delay};
                 communicator_.add_connection(conn);
@@ -187,11 +198,11 @@ public:
     }
 
     void attach_sampler(cell_member_type probe_id, sampler_function f, time_type tfrom = 0) {
-        // TODO: translate probe_id.gid to appropriate group, but for now 1-1.
-        if (probe_id.gid<cell_from_ || probe_id.gid>=cell_to_) {
+        if (!algorithms::in_interval(probe_id.gid, gid_partition().bounds())) {
             return;
         }
-        cell_groups_[probe_id.gid-cell_from_].add_sampler(probe_id, f, tfrom);
+
+        cell_groups_[gid_partition().index(probe_id.gid)].add_sampler(probe_id, f, tfrom);
     }
 
     const std::vector<probe_record>& probes() const { return probes_; }
@@ -199,12 +210,14 @@ public:
     std::size_t num_spikes() const {
         return communicator_.num_spikes();
     }
+
     std::size_t num_groups() const {
         return cell_groups_.size();
     }
+
     std::size_t num_cells() const {
-        // TODO: fix when the assumption that there is one cell per cell group is no longer valid
-        return num_groups();
+        auto bounds = gid_partition().bounds();
+        return bounds.second-bounds.first;
     }
 
     // register a callback that will perform a export of the global
@@ -220,18 +233,22 @@ public:
     }
 
 private:
-    cell_gid_type cell_from_;
-    cell_gid_type cell_to_;
+    std::vector<cell_gid_type> cell_group_divisions_;
+
+    auto gid_partition() const -> decltype(util::partition_view(cell_group_divisions_)) {
+        return util::partition_view(cell_group_divisions_);
+    }
+
     time_type t_ = 0.;
     std::vector<cell_group_type> cell_groups_;
     communicator_type communicator_;
     std::vector<probe_record> probes_;
 
     using event_queue_type = typename communicator_type::event_queue;
-    util::double_buffer< std::vector<event_queue_type> > event_queues_;
+    util::double_buffer<std::vector<event_queue_type>> event_queues_;
 
     using local_spike_store_type = thread_private_spike_store<time_type>;
-    util::double_buffer< local_spike_store_type > local_spikes_;
+    util::double_buffer<local_spike_store_type> local_spikes_;
 
     spike_export_function global_export_callback_ = util::nop_function;
     spike_export_function local_export_callback_ = util::nop_function;
diff --git a/src/util/debug.hpp b/src/util/debug.hpp
index be2b3d51..f189852d 100644
--- a/src/util/debug.hpp
+++ b/src/util/debug.hpp
@@ -78,5 +78,6 @@ void debug_emit_trace(const char* file, int line, const char* varlist, const Arg
        (void)((condition) || \
        nest::mc::util::global_failed_assertion_handler(#condition, __FILE__, __LINE__, DEBUG_FUNCTION_NAME))
 #else
-    #define EXPECTS(condition)
+    #define EXPECTS(condition) \
+       (void)(false && (condition))
 #endif // def WITH_ASSERTIONS
diff --git a/src/util/partition.hpp b/src/util/partition.hpp
index b35f81de..376b769e 100644
--- a/src/util/partition.hpp
+++ b/src/util/partition.hpp
@@ -29,12 +29,17 @@ class partition_range: public range<partition_iterator<I>> {
 public:
     using typename base::iterator;
     using typename base::value_type;
+    using typename base::size_type;
     using base::left;
     using base::right;
     using base::front;
     using base::back;
     using base::empty;
 
+    static constexpr size_type npos = static_cast<size_type>(-1);
+
+    partition_range() = default;
+
     template <typename Seq>
     partition_range(const Seq& s): base{std::begin(s), upto(std::begin(s), std::end(s))} {
         EXPECTS(is_valid());
@@ -62,6 +67,11 @@ public:
         return iterator{std::prev(i)};
     }
 
+    size_type index(const inner_value_type& x) const {
+        iterator i = find(x);
+        return i==right? npos: i-left;
+    }
+
     // access to underlying divisions
     range<I> divisions() const {
         return {left.get(), std::next(right.get())};
diff --git a/src/util/partition_iterator.hpp b/src/util/partition_iterator.hpp
index f13b82f2..dae1475c 100644
--- a/src/util/partition_iterator.hpp
+++ b/src/util/partition_iterator.hpp
@@ -37,6 +37,8 @@ public:
     using pointer = const value_type*;
     using reference = const value_type&;
 
+    partition_iterator() = default;
+
     template <
         typename J,
         typename = enable_if_t<!std::is_same<decay_t<J>, partition_iterator>::value>
diff --git a/src/util/range.hpp b/src/util/range.hpp
index 6bcb4c25..89fcaddb 100644
--- a/src/util/range.hpp
+++ b/src/util/range.hpp
@@ -336,6 +336,30 @@ void append(Container &c, const Seq& seq) {
     c.insert(c.end(), seq.begin(), seq.end());
 }
 
+template <typename AssignableContainer, typename Seq>
+AssignableContainer& assign(AssignableContainer& c, const Seq& seq) {
+    c.assign(seq.begin(), seq.end());
+    return c;
+}
+
+template <typename Seq>
+range<typename sequence_traits<Seq>::iterator_type, typename sequence_traits<Seq>::sentinel_type>
+range_view(Seq& seq) {
+    return make_range(std::begin(seq), std::end(seq));
+}
+
+template <
+    typename Seq,
+    typename Iter = typename sequence_traits<Seq>::iterator_type,
+    typename Size = typename sequence_traits<Seq>::size_type
+>
+enable_if_t<is_forward_iterator<Iter>::value, range<Iter>>
+subrange_view(Seq& seq, Size bi, Size ei) {
+    Iter b = std::next(std::begin(seq), bi);
+    Iter e = std::next(b, ei-bi);
+    return make_range(b, e);
+}
+
 } // namespace util
 } // namespace mc
 } // namespace nest
diff --git a/src/util/transform.hpp b/src/util/transform.hpp
index d20b0061..a5ab0b34 100644
--- a/src/util/transform.hpp
+++ b/src/util/transform.hpp
@@ -37,6 +37,8 @@ public:
     using pointer = const value_type*;
     using reference = const value_type&;
 
+    transform_iterator() = default;
+
     template <typename J, typename G>
     transform_iterator(J&& c, G&& g): inner_(std::forward<J>(c)), f_(std::forward<G>(g)) {}
 
diff --git a/tests/unit/test_cell_group.cpp b/tests/unit/test_cell_group.cpp
index 99b46e56..9b3f6e6a 100644
--- a/tests/unit/test_cell_group.cpp
+++ b/tests/unit/test_cell_group.cpp
@@ -1,8 +1,9 @@
 #include "gtest.h"
 
+#include <cell_group.hpp>
 #include <common_types.hpp>
 #include <fvm_cell.hpp>
-#include <cell_group.hpp>
+#include <util/range.hpp>
 
 #include "../test_common_cells.hpp"
 
@@ -22,7 +23,7 @@ TEST(cell_group, test)
     using namespace nest::mc;
 
     using cell_group_type = cell_group<fvm::fvm_cell<double, cell_local_size_type>>;
-    auto group = cell_group_type{0, make_cell()};
+    auto group = cell_group_type{0, util::singleton_view(make_cell())};
 
     group.advance(50, 0.01);
 
@@ -44,7 +45,7 @@ TEST(cell_group, sources)
     cell.add_detector({1, 0.3}, 2.3);
 
     cell_gid_type first_gid = 37u;
-    auto group = cell_group_type{first_gid, cell};
+    auto group = cell_group_type{first_gid, util::singleton_view(cell)};
 
     // expect group sources to be lexicographically sorted by source id
     // with gids in cell group's range and indices starting from zero
diff --git a/tests/validation/validate_ball_and_stick.cpp b/tests/validation/validate_ball_and_stick.cpp
index da34cfcf..2a5ec45c 100644
--- a/tests/validation/validate_ball_and_stick.cpp
+++ b/tests/validation/validate_ball_and_stick.cpp
@@ -3,7 +3,8 @@
 
 #include <common_types.hpp>
 #include <cell.hpp>
-#include <fvm_cell.hpp>
+//#include <fvm_cell.hpp>
+#include <fvm_multicell.hpp>
 #include <util/range.hpp>
 
 #include "gtest.h"
@@ -72,7 +73,7 @@ TEST(ball_and_stick, neuron_baseline)
         }
     };
 
-    using fvm_cell = fvm::fvm_cell<double, cell_local_size_type>;
+    using fvm_cell = fvm::fvm_multicell<double, cell_local_size_type>;
     std::vector<fvm_cell::detector_handle> detectors(cell.detectors().size());
     std::vector<fvm_cell::target_handle> targets(cell.synapses().size());
     std::vector<fvm_cell::probe_handle> probes(cell.probes().size());
@@ -207,7 +208,7 @@ TEST(ball_and_3stick, neuron_baseline)
         }
     };
 
-    using fvm_cell = fvm::fvm_cell<double, cell_local_size_type>;
+    using fvm_cell = fvm::fvm_multicell<double, cell_local_size_type>;
     std::vector<fvm_cell::detector_handle> detectors(cell.detectors().size());
     std::vector<fvm_cell::target_handle> targets(cell.synapses().size());
     std::vector<fvm_cell::probe_handle> probes(cell.probes().size());
diff --git a/tests/validation/validate_synapses.cpp b/tests/validation/validate_synapses.cpp
index 9aa90abe..d1de3df0 100644
--- a/tests/validation/validate_synapses.cpp
+++ b/tests/validation/validate_synapses.cpp
@@ -8,6 +8,7 @@
 #include <cell_group.hpp>
 #include <fvm_cell.hpp>
 #include <mechanism_interface.hpp>
+#include <util/range.hpp>
 
 #include "gtest.h"
 #include "../test_util.hpp"
@@ -108,7 +109,7 @@ void run_neuron_baseline(const char* syn_type, const char* data_file)
         std::vector<std::vector<double>> v(2);
 
         // make the lowered finite volume cell
-        cell_group<lowered_cell> group(0, cell);
+        cell_group<lowered_cell> group(0, util::singleton_view(cell));
 
         // add the 3 spike events to the queue
         group.enqueue_events(synthetic_events);
-- 
GitLab