From 68af1047aa32c5b70f2e2896cc0264268560408d Mon Sep 17 00:00:00 2001 From: bcumming <bcumming@cscs.ch> Date: Wed, 22 Jun 2016 17:11:26 +0200 Subject: [PATCH] first version of miniapp with spike->event communication --- .ycm_extra_conf.py | 3 + CMakeLists.txt | 15 +- cmake/FindTBB.cmake | 249 +++++++++++++ miniapp/CMakeLists.txt | 17 + miniapp/io.cpp | 27 ++ miniapp/io.hpp | 22 ++ miniapp/miniapp.cpp | 216 +++++++++++ miniapp/plot.py | 22 ++ src/cell.cpp | 7 +- src/cell.hpp | 26 ++ src/cell_group.hpp | 137 +++++++ src/communication/communicator.hpp | 253 +++++++++++++ src/communication/connection.hpp | 78 ++++ src/communication/mpi.hpp | 155 ++++++++ src/communication/mpi_global_policy.hpp | 48 +++ src/communication/serial_global_policy.hpp | 49 +++ src/{ => communication}/spike.hpp | 9 +- src/{ => communication}/spike_source.hpp | 45 +-- src/fvm_cell.hpp | 105 +++--- src/lowered_cell.hpp | 20 -- src/profiling/profiler.hpp | 397 +++++++++++++++++++++ src/threading/serial.hpp | 87 +++++ src/threading/tbb.hpp | 51 +++ src/threading/threading.hpp | 9 + tests/CMakeLists.txt | 4 +- tests/test_cell_group.cpp | 48 +++ tests/test_spikes.cpp | 70 ++++ 27 files changed, 2069 insertions(+), 100 deletions(-) create mode 100644 cmake/FindTBB.cmake create mode 100644 miniapp/CMakeLists.txt create mode 100644 miniapp/io.cpp create mode 100644 miniapp/io.hpp create mode 100644 miniapp/miniapp.cpp create mode 100644 miniapp/plot.py create mode 100644 src/cell_group.hpp create mode 100644 src/communication/communicator.hpp create mode 100644 src/communication/connection.hpp create mode 100644 src/communication/mpi.hpp create mode 100644 src/communication/mpi_global_policy.hpp create mode 100644 src/communication/serial_global_policy.hpp rename src/{ => communication}/spike.hpp (76%) rename src/{ => communication}/spike_source.hpp (77%) delete mode 100644 src/lowered_cell.hpp create mode 100644 src/profiling/profiler.hpp create mode 100644 src/threading/serial.hpp create mode 100644 src/threading/tbb.hpp create mode 100644 src/threading/threading.hpp create mode 100644 tests/test_cell_group.cpp create mode 100644 tests/test_spikes.cpp diff --git a/.ycm_extra_conf.py b/.ycm_extra_conf.py index c356012c..9033c063 100644 --- a/.ycm_extra_conf.py +++ b/.ycm_extra_conf.py @@ -36,6 +36,7 @@ import ycm_core # CHANGE THIS LIST OF FLAGS. YES, THIS IS THE DROID YOU HAVE BEEN LOOKING FOR. flags = [ '-DNDEBUG', + '-DWITH_TBB', '-std=c++11', '-x', 'c++', @@ -47,6 +48,8 @@ flags = [ 'include', '-I', 'external', + '-I', + 'miniapp', # '-isystem', # '/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.10.sdk/usr/include/c++/4.2.1', # '-I', diff --git a/CMakeLists.txt b/CMakeLists.txt index 7be418be..c988ae1a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,6 @@ set(SAVED_CXX_FLAGS "${CMAKE_CXX_FLAGS}") set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake") include("CompilerOptions") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXXOPT_DEBUG} ${CXXOPT_CXX11} ${CXXOPT_PTHREAD} ${CXXOPT_WALL}") -# -g -std=c++11 -pthread -Wall") # this generates a .json file with full compilation command for each file set(CMAKE_EXPORT_COMPILE_COMMANDS "YES") @@ -20,6 +19,14 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS "YES") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +# TBB support +set( WITH_TBB "OFF" CACHE BOOL "use TBB for on-node threading" ) +if( "${WITH_TBB}" STREQUAL "ON" ) + find_package(TBB REQUIRED) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWITH_TBB ${TBB_DEFINITIONS}") +endif() + + # targets for extermal dependencies include(ExternalProject) externalproject_add(modparser @@ -37,9 +44,15 @@ externalproject_add(modparser include_directories(${CMAKE_SOURCE_DIR}/external) include_directories(${CMAKE_SOURCE_DIR}/include) include_directories(${CMAKE_SOURCE_DIR}/src) +include_directories(${CMAKE_SOURCE_DIR}/miniapp) include_directories(${CMAKE_SOURCE_DIR}) +if( "${WITH_TBB}" STREQUAL "ON" ) + include_directories(${TBB_INCLUDE_DIRS}) +endif() + add_subdirectory(mechanisms) add_subdirectory(src) add_subdirectory(tests) +add_subdirectory(miniapp) diff --git a/cmake/FindTBB.cmake b/cmake/FindTBB.cmake new file mode 100644 index 00000000..7cb10d64 --- /dev/null +++ b/cmake/FindTBB.cmake @@ -0,0 +1,249 @@ +#----------------------------------------------------------------------------- +# from github justusc/FindTBB +#----------------------------------------------------------------------------- + +# The MIT License (MIT) +# +# Copyright (c) 2015 Justus Calvin +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# +# FindTBB +# ------- +# +# Find TBB include directories and libraries. +# +# Usage: +# +# find_package(TBB [major[.minor]] [EXACT] +# [QUIET] [REQUIRED] +# [[COMPONENTS] [components...]] +# [OPTIONAL_COMPONENTS components...]) +# +# where the allowed components are tbbmalloc and tbb_preview. Users may modify +# the behavior of this module with the following variables: +# +# * TBB_ROOT_DIR - The base directory the of TBB installation. +# * TBB_INCLUDE_DIR - The directory that contains the TBB headers files. +# * TBB_LIBRARY - The directory that contains the TBB library files. +# * TBB_<library>_LIBRARY - The path of the TBB the corresponding TBB library. +# These libraries, if specified, override the +# corresponding library search results, where <library> +# may be tbb, tbb_debug, tbbmalloc, tbbmalloc_debug, +# tbb_preview, or tbb_preview_debug. +# * TBB_USE_DEBUG_BUILD - The debug version of tbb libraries, if present, will +# be used instead of the release version. +# +# Users may modify the behavior of this module with the following environment +# variables: +# +# * TBB_INSTALL_DIR +# * TBBROOT +# * LIBRARY_PATH +# +# This module will set the following variables: +# +# * TBB_FOUND - Set to false, or undefined, if we haven’t found, or +# don’t want to use TBB. +# * TBB_<component>_FOUND - If False, optional <component> part of TBB sytem is +# not available. +# * TBB_VERSION - The full version string +# * TBB_VERSION_MAJOR - The major version +# * TBB_VERSION_MINOR - The minor version +# * TBB_INTERFACE_VERSION - The interface version number defined in +# tbb/tbb_stddef.h. +# * TBB_<library>_LIBRARY_RELEASE - The path of the TBB release version of +# <library>, where <library> may be tbb, tbb_debug, +# tbbmalloc, tbbmalloc_debug, tbb_preview, or +# tbb_preview_debug. +# * TBB_<library>_LIBRARY_DEGUG - The path of the TBB release version of +# <library>, where <library> may be tbb, tbb_debug, +# tbbmalloc, tbbmalloc_debug, tbb_preview, or +# tbb_preview_debug. +# +# The following varibles should be used to build and link with TBB: +# +# * TBB_INCLUDE_DIRS - The include directory for TBB. +# * TBB_LIBRARIES - The libraries to link against to use TBB. +# * TBB_DEFINITIONS - Definitions to use when compiling code that uses TBB. + +include(FindPackageHandleStandardArgs) + +if(NOT TBB_FOUND) + + ################################## + # Check the build type + ################################## + + if(NOT DEFINED TBB_USE_DEBUG_BUILD) + if(CMAKE_BUILD_TYPE MATCHES "[Debug|DEBUG|debug|RelWithDebInfo|RELWITHDEBINFO|relwithdebinfo]") + set(TBB_USE_DEBUG_BUILD TRUE) + else() + set(TBB_USE_DEBUG_BUILD FALSE) + endif() + endif() + + ################################## + # Set the TBB search directories + ################################## + + # Define search paths based on user input and environment variables + set(TBB_SEARCH_DIR ${TBB_ROOT_DIR} $ENV{TBB_INSTALL_DIR} $ENV{TBBROOT}) + + # Define the search directories based on the current platform + if(CMAKE_SYSTEM_NAME STREQUAL "Windows") + set(TBB_DEFAULT_SEARCH_DIR "C:/Program Files/Intel/TBB" + "C:/Program Files (x86)/Intel/TBB") + + # Set the target architecture + if(CMAKE_SIZEOF_VOID_P EQUAL 8) + set(TBB_ARCHITECTURE "intel64") + else() + set(TBB_ARCHITECTURE "ia32") + endif() + + # Set the TBB search library path search suffix based on the version of VC + if(WINDOWS_STORE) + set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc11_ui") + elseif(MSVC14) + set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc14") + elseif(MSVC12) + set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc12") + elseif(MSVC11) + set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc11") + elseif(MSVC10) + set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc10") + endif() + + # Add the library path search suffix for the VC independent version of TBB + list(APPEND TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc_mt") + + elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin") + # OS X + set(TBB_DEFAULT_SEARCH_DIR "/opt/intel/tbb") + + # TODO: Check to see which C++ library is being used by the compiler. + if(NOT ${CMAKE_SYSTEM_VERSION} VERSION_LESS 13.0) + # The default C++ library on OS X 10.9 and later is libc++ + set(TBB_LIB_PATH_SUFFIX "lib/libc++") + else() + set(TBB_LIB_PATH_SUFFIX "lib") + endif() + elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux") + # Linux + set(TBB_DEFAULT_SEARCH_DIR "/opt/intel/tbb") + + # TODO: Check compiler version to see the suffix should be <arch>/gcc4.1 or + # <arch>/gcc4.1. For now, assume that the compiler is more recent than + # gcc 4.4.x or later. + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + set(TBB_LIB_PATH_SUFFIX "lib/intel64/gcc4.4") + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^i.86$") + set(TBB_LIB_PATH_SUFFIX "lib/ia32/gcc4.4") + endif() + endif() + + ################################## + # Find the TBB include dir + ################################## + + find_path(TBB_INCLUDE_DIRS tbb/tbb.h + HINTS ${TBB_INCLUDE_DIR} ${TBB_SEARCH_DIR} + PATHS ${TBB_DEFAULT_SEARCH_DIR} + PATH_SUFFIXES include) + + ################################## + # Find TBB components + ################################## + + # Find each component + foreach(_comp tbb_preview tbbmalloc tbb) + # Search for the libraries + find_library(TBB_${_comp}_LIBRARY_RELEASE ${_comp} + HINTS ${TBB_LIBRARY} ${TBB_SEARCH_DIR} + PATHS ${TBB_DEFAULT_SEARCH_DIR} + PATH_SUFFIXES "${TBB_LIB_PATH_SUFFIX}") + + find_library(TBB_${_comp}_LIBRARY_DEBUG ${_comp}_debug + HINTS ${TBB_LIBRARY} ${TBB_SEARCH_DIR} + PATHS ${TBB_DEFAULT_SEARCH_DIR} ENV LIBRARY_PATH + PATH_SUFFIXES "${TBB_LIB_PATH_SUFFIX}") + + + # Set the library to be used for the component + if(NOT TBB_${_comp}_LIBRARY) + if(TBB_USE_DEBUG_BUILD AND TBB_${_comp}_LIBRARY_DEBUG) + set(TBB_${_comp}_LIBRARY "${TBB_${_comp}_LIBRARY_DEBUG}") + elseif(TBB_${_comp}_LIBRARY_RELEASE) + set(TBB_${_comp}_LIBRARY "${TBB_${_comp}_LIBRARY_RELEASE}") + elseif(TBB_${_comp}_LIBRARY_DEBUG) + set(TBB_${_comp}_LIBRARY "${TBB_${_comp}_LIBRARY_DEBUG}") + endif() + endif() + + # Set the TBB library list and component found variables + if(TBB_${_comp}_LIBRARY) + list(APPEND TBB_LIBRARIES "${TBB_${_comp}_LIBRARY}") + set(TBB_${_comp}_FOUND TRUE) + else() + set(TBB_${_comp}_FOUND FALSE) + endif() + + mark_as_advanced(TBB_${_comp}_LIBRARY_RELEASE) + mark_as_advanced(TBB_${_comp}_LIBRARY_DEBUG) + mark_as_advanced(TBB_${_comp}_LIBRARY) + + endforeach() + + ################################## + # Set compile flags + ################################## + + if(TBB_tbb_LIBRARY MATCHES "debug") + set(TBB_DEFINITIONS "-DTBB_USE_DEBUG=1") + endif() + + ################################## + # Set version strings + ################################## + + if(TBB_INCLUDE_DIRS) + file(READ "${TBB_INCLUDE_DIRS}/tbb/tbb_stddef.h" _tbb_version_file) + string(REGEX REPLACE ".*#define TBB_VERSION_MAJOR ([0-9]+).*" "\\1" + TBB_VERSION_MAJOR "${_tbb_version_file}") + string(REGEX REPLACE ".*#define TBB_VERSION_MINOR ([0-9]+).*" "\\1" + TBB_VERSION_MINOR "${_tbb_version_file}") + string(REGEX REPLACE ".*#define TBB_INTERFACE_VERSION ([0-9]+).*" "\\1" + TBB_INTERFACE_VERSION "${_tbb_version_file}") + set(TBB_VERSION "${TBB_VERSION_MAJOR}.${TBB_VERSION_MINOR}") + endif() + + find_package_handle_standard_args(TBB + REQUIRED_VARS TBB_INCLUDE_DIRS TBB_LIBRARIES + HANDLE_COMPONENTS + VERSION_VAR TBB_VERSION) + + mark_as_advanced(TBB_INCLUDE_DIRS TBB_LIBRARIES) + + unset(TBB_ARCHITECTURE) + unset(TBB_LIB_PATH_SUFFIX) + unset(TBB_DEFAULT_SEARCH_DIR) + +endif() diff --git a/miniapp/CMakeLists.txt b/miniapp/CMakeLists.txt new file mode 100644 index 00000000..196ce7ae --- /dev/null +++ b/miniapp/CMakeLists.txt @@ -0,0 +1,17 @@ +set(HEADERS +) +set(MINIAPP_SOURCES + io.cpp + miniapp.cpp +) + +add_executable(miniapp.exe ${MINIAPP_SOURCES} ${HEADERS}) + +target_link_libraries(miniapp.exe LINK_PUBLIC cellalgo) +target_link_libraries(miniapp.exe LINK_PUBLIC ${TBB_LIBRARIES}) + +set_target_properties(miniapp.exe + PROPERTIES + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/miniapp" +) + diff --git a/miniapp/io.cpp b/miniapp/io.cpp new file mode 100644 index 00000000..90498c4c --- /dev/null +++ b/miniapp/io.cpp @@ -0,0 +1,27 @@ +#include "io.hpp" + +namespace nest { +namespace mc { +namespace io { + +// read simulation options from json file with name fname +// for now this is just a placeholder +options read_options(std::string fname) +{ + // 10 cells, 1 synapses per cell, 10 compartments per segment + return {5, 1, 100}; +} + +std::ostream& operator<<(std::ostream& o, const options& opt) +{ + o << "simultion options:\n"; + o << " cells : " << opt.cells << "\n"; + o << " compartments/segment : " << opt.compartments_per_segment << "\n"; + o << " synapses/cell : " << opt.synapses_per_cell << "\n"; + + return o; +} + +} // namespace io +} // namespace mc +} // namespace nest diff --git a/miniapp/io.hpp b/miniapp/io.hpp new file mode 100644 index 00000000..e8a8762c --- /dev/null +++ b/miniapp/io.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include <json/src/json.hpp> + +namespace nest { +namespace mc { +namespace io { + +// holds the options for a simulation run +struct options { + int cells; + int synapses_per_cell; + int compartments_per_segment; +}; + +std::ostream& operator<<(std::ostream& o, const options& opt); + +options read_options(std::string fname); + +} // namespace io +} // namespace mc +} // namespace nest diff --git a/miniapp/miniapp.cpp b/miniapp/miniapp.cpp new file mode 100644 index 00000000..5fbc8b36 --- /dev/null +++ b/miniapp/miniapp.cpp @@ -0,0 +1,216 @@ +#include <iostream> + +#include <cell.hpp> +#include <cell_group.hpp> +#include <fvm_cell.hpp> +#include <mechanism_interface.hpp> + +#include "io.hpp" +#include "threading/threading.hpp" +#include "profiling/profiler.hpp" +#include "communication/communicator.hpp" +#include "communication/serial_global_policy.hpp" + +using namespace nest; + +using real_type = double; +using index_type = int; +using numeric_cell = mc::fvm::fvm_cell<real_type, index_type>; +using cell_group = mc::cell_group<numeric_cell>; +using communicator_type = + mc::communication::communicator<mc::communication::serial_global_policy>; + +// define some global model parameters +namespace parameters { +namespace synapses { + // synapse delay + constexpr double delay = 5.0; // ms + + // connection weight + constexpr double weight = 0.05; // uS +} +} + +/////////////////////////////////////// +// prototypes +/////////////////////////////////////// + +/// make a single abstract cell +mc::cell make_cell(int compartments_per_segment); + +/// do basic setup (initialize global state, print banner, etc) +void setup(); + +/// helper function for initializing cells +cell_group make_lowered_cell(int cell_index, const mc::cell& c); + +/////////////////////////////////////// +// main +/////////////////////////////////////// +int main(void) { + + setup(); + + // read parameters + mc::io::options opt; + try { + opt = mc::io::read_options(""); + std::cout << opt << "\n"; + } + catch (std::exception e) { + std::cerr << e.what() << std::endl; + exit(1); + } + + ///////////////////////////////////////////////////// + // make cells + ///////////////////////////////////////////////////// + + // make a basic cell + auto basic_cell = make_cell(opt.compartments_per_segment); + + // make a vector for storing all of the cells + auto start_init = mc::util::timer_type::tic(); + std::vector<cell_group> cell_groups(opt.cells); + + // initialize the cells in parallel + mc::threading::parallel_for::apply( + 0, opt.cells, + [&](int i) { + // initialize cell + cell_groups[i] = make_lowered_cell(i, basic_cell); + } + ); + auto time_init = mc::util::timer_type::toc(start_init); + + ///////////////////////////////////////////////////// + // network creation + ///////////////////////////////////////////////////// + + // calculate the source and synapse distribution serially + auto start_network = mc::util::timer_type::tic(); + std::vector<uint32_t> target_counts(opt.cells); + std::vector<uint32_t> source_counts(opt.cells); + for (auto i=0; i<opt.cells; ++i) { + target_counts[i] = cell_groups[i].cell().synapses()->size(); + source_counts[i] = cell_groups[i].spike_sources().size(); + } + + auto target_map = mc::algorithms::make_index(target_counts); + auto source_map = mc::algorithms::make_index(source_counts); + + // create connections + communicator_type communicator(opt.cells, target_counts); + for(auto i=0u; i<(uint32_t)opt.cells; ++i) { + communicator.add_connection({ + i, (i+1)%opt.cells, + parameters::synapses::weight, parameters::synapses::delay + }); + } + communicator.construct(); + + auto global_source_map = + communicator.communication_policy().make_map(source_map.back()); + auto domain_idx = communicator.communication_policy().id(); + for(auto i=0u; i<(uint32_t)opt.cells; ++i) { + cell_groups[i].set_source_gids(source_map[i]+global_source_map[domain_idx]); + cell_groups[i].set_target_lids(target_map[i]); + } + + auto time_network = mc::util::timer_type::toc(start_network); + + ///////////////////////////////////////////////////// + // time stepping + ///////////////////////////////////////////////////// + auto start_simulation = mc::util::timer_type::tic(); + + auto tfinal = 20.; + auto t = 0.; + auto dt = 0.01; + auto delta = communicator.min_delay(); + + communicator.add_spike({opt.cells-1u, 5}); + + while(t<tfinal) { + mc::threading::parallel_for::apply( + 0, opt.cells, + [&](int i) { + /*if(communicator.queue(i).size()) { + std::cout << ":: delivering events to group " << i << "\n"; + std::cout << " " << communicator.queue(i) << "\n"; + }*/ + cell_groups[i].enqueue_events(communicator.queue(i)); + cell_groups[i].advance(t+delta, dt); + communicator.add_spikes(cell_groups[i].spikes()); + cell_groups[i].clear_spikes(); + } + ); + + communicator.exchange(); + + t += delta; + } + + for(auto i=0u; i<cell_groups.size(); ++i) { + cell_groups[i].splat("cell"+std::to_string(i)+".txt"); + } + + auto time_simulation = mc::util::timer_type::toc(start_simulation); + + std::cout << "initialization took " << time_init << " s\n"; + std::cout << "network took " << time_network << " s\n"; + std::cout << "simulation took " << time_simulation << " s\n"; + std::cout << "performed " << int(tfinal/dt) << " time steps\n"; +} + +/////////////////////////////////////// +// function definitions +/////////////////////////////////////// + +void setup() +{ + // print banner + std::cout << "====================\n"; + std::cout << " starting miniapp\n"; + std::cout << " - " << mc::threading::description() << " threading support\n"; + std::cout << "====================\n"; + + // setup global state for the mechanisms + mc::mechanisms::setup_mechanism_helpers(); +} + +// make a high level cell description for use in simulation +mc::cell make_cell(int compartments_per_segment) +{ + nest::mc::cell cell; + + // Soma with diameter 12.6157 um and HH channel + auto soma = cell.add_soma(12.6157/2.0); + soma->add_mechanism(mc::hh_parameters()); + + // add dendrite of length 200 um and diameter 1 um with passive channel + std::vector<mc::cable_segment*> dendrites; + dendrites.push_back(cell.add_cable(0, mc::segmentKind::dendrite, 0.5, 0.5, 200)); + //dendrites.push_back(cell.add_cable(1, mc::segmentKind::dendrite, 0.5, 0.25, 100)); + //dendrites.push_back(cell.add_cable(1, mc::segmentKind::dendrite, 0.5, 0.25, 100)); + + for(auto d : dendrites) { + d->add_mechanism(mc::pas_parameters()); + d->set_compartments(compartments_per_segment); + d->mechanism("membrane").set("r_L", 100); + } + + // add stimulus + //cell.add_stimulus({1,1}, {5., 80., 0.3}); + + cell.add_detector({0,0}, 30); + cell.add_synapse({1, 0.5}); + + return cell; +} + +cell_group make_lowered_cell(int cell_index, const mc::cell& c) +{ + return cell_group(c); +} + diff --git a/miniapp/plot.py b/miniapp/plot.py new file mode 100644 index 00000000..ced8f4f0 --- /dev/null +++ b/miniapp/plot.py @@ -0,0 +1,22 @@ +from matplotlib import pyplot +import numpy as np + +ncol = 3 + +raw = np.fromfile("cell0.txt", sep=" ") +n = raw.size/ncol +data = raw.reshape(n,ncol) + +t = data[:, 0] +soma = data[:, 1] +dend = data[:, 2] + +pyplot.plot(t, soma, 'k') +pyplot.plot(t, dend, 'r') + +pyplot.xlabel('time (ms)') +pyplot.ylabel('mV') +pyplot.xlim([t[0], t[n-1]]) +pyplot.grid() +pyplot.show() + diff --git a/src/cell.cpp b/src/cell.cpp index 68783157..fb9147ed 100644 --- a/src/cell.cpp +++ b/src/cell.cpp @@ -177,7 +177,7 @@ compartment_model cell::model() const } -void cell::add_stimulus( segment_location loc, i_clamp stim) +void cell::add_stimulus(segment_location loc, i_clamp stim) { if(!(loc.segment<num_segments())) { throw std::out_of_range( @@ -190,6 +190,11 @@ void cell::add_stimulus( segment_location loc, i_clamp stim) stimulii_.push_back({loc, std::move(stim)}); } +void cell::add_detector(segment_location loc, double threshold) +{ + spike_detectors_.push_back({loc, threshold}); +} + std::vector<int> const& cell::segment_parents() const { return parents_; diff --git a/src/cell.hpp b/src/cell.hpp index 2f336f84..0e13e588 100644 --- a/src/cell.hpp +++ b/src/cell.hpp @@ -26,6 +26,9 @@ struct segment_location { { EXPECTS(position>=0. && position<=1.); } + friend bool operator==(segment_location l, segment_location r) { + return l.segment==r.segment && l.position==r.position; + } int segment; double position; }; @@ -99,6 +102,9 @@ class cell { compartment_model model() const; + ////////////////// + // stimulii + ////////////////// void add_stimulus(segment_location loc, i_clamp stim); std::vector<std::pair<segment_location, i_clamp>>& @@ -111,10 +117,27 @@ class cell { return stimulii_; } + ////////////////// + // synapses + ////////////////// void add_synapse(segment_location loc); const std::vector<segment_location>& synapses() const; + ////////////////// + // spike detectors + ////////////////// + void add_detector(segment_location loc, double threshold); + + std::vector<std::pair<segment_location, double>>& + detectors() { + return spike_detectors_; + } + + const std::vector<std::pair<segment_location, double>>& + detectors() const { + return spike_detectors_; + } private: @@ -129,6 +152,9 @@ class cell { // the synapses std::vector<segment_location> synapses_; + + // the sensors + std::vector<std::pair<segment_location, double>> spike_detectors_; }; // Checks that two cells have the same diff --git a/src/cell_group.hpp b/src/cell_group.hpp new file mode 100644 index 00000000..d9ea4551 --- /dev/null +++ b/src/cell_group.hpp @@ -0,0 +1,137 @@ +#pragma once + +#include <vector> + +#include <cell.hpp> +#include <event_queue.hpp> +#include <communication/spike.hpp> +#include <communication/spike_source.hpp> + +namespace nest { +namespace mc { + +template <typename Cell> +class cell_group { + public : + + using index_type = uint32_t; + using cell_type = Cell; + using value_type = typename cell_type::value_type; + using size_type = typename cell_type::value_type; + using spike_detector_type = spike_detector<Cell>; + + struct spike_source_type { + index_type index; + spike_detector_type source; + }; + + cell_group() = default; + + cell_group(const cell& c) : + cell_{c} + { + cell_.voltage()(memory::all) = -65.; + cell_.initialize(); + + for(auto& d : c.detectors()) { + spike_sources_.push_back( { + 0u, spike_detector_type(cell_, d.first, d.second, 0.f) + }); + } + } + + void set_source_gids(index_type gid) { + for(auto& s : spike_sources_) { + s.index = gid++; + } + } + + void set_target_lids(index_type lid) { + first_target_lid_ = lid; + } + + void splat(std::string fname) { + char buffer[128]; + auto fid = std::ofstream(fname); + for(auto i=0u; i<tt.size(); ++i) { + sprintf(buffer, "%8.4f %16.8f %16.8f\n", tt[i], vs[i], vd[i]); + fid << buffer; + } + } + + void advance(double tfinal, double dt) { + + while (cell_.time()<tfinal) { + tt.push_back(cell_.time()); + vs.push_back(cell_.voltage({0,0.0})); + vd.push_back(cell_.voltage({1,0.5})); + + // look for events in the next time step + auto tstep = std::min(tfinal, cell_.time()+dt); + auto next = events_.pop_if_before(tstep); + auto tnext = next ? next->time: tstep; + + // integrate cell state + cell_.advance(tnext - cell_.time()); + + // check for new spikes + for (auto& s : spike_sources_) { + auto spike = s.source.test(cell_, cell_.time()); + if(spike) { + spikes_.push_back({s.index, spike.get()}); + } + } + + // apply events + if (next) { + cell_.apply_event(next.get()); + } + } + + } + + template <typename R> + void enqueue_events(R events) { + for(auto e : events) { + e.target -= first_target_lid_; + events_.push(e); + } + } + + const std::vector<communication::spike<index_type>>& + spikes() const { + return spikes_; + } + + cell_type& cell() { return cell_; } + const cell_type& cell() const { return cell_; } + + const std::vector<spike_source_type>& + spike_sources() const + { + return spike_sources_; + } + + void clear_spikes() { + spikes_.clear(); + } + + private : + + // TEMPORARY... + std::vector<float> tt; + std::vector<float> vs; + std::vector<float> vd; + + cell_type cell_; + std::vector<spike_source_type> spike_sources_; + + // spikes that are generated + std::vector<communication::spike<index_type>> spikes_; + event_queue events_; + + index_type first_target_lid_; +}; + +} // namespace mc +} // namespace nest diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp new file mode 100644 index 00000000..1951143c --- /dev/null +++ b/src/communication/communicator.hpp @@ -0,0 +1,253 @@ +#pragma once + +#include <algorithm> +#include <iostream> +#include <vector> +#include <random> + +#include <communication/spike.hpp> +#include <threading/threading.hpp> +#include <algorithms.hpp> +#include <event_queue.hpp> + + +#include "connection.hpp" + +namespace nest { +namespace mc { +namespace communication { + +// When the communicator is constructed the number of target groups and targets +// is specified, along with a mapping between local cell id and local +// target id. +// +// The user can add connections to an existing communicator object, where +// each connection is between any global cell and any local target. +// +// Once all connections have been specified, the construct() method can be used +// to build the data structures required for efficient spike communication and +// event generation. +template <typename CommunicationPolicy> +class communicator { +public: + using index_type = uint32_t; + using communication_policy_type = CommunicationPolicy; + + using spike_type = spike<index_type>; + + communicator( index_type n_groups, std::vector<index_type> target_counts) : + num_groups_local_(n_groups), + num_targets_local_(target_counts.size()) + { + communicator_id_ = communication_policy_.id(); + + target_map_ = nest::mc::algorithms::make_index(target_counts); + num_targets_local_ = target_map_.back(); + + // create an event queue for each target group + events_.resize(num_groups_local_); + + // make maps for converting lid to gid + target_gid_map_ = communication_policy_.make_map(num_targets_local_); + group_gid_map_ = communication_policy_.make_map(num_groups_local_); + + // transform the target ids from lid to gid + auto first_target = target_gid_map_[communicator_id_]; + for(auto &id : target_map_) { + id += first_target; + } + } + + void add_connection(connection con) { + EXPECTS(is_local_target(con.destination())); + connections_.push_back(con); + } + + bool is_local_target(index_type gid) { + return gid>=target_gid_map_[communicator_id_] + && gid<target_gid_map_[communicator_id_+1]; + } + + bool is_local_group(index_type gid) { + return gid>=group_gid_map_[communicator_id_] + && gid<group_gid_map_[communicator_id_+1]; + } + + /// return the global id of the first group in domain d + /// the groups in domain d are in the contiguous half open range + /// [domain_first_group(d), domain_first_group(d+1)) + index_type group_gid_first(int d) const { + return group_gid_map_[d]; + } + + index_type target_lid(index_type gid) { + EXPECTS(is_local_group(gid)); + + return gid - target_gid_map_[communicator_id_]; + } + + // builds the optimized data structure + void construct() { + if(!std::is_sorted(connections_.begin(), connections_.end())) { + std::sort(connections_.begin(), connections_.end()); + } + } + + float min_delay() { + auto local_min = std::numeric_limits<float>::max(); + for(auto& con : connections_) { + local_min = std::min(local_min, con.delay()); + } + + return communication_policy_.min(local_min); + } + + // return the local group index of the group which hosts the target with + // global id gid + index_type local_group_from_global_target(index_type gid) { + // assert that gid is in range + EXPECTS(is_local_target(gid)); + + return + std::distance( + target_map_.begin(), + std::upper_bound(target_map_.begin(), target_map_.end(), gid) + ) - 1; + } + + void add_spike(spike_type s) { + thread_spikes().push_back(s); + } + + void add_spikes(const std::vector<spike_type>& s) { + auto& v = thread_spikes(); + v.insert(v.end(), s.begin(), s.end()); + } + + std::vector<spike_type>& thread_spikes() { + return thread_spikes_.local(); + } + + void exchange() { + + // global all-to-all to gather a local copy of the global spike list + // on each node + //profiler_.enter("global exchange"); + auto global_spikes = communication_policy_.gather_spikes(local_spikes()); + clear_thread_spike_buffers(); + //profiler_.leave(); + + for(auto& q : events_) { + q.clear(); + } + + //profiler_.enter("events"); + + //profiler_.enter("make events"); + // check all global spikes to see if they will generate local events + for(auto spike : global_spikes) { + // search for targets + auto targets = + std::equal_range( + connections_.begin(), connections_.end(), spike.source + ); + + // generate an event for each target + for(auto it=targets.first; it!=targets.second; ++it) { + auto gidx = local_group_from_global_target(it->destination()); + + events_[gidx].push_back(it->make_event(spike)); + } + } + + //profiler_.leave(); // make events + + //profiler_.leave(); // event generation + } + + const std::vector<local_event>& queue(int i) + { + return events_[i]; + } + + std::vector<connection>const& connections() const { + return connections_; + } + + communication_policy_type& communication_policy() { + return communication_policy_; + } + + const std::vector<index_type>& local_target_map() const { + return target_map_; + } + + std::vector<spike_type> local_spikes() { + std::vector<spike_type> spikes; + for(auto& v : thread_spikes_) { + spikes.insert(spikes.end(), v.begin(), v.end()); + } + return spikes; + } + + void clear_thread_spike_buffers() { + for(auto& v : thread_spikes_) { + v.clear(); + } + } + +private: + + // FIXME : race condition on the thread_spikes_ buffers when exchange() modifies/access them + // ... other threads will be pushing to them simultaneously + // FIXME : race condition on the group-specific event queues when exchange pusheds to them + // ... other threads will be accessing them to get events + + // thread private storage for accumulating spikes + using local_spike_store_type = + nest::mc::threading:: enumerable_thread_specific<std::vector<spike_type>>; + /* +#ifdef WITH_TBB + using local_spike_store_type = + tbb::enumerable_thread_specific<std::vector<spike_type>>; +#else + using local_spike_store_type = + std::array<std::vector<spike_type>, 1>; +#endif + */ + local_spike_store_type thread_spikes_; + + std::vector<connection> connections_; + std::vector<std::vector<nest::mc::local_event>> events_; + + // local target group i has targets in the half open range + // [target_map_[i], target_map_[i+1]) + std::vector<index_type> target_map_; + + // for keeping track of how time is spent where + //util::Profiler profiler_; + + // the number of groups and targets handled by this communicator + index_type num_groups_local_; + index_type num_targets_local_; + + // index maps for the global distribution of groups and targets + + // communicator i has the groups in the half open range : + // [group_gid_map_[i], group_gid_map_[i+1]) + std::vector<index_type> group_gid_map_; + + // communicator i has the targets in the half open range : + // [target_gid_map_[i], target_gid_map_[i+1]) + std::vector<index_type> target_gid_map_; + + // each communicator has a unique id + // e.g. for MPI this could be the MPI rank + index_type communicator_id_; + + communication_policy_type communication_policy_; +}; + +} // namespace communication +} // namespace mc +} // namespace nest diff --git a/src/communication/connection.hpp b/src/communication/connection.hpp new file mode 100644 index 00000000..7519a286 --- /dev/null +++ b/src/communication/connection.hpp @@ -0,0 +1,78 @@ +#pragma once + +#include <cstdint> + +#include <event_queue.hpp> +#include <communication/spike.hpp> + +namespace nest { +namespace mc { +namespace communication { + +class connection { +public: + using index_type = uint32_t; + connection(index_type src, index_type dest, float w, float d) + : source_(src), + destination_(dest), + weight_(w), + delay_(d) + { } + + float weight() const { + return weight_; + } + float delay() const { + return delay_; + } + + index_type source() const { + return source_; + } + index_type destination() const { + return destination_; + } + + local_event make_event(spike<index_type> s) { + return {destination_, s.time + delay_, weight_}; + } + +private: + + index_type source_; + index_type destination_; + float weight_; + float delay_; +}; + +// connections are sorted by source id +// these operators make for easy interopability with STL algorithms + +static inline +bool operator< (connection lhs, connection rhs) { + return lhs.source() < rhs.source(); +} + +static inline +bool operator< (connection lhs, connection::index_type rhs) { + return lhs.source() < rhs; +} + +static inline +bool operator< (connection::index_type lhs, connection rhs) { + return lhs < rhs.source(); +} + +} // namespace communication +} // namespace mc +} // namespace nest + +static inline +std::ostream& operator<<(std::ostream& o, nest::mc::communication::connection const& con) { + char buff[512]; + snprintf( + buff, sizeof(buff), "con [%10u -> %10u : weight %8.4f, delay %8.4f]", + con.source(), con.destination(), con.weight(), con.delay() + ); + return o << buff; +} diff --git a/src/communication/mpi.hpp b/src/communication/mpi.hpp new file mode 100644 index 00000000..398e6033 --- /dev/null +++ b/src/communication/mpi.hpp @@ -0,0 +1,155 @@ +#pragma once + +#include <algorithm> +#include <iostream> +#include <vector> + +#include <cassert> + +#include <mpi.h> +#include "utils.hpp" +#include "utils.hpp" +namespace mpi { + + template <typename T> + struct mpi_traits { + constexpr static size_t count() { + return sizeof(T); + } + constexpr static MPI_Datatype mpi_type() { + return MPI_CHAR; + } + constexpr static bool is_mpi_native_type() { + return false; + } + }; + + #define MAKE_TRAITS(T,M) \ + template <> \ + struct mpi_traits<T> { \ + constexpr static size_t count() { return 1; } \ + constexpr static MPI_Datatype mpi_type() { return M; } \ + constexpr static bool is_mpi_native_type() { return true; } \ + }; + + MAKE_TRAITS(double, MPI_DOUBLE) + MAKE_TRAITS(float, MPI_FLOAT) + MAKE_TRAITS(int, MPI_INT) + MAKE_TRAITS(long int, MPI_LONG) + MAKE_TRAITS(char, MPI_CHAR) + MAKE_TRAITS(size_t, MPI_UNSIGNED_LONG) + static_assert(sizeof(size_t)==sizeof(unsigned long), + "size_t and unsigned long are not equivalent"); + + bool init(int *argc, char ***argv); + bool finalize(); + bool is_root(); + int rank(); + int size(); + void barrier(); + + template <typename T> + std::vector<T> gather(T value, int root) { + using traits = mpi_traits<T>; + auto buffer_size = (rank()==root) ? size() : 0; + std::vector<T> buffer(buffer_size); + + MPI_Gather( &value, traits::count(), traits::mpi_type(), // send buffer + buffer.data(), traits::count(), traits::mpi_type(), // receive buffer + root, MPI_COMM_WORLD); + + return buffer; + } + + template <typename T> + std::vector<T> gather_all(T value) { + using traits = mpi_traits<T>; + std::vector<T> buffer(size()); + + MPI_Allgather( &value, traits::count(), traits::mpi_type(), // send buffer + buffer.data(), traits::count(), traits::mpi_type(), // receive buffer + MPI_COMM_WORLD); + + return buffer; + } + + template <typename T> + std::vector<T> gather_all(const std::vector<T> &values) { + using traits = mpi_traits<T>; + auto counts = gather_all(int(values.size())); + for(auto& c : counts) { + c *= traits::count(); + } + auto displs = algorithms::make_map(counts); + + std::vector<T> buffer(displs.back()/traits::count()); + + MPI_Allgatherv( + // send buffer + values.data(), counts[rank()], traits::mpi_type(), + // receive buffer + buffer.data(), counts.data(), displs.data(), traits::mpi_type(), + MPI_COMM_WORLD + ); + + return buffer; + } + + template <typename T> + T reduce(T value, MPI_Op op, int root) { + using traits = mpi_traits<T>; + static_assert(traits::is_mpi_native_type(), + "can only perform reductions on MPI native types"); + + T result; + + MPI_Reduce(&value, &result, 1, traits::mpi_type(), op, root, MPI_COMM_WORLD); + + return result; + } + + template <typename T> + T reduce(T value, MPI_Op op) { + using traits = mpi_traits<T>; + static_assert(traits::is_mpi_native_type(), + "can only perform reductions on MPI native types"); + + T result; + + MPI_Allreduce(&value, &result, 1, traits::mpi_type(), op, MPI_COMM_WORLD); + + return result; + } + + template <typename T> + std::pair<T,T> minmax(T value) { + return {reduce<T>(value, MPI_MIN), reduce<T>(value, MPI_MAX)}; + } + + template <typename T> + std::pair<T,T> minmax(T value, int root) { + return {reduce<T>(value, MPI_MIN, root), reduce<T>(value, MPI_MAX, root)}; + } + + bool ballot(bool vote); + + template <typename T> + T broadcast(T value, int root) { + using traits = mpi_traits<T>; + + MPI_Bcast(&value, traits::count(), traits::mpi_type(), root, MPI_COMM_WORLD); + + return value; + } + + template <typename T> + T broadcast(int root) { + using traits = mpi_traits<T>; + T value; + + MPI_Bcast(&value, traits::count(), traits::mpi_type(), root, MPI_COMM_WORLD); + + return value; + } + +} // namespace mpi diff --git a/src/communication/mpi_global_policy.hpp b/src/communication/mpi_global_policy.hpp new file mode 100644 index 00000000..43d830b5 --- /dev/null +++ b/src/communication/mpi_global_policy.hpp @@ -0,0 +1,48 @@ +#pragma once + +#include <type_traits> +#include <vector> + +#include <cstdint> + +#include <communication/spike.hpp> +#include <algorithms.hpp> + +#include "mpi.hpp" + +namespace nest { +namespace mc { +namespace communication { + +struct mpi_global_policy { + std::vector<spike<uint32_t>> const + gather_spikes(const std::vector<spike<uint32_t>>& local_spikes) { + return mpi::gather_all(local_spikes); + } + + int id() const { + return mpi::rank(); + } + + /* + template <typename T> + T min(T value) const { + } + */ + + int num_communicators() const { + return mpi::size(); + } + + template < + typename T, + typename = typename std::enable_if<std::is_integral<T>::value> + > + std::vector<T> make_map(T local) { + return algorithms::make_index(mpi::gather_all(local)); + } +}; + +} // namespace communication +} // namespace mc +} // namespace nest diff --git a/src/communication/serial_global_policy.hpp b/src/communication/serial_global_policy.hpp new file mode 100644 index 00000000..94c2c430 --- /dev/null +++ b/src/communication/serial_global_policy.hpp @@ -0,0 +1,49 @@ +#pragma once + +#include <type_traits> +#include <vector> + +#include <cstdint> + +#include <communication/spike.hpp> + +namespace nest { +namespace mc { +namespace communication { + +struct serial_global_policy { + std::vector<spike<uint32_t>> const + gather_spikes(const std::vector<spike<uint32_t>>& local_spikes) { + return local_spikes; + } + + static int id() { + return 0; + } + + static int num_communicators() { + return 1; + } + + template <typename T> + T min(T value) const { + return value; + } + + template <typename T> + T max(T value) const { + return value; + } + + template < + typename T, + typename = typename std::enable_if<std::is_integral<T>::value> + > + std::vector<T> make_map(T local) { + return {T(0), local}; + } +}; + +} // namespace communication +} // namespace mc +} // namespace nest diff --git a/src/spike.hpp b/src/communication/spike.hpp similarity index 76% rename from src/spike.hpp rename to src/communication/spike.hpp index cd250ba5..09c66522 100644 --- a/src/spike.hpp +++ b/src/communication/spike.hpp @@ -5,6 +5,7 @@ namespace nest { namespace mc { +namespace communication { template < typename I, @@ -24,17 +25,21 @@ struct spike { } // namespace mc } // namespace nest +} // namespace communication /// custom stream operator for printing nest::mc::spike<> values template <typename I> -std::ostream& operator <<(std::ostream& o, nest::mc::spike<I> s) { +std::ostream& operator <<(std::ostream& o, nest::mc::communication::spike<I> s) { return o << "spike[t " << s.time << ", src " << s.source << "]"; } /// less than comparison operator for nest::mc::spike<> values /// spikes are ordered by spike time, for use in sorting and queueing template <typename I> -bool operator <(nest::mc::spike<I> lhs, nest::mc::spike<I> rhs) { +bool operator <( + nest::mc::communication::spike<I> lhs, + nest::mc::communication::spike<I> rhs) +{ return lhs.time < rhs.time; } diff --git a/src/spike_source.hpp b/src/communication/spike_source.hpp similarity index 77% rename from src/spike_source.hpp rename to src/communication/spike_source.hpp index 73c54995..d3258b7c 100644 --- a/src/spike_source.hpp +++ b/src/communication/spike_source.hpp @@ -1,44 +1,39 @@ #pragma once #include <cell.hpp> -#include <fvm_cell.hpp> #include <util/optional.hpp> namespace nest { namespace mc { -// generic spike source -class spike_source { - public: - - virtual util::optional<float> test(float t) = 0; -}; - // spike detector for a lowered cell template <typename Cell> -class spike_detector : public spike_source +class spike_detector { - public: +public: using cell_type = Cell; spike_detector( - cell_type const* cell, + const cell_type& cell, segment_location loc, double thresh, float t_init ) - : cell_(cell), - location_(loc), - threshold_(thresh), - previous_t_(t_init) + : location_(loc), + threshold_(thresh), + previous_t_(t_init) { - previous_v_ = cell->voltage(location_); + previous_v_ = cell.voltage(location_); is_spiking_ = previous_v_ >= thresh ? true : false; } - util::optional<float> test(float t) override { + util::optional<float> test(const cell_type& cell, float t) + { util::optional<float> result = util::nothing; - auto v = cell_->voltage(location_); + auto v = cell.voltage(location_); + + // these if statements could be simplified, but I keep them like + // this to clearly reflect the finite state machine if (!is_spiking_) { if (v>=threshold_) { // the threshold has been passed, so estimate the time using @@ -69,10 +64,18 @@ class spike_detector : public spike_source return location_; } - private: + float t() const { + return previous_t_; + } + + float v() const { + return previous_v_; + } + +private: // parameters/data - cell_type* cell_; + //const cell_type* cell_; segment_location location_; double threshold_; @@ -82,6 +85,7 @@ class spike_detector : public spike_source bool is_spiking_; }; +/* // spike generator according to a Poisson process class poisson_generator : public spike_source { @@ -111,6 +115,7 @@ class poisson_generator : public spike_source float firing_rate_; float previous_t_; }; +*/ } // namespace mc diff --git a/src/fvm_cell.hpp b/src/fvm_cell.hpp index c21ed5b3..4bd21571 100644 --- a/src/fvm_cell.hpp +++ b/src/fvm_cell.hpp @@ -29,6 +29,8 @@ template <typename T, typename I> class fvm_cell { public : + fvm_cell() = default; + /// the real number type using value_type = T; /// the integral index type @@ -83,54 +85,30 @@ class fvm_cell { } /// return the voltage in each CV - vector_view voltage() { - return voltage_; - } - const_vector_view voltage() const { - return voltage_; - } + vector_view voltage() { return voltage_; } + const_vector_view voltage() const { return voltage_; } - std::size_t size() const { - return matrix_.size(); - } + std::size_t size() const { return matrix_.size(); } /// return reference to in iterable container of the mechanisms - std::vector<mechanism_type>& mechanisms() { - return mechanisms_; - } + std::vector<mechanism_type>& mechanisms() { return mechanisms_; } /// return reference to list of ions //std::map<mechanisms::ionKind, ion_type> ions_; - std::map<mechanisms::ionKind, ion_type>& ions() { - return ions_; - } - std::map<mechanisms::ionKind, ion_type> const& ions() const { - return ions_; - } + std::map<mechanisms::ionKind, ion_type>& ions() { return ions_; } + std::map<mechanisms::ionKind, ion_type> const& ions() const { return ions_; } /// return reference to sodium ion - ion_type& ion_na() { - return ions_[mechanisms::ionKind::na]; - } - ion_type const& ion_na() const { - return ions_[mechanisms::ionKind::na]; - } + ion_type& ion_na() { return ions_[mechanisms::ionKind::na]; } + ion_type const& ion_na() const { return ions_[mechanisms::ionKind::na]; } /// return reference to calcium ion - ion_type& ion_ca() { - return ions_[mechanisms::ionKind::ca]; - } - ion_type const& ion_ca() const { - return ions_[mechanisms::ionKind::ca]; - } + ion_type& ion_ca() { return ions_[mechanisms::ionKind::ca]; } + ion_type const& ion_ca() const { return ions_[mechanisms::ionKind::ca]; } /// return reference to pottasium ion - ion_type& ion_k() { - return ions_[mechanisms::ionKind::k]; - } - ion_type const& ion_k() const { - return ions_[mechanisms::ionKind::k]; - } + ion_type& ion_k() { return ions_[mechanisms::ionKind::k]; } + ion_type const& ion_k() const { return ions_[mechanisms::ionKind::k]; } /// make a time step void advance(value_type dt); @@ -138,24 +116,25 @@ class fvm_cell { /// advance solution to target time tfinal with maximum step size dt void advance_to(value_type tfinal, value_type dt); - /// set initial states - void initialize(); + /// pass an event to the appropriate synapse and call net_receive + void apply_event(local_event e) { + mechanisms_[synapse_index_]->net_receive(e.target, e.weight); + } - event_queue& queue() { - return events_; + mechanism_type& synapses() { + return mechanisms_[synapse_index_]; } - // returns the compartment index of a segment location - int compartment_index(segment_location loc) { - EXPECTS(loc.segment < segment_index_.size()); + /// set initial states + void initialize(); - const auto seg = loc.segment; + /// returns the compartment index of a segment location + int compartment_index(segment_location loc) const; - auto first = segment_index_[seg]; - auto n = segment_index_[seg+1] - first; - auto index = std::floor(n*loc.position); - return index<n ? first+index : first+n-1; - } + /// returns voltage at a segment location + value_type voltage(segment_location loc) const; + + value_type time() const { return t_; } private: @@ -466,13 +445,30 @@ void fvm_cell<T, I>::setup_matrix(T dt) rhs[i] = cv_areas_[i]*(voltage_[i] - factor/cv_capacitance_[i]*current_[i]); } } +template <typename T, typename I> +int fvm_cell<T, I>::compartment_index(segment_location loc) const +{ + EXPECTS(loc.segment < segment_index_.size()); + + const auto seg = loc.segment; + + auto first = segment_index_[seg]; + auto n = segment_index_[seg+1] - first; + auto index = std::floor(n*loc.position); + return index<n ? first+index : first+n-1; +} + +template <typename T, typename I> +T fvm_cell<T, I>::voltage(segment_location loc) const +{ + return voltage_[compartment_index(loc)]; +} template <typename T, typename I> void fvm_cell<T, I>::initialize() { t_ = 0.; - // initialize mechanism states for(auto& m : mechanisms_) { m->nrn_init(); } @@ -500,15 +496,12 @@ void fvm_cell<T, I>::advance(T dt) current_[loc] -= 100.*ie/cv_areas_[loc]; } - // set matrix diagonals and rhs - setup_matrix(dt); - // solve the linear system + setup_matrix(dt); matrix_.solve(); - voltage_(all) = matrix_.rhs(); - // update states + // integrate state of gating variables etc. for(auto& m : mechanisms_) { m->nrn_state(); } @@ -516,6 +509,7 @@ void fvm_cell<T, I>::advance(T dt) t_ += dt; } +/* template <typename T, typename I> void fvm_cell<T, I>::advance_to(T tfinal, T dt) { @@ -535,6 +529,7 @@ void fvm_cell<T, I>::advance_to(T tfinal, T dt) } } while(t_<tfinal); } +*/ } // namespace fvm } // namespace mc diff --git a/src/lowered_cell.hpp b/src/lowered_cell.hpp deleted file mode 100644 index 07137655..00000000 --- a/src/lowered_cell.hpp +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -namespace nest { -namespace mc { - -template <typename Cell> -class lowered_cell { - public : - - using cell_type = Cell; - using value_type = typename cell_type::value_type; - using size_type = typename cell_type::value_type; - - private : - - cell_type cell_; -}; - -} // namespace mc -} // namespace nest diff --git a/src/profiling/profiler.hpp b/src/profiling/profiler.hpp new file mode 100644 index 00000000..11fbfb4f --- /dev/null +++ b/src/profiling/profiler.hpp @@ -0,0 +1,397 @@ +#pragma once + +#include <algorithm> +#include <unordered_map> +#include <map> +#include <memory> +#include <stdexcept> +#include <fstream> +#include <iostream> +#include <vector> + +#include <cassert> +#include <cstdlib> + +#include <threading/threading.hpp> + +namespace nest { +namespace mc { +namespace util { + +static inline std::string green(std::string s) { return s; } +static inline std::string yellow(std::string s) { return s; } +static inline std::string white(std::string s) { return s; } +static inline std::string red(std::string s) { return s; } +static inline std::string cyan(std::string s) { return s; } + +namespace impl { + + static inline + size_t hash(std::string const& s) + { + size_t h = 5381; + for(auto c: s) { + h = ((h << 5) + h) + int(c); + } + return h; + } + + static inline + size_t hash(char* s) + { + size_t h = 5381; + + while(*s) { + h = ((h << 5) + h) + int(*s); + ++s; + } + return h; + } + + struct profiler_node { + double value; + std::string name; + std::vector<profiler_node> children; + + profiler_node() + : value(0.), name("") + {} + + profiler_node(double v, std::string const& n) + : value(v), name(n) + {} + + void print(int indent=0) + { + std::string s = std::string(indent, ' ') + name; + std::cout << s + << std::string(60-s.size(), '.') + << value + << "\n"; + for(auto &n: children) { + n.print(indent+2); + } + } + + friend profiler_node operator +(profiler_node const& lhs, profiler_node const& rhs) + { + assert(lhs.name == rhs.name); + auto node = lhs; + node.fuse(rhs); + return node; + } + + friend bool operator ==(profiler_node const& lhs, profiler_node const& rhs) + { + return lhs.name == rhs.name; + } + + void print(std::ostream& stream, double threshold) + { + // convert threshold from proportion to time + threshold *= value; + print_sub(stream, 0, threshold, value); + } + + void print_sub(std::ostream& stream, + int indent, + double threshold, + double total) + { + char buffer[512]; + + if(value < threshold) { + std::cout << green("not printing ") << name << std::endl; + return; + } + + auto max_contribution = + std::accumulate( + children.begin(), children.end(), -1., + [] (double lhs, profiler_node const& rhs) { + return lhs > rhs.value ? lhs : rhs.value; + } + ); + + // print the table row + auto const indent_str = std::string(indent, ' '); + auto label = indent_str + name; + float percentage = 100.*value/total; + snprintf(buffer, sizeof(buffer), "%-25s%10.3f%10.1f", + label.c_str(), + float(value), + float(percentage)); + bool print_children = + threshold==0. ? children.size()>0 + : max_contribution >= threshold; + + if(print_children) { + stream << white(buffer) << std::endl; + } + else { + stream << buffer << std::endl; + } + + if(print_children) { + auto other = 0.; + for(auto &n : children) { + if(n.value<threshold || n.name=="other") { + other += n.value; + } + else { + n.print_sub(stream, indent + 2, threshold, total); + } + } + if(other >= threshold && children.size()) { + label = indent_str + " other"; + percentage = 100.*other/total; + snprintf(buffer, sizeof(buffer), "%-25s%10.3f%10.1f", + label.c_str(), float(other), percentage); + stream << buffer << std::endl; + } + } + } + + void fuse(profiler_node const& other) + { + for(auto const& n : other.children) { + // linear search isn't ideal... + auto const it = std::find(children.begin(), children.end(), n); + if(it!=children.end()) { + (*it).fuse(n); + } + else { + children.push_back(n); + } + } + + value += other.value; + } + + }; + + +} // namespace impl + +using timer_type = nest::mc::threading::timer; + +// a region in the profiler, has +// - name +// - accumulated timer +// - nested sub-regions +class region_type { + region_type *parent_ = nullptr; + std::string name_; + size_t hash_; + std::unordered_map< + size_t, + std::unique_ptr<region_type> + > subregions_; + timer_type::time_point start_time_; + double total_time_ = 0; + +public: + + using profiler_node = impl::profiler_node; + + explicit region_type(std::string const& n) + : name_(n) + { + start_time_ = timer_type::tic(); + hash_ = impl::hash(n); + } + + + explicit region_type(const char* n) + : region_type(std::string(n)) + {} + + std::string const& name() const { + return name_; + } + + void name(std::string const& n) { + name_ = n; + } + + region_type* parent() { + return parent_; + } + + void start_time() { start_time_ = timer_type::tic(); } + void end_time () { total_time_ += timer_type::toc(start_time_); } + + region_type(std::string const& n, region_type* p) + : region_type(n) + { + parent_ = p; + } + + bool has_subregions() const { + return subregions_.size() > 0; + } + + size_t hash () const { + return hash_; + } + + region_type* subregion(const char* n) + { + size_t hsh = impl::hash(n); + auto s = subregions_.find(hsh); + if(s == subregions_.end()) { + subregions_[hsh] = util::make_unique<region_type>(n, this); + return subregions_[hsh].get(); + } + return s->second.get(); + } + + double subregion_contributions() const + { + return + std::accumulate( + subregions_.begin(), subregions_.end(), 0., + [](double l, decltype(*(subregions_.begin())) r) { + return l+r.second->total(); + } + ); + } + + double total() const + { + return total_time_; + } + + profiler_node populate_performance_tree() const { + profiler_node tree(total(), name()); + + for(auto &it : subregions_) { + tree.children.push_back(it.second->populate_performance_tree()); + } + + // sort the contributions in descending order + std::stable_sort( + tree.children.begin(), tree.children.end(), + [](profiler_node const& lhs, profiler_node const& rhs) { + return lhs.value>rhs.value; + } + ); + + if(tree.children.size()) { + // find the contribution of parts of the code that were not explicitly profiled + auto contributions = + std::accumulate( + tree.children.begin(), tree.children.end(), 0., + [](double v, profiler_node& n) { + return v+n.value; + } + ); + auto other = total() - contributions; + + // add the "other" category + tree.children.emplace_back(other, std::string("other")); + } + + return tree; + } +}; + +class Profiler { +public: + Profiler(std::string const& name) + : root_region_(name) + { } + + // the copy constructor doesn't do a "deep copy" + // it simply creates a new Profiler with the same name + // This is needed for tbb to create a list of thread local profilers + Profiler(Profiler const& other) + : Profiler(other.root_region_.name()) + {} + + void enter(const char* name) + { + if(!is_activated()) return; + auto start = timer_type::tic(); + current_region_ = current_region_->subregion(name); + current_region_->start_time(); + self_time_ += timer_type::toc(start); + } + + void leave() + { + if(!is_activated()) return; + auto start = timer_type::tic(); + if(current_region_->parent()==nullptr) { + std::cout << "error" << std::endl; + throw std::out_of_range("attempt to leave root memory tracing region"); + } + current_region_->end_time(); + current_region_ = current_region_->parent(); + self_time_ += timer_type::toc(start); + } + + region_type& regions() + { + return root_region_; + } + + region_type* current_region() + { + return current_region_; + } + + double self_time() const + { + return self_time_; + } + + bool is_in_root() const + { + return &root_region_ == current_region_; + } + + bool is_activated() const { + return activated_; + } + + void start() { + if(is_activated()) { + throw std::out_of_range( + "attempt to start an already running profiler" + ); + } + activate(); + root_region_.start_time(); + } + + void stop() { + if(!is_in_root()) { + throw std::out_of_range( + "attempt to profiler that is not in the root region" + ); + } + root_region_.end_time(); + disactivate(); + } + + region_type::profiler_node performance_tree() { + if(is_activated()) { + stop(); + } + return root_region_.populate_performance_tree(); + } + +private: + void activate() { activated_ = true; } + void disactivate() { activated_ = false; } + + bool activated_ = false; + region_type root_region_; + region_type* current_region_ = &root_region_; + double self_time_ = 0.; +}; + +} // namespace util +} // namespace mc +} // namespace nest diff --git a/src/threading/serial.hpp b/src/threading/serial.hpp new file mode 100644 index 00000000..ebdede84 --- /dev/null +++ b/src/threading/serial.hpp @@ -0,0 +1,87 @@ +#pragma once + +#if !defined(WITH_SERIAL) + #error this header can only be loaded if WITH_SERIAL is set +#endif + +#include <array> +#include <chrono> +#include <string> + +namespace nest { +namespace mc { +namespace threading { + +/////////////////////////////////////////////////////////////////////// +// types +/////////////////////////////////////////////////////////////////////// +template <typename T> +class enumerable_thread_specific { + std::array<T, 1> data; + + public : + + T& local() { + return data[0]; + } + const T& local() const { + return data[0]; + } + + auto begin() -> decltype(data.begin()) + { + return data.begin(); + } + auto end() -> decltype(data.end()) + { + return data.end(); + } + auto cbegin() -> decltype(data.cbegin()) + { + return data.cbegin(); + } + auto cend() -> decltype(data.cend()) + { + return data.cend(); + } +}; + + +/////////////////////////////////////////////////////////////////////// +// algorithms +/////////////////////////////////////////////////////////////////////// +struct parallel_for { + template <typename F> + static void apply(int left, int right, F f) { + for(int i=left; i<right; ++i) { + f(i); + } + } +}; + +static inline +std::string description() { + return "serial"; +} + +struct timer { + using time_point = std::chrono::time_point<std::chrono::system_clock>; + + static + inline time_point tic() + { + return std::chrono::system_clock::now(); + } + + static + inline double toc(time_point t) + { + return std::chrono::duration<double>(tic() - t).count(); + } +}; + + +} // threading +} // mc +} // nest + diff --git a/src/threading/tbb.hpp b/src/threading/tbb.hpp new file mode 100644 index 00000000..56c09913 --- /dev/null +++ b/src/threading/tbb.hpp @@ -0,0 +1,51 @@ +#pragma once + +#if !defined(WITH_TBB) + #error this header can only be loaded if WITH_TBB is set +#endif + +#include <string> + +#include <tbb/tbb.h> +#include <tbb/compat/thread> +#include <tbb/enumerable_thread_specific.h> + +namespace nest { +namespace mc { +namespace threading { + +template <typename T> +class enumerable_thread_specific; + +struct parallel_for { + template <typename F> + static void apply(int left, int right, F f) { + tbb::parallel_for(left, right, f); + } +}; + +static +std::string description() { + return "TBB"; +} + +struct timer { + using time_point = tbb::tick_count; + + static + inline time_point tic() + { + return tbb::tick_count::now(); + } + + static + inline double toc(time_point t) + { + return (tic() - t).seconds(); + } +}; + +} // threading +} // mc +} // nest + diff --git a/src/threading/threading.hpp b/src/threading/threading.hpp new file mode 100644 index 00000000..76431d7d --- /dev/null +++ b/src/threading/threading.hpp @@ -0,0 +1,9 @@ +#pragma once + +#if defined(WITH_TBB) + #include "tbb.hpp" +#else + #define WITH_SERIAL + #include "serial.hpp" +#endif + diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 33e44398..b859950e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -18,12 +18,14 @@ set(TEST_SOURCES test_compartments.cpp test_event_queue.cpp test_fvm.cpp + test_cell_group.cpp test_matrix.cpp test_mechanisms.cpp test_optional.cpp test_parameters.cpp test_point.cpp test_segment.cpp + test_spikes.cpp test_stimulus.cpp test_swcio.cpp test_synapses.cpp @@ -38,7 +40,7 @@ set(VALIDATION_SOURCES # unit tests validate_ball_and_stick.cpp validate_soma.cpp - validate_synapses.cpp + #validate_synapses.cpp # unit test driver validate.cpp diff --git a/tests/test_cell_group.cpp b/tests/test_cell_group.cpp new file mode 100644 index 00000000..d9787a19 --- /dev/null +++ b/tests/test_cell_group.cpp @@ -0,0 +1,48 @@ +#include <limits> + +#include "gtest.h" + +#include <fvm_cell.hpp> +#include <cell_group.hpp> + +nest::mc::cell make_cell() { + using namespace nest::mc; + + // setup global state for the mechanisms + mechanisms::setup_mechanism_helpers(); + + nest::mc::cell cell; + + // Soma with diameter 12.6157 um and HH channel + auto soma = cell.add_soma(12.6157/2.0); + soma->add_mechanism(hh_parameters()); + + // add dendrite of length 200 um and diameter 1 um with passive channel + auto dendrite = cell.add_cable(0, segmentKind::dendrite, 0.5, 0.5, 200); + dendrite->add_mechanism(pas_parameters()); + dendrite->set_compartments(101); + + dendrite->mechanism("membrane").set("r_L", 100); + + // add stimulus + cell.add_stimulus({1,1}, {5., 80., 0.3}); + + cell.add_detector({0,0}, 0); + + return cell; +} + +TEST(cell_group, test) +{ + using namespace nest::mc; + + using cell_type = cell_group<fvm::fvm_cell<double, int>>; + + auto cell = cell_type{make_cell()}; + + cell.advance(50, 0.01); + + // a bit lame... + EXPECT_EQ(cell.spikes().size(), 4u); +} + diff --git a/tests/test_spikes.cpp b/tests/test_spikes.cpp new file mode 100644 index 00000000..b16f89cd --- /dev/null +++ b/tests/test_spikes.cpp @@ -0,0 +1,70 @@ +#include "gtest.h" + +#include <communication/spike.hpp> +#include <communication/spike_source.hpp> + +struct cell_proxy { + double voltage(nest::mc::segment_location loc) const { + return v; + } + + double v = -65.; +}; + +TEST(spikes, spike_detector) +{ + using namespace nest::mc; + using detector_type = spike_detector<cell_proxy>; + cell_proxy proxy; + float threshold = 10.f; + float t = 0.f; + float dt = 1.f; + auto loc = segment_location(1, 0.1); + + auto detector = detector_type(proxy, loc, threshold, t); + + EXPECT_FALSE(detector.is_spiking()); + EXPECT_EQ(loc, detector.location()); + EXPECT_EQ(proxy.v, detector.v()); + EXPECT_EQ(t, detector.t()); + + { + t += dt; + proxy.v = 0; + auto spike = detector.test(proxy, t); + EXPECT_FALSE(spike); + + EXPECT_FALSE(detector.is_spiking()); + EXPECT_EQ(loc, detector.location()); + EXPECT_EQ(proxy.v, detector.v()); + EXPECT_EQ(t, detector.t()); + } + + { + t += dt; + proxy.v = 20; + auto spike = detector.test(proxy, t); + + EXPECT_TRUE(spike); + EXPECT_EQ(spike.get(), 1.5); + + EXPECT_TRUE(detector.is_spiking()); + EXPECT_EQ(loc, detector.location()); + EXPECT_EQ(proxy.v, detector.v()); + EXPECT_EQ(t, detector.t()); + } + + { + t += dt; + proxy.v = 0; + auto spike = detector.test(proxy, t); + + EXPECT_FALSE(spike); + + EXPECT_FALSE(detector.is_spiking()); + EXPECT_EQ(loc, detector.location()); + EXPECT_EQ(proxy.v, detector.v()); + EXPECT_EQ(t, detector.t()); + } +} + -- GitLab