From b5662870ff0eac66daed24e87611a726dc27944c Mon Sep 17 00:00:00 2001 From: noraabiakar <nora.abiakar@gmail.com> Date: Wed, 26 Sep 2018 09:23:46 +0200 Subject: [PATCH] Threading exceptions (#595) Propagate exceptions generated in `task_group` tasks on different threads in the threading backend, so that they are thrown on the main thread on `task_group.wait()`. Add tests that verify that exceptions are propagated correctly. Fixes #310. --- arbor/threading/threading.hpp | 75 +++++-- test/unit/CMakeLists.txt | 1 + test/unit/test_threading_exceptions.cpp | 270 ++++++++++++++++++++++++ 3 files changed, 332 insertions(+), 14 deletions(-) create mode 100644 test/unit/test_threading_exceptions.cpp diff --git a/arbor/threading/threading.hpp b/arbor/threading/threading.hpp index 5c9d8ca0..6e5ab984 100644 --- a/arbor/threading/threading.hpp +++ b/arbor/threading/threading.hpp @@ -97,10 +97,38 @@ public: class task_group { private: + struct exception_state { + std::atomic<bool> error_{false}; + std::exception_ptr exception_; + std::mutex mutex_; + + operator bool() const { + return error_.load(std::memory_order_relaxed); + } + + void set(std::exception_ptr ex) { + error_.store(true, std::memory_order_relaxed); + lock ex_lock{mutex_}; + exception_ = std::move(ex); + } + + std::exception_ptr get() { + error_.store(false, std::memory_order_relaxed); + return std::move(exception_); + } + + ~exception_state() { + if(error_) { + std::rethrow_exception(exception_); + } + } + }; + std::atomic<std::size_t> in_flight_{0}; /// We use a raw pointer here instead of a shared_ptr to avoid a race condition /// on the destruction of a task_system that would lead to a thread trying to join itself task_system* task_system_; + exception_state exception_status_; public: task_group(task_system* ts): @@ -112,33 +140,44 @@ public: template <typename F> class wrap { - F f; - std::atomic<std::size_t>& counter; + F f_; + std::atomic<std::size_t>& counter_; + exception_state& exception_status_; public: // Construct from a compatible function and atomic counter template <typename F2> - explicit wrap(F2&& other, std::atomic<std::size_t>& c): - f(std::forward<F2>(other)), - counter(c) + explicit wrap(F2&& other, std::atomic<std::size_t>& c, exception_state& ex): + f_(std::forward<F2>(other)), + counter_(c), + exception_status_(ex) {} wrap(wrap&& other): - f(std::move(other.f)), - counter(other.counter) + f_(std::move(other.f_)), + counter_(other.counter_), + exception_status_(other.exception_status_) {} // std::function is not guaranteed to not copy the contents on move construction // But the class is safe because we don't call operator() more than once on the same wrapped task wrap(const wrap& other): - f(other.f), - counter(other.counter) + f_(other.f_), + counter_(other.counter_), + exception_status_(other.exception_status_) {} void operator()() { - f(); - --counter; + if(!exception_status_) { + try { + f_(); + } + catch (...) { + exception_status_.set(std::current_exception()); + } + } + --counter_; } }; @@ -146,14 +185,14 @@ public: using callable = typename std::decay<F>::type; template <typename F> - wrap<callable<F>> make_wrapped_function(F&& f, std::atomic<std::size_t>& c) { - return wrap<callable<F>>(std::forward<F>(f), c); + wrap<callable<F>> make_wrapped_function(F&& f, std::atomic<std::size_t>& c, exception_state& ex) { + return wrap<callable<F>>(std::forward<F>(f), c, ex); } template<typename F> void run(F&& f) { ++in_flight_; - task_system_->async(make_wrapped_function(std::forward<F>(f), in_flight_)); + task_system_->async(make_wrapped_function(std::forward<F>(f), in_flight_, exception_status_)); } // wait till all tasks in this group are done @@ -161,6 +200,14 @@ public: while (in_flight_) { task_system_->try_run_task(); } + if(auto ex = exception_status_.get()) { + try { + std::rethrow_exception(ex); + } + catch (...) { + throw; + } + } } // Make sure that all tasks are done before clean up diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 28fb4002..4cacc7cf 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -104,6 +104,7 @@ set(unit_sources test_swcio.cpp test_synapses.cpp test_thread.cpp + test_threading_exceptions.cpp test_tree.cpp test_transform.cpp test_uninitialized.cpp diff --git a/test/unit/test_threading_exceptions.cpp b/test/unit/test_threading_exceptions.cpp new file mode 100644 index 00000000..27321081 --- /dev/null +++ b/test/unit/test_threading_exceptions.cpp @@ -0,0 +1,270 @@ +#include "../gtest.h" + +#include <arbor/domain_decomposition.hpp> +#include <arbor/load_balance.hpp> +#include <arbor/recipe.hpp> +#include <arbor/simulation.hpp> +#include <arbor/context.hpp> +#include <arbor/mc_cell.hpp> +#include <arbor/version.hpp> + +#include "threading/threading.hpp" + +using namespace arb::threading::impl; +using namespace arb::threading; +using namespace arb; + +auto duration = std::chrono::nanoseconds(1); + +struct error { + int code; + error(int c): code(c) {}; +}; + +TEST(test_exception, single_task_no_throw) { + for (int nthreads = 1; nthreads < 20; nthreads*=2) { + task_system ts(nthreads); + task_group g(&ts); + g.run([](){ std::this_thread::sleep_for(duration); }); + EXPECT_NO_THROW(g.wait()); + } +} + +TEST(test_exception, single_task_throw) { + for (int nthreads = 1; nthreads < 20; nthreads*=2) { + task_system ts(1); + task_group g(&ts); + g.run([](){ throw error(0);}); + try { + g.wait(); + FAIL() << "Expected exception"; + } + catch (error &e) { + EXPECT_EQ(e.code, 0); + } + catch (...) { + FAIL() << "Expected error type"; + } + } +} + +TEST(test_exception, many_tasks_no_throw) { + for (int nthreads = 1; nthreads < 20; nthreads*=2) { + task_system ts(nthreads); + task_group g(&ts); + for (int i = 1; i < 100; i++) { + for (int j = 0; j < i; j++) { + g.run([j](){ std::this_thread::sleep_for(duration); }); + } + EXPECT_NO_THROW(g.wait()); + } + } +} + +TEST(test_exception, many_tasks_one_throw) { + for (int nthreads = 1; nthreads < 20; nthreads*=2) { + task_system ts(nthreads); + task_group g(&ts); + for (int i = 1; i < 100; i++) { + for (int j = 0; j < i; j++) { + g.run([j](){ if(j==0) {throw error(j);} }); + } + try { + g.wait(); + FAIL() << "Expected exception"; + } + catch (error &e) { + EXPECT_EQ(e.code, 0); + } + catch (...) { + FAIL() << "Expected error type"; + } + } + } +} + +TEST(test_exception, many_tasks_many_throw) { + for (int nthreads = 1; nthreads < 20; nthreads*=2) { + task_system ts(nthreads); + task_group g(&ts); + for (int i = 1; i < 100; i++) { + for (int j = 0; j < i; j++) { + g.run([j](){ if(j%5 == 0) {throw error(j);} }); + } + try { + g.wait(); + FAIL() << "Expected exception"; + } + catch (error &e) { + EXPECT_EQ(e.code%5, 0); + } + catch (...) { + FAIL() << "Expected error type"; + } + } + } +} + +TEST(test_exception, many_tasks_all_throw) { + for (int nthreads = 1; nthreads < 20; nthreads*=2) { + task_system ts(nthreads); + task_group g(&ts); + for (int i = 1; i < 100; i++) { + for (int j = 0; j < i; j++) { + g.run([j](){ throw error(j); }); + } + try { + g.wait(); + FAIL() << "Expected exception"; + } + catch (error &e) { + EXPECT_TRUE((e.code >= 0) && (e.code < i)); + } + catch (...) { + FAIL() << "Expected error type"; + } + } + } +} + +TEST(test_exception, parallel_for_no_throw) { + for (int nthreads = 1; nthreads < 20; nthreads*=2) { + task_system ts(nthreads); + for (int n = 1; n < 100; n*=2) { + EXPECT_NO_THROW(parallel_for::apply(0, n, &ts, [n](int i) { std::this_thread::sleep_for(duration); })); + } + } +} + +TEST(test_exception, parallel_for_one_throw) { + for (int nthreads = 1; nthreads < 20; nthreads*=2) { + task_system ts(nthreads); + for (int n = 1; n < 100; n*=2) { + try { + parallel_for::apply(0, n, &ts, [n](int i) { if(i==n-1) {throw error(i);} }); + FAIL() << "Expected exception"; + } + catch (error &e) { + EXPECT_TRUE(e.code == n-1); + } + catch (...) { + FAIL() << "Expected error type"; + } + } + } +} + +TEST(test_exception, parallel_for_many_throw) { + for (int nthreads = 1; nthreads < 20; nthreads*=2) { + task_system ts(nthreads); + for (int n = 1; n < 100; n*=2) { + try { + parallel_for::apply(0, n, &ts, [n](int i) { if(i%7 == 0) {throw error(i);} }); + FAIL() << "Expected exception"; + } + catch (error &e) { + EXPECT_EQ(e.code%7, 0); + } + catch (...) { + FAIL() << "Expected error type"; + } + } + } +} + +TEST(test_exception, parallel_for_all_throw) { + for (int nthreads = 1; nthreads < 20; nthreads*=2) { + task_system ts(nthreads); + for (int n = 1; n < 100; n*=2) { + try { + parallel_for::apply(0, n, &ts, [n](int i) { throw error(i); }); + FAIL() << "Expected exception"; + } + catch (error &e) { + EXPECT_TRUE((e.code >= 0) && (e.code < n)); + } + catch (...) { + FAIL() << "Expected error type"; + } + } + } +} + +TEST(test_exception, nested_parallel_for_no_throw) { + for (int nthreads = 1; nthreads < 20; nthreads*=2) { + task_system ts(nthreads); + for (int m = 1; m < 50; m*=2) { + for (int n = 1; n < 50; n = !n?1:2*n) { + EXPECT_NO_THROW(parallel_for::apply(0, n, &ts, [&](int i) { + parallel_for::apply(0, m, &ts, [](int j) { std::this_thread::sleep_for(duration); }); + })); + } + } + } +} + +TEST(test_exception, nested_parallel_for_one_throw) { + for (int nthreads = 1; nthreads < 20; nthreads*=2) { + task_system ts(nthreads); + for (int m = 1; m < 100; m*=2) { + for (int n = 1; n < 100; n = !n?1:2*n) { + try { + parallel_for::apply(0, n, &ts, [&](int i) { + parallel_for::apply(0, m, &ts, [](int j) { if (j == 0) {throw error(j);} }); + }); + FAIL() << "Expected exception"; + } + catch (error &e) { + EXPECT_EQ(e.code, 0); + } + catch (...) { + FAIL() << "Expected error type"; + } + } + } + } +} + +TEST(test_exception, nested_parallel_for_many_throw) { + for (int nthreads = 1; nthreads < 20; nthreads*=2) { + task_system ts(nthreads); + for (int m = 1; m < 100; m*=2) { + for (int n = 1; n < 100; n = !n?1:2*n) { + try { + parallel_for::apply(0, n, &ts, [&](int i) { + parallel_for::apply(0, m, &ts, [](int j) { if (j%10 == 0) {throw error(j);} }); + }); + FAIL() << "Expected exception"; + } + catch (error &e) { + EXPECT_TRUE(e.code%10==0); + } + catch (...) { + FAIL() << "Expected error type"; + } + } + } + } +} + +TEST(test_exception, nested_parallel_for_all_throw) { + for (int nthreads = 1; nthreads < 20; nthreads*=2) { + task_system ts(nthreads); + for (int m = 1; m < 100; m*=2) { + for (int n = 1; n < 100; n = !n?1:2*n) { + try { + parallel_for::apply(0, n, &ts, [&](int i) { + parallel_for::apply(0, m, &ts, [](int j) { throw error(j); }); + }); + FAIL() << "Expected exception"; + } + catch (error &e) { + EXPECT_TRUE(e.code >= 0 && e.code < m); + } + catch (...) { + FAIL() << "Expected error type"; + } + } + } + } +} -- GitLab