From 0d6b365a9b3f803830b39d9b623f15a1346fef60 Mon Sep 17 00:00:00 2001
From: noraabiakar <nora.abiakar@gmail.com>
Date: Tue, 31 Jul 2018 12:46:49 +0200
Subject: [PATCH] Manage distributed_context using shared pointers (#555)

* Replace distributed_contest with shared_ptr<distributed_context> in execution_context and pass around the shared pointer instead of a raw pointer.
* Fix construction of mpi_context
* Remove num_threads() from arb and arb::threading. Modify mpi_context so it also returns a shared_ptr. proc_allocation is initialized from execution context to determine available resources.
* Rename threading backend files. Delete useless files.
* Pass execution_context by const reference or value.
* Remove code duplication in thread_system constructors.
---
 arbor/CMakeLists.txt                          |   3 +-
 arbor/communication/communicator.hpp          |   8 +-
 arbor/communication/mpi_context.cpp           |   8 +-
 arbor/local_alloc.cpp                         |   7 +-
 arbor/partition_load_balance.cpp              |   6 +-
 arbor/profile/meter_manager.cpp               |   7 +-
 arbor/profile/profiler.cpp                    |   4 +-
 arbor/simulation.cpp                          |  13 +-
 arbor/threadinfo.cpp                          |  17 --
 arbor/threading/cthread.cpp                   | 131 ---------
 arbor/threading/cthread.hpp                   |   7 -
 arbor/threading/cthread_impl.hpp              | 243 -----------------
 arbor/threading/cthread_sort.hpp              |  24 --
 arbor/threading/thread_info.cpp               |  76 ++++++
 arbor/threading/thread_info.hpp               |  26 ++
 arbor/threading/threading.cpp                 | 173 +++++++-----
 arbor/threading/threading.hpp                 | 248 ++++++++++++++++--
 aux/include/aux/with_mpi.hpp                  |   4 +
 example/bench/bench.cpp                       |  15 +-
 example/brunel/brunel_miniapp.cpp             |  35 ++-
 example/generators/event_gen.cpp              |   6 +-
 example/miniapp/miniapp.cpp                   |  34 ++-
 include/arbor/distributed_context.hpp         |   4 +-
 include/arbor/domain_decomposition.hpp        |   3 +-
 include/arbor/execution_context.hpp           |  15 +-
 include/arbor/load_balance.hpp                |   2 +-
 include/arbor/profile/meter_manager.hpp       |  10 +-
 include/arbor/simulation.hpp                  |   2 +-
 include/arbor/threadinfo.hpp                  |  13 -
 test/ubench/task_system.cpp                   |  13 +-
 .../unit-distributed/distributed_listener.cpp |   4 +-
 .../unit-distributed/distributed_listener.hpp |   6 +-
 test/unit-distributed/test.cpp                |   9 +-
 test/unit-distributed/test_communicator.cpp   |  49 ++--
 .../test_domain_decomposition.cpp             |  15 +-
 test/unit/test_algorithms.cpp                 |  25 +-
 test/unit/test_domain_decomposition.cpp       |  10 +-
 test/unit/test_fvm_lowered.cpp                |   4 +-
 test/unit/test_lif_cell_group.cpp             |  14 +-
 test/unit/test_thread.cpp                     |  27 +-
 test/validation/validate_ball_and_stick.cpp   |  19 +-
 test/validation/validate_kinetic.cpp          |  10 +-
 test/validation/validate_soma.cpp             |   7 +-
 test/validation/validate_synapses.cpp         |  10 +-
 44 files changed, 633 insertions(+), 733 deletions(-)
 delete mode 100644 arbor/threadinfo.cpp
 delete mode 100644 arbor/threading/cthread.cpp
 delete mode 100644 arbor/threading/cthread.hpp
 delete mode 100644 arbor/threading/cthread_impl.hpp
 delete mode 100644 arbor/threading/cthread_sort.hpp
 create mode 100644 arbor/threading/thread_info.cpp
 create mode 100644 arbor/threading/thread_info.hpp
 delete mode 100644 include/arbor/threadinfo.hpp

diff --git a/arbor/CMakeLists.txt b/arbor/CMakeLists.txt
index ec612bbc..f3ff1c4d 100644
--- a/arbor/CMakeLists.txt
+++ b/arbor/CMakeLists.txt
@@ -39,9 +39,8 @@ set(arbor_sources
     spike_event_io.cpp
     spike_source_cell_group.cpp
     swcio.cpp
-    threadinfo.cpp
-    threading/cthread.cpp
     threading/threading.cpp
+    threading/thread_info.cpp
     thread_private_spike_store.cpp
     util/hostname.cpp
     util/unwind.cpp
diff --git a/arbor/communication/communicator.hpp b/arbor/communication/communicator.hpp
index 70b79ad7..4a0777bf 100644
--- a/arbor/communication/communicator.hpp
+++ b/arbor/communication/communicator.hpp
@@ -44,10 +44,10 @@ public:
 
     explicit communicator(const recipe& rec,
                           const domain_decomposition& dom_dec,
-                          const execution_context* ctx)
+                          execution_context ctx)
     {
-        distributed_ = &ctx->distributed;
-        thread_pool_ = ctx->thread_pool;
+        distributed_ = ctx.distributed;
+        thread_pool_ = ctx.thread_pool;
 
         num_domains_ = distributed_->size();
         num_local_groups_ = dom_dec.groups.size();
@@ -261,7 +261,7 @@ private:
     std::vector<cell_size_type> index_divisions_;
     util::partition_view_type<std::vector<cell_size_type>> index_part_;
 
-    const distributed_context* distributed_;
+    distributed_context_handle distributed_;
     task_system_handle thread_pool_;
     std::uint64_t num_spikes_ = 0u;
 };
diff --git a/arbor/communication/mpi_context.cpp b/arbor/communication/mpi_context.cpp
index 4161e3b8..b1134b9a 100644
--- a/arbor/communication/mpi_context.cpp
+++ b/arbor/communication/mpi_context.cpp
@@ -62,13 +62,13 @@ struct mpi_context_impl {
     }
 };
 
-distributed_context mpi_context() {
-    return mpi_context_impl(MPI_COMM_WORLD);
+std::shared_ptr<distributed_context> mpi_context() {
+    return std::make_shared<distributed_context>(mpi_context_impl(MPI_COMM_WORLD));
 }
 
 template <>
-distributed_context mpi_context(MPI_Comm comm) {
-    return mpi_context_impl(comm);
+std::shared_ptr<distributed_context> mpi_context(MPI_Comm comm) {
+    return std::make_shared<distributed_context>(mpi_context_impl(comm));
 }
 
 } // namespace arb
diff --git a/arbor/local_alloc.cpp b/arbor/local_alloc.cpp
index 92204a5f..dfe00d60 100644
--- a/arbor/local_alloc.cpp
+++ b/arbor/local_alloc.cpp
@@ -1,13 +1,14 @@
 #include <arbor/domain_decomposition.hpp>
-#include <arbor/threadinfo.hpp>
+#include <arbor/execution_context.hpp>
 
 #include "hardware/node_info.hpp"
+#include "threading/threading.hpp"
 
 namespace arb {
 
-proc_allocation local_allocation() {
+proc_allocation local_allocation(const execution_context& ctx) {
     proc_allocation info;
-    info.num_threads = arb::num_threads();
+    info.num_threads = ctx.thread_pool->get_num_threads();
     info.num_gpus = arb::hw::node_gpus();
 
     return info;
diff --git a/arbor/partition_load_balance.cpp b/arbor/partition_load_balance.cpp
index f9198140..60dd3946 100644
--- a/arbor/partition_load_balance.cpp
+++ b/arbor/partition_load_balance.cpp
@@ -13,7 +13,7 @@ namespace arb {
 domain_decomposition partition_load_balance(
     const recipe& rec,
     proc_allocation nd,
-    const execution_context* ctx,
+    const execution_context& ctx,
     partition_hint_map hint_map)
 {
     struct partition_gid_domain {
@@ -31,8 +31,8 @@ domain_decomposition partition_load_balance(
 
     using util::make_span;
 
-    unsigned num_domains = ctx->distributed.size();
-    unsigned domain_id = ctx->distributed.id();
+    unsigned num_domains = ctx.distributed->size();
+    unsigned domain_id = ctx.distributed->id();
     auto num_global_cells = rec.num_cells();
 
     auto dom_size = [&](unsigned dom) -> cell_gid_type {
diff --git a/arbor/profile/meter_manager.cpp b/arbor/profile/meter_manager.cpp
index c5663700..e08b1c1a 100644
--- a/arbor/profile/meter_manager.cpp
+++ b/arbor/profile/meter_manager.cpp
@@ -1,6 +1,7 @@
 #include <arbor/profile/timer.hpp>
 
 #include <arbor/profile/meter_manager.hpp>
+#include <arbor/execution_context.hpp>
 
 #include "memory_meter.hpp"
 #include "power_meter.hpp"
@@ -18,7 +19,7 @@ using util::strprintf;
 
 measurement::measurement(std::string n, std::string u,
                          const std::vector<double>& readings,
-                         const distributed_context* ctx):
+                         const distributed_context_handle& ctx):
     name(std::move(n)), units(std::move(u))
 {
     // Assert that the same number of readings were taken on every domain.
@@ -34,7 +35,7 @@ measurement::measurement(std::string n, std::string u,
     }
 }
 
-meter_manager::meter_manager(const distributed_context* ctx): glob_ctx_(ctx) {
+meter_manager::meter_manager(distributed_context_handle ctx): glob_ctx_(ctx) {
     if (auto m = make_memory_meter()) {
         meters_.push_back(std::move(m));
     }
@@ -92,7 +93,7 @@ const std::vector<double>& meter_manager::times() const {
     return times_;
 }
 
-const distributed_context* meter_manager::context() const {
+distributed_context_handle meter_manager::context() const {
     return glob_ctx_;
 }
 
diff --git a/arbor/profile/profiler.cpp b/arbor/profile/profiler.cpp
index 5ccea418..9923438f 100644
--- a/arbor/profile/profiler.cpp
+++ b/arbor/profile/profiler.cpp
@@ -171,9 +171,7 @@ void recorder::clear() {
 
 // profiler implementation
 
-profiler::profiler() {
-    recorders_.resize(threading::num_threads());
-}
+profiler::profiler() {}
 
 void profiler::initialize(task_system_handle& ts) {
     recorders_.resize(ts.get()->get_num_threads());
diff --git a/arbor/simulation.cpp b/arbor/simulation.cpp
index 24587676..be20b6da 100644
--- a/arbor/simulation.cpp
+++ b/arbor/simulation.cpp
@@ -13,6 +13,7 @@
 #include "communication/communicator.hpp"
 #include "merge_events.hpp"
 #include "thread_private_spike_store.hpp"
+#include "threading/threading.hpp"
 #include "util/double_buffer.hpp"
 #include "util/filter.hpp"
 #include "util/maputil.hpp"
@@ -48,7 +49,7 @@ public:
 
 class simulation_state {
 public:
-    simulation_state(const recipe& rec, const domain_decomposition& decomp, const execution_context* ctx);
+    simulation_state(const recipe& rec, const domain_decomposition& decomp, execution_context ctx);
 
     void reset();
 
@@ -126,12 +127,12 @@ private:
 simulation_state::simulation_state(
         const recipe& rec,
         const domain_decomposition& decomp,
-        const execution_context* ctx
+        execution_context ctx
     ):
-    local_spikes_(new spike_double_buffer(thread_private_spike_store(ctx->thread_pool),
-                                          thread_private_spike_store(ctx->thread_pool))),
+    local_spikes_(new spike_double_buffer(thread_private_spike_store(ctx.thread_pool),
+                                          thread_private_spike_store(ctx.thread_pool))),
     communicator_(rec, decomp, ctx),
-    task_system_(ctx->thread_pool)
+    task_system_(ctx.thread_pool)
 {
     const auto num_local_cells = communicator_.num_local_cells();
 
@@ -418,7 +419,7 @@ void simulation_state::inject_events(const pse_vector& events) {
 simulation::simulation(
     const recipe& rec,
     const domain_decomposition& decomp,
-    const execution_context* ctx)
+    execution_context ctx)
 {
     impl_.reset(new simulation_state(rec, decomp, ctx));
 }
diff --git a/arbor/threadinfo.cpp b/arbor/threadinfo.cpp
deleted file mode 100644
index 42df2933..00000000
--- a/arbor/threadinfo.cpp
+++ /dev/null
@@ -1,17 +0,0 @@
-#include <string>
-
-#include <arbor/threadinfo.hpp>
-
-#include "threading/threading.hpp"
-
-namespace arb {
-
-int num_threads() {
-    return threading::num_threads();
-}
-
-std::string thread_implementation() {
-    return threading::description();
-}
-
-} // namespace arb
diff --git a/arbor/threading/cthread.cpp b/arbor/threading/cthread.cpp
deleted file mode 100644
index e613beac..00000000
--- a/arbor/threading/cthread.cpp
+++ /dev/null
@@ -1,131 +0,0 @@
-#include <atomic>
-#include <cassert>
-#include <cstring>
-#include <exception>
-#include <iostream>
-#include <regex>
-
-#include "cthread.hpp"
-#include "threading.hpp"
-#include "arbor/execution_context.hpp"
-
-using namespace arb::threading::impl;
-using namespace arb::threading;
-using namespace arb;
-
-task notification_queue::try_pop() {
-    task tsk;
-    lock q_lock{q_mutex_, std::try_to_lock};
-    if (q_lock && !q_tasks_.empty()) {
-        tsk = std::move(q_tasks_.front());
-        q_tasks_.pop_front();
-    }
-    return tsk;
-}
-
-task notification_queue::pop() {
-    task tsk;
-    lock q_lock{q_mutex_};
-    while (q_tasks_.empty() && !quit_) {
-        q_tasks_available_.wait(q_lock);
-    }
-    if (!q_tasks_.empty()) {
-        tsk = std::move(q_tasks_.front());
-        q_tasks_.pop_front();
-    }
-    return tsk;
-}
-
-bool notification_queue::try_push(task& tsk) {
-    {
-        lock q_lock{q_mutex_, std::try_to_lock};
-        if (!q_lock) return false;
-        q_tasks_.push_back(std::move(tsk));
-        tsk = 0;
-    }
-    q_tasks_available_.notify_all();
-    return true;
-}
-
-void notification_queue::push(task&& tsk) {
-    {
-        lock q_lock{q_mutex_};
-        q_tasks_.push_back(std::move(tsk));
-    }
-    q_tasks_available_.notify_all();
-}
-
-void notification_queue::quit() {
-    {
-        lock q_lock{q_mutex_};
-        quit_ = true;
-    }
-    q_tasks_available_.notify_all();
-}
-
-void task_system::run_tasks_loop(int i){
-    while (true) {
-        task tsk;
-        for (unsigned n = 0; n != count_; n++) {
-            tsk = q_[(i + n) % count_].try_pop();
-            if (tsk) break;
-        }
-        if (!tsk) tsk = q_[i].pop();
-        if (!tsk) break;
-        tsk();
-    }
-}
-
-void task_system::try_run_task() {
-    auto nthreads = get_num_threads();
-    task tsk;
-    for (int n = 0; n != nthreads; n++) {
-        tsk = q_[n % nthreads].try_pop();
-        if (tsk) {
-            tsk();
-            break;
-        }
-    }
-}
-
-task_system::task_system(int nthreads) : count_(nthreads), q_(nthreads) {
-    assert( nthreads > 0);
-
-    // now for the main thread
-    auto tid = std::this_thread::get_id();
-    thread_ids_[tid] = 0;
-
-    for (unsigned i = 1; i < count_; i++) {
-        threads_.emplace_back([this, i]{run_tasks_loop(i);});
-        tid = threads_.back().get_id();
-        thread_ids_[tid] = i;
-    }
-}
-
-task_system::~task_system() {
-    for (auto& e: q_) e.quit();
-    for (auto& e: threads_) e.join();
-}
-
-void task_system::async(task tsk) {
-    auto i = index_++;
-
-    for (unsigned n = 0; n != count_; n++) {
-        if (q_[(i + n) % count_].try_push(tsk)) return;
-    }
-    q_[i % count_].push(std::move(tsk));
-}
-
-int task_system::get_num_threads() {
-    return threads_.size() + 1;
-}
-
-std::unordered_map<std::thread::id, std::size_t> task_system::get_thread_ids() {
-    return thread_ids_;
-};
-
-task_system_handle arb::make_thread_pool(int nthreads) {
-    return task_system_handle(new task_system(nthreads));
-}
-
-
diff --git a/arbor/threading/cthread.hpp b/arbor/threading/cthread.hpp
deleted file mode 100644
index b2a6142b..00000000
--- a/arbor/threading/cthread.hpp
+++ /dev/null
@@ -1,7 +0,0 @@
-#pragma once
-
-// task_group definition
-#include "cthread_impl.hpp"
-
-// and sorts use cthread_parallel_stable_sort
-#include "cthread_sort.hpp"
diff --git a/arbor/threading/cthread_impl.hpp b/arbor/threading/cthread_impl.hpp
deleted file mode 100644
index 6899691d..00000000
--- a/arbor/threading/cthread_impl.hpp
+++ /dev/null
@@ -1,243 +0,0 @@
-#pragma once
-
-#include <iostream>
-#include <type_traits>
-
-#include <thread>
-#include <mutex>
-#include <algorithm>
-#include <array>
-#include <chrono>
-#include <string>
-#include <vector>
-#include <type_traits>
-#include <functional>
-#include <condition_variable>
-#include <utility>
-#include <unordered_map>
-#include <deque>
-#include <atomic>
-#include <type_traits>
-
-#include <cstdlib>
-#include "arbor/execution_context.hpp"
-
-namespace arb {
-namespace threading {
-
-// Forward declare task_group at bottom of this header
-class task_group;
-
-using std::mutex;
-using lock = std::unique_lock<mutex>;
-using std::condition_variable;
-using task = std::function<void()>;
-
-namespace impl {
-class notification_queue {
-private:
-    // FIFO of pending tasks.
-    std::deque<task> q_tasks_;
-
-    // Lock and signal on task availability change this is the crucial bit.
-    mutex q_mutex_;
-    condition_variable q_tasks_available_;
-
-    // Flag to handle exit from all threads.
-    bool quit_ = false;
-
-public:
-    // Pops a task from the task queue returns false when queue is empty.
-    task try_pop();
-    task pop();
-
-    // Pushes a task into the task queue and increases task group counter.
-    void push(task&& tsk); // TODO: need to use value?
-    bool try_push(task& tsk);
-
-    // Finish popping all waiting tasks on queue then stop trying to pop new tasks
-    void quit();
-};
-}// namespace impl
-
-class task_system {
-private:
-    unsigned count_;
-
-    std::vector<std::thread> threads_;
-
-    // queue of tasks
-    std::vector<impl::notification_queue> q_;
-
-    // threads -> index
-    std::unordered_map<std::thread::id, std::size_t> thread_ids_;
-
-    // total number of tasks pushed in all queues
-    std::atomic<unsigned> index_{0};
-
-public:
-    // Create nthreads-1 new c std threads
-    task_system(int nthreads);
-
-    // task_system is a singleton.
-    task_system(const task_system&) = delete;
-    task_system& operator=(const task_system&) = delete;
-
-    ~task_system();
-
-    // Pushes tasks into notification queue.
-    void async(task tsk);
-
-    // Runs tasks until quit is true.
-    void run_tasks_loop(int i);
-
-    // Request that the task_system attempts to find and run a _single_ task.
-    // Will return without executing a task if no tasks available.
-    void try_run_task();
-
-    // Includes master thread.
-    int get_num_threads();
-
-    // Returns the thread_id map
-    std::unordered_map<std::thread::id, std::size_t> get_thread_ids();
-};
-
-///////////////////////////////////////////////////////////////////////
-// types
-///////////////////////////////////////////////////////////////////////
-
-template <typename T>
-class enumerable_thread_specific {
-    std::unordered_map<std::thread::id, std::size_t> thread_ids_;
-
-    using storage_class = std::vector<T>;
-    storage_class data;
-
-public:
-    using iterator = typename storage_class::iterator;
-    using const_iterator = typename storage_class::const_iterator;
-
-    enumerable_thread_specific(const task_system_handle& ts):
-        thread_ids_{ts.get()->get_thread_ids()},
-        data{std::vector<T>(ts.get()->get_num_threads())}
-    {}
-
-    enumerable_thread_specific(const T& init, const task_system_handle& ts):
-        thread_ids_{ts.get()->get_thread_ids()},
-        data{std::vector<T>(ts.get()->get_num_threads(), init)}
-    {}
-
-    T& local() {
-        return data[thread_ids_.at(std::this_thread::get_id())];
-    }
-    const T& local() const {
-        return data[thread_ids_.at(std::this_thread::get_id())];
-    }
-
-    auto size() const { return data.size(); }
-
-    iterator begin() { return data.begin(); }
-    iterator end()   { return data.end(); }
-
-    const_iterator begin() const { return data.begin(); }
-    const_iterator end()   const { return data.end(); }
-
-    const_iterator cbegin() const { return data.cbegin(); }
-    const_iterator cend()   const { return data.cend(); }
-};
-
-inline std::string description() {
-    return "CThread Pool";
-}
-
-constexpr bool multithreaded() { return true; }
-
-class task_group {
-private:
-    std::atomic<std::size_t> in_flight_{0};
-    /// We use a raw pointer here instead of a shared_ptr to avoid a race condition
-    /// on the destruction of a task_system that would lead to a thread trying to join itself
-    task_system* task_system_;
-
-public:
-    task_group(task_system* ts):
-        task_system_{ts}
-    {}
-
-    task_group(const task_group&) = delete;
-    task_group& operator=(const task_group&) = delete;
-
-    template <typename F>
-    class wrap {
-        F f;
-        std::atomic<std::size_t>& counter;
-
-    public:
-
-        // Construct from a compatible function and atomic counter
-        template <typename F2>
-        explicit wrap(F2&& other, std::atomic<std::size_t>& c):
-                f(std::forward<F2>(other)),
-                counter(c)
-        {}
-
-        wrap(wrap&& other):
-                f(std::move(other.f)),
-                counter(other.counter)
-        {}
-
-        // std::function is not guaranteed to not copy the contents on move construction
-        // But the class is safe because we don't call operator() more than once on the same wrapped task
-        wrap(const wrap& other):
-                f(other.f),
-                counter(other.counter)
-        {}
-
-        void operator()() {
-            f();
-            --counter;
-        }
-    };
-
-    template <typename F>
-    using callable = typename std::decay<F>::type;
-
-    template <typename F>
-    wrap<callable<F>> make_wrapped_function(F&& f, std::atomic<std::size_t>& c) {
-        return wrap<callable<F>>(std::forward<F>(f), c);
-    }
-
-    template<typename F>
-    void run(F&& f) {
-        ++in_flight_;
-        task_system_->async(make_wrapped_function(std::forward<F>(f), in_flight_));
-    }
-
-    // wait till all tasks in this group are done
-    void wait() {
-        while (in_flight_) {
-            task_system_->try_run_task();
-        }
-    }
-
-    // Make sure that all tasks are done before clean up
-    ~task_group() {
-        wait();
-    }
-};
-
-///////////////////////////////////////////////////////////////////////
-// algorithms
-///////////////////////////////////////////////////////////////////////
-struct parallel_for {
-    template <typename F>
-    static void apply(int left, int right, task_system* ts, F f) {
-        task_group g(ts);
-        for (int i = left; i < right; ++i) {
-          g.run([=] {f(i);});
-        }
-        g.wait();
-    }
-};
-} // namespace threading
-} // namespace arb
diff --git a/arbor/threading/cthread_sort.hpp b/arbor/threading/cthread_sort.hpp
deleted file mode 100644
index 631b2117..00000000
--- a/arbor/threading/cthread_sort.hpp
+++ /dev/null
@@ -1,24 +0,0 @@
-#pragma once
-
-namespace arb {
-namespace threading {
-inline namespace cthread {
-
-template <typename RandomIt>
-void sort(RandomIt begin, RandomIt end) {
-    std::sort(begin, end);
-}
-
-template <typename RandomIt, typename Compare>
-void sort(RandomIt begin, RandomIt end, Compare comp) {
-    std::sort(begin, end, comp);
-}
-
-template <typename Container>
-void sort(Container& c) {
-    std::sort(std::begin(c), std::end(c));
-}
-
-} // namespace cthread
-} // namespace threading
-} // namespace arb
diff --git a/arbor/threading/thread_info.cpp b/arbor/threading/thread_info.cpp
new file mode 100644
index 00000000..025de420
--- /dev/null
+++ b/arbor/threading/thread_info.cpp
@@ -0,0 +1,76 @@
+#include <cstdlib>
+#include <exception>
+#include <regex>
+#include <string>
+
+#include <arbor/arbexcept.hpp>
+#include <arbor/util/optional.hpp>
+#include <hardware/node_info.hpp>
+
+#include "thread_info.hpp"
+#include "util/strprintf.hpp"
+
+namespace arb {
+namespace threading {
+
+// Test environment variables for user-specified count of threads.
+//
+// ARB_NUM_THREADS is used if set, otherwise OMP_NUM_THREADS is used.
+//
+// If neither variable is set, returns no value.
+//
+// Valid values for the environment variable are:
+//  0 : Arbor is responsible for picking the number of threads.
+//  >0: The number of threads to use.
+//
+// Throws std::runtime_error:
+//  ARB_NUM_THREADS or OMP_NUM_THREADS is set with invalid value.
+util::optional<size_t> get_env_num_threads() {
+    const char* str;
+
+    // select variable to use:
+    //   If ARB_NUM_THREADS_VAR is set, use $ARB_NUM_THREADS_VAR
+    //   else if ARB_NUM_THREAD set, use it
+    //   else if OMP_NUM_THREADS set, use it
+    if (auto nthreads_var_name = std::getenv("ARB_NUM_THREADS_VAR")) {
+        str = std::getenv(nthreads_var_name);
+    }
+    else if (! (str = std::getenv("ARB_NUM_THREADS"))) {
+        str = std::getenv("OMP_NUM_THREADS");
+    }
+
+    // If the selected var is unset set the number of threads to
+    // the hint given by the standard library
+    if (!str) {
+        return util::nullopt;
+    }
+
+    auto nthreads = std::strtoul(str, nullptr, 10);
+
+    // check that the environment variable string describes a non-negative integer
+    if (errno==ERANGE ||
+        !std::regex_match(str, std::regex("\\s*\\d*[0-9]\\d*\\s*")))
+    {
+        throw arbor_exception(util::pprintf(
+            "requested number of threads \"{}\" is not a valid value", str));
+    }
+
+    return nthreads;
+}
+
+std::size_t num_threads_init() {
+    std::size_t n = 0;
+
+    if (auto env_threads = get_env_num_threads()) {
+        n = env_threads.value();
+    }
+
+    if (!n) {
+        n = hw::node_processors();
+    }
+
+    return n? n: 1;
+}
+
+} // namespace threading
+} // namespace arb
diff --git a/arbor/threading/thread_info.hpp b/arbor/threading/thread_info.hpp
new file mode 100644
index 00000000..42195fe7
--- /dev/null
+++ b/arbor/threading/thread_info.hpp
@@ -0,0 +1,26 @@
+#pragma once
+
+#include <arbor/util/optional.hpp>
+
+namespace arb {
+namespace threading {
+
+// Test environment variables for user-specified count of threads.
+// Potential environment variables are tested in this order:
+//   1. use the environment variable specified by ARB_NUM_THREADS_VAR
+//   2. use ARB_NUM_THREADS
+//   3. use OMP_NUM_THREADS
+//   4. If no variable is set, returns no value.
+//
+// Valid values for the environment variable are:
+//      0 : Arbor is responsible for picking the number of threads.
+//     >0 : The number of threads to use.
+//
+// Throws std::runtime_error:
+//      Environment variable is set with invalid value.
+util::optional<size_t> get_env_num_threads();
+
+size_t num_threads_init();
+
+} // namespace threading
+} // namespace arb
diff --git a/arbor/threading/threading.cpp b/arbor/threading/threading.cpp
index 54857415..7c28cdc8 100644
--- a/arbor/threading/threading.cpp
+++ b/arbor/threading/threading.cpp
@@ -1,86 +1,131 @@
-#include <cstdlib>
-#include <exception>
-#include <regex>
-#include <string>
-
-#include <arbor/arbexcept.hpp>
-#include <arbor/util/optional.hpp>
-#include <hardware/node_info.hpp>
+#include <atomic>
 
 #include "threading.hpp"
-#include "util/strprintf.hpp"
-
-namespace arb {
-namespace threading {
-
-// Test environment variables for user-specified count of threads.
-//
-// ARB_NUM_THREADS is used if set, otherwise OMP_NUM_THREADS is used.
-//
-// If neither variable is set, returns no value.
-//
-// Valid values for the environment variable are:
-//  0 : Arbor is responsible for picking the number of threads.
-//  >0: The number of threads to use.
-//
-// Throws std::runtime_error:
-//  ARB_NUM_THREADS or OMP_NUM_THREADS is set with invalid value.
-util::optional<size_t> get_env_num_threads() {
-    const char* str;
-
-    // select variable to use:
-    //   If ARB_NUM_THREADS_VAR is set, use $ARB_NUM_THREADS_VAR
-    //   else if ARB_NUM_THREAD set, use it
-    //   else if OMP_NUM_THREADS set, use it
-    if (auto nthreads_var_name = std::getenv("ARB_NUM_THREADS_VAR")) {
-        str = std::getenv(nthreads_var_name);
+#include "thread_info.hpp"
+#include <arbor/execution_context.hpp>
+
+using namespace arb::threading::impl;
+using namespace arb::threading;
+using namespace arb;
+
+task notification_queue::try_pop() {
+    task tsk;
+    lock q_lock{q_mutex_, std::try_to_lock};
+    if (q_lock && !q_tasks_.empty()) {
+        tsk = std::move(q_tasks_.front());
+        q_tasks_.pop_front();
     }
-    else if (! (str = std::getenv("ARB_NUM_THREADS"))) {
-        str = std::getenv("OMP_NUM_THREADS");
+    return tsk;
+}
+
+task notification_queue::pop() {
+    task tsk;
+    lock q_lock{q_mutex_};
+    while (q_tasks_.empty() && !quit_) {
+        q_tasks_available_.wait(q_lock);
+    }
+    if (!q_tasks_.empty()) {
+        tsk = std::move(q_tasks_.front());
+        q_tasks_.pop_front();
     }
+    return tsk;
+}
 
-    // If the selected var is unset set the number of threads to
-    // the hint given by the standard library
-    if (!str) {
-        return util::nullopt;
+bool notification_queue::try_push(task& tsk) {
+    {
+        lock q_lock{q_mutex_, std::try_to_lock};
+        if (!q_lock) return false;
+        q_tasks_.push_back(std::move(tsk));
+        tsk = 0;
     }
+    q_tasks_available_.notify_all();
+    return true;
+}
 
-    auto nthreads = std::strtoul(str, nullptr, 10);
+void notification_queue::push(task&& tsk) {
+    {
+        lock q_lock{q_mutex_};
+        q_tasks_.push_back(std::move(tsk));
+    }
+    q_tasks_available_.notify_all();
+}
 
-    // check that the environment variable string describes a non-negative integer
-    if (errno==ERANGE ||
-        !std::regex_match(str, std::regex("\\s*\\d*[0-9]\\d*\\s*")))
+void notification_queue::quit() {
     {
-        throw arbor_exception(util::pprintf(
-            "requested number of threads \"{}\" is not a valid value", str));
+        lock q_lock{q_mutex_};
+        quit_ = true;
+    }
+    q_tasks_available_.notify_all();
+}
+
+void task_system::run_tasks_loop(int i){
+    while (true) {
+        task tsk;
+        for (unsigned n = 0; n != count_; n++) {
+            tsk = q_[(i + n) % count_].try_pop();
+            if (tsk) break;
+        }
+        if (!tsk) tsk = q_[i].pop();
+        if (!tsk) break;
+        tsk();
     }
+}
 
-    return nthreads;
+void task_system::try_run_task() {
+    auto nthreads = get_num_threads();
+    task tsk;
+    for (int n = 0; n != nthreads; n++) {
+        tsk = q_[n % nthreads].try_pop();
+        if (tsk) {
+            tsk();
+            break;
+        }
+    }
 }
 
-std::size_t num_threads_init() {
-    std::size_t n = 0;
+task_system::task_system(): task_system(num_threads_init()) {}
 
-    if (auto env_threads = get_env_num_threads()) {
-        n = env_threads.value();
+task_system::task_system(int nthreads): count_(nthreads), q_(nthreads) {
+    if (nthreads <= 0)
+        throw std::runtime_error("Non-positive number of threads in thread pool");
+
+    // Main thread
+    auto tid = std::this_thread::get_id();
+    thread_ids_[tid] = 0;
+
+    for (unsigned i = 1; i < count_; i++) {
+        threads_.emplace_back([this, i]{run_tasks_loop(i);});
+        tid = threads_.back().get_id();
+        thread_ids_[tid] = i;
     }
+}
 
-    if (!n) {
-        n = hw::node_processors();
+task_system::~task_system() {
+    for (auto& e: q_) e.quit();
+    for (auto& e: threads_) e.join();
+}
+
+void task_system::async(task tsk) {
+    auto i = index_++;
+
+    for (unsigned n = 0; n != count_; n++) {
+        if (q_[(i + n) % count_].try_push(tsk)) return;
     }
+    q_[i % count_].push(std::move(tsk));
+}
 
-    return n? n: 1;
+int task_system::get_num_threads() {
+    return threads_.size() + 1;
 }
 
-// Returns the number of threads used by the threading back end.
-// Throws:
-//      std::runtime_error if an invalid environment variable was set for the
-//      number of threads.
-size_t num_threads() {
-    // TODO: this is a bit of a hack until we have user-configurable threading.
-    static size_t num_threads_cached = num_threads_init();
-    return num_threads_cached;
+std::unordered_map<std::thread::id, std::size_t> task_system::get_thread_ids() {
+    return thread_ids_;
+};
+
+task_system_handle arb::make_thread_pool() {
+    return arb::make_thread_pool(num_threads_init());
 }
 
-} // namespace threading
-} // namespace arb
+task_system_handle arb::make_thread_pool(int nthreads) {
+    return task_system_handle(new task_system(nthreads));
+}
diff --git a/arbor/threading/threading.hpp b/arbor/threading/threading.hpp
index d6fe7b93..ea084dcc 100644
--- a/arbor/threading/threading.hpp
+++ b/arbor/threading/threading.hpp
@@ -1,28 +1,238 @@
 #pragma once
 
-#include <arbor/util/optional.hpp>
+#include <algorithm>
+#include <array>
+#include <atomic>
+#include <condition_variable>
+#include <cstddef>
+#include <deque>
+#include <functional>
+#include <mutex>
+#include <string>
+#include <thread>
+#include <vector>
+#include <unordered_map>
+#include <utility>
+
+#include <arbor/execution_context.hpp>
 
 namespace arb {
 namespace threading {
 
-// Test environment variables for user-specified count of threads.
-// Potential environment variables are tested in this order:
-//   1. use the environment variable specified by ARB_NUM_THREADS_VAR
-//   2. use ARB_NUM_THREADS
-//   3. use OMP_NUM_THREADS
-//   4. If no variable is set, returns no value.
-//
-// Valid values for the environment variable are:
-//      0 : Arbor is responsible for picking the number of threads.
-//     >0 : The number of threads to use.
-//
-// Throws std::runtime_error:
-//      Environment variable is set with invalid value.
-util::optional<size_t> get_env_num_threads();
-
-size_t num_threads();
+// Forward declare task_group at bottom of this header
+class task_group;
+
+using std::mutex;
+using lock = std::unique_lock<mutex>;
+using std::condition_variable;
+using task = std::function<void()>;
+
+namespace impl {
+class notification_queue {
+private:
+    // FIFO of pending tasks.
+    std::deque<task> q_tasks_;
+
+    // Lock and signal on task availability change this is the crucial bit.
+    mutex q_mutex_;
+    condition_variable q_tasks_available_;
+
+    // Flag to handle exit from all threads.
+    bool quit_ = false;
+
+public:
+    // Pops a task from the task queue returns false when queue is empty.
+    task try_pop();
+    task pop();
+
+    // Pushes a task into the task queue and increases task group counter.
+    void push(task&& tsk); // TODO: need to use value?
+    bool try_push(task& tsk);
+
+    // Finish popping all waiting tasks on queue then stop trying to pop new tasks
+    void quit();
+};
+}// namespace impl
+
+class task_system {
+private:
+    unsigned count_;
+
+    std::vector<std::thread> threads_;
+
+    // queue of tasks
+    std::vector<impl::notification_queue> q_;
+
+    // threads -> index
+    std::unordered_map<std::thread::id, std::size_t> thread_ids_;
+
+    // total number of tasks pushed in all queues
+    std::atomic<unsigned> index_{0};
+
+public:
+    task_system();
+    // Create nthreads-1 new c std threads
+    task_system(int nthreads);
+
+    // task_system is a singleton.
+    task_system(const task_system&) = delete;
+    task_system& operator=(const task_system&) = delete;
+
+    ~task_system();
+
+    // Pushes tasks into notification queue.
+    void async(task tsk);
+
+    // Runs tasks until quit is true.
+    void run_tasks_loop(int i);
+
+    // Request that the task_system attempts to find and run a _single_ task.
+    // Will return without executing a task if no tasks available.
+    void try_run_task();
+
+    // Includes master thread.
+    int get_num_threads();
+
+    // Returns the thread_id map
+    std::unordered_map<std::thread::id, std::size_t> get_thread_ids();
+};
+
+///////////////////////////////////////////////////////////////////////
+// types
+///////////////////////////////////////////////////////////////////////
+
+template <typename T>
+class enumerable_thread_specific {
+    std::unordered_map<std::thread::id, std::size_t> thread_ids_;
+
+    using storage_class = std::vector<T>;
+    storage_class data;
+
+public:
+    using iterator = typename storage_class::iterator;
+    using const_iterator = typename storage_class::const_iterator;
 
+    enumerable_thread_specific(const task_system_handle& ts):
+        thread_ids_{ts.get()->get_thread_ids()},
+        data{std::vector<T>(ts.get()->get_num_threads())}
+    {}
+
+    enumerable_thread_specific(const T& init, const task_system_handle& ts):
+        thread_ids_{ts.get()->get_thread_ids()},
+        data{std::vector<T>(ts.get()->get_num_threads(), init)}
+    {}
+
+    T& local() {
+        return data[thread_ids_.at(std::this_thread::get_id())];
+    }
+    const T& local() const {
+        return data[thread_ids_.at(std::this_thread::get_id())];
+    }
+
+    auto size() const { return data.size(); }
+
+    iterator begin() { return data.begin(); }
+    iterator end()   { return data.end(); }
+
+    const_iterator begin() const { return data.begin(); }
+    const_iterator end()   const { return data.end(); }
+
+    const_iterator cbegin() const { return data.cbegin(); }
+    const_iterator cend()   const { return data.cend(); }
+};
+
+inline std::string description() {
+    return "CThread Pool";
+}
+
+constexpr bool multithreaded() { return true; }
+
+class task_group {
+private:
+    std::atomic<std::size_t> in_flight_{0};
+    /// We use a raw pointer here instead of a shared_ptr to avoid a race condition
+    /// on the destruction of a task_system that would lead to a thread trying to join itself
+    task_system* task_system_;
+
+public:
+    task_group(task_system* ts):
+        task_system_{ts}
+    {}
+
+    task_group(const task_group&) = delete;
+    task_group& operator=(const task_group&) = delete;
+
+    template <typename F>
+    class wrap {
+        F f;
+        std::atomic<std::size_t>& counter;
+
+    public:
+
+        // Construct from a compatible function and atomic counter
+        template <typename F2>
+        explicit wrap(F2&& other, std::atomic<std::size_t>& c):
+                f(std::forward<F2>(other)),
+                counter(c)
+        {}
+
+        wrap(wrap&& other):
+                f(std::move(other.f)),
+                counter(other.counter)
+        {}
+
+        // std::function is not guaranteed to not copy the contents on move construction
+        // But the class is safe because we don't call operator() more than once on the same wrapped task
+        wrap(const wrap& other):
+                f(other.f),
+                counter(other.counter)
+        {}
+
+        void operator()() {
+            f();
+            --counter;
+        }
+    };
+
+    template <typename F>
+    using callable = typename std::decay<F>::type;
+
+    template <typename F>
+    wrap<callable<F>> make_wrapped_function(F&& f, std::atomic<std::size_t>& c) {
+        return wrap<callable<F>>(std::forward<F>(f), c);
+    }
+
+    template<typename F>
+    void run(F&& f) {
+        ++in_flight_;
+        task_system_->async(make_wrapped_function(std::forward<F>(f), in_flight_));
+    }
+
+    // wait till all tasks in this group are done
+    void wait() {
+        while (in_flight_) {
+            task_system_->try_run_task();
+        }
+    }
+
+    // Make sure that all tasks are done before clean up
+    ~task_group() {
+        wait();
+    }
+};
+
+///////////////////////////////////////////////////////////////////////
+// algorithms
+///////////////////////////////////////////////////////////////////////
+struct parallel_for {
+    template <typename F>
+    static void apply(int left, int right, task_system* ts, F f) {
+        task_group g(ts);
+        for (int i = left; i < right; ++i) {
+          g.run([=] {f(i);});
+        }
+        g.wait();
+    }
+};
 } // namespace threading
 } // namespace arb
-
-#include "cthread.hpp"
diff --git a/aux/include/aux/with_mpi.hpp b/aux/include/aux/with_mpi.hpp
index f9c8cb70..6a0e21e5 100644
--- a/aux/include/aux/with_mpi.hpp
+++ b/aux/include/aux/with_mpi.hpp
@@ -4,6 +4,8 @@
 
 #include <arbor/communication/mpi_error.hpp>
 
+namespace aux {
+
 struct with_mpi {
     with_mpi(int& argc, char**& argv, bool fatal_errors = true) {
         init(&argc, &argv, fatal_errors);
@@ -33,3 +35,5 @@ private:
         }
     }
 };
+
+}
diff --git a/example/bench/bench.cpp b/example/bench/bench.cpp
index f8934b73..a8c32991 100644
--- a/example/bench/bench.cpp
+++ b/example/bench/bench.cpp
@@ -16,7 +16,6 @@
 #include <arbor/profile/profiler.hpp>
 #include <arbor/recipe.hpp>
 #include <arbor/simulation.hpp>
-#include <arbor/threadinfo.hpp>
 #include <arbor/version.hpp>
 
 
@@ -35,13 +34,13 @@ int main(int argc, char** argv) {
     try {
         arb::execution_context context;
 #ifdef ARB_MPI_ENABLED
-        aux::with_mpi guard(&argc, &argv);
-        context.distributed = mpi_context(MPI_COMM_WORLD);
+        aux::with_mpi guard(argc, argv, false);
+        context.distributed = arb::mpi_context(MPI_COMM_WORLD);
 #endif
 #ifdef ARB_PROFILE_ENABLED
         profile::profiler_initialize(context.thread_pool);
 #endif
-        const bool is_root =  context.distributed.id()==0;
+        const bool is_root =  context.distributed->id()==0;
 
         std::cout << aux::mask_stream(is_root);
 
@@ -49,7 +48,7 @@ int main(int argc, char** argv) {
 
         std::cout << params << "\n";
 
-        profile::meter_manager meters(&context.distributed);
+        profile::meter_manager meters(context.distributed);
         meters.start();
 
         // Create an instance of our recipe.
@@ -57,12 +56,12 @@ int main(int argc, char** argv) {
         meters.checkpoint("recipe-build");
 
         // Make the domain decomposition for the model
-        auto local = arb::local_allocation();
-        auto decomp = arb::partition_load_balance(recipe, local, &context);
+        auto local = arb::local_allocation(context);
+        auto decomp = arb::partition_load_balance(recipe, local, context);
         meters.checkpoint("domain-decomp");
 
         // Construct the model.
-        arb::simulation sim(recipe, decomp, &context);
+        arb::simulation sim(recipe, decomp, context);
         meters.checkpoint("model-build");
 
         // Run the simulation for 100 ms, with time steps of 0.01 ms.
diff --git a/example/brunel/brunel_miniapp.cpp b/example/brunel/brunel_miniapp.cpp
index 19f3df93..9a377a81 100644
--- a/example/brunel/brunel_miniapp.cpp
+++ b/example/brunel/brunel_miniapp.cpp
@@ -7,7 +7,6 @@
 #include <vector>
 
 #include <arbor/common_types.hpp>
-#include <arbor/distributed_context.hpp>
 #include <arbor/domain_decomposition.hpp>
 #include <arbor/event_generator.hpp>
 #include <arbor/lif_cell.hpp>
@@ -16,7 +15,6 @@
 #include <arbor/profile/profiler.hpp>
 #include <arbor/recipe.hpp>
 #include <arbor/simulation.hpp>
-#include <arbor/threadinfo.hpp>
 #include <arbor/version.hpp>
 
 #include <aux/ioutil.hpp>
@@ -32,7 +30,7 @@
 
 using namespace arb;
 
-void banner(proc_allocation, const execution_context*);
+void banner(proc_allocation, const execution_context&);
 
 // Samples m unique values in interval [start, end) - gid.
 // We exclude gid because we don't want self-loops.
@@ -189,19 +187,19 @@ int main(int argc, char** argv) {
 
     try {
 #ifdef ARB_MPI_ENABLED
-        with_mpi guard(argc, argv, false);
+        aux::with_mpi guard(argc, argv, false);
         context.distributed = mpi_context(MPI_COMM_WORLD);
 #endif
 #ifdef ARB_PROFILE_ENABLED
         profile::profiler_initialize(context.thread_pool);
 #endif
-        arb::profile::meter_manager meters(&context.distributed);
+        arb::profile::meter_manager meters(context.distributed);
         meters.start();
-        std::cout << aux::mask_stream(context.distributed.id()==0);
+        std::cout << aux::mask_stream(context.distributed->id()==0);
         // read parameters
-        io::cl_options options = io::read_options(argc, argv, context.distributed.id()==0);
-        proc_allocation nd = local_allocation();
-        banner(nd, &context);
+        io::cl_options options = io::read_options(argc, argv, context.distributed->id()==0);
+        proc_allocation nd = local_allocation(context);
+        banner(nd, context);
 
         meters.checkpoint("setup");
 
@@ -238,16 +236,16 @@ int main(int argc, char** argv) {
 
         partition_hint_map hints;
         hints[cell_kind::lif_neuron].cpu_group_size = group_size;
-        auto decomp = partition_load_balance(recipe, nd, &context, hints);
+        auto decomp = partition_load_balance(recipe, nd, context, hints);
 
-        simulation sim(recipe, decomp, &context);
+        simulation sim(recipe, decomp, context);
 
         // Initialize the spike exporting interface
         std::fstream spike_out;
         if (options.spike_file_output) {
             using std::ios_base;
 
-            auto rank = context.distributed.id();
+            auto rank = context.distributed->id();
             aux::path p = options.output_path;
             p /= aux::strsub("%_%.%", options.file_name, rank, options.file_extension);
 
@@ -274,7 +272,7 @@ int main(int argc, char** argv) {
 
         auto report = profile::make_meter_report(meters);
         std::cout << report;
-        if (context.distributed.id()==0) {
+        if (context.distributed->id()==0) {
             std::ofstream fid;
             fid.exceptions(std::ios_base::badbit | std::ios_base::failbit);
             fid.open("meters.json");
@@ -283,7 +281,7 @@ int main(int argc, char** argv) {
     }
     catch (io::usage_error& e) {
         // only print usage/startup errors on master
-        std::cerr << aux::mask_stream(context.distributed.id()==0);
+        std::cerr << aux::mask_stream(context.distributed->id()==0);
         std::cerr << e.what() << "\n";
         return 1;
     }
@@ -294,13 +292,12 @@ int main(int argc, char** argv) {
     return 0;
 }
 
-void banner(proc_allocation nd, const execution_context* ctx) {
+void banner(proc_allocation nd, const execution_context& ctx) {
     std::cout << "==========================================\n";
     std::cout << "  Arbor miniapp\n";
-    std::cout << "  - distributed : " << ctx->distributed.size()
-              << " (" << ctx->distributed.name() << ")\n";
-    std::cout << "  - threads     : " << nd.num_threads
-              << " (" << arb::thread_implementation() << ")\n";
+    std::cout << "  - distributed : " << ctx.distributed->size()
+              << " (" << ctx.distributed->name() << ")\n";
+    std::cout << "  - threads     : " << nd.num_threads << "\n";
     std::cout << "  - gpus        : " << nd.num_gpus << "\n";
     std::cout << "==========================================\n";
 }
diff --git a/example/generators/event_gen.cpp b/example/generators/event_gen.cpp
index a2e3d595..6cd934d0 100644
--- a/example/generators/event_gen.cpp
+++ b/example/generators/event_gen.cpp
@@ -133,11 +133,11 @@ int main() {
     generator_recipe recipe;
 
     // Make the domain decomposition for the model
-    auto node = arb::local_allocation();
-    auto decomp = arb::partition_load_balance(recipe, node, &context);
+    auto node = arb::local_allocation(context);
+    auto decomp = arb::partition_load_balance(recipe, node, context);
 
     // Construct the model.
-    arb::simulation sim(recipe, decomp, &context);
+    arb::simulation sim(recipe, decomp, context);
 
     // Set up the probe that will measure voltage in the cell.
 
diff --git a/example/miniapp/miniapp.cpp b/example/miniapp/miniapp.cpp
index 590cdcfb..a84aedc1 100644
--- a/example/miniapp/miniapp.cpp
+++ b/example/miniapp/miniapp.cpp
@@ -14,7 +14,6 @@
 #include <arbor/sampling.hpp>
 #include <arbor/schedule.hpp>
 #include <arbor/simulation.hpp>
-#include <arbor/threadinfo.hpp>
 #include <arbor/util/any.hpp>
 #include <arbor/version.hpp>
 
@@ -36,7 +35,7 @@ using namespace arb;
 
 using util::any_cast;
 
-void banner(proc_allocation, const execution_context*);
+void banner(proc_allocation, const execution_context&);
 std::unique_ptr<recipe> make_recipe(const io::cl_options&, const probe_distribution&);
 sample_trace make_trace(const probe_info& probe);
 std::fstream& open_or_throw(std::fstream& file, const aux::path& p, bool exclusive = false);
@@ -48,26 +47,26 @@ int main(int argc, char** argv) {
 
     try {
 #ifdef ARB_MPI_ENABLED
-        with_mpi guard(argc, argv, false);
+        aux::with_mpi guard(argc, argv, false);
         context.distributed = mpi_context(MPI_COMM_WORLD);
 #endif
 #ifdef ARB_PROFILE_ENABLED
         profile::profiler_initialize(context.thread_pool);
 #endif
-        profile::meter_manager meters(&context.distributed);
+        profile::meter_manager meters(context.distributed);
         meters.start();
 
-        std::cout << aux::mask_stream(context.distributed.id()==0);
+        std::cout << aux::mask_stream(context.distributed->id()==0);
         // read parameters
-        io::cl_options options = io::read_options(argc, argv, context.distributed.id()==0);
+        io::cl_options options = io::read_options(argc, argv, context.distributed->id()==0);
 
         // TODO: add dry run mode
 
         // Use a node description that uses the number of threads used by the
         // threading back end, and 1 gpu if available.
-        proc_allocation nd = local_allocation();
+        proc_allocation nd = local_allocation(context);
         nd.num_gpus = nd.num_gpus>=1? 1: 0;
-        banner(nd, &context);
+        banner(nd, context);
 
         meters.checkpoint("setup");
 
@@ -81,8 +80,8 @@ int main(int argc, char** argv) {
             report_compartment_stats(*recipe);
         }
 
-        auto decomp = partition_load_balance(*recipe, nd, &context);
-        simulation sim(*recipe, decomp, &context);
+        auto decomp = partition_load_balance(*recipe, nd, context);
+        simulation sim(*recipe, decomp, context);
 
         // Set up samplers for probes on local cable cells, as requested
         // by command line options.
@@ -119,7 +118,7 @@ int main(int argc, char** argv) {
         if (options.spike_file_output) {
             using std::ios_base;
 
-            auto rank = context.distributed.id();
+            auto rank = context.distributed->id();
             aux::path p = options.output_path;
             p /= aux::strsub("%_%.%", options.file_name, rank, options.file_extension);
 
@@ -153,7 +152,7 @@ int main(int argc, char** argv) {
 
         auto report = profile::make_meter_report(meters);
         std::cout << report;
-        if (context.distributed.id()==0) {
+        if (context.distributed->id()==0) {
             std::ofstream fid;
             fid.exceptions(std::ios_base::badbit | std::ios_base::failbit);
             fid.open("meters.json");
@@ -162,7 +161,7 @@ int main(int argc, char** argv) {
     }
     catch (io::usage_error& e) {
         // only print usage/startup errors on master
-        std::cerr << aux::mask_stream(context.distributed.id()==0);
+        std::cerr << aux::mask_stream(context.distributed->id()==0);
         std::cerr << e.what() << "\n";
         return 1;
     }
@@ -173,13 +172,12 @@ int main(int argc, char** argv) {
     return 0;
 }
 
-void banner(proc_allocation nd, const execution_context* ctx) {
+void banner(proc_allocation nd, const execution_context& ctx) {
     std::cout << "==========================================\n";
     std::cout << "  Arbor miniapp\n";
-    std::cout << "  - distributed : " << ctx->distributed.size()
-              << " (" << ctx->distributed.name() << ")\n";
-    std::cout << "  - threads     : " << nd.num_threads
-              << " (" << arb::thread_implementation() << ")\n";
+    std::cout << "  - distributed : " << ctx.distributed->size()
+              << " (" << ctx.distributed->name() << ")\n";
+    std::cout << "  - threads     : " << nd.num_threads << "\n";
     std::cout << "  - gpus        : " << nd.num_gpus << "\n";
     std::cout << "==========================================\n";
 }
diff --git a/include/arbor/distributed_context.hpp b/include/arbor/distributed_context.hpp
index a59846b5..4f545874 100644
--- a/include/arbor/distributed_context.hpp
+++ b/include/arbor/distributed_context.hpp
@@ -165,10 +165,10 @@ inline distributed_context::distributed_context():
 
 // MPI context creation functions only provided if built with MPI support.
 
-distributed_context mpi_context();
+std::shared_ptr<distributed_context> mpi_context();
 
 template <typename MPICommType>
-distributed_context mpi_context(MPICommType);
+std::shared_ptr<distributed_context> mpi_context(MPICommType);
 
 } // namespace arb
 
diff --git a/include/arbor/domain_decomposition.hpp b/include/arbor/domain_decomposition.hpp
index 61f1710b..2aa14eaf 100644
--- a/include/arbor/domain_decomposition.hpp
+++ b/include/arbor/domain_decomposition.hpp
@@ -6,6 +6,7 @@
 
 #include <arbor/assert.hpp>
 #include <arbor/common_types.hpp>
+#include <arbor/execution_context.hpp>
 
 namespace arb {
 
@@ -16,7 +17,7 @@ struct proc_allocation {
 };
 
 /// Determine available local domain resources.
-proc_allocation local_allocation();
+proc_allocation local_allocation(const execution_context& ctx);
 
 /// Metadata for a local cell group.
 struct group_description {
diff --git a/include/arbor/execution_context.hpp b/include/arbor/execution_context.hpp
index fac08b82..fdc9ea9f 100644
--- a/include/arbor/execution_context.hpp
+++ b/include/arbor/execution_context.hpp
@@ -3,10 +3,8 @@
 #include <memory>
 #include <string>
 
-#include <arbor/domain_decomposition.hpp>
 #include <arbor/distributed_context.hpp>
 #include <arbor/util/pp_util.hpp>
-#include <arbor/threadinfo.hpp>
 
 
 namespace arb {
@@ -14,16 +12,19 @@ namespace threading {
     class task_system;
 }
 using task_system_handle = std::shared_ptr<threading::task_system>;
+using distributed_context_handle = std::shared_ptr<distributed_context>;
+
+task_system_handle make_thread_pool();
+task_system_handle make_thread_pool(int nthreads);
 
-task_system_handle make_thread_pool (int nthreads);
 
 struct execution_context {
-    // TODO: use a shared_ptr for distributed_context
-    distributed_context distributed;
+    distributed_context_handle distributed;
     task_system_handle thread_pool;
 
-    execution_context(): thread_pool(arb::make_thread_pool(arb::num_threads())) {};
-    execution_context(proc_allocation nd): thread_pool(arb::make_thread_pool(nd.num_threads)) {};
+    execution_context(): distributed(std::make_shared<distributed_context>()),
+                         thread_pool(arb::make_thread_pool())  {};
+
 };
 
 }
diff --git a/include/arbor/load_balance.hpp b/include/arbor/load_balance.hpp
index 5866f4a8..004445cc 100644
--- a/include/arbor/load_balance.hpp
+++ b/include/arbor/load_balance.hpp
@@ -19,7 +19,7 @@ using partition_hint_map = std::unordered_map<cell_kind, partition_hint>;
 domain_decomposition partition_load_balance(
     const recipe& rec,
     proc_allocation nd,
-    const execution_context* ctx,
+    const execution_context& ctx,
     partition_hint_map hint_map = {});
 
 } // namespace arb
diff --git a/include/arbor/profile/meter_manager.hpp b/include/arbor/profile/meter_manager.hpp
index 63e90f27..454d8dff 100644
--- a/include/arbor/profile/meter_manager.hpp
+++ b/include/arbor/profile/meter_manager.hpp
@@ -4,7 +4,7 @@
 #include <string>
 #include <vector>
 
-#include <arbor/distributed_context.hpp>
+#include <arbor/execution_context.hpp>
 #include <arbor/profile/meter.hpp>
 #include <arbor/profile/timer.hpp>
 
@@ -25,7 +25,7 @@ struct measurement {
     std::string name;
     std::string units;
     std::vector<std::vector<double>> measurements;
-    measurement(std::string, std::string, const std::vector<double>&, const distributed_context*);
+    measurement(std::string, std::string, const std::vector<double>&, const distributed_context_handle&);
 };
 
 class meter_manager {
@@ -38,13 +38,13 @@ private:
     std::vector<std::unique_ptr<meter>> meters_;
     std::vector<std::string> checkpoint_names_;
 
-    const distributed_context* glob_ctx_;
+    distributed_context_handle glob_ctx_;
 
 public:
-    meter_manager(const distributed_context* ctx);
+    meter_manager(distributed_context_handle ctx);
     void start();
     void checkpoint(std::string name);
-    const distributed_context* context() const;
+    distributed_context_handle context() const;
 
     const std::vector<std::unique_ptr<meter>>& meters() const;
     const std::vector<std::string>& checkpoint_names() const;
diff --git a/include/arbor/simulation.hpp b/include/arbor/simulation.hpp
index d652108e..083409a6 100644
--- a/include/arbor/simulation.hpp
+++ b/include/arbor/simulation.hpp
@@ -22,7 +22,7 @@ class simulation_state;
 
 class simulation {
 public:
-    simulation(const recipe& rec, const domain_decomposition& decomp, const execution_context* ctx);
+    simulation(const recipe& rec, const domain_decomposition& decomp, execution_context ctx);
 
     void reset();
 
diff --git a/include/arbor/threadinfo.hpp b/include/arbor/threadinfo.hpp
deleted file mode 100644
index 95de0c14..00000000
--- a/include/arbor/threadinfo.hpp
+++ /dev/null
@@ -1,13 +0,0 @@
-#pragma once
-
-#include <string>
-
-// Query underlying threading implementation for information.
-// (Stop-gap until we virtualize threading interface.)
-
-namespace arb {
-
-int num_threads();
-std::string thread_implementation();
-
-} // namespace arb
diff --git a/test/ubench/task_system.cpp b/test/ubench/task_system.cpp
index e0a7303d..873b16ea 100644
--- a/test/ubench/task_system.cpp
+++ b/test/ubench/task_system.cpp
@@ -6,31 +6,30 @@
 #include <iostream>
 #include <thread>
 
-#include <arbor/threadinfo.hpp>
 #include <arbor/version.hpp>
 
-#include "threading/cthread.hpp"
+#include "threading/threading.hpp"
 
 #include <benchmark/benchmark.h>
 
 using namespace arb;
 
-void run(unsigned long us_per_task, unsigned tasks) {
-    arb::threading::task_system ts(arb::num_threads());
+void run(unsigned long us_per_task, unsigned tasks, threading::task_system* ts) {
     auto duration = std::chrono::microseconds(us_per_task);
     arb::threading::parallel_for::apply(
-            0, tasks, &ts,
+            0, tasks, ts,
             [&](unsigned i){std::this_thread::sleep_for(duration);});
 }
 
 void task_test(benchmark::State& state) {
     const unsigned us_per_task = state.range(0);
-    const auto nthreads = arb::num_threads();
+    arb::threading::task_system ts;
+    const auto nthreads = ts.get_num_threads();
     const unsigned us_per_s = 1000000;
     const unsigned num_tasks = nthreads*us_per_s/us_per_task;
 
     while (state.KeepRunning()) {
-        run(us_per_task, num_tasks);
+        run(us_per_task, num_tasks, &ts);
     }
 }
 
diff --git a/test/unit-distributed/distributed_listener.cpp b/test/unit-distributed/distributed_listener.cpp
index 16c3f8ce..1110dfeb 100644
--- a/test/unit-distributed/distributed_listener.cpp
+++ b/test/unit-distributed/distributed_listener.cpp
@@ -3,8 +3,6 @@
 #include <stdexcept>
 #include <string>
 
-#include <arbor/distributed_context.hpp>
-
 #include "../gtest.h"
 
 #include "distributed_listener.hpp"
@@ -27,7 +25,7 @@ distributed_listener::printer& operator<<(distributed_listener::printer& p, cons
     return p;
 }
 
-distributed_listener::distributed_listener(std::string f_base, const arb::distributed_context* ctx):
+distributed_listener::distributed_listener(std::string f_base, arb::distributed_context_handle ctx):
     context_(ctx),
     rank_(context_->id()),
     size_(context_->size()),
diff --git a/test/unit-distributed/distributed_listener.hpp b/test/unit-distributed/distributed_listener.hpp
index acf3c1a8..b94acea9 100644
--- a/test/unit-distributed/distributed_listener.hpp
+++ b/test/unit-distributed/distributed_listener.hpp
@@ -4,7 +4,7 @@
 #include <string>
 #include <utility>
 
-#include <arbor/distributed_context.hpp>
+#include <arbor/execution_context.hpp>
 
 #include "../gtest.h"
 
@@ -29,7 +29,7 @@ class distributed_listener: public testing::EmptyTestEventListener {
     using TestPartResult = testing::TestPartResult;
 
 public:
-    distributed_listener(std::string f_base, const arb::distributed_context* ctx);
+    distributed_listener(std::string f_base, arb::distributed_context_handle ctx);
 
     /// Messages that are printed at the start and end of the test program.
     /// i.e. once only.
@@ -63,7 +63,7 @@ private:
     template <typename T>
     friend printer& operator<<(printer&, const T&);
 
-    const arb::distributed_context* context_;
+    arb::distributed_context_handle context_;
     int rank_;
     int size_;
     printer emit_;
diff --git a/test/unit-distributed/test.cpp b/test/unit-distributed/test.cpp
index 4db82042..085a8990 100644
--- a/test/unit-distributed/test.cpp
+++ b/test/unit-distributed/test.cpp
@@ -6,7 +6,6 @@
 #include "../gtest.h"
 
 #include <arbor/execution_context.hpp>
-#include "arbor/threadinfo.hpp"
 
 #include <aux/ioutil.hpp>
 #include <aux/tinyopt.hpp>
@@ -28,10 +27,10 @@ const char* usage_str =
 
 int main(int argc, char **argv) {
 #ifdef TEST_MPI
-    with_mpi guard(argc, argv, false);
+    aux::with_mpi guard(argc, argv, false);
     g_context.distributed = mpi_context(MPI_COMM_WORLD);
 #elif defined(TEST_LOCAL)
-    g_context.distributed = local_context();
+    g_context.distributed = std::make_shared<distributed_context>(local_context());
 #else
 #error "define TEST_MPI or TEST_LOCAL for distributed test"
 #endif
@@ -43,7 +42,7 @@ int main(int argc, char **argv) {
     auto& listeners = testing::UnitTest::GetInstance()->listeners();
     // replace original printer with our custom printer
     delete listeners.Release(listeners.default_result_printer());
-    listeners.Append(new distributed_listener("run_"+g_context.distributed.name(), &g_context.distributed));
+    listeners.Append(new distributed_listener("run_"+g_context.distributed->name(), g_context.distributed));
 
     int return_value = 0;
     try {
@@ -85,5 +84,5 @@ int main(int argc, char **argv) {
 
     // perform global collective, to ensure that all ranks return
     // the same exit code
-    return g_context.distributed.max(return_value);
+    return g_context.distributed->max(return_value);
 }
diff --git a/test/unit-distributed/test_communicator.cpp b/test/unit-distributed/test_communicator.cpp
index 4e7daef3..14a5bc0b 100644
--- a/test/unit-distributed/test_communicator.cpp
+++ b/test/unit-distributed/test_communicator.cpp
@@ -4,11 +4,10 @@
 #include <stdexcept>
 #include <vector>
 
-#include <arbor/distributed_context.hpp>
 #include <arbor/domain_decomposition.hpp>
 #include <arbor/load_balance.hpp>
 #include <arbor/spike_event.hpp>
-#include <threading/cthread.hpp>
+#include <threading/threading.hpp>
 
 #include "communication/communicator.hpp"
 #include "util/filter.hpp"
@@ -23,12 +22,12 @@ static bool is_dry_run() {
 }
 
 TEST(communicator, policy_basics) {
-    const auto num_domains = g_context.distributed.size();
-    const auto rank = g_context.distributed.id();
+    const auto num_domains = g_context.distributed->size();
+    const auto rank = g_context.distributed->id();
 
-    EXPECT_EQ(g_context.distributed.min(rank), 0);
+    EXPECT_EQ(g_context.distributed->min(rank), 0);
     if (!is_dry_run()) {
-        EXPECT_EQ(g_context.distributed.max(rank), num_domains-1);
+        EXPECT_EQ(g_context.distributed->max(rank), num_domains-1);
     }
 }
 
@@ -52,8 +51,8 @@ int get_value(const arb::spike& s) {
 // Test low level spike_gather function when each domain produces the same
 // number of spikes in the pattern used by dry run mode.
 TEST(communicator, gather_spikes_equal) {
-    const auto num_domains = g_context.distributed.size();
-    const auto rank = g_context.distributed.id();
+    const auto num_domains = g_context.distributed->size();
+    const auto rank = g_context.distributed->id();
 
     const auto n_local_spikes = 10;
 
@@ -72,7 +71,7 @@ TEST(communicator, gather_spikes_equal) {
     }
 
     // Perform exchange
-    const auto global_spikes = g_context.distributed.gather_spikes(local_spikes);
+    const auto global_spikes = g_context.distributed->gather_spikes(local_spikes);
 
     // Test that partition information is correct
     const auto& part = global_spikes.partition();
@@ -92,7 +91,7 @@ TEST(communicator, gather_spikes_equal) {
     // is a list of num_domains*n_local_spikes spikes that have
     // contiguous source gid
     const auto& spikes = global_spikes.values();
-    EXPECT_EQ(n_local_spikes*g_context.distributed.size(), int(spikes.size()));
+    EXPECT_EQ(n_local_spikes*g_context.distributed->size(), int(spikes.size()));
     for (auto i=0u; i<spikes.size(); ++i) {
         const auto s = spikes[i];
         EXPECT_EQ(i, unsigned(s.source.gid));
@@ -113,8 +112,8 @@ TEST(communicator, gather_spikes_variant) {
     // number of spikes.
     if (is_dry_run()) return;
 
-    const auto num_domains = g_context.distributed.size();
-    const auto rank = g_context.distributed.id();
+    const auto num_domains = g_context.distributed->size();
+    const auto rank = g_context.distributed->id();
 
     // Parameter used to scale the number of spikes generated on successive
     // ranks.
@@ -138,7 +137,7 @@ TEST(communicator, gather_spikes_variant) {
     }
 
     // Perform exchange
-    const auto global_spikes = g_context.distributed.gather_spikes(local_spikes);
+    const auto global_spikes = g_context.distributed->gather_spikes(local_spikes);
 
     // Test that partition information is correct
     const auto& part =global_spikes.partition();
@@ -168,7 +167,7 @@ namespace {
     public:
         ring_recipe(cell_size_type s):
             size_(s),
-            ranks_(g_context.distributed.size())
+            ranks_(g_context.distributed->size())
         {}
 
         cell_size_type num_cells() const override {
@@ -232,7 +231,7 @@ namespace {
     public:
         all2all_recipe(cell_size_type s):
             size_(s),
-            ranks_(g_context.distributed.size())
+            ranks_(g_context.distributed->size())
         {}
 
         cell_size_type num_cells() const override {
@@ -315,10 +314,10 @@ test_ring(const domain_decomposition& D, communicator& C, F&& f) {
 
     // gather the global set of spikes
     auto global_spikes = C.exchange(local_spikes);
-    if (global_spikes.size()!=g_context.distributed.sum(local_spikes.size())) {
+    if (global_spikes.size()!=g_context.distributed->sum(local_spikes.size())) {
         return ::testing::AssertionFailure() << "the number of gathered spikes "
             << global_spikes.size() << " doesn't match the expected "
-            << g_context.distributed.sum(local_spikes.size());
+            << g_context.distributed->sum(local_spikes.size());
     }
 
     // generate the events
@@ -364,7 +363,7 @@ TEST(communicator, ring)
     using util::make_span;
 
     // construct a homogeneous network of 10*n_domain identical cells in a ring
-    unsigned N = g_context.distributed.size();
+    unsigned N = g_context.distributed->size();
 
     unsigned n_local = 10u;
     unsigned n_global = n_local*N;
@@ -372,8 +371,8 @@ TEST(communicator, ring)
     auto R = ring_recipe(n_global);
     // use a node decomposition that reflects the resources available
     // on the node that the test is running on, including gpus.
-    const auto D = partition_load_balance(R, local_allocation(), &g_context);
-    auto C = communicator(R, D, &g_context);
+    const auto D = partition_load_balance(R, local_allocation(g_context), g_context);
+    auto C = communicator(R, D, g_context);
 
     // every cell fires
     EXPECT_TRUE(test_ring(D, C, [](cell_gid_type g){return true;}));
@@ -406,10 +405,10 @@ test_all2all(const domain_decomposition& D, communicator& C, F&& f) {
 
     // gather the global set of spikes
     auto global_spikes = C.exchange(local_spikes);
-    if (global_spikes.size()!=g_context.distributed.sum(local_spikes.size())) {
+    if (global_spikes.size()!=g_context.distributed->sum(local_spikes.size())) {
         return ::testing::AssertionFailure() << "the number of gathered spikes "
             << global_spikes.size() << " doesn't match the expected "
-            << g_context.distributed.sum(local_spikes.size());
+            << g_context.distributed->sum(local_spikes.size());
     }
 
     // generate the events
@@ -459,7 +458,7 @@ TEST(communicator, all2all)
     using util::make_span;
 
     // construct a homogeneous network of 10*n_domain identical cells in a ring
-    unsigned N = g_context.distributed.size();
+    unsigned N = g_context.distributed->size();
 
     unsigned n_local = 10u;
     unsigned n_global = n_local*N;
@@ -467,8 +466,8 @@ TEST(communicator, all2all)
     auto R = all2all_recipe(n_global);
     // use a node decomposition that reflects the resources available
     // on the node that the test is running on, including gpus.
-    const auto D = partition_load_balance(R, local_allocation(), &g_context);
-    auto C = communicator(R, D, &g_context);
+    const auto D = partition_load_balance(R, local_allocation(g_context), g_context);
+    auto C = communicator(R, D, g_context);
 
     // every cell fires
     EXPECT_TRUE(test_all2all(D, C, [](cell_gid_type g){return true;}));
diff --git a/test/unit-distributed/test_domain_decomposition.cpp b/test/unit-distributed/test_domain_decomposition.cpp
index db3c1f1b..2d513c17 100644
--- a/test/unit-distributed/test_domain_decomposition.cpp
+++ b/test/unit-distributed/test_domain_decomposition.cpp
@@ -7,7 +7,6 @@
 #include <string>
 #include <vector>
 
-#include <arbor/distributed_context.hpp>
 #include <arbor/domain_decomposition.hpp>
 #include <arbor/load_balance.hpp>
 
@@ -65,8 +64,8 @@ namespace {
 }
 
 TEST(domain_decomposition, homogeneous_population) {
-    const auto N = g_context.distributed.size();
-    const auto I = g_context.distributed.id();
+    const auto N = g_context.distributed->size();
+    const auto I = g_context.distributed->id();
 
     {   // Test on a node with 1 cpu core and no gpus.
         // We assume that all cells will be put into cell groups of size 1.
@@ -77,7 +76,7 @@ TEST(domain_decomposition, homogeneous_population) {
         // 10 cells per domain
         unsigned n_local = 10;
         unsigned n_global = n_local*N;
-        const auto D = partition_load_balance(homo_recipe(n_global, dummy_cell{}), nd, &g_context);
+        const auto D = partition_load_balance(homo_recipe(n_global, dummy_cell{}), nd, g_context);
 
         EXPECT_EQ(D.num_global_cells, n_global);
         EXPECT_EQ(D.num_local_cells, n_local);
@@ -108,7 +107,7 @@ TEST(domain_decomposition, homogeneous_population) {
         // 10 cells per domain
         unsigned n_local = 10;
         unsigned n_global = n_local*N;
-        const auto D = partition_load_balance(homo_recipe(n_global, dummy_cell{}), nd, &g_context);
+        const auto D = partition_load_balance(homo_recipe(n_global, dummy_cell{}), nd, g_context);
 
         EXPECT_EQ(D.num_global_cells, n_global);
         EXPECT_EQ(D.num_local_cells, n_local);
@@ -134,8 +133,8 @@ TEST(domain_decomposition, homogeneous_population) {
 }
 
 TEST(domain_decomposition, heterogeneous_population) {
-    const auto N = g_context.distributed.size();
-    const auto I = g_context.distributed.id();
+    const auto N = g_context.distributed->size();
+    const auto I = g_context.distributed->id();
 
     {   // Test on a node with 1 cpu core and no gpus.
         // We assume that all cells will be put into cell groups of size 1.
@@ -148,7 +147,7 @@ TEST(domain_decomposition, heterogeneous_population) {
         const unsigned n_global = n_local*N;
         const unsigned n_local_grps = n_local; // 1 cell per group
         auto R = hetero_recipe(n_global);
-        const auto D = partition_load_balance(R, nd, &g_context);
+        const auto D = partition_load_balance(R, nd, g_context);
 
         EXPECT_EQ(D.num_global_cells, n_global);
         EXPECT_EQ(D.num_local_cells, n_local);
diff --git a/test/unit/test_algorithms.cpp b/test/unit/test_algorithms.cpp
index a238565e..76655b79 100644
--- a/test/unit/test_algorithms.cpp
+++ b/test/unit/test_algorithms.cpp
@@ -12,32 +12,9 @@
 
 // (Pending abstraction of threading interface)
 #include <arbor/version.hpp>
-#include "threading/cthread.hpp"
+#include "threading/threading.hpp"
 #include "common.hpp"
 
-/// tests the sort implementation in threading
-/// Not parallel
-TEST(algorithms, parallel_sort)
-{
-    auto n = 10000;
-    std::vector<int> v(n);
-    std::iota(v.begin(), v.end(), 1);
-
-    // intialize with the default random seed
-    std::shuffle(v.begin(), v.end(), std::mt19937());
-
-    // assert that the original vector has in fact been permuted
-    EXPECT_FALSE(std::is_sorted(v.begin(), v.end()));
-
-    arb::threading::sort(v);
-
-    EXPECT_TRUE(std::is_sorted(v.begin(), v.end()));
-    for(auto i=0; i<n; ++i) {
-       EXPECT_EQ(i+1, v[i]);
-   }
-}
-
-
 TEST(algorithms, sum)
 {
     // sum of 10 times 2 is 20
diff --git a/test/unit/test_domain_decomposition.cpp b/test/unit/test_domain_decomposition.cpp
index 1e630d97..ac8b4a2d 100644
--- a/test/unit/test_domain_decomposition.cpp
+++ b/test/unit/test_domain_decomposition.cpp
@@ -57,7 +57,7 @@ TEST(domain_decomposition, homogenous_population)
         proc_allocation nd{1, 0};
 
         unsigned num_cells = 10;
-        const auto D = partition_load_balance(homo_recipe(num_cells, dummy_cell{}), nd, &context);
+        const auto D = partition_load_balance(homo_recipe(num_cells, dummy_cell{}), nd, context);
 
         EXPECT_EQ(D.num_global_cells, num_cells);
         EXPECT_EQ(D.num_local_cells, num_cells);
@@ -83,7 +83,7 @@ TEST(domain_decomposition, homogenous_population)
         proc_allocation nd{1, 1};
 
         unsigned num_cells = 10;
-        const auto D = partition_load_balance(homo_recipe(num_cells, dummy_cell{}), nd, &context);
+        const auto D = partition_load_balance(homo_recipe(num_cells, dummy_cell{}), nd, context);
 
         EXPECT_EQ(D.num_global_cells, num_cells);
         EXPECT_EQ(D.num_local_cells, num_cells);
@@ -118,7 +118,7 @@ TEST(domain_decomposition, heterogenous_population)
 
         unsigned num_cells = 10;
         auto R = hetero_recipe(num_cells);
-        const auto D = partition_load_balance(R, nd, &context);
+        const auto D = partition_load_balance(R, nd, context);
 
         EXPECT_EQ(D.num_global_cells, num_cells);
         EXPECT_EQ(D.num_local_cells, num_cells);
@@ -156,7 +156,7 @@ TEST(domain_decomposition, heterogenous_population)
 
         unsigned num_cells = 10;
         auto R = hetero_recipe(num_cells);
-        const auto D = partition_load_balance(R, nd, &context);
+        const auto D = partition_load_balance(R, nd, context);
 
         EXPECT_EQ(D.num_global_cells, num_cells);
         EXPECT_EQ(D.num_local_cells, num_cells);
@@ -203,7 +203,7 @@ TEST(domain_decomposition, hints) {
     domain_decomposition D = partition_load_balance(
         hetero_recipe(20),
         proc_allocation{16, 1}, // 16 threads, 1 gpu.
-        &context,
+        context,
         hints);
 
     std::vector<std::vector<cell_gid_type>> expected_c1d_groups =
diff --git a/test/unit/test_fvm_lowered.cpp b/test/unit/test_fvm_lowered.cpp
index b93d8195..d6631dde 100644
--- a/test/unit/test_fvm_lowered.cpp
+++ b/test/unit/test_fvm_lowered.cpp
@@ -329,8 +329,8 @@ TEST(fvm_lowered, derived_mechs) {
         float times[] = {10.f, 20.f};
 
         execution_context context;
-        auto decomp = partition_load_balance(rec, proc_allocation{1, 0}, &context);
-        simulation sim(rec, decomp, &context);
+        auto decomp = partition_load_balance(rec, proc_allocation{1, 0}, context);
+        simulation sim(rec, decomp, context);
         sim.add_sampler(all_probes, explicit_schedule(times), sampler);
         sim.run(30.0, 1.f/1024);
 
diff --git a/test/unit/test_lif_cell_group.cpp b/test/unit/test_lif_cell_group.cpp
index 25297e25..51283529 100644
--- a/test/unit/test_lif_cell_group.cpp
+++ b/test/unit/test_lif_cell_group.cpp
@@ -1,10 +1,8 @@
 #include "../gtest.h"
 
-#include <arbor/distributed_context.hpp>
 #include <arbor/domain_decomposition.hpp>
 #include <arbor/lif_cell.hpp>
 #include <arbor/load_balance.hpp>
-#include <arbor/threadinfo.hpp>
 #include <arbor/recipe.hpp>
 #include <arbor/schedule.hpp>
 #include <arbor/simulation.hpp>
@@ -157,10 +155,10 @@ TEST(lif_cell_group, spikes) {
     path_recipe recipe(2, 1000, 0.1);
 
     execution_context context;
-    proc_allocation nd = local_allocation();
+    proc_allocation nd = local_allocation(context);
 
-    auto decomp = partition_load_balance(recipe, nd, &context);
-    simulation sim(recipe, decomp, &context);
+    auto decomp = partition_load_balance(recipe, nd, context);
+    simulation sim(recipe, decomp, context);
 
     std::vector<spike_event> events;
 
@@ -196,12 +194,12 @@ TEST(lif_cell_group, ring)
     time_type simulation_time = 100;
 
     execution_context context;
-    proc_allocation nd = local_allocation();
+    proc_allocation nd = local_allocation(context);
     auto recipe = ring_recipe(num_lif_cells, weight, delay);
-    auto decomp = partition_load_balance(recipe, nd, &context);
+    auto decomp = partition_load_balance(recipe, nd, context);
 
     // Creates a simulation with a ring recipe of lif neurons
-    simulation sim(recipe, decomp, &context);
+    simulation sim(recipe, decomp, context);
 
     std::vector<spike> spike_buffer;
 
diff --git a/test/unit/test_thread.cpp b/test/unit/test_thread.cpp
index 298e4c51..ea3c0b4e 100644
--- a/test/unit/test_thread.cpp
+++ b/test/unit/test_thread.cpp
@@ -1,6 +1,5 @@
 #include "../gtest.h"
 #include "common.hpp"
-#include <arbor/threadinfo.hpp>
 #include <arbor/execution_context.hpp>
 
 #include <iostream>
@@ -8,7 +7,7 @@
 // (Pending abstraction of threading interface)
 #include <arbor/version.hpp>
 
-#include "threading/cthread.hpp"
+#include "threading/threading.hpp"
 
 using namespace arb::threading::impl;
 using namespace arb::threading;
@@ -53,7 +52,7 @@ struct ftor_parallel_wait {
     ftor_parallel_wait(task_system* ts): ts{ts} {}
 
     void operator()() const {
-        auto nthreads = num_threads();
+        auto nthreads = ts->get_num_threads();
         auto duration = std::chrono::microseconds(100);
         parallel_for::apply(0, nthreads, ts, [=](int i){ std::this_thread::sleep_for(duration);});
     }
@@ -64,7 +63,7 @@ struct ftor_parallel_wait {
 }
 
 TEST(task_system, test_copy) {
-    task_system ts(num_threads());
+    task_system ts;
 
     ftor f;
     ts.async(f);
@@ -76,7 +75,7 @@ TEST(task_system, test_copy) {
 }
 
 TEST(task_system, test_move) {
-    task_system ts(num_threads());
+    task_system ts;
 
     ftor f;
     ts.async(std::move(f));
@@ -112,7 +111,7 @@ TEST(notification_queue, test_move) {
 }
 
 TEST(task_group, test_copy) {
-    task_system ts(num_threads());
+    task_system ts;
     task_group g(&ts);
 
     ftor f;
@@ -126,7 +125,7 @@ TEST(task_group, test_copy) {
 }
 
 TEST(task_group, test_move) {
-    task_system ts(num_threads());
+    task_system ts;
     task_group g(&ts);
 
     ftor f;
@@ -141,10 +140,10 @@ TEST(task_group, test_move) {
 
 TEST(task_group, individual_tasks) {
     // Simple check for deadlock
-    task_system ts(num_threads());
+    task_system ts;
     task_group g(&ts);
 
-    auto nthreads = num_threads();
+    auto nthreads = ts.get_num_threads();
 
     ftor_wait f;
     for (int i = 0; i < 32 * nthreads; i++) {
@@ -155,8 +154,8 @@ TEST(task_group, individual_tasks) {
 
 TEST(task_group, parallel_for_sleep) {
     // Simple check for deadlock for nested parallelism
-    auto nthreads = num_threads();
-    task_system ts(nthreads);
+    task_system ts;
+    auto nthreads = ts.get_num_threads();
     task_group g(&ts);
 
     ftor_parallel_wait f(&ts);
@@ -167,7 +166,7 @@ TEST(task_group, parallel_for_sleep) {
 }
 
 TEST(task_group, parallel_for) {
-    task_system ts(num_threads());
+    task_system ts;
     for (int n = 0; n < 10000; n=!n?1:2*n) {
         std::vector<int> v(n, -1);
         parallel_for::apply(0, n, &ts, [&](int i) {v[i] = i;});
@@ -178,7 +177,7 @@ TEST(task_group, parallel_for) {
 }
 
 TEST(task_group, nested_parallel_for) {
-    task_system ts(num_threads());
+    task_system ts;
     for (int m = 1; m < 512; m*=2) {
         for (int n = 0; n < 1000; n=!n?1:2*n) {
             std::vector<std::vector<int>> v(n, std::vector<int>(m, -1));
@@ -196,7 +195,7 @@ TEST(task_group, nested_parallel_for) {
 }
 
 TEST(enumerable_thread_specific, test) {
-    task_system_handle ts = task_system_handle(new task_system(num_threads()));
+    task_system_handle ts = task_system_handle(new task_system);
     enumerable_thread_specific<int> buffers(ts);
     task_group g(ts.get());
 
diff --git a/test/validation/validate_ball_and_stick.cpp b/test/validation/validate_ball_and_stick.cpp
index 95257ab7..8d3691d2 100644
--- a/test/validation/validate_ball_and_stick.cpp
+++ b/test/validation/validate_ball_and_stick.cpp
@@ -79,8 +79,8 @@ void run_ncomp_convergence_test(
             rec.add_probe(0, 0, cell_probe_address{p.where, cell_probe_address::membrane_voltage});
         }
 
-        auto decomp = partition_load_balance(rec, nd, &context);
-        simulation sim(rec, decomp, &context);
+        auto decomp = partition_load_balance(rec, nd, context);
+        simulation sim(rec, decomp, context);
 
         runner.run(sim, ncomp, sample_dt, t_end, dt, exclude);
     }
@@ -195,36 +195,41 @@ void validate_ball_and_squiggle(arb::backend_kind backend) {
 }
 
 TEST(ball_and_stick, neuron_ref) {
+    execution_context ctx;
     validate_ball_and_stick(backend_kind::multicore);
-    if (local_allocation().num_gpus) {
+    if (local_allocation(ctx).num_gpus) {
         validate_ball_and_stick(backend_kind::gpu);
     }
 }
 
 TEST(ball_and_taper, neuron_ref) {
+    execution_context ctx;
     validate_ball_and_taper(backend_kind::multicore);
-    if (local_allocation().num_gpus) {
+    if (local_allocation(ctx).num_gpus) {
         validate_ball_and_taper(backend_kind::gpu);
     }
 }
 
 TEST(ball_and_3stick, neuron_ref) {
+    execution_context ctx;
     validate_ball_and_3stick(backend_kind::multicore);
-    if (local_allocation().num_gpus) {
+    if (local_allocation(ctx).num_gpus) {
         validate_ball_and_3stick(backend_kind::gpu);
     }
 }
 
 TEST(rallpack1, numeric_ref) {
+    execution_context ctx;
     validate_rallpack1(backend_kind::multicore);
-    if (local_allocation().num_gpus) {
+    if (local_allocation(ctx).num_gpus) {
         validate_rallpack1(backend_kind::gpu);
     }
 }
 
 TEST(ball_and_squiggle, neuron_ref) {
+    execution_context ctx;
     validate_ball_and_squiggle(backend_kind::multicore);
-    if (local_allocation().num_gpus) {
+    if (local_allocation(ctx).num_gpus) {
         validate_ball_and_squiggle(backend_kind::gpu);
     }
 }
diff --git a/test/validation/validate_kinetic.cpp b/test/validation/validate_kinetic.cpp
index c2994eb3..b8bfb310 100644
--- a/test/validation/validate_kinetic.cpp
+++ b/test/validation/validate_kinetic.cpp
@@ -47,8 +47,8 @@ void run_kinetic_dt(
     proc_allocation nd;
     nd.num_gpus = (backend==backend_kind::gpu);
 
-    auto decomp = partition_load_balance(rec, nd, &context);
-    simulation sim(rec, decomp, &context);
+    auto decomp = partition_load_balance(rec, nd, context);
+    simulation sim(rec, decomp, context);
 
     auto exclude = stimulus_ends(c);
 
@@ -111,15 +111,17 @@ void validate_kinetic_kinlva(arb::backend_kind backend) {
 using namespace arb;
 
 TEST(kinetic, kin1_numeric_ref) {
+    execution_context ctx;
     validate_kinetic_kin1(backend_kind::multicore);
-    if (local_allocation().num_gpus) {
+    if (local_allocation(ctx).num_gpus) {
         validate_kinetic_kin1(arb::backend_kind::gpu);
     }
 }
 
 TEST(kinetic, kinlva_numeric_ref) {
+    execution_context ctx;
     validate_kinetic_kinlva(backend_kind::multicore);
-    if (local_allocation().num_gpus) {
+    if (local_allocation(ctx).num_gpus) {
         validate_kinetic_kinlva(arb::backend_kind::gpu);
     }
 }
diff --git a/test/validation/validate_soma.cpp b/test/validation/validate_soma.cpp
index a3e6460f..f61e3ab1 100644
--- a/test/validation/validate_soma.cpp
+++ b/test/validation/validate_soma.cpp
@@ -33,8 +33,8 @@ void validate_soma(backend_kind backend) {
     proc_allocation nd;
     nd.num_gpus = (backend==backend_kind::gpu);
 
-    auto decomp = partition_load_balance(rec, nd, &context);
-    simulation sim(rec, decomp, &context);
+    auto decomp = partition_load_balance(rec, nd, context);
+    simulation sim(rec, decomp, context);
 
     nlohmann::json meta = {
         {"name", "membrane voltage"},
@@ -68,8 +68,9 @@ end:
 }
 
 TEST(soma, numeric_ref) {
+    execution_context ctx;
     validate_soma(backend_kind::multicore);
-    if (local_allocation().num_gpus) {
+    if (local_allocation(ctx).num_gpus) {
         validate_soma(backend_kind::gpu);
     }
 }
diff --git a/test/validation/validate_synapses.cpp b/test/validation/validate_synapses.cpp
index 57c926c2..cd19a08d 100644
--- a/test/validation/validate_synapses.cpp
+++ b/test/validation/validate_synapses.cpp
@@ -76,8 +76,8 @@ void run_synapse_test(
         // dend.end
         rec.add_probe(0, 0, cell_probe_address{{1, 1.0}, cell_probe_address::membrane_voltage});
 
-        auto decomp = partition_load_balance(rec, nd, &context);
-        simulation sim(rec, decomp, &context);
+        auto decomp = partition_load_balance(rec, nd, context);
+        simulation sim(rec, decomp, context);
 
         sim.inject_events(synthetic_events);
 
@@ -88,18 +88,20 @@ void run_synapse_test(
 }
 
 TEST(simple_synapse, expsyn_neuron_ref) {
+    execution_context ctx;
     SCOPED_TRACE("expsyn-multicore");
     run_synapse_test("expsyn", "neuron_simple_exp_synapse.json", backend_kind::multicore);
-    if (local_allocation().num_gpus) {
+    if (local_allocation(ctx).num_gpus) {
         SCOPED_TRACE("expsyn-gpu");
         run_synapse_test("expsyn", "neuron_simple_exp_synapse.json", backend_kind::gpu);
     }
 }
 
 TEST(simple_synapse, exp2syn_neuron_ref) {
+    execution_context ctx;
     SCOPED_TRACE("exp2syn-multicore");
     run_synapse_test("exp2syn", "neuron_simple_exp2_synapse.json", backend_kind::multicore);
-    if (local_allocation().num_gpus) {
+    if (local_allocation(ctx).num_gpus) {
         SCOPED_TRACE("exp2syn-gpu");
         run_synapse_test("exp2syn", "neuron_simple_exp2_synapse.json", backend_kind::gpu);
     }
-- 
GitLab