diff --git a/arbor/arbexcept.cpp b/arbor/arbexcept.cpp index f03cfae162257379d7b0b3eedeb456d768b30c50..46025c6c776373ce92bf70d8b241e0cfb016da1c 100644 --- a/arbor/arbexcept.cpp +++ b/arbor/arbexcept.cpp @@ -21,7 +21,7 @@ bad_target_description::bad_target_description(cell_gid_type gid, cell_size_type {} bad_source_description::bad_source_description(cell_gid_type gid, cell_size_type rec_val, cell_size_type cell_val): - arbor_exception(pprintf("Model building error on cell {}: recipe::num_sources(gid={}) = {} is greater than the number of detectors on the cell = {}", gid, gid, rec_val, cell_val)), + arbor_exception(pprintf("Model building error on cell {}: recipe::num_sources(gid={}) = {} is not equal to the number of detectors on the cell = {}", gid, gid, rec_val, cell_val)), gid(gid), rec_val(rec_val), cell_val(cell_val) {} diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index 701576cfd34a2628528c5b7dbe5f26e8a29b2ba8..b5a46f192981d026be20b2b2afdba4015009fda6 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -536,7 +536,7 @@ void fvm_lowered_cell_impl<Backend>::initialize( // Sanity check recipe auto& cell = cells[cell_idx]; - if (rec.num_sources(gid) > cell.detectors().size()) { + if (rec.num_sources(gid) != cell.detectors().size()) { throw arb::bad_source_description(gid, rec.num_sources(gid), cell.detectors().size());; } auto cell_targets = util::sum_by(cell.synapses(), [](auto& syn) {return syn.second.size();}); diff --git a/arbor/include/arbor/cable_cell.hpp b/arbor/include/arbor/cable_cell.hpp index 748f4c58c9094332e40a475a7dcd014f58851666..da2dc908d6da8cc2d36efccff763d5cf27f945cc 100644 --- a/arbor/include/arbor/cable_cell.hpp +++ b/arbor/include/arbor/cable_cell.hpp @@ -45,7 +45,7 @@ using cable_sample_range = std::pair<const double*, const double*>; // Sampler functions receive an `any_ptr` to sampled data. The underlying pointer // type is a const pointer to: // * `double` for scalar data; -// * `cable_sample_rang*` for vector data (see definition above). +// * `cable_sample_range*` for vector data (see definition above). // // The metadata associated with a probe is also passed to a sampler via an `any_ptr`; // the underlying pointer will be a const pointer to one of the following metadata types: diff --git a/arbor/include/arbor/sampling.hpp b/arbor/include/arbor/sampling.hpp index fa34494f0f308c579165075b9319cc4c52552c0e..ff57ae883b9ddb8e2c01651bbbc7740a59db5706 100644 --- a/arbor/include/arbor/sampling.hpp +++ b/arbor/include/arbor/sampling.hpp @@ -16,6 +16,12 @@ inline cell_member_predicate one_probe(cell_member_type pid) { return [pid](cell_member_type x) { return pid==x; }; } +// Probe-specific metadata is provided by cell group implementations. +// +// User code is responsible for correctly determining the metadata type, +// but the value of that metadata must be sufficient to determine the +// correct interpretation of sample data provided to sampler callbacks. + struct probe_metadata { cell_member_type id; // probe id probe_tag tag; // probe tag associated with id diff --git a/doc/cpp/cable_cell.rst b/doc/cpp/cable_cell.rst index 6df7f50451b14b19ab16a9a821fbd416ae08b3d7..b6e1bbedc15d68f8cf5f5292f58d4be12612d9ef 100644 --- a/doc/cpp/cable_cell.rst +++ b/doc/cpp/cable_cell.rst @@ -457,6 +457,8 @@ Overriding properties locally the morphology. +.. _cable-cell-probes: + Cable cell probes ----------------- diff --git a/doc/python/cable_cell.rst b/doc/python/cable_cell.rst index baaeed5c3ff3312dc79c69921865832616de3c7b..922854eada6283b0b1ed8ededbcb3bc5cd21b573 100644 --- a/doc/python/cable_cell.rst +++ b/doc/python/cable_cell.rst @@ -93,4 +93,154 @@ Cable cells properties of an ionic species. +.. _pycableprobes: + +Cable cell probes +----------------- + +Cable cell probe addresses are defined analagously to their counterparts in +the C++ API (see :ref:`cable-cell-probes` for details). Sample data recorded +by the Arbor simulation object is returned in the form of a NumPy array, +with the first column holding sample times, and subsequent columns holding +the corresponding scalar- or vector-valued sample. + +Location expressions will be realised as zero or more specific sites on +a cell; probe addresses defined over location expressions will describe zero, +one, or more probes, one per site. They are evaluated in the context of +the cell on which the probe is attached. + +Each of the functions described below generates an opaque :class:`probe` +object for use in the recipe :py:func:`get_probes` method. + +More information on probes, probe metadata, and sampling can be found +in the documentation for the class :class:`simulation`. + +Membrane voltage + .. py:function:: cable_probe_membrane_voltage(where) + + Cell membrane potential (mV) at the sites specified by the location + expression string ``where``. This value is spatially interpolated. + + Metadata: the explicit :class:`location` of the sample site. + + .. py:function:: cable_probe_membrane_voltage_cell() + + Cell membrane potential (mV) associated with each cable in each CV of + the cell discretization. + + Metadata: the list of corresponding :class:`cable` objects. + +Axial current + .. py:function:: cable_probe_axial_current(where) + + Estimation of intracellular current (nA) in the distal direction at the + sites specified by the location expression string ``where``. + + Metadata: the explicit :class:`location` of the sample site. + +Ionic current + .. py:function:: cable_probe_ion_current_density(where, ion) + + Transmembrane current density (A/m²) associated with the given ``ion`` at + sites specified by the location expression string ``where``. + + Metadata: the explicit :class:`location` of the sample site. + + .. py:function:: cable_probe_ion_current_cell(ion) + + Transmembrane current (nA) associated with the given ``ion`` across each + cable in each CV of the cell discretization. + + Metadata: the list of corresponding :class:`cable` objects. + +Total ionic current + .. py:function:: cable_probe_total_ion_current_density(where) + + Transmembrane current density (A/m²) _excluding_ capacitive currents at the + sites specified by the location expression string ``where``. + + Metadata: the explicit :class:`location` of the sample site. + + .. py:function:: cable_probe_total_ion_current_cell() + + Transmembrane current (nA) _excluding_ capacitive currents across each + cable in each CV of the cell discretization. + + Metadata: the list of corresponding :class:`cable` objects. + +Total transmembrane current + .. py:function:: cable_probe_total_current_cell() + + Transmembrane current (nA) _including_ capacitive currents across each + cable in each CV of the cell discretization. + + Metadata: the list of corresponding :class:`cable` objects. + +Density mechanism state variable + .. py:function:: cable_probe_density_state(where, mechanism, state) + + The value of the state variable ``state`` in the density mechanism ``mechanism`` + at the sites specified by the location expression ``where``. + + Metadata: the explicit :class:`location` of the sample site. + + .. py:function:: cable_probe_density_state_cell(mechanism, state) + + The value of the state variable ``state`` in the density mechanism ``mechanism`` + on each cable in each CV of the cell discretixation. + + Metadata: the list of corresponding :class:`cable` objects. + +Point process state variable + .. py:function:: cable_probe_point_state(target, mechanism, state) + + The value of the state variable ``state`` in the point process ``mechanism`` + associated with the target index ``target`` on the cell. If the given mechanism + is not associated with the target index, no probe will be generated. + + Metadata: an object of type :class:`cable_point_probe_info`, comprising three fields: + + * ``target``: target index on the cell; + + * ``multiplicity``: number of targets sharing the same state in the discretization; + + * ``location``: :class:`location` object corresponding to the target site. + + .. py:function:: cable_probe_point_state_cell(mechanism, state) + + The value of the state variable ``state`` in the point process ``mechanism`` + at each of the targets where that mechanism is defined. + + Metadata: a list of :class:`cable_point_probe_info` values, one for each matching + target. + +Ionic internal concentration + .. py:function:: cable_probe_ion_int_concentration(where, ion) + + Ionic internal concentration (mmol/L) of the given ``ion`` at the + sites specified by the location expression string ``where``. + + Metadata: the explicit :class:`location` of the sample site. + + .. py:function:: cable_probe_ion_int_concentration_cell(ion) + + Ionic internal concentration (mmol/L) of the given ``ion`` in each able in each + CV of the cell discretization. + + Metadata: the list of corresponding :class:`cable` objects. + +Ionic external concentration + .. py:function:: cable_probe_ion_ext_concentration(where, ion) + + Ionic external concentration (mmol/L) of the given ``ion`` at the + sites specified by the location expression string ``where``. + + Metadata: the explicit :class:`location` of the sample site. + + .. py:function:: cable_probe_ion_ext_concentration_cell(ion) + + Ionic external concentration (mmol/L) of the given ``ion`` in each able in each + CV of the cell discretization. + + Metadata: the list of corresponding :class:`cable` objects. diff --git a/doc/python/recipe.rst b/doc/python/recipe.rst index 7286c0c2a44940cf1f274a36aa21920ef42ef7db..eade1da2b38012fd24dd544997f9c5f691335bbe 100644 --- a/doc/python/recipe.rst +++ b/doc/python/recipe.rst @@ -92,7 +92,11 @@ Recipe .. function:: get_probes(gid) - Returns a list containing (in order) all the probes on a given cell `gid`. + Returns a list specifying the probe addresses describing probes on the cell ``gid``. + Each address in the list is an opaque object of type :class:`probe` produced by + cell kind-specific probe address functions. Each probe address in the list + has a corresponding probe id of type :class:`cell_member_type`: an id ``(gid, i)`` + refers to the probes described by the ith entry in the list returned by ``get_probes(gid)``. By default returns an empty list. @@ -106,29 +110,6 @@ Synapses See :ref:`pyinterconnectivity`. -Probes ------- - -.. class:: probe - - Describes the cell probe's information. - -.. function:: cable_probe(kind, id, location) - - Returns the description of a probe at an :class:`arbor.location` on a cable cell with :attr:`id` available for monitoring data of ``voltage`` or ``current`` :attr:`kind`. - - An example of a probe on a cable cell for measuring voltage at the soma reads as follows: - - .. container:: example-code - - .. code-block:: python - - import arbor - - id = arbor.cell_member(0, 0) # cell 0, probe 0 - loc = arbor.location(0, 0) # at the soma - probe = arbor.cable_probe('voltage', id, loc) - Event generator and schedules ----------------------------- @@ -268,9 +249,9 @@ helpers in cell_parameters and make_cable_cell for building cells are used. def num_cells(self): return self.ncells - # The cell_description method returns a cell + # The cell_description method returns a cell. def cell_description(self, gid): - return arbor.make_cable_cell(gid, self.params) + return make_cable_cell(gid, self.params) def num_targets(self, gid): return 1 @@ -298,5 +279,5 @@ helpers in cell_parameters and make_cable_cell for building cells are used. return [] def get_probes(self, id): - loc = arbor.location(0, 0) # at the soma - return [arbor.cable_probe('voltage', loc)] + # Probe just the membrane voltage at a location on the soma. + return [arbor.cable_probe_membrane_voltage('(location 0 0)')] diff --git a/doc/python/simulation.rst b/doc/python/simulation.rst index 34167ae59afbbcf300ea339002248061a78a3889..23e49b1e4ecf1a7b744ea65d7dfd12bbe647ff20 100644 --- a/doc/python/simulation.rst +++ b/doc/python/simulation.rst @@ -56,8 +56,10 @@ over the local and distributed hardware resources (see :ref:`pydomdec`). Then, t Simulations provide an interface for executing and interacting with the model: - * **Advance the model state** from one time to another and reset the model state to its original state before simulation was started. - * Sample the simulation state during the execution (e.g. compartment voltage and current) and generate spike output by using an **I/O interface**. + * Specify what data (spikes, probe results) to record. + * **Advance the model state** by running the simulation up to some time point. + * Retrieve recorded data. + * Reset simulator state back to initial conditions. **Constructor:** @@ -70,6 +72,7 @@ over the local and distributed hardware resources (see :ref:`pydomdec`). Then, t .. function:: reset() Reset the state of the simulation to its initial state. + Clears recorded spikes and sample data. .. function:: run(tfinal, dt) @@ -88,6 +91,57 @@ over the local and distributed hardware resources (see :ref:`pydomdec`). Then, t :param bin_interval: The binning time interval [ms]. + **Recording spike data:** + + .. function:: record(policy) + + Disable or enable recorder of rank-local or global spikes, as determined by the ``policy``. + + :param policy: Recording policy of type :class:`spike_recording`. + + .. function:: spikes() + + Return a NumPy structured array of spikes recorded during the course of a simulation. + Each spike is represented as a NumPy structured datatype with signature + ``('source', [('gid', '<u4'), ('index', '<u4')]), ('time', '<f8')``. + + **Sampling probes:** + + .. function:: sample(probe_id, schedule, policy) + + Set up a sampling schedule for the probes associated with the supplied probe_id of type :class:`cell_member`. + The schedule is any schedule object, as might be used with an event generator — see :ref:`pyrecipe` for details. + The policy is of type :class:`sampling_policy`. It can be omitted, in which case the sampling will accord with the + ``sampling_policy.lax`` policy. + + The method returns a handle which can be used in turn to retrieve the sampled data from the simulator or to + remove the corresponding sampling process. + + .. function:: probe_metadata(probe_id) + + Retrieve probe metadata for the probes associated with the given probe_id of type :class:`cell_member`. + The result will be a list, with one entry per probe; the specifics of each metadata entry will depend upon + the kind of probe in question. + + .. function:: remove_sampler(handle) + + Disable the sampling process referenced by the argument ``handle`` and remove any associated recorded data. + + .. function:: remove_all_samplers() + + Disable all sampling processes and remove any associated recorded data. + + .. function:: samples(handle) + + Retrieve a list of sample data associated with the given ``handle``. + There will be one entry in the list per probe associated with the probe id used when the sampling was set up. + Each entry is a pair ``(samples, meta)`` where ``meta`` is the probe metadata as would be returned by + ``probe_metadata(probe_id)``, and ``samples`` contains the recorded values. + + The format of the recorded values will depend upon the specifics of the probe, though generally it will + be a NumPy array, with the first column corresponding to sample time and subsequent columns holding + the value or values that were sampled from that probe at that time. + **Types:** .. class:: binning @@ -106,43 +160,48 @@ over the local and distributed hardware resources (see :ref:`pydomdec`). Then, t Round times down to previous event if within binning interval. -Recording spikes ----------------- -In order to analyze the simulation output spikes can be recorded. + .. class:: spike_recording + + Enumeration for spike recording policy. -**Types**: + .. attribute:: off -.. class:: spike + Disable spike recording. - .. function:: spike() + .. attribute:: local - Construct a spike. + Record all generated spikes from cells on this MPI rank. - .. attribute:: source + .. attribute:: all - The spike source (type: :class:`arbor.cell_member`). + Record all generated spikes from cells on all MPI ranks. - .. attribute:: time + .. class:: sampling_policy - The spike time [ms]. + Enumeration for deteriming sampling policy. -.. class:: spike_recorder + .. attribute:: lax - .. function:: spike_recorder() + Sampling times may not be exactly as requested in the sampling schedule, but + the process of sampling is guaranteed not to disturb the simulation progress or results. - Initialize the spike recorder. + .. attribute:: exact - .. attribute:: spikes + Interrupt the progress of the simulation as required to retrieve probe samples at exactly + those times requested by the sampling schedule. - The recorded spikes (type: :class:`spike`). +Recording spikes +---------------- -**I/O interface**: +By default, spikes are not recorded. Recording is enabled with the +:py:func:`simulation.record` method, which takes a single argument instructing +the simulation object to record no spikes, all locally generated spikes, or all +spikes generated by any MPI rank. -.. function:: attach_spike_recorder(sim) +Spikes recorded during a simulation are returned as a NumPy structured datatype with two fields, +``source`` and ``time``. The ``source`` field itself is a structured datatype with two fields, +``gid`` and ``index``, identifying the spike detector that generated the spike. - Attach a spike recorder to an arbor :class:`simulation` ``sim``. - The recorder that is returned will record all spikes generated after it has been - attached (spikes generated before attaching are not recorded). .. container:: example-code @@ -153,8 +212,8 @@ In order to analyze the simulation output spikes can be recorded. # Instatitate the simulation. sim = arbor.simulation(recipe, decomp, context) - # Build the spike recorder - recorder = arbor.attach_spike_recorder(sim) + # Direct the simulation to record all spikes. + sim.record(arbor.spike_recording.all) # Run the simulation for 2000 ms with time stepping of 0.025 ms tSim = 2000 @@ -162,19 +221,19 @@ In order to analyze the simulation output spikes can be recorded. sim.run(tSim, dt) # Print the spikes and according spike time - for s in recorder.spikes: + for s in sim.spikes(): print(s) ->>> <arbor.spike: source (0,0), time 2.15168 ms> ->>> <arbor.spike: source (1,0), time 14.5235 ms> ->>> <arbor.spike: source (2,0), time 26.9051 ms> ->>> <arbor.spike: source (3,0), time 39.4083 ms> ->>> <arbor.spike: source (4,0), time 51.9081 ms> ->>> <arbor.spike: source (5,0), time 64.2902 ms> ->>> <arbor.spike: source (6,0), time 76.7706 ms> ->>> <arbor.spike: source (7,0), time 89.1529 ms> ->>> <arbor.spike: source (8,0), time 101.641 ms> ->>> <arbor.spike: source (9,0), time 114.125 ms> +>>> ((0,0), 2.15168) +>>> ((1,0), 14.5235) +>>> ((2,0), 26.9051) +>>> ((3,0), 39.4083) +>>> ((4,0), 51.9081) +>>> ((5,0), 64.2902) +>>> ((6,0), 76.7706) +>>> ((7,0), 89.1529) +>>> ((8,0), 101.641) +>>> ((9,0), 114.125) Recording samples ----------------- @@ -183,53 +242,70 @@ Definitions *********** probe - A location or component of a cell that is available for monitoring (see :attr:`arbor.recipe.num_probes`, - :attr:`arbor.recipe.get_probes` and :attr:`arbor.cable_probe` as references). - -sample/record + A measurement that can be perfomed on a cell. Each cell kind will have its own sorts of probe; + Cable cells (:attr:`arbor.cable_probe`) allow the monitoring of membrane voltage, total membrane + current, mechanism state, and a number of other quantities, measured either over the whole cell, + or at specific sites (see :ref:`pycableprobes`). + + Probes are described by probe addresses, and the collection of probe addresses for a given cell is + provided by the :class:`recipe` object. One address may correspond to more than one probe: + as an example, a request for membrane voltage on a cable cell at sites specified by a location + expression will generate one probe for each site in that location expression. + +probe id + A designator for one or more probes as specified by a recipe. The *probe id* is a + :class:`cell_member` refering to a specific cell by gid, and the index into the list of + probe addresses returned by the recipe for that gid. + +metadata + Each probe has associated metadata describing, for example, the location on a cell where the + measurement is being taken, or other such identifying information. Metadata for the probes + associated with a *probe id* can be retrieved from the simulation object, and is also provided + along with any recorded samples. + +sample A record of data corresponding to the value at a specific *probe* at a specific time. -sampler/sample recorder - A function that receives a sequence of *sample* records. - -Samples and sample recorders -**************************** - -In order to analyze the data collected from an :class:`arbor.probe` the samples can be recorded. - -**Types**: +schedule + An object representing a series of monotonically increasing points in time, used for determining + sample times (see :ref:`pyrecipe`). -.. class:: trace_sample +Procedure +********* - .. attribute:: time + There are three parts to the process of recording cell data over a simulation. - The sample time [ms] at a specific probe. + 1. Describing what to measure. - .. attribute:: value + The recipe object must provide a method :py:func:`recipe.get_probes` that returns a list of + probe addresses for the cell with a given ``gid``. The kth element of the list corresponds + to the *probe id* ``(gid, k)``. - The sample record at a specific probe. + Each probe address is an opaue object describing what to measure and where, and each cell kind + will have its own set of functions for generating valid address specifications. Possible cable + cell probes are described in the cable cell documentation: :ref:`pycableprobes`. -.. class:: sampler + 2. Instructing the simulator to record data. - .. function:: sampler() + Recording is set up with the method :py:func:`simulation.sample` + as described above. It returns a handle that is used to retrieve the recorded data after + simulation. - Initialize the sample recorder. + 3. Retrieve recorded data. - .. function:: samples(probe_id) + The method :py:func:`simulation.samples` takes a handle and returns the recorded data as a list, + with one entry for each probe associated with the *probe id* that was used in step 2 above. Each + entry will be a tuple ``(data, meta)`` where ``meta`` is the metadata associated with the + probe, and ``data`` contains all the data sampled on that probe over the course of the + simulation. - A list of the recorded samples of a probe with probe id. + The contents of ``data`` will depend upon the specifics of the probe, but note: -**Sampling interface**: + i. The object type and structure of ``data`` is fully determined by the metadata. -.. function:: attach_sampler(sim, dt) - - Attach a sample recorder to an arbor simulation. - The recorder will record all samples from a regular sampling interval [ms] (see :class:`arbor.regular_schedule`) matching all probe ids. - -.. function:: attach_sampler(sim, dt, probe_id) - - Attach a sample recorder to an arbor simulation. - The recorder will record all samples from a regular sampling interval [ms] (see :class:`arbor.regular_schedule`) matching one probe id. + ii. All currently implemented probes return data that is a NumPy array, with one + row per sample, first column being sample time, and the remaining columns containing + the corresponding data. .. container:: example-code @@ -237,42 +313,49 @@ In order to analyze the data collected from an :class:`arbor.probe` the samples import arbor - # Instantiate the simulation. + # [... define recipe, decomposition, context ... ] + # Initialize simulation: + sim = arbor.simulation(recipe, decomp, context) - # Build the sample recorder on cell 0 and probe 0 with regular sampling interval of 0.1 ms - pid = arbor.cell_member(0,0) # cell 0, probe 0 - sampler = arbor.attach_sampler(sim, 0.1, pid) - - # Run the simulation for 100 ms - sim.run(100) - - # Print the sample times and values - for sa in sampler.samples(pid): - print(sa) - ->>> <arbor.sample: time 0 ms, value -65> ->>> <arbor.sample: time 0.1 ms, value -64.9981> ->>> <arbor.sample: time 0.2 ms, value -64.9967> ->>> <arbor.sample: time 0.3 ms, value -64.9956> ->>> <arbor.sample: time 0.4 ms, value -64.9947> ->>> <arbor.sample: time 0.475 ms, value -64.9941> ->>> <arbor.sample: time 0.6 ms, value -64.9932> ->>> <arbor.sample: time 0.675 ms, value -64.9927> ->>> <arbor.sample: time 0.8 ms, value -64.992> ->>> <arbor.sample: time 0.9 ms, value -64.9916> ->>> <arbor.sample: time 1 ms, value -64.9912> ->>> <arbor.sample: time 1.1 ms, value -62.936> ->>> <arbor.sample: time 1.2 ms, value -59.2284> ->>> <arbor.sample: time 1.3 ms, value -55.8485> ->>> <arbor.sample: time 1.375 ms, value -53.663> ->>> <arbor.sample: time 1.475 ms, value -51.0649> ->>> <arbor.sample: time 1.6 ms, value -47.9543> ->>> <arbor.sample: time 1.7 ms, value -45.1928> ->>> <arbor.sample: time 1.8 ms, value -41.7243> ->>> <arbor.sample: time 1.875 ms, value -38.2573> ->>> <arbor.sample: time 1.975 ms, value -31.576> ->>> <arbor.sample: time 2.1 ms, value -17.2756> ->>> <arbor.sample: time 2.2 ms, value 0.651031> ->>> <arbor.sample: time 2.275 ms, value 15.0592> + # Sample probe id (0, 0) (first probe id on cell 0) every 0.1 ms with exact sample timing: + + handle = sim.sample((0, 0), arbor.regular_schedule(0.1), arbor.sampling_policy.exact) + + # Run simulation and retrieve sample data from the first probe associated with the handle. + + sim.run(tfinal=3, dt=0.1) + data, meta = sim.samples(handle)[0] + print(data) + +>>> [[ 0. -50. ] +>>> [ 0.1 -55.14412111] +>>> [ 0.2 -59.17057625] +>>> [ 0.3 -62.58417912] +>>> [ 0.4 -65.47040168] +>>> [ 0.5 -67.80222861] +>>> [ 0.6 -15.18191623] +>>> [ 0.7 27.21110919] +>>> [ 0.8 48.74665099] +>>> [ 0.9 48.3515727 ] +>>> [ 1. 41.08435987] +>>> [ 1.1 33.53571111] +>>> [ 1.2 26.55165892] +>>> [ 1.3 20.16421752] +>>> [ 1.4 14.37227532] +>>> [ 1.5 9.16209063] +>>> [ 1.6 4.50159342] +>>> [ 1.7 0.34809083] +>>> [ 1.8 -3.3436289 ] +>>> [ 1.9 -6.61665687] +>>> [ 2. -9.51020525] +>>> [ 2.1 -12.05947812] +>>> [ 2.2 -14.29623969] +>>> [ 2.3 -16.24953688] +>>> [ 2.4 -17.94631322] +>>> [ 2.5 -19.41182385] +>>> [ 2.6 -52.19519009] +>>> [ 2.7 -62.53349949] +>>> [ 2.8 -69.22068995] +>>> [ 2.9 -73.41691825]] diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index de05eeb06c4bdf9c764692a1493d21d1a3a1113c..bd487a9a4c089d91a69cf673c728c59864642b36 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -15,6 +15,7 @@ set(PYBIND11_CPP_STANDARD -std=c++17) add_subdirectory(pybind11) set(pyarb_source + cable_probes.cpp cells.cpp config.cpp context.cpp @@ -29,7 +30,6 @@ set(pyarb_source profiler.cpp pyarb.cpp recipe.cpp - sampling.cpp schedule.cpp simulation.cpp single_cell_model.cpp diff --git a/python/cable_probes.cpp b/python/cable_probes.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fb7e1be401523bb6dc95a08a2e5c7188a5aa0683 --- /dev/null +++ b/python/cable_probes.cpp @@ -0,0 +1,294 @@ +#include <string> + +#include <pybind11/pybind11.h> +#include <pybind11/numpy.h> +#include <pybind11/pytypes.h> +#include <pybind11/stl.h> + +#include <arbor/cable_cell.hpp> +#include <arbor/morph/locset.hpp> +#include <arbor/recipe.hpp> +#include <arbor/sampling.hpp> +#include <arbor/util/any_ptr.hpp> + +#include "pyarb.hpp" +#include "strprintf.hpp" + +using arb::util::any_cast; +using arb::util::any_ptr; +namespace py = pybind11; + +namespace pyarb { + +// Generic recorder classes for array-output sample data, corresponding +// to cable_cell scalar- and vector-valued probes. + +template <typename Meta> +struct recorder_cable_base: sample_recorder { + // Return stride-column array: first column is time, remainder correspond to sample. + + py::object samples() const override { + auto n_record = std::ptrdiff_t(sample_raw_.size()/stride_); + return py::array_t<double>( + std::vector<std::ptrdiff_t>{n_record, stride_}, + sample_raw_.data()); + } + + py::object meta() const override { + return py::cast(meta_); + } + + void reset() override { + sample_raw_.clear(); + } + +protected: + Meta meta_; + std::vector<double> sample_raw_; + std::ptrdiff_t stride_; + + recorder_cable_base(const Meta* meta_ptr, std::ptrdiff_t width): + meta_(*meta_ptr), stride_(1+width) + {} +}; + +template <typename Meta> +struct recorder_cable_scalar: recorder_cable_base<Meta> { + using recorder_cable_base<Meta>::sample_raw_; + + void record(any_ptr, std::size_t n_sample, const arb::sample_record* records) override { + for (std::size_t i = 0; i<n_sample; ++i) { + if (auto* v_ptr =any_cast<const double*>(records[i].data)) { + sample_raw_.push_back(records[i].time); + sample_raw_.push_back(*v_ptr); + } + else { + throw arb::arbor_internal_error("unexpected sample type"); + } + } + } + +protected: + recorder_cable_scalar(const Meta* meta_ptr): recorder_cable_base<Meta>(meta_ptr, 1) {} +}; + +template <typename Meta> +struct recorder_cable_vector: recorder_cable_base<Meta> { + using recorder_cable_base<Meta>::sample_raw_; + + void record(any_ptr, std::size_t n_sample, const arb::sample_record* records) override { + for (std::size_t i = 0; i<n_sample; ++i) { + if (auto* v_ptr = any_cast<const arb::cable_sample_range*>(records[i].data)) { + sample_raw_.push_back(records[i].time); + sample_raw_.insert(sample_raw_.end(), v_ptr->first, v_ptr->second); + } + else { + throw arb::arbor_internal_error("unexpected sample type"); + } + } + } + +protected: + recorder_cable_vector(const Meta* meta_ptr, std::ptrdiff_t width): + recorder_cable_base<Meta>(meta_ptr, width) {} +}; + +// Specific recorder classes: + +struct recorder_cable_scalar_mlocation: recorder_cable_scalar<arb::mlocation> { + explicit recorder_cable_scalar_mlocation(const arb::mlocation* meta_ptr): + recorder_cable_scalar(meta_ptr) {} +}; + +struct recorder_cable_scalar_point_info: recorder_cable_scalar<arb::cable_probe_point_info> { + explicit recorder_cable_scalar_point_info(const arb::cable_probe_point_info* meta_ptr): + recorder_cable_scalar(meta_ptr) {} +}; + +struct recorder_cable_vector_mcable: recorder_cable_vector<arb::mcable_list> { + explicit recorder_cable_vector_mcable(const arb::mcable_list* meta_ptr): + recorder_cable_vector(meta_ptr, std::ptrdiff_t(meta_ptr->size())) {} +}; + +struct recorder_cable_vector_point_info: recorder_cable_vector<std::vector<arb::cable_probe_point_info>> { + explicit recorder_cable_vector_point_info(const std::vector<arb::cable_probe_point_info>* meta_ptr): + recorder_cable_vector(meta_ptr, std::ptrdiff_t(meta_ptr->size())) {} +}; + +// Helper for registering sample recorder factories and (trivial) metadata conversions. + +template <typename Meta, typename Recorder> +void register_probe_meta_maps(pyarb_global_ptr g) { + g->recorder_factories.assign<Meta>( + [](any_ptr meta_ptr) -> std::unique_ptr<sample_recorder> { + return std::unique_ptr<Recorder>(new Recorder(any_cast<const Meta*>(meta_ptr))); + }); + + g->probe_meta_converters.assign<Meta>( + [](any_ptr meta_ptr) -> py::object { + return py::cast(*any_cast<const Meta*>(meta_ptr)); + }); +} + +// Wrapper functions around cable_cell probe types that return arb::probe_info values: +// (Probe tag value is implicitly left at zero.) + +arb::probe_info cable_probe_membrane_voltage(const char* where) { + return arb::cable_probe_membrane_voltage{arb::locset(where)}; +} + +arb::probe_info cable_probe_membrane_voltage_cell() { + return arb::cable_probe_membrane_voltage_cell{}; +} + +arb::probe_info cable_probe_axial_current(const char* where) { + return arb::cable_probe_axial_current{arb::locset(where)}; +} + +arb::probe_info cable_probe_total_ion_current_density(const char* where) { + return arb::cable_probe_total_ion_current_density{arb::locset(where)}; +} + +arb::probe_info cable_probe_total_ion_current_cell() { + return arb::cable_probe_total_ion_current_cell{}; +} + +arb::probe_info cable_probe_total_current_cell() { + return arb::cable_probe_total_current_cell{}; +} + +arb::probe_info cable_probe_density_state(const char* where, const char* mechanism, const char* state) { + return arb::cable_probe_density_state{arb::locset(where), mechanism, state}; +}; + +arb::probe_info cable_probe_density_state_cell(const char* mechanism, const char* state) { + return arb::cable_probe_density_state_cell{mechanism, state}; +}; + +arb::probe_info cable_probe_point_state(arb::cell_lid_type target, const char* mechanism, const char* state) { + return arb::cable_probe_point_state{target, mechanism, state}; +} + +arb::probe_info cable_probe_point_state_cell(const char* mechanism, const char* state_var) { + return arb::cable_probe_point_state_cell{mechanism, state_var}; +} + +arb::probe_info cable_probe_ion_current_density(const char* where, const char* ion) { + return arb::cable_probe_ion_current_density{arb::locset(where), ion}; +} + +arb::probe_info cable_probe_ion_current_cell(const char* ion) { + return arb::cable_probe_ion_current_cell{ion}; +} + +arb::probe_info cable_probe_ion_int_concentration(const char* where, const char* ion) { + return arb::cable_probe_ion_int_concentration{arb::locset(where), ion}; +} + +arb::probe_info cable_probe_ion_int_concentration_cell(const char* ion) { + return arb::cable_probe_ion_int_concentration_cell{ion}; +} + +arb::probe_info cable_probe_ion_ext_concentration(const char* where, const char* ion) { + return arb::cable_probe_ion_ext_concentration{arb::locset(where), ion}; +} + +arb::probe_info cable_probe_ion_ext_concentration_cell(const char* ion) { + return arb::cable_probe_ion_ext_concentration_cell{ion}; +} + +// Add wrappers to module, recorder factories to global data. + +void register_cable_probes(pybind11::module& m, pyarb_global_ptr global_ptr) { + using namespace pybind11::literals; + using util::pprintf; + + // Probe metadata wrappers: + + py::class_<arb::cable_probe_point_info> cable_probe_point_info(m, "cable_probe_point_info", + "Probe metadata associated with a cable cell probe for point process state."); + + cable_probe_point_info + .def_readwrite("target", &arb::cable_probe_point_info::target, + "The target index of the point process instance on the cell.") + .def_readwrite("multiplicity", &arb::cable_probe_point_info::multiplicity, + "Number of coalesced point processes (linear synapses) associated with this instance.") + .def_readwrite("location", &arb::cable_probe_point_info::loc, + "Location of point process instance on cell.") + .def("__str__", [](arb::cable_probe_point_info m) { + return pprintf("<arbor.cable_probe_point_info: target {}, multiplicity {}, location {}>", m.target, m.multiplicity, m.loc);}) + .def("__repr__",[](arb::cable_probe_point_info m) { + return pprintf("<arbor.cable_probe_point_info: target {}, multiplicity {}, location {}>", m.target, m.multiplicity, m.loc);}); + + // Probe address constructors: + + m.def("cable_probe_membrane_voltage", &cable_probe_membrane_voltage, + "Probe specification for cable cell membrane voltage interpolated at points in a location set.", + "where"_a); + + m.def("cable_probe_membrane_voltage_cell", &cable_probe_membrane_voltage_cell, + "Probe specification for cable cell membrane voltage associated with each cable in each CV."); + + m.def("cable_probe_axial_current", &cable_probe_axial_current, + "Probe specification for cable cell axial current at points in a location set.", + "where"_a); + + m.def("cable_probe_total_ion_current_density", &cable_probe_total_ion_current_density, + "Probe specification for cable cell total transmembrane current density excluding capacitive currents at points in a location set.", + "where"_a); + + m.def("cable_probe_total_ion_current_cell", &cable_probe_total_ion_current_cell, + "Probe specification for cable cell total transmembrane current excluding capacitive currents for each cable in each CV."); + + m.def("cable_probe_total_current_cell", &cable_probe_total_current_cell, + "Probe specification for cable cell total transmembrane current for each cable in each CV."); + + m.def("cable_probe_density_state", &cable_probe_density_state, + "Probe specification for a cable cell density mechanism state variable at points in a location set.", + "where"_a, "mechanism"_a, "state"_a); + + m.def("cable_probe_density_state_cell", &cable_probe_density_state_cell, + "Probe specification for a cable cell density mechanism state variable on each cable in each CV.", + "mechanism"_a, "state"_a); + + m.def("cable_probe_point_state", &cable_probe_point_state, + "Probe specification for a cable cell point mechanism state variable value at a given target index.", + "target"_a, "mechanism"_a, "state"_a); + + m.def("cable_probe_point_state_cell", &cable_probe_point_state_cell, + "Probe specification for a cable cell point mechanism state variable value at every corresponding target.", + "mechanism"_a, "state"_a); + + m.def("cable_probe_ion_current_density", &cable_probe_ion_current_density, + "Probe specification for cable cell ionic current density at points in a location set.", + "where"_a, "ion"_a); + + m.def("cable_probe_ion_current_cell", &cable_probe_ion_current_cell, + "Probe specification for cable cell ionic current across each cable in each CV.", + "ion"_a); + + m.def("cable_probe_ion_int_concentration", &cable_probe_ion_int_concentration, + "Probe specification for cable cell internal ionic concentration at points in a location set.", + "where"_a, "ion"_a); + + m.def("cable_probe_ion_int_concentration_cell", &cable_probe_ion_int_concentration_cell, + "Probe specification for cable cell internal ionic concentration for each cable in each CV.", + "ion"_a); + + m.def("cable_probe_ion_ext_concentration", &cable_probe_ion_ext_concentration, + "Probe specification for cable cell external ionic concentration at points in a location set.", + "where"_a, "ion"_a); + + m.def("cable_probe_ion_ext_concentration_cell", &cable_probe_ion_ext_concentration_cell, + "Probe specification for cable cell external ionic concentration for each cable in each CV.", + "ion"_a); + + // Add probe metadata to maps for converters and recorders. + + register_probe_meta_maps<arb::mlocation, recorder_cable_scalar_mlocation>(global_ptr); + register_probe_meta_maps<arb::cable_probe_point_info, recorder_cable_scalar_point_info>(global_ptr); + register_probe_meta_maps<arb::mcable_list, recorder_cable_vector_mcable>(global_ptr); + register_probe_meta_maps<std::vector<arb::cable_probe_point_info>, recorder_cable_vector_point_info>(global_ptr); +} + +} // namespace pyarb diff --git a/python/cells.cpp b/python/cells.cpp index 2e6aed27f2915f35ebc3b8e0764430c169f53672..62ac4011019ef571b8a97f21bdad356e37271cf7 100644 --- a/python/cells.cpp +++ b/python/cells.cpp @@ -268,7 +268,7 @@ void register_cells(pybind11::module& m) { // arb::lif_cell pybind11::class_<arb::lif_cell> lif_cell(m, "lif_cell", - "A benchmarking cell, used by Arbor developers to test communication performance."); + "A leaky integrate-and-fire cell."); lif_cell .def(pybind11::init<>()) diff --git a/python/event_generator.cpp b/python/event_generator.cpp index 957e75e2d56e8ceea5d73889e5358035f025953a..f89df17c1f2c9700a4c4a01d59f1b28ab019ad42 100644 --- a/python/event_generator.cpp +++ b/python/event_generator.cpp @@ -10,15 +10,6 @@ namespace pyarb { -template <typename Sched> -event_generator_shim make_event_generator( - arb::cell_member_type target, - double weight, - const Sched& sched) -{ - return event_generator_shim(target, weight, sched.schedule()); -} - void register_event_generators(pybind11::module& m) { using namespace pybind11::literals; @@ -26,29 +17,13 @@ void register_event_generators(pybind11::module& m) { event_generator .def(pybind11::init<>( - [](arb::cell_member_type target, double weight, const regular_schedule_shim& sched){ - return make_event_generator(target, weight, sched);}), - "target"_a, "weight"_a, "sched"_a, - "Construct an event generator with arguments:\n" - " target: The target synapse (gid, local_id).\n" - " weight: The weight of events to deliver.\n" - " sched: A regular schedule of the events.") - .def(pybind11::init<>( - [](arb::cell_member_type target, double weight, const explicit_schedule_shim& sched){ - return make_event_generator(target, weight, sched);}), - "target"_a, "weight"_a, "sched"_a, - "Construct an event generator with arguments:\n" - " target: The target synapse (gid, local_id).\n" - " weight: The weight of events to deliver.\n" - " sched: An explicit schedule of the events.") - .def(pybind11::init<>( - [](arb::cell_member_type target, double weight, const poisson_schedule_shim& sched){ - return make_event_generator(target, weight, sched);}), + [](arb::cell_member_type target, double weight, const schedule_shim_base& sched) { + return event_generator_shim(target, weight, sched.schedule()); }), "target"_a, "weight"_a, "sched"_a, "Construct an event generator with arguments:\n" " target: The target synapse (gid, local_id).\n" " weight: The weight of events to deliver.\n" - " sched: A poisson schedule of the events.") + " sched: A schedule of the events.") .def_readwrite("target", &event_generator_shim::target, "The target synapse (gid, local_id).") .def_readwrite("weight", &event_generator_shim::weight, diff --git a/python/example/network_ring.py b/python/example/network_ring.py index 5dad8be70edca8753aab35b6c0c39da45ce00578..baffa01d79ea680811bb8421d5d5ed99df02b4c2 100755 --- a/python/example/network_ring.py +++ b/python/example/network_ring.py @@ -2,6 +2,7 @@ import arbor import pandas, seaborn +import numpy from math import sqrt # Construct a cell with the following morphology. @@ -95,8 +96,7 @@ class ring_recipe (arbor.recipe): return [] def get_probes(self, gid): - loc = arbor.location(0, 0) # at the soma - return [arbor.cable_probe('voltage', loc)] + return [arbor.cable_probe_membrane_voltage('(location 0 0)')] context = arbor.context(threads=12, gpu_id=None) print(context) @@ -122,14 +122,13 @@ print(f'{decomp}') meters.checkpoint('load-balance', context) sim = arbor.simulation(recipe, decomp, context) +sim.record(arbor.spike_recording.all) meters.checkpoint('simulation-init', context) -spike_recorder = arbor.attach_spike_recorder(sim) - # Attach a sampler to the voltage probe on cell 0. # Sample rate of 10 sample every ms. -samplers = [arbor.attach_sampler(sim, 0.1, arbor.cell_member(gid,0)) for gid in range(ncells)] +handles = [sim.sample((gid, 0), arbor.regular_schedule(0.1)) for gid in range(ncells)] tfinal=100 sim.run(tfinal) @@ -142,16 +141,15 @@ print(f'{arbor.meter_report(meters, context)}') # Print spike times print('spikes:') -for sp in spike_recorder.spikes: +for sp in sim.spikes(): print(' ', sp) # Plot the recorded voltages over time. print("Plotting results ...") df_list = [] for gid in range(ncells): - times = [s.time for s in samplers[gid].samples(arbor.cell_member(gid,0))] - volts = [s.value for s in samplers[gid].samples(arbor.cell_member(gid,0))] - df_list.append(pandas.DataFrame({'t/ms': times, 'U/mV': volts, 'Cell': f"cell {gid}"})) + samples, meta = sim.samples(handles[gid])[0] + df_list.append(pandas.DataFrame({'t/ms': samples[:, 0], 'U/mV': samples[:, 1], 'Cell': f"cell {gid}"})) df = pandas.concat(df_list) seaborn.relplot(data=df, kind="line", x="t/ms", y="U/mV",hue="Cell",ci=None).savefig('network_ring_result.svg') diff --git a/python/example/single_cell_multi_branch.py b/python/example/single_cell_multi_branch.py index c9c4cb60881c9e5a2b7237f78b3bb868e366b43b..4e8d929e01d3ed7435dbddb9e43a9c268c375008 100755 --- a/python/example/single_cell_multi_branch.py +++ b/python/example/single_cell_multi_branch.py @@ -107,6 +107,6 @@ else: print("Plotting results...") df = pandas.DataFrame() for t in m.traces: - df=df.append(pandas.DataFrame({'t/ms': t.time, 'U/mV': t.value, 'Location': t.location, "Variable": t.variable}) ) + df=df.append(pandas.DataFrame({'t/ms': t.time, 'U/mV': t.value, 'Location': str(t.location), "Variable": t.variable}) ) seaborn.relplot(data=df, kind="line", x="t/ms", y="U/mV",hue="Location",col="Variable",ci=None).savefig('single_cell_multi_branch_result.svg') diff --git a/python/example/single_cell_recipe.py b/python/example/single_cell_recipe.py new file mode 100644 index 0000000000000000000000000000000000000000..f11927f70323809d3661276b81c6ed3546164f94 --- /dev/null +++ b/python/example/single_cell_recipe.py @@ -0,0 +1,82 @@ +import arbor +import numpy, pandas, seaborn # You may have to pip install these. + +# The corresponding generic recipe version of `single_cell_model.py`. + +# (1) Define a recipe for a single cell and set of probes upon it. + +class single_recipe (arbor.recipe): + def __init__(self, cell, probes): + # The base C++ class constructor must be called first, to ensure that + # all memory in the C++ class is initialized correctly. + arbor.recipe.__init__(self) + self.the_cell = cell + self.the_probes = probes + + def num_cells(self): + return 1 + + def num_sources(self, gid): + return 1 + + def cell_kind(self, gid): + return arbor.cell_kind.cable + + def cell_description(self, gid): + return self.the_cell + + def get_probes(self, gid): + return self.the_probes + +# (2) Create a cell. + +tree = arbor.segment_tree() +tree.append(arbor.mnpos, arbor.mpoint(-3, 0, 0, 3), arbor.mpoint(3, 0, 0, 3), tag=1) + +labels = arbor.label_dict() +labels['centre'] = '(location 0 0.5)' + +cell = arbor.cable_cell(tree, labels) +cell.set_properties(Vm=-40) +cell.paint('(all)', 'hh') +cell.place('"centre"', arbor.iclamp( 10, 2, 0.8)) +cell.place('"centre"', arbor.spike_detector(-10)) + +# (3) Instantiate recipe with a voltage probe. + +recipe = single_recipe(cell, [arbor.cable_probe_membrane_voltage('"centre"')]) + +# (4) Instantiate simulation and set up sampling on probe id (0, 0). + +context = arbor.context() +domains = arbor.partition_load_balance(recipe, context) +sim = arbor.simulation(recipe, domains, context) + +sim.record(arbor.spike_recording.all) +handle = sim.sample((0, 0), arbor.regular_schedule(0.1)) + +# (6) Run simulation for 30 ms of simulated activity and collect results. + +sim.run(tfinal=30) +spikes = sim.spikes() +data, meta = sim.samples(handle)[0] + +# (7) Print spike times, if any. + +if len(spikes)>0: + print('{} spikes:'.format(len(spikes))) + for t in spikes['time']: + print('{:3.3f}'.format(t)) +else: + print('no spikes') + +# (8) Plot the recorded voltages over time. + +print("Plotting results ...") +seaborn.set_theme() # Apply some styling to the plot +df = pandas.DataFrame({'t/ms': data[:, 0], 'U/mV': data[:, 1]}) +seaborn.relplot(data=df, kind="line", x="t/ms", y="U/mV", ci=None).savefig('single_cell_recipe_result.svg') + +# (9) Optionally, you can store your results for later processing. + +df.to_csv('single_cell_recipe_result.csv', float_format='%g') diff --git a/python/example/single_cell_swc.py b/python/example/single_cell_swc.py index 05cd7aa8b466a0e4ccb43a31acc690d9d55b34c7..6fa4b967f1805b9d538595b4cfbfd934fbfd3341 100755 --- a/python/example/single_cell_swc.py +++ b/python/example/single_cell_swc.py @@ -85,7 +85,7 @@ else: print("Plotting results ...") df_list = [] for t in m.traces: - df_list.append(pandas.DataFrame({'t/ms': t.time, 'U/mV': t.value, 'Location': t.location, "Variable": t.variable})) + df_list.append(pandas.DataFrame({'t/ms': t.time, 'U/mV': t.value, 'Location': str(t.location), "Variable": t.variable})) df = pandas.concat(df_list) diff --git a/python/identifiers.cpp b/python/identifiers.cpp index 43b6b9c4cbbf081575e1da29a9fc5ba085d7a887..c64ed997b062a608fb6fbd165d8cd6b7c7e1917b 100644 --- a/python/identifiers.cpp +++ b/python/identifiers.cpp @@ -10,23 +10,31 @@ namespace pyarb { using util::pprintf; +namespace py = pybind11; -void register_identifiers(pybind11::module& m) { - using namespace pybind11::literals; +void register_identifiers(py::module& m) { + using namespace py::literals; - pybind11::class_<arb::cell_member_type> cell_member(m, "cell_member", + py::class_<arb::cell_member_type> cell_member(m, "cell_member", "For global identification of a cell-local item.\n\n" "Items of cell_member must:\n" " (1) be associated with a unique cell, identified by the member gid;\n" " (2) identify an item within a cell-local collection by the member index.\n"); cell_member - .def(pybind11::init( + .def(py::init( [](arb::cell_gid_type gid, arb::cell_lid_type idx) { return arb::cell_member_type{gid, idx}; }), "gid"_a, "index"_a, - "Construct a cell member with arguments:\n" + "Construct a cell member identifier with arguments:\n" + " gid: The global identifier of the cell.\n" + " index: The cell-local index of the item.\n") + .def(py::init([](py::tuple t) { + if (py::len(t)!=2) throw std::runtime_error("tuple length != 4"); + return arb::cell_member_type{t[0].cast<arb::cell_gid_type>(), t[1].cast<arb::cell_lid_type>()}; + }), + "Construct a cell member identifier with tuple argument (gid, index):\n" " gid: The global identifier of the cell.\n" " index: The cell-local index of the item.\n") .def_readwrite("gid", &arb::cell_member_type::gid, @@ -36,7 +44,9 @@ void register_identifiers(pybind11::module& m) { .def("__str__", [](arb::cell_member_type m) {return pprintf("<arbor.cell_member: gid {}, index {}>", m.gid, m.index);}) .def("__repr__",[](arb::cell_member_type m) {return pprintf("<arbor.cell_member: gid {}, index {}>", m.gid, m.index);}); - pybind11::enum_<arb::cell_kind>(m, "cell_kind", + py::implicitly_convertible<py::tuple, arb::cell_member_type>(); + + py::enum_<arb::cell_kind>(m, "cell_kind", "Enumeration used to identify the cell kind, used by the model to group equal kinds in the same cell group.") .value("benchmark", arb::cell_kind::benchmark, "Proxy cell used for benchmarking.") @@ -47,14 +57,14 @@ void register_identifiers(pybind11::module& m) { .value("spike_source", arb::cell_kind::spike_source, "Proxy cell that generates spikes from a spike sequence provided by the user."); - pybind11::enum_<arb::backend_kind>(m, "backend", + py::enum_<arb::backend_kind>(m, "backend", "Enumeration used to indicate which hardware backend to execute a cell group on.") .value("gpu", arb::backend_kind::gpu, "Use GPU backend.") .value("multicore", arb::backend_kind::multicore, "Use multicore backend."); - pybind11::enum_<arb::binning_kind>(m, "binning", + py::enum_<arb::binning_kind>(m, "binning", "Enumeration for event time binning policy.") .value("none", arb::binning_kind::none, "No binning policy.") diff --git a/python/morphology.cpp b/python/morphology.cpp index 6f6425e9bb21f2690fb0088691fa4fa3b29bc531..9b89c0fca69021e8ab69e5899b34bdcbad6d6758 100644 --- a/python/morphology.cpp +++ b/python/morphology.cpp @@ -1,4 +1,6 @@ +#include <pybind11/operators.h> #include <pybind11/pybind11.h> +#include <pybind11/pytypes.h> #include <pybind11/stl.h> #include <fstream> @@ -12,6 +14,8 @@ #include "error.hpp" #include "strprintf.hpp" +namespace py = pybind11; + namespace pyarb { void check_trailing(std::istream& in, std::string fname) { @@ -20,8 +24,8 @@ void check_trailing(std::istream& in, std::string fname) { } } -void register_morphology(pybind11::module& m) { - using namespace pybind11::literals; +void register_morphology(py::module& m) { + using namespace py::literals; // // primitives: points, segments, locations, cables... etc. @@ -30,10 +34,10 @@ void register_morphology(pybind11::module& m) { m.attr("mnpos") = arb::mnpos; // arb::mlocation - pybind11::class_<arb::mlocation> location(m, "location", + py::class_<arb::mlocation> location(m, "location", "A location on a cable cell."); location - .def(pybind11::init( + .def(py::init( [](arb::msize_t branch, double pos) { const arb::mlocation mloc{branch, pos}; pyarb::assert_throw(arb::test_invariants(mloc), "invalid location"); @@ -47,19 +51,24 @@ void register_morphology(pybind11::module& m) { "The id of the branch.") .def_readonly("pos", &arb::mlocation::pos, "The relative position on the branch (∈ [0.,1.], where 0. means proximal and 1. distal).") + .def(py::self==py::self) .def("__str__", [](arb::mlocation l) { return util::pprintf("(location {} {})", l.branch, l.pos); }) .def("__repr__", [](arb::mlocation l) { return util::pprintf("(location {} {})", l.branch, l.pos); }); // arb::mpoint - pybind11::class_<arb::mpoint> mpoint(m, "mpoint"); + py::class_<arb::mpoint> mpoint(m, "mpoint"); mpoint - .def(pybind11::init( + .def(py::init( [](double x, double y, double z, double r) { return arb::mpoint{x,y,z,r}; }), "x"_a, "y"_a, "z"_a, "radius"_a, "All values in μm.") + .def(py::init([](py::tuple t) { + if (py::len(t)!=4) throw std::runtime_error("tuple length != 4"); + return arb::mpoint{t[0].cast<double>(), t[1].cast<double>(), t[2].cast<double>(), t[3].cast<double>()}; + })) .def_readonly("x", &arb::mpoint::x, "X coordinate [μm].") .def_readonly("y", &arb::mpoint::y, "Y coordinate [μm].") .def_readonly("z", &arb::mpoint::z, "Z coordinate [μm].") @@ -72,17 +81,19 @@ void register_morphology(pybind11::module& m) { .def("__repr__", [](const arb::mpoint& p) {return util::pprintf("{}>", p);}); + py::implicitly_convertible<py::tuple, arb::mpoint>(); + // arb::msegment - pybind11::class_<arb::msegment> msegment(m, "msegment"); + py::class_<arb::msegment> msegment(m, "msegment"); msegment .def_readonly("prox", &arb::msegment::prox, "the location and radius of the proximal end.") .def_readonly("dist", &arb::msegment::dist, "the location and radius of the distal end.") .def_readonly("tag", &arb::msegment::tag, "tag meta-data."); // arb::mcable - pybind11::class_<arb::mcable> cable(m, "cable"); + py::class_<arb::mcable> cable(m, "cable"); cable - .def(pybind11::init( + .def(py::init( [](arb::msize_t bid, double prox, double dist) { arb::mcable c{bid, prox, dist}; if (!test_invariants(c)) { @@ -97,6 +108,7 @@ void register_morphology(pybind11::module& m) { "The relative position of the proximal end of the cable on its branch ∈ [0,1].") .def_readonly("dist", &arb::mcable::dist_pos, "The relative position of the distal end of the cable on its branch ∈ [0,1].") + .def(py::self==py::self) .def("__str__", [](const arb::mcable& c) { return util::pprintf("{}", c); }) .def("__repr__", [](const arb::mcable& c) { return util::pprintf("{}", c); }); @@ -105,10 +117,10 @@ void register_morphology(pybind11::module& m) { // // arb::segment_tree - pybind11::class_<arb::segment_tree> segment_tree(m, "segment_tree"); + py::class_<arb::segment_tree> segment_tree(m, "segment_tree"); segment_tree // constructors - .def(pybind11::init<>()) + .def(py::init<>()) // modifiers .def("reserve", &arb::segment_tree::reserve) .def("append", [](arb::segment_tree& t, arb::msize_t parent, arb::mpoint prox, arb::mpoint dist, int tag) { @@ -242,10 +254,10 @@ void register_morphology(pybind11::module& m) { // arb::morphology - pybind11::class_<arb::morphology> morph(m, "morphology"); + py::class_<arb::morphology> morph(m, "morphology"); morph // constructors - .def(pybind11::init( + .def(py::init( [](arb::segment_tree t){ return arb::morphology(std::move(t)); })) diff --git a/python/pyarb.cpp b/python/pyarb.cpp index a3d97d5136b68dc4446ac695166906e7c96eeb56..3ac9bcf0c8b821fd3da58a049b31660c475d7e60 100644 --- a/python/pyarb.cpp +++ b/python/pyarb.cpp @@ -1,11 +1,17 @@ #include <pybind11/pybind11.h> +#include <pybind11/numpy.h> +#include <arbor/spike.hpp> +#include <arbor/common_types.hpp> #include <arbor/version.hpp> +#include "pyarb.hpp" + // Forward declarations of functions used to register API // types and functions to be exposed to Python. namespace pyarb { +void register_cable_probes(pybind11::module& m, pyarb_global_ptr); void register_cells(pybind11::module& m); void register_config(pybind11::module& m); void register_contexts(pybind11::module& m); @@ -17,21 +23,28 @@ void register_mechanisms(pybind11::module& m); void register_morphology(pybind11::module& m); void register_profiler(pybind11::module& m); void register_recipe(pybind11::module& m); -void register_sampling(pybind11::module& m); void register_schedules(pybind11::module& m); -void register_simulation(pybind11::module& m); +void register_simulation(pybind11::module& m, pyarb_global_ptr); void register_single_cell(pybind11::module& m); void register_spike_handling(pybind11::module& m); #ifdef ARB_MPI_ENABLED void register_mpi(pybind11::module& m); #endif -} + +} // namespace pyarb PYBIND11_MODULE(_arbor, m) { + pyarb::pyarb_global_ptr global_ptr(new pyarb::pyarb_global); + + // Register NumPy structured datatypes for Arbor structures used in NumPy array outputs. + PYBIND11_NUMPY_DTYPE(arb::cell_member_type, gid, index); + PYBIND11_NUMPY_DTYPE(arb::spike, source, time); + m.doc() = "arbor: multicompartment neural network models."; m.attr("__version__") = ARB_VERSION; + pyarb::register_cable_probes(m, global_ptr); pyarb::register_cells(m); pyarb::register_config(m); pyarb::register_contexts(m); @@ -43,9 +56,8 @@ PYBIND11_MODULE(_arbor, m) { pyarb::register_morphology(m); pyarb::register_profiler(m); pyarb::register_recipe(m); - pyarb::register_sampling(m); pyarb::register_schedules(m); - pyarb::register_simulation(m); + pyarb::register_simulation(m, global_ptr); pyarb::register_single_cell(m); pyarb::register_spike_handling(m); diff --git a/python/pyarb.hpp b/python/pyarb.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6678f5904a7a35cc710727f7de01071e208415b8 --- /dev/null +++ b/python/pyarb.hpp @@ -0,0 +1,84 @@ +#pragma once + +// Common module-global data for use by the pyarb implementation. + +#include <functional> +#include <memory> +#include <typeinfo> +#include <unordered_map> + +#include <arbor/arbexcept.hpp> +#include <arbor/sampling.hpp> +#include <arbor/util/any_ptr.hpp> + +#include <pybind11/pybind11.h> + +namespace pyarb { + +// Sample recorder object interface. + +struct sample_recorder { + virtual void record(arb::util::any_ptr meta, std::size_t n_sample, const arb::sample_record* records) = 0; + virtual pybind11::object samples() const = 0; + virtual pybind11::object meta() const = 0; + virtual void reset() = 0; + virtual ~sample_recorder() {} +}; + +// Recorder 'factory' type: given an any_ptr to probe metadata of a specific subset of types, +// return a corresponding sample_recorder instance. + +using sample_recorder_factory = std::function<std::unique_ptr<sample_recorder> (arb::util::any_ptr)>; + +// Holds map: probe metadata pointer type → recorder object factory. + +struct recorder_factory_map { + std::unordered_map<std::type_index, sample_recorder_factory> map_; + + template <typename Meta> + void assign(sample_recorder_factory rf) { + map_[typeid(const Meta*)] = std::move(rf); + } + + std::unique_ptr<sample_recorder> make_recorder(arb::util::any_ptr meta) const { + try { + return map_.at(meta.type())(meta); + } + catch (std::out_of_range&) { + throw arb::arbor_internal_error("unrecognized probe metadata type"); + } + } +}; + +// Probe metadata to Python object converter. + +using probe_meta_converter = std::function<pybind11::object (arb::util::any_ptr)>; + +struct probe_meta_cvt_map { + std::unordered_map<std::type_index, probe_meta_converter> map_; + + template <typename Meta> + void assign(probe_meta_converter cvt) { + map_[typeid(const Meta*)] = std::move(cvt); + } + + pybind11::object convert(arb::util::any_ptr meta) const { + if (auto iter = map_.find(meta.type()); iter!=map_.end()) { + return iter->second(meta); + } + else { + return pybind11::none(); + } + } +}; + +// Collection of module-global data. + +struct pyarb_global { + recorder_factory_map recorder_factories; + probe_meta_cvt_map probe_meta_converters; +}; + +using pyarb_global_ptr = std::shared_ptr<pyarb_global>; + +} // namespace pyarb diff --git a/python/recipe.cpp b/python/recipe.cpp index 922f47e550c9d4b75ebd2e989013dccff2370297..e1c67da50509cff4a4882d0886d856c1a18459f6 100644 --- a/python/recipe.cpp +++ b/python/recipe.cpp @@ -29,16 +29,6 @@ arb::util::unique_any py_recipe_shim::get_cell_description(arb::cell_gid_type gi "Python error already thrown"); } -arb::probe_info cable_loc_probe(std::string kind, arb::mlocation loc) { - if (kind == "voltage") { - return arb::cable_probe_membrane_voltage{loc}; - } - else if (kind == "ionic current density") { - return arb::cable_probe_total_ion_current_density{loc}; - } - else throw pyarb_error(util::pprintf("unrecognized probe kind: {}", kind)); -}; - std::vector<arb::event_generator> convert_gen(std::vector<pybind11::object> pygens, arb::cell_gid_type gid) { using namespace std::string_literals; using pybind11::isinstance; @@ -169,11 +159,6 @@ void register_recipe(pybind11::module& m) { .def("__repr__", [](const py_recipe&){return "<arbor.recipe>";}); // Probes - m.def("cable_probe", &cable_loc_probe, - "Description of a probe at a location available for monitoring data of kind "\ - "where kind is one of 'voltage' or 'ionic current density'.", - "kind"_a, "location"_a); - pybind11::class_<arb::probe_info> probe(m, "probe"); probe .def("__repr__", [](const arb::probe_info& p){return util::pprintf("<arbor.probe: tag {}>", p.tag);}) diff --git a/python/recipe.hpp b/python/recipe.hpp index bf421aad80948f5e4a28c39e4071deab5fc4d3b6..a6196d0f56be4a39b258bd1824286cceea7a5d2a 100644 --- a/python/recipe.hpp +++ b/python/recipe.hpp @@ -15,9 +15,7 @@ namespace pyarb { -arb::probe_info cable_probe(std::string kind, arb::cell_member_type id, arb::mlocation loc); - -// pyarb::recipe is the recipe interface used by Python. +// pyarb::py_recipe is the recipe interface used by Python. // Calls that return generic types return pybind11::object, to avoid // having to wrap some C++ types used by the C++ interface (specifically // util::unique_any, std::any, std::unique_ptr, etc.) @@ -102,7 +100,7 @@ public: } }; -// A recipe shim that holds a pyarb::recipe implementation. +// A recipe shim that holds a pyarb::py_recipe implementation. // Unwraps/translates python-side output from pyarb::recipe and forwards // to arb::recipe. // For example, unwrap cell descriptions stored in PyObject, and rewrap @@ -123,7 +121,7 @@ public: return try_catch_pyexception([&](){ return impl_->num_cells(); }, msg); } - // The pyarb::recipe::cell_decription returns a pybind11::object, that is + // The pyarb::py_recipe::cell_decription method returns a pybind11::object, that is // unwrapped and copied into a util::unique_any. arb::util::unique_any get_cell_description(arb::cell_gid_type gid) const override; diff --git a/python/sampling.cpp b/python/sampling.cpp deleted file mode 100644 index 6220f625f1cbb9bfc33d8b0de8d1004777b19da1..0000000000000000000000000000000000000000 --- a/python/sampling.cpp +++ /dev/null @@ -1,157 +0,0 @@ -#include <mutex> - -#include <pybind11/pybind11.h> -#include <pybind11/stl.h> - -#include <arbor/common_types.hpp> -#include <arbor/sampling.hpp> -#include <arbor/simulation.hpp> - -#include "error.hpp" -#include "strprintf.hpp" - -namespace pyarb { - -// TODO: trace entry of different types/container (e.g. vector of doubles to get all samples of a cell) - -struct trace_sample { - arb::time_type t; - double v; -}; - -// A helper struct (state) ensuring that only one thread can write to the probe_buffers holding the trace entries (mapped by probe id) -struct sampler_state { - std::mutex mutex; - std::unordered_map<arb::cell_member_type, std::vector<trace_sample>> probe_buffers; - - std::vector<trace_sample>& probe_buffer(arb::cell_member_type pid) { - // lock the mutex, s.t. other threads cannot write - std::lock_guard<std::mutex> lock(mutex); - // return or create entry - return probe_buffers[pid]; - } - - // helper function to search probe id in probe_buffers - bool has_pid(arb::cell_member_type pid) { - return probe_buffers.count(pid); - } - - // helper function to push back to locked vector - void push_back(arb::cell_member_type pid, trace_sample value) { - auto& v = probe_buffer(pid); - v.push_back(std::move(value)); - } - - // Access the probe buffers - const std::unordered_map<arb::cell_member_type, std::vector<trace_sample>>& samples() const { - return probe_buffers; - } -}; - -// A functor that models arb::sampler_function. -// Holds a shared pointer to the trace_sample used to store the samples, so that if -// the trace_sample in sampler is garbage collected in Python, stores will -// not seg fault. - -struct sample_callback { - std::shared_ptr<sampler_state> sample_store; - - sample_callback(const std::shared_ptr<sampler_state>& state): - sample_store(state) - {} - - void operator() (arb::probe_metadata pm, std::size_t n, const arb::sample_record* recs) { - auto& v = sample_store->probe_buffer(pm.id); - for (std::size_t i = 0; i<n; ++i) { - if (auto p = arb::util::any_cast<const double*>(recs[i].data)) { - v.push_back({recs[i].time, *p}); - } - else { - throw std::runtime_error("unexpected sample type"); - } - } - }; -}; - -// Helper type for recording samples from a simulation. -// This type is wrapped in Python, to expose sampler::sample_store. -struct sampler { - std::shared_ptr<sampler_state> sample_store; - - sample_callback callback() { - // initialize the sample_store - sample_store = std::make_shared<sampler_state>(); - - // The callback holds a copy of sample_store, i.e. the shared - // pointer is held by both the sampler and the callback, so if - // the sampler is destructed in the calling Python code, attempts - // to write to sample_store inside the callback will not seg fault. - return sample_callback(sample_store); - } - - const std::vector<trace_sample>& samples(arb::cell_member_type pid) const { - if (!sample_store->has_pid(pid)) { - throw std::runtime_error(util::pprintf("probe id {} does not exist", pid)); - } - return sample_store->probe_buffer(pid); - } - - void clear() { - for (auto b: sample_store->probe_buffers) { - b.second.clear(); - } - } -}; - -// Adds sampler to one probe with pid -std::shared_ptr<sampler> attach_sampler(arb::simulation& sim, arb::time_type interval, arb::cell_member_type pid) { - auto r = std::make_shared<sampler>(); - sim.add_sampler(arb::one_probe(pid), arb::regular_schedule(interval), r->callback()); - return r; -} - -// Adds sampler to all probes -std::shared_ptr<sampler> attach_sampler(arb::simulation& sim, arb::time_type interval) { - auto r = std::make_shared<sampler>(); - sim.add_sampler(arb::all_probes, arb::regular_schedule(interval), r->callback()); - return r; -} - -std::string sample_str(const trace_sample& s) { - return util::pprintf("<arbor.sample: time {} ms, \tvalue {}>", s.t, s.v); -} - -void register_sampling(pybind11::module& m) { - using namespace pybind11::literals; - - // Sample - pybind11::class_<trace_sample> trace_sample(m, "trace_sample"); - trace_sample - .def_readonly("time", &trace_sample::t, "The sample time [ms] at a specific probe.") - .def_readonly("value", &trace_sample::v, "The sample record at a specific probe.") - .def("__str__", &sample_str) - .def("__repr__", &sample_str); - - // Sampler - pybind11::class_<sampler, std::shared_ptr<sampler>> samplerec(m, "sampler"); - samplerec - .def(pybind11::init<>()) - .def("samples", &sampler::samples, - "A list of the recorded samples of a probe with probe id.", - "probe_id"_a) - .def("clear", &sampler::clear, "Clear all recorded samples."); - - m.def("attach_sampler", - (std::shared_ptr<sampler> (*)(arb::simulation&, arb::time_type)) &attach_sampler, - "Attach a sample recorder to an arbor simulation.\n" - "The recorder will record all samples from a regular sampling interval [ms] matching all probe ids.", - "sim"_a, "dt"_a); - - m.def("attach_sampler", - (std::shared_ptr<sampler> (*)(arb::simulation&, arb::time_type, arb::cell_member_type)) &attach_sampler, - "Attach a sample recorder to an arbor simulation.\n" - "The recorder will record all samples from a regular sampling interval [ms] matching one probe id.", - "sim"_a, "dt"_a, "probe_id"_a); -} - -} // namespace pyarb diff --git a/python/schedule.cpp b/python/schedule.cpp index ff51408b63ae719f4467191a9d70925ab5001524..e4f99106065084bd57f642095c8c88dd7b3dd4f5 100644 --- a/python/schedule.cpp +++ b/python/schedule.cpp @@ -8,6 +8,8 @@ #include "schedule.hpp" #include "strprintf.hpp" +namespace py = pybind11; + namespace pyarb { std::ostream& operator<<(std::ostream& o, const regular_schedule_shim& x) { @@ -36,32 +38,33 @@ static std::vector<arb::time_type> as_vector(std::pair<const arb::time_type*, co // regular_schedule shim // -regular_schedule_shim::regular_schedule_shim( - pybind11::object t0, - time_type deltat, - pybind11::object t1) -{ +regular_schedule_shim::regular_schedule_shim(arb::time_type t0, arb::time_type delta_t, py::object t1) { set_tstart(t0); + set_dt(delta_t); set_tstop(t1); - set_dt(deltat); } -void regular_schedule_shim::set_tstart(pybind11::object t) { - tstart = py2optional<time_type>( - t, "tstart must be a non-negative number, or None", is_nonneg()); +regular_schedule_shim::regular_schedule_shim(arb::time_type delta_t) { + set_tstart(0.); + set_dt(delta_t); +} + +void regular_schedule_shim::set_tstart(arb::time_type t) { + pyarb::assert_throw(is_nonneg()(t), "tstart must be a non-negative number"); + tstart = t; }; -void regular_schedule_shim::set_tstop(pybind11::object t) { +void regular_schedule_shim::set_tstop(py::object t) { tstop = py2optional<time_type>( t, "tstop must be a non-negative number, or None", is_nonneg()); }; void regular_schedule_shim::set_dt(arb::time_type delta_t) { - pyarb::assert_throw(is_nonneg()(delta_t), "dt must be a non-negative number"); + pyarb::assert_throw(is_positive()(delta_t), "dt must be a positive number"); dt = delta_t; }; -regular_schedule_shim::opt_time_type regular_schedule_shim::get_tstart() const { +regular_schedule_shim::time_type regular_schedule_shim::get_tstart() const { return tstart; } @@ -75,7 +78,7 @@ regular_schedule_shim::opt_time_type regular_schedule_shim::get_tstop() const { arb::schedule regular_schedule_shim::schedule() const { return arb::regular_schedule( - tstart.value_or(arb::terminal_time), + tstart, dt, tstop.value_or(arb::terminal_time)); } @@ -177,21 +180,27 @@ std::vector<arb::time_type> poisson_schedule_shim::events(arb::time_type t0, arb return as_vector(sched.events(t0, t1)); } -void register_schedules(pybind11::module& m) { - using namespace pybind11::literals; +void register_schedules(py::module& m) { + using namespace py::literals; using time_type = arb::time_type; + py::class_<schedule_shim_base> schedule_base(m, "schedule_base", "Schedule abstract base class."); + // Regular schedule - pybind11::class_<regular_schedule_shim> regular_schedule(m, "regular_schedule", + py::class_<regular_schedule_shim, schedule_shim_base> regular_schedule(m, "regular_schedule", "Describes a regular schedule with multiples of dt within the interval [tstart, tstop)."); regular_schedule - .def(pybind11::init<pybind11::object, time_type, pybind11::object>(), - "tstart"_a = pybind11::none(), "dt"_a = 0., "tstop"_a = pybind11::none(), + .def(py::init<time_type, time_type, py::object>(), + "tstart"_a, "dt"_a, "tstop"_a = py::none(), "Construct a regular schedule with arguments:\n" - " tstart: The delivery time of the first event in the sequence [ms], None by default.\n" - " dt: The interval between time points [ms], 0 by default.\n" + " tstart: The delivery time of the first event in the sequence [ms].\n" + " dt: The interval between time points [ms].\n" " tstop: No events delivered after this time [ms], None by default.") + .def(py::init<time_type>(), + "dt"_a, + "Construct a regular schedule, starting from t = 0 and never terminating, with arguments:\n" + " dt: The interval between time points [ms].\n") .def_property("tstart", ®ular_schedule_shim::get_tstart, ®ular_schedule_shim::set_tstart, "The delivery time of the first event in the sequence [ms].") .def_property("tstop", ®ular_schedule_shim::get_tstop, ®ular_schedule_shim::set_tstop, @@ -204,13 +213,13 @@ void register_schedules(pybind11::module& m) { .def("__repr__", util::to_string<regular_schedule_shim>); // Explicit schedule - pybind11::class_<explicit_schedule_shim> explicit_schedule(m, "explicit_schedule", + py::class_<explicit_schedule_shim, schedule_shim_base> explicit_schedule(m, "explicit_schedule", "Describes an explicit schedule at a predetermined (sorted) sequence of times."); explicit_schedule - .def(pybind11::init<>(), + .def(py::init<>(), "Construct an empty explicit schedule.\n") - .def(pybind11::init<std::vector<time_type>>(), + .def(py::init<std::vector<time_type>>(), "times"_a, "Construct an explicit schedule with argument:\n" " times: A list of times [ms], [] by default.") @@ -222,11 +231,11 @@ void register_schedules(pybind11::module& m) { .def("__repr__", util::to_string<explicit_schedule_shim>); // Poisson schedule - pybind11::class_<poisson_schedule_shim> poisson_schedule(m, "poisson_schedule", + py::class_<poisson_schedule_shim, schedule_shim_base> poisson_schedule(m, "poisson_schedule", "Describes a schedule according to a Poisson process."); poisson_schedule - .def(pybind11::init<time_type, time_type, std::mt19937_64::result_type>(), + .def(py::init<time_type, time_type, std::mt19937_64::result_type>(), "tstart"_a = 0., "freq"_a = 10., "seed"_a = 0, "Construct a Poisson schedule with arguments:\n" " tstart: The delivery time of the first event in the sequence [ms], 0 by default.\n" diff --git a/python/schedule.hpp b/python/schedule.hpp index 5aefeb5fcbc5e783e0764d01688469f2cdf23197..32e6bcefe26964aa3207257b9b4476a48cf88af6 100644 --- a/python/schedule.hpp +++ b/python/schedule.hpp @@ -12,32 +12,41 @@ namespace pyarb { +// Schedule shim base class provides virtual interface for conversion +// to an arb::schedule object. +struct schedule_shim_base { + schedule_shim_base() = default; + schedule_shim_base(const schedule_shim_base&) = delete; + schedule_shim_base& operator=(schedule_shim_base&) = delete; + + virtual arb::schedule schedule() const = 0; +}; + // A Python shim that holds the information that describes an // arb::regular_schedule. This is wrapped in pybind11, and users constructing // a regular_schedule in python are manipulating this type. This is converted to // an arb::regular_schedule when a C++ recipe is created from a Python recipe. -struct regular_schedule_shim { +struct regular_schedule_shim: schedule_shim_base { using time_type = arb::time_type; using opt_time_type = std::optional<time_type>; - opt_time_type tstart = {}; - opt_time_type tstop = {}; + time_type tstart = {}; time_type dt = 0; + opt_time_type tstop = {}; - regular_schedule_shim() = default; - - regular_schedule_shim(pybind11::object t0, time_type deltat, pybind11::object t1); + regular_schedule_shim(time_type t0, time_type delta_t, pybind11::object t1); + explicit regular_schedule_shim(time_type delta_t); // getter and setter (in order to assert when being set) - void set_tstart(pybind11::object t); - void set_tstop(pybind11::object t); + void set_tstart(time_type t); void set_dt(time_type delta_t); + void set_tstop(pybind11::object t); - opt_time_type get_tstart() const; - time_type get_dt() const; - opt_time_type get_tstop() const; + time_type get_tstart() const; + time_type get_dt() const; + opt_time_type get_tstop() const; - arb::schedule schedule() const; + arb::schedule schedule() const override; std::vector<arb::time_type> events(arb::time_type t0, arb::time_type t1); }; @@ -46,7 +55,7 @@ struct regular_schedule_shim { // This is wrapped in pybind11, and users constructing an explicit_schedule in // Python are manipulating this type. This is converted to an // arb::explicit_schedule when a C++ recipe is created from a Python recipe. -struct explicit_schedule_shim { +struct explicit_schedule_shim: schedule_shim_base { std::vector<arb::time_type> times; explicit_schedule_shim() = default; @@ -56,7 +65,7 @@ struct explicit_schedule_shim { void set_times(std::vector<arb::time_type> t); std::vector<arb::time_type> get_times() const; - arb::schedule schedule() const; + arb::schedule schedule() const override; std::vector<arb::time_type> events(arb::time_type t0, arb::time_type t1); }; @@ -65,7 +74,7 @@ struct explicit_schedule_shim { // This is wrapped in pybind11, and users constructing a poisson_schedule in // Python are manipulating this type. This is converted to an // arb::poisson_schedule when a C++ recipe is created from a Python recipe. -struct poisson_schedule_shim { +struct poisson_schedule_shim: schedule_shim_base { using rng_type = std::mt19937_64; arb::time_type tstart = arb::terminal_time; @@ -81,7 +90,7 @@ struct poisson_schedule_shim { arb::time_type get_tstart() const; arb::time_type get_freq() const; - arb::schedule schedule() const; + arb::schedule schedule() const override; std::vector<arb::time_type> events(arb::time_type t0, arb::time_type t1); }; diff --git a/python/simulation.cpp b/python/simulation.cpp index 03b019063fe726a491392f3e686fdc1792373d66..1a95a9f1abbdc2aaabee126341e4baf787b9d030 100644 --- a/python/simulation.cpp +++ b/python/simulation.cpp @@ -1,3 +1,5 @@ +#include <memory> +#include <pybind11/numpy.h> #include <pybind11/pybind11.h> #include <arbor/common_types.hpp> @@ -6,24 +8,176 @@ #include "context.hpp" #include "error.hpp" +#include "pyarb.hpp" #include "recipe.hpp" +#include "schedule.hpp" + +namespace py = pybind11; namespace pyarb { -void register_simulation(pybind11::module& m) { +// Argument type for simulation_shim::record() (see below). + +enum class spike_recording { + off, local, all +}; + +// Wraps an arb::simulation object and in addition manages a set of +// sampler callbacks for retrieving probe data. + +class simulation_shim { + std::unique_ptr<arb::simulation> sim_; + std::vector<arb::spike> spike_record_; + pyarb_global_ptr global_ptr_; + + using sample_recorder_ptr = std::unique_ptr<sample_recorder>; + using sample_recorder_vec = std::vector<sample_recorder_ptr>; + + // These are only used as the target sampler of a single probe id. + struct sampler_callback { + std::shared_ptr<sample_recorder_vec> recorders; + + void operator()(arb::probe_metadata pm, std::size_t n_record, const arb::sample_record* records) { + recorders->at(pm.index)->record(pm.meta, n_record, records); + } + + py::list samples() const { + std::size_t size = recorders->size(); + py::list result(size); + + for (std::size_t i = 0; i<size; ++i) { + result[i] = py::make_tuple(recorders->at(i)->samples(), recorders->at(i)->meta()); + } + return result; + } + }; + + std::unordered_map<arb::sampler_association_handle, sampler_callback> sampler_map_; + +public: + simulation_shim(std::shared_ptr<py_recipe>& rec, const arb::domain_decomposition& decomp, const context_shim& ctx, pyarb_global_ptr global_ptr): + global_ptr_(global_ptr) + { + try { + sim_.reset(new arb::simulation(py_recipe_shim(rec), decomp, ctx.context)); + } + catch (...) { + py_reset_and_throw(); + throw; + } + } + + void reset() { + sim_->reset(); + spike_record_.clear(); + for (auto&& [handle, cb]: sampler_map_) { + for (auto& rec: *cb.recorders) { + rec->reset(); + } + } + } + + arb::time_type run(arb::time_type tfinal, arb::time_type dt) { + return sim_->run(tfinal, dt); + } + + void set_binning_policy(arb::binning_kind policy, arb::time_type bin_interval) { + sim_->set_binning_policy(policy, bin_interval); + } + + void record(spike_recording policy) { + auto spike_recorder = [this](const std::vector<arb::spike>& spikes) { + spike_record_.insert(spike_record_.end(), spikes.begin(), spikes.end()); + }; + + switch (policy) { + case spike_recording::off: + sim_->set_global_spike_callback(); + sim_->set_local_spike_callback(); + break; + case spike_recording::local: + sim_->set_global_spike_callback(); + sim_->set_local_spike_callback(spike_recorder); + break; + case spike_recording::all: + sim_->set_global_spike_callback(spike_recorder); + sim_->set_local_spike_callback(); + break; + } + } + + py::object spikes() const { + return py::array_t<arb::spike>({spike_record_.size()}, spike_record_.data()); + } + + py::list get_probe_metadata(arb::cell_member_type probe_id) const { + py::list result; + for (auto&& pm: sim_->get_probe_metadata(probe_id)) { + result.append(global_ptr_->probe_meta_converters.convert(pm.meta)); + } + return result; + } + + arb::sampler_association_handle sample(arb::cell_member_type probe_id, const pyarb::schedule_shim_base& sched, arb::sampling_policy policy) { + std::shared_ptr<sample_recorder_vec> recorders{new sample_recorder_vec}; + + for (const arb::probe_metadata& pm: sim_->get_probe_metadata(probe_id)) { + recorders->push_back(global_ptr_->recorder_factories.make_recorder(pm.meta)); + } + + // Constructed callbacks are passed to the underlying simulator object, _and_ a copy + // is kept in sampler_map_; the two copies share the same recorder data. + + sampler_callback cb{std::move(recorders)}; + auto sah = sim_->add_sampler(arb::one_probe(probe_id), sched.schedule(), cb, policy); + sampler_map_.insert({sah, cb}); + + return sah; + } + + void remove_sampler(arb::sampler_association_handle sah) { + sim_->remove_sampler(sah); + sampler_map_.erase(sah); + } + + void remove_all_samplers() { + sim_->remove_all_samplers(); + sampler_map_.clear(); + } + + py::list samples(arb::sampler_association_handle sah) { + if (auto iter = sampler_map_.find(sah); iter!=sampler_map_.end()) { + return iter->second.samples(); + } + else { + return py::list{}; + } + } +}; + +void register_simulation(pybind11::module& m, pyarb_global_ptr global_ptr) { using namespace pybind11::literals; + py::enum_<arb::sampling_policy>(m, "sampling_policy") + .value("lax", arb::sampling_policy::lax) + .value("exact", arb::sampling_policy::exact); + + py::enum_<spike_recording>(m, "spike_recording") + .value("off", spike_recording::off) + .value("local", spike_recording::local) + .value("all", spike_recording::all); + // Simulation - pybind11::class_<arb::simulation> simulation(m, "simulation", + py::class_<simulation_shim> simulation(m, "simulation", "The executable form of a model.\n" "A simulation is constructed from a recipe, and then used to update and monitor model state."); simulation // A custom constructor that wraps a python recipe with arb::py_recipe_shim // before forwarding it to the arb::recipe constructor. .def(pybind11::init( - [](std::shared_ptr<py_recipe>& rec, const arb::domain_decomposition& decomp, const context_shim& ctx) { + [global_ptr](std::shared_ptr<py_recipe>& rec, const arb::domain_decomposition& decomp, const context_shim& ctx) { try { - return new arb::simulation(py_recipe_shim(rec), decomp, ctx.context); + return new simulation_shim(rec, decomp, ctx, global_ptr); } catch (...) { py_reset_and_throw(); @@ -35,18 +189,36 @@ void register_simulation(pybind11::module& m) { "Initialize the model described by a recipe, with cells and network distributed\n" "according to the domain decomposition and computational resources described by a context.", "recipe"_a, "domain_decomposition"_a, "context"_a) - .def("reset", &arb::simulation::reset, + .def("reset", &simulation_shim::reset, pybind11::call_guard<pybind11::gil_scoped_release>(), "Reset the state of the simulation to its initial state.") - .def("run", &arb::simulation::run, + .def("run", &simulation_shim::run, pybind11::call_guard<pybind11::gil_scoped_release>(), "Run the simulation from current simulation time to tfinal [ms], with maximum time step size dt [ms].", "tfinal"_a, "dt"_a=0.025) - .def("set_binning_policy", &arb::simulation::set_binning_policy, + .def("set_binning_policy", &simulation_shim::set_binning_policy, "Set the binning policy for event delivery, and the binning time interval if applicable [ms].", "policy"_a, "bin_interval"_a) - .def("__str__", [](const arb::simulation&){ return "<arbor.simulation>"; }) - .def("__repr__", [](const arb::simulation&){ return "<arbor.simulation>"; }); + .def("record", &simulation_shim::record, + "Disable or enable local or global spike recording.") + .def("spikes", &simulation_shim::spikes, + "Retrieve recorded spikes as numpy array.") + .def("probe_metadata", &simulation_shim::get_probe_metadata, + "Retrieve metadata associated with given probe id.", + "probe_id"_a) + .def("sample", &simulation_shim::sample, + "Record data from probes with given probe_id according to supplied schedule.\n" + "Returns handle for retrieving data or removing the sampling.", + "probe_id"_a, "schedule"_a, "policy"_a = arb::sampling_policy::lax) + .def("samples", &simulation_shim::samples, + "Retrieve sample data as a list, one element per probe associated with the query.", + "handle"_a) + .def("remove_sampler", &simulation_shim::remove_sampler, + "Remove sampling associated with the given handle.", + "handle"_a) + .def("remove_all_samplers", &simulation_shim::remove_sampler, + "Remove all sampling on the simulatr."); + } } // namespace pyarb diff --git a/python/test/unit/runner.py b/python/test/unit/runner.py index 06ebb671adb8ed6fd564c98ffe5944c31e15771f..231b88b16ca7c6d041a00ac5252911960907bf35 100644 --- a/python/test/unit/runner.py +++ b/python/test/unit/runner.py @@ -16,6 +16,7 @@ try: import test_identifiers import test_tests import test_schedules + import test_cable_probes # add more if needed except ModuleNotFoundError: from test import options @@ -24,6 +25,7 @@ except ModuleNotFoundError: from test.unit import test_event_generators from test.unit import test_identifiers from test.unit import test_schedules + from test.unit import test_cable_probes # add more if needed test_modules = [\ @@ -31,7 +33,8 @@ test_modules = [\ test_domain_decompositions,\ test_event_generators,\ test_identifiers,\ - test_schedules\ + test_schedules,\ + test_cable_probes\ ] # add more if needed def suite(): diff --git a/python/test/unit/test_cable_probes.py b/python/test/unit/test_cable_probes.py new file mode 100644 index 0000000000000000000000000000000000000000..2640eaaa38c9149e373063cdacb1a36127ca3de2 --- /dev/null +++ b/python/test/unit/test_cable_probes.py @@ -0,0 +1,176 @@ +# -*- coding: utf-8 -*- + +import unittest +import arbor as A + +# to be able to run .py file from child directory +import sys, os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) + +try: + import options +except ModuleNotFoundError: + from test import options + +""" +tests for cable probe wrappers +""" + +# Test recipe cc comprises one simple cable cell and mechanisms on it +# sufficient to test cable cell probe wrappers wrap correctly. + +class cc_recipe(A.recipe): + def __init__(self): + A.recipe.__init__(self) + st = A.segment_tree() + st.append(A.mnpos, (0, 0, 0, 10), (1, 0, 0, 10), 1) + + self.cell = A.cable_cell(st, A.label_dict()) + self.cell.place('(location 0 0.08)', "expsyn") + self.cell.place('(location 0 0.09)', "exp2syn") + self.cell.paint('(all)', "hh") + + def num_cells(self): + return 1 + + def num_targets(self, gid): + return 2 + + def num_sources(self, gid): + return 0 + + def cell_kind(self, gid): + return A.cell_kind.cable + + def get_probes(self, gid): + # Use keyword arguments to check that the wrappers have actually declared keyword arguments correctly. + # Place single-location probes at (location 0 0.01*j) where j is the index of the probe address in + # the returned list. + return [ + # probe id (0, 0) + A.cable_probe_membrane_voltage(where='(location 0 0.00)'), + # probe id (0, 1) + A.cable_probe_membrane_voltage_cell(), + # probe id (0, 2) + A.cable_probe_axial_current(where='(location 0 0.02)'), + # probe id (0, 3) + A.cable_probe_total_ion_current_density(where='(location 0 0.03)'), + # probe id (0, 4) + A.cable_probe_total_ion_current_cell(), + # probe id (0, 5) + A.cable_probe_total_current_cell(), + # probe id (0, 6) + A.cable_probe_density_state(where='(location 0 0.06)', mechanism='hh', state='m'), + # probe id (0, 7) + A.cable_probe_density_state_cell(mechanism='hh', state='n'), + # probe id (0, 8) + A.cable_probe_point_state(target=0, mechanism='expsyn', state='g'), + # probe id (0, 9) + A.cable_probe_point_state_cell(mechanism='exp2syn', state='B'), + # probe id (0, 10) + A.cable_probe_ion_current_density(where='(location 0 0.10)', ion='na'), + # probe id (0, 11) + A.cable_probe_ion_current_cell(ion='na'), + # probe id (0, 12) + A.cable_probe_ion_int_concentration(where='(location 0 0.12)', ion='na'), + # probe id (0, 13) + A.cable_probe_ion_int_concentration_cell(ion='na'), + # probe id (0, 14) + A.cable_probe_ion_ext_concentration(where='(location 0 0.14)', ion='na'), + # probe id (0, 15) + A.cable_probe_ion_ext_concentration_cell(ion='na'), + ] + + def cell_description(self, gid): + return self.cell + +class CableProbes(unittest.TestCase): + def test_probe_addr_metadata(self): + recipe = cc_recipe() + context = A.context() + dd = A.partition_load_balance(recipe, context) + sim = A.simulation(recipe, dd, context) + + all_cv_cables = [A.cable(0, 0, 1)] + + m = sim.probe_metadata((0, 0)) + self.assertEqual(1, len(m)) + self.assertEqual(A.location(0, 0.0), m[0]) + + m = sim.probe_metadata((0, 1)) + self.assertEqual(1, len(m)) + self.assertEqual(all_cv_cables, m[0]) + + m = sim.probe_metadata((0, 2)) + self.assertEqual(1, len(m)) + self.assertEqual(A.location(0, 0.02), m[0]) + + m = sim.probe_metadata((0, 3)) + self.assertEqual(1, len(m)) + self.assertEqual(A.location(0, 0.03), m[0]) + + m = sim.probe_metadata((0, 4)) + self.assertEqual(1, len(m)) + self.assertEqual(all_cv_cables, m[0]) + + m = sim.probe_metadata((0, 5)) + self.assertEqual(1, len(m)) + self.assertEqual(all_cv_cables, m[0]) + + m = sim.probe_metadata((0, 6)) + self.assertEqual(1, len(m)) + self.assertEqual(A.location(0, 0.06), m[0]) + + m = sim.probe_metadata((0, 7)) + self.assertEqual(1, len(m)) + self.assertEqual(all_cv_cables, m[0]) + + m = sim.probe_metadata((0, 8)) + self.assertEqual(1, len(m)) + self.assertEqual(A.location(0, 0.08), m[0].location) + self.assertEqual(1, m[0].multiplicity) + self.assertEqual(0, m[0].target) + + m = sim.probe_metadata((0, 9)) + self.assertEqual(1, len(m)) + self.assertEqual(1, len(m[0])) + self.assertEqual(A.location(0, 0.09), m[0][0].location) + self.assertEqual(1, m[0][0].multiplicity) + self.assertEqual(1, m[0][0].target) + + m = sim.probe_metadata((0, 10)) + self.assertEqual(1, len(m)) + self.assertEqual(A.location(0, 0.10), m[0]) + + m = sim.probe_metadata((0, 11)) + self.assertEqual(1, len(m)) + self.assertEqual(all_cv_cables, m[0]) + + m = sim.probe_metadata((0, 12)) + self.assertEqual(1, len(m)) + self.assertEqual(A.location(0, 0.12), m[0]) + + m = sim.probe_metadata((0, 13)) + self.assertEqual(1, len(m)) + self.assertEqual(all_cv_cables, m[0]) + + m = sim.probe_metadata((0, 14)) + self.assertEqual(1, len(m)) + self.assertEqual(A.location(0, 0.14), m[0]) + + m = sim.probe_metadata((0, 15)) + self.assertEqual(1, len(m)) + self.assertEqual(all_cv_cables, m[0]) + +def suite(): + # specify class and test functions in tuple (here: all tests starting with 'test' from class Contexts + suite = unittest.makeSuite(CableProbes, ('test')) + return suite + +def run(): + v = options.parse_arguments().verbosity + runner = unittest.TextTestRunner(verbosity = v) + runner.run(suite()) + +if __name__ == "__main__": + run() diff --git a/python/test/unit/test_identifiers.py b/python/test/unit/test_identifiers.py index 587dd293133ad8d66465e239ee2a42342ad95bdf..ba425be22472ae16f76d315d4f87cf70151e7481 100644 --- a/python/test/unit/test_identifiers.py +++ b/python/test/unit/test_identifiers.py @@ -21,12 +21,12 @@ all tests for identifiers, indexes, kinds class CellMembers(unittest.TestCase): - def test_gid_index_contor_cell_member(self): + def test_gid_index_ctor_cell_member(self): cm = arb.cell_member(17,42) self.assertEqual(cm.gid, 17) self.assertEqual(cm.index, 42) - def test_set_git_index_cell_member(self): + def test_set_gid_index_cell_member(self): cm = arb.cell_member(0,0) cm.gid = 13 cm.index = 23 diff --git a/python/test/unit/test_schedules.py b/python/test/unit/test_schedules.py index 2d320c7a097c90d9878a96b174129777a970027c..9d0344783b5875382855ebea5e554cb8796294b9 100644 --- a/python/test/unit/test_schedules.py +++ b/python/test/unit/test_schedules.py @@ -20,45 +20,50 @@ all tests for schedules (regular, explicit, poisson) """ class RegularSchedule(unittest.TestCase): - def test_none_contor_regular_schedule(self): - rs = arb.regular_schedule(tstart=None, tstop=None) + def test_none_ctor_regular_schedule(self): + rs = arb.regular_schedule(tstart=0, dt=0.1, tstop=None) + self.assertEqual(rs.dt, 0.1) - def test_tstart_dt_tstop_contor_regular_schedule(self): + def test_tstart_dt_tstop_ctor_regular_schedule(self): rs = arb.regular_schedule(10., 1., 20.) self.assertEqual(rs.tstart, 10.) self.assertEqual(rs.dt, 1.) self.assertEqual(rs.tstop, 20.) def test_set_tstart_dt_tstop_regular_schedule(self): - rs = arb.regular_schedule() + rs = arb.regular_schedule(0.1) + self.assertAlmostEqual(rs.dt, 0.1, places=1) rs.tstart = 17. rs.dt = 0.5 rs.tstop = 42. self.assertEqual(rs.tstart, 17.) - self.assertAlmostEqual(rs.dt, 0.5, places = 1) + self.assertAlmostEqual(rs.dt, 0.5, places=1) self.assertEqual(rs.tstop, 42.) def test_events_regular_schedule(self): expected = [0, 0.25, 0.5, 0.75, 1.0] - rs = arb.regular_schedule(tstart = 0., dt = 0.25, tstop = 1.25) + rs = arb.regular_schedule(tstart=0., dt=0.25, tstop=1.25) self.assertEqual(expected, rs.events(0., 1.25)) self.assertEqual(expected, rs.events(0., 5.)) self.assertEqual([], rs.events(5., 10.)) def test_exceptions_regular_schedule(self): with self.assertRaisesRegex(RuntimeError, - "tstart must be a non-negative number, or None"): - arb.regular_schedule(tstart = -1.) + "tstart must be a non-negative number"): + arb.regular_schedule(tstart=-1., dt=0.1) + with self.assertRaisesRegex(RuntimeError, + "dt must be a positive number"): + arb.regular_schedule(dt=-0.1) with self.assertRaisesRegex(RuntimeError, - "dt must be a non-negative number"): - arb.regular_schedule(dt = -0.1) + "dt must be a positive number"): + arb.regular_schedule(dt=0) with self.assertRaises(TypeError): - arb.regular_schedule(dt = None) + arb.regular_schedule(dt=None) with self.assertRaises(TypeError): - arb.regular_schedule(dt = 'dt') + arb.regular_schedule(dt='dt') with self.assertRaisesRegex(RuntimeError, "tstop must be a non-negative number, or None"): - arb.regular_schedule(tstop = 'tstop') + arb.regular_schedule(tstart=0, dt=0.1, tstop='tstop') with self.assertRaisesRegex(RuntimeError, "t0 must be a non-negative number"): rs = arb.regular_schedule(0., 1., 10.) @@ -98,13 +103,9 @@ class ExplicitSchedule(unittest.TestCase): arb.explicit_schedule([None]) with self.assertRaises(TypeError): arb.explicit_schedule([[1,2,3]]) - with self.assertRaisesRegex(RuntimeError, - "t0 must be a non-negative number"): - rs = arb.regular_schedule() - rs.events(-1., 1.) with self.assertRaisesRegex(RuntimeError, "t1 must be a non-negative number"): - rs = arb.regular_schedule() + rs = arb.regular_schedule(0.1) rs.events(1., -1.) class PoissonSchedule(unittest.TestCase): diff --git a/python/test/unit/test_simulator.py b/python/test/unit/test_simulator.py new file mode 100644 index 0000000000000000000000000000000000000000..9dc26acdb7266b06c165c74534f6529b03785cd7 --- /dev/null +++ b/python/test/unit/test_simulator.py @@ -0,0 +1,258 @@ +# -*- coding: utf-8 -*- +# +# test_simulator.py + +import unittest +import numpy as np +import arbor as A + +# to be able to run .py file from child directory +import sys, os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) + +try: + import options +except ModuleNotFoundError: + from test import options + +""" +all tests for the simulator wrapper +""" + +# Test recipe cc2 comprises two cable cells and some probes. + +class cc2_recipe(A.recipe): + def __init__(self): + A.recipe.__init__(self) + st = A.segment_tree() + i = st.append(A.mnpos, (0, 0, 0, 10), (1, 0, 0, 10), 1) + st.append(i, (1, 3, 0, 5), 1) + st.append(i, (1, -4, 0, 3), 1) + self.the_morphology = A.morphology(st) + + def num_cells(self): + return 2 + + def num_targets(self, gid): + return 0 + + def num_sources(self, gid): + return 0 + + def cell_kind(self, gid): + return A.cell_kind.cable + + def connections_on(self, gid): + return [] + + def event_generators(self, gid): + return [] + + def get_probes(self, gid): + # Cell 0 has three voltage probes: + # 0, 0: end of branch 1 + # 0, 1: end of branch 2 + # 0, 2: all terminal points + # Values sampled from (0, 0) and (0, 1) should correspond + # to the values sampled from (0, 2). + + # Cell 1 has whole cell probes: + # 0, 0: all membrane voltages + # 0, 1: all expsyn state variable 'g' + + if gid==0: + return [A.cable_probe_membrane_voltage('(location 1 1)'), + A.cable_probe_membrane_voltage('(location 2 1)'), + A.cable_probe_membrane_voltage('(terminal)')] + elif gid==1: + return [A.cable_probe_membrane_voltage_cell(), + A.cable_probe_point_state_cell('expsyn', 'g')] + else: + return [] + + def cell_description(self, gid): + c = A.cable_cell(self.the_morphology, A.label_dict()) + c.set_properties(Vm=0.0, cm=0.01, rL=30, tempK=300) + c.paint('(all)', "pas") + c.place('(location 0 0)', A.iclamp(current=10 if gid==0 else 20)) + c.place('(sum (on-branches 0.3) (location 0 0.6))', "expsyn") + return c + +# Test recipe lif2 comprises two independent LIF cells driven by a regular, rapid +# sequence of incoming spikes. The cells have differing refactory periods. + +class lif2_recipe(A.recipe): + def __init__(self): + A.recipe.__init__(self) + + def num_cells(self): + return 2 + + def num_targets(self, gid): + return 0 + + def num_sources(self, gid): + return 0 + + def cell_kind(self, gid): + return A.cell_kind.lif + + def connections_on(self, gid): + return [] + + def event_generators(self, gid): + sched_dt = 0.25 + weight = 400 + return [A.event_generator((gid,0), weight, A.regular_schedule(sched_dt)) for gid in range(0, self.num_cells())] + + def get_probes(self, gid): + return [] + + def cell_description(self, gid): + c = A.lif_cell() + if gid==0: + c.t_ref = 2 + if gid==1: + c.t_ref = 4 + return c + +class Simulator(unittest.TestCase): + def init_sim(self, recipe): + context = A.context() + dd = A.partition_load_balance(recipe, context) + return A.simulation(recipe, dd, context) + + def test_simple_run(self): + sim = self.init_sim(cc2_recipe()) + sim.run(1.0, 0.01) + + def test_probe_meta(self): + sim = self.init_sim(cc2_recipe()) + + self.assertEqual([A.location(1, 1)], sim.probe_metadata((0, 0))) + self.assertEqual([A.location(2, 1)], sim.probe_metadata((0, 1))) + self.assertEqual([A.location(1, 1), A.location(2, 1)], sorted(sim.probe_metadata((0, 2)), key=lambda x:(x.branch, x.pos))) + + # Default CV policy is one per branch, which also gives a tivial CV over the branch point. + # Expect metadata cables to be one for each full branch, plus three length-zero cables corresponding to the branch point. + self.assertEqual([A.cable(0, 0, 1), A.cable(0, 1, 1), A.cable(1, 0, 0), A.cable(1, 0, 1), A.cable(2, 0, 0), A.cable(2, 0, 1)], + sorted(sim.probe_metadata((1,0))[0], key=lambda x:(x.branch, x.prox, x.dist))) + + # Four expsyn synapses; the two on branch zero should be coalesced, giving a multiplicity of 2. + # Expect entries to be in target index order. + m11 = sim.probe_metadata((1,1))[0] + self.assertEqual(4, len(m11)) + self.assertEqual([0, 1, 2, 3], [x.target for x in m11]) + self.assertEqual([2, 2, 1, 1], [x.multiplicity for x in m11]) + self.assertEqual([A.location(0, 0.3), A.location(0, 0.6), A.location(1, 0.3), A.location(2, 0.3)], [x.location for x in m11]) + + def test_probe_scalar_recorders(self): + sim = self.init_sim(cc2_recipe()) + ts = [0, 0.1, 0.3, 0.7] + h = sim.sample((0, 0), A.explicit_schedule(ts)) + dt = 0.01 + sim.run(10., dt) + s, meta = sim.samples(h)[0] + self.assertEqual(A.location(1, 1), meta) + for i, t in enumerate(s[:,0]): + self.assertLess(abs(t-ts[i]), dt) + + sim.remove_sampler(h) + sim.reset() + h = sim.sample(A.cell_member(0, 0), A.explicit_schedule(ts), A.sampling_policy.exact) + sim.run(10., dt) + s, meta = sim.samples(h)[0] + for i, t in enumerate(s[:,0]): + self.assertEqual(t, ts[i]) + + + def test_probe_multi_scalar_recorders(self): + sim = self.init_sim(cc2_recipe()) + ts = [0, 0.1, 0.3, 0.7] + h0 = sim.sample((0, 0), A.explicit_schedule(ts)) + h1 = sim.sample((0, 1), A.explicit_schedule(ts)) + h2 = sim.sample((0, 2), A.explicit_schedule(ts)) + + dt = 0.01 + sim.run(10., dt) + + r0 = sim.samples(h0) + self.assertEqual(1, len(r0)) + s0, meta0 = r0[0] + + r1 = sim.samples(h1) + self.assertEqual(1, len(r1)) + s1, meta1 = r1[0] + + r2 = sim.samples(h2) + self.assertEqual(2, len(r2)) + s20, meta20 = r2[0] + s21, meta21 = r2[1] + + # Probe id (0, 2) has probes over the two locations that correspond to probes (0, 0) and (0, 1). + + # (order is not guaranteed to line up though) + if meta20==meta0: + self.assertEqual(meta1, meta21) + self.assertTrue((s0[:,1]==s20[:,1]).all()) + self.assertTrue((s1[:,1]==s21[:,1]).all()) + else: + self.assertEqual(meta1, meta20) + self.assertTrue((s1[:,1]==s20[:,1]).all()) + self.assertEqual(meta0, meta21) + self.assertTrue((s0[:,1]==s21[:,1]).all()) + + def test_probe_vector_recorders(self): + sim = self.init_sim(cc2_recipe()) + ts = [0, 0.1, 0.3, 0.7] + h0 = sim.sample((1, 0), A.explicit_schedule(ts), A.sampling_policy.exact) + h1 = sim.sample((1, 1), A.explicit_schedule(ts), A.sampling_policy.exact) + sim.run(10., 0.01) + + # probe (1, 0) is the whole cell voltage; expect time + 6 sample values per row in returned data (see test_probe_meta above). + + s0, meta0 = sim.samples(h0)[0] + self.assertEqual(6, len(meta0)) + self.assertEqual((len(ts), 7), s0.shape) + for i, t in enumerate(s0[:,0]): + self.assertEqual(t, ts[i]) + + # probe (1, 1) is the 'g' state for all expsyn synapses. + # With the default descretization, expect two synapses with multiplicity 2 and two with multiplicity 1. + + s1, meta1 = sim.samples(h1)[0] + self.assertEqual(4, len(meta1)) + self.assertEqual((len(ts), 5), s1.shape) + for i, t in enumerate(s1[:,0]): + self.assertEqual(t, ts[i]) + + meta1_mult = {(m.location.branch, m.location.pos): m.multiplicity for m in meta1} + self.assertEqual(2, meta1_mult[(0, 0.3)]) + self.assertEqual(2, meta1_mult[(0, 0.6)]) + self.assertEqual(1, meta1_mult[(1, 0.3)]) + self.assertEqual(1, meta1_mult[(2, 0.3)]) + + def test_spikes(self): + sim = self.init_sim(lif2_recipe()) + sim.record(A.spike_recording.all) + sim.run(21, 0.01) + + spikes = sim.spikes().tolist() + s0 = sorted([t for s, t in spikes if s==(0, 0)]) + s1 = sorted([t for s, t in spikes if s==(1, 0)]) + + self.assertEqual([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20], s0) + self.assertEqual([0, 4, 8, 12, 16, 20], s1) + +def suite(): + # specify class and test functions in tuple (here: all tests starting with 'test' from class Contexts + suite = unittest.makeSuite(Simulator, ('test')) + return suite + +def run(): + v = options.parse_arguments().verbosity + runner = unittest.TextTestRunner(verbosity = v) + runner.run(suite()) + +if __name__ == "__main__": + run() diff --git a/python/test/unit_distributed/runner.py b/python/test/unit_distributed/runner.py index b887802bb8c32df75bbebbe00ce72376983372ed..57d36da433559aa3c02c9cd65b7d6f28e34f13ec 100644 --- a/python/test/unit_distributed/runner.py +++ b/python/test/unit_distributed/runner.py @@ -27,12 +27,14 @@ except ModuleNotFoundError: from test.unit_distributed import test_contexts_arbmpi from test.unit_distributed import test_contexts_mpi4py from test.unit_distributed import test_domain_decompositions + from test.unit_distributed import test_simulator # add more if needed test_modules = [\ test_contexts_arbmpi,\ test_contexts_mpi4py,\ - test_domain_decompositions\ + test_domain_decompositions,\ + test_simulator\ ] # add more if needed def suite(): diff --git a/python/test/unit_distributed/test_domain_decompositions.py b/python/test/unit_distributed/test_domain_decompositions.py index f5a3a0dfdac2fbf627dcaaf94e4945d17c2fdb69..ad6c5fda1e5c22f27ceef9f16554a9fc2a9a4844 100644 --- a/python/test/unit_distributed/test_domain_decompositions.py +++ b/python/test/unit_distributed/test_domain_decompositions.py @@ -443,7 +443,7 @@ def run(): arb.mpi_init() comm = arb.mpi_comm() - + alloc = arb.proc_allocation() ctx = arb.context(alloc, comm) rank = ctx.rank diff --git a/python/test/unit_distributed/test_simulator.py b/python/test/unit_distributed/test_simulator.py new file mode 100644 index 0000000000000000000000000000000000000000..ee0da5516df4d981c2984d8f644f6096ab493248 --- /dev/null +++ b/python/test/unit_distributed/test_simulator.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +# +# test_simulator.py + +import unittest +import numpy as np +import arbor as A + +# to be able to run .py file from child directory +import sys, os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) + +try: + import options +except ModuleNotFoundError: + from test import options + +mpi_enabled = A.__config__["mpi"] + +""" +test for MPI distribution of spike recording +""" + +class lifN_recipe(A.recipe): + def __init__(self, n_cell): + A.recipe.__init__(self) + self.n_cell = n_cell + + def num_cells(self): + return self.n_cell + + def num_targets(self, gid): + return 0 + + def num_sources(self, gid): + return 0 + + def cell_kind(self, gid): + return A.cell_kind.lif + + def connections_on(self, gid): + return [] + + def event_generators(self, gid): + sched_dt = 0.25 + weight = 400 + return [A.event_generator((gid,0), weight, A.regular_schedule(sched_dt)) for gid in range(0, self.num_cells())] + + def get_probes(self, gid): + return [] + + def cell_description(self, gid): + c = A.lif_cell() + if gid%2==0: + c.t_ref = 2 + else: + c.t_ref = 4 + return c + +@unittest.skipIf(mpi_enabled == False, "MPI not enabled") +class Simulator(unittest.TestCase): + def init_sim(self): + comm = A.mpi_comm() + context = A.context(threads=1, gpu_id=None, mpi=A.mpi_comm()) + self.rank = context.rank + self.ranks = context.ranks + + recipe = lifN_recipe(context.ranks) + dd = A.partition_load_balance(recipe, context) + + # Confirm decomposition has gid 0 on rank 0, ..., gid N-1 on rank N-1. + self.assertEqual(1, dd.num_local_cells) + local_groups = dd.groups + self.assertEqual(1, len(local_groups)) + self.assertEqual([self.rank], local_groups[0].gids) + + return A.simulation(recipe, dd, context) + + def test_local_spikes(self): + sim = self.init_sim() + sim.record(A.spike_recording.local) + sim.run(9, 0.01) + spikes = sim.spikes().tolist() + + # Everything should come from the one cell, gid == rank. + self.assertEqual({(self.rank, 0)}, {s for s, t in spikes}) + + times = sorted([t for s, t in spikes]) + if self.rank%2==0: + self.assertEqual([0, 2, 4, 6, 8], times) + else: + self.assertEqual([0, 4, 8], times) + + def test_global_spikes(self): + sim = self.init_sim() + sim.record(A.spike_recording.all) + sim.run(9, 0.01) + spikes = sim.spikes().tolist() + + expected = [((s, 0), t) for s in range(0, self.ranks) for t in ([0, 2, 4, 6, 8] if s%2==0 else [0, 4, 8])] + self.assertEqual(expected, sorted(spikes)) + + +def suite(): + # specify class and test functions in tuple (here: all tests starting with 'test' from class Contexts + suite = unittest.makeSuite(Simulator, ('test')) + return suite + +def run(): + v = options.parse_arguments().verbosity + + if not A.mpi_is_initialized(): + A.mpi_init() + + comm = A.mpi_comm() + alloc = A.proc_allocation() + ctx = A.context(alloc, comm) + rank = ctx.rank + + if rank == 0: + runner = unittest.TextTestRunner(verbosity = v) + else: + sys.stdout = open(os.devnull, 'w') + runner = unittest.TextTestRunner(stream=sys.stdout) + + runner.run(suite()) + + if not A.mpi_is_finalized(): + A.mpi_finalize() + +if __name__ == "__main__": + run() diff --git a/scripts/travis/build.sh b/scripts/travis/build.sh index 8875dfca32364e16e55d5bff756bb0036a14f1aa..03111283e361e6aa64ae10af6b0671b90aab0820 100755 --- a/scripts/travis/build.sh +++ b/scripts/travis/build.sh @@ -113,6 +113,7 @@ if [[ "${WITH_PYTHON}" == "true" ]]; then progress "Python examples" python$PY $python_path/example/network_ring.py || error "running python network_ring example" python$PY $python_path/example/single_cell_model.py || error "running python single_cell_model example" + python$PY $python_path/example/single_cell_recipe.py || error "running python single_cell_recipe example" python$PY $python_path/example/single_cell_multi_branch.py || error "running python single_cell_multi_branch example" python$PY $python_path/example/single_cell_swc.py $base_path/test/unit/swc/pyramidal.swc || error "running python single_cell_swc example" if [[ "${WITH_DISTRIBUTED}" = "mpi" ]]; then