diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp index c2036d91b5091965417bb8ead1b2d0740911789b..802d94f9d13053be78aa45721922e99ebe4b1c0f 100644 --- a/src/communication/communicator.hpp +++ b/src/communication/communicator.hpp @@ -69,7 +69,7 @@ public: /// must be called after all connections have been added void construct() { if (!std::is_sorted(connections_.begin(), connections_.end())) { - std::sort(connections_.begin(), connections_.end()); + threading::sort(connections_); } } diff --git a/src/threading/serial.hpp b/src/threading/serial.hpp index da2740a6b7bb8909f0c4ded32262f157cef7541d..d32058172ca0c2d5c9bb877f8402f7fc7fc9af3f 100644 --- a/src/threading/serial.hpp +++ b/src/threading/serial.hpp @@ -62,10 +62,24 @@ struct parallel_for { } }; +template <typename RandomIt> +void sort(RandomIt begin, RandomIt end) { + std::sort(begin, end); +} + +template <typename RandomIt, typename Compare> +void sort(RandomIt begin, RandomIt end, Compare comp) { + std::sort(begin, end, comp); +} + +template <typename Container> +void sort(Container& c) { + std::sort(c.begin(), c.end()); +} + template <typename T> using parallel_vector = std::vector<T>; - inline std::string description() { return "serial"; } diff --git a/src/threading/tbb.hpp b/src/threading/tbb.hpp index 853ceefde731cb1cd4db323541132c935011c73a..91a7b59b44ed6eaa8394da7fd5f4fa6586556a84 100644 --- a/src/threading/tbb.hpp +++ b/src/threading/tbb.hpp @@ -51,6 +51,21 @@ using parallel_vector = tbb::concurrent_vector<T>; using task_group = tbb::task_group; +template <typename RandomIt> +void sort(RandomIt begin, RandomIt end) { + tbb::parallel_sort(begin, end); +} + +template <typename RandomIt, typename Compare> +void sort(RandomIt begin, RandomIt end, Compare comp) { + tbb::parallel_sort(begin, end, comp); +} + +template <typename Container> +void sort(Container& c) { + tbb::parallel_sort(c.begin(), c.end()); +} + } // threading } // mc } // nest diff --git a/tests/unit/test_algorithms.cpp b/tests/unit/test_algorithms.cpp index b371b9d6f2d071057082e0354951876c7e732b80..9d370ddaf47c94fda470c01719a5c1989e5b766d 100644 --- a/tests/unit/test_algorithms.cpp +++ b/tests/unit/test_algorithms.cpp @@ -6,6 +6,28 @@ #include "../test_util.hpp" #include "util/debug.hpp" +/// tests the sort implementation in threading +/// is only parallel if TBB is being used +TEST(algorithms, parallel_sort) +{ + auto n = 10000; + std::vector<int> v(n); + std::iota(v.begin(), v.end(), 1); + + std::random_device rd; + std::shuffle(v.begin(), v.end(), std::mt19937(rd())); + + // assert that the original vector has in fact been permuted + EXPECT_FALSE(std::is_sorted(v.begin(), v.end())); + + nest::mc::threading::sort(v); + + EXPECT_TRUE(std::is_sorted(v.begin(), v.end())); + for(auto i=0; i<n; ++i) { + EXPECT_EQ(i+1, v[i]); + } +} + TEST(algorithms, sum) {