diff --git a/tests/global_communication/mpi_listener.hpp b/tests/global_communication/mpi_listener.hpp index 76cdfd9bcc321519b1102bb4fa1895286518b0d8..c255c6be82026cc19b7d29d2a4ecec686126045a 100644 --- a/tests/global_communication/mpi_listener.hpp +++ b/tests/global_communication/mpi_listener.hpp @@ -59,7 +59,7 @@ private: /// TODO : it might be an idea to use a resizeable buffer template <typename... Args> void printf_helper(const char* s, Args&&... args) { - snprintf(buffer_, sizeof(buffer_), s, std::forward<Args>(args)...); + std::snprintf(buffer_, sizeof(buffer_), s, std::forward<Args>(args)...); print(buffer_); } @@ -110,7 +110,7 @@ public: test_case.name() ); } - printf_helper("\n"); + print("\n"); } // Called before a test starts. diff --git a/tests/global_communication/test_communicator.cpp b/tests/global_communication/test_communicator.cpp index 154c458f3d42e61a94075793081b7bc5d8fa8c2b..e2f65c01c078b1af47e4ae8a95c3927d4eeadead 100644 --- a/tests/global_communication/test_communicator.cpp +++ b/tests/global_communication/test_communicator.cpp @@ -1,9 +1,6 @@ #include "../gtest.h" -#include <cstdio> -#include <fstream> -#include <iostream> -#include <string> +#include <stdexcept> #include <vector> #include <communication/communicator.hpp> @@ -200,9 +197,9 @@ namespace { return gid%2? cell_kind::cable1d_neuron: cell_kind::regular_spike_source; } - cell_count_info get_cell_count_info(cell_gid_type) const override { - return {1, 1, 0}; - } + cell_size_type num_sources(cell_gid_type) const override { return 1; } + cell_size_type num_targets(cell_gid_type) const override { return 1; } + cell_size_type num_probes(cell_gid_type) const override { return 0; } std::vector<cell_connection> connections_on(cell_gid_type gid) const override { // a single connection from the preceding cell, i.e. a ring @@ -216,12 +213,15 @@ namespace { 1.0f)}; // delay } + probe_info get_probe(cell_member_type) const override { + throw std::logic_error("no probes"); + } + private: cell_size_type size_; cell_size_type ranks_; }; - cell_gid_type source_of(cell_gid_type gid, cell_size_type num_cells) { if (gid) { return gid-1; @@ -264,9 +264,9 @@ namespace { return gid%2? cell_kind::cable1d_neuron: cell_kind::regular_spike_source; } - cell_count_info get_cell_count_info(cell_gid_type) const override { - return {1, size_, 0}; // sources, targets, probes - } + cell_size_type num_sources(cell_gid_type) const override { return 1; } + cell_size_type num_targets(cell_gid_type) const override { return size_; } + cell_size_type num_probes(cell_gid_type) const override { return 0; } std::vector<cell_connection> connections_on(cell_gid_type gid) const override { std::vector<cell_connection> cons; @@ -282,6 +282,10 @@ namespace { return cons; } + probe_info get_probe(cell_member_type) const override { + throw std::logic_error("no probes"); + } + private: cell_size_type size_; cell_size_type ranks_; @@ -326,7 +330,6 @@ test_ring(const domain_decomposition& D, comm_type& C, F&& f) { using util::assign_from; using util::filter; - auto gids = get_gids(D); auto group_map = get_group_map(D); @@ -338,7 +341,7 @@ test_ring(const domain_decomposition& D, comm_type& C, F&& f) { // gather the global set of spikes auto global_spikes = C.exchange(local_spikes); if (global_spikes.size()!=policy::sum(local_spikes.size())) { - return ::testing::AssertionFailure() << " the number of gathered spikes " + return ::testing::AssertionFailure() << "the number of gathered spikes " << global_spikes.size() << " doesn't match the expected " << policy::sum(local_spikes.size()); } @@ -354,6 +357,7 @@ test_ring(const domain_decomposition& D, comm_type& C, F&& f) { // Iterate over each local gid, and testing whether an event is expected for // that gid. If so, look up the event queue of the cell_group of gid, and // search for the expected event. + int expected_count = 0; for (auto gid: gids) { auto src = source_of(gid, D.num_global_cells); if (f(src)) { @@ -362,8 +366,9 @@ test_ring(const domain_decomposition& D, comm_type& C, F&& f) { auto& q = queues[grp]; if (std::find(q.begin(), q.end(), expected)==q.end()) { return ::testing::AssertionFailure() - << " expected event " << expected << " was not found"; + << "expected event " << expected << " was not found"; } + ++expected_count; } } @@ -373,9 +378,12 @@ test_ring(const domain_decomposition& D, comm_type& C, F&& f) { int num_events = std::accumulate(queues.begin(), queues.end(), 0, [](int l, decltype(queues.front())& r){return l + r.size();}); - int expected_events = util::size(filter(gids, f)); + if (expected_count!=num_events) { + return ::testing::AssertionFailure() << + "the number of events " << num_events << + " does not match expected count " << expected_count; + } - EXPECT_EQ(policy::sum(expected_events), policy::sum(num_events)); return ::testing::AssertionSuccess(); } @@ -399,9 +407,9 @@ TEST(communicator, ring) EXPECT_TRUE(test_ring(D, C, [](cell_gid_type g){return true;})); // last cell in each domain fires EXPECT_TRUE(test_ring(D, C, [n_local](cell_gid_type g){return (g+1)%n_local == 0u;})); - // oddly numbered cells fire + // even-numbered cells fire EXPECT_TRUE(test_ring(D, C, [n_local](cell_gid_type g){return g%2==0;})); - // oddly numbered cells fire + // odd-numbered cells fire EXPECT_TRUE(test_ring(D, C, [n_local](cell_gid_type g){return g%2==1;})); } @@ -422,12 +430,12 @@ test_all2all(const domain_decomposition& D, comm_type& C, F&& f) { std::reverse(local_spikes.begin(), local_spikes.end()); std::vector<cell_gid_type> spike_gids = assign_from( - filter(make_span(0, D.groups.size()), f)); + filter(make_span(0, D.num_global_cells), f)); // gather the global set of spikes auto global_spikes = C.exchange(local_spikes); if (global_spikes.size()!=policy::sum(local_spikes.size())) { - return ::testing::AssertionFailure() << " the number of gathered spikes " + return ::testing::AssertionFailure() << "the number of gathered spikes " << global_spikes.size() << " doesn't match the expected " << policy::sum(local_spikes.size()); } @@ -443,6 +451,7 @@ test_all2all(const domain_decomposition& D, comm_type& C, F&& f) { // Iterate over each local gid, and testing whether an event is expected for // that gid. If so, look up the event queue of the cell_group of gid, and // search for the expected event. + int expected_count = 0; for (auto gid: gids) { // get the event queue that this gid belongs to auto& q = queues[group_map[gid]]; @@ -453,6 +462,7 @@ test_all2all(const domain_decomposition& D, comm_type& C, F&& f) { << "expected event " << expected << " from " << src << " was not found"; } + ++expected_count; } } @@ -462,9 +472,12 @@ test_all2all(const domain_decomposition& D, comm_type& C, F&& f) { int num_events = std::accumulate(queues.begin(), queues.end(), 0, [](int l, decltype(queues.front())& r){return l + r.size();}); - int expected_events = D.num_global_cells*spike_gids.size(); + if (expected_count!=num_events) { + return ::testing::AssertionFailure() << + "the number of events " << num_events << + " does not match expected count " << expected_count; + } - EXPECT_EQ(expected_events, policy::sum(num_events)); return ::testing::AssertionSuccess(); } @@ -488,8 +501,8 @@ TEST(communicator, all2all) EXPECT_TRUE(test_all2all(D, C, [](cell_gid_type g){return true;})); // only cell 0 fires EXPECT_TRUE(test_all2all(D, C, [n_local](cell_gid_type g){return g==0u;})); - // oddly numbered cells fire + // even-numbered cells fire EXPECT_TRUE(test_all2all(D, C, [n_local](cell_gid_type g){return g%2==0;})); - // oddly numbered cells fire + // odd-numbered cells fire EXPECT_TRUE(test_all2all(D, C, [n_local](cell_gid_type g){return g%2==1;})); } diff --git a/tests/global_communication/test_domain_decomposition.cpp b/tests/global_communication/test_domain_decomposition.cpp index c90114396fb3ce4cd3fac0806d70b8fbdd917e2a..e6ddfd86c4a5f0a502fa80ea47191c919a81fb11 100644 --- a/tests/global_communication/test_domain_decomposition.cpp +++ b/tests/global_communication/test_domain_decomposition.cpp @@ -3,6 +3,7 @@ #include <cstdio> #include <fstream> #include <iostream> +#include <stdexcept> #include <string> #include <vector> @@ -11,38 +12,17 @@ #include <hardware/node_info.hpp> #include <load_balance.hpp> +#include "../simple_recipes.hpp" + using namespace nest::mc; using communicator_type = communication::communicator<communication::global_policy>; namespace { - // Homogenous cell population of cable cells. - class homo_recipe: public recipe { - public: - homo_recipe(cell_size_type s): size_(s) - {} + // Dummy recipes types for testing. - cell_size_type num_cells() const override { - return size_; - } - - util::unique_any get_cell_description(cell_gid_type) const override { - return {}; - } - cell_kind get_cell_kind(cell_gid_type) const override { - return cell_kind::cable1d_neuron; - } - - cell_count_info get_cell_count_info(cell_gid_type) const override { - return {0, 0, 0}; - } - std::vector<cell_connection> connections_on(cell_gid_type) const override { - return {}; - } - - private: - cell_size_type size_; - }; + struct dummy_cell {}; + using homo_recipe = homogeneous_recipe<cell_kind::cable1d_neuron, dummy_cell>; // Heterogenous cell population of cable and rss cells. // Interleaved so that cells with even gid are cable cells, and even gid are @@ -59,25 +39,31 @@ namespace { util::unique_any get_cell_description(cell_gid_type) const override { return {}; } + cell_kind get_cell_kind(cell_gid_type gid) const override { return gid%2? cell_kind::regular_spike_source: cell_kind::cable1d_neuron; } - cell_count_info get_cell_count_info(cell_gid_type) const override { - return {0, 0, 0}; - } + cell_size_type num_sources(cell_gid_type) const override { return 0; } + cell_size_type num_targets(cell_gid_type) const override { return 0; } + cell_size_type num_probes(cell_gid_type) const override { return 0; } + std::vector<cell_connection> connections_on(cell_gid_type) const override { return {}; } + probe_info get_probe(cell_member_type) const override { + throw std::logic_error("no probes"); + } + private: cell_size_type size_; }; } -TEST(domain_decomp, homogeneous) { +TEST(domain_decomposition, homogeneous_population) { const auto N = communication::global_policy::size(); const auto I = communication::global_policy::id(); @@ -90,7 +76,7 @@ TEST(domain_decomp, homogeneous) { // 10 cells per domain unsigned n_local = 10; unsigned n_global = n_local*N; - const auto D = partition_load_balance(homo_recipe(n_global), nd); + const auto D = partition_load_balance(homo_recipe(n_global, dummy_cell{}), nd); EXPECT_EQ(D.num_global_cells, n_global); EXPECT_EQ(D.num_local_cells, n_local); @@ -124,7 +110,7 @@ TEST(domain_decomp, homogeneous) { // 10 cells per domain unsigned n_local = 10; unsigned n_global = n_local*N; - const auto D = partition_load_balance(homo_recipe(n_global), nd); + const auto D = partition_load_balance(homo_recipe(n_global, dummy_cell{}), nd); EXPECT_EQ(D.num_global_cells, n_global); EXPECT_EQ(D.num_local_cells, n_local); @@ -152,7 +138,7 @@ TEST(domain_decomp, homogeneous) { } } -TEST(domain_decomp, heterogeneous) { +TEST(domain_decomposition, heterogeneous_population) { const auto N = communication::global_policy::size(); const auto I = communication::global_policy::id();