Skip to content
Snippets Groups Projects
simulation.cpp 9.11 KiB
#include <memory>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>

#include <arbor/common_types.hpp>
#include <arbor/sampling.hpp>
#include <arbor/simulation.hpp>

#include "context.hpp"
#include "error.hpp"
#include "pyarb.hpp"
#include "recipe.hpp"
#include "schedule.hpp"

namespace py = pybind11;

namespace pyarb {

// Argument type for simulation_shim::record() (see below).

enum class spike_recording {
    off, local, all
};

// Wraps an arb::simulation object and in addition manages a set of
// sampler callbacks for retrieving probe data.

class simulation_shim {
    std::unique_ptr<arb::simulation> sim_;
    std::vector<arb::spike> spike_record_;
    pyarb_global_ptr global_ptr_;

    using sample_recorder_ptr = std::unique_ptr<sample_recorder>;
    using sample_recorder_vec = std::vector<sample_recorder_ptr>;

    // These are only used as the target sampler of a single probe id.
    struct sampler_callback {
        std::shared_ptr<sample_recorder_vec> recorders;

        void operator()(arb::probe_metadata pm, std::size_t n_record, const arb::sample_record* records) {
            recorders->at(pm.index)->record(pm.meta, n_record, records);
        }

        py::list samples() const {
            std::size_t size = recorders->size();
            py::list result(size);

            for (std::size_t i = 0; i<size; ++i) {
                result[i] = py::make_tuple(recorders->at(i)->samples(), recorders->at(i)->meta());
            }
            return result;
        }
    };

    std::unordered_map<arb::sampler_association_handle, sampler_callback> sampler_map_;

public:
    simulation_shim(std::shared_ptr<py_recipe>& rec, const arb::domain_decomposition& decomp, const context_shim& ctx, pyarb_global_ptr global_ptr):
        global_ptr_(global_ptr)
    {
        try {
            sim_.reset(new arb::simulation(py_recipe_shim(rec), decomp, ctx.context));
        }
        catch (...) {
            py_reset_and_throw();
            throw;
        }
    }

    void reset() {
        sim_->reset();
        spike_record_.clear();
        for (auto&& [handle, cb]: sampler_map_) {
            for (auto& rec: *cb.recorders) {
                rec->reset();
            }
        }
    }

    void clear_samplers() {
        spike_record_.clear();
        for (auto&& [handle, cb]: sampler_map_) {
            for (auto& rec: *cb.recorders) {
                rec->reset();
            }
        }
    }

    arb::time_type run(arb::time_type tfinal, arb::time_type dt) {
        return sim_->run(tfinal, dt);
    }

    void set_binning_policy(arb::binning_kind policy, arb::time_type bin_interval) {
        sim_->set_binning_policy(policy, bin_interval);
    }

    void record(spike_recording policy) {
        auto spike_recorder = [this](const std::vector<arb::spike>& spikes) {
            auto old_size = spike_record_.size();
            // Append the new spikes to the end of the spike record.
            spike_record_.insert(spike_record_.end(), spikes.begin(), spikes.end());
            // Sort the newly appended spikes.
            std::sort(spike_record_.begin()+old_size, spike_record_.end(),
                    [](const auto& lhs, const auto& rhs) {
                        return std::tie(lhs.time, lhs.source.gid, lhs.source.index)<std::tie(rhs.time, rhs.source.gid, rhs.source.index);
                    });
        };

        switch (policy) {
        case spike_recording::off:
            sim_->set_global_spike_callback();
            sim_->set_local_spike_callback();
            break;
        case spike_recording::local:
            sim_->set_global_spike_callback();
            sim_->set_local_spike_callback(spike_recorder);
            break;
        case spike_recording::all:
            sim_->set_global_spike_callback(spike_recorder);
            sim_->set_local_spike_callback();
            break;
        }
    }

    py::object spikes() const {
        return py::array_t<arb::spike>(py::ssize_t(spike_record_.size()), spike_record_.data());
    }

    py::list get_probe_metadata(arb::cell_member_type probe_id) const {
        py::list result;
        for (auto&& pm: sim_->get_probe_metadata(probe_id)) {
             result.append(global_ptr_->probe_meta_converters.convert(pm.meta));
        }
        return result;
    }

    arb::sampler_association_handle sample(arb::cell_member_type probe_id, const pyarb::schedule_shim_base& sched, arb::sampling_policy policy) {
        std::shared_ptr<sample_recorder_vec> recorders{new sample_recorder_vec};

        for (const arb::probe_metadata& pm: sim_->get_probe_metadata(probe_id)) {
            recorders->push_back(global_ptr_->recorder_factories.make_recorder(pm.meta));
        }

        // Constructed callbacks are passed to the underlying simulator object, _and_ a copy
        // is kept in sampler_map_; the two copies share the same recorder data.

        sampler_callback cb{std::move(recorders)};
        auto sah = sim_->add_sampler(arb::one_probe(probe_id), sched.schedule(), cb, policy);
        sampler_map_.insert({sah, cb});

        return sah;
    }

    void remove_sampler(arb::sampler_association_handle sah) {
        sim_->remove_sampler(sah);
        sampler_map_.erase(sah);
    }

    void remove_all_samplers() {
        sim_->remove_all_samplers();
        sampler_map_.clear();
    }

    py::list samples(arb::sampler_association_handle sah) {
        if (auto iter = sampler_map_.find(sah); iter!=sampler_map_.end()) {
            return iter->second.samples();
        }
        else {
            return py::list{};
        }
    }
};

void register_simulation(pybind11::module& m, pyarb_global_ptr global_ptr) {
    using namespace pybind11::literals;

    py::enum_<arb::sampling_policy>(m, "sampling_policy")
       .value("lax", arb::sampling_policy::lax)
       .value("exact", arb::sampling_policy::exact);

    py::enum_<spike_recording>(m, "spike_recording")
       .value("off", spike_recording::off)
       .value("local", spike_recording::local)
       .value("all", spike_recording::all);

    // Simulation
    py::class_<simulation_shim> simulation(m, "simulation",
        "The executable form of a model.\n"
        "A simulation is constructed from a recipe, and then used to update and monitor model state.");
    simulation
        // A custom constructor that wraps a python recipe with arb::py_recipe_shim
        // before forwarding it to the arb::recipe constructor.
        .def(pybind11::init(
            [global_ptr](std::shared_ptr<py_recipe>& rec, const arb::domain_decomposition& decomp, const context_shim& ctx) {
                try {
                    return new simulation_shim(rec, decomp, ctx, global_ptr);
                }
                catch (...) {
                    py_reset_and_throw();
                    throw;
                }
            }),
            // Release the python gil, so that callbacks into the python recipe don't deadlock.
            pybind11::call_guard<pybind11::gil_scoped_release>(),
            "Initialize the model described by a recipe, with cells and network distributed\n"
            "according to the domain decomposition and computational resources described by a context.",
            "recipe"_a, "domain_decomposition"_a, "context"_a)
        .def("reset", &simulation_shim::reset,
            pybind11::call_guard<pybind11::gil_scoped_release>(),
            "Reset the state of the simulation to its initial state.")
        .def("clear_samplers", &simulation_shim::clear_samplers,
             pybind11::call_guard<pybind11::gil_scoped_release>(),
             "Clearing spike and sample information. restoring memory")
        .def("run", &simulation_shim::run,
            pybind11::call_guard<pybind11::gil_scoped_release>(),
            "Run the simulation from current simulation time to tfinal [ms], with maximum time step size dt [ms].",
            "tfinal"_a, "dt"_a=0.025)
        .def("set_binning_policy", &simulation_shim::set_binning_policy,
            "Set the binning policy for event delivery, and the binning time interval if applicable [ms].",
            "policy"_a, "bin_interval"_a)
        .def("record", &simulation_shim::record,
            "Disable or enable local or global spike recording.")
        .def("spikes", &simulation_shim::spikes,
            "Retrieve recorded spikes as numpy array.")
        .def("probe_metadata", &simulation_shim::get_probe_metadata,
            "Retrieve metadata associated with given probe id.",
            "probe_id"_a)
        .def("sample", &simulation_shim::sample,
            "Record data from probes with given probe_id according to supplied schedule.\n"
            "Returns handle for retrieving data or removing the sampling.",
            "probe_id"_a, "schedule"_a, "policy"_a = arb::sampling_policy::lax)
        .def("samples", &simulation_shim::samples,
            "Retrieve sample data as a list, one element per probe associated with the query.",
            "handle"_a)
        .def("remove_sampler", &simulation_shim::remove_sampler,
            "Remove sampling associated with the given handle.",
            "handle"_a)
        .def("remove_all_samplers", &simulation_shim::remove_sampler,
            "Remove all sampling on the simulatr.");

}

} // namespace pyarb