Skip to content
Snippets Groups Projects
Commit 3bcb8f97 authored by Sam Yates's avatar Sam Yates Committed by Ben Cumming
Browse files

Update global comms tests for new sampling API. (#338)

* Avoid format securiy warning/error in `mpi_listener.hpp`
* Update recipe classes in `test_communicator.cpp` and `test_domain_decomposition.cpp`.
* Align test names in `test_domain_decomposition.cpp` with those in unit tests.
* Fix `spike_gids` bug in `test_all2all` routine.
* Replace test assertions with `AssertionResult` returns in `test_ring` and `test_all2all`.
* Simplify no-extra-events check in `test_ring` and `test_all2all`.
parent 8739fd55
No related branches found
No related tags found
No related merge requests found
......@@ -59,7 +59,7 @@ private:
/// TODO : it might be an idea to use a resizeable buffer
template <typename... Args>
void printf_helper(const char* s, Args&&... args) {
snprintf(buffer_, sizeof(buffer_), s, std::forward<Args>(args)...);
std::snprintf(buffer_, sizeof(buffer_), s, std::forward<Args>(args)...);
print(buffer_);
}
......@@ -110,7 +110,7 @@ public:
test_case.name()
);
}
printf_helper("\n");
print("\n");
}
// Called before a test starts.
......
#include "../gtest.h"
#include <cstdio>
#include <fstream>
#include <iostream>
#include <string>
#include <stdexcept>
#include <vector>
#include <communication/communicator.hpp>
......@@ -200,9 +197,9 @@ namespace {
return gid%2? cell_kind::cable1d_neuron: cell_kind::regular_spike_source;
}
cell_count_info get_cell_count_info(cell_gid_type) const override {
return {1, 1, 0};
}
cell_size_type num_sources(cell_gid_type) const override { return 1; }
cell_size_type num_targets(cell_gid_type) const override { return 1; }
cell_size_type num_probes(cell_gid_type) const override { return 0; }
std::vector<cell_connection> connections_on(cell_gid_type gid) const override {
// a single connection from the preceding cell, i.e. a ring
......@@ -216,12 +213,15 @@ namespace {
1.0f)}; // delay
}
probe_info get_probe(cell_member_type) const override {
throw std::logic_error("no probes");
}
private:
cell_size_type size_;
cell_size_type ranks_;
};
cell_gid_type source_of(cell_gid_type gid, cell_size_type num_cells) {
if (gid) {
return gid-1;
......@@ -264,9 +264,9 @@ namespace {
return gid%2? cell_kind::cable1d_neuron: cell_kind::regular_spike_source;
}
cell_count_info get_cell_count_info(cell_gid_type) const override {
return {1, size_, 0}; // sources, targets, probes
}
cell_size_type num_sources(cell_gid_type) const override { return 1; }
cell_size_type num_targets(cell_gid_type) const override { return size_; }
cell_size_type num_probes(cell_gid_type) const override { return 0; }
std::vector<cell_connection> connections_on(cell_gid_type gid) const override {
std::vector<cell_connection> cons;
......@@ -282,6 +282,10 @@ namespace {
return cons;
}
probe_info get_probe(cell_member_type) const override {
throw std::logic_error("no probes");
}
private:
cell_size_type size_;
cell_size_type ranks_;
......@@ -326,7 +330,6 @@ test_ring(const domain_decomposition& D, comm_type& C, F&& f) {
using util::assign_from;
using util::filter;
auto gids = get_gids(D);
auto group_map = get_group_map(D);
......@@ -338,7 +341,7 @@ test_ring(const domain_decomposition& D, comm_type& C, F&& f) {
// gather the global set of spikes
auto global_spikes = C.exchange(local_spikes);
if (global_spikes.size()!=policy::sum(local_spikes.size())) {
return ::testing::AssertionFailure() << " the number of gathered spikes "
return ::testing::AssertionFailure() << "the number of gathered spikes "
<< global_spikes.size() << " doesn't match the expected "
<< policy::sum(local_spikes.size());
}
......@@ -354,6 +357,7 @@ test_ring(const domain_decomposition& D, comm_type& C, F&& f) {
// Iterate over each local gid, and testing whether an event is expected for
// that gid. If so, look up the event queue of the cell_group of gid, and
// search for the expected event.
int expected_count = 0;
for (auto gid: gids) {
auto src = source_of(gid, D.num_global_cells);
if (f(src)) {
......@@ -362,8 +366,9 @@ test_ring(const domain_decomposition& D, comm_type& C, F&& f) {
auto& q = queues[grp];
if (std::find(q.begin(), q.end(), expected)==q.end()) {
return ::testing::AssertionFailure()
<< " expected event " << expected << " was not found";
<< "expected event " << expected << " was not found";
}
++expected_count;
}
}
......@@ -373,9 +378,12 @@ test_ring(const domain_decomposition& D, comm_type& C, F&& f) {
int num_events = std::accumulate(queues.begin(), queues.end(), 0,
[](int l, decltype(queues.front())& r){return l + r.size();});
int expected_events = util::size(filter(gids, f));
if (expected_count!=num_events) {
return ::testing::AssertionFailure() <<
"the number of events " << num_events <<
" does not match expected count " << expected_count;
}
EXPECT_EQ(policy::sum(expected_events), policy::sum(num_events));
return ::testing::AssertionSuccess();
}
......@@ -399,9 +407,9 @@ TEST(communicator, ring)
EXPECT_TRUE(test_ring(D, C, [](cell_gid_type g){return true;}));
// last cell in each domain fires
EXPECT_TRUE(test_ring(D, C, [n_local](cell_gid_type g){return (g+1)%n_local == 0u;}));
// oddly numbered cells fire
// even-numbered cells fire
EXPECT_TRUE(test_ring(D, C, [n_local](cell_gid_type g){return g%2==0;}));
// oddly numbered cells fire
// odd-numbered cells fire
EXPECT_TRUE(test_ring(D, C, [n_local](cell_gid_type g){return g%2==1;}));
}
......@@ -422,12 +430,12 @@ test_all2all(const domain_decomposition& D, comm_type& C, F&& f) {
std::reverse(local_spikes.begin(), local_spikes.end());
std::vector<cell_gid_type> spike_gids = assign_from(
filter(make_span(0, D.groups.size()), f));
filter(make_span(0, D.num_global_cells), f));
// gather the global set of spikes
auto global_spikes = C.exchange(local_spikes);
if (global_spikes.size()!=policy::sum(local_spikes.size())) {
return ::testing::AssertionFailure() << " the number of gathered spikes "
return ::testing::AssertionFailure() << "the number of gathered spikes "
<< global_spikes.size() << " doesn't match the expected "
<< policy::sum(local_spikes.size());
}
......@@ -443,6 +451,7 @@ test_all2all(const domain_decomposition& D, comm_type& C, F&& f) {
// Iterate over each local gid, and testing whether an event is expected for
// that gid. If so, look up the event queue of the cell_group of gid, and
// search for the expected event.
int expected_count = 0;
for (auto gid: gids) {
// get the event queue that this gid belongs to
auto& q = queues[group_map[gid]];
......@@ -453,6 +462,7 @@ test_all2all(const domain_decomposition& D, comm_type& C, F&& f) {
<< "expected event " << expected
<< " from " << src << " was not found";
}
++expected_count;
}
}
......@@ -462,9 +472,12 @@ test_all2all(const domain_decomposition& D, comm_type& C, F&& f) {
int num_events = std::accumulate(queues.begin(), queues.end(), 0,
[](int l, decltype(queues.front())& r){return l + r.size();});
int expected_events = D.num_global_cells*spike_gids.size();
if (expected_count!=num_events) {
return ::testing::AssertionFailure() <<
"the number of events " << num_events <<
" does not match expected count " << expected_count;
}
EXPECT_EQ(expected_events, policy::sum(num_events));
return ::testing::AssertionSuccess();
}
......@@ -488,8 +501,8 @@ TEST(communicator, all2all)
EXPECT_TRUE(test_all2all(D, C, [](cell_gid_type g){return true;}));
// only cell 0 fires
EXPECT_TRUE(test_all2all(D, C, [n_local](cell_gid_type g){return g==0u;}));
// oddly numbered cells fire
// even-numbered cells fire
EXPECT_TRUE(test_all2all(D, C, [n_local](cell_gid_type g){return g%2==0;}));
// oddly numbered cells fire
// odd-numbered cells fire
EXPECT_TRUE(test_all2all(D, C, [n_local](cell_gid_type g){return g%2==1;}));
}
......@@ -3,6 +3,7 @@
#include <cstdio>
#include <fstream>
#include <iostream>
#include <stdexcept>
#include <string>
#include <vector>
......@@ -11,38 +12,17 @@
#include <hardware/node_info.hpp>
#include <load_balance.hpp>
#include "../simple_recipes.hpp"
using namespace nest::mc;
using communicator_type = communication::communicator<communication::global_policy>;
namespace {
// Homogenous cell population of cable cells.
class homo_recipe: public recipe {
public:
homo_recipe(cell_size_type s): size_(s)
{}
// Dummy recipes types for testing.
cell_size_type num_cells() const override {
return size_;
}
util::unique_any get_cell_description(cell_gid_type) const override {
return {};
}
cell_kind get_cell_kind(cell_gid_type) const override {
return cell_kind::cable1d_neuron;
}
cell_count_info get_cell_count_info(cell_gid_type) const override {
return {0, 0, 0};
}
std::vector<cell_connection> connections_on(cell_gid_type) const override {
return {};
}
private:
cell_size_type size_;
};
struct dummy_cell {};
using homo_recipe = homogeneous_recipe<cell_kind::cable1d_neuron, dummy_cell>;
// Heterogenous cell population of cable and rss cells.
// Interleaved so that cells with even gid are cable cells, and even gid are
......@@ -59,25 +39,31 @@ namespace {
util::unique_any get_cell_description(cell_gid_type) const override {
return {};
}
cell_kind get_cell_kind(cell_gid_type gid) const override {
return gid%2?
cell_kind::regular_spike_source:
cell_kind::cable1d_neuron;
}
cell_count_info get_cell_count_info(cell_gid_type) const override {
return {0, 0, 0};
}
cell_size_type num_sources(cell_gid_type) const override { return 0; }
cell_size_type num_targets(cell_gid_type) const override { return 0; }
cell_size_type num_probes(cell_gid_type) const override { return 0; }
std::vector<cell_connection> connections_on(cell_gid_type) const override {
return {};
}
probe_info get_probe(cell_member_type) const override {
throw std::logic_error("no probes");
}
private:
cell_size_type size_;
};
}
TEST(domain_decomp, homogeneous) {
TEST(domain_decomposition, homogeneous_population) {
const auto N = communication::global_policy::size();
const auto I = communication::global_policy::id();
......@@ -90,7 +76,7 @@ TEST(domain_decomp, homogeneous) {
// 10 cells per domain
unsigned n_local = 10;
unsigned n_global = n_local*N;
const auto D = partition_load_balance(homo_recipe(n_global), nd);
const auto D = partition_load_balance(homo_recipe(n_global, dummy_cell{}), nd);
EXPECT_EQ(D.num_global_cells, n_global);
EXPECT_EQ(D.num_local_cells, n_local);
......@@ -124,7 +110,7 @@ TEST(domain_decomp, homogeneous) {
// 10 cells per domain
unsigned n_local = 10;
unsigned n_global = n_local*N;
const auto D = partition_load_balance(homo_recipe(n_global), nd);
const auto D = partition_load_balance(homo_recipe(n_global, dummy_cell{}), nd);
EXPECT_EQ(D.num_global_cells, n_global);
EXPECT_EQ(D.num_local_cells, n_local);
......@@ -152,7 +138,7 @@ TEST(domain_decomp, homogeneous) {
}
}
TEST(domain_decomp, heterogeneous) {
TEST(domain_decomposition, heterogeneous_population) {
const auto N = communication::global_policy::size();
const auto I = communication::global_policy::id();
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment