Skip to content
Snippets Groups Projects
Commit cba9d458 authored by Ben Cumming's avatar Ben Cumming Committed by Sam Yates
Browse files

Refactor domain decomposition for arbitrary gid distribution. (#326)

Changes to `domain_decomposition`:
  * `domain_decomposition` performs two pass load balancing in constructor:
      1. first pass performs global load balance,
      2. second pass distributes cells locally between cpu and gpu cell_groups.
    The current logic for this is very simple and naive, and will be replaced with a load balancer which returns a lighter domain decomposition description in a follow up pull request.
  * Provides a simple `group_description` type that contains gid, `cell_kind` and target backend information for `cell_group_factory`.

Changes to `communicator`:
  * Constructor takes a `domain_decomposition` and recipe.
  * The interface for adding connections and constructing connection table has been removed, as this is now performed within the constructor.
  * Construction is more complicated, as connections are partitioned by source gid which requires multiple passes over the connection information in the recipe.
  * `make_event_queues` updated: spikes and connections are now partitioned by source domain, and an optimization dynamically chooses to iterate over either connection or spike list, whichever is shorter.
  * The `exchange` method now sorts `local_spikes` before global gather to facilitate the optimized spike/connection searching.

Changes to `miniapp`:
  * Automatically use gpu if available and compiled with gpu support.
  * Banner prints out useful information about number of cores, gpus and ranks.
  * Remove -g cell group size flag.

Changes to `cell_group`:
  * `cell_group` interface take a list of gid values instead of a range.
  * Updated internal `cell_group` logic to convert between gid and local indices: use a vector for local index to gid map, and a hash table for gid to local index in `cell_group` implementations that need this lookup.

changes to unit tests
  * tests for the domain decomposition
  * tests for the communicator that test ring and all2all networks
parent 98537ff2
Branches
Tags
No related merge requests found
Showing
with 436 additions and 327 deletions
...@@ -20,6 +20,16 @@ if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") ...@@ -20,6 +20,16 @@ if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
# flag initializations such as # flag initializations such as
# std::array<int,3> a={1,2,3}; # std::array<int,3> a={1,2,3};
set(CXXOPT_WALL "${CXXOPT_WALL} -Wno-missing-braces") set(CXXOPT_WALL "${CXXOPT_WALL} -Wno-missing-braces")
# Clang is erroneously warning that T is an 'unused type alias' in code like this:
# struct X {
# using T = decltype(expression);
# T x;
# };
set(CXXOPT_WALL "${CXXOPT_WALL} -Wno-unused-local-typedef")
# Ignore warning if string passed to snprintf is not a string literal.
set(CXXOPT_WALL "${CXXOPT_WALL} -Wno-format-security")
endif() endif()
if(${CMAKE_CXX_COMPILER_ID} MATCHES "GNU") if(${CMAKE_CXX_COMPILER_ID} MATCHES "GNU")
...@@ -37,14 +47,14 @@ if(${CMAKE_CXX_COMPILER_ID} MATCHES "GNU") ...@@ -37,14 +47,14 @@ if(${CMAKE_CXX_COMPILER_ID} MATCHES "GNU")
endif() endif()
if(${CMAKE_CXX_COMPILER_ID} MATCHES "Intel") if(${CMAKE_CXX_COMPILER_ID} MATCHES "Intel")
# Disable warning for unused template parameter
# this is raised by a templated function in the json library.
set(CXXOPT_WALL "${CXXOPT_WALL} -wd488")
# Compiler flags for generating KNL-specific AVX512 instructions. # Compiler flags for generating KNL-specific AVX512 instructions.
set(CXXOPT_KNL "-xMIC-AVX512") set(CXXOPT_KNL "-xMIC-AVX512")
set(CXXOPT_AVX "-xAVX") set(CXXOPT_AVX "-xAVX")
set(CXXOPT_AVX2 "-xCORE-AVX2") set(CXXOPT_AVX2 "-xCORE-AVX2")
set(CXXOPT_AVX512 "-xCORE-AVX512") set(CXXOPT_AVX512 "-xCORE-AVX512")
# Disable warning for unused template parameter
# this is raised by a templated function in the json library.
set(CXXOPT_WALL "${CXXOPT_WALL} -wd488")
endif() endif()
...@@ -160,9 +160,6 @@ cl_options read_options(int argc, char** argv, bool allow_write) { ...@@ -160,9 +160,6 @@ cl_options read_options(int argc, char** argv, bool allow_write) {
"m","alltoall","all to all network", cmd, false); "m","alltoall","all to all network", cmd, false);
TCLAP::SwitchArg ring_arg( TCLAP::SwitchArg ring_arg(
"r","ring","ring network", cmd, false); "r","ring","ring network", cmd, false);
TCLAP::ValueArg<uint32_t> group_size_arg(
"g", "group-size", "number of cells per cell group",
false, defopts.compartments_per_segment, "integer", cmd);
TCLAP::ValueArg<double> sample_dt_arg( TCLAP::ValueArg<double> sample_dt_arg(
"", "sample-dt", "set sampling interval to <time> ms", "", "sample-dt", "set sampling interval to <time> ms",
false, defopts.bin_dt, "time", cmd); false, defopts.bin_dt, "time", cmd);
...@@ -229,7 +226,6 @@ cl_options read_options(int argc, char** argv, bool allow_write) { ...@@ -229,7 +226,6 @@ cl_options read_options(int argc, char** argv, bool allow_write) {
update_option(options.tfinal, fopts, "tfinal"); update_option(options.tfinal, fopts, "tfinal");
update_option(options.all_to_all, fopts, "all_to_all"); update_option(options.all_to_all, fopts, "all_to_all");
update_option(options.ring, fopts, "ring"); update_option(options.ring, fopts, "ring");
update_option(options.group_size, fopts, "group_size");
update_option(options.sample_dt, fopts, "sample_dt"); update_option(options.sample_dt, fopts, "sample_dt");
update_option(options.probe_ratio, fopts, "probe_ratio"); update_option(options.probe_ratio, fopts, "probe_ratio");
update_option(options.probe_soma_only, fopts, "probe_soma_only"); update_option(options.probe_soma_only, fopts, "probe_soma_only");
...@@ -280,7 +276,6 @@ cl_options read_options(int argc, char** argv, bool allow_write) { ...@@ -280,7 +276,6 @@ cl_options read_options(int argc, char** argv, bool allow_write) {
update_option(options.bin_regular, bin_regular_arg); update_option(options.bin_regular, bin_regular_arg);
update_option(options.all_to_all, all_to_all_arg); update_option(options.all_to_all, all_to_all_arg);
update_option(options.ring, ring_arg); update_option(options.ring, ring_arg);
update_option(options.group_size, group_size_arg);
update_option(options.sample_dt, sample_dt_arg); update_option(options.sample_dt, sample_dt_arg);
update_option(options.probe_ratio, probe_ratio_arg); update_option(options.probe_ratio, probe_ratio_arg);
update_option(options.probe_soma_only, probe_soma_only_arg); update_option(options.probe_soma_only, probe_soma_only_arg);
...@@ -308,10 +303,6 @@ cl_options read_options(int argc, char** argv, bool allow_write) { ...@@ -308,10 +303,6 @@ cl_options read_options(int argc, char** argv, bool allow_write) {
throw usage_error("can specify at most one of --ring and --all-to-all"); throw usage_error("can specify at most one of --ring and --all-to-all");
} }
if (options.group_size<1) {
throw usage_error("minimum of one cell per group");
}
save_file = ofile_arg.getValue(); save_file = ofile_arg.getValue();
} }
catch (TCLAP::ArgException& e) { catch (TCLAP::ArgException& e) {
...@@ -335,7 +326,6 @@ cl_options read_options(int argc, char** argv, bool allow_write) { ...@@ -335,7 +326,6 @@ cl_options read_options(int argc, char** argv, bool allow_write) {
fopts["tfinal"] = options.tfinal; fopts["tfinal"] = options.tfinal;
fopts["all_to_all"] = options.all_to_all; fopts["all_to_all"] = options.all_to_all;
fopts["ring"] = options.ring; fopts["ring"] = options.ring;
fopts["group_size"] = options.group_size;
fopts["sample_dt"] = options.sample_dt; fopts["sample_dt"] = options.sample_dt;
fopts["probe_ratio"] = options.probe_ratio; fopts["probe_ratio"] = options.probe_ratio;
fopts["probe_soma_only"] = options.probe_soma_only; fopts["probe_soma_only"] = options.probe_soma_only;
...@@ -388,7 +378,6 @@ std::ostream& operator<<(std::ostream& o, const cl_options& options) { ...@@ -388,7 +378,6 @@ std::ostream& operator<<(std::ostream& o, const cl_options& options) {
(options.bin_dt==0? "none": options.bin_regular? "regular": "following") << "\n"; (options.bin_dt==0? "none": options.bin_regular? "regular": "following") << "\n";
o << " all to all network : " << (options.all_to_all ? "yes" : "no") << "\n"; o << " all to all network : " << (options.all_to_all ? "yes" : "no") << "\n";
o << " ring network : " << (options.ring ? "yes" : "no") << "\n"; o << " ring network : " << (options.ring ? "yes" : "no") << "\n";
o << " group size : " << options.group_size << "\n";
o << " sample dt : " << options.sample_dt << "\n"; o << " sample dt : " << options.sample_dt << "\n";
o << " probe ratio : " << options.probe_ratio << "\n"; o << " probe ratio : " << options.probe_ratio << "\n";
o << " probe soma only : " << (options.probe_soma_only ? "yes" : "no") << "\n"; o << " probe soma only : " << (options.probe_soma_only ? "yes" : "no") << "\n";
......
...@@ -34,7 +34,6 @@ struct cl_options { ...@@ -34,7 +34,6 @@ struct cl_options {
// Simulation running parameters: // Simulation running parameters:
double tfinal = 100.; double tfinal = 100.;
double dt = 0.025; double dt = 0.025;
uint32_t group_size = 1;
bool bin_regular = false; // False => use 'following' instead of 'regular'. bool bin_regular = false; // False => use 'following' instead of 'regular'.
double bin_dt = 0.0025; // 0 => no binning. double bin_dt = 0.0025; // 0 => no binning.
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#include <communication/global_policy.hpp> #include <communication/global_policy.hpp>
#include <cell.hpp> #include <cell.hpp>
#include <fvm_multicell.hpp> #include <fvm_multicell.hpp>
#include <hardware/gpu.hpp>
#include <hardware/node_info.hpp>
#include <io/exporter_spike_file.hpp> #include <io/exporter_spike_file.hpp>
#include <model.hpp> #include <model.hpp>
#include <profiling/profiler.hpp> #include <profiling/profiler.hpp>
...@@ -31,7 +33,7 @@ using namespace nest::mc; ...@@ -31,7 +33,7 @@ using namespace nest::mc;
using global_policy = communication::global_policy; using global_policy = communication::global_policy;
using sample_trace_type = sample_trace<time_type, double>; using sample_trace_type = sample_trace<time_type, double>;
using file_export_type = io::exporter_spike_file<global_policy>; using file_export_type = io::exporter_spike_file<global_policy>;
void banner(); void banner(hw::node_info);
std::unique_ptr<recipe> make_recipe(const io::cl_options&, const probe_distribution&); std::unique_ptr<recipe> make_recipe(const io::cl_options&, const probe_distribution&);
std::unique_ptr<sample_trace_type> make_trace(probe_record probe); std::unique_ptr<sample_trace_type> make_trace(probe_record probe);
using communicator_type = communication::communicator<communication::global_policy>; using communicator_type = communication::communicator<communication::global_policy>;
...@@ -67,7 +69,12 @@ int main(int argc, char** argv) { ...@@ -67,7 +69,12 @@ int main(int argc, char** argv) {
global_policy::set_sizes(options.dry_run_ranks, cells_per_rank); global_policy::set_sizes(options.dry_run_ranks, cells_per_rank);
} }
banner(); // Use a node description that uses the number of threads used by the
// threading back end, and 1 gpu if available.
hw::node_info nd;
nd.num_cpu_cores = threading::num_threads();
nd.num_gpus = hw::num_gpus()>0? 1: 0;
banner(nd);
meters.checkpoint("setup"); meters.checkpoint("setup");
...@@ -85,11 +92,7 @@ int main(int argc, char** argv) { ...@@ -85,11 +92,7 @@ int main(int argc, char** argv) {
options.file_extension, options.over_write); options.file_extension, options.over_write);
}; };
group_rules rules; auto decomp = domain_decomposition(*recipe, nd);
rules.policy = config::has_cuda?
backend_policy::prefer_gpu: backend_policy::use_multicore;
rules.target_group_size = options.group_size;
auto decomp = domain_decomposition(*recipe, rules);
model m(*recipe, decomp); model m(*recipe, decomp);
...@@ -144,8 +147,7 @@ int main(int argc, char** argv) { ...@@ -144,8 +147,7 @@ int main(int argc, char** argv) {
meters.checkpoint("model-simulate"); meters.checkpoint("model-simulate");
// output profile and diagnostic feedback // output profile and diagnostic feedback
auto const num_steps = options.tfinal / options.dt; util::profiler_output(0.001, options.profile_only_zero);
util::profiler_output(0.001, m.num_cells()*num_steps, options.profile_only_zero);
std::cout << "there were " << m.num_spikes() << " spikes\n"; std::cout << "there were " << m.num_spikes() << " spikes\n";
// save traces // save traces
...@@ -176,13 +178,15 @@ int main(int argc, char** argv) { ...@@ -176,13 +178,15 @@ int main(int argc, char** argv) {
return 0; return 0;
} }
void banner() { void banner(hw::node_info nd) {
std::cout << "====================\n"; std::cout << "==========================================\n";
std::cout << " NestMC miniapp\n"; std::cout << " NestMC miniapp\n";
std::cout << " - " << threading::description() << " threading support (" << threading::num_threads() << ")\n"; std::cout << " - distributed : " << global_policy::size()
std::cout << " - communication policy: " << std::to_string(global_policy::kind()) << " (" << global_policy::size() << ")\n"; << " (" << std::to_string(global_policy::kind()) << ")\n";
std::cout << " - gpu support: " << (config::has_cuda? "on": "off") << "\n"; std::cout << " - threads : " << nd.num_cpu_cores
std::cout << "====================\n"; << " (" << threading::description() << ")\n";
std::cout << " - gpus : " << nd.num_gpus << "\n";
std::cout << "==========================================\n";
} }
std::unique_ptr<recipe> make_recipe(const io::cl_options& options, const probe_distribution& pdist) { std::unique_ptr<recipe> make_recipe(const io::cl_options& options, const probe_distribution& pdist) {
......
...@@ -239,7 +239,14 @@ public: ...@@ -239,7 +239,14 @@ public:
basic_rgraph_recipe(cell_gid_type ncell, basic_rgraph_recipe(cell_gid_type ncell,
basic_recipe_param param, basic_recipe_param param,
probe_distribution pdist = probe_distribution{}): probe_distribution pdist = probe_distribution{}):
basic_cell_recipe(ncell, std::move(param), std::move(pdist)) {} basic_cell_recipe(ncell, std::move(param), std::move(pdist))
{
// Cells are not allowed to connect to themselves; hence there must be least two cells
// to build a connected network.
if (ncell<2) {
throw std::runtime_error("A randomly connected network must have at least 2 cells.");
}
}
std::vector<cell_connection> connections_on(cell_gid_type i) const override { std::vector<cell_connection> connections_on(cell_gid_type i) const override {
std::vector<cell_connection> conns; std::vector<cell_connection> conns;
......
...@@ -14,6 +14,7 @@ set(BASE_SOURCES ...@@ -14,6 +14,7 @@ set(BASE_SOURCES
hardware/affinity.cpp hardware/affinity.cpp
hardware/gpu.cpp hardware/gpu.cpp
hardware/memory.cpp hardware/memory.cpp
hardware/node_info.cpp
hardware/power.cpp hardware/power.cpp
threading/threading.cpp threading/threading.cpp
util/debug.cpp util/debug.cpp
......
...@@ -6,16 +6,19 @@ ...@@ -6,16 +6,19 @@
namespace nest { namespace nest {
namespace mc { namespace mc {
enum class backend_policy { enum class backend_kind {
use_multicore, // use multicore backend for all computation multicore, // use multicore backend for all computation
prefer_gpu // use gpu back end when supported by cell_group type gpu // use gpu back end when supported by cell_group type
}; };
inline std::string to_string(backend_policy p) { inline std::string to_string(backend_kind p) {
if (p==backend_policy::use_multicore) { switch (p) {
return "use_multicore"; case backend_kind::multicore:
return "multicore";
case backend_kind::gpu:
return "gpu";
} }
return "prefer_gpu"; return "unknown";
} }
} // namespace mc } // namespace mc
......
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
#include <backends.hpp> #include <backends.hpp>
#include <cell_group.hpp> #include <cell_group.hpp>
#include <domain_decomposition.hpp>
#include <dss_cell_group.hpp> #include <dss_cell_group.hpp>
#include <fvm_multicell.hpp> #include <fvm_multicell.hpp>
#include <mc_cell_group.hpp> #include <mc_cell_group.hpp>
#include <recipe.hpp>
#include <rss_cell_group.hpp> #include <rss_cell_group.hpp>
#include <util/unique_any.hpp> #include <util/unique_any.hpp>
...@@ -14,26 +16,29 @@ namespace mc { ...@@ -14,26 +16,29 @@ namespace mc {
using gpu_fvm_cell = mc_cell_group<fvm::fvm_multicell<gpu::backend>>; using gpu_fvm_cell = mc_cell_group<fvm::fvm_multicell<gpu::backend>>;
using mc_fvm_cell = mc_cell_group<fvm::fvm_multicell<multicore::backend>>; using mc_fvm_cell = mc_cell_group<fvm::fvm_multicell<multicore::backend>>;
cell_group_ptr cell_group_factory( cell_group_ptr cell_group_factory(const recipe& rec, const group_description& group) {
cell_kind kind, // Make a list of all the cell descriptions to be forwarded
cell_gid_type first_gid, // to the appropriate cell_group constructor.
const std::vector<util::unique_any>& cell_descriptions, std::vector<util::unique_any> descriptions;
backend_policy backend) descriptions.reserve(group.gids.size());
{ for (auto gid: group.gids) {
switch (kind) { descriptions.push_back(rec.get_cell_description(gid));
}
switch (group.kind) {
case cell_kind::cable1d_neuron: case cell_kind::cable1d_neuron:
if (backend == backend_policy::prefer_gpu) { if (group.backend == backend_kind::gpu) {
return make_cell_group<gpu_fvm_cell>(first_gid, cell_descriptions); return make_cell_group<gpu_fvm_cell>(group.gids, descriptions);
} }
else { else {
return make_cell_group<mc_fvm_cell>(first_gid, cell_descriptions); return make_cell_group<mc_fvm_cell>(group.gids, descriptions);
} }
case cell_kind::regular_spike_source: case cell_kind::regular_spike_source:
return make_cell_group<rss_cell_group>(first_gid, cell_descriptions); return make_cell_group<rss_cell_group>(group.gids, descriptions);
case cell_kind::data_spike_source: case cell_kind::data_spike_source:
return make_cell_group<dss_cell_group>(first_gid, cell_descriptions); return make_cell_group<dss_cell_group>(group.gids, descriptions);
default: default:
throw std::runtime_error("unknown cell kind"); throw std::runtime_error("unknown cell kind");
......
...@@ -4,17 +4,15 @@ ...@@ -4,17 +4,15 @@
#include <backends.hpp> #include <backends.hpp>
#include <cell_group.hpp> #include <cell_group.hpp>
#include <domain_decomposition.hpp>
#include <recipe.hpp>
#include <util/unique_any.hpp> #include <util/unique_any.hpp>
namespace nest { namespace nest {
namespace mc { namespace mc {
// Helper factory for building cell groups // Helper factory for building cell groups
cell_group_ptr cell_group_factory( cell_group_ptr cell_group_factory(const recipe& rec, const group_description& group);
cell_kind kind,
cell_gid_type first_gid,
const std::vector<util::unique_any>& cells,
backend_policy backend);
} // namespace mc } // namespace mc
} // namespace nest } // namespace nest
...@@ -10,11 +10,14 @@ ...@@ -10,11 +10,14 @@
#include <common_types.hpp> #include <common_types.hpp>
#include <connection.hpp> #include <connection.hpp>
#include <communication/gathered_vector.hpp> #include <communication/gathered_vector.hpp>
#include <domain_decomposition.hpp>
#include <event_queue.hpp> #include <event_queue.hpp>
#include <recipe.hpp>
#include <spike.hpp> #include <spike.hpp>
#include <util/debug.hpp> #include <util/debug.hpp>
#include <util/double_buffer.hpp> #include <util/double_buffer.hpp>
#include <util/partition.hpp> #include <util/partition.hpp>
#include <util/rangeutil.hpp>
namespace nest { namespace nest {
namespace mc { namespace mc {
...@@ -45,33 +48,73 @@ public: ...@@ -45,33 +48,73 @@ public:
communicator() {} communicator() {}
explicit communicator(gid_partition_type cell_gid_partition): explicit communicator(const recipe& rec, const domain_decomposition& dom_dec) {
cell_gid_partition_(cell_gid_partition) using util::make_span;
{} num_domains_ = comms_.size();
num_local_groups_ = dom_dec.num_local_groups();
// For caching information about each cell
struct gid_info {
using connection_list = decltype(rec.connections_on(0));
cell_gid_type gid;
cell_gid_type local_group;
connection_list conns;
gid_info(cell_gid_type g, cell_gid_type lg, connection_list c):
gid(g), local_group(lg), conns(std::move(c)) {}
};
cell_local_size_type num_groups_local() const // Make a list of local gid with their group index and connections
{ // -> gid_infos
return cell_gid_partition_.size(); // Count the number of local connections (i.e. connections terminating on this domain)
// -> n_cons: scalar
// Calculate and store domain id of the presynaptic cell on each local connection
// -> src_domains: array with one entry for every local connection
// Also the count of presynaptic sources from each domain
// -> src_counts: array with one entry for each domain
std::vector<gid_info> gid_infos;
gid_infos.reserve(dom_dec.num_local_cells());
cell_local_size_type n_cons = 0;
std::vector<unsigned> src_domains;
std::vector<cell_size_type> src_counts(num_domains_);
for (auto i: make_span(0, num_local_groups_)) {
const auto& group = dom_dec.get_group(i);
for (auto gid: group.gids) {
gid_info info(gid, i, rec.connections_on(gid));
n_cons += info.conns.size();
for (auto con: info.conns) {
const auto src = dom_dec.gid_domain(con.source.gid);
src_domains.push_back(src);
src_counts[src]++;
} }
gid_infos.push_back(std::move(info));
void add_connection(connection con) {
EXPECTS(is_local_cell(con.destination().gid));
connections_.push_back(con);
} }
/// returns true if the cell with gid is on the domain of the caller
bool is_local_cell(cell_gid_type gid) const {
return algorithms::in_interval(gid, cell_gid_partition_.bounds());
} }
/// builds the optimized data structure // Construct the connections.
/// must be called after all connections have been added // The loop above gave the information required to construct in place
void construct() { // the connections as partitioned by the domain of their source gid.
if (!std::is_sorted(connections_.begin(), connections_.end())) { connections_.resize(n_cons);
threading::sort(connections_); connection_part_ = algorithms::make_index(src_counts);
auto offsets = connection_part_;
std::size_t pos = 0;
for (const auto& cell: gid_infos) {
for (auto c: cell.conns) {
const auto i = offsets[src_domains[pos]]++;
connections_[i] = {c.source, c.dest, c.weight, c.delay, cell.local_group};
++pos;
} }
} }
// Sort the connections for each domain.
// This is num_domains_ independent sorts, so it can be parallelized trivially.
const auto& cp = connection_part_;
threading::parallel_for::apply(0, num_domains_,
[&](cell_gid_type i) {
util::sort(util::subrange_view(connections_, cp[i], cp[i+1]));
});
}
/// the minimum delay of all connections in the global network. /// the minimum delay of all connections in the global network.
time_type min_delay() { time_type min_delay() {
auto local_min = std::numeric_limits<time_type>::max(); auto local_min = std::numeric_limits<time_type>::max();
...@@ -79,16 +122,19 @@ public: ...@@ -79,16 +122,19 @@ public:
local_min = std::min(local_min, con.delay()); local_min = std::min(local_min, con.delay());
} }
return communication_policy_.min(local_min); return comms_.min(local_min);
} }
/// Perform exchange of spikes. /// Perform exchange of spikes.
/// ///
/// Takes as input the list of local_spikes that were generated on the calling domain. /// Takes as input the list of local_spikes that were generated on the calling domain.
/// Returns the full global set of vectors, along with meta data about their partition /// Returns the full global set of vectors, along with meta data about their partition
gathered_vector<spike> exchange(const std::vector<spike>& local_spikes) { gathered_vector<spike> exchange(std::vector<spike> local_spikes) {
// sort the spikes in ascending order of source gid
util::sort_by(local_spikes, [](spike s){return s.source;});
// global all-to-all to gather a local copy of the global spike list on each node. // global all-to-all to gather a local copy of the global spike list on each node.
auto global_spikes = communication_policy_.gather_spikes( local_spikes ); auto global_spikes = comms_.gather_spikes(local_spikes);
num_spikes_ += global_spikes.size(); num_spikes_ += global_spikes.size();
return global_spikes; return global_spikes;
} }
...@@ -101,15 +147,51 @@ public: ...@@ -101,15 +147,51 @@ public:
/// events in each queue are all events that must be delivered to targets in that cell /// events in each queue are all events that must be delivered to targets in that cell
/// group as a result of the global spike exchange. /// group as a result of the global spike exchange.
std::vector<event_queue> make_event_queues(const gathered_vector<spike>& global_spikes) { std::vector<event_queue> make_event_queues(const gathered_vector<spike>& global_spikes) {
auto queues = std::vector<event_queue>(num_groups_local()); using util::subrange_view;
for (auto spike : global_spikes.values()) { using util::make_span;
// search for targets using util::make_range;
auto targets = std::equal_range(connections_.begin(), connections_.end(), spike.source);
auto queues = std::vector<event_queue>(num_local_groups_);
const auto& sp = global_spikes.partition();
const auto& cp = connection_part_;
for (auto dom: make_span(0, num_domains_)) {
auto cons = subrange_view(connections_, cp[dom], cp[dom+1]);
auto spks = subrange_view(global_spikes.values(), sp[dom], sp[dom+1]);
struct spike_pred {
bool operator()(const spike& spk, const cell_member_type& src)
{return spk.source<src;}
bool operator()(const cell_member_type& src, const spike& spk)
{return src<spk.source;}
};
if (cons.size()<spks.size()) {
auto sp = spks.begin();
auto cn = cons.begin();
while (cn!=cons.end() && sp!=spks.end()) {
auto sources = std::equal_range(sp, spks.end(), cn->source(), spike_pred());
for (auto s: make_range(sources.first, sources.second)) {
queues[cn->group_index()].push_back(cn->make_event(s));
}
sp = sources.first;
++cn;
}
}
else {
auto cn = cons.begin();
auto sp = spks.begin();
while (cn!=cons.end() && sp!=spks.end()) {
auto targets = std::equal_range(cn, cons.end(), sp->source);
for (auto c: make_range(targets.first, targets.second)) {
queues[c.group_index()].push_back(c.make_event(*sp));
}
// generate an event for each target cn = targets.first;
for (auto it=targets.first; it!=targets.second; ++it) { ++sp;
auto gidx = cell_group_index(it->destination().gid); }
queues[gidx].push_back(it->make_event(spike));
} }
} }
...@@ -117,33 +199,23 @@ public: ...@@ -117,33 +199,23 @@ public:
} }
/// Returns the total number of global spikes over the duration of the simulation /// Returns the total number of global spikes over the duration of the simulation
uint64_t num_spikes() const { return num_spikes_; } std::uint64_t num_spikes() const { return num_spikes_; }
const std::vector<connection>& connections() const { const std::vector<connection>& connections() const {
return connections_; return connections_;
} }
communication_policy_type communication_policy() const {
return communication_policy_;
}
void reset() { void reset() {
num_spikes_ = 0; num_spikes_ = 0;
} }
private: private:
std::size_t cell_group_index(cell_gid_type cell_gid) const { cell_size_type num_local_groups_;
EXPECTS(is_local_cell(cell_gid)); cell_size_type num_domains_;
return cell_gid_partition_.index(cell_gid);
}
std::vector<connection> connections_; std::vector<connection> connections_;
std::vector<cell_size_type> connection_part_;
communication_policy_type communication_policy_; communication_policy_type comms_;
std::uint64_t num_spikes_ = 0u;
uint64_t num_spikes_ = 0u;
gid_partition_type cell_gid_partition_;
}; };
} // namespace communication } // namespace communication
......
...@@ -12,11 +12,16 @@ namespace mc { ...@@ -12,11 +12,16 @@ namespace mc {
class connection { class connection {
public: public:
connection() = default; connection() = default;
connection(cell_member_type src, cell_member_type dest, float w, time_type d) : connection( cell_member_type src,
cell_member_type dest,
float w,
time_type d,
cell_gid_type gidx=cell_gid_type(-1)):
source_(src), source_(src),
destination_(dest), destination_(dest),
weight_(w), weight_(w),
delay_(d) delay_(d),
group_index_(gidx)
{} {}
float weight() const { return weight_; } float weight() const { return weight_; }
...@@ -24,6 +29,7 @@ public: ...@@ -24,6 +29,7 @@ public:
cell_member_type source() const { return source_; } cell_member_type source() const { return source_; }
cell_member_type destination() const { return destination_; } cell_member_type destination() const { return destination_; }
cell_gid_type group_index() const { return group_index_; }
postsynaptic_spike_event make_event(const spike& s) { postsynaptic_spike_event make_event(const spike& s) {
return {destination_, s.time + delay_, weight_}; return {destination_, s.time + delay_, weight_};
...@@ -34,6 +40,7 @@ private: ...@@ -34,6 +40,7 @@ private:
cell_member_type destination_; cell_member_type destination_;
float weight_; float weight_;
time_type delay_; time_type delay_;
cell_gid_type group_index_;
}; };
// connections are sorted by source id // connections are sorted by source id
......
#pragma once #pragma once
#include <type_traits>
#include <vector> #include <vector>
#include <unordered_map>
#include <backends.hpp> #include <backends.hpp>
#include <common_types.hpp> #include <common_types.hpp>
#include <communication/global_policy.hpp> #include <communication/global_policy.hpp>
#include <hardware/node_info.hpp>
#include <recipe.hpp> #include <recipe.hpp>
#include <util/optional.hpp> #include <util/optional.hpp>
#include <util/partition.hpp> #include <util/partition.hpp>
...@@ -13,97 +16,91 @@ ...@@ -13,97 +16,91 @@
namespace nest { namespace nest {
namespace mc { namespace mc {
// Meta data used to guide the domain_decomposition in distributing inline bool has_gpu_backend(cell_kind k) {
// and grouping cells. if (k==cell_kind::cable1d_neuron) {
struct group_rules { return true;
cell_size_type target_group_size; }
backend_policy policy; return false;
}
/// Utility type for meta data for a local cell group.
struct group_description {
const cell_kind kind;
const std::vector<cell_gid_type> gids;
const backend_kind backend;
group_description(cell_kind k, std::vector<cell_gid_type> g, backend_kind b):
kind(k), gids(std::move(g)), backend(b)
{}
}; };
class domain_decomposition { class domain_decomposition {
using gid_partition_type =
util::partition_range<std::vector<cell_gid_type>::const_iterator>;
public: public:
/// Utility type for meta data for a local cell group. domain_decomposition(const recipe& rec, hw::node_info nd):
struct group_range_type { node_(nd)
cell_gid_type begin; {
cell_gid_type end; using kind_type = std::underlying_type<cell_kind>::type;
cell_kind kind; using util::make_span;
num_domains_ = communication::global_policy::size();
domain_id_ = communication::global_policy::id();
num_global_cells_ = rec.num_cells();
auto dom_size = [this](unsigned dom) -> cell_gid_type {
const cell_gid_type B = num_global_cells_/num_domains_;
const cell_gid_type R = num_global_cells_ - num_domains_*B;
return B + (dom<R);
}; };
domain_decomposition(const recipe& rec, const group_rules& rules): // TODO: load balancing logic will be refactored into its own class,
backend_policy_(rules.policy) // and the domain decomposition will become a much simpler representation
{ // of the result distribution of cells over domains.
EXPECTS(rules.target_group_size>0);
auto num_domains = communication::global_policy::size(); // Global load balance
auto domain_id = communication::global_policy::id();
// Partition the cells globally across the domains. gid_part_ = make_partition(
num_global_cells_ = rec.num_cells(); gid_divisions_, transform_view(make_span(0, num_domains_), dom_size));
cell_begin_ = (cell_gid_type)(num_global_cells_*(domain_id/(double)num_domains));
cell_end_ = (cell_gid_type)(num_global_cells_*((domain_id+1)/(double)num_domains)); // Local load balance
// Partition the local cells into cell groups that satisfy three std::unordered_map<kind_type, std::vector<cell_gid_type>> kind_lists;
// criteria: for (auto gid: make_span(gid_part_[domain_id_])) {
// 1. the cells in a group have contiguous gid kind_lists[rec.get_cell_kind(gid)].push_back(gid);
// 2. the size of a cell group does not exceed rules.target_group_size;
// 3. all cells in a cell group have the same cell_kind.
// This simple greedy algorithm appends contiguous cells to a cell
// group until either the target group size is reached, or a cell with a
// different kind is encountered.
// On completion, cell_starts_ partitions the local gid into cell
// groups, and group_kinds_ records the cell kind in each cell group.
if (num_local_cells()>0) {
cell_size_type group_size = 1;
// 1st group starts at cell_begin_
group_starts_.push_back(cell_begin_);
auto group_kind = rec.get_cell_kind(cell_begin_);
// set kind for 1st group
group_kinds_.push_back(group_kind);
cell_gid_type gid = cell_begin_+1;
while (gid<cell_end_) {
auto kind = rec.get_cell_kind(gid);
// Test if gid belongs to a new cell group, i.e. whether it has
// a new cell_kind or if the target group size has been reached.
if (kind!=group_kind || group_size>=rules.target_group_size) {
group_starts_.push_back(gid);
group_kinds_.push_back(kind);
group_size = 0;
} }
++group_size;
++gid; // Create a flat vector of the cell kinds present on this node,
// partitioned such that kinds for which GPU implementation are
// listed before the others. This is a very primitive attempt at
// scheduling; the cell_groups that run on the GPU will be executed
// before other cell_groups, which is likely to be more efficient.
//
// TODO: This creates an dependency between the load balancer and
// the threading internals. We need support for setting the priority
// of cell group updates according to rules such as the back end on
// which the cell group is running.
std::vector<cell_kind> kinds;
for (auto l: kind_lists) {
kinds.push_back(cell_kind(l.first));
} }
group_starts_.push_back(cell_end_); std::partition(kinds.begin(), kinds.end(), has_gpu_backend);
for (auto k: kinds) {
// put all cells into a single cell group on the gpu if possible
if (node_.num_gpus && has_gpu_backend(k)) {
groups_.push_back({k, std::move(kind_lists[k]), backend_kind::gpu});
} }
// otherwise place into cell groups of size 1 on the cpu cores
else {
for (auto gid: kind_lists[k]) {
groups_.push_back({k, {gid}, backend_kind::multicore});
} }
/// Returns the local index of the cell_group that contains a cell with
/// with gid.
/// If the cell is not on the local domain, the optional return value is
/// not set.
util::optional<cell_size_type>
local_group_from_gid(cell_gid_type gid) const {
// check if gid is a local cell
if (!is_local_gid(gid)) {
return util::nothing;
} }
return gid_group_partition().index(gid);
} }
/// Returns the gid of the first cell on the local domain.
cell_gid_type cell_begin() const {
return cell_begin_;
} }
/// Returns one past the gid of the last cell in the local domain. int gid_domain(cell_gid_type gid) const {
cell_gid_type cell_end() const { EXPECTS(gid<num_global_cells_);
return cell_end_; return gid_part_.index(gid);
} }
/// Returns the total number of cells in the global model. /// Returns the total number of cells in the global model.
...@@ -113,42 +110,35 @@ public: ...@@ -113,42 +110,35 @@ public:
/// Returns the number of cells on the local domain. /// Returns the number of cells on the local domain.
cell_size_type num_local_cells() const { cell_size_type num_local_cells() const {
return cell_end()-cell_begin(); auto rng = gid_part_[domain_id_];
return rng.second - rng.first;
} }
/// Returns the number of cell groups on the local domain. /// Returns the number of cell groups on the local domain.
cell_size_type num_local_groups() const { cell_size_type num_local_groups() const {
return group_kinds_.size(); return groups_.size();
} }
/// Returns meta data for a local cell group. /// Returns meta data for a local cell group.
group_range_type get_group(cell_size_type i) const { const group_description& get_group(cell_size_type i) const {
return {group_starts_[i], group_starts_[i+1], group_kinds_[i]}; EXPECTS(i<num_local_groups());
return groups_[i];
} }
/// Tests whether a gid is on the local domain. /// Tests whether a gid is on the local domain.
bool is_local_gid(cell_gid_type i) const { bool is_local_gid(cell_gid_type gid) const {
return i>=cell_begin_ && i<cell_end_; return algorithms::in_interval(gid, gid_part_[domain_id_]);
}
/// Return a partition of the cell gid over local cell groups.
gid_partition_type gid_group_partition() const {
return util::partition_view(group_starts_);
}
/// Returns the backend policy.
backend_policy backend() const {
return backend_policy_;
} }
private: private:
int num_domains_;
backend_policy backend_policy_; int domain_id_;
cell_gid_type cell_begin_; hw::node_info node_;
cell_gid_type cell_end_;
cell_size_type num_global_cells_; cell_size_type num_global_cells_;
std::vector<cell_size_type> group_starts_; std::vector<cell_gid_type> gid_divisions_;
decltype(util::make_partition(gid_divisions_, gid_divisions_)) gid_part_;
std::vector<cell_kind> group_kinds_; std::vector<cell_kind> group_kinds_;
std::vector<group_description> groups_;
}; };
} // namespace mc } // namespace mc
......
...@@ -11,25 +11,19 @@ namespace mc { ...@@ -11,25 +11,19 @@ namespace mc {
/// Cell_group to collect spike sources /// Cell_group to collect spike sources
class dss_cell_group: public cell_group { class dss_cell_group: public cell_group {
public: public:
using source_id_type = cell_member_type; dss_cell_group(std::vector<cell_gid_type> gids,
const std::vector<util::unique_any>& cell_descriptions):
dss_cell_group(cell_gid_type first_gid, const std::vector<util::unique_any>& cell_descriptions): gids_(std::move(gids))
gid_base_(first_gid)
{ {
using util::make_span; using util::make_span;
for (cell_gid_type i: make_span(0, cell_descriptions.size())) { for (cell_gid_type i: make_span(0, cell_descriptions.size())) {
// store spike times from description // store spike times from description
const auto times = util::any_cast<dss_cell_description>(cell_descriptions[i]).spike_times; auto times = util::any_cast<dss_cell_description>(cell_descriptions[i]).spike_times;
spike_times_.push_back(std::vector<time_type>(times)); util::sort(times);
spike_times_.push_back(std::move(times));
// Assure the spike times are sorted
std::sort(spike_times_[i].begin(), spike_times_[i].end());
// Now we can grab the first spike time // Take a reference to the first spike time
not_emit_it_.push_back(spike_times_[i].begin()); not_emit_it_.push_back(spike_times_[i].begin());
// create a lid to gid map
spike_sources_.push_back({gid_base_+i, 0});
} }
} }
...@@ -40,12 +34,11 @@ public: ...@@ -40,12 +34,11 @@ public:
} }
void reset() override { void reset() override {
// Declare both iterators outside of the for loop for consistency // Reset the pointers to the next undelivered spike to the start
// of the input range.
auto it = not_emit_it_.begin(); auto it = not_emit_it_.begin();
auto times = spike_times_.begin(); auto times = spike_times_.begin();
for (;it != not_emit_it_.end(); ++it, times++) { for (;it != not_emit_it_.end(); ++it, times++) {
// Point to the first not emitted spike.
*it = times->begin(); *it = times->begin();
} }
...@@ -56,25 +49,25 @@ public: ...@@ -56,25 +49,25 @@ public:
{} {}
void advance(time_type tfinal, time_type dt) override { void advance(time_type tfinal, time_type dt) override {
for (auto cell_idx: util::make_span(0, not_emit_it_.size())) { for (auto i: util::make_span(0, not_emit_it_.size())) {
// The first potential spike_time to emit for this cell // The first potential spike_time to emit for this cell
auto spike_time_it = not_emit_it_[cell_idx]; auto spike_time_it = not_emit_it_[i];
// Find the first to not emit and store as iterator // Find the first spike past tfinal
not_emit_it_[cell_idx] = std::find_if( not_emit_it_[i] = std::find_if(
spike_time_it, spike_times_[cell_idx].end(), spike_time_it, spike_times_[i].end(),
[tfinal](time_type t) {return t >= tfinal; } [tfinal](time_type t) {return t >= tfinal; }
); );
// loop over the range we now have (might be empty), and create spikes // Loop over the range and create spikes
for (; spike_time_it != not_emit_it_[cell_idx]; ++spike_time_it) { for (; spike_time_it != not_emit_it_[i]; ++spike_time_it) {
spikes_.push_back({ spike_sources_[cell_idx], *spike_time_it }); spikes_.push_back({ {gids_[i], 0u}, *spike_time_it });
} }
} }
}; };
void enqueue_events(const std::vector<postsynaptic_spike_event>& events) override { void enqueue_events(const std::vector<postsynaptic_spike_event>& events) override {
std::logic_error("The dss_cells do not support incoming events!"); std::runtime_error("The dss_cells do not support incoming events!");
} }
const std::vector<spike>& spikes() const override { const std::vector<spike>& spikes() const override {
...@@ -94,14 +87,11 @@ public: ...@@ -94,14 +87,11 @@ public:
} }
private: private:
// gid of first cell in group.
cell_gid_type gid_base_;
// Spikes that are generated. // Spikes that are generated.
std::vector<spike> spikes_; std::vector<spike> spikes_;
// Spike generators attached to the cell // Map of local index to gid
std::vector<source_id_type> spike_sources_; std::vector<cell_gid_type> gids_;
std::vector<probe_record> probes_; std::vector<probe_record> probes_;
......
...@@ -27,6 +27,15 @@ struct postsynaptic_spike_event { ...@@ -27,6 +27,15 @@ struct postsynaptic_spike_event {
cell_member_type target; cell_member_type target;
time_type time; time_type time;
float weight; float weight;
friend bool operator==(postsynaptic_spike_event l, postsynaptic_spike_event r) {
return l.target==r.target && l.time==r.time && l.weight==r.weight;
}
friend std::ostream& operator<<(std::ostream& o, const nest::mc::postsynaptic_spike_event& e)
{
return o << "E[tgt " << e.target << ", t " << e.time << ", w " << e.weight << "]";
}
}; };
struct sample_event { struct sample_event {
...@@ -140,9 +149,3 @@ private: ...@@ -140,9 +149,3 @@ private:
} // namespace nest } // namespace nest
} // namespace mc } // namespace mc
inline std::ostream& operator<<(
std::ostream& o, const nest::mc::postsynaptic_spike_event& e)
{
return o << "event[" << e.target << "," << e.time << "," << e.weight << "]";
}
#include <algorithm>
#include "affinity.hpp"
#include "gpu.hpp"
#include "node_info.hpp"
namespace nest {
namespace mc {
namespace hw {
// Return a node_info that describes the hardware resources available on this node.
// If unable to determine the number of available cores, assumes that there is one
// core available.
node_info get_node_info() {
auto res = num_cores();
unsigned ncpu = res? *res: 1u;
return {ncpu, num_gpus()};
}
} // namespace util
} // namespace mc
} // namespace nest
#pragma once
namespace nest {
namespace mc {
namespace hw {
// Information about the computational resources available on a compute node.
// Currently a simple enumeration of the number of cpu cores and gpus, which
// will become richer.
struct node_info {
node_info() = default;
node_info(unsigned c, unsigned g):
num_cpu_cores(c), num_gpus(g)
{}
unsigned num_cpu_cores = 1;
unsigned num_gpus = 0;
};
node_info get_node_info();
} // namespace util
} // namespace mc
} // namespace nest
...@@ -9,7 +9,7 @@ namespace hw { ...@@ -9,7 +9,7 @@ namespace hw {
#ifdef NMC_HAVE_CRAY #ifdef NMC_HAVE_CRAY
energy_size_type energy() { energy_size_type energy() {
energy_size_type result = -1; energy_size_type result = energy_size_type(-1);
std::ifstream fid("/sys/cray/pm_counters/energy"); std::ifstream fid("/sys/cray/pm_counters/energy");
if (fid) { if (fid) {
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <cstdint> #include <cstdint>
#include <functional> #include <functional>
#include <iterator> #include <iterator>
#include <unordered_map>
#include <vector> #include <vector>
#include <algorithms.hpp> #include <algorithms.hpp>
...@@ -35,9 +36,14 @@ public: ...@@ -35,9 +36,14 @@ public:
mc_cell_group() = default; mc_cell_group() = default;
template <typename Cells> template <typename Cells>
mc_cell_group(cell_gid_type first_gid, const Cells& cell_descriptions): mc_cell_group(std::vector<cell_gid_type> gids, const Cells& cell_descriptions):
gid_base_{first_gid} gids_(std::move(gids))
{ {
// Build lookup table for gid to local index.
for (auto i: util::make_span(0, gids_.size())) {
gid2lid_[gids_[i]] = i;
}
// Create lookup structure for probe and target ids. // Create lookup structure for probe and target ids.
build_handle_partitions(cell_descriptions); build_handle_partitions(cell_descriptions);
std::size_t n_probes = probe_handle_divisions_.back(); std::size_t n_probes = probe_handle_divisions_.back();
...@@ -51,41 +57,42 @@ public: ...@@ -51,41 +57,42 @@ public:
lowered_.initialize(cell_descriptions, target_handles_, probe_handles_); lowered_.initialize(cell_descriptions, target_handles_, probe_handles_);
// Create a list of the global identifiers for the spike sources // Create a list of the global identifiers for the spike sources.
auto source_gid = cell_gid_type{gid_base_}; auto source_gid = gids_.begin();
for (const auto& cell: cell_descriptions) { for (const auto& cell: cell_descriptions) {
for (cell_lid_type lid=0u; lid<cell.detectors().size(); ++lid) { for (cell_lid_type lid=0u; lid<cell.detectors().size(); ++lid) {
spike_sources_.push_back(source_id_type{source_gid, lid}); spike_sources_.push_back(source_id_type{*source_gid, lid});
} }
++source_gid; ++source_gid;
} }
EXPECTS(spike_sources_.size()==n_detectors); EXPECTS(spike_sources_.size()==n_detectors);
// Create the enumeration of probes attached to cells in this cell group // Create the enumeration of probes attached to cells in this cell group.
probes_.reserve(n_probes); probes_.reserve(n_probes);
for (auto i: util::make_span(0, cell_descriptions.size())){ for (auto i: util::make_span(0, cell_descriptions.size())){
const cell_gid_type probe_gid = gid_base_ + i; const cell_gid_type probe_gid = gids_[i];
const auto probes_on_cell = cell_descriptions[i].probes(); const auto probes_on_cell = cell_descriptions[i].probes();
for (cell_lid_type lid: util::make_span(0, probes_on_cell.size())) { for (cell_lid_type lid: util::make_span(0, probes_on_cell.size())) {
// get the unique global identifier of this probe // Get the unique global identifier of this probe.
cell_member_type id{probe_gid, lid}; cell_member_type id{probe_gid, lid};
// get the location and kind information of the probe // Get the location and kind information of the probe.
const auto p = probes_on_cell[lid]; const auto p = probes_on_cell[lid];
// record the combined identifier and probe details // Record the combined identifier and probe details.
probes_.push_back(probe_record{id, p.location, p.kind}); probes_.push_back(probe_record{id, p.location, p.kind});
} }
} }
} }
mc_cell_group(cell_gid_type first_gid, const std::vector<util::unique_any>& cell_descriptions): mc_cell_group(std::vector<cell_gid_type> gids,
const std::vector<util::unique_any>& cell_descriptions):
mc_cell_group( mc_cell_group(
first_gid, std::move(gids),
util::transform_view( util::transform_view(cell_descriptions,
cell_descriptions, [](const util::unique_any& c) -> const cell& {
[](const util::unique_any& c) -> const cell& {return util::any_cast<const cell&>(c);}) return util::any_cast<const cell&>(c);
) }))
{} {}
cell_kind get_cell_kind() const override { cell_kind get_cell_kind() const override {
...@@ -132,7 +139,7 @@ public: ...@@ -132,7 +139,7 @@ public:
auto& s = samplers_[m->sampler_index]; auto& s = samplers_[m->sampler_index];
EXPECTS((bool)s.sampler); EXPECTS((bool)s.sampler);
time_type cell_time = lowered_.time(s.cell_gid-gid_base_); time_type cell_time = lowered_.time(gid2lid(s.cell_gid));
if (cell_time<m->time) { if (cell_time<m->time) {
// This cell hasn't reached this sample time yet. // This cell hasn't reached this sample time yet.
requeue_sample_events.push_back(*m); requeue_sample_events.push_back(*m);
...@@ -161,8 +168,8 @@ public: ...@@ -161,8 +168,8 @@ public:
lowered_.step_integration(); lowered_.step_integration();
if (util::is_debug_mode() && !lowered_.is_physical_solution()) { if (util::is_debug_mode() && !lowered_.is_physical_solution()) {
std::cerr << "warning: solution out of bounds for cell " std::cerr << "warning: solution out of bounds at (max) t "
<< gid_base_ << " at (max) t " << lowered_.max_time() << " ms\n"; << lowered_.max_time() << " ms\n";
} }
} }
...@@ -213,8 +220,11 @@ public: ...@@ -213,8 +220,11 @@ public:
} }
private: private:
// gid of first cell in group. // List of the gids of the cells in the group
cell_gid_type gid_base_; std::vector<cell_gid_type> gids_;
// Hash table for converting gid to local index
std::unordered_map<cell_gid_type, cell_gid_type> gid2lid_;
// The lowered cell state (e.g. FVM) of the cell. // The lowered cell state (e.g. FVM) of the cell.
lowered_cell_type lowered_; lowered_cell_type lowered_;
...@@ -275,19 +285,7 @@ private: ...@@ -275,19 +285,7 @@ private:
// Use handle partition to get index from id. // Use handle partition to get index from id.
template <typename Divisions> template <typename Divisions>
std::size_t handle_partition_lookup(const Divisions& divisions, cell_member_type id) const { std::size_t handle_partition_lookup(const Divisions& divisions, cell_member_type id) const {
// NB: without any assertion checking, this would just be: return divisions[gid2lid(id.gid)]+id.index;
// return divisions[id.gid-gid_base_]+id.index;
EXPECTS(id.gid>=gid_base_);
auto handle_partition = util::partition_view(divisions);
EXPECTS(id.gid-gid_base_<handle_partition.size());
auto ival = handle_partition[id.gid-gid_base_];
std::size_t i = ival.first + id.index;
EXPECTS(i<ival.second);
return i;
} }
// Get probe handle from probe id. // Get probe handle from probe id.
...@@ -308,6 +306,12 @@ private: ...@@ -308,6 +306,12 @@ private:
sample_events_.push({i, sampler_start_times_[i]}); sample_events_.push({i, sampler_start_times_[i]});
} }
} }
cell_gid_type gid2lid(cell_gid_type gid) const {
auto it = gid2lid_.find(gid);
EXPECTS(it!=gid2lid_.end());
return it->second;
}
}; };
} // namespace mc } // namespace mc
......
...@@ -15,48 +15,28 @@ namespace nest { ...@@ -15,48 +15,28 @@ namespace nest {
namespace mc { namespace mc {
model::model(const recipe& rec, const domain_decomposition& decomp): model::model(const recipe& rec, const domain_decomposition& decomp):
domain_(decomp) communicator_(rec, decomp)
{ {
// set up communicator based on partition for (auto i: util::make_span(0, decomp.num_local_groups())) {
communicator_ = communicator_type(domain_.gid_group_partition()); for (auto gid: decomp.get_group(i).gids) {
gid_groups_[gid] = i;
// generate the cell groups in parallel, with one task per cell group }
cell_groups_.resize(domain_.num_local_groups()); }
// thread safe vector for constructing the list of probes in parallel
threading::parallel_vector<probe_record> probe_tmp;
// Generate the cell groups in parallel, with one task per cell group.
cell_groups_.resize(decomp.num_local_groups());
threading::parallel_for::apply(0, cell_groups_.size(), threading::parallel_for::apply(0, cell_groups_.size(),
[&](cell_gid_type i) { [&](cell_gid_type i) {
PE("setup", "cells"); PE("setup", "cells");
cell_groups_[i] = cell_group_factory(rec, decomp.get_group(i));
auto group = domain_.get_group(i);
std::vector<util::unique_any> cell_descriptions(group.end-group.begin);
for (auto gid: util::make_span(group.begin, group.end)) {
auto i = gid-group.begin;
cell_descriptions[i] = rec.get_cell_description(gid);
}
cell_groups_[i] = cell_group_factory(
group.kind, group.begin, cell_descriptions, domain_.backend());
PL(2); PL(2);
}); });
// store probes // Store probes.
for (const auto& c: cell_groups_) { for (const auto& c: cell_groups_) {
util::append(probes_, c->probes()); util::append(probes_, c->probes());
} }
// generate the network connections
for (cell_gid_type i: util::make_span(domain_.cell_begin(), domain_.cell_end())) {
for (const auto& cc: rec.connections_on(i)) {
connection conn{cc.source, cc.dest, cc.weight, cc.delay};
communicator_.add_connection(conn);
}
}
communicator_.construct();
// Allocate an empty queue buffer for each cell group // Allocate an empty queue buffer for each cell group
// These must be set initially to ensure that a queue is available for each // These must be set initially to ensure that a queue is available for each
// cell group for the first time step. // cell group for the first time step.
...@@ -169,11 +149,11 @@ time_type model::run(time_type tfinal, time_type dt) { ...@@ -169,11 +149,11 @@ time_type model::run(time_type tfinal, time_type dt) {
} }
void model::attach_sampler(cell_member_type probe_id, sampler_function f, time_type tfrom) { void model::attach_sampler(cell_member_type probe_id, sampler_function f, time_type tfrom) {
const auto idx = domain_.local_group_from_gid(probe_id.gid); // TODO: remove the gid_groups data structure completely when/if no longer needed
// for the probe interface.
// only attach samplers for local cells auto it = gid_groups_.find(probe_id.gid);
if (idx) { if (it!=gid_groups_.end()) {
cell_groups_[*idx]->add_sampler(probe_id, f, tfrom); cell_groups_[it->second]->add_sampler(probe_id, f, tfrom);
} }
} }
...@@ -189,10 +169,6 @@ std::size_t model::num_groups() const { ...@@ -189,10 +169,6 @@ std::size_t model::num_groups() const {
return cell_groups_.size(); return cell_groups_.size();
} }
std::size_t model::num_cells() const {
return domain_.num_local_cells();
}
void model::set_binning_policy(binning_kind policy, time_type bin_interval) { void model::set_binning_policy(binning_kind policy, time_type bin_interval) {
for (auto& group: cell_groups_) { for (auto& group: cell_groups_) {
group->set_binning_policy(policy, bin_interval); group->set_binning_policy(policy, bin_interval);
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <sampler_function.hpp> #include <sampler_function.hpp>
#include <thread_private_spike_store.hpp> #include <thread_private_spike_store.hpp>
#include <util/nop.hpp> #include <util/nop.hpp>
#include <util/rangeutil.hpp>
#include <util/unique_any.hpp> #include <util/unique_any.hpp>
namespace nest { namespace nest {
...@@ -34,14 +35,13 @@ public: ...@@ -34,14 +35,13 @@ public:
std::size_t num_spikes() const; std::size_t num_spikes() const;
std::size_t num_groups() const;
std::size_t num_cells() const;
// Set event binning policy on all our groups. // Set event binning policy on all our groups.
void set_binning_policy(binning_kind policy, time_type bin_interval); void set_binning_policy(binning_kind policy, time_type bin_interval);
// access cell_group directly // access cell_group directly
// TODO: depricate. Currently used in some validation tests to inject
// events directly into a cell group. This should be done with a spiking
// neuron.
cell_group& group(int i); cell_group& group(int i);
// register a callback that will perform a export of the global // register a callback that will perform a export of the global
...@@ -53,11 +53,10 @@ public: ...@@ -53,11 +53,10 @@ public:
void set_local_spike_callback(spike_export_function export_callback); void set_local_spike_callback(spike_export_function export_callback);
private: private:
const domain_decomposition &domain_; std::size_t num_groups() const;
time_type t_ = 0.; time_type t_ = 0.;
std::vector<cell_group_ptr> cell_groups_; std::vector<cell_group_ptr> cell_groups_;
communicator_type communicator_;
std::vector<probe_record> probes_; std::vector<probe_record> probes_;
using event_queue_type = typename communicator_type::event_queue; using event_queue_type = typename communicator_type::event_queue;
...@@ -69,6 +68,12 @@ private: ...@@ -69,6 +68,12 @@ private:
spike_export_function global_export_callback_ = util::nop_function; spike_export_function global_export_callback_ = util::nop_function;
spike_export_function local_export_callback_ = util::nop_function; spike_export_function local_export_callback_ = util::nop_function;
// Hash table for looking up the group index of the cell_group that
// contains gid
std::unordered_map<cell_gid_type, cell_gid_type> gid_groups_;
communicator_type communicator_;
// Convenience functions that map the spike buffers and event queues onto // Convenience functions that map the spike buffers and event queues onto
// the appropriate integration interval. // the appropriate integration interval.
// //
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment