diff --git a/miniapp/miniapp.cpp b/miniapp/miniapp.cpp index 5f2082e1b41945ecd0452a27984b6dc34bcb20b8..354bdbf7605c4338642f65004bd5ff3f17108f3f 100644 --- a/miniapp/miniapp.cpp +++ b/miniapp/miniapp.cpp @@ -10,7 +10,17 @@ #include "profiling/profiler.hpp" #include "communication/communicator.hpp" #include "communication/serial_global_policy.hpp" + +#ifdef WITH_MPI + #include "communication/mpi_global_policy.hpp" +using global_policy = nest::mc::communication::mpi_global_policy; + +#else + +using global_policy = nest::mc::communication::serial_global_policy; + +#endif using namespace nest; @@ -19,13 +29,28 @@ using index_type = int; using id_type = uint32_t; using numeric_cell = mc::fvm::fvm_cell<real_type, index_type>; using cell_group = mc::cell_group<numeric_cell>; -#ifdef WITH_MPI -using communicator_type = - mc::communication::communicator<mc::communication::mpi_global_policy>; -#else + using communicator_type = - mc::communication::communicator<mc::communication::serial_global_policy>; -#endif + mc::communication::communicator<global_policy>; + +template <typename Policy> +struct policy_guard { + policy_guard(int argc, char **&argv) { + Policy::setup(argc, argv); + } + + policy_guard() = delete; + policy_guard(policy_guard &&) = delete; + policy_guard(const policy_guard &) = delete; + policy_guard &operator=(policy_guard &&) = delete; + policy_guard &operator=(const policy_guard &) = delete; + + ~policy_guard() { + Policy::teardown(); + } +}; + +using global_policy_guard = policy_guard<global_policy>; struct model { communicator_type communicator; @@ -122,7 +147,7 @@ namespace synapses { mc::cell make_cell(int compartments_per_segment, int num_synapses); /// do basic setup (initialize global state, print banner, etc) -void setup(int argc, char** argv); +void setup(); /// helper function for initializing cells cell_group make_lowered_cell(int cell_index, const mc::cell& c); @@ -131,22 +156,24 @@ cell_group make_lowered_cell(int cell_index, const mc::cell& c); void ring_model(nest::mc::io::options& opt, model& m); void all_to_all_model(nest::mc::io::options& opt, model& m); + /////////////////////////////////////// // main /////////////////////////////////////// int main(int argc, char** argv) { + global_policy_guard _(argc, argv); - setup(argc, argv); + setup(); // read parameters mc::io::options opt; try { opt = mc::io::read_options(""); - if (mc::mpi::rank()==0) { + if (!global_policy::id()) { std::cout << opt << "\n"; } } - catch (std::exception e) { + catch (std::exception &e) { std::cerr << e.what() << std::endl; exit(1); } @@ -174,7 +201,7 @@ int main(int argc, char** argv) { } #ifdef SPLAT - if (!mc::mpi::rank()) { + if (!global_policy::id()) { //for (auto i=0u; i<m.cell_groups.size(); ++i) { m.cell_groups[0].splat("cell0.txt"); m.cell_groups[1].splat("cell1.txt"); @@ -182,10 +209,6 @@ int main(int argc, char** argv) { //} } #endif - -#ifdef WITH_MPI - mc::mpi::finalize(); -#endif } /////////////////////////////////////// @@ -295,25 +318,15 @@ void all_to_all_model(nest::mc::io::options& opt, model& m) { // function definitions /////////////////////////////////////// -void setup(int argc, char** argv) { -#ifdef WITH_MPI - mc::mpi::init(&argc, &argv); - +void setup() { // print banner - if (mc::mpi::rank()==0) { + if (!global_policy::id()) { std::cout << "====================\n"; std::cout << " starting miniapp\n"; std::cout << " - " << mc::threading::description() << " threading support\n"; - std::cout << " - MPI support\n"; + std::cout << " - communication policy: " << global_policy::name() << "\n"; std::cout << "====================\n"; } -#else - // print banner - std::cout << "====================\n"; - std::cout << " starting miniapp\n"; - std::cout << " - " << mc::threading::description() << " threading support\n"; - std::cout << "====================\n"; -#endif // setup global state for the mechanisms mc::mechanisms::setup_mechanism_helpers(); diff --git a/src/communication/mpi.hpp b/src/communication/mpi.hpp index 063e36ef3fc3ae44f952e2caeddf18663832cefd..5be71b6381bf8311bcab85d5396d64ba2c3eb26b 100644 --- a/src/communication/mpi.hpp +++ b/src/communication/mpi.hpp @@ -43,7 +43,7 @@ namespace mpi { template <> \ struct mpi_traits<T> { \ constexpr static size_t count() { return 1; } \ - constexpr static MPI_Datatype mpi_type() { return M; } \ + /* constexpr */ static MPI_Datatype mpi_type() { return M; } \ constexpr static bool is_mpi_native_type() { return true; } \ }; diff --git a/src/communication/mpi_global_policy.hpp b/src/communication/mpi_global_policy.hpp index 79135437e73d21f0a49ac91c6408a33e4b6d61ef..7409585447ae888eec2aa47ddc1fa0f137df7d6e 100644 --- a/src/communication/mpi_global_policy.hpp +++ b/src/communication/mpi_global_policy.hpp @@ -1,5 +1,9 @@ #pragma once +#ifndef WITH_MPI +#error "mpi_global_policy.hpp should only be compiled in a WITH_MPI build" +#endif + #include <type_traits> #include <vector> @@ -9,7 +13,6 @@ #include <communication/mpi.hpp> #include <algorithms.hpp> - namespace nest { namespace mc { namespace communication { @@ -17,22 +20,22 @@ namespace communication { struct mpi_global_policy { using id_type = uint32_t; - std::vector<spike<id_type>> const - gather_spikes(const std::vector<spike<id_type>>& local_spikes) { + std::vector<spike<id_type>> + static gather_spikes(const std::vector<spike<id_type>>& local_spikes) { return mpi::gather_all(local_spikes); } - int id() const { return mpi::rank(); } + static int id() { return mpi::rank(); } - int size() const { return mpi::size(); } + static int size() { return mpi::size(); } template <typename T> - T min(T value) const { + static T min(T value) { return nest::mc::mpi::reduce(value, MPI_MIN); } template <typename T> - T max(T value) const { + static T max(T value) { return nest::mc::mpi::reduce(value, MPI_MAX); } @@ -40,11 +43,24 @@ struct mpi_global_policy { typename T, typename = typename std::enable_if<std::is_integral<T>::value> > - std::vector<T> make_map(T local) { + static std::vector<T> make_map(T local) { return algorithms::make_index(mpi::gather_all(local)); } + + static void setup(int& argc, char**& argv) { + nest::mc::mpi::init(&argc, &argv); + } + + static void teardown() { + nest::mc::mpi::finalize(); + } + + static const char* name() { return "MPI"; } + +private: }; } // namespace communication } // namespace mc } // namespace nest + diff --git a/src/communication/serial_global_policy.hpp b/src/communication/serial_global_policy.hpp index eaa6c4b11168e4e11f0c56df79631548ccb02444..a1c7fee9671cb79768e255a710e0082386302e8a 100644 --- a/src/communication/serial_global_policy.hpp +++ b/src/communication/serial_global_policy.hpp @@ -15,7 +15,7 @@ struct serial_global_policy { using id_type = uint32_t; std::vector<spike<id_type>> const - gather_spikes(const std::vector<spike<id_type>>& local_spikes) { + static gather_spikes(const std::vector<spike<id_type>>& local_spikes) { return local_spikes; } @@ -28,12 +28,12 @@ struct serial_global_policy { } template <typename T> - T min(T value) const { + static T min(T value) { return value; } template <typename T> - T max(T value) const { + static T max(T value) { return value; } @@ -41,9 +41,13 @@ struct serial_global_policy { typename T, typename = typename std::enable_if<std::is_integral<T>::value> > - std::vector<T> make_map(T local) { + static std::vector<T> make_map(T local) { return {T(0), local}; } + + static void setup(int& argc, char**& argv) {} + static void teardown() {} + static const char* name() { return "serial"; } }; } // namespace communication