Skip to content
Snippets Groups Projects
Commit 6c98c1fc authored by Alexander Peyser's avatar Alexander Peyser Committed by Sam Yates
Browse files

Threading pool (#144)

Add threading pool built on `std::thread`

* Provide new threading model 'cthread' for nestmc based on a pool of `std::thread` objects.
* Unify duplicated timer class provided by `serial`, `omp` and now `cthread` threading models.
parent e5092d3f
No related branches found
No related tags found
No related merge requests found
......@@ -69,6 +69,16 @@ elseif(NMC_THREADING_MODEL MATCHES "omp")
add_definitions(-DNMC_HAVE_OMP)
set(NMC_HAVE_OMP TRUE)
elseif(NMC_THREADING_MODEL MATCHES "cthread")
find_package(Threads REQUIRED)
add_definitions(-DNMC_HAVE_CTHREAD)
set(NMC_HAVE_CTHREAD TRUE)
list(APPEND EXTERNAL_LIBRARIES ${CMAKE_THREAD_LIBS_INIT})
if(CMAKE_USE_PTHREADS_INIT)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
endif()
elseif(NMC_THREADING_MODEL MATCHES "serial")
#setup previously done
......
......@@ -20,6 +20,10 @@ if(NMC_WITH_MPI)
set(BASE_SOURCES ${BASE_SOURCES} communication/mpi.cpp)
endif()
if(NMC_HAVE_CTHREAD)
set(BASE_SOURCES ${BASE_SOURCES} threading/cthread.cpp)
endif()
add_library(nestmc ${BASE_SOURCES} ${HEADERS})
if (NMC_AUTO_RUN_MODCC_ON_CHANGES)
......
#include <cassert>
#include <exception>
#include <iostream>
#include "cthread.hpp"
using namespace nest::mc::threading::impl;
// RAII owner for a task in flight
struct task_pool::run_task {
task_pool& pool;
lock& lck;
task tsk;
run_task(task_pool&, lock&);
~run_task();
};
// Own a task in flight
// lock should be passed locked,
// and will be unlocked after call
task_pool::run_task::run_task(task_pool& pool, lock& lck):
pool{pool},
lck{lck},
tsk{}
{
std::swap(tsk, pool.tasks_.front());
pool.tasks_.pop_front();
lck.unlock();
pool.tasks_available_.notify_all();
}
// Release task
// Call unlocked, returns unlocked
task_pool::run_task::~run_task() {
lck.lock();
tsk.second->in_flight--;
lck.unlock();
pool.tasks_available_.notify_all();
}
template<typename B>
void task_pool::run_tasks_loop(B finished) {
lock lck{tasks_mutex_, std::defer_lock};
while (true) {
lck.lock();
while (! quit_ && tasks_.empty() && ! finished()) {
tasks_available_.wait(lck);
}
if (quit_ || finished()) {
return;
}
run_task run{*this, lck};
run.tsk.first();
}
}
// runs forever until quit is true
void task_pool::run_tasks_forever() {
run_tasks_loop([] {return false;});
}
// run until out of tasks for a group
void task_pool::run_tasks_while(task_group* g) {
run_tasks_loop([=] {return ! g->in_flight;});
}
// Create pool and threads
// new threads are nthreads-1
task_pool::task_pool(std::size_t nthreads):
tasks_mutex_{},
tasks_available_{},
tasks_{},
threads_{}
{
assert(nthreads > 0);
// now for the main thread
auto tid = std::this_thread::get_id();
thread_ids_[tid] = 0;
// and go from there
for (std::size_t i = 1; i < nthreads; i++) {
threads_.emplace_back([this]{run_tasks_forever();});
tid = threads_.back().get_id();
thread_ids_[tid] = i;
}
}
task_pool::~task_pool() {
{
lock lck{tasks_mutex_};
quit_ = true;
}
tasks_available_.notify_all();
for (auto& thread: threads_) {
thread.join();
}
}
// push a task into pool
void task_pool::run(const task& tsk) {
{
lock lck{tasks_mutex_};
tasks_.push_back(tsk);
tsk.second->in_flight++;
}
tasks_available_.notify_all();
}
void task_pool::run(task&& tsk) {
{
lock lck{tasks_mutex_};
tasks_.push_back(std::move(tsk));
tsk.second->in_flight++;
}
tasks_available_.notify_all();
}
// call on main thread
// uses this thread to run tasks
// and waits until the entire task
// queue is cleared
void task_pool::wait(task_group* g) {
run_tasks_while(g);
}
[[noreturn]]
static void terminate(const char *const msg) {
std::cerr << "NMC_NUM_THREADS_ERROR: " << msg << std::endl;
std::terminate();
}
// should check string, throw exception on missing or badly formed
static size_t global_get_num_threads() {
const char* nthreads_str;
// select variable to use:
// If NMC_NUM_THREADS_VAR is set, use $NMC_NUM_THREADS_VAR
// else if NMC_NUM_THREAD set, use it
// else if OMP_NUM_THREADS set, use it
if (auto nthreads_var_name = std::getenv("NMC_NUM_THREADS_VAR")) {
nthreads_str = std::getenv(nthreads_var_name);
}
else if (! (nthreads_str = std::getenv("NMC_NUM_THREADS"))) {
nthreads_str = std::getenv("OMP_NUM_THREADS");
}
// If the selected var is unset,
// or no var is set,
// error
if (! nthreads_str) {
terminate("No environmental var defined");
}
// only composed of spaces*digits*space*
auto nthreads_str_end{nthreads_str};
while (std::isspace(*nthreads_str_end)) {
++nthreads_str_end;
}
while (std::isdigit(*nthreads_str_end)) {
++nthreads_str_end;
}
while (std::isspace(*nthreads_str_end)) {
++nthreads_str_end;
}
if (*nthreads_str_end) {
terminate("Num threads is not a single integer");
}
// and it's got a single non-zero value
auto nthreads{std::atoi(nthreads_str)};
if (! nthreads) {
terminate("Num threads is not a non-zero number");
}
return nthreads;
}
task_pool& task_pool::get_global_task_pool() {
static task_pool global_task_pool{global_get_num_threads()};
return global_task_pool;
}
#pragma once
#if !defined(NMC_HAVE_CTHREAD)
#error "this header can only be loaded if NMC_HAVE_CTHREAD is set"
#endif
// task_group definition
#include "cthread_impl.hpp"
// and sorts use cthread_parallel_stable_sort
#include "cthread_sort.hpp"
#pragma once
#include <thread>
#include <mutex>
#include <algorithm>
#include <array>
#include <chrono>
#include <string>
#include <vector>
#include <type_traits>
#include <functional>
#include <condition_variable>
#include <utility>
#include <unordered_map>
#include <deque>
#include <cstdlib>
#include "timer.hpp"
namespace nest {
namespace mc {
namespace threading {
// Forward declare task_group at bottom of this header
class task_group;
using nest::mc::threading::impl::timer;
namespace impl {
using nest::mc::threading::task_group;
using std::mutex;
using lock = std::unique_lock<mutex>;
using std::condition_variable;
using task = std::pair<std::function<void()>, task_group*>;
using task_queue = std::deque<task>;
using thread_list = std::vector<std::thread>;
using thread_map = std::unordered_map<std::thread::id, std::size_t>;
class task_pool {
private:
// lock and signal on task availability change
// this is the crucial bit
mutex tasks_mutex_;
condition_variable tasks_available_;
// fifo of pending tasks
task_queue tasks_;
// thread resource
thread_list threads_;
// threads -> index
thread_map thread_ids_;
// flag to handle exit from all threads
bool quit_ = false;
// internals for taking tasks as a resource
// and running them (updating above)
// They get run by a thread in order to consume
// tasks
struct run_task;
// run tasks until a task_group tasks are done
// for wait
void run_tasks_while(task_group*);
// loop forever for secondary threads
// until quit is set
void run_tasks_forever();
// common code for the previous
// finished is a function/lambda
// that returns true when the infinite loop
// needs to be broken
template<typename B>
void run_tasks_loop(B finished );
// Create nthreads-1 new c std threads
// must be > 0
// singled only created in static get_global_task_pool()
task_pool(std::size_t nthreads);
// task_pool is a singleton
task_pool(const task_pool&) = delete;
task_pool& operator=(const task_pool&) = delete;
// set quit and wait for secondary threads to end
~task_pool();
public:
// Like tbb calls: run queues a task,
// wait waits for all tasks in the group to be done
void run(const task&);
void run(task&&);
void wait(task_group*);
// includes master thread
int get_num_threads() {
return threads_.size() + 1;
}
// get a stable integer for the current thread that
// is 0..nthreads
std::size_t get_current_thread() {
return thread_ids_[std::this_thread::get_id()];
}
// singleton constructor - needed to order construction
// with other singletons (profiler)
static task_pool& get_global_task_pool();
};
} //impl
///////////////////////////////////////////////////////////////////////
// types
///////////////////////////////////////////////////////////////////////
template <typename T>
class enumerable_thread_specific {
impl::task_pool& global_task_pool;
using storage_class = std::vector<T>;
storage_class data;
public :
using iterator = typename storage_class::iterator;
using const_iterator = typename storage_class::const_iterator;
enumerable_thread_specific():
global_task_pool{impl::task_pool::get_global_task_pool()},
data{std::vector<T>(global_task_pool.get_num_threads())}
{}
enumerable_thread_specific(const T& init):
global_task_pool{impl::task_pool::get_global_task_pool()},
data{std::vector<T>(global_task_pool.get_num_threads(), init)}
{}
T& local() {
return data[global_task_pool.get_current_thread()];
}
const T& local() const {
return data[global_task_pool.get_current_thread()];
}
auto size() -> decltype(data.size()) const { return data.size(); }
iterator begin() { return data.begin(); }
iterator end() { return data.end(); }
const_iterator begin() const { return data.begin(); }
const_iterator end() const { return data.end(); }
const_iterator cbegin() const { return data.cbegin(); }
const_iterator cend() const { return data.cend(); }
};
template <typename T>
class parallel_vector {
using value_type = T;
std::vector<value_type> data_;
private:
// lock the parallel_vector to update
impl::mutex mutex;
// call a function of type X f() in a lock
template<typename F>
auto critical(F f) -> decltype(f()) {
impl::lock lock{mutex};
return f();
}
public:
parallel_vector() = default;
using iterator = typename std::vector<value_type>::iterator;
using const_iterator = typename std::vector<value_type>::const_iterator;
iterator begin() { return data_.begin(); }
iterator end() { return data_.end(); }
const_iterator begin() const { return data_.begin(); }
const_iterator end() const { return data_.end(); }
const_iterator cbegin() const { return data_.cbegin(); }
const_iterator cend() const { return data_.cend(); }
// only guarantees the state of the vector, but not the iterators
// unlike tbb push_back
void push_back (value_type&& val) {
critical([&] {
data_.push_back(std::move(val));
});
}
};
inline std::string description() {
return "CThread Pool";
}
constexpr bool multithreaded() { return true; }
class task_group {
private:
std::size_t in_flight = 0;
impl::task_pool& global_task_pool;
// task pool manipulates in_flight
friend impl::task_pool;
public:
task_group():
global_task_pool{impl::task_pool::get_global_task_pool()}
{}
task_group(const task_group&) = delete;
task_group& operator=(const task_group&) = delete;
// send function void f() to threads
template<typename F>
void run(const F& f) {
global_task_pool.run(impl::task{f, this});
}
template<typename F>
void run(F&& f) {
global_task_pool.run(impl::task{std::move(f), this});
}
// run function void f() and then wait on all threads in group
template<typename F>
void run_and_wait(const F& f) {
f();
global_task_pool.wait(this);
}
template<typename F>
void run_and_wait(F&& f) {
f();
global_task_pool.wait(this);
}
// wait till all tasks in this group are done
void wait() {
global_task_pool.wait(this);
}
// Make sure that all tasks are done before clean up
~task_group() {
wait();
}
};
///////////////////////////////////////////////////////////////////////
// algorithms
///////////////////////////////////////////////////////////////////////
struct parallel_for {
template <typename F>
static void apply(int left, int right, F f) {
task_group g;
for(int i = left; i < right; ++i) {
g.run([=] {f(i);});
}
g.wait();
}
};
} // threading
} // mc
} // nest
/*
Copyright (C) 2014 Intel Corporation
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
* Neither the name of Intel Corporation nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
Modified for nestmc
*/
#include <algorithm>
#include "pss_common.h"
namespace pss {
namespace internal {
using task_group = nest::mc::threading::task_group;
// Merge sequences [xs,xe) and [ys,ye) to output sequence [zs,zs+(xe-xs)+(ye-ys))
// Destroy input sequence iff destroy==true
template<typename RandomAccessIterator1,
typename RandomAccessIterator2,
typename RandomAccessIterator3,
typename Compare>
void parallel_move_merge(RandomAccessIterator1 xs,
RandomAccessIterator1 xe,
RandomAccessIterator2 ys,
RandomAccessIterator2 ye,
RandomAccessIterator3 zs,
bool destroy,
Compare comp)
{
task_group g;
const int MERGE_CUT_OFF = 2000;
while( (xe-xs) + (ye-ys) > MERGE_CUT_OFF ) {
RandomAccessIterator1 xm;
RandomAccessIterator2 ym;
if( xe-xs < ye-ys ) {
ym = ys+(ye-ys)/2;
xm = std::upper_bound(xs,xe,*ym,comp);
} else {
xm = xs+(xe-xs)/2;
ym = std::lower_bound(ys,ye,*xm,comp);
}
g.run([=] {
parallel_move_merge( xs, xm, ys, ym, zs, destroy, comp);
});
zs += (xm-xs) + (ym-ys);
xs = xm;
ys = ym;
}
serial_move_merge( xs, xe, ys, ye, zs, comp );
if( destroy ) {
serial_destroy( xs, xe );
serial_destroy( ys, ye );
}
g.wait();
}
// Sorts [xs,xe), where zs[0:xe-xs) is temporary buffer supplied by caller.
// Result is in [xs,xe) if inplace==true, otherwise in [zs,zs+(xe-xs))
template<typename RandomAccessIterator1,
typename RandomAccessIterator2,
typename Compare>
void parallel_stable_sort_aux(RandomAccessIterator1 xs,
RandomAccessIterator1 xe,
RandomAccessIterator2 zs,
int inplace,
Compare comp)
{
//typedef typename std::iterator_traits<RandomAccessIterator2>::value_type T;
const int SORT_CUT_OFF = 500;
if( xe-xs<=SORT_CUT_OFF ) {
stable_sort_base_case(xs, xe, zs, inplace, comp);
}
else {
RandomAccessIterator1 xm = xs + (xe-xs)/2;
RandomAccessIterator2 zm = zs + (xm-xs);
RandomAccessIterator2 ze = zs + (xe-xs);
task_group g;
g.run([&] {
parallel_stable_sort_aux( xs, xm, zs, !inplace, comp );
});
parallel_stable_sort_aux( xm, xe, zm, !inplace, comp );
g.wait();
if( inplace )
parallel_move_merge( zs, zm, zm, ze, xs, inplace==2, comp );
else
parallel_move_merge( xs, xm, xm, xe, zs, false, comp );
}
}
} // namespace internal
template<typename RandomAccessIterator, typename Compare>
void parallel_stable_sort(RandomAccessIterator xs,
RandomAccessIterator xe,
Compare comp )
{
using T
= typename std::iterator_traits<RandomAccessIterator>
::value_type;
if(internal::raw_buffer z
= internal::raw_buffer( sizeof(T)*(xe-xs)))
internal::parallel_stable_sort_aux( xs, xe,
(T*)z.get(), 2, comp );
else
// Not enough memory available - fall back on serial sort
std::stable_sort( xs, xe, comp );
}
template<class RandomAccessIterator>
void parallel_stable_sort(RandomAccessIterator xs,
RandomAccessIterator xe)
{
using T
= typename std::iterator_traits<RandomAccessIterator>
::value_type;
parallel_stable_sort(xs, xe, std::less<T>());
}
} // namespace pss
// parallel stable sort uses threading
#include "cthread_parallel_stable_sort.h"
namespace nest {
namespace mc {
namespace threading {
template <typename RandomIt>
void sort(RandomIt begin, RandomIt end) {
pss::parallel_stable_sort(begin, end);
}
template <typename RandomIt, typename Compare>
void sort(RandomIt begin, RandomIt end, Compare comp) {
pss::parallel_stable_sort(begin, end ,comp);
}
template <typename Container>
void sort(Container& c) {
pss::parallel_stable_sort(c.begin(), c.end());
}
}
}
}
......@@ -13,10 +13,15 @@
#include <string>
#include <vector>
#include "timer.hpp"
namespace nest {
namespace mc {
namespace threading {
using nest::mc::threading::impl::timer;
///////////////////////////////////////////////////////////////////////
// types
///////////////////////////////////////////////////////////////////////
......@@ -113,22 +118,6 @@ inline std::string description() {
return "OpenMP";
}
struct timer {
using time_point = std::chrono::time_point<std::chrono::system_clock>;
static inline time_point tic() {
return std::chrono::system_clock::now();
}
static inline double toc(time_point t) {
return std::chrono::duration<double>(tic() - t).count();
}
static inline double difference(time_point b, time_point e) {
return std::chrono::duration<double>(e-b).count();
}
};
constexpr bool multithreaded() { return true; }
......
......@@ -10,10 +10,14 @@
#include <string>
#include <vector>
#include "timer.hpp"
namespace nest {
namespace mc {
namespace threading {
using nest::mc::threading::impl::timer;
///////////////////////////////////////////////////////////////////////
// types
///////////////////////////////////////////////////////////////////////
......@@ -85,22 +89,6 @@ inline std::string description() {
return "serial";
}
struct timer {
using time_point = std::chrono::time_point<std::chrono::system_clock>;
static inline time_point tic() {
return std::chrono::system_clock::now();
}
static inline double toc(time_point t) {
return std::chrono::duration<double>(tic() - t).count();
}
static inline double difference(time_point b, time_point e) {
return std::chrono::duration<double>(e-b).count();
}
};
constexpr bool multithreaded() { return false; }
/// Proxy for tbb task group.
......
......@@ -4,6 +4,8 @@
#include "tbb.hpp"
#elif defined(NMC_HAVE_OMP)
#include "omp.hpp"
#elif defined(NMC_HAVE_CTHREAD)
#include "cthread.hpp"
#else
#define NMC_HAVE_SERIAL
#include "serial.hpp"
......
#pragma once
#include <chrono>
namespace nest {
namespace mc {
namespace threading {
namespace impl{
struct timer {
using time_point = std::chrono::time_point<std::chrono::system_clock>;
static inline time_point tic() {
return std::chrono::system_clock::now();
}
static inline double toc(time_point t) {
return std::chrono::duration<double>{tic() - t}.count();
}
static inline double difference(time_point b, time_point e) {
return std::chrono::duration<double>{e-b}.count();
}
};
}
}
}
}
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment