diff --git a/src/communication/gathered_vector.hpp b/src/communication/gathered_vector.hpp index ceb5c16cde44a872e9ce08d3b4c65c58f9f73c1a..22daa5a68a4c0c288c613a8a547ff532ed463e3c 100644 --- a/src/communication/gathered_vector.hpp +++ b/src/communication/gathered_vector.hpp @@ -28,7 +28,7 @@ public: return partition_; } - /// the number of entries in the gathered vector in partiion i + /// the number of entries in the gathered vector in partition i count_type count(std::size_t i) const { return partition_[i+1] - partition_[i]; } diff --git a/tests/global_communication/mpi_listener.hpp b/tests/global_communication/mpi_listener.hpp index 18d245e4dd2cb3f3752e4fe22a342a28bfa305bd..76cdfd9bcc321519b1102bb4fa1895286518b0d8 100644 --- a/tests/global_communication/mpi_listener.hpp +++ b/tests/global_communication/mpi_listener.hpp @@ -97,24 +97,30 @@ public: } virtual void OnTestCaseEnd(const TestCase& test_case) override { printf_helper( - "[PASSED %3d; FAILED %3d] of %3d tests in %s\n\n", + " PASSED %d of %d tests in %s\n", test_case_tests_-test_case_failures_, - test_case_failures_, test_case_tests_, test_case.name() ); + if (test_case_failures_>0) { + printf_helper( + " FAILED %d of %d tests in %s\n", + test_case_failures_, + test_case_tests_, + test_case.name() + ); + } + printf_helper("\n"); } // Called before a test starts. virtual void OnTestStart(const TestInfo& test_info) override { - printf_helper( " TEST %s::%s\n", test_info.test_case_name(), test_info.name()); + printf_helper( "TEST: %s::%s\n", test_info.test_case_name(), test_info.name()); test_failures_ = 0; } // Called after a failed assertion or a SUCCEED() invocation. virtual void OnTestPartResult(const TestPartResult& test_part_result) override { - const char* banner = "--------------------------------------------------------------------------------"; - // indent all lines in the summary by 4 spaces std::string summary = " " + std::string(test_part_result.summary()); auto pos = summary.find("\n"); @@ -124,13 +130,11 @@ public: } printf_helper( - " LOCAL_%s\n %s\n %s:%d\n%s\n %s\n", + " LOCAL_%s\n %s:%d\n%s\n", test_part_result.failed() ? "FAIL" : "SUCCESS", - banner, test_part_result.file_name(), test_part_result.line_number(), - summary.c_str(), - banner + summary.c_str() ); // note that there was a failure in this test case diff --git a/tests/global_communication/test.cpp b/tests/global_communication/test.cpp index fae2502972c9f8f1025112ca15b069861d4476a8..6eb249034614d886ab3f62d2372216b3b9b2d17f 100644 --- a/tests/global_communication/test.cpp +++ b/tests/global_communication/test.cpp @@ -7,13 +7,23 @@ #include "mpi_listener.hpp" +#include <tinyopt.hpp> #include <communication/communicator.hpp> #include <communication/global_policy.hpp> #include <util/ioutil.hpp> + using namespace nest::mc; +const char* usage_str = +"[OPTION]...\n" +"\n" +" -d, --dryrun Number of dry run ranks\n" +" -h, --help Display usage information and exit\n"; + int main(int argc, char **argv) { + using policy = communication::global_policy; + // We need to set the communicator policy at the top level // this allows us to build multiple communicators in the tests communication::global_policy_guard global_guard(argc, argv); @@ -28,12 +38,44 @@ int main(int argc, char **argv) { // now add our custom printer listeners.Append(new mpi_listener("results_global_communication")); - // record the local return value for tests run on this mpi rank - // 0 : success - // 1 : failure - auto result = RUN_ALL_TESTS(); + int return_value = 0; + try { + auto arg = argv+1; + while (*arg) { + if (auto comm_size = to::parse_opt<unsigned>(arg, 'd', "dryrun")) { + if (*comm_size==0) { + throw to::parse_opt_error(*arg, "must be positive integer"); + } + // Note that this must be set again for each test that uses a different + // number of cells per domain, e.g. + // policy::set_sizes(policy::size(), new_cells_per_rank) + policy::set_sizes(*comm_size, 0); + } + else if (auto o = to::parse_opt(arg, 'h', "help")) { + to::usage(argv[0], usage_str); + return 0; + } + else { + throw to::parse_opt_error(*arg, "unrecognized option"); + } + } + + // record the local return value for tests run on this mpi rank + // 0 : success + // 1 : failure + return_value = RUN_ALL_TESTS(); + } + + catch (to::parse_opt_error& e) { + to::usage(argv[0], usage_str, e.what()); + return_value = 1; + } + catch (std::exception& e) { + std::cerr << "caught exception: " << e.what() << "\n"; + return_value = 1; + } // perform global collective, to ensure that all ranks return // the same exit code - return communication::global_policy::max(result); + return policy::max(return_value); } diff --git a/tests/global_communication/test_communicator.cpp b/tests/global_communication/test_communicator.cpp index e3d427ae36e4df5a2e0b139a05bc3d8b14121b83..b13c8e4f9a7412456289fa2c9dd02e31313dd86d 100644 --- a/tests/global_communication/test_communicator.cpp +++ b/tests/global_communication/test_communicator.cpp @@ -13,22 +13,163 @@ using namespace nest::mc; using communicator_type = communication::communicator<communication::global_policy>; -TEST(communicator, setup) { - /* +bool is_dry_run() { + return communication::global_policy::kind() == + communication::global_policy_kind::dryrun; +} + +TEST(communicator, policy_basics) { using policy = communication::global_policy; - auto num_domains = policy::size(); - auto rank = policy::id(); + const auto num_domains = policy::size(); + const auto rank = policy::id(); - auto counts = policy.gather_all(1); - EXPECT_EQ(counts.size(), unsigned(num_domains)); - for(auto i : counts) { - EXPECT_EQ(i, 1); + EXPECT_EQ(policy::min(rank), 0); + if (!is_dry_run()) { + EXPECT_EQ(policy::max(rank), num_domains-1); } +} + +// Spike gathering works with a generic spike type that +// * has a member called source that +// * the source must be of a type that has a gid member +// +// Here we defined proxy types for testing the gather_spikes functionality. +// These are a little bit simpler than the spike and source types used inside +// NestMC, to simplify the testing. + +// Proxy for a spike source, which represents gid as an integer. +struct source_proxy { + source_proxy() = default; + source_proxy(int g): gid(g) {} + + int gid = 0; +}; + +bool operator==(int other, source_proxy s) {return s.gid==other;}; +bool operator==(source_proxy s, int other) {return s.gid==other;}; + +// Proxy for a spike. +// The value member can be used to test if the spike and its contents were +// successfully gathered. +struct spike_proxy { + spike_proxy() = default; + spike_proxy(int s, int v): source(s), value(v) {} + source_proxy source = 0; + int value = 0; +}; + +// Test low level spike_gather function when each domain produces the same +// number of spikes in the pattern used by dry run mode. +TEST(communicator, gather_spikes_equal) { + using policy = communication::global_policy; + + const auto num_domains = policy::size(); + const auto rank = policy::id(); + + const auto n_local_spikes = 10; + const auto n_local_cells = n_local_spikes; + + // Important: set up meta-data in dry run back end. + if (is_dry_run()) { + policy::set_sizes(policy::size(), n_local_cells); + } + + // Create local spikes for communication. + std::vector<spike_proxy> local_spikes; + for (auto i=0; i<n_local_spikes; ++i) { + local_spikes.push_back(spike_proxy{i+rank*n_local_spikes, rank}); + } + + // Perform exchange + const auto global_spikes = policy::gather_spikes(local_spikes); - auto part = util::parition_view(counts); - for(auto p : part) { - EXPECT_EQ(p.second-p.first, 1); + // Test that partition information is correct + const auto& part = global_spikes.partition(); + EXPECT_EQ(num_domains+1u, part.size()); + for (auto i=0u; i<part.size(); ++i) { + EXPECT_EQ(part[i], n_local_spikes*i); + } + + // Test that spikes were correctly exchanged + // + // In dry run mode the local spikes had sources numbered 0:n_local_spikes-1. + // The global exchange should replicate the local spikes and + // shift their sources to make them local to the "dummy" source + // domain. + // We set the model up with n_local_cells==n_local_spikes with + // one spike per local cell, so the result of the global exchange + // is a list of num_domains*n_local_spikes spikes that have + // contiguous source gid + const auto& spikes = global_spikes.values(); + EXPECT_EQ(n_local_spikes*policy::size(), int(spikes.size())); + for (auto i=0u; i<spikes.size(); ++i) { + const auto s = spikes[i]; + EXPECT_EQ(i, unsigned(s.source.gid)); + if (is_dry_run()) { + EXPECT_EQ(0, s.value); + } + else { + EXPECT_EQ(int(i)/n_local_spikes, s.value); + } + } +} + +// Test low level spike_gather function when the number of spikes per domain +// are not equal. +TEST(communicator, gather_spikes_variant) { + // This test does not apply if in dry run mode. + // Because dry run mode requires that each domain have the same + // number of spikes. + if (is_dry_run()) return; + + using policy = communication::global_policy; + + const auto num_domains = policy::size(); + const auto rank = policy::id(); + + // Parameter used to scale the number of spikes generated on successive + // ranks. + const auto scale = 10; + // Calculates the number of spikes generated by the first n ranks. + // Can be used to calculate the index of the range of spikes + // generated by a given rank, and to determine the total number of + // spikes generated globally. + auto sumn = [scale](int n) {return scale*n*(n+1)/2;}; + const auto n_local_spikes = scale*rank; + + // Create local spikes for communication. + // The ranks generate different numbers of spikes, with the ranks + // generating the following number of spikes + // [ 0, scale, 2*scale, 3*scale, ..., (num_domains-1)*scale ] + // i.e. 0 spikes on the first rank, scale spikes on the second, and so on. + std::vector<spike_proxy> local_spikes; + const auto local_start_id = sumn(rank-1); + for (auto i=0; i<n_local_spikes; ++i) { + local_spikes.push_back(spike_proxy{local_start_id+i, rank}); + } + + // Perform exchange + const auto global_spikes = policy::gather_spikes(local_spikes); + + // Test that partition information is correct + const auto& part =global_spikes.partition(); + EXPECT_EQ(unsigned(num_domains+1), part.size()); + EXPECT_EQ(0, (int)part[0]); + for (auto i=1u; i<part.size(); ++i) { + EXPECT_EQ(sumn(i-1), (int)part[i]); + } + + // Test that spikes were correctly exchanged + for (auto domain=0; domain<num_domains; ++domain) { + auto source = sumn(domain-1); + const auto first_spike = global_spikes.values().begin() + sumn(domain-1); + const auto last_spike = global_spikes.values().begin() + sumn(domain); + const auto spikes = util::make_range(first_spike, last_spike); + for (auto s: spikes) { + EXPECT_EQ(s.value, domain); + EXPECT_EQ(s.source, source++); + } } - */ } +