diff --git a/python/context.cpp b/python/context.cpp
index be043ec02fdf23b4d20c8d8c8ff3415571bb8b57..4e7dde4f49c3e3b9d2f107176f799946039fb958 100644
--- a/python/context.cpp
+++ b/python/context.cpp
@@ -4,9 +4,11 @@
 
 #include <arbor/context.hpp>
 #include <arbor/version.hpp>
+#include <arbor/util/optional.hpp>
 
 #include <pybind11/pybind11.h>
 
+
 #include "context.hpp"
 #include "conversion.hpp"
 #include "error.hpp"
@@ -24,9 +26,7 @@ auto is_nonneg_int = [](int n){ return n>=0; };
 
 // A Python shim that holds the information that describes an arb::proc_allocation.
 struct proc_allocation_shim {
-    using opt_int = arb::util::optional<int>;
-
-    opt_int gpu_id = {};
+    arb::util::optional<int> gpu_id = {};
     int num_threads = 1;
 
     proc_allocation_shim(): proc_allocation_shim(1, pybind11::none()) {}
@@ -46,9 +46,9 @@ struct proc_allocation_shim {
         num_threads = threads;
     };
 
-    const opt_int get_gpu_id()      const { return gpu_id; }
-    const int     get_num_threads() const { return num_threads; }
-    const bool    has_gpu()         const { return (gpu_id.value_or(-1) >= 0) ? true : false; }
+    arb::util::optional<int> get_gpu_id() const { return gpu_id; }
+    int get_num_threads() const { return num_threads; }
+    bool has_gpu() const { return bool(gpu_id); }
 
     // helper function to use arb::make_context(arb::proc_allocation)
     arb::proc_allocation allocation() const {
@@ -75,20 +75,9 @@ std::string proc_alloc_string(const proc_allocation_shim& a) {
     return s.str();
 }
 
-// helper functions to wrap arb::make_context
-arb::context make_context_shim(const proc_allocation_shim& Alloc) {
-    return arb::make_context(Alloc.allocation());
-}
-
-template <typename Comm>
-arb::context make_context_shim(const proc_allocation_shim& Alloc, Comm comm) {
-    return arb::make_context(Alloc.allocation(), comm);
-}
-
 void register_contexts(pybind11::module& m) {
     using namespace std::string_literals;
     using namespace pybind11::literals;
-    using opt_int = arb::util::optional<int>;
 
     // proc_allocation
     pybind11::class_<proc_allocation_shim> proc_allocation(m, "proc_allocation",
@@ -104,7 +93,7 @@ void register_contexts(pybind11::module& m) {
         .def_property("gpu_id", &proc_allocation_shim::get_gpu_id, &proc_allocation_shim::set_gpu_id,
             "The identifier of the GPU to use.\n"
             "Corresponds to the integer index used to identify GPUs in CUDA API calls.")
-	.def_property_readonly("has_gpu", &proc_allocation_shim::has_gpu,
+        .def_property_readonly("has_gpu", &proc_allocation_shim::has_gpu,
             "Whether a GPU is being used (True/False).")
         .def("__str__", &proc_alloc_string)
         .def("__repr__", &proc_alloc_string);
@@ -118,48 +107,53 @@ void register_contexts(pybind11::module& m) {
             )
         .def(pybind11::init(
             [](const proc_allocation_shim& alloc){
-                return context_shim(make_context_shim(alloc));
-            }),
-             "alloc"_a,
-             "Construct a local context with argument:\n"
-             "  alloc:   The computational resources to be used for the simulation.\n")
+                return context_shim(arb::make_context(alloc.allocation())); }),
+            "alloc"_a,
+            "Construct a local context with argument:\n"
+            "  alloc:   The computational resources to be used for the simulation.\n")
 #ifdef ARB_MPI_ENABLED
         .def(pybind11::init(
-            [](const proc_allocation_shim& alloc, pybind11::object mpi){
-                if (mpi.is_none()) {
-                    return context_shim(make_context_shim(alloc));
+            [](proc_allocation_shim alloc, pybind11::object mpi){
+                const char* mpi_err_str = "mpi must be None, or an MPI communicator";
+                auto a = alloc.allocation(); // unwrap the C++ resource_allocation description
+                if (can_convert_to_mpi_comm(mpi)) {
+                    return context_shim(arb::make_context(a, convert_to_mpi_comm(mpi)));
+                }
+                if (auto c = py2optional<mpi_comm_shim>(mpi, mpi_err_str)) {
+                    return context_shim(arb::make_context(a, c->comm));
                 }
-                auto c = py2optional<mpi_comm_shim>(mpi,
-                        "mpi must be None, or an MPI communicator.");
-                auto comm = c.value_or(MPI_COMM_WORLD).comm;
-                return context_shim(make_context_shim(alloc, comm));
+                return context_shim(arb::make_context(a));
             }),
-             "alloc"_a, "mpi"_a=pybind11::none(),
-             "Construct a distributed context with arguments:\n"
-             "  alloc:   The computational resources to be used for the simulation.\n"
-             "  mpi:     The MPI communicator (default None).\n")
+            "alloc"_a, "mpi"_a=pybind11::none(),
+            "Construct a distributed context with arguments:\n"
+            "  alloc:   The computational resources to be used for the simulation.\n"
+            "  mpi:     The MPI communicator (default None).\n")
         .def(pybind11::init(
             [](int threads, pybind11::object gpu, pybind11::object mpi){
-                opt_int gpu_id = py2optional<int>(gpu,
-                        "gpu id must be None, or a non-negative integer.", is_nonneg_int);
+                const char* gpu_err_str = "gpu_id must be None, or a non-negative integer";
+                const char* mpi_err_str = "mpi must be None, or an MPI communicator";
+
+                auto gpu_id = py2optional<int>(gpu, gpu_err_str, is_nonneg_int);
                 arb::proc_allocation alloc(threads, gpu_id.value_or(-1));
-                if (mpi.is_none()) {
-                    return context_shim(arb::make_context(alloc));
+
+                if (can_convert_to_mpi_comm(mpi)) {
+                    return context_shim(arb::make_context(alloc, convert_to_mpi_comm(mpi)));
+                }
+                if (auto c = py2optional<mpi_comm_shim>(mpi, mpi_err_str)) {
+                    return context_shim(arb::make_context(alloc, c->comm));
                 }
-                auto c = py2optional<mpi_comm_shim>(mpi,
-                        "mpi must be None, or an MPI communicator.");
-                auto comm = c.value_or(MPI_COMM_WORLD).comm;
-                return context_shim(arb::make_context(alloc, comm));
+
+                return context_shim(arb::make_context(alloc));
             }),
-             "threads"_a=1, "gpu_id"_a=pybind11::none(), "mpi"_a=pybind11::none(),
-             "Construct a distributed context with arguments:\n"
-             "  threads: The number of threads available locally for execution (default 1).\n"
-             "  gpu_id:  The index of the GPU to use (default None).\n"
-             "  mpi:     The MPI communicator (default None).\n")
+            "threads"_a=1, "gpu_id"_a=pybind11::none(), "mpi"_a=pybind11::none(),
+            "Construct a distributed context with arguments:\n"
+            "  threads: The number of threads available locally for execution (default 1).\n"
+            "  gpu_id:  The index of the GPU to use (default None).\n"
+            "  mpi:     The MPI communicator (default None).\n")
 #else
         .def(pybind11::init(
             [](int threads, pybind11::object gpu){
-                opt_int gpu_id = py2optional<int>(gpu,
+                auto gpu_id = py2optional<int>(gpu,
                         "gpu id must be None, or a non-negative integer.", is_nonneg_int);
                 return context_shim(arb::make_context(arb::proc_allocation(threads, gpu_id.value_or(-1))));
             }),
diff --git a/python/mpi.cpp b/python/mpi.cpp
index 0fd0dbadde66b9544288f211a6b5131f8227d9a5..16a2b36260336b5060ebb0dbd43a5977ca9cc651 100644
--- a/python/mpi.cpp
+++ b/python/mpi.cpp
@@ -20,19 +20,35 @@
 namespace pyarb {
 
 #ifdef ARB_MPI_ENABLED
-#ifdef ARB_WITH_MPI4PY
 
-mpi_comm_shim comm_from_mpi4py(pybind11::object& o) {
+// Convert a Python object an MPI Communicator.
+// Used to construct mpi_comm_shim from arbitrary Python types.
+// Currently only supports mpi4py communicators, but could be extended to
+// other types.
+MPI_Comm convert_to_mpi_comm(pybind11::object o) {
+#ifdef ARB_WITH_MPI4PY
     import_mpi4py();
-
-    // If object is already a mpi4py communicator, return
     if (PyObject_TypeCheck(o.ptr(), &PyMPIComm_Type)) {
-        return mpi_comm_shim(*PyMPIComm_Get(o.ptr()));
+        return *PyMPIComm_Get(o.ptr());
     }
-    throw arb::mpi_error(MPI_ERR_OTHER, "The argument is not an mpi4py communicator");
+#endif
+    throw arb::mpi_error(MPI_ERR_OTHER, "Unable to convert to an MPI Communicatior.");
+}
+
+mpi_comm_shim::mpi_comm_shim(pybind11::object o) {
+    comm = convert_to_mpi_comm(o);
 }
 
+// Test if a Python object can be converted to an mpi_comm_shim.
+bool can_convert_to_mpi_comm(pybind11::object o) {
+#ifdef ARB_WITH_MPI4PY
+    import_mpi4py();
+    if (PyObject_TypeCheck(o.ptr(), &PyMPIComm_Type)) {
+        return true;
+    }
 #endif
+    return false;
+}
 
 // Some helper functions for initializing and finalizing MPI.
 // Arbor requires at least MPI_THREAD_SERIALIZED, because the communication task
@@ -83,6 +99,7 @@ void register_mpi(pybind11::module& m) {
     pybind11::class_<mpi_comm_shim> mpi_comm(m, "mpi_comm");
     mpi_comm
         .def(pybind11::init<>())
+        .def(pybind11::init([](pybind11::object o){return mpi_comm_shim(o);}))
         .def("__str__", &mpi_comm_string)
         .def("__repr__", &mpi_comm_string);
 
@@ -90,10 +107,6 @@ void register_mpi(pybind11::module& m) {
     m.def("mpi_finalize", &mpi_finalize, "Finalize MPI (calls MPI_Finalize)");
     m.def("mpi_is_initialized", &mpi_is_initialized, "Check if MPI is initialized.");
     m.def("mpi_is_finalized", &mpi_is_finalized, "Check if MPI is finalized.");
-
-    #ifdef ARB_WITH_MPI4PY
-    m.def("mpi_comm_from_mpi4py", comm_from_mpi4py);
-    #endif
 }
 #endif
 } // namespace pyarb
diff --git a/python/mpi.hpp b/python/mpi.hpp
index 61ba646c34461742c760ba7418a6d78e1e40bd00..2a278e6fcb3539123c74a075a8ea9d49f461e7b0 100644
--- a/python/mpi.hpp
+++ b/python/mpi.hpp
@@ -13,8 +13,13 @@ struct mpi_comm_shim {
 
     mpi_comm_shim() = default;
     mpi_comm_shim(MPI_Comm c): comm(c) {}
+
+    mpi_comm_shim(pybind11::object o);
 };
 
+bool can_convert_to_mpi_comm(pybind11::object o);
+MPI_Comm convert_to_mpi_comm(pybind11::object o);
+
 } // namespace pyarb
 #endif
 
diff --git a/python/test/unit_distributed/runner.py b/python/test/unit_distributed/runner.py
index 401071da8df325829e34692d4d4dc211049d84a6..ddb276082b8e7ac77db0804e4d4fb2200a2b31cf 100644
--- a/python/test/unit_distributed/runner.py
+++ b/python/test/unit_distributed/runner.py
@@ -54,7 +54,7 @@ if __name__ == "__main__":
         arb.mpi_init()
 
     if mpi4py_enabled:
-        comm = arb.mpi_comm_from_mpi4py(mpi.COMM_WORLD)
+        comm = arb.mpi_comm(mpi.COMM_WORLD)
     elif mpi_enabled:
         comm = arb.mpi_comm()
 
diff --git a/python/test/unit_distributed/test_contexts_arbmpi.py b/python/test/unit_distributed/test_contexts_arbmpi.py
index 05183078882f149ab7b6f73b9d7edd5bc5b93a81..947eb45858100c06caec7b68048a638421e8df33 100644
--- a/python/test/unit_distributed/test_contexts_arbmpi.py
+++ b/python/test/unit_distributed/test_contexts_arbmpi.py
@@ -67,10 +67,10 @@ class Contexts_arbmpi(unittest.TestCase):
         alloc = arb.proc_allocation()
 
         with self.assertRaisesRegex(RuntimeError,
-            "mpi must be None, or an MPI communicator."):
+            "mpi must be None, or an MPI communicator"):
             arb.context(mpi='MPI_COMM_WORLD')
         with self.assertRaisesRegex(RuntimeError,
-            "mpi must be None, or an MPI communicator."):
+            "mpi must be None, or an MPI communicator"):
             arb.context(alloc, mpi=0)
 
     def test_finalized_arbmpi(self):
diff --git a/python/test/unit_distributed/test_contexts_mpi4py.py b/python/test/unit_distributed/test_contexts_mpi4py.py
index 44af8cb8b19d204efd5620a2b7d18f0dc83bba75..c93171751ae3f7b310838b72bbbc6bed2bb8f9bd 100644
--- a/python/test/unit_distributed/test_contexts_mpi4py.py
+++ b/python/test/unit_distributed/test_contexts_mpi4py.py
@@ -34,20 +34,20 @@ class Contexts_mpi4py(unittest.TestCase):
         self.assertTrue(mpi.Is_initialized())
 
     def test_communicator_mpi4py(self):
-        comm = arb.mpi_comm_from_mpi4py(mpi.COMM_WORLD)
+        comm = arb.mpi_comm(mpi.COMM_WORLD)
 
         # test that set communicator is MPI_COMM_WORLD
         self.assertEqual(str(comm), '<mpi_comm: MPI_COMM_WORLD>')
 
     def test_context_mpi4py(self):
-        comm = arb.mpi_comm_from_mpi4py(mpi.COMM_WORLD)
+        comm = arb.mpi_comm(mpi.COMM_WORLD)
 
         # test context with mpi
         ctx = arb.context(mpi=comm)
         self.assertTrue(ctx.has_mpi)
 
     def test_context_allocation_mpi4py(self):
-        comm = arb.mpi_comm_from_mpi4py(mpi.COMM_WORLD)
+        comm = arb.mpi_comm(mpi.COMM_WORLD)
 
         # test context with alloc and mpi
         alloc = arb.proc_allocation()
@@ -60,10 +60,10 @@ class Contexts_mpi4py(unittest.TestCase):
         alloc = arb.proc_allocation()
 
         with self.assertRaisesRegex(RuntimeError,
-            "mpi must be None, or an MPI communicator."):
+            "mpi must be None, or an MPI communicator"):
             arb.context(mpi='MPI_COMM_WORLD')
         with self.assertRaisesRegex(RuntimeError,
-            "mpi must be None, or an MPI communicator."):
+            "mpi must be None, or an MPI communicator"):
             arb.context(alloc, mpi=0)
 
     def test_finalized_mpi4py(self):