diff --git a/external/modparser b/external/modparser index 588ca1a5ea28ef04d17b318e754d479e5489eb9a..b200bf6376a2dc30edea98fcc2375fc9be095135 160000 --- a/external/modparser +++ b/external/modparser @@ -1 +1 @@ -Subproject commit 588ca1a5ea28ef04d17b318e754d479e5489eb9a +Subproject commit b200bf6376a2dc30edea98fcc2375fc9be095135 diff --git a/miniapp/CMakeLists.txt b/miniapp/CMakeLists.txt index a5b9561198875a6515b4e8a69f7685ea75c19803..682adefe1d04b057cb330dc9468485b276b797b2 100644 --- a/miniapp/CMakeLists.txt +++ b/miniapp/CMakeLists.txt @@ -1,9 +1,9 @@ set(HEADERS ) set(MINIAPP_SOURCES - # mpi.cpp io.cpp miniapp.cpp + miniapp_recipes.cpp ) add_executable(miniapp.exe ${MINIAPP_SOURCES} ${HEADERS}) diff --git a/miniapp/io.cpp b/miniapp/io.cpp index 52594b6bcf872bc3eb22c0bc7376f2d98b35c060..6664cf47e4ddbf1fcd33cd4589553a29cb967255 100644 --- a/miniapp/io.cpp +++ b/miniapp/io.cpp @@ -2,6 +2,7 @@ #include <exception> #include <tclap/CmdLine.h> +#include <json/src/json.hpp> #include "io.hpp" @@ -16,7 +17,7 @@ namespace io { cl_options read_options(int argc, char** argv) { // set default options - const cl_options default_options{"", 1000, 500, "expsyn", 100, 100., 0.025, false}; + const cl_options defopts{"", 1000, 500, "expsyn", 100, 100., 0.025, false}; cl_options options; // parse command line arguments @@ -25,35 +26,27 @@ cl_options read_options(int argc, char** argv) { TCLAP::ValueArg<uint32_t> ncells_arg( "n", "ncells", "total number of cells in the model", - false, 1000, "non negative integer"); + false, defopts.cells, "non negative integer", cmd); TCLAP::ValueArg<uint32_t> nsynapses_arg( "s", "nsynapses", "number of synapses per cell", - false, 500, "non negative integer"); + false, defopts.synapses_per_cell, "non negative integer", cmd); TCLAP::ValueArg<std::string> syntype_arg( "S", "syntype", "type of synapse (expsyn or exp2syn)", - false, "expsyn", "synapse type"); + false, defopts.syn_type, "synapse type", cmd); TCLAP::ValueArg<uint32_t> ncompartments_arg( "c", "ncompartments", "number of compartments per segment", - false, 100, "non negative integer"); + false, defopts.compartments_per_segment, "non negative integer", cmd); TCLAP::ValueArg<std::string> ifile_arg( "i", "ifile", "json file with model parameters", - false, "","file name string"); + false, "","file name string", cmd); TCLAP::ValueArg<double> tfinal_arg( "t", "tfinal", "time to simulate in ms", - false, 100., "positive real number"); + false, defopts.tfinal, "positive real number", cmd); TCLAP::ValueArg<double> dt_arg( "d", "dt", "time step size in ms", - false, 0.025, "positive real number"); + false, defopts.dt, "positive real number", cmd); TCLAP::SwitchArg all_to_all_arg( - "a","alltoall","all to all network", cmd, false); - - cmd.add(ncells_arg); - cmd.add(nsynapses_arg); - cmd.add(syntype_arg); - cmd.add(ncompartments_arg); - cmd.add(ifile_arg); - cmd.add(dt_arg); - cmd.add(tfinal_arg); + "m","alltoall","all to all network", cmd, false); cmd.parse(argc, argv); @@ -67,20 +60,13 @@ cl_options read_options(int argc, char** argv) { options.all_to_all = all_to_all_arg.getValue(); } // catch any exceptions in command line handling - catch(TCLAP::ArgException &e) { - std::cerr << "error: parsing command line arguments:\n " - << e.error() << " for arg " << e.argId() << "\n"; - exit(1); - } - - if(options.ifname == "") { - options.check_and_normalize(); - return options; + catch (TCLAP::ArgException& e) { + throw usage_error("error parsing command line argument "+e.argId()+": "+e.error()); } - else { - std::ifstream fid(options.ifname, std::ifstream::in); - if(fid.is_open()) { + if (options.ifname != "") { + std::ifstream fid(options.ifname); + if (fid) { // read json data in input file nlohmann::json fopts; fid >> fopts; @@ -93,22 +79,17 @@ cl_options read_options(int argc, char** argv) { options.tfinal = fopts["tfinal"]; options.all_to_all = fopts["all_to_all"]; } - catch(std::domain_error e) { - std::cerr << "error: unable to open parameters in " - << options.ifname << " : " << e.what() << "\n"; - exit(1); + catch (std::exception& e) { + throw model_description_error( + "unable to parse parameters in "+options.ifname+": "+e.what()); } - catch(std::exception e) { - std::cerr << "error: unable to open parameters in " - << options.ifname << "\n"; - exit(1); - } - options.check_and_normalize(); - return options; + } + else { + throw usage_error("unable to open model parameter file "+options.ifname); } } - return default_options; + return options; } std::ostream& operator<<(std::ostream& o, const cl_options& options) { diff --git a/miniapp/io.hpp b/miniapp/io.hpp index 1f530d9150c69b01cfa9f6cac8127180ff12ec70..dab584fe2a132f289bb9a2f1ae02711aa29ef5a4 100644 --- a/miniapp/io.hpp +++ b/miniapp/io.hpp @@ -1,6 +1,10 @@ #pragma once -#include <json/src/json.hpp> +#include <string> +#include <cstdint> +#include <iosfwd> +#include <stdexcept> +#include <utility> namespace nest { namespace mc { @@ -16,20 +20,25 @@ struct cl_options { double tfinal; double dt; bool all_to_all; +}; + +class usage_error: public std::runtime_error { +public: + template <typename S> + usage_error(S&& whatmsg): std::runtime_error(std::forward<S>(whatmsg)) {} +}; - // TODO the normalize bit should be moved to the model_parameters when - // we start having more models - void check_and_normalize() { - if(all_to_all) { - synapses_per_cell = cells - 1; - } - } +class model_description_error: public std::runtime_error { +public: + template <typename S> + model_description_error(S&& whatmsg): std::runtime_error(std::forward<S>(whatmsg)) {} }; std::ostream& operator<<(std::ostream& o, const cl_options& opt); cl_options read_options(int argc, char** argv); + } // namespace io } // namespace mc } // namespace nest diff --git a/miniapp/miniapp.cpp b/miniapp/miniapp.cpp index 3867f4e8365fa8ac65bbe235e2e7fe915b94e60b..9b26ff86ef37f829667161afca27a129791de3ed 100644 --- a/miniapp/miniapp.cpp +++ b/miniapp/miniapp.cpp @@ -1,448 +1,178 @@ +#include <cmath> +#include <exception> #include <iostream> #include <fstream> -#include <sstream> +#include <memory> +#include <json/src/json.hpp> + +#include <common_types.hpp> #include <cell.hpp> #include <cell_group.hpp> #include <fvm_cell.hpp> -#include <mechanism_interface.hpp> +#include <mechanism_catalogue.hpp> +#include <model.hpp> +#include <threading/threading.hpp> +#include <profiling/profiler.hpp> +#include <communication/communicator.hpp> +#include <communication/global_policy.hpp> +#include <util/ioutil.hpp> +#include <util/optional.hpp> #include "io.hpp" -#include "threading/threading.hpp" -#include "profiling/profiler.hpp" -#include "communication/communicator.hpp" -#include "communication/global_policy.hpp" -#include "util/optional.hpp" - -using namespace nest; +#include "miniapp_recipes.hpp" +#include "trace_sampler.hpp" -using real_type = double; -using index_type = int; -using id_type = uint32_t; -using numeric_cell = mc::fvm::fvm_cell<real_type, index_type>; -using cell_group = mc::cell_group<numeric_cell>; +using namespace nest::mc; -using global_policy = nest::mc::communication::global_policy; -using communicator_type = - mc::communication::communicator<global_policy>; +using global_policy = communication::global_policy; -using nest::mc::util::optional; +using lowered_cell = fvm::fvm_cell<double, cell_local_size_type>; +using model_type = model<lowered_cell>; +using sample_trace_type = sample_trace<model_type::time_type, model_type::value_type>; -struct model { - communicator_type communicator; - std::vector<cell_group> cell_groups; +void banner(); +std::unique_ptr<recipe> make_recipe(const io::cl_options&); +std::unique_ptr<sample_trace_type> make_trace(cell_member_type probe_id, probe_spec probe); +std::pair<cell_gid_type, cell_gid_type> distribute_cells(cell_size_type ncells); - unsigned num_groups() const { - return cell_groups.size(); - } +void write_trace_json(const sample_trace_type& trace, const std::string& prefix = "trace_"); - void run(double tfinal, double dt) { - auto t = 0.; - auto delta = std::min(double(communicator.min_delay()), tfinal); - while (t<tfinal) { - mc::threading::parallel_for::apply( - 0, num_groups(), - [&](int i) { - PE("stepping","events"); - cell_groups[i].enqueue_events(communicator.queue(i)); - PL(); - - cell_groups[i].advance(std::min(t+delta, tfinal), dt); - - PE("events"); - communicator.add_spikes(cell_groups[i].spikes()); - cell_groups[i].clear_spikes(); - PL(2); - } - ); - - PE("stepping", "exchange"); - communicator.exchange(); - PL(2); - - t += delta; - } - } - - void init_communicator() { - PE("setup", "communicator"); +int main(int argc, char** argv) { + nest::mc::communication::global_policy_guard global_guard(argc, argv); - // calculate the source and synapse distribution serially - std::vector<id_type> target_counts(num_groups()); - std::vector<id_type> source_counts(num_groups()); - for (auto i=0u; i<num_groups(); ++i) { - target_counts[i] = cell_groups[i].cell().synapses()->size(); - source_counts[i] = cell_groups[i].spike_sources().size(); - } + try { + std::cout << util::mask_stream(global_policy::id()==0); + banner(); - target_map = mc::algorithms::make_index(target_counts); - source_map = mc::algorithms::make_index(source_counts); + // read parameters + io::cl_options options = io::read_options(argc, argv); + std::cout << options << "\n"; + std::cout << "\n"; + std::cout << ":: simulation to " << options.tfinal << " ms in " + << std::ceil(options.tfinal / options.dt) << " steps of " + << options.dt << " ms" << std::endl; - // create connections - communicator = communicator_type(num_groups(), target_counts); + auto recipe = make_recipe(options); + auto cell_range = distribute_cells(recipe->num_cells()); - PL(2); - } + // build model from recipe + model_type m(*recipe, cell_range.first, cell_range.second); - void update_gids() { - PE("setup", "globalize"); - auto com_policy = communicator.communication_policy(); - auto global_source_map = com_policy.make_map(source_map.back()); - auto domain_idx = communicator.domain_id(); - for (auto i=0u; i<num_groups(); ++i) { - cell_groups[i].set_source_gids( - source_map[i]+global_source_map[domain_idx] - ); - cell_groups[i].set_target_gids( - target_map[i]+communicator.target_gid_from_group_lid(0) - ); + // inject some artificial spikes, 1 per 20 neurons. + cell_gid_type spike_cell = 20*((cell_range.first+19)/20); + for (; spike_cell<cell_range.second; spike_cell+=20) { + m.add_artificial_spike({spike_cell,0u}); } - PL(2); - } - // TODO : only stored here because init_communicator() and update_gids() are split - std::vector<id_type> source_map; - std::vector<id_type> target_map; - - // traces from probes - struct trace_data { - struct sample_type { - float time; - double value; - }; - std::string name; - std::string units; - index_type id; - std::vector<sample_type> samples; - }; - - // different traces may be written to by different threads; - // during simulation, each trace_sampler will be responsible for its - // corresponding element in the traces vector. - - std::vector<trace_data> traces; - - // make a sampler that records to traces - struct simple_sampler_functor { - std::vector<trace_data> &traces_; - size_t trace_index_ = 0; - float requested_sample_time_ = 0; - float dt_ = 0; - - simple_sampler_functor(std::vector<trace_data> &traces, size_t index, float dt) : - traces_(traces), trace_index_(index), dt_(dt) - {} - - optional<float> operator()(float t, double v) { - traces_[trace_index_].samples.push_back({t,v}); - return requested_sample_time_ += dt_; + // attach samplers to all probes + std::vector<std::unique_ptr<sample_trace_type>> traces; + const model_type::time_type sample_dt = 0.1; + for (auto probe: m.probes()) { + traces.push_back(make_trace(probe.id, probe.probe)); + m.attach_sampler(probe.id, make_trace_sampler(traces.back().get(),sample_dt)); } - }; - mc::sampler make_simple_sampler( - index_type probe_gid, const std::string& name, const std::string& units, index_type id, float dt) - { - traces.push_back(trace_data{name, units, id}); - return {probe_gid, simple_sampler_functor(traces, traces.size()-1, dt)}; - } + // run model + m.run(options.tfinal, options.dt); + util::profiler_output(0.001); - void reset_traces() { - // do not call during simulation: thread-unsafe access to traces. - traces.clear(); - } + std::cout << "there were " << m.num_spikes() << " spikes\n"; - void dump_traces() { - // do not call during simulation: thread-unsafe access to traces. + // save traces for (const auto& trace: traces) { - auto path = "trace_" + std::to_string(trace.id) - + "_" + trace.name + ".json"; - - nlohmann::json jrep; - jrep["name"] = trace.name; - jrep["units"] = trace.units; - jrep["id"] = trace.id; - - auto& jt = jrep["data"]["time"]; - auto& jy = jrep["data"][trace.name]; - - for (const auto& sample: trace.samples) { - jt.push_back(sample.time); - jy.push_back(sample.value); - } - std::ofstream file(path); - file << std::setw(1) << jrep << std::endl; + write_trace_json(*trace.get()); } } -}; - -// define some global model parameters -namespace parameters { -namespace synapses { - // synapse delay - constexpr float delay = 20.0; // ms - - // connection weight - constexpr double weight_per_cell = 0.3; // uS -} -} - -/////////////////////////////////////// -// prototypes -/////////////////////////////////////// - -/// make a single abstract cell -mc::cell make_cell(int compartments_per_segment, int num_synapses, const std::string& syn_type); - -/// do basic setup (initialize global state, print banner, etc) -void setup(); - -/// helper function for initializing cells -cell_group make_lowered_cell(int cell_index, const mc::cell& c); - -/// models -void all_to_all_model(nest::mc::io::cl_options& opt, model& m); - -/////////////////////////////////////// -// main -/////////////////////////////////////// -int main(int argc, char** argv) { - nest::mc::communication::global_policy_guard global_guard(argc, argv); - - setup(); - - // read parameters - mc::io::cl_options options; - try { - options = mc::io::read_options(argc, argv); - if (global_policy::id()==0) { - std::cout << options << "\n"; - } + catch (io::usage_error& e) { + // only print usage/startup errors on master + std::cerr << util::mask_stream(global_policy::id()==0); + std::cerr << e.what() << "\n"; + return 1; } catch (std::exception& e) { - std::cerr << e.what() << std::endl; - exit(1); - } - - model m; - all_to_all_model(options, m); - - // - // time stepping - // - auto tfinal = options.tfinal; - auto dt = options.dt; - - auto id = m.communicator.domain_id(); - - if (id==0) { - // use std::endl to force flush of output on cluster jobs - std::cout << "\n"; - std::cout << ":: simulation to " << tfinal << " ms in " - << std::ceil(tfinal / dt) << " steps of " - << dt << " ms" << std::endl; - } - - // add some spikes to the system to start it - auto first = m.communicator.group_gid_first(id); - if(first%20) { - first += 20 - (first%20); // round up to multiple of 20 + std::cerr << e.what() << "\n"; + return 2; } - auto last = m.communicator.group_gid_first(id+1); - for (auto i=first; i<last; i+=20) { - m.communicator.add_spike({i, 0}); - } - - m.run(tfinal, dt); - - mc::util::profiler_output(0.001); - - if (id==0) { - std::cout << "there were " << m.communicator.num_spikes() << " spikes\n"; - } - - m.dump_traces(); + return 0; } -/////////////////////////////////////// -// models -/////////////////////////////////////// +std::pair<cell_gid_type, cell_gid_type> distribute_cells(cell_size_type num_cells) { + // Crude load balancing: + // divide [0, num_cells) into num_domains non-overlapping, contiguous blocks + // of size as close to equal as possible. -void all_to_all_model(nest::mc::io::cl_options& options, model& m) { - // - // make cells - // + auto num_domains = communication::global_policy::size(); + auto domain_id = communication::global_policy::id(); - auto synapses_per_cell = options.synapses_per_cell; - auto is_all_to_all = options.all_to_all; + cell_gid_type cell_from = (cell_gid_type)(num_cells*(domain_id/(double)num_domains)); + cell_gid_type cell_to = (cell_gid_type)(num_cells*((domain_id+1)/(double)num_domains)); - // make a basic cell - auto basic_cell = - make_cell(options.compartments_per_segment, synapses_per_cell, options.syn_type); + return {cell_from, cell_to}; +} - auto num_domains = global_policy::size(); - auto domain_id = global_policy::id(); +void banner() { + std::cout << "====================\n"; + std::cout << " starting miniapp\n"; + std::cout << " - " << threading::description() << " threading support\n"; + std::cout << " - communication policy: " << global_policy::name() << "\n"; + std::cout << "====================\n"; +} - // make a vector for storing all of the cells - id_type ncell_global = options.cells; - id_type ncell_local = ncell_global / num_domains; - int remainder = ncell_global - (ncell_local*num_domains); - if (domain_id<remainder) { - ncell_local++; - } +std::unique_ptr<recipe> make_recipe(const io::cl_options& options) { + basic_recipe_param p; - m.cell_groups = std::vector<cell_group>(ncell_local); + p.num_compartments = options.compartments_per_segment; + p.num_synapses = options.all_to_all? options.cells-1: options.synapses_per_cell; + p.synapse_type = options.syn_type; - // initialize the cells in parallel - mc::threading::parallel_for::apply( - 0, ncell_local, - [&](int i) { - PE("setup", "cells"); - m.cell_groups[i] = make_lowered_cell(i, basic_cell); - PL(2); - } - ); - - // - // network creation - // - m.init_communicator(); - - PE("setup", "connections"); - - // RNG distributions for connection delays and source cell ids - auto weight_distribution = std::exponential_distribution<float>(0.75); - auto source_distribution = - std::uniform_int_distribution<uint32_t>(0u, options.cells-1); - - // calculate the weight of synaptic connections, which is chosen so that - // the sum of all synaptic weights on a cell is - // parameters::synapses::weight_per_cell - float weight = parameters::synapses::weight_per_cell / synapses_per_cell; - - // loop over each local cell and build the list of synapse connections - // that terminate on the cell - for (auto lid=0u; lid<ncell_local; ++lid) { - auto target = m.communicator.target_gid_from_group_lid(lid); - auto gid = m.communicator.group_gid_from_group_lid(lid); - auto gen = std::mt19937(gid); // seed with cell gid for reproducability - // add synapses to cell - auto i = 0u; - auto cells_added = 0u; - while (cells_added < synapses_per_cell) { - auto source = is_all_to_all ? i : source_distribution(gen); - if (gid!=source) { - m.communicator.add_connection({ - source, target++, weight, - parameters::synapses::delay + weight_distribution(gen) - }); - cells_added++; - } - ++i; - } + if (options.all_to_all) { + return make_basic_kgraph_recipe(options.cells, p); } - - m.communicator.construct(); - - //for (auto con : m.communicator.connections()) std::cout << con << "\n"; - - m.update_gids(); - - // - // setup probes - // - - PL(); PE("probes"); - - // monitor soma and dendrite on a few cells - float sample_dt = 0.1; - index_type monitor_group_gids[] = { 0, 1, 2 }; - for (auto gid : monitor_group_gids) { - if (!m.communicator.is_local_group(gid)) { - continue; - } - - auto lid = m.communicator.group_lid(gid); - auto probe_soma = m.cell_groups[lid].probe_gid_range().first; - auto probe_dend = probe_soma+1; - auto probe_dend_current = probe_soma+2; - - m.cell_groups[lid].add_sampler( - m.make_simple_sampler(probe_soma, "vsoma", "mV", gid, sample_dt) - ); - m.cell_groups[lid].add_sampler( - m.make_simple_sampler(probe_dend, "vdend", "mV", gid, sample_dt) - ); - m.cell_groups[lid].add_sampler( - m.make_simple_sampler(probe_dend_current, "idend", "mA/cm²", gid, sample_dt) - ); + else { + return make_basic_rgraph_recipe(options.cells, p); } - - PL(2); } -/////////////////////////////////////// -// function definitions -/////////////////////////////////////// - -void setup() { - // print banner - if (global_policy::id()==0) { - std::cout << "====================\n"; - std::cout << " starting miniapp\n"; - std::cout << " - " << mc::threading::description() << " threading support\n"; - std::cout << " - communication policy: " << global_policy::name() << "\n"; - std::cout << "====================\n"; - } - - // setup global state for the mechanisms - mc::mechanisms::setup_mechanism_helpers(); +std::unique_ptr<sample_trace_type> make_trace(cell_member_type probe_id, probe_spec probe) { + std::string name = ""; + std::string units = ""; + + switch (probe.kind) { + case probeKind::membrane_voltage: + name = "v"; + units = "mV"; + break; + case probeKind::membrane_current: + name = "i"; + units = "mA/cm²"; + break; + default: ; + } + name += probe.location.segment? "dend" : "soma"; + + return util::make_unique<sample_trace_type>(probe_id, name, units); } -// make a high level cell description for use in simulation -mc::cell make_cell(int compartments_per_segment, int num_synapses, const std::string& syn_type) { - nest::mc::cell cell; - - // Soma with diameter 12.6157 um and HH channel - auto soma = cell.add_soma(12.6157/2.0); - soma->add_mechanism(mc::hh_parameters()); +void write_trace_json(const sample_trace_type& trace, const std::string& prefix) { + auto path = prefix + std::to_string(trace.probe_id.gid) + + "." + std::to_string(trace.probe_id.index) + "_" + trace.name + ".json"; - // add dendrite of length 200 um and diameter 1 um with passive channel - std::vector<mc::cable_segment*> dendrites; - dendrites.push_back(cell.add_cable(0, mc::segmentKind::dendrite, 0.5, 0.5, 200)); - dendrites.push_back(cell.add_cable(1, mc::segmentKind::dendrite, 0.5, 0.25,100)); - dendrites.push_back(cell.add_cable(1, mc::segmentKind::dendrite, 0.5, 0.25,100)); - - for (auto d : dendrites) { - d->add_mechanism(mc::pas_parameters()); - d->set_compartments(compartments_per_segment); - d->mechanism("membrane").set("r_L", 100); - } + nlohmann::json jrep; + jrep["name"] = trace.name; + jrep["units"] = trace.units; + jrep["cell"] = trace.probe_id.gid; + jrep["probe"] = trace.probe_id.index; - cell.add_detector({0,0}, 20); + auto& jt = jrep["data"]["time"]; + auto& jy = jrep["data"][trace.name]; - auto gen = std::mt19937(); - auto distribution = std::uniform_real_distribution<float>(0.f, 1.0f); - // distribute the synapses at random locations the terminal dendrites in a - // round robin manner - nest::mc::parameter_list syn_default(syn_type); - for (auto i=0; i<num_synapses; ++i) { - cell.add_synapse({2+(i%2), distribution(gen)}, syn_default); + for (const auto& sample: trace.samples) { + jt.push_back(sample.time); + jy.push_back(sample.value); } - - // add probes: - auto probe_soma = cell.add_probe({0, 0}, mc::probeKind::membrane_voltage); - auto probe_dendrite = cell.add_probe({1, 0.5}, mc::probeKind::membrane_voltage); - auto probe_dendrite_current = cell.add_probe({1, 0.5}, mc::probeKind::membrane_current); - - EXPECTS(probe_soma==0); - EXPECTS(probe_dendrite==1); - EXPECTS(probe_dendrite_current==2); - (void)probe_soma, (void)probe_dendrite, (void)probe_dendrite_current; - - return cell; + std::ofstream file(path); + file << std::setw(1) << jrep << std::endl; } -cell_group make_lowered_cell(int cell_index, const mc::cell& c) { - return cell_group(c); -} diff --git a/miniapp/miniapp_recipes.cpp b/miniapp/miniapp_recipes.cpp new file mode 100644 index 0000000000000000000000000000000000000000..22c59bd851a049e448784b9cfdd635d7e7a96970 --- /dev/null +++ b/miniapp/miniapp_recipes.cpp @@ -0,0 +1,234 @@ +#include <cmath> +#include <random> +#include <vector> +#include <utility> + +#include <cell.hpp> +#include <util/debug.hpp> + +#include "miniapp_recipes.hpp" + +namespace nest { +namespace mc { + +// TODO: split cell description into separate morphology, stimulus, probes, mechanisms etc. +// description for greater data reuse. + +template <typename RNG> +cell make_basic_cell( + unsigned compartments_per_segment, + unsigned num_synapses, + const std::string& syn_type, + RNG& rng) +{ + nest::mc::cell cell; + + // Soma with diameter 12.6157 um and HH channel + auto soma = cell.add_soma(12.6157/2.0); + soma->add_mechanism(mc::hh_parameters()); + + // add dendrite of length 200 um and diameter 1 um with passive channel + std::vector<mc::cable_segment*> dendrites; + dendrites.push_back(cell.add_cable(0, mc::segmentKind::dendrite, 0.5, 0.5, 200)); + dendrites.push_back(cell.add_cable(1, mc::segmentKind::dendrite, 0.5, 0.25,100)); + dendrites.push_back(cell.add_cable(1, mc::segmentKind::dendrite, 0.5, 0.25,100)); + + for (auto d : dendrites) { + d->add_mechanism(mc::pas_parameters()); + d->set_compartments(compartments_per_segment); + d->mechanism("membrane").set("r_L", 100); + } + + cell.add_detector({0,0}, 20); + + auto distribution = std::uniform_real_distribution<float>(0.f, 1.0f); + // distribute the synapses at random locations the terminal dendrites in a + // round robin manner; the terminal dendrites in this cell have indices 2 and 3. + nest::mc::parameter_list syn_default(syn_type); + for (unsigned i=0; i<num_synapses; ++i) { + cell.add_synapse({2+(i%2), distribution(rng)}, syn_default); + } + + return cell; +} + +class basic_cell_recipe: public recipe { +public: + basic_cell_recipe(cell_gid_type ncell, basic_recipe_param param, probe_distribution pdist): + ncell_(ncell), param_(std::move(param)), pdist_(std::move(pdist)) + { + delay_distribution_param = exp_param{param_.mean_connection_delay_ms + - param_.min_connection_delay_ms}; + } + + cell_size_type num_cells() const override { return ncell_; } + + cell get_cell(cell_gid_type i) const override { + auto gen = std::mt19937(i); // TODO: replace this with hashing generator... + + auto cc = get_cell_count_info(i); + auto cell = make_basic_cell(param_.num_compartments, cc.num_targets, + param_.synapse_type, gen); + + EXPECTS(cell.num_segments()==basic_cell_segments); + EXPECTS(cell.probes().size()==0); + EXPECTS(cell.synapses().size()==cc.num_targets); + EXPECTS(cell.detectors().size()==cc.num_sources); + + // add probes + unsigned n_probe_segs = pdist_.all_segments? basic_cell_segments: 1u; + for (unsigned i = 0; i<n_probe_segs; ++i) { + if (pdist_.membrane_voltage) { + cell.add_probe({{i, i? 0.5: 0.0}, mc::probeKind::membrane_voltage}); + } + if (pdist_.membrane_current) { + cell.add_probe({{i, i? 0.5: 0.0}, mc::probeKind::membrane_current}); + } + } + EXPECTS(cell.probes().size()==cc.num_probes); + return cell; + } + + cell_count_info get_cell_count_info(cell_gid_type i) const override { + cell_count_info cc = {1, param_.num_synapses, 0 }; + + // probe this cell? + if (std::floor(i*pdist_.proportion)!=std::floor((i-1.0)*pdist_.proportion)) { + std::size_t np = pdist_.membrane_voltage + pdist_.membrane_current; + if (pdist_.all_segments) { + np *= basic_cell_segments; + } + + cc.num_probes = np; + } + + return cc; + } + +protected: + template <typename RNG> + cell_connection draw_connection_params(RNG& rng) const { + std::exponential_distribution<float> delay_dist(delay_distribution_param); + float delay = param_.min_connection_delay_ms + delay_dist(rng); + float weight = param_.syn_weight_per_cell/param_.num_synapses; + return cell_connection{{0, 0}, {0, 0}, weight, delay}; + } + + cell_gid_type ncell_; + basic_recipe_param param_; + probe_distribution pdist_; + static constexpr int basic_cell_segments = 4; + + using exp_param = std::exponential_distribution<float>::param_type; + exp_param delay_distribution_param; +}; + +class basic_ring_recipe: public basic_cell_recipe { +public: + basic_ring_recipe(cell_gid_type ncell, + basic_recipe_param param, + probe_distribution pdist = probe_distribution{}): + basic_cell_recipe(ncell, std::move(param), std::move(pdist)) {} + + std::vector<cell_connection> connections_on(cell_gid_type i) const override { + std::vector<cell_connection> conns; + auto gen = std::mt19937(i); // TODO: replace this with hashing generator... + + cell_gid_type prev = i==0? ncell_-1: i-1; + for (unsigned t=0; t<param_.num_synapses; ++t) { + cell_connection cc = draw_connection_params(gen); + cc.source = {prev, 0}; + cc.dest = {i, t}; + conns.push_back(cc); + } + + return conns; + } +}; + +std::unique_ptr<recipe> make_basic_ring_recipe( + cell_gid_type ncell, + basic_recipe_param param, + probe_distribution pdist) +{ + return std::unique_ptr<recipe>(new basic_ring_recipe(ncell, param, pdist)); +} + + +class basic_rgraph_recipe: public basic_cell_recipe { +public: + basic_rgraph_recipe(cell_gid_type ncell, + basic_recipe_param param, + probe_distribution pdist = probe_distribution{}): + basic_cell_recipe(ncell, std::move(param), std::move(pdist)) {} + + std::vector<cell_connection> connections_on(cell_gid_type i) const override { + std::vector<cell_connection> conns; + auto conn_param_gen = std::mt19937(i); // TODO: replace this with hashing generator... + auto source_gen = std::mt19937(i*123+457); // ditto + + std::uniform_int_distribution<cell_gid_type> source_distribution(0, ncell_-2); + + for (unsigned t=0; t<param_.num_synapses; ++t) { + auto source = source_distribution(source_gen); + if (source>=i) ++source; + + cell_connection cc = draw_connection_params(conn_param_gen); + cc.source = {source, 0}; + cc.dest = {i, t}; + conns.push_back(cc); + } + + return conns; + } +}; + +std::unique_ptr<recipe> make_basic_rgraph_recipe( + cell_gid_type ncell, + basic_recipe_param param, + probe_distribution pdist) +{ + return std::unique_ptr<recipe>(new basic_rgraph_recipe(ncell, param, pdist)); +} + +class basic_kgraph_recipe: public basic_cell_recipe { +public: + basic_kgraph_recipe(cell_gid_type ncell, + basic_recipe_param param, + probe_distribution pdist = probe_distribution{}): + basic_cell_recipe(ncell, std::move(param), std::move(pdist)) + { + if (std::size_t(param.num_synapses) != ncell-1) { + throw invalid_recipe_error("number of synapses per cell must equal number " + "of cells minus one in complete graph model"); + } + } + + std::vector<cell_connection> connections_on(cell_gid_type i) const override { + std::vector<cell_connection> conns; + auto conn_param_gen = std::mt19937(i); // TODO: replace this with hashing generator... + + for (unsigned t=0; t<param_.num_synapses; ++t) { + cell_gid_type source = t>=i? t+1: t; + EXPECTS(source<ncell_); + + cell_connection cc = draw_connection_params(conn_param_gen); + cc.source = {source, 0}; + cc.dest = {i, t}; + conns.push_back(cc); + } + + return conns; + } +}; + +std::unique_ptr<recipe> make_basic_kgraph_recipe( + cell_gid_type ncell, + basic_recipe_param param, + probe_distribution pdist) +{ + return std::unique_ptr<recipe>(new basic_kgraph_recipe(ncell, param, pdist)); +} + +} // namespace mc +} // namespace nest diff --git a/miniapp/miniapp_recipes.hpp b/miniapp/miniapp_recipes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2e519bd1b492e532db86175acbd11520d1bd7a8b --- /dev/null +++ b/miniapp/miniapp_recipes.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include <cstddef> +#include <memory> +#include <stdexcept> + +#include "recipe.hpp" + +// miniapp-specific recipes + +namespace nest { +namespace mc { + +struct probe_distribution { + float proportion = 1.f; // what proportion of cells should get probes? + bool all_segments = true; // false => soma only + bool membrane_voltage = true; + bool membrane_current = true; +}; + +struct basic_recipe_param { + unsigned num_compartments = 1; + unsigned num_synapses = 1; + std::string synapse_type = "expsyn"; + float min_connection_delay_ms = 20.0; + float mean_connection_delay_ms = 20.75; + float syn_weight_per_cell = 0.3; +}; + +std::unique_ptr<recipe> make_basic_ring_recipe( + cell_gid_type ncell, + basic_recipe_param param, + probe_distribution pdist = probe_distribution{}); + +std::unique_ptr<recipe> make_basic_kgraph_recipe( + cell_gid_type ncell, + basic_recipe_param param, + probe_distribution pdist = probe_distribution{}); + +std::unique_ptr<recipe> make_basic_rgraph_recipe( + cell_gid_type ncell, + basic_recipe_param param, + probe_distribution pdist = probe_distribution{}); + +} // namespace mc +} // namespace nest diff --git a/miniapp/trace_sampler.hpp b/miniapp/trace_sampler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..26db0a491444db32a49653f861b18dce464c103d --- /dev/null +++ b/miniapp/trace_sampler.hpp @@ -0,0 +1,70 @@ +#pragma once + +#include <cstdlib> +#include <vector> + +#include <common_types.hpp> +#include <cell.hpp> +#include <util/optional.hpp> + +#include <iostream> + +namespace nest { +namespace mc { + +template <typename Time=float, typename Value=double> +struct sample_trace { + using time_type = Time; + using value_type = Value; + + struct sample_type { + time_type time; + value_type value; + }; + + std::string name; + std::string units; + cell_member_type probe_id; + std::vector<sample_type> samples; + + sample_trace() =default; + sample_trace(cell_member_type probe_id, const std::string& name, const std::string& units): + name(name), units(units), probe_id(probe_id) + {} +}; + +template <typename Time=float, typename Value=double> +struct trace_sampler { + using time_type = Time; + using value_type = Value; + + float next_sample_t() const { return t_next_sample_; } + + util::optional<time_type> operator()(time_type t, value_type v) { + if (t<t_next_sample_) { + return t_next_sample_; + } + + trace_->samples.push_back({t,v}); + return t_next_sample_+=sample_dt_; + } + + trace_sampler(sample_trace<time_type, value_type> *trace, time_type sample_dt, time_type tfrom=0): + trace_(trace), sample_dt_(sample_dt), t_next_sample_(tfrom) + {} + +private: + sample_trace<time_type, value_type> *trace_; + + time_type sample_dt_; + time_type t_next_sample_; +}; + +// with type deduction ... +template <typename Time, typename Value> +trace_sampler<Time, Value> make_trace_sampler(sample_trace<Time, Value> *trace, Time sample_dt, Time tfrom=0) { + return trace_sampler<Time, Value>(trace, sample_dt, tfrom); +} + +} // namespace mc +} // namespace nest diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7f4a588d933a5ea74342c70dc8b6ac5df95aca8b..8e0db8876fadf3d030b65ccac9a653eedc1c8d37 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -2,8 +2,8 @@ set(HEADERS swcio.hpp ) set(BASE_SOURCES + common_types_io.cpp cell.cpp - mechanism_interface.cpp parameter_list.cpp profiling/profiler.cpp swcio.cpp diff --git a/src/algorithms.hpp b/src/algorithms.hpp index e89f06f539f51b37f7ce95ca8ff2de18b3709d62..7789b24156d403ba702404ccf67a67d0b2cfcedf 100644 --- a/src/algorithms.hpp +++ b/src/algorithms.hpp @@ -90,12 +90,12 @@ bool is_minimal_degree(C const& c) "is_minimal_degree only applies to integral types" ); - if(c.size()==0u) { + if (c.size()==0u) { return true; } using value_type = typename C::value_type; - if(c[0] != value_type(0)) { + if (c[0] != value_type(0)) { return false; } auto i = value_type(1); @@ -121,7 +121,7 @@ bool is_positive(C const& c) } template<typename C> -bool has_contiguous_segments(const C &parent_index) +bool has_contiguous_segments(const C& parent_index) { static_assert( std::is_integral<typename C::value_type>::value, @@ -141,7 +141,7 @@ bool has_contiguous_segments(const C &parent_index) return false; } - if(p != i-1) { + if(p != decltype(p)(i-1)) { // we have a branch and i-1 is a leaf node is_leaf[i-1] = true; } @@ -151,7 +151,7 @@ bool has_contiguous_segments(const C &parent_index) } template<typename C> -std::vector<typename C::value_type> child_count(const C &parent_index) +std::vector<typename C::value_type> child_count(const C& parent_index) { static_assert( std::is_integral<typename C::value_type>::value, @@ -174,7 +174,7 @@ std::vector<typename C::value_type> branches(const C& parent_index) "integral type required" ); - EXPECTS(has_contiguous_segments(parent_index)); + //EXPECTS(has_contiguous_segments(parent_index)); std::vector<typename C::value_type> branch_index; if (parent_index.empty()) { @@ -229,7 +229,7 @@ typename C::value_type find_branch(const C& branch_index, auto it = std::find_if( branch_index.begin(), branch_index.end(), - [nid](const value_type &v) { return v > nid; } + [nid](const value_type& v) { return v > nid; } ); return it - branch_index.begin() - 1; @@ -250,7 +250,7 @@ std::vector<typename C::value_type> make_parent_index( } EXPECTS(parent_index.size() == unsigned(branch_index.back())); - EXPECTS(has_contiguous_segments(parent_index)); + //EXPECTS(has_contiguous_segments(parent_index)); EXPECTS(is_strictly_monotonic_increasing(branch_index)); // expand the branch index diff --git a/src/cell.cpp b/src/cell.cpp index 01d681a7e080c8c0444fe0299820181d5154028e..e96cae69aca834d61d81c9c53c9488532a3b5b69 100644 --- a/src/cell.cpp +++ b/src/cell.cpp @@ -26,7 +26,7 @@ cell::cell() parents_.push_back(0); } -int cell::num_segments() const +cell::size_type cell::num_segments() const { return segments_.size(); } @@ -75,9 +75,9 @@ cable_segment* cell::add_cable(cell::index_type parent, segment_ptr&& cable) return segments_.back()->as_cable(); } -segment* cell::segment(int index) +segment* cell::segment(index_type index) { - if(index<0 || index>=num_segments()) { + if (index>=num_segments()) { throw std::out_of_range( "attempt to access a segment with invalid index" ); @@ -85,9 +85,9 @@ segment* cell::segment(int index) return segments_[index].get(); } -segment const* cell::segment(int index) const +segment const* cell::segment(index_type index) const { - if(index<0 || index>=num_segments()) { + if (index>=num_segments()) { throw std::out_of_range( "attempt to access a segment with invalid index" ); @@ -109,7 +109,7 @@ soma_segment* cell::soma() return nullptr; } -cable_segment* cell::cable(int index) +cable_segment* cell::cable(index_type index) { if(index>0 && index<num_segments()) { return segment(index)->as_cable(); @@ -146,9 +146,9 @@ std::vector<segment_ptr> const& cell::segments() const return segments_; } -std::vector<int> cell::compartment_counts() const +std::vector<cell::size_type> cell::compartment_counts() const { - std::vector<int> comp_count; + std::vector<size_type> comp_count; comp_count.reserve(num_segments()); for(auto const& s : segments()) { comp_count.push_back(s->num_compartments()); @@ -156,10 +156,10 @@ std::vector<int> cell::compartment_counts() const return comp_count; } -size_t cell::num_compartments() const +cell::size_type cell::num_compartments() const { auto n = 0u; - for(auto &s : segments_) { + for(auto& s : segments_) { n += s->num_compartments(); } return n; @@ -196,7 +196,7 @@ void cell::add_detector(segment_location loc, double threshold) spike_detectors_.push_back({loc, threshold}); } -std::vector<int> const& cell::segment_parents() const +std::vector<cell::index_type> const& cell::segment_parents() const { return parents_; } @@ -212,24 +212,22 @@ std::vector<int> const& cell::segment_parents() const // - number of compartments in each segment bool cell_basic_equality(cell const& lhs, cell const& rhs) { - if(lhs.num_segments() != rhs.num_segments()) { + if (lhs.num_segments() != rhs.num_segments()) { return false; } - if(lhs.segment_parents() != rhs.segment_parents()) { + if (lhs.segment_parents() != rhs.segment_parents()) { return false; } - for(auto i=0; i<lhs.num_segments(); ++i) { + for (cell::index_type i=0; i<lhs.num_segments(); ++i) { // a quick and dirty test auto& l = *lhs.segment(i); auto& r = *rhs.segment(i); - if(l.kind() != r.kind()) return false; - if(l.area() != r.area()) return false; - if(l.volume() != r.volume()) return false; - if(l.as_cable()) { - if( l.as_cable()->num_compartments() - != r.as_cable()->num_compartments()) - { + if (l.kind() != r.kind()) return false; + if (l.area() != r.area()) return false; + if (l.volume() != r.volume()) return false; + if (l.as_cable()) { + if (l.as_cable()->num_compartments() != r.as_cable()->num_compartments()) { return false; } } diff --git a/src/cell.hpp b/src/cell.hpp index 1cb9b9f65dedbdcc4476c2b447eed1d1dea243ca..746549289f1b718433388e92f5f1ca9d13d433c1 100644 --- a/src/cell.hpp +++ b/src/cell.hpp @@ -5,8 +5,9 @@ #include <thread> #include <vector> -#include "segment.hpp" +#include "common_types.hpp" #include "cell_tree.hpp" +#include "segment.hpp" #include "stimulus.hpp" #include "util/debug.hpp" @@ -17,12 +18,12 @@ namespace mc { /// description struct compartment_model { cell_tree tree; - std::vector<int> parent_index; - std::vector<int> segment_index; + std::vector<cell_tree::int_type> parent_index; + std::vector<cell_tree::int_type> segment_index; }; struct segment_location { - segment_location(int s, double l) + segment_location(cell_lid_type s, double l) : segment(s), position(l) { EXPECTS(position>=0. && position<=1.); @@ -30,7 +31,7 @@ struct segment_location { friend bool operator==(segment_location l, segment_location r) { return l.segment==r.segment && l.position==r.position; } - int segment; + cell_lid_type segment; double position; }; @@ -44,27 +45,29 @@ enum class probeKind { membrane_current }; +struct probe_spec { + segment_location location; + probeKind kind; +}; + /// high-level abstract representation of a cell and its segments class cell { public: - - // types - using index_type = int; + using index_type = cell_lid_type; + using size_type = cell_local_size_type; using value_type = double; using point_type = point<value_type>; - + struct synapse_instance { segment_location location; parameter_list mechanism; }; - struct probe_instance { - segment_location location; - probeKind kind; - }; + struct stimulus_instance { segment_location location; i_clamp clamp; }; + struct detector_instance { segment_location location; double threshold; @@ -89,12 +92,12 @@ public: cable_segment* add_cable(index_type parent, Args ...args); /// the number of segments in the cell - int num_segments() const; + size_type num_segments() const; bool has_soma() const; - class segment* segment(int index); - class segment const* segment(int index) const; + class segment* segment(index_type index); + class segment const* segment(index_type index) const; /// access pointer to the soma /// returns nullptr if the cell has no soma @@ -103,7 +106,7 @@ public: /// access pointer to a cable segment /// will throw an std::out_of_range exception if /// the cable index is not valid - cable_segment* cable(int index); + cable_segment* cable(index_type index); /// the volume of the cell value_type volume() const; @@ -112,16 +115,16 @@ public: value_type area() const; /// the total number of compartments over all segments - size_t num_compartments() const; + size_type num_compartments() const; std::vector<segment_ptr> const& segments() const; /// return reference to array that enumerates the index of the parent of /// each segment - std::vector<int> const& segment_parents() const; + std::vector<index_type> const& segment_parents() const; /// return a vector with the compartment count for each segment in the cell - std::vector<int> compartment_counts() const; + std::vector<size_type> compartment_counts() const; compartment_model model() const; @@ -169,12 +172,12 @@ public: ////////////////// // probes ////////////////// - index_type add_probe(segment_location loc, probeKind kind) { - probes_.push_back({loc, kind}); + index_type add_probe(probe_spec p) { + probes_.push_back(p); return probes_.size()-1; } - const std::vector<probe_instance>& + const std::vector<probe_spec>& probes() const { return probes_; } private: @@ -195,7 +198,7 @@ private: std::vector<detector_instance> spike_detectors_; // the probes - std::vector<probe_instance> probes_; + std::vector<probe_spec> probes_; }; // Checks that two cells have the same diff --git a/src/cell_group.hpp b/src/cell_group.hpp index a5a6690cf08fe03f64f52d56a3577b0b0c747a52..e70d4b0bb5826b6ccde449f63a7fad6dab29c783 100644 --- a/src/cell_group.hpp +++ b/src/cell_group.hpp @@ -1,93 +1,80 @@ #pragma once #include <cstdint> +#include <functional> #include <vector> #include <cell.hpp> +#include <common_types.hpp> #include <event_queue.hpp> -#include <communication/spike.hpp> -#include <communication/spike_source.hpp> +#include <spike.hpp> +#include <spike_source.hpp> #include <profiling/profiler.hpp> namespace nest { namespace mc { -// samplers take a time and sample value, and return an optional time -// for the next desired sample. - -struct sampler { - using index_type = int; - using time_type = float; - using value_type = double; - - index_type probe_gid; // samplers are attached to probes - std::function<util::optional<time_type>(time_type, value_type)> sample; -}; - template <typename Cell> class cell_group { public: - using index_type = uint32_t; + using index_type = cell_gid_type; using cell_type = Cell; using value_type = typename cell_type::value_type; using size_type = typename cell_type::value_type; using spike_detector_type = spike_detector<Cell>; + using source_id_type = cell_member_type; + + using time_type = float; + using sampler_function = std::function<util::optional<time_type>(time_type, double)>; struct spike_source_type { - index_type index; + source_id_type source_id; spike_detector_type source; }; cell_group() = default; - cell_group(const cell& c) : - cell_{c} + cell_group(cell_gid_type gid, const cell& c) : + gid_base_{gid}, cell_{c} { - cell_.voltage()(memory::all) = -65.; - cell_.initialize(); + initialize_cells(); + + // Create spike detectors and associate them with globally unique source ids, + // as specified by cell gid and cell-local zero-based index. + + cell_gid_type source_gid = gid_base_; + cell_lid_type source_lid = 0u; for (auto& d : c.detectors()) { - spike_sources_.push_back( { - 0u, spike_detector_type(cell_, d.location, d.threshold, 0.f) + cell_member_type source_id{source_gid, source_lid++}; + + spike_sources_.push_back({ + source_id, spike_detector_type(cell_, d.location, d.threshold, 0.f) }); } } - void set_source_gids(index_type gid) { - for (auto& s : spike_sources_) { - s.index = gid++; + void reset() { + remove_samplers(); + initialize_cells(); + for (auto& spike_source: spike_sources_) { + spike_source.source.reset(cell_, 0.f); } } - void set_target_gids(index_type lid) { - first_target_gid_ = lid; - } - - index_type num_probes() const { - return cell_.num_probes(); - } - - void set_probe_gids(index_type gid) { - first_probe_gid_ = gid; - } - - std::pair<index_type, index_type> probe_gid_range() const { - return { first_probe_gid_, first_probe_gid_+cell_.num_probes() }; - } - - void advance(double tfinal, double dt) { + void advance(time_type tfinal, time_type dt) { while (cell_.time()<tfinal) { // take any pending samples - float cell_time = cell_.time(); + time_type cell_time = cell_.time(); PE("sampling"); while (auto m = sample_events_.pop_if_before(cell_time)) { - auto& sampler = samplers_[m->sampler_index]; - EXPECTS((bool)sampler.sample); + auto& sampler_spec = samplers_[m->sampler_index]; + EXPECTS((bool)sampler_spec.sampler); - index_type probe_index = sampler.probe_gid-first_probe_gid_; - auto next = sampler.sample(cell_.time(), cell_.probe(probe_index)); + index_type probe_index = sampler_spec.probe_id.index; + auto next = sampler_spec.sampler(cell_.time(), cell_.probe(probe_index)); if (next) { m->time = std::max(*next, cell_time); sample_events_.push(*m); @@ -96,9 +83,11 @@ public: PL(); // look for events in the next time step - auto tstep = std::min(tfinal, cell_.time()+dt); + time_type tstep = cell_.time()+dt; + tstep = std::min(tstep, tfinal); + auto next = events_.pop_if_before(tstep); - auto tnext = next ? next->time: tstep; + time_type tnext = next ? next->time: tstep; // integrate cell state cell_.advance(tnext - cell_.time()); @@ -110,7 +99,7 @@ public: // check for new spikes for (auto& s : spike_sources_) { if (auto spike = s.source.test(cell_, cell_.time())) { - spikes_.push_back({s.index, spike.get()}); + spikes_.push_back({s.source_id, spike.get()}); } } @@ -132,15 +121,12 @@ public: template <typename R> void enqueue_events(const R& events) { for (auto e : events) { - e.target -= first_target_gid_; events_.push(e); } } - const std::vector<communication::spike<index_type>>& - spikes() const { - return spikes_; - } + const std::vector<spike<source_id_type, time_type>>& + spikes() const { return spikes_; } cell_type& cell() { return cell_; } const cell_type& cell() const { return cell_; } @@ -154,14 +140,25 @@ public: spikes_.clear(); } - void add_sampler(const sampler& s, float start_time = 0) { + void add_sampler(cell_member_type probe_id, sampler_function s, time_type start_time = 0) { auto sampler_index = uint32_t(samplers_.size()); - samplers_.push_back(s); + samplers_.push_back({probe_id, s}); sample_events_.push({sampler_index, start_time}); } + void remove_samplers() { + sample_events_.clear(); + samplers_.clear(); + } + private: + void initialize_cells() { + cell_.voltage()(memory::all) = -65.; + cell_.initialize(); + } + /// gid of first cell in group + cell_gid_type gid_base_; /// the lowered cell state (e.g. FVM) of the cell cell_type cell_; @@ -170,22 +167,27 @@ private: std::vector<spike_source_type> spike_sources_; //. spikes that are generated - std::vector<communication::spike<index_type>> spikes_; + std::vector<spike<source_id_type, time_type>> spikes_; /// pending events to be delivered - event_queue<postsynaptic_spike_event> events_; + event_queue<postsynaptic_spike_event<time_type>> events_; /// pending samples to be taken - event_queue<sample_event> sample_events_; + event_queue<sample_event<time_type>> sample_events_; /// the global id of the first target (e.g. a synapse) in this group index_type first_target_gid_; - + /// the global id of the first probe in this group index_type first_probe_gid_; + struct sampler_entry { + cell_member_type probe_id; + sampler_function sampler; + }; + /// collection of samplers to be run against probes in this group - std::vector<sampler> samplers_; + std::vector<sampler_entry> samplers_; }; } // namespace mc diff --git a/src/cell_tree.hpp b/src/cell_tree.hpp index 4d9f70b1a934ab3b5bdb840d6bb0bd1b8419dfec..ce06d69a8b8ac7427b991e7e54ac2bc764a052e7 100644 --- a/src/cell_tree.hpp +++ b/src/cell_tree.hpp @@ -10,6 +10,8 @@ #include <vector> #include <vector/include/Vector.hpp> + +#include "common_types.hpp" #include "tree.hpp" #include "util.hpp" @@ -29,38 +31,42 @@ namespace mc { /// flexibility in choosing the root. class cell_tree { using range = memory::Range; -public : - // use a signed 16-bit integer for storage of indexes, which is reasonable given - // that typical cells have at most 1000-2000 segments - using int_type = int; + +public: + using int_type = cell_lid_type; + using size_type = cell_local_size_type; + using index_type = memory::HostVector<int_type>; using view_type = index_type::view_type; using const_view_type = index_type::const_view_type; + using tree = nest::mc::tree<int_type, size_type>; + static constexpr int_type no_parent = tree::no_parent; + /// default empty constructor cell_tree() = default; /// construct from a parent index - cell_tree(std::vector<int> const& parent_index) + cell_tree(std::vector<int_type> const& parent_index) { // handle the case of an empty parent list, which implies a single-compartment model if(parent_index.size()>0) { tree_ = tree(parent_index); } else { - tree_ = tree(std::vector<int>({0})); + tree_ = tree(std::vector<int_type>({0})); } } /// construct from a tree // copy constructor - cell_tree(tree const& t, int s) + cell_tree(tree const& t, int_type s) : tree_(t), soma_(s) { } // move constructor - cell_tree(tree&& t, int s) + cell_tree(tree&& t, int_type s) : tree_(std::move(t)), soma_(s) { } @@ -129,12 +135,12 @@ public : } /// returns the number of child segments of segment b - size_t num_children(size_t b) const { + size_type num_children(int_type b) const { return tree_.num_children(b); } /// returns a list of the children of segment b - const_view_type children(size_t b) const { + const_view_type children(int_type b) const { return tree_.children(b); } @@ -162,7 +168,7 @@ public : index_type depth_from_leaf() { tree::index_type depth(num_segments()); - depth_from_leaf(depth, 0); + depth_from_leaf(depth, int_type{0}); return depth; } @@ -170,7 +176,7 @@ public : { tree::index_type depth(num_segments()); depth[0] = 0; - depth_from_root(depth, 1); + depth_from_root(depth, int_type{1}); return depth; } @@ -179,12 +185,11 @@ private : /// helper type for sub-tree computation /// use in balance() struct sub_tree { - sub_tree(int r, int diam, int dpth) - : root(r), diameter(diam), depth(dpth) + sub_tree(int_type r, int_type diam, int_type dpth): + root(r), diameter(diam), depth(dpth) {} - void set(int r, int diam, int dpth) - { + void set(int r, int diam, int dpth) { root = r; diameter = diam; depth = dpth; @@ -198,15 +203,15 @@ private : "]"; } - int root; - int diameter; - int depth; + int_type root; + int_type diameter; + int_type depth; }; /// returns the index of the segment that would minimise the depth of the /// tree if used as the root segment int_type find_minimum_root() { - if(num_segments()==1) { + if (num_segments()==1) { return 0; } @@ -229,14 +234,14 @@ private : // walk has been completed to the root node, the node that has been // selected will be the root of the sub-tree with the largest diameter. sub_tree max_sub_tree(0, 0, 0); - auto distance_from_max_leaf = 1; + int_type distance_from_max_leaf = 1; auto pnt = max_leaf; auto pos = parent(max_leaf); - while(pos != -1) { + while(pos != no_parent) { for(auto c : children(pos)) { if(c!=pnt) { auto diameter = depth[c] + 1 + distance_from_max_leaf; - if(diameter>max_sub_tree.diameter) { + if (diameter>max_sub_tree.diameter) { max_sub_tree.set(pos, diameter, distance_from_max_leaf); } } @@ -266,7 +271,7 @@ private : return new_root; } - int_type depth_from_leaf(index_type& depth, int segment) + int_type depth_from_leaf(index_type& depth, int_type segment) { int_type max_depth = 0; for(auto c : children(segment)) { @@ -276,7 +281,7 @@ private : return max_depth+1; } - void depth_from_root(index_type& depth, int segment) + void depth_from_root(index_type& depth, int_type segment) { auto d = depth[parent(segment)] + 1; depth[segment] = d; diff --git a/src/common_types.hpp b/src/common_types.hpp new file mode 100644 index 0000000000000000000000000000000000000000..da81ed5d65368af60095d8b9dab40aeede29467a --- /dev/null +++ b/src/common_types.hpp @@ -0,0 +1,55 @@ +#pragma once + +/* + * Common definitions for index types etc. across prototype simulator + * library. (Expect analogues in future versions to be template parameters?) + */ + +#include <iosfwd> +#include <type_traits> + +#include <util/lexcmp_def.hpp> + +namespace nest { +namespace mc { + +// For identifying cells globally. + +using cell_gid_type = std::uint32_t; + +// For sizes of collections of cells. + +using cell_size_type = typename std::make_unsigned<cell_gid_type>::type; + +// For indexes into cell-local data. +// +// Local indices for items within a particular cell-local collection should be +// zero-based and numbered contiguously. + +using cell_lid_type = std::uint32_t; + +// For counts of cell-local data. + +using cell_local_size_type = typename std::make_unsigned<cell_lid_type>::type; + +// For global identification of an item of cell local data. +// +// Items of cell_member_type must: +// +// * be associated with a unique cell, identified by the member `gid` +// (see: cell_gid_type); +// +// * identify an item within a cell-local collection by the member `index` +// (see: cell_lid_type). + +struct cell_member_type { + cell_gid_type gid; + cell_lid_type index; +}; + +DEFINE_LEXICOGRAPHIC_ORDERING(cell_member_type,(a.gid,a.index),(b.gid,b.index)) + +} // namespace mc +} // namespace nest + +std::ostream& operator<<(std::ostream& O, nest::mc::cell_member_type m); diff --git a/src/common_types_io.cpp b/src/common_types_io.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ad6ca540b80e1ce296c8ba2f4eceb52cd9ebfd63 --- /dev/null +++ b/src/common_types_io.cpp @@ -0,0 +1,8 @@ +#include <iostream> + +#include <common_types.hpp> + +std::ostream& operator<<(std::ostream& O, nest::mc::cell_member_type m) { + return O << m.gid << ':' << m.index; +} + diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp index 34185ac0fbd795294211ced7190d8b84a3e70874..871bf65266da7c7ffcc962a41f9913d93fa89fc2 100644 --- a/src/communication/communicator.hpp +++ b/src/communication/communicator.hpp @@ -5,12 +5,12 @@ #include <vector> #include <random> -#include <communication/spike.hpp> -#include <threading/threading.hpp> #include <algorithms.hpp> +#include <connection.hpp> #include <event_queue.hpp> - -#include "connection.hpp" +#include <spike.hpp> +#include <threading/threading.hpp> +#include <util/debug.hpp> namespace nest { namespace mc { @@ -26,79 +26,37 @@ namespace communication { // Once all connections have been specified, the construct() method can be used // to build the data structures required for efficient spike communication and // event generation. -template <typename CommunicationPolicy> +template <typename Time, typename CommunicationPolicy> class communicator { public: - using id_type = uint32_t; + using id_type = cell_gid_type; + using time_type = Time; using communication_policy_type = CommunicationPolicy; - using spike_type = spike<id_type>; + using spike_type = spike<cell_member_type, time_type>; communicator() = default; - communicator(id_type n_groups, std::vector<id_type> target_counts) : - num_groups_local_(n_groups), - num_targets_local_(target_counts.size()) + // for now, still assuming one-to-one association cells <-> groups, + // so that 'group' gids as represented by their first cell gid are + // contiguous. + communicator(id_type cell_from, id_type cell_to): + cell_gid_from_(cell_from), cell_gid_to_(cell_to) { - target_map_ = nest::mc::algorithms::make_index(target_counts); - num_targets_local_ = target_map_.back(); + auto num_groups_local_ = cell_gid_to_-cell_gid_from_; // create an event queue for each target group events_.resize(num_groups_local_); - - // make maps for converting lid to gid - target_gid_map_ = communication_policy_.make_map(num_targets_local_); - group_gid_map_ = communication_policy_.make_map(num_groups_local_); - - // transform the target ids from lid to gid - auto first_target = target_gid_map_[domain_id()]; - for (auto &id : target_map_) { - id += first_target; - } } - id_type target_gid_from_group_lid(id_type lid) const { - EXPECTS(lid<num_groups_local_); - return target_map_[lid]; - } - - id_type group_gid_from_group_lid(id_type lid) const { - EXPECTS(lid<num_groups_local_); - return group_gid_map_[domain_id()] + lid; - } - void add_connection(connection con) { - EXPECTS(is_local_target(con.destination())); + void add_connection(connection<time_type> con) { + EXPECTS(is_local_cell(con.destination().gid)); connections_.push_back(con); } - bool is_local_target(id_type gid) { - return gid>=target_gid_map_[domain_id()] - && gid<target_gid_map_[domain_id()+1]; - } - - bool is_local_group(id_type gid) { - return gid>=group_gid_map_[domain_id()] - && gid<group_gid_map_[domain_id()+1]; - } - - /// return the global id of the first group in domain d - /// the groups in domain d are in the contiguous half open range - /// [domain_first_group(d), domain_first_group(d+1)) - id_type group_gid_first(int d) const { - return group_gid_map_[d]; - } - - id_type target_lid(id_type gid) { - EXPECTS(is_local_group(gid)); - - return gid - target_gid_map_[domain_id()]; - } - - id_type group_lid(id_type gid) { - EXPECTS(is_local_group(gid)); - - return gid - group_gid_map_[domain_id()]; + bool is_local_cell(id_type gid) const { + return gid>=cell_gid_from_ && gid<cell_gid_to_; } // builds the optimized data structure @@ -108,8 +66,8 @@ public: } } - float min_delay() { - auto local_min = std::numeric_limits<float>::max(); + time_type min_delay() { + auto local_min = std::numeric_limits<time_type>::max(); for (auto& con : connections_) { local_min = std::min(local_min, con.delay()); } @@ -117,19 +75,6 @@ public: return communication_policy_.min(local_min); } - // return the local group index of the group which hosts the target with - // global id gid - id_type local_group_from_global_target(id_type gid) { - // assert that gid is in range - EXPECTS(is_local_target(gid)); - - return - std::distance( - target_map_.begin(), - std::upper_bound(target_map_.begin(), target_map_.end(), gid) - ) - 1; - } - void add_spike(spike_type s) { thread_spikes().push_back(s); } @@ -144,22 +89,16 @@ public: } void exchange() { - // global all-to-all to gather a local copy of the global spike list // on each node - //profiler_.enter("global exchange"); auto global_spikes = communication_policy_.gather_spikes(local_spikes()); num_spikes_ += global_spikes.size(); clear_thread_spike_buffers(); - //profiler_.leave(); for (auto& q : events_) { q.clear(); } - //profiler_.enter("events"); - - //profiler_.enter("make events"); // check all global spikes to see if they will generate local events for (auto spike : global_spikes) { // search for targets @@ -170,35 +109,19 @@ public: // generate an event for each target for (auto it=targets.first; it!=targets.second; ++it) { - auto gidx = local_group_from_global_target(it->destination()); - + auto gidx = cell_group_index(it->destination().gid); events_[gidx].push_back(it->make_event(spike)); } } - - //profiler_.leave(); // make events - - //profiler_.leave(); // event generation } - uint64_t num_spikes() const - { - return num_spikes_; - } + uint64_t num_spikes() const { return num_spikes_; } - int domain_id() const { - return communication_policy_.id(); - } - - int num_domains() const { - return communication_policy_.size(); - } - - const std::vector<postsynaptic_spike_event>& queue(int i) const { + const std::vector<postsynaptic_spike_event<time_type>>& queue(int i) const { return events_[i]; } - const std::vector<connection>& connections() const { + const std::vector<connection<time_type>>& connections() const { return connections_; } @@ -206,10 +129,6 @@ public: return communication_policy_; } - const std::vector<id_type>& local_target_map() const { - return target_map_; - } - std::vector<spike_type> local_spikes() { std::vector<spike_type> spikes; for (auto& v : thread_spikes_) { @@ -224,7 +143,20 @@ public: } } + void reset() { + // remove all in-flight spikes/events + clear_thread_spike_buffers(); + for (auto& evbuf: events_) { + evbuf.clear(); + } + } + private: + std::size_t cell_group_index(cell_gid_type cell_gid) const { + // this will be more elaborate when there is more than one cell per cell group + EXPECTS(cell_gid>=cell_gid_from_ && cell_gid<cell_gid_to_); + return cell_gid-cell_gid_from_; + } // // both of these can be fixed with double buffering @@ -239,33 +171,17 @@ private: nest::mc::threading::enumerable_thread_specific<std::vector<spike_type>>; local_spike_store_type thread_spikes_; - std::vector<connection> connections_; - std::vector<std::vector<postsynaptic_spike_event>> events_; - - // local target group i has targets in the half open range - // [target_map_[i], target_map_[i+1]) - std::vector<id_type> target_map_; + std::vector<connection<time_type>> connections_; + std::vector<std::vector<postsynaptic_spike_event<time_type>>> events_; // for keeping track of how time is spent where //util::Profiler profiler_; - // the number of groups and targets handled by this communicator - id_type num_groups_local_; - id_type num_targets_local_; - - // index maps for the global distribution of groups and targets - - // communicator i has the groups in the half open range : - // [group_gid_map_[i], group_gid_map_[i+1]) - std::vector<id_type> group_gid_map_; - - // communicator i has the targets in the half open range : - // [target_gid_map_[i], target_gid_map_[i+1]) - std::vector<id_type> target_gid_map_; - communication_policy_type communication_policy_; uint64_t num_spikes_ = 0u; + id_type cell_gid_from_; + id_type cell_gid_to_; }; } // namespace communication diff --git a/src/communication/connection.hpp b/src/communication/connection.hpp deleted file mode 100644 index ee1acfd72dc21cffe01f28239bca29526366df31..0000000000000000000000000000000000000000 --- a/src/communication/connection.hpp +++ /dev/null @@ -1,70 +0,0 @@ -#pragma once - -#include <cstdint> - -#include <event_queue.hpp> -#include <communication/spike.hpp> - -namespace nest { -namespace mc { -namespace communication { - -class connection { -public: - using id_type = uint32_t; - connection(id_type src, id_type dest, float w, float d) : - source_(src), - destination_(dest), - weight_(w), - delay_(d) - {} - - float weight() const { return weight_; } - float delay() const { return delay_; } - - id_type source() const { return source_; } - id_type destination() const { return destination_; } - - postsynaptic_spike_event make_event(spike<id_type> s) { - return {destination_, s.time + delay_, weight_}; - } - -private: - - id_type source_; - id_type destination_; - float weight_; - float delay_; -}; - -// connections are sorted by source id -// these operators make for easy interopability with STL algorithms - -static inline -bool operator< (connection lhs, connection rhs) { - return lhs.source() < rhs.source(); -} - -static inline -bool operator< (connection lhs, connection::id_type rhs) { - return lhs.source() < rhs; -} - -static inline -bool operator< (connection::id_type lhs, connection rhs) { - return lhs < rhs.source(); -} - -} // namespace communication -} // namespace mc -} // namespace nest - -static inline -std::ostream& operator<<(std::ostream& o, nest::mc::communication::connection const& con) { - char buff[512]; - snprintf( - buff, sizeof(buff), "con [%10u -> %10u : weight %8.4f, delay %8.4f]", - con.source(), con.destination(), con.weight(), con.delay() - ); - return o << buff; -} diff --git a/src/communication/mpi_global_policy.hpp b/src/communication/mpi_global_policy.hpp index 7409585447ae888eec2aa47ddc1fa0f137df7d6e..0e9ec33370695990e2e2fa0d8d03d4a1e8b26c26 100644 --- a/src/communication/mpi_global_policy.hpp +++ b/src/communication/mpi_global_policy.hpp @@ -4,24 +4,23 @@ #error "mpi_global_policy.hpp should only be compiled in a WITH_MPI build" #endif +#include <cstdint> #include <type_traits> #include <vector> -#include <cstdint> - -#include <communication/spike.hpp> -#include <communication/mpi.hpp> #include <algorithms.hpp> +#include <common_types.hpp> +#include <communication/mpi.hpp> +#include <spike.hpp> namespace nest { namespace mc { namespace communication { struct mpi_global_policy { - using id_type = uint32_t; - - std::vector<spike<id_type>> - static gather_spikes(const std::vector<spike<id_type>>& local_spikes) { + template <typename I, typename T> + static std::vector<spike<I, T>> + gather_spikes(const std::vector<spike<I,T>>& local_spikes) { return mpi::gather_all(local_spikes); } diff --git a/src/communication/serial_global_policy.hpp b/src/communication/serial_global_policy.hpp index a1c7fee9671cb79768e255a710e0082386302e8a..9af5eae3fd0dd84be71cb6346018c607fa9ab3df 100644 --- a/src/communication/serial_global_policy.hpp +++ b/src/communication/serial_global_policy.hpp @@ -1,21 +1,19 @@ #pragma once +#include <cstdint> #include <type_traits> #include <vector> -#include <cstdint> - -#include <communication/spike.hpp> +#include <spike.hpp> namespace nest { namespace mc { namespace communication { struct serial_global_policy { - using id_type = uint32_t; - - std::vector<spike<id_type>> const - static gather_spikes(const std::vector<spike<id_type>>& local_spikes) { + template <typename I, typename T> + static const std::vector<spike<I, T>>& + gather_spikes(const std::vector<spike<I, T>>& local_spikes) { return local_spikes; } diff --git a/src/compartment.hpp b/src/compartment.hpp index b92360a560ba5de30a036505d50d457b7b5a5a9c..da6bbcf57298593a60f7510822832060c7a89846 100644 --- a/src/compartment.hpp +++ b/src/compartment.hpp @@ -3,6 +3,8 @@ #include <iterator> #include <utility> +#include "common_types.hpp" + namespace nest { namespace mc { @@ -10,7 +12,7 @@ namespace mc { /// The compartment is a conic frustrum struct compartment { using value_type = double; - using size_type = int; + using size_type = cell_local_size_type; using value_pair = std::pair<value_type, value_type>; compartment() = delete; @@ -103,8 +105,7 @@ class compartment_iterator : }; class compartment_range { - public: - +public: using size_type = compartment_iterator::size_type; using real_type = compartment_iterator::real_type; diff --git a/src/connection.hpp b/src/connection.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b9e40fb5baeb83ad69c87c219f65d9f98acc4456 --- /dev/null +++ b/src/connection.hpp @@ -0,0 +1,68 @@ +#pragma once + +#include <cstdint> + +#include <common_types.hpp> +#include <event_queue.hpp> +#include <spike.hpp> + +namespace nest { +namespace mc { + +template <typename Time> +class connection { +public: + using id_type = cell_member_type; + using time_type = Time; + + connection(id_type src, id_type dest, float w, time_type d) : + source_(src), + destination_(dest), + weight_(w), + delay_(d) + {} + + float weight() const { return weight_; } + float delay() const { return delay_; } + + id_type source() const { return source_; } + id_type destination() const { return destination_; } + + postsynaptic_spike_event<time_type> make_event(spike<id_type, time_type> s) { + return {destination_, s.time + delay_, weight_}; + } + +private: + id_type source_; + id_type destination_; + float weight_; + time_type delay_; +}; + +// connections are sorted by source id +// these operators make for easy interopability with STL algorithms + +template <typename T> +static inline bool operator<(connection<T> lhs, connection<T> rhs) { + return lhs.source() < rhs.source(); +} + +template <typename T> +static inline bool operator<(connection<T> lhs, typename connection<T>::id_type rhs) { + return lhs.source() < rhs; +} + +template <typename T> +static inline bool operator<(typename connection<T>::id_type lhs, connection<T> rhs) { + return lhs < rhs.source(); +} + +} // namespace mc +} // namespace nest + +template <typename T> +static inline std::ostream& operator<<(std::ostream& o, nest::mc::connection<T> const& con) { + return o << "con [" << con.source() << " -> " << con.destination() + << " : weight " << con.weight() + << ", delay " << con.delay() << "]"; +} diff --git a/src/event_queue.hpp b/src/event_queue.hpp index 7f3034f3f08bbb2235bf457c95e8eedfe8d0895d..31f6a5a10cf6b0c9ff9c0ff049909b801cd25b34 100644 --- a/src/event_queue.hpp +++ b/src/event_queue.hpp @@ -4,35 +4,49 @@ #include <ostream> #include <queue> +#include "common_types.hpp" #include "util/optional.hpp" namespace nest { namespace mc { +/* An event class Event must comply with the following conventions: + * Typedefs: + * time_type floating point type used to represent event times + * Member functions: + * time_type when() const return time value associated with event + */ + +template <typename Time> struct postsynaptic_spike_event { - uint32_t target; - float time; + using time_type = Time; + + cell_member_type target; + time_type time; float weight; -}; -inline float event_time(const postsynaptic_spike_event &ev) { return ev.time; } + time_type when() const { return time; } +}; +template <typename Time> struct sample_event { - uint32_t sampler_index; - float time; -}; + using time_type = Time; -inline float event_time(const sample_event &ev) { return ev.time; } + std::uint32_t sampler_index; + time_type time; + + time_type when() const { return time; } +}; /* Event objects must have a method event_time() which returns a value * from a type with a total ordering with respect to <, >, etc. - */ + */ template <typename Event> class event_queue { public : using value_type = Event; - using time_type = decltype(event_time(std::declval<Event>())); + using time_type = typename Event::time_type; // create event_queue() {} @@ -46,7 +60,7 @@ public : } // push thing - void push(const value_type &e) { + void push(const value_type& e) { queue_.push(e); } @@ -56,7 +70,7 @@ public : // pop until util::optional<value_type> pop_if_before(time_type t_until) { - if (!queue_.empty() && event_time(queue_.top()) < t_until) { + if (!queue_.empty() && queue_.top().when() < t_until) { auto ev = queue_.top(); queue_.pop(); return ev; @@ -66,10 +80,15 @@ public : } } + // clear everything + void clear() { + queue_ = decltype(queue_){}; + } + private: struct event_greater { - bool operator()(const Event &a, const Event &b) { - return event_time(a) > event_time(b); + bool operator()(const Event& a, const Event& b) { + return a.when() > b.when(); } }; @@ -83,8 +102,8 @@ private: } // namespace nest } // namespace mc -inline -std::ostream& operator<< (std::ostream& o, const nest::mc::postsynaptic_spike_event& e) +template <typename T> +inline std::ostream& operator<<(std::ostream& o, const nest::mc::postsynaptic_spike_event<T>& e) { return o << "event[" << e.target << "," << e.time << "," << e.weight << "]"; } diff --git a/src/fvm_cell.hpp b/src/fvm_cell.hpp index 148f8ac55be791e2efaefd70e444dfef3a458058..526fd22a391b477d5af3c88a4535b4721c83cbfe 100644 --- a/src/fvm_cell.hpp +++ b/src/fvm_cell.hpp @@ -13,7 +13,7 @@ #include <math.hpp> #include <matrix.hpp> #include <mechanism.hpp> -#include <mechanism_interface.hpp> +#include <mechanism_catalogue.hpp> #include <segment.hpp> #include <stimulus.hpp> #include <util.hpp> @@ -113,8 +113,9 @@ public: void advance(value_type dt); /// pass an event to the appropriate synapse and call net_receive - void apply_event(postsynaptic_spike_event e) { - mechanisms_[synapse_index_]->net_receive(e.target, e.weight); + template <typename Time> + void apply_event(postsynaptic_spike_event<Time> e) { + mechanisms_[synapse_index_]->net_receive(e.target.index, e.weight); } mechanism_type& synapses() { @@ -194,6 +195,9 @@ private: std::vector<std::pair<uint32_t, i_clamp>> stimulii_; std::vector<std::pair<const vector_type fvm_cell::*, uint32_t>> probes_; + + // mechanism factory + using mechanism_catalogue = nest::mc::mechanisms::catalogue<T, I>; }; //////////////////////////////////////////////////////////////////////////////// @@ -291,10 +295,10 @@ fvm_cell<T, I>::fvm_cell(nest::mc::cell const& cell) // for each mechanism in the cell record the indexes of the segments that // contain the mechanism - std::map<std::string, std::vector<int>> mech_map; + std::map<std::string, std::vector<unsigned>> mech_map; - for(auto i=0; i<cell.num_segments(); ++i) { - for(const auto& mech : cell.segment(i)->mechanisms()) { + for (unsigned i=0; i<cell.num_segments(); ++i) { + for (const auto& mech : cell.segment(i)->mechanisms()) { // FIXME : Membrane has to be a proper mechanism, // because it is exposed via the public interface. // This if statement is bad @@ -308,12 +312,12 @@ fvm_cell<T, I>::fvm_cell(nest::mc::cell const& cell) // instance. // TODO : this works well for density mechanisms (e.g. ion channels), but // does it work for point processes (e.g. synapses)? - for(auto& mech : mech_map) { - auto& helper = nest::mc::mechanisms::get_mechanism_helper(mech.first); + for (auto& mech : mech_map) { + //auto& helper = nest::mc::mechanisms::get_mechanism_helper(mech.first); // calculate the number of compartments that contain the mechanism auto num_comp = 0u; - for(auto seg : mech.second) { + for (auto seg : mech.second) { num_comp += segment_index_[seg+1] - segment_index_[seg]; } @@ -321,7 +325,7 @@ fvm_cell<T, I>::fvm_cell(nest::mc::cell const& cell) // the mechanism index_type compartment_index(num_comp); auto pos = 0u; - for(auto seg : mech.second) { + for (auto seg : mech.second) { auto seg_size = segment_index_[seg+1] - segment_index_[seg]; std::iota( compartment_index.data() + pos, @@ -333,24 +337,24 @@ fvm_cell<T, I>::fvm_cell(nest::mc::cell const& cell) // instantiate the mechanism mechanisms_.push_back( - helper->new_mechanism(voltage_, current_, compartment_index) + mechanism_catalogue::make(mech.first, voltage_, current_, compartment_index) + //helper->new_mechanism(voltage_, current_, compartment_index) ); } synapse_index_ = mechanisms_.size(); - std::map<std::string, std::vector<int>> syn_map; + std::map<std::string, std::vector<cell_lid_type>> syn_map; for (const auto& syn : cell.synapses()) { syn_map[syn.mechanism.name()].push_back(find_compartment_index(syn.location, graph)); } - for (const auto &syni : syn_map) { + for (const auto& syni : syn_map) { const auto& mech_name = syni.first; - auto& helper = nest::mc::mechanisms::get_mechanism_helper(mech_name); + // auto& helper = nest::mc::mechanisms::get_mechanism_helper(mech_name); index_type compartment_index(syni.second); - - auto mech = helper->new_mechanism(voltage_, current_, compartment_index); + auto mech = mechanism_catalogue::make(mech_name, voltage_, current_, compartment_index); mech->set_areas(cv_areas_); mechanisms_.push_back(std::move(mech)); } @@ -370,7 +374,7 @@ fvm_cell<T, I>::fvm_cell(nest::mc::cell const& cell) } } } - std::vector<int> indexes(index_set.begin(), index_set.end()); + std::vector<cell_lid_type> indexes(index_set.begin(), index_set.end()); // create the ion state if(indexes.size()) { diff --git a/src/matrix.hpp b/src/matrix.hpp index d0ddb816adc5cb9bfb6c3c88562e0a41835920ec..79b7218bc97fa0b16469ac50dbb5a91dc58c0db6 100644 --- a/src/matrix.hpp +++ b/src/matrix.hpp @@ -164,7 +164,7 @@ class matrix { auto const ncells = num_cells(); // loop over submatrices - for(auto m=0; m<ncells; ++m) { + for (size_type m=0; m<ncells; ++m) { auto first = cell_index_[m]; auto last = cell_index_[m+1]; diff --git a/src/mechanism_catalogue.hpp b/src/mechanism_catalogue.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c9d3e6de2771a1b18fa05f104dd257b8c7fe6dd9 --- /dev/null +++ b/src/mechanism_catalogue.hpp @@ -0,0 +1,61 @@ +#pragma once + +#include <map> +#include <stdexcept> +#include <string> + +#include <mechanism.hpp> +#include <mechanisms/hh.hpp> +#include <mechanisms/pas.hpp> +#include <mechanisms/expsyn.hpp> +#include <mechanisms/exp2syn.hpp> + +namespace nest { +namespace mc { +namespace mechanisms { + +template <typename T, typename I> +struct catalogue { + using view_type = typename mechanism<T, I>::view_type; + using index_view = typename mechanism<T, I>::index_view; + + static mechanism_ptr<T, I> make( + const std::string& name, + view_type vec_v, + view_type vec_i, + index_view node_indices) + { + auto entry = mech_map.find(name); + if (entry==mech_map.end()) { + throw std::out_of_range("no such mechanism"); + } + + return entry->second(vec_v, vec_i, node_indices); + } + + static bool has(const std::string& name) { + return mech_map.count(name)>0; + } + +private: + using maker_type = mechanism_ptr<T, I> (*)(view_type, view_type, index_view); + static const std::map<std::string, maker_type> mech_map; + + template <template <typename, typename> class mech> + static mechanism_ptr<T, I> maker(view_type vec_v, view_type vec_i, index_view node_indices) { + return make_mechanism<mech<T, I>>(vec_v, vec_i, node_indices); + } +}; + +template <typename T, typename I> +const std::map<std::string, typename catalogue<T, I>::maker_type> catalogue<T, I>::mech_map = { + { "pas", maker<pas::mechanism_pas> }, + { "hh", maker<hh::mechanism_hh> }, + { "expsyn", maker<expsyn::mechanism_expsyn> }, + { "exp2syn", maker<exp2syn::mechanism_exp2syn> } +}; + + +} // namespace mechanisms +} // namespace mc +} // namespace nest diff --git a/src/mechanism_interface.cpp b/src/mechanism_interface.cpp deleted file mode 100644 index 680e8739d11ee0dad16890bb1088a1a7d4b774b5..0000000000000000000000000000000000000000 --- a/src/mechanism_interface.cpp +++ /dev/null @@ -1,57 +0,0 @@ -#include "mechanism_interface.hpp" - -// -// include the mechanisms -// - -#include <mechanisms/hh.hpp> -#include <mechanisms/pas.hpp> -#include <mechanisms/expsyn.hpp> -#include <mechanisms/exp2syn.hpp> - - -namespace nest { -namespace mc { -namespace mechanisms { - -std::map<std::string, mechanism_helper_ptr<value_type, index_type>> mechanism_helpers; - -void setup_mechanism_helpers() { - mechanism_helpers["pas"] = - make_mechanism_helper< - mechanisms::pas::helper<value_type, index_type> - >(); - - mechanism_helpers["hh"] = - make_mechanism_helper< - mechanisms::hh::helper<value_type, index_type> - >(); - - mechanism_helpers["expsyn"] = - make_mechanism_helper< - mechanisms::expsyn::helper<value_type, index_type> - >(); - - mechanism_helpers["exp2syn"] = - make_mechanism_helper< - mechanisms::exp2syn::helper<value_type, index_type> - >(); -} - -mechanism_helper_ptr<value_type, index_type>& -get_mechanism_helper(const std::string& name) -{ - auto helper = mechanism_helpers.find(name); - if(helper==mechanism_helpers.end()) { - throw std::out_of_range( - nest::mc::util::pprintf("there is no mechanism named \'%\'", name) - ); - } - - return helper->second; -} - -} // namespace mechanisms -} // namespace nest -} // namespace mc - diff --git a/src/mechanism_interface.hpp b/src/mechanism_interface.hpp index c59caaad6c655d6a69c5206b956b1cb42959cfdf..5e5360e6602761ae2923c6c2afec958279235dca 100644 --- a/src/mechanism_interface.hpp +++ b/src/mechanism_interface.hpp @@ -1,7 +1,6 @@ #pragma once -#include <map> -#include <string> +// just for compatibility with current version of modparser... #include "mechanism.hpp" #include "parameter_list.hpp" @@ -10,11 +9,6 @@ namespace nest { namespace mc { namespace mechanisms { -using value_type = double; -using index_type = int; - -/// helper type for building mechanisms -/// the use of abstract base classes everywhere is a bit ugly template <typename T, typename I> struct mechanism_helper { using value_type = T; @@ -25,34 +19,10 @@ struct mechanism_helper { using view_type = typename mechanism<T,I>::view_type; virtual std::string name() const = 0; - virtual mechanism_ptr<T,I> new_mechanism(view_type, view_type, index_view) const = 0; - virtual void set_parameters(mechanism_ptr_type&, parameter_list const&) const = 0; }; -template <typename T, typename I> -using mechanism_helper_ptr = - std::unique_ptr<mechanism_helper<T,I>>; - -template <typename M> -mechanism_helper_ptr<typename M::value_type, typename M::size_type> -make_mechanism_helper() -{ - return util::make_unique<M>(); -} - -// for now use a global variable for the map of mechanism helpers -extern std::map< - std::string, - mechanism_helper_ptr<value_type, index_type> -> mechanism_helpers; - -void setup_mechanism_helpers(); - -mechanism_helper_ptr<value_type, index_type>& -get_mechanism_helper(const std::string& name); - } // namespace mechanisms } // namespace mc } // namespace nest diff --git a/src/model.hpp b/src/model.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3812a44bfffaa21397d12953c93fcd13bb3843d9 --- /dev/null +++ b/src/model.hpp @@ -0,0 +1,136 @@ +#pragma once + +#include <cstdlib> +#include <vector> + +#include <common_types.hpp> +#include <cell.hpp> +#include <cell_group.hpp> +#include <communication/communicator.hpp> +#include <communication/global_policy.hpp> +#include <fvm_cell.hpp> +#include <recipe.hpp> +#include <profiling/profiler.hpp> + +#include "trace_sampler.hpp" + +namespace nest { +namespace mc { + +template <typename Cell> +class model { +public: + using cell_group_type = cell_group<Cell>; + using time_type = typename cell_group_type::time_type; + using value_type = typename cell_group_type::value_type; + using communicator_type = communication::communicator<time_type, communication::global_policy>; + using sampler_function = typename cell_group_type::sampler_function; + + struct probe_record { + cell_member_type id; + probe_spec probe; + }; + + model(const recipe& rec, cell_gid_type cell_from, cell_gid_type cell_to): + cell_from_(cell_from), + cell_to_(cell_to), + communicator_(cell_from, cell_to) + { + cell_groups_ = std::vector<cell_group_type>{cell_to_-cell_from_}; + + threading::parallel_vector<probe_record> probes; + threading::parallel_for::apply(cell_from_, cell_to_, + [&](cell_gid_type i) { + PE("setup", "cells"); + auto cell = rec.get_cell(i); + auto idx = i-cell_from_; + cell_groups_[idx] = cell_group_type(i, cell); + + cell_lid_type j = 0; + for (const auto& probe: cell.probes()) { + cell_member_type probe_id{i,j++}; + probes.push_back({probe_id, probe}); + } + PL(2); + }); + + probes_.assign(probes.begin(), probes.end()); + + for (cell_gid_type i=cell_from_; i<cell_to_; ++i) { + for (const auto& cc: rec.connections_on(i)) { + connection<time_type> conn{cc.source, cc.dest, cc.weight, cc.delay}; + communicator_.add_connection(conn); + } + } + communicator_.construct(); + } + + void reset() { + t_ = 0.; + for (auto& group: cell_groups_) { + group.reset(); + } + communicator_.reset(); + } + + time_type run(time_type tfinal, time_type dt) { + time_type min_delay = communicator_.min_delay(); + while (t_<tfinal) { + auto tuntil = std::min(t_+min_delay, tfinal); + threading::parallel_for::apply( + 0u, cell_groups_.size(), + [&](unsigned i) { + auto& group = cell_groups_[i]; + + PE("stepping","events"); + group.enqueue_events(communicator_.queue(i)); + PL(); + + group.advance(tuntil, dt); + + PE("events"); + communicator_.add_spikes(group.spikes()); + group.clear_spikes(); + PL(2); + }); + + PE("stepping", "exchange"); + communicator_.exchange(); + PL(2); + + t_ = tuntil; + } + return t_; + } + + void add_artificial_spike(cell_member_type source) { + add_artificial_spike(source, t_); + } + + void add_artificial_spike(cell_member_type source, time_type tspike) { + communicator_.add_spike({source, tspike}); + } + + void attach_sampler(cell_member_type probe_id, sampler_function f, time_type tfrom = 0) { + // TODO: translate probe_id.gid to appropriate group, but for now 1-1. + if (probe_id.gid<cell_from_ || probe_id.gid>=cell_to_) { + return; + } + cell_groups_[probe_id.gid-cell_from_].add_sampler(probe_id, f, tfrom); + } + + const std::vector<probe_record>& probes() const { return probes_; } + + std::size_t num_spikes() const { return communicator_.num_spikes(); } + +private: + cell_gid_type cell_from_; + cell_gid_type cell_to_; + time_type t_ = 0.; + std::vector<cell_group_type> cell_groups_; + communicator_type communicator_; + std::vector<probe_record> probes_; +}; + +} // namespace mc +} // namespace nest diff --git a/src/profiling/profiler.cpp b/src/profiling/profiler.cpp index b5d2ad0d8d7bd9efe7fd8e55e5f6720ab051b3f1..1e30449ab4b225385570292bea527fbcc6981c3d 100644 --- a/src/profiling/profiler.cpp +++ b/src/profiling/profiler.cpp @@ -66,7 +66,7 @@ void profiler_node::print_sub( if (print_children) { auto other = 0.; - for (auto &n : children) { + for (auto& n : children) { if (n.value<threshold || n.name=="other") { other += n.value; } @@ -162,7 +162,7 @@ double region_type::subregion_contributions() const { profiler_node region_type::populate_performance_tree() const { profiler_node tree(total(), name()); - for (auto &it : subregions_) { + for (auto& it : subregions_) { tree.children.push_back(it.second->populate_performance_tree()); } diff --git a/src/recipe.hpp b/src/recipe.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b451a1a1cffbd2139707e8f8d8a60fa6e2f4e037 --- /dev/null +++ b/src/recipe.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include <cstddef> +#include <memory> +#include <stdexcept> + +namespace nest { +namespace mc { + +struct cell_count_info { + cell_size_type num_sources; + cell_size_type num_targets; + cell_size_type num_probes; +}; + +class invalid_recipe_error: public std::runtime_error { +public: + invalid_recipe_error(std::string whatstr): std::runtime_error(std::move(whatstr)) {} +}; + +/* Recipe descriptions are cell-oriented: in order that the building + * phase can be done distributedly and in order that the recipe + * description can be built indepdently of any runtime execution + * environment, connection end-points are represented by pairs + * (cell index, source/target index on cell). + */ + +using cell_connection_endpoint = cell_member_type; + +// Note: `cell_connection` and `connection` have essentially the same data +// and represent the same thing conceptually. `cell_connection` objects +// are notionally described in terms of external cell identifiers instead +// of internal gids, but we are not making the distinction between the +// two in the current code. These two types could well be merged. + +struct cell_connection { + cell_connection_endpoint source; + cell_connection_endpoint dest; + + float weight; + float delay; +}; + +class recipe { +public: + virtual cell_size_type num_cells() const =0; + + virtual cell get_cell(cell_gid_type) const =0; + virtual cell_count_info get_cell_count_info(cell_gid_type) const =0; + virtual std::vector<cell_connection> connections_on(cell_gid_type) const =0; +}; + +} // namespace mc +} // namespace nest diff --git a/src/segment.hpp b/src/segment.hpp index c6c49b26a946d36fdc969d9473fa07105d30d4af..248306a6c3a976d123b9d412b93c09156305086d 100644 --- a/src/segment.hpp +++ b/src/segment.hpp @@ -4,6 +4,7 @@ #include <vector> #include "algorithms.hpp" +#include "common_types.hpp" #include "compartment.hpp" #include "math.hpp" #include "parameter_list.hpp" @@ -36,6 +37,7 @@ class segment { public: using value_type = double; + using size_type = cell_local_size_type; using point_type = point<value_type>; segmentKind kind() const { @@ -57,8 +59,8 @@ class segment { return kind_==segmentKind::axon; } - virtual int num_compartments() const = 0; - virtual void set_compartments(int) = 0; + virtual size_type num_compartments() const = 0; + virtual void set_compartments(size_type) = 0; virtual value_type volume() const = 0; virtual value_type area() const = 0; @@ -153,13 +155,12 @@ class segment { std::vector<parameter_list> mechanisms_; }; -class placeholder_segment : public segment -{ - public: - +class placeholder_segment : public segment { +public: using base = segment; using base::kind_; using base::value_type; + using base::size_type; placeholder_segment() { @@ -181,23 +182,22 @@ class placeholder_segment : public segment return true; } - int num_compartments() const override + size_type num_compartments() const override { return 0; } - virtual void set_compartments(int) override + virtual void set_compartments(size_type) override { } }; -class soma_segment : public segment -{ - public : - +class soma_segment : public segment { +public: using base = segment; using base::kind_; using base::value_type; using base::point_type; + using base::size_type; soma_segment() = delete; @@ -208,7 +208,7 @@ class soma_segment : public segment mechanisms_.push_back(membrane_parameters()); } - soma_segment(value_type r, point_type const &c) + soma_segment(value_type r, point_type const& c) : soma_segment(r) { center_ = c; @@ -245,12 +245,12 @@ class soma_segment : public segment } /// soma has one and one only compartments - int num_compartments() const override + size_type num_compartments() const override { return 1; } - void set_compartments(int n) override + void set_compartments(size_type n) override { } private : @@ -261,10 +261,8 @@ class soma_segment : public segment point_type center_; }; -class cable_segment : public segment -{ - public : - +class cable_segment : public segment { +public: using base = segment; using base::kind_; using base::value_type; @@ -332,7 +330,7 @@ class cable_segment : public segment value_type volume() const override { auto sum = value_type{0}; - for(auto i=0; i<num_sub_segments(); ++i) { + for (auto i=0u; i<num_sub_segments(); ++i) { sum += math::volume_frustrum(lengths_[i], radii_[i], radii_[i+1]); } return sum; @@ -341,7 +339,7 @@ class cable_segment : public segment value_type area() const override { auto sum = value_type{0}; - for(auto i=0; i<num_sub_segments(); ++i) { + for (auto i=0u; i<num_sub_segments(); ++i) { sum += math::area_frustrum(lengths_[i], radii_[i], radii_[i+1]); } return sum; @@ -358,7 +356,7 @@ class cable_segment : public segment } // the number sub-segments that define the cable segment - int num_sub_segments() const + size_type num_sub_segments() const { return radii_.size()-1; } @@ -378,12 +376,12 @@ class cable_segment : public segment return this; } - int num_compartments() const override + size_type num_compartments() const override { return num_compartments_; } - void set_compartments(int n) override + void set_compartments(size_type n) override { if(n<1) { throw std::out_of_range( @@ -406,8 +404,8 @@ class cable_segment : public segment // we find ourselves having to do it over and over again. // The time to cache it might be when update_lengths() is called. auto sum = value_type(0); - auto i=0; - for(i=0; i<num_sub_segments(); ++i) { + size_type i = 0; + for (i = 0; i<num_sub_segments(); ++i) { if(sum+lengths_[i]>pos) { break; } @@ -425,19 +423,18 @@ class cable_segment : public segment return {num_compartments(), radii_.front(), radii_.back(), length()}; } - private : - +private: void update_lengths() { - if(locations_.size()) { + if (locations_.size()) { lengths_.resize(num_sub_segments()); - for(auto i=0; i<num_sub_segments(); ++i) { + for (size_type i=0; i<num_sub_segments(); ++i) { lengths_[i] = norm(locations_[i] - locations_[i+1]); } } } - int num_compartments_ = 1; + size_type num_compartments_ = 1; std::vector<value_type> lengths_; std::vector<value_type> radii_; std::vector<point_type> locations_; diff --git a/src/communication/spike.hpp b/src/spike.hpp similarity index 53% rename from src/communication/spike.hpp rename to src/spike.hpp index 03b5ed75def7a2fb68bd491ae7f1208d4b230b8a..d3ea02551456d2408712886dcff8125b314630e9 100644 --- a/src/communication/spike.hpp +++ b/src/spike.hpp @@ -1,48 +1,39 @@ #pragma once -#include <type_traits> #include <ostream> +#include <type_traits> namespace nest { namespace mc { -namespace communication { -template < - typename I, - typename = typename std::enable_if<std::is_integral<I>::value> -> +template <typename I, typename Time> struct spike { using id_type = I; - id_type source = 0; - float time = -1.; + using time_type = Time; + + id_type source = id_type{}; + time_type time = -1.; spike() = default; - spike(id_type s, float t) : + spike(id_type s, time_type t) : source(s), time(t) {} }; } // namespace mc } // namespace nest -} // namespace communication /// custom stream operator for printing nest::mc::spike<> values -template <typename I> -std::ostream& operator<<( - std::ostream& o, - nest::mc::communication::spike<I> s) -{ +template <typename I, typename T> +std::ostream& operator<<(std::ostream& o, nest::mc::spike<I, T> s) { return o << "spike[t " << s.time << ", src " << s.source << "]"; } /// less than comparison operator for nest::mc::spike<> values /// spikes are ordered by spike time, for use in sorting and queueing -template <typename I> -bool operator<( - nest::mc::communication::spike<I> lhs, - nest::mc::communication::spike<I> rhs) -{ +template <typename I, typename T> +bool operator<(nest::mc::spike<I, T> lhs, nest::mc::spike<I, T> rhs) { return lhs.time < rhs.time; } diff --git a/src/communication/spike_source.hpp b/src/spike_source.hpp similarity index 84% rename from src/communication/spike_source.hpp rename to src/spike_source.hpp index 95099cb547de24e76386c2253960b8fffe4bf569..0537b6800a19534d23aedd7de2edae3d022b53e6 100644 --- a/src/communication/spike_source.hpp +++ b/src/spike_source.hpp @@ -13,13 +13,11 @@ class spike_detector public: using cell_type = Cell; - spike_detector( const cell_type& cell, segment_location loc, double thresh, float t_init) : + spike_detector(const cell_type& cell, segment_location loc, double thresh, float t_init) : location_(loc), - threshold_(thresh), - previous_t_(t_init) + threshold_(thresh) { - previous_v_ = cell.voltage(location_); - is_spiking_ = previous_v_ >= thresh ? true : false; + reset(cell, t_init); } util::optional<float> test(const cell_type& cell, float t) { @@ -58,6 +56,12 @@ public: float v() const { return previous_v_; } + void reset(const cell_type& cell, float t_init) { + previous_t_ = t_init; + previous_v_ = cell.voltage(location_); + is_spiking_ = previous_v_ >= threshold_; + } + private: // parameters/data diff --git a/src/threading/serial.hpp b/src/threading/serial.hpp index fb66e6e2ff78ec7eb638091ded2ac874583227e6..8d9c4d64fd84805ed0909ad704ffc192bd0da748 100644 --- a/src/threading/serial.hpp +++ b/src/threading/serial.hpp @@ -7,6 +7,7 @@ #include <array> #include <chrono> #include <string> +#include <vector> namespace nest { namespace mc { @@ -56,6 +57,10 @@ struct parallel_for { } }; +template <typename T> +using parallel_vector = std::vector<T>; + + inline std::string description() { return "serial"; } diff --git a/src/threading/tbb.hpp b/src/threading/tbb.hpp index eae6b112bdb75b697709379b41f00cfbc249a3b1..8e0086741cac1d51e02e9fa148e9eaa3f3f837a3 100644 --- a/src/threading/tbb.hpp +++ b/src/threading/tbb.hpp @@ -46,6 +46,9 @@ struct timer { constexpr bool multithreaded() { return true; } +template <typename T> +using parallel_vector = tbb::concurrent_vector<T>; + } // threading } // mc } // nest diff --git a/src/tree.hpp b/src/tree.hpp index 98ff9aa43656f9d68b879297a6ffd5399348b6f0..cd59537411e90c7b5efa696ebe8d5baed9c5ea1a 100644 --- a/src/tree.hpp +++ b/src/tree.hpp @@ -12,16 +12,18 @@ namespace nest { namespace mc { +template <typename Int, typename Size = std::size_t> class tree { using range = memory::Range; - public : - - using int_type = int; +public: + using int_type = Int; + using size_type = Size; using index_type = memory::HostVector<int_type>; - using view_type = index_type::view_type; - using const_view_type = index_type::const_view_type; + using view_type = typename index_type::view_type; + using const_view_type = typename index_type::const_view_type; + static constexpr int_type no_parent = (int_type)-1; tree() = default; @@ -67,12 +69,12 @@ class tree { init(new_parent_index.size()); parents_(memory::all) = new_parent_index; - parents_[0] = -1; + parents_[0] = no_parent; child_index_(memory::all) = algorithms::make_index(algorithms::child_count(parents_)); - std::vector<int> pos(parents_.size(), 0); + std::vector<int_type> pos(parents_.size(), 0); for (auto i = 1u; i < parents_.size(); ++i) { auto p = parents_[i]; children_[child_index_[p] + pos[p]] = i; @@ -80,16 +82,18 @@ class tree { } } - size_t num_children() const { - return children_.size(); + size_type num_children() const { + return static_cast<size_type>(children_.size()); } - size_t num_children(size_t b) const { + + size_type num_children(size_t b) const { return child_index_[b+1] - child_index_[b]; } - size_t num_nodes() const { + + size_type num_nodes() const { // the number of nodes is the size of the child index minus 1 // ... except for the case of an empty tree - auto sz = child_index_.size(); + auto sz = static_cast<size_type>(child_index_.size()); return sz ? sz - 1 : 0; } @@ -104,7 +108,7 @@ class tree { } /// return the list of all children of branch b - const_view_type children(size_t b) const { + const_view_type children(size_type b) const { return children_(child_index_[b], child_index_[b+1]); } @@ -122,7 +126,7 @@ class tree { } /// memory used to store tree (in bytes) - size_t memory() const { + std::size_t memory() const { return sizeof(int_type)*data_.size() + sizeof(tree); } @@ -164,17 +168,16 @@ class tree { return p; } - private : - - void init(int nnode) { - auto nchild = nnode -1; +private: + void init(size_type nnode) { + auto nchild = nnode - 1; data_ = index_type(nchild + (nnode + 1) + nnode); set_ranges(nnode); } - void set_ranges(int nnode) { - if(nnode) { + void set_ranges(size_type nnode) { + if (nnode) { auto nchild = nnode - 1; // data_ is partitioned as follows: // data_ = [children_[nchild], child_index_[nnode+1], parents_[nnode]] @@ -215,17 +218,17 @@ class tree { /// new_node /// p : permutation vector, p[i] is the new index of node i in the old /// tree - int add_children( - int new_node, - int old_node, - int parent_node, + int_type add_children( + int_type new_node, + int_type old_node, + int_type parent_node, view_type p, tree const& old_tree ) { - // check for the senitel that indicates that the old root has + // check for the sentinel that indicates that the old root has // been processed - if(old_node==-1) { + if (old_node==no_parent) { return new_node; } @@ -237,7 +240,7 @@ class tree { auto this_node = new_node; auto pos = child_index_[this_node]; - auto add_parent_as_child = parent_node>=0 && old_node>0; + auto add_parent_as_child = parent_node!=no_parent && old_node>0; // // STEP 1 : add the child indexes for this_node // @@ -259,12 +262,12 @@ class tree { // STEP 2 : recursively add each child's children // new_node++; - for(auto b : old_children) { - if(b != parent_node) { - new_node = add_children(new_node, b, -1, p, old_tree); + for (auto b : old_children) { + if (b != parent_node) { + new_node = add_children(new_node, b, no_parent, p, old_tree); } } - if(add_parent_as_child) { + if (add_parent_as_child) { new_node = add_children( new_node, old_tree.parent(old_node), old_node, p, old_tree @@ -286,38 +289,38 @@ class tree { view_type parents_ = data_(0, 0); }; -template <typename C> -std::vector<int> make_parent_index(tree const& t, C const& counts) +template <typename IntT, typename SizeT, typename C> +std::vector<IntT> make_parent_index(tree<IntT, SizeT> const& t, C const& counts) { using range = memory::Range; + using int_type = typename tree<IntT, SizeT>::int_type; + constexpr auto no_parent = tree<IntT, SizeT>::no_parent; - if( !algorithms::is_positive(counts) - || counts.size() != t.num_nodes() ) - { + if (!algorithms::is_positive(counts) || counts.size() != t.num_nodes()) { throw std::domain_error( "make_parent_index requires one non-zero count per segment" ); } auto index = algorithms::make_index(counts); auto num_compartments = index.back(); - std::vector<int> parent_index(num_compartments); - auto pos = 0; - for(int i : range(0, t.num_nodes())) { + std::vector<int_type> parent_index(num_compartments); + int_type pos = 0; + for (int_type i : range(0, t.num_nodes())) { // get the parent of this segment // taking care for the case where the root node has -1 as its parent auto parent = t.parent(i); - parent = parent>=0 ? parent : 0; + parent = parent!=no_parent ? parent : 0; // the index of the first compartment in the segment // is calculated differently for the root (i.e when i==parent) - if(i!=parent) { + if (i!=parent) { parent_index[pos++] = index[parent+1]-1; } else { parent_index[pos++] = parent; } // number the remaining compartments in the segment consecutively - while(pos<index[i+1]) { + while (pos<index[i+1]) { parent_index[pos] = pos-1; pos++; } diff --git a/src/util.hpp b/src/util.hpp index 29a6b28a9ec09cce13fcced0db20ae30aa3d803f..243987e628685ee4d4414394b054bebc6e7646a0 100644 --- a/src/util.hpp +++ b/src/util.hpp @@ -18,7 +18,7 @@ using memory::util::cyan; template <typename T> std::ostream& -operator << (std::ostream &o, std::vector<T>const& v) +operator << (std::ostream& o, std::vector<T>const& v) { o << "["; for(auto const& i: v) { @@ -29,7 +29,7 @@ operator << (std::ostream &o, std::vector<T>const& v) } template <typename T> -std::ostream& print(std::ostream &o, std::vector<T>const& v) +std::ostream& print(std::ostream& o, std::vector<T>const& v) { o << "["; for(auto const& i: v) { diff --git a/src/util/debug.hpp b/src/util/debug.hpp index 1d0ef63346068876f8366a0d423a8887fe4a7470..f258b37123755c7ec7a7324641249ed35290e8bb 100644 --- a/src/util/debug.hpp +++ b/src/util/debug.hpp @@ -11,7 +11,7 @@ namespace mc { namespace util { bool failed_assertion(const char* assertion, const char* file, int line, const char* func); -std::ostream &debug_emit_trace_leader(std::ostream& out, const char* file, int line, const char* varlist); +std::ostream& debug_emit_trace_leader(std::ostream& out, const char* file, int line, const char* varlist); inline void debug_emit(std::ostream& out) { out << "\n"; diff --git a/src/util/ioutil.hpp b/src/util/ioutil.hpp index da73bd800e48c66364b0fd884e643a15797f1c8f..2feab394df71d730bde904d041ebbeeb85bf9a44 100644 --- a/src/util/ioutil.hpp +++ b/src/util/ioutil.hpp @@ -1,6 +1,6 @@ #pragma once -#include <ios> +#include <iostream> namespace nest { namespace mc { @@ -24,6 +24,75 @@ private: }; +template <typename charT, typename traitsT = std::char_traits<charT> > +class basic_null_streambuf: public std::basic_streambuf<charT, traitsT> { +private: + typedef typename std::basic_streambuf<charT, traitsT> streambuf_type; + +public: + typedef typename streambuf_type::char_type char_type; + typedef typename streambuf_type::int_type int_type; + typedef typename streambuf_type::pos_type pos_type; + typedef typename streambuf_type::off_type off_type; + typedef typename streambuf_type::traits_type traits_type; + + virtual ~basic_null_streambuf() {} + +protected: + std::streamsize xsputn(const char_type* s, std::streamsize count) override { + return count; + } + + int_type overflow(int_type c) override { + return traits_type::not_eof(c); + } +}; + +class mask_stream { +public: + explicit mask_stream(bool mask): mask_(mask) {} + + operator bool() const { return mask_; } + + template <typename charT, typename traitsT> + friend std::basic_ostream<charT, traitsT>& + operator<<(std::basic_ostream<charT, traitsT>& O, const mask_stream& F) { + int xindex = get_xindex(); + + std::basic_streambuf<charT, traitsT>* saved_streambuf = + static_cast<std::basic_streambuf<charT, traitsT>*>(O.pword(xindex)); + + if (F.mask_ && saved_streambuf) { + // re-enable by restoring saved streambuf + O.pword(xindex) = 0; + O.rdbuf(saved_streambuf); + } + else if (!F.mask_ && !saved_streambuf) { + // disable stream but save old streambuf + O.pword(xindex) = O.rdbuf(); + O.rdbuf(get_null_streambuf<charT, traitsT>()); + } + + return O; + } + +private: + // our key for retrieve saved streambufs. + static int get_xindex() { + static int xindex = std::ios_base::xalloc(); + return xindex; + } + + template <typename charT, typename traitsT> + static std::basic_streambuf<charT, traitsT>* get_null_streambuf() { + static basic_null_streambuf<charT, traitsT> the_null_streambuf; + return &the_null_streambuf; + } + + // true => do not filter + bool mask_; +}; + } // namespace util } // namespace mc } // namespace nest diff --git a/src/util/lexcmp_def.hpp b/src/util/lexcmp_def.hpp index d1c3e8a3bfc63341070b55b1320de03583797bbe..cfb297b1d16d3dbd7023bb7f300271eda0af1e41 100644 --- a/src/util/lexcmp_def.hpp +++ b/src/util/lexcmp_def.hpp @@ -20,7 +20,7 @@ #include <tuple> #define DEFINE_LEXICOGRAPHIC_ORDERING_IMPL_(proxy,op,type,a_fields,b_fields) \ -inline bool operator op(const type &a,const type &b) { return proxy a_fields op proxy b_fields; } +inline bool operator op(const type& a,const type& b) { return proxy a_fields op proxy b_fields; } #define DEFINE_LEXICOGRAPHIC_ORDERING(type,a_fields,b_fields) \ DEFINE_LEXICOGRAPHIC_ORDERING_IMPL_(std::tie,<,type,a_fields,b_fields) \ diff --git a/src/util/optional.hpp b/src/util/optional.hpp index 92191dfff7adf9af084a2984b19210565da761e4..6b4d0f8b66cbdcf24c59914466a01621895ca2bb 100644 --- a/src/util/optional.hpp +++ b/src/util/optional.hpp @@ -29,7 +29,7 @@ namespace util { template <typename X> struct optional; struct optional_unset_error: std::runtime_error { - explicit optional_unset_error(const std::string &what_str) + explicit optional_unset_error(const std::string& what_str) : std::runtime_error(what_str) {} @@ -39,7 +39,7 @@ struct optional_unset_error: std::runtime_error { }; struct optional_invalid_dereference: std::runtime_error { - explicit optional_invalid_dereference(const std::string &what_str) + explicit optional_invalid_dereference(const std::string& what_str) : std::runtime_error(what_str) {} diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 61d713bc831e7ec63dfaf370f7931e204c7d9e45..4ba5bb8b0be6764e84fb24ea1c16ee803f6bc516 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -17,6 +17,7 @@ set(TEST_SOURCES test_fvm.cpp test_cell_group.cpp test_lexcmp.cpp + test_mask_stream.cpp test_matrix.cpp test_mechanisms.cpp test_optional.cpp diff --git a/tests/unit/test_cell.cpp b/tests/unit/test_cell.cpp index 88c8dcaced167d4030529adb8a348a116158c17f..3102d97435d3095b4ef91b885b040ac084f37a5f 100644 --- a/tests/unit/test_cell.cpp +++ b/tests/unit/test_cell.cpp @@ -1,6 +1,6 @@ #include "gtest.h" -#include "../src/cell.hpp" +#include "cell.hpp" TEST(cell_type, soma) { @@ -59,7 +59,7 @@ TEST(cell_type, add_segment) ); c.add_cable(0, std::move(seg)); - EXPECT_EQ(c.num_segments(), 2); + EXPECT_EQ(c.num_segments(), 2u); } // add segment on the fly @@ -78,7 +78,7 @@ TEST(cell_type, add_segment) segmentKind::dendrite, cable_radius, cable_radius, cable_length ); - EXPECT_EQ(c.num_segments(), 2); + EXPECT_EQ(c.num_segments(), 2u); } { cell c; @@ -97,7 +97,7 @@ TEST(cell_type, add_segment) std::vector<double>{cable_length, cable_length, cable_length} ); - EXPECT_EQ(c.num_segments(), 2); + EXPECT_EQ(c.num_segments(), 2u); } } @@ -137,7 +137,7 @@ TEST(cell_type, multiple_cables) c.add_cable(1, seg(segmentKind::dendrite)); c.add_cable(1, seg(segmentKind::dendrite)); - EXPECT_EQ(c.num_segments(), 5); + EXPECT_EQ(c.num_segments(), 5u); // each of the 5 segments has volume 1 by design EXPECT_EQ(c.volume(), 5.); // each of the 4 cables has volume 2., and the soma has an awkward area @@ -148,12 +148,14 @@ TEST(cell_type, multiple_cables) const auto model = c.model(); auto const& con = model.tree; + auto no_parent = cell_tree::no_parent; + EXPECT_EQ(con.num_segments(), 5u); - EXPECT_EQ(con.parent(0), -1); - EXPECT_EQ(con.parent(1), 0); - EXPECT_EQ(con.parent(2), 0); - EXPECT_EQ(con.parent(3), 1); - EXPECT_EQ(con.parent(4), 1); + EXPECT_EQ(con.parent(0), no_parent); + EXPECT_EQ(con.parent(1), 0u); + EXPECT_EQ(con.parent(2), 0u); + EXPECT_EQ(con.parent(3), 1u); + EXPECT_EQ(con.parent(4), 1u); EXPECT_EQ(con.num_children(0), 2u); EXPECT_EQ(con.num_children(1), 2u); EXPECT_EQ(con.num_children(2), 0u); diff --git a/tests/unit/test_cell_group.cpp b/tests/unit/test_cell_group.cpp index d9787a1935657391be82f2e674b13bab9ef908ed..3aed222029f4535a85d9713849cff184a3684531 100644 --- a/tests/unit/test_cell_group.cpp +++ b/tests/unit/test_cell_group.cpp @@ -1,16 +1,12 @@ -#include <limits> - #include "gtest.h" +#include <common_types.hpp> #include <fvm_cell.hpp> #include <cell_group.hpp> nest::mc::cell make_cell() { using namespace nest::mc; - // setup global state for the mechanisms - mechanisms::setup_mechanism_helpers(); - nest::mc::cell cell; // Soma with diameter 12.6157 um and HH channel @@ -24,10 +20,8 @@ nest::mc::cell make_cell() { dendrite->mechanism("membrane").set("r_L", 100); - // add stimulus - cell.add_stimulus({1,1}, {5., 80., 0.3}); - - cell.add_detector({0,0}, 0); + cell.add_detector({0, 0}, 0); + cell.add_stimulus({1, 1}, {5., 80., 0.3}); return cell; } @@ -36,13 +30,45 @@ TEST(cell_group, test) { using namespace nest::mc; - using cell_type = cell_group<fvm::fvm_cell<double, int>>; - - auto cell = cell_type{make_cell()}; + using cell_group_type = cell_group<fvm::fvm_cell<double, cell_local_size_type>>; + auto group = cell_group_type{0, make_cell()}; - cell.advance(50, 0.01); + group.advance(50, 0.01); // a bit lame... - EXPECT_EQ(cell.spikes().size(), 4u); + EXPECT_EQ(group.spikes().size(), 4u); } +TEST(cell_group, sources) +{ + using namespace nest::mc; + + // TODO: extend to multi-cell cell groups when the time comes + + using cell_group_type = cell_group<fvm::fvm_cell<double, cell_local_size_type>>; + + auto cell = make_cell(); + EXPECT_EQ(cell.detectors().size(), 1u); + // add another detector on the cell to make things more interesting + cell.add_detector({1, 0.3}, 2.3); + + cell_gid_type first_gid = 37u; + auto group = cell_group_type{first_gid, cell}; + + // expect group sources to be lexicographically sorted by source id + // with gids in cell group's range and indices starting from zero + + const auto& sources = group.spike_sources(); + for (unsigned i = 0; i<sources.size(); ++i) { + auto id = sources[i].source_id; + if (i==0) { + EXPECT_EQ(id.gid, first_gid); + EXPECT_EQ(id.index, 0u); + } + else { + auto prev = sources[i-1].source_id; + EXPECT_GT(id, prev); + EXPECT_EQ(id.index, id.gid==prev.gid? prev.index+1: 0u); + } + } +} diff --git a/tests/unit/test_compartments.cpp b/tests/unit/test_compartments.cpp index c109f9cda9a335ce32f695f85ccd0295446a5d92..c167b16cce5585a43efc13fc9f136025b21f7811 100644 --- a/tests/unit/test_compartments.cpp +++ b/tests/unit/test_compartments.cpp @@ -15,13 +15,13 @@ TEST(compartments, compartment) { nest::mc::compartment c(100, 1.2, 2.1, 2.2); - EXPECT_EQ(c.index, 100); + EXPECT_EQ(c.index, 100u); EXPECT_EQ(c.length, 1.2); EXPECT_EQ(left(c.radius), 2.1); EXPECT_EQ(right(c.radius), 2.2); auto c2 = c; - EXPECT_EQ(c2.index, 100); + EXPECT_EQ(c2.index, 100u); EXPECT_EQ(c2.length, 1.2); EXPECT_EQ(left(c2.radius), 2.1); EXPECT_EQ(right(c2.radius), 2.2); @@ -29,7 +29,7 @@ TEST(compartments, compartment) { nest::mc::compartment c{100, 1, 2, 3}; - EXPECT_EQ(c.index, 100); + EXPECT_EQ(c.index, 100u); EXPECT_EQ(c.length, 1.); EXPECT_EQ(left(c.radius), 2.); EXPECT_EQ(right(c.radius), 3.); @@ -54,7 +54,7 @@ TEST(compartments, compartment_iterator) ++it; { auto c = *it; - EXPECT_EQ(c.index, 1); + EXPECT_EQ(c.index, 1u); EXPECT_EQ(left(c.radius), 3.0); EXPECT_EQ(right(c.radius), 5.0); EXPECT_EQ(c.length, 2.5); @@ -65,7 +65,7 @@ TEST(compartments, compartment_iterator) // returned iterator should be unchanged { auto c = *(it++); - EXPECT_EQ(c.index, 1); + EXPECT_EQ(c.index, 1u); EXPECT_EQ(left(c.radius), 3.0); EXPECT_EQ(right(c.radius), 5.0); EXPECT_EQ(c.length, 2.5); @@ -73,7 +73,7 @@ TEST(compartments, compartment_iterator) // while the iterator itself was updated { auto c = *it; - EXPECT_EQ(c.index, 2); + EXPECT_EQ(c.index, 2u); EXPECT_EQ(left(c.radius), 5.0); EXPECT_EQ(right(c.radius), 7.0); EXPECT_EQ(c.length, 2.5); @@ -84,7 +84,7 @@ TEST(compartments, compartment_iterator) // copy iterator auto it2 = it; auto c = *it2; - EXPECT_EQ(c.index, 2); + EXPECT_EQ(c.index, 2u); EXPECT_EQ(left(c.radius), 5.0); EXPECT_EQ(right(c.radius), 7.0); EXPECT_EQ(c.length, 2.5); @@ -96,7 +96,7 @@ TEST(compartments, compartment_iterator) // check the copy has updated correctly when incremented c= *it2; - EXPECT_EQ(c.index, 3); + EXPECT_EQ(c.index, 3u); EXPECT_EQ(left(c.radius), 7.0); EXPECT_EQ(right(c.radius), 9.0); EXPECT_EQ(c.length, 2.5); @@ -108,11 +108,11 @@ TEST(compartments, compartment_range) { nest::mc::compartment_range rng(10, 1.0, 2.0, 10.); - EXPECT_EQ((*rng.begin()).index, 0); - EXPECT_EQ((*rng.end()).index, 10); + EXPECT_EQ((*rng.begin()).index, 0u); + EXPECT_EQ((*rng.end()).index, 10u); EXPECT_NE(rng.begin(), rng.end()); - auto count = 0; + unsigned count = 0; for(auto c : rng) { EXPECT_EQ(c.index, count); auto er = 1.0 + double(count)/10.; @@ -121,7 +121,7 @@ TEST(compartments, compartment_range) EXPECT_EQ(c.length, 1.0); ++count; } - EXPECT_EQ(count, 10); + EXPECT_EQ(count, 10u); } // test case of zero length range diff --git a/tests/unit/test_event_queue.cpp b/tests/unit/test_event_queue.cpp index 5ac4c75747d761638b0bbd4b90a9b6fed6fa3c18..f4e31b26715fa2aa3cb690720442cc68f2b18fe2 100644 --- a/tests/unit/test_event_queue.cpp +++ b/tests/unit/test_event_queue.cpp @@ -7,14 +7,14 @@ TEST(event_queue, push) { using namespace nest::mc; - using ps_event_queue = event_queue<postsynaptic_spike_event>; + using ps_event_queue = event_queue<postsynaptic_spike_event<float>>; ps_event_queue q; - q.push({1u, 2.f, 2.f}); - q.push({4u, 1.f, 2.f}); - q.push({8u, 20.f, 2.f}); - q.push({2u, 8.f, 2.f}); + q.push({{1u, 0u}, 2.f, 2.f}); + q.push({{4u, 1u}, 1.f, 2.f}); + q.push({{8u, 2u}, 20.f, 2.f}); + q.push({{2u, 3u}, 8.f, 2.f}); std::vector<float> times; while(q.size()) { @@ -31,13 +31,13 @@ TEST(event_queue, push) TEST(event_queue, push_range) { using namespace nest::mc; - using ps_event_queue = event_queue<postsynaptic_spike_event>; + using ps_event_queue = event_queue<postsynaptic_spike_event<float>>; - postsynaptic_spike_event events[] = { - {1u, 2.f, 2.f}, - {4u, 1.f, 2.f}, - {8u, 20.f, 2.f}, - {2u, 8.f, 2.f} + postsynaptic_spike_event<float> events[] = { + {{1u, 0u}, 2.f, 2.f}, + {{4u, 1u}, 1.f, 2.f}, + {{8u, 2u}, 20.f, 2.f}, + {{2u, 3u}, 8.f, 2.f} }; ps_event_queue q; @@ -56,13 +56,20 @@ TEST(event_queue, push_range) TEST(event_queue, pop_if_before) { using namespace nest::mc; - using ps_event_queue = event_queue<postsynaptic_spike_event>; + using ps_event_queue = event_queue<postsynaptic_spike_event<float>>; - postsynaptic_spike_event events[] = { - {1u, 1.f, 2.f}, - {2u, 2.f, 2.f}, - {3u, 3.f, 2.f}, - {4u, 4.f, 2.f} + cell_member_type target[4] = { + {1u, 0u}, + {4u, 1u}, + {8u, 2u}, + {2u, 3u} + }; + + postsynaptic_spike_event<float> events[] = { + {target[0], 1.f, 2.f}, + {target[1], 2.f, 2.f}, + {target[2], 3.f, 2.f}, + {target[3], 4.f, 2.f} }; ps_event_queue q; @@ -76,12 +83,12 @@ TEST(event_queue, pop_if_before) auto e2 = q.pop_if_before(5.); EXPECT_TRUE(e2); - EXPECT_EQ(e2->target, 1u); + EXPECT_EQ(e2->target, target[0]); EXPECT_EQ(q.size(), 3u); auto e3 = q.pop_if_before(5.); EXPECT_TRUE(e3); - EXPECT_EQ(e3->target, 2u); + EXPECT_EQ(e3->target, target[1]); EXPECT_EQ(q.size(), 2u); auto e4 = q.pop_if_before(2.5); @@ -90,7 +97,7 @@ TEST(event_queue, pop_if_before) auto e5 = q.pop_if_before(5.); EXPECT_TRUE(e5); - EXPECT_EQ(e5->target, 3u); + EXPECT_EQ(e5->target, target[2]); EXPECT_EQ(q.size(), 1u); q.pop_if_before(5.); diff --git a/tests/unit/test_fvm.cpp b/tests/unit/test_fvm.cpp index 1383167aafec9b0c0ae36f2bd6ab8661803d2fe5..c0ed438f20841d40a9f2653e38424aeeba4a64ea 100644 --- a/tests/unit/test_fvm.cpp +++ b/tests/unit/test_fvm.cpp @@ -1,20 +1,19 @@ #include <fstream> #include "gtest.h" -#include "../test_util.hpp" +#include <common_types.hpp> #include <cell.hpp> #include <fvm_cell.hpp> +#include "../test_util.hpp" + TEST(fvm, cable) { using namespace nest::mc; nest::mc::cell cell; - // setup global state for the mechanisms - nest::mc::mechanisms::setup_mechanism_helpers(); - cell.add_soma(6e-4); // 6um in cm // 1um radius and 4mm long, all in cm @@ -42,7 +41,7 @@ TEST(fvm, cable) cell.segment(1)->set_compartments(4); cell.segment(2)->set_compartments(4); - using fvm_cell = fvm::fvm_cell<double, int>; + using fvm_cell = fvm::fvm_cell<double, cell_lid_type>; fvm_cell fvcell(cell); auto& J = fvcell.jacobian(); @@ -64,9 +63,6 @@ TEST(fvm, init) nest::mc::cell cell; - // setup global state for the mechanisms - nest::mc::mechanisms::setup_mechanism_helpers(); - cell.add_soma(12.6157/2.0); //auto& props = cell.soma()->properties; @@ -95,7 +91,7 @@ TEST(fvm, init) cell.segment(1)->set_compartments(10); - using fvm_cell = fvm::fvm_cell<double, int>; + using fvm_cell = fvm::fvm_cell<double, cell_lid_type>; fvm_cell fvcell(cell); auto& J = fvcell.jacobian(); EXPECT_EQ(J.size(), 11u); diff --git a/tests/unit/test_lexcmp.cpp b/tests/unit/test_lexcmp.cpp index a02ac0765823f61d54e8acee3cf38611552d9dd7..e2b5c476d7983f7726349953d15e34509aab031c 100644 --- a/tests/unit/test_lexcmp.cpp +++ b/tests/unit/test_lexcmp.cpp @@ -62,8 +62,8 @@ class lexcmp_test_refmemfn { public: explicit lexcmp_test_refmemfn(int foo): foo_(foo) {} - const int &foo() const { return foo_; } - int &foo() { return foo_; } + const int& foo() const { return foo_; } + int& foo() { return foo_; } private: int foo_; diff --git a/tests/unit/test_mask_stream.cpp b/tests/unit/test_mask_stream.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8ca8089eab1f7486204ce0df70fb953000b301a0 --- /dev/null +++ b/tests/unit/test_mask_stream.cpp @@ -0,0 +1,68 @@ +#include "gtest.h" + +#include <sstream> + +#include <util/ioutil.hpp> + +using namespace nest::mc::util; + +TEST(mask_stream,nomask) { + // expect mask_stream(true) on a new stream not to change rdbuf. + std::ostringstream s; + auto sbuf = s.rdbuf(); + s << mask_stream(true); + EXPECT_EQ(sbuf, s.rdbuf()); +} + +TEST(mask_stream,mask) { + // masked stream should produce no ouptut + std::ostringstream s; + s << "one"; + s << mask_stream(false); + + s << "two"; + EXPECT_EQ(s.str(), "one"); + + s << mask_stream(true); + s << "three"; + EXPECT_EQ(s.str(), "onethree"); +} + +TEST(mask_stream,mask_multi) { + // mark_stream(false) should be idempotent + + std::ostringstream s; + auto sbuf1 = s.rdbuf(); + + s << "foo"; + s << mask_stream(false); + auto sbuf2 = s.rdbuf(); + + s << "bar"; + s << mask_stream(false); + auto sbuf3 = s.rdbuf(); + EXPECT_EQ(sbuf2, sbuf3); + + s << "baz"; + s << mask_stream(true); + auto sbuf4 = s.rdbuf(); + EXPECT_EQ(sbuf1, sbuf4); + + s << "xyzzy"; + EXPECT_EQ(s.str(), "fooxyzzy"); +} + +TEST(mask_stream,fmt) { + // expect formatting to be preserved across masks. + + std::ostringstream s; + s.precision(1); + + s << mask_stream(false); + EXPECT_EQ(s.precision(), 1); + s.precision(2); + + s << mask_stream(true); + EXPECT_EQ(s.precision(), 2); +} + diff --git a/tests/unit/test_mechanisms.cpp b/tests/unit/test_mechanisms.cpp index 2e31f1139d25cb3e36bc1ff0b32a3ec8a4316a4b..377f683fa8ff0831b0cdff3f94be96710d966d0c 100644 --- a/tests/unit/test_mechanisms.cpp +++ b/tests/unit/test_mechanisms.cpp @@ -1,25 +1,17 @@ #include "gtest.h" -#include "../src/mechanism_interface.hpp" -#include "../src/matrix.hpp" +//#include "../src/mechanism_interface.hpp" +#include "mechanism_catalogue.hpp" +#include "matrix.hpp" TEST(mechanisms, helpers) { - nest::mc::mechanisms::setup_mechanism_helpers(); - - EXPECT_EQ(nest::mc::mechanisms::mechanism_helpers.size(), 4u); + using namespace nest::mc; + using catalogue = mechanisms::catalogue<double, int>; // verify that the hh and pas channels are available - EXPECT_EQ(nest::mc::mechanisms::get_mechanism_helper("hh")->name(), "hh"); - EXPECT_EQ(nest::mc::mechanisms::get_mechanism_helper("pas")->name(), "pas"); - - // check that an out_of_range exception is thrown if an invalid mechanism is - // requested - ASSERT_THROW( - nest::mc::mechanisms::get_mechanism_helper("dachshund"), - std::out_of_range - ); + EXPECT_TRUE(catalogue::has("hh")); + EXPECT_TRUE(catalogue::has("pas")); - //0 1 2 3 4 5 6 7 8 9 std::vector<int> parent_index = {0,0,1,2,3,4,0,6,7,8}; memory::HostVector<int> node_indices = std::vector<int>{0,6,7,8,9}; auto n = node_indices.size(); @@ -28,10 +20,17 @@ TEST(mechanisms, helpers) { memory::HostVector<double> vec_i(n, 0.); memory::HostVector<double> vec_v(n, 0.); - auto& helper = nest::mc::mechanisms::get_mechanism_helper("hh"); - auto mech = helper->new_mechanism(vec_v, vec_i, node_indices); + auto mech = catalogue::make("hh", vec_v, vec_i, node_indices); EXPECT_EQ(mech->name(), "hh"); EXPECT_EQ(mech->size(), 5u); //EXPECT_EQ(mech->matrix_, &matrix); + + // check that an out_of_range exception is thrown if an invalid mechanism is + // requested + ASSERT_THROW( + catalogue::make("dachshund", vec_v, vec_i, node_indices), + std::out_of_range + ); + //0 1 2 3 4 5 6 7 8 9 } diff --git a/tests/unit/test_optional.cpp b/tests/unit/test_optional.cpp index 35b74a470cea3e82a230592454126d04ce092da9..4610f0c48cf6b317f97e8a895f9bb09245ee7220 100644 --- a/tests/unit/test_optional.cpp +++ b/tests/unit/test_optional.cpp @@ -24,24 +24,30 @@ TEST(optionalm,unset_throw) { optional<int> a; int check=10; - try { a.get(); } - catch(optional_unset_error &e) { + try { + a.get(); + } + catch (optional_unset_error& e) { ++check; } EXPECT_EQ(11,check); check=20; a=2; - try { a.get(); } - catch(optional_unset_error &e) { + try { + a.get(); + } + catch (optional_unset_error& e) { ++check; } EXPECT_EQ(20,check); check=30; a.reset(); - try { a.get(); } - catch(optional_unset_error &e) { + try { + a.get(); + } + catch (optional_unset_error& e) { ++check; } EXPECT_EQ(31,check); @@ -66,13 +72,13 @@ TEST(optionalm,ctor_conv) { TEST(optionalm,ctor_ref) { int v=10; - optional<int &> a(v); + optional<int&> a(v); EXPECT_EQ(10,a.get()); v=20; EXPECT_EQ(20,a.get()); - optional<int &> b(a),c=b,d=v; + optional<int&> b(a),c=b,d=v; EXPECT_EQ(&(a.get()),&(b.get())); EXPECT_EQ(&(a.get()),&(c.get())); EXPECT_EQ(&(a.get()),&(d.get())); @@ -110,13 +116,13 @@ struct nomove { nomove(): value(0) {} nomove(int i): value(i) {} - nomove(const nomove &n): value(n.value) {} - nomove(nomove &&n) = delete; + nomove(const nomove& n): value(n.value) {} + nomove(nomove&& n) = delete; - nomove &operator=(const nomove &n) { value=n.value; return *this; } + nomove& operator=(const nomove& n) { value=n.value; return *this; } - bool operator==(const nomove &them) const { return them.value==value; } - bool operator!=(const nomove &them) const { return !(*this==them); } + bool operator==(const nomove& them) const { return them.value==value; } + bool operator!=(const nomove& them) const { return !(*this==them); } }; TEST(optionalm,ctor_nomove) { @@ -136,21 +142,21 @@ struct nocopy { nocopy(): value(0) {} nocopy(int i): value(i) {} - nocopy(const nocopy &n) = delete; - nocopy(nocopy &&n) { + nocopy(const nocopy& n) = delete; + nocopy(nocopy&& n) { value=n.value; n.value=0; } - nocopy &operator=(const nocopy &n) = delete; - nocopy &operator=(nocopy &&n) { + nocopy& operator=(const nocopy& n) = delete; + nocopy& operator=(nocopy&& n) { value=n.value; n.value=-1; return *this; } - bool operator==(const nocopy &them) const { return them.value==value; } - bool operator!=(const nocopy &them) const { return !(*this==them); } + bool operator==(const nocopy& them) const { return them.value==value; } + bool operator!=(const nocopy& them) const { return !(*this==them); } }; TEST(optionalm,ctor_nocopy) { @@ -251,13 +257,13 @@ TEST(optionalm,bind_to_optional_void) { TEST(optionalm,bind_with_ref) { optional<int> a=10; - a >> [](int &v) {++v; }; + a >> [](int& v) { ++v; }; EXPECT_EQ(11,*a); } struct check_cref { - int operator()(const int &) { return 10; } - int operator()(int &) { return 11; } + int operator()(const int&) { return 10; } + int operator()(int&) { return 11; } }; TEST(optionalm,bind_constness) { diff --git a/tests/unit/test_parameters.cpp b/tests/unit/test_parameters.cpp index 0781640489725a100b3e4ede3be40dfe168b1cef..a573e07cd9fce4581d2bd6b026e1499f94336d1d 100644 --- a/tests/unit/test_parameters.cpp +++ b/tests/unit/test_parameters.cpp @@ -28,7 +28,7 @@ TEST(parameters, setting) EXPECT_FALSE(list.add_parameter({"b", -3.0})); EXPECT_EQ(list.num_parameters(), 2); - auto &parms = list.parameters(); + auto& parms = list.parameters(); EXPECT_EQ(parms[0].name, "a"); EXPECT_EQ(parms[0].value, 0.12); EXPECT_EQ(parms[0].range.min, 0); diff --git a/tests/unit/test_probe.cpp b/tests/unit/test_probe.cpp index 52ac4e7f62171bc556df5b0ba4e25f05b3304fa5..5af744d2cc5f7f9b8a403b991b22d260c9ec74d4 100644 --- a/tests/unit/test_probe.cpp +++ b/tests/unit/test_probe.cpp @@ -1,5 +1,6 @@ #include "gtest.h" +#include "common_types.hpp" #include "cell.hpp" #include "fvm_cell.hpp" @@ -12,13 +13,13 @@ TEST(probe, instantiation) segment_location loc1{0, 0}; segment_location loc2{1, 0.6}; - auto p1 = c1.add_probe(loc1, probeKind::membrane_voltage); - auto p2 = c1.add_probe(loc2, probeKind::membrane_current); + auto p1 = c1.add_probe({loc1, probeKind::membrane_voltage}); + auto p2 = c1.add_probe({loc2, probeKind::membrane_current}); // expect locally provided probe ids to be numbered sequentially from zero. - - EXPECT_EQ(0, p1); - EXPECT_EQ(1, p2); + + EXPECT_EQ(0u, p1); + EXPECT_EQ(1u, p2); // expect the probes() return to be a collection with these two probes. @@ -48,14 +49,14 @@ TEST(probe, fvm_cell) segment_location loc1{1, 1}; segment_location loc2{1, 0.5}; - auto pv0 = bs.add_probe(loc0, probeKind::membrane_voltage); - auto pv1 = bs.add_probe(loc1, probeKind::membrane_voltage); - auto pi2 = bs.add_probe(loc2, probeKind::membrane_current); - + auto pv0 = bs.add_probe({loc0, probeKind::membrane_voltage}); + auto pv1 = bs.add_probe({loc1, probeKind::membrane_voltage}); + auto pi2 = bs.add_probe({loc2, probeKind::membrane_current}); + i_clamp stim(0, 100, 0.3); bs.add_stimulus({1, 1}, stim); - fvm::fvm_cell<double, int> lcell(bs); + fvm::fvm_cell<double, cell_local_size_type> lcell(bs); lcell.setup_matrix(0.01); lcell.initialize(); diff --git a/tests/unit/test_spikes.cpp b/tests/unit/test_spikes.cpp index b16f89cda31768f43b8509b9f9cdde375ed9bc25..3b4862029208c1f6c5a657d83dd982f510852cd6 100644 --- a/tests/unit/test_spikes.cpp +++ b/tests/unit/test_spikes.cpp @@ -1,7 +1,7 @@ #include "gtest.h" -#include <communication/spike.hpp> -#include <communication/spike_source.hpp> +#include <spike.hpp> +#include <spike_source.hpp> struct cell_proxy { double voltage(nest::mc::segment_location loc) const { diff --git a/tests/unit/test_swcio.cpp b/tests/unit/test_swcio.cpp index 32eaff75b967ccf9c5f8b7ef7f06c8b3d8b10f3e..c927224ffaef7e6e9d026a53dd81377b061a24eb 100644 --- a/tests/unit/test_swcio.cpp +++ b/tests/unit/test_swcio.cpp @@ -502,7 +502,7 @@ TEST(swc_io, cell_construction) cell cell = io::swc_read_cell(is); EXPECT_TRUE(cell.has_soma()); - EXPECT_EQ(4, cell.num_segments()); + EXPECT_EQ(4u, cell.num_segments()); EXPECT_EQ(norm(points[1]-points[2]), cell.cable(1)->length()); EXPECT_EQ(norm(points[2]-points[3]), cell.cable(2)->length()); @@ -515,13 +515,13 @@ TEST(swc_io, cell_construction) EXPECT_EQ(2.1, cell.soma()->radius()); EXPECT_EQ(point_type(0, 0, 0), cell.soma()->center()); - for (auto i = 1; i < cell.num_segments(); ++i) { + for (auto i = 1u; i < cell.num_segments(); ++i) { EXPECT_TRUE(cell.segment(i)->is_dendrite()); } - EXPECT_EQ(1, cell.cable(1)->num_sub_segments()); - EXPECT_EQ(1, cell.cable(2)->num_sub_segments()); - EXPECT_EQ(2, cell.cable(3)->num_sub_segments()); + EXPECT_EQ(1u, cell.cable(1)->num_sub_segments()); + EXPECT_EQ(1u, cell.cable(2)->num_sub_segments()); + EXPECT_EQ(2u, cell.cable(3)->num_sub_segments()); // Check the radii @@ -563,7 +563,7 @@ TEST(swc_parser, from_file_ball_and_stick) auto cell = nest::mc::io::swc_read_cell(fid); // verify that the correct number of nodes was read - EXPECT_EQ(cell.num_segments(), 2); + EXPECT_EQ(cell.num_segments(), 2u); EXPECT_EQ(cell.num_compartments(), 2u); // make an equivalent cell via C++ interface diff --git a/tests/unit/test_synapses.cpp b/tests/unit/test_synapses.cpp index 9ad68c2444ef17d6f2b2355e3c388f76ca9220e6..8cd9de04f57625ec7416e5a571ab24447c1a6cae 100644 --- a/tests/unit/test_synapses.cpp +++ b/tests/unit/test_synapses.cpp @@ -14,7 +14,7 @@ TEST(synapses, add_to_cell) nest::mc::cell cell; // setup global state for the mechanisms - nest::mc::mechanisms::setup_mechanism_helpers(); + // nest::mc::mechanisms::setup_mechanism_helpers(); // Soma with diameter 12.6157 um and HH channel auto soma = cell.add_soma(12.6157/2.0); @@ -30,15 +30,15 @@ TEST(synapses, add_to_cell) EXPECT_EQ(3u, cell.synapses().size()); const auto& syns = cell.synapses(); - EXPECT_EQ(syns[0].location.segment, 0); + EXPECT_EQ(syns[0].location.segment, 0u); EXPECT_EQ(syns[0].location.position, 0.1); EXPECT_EQ(syns[0].mechanism.name(), "expsyn"); - EXPECT_EQ(syns[1].location.segment, 1); + EXPECT_EQ(syns[1].location.segment, 1u); EXPECT_EQ(syns[1].location.position, 0.2); EXPECT_EQ(syns[1].mechanism.name(), "exp2syn"); - EXPECT_EQ(syns[2].location.segment, 0); + EXPECT_EQ(syns[2].location.segment, 0u); EXPECT_EQ(syns[2].location.position, 0.3); EXPECT_EQ(syns[2].mechanism.name(), "expsyn"); } diff --git a/tests/unit/test_tree.cpp b/tests/unit/test_tree.cpp index b82da064b6143fa120c92df2b223db8b95db15ff..4c69ce400d1c6146f7f704b2317a0189b8a41a53 100644 --- a/tests/unit/test_tree.cpp +++ b/tests/unit/test_tree.cpp @@ -17,20 +17,24 @@ using json = nlohmann::json; using range = memory::Range; using namespace nest::mc; +using int_type = cell_tree::int_type; + TEST(cell_tree, from_parent_index) { + auto no_parent = cell_tree::no_parent; + // tree with single branch corresponding to the root node // this is equivalent to a single compartment model // CASE 1 : single root node in parent_index { - std::vector<int> parent_index = {0}; + std::vector<int_type> parent_index = {0}; cell_tree tree(parent_index); EXPECT_EQ(tree.num_segments(), 1u); EXPECT_EQ(tree.num_children(0), 0u); } // CASE 2 : empty parent_index { - std::vector<int> parent_index; + std::vector<int_type> parent_index; cell_tree tree(parent_index); EXPECT_EQ(tree.num_segments(), 1u); EXPECT_EQ(tree.num_children(0), 0u); @@ -46,7 +50,7 @@ TEST(cell_tree, from_parent_index) { // / // 3 // - std::vector<int> parent_index = + std::vector<int_type> parent_index = {0, 0, 1, 2, 0, 4}; cell_tree tree(parent_index); EXPECT_EQ(tree.num_segments(), 3u); @@ -66,7 +70,7 @@ TEST(cell_tree, from_parent_index) { // / \. // 3 8 // - std::vector<int> parent_index = + std::vector<int_type> parent_index = {0, 0, 1, 2, 0, 4, 0, 6, 7, 8}; cell_tree tree(parent_index); @@ -79,10 +83,10 @@ TEST(cell_tree, from_parent_index) { EXPECT_EQ(tree.num_children(3), 0u); // Check new structure - EXPECT_EQ(-1, tree.parent(0)); - EXPECT_EQ(0, tree.parent(1)); - EXPECT_EQ(0, tree.parent(2)); - EXPECT_EQ(0, tree.parent(3)); + EXPECT_EQ(no_parent, tree.parent(0)); + EXPECT_EQ(0u, tree.parent(1)); + EXPECT_EQ(0u, tree.parent(2)); + EXPECT_EQ(0u, tree.parent(3)); } { // @@ -100,7 +104,7 @@ TEST(cell_tree, from_parent_index) { // \. // 13 // - std::vector<int> parent_index = + std::vector<int_type> parent_index = {0, 0, 1, 2, 0, 4, 0, 6, 7, 8, 9, 8, 11, 12}; cell_tree tree(parent_index); EXPECT_EQ(tree.num_segments(), 6u); @@ -115,12 +119,12 @@ TEST(cell_tree, from_parent_index) { EXPECT_EQ(tree.num_children(5), 0u); // Check new structure - EXPECT_EQ(-1, tree.parent(0)); - EXPECT_EQ(0, tree.parent(1)); - EXPECT_EQ(0, tree.parent(2)); - EXPECT_EQ(0, tree.parent(3)); - EXPECT_EQ(3, tree.parent(4)); - EXPECT_EQ(3, tree.parent(5)); + EXPECT_EQ(no_parent, tree.parent(0)); + EXPECT_EQ(0u, tree.parent(1)); + EXPECT_EQ(0u, tree.parent(2)); + EXPECT_EQ(0u, tree.parent(3)); + EXPECT_EQ(3u, tree.parent(4)); + EXPECT_EQ(3u, tree.parent(5)); } { // @@ -129,7 +133,7 @@ TEST(cell_tree, from_parent_index) { // 1 // / \. // 2 3 - std::vector<int> parent_index = {0,0,1,1}; + std::vector<int_type> parent_index = {0,0,1,1}; cell_tree tree(parent_index); EXPECT_EQ(tree.num_segments(), 4u); @@ -146,7 +150,7 @@ TEST(cell_tree, from_parent_index) { // 1 4 5 // / \. // 2 3 - std::vector<int> parent_index = {0,0,1,1,0,0}; + std::vector<int_type> parent_index = {0,0,1,1,0,0}; cell_tree tree(parent_index); EXPECT_EQ(tree.num_segments(), 6u); @@ -158,11 +162,11 @@ TEST(cell_tree, from_parent_index) { EXPECT_EQ(tree.num_children(4), 0u); // Check children - EXPECT_EQ(1, tree.children(0)[0]); - EXPECT_EQ(4, tree.children(0)[1]); - EXPECT_EQ(5, tree.children(0)[2]); - EXPECT_EQ(2, tree.children(1)[0]); - EXPECT_EQ(3, tree.children(1)[1]); + EXPECT_EQ(1u, tree.children(0)[0]); + EXPECT_EQ(4u, tree.children(0)[1]); + EXPECT_EQ(5u, tree.children(0)[2]); + EXPECT_EQ(2u, tree.children(1)[0]); + EXPECT_EQ(3u, tree.children(1)[1]); } /* FIXME { @@ -198,8 +202,8 @@ TEST(tree, change_root) { // 1 2 -> 1 // | // 2 - std::vector<int> parent_index = {0,0,0}; - tree t(parent_index); + std::vector<int_type> parent_index = {0,0,0}; + tree<int_type> t(parent_index); t.change_root(1); EXPECT_EQ(t.num_nodes(), 3u); @@ -216,8 +220,8 @@ TEST(tree, change_root) { // 1 2 -> 1 2 3 // / \ | // 3 4 4 - std::vector<int> parent_index = {0,0,0,1,1}; - tree t(parent_index); + std::vector<int_type> parent_index = {0,0,0,1,1}; + tree<int_type> t(parent_index); t.change_root(1u); EXPECT_EQ(t.num_nodes(), 5u); @@ -240,8 +244,8 @@ TEST(tree, change_root) { // 3 4 3 4 6 // / \. // 5 6 - std::vector<int> parent_index = {0,0,0,1,1,4,4}; - tree t(parent_index); + std::vector<int_type> parent_index = {0,0,0,1,1,4,4}; + tree<int_type> t(parent_index); t.change_root(1); @@ -258,6 +262,8 @@ TEST(tree, change_root) { } TEST(cell_tree, balance) { + auto no_parent = cell_tree::no_parent; + { // a cell with the following structure // will balance around 1 @@ -268,13 +274,13 @@ TEST(cell_tree, balance) { // 3 4 3 4 6 // / \. // 5 6 - std::vector<int> parent_index = {0,0,0,1,1,4,4}; + std::vector<int_type> parent_index = {0,0,0,1,1,4,4}; cell_tree t(parent_index); t.balance(); // the soma (original root) has moved to 5 in the new tree - EXPECT_EQ(t.soma(), 5); + EXPECT_EQ(t.soma(), 5u); EXPECT_EQ(t.num_segments(), 7u); EXPECT_EQ(t.num_children(0),3u); @@ -284,13 +290,13 @@ TEST(cell_tree, balance) { EXPECT_EQ(t.num_children(4),0u); EXPECT_EQ(t.num_children(5),1u); EXPECT_EQ(t.num_children(6),0u); - EXPECT_EQ(t.parent(0),-1); - EXPECT_EQ(t.parent(1), 0); - EXPECT_EQ(t.parent(2), 0); - EXPECT_EQ(t.parent(3), 0); - EXPECT_EQ(t.parent(4), 2); - EXPECT_EQ(t.parent(5), 2); - EXPECT_EQ(t.parent(6), 5); + EXPECT_EQ(t.parent(0), no_parent); + EXPECT_EQ(t.parent(1), 0u); + EXPECT_EQ(t.parent(2), 0u); + EXPECT_EQ(t.parent(3), 0u); + EXPECT_EQ(t.parent(4), 2u); + EXPECT_EQ(t.parent(5), 2u); + EXPECT_EQ(t.parent(6), 5u); //t.to_graphviz("cell.dot"); } @@ -307,7 +313,7 @@ TEST(cell_tree, json_load) std::ifstream(path) >> cell_data; for(auto c : range(0,cell_data.size())) { - std::vector<int> parent_index = cell_data[c]["parent_index"]; + std::vector<int_type> parent_index = cell_data[c]["parent_index"]; cell_tree tree(parent_index); //tree.to_graphviz("cell" + std::to_string(c) + ".dot"); } diff --git a/tests/unit/test_uninitialized.cpp b/tests/unit/test_uninitialized.cpp index dcc9e47ce6a7fa266376e054e8f84a04d8ec818d..1f04510d7df6af054eaec854d7e0c954ad5e73c8 100644 --- a/tests/unit/test_uninitialized.cpp +++ b/tests/unit/test_uninitialized.cpp @@ -7,16 +7,16 @@ using namespace nest::mc::util; namespace { struct count_ops { count_ops() {} - count_ops(const count_ops &n) { ++copy_ctor_count; } - count_ops(count_ops &&n) { ++move_ctor_count; } + count_ops(const count_ops& n) { ++copy_ctor_count; } + count_ops(count_ops&& n) { ++move_ctor_count; } - count_ops &operator=(const count_ops &n) { ++copy_assign_count; return *this; } - count_ops &operator=(count_ops &&n) { ++move_assign_count; return *this; } + count_ops& operator=(const count_ops& n) { ++copy_assign_count; return *this; } + count_ops& operator=(count_ops&& n) { ++move_assign_count; return *this; } static int copy_ctor_count,copy_assign_count; static int move_ctor_count,move_assign_count; static void reset_counts() { - copy_ctor_count=copy_assign_count=0; + copy_ctor_count=copy_assign_count=0; move_ctor_count=move_assign_count=0; } }; @@ -53,11 +53,11 @@ TEST(uninitialized,ctor) { namespace { struct nocopy { nocopy() {} - nocopy(const nocopy &n) = delete; - nocopy(nocopy &&n) { ++move_ctor_count; } + nocopy(const nocopy& n) = delete; + nocopy(nocopy&& n) { ++move_ctor_count; } - nocopy &operator=(const nocopy &n) = delete; - nocopy &operator=(nocopy &&n) { ++move_assign_count; return *this; } + nocopy& operator=(const nocopy& n) = delete; + nocopy& operator=(nocopy&& n) { ++move_assign_count; return *this; } static int move_ctor_count,move_assign_count; static void reset_counts() { move_ctor_count=move_assign_count=0; } @@ -85,11 +85,11 @@ TEST(uninitialized,ctor_nocopy) { namespace { struct nomove { nomove() {} - nomove(const nomove &n) { ++copy_ctor_count; } - nomove(nomove &&n) = delete; + nomove(const nomove& n) { ++copy_ctor_count; } + nomove(nomove&& n) = delete; - nomove &operator=(const nomove &n) { ++copy_assign_count; return *this; } - nomove &operator=(nomove &&n) = delete; + nomove& operator=(const nomove& n) { ++copy_assign_count; return *this; } + nomove& operator=(nomove&& n) = delete; static int copy_ctor_count,copy_assign_count; static void reset_counts() { copy_ctor_count=copy_assign_count=0; } @@ -129,7 +129,7 @@ TEST(uninitialized,void) { } TEST(uninitialized,ref) { - uninitialized<int &> x,y; + uninitialized<int&> x,y; int a; x.construct(a); @@ -151,8 +151,8 @@ namespace { mutable int op_count=0; mutable int const_op_count=0; - int operator()(const int &a) const { ++const_op_count; return a+1; } - int operator()(int &a) const { ++op_count; return ++a; } + int operator()(const int& a) const { ++const_op_count; return a+1; } + int operator()(int& a) const { ++op_count; return ++a; } }; } @@ -165,14 +165,14 @@ TEST(uninitialized,apply) { EXPECT_EQ(11,ua.cref()); EXPECT_EQ(11,r); - uninitialized<int &> ub; + uninitialized<int&> ub; ub.construct(ua.ref()); r=ub.apply(A); EXPECT_EQ(12,ua.cref()); EXPECT_EQ(12,r); - uninitialized<const int &> uc; + uninitialized<const int&> uc; uc.construct(ua.ref()); r=uc.apply(A); diff --git a/tests/validation/validate_ball_and_stick.cpp b/tests/validation/validate_ball_and_stick.cpp index 0e41bd021534af6020482659f082eca0867fb63a..62d06b4dac85885f935dd88855223551cc33fcce 100644 --- a/tests/validation/validate_ball_and_stick.cpp +++ b/tests/validation/validate_ball_and_stick.cpp @@ -1,6 +1,7 @@ #include <fstream> #include <json/src/json.hpp> +#include <common_types.hpp> #include <cell.hpp> #include <fvm_cell.hpp> @@ -16,9 +17,6 @@ TEST(ball_and_stick, neuron_baseline) nest::mc::cell cell; - // setup global state for the mechanisms - nest::mc::mechanisms::setup_mechanism_helpers(); - // Soma with diameter 12.6157 um and HH channel auto soma = cell.add_soma(12.6157/2.0); soma->add_mechanism(hh_parameters()); @@ -92,7 +90,7 @@ TEST(ball_and_stick, neuron_baseline) std::vector<std::vector<double>> v(3); // make the lowered finite volume cell - fvm::fvm_cell<double, int> model(cell); + fvm::fvm_cell<double, cell_local_size_type> model(cell); auto graph = cell.model(); // set initial conditions @@ -165,9 +163,6 @@ TEST(ball_and_3stick, neuron_baseline) nest::mc::cell cell; - // setup global state for the mechanisms - nest::mc::mechanisms::setup_mechanism_helpers(); - // Soma with diameter 12.6157 um and HH channel auto soma = cell.add_soma(12.6157/2.0); soma->add_mechanism(hh_parameters()); @@ -250,7 +245,7 @@ TEST(ball_and_3stick, neuron_baseline) std::vector<std::vector<double>> v(3); // make the lowered finite volume cell - fvm::fvm_cell<double, int> model(cell); + fvm::fvm_cell<double, cell_local_size_type> model(cell); auto graph = cell.model(); // set initial conditions diff --git a/tests/validation/validate_soma.cpp b/tests/validation/validate_soma.cpp index 9c3c9fdfa56b14b9aa948c762023355541866d9a..24a0c1c59f0db24bc6f9c7832909f9f52af7beb7 100644 --- a/tests/validation/validate_soma.cpp +++ b/tests/validation/validate_soma.cpp @@ -1,6 +1,7 @@ #include <fstream> #include <json/src/json.hpp> +#include <common_types.hpp> #include <cell.hpp> #include <fvm_cell.hpp> @@ -17,9 +18,6 @@ TEST(soma, neuron_baseline) nest::mc::cell cell; - // setup global state for the mechanisms - nest::mc::mechanisms::setup_mechanism_helpers(); - // Soma with diameter 18.8um and HH channel auto soma = cell.add_soma(18.8/2.0); soma->mechanism("membrane").set("r_L", 123); // no effect for single compartment cell @@ -29,7 +27,7 @@ TEST(soma, neuron_baseline) cell.add_stimulus({0,0.5}, {10., 100., 0.1}); // make the lowered finite volume cell - fvm::fvm_cell<double, int> model(cell); + fvm::fvm_cell<double, cell_local_size_type> model(cell); // load data from file auto cell_data = testing::g_validation_data.load("soma.json"); @@ -85,9 +83,6 @@ TEST(soma, convergence) nest::mc::cell cell; - // setup global state for the mechanisms - nest::mc::mechanisms::setup_mechanism_helpers(); - // Soma with diameter 18.8um and HH channel auto soma = cell.add_soma(18.8/2.0); soma->mechanism("membrane").set("r_L", 123); // no effect for single compartment cell @@ -97,7 +92,7 @@ TEST(soma, convergence) cell.add_stimulus({0,0.5}, {10., 100., 0.1}); // make the lowered finite volume cell - fvm::fvm_cell<double, int> model(cell); + fvm::fvm_cell<double, cell_local_size_type> model(cell); // generate baseline solution with small dt=0.0001 std::vector<double> baseline_spike_times; diff --git a/tests/validation/validate_synapses.cpp b/tests/validation/validate_synapses.cpp index 42ace29c3393bb177b878ab3e8c6eb6de42d8676..876b24eb55c9cad0e08c17cef9cdf16595f5bd90 100644 --- a/tests/validation/validate_synapses.cpp +++ b/tests/validation/validate_synapses.cpp @@ -3,6 +3,7 @@ #include <json/src/json.hpp> +#include <common_types.hpp> #include <cell.hpp> #include <cell_group.hpp> #include <fvm_cell.hpp> @@ -50,12 +51,10 @@ void run_neuron_baseline(const char* syn_type, const char* data_file) { using namespace nest::mc; using namespace nlohmann; + using lowered_cell = fvm::fvm_cell<double, cell_local_size_type>; nest::mc::cell cell; - // setup global state for the mechanisms - mechanisms::setup_mechanism_helpers(); - // Soma with diameter 12.6157 um and HH channel auto soma = cell.add_soma(12.6157/2.0); soma->add_mechanism(hh_parameters()); @@ -72,14 +71,14 @@ void run_neuron_baseline(const char* syn_type, const char* data_file) cell.add_synapse({1, 0.5}, syn_default); // add probes - auto probe_soma = cell.add_probe({0,0}, probeKind::membrane_voltage); - auto probe_dend = cell.add_probe({1,0.5}, probeKind::membrane_voltage); + auto probe_soma = cell.add_probe({{0,0}, probeKind::membrane_voltage}); + auto probe_dend = cell.add_probe({{1,0.5}, probeKind::membrane_voltage}); // injected spike events - postsynaptic_spike_event synthetic_events[] = { - {0u, 10.0, 0.04}, - {0u, 20.0, 0.04}, - {0u, 40.0, 0.04} + postsynaptic_spike_event<float> synthetic_events[] = { + {{0u, 0u}, 10.0, 0.04}, + {{0u, 0u}, 20.0, 0.04}, + {{0u, 0u}, 40.0, 0.04} }; // load data from file @@ -106,9 +105,7 @@ void run_neuron_baseline(const char* syn_type, const char* data_file) std::vector<std::vector<double>> v(2); // make the lowered finite volume cell - cell_group<fvm::fvm_cell<double, int>> group(cell); - group.set_source_gids(0); - group.set_target_gids(0); + cell_group<lowered_cell> group(0, cell); // add the 3 spike events to the queue group.enqueue_events(synthetic_events);