From ad19f1742c0000716486c7fef41f707479bcf402 Mon Sep 17 00:00:00 2001
From: Sam Yates <halfflat@gmail.com>
Date: Thu, 4 Jun 2020 18:33:26 +0200
Subject: [PATCH] Tidy synapse collation code. (#1053)

* Tidy synapse collation code.
* Further reduce synapse collation allocations.
---
 arbor/fvm_layout.cpp | 107 ++++++++++++++++++++++++++++++-------------
 1 file changed, 74 insertions(+), 33 deletions(-)

diff --git a/arbor/fvm_layout.cpp b/arbor/fvm_layout.cpp
index a950e5d0..e370372a 100644
--- a/arbor/fvm_layout.cpp
+++ b/arbor/fvm_layout.cpp
@@ -28,6 +28,7 @@ using util::append;
 using util::assign;
 using util::assign_by;
 using util::count_along;
+using util::make_span;
 using util::pw_elements;
 using util::pw_element;
 using util::sort;
@@ -219,7 +220,7 @@ cv_geometry cv_geometry_from_ends(const cable_cell& cell, const locset& lset) {
             [](auto v) { return v!=no_parent; }));
 
     // Construct CV children mapping by sorting CV indices by parent.
-    assign(geom.cv_children, util::make_span(1, n_cv));
+    assign(geom.cv_children, make_span(1, n_cv));
     stable_sort_by(geom.cv_children, [&geom](auto cv) { return geom.cv_parent[cv]; });
 
     geom.cv_children_divs.reserve(n_cv+1);
@@ -243,7 +244,7 @@ cv_geometry cv_geometry_from_ends(const cable_cell& cell, const locset& lset) {
     geom.branch_cv_map.resize(1);
     std::vector<pw_elements<fvm_size_type>>& bmap = geom.branch_cv_map.back();
 
-    for (auto cv: util::make_span(n_cv)) {
+    for (auto cv: make_span(n_cv)) {
         for (auto cable: geom.cables(cv)) {
             if (cable.branch>=bmap.size()) {
                 bmap.resize(cable.branch+1);
@@ -738,8 +739,8 @@ fvm_mechanism_data& append(fvm_mechanism_data& left, const fvm_mechanism_data& r
             append(L.norm_area, R.norm_area);
             append_offset(L.target, target_offset, R.target);
 
-            arb_assert(util::is_sorted_by(L.param_values, util::first));
-            arb_assert(util::is_sorted_by(R.param_values, util::first));
+            arb_assert(util::equal(L.param_values, R.param_values,
+                [](auto& a, auto& b) { return a.first==b.first; }));
             arb_assert(L.param_values.size()==R.param_values.size());
 
             for (auto j: count_along(R.param_values)) {
@@ -934,42 +935,63 @@ fvm_mechanism_data fvm_build_mechanism_data(const cable_cell_global_properties&
 
     struct synapse_instance {
         size_type cv;
-        std::vector<std::pair<unsigned, double>> param_value; // sorted vector of <param_id,value> pairs
+        std::size_t param_values_offset;
         size_type target_index;
     };
 
+    // Working vectors for synapse collation:
+    std::vector<double> default_param_value;
+    std::vector<double> all_param_values;
+    std::vector<synapse_instance> inst_list;
+    std::vector<size_type> cv_order;
+
     for (const auto& entry: cell.synapses()) {
         const std::string& name = entry.first;
         mechanism_info info = catalogue[name];
-        std::vector<synapse_instance> sl;
 
-        // Map from a param string identifier to a unique unsigned int identifier
-        std::map<std::string, unsigned> param_map;
+        std::size_t n_param = info.parameters.size();
+        std::size_t n_inst = entry.second.size();
+
+        default_param_value.resize(n_param);
+        inst_list.clear();
+        inst_list.reserve(n_inst);
+
+        all_param_values.resize(n_param*n_inst);
+
+        // Vectors of parameter values are stored in the order of
+        // parameters given by info.parameters. param_index holds
+        // the mapping from parameter names to their index in this
+        // order.
 
-        // Map from a param unsigned int identifier to a default value
-        std::map<unsigned, double> default_param_value;
+        std::unordered_map<std::string, unsigned> param_index;
 
-        unsigned id=0;
+        unsigned ix=0;
         for (const auto& kv: info.parameters) {
-            param_map[kv.first] = id;
-            default_param_value[id++] = kv.second.default_value;
+            param_index[kv.first] = ix;
+            default_param_value.at(ix++) = kv.second.default_value;
         }
+        arb_assert(ix==n_param);
 
+        std::size_t offset = 0;
         for (const placed<mechanism_desc>& pm: entry.second) {
             verify_mechanism(info, pm.item);
 
             synapse_instance in;
 
-            auto param_value_map = default_param_value;
+            in.param_values_offset = offset;
+            offset += n_param;
+            arb_assert(offset<=all_param_values.size());
+
+            double* in_param = all_param_values.data()+in.param_values_offset;
+            std::copy(default_param_value.begin(), default_param_value.end(), in_param);
+
             for (const auto& kv: pm.item.values()) {
-                param_value_map.at(param_map.at(kv.first)) = kv.second;
+                in_param[param_index.at(kv.first)] = kv.second;
             }
-            std::copy(param_value_map.begin(), param_value_map.end(), std::back_inserter(in.param_value));
-            std::sort(in.param_value.begin(), in.param_value.end());
 
             in.target_index = pm.lid;
             in.cv = D.geometry.location_cv(cell_idx, pm.loc, cv_prefer::cv_nonempty);
-            sl.push_back(std::move(in));
+            inst_list.push_back(std::move(in));
         }
 
         // Permute synapse instances so that they are in increasing order
@@ -977,25 +999,49 @@ fvm_mechanism_data fvm_build_mechanism_data(const cable_cell_global_properties&
         // instances in the same CV with the same parameter values are adjacent.
         // cv_order[i] is the index of the ith instance by this ordering.
 
-        std::vector<size_type> cv_order;
-        assign(cv_order, count_along(sl));
-        sort_by(cv_order, [&](size_type i) {
-            return std::tie(sl[i].cv, sl[i].param_value, sl[i].target_index);
+        auto cmp_inst_param = [n_param, &all_param_values](const synapse_instance& a, const synapse_instance& b) {
+            const double* aparam = all_param_values.data()+a.param_values_offset;
+            const double* bparam = all_param_values.data()+b.param_values_offset;
+
+            for (auto j: make_span(n_param)) {
+                if (aparam[j]<bparam[j]) return -1;
+                if (bparam[j]<aparam[j]) return 1;
+            }
+            return 0;
+        };
+
+        assign(cv_order, count_along(inst_list));
+        sort(cv_order, [&](size_type i, size_type j) {
+            const synapse_instance& a = inst_list[i];
+            const synapse_instance& b = inst_list[j];
+
+            if (a.cv<b.cv) return true;
+            if (b.cv<a.cv) return false;
+
+            auto cmp_param = cmp_inst_param(a, b);
+            if (cmp_param<0) return true;
+            if (cmp_param>0) return false;
+
+            // CV and all parameters are equal, so finally sort on target index.
+            return a.target_index<b.target_index;
         });
 
         bool coalesce = catalogue[name].linear && gprop.coalesce_synapses;
 
         fvm_mechanism_config config;
         config.kind = mechanismKind::point;
-        for (auto& pentry: param_map) {
-            config.param_values.emplace_back(pentry.first, std::vector<value_type>{});
+        for (auto& kv: info.parameters) {
+            config.param_values.emplace_back(kv.first, std::vector<value_type>{});
+            if (!coalesce) {
+                config.param_values.back().second.reserve(n_inst);
+            }
         }
 
         const synapse_instance* prev = nullptr;
         for (auto i: cv_order) {
-            const auto& in = sl[i];
+            const auto& in = inst_list[i];
 
-            if (coalesce && prev && prev->cv==in.cv && prev->param_value==in.param_value) {
+            if (coalesce && prev && prev->cv==in.cv && cmp_inst_param(*prev, in)==0) {
                 ++config.multiplicity.back();
             }
             else {
@@ -1004,13 +1050,8 @@ fvm_mechanism_data fvm_build_mechanism_data(const cable_cell_global_properties&
                     config.multiplicity.push_back(1);
                 }
 
-                unsigned j = 0;
-                for (auto& pentry: param_map) {
-                    arb_assert(config.param_values[j].first==pentry.first);
-                    auto it = std::lower_bound(in.param_value.begin(), in.param_value.end(), pentry.second, [](auto el, unsigned val) {return el.first < val;});
-
-                    arb_assert(it!= in.param_value.end());
-                    config.param_values[j++].second.push_back((*it).second);
+                for (auto j: make_span(n_param)) {
+                    config.param_values[j].second.push_back(all_param_values[in.param_values_offset+j]);
                 }
             }
             config.target.push_back(in.target_index);
-- 
GitLab