diff --git a/.gitignore b/.gitignore index a18a8b7d7af992216fe6d6bd229358fe0aa88b95..e9d54cb796031fdcbba7935bec51511be3f6de6b 100644 --- a/.gitignore +++ b/.gitignore @@ -64,3 +64,5 @@ external/modparser-update external/tmp mechanisms/*.hpp +# build path +build diff --git a/.ycm_extra_conf.py b/.ycm_extra_conf.py index c356012c50dfe37e3a849f16809e390433511306..9033c063eb811ac190125cd94e6babcfa5518612 100644 --- a/.ycm_extra_conf.py +++ b/.ycm_extra_conf.py @@ -36,6 +36,7 @@ import ycm_core # CHANGE THIS LIST OF FLAGS. YES, THIS IS THE DROID YOU HAVE BEEN LOOKING FOR. flags = [ '-DNDEBUG', + '-DWITH_TBB', '-std=c++11', '-x', 'c++', @@ -47,6 +48,8 @@ flags = [ 'include', '-I', 'external', + '-I', + 'miniapp', # '-isystem', # '/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.10.sdk/usr/include/c++/4.2.1', # '-I', diff --git a/CMakeLists.txt b/CMakeLists.txt index 706a1baab35ab43d9e4f659aa8bb39f666e3a142..a432f4e8d3cadf2899f90b13ff687018e4fecd0d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,37 @@ if(WITH_ASSERTIONS) add_definitions("-DWITH_ASSERTIONS") endif() +# TBB support +set(WITH_TBB OFF CACHE BOOL "use TBB for on-node threading" ) +if(WITH_TBB) + find_package(TBB REQUIRED) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TBB_DEFINITIONS}") + add_definitions(-DWITH_TBB) +endif() + +# MPI support +set(WITH_MPI OFF CACHE BOOL "use MPI for distrubuted parallelism") +if(WITH_MPI) + find_package(MPI REQUIRED) + include_directories(SYSTEM ${MPI_C_INCLUDE_PATH}) + add_definitions(-DWITH_MPI) + # unfortunate workaround for C++ detection in system mpi.h + add_definitions(-DMPICH_SKIP_MPICXX=1 -DOMPI_SKIP_MPICXX=1) + set_property(DIRECTORY APPEND_STRING PROPERTY COMPILE_OPTIONS "${MPI_C_COMPILE_FLAGS}") +endif() + +# Profiler support +set(WITH_PROFILING OFF CACHE BOOL "use built in profiling of miniapp" ) +if(WITH_PROFILING) + add_definitions(-DWITH_PROFILING) +endif() + +# Cray systems +set(SYSTEM_CRAY OFF CACHE BOOL "add flags for compilation on Cray systems") +if(SYSTEM_CRAY) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -dynamic") +endif() + # targets for extermal dependencies include(ExternalProject) externalproject_add(modparser @@ -42,9 +73,15 @@ externalproject_add(modparser include_directories(${CMAKE_SOURCE_DIR}/external) include_directories(${CMAKE_SOURCE_DIR}/include) include_directories(${CMAKE_SOURCE_DIR}/src) +include_directories(${CMAKE_SOURCE_DIR}/miniapp) include_directories(${CMAKE_SOURCE_DIR}) +if( "${WITH_TBB}" STREQUAL "ON" ) + include_directories(${TBB_INCLUDE_DIRS}) +endif() + add_subdirectory(mechanisms) add_subdirectory(src) add_subdirectory(tests) +add_subdirectory(miniapp) diff --git a/README.md b/README.md index 0cb0c0d8dd27b41cf9358aebb5761cf221c72bef..dc2e7d720fa5452c15f2db6b2e7ad6001928cb85 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,9 @@ git submodule init git submodule update # setup environment -module load gcc +# on a desktop system this might not be required +# on a cluster this could be required +module load gcc module load cmake export CC=`which gcc` export CXX=`which g++` @@ -25,3 +27,53 @@ make -j cd tests ./test.exe ``` + +## MPI + +Set the `WITH_MPI` option either via the ccmake interface, or via the command line as shown below. +For the time being our CMake configuration does not try to detect MPI. +Instead, it relies on the user specifying an MPI wrapper for the compiler by setting the `CXX` and `CC` environment variables. + +``` +export CXX=mpicxx +export CC=mpicc +cmake <path to CMakeLists.txt> -DWITH_MPI=ON + +``` + +## TBB + +Support for multi-threading requires Intel Threading Building Blocks (TBB). +When TBB is installed, it comes with some scripts that can be run to set up the user environment. +The scripts set the `TBB_ROOT` environment variable, which is used by the CMake configuration to find TBB. + + +``` +source <path to TBB installation>/tbbvars.sh +cmake <path to CMakeLists.txt> -DWITH_TBB=ON +``` + +### TBB on Cray systems + +To compile with TBB on Cray systems, load the intel module, which will automatically configure the environment. + +``` +# load the gnu environment for compiling the application +module load PrgEnv-gnu +# gcc 5.x does not work with the version of TBB installed on Cray +# requires at least version 4.4 of TBB +module swap gcc/4.9.3 +# load the intel programming module +# on Cray systems this automatically sets `TBB_ROOT` environment variable +module load intel +module load cmake +export CXX=`which CC` +export CC=`which cc` + +# multithreading only +cmake <path to CMakeLists.txt> -DWITH_TBB=ON -DSYSTEM_CRAY=ON + +# multithreading and MPI +cmake <path to CMakeLists.txt> -DWITH_TBB=ON -DWITH_MPI=ON -DSYSTEM_CRAY=ON + +``` diff --git a/cmake/FindTBB.cmake b/cmake/FindTBB.cmake new file mode 100644 index 0000000000000000000000000000000000000000..7cb10d64935022bf78e4b5d8ffcac3328ebe4ac5 --- /dev/null +++ b/cmake/FindTBB.cmake @@ -0,0 +1,249 @@ +#----------------------------------------------------------------------------- +# from github justusc/FindTBB +#----------------------------------------------------------------------------- + +# The MIT License (MIT) +# +# Copyright (c) 2015 Justus Calvin +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# +# FindTBB +# ------- +# +# Find TBB include directories and libraries. +# +# Usage: +# +# find_package(TBB [major[.minor]] [EXACT] +# [QUIET] [REQUIRED] +# [[COMPONENTS] [components...]] +# [OPTIONAL_COMPONENTS components...]) +# +# where the allowed components are tbbmalloc and tbb_preview. Users may modify +# the behavior of this module with the following variables: +# +# * TBB_ROOT_DIR - The base directory the of TBB installation. +# * TBB_INCLUDE_DIR - The directory that contains the TBB headers files. +# * TBB_LIBRARY - The directory that contains the TBB library files. +# * TBB_<library>_LIBRARY - The path of the TBB the corresponding TBB library. +# These libraries, if specified, override the +# corresponding library search results, where <library> +# may be tbb, tbb_debug, tbbmalloc, tbbmalloc_debug, +# tbb_preview, or tbb_preview_debug. +# * TBB_USE_DEBUG_BUILD - The debug version of tbb libraries, if present, will +# be used instead of the release version. +# +# Users may modify the behavior of this module with the following environment +# variables: +# +# * TBB_INSTALL_DIR +# * TBBROOT +# * LIBRARY_PATH +# +# This module will set the following variables: +# +# * TBB_FOUND - Set to false, or undefined, if we haven’t found, or +# don’t want to use TBB. +# * TBB_<component>_FOUND - If False, optional <component> part of TBB sytem is +# not available. +# * TBB_VERSION - The full version string +# * TBB_VERSION_MAJOR - The major version +# * TBB_VERSION_MINOR - The minor version +# * TBB_INTERFACE_VERSION - The interface version number defined in +# tbb/tbb_stddef.h. +# * TBB_<library>_LIBRARY_RELEASE - The path of the TBB release version of +# <library>, where <library> may be tbb, tbb_debug, +# tbbmalloc, tbbmalloc_debug, tbb_preview, or +# tbb_preview_debug. +# * TBB_<library>_LIBRARY_DEGUG - The path of the TBB release version of +# <library>, where <library> may be tbb, tbb_debug, +# tbbmalloc, tbbmalloc_debug, tbb_preview, or +# tbb_preview_debug. +# +# The following varibles should be used to build and link with TBB: +# +# * TBB_INCLUDE_DIRS - The include directory for TBB. +# * TBB_LIBRARIES - The libraries to link against to use TBB. +# * TBB_DEFINITIONS - Definitions to use when compiling code that uses TBB. + +include(FindPackageHandleStandardArgs) + +if(NOT TBB_FOUND) + + ################################## + # Check the build type + ################################## + + if(NOT DEFINED TBB_USE_DEBUG_BUILD) + if(CMAKE_BUILD_TYPE MATCHES "[Debug|DEBUG|debug|RelWithDebInfo|RELWITHDEBINFO|relwithdebinfo]") + set(TBB_USE_DEBUG_BUILD TRUE) + else() + set(TBB_USE_DEBUG_BUILD FALSE) + endif() + endif() + + ################################## + # Set the TBB search directories + ################################## + + # Define search paths based on user input and environment variables + set(TBB_SEARCH_DIR ${TBB_ROOT_DIR} $ENV{TBB_INSTALL_DIR} $ENV{TBBROOT}) + + # Define the search directories based on the current platform + if(CMAKE_SYSTEM_NAME STREQUAL "Windows") + set(TBB_DEFAULT_SEARCH_DIR "C:/Program Files/Intel/TBB" + "C:/Program Files (x86)/Intel/TBB") + + # Set the target architecture + if(CMAKE_SIZEOF_VOID_P EQUAL 8) + set(TBB_ARCHITECTURE "intel64") + else() + set(TBB_ARCHITECTURE "ia32") + endif() + + # Set the TBB search library path search suffix based on the version of VC + if(WINDOWS_STORE) + set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc11_ui") + elseif(MSVC14) + set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc14") + elseif(MSVC12) + set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc12") + elseif(MSVC11) + set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc11") + elseif(MSVC10) + set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc10") + endif() + + # Add the library path search suffix for the VC independent version of TBB + list(APPEND TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc_mt") + + elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin") + # OS X + set(TBB_DEFAULT_SEARCH_DIR "/opt/intel/tbb") + + # TODO: Check to see which C++ library is being used by the compiler. + if(NOT ${CMAKE_SYSTEM_VERSION} VERSION_LESS 13.0) + # The default C++ library on OS X 10.9 and later is libc++ + set(TBB_LIB_PATH_SUFFIX "lib/libc++") + else() + set(TBB_LIB_PATH_SUFFIX "lib") + endif() + elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux") + # Linux + set(TBB_DEFAULT_SEARCH_DIR "/opt/intel/tbb") + + # TODO: Check compiler version to see the suffix should be <arch>/gcc4.1 or + # <arch>/gcc4.1. For now, assume that the compiler is more recent than + # gcc 4.4.x or later. + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + set(TBB_LIB_PATH_SUFFIX "lib/intel64/gcc4.4") + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^i.86$") + set(TBB_LIB_PATH_SUFFIX "lib/ia32/gcc4.4") + endif() + endif() + + ################################## + # Find the TBB include dir + ################################## + + find_path(TBB_INCLUDE_DIRS tbb/tbb.h + HINTS ${TBB_INCLUDE_DIR} ${TBB_SEARCH_DIR} + PATHS ${TBB_DEFAULT_SEARCH_DIR} + PATH_SUFFIXES include) + + ################################## + # Find TBB components + ################################## + + # Find each component + foreach(_comp tbb_preview tbbmalloc tbb) + # Search for the libraries + find_library(TBB_${_comp}_LIBRARY_RELEASE ${_comp} + HINTS ${TBB_LIBRARY} ${TBB_SEARCH_DIR} + PATHS ${TBB_DEFAULT_SEARCH_DIR} + PATH_SUFFIXES "${TBB_LIB_PATH_SUFFIX}") + + find_library(TBB_${_comp}_LIBRARY_DEBUG ${_comp}_debug + HINTS ${TBB_LIBRARY} ${TBB_SEARCH_DIR} + PATHS ${TBB_DEFAULT_SEARCH_DIR} ENV LIBRARY_PATH + PATH_SUFFIXES "${TBB_LIB_PATH_SUFFIX}") + + + # Set the library to be used for the component + if(NOT TBB_${_comp}_LIBRARY) + if(TBB_USE_DEBUG_BUILD AND TBB_${_comp}_LIBRARY_DEBUG) + set(TBB_${_comp}_LIBRARY "${TBB_${_comp}_LIBRARY_DEBUG}") + elseif(TBB_${_comp}_LIBRARY_RELEASE) + set(TBB_${_comp}_LIBRARY "${TBB_${_comp}_LIBRARY_RELEASE}") + elseif(TBB_${_comp}_LIBRARY_DEBUG) + set(TBB_${_comp}_LIBRARY "${TBB_${_comp}_LIBRARY_DEBUG}") + endif() + endif() + + # Set the TBB library list and component found variables + if(TBB_${_comp}_LIBRARY) + list(APPEND TBB_LIBRARIES "${TBB_${_comp}_LIBRARY}") + set(TBB_${_comp}_FOUND TRUE) + else() + set(TBB_${_comp}_FOUND FALSE) + endif() + + mark_as_advanced(TBB_${_comp}_LIBRARY_RELEASE) + mark_as_advanced(TBB_${_comp}_LIBRARY_DEBUG) + mark_as_advanced(TBB_${_comp}_LIBRARY) + + endforeach() + + ################################## + # Set compile flags + ################################## + + if(TBB_tbb_LIBRARY MATCHES "debug") + set(TBB_DEFINITIONS "-DTBB_USE_DEBUG=1") + endif() + + ################################## + # Set version strings + ################################## + + if(TBB_INCLUDE_DIRS) + file(READ "${TBB_INCLUDE_DIRS}/tbb/tbb_stddef.h" _tbb_version_file) + string(REGEX REPLACE ".*#define TBB_VERSION_MAJOR ([0-9]+).*" "\\1" + TBB_VERSION_MAJOR "${_tbb_version_file}") + string(REGEX REPLACE ".*#define TBB_VERSION_MINOR ([0-9]+).*" "\\1" + TBB_VERSION_MINOR "${_tbb_version_file}") + string(REGEX REPLACE ".*#define TBB_INTERFACE_VERSION ([0-9]+).*" "\\1" + TBB_INTERFACE_VERSION "${_tbb_version_file}") + set(TBB_VERSION "${TBB_VERSION_MAJOR}.${TBB_VERSION_MINOR}") + endif() + + find_package_handle_standard_args(TBB + REQUIRED_VARS TBB_INCLUDE_DIRS TBB_LIBRARIES + HANDLE_COMPONENTS + VERSION_VAR TBB_VERSION) + + mark_as_advanced(TBB_INCLUDE_DIRS TBB_LIBRARIES) + + unset(TBB_ARCHITECTURE) + unset(TBB_LIB_PATH_SUFFIX) + unset(TBB_DEFAULT_SEARCH_DIR) + +endif() diff --git a/commit.msg b/commit.msg new file mode 100644 index 0000000000000000000000000000000000000000..3abb1935698a6efe63d5ea651d8c7d5c017541fe --- /dev/null +++ b/commit.msg @@ -0,0 +1,9 @@ +* change `probe_sort` enum to scoped enum + - renamed to `probeKind` + - refactored out of class + - updated coding guidelines wiki with enum rules +* refactor `std::pair` into structs with meaningfull name types in `cell.hpp` + - not just probes: stimulii and detectors too. +* add profiling region around sampling in `cell_group` +* change output data format for traces to json +* remove white space at end of lines (looking at you Sam) diff --git a/miniapp/CMakeLists.txt b/miniapp/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a5b9561198875a6515b4e8a69f7685ea75c19803 --- /dev/null +++ b/miniapp/CMakeLists.txt @@ -0,0 +1,23 @@ +set(HEADERS +) +set(MINIAPP_SOURCES + # mpi.cpp + io.cpp + miniapp.cpp +) + +add_executable(miniapp.exe ${MINIAPP_SOURCES} ${HEADERS}) + +target_link_libraries(miniapp.exe LINK_PUBLIC cellalgo) +target_link_libraries(miniapp.exe LINK_PUBLIC ${TBB_LIBRARIES}) + +if(WITH_MPI) + target_link_libraries(miniapp.exe LINK_PUBLIC ${MPI_C_LIBRARIES}) + set_property(TARGET miniapp.exe APPEND_STRING PROPERTY LINK_FLAGS "${MPI_C_LINK_FLAGS}") +endif() + +set_target_properties(miniapp.exe + PROPERTIES + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/miniapp" +) + diff --git a/miniapp/io.cpp b/miniapp/io.cpp new file mode 100644 index 0000000000000000000000000000000000000000..887848b549b7bad46574b92ac9fcde74faae5d5c --- /dev/null +++ b/miniapp/io.cpp @@ -0,0 +1,25 @@ +#include "io.hpp" + +namespace nest { +namespace mc { +namespace io { + +// read simulation options from json file with name fname +// for now this is just a placeholder +options read_options(std::string fname) { + // 10 cells, 1 synapses per cell, 10 compartments per segment + return {200, 1, 100}; +} + +std::ostream& operator<<(std::ostream& o, const options& opt) { + o << "simultion options:\n"; + o << " cells : " << opt.cells << "\n"; + o << " compartments/segment : " << opt.compartments_per_segment << "\n"; + o << " synapses/cell : " << opt.synapses_per_cell << "\n"; + + return o; +} + +} // namespace io +} // namespace mc +} // namespace nest diff --git a/miniapp/io.hpp b/miniapp/io.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e8a8762ceb34b7d7148dfb7eb831217ac633d3d3 --- /dev/null +++ b/miniapp/io.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include <json/src/json.hpp> + +namespace nest { +namespace mc { +namespace io { + +// holds the options for a simulation run +struct options { + int cells; + int synapses_per_cell; + int compartments_per_segment; +}; + +std::ostream& operator<<(std::ostream& o, const options& opt); + +options read_options(std::string fname); + +} // namespace io +} // namespace mc +} // namespace nest diff --git a/miniapp/miniapp.cpp b/miniapp/miniapp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..74b551ff4f07fe795514caafde901de1721108dc --- /dev/null +++ b/miniapp/miniapp.cpp @@ -0,0 +1,425 @@ +#include <iostream> +#include <fstream> +#include <sstream> + +#include <cell.hpp> +#include <cell_group.hpp> +#include <fvm_cell.hpp> +#include <mechanism_interface.hpp> + +#include "io.hpp" +#include "threading/threading.hpp" +#include "profiling/profiler.hpp" +#include "communication/communicator.hpp" +#include "communication/global_policy.hpp" +#include "util/optional.hpp" + +using namespace nest; + +using real_type = double; +using index_type = int; +using id_type = uint32_t; +using numeric_cell = mc::fvm::fvm_cell<real_type, index_type>; +using cell_group = mc::cell_group<numeric_cell>; + +using global_policy = nest::mc::communication::global_policy; +using communicator_type = + mc::communication::communicator<global_policy>; + +using nest::mc::util::optional; + +struct model { + communicator_type communicator; + std::vector<cell_group> cell_groups; + + int num_groups() const { + return cell_groups.size(); + } + + void run(double tfinal, double dt) { + auto t = 0.; + auto delta = communicator.min_delay(); + while(t<tfinal) { + mc::threading::parallel_for::apply( + 0, num_groups(), + [&](int i) { + mc::util::profiler_enter("stepping","events"); + cell_groups[i].enqueue_events(communicator.queue(i)); + mc::util::profiler_leave(); + cell_groups[i].advance(t+delta, dt); + mc::util::profiler_enter("events"); + communicator.add_spikes(cell_groups[i].spikes()); + cell_groups[i].clear_spikes(); + mc::util::profiler_leave(2); + } + ); + + mc::util::profiler_enter("stepping", "exchange"); + communicator.exchange(); + mc::util::profiler_leave(2); + + t += delta; + } + } + + void init_communicator() { + mc::util::profiler_enter("setup", "communicator"); + + // calculate the source and synapse distribution serially + std::vector<id_type> target_counts(num_groups()); + std::vector<id_type> source_counts(num_groups()); + for (auto i=0; i<num_groups(); ++i) { + target_counts[i] = cell_groups[i].cell().synapses()->size(); + source_counts[i] = cell_groups[i].spike_sources().size(); + } + + target_map = mc::algorithms::make_index(target_counts); + source_map = mc::algorithms::make_index(source_counts); + + // create connections + communicator = communicator_type(num_groups(), target_counts); + + mc::util::profiler_leave(2); + } + + void update_gids() { + mc::util::profiler_enter("setup", "globalize"); + auto com_policy = communicator.communication_policy(); + auto global_source_map = com_policy.make_map(source_map.back()); + auto domain_idx = communicator.domain_id(); + for (auto i=0; i<num_groups(); ++i) { + cell_groups[i].set_source_gids(source_map[i]+global_source_map[domain_idx]); + cell_groups[i].set_target_gids(target_map[i]+communicator.target_gid_from_group_lid(0)); + } + mc::util::profiler_leave(2); + } + + // TODO : only stored here because init_communicator() and update_gids() are split + std::vector<id_type> source_map; + std::vector<id_type> target_map; + + // traces from probes + struct trace_data { + struct sample_type { + float time; + double value; + }; + std::string name; + index_type id; + std::vector<sample_type> samples; + }; + + // different traces may be written to by different threads; + // during simulation, each trace_sampler will be responsible for its + // corresponding element in the traces vector. + + std::vector<trace_data> traces; + + // make a sampler that records to traces + struct simple_sampler_functor { + std::vector<trace_data> &traces_; + size_t trace_index_ = 0; + float requested_sample_time_ = 0; + float dt_ = 0; + + simple_sampler_functor(std::vector<trace_data> &traces, size_t index, float dt) : + traces_(traces), trace_index_(index), dt_(dt) + {} + + optional<float> operator()(float t, double v) { + traces_[trace_index_].samples.push_back({t,v}); + return requested_sample_time_ += dt_; + } + }; + + mc::sampler make_simple_sampler( + index_type probe_gid, const std::string name, index_type id, float dt) + { + traces.push_back(trace_data{name, id}); + return {probe_gid, simple_sampler_functor(traces, traces.size()-1, dt)}; + } + + void reset_traces() { + // do not call during simulation: thread-unsafe access to traces. + traces.clear(); + } + + void dump_traces() { + // do not call during simulation: thread-unsafe access to traces. + for (const auto& trace: traces) { + auto path = "trace_" + std::to_string(trace.id) + + "_" + trace.name + ".json"; + + nlohmann::json json; + json["name"] = trace.name; + for (const auto& sample: trace.samples) { + json["time"].push_back(sample.time); + json["value"].push_back(sample.value); + } + std::ofstream file(path); + file << std::setw(1) << json << std::endl; + } + } +}; + +// define some global model parameters +namespace parameters { +namespace synapses { + // synapse delay + constexpr double delay = 5.0; // ms + + // connection weight + constexpr double weight = 0.005; // uS +} +} + +/////////////////////////////////////// +// prototypes +/////////////////////////////////////// + +/// make a single abstract cell +mc::cell make_cell(int compartments_per_segment, int num_synapses); + +/// do basic setup (initialize global state, print banner, etc) +void setup(); + +/// helper function for initializing cells +cell_group make_lowered_cell(int cell_index, const mc::cell& c); + +/// models +void ring_model(nest::mc::io::options& opt, model& m); +void all_to_all_model(nest::mc::io::options& opt, model& m); + + +/////////////////////////////////////// +// main +/////////////////////////////////////// +int main(int argc, char** argv) { + nest::mc::communication::global_policy_guard global_guard(argc, argv); + + setup(); + + // read parameters + mc::io::options opt; + try { + opt = mc::io::read_options(""); + if (!global_policy::id()) { + std::cout << opt << "\n"; + } + } + catch (std::exception& e) { + std::cerr << e.what() << std::endl; + exit(1); + } + + model m; + all_to_all_model(opt, m); + + // + // time stepping + // + auto tfinal = 20.; + auto dt = 0.01; + + auto id = m.communicator.domain_id(); + + if (!id) { + m.communicator.add_spike({0, 5}); + } + + m.run(tfinal, dt); + + mc::util::profiler_output(0.001); + + if (!id) { + std::cout << "there were " << m.communicator.num_spikes() << " spikes\n"; + } + m.dump_traces(); + +#ifdef SPLAT + if (!global_policy::id()) { + //for (auto i=0u; i<m.cell_groups.size(); ++i) { + m.cell_groups[0].splat("cell0.txt"); + m.cell_groups[1].splat("cell1.txt"); + m.cell_groups[2].splat("cell2.txt"); + //} + } +#endif +} + +/////////////////////////////////////// +// models +/////////////////////////////////////// + +void ring_model(nest::mc::io::options& opt, model& m) { + // + // make cells + // + + // make a basic cell + auto basic_cell = make_cell(opt.compartments_per_segment, 1); + + // make a vector for storing all of the cells + m.cell_groups = std::vector<cell_group>(opt.cells); + + // initialize the cells in parallel + mc::threading::parallel_for::apply( + 0, opt.cells, + [&](int i) { + // initialize cell + mc::util::profiler_enter("setup"); + mc::util::profiler_enter("make cell"); + m.cell_groups[i] = make_lowered_cell(i, basic_cell); + mc::util::profiler_leave(); + mc::util::profiler_leave(); + } + ); + + // + // network creation + // + m.init_communicator(); + + for (auto i=0u; i<(id_type)opt.cells; ++i) { + m.communicator.add_connection({ + i, (i+1)%opt.cells, + parameters::synapses::weight, parameters::synapses::delay + }); + } + + m.update_gids(); +} + +void all_to_all_model(nest::mc::io::options& opt, model& m) { + // + // make cells + // + + // make a basic cell + auto basic_cell = make_cell(opt.compartments_per_segment, opt.cells-1); + + // make a vector for storing all of the cells + id_type ncell_global = opt.cells; + id_type ncell_local = ncell_global / m.communicator.num_domains(); + int remainder = ncell_global - (ncell_local*m.communicator.num_domains()); + if (m.communicator.domain_id()<remainder) { + ncell_local++; + } + + m.cell_groups = std::vector<cell_group>(ncell_local); + + // initialize the cells in parallel + mc::threading::parallel_for::apply( + 0, ncell_local, + [&](int i) { + mc::util::profiler_enter("setup", "cells"); + m.cell_groups[i] = make_lowered_cell(i, basic_cell); + mc::util::profiler_leave(2); + } + ); + + // + // network creation + // + m.init_communicator(); + + // monitor soma and dendrite on a few cells + float sample_dt = 0.1; + index_type monitor_group_gids[] = { 0, 1, 2 }; + for (auto gid : monitor_group_gids) { + if (!m.communicator.is_local_group(gid)) { + continue; + } + + auto lid = m.communicator.group_lid(gid); + auto probe_soma = m.cell_groups[lid].probe_gid_range().first; + auto probe_dend = probe_soma+1; + + m.cell_groups[lid].add_sampler(m.make_simple_sampler(probe_soma, "vsoma", gid, sample_dt)); + m.cell_groups[lid].add_sampler(m.make_simple_sampler(probe_dend, "vdend", gid, sample_dt)); + } + + mc::util::profiler_enter("setup", "connections"); + // lid is local cell/group id + for (auto lid=0u; lid<ncell_local; ++lid) { + auto target = m.communicator.target_gid_from_group_lid(lid); + auto gid = m.communicator.group_gid_from_group_lid(lid); + // tid is global cell/group id + for (auto tid=0u; tid<ncell_global; ++tid) { + if (gid!=tid) { + m.communicator.add_connection({ + tid, target++, + parameters::synapses::weight, parameters::synapses::delay + }); + } + } + } + + m.communicator.construct(); + mc::util::profiler_leave(2); + + m.update_gids(); +} + +/////////////////////////////////////// +// function definitions +/////////////////////////////////////// + +void setup() { + // print banner + if (!global_policy::id()) { + std::cout << "====================\n"; + std::cout << " starting miniapp\n"; + std::cout << " - " << mc::threading::description() << " threading support\n"; + std::cout << " - communication policy: " << global_policy::name() << "\n"; + std::cout << "====================\n"; + } + + // setup global state for the mechanisms + mc::mechanisms::setup_mechanism_helpers(); +} + +// make a high level cell description for use in simulation +mc::cell make_cell(int compartments_per_segment, int num_synapses) { + nest::mc::cell cell; + + // Soma with diameter 12.6157 um and HH channel + auto soma = cell.add_soma(12.6157/2.0); + soma->add_mechanism(mc::hh_parameters()); + + // add dendrite of length 200 um and diameter 1 um with passive channel + std::vector<mc::cable_segment*> dendrites; + dendrites.push_back(cell.add_cable(0, mc::segmentKind::dendrite, 0.5, 0.5, 200)); + //dendrites.push_back(cell.add_cable(1, mc::segmentKind::dendrite, 0.5, 0.25, 100)); + //dendrites.push_back(cell.add_cable(1, mc::segmentKind::dendrite, 0.5, 0.25, 100)); + + for (auto d : dendrites) { + d->add_mechanism(mc::pas_parameters()); + d->set_compartments(compartments_per_segment); + d->mechanism("membrane").set("r_L", 100); + } + + // add stimulus + //cell.add_stimulus({1,1}, {5., 80., 0.3}); + + cell.add_detector({0,0}, 30); + + for (auto i=0; i<num_synapses; ++i) { + cell.add_synapse({1, 0.5}); + } + + // add probes: + auto probe_soma = cell.add_probe({0, 0}, mc::probeKind::membrane_voltage); + auto probe_dendrite = cell.add_probe({1, 0.5}, mc::probeKind::membrane_voltage); + + EXPECTS(probe_soma==0); + EXPECTS(probe_dendrite==1); + (void)probe_soma, (void)probe_dendrite; + + return cell; +} + +cell_group make_lowered_cell(int cell_index, const mc::cell& c) { + return cell_group(c); +} + diff --git a/miniapp/plot.py b/miniapp/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..ced8f4f0bfc4d51d061fd47dde86766e5a5f8438 --- /dev/null +++ b/miniapp/plot.py @@ -0,0 +1,22 @@ +from matplotlib import pyplot +import numpy as np + +ncol = 3 + +raw = np.fromfile("cell0.txt", sep=" ") +n = raw.size/ncol +data = raw.reshape(n,ncol) + +t = data[:, 0] +soma = data[:, 1] +dend = data[:, 2] + +pyplot.plot(t, soma, 'k') +pyplot.plot(t, dend, 'r') + +pyplot.xlabel('time (ms)') +pyplot.ylabel('mV') +pyplot.xlim([t[0], t[n-1]]) +pyplot.grid() +pyplot.show() + diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b6e63096ac86d3759d14373358dac66264c83f99..7f4a588d933a5ea74342c70dc8b6ac5df95aca8b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -5,10 +5,15 @@ set(BASE_SOURCES cell.cpp mechanism_interface.cpp parameter_list.cpp + profiling/profiler.cpp swcio.cpp util/debug.cpp ) +if(${WITH_MPI}) + set(BASE_SOURCES ${BASE_SOURCES} communication/mpi.cpp) +endif() + add_library(cellalgo ${BASE_SOURCES} ${HEADERS}) add_dependencies(cellalgo build_all_mods) diff --git a/src/algorithms.hpp b/src/algorithms.hpp index f3cd9882ec14eb4f39baba1035a297ada14d485d..a1f6bad69f49fd9c40bc8cdea1cf2a400f8d7285 100644 --- a/src/algorithms.hpp +++ b/src/algorithms.hpp @@ -209,7 +209,7 @@ std::vector<typename C::value_type> expand_branches(const C& branch_index) std::vector<typename C::value_type> expanded(branch_index.back()); for (std::size_t i = 0; i < branch_index.size()-1; ++i) { - for (std::size_t j = branch_index[i]; j < branch_index[i+1]; ++j) { + for (auto j = branch_index[i]; j < branch_index[i+1]; ++j) { expanded[j] = i; } } diff --git a/src/cell.cpp b/src/cell.cpp index da8febafac7957bbae99edd42e43093bebbb7e0e..1a70ca3e3034cf25e9238352cb1c5dfdd97a8c49 100644 --- a/src/cell.cpp +++ b/src/cell.cpp @@ -178,7 +178,7 @@ compartment_model cell::model() const } -void cell::add_stimulus( segment_location loc, i_clamp stim) +void cell::add_stimulus(segment_location loc, i_clamp stim) { if(!(loc.segment<num_segments())) { throw std::out_of_range( @@ -191,6 +191,11 @@ void cell::add_stimulus( segment_location loc, i_clamp stim) stimulii_.push_back({loc, std::move(stim)}); } +void cell::add_detector(segment_location loc, double threshold) +{ + spike_detectors_.push_back({loc, threshold}); +} + std::vector<int> const& cell::segment_parents() const { return parents_; diff --git a/src/cell.hpp b/src/cell.hpp index a10e0a8d5ccc53993962c4ed483b9367ce5eda6e..744c9970e26af7f439f9cee7cfb2f88c96014624 100644 --- a/src/cell.hpp +++ b/src/cell.hpp @@ -27,6 +27,9 @@ struct segment_location { { EXPECTS(position>=0. && position<=1.); } + friend bool operator==(segment_location l, segment_location r) { + return l.segment==r.segment && l.position==r.position; + } int segment; double position; }; @@ -36,14 +39,31 @@ int find_compartment_index( compartment_model const& graph ); +enum class probeKind { + membrane_voltage, + membrane_current +}; + /// high-level abstract representation of a cell and its segments class cell { - public: +public: // types using index_type = int; using value_type = double; using point_type = point<value_type>; + struct probe_instance { + segment_location location; + probeKind kind; + }; + struct stimulus_instance { + segment_location location; + i_clamp clamp; + }; + struct detector_instance { + segment_location location; + double threshold; + }; // constructor cell(); @@ -100,24 +120,55 @@ class cell { compartment_model model() const; + ////////////////// + // stimulii + ////////////////// void add_stimulus(segment_location loc, i_clamp stim); - std::vector<std::pair<segment_location, i_clamp>>& + std::vector<stimulus_instance>& stimulii() { return stimulii_; } - const std::vector<std::pair<segment_location, i_clamp>>& + const std::vector<stimulus_instance>& stimulii() const { return stimulii_; } + ////////////////// + // synapses + ////////////////// void add_synapse(segment_location loc); const std::vector<segment_location>& synapses() const; + ////////////////// + // spike detectors + ////////////////// + void add_detector(segment_location loc, double threshold); + + std::vector<detector_instance>& + detectors() { + return spike_detectors_; + } + + const std::vector<detector_instance>& + detectors() const { + return spike_detectors_; + } + + ////////////////// + // probes + ////////////////// + index_type add_probe(segment_location loc, probeKind kind) { + probes_.push_back({loc, kind}); + return probes_.size()-1; + } - private: + const std::vector<probe_instance>& + probes() const { return probes_; } + +private: // storage for connections std::vector<index_type> parents_; @@ -126,10 +177,16 @@ class cell { std::vector<segment_ptr> segments_; // the stimulii - std::vector<std::pair<segment_location, i_clamp>> stimulii_; + std::vector<stimulus_instance> stimulii_; // the synapses std::vector<segment_location> synapses_; + + // the sensors + std::vector<detector_instance> spike_detectors_; + + // the probes + std::vector<probe_instance> probes_; }; // Checks that two cells have the same diff --git a/src/cell_group.hpp b/src/cell_group.hpp new file mode 100644 index 0000000000000000000000000000000000000000..70bfdeb13123c780744c5875ae125bf631bf6142 --- /dev/null +++ b/src/cell_group.hpp @@ -0,0 +1,189 @@ +#pragma once + +#include <cstdint> +#include <vector> + +#include <cell.hpp> +#include <event_queue.hpp> +#include <communication/spike.hpp> +#include <communication/spike_source.hpp> + +#include <profiling/profiler.hpp> + +namespace nest { +namespace mc { + +// samplers take a time and sample value, and return an optional time +// for the next desired sample. + +struct sampler { + using index_type = int; + using time_type = float; + using value_type = double; + + index_type probe_gid; // samplers are attached to probes + std::function<util::optional<time_type>(time_type, value_type)> sample; +}; + +template <typename Cell> +class cell_group { +public: + using index_type = uint32_t; + using cell_type = Cell; + using value_type = typename cell_type::value_type; + using size_type = typename cell_type::value_type; + using spike_detector_type = spike_detector<Cell>; + + struct spike_source_type { + index_type index; + spike_detector_type source; + }; + + cell_group() = default; + + cell_group(const cell& c) : + cell_{c} + { + cell_.voltage()(memory::all) = -65.; + cell_.initialize(); + + for (auto& d : c.detectors()) { + spike_sources_.push_back( { + 0u, spike_detector_type(cell_, d.location, d.threshold, 0.f) + }); + } + } + + void set_source_gids(index_type gid) { + for (auto& s : spike_sources_) { + s.index = gid++; + } + } + + void set_target_gids(index_type lid) { + first_target_gid_ = lid; + } + + index_type num_probes() const { + return cell_.num_probes(); + } + + void set_probe_gids(index_type gid) { + first_probe_gid_ = gid; + } + + std::pair<index_type, index_type> probe_gid_range() const { + return { first_probe_gid_, first_probe_gid_+cell_.num_probes() }; + } + + void advance(double tfinal, double dt) { + while (cell_.time()<tfinal) { + // take any pending samples + float cell_time = cell_.time(); + + nest::mc::util::profiler_enter("sampling"); + while (auto m = sample_events_.pop_if_before(cell_time)) { + auto& sampler = samplers_[m->sampler_index]; + EXPECTS((bool)sampler.sample); + + index_type probe_index = sampler.probe_gid-first_probe_gid_; + auto next = sampler.sample(cell_.time(), cell_.probe(probe_index)); + if (next) { + m->time = std::max(*next, cell_time); + sample_events_.push(*m); + } + } + nest::mc::util::profiler_leave(); + + // look for events in the next time step + auto tstep = std::min(tfinal, cell_.time()+dt); + auto next = events_.pop_if_before(tstep); + auto tnext = next ? next->time: tstep; + + // integrate cell state + cell_.advance(tnext - cell_.time()); + + nest::mc::util::profiler_enter("events"); + // check for new spikes + for (auto& s : spike_sources_) { + if (auto spike = s.source.test(cell_, cell_.time())) { + spikes_.push_back({s.index, spike.get()}); + } + } + + // apply events + if (next) { + cell_.apply_event(next.get()); + // apply events that are due within some epsilon of the current + // time step. This should be a parameter. e.g. with for variable + // order time stepping, use the minimum possible time step size. + while(auto e = events_.pop_if_before(cell_.time()+dt/10.)) { + cell_.apply_event(e.get()); + } + } + nest::mc::util::profiler_leave(); + } + + } + + template <typename R> + void enqueue_events(R events) { + for (auto e : events) { + e.target -= first_target_gid_; + events_.push(e); + } + } + + const std::vector<communication::spike<index_type>>& + spikes() const { + return spikes_; + } + + cell_type& cell() { return cell_; } + const cell_type& cell() const { return cell_; } + + const std::vector<spike_source_type>& + spike_sources() const { + return spike_sources_; + } + + void clear_spikes() { + spikes_.clear(); + } + + void add_sampler(const sampler& s, float start_time = 0) { + auto sampler_index = uint32_t(samplers_.size()); + samplers_.push_back(s); + sample_events_.push({sampler_index, start_time}); + } + +private: + + + /// the lowered cell state (e.g. FVM) of the cell + cell_type cell_; + + /// spike detectors attached to the cell + std::vector<spike_source_type> spike_sources_; + + //. spikes that are generated + std::vector<communication::spike<index_type>> spikes_; + + /// pending events to be delivered + event_queue<postsynaptic_spike_event> events_; + + /// pending samples to be taken + event_queue<sample_event> sample_events_; + + /// the global id of the first target (e.g. a synapse) in this group + index_type first_target_gid_; + + /// the global id of the first probe in this group + index_type first_probe_gid_; + + /// collection of samplers to be run against probes in this group + std::vector<sampler> samplers_; +}; + +} // namespace mc +} // namespace nest diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..34185ac0fbd795294211ced7190d8b84a3e70874 --- /dev/null +++ b/src/communication/communicator.hpp @@ -0,0 +1,273 @@ +#pragma once + +#include <algorithm> +#include <iostream> +#include <vector> +#include <random> + +#include <communication/spike.hpp> +#include <threading/threading.hpp> +#include <algorithms.hpp> +#include <event_queue.hpp> + +#include "connection.hpp" + +namespace nest { +namespace mc { +namespace communication { + +// When the communicator is constructed the number of target groups and targets +// is specified, along with a mapping between local cell id and local +// target id. +// +// The user can add connections to an existing communicator object, where +// each connection is between any global cell and any local target. +// +// Once all connections have been specified, the construct() method can be used +// to build the data structures required for efficient spike communication and +// event generation. +template <typename CommunicationPolicy> +class communicator { +public: + using id_type = uint32_t; + using communication_policy_type = CommunicationPolicy; + + using spike_type = spike<id_type>; + + communicator() = default; + + communicator(id_type n_groups, std::vector<id_type> target_counts) : + num_groups_local_(n_groups), + num_targets_local_(target_counts.size()) + { + target_map_ = nest::mc::algorithms::make_index(target_counts); + num_targets_local_ = target_map_.back(); + + // create an event queue for each target group + events_.resize(num_groups_local_); + + // make maps for converting lid to gid + target_gid_map_ = communication_policy_.make_map(num_targets_local_); + group_gid_map_ = communication_policy_.make_map(num_groups_local_); + + // transform the target ids from lid to gid + auto first_target = target_gid_map_[domain_id()]; + for (auto &id : target_map_) { + id += first_target; + } + } + + id_type target_gid_from_group_lid(id_type lid) const { + EXPECTS(lid<num_groups_local_); + return target_map_[lid]; + } + + id_type group_gid_from_group_lid(id_type lid) const { + EXPECTS(lid<num_groups_local_); + return group_gid_map_[domain_id()] + lid; + } + + void add_connection(connection con) { + EXPECTS(is_local_target(con.destination())); + connections_.push_back(con); + } + + bool is_local_target(id_type gid) { + return gid>=target_gid_map_[domain_id()] + && gid<target_gid_map_[domain_id()+1]; + } + + bool is_local_group(id_type gid) { + return gid>=group_gid_map_[domain_id()] + && gid<group_gid_map_[domain_id()+1]; + } + + /// return the global id of the first group in domain d + /// the groups in domain d are in the contiguous half open range + /// [domain_first_group(d), domain_first_group(d+1)) + id_type group_gid_first(int d) const { + return group_gid_map_[d]; + } + + id_type target_lid(id_type gid) { + EXPECTS(is_local_group(gid)); + + return gid - target_gid_map_[domain_id()]; + } + + id_type group_lid(id_type gid) { + EXPECTS(is_local_group(gid)); + + return gid - group_gid_map_[domain_id()]; + } + + // builds the optimized data structure + void construct() { + if (!std::is_sorted(connections_.begin(), connections_.end())) { + std::sort(connections_.begin(), connections_.end()); + } + } + + float min_delay() { + auto local_min = std::numeric_limits<float>::max(); + for (auto& con : connections_) { + local_min = std::min(local_min, con.delay()); + } + + return communication_policy_.min(local_min); + } + + // return the local group index of the group which hosts the target with + // global id gid + id_type local_group_from_global_target(id_type gid) { + // assert that gid is in range + EXPECTS(is_local_target(gid)); + + return + std::distance( + target_map_.begin(), + std::upper_bound(target_map_.begin(), target_map_.end(), gid) + ) - 1; + } + + void add_spike(spike_type s) { + thread_spikes().push_back(s); + } + + void add_spikes(const std::vector<spike_type>& s) { + auto& v = thread_spikes(); + v.insert(v.end(), s.begin(), s.end()); + } + + std::vector<spike_type>& thread_spikes() { + return thread_spikes_.local(); + } + + void exchange() { + + // global all-to-all to gather a local copy of the global spike list + // on each node + //profiler_.enter("global exchange"); + auto global_spikes = communication_policy_.gather_spikes(local_spikes()); + num_spikes_ += global_spikes.size(); + clear_thread_spike_buffers(); + //profiler_.leave(); + + for (auto& q : events_) { + q.clear(); + } + + //profiler_.enter("events"); + + //profiler_.enter("make events"); + // check all global spikes to see if they will generate local events + for (auto spike : global_spikes) { + // search for targets + auto targets = + std::equal_range( + connections_.begin(), connections_.end(), spike.source + ); + + // generate an event for each target + for (auto it=targets.first; it!=targets.second; ++it) { + auto gidx = local_group_from_global_target(it->destination()); + + events_[gidx].push_back(it->make_event(spike)); + } + } + + //profiler_.leave(); // make events + + //profiler_.leave(); // event generation + } + + uint64_t num_spikes() const + { + return num_spikes_; + } + + int domain_id() const { + return communication_policy_.id(); + } + + int num_domains() const { + return communication_policy_.size(); + } + + const std::vector<postsynaptic_spike_event>& queue(int i) const { + return events_[i]; + } + + const std::vector<connection>& connections() const { + return connections_; + } + + communication_policy_type communication_policy() const { + return communication_policy_; + } + + const std::vector<id_type>& local_target_map() const { + return target_map_; + } + + std::vector<spike_type> local_spikes() { + std::vector<spike_type> spikes; + for (auto& v : thread_spikes_) { + spikes.insert(spikes.end(), v.begin(), v.end()); + } + return spikes; + } + + void clear_thread_spike_buffers() { + for (auto& v : thread_spikes_) { + v.clear(); + } + } + +private: + + // + // both of these can be fixed with double buffering + // + // FIXME : race condition on the thread_spikes_ buffers when exchange() modifies/access them + // ... other threads will be pushing to them simultaneously + // FIXME : race condition on the group-specific event queues when exchange pushes to them + // ... other threads will be accessing them to update their event queues + + // thread private storage for accumulating spikes + using local_spike_store_type = + nest::mc::threading::enumerable_thread_specific<std::vector<spike_type>>; + local_spike_store_type thread_spikes_; + + std::vector<connection> connections_; + std::vector<std::vector<postsynaptic_spike_event>> events_; + + // local target group i has targets in the half open range + // [target_map_[i], target_map_[i+1]) + std::vector<id_type> target_map_; + + // for keeping track of how time is spent where + //util::Profiler profiler_; + + // the number of groups and targets handled by this communicator + id_type num_groups_local_; + id_type num_targets_local_; + + // index maps for the global distribution of groups and targets + + // communicator i has the groups in the half open range : + // [group_gid_map_[i], group_gid_map_[i+1]) + std::vector<id_type> group_gid_map_; + + // communicator i has the targets in the half open range : + // [target_gid_map_[i], target_gid_map_[i+1]) + std::vector<id_type> target_gid_map_; + + communication_policy_type communication_policy_; + + uint64_t num_spikes_ = 0u; +}; + +} // namespace communication +} // namespace mc +} // namespace nest diff --git a/src/communication/connection.hpp b/src/communication/connection.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ee1acfd72dc21cffe01f28239bca29526366df31 --- /dev/null +++ b/src/communication/connection.hpp @@ -0,0 +1,70 @@ +#pragma once + +#include <cstdint> + +#include <event_queue.hpp> +#include <communication/spike.hpp> + +namespace nest { +namespace mc { +namespace communication { + +class connection { +public: + using id_type = uint32_t; + connection(id_type src, id_type dest, float w, float d) : + source_(src), + destination_(dest), + weight_(w), + delay_(d) + {} + + float weight() const { return weight_; } + float delay() const { return delay_; } + + id_type source() const { return source_; } + id_type destination() const { return destination_; } + + postsynaptic_spike_event make_event(spike<id_type> s) { + return {destination_, s.time + delay_, weight_}; + } + +private: + + id_type source_; + id_type destination_; + float weight_; + float delay_; +}; + +// connections are sorted by source id +// these operators make for easy interopability with STL algorithms + +static inline +bool operator< (connection lhs, connection rhs) { + return lhs.source() < rhs.source(); +} + +static inline +bool operator< (connection lhs, connection::id_type rhs) { + return lhs.source() < rhs; +} + +static inline +bool operator< (connection::id_type lhs, connection rhs) { + return lhs < rhs.source(); +} + +} // namespace communication +} // namespace mc +} // namespace nest + +static inline +std::ostream& operator<<(std::ostream& o, nest::mc::communication::connection const& con) { + char buff[512]; + snprintf( + buff, sizeof(buff), "con [%10u -> %10u : weight %8.4f, delay %8.4f]", + con.source(), con.destination(), con.weight(), con.delay() + ); + return o << buff; +} diff --git a/src/communication/global_policy.hpp b/src/communication/global_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b771e8ed24d5d022b6a23cc7d77171ca477a7dbf --- /dev/null +++ b/src/communication/global_policy.hpp @@ -0,0 +1,42 @@ +#pragma once + +#ifdef WITH_MPI + #include "communication/mpi_global_policy.hpp" +#else + #include "communication/serial_global_policy.hpp" +#endif + +namespace nest { +namespace mc { +namespace communication { + +#ifdef WITH_MPI +using global_policy = nest::mc::communication::mpi_global_policy; +#else +using global_policy = nest::mc::communication::serial_global_policy; +#endif + +template <typename Policy> +struct policy_guard { + using policy_type = Policy; + + policy_guard(int argc, char**& argv) { + policy_type::setup(argc, argv); + } + + policy_guard() = delete; + policy_guard(policy_guard&&) = delete; + policy_guard(const policy_guard&) = delete; + policy_guard& operator=(policy_guard&&) = delete; + policy_guard& operator=(const policy_guard&) = delete; + + ~policy_guard() { + Policy::teardown(); + } +}; + +using global_policy_guard = policy_guard<global_policy>; + +} // namespace communication +} // namespace mc +} // namespace nest diff --git a/src/communication/mpi.cpp b/src/communication/mpi.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0481ec15d17cb9a84fe88745c3bbe7533b25945e --- /dev/null +++ b/src/communication/mpi.cpp @@ -0,0 +1,59 @@ +#include <mpi.h> + +#include <communication/mpi.hpp> + +namespace nest { +namespace mc { +namespace mpi { + +// global state +namespace state { + int size = -1; + int rank = -1; +} // namespace state + +void init(int *argc, char ***argv) { + int provided; + + // initialize with thread serialized level of thread safety + MPI_Init_thread(argc, argv, MPI_THREAD_SERIALIZED, &provided); + assert(provided>=MPI_THREAD_SERIALIZED); + + MPI_Comm_rank(MPI_COMM_WORLD, &state::rank); + MPI_Comm_size(MPI_COMM_WORLD, &state::size); +} + +void finalize() { + MPI_Finalize(); +} + +bool is_root() { + return state::rank == 0; +} + +int rank() { + return state::rank; +} + +int size() { + return state::size; +} + +void barrier() { + MPI_Barrier(MPI_COMM_WORLD); +} + +bool ballot(bool vote) { + using traits = mpi_traits<char>; + + char result; + char value = vote ? 1 : 0; + + MPI_Allreduce(&value, &result, 1, traits::mpi_type(), MPI_LAND, MPI_COMM_WORLD); + + return result; +} + +} // namespace mpi +} // namespace mc +} // namespace nest diff --git a/src/communication/mpi.hpp b/src/communication/mpi.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5be71b6381bf8311bcab85d5396d64ba2c3eb26b --- /dev/null +++ b/src/communication/mpi.hpp @@ -0,0 +1,191 @@ +#pragma once + +#include <algorithm> +#include <iostream> +#include <type_traits> +#include <vector> + +#include <cassert> + +#include <mpi.h> + +#include <algorithms.hpp> + +namespace nest { +namespace mc { +namespace mpi { + + // prototypes + void init(int *argc, char ***argv); + void finalize(); + bool is_root(); + int rank(); + int size(); + void barrier(); + bool ballot(bool vote); + + // type traits for automatically setting MPI_Datatype information + // for C++ types + template <typename T> + struct mpi_traits { + constexpr static size_t count() { + return sizeof(T); + } + constexpr static MPI_Datatype mpi_type() { + return MPI_CHAR; + } + constexpr static bool is_mpi_native_type() { + return false; + } + }; + + #define MAKE_TRAITS(T,M) \ + template <> \ + struct mpi_traits<T> { \ + constexpr static size_t count() { return 1; } \ + /* constexpr */ static MPI_Datatype mpi_type() { return M; } \ + constexpr static bool is_mpi_native_type() { return true; } \ + }; + + MAKE_TRAITS(double, MPI_DOUBLE) + MAKE_TRAITS(float, MPI_FLOAT) + MAKE_TRAITS(int, MPI_INT) + MAKE_TRAITS(long int, MPI_LONG) + MAKE_TRAITS(char, MPI_CHAR) + MAKE_TRAITS(size_t, MPI_UNSIGNED_LONG) + static_assert(sizeof(size_t)==sizeof(unsigned long), + "size_t and unsigned long are not equivalent"); + + // Gather individual values of type T from each rank into a std::vector on + // the root rank. + // T must be trivially copyable + template<typename T> + std::vector<T> gather(T value, int root) { + static_assert( + true,//std::is_trivially_copyable<T>::value, + "gather can only be performed on trivally copyable types"); + + using traits = mpi_traits<T>; + auto buffer_size = (rank()==root) ? size() : 0; + std::vector<T> buffer(buffer_size); + + MPI_Gather( &value, traits::count(), traits::mpi_type(), // send buffer + buffer.data(), traits::count(), traits::mpi_type(), // receive buffer + root, MPI_COMM_WORLD); + + return buffer; + } + + // Gather individual values of type T from each rank into a std::vector on + // the every rank. + // T must be trivially copyable + template <typename T> + std::vector<T> gather_all(T value) { + static_assert( + true,//std::is_trivially_copyable<T>::value, + "gather_all can only be performed on trivally copyable types"); + + using traits = mpi_traits<T>; + std::vector<T> buffer(size()); + + MPI_Allgather( &value, traits::count(), traits::mpi_type(), // send buffer + buffer.data(), traits::count(), traits::mpi_type(), // receive buffer + MPI_COMM_WORLD); + + return buffer; + } + + template <typename T> + std::vector<T> gather_all(const std::vector<T> &values) { + static_assert( + true,//std::is_trivially_copyable<T>::value, + "gather_all can only be performed on trivally copyable types"); + + using traits = mpi_traits<T>; + auto counts = gather_all(int(values.size())); + for (auto& c : counts) { + c *= traits::count(); + } + auto displs = algorithms::make_index(counts); + + std::vector<T> buffer(displs.back()/traits::count()); + + MPI_Allgatherv( + // send buffer + values.data(), counts[rank()], traits::mpi_type(), + // receive buffer + buffer.data(), counts.data(), displs.data(), traits::mpi_type(), + MPI_COMM_WORLD + ); + + return buffer; + } + + template <typename T> + T reduce(T value, MPI_Op op, int root) { + using traits = mpi_traits<T>; + static_assert( + traits::is_mpi_native_type(), + "can only perform reductions on MPI native types"); + + T result; + + MPI_Reduce(&value, &result, 1, traits::mpi_type(), op, root, MPI_COMM_WORLD); + + return result; + } + + template <typename T> + T reduce(T value, MPI_Op op) { + using traits = mpi_traits<T>; + static_assert( + traits::is_mpi_native_type(), + "can only perform reductions on MPI native types"); + + T result; + + MPI_Allreduce(&value, &result, 1, traits::mpi_type(), op, MPI_COMM_WORLD); + + return result; + } + + template <typename T> + std::pair<T,T> minmax(T value) { + return {reduce<T>(value, MPI_MIN), reduce<T>(value, MPI_MAX)}; + } + + template <typename T> + std::pair<T,T> minmax(T value, int root) { + return {reduce<T>(value, MPI_MIN, root), reduce<T>(value, MPI_MAX, root)}; + } + + template <typename T> + T broadcast(T value, int root) { + static_assert( + true,//std::is_trivially_copyable<T>::value, + "broadcast can only be performed on trivally copyable types"); + + using traits = mpi_traits<T>; + + MPI_Bcast(&value, traits::count(), traits::mpi_type(), root, MPI_COMM_WORLD); + + return value; + } + + template <typename T> + T broadcast(int root) { + static_assert( + true,//std::is_trivially_copyable<T>::value, + "broadcast can only be performed on trivally copyable types"); + + using traits = mpi_traits<T>; + T value; + + MPI_Bcast(&value, traits::count(), traits::mpi_type(), root, MPI_COMM_WORLD); + + return value; + } + +} // namespace mpi +} // namespace mc +} // namespace nest diff --git a/src/communication/mpi_global_policy.hpp b/src/communication/mpi_global_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7409585447ae888eec2aa47ddc1fa0f137df7d6e --- /dev/null +++ b/src/communication/mpi_global_policy.hpp @@ -0,0 +1,66 @@ +#pragma once + +#ifndef WITH_MPI +#error "mpi_global_policy.hpp should only be compiled in a WITH_MPI build" +#endif + +#include <type_traits> +#include <vector> + +#include <cstdint> + +#include <communication/spike.hpp> +#include <communication/mpi.hpp> +#include <algorithms.hpp> + +namespace nest { +namespace mc { +namespace communication { + +struct mpi_global_policy { + using id_type = uint32_t; + + std::vector<spike<id_type>> + static gather_spikes(const std::vector<spike<id_type>>& local_spikes) { + return mpi::gather_all(local_spikes); + } + + static int id() { return mpi::rank(); } + + static int size() { return mpi::size(); } + + template <typename T> + static T min(T value) { + return nest::mc::mpi::reduce(value, MPI_MIN); + } + + template <typename T> + static T max(T value) { + return nest::mc::mpi::reduce(value, MPI_MAX); + } + + template < + typename T, + typename = typename std::enable_if<std::is_integral<T>::value> + > + static std::vector<T> make_map(T local) { + return algorithms::make_index(mpi::gather_all(local)); + } + + static void setup(int& argc, char**& argv) { + nest::mc::mpi::init(&argc, &argv); + } + + static void teardown() { + nest::mc::mpi::finalize(); + } + + static const char* name() { return "MPI"; } + +private: +}; + +} // namespace communication +} // namespace mc +} // namespace nest + diff --git a/src/communication/serial_global_policy.hpp b/src/communication/serial_global_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a1c7fee9671cb79768e255a710e0082386302e8a --- /dev/null +++ b/src/communication/serial_global_policy.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include <type_traits> +#include <vector> + +#include <cstdint> + +#include <communication/spike.hpp> + +namespace nest { +namespace mc { +namespace communication { + +struct serial_global_policy { + using id_type = uint32_t; + + std::vector<spike<id_type>> const + static gather_spikes(const std::vector<spike<id_type>>& local_spikes) { + return local_spikes; + } + + static int id() { + return 0; + } + + static int size() { + return 1; + } + + template <typename T> + static T min(T value) { + return value; + } + + template <typename T> + static T max(T value) { + return value; + } + + template < + typename T, + typename = typename std::enable_if<std::is_integral<T>::value> + > + static std::vector<T> make_map(T local) { + return {T(0), local}; + } + + static void setup(int& argc, char**& argv) {} + static void teardown() {} + static const char* name() { return "serial"; } +}; + +} // namespace communication +} // namespace mc +} // namespace nest diff --git a/src/communication/spike.hpp b/src/communication/spike.hpp new file mode 100644 index 0000000000000000000000000000000000000000..03b5ed75def7a2fb68bd491ae7f1208d4b230b8a --- /dev/null +++ b/src/communication/spike.hpp @@ -0,0 +1,48 @@ +#pragma once + +#include <type_traits> +#include <ostream> + +namespace nest { +namespace mc { +namespace communication { + +template < + typename I, + typename = typename std::enable_if<std::is_integral<I>::value> +> +struct spike { + using id_type = I; + id_type source = 0; + float time = -1.; + + spike() = default; + + spike(id_type s, float t) : + source(s), time(t) + {} +}; + +} // namespace mc +} // namespace nest +} // namespace communication + +/// custom stream operator for printing nest::mc::spike<> values +template <typename I> +std::ostream& operator<<( + std::ostream& o, + nest::mc::communication::spike<I> s) +{ + return o << "spike[t " << s.time << ", src " << s.source << "]"; +} + +/// less than comparison operator for nest::mc::spike<> values +/// spikes are ordered by spike time, for use in sorting and queueing +template <typename I> +bool operator<( + nest::mc::communication::spike<I> lhs, + nest::mc::communication::spike<I> rhs) +{ + return lhs.time < rhs.time; +} + diff --git a/src/communication/spike_source.hpp b/src/communication/spike_source.hpp new file mode 100644 index 0000000000000000000000000000000000000000..95099cb547de24e76386c2253960b8fffe4bf569 --- /dev/null +++ b/src/communication/spike_source.hpp @@ -0,0 +1,76 @@ +#pragma once + +#include <cell.hpp> +#include <util/optional.hpp> + +namespace nest { +namespace mc { + +// spike detector for a lowered cell +template <typename Cell> +class spike_detector +{ +public: + using cell_type = Cell; + + spike_detector( const cell_type& cell, segment_location loc, double thresh, float t_init) : + location_(loc), + threshold_(thresh), + previous_t_(t_init) + { + previous_v_ = cell.voltage(location_); + is_spiking_ = previous_v_ >= thresh ? true : false; + } + + util::optional<float> test(const cell_type& cell, float t) { + util::optional<float> result = util::nothing; + auto v = cell.voltage(location_); + + // these if statements could be simplified, but I keep them like + // this to clearly reflect the finite state machine + if (!is_spiking_) { + if (v>=threshold_) { + // the threshold has been passed, so estimate the time using + // linear interpolation + auto pos = (threshold_ - previous_v_)/(v - previous_v_); + result = previous_t_ + pos*(t - previous_t_); + + is_spiking_ = true; + } + } + else { + if (v<threshold_) { + is_spiking_ = false; + } + } + + previous_v_ = v; + previous_t_ = t; + + return result; + } + + bool is_spiking() const { return is_spiking_; } + + segment_location location() const { return location_; } + + float t() const { return previous_t_; } + + float v() const { return previous_v_; } + +private: + + // parameters/data + segment_location location_; + double threshold_; + + // state + float previous_t_; + float previous_v_; + bool is_spiking_; +}; + + +} // namespace mc +} // namespace nest + diff --git a/src/event_queue.hpp b/src/event_queue.hpp index 989e34873a1f3ff851d0a3a9f85276cea9d9bf7a..7f3034f3f08bbb2235bf457c95e8eedfe8d0895d 100644 --- a/src/event_queue.hpp +++ b/src/event_queue.hpp @@ -9,22 +9,31 @@ namespace nest { namespace mc { -struct local_event { +struct postsynaptic_spike_event { uint32_t target; float time; float weight; }; -inline bool operator < (local_event const& l, local_event const& r) { - return l.time < r.time; -} +inline float event_time(const postsynaptic_spike_event &ev) { return ev.time; } -inline bool operator > (local_event const& l, local_event const& r) { - return l.time > r.time; -} +struct sample_event { + uint32_t sampler_index; + float time; +}; +inline float event_time(const sample_event &ev) { return ev.time; } + +/* Event objects must have a method event_time() which returns a value + * from a type with a total ordering with respect to <, >, etc. + */ + +template <typename Event> class event_queue { public : + using value_type = Event; + using time_type = decltype(event_time(std::declval<Event>())); + // create event_queue() {} @@ -37,7 +46,7 @@ public : } // push thing - void push(local_event e) { + void push(const value_type &e) { queue_.push(e); } @@ -46,8 +55,8 @@ public : } // pop until - util::optional<local_event> pop_if_before(float t_until) { - if (!queue_.empty() && queue_.top().time < t_until) { + util::optional<value_type> pop_if_before(time_type t_until) { + if (!queue_.empty() && event_time(queue_.top()) < t_until) { auto ev = queue_.top(); queue_.pop(); return ev; @@ -58,10 +67,16 @@ public : } private: + struct event_greater { + bool operator()(const Event &a, const Event &b) { + return event_time(a) > event_time(b); + } + }; + std::priority_queue< - local_event, - std::vector<local_event>, - std::greater<local_event> + Event, + std::vector<Event>, + event_greater > queue_; }; @@ -69,7 +84,7 @@ private: } // namespace mc inline -std::ostream& operator<< (std::ostream& o, nest::mc::local_event e) +std::ostream& operator<< (std::ostream& o, const nest::mc::postsynaptic_spike_event& e) { return o << "event[" << e.target << "," << e.time << "," << e.weight << "]"; } diff --git a/src/fvm.hpp b/src/fvm_cell.hpp similarity index 78% rename from src/fvm.hpp rename to src/fvm_cell.hpp index 603196c0cc66ebb2293c09dfd9a153fd7a806267..90e7f7be0d373274aba0b4ff28b908465e674cb5 100644 --- a/src/fvm.hpp +++ b/src/fvm_cell.hpp @@ -17,17 +17,21 @@ #include <segment.hpp> #include <stimulus.hpp> #include <util.hpp> +#include <profiling/profiler.hpp> #include <vector/include/Vector.hpp> #include <mechanisms/expsyn.hpp> + namespace nest { namespace mc { namespace fvm { template <typename T, typename I> class fvm_cell { - public : +public: + + fvm_cell() = default; /// the real number type using value_type = T; @@ -83,69 +87,62 @@ class fvm_cell { } /// return the voltage in each CV - vector_view voltage() { - return voltage_; - } - const_vector_view voltage() const { - return voltage_; - } + vector_view voltage() { return voltage_; } + const_vector_view voltage() const { return voltage_; } - std::size_t size() const { - return matrix_.size(); - } + std::size_t size() const { return matrix_.size(); } /// return reference to in iterable container of the mechanisms - std::vector<mechanism_type>& mechanisms() { - return mechanisms_; - } + std::vector<mechanism_type>& mechanisms() { return mechanisms_; } /// return reference to list of ions //std::map<mechanisms::ionKind, ion_type> ions_; - std::map<mechanisms::ionKind, ion_type>& ions() { - return ions_; - } - std::map<mechanisms::ionKind, ion_type> const& ions() const { - return ions_; - } + std::map<mechanisms::ionKind, ion_type>& ions() { return ions_; } + std::map<mechanisms::ionKind, ion_type> const& ions() const { return ions_; } /// return reference to sodium ion - ion_type& ion_na() { - return ions_[mechanisms::ionKind::na]; - } - ion_type const& ion_na() const { - return ions_[mechanisms::ionKind::na]; - } + ion_type& ion_na() { return ions_[mechanisms::ionKind::na]; } + ion_type const& ion_na() const { return ions_[mechanisms::ionKind::na]; } /// return reference to calcium ion - ion_type& ion_ca() { - return ions_[mechanisms::ionKind::ca]; - } - ion_type const& ion_ca() const { - return ions_[mechanisms::ionKind::ca]; - } + ion_type& ion_ca() { return ions_[mechanisms::ionKind::ca]; } + ion_type const& ion_ca() const { return ions_[mechanisms::ionKind::ca]; } /// return reference to pottasium ion - ion_type& ion_k() { - return ions_[mechanisms::ionKind::k]; - } - ion_type const& ion_k() const { - return ions_[mechanisms::ionKind::k]; - } + ion_type& ion_k() { return ions_[mechanisms::ionKind::k]; } + ion_type const& ion_k() const { return ions_[mechanisms::ionKind::k]; } /// make a time step void advance(value_type dt); - /// advance solution to target time tfinal with maximum step size dt - void advance_to(value_type tfinal, value_type dt); + /// pass an event to the appropriate synapse and call net_receive + void apply_event(postsynaptic_spike_event e) { + mechanisms_[synapse_index_]->net_receive(e.target, e.weight); + } + + mechanism_type& synapses() { + return mechanisms_[synapse_index_]; + } /// set initial states void initialize(); - event_queue& queue() { - return events_; + /// returns the compartment index of a segment location + int compartment_index(segment_location loc) const; + + /// returns voltage at a segment location + value_type voltage(segment_location loc) const; + + value_type time() const { return t_; } + + value_type probe(uint32_t i) const { + auto p = probes_[i]; + return (this->*p.first)[p.second]; } - private: + std::size_t num_probes() const { return probes_.size(); } + +private: /// current time value_type t_ = value_type{0}; @@ -153,6 +150,9 @@ class fvm_cell { /// the linear system for implicit time stepping of cell state matrix_type matrix_; + /// index for fast lookup of compartment index ranges of segments + index_type segment_index_; + /// cv_areas_[i] is the surface area of CV i vector_type cv_areas_; @@ -186,8 +186,7 @@ class fvm_cell { std::vector<std::pair<uint32_t, i_clamp>> stimulii_; - /// event queue - event_queue events_; + std::vector<std::pair<const vector_type fvm_cell::*, uint32_t>> probes_; }; //////////////////////////////////////////////////////////////////////////////// @@ -216,7 +215,7 @@ fvm_cell<T, I>::fvm_cell(nest::mc::cell const& cell) matrix_ = matrix_type(graph.parent_index); auto parent_index = matrix_.p(); - auto const& segment_index = graph.segment_index; + segment_index_ = graph.segment_index; auto seg_idx = 0; for(auto const& s : cell.segments()) { @@ -249,7 +248,7 @@ fvm_cell<T, I>::fvm_cell(nest::mc::cell const& cell) auto c_m = cable->mechanism("membrane").get("c_m").value; auto r_L = cable->mechanism("membrane").get("r_L").value; for(auto c : cable->compartments()) { - auto i = segment_index[seg_idx] + c.index; + auto i = segment_index_[seg_idx] + c.index; auto j = parent_index[i]; auto radius_center = math::mean(c.radius); @@ -308,7 +307,7 @@ fvm_cell<T, I>::fvm_cell(nest::mc::cell const& cell) // calculate the number of compartments that contain the mechanism auto num_comp = 0u; for(auto seg : mech.second) { - num_comp += segment_index[seg+1] - segment_index[seg]; + num_comp += segment_index_[seg+1] - segment_index_[seg]; } // build a vector of the indexes of the compartments that contain @@ -316,11 +315,11 @@ fvm_cell<T, I>::fvm_cell(nest::mc::cell const& cell) index_type compartment_index(num_comp); auto pos = 0u; for(auto seg : mech.second) { - auto seg_size = segment_index[seg+1] - segment_index[seg]; + auto seg_size = segment_index_[seg+1] - segment_index_[seg]; std::iota( compartment_index.data() + pos, compartment_index.data() + pos + seg_size, - segment_index[seg] + segment_index_[seg] ); pos += seg_size; } @@ -383,8 +382,8 @@ fvm_cell<T, I>::fvm_cell(nest::mc::cell const& cell) // add the stimulii for(const auto& stim : cell.stimulii()) { - auto idx = find_compartment_index(stim.first, graph); - stimulii_.push_back( {idx, stim.second} ); + auto idx = find_compartment_index(stim.location, graph); + stimulii_.push_back( {idx, stim.clamp} ); } // add the synapses @@ -404,6 +403,21 @@ fvm_cell<T, I>::fvm_cell(nest::mc::cell const& cell) synapse_index_ = mechanisms_.size()-1; // don't forget to give point processes access to cv_areas_ mechanisms_[synapse_index_]->set_areas(cv_areas_); + + // record probe locations by index into corresponding state vector + for (auto probe : cell.probes()) { + uint32_t comp = find_compartment_index(probe.location, graph); + switch (probe.kind) { + case probeKind::membrane_voltage: + probes_.push_back({&fvm_cell::voltage_, comp}); + break; + case probeKind::membrane_current: + probes_.push_back({&fvm_cell::current_, comp}); + break; + default: + throw std::logic_error("unrecognized probeKind"); + } + } } template <typename T, typename I> @@ -451,13 +465,30 @@ void fvm_cell<T, I>::setup_matrix(T dt) rhs[i] = cv_areas_[i]*(voltage_[i] - factor/cv_capacitance_[i]*current_[i]); } } +template <typename T, typename I> +int fvm_cell<T, I>::compartment_index(segment_location loc) const +{ + EXPECTS(loc.segment < segment_index_.size()); + + const auto seg = loc.segment; + + auto first = segment_index_[seg]; + auto n = segment_index_[seg+1] - first; + auto index = std::floor(n*loc.position); + return index<n ? first+index : first+n-1; +} + +template <typename T, typename I> +T fvm_cell<T, I>::voltage(segment_location loc) const +{ + return voltage_[compartment_index(loc)]; +} template <typename T, typename I> void fvm_cell<T, I>::initialize() { t_ = 0.; - // initialize mechanism states for(auto& m : mechanisms_) { m->nrn_init(); } @@ -468,12 +499,15 @@ void fvm_cell<T, I>::advance(T dt) { using memory::all; + mc::util::profiler_enter("current"); current_(all) = 0.; // update currents from ion channels for(auto& m : mechanisms_) { + mc::util::profiler_enter(m->name().c_str()); m->set_params(t_, dt); m->nrn_current(); + mc::util::profiler_leave(); } // add current contributions from stimulii @@ -484,43 +518,29 @@ void fvm_cell<T, I>::advance(T dt) // the factor of 100 scales the injected current to 10^2.nA current_[loc] -= 100.*ie/cv_areas_[loc]; } + mc::util::profiler_leave(); - // set matrix diagonals and rhs - setup_matrix(dt); - + mc::util::profiler_enter("matrix", "setup"); // solve the linear system + setup_matrix(dt); + mc::util::profiler_leave(); mc::util::profiler_enter("solve"); matrix_.solve(); - + mc::util::profiler_leave(); voltage_(all) = matrix_.rhs(); + mc::util::profiler_leave(); - // update states + mc::util::profiler_enter("state"); + // integrate state of gating variables etc. for(auto& m : mechanisms_) { + mc::util::profiler_enter(m->name().c_str()); m->nrn_state(); + mc::util::profiler_leave(); } + mc::util::profiler_leave(); t_ += dt; } -template <typename T, typename I> -void fvm_cell<T, I>::advance_to(T tfinal, T dt) -{ - if(t_>=tfinal) { - return; - } - - do { - auto tstep = std::min(tfinal, t_+dt); - auto next = events_.pop_if_before(tstep); - auto tnext = next? next->time: tstep; - - advance(tnext-t_); - t_ = tnext; - if (next) { // handle event - mechanisms_[synapse_index_]->net_receive(next->target, next->weight); - } - } while(t_<tfinal); -} - } // namespace fvm } // namespace mc } // namespace nest diff --git a/src/profiling/profiler.cpp b/src/profiling/profiler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c07351e296ba8873ed7064ca4b605d2e9bb0cb8d --- /dev/null +++ b/src/profiling/profiler.cpp @@ -0,0 +1,375 @@ +#include "profiler.hpp" + +#include <communication/global_policy.hpp> + +namespace nest { +namespace mc { +namespace util { + +///////////////////////////////////////////////////////// +// profiler_node +///////////////////////////////////////////////////////// +void profiler_node::print(int indent) { + std::string s = std::string(indent, ' ') + name; + std::cout << s + << std::string(60-s.size(), '.') + << value + << "\n"; + for (auto& n : children) { + n.print(indent+2); + } +} + +void profiler_node::print(std::ostream& stream, double threshold) { + // convert threshold from proportion to time + threshold *= value; + print_sub(stream, 0, threshold, value); +} + +void profiler_node::print_sub( + std::ostream& stream, + int indent, + double threshold, + double total) +{ + char buffer[512]; + + if (value < threshold) { + std::cout << green("not printing ") << name << std::endl; + return; + } + + auto max_contribution = + std::accumulate( + children.begin(), children.end(), -1., + [] (double lhs, const profiler_node& rhs) { + return lhs > rhs.value ? lhs : rhs.value; + } + ); + + // print the table row + auto const indent_str = std::string(indent, ' '); + auto label = indent_str + name; + float percentage = 100.*value/total; + snprintf(buffer, sizeof(buffer), "%-25s%10.3f%10.1f", + label.c_str(), + float(value), + float(percentage)); + bool print_children = + threshold==0. ? children.size()>0 + : max_contribution >= threshold; + + stream << (print_children ? white(buffer) : buffer) << "\n"; + + if (print_children) { + auto other = 0.; + for (auto &n : children) { + if (n.value<threshold || n.name=="other") { + other += n.value; + } + else { + n.print_sub(stream, indent + 2, threshold, total); + } + } + if (other>=std::max(threshold, 0.001) && children.size()) { + label = indent_str + " other"; + percentage = 100.*other/total; + snprintf(buffer, sizeof(buffer), "%-25s%10.3f%10.1f", + label.c_str(), float(other), percentage); + stream << buffer << std::endl; + } + } +} + +void profiler_node::fuse(const profiler_node& other) { + for (auto& n : other.children) { + auto it = std::find(children.begin(), children.end(), n); + if (it!=children.end()) { + (*it).fuse(n); + } + else { + children.push_back(n); + } + } + + value += other.value; +} + +double profiler_node::time_in_other() const { + auto o = std::find_if( + children.begin(), children.end(), + [](const profiler_node& n) { + return n.name == std::string("other"); + } + ); + return o==children.end() ? 0. : o->value; +} + +void profiler_node::scale(double factor) { + value *= factor; + for (auto& n : children) { + n.scale(factor); + } +} + +profiler_node::json profiler_node::as_json() const { + json node; + node["name"] = name; + node["time"] = value; + for (const auto& n : children) { + node["regions"].push_back(n.as_json()); + } + return node; +} + +profiler_node operator+ (const profiler_node& lhs, const profiler_node& rhs) { + assert(lhs.name == rhs.name); + auto node = lhs; + node.fuse(rhs); + return node; +} + +bool operator== (const profiler_node& lhs, const profiler_node& rhs) { + return lhs.name == rhs.name; +} + +///////////////////////////////////////////////////////// +// region_type +///////////////////////////////////////////////////////// +region_type* region_type::subregion(const char* n) { + size_t hsh = impl::hash(n); + auto s = subregions_.find(hsh); + if (s == subregions_.end()) { + subregions_[hsh] = util::make_unique<region_type>(n, this); + return subregions_[hsh].get(); + } + return s->second.get(); +} + +double region_type::subregion_contributions() const { + return + std::accumulate( + subregions_.begin(), subregions_.end(), 0., + [](double l, decltype(*(subregions_.begin())) r) { + return l+r.second->total(); + } + ); +} + +profiler_node region_type::populate_performance_tree() const { + profiler_node tree(total(), name()); + + for (auto &it : subregions_) { + tree.children.push_back(it.second->populate_performance_tree()); + } + + // sort the contributions in descending order + std::stable_sort( + tree.children.begin(), tree.children.end(), + [](const profiler_node& lhs, const profiler_node& rhs) { + return lhs.value>rhs.value; + } + ); + + if (tree.children.size()) { + // find the contribution of parts of the code that were not explicitly profiled + auto contributions = + std::accumulate( + tree.children.begin(), tree.children.end(), 0., + [](double v, profiler_node& n) { + return v+n.value; + } + ); + auto other = total() - contributions; + + // add the "other" category + tree.children.emplace_back(other, std::string("other")); + } + + return tree; +} + +///////////////////////////////////////////////////////// +// region_type +///////////////////////////////////////////////////////// +void profiler::enter(const char* name) { + if (!is_activated()) return; + current_region_ = current_region_->subregion(name); + current_region_->start_time(); +} + +void profiler::leave() { + if (!is_activated()) return; + if (current_region_->parent()==nullptr) { + throw std::out_of_range("attempt to leave root memory tracing region"); + } + current_region_->end_time(); + current_region_ = current_region_->parent(); +} + +void profiler::leave(int n) { + EXPECTS(n>=1); + + while(n--) { + leave(); + } +} + +void profiler::start() { + if (is_activated()) { + throw std::out_of_range( + "attempt to start an already running profiler" + ); + } + activate(); + start_time_ = timer_type::tic(); + root_region_.start_time(); +} + +void profiler::stop() { + if (!is_in_root()) { + throw std::out_of_range( + "profiler must be in root region when stopped" + ); + } + root_region_.end_time(); + stop_time_ = timer_type::tic(); + + deactivate(); +} + +profiler_node profiler::performance_tree() { + if (is_activated()) { + stop(); + } + return root_region_.populate_performance_tree(); +} + + +#ifdef WITH_PROFILING +namespace data { + profiler_wrapper profilers_(profiler("root")); +} + +profiler& get_profiler() { + auto& p = data::profilers_.local(); + if (!p.is_activated()) { + p.start(); + } + return p; +} + +// this will throw an exception if the profler has already been started +void profiler_start() { + data::profilers_.local().start(); +} +void profiler_stop() { + get_profiler().stop(); +} +void profiler_enter(const char* n) { + get_profiler().enter(n); +} + +void profiler_leave() { + get_profiler().leave(); +} +void profiler_leave(int nlevels) { + get_profiler().leave(nlevels); +} + +// iterate over all profilers and ensure that they have the same start stop times +void stop_profilers() { + for (auto& p : data::profilers_) { + p.stop(); + } +} + +void profiler_output(double threshold) { + stop_profilers(); + + // Find the earliest start time and latest stop time over all profilers + // This can be used to calculate the wall time for this communicator. + // The min-max values are used because, for example, the individual + // profilers might start at different times. In this case, the time stamp + // when the first profiler started is taken as the start time of the whole + // measurement period. Likewise for the last profiler to stop. + auto start_time = data::profilers_.begin()->start_time(); + auto stop_time = data::profilers_.begin()->stop_time(); + for(auto& p : data::profilers_) { + start_time = std::min(start_time, p.start_time()); + stop_time = std::max(stop_time, p.stop_time()); + } + // calculate the wall time + auto wall_time = timer_type::difference(start_time, stop_time); + // calculate the accumulated wall time over all threads + auto nthreads = data::profilers_.size(); + auto thread_wall = wall_time * nthreads; + + // gather the profilers into one accumulated profile over all threads + auto thread_measured = 0.; // accumulator for the time measured in each thread + auto p = profiler_node(0, "total"); + for(auto& thread_profiler : data::profilers_) { + auto tree = thread_profiler.performance_tree(); + thread_measured += tree.value - tree.time_in_other(); + p.fuse(thread_profiler.performance_tree()); + } + auto efficiency = 100. * thread_measured / thread_wall; + + p.scale(1./nthreads); + + auto ncomms = communication::global_policy::size(); + auto comm_rank = communication::global_policy::id(); + bool print = comm_rank==0 ? true : false; + if(print) { + std::cout << " ---------------------------------------------------- \n"; + std::cout << "| profiler |\n"; + std::cout << " ---------------------------------------------------- \n"; + char line[128]; + std::snprintf( + line, sizeof(line), "%-18s%10.3f s\n", + "wall time", float(wall_time)); + std::cout << line; + std::snprintf( + line, sizeof(line), "%-18s%10d\n", + "communicators", int(ncomms)); + std::cout << line; + std::snprintf( + line, sizeof(line), "%-18s%10d\n", + "threads", int(nthreads)); + std::cout << line; + std::snprintf( + line, sizeof(line), "%-18s%10.2f %%\n", + "thread efficiency", float(efficiency)); + std::cout << line << "\n"; + p.print(std::cout, threshold); + + std::cout << "\n\n"; + } + + nlohmann::json as_json; + as_json["wall time"] = wall_time; + as_json["threads"] = nthreads; + as_json["efficiency"] = efficiency; + as_json["communicators"] = ncomms; + as_json["rank"] = comm_rank; + as_json["regions"] = p.as_json(); + + auto fname = std::string("profile_" + std::to_string(comm_rank)); + std::ofstream fid(fname); + fid << std::setw(1) << as_json; +} + +#else +void profiler_start() {} +void profiler_stop() {} +void profiler_enter(const char*) {} +void profiler_leave() {} +void profiler_leave(int) {} +void stop_profilers() {} +void profiler_output(double threshold) {} +#endif + +} // namespace util +} // namespace mc +} // namespace nest + diff --git a/src/profiling/profiler.hpp b/src/profiling/profiler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7230697baab21f914fa3b92f4ffbed4382b062b6 --- /dev/null +++ b/src/profiling/profiler.hpp @@ -0,0 +1,240 @@ +#pragma once + +#include <algorithm> +#include <unordered_map> +#include <map> +#include <memory> +#include <stdexcept> +#include <fstream> +#include <iostream> +#include <vector> + +#include <cassert> +#include <cstdlib> + +#include <json/src/json.hpp> + +#include <threading/threading.hpp> +#include <util.hpp> + +namespace nest { +namespace mc { +namespace util { + +inline std::string green(std::string s) { return s; } +inline std::string yellow(std::string s) { return s; } +inline std::string white(std::string s) { return s; } +inline std::string red(std::string s) { return s; } +inline std::string cyan(std::string s) { return s; } + +using timer_type = nest::mc::threading::timer; + +namespace impl { + /// simple hashing function for strings + static inline + size_t hash(const char* s) { + size_t h = 5381; + while (*s) { + h = ((h << 5) + h) + int(*s); + ++s; + } + return h; + } + + /// std::string overload for hash + static inline + size_t hash(const std::string& s) { + return hash(s.c_str()); + } +} // namespace impl + +/// The tree data structure that is generated by post-processing of +/// a profiler. +struct profiler_node { + double value; + std::string name; + std::vector<profiler_node> children; + using json = nlohmann::json; + + profiler_node() : + value(0.), name("") + {} + + profiler_node(double v, const std::string& n) : + value(v), name(n) + {} + + void print(int indent=0); + void print(std::ostream& stream, double threshold); + void print_sub(std::ostream& stream, int indent, double threshold, double total); + void fuse(const profiler_node& other); + /// return wall time spend in "other" region + double time_in_other() const; + /// scale the value in each node by factor + /// performed to all children recursively + void scale(double factor); + + json as_json() const; +}; + +profiler_node operator+ (const profiler_node& lhs, const profiler_node& rhs); +bool operator== (const profiler_node& lhs, const profiler_node& rhs); + +// a region in the profiler, has +// - name +// - accumulated timer +// - nested sub-regions +class region_type { + region_type *parent_ = nullptr; + std::string name_; + size_t hash_; + std::unordered_map<size_t, std::unique_ptr<region_type>> subregions_; + timer_type::time_point start_time_; + double total_time_ = 0; + +public: + + explicit region_type(std::string n) : + name_(std::move(n)), + hash_(impl::hash(n)), + start_time_(timer_type::tic()) + {} + + explicit region_type(const char* n) : + region_type(std::string(n)) + {} + + region_type(std::string n, region_type* p) : + region_type(std::move(n)) + { + parent_ = p; + } + + const std::string& name() const { return name_; } + void name(std::string n) { name_ = std::move(n); } + + region_type* parent() { return parent_; } + + void start_time() { start_time_ = timer_type::tic(); } + void end_time () { total_time_ += timer_type::toc(start_time_); } + double total() const { return total_time_; } + + bool has_subregions() const { return subregions_.size() > 0; } + + size_t hash() const { return hash_; } + + region_type* subregion(const char* n); + + double subregion_contributions() const; + + profiler_node populate_performance_tree() const; +}; + +class profiler { +public: + profiler(std::string name) : + root_region_(std::move(name)) + {} + + // the copy constructor doesn't do a "deep copy" + // it simply creates a new profiler with the same name + // This is needed for tbb to create a list of thread local profilers + profiler(const profiler& other) : + profiler(other.root_region_.name()) + {} + + /// step down into level with name + void enter(const char* name); + + /// step up one level + void leave(); + + /// step up multiple n levels in one call + void leave(int n); + + /// return a reference to the root region + region_type& regions() { return root_region_; } + + /// return a pointer to the current region + region_type* current_region() { return current_region_; } + + /// return if in the root region (i.e. the highest level) + bool is_in_root() const { return &root_region_ == current_region_; } + + /// return if the profiler has been activated + bool is_activated() const { return activated_; } + + /// start (activate) the profiler + void start(); + + /// stop (deactivate) the profiler + void stop(); + + /// the time stamp at which the profiler was started (avtivated) + timer_type::time_point start_time() const { return start_time_; } + + /// the time stamp at which the profiler was stopped (deavtivated) + timer_type::time_point stop_time() const { return stop_time_; } + + /// the time in seconds between activation and deactivation of the profiler + double wall_time() const { + return timer_type::difference(start_time_, stop_time_); + } + + /// stop the profiler then generate the performance tree ready for output + profiler_node performance_tree(); + +private: + void activate() { activated_ = true; } + void deactivate() { activated_ = false; } + + timer_type::time_point start_time_; + timer_type::time_point stop_time_; + bool activated_ = false; + region_type root_region_; + region_type* current_region_ = &root_region_; +}; + +#ifdef WITH_PROFILING +namespace data { + using profiler_wrapper = nest::mc::threading::enumerable_thread_specific<profiler>; + extern profiler_wrapper profilers_; +} +#endif + +/// get a reference to the thread private profiler +/// will lazily create and start the profiler it it has not already been done so +profiler& get_profiler(); + +/// start thread private profiler +void profiler_start(); + +/// stop thread private profiler +void profiler_stop(); + +/// enter a profiling region with name n +void profiler_enter(const char* n); + +/// enter nested profiler regions in a single call +template <class...Args> +void profiler_enter(const char* n, Args... args) { +#ifdef WITH_PROFILING + get_profiler().enter(n); + profiler_enter(args...); +#endif +} + +/// move up one level in the profiler +void profiler_leave(); +/// move up multiple profiler levels in one call +void profiler_leave(int nlevels); + +/// iterate and stop them +void stop_profilers(); + +/// print the collated profiler to std::cout +void profiler_output(double threshold); + +} // namespace util +} // namespace mc +} // namespace nest diff --git a/src/threading/serial.hpp b/src/threading/serial.hpp new file mode 100644 index 0000000000000000000000000000000000000000..eef783949049c77f3c06bfa059b02c06d96b3670 --- /dev/null +++ b/src/threading/serial.hpp @@ -0,0 +1,83 @@ +#pragma once + +#if !defined(WITH_SERIAL) + #error this header can only be loaded if WITH_SERIAL is set +#endif + +#include <array> +#include <chrono> +#include <string> + +namespace nest { +namespace mc { +namespace threading { + +/////////////////////////////////////////////////////////////////////// +// types +/////////////////////////////////////////////////////////////////////// +template <typename T> +class enumerable_thread_specific { + std::array<T, 1> data; + + public : + + enumerable_thread_specific() = default; + + enumerable_thread_specific(const T& init) : + data{init} + {} + + enumerable_thread_specific(T&& init) : + data{std::move(init)} + {} + + T& local() { return data[0]; } + const T& local() const { return data[0]; } + + auto size() -> decltype(data.size()) const { return data.size(); } + + auto begin() -> decltype(data.begin()) { return data.begin(); } + auto end() -> decltype(data.end()) { return data.end(); } + + auto cbegin() -> decltype(data.cbegin()) const { return data.cbegin(); } + auto cend() -> decltype(data.cend()) const { return data.cend(); } +}; + + +/////////////////////////////////////////////////////////////////////// +// algorithms +/////////////////////////////////////////////////////////////////////// +struct parallel_for { + template <typename F> + static void apply(int left, int right, F f) { + for(int i=left; i<right; ++i) { + f(i); + } + } +}; + +inline std::string description() { + return "serial"; +} + +struct timer { + using time_point = std::chrono::time_point<std::chrono::system_clock>; + + static inline time_point tic() { + return std::chrono::system_clock::now(); + } + + static inline double toc(time_point t) { + return std::chrono::duration<double>(tic() - t).count(); + } + + static inline double difference(time_point b, time_point e) { + return std::chrono::duration<double>(e-b).count(); + } +}; + + +} // threading +} // mc +} // nest + diff --git a/src/threading/tbb.hpp b/src/threading/tbb.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ed82e26ec87617f7453f4548efe3eec3f858507f --- /dev/null +++ b/src/threading/tbb.hpp @@ -0,0 +1,58 @@ +#pragma once + +#if !defined(WITH_TBB) + #error this header can only be loaded if WITH_TBB is set +#endif + +#include <string> + +#include <tbb/tbb.h> +#include <tbb/compat/thread> +#include <tbb/enumerable_thread_specific.h> + +namespace nest { +namespace mc { +namespace threading { + +template <typename T> +using enumerable_thread_specific = tbb::enumerable_thread_specific<T>; + +struct parallel_for { + template <typename F> + static void apply(int left, int right, F f) { + tbb::parallel_for(left, right, f); + } +}; + +inline std::string description() { + return "TBB"; +} + +struct timer { + using time_point = tbb::tick_count; + + static inline time_point tic() { + return tbb::tick_count::now(); + } + + static inline double toc(time_point t) { + return (tic() - t).seconds(); + } + + static inline double difference(time_point b, time_point e) { + return (e-b).seconds(); + } +}; + +} // threading +} // mc +} // nest + +namespace tbb { + /// comparison operator for tbb::tick_count type + /// returns true iff time stamp l occurred before timestamp r + inline bool operator< (tbb::tick_count l, tbb::tick_count r) { + return (l-r).seconds() < 0.; + } +} + diff --git a/src/threading/threading.hpp b/src/threading/threading.hpp new file mode 100644 index 0000000000000000000000000000000000000000..76431d7de8cb85aebcf18ce1988ab7f6f2ea118f --- /dev/null +++ b/src/threading/threading.hpp @@ -0,0 +1,9 @@ +#pragma once + +#if defined(WITH_TBB) + #include "tbb.hpp" +#else + #define WITH_SERIAL + #include "serial.hpp" +#endif + diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 33e4439834b902c2ad775dc157e03fb452c8f8ac..ebebb306e1a896f1cd815751324971ea4e67cfe8 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -18,12 +18,14 @@ set(TEST_SOURCES test_compartments.cpp test_event_queue.cpp test_fvm.cpp + test_cell_group.cpp test_matrix.cpp test_mechanisms.cpp test_optional.cpp test_parameters.cpp test_point.cpp test_segment.cpp + test_spikes.cpp test_stimulus.cpp test_swcio.cpp test_synapses.cpp @@ -38,7 +40,7 @@ set(VALIDATION_SOURCES # unit tests validate_ball_and_stick.cpp validate_soma.cpp - validate_synapses.cpp + #validate_synapses.cpp # unit test driver validate.cpp @@ -48,15 +50,23 @@ add_definitions("-DDATADIR=\"${CMAKE_SOURCE_DIR}/data\"") add_executable(test.exe ${TEST_SOURCES} ${HEADERS}) add_executable(validate.exe ${VALIDATION_SOURCES} ${HEADERS}) -target_link_libraries(test.exe LINK_PUBLIC cellalgo gtest) -target_link_libraries(validate.exe LINK_PUBLIC cellalgo gtest) +set(TARGETS test.exe validate.exe) -set_target_properties(test.exe - PROPERTIES - RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/tests" -) +foreach(target ${TARGETS}) + target_link_libraries(${target} LINK_PUBLIC cellalgo gtest) + + if(WITH_TBB) + target_link_libraries(${target} LINK_PUBLIC ${TBB_LIBRARIES}) + endif() + + if(WITH_MPI) + target_link_libraries(${target} LINK_PUBLIC ${MPI_C_LIBRARIES}) + set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS "${MPI_C_LINK_FLAGS}") + endif() + + set_target_properties(${target} + PROPERTIES + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/tests" + ) +endforeach() -set_target_properties(validate.exe - PROPERTIES - RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/tests" -) diff --git a/tests/test_cell_group.cpp b/tests/test_cell_group.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d9787a1935657391be82f2e674b13bab9ef908ed --- /dev/null +++ b/tests/test_cell_group.cpp @@ -0,0 +1,48 @@ +#include <limits> + +#include "gtest.h" + +#include <fvm_cell.hpp> +#include <cell_group.hpp> + +nest::mc::cell make_cell() { + using namespace nest::mc; + + // setup global state for the mechanisms + mechanisms::setup_mechanism_helpers(); + + nest::mc::cell cell; + + // Soma with diameter 12.6157 um and HH channel + auto soma = cell.add_soma(12.6157/2.0); + soma->add_mechanism(hh_parameters()); + + // add dendrite of length 200 um and diameter 1 um with passive channel + auto dendrite = cell.add_cable(0, segmentKind::dendrite, 0.5, 0.5, 200); + dendrite->add_mechanism(pas_parameters()); + dendrite->set_compartments(101); + + dendrite->mechanism("membrane").set("r_L", 100); + + // add stimulus + cell.add_stimulus({1,1}, {5., 80., 0.3}); + + cell.add_detector({0,0}, 0); + + return cell; +} + +TEST(cell_group, test) +{ + using namespace nest::mc; + + using cell_type = cell_group<fvm::fvm_cell<double, int>>; + + auto cell = cell_type{make_cell()}; + + cell.advance(50, 0.01); + + // a bit lame... + EXPECT_EQ(cell.spikes().size(), 4u); +} + diff --git a/tests/test_event_queue.cpp b/tests/test_event_queue.cpp index e67ade0bea22a417ca417a78c550b67bb0c3e45b..5ac4c75747d761638b0bbd4b90a9b6fed6fa3c18 100644 --- a/tests/test_event_queue.cpp +++ b/tests/test_event_queue.cpp @@ -7,8 +7,9 @@ TEST(event_queue, push) { using namespace nest::mc; + using ps_event_queue = event_queue<postsynaptic_spike_event>; - event_queue q; + ps_event_queue q; q.push({1u, 2.f, 2.f}); q.push({4u, 1.f, 2.f}); @@ -30,15 +31,16 @@ TEST(event_queue, push) TEST(event_queue, push_range) { using namespace nest::mc; + using ps_event_queue = event_queue<postsynaptic_spike_event>; - local_event events[] = { + postsynaptic_spike_event events[] = { {1u, 2.f, 2.f}, {4u, 1.f, 2.f}, {8u, 20.f, 2.f}, {2u, 8.f, 2.f} }; - event_queue q; + ps_event_queue q; q.push(std::begin(events), std::end(events)); std::vector<float> times; @@ -54,15 +56,16 @@ TEST(event_queue, push_range) TEST(event_queue, pop_if_before) { using namespace nest::mc; + using ps_event_queue = event_queue<postsynaptic_spike_event>; - local_event events[] = { + postsynaptic_spike_event events[] = { {1u, 1.f, 2.f}, {2u, 2.f, 2.f}, {3u, 3.f, 2.f}, {4u, 4.f, 2.f} }; - event_queue q; + ps_event_queue q; q.push(std::begin(events), std::end(events)); EXPECT_EQ(q.size(), 4u); diff --git a/tests/test_fvm.cpp b/tests/test_fvm.cpp index 85a25f3579aa25e2e84860b3144294269ff0895a..bc7374cbd0ea8f16b2b8adb02c4d27382aa83c61 100644 --- a/tests/test_fvm.cpp +++ b/tests/test_fvm.cpp @@ -3,8 +3,8 @@ #include "gtest.h" #include "util.hpp" -#include "../src/cell.hpp" -#include "../src/fvm.hpp" +#include <cell.hpp> +#include <fvm_cell.hpp> TEST(fvm, cable) { diff --git a/tests/test_spikes.cpp b/tests/test_spikes.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b16f89cda31768f43b8509b9f9cdde375ed9bc25 --- /dev/null +++ b/tests/test_spikes.cpp @@ -0,0 +1,70 @@ +#include "gtest.h" + +#include <communication/spike.hpp> +#include <communication/spike_source.hpp> + +struct cell_proxy { + double voltage(nest::mc::segment_location loc) const { + return v; + } + + double v = -65.; +}; + +TEST(spikes, spike_detector) +{ + using namespace nest::mc; + using detector_type = spike_detector<cell_proxy>; + cell_proxy proxy; + float threshold = 10.f; + float t = 0.f; + float dt = 1.f; + auto loc = segment_location(1, 0.1); + + auto detector = detector_type(proxy, loc, threshold, t); + + EXPECT_FALSE(detector.is_spiking()); + EXPECT_EQ(loc, detector.location()); + EXPECT_EQ(proxy.v, detector.v()); + EXPECT_EQ(t, detector.t()); + + { + t += dt; + proxy.v = 0; + auto spike = detector.test(proxy, t); + EXPECT_FALSE(spike); + + EXPECT_FALSE(detector.is_spiking()); + EXPECT_EQ(loc, detector.location()); + EXPECT_EQ(proxy.v, detector.v()); + EXPECT_EQ(t, detector.t()); + } + + { + t += dt; + proxy.v = 20; + auto spike = detector.test(proxy, t); + + EXPECT_TRUE(spike); + EXPECT_EQ(spike.get(), 1.5); + + EXPECT_TRUE(detector.is_spiking()); + EXPECT_EQ(loc, detector.location()); + EXPECT_EQ(proxy.v, detector.v()); + EXPECT_EQ(t, detector.t()); + } + + { + t += dt; + proxy.v = 0; + auto spike = detector.test(proxy, t); + + EXPECT_FALSE(spike); + + EXPECT_FALSE(detector.is_spiking()); + EXPECT_EQ(loc, detector.location()); + EXPECT_EQ(proxy.v, detector.v()); + EXPECT_EQ(t, detector.t()); + } +} + diff --git a/tests/test_synapses.cpp b/tests/test_synapses.cpp index 5a8f2ffac903289663dc21e105523a593bd695a2..87ac3aeca2e828035dea39c7cab4bf156a4e04bf 100644 --- a/tests/test_synapses.cpp +++ b/tests/test_synapses.cpp @@ -2,7 +2,7 @@ #include "util.hpp" #include <cell.hpp> -#include <fvm.hpp> +#include <fvm_cell.hpp> // compares results with those generated by nrn/ball_and_stick.py TEST(synapses, add_to_cell) diff --git a/tests/validate_ball_and_stick.cpp b/tests/validate_ball_and_stick.cpp index 1a9bddb7e3aaa031cbf4212e0c1898767f3682bb..5b509525b053fe389908b150b87333a27f5241ec 100644 --- a/tests/validate_ball_and_stick.cpp +++ b/tests/validate_ball_and_stick.cpp @@ -4,8 +4,8 @@ #include "gtest.h" #include "util.hpp" -#include "../src/cell.hpp" -#include "../src/fvm.hpp" +#include <cell.hpp> +#include <fvm_cell.hpp> // compares results with those generated by nrn/ball_and_stick.py TEST(ball_and_stick, neuron_baseline) diff --git a/tests/validate_soma.cpp b/tests/validate_soma.cpp index 44d8235a4463e153c87fa58b325b7694440ed3a1..2353cb483967fb74757dab81ffcf96c68f3c649a 100644 --- a/tests/validate_soma.cpp +++ b/tests/validate_soma.cpp @@ -4,8 +4,8 @@ #include "gtest.h" #include "util.hpp" -#include "../src/cell.hpp" -#include "../src/fvm.hpp" +#include <cell.hpp> +#include <fvm_cell.hpp> // compares results with those generated by nrn/soma.py // single compartment model with HH channels diff --git a/tests/validate_synapses.cpp b/tests/validate_synapses.cpp index 7b882a689c94cb5e6b70e8f89115917fc7a6ba0d..2ccde863cb4cbf16efbe75dd509e9c7c2095ddbe 100644 --- a/tests/validate_synapses.cpp +++ b/tests/validate_synapses.cpp @@ -4,7 +4,7 @@ #include "util.hpp" #include <cell.hpp> -#include <fvm.hpp> +#include <fvm_cell.hpp> #include <json/src/json.hpp>