diff --git a/arbor/communication/dry_run_context.cpp b/arbor/communication/dry_run_context.cpp index 1ef5522015965caa3467f989d4253bb9a4b62032..e3a12d28c144258d41995446d28dee5a99f4454b 100644 --- a/arbor/communication/dry_run_context.cpp +++ b/arbor/communication/dry_run_context.cpp @@ -70,6 +70,26 @@ struct dry_run_context_impl { return gathered_vector<cell_gid_type>(std::move(gathered_gids), std::move(partition)); } + std::vector<std::vector<cell_gid_type>> + gather_gj_connections(const std::vector<std::vector<cell_gid_type>> & local_connections) const { + auto local_size = local_connections.size(); + std::vector<std::vector<cell_gid_type>> global_connections; + global_connections.reserve(local_size*num_ranks_); + + for (unsigned i = 0; i < num_ranks_; i++) { + util::append(global_connections, local_connections); + } + + for (unsigned i = 0; i < num_ranks_; i++) { + for (unsigned j = i*local_size; j < (i+1)*local_size; j++){ + for (auto& conn_gid: global_connections[j]) { + conn_gid += num_cells_per_tile_*i; + } + } + } + return global_connections; + } + cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const { cell_label_range global_ranges; for (unsigned i = 0; i < num_ranks_; i++) { diff --git a/arbor/communication/mpi.hpp b/arbor/communication/mpi.hpp index 972e95e7a3df9f006dd56e206271944f679ae84d..7a3546ff1fc78ae8d0eed95137d6791ba149ee64 100644 --- a/arbor/communication/mpi.hpp +++ b/arbor/communication/mpi.hpp @@ -186,6 +186,38 @@ inline std::vector<std::string> gather_all(const std::vector<std::string>& value return string_buffer; } +template <typename T> +std::vector<std::vector<T>> gather_all(const std::vector<std::vector<T>>& values, MPI_Comm comm) { + std::vector<unsigned long> counts_internal, displs_internal; + + // Vector of individual vector sizes + std::vector<unsigned long> internal_sizes(values.size()); + std::transform(values.begin(), values.end(), internal_sizes.begin(), [](const auto& val){return int(val.size());}); + + counts_internal = gather_all(internal_sizes, comm); + auto displs_internal_part = util::make_partition(displs_internal, counts_internal); + + // Concatenate all internal vector data + std::vector<T> values_concat; + for (const auto& v: values) { + values_concat.insert(values_concat.end(), v.begin(), v.end()); + } + + // Gather all concatenated vector data + auto global_vec_concat = gather_all(values_concat, comm); + + // Construct the vector of vectors + std::vector<std::vector<T>> global_vec; + global_vec.reserve(displs_internal_part.size()); + + for (const auto& internal_vec_range: displs_internal_part) { + global_vec.emplace_back(global_vec_concat.begin()+internal_vec_range.first, + global_vec_concat.begin()+internal_vec_range.second); + } + + return global_vec; +} + /// Gather all of a distributed vector /// Retains the meta data (i.e. vector partition) template <typename T> diff --git a/arbor/communication/mpi_context.cpp b/arbor/communication/mpi_context.cpp index fa1e57c1388ea926719f798caf642bb8284ed2fc..22a2428343b235dd3bbcae3941ee030372f2ce4b 100644 --- a/arbor/communication/mpi_context.cpp +++ b/arbor/communication/mpi_context.cpp @@ -39,6 +39,11 @@ struct mpi_context_impl { return mpi::gather_all_with_partition(local_gids, comm_); } + std::vector<std::vector<cell_gid_type>> + gather_gj_connections(const std::vector<std::vector<cell_gid_type>>& local_connections) const { + return mpi::gather_all(local_connections, comm_); + } + cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const { std::vector<cell_size_type> sizes; std::vector<cell_tag_type> labels; diff --git a/arbor/distributed_context.hpp b/arbor/distributed_context.hpp index ab3ec6587478e991760c9f7d46ef80cbfdce2631..2cea04e58fcf79c09c2360ba9c0d9310c26c7c2e 100644 --- a/arbor/distributed_context.hpp +++ b/arbor/distributed_context.hpp @@ -44,6 +44,7 @@ class distributed_context { public: using spike_vector = std::vector<arb::spike>; using gid_vector = std::vector<cell_gid_type>; + using gj_connection_vector = std::vector<gid_vector>; // default constructor uses a local context: see below. distributed_context(); @@ -64,6 +65,10 @@ public: return impl_->gather_gids(local_gids); } + gj_connection_vector gather_gj_connections(const gj_connection_vector& local_connections) const { + return impl_->gather_gj_connections(local_connections); + } + cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const { return impl_->gather_cell_label_range(local_ranges); } @@ -100,6 +105,8 @@ private: gather_spikes(const spike_vector& local_spikes) const = 0; virtual gathered_vector<cell_gid_type> gather_gids(const gid_vector& local_gids) const = 0; + virtual gj_connection_vector + gather_gj_connections(const gj_connection_vector& local_connections) const = 0; virtual cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const = 0; virtual cell_labels_and_gids @@ -129,6 +136,10 @@ private: gather_gids(const gid_vector& local_gids) const override { return wrapped.gather_gids(local_gids); } + std::vector<std::vector<cell_gid_type>> + gather_gj_connections(const gj_connection_vector& local_connections) const override { + return wrapped.gather_gj_connections(local_connections); + } cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const override { return wrapped.gather_cell_label_range(local_ranges); @@ -179,6 +190,10 @@ struct local_context { {0u, static_cast<count_type>(local_gids.size())} ); } + std::vector<std::vector<cell_gid_type>> + gather_gj_connections(const std::vector<std::vector<cell_gid_type>>& local_connections) const { + return local_connections; + } cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const { return local_ranges; diff --git a/arbor/partition_load_balance.cpp b/arbor/partition_load_balance.cpp index f1f064b7957ec55f788e0f9e98d15ab85f853df7..73227f9238517c92a6217b3beb1f52ea0c78ee7a 100644 --- a/arbor/partition_load_balance.cpp +++ b/arbor/partition_load_balance.cpp @@ -67,6 +67,44 @@ domain_decomposition partition_load_balance( auto gid_part = make_partition( gid_divisions, transform_view(make_span(num_domains), dom_size)); + // Global gj_connection table + + // Generate a local gj_connection table. + // The table is indexed by the index of the target gid in the gid_part of that domain. + // If gid_part[domain_id] = [a, b); local_gj_connection of gid `x` is at index `x-a`. + const auto dom_range = gid_part[domain_id]; + std::vector<std::vector<cell_gid_type>> local_gj_connection_table(dom_range.second-dom_range.first); + for (auto gid: make_span(gid_part[domain_id])) { + for (const auto& c: rec.gap_junctions_on(gid)) { + local_gj_connection_table[gid-dom_range.first].push_back(c.peer.gid); + } + } + // Sort the gj connections of each local cell. + for (auto& gid_conns: local_gj_connection_table) { + util::sort(gid_conns); + } + + // Gather the global gj_connection table. + // The global gj_connection table after gathering is indexed by gid. + auto global_gj_connection_table = ctx->distributed->gather_gj_connections(local_gj_connection_table); + + // Make all gj_connections bidirectional. + std::vector<std::unordered_set<cell_gid_type>> missing_peers(global_gj_connection_table.size()); + for (auto gid: make_span(global_gj_connection_table.size())) { + const auto& local_conns = global_gj_connection_table[gid]; + for (auto peer: local_conns) { + auto& peer_conns = global_gj_connection_table[peer]; + // If gid is not in the peer connection table insert it into the + // missing_peers set + if (!std::binary_search(peer_conns.begin(), peer_conns.end(), gid)) { + missing_peers[peer].insert(gid); + } + } + } + // Append the missing peers into the global_gj_connections table + for (unsigned i = 0; i < global_gj_connection_table.size(); ++i) { + std::move(missing_peers[i].begin(), missing_peers[i].end(), std::back_inserter(global_gj_connection_table[i])); + } // Local load balance std::vector<std::vector<cell_gid_type>> super_cells; //cells connected by gj @@ -78,7 +116,7 @@ domain_decomposition partition_load_balance( // Connected components algorithm using BFS std::queue<cell_gid_type> q; for (auto gid: make_span(gid_part[domain_id])) { - if (!rec.gap_junctions_on(gid).empty()) { + if (!global_gj_connection_table[gid].empty()) { // If cell hasn't been visited yet, must belong to new super_cell // Perform BFS starting from that cell if (!visited.count(gid)) { @@ -90,11 +128,9 @@ domain_decomposition partition_load_balance( q.pop(); cg.push_back(element); // Adjacency list - auto conns = rec.gap_junctions_on(element); - for (auto c: conns) { - if (!visited.count(c.peer.gid)) { - visited.insert(c.peer.gid); - q.push(c.peer.gid); + for (const auto& peer: global_gj_connection_table[element]) { + if (visited.insert(peer).second) { + q.push(peer); } } } diff --git a/test/unit-distributed/test_domain_decomposition.cpp b/test/unit-distributed/test_domain_decomposition.cpp index 92a7a45374bb850fa03d9f09aaf18d88bf481a47..eb539ea54123ba35141f09dcfffb04d65f60f60b 100644 --- a/test/unit-distributed/test_domain_decomposition.cpp +++ b/test/unit-distributed/test_domain_decomposition.cpp @@ -66,7 +66,9 @@ namespace { class gj_symmetric: public recipe { public: - gj_symmetric(unsigned num_ranks): ncopies_(num_ranks){} + gj_symmetric(unsigned num_ranks, bool fully_connected): + ncopies_(num_ranks), + fully_connected_(fully_connected) {} cell_size_type num_cells() const override { return size_*ncopies_; @@ -79,22 +81,34 @@ namespace { cell_kind get_cell_kind(cell_gid_type gid) const override { return cell_kind::cable; } + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { unsigned shift = (gid/size_)*size_; switch (gid % size_) { - case 1 : return { gap_junction_connection({7 + shift, "gj"}, {"gj"}, 0.1)}; - case 2 : return { - gap_junction_connection({6 + shift, "gj"}, {"gj"}, 0.1), - gap_junction_connection({9 + shift, "gj"}, {"gj"}, 0.1) - }; + case 1 : { + if (!fully_connected_) return {}; + return {gap_junction_connection({7 + shift, "gj"}, {"gj"}, 0.1)}; + } + case 2 : { + if (!fully_connected_) return {}; + return { + gap_junction_connection({6 + shift, "gj"}, {"gj"}, 0.1), + gap_junction_connection({9 + shift, "gj"}, {"gj"}, 0.1) + }; + } case 6 : return { gap_junction_connection({2 + shift, "gj"}, {"gj"}, 0.1), gap_junction_connection({7 + shift, "gj"}, {"gj"}, 0.1) }; - case 7 : return { - gap_junction_connection({6 + shift, "gj"}, {"gj"}, 0.1), - gap_junction_connection({1 + shift, "gj"}, {"gj"}, 0.1) - }; + case 7 : { + if (!fully_connected_) { + return {gap_junction_connection({1 + shift, "gj"}, {"gj"}, 0.1)}; + } + return { + gap_junction_connection({6 + shift, "gj"}, {"gj"}, 0.1), + gap_junction_connection({1 + shift, "gj"}, {"gj"}, 0.1) + }; + } case 9 : return { gap_junction_connection({2 + shift, "gj"}, {"gj"}, 0.1)}; default : return {}; } @@ -103,11 +117,15 @@ namespace { private: cell_size_type size_ = 10; unsigned ncopies_; + bool fully_connected_; }; - class gj_non_symmetric: public recipe { + class gj_multi_group: public recipe { public: - gj_non_symmetric(unsigned num_ranks): groups_(num_ranks), size_(num_ranks){} + gj_multi_group(unsigned num_ranks, bool fully_connected): + groups_(num_ranks), + size_(num_ranks), + fully_connected_(fully_connected) {} cell_size_type num_cells() const override { return size_*groups_; @@ -120,10 +138,26 @@ namespace { cell_kind get_cell_kind(cell_gid_type gid) const override { return cell_kind::cable; } + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { + // Example topology on 4 ranks of 4 cells each + // Fully connected Not fully connected + // 0 1 2 3 0 1 2 3 + // ^ | + // | | + // v v + // 4 5 6 7 4 5 6 7 + // ^ | + // | | + // v v + // 8 9 10 11 8 9 10 11 + // ^ | + // | | + // v v + // 12 13 14 15 12 13 14 15 unsigned group = gid/groups_; unsigned id = gid%size_; - if (id == group && group != (groups_ - 1)) { + if (id == group && group != (groups_ - 1) && fully_connected_) { return {gap_junction_connection({gid + size_, "gj"}, {"gj"}, 0.1)}; } else if (id == group - 1) { @@ -134,6 +168,50 @@ namespace { } } + private: + unsigned groups_; + cell_size_type size_; + bool fully_connected_; + }; + + class gj_single_group: public recipe { + public: + gj_single_group(unsigned num_ranks): + groups_(num_ranks), + size_(num_ranks){} + + cell_size_type num_cells() const override { + return size_*groups_; + } + + arb::util::unique_any get_cell_description(cell_gid_type) const override { + return {}; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable; + } + + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { + // Example topology on 4 ranks of 4 cells each + // 0 1 2 3 + // | + // v + // 4 5 6 7 + // | + // v + // 8 9 10 11 + // | + // v + // 12 13 14 15 + unsigned group = gid/groups_; + unsigned id = gid%size_; + if (group!= 0 && id == 1) { + return {gap_junction_connection({gid - size_, "gj"}, {"gj"}, 0.1)}; + } + return {}; + } + private: unsigned groups_; cell_size_type size_; @@ -288,8 +366,7 @@ TEST(domain_decomposition, heterogeneous_population) { } } -TEST(domain_decomposition, symmetric_groups) -{ +TEST(domain_decomposition, symmetric_groups) { proc_allocation resources{1, -1}; int nranks = 1; int rank = 0; @@ -300,68 +377,68 @@ TEST(domain_decomposition, symmetric_groups) #else auto ctx = make_context(resources); #endif - auto R = gj_symmetric(nranks); - const auto D0 = partition_load_balance(R, ctx); - EXPECT_EQ(6u, D0.groups.size()); - - unsigned shift = rank*R.num_cells()/nranks; - std::vector<std::vector<cell_gid_type>> expected_groups0 = - { {0 + shift}, - {3 + shift}, - {4 + shift}, - {5 + shift}, - {8 + shift}, - {1 + shift, 2 + shift, 6 + shift, 7 + shift, 9 + shift} - }; - - for (unsigned i = 0; i < 6; i++){ - EXPECT_EQ(expected_groups0[i], D0.groups[i].gids); - } + std::vector<gj_symmetric> recipes = {gj_symmetric(nranks, true), gj_symmetric(nranks, false)}; + for (const auto& R: recipes) { + const auto D0 = partition_load_balance(R, ctx); + EXPECT_EQ(6u, D0.groups.size()); + + unsigned shift = rank * R.num_cells()/nranks; + std::vector<std::vector<cell_gid_type>> expected_groups0 = + {{0 + shift}, + {3 + shift}, + {4 + shift}, + {5 + shift}, + {8 + shift}, + {1 + shift, 2 + shift, 6 + shift, 7 + shift, 9 + shift} + }; - unsigned cells_per_rank = R.num_cells()/nranks; - for (unsigned i = 0; i < R.num_cells(); i++) { - EXPECT_EQ(i/cells_per_rank, (unsigned)D0.gid_domain(i)); - } + for (unsigned i = 0; i < 6; i++) { + EXPECT_EQ(expected_groups0[i], D0.groups[i].gids); + } - // Test different group_hints - partition_hint_map hints; - hints[cell_kind::cable].cpu_group_size = R.num_cells(); - hints[cell_kind::cable].prefer_gpu = false; + unsigned cells_per_rank = R.num_cells()/nranks; + for (unsigned i = 0; i < R.num_cells(); i++) { + EXPECT_EQ(i/cells_per_rank, (unsigned) D0.gid_domain(i)); + } - const auto D1 = partition_load_balance(R, ctx, hints); - EXPECT_EQ(1u, D1.groups.size()); + // Test different group_hints + partition_hint_map hints; + hints[cell_kind::cable].cpu_group_size = R.num_cells(); + hints[cell_kind::cable].prefer_gpu = false; - std::vector<cell_gid_type> expected_groups1 = - { 0 + shift, 3 + shift, 4 + shift, 5 + shift, 8 + shift, - 1 + shift, 2 + shift, 6 + shift, 7 + shift, 9 + shift }; + const auto D1 = partition_load_balance(R, ctx, hints); + EXPECT_EQ(1u, D1.groups.size()); - EXPECT_EQ(expected_groups1, D1.groups[0].gids); + std::vector<cell_gid_type> expected_groups1 = + {0 + shift, 3 + shift, 4 + shift, 5 + shift, 8 + shift, + 1 + shift, 2 + shift, 6 + shift, 7 + shift, 9 + shift}; - for (unsigned i = 0; i < R.num_cells(); i++) { - EXPECT_EQ(i/cells_per_rank, (unsigned)D1.gid_domain(i)); - } + EXPECT_EQ(expected_groups1, D1.groups[0].gids); - hints[cell_kind::cable].cpu_group_size = cells_per_rank/2; - hints[cell_kind::cable].prefer_gpu = false; + for (unsigned i = 0; i < R.num_cells(); i++) { + EXPECT_EQ(i/cells_per_rank, (unsigned) D1.gid_domain(i)); + } - const auto D2 = partition_load_balance(R, ctx, hints); - EXPECT_EQ(2u, D2.groups.size()); + hints[cell_kind::cable].cpu_group_size = cells_per_rank/2; + hints[cell_kind::cable].prefer_gpu = false; - std::vector<std::vector<cell_gid_type>> expected_groups2 = - { { 0 + shift, 3 + shift, 4 + shift, 5 + shift, 8 + shift }, - { 1 + shift, 2 + shift, 6 + shift, 7 + shift, 9 + shift } }; + const auto D2 = partition_load_balance(R, ctx, hints); + EXPECT_EQ(2u, D2.groups.size()); - for (unsigned i = 0; i < 2u; i++) { - EXPECT_EQ(expected_groups2[i], D2.groups[i].gids); - } - for (unsigned i = 0; i < R.num_cells(); i++) { - EXPECT_EQ(i/cells_per_rank, (unsigned)D2.gid_domain(i)); - } + std::vector<std::vector<cell_gid_type>> expected_groups2 = + {{0 + shift, 3 + shift, 4 + shift, 5 + shift, 8 + shift}, + {1 + shift, 2 + shift, 6 + shift, 7 + shift, 9 + shift}}; + for (unsigned i = 0; i < 2u; i++) { + EXPECT_EQ(expected_groups2[i], D2.groups[i].gids); + } + for (unsigned i = 0; i < R.num_cells(); i++) { + EXPECT_EQ(i/cells_per_rank, (unsigned) D2.gid_domain(i)); + } + } } -TEST(domain_decomposition, non_symmetric_groups) -{ +TEST(domain_decomposition, gj_multi_distributed_groups) { proc_allocation resources{1, -1}; int nranks = 1; int rank = 0; @@ -372,25 +449,69 @@ TEST(domain_decomposition, non_symmetric_groups) #else auto ctx = make_context(resources); #endif - /*if (nranks == 1) { - return; - }*/ + std::vector<gj_multi_group> recipes = {gj_multi_group(nranks, true), gj_multi_group(nranks, false)}; + for (const auto& R: recipes) { + const auto D = partition_load_balance(R, ctx); + + unsigned cells_per_rank = nranks; + // check groups + unsigned i = 0; + for (unsigned gid = rank * cells_per_rank; gid < (rank + 1) * cells_per_rank; gid++) { + if (gid % nranks == (unsigned) rank - 1) { + continue; + } else if (gid % nranks == (unsigned) rank && rank != nranks - 1) { + std::vector<cell_gid_type> cg = {gid, gid + cells_per_rank}; + EXPECT_EQ(cg, D.groups[D.groups.size() - 1].gids); + } else { + std::vector<cell_gid_type> cg = {gid}; + EXPECT_EQ(cg, D.groups[i++].gids); + } + } + // check gid_domains + for (unsigned gid = 0; gid < R.num_cells(); gid++) { + auto group = gid / cells_per_rank; + auto idx = gid % cells_per_rank; + unsigned ngroups = nranks; + if (idx == group - 1) { + EXPECT_EQ(group - 1, (unsigned) D.gid_domain(gid)); + } else if (idx == group && group != ngroups - 1) { + EXPECT_EQ(group, (unsigned) D.gid_domain(gid)); + } else { + EXPECT_EQ(group, (unsigned) D.gid_domain(gid)); + } + } + } +} - auto R = gj_non_symmetric(nranks); +TEST(domain_decomposition, gj_single_distributed_group) { + proc_allocation resources{1, -1}; + int nranks = 1; + int rank = 0; +#ifdef TEST_MPI + auto ctx = make_context(resources, MPI_COMM_WORLD); + MPI_Comm_size(MPI_COMM_WORLD, &nranks); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); +#else + auto ctx = make_context(resources); +#endif + auto R = gj_single_group(nranks); const auto D = partition_load_balance(R, ctx); unsigned cells_per_rank = nranks; // check groups unsigned i = 0; - for (unsigned gid = rank*cells_per_rank; gid < (rank + 1)*cells_per_rank; gid++) { - if (gid % nranks == (unsigned)rank - 1) { - continue; - } - else if (gid % nranks == (unsigned)rank && rank != nranks - 1) { - std::vector<cell_gid_type> cg = {gid, gid+cells_per_rank}; - EXPECT_EQ(cg, D.groups[D.groups.size()-1].gids); - } - else { + for (unsigned gid = rank * cells_per_rank; gid < (rank + 1) * cells_per_rank; gid++) { + if (gid%nranks == 1) { + if (rank == 0) { + std::vector<cell_gid_type> cg; + for (int r = 0; r < nranks; ++r) { + cg.push_back(gid + (r * nranks)); + } + EXPECT_EQ(cg, D.groups.back().gids); + } else { + continue; + } + } else { std::vector<cell_gid_type> cg = {gid}; EXPECT_EQ(cg, D.groups[i++].gids); } @@ -398,16 +519,11 @@ TEST(domain_decomposition, non_symmetric_groups) // check gid_domains for (unsigned gid = 0; gid < R.num_cells(); gid++) { auto group = gid/cells_per_rank; - auto idx = gid % cells_per_rank; - unsigned ngroups = nranks; - if (idx == group - 1) { - EXPECT_EQ(group - 1, (unsigned)D.gid_domain(gid)); - } - else if (idx == group && group != ngroups - 1) { - EXPECT_EQ(group, (unsigned)D.gid_domain(gid)); - } - else { - EXPECT_EQ(group, (unsigned)D.gid_domain(gid)); + auto idx = gid%cells_per_rank; + if (idx == 1) { + EXPECT_EQ(0u, (unsigned) D.gid_domain(gid)); + } else { + EXPECT_EQ(group, (unsigned) D.gid_domain(gid)); } } -} +} \ No newline at end of file diff --git a/test/unit-distributed/test_mpi.cpp b/test/unit-distributed/test_mpi.cpp index 95ae3a8a649e05ff23944ace249f21321ab24390..4097eb488dd8207f7f6e1c292e6b173326985c81 100644 --- a/test/unit-distributed/test_mpi.cpp +++ b/test/unit-distributed/test_mpi.cpp @@ -58,6 +58,40 @@ TEST(mpi, gather_all) { EXPECT_EQ(expected, gathered); } +TEST(mpi, gather_all_nested_vec) { + int id = mpi::rank(MPI_COMM_WORLD); + int size = mpi::size(MPI_COMM_WORLD); + + std::vector<std::vector<big_thing>> data; + if (id%2) { + for (int s = 0; s < (id+1)*2; s++) { + data.push_back({id + s, id + s + 7, id + s + 8}); + } + } + else { + for (int s = 0; s < (id+1)*2; s++) { + data.push_back({id + 2*s}); + } + } + + std::vector<std::vector<big_thing>> expected; + for (int i = 0; i<size; ++i) { + if (i%2) { + for (int s = 0; s < (i+1)*2; s++) { + expected.push_back({i + s, i + s + 7, i + s + 8}); + } + } + else { + for (int s = 0; s < (i+1)*2; s++) { + expected.push_back({i + 2*s}); + } + } + } + + auto gathered = mpi::gather_all(data, MPI_COMM_WORLD); + EXPECT_EQ(expected, gathered); +} + TEST(mpi, gather_all_with_partition) { int id = mpi::rank(MPI_COMM_WORLD); int size = mpi::size(MPI_COMM_WORLD); diff --git a/test/unit/test_domain_decomposition.cpp b/test/unit/test_domain_decomposition.cpp index 2d71b23990634da22a07525ec4031d67a7d232f1..cae46a3cb0c7f60764dbc3948accd0516e21cf1d 100644 --- a/test/unit/test_domain_decomposition.cpp +++ b/test/unit/test_domain_decomposition.cpp @@ -60,7 +60,7 @@ namespace { class gap_recipe: public recipe { public: - gap_recipe() {} + gap_recipe(bool full_connected): fully_connected_(full_connected) {} cell_size_type num_cells() const override { return size_; @@ -77,47 +77,73 @@ namespace { } std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { switch (gid) { - case 0 : return { - gap_junction_connection({13, "gj"}, {"gj"}, 0.1) - }; - case 2 : return { - gap_junction_connection({7, "gj"}, {"gj"}, 0.1), - gap_junction_connection({11, "gj"}, {"gj"}, 0.1) - }; - case 3 : return { - gap_junction_connection({4, "gj"}, {"gj"}, 0.1), - gap_junction_connection({8, "gj"}, {"gj"}, 0.1) - }; - case 4 : return { - gap_junction_connection({3, "gj"}, {"gj"}, 0.1), - gap_junction_connection({8, "gj"}, {"gj"}, 0.1), - gap_junction_connection({9, "gj"}, {"gj"}, 0.1) - }; - case 7 : return { - gap_junction_connection({2, "gj"}, {"gj"}, 0.1), - gap_junction_connection({11, "gj"}, {"gj"}, 0.1) - };; - case 8 : return { - gap_junction_connection({3, "gj"}, {"gj"}, 0.1), - gap_junction_connection({4, "gj"}, {"gj"}, 0.1) - };; - case 9 : return { - gap_junction_connection({4, "gj"}, {"gj"}, 0.1) - }; - case 11 : return { - gap_junction_connection({2, "gj"}, {"gj"}, 0.1), - gap_junction_connection({7, "gj"}, {"gj"}, 0.1) - }; - case 13 : return { - gap_junction_connection({0, "gj"}, {"gj"}, 0.1) - }; - default : return {}; + case 0: return {gap_junction_connection({13, "gj"}, {"gj"}, 0.1)}; + case 2: return {gap_junction_connection({7, "gj"}, {"gj"}, 0.1)}; + case 3: return {gap_junction_connection({8, "gj"}, {"gj"}, 0.1)}; + case 4: { + if (!fully_connected_) return {gap_junction_connection({9, "gj"}, {"gj"}, 0.1)}; + return { + gap_junction_connection({8, "gj"}, {"gj"}, 0.1), + gap_junction_connection({9, "gj"}, {"gj"}, 0.1) + }; + } + case 7: { + if (!fully_connected_) return {}; + return { + gap_junction_connection({2, "gj"}, {"gj"}, 0.1), + gap_junction_connection({11, "gj"}, {"gj"}, 0.1) + }; + } + case 8: { + if (!fully_connected_) return {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)}; + return { + gap_junction_connection({3, "gj"}, {"gj"}, 0.1), + gap_junction_connection({4, "gj"}, {"gj"}, 0.1) + }; + } + case 9: { + if (!fully_connected_) return {}; + return {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)}; + } + case 11: return {gap_junction_connection({7, "gj"}, {"gj"}, 0.1)}; + case 13: { + if (!fully_connected_) return {}; + return { gap_junction_connection({0, "gj"}, {"gj"}, 0.1)}; + } + default: return {}; } } private: + bool fully_connected_ = true; cell_size_type size_ = 15; }; + + class custom_gap_recipe: public recipe { + public: + custom_gap_recipe(cell_size_type ncells, std::vector<std::vector<gap_junction_connection>> gj_conns): + size_(ncells), gj_conns_(std::move(gj_conns)){} + + cell_size_type num_cells() const override { + return size_; + } + + arb::util::unique_any get_cell_description(cell_gid_type) const override { + auto c = arb::make_cell_soma_only(false); + c.decorations.place(mlocation{0,1}, junction("gj"), "gj"); + return {arb::cable_cell(c)}; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable; + } + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { + return gj_conns_[gid]; + } + private: + cell_size_type size_ = 7; + std::vector<std::vector<gap_junction_connection>> gj_conns_; + }; } // test assumes one domain @@ -316,48 +342,119 @@ TEST(domain_decomposition, hints) { EXPECT_EQ(expected_ss_groups, ss_groups); } -TEST(domain_decomposition, compulsory_groups) -{ +TEST(domain_decomposition, gj_recipe) { proc_allocation resources; resources.num_threads = 1; resources.gpu_id = -1; // disable GPU if available auto ctx = make_context(resources); - auto R = gap_recipe(); - const auto D0 = partition_load_balance(R, ctx); - EXPECT_EQ(9u, D0.groups.size()); + auto recipes = {gap_recipe(false), gap_recipe(true)}; + for (const auto& R: recipes) { + const auto D0 = partition_load_balance(R, ctx); + EXPECT_EQ(9u, D0.groups.size()); - std::vector<std::vector<cell_gid_type>> expected_groups0 = - { {1}, {5}, {6}, {10}, {12}, {14}, {0, 13}, {2, 7, 11}, {3, 4, 8, 9} }; + std::vector<std::vector<cell_gid_type>> expected_groups0 = + {{1}, {5}, {6}, {10}, {12}, {14}, {0, 13}, {2, 7, 11}, {3, 4, 8, 9}}; - for (unsigned i = 0; i < 9u; i++) { - EXPECT_EQ(expected_groups0[i], D0.groups[i].gids); - } + for (unsigned i = 0; i < 9u; i++) { + EXPECT_EQ(expected_groups0[i], D0.groups[i].gids); + } - // Test different group_hints - partition_hint_map hints; - hints[cell_kind::cable].cpu_group_size = 3; - hints[cell_kind::cable].prefer_gpu = false; + // Test different group_hints + partition_hint_map hints; + hints[cell_kind::cable].cpu_group_size = 3; + hints[cell_kind::cable].prefer_gpu = false; - const auto D1 = partition_load_balance(R, ctx, hints); - EXPECT_EQ(5u, D1.groups.size()); + const auto D1 = partition_load_balance(R, ctx, hints); + EXPECT_EQ(5u, D1.groups.size()); - std::vector<std::vector<cell_gid_type>> expected_groups1 = - { {1, 5, 6}, {10, 12, 14}, {0, 13}, {2, 7, 11}, {3, 4, 8, 9} }; + std::vector<std::vector<cell_gid_type>> expected_groups1 = + {{1, 5, 6}, {10, 12, 14}, {0, 13}, {2, 7, 11}, {3, 4, 8, 9}}; - for (unsigned i = 0; i < 5u; i++) { - EXPECT_EQ(expected_groups1[i], D1.groups[i].gids); - } + for (unsigned i = 0; i < 5u; i++) { + EXPECT_EQ(expected_groups1[i], D1.groups[i].gids); + } - hints[cell_kind::cable].cpu_group_size = 20; - hints[cell_kind::cable].prefer_gpu = false; + hints[cell_kind::cable].cpu_group_size = 20; + hints[cell_kind::cable].prefer_gpu = false; - const auto D2 = partition_load_balance(R, ctx, hints); - EXPECT_EQ(1u, D2.groups.size()); + const auto D2 = partition_load_balance(R, ctx, hints); + EXPECT_EQ(1u, D2.groups.size()); + + std::vector<cell_gid_type> expected_groups2 = + {1, 5, 6, 10, 12, 14, 0, 13, 2, 7, 11, 3, 4, 8, 9}; + + EXPECT_EQ(expected_groups2, D2.groups[0].gids); + } +} + +TEST(domain_decomposition, unidirectional_gj_recipe) { + proc_allocation resources; + resources.num_threads = 1; + resources.gpu_id = -1; + auto ctx = make_context(resources); + { + std::vector<std::vector<gap_junction_connection>> gj_conns = + { + {gap_junction_connection({1, "gj"}, {"gj"}, 0.1)}, + {}, + {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)}, + {gap_junction_connection({0, "gj"}, {"gj"}, 0.1), gap_junction_connection({5, "gj"}, {"gj"}, 0.1)}, + {}, + {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)}, + {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)} + }; + auto R = custom_gap_recipe(gj_conns.size(), gj_conns); + const auto D = partition_load_balance(R, ctx); + std::vector<cell_gid_type> expected_group = {0, 1, 2, 3, 4, 5, 6}; - std::vector<cell_gid_type> expected_groups2 = - { 1, 5, 6, 10, 12, 14, 0, 13, 2, 7, 11, 3, 4, 8, 9 }; + EXPECT_EQ(1u, D.groups.size()); + EXPECT_EQ(expected_group, D.groups[0].gids); + } + { + std::vector<std::vector<gap_junction_connection>> gj_conns = + { + {}, + {gap_junction_connection({3, "gj"}, {"gj"}, 0.1)}, + {gap_junction_connection({0, "gj"}, {"gj"}, 0.1)}, + {gap_junction_connection({0, "gj"}, {"gj"}, 0.1), gap_junction_connection({1, "gj"}, {"gj"}, 0.1)}, + {}, + {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)}, + {}, + {gap_junction_connection({9, "gj"}, {"gj"}, 0.1)}, + {gap_junction_connection({7, "gj"}, {"gj"}, 0.1)}, + {gap_junction_connection({8, "gj"}, {"gj"}, 0.1)} + }; + auto R = custom_gap_recipe(gj_conns.size(), gj_conns); + const auto D = partition_load_balance(R, ctx); + std::vector<std::vector<cell_gid_type>> expected_groups = {{6}, {0, 1, 2, 3}, {4, 5}, {7, 8, 9}}; - EXPECT_EQ(expected_groups2, D2.groups[0].gids); + EXPECT_EQ(expected_groups.size(), D.groups.size()); + for (unsigned i=0; i < expected_groups.size(); ++i) { + EXPECT_EQ(expected_groups[i], D.groups[i].gids); + } + } + { + std::vector<std::vector<gap_junction_connection>> gj_conns = + { + {}, + {}, + {}, + {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)}, + {}, + {}, + {gap_junction_connection({5, "gj"}, {"gj"}, 0.1), gap_junction_connection({7, "gj"}, {"gj"}, 0.1)}, + {gap_junction_connection({5, "gj"}, {"gj"}, 0.1), gap_junction_connection({4, "gj"}, {"gj"}, 0.1)}, + {gap_junction_connection({0, "gj"}, {"gj"}, 0.1)}, + {} + }; + auto R = custom_gap_recipe(gj_conns.size(), gj_conns); + const auto D = partition_load_balance(R, ctx); + std::vector<std::vector<cell_gid_type>> expected_groups = {{1}, {2}, {9}, {0, 8}, {3, 4, 5, 6, 7}}; + EXPECT_EQ(expected_groups.size(), D.groups.size()); + for (unsigned i=0; i < expected_groups.size(); ++i) { + EXPECT_EQ(expected_groups[i], D.groups[i].gids); + } + } }