Skip to content
Snippets Groups Projects
Commit 0443c271 authored by Ben Cumming's avatar Ben Cumming Committed by Alexander Peyser
Browse files

Feature/comm tests (#201)

Add unit tests for communicator

fixes #200

    update global_communication test driver to initialize correctly in
    dry run mode
    add unit tests that test the communication::global_policy:
        basic initialization
        global spike exchange (just the spike gather step, not the
        event delivery).
    improve the formatting of the reporting from the MPI GTest wrapper
    to make it easy to see if tests have failed.
parent 5b297b29
No related branches found
No related tags found
No related merge requests found
...@@ -28,7 +28,7 @@ public: ...@@ -28,7 +28,7 @@ public:
return partition_; return partition_;
} }
/// the number of entries in the gathered vector in partiion i /// the number of entries in the gathered vector in partition i
count_type count(std::size_t i) const { count_type count(std::size_t i) const {
return partition_[i+1] - partition_[i]; return partition_[i+1] - partition_[i];
} }
......
...@@ -97,24 +97,30 @@ public: ...@@ -97,24 +97,30 @@ public:
} }
virtual void OnTestCaseEnd(const TestCase& test_case) override { virtual void OnTestCaseEnd(const TestCase& test_case) override {
printf_helper( printf_helper(
"[PASSED %3d; FAILED %3d] of %3d tests in %s\n\n", " PASSED %d of %d tests in %s\n",
test_case_tests_-test_case_failures_, test_case_tests_-test_case_failures_,
test_case_failures_,
test_case_tests_, test_case_tests_,
test_case.name() test_case.name()
); );
if (test_case_failures_>0) {
printf_helper(
" FAILED %d of %d tests in %s\n",
test_case_failures_,
test_case_tests_,
test_case.name()
);
}
printf_helper("\n");
} }
// Called before a test starts. // Called before a test starts.
virtual void OnTestStart(const TestInfo& test_info) override { virtual void OnTestStart(const TestInfo& test_info) override {
printf_helper( " TEST %s::%s\n", test_info.test_case_name(), test_info.name()); printf_helper( "TEST: %s::%s\n", test_info.test_case_name(), test_info.name());
test_failures_ = 0; test_failures_ = 0;
} }
// Called after a failed assertion or a SUCCEED() invocation. // Called after a failed assertion or a SUCCEED() invocation.
virtual void OnTestPartResult(const TestPartResult& test_part_result) override { virtual void OnTestPartResult(const TestPartResult& test_part_result) override {
const char* banner = "--------------------------------------------------------------------------------";
// indent all lines in the summary by 4 spaces // indent all lines in the summary by 4 spaces
std::string summary = " " + std::string(test_part_result.summary()); std::string summary = " " + std::string(test_part_result.summary());
auto pos = summary.find("\n"); auto pos = summary.find("\n");
...@@ -124,13 +130,11 @@ public: ...@@ -124,13 +130,11 @@ public:
} }
printf_helper( printf_helper(
" LOCAL_%s\n %s\n %s:%d\n%s\n %s\n", " LOCAL_%s\n %s:%d\n%s\n",
test_part_result.failed() ? "FAIL" : "SUCCESS", test_part_result.failed() ? "FAIL" : "SUCCESS",
banner,
test_part_result.file_name(), test_part_result.file_name(),
test_part_result.line_number(), test_part_result.line_number(),
summary.c_str(), summary.c_str()
banner
); );
// note that there was a failure in this test case // note that there was a failure in this test case
......
...@@ -7,13 +7,23 @@ ...@@ -7,13 +7,23 @@
#include "mpi_listener.hpp" #include "mpi_listener.hpp"
#include <tinyopt.hpp>
#include <communication/communicator.hpp> #include <communication/communicator.hpp>
#include <communication/global_policy.hpp> #include <communication/global_policy.hpp>
#include <util/ioutil.hpp> #include <util/ioutil.hpp>
using namespace nest::mc; using namespace nest::mc;
const char* usage_str =
"[OPTION]...\n"
"\n"
" -d, --dryrun Number of dry run ranks\n"
" -h, --help Display usage information and exit\n";
int main(int argc, char **argv) { int main(int argc, char **argv) {
using policy = communication::global_policy;
// We need to set the communicator policy at the top level // We need to set the communicator policy at the top level
// this allows us to build multiple communicators in the tests // this allows us to build multiple communicators in the tests
communication::global_policy_guard global_guard(argc, argv); communication::global_policy_guard global_guard(argc, argv);
...@@ -28,12 +38,44 @@ int main(int argc, char **argv) { ...@@ -28,12 +38,44 @@ int main(int argc, char **argv) {
// now add our custom printer // now add our custom printer
listeners.Append(new mpi_listener("results_global_communication")); listeners.Append(new mpi_listener("results_global_communication"));
// record the local return value for tests run on this mpi rank int return_value = 0;
// 0 : success try {
// 1 : failure auto arg = argv+1;
auto result = RUN_ALL_TESTS(); while (*arg) {
if (auto comm_size = to::parse_opt<unsigned>(arg, 'd', "dryrun")) {
if (*comm_size==0) {
throw to::parse_opt_error(*arg, "must be positive integer");
}
// Note that this must be set again for each test that uses a different
// number of cells per domain, e.g.
// policy::set_sizes(policy::size(), new_cells_per_rank)
policy::set_sizes(*comm_size, 0);
}
else if (auto o = to::parse_opt(arg, 'h', "help")) {
to::usage(argv[0], usage_str);
return 0;
}
else {
throw to::parse_opt_error(*arg, "unrecognized option");
}
}
// record the local return value for tests run on this mpi rank
// 0 : success
// 1 : failure
return_value = RUN_ALL_TESTS();
}
catch (to::parse_opt_error& e) {
to::usage(argv[0], usage_str, e.what());
return_value = 1;
}
catch (std::exception& e) {
std::cerr << "caught exception: " << e.what() << "\n";
return_value = 1;
}
// perform global collective, to ensure that all ranks return // perform global collective, to ensure that all ranks return
// the same exit code // the same exit code
return communication::global_policy::max(result); return policy::max(return_value);
} }
...@@ -13,22 +13,163 @@ using namespace nest::mc; ...@@ -13,22 +13,163 @@ using namespace nest::mc;
using communicator_type = communication::communicator<communication::global_policy>; using communicator_type = communication::communicator<communication::global_policy>;
TEST(communicator, setup) { bool is_dry_run() {
/* return communication::global_policy::kind() ==
communication::global_policy_kind::dryrun;
}
TEST(communicator, policy_basics) {
using policy = communication::global_policy; using policy = communication::global_policy;
auto num_domains = policy::size(); const auto num_domains = policy::size();
auto rank = policy::id(); const auto rank = policy::id();
auto counts = policy.gather_all(1); EXPECT_EQ(policy::min(rank), 0);
EXPECT_EQ(counts.size(), unsigned(num_domains)); if (!is_dry_run()) {
for(auto i : counts) { EXPECT_EQ(policy::max(rank), num_domains-1);
EXPECT_EQ(i, 1);
} }
}
// Spike gathering works with a generic spike type that
// * has a member called source that
// * the source must be of a type that has a gid member
//
// Here we defined proxy types for testing the gather_spikes functionality.
// These are a little bit simpler than the spike and source types used inside
// NestMC, to simplify the testing.
// Proxy for a spike source, which represents gid as an integer.
struct source_proxy {
source_proxy() = default;
source_proxy(int g): gid(g) {}
int gid = 0;
};
bool operator==(int other, source_proxy s) {return s.gid==other;};
bool operator==(source_proxy s, int other) {return s.gid==other;};
// Proxy for a spike.
// The value member can be used to test if the spike and its contents were
// successfully gathered.
struct spike_proxy {
spike_proxy() = default;
spike_proxy(int s, int v): source(s), value(v) {}
source_proxy source = 0;
int value = 0;
};
// Test low level spike_gather function when each domain produces the same
// number of spikes in the pattern used by dry run mode.
TEST(communicator, gather_spikes_equal) {
using policy = communication::global_policy;
const auto num_domains = policy::size();
const auto rank = policy::id();
const auto n_local_spikes = 10;
const auto n_local_cells = n_local_spikes;
// Important: set up meta-data in dry run back end.
if (is_dry_run()) {
policy::set_sizes(policy::size(), n_local_cells);
}
// Create local spikes for communication.
std::vector<spike_proxy> local_spikes;
for (auto i=0; i<n_local_spikes; ++i) {
local_spikes.push_back(spike_proxy{i+rank*n_local_spikes, rank});
}
// Perform exchange
const auto global_spikes = policy::gather_spikes(local_spikes);
auto part = util::parition_view(counts); // Test that partition information is correct
for(auto p : part) { const auto& part = global_spikes.partition();
EXPECT_EQ(p.second-p.first, 1); EXPECT_EQ(num_domains+1u, part.size());
for (auto i=0u; i<part.size(); ++i) {
EXPECT_EQ(part[i], n_local_spikes*i);
}
// Test that spikes were correctly exchanged
//
// In dry run mode the local spikes had sources numbered 0:n_local_spikes-1.
// The global exchange should replicate the local spikes and
// shift their sources to make them local to the "dummy" source
// domain.
// We set the model up with n_local_cells==n_local_spikes with
// one spike per local cell, so the result of the global exchange
// is a list of num_domains*n_local_spikes spikes that have
// contiguous source gid
const auto& spikes = global_spikes.values();
EXPECT_EQ(n_local_spikes*policy::size(), int(spikes.size()));
for (auto i=0u; i<spikes.size(); ++i) {
const auto s = spikes[i];
EXPECT_EQ(i, unsigned(s.source.gid));
if (is_dry_run()) {
EXPECT_EQ(0, s.value);
}
else {
EXPECT_EQ(int(i)/n_local_spikes, s.value);
}
}
}
// Test low level spike_gather function when the number of spikes per domain
// are not equal.
TEST(communicator, gather_spikes_variant) {
// This test does not apply if in dry run mode.
// Because dry run mode requires that each domain have the same
// number of spikes.
if (is_dry_run()) return;
using policy = communication::global_policy;
const auto num_domains = policy::size();
const auto rank = policy::id();
// Parameter used to scale the number of spikes generated on successive
// ranks.
const auto scale = 10;
// Calculates the number of spikes generated by the first n ranks.
// Can be used to calculate the index of the range of spikes
// generated by a given rank, and to determine the total number of
// spikes generated globally.
auto sumn = [scale](int n) {return scale*n*(n+1)/2;};
const auto n_local_spikes = scale*rank;
// Create local spikes for communication.
// The ranks generate different numbers of spikes, with the ranks
// generating the following number of spikes
// [ 0, scale, 2*scale, 3*scale, ..., (num_domains-1)*scale ]
// i.e. 0 spikes on the first rank, scale spikes on the second, and so on.
std::vector<spike_proxy> local_spikes;
const auto local_start_id = sumn(rank-1);
for (auto i=0; i<n_local_spikes; ++i) {
local_spikes.push_back(spike_proxy{local_start_id+i, rank});
}
// Perform exchange
const auto global_spikes = policy::gather_spikes(local_spikes);
// Test that partition information is correct
const auto& part =global_spikes.partition();
EXPECT_EQ(unsigned(num_domains+1), part.size());
EXPECT_EQ(0, (int)part[0]);
for (auto i=1u; i<part.size(); ++i) {
EXPECT_EQ(sumn(i-1), (int)part[i]);
}
// Test that spikes were correctly exchanged
for (auto domain=0; domain<num_domains; ++domain) {
auto source = sumn(domain-1);
const auto first_spike = global_spikes.values().begin() + sumn(domain-1);
const auto last_spike = global_spikes.values().begin() + sumn(domain);
const auto spikes = util::make_range(first_spike, last_spike);
for (auto s: spikes) {
EXPECT_EQ(s.value, domain);
EXPECT_EQ(s.source, source++);
}
} }
*/
} }
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment