diff --git a/arbor/threading/threading.hpp b/arbor/threading/threading.hpp
index 5c9d8ca0f203d87a97d888cc66e73d1e22960eaa..6e5ab9840ae67f4035d08c27b629b579d45d8365 100644
--- a/arbor/threading/threading.hpp
+++ b/arbor/threading/threading.hpp
@@ -97,10 +97,38 @@ public:
 
 class task_group {
 private:
+    struct exception_state {
+        std::atomic<bool> error_{false};
+        std::exception_ptr exception_;
+        std::mutex mutex_;
+
+        operator bool() const {
+            return error_.load(std::memory_order_relaxed);
+        }
+
+        void set(std::exception_ptr ex) {
+            error_.store(true, std::memory_order_relaxed);
+            lock ex_lock{mutex_};
+            exception_ = std::move(ex);
+        }
+
+        std::exception_ptr get() {
+            error_.store(false, std::memory_order_relaxed);
+            return std::move(exception_);
+        }
+
+        ~exception_state() {
+            if(error_) {
+                std::rethrow_exception(exception_);
+            }
+        }
+    };
+
     std::atomic<std::size_t> in_flight_{0};
     /// We use a raw pointer here instead of a shared_ptr to avoid a race condition
     /// on the destruction of a task_system that would lead to a thread trying to join itself
     task_system* task_system_;
+    exception_state exception_status_;
 
 public:
     task_group(task_system* ts):
@@ -112,33 +140,44 @@ public:
 
     template <typename F>
     class wrap {
-        F f;
-        std::atomic<std::size_t>& counter;
+        F f_;
+        std::atomic<std::size_t>& counter_;
+        exception_state& exception_status_;
 
     public:
 
         // Construct from a compatible function and atomic counter
         template <typename F2>
-        explicit wrap(F2&& other, std::atomic<std::size_t>& c):
-                f(std::forward<F2>(other)),
-                counter(c)
+        explicit wrap(F2&& other, std::atomic<std::size_t>& c, exception_state& ex):
+                f_(std::forward<F2>(other)),
+                counter_(c),
+                exception_status_(ex)
         {}
 
         wrap(wrap&& other):
-                f(std::move(other.f)),
-                counter(other.counter)
+                f_(std::move(other.f_)),
+                counter_(other.counter_),
+                exception_status_(other.exception_status_)
         {}
 
         // std::function is not guaranteed to not copy the contents on move construction
         // But the class is safe because we don't call operator() more than once on the same wrapped task
         wrap(const wrap& other):
-                f(other.f),
-                counter(other.counter)
+                f_(other.f_),
+                counter_(other.counter_),
+                exception_status_(other.exception_status_)
         {}
 
         void operator()() {
-            f();
-            --counter;
+            if(!exception_status_) {
+                try {
+                    f_();
+                }
+                catch (...) {
+                    exception_status_.set(std::current_exception());
+                }
+            }
+            --counter_;
         }
     };
 
@@ -146,14 +185,14 @@ public:
     using callable = typename std::decay<F>::type;
 
     template <typename F>
-    wrap<callable<F>> make_wrapped_function(F&& f, std::atomic<std::size_t>& c) {
-        return wrap<callable<F>>(std::forward<F>(f), c);
+    wrap<callable<F>> make_wrapped_function(F&& f, std::atomic<std::size_t>& c, exception_state& ex) {
+        return wrap<callable<F>>(std::forward<F>(f), c, ex);
     }
 
     template<typename F>
     void run(F&& f) {
         ++in_flight_;
-        task_system_->async(make_wrapped_function(std::forward<F>(f), in_flight_));
+        task_system_->async(make_wrapped_function(std::forward<F>(f), in_flight_, exception_status_));
     }
 
     // wait till all tasks in this group are done
@@ -161,6 +200,14 @@ public:
         while (in_flight_) {
             task_system_->try_run_task();
         }
+        if(auto ex = exception_status_.get()) {
+            try {
+                std::rethrow_exception(ex);
+            }
+            catch (...) {
+                throw;
+            }
+        }
     }
 
     // Make sure that all tasks are done before clean up
diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt
index 28fb4002e23d331e40d0efd33b9267ec5c324274..4cacc7cf9ef1f0576f4b5982e220adf76d037dc7 100644
--- a/test/unit/CMakeLists.txt
+++ b/test/unit/CMakeLists.txt
@@ -104,6 +104,7 @@ set(unit_sources
     test_swcio.cpp
     test_synapses.cpp
     test_thread.cpp
+    test_threading_exceptions.cpp
     test_tree.cpp
     test_transform.cpp
     test_uninitialized.cpp
diff --git a/test/unit/test_threading_exceptions.cpp b/test/unit/test_threading_exceptions.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..273210812bdc0c2d5f0a545bc5c82fd9fd3a49aa
--- /dev/null
+++ b/test/unit/test_threading_exceptions.cpp
@@ -0,0 +1,270 @@
+#include "../gtest.h"
+
+#include <arbor/domain_decomposition.hpp>
+#include <arbor/load_balance.hpp>
+#include <arbor/recipe.hpp>
+#include <arbor/simulation.hpp>
+#include <arbor/context.hpp>
+#include <arbor/mc_cell.hpp>
+#include <arbor/version.hpp>
+
+#include "threading/threading.hpp"
+
+using namespace arb::threading::impl;
+using namespace arb::threading;
+using namespace arb;
+
+auto duration = std::chrono::nanoseconds(1);
+
+struct error {
+    int code;
+    error(int c): code(c) {};
+};
+
+TEST(test_exception, single_task_no_throw) {
+    for (int nthreads = 1; nthreads < 20; nthreads*=2) {
+        task_system ts(nthreads);
+        task_group g(&ts);
+        g.run([](){ std::this_thread::sleep_for(duration); });
+        EXPECT_NO_THROW(g.wait());
+    }
+}
+
+TEST(test_exception, single_task_throw) {
+    for (int nthreads = 1; nthreads < 20; nthreads*=2) {
+        task_system ts(1);
+        task_group g(&ts);
+        g.run([](){ throw error(0);});
+        try {
+            g.wait();
+            FAIL() << "Expected exception";
+        }
+        catch (error &e) {
+            EXPECT_EQ(e.code, 0);
+        }
+        catch (...) {
+            FAIL() << "Expected error type";
+        }
+    }
+}
+
+TEST(test_exception, many_tasks_no_throw) {
+    for (int nthreads = 1; nthreads < 20; nthreads*=2) {
+        task_system ts(nthreads);
+        task_group g(&ts);
+        for (int i = 1; i < 100; i++) {
+            for (int j = 0; j < i; j++) {
+                g.run([j](){ std::this_thread::sleep_for(duration); });
+            }
+            EXPECT_NO_THROW(g.wait());
+        }
+    }
+}
+
+TEST(test_exception, many_tasks_one_throw) {
+    for (int nthreads = 1; nthreads < 20; nthreads*=2) {
+        task_system ts(nthreads);
+        task_group g(&ts);
+        for (int i = 1; i < 100; i++) {
+            for (int j = 0; j < i; j++) {
+                g.run([j](){ if(j==0) {throw error(j);} });
+            }
+            try {
+                g.wait();
+                FAIL() << "Expected exception";
+            }
+            catch (error &e) {
+                EXPECT_EQ(e.code, 0);
+            }
+            catch (...) {
+                FAIL() << "Expected error type";
+            }
+        }
+    }
+}
+
+TEST(test_exception, many_tasks_many_throw) {
+    for (int nthreads = 1; nthreads < 20; nthreads*=2) {
+        task_system ts(nthreads);
+        task_group g(&ts);
+        for (int i = 1; i < 100; i++) {
+            for (int j = 0; j < i; j++) {
+                g.run([j](){ if(j%5 == 0) {throw error(j);} });
+            }
+            try {
+                g.wait();
+                FAIL() << "Expected exception";
+            }
+            catch (error &e) {
+                EXPECT_EQ(e.code%5, 0);
+            }
+            catch (...) {
+                FAIL() << "Expected error type";
+            }
+        }
+    }
+}
+
+TEST(test_exception, many_tasks_all_throw) {
+    for (int nthreads = 1; nthreads < 20; nthreads*=2) {
+        task_system ts(nthreads);
+        task_group g(&ts);
+        for (int i = 1; i < 100; i++) {
+            for (int j = 0; j < i; j++) {
+                g.run([j](){ throw error(j); });
+            }
+            try {
+                g.wait();
+                FAIL() << "Expected exception";
+            }
+            catch (error &e) {
+                EXPECT_TRUE((e.code >= 0) && (e.code < i));
+            }
+            catch (...) {
+                FAIL() << "Expected error type";
+            }
+        }
+    }
+}
+
+TEST(test_exception, parallel_for_no_throw) {
+    for (int nthreads = 1; nthreads < 20; nthreads*=2) {
+        task_system ts(nthreads);
+        for (int n = 1; n < 100; n*=2) {
+            EXPECT_NO_THROW(parallel_for::apply(0, n, &ts, [n](int i) { std::this_thread::sleep_for(duration); }));
+        }
+    }
+}
+
+TEST(test_exception, parallel_for_one_throw) {
+    for (int nthreads = 1; nthreads < 20; nthreads*=2) {
+        task_system ts(nthreads);
+        for (int n = 1; n < 100; n*=2) {
+            try {
+                parallel_for::apply(0, n, &ts, [n](int i) { if(i==n-1) {throw error(i);} });
+                FAIL() << "Expected exception";
+            }
+            catch (error &e) {
+                EXPECT_TRUE(e.code == n-1);
+            }
+            catch (...) {
+                FAIL() << "Expected error type";
+            }
+        }
+    }
+}
+
+TEST(test_exception, parallel_for_many_throw) {
+    for (int nthreads = 1; nthreads < 20; nthreads*=2) {
+        task_system ts(nthreads);
+        for (int n = 1; n < 100; n*=2) {
+            try {
+                parallel_for::apply(0, n, &ts, [n](int i) { if(i%7 == 0) {throw error(i);} });
+                FAIL() << "Expected exception";
+            }
+            catch (error &e) {
+                EXPECT_EQ(e.code%7, 0);
+            }
+            catch (...) {
+                FAIL() << "Expected error type";
+            }
+        }
+    }
+}
+
+TEST(test_exception, parallel_for_all_throw) {
+    for (int nthreads = 1; nthreads < 20; nthreads*=2) {
+        task_system ts(nthreads);
+        for (int n = 1; n < 100; n*=2) {
+            try {
+                parallel_for::apply(0, n, &ts, [n](int i) { throw error(i); });
+                FAIL() << "Expected exception";
+            }
+            catch (error &e) {
+                EXPECT_TRUE((e.code >= 0) && (e.code < n));
+            }
+            catch (...) {
+                FAIL() << "Expected error type";
+            }
+        }
+    }
+}
+
+TEST(test_exception, nested_parallel_for_no_throw) {
+    for (int nthreads = 1; nthreads < 20; nthreads*=2) {
+        task_system ts(nthreads);
+        for (int m = 1; m < 50; m*=2) {
+            for (int n = 1; n < 50; n = !n?1:2*n) {
+                EXPECT_NO_THROW(parallel_for::apply(0, n, &ts, [&](int i) {
+                    parallel_for::apply(0, m, &ts, [](int j) { std::this_thread::sleep_for(duration); });
+                }));
+            }
+        }
+    }
+}
+
+TEST(test_exception, nested_parallel_for_one_throw) {
+    for (int nthreads = 1; nthreads < 20; nthreads*=2) {
+        task_system ts(nthreads);
+        for (int m = 1; m < 100; m*=2) {
+            for (int n = 1; n < 100; n = !n?1:2*n) {
+                try {
+                    parallel_for::apply(0, n, &ts, [&](int i) {
+                        parallel_for::apply(0, m, &ts, [](int j) { if (j == 0) {throw error(j);} });
+                    });
+                    FAIL() << "Expected exception";
+                }
+                catch (error &e) {
+                    EXPECT_EQ(e.code, 0);
+                }
+                catch (...) {
+                    FAIL() << "Expected error type";
+                }
+            }
+        }
+    }
+}
+
+TEST(test_exception, nested_parallel_for_many_throw) {
+    for (int nthreads = 1; nthreads < 20; nthreads*=2) {
+        task_system ts(nthreads);
+        for (int m = 1; m < 100; m*=2) {
+            for (int n = 1; n < 100; n = !n?1:2*n) {
+                try {
+                    parallel_for::apply(0, n, &ts, [&](int i) {
+                            parallel_for::apply(0, m, &ts, [](int j) { if (j%10 == 0) {throw error(j);} });
+                    });
+                    FAIL() << "Expected exception";
+                }
+                catch (error &e) {
+                    EXPECT_TRUE(e.code%10==0);
+                }
+                catch (...) {
+                    FAIL() << "Expected error type";
+                }
+            }
+        }
+    }
+}
+
+TEST(test_exception, nested_parallel_for_all_throw) {
+    for (int nthreads = 1; nthreads < 20; nthreads*=2) {
+        task_system ts(nthreads);
+        for (int m = 1; m < 100; m*=2) {
+            for (int n = 1; n < 100; n = !n?1:2*n) {
+                try {
+                    parallel_for::apply(0, n, &ts, [&](int i) {
+                        parallel_for::apply(0, m, &ts, [](int j) { throw error(j); });
+                    });
+                    FAIL() << "Expected exception";
+                }
+                catch (error &e) {
+                    EXPECT_TRUE(e.code >= 0 && e.code < m);
+                }
+                catch (...) {
+                    FAIL() << "Expected error type";
+                }
+            }
+        }
+    }
+}