From 520d27a183a4f17db090b486912ee69b49ef77f6 Mon Sep 17 00:00:00 2001
From: bcumming <bcumming@cscs.ch>
Date: Fri, 24 Jun 2016 11:37:08 +0200
Subject: [PATCH] Add MPI support to miniapp

Added MPI support to the miniapp
- added ring model and all-to-all model
- started refactoring general model setup steps into a model class in
  miniapp
- optional -DWITH_MPI to turn on MPI support
---
 .gitignore                                 |   2 +
 miniapp/CMakeLists.txt                     |   1 +
 miniapp/io.cpp                             |   8 +-
 miniapp/miniapp.cpp                        | 303 +++++++++++++++------
 miniapp/mpi.cpp                            |  59 ++++
 src/cell_group.hpp                         |  46 ++--
 src/communication/communicator.hpp         | 111 ++++----
 src/communication/connection.hpp           |  34 +--
 src/communication/mpi.hpp                  |  72 +++--
 src/communication/mpi_global_policy.hpp    |  22 +-
 src/communication/serial_global_policy.hpp |   8 +-
 src/communication/spike.hpp                |  15 +-
 src/communication/spike_source.hpp         |  61 +----
 13 files changed, 480 insertions(+), 262 deletions(-)
 create mode 100644 miniapp/mpi.cpp

diff --git a/.gitignore b/.gitignore
index a18a8b7d..e9d54cb7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -64,3 +64,5 @@ external/modparser-update
 external/tmp
 mechanisms/*.hpp
 
+# build path
+build
diff --git a/miniapp/CMakeLists.txt b/miniapp/CMakeLists.txt
index 196ce7ae..a3f8447e 100644
--- a/miniapp/CMakeLists.txt
+++ b/miniapp/CMakeLists.txt
@@ -1,6 +1,7 @@
 set(HEADERS
 )
 set(MINIAPP_SOURCES
+    mpi.cpp
     io.cpp
     miniapp.cpp
 )
diff --git a/miniapp/io.cpp b/miniapp/io.cpp
index 90498c4c..e1a059ae 100644
--- a/miniapp/io.cpp
+++ b/miniapp/io.cpp
@@ -6,14 +6,12 @@ namespace io {
 
 // read simulation options from json file with name fname
 // for now this is just a placeholder
-options read_options(std::string fname)
-{
+options read_options(std::string fname) {
     // 10 cells, 1 synapses per cell, 10 compartments per segment
-    return {5, 1, 100};
+    return {100, 1, 100};
 }
 
-std::ostream& operator<<(std::ostream& o, const options& opt)
-{
+std::ostream& operator<<(std::ostream& o, const options& opt) {
     o << "simultion options:\n";
     o << "  cells                : " << opt.cells << "\n";
     o << "  compartments/segment : " << opt.compartments_per_segment << "\n";
diff --git a/miniapp/miniapp.cpp b/miniapp/miniapp.cpp
index 5fbc8b36..5f2082e1 100644
--- a/miniapp/miniapp.cpp
+++ b/miniapp/miniapp.cpp
@@ -10,15 +10,98 @@
 #include "profiling/profiler.hpp"
 #include "communication/communicator.hpp"
 #include "communication/serial_global_policy.hpp"
+#include "communication/mpi_global_policy.hpp"
 
 using namespace nest;
 
 using real_type = double;
 using index_type = int;
+using id_type = uint32_t;
 using numeric_cell = mc::fvm::fvm_cell<real_type, index_type>;
 using cell_group   = mc::cell_group<numeric_cell>;
+#ifdef WITH_MPI
+using communicator_type =
+    mc::communication::communicator<mc::communication::mpi_global_policy>;
+#else
 using communicator_type =
     mc::communication::communicator<mc::communication::serial_global_policy>;
+#endif
+
+struct model {
+    communicator_type communicator;
+    std::vector<cell_group> cell_groups;
+
+    double time_init;
+    double time_network;
+    double time_solve;
+    double time_comms;
+    void print_times() const {
+        std::cout << "initialization took " << time_init << " s\n";
+        std::cout << "network        took " << time_network << " s\n";
+        std::cout << "solve          took " << time_solve << " s\n";
+        std::cout << "comms          took " << time_comms << " s\n";
+    }
+
+    int num_groups() const {
+        return cell_groups.size();
+    }
+
+    void run(double tfinal, double dt) {
+        auto t = 0.;
+        auto delta = communicator.min_delay();
+        time_solve = 0.;
+        time_comms = 0.;
+        while(t<tfinal) {
+            auto start_solve = mc::util::timer_type::tic();
+            mc::threading::parallel_for::apply(
+                0, num_groups(),
+                [&](int i) {
+                    cell_groups[i].enqueue_events(communicator.queue(i));
+                    cell_groups[i].advance(t+delta, dt);
+                    communicator.add_spikes(cell_groups[i].spikes());
+                    cell_groups[i].clear_spikes();
+                }
+            );
+            time_solve += mc::util::timer_type::toc(start_solve);
+
+            auto start_comms = mc::util::timer_type::tic();
+            communicator.exchange();
+            time_comms += mc::util::timer_type::toc(start_comms);
+
+            t += delta;
+        }
+    }
+
+    void init_communicator() {
+        // calculate the source and synapse distribution serially
+        std::vector<id_type> target_counts(num_groups());
+        std::vector<id_type> source_counts(num_groups());
+        for (auto i=0; i<num_groups(); ++i) {
+            target_counts[i] = cell_groups[i].cell().synapses()->size();
+            source_counts[i] = cell_groups[i].spike_sources().size();
+        }
+
+        target_map = mc::algorithms::make_index(target_counts);
+        source_map = mc::algorithms::make_index(source_counts);
+
+        //  create connections
+        communicator = communicator_type(num_groups(), target_counts);
+    }
+
+    void update_gids() {
+        auto com_policy = communicator.communication_policy();
+        auto global_source_map = com_policy.make_map(source_map.back());
+        auto domain_idx = communicator.domain_id();
+        for (auto i=0; i<num_groups(); ++i) {
+            cell_groups[i].set_source_gids(source_map[i]+global_source_map[domain_idx]);
+            cell_groups[i].set_target_gids(target_map[i]+communicator.target_gid_from_group_lid(0));
+        }
+    }
+
+    // TODO : only stored here because init_communicator() and update_gids() are split
+    std::vector<id_type> source_map;
+    std::vector<id_type> target_map;
+};
 
 // define some global model parameters
 namespace parameters {
@@ -27,7 +110,7 @@ namespace synapses {
     constexpr double delay  = 5.0;  // ms
 
     // connection weight
-    constexpr double weight = 0.05;  // uS
+    constexpr double weight = 0.005;  // uS
 }
 }
 
@@ -36,152 +119,208 @@ namespace synapses {
 ///////////////////////////////////////
 
 /// make a single abstract cell
-mc::cell make_cell(int compartments_per_segment);
+mc::cell make_cell(int compartments_per_segment, int num_synapses);
 
 /// do basic setup (initialize global state, print banner, etc)
-void setup();
+void setup(int argc, char** argv);
 
 /// helper function for initializing cells
 cell_group make_lowered_cell(int cell_index, const mc::cell& c);
 
+/// models
+void ring_model(nest::mc::io::options& opt, model& m);
+void all_to_all_model(nest::mc::io::options& opt, model& m);
+
 ///////////////////////////////////////
 // main
 ///////////////////////////////////////
-int main(void) {
+int main(int argc, char** argv) {
 
-    setup();
+    setup(argc, argv);
 
     // read parameters
     mc::io::options opt;
     try {
         opt = mc::io::read_options("");
-        std::cout << opt << "\n";
+        if (mc::mpi::rank()==0) {
+            std::cout << opt << "\n";
+        }
     }
     catch (std::exception e) {
         std::cerr << e.what() << std::endl;
         exit(1);
     }
 
+    model m;
+    //ring_model(opt, m);
+    all_to_all_model(opt, m);
+
     /////////////////////////////////////////////////////
-    //  make cells
+    //  time stepping
     /////////////////////////////////////////////////////
+    auto tfinal = 50.;
+    auto dt = 0.01;
+
+    auto id = m.communicator.domain_id();
+
+    if (!id) {
+        m.communicator.add_spike({0, 5});
+    }
+
+    m.run(tfinal, dt);
+    if (!id) {
+        m.print_times();
+        std::cout << "there were " << m.communicator.num_spikes() << " spikes\n";
+    }
+
+#ifdef SPLAT
+    if (!mc::mpi::rank()) {
+        //for (auto i=0u; i<m.cell_groups.size(); ++i) {
+        m.cell_groups[0].splat("cell0.txt");
+        m.cell_groups[1].splat("cell1.txt");
+        m.cell_groups[2].splat("cell2.txt");
+        //}
+    }
+#endif
+
+#ifdef WITH_MPI
+    mc::mpi::finalize();
+#endif
+}
+
+///////////////////////////////////////
+// models
+///////////////////////////////////////
+
+void ring_model(nest::mc::io::options& opt, model& m) {
+    //
+    //  make cells
+    //
 
     // make a basic cell
-    auto basic_cell = make_cell(opt.compartments_per_segment);
+    auto basic_cell = make_cell(opt.compartments_per_segment, 1);
 
     // make a vector for storing all of the cells
     auto start_init = mc::util::timer_type::tic();
-    std::vector<cell_group> cell_groups(opt.cells);
+    m.cell_groups = std::vector<cell_group>(opt.cells);
 
     // initialize the cells in parallel
     mc::threading::parallel_for::apply(
         0, opt.cells,
         [&](int i) {
             // initialize cell
-            cell_groups[i] = make_lowered_cell(i, basic_cell);
+            m.cell_groups[i] = make_lowered_cell(i, basic_cell);
         }
     );
-    auto time_init = mc::util::timer_type::toc(start_init);
+    m.time_init = mc::util::timer_type::toc(start_init);
 
-    /////////////////////////////////////////////////////
+    //
     //  network creation
-    /////////////////////////////////////////////////////
-
-    // calculate the source and synapse distribution serially
+    //
     auto start_network = mc::util::timer_type::tic();
-    std::vector<uint32_t> target_counts(opt.cells);
-    std::vector<uint32_t> source_counts(opt.cells);
-    for (auto i=0; i<opt.cells; ++i) {
-        target_counts[i] = cell_groups[i].cell().synapses()->size();
-        source_counts[i] = cell_groups[i].spike_sources().size();
-    }
-
-    auto target_map = mc::algorithms::make_index(target_counts);
-    auto source_map = mc::algorithms::make_index(source_counts);
+    m.init_communicator();
 
-    //  create connections
-    communicator_type communicator(opt.cells, target_counts);
-    for(auto i=0u; i<(uint32_t)opt.cells; ++i) {
-        communicator.add_connection({
+    for (auto i=0u; i<(id_type)opt.cells; ++i) {
+        m.communicator.add_connection({
             i, (i+1)%opt.cells,
             parameters::synapses::weight, parameters::synapses::delay
         });
     }
-    communicator.construct();
 
-    auto global_source_map =
-        communicator.communication_policy().make_map(source_map.back());
-    auto domain_idx = communicator.communication_policy().id();
-    for(auto i=0u; i<(uint32_t)opt.cells; ++i) {
-        cell_groups[i].set_source_gids(source_map[i]+global_source_map[domain_idx]);
-        cell_groups[i].set_target_lids(target_map[i]);
-    }
+    m.communicator.construct();
 
-    auto time_network = mc::util::timer_type::toc(start_network);
+    m.update_gids();
 
-    /////////////////////////////////////////////////////
-    //  time stepping
-    /////////////////////////////////////////////////////
-    auto start_simulation = mc::util::timer_type::tic();
+    m.time_network = mc::util::timer_type::toc(start_network);
+}
 
-    auto tfinal = 20.;
-    auto t =  0.;
-    auto dt = 0.01;
-    auto delta = communicator.min_delay();
-
-    communicator.add_spike({opt.cells-1u, 5});
-
-    while(t<tfinal) {
-        mc::threading::parallel_for::apply(
-            0, opt.cells,
-            [&](int i) {
-                /*if(communicator.queue(i).size()) {
-                    std::cout << ":: delivering events to group " << i << "\n";
-                    std::cout << "  " << communicator.queue(i) << "\n";
-                }*/
-                cell_groups[i].enqueue_events(communicator.queue(i));
-                cell_groups[i].advance(t+delta, dt);
-                communicator.add_spikes(cell_groups[i].spikes());
-                cell_groups[i].clear_spikes();
-            }
-        );
+void all_to_all_model(nest::mc::io::options& opt, model& m) {
+    //
+    //  make cells
+    //
+    auto timer = mc::util::timer_type();
 
-        communicator.exchange();
+    // make a basic cell
+    auto basic_cell = make_cell(opt.compartments_per_segment, opt.cells-1);
 
-        t += delta;
+    // make a vector for storing all of the cells
+    auto start_init = timer.tic();
+    id_type ncell_global = opt.cells;
+    id_type ncell_local  = ncell_global / m.communicator.num_domains();
+    int remainder = ncell_global - (ncell_local*m.communicator.num_domains());
+    if (m.communicator.domain_id()<remainder) {
+        ncell_local++;
     }
 
-    for(auto i=0u; i<cell_groups.size(); ++i) {
-        cell_groups[i].splat("cell"+std::to_string(i)+".txt");
+    m.cell_groups = std::vector<cell_group>(ncell_local);
+
+    // initialize the cells in parallel
+    mc::threading::parallel_for::apply(
+        0, ncell_local,
+        [&](int i) {
+            m.cell_groups[i] = make_lowered_cell(i, basic_cell);
+        }
+    );
+    m.time_init = timer.toc(start_init);
+
+    //
+    //  network creation
+    //
+    auto start_network = timer.tic();
+    m.init_communicator();
+
+    // lid is local cell/group id
+    for (auto lid=0u; lid<ncell_local; ++lid) {
+        auto target = m.communicator.target_gid_from_group_lid(lid);
+        auto gid = m.communicator.group_gid_from_group_lid(lid);
+        // tid is global cell/group id
+        for (auto tid=0u; tid<ncell_global; ++tid) {
+            if (gid!=tid) {
+                m.communicator.add_connection({
+                    tid, target++,
+                    parameters::synapses::weight, parameters::synapses::delay
+                });
+            }
+        }
     }
 
-    auto time_simulation = mc::util::timer_type::toc(start_simulation);
+    m.communicator.construct();
 
-    std::cout << "initialization took " << time_init << " s\n";
-    std::cout << "network        took " << time_network << " s\n";
-    std::cout << "simulation     took " << time_simulation << " s\n";
-    std::cout << "performed " << int(tfinal/dt) << " time steps\n";
+    m.update_gids();
+
+    m.time_network = timer.toc(start_network);
 }
 
 ///////////////////////////////////////
 // function definitions
 ///////////////////////////////////////
 
-void setup()
-{
+void setup(int argc, char** argv) {
+#ifdef WITH_MPI
+    mc::mpi::init(&argc, &argv);
+
+    // print banner
+    if (mc::mpi::rank()==0) {
+        std::cout << "====================\n";
+        std::cout << "  starting miniapp\n";
+        std::cout << "  - " << mc::threading::description() << " threading support\n";
+        std::cout << "  - MPI support\n";
+        std::cout << "====================\n";
+    }
+#else
     // print banner
     std::cout << "====================\n";
     std::cout << "  starting miniapp\n";
     std::cout << "  - " << mc::threading::description() << " threading support\n";
     std::cout << "====================\n";
+#endif
 
     // setup global state for the mechanisms
     mc::mechanisms::setup_mechanism_helpers();
 }
 
 // make a high level cell description for use in simulation
-mc::cell make_cell(int compartments_per_segment)
-{
+mc::cell make_cell(int compartments_per_segment, int num_synapses) {
     nest::mc::cell cell;
 
     // Soma with diameter 12.6157 um and HH channel
@@ -191,10 +330,10 @@ mc::cell make_cell(int compartments_per_segment)
     // add dendrite of length 200 um and diameter 1 um with passive channel
     std::vector<mc::cable_segment*> dendrites;
     dendrites.push_back(cell.add_cable(0, mc::segmentKind::dendrite, 0.5, 0.5, 200));
-    //dendrites.push_back(cell.add_cable(1, mc::segmentKind::dendrite, 0.5, 0.25, 100));
-    //dendrites.push_back(cell.add_cable(1, mc::segmentKind::dendrite, 0.5, 0.25, 100));
+    dendrites.push_back(cell.add_cable(1, mc::segmentKind::dendrite, 0.5, 0.25, 100));
+    dendrites.push_back(cell.add_cable(1, mc::segmentKind::dendrite, 0.5, 0.25, 100));
 
-    for(auto d : dendrites) {
+    for (auto d : dendrites) {
         d->add_mechanism(mc::pas_parameters());
         d->set_compartments(compartments_per_segment);
         d->mechanism("membrane").set("r_L", 100);
@@ -204,13 +343,15 @@ mc::cell make_cell(int compartments_per_segment)
     //cell.add_stimulus({1,1}, {5., 80., 0.3});
 
     cell.add_detector({0,0}, 30);
-    cell.add_synapse({1, 0.5});
+
+    for (auto i=0; i<num_synapses; ++i) {
+        cell.add_synapse({1, 0.5});
+    }
 
     return cell;
 }
 
-cell_group make_lowered_cell(int cell_index, const mc::cell& c)
-{
+cell_group make_lowered_cell(int cell_index, const mc::cell& c) {
     return cell_group(c);
 }
 
diff --git a/miniapp/mpi.cpp b/miniapp/mpi.cpp
new file mode 100644
index 00000000..0481ec15
--- /dev/null
+++ b/miniapp/mpi.cpp
@@ -0,0 +1,59 @@
+#include <mpi.h>
+
+#include <communication/mpi.hpp>
+
+namespace nest {
+namespace mc {
+namespace mpi {
+
+// global state
+namespace state {
+    int size = -1;
+    int rank = -1;
+} // namespace state
+
+void init(int *argc, char ***argv) {
+    int provided;
+
+    // initialize with thread serialized level of thread safety
+    MPI_Init_thread(argc, argv, MPI_THREAD_SERIALIZED, &provided);
+    assert(provided>=MPI_THREAD_SERIALIZED);
+
+    MPI_Comm_rank(MPI_COMM_WORLD, &state::rank);
+    MPI_Comm_size(MPI_COMM_WORLD, &state::size);
+}
+
+void finalize() {
+    MPI_Finalize();
+}
+
+bool is_root() {
+    return state::rank == 0;
+}
+
+int rank() {
+    return state::rank;
+}
+
+int size() {
+    return state::size;
+}
+
+void barrier() {
+    MPI_Barrier(MPI_COMM_WORLD);
+}
+
+bool ballot(bool vote) {
+    using traits = mpi_traits<char>;
+
+    char result;
+    char value = vote ? 1 : 0;
+
+    MPI_Allreduce(&value, &result, 1, traits::mpi_type(), MPI_LAND, MPI_COMM_WORLD);
+
+    return result;
+}
+
+} // namespace mpi
+} // namespace mc
+} // namespace nest
diff --git a/src/cell_group.hpp b/src/cell_group.hpp
index 767a483b..85a0832b 100644
--- a/src/cell_group.hpp
+++ b/src/cell_group.hpp
@@ -33,7 +33,7 @@ class cell_group {
         cell_.voltage()(memory::all) = -65.;
         cell_.initialize();
 
-        for(auto& d : c.detectors()) {
+        for (auto& d : c.detectors()) {
             spike_sources_.push_back( {
                 0u, spike_detector_type(cell_, d.first, d.second, 0.f)
             });
@@ -41,31 +41,33 @@ class cell_group {
     }
 
     void set_source_gids(index_type gid) {
-        for(auto& s : spike_sources_) {
+        for (auto& s : spike_sources_) {
             s.index = gid++;
         }
     }
 
-    void set_target_lids(index_type lid) {
-        first_target_lid_ = lid;
+    void set_target_gids(index_type lid) {
+        first_target_gid_ = lid;
     }
 
+#ifdef SPLAT
     void splat(std::string fname) {
         char buffer[128];
         std::ofstream fid(fname);
-        for(auto i=0u; i<tt.size(); ++i) {
+        for (auto i=0u; i<tt.size(); ++i) {
             sprintf(buffer, "%8.4f %16.8f %16.8f\n", tt[i], vs[i], vd[i]);
             fid << buffer;
         }
     }
+#endif
 
     void advance(double tfinal, double dt) {
-
         while (cell_.time()<tfinal) {
+#ifdef SPLAT
             tt.push_back(cell_.time());
             vs.push_back(cell_.voltage({0,0.0}));
             vd.push_back(cell_.voltage({1,0.5}));
-
+#endif
             // look for events in the next time step
             auto tstep = std::min(tfinal, cell_.time()+dt);
             auto next = events_.pop_if_before(tstep);
@@ -76,8 +78,7 @@ class cell_group {
 
             // check for new spikes
             for (auto& s : spike_sources_) {
-                auto spike = s.source.test(cell_, cell_.time());
-                if(spike) {
+                if (auto spike = s.source.test(cell_, cell_.time())) {
                     spikes_.push_back({s.index, spike.get()});
                 }
             }
@@ -85,6 +86,12 @@ class cell_group {
             // apply events
             if (next) {
                 cell_.apply_event(next.get());
+                // apply events that are due within some epsilon of the current
+                // time step. This should be a parameter. e.g. with for variable
+                // order time stepping, use the minimum possible time step size.
+                while(auto e = events_.pop_if_before(cell_.time()+dt/10.)) {
+                    cell_.apply_event(e.get());
+                }
             }
         }
 
@@ -92,8 +99,8 @@ class cell_group {
 
     template <typename R>
     void enqueue_events(R events) {
-        for(auto e : events) {
-            e.target -= first_target_lid_;
+        for (auto e : events) {
+            e.target -= first_target_gid_;
             events_.push(e);
         }
     }
@@ -107,8 +114,7 @@ class cell_group {
     const cell_type& cell() const { return cell_; }
 
     const std::vector<spike_source_type>&
-    spike_sources() const
-    {
+    spike_sources() const {
         return spike_sources_;
     }
 
@@ -118,19 +124,27 @@ class cell_group {
 
     private :
 
-    // TEMPORARY...
+#ifdef SPLAT
+    // REMOVE as soon as we have a better way to probe cell state
     std::vector<float> tt;
     std::vector<float> vs;
     std::vector<float> vd;
+#endif
 
+    /// the lowered cell state (e.g. FVM) of the cell
     cell_type cell_;
+
+    /// spike detectors attached to the cell
     std::vector<spike_source_type> spike_sources_;
 
-    // spikes that are generated
+    //. spikes that are generated
     std::vector<communication::spike<index_type>> spikes_;
+
+    /// pending events to be delivered
     event_queue events_;
 
-    index_type first_target_lid_;
+    /// the global id of the first target (e.g. a synapse) in this group
+    index_type first_target_gid_;
 };
 
 } // namespace mc
diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp
index 2a759385..517a2f6c 100644
--- a/src/communication/communicator.hpp
+++ b/src/communication/communicator.hpp
@@ -29,17 +29,17 @@ namespace communication {
 template <typename CommunicationPolicy>
 class communicator {
 public:
-    using index_type = uint32_t;
+    using id_type = uint32_t;
     using communication_policy_type = CommunicationPolicy;
 
-    using spike_type = spike<index_type>;
+    using spike_type = spike<id_type>;
 
-    communicator( index_type n_groups, std::vector<index_type> target_counts) :
+    communicator() = default;
+
+    communicator(id_type n_groups, std::vector<id_type> target_counts) :
         num_groups_local_(n_groups),
         num_targets_local_(target_counts.size())
     {
-        communicator_id_ = communication_policy_.id();
-
         target_map_ = nest::mc::algorithms::make_index(target_counts);
         num_targets_local_ = target_map_.back();
 
@@ -51,50 +51,60 @@ public:
         group_gid_map_  = communication_policy_.make_map(num_groups_local_);
 
         // transform the target ids from lid to gid
-        auto first_target = target_gid_map_[communicator_id_];
-        for(auto &id : target_map_) {
+        auto first_target = target_gid_map_[domain_id()];
+        for (auto &id : target_map_) {
             id += first_target;
         }
     }
 
+    id_type target_gid_from_group_lid(id_type lid) const {
+        EXPECTS(lid<num_groups_local_);
+        return target_map_[lid];
+    }
+
+    id_type group_gid_from_group_lid(id_type lid) const {
+        EXPECTS(lid<num_groups_local_);
+        return group_gid_map_[domain_id()] + lid;
+    }
+
     void add_connection(connection con) {
         EXPECTS(is_local_target(con.destination()));
         connections_.push_back(con);
     }
 
-    bool is_local_target(index_type gid) {
-        return gid>=target_gid_map_[communicator_id_]
-            && gid<target_gid_map_[communicator_id_+1];
+    bool is_local_target(id_type gid) {
+        return gid>=target_gid_map_[domain_id()]
+            && gid<target_gid_map_[domain_id()+1];
     }
 
-    bool is_local_group(index_type gid) {
-        return gid>=group_gid_map_[communicator_id_]
-            && gid<group_gid_map_[communicator_id_+1];
+    bool is_local_group(id_type gid) {
+        return gid>=group_gid_map_[domain_id()]
+            && gid<group_gid_map_[domain_id()+1];
     }
 
     /// return the global id of the first group in domain d
     /// the groups in domain d are in the contiguous half open range
     ///     [domain_first_group(d), domain_first_group(d+1))
-    index_type group_gid_first(int d) const {
+    id_type group_gid_first(int d) const {
         return group_gid_map_[d];
     }
 
-    index_type target_lid(index_type gid) {
+    id_type target_lid(id_type gid) {
         EXPECTS(is_local_group(gid));
 
-        return gid - target_gid_map_[communicator_id_];
+        return gid - target_gid_map_[domain_id()];
     }
 
     // builds the optimized data structure
     void construct() {
-        if(!std::is_sorted(connections_.begin(), connections_.end())) {
+        if (!std::is_sorted(connections_.begin(), connections_.end())) {
             std::sort(connections_.begin(), connections_.end());
         }
     }
 
     float min_delay() {
         auto local_min = std::numeric_limits<float>::max();
-        for(auto& con : connections_) {
+        for (auto& con : connections_) {
             local_min = std::min(local_min, con.delay());
         }
 
@@ -103,7 +113,7 @@ public:
 
     // return the local group index of the group which hosts the target with
     // global id gid
-    index_type local_group_from_global_target(index_type gid) {
+    id_type local_group_from_global_target(id_type gid) {
         // assert that gid is in range
         EXPECTS(is_local_target(gid));
 
@@ -133,10 +143,11 @@ public:
         // on each node
         //profiler_.enter("global exchange");
         auto global_spikes = communication_policy_.gather_spikes(local_spikes());
+        num_spikes_ += global_spikes.size();
         clear_thread_spike_buffers();
         //profiler_.leave();
 
-        for(auto& q : events_) {
+        for (auto& q : events_) {
             q.clear();
         }
 
@@ -144,7 +155,7 @@ public:
 
         //profiler_.enter("make events");
         // check all global spikes to see if they will generate local events
-        for(auto spike : global_spikes) {
+        for (auto spike : global_spikes) {
             // search for targets
             auto targets =
                 std::equal_range(
@@ -152,7 +163,7 @@ public:
                 );
 
             // generate an event for each target
-            for(auto it=targets.first; it!=targets.second; ++it) {
+            for (auto it=targets.first; it!=targets.second; ++it) {
                 auto gidx = local_group_from_global_target(it->destination());
 
                 events_[gidx].push_back(it->make_event(spike));
@@ -164,56 +175,62 @@ public:
         //profiler_.leave(); // event generation
     }
 
-    const std::vector<local_event>& queue(int i)
+    uint64_t num_spikes() const
     {
+        return num_spikes_;
+    }
+
+    int domain_id() const {
+        return communication_policy_.id();
+    }
+
+    int num_domains() const {
+        return communication_policy_.size();
+    }
+
+    const std::vector<local_event>& queue(int i) const {
         return events_[i];
     }
 
-    std::vector<connection>const& connections() const {
+    const std::vector<connection>& connections() const {
         return connections_;
     }
 
-    communication_policy_type& communication_policy() {
+    communication_policy_type communication_policy() const {
         return communication_policy_;
     }
 
-    const std::vector<index_type>& local_target_map() const {
+    const std::vector<id_type>& local_target_map() const {
         return target_map_;
     }
 
     std::vector<spike_type> local_spikes() {
         std::vector<spike_type> spikes;
-        for(auto& v : thread_spikes_) {
+        for (auto& v : thread_spikes_) {
             spikes.insert(spikes.end(), v.begin(), v.end());
         }
         return spikes;
     }
 
     void clear_thread_spike_buffers() {
-        for(auto& v : thread_spikes_) {
+        for (auto& v : thread_spikes_) {
             v.clear();
         }
     }
 
 private:
 
+    //
+    //  both of these can be fixed with double buffering
+    //
     // FIXME : race condition on the thread_spikes_ buffers when exchange() modifies/access them
     //         ... other threads will be pushing to them simultaneously
-    // FIXME : race condition on the group-specific event queues when exchange pusheds to them
-    //         ... other threads will be accessing them to get events
+    // FIXME : race condition on the group-specific event queues when exchange pushes to them
+    //         ... other threads will be accessing them to update their event queues
 
     // thread private storage for accumulating spikes
     using local_spike_store_type =
         nest::mc::threading::enumerable_thread_specific<std::vector<spike_type>>;
-    /*
-#ifdef WITH_TBB
-    using local_spike_store_type =
-        tbb::enumerable_thread_specific<std::vector<spike_type>>;
-#else
-    using local_spike_store_type =
-        std::array<std::vector<spike_type>, 1>;
-#endif
-    */
     local_spike_store_type thread_spikes_;
 
     std::vector<connection> connections_;
@@ -221,30 +238,28 @@ private:
 
     // local target group i has targets in the half open range
     //      [target_map_[i], target_map_[i+1])
-    std::vector<index_type> target_map_;
+    std::vector<id_type> target_map_;
 
     // for keeping track of how time is spent where
     //util::Profiler profiler_;
 
     // the number of groups and targets handled by this communicator
-    index_type num_groups_local_;
-    index_type num_targets_local_;
+    id_type num_groups_local_;
+    id_type num_targets_local_;
 
     // index maps for the global distribution of groups and targets
 
     // communicator i has the groups in the half open range :
     //      [group_gid_map_[i], group_gid_map_[i+1])
-    std::vector<index_type> group_gid_map_;
+    std::vector<id_type> group_gid_map_;
 
     // communicator i has the targets in the half open range :
     //      [target_gid_map_[i], target_gid_map_[i+1])
-    std::vector<index_type> target_gid_map_;
-
-    // each communicator has a unique id
-    // e.g. for MPI this could be the MPI rank
-    index_type communicator_id_;
+    std::vector<id_type> target_gid_map_;
 
     communication_policy_type communication_policy_;
+
+    uint64_t num_spikes_ = 0u;
 };
 
 } // namespace communication
diff --git a/src/communication/connection.hpp b/src/communication/connection.hpp
index 7519a286..528c57b5 100644
--- a/src/communication/connection.hpp
+++ b/src/communication/connection.hpp
@@ -11,36 +11,28 @@ namespace communication {
 
 class connection {
 public:
-    using index_type = uint32_t;
-    connection(index_type src, index_type dest, float w, float d)
-    :   source_(src),
+    using id_type = uint32_t;
+    connection(id_type src, id_type dest, float w, float d) :
+        source_(src),
         destination_(dest),
         weight_(w),
         delay_(d)
-    { }
+    {}
 
-    float weight() const {
-        return weight_;
-    }
-    float delay() const {
-        return delay_;
-    }
+    float weight() const { return weight_; }
+    float delay() const { return delay_; }
 
-    index_type source() const {
-        return source_;
-    }
-    index_type destination() const {
-        return destination_;
-    }
+    id_type source() const { return source_; }
+    id_type destination() const { return destination_; }
 
-    local_event make_event(spike<index_type> s) {
+    local_event make_event(spike<id_type> s) {
         return {destination_, s.time + delay_, weight_};
     }
 
 private:
 
-    index_type source_;
-    index_type destination_;
+    id_type source_;
+    id_type destination_;
     float weight_;
     float delay_;
 };
@@ -54,12 +46,12 @@ bool operator< (connection lhs, connection rhs) {
 }
 
 static inline
-bool operator< (connection lhs, connection::index_type rhs) {
+bool operator< (connection lhs, connection::id_type rhs) {
     return lhs.source() < rhs;
 }
 
 static inline
-bool operator< (connection::index_type lhs, connection rhs) {
+bool operator< (connection::id_type lhs, connection rhs) {
     return lhs < rhs.source();
 }
 
diff --git a/src/communication/mpi.hpp b/src/communication/mpi.hpp
index 398e6033..063e36ef 100644
--- a/src/communication/mpi.hpp
+++ b/src/communication/mpi.hpp
@@ -2,15 +2,30 @@
 
 #include <algorithm>
 #include <iostream>
+#include <type_traits>
 #include <vector>
 
 #include <cassert>
 
 #include <mpi.h>
-#include "utils.hpp" 
-#include "utils.hpp" 
+
+#include <algorithms.hpp>
+
+namespace nest {
+namespace mc {
 namespace mpi {
 
+    // prototypes
+    void init(int *argc, char ***argv);
+    void finalize();
+    bool is_root();
+    int rank();
+    int size();
+    void barrier();
+    bool ballot(bool vote);
+
+    // type traits for automatically setting MPI_Datatype information
+    // for C++ types
     template <typename T>
     struct mpi_traits {
         constexpr static size_t count() {
@@ -41,15 +56,15 @@ namespace mpi {
     static_assert(sizeof(size_t)==sizeof(unsigned long),
                   "size_t and unsigned long are not equivalent");
 
-    bool init(int *argc, char ***argv);
-    bool finalize();
-    bool is_root();
-    int rank();
-    int size();
-    void barrier();
-
-    template <typename T>
+    // Gather individual values of type T from each rank into a std::vector on
+    // the root rank.
+    // T must be trivially copyable
+    template<typename T>
     std::vector<T> gather(T value, int root) {
+        static_assert(
+            true,//std::is_trivially_copyable<T>::value,
+            "gather can only be performed on trivally copyable types");
+
         using traits = mpi_traits<T>;
         auto buffer_size = (rank()==root) ? size() : 0;
         std::vector<T> buffer(buffer_size);
@@ -61,8 +76,15 @@ namespace mpi {
         return buffer;
     }
 
+    // Gather individual values of type T from each rank into a std::vector on
+    // the every rank.
+    // T must be trivially copyable
     template <typename T>
     std::vector<T> gather_all(T value) {
+        static_assert(
+            true,//std::is_trivially_copyable<T>::value,
+            "gather_all can only be performed on trivally copyable types");
+
         using traits = mpi_traits<T>;
         std::vector<T> buffer(size());
 
@@ -75,12 +97,16 @@ namespace mpi {
 
     template <typename T>
     std::vector<T> gather_all(const std::vector<T> &values) {
+        static_assert(
+            true,//std::is_trivially_copyable<T>::value,
+            "gather_all can only be performed on trivally copyable types");
+
         using traits = mpi_traits<T>;
         auto counts = gather_all(int(values.size()));
-        for(auto& c : counts) {
+        for (auto& c : counts) {
             c *= traits::count();
         }
-        auto displs = algorithms::make_map(counts);
+        auto displs = algorithms::make_index(counts);
 
         std::vector<T> buffer(displs.back()/traits::count());
 
@@ -98,8 +124,9 @@ namespace mpi {
     template <typename T>
     T reduce(T value, MPI_Op op, int root) {
         using traits = mpi_traits<T>;
-        static_assert(traits::is_mpi_native_type(),
-                      "can only perform reductions on MPI native types");
+        static_assert(
+            traits::is_mpi_native_type(),
+            "can only perform reductions on MPI native types");
 
         T result;
 
@@ -111,8 +138,9 @@ namespace mpi {
     template <typename T>
     T reduce(T value, MPI_Op op) {
         using traits = mpi_traits<T>;
-        static_assert(traits::is_mpi_native_type(),
-                      "can only perform reductions on MPI native types");
+        static_assert(
+            traits::is_mpi_native_type(),
+            "can only perform reductions on MPI native types");
 
         T result;
 
@@ -131,10 +159,12 @@ namespace mpi {
         return {reduce<T>(value, MPI_MIN, root), reduce<T>(value, MPI_MAX, root)};
     }
 
-    bool ballot(bool vote);
-
     template <typename T>
     T broadcast(T value, int root) {
+        static_assert(
+            true,//std::is_trivially_copyable<T>::value,
+            "broadcast can only be performed on trivally copyable types");
+
         using traits = mpi_traits<T>;
 
         MPI_Bcast(&value, traits::count(), traits::mpi_type(), root, MPI_COMM_WORLD);
@@ -144,6 +174,10 @@ namespace mpi {
 
     template <typename T>
     T broadcast(int root) {
+        static_assert(
+            true,//std::is_trivially_copyable<T>::value,
+            "broadcast can only be performed on trivally copyable types");
+
         using traits = mpi_traits<T>;
         T value;
 
@@ -153,3 +187,5 @@ namespace mpi {
     }
 
 } // namespace mpi
+} // namespace mc
+} // namespace nest
diff --git a/src/communication/mpi_global_policy.hpp b/src/communication/mpi_global_policy.hpp
index 43d830b5..79135437 100644
--- a/src/communication/mpi_global_policy.hpp
+++ b/src/communication/mpi_global_policy.hpp
@@ -6,32 +6,34 @@
 #include <cstdint>
 
 #include <communication/spike.hpp>
+#include <communication/mpi.hpp>
 #include <algorithms.hpp>
 
-#include "mpi.hpp"
 
 namespace nest {
 namespace mc {
 namespace communication {
 
 struct mpi_global_policy {
-    std::vector<spike<uint32_t>> const
-    gather_spikes(const std::vector<spike<uint32_t>>& local_spikes) {
+    using id_type = uint32_t;
+
+    std::vector<spike<id_type>> const
+    gather_spikes(const std::vector<spike<id_type>>& local_spikes) {
         return mpi::gather_all(local_spikes);
     }
 
-    int id() const {
-        return mpi::rank();
-    }
+    int id() const { return mpi::rank(); }
+
+    int size() const { return mpi::size(); }
 
-    /*
     template <typename T>
     T min(T value) const {
+        return nest::mc::mpi::reduce(value, MPI_MIN);
     }
-    */
 
-    int num_communicators() const {
-        return mpi::size();
+    template <typename T>
+    T max(T value) const {
+        return nest::mc::mpi::reduce(value, MPI_MAX);
     }
 
     template <
diff --git a/src/communication/serial_global_policy.hpp b/src/communication/serial_global_policy.hpp
index 94c2c430..eaa6c4b1 100644
--- a/src/communication/serial_global_policy.hpp
+++ b/src/communication/serial_global_policy.hpp
@@ -12,8 +12,10 @@ namespace mc {
 namespace communication {
 
 struct serial_global_policy {
-    std::vector<spike<uint32_t>> const
-    gather_spikes(const std::vector<spike<uint32_t>>& local_spikes) {
+    using id_type = uint32_t;
+
+    std::vector<spike<id_type>> const
+    gather_spikes(const std::vector<spike<id_type>>& local_spikes) {
         return local_spikes;
     }
 
@@ -21,7 +23,7 @@ struct serial_global_policy {
         return 0;
     }
 
-    static int num_communicators() {
+    static int size() {
         return 1;
     }
 
diff --git a/src/communication/spike.hpp b/src/communication/spike.hpp
index 09c66522..03b5ed75 100644
--- a/src/communication/spike.hpp
+++ b/src/communication/spike.hpp
@@ -12,14 +12,14 @@ template <
     typename = typename std::enable_if<std::is_integral<I>::value>
 >
 struct spike {
-    using index_type = I;
-    index_type source = 0;
+    using id_type = I;
+    id_type source = 0;
     float time = -1.;
 
     spike() = default;
 
-    spike(index_type s, float t)
-    :   source(s), time(t)
+    spike(id_type s, float t) :
+        source(s), time(t)
     {}
 };
 
@@ -29,14 +29,17 @@ struct spike {
 
 /// custom stream operator for printing nest::mc::spike<> values
 template <typename I>
-std::ostream& operator <<(std::ostream& o, nest::mc::communication::spike<I> s) {
+std::ostream& operator<<(
+    std::ostream& o,
+    nest::mc::communication::spike<I> s)
+{
     return o << "spike[t " << s.time << ", src " << s.source << "]";
 }
 
 /// less than comparison operator for nest::mc::spike<> values
 /// spikes are ordered by spike time, for use in sorting and queueing
 template <typename I>
-bool operator <(
+bool operator<(
     nest::mc::communication::spike<I> lhs,
     nest::mc::communication::spike<I> rhs)
 {
diff --git a/src/communication/spike_source.hpp b/src/communication/spike_source.hpp
index d3258b7c..95099cb5 100644
--- a/src/communication/spike_source.hpp
+++ b/src/communication/spike_source.hpp
@@ -13,13 +13,8 @@ class spike_detector
 public:
     using cell_type = Cell;
 
-    spike_detector(
-        const cell_type& cell,
-        segment_location loc,
-        double thresh,
-        float t_init
-    )
-    :   location_(loc),
+    spike_detector( const cell_type& cell, segment_location loc, double thresh, float t_init) :
+        location_(loc),
         threshold_(thresh),
         previous_t_(t_init)
     {
@@ -27,8 +22,7 @@ public:
         is_spiking_ = previous_v_ >= thresh ? true : false;
     }
 
-    util::optional<float> test(const cell_type& cell, float t)
-    {
+    util::optional<float> test(const cell_type& cell, float t) {
         util::optional<float> result = util::nothing;
         auto v = cell.voltage(location_);
 
@@ -56,26 +50,17 @@ public:
         return result;
     }
 
-    bool is_spiking() const {
-        return is_spiking_;
-    }
+    bool is_spiking() const { return is_spiking_; }
 
-    segment_location location() const {
-        return location_;
-    }
+    segment_location location() const { return location_; }
 
-    float t() const {
-        return previous_t_;
-    }
+    float t() const { return previous_t_; }
 
-    float v() const {
-        return previous_v_;
-    }
+    float v() const { return previous_v_; }
 
 private:
 
     // parameters/data
-    //const cell_type* cell_;
     segment_location location_;
     double threshold_;
 
@@ -85,38 +70,6 @@ private:
     bool is_spiking_;
 };
 
-/*
-// spike generator according to a Poisson process
-class poisson_generator : public spike_source
-{
-    public:
-
-    poisson_generator(float r)
-    :   dist_(0.0f, 1.0f),
-        firing_rate_(r)
-    {}
-
-    util::optional<float> test(float t) {
-        // generate a uniformly distrubuted random number x \in [0,1]
-        // if  (x > r*dt)  we have a spike in the interval
-        std::vector<float> spike_times;
-        if(dist_(generator_) > firing_rate_*(t-previous_t_)) {
-            return t;
-        }
-        return util::nothing;
-    }
-
-    private:
-
-    std::mt19937 generator_; // for now default initialized
-    std::uniform_real_distribution<float> dist_;
-
-    // firing rate in spikes/ms
-    float firing_rate_;
-    float previous_t_;
-};
-*/
-
 
 } // namespace mc
 } // namespace nest
-- 
GitLab