#include <mutex>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <arbor/common_types.hpp>
#include <arbor/sampling.hpp>
#include <arbor/simulation.hpp>
#include "error.hpp"
#include "strprintf.hpp"
namespace pyarb {
// TODO: trace entry of different types/container (e.g. vector of doubles to get all samples of a cell)
struct trace_sample {
arb::time_type t;
double v;
};
// A helper struct (state) ensuring that only one thread can write to the probe_buffers holding the trace entries (mapped by probe id)
struct sampler_state {
std::mutex mutex;
std::unordered_map<arb::cell_member_type, std::vector<trace_sample>> probe_buffers;
std::vector<trace_sample>& probe_buffer(arb::cell_member_type pid) {
// lock the mutex, s.t. other threads cannot write
std::lock_guard<std::mutex> lock(mutex);
// return or create entry
return probe_buffers[pid];
}
// helper function to search probe id in probe_buffers
bool has_pid(arb::cell_member_type pid) {
return probe_buffers.count(pid);
}
// helper function to push back to locked vector
void push_back(arb::cell_member_type pid, trace_sample value) {
auto& v = probe_buffer(pid);
v.push_back(std::move(value));
}
// Access the probe buffers
const std::unordered_map<arb::cell_member_type, std::vector<trace_sample>>& samples() const {
return probe_buffers;
}
};
// A functor that models arb::sampler_function.
// Holds a shared pointer to the trace_sample used to store the samples, so that if
// the trace_sample in sampler is garbage collected in Python, stores will
// not seg fault.
struct sample_callback {
std::shared_ptr<sampler_state> sample_store;
sample_callback(const std::shared_ptr<sampler_state>& state):
sample_store(state)
{}
void operator() (arb::probe_metadata pm, std::size_t n, const arb::sample_record* recs) {
auto& v = sample_store->probe_buffer(pm.id);
for (std::size_t i = 0; i<n; ++i) {
if (auto p = arb::util::any_cast<const double*>(recs[i].data)) {
v.push_back({recs[i].time, *p});
}
else {
throw std::runtime_error("unexpected sample type");
}
}
};
};
// Helper type for recording samples from a simulation.
// This type is wrapped in Python, to expose sampler::sample_store.
struct sampler {
std::shared_ptr<sampler_state> sample_store;
sample_callback callback() {
// initialize the sample_store
sample_store = std::make_shared<sampler_state>();
// The callback holds a copy of sample_store, i.e. the shared
// pointer is held by both the sampler and the callback, so if
// the sampler is destructed in the calling Python code, attempts
// to write to sample_store inside the callback will not seg fault.
return sample_callback(sample_store);
}
const std::vector<trace_sample>& samples(arb::cell_member_type pid) const {
if (!sample_store->has_pid(pid)) {
throw std::runtime_error(util::pprintf("probe id {} does not exist", pid));
}
return sample_store->probe_buffer(pid);
}
void clear() {
for (auto b: sample_store->probe_buffers) {
b.second.clear();
}
}
};
// Adds sampler to one probe with pid
std::shared_ptr<sampler> attach_sampler(arb::simulation& sim, arb::time_type interval, arb::cell_member_type pid) {
auto r = std::make_shared<sampler>();
sim.add_sampler(arb::one_probe(pid), arb::regular_schedule(interval), r->callback());
return r;
}
// Adds sampler to all probes
std::shared_ptr<sampler> attach_sampler(arb::simulation& sim, arb::time_type interval) {
auto r = std::make_shared<sampler>();
sim.add_sampler(arb::all_probes, arb::regular_schedule(interval), r->callback());
return r;
}
std::string sample_str(const trace_sample& s) {
return util::pprintf("<arbor.sample: time {} ms, \tvalue {}>", s.t, s.v);
}
void register_sampling(pybind11::module& m) {
using namespace pybind11::literals;
// Sample
pybind11::class_<trace_sample> trace_sample(m, "trace_sample");
trace_sample
.def_readonly("time", &trace_sample::t, "The sample time [ms] at a specific probe.")
.def_readonly("value", &trace_sample::v, "The sample record at a specific probe.")
.def("__str__", &sample_str)
.def("__repr__", &sample_str);
// Sampler
pybind11::class_<sampler, std::shared_ptr<sampler>> samplerec(m, "sampler");
samplerec
.def(pybind11::init<>())
.def("samples", &sampler::samples,
"A list of the recorded samples of a probe with probe id.",
"probe_id"_a)
.def("clear", &sampler::clear, "Clear all recorded samples.");
m.def("attach_sampler",
(std::shared_ptr<sampler> (*)(arb::simulation&, arb::time_type)) &attach_sampler,
"Attach a sample recorder to an arbor simulation.\n"
"The recorder will record all samples from a regular sampling interval [ms] matching all probe ids.",
"sim"_a, "dt"_a);
m.def("attach_sampler",
(std::shared_ptr<sampler> (*)(arb::simulation&, arb::time_type, arb::cell_member_type)) &attach_sampler,
"Attach a sample recorder to an arbor simulation.\n"
"The recorder will record all samples from a regular sampling interval [ms] matching one probe id.",
"sim"_a, "dt"_a, "probe_id"_a);
}
} // namespace pyarb