diff --git a/arbor/execution_context.cpp b/arbor/execution_context.cpp index 12d149a55a1030541eed2f9aea9526433c2d1892..989ef13f097d902d51689066ce1cdca4cfa5069c 100644 --- a/arbor/execution_context.cpp +++ b/arbor/execution_context.cpp @@ -2,14 +2,13 @@ #include <memory> #include <arbor/context.hpp> -#include <arbor/version.hpp> #include "gpu_context.hpp" #include "distributed_context.hpp" #include "execution_context.hpp" #include "threading/threading.hpp" -#ifdef ARB_MPI_ENABLED +#ifdef ARB_HAVE_MPI #include <mpi.h> #endif @@ -34,7 +33,7 @@ context make_context(const proc_allocation& p) { return context(new execution_context(p), [](execution_context* p){delete p;}); } -#ifdef ARB_MPI_ENABLED +#ifdef ARB_HAVE_MPI template <> execution_context::execution_context(const proc_allocation& resources, MPI_Comm comm): distributed(make_mpi_context(comm)), diff --git a/arbor/threading/threading.hpp b/arbor/threading/threading.hpp index 6e5ab9840ae67f4035d08c27b629b579d45d8365..9a3d94c5b25a44fce9324bc79a677cf6cc1592bb 100644 --- a/arbor/threading/threading.hpp +++ b/arbor/threading/threading.hpp @@ -112,21 +112,25 @@ private: exception_ = std::move(ex); } - std::exception_ptr get() { + // Clear exception state but return old state. + // For consistency, this must only be called when there + // are no tasks in flight that reference this exception state. + std::exception_ptr reset() { + auto ex = std::move(exception_); error_.store(false, std::memory_order_relaxed); - return std::move(exception_); - } - - ~exception_state() { - if(error_) { - std::rethrow_exception(exception_); - } + exception_ = nullptr; + return ex; } }; std::atomic<std::size_t> in_flight_{0}; - /// We use a raw pointer here instead of a shared_ptr to avoid a race condition - /// on the destruction of a task_system that would lead to a thread trying to join itself + + // Set by run(), cleared by wait(). Used to check task completion status + // in destructor. + bool running_ = false; + + // We use a raw pointer here instead of a shared_ptr to avoid a race condition + // on the destruction of a task_system that would lead to a thread trying to join itself. task_system* task_system_; exception_state exception_status_; @@ -145,7 +149,6 @@ public: exception_state& exception_status_; public: - // Construct from a compatible function and atomic counter template <typename F2> explicit wrap(F2&& other, std::atomic<std::size_t>& c, exception_state& ex): @@ -160,8 +163,8 @@ public: exception_status_(other.exception_status_) {} - // std::function is not guaranteed to not copy the contents on move construction - // But the class is safe because we don't call operator() more than once on the same wrapped task + // std::function is not guaranteed to not copy the contents on move construction, + // but the class is safe because we don't call operator() more than once on the same wrapped task. wrap(const wrap& other): f_(other.f_), counter_(other.counter_), @@ -169,7 +172,7 @@ public: {} void operator()() { - if(!exception_status_) { + if (!exception_status_) { try { f_(); } @@ -191,28 +194,25 @@ public: template<typename F> void run(F&& f) { + running_ = true; ++in_flight_; task_system_->async(make_wrapped_function(std::forward<F>(f), in_flight_, exception_status_)); } - // wait till all tasks in this group are done + // Wait till all tasks in this group are done. void wait() { while (in_flight_) { task_system_->try_run_task(); } - if(auto ex = exception_status_.get()) { - try { - std::rethrow_exception(ex); - } - catch (...) { - throw; - } + running_ = false; + + if (auto ex = exception_status_.reset()) { + std::rethrow_exception(ex); } } - // Make sure that all tasks are done before clean up ~task_group() { - wait(); + if (running_) std::terminate(); } }; diff --git a/test/unit/test_threading_exceptions.cpp b/test/unit/test_threading_exceptions.cpp index 273210812bdc0c2d5f0a545bc5c82fd9fd3a49aa..14569bf2824709fcbb6c78033a738a25618d21fe 100644 --- a/test/unit/test_threading_exceptions.cpp +++ b/test/unit/test_threading_exceptions.cpp @@ -1,5 +1,7 @@ #include "../gtest.h" +#include <csignal> + #include <arbor/domain_decomposition.hpp> #include <arbor/load_balance.hpp> #include <arbor/recipe.hpp> @@ -14,7 +16,7 @@ using namespace arb::threading::impl; using namespace arb::threading; using namespace arb; -auto duration = std::chrono::nanoseconds(1); +const auto duration = std::chrono::nanoseconds(1); struct error { int code; @@ -54,7 +56,7 @@ TEST(test_exception, many_tasks_no_throw) { task_group g(&ts); for (int i = 1; i < 100; i++) { for (int j = 0; j < i; j++) { - g.run([j](){ std::this_thread::sleep_for(duration); }); + g.run([](){ std::this_thread::sleep_for(duration); }); } EXPECT_NO_THROW(g.wait()); } @@ -131,7 +133,7 @@ TEST(test_exception, parallel_for_no_throw) { for (int nthreads = 1; nthreads < 20; nthreads*=2) { task_system ts(nthreads); for (int n = 1; n < 100; n*=2) { - EXPECT_NO_THROW(parallel_for::apply(0, n, &ts, [n](int i) { std::this_thread::sleep_for(duration); })); + EXPECT_NO_THROW(parallel_for::apply(0, n, &ts, [](int i) { std::this_thread::sleep_for(duration); })); } } } @@ -159,7 +161,7 @@ TEST(test_exception, parallel_for_many_throw) { task_system ts(nthreads); for (int n = 1; n < 100; n*=2) { try { - parallel_for::apply(0, n, &ts, [n](int i) { if(i%7 == 0) {throw error(i);} }); + parallel_for::apply(0, n, &ts, [](int i) { if(i%7 == 0) {throw error(i);} }); FAIL() << "Expected exception"; } catch (error &e) { @@ -177,7 +179,7 @@ TEST(test_exception, parallel_for_all_throw) { task_system ts(nthreads); for (int n = 1; n < 100; n*=2) { try { - parallel_for::apply(0, n, &ts, [n](int i) { throw error(i); }); + parallel_for::apply(0, n, &ts, [](int i) { throw error(i); }); FAIL() << "Expected exception"; } catch (error &e) { @@ -268,3 +270,56 @@ TEST(test_exception, nested_parallel_for_all_throw) { } } } + +TEST(test_exception, post_exception_state) { + for (int nthreads: {1, 2, 16}) { + task_system ts(nthreads); + task_group g(&ts); + + for (int trial = 1; trial<100; ++trial) { + std::atomic<int> dummy(0); + std::atomic<int> counter(0); + int throws = 0; + + try { + for (int k = 0; k<100; ++k) { + g.run([&] { dummy.fetch_add(1, std::memory_order_relaxed); }); + } + g.run([&] { dummy.fetch_add(1, std::memory_order_relaxed); throw error(1); }); + g.wait(); + } + catch (error& e) { ++throws; } + + for (int k = 0; k<100; ++k) { + g.run([&] { ++counter; }); + } + try { + g.wait(); + } + catch (...) { + FAIL() << "Expected no error to be thrown"; + } + + EXPECT_EQ(100, counter.load()); + EXPECT_EQ(1, throws); + } + } +} + +TEST(test_exception, terminate_if_no_wait_DeathTest) { + testing::FLAGS_gtest_death_test_style = "threadsafe"; + + auto run_terminate_test = [](int nthread) { + task_system ts(nthread); + task_group g(&ts); + + g.run([] {}); + g.run([] {}); + g.run([] {}); + }; + + for (int n: {1, 2, 16}) { + // Check for (default) std::terminate behaviour: + ASSERT_EXIT(run_terminate_test(n), ::testing::KilledBySignal(SIGABRT), ""); + } +}