diff --git a/python/context.cpp b/python/context.cpp index be043ec02fdf23b4d20c8d8c8ff3415571bb8b57..4e7dde4f49c3e3b9d2f107176f799946039fb958 100644 --- a/python/context.cpp +++ b/python/context.cpp @@ -4,9 +4,11 @@ #include <arbor/context.hpp> #include <arbor/version.hpp> +#include <arbor/util/optional.hpp> #include <pybind11/pybind11.h> + #include "context.hpp" #include "conversion.hpp" #include "error.hpp" @@ -24,9 +26,7 @@ auto is_nonneg_int = [](int n){ return n>=0; }; // A Python shim that holds the information that describes an arb::proc_allocation. struct proc_allocation_shim { - using opt_int = arb::util::optional<int>; - - opt_int gpu_id = {}; + arb::util::optional<int> gpu_id = {}; int num_threads = 1; proc_allocation_shim(): proc_allocation_shim(1, pybind11::none()) {} @@ -46,9 +46,9 @@ struct proc_allocation_shim { num_threads = threads; }; - const opt_int get_gpu_id() const { return gpu_id; } - const int get_num_threads() const { return num_threads; } - const bool has_gpu() const { return (gpu_id.value_or(-1) >= 0) ? true : false; } + arb::util::optional<int> get_gpu_id() const { return gpu_id; } + int get_num_threads() const { return num_threads; } + bool has_gpu() const { return bool(gpu_id); } // helper function to use arb::make_context(arb::proc_allocation) arb::proc_allocation allocation() const { @@ -75,20 +75,9 @@ std::string proc_alloc_string(const proc_allocation_shim& a) { return s.str(); } -// helper functions to wrap arb::make_context -arb::context make_context_shim(const proc_allocation_shim& Alloc) { - return arb::make_context(Alloc.allocation()); -} - -template <typename Comm> -arb::context make_context_shim(const proc_allocation_shim& Alloc, Comm comm) { - return arb::make_context(Alloc.allocation(), comm); -} - void register_contexts(pybind11::module& m) { using namespace std::string_literals; using namespace pybind11::literals; - using opt_int = arb::util::optional<int>; // proc_allocation pybind11::class_<proc_allocation_shim> proc_allocation(m, "proc_allocation", @@ -104,7 +93,7 @@ void register_contexts(pybind11::module& m) { .def_property("gpu_id", &proc_allocation_shim::get_gpu_id, &proc_allocation_shim::set_gpu_id, "The identifier of the GPU to use.\n" "Corresponds to the integer index used to identify GPUs in CUDA API calls.") - .def_property_readonly("has_gpu", &proc_allocation_shim::has_gpu, + .def_property_readonly("has_gpu", &proc_allocation_shim::has_gpu, "Whether a GPU is being used (True/False).") .def("__str__", &proc_alloc_string) .def("__repr__", &proc_alloc_string); @@ -118,48 +107,53 @@ void register_contexts(pybind11::module& m) { ) .def(pybind11::init( [](const proc_allocation_shim& alloc){ - return context_shim(make_context_shim(alloc)); - }), - "alloc"_a, - "Construct a local context with argument:\n" - " alloc: The computational resources to be used for the simulation.\n") + return context_shim(arb::make_context(alloc.allocation())); }), + "alloc"_a, + "Construct a local context with argument:\n" + " alloc: The computational resources to be used for the simulation.\n") #ifdef ARB_MPI_ENABLED .def(pybind11::init( - [](const proc_allocation_shim& alloc, pybind11::object mpi){ - if (mpi.is_none()) { - return context_shim(make_context_shim(alloc)); + [](proc_allocation_shim alloc, pybind11::object mpi){ + const char* mpi_err_str = "mpi must be None, or an MPI communicator"; + auto a = alloc.allocation(); // unwrap the C++ resource_allocation description + if (can_convert_to_mpi_comm(mpi)) { + return context_shim(arb::make_context(a, convert_to_mpi_comm(mpi))); + } + if (auto c = py2optional<mpi_comm_shim>(mpi, mpi_err_str)) { + return context_shim(arb::make_context(a, c->comm)); } - auto c = py2optional<mpi_comm_shim>(mpi, - "mpi must be None, or an MPI communicator."); - auto comm = c.value_or(MPI_COMM_WORLD).comm; - return context_shim(make_context_shim(alloc, comm)); + return context_shim(arb::make_context(a)); }), - "alloc"_a, "mpi"_a=pybind11::none(), - "Construct a distributed context with arguments:\n" - " alloc: The computational resources to be used for the simulation.\n" - " mpi: The MPI communicator (default None).\n") + "alloc"_a, "mpi"_a=pybind11::none(), + "Construct a distributed context with arguments:\n" + " alloc: The computational resources to be used for the simulation.\n" + " mpi: The MPI communicator (default None).\n") .def(pybind11::init( [](int threads, pybind11::object gpu, pybind11::object mpi){ - opt_int gpu_id = py2optional<int>(gpu, - "gpu id must be None, or a non-negative integer.", is_nonneg_int); + const char* gpu_err_str = "gpu_id must be None, or a non-negative integer"; + const char* mpi_err_str = "mpi must be None, or an MPI communicator"; + + auto gpu_id = py2optional<int>(gpu, gpu_err_str, is_nonneg_int); arb::proc_allocation alloc(threads, gpu_id.value_or(-1)); - if (mpi.is_none()) { - return context_shim(arb::make_context(alloc)); + + if (can_convert_to_mpi_comm(mpi)) { + return context_shim(arb::make_context(alloc, convert_to_mpi_comm(mpi))); + } + if (auto c = py2optional<mpi_comm_shim>(mpi, mpi_err_str)) { + return context_shim(arb::make_context(alloc, c->comm)); } - auto c = py2optional<mpi_comm_shim>(mpi, - "mpi must be None, or an MPI communicator."); - auto comm = c.value_or(MPI_COMM_WORLD).comm; - return context_shim(arb::make_context(alloc, comm)); + + return context_shim(arb::make_context(alloc)); }), - "threads"_a=1, "gpu_id"_a=pybind11::none(), "mpi"_a=pybind11::none(), - "Construct a distributed context with arguments:\n" - " threads: The number of threads available locally for execution (default 1).\n" - " gpu_id: The index of the GPU to use (default None).\n" - " mpi: The MPI communicator (default None).\n") + "threads"_a=1, "gpu_id"_a=pybind11::none(), "mpi"_a=pybind11::none(), + "Construct a distributed context with arguments:\n" + " threads: The number of threads available locally for execution (default 1).\n" + " gpu_id: The index of the GPU to use (default None).\n" + " mpi: The MPI communicator (default None).\n") #else .def(pybind11::init( [](int threads, pybind11::object gpu){ - opt_int gpu_id = py2optional<int>(gpu, + auto gpu_id = py2optional<int>(gpu, "gpu id must be None, or a non-negative integer.", is_nonneg_int); return context_shim(arb::make_context(arb::proc_allocation(threads, gpu_id.value_or(-1)))); }), diff --git a/python/mpi.cpp b/python/mpi.cpp index 0fd0dbadde66b9544288f211a6b5131f8227d9a5..16a2b36260336b5060ebb0dbd43a5977ca9cc651 100644 --- a/python/mpi.cpp +++ b/python/mpi.cpp @@ -20,19 +20,35 @@ namespace pyarb { #ifdef ARB_MPI_ENABLED -#ifdef ARB_WITH_MPI4PY -mpi_comm_shim comm_from_mpi4py(pybind11::object& o) { +// Convert a Python object an MPI Communicator. +// Used to construct mpi_comm_shim from arbitrary Python types. +// Currently only supports mpi4py communicators, but could be extended to +// other types. +MPI_Comm convert_to_mpi_comm(pybind11::object o) { +#ifdef ARB_WITH_MPI4PY import_mpi4py(); - - // If object is already a mpi4py communicator, return if (PyObject_TypeCheck(o.ptr(), &PyMPIComm_Type)) { - return mpi_comm_shim(*PyMPIComm_Get(o.ptr())); + return *PyMPIComm_Get(o.ptr()); } - throw arb::mpi_error(MPI_ERR_OTHER, "The argument is not an mpi4py communicator"); +#endif + throw arb::mpi_error(MPI_ERR_OTHER, "Unable to convert to an MPI Communicatior."); +} + +mpi_comm_shim::mpi_comm_shim(pybind11::object o) { + comm = convert_to_mpi_comm(o); } +// Test if a Python object can be converted to an mpi_comm_shim. +bool can_convert_to_mpi_comm(pybind11::object o) { +#ifdef ARB_WITH_MPI4PY + import_mpi4py(); + if (PyObject_TypeCheck(o.ptr(), &PyMPIComm_Type)) { + return true; + } #endif + return false; +} // Some helper functions for initializing and finalizing MPI. // Arbor requires at least MPI_THREAD_SERIALIZED, because the communication task @@ -83,6 +99,7 @@ void register_mpi(pybind11::module& m) { pybind11::class_<mpi_comm_shim> mpi_comm(m, "mpi_comm"); mpi_comm .def(pybind11::init<>()) + .def(pybind11::init([](pybind11::object o){return mpi_comm_shim(o);})) .def("__str__", &mpi_comm_string) .def("__repr__", &mpi_comm_string); @@ -90,10 +107,6 @@ void register_mpi(pybind11::module& m) { m.def("mpi_finalize", &mpi_finalize, "Finalize MPI (calls MPI_Finalize)"); m.def("mpi_is_initialized", &mpi_is_initialized, "Check if MPI is initialized."); m.def("mpi_is_finalized", &mpi_is_finalized, "Check if MPI is finalized."); - - #ifdef ARB_WITH_MPI4PY - m.def("mpi_comm_from_mpi4py", comm_from_mpi4py); - #endif } #endif } // namespace pyarb diff --git a/python/mpi.hpp b/python/mpi.hpp index 61ba646c34461742c760ba7418a6d78e1e40bd00..2a278e6fcb3539123c74a075a8ea9d49f461e7b0 100644 --- a/python/mpi.hpp +++ b/python/mpi.hpp @@ -13,8 +13,13 @@ struct mpi_comm_shim { mpi_comm_shim() = default; mpi_comm_shim(MPI_Comm c): comm(c) {} + + mpi_comm_shim(pybind11::object o); }; +bool can_convert_to_mpi_comm(pybind11::object o); +MPI_Comm convert_to_mpi_comm(pybind11::object o); + } // namespace pyarb #endif diff --git a/python/test/unit_distributed/runner.py b/python/test/unit_distributed/runner.py index 401071da8df325829e34692d4d4dc211049d84a6..ddb276082b8e7ac77db0804e4d4fb2200a2b31cf 100644 --- a/python/test/unit_distributed/runner.py +++ b/python/test/unit_distributed/runner.py @@ -54,7 +54,7 @@ if __name__ == "__main__": arb.mpi_init() if mpi4py_enabled: - comm = arb.mpi_comm_from_mpi4py(mpi.COMM_WORLD) + comm = arb.mpi_comm(mpi.COMM_WORLD) elif mpi_enabled: comm = arb.mpi_comm() diff --git a/python/test/unit_distributed/test_contexts_arbmpi.py b/python/test/unit_distributed/test_contexts_arbmpi.py index 05183078882f149ab7b6f73b9d7edd5bc5b93a81..947eb45858100c06caec7b68048a638421e8df33 100644 --- a/python/test/unit_distributed/test_contexts_arbmpi.py +++ b/python/test/unit_distributed/test_contexts_arbmpi.py @@ -67,10 +67,10 @@ class Contexts_arbmpi(unittest.TestCase): alloc = arb.proc_allocation() with self.assertRaisesRegex(RuntimeError, - "mpi must be None, or an MPI communicator."): + "mpi must be None, or an MPI communicator"): arb.context(mpi='MPI_COMM_WORLD') with self.assertRaisesRegex(RuntimeError, - "mpi must be None, or an MPI communicator."): + "mpi must be None, or an MPI communicator"): arb.context(alloc, mpi=0) def test_finalized_arbmpi(self): diff --git a/python/test/unit_distributed/test_contexts_mpi4py.py b/python/test/unit_distributed/test_contexts_mpi4py.py index 44af8cb8b19d204efd5620a2b7d18f0dc83bba75..c93171751ae3f7b310838b72bbbc6bed2bb8f9bd 100644 --- a/python/test/unit_distributed/test_contexts_mpi4py.py +++ b/python/test/unit_distributed/test_contexts_mpi4py.py @@ -34,20 +34,20 @@ class Contexts_mpi4py(unittest.TestCase): self.assertTrue(mpi.Is_initialized()) def test_communicator_mpi4py(self): - comm = arb.mpi_comm_from_mpi4py(mpi.COMM_WORLD) + comm = arb.mpi_comm(mpi.COMM_WORLD) # test that set communicator is MPI_COMM_WORLD self.assertEqual(str(comm), '<mpi_comm: MPI_COMM_WORLD>') def test_context_mpi4py(self): - comm = arb.mpi_comm_from_mpi4py(mpi.COMM_WORLD) + comm = arb.mpi_comm(mpi.COMM_WORLD) # test context with mpi ctx = arb.context(mpi=comm) self.assertTrue(ctx.has_mpi) def test_context_allocation_mpi4py(self): - comm = arb.mpi_comm_from_mpi4py(mpi.COMM_WORLD) + comm = arb.mpi_comm(mpi.COMM_WORLD) # test context with alloc and mpi alloc = arb.proc_allocation() @@ -60,10 +60,10 @@ class Contexts_mpi4py(unittest.TestCase): alloc = arb.proc_allocation() with self.assertRaisesRegex(RuntimeError, - "mpi must be None, or an MPI communicator."): + "mpi must be None, or an MPI communicator"): arb.context(mpi='MPI_COMM_WORLD') with self.assertRaisesRegex(RuntimeError, - "mpi must be None, or an MPI communicator."): + "mpi must be None, or an MPI communicator"): arb.context(alloc, mpi=0) def test_finalized_mpi4py(self):