diff --git a/miniapp/miniapp.cpp b/miniapp/miniapp.cpp index 98c0e21c25579a2048fb8c599935012622eb77ef..e3ad9b92cbc8595aaa66a62f1ff290d9b4118071 100644 --- a/miniapp/miniapp.cpp +++ b/miniapp/miniapp.cpp @@ -6,11 +6,10 @@ #include <mechanism_interface.hpp> #include "io.hpp" - -#include <threading/threading.hpp> -#include <profiling/profiler.hpp> -#include <communication/communicator.hpp> -#include <communication/serial_global_policy.hpp> +#include "threading/threading.hpp" +#include "profiling/profiler.hpp" +#include "communication/communicator.hpp" +#include "communication/global_policy.hpp" using namespace nest; @@ -19,14 +18,10 @@ using index_type = int; using id_type = uint32_t; using numeric_cell = mc::fvm::fvm_cell<real_type, index_type>; using cell_group = mc::cell_group<numeric_cell>; -#ifdef WITH_MPI -#include <communication/mpi_global_policy.hpp> -using communicator_type = - mc::communication::communicator<mc::communication::mpi_global_policy>; -#else + +using global_policy = nest::mc::communication::global_policy; using communicator_type = - mc::communication::communicator<mc::communication::serial_global_policy>; -#endif + mc::communication::communicator<global_policy>; struct model { communicator_type communicator; @@ -118,7 +113,7 @@ namespace synapses { mc::cell make_cell(int compartments_per_segment, int num_synapses); /// do basic setup (initialize global state, print banner, etc) -void setup(int argc, char** argv); +void setup(); /// helper function for initializing cells cell_group make_lowered_cell(int cell_index, const mc::cell& c); @@ -127,26 +122,24 @@ cell_group make_lowered_cell(int cell_index, const mc::cell& c); void ring_model(nest::mc::io::options& opt, model& m); void all_to_all_model(nest::mc::io::options& opt, model& m); + /////////////////////////////////////// // main /////////////////////////////////////// int main(int argc, char** argv) { + nest::mc::communication::global_policy_guard global_guard(argc, argv); - setup(argc, argv); + setup(); // read parameters mc::io::options opt; try { opt = mc::io::read_options(""); - #ifdef WITH_MPI - if (mc::mpi::rank()==0) { + if (!global_policy::id()) { std::cout << opt << "\n"; } - #else - std::cout << opt << "\n"; - #endif } - catch (std::exception e) { + catch (std::exception& e) { std::cerr << e.what() << std::endl; exit(1); } @@ -175,7 +168,7 @@ int main(int argc, char** argv) { } #ifdef SPLAT - if (!mc::mpi::rank()) { + if (!global_policy::id()) { //for (auto i=0u; i<m.cell_groups.size(); ++i) { m.cell_groups[0].splat("cell0.txt"); m.cell_groups[1].splat("cell1.txt"); @@ -183,10 +176,6 @@ int main(int argc, char** argv) { //} } #endif - -#ifdef WITH_MPI - mc::mpi::finalize(); -#endif } /////////////////////////////////////// @@ -291,25 +280,15 @@ void all_to_all_model(nest::mc::io::options& opt, model& m) { // function definitions /////////////////////////////////////// -void setup(int argc, char** argv) { -#ifdef WITH_MPI - mc::mpi::init(&argc, &argv); - +void setup() { // print banner - if (mc::mpi::rank()==0) { + if (!global_policy::id()) { std::cout << "====================\n"; std::cout << " starting miniapp\n"; std::cout << " - " << mc::threading::description() << " threading support\n"; - std::cout << " - MPI support\n"; + std::cout << " - communication policy: " << global_policy::name() << "\n"; std::cout << "====================\n"; } -#else - // print banner - std::cout << "====================\n"; - std::cout << " starting miniapp\n"; - std::cout << " - " << mc::threading::description() << " threading support\n"; - std::cout << "====================\n"; -#endif // setup global state for the mechanisms mc::mechanisms::setup_mechanism_helpers(); diff --git a/src/communication/global_policy.hpp b/src/communication/global_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b771e8ed24d5d022b6a23cc7d77171ca477a7dbf --- /dev/null +++ b/src/communication/global_policy.hpp @@ -0,0 +1,42 @@ +#pragma once + +#ifdef WITH_MPI + #include "communication/mpi_global_policy.hpp" +#else + #include "communication/serial_global_policy.hpp" +#endif + +namespace nest { +namespace mc { +namespace communication { + +#ifdef WITH_MPI +using global_policy = nest::mc::communication::mpi_global_policy; +#else +using global_policy = nest::mc::communication::serial_global_policy; +#endif + +template <typename Policy> +struct policy_guard { + using policy_type = Policy; + + policy_guard(int argc, char**& argv) { + policy_type::setup(argc, argv); + } + + policy_guard() = delete; + policy_guard(policy_guard&&) = delete; + policy_guard(const policy_guard&) = delete; + policy_guard& operator=(policy_guard&&) = delete; + policy_guard& operator=(const policy_guard&) = delete; + + ~policy_guard() { + Policy::teardown(); + } +}; + +using global_policy_guard = policy_guard<global_policy>; + +} // namespace communication +} // namespace mc +} // namespace nest diff --git a/src/communication/mpi.hpp b/src/communication/mpi.hpp index 063e36ef3fc3ae44f952e2caeddf18663832cefd..5be71b6381bf8311bcab85d5396d64ba2c3eb26b 100644 --- a/src/communication/mpi.hpp +++ b/src/communication/mpi.hpp @@ -43,7 +43,7 @@ namespace mpi { template <> \ struct mpi_traits<T> { \ constexpr static size_t count() { return 1; } \ - constexpr static MPI_Datatype mpi_type() { return M; } \ + /* constexpr */ static MPI_Datatype mpi_type() { return M; } \ constexpr static bool is_mpi_native_type() { return true; } \ }; diff --git a/src/communication/mpi_global_policy.hpp b/src/communication/mpi_global_policy.hpp index 79135437e73d21f0a49ac91c6408a33e4b6d61ef..7409585447ae888eec2aa47ddc1fa0f137df7d6e 100644 --- a/src/communication/mpi_global_policy.hpp +++ b/src/communication/mpi_global_policy.hpp @@ -1,5 +1,9 @@ #pragma once +#ifndef WITH_MPI +#error "mpi_global_policy.hpp should only be compiled in a WITH_MPI build" +#endif + #include <type_traits> #include <vector> @@ -9,7 +13,6 @@ #include <communication/mpi.hpp> #include <algorithms.hpp> - namespace nest { namespace mc { namespace communication { @@ -17,22 +20,22 @@ namespace communication { struct mpi_global_policy { using id_type = uint32_t; - std::vector<spike<id_type>> const - gather_spikes(const std::vector<spike<id_type>>& local_spikes) { + std::vector<spike<id_type>> + static gather_spikes(const std::vector<spike<id_type>>& local_spikes) { return mpi::gather_all(local_spikes); } - int id() const { return mpi::rank(); } + static int id() { return mpi::rank(); } - int size() const { return mpi::size(); } + static int size() { return mpi::size(); } template <typename T> - T min(T value) const { + static T min(T value) { return nest::mc::mpi::reduce(value, MPI_MIN); } template <typename T> - T max(T value) const { + static T max(T value) { return nest::mc::mpi::reduce(value, MPI_MAX); } @@ -40,11 +43,24 @@ struct mpi_global_policy { typename T, typename = typename std::enable_if<std::is_integral<T>::value> > - std::vector<T> make_map(T local) { + static std::vector<T> make_map(T local) { return algorithms::make_index(mpi::gather_all(local)); } + + static void setup(int& argc, char**& argv) { + nest::mc::mpi::init(&argc, &argv); + } + + static void teardown() { + nest::mc::mpi::finalize(); + } + + static const char* name() { return "MPI"; } + +private: }; } // namespace communication } // namespace mc } // namespace nest + diff --git a/src/communication/serial_global_policy.hpp b/src/communication/serial_global_policy.hpp index eaa6c4b11168e4e11f0c56df79631548ccb02444..a1c7fee9671cb79768e255a710e0082386302e8a 100644 --- a/src/communication/serial_global_policy.hpp +++ b/src/communication/serial_global_policy.hpp @@ -15,7 +15,7 @@ struct serial_global_policy { using id_type = uint32_t; std::vector<spike<id_type>> const - gather_spikes(const std::vector<spike<id_type>>& local_spikes) { + static gather_spikes(const std::vector<spike<id_type>>& local_spikes) { return local_spikes; } @@ -28,12 +28,12 @@ struct serial_global_policy { } template <typename T> - T min(T value) const { + static T min(T value) { return value; } template <typename T> - T max(T value) const { + static T max(T value) { return value; } @@ -41,9 +41,13 @@ struct serial_global_policy { typename T, typename = typename std::enable_if<std::is_integral<T>::value> > - std::vector<T> make_map(T local) { + static std::vector<T> make_map(T local) { return {T(0), local}; } + + static void setup(int& argc, char**& argv) {} + static void teardown() {} + static const char* name() { return "serial"; } }; } // namespace communication diff --git a/src/profiling/profiler.cpp b/src/profiling/profiler.cpp index e9765da76c0e732aa1a27a9ceb532b9fa7544d0c..c07351e296ba8873ed7064ca4b605d2e9bb0cb8d 100644 --- a/src/profiling/profiler.cpp +++ b/src/profiling/profiler.cpp @@ -1,8 +1,6 @@ #include "profiler.hpp" -#ifdef WITH_MPI -#include <communication/mpi.hpp> -#endif +#include <communication/global_policy.hpp> namespace nest { namespace mc { @@ -319,11 +317,9 @@ void profiler_output(double threshold) { p.scale(1./nthreads); -#ifdef WITH_MPI - bool print = nest::mc::mpi::rank()==0 ? true : false; -#else - bool print = true; -#endif + auto ncomms = communication::global_policy::size(); + auto comm_rank = communication::global_policy::id(); + bool print = comm_rank==0 ? true : false; if(print) { std::cout << " ---------------------------------------------------- \n"; std::cout << "| profiler |\n"; @@ -333,12 +329,10 @@ void profiler_output(double threshold) { line, sizeof(line), "%-18s%10.3f s\n", "wall time", float(wall_time)); std::cout << line; - #ifdef WITH_MPI std::snprintf( line, sizeof(line), "%-18s%10d\n", - "MPI ranks", int(nest::mc::mpi::size())); + "communicators", int(ncomms)); std::cout << line; - #endif std::snprintf( line, sizeof(line), "%-18s%10d\n", "threads", int(nthreads)); @@ -352,23 +346,16 @@ void profiler_output(double threshold) { std::cout << "\n\n"; } - nlohmann::json as_json = p.as_json(); + nlohmann::json as_json; as_json["wall time"] = wall_time; as_json["threads"] = nthreads; as_json["efficiency"] = efficiency; -#ifdef WITH_MPI - as_json["communicators"] = nest::mc::mpi::size(); - as_json["rank"] = nest::mc::mpi::rank(); -#else - as_json["communicators"] = 1; - as_json["rank"] = 0; -#endif + as_json["communicators"] = ncomms; + as_json["rank"] = comm_rank; + as_json["regions"] = p.as_json(); -#ifdef WITH_MPI - std::ofstream fid("profile_" + std::to_string(mpi::rank())); -#else - std::ofstream fid("profile"); -#endif + auto fname = std::string("profile_" + std::to_string(comm_rank)); + std::ofstream fid(fname); fid << std::setw(1) << as_json; }