Skip to content
Snippets Groups Projects
Commit b5662870 authored by noraabiakar's avatar noraabiakar Committed by Benjamin Cumming
Browse files

Threading exceptions (#595)

Propagate exceptions generated in `task_group` tasks on different threads in the threading backend, so that they are thrown on the main thread on `task_group.wait()`.

Add tests that verify that exceptions are propagated correctly.

Fixes #310.
parent ad26b114
Branches
Tags
No related merge requests found
...@@ -97,10 +97,38 @@ public: ...@@ -97,10 +97,38 @@ public:
class task_group { class task_group {
private: private:
struct exception_state {
std::atomic<bool> error_{false};
std::exception_ptr exception_;
std::mutex mutex_;
operator bool() const {
return error_.load(std::memory_order_relaxed);
}
void set(std::exception_ptr ex) {
error_.store(true, std::memory_order_relaxed);
lock ex_lock{mutex_};
exception_ = std::move(ex);
}
std::exception_ptr get() {
error_.store(false, std::memory_order_relaxed);
return std::move(exception_);
}
~exception_state() {
if(error_) {
std::rethrow_exception(exception_);
}
}
};
std::atomic<std::size_t> in_flight_{0}; std::atomic<std::size_t> in_flight_{0};
/// We use a raw pointer here instead of a shared_ptr to avoid a race condition /// We use a raw pointer here instead of a shared_ptr to avoid a race condition
/// on the destruction of a task_system that would lead to a thread trying to join itself /// on the destruction of a task_system that would lead to a thread trying to join itself
task_system* task_system_; task_system* task_system_;
exception_state exception_status_;
public: public:
task_group(task_system* ts): task_group(task_system* ts):
...@@ -112,33 +140,44 @@ public: ...@@ -112,33 +140,44 @@ public:
template <typename F> template <typename F>
class wrap { class wrap {
F f; F f_;
std::atomic<std::size_t>& counter; std::atomic<std::size_t>& counter_;
exception_state& exception_status_;
public: public:
// Construct from a compatible function and atomic counter // Construct from a compatible function and atomic counter
template <typename F2> template <typename F2>
explicit wrap(F2&& other, std::atomic<std::size_t>& c): explicit wrap(F2&& other, std::atomic<std::size_t>& c, exception_state& ex):
f(std::forward<F2>(other)), f_(std::forward<F2>(other)),
counter(c) counter_(c),
exception_status_(ex)
{} {}
wrap(wrap&& other): wrap(wrap&& other):
f(std::move(other.f)), f_(std::move(other.f_)),
counter(other.counter) counter_(other.counter_),
exception_status_(other.exception_status_)
{} {}
// std::function is not guaranteed to not copy the contents on move construction // std::function is not guaranteed to not copy the contents on move construction
// But the class is safe because we don't call operator() more than once on the same wrapped task // But the class is safe because we don't call operator() more than once on the same wrapped task
wrap(const wrap& other): wrap(const wrap& other):
f(other.f), f_(other.f_),
counter(other.counter) counter_(other.counter_),
exception_status_(other.exception_status_)
{} {}
void operator()() { void operator()() {
f(); if(!exception_status_) {
--counter; try {
f_();
}
catch (...) {
exception_status_.set(std::current_exception());
}
}
--counter_;
} }
}; };
...@@ -146,14 +185,14 @@ public: ...@@ -146,14 +185,14 @@ public:
using callable = typename std::decay<F>::type; using callable = typename std::decay<F>::type;
template <typename F> template <typename F>
wrap<callable<F>> make_wrapped_function(F&& f, std::atomic<std::size_t>& c) { wrap<callable<F>> make_wrapped_function(F&& f, std::atomic<std::size_t>& c, exception_state& ex) {
return wrap<callable<F>>(std::forward<F>(f), c); return wrap<callable<F>>(std::forward<F>(f), c, ex);
} }
template<typename F> template<typename F>
void run(F&& f) { void run(F&& f) {
++in_flight_; ++in_flight_;
task_system_->async(make_wrapped_function(std::forward<F>(f), in_flight_)); task_system_->async(make_wrapped_function(std::forward<F>(f), in_flight_, exception_status_));
} }
// wait till all tasks in this group are done // wait till all tasks in this group are done
...@@ -161,6 +200,14 @@ public: ...@@ -161,6 +200,14 @@ public:
while (in_flight_) { while (in_flight_) {
task_system_->try_run_task(); task_system_->try_run_task();
} }
if(auto ex = exception_status_.get()) {
try {
std::rethrow_exception(ex);
}
catch (...) {
throw;
}
}
} }
// Make sure that all tasks are done before clean up // Make sure that all tasks are done before clean up
......
...@@ -104,6 +104,7 @@ set(unit_sources ...@@ -104,6 +104,7 @@ set(unit_sources
test_swcio.cpp test_swcio.cpp
test_synapses.cpp test_synapses.cpp
test_thread.cpp test_thread.cpp
test_threading_exceptions.cpp
test_tree.cpp test_tree.cpp
test_transform.cpp test_transform.cpp
test_uninitialized.cpp test_uninitialized.cpp
......
#include "../gtest.h"
#include <arbor/domain_decomposition.hpp>
#include <arbor/load_balance.hpp>
#include <arbor/recipe.hpp>
#include <arbor/simulation.hpp>
#include <arbor/context.hpp>
#include <arbor/mc_cell.hpp>
#include <arbor/version.hpp>
#include "threading/threading.hpp"
using namespace arb::threading::impl;
using namespace arb::threading;
using namespace arb;
auto duration = std::chrono::nanoseconds(1);
struct error {
int code;
error(int c): code(c) {};
};
TEST(test_exception, single_task_no_throw) {
for (int nthreads = 1; nthreads < 20; nthreads*=2) {
task_system ts(nthreads);
task_group g(&ts);
g.run([](){ std::this_thread::sleep_for(duration); });
EXPECT_NO_THROW(g.wait());
}
}
TEST(test_exception, single_task_throw) {
for (int nthreads = 1; nthreads < 20; nthreads*=2) {
task_system ts(1);
task_group g(&ts);
g.run([](){ throw error(0);});
try {
g.wait();
FAIL() << "Expected exception";
}
catch (error &e) {
EXPECT_EQ(e.code, 0);
}
catch (...) {
FAIL() << "Expected error type";
}
}
}
TEST(test_exception, many_tasks_no_throw) {
for (int nthreads = 1; nthreads < 20; nthreads*=2) {
task_system ts(nthreads);
task_group g(&ts);
for (int i = 1; i < 100; i++) {
for (int j = 0; j < i; j++) {
g.run([j](){ std::this_thread::sleep_for(duration); });
}
EXPECT_NO_THROW(g.wait());
}
}
}
TEST(test_exception, many_tasks_one_throw) {
for (int nthreads = 1; nthreads < 20; nthreads*=2) {
task_system ts(nthreads);
task_group g(&ts);
for (int i = 1; i < 100; i++) {
for (int j = 0; j < i; j++) {
g.run([j](){ if(j==0) {throw error(j);} });
}
try {
g.wait();
FAIL() << "Expected exception";
}
catch (error &e) {
EXPECT_EQ(e.code, 0);
}
catch (...) {
FAIL() << "Expected error type";
}
}
}
}
TEST(test_exception, many_tasks_many_throw) {
for (int nthreads = 1; nthreads < 20; nthreads*=2) {
task_system ts(nthreads);
task_group g(&ts);
for (int i = 1; i < 100; i++) {
for (int j = 0; j < i; j++) {
g.run([j](){ if(j%5 == 0) {throw error(j);} });
}
try {
g.wait();
FAIL() << "Expected exception";
}
catch (error &e) {
EXPECT_EQ(e.code%5, 0);
}
catch (...) {
FAIL() << "Expected error type";
}
}
}
}
TEST(test_exception, many_tasks_all_throw) {
for (int nthreads = 1; nthreads < 20; nthreads*=2) {
task_system ts(nthreads);
task_group g(&ts);
for (int i = 1; i < 100; i++) {
for (int j = 0; j < i; j++) {
g.run([j](){ throw error(j); });
}
try {
g.wait();
FAIL() << "Expected exception";
}
catch (error &e) {
EXPECT_TRUE((e.code >= 0) && (e.code < i));
}
catch (...) {
FAIL() << "Expected error type";
}
}
}
}
TEST(test_exception, parallel_for_no_throw) {
for (int nthreads = 1; nthreads < 20; nthreads*=2) {
task_system ts(nthreads);
for (int n = 1; n < 100; n*=2) {
EXPECT_NO_THROW(parallel_for::apply(0, n, &ts, [n](int i) { std::this_thread::sleep_for(duration); }));
}
}
}
TEST(test_exception, parallel_for_one_throw) {
for (int nthreads = 1; nthreads < 20; nthreads*=2) {
task_system ts(nthreads);
for (int n = 1; n < 100; n*=2) {
try {
parallel_for::apply(0, n, &ts, [n](int i) { if(i==n-1) {throw error(i);} });
FAIL() << "Expected exception";
}
catch (error &e) {
EXPECT_TRUE(e.code == n-1);
}
catch (...) {
FAIL() << "Expected error type";
}
}
}
}
TEST(test_exception, parallel_for_many_throw) {
for (int nthreads = 1; nthreads < 20; nthreads*=2) {
task_system ts(nthreads);
for (int n = 1; n < 100; n*=2) {
try {
parallel_for::apply(0, n, &ts, [n](int i) { if(i%7 == 0) {throw error(i);} });
FAIL() << "Expected exception";
}
catch (error &e) {
EXPECT_EQ(e.code%7, 0);
}
catch (...) {
FAIL() << "Expected error type";
}
}
}
}
TEST(test_exception, parallel_for_all_throw) {
for (int nthreads = 1; nthreads < 20; nthreads*=2) {
task_system ts(nthreads);
for (int n = 1; n < 100; n*=2) {
try {
parallel_for::apply(0, n, &ts, [n](int i) { throw error(i); });
FAIL() << "Expected exception";
}
catch (error &e) {
EXPECT_TRUE((e.code >= 0) && (e.code < n));
}
catch (...) {
FAIL() << "Expected error type";
}
}
}
}
TEST(test_exception, nested_parallel_for_no_throw) {
for (int nthreads = 1; nthreads < 20; nthreads*=2) {
task_system ts(nthreads);
for (int m = 1; m < 50; m*=2) {
for (int n = 1; n < 50; n = !n?1:2*n) {
EXPECT_NO_THROW(parallel_for::apply(0, n, &ts, [&](int i) {
parallel_for::apply(0, m, &ts, [](int j) { std::this_thread::sleep_for(duration); });
}));
}
}
}
}
TEST(test_exception, nested_parallel_for_one_throw) {
for (int nthreads = 1; nthreads < 20; nthreads*=2) {
task_system ts(nthreads);
for (int m = 1; m < 100; m*=2) {
for (int n = 1; n < 100; n = !n?1:2*n) {
try {
parallel_for::apply(0, n, &ts, [&](int i) {
parallel_for::apply(0, m, &ts, [](int j) { if (j == 0) {throw error(j);} });
});
FAIL() << "Expected exception";
}
catch (error &e) {
EXPECT_EQ(e.code, 0);
}
catch (...) {
FAIL() << "Expected error type";
}
}
}
}
}
TEST(test_exception, nested_parallel_for_many_throw) {
for (int nthreads = 1; nthreads < 20; nthreads*=2) {
task_system ts(nthreads);
for (int m = 1; m < 100; m*=2) {
for (int n = 1; n < 100; n = !n?1:2*n) {
try {
parallel_for::apply(0, n, &ts, [&](int i) {
parallel_for::apply(0, m, &ts, [](int j) { if (j%10 == 0) {throw error(j);} });
});
FAIL() << "Expected exception";
}
catch (error &e) {
EXPECT_TRUE(e.code%10==0);
}
catch (...) {
FAIL() << "Expected error type";
}
}
}
}
}
TEST(test_exception, nested_parallel_for_all_throw) {
for (int nthreads = 1; nthreads < 20; nthreads*=2) {
task_system ts(nthreads);
for (int m = 1; m < 100; m*=2) {
for (int n = 1; n < 100; n = !n?1:2*n) {
try {
parallel_for::apply(0, n, &ts, [&](int i) {
parallel_for::apply(0, m, &ts, [](int j) { throw error(j); });
});
FAIL() << "Expected exception";
}
catch (error &e) {
EXPECT_TRUE(e.code >= 0 && e.code < m);
}
catch (...) {
FAIL() << "Expected error type";
}
}
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment