diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp index 4466af131a0215293e171e58ae5dcd71808e494b..dc41015934b3f1bf9900f4fe6648bfa00460d381 100644 --- a/src/communication/communicator.hpp +++ b/src/communication/communicator.hpp @@ -45,6 +45,7 @@ public: using util::make_span; num_domains_ = comms_.size(); num_local_groups_ = dom_dec.groups.size(); + num_local_cells_ = dom_dec.num_local_cells; // For caching information about each cell struct gid_info { @@ -52,6 +53,7 @@ public: cell_gid_type gid; // global identifier of cell cell_size_type index_on_domain; // index of cell in this domain connection_list conns; // list of connections terminating at this cell + gid_info() = default; // so we can in a std::vector gid_info(cell_gid_type g, cell_size_type di, connection_list c): gid(g), index_on_domain(di), conns(std::move(c)) {} }; @@ -64,30 +66,37 @@ public: // -> src_domains: array with one entry for every local connection // Also the count of presynaptic sources from each domain // -> src_counts: array with one entry for each domain + + // Record all the gid in a flat vector. + // These are used to map from local index to gid in the parallel loop + // that populates gid_infos. + std::vector<cell_gid_type> gids; + gids.reserve(num_local_cells_); + for (auto g: dom_dec.groups) { + util::append(gids, g.gids); + } + // Build the connection information for local cells in parallel. std::vector<gid_info> gid_infos; - gid_infos.reserve(dom_dec.num_local_cells); + gid_infos.resize(num_local_cells_); + threading::parallel_for::apply(0, gids.size(), + [&](cell_size_type i) { + auto gid = gids[i]; + gid_infos[i] = gid_info(gid, i, rec.connections_on(gid)); + }); - cell_local_size_type n_cons = 0; - cell_local_size_type n_gid = 0; + cell_local_size_type n_cons = + util::sum_by(gid_infos, [](const gid_info& g){ return g.conns.size(); }); std::vector<unsigned> src_domains; + src_domains.reserve(n_cons); std::vector<cell_size_type> src_counts(num_domains_); - for (auto g: make_span(0, num_local_groups_)) { - const auto& group = dom_dec.groups[g]; - for (auto gid: group.gids) { - gid_info info(gid, n_gid, rec.connections_on(gid)); - n_cons += info.conns.size(); - for (auto con: info.conns) { - const auto src = dom_dec.gid_domain(con.source.gid); - src_domains.push_back(src); - src_counts[src]++; - } - gid_infos.push_back(std::move(info)); - ++n_gid; + for (const auto& g: gid_infos) { + for (auto con: g.conns) { + const auto src = dom_dec.gid_domain(con.source.gid); + src_domains.push_back(src); + src_counts[src]++; } } - num_local_cells_ = n_gid; - // Construct the connections. // The loop above gave the information required to construct in place // the connections as partitioned by the domain of their source gid.