Skip to content
Snippets Groups Projects
Commit d5edf26e authored by Sam Yates's avatar Sam Yates
Browse files

WIP - does not compile

* splitting out model abstraction from miniapp
* move to (gid,index) ids for sources, targets
* simplify lid/gid mappings for communicator, cell group etc.
parent ee03400c
No related branches found
No related tags found
No related merge requests found
...@@ -3,8 +3,9 @@ set(HEADERS ...@@ -3,8 +3,9 @@ set(HEADERS
set(MINIAPP_SOURCES set(MINIAPP_SOURCES
# mpi.cpp # mpi.cpp
io.cpp io.cpp
miniapp.cpp # miniapp.cpp
recipes.cpp recipes.cpp
model.cpp
) )
add_executable(miniapp.exe ${MINIAPP_SOURCES} ${HEADERS}) add_executable(miniapp.exe ${MINIAPP_SOURCES} ${HEADERS})
......
#include <cstdlib>
#include <vector>
#include "catypes.hpp"
#include "cell.hpp"
#include "cell_group.hpp"
#include "communication/communicator.hpp"
#include "communication/global_policy.hpp"
#include "fvm_cell.hpp"
#include "profiling/profiler.hpp"
#include "threading/threading.hpp"
namespace nest {
namespace mc {
struct model {
using cell_group = cell_group<fvm::fvm_cell<double, cell_local_size_type>>;
void reset() {
t_ = 0.;
// otherwise unimplemented
std::abort();
}
double run(double tuntil, double dt) {
while (t_<tuntil) {
auto tstep = std::min(t_+dt, tunitl);
threading::parallel_for::apply(
0u, cell_groups.size(),
[&](unsigned i) {
auto &group = cell_group[i];
util::profiler_enter("stepping","events");
group.enqueue_events(communicator.queue(i));
util::profiler_leave();
group.advance(tstep, dt);
util::profiler_enter("events");
communicator.add_spikes(group.spikes());
group.clear_spikes();
util::profiler_leave(2);
});
util::profiler_enter("stepping", "exchange");
communicator.exchange();
util::profiler_leave(2);
t_ += delta;
}
return t_;
}
explicit model(const recipe &rec, float sample_dt) {
// crude load balancing:
auto num_domains = global_policy::size();
auto domain_id = global_policy::id();
auto num_cells = rec.num_cells();
cell_gid_type cell_from = (cell_gid_type)(num_cells*(domain_id/(double)num_domains));
cell_gid_type cell_to = (cell_gid_type)(num_cells*((domain_id+1)/(double)num_domains));
// construct cell groups (one cell per group) and attach samplers
cell_groups.resize(cell_to-cell_from);
samplers.resize(cell_to-cell_from);
threading::parallel_for::apply(cell_from, cell_to,
[&](cell_gid_type i) {
util::profiler_enter("setup", "cells");
auto cell = cell_group(rec.get_cell(i));
auto idx = i-cell_from;
cell_groups[idx] = cell_group(cell);
cell_local_index_type j = 0;
for (const auto& probe: cell.probes()) {
samplers[idx].emplace_back({i,j}, probe.kind, probe.location, sample_dt);
const auto &sampler = samplers[idx].back();
cell_groups[idx].add_sampler(sampler, sampler.next_sample_t());
}
util::profiler_leave(2);
});
// initialise communicator
communicator = communicator_type(cell_from, cell_to);
}
private:
double t_ = 0.;
std::vector<cell_group> cell_groups;
std::vector<std::vector<sample_to_trace>> samplers;
communicator_type communicator;
};
// move sampler code to another source file...
struct sample_trace {
struct sample_type {
float time;
double value;
};
std::string name;
std::string units;
cell_gid_type cell_gid;
cell_index_type probe_index;
std::vector<sample_type> samples;
};
struct sample_to_trace {
float next_sample_t() const { return t_next_sample_; }
optional<float> operator()(float t, double v) {
if (t<t_next_sample_) {
return t_next_sample_;
}
trace.samples.push_back({t,v});
return t_next_sample_+=sample_dt_;
}
sample_to_trace(cell_member_type probe_id,
const std::string &name,
const std::string &units,
float dt,
float t_start=0):
trace_{{name, units, probe_id.gid, probe_id.index}},
sample_dt_(dt),
t_next_sample_(t_start)
{}
sample_to_trace(cell_member_type probe_id,
probeKind kind,
segment_location loc,
float dt,
float t_start=0):
sample_to_trace(probe_id, "", "", dt, t_start)
{
std::string name = "";
std::string units = "";
switch (kind) {
case probeKind::mebrane_voltage:
name = "v";
units = "mV";
break;
case probeKind::mebrane_current:
name = "i";
units = "mA/cm^2";
break;
default: ;
}
trace_.name = name + (loc.segment? "dend": "soma");
trace_.units = units;
}
void write_trace(const std::string& prefix = "trace_") const {
// do not call during simulation: thread-unsafe access to traces.
auto path = prefix + std::to_string(trace.cell_gid) +
"." + std::to_string(trace.probe_index) + ".json";
nlohmann::json jrep;
jrep["name"] = trace.name;
jrep["units"] = trace.units;
jrep["cell"] = trace.cell_gid;
jrep["probe"] = trace.probe_index;
auto& jt = jrep["data"]["time"];
auto& jy = jrep["data"][trace.name];
for (const auto& sample: trace.samples) {
jt.push_back(sample.time);
jy.push_back(sample.value);
}
std::ofstream file(path);
file << std::setw(1) << jrep << std::endl;
}
private:
sample_trace trace_;
float sample_dt_;
float t_next_sample_;
};
} // namespace mc
} // namespace nest
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include <catypes.hpp>
#include <cell.hpp> #include <cell.hpp>
#include <event_queue.hpp> #include <event_queue.hpp>
#include <spike.hpp> #include <spike.hpp>
...@@ -27,60 +28,41 @@ struct sampler { ...@@ -27,60 +28,41 @@ struct sampler {
template <typename Cell> template <typename Cell>
class cell_group { class cell_group {
public: public:
using index_type = uint32_t; using index_type = cell_gid_type;
using cell_type = Cell; using cell_type = Cell;
using value_type = typename cell_type::value_type; using value_type = typename cell_type::value_type;
using size_type = typename cell_type::value_type; using size_type = typename cell_type::value_type;
using spike_detector_type = spike_detector<Cell>; using spike_detector_type = spike_detector<Cell>;
using source_id_type = cell_member_type;
struct spike_source_type { struct spike_source_type {
index_type index; source_id_type source_id;
spike_detector_type source; spike_detector_type source;
}; };
cell_group() = default; cell_group() = default;
cell_group(const cell& c) : cell_group(cell_gid_type gid, const cell& c) :
cell_{c} gid_base_{gid}, cell_{c}
{ {
cell_.voltage()(memory::all) = -65.; cell_.voltage()(memory::all) = -65.;
cell_.initialize(); cell_.initialize();
source_id_type source_id={gid_base_,0};
for (auto& d : c.detectors()) { for (auto& d : c.detectors()) {
spike_sources_.push_back( { ++source_id.index;
0u, spike_detector_type(cell_, d.location, d.threshold, 0.f) spike_sources_.push_back({
source_id, spike_detector_type(cell_, d.location, d.threshold, 0.f)
}); });
} }
} }
void set_source_gids(index_type gid) {
for (auto& s : spike_sources_) {
s.index = gid++;
}
}
void set_target_gids(index_type lid) {
first_target_gid_ = lid;
}
index_type num_probes() const {
return cell_.num_probes();
}
void set_probe_gids(index_type gid) {
first_probe_gid_ = gid;
}
std::pair<index_type, index_type> probe_gid_range() const {
return { first_probe_gid_, first_probe_gid_+cell_.num_probes() };
}
void advance(double tfinal, double dt) { void advance(double tfinal, double dt) {
while (cell_.time()<tfinal) { while (cell_.time()<tfinal) {
// take any pending samples // take any pending samples
float cell_time = cell_.time(); float cell_time = cell_.time();
nest::mc::util::profiler_enter("sampling"); util::profiler_enter("sampling");
while (auto m = sample_events_.pop_if_before(cell_time)) { while (auto m = sample_events_.pop_if_before(cell_time)) {
auto& sampler = samplers_[m->sampler_index]; auto& sampler = samplers_[m->sampler_index];
EXPECTS((bool)sampler.sample); EXPECTS((bool)sampler.sample);
...@@ -92,7 +74,7 @@ public: ...@@ -92,7 +74,7 @@ public:
sample_events_.push(*m); sample_events_.push(*m);
} }
} }
nest::mc::util::profiler_leave(); util::profiler_leave();
// look for events in the next time step // look for events in the next time step
auto tstep = std::min(tfinal, cell_.time()+dt); auto tstep = std::min(tfinal, cell_.time()+dt);
...@@ -105,11 +87,11 @@ public: ...@@ -105,11 +87,11 @@ public:
std::cerr << "warning: solution out of bounds\n"; std::cerr << "warning: solution out of bounds\n";
} }
nest::mc::util::profiler_enter("events"); util::profiler_enter("events");
// check for new spikes // check for new spikes
for (auto& s : spike_sources_) { for (auto& s : spike_sources_) {
if (auto spike = s.source.test(cell_, cell_.time())) { if (auto spike = s.source.test(cell_, cell_.time())) {
spikes_.push_back({s.index, spike.get()}); spikes_.push_back({s.source_id, spike.get()});
} }
} }
...@@ -123,7 +105,7 @@ public: ...@@ -123,7 +105,7 @@ public:
cell_.apply_event(e.get()); cell_.apply_event(e.get());
} }
} }
nest::mc::util::profiler_leave(); util::profiler_leave();
} }
} }
...@@ -136,10 +118,8 @@ public: ...@@ -136,10 +118,8 @@ public:
} }
} }
const std::vector<spike<index_type>>& const std::vector<spike<source_id_type>>&
spikes() const { spikes() const { return spikes_; }
return spikes_;
}
cell_type& cell() { return cell_; } cell_type& cell() { return cell_; }
const cell_type& cell() const { return cell_; } const cell_type& cell() const { return cell_; }
...@@ -160,7 +140,8 @@ public: ...@@ -160,7 +140,8 @@ public:
} }
private: private:
/// gid of first cell in group
cell_gid_type gid_base_;
/// the lowered cell state (e.g. FVM) of the cell /// the lowered cell state (e.g. FVM) of the cell
cell_type cell_; cell_type cell_;
...@@ -169,7 +150,7 @@ private: ...@@ -169,7 +150,7 @@ private:
std::vector<spike_source_type> spike_sources_; std::vector<spike_source_type> spike_sources_;
//. spikes that are generated //. spikes that are generated
std::vector<spike<index_type>> spikes_; std::vector<spike<source_id_type>> spikes_;
/// pending events to be delivered /// pending events to be delivered
event_queue<postsynaptic_spike_event> events_; event_queue<postsynaptic_spike_event> events_;
......
...@@ -29,76 +29,33 @@ namespace communication { ...@@ -29,76 +29,33 @@ namespace communication {
template <typename CommunicationPolicy> template <typename CommunicationPolicy>
class communicator { class communicator {
public: public:
using id_type = uint32_t; using id_type = cell_gid_type;
using communication_policy_type = CommunicationPolicy; using communication_policy_type = CommunicationPolicy;
using spike_type = spike<id_type>; using spike_type = spike<id_type>;
communicator() = default; communicator() = default;
communicator(id_type n_groups, std::vector<id_type> target_counts) : // for now, still assuming one-to-one association cells <-> groups,
num_groups_local_(n_groups), // so that 'group' gids as represented by their first cell gid are
num_targets_local_(target_counts.size()) // contiguous.
communicator(id_type cell_from, id_type cell_to):
cell_gid_from_(cell_from), cell_gid_to(cell_to)
{ {
target_map_ = nest::mc::algorithms::make_index(target_counts); auto num_groups_local_ = cell_gid_to_-cell_gid_from_;
num_targets_local_ = target_map_.back();
// create an event queue for each target group // create an event queue for each target group
events_.resize(num_groups_local_); events_.resize(num_groups_local_);
// make maps for converting lid to gid
target_gid_map_ = communication_policy_.make_map(num_targets_local_);
group_gid_map_ = communication_policy_.make_map(num_groups_local_);
// transform the target ids from lid to gid
auto first_target = target_gid_map_[domain_id()];
for (auto &id : target_map_) {
id += first_target;
}
}
id_type target_gid_from_group_lid(id_type lid) const {
EXPECTS(lid<num_groups_local_);
return target_map_[lid];
} }
id_type group_gid_from_group_lid(id_type lid) const {
EXPECTS(lid<num_groups_local_);
return group_gid_map_[domain_id()] + lid;
}
void add_connection(connection con) { void add_connection(connection con) {
EXPECTS(is_local_target(con.destination())); EXPECTS(is_local_target(con.destination()));
connections_.push_back(con); connections_.push_back(con);
} }
bool is_local_target(id_type gid) { bool is_local_cell(id_type gid) const {
return gid>=target_gid_map_[domain_id()] return gid>=cell_gid_from_ && gid<cell_gid_to_;
&& gid<target_gid_map_[domain_id()+1];
}
bool is_local_group(id_type gid) {
return gid>=group_gid_map_[domain_id()]
&& gid<group_gid_map_[domain_id()+1];
}
/// return the global id of the first group in domain d
/// the groups in domain d are in the contiguous half open range
/// [domain_first_group(d), domain_first_group(d+1))
id_type group_gid_first(int d) const {
return group_gid_map_[d];
}
id_type target_lid(id_type gid) {
EXPECTS(is_local_group(gid));
return gid - target_gid_map_[domain_id()];
}
id_type group_lid(id_type gid) {
EXPECTS(is_local_group(gid));
return gid - group_gid_map_[domain_id()];
} }
// builds the optimized data structure // builds the optimized data structure
...@@ -117,19 +74,6 @@ public: ...@@ -117,19 +74,6 @@ public:
return communication_policy_.min(local_min); return communication_policy_.min(local_min);
} }
// return the local group index of the group which hosts the target with
// global id gid
id_type local_group_from_global_target(id_type gid) {
// assert that gid is in range
EXPECTS(is_local_target(gid));
return
std::distance(
target_map_.begin(),
std::upper_bound(target_map_.begin(), target_map_.end(), gid)
) - 1;
}
void add_spike(spike_type s) { void add_spike(spike_type s) {
thread_spikes().push_back(s); thread_spikes().push_back(s);
} }
...@@ -144,7 +88,6 @@ public: ...@@ -144,7 +88,6 @@ public:
} }
void exchange() { void exchange() {
// global all-to-all to gather a local copy of the global spike list // global all-to-all to gather a local copy of the global spike list
// on each node // on each node
//profiler_.enter("global exchange"); //profiler_.enter("global exchange");
...@@ -170,7 +113,7 @@ public: ...@@ -170,7 +113,7 @@ public:
// generate an event for each target // generate an event for each target
for (auto it=targets.first; it!=targets.second; ++it) { for (auto it=targets.first; it!=targets.second; ++it) {
auto gidx = local_group_from_global_target(it->destination()); auto gidx = it->destination().gid - cell_gid_from_;
events_[gidx].push_back(it->make_event(spike)); events_[gidx].push_back(it->make_event(spike));
} }
...@@ -181,18 +124,7 @@ public: ...@@ -181,18 +124,7 @@ public:
//profiler_.leave(); // event generation //profiler_.leave(); // event generation
} }
uint64_t num_spikes() const uint64_t num_spikes() const { return num_spikes_; }
{
return num_spikes_;
}
int domain_id() const {
return communication_policy_.id();
}
int num_domains() const {
return communication_policy_.size();
}
const std::vector<postsynaptic_spike_event>& queue(int i) const { const std::vector<postsynaptic_spike_event>& queue(int i) const {
return events_[i]; return events_[i];
...@@ -242,27 +174,9 @@ private: ...@@ -242,27 +174,9 @@ private:
std::vector<connection> connections_; std::vector<connection> connections_;
std::vector<std::vector<postsynaptic_spike_event>> events_; std::vector<std::vector<postsynaptic_spike_event>> events_;
// local target group i has targets in the half open range
// [target_map_[i], target_map_[i+1])
std::vector<id_type> target_map_;
// for keeping track of how time is spent where // for keeping track of how time is spent where
//util::Profiler profiler_; //util::Profiler profiler_;
// the number of groups and targets handled by this communicator
id_type num_groups_local_;
id_type num_targets_local_;
// index maps for the global distribution of groups and targets
// communicator i has the groups in the half open range :
// [group_gid_map_[i], group_gid_map_[i+1])
std::vector<id_type> group_gid_map_;
// communicator i has the targets in the half open range :
// [target_gid_map_[i], target_gid_map_[i+1])
std::vector<id_type> target_gid_map_;
communication_policy_type communication_policy_; communication_policy_type communication_policy_;
uint64_t num_spikes_ = 0u; uint64_t num_spikes_ = 0u;
......
...@@ -6,13 +6,10 @@ ...@@ -6,13 +6,10 @@
namespace nest { namespace nest {
namespace mc { namespace mc {
template < template <typename I>
typename I,
typename = typename std::enable_if<std::is_integral<I>::value>
>
struct spike { struct spike {
using id_type = I; using id_type = I;
id_type source = 0; id_type source = id_type{};
float time = -1.; float time = -1.;
spike() = default; spike() = default;
......
#include <limits>
#include "gtest.h" #include "gtest.h"
#include <catypes.hpp> #include <catypes.hpp>
...@@ -36,7 +34,7 @@ TEST(cell_group, test) ...@@ -36,7 +34,7 @@ TEST(cell_group, test)
using cell_type = cell_group<fvm::fvm_cell<double, cell_local_size_type>>; using cell_type = cell_group<fvm::fvm_cell<double, cell_local_size_type>>;
auto cell = cell_type{make_cell()}; auto cell = cell_type{0, make_cell()};
cell.advance(50, 0.01); cell.advance(50, 0.01);
......
...@@ -105,9 +105,7 @@ void run_neuron_baseline(const char* syn_type, const char* data_file) ...@@ -105,9 +105,7 @@ void run_neuron_baseline(const char* syn_type, const char* data_file)
std::vector<std::vector<double>> v(2); std::vector<std::vector<double>> v(2);
// make the lowered finite volume cell // make the lowered finite volume cell
cell_group<lowered_cell> group(cell); cell_group<lowered_cell> group(0, cell);
group.set_source_gids(0);
group.set_target_gids(0);
// add the 3 spike events to the queue // add the 3 spike events to the queue
group.enqueue_events(synthetic_events); group.enqueue_events(synthetic_events);
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment