Skip to content
Snippets Groups Projects
Commit 168f84d2 authored by Sam Yates's avatar Sam Yates
Browse files

Re-enaable WITH_MPI=OFF builds

* Make global_policy methods static
* Move global_policy init/finalize into static methods
* Make miniapp use only global_policy methods for
  setup, rank, etc.
parent 0bfd5f86
No related branches found
No related tags found
No related merge requests found
......@@ -10,7 +10,17 @@
#include "profiling/profiler.hpp"
#include "communication/communicator.hpp"
#include "communication/serial_global_policy.hpp"
#ifdef WITH_MPI
#include "communication/mpi_global_policy.hpp"
using global_policy = nest::mc::communication::mpi_global_policy;
#else
using global_policy = nest::mc::communication::serial_global_policy;
#endif
using namespace nest;
......@@ -19,13 +29,28 @@ using index_type = int;
using id_type = uint32_t;
using numeric_cell = mc::fvm::fvm_cell<real_type, index_type>;
using cell_group = mc::cell_group<numeric_cell>;
#ifdef WITH_MPI
using communicator_type =
mc::communication::communicator<mc::communication::mpi_global_policy>;
#else
using communicator_type =
mc::communication::communicator<mc::communication::serial_global_policy>;
#endif
mc::communication::communicator<global_policy>;
template <typename Policy>
struct policy_guard {
policy_guard(int argc, char **&argv) {
Policy::setup(argc, argv);
}
policy_guard() = delete;
policy_guard(policy_guard &&) = delete;
policy_guard(const policy_guard &) = delete;
policy_guard &operator=(policy_guard &&) = delete;
policy_guard &operator=(const policy_guard &) = delete;
~policy_guard() {
Policy::teardown();
}
};
using global_policy_guard = policy_guard<global_policy>;
struct model {
communicator_type communicator;
......@@ -122,7 +147,7 @@ namespace synapses {
mc::cell make_cell(int compartments_per_segment, int num_synapses);
/// do basic setup (initialize global state, print banner, etc)
void setup(int argc, char** argv);
void setup();
/// helper function for initializing cells
cell_group make_lowered_cell(int cell_index, const mc::cell& c);
......@@ -131,22 +156,24 @@ cell_group make_lowered_cell(int cell_index, const mc::cell& c);
void ring_model(nest::mc::io::options& opt, model& m);
void all_to_all_model(nest::mc::io::options& opt, model& m);
///////////////////////////////////////
// main
///////////////////////////////////////
int main(int argc, char** argv) {
global_policy_guard _(argc, argv);
setup(argc, argv);
setup();
// read parameters
mc::io::options opt;
try {
opt = mc::io::read_options("");
if (mc::mpi::rank()==0) {
if (!global_policy::id()) {
std::cout << opt << "\n";
}
}
catch (std::exception e) {
catch (std::exception &e) {
std::cerr << e.what() << std::endl;
exit(1);
}
......@@ -174,7 +201,7 @@ int main(int argc, char** argv) {
}
#ifdef SPLAT
if (!mc::mpi::rank()) {
if (!global_policy::id()) {
//for (auto i=0u; i<m.cell_groups.size(); ++i) {
m.cell_groups[0].splat("cell0.txt");
m.cell_groups[1].splat("cell1.txt");
......@@ -182,10 +209,6 @@ int main(int argc, char** argv) {
//}
}
#endif
#ifdef WITH_MPI
mc::mpi::finalize();
#endif
}
///////////////////////////////////////
......@@ -295,25 +318,15 @@ void all_to_all_model(nest::mc::io::options& opt, model& m) {
// function definitions
///////////////////////////////////////
void setup(int argc, char** argv) {
#ifdef WITH_MPI
mc::mpi::init(&argc, &argv);
void setup() {
// print banner
if (mc::mpi::rank()==0) {
if (!global_policy::id()) {
std::cout << "====================\n";
std::cout << " starting miniapp\n";
std::cout << " - " << mc::threading::description() << " threading support\n";
std::cout << " - MPI support\n";
std::cout << " - communication policy: " << global_policy::name() << "\n";
std::cout << "====================\n";
}
#else
// print banner
std::cout << "====================\n";
std::cout << " starting miniapp\n";
std::cout << " - " << mc::threading::description() << " threading support\n";
std::cout << "====================\n";
#endif
// setup global state for the mechanisms
mc::mechanisms::setup_mechanism_helpers();
......
......@@ -43,7 +43,7 @@ namespace mpi {
template <> \
struct mpi_traits<T> { \
constexpr static size_t count() { return 1; } \
constexpr static MPI_Datatype mpi_type() { return M; } \
/* constexpr */ static MPI_Datatype mpi_type() { return M; } \
constexpr static bool is_mpi_native_type() { return true; } \
};
......
#pragma once
#ifndef WITH_MPI
#error "mpi_global_policy.hpp should only be compiled in a WITH_MPI build"
#endif
#include <type_traits>
#include <vector>
......@@ -9,7 +13,6 @@
#include <communication/mpi.hpp>
#include <algorithms.hpp>
namespace nest {
namespace mc {
namespace communication {
......@@ -17,22 +20,22 @@ namespace communication {
struct mpi_global_policy {
using id_type = uint32_t;
std::vector<spike<id_type>> const
gather_spikes(const std::vector<spike<id_type>>& local_spikes) {
std::vector<spike<id_type>>
static gather_spikes(const std::vector<spike<id_type>>& local_spikes) {
return mpi::gather_all(local_spikes);
}
int id() const { return mpi::rank(); }
static int id() { return mpi::rank(); }
int size() const { return mpi::size(); }
static int size() { return mpi::size(); }
template <typename T>
T min(T value) const {
static T min(T value) {
return nest::mc::mpi::reduce(value, MPI_MIN);
}
template <typename T>
T max(T value) const {
static T max(T value) {
return nest::mc::mpi::reduce(value, MPI_MAX);
}
......@@ -40,11 +43,24 @@ struct mpi_global_policy {
typename T,
typename = typename std::enable_if<std::is_integral<T>::value>
>
std::vector<T> make_map(T local) {
static std::vector<T> make_map(T local) {
return algorithms::make_index(mpi::gather_all(local));
}
static void setup(int& argc, char**& argv) {
nest::mc::mpi::init(&argc, &argv);
}
static void teardown() {
nest::mc::mpi::finalize();
}
static const char* name() { return "MPI"; }
private:
};
} // namespace communication
} // namespace mc
} // namespace nest
......@@ -15,7 +15,7 @@ struct serial_global_policy {
using id_type = uint32_t;
std::vector<spike<id_type>> const
gather_spikes(const std::vector<spike<id_type>>& local_spikes) {
static gather_spikes(const std::vector<spike<id_type>>& local_spikes) {
return local_spikes;
}
......@@ -28,12 +28,12 @@ struct serial_global_policy {
}
template <typename T>
T min(T value) const {
static T min(T value) {
return value;
}
template <typename T>
T max(T value) const {
static T max(T value) {
return value;
}
......@@ -41,9 +41,13 @@ struct serial_global_policy {
typename T,
typename = typename std::enable_if<std::is_integral<T>::value>
>
std::vector<T> make_map(T local) {
static std::vector<T> make_map(T local) {
return {T(0), local};
}
static void setup(int& argc, char**& argv) {}
static void teardown() {}
static const char* name() { return "serial"; }
};
} // namespace communication
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment