From bcb4e57d697b3535531be91621bd8f520947db17 Mon Sep 17 00:00:00 2001
From: bcumming <bcumming@cscs.ch>
Date: Fri, 19 Aug 2016 13:08:54 +0200
Subject: [PATCH] remove spike output callback from communicator

* make a new gathered_vector type for storing a "distributed" vector
  that has been gathered onto a single node
    * stores the vector
    * stores the partition information
* split communicator::exchange() into two phases
    * exchange() which simply exchanges spikes
    * make_event_queues() which generates local queues from spikes
* now callbacks for both local and global spike store can be
  performed in the model
* and extra meta data about the partition is available for more
  advanced parallel global storage schemes.
---
 src/communication/communicator.hpp         | 38 +++++++---------
 src/communication/gathered_vector.hpp      | 52 ++++++++++++++++++++++
 src/communication/mpi.hpp                  | 33 ++++++++++++++
 src/communication/mpi_global_policy.hpp    |  9 ++--
 src/communication/serial_global_policy.hpp |  9 ++--
 src/model.hpp                              | 23 ++++++----
 6 files changed, 126 insertions(+), 38 deletions(-)
 create mode 100644 src/communication/gathered_vector.hpp

diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp
index 50ce644c..5da7f189 100644
--- a/src/communication/communicator.hpp
+++ b/src/communication/communicator.hpp
@@ -6,13 +6,13 @@
 #include <random>
 #include <functional>
 
-#include <spike.hpp>
-#include <util/double_buffer.hpp>
 #include <algorithms.hpp>
 #include <connection.hpp>
+#include <communication/gathered_vector.hpp>
 #include <event_queue.hpp>
 #include <spike.hpp>
 #include <util/debug.hpp>
+#include <util/double_buffer.hpp>
 
 namespace nest {
 namespace mc {
@@ -87,29 +87,26 @@ public:
     /// Perform exchange of spikes.
     ///
     /// Takes as input the list of local_spikes that were generated on the calling domain.
-    ///
-    /// Returns a vector of event queues, with one queue for each local cell group. The
-    /// events in each queue are all events that must be delivered to targets in that cell
-    /// group as a result of the global spike exchange.
-    std::vector<event_queue> exchange(const std::vector<spike_type>& local_spikes,
-        std::function<void(const std::vector<spike_type>&)> global_export_callback)
-    {
+    /// Returns the full global set of vectors, along with meta data about their partition
+    gathered_vector<spike_type> exchange(const std::vector<spike_type>& local_spikes) {
         // global all-to-all to gather a local copy of the global spike list on each node.
         auto global_spikes = communication_policy_.gather_spikes( local_spikes );
         num_spikes_ += global_spikes.size();
+        return global_spikes;
+    }
 
-        global_export_callback(global_spikes);
-
-        // check each global spike in turn to see it generates local events.
-        // if so, make the events and insert them into the appropriate event list.
+    /// Check each global spike in turn to see it generates local events.
+    /// If so, make the events and insert them into the appropriate event list.
+    /// Return a vector that contains the event queues for each local cell group.
+    ///
+    /// Returns a vector of event queues, with one queue for each local cell group. The
+    /// events in each queue are all events that must be delivered to targets in that cell
+    /// group as a result of the global spike exchange.
+    std::vector<event_queue> make_event_queues(const gathered_vector<spike_type>& global_spikes) {
         auto queues = std::vector<event_queue>(num_groups_local());
-
-        for (auto spike : global_spikes) {
+        for (auto spike : global_spikes.values()) {
             // search for targets
-            auto targets =
-                std::equal_range(
-                    connections_.begin(), connections_.end(), spike.source
-                );
+            auto targets = std::equal_range(connections_.begin(), connections_.end(), spike.source);
 
             // generate an event for each target
             for (auto it=targets.first; it!=targets.second; ++it) {
@@ -121,8 +118,7 @@ public:
         return queues;
     }
 
-    /// Returns the total number of global spikes over the duration
-    /// of the simulation
+    /// Returns the total number of global spikes over the duration of the simulation
     uint64_t num_spikes() const { return num_spikes_; }
 
     const std::vector<connection_type>& connections() const {
diff --git a/src/communication/gathered_vector.hpp b/src/communication/gathered_vector.hpp
new file mode 100644
index 00000000..9c83f744
--- /dev/null
+++ b/src/communication/gathered_vector.hpp
@@ -0,0 +1,52 @@
+#pragma once
+
+#include <cstdint>
+#include <numeric>
+#include <vector>
+
+#include <algorithms.hpp>
+
+namespace nest {
+namespace mc {
+
+template <typename T>
+class gathered_vector {
+public:
+    using value_type = T;
+    using count_type = unsigned;
+
+    gathered_vector(std::vector<value_type>&& v, std::vector<count_type>&& p) :
+        values_(std::move(v)),
+        partition_(std::move(p))
+    {
+        EXPECTS(std::is_sorted(partition_.begin(), partition_.end()));
+        EXPECTS(std::size_t(partition_.back()) == v.size());
+    }
+
+    /// the partition of distribution
+    const std::vector<count_type>& partition() const {
+        return partition_;
+    }
+
+    /// the number of entries in the gathered vector in partiion i
+    count_type count(std::size_t i) const {
+        return partition_[i+1] - partition[i];
+    }
+
+    /// the values in the gathered vector
+    const std::vector<value_type>& values() const {
+        return values_;
+    }
+
+    /// the size of the gathered vector
+    std::size_t size() const {
+        return values_.size();
+    }
+
+private:
+    std::vector<value_type> values_;
+    std::vector<count_type> partition_;
+};
+
+} // namespace mc
+} // namespace nest
diff --git a/src/communication/mpi.hpp b/src/communication/mpi.hpp
index 5be71b63..5b43f1a8 100644
--- a/src/communication/mpi.hpp
+++ b/src/communication/mpi.hpp
@@ -10,6 +10,7 @@
 #include <mpi.h>
 
 #include <algorithms.hpp>
+#include <communication/gathered_vector.hpp>
 
 namespace nest {
 namespace mc {
@@ -121,6 +122,38 @@ namespace mpi {
         return buffer;
     }
 
+    /// Gather all of a distributed vector
+    /// Retains the meta data (i.e. vector partition)
+    template <typename T>
+    gathered_vector<T> gather_all_meta(const std::vector<T>& values) {
+        using gathered_type = gathered_vector<T>;
+        using count_type = typename gathered_vector<T>::count_type;
+        using traits = mpi_traits<T>;
+
+        // We have to use int for the count and displs vectors instead
+        // of count_type because these are used as arguments to MPI_Allgatherv
+        // which expects int arguments.
+        auto counts = gather_all(int(values.size()));
+        for (auto& c : counts) {
+            c *= traits::count();
+        }
+        auto displs = algorithms::make_index(counts);
+
+        std::vector<T> buffer(displs.back()/traits::count());
+
+        MPI_Allgatherv(
+            // send buffer
+            values.data(), counts[rank()], traits::mpi_type(),
+            // receive buffer
+            buffer.data(), counts.data(), displs.data(), traits::mpi_type(),
+            MPI_COMM_WORLD
+        );
+
+        std::vector<count_type> part(displs.size());
+        std::copy(displs.begin(), displs.end(), part.begin());
+        return gathered_type(std::move(buffer), std::move(part));
+    }
+
     template <typename T>
     T reduce(T value, MPI_Op op, int root) {
         using traits = mpi_traits<T>;
diff --git a/src/communication/mpi_global_policy.hpp b/src/communication/mpi_global_policy.hpp
index 0e9ec333..8bb597dd 100644
--- a/src/communication/mpi_global_policy.hpp
+++ b/src/communication/mpi_global_policy.hpp
@@ -10,6 +10,7 @@
 
 #include <algorithms.hpp>
 #include <common_types.hpp>
+#include <communication/gathered_vector.hpp>
 #include <communication/mpi.hpp>
 #include <spike.hpp>
 
@@ -18,10 +19,10 @@ namespace mc {
 namespace communication {
 
 struct mpi_global_policy {
-    template <typename I, typename T>
-    static std::vector<spike<I, T>>
-    gather_spikes(const std::vector<spike<I,T>>& local_spikes) {
-        return mpi::gather_all(local_spikes);
+    template <typename Spike>
+    static gathered_vector<Spike>
+    gather_spikes(const std::vector<Spike>& local_spikes) {
+        return mpi::gather_all_meta(local_spikes);
     }
 
     static int id() { return mpi::rank(); }
diff --git a/src/communication/serial_global_policy.hpp b/src/communication/serial_global_policy.hpp
index 9af5eae3..b048b6cf 100644
--- a/src/communication/serial_global_policy.hpp
+++ b/src/communication/serial_global_policy.hpp
@@ -4,6 +4,7 @@
 #include <type_traits>
 #include <vector>
 
+#include <communication/gathered_vector.hpp>
 #include <spike.hpp>
 
 namespace nest {
@@ -11,10 +12,10 @@ namespace mc {
 namespace communication {
 
 struct serial_global_policy {
-    template <typename I, typename T>
-    static const std::vector<spike<I, T>>&
-    gather_spikes(const std::vector<spike<I, T>>& local_spikes) {
-        return local_spikes;
+    template <typename Spike>
+    static gathered_vector<Spike>
+    gather_spikes(const std::vector<Spike>& local_spikes) {
+        return gathered_vector<Spike>(std::vector<Spike>(local_spikes), {0u, 1u});
     }
 
     static int id() {
diff --git a/src/model.hpp b/src/model.hpp
index 4fa778e0..ae7b4a7a 100644
--- a/src/model.hpp
+++ b/src/model.hpp
@@ -130,18 +130,23 @@ public:
             // the previous integration period, generating the postsynaptic
             // events that must be delivered at the start of the next
             // integration period at the latest.
-
-            //TODO:
-            //An improvement might be :
-            //the exchange method simply exchanges spikes, and does not generate the event queues.It returns a struct that has both 1) the global spike list 2) an integer vector that describes the distribution of spikes across the ranks
-            //another method called something like build_queues that takes this spike info and returns the local spikes
-            //    and the callbacks can then be called on the spike information directly in the model.
             auto exchange = [&] () {
-                PE("stepping", "exchange");
+                PE("stepping", "communciation");
+
+                PE("exchange");
                 auto local_spikes = previous_spikes().gather();
+                auto global_spikes = communicator_.exchange(local_spikes);
+                PL();
+
+                PE("spike output");
                 local_export_callback_(local_spikes);
-                future_events() =
-                    communicator_.exchange(local_spikes, global_export_callback_);
+                global_export_callback_(global_spikes.values());
+                PL();
+
+                PE("events");
+                future_events() = communicator_.make_event_queues(global_spikes);
+                PL();
+
                 PL(2);
             };
 
-- 
GitLab