From 6c98c1fc1ba87cff5e20261e4e3c356e93eee89c Mon Sep 17 00:00:00 2001 From: Alexander Peyser <apeyser@users.noreply.github.com> Date: Wed, 1 Feb 2017 14:16:07 +0100 Subject: [PATCH] Threading pool (#144) Add threading pool built on `std::thread` * Provide new threading model 'cthread' for nestmc based on a pool of `std::thread` objects. * Unify duplicated timer class provided by `serial`, `omp` and now `cthread` threading models. --- CMakeLists.txt | 10 + src/CMakeLists.txt | 4 + src/threading/cthread.cpp | 188 +++++++++++++ src/threading/cthread.hpp | 11 + src/threading/cthread_impl.hpp | 269 +++++++++++++++++++ src/threading/cthread_parallel_stable_sort.h | 154 +++++++++++ src/threading/cthread_sort.hpp | 26 ++ src/threading/omp.hpp | 21 +- src/threading/serial.hpp | 20 +- src/threading/threading.hpp | 2 + src/threading/timer.hpp | 29 ++ 11 files changed, 702 insertions(+), 32 deletions(-) create mode 100644 src/threading/cthread.cpp create mode 100644 src/threading/cthread.hpp create mode 100644 src/threading/cthread_impl.hpp create mode 100644 src/threading/cthread_parallel_stable_sort.h create mode 100644 src/threading/cthread_sort.hpp create mode 100644 src/threading/timer.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f51b8f53..1a4858b0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,6 +69,16 @@ elseif(NMC_THREADING_MODEL MATCHES "omp") add_definitions(-DNMC_HAVE_OMP) set(NMC_HAVE_OMP TRUE) +elseif(NMC_THREADING_MODEL MATCHES "cthread") + find_package(Threads REQUIRED) + add_definitions(-DNMC_HAVE_CTHREAD) + set(NMC_HAVE_CTHREAD TRUE) + list(APPEND EXTERNAL_LIBRARIES ${CMAKE_THREAD_LIBS_INIT}) + + if(CMAKE_USE_PTHREADS_INIT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") + endif() + elseif(NMC_THREADING_MODEL MATCHES "serial") #setup previously done diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 09a61691..eda28c0c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -20,6 +20,10 @@ if(NMC_WITH_MPI) set(BASE_SOURCES ${BASE_SOURCES} communication/mpi.cpp) endif() +if(NMC_HAVE_CTHREAD) + set(BASE_SOURCES ${BASE_SOURCES} threading/cthread.cpp) +endif() + add_library(nestmc ${BASE_SOURCES} ${HEADERS}) if (NMC_AUTO_RUN_MODCC_ON_CHANGES) diff --git a/src/threading/cthread.cpp b/src/threading/cthread.cpp new file mode 100644 index 00000000..0da76d5a --- /dev/null +++ b/src/threading/cthread.cpp @@ -0,0 +1,188 @@ +#include <cassert> +#include <exception> +#include <iostream> + +#include "cthread.hpp" + + +using namespace nest::mc::threading::impl; + +// RAII owner for a task in flight +struct task_pool::run_task { + task_pool& pool; + lock& lck; + task tsk; + + run_task(task_pool&, lock&); + ~run_task(); +}; + +// Own a task in flight +// lock should be passed locked, +// and will be unlocked after call +task_pool::run_task::run_task(task_pool& pool, lock& lck): + pool{pool}, + lck{lck}, + tsk{} +{ + std::swap(tsk, pool.tasks_.front()); + pool.tasks_.pop_front(); + + lck.unlock(); + pool.tasks_available_.notify_all(); +} + +// Release task +// Call unlocked, returns unlocked +task_pool::run_task::~run_task() { + lck.lock(); + tsk.second->in_flight--; + + lck.unlock(); + pool.tasks_available_.notify_all(); +} + +template<typename B> +void task_pool::run_tasks_loop(B finished) { + lock lck{tasks_mutex_, std::defer_lock}; + while (true) { + lck.lock(); + + while (! quit_ && tasks_.empty() && ! finished()) { + tasks_available_.wait(lck); + } + if (quit_ || finished()) { + return; + } + + run_task run{*this, lck}; + run.tsk.first(); + } +} + +// runs forever until quit is true +void task_pool::run_tasks_forever() { + run_tasks_loop([] {return false;}); +} + +// run until out of tasks for a group +void task_pool::run_tasks_while(task_group* g) { + run_tasks_loop([=] {return ! g->in_flight;}); +} + +// Create pool and threads +// new threads are nthreads-1 +task_pool::task_pool(std::size_t nthreads): + tasks_mutex_{}, + tasks_available_{}, + tasks_{}, + threads_{} +{ + assert(nthreads > 0); + + // now for the main thread + auto tid = std::this_thread::get_id(); + thread_ids_[tid] = 0; + + // and go from there + for (std::size_t i = 1; i < nthreads; i++) { + threads_.emplace_back([this]{run_tasks_forever();}); + tid = threads_.back().get_id(); + thread_ids_[tid] = i; + } +} + +task_pool::~task_pool() { + { + lock lck{tasks_mutex_}; + quit_ = true; + } + tasks_available_.notify_all(); + + for (auto& thread: threads_) { + thread.join(); + } +} + +// push a task into pool +void task_pool::run(const task& tsk) { + { + lock lck{tasks_mutex_}; + tasks_.push_back(tsk); + tsk.second->in_flight++; + } + tasks_available_.notify_all(); +} + +void task_pool::run(task&& tsk) { + { + lock lck{tasks_mutex_}; + tasks_.push_back(std::move(tsk)); + tsk.second->in_flight++; + } + tasks_available_.notify_all(); +} + +// call on main thread +// uses this thread to run tasks +// and waits until the entire task +// queue is cleared +void task_pool::wait(task_group* g) { + run_tasks_while(g); +} + +[[noreturn]] +static void terminate(const char *const msg) { + std::cerr << "NMC_NUM_THREADS_ERROR: " << msg << std::endl; + std::terminate(); +} + +// should check string, throw exception on missing or badly formed +static size_t global_get_num_threads() { + const char* nthreads_str; + // select variable to use: + // If NMC_NUM_THREADS_VAR is set, use $NMC_NUM_THREADS_VAR + // else if NMC_NUM_THREAD set, use it + // else if OMP_NUM_THREADS set, use it + if (auto nthreads_var_name = std::getenv("NMC_NUM_THREADS_VAR")) { + nthreads_str = std::getenv(nthreads_var_name); + } + else if (! (nthreads_str = std::getenv("NMC_NUM_THREADS"))) { + nthreads_str = std::getenv("OMP_NUM_THREADS"); + } + + // If the selected var is unset, + // or no var is set, + // error + if (! nthreads_str) { + terminate("No environmental var defined"); + } + + // only composed of spaces*digits*space* + auto nthreads_str_end{nthreads_str}; + while (std::isspace(*nthreads_str_end)) { + ++nthreads_str_end; + } + while (std::isdigit(*nthreads_str_end)) { + ++nthreads_str_end; + } + while (std::isspace(*nthreads_str_end)) { + ++nthreads_str_end; + } + if (*nthreads_str_end) { + terminate("Num threads is not a single integer"); + } + + // and it's got a single non-zero value + auto nthreads{std::atoi(nthreads_str)}; + if (! nthreads) { + terminate("Num threads is not a non-zero number"); + } + + return nthreads; +} + +task_pool& task_pool::get_global_task_pool() { + static task_pool global_task_pool{global_get_num_threads()}; + return global_task_pool; +} diff --git a/src/threading/cthread.hpp b/src/threading/cthread.hpp new file mode 100644 index 00000000..57bdbc32 --- /dev/null +++ b/src/threading/cthread.hpp @@ -0,0 +1,11 @@ +#pragma once + +#if !defined(NMC_HAVE_CTHREAD) + #error "this header can only be loaded if NMC_HAVE_CTHREAD is set" +#endif + +// task_group definition +#include "cthread_impl.hpp" + +// and sorts use cthread_parallel_stable_sort +#include "cthread_sort.hpp" diff --git a/src/threading/cthread_impl.hpp b/src/threading/cthread_impl.hpp new file mode 100644 index 00000000..d17d4407 --- /dev/null +++ b/src/threading/cthread_impl.hpp @@ -0,0 +1,269 @@ +#pragma once + + +#include <thread> +#include <mutex> +#include <algorithm> +#include <array> +#include <chrono> +#include <string> +#include <vector> +#include <type_traits> +#include <functional> +#include <condition_variable> +#include <utility> +#include <unordered_map> +#include <deque> + +#include <cstdlib> + +#include "timer.hpp" + +namespace nest { +namespace mc { +namespace threading { + +// Forward declare task_group at bottom of this header +class task_group; +using nest::mc::threading::impl::timer; + +namespace impl { + +using nest::mc::threading::task_group; +using std::mutex; +using lock = std::unique_lock<mutex>; +using std::condition_variable; + +using task = std::pair<std::function<void()>, task_group*>; +using task_queue = std::deque<task>; + +using thread_list = std::vector<std::thread>; +using thread_map = std::unordered_map<std::thread::id, std::size_t>; + +class task_pool { +private: + // lock and signal on task availability change + // this is the crucial bit + mutex tasks_mutex_; + condition_variable tasks_available_; + + // fifo of pending tasks + task_queue tasks_; + + // thread resource + thread_list threads_; + // threads -> index + thread_map thread_ids_; + // flag to handle exit from all threads + bool quit_ = false; + + // internals for taking tasks as a resource + // and running them (updating above) + // They get run by a thread in order to consume + // tasks + struct run_task; + // run tasks until a task_group tasks are done + // for wait + void run_tasks_while(task_group*); + // loop forever for secondary threads + // until quit is set + void run_tasks_forever(); + + // common code for the previous + // finished is a function/lambda + // that returns true when the infinite loop + // needs to be broken + template<typename B> + void run_tasks_loop(B finished ); + + // Create nthreads-1 new c std threads + // must be > 0 + // singled only created in static get_global_task_pool() + task_pool(std::size_t nthreads); + + // task_pool is a singleton + task_pool(const task_pool&) = delete; + task_pool& operator=(const task_pool&) = delete; + + // set quit and wait for secondary threads to end + ~task_pool(); + +public: + // Like tbb calls: run queues a task, + // wait waits for all tasks in the group to be done + void run(const task&); + void run(task&&); + void wait(task_group*); + + // includes master thread + int get_num_threads() { + return threads_.size() + 1; + } + + // get a stable integer for the current thread that + // is 0..nthreads + std::size_t get_current_thread() { + return thread_ids_[std::this_thread::get_id()]; + } + + // singleton constructor - needed to order construction + // with other singletons (profiler) + static task_pool& get_global_task_pool(); +}; +} //impl + +/////////////////////////////////////////////////////////////////////// +// types +/////////////////////////////////////////////////////////////////////// +template <typename T> +class enumerable_thread_specific { + impl::task_pool& global_task_pool; + + using storage_class = std::vector<T>; + storage_class data; + +public : + using iterator = typename storage_class::iterator; + using const_iterator = typename storage_class::const_iterator; + + enumerable_thread_specific(): + global_task_pool{impl::task_pool::get_global_task_pool()}, + data{std::vector<T>(global_task_pool.get_num_threads())} + {} + + enumerable_thread_specific(const T& init): + global_task_pool{impl::task_pool::get_global_task_pool()}, + data{std::vector<T>(global_task_pool.get_num_threads(), init)} + {} + + T& local() { + return data[global_task_pool.get_current_thread()]; + } + const T& local() const { + return data[global_task_pool.get_current_thread()]; + } + + auto size() -> decltype(data.size()) const { return data.size(); } + + iterator begin() { return data.begin(); } + iterator end() { return data.end(); } + + const_iterator begin() const { return data.begin(); } + const_iterator end() const { return data.end(); } + + const_iterator cbegin() const { return data.cbegin(); } + const_iterator cend() const { return data.cend(); } +}; + +template <typename T> +class parallel_vector { + using value_type = T; + std::vector<value_type> data_; + +private: + // lock the parallel_vector to update + impl::mutex mutex; + + // call a function of type X f() in a lock + template<typename F> + auto critical(F f) -> decltype(f()) { + impl::lock lock{mutex}; + return f(); + } + +public: + parallel_vector() = default; + using iterator = typename std::vector<value_type>::iterator; + using const_iterator = typename std::vector<value_type>::const_iterator; + + iterator begin() { return data_.begin(); } + iterator end() { return data_.end(); } + + const_iterator begin() const { return data_.begin(); } + const_iterator end() const { return data_.end(); } + + const_iterator cbegin() const { return data_.cbegin(); } + const_iterator cend() const { return data_.cend(); } + + // only guarantees the state of the vector, but not the iterators + // unlike tbb push_back + void push_back (value_type&& val) { + critical([&] { + data_.push_back(std::move(val)); + }); + } +}; + +inline std::string description() { + return "CThread Pool"; +} + +constexpr bool multithreaded() { return true; } + +class task_group { +private: + std::size_t in_flight = 0; + impl::task_pool& global_task_pool; + // task pool manipulates in_flight + friend impl::task_pool; + +public: + task_group(): + global_task_pool{impl::task_pool::get_global_task_pool()} + {} + + task_group(const task_group&) = delete; + task_group& operator=(const task_group&) = delete; + + // send function void f() to threads + template<typename F> + void run(const F& f) { + global_task_pool.run(impl::task{f, this}); + } + + template<typename F> + void run(F&& f) { + global_task_pool.run(impl::task{std::move(f), this}); + } + + // run function void f() and then wait on all threads in group + template<typename F> + void run_and_wait(const F& f) { + f(); + global_task_pool.wait(this); + } + + template<typename F> + void run_and_wait(F&& f) { + f(); + global_task_pool.wait(this); + } + + // wait till all tasks in this group are done + void wait() { + global_task_pool.wait(this); + } + + // Make sure that all tasks are done before clean up + ~task_group() { + wait(); + } +}; + +/////////////////////////////////////////////////////////////////////// +// algorithms +/////////////////////////////////////////////////////////////////////// +struct parallel_for { + template <typename F> + static void apply(int left, int right, F f) { + task_group g; + for(int i = left; i < right; ++i) { + g.run([=] {f(i);}); + } + g.wait(); + } +}; + +} // threading +} // mc +} // nest diff --git a/src/threading/cthread_parallel_stable_sort.h b/src/threading/cthread_parallel_stable_sort.h new file mode 100644 index 00000000..304b1487 --- /dev/null +++ b/src/threading/cthread_parallel_stable_sort.h @@ -0,0 +1,154 @@ +/* + Copyright (C) 2014 Intel Corporation + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, + INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS + OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED + AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY + WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. + + Modified for nestmc +*/ + +#include <algorithm> + +#include "pss_common.h" + +namespace pss { + +namespace internal { + +using task_group = nest::mc::threading::task_group; + +// Merge sequences [xs,xe) and [ys,ye) to output sequence [zs,zs+(xe-xs)+(ye-ys)) +// Destroy input sequence iff destroy==true +template<typename RandomAccessIterator1, + typename RandomAccessIterator2, + typename RandomAccessIterator3, + typename Compare> +void parallel_move_merge(RandomAccessIterator1 xs, + RandomAccessIterator1 xe, + RandomAccessIterator2 ys, + RandomAccessIterator2 ye, + RandomAccessIterator3 zs, + bool destroy, + Compare comp) +{ + task_group g; + const int MERGE_CUT_OFF = 2000; + while( (xe-xs) + (ye-ys) > MERGE_CUT_OFF ) { + RandomAccessIterator1 xm; + RandomAccessIterator2 ym; + if( xe-xs < ye-ys ) { + ym = ys+(ye-ys)/2; + xm = std::upper_bound(xs,xe,*ym,comp); + } else { + xm = xs+(xe-xs)/2; + ym = std::lower_bound(ys,ye,*xm,comp); + } + + g.run([=] { + parallel_move_merge( xs, xm, ys, ym, zs, destroy, comp); + }); + + zs += (xm-xs) + (ym-ys); + xs = xm; + ys = ym; + } + + serial_move_merge( xs, xe, ys, ye, zs, comp ); + if( destroy ) { + serial_destroy( xs, xe ); + serial_destroy( ys, ye ); + } + + g.wait(); +} + +// Sorts [xs,xe), where zs[0:xe-xs) is temporary buffer supplied by caller. +// Result is in [xs,xe) if inplace==true, otherwise in [zs,zs+(xe-xs)) +template<typename RandomAccessIterator1, + typename RandomAccessIterator2, + typename Compare> +void parallel_stable_sort_aux(RandomAccessIterator1 xs, + RandomAccessIterator1 xe, + RandomAccessIterator2 zs, + int inplace, + Compare comp) +{ + //typedef typename std::iterator_traits<RandomAccessIterator2>::value_type T; + const int SORT_CUT_OFF = 500; + if( xe-xs<=SORT_CUT_OFF ) { + stable_sort_base_case(xs, xe, zs, inplace, comp); + } + else { + RandomAccessIterator1 xm = xs + (xe-xs)/2; + RandomAccessIterator2 zm = zs + (xm-xs); + RandomAccessIterator2 ze = zs + (xe-xs); + + task_group g; + g.run([&] { + parallel_stable_sort_aux( xs, xm, zs, !inplace, comp ); + }); + parallel_stable_sort_aux( xm, xe, zm, !inplace, comp ); + g.wait(); + + if( inplace ) + parallel_move_merge( zs, zm, zm, ze, xs, inplace==2, comp ); + else + parallel_move_merge( xs, xm, xm, xe, zs, false, comp ); + } +} + +} // namespace internal + +template<typename RandomAccessIterator, typename Compare> +void parallel_stable_sort(RandomAccessIterator xs, + RandomAccessIterator xe, + Compare comp ) +{ + using T + = typename std::iterator_traits<RandomAccessIterator> + ::value_type; + + if(internal::raw_buffer z + = internal::raw_buffer( sizeof(T)*(xe-xs))) + internal::parallel_stable_sort_aux( xs, xe, + (T*)z.get(), 2, comp ); + else + // Not enough memory available - fall back on serial sort + std::stable_sort( xs, xe, comp ); +} + +template<class RandomAccessIterator> +void parallel_stable_sort(RandomAccessIterator xs, + RandomAccessIterator xe) +{ + using T + = typename std::iterator_traits<RandomAccessIterator> + ::value_type; + parallel_stable_sort(xs, xe, std::less<T>()); +} +} // namespace pss diff --git a/src/threading/cthread_sort.hpp b/src/threading/cthread_sort.hpp new file mode 100644 index 00000000..cbdfc246 --- /dev/null +++ b/src/threading/cthread_sort.hpp @@ -0,0 +1,26 @@ +// parallel stable sort uses threading +#include "cthread_parallel_stable_sort.h" + +namespace nest { +namespace mc { +namespace threading { + +template <typename RandomIt> +void sort(RandomIt begin, RandomIt end) { + pss::parallel_stable_sort(begin, end); +} + +template <typename RandomIt, typename Compare> +void sort(RandomIt begin, RandomIt end, Compare comp) { + pss::parallel_stable_sort(begin, end ,comp); +} + +template <typename Container> +void sort(Container& c) { + pss::parallel_stable_sort(c.begin(), c.end()); +} + + +} +} +} diff --git a/src/threading/omp.hpp b/src/threading/omp.hpp index d056e6b5..9a5eee45 100644 --- a/src/threading/omp.hpp +++ b/src/threading/omp.hpp @@ -13,10 +13,15 @@ #include <string> #include <vector> +#include "timer.hpp" + namespace nest { namespace mc { namespace threading { +using nest::mc::threading::impl::timer; + + /////////////////////////////////////////////////////////////////////// // types /////////////////////////////////////////////////////////////////////// @@ -113,22 +118,6 @@ inline std::string description() { return "OpenMP"; } -struct timer { - using time_point = std::chrono::time_point<std::chrono::system_clock>; - - static inline time_point tic() { - return std::chrono::system_clock::now(); - } - - static inline double toc(time_point t) { - return std::chrono::duration<double>(tic() - t).count(); - } - - static inline double difference(time_point b, time_point e) { - return std::chrono::duration<double>(e-b).count(); - } -}; - constexpr bool multithreaded() { return true; } diff --git a/src/threading/serial.hpp b/src/threading/serial.hpp index b8b51009..6876d3db 100644 --- a/src/threading/serial.hpp +++ b/src/threading/serial.hpp @@ -10,10 +10,14 @@ #include <string> #include <vector> +#include "timer.hpp" + namespace nest { namespace mc { namespace threading { +using nest::mc::threading::impl::timer; + /////////////////////////////////////////////////////////////////////// // types /////////////////////////////////////////////////////////////////////// @@ -85,22 +89,6 @@ inline std::string description() { return "serial"; } -struct timer { - using time_point = std::chrono::time_point<std::chrono::system_clock>; - - static inline time_point tic() { - return std::chrono::system_clock::now(); - } - - static inline double toc(time_point t) { - return std::chrono::duration<double>(tic() - t).count(); - } - - static inline double difference(time_point b, time_point e) { - return std::chrono::duration<double>(e-b).count(); - } -}; - constexpr bool multithreaded() { return false; } /// Proxy for tbb task group. diff --git a/src/threading/threading.hpp b/src/threading/threading.hpp index 157b936a..03979770 100644 --- a/src/threading/threading.hpp +++ b/src/threading/threading.hpp @@ -4,6 +4,8 @@ #include "tbb.hpp" #elif defined(NMC_HAVE_OMP) #include "omp.hpp" +#elif defined(NMC_HAVE_CTHREAD) + #include "cthread.hpp" #else #define NMC_HAVE_SERIAL #include "serial.hpp" diff --git a/src/threading/timer.hpp b/src/threading/timer.hpp new file mode 100644 index 00000000..bf8242a5 --- /dev/null +++ b/src/threading/timer.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include <chrono> + +namespace nest { +namespace mc { +namespace threading { +namespace impl{ + +struct timer { + using time_point = std::chrono::time_point<std::chrono::system_clock>; + + static inline time_point tic() { + return std::chrono::system_clock::now(); + } + + static inline double toc(time_point t) { + return std::chrono::duration<double>{tic() - t}.count(); + } + + static inline double difference(time_point b, time_point e) { + return std::chrono::duration<double>{e-b}.count(); + } +}; + +} +} +} +} -- GitLab