Skip to content
Snippets Groups Projects
  • Sam Yates's avatar
    New recipe probe API (#1054) · 8d866593
    Sam Yates authored
    * Make recipe return probes as a vector.
    * Remove `probe_id` from `probe_info`.
    * Rename `fvm_probe_info` to `fvm_probe_data` (everything was being called info, and it was getting confusing).
    * Make `probe_association_map` specific to `mc_cell_group`/`fvm_lowered_cell`.
    * Change probe_association_map representation: unordered map for probe_id -> tag; unordered multimap for probe_id -> fvm_probe_data. This allows multiple probes to be associated with the same probe id.
    * Call sampler callback separately for each probe with the same probe_id.
    * Replace location-based probes with locset equivalents.
    * Add index for probes sharing a probe id.
    * Bundle all probe metadata (id, tag, index, probe-specific meta) into `probe_metadata` struct for passing to sampler callbacks.
    * Change simple_sampler to work on `trace_vector`, a vector of `trace_data`. The _i_th element is the data from probe with index i.
    * Consolidate hash composition and `std::hash` specialization code in new header.
    * Update python lib for new API.
    * Update tests and examples for new recipe, internal probe, and simple_sampler APIs.
    * Update docs to suit.
    Unverified
    8d866593
recipe.hpp 6.35 KiB
#pragma once

#include <vector>

#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>

#include <arbor/event_generator.hpp>
#include <arbor/cable_cell_param.hpp>
#include <arbor/recipe.hpp>

#include "error.hpp"
#include "strprintf.hpp"

namespace pyarb {

arb::probe_info cable_probe(std::string kind, arb::cell_member_type id, arb::mlocation loc);

// pyarb::recipe is the recipe interface used by Python.
// Calls that return generic types return pybind11::object, to avoid
// having to wrap some C++ types used by the C++ interface (specifically
// util::unique_any, util::any, std::unique_ptr, etc.)
// For example, requests for cell description return pybind11::object, instead
// of util::unique_any used by the C++ recipe interface.
// The py_recipe_shim unwraps the python objects, and forwards them
// to the C++ back end.

class py_recipe {
public:
    py_recipe() = default;
    virtual ~py_recipe() {}

    virtual arb::cell_size_type num_cells() const = 0;
    virtual pybind11::object cell_description(arb::cell_gid_type gid) const = 0;
    virtual arb::cell_kind cell_kind(arb::cell_gid_type gid) const = 0;

    virtual arb::cell_size_type num_sources(arb::cell_gid_type) const {
        return 0;
    }
    virtual arb::cell_size_type num_targets(arb::cell_gid_type) const {
        return 0;
    }
    virtual arb::cell_size_type num_gap_junction_sites(arb::cell_gid_type gid) const {
        return gap_junctions_on(gid).size();
    }
    virtual std::vector<pybind11::object> event_generators(arb::cell_gid_type gid) const {
        return {};
    }
    virtual std::vector<arb::cell_connection> connections_on(arb::cell_gid_type gid) const {
        return {};
    }
    virtual std::vector<arb::gap_junction_connection> gap_junctions_on(arb::cell_gid_type) const {
        return {};
    }
    virtual std::vector<arb::probe_info> get_probes(arb::cell_gid_type gid) const {
        return {};
    }
    //TODO: virtual pybind11::object global_properties(arb::cell_kind kind) const {return pybind11::none();};
};

class py_recipe_trampoline: public py_recipe {
public:
    arb::cell_size_type num_cells() const override {
        PYBIND11_OVERLOAD_PURE(arb::cell_size_type, py_recipe, num_cells);
    }

    pybind11::object cell_description(arb::cell_gid_type gid) const override {
        PYBIND11_OVERLOAD_PURE(pybind11::object, py_recipe, cell_description, gid);
    }

    arb::cell_kind cell_kind(arb::cell_gid_type gid) const override {
        PYBIND11_OVERLOAD_PURE(arb::cell_kind, py_recipe, cell_kind, gid);
    }

    arb::cell_size_type num_sources(arb::cell_gid_type gid) const override {
        PYBIND11_OVERLOAD(arb::cell_size_type, py_recipe, num_sources, gid);
    }

    arb::cell_size_type num_targets(arb::cell_gid_type gid) const override {
        PYBIND11_OVERLOAD(arb::cell_size_type, py_recipe, num_targets, gid);
    }

    arb::cell_size_type num_gap_junction_sites(arb::cell_gid_type gid) const override {
        PYBIND11_OVERLOAD(arb::cell_size_type, py_recipe, num_gap_junction_sites, gid);
    }

    std::vector<pybind11::object> event_generators(arb::cell_gid_type gid) const override {
        PYBIND11_OVERLOAD(std::vector<pybind11::object>, py_recipe, event_generators, gid);
    }

    std::vector<arb::cell_connection> connections_on(arb::cell_gid_type gid) const override {
        PYBIND11_OVERLOAD(std::vector<arb::cell_connection>, py_recipe, connections_on, gid);
    }

    std::vector<arb::gap_junction_connection> gap_junctions_on(arb::cell_gid_type gid) const override {
        PYBIND11_OVERLOAD(std::vector<arb::gap_junction_connection>, py_recipe, gap_junctions_on, gid);
    }

    std::vector<arb::probe_info> get_probes(arb::cell_gid_type gid) const override {
        PYBIND11_OVERLOAD(std::vector<arb::probe_info>, py_recipe, get_probes, gid);
    }
};

// A recipe shim that holds a pyarb::recipe implementation.
// Unwraps/translates python-side output from pyarb::recipe and forwards
// to arb::recipe.
// For example, unwrap cell descriptions stored in PyObject, and rewrap
// in util::unique_any.

class py_recipe_shim: public arb::recipe {
    // pointer to the python recipe implementation
    std::shared_ptr<py_recipe> impl_;

public:
    using recipe::recipe;

    py_recipe_shim(std::shared_ptr<py_recipe> r): impl_(std::move(r)) {}

    const char* msg = "Python error already thrown";

    arb::cell_size_type num_cells() const override {
        return try_catch_pyexception([&](){ return impl_->num_cells(); }, msg);
    }

    // The pyarb::recipe::cell_decription returns a pybind11::object, that is
    // unwrapped and copied into a util::unique_any.
    arb::util::unique_any get_cell_description(arb::cell_gid_type gid) const override;

    arb::cell_kind get_cell_kind(arb::cell_gid_type gid) const override {
        return try_catch_pyexception([&](){ return impl_->cell_kind(gid); }, msg);
    }

    arb::cell_size_type num_sources(arb::cell_gid_type gid) const override {
        return try_catch_pyexception([&](){ return impl_->num_sources(gid); }, msg);
    }

    arb::cell_size_type num_targets(arb::cell_gid_type gid) const override {
        return try_catch_pyexception([&](){ return impl_->num_targets(gid); }, msg);
    }

    arb::cell_size_type num_gap_junction_sites(arb::cell_gid_type gid) const override {
        return try_catch_pyexception([&](){ return impl_->num_gap_junction_sites(gid); }, msg);
    }

    std::vector<arb::event_generator> event_generators(arb::cell_gid_type gid) const override;

    std::vector<arb::cell_connection> connections_on(arb::cell_gid_type gid) const override {
        return try_catch_pyexception([&](){ return impl_->connections_on(gid); }, msg);
    }

    std::vector<arb::gap_junction_connection> gap_junctions_on(arb::cell_gid_type gid) const override {
        return try_catch_pyexception([&](){ return impl_->gap_junctions_on(gid); }, msg);
    }

    std::vector<arb::probe_info> get_probes(arb::cell_gid_type gid) const override {
        return try_catch_pyexception([&](){ return impl_->get_probes(gid); }, msg);
    }

    // TODO: make thread safe
    arb::util::any get_global_properties(arb::cell_kind kind) const override {
        if (kind==arb::cell_kind::cable) {
            arb::cable_cell_global_properties gprop;
            gprop.default_parameters = arb::neuron_parameter_defaults;
            return gprop;
        }
        return arb::util::any{};
    }
};

} // namespace pyarb