From 177cce0d810e4c186bd5960d80c6d471b895afc1 Mon Sep 17 00:00:00 2001
From: noraabiakar <nora.abiakar@gmail.com>
Date: Wed, 25 Jul 2018 15:16:41 +0200
Subject: [PATCH] fix profiler (#548)

Add initialize method to the profiler to set up the needed threading parameters given a simulation's task system.
---
 arbor/profile/profiler.cpp         | 25 ++++++++++++++++++++++---
 example/bench/bench.cpp            |  3 +++
 example/brunel/brunel_miniapp.cpp  |  3 +++
 example/miniapp/miniapp.cpp        |  4 +++-
 include/arbor/profile/profiler.hpp |  2 ++
 5 files changed, 33 insertions(+), 4 deletions(-)

diff --git a/arbor/profile/profiler.cpp b/arbor/profile/profiler.cpp
index 7a9ff84e..5ccea418 100644
--- a/arbor/profile/profiler.cpp
+++ b/arbor/profile/profiler.cpp
@@ -83,6 +83,8 @@ public:
 class profiler {
     std::vector<recorder> recorders_;
 
+    std::unordered_map<std::thread::id, std::size_t> thread_ids_;
+
     // Hash table that maps region names to a unique index.
     // The regions are assigned consecutive indexes in the order that they are
     // added to the profiler with calls to `region_index()`, with the first
@@ -96,9 +98,13 @@ class profiler {
     // Used to protect name_index_, which is shared between all threads.
     std::mutex mutex_;
 
+    // Flag to indicate whether the profiler has been initialized with the task_system
+    bool init_ = false;
+
 public:
     profiler();
 
+    void initialize(task_system_handle& ts);
     void enter(region_id_type index);
     void enter(const char* name);
     void leave();
@@ -169,17 +175,26 @@ profiler::profiler() {
     recorders_.resize(threading::num_threads());
 }
 
+void profiler::initialize(task_system_handle& ts) {
+    recorders_.resize(ts.get()->get_num_threads());
+    thread_ids_ = ts.get()->get_thread_ids();
+    init_ = true;
+}
+
 void profiler::enter(region_id_type index) {
-    recorders_[threading::thread_id()].enter(index);
+    if (!init_) return;
+    recorders_[thread_ids_.at(std::this_thread::get_id())].enter(index);
 }
 
 void profiler::enter(const char* name) {
+    if (!init_) return;
     const auto index = region_index(name);
-    recorders_[threading::thread_id()].enter(index);
+    recorders_[thread_ids_.at(std::this_thread::get_id())].enter(index);
 }
 
 void profiler::leave() {
-    recorders_[threading::thread_id()].leave();
+    if (!init_) return;
+    recorders_[thread_ids_.at(std::this_thread::get_id())].leave();
 }
 
 region_id_type profiler::region_index(const char* name) {
@@ -328,6 +343,10 @@ void profiler_enter(region_id_type region_id) {
     profiler::get_global_profiler().enter(region_id);
 }
 
+void profiler_initialize(task_system_handle& ts) {
+    profiler::get_global_profiler().initialize(ts);
+}
+
 // Print profiler statistics to an ostream
 std::ostream& operator<<(std::ostream& o, const profile& prof) {
     char buf[80];
diff --git a/example/bench/bench.cpp b/example/bench/bench.cpp
index 52e11c20..2c566e7b 100644
--- a/example/bench/bench.cpp
+++ b/example/bench/bench.cpp
@@ -36,6 +36,9 @@ int main(int argc, char** argv) {
 #ifdef ARB_HAVE_MPI
         aux::with_mpi guard(&argc, &argv);
         context.distributed = mpi_context(MPI_COMM_WORLD);
+#endif
+#ifdef ARB_HAVE_PROFILING
+        profile::profiler_initialize(context.thread_pool);
 #endif
         const bool is_root =  context.distributed.id()==0;
 
diff --git a/example/brunel/brunel_miniapp.cpp b/example/brunel/brunel_miniapp.cpp
index 463be264..45d82cb7 100644
--- a/example/brunel/brunel_miniapp.cpp
+++ b/example/brunel/brunel_miniapp.cpp
@@ -193,6 +193,9 @@ int main(int argc, char** argv) {
 #ifdef ARB_MPI_ENABLED
         with_mpi guard(argc, argv, false);
         context.distributed = mpi_context(MPI_COMM_WORLD);
+#endif
+#ifdef ARB_HAVE_PROFILING
+        profile::profiler_initialize(context.thread_pool);
 #endif
         arb::profile::meter_manager meters(&context.distributed);
         meters.start();
diff --git a/example/miniapp/miniapp.cpp b/example/miniapp/miniapp.cpp
index edbc220f..a809a3f3 100644
--- a/example/miniapp/miniapp.cpp
+++ b/example/miniapp/miniapp.cpp
@@ -51,7 +51,9 @@ int main(int argc, char** argv) {
         with_mpi guard(argc, argv, false);
         context.distributed = mpi_context(MPI_COMM_WORLD);
 #endif
-
+#ifdef ARB_HAVE_PROFILING
+        profile::profiler_initialize(context.thread_pool);
+#endif
         profile::meter_manager meters(&context.distributed);
         meters.start();
 
diff --git a/include/arbor/profile/profiler.hpp b/include/arbor/profile/profiler.hpp
index 6d831892..fa8da040 100644
--- a/include/arbor/profile/profiler.hpp
+++ b/include/arbor/profile/profiler.hpp
@@ -5,6 +5,7 @@
 #include <unordered_map>
 #include <vector>
 
+#include <arbor/execution_context.hpp>
 #include <arbor/profile/timer.hpp>
 
 namespace arb {
@@ -32,6 +33,7 @@ struct profile {
 };
 
 void profiler_clear();
+void profiler_initialize(task_system_handle& ts);
 void profiler_enter(std::size_t region_id);
 void profiler_leave();
 
-- 
GitLab