#include <memory> #include <pybind11/numpy.h> #include <pybind11/pybind11.h> #include <arbor/common_types.hpp> #include <arbor/sampling.hpp> #include <arbor/simulation.hpp> #include "context.hpp" #include "error.hpp" #include "pyarb.hpp" #include "recipe.hpp" #include "schedule.hpp" namespace py = pybind11; namespace pyarb { // Argument type for simulation_shim::record() (see below). enum class spike_recording { off, local, all }; // Wraps an arb::simulation object and in addition manages a set of // sampler callbacks for retrieving probe data. class simulation_shim { std::unique_ptr<arb::simulation> sim_; std::vector<arb::spike> spike_record_; pyarb_global_ptr global_ptr_; using sample_recorder_ptr = std::unique_ptr<sample_recorder>; using sample_recorder_vec = std::vector<sample_recorder_ptr>; // These are only used as the target sampler of a single probe id. struct sampler_callback { std::shared_ptr<sample_recorder_vec> recorders; void operator()(arb::probe_metadata pm, std::size_t n_record, const arb::sample_record* records) { recorders->at(pm.index)->record(pm.meta, n_record, records); } py::list samples() const { std::size_t size = recorders->size(); py::list result(size); for (std::size_t i = 0; i<size; ++i) { result[i] = py::make_tuple(recorders->at(i)->samples(), recorders->at(i)->meta()); } return result; } }; std::unordered_map<arb::sampler_association_handle, sampler_callback> sampler_map_; public: simulation_shim(std::shared_ptr<py_recipe>& rec, const arb::domain_decomposition& decomp, const context_shim& ctx, pyarb_global_ptr global_ptr): global_ptr_(global_ptr) { try { sim_.reset(new arb::simulation(py_recipe_shim(rec), decomp, ctx.context)); } catch (...) { py_reset_and_throw(); throw; } } void reset() { sim_->reset(); spike_record_.clear(); for (auto&& [handle, cb]: sampler_map_) { for (auto& rec: *cb.recorders) { rec->reset(); } } } arb::time_type run(arb::time_type tfinal, arb::time_type dt) { return sim_->run(tfinal, dt); } void set_binning_policy(arb::binning_kind policy, arb::time_type bin_interval) { sim_->set_binning_policy(policy, bin_interval); } void record(spike_recording policy) { auto spike_recorder = [this](const std::vector<arb::spike>& spikes) { spike_record_.insert(spike_record_.end(), spikes.begin(), spikes.end()); }; switch (policy) { case spike_recording::off: sim_->set_global_spike_callback(); sim_->set_local_spike_callback(); break; case spike_recording::local: sim_->set_global_spike_callback(); sim_->set_local_spike_callback(spike_recorder); break; case spike_recording::all: sim_->set_global_spike_callback(spike_recorder); sim_->set_local_spike_callback(); break; } } py::object spikes() const { return py::array_t<arb::spike>(py::ssize_t(spike_record_.size()), spike_record_.data()); } py::list get_probe_metadata(arb::cell_member_type probe_id) const { py::list result; for (auto&& pm: sim_->get_probe_metadata(probe_id)) { result.append(global_ptr_->probe_meta_converters.convert(pm.meta)); } return result; } arb::sampler_association_handle sample(arb::cell_member_type probe_id, const pyarb::schedule_shim_base& sched, arb::sampling_policy policy) { std::shared_ptr<sample_recorder_vec> recorders{new sample_recorder_vec}; for (const arb::probe_metadata& pm: sim_->get_probe_metadata(probe_id)) { recorders->push_back(global_ptr_->recorder_factories.make_recorder(pm.meta)); } // Constructed callbacks are passed to the underlying simulator object, _and_ a copy // is kept in sampler_map_; the two copies share the same recorder data. sampler_callback cb{std::move(recorders)}; auto sah = sim_->add_sampler(arb::one_probe(probe_id), sched.schedule(), cb, policy); sampler_map_.insert({sah, cb}); return sah; } void remove_sampler(arb::sampler_association_handle sah) { sim_->remove_sampler(sah); sampler_map_.erase(sah); } void remove_all_samplers() { sim_->remove_all_samplers(); sampler_map_.clear(); } py::list samples(arb::sampler_association_handle sah) { if (auto iter = sampler_map_.find(sah); iter!=sampler_map_.end()) { return iter->second.samples(); } else { return py::list{}; } } }; void register_simulation(pybind11::module& m, pyarb_global_ptr global_ptr) { using namespace pybind11::literals; py::enum_<arb::sampling_policy>(m, "sampling_policy") .value("lax", arb::sampling_policy::lax) .value("exact", arb::sampling_policy::exact); py::enum_<spike_recording>(m, "spike_recording") .value("off", spike_recording::off) .value("local", spike_recording::local) .value("all", spike_recording::all); // Simulation py::class_<simulation_shim> simulation(m, "simulation", "The executable form of a model.\n" "A simulation is constructed from a recipe, and then used to update and monitor model state."); simulation // A custom constructor that wraps a python recipe with arb::py_recipe_shim // before forwarding it to the arb::recipe constructor. .def(pybind11::init( [global_ptr](std::shared_ptr<py_recipe>& rec, const arb::domain_decomposition& decomp, const context_shim& ctx) { try { return new simulation_shim(rec, decomp, ctx, global_ptr); } catch (...) { py_reset_and_throw(); throw; } }), // Release the python gil, so that callbacks into the python recipe don't deadlock. pybind11::call_guard<pybind11::gil_scoped_release>(), "Initialize the model described by a recipe, with cells and network distributed\n" "according to the domain decomposition and computational resources described by a context.", "recipe"_a, "domain_decomposition"_a, "context"_a) .def("reset", &simulation_shim::reset, pybind11::call_guard<pybind11::gil_scoped_release>(), "Reset the state of the simulation to its initial state.") .def("run", &simulation_shim::run, pybind11::call_guard<pybind11::gil_scoped_release>(), "Run the simulation from current simulation time to tfinal [ms], with maximum time step size dt [ms].", "tfinal"_a, "dt"_a=0.025) .def("set_binning_policy", &simulation_shim::set_binning_policy, "Set the binning policy for event delivery, and the binning time interval if applicable [ms].", "policy"_a, "bin_interval"_a) .def("record", &simulation_shim::record, "Disable or enable local or global spike recording.") .def("spikes", &simulation_shim::spikes, "Retrieve recorded spikes as numpy array.") .def("probe_metadata", &simulation_shim::get_probe_metadata, "Retrieve metadata associated with given probe id.", "probe_id"_a) .def("sample", &simulation_shim::sample, "Record data from probes with given probe_id according to supplied schedule.\n" "Returns handle for retrieving data or removing the sampling.", "probe_id"_a, "schedule"_a, "policy"_a = arb::sampling_policy::lax) .def("samples", &simulation_shim::samples, "Retrieve sample data as a list, one element per probe associated with the query.", "handle"_a) .def("remove_sampler", &simulation_shim::remove_sampler, "Remove sampling associated with the given handle.", "handle"_a) .def("remove_all_samplers", &simulation_shim::remove_sampler, "Remove all sampling on the simulatr."); } } // namespace pyarb