Skip to content
Snippets Groups Projects
Commit 520d27a1 authored by Benjamin Cumming's avatar Benjamin Cumming
Browse files

Add MPI support to miniapp

Added MPI support to the miniapp
- added ring model and all-to-all model
- started refactoring general model setup steps into a model class in
  miniapp
- optional -DWITH_MPI to turn on MPI support
parent a65440cb
No related branches found
No related tags found
No related merge requests found
...@@ -64,3 +64,5 @@ external/modparser-update ...@@ -64,3 +64,5 @@ external/modparser-update
external/tmp external/tmp
mechanisms/*.hpp mechanisms/*.hpp
# build path
build
set(HEADERS set(HEADERS
) )
set(MINIAPP_SOURCES set(MINIAPP_SOURCES
mpi.cpp
io.cpp io.cpp
miniapp.cpp miniapp.cpp
) )
......
...@@ -6,14 +6,12 @@ namespace io { ...@@ -6,14 +6,12 @@ namespace io {
// read simulation options from json file with name fname // read simulation options from json file with name fname
// for now this is just a placeholder // for now this is just a placeholder
options read_options(std::string fname) options read_options(std::string fname) {
{
// 10 cells, 1 synapses per cell, 10 compartments per segment // 10 cells, 1 synapses per cell, 10 compartments per segment
return {5, 1, 100}; return {100, 1, 100};
} }
std::ostream& operator<<(std::ostream& o, const options& opt) std::ostream& operator<<(std::ostream& o, const options& opt) {
{
o << "simultion options:\n"; o << "simultion options:\n";
o << " cells : " << opt.cells << "\n"; o << " cells : " << opt.cells << "\n";
o << " compartments/segment : " << opt.compartments_per_segment << "\n"; o << " compartments/segment : " << opt.compartments_per_segment << "\n";
......
...@@ -10,15 +10,98 @@ ...@@ -10,15 +10,98 @@
#include "profiling/profiler.hpp" #include "profiling/profiler.hpp"
#include "communication/communicator.hpp" #include "communication/communicator.hpp"
#include "communication/serial_global_policy.hpp" #include "communication/serial_global_policy.hpp"
#include "communication/mpi_global_policy.hpp"
using namespace nest; using namespace nest;
using real_type = double; using real_type = double;
using index_type = int; using index_type = int;
using id_type = uint32_t;
using numeric_cell = mc::fvm::fvm_cell<real_type, index_type>; using numeric_cell = mc::fvm::fvm_cell<real_type, index_type>;
using cell_group = mc::cell_group<numeric_cell>; using cell_group = mc::cell_group<numeric_cell>;
#ifdef WITH_MPI
using communicator_type =
mc::communication::communicator<mc::communication::mpi_global_policy>;
#else
using communicator_type = using communicator_type =
mc::communication::communicator<mc::communication::serial_global_policy>; mc::communication::communicator<mc::communication::serial_global_policy>;
#endif
struct model {
communicator_type communicator;
std::vector<cell_group> cell_groups;
double time_init;
double time_network;
double time_solve;
double time_comms;
void print_times() const {
std::cout << "initialization took " << time_init << " s\n";
std::cout << "network took " << time_network << " s\n";
std::cout << "solve took " << time_solve << " s\n";
std::cout << "comms took " << time_comms << " s\n";
}
int num_groups() const {
return cell_groups.size();
}
void run(double tfinal, double dt) {
auto t = 0.;
auto delta = communicator.min_delay();
time_solve = 0.;
time_comms = 0.;
while(t<tfinal) {
auto start_solve = mc::util::timer_type::tic();
mc::threading::parallel_for::apply(
0, num_groups(),
[&](int i) {
cell_groups[i].enqueue_events(communicator.queue(i));
cell_groups[i].advance(t+delta, dt);
communicator.add_spikes(cell_groups[i].spikes());
cell_groups[i].clear_spikes();
}
);
time_solve += mc::util::timer_type::toc(start_solve);
auto start_comms = mc::util::timer_type::tic();
communicator.exchange();
time_comms += mc::util::timer_type::toc(start_comms);
t += delta;
}
}
void init_communicator() {
// calculate the source and synapse distribution serially
std::vector<id_type> target_counts(num_groups());
std::vector<id_type> source_counts(num_groups());
for (auto i=0; i<num_groups(); ++i) {
target_counts[i] = cell_groups[i].cell().synapses()->size();
source_counts[i] = cell_groups[i].spike_sources().size();
}
target_map = mc::algorithms::make_index(target_counts);
source_map = mc::algorithms::make_index(source_counts);
// create connections
communicator = communicator_type(num_groups(), target_counts);
}
void update_gids() {
auto com_policy = communicator.communication_policy();
auto global_source_map = com_policy.make_map(source_map.back());
auto domain_idx = communicator.domain_id();
for (auto i=0; i<num_groups(); ++i) {
cell_groups[i].set_source_gids(source_map[i]+global_source_map[domain_idx]);
cell_groups[i].set_target_gids(target_map[i]+communicator.target_gid_from_group_lid(0));
}
}
// TODO : only stored here because init_communicator() and update_gids() are split
std::vector<id_type> source_map;
std::vector<id_type> target_map;
};
// define some global model parameters // define some global model parameters
namespace parameters { namespace parameters {
...@@ -27,7 +110,7 @@ namespace synapses { ...@@ -27,7 +110,7 @@ namespace synapses {
constexpr double delay = 5.0; // ms constexpr double delay = 5.0; // ms
// connection weight // connection weight
constexpr double weight = 0.05; // uS constexpr double weight = 0.005; // uS
} }
} }
...@@ -36,152 +119,208 @@ namespace synapses { ...@@ -36,152 +119,208 @@ namespace synapses {
/////////////////////////////////////// ///////////////////////////////////////
/// make a single abstract cell /// make a single abstract cell
mc::cell make_cell(int compartments_per_segment); mc::cell make_cell(int compartments_per_segment, int num_synapses);
/// do basic setup (initialize global state, print banner, etc) /// do basic setup (initialize global state, print banner, etc)
void setup(); void setup(int argc, char** argv);
/// helper function for initializing cells /// helper function for initializing cells
cell_group make_lowered_cell(int cell_index, const mc::cell& c); cell_group make_lowered_cell(int cell_index, const mc::cell& c);
/// models
void ring_model(nest::mc::io::options& opt, model& m);
void all_to_all_model(nest::mc::io::options& opt, model& m);
/////////////////////////////////////// ///////////////////////////////////////
// main // main
/////////////////////////////////////// ///////////////////////////////////////
int main(void) { int main(int argc, char** argv) {
setup(); setup(argc, argv);
// read parameters // read parameters
mc::io::options opt; mc::io::options opt;
try { try {
opt = mc::io::read_options(""); opt = mc::io::read_options("");
std::cout << opt << "\n"; if (mc::mpi::rank()==0) {
std::cout << opt << "\n";
}
} }
catch (std::exception e) { catch (std::exception e) {
std::cerr << e.what() << std::endl; std::cerr << e.what() << std::endl;
exit(1); exit(1);
} }
model m;
//ring_model(opt, m);
all_to_all_model(opt, m);
///////////////////////////////////////////////////// /////////////////////////////////////////////////////
// make cells // time stepping
///////////////////////////////////////////////////// /////////////////////////////////////////////////////
auto tfinal = 50.;
auto dt = 0.01;
auto id = m.communicator.domain_id();
if (!id) {
m.communicator.add_spike({0, 5});
}
m.run(tfinal, dt);
if (!id) {
m.print_times();
std::cout << "there were " << m.communicator.num_spikes() << " spikes\n";
}
#ifdef SPLAT
if (!mc::mpi::rank()) {
//for (auto i=0u; i<m.cell_groups.size(); ++i) {
m.cell_groups[0].splat("cell0.txt");
m.cell_groups[1].splat("cell1.txt");
m.cell_groups[2].splat("cell2.txt");
//}
}
#endif
#ifdef WITH_MPI
mc::mpi::finalize();
#endif
}
///////////////////////////////////////
// models
///////////////////////////////////////
void ring_model(nest::mc::io::options& opt, model& m) {
//
// make cells
//
// make a basic cell // make a basic cell
auto basic_cell = make_cell(opt.compartments_per_segment); auto basic_cell = make_cell(opt.compartments_per_segment, 1);
// make a vector for storing all of the cells // make a vector for storing all of the cells
auto start_init = mc::util::timer_type::tic(); auto start_init = mc::util::timer_type::tic();
std::vector<cell_group> cell_groups(opt.cells); m.cell_groups = std::vector<cell_group>(opt.cells);
// initialize the cells in parallel // initialize the cells in parallel
mc::threading::parallel_for::apply( mc::threading::parallel_for::apply(
0, opt.cells, 0, opt.cells,
[&](int i) { [&](int i) {
// initialize cell // initialize cell
cell_groups[i] = make_lowered_cell(i, basic_cell); m.cell_groups[i] = make_lowered_cell(i, basic_cell);
} }
); );
auto time_init = mc::util::timer_type::toc(start_init); m.time_init = mc::util::timer_type::toc(start_init);
///////////////////////////////////////////////////// //
// network creation // network creation
///////////////////////////////////////////////////// //
// calculate the source and synapse distribution serially
auto start_network = mc::util::timer_type::tic(); auto start_network = mc::util::timer_type::tic();
std::vector<uint32_t> target_counts(opt.cells); m.init_communicator();
std::vector<uint32_t> source_counts(opt.cells);
for (auto i=0; i<opt.cells; ++i) {
target_counts[i] = cell_groups[i].cell().synapses()->size();
source_counts[i] = cell_groups[i].spike_sources().size();
}
auto target_map = mc::algorithms::make_index(target_counts);
auto source_map = mc::algorithms::make_index(source_counts);
// create connections for (auto i=0u; i<(id_type)opt.cells; ++i) {
communicator_type communicator(opt.cells, target_counts); m.communicator.add_connection({
for(auto i=0u; i<(uint32_t)opt.cells; ++i) {
communicator.add_connection({
i, (i+1)%opt.cells, i, (i+1)%opt.cells,
parameters::synapses::weight, parameters::synapses::delay parameters::synapses::weight, parameters::synapses::delay
}); });
} }
communicator.construct();
auto global_source_map = m.communicator.construct();
communicator.communication_policy().make_map(source_map.back());
auto domain_idx = communicator.communication_policy().id();
for(auto i=0u; i<(uint32_t)opt.cells; ++i) {
cell_groups[i].set_source_gids(source_map[i]+global_source_map[domain_idx]);
cell_groups[i].set_target_lids(target_map[i]);
}
auto time_network = mc::util::timer_type::toc(start_network); m.update_gids();
///////////////////////////////////////////////////// m.time_network = mc::util::timer_type::toc(start_network);
// time stepping }
/////////////////////////////////////////////////////
auto start_simulation = mc::util::timer_type::tic();
auto tfinal = 20.; void all_to_all_model(nest::mc::io::options& opt, model& m) {
auto t = 0.; //
auto dt = 0.01; // make cells
auto delta = communicator.min_delay(); //
auto timer = mc::util::timer_type();
communicator.add_spike({opt.cells-1u, 5});
while(t<tfinal) {
mc::threading::parallel_for::apply(
0, opt.cells,
[&](int i) {
/*if(communicator.queue(i).size()) {
std::cout << ":: delivering events to group " << i << "\n";
std::cout << " " << communicator.queue(i) << "\n";
}*/
cell_groups[i].enqueue_events(communicator.queue(i));
cell_groups[i].advance(t+delta, dt);
communicator.add_spikes(cell_groups[i].spikes());
cell_groups[i].clear_spikes();
}
);
communicator.exchange(); // make a basic cell
auto basic_cell = make_cell(opt.compartments_per_segment, opt.cells-1);
t += delta; // make a vector for storing all of the cells
auto start_init = timer.tic();
id_type ncell_global = opt.cells;
id_type ncell_local = ncell_global / m.communicator.num_domains();
int remainder = ncell_global - (ncell_local*m.communicator.num_domains());
if (m.communicator.domain_id()<remainder) {
ncell_local++;
} }
for(auto i=0u; i<cell_groups.size(); ++i) { m.cell_groups = std::vector<cell_group>(ncell_local);
cell_groups[i].splat("cell"+std::to_string(i)+".txt");
// initialize the cells in parallel
mc::threading::parallel_for::apply(
0, ncell_local,
[&](int i) {
m.cell_groups[i] = make_lowered_cell(i, basic_cell);
}
);
m.time_init = timer.toc(start_init);
//
// network creation
//
auto start_network = timer.tic();
m.init_communicator();
// lid is local cell/group id
for (auto lid=0u; lid<ncell_local; ++lid) {
auto target = m.communicator.target_gid_from_group_lid(lid);
auto gid = m.communicator.group_gid_from_group_lid(lid);
// tid is global cell/group id
for (auto tid=0u; tid<ncell_global; ++tid) {
if (gid!=tid) {
m.communicator.add_connection({
tid, target++,
parameters::synapses::weight, parameters::synapses::delay
});
}
}
} }
auto time_simulation = mc::util::timer_type::toc(start_simulation); m.communicator.construct();
std::cout << "initialization took " << time_init << " s\n"; m.update_gids();
std::cout << "network took " << time_network << " s\n";
std::cout << "simulation took " << time_simulation << " s\n"; m.time_network = timer.toc(start_network);
std::cout << "performed " << int(tfinal/dt) << " time steps\n";
} }
/////////////////////////////////////// ///////////////////////////////////////
// function definitions // function definitions
/////////////////////////////////////// ///////////////////////////////////////
void setup() void setup(int argc, char** argv) {
{ #ifdef WITH_MPI
mc::mpi::init(&argc, &argv);
// print banner
if (mc::mpi::rank()==0) {
std::cout << "====================\n";
std::cout << " starting miniapp\n";
std::cout << " - " << mc::threading::description() << " threading support\n";
std::cout << " - MPI support\n";
std::cout << "====================\n";
}
#else
// print banner // print banner
std::cout << "====================\n"; std::cout << "====================\n";
std::cout << " starting miniapp\n"; std::cout << " starting miniapp\n";
std::cout << " - " << mc::threading::description() << " threading support\n"; std::cout << " - " << mc::threading::description() << " threading support\n";
std::cout << "====================\n"; std::cout << "====================\n";
#endif
// setup global state for the mechanisms // setup global state for the mechanisms
mc::mechanisms::setup_mechanism_helpers(); mc::mechanisms::setup_mechanism_helpers();
} }
// make a high level cell description for use in simulation // make a high level cell description for use in simulation
mc::cell make_cell(int compartments_per_segment) mc::cell make_cell(int compartments_per_segment, int num_synapses) {
{
nest::mc::cell cell; nest::mc::cell cell;
// Soma with diameter 12.6157 um and HH channel // Soma with diameter 12.6157 um and HH channel
...@@ -191,10 +330,10 @@ mc::cell make_cell(int compartments_per_segment) ...@@ -191,10 +330,10 @@ mc::cell make_cell(int compartments_per_segment)
// add dendrite of length 200 um and diameter 1 um with passive channel // add dendrite of length 200 um and diameter 1 um with passive channel
std::vector<mc::cable_segment*> dendrites; std::vector<mc::cable_segment*> dendrites;
dendrites.push_back(cell.add_cable(0, mc::segmentKind::dendrite, 0.5, 0.5, 200)); dendrites.push_back(cell.add_cable(0, mc::segmentKind::dendrite, 0.5, 0.5, 200));
//dendrites.push_back(cell.add_cable(1, mc::segmentKind::dendrite, 0.5, 0.25, 100)); dendrites.push_back(cell.add_cable(1, mc::segmentKind::dendrite, 0.5, 0.25, 100));
//dendrites.push_back(cell.add_cable(1, mc::segmentKind::dendrite, 0.5, 0.25, 100)); dendrites.push_back(cell.add_cable(1, mc::segmentKind::dendrite, 0.5, 0.25, 100));
for(auto d : dendrites) { for (auto d : dendrites) {
d->add_mechanism(mc::pas_parameters()); d->add_mechanism(mc::pas_parameters());
d->set_compartments(compartments_per_segment); d->set_compartments(compartments_per_segment);
d->mechanism("membrane").set("r_L", 100); d->mechanism("membrane").set("r_L", 100);
...@@ -204,13 +343,15 @@ mc::cell make_cell(int compartments_per_segment) ...@@ -204,13 +343,15 @@ mc::cell make_cell(int compartments_per_segment)
//cell.add_stimulus({1,1}, {5., 80., 0.3}); //cell.add_stimulus({1,1}, {5., 80., 0.3});
cell.add_detector({0,0}, 30); cell.add_detector({0,0}, 30);
cell.add_synapse({1, 0.5});
for (auto i=0; i<num_synapses; ++i) {
cell.add_synapse({1, 0.5});
}
return cell; return cell;
} }
cell_group make_lowered_cell(int cell_index, const mc::cell& c) cell_group make_lowered_cell(int cell_index, const mc::cell& c) {
{
return cell_group(c); return cell_group(c);
} }
#include <mpi.h>
#include <communication/mpi.hpp>
namespace nest {
namespace mc {
namespace mpi {
// global state
namespace state {
int size = -1;
int rank = -1;
} // namespace state
void init(int *argc, char ***argv) {
int provided;
// initialize with thread serialized level of thread safety
MPI_Init_thread(argc, argv, MPI_THREAD_SERIALIZED, &provided);
assert(provided>=MPI_THREAD_SERIALIZED);
MPI_Comm_rank(MPI_COMM_WORLD, &state::rank);
MPI_Comm_size(MPI_COMM_WORLD, &state::size);
}
void finalize() {
MPI_Finalize();
}
bool is_root() {
return state::rank == 0;
}
int rank() {
return state::rank;
}
int size() {
return state::size;
}
void barrier() {
MPI_Barrier(MPI_COMM_WORLD);
}
bool ballot(bool vote) {
using traits = mpi_traits<char>;
char result;
char value = vote ? 1 : 0;
MPI_Allreduce(&value, &result, 1, traits::mpi_type(), MPI_LAND, MPI_COMM_WORLD);
return result;
}
} // namespace mpi
} // namespace mc
} // namespace nest
...@@ -33,7 +33,7 @@ class cell_group { ...@@ -33,7 +33,7 @@ class cell_group {
cell_.voltage()(memory::all) = -65.; cell_.voltage()(memory::all) = -65.;
cell_.initialize(); cell_.initialize();
for(auto& d : c.detectors()) { for (auto& d : c.detectors()) {
spike_sources_.push_back( { spike_sources_.push_back( {
0u, spike_detector_type(cell_, d.first, d.second, 0.f) 0u, spike_detector_type(cell_, d.first, d.second, 0.f)
}); });
...@@ -41,31 +41,33 @@ class cell_group { ...@@ -41,31 +41,33 @@ class cell_group {
} }
void set_source_gids(index_type gid) { void set_source_gids(index_type gid) {
for(auto& s : spike_sources_) { for (auto& s : spike_sources_) {
s.index = gid++; s.index = gid++;
} }
} }
void set_target_lids(index_type lid) { void set_target_gids(index_type lid) {
first_target_lid_ = lid; first_target_gid_ = lid;
} }
#ifdef SPLAT
void splat(std::string fname) { void splat(std::string fname) {
char buffer[128]; char buffer[128];
std::ofstream fid(fname); std::ofstream fid(fname);
for(auto i=0u; i<tt.size(); ++i) { for (auto i=0u; i<tt.size(); ++i) {
sprintf(buffer, "%8.4f %16.8f %16.8f\n", tt[i], vs[i], vd[i]); sprintf(buffer, "%8.4f %16.8f %16.8f\n", tt[i], vs[i], vd[i]);
fid << buffer; fid << buffer;
} }
} }
#endif
void advance(double tfinal, double dt) { void advance(double tfinal, double dt) {
while (cell_.time()<tfinal) { while (cell_.time()<tfinal) {
#ifdef SPLAT
tt.push_back(cell_.time()); tt.push_back(cell_.time());
vs.push_back(cell_.voltage({0,0.0})); vs.push_back(cell_.voltage({0,0.0}));
vd.push_back(cell_.voltage({1,0.5})); vd.push_back(cell_.voltage({1,0.5}));
#endif
// look for events in the next time step // look for events in the next time step
auto tstep = std::min(tfinal, cell_.time()+dt); auto tstep = std::min(tfinal, cell_.time()+dt);
auto next = events_.pop_if_before(tstep); auto next = events_.pop_if_before(tstep);
...@@ -76,8 +78,7 @@ class cell_group { ...@@ -76,8 +78,7 @@ class cell_group {
// check for new spikes // check for new spikes
for (auto& s : spike_sources_) { for (auto& s : spike_sources_) {
auto spike = s.source.test(cell_, cell_.time()); if (auto spike = s.source.test(cell_, cell_.time())) {
if(spike) {
spikes_.push_back({s.index, spike.get()}); spikes_.push_back({s.index, spike.get()});
} }
} }
...@@ -85,6 +86,12 @@ class cell_group { ...@@ -85,6 +86,12 @@ class cell_group {
// apply events // apply events
if (next) { if (next) {
cell_.apply_event(next.get()); cell_.apply_event(next.get());
// apply events that are due within some epsilon of the current
// time step. This should be a parameter. e.g. with for variable
// order time stepping, use the minimum possible time step size.
while(auto e = events_.pop_if_before(cell_.time()+dt/10.)) {
cell_.apply_event(e.get());
}
} }
} }
...@@ -92,8 +99,8 @@ class cell_group { ...@@ -92,8 +99,8 @@ class cell_group {
template <typename R> template <typename R>
void enqueue_events(R events) { void enqueue_events(R events) {
for(auto e : events) { for (auto e : events) {
e.target -= first_target_lid_; e.target -= first_target_gid_;
events_.push(e); events_.push(e);
} }
} }
...@@ -107,8 +114,7 @@ class cell_group { ...@@ -107,8 +114,7 @@ class cell_group {
const cell_type& cell() const { return cell_; } const cell_type& cell() const { return cell_; }
const std::vector<spike_source_type>& const std::vector<spike_source_type>&
spike_sources() const spike_sources() const {
{
return spike_sources_; return spike_sources_;
} }
...@@ -118,19 +124,27 @@ class cell_group { ...@@ -118,19 +124,27 @@ class cell_group {
private : private :
// TEMPORARY... #ifdef SPLAT
// REMOVE as soon as we have a better way to probe cell state
std::vector<float> tt; std::vector<float> tt;
std::vector<float> vs; std::vector<float> vs;
std::vector<float> vd; std::vector<float> vd;
#endif
/// the lowered cell state (e.g. FVM) of the cell
cell_type cell_; cell_type cell_;
/// spike detectors attached to the cell
std::vector<spike_source_type> spike_sources_; std::vector<spike_source_type> spike_sources_;
// spikes that are generated //. spikes that are generated
std::vector<communication::spike<index_type>> spikes_; std::vector<communication::spike<index_type>> spikes_;
/// pending events to be delivered
event_queue events_; event_queue events_;
index_type first_target_lid_; /// the global id of the first target (e.g. a synapse) in this group
index_type first_target_gid_;
}; };
} // namespace mc } // namespace mc
......
...@@ -29,17 +29,17 @@ namespace communication { ...@@ -29,17 +29,17 @@ namespace communication {
template <typename CommunicationPolicy> template <typename CommunicationPolicy>
class communicator { class communicator {
public: public:
using index_type = uint32_t; using id_type = uint32_t;
using communication_policy_type = CommunicationPolicy; using communication_policy_type = CommunicationPolicy;
using spike_type = spike<index_type>; using spike_type = spike<id_type>;
communicator( index_type n_groups, std::vector<index_type> target_counts) : communicator() = default;
communicator(id_type n_groups, std::vector<id_type> target_counts) :
num_groups_local_(n_groups), num_groups_local_(n_groups),
num_targets_local_(target_counts.size()) num_targets_local_(target_counts.size())
{ {
communicator_id_ = communication_policy_.id();
target_map_ = nest::mc::algorithms::make_index(target_counts); target_map_ = nest::mc::algorithms::make_index(target_counts);
num_targets_local_ = target_map_.back(); num_targets_local_ = target_map_.back();
...@@ -51,50 +51,60 @@ public: ...@@ -51,50 +51,60 @@ public:
group_gid_map_ = communication_policy_.make_map(num_groups_local_); group_gid_map_ = communication_policy_.make_map(num_groups_local_);
// transform the target ids from lid to gid // transform the target ids from lid to gid
auto first_target = target_gid_map_[communicator_id_]; auto first_target = target_gid_map_[domain_id()];
for(auto &id : target_map_) { for (auto &id : target_map_) {
id += first_target; id += first_target;
} }
} }
id_type target_gid_from_group_lid(id_type lid) const {
EXPECTS(lid<num_groups_local_);
return target_map_[lid];
}
id_type group_gid_from_group_lid(id_type lid) const {
EXPECTS(lid<num_groups_local_);
return group_gid_map_[domain_id()] + lid;
}
void add_connection(connection con) { void add_connection(connection con) {
EXPECTS(is_local_target(con.destination())); EXPECTS(is_local_target(con.destination()));
connections_.push_back(con); connections_.push_back(con);
} }
bool is_local_target(index_type gid) { bool is_local_target(id_type gid) {
return gid>=target_gid_map_[communicator_id_] return gid>=target_gid_map_[domain_id()]
&& gid<target_gid_map_[communicator_id_+1]; && gid<target_gid_map_[domain_id()+1];
} }
bool is_local_group(index_type gid) { bool is_local_group(id_type gid) {
return gid>=group_gid_map_[communicator_id_] return gid>=group_gid_map_[domain_id()]
&& gid<group_gid_map_[communicator_id_+1]; && gid<group_gid_map_[domain_id()+1];
} }
/// return the global id of the first group in domain d /// return the global id of the first group in domain d
/// the groups in domain d are in the contiguous half open range /// the groups in domain d are in the contiguous half open range
/// [domain_first_group(d), domain_first_group(d+1)) /// [domain_first_group(d), domain_first_group(d+1))
index_type group_gid_first(int d) const { id_type group_gid_first(int d) const {
return group_gid_map_[d]; return group_gid_map_[d];
} }
index_type target_lid(index_type gid) { id_type target_lid(id_type gid) {
EXPECTS(is_local_group(gid)); EXPECTS(is_local_group(gid));
return gid - target_gid_map_[communicator_id_]; return gid - target_gid_map_[domain_id()];
} }
// builds the optimized data structure // builds the optimized data structure
void construct() { void construct() {
if(!std::is_sorted(connections_.begin(), connections_.end())) { if (!std::is_sorted(connections_.begin(), connections_.end())) {
std::sort(connections_.begin(), connections_.end()); std::sort(connections_.begin(), connections_.end());
} }
} }
float min_delay() { float min_delay() {
auto local_min = std::numeric_limits<float>::max(); auto local_min = std::numeric_limits<float>::max();
for(auto& con : connections_) { for (auto& con : connections_) {
local_min = std::min(local_min, con.delay()); local_min = std::min(local_min, con.delay());
} }
...@@ -103,7 +113,7 @@ public: ...@@ -103,7 +113,7 @@ public:
// return the local group index of the group which hosts the target with // return the local group index of the group which hosts the target with
// global id gid // global id gid
index_type local_group_from_global_target(index_type gid) { id_type local_group_from_global_target(id_type gid) {
// assert that gid is in range // assert that gid is in range
EXPECTS(is_local_target(gid)); EXPECTS(is_local_target(gid));
...@@ -133,10 +143,11 @@ public: ...@@ -133,10 +143,11 @@ public:
// on each node // on each node
//profiler_.enter("global exchange"); //profiler_.enter("global exchange");
auto global_spikes = communication_policy_.gather_spikes(local_spikes()); auto global_spikes = communication_policy_.gather_spikes(local_spikes());
num_spikes_ += global_spikes.size();
clear_thread_spike_buffers(); clear_thread_spike_buffers();
//profiler_.leave(); //profiler_.leave();
for(auto& q : events_) { for (auto& q : events_) {
q.clear(); q.clear();
} }
...@@ -144,7 +155,7 @@ public: ...@@ -144,7 +155,7 @@ public:
//profiler_.enter("make events"); //profiler_.enter("make events");
// check all global spikes to see if they will generate local events // check all global spikes to see if they will generate local events
for(auto spike : global_spikes) { for (auto spike : global_spikes) {
// search for targets // search for targets
auto targets = auto targets =
std::equal_range( std::equal_range(
...@@ -152,7 +163,7 @@ public: ...@@ -152,7 +163,7 @@ public:
); );
// generate an event for each target // generate an event for each target
for(auto it=targets.first; it!=targets.second; ++it) { for (auto it=targets.first; it!=targets.second; ++it) {
auto gidx = local_group_from_global_target(it->destination()); auto gidx = local_group_from_global_target(it->destination());
events_[gidx].push_back(it->make_event(spike)); events_[gidx].push_back(it->make_event(spike));
...@@ -164,56 +175,62 @@ public: ...@@ -164,56 +175,62 @@ public:
//profiler_.leave(); // event generation //profiler_.leave(); // event generation
} }
const std::vector<local_event>& queue(int i) uint64_t num_spikes() const
{ {
return num_spikes_;
}
int domain_id() const {
return communication_policy_.id();
}
int num_domains() const {
return communication_policy_.size();
}
const std::vector<local_event>& queue(int i) const {
return events_[i]; return events_[i];
} }
std::vector<connection>const& connections() const { const std::vector<connection>& connections() const {
return connections_; return connections_;
} }
communication_policy_type& communication_policy() { communication_policy_type communication_policy() const {
return communication_policy_; return communication_policy_;
} }
const std::vector<index_type>& local_target_map() const { const std::vector<id_type>& local_target_map() const {
return target_map_; return target_map_;
} }
std::vector<spike_type> local_spikes() { std::vector<spike_type> local_spikes() {
std::vector<spike_type> spikes; std::vector<spike_type> spikes;
for(auto& v : thread_spikes_) { for (auto& v : thread_spikes_) {
spikes.insert(spikes.end(), v.begin(), v.end()); spikes.insert(spikes.end(), v.begin(), v.end());
} }
return spikes; return spikes;
} }
void clear_thread_spike_buffers() { void clear_thread_spike_buffers() {
for(auto& v : thread_spikes_) { for (auto& v : thread_spikes_) {
v.clear(); v.clear();
} }
} }
private: private:
//
// both of these can be fixed with double buffering
//
// FIXME : race condition on the thread_spikes_ buffers when exchange() modifies/access them // FIXME : race condition on the thread_spikes_ buffers when exchange() modifies/access them
// ... other threads will be pushing to them simultaneously // ... other threads will be pushing to them simultaneously
// FIXME : race condition on the group-specific event queues when exchange pusheds to them // FIXME : race condition on the group-specific event queues when exchange pushes to them
// ... other threads will be accessing them to get events // ... other threads will be accessing them to update their event queues
// thread private storage for accumulating spikes // thread private storage for accumulating spikes
using local_spike_store_type = using local_spike_store_type =
nest::mc::threading::enumerable_thread_specific<std::vector<spike_type>>; nest::mc::threading::enumerable_thread_specific<std::vector<spike_type>>;
/*
#ifdef WITH_TBB
using local_spike_store_type =
tbb::enumerable_thread_specific<std::vector<spike_type>>;
#else
using local_spike_store_type =
std::array<std::vector<spike_type>, 1>;
#endif
*/
local_spike_store_type thread_spikes_; local_spike_store_type thread_spikes_;
std::vector<connection> connections_; std::vector<connection> connections_;
...@@ -221,30 +238,28 @@ private: ...@@ -221,30 +238,28 @@ private:
// local target group i has targets in the half open range // local target group i has targets in the half open range
// [target_map_[i], target_map_[i+1]) // [target_map_[i], target_map_[i+1])
std::vector<index_type> target_map_; std::vector<id_type> target_map_;
// for keeping track of how time is spent where // for keeping track of how time is spent where
//util::Profiler profiler_; //util::Profiler profiler_;
// the number of groups and targets handled by this communicator // the number of groups and targets handled by this communicator
index_type num_groups_local_; id_type num_groups_local_;
index_type num_targets_local_; id_type num_targets_local_;
// index maps for the global distribution of groups and targets // index maps for the global distribution of groups and targets
// communicator i has the groups in the half open range : // communicator i has the groups in the half open range :
// [group_gid_map_[i], group_gid_map_[i+1]) // [group_gid_map_[i], group_gid_map_[i+1])
std::vector<index_type> group_gid_map_; std::vector<id_type> group_gid_map_;
// communicator i has the targets in the half open range : // communicator i has the targets in the half open range :
// [target_gid_map_[i], target_gid_map_[i+1]) // [target_gid_map_[i], target_gid_map_[i+1])
std::vector<index_type> target_gid_map_; std::vector<id_type> target_gid_map_;
// each communicator has a unique id
// e.g. for MPI this could be the MPI rank
index_type communicator_id_;
communication_policy_type communication_policy_; communication_policy_type communication_policy_;
uint64_t num_spikes_ = 0u;
}; };
} // namespace communication } // namespace communication
......
...@@ -11,36 +11,28 @@ namespace communication { ...@@ -11,36 +11,28 @@ namespace communication {
class connection { class connection {
public: public:
using index_type = uint32_t; using id_type = uint32_t;
connection(index_type src, index_type dest, float w, float d) connection(id_type src, id_type dest, float w, float d) :
: source_(src), source_(src),
destination_(dest), destination_(dest),
weight_(w), weight_(w),
delay_(d) delay_(d)
{ } {}
float weight() const { float weight() const { return weight_; }
return weight_; float delay() const { return delay_; }
}
float delay() const {
return delay_;
}
index_type source() const { id_type source() const { return source_; }
return source_; id_type destination() const { return destination_; }
}
index_type destination() const {
return destination_;
}
local_event make_event(spike<index_type> s) { local_event make_event(spike<id_type> s) {
return {destination_, s.time + delay_, weight_}; return {destination_, s.time + delay_, weight_};
} }
private: private:
index_type source_; id_type source_;
index_type destination_; id_type destination_;
float weight_; float weight_;
float delay_; float delay_;
}; };
...@@ -54,12 +46,12 @@ bool operator< (connection lhs, connection rhs) { ...@@ -54,12 +46,12 @@ bool operator< (connection lhs, connection rhs) {
} }
static inline static inline
bool operator< (connection lhs, connection::index_type rhs) { bool operator< (connection lhs, connection::id_type rhs) {
return lhs.source() < rhs; return lhs.source() < rhs;
} }
static inline static inline
bool operator< (connection::index_type lhs, connection rhs) { bool operator< (connection::id_type lhs, connection rhs) {
return lhs < rhs.source(); return lhs < rhs.source();
} }
......
...@@ -2,15 +2,30 @@ ...@@ -2,15 +2,30 @@
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <type_traits>
#include <vector> #include <vector>
#include <cassert> #include <cassert>
#include <mpi.h> #include <mpi.h>
#include "utils.hpp"
#include "utils.hpp" #include <algorithms.hpp>
namespace nest {
namespace mc {
namespace mpi { namespace mpi {
// prototypes
void init(int *argc, char ***argv);
void finalize();
bool is_root();
int rank();
int size();
void barrier();
bool ballot(bool vote);
// type traits for automatically setting MPI_Datatype information
// for C++ types
template <typename T> template <typename T>
struct mpi_traits { struct mpi_traits {
constexpr static size_t count() { constexpr static size_t count() {
...@@ -41,15 +56,15 @@ namespace mpi { ...@@ -41,15 +56,15 @@ namespace mpi {
static_assert(sizeof(size_t)==sizeof(unsigned long), static_assert(sizeof(size_t)==sizeof(unsigned long),
"size_t and unsigned long are not equivalent"); "size_t and unsigned long are not equivalent");
bool init(int *argc, char ***argv); // Gather individual values of type T from each rank into a std::vector on
bool finalize(); // the root rank.
bool is_root(); // T must be trivially copyable
int rank(); template<typename T>
int size();
void barrier();
template <typename T>
std::vector<T> gather(T value, int root) { std::vector<T> gather(T value, int root) {
static_assert(
true,//std::is_trivially_copyable<T>::value,
"gather can only be performed on trivally copyable types");
using traits = mpi_traits<T>; using traits = mpi_traits<T>;
auto buffer_size = (rank()==root) ? size() : 0; auto buffer_size = (rank()==root) ? size() : 0;
std::vector<T> buffer(buffer_size); std::vector<T> buffer(buffer_size);
...@@ -61,8 +76,15 @@ namespace mpi { ...@@ -61,8 +76,15 @@ namespace mpi {
return buffer; return buffer;
} }
// Gather individual values of type T from each rank into a std::vector on
// the every rank.
// T must be trivially copyable
template <typename T> template <typename T>
std::vector<T> gather_all(T value) { std::vector<T> gather_all(T value) {
static_assert(
true,//std::is_trivially_copyable<T>::value,
"gather_all can only be performed on trivally copyable types");
using traits = mpi_traits<T>; using traits = mpi_traits<T>;
std::vector<T> buffer(size()); std::vector<T> buffer(size());
...@@ -75,12 +97,16 @@ namespace mpi { ...@@ -75,12 +97,16 @@ namespace mpi {
template <typename T> template <typename T>
std::vector<T> gather_all(const std::vector<T> &values) { std::vector<T> gather_all(const std::vector<T> &values) {
static_assert(
true,//std::is_trivially_copyable<T>::value,
"gather_all can only be performed on trivally copyable types");
using traits = mpi_traits<T>; using traits = mpi_traits<T>;
auto counts = gather_all(int(values.size())); auto counts = gather_all(int(values.size()));
for(auto& c : counts) { for (auto& c : counts) {
c *= traits::count(); c *= traits::count();
} }
auto displs = algorithms::make_map(counts); auto displs = algorithms::make_index(counts);
std::vector<T> buffer(displs.back()/traits::count()); std::vector<T> buffer(displs.back()/traits::count());
...@@ -98,8 +124,9 @@ namespace mpi { ...@@ -98,8 +124,9 @@ namespace mpi {
template <typename T> template <typename T>
T reduce(T value, MPI_Op op, int root) { T reduce(T value, MPI_Op op, int root) {
using traits = mpi_traits<T>; using traits = mpi_traits<T>;
static_assert(traits::is_mpi_native_type(), static_assert(
"can only perform reductions on MPI native types"); traits::is_mpi_native_type(),
"can only perform reductions on MPI native types");
T result; T result;
...@@ -111,8 +138,9 @@ namespace mpi { ...@@ -111,8 +138,9 @@ namespace mpi {
template <typename T> template <typename T>
T reduce(T value, MPI_Op op) { T reduce(T value, MPI_Op op) {
using traits = mpi_traits<T>; using traits = mpi_traits<T>;
static_assert(traits::is_mpi_native_type(), static_assert(
"can only perform reductions on MPI native types"); traits::is_mpi_native_type(),
"can only perform reductions on MPI native types");
T result; T result;
...@@ -131,10 +159,12 @@ namespace mpi { ...@@ -131,10 +159,12 @@ namespace mpi {
return {reduce<T>(value, MPI_MIN, root), reduce<T>(value, MPI_MAX, root)}; return {reduce<T>(value, MPI_MIN, root), reduce<T>(value, MPI_MAX, root)};
} }
bool ballot(bool vote);
template <typename T> template <typename T>
T broadcast(T value, int root) { T broadcast(T value, int root) {
static_assert(
true,//std::is_trivially_copyable<T>::value,
"broadcast can only be performed on trivally copyable types");
using traits = mpi_traits<T>; using traits = mpi_traits<T>;
MPI_Bcast(&value, traits::count(), traits::mpi_type(), root, MPI_COMM_WORLD); MPI_Bcast(&value, traits::count(), traits::mpi_type(), root, MPI_COMM_WORLD);
...@@ -144,6 +174,10 @@ namespace mpi { ...@@ -144,6 +174,10 @@ namespace mpi {
template <typename T> template <typename T>
T broadcast(int root) { T broadcast(int root) {
static_assert(
true,//std::is_trivially_copyable<T>::value,
"broadcast can only be performed on trivally copyable types");
using traits = mpi_traits<T>; using traits = mpi_traits<T>;
T value; T value;
...@@ -153,3 +187,5 @@ namespace mpi { ...@@ -153,3 +187,5 @@ namespace mpi {
} }
} // namespace mpi } // namespace mpi
} // namespace mc
} // namespace nest
...@@ -6,32 +6,34 @@ ...@@ -6,32 +6,34 @@
#include <cstdint> #include <cstdint>
#include <communication/spike.hpp> #include <communication/spike.hpp>
#include <communication/mpi.hpp>
#include <algorithms.hpp> #include <algorithms.hpp>
#include "mpi.hpp"
namespace nest { namespace nest {
namespace mc { namespace mc {
namespace communication { namespace communication {
struct mpi_global_policy { struct mpi_global_policy {
std::vector<spike<uint32_t>> const using id_type = uint32_t;
gather_spikes(const std::vector<spike<uint32_t>>& local_spikes) {
std::vector<spike<id_type>> const
gather_spikes(const std::vector<spike<id_type>>& local_spikes) {
return mpi::gather_all(local_spikes); return mpi::gather_all(local_spikes);
} }
int id() const { int id() const { return mpi::rank(); }
return mpi::rank();
} int size() const { return mpi::size(); }
/*
template <typename T> template <typename T>
T min(T value) const { T min(T value) const {
return nest::mc::mpi::reduce(value, MPI_MIN);
} }
*/
int num_communicators() const { template <typename T>
return mpi::size(); T max(T value) const {
return nest::mc::mpi::reduce(value, MPI_MAX);
} }
template < template <
......
...@@ -12,8 +12,10 @@ namespace mc { ...@@ -12,8 +12,10 @@ namespace mc {
namespace communication { namespace communication {
struct serial_global_policy { struct serial_global_policy {
std::vector<spike<uint32_t>> const using id_type = uint32_t;
gather_spikes(const std::vector<spike<uint32_t>>& local_spikes) {
std::vector<spike<id_type>> const
gather_spikes(const std::vector<spike<id_type>>& local_spikes) {
return local_spikes; return local_spikes;
} }
...@@ -21,7 +23,7 @@ struct serial_global_policy { ...@@ -21,7 +23,7 @@ struct serial_global_policy {
return 0; return 0;
} }
static int num_communicators() { static int size() {
return 1; return 1;
} }
......
...@@ -12,14 +12,14 @@ template < ...@@ -12,14 +12,14 @@ template <
typename = typename std::enable_if<std::is_integral<I>::value> typename = typename std::enable_if<std::is_integral<I>::value>
> >
struct spike { struct spike {
using index_type = I; using id_type = I;
index_type source = 0; id_type source = 0;
float time = -1.; float time = -1.;
spike() = default; spike() = default;
spike(index_type s, float t) spike(id_type s, float t) :
: source(s), time(t) source(s), time(t)
{} {}
}; };
...@@ -29,14 +29,17 @@ struct spike { ...@@ -29,14 +29,17 @@ struct spike {
/// custom stream operator for printing nest::mc::spike<> values /// custom stream operator for printing nest::mc::spike<> values
template <typename I> template <typename I>
std::ostream& operator <<(std::ostream& o, nest::mc::communication::spike<I> s) { std::ostream& operator<<(
std::ostream& o,
nest::mc::communication::spike<I> s)
{
return o << "spike[t " << s.time << ", src " << s.source << "]"; return o << "spike[t " << s.time << ", src " << s.source << "]";
} }
/// less than comparison operator for nest::mc::spike<> values /// less than comparison operator for nest::mc::spike<> values
/// spikes are ordered by spike time, for use in sorting and queueing /// spikes are ordered by spike time, for use in sorting and queueing
template <typename I> template <typename I>
bool operator <( bool operator<(
nest::mc::communication::spike<I> lhs, nest::mc::communication::spike<I> lhs,
nest::mc::communication::spike<I> rhs) nest::mc::communication::spike<I> rhs)
{ {
......
...@@ -13,13 +13,8 @@ class spike_detector ...@@ -13,13 +13,8 @@ class spike_detector
public: public:
using cell_type = Cell; using cell_type = Cell;
spike_detector( spike_detector( const cell_type& cell, segment_location loc, double thresh, float t_init) :
const cell_type& cell, location_(loc),
segment_location loc,
double thresh,
float t_init
)
: location_(loc),
threshold_(thresh), threshold_(thresh),
previous_t_(t_init) previous_t_(t_init)
{ {
...@@ -27,8 +22,7 @@ public: ...@@ -27,8 +22,7 @@ public:
is_spiking_ = previous_v_ >= thresh ? true : false; is_spiking_ = previous_v_ >= thresh ? true : false;
} }
util::optional<float> test(const cell_type& cell, float t) util::optional<float> test(const cell_type& cell, float t) {
{
util::optional<float> result = util::nothing; util::optional<float> result = util::nothing;
auto v = cell.voltage(location_); auto v = cell.voltage(location_);
...@@ -56,26 +50,17 @@ public: ...@@ -56,26 +50,17 @@ public:
return result; return result;
} }
bool is_spiking() const { bool is_spiking() const { return is_spiking_; }
return is_spiking_;
}
segment_location location() const { segment_location location() const { return location_; }
return location_;
}
float t() const { float t() const { return previous_t_; }
return previous_t_;
}
float v() const { float v() const { return previous_v_; }
return previous_v_;
}
private: private:
// parameters/data // parameters/data
//const cell_type* cell_;
segment_location location_; segment_location location_;
double threshold_; double threshold_;
...@@ -85,38 +70,6 @@ private: ...@@ -85,38 +70,6 @@ private:
bool is_spiking_; bool is_spiking_;
}; };
/*
// spike generator according to a Poisson process
class poisson_generator : public spike_source
{
public:
poisson_generator(float r)
: dist_(0.0f, 1.0f),
firing_rate_(r)
{}
util::optional<float> test(float t) {
// generate a uniformly distrubuted random number x \in [0,1]
// if (x > r*dt) we have a spike in the interval
std::vector<float> spike_times;
if(dist_(generator_) > firing_rate_*(t-previous_t_)) {
return t;
}
return util::nothing;
}
private:
std::mt19937 generator_; // for now default initialized
std::uniform_real_distribution<float> dist_;
// firing rate in spikes/ms
float firing_rate_;
float previous_t_;
};
*/
} // namespace mc } // namespace mc
} // namespace nest } // namespace nest
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment