diff --git a/arbor/include/arbor/domain_decomposition.hpp b/arbor/include/arbor/domain_decomposition.hpp index ebc7e34bb3dbabf5a134664c3832ca28eca1a51f..43aa6415a27e19de3663eb2b5a76463a5fb54240 100644 --- a/arbor/include/arbor/domain_decomposition.hpp +++ b/arbor/include/arbor/domain_decomposition.hpp @@ -15,7 +15,7 @@ struct group_description { /// The kind of cell in the group. All cells in a cell_group have the same type. cell_kind kind; - /// The gids of the cells in the cell_group, sorted in ascending order. + /// The gids of the cells in the cell_group. Does not need to be sorted. std::vector<cell_gid_type> gids; /// The back end on which the cell_group is to run. diff --git a/arbor/include/arbor/recipe.hpp b/arbor/include/arbor/recipe.hpp index 5521aa12e0294f0b45a8960b13b12115ad4b6de2..f5910c908c4cefde96ea48378617ca501e0aedf4 100644 --- a/arbor/include/arbor/recipe.hpp +++ b/arbor/include/arbor/recipe.hpp @@ -40,9 +40,9 @@ struct cell_connection { cell_connection_endpoint dest; float weight; - float delay; + time_type delay; - cell_connection(cell_connection_endpoint src, cell_connection_endpoint dst, float w, float d): + cell_connection(cell_connection_endpoint src, cell_connection_endpoint dst, float w, time_type d): source(src), dest(dst), weight(w), delay(d) {} }; diff --git a/doc/cpp_domdec.rst b/doc/cpp_domdec.rst index 95a3462a732b7c9889c9b850f7a26b75c012790f..78360c8423d24cbb145cb35321605ee86b530184 100644 --- a/doc/cpp_domdec.rst +++ b/doc/cpp_domdec.rst @@ -154,7 +154,7 @@ Documentation for the data structures used to describe domain decompositions. .. cpp:member:: const std::vector<cell_gid_type> gids - The gids of the cells in the cell group, **sorted in ascending order**. + The gids of the cells in the cell group. .. cpp:member:: const backend_kind backend diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 3f4611f76a3ff5ac4df500d4dc85e8f960bcfe05..234597ba6116061e9a3c58e20a15b40190c832df 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -18,6 +18,7 @@ add_subdirectory(pybind11) add_library(pyarb MODULE config.cpp context.cpp + domain_decomposition.cpp event_generator.cpp identifiers.cpp mpi.cpp diff --git a/python/config.cpp b/python/config.cpp index 29636c8d525012e89f8eedb70e0ebf5a4d52f5cc..d011d26a949d6d58d65592eb38d5424e8fcd02dd 100644 --- a/python/config.cpp +++ b/python/config.cpp @@ -2,11 +2,11 @@ #include <ios> #include <sstream> -#include <arbor/version.hpp> - #include <pybind11/pybind11.h> #include <pybind11/stl.h> +#include <arbor/version.hpp> + namespace pyarb { // Returns a python dictionary that python users can use to look up diff --git a/python/context.cpp b/python/context.cpp index 19eb59efcd23ba26eaf05c2fdcee34db3be946ec..f871d3d5d3884f5f701fa51ad85f420371bf16f7 100644 --- a/python/context.cpp +++ b/python/context.cpp @@ -2,12 +2,12 @@ #include <sstream> #include <string> +#include <pybind11/pybind11.h> + #include <arbor/context.hpp> #include <arbor/version.hpp> #include <arbor/util/optional.hpp> -#include <pybind11/pybind11.h> - #include "context.hpp" #include "conversion.hpp" #include "error.hpp" @@ -27,7 +27,7 @@ std::ostream& operator<<(std::ostream& o, const context_shim& ctx) { o << "<context: threads " << arb::num_threads(c) << ", gpu " << (gpu? "yes": "no") << ", mpi " << (mpi? "yes": "no") - << " ranks " << arb::num_ranks(c) + << ", ranks " << arb::num_ranks(c) << ">"; } diff --git a/python/domain_decomposition.cpp b/python/domain_decomposition.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8a52275f3e1255b66650b7edde15061e78630f1f --- /dev/null +++ b/python/domain_decomposition.cpp @@ -0,0 +1,92 @@ +#include <string> +#include <sstream> + +#include <pybind11/pybind11.h> + +#include <arbor/context.hpp> +#include <arbor/domain_decomposition.hpp> +#include <arbor/load_balance.hpp> + +#include "context.hpp" +#include "recipe.hpp" +#include "strprintf.hpp" + +namespace pyarb { + +std::string gd_string(const arb::group_description& g) { + std::stringstream s; + s << "<cell group: " << g.gids.size() + << " cells, gids [" << util::csv(g.gids, 5) << "]" + << ", " << g.kind << ", " << g.backend << ">"; + return s.str(); +} + +std::string dd_string(const arb::domain_decomposition& d) { + std::stringstream s; + s << "<domain decomposition: domain " + << d.domain_id << " of " + << d.num_domains << ", " + << d.num_local_cells << "/" << d.num_global_cells << " loc/glob cells, " + << d.groups.size() << " groups>"; + + return s.str(); +} + +void register_domain_decomposition(pybind11::module& m) { + using namespace pybind11::literals; + + // Group description + pybind11::class_<arb::group_description> group_description(m, "group_description", + "The indexes of a set of cells of the same kind that are grouped together in a cell group."); + group_description + .def(pybind11::init<arb::cell_kind, std::vector<arb::cell_gid_type>, arb::backend_kind>(), + "Construct a group description with cell kind, list of gids, and backend kind.", + "kind"_a, "gids"_a, "backend"_a) + .def_readonly("kind", &arb::group_description::kind, + "The type of cell in the cell group.") + .def_readonly("gids", &arb::group_description::gids, + "The gids of the cells in the group in ascending order.") + .def_readonly("backend", &arb::group_description::backend, + "The hardware backend on which the cell group will run.") + .def("__str__", &gd_string) + .def("__repr__", &gd_string); + + // Domain decomposition + pybind11::class_<arb::domain_decomposition> domain_decomposition(m, "domain_decomposition", + "The domain decomposition is responsible for describing the distribution of cells across cell groups and domains."); + domain_decomposition + .def(pybind11::init<>()) + .def("gid_domain", + [](const arb::domain_decomposition& d, arb::cell_gid_type gid) { + return d.gid_domain(gid); + }, + "Query the domain id that a cell assigned to (using global identifier gid).", + "gid"_a) + .def_readonly("num_domains", &arb::domain_decomposition::num_domains, + "Number of domains that the model is distributed over.") + .def_readonly("domain_id", &arb::domain_decomposition::domain_id, + "The index of the local domain.\n" + "Always 0 for non-distributed models, and corresponds to the MPI rank for distributed runs.") + .def_readonly("num_local_cells", &arb::domain_decomposition::num_local_cells, + "Total number of cells in the local domain.") + .def_readonly("num_global_cells", &arb::domain_decomposition::num_global_cells, + "Total number of cells in the global model (sum of num_local_cells over all domains).") + .def_readonly("groups", &arb::domain_decomposition::groups, + "Descriptions of the cell groups on the local domain.") + .def("__str__", &dd_string) + .def("__repr__", &dd_string); + + // Partition load balancer + // The Python recipe has to be shimmed for passing to the function that + // takes a C++ recipe. + m.def("partition_load_balance", + [](std::shared_ptr<py_recipe>& recipe, const context_shim& ctx) { + return arb::partition_load_balance(py_recipe_shim(recipe), ctx.context); + }, + "Construct a domain_decomposition that distributes the cells in the model described by recipe\n" + "over the distributed and local hardware resources described by context.", + "recipe"_a, "context"_a); +} + +} // namespace pyarb + diff --git a/python/event_generator.cpp b/python/event_generator.cpp index 96505ba9505d8c7ee2c8f09658e604021869b9c0..ceb5ec53dc54dc60e1ba52e8c84cde2d96675f44 100644 --- a/python/event_generator.cpp +++ b/python/event_generator.cpp @@ -2,14 +2,14 @@ #include <sstream> #include <string> -#include <arbor/common_types.hpp> -#include <arbor/schedule.hpp> -#include <arbor/util/optional.hpp> - #include <pybind11/pybind11.h> #include <pybind11/pytypes.h> #include <pybind11/stl.h> +#include <arbor/common_types.hpp> +#include <arbor/schedule.hpp> +#include <arbor/util/optional.hpp> + #include "conversion.hpp" #include "error.hpp" #include "event_generator.hpp" diff --git a/python/identifiers.cpp b/python/identifiers.cpp index 78824eff83f07e1d46bc5c614917e6d9cbda89ba..e6dda467964962b796bc82c386a8933ef641e503 100644 --- a/python/identifiers.cpp +++ b/python/identifiers.cpp @@ -1,10 +1,10 @@ #include <ostream> #include <string> -#include <arbor/common_types.hpp> - #include <pybind11/pybind11.h> +#include <arbor/common_types.hpp> + #include "strprintf.hpp" namespace pyarb { @@ -52,6 +52,13 @@ void register_identifiers(pybind11::module& m) { "Leaky-integrate and fire neuron.") .value("spike_source", arb::cell_kind::spike_source, "Proxy cell that generates spikes from a spike sequence provided by the user."); + + pybind11::enum_<arb::backend_kind>(m, "backend_kind", + "Enumeration used to indicate which hardware backend to use for running a cell_group.") + .value("gpu", arb::backend_kind::gpu, + "Use GPU backend.") + .value("multicore", arb::backend_kind::multicore, + "Use multicore backend."); } } // namespace pyarb diff --git a/python/pyarb.cpp b/python/pyarb.cpp index 965d6406fccbe7d5af713f5519b46204c96b4ff1..f037db39886b83fff0254e203d5077302d983bf0 100644 --- a/python/pyarb.cpp +++ b/python/pyarb.cpp @@ -1,13 +1,14 @@ -#include <arbor/version.hpp> - #include <pybind11/pybind11.h> +#include <arbor/version.hpp> + // Forward declarations of functions used to register API // types and functions to be exposed to Python. namespace pyarb { void register_config(pybind11::module& m); void register_contexts(pybind11::module& m); +void register_domain_decomposition(pybind11::module& m); void register_event_generators(pybind11::module& m); void register_identifiers(pybind11::module& m); void register_recipe(pybind11::module& m); @@ -24,6 +25,7 @@ PYBIND11_MODULE(arbor, m) { pyarb::register_config(m); pyarb::register_contexts(m); + pyarb::register_domain_decomposition(m); pyarb::register_event_generators(m); pyarb::register_identifiers(m); #ifdef ARB_MPI_ENABLED diff --git a/python/recipe.cpp b/python/recipe.cpp index 413c00be8727676cdfc71fc1901abd6acb40d079..b12ad0165b6da16b83d1bd560fd183f0deba4300 100644 --- a/python/recipe.cpp +++ b/python/recipe.cpp @@ -110,11 +110,39 @@ std::vector<arb::event_generator> py_recipe_shim::event_generators(arb::cell_gid return gens; } +// Wrap arb::cell_connection in a shim that asserts constraints on connection +// delay when the user attempts to set them in Python. +struct cell_connection_shim { + arb::cell_member_type source; + arb::cell_member_type destination; + float weight; + arb::time_type delay; + + cell_connection_shim(arb::cell_member_type src, arb::cell_member_type dst, float w, arb::time_type del) { + source = src; + destination = dst; + weight = w; + set_delay(del); + } + + // getter and setter + void set_delay(arb::time_type t) { + pyarb::assert_throw([](arb::time_type f){ return f>arb::time_type(0); }(t), "connection delay must be positive"); + delay = t; + } + + arb::time_type get_delay() const { return delay; } + + operator arb::cell_connection() const { + return arb::cell_connection(source, destination, weight, delay); + } +}; + // TODO: implement py_recipe_shim::probe_info -std::string con_to_string(const arb::cell_connection& c) { +std::string con_to_string(const cell_connection_shim& c) { return util::pprintf("<connection: ({},{}) -> ({},{}), delay {}, weight {}>", - c.source.gid, c.source.index, c.dest.gid, c.dest.index, c.delay, c.weight); + c.source.gid, c.source.index, c.destination.gid, c.destination.index, c.delay, c.weight); } std::string gj_to_string(const arb::gap_junction_connection& gc) { @@ -126,31 +154,24 @@ void register_recipe(pybind11::module& m) { using namespace pybind11::literals; // Connections - pybind11::class_<arb::cell_connection> cell_connection(m, "cell_connection", + pybind11::class_<cell_connection_shim> cell_connection(m, "cell_connection", "Describes a connection between two cells:\n" - "a pre-synaptic source and a post-synaptic destination."); + " Defined by source and destination end points (that is pre-synaptic and post-synaptic respectively), a connection weight and a delay time."); cell_connection - .def(pybind11::init<>( - [](){return arb::cell_connection({0u,0u}, {0u,0u}, 0.f, 0.f);}), - "Construct a connection with default arguments:\n" - " source: gid 0, index 0.\n" - " destination: gid 0, index 0.\n" - " weight: 0.\n" - " delay: 0 ms.\n") - .def(pybind11::init<arb::cell_member_type, arb::cell_member_type, float, float>(), - "source"_a, "destination"_a, "weight"_a = 0.f, "delay"_a = 0.f, + .def(pybind11::init<arb::cell_member_type, arb::cell_member_type, float, arb::time_type>(), + "source"_a = arb::cell_member_type{0,0}, "dest"_a = arb::cell_member_type{0,0}, "weight"_a = 0.f, "delay"_a, "Construct a connection with arguments:\n" - " source: The source end point of the connection.\n" - " destination: The destination end point of the connection.\n" + " source: The source end point of the connection (default (0,0)).\n" + " dest: The destination end point of the connection (default (0,0)).\n" " weight: The weight delivered to the target synapse (dimensionless with interpretation specific to synapse type of target, default 0.).\n" - " delay: The delay of the connection (unit: ms, default 0.).\n") - .def_readwrite("source", &arb::cell_connection::source, + " delay: The delay of the connection (unit: ms).") + .def_readwrite("source", &cell_connection_shim::source, "The source of the connection.") - .def_readwrite("destination", &arb::cell_connection::dest, + .def_readwrite("dest", &cell_connection_shim::destination, "The destination of the connection.") - .def_readwrite("weight", &arb::cell_connection::weight, - "The weight of the connection (unit: Sâ‹…cmâ»Â²).") - .def_readwrite("delay", &arb::cell_connection::delay, + .def_readwrite("weight", &cell_connection_shim::weight, + "The weight of the connection.") + .def_property("delay", &cell_connection_shim::get_delay, &cell_connection_shim::set_delay, "The delay time of the connection (unit: ms).") .def("__str__", &con_to_string) .def("__repr__", &con_to_string); @@ -159,18 +180,12 @@ void register_recipe(pybind11::module& m) { pybind11::class_<arb::gap_junction_connection> gap_junction_connection(m, "gap_junction_connection", "Describes a gap junction between two gap junction sites."); gap_junction_connection - .def(pybind11::init<>( - [](){return arb::gap_junction_connection({0u,0u}, {0u,0u}, 0.f);}), - "Construct a gap junction connection with default arguments:\n" - " local: gid 0, index 0.\n" - " peer: gid 0, index 0.\n" - " ggap: 0 μS.\n") .def(pybind11::init<arb::cell_member_type, arb::cell_member_type, double>(), - "local"_a, "peer"_a, "ggap"_a = 0.f, + "local"_a = arb::cell_member_type{0,0}, "peer"_a = arb::cell_member_type{0,0}, "ggap"_a = 0.f, "Construct a gap junction connection with arguments:\n" - " local: One half of the gap junction connection.\n" - " peer: Other half of the gap junction connection.\n" - " ggap: Gap junction conductance (unit: μS, default 0.).\n") + " local: One half of the gap junction connection (default (0,0)).\n" + " peer: Other half of the gap junction connection (default (0,0)).\n" + " ggap: Gap junction conductance (unit: μS, default 0.).") .def_readwrite("local", &arb::gap_junction_connection::local, "One half of the gap junction connection.") .def_readwrite("peer", &arb::gap_junction_connection::peer, @@ -218,7 +233,7 @@ void register_recipe(pybind11::module& m) { .def("global_properties", &py_recipe::global_properties, pybind11::return_value_policy::copy, "cell_kind"_a, "Global property type specific to a given cell kind.") - .def("__str__", [](const py_recipe&){return "<pyarb.recipe>";}) - .def("__repr__", [](const py_recipe&){return "<pyarb.recipe>";}); + .def("__str__", [](const py_recipe&){return "<arbor.recipe>";}) + .def("__repr__", [](const py_recipe&){return "<arbor.recipe>";}); } } // namespace pyarb diff --git a/python/recipe.hpp b/python/recipe.hpp index 4e83c276d37eedb497be973871c0b1d25cb3c168..6057fb6d0227650c2f3f58097dff40bbd64c9ea2 100644 --- a/python/recipe.hpp +++ b/python/recipe.hpp @@ -65,7 +65,7 @@ public: arb::cell_kind cell_kind(arb::cell_gid_type gid) const override { PYBIND11_OVERLOAD_PURE(arb::cell_kind, py_recipe, cell_kind, gid); } - + arb::cell_size_type num_sources(arb::cell_gid_type gid) const override { PYBIND11_OVERLOAD(arb::cell_size_type, py_recipe, num_sources, gid); } diff --git a/python/schedule.cpp b/python/schedule.cpp index f295b85c65c1c76ebcf90197db38e9b0ec479c0e..d2b8449c0080aea1aa3497a7f7df080c35865e21 100644 --- a/python/schedule.cpp +++ b/python/schedule.cpp @@ -19,7 +19,7 @@ std::ostream& operator<<(std::ostream& o, const regular_schedule_shim& x) { std::ostream& operator<<(std::ostream& o, const explicit_schedule_shim& e) { o << "<explicit_schedule: times ["; - return util::csv(o, e.times) << "] ms>"; + return o << util::csv(e.times) << "] ms>"; }; std::ostream& operator<<(std::ostream& o, const poisson_schedule_shim& p) { @@ -44,12 +44,12 @@ regular_schedule_shim::regular_schedule_shim( void regular_schedule_shim::set_tstart(pybind11::object t) { tstart = py2optional<time_type>( - t, "tstart must a non-negative number, or None", is_nonneg()); + t, "tstart must be a non-negative number, or None", is_nonneg()); }; void regular_schedule_shim::set_tstop(pybind11::object t) { tstop = py2optional<time_type>( - t, "tstop must a non-negative number, or None", is_nonneg()); + t, "tstop must be a non-negative number, or None", is_nonneg()); }; void regular_schedule_shim::set_dt(arb::time_type delta_t) { @@ -61,7 +61,7 @@ regular_schedule_shim::opt_time_type regular_schedule_shim::get_tstart() const { return tstart; } -regular_schedule_shim::opt_time_type regular_schedule_shim::get_dt() const { +regular_schedule_shim::time_type regular_schedule_shim::get_dt() const { return dt; } @@ -97,7 +97,7 @@ void explicit_schedule_shim::set_times(std::vector<arb::time_type> t) { // Assert that there are no negative times if (times.size()) { pyarb::assert_throw(is_nonneg()(times[0]), - "explicit time schedule can not contain negative values"); + "explicit time schedule cannot contain negative values"); } }; diff --git a/python/schedule.hpp b/python/schedule.hpp index 9f43ccc189f9a4481d1e706725647eb7ca10b650..32b16c71fed3d4d6a50fb6242fd600ad87c240be 100644 --- a/python/schedule.hpp +++ b/python/schedule.hpp @@ -1,11 +1,11 @@ #pragma once +#include <pybind11/pybind11.h> + #include <arbor/schedule.hpp> #include <arbor/common_types.hpp> #include <arbor/util/optional.hpp> -#include <pybind11/pybind11.h> - namespace pyarb { // A Python shim that holds the information that describes an @@ -30,7 +30,7 @@ struct regular_schedule_shim { void set_dt(time_type delta_t); opt_time_type get_tstart() const; - opt_time_type get_dt() const; + time_type get_dt() const; opt_time_type get_tstop() const; arb::schedule schedule() const; diff --git a/python/strprintf.hpp b/python/strprintf.hpp index 431f83d2de6ae0010ebe0a85693382589022cfdd..e3b9dd49914318627f96c7f07c301df7f9b3fbbf 100644 --- a/python/strprintf.hpp +++ b/python/strprintf.hpp @@ -146,8 +146,8 @@ namespace impl { for (auto& x: s.seq_) { if (!first) { o << s.sep_; - first = false; } + first = false; if (!n) { return o << "..."; } @@ -160,23 +160,23 @@ namespace impl { } template <typename Seq> -std::ostream& sepval(std::ostream& o, const char* sep, const Seq& seq) { - return o << impl::sepval<Seq>(seq, sep); +impl::sepval<Seq> sepval(const char* sep, const Seq& seq) { + return impl::sepval<Seq>(seq, sep); } template <typename Seq> -std::ostream& sepval(std::ostream& o, const char* sep, const Seq& seq, unsigned n) { - return o << impl::sepval_lim<Seq>(seq, sep, n); +impl::sepval_lim<Seq> sepval(const char* sep, const Seq& seq, unsigned n) { + return impl::sepval_lim<Seq>(seq, sep, n); } template <typename Seq> -std::ostream& csv(std::ostream& o, const Seq& seq) { - return o << impl::sepval<Seq>(seq, ", "); +impl::sepval<Seq> csv(const Seq& seq) { + return impl::sepval<Seq>(seq, ", "); } template <typename Seq> -std::ostream& csv(std::ostream& o, const Seq& seq, unsigned n) { - return o << impl::sepval_lim<Seq>(seq, ", ", n); +impl::sepval_lim<Seq> csv(const Seq& seq, unsigned n) { + return impl::sepval_lim<Seq>(seq, ", ", n); } } // namespace util diff --git a/python/test/unit/runner.py b/python/test/unit/runner.py index 1ea70555e0505cfbbbab34e692fb6eae42176890..2833e3816d8e00066beeb608d9eb141149e530ff 100644 --- a/python/test/unit/runner.py +++ b/python/test/unit/runner.py @@ -11,20 +11,24 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../. try: import options import test_contexts - import test_identifiers import test_event_generators + import test_identifiers + import test_tests + import test_schedules # add more if needed except ModuleNotFoundError: from test import options from test.unit import test_contexts - from test.unit import test_identifiers from test.unit import test_event_generators + from test.unit import test_identifiers + from test.unit import test_schedules # add more if needed test_modules = [\ test_contexts,\ + test_event_generators,\ test_identifiers,\ - test_event_generators\ + test_schedules\ ] # add more if needed def suite(): diff --git a/python/test/unit/test_event_generators.py b/python/test/unit/test_event_generators.py index b64119c8bcd8670cb6ea43a3fe9ab685a8e621a2..527564da47dc2a76f8c2c3b56ca275c2d51c32f4 100644 --- a/python/test/unit/test_event_generators.py +++ b/python/test/unit/test_event_generators.py @@ -3,7 +3,6 @@ # test_event_generators.py import unittest -import numpy as np import arbor as arb @@ -20,24 +19,7 @@ except ModuleNotFoundError: all tests for event generators (regular, explicit, poisson) """ -class RegularSchedule(unittest.TestCase): - def test_none_contor_regular_schedule(self): - rs = arb.regular_schedule(tstart=None, tstop=None) - - def test_tstart_dt_tstop_contor_regular_schedule(self): - rs = arb.regular_schedule(10., 1., 20.) - self.assertEqual(rs.tstart, 10.) - self.assertEqual(rs.dt, 1.) - self.assertEqual(rs.tstop, 20.) - - def test_set_tstart_dt_tstop_regular_schedule(self): - rs = arb.regular_schedule() - rs.tstart = 17. - rs.dt = 0.5 - rs.tstop = 42. - self.assertEqual(rs.tstart, 17.) - self.assertAlmostEqual(rs.dt, 0.5) - self.assertEqual(rs.tstop, 42.) +class EventGenerator(unittest.TestCase): def test_event_generator_regular_schedule(self): cm = arb.cell_member() @@ -49,31 +31,6 @@ class RegularSchedule(unittest.TestCase): self.assertEqual(rg.target.index, 3) self.assertAlmostEqual(rg.weight, 3.14) - def test_exceptions_regular_schedule(self): - with self.assertRaisesRegex(RuntimeError, - "tstart must a non-negative number, or None"): - arb.regular_schedule(tstart = -1.) - with self.assertRaisesRegex(RuntimeError, - "dt must be a non-negative number"): - arb.regular_schedule(dt = -0.1) - with self.assertRaises(TypeError): - arb.regular_schedule(dt = None) - with self.assertRaises(TypeError): - arb.regular_schedule(dt = 'dt') - with self.assertRaisesRegex(RuntimeError, - "tstop must a non-negative number, or None"): - arb.regular_schedule(tstop = 'tstop') - -class ExplicitSchedule(unittest.TestCase): - def test_times_contor_explicit_schedule(self): - es = arb.explicit_schedule([1, 2, 3, 4.5]) - self.assertEqual(es.times, [1, 2, 3, 4.5]) - - def test_set_times_explicit_schedule(self): - es = arb.explicit_schedule() - es.times = [42, 43, 44, 55.5, 100] - self.assertEqual(es.times, [42, 43, 44, 55.5, 100]) - def test_event_generator_explicit_schedule(self): cm = arb.cell_member() cm.gid = 0 @@ -84,38 +41,6 @@ class ExplicitSchedule(unittest.TestCase): self.assertEqual(eg.target.index, 42) self.assertAlmostEqual(eg.weight, -0.01) - def test_exceptions_explicit_schedule(self): - with self.assertRaisesRegex(RuntimeError, - "explicit time schedule can not contain negative values"): - arb.explicit_schedule([-1]) - with self.assertRaises(TypeError): - arb.explicit_schedule(['times']) - with self.assertRaises(TypeError): - arb.explicit_schedule([None]) - with self.assertRaises(TypeError): - arb.explicit_schedule([[1,2,3]]) - -class PoissonSchedule(unittest.TestCase): - def test_freq_seed_contor_poisson_schedule(self): - ps = arb.poisson_schedule(freq = 5., seed = 42) - self.assertEqual(ps.freq, 5.) - self.assertEqual(ps.seed, 42) - - def test_tstart_freq_seed_contor_poisson_schedule(self): - ps = arb.poisson_schedule(10., 100., 1000) - self.assertEqual(ps.tstart, 10.) - self.assertEqual(ps.freq, 100.) - self.assertEqual(ps.seed, 1000) - - def test_set_tstart_freq_seed_poisson_schedule(self): - ps = arb.poisson_schedule() - ps.tstart = 4.5 - ps.freq = 5.5 - ps.seed = 83 - self.assertAlmostEqual(ps.tstart, 4.5) - self.assertAlmostEqual(ps.freq, 5.5) - self.assertEqual(ps.seed, 83) - def test_event_generator_poisson_schedule(self): cm = arb.cell_member() cm.gid = 4 @@ -126,34 +51,9 @@ class PoissonSchedule(unittest.TestCase): self.assertEqual(pg.target.index, 2) self.assertEqual(pg.weight, 42.) - def test_exceptions_poisson_schedule(self): - with self.assertRaisesRegex(RuntimeError, - "tstart must be a non-negative number"): - arb.poisson_schedule(tstart = -10.) - with self.assertRaises(TypeError): - arb.poisson_schedule(tstart = None) - with self.assertRaises(TypeError): - arb.poisson_schedule(tstart = 'tstart') - with self.assertRaisesRegex(RuntimeError, - "frequency must be a non-negative number"): - arb.poisson_schedule(freq = -100.) - with self.assertRaises(TypeError): - arb.poisson_schedule(freq = 'freq') - with self.assertRaises(TypeError): - arb.poisson_schedule(seed = -1) - with self.assertRaises(TypeError): - arb.poisson_schedule(seed = 10.) - with self.assertRaises(TypeError): - arb.poisson_schedule(seed = 'seed') - with self.assertRaises(TypeError): - arb.poisson_schedule(seed = None) - def suite(): - # specify class and test functions in tuple (here: all tests starting with 'test' from classes RegularSchedule, ExplicitSchedule and PoissonSchedule - suite = unittest.TestSuite() - suite.addTests(unittest.makeSuite(RegularSchedule, ('test'))) - suite.addTests(unittest.makeSuite(ExplicitSchedule, ('test'))) - suite.addTests(unittest.makeSuite(PoissonSchedule, ('test'))) + # specify class and test functions in tuple (here: all tests starting with 'test' from class EventGenerator + suite = unittest.makeSuite(EventGenerator, ('test')) return suite def run(): diff --git a/python/test/unit/test_schedules.py b/python/test/unit/test_schedules.py new file mode 100644 index 0000000000000000000000000000000000000000..655773d0ca7bc12354b59d3842b23ac996172f19 --- /dev/null +++ b/python/test/unit/test_schedules.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- +# +# test_schedules.py + +import unittest + +import arbor as arb + +# to be able to run .py file from child directory +import sys, os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) + +try: + import options +except ModuleNotFoundError: + from test import options + +""" +all tests for schedules (regular, explicit, poisson) +""" + +class RegularSchedule(unittest.TestCase): + def test_none_contor_regular_schedule(self): + rs = arb.regular_schedule(tstart=None, tstop=None) + + def test_tstart_dt_tstop_contor_regular_schedule(self): + rs = arb.regular_schedule(10., 1., 20.) + self.assertEqual(rs.tstart, 10.) + self.assertEqual(rs.dt, 1.) + self.assertEqual(rs.tstop, 20.) + + def test_set_tstart_dt_tstop_regular_schedule(self): + rs = arb.regular_schedule() + rs.tstart = 17. + rs.dt = 0.5 + rs.tstop = 42. + self.assertEqual(rs.tstart, 17.) + self.assertAlmostEqual(rs.dt, 0.5) + self.assertEqual(rs.tstop, 42.) + + def test_exceptions_regular_schedule(self): + with self.assertRaisesRegex(RuntimeError, + "tstart must be a non-negative number, or None"): + arb.regular_schedule(tstart = -1.) + with self.assertRaisesRegex(RuntimeError, + "dt must be a non-negative number"): + arb.regular_schedule(dt = -0.1) + with self.assertRaises(TypeError): + arb.regular_schedule(dt = None) + with self.assertRaises(TypeError): + arb.regular_schedule(dt = 'dt') + with self.assertRaisesRegex(RuntimeError, + "tstop must be a non-negative number, or None"): + arb.regular_schedule(tstop = 'tstop') + +class ExplicitSchedule(unittest.TestCase): + def test_times_contor_explicit_schedule(self): + es = arb.explicit_schedule([1, 2, 3, 4.5]) + self.assertEqual(es.times, [1, 2, 3, 4.5]) + + def test_set_times_explicit_schedule(self): + es = arb.explicit_schedule() + es.times = [42, 43, 44, 55.5, 100] + self.assertEqual(es.times, [42, 43, 44, 55.5, 100]) + + def test_exceptions_explicit_schedule(self): + with self.assertRaisesRegex(RuntimeError, + "explicit time schedule cannot contain negative values"): + arb.explicit_schedule([-1]) + with self.assertRaises(TypeError): + arb.explicit_schedule(['times']) + with self.assertRaises(TypeError): + arb.explicit_schedule([None]) + with self.assertRaises(TypeError): + arb.explicit_schedule([[1,2,3]]) + +class PoissonSchedule(unittest.TestCase): + def test_freq_seed_contor_poisson_schedule(self): + ps = arb.poisson_schedule(freq = 5., seed = 42) + self.assertEqual(ps.freq, 5.) + self.assertEqual(ps.seed, 42) + + def test_tstart_freq_seed_contor_poisson_schedule(self): + ps = arb.poisson_schedule(10., 100., 1000) + self.assertEqual(ps.tstart, 10.) + self.assertEqual(ps.freq, 100.) + self.assertEqual(ps.seed, 1000) + + def test_set_tstart_freq_seed_poisson_schedule(self): + ps = arb.poisson_schedule() + ps.tstart = 4.5 + ps.freq = 5.5 + ps.seed = 83 + self.assertAlmostEqual(ps.tstart, 4.5) + self.assertAlmostEqual(ps.freq, 5.5) + self.assertEqual(ps.seed, 83) + + def test_exceptions_poisson_schedule(self): + with self.assertRaisesRegex(RuntimeError, + "tstart must be a non-negative number"): + arb.poisson_schedule(tstart = -10.) + with self.assertRaises(TypeError): + arb.poisson_schedule(tstart = None) + with self.assertRaises(TypeError): + arb.poisson_schedule(tstart = 'tstart') + with self.assertRaisesRegex(RuntimeError, + "frequency must be a non-negative number"): + arb.poisson_schedule(freq = -100.) + with self.assertRaises(TypeError): + arb.poisson_schedule(freq = 'freq') + with self.assertRaises(TypeError): + arb.poisson_schedule(seed = -1) + with self.assertRaises(TypeError): + arb.poisson_schedule(seed = 10.) + with self.assertRaises(TypeError): + arb.poisson_schedule(seed = 'seed') + with self.assertRaises(TypeError): + arb.poisson_schedule(seed = None) + +def suite(): + # specify class and test functions in tuple (here: all tests starting with 'test' from classes RegularSchedule, ExplicitSchedule and PoissonSchedule + suite = unittest.TestSuite() + suite.addTests(unittest.makeSuite(RegularSchedule, ('test'))) + suite.addTests(unittest.makeSuite(ExplicitSchedule, ('test'))) + suite.addTests(unittest.makeSuite(PoissonSchedule, ('test'))) + return suite + +def run(): + v = options.parse_arguments().verbosity + runner = unittest.TextTestRunner(verbosity = v) + runner.run(suite()) + +if __name__ == "__main__": + run()