From d6aec81aba592676143c9bafb0d845865dc889b4 Mon Sep 17 00:00:00 2001 From: Sam Yates <yates@cscs.ch> Date: Mon, 1 Oct 2018 10:05:59 +0200 Subject: [PATCH] Fix double throw of captured exception in thread group. (#606) Fixes #603. * Clear exception pointer in exception_state helper class after move of state. * Rename exception_state::get() method to reset(). * Call std::terminate() if task_group is destroyed before tasks are collected with wait(). * Do not attempt to collect tasks in destructor for task_group. * Do not attempt to rethrow exception in destructor for exception_state. * Add unit test to verify correct exception behaviour when a task_group is runs and waits on a series of tasks. * Add unit test for terminate behaviour as above. Code quality fix ups: * Remove unused warning variable warning in threading exception tests. * Address if-statement spacing in threading.hpp. * Use ARB_HAVE_MPI in execution_context.cpp instead of introducing a dependency on generated version header via feature macro ARB_MPI_ENABLED. --- arbor/execution_context.cpp | 5 +- arbor/threading/threading.hpp | 48 +++++++++--------- test/unit/test_threading_exceptions.cpp | 65 +++++++++++++++++++++++-- 3 files changed, 86 insertions(+), 32 deletions(-) diff --git a/arbor/execution_context.cpp b/arbor/execution_context.cpp index 12d149a5..989ef13f 100644 --- a/arbor/execution_context.cpp +++ b/arbor/execution_context.cpp @@ -2,14 +2,13 @@ #include <memory> #include <arbor/context.hpp> -#include <arbor/version.hpp> #include "gpu_context.hpp" #include "distributed_context.hpp" #include "execution_context.hpp" #include "threading/threading.hpp" -#ifdef ARB_MPI_ENABLED +#ifdef ARB_HAVE_MPI #include <mpi.h> #endif @@ -34,7 +33,7 @@ context make_context(const proc_allocation& p) { return context(new execution_context(p), [](execution_context* p){delete p;}); } -#ifdef ARB_MPI_ENABLED +#ifdef ARB_HAVE_MPI template <> execution_context::execution_context(const proc_allocation& resources, MPI_Comm comm): distributed(make_mpi_context(comm)), diff --git a/arbor/threading/threading.hpp b/arbor/threading/threading.hpp index 6e5ab984..9a3d94c5 100644 --- a/arbor/threading/threading.hpp +++ b/arbor/threading/threading.hpp @@ -112,21 +112,25 @@ private: exception_ = std::move(ex); } - std::exception_ptr get() { + // Clear exception state but return old state. + // For consistency, this must only be called when there + // are no tasks in flight that reference this exception state. + std::exception_ptr reset() { + auto ex = std::move(exception_); error_.store(false, std::memory_order_relaxed); - return std::move(exception_); - } - - ~exception_state() { - if(error_) { - std::rethrow_exception(exception_); - } + exception_ = nullptr; + return ex; } }; std::atomic<std::size_t> in_flight_{0}; - /// We use a raw pointer here instead of a shared_ptr to avoid a race condition - /// on the destruction of a task_system that would lead to a thread trying to join itself + + // Set by run(), cleared by wait(). Used to check task completion status + // in destructor. + bool running_ = false; + + // We use a raw pointer here instead of a shared_ptr to avoid a race condition + // on the destruction of a task_system that would lead to a thread trying to join itself. task_system* task_system_; exception_state exception_status_; @@ -145,7 +149,6 @@ public: exception_state& exception_status_; public: - // Construct from a compatible function and atomic counter template <typename F2> explicit wrap(F2&& other, std::atomic<std::size_t>& c, exception_state& ex): @@ -160,8 +163,8 @@ public: exception_status_(other.exception_status_) {} - // std::function is not guaranteed to not copy the contents on move construction - // But the class is safe because we don't call operator() more than once on the same wrapped task + // std::function is not guaranteed to not copy the contents on move construction, + // but the class is safe because we don't call operator() more than once on the same wrapped task. wrap(const wrap& other): f_(other.f_), counter_(other.counter_), @@ -169,7 +172,7 @@ public: {} void operator()() { - if(!exception_status_) { + if (!exception_status_) { try { f_(); } @@ -191,28 +194,25 @@ public: template<typename F> void run(F&& f) { + running_ = true; ++in_flight_; task_system_->async(make_wrapped_function(std::forward<F>(f), in_flight_, exception_status_)); } - // wait till all tasks in this group are done + // Wait till all tasks in this group are done. void wait() { while (in_flight_) { task_system_->try_run_task(); } - if(auto ex = exception_status_.get()) { - try { - std::rethrow_exception(ex); - } - catch (...) { - throw; - } + running_ = false; + + if (auto ex = exception_status_.reset()) { + std::rethrow_exception(ex); } } - // Make sure that all tasks are done before clean up ~task_group() { - wait(); + if (running_) std::terminate(); } }; diff --git a/test/unit/test_threading_exceptions.cpp b/test/unit/test_threading_exceptions.cpp index 27321081..14569bf2 100644 --- a/test/unit/test_threading_exceptions.cpp +++ b/test/unit/test_threading_exceptions.cpp @@ -1,5 +1,7 @@ #include "../gtest.h" +#include <csignal> + #include <arbor/domain_decomposition.hpp> #include <arbor/load_balance.hpp> #include <arbor/recipe.hpp> @@ -14,7 +16,7 @@ using namespace arb::threading::impl; using namespace arb::threading; using namespace arb; -auto duration = std::chrono::nanoseconds(1); +const auto duration = std::chrono::nanoseconds(1); struct error { int code; @@ -54,7 +56,7 @@ TEST(test_exception, many_tasks_no_throw) { task_group g(&ts); for (int i = 1; i < 100; i++) { for (int j = 0; j < i; j++) { - g.run([j](){ std::this_thread::sleep_for(duration); }); + g.run([](){ std::this_thread::sleep_for(duration); }); } EXPECT_NO_THROW(g.wait()); } @@ -131,7 +133,7 @@ TEST(test_exception, parallel_for_no_throw) { for (int nthreads = 1; nthreads < 20; nthreads*=2) { task_system ts(nthreads); for (int n = 1; n < 100; n*=2) { - EXPECT_NO_THROW(parallel_for::apply(0, n, &ts, [n](int i) { std::this_thread::sleep_for(duration); })); + EXPECT_NO_THROW(parallel_for::apply(0, n, &ts, [](int i) { std::this_thread::sleep_for(duration); })); } } } @@ -159,7 +161,7 @@ TEST(test_exception, parallel_for_many_throw) { task_system ts(nthreads); for (int n = 1; n < 100; n*=2) { try { - parallel_for::apply(0, n, &ts, [n](int i) { if(i%7 == 0) {throw error(i);} }); + parallel_for::apply(0, n, &ts, [](int i) { if(i%7 == 0) {throw error(i);} }); FAIL() << "Expected exception"; } catch (error &e) { @@ -177,7 +179,7 @@ TEST(test_exception, parallel_for_all_throw) { task_system ts(nthreads); for (int n = 1; n < 100; n*=2) { try { - parallel_for::apply(0, n, &ts, [n](int i) { throw error(i); }); + parallel_for::apply(0, n, &ts, [](int i) { throw error(i); }); FAIL() << "Expected exception"; } catch (error &e) { @@ -268,3 +270,56 @@ TEST(test_exception, nested_parallel_for_all_throw) { } } } + +TEST(test_exception, post_exception_state) { + for (int nthreads: {1, 2, 16}) { + task_system ts(nthreads); + task_group g(&ts); + + for (int trial = 1; trial<100; ++trial) { + std::atomic<int> dummy(0); + std::atomic<int> counter(0); + int throws = 0; + + try { + for (int k = 0; k<100; ++k) { + g.run([&] { dummy.fetch_add(1, std::memory_order_relaxed); }); + } + g.run([&] { dummy.fetch_add(1, std::memory_order_relaxed); throw error(1); }); + g.wait(); + } + catch (error& e) { ++throws; } + + for (int k = 0; k<100; ++k) { + g.run([&] { ++counter; }); + } + try { + g.wait(); + } + catch (...) { + FAIL() << "Expected no error to be thrown"; + } + + EXPECT_EQ(100, counter.load()); + EXPECT_EQ(1, throws); + } + } +} + +TEST(test_exception, terminate_if_no_wait_DeathTest) { + testing::FLAGS_gtest_death_test_style = "threadsafe"; + + auto run_terminate_test = [](int nthread) { + task_system ts(nthread); + task_group g(&ts); + + g.run([] {}); + g.run([] {}); + g.run([] {}); + }; + + for (int n: {1, 2, 16}) { + // Check for (default) std::terminate behaviour: + ASSERT_EXIT(run_terminate_test(n), ::testing::KilledBySignal(SIGABRT), ""); + } +} -- GitLab