From d6aec81aba592676143c9bafb0d845865dc889b4 Mon Sep 17 00:00:00 2001
From: Sam Yates <yates@cscs.ch>
Date: Mon, 1 Oct 2018 10:05:59 +0200
Subject: [PATCH] Fix double throw of captured exception in thread group.
 (#606)

Fixes #603.

* Clear exception pointer in exception_state helper class after move of state.
* Rename exception_state::get() method to reset().
* Call std::terminate() if task_group is destroyed before tasks are collected with wait().
* Do not attempt to collect tasks in destructor for task_group.
* Do not attempt to rethrow exception in destructor for exception_state.
* Add unit test to verify correct exception behaviour when a task_group is runs and waits on a series of tasks.
* Add unit test for terminate behaviour as above.

Code quality fix ups:
* Remove unused warning variable warning in threading exception tests.
* Address if-statement spacing in threading.hpp.
* Use ARB_HAVE_MPI in execution_context.cpp instead of introducing a dependency on generated version header via feature macro ARB_MPI_ENABLED.
---
 arbor/execution_context.cpp             |  5 +-
 arbor/threading/threading.hpp           | 48 +++++++++---------
 test/unit/test_threading_exceptions.cpp | 65 +++++++++++++++++++++++--
 3 files changed, 86 insertions(+), 32 deletions(-)

diff --git a/arbor/execution_context.cpp b/arbor/execution_context.cpp
index 12d149a5..989ef13f 100644
--- a/arbor/execution_context.cpp
+++ b/arbor/execution_context.cpp
@@ -2,14 +2,13 @@
 #include <memory>
 
 #include <arbor/context.hpp>
-#include <arbor/version.hpp>
 
 #include "gpu_context.hpp"
 #include "distributed_context.hpp"
 #include "execution_context.hpp"
 #include "threading/threading.hpp"
 
-#ifdef ARB_MPI_ENABLED
+#ifdef ARB_HAVE_MPI
 #include <mpi.h>
 #endif
 
@@ -34,7 +33,7 @@ context make_context(const proc_allocation& p) {
     return context(new execution_context(p), [](execution_context* p){delete p;});
 }
 
-#ifdef ARB_MPI_ENABLED
+#ifdef ARB_HAVE_MPI
 template <>
 execution_context::execution_context(const proc_allocation& resources, MPI_Comm comm):
     distributed(make_mpi_context(comm)),
diff --git a/arbor/threading/threading.hpp b/arbor/threading/threading.hpp
index 6e5ab984..9a3d94c5 100644
--- a/arbor/threading/threading.hpp
+++ b/arbor/threading/threading.hpp
@@ -112,21 +112,25 @@ private:
             exception_ = std::move(ex);
         }
 
-        std::exception_ptr get() {
+        // Clear exception state but return old state.
+        // For consistency, this must only be called when there
+        // are no tasks in flight that reference this exception state.
+        std::exception_ptr reset() {
+            auto ex = std::move(exception_);
             error_.store(false, std::memory_order_relaxed);
-            return std::move(exception_);
-        }
-
-        ~exception_state() {
-            if(error_) {
-                std::rethrow_exception(exception_);
-            }
+            exception_ = nullptr;
+            return ex;
         }
     };
 
     std::atomic<std::size_t> in_flight_{0};
-    /// We use a raw pointer here instead of a shared_ptr to avoid a race condition
-    /// on the destruction of a task_system that would lead to a thread trying to join itself
+
+    // Set by run(), cleared by wait(). Used to check task completion status
+    // in destructor.
+    bool running_ = false;
+
+    // We use a raw pointer here instead of a shared_ptr to avoid a race condition
+    // on the destruction of a task_system that would lead to a thread trying to join itself.
     task_system* task_system_;
     exception_state exception_status_;
 
@@ -145,7 +149,6 @@ public:
         exception_state& exception_status_;
 
     public:
-
         // Construct from a compatible function and atomic counter
         template <typename F2>
         explicit wrap(F2&& other, std::atomic<std::size_t>& c, exception_state& ex):
@@ -160,8 +163,8 @@ public:
                 exception_status_(other.exception_status_)
         {}
 
-        // std::function is not guaranteed to not copy the contents on move construction
-        // But the class is safe because we don't call operator() more than once on the same wrapped task
+        // std::function is not guaranteed to not copy the contents on move construction,
+        // but the class is safe because we don't call operator() more than once on the same wrapped task.
         wrap(const wrap& other):
                 f_(other.f_),
                 counter_(other.counter_),
@@ -169,7 +172,7 @@ public:
         {}
 
         void operator()() {
-            if(!exception_status_) {
+            if (!exception_status_) {
                 try {
                     f_();
                 }
@@ -191,28 +194,25 @@ public:
 
     template<typename F>
     void run(F&& f) {
+        running_ = true;
         ++in_flight_;
         task_system_->async(make_wrapped_function(std::forward<F>(f), in_flight_, exception_status_));
     }
 
-    // wait till all tasks in this group are done
+    // Wait till all tasks in this group are done.
     void wait() {
         while (in_flight_) {
             task_system_->try_run_task();
         }
-        if(auto ex = exception_status_.get()) {
-            try {
-                std::rethrow_exception(ex);
-            }
-            catch (...) {
-                throw;
-            }
+        running_ = false;
+
+        if (auto ex = exception_status_.reset()) {
+            std::rethrow_exception(ex);
         }
     }
 
-    // Make sure that all tasks are done before clean up
     ~task_group() {
-        wait();
+        if (running_) std::terminate();
     }
 };
 
diff --git a/test/unit/test_threading_exceptions.cpp b/test/unit/test_threading_exceptions.cpp
index 27321081..14569bf2 100644
--- a/test/unit/test_threading_exceptions.cpp
+++ b/test/unit/test_threading_exceptions.cpp
@@ -1,5 +1,7 @@
 #include "../gtest.h"
 
+#include <csignal>
+
 #include <arbor/domain_decomposition.hpp>
 #include <arbor/load_balance.hpp>
 #include <arbor/recipe.hpp>
@@ -14,7 +16,7 @@ using namespace arb::threading::impl;
 using namespace arb::threading;
 using namespace arb;
 
-auto duration = std::chrono::nanoseconds(1);
+const auto duration = std::chrono::nanoseconds(1);
 
 struct error {
     int code;
@@ -54,7 +56,7 @@ TEST(test_exception, many_tasks_no_throw) {
         task_group g(&ts);
         for (int i = 1; i < 100; i++) {
             for (int j = 0; j < i; j++) {
-                g.run([j](){ std::this_thread::sleep_for(duration); });
+                g.run([](){ std::this_thread::sleep_for(duration); });
             }
             EXPECT_NO_THROW(g.wait());
         }
@@ -131,7 +133,7 @@ TEST(test_exception, parallel_for_no_throw) {
     for (int nthreads = 1; nthreads < 20; nthreads*=2) {
         task_system ts(nthreads);
         for (int n = 1; n < 100; n*=2) {
-            EXPECT_NO_THROW(parallel_for::apply(0, n, &ts, [n](int i) { std::this_thread::sleep_for(duration); }));
+            EXPECT_NO_THROW(parallel_for::apply(0, n, &ts, [](int i) { std::this_thread::sleep_for(duration); }));
         }
     }
 }
@@ -159,7 +161,7 @@ TEST(test_exception, parallel_for_many_throw) {
         task_system ts(nthreads);
         for (int n = 1; n < 100; n*=2) {
             try {
-                parallel_for::apply(0, n, &ts, [n](int i) { if(i%7 == 0) {throw error(i);} });
+                parallel_for::apply(0, n, &ts, [](int i) { if(i%7 == 0) {throw error(i);} });
                 FAIL() << "Expected exception";
             }
             catch (error &e) {
@@ -177,7 +179,7 @@ TEST(test_exception, parallel_for_all_throw) {
         task_system ts(nthreads);
         for (int n = 1; n < 100; n*=2) {
             try {
-                parallel_for::apply(0, n, &ts, [n](int i) { throw error(i); });
+                parallel_for::apply(0, n, &ts, [](int i) { throw error(i); });
                 FAIL() << "Expected exception";
             }
             catch (error &e) {
@@ -268,3 +270,56 @@ TEST(test_exception, nested_parallel_for_all_throw) {
         }
     }
 }
+
+TEST(test_exception, post_exception_state) {
+    for (int nthreads: {1, 2, 16}) {
+        task_system ts(nthreads);
+        task_group g(&ts);
+
+        for (int trial = 1; trial<100; ++trial) {
+            std::atomic<int> dummy(0);
+            std::atomic<int> counter(0);
+            int throws = 0;
+
+            try {
+                for (int k = 0; k<100; ++k) {
+                    g.run([&] { dummy.fetch_add(1, std::memory_order_relaxed); });
+                }
+                g.run([&] { dummy.fetch_add(1, std::memory_order_relaxed); throw error(1); });
+                g.wait();
+            }
+            catch (error& e) { ++throws; }
+
+            for (int k = 0; k<100; ++k) {
+                g.run([&] { ++counter; });
+            }
+            try {
+                g.wait();
+            }
+            catch (...) {
+                FAIL() << "Expected no error to be thrown";
+            }
+
+            EXPECT_EQ(100, counter.load());
+            EXPECT_EQ(1, throws);
+        }
+    }
+}
+
+TEST(test_exception, terminate_if_no_wait_DeathTest) {
+    testing::FLAGS_gtest_death_test_style = "threadsafe";
+
+    auto run_terminate_test = [](int nthread) {
+        task_system ts(nthread);
+        task_group g(&ts);
+
+        g.run([] {});
+        g.run([] {});
+        g.run([] {});
+    };
+
+    for (int n: {1, 2, 16}) {
+        // Check for (default) std::terminate behaviour:
+        ASSERT_EXIT(run_terminate_test(n), ::testing::KilledBySignal(SIGABRT), "");
+    }
+}
-- 
GitLab