diff --git a/CMakeLists.txt b/CMakeLists.txt index 02d70e06b34d688b4f8751ede06d6fdf673aa720..39c99a77d6369c85c310c4fb920c83f451f1975e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -132,6 +132,11 @@ set(ARB_WITH_CUDA FALSE) if(NOT ARB_GPU_MODEL MATCHES "none") find_package(CUDA REQUIRED) + + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} + -Xcudafe --diag_suppress=integer_sign_change + -Xcudafe --diag_suppress=unsigned_compare_with_zero) + set(ARB_WITH_CUDA TRUE) add_definitions(-DARB_HAVE_GPU) include_directories(SYSTEM ${CUDA_INCLUDE_DIRS}) @@ -154,7 +159,7 @@ endif() #---------------------------------------------------------- # Cray/BGQ/Generic Linux/other flag? #---------------------------------------------------------- -set(ARB_SYSTEM_TYPE "Generic" CACHE STRING +set(ARB_SYSTEM_TYPE "Generic" CACHE STRING "Choose a system type to customize flags") set_property(CACHE ARB_SYSTEM_TYPE PROPERTY STRINGS Generic Cray BGQ ) @@ -213,20 +218,17 @@ if(ARB_WITH_PROFILING) endif() #---------------------------------------------------------- -# vectorization target +# Modcc vectorization target #---------------------------------------------------------- -set(ARB_VECTORIZE_TARGET "none" CACHE STRING "CPU target for vectorization {none,KNL,AVX2,AVX512}") -set_property(CACHE ARB_VECTORIZE_TARGET PROPERTY STRINGS none KNL AVX2 AVX512) - -# Note: this option conflates modcc code generation options and -# the target architecture for compilation of Arbor. TODO: fix! +option(ARB_VECTORIZE "use explicit SIMD code in generated mechanisms" OFF) -if(ARB_VECTORIZE_TARGET STREQUAL "KNL") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXXOPT_KNL} -DSIMD_KNL") -elseif(ARB_VECTORIZE_TARGET STREQUAL "AVX2") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXXOPT_AVX2} -DSIMD_AVX2") -elseif(ARB_VECTORIZE_TARGET STREQUAL "AVX512") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXXOPT_AVX512} -DSIMD_AVX512") +#---------------------------------------------------------- +# Target microarchitecture for building arbor libraries +#---------------------------------------------------------- +set(ARB_ARCH "" CACHE STRING "Target architecture for arbor libraries") +if(ARB_ARCH) + # Sets CXXOPT_ARCH variable accordingly: + set_arch_target("${ARB_ARCH}") endif() #---------------------------------------------------------- diff --git a/cmake/CompilerOptions.cmake b/cmake/CompilerOptions.cmake index 8e51dddbdb597933a7808f7ae77685e334be2a7a..b892417f4cae18d41081eeef2e14ef818a308dd1 100644 --- a/cmake/CompilerOptions.cmake +++ b/cmake/CompilerOptions.cmake @@ -5,17 +5,28 @@ set(CXXOPT_PTHREAD "-pthread") set(CXXOPT_CXX11 "-std=c++11") set(CXXOPT_WALL "-Wall") -if(${CMAKE_CXX_COMPILER_ID} MATCHES "XL") +# CMake (at least sometimes) misidentifies XL 13 for Linux as Clang. +if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + try_compile(ignore ${CMAKE_BINARY_DIR} ${PROJECT_SOURCE_DIR}/cmake/dummy.cpp COMPILE_DEFINITIONS --version OUTPUT_VARIABLE cc_out) + string(REPLACE "\n" ";" cc_out "${cc_out}") + foreach(line ${cc_out}) + if(line MATCHES "^IBM XL C") + set(CMAKE_CXX_COMPILER_ID "XL") + endif() + endforeach(line) +endif() + +if(CMAKE_CXX_COMPILER_ID MATCHES "XL") # Disable 'missing-braces' warning: this will inappropriately # flag initializations such as # std::array<int,3> a={1,2,3}; set(CXXOPT_WALL "${CXXOPT_WALL} -Wno-missing-braces") # CMake, bless its soul, likes to insert this unsupported flag. Hilarity ensues. - string(REPLACE "-qhalt=e" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) + string(REPLACE "-qhalt=e" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") endif() -if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") +if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") set(CXXOPT_KNL "-march=knl") set(CXXOPT_AVX2 "-mavx2 -mfma") set(CXXOPT_AVX512 "-mavx512f -mavx512cd") @@ -25,6 +36,11 @@ if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") # std::array<int,3> a={1,2,3}; set(CXXOPT_WALL "${CXXOPT_WALL} -Wno-missing-braces") + # Disable 'potentially-evaluated-expression' warning: this warns + # on expressions of the form `typeid(expr)` when `expr` has side + # effects. + set(CXXOPT_WALL "${CXXOPT_WALL} -Wno-potentially-evaluated-expression") + # Clang is erroneously warning that T is an 'unused type alias' in code like this: # struct X { # using T = decltype(expression); @@ -36,27 +52,95 @@ if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") set(CXXOPT_WALL "${CXXOPT_WALL} -Wno-format-security") endif() -if(${CMAKE_CXX_COMPILER_ID} MATCHES "GNU") - # Compiler flags for generating KNL-specific AVX512 instructions - # supported in gcc 4.9.x and later. - set(CXXOPT_KNL "-march=knl") - set(CXXOPT_AVX2 "-mavx2 -mfma") - set(CXXOPT_AVX512 "-mavx512f -mavx512cd") - +if(CMAKE_CXX_COMPILER_ID MATCHES "GNU") # Disable 'maybe-uninitialized' warning: this will be raised # inappropriately in some uses of util::optional<T> when T # is a primitive type. set(CXXOPT_WALL "${CXXOPT_WALL} -Wno-maybe-uninitialized") endif() -if(${CMAKE_CXX_COMPILER_ID} MATCHES "Intel") - # Compiler flags for generating KNL-specific AVX512 instructions. - set(CXXOPT_KNL "-xMIC-AVX512") - set(CXXOPT_AVX2 "-xCORE-AVX2") - set(CXXOPT_AVX512 "-xCORE-AVX512") - +if(CMAKE_CXX_COMPILER_ID MATCHES "Intel") # Disable warning for unused template parameter # this is raised by a templated function in the json library. set(CXXOPT_WALL "${CXXOPT_WALL} -wd488") endif() +# Set CXXOPT_ARCH in parent scope according to requested architecture. +# Architectures are given by the same names that GCC uses for its +# -mcpu or -march options. + +function(set_arch_target arch) + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID MATCHES "GNU") + # Correct compiler option unfortunately depends upon the target architecture family. + # Extract this information from running the configured compiler with --verbose. + + try_compile(ignore ${CMAKE_BINARY_DIR} ${PROJECT_SOURCE_DIR}/cmake/dummy.cpp COMPILE_DEFINITIONS --verbose OUTPUT_VARIABLE cc_out) + string(REPLACE "\n" ";" cc_out "${cc_out}") + set(target) + foreach(line ${cc_out}) + if(line MATCHES "^Target:") + string(REGEX REPLACE "^Target: " "" target "${line}") + endif() + endforeach(line) + string(REGEX REPLACE "-.*" "" target_model "${target}") + + # Use -mcpu for all supported targets _except_ for x86, where it should be -march. + + if(target_model MATCHES "x86" OR target_model MATCHES "amd64") + set(CXXOPT_ARCH "-march=${arch}") + else() + set(CXXOPT_ARCH "-mcpu=${arch}") + endif() + + elseif(CMAKE_CXX_COMPILER_ID MATCHES "Intel") + # Translate target architecture names to Intel-compatible names. + # icc 17 recognizes the following specific microarchitecture names for -mtune: + # broadwell, haswell, ivybridge, knl, sandybridge, skylake + + if(arch MATCHES "sandybridge") + set(tune "${arch}") + set(arch "AVX") + elseif(arch MATCHES "ivybridge") + set(tune "${arch}") + set(arch "CORE-AVX-I") + elseif(arch MATCHES "broadwell|haswell|skylake") + set(tune "${arch}") + set(arch "CORE-AVX2") + elseif(arch MATCHES "knl") + set(tune "${arch}") + set(arch "MIC-AVX512") + elseif(arch MATCHES "nehalem|westmere") + set(tune "corei7") + set(arch "SSE4.2") + elseif(arch MATCHES "core2") + set(tune "core2") + set(arch "SSSE3") + elseif(arch MATCHES "native") + unset(tune) + set(arch "Host") + else() + set(tune "generic") + set(arch "SSE2") # default for icc + endif() + + if(tune) + set(CXXOPT_ARCH "-x${arch};-mtune=${tune}") + else() + set(CXXOPT_ARCH "-x${arch}") + endif() + + elseif(CMAKE_CXX_COMPILER_ID MATCHES "XL") + # xlC 13 for Linux uses -mcpu. Not even attempting to get xlC 12 for BG/Q right + # at this point: use CXXFLAGS as required! + # + # xlC, gcc, and clang all recognize power8 and power9 as architecture keywords. + + if(arch MATCHES "native") + set(CXXOPT_ARCH "-qarch=auto") + else() + set(CXXOPT_ARCH "-mcpu=${arch}") + endif() + endif() + + set(CXXOPT_ARCH "${CXXOPT_ARCH}" PARENT_SCOPE) +endfunction() diff --git a/cmake/dummy.cpp b/cmake/dummy.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c20ecf844d719aa3d48c48373b6cd0bae37bdb9 --- /dev/null +++ b/cmake/dummy.cpp @@ -0,0 +1,3 @@ +// Used to extract system model family from gnu or clang. +int main() { return 0; } + diff --git a/example/generators/event_gen.cpp b/example/generators/event_gen.cpp index 145802bd194375c194efc6e4b0aa3a5a58c937d9..0eac37185e27468cf23acbc67e9c9e0a48de5349 100644 --- a/example/generators/event_gen.cpp +++ b/example/generators/event_gen.cpp @@ -54,8 +54,7 @@ public: // Add one synapse at the soma. // This synapse will be the target for all events, from both // event_generators. - auto syn_spec = arb::mechanism_spec("expsyn"); - c.add_synapse({0, 0.5}, syn_spec); + c.add_synapse({0, 0.5}, "expsyn"); return std::move(c); } diff --git a/example/miniapp/miniapp.cpp b/example/miniapp/miniapp.cpp index 2ef4ef722edb5a335536eabe69ec5a97ec204ad1..1aecbf493e762345d693966d62802374e4429c38 100644 --- a/example/miniapp/miniapp.cpp +++ b/example/miniapp/miniapp.cpp @@ -11,7 +11,6 @@ #include <communication/communicator.hpp> #include <communication/global_policy.hpp> #include <cell.hpp> -#include <fvm_multicell.hpp> #include <hardware/gpu.hpp> #include <hardware/node_info.hpp> #include <io/exporter_spike_file.hpp> diff --git a/example/miniapp/miniapp_recipes.cpp b/example/miniapp/miniapp_recipes.cpp index 69690387a6b2cc3e10f530a2c32e114665b3a59f..443df8e4dff13d40a7051c70065c105fe469afd7 100644 --- a/example/miniapp/miniapp_recipes.cpp +++ b/example/miniapp/miniapp_recipes.cpp @@ -63,7 +63,7 @@ cell make_basic_cell( EXPECTS(!terminals.empty()); - arb::mechanism_spec syn_default(syn_type); + arb::mechanism_desc syn_default(syn_type); for (unsigned i=0; i<num_synapses; ++i) { unsigned id = terminals[i%terminals.size()]; cell.add_synapse({id, distribution(rng)}, syn_default); diff --git a/mechanisms/CMakeLists.txt b/mechanisms/CMakeLists.txt index a4a62da7d38391130cac5bb6dd7e1048218d0663..b08342c633e9d8e9c642c27a6f63f7505742c190 100644 --- a/mechanisms/CMakeLists.txt +++ b/mechanisms/CMakeLists.txt @@ -7,59 +7,69 @@ set(mod_srcdir "${CMAKE_CURRENT_SOURCE_DIR}/mod") # Generate mechanism implementations for host/cpu environment -set(mech_dir "${CMAKE_CURRENT_SOURCE_DIR}/multicore") +set(mech_dir "${CMAKE_CURRENT_BINARY_DIR}/generated") file(MAKE_DIRECTORY "${mech_dir}") -if(ARB_VECTORIZE_TARGET STREQUAL "none") - set(modcc_simd "") -elseif(ARB_VECTORIZE_TARGET STREQUAL "KNL") - set(modcc_simd "-s avx512") -elseif(ARB_VECTORIZE_TARGET STREQUAL "AVX512") - set(modcc_simd "-s avx512") -elseif(ARB_VECTORIZE_TARGET STREQUAL "AVX2") - set(modcc_simd "-s avx2") -else() - message(SEND_ERROR "Unrecognized architecture for ARB_VECTORIZE_TARGET") - set(modcc_simd "") +if(ARB_VECTORIZE) + set(modcc_simd "-s") endif() + build_modules( ${mechanisms} SOURCE_DIR "${mod_srcdir}" DEST_DIR "${mech_dir}" - MODCC_FLAGS -t cpu ${modcc_simd} - GENERATES _cpu.hpp + MODCC_FLAGS -t cpu -t gpu ${modcc_simd} + GENERATES .hpp _cpu.cpp _gpu.cpp _gpu.cu TARGET build_all_mods ) -# Generate mechanism implementations for gpu +# Generate source for default mechanism catalogue -set(mech_dir "${CMAKE_CURRENT_SOURCE_DIR}/gpu") -file(MAKE_DIRECTORY "${mech_dir}") -build_modules( - ${mechanisms} - SOURCE_DIR "${mod_srcdir}" - DEST_DIR "${mech_dir}" - MODCC_FLAGS -t gpu - GENERATES _gpu_impl.cu _gpu.hpp _gpu_impl.hpp - TARGET build_all_gpu_mods +set(catsrc ${CMAKE_CURRENT_BINARY_DIR}/default_catalogue.cpp) +set(default_catalogue_options -I ${mech_dir} -o ${catsrc} -B multicore) +if(ARB_WITH_CUDA) + list(APPEND default_catalogue_options -B gpu) +endif() + +add_custom_command( + OUTPUT ${catsrc} + COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/generate_default_catalogue ${default_catalogue_options} ${mechanisms} + DEPENDS build_all_mods generate_default_catalogue ) -# Make a library with the implementations of the mechanism kernels +# Make libraries with the implementations of the mechanism kernels. + +foreach(mech ${mechanisms}) + list(APPEND cpu_mech_sources ${mech_dir}/${mech}_cpu.cpp) +endforeach() + +add_library(arbormech ${cpu_mech_sources} ${catsrc}) +target_compile_options(arbormech PRIVATE ${CXXOPT_ARCH}) +set(mech_libs arbormech) + +if (ARB_AUTO_RUN_MODCC_ON_CHANGES) + add_dependencies(arbormech build_all_mods) +endif() if(ARB_WITH_CUDA) - # make list of the .cu files that implement the mechanism kernels foreach(mech ${mechanisms}) - list(APPEND cuda_mech_sources ${mech_dir}/${mech}_gpu_impl.cu) + list(APPEND cuda_mech_sources "${mech_dir}/${mech}_gpu.cpp" "${mech_dir}/${mech}_gpu.cu") endforeach() - # compile the .cu files into a library cuda_add_library(arbormechcu ${cuda_mech_sources}) + set(mech_libs arbormech arbormechcu) # force recompilation on changes to modcc or the underlying .mod files if (ARB_AUTO_RUN_MODCC_ON_CHANGES) - add_dependencies(arbormechcu build_all_gpu_mods) + add_dependencies(arbormechcu build_all_mods) endif() - list(APPEND ARB_LIBRARIES arbormechcu) - set(ARB_LIBRARIES "${ARB_LIBRARIES}" PARENT_SCOPE) endif() + +# Until we merge our myriad static libraries, we prepend mech libs and also append mech libs +# to capture generated default catalogue code interdependencies with arbor lib. + +list(INSERT ARB_LIBRARIES 0 ${mech_libs}) +list(APPEND ARB_LIBRARIES ${mech_libs}) + +set(ARB_LIBRARIES "${ARB_LIBRARIES}" PARENT_SCOPE) diff --git a/mechanisms/generate_default_catalogue b/mechanisms/generate_default_catalogue new file mode 100755 index 0000000000000000000000000000000000000000..1f7cba510a3bd23eb939aae7e7541dfe1e1fc7cc --- /dev/null +++ b/mechanisms/generate_default_catalogue @@ -0,0 +1,130 @@ +#!/usr/bin/env python + +# Note: compatible with Python2.7 and Python3. + +from __future__ import print_function +import sys +import string +import argparse + +def parse_arguments(): + def append_slash(s): + return s+'/' if s and not s.endswith('/') else s + + class ConciseHelpFormatter(argparse.HelpFormatter): + def __init__(self, **kwargs): + super(ConciseHelpFormatter, self).__init__(max_help_position=20, **kwargs) + + def _format_action_invocation(self, action): + if not action.option_strings: + return super(ConciseHelpFormatter, self)._format_action_invocation(action) + else: + optstr = ', '.join(action.option_strings) + if action.nargs==0: + return optstr + else: + return optstr+' '+self._format_args(action, action.dest.upper()) + + parser = argparse.ArgumentParser( + description = 'Generate global default catalogue source for Arbor build.', + usage = '%(prog)s [options] [module...]', + add_help = False, + formatter_class = ConciseHelpFormatter) + + parser.add_argument( + 'modules', + nargs = '*', + help = argparse.SUPPRESS) + + group = parser.add_argument_group('Options') + + group.add_argument( + '-I', '--module-prefix', + default = 'mechanisms', + metavar = 'PATH', + dest = 'modpfx', + type = append_slash, + help = 'directory prefix for module includes, default "%(default)s"') + + group.add_argument( + '-A', '--arbor-prefix', + default = '', + metavar = 'PATH', + dest = 'arbpfx', + type = append_slash, + help = 'directory prefix for arbor includes, default "%(default)s"') + + group.add_argument( + '-B', '--backend', + default = [], + action = 'append', + dest = 'backends', + metavar = 'BACKEND', + help = 'register implementations for back-end %(metavar)s') + + group.add_argument( + '-o', '--output', + default = [], + dest = 'output', + metavar = 'FILE', + help = 'save output to %(metavar)s (default is to print to stdout)') + + group.add_argument( + '-h', '--help', + action = 'help', + help = 'display this help and exit') + + return vars(parser.parse_args()) + + +def generate(modpfx='', arbpfx='', modules=[], backends=[], **rest): + src = string.Template(\ +r'''// Automatically generated by: +// $cmdline + +#include <${arbpfx}mechcat.hpp> +$backend_includes +$module_includes + +namespace arb { + +mechanism_catalogue build_default_catalogue() { + mechanism_catalogue cat; + + $add_modules + $register_modules + return cat; +} + +const mechanism_catalogue& global_default_catalogue() { + static mechanism_catalogue cat = build_default_catalogue(); + return cat; +} + +} // namespace arb''') + + def indent(n, lines): + return '{{:<{0!s}}}'.format(n+1).format('\n').join(lines) + + return src.safe_substitute(dict( + cmdline = " ".join(sys.argv), + arbpfx = arbpfx, + backend_includes = indent(0, + ['#include <{}backends/{}/fvm.hpp>'.format(arbpfx, b) for b in backends]), + module_includes = indent(0, + ['#include <{}{}.hpp>'.format(modpfx, m) for m in modules]), + add_modules = indent(4, + ['cat.add("{0}", mechanism_{0}_info());'.format(m) for m in modules]), + register_modules = indent(4, + ['cat.register_implementation("{0}", make_mechanism_{0}<{1}::backend>());'.format(m, b) + for m in modules for b in backends]) + )) + + +args = parse_arguments() +code = generate(**args) +if args['output']: + print(code, file = open(args['output'],'w')) +else: + print(code) + diff --git a/mechanisms/mod/test_ca.mod b/mechanisms/mod/test_ca.mod index 67d09c028d62c35277aeacd75bf5f74d8c71e818..09630f4377bc7df67b6e30cc73425a7a47e52f1a 100644 --- a/mechanisms/mod/test_ca.mod +++ b/mechanisms/mod/test_ca.mod @@ -21,6 +21,10 @@ PARAMETER { ASSIGNED {} +INITIAL { + cai = cai0 +} + STATE { cai (mM) } diff --git a/modcc/CMakeLists.txt b/modcc/CMakeLists.txt index 286c6729c01ad5dd985c034e792020362336a4b2..360c66cdc55df523c7b6db0bb3d92fd670aa6179 100644 --- a/modcc/CMakeLists.txt +++ b/modcc/CMakeLists.txt @@ -1,8 +1,5 @@ set(MODCC_SOURCES astmanip.cpp - cexpr_emit.cpp - cprinter.cpp - cudaprinter.cpp errorvisitor.cpp expression.cpp functionexpander.cpp @@ -14,10 +11,19 @@ set(MODCC_SOURCES solvers.cpp symdiff.cpp symge.cpp - textbuffer.cpp token.cpp + + io/prefixbuf.cpp + + printer/cexpr_emit.cpp + printer/cprinter.cpp + printer/cudaprinter.cpp + printer/infoprinter.cpp + printer/printerutil.cpp ) +include_directories("${PROJECT_SOURCE_DIR}/modcc") + add_library(compiler ${MODCC_SOURCES}) add_executable(modcc modcc.cpp) diff --git a/modcc/blocks.hpp b/modcc/blocks.hpp index 9256b8a633e1124d778d8c7a622737254281019a..ad31932e3e16448595ad74884c0a83ccb52a9c01 100644 --- a/modcc/blocks.hpp +++ b/modcc/blocks.hpp @@ -18,33 +18,33 @@ struct IonDep { std::vector<Token> read; // name of channels parameters to write std::vector<Token> write; // name of channels parameters to read - bool has_variable(std::string const& name) { + bool has_variable(std::string const& name) const { return writes_variable(name) || reads_variable(name); }; - bool uses_current() { + bool uses_current() const { return has_variable("i"+name); }; - bool uses_rev_potential() { + bool uses_rev_potential() const { return has_variable("e"+name); }; - bool uses_concentration_int() { + bool uses_concentration_int() const { return has_variable(name+"i"); }; - bool uses_concentration_ext() { + bool uses_concentration_ext() const { return has_variable(name+"o"); }; - bool writes_concentration_int() { + bool writes_concentration_int() const { return writes_variable(name+"i"); }; - bool writes_concentration_ext() { + bool writes_concentration_ext() const { return writes_variable(name+"o"); }; - bool reads_variable(const std::string& name) { + bool reads_variable(const std::string& name) const { return std::find_if(read.begin(), read.end(), [&name](const Token& t) {return t.spelling==name;}) != read.end(); } - bool writes_variable(const std::string& name) { + bool writes_variable(const std::string& name) const { return std::find_if(write.begin(), write.end(), [&name](const Token& t) {return t.spelling==name;}) != write.end(); } diff --git a/modcc/cprinter.cpp b/modcc/cprinter.cpp deleted file mode 100644 index 91ace7217e86c3ebb97e648c8aeee69015842236..0000000000000000000000000000000000000000 --- a/modcc/cprinter.cpp +++ /dev/null @@ -1,740 +0,0 @@ -#include <algorithm> -#include <string> -#include <unordered_set> - -#include "cexpr_emit.hpp" -#include "cprinter.hpp" -#include "lexer.hpp" - -/****************************************************************************** - CPrinter driver -******************************************************************************/ - -std::string CPrinter::emit_source() { - // make a list of vector types, both parameters and assigned - // and a list of all scalar types - std::vector<VariableExpression*> scalar_variables; - std::vector<VariableExpression*> array_variables; - - for(auto& sym: module_->symbols()) { - if(auto var = sym.second->is_variable()) { - if(var->is_range()) { - array_variables.push_back(var); - } - else { - scalar_variables.push_back(var); - } - } - } - - std::string module_name = module_->module_name(); - - ////////////////////////////////////////////// - ////////////////////////////////////////////// - emit_headers(); - - ////////////////////////////////////////////// - ////////////////////////////////////////////// - std::string class_name = "mechanism_" + module_name; - - text_.add_line("namespace arb { namespace multicore {"); - text_.add_line(); - text_.add_line("using math::exprelr;"); - text_.add_line("using math::min;"); - text_.add_line("using math::max;"); - text_.add_line(); - text_.add_line("template<class Backend>"); - text_.add_line("class " + class_name + " : public mechanism<Backend> {"); - text_.add_line("public:"); - text_.increase_indentation(); - text_.add_line("using base = mechanism<Backend>;"); - text_.add_line("using value_type = typename base::value_type;"); - text_.add_line("using size_type = typename base::size_type;"); - text_.add_line(); - text_.add_line("using array = typename base::array;"); - text_.add_line("using iarray = typename base::iarray;"); - text_.add_line("using view = typename base::view;"); - text_.add_line("using iview = typename base::iview;"); - text_.add_line("using const_view = typename base::const_view;"); - text_.add_line("using const_iview = typename base::const_iview;"); - text_.add_line("using ion_type = typename base::ion_type;"); - text_.add_line("using deliverable_event_stream_state = typename base::deliverable_event_stream_state;"); - text_.add_line(); - - ////////////////////////////////////////////// - ////////////////////////////////////////////// - for(auto& ion: module_->neuron_block().ions) { - auto tname = "Ion" + ion.name; - text_.add_line("struct " + tname + " {"); - text_.increase_indentation(); - for(auto& field : ion.read) { - text_.add_line("view " + field.spelling + ";"); - } - for(auto& field : ion.write) { - text_.add_line("view " + field.spelling + ";"); - } - text_.add_line("iarray index;"); - text_.add_line("std::size_t memory() const { return sizeof(size_type)*index.size(); }"); - text_.add_line("std::size_t size() const { return index.size(); }"); - text_.decrease_indentation(); - text_.add_line("};"); - text_.add_line(tname + " ion_" + ion.name + ";"); - } - - ////////////////////////////////////////////// - // constructor - ////////////////////////////////////////////// - int num_vars = array_variables.size(); - text_.add_line(); - text_.add_line(class_name + "(size_type mech_id, const_iview vec_ci, const_view vec_t, const_view vec_t_to, const_view vec_dt, view vec_v, view vec_i, array&& weights, iarray&& node_index)"); - text_.add_line(": base(mech_id, vec_ci, vec_t, vec_t_to, vec_dt, vec_v, vec_i, std::move(node_index))"); - text_.add_line("{"); - text_.increase_indentation(); - text_.add_gutter() << "size_type num_fields = " << num_vars << ";"; - text_.end_line(); - - text_.add_line(); - text_.add_line("// calculate the padding required to maintain proper alignment of sub arrays"); - text_.add_line("auto alignment = data_.alignment();"); - text_.add_line("auto field_size_in_bytes = sizeof(value_type)*size();"); - text_.add_line("auto remainder = field_size_in_bytes % alignment;"); - text_.add_line("auto padding = remainder ? (alignment - remainder)/sizeof(value_type) : 0;"); - text_.add_line("auto field_size = size()+padding;"); - - text_.add_line(); - text_.add_line("// allocate memory"); - text_.add_line("data_ = array(field_size*num_fields, std::numeric_limits<value_type>::quiet_NaN());"); - - // assign the sub-arrays - // replace this : data_(1*n, 2*n); - // with this : data_(1*field_size, 1*field_size+n); - - text_.add_line(); - text_.add_line("// asign the sub-arrays"); - for(int i=0; i<num_vars; ++i) { - char namestr[128]; - sprintf(namestr, "%-15s", array_variables[i]->name().c_str()); - text_.add_gutter() << namestr << " = data_(" - << i << "*field_size, " << i+1 << "*size());"; - text_.end_line(); - } - text_.add_line(); - - // copy in the weights - text_.add_line("// add the user-supplied weights for converting from current density"); - text_.add_line("// to per-compartment current in nA"); - text_.add_line("if (weights.size()) {"); - text_.increase_indentation(); - text_.add_line("memory::copy(weights, weights_(0, size()));"); - text_.decrease_indentation(); - text_.add_line("}"); - text_.add_line("else {"); - text_.increase_indentation(); - text_.add_line("memory::fill(weights_, 1.0);"); - text_.decrease_indentation(); - text_.add_line("}"); - text_.add_line(); - - text_.add_line("// set initial values for variables and parameters"); - for(auto const& var : array_variables) { - double val = var->value(); - // only non-NaN fields need to be initialized, because data_ - // is NaN by default - std::string pointer_name = var->name()+".data()"; - if(val == val) { - text_.add_gutter() << "std::fill(" << pointer_name << ", " - << pointer_name << "+size(), " - << val << ");"; - text_.end_line(); - } - } - - text_.add_line(); - text_.decrease_indentation(); - text_.add_line("}"); - - ////////////////////////////////////////////// - ////////////////////////////////////////////// - - text_.add_line(); - text_.add_line("using base::size;"); - text_.add_line(); - - text_.add_line("std::size_t memory() const override {"); - text_.increase_indentation(); - text_.add_line("auto s = std::size_t{0};"); - text_.add_line("s += data_.size()*sizeof(value_type);"); - for(auto& ion: module_->neuron_block().ions) { - text_.add_line("s += ion_" + ion.name + ".memory();"); - } - text_.add_line("return s;"); - text_.decrease_indentation(); - text_.add_line("}"); - text_.add_line(); - - text_.add_line("std::string name() const override {"); - text_.increase_indentation(); - text_.add_line("return \"" + module_name + "\";"); - text_.decrease_indentation(); - text_.add_line("}"); - text_.add_line(); - - std::string kind_str = module_->kind() == moduleKind::density - ? "mechanismKind::density" - : "mechanismKind::point"; - text_.add_line("mechanismKind kind() const override {"); - text_.increase_indentation(); - text_.add_line("return " + kind_str + ";"); - text_.decrease_indentation(); - text_.add_line("}"); - text_.add_line(); - - // Implement `set_weights` method. - text_.add_line("void set_weights(array&& weights) override {"); - text_.increase_indentation(); - text_.add_line("memory::copy(weights, weights_(0, size()));"); - text_.decrease_indentation(); - text_.add_line("}"); - text_.add_line(); - - /*************************************************************************** - * - * ion channels have the following fields : - * - * --------------------------------------------------- - * label Ca Na K name - * --------------------------------------------------- - * iX ica ina ik current - * eX eca ena ek reversal_potential - * Xi cai nai ki internal_concentration - * Xo cao nao ko external_concentration - * gX gca gna gk conductance - * --------------------------------------------------- - * - **************************************************************************/ - - // ion_spec uses_ion(ionKind k) const override - text_.add_line("typename base::ion_spec uses_ion(ionKind k) const override {"); - text_.increase_indentation(); - text_.add_line("bool uses = false;"); - text_.add_line("bool writes_ext = false;"); - text_.add_line("bool writes_int = false;"); - for (auto k: {ionKind::Na, ionKind::Ca, ionKind::K}) { - if (module_->has_ion(k)) { - auto ion = *module_->find_ion(k); - text_.add_line("if (k==ionKind::" + ion.name + ") {"); - text_.increase_indentation(); - text_.add_line("uses = true;"); - if (ion.writes_concentration_int()) text_.add_line("writes_int = true;"); - if (ion.writes_concentration_ext()) text_.add_line("writes_ext = true;"); - text_.decrease_indentation(); - text_.add_line("}"); - } - } - text_.add_line("return {uses, writes_int, writes_ext};"); - text_.decrease_indentation(); - text_.add_line("}"); - text_.add_line(); - - // void set_ion(ionKind k, ion_type& i) override - text_.add_line("void set_ion(ionKind k, ion_type& i, std::vector<size_type>const& index) override {"); - text_.increase_indentation(); - for (auto k: {ionKind::Na, ionKind::Ca, ionKind::K}) { - if (module_->has_ion(k)) { - auto ion = *module_->find_ion(k); - text_.add_line("if (k==ionKind::" + ion.name + ") {"); - text_.increase_indentation(); - auto n = ion.name; - auto pre = "ion_"+n; - text_.add_line(pre+".index = memory::make_const_view(index);"); - if (ion.uses_current()) - text_.add_line(pre+".i"+n+" = i.current();"); - if (ion.uses_rev_potential()) - text_.add_line(pre+".e"+n+" = i.reversal_potential();"); - if (ion.uses_concentration_int()) - text_.add_line(pre+"."+n+"i = i.internal_concentration();"); - if (ion.uses_concentration_ext()) - text_.add_line(pre+"."+n+"o = i.external_concentration();"); - text_.add_line("return;"); - text_.decrease_indentation(); - text_.add_line("}"); - } - } - text_.add_line("throw std::domain_error(arb::util::pprintf(\"mechanism % does not support ion type\\n\", name()));"); - text_.decrease_indentation(); - text_.add_line("}"); - text_.add_line(); - - ////////////////////////////////////////////// - - auto proctest = [] (procedureKind k) { - return is_in(k, {procedureKind::normal, procedureKind::api, procedureKind::net_receive}); - }; - bool override_deliver_events = false; - for(auto const& var: module_->symbols()) { - auto isproc = var.second->kind()==symbolKind::procedure; - if(isproc) { - auto proc = var.second->is_procedure(); - if(proctest(proc->kind())) { - proc->accept(this); - } - override_deliver_events |= proc->kind()==procedureKind::net_receive; - } - } - - if(override_deliver_events) { - text_.add_line("void deliver_events(const deliverable_event_stream_state& events) override {"); - text_.increase_indentation(); - text_.add_line("auto ncell = events.n_streams();"); - text_.add_line("for (size_type c = 0; c<ncell; ++c) {"); - text_.increase_indentation(); - - text_.add_line("auto begin = events.begin_marked(c);"); - text_.add_line("auto end = events.end_marked(c);"); - text_.add_line("for (auto p = begin; p<end; ++p) {"); - text_.increase_indentation(); - text_.add_line("if (p->mech_id==mech_id_) net_receive(p->mech_index, p->weight);"); - text_.decrease_indentation(); - text_.add_line("}"); - text_.decrease_indentation(); - text_.add_line("}"); - text_.decrease_indentation(); - text_.add_line("}"); - text_.add_line(); - } - - if(module_->write_backs().size()) { - text_.add_line("void write_back() override {"); - text_.increase_indentation(); - - text_.add_line("const size_type n_ = node_index_.size();"); - for (auto& w: module_->write_backs()) { - auto& src = w.source_name; - auto tgt = w.target_name; - tgt.erase(tgt.begin(), tgt.begin()+tgt.find('_')+1); - auto istore = ion_store(w.ion_kind)+"."; - - text_.add_line(); - text_.add_line("auto "+src+"_out_ = util::indirect_view("+istore+tgt+", "+istore+"index);"); - text_.add_line("for (size_type i_ = 0; i_ < n_; ++i_) {"); - text_.increase_indentation(); - text_.add_line("// 1/10 magic number due to unit normalisation"); - text_.add_line(src+"_out_[i_] += value_type(0.1)*weights_[i_]*"+src+"[i_];"); - text_.decrease_indentation(); text_.add_line("}"); - } - text_.decrease_indentation(); text_.add_line("}"); - } - text_.add_line(); - - // TODO: replace field_info() generation implemenation with separate schema info generation - // as per #349. - auto field_info_string = [](const std::string& kind, const Id& id) { - return "field_spec{field_spec::" + kind + ", " + - "\"" + id.unit_string() + "\", " + - (id.has_value()? id.value: "0") + - (id.has_range()? ", " + id.range.first.spelling + "," + id.range.second.spelling: "") + - "}"; - }; - - std::unordered_set<std::string> scalar_set; - for (auto& v: scalar_variables) { - scalar_set.insert(v->name()); - } - - std::vector<Id> global_param_ids; - std::vector<Id> instance_param_ids; - - for (const Id& id: module_->parameter_block().parameters) { - auto var = id.token.spelling; - (scalar_set.count(var)? global_param_ids: instance_param_ids).push_back(id); - } - const std::vector<Id>& state_ids = module_->state_block().state_variables; - - text_.add_line("util::optional<field_spec> field_info(const char* id) const /* override */ {"); - text_.increase_indentation(); - text_.add_line("static const std::pair<const char*, field_spec> field_tbl[] = {"); - text_.increase_indentation(); - for (const auto& id: global_param_ids) { - auto var = id.token.spelling; - text_.add_line("{\""+var+"\", "+field_info_string("global",id )+"},"); - } - for (const auto& id: instance_param_ids) { - auto var = id.token.spelling; - text_.add_line("{\""+var+"\", "+field_info_string("parameter", id)+"},"); - } - for (const auto& id: state_ids) { - auto var = id.token.spelling; - text_.add_line("{\""+var+"\", "+field_info_string("state", id)+"},"); - } - text_.decrease_indentation(); - text_.add_line("};"); - text_.add_line(); - text_.add_line("auto* info = util::table_lookup(field_tbl, id);"); - text_.add_line("return info? util::just(*info): util::nullopt;"); - text_.decrease_indentation(); - text_.add_line("}"); - text_.add_line(); - - if (!instance_param_ids.empty() || !state_ids.empty()) { - text_.add_line("view base::* field_view_ptr(const char* id) const override {"); - text_.increase_indentation(); - text_.add_line("static const std::pair<const char*, view "+class_name+"::*> field_tbl[] = {"); - text_.increase_indentation(); - for (const auto& id: instance_param_ids) { - auto var = id.token.spelling; - text_.add_line("{\""+var+"\", &"+class_name+"::"+var+"},"); - } - for (const auto& id: state_ids) { - auto var = id.token.spelling; - text_.add_line("{\""+var+"\", &"+class_name+"::"+var+"},"); - } - text_.decrease_indentation(); - text_.add_line("};"); - text_.add_line(); - text_.add_line("auto* pptr = util::table_lookup(field_tbl, id);"); - text_.add_line("return pptr? static_cast<view base::*>(*pptr): nullptr;"); - text_.decrease_indentation(); - text_.add_line("}"); - text_.add_line(); - } - - if (!global_param_ids.empty()) { - text_.add_line("value_type base::* field_value_ptr(const char* id) const override {"); - text_.increase_indentation(); - text_.add_line("static const std::pair<const char*, value_type "+class_name+"::*> field_tbl[] = {"); - text_.increase_indentation(); - for (const auto& id: global_param_ids) { - auto var = id.token.spelling; - text_.add_line("{\""+var+"\", &"+class_name+"::"+var+"},"); - } - text_.decrease_indentation(); - text_.add_line("};"); - text_.add_line(); - text_.add_line("auto* pptr = util::table_lookup(field_tbl, id);"); - text_.add_line("return pptr? static_cast<value_type base::*>(*pptr): nullptr;"); - text_.decrease_indentation(); - text_.add_line("}"); - text_.add_line(); - } - - ////////////////////////////////////////////// - ////////////////////////////////////////////// - - text_.add_line("array data_;"); - for(auto var: array_variables) { - text_.add_line("view " + var->name() + ";"); - } - - for(auto var: scalar_variables) { - double val = var->value(); - // test the default value for NaN - // useful for error propogation from bad initial conditions - if(val==val) { - text_.add_gutter() << "value_type " << var->name() << " = " << val << ";"; - text_.end_line(); - } - else { - text_.add_line("value_type " + var->name() + " = 0;"); - } - } - - text_.add_line(); - text_.add_line("using base::mech_id_;"); - text_.add_line("using base::vec_ci_;"); - text_.add_line("using base::vec_t_;"); - text_.add_line("using base::vec_t_to_;"); - text_.add_line("using base::vec_dt_;"); - text_.add_line("using base::vec_v_;"); - text_.add_line("using base::vec_i_;"); - text_.add_line("using base::node_index_;"); - - text_.add_line(); - text_.decrease_indentation(); - text_.add_line("};"); - text_.add_line(); - - text_.add_line("}} // namespaces"); - return text_.str(); -} - - - -void CPrinter::emit_headers() { - text_.add_line("#pragma once"); - text_.add_line(); - text_.add_line("#include <cmath>"); - text_.add_line("#include <limits>"); - text_.add_line(); - text_.add_line("#include <math.hpp>"); - text_.add_line("#include <mechanism.hpp>"); - text_.add_line("#include <algorithms.hpp>"); - text_.add_line("#include <backends/event.hpp>"); - text_.add_line("#include <backends/multi_event_stream_state.hpp>"); - text_.add_line("#include <util/pprintf.hpp>"); - text_.add_line("#include <util/simple_table.hpp>"); - text_.add_line(); -} - -/****************************************************************************** - CPrinter -******************************************************************************/ - -void CPrinter::visit(Expression *e) { - throw compiler_exception( - "CPrinter doesn't know how to print " + e->to_string(), - e->location()); -} - -void CPrinter::visit(LocalDeclaration *e) { -} - -void CPrinter::visit(Symbol *e) { - throw compiler_exception("I don't know how to print raw Symbol " + e->to_string(), - e->location()); -} - -void CPrinter::visit(LocalVariable *e) { - std::string const& name = e->name(); - text_ << name; - if(is_ghost_local(e)) { - text_ << "[j_]"; - } -} - -void CPrinter::visit(NumberExpression *e) { - cexpr_emit(e, text_.text(), this); -} - -void CPrinter::visit(IdentifierExpression *e) { - e->symbol()->accept(this); -} - -void CPrinter::visit(VariableExpression *e) { - text_ << e->name(); - if(e->is_range()) { - text_ << "[i_]"; - } -} - -void CPrinter::visit(IndexedVariable *e) { - text_ << e->index_name() << "[i_]"; -} - -void CPrinter::visit(CellIndexedVariable *e) { - text_ << e->index_name() << "[i_]"; -} - -void CPrinter::visit(UnaryExpression *e) { - cexpr_emit(e, text_.text(), this); -} - -void CPrinter::visit(BlockExpression *e) { - // ------------- declare local variables ------------- // - // only if this is the outer block - if(!e->is_nested()) { - std::vector<std::string> names; - for(auto& symbol : e->scope()->locals()) { - auto sym = symbol.second.get(); - // input variables are declared earlier, before the - // block body is printed - if(is_stack_local(sym) && !is_input(sym)) { - names.push_back(sym->name()); - } - } - if(names.size()>0) { - text_.add_gutter() << "value_type " << *(names.begin()); - for(auto it=names.begin()+1; it!=names.end(); ++it) { - text_ << ", " << *it; - } - text_.end_line(";"); - } - } - - // ------------- statements ------------- // - for(auto& stmt : e->statements()) { - if(stmt->is_local_declaration()) continue; - - // these all must be handled - text_.add_gutter(); - stmt->accept(this); - if (not stmt->is_if()) { - text_.end_line(";"); - } - } -} - -void CPrinter::visit(IfExpression *e) { - // for now we remove the brackets around the condition because - // the binary expression printer adds them, and we want to work - // around the -Wparentheses-equality warning - text_ << "if("; - e->condition()->accept(this); - text_ << ") {\n"; - increase_indentation(); - e->true_branch()->accept(this); - decrease_indentation(); - text_.add_line("}"); - // check if there is a false-branch, i.e. if - // there is an "else" branch to print - if (auto fb = e->false_branch()) { - text_.add_gutter() << "else "; - // use recursion for "else if" - if (fb->is_if()) { - fb->accept(this); - } - // otherwise print the "else" block - else { - text_ << "{\n"; - increase_indentation(); - fb->accept(this); - decrease_indentation(); - text_.add_line("}"); - } - } -} - -// NOTE: net_receive() is classified as a ProcedureExpression -void CPrinter::visit(ProcedureExpression *e) { - // print prototype - text_.add_gutter() << "void " << e->name() << "(int i_"; - for(auto& arg : e->args()) { - text_ << ", value_type " << arg->is_argument()->name(); - } - if(e->kind() == procedureKind::net_receive) { - text_.end_line(") override {"); - } - else { - text_.end_line(") {"); - } - - if(!e->scope()) { // error: semantic analysis has not been performed - throw compiler_exception( - "CPrinter attempt to print Procedure " + e->name() - + " for which semantic analysis has not been performed", - e->location()); - } - - // print body - increase_indentation(); - e->body()->accept(this); - - // close the function body - decrease_indentation(); - text_.add_line("}"); - text_.add_line(); -} - -void CPrinter::visit(APIMethod *e) { - // print prototype - text_.add_gutter() << "void " << e->name() << "() override {"; - text_.end_line(); - - if(!e->scope()) { // error: semantic analysis has not been performed - throw compiler_exception( - "CPrinter attempt to print APIMethod " + e->name() - + " for which semantic analysis has not been performed", - e->location()); - } - - // only print the body if it has contents - if(e->is_api_method()->body()->statements().size()) { - increase_indentation(); - - // create local indexed views - for(auto &symbol : e->scope()->locals()) { - auto var = symbol.second->is_local_variable(); - if (!var->is_indexed()) continue; - - auto external = var->external_variable(); - auto const& name = var->name(); - auto const& index_name = external->index_name(); - - text_.add_gutter(); - text_ << "auto " + index_name + " = "; - - if(external->is_cell_indexed_variable()) { - text_ << "util::indirect_view(util::indirect_view(" + index_name + "_, vec_ci_), node_index_);\n"; - } - else if(external->is_ion()) { - auto channel = external->ion_channel(); - auto iname = ion_store(channel); - text_ << "util::indirect_view(" << iname << "." << name << ", " << ion_store(channel) << ".index);\n"; - } - else { - text_ << "util::indirect_view(" + index_name + "_, node_index_);\n"; - } - } - - // get loop dimensions - text_.add_line("int n_ = node_index_.size();"); - - print_APIMethod(e); - } - - // close up the loop body - text_.add_line("}"); - text_.add_line(); -} - -void CPrinter::emit_api_loop(APIMethod* e, - const std::string& start, - const std::string& end, - const std::string& inc) { - text_.add_gutter(); - text_ << "for (" << start << "; " << end << "; " << inc << ") {"; - text_.end_line(); - text_.increase_indentation(); - - // loads from external indexed arrays - for(auto &symbol : e->scope()->locals()) { - auto var = symbol.second->is_local_variable(); - if(is_input(var)) { - auto ext = var->external_variable(); - text_.add_gutter() << "value_type "; - var->accept(this); - text_ << " = "; - ext->accept(this); - text_.end_line(";"); - } - } - - // print the body of the loop - e->body()->accept(this); - - // perform update of external variables (currents etc) - for(auto &symbol : e->scope()->locals()) { - auto var = symbol.second->is_local_variable(); - if(is_output(var)) { - auto ext = var->external_variable(); - text_.add_gutter(); - ext->accept(this); - text_ << (ext->op() == tok::plus ? " += " : " -= "); - var->accept(this); - text_.end_line(";"); - } - } - - text_.decrease_indentation(); - text_.add_line("}"); -} - -void CPrinter::print_APIMethod(APIMethod* e) { - emit_api_loop(e, "int i_ = 0", "i_ < n_", "++i_"); - decrease_indentation(); - - return; -} - -void CPrinter::visit(CallExpression *e) { - text_ << e->name() << "(i_"; - for(auto& arg: e->args()) { - text_ << ", "; - arg->accept(this); - } - text_ << ")"; -} - -void CPrinter::visit(BinaryExpression *e) { - cexpr_emit(e, text_.text(), this); -} - diff --git a/modcc/cprinter.hpp b/modcc/cprinter.hpp deleted file mode 100644 index 9c6363225cb5a57788014f8e29301bf7bb600c86..0000000000000000000000000000000000000000 --- a/modcc/cprinter.hpp +++ /dev/null @@ -1,122 +0,0 @@ -#pragma once - -#include <sstream> - -#include "module.hpp" -#include "textbuffer.hpp" -#include "visitor.hpp" - -class CPrinter: public Visitor { -public: - CPrinter() {} - explicit CPrinter(Module &m): module_(&m) {} - - void visit(Expression *e) override; - void visit(UnaryExpression *e) override; - void visit(BinaryExpression *e) override; - void visit(NumberExpression *e) override; - void visit(VariableExpression *e) override; - void visit(Symbol *e) override; - void visit(LocalVariable *e) override; - void visit(IndexedVariable *e) override; - void visit(CellIndexedVariable *e) override; - void visit(IdentifierExpression *e) override; - void visit(CallExpression *e) override; - void visit(ProcedureExpression *e) override; - void visit(APIMethod *e) override; - void visit(LocalDeclaration *e) override; - void visit(BlockExpression *e) override; - void visit(IfExpression *e) override; - - std::string text() const { - return text_.str(); - } - - void set_gutter(int w) { - text_.set_gutter(w); - } - void increase_indentation(){ - text_.increase_indentation(); - } - void decrease_indentation(){ - text_.decrease_indentation(); - } - void clear_text() { - text_.clear(); - } - - virtual ~CPrinter() { } - - virtual std::string emit_source(); - virtual void emit_headers(); - virtual void emit_api_loop(APIMethod* e, - const std::string& start, - const std::string& end, - const std::string& inc); - -protected: - void print_mechanism(Visitor *backend); - void print_APIMethod(APIMethod* e); - - Module *module_ = nullptr; - TextBuffer text_; - bool aliased_output_ = false; - - bool is_input(Symbol *s) { - if(auto l = s->is_local_variable() ) { - if(l->is_local()) { - if(l->is_indexed() && l->is_read()) { - return true; - } - } - } - return false; - } - - bool is_output(Symbol *s) { - if(auto l = s->is_local_variable() ) { - if(l->is_local()) { - if(l->is_indexed() && l->is_write()) { - return true; - } - } - } - return false; - } - - bool is_arg_local(Symbol *s) { - if(auto l=s->is_local_variable()) { - if(l->is_arg()) { - return true; - } - } - return false; - } - - bool is_indexed_local(Symbol *s) { - if(auto l=s->is_local_variable()) { - if(l->is_indexed()) { - return true; - } - } - return false; - } - - bool is_ghost_local(Symbol *s) { - if(!is_point_process()) return false; - if(!aliased_output_) return false; - if(is_arg_local(s)) return false; - return is_output(s); - } - - bool is_stack_local(Symbol *s) { - if(is_arg_local(s)) return false; - return !is_ghost_local(s); - } - - bool is_point_process() { - return module_ && module_->kind() == moduleKind::point; - } - - std::vector<LocalVariable*> aliased_vars(APIMethod* e); -}; diff --git a/modcc/cudaprinter.cpp b/modcc/cudaprinter.cpp deleted file mode 100644 index ba3ff2cd6be572e26505a84659336bba9c776ad5..0000000000000000000000000000000000000000 --- a/modcc/cudaprinter.cpp +++ /dev/null @@ -1,1012 +0,0 @@ -#include <algorithm> -#include <string> -#include <unordered_set> - -#include "cexpr_emit.hpp" -#include "cudaprinter.hpp" -#include "lexer.hpp" - -std::string CUDAPrinter::pack_name() { - return module_name_ + "_ParamPack"; -} - -CUDAPrinter::CUDAPrinter(Module &m, bool o) - : module_(&m) -{ - // make a list of vector types, both parameters and assigned - // and a list of all scalar types - std::vector<VariableExpression*> scalar_variables; - std::vector<VariableExpression*> array_variables; - for(auto& sym: m.symbols()) { - if(sym.second->kind()==symbolKind::variable) { - auto var = sym.second->is_variable(); - if(var->is_range()) { - array_variables.push_back(var); - } - else { - scalar_variables.push_back(var) ; - } - } - } - - module_name_ = module_->module_name(); - - // - // Implementation header. - // - // Contains the parameter pack and protypes of c wrappers around cuda kernels. - // - - set_buffer(impl_interface_); - - // headers - buffer().add_line("#pragma once"); - buffer().add_line("#include <backends/event.hpp>"); - buffer().add_line("#include <backends/fvm_types.hpp>"); - buffer().add_line("#include <backends/multi_event_stream_state.hpp>"); - buffer().add_line("#include <backends/gpu/kernels/detail.hpp>"); - buffer().add_line("#include <util/simple_table.hpp>"); - buffer().add_line(); - - buffer().add_line("namespace arb { namespace gpu{"); - buffer().add_line("using deliverable_event_stream_state = multi_event_stream_state<deliverable_event_data>;"); - buffer().add_line(); - - // definition of parameter pack type - std::vector<std::string> param_pack; - buffer().add_gutter() << "struct " << pack_name() << " {"; - buffer().end_line(); - buffer().increase_indentation(); - buffer().add_line("using T = arb::fvm_value_type;"); - buffer().add_line("using I = arb::fvm_size_type;"); - buffer().add_line("// array parameters"); - for(auto const &var: array_variables) { - buffer().add_line("T* " + var->name() + ";"); - param_pack.push_back(var->name() + ".data()"); - } - buffer().add_line("// scalar parameters"); - for(auto const &var: scalar_variables) { - buffer().add_line("T " + var->name() + ";"); - param_pack.push_back(var->name()); - } - buffer().add_line("// ion channel dependencies"); - for(auto& ion: m.neuron_block().ions) { - auto tname = "ion_" + ion.name; - for(auto& field : ion.read) { - buffer().add_line("T* ion_" + field.spelling + ";"); - param_pack.push_back(tname + "." + field.spelling + ".data()"); - } - for(auto& field : ion.write) { - buffer().add_line("T* ion_" + field.spelling + ";"); - param_pack.push_back(tname + "." + field.spelling + ".data()"); - } - buffer().add_line("I* ion_" + ion.name + "_idx_;"); - param_pack.push_back(tname + ".index.data()"); - } - - buffer().add_line("// cv index to cell mapping and cell time states"); - buffer().add_line("const I* ci;"); - buffer().add_line("const T* vec_t;"); - buffer().add_line("const T* vec_t_to;"); - buffer().add_line("const T* vec_dt;"); - param_pack.push_back("vec_ci_.data()"); - param_pack.push_back("vec_t_.data()"); - param_pack.push_back("vec_t_to_.data()"); - param_pack.push_back("vec_dt_.data()"); - - buffer().add_line("// voltage and current state within the cell"); - buffer().add_line("T* vec_v;"); - buffer().add_line("T* vec_i;"); - param_pack.push_back("vec_v_.data()"); - param_pack.push_back("vec_i_.data()"); - - buffer().add_line("// node index information"); - buffer().add_line("I* ni;"); - buffer().add_line("unsigned long n_;"); - buffer().decrease_indentation(); - buffer().add_line("};"); - buffer().add_line(); - param_pack.push_back("node_index_.data()"); - param_pack.push_back("node_index_.size()"); - - // kernel wrapper prototypes - for(auto const &var: m.symbols()) { - if (auto e = var.second->is_api_method()) { - buffer().add_line(APIMethod_prototype(e) + ";"); - } - else if (var.second->is_net_receive()) { - buffer().add_line( - "void deliver_events_" + module_name_ +"(" + pack_name() + " params_, arb::fvm_size_type mech_id, deliverable_event_stream_state state);"); - } - } - if(module_->write_backs().size()) { - buffer().add_line("void write_back_"+module_name_+"("+pack_name()+" params_);"); - } - buffer().add_line(); - buffer().add_line("}} // namespace arb::gpu"); - - // - // Implementation - // - - set_buffer(impl_); - - // kernels - buffer().add_line("#include \"" + module_name_ + "_gpu_impl.hpp\""); - buffer().add_line(); - buffer().add_line("#include <backends/gpu/intrinsics.hpp>"); - buffer().add_line("#include <backends/gpu/kernels/reduce_by_key.hpp>"); - buffer().add_line(); - buffer().add_line("namespace arb { namespace gpu{"); - buffer().add_line("namespace kernels {"); - buffer().increase_indentation(); - { - // forward declarations of procedures - for(auto const &var: m.symbols()) { - if( var.second->kind()==symbolKind::procedure && - var.second->is_procedure()->kind() == procedureKind::normal) - { - print_device_function_prototype(var.second->is_procedure()); - buffer().end_line(";"); - buffer().add_line(); - } - } - - // print stubs that call API method kernels that are defined in the - // kernels::name namespace - for(auto const &var: m.symbols()) { - if (var.second->kind()==symbolKind::procedure && - is_in(var.second->is_procedure()->kind(), - {procedureKind::normal, procedureKind::api, procedureKind::net_receive})) - { - auto e = var.second->is_procedure(); - e->accept(this); - } - } - } - - // print the write_back kernel - if(module_->write_backs().size()) { - buffer().add_line("__global__"); - buffer().add_line("void write_back_"+module_name_+"("+pack_name()+" params_) {"); - buffer().increase_indentation(); - buffer().add_line("using value_type = arb::fvm_value_type;"); - - buffer().add_line("auto tid_ = threadIdx.x + blockDim.x*blockIdx.x;"); - buffer().add_line("auto const n_ = params_.n_;"); - buffer().add_line("if(tid_<n_) {"); - buffer().increase_indentation(); - - for (auto& w: module_->write_backs()) { - auto& src = w.source_name; - auto& tgt = w.target_name; - - auto idx = src + "_idx_"; - buffer().add_line("auto "+idx+" = params_.ion_"+to_string(w.ion_kind)+"_idx_[tid_];"); - buffer().add_line("// 1/10 magic number due to unit normalisation"); - buffer().add_line("params_."+tgt+"["+idx+"] = value_type(0.1)*params_.weights_[tid_]*params_."+src+"[tid_];"); - } - buffer().decrease_indentation(); buffer().add_line("}"); - buffer().decrease_indentation(); buffer().add_line("}"); - } - - buffer().decrease_indentation(); - buffer().add_line("} // kernel namespace"); - - // implementation of the kernel wrappers - buffer().add_line(); - for(auto const &var : m.symbols()) { - if (auto e = var.second->is_api_method()) { - buffer().add_line(APIMethod_prototype(e) + " {"); - buffer().increase_indentation(); - buffer().add_line("auto n = params_.n_;"); - buffer().add_line("constexpr int blockwidth = 128;"); - buffer().add_line("dim3 dim_block(blockwidth);"); - buffer().add_line("dim3 dim_grid(impl::block_count(n, blockwidth));"); - buffer().add_line("arb::gpu::kernels::"+e->name()+"_"+module_name_+"<<<dim_grid, dim_block>>>(params_);"); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - } - else if (var.second->is_net_receive()) { - buffer().add_line("void deliver_events_" + module_name_ - + "(" + pack_name() + " params_, arb::fvm_size_type mech_id, deliverable_event_stream_state state) {"); - buffer().increase_indentation(); - buffer().add_line("const int n = state.n;"); - buffer().add_line("constexpr int blockwidth = 128;"); - buffer().add_line("const auto nblock = impl::block_count(n, blockwidth);"); - buffer().add_line("arb::gpu::kernels::deliver_events<<<nblock, blockwidth>>>(params_, mech_id, state);"); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - } - } - - // add the write_back kernel wrapper if required by this module - if(module_->write_backs().size()) { - buffer().add_line("void write_back_"+module_name_+"("+pack_name()+" params_) {"); - buffer().increase_indentation(); - buffer().add_line("auto n = params_.n_;"); - buffer().add_line("constexpr int blockwidth = 128;"); - buffer().add_line("dim3 dim_block(blockwidth);"); - buffer().add_line("dim3 dim_grid(impl::block_count(n, blockwidth));"); - buffer().add_line("arb::gpu::kernels::write_back_"+module_name_+"<<<dim_grid, dim_block>>>(params_);"); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - } - buffer().add_line("}} // namespace arb::gpu"); - - // - // Interface header - // - // Included in the front-end C++ code. - // - - set_buffer(interface_); - - buffer().add_line("#pragma once"); - buffer().add_line(); - buffer().add_line("#include <cmath>"); - buffer().add_line("#include <limits>"); - buffer().add_line(); - buffer().add_line("#include <mechanism.hpp>"); - buffer().add_line("#include <algorithms.hpp>"); - buffer().add_line("#include <backends/event.hpp>"); - buffer().add_line("#include <backends/fvm_types.hpp>"); - buffer().add_line("#include <backends/gpu/multi_event_stream.hpp>"); - buffer().add_line("#include <util/pprintf.hpp>"); - buffer().add_line(); - buffer().add_line("#include \"" + module_name_ + "_gpu_impl.hpp\""); - buffer().add_line(); - - buffer().add_line("namespace arb { namespace gpu{"); - buffer().add_line(); - - ////////////////////////////////////////////// - ////////////////////////////////////////////// - std::string class_name = "mechanism_" + module_name_; - - buffer().add_line("template <typename Backend>"); - buffer().add_line("class " + class_name + " : public mechanism<Backend> {"); - buffer().add_line("public:"); - buffer().increase_indentation(); - buffer().add_line("using base = mechanism<Backend>;"); - buffer().add_line("using typename base::value_type;"); - buffer().add_line("using typename base::size_type;"); - buffer().add_line("using typename base::array;"); - buffer().add_line("using typename base::view;"); - buffer().add_line("using typename base::iarray;"); - buffer().add_line("using host_iarray = typename Backend::host_iarray;"); - buffer().add_line("using typename base::iview;"); - buffer().add_line("using typename base::const_iview;"); - buffer().add_line("using typename base::const_view;"); - buffer().add_line("using typename base::ion_type;"); - buffer().add_line("using deliverable_event_stream_state = typename base::deliverable_event_stream_state;"); - buffer().add_line("using param_pack_type = " + pack_name() + ";"); - - ////////////////////////////////////////////// - ////////////////////////////////////////////// - for(auto& ion: m.neuron_block().ions) { - auto tname = "Ion" + ion.name; - buffer().add_line("struct " + tname + " {"); - buffer().increase_indentation(); - for(auto& field : ion.read) { - buffer().add_line("view " + field.spelling + ";"); - } - for(auto& field : ion.write) { - buffer().add_line("view " + field.spelling + ";"); - } - buffer().add_line("iarray index;"); - buffer().add_line("std::size_t memory() const { return sizeof(size_type)*index.size(); }"); - buffer().add_line("std::size_t size() const { return index.size(); }"); - buffer().decrease_indentation(); - buffer().add_line("};"); - buffer().add_line(tname + " ion_" + ion.name + ";"); - buffer().add_line(); - } - - ////////////////////////////////////////////// - // constructor - ////////////////////////////////////////////// - - int num_vars = array_variables.size(); - buffer().add_line(); - buffer().add_line(class_name + "(size_type mech_id, const_iview vec_ci, const_view vec_t, const_view vec_t_to, const_view vec_dt, view vec_v, view vec_i, array&& weights, iarray&& node_index):"); - buffer().add_line(" base(mech_id, vec_ci, vec_t, vec_t_to, vec_dt, vec_v, vec_i, std::move(node_index))"); - buffer().add_line("{"); - buffer().increase_indentation(); - buffer().add_gutter() << "size_type num_fields = " << num_vars << ";"; - buffer().end_line(); - - buffer().add_line(); - buffer().add_line("// calculate the padding required to maintain proper alignment of sub arrays"); - buffer().add_line("auto alignment = data_.alignment();"); - buffer().add_line("auto field_size_in_bytes = sizeof(value_type)*size();"); - buffer().add_line("auto remainder = field_size_in_bytes % alignment;"); - buffer().add_line("auto padding = remainder ? (alignment - remainder)/sizeof(value_type) : 0;"); - buffer().add_line("auto field_size = size()+padding;"); - - buffer().add_line(); - buffer().add_line("// allocate memory"); - buffer().add_line("data_ = array(field_size*num_fields, std::numeric_limits<value_type>::quiet_NaN());"); - - // assign the sub-arrays - // replace this : data_(1*n, 2*n); - // with this : data_(1*field_size, 1*field_size+n); - - buffer().add_line(); - buffer().add_line("// asign the sub-arrays"); - for(int i=0; i<num_vars; ++i) { - char namestr[128]; - sprintf(namestr, "%-15s", array_variables[i]->name().c_str()); - buffer().add_line( - array_variables[i]->name() + " = data_(" - + std::to_string(i) + "*field_size, " + std::to_string(i+1) + "*field_size);"); - } - buffer().add_line(); - - for(auto const& var : array_variables) { - double val = var->value(); - // only non-NaN fields need to be initialized, because data_ - // is NaN by default - if(val == val) { - buffer().add_line("memory::fill(" + var->name() + ", " + std::to_string(val) + ");"); - } - } - buffer().add_line(); - - // copy in the weights - buffer().add_line("// add the user-supplied weights for converting from current density"); - buffer().add_line("// to per-compartment current in nA"); - buffer().add_line("if (weights.size()) {"); - buffer().increase_indentation(); - buffer().add_line("memory::copy(weights, weights_(0, size()));"); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line("else {"); - buffer().increase_indentation(); - buffer().add_line("memory::fill(weights_, 1.0);"); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - - ////////////////////////////////////////////// - ////////////////////////////////////////////// - - buffer().add_line("using base::size;"); - buffer().add_line(); - - buffer().add_line("std::size_t memory() const override {"); - buffer().increase_indentation(); - buffer().add_line("auto s = std::size_t{0};"); - buffer().add_line("s += data_.size()*sizeof(value_type);"); - for(auto& ion: m.neuron_block().ions) { - buffer().add_line("s += ion_" + ion.name + ".memory();"); - } - buffer().add_line("return s;"); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - - // print the member funtion that packs up the parameters for use on the GPU - buffer().add_line("void set_params() override {"); - buffer().increase_indentation(); - buffer().add_line("param_pack_ ="); - buffer().increase_indentation(); - buffer().add_line("param_pack_type {"); - buffer().increase_indentation(); - for(auto& str: param_pack) { - buffer().add_line(str + ","); - } - buffer().decrease_indentation(); - buffer().add_line("};"); - buffer().decrease_indentation(); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - - // name member function - buffer().add_line("std::string name() const override {"); - buffer().increase_indentation(); - buffer().add_line("return \"" + module_name_ + "\";"); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - - std::string kind_str = m.kind() == moduleKind::density - ? "mechanismKind::density" - : "mechanismKind::point"; - buffer().add_line("mechanismKind kind() const override {"); - buffer().increase_indentation(); - buffer().add_line("return " + kind_str + ";"); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - - // Implement mechanism::set_weights method - buffer().add_line("void set_weights(array&& weights) override {"); - buffer().increase_indentation(); - buffer().add_line("memory::copy(weights, weights_(0, size()));"); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - - ////////////////////////////////////////////// - // print ion channel interface - ////////////////////////////////////////////// - - /*************************************************************************** - * - * ion channels have the following fields : - * - * --------------------------------------------------- - * label Ca Na K name - * --------------------------------------------------- - * iX ica ina ik current - * eX eca ena ek reversal_potential - * Xi cai nai ki internal_concentration - * Xo cao nao ko external_concentration - * gX gca gna gk conductance - * --------------------------------------------------- - * - **************************************************************************/ - - // ion_spec uses_ion(ionKind k) const override - buffer().add_line("typename base::ion_spec uses_ion(ionKind k) const override {"); - buffer().increase_indentation(); - buffer().add_line("bool uses = false;"); - buffer().add_line("bool writes_ext = false;"); - buffer().add_line("bool writes_int = false;"); - for (auto k: {ionKind::Na, ionKind::Ca, ionKind::K}) { - if (module_->has_ion(k)) { - auto ion = *module_->find_ion(k); - buffer().add_line("if (k==ionKind::" + ion.name + ") {"); - buffer().increase_indentation(); - buffer().add_line("uses = true;"); - if (ion.writes_concentration_int()) buffer().add_line("writes_int = true;"); - if (ion.writes_concentration_ext()) buffer().add_line("writes_ext = true;"); - buffer().decrease_indentation(); - buffer().add_line("}"); - } - } - buffer().add_line("return {uses, writes_int, writes_ext};"); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - - // void set_ion(ionKind k, ion_type& i) override - buffer().add_line("void set_ion(ionKind k, ion_type& i, std::vector<size_type>const& index) override {"); - buffer().increase_indentation(); - for (auto k: {ionKind::Na, ionKind::Ca, ionKind::K}) { - if (module_->has_ion(k)) { - auto ion = *module_->find_ion(k); - buffer().add_line("if (k==ionKind::" + ion.name + ") {"); - buffer().increase_indentation(); - auto n = ion.name; - auto pre = "ion_"+n; - buffer().add_line(pre+".index = memory::make_const_view(index);"); - if (ion.uses_current()) - buffer().add_line(pre+".i"+n+" = i.current();"); - if (ion.uses_rev_potential()) - buffer().add_line(pre+".e"+n+" = i.reversal_potential();"); - if (ion.uses_concentration_int()) - buffer().add_line(pre+"."+n+"i = i.internal_concentration();"); - if (ion.uses_concentration_ext()) - buffer().add_line(pre+"."+n+"o = i.external_concentration();"); - buffer().add_line("return;"); - buffer().decrease_indentation(); - buffer().add_line("}"); - } - } - buffer().add_line("throw std::domain_error(arb::util::pprintf(\"mechanism % does not support ion type\\n\", name()));"); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - - ////////////////////////////////////////////// - ////////////////////////////////////////////// - for(auto const &var : m.symbols()) { - if( var.second->kind()==symbolKind::procedure && - var.second->is_procedure()->kind()==procedureKind::api) - { - auto proc = var.second->is_api_method(); - auto name = proc->name(); - buffer().add_line("void " + name + "() {"); - buffer().increase_indentation(); - buffer().add_line("arb::gpu::"+name+"_"+module_name_+"(param_pack_);"); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - } - else if( var.second->kind()==symbolKind::procedure && - var.second->is_procedure()->kind()==procedureKind::net_receive) - { - // Override `deliver_events`. - buffer().add_line("void deliver_events(const deliverable_event_stream_state& events) override {"); - buffer().increase_indentation(); - - buffer().add_line("arb::gpu::deliver_events_"+module_name_ - +"(param_pack_, mech_id_, events);"); - - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - } - } - - if(module_->write_backs().size()) { - buffer().add_line("void write_back() override {"); - buffer().increase_indentation(); - buffer().add_line("arb::gpu::write_back_"+module_name_+"(param_pack_);"); - buffer().decrease_indentation(); buffer().add_line("}"); - } - buffer().add_line(); - - std::unordered_set<std::string> scalar_set; - for (auto& v: scalar_variables) { - scalar_set.insert(v->name()); - } - - std::vector<Id> global_param_ids; - std::vector<Id> instance_param_ids; - - for (const Id& id: module_->parameter_block().parameters) { - auto var = id.token.spelling; - (scalar_set.count(var)? global_param_ids: instance_param_ids).push_back(id); - } - const std::vector<Id>& state_ids = module_->state_block().state_variables; - - if (!instance_param_ids.empty() || !state_ids.empty()) { - buffer().add_line("view base::* field_view_ptr(const char* id) const override {"); - buffer().increase_indentation(); - buffer().add_line("static const std::pair<const char*, view "+class_name+"::*> field_tbl[] = {"); - buffer().increase_indentation(); - for (const auto& id: instance_param_ids) { - auto var = id.token.spelling; - buffer().add_line("{\""+var+"\", &"+class_name+"::"+var+"},"); - } - for (const auto& id: state_ids) { - auto var = id.token.spelling; - buffer().add_line("{\""+var+"\", &"+class_name+"::"+var+"},"); - } - buffer().decrease_indentation(); - buffer().add_line("};"); - buffer().add_line(); - buffer().add_line("auto* pptr = util::table_lookup(field_tbl, id);"); - buffer().add_line("return pptr? static_cast<view base::*>(*pptr): nullptr;"); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - } - - if (!global_param_ids.empty()) { - buffer().add_line("value_type base::* field_value_ptr(const char* id) const override {"); - buffer().increase_indentation(); - buffer().add_line("static const std::pair<const char*, value_type "+class_name+"::*> field_tbl[] = {"); - buffer().increase_indentation(); - for (const auto& id: global_param_ids) { - auto var = id.token.spelling; - buffer().add_line("{\""+var+"\", &"+class_name+"::"+var+"},"); - } - buffer().decrease_indentation(); - buffer().add_line("};"); - buffer().add_line(); - buffer().add_line("auto* pptr = util::table_lookup(field_tbl, id);"); - buffer().add_line("return pptr? static_cast<value_type base::*>(*pptr): nullptr;"); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - } - ////////////////////////////////////////////// - ////////////////////////////////////////////// - - buffer().add_line("array data_;"); - for(auto var: array_variables) { - buffer().add_line("view " + var->name() + ";"); - } - for(auto var: scalar_variables) { - double val = var->value(); - // test the default value for NaN - // useful for error propogation from bad initial conditions - if(val==val) { - buffer().add_line("value_type " + var->name() + " = " + std::to_string(val) + ";"); - } - else { - // the cuda compiler has a bug that doesn't allow initialization of - // class members with std::numer_limites<>. So simply set to zero. - buffer().add_line("value_type " + var->name() + " = value_type{0};"); - } - } - - buffer().add_line("using base::mech_id_;"); - buffer().add_line("using base::vec_ci_;"); - buffer().add_line("using base::vec_t_;"); - buffer().add_line("using base::vec_t_to_;"); - buffer().add_line("using base::vec_dt_;"); - buffer().add_line("using base::vec_v_;"); - buffer().add_line("using base::vec_i_;"); - buffer().add_line("using base::node_index_;"); - buffer().add_line(); - buffer().add_line("param_pack_type param_pack_;"); - buffer().decrease_indentation(); - buffer().add_line("};"); - buffer().add_line(); - buffer().add_line("}} // namespaces"); -} - -void CUDAPrinter::visit(Expression *e) { - throw compiler_exception( - "CUDAPrinter doesn't know how to print " + e->to_string(), - e->location()); -} - -void CUDAPrinter::visit(LocalDeclaration *e) { -} - -void CUDAPrinter::visit(NumberExpression *e) { - cexpr_emit(e, buffer().text(), this); -} - -void CUDAPrinter::visit(IdentifierExpression *e) { - e->symbol()->accept(this); -} - -void CUDAPrinter::visit(Symbol *e) { - buffer() << e->name(); -} - -void CUDAPrinter::visit(VariableExpression *e) { - buffer() << "params_." << e->name(); - if(e->is_range()) { - buffer() << "[" << index_string(e) << "]"; - } -} - -std::string CUDAPrinter::index_string(Symbol *s) { - if(s->is_variable()) { - return "tid_"; - } - else if(auto var = s->is_indexed_variable()) { - switch(var->ion_channel()) { - case ionKind::none : - return "gid_"; - case ionKind::Ca : - return "caid_"; - case ionKind::Na : - return "naid_"; - case ionKind::K : - return "kid_"; - // a nonspecific ion current should never be indexed: it is - // local to a mechanism - case ionKind::nonspecific: - break; - default : - throw compiler_exception( - "CUDAPrinter unknown ion type", - s->location()); - } - } - else if(s->is_cell_indexed_variable()) { - return "cid_"; - } - return ""; -} - -void CUDAPrinter::visit(IndexedVariable *e) { - buffer() << "params_." << e->index_name() << "[" << index_string(e) << "]"; -} - -void CUDAPrinter::visit(CellIndexedVariable *e) { - buffer() << "params_." << e->index_name() << "[" << index_string(e) << "]"; -} - - -void CUDAPrinter::visit(LocalVariable *e) { - std::string const& name = e->name(); - buffer() << name; -} - -void CUDAPrinter::visit(UnaryExpression *e) { - cexpr_emit(e, buffer().text(), this); -} - -void CUDAPrinter::visit(BlockExpression *e) { - // ------------- declare local variables ------------- // - // only if this is the outer block - if(!e->is_nested()) { - for(auto& var : e->scope()->locals()) { - auto sym = var.second.get(); - // input variables are declared earlier, before the - // block body is printed - if(is_stack_local(sym) && !is_input(sym)) { - buffer().add_line("value_type " + var.first + ";"); - } - } - } - - // ------------- statements ------------- // - for(auto& stmt : e->statements()) { - if(stmt->is_local_declaration()) continue; - // these all must be handled - buffer().add_gutter(); - stmt->accept(this); - if (not stmt->is_if()) { - buffer().end_line(";"); - } - } -} - -void CUDAPrinter::visit(IfExpression *e) { - // for now we remove the brackets around the condition because - // the binary expression printer adds them, and we want to work - // around the -Wparentheses-equality warning - buffer() << "if("; - e->condition()->accept(this); - buffer() << ") {\n"; - buffer().increase_indentation(); - e->true_branch()->accept(this); - buffer().decrease_indentation(); - buffer().add_line("}"); - // check if there is a false-branch, i.e. if - // there is an "else" branch to print - if (auto fb = e->false_branch()) { - buffer().add_gutter() << "else "; - // use recursion for "else if" - if (fb->is_if()) { - fb->accept(this); - } - // otherwise print the "else" block - else { - buffer() << "{\n"; - buffer().increase_indentation(); - fb->accept(this); - buffer().decrease_indentation(); - buffer().add_line("}"); - } - } -} - -void CUDAPrinter::print_device_function_prototype(ProcedureExpression *e) { - buffer().add_line("__device__"); - buffer().add_gutter() << "void " << e->name() - << "(" << module_name_ << "_ParamPack const& params_," - << "const int tid_"; - for(auto& arg : e->args()) { - buffer() << ", arb::fvm_value_type " << arg->is_argument()->name(); - } - buffer() << ")"; -} - -void CUDAPrinter::visit(ProcedureExpression *e) { - // error: semantic analysis has not been performed - if(!e->scope()) { // error: semantic analysis has not been performed - throw compiler_exception( - "CUDAPrinter attempt to print Procedure " + e->name() - + " for which semantic analysis has not been performed", - e->location()); - } - - if(e->kind() != procedureKind::net_receive) { - // print prototype - print_device_function_prototype(e); - buffer().end_line(" {"); - - // print body - buffer().increase_indentation(); - - buffer().add_line("using value_type = arb::fvm_value_type;"); - buffer().add_line(); - - e->body()->accept(this); - - // close up - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - } - else { - // net_receive() kernel is a special case, not covered by APIMethod visit. - - // Core `net_receive` kernel is called device-side from `kernel::deliver_events`. - buffer().add_line( "__device__"); - buffer().add_gutter() << "void net_receive(const " << module_name_ << "_ParamPack& params_, " - << "arb::fvm_size_type i_, arb::fvm_value_type weight) {"; - buffer().add_line(); - buffer().increase_indentation(); - - buffer().add_line("using value_type = arb::fvm_value_type;"); - buffer().add_line(); - - buffer().add_line("auto tid_ = i_;"); - buffer().add_line("auto gid_ __attribute__((unused)) = params_.ni[tid_];"); - buffer().add_line("auto cid_ __attribute__((unused)) = params_.ci[gid_];"); - - print_APIMethod_body(e); - - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - - // Global one-thread wrapper for `net_receive` kernel is used to implement the - // `mechanism::net_receive` method. This is not called in the normal course - // of event delivery. - buffer().add_line( "__global__"); - buffer().add_gutter() << "void net_receive_global(" - << module_name_ << "_ParamPack params_, " - << "arb::fvm_size_type i_, arb::fvm_value_type weight) {"; - buffer().add_line(); - buffer().increase_indentation(); - - buffer().add_line("if (threadIdx.x || blockIdx.x) return;"); - buffer().add_line("net_receive(params_, i_, weight);"); - - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - - buffer().add_line( "__global__"); - buffer().add_gutter() << "void deliver_events(" - << module_name_ << "_ParamPack params_, " - << "arb::fvm_size_type mech_id, deliverable_event_stream_state state) {"; - buffer().add_line(); - buffer().increase_indentation(); - - buffer().add_line("auto tid_ = threadIdx.x + blockDim.x*blockIdx.x;"); - buffer().add_line("auto const ncell_ = state.n;"); - buffer().add_line(); - buffer().add_line("if(tid_<ncell_) {"); - buffer().increase_indentation(); - - - buffer().add_line("auto begin = state.ev_data+state.begin_offset[tid_];"); - buffer().add_line("auto end = state.ev_data+state.end_offset[tid_];"); - buffer().add_line("for (auto p = begin; p<end; ++p) {"); - buffer().increase_indentation(); - buffer().add_line("if (p->mech_id==mech_id) {"); - buffer().increase_indentation(); - buffer().add_line("net_receive(params_, p->mech_index, p->weight);"); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().decrease_indentation(); - buffer().add_line("}"); - - buffer().decrease_indentation(); - buffer().add_line("}"); - - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); - } -} - -std::string CUDAPrinter::APIMethod_prototype(APIMethod *e) { - return "void " + e->name() + "_" + module_name_ - + "(" + pack_name() + " params_)"; -} - -void CUDAPrinter::visit(APIMethod *e) { - // print prototype - buffer().add_line("__global__"); - buffer().add_line(APIMethod_prototype(e) + " {"); - - if(!e->scope()) { // error: semantic analysis has not been performed - throw compiler_exception( - "CUDAPrinter attempt to print APIMethod " + e->name() - + " for which semantic analysis has not been performed", - e->location()); - } - buffer().increase_indentation(); - - buffer().add_line("using value_type = arb::fvm_value_type;"); - buffer().add_line(); - - buffer().add_line("auto tid_ = threadIdx.x + blockDim.x*blockIdx.x;"); - buffer().add_line("auto const n_ = params_.n_;"); - buffer().add_line(); - buffer().add_line("if(tid_<n_) {"); - buffer().increase_indentation(); - - buffer().add_line("auto gid_ __attribute__((unused)) = params_.ni[tid_];"); - buffer().add_line("auto cid_ __attribute__((unused)) = params_.ci[gid_];"); - - print_APIMethod_body(e); - - buffer().decrease_indentation(); - buffer().add_line("}"); - - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); -} - -void CUDAPrinter::print_APIMethod_body(ProcedureExpression* e) { - // load indexes of ion channels - auto uses_k = false; - auto uses_na = false; - auto uses_ca = false; - for(auto &symbol : e->scope()->locals()) { - auto ch = symbol.second->is_local_variable()->ion_channel(); - if(!uses_k && (uses_k = (ch == ionKind::K)) ) { - buffer().add_line("auto kid_ = params_.ion_k_idx_[tid_];"); - } - if(!uses_ca && (uses_ca = (ch == ionKind::Ca)) ) { - buffer().add_line("auto caid_ = params_.ion_ca_idx_[tid_];"); - } - if(!uses_na && (uses_na = (ch == ionKind::Na)) ) { - buffer().add_line("auto naid_ = params_.ion_na_idx_[tid_];"); - } - } - - // shadows for indexed arrays - for(auto &symbol : e->scope()->locals()) { - auto var = symbol.second->is_local_variable(); - if(is_input(var)) { - auto ext = var->external_variable(); - buffer().add_gutter() << "value_type "; - var->accept(this); - buffer() << " = "; - ext->accept(this); - buffer().end_line("; // indexed load"); - } - else if (is_output(var)) { - buffer().add_gutter() << "value_type " << var->name() << ";"; - buffer().end_line(); - } - } - - buffer().add_line(); - buffer().add_line("// the kernel computation"); - - e->body()->accept(this); - - // insert stores here - // take care to use atomic operations for the updates for point processes, where - // more than one thread may try add/subtract to the same memory location - auto has_outputs = false; - for(auto &symbol : e->scope()->locals()) { - auto in = symbol.second->is_local_variable(); - auto out = in->external_variable(); - if(out==nullptr || !is_output(in)) continue; - if(!has_outputs) { - buffer().add_line(); - buffer().add_line("// stores to indexed global memory"); - has_outputs = true; - } - buffer().add_gutter(); - if(!is_point_process()) { - out->accept(this); - buffer() << (out->op()==tok::plus ? " += " : " -= "); - in->accept(this); - } - else { - buffer() << "arb::gpu::reduce_by_key("; - if (out->op()==tok::minus) buffer() << "-"; - in->accept(this); - // reduce_by_key() takes a pointer to the start of the target - // array as a parameter. This requires writing the index_name of out, which - // we can safely assume is an indexed_variable by this point. - buffer() << ", params_." << out->is_indexed_variable()->index_name() << ", gid_)"; - } - buffer().end_line(";"); - } - - return; -} - -void CUDAPrinter::visit(CallExpression *e) { - buffer() << e->name() << "(params_, tid_"; - for(auto& arg: e->args()) { - buffer() << ", "; - arg->accept(this); - } - buffer() << ")"; -} - -void CUDAPrinter::visit(BinaryExpression *e) { - cexpr_emit(e, buffer().text(), this); -} diff --git a/modcc/cudaprinter.hpp b/modcc/cudaprinter.hpp deleted file mode 100644 index af7d9d133bd195e052c7889e3b504fcd737be1cc..0000000000000000000000000000000000000000 --- a/modcc/cudaprinter.hpp +++ /dev/null @@ -1,124 +0,0 @@ -#pragma once - -#include <sstream> - -#include "module.hpp" -#include "textbuffer.hpp" -#include "visitor.hpp" - -class CUDAPrinter : public Visitor { -public: - CUDAPrinter() {} - CUDAPrinter(Module &m, bool o=false); - - void visit(Expression *e) override; - void visit(UnaryExpression *e) override; - void visit(BinaryExpression *e) override; - void visit(NumberExpression *e) override; - void visit(VariableExpression *e) override; - - void visit(Symbol *e) override; - void visit(LocalVariable *e) override; - void visit(IndexedVariable *e) override; - void visit(CellIndexedVariable *e) override; - - void visit(IdentifierExpression *e) override; - void visit(CallExpression *e) override; - void visit(ProcedureExpression *e) override; - void visit(APIMethod *e) override; - void visit(LocalDeclaration *e) override; - void visit(BlockExpression *e) override; - void visit(IfExpression *e) override; - - std::string impl_header_text() const { - return impl_interface_.str(); - } - - std::string impl_text() const { - return impl_.str(); - } - - std::string interface_text() const { - return interface_.str(); - } - - // public for testing purposes: - void set_buffer(TextBuffer& buf) { - current_buffer_ = &buf; - } - -private: - - bool is_input(Symbol *s) { - if(auto l = s->is_local_variable() ) { - if(l->is_local()) { - if(l->is_indexed() && l->is_read()) { - return true; - } - } - } - return false; - } - - bool is_output(Symbol *s) { - if(auto l = s->is_local_variable() ) { - if(l->is_local()) { - if(l->is_indexed() && l->is_write()) { - return true; - } - } - } - return false; - } - - bool is_indexed_local(Symbol *s) { - if(auto l=s->is_local_variable()) { - if(l->is_indexed()) { - return true; - } - } - return false; - } - - bool is_arg_local(Symbol *s) { - if(auto l=s->is_local_variable()) { - if(l->is_arg()) { - return true; - } - } - return false; - } - - bool is_stack_local(Symbol *s) { - if(is_arg_local(s)) return false; - if(is_input(s)) return false; - if(is_output(s)) return false; - return true; - } - - bool is_point_process() const { - return module_->kind() == moduleKind::point; - } - - void print_APIMethod_body(ProcedureExpression* e); - std::string APIMethod_prototype(APIMethod *e); - std::string pack_name(); - void print_device_function_prototype(ProcedureExpression *e); - std::string index_string(Symbol *e); - - std::string module_name_; - Module *module_ = nullptr; - - TextBuffer interface_; - TextBuffer impl_; - TextBuffer impl_interface_; - TextBuffer* current_buffer_; - - TextBuffer& buffer() { - if (!current_buffer_) { - throw std::runtime_error("CUDAPrinter buffer must be set via CUDAPrinter::set_buffer() before accessing via CUDAPrinter::buffer()."); - } - return *current_buffer_; - } -}; - diff --git a/modcc/expression.cpp b/modcc/expression.cpp index 8a9da2a17ff199e3b7f87b07533fe0b6bdf222bb..5689fe5599ad82c2855bf888d2c1d52755fff142 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -96,7 +96,7 @@ void IdentifierExpression::semantic(scope_ptr scp) { // indexed variable is used in this procedure. In which case, we create // a local variable which refers to the indexed variable, which will be // found for any subsequent variable lookup inside the procedure - if(auto sym = s->is_abstract_indexed_variable()) { + if(auto sym = s->is_indexed_variable()) { auto var = new LocalVariable(location_, spelling_); var->external_variable(sym); s = scope_->add_local_symbol(spelling_, scope_type::symbol_ptr{var}); @@ -272,15 +272,6 @@ std::string IndexedVariable::to_string() const { + ", ion" + (ion_channel()==ionKind::none ? red(ch) : green(ch)) + ") "; } -/******************************************************************************* - CellIndexedVariable -*******************************************************************************/ - -std::string CellIndexedVariable::to_string() const { - auto ch = ::to_string(ion_channel()); - return blue("cellindexed") + " " + yellow(name()) + "->" + yellow(index_name()); -} - /******************************************************************************* ReactionExpression *******************************************************************************/ @@ -889,9 +880,6 @@ void VariableExpression::accept(Visitor *v) { void IndexedVariable::accept(Visitor *v) { v->visit(this); } -void CellIndexedVariable::accept(Visitor *v) { - v->visit(this); -} void NumberExpression::accept(Visitor *v) { v->visit(this); } diff --git a/modcc/expression.hpp b/modcc/expression.hpp index db25ca0d2825edda167e12c1274d129948b5a157..dd018629a92301c44559ce7c5bf12acf91df99f1 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -45,9 +45,7 @@ class VariableExpression; class ProcedureExpression; class NetReceiveExpression; class APIMethod; -class AbstractIndexedVariable; class IndexedVariable; -class CellIndexedVariable; class LocalVariable; using expression_ptr = std::unique_ptr<Expression>; @@ -226,9 +224,7 @@ public : virtual ProcedureExpression* is_procedure() {return nullptr;} virtual NetReceiveExpression* is_net_receive() {return nullptr;} virtual APIMethod* is_api_method() {return nullptr;} - virtual AbstractIndexedVariable* is_abstract_indexed_variable() {return nullptr;} virtual IndexedVariable* is_indexed_variable() {return nullptr;} - virtual CellIndexedVariable* is_cell_indexed_variable() {return nullptr;} virtual LocalVariable* is_local_variable() {return nullptr;} private : @@ -268,6 +264,7 @@ public: void semantic(scope_ptr scp) override; + void symbol(Symbol* sym) { symbol_ = sym; } Symbol* symbol() { return symbol_; }; void accept(Visitor *v) override; @@ -463,6 +460,9 @@ public: void state(bool s) { is_state_ = s; } + void shadows(Symbol* s) { + shadows_ = s; + } accessKind access() const { return access_; @@ -477,6 +477,10 @@ public: return ion_channel_; } + Symbol* shadows() const { + return shadows_; + } + bool is_ion() const {return ion_channel_ != ionKind::none;} bool is_state() const {return is_state_;} bool is_range() const {return range_kind_ == rangeKind::range;} @@ -504,40 +508,24 @@ protected: rangeKind range_kind_ = rangeKind::range; ionKind ion_channel_ = ionKind::none; double value_ = std::numeric_limits<double>::quiet_NaN(); -}; - -// abstract base class for the two sorts of indexed externals -class AbstractIndexedVariable: public Symbol { -public: - AbstractIndexedVariable(Location loc, std::string name, symbolKind kind) - : Symbol(std::move(loc), std::move(name), std::move(kind)) - {} - - virtual accessKind access() const = 0; - virtual ionKind ion_channel() const = 0; - virtual std::string const& index_name() const = 0; - virtual tok op() const = 0; - - virtual bool is_ion() const = 0; - virtual bool is_read() const = 0; - virtual bool is_write() const = 0; - - AbstractIndexedVariable* is_abstract_indexed_variable() override {return this;} + Symbol* shadows_ = nullptr; }; // an indexed variable -class IndexedVariable : public AbstractIndexedVariable { +class IndexedVariable : public Symbol { public: IndexedVariable(Location loc, std::string lookup_name, std::string index_name, + sourceKind data_source, accessKind acc, tok o=tok::eq, ionKind channel=ionKind::none) - : AbstractIndexedVariable(loc, std::move(lookup_name), symbolKind::indexed_variable), + : Symbol(std::move(loc), std::move(lookup_name), symbolKind::indexed_variable), access_(acc), ion_channel_(channel), - index_name_(index_name), + index_name_(index_name), // (TODO: deprecate/remove this...) + data_source_(data_source), op_(o) { std::string msg; @@ -568,14 +556,15 @@ public: std::string to_string() const override; - accessKind access() const override { return access_; } - ionKind ion_channel() const override { return ion_channel_; } - std::string const& index_name() const override { return index_name_; } - tok op() const override { return op_; } + accessKind access() const { return access_; } + ionKind ion_channel() const { return ion_channel_; } + sourceKind data_source() const { return data_source_; } + std::string const& index_name() const { return index_name_; } + tok op() const { return op_; } - bool is_ion() const override { return ion_channel_ != ionKind::none; } - bool is_read() const override { return access_ == accessKind::read; } - bool is_write() const override { return access_ == accessKind::write; } + bool is_ion() const { return ion_channel_ != ionKind::none; } + bool is_read() const { return access_ == accessKind::read; } + bool is_write() const { return access_ == accessKind::write; } void accept(Visitor *v) override; IndexedVariable* is_indexed_variable() override {return this;} @@ -584,40 +573,11 @@ public: protected: accessKind access_; ionKind ion_channel_; - std::string index_name_; + std::string index_name_; // hint to printer only + sourceKind data_source_; tok op_; }; -class CellIndexedVariable : public AbstractIndexedVariable { -public: - CellIndexedVariable(Location loc, - std::string lookup_name, - std::string index_name) - : AbstractIndexedVariable(loc, std::move(lookup_name), symbolKind::indexed_variable), - index_name_(std::move(index_name)) - {} - - std::string to_string() const override; - - accessKind access() const override { return accessKind::read; } - ionKind ion_channel() const override { return ionKind::none; } - std::string const& index_name() const override { return index_name_; } - tok op() const override { return tok::plus; } - - bool is_ion() const override { return false; } - bool is_read() const override { return true; } - bool is_write() const override { return false; } - - void accept(Visitor *v) override; - CellIndexedVariable* is_cell_indexed_variable() override {return this;} - - ~CellIndexedVariable() {} - -protected: - std::string index_name_; -}; - - class LocalVariable : public Symbol { public : LocalVariable(Location loc, @@ -662,11 +622,11 @@ public : return kind_==localVariableKind::argument; } - AbstractIndexedVariable* external_variable() { + IndexedVariable* external_variable() { return external_; } - void external_variable(AbstractIndexedVariable *i) { + void external_variable(IndexedVariable *i) { external_ = i; } @@ -674,7 +634,7 @@ public : void accept(Visitor *v) override; private : - AbstractIndexedVariable *external_=nullptr; + IndexedVariable *external_=nullptr; localVariableKind kind_; }; diff --git a/modcc/identifier.hpp b/modcc/identifier.hpp index 78a8ce5d19279ccf2a9dbf767c7a9f3c057012f8..6642caacf668b12f17981ac2b05be0509e1b480c 100644 --- a/modcc/identifier.hpp +++ b/modcc/identifier.hpp @@ -42,6 +42,18 @@ enum class ionKind { K ///< potassium ion }; +/// possible external data source for indexed variables +enum class sourceKind { + voltage, + current, + dt, + ion_current, + ion_revpot, + ion_iconc, + ion_econc, + no_source +}; + inline std::string yesno(bool val) { return std::string(val ? "yes" : "no"); }; @@ -99,22 +111,18 @@ inline std::ostream& operator<< (std::ostream& os, linkageKind l) { return os << to_string(l); } -inline ionKind ion_kind_from_name(std::string field) { - if(field.substr(0,4) == "ion_") { - field = field.substr(4); - } - if(field=="ica" || field=="eca" || field=="cai" || field=="cao") { - return ionKind::Ca; - } - if(field=="ik" || field=="ek" || field=="ki" || field=="ko") { - return ionKind::K; - } - if(field=="ina" || field=="ena" || field=="nai" || field=="nao") { - return ionKind::Na; - } - return ionKind::none; +/// ion variable to data source kind + +inline sourceKind ion_source(ionKind i, const std::string& var) { + std::string ion = to_string(i); + if (var=="i"+ion) return sourceKind::ion_current; + else if (var=="e"+ion) return sourceKind::ion_revpot; + else if (var==ion+"i") return sourceKind::ion_iconc; + else if (var==ion+"e") return sourceKind::ion_econc; + else return sourceKind::no_source; } +// TODO: deprecate; back-end dependent. inline std::string ion_store(ionKind k) { switch(k) { case ionKind::Ca: diff --git a/modcc/io/bulkio.hpp b/modcc/io/bulkio.hpp index e140843d4053bfc45534cb95186e319dfc7a574f..65e227b87bc56dbffbc4cc66a1726d88cca0b74f 100644 --- a/modcc/io/bulkio.hpp +++ b/modcc/io/bulkio.hpp @@ -2,12 +2,21 @@ // Read or write the contents of a file in toto. -#include <string> -#include <iterator> #include <fstream> +#include <iterator> +#include <stdexcept> +#include <string> +#include <utility> namespace io { +// Note: catching std::ios_base::failure is broken for gcc versions before 7 +// with C++11, owing to ABI issues. + +struct bulkio_error: std::runtime_error { + bulkio_error(std::string what): std::runtime_error(std::move(what)) {} +}; + template <typename HasAssign> void read_all(std::istream& in, HasAssign& A) { A.assign(std::istreambuf_iterator<char>(in), std::istreambuf_iterator<char>()); @@ -15,10 +24,15 @@ void read_all(std::istream& in, HasAssign& A) { template <typename HasAssign> void read_all(const std::string& filename, HasAssign& A) { - std::ifstream fs; - fs.exceptions(std::ios::failbit); - fs.open(filename); - read_all(fs, A); + try { + std::ifstream fs; + fs.exceptions(std::ios::failbit); + fs.open(filename); + read_all(fs, A); + } + catch (const std::exception&) { + throw bulkio_error("failure reading "+filename); + } } inline std::string read_all(std::istream& in) { @@ -40,10 +54,15 @@ void write_all(const Container& data, std::ostream& out) { template <typename Container> void write_all(const Container& data, const std::string& filename) { - std::ofstream fs; - fs.exceptions(std::ios::failbit); - fs.open(filename); - write_all(data, fs); + try { + std::ofstream fs; + fs.exceptions(std::ios::failbit); + fs.open(filename); + write_all(data, fs); + } + catch (const std::exception&) { + throw bulkio_error("failure writing "+filename); + } } -} +} // namespace io diff --git a/modcc/io/ostream_wrappers.hpp b/modcc/io/ostream_wrappers.hpp new file mode 100644 index 0000000000000000000000000000000000000000..baf88a00403273ef4a39a7a7f1627584b507dc82 --- /dev/null +++ b/modcc/io/ostream_wrappers.hpp @@ -0,0 +1,57 @@ +#pragma once + +// Convenience wrappers for ostream formatting. + +#include <locale> +#include <ostream> +#include <string> + +namespace io { + +template <typename T> +struct quoted_item { + quoted_item(const T& item): item(item) {} + const T& item; +}; + +template <typename T> +std::ostream& operator<<(std::ostream& o, const quoted_item<T>& q) { + return o << '"' << q.item << '"'; +} + +// Print item enclosed in double quotes. + +template <typename T> +quoted_item<T> quote(const T& item) { return quoted_item<T>(item); } + +// Separator which prints nothing or a prefix the first time it is used, +// then prints the delimiter text each time thereafter. + +struct separator { + explicit separator(std::string d): + delimiter(std::move(d)) {} + + separator(std::string p, std::string d): + prefix(std::move(p)), delimiter(std::move(d)) {} + + std::string prefix, delimiter; + bool visited = false; + + friend std::ostream& operator<<(std::ostream& o, separator& sep) { + o << (sep.visited? sep.delimiter: sep.prefix); + sep.visited = true; + return o; + } + + void reset() { visited = false; } +}; + +// Reset locale on a stream to 'classic' for locale-independent formatting +// inline. + +inline std::ios_base& classic(std::ios_base& ios) { + ios.imbue(std::locale::classic()); + return ios; +} + +} // namespace io diff --git a/src/util/prefixbuf.cpp b/modcc/io/prefixbuf.cpp similarity index 94% rename from src/util/prefixbuf.cpp rename to modcc/io/prefixbuf.cpp index 71dfea842abf53c1d058775a31065259c9bb7f59..916e795070a221cc1a61379d9a05fae3631ea1d5 100644 --- a/src/util/prefixbuf.cpp +++ b/modcc/io/prefixbuf.cpp @@ -3,10 +3,9 @@ #include <string> #include <vector> -#include <util/prefixbuf.hpp> +#include "io/prefixbuf.hpp" -namespace arb { -namespace util { +namespace io { // prefixbuf implementation: @@ -15,7 +14,9 @@ std::streamsize prefixbuf::xsputn(const char_type* s, std::streamsize count) { while (count>0) { if (bol_) { - inner_->sputn(&prefix[0], prefix.size()); + if (prefix_empty_lines_ || s[0]!='\n') { + inner_->sputn(&prefix[0], prefix.size()); + } bol_ = false; } @@ -140,5 +141,4 @@ std::ostream& operator<<(std::ostream& os, indent_manip in) { return os; } -} // namespace util -} // namespace arb +} // namespace io diff --git a/src/util/prefixbuf.hpp b/modcc/io/prefixbuf.hpp similarity index 91% rename from src/util/prefixbuf.hpp rename to modcc/io/prefixbuf.hpp index 5f4c342e78cab89f69029e8a6a72b930891cd594..3768779781b5bd653fc495c1adee6d2025a8d92e 100644 --- a/src/util/prefixbuf.hpp +++ b/modcc/io/prefixbuf.hpp @@ -8,8 +8,7 @@ #include <sstream> #include <string> -namespace arb { -namespace util { +namespace io { // `prefixbuf` acts an output-only filter for another streambuf, inserting // the contents of the `prefix` string before the first character in a line. @@ -25,10 +24,14 @@ namespace util { // // >>> hello // >>> world +// +// A flag determines if the prefixbuf should or should not emit the prefix +// for empty lines. class prefixbuf: public std::streambuf { public: - explicit prefixbuf(std::streambuf* inner): inner_(inner) {} + explicit prefixbuf(std::streambuf* inner, bool prefix_empty_lines=false): + inner_(inner), prefix_empty_lines_(prefix_empty_lines) {} prefixbuf(prefixbuf&&) = default; prefixbuf(const prefixbuf&) = delete; @@ -41,6 +44,7 @@ public: protected: std::streambuf* inner_; + bool prefix_empty_lines_ = false; bool bol_ = true; std::streamsize xsputn(const char_type* s, std::streamsize count) override; @@ -130,6 +134,4 @@ private: prefixbuf pbuf_; }; - -} // namespace util -} // namespace arb +} // namespace io diff --git a/modcc/modcc.cpp b/modcc/modcc.cpp index ea2dafc1c3c43a3e60625ced65f74cfbfe0fdda8..16d19d8583ccb5ffe742e33f37ddd4183cdff495 100644 --- a/modcc/modcc.cpp +++ b/modcc/modcc.cpp @@ -1,17 +1,20 @@ #include <exception> #include <iostream> +#include <stdexcept> #include <unordered_map> #include <unordered_set> #include <tclap/CmdLine.h> -#include "cprinter.hpp" -#include "cudaprinter.hpp" +#include "printer/cprinter.hpp" +#include "printer/cudaprinter.hpp" +#include "printer/infoprinter.hpp" +#include "printer/simd.hpp" + #include "modccutil.hpp" #include "module.hpp" #include "parser.hpp" #include "perfvisitor.hpp" -#include "simd_printer.hpp" #include "io/bulkio.hpp" @@ -38,13 +41,16 @@ enum class targetKind { std::unordered_map<std::string, targetKind> targetKindMap = { {"cpu", targetKind::cpu}, - {"gpu", targetKind::gpu} + {"gpu", targetKind::gpu}, }; -std::unordered_map<std::string, simdKind> simdKindMap = { - {"none", simdKind::none}, - {"avx2", simdKind::avx2}, - {"avx512", simdKind::avx512} +std::unordered_map<std::string, enum simd_spec::simd_abi> simdAbiMap = { + {"none", simd_spec::none}, + {"avx", simd_spec::avx}, + {"avx2", simd_spec::avx2}, + {"avx512", simd_spec::avx512}, + {"default_abi", simd_spec::default_abi}, + {"native", simd_spec::native} }; template <typename Map, typename V> @@ -61,19 +67,28 @@ struct Options { std::string modulename; bool verbose = true; bool analysis = false; - simdKind simd_arch = simdKind::none; + simd_spec simd = simd_spec::none; std::unordered_set<targetKind, enum_hash> targets; }; // Helper for formatting tabulated output (option reporting). struct table_prefix { std::string text; }; std::ostream& operator<<(std::ostream& out, const table_prefix& tb) { - return out << cyan("| "+tb.text) << std::left << std::setw(61-tb.text.size()); + return out << cyan("| "+tb.text) << std::right << std::setw(58-tb.text.size()); }; +std::ostream& operator<<(std::ostream& out, simd_spec simd) { + std::stringstream s; + s << key_by_value(simdAbiMap, simd.abi); + if (simd.width!=0) { + s << '/' << simd.width; + } + return out << s.str(); +} + std::ostream& operator<<(std::ostream& out, const Options& opt) { static const char* noyes[2] = {"no", "yes"}; - static const std::string line_end = cyan("|") + "\n"; + static const std::string line_end = cyan(" |") + "\n"; static const std::string tableline = cyan("."+std::string(60, '-')+".")+"\n"; std::string targets; @@ -87,7 +102,7 @@ std::ostream& operator<<(std::ostream& out, const Options& opt) { table_prefix{"output"} << (opt.outprefix.empty()? "-": opt.outprefix) << line_end << table_prefix{"verbose"} << noyes[opt.verbose] << line_end << table_prefix{"targets"} << targets << line_end << - table_prefix{"simd"} << key_by_value(simdKindMap, opt.simd_arch) << line_end << + table_prefix{"simd"} << opt.simd << line_end << table_prefix{"analysis"} << noyes[opt.analysis] << line_end << tableline; } @@ -107,6 +122,37 @@ struct MapConstraint: private std::vector<std::string>, public TCLAP::ValuesCons } }; +simd_spec parse_simd_spec(std::string spec) { + auto npos = std::string::npos; + unsigned width = 0; + + auto suffix = spec.find_last_of('/'); + if (suffix!=npos) { + width = stoul(spec.substr(suffix+1)); + spec = spec.substr(0, suffix); + } + + return simd_spec(simdAbiMap.at(spec.c_str()), width); +} + +struct SimdAbiConstraint: public TCLAP::Constraint<std::string> { + std::string description() const override { + return "simd_abi[/n]"; + } + std::string shortID() const override { + return description(); + } + bool check(const std::string& spec) const override { + try { + (void)parse_simd_spec(spec); + return true; + } + catch (...) { + return false; + } + } +}; + int main(int argc, char **argv) { Options opt; @@ -121,11 +167,14 @@ int main(int argc, char **argv) { MapConstraint targets_arg_constraint(targetKindMap); TCLAP::MultiArg<std::string> - target_arg("t", "target", "backend target={cpu, gpu}", false, &targets_arg_constraint, cmd); + target_arg("t", "target", "build module for cpu or gpu back-end", false, &targets_arg_constraint, cmd); + + TCLAP::SwitchArg + simd_arg("s", "simd", "generate code with explicit SIMD vectorization", cmd, false); - MapConstraint simd_arg_constraint(simdKindMap); + SimdAbiConstraint simd_abi_constraint; TCLAP::ValueArg<std::string> - simd_arg("s", "simd", "use SIMD intrinsics={avx512, avx2}", false, "", &simd_arg_constraint, cmd); + simd_abi_arg("S", "simd-abi", "override SIMD ABI in generated code. Use /n suffix to force SIMD width to be size n. Examples: 'avx2', 'native/4', ...", false, "", &simd_abi_constraint, cmd); TCLAP::SwitchArg verbose_arg("V","verbose","toggle verbose mode", cmd, false); @@ -142,8 +191,11 @@ int main(int argc, char **argv) { opt.verbose = verbose_arg.getValue(); opt.analysis = analysis_arg.getValue(); - if (!simd_arg.getValue().empty()) { - opt.simd_arch = simdKindMap.at(simd_arg.getValue()); + if (simd_arg.getValue()) { + opt.simd = simd_spec(simd_spec::native); + if (!simd_abi_arg.getValue().empty()) { + opt.simd = parse_simd_spec(simd_abi_arg.getValue()); + } } for (auto& target: target_arg.getValue()) { @@ -202,31 +254,18 @@ int main(int argc, char **argv) { // If no output prefix given, use the module name. std::string prefix = opt.outprefix.empty()? m.module_name(): opt.outprefix; + io::write_all(build_info_header(m, "arb"), prefix+".hpp"); + for (targetKind target: opt.targets) { std::string outfile = prefix; switch (target) { case targetKind::gpu: - outfile += "_gpu"; - { - CUDAPrinter printer(m); - io::write_all(printer.interface_text(), outfile+".hpp"); - io::write_all(printer.impl_header_text(), outfile+"_impl.hpp"); - io::write_all(printer.impl_text(), outfile+"_impl.cu"); - } + io::write_all(emit_cuda_cpp_source(m, "arb"), outfile+"_gpu.cpp"); + io::write_all(emit_cuda_cu_source(m, "arb"), outfile+"_gpu.cu"); break; case targetKind::cpu: - outfile += "_cpu.hpp"; - switch (opt.simd_arch) { - case simdKind::none: - io::write_all(CPrinter(m).emit_source(), outfile); - break; - case simdKind::avx2: - io::write_all(SimdPrinter<simdKind::avx2>(m).emit_source(), outfile); - break; - case simdKind::avx512: - io::write_all(SimdPrinter<simdKind::avx512>(m).emit_source(), outfile); - break; - } + io::write_all(emit_cpp_source(m, "arb", opt.simd), outfile+"_cpu.cpp"); + break; } } diff --git a/modcc/module.cpp b/modcc/module.cpp index 85030677eb745f8f64737bfd4f7bc3110e732d07..409845d4660b9c7721775111d3ce81e9909af767 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -3,6 +3,7 @@ #include <fstream> #include <iostream> #include <set> +#include <unordered_set> #include "errorvisitor.hpp" #include "functionexpander.hpp" @@ -50,14 +51,14 @@ public: statements_.push_back(make_expression<AssignmentExpression>(loc_, id("current_"), make_expression<MulBinaryExpression>(loc_, - id("weights_"), + id("weight_"), id("current_")))); for (auto& v: ion_current_vars_) { statements_.push_back(make_expression<AssignmentExpression>(loc_, id(v), make_expression<MulBinaryExpression>(loc_, - id("weights_"), + id("weight_"), id(v)))); } } @@ -114,6 +115,10 @@ std::string Module::warning_string() const { return str; } +void Module::add_callable(symbol_ptr callable) { + callables_.push_back(std::move(callable)); +} + bool Module::semantic() { //////////////////////////////////////////////////////////////////////////// // create the symbol table @@ -152,17 +157,16 @@ bool Module::semantic() { }; // Add built in function that approximate exp use pade polynomials - functions_.push_back( + callables_.push_back( Parser{"FUNCTION exp_pade_11(z) { exp_pade_11=(1+0.5*z)/(1-0.5*z) }"}.parse_function()); - functions_.push_back( + callables_.push_back( Parser{ "FUNCTION exp_pade_22(z)" "{ exp_pade_22=(1+0.5*z+0.08333333333333333*z*z)/(1-0.5*z+0.08333333333333333*z*z) }" }.parse_function()); // move functions and procedures to the symbol table - if(!move_symbols(functions_)) return false; - if(!move_symbols(procedures_)) return false; + if(!move_symbols(callables_)) return false; // perform semantic analysis and inlining on function and procedure bodies if(auto errors = semantic_func_proc()) { @@ -206,6 +210,49 @@ bool Module::semantic() { return std::make_pair(proc, source); }; + // ... except for write_ions(), which we construct by hand here. + + expr_list_type ion_assignments; + + auto sym_to_id = [](Symbol* sym) -> expression_ptr { + auto id = make_expression<IdentifierExpression>(sym->location(), sym->name()); + id->is_identifier()->symbol(sym); + return id; + }; + + auto numeric_literal = [](double v) -> expression_ptr { + return make_expression<NumberExpression>(Location{}, v); + }; + + for (auto& sym: symbols_) { + Location loc; + + auto state = sym.second->is_variable(); + if (!state || !state->is_state()) continue; + + auto shadowed = state->shadows(); + if (!shadowed) continue; + + auto ionvar = shadowed->is_indexed_variable(); + if (!ionvar || !ionvar->is_ion() || !ionvar->is_write()) continue; + + auto weight = symbols_["weight_"].get(); + if (!weight) throw compiler_exception("missing weight_ global", loc); + + ion_assignments.push_back( + make_expression<AssignmentExpression>(loc, + sym_to_id(ionvar), + make_expression<MulBinaryExpression>(loc, + make_expression<MulBinaryExpression>(loc, + numeric_literal(0.1), + sym_to_id(weight)), + sym_to_id(state)))); + } + + symbols_["write_ions"] = make_symbol<APIMethod>(Location{}, "write_ions", + std::vector<expression_ptr>(), + make_expression<BlockExpression>(Location{}, std::move(ion_assignments), false)); + //......................................................................... // nrn_init : based on the INITIAL block (i.e. the 'initial' procedure //......................................................................... @@ -360,18 +407,18 @@ bool Module::semantic() { /// populate the symbol table with class scope variables void Module::add_variables_to_symbols() { - // add reserved symbols (not v, because for some reason it has to be added - // by the user) - auto create_variable = [this] (const char* name, rangeKind rng, accessKind acc) { - auto t = new VariableExpression(Location(), name); - t->state(false); - t->linkage(linkageKind::local); - t->ion_channel(ionKind::none); - t->range(rng); - t->access(acc); - t->visibility(visibilityKind::global); - symbols_[name] = symbol_ptr{t}; - }; + auto create_variable = + [this](const Token& token, accessKind a, visibilityKind v, linkageKind l, + rangeKind r, bool is_state = false) -> symbol_ptr& + { + auto var = new VariableExpression(token.location, token.spelling); + var->access(a); + var->visibility(v); + var->linkage(l); + var->range(r); + var->state(is_state); + return symbols_[var->name()] = symbol_ptr{var}; + }; // mechanisms use a vector of weights to: // density mechs: @@ -379,110 +426,65 @@ void Module::add_variables_to_symbols() { // - density or proportion of a CV's area affected by the mechansim // point procs: // - convert current in nA to current densities in A.m^-2 - create_variable("weights_", rangeKind::range, accessKind::read); + + create_variable(Token(tok::identifier, "weight_"), + accessKind::read, visibilityKind::global, linkageKind::external, rangeKind::range); // add indexed variables to the table auto create_indexed_variable = [this] - (std::string const& name, std::string const& indexed_name, - tok op, accessKind acc, ionKind ch, Location loc) + (std::string const& name, std::string const& indexed_name, sourceKind data_source, + tok op, accessKind acc, ionKind ch, Location loc) -> symbol_ptr& { if(symbols_.count(name)) { throw compiler_exception( - pprintf("the symbol % already exists", yellow(name)), - loc); + pprintf("the symbol % already exists", yellow(name)), loc); } - symbols_[name] = - make_symbol<IndexedVariable>(loc, name, indexed_name, acc, op, ch); + return symbols_[name] = + make_symbol<IndexedVariable>(loc, name, indexed_name, data_source, acc, op, ch); }; - create_indexed_variable("current_", "vec_i", tok::plus, + create_indexed_variable("current_", "vec_i", sourceKind::current, tok::plus, accessKind::write, ionKind::none, Location()); - create_indexed_variable("v", "vec_v", tok::eq, + create_indexed_variable("v", "vec_v", sourceKind::voltage, tok::eq, accessKind::read, ionKind::none, Location()); - create_indexed_variable("dt", "vec_dt", tok::eq, + create_indexed_variable("dt", "vec_dt", sourceKind::dt, tok::eq, accessKind::read, ionKind::none, Location()); - // add cell-indexed variables to the table - auto create_cell_indexed_variable = [this] - (std::string const& name, std::string const& indexed_name, Location loc = Location()) - { - if(symbols_.count(name)) { - throw compiler_exception( - "trying to insert a symbol that already exists", - loc); - } - symbols_[name] = make_symbol<CellIndexedVariable>(loc, name, indexed_name); - }; - - create_cell_indexed_variable("t", "vec_t"); - create_cell_indexed_variable("t_to", "vec_t_to"); - - // add state variables - for(auto const &var : state_block()) { - VariableExpression *id = new VariableExpression(Location(), var.name()); - - id->state(true); // set state to true - // state variables are private - // what about if the state variables is an ion concentration? - id->linkage(linkageKind::local); - id->visibility(visibilityKind::local); - id->ion_channel(ionKind::none); // no ion channel - id->range(rangeKind::range); // always a range - id->access(accessKind::readwrite); + // If we put back support for accessing cell time again from NMODL code, + // add indexed_variable also for "time" with appropriate cell-index based + // indirection in printers. - symbols_[var.name()] = symbol_ptr{id}; + // Add state variables. + for (const Id& id: state_block_) { + create_variable(id.token, + accessKind::readwrite, visibilityKind::local, linkageKind::local, rangeKind::range, true); } - // add the parameters - for(auto const& var : parameter_block()) { - auto name = var.name(); - if(name == "v") { // global voltage values - // ignore voltage, which is added as an indexed variable by default + // Add parameters, ignoring built-in voltage variable "v". + for (const Id& id: parameter_block_) { + if (id.name() == "v") { continue; } - VariableExpression *id = new VariableExpression(Location(), name); - - id->state(false); // never a state variable - id->linkage(linkageKind::local); - // parameters are visible to Neuron - id->visibility(visibilityKind::global); - id->ion_channel(ionKind::none); - // scalar by default, may later be upgraded to range - id->range(rangeKind::scalar); - id->access(accessKind::read); - - // check for 'special' variables - if(name == "celcius") { // global celcius parameter - id->linkage(linkageKind::external); - } + + // Parameters are scalar by default, but may later be changed to range. + linkageKind linkage = linkageKind::local; + auto& sym = create_variable(id.token, + accessKind::read, visibilityKind::global, linkage, rangeKind::scalar); // set default value if one was specified - if(var.value.size()) { - id->value(std::stod(var.value)); + if (id.has_value()) { + sym->is_variable()->value(std::stod(id.value)); } - - symbols_[name] = symbol_ptr{id}; } - // add the assigned variables - for(auto const& var : assigned_block()) { - auto name = var.name(); - if(name == "v") { // global voltage values - // ignore voltage, which is added as an indexed variable by default + // Add 'assigned' variables, ignoring built-in voltage variable "v". + for (const Id& id: assigned_block_) { + if (id.name() == "v") { continue; } - VariableExpression *id = new VariableExpression(var.token.location, name); - - id->state(false); // never a state variable - id->linkage(linkageKind::local); - // local visibility by default - id->visibility(visibilityKind::local); - id->ion_channel(ionKind::none); // can change later - // ranges because these are assigned to in loop - id->range(rangeKind::range); - id->access(accessKind::readwrite); - - symbols_[name] = symbol_ptr{id}; + + create_variable(id.token, + accessKind::readwrite, visibilityKind::local, linkageKind::local, rangeKind::range); } //////////////////////////////////////////////////// @@ -494,58 +496,42 @@ void Module::add_variables_to_symbols() { auto update_ion_symbols = [this, create_indexed_variable] (Token const& tkn, accessKind acc, ionKind channel) { - auto const& name = tkn.spelling; - - if(has_symbol(name)) { - auto sym = symbols_[name].get(); - - // if sym is an indexed_variable: error - // else if sym is a state variable: register a writeback call - // else if sym is a range (non parameter) variable: error - // else if sym is a parameter variable: error - // else it does not exist so make an indexed variable - - // If an indexed variable has already been created with the same name - // throw an error. - if(sym->kind()==symbolKind::indexed_variable) { + std::string name = tkn.spelling; + sourceKind data_source = ion_source(channel, name); + + // If the symbol already exists and is not a state variable, + // it is an error. + // + // Otherwise create an indexed variable and associate it + // with the state variable if present (via a different name) + // for ion state updates. + + VariableExpression* state = nullptr; + if (has_symbol(name)) { + state = symbols_[name].get()->is_variable(); + if (!state || !state->is_state()) { error(pprintf("the symbol defined % at % can't be redeclared", - sym->location(), yellow(name)), - tkn.location); - return; - } - else if(sym->kind()==symbolKind::variable) { - auto var = sym->is_variable(); - - // state variable: register writeback - if(var->is_state()) { - // create writeback - write_backs_.push_back(WriteBack(name, "ion_"+name, channel)); - return; - } - - // error: a normal range variable or parameter can't have the same - // name as an indexed ion variable - error(pprintf("the ion channel variable % at % can't be redeclared", - yellow(name), sym->location()), - tkn.location); + state->location(), yellow(name)), tkn.location); return; } + name += "_shadowed_"; } - // add the ion variable's indexed shadow - create_indexed_variable(name, "ion_"+name, - acc==accessKind::read ? tok::eq : tok::plus, -acc, channel, tkn.location); + auto& sym = create_indexed_variable(name, "ion_"+name, data_source, + acc==accessKind::read ? tok::eq : tok::plus, acc, channel, tkn.location); + + if (state) { + state->shadows(sym.get()); + } }; // check for nonspecific current - if( neuron_block().has_nonspecific_current() ) { - auto const& i = neuron_block().nonspecific_current; + if( neuron_block_.has_nonspecific_current() ) { + auto const& i = neuron_block_.nonspecific_current; update_ion_symbols(i, accessKind::write, ionKind::nonspecific); } - - for(auto const& ion : neuron_block().ions) { + for(auto const& ion : neuron_block_.ions) { for(auto const& var : ion.read) { update_ion_symbols(var, accessKind::read, ion.kind()); } @@ -555,7 +541,7 @@ acc, channel, tkn.location); } // then GLOBAL variables - for(auto const& var : neuron_block().globals) { + for(auto const& var : neuron_block_.globals) { if(!symbols_[var.spelling]) { error( yellow(var.spelling) + " is declared as GLOBAL, but has not been declared in the" + @@ -575,7 +561,7 @@ acc, channel, tkn.location); } // then RANGE variables - for(auto const& var : neuron_block().ranges) { + for(auto const& var : neuron_block_.ranges) { if(!symbols_[var.spelling]) { error( yellow(var.spelling) + " is declared as RANGE, but has not been declared in the" + @@ -710,3 +696,4 @@ int Module::semantic_func_proc() { } return errors; } + diff --git a/modcc/module.hpp b/modcc/module.hpp index 8195caaaf48a8973b4f8e8220c6c151158938e14..8b650870efc28d6b8feadc96ca30b9341cf73842 100644 --- a/modcc/module.hpp +++ b/modcc/module.hpp @@ -1,12 +1,13 @@ #pragma once +#include <iterator> #include <string> +#include <utility> #include <vector> #include "blocks.hpp" #include "error.hpp" #include "expression.hpp" -#include "writeback.hpp" // wrapper around a .mod file class Module: public error_stack { @@ -44,39 +45,35 @@ public: void title(const std::string& t) { title_ = t; } const std::string& title() const { return title_; } -// TODO: are const and non-const methods necessary? check usage. - NeuronBlock & neuron_block() {return neuron_block_;} + moduleKind kind() const { return kind_; } + void kind(moduleKind k) { kind_ = k; } + + // only used for ion access - this will be done differently ... NeuronBlock const& neuron_block() const {return neuron_block_;} - StateBlock & state_block() {return state_block_;} + // Retrieve list of state variable ids. StateBlock const& state_block() const {return state_block_;} - UnitsBlock & units_block() {return units_block_;} - UnitsBlock const& units_block() const {return units_block_;} - - ParameterBlock & parameter_block() {return parameter_block_;} + // Retrieve list of parameter variable ids. ParameterBlock const& parameter_block() const {return parameter_block_;} - AssignedBlock & assigned_block() {return assigned_block_;} - AssignedBlock const& assigned_block() const {return assigned_block_;} + // Retrieve list of ion dependencies. + const std::vector<IonDep>& ion_deps() const { return neuron_block_.ions; } + // Set top-level blocks (called from Parser). void neuron_block(const NeuronBlock& n) { neuron_block_ = n; } void state_block(const StateBlock& s) { state_block_ = s; } void units_block(const UnitsBlock& u) { units_block_ = u; } void parameter_block(const ParameterBlock& p) { parameter_block_ = p; } void assigned_block(const AssignedBlock& a) { assigned_block_ = a; } - // access to the AST - std::vector<symbol_ptr>& procedures() { return procedures_; } - const std::vector<symbol_ptr>& procedures() const { return procedures_; } + // Add global procedure or function, before semantic pass (called from Parser). + void add_callable(symbol_ptr callable); - std::vector<symbol_ptr>& functions() { return functions_; } - const std::vector<symbol_ptr>& functions() const { return functions_; } - - symbol_map& symbols() { return symbols_; } + // Raw access to AST data. const symbol_map& symbols() const { return symbols_; } - // error handling + // Error and warning handling. using error_stack::error; void error(std::string const& msg, Location loc = Location{}) { error({msg, loc}); @@ -84,7 +81,6 @@ public: std::string error_string() const; - // warnings using error_stack::warning; void warning(std::string const& msg, Location loc = Location{}) { warning({msg, loc}); @@ -92,18 +88,10 @@ public: std::string warning_string() const; - moduleKind kind() const { return kind_; } - void kind(moduleKind k) { kind_ = k; } - - // perform semantic analysis - void add_variables_to_symbols(); + // Perform semantic analysis pass. bool semantic(); - const std::vector<WriteBack>& write_backs() const { - return write_backs_; - } - - auto find_ion(ionKind k) -> decltype(neuron_block().ions.begin()) { + auto find_ion(ionKind k) -> decltype(ion_deps().begin()) { auto& ions = neuron_block().ions; return std::find_if( ions.begin(), ions.end(), @@ -121,24 +109,29 @@ private: std::string title_; std::string module_name_; std::string source_name_; - std::vector<char> buffer_; // character buffer loaded from file + std::vector<char> buffer_; // Holds module source, zero terminated. - bool generate_initial_api(); - bool generate_current_api(); - bool generate_state_api(); + NeuronBlock neuron_block_; + StateBlock state_block_; + UnitsBlock units_block_; + ParameterBlock parameter_block_; + AssignedBlock assigned_block_; - // AST storage - std::vector<symbol_ptr> procedures_; - std::vector<symbol_ptr> functions_; + // AST storage. + std::vector<symbol_ptr> callables_; - // hash table for lookup of variable and call names + // Symbol name to symbol_ptr map. symbol_map symbols_; - /// tests if symbol is defined + bool generate_initial_api(); + bool generate_current_api(); + bool generate_state_api(); + void add_variables_to_symbols(); + bool has_symbol(const std::string& name) { return symbols_.find(name) != symbols_.end(); } - /// tests if symbol is defined + bool has_symbol(const std::string& name, symbolKind kind) { auto s = symbols_.find(name); return s == symbols_.end() ? false : s->second->kind() == kind; @@ -147,13 +140,4 @@ private: // Perform semantic analysis on functions and procedures. // Returns the number of errors that were encountered. int semantic_func_proc(); - - // blocks - NeuronBlock neuron_block_; - StateBlock state_block_; - UnitsBlock units_block_; - ParameterBlock parameter_block_; - AssignedBlock assigned_block_; - - std::vector<WriteBack> write_backs_; }; diff --git a/modcc/parser.cpp b/modcc/parser.cpp index 2c52fc4c2600cb4deaa958c48f48393a4ba0388b..d912a4415cbcb7fcf5413ca5d5cfb04aea991475 100644 --- a/modcc/parser.cpp +++ b/modcc/parser.cpp @@ -113,14 +113,14 @@ bool Parser::parse() { { auto p = parse_procedure(); if(!p) break; - module_->procedures().emplace_back(std::move(p)); + module_->add_callable(std::move(p)); } break; case tok::function : { auto f = parse_function(); if(!f) break; - module_->functions().emplace_back(std::move(f)); + module_->add_callable(std::move(f)); } break; default : diff --git a/modcc/backends/avx2.hpp b/modcc/printer/backends/avx2.hpp similarity index 100% rename from modcc/backends/avx2.hpp rename to modcc/printer/backends/avx2.hpp diff --git a/modcc/backends/avx512.hpp b/modcc/printer/backends/avx512.hpp similarity index 99% rename from modcc/backends/avx512.hpp rename to modcc/printer/backends/avx512.hpp index 97e3ecbce14a376f07adc40444f70e9cf87f0872..2d8563497eaf52389ec2a310e773bffba2f52baa 100644 --- a/modcc/backends/avx512.hpp +++ b/modcc/printer/backends/avx512.hpp @@ -159,4 +159,4 @@ struct simd_intrinsics<simdKind::avx512> { } }; -} // namespace modcc; +} // namespace modcc diff --git a/modcc/backends/base.hpp b/modcc/printer/backends/base.hpp similarity index 99% rename from modcc/backends/base.hpp rename to modcc/printer/backends/base.hpp index 5149fe7904661693dc2d7b230cb595a2be8971d7..b0b9398e48e6673bc93c02dd5ee0f102e7f7c3c1 100644 --- a/modcc/backends/base.hpp +++ b/modcc/printer/backends/base.hpp @@ -7,6 +7,7 @@ #include <functional> #include <stdexcept> #include <string> +#include <type_traits> #include "token.hpp" #include "textbuffer.hpp" @@ -43,7 +44,6 @@ static operand_fn_t arg_emitter(const operand_fn_t& arg) { return arg; } - template<simdKind Arch> struct simd_intrinsics { static std::string emit_headers(); diff --git a/modcc/backends/simd.hpp b/modcc/printer/backends/simd.hpp similarity index 100% rename from modcc/backends/simd.hpp rename to modcc/printer/backends/simd.hpp diff --git a/modcc/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp similarity index 73% rename from modcc/cexpr_emit.cpp rename to modcc/printer/cexpr_emit.cpp index cb5ee70dd12c7b147e63ffba004e27743bb89242..49bf4974a61ee84cdd598e421b1881a96e43a426 100644 --- a/modcc/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -1,8 +1,29 @@ +#include <cmath> +#include <iomanip> #include <ostream> #include <unordered_map> #include "cexpr_emit.hpp" #include "error.hpp" +#include "lexer.hpp" +#include "io/ostream_wrappers.hpp" +#include "io/prefixbuf.hpp" + +std::ostream& operator<<(std::ostream& out, as_c_double wrap) { + bool neg = std::signbit(wrap.value); + + switch (std::fpclassify(wrap.value)) { + case FP_INFINITE: + return out << (neg? "-": "") << "INFINITY"; + case FP_NAN: + return out << "NAN"; + case FP_ZERO: + return out << (neg? "-0.": "0."); + default: + return out << + (std::stringstream{} << io::classic << std::setprecision(17) << wrap.value).rdbuf(); + } +} void CExprEmitter::emit_as_call(const char* sub, Expression* e) { out_ << sub << '('; @@ -19,7 +40,7 @@ void CExprEmitter::emit_as_call(const char* sub, Expression* e1, Expression* e2) } void CExprEmitter::visit(NumberExpression* e) { - out_ << " " << e->value(); + out_ << " " << as_c_double(e->value()); } void CExprEmitter::visit(UnaryExpression* e) { @@ -61,7 +82,7 @@ void CExprEmitter::visit(AssignmentExpression* e) { } void CExprEmitter::visit(PowBinaryExpression* e) { - emit_as_call("std::pow", e->lhs(), e->rhs()); + emit_as_call("pow", e->lhs(), e->rhs()); } void CExprEmitter::visit(BinaryExpression* e) { @@ -120,3 +141,23 @@ void CExprEmitter::visit(BinaryExpression* e) { emit_as_call(op_spelling, lhs, rhs); } } + +void CExprEmitter::visit(IfExpression* e) { + out_ << "if ("; + e->condition()->accept(this); + out_ << ") {\n" << io::indent; + e->true_branch()->accept(this); + out_ << io::popindent << "}\n"; + + if (auto fb = e->false_branch()) { + out_ << "else "; + if (fb->is_if()) { + fb->accept(this); + } + else { + out_ << "{\n" << io::indent; + fb->accept(this); + out_ << io::popindent << "}\n"; + } + } +} diff --git a/modcc/cexpr_emit.hpp b/modcc/printer/cexpr_emit.hpp similarity index 54% rename from modcc/cexpr_emit.hpp rename to modcc/printer/cexpr_emit.hpp index e52c956b85298ee4a33c3e08dc2bd235aa3f2dd2..099e3ed5f397eb35a0afb963e37ca37b1eae8395 100644 --- a/modcc/cexpr_emit.hpp +++ b/modcc/printer/cexpr_emit.hpp @@ -1,10 +1,12 @@ #pragma once +#include <iosfwd> + #include "expression.hpp" #include "visitor.hpp" // Common functionality for generating source from binary expressions -// as C expressions. +// and conditional structures with C syntax. class CExprEmitter: public Visitor { public: @@ -14,11 +16,12 @@ public: void visit(Expression* e) override { e->accept(fallback_); } - void visit(UnaryExpression *e) override; - void visit(BinaryExpression *e) override; + void visit(UnaryExpression *e) override; + void visit(BinaryExpression *e) override; void visit(AssignmentExpression *e) override; - void visit(PowBinaryExpression *e) override; - void visit(NumberExpression *e) override; + void visit(PowBinaryExpression *e) override; + void visit(NumberExpression *e) override; + void visit(IfExpression *e) override; protected: std::ostream& out_; @@ -29,6 +32,14 @@ protected: }; inline void cexpr_emit(Expression* e, std::ostream& out, Visitor* fallback) { - CExprEmitter renderer(out, fallback); - e->accept(&renderer); + CExprEmitter emitter(out, fallback); + e->accept(&emitter); } + +// Helper for formatting of double-valued numeric constants. +struct as_c_double { + double value; + as_c_double(double value): value(value) {} +}; + +std::ostream& operator<<(std::ostream&, as_c_double); diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b17184de50887a1d650f757ec72a98597fd4c977 --- /dev/null +++ b/modcc/printer/cprinter.cpp @@ -0,0 +1,618 @@ +#include <cmath> +#include <iostream> +#include <string> +#include <unordered_set> + +#include "expression.hpp" +#include "io/ostream_wrappers.hpp" +#include "io/prefixbuf.hpp" +#include "printer/cexpr_emit.hpp" +#include "printer/cprinter.hpp" +#include "printer/printerutil.hpp" + +using io::indent; +using io::popindent; +using io::quote; + + +void emit_procedure_proto(std::ostream&, ProcedureExpression*, const std::string& qualified = ""); +void emit_simd_procedure_proto(std::ostream&, ProcedureExpression*, const std::string& qualified = ""); + +void emit_api_body(std::ostream&, APIMethod*); +void emit_simd_api_body(std::ostream&, APIMethod*, moduleKind); + +struct cprint { + Expression* expr_; + explicit cprint(Expression* expr): expr_(expr) {} + + friend std::ostream& operator<<(std::ostream& out, const cprint& w) { + CPrinter printer(out); + return w.expr_->accept(&printer), out; + } +}; + +struct simdprint { + Expression* expr_; + explicit simdprint(Expression* expr): expr_(expr) {} + + friend std::ostream& operator<<(std::ostream& out, const simdprint& w) { + SimdPrinter printer(out); + return w.expr_->accept(&printer), out; + } +}; + +static std::string ion_state_field(std::string ion_name) { + return "ion_"+ion_name+"_"; +} + +static std::string ion_state_index(std::string ion_name) { + return "ion_"+ion_name+"_index_"; +} + +std::string emit_cpp_source(const Module& module_, const std::string& ns, simd_spec simd) { + std::string name = module_.module_name(); + std::string class_name = "mechanism_cpu_"+name; + auto ns_components = namespace_components(ns); + + NetReceiveExpression* net_receive = find_net_receive(module_); + APIMethod* init_api = find_api_method(module_, "nrn_init"); + APIMethod* state_api = find_api_method(module_, "nrn_state"); + APIMethod* current_api = find_api_method(module_, "nrn_current"); + APIMethod* write_ions_api = find_api_method(module_, "write_ions"); + + bool with_simd = simd.abi!=simd_spec::none; + + // init_api, state_api, current_api methods are mandatory: + + assert_has_scope(init_api, "nrn_init"); + assert_has_scope(state_api, "nrn_state"); + assert_has_scope(current_api, "nrn_current"); + + auto vars = local_module_variables(module_); + auto ion_deps = module_.ion_deps(); + std::string fingerprint = "<placeholder>"; + + io::pfxstringstream out; + + out << + "#include <cmath>\n" + "#include <cstddef>\n" + "#include <memory>\n" + "#include <" << arb_header_prefix() << "backends/multicore/mechanism.hpp>\n" + "#include <" << arb_header_prefix() << "math.hpp>\n"; + + if (with_simd) { + out << "#include <" << arb_header_prefix() << "simd/simd.hpp>\n"; + } + + out << + "\n" << namespace_declaration_open(ns_components) << + "\n" + "using backend = ::arb::multicore::backend;\n" + "using base = ::arb::multicore::mechanism;\n" + "using value_type = base::value_type;\n" + "using size_type = base::size_type;\n" + "using index_type = base::index_type;\n" + "using ::std::abs;\n" + "using ::std::cos;\n" + "using ::std::exp;\n" + "using ::arb::math::exprelr;\n" + "using ::std::log;\n" + "using ::arb::math::max;\n" + "using ::arb::math::min;\n" + "using ::std::pow;\n" + "using ::std::sin;\n" + "\n"; + + if (with_simd) { + out << + "namespace S = ::arb::simd;\n" + "static constexpr unsigned simd_width_ = "; + + if (!simd.width) { + out << "S::simd_abi::native_width<fvm_value_type>::value;\n"; + } + else { + out << simd.width << ";\n"; + } + + std::string abi = "S::simd_abi::"; + switch (simd.abi) { + case simd_spec::avx: abi += "avx"; break; + case simd_spec::avx2: abi += "avx2"; break; + case simd_spec::avx512: abi += "avx512"; break; + case simd_spec::native: abi += "native"; break; + default: + abi += "default_abi"; break; + } + + out << + "using simd_value = S::simd<fvm_value_type, simd_width_, " << abi << ">;\n" + "using simd_index = S::simd<fvm_index_type, simd_width_, " << abi << ">;\n" + "\n"; + } + + out << + "class " << class_name << ": public base {\n" + "public:\n" << indent << + "const mechanism_fingerprint& fingerprint() const override {\n" << indent << + "static mechanism_fingerprint hash = " << quote(fingerprint) << ";\n" + "return hash;\n" << popindent << + "}\n" + "std::string internal_name() const override { return " << quote(name) << "; }\n" + "mechanismKind kind() const override { return " << module_kind_str(module_) << "; }\n" + "mechanism_ptr clone() const override { return mechanism_ptr(new " << class_name << "()); }\n" + "\n" + "void nrn_init() override;\n" + "void nrn_state() override;\n" + "void nrn_current() override;\n" + "void write_ions() override;\n"; + + net_receive && out << + "void deliver_events(deliverable_event_stream::state events) override;\n" + "void net_receive(int i_, value_type weight);\n"; + + out << + "\n" << popindent << + "protected:\n" << indent << + "using ionKind = ::arb::ionKind;\n\n" + "std::size_t object_sizeof() const override { return sizeof(*this); }\n"; + + io::separator sep("\n", ",\n"); + if (!vars.scalars.empty()) { + out << + "mechanism_global_table global_table() override {\n" << indent << + "return {" << indent; + + for (const auto& scalar: vars.scalars) { + auto memb = scalar->name(); + out << sep << "{" << quote(memb) << ", &" << memb << "}"; + } + out << popindent << "\n};\n" << popindent << "}\n"; + } + + if (!vars.arrays.empty()) { + out << + "mechanism_field_table field_table() override {\n" << indent << + "return {" << indent; + + sep.reset(); + for (const auto& array: vars.arrays) { + auto memb = array->name(); + out << sep << "{" << quote(memb) << ", &" << memb << "}"; + } + out << popindent << "\n};" << popindent << "\n}\n"; + + out << + "mechanism_field_default_table field_default_table() override {\n" << indent << + "return {" << indent; + + sep.reset(); + for (const auto& array: vars.arrays) { + auto memb = array->name(); + auto dflt = array->value(); + if (!std::isnan(dflt)) { + out << sep << "{" << quote(memb) << ", " << as_c_double(dflt) << "}"; + } + } + out << popindent << "\n};" << popindent << "\n}\n"; + + } + + if (!ion_deps.empty()) { + out << + "mechanism_ion_state_table ion_state_table() override {\n" << indent << + "return {" << indent; + + sep.reset(); + for (const auto& dep: ion_deps) { + out << sep << "{ionKind::" << dep.name << ", &" << ion_state_field(dep.name) << "}"; + } + out << popindent << "\n};" << popindent << "\n}\n"; + + sep.reset(); + out << "mechanism_ion_index_table ion_index_table() override {\n" << indent << "return {" << indent; + for (const auto& dep: ion_deps) { + out << sep << "{ionKind::" << dep.name << ", &" << ion_state_index(dep.name) << "}"; + } + out << popindent << "\n};" << popindent << "\n}\n"; + } + + out << popindent << "\n" + "private:\n" << indent; + + for (const auto& scalar: vars.scalars) { + out << "value_type " << scalar->name() << " = " << as_c_double(scalar->value()) << ";\n"; + } + for (const auto& array: vars.arrays) { + out << "value_type* " << array->name() << ";\n"; + } + for (const auto& dep: ion_deps) { + out << "ion_state_view " << ion_state_field(dep.name) << ";\n"; + out << "iarray " << ion_state_index(dep.name) << ";\n"; + } + + for (auto proc: normal_procedures(module_)) { + emit_procedure_proto(out, proc); + out << ";\n"; + if (with_simd) { + emit_simd_procedure_proto(out, proc); + out << ";\n"; + } + } + + out << popindent << + "};\n\n" + "template <typename B> ::arb::concrete_mech_ptr<B> make_mechanism_" <<name << "();\n" + "template <> ::arb::concrete_mech_ptr<backend> make_mechanism_" << name << "<backend>() {\n" << indent << + "return ::arb::concrete_mech_ptr<backend>(new " << class_name << "());\n" << popindent << + "}\n\n"; + + // Nrn methods: + + net_receive && out << + "void " << class_name << "::deliver_events(deliverable_event_stream::state events) {\n" << indent << + "auto ncell = events.n_streams();\n" + "for (size_type c = 0; c<ncell; ++c) {\n" << indent << + "auto begin = events.begin_marked(c);\n" + "auto end = events.end_marked(c);\n" + "for (auto p = begin; p<end; ++p) {\n" << indent << + "if (p->mech_id==mechanism_id_) net_receive(p->mech_index, p->weight);\n" << popindent << + "}\n" << popindent << + "}\n" << popindent << + "}\n" + "\n" + "void " << class_name << "::net_receive(int i_, value_type weight) {\n" << indent << + cprint(net_receive->body()) << popindent << + "}\n\n"; + + auto emit_body = [&](APIMethod *p) { + if (with_simd) { + emit_simd_api_body(out, p, module_.kind()); + } + else { + emit_api_body(out, p); + } + }; + + out << "void " << class_name << "::nrn_init() {\n" << indent; + emit_body(init_api); + out << popindent << "}\n\n"; + + out << "void " << class_name << "::nrn_state() {\n" << indent; + emit_body(state_api); + out << popindent << "}\n\n"; + + out << "void " << class_name << "::nrn_current() {\n" << indent; + emit_body(current_api); + out << popindent << "}\n\n"; + + out << "void " << class_name << "::write_ions() {\n" << indent; + emit_body(write_ions_api); + out << popindent << "}\n\n"; + + // Mechanism procedures + + for (auto proc: normal_procedures(module_)) { + emit_procedure_proto(out, proc, class_name); + out << + " {\n" << indent << + cprint(proc->body()) << popindent << + "}\n\n"; + + if (with_simd) { + emit_simd_procedure_proto(out, proc, class_name); + out << + " {\n" << indent << + simdprint(proc->body()) << popindent << + "}\n\n"; + } + } + + out << namespace_declaration_close(ns_components); + return out.str(); +} + +struct indexed_variable_info { + std::string data_var; + std::string index_var; +}; + +indexed_variable_info decode_indexed_variable(IndexedVariable* sym) { + std::string data_var, ion_pfx; + std::string index_var = "node_index_"; + + if (sym->is_ion()) { + ion_pfx = "ion_"+to_string(sym->ion_channel())+"_"; + index_var = ion_pfx+"index_"; + } + + switch (sym->data_source()) { + case sourceKind::voltage: + data_var="vec_v_"; + break; + case sourceKind::current: + data_var="vec_i_"; + break; + case sourceKind::dt: + data_var="vec_dt_"; + break; + case sourceKind::ion_current: + data_var=ion_pfx+".current_density"; + break; + case sourceKind::ion_revpot: + data_var=ion_pfx+".reversal_potential"; + break; + case sourceKind::ion_iconc: + data_var=ion_pfx+".internal_concentration"; + break; + case sourceKind::ion_econc: + data_var=ion_pfx+".external_concentration"; + break; + default: + throw compiler_exception("unrecognized indexed data source", sym->location()); + } + + return {data_var, index_var}; +} + +// Scalar printing: + +void CPrinter::visit(IdentifierExpression *e) { + e->symbol()->accept(this); +} + +void CPrinter::visit(LocalVariable* sym) { + out_ << sym->name(); +} + +void CPrinter::visit(VariableExpression *sym) { + out_ << sym->name() << (sym->is_range()? "[i_]": ""); +} + +void CPrinter::visit(IndexedVariable *sym) { + indexed_variable_info v = decode_indexed_variable(sym); + out_ << v.data_var << "[" << v.index_var << "[i_]]"; +} + +void CPrinter::visit(CallExpression* e) { + out_ << e->name() << "(i_"; + for (auto& arg: e->args()) { + out_ << ", "; + arg->accept(this); + } + out_ << ")"; +} + +void CPrinter::visit(BlockExpression* block) { + // Only include local declarations in outer-most block. + if (!block->is_nested()) { + auto locals = pure_locals(block->scope()); + if (!locals.empty()) { + out_ << "value_type "; + io::separator sep(", "); + for (auto local: locals) { + out_ << sep << local->name(); + } + out_ << ";\n"; + } + } + + for (auto& stmt: block->statements()) { + if (!stmt->is_local_declaration()) { + stmt->accept(this); + out_ << (stmt->is_if()? "": ";\n"); + } + } +} + +void emit_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& qualified) { + out << "void " << qualified << (qualified.empty()? "": "::") << e->name() << "(int i_"; + for (auto& arg: e->args()) { + out << ", value_type " << arg->is_argument()->name(); + } + out << ")"; +} + +void emit_state_read(std::ostream& out, LocalVariable* local) { + out << "value_type " << cprint(local) << " = "; + + if (local->is_read()) { + out << cprint(local->external_variable()) << ";\n"; + } + else { + out << "0;\n"; + } +} + +void emit_state_update(std::ostream& out, Symbol* from, IndexedVariable* external) { + if (!external->is_write()) return; + + const char* op = external->op()==tok::plus? " += ": " -= "; + out << cprint(external) << op << from->name() << ";\n"; +} + +void emit_api_body(std::ostream& out, APIMethod* method) { + auto body = method->body(); + auto indexed_vars = indexed_locals(method->scope()); + + if (!body->statements().empty()) { + out << + "int n_ = width_;\n" + "for (int i_ = 0; i_ < n_; ++i_) {\n" << indent; + + for (auto& sym: indexed_vars) { + emit_state_read(out, sym); + } + out << cprint(body); + + for (auto& sym: indexed_vars) { + emit_state_update(out, sym, sym->external_variable()); + } + out << popindent << "}\n"; + } +} + +// SIMD printing: + +std::string index_i_name(const std::string& index_var) { + return index_var+"i_"; +} + +std::string index_constraint_name(const std::string& index_var) { + return index_var+"constraint_"; +} + + +void SimdPrinter::visit(IdentifierExpression *e) { + e->symbol()->accept(this); +} + +void SimdPrinter::visit(LocalVariable* sym) { + out_ << sym->name(); +} + +void SimdPrinter::visit(VariableExpression *sym) { + if (sym->is_range()) { + out_ << "simd_value(" << sym->name() << "+i_)"; + } + else { + out_ << sym->name(); + } +} + +void SimdPrinter::visit(AssignmentExpression* e) { + if (!e->lhs() || !e->lhs()->is_identifier() || !e->lhs()->is_identifier()->symbol()) { + throw compiler_exception("Expect symbol on lhs of assignment: "+e->to_string()); + } + + Symbol* lhs = e->lhs()->is_identifier()->symbol(); + + if (lhs->is_variable() && lhs->is_variable()->is_range()) { + out_ << "simd_value("; + e->rhs()->accept(this); + out_ << ").copy_to(" << lhs->name() << "+i_)"; + } + else { + out_ << lhs->name() << " = "; + e->rhs()->accept(this); + } +} + +void SimdPrinter::visit(IndexedVariable *sym) { + indexed_variable_info v = decode_indexed_variable(sym); + out_ << "S::indirect(" << v.data_var + << ", " << index_i_name(v.index_var) + << ", " << index_constraint_name(v.index_var) << ")"; +} + +void SimdPrinter::visit(CallExpression* e) { + out_ << e->name() << "(i_"; + for (auto& arg: e->args()) { + out_ << ", "; + arg->accept(this); + } + out_ << ")"; +} + +void SimdPrinter::visit(BlockExpression* block) { + // Only include local declarations in outer-most block. + if (!block->is_nested()) { + auto locals = pure_locals(block->scope()); + if (!locals.empty()) { + out_ << "simd_value "; + io::separator sep(", "); + for (auto local: locals) { + out_ << sep << local->name(); + } + out_ << ";\n"; + } + } + + for (auto& stmt: block->statements()) { + if (stmt->is_if()) { + throw compiler_exception("Conditionals not yet supported in SIMD printer: "+stmt->to_string()); + } + if (!stmt->is_local_declaration()) { + stmt->accept(this); + out_ << ";\n"; + } + } +} + +void emit_simd_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& qualified) { + out << "void " << qualified << (qualified.empty()? "": "::") << e->name() << "(int i_"; + for (auto& arg: e->args()) { + out << ", const simd_value& " << arg->is_argument()->name(); + } + out << ")"; +} + +void emit_simd_state_read(std::ostream& out, LocalVariable* local) { + out << "simd_value " << local->name(); + + if (local->is_read()) { + out << "(" << simdprint(local->external_variable()) << ");\n"; + } + else { + out << " = 0;\n"; + } +} + +void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* external) { + if (!external->is_write()) return; + + const char* op = external->op()==tok::plus? " += ": " -= "; + out << simdprint(external) << op << from->name() << ";\n"; +} + + +void emit_simd_api_body(std::ostream& out, APIMethod* method, moduleKind module_kind) { + auto body = method->body(); + auto indexed_vars = indexed_locals(method->scope()); + + std::unordered_set<std::string> indices; + for (auto& sym: indexed_vars) { + indices.insert(decode_indexed_variable(sym->external_variable()).index_var); + } + + // Note: expect to make index constraints non-constant for point mechanisms as + // an optimization in the near future. + + // Another note: (TODO) can't actually use index_constraint::independent + // for density mechanisms because of collisions in the padded part of + // the indices. Work-arounds exist, but not yet implemented. + + + std::string common_constraint = module_kind==moduleKind::density? + //"S::index_constraint::independent": + "S::index_constraint::none": + "S::index_constraint::none"; + + for (auto& index: indices) { + out << "simd_index " << index_i_name(index) << ";\n"; + out << "constexpr S::index_constraint " << index_constraint_name(index) + << " = " << common_constraint << ";\n"; + } + + if (!body->statements().empty()) { + out << + "int n_ = width_;\n" + "for (int i_ = 0; i_ < n_; i_ += simd_width_) {\n" << indent; + + for (auto& index: indices) { + out << index_i_name(index) << ".copy_from(" << index << ".data()+i_);\n"; + } + + for (auto& sym: indexed_vars) { + emit_simd_state_read(out, sym); + } + + out << simdprint(body); + + for (auto& sym: indexed_vars) { + emit_simd_state_update(out, sym, sym->external_variable()); + } + out << popindent << "}\n"; + } +} diff --git a/modcc/printer/cprinter.hpp b/modcc/printer/cprinter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a92d014d574986feaed61aa16e66bde6b0e557d0 --- /dev/null +++ b/modcc/printer/cprinter.hpp @@ -0,0 +1,63 @@ +#pragma once + +#include <iosfwd> +#include <string> + +#include "module.hpp" +#include "visitor.hpp" + +#include "printer/cexpr_emit.hpp" +#include "printer/simd.hpp" + +std::string emit_cpp_source(const Module& m, const std::string& ns, simd_spec simd); + +// CPrinter and SimdPrinter visitors exposed in header for testing purposes only. + +class CPrinter: public Visitor { +public: + CPrinter(std::ostream& out): out_(out) {} + + void visit(Expression* e) override { + throw compiler_exception("CPrinter cannot translate expression "+e->to_string()); + } + + void visit(BlockExpression*) override; + void visit(CallExpression*) override; + void visit(IdentifierExpression*) override; + void visit(VariableExpression*) override; + void visit(LocalVariable*) override; + void visit(IndexedVariable*) override; + + // Delegate low-level emits to cexpr_emit: + void visit(NumberExpression* e) override { cexpr_emit(e, out_, this); } + void visit(UnaryExpression* e) override { cexpr_emit(e, out_, this); } + void visit(BinaryExpression* e) override { cexpr_emit(e, out_, this); } + void visit(IfExpression* e) override { cexpr_emit(e, out_, this); } + +private: + std::ostream& out_; +}; + +class SimdPrinter: public Visitor { +public: + SimdPrinter(std::ostream& out): out_(out) {} + + void visit(Expression* e) override { + throw compiler_exception("SimdPrinter cannot translate expression "+e->to_string()); + } + + void visit(BlockExpression*) override; + void visit(CallExpression*) override; + void visit(IdentifierExpression*) override; + void visit(VariableExpression*) override; + void visit(LocalVariable*) override; + void visit(IndexedVariable*) override; + void visit(AssignmentExpression*) override; + + void visit(NumberExpression* e) override { cexpr_emit(e, out_, this); } + void visit(UnaryExpression* e) override { cexpr_emit(e, out_, this); } + void visit(BinaryExpression* e) override { cexpr_emit(e, out_, this); } + +private: + std::ostream& out_; +}; diff --git a/modcc/printer/cudaprinter.cpp b/modcc/printer/cudaprinter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c1272015a723d58def758ae51498a2397e2dc7fa --- /dev/null +++ b/modcc/printer/cudaprinter.cpp @@ -0,0 +1,255 @@ +#include <cmath> +#include <iostream> +#include <string> +#include <unordered_set> + +#include "expression.hpp" +#include "io/ostream_wrappers.hpp" +#include "io/prefixbuf.hpp" +#include "printer/cexpr_emit.hpp" +#include "printer/printerutil.hpp" + +using io::indent; +using io::popindent; +using io::quote; + +// Emit stream_stat alis, parameter pack struct. +void emit_common_defs(std::ostream&, const Module& module_); + +std::string make_class_name(const std::string& module_name) { + return "mechanism_gpu_"+module_name; +} + +std::string make_ppack_name(const std::string& module_name) { + return make_class_name(module_name)+"_pp_"; +} + +static std::string ion_state_field(std::string ion_name) { + return "ion_"+ion_name+"_"; +} + +static std::string ion_state_index(std::string ion_name) { + return "ion_"+ion_name+"_index_"; +} + +std::string emit_cuda_cpp_source(const Module& module_, const std::string& ns) { + std::string name = module_.module_name(); + std::string class_name = make_class_name(name); + std::string ppack_name = make_ppack_name(name); + auto ns_components = namespace_components(ns); + + NetReceiveExpression* net_receive = find_net_receive(module_); + + auto vars = local_module_variables(module_); + auto ion_deps = module_.ion_deps(); + + std::string fingerprint = "<placeholder>"; + + io::pfxstringstream out; + + net_receive && out << + "#include <" << arb_header_prefix() << "backends/event.hpp>\n" + "#include <" << arb_header_prefix() << "backends/multi_event_stream_state.hpp>\n"; + + out << + "#include <" << arb_header_prefix() << "backends/gpu/mechanism.hpp>\n" + "#include <" << arb_header_prefix() << "backends/gpu/mechanism_ppack_base.hpp>\n" + "\n" << namespace_declaration_open(ns_components) << + "\n"; + + emit_common_defs(out, module_); + + out << + "void " << class_name << "_nrn_init_(int, " << ppack_name << "&);\n" + "void " << class_name << "_nrn_state_(int, " << ppack_name << "&);\n" + "void " << class_name << "_nrn_current_(int, " << ppack_name << "&);\n" + "void " << class_name << "_write_ions_(int, " << ppack_name << "&);\n"; + + net_receive && out << + "void " << class_name << "_deliver_events_(int, " << ppack_name << "&, deliverable_event_stream_state events);\n"; + + out << + "\n" + "class " << class_name << ": public ::arb::gpu::mechanism {\n" + "public:\n" << indent << + "const mechanism_fingerprint& fingerprint() const override {\n" << indent << + "static mechanism_fingerprint hash = " << quote(fingerprint) << ";\n" + "return hash;\n" << popindent << + "}\n\n" + "std::string internal_name() const override { return " << quote(name) << "; }\n" + "mechanismKind kind() const override { return " << module_kind_str(module_) << "; }\n" + "mechanism_ptr clone() const override { return mechanism_ptr(new " << class_name << "()); }\n" + "\n" + "void nrn_init() override {\n" << indent << + class_name << "_nrn_init_(width_, pp_);\n" << popindent << + "}\n\n" + "void nrn_state() override {\n" << indent << + class_name << "_nrn_state_(width_, pp_);\n" << popindent << + "}\n\n" + "void nrn_current() override {\n" << indent << + class_name << "_nrn_current_(width_, pp_);\n" << popindent << + "}\n\n" + "void write_ions() override {\n" << indent << + class_name << "_write_ions_(width_, pp_);\n" << popindent << + "}\n\n"; + + net_receive && out << + "void deliver_events(deliverable_event_stream_state events) override {\n" << indent << + class_name << "_deliver_events_(width_, pp_, events);\n" << popindent << + "}\n\n"; + + out << popindent << + "protected:\n" << indent << + "using ionKind = ::arb::ionKind;\n\n" + "std::size_t object_sizeof() const override { return sizeof(*this); }\n" + "::arb::gpu::mechanism_ppack_base* ppack_ptr() { return &pp_; }\n\n"; + + io::separator sep("\n", ",\n"); + if (!vars.scalars.empty()) { + out << + "mechanism_global_table global_table() override {\n" << indent << + "return {" << indent; + + for (const auto& scalar: vars.scalars) { + auto memb = scalar->name(); + out << sep << "{" << quote(memb) << ", &pp_." << memb << "}"; + } + out << popindent << "\n};\n" << popindent << "}\n"; + } + + if (!vars.arrays.empty()) { + out << + "mechanism_field_table field_table() override {\n" << indent << + "return {" << indent; + + sep.reset(); + for (const auto& array: vars.arrays) { + auto memb = array->name(); + out << sep << "{" << quote(memb) << ", &pp_." << memb << "}"; + } + out << popindent << "\n};" << popindent << "\n}\n"; + + out << + "mechanism_field_default_table field_default_table() override {\n" << indent << + "return {" << indent; + + sep.reset(); + for (const auto& array: vars.arrays) { + auto memb = array->name(); + auto dflt = array->value(); + if (!std::isnan(dflt)) { + out << sep << "{" << quote(memb) << ", " << as_c_double(dflt) << "}"; + } + } + out << popindent << "\n};" << popindent << "\n}\n"; + + } + + if (!ion_deps.empty()) { + out << + "mechanism_ion_state_table ion_state_table() override {\n" << indent << + "return {" << indent; + + sep.reset(); + for (const auto& dep: ion_deps) { + out << sep << "{ionKind::" << dep.name << ", &pp_." << ion_state_field(dep.name) << "}"; + } + out << popindent << "\n};" << popindent << "\n}\n"; + + sep.reset(); + out << "mechanism_ion_index_table ion_index_table() override {\n" << indent << "return {" << indent; + for (const auto& dep: ion_deps) { + out << sep << "{ionKind::" << dep.name << ", &pp_." << ion_state_index(dep.name) << "}"; + } + out << popindent << "\n};" << popindent << "\n}\n"; + } + + out << popindent << "\n" + "private:\n" << indent << + make_ppack_name(name) << " pp_;\n" << popindent << + "};\n\n" + "template <typename B> ::arb::concrete_mech_ptr<B> make_mechanism_" << name << "();\n" + "template <> ::arb::concrete_mech_ptr<::arb::gpu::backend> make_mechanism_" << name << "<::arb::gpu::backend>() {\n" << indent << + "return ::arb::concrete_mech_ptr<::arb::gpu::backend>(new " << class_name << "());\n" << popindent << + "}\n\n"; + + out << namespace_declaration_close(ns_components); + return out.str(); +} + +std::string emit_cuda_cu_source(const Module& module_, const std::string& ns) { + std::string name = module_.module_name(); + std::string class_name = make_class_name(name); + std::string ppack_name = make_ppack_name(name); + auto ns_components = namespace_components(ns); + + NetReceiveExpression* net_receive = find_net_receive(module_); + APIMethod* init_api = find_api_method(module_, "nrn_init"); + APIMethod* state_api = find_api_method(module_, "nrn_state"); + APIMethod* current_api = find_api_method(module_, "nrn_current"); + APIMethod* write_ions_api = find_api_method(module_, "write_ions"); + + assert_has_scope(init_api, "nrn_init"); + assert_has_scope(state_api, "nrn_state"); + assert_has_scope(current_api, "nrn_current"); + + io::pfxstringstream out; + + out << + "#include <" << arb_header_prefix() << "backends/event.hpp>\n" + "#include <" << arb_header_prefix() << "backends/multi_event_stream_state.hpp>\n" + "#include <" << arb_header_prefix() << "backends/gpu/mechanism_ppack_base.hpp>\n" + "\n" << namespace_declaration_open(ns_components) << + "\n"; + + emit_common_defs(out, module_); + + out << + "using value_type = ::arb::gpu::mechanism_ppack_base::value_type;\n" + "using index_type = ::arb::gpu::mechanism_ppack_base::index_type;\n" + "\n"; + + out << + "void " << class_name << "_nrn_init_(int, " << ppack_name << "&) {};\n" + "void " << class_name << "_nrn_state_(int, " << ppack_name << "&) {};\n" + "void " << class_name << "_nrn_current_(int, " << ppack_name << "&) {};\n" + "void " << class_name << "_write_ions_(int, " << ppack_name << "&) {};\n"; + + net_receive && out << + "void " << class_name << "_deliver_events_(int, " << ppack_name << "&, deliverable_event_stream_state events) {}\n"; + + (void)write_ions_api; + + out << namespace_declaration_close(ns_components); + return out.str(); +} + +void emit_common_defs(std::ostream& out, const Module& module_) { + std::string class_name = make_class_name(module_.module_name()); + std::string ppack_name = make_ppack_name(module_.module_name()); + + auto vars = local_module_variables(module_); + auto ion_deps = module_.ion_deps(); + + find_net_receive(module_) && out << + "using deliverable_event_stream_state = ::arb::multi_event_stream_state<::arb::deliverable_event_data>;\n" + "\n"; + + out << + "struct " << ppack_name << ": ::arb::gpu::mechanism_ppack_base {\n" << indent; + + for (const auto& scalar: vars.scalars) { + out << "value_type " << scalar->name() << " = " << as_c_double(scalar->value()) << ";\n"; + } + for (const auto& array: vars.arrays) { + out << "value_type* " << array->name() << ";\n"; + } + for (const auto& dep: ion_deps) { + out << "ion_state_view " << ion_state_field(dep.name) << ";\n"; + out << "const index_type* " << ion_state_index(dep.name) << ";\n"; + } + + out << popindent << "};\n\n"; +} + + diff --git a/modcc/printer/cudaprinter.hpp b/modcc/printer/cudaprinter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..767a4eae952c1bc33222c73a87f3e87a79a36e98 --- /dev/null +++ b/modcc/printer/cudaprinter.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include <string> + +#include "module.hpp" + +std::string emit_cuda_cpp_source(const Module& m, const std::string& ns); +std::string emit_cuda_cu_source(const Module& m, const std::string& ns); diff --git a/modcc/printer/infoprinter.cpp b/modcc/printer/infoprinter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..afeee15953f690424de1c53784c2b62e0cfa9b54 --- /dev/null +++ b/modcc/printer/infoprinter.cpp @@ -0,0 +1,133 @@ +#include <ostream> +#include <set> +#include <string> + +#include "blocks.hpp" +#include "infoprinter.hpp" +#include "module.hpp" +#include "printerutil.hpp" + +#include "io/ostream_wrappers.hpp" +#include "io/prefixbuf.hpp" + +using io::quote; + +struct id_field_info { + id_field_info(const Id& id, const char* kind): id(id), kind(kind) {} + + const Id& id; + const char* kind; +}; + +std::ostream& operator<<(std::ostream& out, const id_field_info& wrap) { + const Id& id = wrap.id; + + out << "{" << quote(id.name()) << ", " + << "spec(spec::" << wrap.kind << ", " << quote(id.unit_string()) << ", " + << (id.has_value()? id.value: "0"); + + if (id.has_range()) { + out << ", " << id.range.first.spelling << "," << id.range.second.spelling; + } + + out << ")}"; + return out; +} + +struct ion_dep_info { + ion_dep_info(const IonDep& ion): ion(ion) {} + + const IonDep& ion; +}; + +std::ostream& operator<<(std::ostream& out, const ion_dep_info& wrap) { + const char* boolalpha[2] = {"false", "true"}; + const IonDep& ion = wrap.ion; + + return out << "{ionKind::" << ion.name << ", {" + << boolalpha[ion.writes_concentration_int()] << ", " + << boolalpha[ion.writes_concentration_ext()] << "}}"; +} + +std::string build_info_header(const Module& m, const std::string& qual_namespace) { + using io::indent; + using io::popindent; + + // TODO: When arbor headers are moved into a named hierarchy, change this prefix. + const char* arb_header_prefix = ""; + + std::string name = m.module_name(); + auto ids = public_variable_ids(m); + auto ns_components = namespace_components(qual_namespace); + + io::pfxstringstream out; + + out << + "#pragma once\n" + "#include <memory>\n" + "\n" + "#include <" << arb_header_prefix << "mechanism.hpp>\n" + "#include <" << arb_header_prefix << "mechinfo.hpp>\n" + "\n" + << namespace_declaration_open(ns_components) << + "\n" + "template <typename Backend>\n" + "::arb::concrete_mech_ptr<Backend> make_mechanism_" << name << "();\n" + "\n" + "inline const ::arb::mechanism_info& mechanism_" << name << "_info() {\n" + << indent << + "using ::arb::ionKind;\n" + "using spec = ::arb::mechanism_field_spec;\n" + "static mechanism_info info = {\n" + << indent << + "// globals\n" + "{\n" + << indent; + + io::separator sep(",\n"); + for (const auto& id: ids.global_parameter_ids) { + out << sep << id_field_info(id, "global"); + } + + out << popindent << + "\n},\n// parameters\n{\n" + << indent; + + sep.reset(); + for (const auto& id: ids.range_parameter_ids) { + out << sep << id_field_info(id, "parameter"); + } + + out << popindent << + "\n},\n// state variables\n{\n" + << indent; + + sep.reset(); + for (const auto& id: ids.state_ids) { + out << sep << id_field_info(id, "state"); + } + + out << popindent << + "\n},\n// ion dependencies\n{\n" + << indent; + + sep.reset(); + for (const auto& ion: m.ion_deps()) { + out << sep << ion_dep_info(ion); + } + + std::string fingerprint = "<placeholder>"; + out << popindent << "\n" + "},\n" + "// fingerprint\n" << quote(fingerprint) << "\n" + << popindent << + "};\n" + "\n" + "return info;\n" + << popindent << + "}\n" + "\n" + << namespace_declaration_close(ns_components); + + return out.str(); +} diff --git a/modcc/printer/infoprinter.hpp b/modcc/printer/infoprinter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..24ee11fcf544c5a97492def5b91f32d0af53a2c1 --- /dev/null +++ b/modcc/printer/infoprinter.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include <string> + +#include "module.hpp" + +// Build header file comprising mechanism metadata +// and declarations of backend-specific mechanism implementations. + +std::string build_info_header(const Module& m, const std::string& qual_namespace); + diff --git a/modcc/printer/printerutil.cpp b/modcc/printer/printerutil.cpp new file mode 100644 index 0000000000000000000000000000000000000000..17cd445dfd9e7f144a6c7598212dab3c70635e42 --- /dev/null +++ b/modcc/printer/printerutil.cpp @@ -0,0 +1,102 @@ +#include <regex> +#include <string> +#include <unordered_set> + +#include "expression.hpp" +#include "module.hpp" +#include "printerutil.hpp" + +std::vector<std::string> namespace_components(const std::string& ns) { + static std::regex ns_regex("([^:]+)(?:::|$)"); + + std::vector<std::string> components; + auto i = std::sregex_iterator(ns.begin(), ns.end(), ns_regex); + while (i!=std::sregex_iterator()) { + components.push_back(i++->str(1)); + } + + return components; +} + +std::vector<LocalVariable*> indexed_locals(scope_ptr scope) { + std::vector<LocalVariable*> vars; + for (auto& entry: scope->locals()) { + LocalVariable* local = entry.second->is_local_variable(); + if (local && local->is_indexed()) { + vars.push_back(local); + } + } + return vars; +} + +std::vector<LocalVariable*> pure_locals(scope_ptr scope) { + std::vector<LocalVariable*> vars; + for (auto& entry: scope->locals()) { + LocalVariable* local = entry.second->is_local_variable(); + if (local && !local->is_arg() && !local->is_indexed()) { + vars.push_back(local); + } + } + return vars; +} + +std::vector<ProcedureExpression*> normal_procedures(const Module& m) { + std::vector<ProcedureExpression*> procs; + + for (auto& sym: m.symbols()) { + auto proc = sym.second->is_procedure(); + if (proc && proc->kind()==procedureKind::normal && !proc->is_api_method() && !proc->is_net_receive()) { + procs.push_back(proc); + } + } + + return procs; +} + +public_variable_ids_t public_variable_ids(const Module& m) { + public_variable_ids_t ids; + ids.state_ids = m.state_block().state_variables; + + std::unordered_set<std::string> range_varnames; + for (const auto& sym: m.symbols()) { + if (auto var = sym.second->is_variable()) { + if (var->is_range()) { + range_varnames.insert(var->name()); + } + } + } + + for (const Id& id: m.parameter_block().parameters) { + if (range_varnames.count(id.token.spelling)) { + ids.range_parameter_ids.push_back(id); + } + else { + ids.global_parameter_ids.push_back(id); + } + } + + return ids; +} + +module_variables_t local_module_variables(const Module& m) { + module_variables_t mv; + + for (auto& sym: m.symbols()) { + auto v = sym.second->is_variable(); + if (v && v->linkage()==linkageKind::local) { + (v->is_range()? mv.arrays: mv.scalars).push_back(v); + } + } + + return mv; +} + +APIMethod* find_api_method(const Module& m, const char* which) { + auto it = m.symbols().find(which); + return it==m.symbols().end()? nullptr: it->second->is_api_method(); +} + +NetReceiveExpression* find_net_receive(const Module& m) { + auto it = m.symbols().find("net_receive"); + return it==m.symbols().end()? nullptr: it->second->is_net_receive(); +} diff --git a/modcc/printer/printerutil.hpp b/modcc/printer/printerutil.hpp new file mode 100644 index 0000000000000000000000000000000000000000..49cf97fbb9a3cf664e80dd9f021933d289ac9aaf --- /dev/null +++ b/modcc/printer/printerutil.hpp @@ -0,0 +1,102 @@ +#pragma once + +// Convenience routines/helpers for source printers. + +#include <ostream> +#include <string> +#include <vector> + +#include "blocks.hpp" +#include "error.hpp" +#include "expression.hpp" +#include "module.hpp" + +std::vector<std::string> namespace_components(const std::string& qualified_namespace); + +inline const char* arb_header_prefix() { + static const char* prefix = ""; + return prefix; +} + +struct namespace_declaration_open { + const std::vector<std::string>& ids; + namespace_declaration_open(const std::vector<std::string>& ids): ids(ids) {} + + friend std::ostream& operator<<(std::ostream& o, const namespace_declaration_open& n) { + for (auto& id: n.ids) { + o << "namespace " << id << " {\n"; + } + return o; + } +}; + +struct namespace_declaration_close { + const std::vector<std::string>& ids; + namespace_declaration_close(const std::vector<std::string>& ids): ids(ids) {} + + friend std::ostream& operator<<(std::ostream& o, const namespace_declaration_close& n) { + for (auto i = n.ids.rbegin(); i!=n.ids.rend(); ++i) { + o << "} // namespace " << *i << "\n"; + } + return o; + } +}; + +// Enum representation: + +inline const char* module_kind_str(const Module& m) { + return m.kind()==moduleKind::density? + "::arb::mechanismKind::density": + "::arb::mechanismKind::point"; +} + +// Check expression non-null and scoped, or else throw. + +inline void assert_has_scope(Expression* expr, const std::string& context) { + return + !expr? throw compiler_exception("missing expression for "+context): + !expr->scope()? throw compiler_exception("printer invoked before semantic pass for "+context): + void(); +} + + +// Scope query functions: + +// All local variables in scope with `is_indexed()` true. +std::vector<LocalVariable*> indexed_locals(scope_ptr scope); + +// All local variables in scope with `is_arg()` and `is_indexed()` false. +std::vector<LocalVariable*> pure_locals(scope_ptr scope); + + +// Module state query functions: + +// Normal (not API, net_receive) procedures in module: + +std::vector<ProcedureExpression*> normal_procedures(const Module&); + +struct public_variable_ids_t { + std::vector<Id> state_ids; + std::vector<Id> global_parameter_ids; + std::vector<Id> range_parameter_ids; +}; + +// Public module variables by role. + +public_variable_ids_t public_variable_ids(const Module&); + +struct module_variables_t { + std::vector<VariableExpression*> scalars; + std::vector<VariableExpression*> arrays; +}; + +// Scalar and array variables with local linkage. + +module_variables_t local_module_variables(const Module&); + +// Extract key procedures from module. + +APIMethod* find_api_method(const Module& m, const char* which); + +NetReceiveExpression* find_net_receive(const Module& m); + diff --git a/modcc/printer/simd.hpp b/modcc/printer/simd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..800ddaccd6c485edd5154f5065c9e2c23fce2c78 --- /dev/null +++ b/modcc/printer/simd.hpp @@ -0,0 +1,25 @@ +#pragma once + +struct simd_spec { + enum simd_abi { none, avx, avx2, avx512, native, default_abi } abi = none; + unsigned width = 0; // zero => use `simd::native_width` to determine. + + simd_spec() = default; + simd_spec(enum simd_abi a, unsigned w = 0): + abi(a), width(w) + { + if (width==0) { + // Pick a width based on abi, if applicable. + switch (abi) { + case avx: + case avx2: + width = 4; + break; + case avx512: + width = 8; + break; + default: ; + } + } + } +}; diff --git a/modcc/scope.hpp b/modcc/scope.hpp index ef70040073d84adae6c33e07f0565180c18a9f8e..831a3e6db4851bb733188ac168ba36efd245d9ce 100644 --- a/modcc/scope.hpp +++ b/modcc/scope.hpp @@ -22,9 +22,9 @@ public: Scope(symbol_map& s); ~Scope() {}; symbol_type* add_local_symbol(std::string const& name, symbol_ptr s); - symbol_type* find(std::string const& name); - symbol_type* find_local(std::string const& name); - symbol_type* find_global(std::string const& name); + symbol_type* find(std::string const& name) const; + symbol_type* find_local(std::string const& name) const; + symbol_type* find_global(std::string const& name) const; std::string to_string() const; symbol_map& locals(); @@ -58,14 +58,14 @@ Scope<Symbol>::add_local_symbol( std::string const& name, template<typename Symbol> Symbol* -Scope<Symbol>::find(std::string const& name) { +Scope<Symbol>::find(std::string const& name) const { auto local = find_local(name); return local ? local : find_global(name); } template<typename Symbol> Symbol* -Scope<Symbol>::find_local(std::string const& name) { +Scope<Symbol>::find_local(std::string const& name) const { // search in local symbols auto local = local_symbols_.find(name); @@ -78,7 +78,7 @@ Scope<Symbol>::find_local(std::string const& name) { template<typename Symbol> Symbol* -Scope<Symbol>::find_global(std::string const& name) { +Scope<Symbol>::find_global(std::string const& name) const { // search in global symbols if( global_symbols_ ) { auto global = global_symbols_->find(name); diff --git a/modcc/simd_printer.hpp b/modcc/simd_printer.hpp deleted file mode 100644 index 8208632bf381d947156a95ede4e948ab0edf141b..0000000000000000000000000000000000000000 --- a/modcc/simd_printer.hpp +++ /dev/null @@ -1,602 +0,0 @@ -#pragma once - -#include <set> -#include <sstream> - -#include "backends/simd.hpp" -#include "cprinter.hpp" -#include "modccutil.hpp" -#include "textbuffer.hpp" - -#ifdef __GNUC__ -# define ANNOT_UNUSED "__attribute__((unused))" -#else -# define ANNOT_UNUSED "" -#endif - -template <simdKind Arch> -class SimdPrinter: public CPrinter { - using CPrinter::visit; - -public: - SimdPrinter(): cprinter_(make_unique<CPrinter>()) - {} - - explicit SimdPrinter(Module& m): - CPrinter(m), - cprinter_(make_unique<CPrinter>(m)) - {} - - void visit(NumberExpression *e) override { - simd_backend::emit_set_value(text_, e->value()); - } - - void visit(UnaryExpression *e) override; - void visit(BinaryExpression *e) override; - void visit(PowBinaryExpression *e) override; - void visit(AssignmentExpression *e) override; - void visit(ProcedureExpression *e) override; - void visit(VariableExpression *e) override; - void visit(LocalVariable *e) override { - const std::string& name = e->name(); - text_ << name; - } - - void visit(CellIndexedVariable *e) override; - void visit(IndexedVariable *e) override; - void visit(APIMethod *e) override; - void visit(BlockExpression *e) override; - void visit(CallExpression *e) override; - - void emit_headers() override { - CPrinter::emit_headers(); - text_.add_line("#include <climits>"); - text_ << simd_backend::emit_headers(); - text_.add_line(); - } - - void emit_api_loop(APIMethod* e, - const std::string& start, - const std::string& end, - const std::string& inc) override; - -private: - using simd_backend = modcc::simd_intrinsics<Arch>; - - void emit_indexed_view(LocalVariable* var, std::set<std::string>& decls); - void emit_indexed_view_simd(LocalVariable* var, std::set<std::string>& decls); - - // variable naming conventions - std::string emit_member_name(const std::string& varname) { - return varname + "_"; - } - - - std::string emit_rawptr_name(const std::string& varname) { - return "r_" + varname; - } - - std::pair<std::string, std::string> - emit_rawptr_ion(const std::string& iname, const std::string& ifield) { - return std::make_pair(emit_rawptr_name(iname), - emit_rawptr_name(iname + "_" + ifield)); - } - - std::string emit_vindex_name(const std::string& varname) { - return "v_" + varname + "_index"; - } - - std::string emit_vtmp_name(const std::string& varname) { - return "v_" + varname; - } - - // CPrinter to delegate generation of unvectorised code - std::unique_ptr<CPrinter> cprinter_; - - // Treat range access as loads - bool range_load_ = true; -}; - -template <simdKind Arch> -void SimdPrinter<Arch>::visit(APIMethod *e) { - text_.add_gutter() << "void " << e->name() << "() override {\n"; - if (!e->scope()) { // error: semantic analysis has not been performed - throw compiler_exception( - "SimdPrinter attempt to print APIMethod " + e->name() - + " for which semantic analysis has not been performed", - e->location()); - } - - // only print the body if it has contents - if (e->is_api_method()->body()->statements().size()) { - text_.increase_indentation(); - - // First emit the raw pointer of node_index_ and vec_ci_ - text_.add_line("constexpr int simd_width = " + - simd_backend::emit_simd_width() + - " / (CHAR_BIT*sizeof(value_type));"); - text_.add_line("const size_type " ANNOT_UNUSED " *" + - emit_rawptr_name("node_index_") +" = node_index_.data();"); - text_.add_line("const size_type " ANNOT_UNUSED " *" + - emit_rawptr_name("vec_ci_") + " = vec_ci_.data();"); - text_.add_line(); - - // create local indexed views - std::set<std::string> index_decls; - for (auto const& symbol : e->scope()->locals()) { - auto var = symbol.second->is_local_variable(); - if (var->is_indexed()) { - emit_indexed_view(var, index_decls); - emit_indexed_view_simd(var, index_decls); - text_.add_line(); - } - } - - // get loop dimensions - text_.add_line("int n_ = node_index_.size();"); - - // print the vectorized version of the loop - emit_api_loop(e, "int i_ = 0", "i_ < n_/simd_width", "++i_"); - text_.add_line(); - - // delegate the printing of the remainder unvectorized loop - auto cprinter = cprinter_.get(); - cprinter->clear_text(); - cprinter->set_gutter(text_.get_gutter()); - cprinter->emit_api_loop(e, "int i_ = n_ - n_ % simd_width", "i_ < n_", "++i_"); - text_ << cprinter->text(); - - text_.decrease_indentation(); - } - - text_.add_line("}\n"); -} - -template <simdKind Arch> -void SimdPrinter<Arch>::emit_indexed_view(LocalVariable* var, - std::set<std::string>& decls) { - auto const& name = var->name(); - auto external = var->external_variable(); - auto const& index_name = external->index_name(); - text_.add_gutter(); - - if (decls.find(index_name) == decls.cend()) { - text_ << "auto "; - decls.insert(index_name); - } - - text_ << index_name << " = "; - - if (external->is_cell_indexed_variable()) { - text_ << "util::indirect_view(util::indirect_view(" - << emit_member_name(index_name) << ", vec_ci_), node_index_);\n"; - } - else if (external->is_ion()) { - auto channel = external->ion_channel(); - auto iname = ion_store(channel); - text_ << "util::indirect_view(" << iname << "." << name << ", " - << ion_store(channel) << ".index);\n"; - } - else { - text_ << " util::indirect_view(" + emit_member_name(index_name) + ", node_index_);\n"; - } -} - -template <simdKind Arch> -void SimdPrinter<Arch>::emit_indexed_view_simd(LocalVariable* var, - std::set<std::string>& decls) { - auto const& name = var->name(); - auto external = var->external_variable(); - auto const& index_name = external->index_name(); - - // We need to work with with raw pointers in the vectorized version - auto channel = var->external_variable()->ion_channel(); - if (channel==ionKind::none) { - auto raw_index_name = emit_rawptr_name(index_name); - if (decls.find(raw_index_name) == decls.cend()) { - text_.add_gutter(); - if (var->is_read()) - text_ << "const "; - - text_ << "value_type " ANNOT_UNUSED " *"; - decls.insert(raw_index_name); - text_ << raw_index_name << " = " - << emit_member_name(index_name) << ".data()"; - } - } - else { - auto iname = ion_store(channel); - auto ion_var_names = emit_rawptr_ion(iname, name); - if (decls.find(ion_var_names.first) == decls.cend()) { - text_.add_gutter(); - text_ << "size_type* "; - decls.insert(ion_var_names.first); - text_ << ion_var_names.first << " = " << iname << ".index.data()"; - text_.end_line(";"); - } - - if (decls.find(ion_var_names.second) == decls.cend()) { - text_.add_gutter(); - if (var->is_read()) - text_ << "const "; - - text_ << "value_type " ANNOT_UNUSED " *"; - decls.insert(ion_var_names.second); - text_ << ion_var_names.second << " = " << iname << "." - << name << ".data()"; - } - } - - text_.end_line(";"); -} - -template <simdKind Arch> -void SimdPrinter<Arch>::emit_api_loop(APIMethod* e, - const std::string& start, - const std::string& end, - const std::string& inc) { - text_.add_gutter(); - text_ << "for (" << start << "; " << end << "; " << inc << ") {"; - text_.end_line(); - text_.increase_indentation(); - text_.add_line("int off_ = i_*simd_width;"); - - // First load the index vectors of all involved ions - std::set<std::string> declared_ion_vars; - for (auto& symbol : e->scope()->locals()) { - auto var = symbol.second->is_local_variable(); - if (var->is_indexed()) { - auto external = var->external_variable(); - auto channel = external->ion_channel(); - std::string cast_type = - "(const " + simd_backend::emit_index_type() + " *) "; - - std::string vindex_name, index_ptr_name; - if (channel == ionKind::none) { - vindex_name = emit_vtmp_name("node_index_"); - index_ptr_name = emit_rawptr_name("node_index_"); - } - else { - auto iname = ion_store(channel); - vindex_name = emit_vindex_name(iname); - index_ptr_name = emit_rawptr_name(iname); - - } - - if (declared_ion_vars.find(vindex_name) == declared_ion_vars.cend()) { - declared_ion_vars.insert(vindex_name); - text_.add_gutter(); - text_ << simd_backend::emit_index_type() << " " ANNOT_UNUSED " " - << vindex_name << " = "; - // FIXME: cast should better go inside `emit_load_index()` - simd_backend::emit_load_index( - text_, cast_type + "&" + index_ptr_name + "[off_]"); - text_.end_line(";"); - } - - if (external->is_cell_indexed_variable()) { - std::string vci_name = emit_vtmp_name("vec_ci_"); - std::string ci_ptr_name = emit_rawptr_name("vec_ci_"); - - if (declared_ion_vars.find(vci_name) == declared_ion_vars.cend()) { - declared_ion_vars.insert(vci_name); - text_.add_gutter(); - text_ << simd_backend::emit_index_type() << " " ANNOT_UNUSED " " - << vci_name << " = "; - simd_backend::emit_gather_index(text_, "(int *)" + ci_ptr_name, vindex_name, "sizeof(size_type)"); - text_.end_line(";"); - } - } - } - } - - text_.add_line(); - for (auto& symbol : e->scope()->locals()) { - auto var = symbol.second->is_local_variable(); - if (is_input(var)) { - auto ext = var->external_variable(); - text_.add_gutter() << simd_backend::emit_value_type() << " "; - var->accept(this); - text_ << " = "; - ext->accept(this); - text_.end_line(";"); - } - } - - text_.add_line(); - e->body()->accept(this); - - std::vector<LocalVariable*> aliased_variables; - - // perform update of external variables (currents etc) - for (auto &symbol : e->scope()->locals()) { - auto var = symbol.second->is_local_variable(); - if (is_output(var) && - !is_point_process() && - simd_backend::has_scatter()) { - // We can safely use scatter, but we need to fetch the variable - // first - text_.add_line(); - auto ext = var->external_variable(); - auto ext_tmpname = "_" + ext->index_name(); - text_.add_gutter() << simd_backend::emit_value_type() << " " - << ext_tmpname << " = "; - ext->accept(this); - text_.end_line(";"); - text_.add_gutter(); - text_ << ext_tmpname << " = "; - simd_backend::emit_binary_op(text_, ext->op(), ext_tmpname, - [this,var](TextBuffer& tb) { - var->accept(this); - }); - text_.end_line(";"); - text_.add_gutter(); - - // Build up the index name - std::string vindex_name, raw_index_name; - auto channel = ext->ion_channel(); - if (channel != ionKind::none) { - auto iname = ion_store(channel); - vindex_name = emit_vindex_name(iname); - raw_index_name = emit_rawptr_ion(iname, ext->name()).second; - } - else { - vindex_name = emit_vtmp_name("node_index_"); - raw_index_name = emit_rawptr_name(ext->index_name()); - } - - simd_backend::emit_scatter(text_, raw_index_name, vindex_name, - ext_tmpname, "sizeof(value_type)"); - text_.end_line(";"); - } - else if (is_output(var)) { - // var is aliased; collect all the aliased variables and we will - // update them later in a fused loop all at once - aliased_variables.push_back(var); - } - } - - // Emit update code for the aliased variables - // First, define their scalar equivalents - constexpr auto scalar_var_prefix = "s_"; - for (auto& v: aliased_variables) { - text_.add_gutter(); - text_ << "value_type* " << scalar_var_prefix << v->name() - << " = (value_type*) &" << v->name(); - text_.end_line(";"); - } - - if (aliased_variables.size() > 0) { - // Update them all in a single loop - text_.add_line("for (int k_ = 0; k_ < simd_width; ++k_) {"); - text_.increase_indentation(); - for (auto& v: aliased_variables) { - auto ext = v->external_variable(); - text_.add_gutter(); - text_ << ext->index_name() << "[off_+k_]"; - text_ << (ext->op() == tok::plus ? " += " : " -= "); - text_ << scalar_var_prefix << v->name() << "[k_]"; - text_.end_line(";"); - } - text_.decrease_indentation(); - text_.add_line("}"); - } - - text_.decrease_indentation(); - text_.add_line("}"); -} - -template <simdKind Arch> -void SimdPrinter<Arch>::visit(IndexedVariable *e) { - std::string vindex_name, value_name; - - auto channel = e->ion_channel(); - if (channel != ionKind::none) { - auto iname = ion_store(channel); - vindex_name = emit_vindex_name(iname); - value_name = emit_rawptr_ion(iname, e->name()).second; - } - else { - vindex_name = emit_vtmp_name("node_index_"); - value_name = emit_rawptr_name(e->index_name()); - } - - simd_backend::emit_gather(text_, value_name, vindex_name, "sizeof(value_type)"); -} - -template <simdKind Arch> -void SimdPrinter<Arch>::visit(CellIndexedVariable *e) { - std::string vindex_name, value_name; - - vindex_name = emit_vtmp_name("vec_ci_"); -#ifdef __GNUC__ - vindex_name += " __attribute__((unused))"; -#endif - value_name = emit_rawptr_name(e->index_name()); - - simd_backend::emit_gather(text_, vindex_name, value_name, "sizeof(value_type)"); -} - -template <simdKind Arch> -void SimdPrinter<Arch>::visit(BlockExpression *e) { - if (!e->is_nested()) { - std::vector<std::string> names; - for(auto& symbol : e->scope()->locals()) { - auto sym = symbol.second.get(); - // input variables are declared earlier, before the - // block body is printed - if (is_stack_local(sym) && !is_input(sym)) { - names.push_back(sym->name()); - } - } - - if (names.size() > 0) { - text_.add_gutter() << simd_backend::emit_value_type() << " " - << *(names.begin()); - for(auto it=names.begin()+1; it!=names.end(); ++it) { - text_ << ", " << *it; - } - text_.end_line(";"); - } - } - - for (auto& stmt : e->statements()) { - if (stmt->is_local_declaration()) - continue; - - // these all must be handled - text_.add_gutter(); - stmt->accept(this); - if (not stmt->is_if()) { - text_.end_line(";"); - } - } -} - -template <simdKind Arch> -void SimdPrinter<Arch>::visit(BinaryExpression *e) { - auto lhs = e->lhs(); - auto rhs = e->rhs(); - - auto emit_lhs = [this, lhs](TextBuffer& tb) { - lhs->accept(this); - }; - auto emit_rhs = [this, rhs](TextBuffer& tb) { - rhs->accept(this); - }; - - try { - simd_backend::emit_binary_op(text_, e->op(), emit_lhs, emit_rhs); - } catch (const std::exception& exc) { - // Rethrow it as a compiler_exception - throw compiler_exception( - "SimdPrinter: " + std::string(exc.what()) + ": " + - yellow(token_string(e->op())), e->location()); - } -} - -template <simdKind Arch> -void SimdPrinter<Arch>::visit(AssignmentExpression *e) { - auto is_memop = [](Expression *e) { - auto ident = e->is_identifier(); - auto var = (ident) ? ident->symbol()->is_variable() : nullptr; - return var != nullptr && var->is_range(); - }; - - auto lhs = e->lhs(); - auto rhs = e->rhs(); - if (is_memop(lhs)) { - // that's a store; change printer's state so as not to emit a load - // instruction for the lhs visit - simd_backend::emit_store_unaligned(text_, - [this, lhs](TextBuffer&) { - auto range_load_save = range_load_; - range_load_ = false; - lhs->accept(this); - range_load_ = range_load_save; - }, - [this, rhs](TextBuffer&) { - rhs->accept(this); - }); - } - else { - // that's an ordinary assignment - lhs->accept(this); - text_ << " = "; - rhs->accept(this); - } -} - - -template <simdKind Arch> -void SimdPrinter<Arch>::visit(VariableExpression *e) { - if (e->is_range() && range_load_) { - simd_backend::emit_load_unaligned(text_, "&" + e->name() + "[off_]"); - } - else if (e->is_range()) { - text_ << "&" << e->name() << "[off_]"; - } - else { - simd_backend::emit_set_value(text_, e->name()); - } -} - -template <simdKind Arch> -void SimdPrinter<Arch>::visit(UnaryExpression *e) { - - auto arg = e->expression(); - auto emit_arg = [this, arg](TextBuffer& tb) { arg->accept(this); }; - - try { - simd_backend::emit_unary_op(text_, e->op(), emit_arg); - } catch (std::exception& exc) { - throw compiler_exception( - "SimdPrinter: " + std::string(exc.what()) + ": " + - yellow(token_string(e->op())), e->location()); - } -} - -template <simdKind Arch> -void SimdPrinter<Arch>::visit(PowBinaryExpression *e) { - auto lhs = e->lhs(); - auto rhs = e->rhs(); - auto emit_lhs = [this, lhs](TextBuffer&) { lhs->accept(this); }; - auto emit_rhs = [this, rhs](TextBuffer&) { rhs->accept(this); }; - simd_backend::emit_pow(text_, emit_lhs, emit_rhs); -} - -template <simdKind Arch> -void SimdPrinter<Arch>::visit(CallExpression *e) { - text_ << e->name() << "(off_"; - for (auto& arg: e->args()) { - text_ << ", "; - arg->accept(this); - } - text_ << ")"; -} - -template <simdKind Arch> -void SimdPrinter<Arch>::visit(ProcedureExpression *e) { - auto emit_procedure_unvectorized = [this](ProcedureExpression* e) { - auto cprinter = cprinter_.get(); - cprinter->clear_text(); - cprinter->set_gutter(text_.get_gutter()); - cprinter->visit(e); - text_ << cprinter->text(); - }; - - if (e->kind() == procedureKind::net_receive) { - // Use non-vectorized printer for printing net_receive - emit_procedure_unvectorized(e); - return; - } - - // Two versions of each procedure are needed: vectorized and unvectorized - text_.add_gutter() << "void " << e->name() << "(int off_"; - for(auto& arg : e->args()) { - text_ << ", " << simd_backend::emit_value_type() << " " - << arg->is_argument()->name(); - } - text_ << ") {\n"; - - if (!e->scope()) { - // error: semantic analysis has not been performed - throw compiler_exception( - "SimdPrinter attempt to print Procedure " + e->name() - + " for which semantic analysis has not been performed", - e->location()); - } - - // print body - increase_indentation(); - e->body()->accept(this); - - // close the function body - decrease_indentation(); - - text_.add_line("}"); - text_.add_line(); - - // Emit also the unvectorised version of the procedure - emit_procedure_unvectorized(e); -} diff --git a/modcc/textbuffer.cpp b/modcc/textbuffer.cpp deleted file mode 100644 index 658d20c4714e698967034de295327c20bb082702..0000000000000000000000000000000000000000 --- a/modcc/textbuffer.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#include "textbuffer.hpp" - -/****************************************************************************** - TextBuffer -******************************************************************************/ -TextBuffer& TextBuffer::add_gutter() { - text_ << gutter_; - return *this; -} -void TextBuffer::add_line(std::string const& line) { - text_ << gutter_ << line << std::endl; -} -void TextBuffer::add_line() { - text_ << std::endl; -} -void TextBuffer::end_line(std::string const& line) { - text_ << line << std::endl; -} -void TextBuffer::end_line() { - text_ << std::endl; -} - -std::string TextBuffer::str() const { - return text_.str(); -} - -void TextBuffer::set_gutter(int width) { - indent_ = width; - gutter_ = std::string(indent_, ' '); -} -int TextBuffer::get_gutter() { - return indent_; -} - -void TextBuffer::increase_indentation() { - indent_ += indentation_width_; - if(indent_<0) { - indent_=0; - } - gutter_ = std::string(indent_, ' '); -} -void TextBuffer::decrease_indentation() { - indent_ -= indentation_width_; - if(indent_<0) { - indent_=0; - } - gutter_ = std::string(indent_, ' '); -} - -std::stringstream& TextBuffer::text() { - return text_; -} - -void TextBuffer::clear() { - text_.str(""); -} diff --git a/modcc/textbuffer.hpp b/modcc/textbuffer.hpp deleted file mode 100644 index ade2f38c14165ff5c6a51b002b6e8d867349dfb6..0000000000000000000000000000000000000000 --- a/modcc/textbuffer.hpp +++ /dev/null @@ -1,48 +0,0 @@ -#pragma once - -#include <limits> -#include <sstream> -#include <string> - -class TextBuffer { -public: - TextBuffer() { - text_.precision(std::numeric_limits<double>::max_digits10); - } - - TextBuffer(const TextBuffer& other): - indent_(other.indent_), - indentation_width_(other.indentation_width_), - gutter_(other.gutter_), - text_(other.text_.str()) - {} - - TextBuffer& add_gutter(); - void add_line(std::string const& line); - void add_line(); - void end_line(std::string const& line); - void end_line(); - - std::string str() const; - - void set_gutter(int width); - int get_gutter(); - - void increase_indentation(); - void decrease_indentation(); - std::stringstream& text(); - - void clear(); - -private: - int indent_ = 0; - const int indentation_width_=4; - std::string gutter_ = ""; - std::stringstream text_; -}; - -template <typename T> -TextBuffer& operator<<(TextBuffer& buffer, T const& v) { - buffer.text() << v; - return buffer; -} diff --git a/modcc/visitor.hpp b/modcc/visitor.hpp index 3ccbb938ae08563ff93a1a63a87b7bd77d9bb401..8179b977cba970fccf0ca6d1bf1d1a9501c648b7 100644 --- a/modcc/visitor.hpp +++ b/modcc/visitor.hpp @@ -30,7 +30,6 @@ public: virtual void visit(StoichExpression *e) { visit((Expression*) e); } virtual void visit(VariableExpression *e) { visit((Expression*) e); } virtual void visit(IndexedVariable *e) { visit((Expression*) e); } - virtual void visit(CellIndexedVariable *e) { visit((Expression*) e); } virtual void visit(FunctionExpression *e) { visit((Expression*) e); } virtual void visit(IfExpression *e) { visit((Expression*) e); } virtual void visit(SolveExpression *e) { visit((Expression*) e); } @@ -60,7 +59,6 @@ public: virtual void visit(DivBinaryExpression *e) { visit((BinaryExpression*) e); } virtual void visit(PowBinaryExpression *e) { visit((BinaryExpression*) e); } - virtual ~Visitor() {}; }; diff --git a/modcc/writeback.hpp b/modcc/writeback.hpp deleted file mode 100644 index f8da02100a3af006d9a250d235f3ef9ebf7f419b..0000000000000000000000000000000000000000 --- a/modcc/writeback.hpp +++ /dev/null @@ -1,21 +0,0 @@ -#pragma once - -#include <expression.hpp> -#include <identifier.hpp> - -// Holds the state required to generate a write_back call in a mechanism. -struct WriteBack { - // Name of the symbol inside the mechanism used to store. - // must be a state field - std::string source_name; - // Name of the field in the ion channel being written to. - std::string target_name; - // The ion channel being written to. - // must not be ionKind::none - ionKind ion_kind; - - WriteBack(std::string src, std::string tgt, ionKind k): - source_name(std::move(src)), target_name(std::move(tgt)), ion_kind(k) - {} -}; - diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 757fe245c5e7fbd81245f263cc7eb48defb8eacd..b50e258609810105be8ae5417cbf4d7938c007b7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,15 +1,22 @@ -set(BASE_SOURCES - backends/multicore/fvm.cpp +set(arbor_cxx_sources + backends/multicore/mechanism.cpp + backends/multicore/shared_state.cpp + backends/multicore/stimulus.cpp + builtin_mechanisms.cpp cell_group_factory.cpp common_types_io.cpp cell.cpp event_binner.cpp - lif_cell_group.cpp + fvm_layout.cpp + fvm_lowered_cell_impl.cpp hardware/affinity.cpp hardware/gpu.cpp hardware/memory.cpp hardware/node_info.cpp hardware/power.cpp + lif_cell_group.cpp + mc_cell_group.cpp + mechcat.cpp merge_events.cpp simulation.cpp morphology.cpp @@ -24,45 +31,56 @@ set(BASE_SOURCES util/debug.cpp util/hostname.cpp util/path.cpp - util/prefixbuf.cpp util/unwind.cpp ) -set(CUDA_SOURCES - backends/gpu/fvm.cpp - backends/gpu/fill.cu - backends/gpu/multi_event_stream.cu - backends/gpu/kernels/assemble_matrix.cu - backends/gpu/kernels/interleave.cu - backends/gpu/kernels/ions.cu - backends/gpu/kernels/solve_matrix.cu - backends/gpu/kernels/stim_current.cu - backends/gpu/kernels/take_samples.cu - backends/gpu/kernels/test_thresholds.cu - backends/gpu/kernels/time_ops.cu -) + +if(ARB_WITH_CUDA) + list(APPEND arbor_cxx_sources + backends/gpu/mechanism.cpp + backends/gpu/shared_state.cpp + backends/gpu/stimulus.cpp + backends/gpu/stimulus.cu + backends/gpu/threshold_watcher.cu + ) +endif() if(ARB_WITH_MPI) - set(BASE_SOURCES ${BASE_SOURCES} communication/mpi.cpp) + list(APPEND arbor_cxx_sources + communication/mpi.cpp) elseif(ARB_WITH_DRYRUN) - set(BASE_SOURCES ${BASE_SOURCES} communication/dryrun_global_policy.cpp) + list(APPEND arbor_cxx_sources + communication/dryrun_global_policy.cpp) endif() if(ARB_WITH_CTHREAD) - set(BASE_SOURCES ${BASE_SOURCES} threading/cthread.cpp) + list(APPEND arbor_cxx_sources + threading/cthread.cpp) endif() -add_library(arbor ${BASE_SOURCES}) +set(arbor_cuda_sources + memory/fill.cu + backends/gpu/matrix_assemble.cu + backends/gpu/matrix_interleave.cu + backends/gpu/matrix_solve.cu + backends/gpu/multi_event_stream.cu + backends/gpu/shared_state.cu + backends/gpu/stimulus.cu + backends/gpu/threshold_watcher.cu +) + +add_library(arbor ${arbor_cxx_sources}) +target_compile_options(arbor PRIVATE ${CXXOPT_ARCH}) list(APPEND ARB_LIBRARIES arbor) if(ARB_WITH_CUDA) - cuda_add_library(arborcu ${CUDA_SOURCES}) + cuda_add_library(arborcu ${arbor_cuda_sources}) list(APPEND ARB_LIBRARIES arborcu) endif() if (ARB_AUTO_RUN_MODCC_ON_CHANGES) add_dependencies(arbor build_all_mods) if (ARB_WITH_CUDA) - add_dependencies(arborcu build_all_gpu_mods) + add_dependencies(arborcu build_all_mods) endif() endif() diff --git a/src/algorithms.hpp b/src/algorithms.hpp index 88f6ff94cc40dc7bf66c2cfab87a73789647f9c6..43ec3dcbe47388bdcd975574ef48712756786928 100644 --- a/src/algorithms.hpp +++ b/src/algorithms.hpp @@ -11,6 +11,7 @@ #include <util/debug.hpp> #include <util/meta.hpp> #include <util/range.hpp> +#include <util/rangeutil.hpp> /* * Some simple wrappers around stl algorithms to improve readability of code @@ -118,19 +119,30 @@ bool is_minimal_degree(C const& c) return it==c.end(); } -template <typename C> -bool is_positive(C const& c) -{ - static_assert( - std::is_integral<typename C::value_type>::value, - "is_positive only applies to integral types" - ); - for(auto v : c) { - if(v<1) { - return false; - } +struct generic_is_positive { + template <typename V> + bool operator()(V v) const { + static V zero = V{}; + return v>zero; } - return true; +}; + +struct generic_is_negative { + template <typename V> + bool operator()(V v) const { + static V zero = V{}; + return v<zero; + } +}; + +template <typename C> +bool all_positive(const C& c) { + return util::all_of(c, generic_is_positive{}); +} + +template <typename C> +bool all_negative(const C& c) { + return util::all_of(c, generic_is_negative{}); } template<typename C> @@ -296,126 +308,11 @@ std::vector<typename C::value_type> tree_reduce( return new_parent_index; } - -template<typename Seq, typename = util::enable_if_sequence_t<Seq>> -bool is_sorted(const Seq& seq) { - return std::is_sorted(std::begin(seq), std::end(seq)); -} - -template< typename Seq, typename = util::enable_if_sequence_t<Seq>> +template <typename Seq, typename = util::enable_if_sequence_t<Seq&>> bool is_unique(const Seq& seq) { return std::adjacent_find(std::begin(seq), std::end(seq)) == std::end(seq); } -template <typename SubIt, typename SupIt, typename SupEnd> -class index_into_iterator { -public: - using value_type = typename std::iterator_traits<SupIt>::difference_type; - using difference_type = value_type; - using pointer = const value_type*; - using reference = const value_type&; - using iterator_category = std::forward_iterator_tag; - -private: - using super_iterator = SupIt; - using super_senitel = SupEnd; - using sub_iterator = SubIt; - - sub_iterator sub_it_; - - mutable super_iterator super_it_; - const super_senitel super_end_; - - mutable value_type super_idx_; - -public: - index_into_iterator(sub_iterator sub, super_iterator sup, super_senitel sup_end) : - sub_it_(sub), - super_it_(sup), - super_end_(sup_end), - super_idx_(0) - {} - - value_type operator*() { - advance_super(); - return super_idx_; - } - - value_type operator*() const { - advance_super(); - return super_idx_; - } - - bool operator==(const index_into_iterator& other) { - return sub_it_ == other.sub_it_; - } - - bool operator!=(const index_into_iterator& other) { - return !(*this == other); - } - - index_into_iterator operator++() { - ++sub_it_; - return (*this); - } - - index_into_iterator operator++(int) { - auto previous = *this; - ++(*this); - return previous; - } - - static constexpr value_type npos = value_type(-1); - -private: - - bool is_aligned() const { - return *sub_it_ == *super_it_; - } - - void advance_super() { - while(super_it_!=super_end_ && !is_aligned()) { - ++super_it_; - ++super_idx_; - } - - // this indicates that no match was found in super for a value - // in sub, which violates the precondition that sub is a subset of super - EXPECTS(!(super_it_==super_end_)); - - // set guard for users to test for validity if assertions are disabled - if (super_it_==super_end_) { - super_idx_ = npos; - } - } -}; - -/// Return an index that maps entries in sub to their corresponding values in -/// super, where sub is a subset of super. / -/// Both sets are sorted and have unique entries. Complexity is O(n), where n is -/// size of super -template<typename Sub, typename Super> -auto index_into(const Sub& sub, const Super& super) - -> util::range< - index_into_iterator< - typename util::sequence_traits<Sub>::const_iterator, - typename util::sequence_traits<Super>::const_iterator, - typename util::sequence_traits<Super>::const_sentinel - >> -{ - - EXPECTS(is_unique(super) && is_unique(sub)); - EXPECTS(is_sorted(super) && is_sorted(sub)); - EXPECTS(util::size(sub) <= util::size(super)); - - using iterator = index_into_iterator< - typename util::sequence_traits<Sub>::const_iterator, - typename util::sequence_traits<Super>::const_iterator, - typename util::sequence_traits<Super>::const_sentinel >; - auto begin = iterator(std::begin(sub), std::begin(super), std::end(super)); - auto end = iterator(std::end(sub), std::end(super), std::end(super)); - return util::make_range(begin, end); -} /// Binary search, because std::binary_search doesn't return information /// about where a match was found. diff --git a/src/backends.hpp b/src/backends.hpp index b4d9c1b39f2f5dd2601439e34acd05b6db5547d9..e45f96e0bc2afe5d855f7db8b6905fc3a3015e96 100644 --- a/src/backends.hpp +++ b/src/backends.hpp @@ -1,7 +1,6 @@ #pragma once #include <string> -#include <backends/fvm.hpp> namespace arb { diff --git a/src/backends/builtin_mech_proto.hpp b/src/backends/builtin_mech_proto.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c08f6cd797f33b33450f5b1cadccd20d671b7f80 --- /dev/null +++ b/src/backends/builtin_mech_proto.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include <mechanism.hpp> +#include <mechinfo.hpp> + +namespace arb { + +// Stimulus + +inline const mechanism_info& builtin_stimulus_info() { + using spec = mechanism_field_spec; + static mechanism_info info = { + // globals + {}, + // parameters + { + {"delay", spec(spec::parameter, "ms", 0, 0)}, + {"duration", spec(spec::parameter, "ms", 0, 0)}, + {"amplitude", spec(spec::parameter, "nA", 0, 0)} + }, + // state + {}, + // ions + {}, + // fingerprint + "##builtin_stimulus" + }; + + return info; +}; + +template <typename B> +concrete_mech_ptr<B> make_builtin_stimulus(); + +} // namespace arb diff --git a/src/backends/fvm.hpp b/src/backends/fvm.hpp deleted file mode 100644 index 42ba619c3e0598ab180551fecd042dd255c1c05a..0000000000000000000000000000000000000000 --- a/src/backends/fvm.hpp +++ /dev/null @@ -1,43 +0,0 @@ -#pragma once - -#include <backends/multicore/fvm.hpp> - -namespace arb { - -// A null back end used as a placeholder for back ends that are not supported -// on the target platform. -struct null_backend: public multicore::backend { - static bool is_supported() { - return false; - } - - static mechanism_ptr make_mechanism( - const std::string&, - size_type, - const_iview, - const_view, const_view, const_view, - view, view, - const std::vector<value_type>&, - const std::vector<size_type>&) - { - throw std::runtime_error("attempt to use an unsupported back end"); - } - - static bool has_mechanism(const std::string& name) { - return false; - } - - static std::string name() { - return "null"; - } -}; - -} // namespace arb - -#ifdef ARB_HAVE_GPU -#include <backends/gpu/fvm.hpp> -#else -namespace arb { namespace gpu { - using backend = null_backend; -}} // namespace arb::gpu -#endif diff --git a/src/backends/fvm_types.hpp b/src/backends/fvm_types.hpp index 5252afb0448b0a0d5f80c654b81a747bd592a23d..fdb8fdb794bfb06d60267f2b42c699e5becbb5d1 100644 --- a/src/backends/fvm_types.hpp +++ b/src/backends/fvm_types.hpp @@ -8,5 +8,17 @@ namespace arb { using fvm_value_type = double; using fvm_size_type = cell_local_size_type; +using fvm_index_type = int; + +// Stores a single crossing event. + +struct threshold_crossing { + fvm_size_type index; // index of variable + fvm_value_type time; // time of crossing + + friend bool operator==(threshold_crossing l, threshold_crossing r) { + return l.index==r.index && l.time==r.time; + } +}; } // namespace arb diff --git a/src/backends/gpu/intrinsics.hpp b/src/backends/gpu/cuda_atomic.hpp similarity index 77% rename from src/backends/gpu/intrinsics.hpp rename to src/backends/gpu/cuda_atomic.hpp index 6f6cf46be1288d4e03208296769a4a4a9d22dec9..1ee9e53d85cd0d034e3c86f3b652d8eb5420681f 100644 --- a/src/backends/gpu/intrinsics.hpp +++ b/src/backends/gpu/cuda_atomic.hpp @@ -39,24 +39,3 @@ inline float cuda_atomic_sub(float* address, float val) { return atomicAdd(address, -val); } -__device__ -inline double exprelr(double x) { - if (1.0+x == 1.0) { - return 1.0; - } - return x/expm1(x); -} - -// Return minimum of the two values -template <typename T> -__device__ -inline T min(T lhs, T rhs) { - return lhs<rhs? lhs: rhs; -} - -// Return maximum of the two values -template <typename T> -__device__ -inline T max(T lhs, T rhs) { - return lhs<rhs? rhs: lhs; -} diff --git a/src/backends/gpu/cuda_common.hpp b/src/backends/gpu/cuda_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c3cad0410dab5c7f3a670cf15c30ad1832de047b --- /dev/null +++ b/src/backends/gpu/cuda_common.hpp @@ -0,0 +1,30 @@ +#pragma once + +#ifdef __CUDACC__ +# define HOST_DEVICE_IF_CUDA __host__ __device__ +#else +# define HOST_DEVICE_IF_CUDA +#endif + +namespace arb { +namespace gpu { + +namespace impl { +// Number of threads per warp +// This has always been 32, however it may change in future NVIDIA gpus +HOST_DEVICE_IF_CUDA +constexpr inline unsigned threads_per_warp() { + return 32u; +} + +// The minimum number of bins required to store n values where the bins have +// dimension of block_size. +HOST_DEVICE_IF_CUDA +constexpr inline unsigned block_count(unsigned n, unsigned block_size) { + return (n+block_size-1)/block_size; +} + +} // namespace impl + +} // namespace gpu +} // namespace arb diff --git a/src/backends/gpu/fvm.cpp b/src/backends/gpu/fvm.cpp deleted file mode 100644 index e70b30cfe6c566429349e98267f149138a3c5cbf..0000000000000000000000000000000000000000 --- a/src/backends/gpu/fvm.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include "fvm.hpp" - -#include <mechanisms/gpu/hh_gpu.hpp> -#include <mechanisms/gpu/pas_gpu.hpp> -#include <mechanisms/gpu/expsyn_gpu.hpp> -#include <mechanisms/gpu/exp2syn_gpu.hpp> -#include <mechanisms/gpu/test_kin1_gpu.hpp> -#include <mechanisms/gpu/test_kinlva_gpu.hpp> -#include <mechanisms/gpu/test_ca_gpu.hpp> -#include <mechanisms/gpu/nax_gpu.hpp> -#include <mechanisms/gpu/kamt_gpu.hpp> -#include <mechanisms/gpu/kdrmt_gpu.hpp> - -namespace arb { -namespace gpu { - -std::map<std::string, backend::maker_type> -backend::mech_map_ = { - { "pas", maker<mechanism_pas> }, - { "hh", maker<mechanism_hh> }, - { "expsyn", maker<mechanism_expsyn> }, - { "exp2syn", maker<mechanism_exp2syn> }, - { "test_kin1", maker<mechanism_test_kin1> }, - { "test_kinlva", maker<mechanism_test_kinlva> }, - { "test_ca", maker<mechanism_test_ca> }, - { "nax", maker<mechanism_nax> }, - { "kamt", maker<mechanism_kamt> }, - { "kdrmt", maker<mechanism_kdrmt> }, -}; - -} // namespace gpu -} // namespace arb diff --git a/src/backends/gpu/fvm.hpp b/src/backends/gpu/fvm.hpp index 0a9a2c9a38c2b18774c0026f38acfabf2f76e760..a68661e86f9a9ebd9d0d2d09a0796388ab401f96 100644 --- a/src/backends/gpu/fvm.hpp +++ b/src/backends/gpu/fvm.hpp @@ -3,166 +3,62 @@ #include <map> #include <string> -#include <backends/event.hpp> -#include <backends/fvm_types.hpp> #include <common_types.hpp> #include <mechanism.hpp> #include <memory/memory.hpp> #include <util/rangeutil.hpp> -#include "kernels/take_samples.hpp" +#include <backends/event.hpp> +#include <backends/fvm_types.hpp> + +#include <backends/gpu/gpu_store_types.hpp> +#include <backends/gpu/shared_state.hpp> + #include "matrix_state_interleaved.hpp" -#include "multi_event_stream.hpp" -#include "ions.hpp" -#include "stimulus.hpp" #include "threshold_watcher.hpp" -#include "time_ops.hpp" namespace arb { namespace gpu { struct backend { - static bool is_supported() { - return true; - } + static bool is_supported() { return true; } + static std::string name() { return "gpu"; } - /// define the real and index types using value_type = fvm_value_type; + using index_type = fvm_index_type; using size_type = fvm_size_type; - /// define storage types - using array = memory::device_vector<value_type>; - using iarray = memory::device_vector<size_type>; - - using view = typename array::view_type; - using const_view = typename array::const_view_type; - - using iview = typename iarray::view_type; - using const_iview = typename iarray::const_view_type; - - using host_array = typename memory::host_vector<value_type>; - using host_iarray = typename memory::host_vector<size_type>; + using array = arb::gpu::array; + using iarray = arb::gpu::iarray; - using host_view = typename host_array::view_type; - using host_iview = typename host_iarray::const_view_type; - - static std::string name() { - return "gpu"; - } - - // dereference a probe handle - static value_type dereference(probe_handle h) { - memory::const_device_reference<value_type> v(h); // h is a device-side pointer - return v; + static memory::host_vector<value_type> host_view(const array& v) { + return memory::on_host(v); } - // matrix back end implementation - using matrix_state = matrix_state_interleaved<value_type, size_type>; - - // backend-specific multi event streams. - using deliverable_event_stream = arb::gpu::multi_event_stream<deliverable_event>; - using sample_event_stream = arb::gpu::multi_event_stream<sample_event>; - - // mechanism infrastructure - using ion_type = ion<backend>; - using stimulus = gpu::stimulus<backend>; - - using mechanism = arb::mechanism<backend>; - using mechanism_ptr = std::unique_ptr<mechanism>; - - static mechanism_ptr make_mechanism( - const std::string& name, - size_type mech_id, - const_iview vec_ci, - const_view vec_t, const_view vec_t_to, const_view vec_dt, - view vec_v, view vec_i, - const std::vector<value_type>& weights, - const std::vector<size_type>& node_indices) - { - if (!has_mechanism(name)) { - throw std::out_of_range("no mechanism in database : " + name); - } - - return mech_map_.find(name)-> - second(mech_id, vec_ci, vec_t, vec_t_to, vec_dt, vec_v, vec_i, memory::make_const_view(weights), memory::make_const_view(node_indices)); - } - - static bool has_mechanism(const std::string& name) { - return mech_map_.count(name)>0; + static memory::host_vector<index_type> host_view(const iarray& v) { + return memory::on_host(v); } + using matrix_state = arb::gpu::matrix_state_interleaved<value_type, index_type>; using threshold_watcher = arb::gpu::threshold_watcher; - // perform min/max reductions on 'array' type - template <typename V> - static std::pair<V, V> minmax_value(const memory::device_vector<V>& v) { - // TODO: consider/test replacement with CUDA kernel (or generic reduction kernel). - auto v_copy = memory::on_host(v); - return util::minmax_value(v_copy); - } - - // perform element-wise comparison on 'array' type against `t_test`. - template <typename V> - static bool any_time_before(const memory::device_vector<V>& t, V t_test) { - // Note: ubbench benchmarking (on a P100) indicates that copying the - // time vectors to the host is faster than a device side - // implementation unless we're running over ten thousands of cells per - // cell group. - - auto v_copy = memory::on_host(t); - return util::minmax_value(v_copy).first<t_test; - } - - static void update_time_to(array& time_to, const_view time, value_type dt, value_type tmax) { - arb::gpu::update_time_to(time_to.size(), time_to.data(), time.data(), dt, tmax); - } - - // set the per-cell and per-compartment dt_ from time_to_ - time_. - static void set_dt(array& dt_cell, array& dt_comp, const_view time_to, const_view time, const_iview cv_to_cell) { - size_type ncell = util::size(dt_cell); - size_type ncomp = util::size(dt_comp); + using deliverable_event_stream = arb::gpu::deliverable_event_stream; + using sample_event_stream = arb::gpu::sample_event_stream; - arb::gpu::set_dt( - ncell, ncomp, dt_cell.data(), dt_comp.data(), time_to.data(), time.data(), cv_to_cell.data()); - } + using shared_state = arb::gpu::shared_state; - // perform sampling as described by marked events in a sample_event_stream - static void take_samples( - const sample_event_stream::state& s, - const_view time, - array& sample_time, - array& sample_value) + static threshold_watcher voltage_watcher( + const shared_state& state, + const std::vector<index_type>& cv, + const std::vector<value_type>& thresholds) { - arb::gpu::take_samples(s, time.data(), sample_time.data(), sample_value.data()); - } - - // Calculate the reversal potential eX (mV) using Nernst equation - // eX = RT/zF * ln(Xo/Xi) - // R: universal gas constant 8.3144598 J.K-1.mol-1 - // T: temperature in Kelvin - // z: valency of species (K, Na: +1) (Ca: +2) - // F: Faraday's constant 96485.33289 C.mol-1 - // Xo/Xi: ratio of out/in concentrations - static void nernst(int valency, value_type temperature, const_view Xo, const_view Xi, view eX) { - arb::gpu::nernst(eX.size(), valency, temperature, Xo.data(), Xi.data(), eX.data()); - } - - static void init_concentration( - view Xi, view Xo, - const_view weight_Xi, const_view weight_Xo, - value_type c_int, value_type c_ext) - { - arb::gpu::init_concentration(Xi.size(), Xi.data(), Xo.data(), weight_Xi.data(), weight_Xo.data(), c_int, c_ext); - } - -private: - using maker_type = mechanism_ptr (*)(size_type, const_iview, const_view, const_view, const_view, view, view, array&&, iarray&&); - static std::map<std::string, maker_type> mech_map_; - - template <template <typename> class Mech> - static mechanism_ptr maker(size_type mech_id, const_iview vec_ci, const_view vec_t, const_view vec_t_to, const_view vec_dt, view vec_v, view vec_i, array&& weights, iarray&& node_indices) { - return arb::make_mechanism<Mech<backend>> - (mech_id, vec_ci, vec_t, vec_t_to, vec_dt, vec_v, vec_i, std::move(weights), std::move(node_indices)); + return threshold_watcher( + state.cv_to_cell.data(), + state.time.data(), + state.time_to.data(), + state.voltage.data(), + cv, + thresholds); } }; diff --git a/src/backends/gpu/gpu_store_types.hpp b/src/backends/gpu/gpu_store_types.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7eca723aec83b1e644c4acefdc4b8c150da6f16c --- /dev/null +++ b/src/backends/gpu/gpu_store_types.hpp @@ -0,0 +1,26 @@ +#pragma once + +// Storage classes and other common types across +// gpu back-end implementations. +// +// Defines array, iarray, and specialized multi-event stream classes. + +#include <memory/memory.hpp> + +#include <backends/event.hpp> +#include <backends/fvm_types.hpp> +#include <backends/gpu/multi_event_stream.hpp> +#include <backends/gpu/multi_event_stream.hpp> + +namespace arb { +namespace gpu { + +using array = memory::device_vector<fvm_value_type>; +using iarray = memory::device_vector<fvm_index_type>; + +using deliverable_event_stream = arb::gpu::multi_event_stream<deliverable_event>; +using sample_event_stream = arb::gpu::multi_event_stream<sample_event>; + +} // namespace gpu +} // namespace arb + diff --git a/src/backends/gpu/ions.hpp b/src/backends/gpu/ions.hpp deleted file mode 100644 index 6cd18b043185681b85b1b3c17edf89dba2307611..0000000000000000000000000000000000000000 --- a/src/backends/gpu/ions.hpp +++ /dev/null @@ -1,25 +0,0 @@ -#pragma once - -#include <cstdint> - -#include <backends/fvm_types.hpp> - -namespace arb { -namespace gpu { - -// prototype for nernst equation cacluation -void nernst(std::size_t n, int valency, - fvm_value_type temperature, - const fvm_value_type* Xo, - const fvm_value_type* Xi, - fvm_value_type* eX); - -// prototype for inializing ion species concentrations -void init_concentration(std::size_t n, - fvm_value_type* Xi, fvm_value_type* Xo, - const fvm_value_type* weight_Xi, const fvm_value_type* weight_Xo, - fvm_value_type c_int, fvm_value_type c_ext); - -} // namespace gpu -} // namespace arb - diff --git a/src/backends/gpu/kernels/ions.cu b/src/backends/gpu/kernels/ions.cu deleted file mode 100644 index a5061c5fdf13800644f809e75897f23ddffcff4d..0000000000000000000000000000000000000000 --- a/src/backends/gpu/kernels/ions.cu +++ /dev/null @@ -1,67 +0,0 @@ -#include <cstdint> - -#include <constants.hpp> - -#include "../ions.hpp" -#include "detail.hpp" - -namespace arb { -namespace gpu { - -namespace kernels { - template <typename T> - __global__ - void nernst(std::size_t n, int valency, T temperature, const T* Xo, const T* Xi, T* eX) { - auto i = threadIdx.x+blockIdx.x*blockDim.x; - - // factor 1e3 to scale from V -> mV - constexpr T RF = 1e3*constant::gas_constant/constant::faraday; - T factor = RF*temperature/valency; - if (i<n) { - eX[i] = factor*std::log(Xo[i]/Xi[i]); - } - } - - template <typename T> - __global__ - void init_concentration(std::size_t n, T* Xi, T* Xo, const T* weight_Xi, const T* weight_Xo, T c_int, T c_ext) { - auto i = threadIdx.x+blockIdx.x*blockDim.x; - - if (i<n) { - Xi[i] = c_int*weight_Xi[i]; - Xo[i] = c_ext*weight_Xo[i]; - } - } -} // namespace kernels - -void nernst(std::size_t n, - int valency, - fvm_value_type temperature, - const fvm_value_type* Xo, - const fvm_value_type* Xi, - fvm_value_type* eX) -{ - if (n>0) { - constexpr int block_dim = 128; - const int grid_dim = impl::block_count(n, block_dim); - kernels::nernst<<<grid_dim, block_dim>>> - (n, valency, temperature, Xo, Xi, eX); - } -} - -void init_concentration( - std::size_t n, - fvm_value_type* Xi, fvm_value_type* Xo, - const fvm_value_type* weight_Xi, const fvm_value_type* weight_Xo, - fvm_value_type c_int, fvm_value_type c_ext) -{ - if (n>0) { - constexpr int block_dim = 128; - const int grid_dim = impl::block_count(n, block_dim); - kernels::init_concentration<<<grid_dim, block_dim>>> - (n, Xi, Xo, weight_Xi, weight_Xo, c_int, c_ext); - } -} - -} // namespace gpu -} // namespace arb diff --git a/src/backends/gpu/kernels/stim_current.cu b/src/backends/gpu/kernels/stim_current.cu deleted file mode 100644 index 939164395e4e2b7e8501cb679fe79f39eaef5aa7..0000000000000000000000000000000000000000 --- a/src/backends/gpu/kernels/stim_current.cu +++ /dev/null @@ -1,48 +0,0 @@ -#include <backends/fvm_types.hpp> -#include <backends/gpu/intrinsics.hpp> - -namespace arb{ -namespace gpu { - -namespace kernels { - template <typename T, typename I> - __global__ - void stim_current( - const T* delay, const T* duration, const T* amplitude, const T* weights, - const I* node_index, int n, const I* cell_index, const T* time, T* current) - { - using value_type = T; - using iarray = I; - - auto i = threadIdx.x + blockDim.x*blockIdx.x; - - if (i<n) { - auto t = time[cell_index[i]]; - if (t>=delay[i] && t<delay[i]+duration[i]) { - // use subtraction because the electrode currents are specified - // in terms of current into the compartment - cuda_atomic_add(current+node_index[i], -weights[i]*amplitude[i]); - } - } - } -} // namespace kernels - - -void stim_current( - const fvm_value_type* delay, const fvm_value_type* duration, - const fvm_value_type* amplitude, const fvm_value_type* weights, - const fvm_size_type* node_index, int n, - const fvm_size_type* cell_index, const fvm_value_type* time, - fvm_value_type* current) -{ - constexpr unsigned thread_dim = 192; - dim3 dim_block(thread_dim); - dim3 dim_grid((n+thread_dim-1)/thread_dim); - - kernels::stim_current<fvm_value_type, fvm_size_type><<<dim_grid, dim_block>>> - (delay, duration, amplitude, weights, node_index, n, cell_index, time, current); - -} - -} // namespace gpu -} // namespace arb diff --git a/src/backends/gpu/kernels/take_samples.cu b/src/backends/gpu/kernels/take_samples.cu deleted file mode 100644 index 1e85f0e3805d19aabbb0e3e040bed2c99c7c483c..0000000000000000000000000000000000000000 --- a/src/backends/gpu/kernels/take_samples.cu +++ /dev/null @@ -1,46 +0,0 @@ -#include <common_types.hpp> -#include <backends/event.hpp> -#include <backends/fvm_types.hpp> -#include <backends/multi_event_stream_state.hpp> - -namespace arb { -namespace gpu { - -namespace kernels { - __global__ void take_samples( - multi_event_stream_state<raw_probe_info> s, - const fvm_value_type* time, - fvm_value_type* sample_time, - fvm_value_type* sample_value) - { - int i = threadIdx.x+blockIdx.x*blockDim.x; - - if (i<s.n) { - auto begin = s.ev_data+s.begin_offset[i]; - auto end = s.ev_data+s.end_offset[i]; - for (auto p = begin; p!=end; ++p) { - sample_time[p->offset] = time[i]; - sample_value[p->offset] = *p->handle; - } - } - } -} - -void take_samples( - const multi_event_stream_state<raw_probe_info>& s, - const fvm_value_type* time, - fvm_value_type* sample_time, - fvm_value_type* sample_value) -{ - if (!s.n_streams()) { - return; - } - - constexpr int blockwidth = 128; - int nblock = 1+(s.n_streams()-1)/blockwidth; - kernels::take_samples<<<nblock, blockwidth>>>(s, time, sample_time, sample_value); -} - -} // namespace gpu -} // namespace arb - diff --git a/src/backends/gpu/kernels/take_samples.hpp b/src/backends/gpu/kernels/take_samples.hpp deleted file mode 100644 index f174c99cd6d210f4d7092a638f52763b8414681d..0000000000000000000000000000000000000000 --- a/src/backends/gpu/kernels/take_samples.hpp +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include <common_types.hpp> -#include <backends/event.hpp> -#include <backends/fvm_types.hpp> -#include <backends/multi_event_stream_state.hpp> - -namespace arb { -namespace gpu { - -void take_samples( - const multi_event_stream_state<raw_probe_info>& s, - const fvm_value_type* time, - fvm_value_type* sample_time, - fvm_value_type* sample_value); - -} // namespace gpu -} // namespace arb - diff --git a/src/backends/gpu/kernels/test_thresholds.hpp b/src/backends/gpu/kernels/test_thresholds.hpp deleted file mode 100644 index d7f9479e5e4615c7042f7b6d7c89d3b99a006dd3..0000000000000000000000000000000000000000 --- a/src/backends/gpu/kernels/test_thresholds.hpp +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - -#include <backends/fvm_types.hpp> - -#include "stack.hpp" - -namespace arb { -namespace gpu { - -extern void test_thresholds( - const fvm_size_type* cv_to_cell, const fvm_value_type* t_after, const fvm_value_type* t_before, - int size, - stack_storage<threshold_crossing>& stack, - fvm_size_type* is_crossed, fvm_value_type* prev_values, - const fvm_size_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds); - -} // namespace gpu -} // namespace arb diff --git a/src/backends/gpu/kernels/time_ops.cu b/src/backends/gpu/kernels/time_ops.cu deleted file mode 100644 index 7b78152219771abc7a01003a1d4e89395064daa8..0000000000000000000000000000000000000000 --- a/src/backends/gpu/kernels/time_ops.cu +++ /dev/null @@ -1,78 +0,0 @@ -#include <backends/fvm_types.hpp> - -namespace arb { -namespace gpu { - -namespace kernels { - template <typename T, typename I> - __global__ void update_time_to(I n, T* time_to, const T* time, T dt, T tmax) { - int i = threadIdx.x+blockIdx.x*blockDim.x; - if (i<n) { - auto t = time[i]+dt; - time_to[i] = t<tmax? t: tmax; - } - } - - template <typename T> - struct less { - __device__ __host__ - bool operator()(const T& a, const T& b) const { return a<b; } - }; - - // vector minus: x = y - z - template <typename T, typename I> - __global__ void vec_minus(I n, T* x, const T* y, const T* z) { - int i = threadIdx.x+blockIdx.x*blockDim.x; - if (i<n) { - x[i] = y[i]-z[i]; - } - } - - // vector gather: x[i] = y[index[i]] - template <typename T, typename I> - __global__ void gather(I n, T* x, const T* y, const I* index) { - int i = threadIdx.x+blockIdx.x*blockDim.x; - if (i<n) { - x[i] = y[index[i]]; - } - } -} - -void update_time_to(fvm_size_type n, - fvm_value_type* time_to, - const fvm_value_type* time, - fvm_value_type dt, - fvm_value_type tmax) -{ - if (!n) { - return; - } - - constexpr int blockwidth = 128; - int nblock = 1+(n-1)/blockwidth; - kernels::update_time_to<<<nblock, blockwidth>>> - (n, time_to, time, dt, tmax); -} - -void set_dt(fvm_size_type ncell, - fvm_size_type ncomp, - fvm_value_type* dt_cell, - fvm_value_type* dt_comp, - const fvm_value_type* time_to, - const fvm_value_type* time, - const fvm_size_type* cv_to_cell) -{ - if (!ncell || !ncomp) { - return; - } - - constexpr int blockwidth = 128; - int nblock = 1+(ncell-1)/blockwidth; - kernels::vec_minus<<<nblock, blockwidth>>>(ncell, dt_cell, time_to, time); - - nblock = 1+(ncomp-1)/blockwidth; - kernels::gather<<<nblock, blockwidth>>>(ncomp, dt_comp, dt_cell, cv_to_cell); -} - -} // namespace gpu -} // namespace arb diff --git a/src/backends/gpu/math.hpp b/src/backends/gpu/math.hpp new file mode 100644 index 0000000000000000000000000000000000000000..13abd7eaf2a8bd95df9f62dec43c11d626499eef --- /dev/null +++ b/src/backends/gpu/math.hpp @@ -0,0 +1,32 @@ +#pragma once + +// Implementations of mathematical operations required +// by generated CUDA mechanisms. + +namespace arb { +namespace gpu { + +__device__ +inline double exprelr(double x) { + if (1.0+x == 1.0) { + return 1.0; + } + return x/expm1(x); +} + +// Return minimum of the two values +template <typename T> +__device__ +inline T min(T lhs, T rhs) { + return lhs<rhs? lhs: rhs; +} + +// Return maximum of the two values +template <typename T> +__device__ +inline T max(T lhs, T rhs) { + return lhs<rhs? rhs: lhs; +} + +} // namespace gpu +} // namespace arb diff --git a/src/backends/gpu/kernels/assemble_matrix.cu b/src/backends/gpu/matrix_assemble.cu similarity index 94% rename from src/backends/gpu/kernels/assemble_matrix.cu rename to src/backends/gpu/matrix_assemble.cu index 17146ebccb99a7739712ccb1fde5422e5d295bc0..040724a3d5e8f236e3849c47598cd0340a8a2907 100644 --- a/src/backends/gpu/kernels/assemble_matrix.cu +++ b/src/backends/gpu/matrix_assemble.cu @@ -1,6 +1,7 @@ #include <backends/fvm_types.hpp> -#include "detail.hpp" +#include "cuda_common.hpp" +#include "matrix_common.hpp" namespace arb { namespace gpu { @@ -153,7 +154,7 @@ void assemble_matrix_flat( const fvm_value_type* current, const fvm_value_type* cv_capacitance, const fvm_value_type* area, - const fvm_size_type* cv_to_cell, + const fvm_index_type* cv_to_cell, const fvm_value_type* dt_cell, unsigned n) { @@ -161,7 +162,7 @@ void assemble_matrix_flat( const unsigned grid_dim = impl::block_count(n, block_dim); kernels::assemble_matrix_flat - <fvm_value_type, fvm_size_type> + <fvm_value_type, fvm_index_type> <<<grid_dim, block_dim>>> (d, rhs, invariant_d, voltage, current, cv_capacitance, area, cv_to_cell, dt_cell, n); @@ -176,13 +177,13 @@ void assemble_matrix_interleaved( const fvm_value_type* current, const fvm_value_type* cv_capacitance, const fvm_value_type* area, - const fvm_size_type* sizes, - const fvm_size_type* starts, - const fvm_size_type* matrix_to_cell, + const fvm_index_type* sizes, + const fvm_index_type* starts, + const fvm_index_type* matrix_to_cell, const fvm_value_type* dt_cell, unsigned padded_size, unsigned num_mtx) { - constexpr unsigned bd = impl::block_dim(); + constexpr unsigned bd = impl::matrices_per_block(); constexpr unsigned lw = impl::load_width(); constexpr unsigned block_dim = bd*lw; @@ -190,7 +191,7 @@ void assemble_matrix_interleaved( const unsigned grid_dim = impl::block_count(num_mtx*lw, block_dim); kernels::assemble_matrix_interleaved - <fvm_value_type, fvm_size_type, bd, lw, block_dim> + <fvm_value_type, fvm_index_type, bd, lw, block_dim> <<<grid_dim, block_dim>>> (d, rhs, invariant_d, voltage, current, cv_capacitance, area, sizes, starts, matrix_to_cell, diff --git a/src/backends/gpu/kernels/detail.hpp b/src/backends/gpu/matrix_common.hpp similarity index 62% rename from src/backends/gpu/kernels/detail.hpp rename to src/backends/gpu/matrix_common.hpp index 454dfe21602f13bdf0f4273ca147c11e1eb8e8de..13b4e7c426e2ca6dec7166cc1c09a42d8beb6f16 100644 --- a/src/backends/gpu/kernels/detail.hpp +++ b/src/backends/gpu/matrix_common.hpp @@ -1,8 +1,6 @@ #pragma once #include <cfloat> -#include <cmath> -#include <cstdint> #include <climits> #ifdef __CUDACC__ @@ -17,7 +15,7 @@ namespace gpu { namespace impl { // Number of matrices per block in block-interleaved storage HOST_DEVICE_IF_CUDA -constexpr inline unsigned block_dim() { +constexpr inline unsigned matrices_per_block() { return 32u; } @@ -34,26 +32,6 @@ constexpr inline unsigned matrix_padding() { return load_width(); } -// Number of threads per warp -// This has always been 32, however it may change in future NVIDIA gpus -HOST_DEVICE_IF_CUDA -constexpr inline unsigned threads_per_warp() { - return 32u; -} - -// The minimum number of bins required to store n values where the bins have -// dimension of block_size. -HOST_DEVICE_IF_CUDA -constexpr inline unsigned block_count(unsigned n, unsigned block_size) { - return (n+block_size-1)/block_size; -} - -// The smallest size of a buffer required to store n items in such that the -// buffer has size that is a multiple of block_dim. -constexpr inline unsigned padded_size(unsigned n, unsigned block_dim) { - return block_dim*block_count(n, block_dim); -} - // Placeholders to use for mark padded locations in data structures that use // padding. Using such markers makes it easier to test that padding is // performed correctly. @@ -78,13 +56,6 @@ constexpr bool is_npos(T v) { return v == npos<T>(); } -/// Cuda lerp by u on [a,b]: (1-u)*a + u*b. -template <typename T> -HOST_DEVICE_IF_CUDA -inline T lerp(T a, T b, T u) { - return std::fma(u, b, std::fma(-u, a, a)); -} - } // namespace impl } // namespace gpu diff --git a/src/backends/gpu/kernels/interleave.cu b/src/backends/gpu/matrix_interleave.cu similarity index 62% rename from src/backends/gpu/kernels/interleave.cu rename to src/backends/gpu/matrix_interleave.cu index 4bf923f7b71fc4d8f3778049566a474fe1a48d63..daad91f87afeb653f3526c2ab7affa37e2a9faed 100644 --- a/src/backends/gpu/kernels/interleave.cu +++ b/src/backends/gpu/matrix_interleave.cu @@ -1,6 +1,7 @@ #include <backends/fvm_types.hpp> -#include "detail.hpp" -#include "interleave.hpp" + +#include "matrix_common.hpp" +#include "matrix_interleave.hpp" namespace arb { namespace gpu { @@ -9,16 +10,16 @@ namespace gpu { void flat_to_interleaved( const fvm_value_type* in, fvm_value_type* out, - const fvm_size_type* sizes, - const fvm_size_type* starts, + const fvm_index_type* sizes, + const fvm_index_type* starts, unsigned padded_size, unsigned num_vec) { - constexpr unsigned BlockWidth = impl::block_dim(); + constexpr unsigned BlockWidth = impl::matrices_per_block(); constexpr unsigned LoadWidth = impl::load_width(); flat_to_interleaved - <fvm_value_type, fvm_size_type, BlockWidth, LoadWidth> + <fvm_value_type, fvm_index_type, BlockWidth, LoadWidth> (in, out, sizes, starts, padded_size, num_vec); } @@ -26,16 +27,16 @@ void flat_to_interleaved( void interleaved_to_flat( const fvm_value_type* in, fvm_value_type* out, - const fvm_size_type* sizes, - const fvm_size_type* starts, + const fvm_index_type* sizes, + const fvm_index_type* starts, unsigned padded_size, unsigned num_vec) { - constexpr unsigned BlockWidth = impl::block_dim(); + constexpr unsigned BlockWidth = impl::matrices_per_block(); constexpr unsigned LoadWidth = impl::load_width(); interleaved_to_flat - <fvm_value_type, fvm_size_type, BlockWidth, LoadWidth> + <fvm_value_type, fvm_index_type, BlockWidth, LoadWidth> (in, out, sizes, starts, padded_size, num_vec); } diff --git a/src/backends/gpu/kernels/interleave.hpp b/src/backends/gpu/matrix_interleave.hpp similarity index 98% rename from src/backends/gpu/kernels/interleave.hpp rename to src/backends/gpu/matrix_interleave.hpp index 76e41af6d7d9d99687276be4e4a02b5d17f9cd82..dc9c5674e4797f0732648ac5559518c2d62840ff 100644 --- a/src/backends/gpu/kernels/interleave.hpp +++ b/src/backends/gpu/matrix_interleave.hpp @@ -1,6 +1,7 @@ #pragma once -#include "detail.hpp" +#include "cuda_common.hpp" +#include "matrix_common.hpp" namespace arb { namespace gpu { diff --git a/src/backends/gpu/kernels/solve_matrix.cu b/src/backends/gpu/matrix_solve.cu similarity index 86% rename from src/backends/gpu/kernels/solve_matrix.cu rename to src/backends/gpu/matrix_solve.cu index 30026e9fe7b12bc1d9c8566b95d3137c4554ebaf..eaf5724ff9b451c0d50cdd56b885a97b7c0d9d16 100644 --- a/src/backends/gpu/kernels/solve_matrix.cu +++ b/src/backends/gpu/matrix_solve.cu @@ -1,6 +1,7 @@ #include <cassert> -#include "detail.hpp" +#include "cuda_common.hpp" +#include "matrix_common.hpp" #include <backends/fvm_types.hpp> namespace arb { @@ -89,14 +90,14 @@ void solve_matrix_flat( fvm_value_type* rhs, fvm_value_type* d, const fvm_value_type* u, - const fvm_size_type* p, - const fvm_size_type* cell_cv_divs, + const fvm_index_type* p, + const fvm_index_type* cell_cv_divs, int num_mtx) { constexpr unsigned block_dim = 128; const unsigned grid_dim = impl::block_count(num_mtx, block_dim); kernels::solve_matrix_flat - <fvm_value_type, fvm_size_type> + <fvm_value_type, fvm_index_type> <<<grid_dim, block_dim>>> (rhs, d, u, p, cell_cv_divs, num_mtx); } @@ -105,14 +106,15 @@ void solve_matrix_interleaved( fvm_value_type* rhs, fvm_value_type* d, const fvm_value_type* u, - const fvm_size_type* p, - const fvm_size_type* sizes, + const fvm_index_type* p, + const fvm_index_type* sizes, int padded_size, int num_mtx) { - const unsigned grid_dim = impl::block_count(num_mtx, impl::block_dim()); - kernels::solve_matrix_interleaved<fvm_value_type, fvm_size_type, impl::block_dim()> - <<<grid_dim, impl::block_dim()>>> + constexpr unsigned block_dim = impl::matrices_per_block(); + const unsigned grid_dim = impl::block_count(num_mtx, block_dim); + kernels::solve_matrix_interleaved<fvm_value_type, fvm_index_type, block_dim> + <<<grid_dim, block_dim>>> (rhs, d, u, p, sizes, padded_size, num_mtx); } diff --git a/src/backends/gpu/matrix_state_flat.hpp b/src/backends/gpu/matrix_state_flat.hpp index 380ab804ceda859f554d551039e2765436b9e717..b65d6efe38ef0b19b42bdd056e53b0d0ccb6f0c3 100644 --- a/src/backends/gpu/matrix_state_flat.hpp +++ b/src/backends/gpu/matrix_state_flat.hpp @@ -10,12 +10,14 @@ namespace arb { namespace gpu { +// CUDA implementation entry points: + void solve_matrix_flat( fvm_value_type* rhs, fvm_value_type* d, const fvm_value_type* u, - const fvm_size_type* p, - const fvm_size_type* cell_cv_divs, + const fvm_index_type* p, + const fvm_index_type* cell_cv_divs, int num_mtx); void assemble_matrix_flat( @@ -26,7 +28,7 @@ void assemble_matrix_flat( const fvm_value_type* current, const fvm_value_type* cv_capacitance, const fvm_value_type* cv_area, - const fvm_size_type* cv_to_cell, + const fvm_index_type* cv_to_cell, const fvm_value_type* dt_cell, unsigned n); @@ -34,10 +36,10 @@ void assemble_matrix_flat( template <typename T, typename I> struct matrix_state_flat { using value_type = T; - using size_type = I; + using index_type = I; using array = memory::device_vector<value_type>; - using iarray = memory::device_vector<size_type>; + using iarray = memory::device_vector<index_type>; using view = typename array::view_type; using const_view = typename array::const_view_type; @@ -59,8 +61,8 @@ struct matrix_state_flat { matrix_state_flat() = default; - matrix_state_flat(const std::vector<size_type>& p, - const std::vector<size_type>& cell_cv_divs, + matrix_state_flat(const std::vector<index_type>& p, + const std::vector<index_type>& cell_cv_divs, const std::vector<value_type>& cv_cap, const std::vector<value_type>& face_cond, const std::vector<value_type>& area): @@ -76,13 +78,13 @@ struct matrix_state_flat { EXPECTS(cv_cap.size() == size()); EXPECTS(face_cond.size() == size()); EXPECTS(area.size() == size()); - EXPECTS(cell_cv_divs.back() == size()); + EXPECTS(cell_cv_divs.back() == (index_type)size()); EXPECTS(cell_cv_divs.size() > 1u); using memory::make_const_view; auto n = d.size(); - std::vector<size_type> cv_to_cell_tmp(n, 0); + std::vector<index_type> cv_to_cell_tmp(n, 0); std::vector<value_type> invariant_d_tmp(n, 0); std::vector<value_type> u_tmp(n, 0); @@ -94,7 +96,7 @@ struct matrix_state_flat { invariant_d_tmp[p[i]] += gij; } - size_type ci = 0; + index_type ci = 0; for (auto cv_span: util::partition_view(cell_cv_divs)) { util::fill(util::subrange_view(cv_to_cell_tmp, cv_span), ci); ++ci; diff --git a/src/backends/gpu/matrix_state_interleaved.hpp b/src/backends/gpu/matrix_state_interleaved.hpp index c7db9c42ac337430df4ad001cf8954e4279dfaeb..7463918cc63a87d80835565c9939bdc57cf591cc 100644 --- a/src/backends/gpu/matrix_state_interleaved.hpp +++ b/src/backends/gpu/matrix_state_interleaved.hpp @@ -1,6 +1,7 @@ #pragma once #include <backends/fvm_types.hpp> +#include <math.hpp> #include <memory/memory.hpp> #include <util/debug.hpp> #include <util/span.hpp> @@ -8,7 +9,8 @@ #include <util/rangeutil.hpp> #include <util/indirect.hpp> -#include "kernels/detail.hpp" +#include "cuda_common.hpp" +#include "matrix_common.hpp" namespace arb { namespace gpu { @@ -22,9 +24,9 @@ void assemble_matrix_interleaved( const fvm_value_type* current, const fvm_value_type* cv_capacitance, const fvm_value_type* area, - const fvm_size_type* sizes, - const fvm_size_type* starts, - const fvm_size_type* matrix_to_cell, + const fvm_index_type* sizes, + const fvm_index_type* starts, + const fvm_index_type* matrix_to_cell, const fvm_value_type* dt_cell, unsigned padded_size, unsigned num_mtx); @@ -33,8 +35,8 @@ void solve_matrix_interleaved( fvm_value_type* rhs, fvm_value_type* d, const fvm_value_type* u, - const fvm_size_type* p, - const fvm_size_type* sizes, + const fvm_index_type* p, + const fvm_index_type* sizes, int padded_size, int num_mtx); @@ -42,8 +44,8 @@ void solve_matrix_interleaved( void flat_to_interleaved( const fvm_value_type* in, fvm_value_type* out, - const fvm_size_type* sizes, - const fvm_size_type* starts, + const fvm_index_type* sizes, + const fvm_index_type* starts, unsigned padded_size, unsigned num_vec); @@ -51,8 +53,8 @@ void flat_to_interleaved( void interleaved_to_flat( const fvm_value_type* in, fvm_value_type* out, - const fvm_size_type* sizes, - const fvm_size_type* starts, + const fvm_index_type* sizes, + const fvm_index_type* starts, unsigned padded_size, unsigned num_vec); @@ -85,10 +87,10 @@ std::vector<T> flat_to_interleaved( template <typename T, typename I> struct matrix_state_interleaved { using value_type = T; - using size_type = I; + using index_type = I; using array = memory::device_vector<value_type>; - using iarray = memory::device_vector<size_type>; + using iarray = memory::device_vector<index_type>; using const_view = typename array::const_view_type; @@ -137,15 +139,15 @@ struct matrix_state_interleaved { // of indexes and data structures in the constructor. // cv_cap // [pF] // face_cond // [μS] - matrix_state_interleaved(const std::vector<size_type>& p, - const std::vector<size_type>& cell_cv_divs, + matrix_state_interleaved(const std::vector<index_type>& p, + const std::vector<index_type>& cell_cv_divs, const std::vector<value_type>& cv_cap, const std::vector<value_type>& face_cond, const std::vector<value_type>& area) { EXPECTS(cv_cap.size() == p.size()); EXPECTS(face_cond.size() == p.size()); - EXPECTS(cell_cv_divs.back() == p.size()); + EXPECTS(cell_cv_divs.back() == (index_type)p.size()); // Just because you never know. EXPECTS(cell_cv_divs.size() <= UINT_MAX); @@ -154,7 +156,7 @@ struct matrix_state_interleaved { using util::indirect_view; // Convenience for commonly used type in this routine. - using svec = std::vector<size_type>; + using svec = std::vector<index_type>; // // Sort matrices in descending order of size. @@ -171,7 +173,7 @@ struct matrix_state_interleaved { svec perm(num_mtx); std::iota(perm.begin(), perm.end(), 0); // calculate the permutation of matrices to put the in ascending size - util::stable_sort_by(perm, [&sizes](size_type i){ return sizes[i]; }); + util::stable_sort_by(perm, [&sizes](index_type i){ return sizes[i]; }); std::reverse(perm.begin(), perm.end()); svec sizes_p = util::assign_from(indirect_view(sizes, perm)); @@ -182,31 +184,31 @@ struct matrix_state_interleaved { // // Calculate dimensions required to store matrices. // - using impl::block_dim; + constexpr unsigned block_dim = impl::matrices_per_block(); using impl::matrix_padding; // To start, take simplest approach of assuming all matrices stored // in blocks of the same dimension: padded_size - padded_size = impl::padded_size(sizes_p[0], matrix_padding()); - const auto num_blocks = impl::block_count(num_mtx, block_dim()); + padded_size = math::round_up(sizes_p[0], matrix_padding()); + const auto num_blocks = impl::block_count(num_mtx, block_dim); - const auto total_storage = num_blocks*block_dim()*padded_size; + const auto total_storage = num_blocks*block_dim*padded_size; // calculate the interleaved and permuted p vector - constexpr auto npos = std::numeric_limits<size_type>::max(); - std::vector<size_type> p_tmp(total_storage, npos); + constexpr auto npos = std::numeric_limits<index_type>::max(); + std::vector<index_type> p_tmp(total_storage, npos); for (auto mtx: make_span(0, num_mtx)) { - auto block = mtx/block_dim(); - auto lane = mtx%block_dim(); + auto block = mtx/block_dim; + auto lane = mtx%block_dim; auto len = sizes_p[mtx]; auto src = cell_to_cv_p[mtx]; - auto dst = block*(block_dim()*padded_size) + lane; + auto dst = block*(block_dim*padded_size) + lane; for (auto i: make_span(0, len)) { // the p indexes are always relative to the start of the p vector. // the addition and subtraction of dst and src respectively is to convert from // the original offset to the new padded and permuted offset. - p_tmp[dst+block_dim()*i] = dst + block_dim()*(p[src+i]-src); + p_tmp[dst+block_dim*i] = dst + block_dim*(p[src+i]-src); } } @@ -235,7 +237,7 @@ struct matrix_state_interleaved { // memory, for use as an rvalue in an assignemt to a device vector. auto interleave = [&] (std::vector<T>const& x) { return memory::on_gpu( - flat_to_interleaved(x, sizes_p, cell_to_cv_p, block_dim(), num_mtx, padded_size)); + flat_to_interleaved(x, sizes_p, cell_to_cv_p, block_dim, num_mtx, padded_size)); }; u = interleave(u_tmp); invariant_d = interleave(invariant_d_tmp); diff --git a/src/backends/gpu/mechanism.cpp b/src/backends/gpu/mechanism.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8cf643cd2de0432964748b0f49372dff19d8b948 --- /dev/null +++ b/src/backends/gpu/mechanism.cpp @@ -0,0 +1,162 @@ +#include <algorithm> +#include <cstddef> +#include <cmath> +#include <string> +#include <utility> +#include <vector> + +#include <common_types.hpp> +#include <math.hpp> +#include <mechanism.hpp> +#include <memory/memory.hpp> +#include <util/index_into.hpp> +#include <util/optional.hpp> +#include <util/maputil.hpp> +#include <util/range.hpp> +#include <util/span.hpp> + +#include <backends/fvm_types.hpp> +#include <backends/gpu/mechanism.hpp> +#include <backends/gpu/fvm.hpp> + +namespace arb { +namespace gpu { + +using memory::make_const_view; +using util::value_by_key; +using util::make_span; + +template <typename T> +memory::device_view<T> device_view(T* ptr, std::size_t n) { + return memory::device_view<T>(ptr, n); +} + +template <typename T> +memory::const_device_view<T> device_view(const T* ptr, std::size_t n) { + return memory::const_device_view<T>(ptr, n); +} + +// The derived class (typically generated code from modcc) holds pointers that need +// to be set to point inside the shared state, or into the allocated parameter/variable +// data block. + +void mechanism::instantiate(fvm_size_type id, backend::shared_state& shared, const layout& pos_data) { + mechanism_id_ = id; + width_ = pos_data.cv.size(); + + unsigned alignment = std::max(array::alignment(), iarray::alignment()); + auto width_padded_ = math::round_up(width_, alignment); + + // Assign non-owning views onto shared state: + + mechanism_ppack_base* pp = ppack_ptr(); // From derived class instance. + + pp->vec_ci_ = shared.cv_to_cell.data(); + pp->vec_t_ = shared.time.data(); + pp->vec_t_to_ = shared.time_to.data(); + pp->vec_dt_ = shared.dt_cv.data(); + + pp->vec_v_ = shared.voltage.data(); + pp->vec_i_ = shared.current_density.data(); + + auto ion_state_tbl = ion_state_table(); + n_ion_ = ion_state_tbl.size(); + + for (auto i: ion_state_tbl) { + util::optional<ion_state&> oion = value_by_key(shared.ion_data, i.first); + if (!oion) { + throw std::logic_error("mechanism holds ion with no corresponding shared state"); + } + + ion_state_view& ion_view = *i.second; + ion_view.current_density = oion->iX_.data(); + ion_view.reversal_potential = oion->eX_.data(); + ion_view.internal_concentration = oion->Xi_.data(); + ion_view.external_concentration = oion->Xo_.data(); + } + + event_stream_ptr_ = &shared.deliverable_events; + + // If there are no sites (is this ever meaningful?) there is nothing more to do. + if (width_==0) { + return; + } + + // Allocate and initialize state and parameter vectors with default values. + // (First sub-array of data_ is used for width_.) + + auto fields = field_table(); + std::size_t n_field = fields.size(); + + data_ = array((1+n_field)*width_padded_, NAN); + memory::copy(make_const_view(pos_data.weight), device_view(data_.data(), width_)); + pp->weight_ = data_.data(); + + for (auto i: make_span(0, n_field)) { + // Take reference to corresponding derived (generated) mechanism value pointer member. + fvm_value_type*& field_ptr = *std::get<1>(fields[i]); + field_ptr = data_.data()+(i+1)*width_padded_; + + if (auto opt_value = value_by_key(field_default_table(), fields[i].first)) { + memory::fill(device_view(field_ptr, width_), *opt_value); + } + } + + // Allocate and initialize index vectors, viz. node_index_ and any ion indices. + // (First sub-array of indices_ is used for node_index_.) + + indices_ = iarray((1+n_ion_)*width_padded_); + + memory::copy(make_const_view(pos_data.cv), device_view(indices_.data(), width_)); + pp->node_index_ = indices_.data(); + + auto ion_index_tbl = ion_index_table(); + EXPECTS(n_ion_==ion_index_tbl.size()); + + for (auto i: make_span(0, n_ion_)) { + util::optional<ion_state&> oion = value_by_key(shared.ion_data, ion_index_tbl[i].first); + if (!oion) { + throw std::logic_error("mechanism holds ion with no corresponding shared state"); + } + + auto indices = util::index_into(pos_data.cv, memory::on_host(oion->node_index_)); + std::vector<index_type> mech_ion_index(indices.begin(), indices.end()); + + // Take reference to derived (generated) mechanism ion index pointer. + auto& ion_index_ptr = *ion_index_tbl[i].second; + auto index_start = indices_.data()+(i+1)*width_padded_; + ion_index_ptr = index_start; + memory::copy(make_const_view(mech_ion_index), device_view(index_start, width_)); + } +} + +void mechanism::set_parameter(const std::string& key, const std::vector<fvm_value_type>& values) { + if (auto opt_ptr = value_by_key(field_table(), key)) { + if (values.size()!=width_) { + throw std::logic_error("internal error: mechanism parameter size mismatch"); + } + + if (width_>0) { + // Retrieve corresponding derived (generated) mechanism value pointer member. + value_type* field_ptr = *opt_ptr.value(); + memory::copy(make_const_view(values), device_view(field_ptr, width_)); + } + } + else { + throw std::logic_error("internal error: no such mechanism parameter"); + } +} + +void mechanism::set_global(const std::string& key, fvm_value_type value) { + if (auto opt_ptr = value_by_key(global_table(), key)) { + // Take reference to corresponding derived (generated) mechanism value member. + value_type& global = *opt_ptr.value(); + global = value; + } + else { + throw std::logic_error("internal error: no such mechanism global"); + } +} + +} // namespace multicore +} // namespace arb diff --git a/src/backends/gpu/mechanism.hpp b/src/backends/gpu/mechanism.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f4adf68e72023c8ac9af532479e09ab191f41fa0 --- /dev/null +++ b/src/backends/gpu/mechanism.hpp @@ -0,0 +1,125 @@ +#pragma once + +#include <algorithm> +#include <cstddef> +#include <cmath> +#include <string> +#include <utility> +#include <vector> + +#include <backends/fvm_types.hpp> +#include <common_types.hpp> +#include <mechanism.hpp> + +#include <backends/gpu/fvm.hpp> +#include <backends/gpu/gpu_store_types.hpp> +#include <backends/gpu/mechanism_ppack_base.hpp> + +namespace arb { +namespace gpu { + +// Base class for all generated mechanisms for gpu back-end. + +class mechanism: public arb::concrete_mechanism<arb::gpu::backend> { +public: + using value_type = fvm_value_type; + using index_type = fvm_index_type; + using size_type = fvm_size_type; + +protected: + using backend = arb::gpu::backend; + using deliverable_event_stream = backend::deliverable_event_stream; + + using array = arb::gpu::array; + using iarray = arb::gpu::iarray; + + using ion_state_vuew = arb::gpu::ion_state_view; + +public: + std::size_t size() const override { + return width_; + } + + std::size_t memory() const override { + std::size_t s = object_sizeof(); + + s += sizeof(value_type) * data_.size(); + s += sizeof(index_type) * indices_.size(); + return s; + } + + void instantiate(fvm_size_type id, backend::shared_state& shared, const layout& w) override; + + void deliver_events() override { + // Delegate to derived class, passing in event queue state. + deliver_events(event_stream_ptr_->marked_events()); + } + + void set_parameter(const std::string& key, const std::vector<fvm_value_type>& values) override; + + void set_global(const std::string& key, fvm_value_type value) override; + +protected: + size_type width_ = 0; // Instance width (number of CVs/sites) + size_type n_ion_ = 0; + + // Returns pointer to (derived) parameter-pack object that holds: + // * pointers to shared cell state `vec_ci_` et al., + // * pointer to mechanism weights `weight_`, + // * pointer to mechanism node indices `node_index_`, + // * mechanism global scalars and pointers to mechanism range parameters. + // * mechanism ion_state_view objects and pointers to mechanism ion indices. + + virtual mechanism_ppack_base* ppack_ptr() = 0; + + deliverable_event_stream* event_stream_ptr_; + + // Bulk storage for index vectors and state and parameter variables. + + iarray indices_; + array data_; + + // Generated mechanism field, global and ion table lookup types. + // First component is name, second is pointer to corresponing member in + // the mechanism's parameter pack, or for field_default_table, + // the scalar value used to initialize the field. + + using global_table_entry = std::pair<const char*, value_type*>; + using mechanism_global_table = std::vector<global_table_entry>; + + using field_table_entry = std::pair<const char*, value_type**>; + using mechanism_field_table = std::vector<field_table_entry>; + + using field_default_entry = std::pair<const char*, value_type>; + using mechanism_field_default_table = std::vector<field_default_entry>; + + using ion_state_entry = std::pair<ionKind, ion_state_view*>; + using mechanism_ion_state_table = std::vector<ion_state_entry>; + + using ion_index_entry = std::pair<ionKind, const index_type**>; + using mechanism_ion_index_table = std::vector<ion_index_entry>; + + // Generated mechanisms must implement the following methods, together with + // fingerprint(), clone(), kind(), nrn_init(), nrn_state(), nrn_current() + // and deliver_events() (if required) from arb::mechanism. + + // Member tables: introspection into derived mechanism fields, views etc. + // Default implementations correspond to no corresponding fields/globals/ions. + + virtual mechanism_field_table field_table() { return {}; } + virtual mechanism_field_default_table field_default_table() { return {}; } + virtual mechanism_global_table global_table() { return {}; } + virtual mechanism_ion_state_table ion_state_table() { return {}; } + virtual mechanism_ion_index_table ion_index_table() { return {}; } + + // Report raw size in bytes of mechanism object. + + virtual std::size_t object_sizeof() const = 0; + + // Event delivery, given event queue state: + + virtual void deliver_events(deliverable_event_stream::state) {}; +}; + +} // namespace gpu +} // namespace arb diff --git a/src/backends/gpu/mechanism_ppack_base.hpp b/src/backends/gpu/mechanism_ppack_base.hpp new file mode 100644 index 0000000000000000000000000000000000000000..68095dd30ed5cdf9c08e212b7848bb41d74a736e --- /dev/null +++ b/src/backends/gpu/mechanism_ppack_base.hpp @@ -0,0 +1,42 @@ +#pragma once + +// Base class for parameter packs for GPU generated kernels: +// will be included by .cu generated sources. + +#include <backends/fvm_types.hpp> + +namespace arb { +namespace gpu { + +// Derived ppack structs may have ion_state_view fields: + +struct ion_state_view { + using value_type = fvm_value_type; + using index_type = fvm_index_type; + + value_type* current_density; + value_type* reversal_potential; + value_type* internal_concentration; + value_type* external_concentration; +}; + +// Parameter pack base: + +struct mechanism_ppack_base { + using value_type = fvm_value_type; + using index_type = fvm_index_type; + using ion_state_view = ::arb::gpu::ion_state_view; + + const index_type* vec_ci_; + const value_type* vec_t_; + const value_type* vec_t_to_; + const value_type* vec_dt_; + const value_type* vec_v_; + value_type* vec_i_; + + const index_type* node_index_; + const value_type* weight_; +}; + +} // namespace gpu +} // namespace arb diff --git a/src/backends/gpu/multi_event_stream.cu b/src/backends/gpu/multi_event_stream.cu index 7dac8501e5893b4c3180eb4c4033aef1c33704c5..349e9da89c746607a647879c899fee2bd0f5362f 100644 --- a/src/backends/gpu/multi_event_stream.cu +++ b/src/backends/gpu/multi_event_stream.cu @@ -5,19 +5,21 @@ #include <memory/copy.hpp> #include <util/rangeutil.hpp> +#include "cuda_common.hpp" + namespace arb { namespace gpu { namespace kernels { template <typename T, typename I> __global__ void mark_until_after( - I n, + unsigned n, I* mark, const I* span_end, const T* ev_time, const T* t_until) { - I i = threadIdx.x+blockIdx.x*blockDim.x; + unsigned i = threadIdx.x+blockIdx.x*blockDim.x; if (i<n) { auto t = t_until[i]; auto end = span_end[i]; @@ -31,13 +33,13 @@ namespace kernels { template <typename T, typename I> __global__ void mark_until( - I n, + unsigned n, I* mark, const I* span_end, const T* ev_time, const T* t_until) { - I i = threadIdx.x+blockIdx.x*blockDim.x; + unsigned i = threadIdx.x+blockIdx.x*blockDim.x; if (i<n) { auto t = t_until[i]; auto end = span_end[i]; @@ -49,15 +51,15 @@ namespace kernels { } } - template <typename T, typename I> + template <typename I> __global__ void drop_marked_events( - I n, + unsigned n, I* n_nonempty, I* span_begin, const I* span_end, const I* mark) { - I i = threadIdx.x+blockIdx.x*blockDim.x; + unsigned i = threadIdx.x+blockIdx.x*blockDim.x; if (i<n) { bool emptied = (span_begin[i]<span_end[i] && mark[i]==span_end[i]); span_begin[i] = mark[i]; @@ -69,13 +71,13 @@ namespace kernels { template <typename T, typename I> __global__ void event_time_if_before( - I n, + unsigned n, const I* span_begin, const I* span_end, const T* ev_time, T* t_until) { - I i = threadIdx.x+blockIdx.x*blockDim.x; + unsigned i = threadIdx.x+blockIdx.x*blockDim.x; if (i<n) { if (span_begin[i]<span_end[i]) { auto ev_t = ev_time[span_begin[i]]; @@ -99,37 +101,42 @@ void multi_event_stream_base::clear() { void multi_event_stream_base::mark_until_after(const_view t_until) { EXPECTS(n_streams()==util::size(t_until)); - constexpr int blockwidth = 128; - int nblock = 1+(n_stream_-1)/blockwidth; - kernels::mark_until_after<value_type, size_type><<<nblock, blockwidth>>>( - n_stream_, mark_.data(), span_end_.data(), ev_time_.data(), t_until.data()); + constexpr int block_dim = 128; + + unsigned n = n_stream_; + int nblock = impl::block_count(n, block_dim); + kernels::mark_until_after<<<nblock, block_dim>>>( + n, mark_.data(), span_end_.data(), ev_time_.data(), t_until.data()); } // Designate for processing events `ev` at head of each event stream `i` // while `t_until[i]` > `event_time(ev)`. void multi_event_stream_base::mark_until(const_view t_until) { EXPECTS(n_streams()==util::size(t_until)); + constexpr int block_dim = 128; - constexpr int blockwidth = 128; - int nblock = 1+(n_stream_-1)/blockwidth; - kernels::mark_until<value_type, size_type><<<nblock, blockwidth>>>( - n_stream_, mark_.data(), span_end_.data(), ev_time_.data(), t_until.data()); + unsigned n = n_stream_; + int nblock = impl::block_count(n, block_dim); + kernels::mark_until<<<nblock, block_dim>>>( + n, mark_.data(), span_end_.data(), ev_time_.data(), t_until.data()); } // Remove marked events from front of each event stream. void multi_event_stream_base::drop_marked_events() { - constexpr int blockwidth = 128; - int nblock = 1+(n_stream_-1)/blockwidth; - kernels::drop_marked_events<value_type, size_type><<<nblock, blockwidth>>>( - n_stream_, n_nonempty_stream_.data(), span_begin_.data(), span_end_.data(), mark_.data()); + constexpr int block_dim = 128; + + unsigned n = n_stream_; + int nblock = impl::block_count(n, block_dim); + kernels::drop_marked_events<<<nblock, block_dim>>>( + n, n_nonempty_stream_.data(), span_begin_.data(), span_end_.data(), mark_.data()); } // If the head of `i`th event stream exists and has time less than `t_until[i]`, set // `t_until[i]` to the event time. void multi_event_stream_base::event_time_if_before(view t_until) { - constexpr int blockwidth = 128; - int nblock = 1+(n_stream_-1)/blockwidth; - kernels::event_time_if_before<value_type, size_type><<<nblock, blockwidth>>>( + constexpr int block_dim = 128; + int nblock = impl::block_count(n_stream_, block_dim); + kernels::event_time_if_before<<<nblock, block_dim>>>( n_stream_, span_begin_.data(), span_end_.data(), ev_time_.data(), t_until.data()); } diff --git a/src/backends/gpu/multi_event_stream.hpp b/src/backends/gpu/multi_event_stream.hpp index 3b9e28bf78e32e5a32a931d5d54bb31e0d28ea94..842c88cd8b501861ba1cff0b194d7bba0c9d65d6 100644 --- a/src/backends/gpu/multi_event_stream.hpp +++ b/src/backends/gpu/multi_event_stream.hpp @@ -20,9 +20,10 @@ class multi_event_stream_base { public: using size_type = cell_size_type; using value_type = fvm_value_type; + using index_type = fvm_index_type; using array = memory::device_vector<value_type>; - using iarray = memory::device_vector<size_type>; + using iarray = memory::device_vector<index_type>; using const_view = array::const_view_type; using view = array::view_type; @@ -83,11 +84,11 @@ protected: tmp_divs_.reserve(n_stream_+1); size_type n_nonempty = 0; - size_type ev_begin_i = 0; - size_type ev_i = 0; + index_type ev_begin_i = 0; + index_type ev_i = 0; tmp_divs_.push_back(ev_i); for (size_type s = 0; s<n_stream_; ++s) { - while (ev_i<n_ev && event_index(staged[ev_i])<s+1) ++ev_i; + while ((size_type)ev_i<n_ev && (size_type)event_index(staged[ev_i])<s+1) ++ev_i; // Within a subrange of events with the same index, events should // be sorted by time. @@ -113,7 +114,7 @@ protected: // Host-side vectors for staging values in init(): std::vector<value_type> tmp_ev_time_; - std::vector<size_type> tmp_divs_; + std::vector<index_type> tmp_divs_; }; template <typename Event> diff --git a/src/backends/gpu/kernels/reduce_by_key.hpp b/src/backends/gpu/reduce_by_key.hpp similarity index 98% rename from src/backends/gpu/kernels/reduce_by_key.hpp rename to src/backends/gpu/reduce_by_key.hpp index 29b054b1c16c2f767b325c8e3145e426b2fc8c7d..f7e8e4c80ffc35c86d1f0597b4e9b7c64578cff1 100644 --- a/src/backends/gpu/kernels/reduce_by_key.hpp +++ b/src/backends/gpu/reduce_by_key.hpp @@ -1,8 +1,8 @@ #pragma once #include <cstdint> -#include "detail.hpp" -#include <backends/gpu/intrinsics.hpp> +#include "cuda_atomic.hpp" +#include "cuda_common.hpp" namespace arb { namespace gpu { diff --git a/src/backends/gpu/shared_state.cpp b/src/backends/gpu/shared_state.cpp new file mode 100644 index 0000000000000000000000000000000000000000..183abd9f43899a7e782ce47a63545a4e9656a428 --- /dev/null +++ b/src/backends/gpu/shared_state.cpp @@ -0,0 +1,189 @@ +#include <cstddef> +#include <vector> + +#include <constants.hpp> +#include <ion.hpp> +#include <memory/wrappers.hpp> +#include <util/rangeutil.hpp> + +#include <backends/event.hpp> +#include <backends/fvm_types.hpp> +#include <backends/multi_event_stream_state.hpp> + +#include <backends/gpu/gpu_store_types.hpp> +#include <backends/gpu/shared_state.hpp> + +using arb::memory::make_const_view; + +namespace arb { +namespace gpu { + +// CUDA implementation entry points: + +void init_concentration_impl( + std::size_t n, fvm_value_type* Xi, fvm_value_type* Xo, const fvm_value_type* weight_Xi, + const fvm_value_type* weight_Xo, fvm_value_type iconc, fvm_value_type econc); + +void nernst_impl( + std::size_t n, fvm_value_type factor, + const fvm_value_type* Xi, const fvm_value_type* Xo, fvm_value_type* eX); + +void update_time_to_impl( + std::size_t n, fvm_value_type* time_to, const fvm_value_type* time, + fvm_value_type dt, fvm_value_type tmax); + +void update_time_to_impl( + std::size_t n, fvm_value_type* time_to, const fvm_value_type* time, + fvm_value_type dt, fvm_value_type tmax); + +void set_dt_impl( + fvm_size_type ncell, fvm_size_type ncomp, fvm_value_type* dt_cell, fvm_value_type* dt_comp, + const fvm_value_type* time_to, const fvm_value_type* time, const fvm_index_type* cv_to_cell); + +void take_samples_impl( + const multi_event_stream_state<raw_probe_info>& s, + const fvm_value_type* time, fvm_value_type* sample_time, fvm_value_type* sample_value); + +// GPU-side minmax: consider CUDA kernel replacement. +std::pair<fvm_value_type, fvm_value_type> minmax_value_impl(fvm_size_type n, const fvm_value_type* v) { + auto v_copy = memory::on_host(memory::const_device_view<fvm_value_type>(v, n)); + return util::minmax_value(v_copy); +} + +// Ion state methods: + +ion_state::ion_state( + ion_info info, + const std::vector<fvm_index_type>& cv, + const std::vector<fvm_value_type>& iconc_norm_area, + const std::vector<fvm_value_type>& econc_norm_area, + unsigned // alignment/padding ignored. +): + node_index_(make_const_view(cv)), + iX_(cv.size(), NAN), + eX_(cv.size(), NAN), + Xi_(cv.size(), NAN), + Xo_(cv.size(), NAN), + weight_Xi_(make_const_view(iconc_norm_area)), + weight_Xo_(make_const_view(econc_norm_area)), + charge(info.charge), + default_int_concentration(info.default_int_concentration), + default_ext_concentration(info.default_ext_concentration) +{ + EXPECTS(node_index_.size()==weight_Xi_.size()); + EXPECTS(node_index_.size()==weight_Xo_.size()); +} + +void ion_state::nernst(fvm_value_type temperature_K) { + // Nernst equation: reversal potenial eX given by: + // + // eX = RT/zF * ln(Xo/Xi) + // + // where: + // R: universal gas constant 8.3144598 J.K-1.mol-1 + // T: temperature in Kelvin + // z: valency of species (K, Na: +1) (Ca: +2) + // F: Faraday's constant 96485.33289 C.mol-1 + // Xo/Xi: ratio of out/in concentrations + + // 1e3 factor required to scale from V -> mV. + constexpr fvm_value_type RF = 1e3*constant::gas_constant/constant::faraday; + + fvm_value_type factor = RF*temperature_K/charge; + nernst_impl(Xi_.size(), factor, Xi_.data(), Xo_.data(), eX_.data()); +} + +void ion_state::init_concentration() { + init_concentration_impl( + Xi_.size(), + Xi_.data(), Xo_.data(), + weight_Xi_.data(), weight_Xo_.data(), + default_int_concentration, default_ext_concentration); +} + +void ion_state::zero_current() { + memory::fill(iX_, 0); +} + +// Shared state methods: + +shared_state::shared_state( + fvm_size_type n_cell, + const std::vector<fvm_index_type>& cv_to_cell_vec, + unsigned // alignment parameter ignored. +): + n_cell(n_cell), + n_cv(cv_to_cell_vec.size()), + cv_to_cell(make_const_view(cv_to_cell_vec)), + time(n_cell), + time_to(n_cell), + dt_cell(n_cell), + dt_cv(n_cv), + voltage(n_cv), + current_density(n_cv), + deliverable_events(n_cell) +{} + +void shared_state::add_ion( + ion_info info, + const std::vector<fvm_index_type>& cv, + const std::vector<fvm_value_type>& iconc_norm_area, + const std::vector<fvm_value_type>& econc_norm_area) +{ + ion_data.emplace(std::piecewise_construct, + std::forward_as_tuple(info.kind), + std::forward_as_tuple(info, cv, iconc_norm_area, econc_norm_area, 1u)); +} + +void shared_state::reset(fvm_value_type initial_voltage, fvm_value_type temperature_K) { + memory::fill(voltage, initial_voltage); + memory::fill(current_density, 0); + memory::fill(time, 0); + memory::fill(time_to, 0); + + for (auto& i: ion_data) { + i.second.reset(temperature_K); + } +} + +void shared_state::zero_currents() { + memory::fill(current_density, 0); + for (auto& i: ion_data) { + i.second.zero_current(); + } +} + +void shared_state::ions_init_concentration() { + for (auto& i: ion_data) { + i.second.init_concentration(); + } +} + +void shared_state::ions_nernst_reversal_potential(fvm_value_type temperature_K) { + for (auto& i: ion_data) { + i.second.nernst(temperature_K); + } +} + +void shared_state::update_time_to(fvm_value_type dt_step, fvm_value_type tmax) { + update_time_to_impl(n_cell, time_to.data(), time.data(), dt_step, tmax); +} + +void shared_state::set_dt() { + set_dt_impl(n_cell, n_cv, dt_cell.data(), dt_cv.data(), time_to.data(), time_to.data(), cv_to_cell.data()); +} + +std::pair<fvm_value_type, fvm_value_type> shared_state::time_bounds() const { + return minmax_value_impl(n_cell, time.data()); +} + +std::pair<fvm_value_type, fvm_value_type> shared_state::voltage_bounds() const { + return minmax_value_impl(n_cv, voltage.data()); +} + +void shared_state::take_samples(const sample_event_stream::state& s, array& sample_time, array& sample_value) { + take_samples_impl(s, time.data(), sample_time.data(), sample_value.data()); +} + +} // namespace gpu +} // namespace arb diff --git a/src/backends/gpu/shared_state.cu b/src/backends/gpu/shared_state.cu new file mode 100644 index 0000000000000000000000000000000000000000..995eeb7678f0684d7b3bac23626ca5fb632ecc54 --- /dev/null +++ b/src/backends/gpu/shared_state.cu @@ -0,0 +1,141 @@ +// CUDA kernels and wrappers for shared state methods. + +#include <cstdint> + +#include <backends/event.hpp> +#include <backends/multi_event_stream_state.hpp> + +#include "cuda_common.hpp" + +namespace arb { +namespace gpu { + +namespace kernel { + +template <typename T> +__global__ +void nernst_impl(unsigned n, T factor, const T* Xo, const T* Xi, T* eX) { + unsigned i = threadIdx.x+blockIdx.x*blockDim.x; + + if (i<n) { + eX[i] = factor*std::log(Xo[i]/Xi[i]); + } +} + +template <typename T> +__global__ +void init_concentration_impl(unsigned n, T* Xi, T* Xo, const T* weight_Xi, const T* weight_Xo, T c_int, T c_ext) { + unsigned i = threadIdx.x+blockIdx.x*blockDim.x; + + if (i<n) { + Xi[i] = c_int*weight_Xi[i]; + Xo[i] = c_ext*weight_Xo[i]; + } +} + +template <typename T> +__global__ void update_time_to_impl(unsigned n, T* time_to, const T* time, T dt, T tmax) { + unsigned i = threadIdx.x+blockIdx.x*blockDim.x; + if (i<n) { + auto t = time[i]+dt; + time_to[i] = t<tmax? t: tmax; + } +} + +// Vector minus: x = y - z +template <typename T> +__global__ void vec_minus(unsigned n, T* x, const T* y, const T* z) { + unsigned i = threadIdx.x+blockIdx.x*blockDim.x; + if (i<n) { + x[i] = y[i]-z[i]; + } +} + +// Vector gather: x[i] = y[index[i]] +template <typename T, typename I> +__global__ void gather(unsigned n, T* x, const T* y, const I* index) { + unsigned i = threadIdx.x+blockIdx.x*blockDim.x; + if (i<n) { + x[i] = y[index[i]]; + } +} + +__global__ void take_samples_impl( + multi_event_stream_state<raw_probe_info> s, + const fvm_value_type* time, fvm_value_type* sample_time, fvm_value_type* sample_value) +{ + unsigned i = threadIdx.x+blockIdx.x*blockDim.x; + if (i<s.n) { + auto begin = s.ev_data+s.begin_offset[i]; + auto end = s.ev_data+s.end_offset[i]; + for (auto p = begin; p!=end; ++p) { + sample_time[p->offset] = time[i]; + sample_value[p->offset] = *p->handle; + } + } +} + +} // namespace kernel + +using impl::block_count; + +void nernst_impl( + std::size_t n, fvm_value_type factor, + const fvm_value_type* Xo, const fvm_value_type* Xi, fvm_value_type* eX) +{ + if (!n) return; + + constexpr int block_dim = 128; + int nblock = block_count(n, block_dim); + kernel::nernst_impl<<<nblock, block_dim>>>(n, factor, Xo, Xi, eX); +} + +void init_concentration_impl( + std::size_t n, fvm_value_type* Xi, fvm_value_type* Xo, const fvm_value_type* weight_Xi, + const fvm_value_type* weight_Xo, fvm_value_type c_int, fvm_value_type c_ext) +{ + if (!n) return; + + constexpr int block_dim = 128; + int nblock = block_count(n, block_dim); + kernel::init_concentration_impl<<<nblock, block_dim>>>(n, Xi, Xo, weight_Xi, weight_Xo, c_int, c_ext); +} + +void update_time_to_impl( + std::size_t n, fvm_value_type* time_to, const fvm_value_type* time, + fvm_value_type dt, fvm_value_type tmax) +{ + if (!n) return; + + constexpr int block_dim = 128; + const int nblock = block_count(n, block_dim); + kernel::update_time_to_impl<<<nblock, block_dim>>>(n, time_to, time, dt, tmax); +} + +void set_dt_impl( + fvm_size_type ncell, fvm_size_type ncomp, fvm_value_type* dt_cell, fvm_value_type* dt_comp, + const fvm_value_type* time_to, const fvm_value_type* time, const fvm_index_type* cv_to_cell) +{ + if (!ncell || !ncomp) return; + + constexpr int block_dim = 128; + int nblock = block_count(ncell, block_dim); + kernel::vec_minus<<<nblock, block_dim>>>(ncell, dt_cell, time_to, time); + + nblock = block_count(ncomp, block_dim); + kernel::gather<<<nblock, block_dim>>>(ncomp, dt_comp, dt_cell, cv_to_cell); +} + +void take_samples_impl( + const multi_event_stream_state<raw_probe_info>& s, + const fvm_value_type* time, fvm_value_type* sample_time, fvm_value_type* sample_value) +{ + if (!s.n_streams()) return; + + constexpr int block_dim = 128; + const int nblock = block_count(s.n_streams(), block_dim); + kernel::take_samples_impl<<<nblock, block_dim>>>(s, time, sample_time, sample_value); +} + +} // namespace gpu +} // namespace arb diff --git a/src/backends/gpu/shared_state.hpp b/src/backends/gpu/shared_state.hpp new file mode 100644 index 0000000000000000000000000000000000000000..642dd1eb9755e7c7fe29b0c028712f391bca1657 --- /dev/null +++ b/src/backends/gpu/shared_state.hpp @@ -0,0 +1,128 @@ +#pragma once + +#include <iosfwd> +#include <unordered_map> +#include <utility> +#include <vector> + +#include <util/enumhash.hpp> +#include <backends/fvm_types.hpp> +#include <backends/gpu/gpu_store_types.hpp> + +namespace arb { +namespace gpu { + +/* + * Ion state fields correspond to NMODL ion variables, where X + * is replaced with the name of the ion. E.g. for calcium 'ca': + * + * Field NMODL variable Meaning + * ------------------------------------------------------- + * iX_ ica calcium ion current density + * eX_ eca calcium ion channel reversal potential + * Xi_ cai internal calcium concentration + * Xo_ cao external calcium concentration + */ + +struct ion_state { + iarray node_index_; // Instance to CV map. + array iX_; // (nA) current + array eX_; // (mV) reversal potential + array Xi_; // (mM) internal concentration + array Xo_; // (mM) external concentration + array weight_Xi_; // (1) concentration weight internal + array weight_Xo_; // (1) concentration weight external + + int charge; // charge of ionic species + fvm_value_type default_int_concentration; // (mM) default internal concentration + fvm_value_type default_ext_concentration; // (mM) default external concentration + + ion_state() = default; + + ion_state( + ion_info info, + const std::vector<fvm_index_type>& cv, + const std::vector<fvm_value_type>& iconc_norm_area, + const std::vector<fvm_value_type>& econc_norm_area, + unsigned align + ); + + // Calculate the reversal potential eX (mV) using Nernst equation + void nernst(fvm_value_type temperature_K); + + // Set ion concentrations to weighted proportion of default concentrations. + void init_concentration(); + + // Set ionic current density to zero. + void zero_current(); + + void reset(fvm_value_type temperature_K) { + zero_current(); + init_concentration(); + nernst(temperature_K); + } +}; + +struct shared_state { + fvm_size_type n_cell = 0; // Number of distinct cells (integration domains). + fvm_size_type n_cv = 0; // Total number of CVs. + + iarray cv_to_cell; // Maps CV index to cell index. + array time; // Maps cell index to integration start time [ms]. + array time_to; // Maps cell index to integration stop time [ms]. + array dt_cell; // Maps cell index to (stop time) - (start time) [ms]. + array dt_cv; // Maps CV index to dt [ms]. + array voltage; // Maps CV index to membrane voltage [mV]. + array current_density; // Maps CV index to current density [A/m²]. + + std::unordered_map<ionKind, ion_state, util::enum_hash> ion_data; + + deliverable_event_stream deliverable_events; + + shared_state() = default; + + shared_state( + fvm_size_type n_cell, + const std::vector<fvm_index_type>& cv_to_cell_vec, + unsigned align + ); + + void add_ion( + ion_info info, + const std::vector<fvm_index_type>& cv, + const std::vector<fvm_value_type>& iconc_norm_area, + const std::vector<fvm_value_type>& econc_norm_area); + + void zero_currents(); + + void ions_init_concentration(); + + void ions_nernst_reversal_potential(fvm_value_type temperature_K); + + // Set time_to to earliest of time+dt_step and tmax. + void update_time_to(fvm_value_type dt_step, fvm_value_type tmax); + + // Set the per-cell and per-compartment dt from time_to - time. + void set_dt(); + + // Return minimum and maximum time value [ms] across cells. + std::pair<fvm_value_type, fvm_value_type> time_bounds() const; + + // Return minimum and maximum voltage value [mV] across cells. + // (Used for solution bounds checking.) + std::pair<fvm_value_type, fvm_value_type> voltage_bounds() const; + + // Take samples according to marked events in a sample_event_stream. + void take_samples( + const sample_event_stream::state& s, + array& sample_time, + array& sample_value); + + void reset(fvm_value_type initial_voltage, fvm_value_type temperature_K); +}; + +// For debugging only: +std::ostream& operator<<(std::ostream& o, const shared_state& s); + +} // namespace gpu +} // namespace arb diff --git a/src/backends/gpu/stack.hpp b/src/backends/gpu/stack.hpp index 3c901f2ef0ae924d12da8b0d4de5815dc50764b4..2e19d329154dc5420240dd58f4d075d42760f842 100644 --- a/src/backends/gpu/stack.hpp +++ b/src/backends/gpu/stack.hpp @@ -4,7 +4,7 @@ #include <backends/gpu/managed_ptr.hpp> #include <memory/allocator.hpp> -#include "stack_common.hpp" +#include "stack_storage.hpp" namespace arb { namespace gpu { diff --git a/src/backends/gpu/kernels/stack.hpp b/src/backends/gpu/stack_cu.hpp similarity index 96% rename from src/backends/gpu/kernels/stack.hpp rename to src/backends/gpu/stack_cu.hpp index b20f18c3b2e05fb12e21819607276399e71ff011..743835cd6a97c689bc594f3d61614320d9c6e089 100644 --- a/src/backends/gpu/kernels/stack.hpp +++ b/src/backends/gpu/stack_cu.hpp @@ -1,6 +1,6 @@ #pragma once -#include "../stack_common.hpp" +#include "stack_storage.hpp" namespace arb { namespace gpu { diff --git a/src/backends/gpu/stack_common.hpp b/src/backends/gpu/stack_storage.hpp similarity index 72% rename from src/backends/gpu/stack_common.hpp rename to src/backends/gpu/stack_storage.hpp index 1d58646f6c68af4011a86720e3e0a15dcd1b1ba1..fa586c6ce29b0f94c636072b0c1d58f97fbc361d 100644 --- a/src/backends/gpu/stack_common.hpp +++ b/src/backends/gpu/stack_storage.hpp @@ -5,16 +5,6 @@ namespace arb { namespace gpu { -// stores a single crossing event -struct threshold_crossing { - fvm_size_type index; // index of variable - fvm_value_type time; // time of crossing - - friend bool operator==(threshold_crossing l, threshold_crossing r) { - return l.index==r.index && l.time==r.time; - } -}; - // Concrete storage of gpu stack datatype. // The stack datatype resides in host memory, and holds a pointer to the // stack_storage in managed memory, which can be accessed by both host and diff --git a/src/backends/gpu/stim_current.hpp b/src/backends/gpu/stim_current.hpp deleted file mode 100644 index 04bebdb4958b4693d14df20d52dc165ef46aa58f..0000000000000000000000000000000000000000 --- a/src/backends/gpu/stim_current.hpp +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once - -#include <backends/fvm_types.hpp> - -namespace arb{ -namespace gpu { - -void stim_current( - const fvm_value_type* delay, const fvm_value_type* duration, - const fvm_value_type* amplitude, const fvm_value_type* weights, - const fvm_size_type* node_index, int n, - const fvm_size_type* cell_index, const fvm_value_type* time, - fvm_value_type* current); - -} // namespace gpu -} // namespace arb diff --git a/src/backends/gpu/stimulus.cpp b/src/backends/gpu/stimulus.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b08561e568b1af6ff66dda93200b43c879be9c2f --- /dev/null +++ b/src/backends/gpu/stimulus.cpp @@ -0,0 +1,66 @@ +#include <cmath> + +#include <backends/builtin_mech_proto.hpp> +#include <backends/fvm_types.hpp> +#include <backends/gpu/mechanism.hpp> +#include <backends/gpu/mechanism_ppack_base.hpp> + +#include "stimulus.hpp" + +namespace arb { +namespace gpu { + +class stimulus: public arb::gpu::mechanism { +public: + const mechanism_fingerprint& fingerprint() const override { + static mechanism_fingerprint hash = "##builtin_stimulus"; + return hash; + } + std::string internal_name() const override { return "_builtin_stimulus"; } + mechanismKind kind() const override { return ::arb::mechanismKind::point; } + mechanism_ptr clone() const override { return mechanism_ptr(new stimulus()); } + + void nrn_init() override {} + void nrn_state() override {} + void nrn_current() override { + stimulus_current_impl(size(), pp_); + } + + void write_ions() override {} + void deliver_events(deliverable_event_stream::state events) override {} + + mechanism_ppack_base* ppack_ptr() override { + return &pp_; + } + +protected: + std::size_t object_sizeof() const override { return sizeof(*this); } + + mechanism_field_table field_table() override { + return { + {"delay", &pp_.delay}, + {"duration", &pp_.duration}, + {"amplitude", &pp_.amplitude} + }; + } + + mechanism_field_default_table field_default_table() override { + return { + {"delay", 0}, + {"duration", 0}, + {"amplitude", 0} + }; + } + +private: + stimulus_pp pp_; +}; + +} // namespace gpu + +template <> +concrete_mech_ptr<gpu::backend> make_builtin_stimulus() { + return concrete_mech_ptr<gpu::backend>(new arb::gpu::stimulus()); +} + +} // namespace arb diff --git a/src/backends/gpu/stimulus.cu b/src/backends/gpu/stimulus.cu new file mode 100644 index 0000000000000000000000000000000000000000..a07f93f2a882184eab68731052b60349e07a87e8 --- /dev/null +++ b/src/backends/gpu/stimulus.cu @@ -0,0 +1,34 @@ +#include <backends/fvm_types.hpp> + +#include "cuda_atomic.hpp" +#include "cuda_common.hpp" +#include "stimulus.hpp" + +namespace arb { +namespace gpu { + +namespace kernel { + __global__ + void stimulus_current_impl(int n, stimulus_pp pp) { + auto i = threadIdx.x + blockDim.x*blockIdx.x; + if (i<n) { + auto t = pp.vec_t_[pp.vec_ci_[i]]; + if (t>=pp.delay[i] && t<pp.delay[i]+pp.duration[i]) { + // use subtraction because the electrode currents are specified + // in terms of current into the compartment + cuda_atomic_add(pp.vec_i_+pp.node_index_[i], -pp.weight_[i]*pp.amplitude[i]); + } + } + } +} // namespace kernel + + +void stimulus_current_impl(int n, const stimulus_pp& pp) { + constexpr unsigned block_dim = 128; + const unsigned grid_dim = impl::block_count(n, block_dim); + + kernel::stimulus_current_impl<<<grid_dim, block_dim>>>(n, pp); +} + +} // namespace gpu +} // namespace arb diff --git a/src/backends/gpu/stimulus.hpp b/src/backends/gpu/stimulus.hpp index 72a3a3e0c481eebeaed622bab81c5dea16adb2c6..b5d41dada90543e967f14750f166949f0fa6c33f 100644 --- a/src/backends/gpu/stimulus.hpp +++ b/src/backends/gpu/stimulus.hpp @@ -1,109 +1,17 @@ #pragma once -#include <cmath> -#include <limits> +#include <backends/gpu/mechanism_ppack_base.hpp> -#include <mechanism.hpp> -#include <algorithms.hpp> -#include "stim_current.hpp" -#include <util/pprintf.hpp> - -namespace arb{ +namespace arb { namespace gpu { -template<class Backend> -class stimulus : public mechanism<Backend> { -public: - using base = mechanism<Backend>; - using value_type = typename base::value_type; - using size_type = typename base::size_type; - - using array = typename base::array; - using iarray = typename base::iarray; - using view = typename base::view; - using iview = typename base::iview; - using const_view = typename base::const_view; - using const_iview = typename base::const_iview; - using ion_type = typename base::ion_type; - - static constexpr size_type no_mech_id = (size_type)-1; - - stimulus(const_iview vec_ci, const_view vec_t, const_view vec_t_to, const_view vec_dt, view vec_v, view vec_i, iarray&& node_index): - base(no_mech_id, vec_ci, vec_t, vec_t_to, vec_dt, vec_v, vec_i, std::move(node_index)) - {} - - using base::size; - - std::size_t memory() const override { - return 0; - } - - void set_params() override {} - - std::string name() const override { - return "stimulus"; - } - - mechanismKind kind() const override { - return mechanismKind::point; - } - - typename base::ion_spec uses_ion(ionKind k) const override { - return {false, false, false}; - } - - void set_ion(ionKind k, ion_type& i, std::vector<size_type>const& index) override { - throw std::domain_error( - arb::util::pprintf("mechanism % does not support ion type\n", name())); - } - - void nrn_init() override {} - void nrn_state() override {} - - void net_receive(int i_, value_type weight) override { - throw std::domain_error("stimulus mechanism should never receive an event\n"); - } - - void set_parameters( - const std::vector<value_type>& amp, - const std::vector<value_type>& dur, - const std::vector<value_type>& del) - { - amplitude = memory::on_gpu(amp); - duration = memory::on_gpu(dur); - delay = memory::on_gpu(del); - } - - void set_weights(array&& w) override { - EXPECTS(size()==w.size()); - weights = w; - } - - void nrn_current() override { - if (amplitude.size() != size()) { - throw std::domain_error("stimulus called with mismatched parameter size\n"); - } - - // don't launch a kernel if there are no stimuli - if (!size()) return; - - stim_current(delay.data(), duration.data(), amplitude.data(), weights.data(), - node_index_.data(), size(), vec_ci_.data(), vec_t_.data(), - vec_i_.data()); - - } - - array amplitude; - array duration; - array delay; - array weights; - - using base::vec_ci_; - using base::vec_t_; - using base::vec_v_; - using base::vec_i_; - using base::node_index_; +struct stimulus_pp: mechanism_ppack_base { + fvm_value_type* delay; + fvm_value_type* duration; + fvm_value_type* amplitude; }; +void stimulus_current_impl(int n, const stimulus_pp&); + } // namespace gpu } // namespace arb diff --git a/src/backends/gpu/kernels/test_thresholds.cu b/src/backends/gpu/threshold_watcher.cu similarity index 51% rename from src/backends/gpu/kernels/test_thresholds.cu rename to src/backends/gpu/threshold_watcher.cu index ce0ef688b96e28477dc7d9737b3899d5517c2d75..12c43676791ad384d7c38fb4bf4864487db57737 100644 --- a/src/backends/gpu/kernels/test_thresholds.cu +++ b/src/backends/gpu/threshold_watcher.cu @@ -1,11 +1,21 @@ +#include <cmath> + #include <backends/fvm_types.hpp> -#include "detail.hpp" -#include "stack.hpp" +#include "cuda_common.hpp" +#include "stack_cu.hpp" namespace arb { namespace gpu { +namespace kernel { + +template <typename T> +__device__ +inline T lerp(T a, T b, T u) { + return std::fma(u, b, std::fma(-u, a, a)); +} + /// kernel used to test for threshold crossing test code. /// params: /// t : current time (ms) @@ -17,12 +27,12 @@ namespace gpu { /// values : values at t_prev /// thresholds : threshold values to watch for crossings __global__ -void test_thresholds_kernel( - const fvm_size_type* cv_to_cell, const fvm_value_type* t_after, const fvm_value_type* t_before, +void test_thresholds_impl( int size, + const fvm_index_type* cv_to_cell, const fvm_value_type* t_after, const fvm_value_type* t_before, stack_storage<threshold_crossing>& stack, - fvm_size_type* is_crossed, fvm_value_type* prev_values, - const fvm_size_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds) + fvm_index_type* is_crossed, fvm_value_type* prev_values, + const fvm_index_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds) { int i = threadIdx.x + blockIdx.x*blockDim.x; @@ -42,7 +52,7 @@ void test_thresholds_kernel( // The threshold has been passed, so estimate the time using // linear interpolation auto pos = (thresh - v_prev)/(v - v_prev); - crossing_time = impl::lerp(t_before[cell], t_after[cell], pos); + crossing_time = lerp(t_before[cell], t_after[cell], pos); is_crossed[i] = 1; crossed = true; @@ -60,17 +70,39 @@ void test_thresholds_kernel( } } -void test_thresholds( - const fvm_size_type* cv_to_cell, const fvm_value_type* t_after, const fvm_value_type* t_before, +__global__ +extern void reset_crossed_impl( + int size, fvm_index_type* is_crossed, + const fvm_index_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds) +{ + int i = threadIdx.x + blockIdx.x*blockDim.x; + if (i<size) { + is_crossed[i] = values[cv_index[i]] >= thresholds[i]; + } +} + +} // namespace kernel + +void test_thresholds_impl( int size, + const fvm_index_type* cv_to_cell, const fvm_value_type* t_after, const fvm_value_type* t_before, stack_storage<threshold_crossing>& stack, - fvm_size_type* is_crossed, fvm_value_type* prev_values, - const fvm_size_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds) + fvm_index_type* is_crossed, fvm_value_type* prev_values, + const fvm_index_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds) +{ + constexpr int block_dim = 128; + const int grid_dim = impl::block_count(size, block_dim); + kernel::test_thresholds_impl<<<grid_dim, block_dim>>>( + size, cv_to_cell, t_after, t_before, stack, is_crossed, prev_values, cv_index, values, thresholds); +} + +void reset_crossed_impl( + int size, fvm_index_type* is_crossed, + const fvm_index_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds) { constexpr int block_dim = 128; const int grid_dim = impl::block_count(size, block_dim); - test_thresholds_kernel<<<grid_dim, block_dim>>>( - cv_to_cell, t_after, t_before, size, stack, is_crossed, prev_values, cv_index, values, thresholds); + kernel::reset_crossed_impl<<<grid_dim, block_dim>>>(size, is_crossed, cv_index, values, thresholds); } } // namespace gpu diff --git a/src/backends/gpu/threshold_watcher.hpp b/src/backends/gpu/threshold_watcher.hpp index 6315fe32571a3dec0e242f2c031c6a4136b97405..b38e850fcac37164ed9f7b12b4b04972286b8a88 100644 --- a/src/backends/gpu/threshold_watcher.hpp +++ b/src/backends/gpu/threshold_watcher.hpp @@ -4,49 +4,56 @@ #include <memory/memory.hpp> #include <util/span.hpp> -#include "managed_ptr.hpp" +#include <backends/fvm_types.hpp> + +#include <backends/gpu/gpu_store_types.hpp> +#include <backends/gpu/managed_ptr.hpp> +#include <backends/gpu/stack.hpp> + #include "stack.hpp" -#include "backends/fvm_types.hpp" -#include "kernels/test_thresholds.hpp" namespace arb { namespace gpu { -/// threshold crossing logic -/// used as part of spike detection back end -class threshold_watcher { -public: - using value_type = fvm_value_type; - using size_type = fvm_size_type; +// CUDA implementation entry point: - using array = memory::device_vector<value_type>; - using iarray = memory::device_vector<size_type>; - using const_view = typename array::const_view_type; - using const_iview = typename iarray::const_view_type; +void test_thresholds_impl( + int size, + const fvm_index_type* cv_to_cell, const fvm_value_type* t_after, const fvm_value_type* t_before, + stack_storage<threshold_crossing>& stack, + fvm_index_type* is_crossed, fvm_value_type* prev_values, + const fvm_index_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds); +void reset_crossed_impl( + int size, + fvm_index_type* is_crossed, + const fvm_index_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds); + + +class threshold_watcher { +public: using stack_type = stack<threshold_crossing>; threshold_watcher() = default; - threshold_watcher(threshold_watcher&& other) = default; threshold_watcher& operator=(threshold_watcher&& other) = default; threshold_watcher( - const_iview vec_ci, - const_view vec_t_before, - const_view vec_t_after, - const_view values, - const std::vector<size_type>& index, - const std::vector<value_type>& thresh, - value_type t=0): - cv_to_cell_(vec_ci), - t_before_(vec_t_before), - t_after_(vec_t_after), + const fvm_index_type* cv_to_cell, + const fvm_value_type* t_before, + const fvm_value_type* t_after, + const fvm_value_type* values, + const std::vector<fvm_index_type>& cv_index, + const std::vector<fvm_value_type>& thresholds + ): + cv_to_cell_(cv_to_cell), + t_before_(t_before), + t_after_(t_after), values_(values), - cv_index_(memory::make_const_view(index)), - thresholds_(memory::make_const_view(thresh)), - prev_values_(values), - is_crossed_(size()), + cv_index_(memory::make_const_view(cv_index)), + is_crossed_(cv_index.size()), + thresholds_(memory::make_const_view(thresholds)), + v_prev_(memory::const_host_view<fvm_value_type>(values, cv_index.size())), // TODO: allocates enough space for 10 spikes per watch. // A more robust approach might be needed to avoid overflows. stack_(10*size()) @@ -67,24 +74,11 @@ public: /// calling, because the values are used to determine the initial state void reset() { clear_crossings(); - - // Make host-side copies of the information needed to calculate - // the initial crossed state - auto values = memory::on_host(values_); - auto thresholds = memory::on_host(thresholds_); - auto cv_index = memory::on_host(cv_index_); - - // calculate the initial crossed state in host memory - std::vector<size_type> crossed(size()); - for (auto i: util::make_span(0u, size())) { - crossed[i] = values[cv_index[i]] < thresholds[i] ? 0 : 1; - } - - // copy the initial crossed state to device memory - memory::copy(crossed, is_crossed_); + reset_crossed_impl((int)size(), is_crossed_.data(), cv_index_.data(), values_, thresholds_.data()); } - bool is_crossed(size_type i) const { + // Testing-only interface. + bool is_crossed(int i) const { return is_crossed_[i]; } @@ -101,12 +95,12 @@ public: /// crossed since current time t, and the last time the test was /// performed. void test() { - test_thresholds( - cv_to_cell_.data(), t_after_.data(), t_before_.data(), - size(), + test_thresholds_impl( + (int)size(), + cv_to_cell_, t_after_, t_before_, stack_.storage(), - is_crossed_.data(), prev_values_.data(), - cv_index_.data(), values_.data(), thresholds_.data()); + is_crossed_.data(), v_prev_.data(), + cv_index_.data(), values_, thresholds_.data()); // Check that the number of spikes has not exceeded capacity. // ATTENTION: requires cudaDeviceSynchronize to avoid simultaneous @@ -119,21 +113,21 @@ public: return cv_index_.size(); } - /// Data type used to store the crossings. - /// Provided to make type-generic calling code. - using crossing_list = std::vector<threshold_crossing>; - private: - const_iview cv_to_cell_; // index to cell mapping: on gpu - const_view t_before_; // times per cell corresponding to prev_values_: on gpu - const_view t_after_; // times per cell corresponding to values_: on gpu - const_view values_; // values to watch: on gpu - iarray cv_index_; // compartment indexes of values to watch: on gpu - - array thresholds_; // threshold for each watch: on gpu - array prev_values_; // values at previous sample time: on gpu - iarray is_crossed_; // bool flag for state of each watch: on gpu - + /// Non-owning pointers to gpu-side cv-to-cell map, per-cell time data, + /// and the values for to test against thresholds. + const fvm_index_type* cv_to_cell_ = nullptr; + const fvm_value_type* t_before_ = nullptr; + const fvm_value_type* t_after_ = nullptr; + const fvm_value_type* values_ = nullptr; + + // Threshold watch state, with data on gpu: + iarray cv_index_; // Compartment indexes of values to watch. + iarray is_crossed_; // Boolean flag for state of each watch. + array thresholds_; // Threshold for each watch. + array v_prev_; // Values at previous sample time. + + // Hybrid host/gpu data structure for accumulating threshold crossings. stack_type stack_; }; diff --git a/src/backends/gpu/time_ops.hpp b/src/backends/gpu/time_ops.hpp deleted file mode 100644 index b38760f8bb6343b6ce37adf0e08ad411274a229f..0000000000000000000000000000000000000000 --- a/src/backends/gpu/time_ops.hpp +++ /dev/null @@ -1,23 +0,0 @@ -#pragma once - -#include <backends/fvm_types.hpp> - -namespace arb { -namespace gpu { - -void update_time_to(fvm_size_type n, - fvm_value_type* time_to, - const fvm_value_type* time, - fvm_value_type dt, - fvm_value_type tmax); - -void set_dt(fvm_size_type ncell, - fvm_size_type ncomp, - fvm_value_type* dt_cell, - fvm_value_type* dt_comp, - const fvm_value_type* time_to, - const fvm_value_type* time, - const fvm_size_type* cv_to_cell); - -} // namespace gpu -} // namespace arb diff --git a/src/backends/multi_event_stream_state.hpp b/src/backends/multi_event_stream_state.hpp index b0ef0f472c5a4e7b396d27ad92aa8c5d5b22f5e7..303d4059e5079552218fa3edf2b651c364d29e6d 100644 --- a/src/backends/multi_event_stream_state.hpp +++ b/src/backends/multi_event_stream_state.hpp @@ -1,6 +1,6 @@ #pragma once -#include <common_types.hpp> +#include <backends/fvm_types.hpp> // Pointer representation of multi-event stream marked event state, // common across CPU and GPU backends. @@ -11,20 +11,20 @@ template <typename EvData> struct multi_event_stream_state { using value_type = EvData; - cell_size_type n; // number of streams + fvm_size_type n; // number of streams const value_type* ev_data; // array of event data items - const cell_size_type* begin_offset; // array of offsets to beginning of marked events - const cell_size_type* end_offset; // array of offsets to end of marked events + const fvm_index_type* begin_offset; // array of offsets to beginning of marked events + const fvm_index_type* end_offset; // array of offsets to end of marked events fvm_size_type n_streams() const { return n; } - const value_type* begin_marked(fvm_size_type i) const { + const value_type* begin_marked(fvm_index_type i) const { return ev_data+begin_offset[i]; } - const value_type* end_marked(fvm_size_type i) const { + const value_type* end_marked(fvm_index_type i) const { return ev_data+end_offset[i]; } }; diff --git a/src/backends/multicore/fvm.cpp b/src/backends/multicore/fvm.cpp deleted file mode 100644 index ea471917cca568f85ef807d0011e79090fc2137f..0000000000000000000000000000000000000000 --- a/src/backends/multicore/fvm.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include "fvm.hpp" - -#include <mechanisms/multicore/hh_cpu.hpp> -#include <mechanisms/multicore/pas_cpu.hpp> -#include <mechanisms/multicore/expsyn_cpu.hpp> -#include <mechanisms/multicore/exp2syn_cpu.hpp> -#include <mechanisms/multicore/test_kin1_cpu.hpp> -#include <mechanisms/multicore/test_kinlva_cpu.hpp> -#include <mechanisms/multicore/test_ca_cpu.hpp> -#include <mechanisms/multicore/nax_cpu.hpp> -#include <mechanisms/multicore/kamt_cpu.hpp> -#include <mechanisms/multicore/kdrmt_cpu.hpp> - -namespace arb { -namespace multicore { - -std::map<std::string, backend::maker_type> -backend::mech_map_ = { - { std::string("pas"), maker<mechanism_pas> }, - { std::string("hh"), maker<mechanism_hh> }, - { std::string("expsyn"), maker<mechanism_expsyn> }, - { std::string("exp2syn"), maker<mechanism_exp2syn> }, - { std::string("test_kin1"), maker<mechanism_test_kin1> }, - { std::string("test_kinlva"), maker<mechanism_test_kinlva> }, - { std::string("test_ca"), maker<mechanism_test_ca> }, - { std::string("nax"), maker<mechanism_nax> }, - { std::string("kamt"), maker<mechanism_kamt> }, - { std::string("kdrmt"), maker<mechanism_kdrmt> }, -}; - -} // namespace multicore -} // namespace arb diff --git a/src/backends/multicore/fvm.hpp b/src/backends/multicore/fvm.hpp index 4a0ff6607e7bef389a1b821f8f25724df0e46de2..6dbb7a209c0228c4ebca707b8d82ae19d02e92bd 100644 --- a/src/backends/multicore/fvm.hpp +++ b/src/backends/multicore/fvm.hpp @@ -1,189 +1,61 @@ #pragma once -#include <map> #include <string> +#include <vector> #include <backends/event.hpp> -#include <backends/fvm_types.hpp> -#include <common_types.hpp> -#include <constants.hpp> -#include <event_queue.hpp> -#include <mechanism.hpp> -#include <memory/memory.hpp> -#include <memory/wrappers.hpp> -#include <util/meta.hpp> +#include <util/padded_alloc.hpp> +#include <util/range.hpp> #include <util/rangeutil.hpp> -#include <util/span.hpp> -#include "matrix_state.hpp" -#include "multi_event_stream.hpp" -#include "stimulus.hpp" -#include "threshold_watcher.hpp" +#include <backends/multicore/matrix_state.hpp> +#include <backends/multicore/multi_event_stream.hpp> +#include <backends/multicore/multicore_common.hpp> +#include <backends/multicore/shared_state.hpp> +#include <backends/multicore/threshold_watcher.hpp> namespace arb { namespace multicore { struct backend { - static bool is_supported() { - return true; - } + static bool is_supported() { return true; } + static std::string name() { return "cpu"; } - /// define the real and index types using value_type = fvm_value_type; + using index_type = fvm_index_type; using size_type = fvm_size_type; - /// define storage types - using array = memory::host_vector<value_type>; - using iarray = memory::host_vector<size_type>; - - using view = typename array::view_type; - using const_view = typename array::const_view_type; - - using iview = typename iarray::view_type; - using const_iview = typename iarray::const_view_type; - - using host_array = array; - using host_iarray = iarray; - - using host_view = view; - using host_iview = iview; - - using matrix_state = arb::multicore::matrix_state<value_type, size_type>; - - // backend-specific multi event streams. - using deliverable_event_stream = arb::multicore::multi_event_stream<deliverable_event>; - using sample_event_stream = arb::multicore::multi_event_stream<sample_event>; - - // - // mechanism infrastructure - // - using ion_type = ion<backend>; - using stimulus = multicore::stimulus<backend>; - - using mechanism = arb::mechanism<backend>; - using mechanism_ptr = std::unique_ptr<mechanism>; + using array = arb::multicore::array; + using iarray = arb::multicore::iarray; - static mechanism_ptr make_mechanism( - const std::string& name, - size_type mech_id, - const_iview vec_ci, - const_view vec_t, const_view vec_t_to, const_view vec_dt, - view vec_v, view vec_i, - const std::vector<value_type>& weights, - const std::vector<size_type>& node_indices) - { - if (!has_mechanism(name)) { - throw std::out_of_range("no mechanism in database : " + name); - } - - return mech_map_.find(name)->second(mech_id, vec_ci, vec_t, vec_t_to, vec_dt, vec_v, vec_i, memory::make_const_view(weights), memory::make_const_view(node_indices)); - } - - static bool has_mechanism(const std::string& name) { - return mech_map_.count(name)>0; + static util::range<const value_type*> host_view(const array& v) { + return util::range_pointer_view(v); } - static std::string name() { - return "cpu"; + static util::range<const index_type*> host_view(const iarray& v) { + return util::range_pointer_view(v); } - // dereference a probe handle - static value_type dereference(probe_handle h) { - return *h; // just a pointer! - } - - /// threshold crossing logic - /// used as part of spike detection back end - using threshold_watcher = - arb::multicore::threshold_watcher<value_type, size_type>; + using matrix_state = arb::multicore::matrix_state<value_type, index_type>; + using threshold_watcher = arb::multicore::threshold_watcher; + using deliverable_event_stream = arb::multicore::deliverable_event_stream; + using sample_event_stream = arb::multicore::sample_event_stream; - // perform min/max reductions on 'array' type - template <typename V> - static std::pair<V, V> minmax_value(const memory::host_vector<V>& v) { - return util::minmax_value(v); - } + using shared_state = arb::multicore::shared_state; - // perform element-wise comparison on 'array' type against `t_test`. - template <typename V> - static bool any_time_before(const memory::host_vector<V>& t, V t_test) { - return minmax_value(t).first<t_test; - } - - static void update_time_to(array& time_to, const_view time, value_type dt, value_type tmax) { - size_type ncell = util::size(time); - for (size_type i = 0; i<ncell; ++i) { - time_to[i] = std::min(time[i]+dt, tmax); - } - } - - // set the per-cell and per-compartment dt_ from time_to_ - time_. - static void set_dt(array& dt_cell, array& dt_comp, const_view time_to, const_view time, const_iview cv_to_cell) { - size_type ncell = util::size(dt_cell); - size_type ncomp = util::size(dt_comp); - - for (size_type j = 0; j<ncell; ++j) { - dt_cell[j] = time_to[j]-time[j]; - } - - for (size_type i = 0; i<ncomp; ++i) { - dt_comp[i] = dt_cell[cv_to_cell[i]]; - } - } - - // perform sampling as described by marked events in a sample_event_stream - static void take_samples( - const sample_event_stream::state& s, - const_view time, - array& sample_time, - array& sample_value) - { - for (size_type i = 0; i<s.n_streams(); ++i) { - auto begin = s.begin_marked(i); - auto end = s.end_marked(i); - - for (auto p = begin; p<end; ++p) { - sample_time[p->offset] = time[i]; - sample_value[p->offset] = *p->handle; - } - } - } - - // Calculate the reversal potential eX (mV) using Nernst equation - // eX = RT/zF * ln(Xo/Xi) - // R: universal gas constant 8.3144598 J.K-1.mol-1 - // T: temperature in Kelvin - // z: valency of species (K, Na: +1) (Ca: +2) - // F: Faraday's constant 96485.33289 C.mol-1 - // Xo/Xi: ratio of out/in concentrations - static void nernst(int valency, value_type temperature, const_view Xo, const_view Xi, view eX) { - // factor 1e3 to scale from V -> mV - constexpr value_type RF = 1e3*constant::gas_constant/constant::faraday; - value_type factor = RF*temperature/valency; - for (std::size_t i=0; i<Xi.size(); ++i) { - eX[i] = factor*std::log(Xo[i]/Xi[i]); - } - } - - static void init_concentration( - view Xi, view Xo, - const_view weight_Xi, const_view weight_Xo, - value_type c_int, value_type c_ext) + static threshold_watcher voltage_watcher( + const shared_state& state, + const std::vector<index_type>& cv, + const std::vector<value_type>& thresholds) { - for (std::size_t i=0u; i<Xi.size(); ++i) { - Xi[i] = c_int*weight_Xi[i]; - Xo[i] = c_ext*weight_Xo[i]; - } - } - -private: - using maker_type = mechanism_ptr (*)(value_type, const_iview, const_view, const_view, const_view, view, view, array&&, iarray&&); - static std::map<std::string, maker_type> mech_map_; - - template <template <typename> class Mech> - static mechanism_ptr maker(value_type mech_id, const_iview vec_ci, const_view vec_t, const_view vec_t_to, const_view vec_dt, view vec_v, view vec_i, array&& weights, iarray&& node_indices) { - return arb::make_mechanism<Mech<backend>> - (mech_id, vec_ci, vec_t, vec_t_to, vec_dt, vec_v, vec_i, std::move(weights), std::move(node_indices)); + return threshold_watcher( + state.cv_to_cell.data(), + state.time.data(), + state.time_to.data(), + state.voltage.data(), + cv, + thresholds); } }; diff --git a/src/backends/multicore/intrin.hpp b/src/backends/multicore/intrin.hpp deleted file mode 100644 index 66fc2f5dd26e0410714c68d6c719dc3a9a9e9b7b..0000000000000000000000000000000000000000 --- a/src/backends/multicore/intrin.hpp +++ /dev/null @@ -1,67 +0,0 @@ -// -// Custom transcendental intrinsics -// -// Implementation inspired by the Cephes library: -// - http://www.netlib.org/cephes/ - -#pragma once - -#include <iostream> -#include <limits> - -#include <immintrin.h> - -namespace arb { -namespace multicore { - -namespace detail { - -constexpr double exp_limit = 708; - -// P/Q polynomial coefficients for the exponential function -constexpr double P0exp = 9.99999999999999999910E-1; -constexpr double P1exp = 3.02994407707441961300E-2; -constexpr double P2exp = 1.26177193074810590878E-4; - -constexpr double Q0exp = 2.00000000000000000009E0; -constexpr double Q1exp = 2.27265548208155028766E-1; -constexpr double Q2exp = 2.52448340349684104192E-3; -constexpr double Q3exp = 3.00198505138664455042E-6; - -// P/Q polynomial coefficients for the log function -constexpr double P0log = 7.70838733755885391666E0; -constexpr double P1log = 1.79368678507819816313e1; -constexpr double P2log = 1.44989225341610930846e1; -constexpr double P3log = 4.70579119878881725854e0; -constexpr double P4log = 4.97494994976747001425e-1; -constexpr double P5log = 1.01875663804580931796e-4; - -constexpr double Q0log = 2.31251620126765340583E1; -constexpr double Q1log = 7.11544750618563894466e1; -constexpr double Q2log = 8.29875266912776603211e1; -constexpr double Q3log = 4.52279145837532221105e1; -constexpr double Q4log = 1.12873587189167450590e1; -constexpr double ln2inv = 1.4426950408889634073599; // 1/ln(2) - -// C1 + C2 = ln(2) -constexpr double C1 = 6.93145751953125E-1; -constexpr double C2 = 1.42860682030941723212E-6; - -// C4 - C3 = ln(2) -constexpr double C3 = 2.121944400546905827679e-4; -constexpr double C4 = 0.693359375; - -constexpr uint64_t dmant_mask = ((1UL<<52) - 1) | (1UL << 63); // mantissa + sign -constexpr uint64_t dexp_mask = ((1UL<<11) - 1) << 52; -constexpr int exp_bias = 1023; -constexpr double dsqrth = 0.70710678118654752440; -} - -#include "intrin_avx2.hpp" - -#if defined(SIMD_KNL) || defined(SIMD_AVX512) -#include "intrin_avx512.hpp" -#endif - -} // end namespace multicore -} // end namespace arb diff --git a/src/backends/multicore/intrin_avx2.hpp b/src/backends/multicore/intrin_avx2.hpp deleted file mode 100644 index be8aa8b64aec6a12efa02010eb7547242301a3ea..0000000000000000000000000000000000000000 --- a/src/backends/multicore/intrin_avx2.hpp +++ /dev/null @@ -1,396 +0,0 @@ -#pragma once - -namespace detail { - // Useful constants in vector registers - const __m256d arb_m256d_zero = _mm256_set1_pd(0.0); - const __m256d arb_m256d_one = _mm256_set1_pd(1.0); - const __m256d arb_m256d_two = _mm256_set1_pd(2.0); - const __m256d arb_m256d_nan = _mm256_set1_pd(std::numeric_limits<double>::quiet_NaN()); - const __m256d arb_m256d_inf = _mm256_set1_pd(std::numeric_limits<double>::infinity()); - const __m256d arb_m256d_ninf = _mm256_set1_pd(-std::numeric_limits<double>::infinity()); -} - -inline void arb_mm256_print_pd(__m256d x, const char *name) __attribute__ ((unused)); -inline void arb_mm256_print_epi32(__m128i x, const char *name) __attribute__ ((unused)); -inline void arb_mm256_print_epi64x(__m256i x, const char *name) __attribute__ ((unused)); -inline __m256d arb_mm256_exp_pd(__m256d x) __attribute__ ((unused)); -inline __m256d arb_mm256_subnormal_pd(__m256d x) __attribute__ ((unused)); -inline __m256d arb_mm256_frexp_pd(__m256d x, __m128i *e) __attribute__ ((unused)); -inline __m256d arb_mm256_log_pd(__m256d x) __attribute__ ((unused)); -inline __m256d arb_mm256_abs_pd(__m256d x) __attribute__ ((unused)); -inline __m256d arb_mm256_pow_pd(__m256d x, __m256d y) __attribute__ ((unused)); -inline __m256d arb_mm256_abs_pd(__m256d x); -inline __m256d arb_mm256_min_pd(__m256d x, __m256d y); -inline __m256d arb_mm256_exprelr_pd(__m256d x); - -void arb_mm256_print_pd(__m256d x, const char *name) { - double *val = (double *) &x; - std::cout << name << " = { "; - for (size_t i = 0; i < 4; ++i) { - std::cout << val[i] << " "; - } - - std::cout << "}\n"; -} - -void arb_mm256_print_epi32(__m128i x, const char *name) { - int *val = (int *) &x; - std::cout << name << " = { "; - for (size_t i = 0; i < 4; ++i) { - std::cout << val[i] << " "; - } - - std::cout << "}\n"; -} - -void arb_mm256_print_epi64x(__m256i x, const char *name) { - uint64_t *val = (uint64_t *) &x; - std::cout << name << " = { "; - for (size_t i = 0; i < 4; ++i) { - std::cout << val[i] << " "; - } - - std::cout << "}\n"; -} - -// -// Calculates absolute value using AVX2 instructions -// -// Calculated as follows: -// abs(x) = max(x, 0-x) -// -// Other approaches that use a bitwise mask might be more efficient, but using -// max gives a simple one liner. -inline -__m256d arb_mm256_abs_pd(__m256d x) { - return _mm256_max_pd(x, _mm256_sub_pd(_mm256_set1_pd(0.), x)); -} - -// -// Calculates minimum of two values using AVX2 instructions -// -// Caluclated as follows: -// min(x,y) = x>y? y: x -inline -__m256d arb_mm256_min_pd(__m256d x, __m256d y) { - // substitute values in x with values from y where x>y - return _mm256_blendv_pd(x, y, _mm256_cmp_pd(x, y, 30)); // 30 -> _CMP_GT_OQ -} - -// -// Calculates exprelr value using AVX2 instructions -// -// Calculated as follows: -// exprelr(x) = x / (exp(x)-1) = x / expm1(x) -// -// TODO: currently calculates exp(x)-1 for the denominator, which will not be -// accurate for x≈0. A vectorized implementation of expm1(x) would fix this. -// An example of such an implementation is in Cephes. -inline -__m256d arb_mm256_exprelr_pd(__m256d x) { - const auto ones = _mm256_set1_pd(1); - return _mm256_blendv_pd( - _mm256_div_pd(x, _mm256_sub_pd(arb_mm256_exp_pd(x), ones)), // x / (exp(x)-1) - ones, // 1 - _mm256_cmp_pd(ones, _mm256_add_pd(x, ones), 0)); // 1+x == 1 -} - -// -// Calculates exponential using AVX2 instructions -// -// Exponential is calculated as follows: -// e^x = e^g * 2^n, -// -// where g in [-0.5, 0.5) and n is an integer. Obviously 2^n can be calculated -// fast with bit shift, whereas e^g is approximated using the following Pade' -// form: -// -// e^g = 1 + 2*g*P(g^2) / (Q(g^2)-P(g^2)) -// -// The exponents n and g are calculated using the following formulas: -// -// n = floor(x/ln(2) + 0.5) -// g = x - n*ln(2) -// -// They can be derived as follows: -// -// e^x = 2^(x/ln(2)) -// = 2^-0.5 * 2^(x/ln(2) + 0.5) -// = 2^r'-0.5 * 2^floor(x/ln(2) + 0.5) (1) -// -// Setting n = floor(x/ln(2) + 0.5), -// -// r' = x/ln(2) - n, and r' in [0, 1) -// -// Substituting r' in (1) gives us -// -// e^x = 2^(x/ln(2) - n) * 2^n, where x/ln(2) - n is now in [-0.5, 0.5) -// = e^(x-n*ln(2)) * 2^n -// = e^g * 2^n, where g = x - n*ln(2) (2) -// -// NOTE: The calculation of ln(2) in (2) is split in two operations to -// compensate for rounding errors: -// -// ln(2) = C1 + C2, where -// -// C1 = floor(2^k*ln(2))/2^k -// C2 = ln(2) - C1 -// -// We use k=32, since this is what the Cephes library does historically. -// Theoretically, we could use k=52 to match the IEEE-754 double accuracy, but -// the standard library seems not to do that, so we are getting differences -// compared to std::exp() for large exponents. -// -__m256d arb_mm256_exp_pd(__m256d x) { - __m256d x_orig = x; - - __m256d px = _mm256_floor_pd( - _mm256_add_pd( - _mm256_mul_pd(_mm256_set1_pd(detail::ln2inv), x), - _mm256_set1_pd(0.5) - ) - ); - - __m128i n = _mm256_cvtpd_epi32(px); - - x = _mm256_sub_pd(x, _mm256_mul_pd(px, _mm256_set1_pd(detail::C1))); - x = _mm256_sub_pd(x, _mm256_mul_pd(px, _mm256_set1_pd(detail::C2))); - - __m256d xx = _mm256_mul_pd(x, x); - - // Compute the P and Q polynomials. - - // Polynomials are computed in factorized form in order to reduce the total - // numbers of operations: - // - // P(x) = P0 + P1*x + P2*x^2 = P0 + x*(P1 + x*P2) - // Q(x) = Q0 + Q1*x + Q2*x^2 + Q3*x^3 = Q0 + x*(Q1 + x*(Q2 + x*Q3)) - - // Compute x*P(x**2) - px = _mm256_set1_pd(detail::P2exp); - px = _mm256_mul_pd(px, xx); - px = _mm256_add_pd(px, _mm256_set1_pd(detail::P1exp)); - px = _mm256_mul_pd(px, xx); - px = _mm256_add_pd(px, _mm256_set1_pd(detail::P0exp)); - px = _mm256_mul_pd(px, x); - - - // Compute Q(x**2) - __m256d qx = _mm256_set1_pd(detail::Q3exp); - qx = _mm256_mul_pd(qx, xx); - qx = _mm256_add_pd(qx, _mm256_set1_pd(detail::Q2exp)); - qx = _mm256_mul_pd(qx, xx); - qx = _mm256_add_pd(qx, _mm256_set1_pd(detail::Q1exp)); - qx = _mm256_mul_pd(qx, xx); - qx = _mm256_add_pd(qx, _mm256_set1_pd(detail::Q0exp)); - - // Compute 1 + 2*P(x**2) / (Q(x**2)-P(x**2)) - x = _mm256_div_pd(px, _mm256_sub_pd(qx, px)); - x = _mm256_add_pd(detail::arb_m256d_one, - _mm256_mul_pd(detail::arb_m256d_two, x)); - - // Finally, compute x *= 2**n - __m256i n64 = _mm256_cvtepi32_epi64(n); - n64 = _mm256_add_epi64(n64, _mm256_set1_epi64x(1023)); - n64 = _mm256_sll_epi64(n64, _mm_set_epi64x(0, 52)); - x = _mm256_mul_pd(x, _mm256_castsi256_pd(n64)); - - // Treat exceptional cases - __m256d is_large = _mm256_cmp_pd( - x_orig, _mm256_set1_pd(detail::exp_limit), 30 /* _CMP_GT_OQ */ - ); - __m256d is_small = _mm256_cmp_pd( - x_orig, _mm256_set1_pd(-detail::exp_limit), 17 /* _CMP_LT_OQ */ - ); - __m256d is_nan = _mm256_cmp_pd(x_orig, x_orig, 3 /* _CMP_UNORD_Q */ ); - - x = _mm256_blendv_pd(x, detail::arb_m256d_inf, is_large); - x = _mm256_blendv_pd(x, detail::arb_m256d_zero, is_small); - x = _mm256_blendv_pd(x, detail::arb_m256d_nan, is_nan); - return x; - -} - -__m256d arb_mm256_subnormal_pd(__m256d x) { - __m256i x_raw = _mm256_castpd_si256(x); - __m256i exp_mask = _mm256_set1_epi64x(detail::dexp_mask); - __m256d x_exp = _mm256_castsi256_pd(_mm256_and_si256(x_raw, exp_mask)); - - // Subnormals have a zero exponent - return _mm256_cmp_pd(x_exp, detail::arb_m256d_zero, 0 /* _CMP_EQ_OQ */); -} - -__m256d arb_mm256_frexp_pd(__m256d x, __m128i *e) { - __m256i exp_mask = _mm256_set1_epi64x(detail::dexp_mask); - __m256i mant_mask = _mm256_set1_epi64x(detail::dmant_mask); - - __m256d x_orig = x; - - // we will work on the raw bits of x - __m256i x_raw = _mm256_castpd_si256(x); - __m256i x_exp = _mm256_and_si256(x_raw, exp_mask); - x_exp = _mm256_srli_epi64(x_exp, 52); - - // We need bias-1 since frexp returns base values in (-1, -0.5], [0.5, 1) - x_exp = _mm256_sub_epi64(x_exp, _mm256_set1_epi64x(detail::exp_bias-1)); - - // IEEE-754 floats are in 1.<mantissa> form, but frexp needs to return a - // float in (-1, -0.5], [0.5, 1). We convert x_ret in place by adding it - // an 2^-1 exponent, i.e., 1022 in IEEE-754 format - __m256i x_ret = _mm256_and_si256(x_raw, mant_mask); - - __m256i exp_bits = _mm256_slli_epi64(_mm256_set1_epi64x(detail::exp_bias-1), 52); - x_ret = _mm256_or_si256(x_ret, exp_bits); - x = _mm256_castsi256_pd(x_ret); - - // Treat special cases - __m256d is_zero = _mm256_cmp_pd( - x_orig, detail::arb_m256d_zero, 0 /* _CMP_EQ_OQ */ - ); - __m256d is_inf = _mm256_cmp_pd( - x_orig, detail::arb_m256d_inf, 0 /* _CMP_EQ_OQ */ - ); - __m256d is_ninf = _mm256_cmp_pd( - x_orig, detail::arb_m256d_ninf, 0 /* _CMP_EQ_OQ */ - ); - __m256d is_nan = _mm256_cmp_pd(x_orig, x_orig, 3 /* _CMP_UNORD_Q */ ); - - // Denormalized numbers have a zero exponent. Here we expect -1022 since we - // have already prepared it as a power of 2 - __m256i is_denorm = _mm256_cmpeq_epi64(x_exp, _mm256_set1_epi64x(-1022)); - - x = _mm256_blendv_pd(x, detail::arb_m256d_zero, is_zero); - x = _mm256_blendv_pd(x, detail::arb_m256d_inf, is_inf); - x = _mm256_blendv_pd(x, detail::arb_m256d_ninf, is_ninf); - x = _mm256_blendv_pd(x, detail::arb_m256d_nan, is_nan); - - // FIXME: We treat denormalized numbers as zero here - x = _mm256_blendv_pd(x, detail::arb_m256d_zero, - _mm256_castsi256_pd(is_denorm)); - x_exp = _mm256_blendv_epi8(x_exp, _mm256_set1_epi64x(0), is_denorm); - - x_exp = _mm256_blendv_epi8(x_exp, _mm256_set1_epi64x(0), - _mm256_castpd_si256(is_zero)); - - - // We need to "compress" x_exp into the first 128 bits before casting it - // safely to __m128i and return to *e - x_exp = _mm256_permutevar8x32_epi32( - x_exp, _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0) - ); - *e = _mm256_castsi256_si128(x_exp); - return x; -} - -// -// Calculates natural logarithm using AVX2 instructions -// -// ln(x) = ln(x'*2^g), x' in [0,1), g in N -// = ln(x') + g*ln(2) -// -// The logarithm in [0,1) is computed using the following Pade' form: -// -// ln(1+x) = x - 0.5*x^2 + x^3*P(x)/Q(x) -// -__m256d arb_mm256_log_pd(__m256d x) { - __m256d x_orig = x; - __m128i x_exp; - - // x := x', x_exp := g - x = arb_mm256_frexp_pd(x, &x_exp); - - // convert x_exp to packed double - __m256d dx_exp = _mm256_cvtepi32_pd(x_exp); - - // blending - __m256d lt_sqrth = _mm256_cmp_pd( - x, _mm256_set1_pd(detail::dsqrth), 17 /* _CMP_LT_OQ */); - - // Adjust the argument and the exponent - // | 2*x - 1; e := e -1 , if x < sqrt(2)/2 - // x := | - // | x - 1, otherwise - - // Precompute both branches - // 2*x - 1 - __m256d x2m1 = _mm256_sub_pd(_mm256_add_pd(x, x), detail::arb_m256d_one); - - // x - 1 - __m256d xm1 = _mm256_sub_pd(x, detail::arb_m256d_one); - - // dx_exp - 1 - __m256d dx_exp_m1 = _mm256_sub_pd(dx_exp, detail::arb_m256d_one); - - x = _mm256_blendv_pd(xm1, x2m1, lt_sqrth); - dx_exp = _mm256_blendv_pd(dx_exp, dx_exp_m1, lt_sqrth); - - // compute P(x) - __m256d px = _mm256_set1_pd(detail::P5log); - px = _mm256_mul_pd(px, x); - px = _mm256_add_pd(px, _mm256_set1_pd(detail::P4log)); - px = _mm256_mul_pd(px, x); - px = _mm256_add_pd(px, _mm256_set1_pd(detail::P3log)); - px = _mm256_mul_pd(px, x); - px = _mm256_add_pd(px, _mm256_set1_pd(detail::P2log)); - px = _mm256_mul_pd(px, x); - px = _mm256_add_pd(px, _mm256_set1_pd(detail::P1log)); - px = _mm256_mul_pd(px, x); - px = _mm256_add_pd(px, _mm256_set1_pd(detail::P0log)); - - // xx := x^2 - // px := P(x)*x^3 - __m256d xx = _mm256_mul_pd(x, x); - px = _mm256_mul_pd(px, x); - px = _mm256_mul_pd(px, xx); - - // compute Q(x) - __m256d qx = x; - qx = _mm256_add_pd(qx, _mm256_set1_pd(detail::Q4log)); - qx = _mm256_mul_pd(qx, x); - qx = _mm256_add_pd(qx, _mm256_set1_pd(detail::Q3log)); - qx = _mm256_mul_pd(qx, x); - qx = _mm256_add_pd(qx, _mm256_set1_pd(detail::Q2log)); - qx = _mm256_mul_pd(qx, x); - qx = _mm256_add_pd(qx, _mm256_set1_pd(detail::Q1log)); - qx = _mm256_mul_pd(qx, x); - qx = _mm256_add_pd(qx, _mm256_set1_pd(detail::Q0log)); - - // x^3*P(x)/Q(x) - __m256d ret = _mm256_div_pd(px, qx); - - // x^3*P(x)/Q(x) - g*ln(2) - ret = _mm256_sub_pd( - ret, _mm256_mul_pd(dx_exp, _mm256_set1_pd(detail::C3)) - ); - - // -.5*x^ + x^3*P(x)/Q(x) - g*ln(2) - ret = _mm256_sub_pd(ret, _mm256_mul_pd(_mm256_set1_pd(0.5), xx)); - - // x -.5*x^ + x^3*P(x)/Q(x) - g*ln(2) - ret = _mm256_add_pd(ret, x); - - // rounding error correction for ln(2) - ret = _mm256_add_pd(ret, _mm256_mul_pd(dx_exp, _mm256_set1_pd(detail::C4))); - - // Treat exceptional cases - __m256d is_inf = _mm256_cmp_pd( - x_orig, detail::arb_m256d_inf, 0 /* _CMP_EQ_OQ */); - __m256d is_zero = _mm256_cmp_pd( - x_orig, detail::arb_m256d_zero, 0 /* _CMP_EQ_OQ */); - __m256d is_neg = _mm256_cmp_pd( - x_orig, detail::arb_m256d_zero, 17 /* _CMP_LT_OQ */); - __m256d is_denorm = arb_mm256_subnormal_pd(x_orig); - - ret = _mm256_blendv_pd(ret, detail::arb_m256d_inf, is_inf); - ret = _mm256_blendv_pd(ret, detail::arb_m256d_ninf, is_zero); - - // We treat denormalized cases as zeros - ret = _mm256_blendv_pd(ret, detail::arb_m256d_ninf, is_denorm); - ret = _mm256_blendv_pd(ret, detail::arb_m256d_nan, is_neg); - return ret; -} - -// Equivalent to exp(y*log(x)) -__m256d arb_mm256_pow_pd(__m256d x, __m256d y) { - return arb_mm256_exp_pd(_mm256_mul_pd(y, arb_mm256_log_pd(x))); -} diff --git a/src/backends/multicore/intrin_avx512.hpp b/src/backends/multicore/intrin_avx512.hpp deleted file mode 100644 index 1ef6354dfd55f594f6f760b932c01b1f6445964a..0000000000000000000000000000000000000000 --- a/src/backends/multicore/intrin_avx512.hpp +++ /dev/null @@ -1,80 +0,0 @@ -#pragma once - -// vector types for avx512 - -// double precision avx512 register -using vecd_avx512 = __m512d; -// 8 way mask for avx512 register (for use with double precision) -using mask8_avx512 = __mmask8; - -inline vecd_avx512 set(double x) { - return _mm512_set1_pd(x); -} - -namespace detail { - // Useful constants in vector registers - const vecd_avx512 vecd_avx512_zero = set(0.0); - const vecd_avx512 vecd_avx512_one = set(1.0); - const vecd_avx512 vecd_avx512_two = set(2.0); - const vecd_avx512 vecd_avx512_nan = set(std::numeric_limits<double>::quiet_NaN()); - const vecd_avx512 vecd_avx512_inf = set(std::numeric_limits<double>::infinity()); - const vecd_avx512 vecd_avx512_ninf = set(-std::numeric_limits<double>::infinity()); -} - -// -// Operations on vector registers. -// -// shorter, less verbose wrappers around intrinsics -// - -inline vecd_avx512 blend(mask8_avx512 m, vecd_avx512 x, vecd_avx512 y) { - return _mm512_mask_blend_pd(m, x, y); -} - -inline vecd_avx512 add(vecd_avx512 x, vecd_avx512 y) { - return _mm512_add_pd(x, y); -} - -inline vecd_avx512 sub(vecd_avx512 x, vecd_avx512 y) { - return _mm512_sub_pd(x, y); -} - -inline vecd_avx512 mul(vecd_avx512 x, vecd_avx512 y) { - return _mm512_mul_pd(x, y); -} - -inline vecd_avx512 div(vecd_avx512 x, vecd_avx512 y) { - return _mm512_div_pd(x, y); -} - -inline vecd_avx512 max(vecd_avx512 x, vecd_avx512 y) { - return _mm512_max_pd(x, y); -} - -inline vecd_avx512 expm1(vecd_avx512 x) { - // Assume that we are using the Intel compiler, and use the vectorized expm1 - // defined in the Intel SVML library. - return _mm512_expm1_pd(x); -} - -inline mask8_avx512 less(vecd_avx512 x, vecd_avx512 y) { - return _mm512_cmp_pd_mask(x, y, 0); -} - -inline mask8_avx512 greater(vecd_avx512 x, vecd_avx512 y) { - return _mm512_cmp_pd_mask(x, y, 30); -} - -inline vecd_avx512 abs(vecd_avx512 x) { - return max(x, sub(set(0.), x)); -} - -inline vecd_avx512 min(vecd_avx512 x, vecd_avx512 y) { - // substitute values in x with values from y where x>y - return blend(greater(x, y), x, y); -} - -inline vecd_avx512 exprelr(vecd_avx512 x) { - const auto ones = set(1); - return blend(less(ones, add(x, ones)), div(x, expm1(x)), ones); -} diff --git a/src/backends/multicore/matrix_state.hpp b/src/backends/multicore/matrix_state.hpp index 45004231a0d481d0442e29b2498354520610c5b2..67976ee401827e15d2b4afcc79f12840453d1612 100644 --- a/src/backends/multicore/matrix_state.hpp +++ b/src/backends/multicore/matrix_state.hpp @@ -1,9 +1,10 @@ #pragma once -#include <memory/memory.hpp> #include <util/partition.hpp> #include <util/span.hpp> +#include "multicore_common.hpp" + namespace arb { namespace multicore { @@ -11,11 +12,12 @@ template <typename T, typename I> struct matrix_state { public: using value_type = T; - using size_type = I; + using index_type = I; + + using array = padded_vector<value_type>; + using const_view = const array&; - using array = memory::host_vector<value_type>; - using const_view = typename array::const_view_type; - using iarray = memory::host_vector<size_type>; + using iarray = padded_vector<index_type>; iarray parent_index; iarray cell_cv_divs; @@ -32,21 +34,21 @@ public: matrix_state() = default; - matrix_state(const std::vector<size_type>& p, - const std::vector<size_type>& cell_cv_divs, + matrix_state(const std::vector<index_type>& p, + const std::vector<index_type>& cell_cv_divs, const std::vector<value_type>& cap, const std::vector<value_type>& cond, const std::vector<value_type>& area): - parent_index(memory::make_const_view(p)), - cell_cv_divs(memory::make_const_view(cell_cv_divs)), + parent_index(p.begin(), p.end()), + cell_cv_divs(cell_cv_divs.begin(), cell_cv_divs.end()), d(size(), 0), u(size(), 0), rhs(size()), - cv_capacitance(memory::make_const_view(cap)), - face_conductance(memory::make_const_view(cond)), - cv_area(memory::make_const_view(area)) + cv_capacitance(cap.begin(), cap.end()), + face_conductance(cond.begin(), cond.end()), + cv_area(area.begin(), area.end()) { EXPECTS(cap.size() == size()); EXPECTS(cond.size() == size()); - EXPECTS(cell_cv_divs.back() == size()); + EXPECTS(cell_cv_divs.back() == (index_type)size()); auto n = size(); invariant_d = array(n, 0); @@ -62,7 +64,7 @@ public: const_view solution() const { // In this back end the solution is a simple view of the rhs, which // contains the solution after the matrix_solve is performed. - return const_view(rhs); + return rhs; } @@ -73,7 +75,7 @@ public: // current density [A.m^-2] (per compartment) void assemble(const_view dt_cell, const_view voltage, const_view current) { auto cell_cv_part = util::partition_view(cell_cv_divs); - const size_type ncells = cell_cv_part.size(); + const index_type ncells = cell_cv_part.size(); // loop over submatrices for (auto m: util::make_span(0, ncells)) { diff --git a/src/backends/multicore/mechanism.cpp b/src/backends/multicore/mechanism.cpp new file mode 100644 index 0000000000000000000000000000000000000000..73f955291f04e33151821e10ff6fccbdbba5cc05 --- /dev/null +++ b/src/backends/multicore/mechanism.cpp @@ -0,0 +1,178 @@ +#include <algorithm> +#include <cstddef> +#include <cmath> +#include <string> +#include <utility> +#include <vector> + +#include <backends/fvm_types.hpp> +#include <common_types.hpp> + +#include <math.hpp> +#include <mechanism.hpp> +#include <util/index_into.hpp> +#include <util/optional.hpp> +#include <util/maputil.hpp> +#include <util/padded_alloc.hpp> +#include <util/range.hpp> + +#include <backends/multicore/mechanism.hpp> +#include <backends/multicore/multicore_common.hpp> +#include <backends/multicore/fvm.hpp> + +namespace arb { +namespace multicore { + +using util::make_range; +using util::value_by_key; + +// Copy elements from source sequence into destination sequence, +// and fill the remaining elements of the destination sequence +// with the given fill value. +// +// Assumes that the iterators for these sequences are at least +// forward iterators. + +template <typename Source, typename Dest, typename Fill> +void copy_extend(const Source& source, Dest&& dest, const Fill& fill) { + using std::begin; + using std::end; + + auto dest_n = util::size(dest); + auto source_n = util::size(source); + + auto n = source_n<dest_n? source_n: dest_n; + auto tail = std::copy_n(begin(source), n, begin(dest)); + std::fill(tail, end(dest), fill); +} + +// The derived class (typically generated code from modcc) holds pointers that need +// to be set to point inside the shared state, or into the allocated parameter/variable +// data block. +// +// In ths SIMD case, there may be a 'tail' of values that correspond to a partial +// SIMD value when the width is not a multiple of the SIMD data width. In this +// implementation we do not use SIMD masking to avoid tail values, but instead +// extend the vectors to a multiple of the SIMD width: sites/CVs corresponding to +// these past-the-end values are given a weight of zero, and any corresponding +// indices into shared state point to the last valid slot. + +void mechanism::instantiate(fvm_size_type id, backend::shared_state& shared, const layout& pos_data) { + using util::make_range; + + util::padded_allocator<> pad(shared.alignment); + mechanism_id_ = id; + width_ = pos_data.cv.size(); + + // Assign non-owning views onto shared state: + + vec_ci_ = shared.cv_to_cell.data(); + vec_t_ = shared.time.data(); + vec_t_to_ = shared.time_to.data(); + vec_dt_ = shared.dt_cv.data(); + + vec_v_ = shared.voltage.data(); + vec_i_ = shared.current_density.data(); + + auto ion_state_tbl = ion_state_table(); + n_ion_ = ion_state_tbl.size(); + for (auto i: ion_state_tbl) { + util::optional<ion_state&> oion = value_by_key(shared.ion_data, i.first); + if (!oion) { + throw std::logic_error("mechanism holds ion with no corresponding shared state"); + } + + ion_state_view& ion_view = *i.second; + ion_view.current_density = oion->iX_.data(); + ion_view.reversal_potential = oion->eX_.data(); + ion_view.internal_concentration = oion->Xi_.data(); + ion_view.external_concentration = oion->Xo_.data(); + } + + event_stream_ptr_ = &shared.deliverable_events; + + // If there are no sites (is this ever meaningful?) there is nothing more to do. + if (width_==0) { + return; + } + + // Extend width to account for requisite SIMD padding. + width_padded_ = math::round_up(width_, shared.alignment); + + // Allocate and initialize state and parameter vectors with default values. + + auto fields = field_table(); + std::size_t n_field = fields.size(); + + // (First sub-array of data_ is used for width_, below.) + data_ = array((1+n_field)*width_padded_, NAN, pad); + for (std::size_t i = 0; i<n_field; ++i) { + // Take reference to corresponding derived (generated) mechanism value pointer member. + fvm_value_type*& field_ptr = *(fields[i].second); + field_ptr = data_.data()+(i+1)*width_padded_; + + if (auto opt_value = value_by_key(field_default_table(), fields[i].first)) { + std::fill(field_ptr, field_ptr+width_padded_, *opt_value); + } + } + weight_ = data_.data(); + + // Allocate and copy local state: weight, node indices, ion indices. + // The tail comprises those elements between width_ and width_padded_: + // + // * For entries in the padded tail of weight_, set weight to zero. + // * For indices in the padded tail of node_index_, set index to last valid CV index. + // * For indices in the padded tail of ion index maps, set index to last valid ion index. + + node_index_ = iarray(width_padded_, pad); + copy_extend(pos_data.cv, node_index_, pos_data.cv.back()); + copy_extend(pos_data.weight, make_range(data_.data(), data_.data()+width_padded_), 0); + + for (auto i: ion_index_table()) { + util::optional<ion_state&> oion = value_by_key(shared.ion_data, i.first); + if (!oion) { + throw std::logic_error("mechanism holds ion with no corresponding shared state"); + } + + auto indices = util::index_into(node_index_, oion->node_index_); + + // Take reference to derived (generated) mechanism ion index member. + auto& ion_index = *i.second; + ion_index = iarray(width_padded_, pad); + copy_extend(indices, ion_index, util::back(indices)); + } + +} + +void mechanism::set_parameter(const std::string& key, const std::vector<fvm_value_type>& values) { + if (auto opt_ptr = value_by_key(field_table(), key)) { + if (values.size()!=width_) { + throw std::logic_error("internal error: mechanism parameter size mismatch"); + } + + if (width_>0) { + // Retrieve corresponding derived (generated) mechanism value pointer member. + value_type* field_ptr = *opt_ptr.value(); + util::range<value_type*> field(field_ptr, field_ptr+width_padded_); + + copy_extend(values, field, values.back()); + } + } + else { + throw std::logic_error("internal error: no such mechanism parameter"); + } +} + +void mechanism::set_global(const std::string& key, fvm_value_type value) { + if (auto opt_ptr = value_by_key(global_table(), key)) { + // Take reference to corresponding derived (generated) mechanism value member. + value_type& global = *opt_ptr.value(); + global = value; + } + else { + throw std::logic_error("internal error: no such mechanism global"); + } +} + +} // namespace multicore +} // namespace arb diff --git a/src/backends/multicore/mechanism.hpp b/src/backends/multicore/mechanism.hpp new file mode 100644 index 0000000000000000000000000000000000000000..066e287f76430f6d1898c84922f782ee751b21c0 --- /dev/null +++ b/src/backends/multicore/mechanism.hpp @@ -0,0 +1,133 @@ +#pragma once + +#include <algorithm> +#include <cstddef> +#include <cmath> +#include <string> +#include <utility> +#include <vector> + +#include <backends/fvm_types.hpp> +#include <common_types.hpp> +#include <mechanism.hpp> + +#include <backends/multicore/multicore_common.hpp> +#include <backends/multicore/fvm.hpp> + +namespace arb { +namespace multicore { + +// Base class for all generated mechanisms for multicore back-end. + +class mechanism: public arb::concrete_mechanism<arb::multicore::backend> { +public: + using value_type = fvm_value_type; + using index_type = fvm_index_type; + using size_type = fvm_size_type; + +protected: + using backend = arb::multicore::backend; + using deliverable_event_stream = backend::deliverable_event_stream; + + using array = arb::multicore::array; + using iarray = arb::multicore::iarray; + + struct ion_state_view { + value_type* current_density; + value_type* reversal_potential; + value_type* internal_concentration; + value_type* external_concentration; + }; + +public: + std::size_t size() const override { + return width_; + } + + std::size_t memory() const override { + std::size_t s = object_sizeof(); + + s += sizeof(value_type) * data_.size(); + s += sizeof(size_type) * width_padded_ * (n_ion_ + 1); // node and ion indices. + return s; + } + + void instantiate(fvm_size_type id, backend::shared_state& shared, const layout& w) override; + + void deliver_events() override { + // Delegate to derived class, passing in event queue state. + deliver_events(event_stream_ptr_->marked_events()); + } + + void set_parameter(const std::string& key, const std::vector<fvm_value_type>& values) override; + + void set_global(const std::string& key, fvm_value_type value) override; + +protected: + size_type width_ = 0; // Instance width (number of CVs/sites) + size_type width_padded_ = 0; // Width rounded up to multiple of pad/alignment. + size_type n_ion_ = 0; + + // Non-owning views onto shared cell state, excepting ion state. + + const index_type* vec_ci_; // CV to cell index. + const value_type* vec_t_; // Cell index to cell-local time. + const value_type* vec_t_to_; // Cell index to cell-local integration step time end. + const value_type* vec_dt_; // CV to integration time step. + const value_type* vec_v_; // CV to cell membrane voltage. + value_type* vec_i_; // CV to cell membrane current density. + deliverable_event_stream* event_stream_ptr_; + + // Per-mechanism index and weight data, excepting ion indices. + + iarray node_index_; + const value_type* weight_; // Points within data_ after instantiation. + + // Bulk storage for state and parameter variables. + + array data_; + + // Generated mechanism field, global and ion table lookup types. + // First component is name, second is pointer to corresponing member in + // the mechanism's parameter pack, or for field_default_table, + // the scalar value used to initialize the field. + + using global_table_entry = std::pair<const char*, value_type*>; + using mechanism_global_table = std::vector<global_table_entry>; + + using field_table_entry = std::pair<const char*, value_type**>; + using mechanism_field_table = std::vector<field_table_entry>; + + using field_default_entry = std::pair<const char*, value_type>; + using mechanism_field_default_table = std::vector<field_default_entry>; + + using ion_state_entry = std::pair<ionKind, ion_state_view*>; + using mechanism_ion_state_table = std::vector<ion_state_entry>; + + using ion_index_entry = std::pair<ionKind, iarray*>; + using mechanism_ion_index_table = std::vector<ion_index_entry>; + + // Generated mechanisms must implement the following methods, together with + // fingerprint(), clone(), kind(), nrn_init(), nrn_state(), nrn_current() + // and deliver_events() (if required) from arb::mechanism. + + // Member tables: introspection into derived mechanism fields, views etc. + // Default implementations correspond to no corresponding fields/globals/ions. + + virtual mechanism_field_table field_table() { return {}; } + virtual mechanism_field_default_table field_default_table() { return {}; } + virtual mechanism_global_table global_table() { return {}; } + virtual mechanism_ion_state_table ion_state_table() { return {}; } + virtual mechanism_ion_index_table ion_index_table() { return {}; } + + // Report raw size in bytes of mechanism object. + + virtual std::size_t object_sizeof() const = 0; + + // Event delivery, given event queue state: + + virtual void deliver_events(deliverable_event_stream::state) {}; +}; + +} // namespace multicore +} // namespace arb diff --git a/src/backends/multicore/multi_event_stream.hpp b/src/backends/multicore/multi_event_stream.hpp index dc6683faf6fdb3d6d8a356c1fe6ea0dbe465520c..0b188c63b7f82e24c5fb1374dc9dc5af0653512a 100644 --- a/src/backends/multicore/multi_event_stream.hpp +++ b/src/backends/multicore/multi_event_stream.hpp @@ -6,8 +6,8 @@ #include <ostream> #include <utility> -#include <common_types.hpp> #include <backends/event.hpp> +#include <backends/fvm_types.hpp> #include <backends/multi_event_stream_state.hpp> #include <generic_event.hpp> #include <algorithms.hpp> @@ -22,7 +22,8 @@ namespace multicore { template <typename Event> class multi_event_stream { public: - using size_type = cell_size_type; + using size_type = fvm_size_type; + using index_type = fvm_index_type; using event_type = Event; using event_time_type = ::arb::event_time_type<Event>; @@ -44,9 +45,9 @@ public: ev_data_.clear(); remaining_ = 0; - util::fill(span_begin_, 0u); - util::fill(span_end_, 0u); - util::fill(mark_, 0u); + util::fill(span_begin_, 0); + util::fill(span_end_, 0); + util::fill(mark_, 0); } // Initialize event streams from a vector of events, sorted by time. @@ -72,10 +73,10 @@ public: EXPECTS(n_streams() == span_end_.size()); EXPECTS(n_streams() == mark_.size()); - size_type ev_begin_i = 0; - size_type ev_i = 0; + index_type ev_begin_i = 0; + index_type ev_i = 0; for (size_type s = 0; s<n_streams(); ++s) { - while (ev_i<n_ev && event_index(staged[ev_i])<s+1) ++ev_i; + while ((size_type)ev_i<n_ev && (size_type)(event_index(staged[ev_i]))<s+1) ++ev_i; // Within a subrange of events with the same index, events should // be sorted by time. @@ -196,9 +197,9 @@ public: private: std::vector<event_time_type> ev_time_; - std::vector<size_type> span_begin_; - std::vector<size_type> span_end_; - std::vector<size_type> mark_; + std::vector<index_type> span_begin_; + std::vector<index_type> span_end_; + std::vector<index_type> mark_; std::vector<event_data_type> ev_data_; size_type remaining_ = 0; }; diff --git a/src/backends/multicore/multicore_common.hpp b/src/backends/multicore/multicore_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dfb75c43c8d150191c0d7f8e49891868b50d8a87 --- /dev/null +++ b/src/backends/multicore/multicore_common.hpp @@ -0,0 +1,33 @@ +#pragma once + +// Storage classes and other common types across +// multicore back end implementations. +// +// Defines array, iarray, and specialized multi-event stream classes. + +#include <utility> +#include <vector> + +#include <backends/event.hpp> +#include <backends/fvm_types.hpp> +#include <math.hpp> +#include <simd/simd.hpp> +#include <util/padded_alloc.hpp> + +#include "multi_event_stream.hpp" + +namespace arb { +namespace multicore { + +template <typename V> +using padded_vector = std::vector<V, util::padded_allocator<V>>; + +using array = padded_vector<fvm_value_type>; +using iarray = padded_vector<fvm_index_type>; + +using deliverable_event_stream = arb::multicore::multi_event_stream<deliverable_event>; +using sample_event_stream = arb::multicore::multi_event_stream<sample_event>; + +} // namespace multicore +} // namespace arb + diff --git a/src/backends/multicore/shared_state.cpp b/src/backends/multicore/shared_state.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a15886923cf94fc81892e0bf4e9f726b0ca2acc0 --- /dev/null +++ b/src/backends/multicore/shared_state.cpp @@ -0,0 +1,258 @@ +#include <cmath> +#include <iostream> +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> + +#include <backends/event.hpp> +#include <backends/fvm_types.hpp> +#include <common_types.hpp> +#include <constants.hpp> +#include <ion.hpp> +#include <math.hpp> +#include <simd/simd.hpp> +#include <util/padded_alloc.hpp> +#include <util/rangeutil.hpp> + +#include <util/debug.hpp> + +#include "multi_event_stream.hpp" +#include "multicore_common.hpp" +#include "shared_state.hpp" + +namespace arb { +namespace multicore { + +constexpr unsigned simd_width = simd::simd_abi::native_width<fvm_value_type>::value; +using simd_value_type = simd::simd<fvm_value_type, simd_width>; +using simd_index_type = simd::simd<fvm_index_type, simd_width>; + +// Pick alignment compatible with native SIMD width for explicitly +// vectorized operations below. +// +// TODO: Is SIMD use here a win? Test and compare; may be better to leave +// these up to the compiler to optimize/auto-vectorize. + +inline unsigned min_alignment(unsigned align) { + unsigned simd_align = sizeof(fvm_value_type)*simd_width; + return math::next_pow2(std::max(align, simd_align)); +} + +using pad = util::padded_allocator<>; + + +// ion_state methods: + +ion_state::ion_state( + ion_info info, + const std::vector<fvm_index_type>& cv, + const std::vector<fvm_value_type>& iconc_norm_area, + const std::vector<fvm_value_type>& econc_norm_area, + unsigned align +): + alignment(min_alignment(align)), + node_index_(cv.begin(), cv.end(), pad(alignment)), + iX_(cv.size(), NAN, pad(alignment)), + eX_(cv.size(), NAN, pad(alignment)), + Xi_(cv.size(), NAN, pad(alignment)), + Xo_(cv.size(), NAN, pad(alignment)), + weight_Xi_(iconc_norm_area.begin(), iconc_norm_area.end(), pad(alignment)), + weight_Xo_(econc_norm_area.begin(), econc_norm_area.end(), pad(alignment)), + charge(info.charge), + default_int_concentration(info.default_int_concentration), + default_ext_concentration(info.default_ext_concentration) +{ + EXPECTS(node_index_.size()==weight_Xi_.size()); + EXPECTS(node_index_.size()==weight_Xo_.size()); +} + +void ion_state::nernst(fvm_value_type temperature_K) { + // Nernst equation: reversal potenial eX given by: + // + // eX = RT/zF * ln(Xo/Xi) + // + // where: + // R: universal gas constant 8.3144598 J.K-1.mol-1 + // T: temperature in Kelvin + // z: valency of species (K, Na: +1) (Ca: +2) + // F: Faraday's constant 96485.33289 C.mol-1 + // Xo/Xi: ratio of out/in concentrations + + // 1e3 factor required to scale from V -> mV. + constexpr fvm_value_type RF = 1e3*constant::gas_constant/constant::faraday; + + simd_value_type factor = RF*temperature_K/charge; + for (std::size_t i=0; i<Xi_.size(); i+=simd_width) { + simd_value_type xi(Xi_.data()+i); + simd_value_type xo(Xo_.data()+i); + + auto ex = factor*log(xo/xi); + ex.copy_to(eX_.data()+i); + } +} + +void ion_state::init_concentration() { + for (std::size_t i=0u; i<Xi_.size(); i+=simd_width) { + simd_value_type weight_xi(weight_Xi_.data()+i); + simd_value_type weight_xo(weight_Xo_.data()+i); + + auto xi = default_int_concentration*weight_xi; + xi.copy_to(Xi_.data()+i); + + auto xo = default_ext_concentration*weight_xo; + xo.copy_to(Xo_.data()+i); + } +} + +void ion_state::zero_current() { + util::fill(iX_, 0); +} + + +// shared_state methods: + +shared_state::shared_state( + fvm_size_type n_cell, + const std::vector<fvm_index_type>& cv_to_cell_vec, + unsigned align +): + alignment(min_alignment(align)), + alloc(alignment), + n_cell(n_cell), + n_cv(cv_to_cell_vec.size()), + cv_to_cell(math::round_up(n_cv, alignment), pad(alignment)), + time(n_cell, pad(alignment)), + time_to(n_cell, pad(alignment)), + dt_cell(n_cell, pad(alignment)), + dt_cv(n_cv, pad(alignment)), + voltage(n_cv, pad(alignment)), + current_density(n_cv, pad(alignment)), + deliverable_events(n_cell) +{ + // For indices in the padded tail of cv_to_cell, set index to last valid cell index. + + if (n_cv>0) { + std::copy(cv_to_cell_vec.begin(), cv_to_cell_vec.end(), cv_to_cell.begin()); + std::fill(cv_to_cell.begin()+n_cv, cv_to_cell.end(), cv_to_cell_vec.back()); + } +} + +void shared_state::add_ion( + ion_info info, + const std::vector<fvm_index_type>& cv, + const std::vector<fvm_value_type>& iconc_norm_area, + const std::vector<fvm_value_type>& econc_norm_area) +{ + ion_data.emplace(std::piecewise_construct, + std::forward_as_tuple(info.kind), + std::forward_as_tuple(info, cv, iconc_norm_area, econc_norm_area, alignment)); +} + +void shared_state::reset(fvm_value_type initial_voltage, fvm_value_type temperature_K) { + util::fill(voltage, initial_voltage); + util::fill(current_density, 0); + util::fill(time, 0); + util::fill(time_to, 0); + + for (auto& i: ion_data) { + i.second.reset(temperature_K); + } +} + +void shared_state::zero_currents() { + util::fill(current_density, 0); + for (auto& i: ion_data) { + i.second.zero_current(); + } +} + +void shared_state::ions_init_concentration() { + for (auto& i: ion_data) { + i.second.init_concentration(); + } +} + +void shared_state::ions_nernst_reversal_potential(fvm_value_type temperature_K) { + for (auto& i: ion_data) { + i.second.nernst(temperature_K); + } +} + +void shared_state::update_time_to(fvm_value_type dt_step, fvm_value_type tmax) { + for (fvm_size_type i = 0; i<n_cell; i+=simd_width) { + simd_value_type t(time.data()+i); + t = min(t+dt_step, simd_value_type(tmax)); + t.copy_to(time_to.data()+i); + } +} + +void shared_state::set_dt() { + for (fvm_size_type j = 0; j<n_cell; j+=simd_width) { + simd_value_type t(time.data()+j); + simd_value_type t_to(time_to.data()+j); + + auto dt = t_to-t; + dt.copy_to(dt_cell.data()+j); + } + + for (fvm_size_type i = 0; i<n_cv; i+=simd_width) { + simd_index_type cell_idx(cv_to_cell.data()+i); + + simd_value_type dt(simd::indirect(dt_cell.data(), cell_idx)); + dt.copy_to(dt_cv.data()+i); + } +} + +std::pair<fvm_value_type, fvm_value_type> shared_state::time_bounds() const { + return util::minmax_value(time); +} + +std::pair<fvm_value_type, fvm_value_type> shared_state::voltage_bounds() const { + return util::minmax_value(voltage); +} + +void shared_state::take_samples( + const sample_event_stream::state& s, + array& sample_time, + array& sample_value) +{ + for (fvm_size_type i = 0; i<s.n_streams(); ++i) { + auto begin = s.begin_marked(i); + auto end = s.end_marked(i); + + // (Note: probably not worth explicitly vectorizing this.) + for (auto p = begin; p<end; ++p) { + sample_time[p->offset] = time[i]; + sample_value[p->offset] = *p->handle; + } + } +} + +// (Debug interface only.) +std::ostream& operator<<(std::ostream& out, const shared_state& s) { + using util::csv; + + out << "n_cell " << s.n_cell << "\n----\n"; + out << "n_cv " << s.n_cv << "\n----\n"; + out << "cv_to_cell:\n" << csv(s.cv_to_cell) << "\n"; + out << "time:\n" << csv(s.time) << "\n"; + out << "time_to:\n" << csv(s.time_to) << "\n"; + out << "dt:\n" << csv(s.dt_cell) << "\n"; + out << "dt_comp:\n" << csv(s.dt_cv) << "\n"; + out << "voltage:\n" << csv(s.voltage) << "\n"; + out << "current_density:\n" << csv(s.current_density) << "\n"; + for (auto& ki: s.ion_data) { + auto kn = to_string(ki.first); + auto& i = const_cast<ion_state&>(ki.second); + out << kn << ".current_density:\n" << csv(i.iX_) << "\n"; + out << kn << ".reversal_potential:\n" << csv(i.eX_) << "\n"; + out << kn << ".internal_concentration:\n" << csv(i.Xi_) << "\n"; + out << kn << ".external_concentration:\n" << csv(i.Xo_) << "\n"; + } + + return out; +} + +} // namespace multicore +} // namespace arb diff --git a/src/backends/multicore/shared_state.hpp b/src/backends/multicore/shared_state.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a95605c54a8fd185b286945454e7da0532814294 --- /dev/null +++ b/src/backends/multicore/shared_state.hpp @@ -0,0 +1,152 @@ +#pragma once + +#include <cmath> +#include <iosfwd> +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> + +#include <backends/event.hpp> +#include <backends/fvm_types.hpp> +#include <common_types.hpp> +#include <constants.hpp> +#include <event_queue.hpp> +#include <ion.hpp> +#include <math.hpp> +#include <simd/simd.hpp> +#include <util/enumhash.hpp> +#include <util/padded_alloc.hpp> +#include <util/rangeutil.hpp> + +#include <util/debug.hpp> + +#include "matrix_state.hpp" +#include "multi_event_stream.hpp" +#include "threshold_watcher.hpp" + +#include "multicore_common.hpp" + +namespace arb { +namespace multicore { + +/* + * Ion state fields correspond to NMODL ion variables, where X + * is replaced with the name of the ion. E.g. for calcium 'ca': + * + * Field NMODL variable Meaning + * ------------------------------------------------------- + * iX_ ica calcium ion current density + * eX_ eca calcium ion channel reversal potential + * Xi_ cai internal calcium concentration + * Xo_ cao external calcium concentration + */ + +struct ion_state { + unsigned alignment = 1; // Alignment and padding multiple. + + iarray node_index_; // Instance to CV map. + array iX_; // (nA) current + array eX_; // (mV) reversal potential + array Xi_; // (mM) internal concentration + array Xo_; // (mM) external concentration + array weight_Xi_; // (1) concentration weight internal + array weight_Xo_; // (1) concentration weight external + + int charge; // charge of ionic species + fvm_value_type default_int_concentration; // (mM) default internal concentration + fvm_value_type default_ext_concentration; // (mM) default external concentration + + ion_state() = default; + + ion_state( + ion_info info, + const std::vector<fvm_index_type>& cv, + const std::vector<fvm_value_type>& iconc_norm_area, + const std::vector<fvm_value_type>& econc_norm_area, + unsigned align + ); + + // Calculate the reversal potential eX (mV) using Nernst equation + void nernst(fvm_value_type temperature_K); + + // Set ion concentrations to weighted proportion of default concentrations. + void init_concentration(); + + // Set ionic current density to zero. + void zero_current(); + + void reset(fvm_value_type temperature_K) { + zero_current(); + init_concentration(); + nernst(temperature_K); + } +}; + +struct shared_state { + unsigned alignment = 1; // Alignment and padding multiple. + util::padded_allocator<> alloc; // Allocator with corresponging alignment/padding. + + fvm_size_type n_cell = 0; // Number of distinct cells (integration domains). + fvm_size_type n_cv = 0; // Total number of CVs. + + iarray cv_to_cell; // Maps CV index to cell index. + array time; // Maps cell index to integration start time [ms]. + array time_to; // Maps cell index to integration stop time [ms]. + array dt_cell; // Maps cell index to (stop time) - (start time) [ms]. + array dt_cv; // Maps CV index to dt [ms]. + array voltage; // Maps CV index to membrane voltage [mV]. + array current_density; // Maps CV index to current density [A/m²]. + + std::unordered_map<ionKind, ion_state, util::enum_hash> ion_data; + + deliverable_event_stream deliverable_events; + + shared_state() = default; + + shared_state( + fvm_size_type n_cell, + const std::vector<fvm_index_type>& cv_to_cell_vec, + unsigned align + ); + + void add_ion( + ion_info info, + const std::vector<fvm_index_type>& cv, + const std::vector<fvm_value_type>& iconc_norm_area, + const std::vector<fvm_value_type>& econc_norm_area); + + void zero_currents(); + + void ions_init_concentration(); + + void ions_nernst_reversal_potential(fvm_value_type temperature_K); + + // Set time_to to earliest of time+dt_step and tmax. + void update_time_to(fvm_value_type dt_step, fvm_value_type tmax); + + // Set the per-cell and per-compartment dt from time_to - time. + void set_dt(); + + // Return minimum and maximum time value [ms] across cells. + std::pair<fvm_value_type, fvm_value_type> time_bounds() const; + + // Return minimum and maximum voltage value [mV] across cells. + // (Used for solution bounds checking.) + std::pair<fvm_value_type, fvm_value_type> voltage_bounds() const; + + // Take samples according to marked events in a sample_event_stream. + void take_samples( + const sample_event_stream::state& s, + array& sample_time, + array& sample_value); + + void reset(fvm_value_type initial_voltage, fvm_value_type temperature_K); +}; + +// For debugging only: +std::ostream& operator<<(std::ostream& o, const shared_state& s); + + +} // namespace multicore +} // namespace arb diff --git a/src/backends/multicore/stimulus.cpp b/src/backends/multicore/stimulus.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8f072a11df73a966c20a3cedc58d29629447fe90 --- /dev/null +++ b/src/backends/multicore/stimulus.cpp @@ -0,0 +1,67 @@ +#include <cmath> + +#include <backends/builtin_mech_proto.hpp> +#include <backends/fvm_types.hpp> +#include <backends/multicore/mechanism.hpp> + +namespace arb { + +namespace multicore { +class stimulus: public arb::multicore::mechanism { +public: + const mechanism_fingerprint& fingerprint() const override { + static mechanism_fingerprint hash = "##builtin_stimulus"; + return hash; + } + std::string internal_name() const override { return "_builtin_stimulus"; } + mechanismKind kind() const override { return ::arb::mechanismKind::point; } + mechanism_ptr clone() const override { return mechanism_ptr(new stimulus()); } + + void nrn_init() override {} + void nrn_state() override {} + void nrn_current() override { + size_type n = size(); + for (size_type i=0; i<n; ++i) { + auto cv = node_index_[i]; + auto t = vec_t_[vec_ci_[cv]]; + + if (t>=delay[i] && t<delay[i]+duration[i]) { + // Amplitudes are given as a current into a compartment, so subtract. + vec_i_[cv] -= weight_[i]*amplitude[i]; + } + } + } + void write_ions() override {} + void deliver_events(deliverable_event_stream::state events) override {} + +protected: + std::size_t object_sizeof() const override { return sizeof(*this); } + + mechanism_field_table field_table() override { + return { + {"delay", &delay}, + {"duration", &duration}, + {"amplitude", &litude} + }; + } + + mechanism_field_default_table field_default_table() override { + return { + {"delay", 0}, + {"duration", 0}, + {"amplitude", 0} + }; + } +private: + fvm_value_type* delay; + fvm_value_type* duration; + fvm_value_type* amplitude; +}; +} // namespace multicore + +template <> +concrete_mech_ptr<multicore::backend> make_builtin_stimulus() { + return concrete_mech_ptr<multicore::backend>(new arb::multicore::stimulus()); +} + +} // namespace arb diff --git a/src/backends/multicore/stimulus.hpp b/src/backends/multicore/stimulus.hpp deleted file mode 100644 index c7c1e9cd7ae9b7c209027a315c37dbd6b8e016ec..0000000000000000000000000000000000000000 --- a/src/backends/multicore/stimulus.hpp +++ /dev/null @@ -1,111 +0,0 @@ -#pragma once - -#include <cmath> -#include <limits> - -#include <mechanism.hpp> -#include <algorithms.hpp> -#include <util/indirect.hpp> -#include <util/pprintf.hpp> - -namespace arb{ -namespace multicore{ - -template<class Backend> -class stimulus : public mechanism<Backend> { -public: - using base = mechanism<Backend>; - using value_type = typename base::value_type; - using size_type = typename base::size_type; - - using array = typename base::array; - using iarray = typename base::iarray; - using view = typename base::view; - using iview = typename base::iview; - using const_view = typename base::const_view; - using const_iview = typename base::const_iview; - using ion_type = typename base::ion_type; - - static constexpr size_type no_mech_id = (size_type)-1; - - stimulus(const_iview vec_ci, const_view vec_t, const_view vec_t_to, const_view vec_dt, view vec_v, view vec_i, iarray&& node_index): - base(no_mech_id, vec_ci, vec_t, vec_t_to, vec_dt, vec_v, vec_i, std::move(node_index)) - {} - - using base::size; - - std::size_t memory() const override { - return 0; - } - - std::string name() const override { - return "stimulus"; - } - - mechanismKind kind() const override { - return mechanismKind::point; - } - - typename base::ion_spec uses_ion(ionKind k) const override { - return {false, false, false}; - } - - void set_ion(ionKind k, ion_type& i, std::vector<size_type>const& index) override { - throw std::domain_error( - arb::util::pprintf("mechanism % does not support ion type\n", name())); - } - - void nrn_init() override {} - void nrn_state() override {} - - void net_receive(int i_, value_type weight) override { - throw std::domain_error("stimulus mechanism should never receive an event\n"); - } - - void set_parameters( - const std::vector<value_type>& amp, - const std::vector<value_type>& dur, - const std::vector<value_type>& del) - { - amplitude = amp; - duration = dur; - delay = del; - } - - void set_weights(array&& w) override { - EXPECTS(size()==w.size()); - weights.resize(size()); - std::copy(w.begin(), w.end(), weights.begin()); - } - - void nrn_current() override { - if (amplitude.size() != size()) { - throw std::domain_error("stimulus called with mismatched parameter size\n"); - } - auto vec_t = util::indirect_view(util::indirect_view(vec_t_, vec_ci_), node_index_); - auto vec_i = util::indirect_view(vec_i_, node_index_); - size_type n = size(); - for (size_type i=0; i<n; ++i) { - auto t = vec_t[i]; - if (t>=delay[i] && t<delay[i]+duration[i]) { - // use subtraction because the electrod currents are specified - // in terms of current into the compartment - vec_i[i] -= weights[i]*amplitude[i]; - } - } - } - - std::vector<value_type> amplitude; - std::vector<value_type> duration; - std::vector<value_type> delay; - std::vector<value_type> weights; - - using base::vec_ci_; - using base::vec_t_; - using base::vec_v_; - using base::vec_i_; - using base::node_index_; -}; - -} // namespace multicore -} // namespace arb diff --git a/src/backends/multicore/threshold_watcher.hpp b/src/backends/multicore/threshold_watcher.hpp index ad97a5e88dcf79620212861da1caf161d59d4b29..5a000385c72009a908581b8654a1c28bd9116b85 100644 --- a/src/backends/multicore/threshold_watcher.hpp +++ b/src/backends/multicore/threshold_watcher.hpp @@ -1,51 +1,37 @@ #pragma once +#include <backends/fvm_types.hpp> #include <math.hpp> -#include <memory/memory.hpp> +#include <util/debug.hpp> + +#include "multicore_common.hpp" namespace arb { namespace multicore { -template <typename T, typename I> class threshold_watcher { public: - using value_type = T; - using size_type = I; - - using array = memory::host_vector<value_type>; - using const_view = typename array::const_view_type; - using iarray = memory::host_vector<size_type>; - using const_iview = typename iarray::const_view_type; - - /// stores a single crossing event - struct threshold_crossing { - size_type index; // index of variable - value_type time; // time of crossing - friend bool operator== ( - const threshold_crossing& lhs, const threshold_crossing& rhs) - { - return lhs.index==rhs.index && lhs.time==rhs.time; - } - }; - threshold_watcher() = default; threshold_watcher( - const_iview vec_ci, - const_view vec_t_before, - const_view vec_t_after, - const_view vals, - const std::vector<size_type>& indxs, - const std::vector<value_type>& thresh): - cv_to_cell_(vec_ci), - t_before_(vec_t_before), - t_after_(vec_t_after), - values_(vals), - cv_index_(memory::make_const_view(indxs)), - thresholds_(memory::make_const_view(thresh)), - v_prev_(vals) + const fvm_index_type* cv_to_cell, + const fvm_value_type* t_before, + const fvm_value_type* t_after, + const fvm_value_type* values, + const std::vector<fvm_index_type>& cv_index, + const std::vector<fvm_value_type>& thresholds + ): + cv_to_cell_(cv_to_cell), + t_before_(t_before), + t_after_(t_after), + values_(values), + n_cv_(cv_index.size()), + cv_index_(cv_index), + is_crossed_(n_cv_), + thresholds_(thresholds), + v_prev_(values_, values_+n_cv_) { - is_crossed_ = iarray(size()); + EXPECTS(n_cv_==thresholds.size()); reset(); } @@ -60,7 +46,7 @@ public: /// calling, because the values are used to determine the initial state void reset() { clear_crossings(); - for (auto i=0u; i<size(); ++i) { + for (fvm_size_type i = 0; i<n_cv_; ++i) { is_crossed_[i] = values_[cv_index_[i]]>=thresholds_[i]; } } @@ -73,7 +59,7 @@ public: /// Crossing events are recorded for each threshold that /// is crossed since the last call to test void test() { - for (auto i=0u; i<size(); ++i) { + for (fvm_size_type i = 0; i<n_cv_; ++i) { auto cv = cv_index_[i]; auto cell = cv_to_cell_[cv]; auto v_prev = v_prev_[i]; @@ -101,30 +87,30 @@ public: } } - bool is_crossed(size_type i) const { + bool is_crossed(fvm_size_type i) const { return is_crossed_[i]; } - /// the number of threashold values that are being monitored + /// The number of threshold values that are monitored. std::size_t size() const { - return cv_index_.size(); + return n_cv_; } - /// Data type used to store the crossings. - /// Provided to make type-generic calling code. - using crossing_list = std::vector<threshold_crossing>; - private: - const_iview cv_to_cell_; - const_view t_before_; - const_view t_after_; - const_view values_; - iarray cv_index_; - - array thresholds_; - array v_prev_; - crossing_list crossings_; - iarray is_crossed_; + /// Non-owning pointers to cv-to-cell map, per-cell time data, + /// and the values for to test against thresholds. + const fvm_index_type* cv_to_cell_ = nullptr; + const fvm_value_type* t_before_ = nullptr; + const fvm_value_type* t_after_ = nullptr; + const fvm_value_type* values_ = nullptr; + + /// Threshold watcher state. + fvm_size_type n_cv_ = 0; + std::vector<fvm_index_type> cv_index_; + std::vector<fvm_size_type> is_crossed_; + std::vector<fvm_value_type> thresholds_; + std::vector<fvm_value_type> v_prev_; + std::vector<threshold_crossing> crossings_; }; } // namespace multicore diff --git a/src/builtin_mechanisms.cpp b/src/builtin_mechanisms.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2187987e1fc45f98dbcdd9e1b4eb36f99a50520e --- /dev/null +++ b/src/builtin_mechanisms.cpp @@ -0,0 +1,34 @@ +#include <mechcat.hpp> +#include <backends/builtin_mech_proto.hpp> + +#include <backends/multicore/fvm.hpp> +#if ARB_HAVE_GPU +#include <backends/gpu/fvm.hpp> +#endif + +namespace arb { + +template <typename B> +concrete_mech_ptr<B> make_builtin_stimulus(); + +mechanism_catalogue build_builtin_mechanisms() { + mechanism_catalogue cat; + + cat.add("_builtin_stimulus", builtin_stimulus_info()); + + cat.register_implementation("_builtin_stimulus", make_builtin_stimulus<multicore::backend>()); + +#if ARB_HAVE_GPU + cat.register_implementation("_builtin_stimulus", make_builtin_stimulus<gpu::backend>()); +#endif + + return cat; +} + +const mechanism_catalogue& builtin_mechanisms() { + static mechanism_catalogue cat = build_builtin_mechanisms(); + return cat; +} + +} // namespace arb + diff --git a/src/builtin_mechanisms.hpp b/src/builtin_mechanisms.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a2bdd70074219849fb607da690cfa2edc54970b7 --- /dev/null +++ b/src/builtin_mechanisms.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include <mechcat.hpp> + +namespace arb { + +const mechanism_catalogue& builtin_mechanisms(); + +} // namespace arb diff --git a/src/cell.cpp b/src/cell.cpp index 12380329636119a67801372158e450b4d2dc07a5..470af08dac0200b21b0d7b1be1ec019817f5841c 100644 --- a/src/cell.cpp +++ b/src/cell.cpp @@ -1,33 +1,23 @@ #include <cell.hpp> #include <morphology.hpp> #include <tree.hpp> -#include <util/debug.hpp> +#include <util/rangeutil.hpp> namespace arb { -int find_compartment_index( - segment_location const& location, - compartment_model const& graph -) { - EXPECTS(unsigned(location.segment)<graph.segment_index.size()); - const auto& si = graph.segment_index; - const auto seg = location.segment; - - auto first = si[seg]; - auto n = si[seg+1] - first; - auto index = std::floor(n*location.position); - return index<n ? first+index : first+n-1; -} - -cell::cell() -{ +cell::cell() { // insert a placeholder segment for the soma segments_.push_back(make_segment<placeholder_segment>()); parents_.push_back(0); } -cell::size_type cell::num_segments() const -{ +void cell::assert_valid_segment(index_type i) const { + if (i>=num_segments()) { + throw std::out_of_range("no such segment"); + } +} + +cell::size_type cell::num_segments() const { return segments_.size(); } @@ -35,69 +25,40 @@ cell::size_type cell::num_segments() const // note: I think that we have to enforce that the soma is the first // segment that is added // -soma_segment* cell::add_soma(value_type radius, point_type center) -{ - if(has_soma()) { - throw std::domain_error( - "attempt to add a soma to a cell that already has one" - ); - } - - // add segment for the soma - if(center.is_set()) { - segments_[0] = make_segment<soma_segment>(radius, center); +soma_segment* cell::add_soma(value_type radius, point_type center) { + if (has_soma()) { + throw std::runtime_error("cell already has soma"); } - else { - segments_[0] = make_segment<soma_segment>(radius); - } - + segments_[0] = make_segment<soma_segment>(radius, center); return segments_[0]->as_soma(); } -cable_segment* cell::add_cable(cell::index_type parent, segment_ptr&& cable) -{ - // check for a valid parent id - if(cable->is_soma()) { - throw std::domain_error( - "attempt to add a soma as a segment" - ); +cable_segment* cell::add_cable(cell::index_type parent, segment_ptr&& cable) { + if (!cable->as_cable()) { + throw std::invalid_argument("segment is not a cable segment"); } - // check for a valid parent id - if(parent>num_segments()) { - throw std::out_of_range( - "parent index of cell segment is out of range" - ); + if (parent>num_segments()) { + throw std::out_of_range("parent index out of range"); } + segments_.push_back(std::move(cable)); parents_.push_back(parent); return segments_.back()->as_cable(); } -segment* cell::segment(index_type index) -{ - if (index>=num_segments()) { - throw std::out_of_range( - "attempt to access a segment with invalid index" - ); - } +segment* cell::segment(index_type index) { + assert_valid_segment(index); return segments_[index].get(); } -segment const* cell::segment(index_type index) const -{ - if (index>=num_segments()) { - throw std::out_of_range( - "attempt to access a segment with invalid index" - ); - } +segment const* cell::segment(index_type index) const { + assert_valid_segment(index); return segments_[index].get(); } - -bool cell::has_soma() const -{ +bool cell::has_soma() const { return !segment(0)->is_placeholder(); } @@ -109,64 +70,37 @@ const soma_segment* cell::soma() const { return has_soma()? segment(0)->as_soma(): nullptr; } -cable_segment* cell::cable(index_type index) -{ - if(index>0 && index<num_segments()) { - return segment(index)->as_cable(); - } - return nullptr; +cable_segment* cell::cable(index_type index) { + assert_valid_segment(index); + auto cable = segment(index)->as_cable(); + return cable? cable: throw std::runtime_error("segment is not a cable segment"); } -cell::value_type cell::volume() const -{ - return - std::accumulate( - segments_.begin(), segments_.end(), - 0., - [](double value, segment_ptr const& seg) { - return seg->volume() + value; - } - ); +cell::value_type cell::volume() const { + return util::sum_by(segments_, + [](const segment_ptr& s) { return s->volume(); }); } -cell::value_type cell::area() const -{ - return - std::accumulate( - segments_.begin(), segments_.end(), - 0., - [](double value, segment_ptr const& seg) { - return seg->area() + value; - } - ); +cell::value_type cell::area() const { + return util::sum_by(segments_, + [](const segment_ptr& s) { return s->area(); }); } -std::vector<segment_ptr> const& cell::segments() const -{ - return segments_; -} - -std::vector<cell::size_type> cell::compartment_counts() const -{ +std::vector<cell::size_type> cell::compartment_counts() const { std::vector<size_type> comp_count; comp_count.reserve(num_segments()); - for(auto const& s : segments()) { + for (const auto& s: segments()) { comp_count.push_back(s->num_compartments()); } return comp_count; } -cell::size_type cell::num_compartments() const -{ - auto n = 0u; - for(auto& s : segments_) { - n += s->num_compartments(); - } - return n; +cell::size_type cell::num_compartments() const { + return util::sum_by(segments_, + [](const segment_ptr& s) { return s->num_compartments(); }); } -compartment_model cell::model() const -{ +compartment_model cell::model() const { compartment_model m; m.tree = tree(parents_); @@ -177,30 +111,15 @@ compartment_model cell::model() const return m; } - -void cell::add_stimulus(segment_location loc, i_clamp stim) -{ - if(!(loc.segment<num_segments())) { - throw std::out_of_range( - util::pprintf( - "can't insert stimulus in segment % of a cell with % segments", - loc.segment, num_segments() - ) - ); - } +void cell::add_stimulus(segment_location loc, i_clamp stim) { + (void)segment(loc.segment); // assert loc.segment in range stimuli_.push_back({loc, std::move(stim)}); } -void cell::add_detector(segment_location loc, double threshold) -{ +void cell::add_detector(segment_location loc, double threshold) { spike_detectors_.push_back({loc, threshold}); } -std::vector<cell::index_type> const& cell::segment_parents() const -{ - return parents_; -} - // Rough and ready comparison of two cells. // We don't use an operator== because equality of two cells is open to // interpretation. For example, it is possible to have two viable representations @@ -210,14 +129,11 @@ std::vector<cell::index_type> const& cell::segment_parents() const // - number and type of segments // - volume and area properties of each segment // - number of compartments in each segment -bool cell_basic_equality(cell const& lhs, cell const& rhs) -{ - if (lhs.num_segments() != rhs.num_segments()) { - return false; - } - if (lhs.segment_parents() != rhs.segment_parents()) { +bool cell_basic_equality(const cell& lhs, const cell& rhs) { + if (lhs.parents_ != rhs.parents_) { return false; } + for (cell::index_type i=0; i<lhs.num_segments(); ++i) { // a quick and dirty test auto& l = *lhs.segment(i); diff --git a/src/cell.hpp b/src/cell.hpp index a4d81592e35864c580662d4b30f43bc522347d30..ff286e005306cf4babd7f47eadc6c2b178d8edf3 100644 --- a/src/cell.hpp +++ b/src/cell.hpp @@ -1,36 +1,53 @@ #pragma once -#include <map> -#include <mutex> +#include <unordered_map> #include <stdexcept> -#include <thread> #include <vector> #include <common_types.hpp> -#include <tree.hpp> +#include <constants.hpp> +#include <ion.hpp> +#include <mechcat.hpp> #include <morphology.hpp> #include <segment.hpp> -#include <stimulus.hpp> -#include <util/debug.hpp> -#include <util/pprintf.hpp> +#include <tree.hpp> #include <util/rangeutil.hpp> namespace arb { -/// wrapper around compartment layout information derived from a high level cell -/// description -struct compartment_model { - arb::tree tree; - std::vector<tree::int_type> parent_index; - std::vector<tree::int_type> segment_index; +// Location specification for point processes. + +struct segment_location { + segment_location(cell_lid_type s, double l): + segment(s), position(l) + { + EXPECTS(position>=0. && position<=1.); + } + + bool operator==(segment_location other) const { + return segment==other.segment && position==other.position; + } + + cell_lid_type segment; + double position; }; -int find_compartment_index( - segment_location const& location, - compartment_model const& graph -); +// Current clamp description for stimulus specification. + +struct i_clamp { + using value_type = double; + + value_type delay = 0; // [ms] + value_type duration = 0; // [ms] + value_type amplitude = 0; // [nA] + + i_clamp(value_type delay, value_type duration, value_type amplitude): + delay(delay), duration(duration), amplitude(amplitude) + {} +}; // Probe type for cell descriptions. + struct cell_probe_address { enum probe_kind { membrane_voltage, membrane_current @@ -42,16 +59,35 @@ struct cell_probe_address { // Global parameter type for cell descriptions. -struct specialized_mechanism { - std::string mech_name; // underlying mechanism +struct cell_global_properties { + const mechanism_catalogue* catalogue = &global_default_catalogue(); + + // TODO: consider making some/all of the following parameters + // cell or even segment-local. + // + // Consider also a model-level dictionary of default values that + // can be used to initialize per-cell-kind info? + // + // Defaults below chosen to match NEURON. + + // Ion species currently limited to just "ca", "na", "k". + std::unordered_map<std::string, ion_info> ion_default = { + {"ca", { ionKind::ca, 2, 5e-5, 2. }}, + {"na", { ionKind::na, 1, 10., 140.}}, + {"k", { ionKind::k, 1, 54.4, 2.5 }} + }; - // parameters specify global constants for the specialized mechanism - std::vector<std::pair<std::string, double>> parameters; + double temperature_K = constant::hh_squid_temp; // [K] + double init_membrane_potential_mV = -65; // [mV] }; -struct cell_global_properties { - // Mechanisms specialized by mechanism-global parameter settings. - std::map<std::string, specialized_mechanism> special_mechs; +/// Wrapper around compartment layout information derived from a high level cell +/// description. + +struct compartment_model { + arb::tree tree; + std::vector<tree::int_type> parent_index; + std::vector<tree::int_type> segment_index; }; /// high-level abstract representation of a cell and its segments @@ -64,7 +100,7 @@ public: struct synapse_instance { segment_location location; - mechanism_spec mechanism; + mechanism_desc mechanism; }; struct stimulus_instance { @@ -144,11 +180,9 @@ public: /// the total number of compartments over all segments size_type num_compartments() const; - std::vector<segment_ptr> const& segments() const; - - /// return reference to array that enumerates the index of the parent of - /// each segment - std::vector<index_type> const& segment_parents() const; + std::vector<segment_ptr> const& segments() const { + return segments_; + } /// return a vector with the compartment count for each segment in the cell std::vector<size_type> compartment_counts() const; @@ -173,7 +207,7 @@ public: ////////////////// // synapses ////////////////// - void add_synapse(segment_location loc, mechanism_spec p) + void add_synapse(segment_location loc, mechanism_desc p) { synapses_.push_back(synapse_instance{loc, std::move(p)}); } @@ -196,7 +230,16 @@ public: return spike_detectors_; } + // Checks that two cells have the same + // - number and type of segments + // - volume and area properties of each segment + // - number of compartments in each segment + // (note: just used for testing: move to test code?) + friend bool cell_basic_equality(const cell&, const cell&); + private: + void assert_valid_segment(index_type) const; + // storage for connections std::vector<index_type> parents_; @@ -213,12 +256,6 @@ private: std::vector<detector_instance> spike_detectors_; }; -// Checks that two cells have the same -// - number and type of segments -// - volume and area properties of each segment -// - number of compartments in each segment -bool cell_basic_equality(cell const& lhs, cell const& rhs); - // create a cable by forwarding cable construction parameters provided by the user template <typename... Args> cable_segment* cell::add_cable(cell::index_type parent, Args&&... args) diff --git a/src/cell_group_factory.cpp b/src/cell_group_factory.cpp index f1f9711fc8c0b07a0cf0c5a0ac4f860d368d57ba..fbdd4dd9df5f8a6d9f7f5b3f1623e05d856ca37a 100644 --- a/src/cell_group_factory.cpp +++ b/src/cell_group_factory.cpp @@ -4,7 +4,7 @@ #include <cell_group.hpp> #include <domain_decomposition.hpp> #include <dss_cell_group.hpp> -#include <fvm_multicell.hpp> +#include <fvm_lowered_cell.hpp> #include <lif_cell_group.hpp> #include <mc_cell_group.hpp> #include <recipe.hpp> @@ -13,18 +13,10 @@ namespace arb { -using gpu_fvm_cell = mc_cell_group<fvm::fvm_multicell<gpu::backend>>; -using mc_fvm_cell = mc_cell_group<fvm::fvm_multicell<multicore::backend>>; - cell_group_ptr cell_group_factory(const recipe& rec, const group_description& group) { switch (group.kind) { case cell_kind::cable1d_neuron: - if (group.backend == backend_kind::gpu) { - return make_cell_group<gpu_fvm_cell>(group.gids, rec); - } - else { - return make_cell_group<mc_fvm_cell>(group.gids, rec); - } + return make_cell_group<mc_cell_group>(group.gids, rec, make_fvm_lowered_cell(group.backend)); case cell_kind::regular_spike_source: return make_cell_group<rss_cell_group>(group.gids, rec); diff --git a/src/communication/gathered_vector.hpp b/src/communication/gathered_vector.hpp index 19c51fb930ac2ad4bdb0dc4bd9a1ec969e9e6150..3f5b550528e2aa892968bd75baf697ec6b4d3f48 100644 --- a/src/communication/gathered_vector.hpp +++ b/src/communication/gathered_vector.hpp @@ -4,7 +4,7 @@ #include <numeric> #include <vector> -#include <algorithms.hpp> +#include <util/rangeutil.hpp> namespace arb { @@ -18,8 +18,8 @@ public: values_(std::move(v)), partition_(std::move(p)) { - EXPECTS(std::is_sorted(partition_.begin(), partition_.end())); - EXPECTS(std::size_t(partition_.back()) == values_.size()); + EXPECTS(util::is_sorted(partition_)); + EXPECTS(partition_.back() == values_.size()); } /// the partition of distribution diff --git a/src/constants.hpp b/src/constants.hpp index 02da70144ecd44c185d1a8bf1e3d1ffb6fd7fcda..a761f36914eb84948f5e947e7091f7331f905925 100644 --- a/src/constants.hpp +++ b/src/constants.hpp @@ -18,7 +18,7 @@ namespace constant { // Universal gas constant (R) // https://physics.nist.gov/cgi-bin/cuu/Value?r -constexpr double gas_constant = 8.3144598; // J.°K^-1.mol^-1 +constexpr double gas_constant = 8.3144598; // J.K^-1.mol^-1 // Faraday's constant (F) // https://physics.nist.gov/cgi-bin/cuu/Value?f @@ -26,7 +26,7 @@ constexpr double faraday = 96485.33289; // C.mol^-1 // Temperature used in original Hodgkin-Huxley paper // doi:10.1113/jphysiol.1952.sp004764 -constexpr double hh_squid_temp = 6.3+273.15; // °K +constexpr double hh_squid_temp = 6.3+273.15; // K } // namespace arb } // namespace arb diff --git a/src/domain_decomposition.hpp b/src/domain_decomposition.hpp index 52173d2c0643c64e4bcab036bf5d0ec3cbcfea1d..166822f20c6632f98de5b4438b2564727d32fe18 100644 --- a/src/domain_decomposition.hpp +++ b/src/domain_decomposition.hpp @@ -12,6 +12,7 @@ #include <recipe.hpp> #include <util/optional.hpp> #include <util/partition.hpp> +#include <util/range.hpp> #include <util/transform.hpp> namespace arb { @@ -37,7 +38,7 @@ struct group_description { group_description(cell_kind k, std::vector<cell_gid_type> g, backend_kind b): kind(k), gids(std::move(g)), backend(b) { - EXPECTS(std::is_sorted(gids.begin(), gids.end())); + EXPECTS(util::is_sorted(gids)); } }; diff --git a/src/epoch.hpp b/src/epoch.hpp index b54101971a5ecc7685d4eeeb3ae2ce851326a6f9..3f9eb4f8350a60b7b76316f840bf63a1c8f8667d 100644 --- a/src/epoch.hpp +++ b/src/epoch.hpp @@ -3,6 +3,7 @@ #include <cstdint> #include <common_types.hpp> +#include <util/debug.hpp> namespace arb { diff --git a/src/event_generator.hpp b/src/event_generator.hpp index 60819ef951544a76bd7dd3d5834e6212dc65d785..393e391a8b4901034a9a981ce89030fd66786cd4 100644 --- a/src/event_generator.hpp +++ b/src/event_generator.hpp @@ -132,7 +132,7 @@ struct vector_backed_generator { events_(std::move(events)), it_(events_.begin()) { - if (!std::is_sorted(events_.begin(), events_.end())) { + if (!util::is_sorted(events_)) { util::sort(events_); } } @@ -171,7 +171,7 @@ struct seq_generator { events_(events), it_(std::begin(events_)) { - EXPECTS(std::is_sorted(events_.begin(), events_.end())); + EXPECTS(util::is_sorted(events_)); } postsynaptic_spike_event next() { diff --git a/src/fvm_layout.cpp b/src/fvm_layout.cpp new file mode 100644 index 0000000000000000000000000000000000000000..90d6509752d8b9725fe02a2ef2d8b2d79afdf20b --- /dev/null +++ b/src/fvm_layout.cpp @@ -0,0 +1,587 @@ +#include <set> +#include <stdexcept> +#include <unordered_set> +#include <vector> + +#include <fvm_layout.hpp> +#include <util/enumhash.hpp> +#include <util/maputil.hpp> +#include <util/meta.hpp> +#include <util/partition.hpp> +#include <util/rangeutil.hpp> +#include <util/transform.hpp> + +namespace arb { + +using util::count_along; +using util::make_span; +using util::subrange_view; +using util::transform_view; +using util::value_by_key; + +// Convenience routines + +template <typename ResizableContainer, typename Index> +void extend_to(ResizableContainer& c, const Index& i) { + if (util::size(c)<=i) { + c.resize(i+1); + } +} + +// Cable segment discretization +// ---------------------------- +// +// Each compartment i straddles the ith control volume on the right +// and the jth control volume on the left, where j is the parent index +// of i. +// +// Dividing the comparment into two halves, the centre face C +// corresponds to the shared face between the two control volumes, +// the surface areas in each half contribute to the surface area of +// the respective control volumes, and the volumes and lengths of +// each half are used to calculate the flux coefficients that +// for the connection between the two control volumes and which +// is stored in `face_conductance[i]`. +// +// +// +------- cv j --------+------- cv i -------+ +// | | | +// v v v +// ____________________________________________ +// | ........ | ........ | | | +// | ........ L ........ C R | +// |__________|__________|__________|_________| +// ^ ^ +// | | +// +--- compartment i ---+ +// +// The first control volume of any cell corresponds to the soma +// and the first half of the first cable compartment of that cell. +// +// +// Face conductance computation +// ---------------------------- +// +// The conductance between two adjacent CVs is computed as follows, +// computed in terms of the two half CVs on either side of the interface, +// correspond to the regions L–C and C–R in the diagram above. +// +// The conductance itself is approximated by the weighted harmonic mean +// of the mean linear conductivities in each half, corresponding to +// the two-point flux approximation in 1-D. +// +// Mean linear conductivities: +// +// gâ‚ = 1/h₠∫₠A(x)/R dx +// gâ‚‚ = 1/hâ‚‚ ∫₂ A(x)/R dx +// +// where A(x) is the cross-sectional area, R is the bulk resistivity, +// and h is the width of the region. The integrals are taken over the +// half CVs as described above. +// +// Equivalently, in terms of the semi-compartment volumes Vâ‚ and Vâ‚‚: +// +// gâ‚ = 1/R·Vâ‚/hâ‚ +// gâ‚‚ = 1/R·Vâ‚‚/hâ‚‚ +// +// Weighted harmonic mean, with h = hâ‚+hâ‚‚: +// +// g = (hâ‚/h·g₯¹+hâ‚‚/h·g₂¯¹)¯¹ +// = 1/R · hVâ‚Vâ‚‚/(h₂²Vâ‚+h₲Vâ‚‚) +// + +fvm_discretization fvm_discretize(const std::vector<cell>& cells) { + using value_type = fvm_value_type; + using index_type = fvm_index_type; + using size_type = fvm_size_type; + + fvm_discretization D; + + util::make_partition(D.cell_segment_bounds, + transform_view(cells, [](const cell& c) { return c.num_segments(); })); + + std::vector<index_type> cell_comp_bounds; + auto cell_comp_part = make_partition(cell_comp_bounds, + transform_view(cells, [](const cell& c) { return c.num_compartments(); })); + + D.ncell = cells.size(); + D.ncomp = cell_comp_part.bounds().second; + + D.face_conductance.assign(D.ncomp, 0.); + D.cv_area.assign(D.ncomp, 0.); + D.cv_capacitance.assign(D.ncomp, 0.); + D.parent_cv.assign(D.ncomp, index_type(-1)); + D.cv_to_cell.resize(D.ncomp); + for (auto i: make_span(0, D.ncell)) { + util::fill(subrange_view(D.cv_to_cell, cell_comp_part[i]), static_cast<index_type>(i)); + } + + std::vector<size_type> seg_comp_bounds; + for (auto i: make_span(0, D.ncell)) { + const auto& c = cells[i]; + auto cell_graph = c.model(); + auto cell_comp_ival = cell_comp_part[i]; + + auto cell_comp_base = cell_comp_ival.first; + for (auto k: make_span(cell_comp_ival)) { + D.parent_cv[k] = cell_graph.parent_index[k-cell_comp_base]+cell_comp_base; + } + + // Compartment index range for each segment in this cell. + seg_comp_bounds.clear(); + auto seg_comp_part = make_partition( + seg_comp_bounds, + transform_view(c.segments(), [](const segment_ptr& s) { return s->num_compartments(); }), + cell_comp_base); + + const auto nseg = seg_comp_part.size(); + if (nseg==0) { + throw std::invalid_argument("cannot discretrize cell with no segments"); + } + + // Handle soma (first segment and root of tree) specifically. + const auto soma = c.segment(0)->as_soma(); + if (!soma) { + throw std::logic_error("First segment of cell must be soma"); + } + else if (soma->num_compartments()!=1) { + throw std::logic_error("Soma must have exactly one compartment"); + } + + segment_info soma_info; + + size_type soma_cv = cell_comp_base; + value_type soma_area = math::area_sphere(soma->radius()); + + D.cv_area[soma_cv] = soma_area; // [µm²] + D.cv_capacitance[soma_cv] = soma_area*soma->cm; // [pF] + + soma_info.proximal_cv = soma_cv; + soma_info.distal_cv = soma_cv; + soma_info.distal_cv_area = soma_area; + D.segments.push_back(soma_info); + + // Other segments must all be cable segments. + for (size_type j = 1; j<nseg; ++j) { + const auto& seg_comp_ival = seg_comp_part[j]; + const auto ncomp = seg_comp_ival.second-seg_comp_ival.first; + + segment_info seg_info; + + const auto cable = c.segment(j)->as_cable(); + if (!cable) { + throw std::logic_error("Non-root segments of cell must be cable segments"); + } + auto cm = cable->cm; // [F/m²] + auto rL = cable->rL; // [Ω·cm] + + auto divs = div_compartments<div_compartment_integrator>(cable, ncomp); + + seg_info.parent_cv = D.parent_cv[seg_comp_ival.first]; + seg_info.parent_cv_area = divs(0).left.area; + + seg_info.proximal_cv = seg_comp_ival.first; + seg_info.distal_cv = seg_comp_ival.second-1; + seg_info.distal_cv_area = divs(ncomp-1).right.area; + + D.segments.push_back(seg_info); + + for (auto i: make_span(seg_comp_ival)) { + const auto& div = divs(i-seg_comp_ival.first); + auto j = D.parent_cv[i]; + + auto h1 = div.left.length; // [µm] + auto V1 = div.left.volume; // [µm³] + auto h2 = div.right.length; // [µm] + auto V2 = div.right.volume; // [µm³] + auto h = h1+h2; + + auto linear_conductivity = 1/rL*h*V1*V2/(h2*h2*V1+h1*h1*V2); // [S·cm¯¹·µm²] ≡ [10²·µS·µm] + constexpr double unit_scale = 1e2; + D.face_conductance[i] = unit_scale * linear_conductivity / h; // [µS] + + auto al = div.left.area; // [µm²] + auto ar = div.right.area; // [µm²] + + D.cv_area[j] += al; // [µm²] + D.cv_capacitance[j] += al * cm; // [pF] + D.cv_area[i] += ar; // [µm²] + D.cv_capacitance[i] += ar * cm; // [pF] + } + } + } + + // Number of CVs per cell is exactly number of compartments. + D.cell_cv_bounds = std::move(cell_comp_bounds); + return D; +} + + +// Build up mechanisms. +// +// Processing procedes in the following stages: +// +// I. Collect segment mechanism info from the cell descriptions into temporary +// data structures for density mechanism, point mechanisms, and ion channels. +// +// II. Build mechanism and ion configuration in `fvm_mechanism_data`: +// IIa. Ion channel CVs. +// IIb. Density mechanism CVs, parameter values; ion channel default concentration contributions. +// IIc. Point mechanism CVs, parameter values, and targets. + +fvm_mechanism_data fvm_build_mechanism_data(const mechanism_catalogue& catalogue, const std::vector<cell>& cells, const fvm_discretization& D) { + using util::assign; + using util::sort_by; + using util::optional; + + using value_type = fvm_value_type; + using index_type = fvm_index_type; + using size_type = fvm_size_type; + + using string_set = std::unordered_set<std::string>; + using string_index_map = std::unordered_map<std::string, size_type>; + + fvm_mechanism_data mechdata; + + // I. Collect segment mechanism info from cells. + + // Temporary table for density mechanism info, mapping mechanism name to tuple of: + // 1. Vector of segment indices and mechanism parameter settings where mechanism occurs. + // 2. Set of the names of parameters that are anywhere modified. + // 3. Pointer to mechanism metadata from catalogue. + + struct density_mech_data { + std::vector<std::pair<size_type, const mechanism_desc*>> segments; // + string_set paramset; + const mechanism_info* info = nullptr; + }; + std::unordered_map<std::string, density_mech_data> density_mech_table; + + // Temporary table for point mechanism info, mapping mechanism name to tuple: + // 1. Vector of point info: CV, index into cell group targets, parameter settings. + // 2. Set of the names of parameters that are anywhere modified. + // 3. Mechanism parameter settings. + + struct point_mech_data { + struct point_data { + size_type cv; + size_type target_index; + const mechanism_desc* desc; + }; + std::vector<point_data> points; + string_set paramset; + const mechanism_info* info = nullptr; + }; + std::unordered_map<std::string, point_mech_data> point_mech_table; + + // Built-in stimulus mechanism data is dealt with especially below. + // Record for each stimulus the CV and clamp data. + std::vector<std::pair<size_type, i_clamp>> stimuli; + + // Temporary table for presence of ion channels, mapping ionKind to _sorted_ + // collection of segment indices. + + std::unordered_map<ionKind, std::set<size_type>, util::enum_hash> ion_segments; + + auto update_paramset_and_validate = + [&catalogue] + (const mechanism_desc& desc, const mechanism_info*& info, string_set& paramset) + { + auto& name = desc.name(); + if (!info) { + if (!catalogue.has(name)) { + throw std::out_of_range("No mechanism "+name+" in mechanism catalogue"); + } + info = &catalogue[name]; + } + for (const auto& pv: desc.values()) { + if (!paramset.count(pv.first)) { + if (!info->parameters.count(pv.first)) { + throw std::out_of_range("Mechanism "+name+" has no parameter "+pv.first); + } + if (!info->parameters.at(pv.first).valid(pv.second)) { + throw std::out_of_range("Value out of range for mechanism "+name+" parameter "+pv.first); + } + paramset.insert(pv.first); + } + } + }; + + auto cell_segment_part = D.cell_segment_part(); + size_type target_id = 0; + + for (auto cell_idx: make_span(0, D.ncell)) { + auto& cell = cells[cell_idx]; + auto seg_range = cell_segment_part[cell_idx]; + + for (auto segment_idx: make_span(seg_range)) { + for (const mechanism_desc& desc: cell.segments()[segment_idx-seg_range.first]->mechanisms()) { + const auto& name = desc.name(); + + density_mech_data& entry = density_mech_table[name]; + update_paramset_and_validate(desc, entry.info, entry.paramset); + entry.segments.emplace_back(segment_idx, &desc); + + for (const auto& ion: entry.info->ions) { + ion_segments[ion.first].insert(segment_idx); + } + } + } + + for (const auto& cellsyn: cell.synapses()) { + const mechanism_desc& desc = cellsyn.mechanism; + size_type cv = D.segment_location_cv(cell_idx, cellsyn.location); + const auto& name = desc.name(); + + point_mech_data& entry = point_mech_table[name]; + update_paramset_and_validate(desc, entry.info, entry.paramset); + entry.points.push_back({cv, target_id++, &desc}); + + size_type segment_idx = D.cell_segment_bounds[cell_idx]+cellsyn.location.segment; + for (const auto& ion: entry.info->ions) { + ion_segments[ion.first].insert(segment_idx); + } + } + + for (const auto& stimulus: cell.stimuli()) { + size_type cv = D.segment_location_cv(cell_idx, stimulus.location); + stimuli.push_back({cv, stimulus.clamp}); + } + } + + // II. Build ion and mechanism configs. + + // Shared temporary lookup info across mechanism instances, set by build_param_data. + string_index_map param_index; + std::vector<std::string> param_name; + std::vector<value_type> param_default; + + auto build_param_data = + [¶m_name, ¶m_index, ¶m_default](const string_set& paramset, const mechanism_info* info) + { + assign(param_name, paramset); + auto nparam = paramset.size(); + + assign(param_default, transform_view(param_name, + [info](const std::string& p) { return info->parameters.at(p).default_value; })); + + param_index.clear(); + for (auto i: make_span(0, nparam)) { + param_index[param_name[i]] = i; + } + return nparam; + }; + + // IIa. Ion channel CVs. + + for (auto& ionseg: ion_segments) { + auto& ion = mechdata.ions[ionseg.first]; + + for (size_type segment: ionseg.second) { + const segment_info& seg_info = D.segments[segment]; + + if (seg_info.has_parent()) { + index_type cv = seg_info.parent_cv; + optional<std::size_t> parent_idx = util::binary_search_index(ion.cv, cv); + if (!parent_idx) { + ion.cv.push_back(cv); + ion.iconc_norm_area.push_back(D.cv_area[cv]); + ion.econc_norm_area.push_back(D.cv_area[cv]); + } + } + + for (auto cv: make_span(seg_info.cv_range())) { + ion.cv.push_back(cv); + ion.iconc_norm_area.push_back(D.cv_area[cv]); + ion.econc_norm_area.push_back(D.cv_area[cv]); + } + } + } + + // IIb. Density mechanism CVs, parameters and ionic default concentration contributions. + + // Ameliorate area sum rounding areas by clamping normalized area contributions to [0, 1] + // and rounding values within an epsilon of 0 or 1 to that value. + auto trim = [](value_type& v) { + constexpr value_type eps = std::numeric_limits<value_type>::epsilon()*4; + v = v<eps? 0: v+eps>1? 1: v; + }; + + for (const auto& entry: density_mech_table) { + const std::string& name = entry.first; + fvm_mechanism_config& config = mechdata.mechanisms[name]; + config.kind = mechanismKind::density; + + auto nparam = build_param_data(entry.second.paramset, entry.second.info); + + // In order to properly account for partially overriden paramaters in CVs + // that are shared between segments, we need to track not only the area-weighted + // sum of parameter values, but also the total area for each CV for each parameter + // that has been overriden — the remaining area demands a contribution from the + // parameter default value. + + std::vector<std::vector<value_type>> param_value(nparam); + std::vector<std::vector<value_type>> param_area_contrib(nparam); + + const auto& info = *entry.second.info; // TODO: C++14, use lambda capture with initializer + auto& ion_configs = mechdata.ions; // TODO: C++14 ditto + auto accumulate_mech_data = + [¶m_index, ¶m_value, ¶m_area_contrib, &config, &info, &ion_configs] + (size_type index, index_type cv, value_type area, const mechanism_desc& desc) + { + for (auto& kv: desc.values()) { + int pidx = param_index.at(kv.first); + value_type v = kv.second; + + extend_to(param_area_contrib[pidx], index); + param_area_contrib[pidx][index] += area; + + extend_to(param_value[pidx], index); + param_value[pidx][index] += area*v; + } + + for (auto& ion: info.ions) { + fvm_ion_config& ion_config = ion_configs[ion.first]; + size_type index = util::binary_search_index(ion_config.cv, cv).value(); + if (ion.second.write_concentration_int) { + ion_config.iconc_norm_area[index] -= area; + } + if (ion.second.write_concentration_ext) { + ion_config.econc_norm_area[index] -= area; + } + } + + extend_to(config.norm_area, index); + config.norm_area[index] += area; + }; + + for (auto& seg_entry: entry.second.segments) { + const segment_info& seg_info = D.segments[seg_entry.first]; + const mechanism_desc& mech_desc = *seg_entry.second; + + if (seg_info.has_parent()) { + index_type cv = seg_info.parent_cv; + optional<std::size_t> parent_idx = util::binary_search_index(config.cv, cv); + if (!parent_idx) { + parent_idx = config.cv.size(); + config.cv.push_back(cv); + } + + accumulate_mech_data(*parent_idx, cv, seg_info.parent_cv_area, mech_desc); + } + + for (auto cv: make_span(seg_info.cv_range())) { + size_type idx = config.cv.size(); + config.cv.push_back(cv); + + value_type area = cv==seg_info.distal_cv? seg_info.distal_cv_area: D.cv_area[cv]; + accumulate_mech_data(idx, cv, area, mech_desc); + } + } + + // Complete parameter values with default values. + + config.param_values.resize(nparam); + for (auto pidx: make_span(0, nparam)) { + value_type default_value = param_default[pidx]; + config.param_values[pidx].first = param_name[pidx]; + + auto& values = config.param_values[pidx].second; + values.resize(config.cv.size()); + + for (auto i: count_along(config.cv)) { + value_type v = param_value[pidx][i]; + value_type cv_area = D.cv_area[config.cv[i]]; + value_type remaining_area = cv_area-param_area_contrib[pidx][i]; + + values[i] = (v+remaining_area*default_value)/cv_area; + } + } + + // Normalize norm_area entries. + + for (auto i: count_along(config.cv)) { + config.norm_area[i] /= D.cv_area[config.cv[i]]; + trim(config.norm_area[i]); + } + } + + // Normalize ion norm_area entries. + + for (auto& entry: mechdata.ions) { + auto& ion_config = entry.second; + for (auto i: count_along(ion_config.cv)) { + auto cv_area = D.cv_area[ion_config.cv[i]]; + ion_config.iconc_norm_area[i] /= cv_area; + trim(ion_config.iconc_norm_area[i]); + + ion_config.econc_norm_area[i] /= cv_area; + trim(ion_config.econc_norm_area[i]); + } + } + + // II.3 Point mechanism CVs, targets, parameters and stimuli. + + for (const auto& entry: point_mech_table) { + const std::string& name = entry.first; + const auto& points = entry.second.points; + + auto nparam = build_param_data(entry.second.paramset, entry.second.info); + std::vector<std::vector<value_type>> param_value(nparam); + + // Permute points in this mechanism so that they are in increasing CV order; + // cv_order[i] is the index of the ith point by increasing CV. + + mechdata.ntarget += points.size(); + + std::vector<size_type> cv_order; + assign(cv_order, count_along(points)); + sort_by(cv_order, [&](size_type i) { return points[i].cv; }); + + fvm_mechanism_config& config = mechdata.mechanisms[name]; + config.kind = mechanismKind::point; + + assign(config.cv, transform_view(cv_order, [&](size_type j) { return points[j].cv; })); + assign(config.target, transform_view(cv_order, [&](size_type j) { return points[j].target_index; })); + + config.param_values.resize(nparam); + for (auto pidx: make_span(0, nparam)) { + value_type pdefault = param_default[pidx]; + const std::string& pname = param_name[pidx]; + + config.param_values[pidx].first = pname; + + auto& values = config.param_values[pidx].second; + assign(values, transform_view(cv_order, + [&](size_type j) { return value_by_key(points[j].desc->values(), pname).value_or(pdefault); })); + } + } + + // Sort stimuli by ascending CV and construct parameter vectors. + if (!stimuli.empty()) { + fvm_mechanism_config& stim_config = mechdata.mechanisms["_builtin_stimulus"]; + using cv_clamp = const std::pair<size_type, i_clamp>&; + + auto stim_cv_field = [](cv_clamp p) { return p.first; }; + sort_by(stimuli, stim_cv_field); + assign(stim_config.cv, transform_view(stimuli, stim_cv_field)); + + stim_config.param_values.resize(3); + + stim_config.param_values[0].first = "delay"; + assign(stim_config.param_values[0].second, + transform_view(stimuli, [](cv_clamp p) { return p.second.delay; })); + + stim_config.param_values[1].first = "duration"; + assign(stim_config.param_values[1].second, + transform_view(stimuli, [](cv_clamp p) { return p.second.duration; })); + + stim_config.param_values[2].first = "amplitude"; + assign(stim_config.param_values[2].second, + transform_view(stimuli, [](cv_clamp p) { return p.second.amplitude; })); + } + + return mechdata; +} + +} // namespace arb diff --git a/src/fvm_layout.hpp b/src/fvm_layout.hpp new file mode 100644 index 0000000000000000000000000000000000000000..22c1cf10e2c8bc3d67e0aa45dd9fc44a216a73fa --- /dev/null +++ b/src/fvm_layout.hpp @@ -0,0 +1,140 @@ +#pragma once + +#include <backends/fvm_types.hpp> +#include <cell.hpp> +#include <compartment.hpp> +#include <mechanism.hpp> +#include <mechinfo.hpp> +#include <mechcat.hpp> +#include <util/deduce_return.hpp> +#include <util/enumhash.hpp> +#include <util/span.hpp> + +namespace arb { + +// Discretization data for an unbranched segment. + +struct segment_info { + using value_type = fvm_value_type; + using index_type = fvm_index_type; + + value_type parent_cv_area = 0; + value_type distal_cv_area = 0; + + static constexpr index_type npos = -1; + + index_type parent_cv = npos; // npos => no parent. + index_type proximal_cv = 0; // First CV in segment, excluding parent. + index_type distal_cv = 0; // Last CV in segment (may be shared with other segments). + + bool has_parent() const { return parent_cv!=npos; } + + // Range of CV-indices for segment, excluding parent. + std::pair<index_type, index_type> cv_range() const { + return {proximal_cv, 1+distal_cv}; + } + + // Position is proportional distal distance along segment, in [0, 1). + index_type cv_by_position(double pos) const { + index_type n = distal_cv+1-proximal_cv; + index_type i = static_cast<index_type>(n*pos+0.5); + if (i>0) { + return proximal_cv+(i-1); + } + else { + return parent_cv==npos? proximal_cv: parent_cv; + } + } +}; + +// Discretization of morphologies and electrical properties for +// cells in a cell group. + +struct fvm_discretization { + using value_type = fvm_value_type; + using size_type = fvm_size_type; + using index_type = fvm_index_type; // In particular, used for CV indices. + + size_type ncell; + size_type ncomp; + + // Note: if CV j has no parent, parent_cv[j] = j. TODO: confirm! + std::vector<index_type> parent_cv; + std::vector<index_type> cv_to_cell; + + std::vector<value_type> face_conductance; // [µS] + std::vector<value_type> cv_area; // [µm²] + std::vector<value_type> cv_capacitance; // [pF] + + std::vector<segment_info> segments; + std::vector<size_type> cell_segment_bounds; // Partitions segment indices by cell. + std::vector<index_type> cell_cv_bounds; // Partitions CV indices by cell. + + auto cell_segment_part() const + DEDUCED_RETURN_TYPE(util::partition_view(cell_segment_bounds)) + + auto cell_cv_part() const + DEDUCED_RETURN_TYPE(util::partition_view(cell_cv_bounds)) + + size_type segment_location_cv(size_type cell_index, segment_location segloc) const { + auto cell_segs = cell_segment_part()[cell_index]; + + size_type seg = segloc.segment+cell_segs.first; + EXPECTS(seg<cell_segs.second); + return segments[seg].cv_by_position(segloc.position); + } +}; + +fvm_discretization fvm_discretize(const std::vector<cell>& cells); + + +// Post-discretization data for point and density mechanism instantiation. + +struct fvm_mechanism_config { + using value_type = fvm_value_type; + using index_type = fvm_index_type; + + mechanismKind kind; + + // Ordered CV indices where mechanism is present; may contain + // duplicates for point mechanisms. + std::vector<index_type> cv; + + // Normalized area contribution in corresponding CV (density mechanisms only). + std::vector<value_type> norm_area; + + // Synapse target number (point mechanisms only). + std::vector<index_type> target; + + // (Non-global) parameters and parameter values across the mechanism instance. + std::vector<std::pair<std::string, std::vector<value_type>>> param_values; +}; + +// Post-discretization data for ion channel state. + +struct fvm_ion_config { + using value_type = fvm_value_type; + using index_type = fvm_index_type; + + // Ordered CV indices where ion must be present. + std::vector<index_type> cv; + + // Normalized area contribution of default concentration contribution in corresponding CV. + std::vector<value_type> iconc_norm_area; + std::vector<value_type> econc_norm_area; +}; + +struct fvm_mechanism_data { + // Mechanism config, indexed by mechanism name. + std::unordered_map<std::string, fvm_mechanism_config> mechanisms; + + // Ion config, indexed by ionKind. + std::unordered_map<ionKind, fvm_ion_config, util::enum_hash> ions; + + // Total number of targets (point-mechanism points) + std::size_t ntarget = 0; +}; + +fvm_mechanism_data fvm_build_mechanism_data(const mechanism_catalogue& catalogue, const std::vector<cell>& cells, const fvm_discretization& D); + +} // namespace arb diff --git a/src/fvm_lowered_cell.hpp b/src/fvm_lowered_cell.hpp new file mode 100644 index 0000000000000000000000000000000000000000..66c04b7f785b89b890f17f681d650e72e2b53391 --- /dev/null +++ b/src/fvm_lowered_cell.hpp @@ -0,0 +1,48 @@ +#pragma once + +#include <memory> +#include <vector> + +#include <backends.hpp> +#include <backends/event.hpp> +#include <backends/fvm_types.hpp> +#include <recipe.hpp> +#include <sampler_map.hpp> +#include <util/range.hpp> + +namespace arb { + +struct fvm_integration_result { + util::range<const threshold_crossing*> crossings; + util::range<const fvm_value_type*> sample_time; + util::range<const fvm_value_type*> sample_value; +}; + +// Common base class for FVM implementation on host or gpu back-end. + +struct fvm_lowered_cell { + virtual void reset() = 0; + + virtual void initialize( + const std::vector<cell_gid_type>& gids, + const recipe& rec, + std::vector<target_handle>& target_handles, + probe_association_map<probe_handle>& probe_map) = 0; + + virtual fvm_integration_result integrate( + fvm_value_type tfinal, + fvm_value_type max_dt, + std::vector<deliverable_event> staged_events, + std::vector<sample_event> staged_samples, + bool check_physical = false) = 0; + + virtual fvm_value_type time() const = 0; + + virtual ~fvm_lowered_cell() {} +}; + +using fvm_lowered_cell_ptr = std::unique_ptr<fvm_lowered_cell>; + +fvm_lowered_cell_ptr make_fvm_lowered_cell(backend_kind p); + +} // namespace arb diff --git a/src/fvm_lowered_cell_impl.cpp b/src/fvm_lowered_cell_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f2529be36ab4493dab3a8c0796c65b9b78012c56 --- /dev/null +++ b/src/fvm_lowered_cell_impl.cpp @@ -0,0 +1,27 @@ +#include <memory> +#include <stdexcept> + +#include <backends.hpp> +#include <backends/multicore/fvm.hpp> +#ifdef ARB_HAVE_GPU +#include <backends/gpu/fvm.hpp> +#endif +#include <fvm_lowered_cell_impl.hpp> + +namespace arb { + +fvm_lowered_cell_ptr make_fvm_lowered_cell(backend_kind p) { + switch (p) { + case backend_kind::multicore: + return fvm_lowered_cell_ptr(new fvm_lowered_cell_impl<multicore::backend>); + case backend_kind::gpu: +#ifdef ARB_HAVE_GPU + return fvm_lowered_cell_ptr(new fvm_lowered_cell_impl<gpu::backend>); +#endif + ; // fall through + default: + throw std::logic_error("unsupported back-end"); + } +} + +} // namespace arb diff --git a/src/fvm_lowered_cell_impl.hpp b/src/fvm_lowered_cell_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..22b8550c856e416a1ba73b94d3decdc492b9968e --- /dev/null +++ b/src/fvm_lowered_cell_impl.hpp @@ -0,0 +1,429 @@ +#pragma once + +// Implementations for fvm_lowered_cell are parameterized +// on the back-end class. +// +// Classes here are exposed in a header only so that +// implementation details may be tested in the unit tests. +// It should otherwise only be used in `fvm_lowered_cell.cpp`. + +#include <cmath> +#include <iterator> +#include <utility> +#include <vector> +#include <stdexcept> + +#include <common_types.hpp> +#include <builtin_mechanisms.hpp> +#include <fvm_layout.hpp> +#include <fvm_lowered_cell.hpp> +#include <ion.hpp> +#include <matrix.hpp> +#include <profiling/profiler.hpp> +#include <recipe.hpp> +#include <sampler_map.hpp> +#include <util/meta.hpp> +#include <util/range.hpp> +#include <util/rangeutil.hpp> +#include <util/transform.hpp> + +#include <util/debug.hpp> + +namespace arb { + +template <class Backend> +class fvm_lowered_cell_impl: public fvm_lowered_cell { +public: + using backend = Backend; + using value_type = fvm_value_type; + using index_type = fvm_index_type; + using size_type = fvm_size_type; + + void reset() override; + + void initialize( + const std::vector<cell_gid_type>& gids, + const recipe& rec, + std::vector<target_handle>& target_handles, + probe_association_map<probe_handle>& probe_map) override; + + fvm_integration_result integrate( + value_type tfinal, + value_type max_dt, + std::vector<deliverable_event> staged_events, + std::vector<sample_event> staged_samples, + bool check_physical = false) override; + + value_type time() const override { return tmin_; } + +private: + // Host or GPU-side back-end dependent storage. + using array = typename backend::array; + using shared_state = typename backend::shared_state; + using sample_event_stream = typename backend::sample_event_stream; + using threshold_watcher = typename backend::threshold_watcher; + + std::unique_ptr<shared_state> state_; // Cell state shared across mechanisms. + + // TODO: Can we move the backend-dependent data structures below into state_? + sample_event_stream sample_events_; + array sample_time_; + array sample_value_; + matrix<backend> matrix_; + threshold_watcher threshold_watcher_; + + value_type tmin_ = 0; + value_type initial_voltage_ = NAN; + value_type temperature_ = NAN; + std::vector<mechanism_ptr> mechanisms_; + + // Host-side views/copies and local state. + decltype(backend::host_view(sample_time_)) sample_time_host_; + decltype(backend::host_view(sample_value_)) sample_value_host_; + + void update_ion_state(); + + // Throw if absolute value of membrane voltage exceeds bounds. + void assert_voltage_bounded(fvm_value_type bound); + + // Throw if any cell time not equal to tmin_ + void assert_tmin(); + + // Assign tmin_ and call assert_tmin() if assertions on. + void set_tmin(value_type t) { + tmin_ = t; + EXPECTS((assert_tmin(), true)); + } + + static unsigned dt_steps(value_type t0, value_type t1, value_type dt) { + return t0>=t1? 0: 1+(unsigned)((t1-t0)/dt); + } +}; + +template <typename Backend> +void fvm_lowered_cell_impl<Backend>::assert_tmin() { + auto time_minmax = state_->time_bounds(); + if (time_minmax.first != time_minmax.second) { + throw std::logic_error("inconsistent times across cells"); + } + if (time_minmax.first != tmin_) { + throw std::logic_error("out of synchronziation with cell state time"); + } +} + +template <typename Backend> +void fvm_lowered_cell_impl<Backend>::reset() { + state_->reset(initial_voltage_, temperature_); + set_tmin(0); + + for (auto& m: mechanisms_) { + m->nrn_init(); + } + + update_ion_state(); + + // NOTE: Threshold watcher reset must come after the voltage values are set, + // as voltage is implicitly read by watcher to set initial state. + threshold_watcher_.reset(); +} + +template <typename Backend> +fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate( + value_type tfinal, + value_type dt_max, + std::vector<deliverable_event> staged_events, + std::vector<sample_event> staged_samples, + bool check_physical) +{ + using util::as_const; + + // Integration setup + PE(advance_integrate_setup); + threshold_watcher_.clear_crossings(); + + auto n_samples = staged_samples.size(); + if (sample_time_.size() < n_samples) { + sample_time_ = array(n_samples); + sample_value_ = array(n_samples); + } + + state_->deliverable_events.init(std::move(staged_events)); + sample_events_.init(std::move(staged_samples)); + + EXPECTS((assert_tmin(), true)); + unsigned remaining_steps = dt_steps(tmin_, tfinal, dt_max); + PL(); + + // TODO: Consider devolving more of this to back-end routines (e.g. + // per-compartment dt probably not a win on GPU), possibly rumbling + // complete fvm state into shared state object. + + while (remaining_steps) { + // Deliver events and accumulate mechanism current contributions. + + PE(advance_integrate_events); + state_->deliverable_events.mark_until_after(state_->time); + PL(); + + PE(advance_integrate_current); + state_->zero_currents(); + for (auto& m: mechanisms_) { + m->deliver_events(); + m->nrn_current(); + } + PL(); + + PE(advance_integrate_events); + state_->deliverable_events.drop_marked_events(); + + // Update event list and integration step times. + + state_->update_time_to(dt_max, tfinal); + state_->deliverable_events.event_time_if_before(state_->time_to); + state_->set_dt(); + PL(); + + // Take samples at cell time if sample time in this step interval. + + PE(advance_integrate_samples); + sample_events_.mark_until(state_->time_to); + state_->take_samples(sample_events_.marked_events(), sample_time_, sample_value_); + sample_events_.drop_marked_events(); + PL(); + + // Integrate voltage by matrix solve. + + PE(advance_integrate_matrix_build); + matrix_.assemble(state_->dt_cell, state_->voltage, state_->current_density); + PL(); + PE(advance_integrate_matrix_solve); + matrix_.solve(); + memory::copy(matrix_.solution(), state_->voltage); + PL(); + + // Integrate mechanism state. + + PE(advance_integrate_state); + for (auto& m: mechanisms_) { + m->nrn_state(); + } + PL(); + + // Update ion concentrations. + + PE(advance_integrate_ionupdate); + update_ion_state(); + PL(); + + // Update time and test for spike threshold crossings. + + PE(advance_integrate_threshold); + memory::copy(state_->time_to, state_->time); + threshold_watcher_.test(); + PL(); + + // Check for non-physical solutions: + + if (check_physical) { + PE(advance_integrate_physicalcheck); + assert_voltage_bounded(1000.); + PL(); + } + + // Check for end of integration. + + PE(advance_integrate_stepsupdate); + if (!--remaining_steps) { + tmin_ = state_->time_bounds().first; + remaining_steps = dt_steps(tmin_, tfinal, dt_max); + } + PL(); + } + + set_tmin(tfinal); + + const auto& crossings = threshold_watcher_.crossings(); + sample_time_host_ = backend::host_view(sample_time_); + sample_value_host_ = backend::host_view(sample_value_); + + return fvm_integration_result{ + util::range_pointer_view(crossings), + util::range_pointer_view(sample_time_host_), + util::range_pointer_view(sample_value_host_) + }; +} + +template <typename B> +void fvm_lowered_cell_impl<B>::update_ion_state() { + state_->ions_init_concentration(); + for (auto& m: mechanisms_) { + m->write_ions(); + } + state_->ions_nernst_reversal_potential(temperature_); +} + +template <typename B> +void fvm_lowered_cell_impl<B>::assert_voltage_bounded(fvm_value_type bound) { + auto v_minmax = state_->voltage_bounds(); + if (v_minmax.first>=-bound && v_minmax.second<=bound) { + return; + } + + auto t_minmax = state_->time_bounds(); + throw std::out_of_range("voltage solution out of bounds for t in ["+ + std::to_string(t_minmax.first)+", "+std::to_string(t_minmax.second)+"]"); +} + +template <typename B> +void fvm_lowered_cell_impl<B>::initialize( + const std::vector<cell_gid_type>& gids, + const recipe& rec, + std::vector<target_handle>& target_handles, + probe_association_map<probe_handle>& probe_map) +{ + using util::any_cast; + using util::count_along; + using util::make_span; + using util::value_by_key; + using util::keys; + + std::vector<cell> cells; + const std::size_t ncell = gids.size(); + + cells.reserve(ncell); + for (auto gid: gids) { + cells.push_back(any_cast<cell>(rec.get_cell_description(gid))); + } + + auto rec_props = rec.get_global_properties(cell_kind::cable1d_neuron); + auto global_props = rec_props.has_value()? any_cast<cell_global_properties>(rec_props): cell_global_properties{}; + + const mechanism_catalogue* catalogue = global_props.catalogue; + initial_voltage_ = global_props.init_membrane_potential_mV; + temperature_ = global_props.temperature_K; + + // Mechanism instantiator helper. + auto mech_instance = [&catalogue](const std::string& name) { + auto cat = builtin_mechanisms().has(name)? &builtin_mechanisms(): catalogue; + return cat->instance<backend>(name); + }; + + // Discretize cells, build matrix. + + fvm_discretization D = fvm_discretize(cells); + EXPECTS(D.ncell == ncell); + matrix_ = matrix<backend>(D.parent_cv, D.cell_cv_bounds, D.cv_capacitance, D.face_conductance, D.cv_area); + sample_events_ = sample_event_stream(ncell); + + // Discretize mechanism data. + + fvm_mechanism_data mech_data = fvm_build_mechanism_data(*catalogue, cells, D); + + // Create shared cell state. + // (SIMD padding requires us to check each mechanism for alignment/padding constraints.) + + unsigned data_alignment = util::max_value( + util::transform_view(keys(mech_data.mechanisms), + [&](const std::string& name) { return mech_instance(name)->data_alignment(); })); + + state_ = util::make_unique<shared_state>(ncell, D.cv_to_cell, data_alignment? data_alignment: 1u); + + // Instantiate mechanisms and ions. + + for (auto& i: mech_data.ions) { + ionKind kind = i.first; + + if (auto ion = value_by_key(global_props.ion_default, to_string(kind))) { + state_->add_ion(ion.value(), i.second.cv, i.second.iconc_norm_area, i.second.econc_norm_area); + } + } + + target_handles.resize(mech_data.ntarget); + + for (auto& m: mech_data.mechanisms) { + auto& name = m.first; + auto& config = m.second; + unsigned mech_id = mechanisms_.size(); + + mechanism::layout layout; + layout.cv = config.cv; + layout.weight.resize(layout.cv.size()); + + // Mechanism weights are F·α where α ∈ [0, 1] is the proportional + // contribution in the CV, and F is the scaling factor required + // to convert from the mechanism current contribution units to A/m². + + if (config.kind==mechanismKind::point) { + // Point mechanism contributions are in [nA]; CV area A in [µm^2]. + // F = 1/A * [nA/µm²] / [A/m²] = 1000/A. + + for (auto i: count_along(layout.cv)) { + auto cv = layout.cv[i]; + layout.weight[i] = 1000/D.cv_area[cv]; + + // (builtin stimulus, for example, has no targets) + if (!config.target.empty()) { + target_handles[config.target[i]] = target_handle(mech_id, i, D.cv_to_cell[cv]); + } + } + } + else { + // Density Current density contributions from mechanism are in [mA/cm²] + // (NEURON compatibility). F = [mA/cm²] / [A/m²] = 10. + + for (auto i: count_along(layout.cv)) { + layout.weight[i] = 10*config.norm_area[i]; + } + } + + auto mech = mech_instance(name); + mech->instantiate(mech_id, *state_, layout); + + for (auto& pv: config.param_values) { + mech->set_parameter(pv.first, pv.second); + } + mechanisms_.push_back(mechanism_ptr(mech.release())); + } + + // Collect detectors, probe handles. + + std::vector<index_type> detector_cv; + std::vector<value_type> detector_threshold; + + for (auto cell_idx: make_span(ncell)) { + cell_gid_type gid = gids[cell_idx]; + + for (auto detector: cells[cell_idx].detectors()) { + detector_cv.push_back(D.segment_location_cv(cell_idx, detector.location)); + detector_threshold.push_back(detector.threshold); + } + + for (cell_lid_type j: make_span(rec.num_probes(gid))) { + probe_info pi = rec.get_probe({gid, j}); + auto where = any_cast<cell_probe_address>(pi.address); + + auto cv = D.segment_location_cv(cell_idx, where.location); + probe_handle handle; + + switch (where.kind) { + case cell_probe_address::membrane_voltage: + handle = state_->voltage.data()+cv; + break; + case cell_probe_address::membrane_current: + handle = state_->current_density.data()+cv; + break; + default: + throw std::logic_error("unrecognized probeKind"); + } + + probe_map.insert({pi.id, {handle, pi.tag}}); + } + } + + threshold_watcher_ = threshold_watcher(state_->cv_to_cell.data(), state_->time.data(), + state_->time_to.data(), state_->voltage.data(), detector_cv, detector_threshold); + + reset(); +} + +} // namespace arb diff --git a/src/fvm_multicell.hpp b/src/fvm_multicell.hpp deleted file mode 100644 index 2049c2942e969dfb0dac98f3aa4e5f1bdab1d1c3..0000000000000000000000000000000000000000 --- a/src/fvm_multicell.hpp +++ /dev/null @@ -1,1256 +0,0 @@ -#pragma once - -#include <algorithm> -#include <iterator> -#include <map> -#include <set> -#include <string> -#include <vector> - -#include <algorithms.hpp> -#include <backends/event.hpp> -#include <backends/fvm_types.hpp> -#include <cell.hpp> -#include <compartment.hpp> -#include <constants.hpp> -#include <event_queue.hpp> -#include <ion.hpp> -#include <math.hpp> -#include <matrix.hpp> -#include <memory/memory.hpp> -#include <profiling/profiler.hpp> -#include <recipe.hpp> -#include <sampler_map.hpp> -#include <segment.hpp> -#include <stimulus.hpp> -#include <util/meta.hpp> -#include <util/partition.hpp> -#include <util/rangeutil.hpp> -#include <util/span.hpp> - -namespace arb { -namespace fvm { - -inline int find_cv_index(const segment_location& loc, const compartment_model& graph) { - const auto& si = graph.segment_index; - const auto seg = loc.segment; - - auto first = si[seg]; - auto n = si[seg+1] - first; - - int index = static_cast<int>(n*loc.position+0.5); - index = index==0? graph.parent_index[first]: first+(index-1); - - return index; -}; - -template<class Backend> -class fvm_multicell { -public: - using backend = Backend; - - /// the real number type - using value_type = fvm_value_type; - - /// the integral index type - using size_type = fvm_size_type; - - /// the container used for values - using array = typename backend::array; - using host_array = typename backend::host_array; - - /// the container used for indexes - using iarray = typename backend::iarray; - - using view = typename array::view_type; - using const_view = typename array::const_view_type; - - /// the type (view or copy) for a const host-side view of an array - using host_view = decltype(memory::on_host(std::declval<array>())); - - // handles and events are currently common across implementations; - // re-expose definitions from `backends/event.hpp`. - using target_handle = ::arb::target_handle; - using probe_handle = ::arb::probe_handle; - using deliverable_event = ::arb::deliverable_event; - - fvm_multicell() = default; - - void resting_potential(value_type potential_mV) { - resting_potential_ = potential_mV; - } - - // Set up data structures for a fixed collection of cells identified by `gids` - // with descriptions taken from the recipe `rec`. - // - // Lowered-cell specific handles for targets and probes are stored in the - // caller-provided vector `target_handles` and map `probe_map`. - void initialize( - const std::vector<cell_gid_type>& gids, - const recipe& rec, - std::vector<target_handle>& target_handles, - probe_association_map<probe_handle>& probe_map); - - void reset(); - - // fvm_multicell::deliver_event is used only for testing. - void deliver_event(target_handle h, value_type weight) { - mechanisms_[h.mech_id]->net_receive(h.mech_index, weight); - } - - // fvm_multicell::probe is used only for testing. - value_type probe(probe_handle h) const { - return backend::dereference(h); // h is a pointer, but might be device-side. - } - - // Initialize state prior to a sequence of integration steps. - // `staged_events` and `staged_samples` are expected to be - // sorted by event time. - void setup_integration( - value_type tfinal, value_type dt_max, - const std::vector<deliverable_event>& staged_events, - const std::vector<sample_event>& staged_samples) - { - PE(advance_integrate_setup); - - EXPECTS(dt_max>0); - - tfinal_ = tfinal; - dt_max_ = dt_max; - - compute_min_remaining(); - - EXPECTS(!has_pending_events()); - - n_samples_ = staged_samples.size(); - - events_.init(staged_events); - sample_events_.init(staged_samples); - - // Reallocate sample buffers if necessary. - if (sample_value_.size()<n_samples_) { - sample_value_ = array(n_samples_); - sample_time_ = array(n_samples_); - } - - PL(); - } - - // Advance one integration step. - void step_integration(); - - // Query integration completion state. - bool integration_complete() const { - return min_remaining_steps_==0; - } - - // Access to sample data post-integration. - decltype(memory::make_const_view(std::declval<host_view>())) sample_value() const { - EXPECTS(sample_events_.empty()); - host_sample_value_ = memory::on_host(sample_value_); - return host_sample_value_; - } - - decltype(memory::make_const_view(std::declval<host_view>())) sample_time() const { - EXPECTS(sample_events_.empty()); - host_sample_time_ = memory::on_host(sample_time_); - return host_sample_time_; - } - - // Query per-cell time state. - // Placeholder: external time queries will no longer be required when - // complete integration loop is in lowered cell. - value_type time(size_type cell_idx) const { - refresh_time_cache(); - return cached_time_[cell_idx]; - } - - value_type min_time() const { - return backend::minmax_value(time_).first; - } - - value_type max_time() const { - return backend::minmax_value(time_).second; - } - - bool state_synchronized() const { - auto mm = backend::minmax_value(time_); - return mm.first==mm.second; - } - - /// Set times for all cells (public for testing purposes only). - void set_time_global(value_type t) { - memory::fill(time_, t); - invalidate_time_cache(); - } - - void set_time_to_global(value_type t) { - memory::fill(time_to_, t); - invalidate_time_cache(); - } - - /// Following types and methods are public only for testing: - - /// the type used to store matrix information - using matrix_type = matrix<backend>; - - /// mechanism type - using mechanism = typename backend::mechanism; - using mechanism_ptr = typename backend::mechanism_ptr; - - /// stimulus type - using stimulus = typename backend::stimulus; - - /// ion species storage - using ion_type = typename backend::ion_type; - - /// view into index container - using iview = typename backend::iview; - using const_iview = typename backend::const_iview; - - const matrix_type& jacobian() { return matrix_; } - - /// return list of CV areas in : - /// um^2 - /// 1e-6.mm^2 - /// 1e-8.cm^2 - const_view cv_areas() const { return cv_areas_; } - - /// return the voltage in each CV - view voltage() { return voltage_; } - const_view voltage() const { return voltage_; } - - /// return the current density in each CV: A.m^-2 - view current() { return current_; } - const_view current() const { return current_; } - - std::size_t size() const { return matrix_.size(); } - - /// return reference to in iterable container of the mechanisms - std::vector<mechanism_ptr>& mechanisms() { return mechanisms_; } - - /// return reference to list of ions - std::map<ionKind, ion_type>& ions() { return ions_; } - std::map<ionKind, ion_type> const& ions() const { return ions_; } - - /// return reference to sodium ion - ion_type& ion_na() { return ions_[ionKind::na]; } - ion_type const& ion_na() const { return ions_[ionKind::na]; } - - /// return reference to calcium ion - ion_type& ion_ca() { return ions_[ionKind::ca]; } - ion_type const& ion_ca() const { return ions_[ionKind::ca]; } - - /// return reference to pottasium ion - ion_type& ion_k() { return ions_[ionKind::k]; } - ion_type const& ion_k() const { return ions_[ionKind::k]; } - - /// flags if solution is physically realistic. - /// here we define physically realistic as the voltage being within reasonable bounds. - /// use a simple test of the voltage at the soma is reasonable, i.e. in the range - /// v_soma \in (-1000mv, 1000mv) - bool is_physical_solution() const { - auto v = voltage_[0]; - return (v>-1000.) && (v<1000.); - } - - /// Return reference to the mechanism that matches name. - /// The reference is const, because this information should not be - /// modified by the caller, however it is needed for unit testing. - util::optional<const mechanism_ptr&> find_mechanism(const std::string& name) const { - auto it = std::find_if( - std::begin(mechanisms_), std::end(mechanisms_), - [&name](const mechanism_ptr& m) {return m->name()==name;}); - return it==mechanisms_.end() ? util::nullopt: util::just(*it); - } - - // - // Threshold crossing interface. - // Used by calling code to perform spike detection - // - - /// types defined by the back end for threshold detection - using threshold_watcher = typename backend::threshold_watcher; - using crossing_list = typename backend::threshold_watcher::crossing_list; - - /// Forward the list of threshold crossings from the back end. - /// The list is passed by value, because we don't want the calling code - /// to depend on references to internal state of the solver, and because - /// for some backends the results might have to be collated before returning. - crossing_list get_spikes() const { - return threshold_watcher_.crossings(); - } - - /// clear all spikes: aka threshold crossings. - void clear_spikes() { - threshold_watcher_.clear_crossings(); - } - -private: - /// number of distinct cells (integration domains) - size_type ncell_; - - threshold_watcher threshold_watcher_; - - /// resting potential (initial voltage condition) - value_type resting_potential_ = -65; - - /// final time in integration round [ms] - value_type tfinal_ = 0; - - /// max time step for integration [ms] - value_type dt_max_ = 0; - - /// minimum number of integration steps left in integration period. - // zero => integration complete. - unsigned min_remaining_steps_ = 0; - - void compute_min_remaining() { - auto tmin = min_time(); - min_remaining_steps_ = tmin>=tfinal_? 0: 1 + (unsigned)((tfinal_-tmin)/dt_max_); - } - - void decrement_min_remaining() { - EXPECTS(min_remaining_steps_>0); - if (!--min_remaining_steps_) { - compute_min_remaining(); - } - } - - /// event queue for integration period - using deliverable_event_stream = typename backend::deliverable_event_stream; - deliverable_event_stream events_; - - bool has_pending_events() const { - return !events_.empty(); - } - - /// sample events for integration period - using sample_event_stream = typename backend::sample_event_stream; - sample_event_stream sample_events_; - - /// sample buffers - size_type n_samples_ = 0; - array sample_value_; - array sample_time_; - - mutable host_view host_sample_value_; // host-side views/copies of sample data - mutable host_view host_sample_time_; - - /// the linear system for implicit time stepping of cell state - matrix_type matrix_; - - /// cv_areas_[i] is the surface area of CV i [µm^2] - array cv_areas_; - - /// the map from compartment index to cell index - iarray cv_to_cell_; - - /// the per-cell simulation time - array time_; - - /// the per-cell integration period end point - array time_to_; - - // the per-compartment dt - // (set to dt_cell_[j] for each compartment in cell j). - array dt_comp_; - - // the per-cell dt - // (set to time_to_[j]-time_[j] for each cell j). - array dt_cell_; - - // Maintain cached copy of time vector for querying by - // cell_group. This will no longer be necessary when full - // integration loop is in lowered cell. - mutable std::vector<value_type> cached_time_; - mutable bool cached_time_valid_ = false; - - void invalidate_time_cache() { cached_time_valid_ = false; } - void refresh_time_cache() const { - if (!cached_time_valid_) { - memory::copy(time_, memory::make_view(cached_time_)); - } - cached_time_valid_ = true; - } - - /// the transmembrane current density over the surface of each CV [A.m^-2] - /// I = i_m - I_e/area - array current_; - - /// the potential in each CV [mV] - array voltage_; - - /// the set of mechanisms present in the cell - std::vector<mechanism_ptr> mechanisms_; - - /// the ion species - std::map<ionKind, ion_type> ions_; - - /// Compact representation of the control volumes into which a segment is - /// decomposed. Used to reconstruct the weights used to convert current - /// densities to currents for density channels. - struct segment_cv_range { - // the contribution to the surface area of the CVs that - // are at the beginning and end of the segment - std::pair<value_type, value_type> areas; - - // the range of CVs in the segment, excluding the parent CV - std::pair<size_type, size_type> segment_cvs; - - // The last CV in the parent segment, which corresponds to the - // first CV in this segment. - // Set to npos() if there is no parent (i.e. if soma) - size_type parent_cv; - - static constexpr size_type npos() { - return std::numeric_limits<size_type>::max(); - } - - // the number of CVs (including the parent) - std::size_t size() const { - return segment_cvs.second-segment_cvs.first + (parent_cv==npos() ? 0 : 1); - } - - bool has_parent() const { - return parent_cv != npos(); - } - }; - - // perform area and capacitance calculation on initialization - segment_cv_range compute_cv_area_capacitance( - std::pair<size_type, size_type> comp_ival, - const segment* seg, - const std::vector<size_type>& parent, - std::vector<value_type>& face_conductance, - std::vector<value_type>& tmp_cv_areas, - std::vector<value_type>& cv_capacitance - ); - - // TODO: This process should be simpler when we can deal with mechanism prototypes and have - // separate initialization. - // - // Create possibly-specialized mechanism and add to mechanism set. - // Weights are unset, and should be set specifically with mechanism::set_weights(). - mechanism& make_mechanism( - const std::string& name, - const std::map<std::string, specialized_mechanism>& special_mechs, - const std::vector<size_type>& node_indices) - { - std::string impl_name = name; - std::vector<std::pair<std::string, double>> global_params; - - if (special_mechs.count(name)) { - const auto& spec_mech = special_mechs.at(name); - impl_name = spec_mech.mech_name; - global_params = spec_mech.parameters; - } - - size_type mech_id = mechanisms_.size(); - auto m = backend::make_mechanism(impl_name, mech_id, cv_to_cell_, time_, time_to_, dt_comp_, voltage_, current_, {}, node_indices); - if (impl_name!=name) { - m->set_alias(name); - } - - for (const auto& pv: global_params) { - auto field = m->field_value_ptr(pv.first); - if (!field) { - throw std::invalid_argument("no scalar parameter "+pv.first+" in mechanism "+m->name()); - } - m.get()->*field = pv.second; - } - - mechanisms_.push_back(std::move(m)); - return *mechanisms_.back(); - } - - // Throwing-wrapper around mechanism (range) parameter look up. - static view mech_field(mechanism& m, const std::string& param_name) { - auto p = m.field_view_ptr(param_name); - if (!p) { - throw std::invalid_argument("no parameter "+param_name+" in mechanism "+m.name()); - } - return m.*p; - } -}; - -//////////////////////////////////////////////////////////////////////////////// -//////////////////////////////// Implementation //////////////////////////////// -//////////////////////////////////////////////////////////////////////////////// -template <typename Backend> -typename fvm_multicell<Backend>::segment_cv_range -fvm_multicell<Backend>::compute_cv_area_capacitance( - std::pair<size_type, size_type> comp_ival, - const segment* seg, - const std::vector<size_type>& parent, - std::vector<value_type>& face_conductance, - std::vector<value_type>& tmp_cv_areas, - std::vector<value_type>& cv_capacitance) -{ - // precondition: group_parent_index[j] holds the correct value for - // j in [base_comp, base_comp+segment.num_compartments()]. - - auto ncomp = comp_ival.second-comp_ival.first; - - segment_cv_range cv_range; - - auto cm = seg->cm; - auto rL = seg->rL; - - if (auto soma = seg->as_soma()) { - // confirm assumption that there is one compartment in soma - if (ncomp!=1) { - throw std::logic_error("soma allocated more than one compartment"); - } - auto i = comp_ival.first; - auto area = math::area_sphere(soma->radius()); - - tmp_cv_areas[i] += area; - cv_capacitance[i] += area*cm; - - cv_range.segment_cvs = {comp_ival.first, comp_ival.first+1}; - cv_range.areas = {0.0, area}; - cv_range.parent_cv = segment_cv_range::npos(); - } - else if (auto cable = seg->as_cable()) { - // Loop over each compartment in the cable - // - // Each compartment i straddles the ith control volume on the right - // and the jth control volume on the left, where j is the parent index - // of i. - // - // Dividing the comparment into two halves, the centre face C - // corresponds to the shared face between the two control volumes, - // the surface areas in each half contribute to the surface area of - // the respective control volumes, and the volumes and lengths of - // each half are used to calculate the flux coefficients that - // for the connection between the two control volumes and which - // is stored in `face_conductance[i]`. - // - // - // +------- cv j --------+------- cv i -------+ - // | | | - // v v v - // ____________________________________________ - // | ........ | ........ | | | - // | ........ L ........ C R | - // |__________|__________|__________|_________| - // ^ ^ - // | | - // +--- compartment i ---+ - // - // The first control volume of any cell corresponds to the soma - // and the first half of the first cable compartment of that cell. - - auto divs = div_compartments<div_compartment_integrator>(cable, ncomp); - - // assume that this segment has a parent, which is the case so long - // as the soma is the root of all cell trees. - cv_range.parent_cv = parent[comp_ival.first]; - cv_range.segment_cvs = comp_ival; - cv_range.areas = {divs(0).left.area, divs(ncomp-1).right.area}; - - for (auto i: util::make_span(comp_ival)) { - const auto& div = divs(i-comp_ival.first); - auto j = parent[i]; - - // Conductance approximated by weighted harmonic mean of mean - // conductances in each half. - // - // Mean conductances: - // gâ‚ = 1/h₠∫₠A(x)/R dx - // gâ‚‚ = 1/hâ‚‚ ∫₂ A(x)/R dx - // - // where A(x) is the cross-sectional area, R is the bulk - // resistivity, h is the length of the interval and the - // integrals are taken over the intervals respectively. - // Equivalently, in terms of the semi-compartment volumes - // Vâ‚ and Vâ‚‚: - // - // gâ‚ = 1/R·Vâ‚/hâ‚ - // gâ‚‚ = 1/R·Vâ‚‚/hâ‚‚ - // - // Weighted harmonic mean, with h = hâ‚+hâ‚‚: - // - // g = (hâ‚/h·g₯¹+hâ‚‚/h·g₂¯¹)¯¹ - // = 1/R · hVâ‚Vâ‚‚/(h₂²Vâ‚+h₲Vâ‚‚) - // - // the following units are used - // lengths : μm - // areas : μm^2 - // volumes : μm^3 - - auto h1 = div.left.length; - auto V1 = div.left.volume; - auto h2 = div.right.length; - auto V2 = div.right.volume; - auto h = h1+h2; - - auto conductance = 1/rL*h*V1*V2/(h2*h2*V1+h1*h1*V2); - // the scaling factor of 10^2 is to convert the quantity - // to micro Siemens [μS] - face_conductance[i] = 1e2 * conductance / h; - - auto al = div.left.area; - auto ar = div.right.area; - - tmp_cv_areas[j] += al; - tmp_cv_areas[i] += ar; - cv_capacitance[j] += al * cm; - cv_capacitance[i] += ar * cm; - } - } - else { - throw std::domain_error("FVM lowering encountered unsuported segment type"); - } - - return cv_range; -} - -template <typename Backend> -void fvm_multicell<Backend>::initialize( - const std::vector<cell_gid_type>& gids, - const recipe& rec, - std::vector<target_handle>& target_handles, - probe_association_map<probe_handle>& probe_map) -{ - using memory::make_const_view; - using util::any_cast; - using util::assign_by; - using util::make_partition; - using util::make_span; - using util::size; - using util::sort_by; - using util::transform_view; - using util::subrange_view; - - ncell_ = size(gids); - std::size_t targets_count = 0u; - - // Handle any global parameters for these cell groups. - // (Currently: just specialized mechanisms). - std::map<std::string, specialized_mechanism> special_mechs; - util::any gprops = rec.get_global_properties(cell_kind::cable1d_neuron); - if (gprops.has_value()) { - special_mechs = util::any_cast<cell_global_properties&>(gprops).special_mechs; - } - - // Take cell descriptions from recipe. These are used initially - // to count compartments for allocation of data structures, and - // then interrogated again for the details for each cell in turn. - - std::vector<cell> cells; - cells.reserve(gids.size()); - for (auto gid: gids) { - cells.push_back(std::move(any_cast<cell>(rec.get_cell_description(gid)))); - } - - auto cell_num_compartments = - transform_view(cells, [](const cell& c) { return c.num_compartments(); }); - - std::vector<cell_lid_type> cell_comp_bounds; - auto cell_comp_part = make_partition(cell_comp_bounds, cell_num_compartments); - auto ncomp = cell_comp_part.bounds().second; - - // initialize storage from total compartment count - current_ = array(ncomp, 0); - voltage_ = array(ncomp, resting_potential_); - cv_to_cell_ = iarray(ncomp, 0); - time_ = array(ncell_, 0); - time_to_ = array(ncell_, 0); - cached_time_.resize(ncell_); - cached_time_valid_ = false; - dt_cell_ = array(ncell_, 0); - dt_comp_ = array(ncomp, 0); - - // initialize cv_to_cell_ values from compartment partition - std::vector<size_type> cv_to_cell_tmp(ncomp); - for (size_type i = 0; i<ncell_; ++i) { - util::fill(util::subrange_view(cv_to_cell_tmp, cell_comp_part[i]), i); - } - memory::copy(cv_to_cell_tmp, cv_to_cell_); - - // TODO: mechanism parameters are currently indexed by string keys; more efficient - // to use the mechanism member pointers, when these become easily accessible via - // the mechanism catalogue interface. - - // look up table: mechanism name -> list of cv_range objects and parameter settings. - - struct mech_area_contrib { - size_type index; - value_type area; - }; - - struct mech_info { - segment_cv_range cv_range; - // Note: owing to linearity constraints, the only parameters for which it is - // sensible to modify are those which make a linear contribution to currents - // (or ion fluxes, etc.) - std::map<std::string, value_type> param_map; - std::vector<mech_area_contrib> contributions; - }; - - std::map<std::string, std::vector<mech_info>> mech_map; - - // look up table: point mechanism (synapse) name -> list of CV indices, target numbers, parameters. - struct syn_info { - cell_lid_type cv; - cell_lid_type target; - std::map<std::string, value_type> param_map; - }; - - std::map<std::string, std::vector<syn_info>> syn_mech_map; - - // initialize vector used for matrix creation. - std::vector<size_type> group_parent_index(ncomp); - - // setup per-cell event stores. - events_ = deliverable_event_stream(ncell_); - sample_events_ = sample_event_stream(ncell_); - - // Create each cell: - - // Allocate scratch storage for calculating quantities used to build the - // linear system: these will later be copied into target-specific storage. - - // face_conductance_[i] = area_face / (rL * delta_x); - std::vector<value_type> face_conductance(ncomp); // [µS] - /// cv_capacitance_[i] is the capacitance of CV membrane - std::vector<value_type> cv_capacitance(ncomp); // [µm^2*F*m^-2 = pF] - /// membrane area of each cv - std::vector<value_type> tmp_cv_areas(ncomp); // [µm^2] - - // used to build the information required to construct spike detectors - std::vector<size_type> spike_detector_index; - std::vector<value_type> thresholds; - - // Iterate over the input cells and build the indexes etc that descrbe the - // fused cell group. On completion: - // - group_paranet_index contains the full parent index for the fused cells. - // - mech_to_cv_range and syn_mech_map provide a map from mechanism names to an - // iterable container of compartment ranges, which are used later to - // generate the node index for each mechanism kind. - // - the tmp_* vectors contain compartment-specific information for each - // compartment in the fused cell group (areas, capacitance, etc). - // - each probe, stimulus and detector is attached to its compartment. - for (auto i: make_span(0, ncell_)) { - const auto& c = cells[i]; - auto gid = gids[i]; - auto comp_ival = cell_comp_part[i]; - - auto graph = c.model(); - - for (auto k: make_span(comp_ival)) { - group_parent_index[k] = graph.parent_index[k-comp_ival.first]+comp_ival.first; - } - - auto seg_num_compartments = - transform_view(c.segments(), [](const segment_ptr& s) { return s->num_compartments(); }); - const auto nseg = seg_num_compartments.size(); - - std::vector<cell_lid_type> seg_comp_bounds; - auto seg_comp_part = - make_partition(seg_comp_bounds, seg_num_compartments, comp_ival.first); - - for (size_type j = 0; j<nseg; ++j) { - const auto& seg = c.segment(j); - const auto& seg_comp_ival = seg_comp_part[j]; - - auto cv_range = compute_cv_area_capacitance( - seg_comp_ival, seg, group_parent_index, - face_conductance, tmp_cv_areas, cv_capacitance); - - for (const auto& mech: seg->mechanisms()) { - mech_map[mech.name()].push_back({cv_range, mech.values()}); - } - } - - for (const auto& syn: c.synapses()) { - const auto& name = syn.mechanism.name(); - - cell_lid_type syn_cv = comp_ival.first + find_cv_index(syn.location, graph); - cell_lid_type target_index = targets_count++; - - syn_mech_map[name].push_back({syn_cv, target_index, syn.mechanism.values()}); - } - - // - // add the stimuli - // - - // TODO: use same process as for synapses! - // step 1: pack the index and parameter information into flat vectors - std::vector<size_type> stim_index; - std::vector<value_type> stim_durations; - std::vector<value_type> stim_delays; - std::vector<value_type> stim_amplitudes; - std::vector<value_type> stim_weights; - for (const auto& stim: c.stimuli()) { - auto idx = comp_ival.first+find_cv_index(stim.location, graph); - stim_index.push_back(idx); - stim_durations.push_back(stim.clamp.duration()); - stim_delays.push_back(stim.clamp.delay()); - stim_amplitudes.push_back(stim.clamp.amplitude()); - stim_weights.push_back(1e3/tmp_cv_areas[idx]); - } - - // step 2: create the stimulus mechanism and initialize the stimulus - // parameters - // NOTE: the indexes and associated metadata (durations, delays, - // amplitudes) have not been permuted to ascending cv index order, - // as is the case with other point processes. - // This is because the hard-coded stimulus mechanism makes no - // optimizations that rely on this assumption. - if (stim_index.size()) { - auto stim = new stimulus( - cv_to_cell_, time_, time_to_, dt_comp_, - voltage_, current_, memory::make_const_view(stim_index)); - stim->set_parameters(stim_amplitudes, stim_durations, stim_delays); - stim->set_weights(memory::make_const_view(stim_weights)); - mechanisms_.push_back(mechanism_ptr(stim)); - } - - // calculate spike detector handles are their corresponding compartment indices - for (const auto& detector: c.detectors()) { - auto comp = comp_ival.first+find_cv_index(detector.location, graph); - spike_detector_index.push_back(comp); - thresholds.push_back(detector.threshold); - } - - // Retrieve probe addresses and tags from recipe for this cell. - for (cell_lid_type j: make_span(0, rec.num_probes(gid))) { - probe_info pi = rec.get_probe({gid, j}); - auto where = any_cast<cell_probe_address>(pi.address); - - auto comp = comp_ival.first+find_cv_index(where.location, graph); - probe_handle handle; - - switch (where.kind) { - case cell_probe_address::membrane_voltage: - handle = fvm_multicell::voltage_.data()+comp; - break; - case cell_probe_address::membrane_current: - handle = fvm_multicell::current_.data()+comp; - break; - default: - throw std::logic_error("unrecognized probeKind"); - } - - probe_map.insert({pi.id, {handle, pi.tag}}); - } - } - - // set a back-end supplied watcher on the voltage vector - threshold_watcher_ = - threshold_watcher(cv_to_cell_, time_, time_to_, voltage_, spike_detector_index, thresholds); - - // store the geometric information in target-specific containers - cv_areas_ = make_const_view(tmp_cv_areas); - - // initalize matrix - matrix_ = matrix_type( - group_parent_index, cell_comp_bounds, cv_capacitance, face_conductance, tmp_cv_areas); - - // Keep cv index list for each mechanism for ion set up below. - std::map<std::string, std::vector<size_type>> mech_to_cv_index; - // Keep area of each cv occupied by each mechanism, which may be less than - // the total area of the cv. - std::map<std::string, std::vector<value_type>> mech_to_area; - - // Working vectors (re-used per mechanism). - std::vector<size_type> mech_cv(ncomp); - std::vector<value_type> mech_weight(ncomp); - - for (auto& entry: mech_map) { - const auto& mech_name = entry.first; - auto& segments = entry.second; - - mech_cv.clear(); - mech_weight.clear(); - - // Three passes are performed over the segment list: - // 1. Compute the CVs and area contributions where the mechanism is instanced. - // 2. Build table of modified parameters together with default values. - // 3. Compute weights and parameters. - // The mechanism is instantiated after the first pass, in order to gain - // access to default mechanism parameter values. - - for (auto& seg: segments) { - const auto& rng = seg.cv_range; - seg.contributions.reserve(rng.size()); - - if (rng.has_parent()) { - auto cv = rng.parent_cv; - - auto it = algorithms::binary_find(mech_cv, cv); - size_type pos = it - mech_cv.begin(); - - if (it == mech_cv.end()) { - mech_cv.push_back(cv); - } - - seg.contributions.push_back({pos, rng.areas.first}); - } - - for (auto cv: make_span(rng.segment_cvs)) { - size_type pos = mech_cv.size(); - mech_cv.push_back(cv); - seg.contributions.push_back({pos, tmp_cv_areas[cv]}); - } - - // Last CV contribution may be only partial, so adjust. - seg.contributions.back().area = rng.areas.second; - } - - auto nindex = mech_cv.size(); - - EXPECTS(std::is_sorted(mech_cv.begin(), mech_cv.end())); - EXPECTS(nindex>0); - - auto& mech = make_mechanism(mech_name, special_mechs, mech_cv); - - // Save the indices for ion set up below. - - mech_to_cv_index[mech_name] = mech_cv; - - // Build modified (non-global) parameter table. - - struct param_tbl_entry { - std::vector<value_type> values; // staged for writing to mechanism - view data; // view to corresponding data in mechanism - value_type dflt; // default value for parameter - }; - - std::map<std::string, param_tbl_entry> param_tbl; - - for (const auto& seg: segments) { - for (const auto& pv: seg.param_map) { - if (param_tbl.count(pv.first)) { - continue; - } - - // Grab default value from mechanism data. - auto& entry = param_tbl[pv.first]; - entry.data = mech_field(mech, pv.first); - entry.dflt = entry.data[0]; - entry.values.assign(nindex, 0.); - } - } - - // Perform another pass of segment list to compute weights and (non-global) parameters. - - mech_weight.assign(nindex, 0.); - - for (const auto& seg: segments) { - for (auto cw: seg.contributions) { - mech_weight[cw.index] += cw.area; - - for (auto& entry: param_tbl) { - value_type v = entry.second.dflt; - const auto& name = entry.first; - - auto it = seg.param_map.find(name); - if (it != seg.param_map.end()) { - v = it->second; - } - - entry.second.values[cw.index] += cw.area*v; - } - } - } - - // Save the areas for ion setup below. - mech_to_area[mech_name] = mech_weight; - - for (auto& entry: param_tbl) { - for (size_type i = 0; i<nindex; ++i) { - entry.second.values[i] /= mech_weight[i]; - } - memory::copy(entry.second.values, entry.second.data); - } - - // Scale the weights by the CV area to get the proportion of the CV surface - // on which the mechanism is present. After scaling, the current will have - // units A.m^-2. - for (auto i: make_span(0, mech_weight.size())) { - mech_weight[i] *= 10/tmp_cv_areas[mech_cv[i]]; - } - mech.set_weights(memory::make_const_view(mech_weight)); - } - - target_handles.resize(targets_count); - - // Create point (synapse) mechanisms. - for (auto& map_entry: syn_mech_map) { - size_type mech_id = mechanisms_.size(); - - const auto& mech_name = map_entry.first; - auto& syn_data = map_entry.second; - auto n_instance = syn_data.size(); - - // Build permutation p such that p[j] is the index into - // syn_data for the jth synapse of this mechanism type as ordered by cv index. - - auto cv_of = [&](cell_lid_type i) { return syn_data[i].cv; }; - - std::vector<cell_lid_type> p(n_instance); - std::iota(p.begin(), p.end(), 0u); - util::sort_by(p, cv_of); - - std::vector<cell_lid_type> mech_cv; - std::vector<value_type> mech_weight; - mech_cv.reserve(n_instance); - mech_weight.reserve(n_instance); - - // Build mechanism cv index vector, weights and targets. - for (auto i: make_span(0u, n_instance)) { - const auto& syn = syn_data[p[i]]; - mech_cv.push_back(syn.cv); - // The weight for each synapses is 1/cv_area, scaled by 100 to match the units - // of 10.A.m^-2 used to store current densities in current_. - mech_weight.push_back(1e3/tmp_cv_areas[syn.cv]); - target_handles[syn.target] = target_handle(mech_id, i, cv_to_cell_tmp[syn.cv]); - } - - auto& mech = make_mechanism(mech_name, special_mechs, mech_cv); - mech.set_weights(memory::make_const_view(mech_weight)); - - // Save the indices for ion set up below. - mech_to_cv_index[mech_name] = mech_cv; - - // Update the mechanism parameters. - std::map<std::string, std::vector<std::pair<cell_lid_type, value_type>>> param_assigns; - for (auto i: make_span(0u, n_instance)) { - for (const auto& pv: syn_data[p[i]].param_map) { - param_assigns[pv.first].push_back({i, pv.second}); - } - } - - for (const auto& pa: param_assigns) { - view field_data = mech_field(mech, pa.first); - host_array field_values = field_data; - for (const auto &iv: pa.second) { - field_values[iv.first] = iv.second; - } - memory::copy(field_values, field_data); - } - } - - // build the ion species - for (auto ion : ion_kinds()) { - // find the compartment indexes of all compartments that have a - // mechanism that depends on/influences ion - std::set<size_type> index_set; - for (auto const& mech : mechanisms_) { - if(mech->uses_ion(ion).uses) { - auto const& ni = mech_to_cv_index[mech->name()]; - index_set.insert(ni.begin(), ni.end()); - } - } - std::vector<size_type> indexes(index_set.begin(), index_set.end()); - const auto n = indexes.size(); - - if (n==0u) continue; - - // create the ion state - ions_[ion] = indexes; - - std::vector<value_type> w_int; - w_int.reserve(n); - for (auto i: indexes) { - w_int.push_back(tmp_cv_areas[i]); - } - std::vector<value_type> w_out = w_int; - - // Join the ion reference in each mechanism into the cell-wide ion state. - for (auto& mech : mechanisms_) { - const auto spec = mech->uses_ion(ion); - if (spec.uses) { - const auto& ni = mech_to_cv_index[mech->name()]; - const auto m = ni.size(); // number of CVs - const std::vector<size_type> sub_index = - util::assign_from(algorithms::index_into(ni, indexes)); - mech->set_ion(ion, ions_[ion], sub_index); - - const auto& ai = mech_to_area[mech->name()]; - if (spec.write_concentration_in) { - for (auto i: make_span(0, m)) { - w_int[sub_index[i]] -= ai[i]; - } - } - if (spec.write_concentration_out) { - for (auto i: make_span(0, m)) { - w_out[sub_index[i]] -= ai[i]; - } - } - } - } - // Normalise the weights. - for (auto i: make_span(0, n)) { - w_int[i] /= tmp_cv_areas[indexes[i]]; - w_out[i] /= tmp_cv_areas[indexes[i]]; - } - ions_[ion].set_weights(w_int, w_out); - } - - // Note: NEURON defined default values for reversal potential as follows, - // with units mV: - // - // const auto DEF_vrest = -65.0 - // ena = 115.0 + DEF_vrest - // ek = -12.0 + DEF_vrest - // eca = 12.5*std::log(2.0/5e-5) - // - // Whereas we use the Nernst equation to calculate reversal potentials at - // the start of each time step. - - ion_na().default_int_concentration = 10; - ion_na().default_ext_concentration =140; - ion_na().valency = 1; - - ion_k().default_int_concentration =54.4; - ion_k().default_ext_concentration = 2.5; - ion_k().valency = 1; - - ion_ca().default_int_concentration =5e-5; - ion_ca().default_ext_concentration = 2.0; - ion_ca().valency = 2; - - // initialize mechanism and voltage state - reset(); -} - -template <typename Backend> -void fvm_multicell<Backend>::reset() { - memory::fill(voltage_, resting_potential_); - - set_time_global(0); - set_time_to_global(0); - - // Update ion species: - // - clear currents - // - reset concentrations to defaults - // - recalculate reversal potentials - for (auto& i: ions_) { - i.second.reset(); - } - - for (auto& m : mechanisms_) { - m->set_params(); - m->nrn_init(); - m->write_back(); - } - - // Update reversal potential to account for changes to concentrations made - // by calls to nrn_init() in mechansisms. - for (auto& i: ions_) { - i.second.nernst_reversal_potential(constant::hh_squid_temp); // TODO: use temperature specfied in model - } - - // Reset state of the threshold watcher. - // NOTE: this has to come after the voltage_ values have been reinitialized, - // because these values are used by the watchers to set their initial state. - threshold_watcher_.reset(); - - // Reset integration state. - tfinal_ = 0; - dt_max_ = 0; - min_remaining_steps_ = 0; - events_.clear(); - sample_events_.clear(); - - EXPECTS(integration_complete()); - EXPECTS(!has_pending_events()); -} - -template <typename Backend> -void fvm_multicell<Backend>::step_integration() { - EXPECTS(!integration_complete()); - - PE(advance_integrate_events); - // mark pending events for delivery - events_.mark_until_after(time_); - PL(); - - PE(advance_integrate_current); - memory::fill(current_, 0.); - - // clear currents and recalculate reversal potentials for all ion channels - for (auto& i: ions_) { - auto& ion = i.second; - memory::fill(ion.current(), 0.); - ion.nernst_reversal_potential(constant::hh_squid_temp); // TODO: use temperature specfied in model - } - - // deliver pending events and update current contributions from mechanisms - for (auto& m: mechanisms_) { - m->deliver_events(events_.marked_events()); - m->nrn_current(); - } - PL(); - - PE(advance_integrate_events); - // remove delivered events from queue and set time_to_ - events_.drop_marked_events(); - - backend::update_time_to(time_to_, time_, dt_max_, tfinal_); - invalidate_time_cache(); - events_.event_time_if_before(time_to_); - - // set per-cell and per-compartment dt (constant within a cell) - backend::set_dt(dt_cell_, dt_comp_, time_to_, time_, cv_to_cell_); - PL(); - - PE(advance_integrate_samples); - // take samples if they lie within the integration step; they will be provided - // with the values (post-event delivery) at the beginning of the interval. - sample_events_.mark_until(time_to_); - backend::take_samples(sample_events_.marked_events(), time_, sample_time_, sample_value_); - sample_events_.drop_marked_events(); - PL(); - - // solve the linear system - PE(advance_integrate_matrix_build); - matrix_.assemble(dt_cell_, voltage_, current_); - PL(); - - PE(advance_integrate_matrix_solve); - matrix_.solve(); - memory::copy(matrix_.solution(), voltage_); - PL(); - - // integrate state of gating variables etc. - PE(advance_integrate_state); - for(auto& m: mechanisms_) { - m->nrn_state(); - } - PL(); - - PE(advance_integrate_ionupdate); - for(auto& i: ions_) { - i.second.init_concentration(); - } - for(auto& m: mechanisms_) { - m->write_back(); - } - PL(); - - PE(advance_integrate_events); - // update time stepping variables - memory::copy(time_to_, time_); - invalidate_time_cache(); - PL(); - - PE(advance_integrate_threshold); - // update spike detector thresholds - threshold_watcher_.test(); - PL(); - - // are we there yet? - decrement_min_remaining(); - - EXPECTS(!integration_complete() || !has_pending_events()); -} - -} // namespace fvm -} // namespace arb diff --git a/src/io/exporter.hpp b/src/io/exporter.hpp index 9cf817f3c530e94c5fa2e0900deba556d17dbae2..9435cea60dac82f82ce765848cae26083f8e4a32 100644 --- a/src/io/exporter.hpp +++ b/src/io/exporter.hpp @@ -21,6 +21,8 @@ public: // Returns the status of the exporter virtual bool good() const = 0; + + virtual ~exporter() {} }; } //communication diff --git a/src/ion.hpp b/src/ion.hpp index 3d64d0370371bce0a66c0ae6779bea84cfe49909..b0bd0c3c59591e14f2011a13d21dcf506caece33 100644 --- a/src/ion.hpp +++ b/src/ion.hpp @@ -1,154 +1,29 @@ #pragma once -#include <array> -#include <constants.hpp> -#include <memory/memory.hpp> -#include <util/indirect.hpp> +#include <stdexcept> +#include <string> namespace arb { -/* - Ion channels have the following fields, whose label corresponds to that - in NEURON. We give them more easily understood accessors. +// Fixed set of ion species (to be generalized in the future): - --------------------------------------------------- - label Ca Na K name - --------------------------------------------------- - iX ica ina ik current - eX eca ena ek reversal_potential - Xi cai nai ki internal_concentration - Xo cao nao ko external_concentration - gX gca gna gk conductance - --------------------------------------------------- -*/ - -/// enumerate the ion channel types enum class ionKind {ca, na, k}; - -inline static -std::string to_string(ionKind k) { - switch(k) { - case ionKind::na : return "sodium"; - case ionKind::ca : return "calcium"; - case ionKind::k : return "pottasium"; +inline std::string to_string(ionKind k) { + switch (k) { + case ionKind::ca: return "ca"; + case ionKind::na: return "na"; + case ionKind::k: return "k"; + default: throw std::out_of_range("unknown ionKind"); } - return "unkown"; } -/// a helper for iterting over the ion species -constexpr std::array<ionKind, 3> ion_kinds() { - return {ionKind::ca, ionKind::na, ionKind::k}; -} - -/// storage for ion channel information in a cell group -template<typename Backend> -class ion { -public : - using backend = Backend; - - // expose tempalte parameters - using value_type = typename backend::value_type; - using size_type = typename backend::size_type; - - // define storage types - using array = typename backend::array; - using iarray = typename backend::iarray; - using view = typename backend::view; - using const_iview = typename backend::const_iview; - - ion() = default; - - ion(const std::vector<size_type>& idx) : - node_index_{memory::make_const_view(idx)}, - iX_{idx.size(), std::numeric_limits<value_type>::quiet_NaN()}, - eX_{idx.size(), std::numeric_limits<value_type>::quiet_NaN()}, - Xi_{idx.size(), std::numeric_limits<value_type>::quiet_NaN()}, - Xo_{idx.size(), std::numeric_limits<value_type>::quiet_NaN()}, - valency(0), - default_int_concentration(0), - default_ext_concentration(0) - {} - - // Set the weights used when setting default concentration values in each CV. - // The concentration of an ion species in a CV is a linear combination of - // default concentration and contributions from mechanisms that update the - // concentration. The weight is a value between 0 and 1 that represents the - // proportion of the CV area for which the default value is to be used - // (i.e. the proportion of the CV where the concentration is prescribed by a - // mechanism). - void set_weights(const std::vector<value_type>& win, const std::vector<value_type>& wout) { - EXPECTS(win.size() == size()); - EXPECTS(wout.size() == size()); - weight_Xi_ = memory::make_const_view(win); - weight_Xo_ = memory::make_const_view(wout); - } - - view current() { - return iX_; - } - - view reversal_potential() { - return eX_; - } - - view internal_concentration() { - return Xi_; - } - - view external_concentration() { - return Xo_; - } - - view internal_concentration_weights() { - return weight_Xi_; - } - - view external_concentration_weights() { - return weight_Xo_; - } - - void reset() { - // The Nernst equation uses the assumption of nonzero concentrations: - EXPECTS(default_int_concentration > value_type(0)); - EXPECTS(default_ext_concentration > value_type(0)); - memory::fill(iX_, 0); // reset current - init_concentration(); // reset internal and external concentrations - nernst_reversal_potential(constant::hh_squid_temp); // TODO: use temperature specfied in model - } - - /// Calculate the reversal potential for all compartments using Nernst equation - /// temperature is in degrees Kelvin - void nernst_reversal_potential(value_type temperature) { - backend::nernst(valency, temperature, Xo_, Xi_, eX_); - } - - void init_concentration() { - backend::init_concentration( - Xi_, Xo_, weight_Xi_, weight_Xo_, - default_int_concentration, default_ext_concentration); - } - - const_iview node_index() const { - return node_index_; - } - - std::size_t size() const { - return node_index_.size(); - } - -private: - iarray node_index_; - array iX_; // (nA) current - array eX_; // (mV) reversal potential - array Xi_; // (mM) internal concentration - array Xo_; // (mM) external concentration - array weight_Xi_; // (1) concentration weight internal - array weight_Xo_; // (1) concentration weight external +// Ion (species) description -public: - int valency; // valency of ionic species - value_type default_int_concentration; // (mM) default internal concentration - value_type default_ext_concentration; // (mM) default external concentration +struct ion_info { + ionKind kind; + int charge; // charge of ionic species + double default_int_concentration; // (mM) default internal concentration + double default_ext_concentration; // (mM) default external concentration }; } // namespace arb diff --git a/src/math.hpp b/src/math.hpp index 120b7fef02526526a4766712537ba8a5742670c6..f5a8cfc3bdf410e8de12e2b956ab13dc5a85f65b 100644 --- a/src/math.hpp +++ b/src/math.hpp @@ -2,6 +2,7 @@ #include <cmath> #include <limits> +#include <type_traits> #include <utility> namespace arb { @@ -80,6 +81,53 @@ int signum(T x) { return (x>T(0)) - (x<T(0)); } +// Next integral power of 2 for unsigned integers: +// +// next_pow2(x) returns 0 if x==0, else returns smallest 2^k such +// that 2^k>=x. + +template <typename U, typename = typename std::enable_if<std::is_unsigned<U>::value>::type> +U next_pow2(U x) { + --x; + for (unsigned s=1; s<std::numeric_limits<U>::digits; s<<=1) { + x|=(x>>s); + } + return ++x; +} + +namespace impl { + template <typename T> + T abs_if_signed(const T& x, std::true_type) { + return std::abs(x); + } + + template <typename T> + T abs_if_signed(const T& x, std::false_type) { + return x; + } +} + +// round_up(v, b) returns r, the smallest magnitude multiple of b +// such that v lies between 0 and r inclusive. +// +// Examples: +// round_up( 7, 3) == 9 +// round_up( 7, -3) == 9 +// round_up(-7, 3) == -9 +// round_up(-7, -3) == -9 +// round_up( 8, 4) == 8 + +template < + typename T, + typename U, + typename C = typename std::common_type<T, U>::type, + typename Signed = typename std::is_signed<C>::type +> +C round_up(T v, U b) { + C m = v%b; + return v-m+signum(m)*impl::abs_if_signed(b, Signed{}); +} + // Return minimum of the two values template <typename T> T min(const T& lhs, const T& rhs) { diff --git a/src/matrix.hpp b/src/matrix.hpp index d8968b2f277c1a7016156e8512044265cdd9b3dd..949ebc7cbeb3d5df12dc1fd2718ce10bc3fee0ef 100644 --- a/src/matrix.hpp +++ b/src/matrix.hpp @@ -19,32 +19,28 @@ public: // define basic types using value_type = typename backend::value_type; + using index_type = typename backend::index_type; using size_type = typename backend::size_type; // define storage types using array = typename backend::array; using iarray = typename backend::iarray; - using const_view = typename backend::const_view; - using const_iview = typename backend::const_iview; - - using host_array = typename backend::host_array; - // back end specific storage for matrix state using state = State; matrix() = default; - matrix(const std::vector<size_type>& pi, - const std::vector<size_type>& ci, + matrix(const std::vector<index_type>& pi, + const std::vector<index_type>& ci, const std::vector<value_type>& cv_capacitance, const std::vector<value_type>& face_conductance, const std::vector<value_type>& cv_area): - parent_index_(memory::make_const_view(pi)), - cell_index_(memory::make_const_view(ci)), + parent_index_(pi.begin(), pi.end()), + cell_index_(ci.begin(), ci.end()), state_(pi, ci, cv_capacitance, face_conductance, cv_area) { - EXPECTS(cell_index_[num_cells()] == parent_index_.size()); + EXPECTS(cell_index_[num_cells()] == index_type(parent_index_.size())); } /// the dimension of the matrix (i.e. the number of rows or colums) @@ -58,10 +54,10 @@ public: } /// the vector holding the parent index - const_iview p() const { return parent_index_; } + const iarray& p() const { return parent_index_; } /// the partition of the parent index over the cells - const_iview cell_index() const { return cell_index_; } + const iarray& cell_index() const { return cell_index_; } /// Solve the linear system. void solve() { @@ -69,12 +65,12 @@ public: } /// Assemble the matrix for given dt - void assemble(const_view dt_cell, const_view voltage, const_view current) { + void assemble(const array& dt_cell, const array& voltage, const array& current) { state_.assemble(dt_cell, voltage, current); } /// Get a view of the solution - const_view solution() const { + typename State::const_view solution() const { return state_.solution(); } diff --git a/src/mc_cell_group.cpp b/src/mc_cell_group.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d96f9062942cecba4fb3ece3f9f782da214e5593 --- /dev/null +++ b/src/mc_cell_group.cpp @@ -0,0 +1,205 @@ +#include <functional> +#include <unordered_map> +#include <vector> + +#include <backends/event.hpp> +#include <cell.hpp> +#include <cell_group.hpp> +#include <common_types.hpp> +#include <cell_group.hpp> +#include <event_binner.hpp> +#include <event_queue.hpp> +#include <fvm_lowered_cell.hpp> +#include <mc_cell_group.hpp> +#include <recipe.hpp> +#include <sampler_map.hpp> +#include <sampling.hpp> +#include <spike.hpp> +#include <util/filter.hpp> +#include <util/partition.hpp> + +#include <profiling/profiler.hpp> +#include <util/debug.hpp> + +namespace arb { + +mc_cell_group::mc_cell_group(std::vector<cell_gid_type> gids, const recipe& rec, fvm_lowered_cell_ptr lowered): + gids_(std::move(gids)), lowered_(std::move(lowered)) +{ + // Default to no binning of events + set_binning_policy(binning_kind::none, 0); + + // Build lookup table for gid to local index. + for (auto i: util::make_span(0, gids_.size())) { + gid_index_map_[gids_[i]] = i; + } + + // Create lookup structure for target ids. + util::make_partition(target_handle_divisions_, + util::transform_view(gids_, [&rec](cell_gid_type i) { return rec.num_targets(i); })); + std::size_t n_targets = target_handle_divisions_.back(); + + // Pre-allocate space to store handles, probe map. + auto n_probes = util::sum_by(gids_, [&rec](cell_gid_type i) { return rec.num_probes(i); }); + probe_map_.reserve(n_probes); + target_handles_.reserve(n_targets); + + // Construct cell implementation, retrieving handles and maps. + lowered_->initialize(gids_, rec, target_handles_, probe_map_); + + // Create a list of the global identifiers for the spike sources + for (auto source_gid: gids_) { + for (cell_lid_type lid = 0; lid<rec.num_sources(source_gid); ++lid) { + spike_sources_.push_back({source_gid, lid}); + } + } + spike_sources_.shrink_to_fit(); +} + +void mc_cell_group::reset() { + spikes_.clear(); + + sample_events_.clear(); + for (auto &assoc: sampler_map_) { + assoc.sched.reset(); + } + + for (auto& b: binners_) { + b.reset(); + } + + lowered_->reset(); +} + +void mc_cell_group::set_binning_policy(binning_kind policy, time_type bin_interval) { + binners_.clear(); + binners_.resize(gids_.size(), event_binner(policy, bin_interval)); +} + +void mc_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange& event_lanes) { + time_type tstart = lowered_->time(); + + PE(advance_eventsetup); + staged_events_.clear(); + // skip event binning if empty lanes are passed + if (event_lanes.size()) { + for (auto lid: util::make_span(0, gids_.size())) { + auto& lane = event_lanes[lid]; + for (auto e: lane) { + if (e.time>=ep.tfinal) break; + e.time = binners_[lid].bin(e.time, tstart); + auto h = target_handles_[target_handle_divisions_[lid]+e.target.index]; + auto ev = deliverable_event(e.time, h, e.weight); + staged_events_.push_back(ev); + } + } + } + PL(); + + // Create sample events and delivery information. + // + // For each (schedule, sampler, probe set) in the sampler association + // map that will be triggered in this integration interval, create + // sample events for the lowered cell, one for each scheduled sample + // time and probe in the probe set. + // + // Each event is associated with an offset into the sample data and + // time buffers; these are assigned contiguously such that one call to + // a sampler callback can be represented by a `sampler_call_info` + // value as defined below, grouping together all the samples of the + // same probe for this callback in this association. + + struct sampler_call_info { + sampler_function sampler; + cell_member_type probe_id; + probe_tag tag; + + // Offsets are into lowered cell sample time and event arrays. + sample_size_type begin_offset; + sample_size_type end_offset; + }; + + PE(advance_samplesetup); + std::vector<sampler_call_info> call_info; + + std::vector<sample_event> sample_events; + sample_size_type n_samples = 0; + sample_size_type max_samples_per_call = 0; + + for (auto& sa: sampler_map_) { + auto sample_times = sa.sched.events(tstart, ep.tfinal); + if (sample_times.empty()) { + continue; + } + + sample_size_type n_times = sample_times.size(); + max_samples_per_call = std::max(max_samples_per_call, n_times); + + for (cell_member_type pid: sa.probe_ids) { + auto cell_index = gid_index_map_.at(pid.gid); + auto p = probe_map_[pid]; + + call_info.push_back({sa.sampler, pid, p.tag, n_samples, n_samples+n_times}); + + for (auto t: sample_times) { + sample_event ev{t, cell_index, {p.handle, n_samples++}}; + sample_events.push_back(ev); + } + } + } + + // Sample events must be ordered by time for the lowered cell. + util::sort_by(sample_events, [](const sample_event& ev) { return event_time(ev); }); + PL(); + + // Run integration and collect samples, spikes. + auto result = lowered_->integrate(ep.tfinal, dt, staged_events_, std::move(sample_events), util::is_debug_mode()); + + // For each sampler callback registered in `call_info`, construct the + // vector of sample entries from the lowered cell sample times and values + // and then call the callback. + + PE(advance_sampledeliver); + std::vector<sample_record> sample_records; + sample_records.reserve(max_samples_per_call); + + for (auto& sc: call_info) { + sample_records.clear(); + for (auto i = sc.begin_offset; i!=sc.end_offset; ++i) { + sample_records.push_back(sample_record{time_type(result.sample_time[i]), &result.sample_value[i]}); + } + + sc.sampler(sc.probe_id, sc.tag, sc.end_offset-sc.begin_offset, sample_records.data()); + } + PL(); + + // Copy out spike voltage threshold crossings from the back end, then + // generate spikes with global spike source ids. The threshold crossings + // record the local spike source index, which must be converted to a + // global index for spike communication. + + for (auto c: result.crossings) { + spikes_.push_back({spike_sources_[c.index], time_type(c.time)}); + } +} + +void mc_cell_group::add_sampler(sampler_association_handle h, cell_member_predicate probe_ids, + schedule sched, sampler_function fn, sampling_policy policy) +{ + std::vector<cell_member_type> probeset = + util::assign_from(util::filter(util::keys(probe_map_), probe_ids)); + + if (!probeset.empty()) { + sampler_map_.add(h, sampler_association{std::move(sched), std::move(fn), std::move(probeset)}); + } +} + +void mc_cell_group::remove_sampler(sampler_association_handle h) { + sampler_map_.remove(h); +} + +void mc_cell_group::remove_all_samplers() { + sampler_map_.clear(); +} + +} // namespace arb diff --git a/src/mc_cell_group.hpp b/src/mc_cell_group.hpp index a377358b4ea6eaaa4cf9ded5b8f8934592c8c898..e98fc1596b06522dc26f69105f7ff4a8e4e0bf92 100644 --- a/src/mc_cell_group.hpp +++ b/src/mc_cell_group.hpp @@ -6,13 +6,13 @@ #include <unordered_map> #include <vector> -#include <algorithms.hpp> #include <backends/event.hpp> #include <cell_group.hpp> #include <cell.hpp> #include <common_types.hpp> #include <event_binner.hpp> #include <event_queue.hpp> +#include <fvm_lowered_cell.hpp> #include <recipe.hpp> #include <sampler_map.hpp> #include <sampling.hpp> @@ -28,189 +28,21 @@ namespace arb { -template <typename LoweredCell> class mc_cell_group: public cell_group { public: - using lowered_cell_type = LoweredCell; - using value_type = typename lowered_cell_type::value_type; - using size_type = typename lowered_cell_type::value_type; - mc_cell_group() = default; - mc_cell_group(std::vector<cell_gid_type> gids, const recipe& rec): - gids_(std::move(gids)) - { - // Default to no binning of events - set_binning_policy(binning_kind::none, 0); - - // Build lookup table for gid to local index. - for (auto i: util::make_span(0, gids_.size())) { - gid_index_map_[gids_[i]] = i; - } - - // Create lookup structure for target ids. - build_target_handle_partition(rec); - std::size_t n_targets = target_handle_divisions_.back(); - - // Pre-allocate space to store handles, probe map. - auto n_probes = util::sum_by(gids_, [&rec](cell_gid_type i) { return rec.num_probes(i); }); - probe_map_.reserve(n_probes); - target_handles_.reserve(n_targets); - - // Construct cell implementation, retrieving handles and maps. - lowered_.initialize(gids_, rec, target_handles_, probe_map_); - - // Create a list of the global identifiers for the spike sources - for (auto source_gid: gids_) { - for (cell_lid_type lid = 0; lid<rec.num_sources(source_gid); ++lid) { - spike_sources_.push_back({source_gid, lid}); - } - } - spike_sources_.shrink_to_fit(); - } + mc_cell_group(std::vector<cell_gid_type> gids, const recipe& rec, fvm_lowered_cell_ptr lowered); cell_kind get_cell_kind() const override { return cell_kind::cable1d_neuron; } - void reset() override { - spikes_.clear(); - reset_samplers(); - for (auto& b: binners_) { - b.reset(); - } - lowered_.reset(); - } - - void set_binning_policy(binning_kind policy, time_type bin_interval) override { - binners_.clear(); - binners_.resize(gids_.size(), event_binner(policy, bin_interval)); - } - - void advance(epoch ep, time_type dt, const event_lane_subrange& event_lanes) override { - EXPECTS(lowered_.state_synchronized()); - time_type tstart = lowered_.min_time(); - - PE(advance_eventsetup); - staged_events_.clear(); - // skip event binning if empty lanes are passed - if (event_lanes.size()) { - for (auto lid: util::make_span(0, gids_.size())) { - auto& lane = event_lanes[lid]; - for (auto e: lane) { - if (e.time>=ep.tfinal) break; - e.time = binners_[lid].bin(e.time, tstart); - auto h = target_handles_[target_handle_divisions_[lid]+e.target.index]; - auto ev = deliverable_event(e.time, h, e.weight); - staged_events_.push_back(ev); - } - } - } - PL(); - - // Create sample events and delivery information. - // - // For each (schedule, sampler, probe set) in the sampler association - // map that will be triggered in this integration interval, create - // sample events for the lowered cell, one for each scheduled sample - // time and probe in the probe set. - // - // Each event is associated with an offset into the sample data and - // time buffers; these are assigned contiguously such that one call to - // a sampler callback can be represented by a `sampler_call_info` - // value as defined below, grouping together all the samples of the - // same probe for this callback in this association. - - struct sampler_call_info { - sampler_function sampler; - cell_member_type probe_id; - probe_tag tag; - - // Offsets are into lowered cell sample time and event arrays. - sample_size_type begin_offset; - sample_size_type end_offset; - }; - - PE(advance_samplesetup); - std::vector<sampler_call_info> call_info; - - std::vector<sample_event> sample_events; - sample_size_type n_samples = 0; - sample_size_type max_samples_per_call = 0; - - for (auto& sa: sampler_map_) { - auto sample_times = sa.sched.events(tstart, ep.tfinal); - if (sample_times.empty()) { - continue; - } - - sample_size_type n_times = sample_times.size(); - max_samples_per_call = std::max(max_samples_per_call, n_times); - - for (cell_member_type pid: sa.probe_ids) { - auto cell_index = gid_to_index(pid.gid); - auto p = probe_map_[pid]; - - call_info.push_back({sa.sampler, pid, p.tag, n_samples, n_samples+n_times}); - - for (auto t: sample_times) { - sample_event ev{t, cell_index, {p.handle, n_samples++}}; - sample_events.push_back(ev); - } - } - } - - // Sample events must be ordered by time for the lowered cell. - util::sort_by(sample_events, [](const sample_event& ev) { return event_time(ev); }); - PL(); + void reset() override; - // Run integration. - lowered_.setup_integration(ep.tfinal, dt, staged_events_, std::move(sample_events)); - while (!lowered_.integration_complete()) { - lowered_.step_integration(); - if (util::is_debug_mode() && !lowered_.is_physical_solution()) { - std::cerr << "warning: solution out of bounds at (max) t " - << lowered_.max_time() << " ms\n"; - } - } + void set_binning_policy(binning_kind policy, time_type bin_interval) override; - // For each sampler callback registered in `call_info`, construct the - // vector of sample entries from the lowered cell sample times and values - // and then call the callback. - - PE(advance_sampledeliver); - std::vector<sample_record> sample_records; - sample_records.reserve(max_samples_per_call); - - auto sample_time = lowered_.sample_time(); - auto sample_value = lowered_.sample_value(); - - for (auto& sc: call_info) { - sample_records.clear(); - for (auto i = sc.begin_offset; i!=sc.end_offset; ++i) { - sample_records.push_back(sample_record{time_type(sample_time[i]), &sample_value[i]}); - } - - sc.sampler(sc.probe_id, sc.tag, sc.end_offset-sc.begin_offset, sample_records.data()); - } - PL(); - - // Copy out spike voltage threshold crossings from the back end, then - // generate spikes with global spike source ids. The threshold crossings - // record the local spike source index, which must be converted to a - // global index for spike communication. - - PE(advance_spikes); - for (auto c: lowered_.get_spikes()) { - spikes_.push_back({spike_sources_[c.index], time_type(c.time)}); - } - - // Now that the spikes have been generated, clear the old crossings - // to get ready to record spikes from the next integration period. - - lowered_.clear_spikes(); - PL(); - } + void advance(epoch ep, time_type dt, const event_lane_subrange& event_lanes) override; const std::vector<spike>& spikes() const override { return spikes_; @@ -220,28 +52,12 @@ public: spikes_.clear(); } - const std::vector<cell_member_type>& spike_sources() const { - return spike_sources_; - } - void add_sampler(sampler_association_handle h, cell_member_predicate probe_ids, - schedule sched, sampler_function fn, sampling_policy policy) override - { - std::vector<cell_member_type> probeset = - util::assign_from(util::filter(util::keys(probe_map_), probe_ids)); + schedule sched, sampler_function fn, sampling_policy policy) override; - if (!probeset.empty()) { - sampler_map_.add(h, sampler_association{std::move(sched), std::move(fn), std::move(probeset)}); - } - } + void remove_sampler(sampler_association_handle h) override; - void remove_sampler(sampler_association_handle h) override { - sampler_map_.remove(h); - } - - void remove_all_samplers() override { - sampler_map_.clear(); - } + void remove_all_samplers() override; private: // List of the gids of the cells in the group. @@ -251,7 +67,7 @@ private: std::unordered_map<cell_gid_type, cell_gid_type> gid_index_map_; // The lowered cell state (e.g. FVM) of the cell. - lowered_cell_type lowered_; + fvm_lowered_cell_ptr lowered_; // Spike detectors attached to the cell. std::vector<cell_member_type> spike_sources_; @@ -269,11 +85,9 @@ private: event_queue<sample_event> sample_events_; // Handles for accessing lowered cell. - using target_handle = typename lowered_cell_type::target_handle; std::vector<target_handle> target_handles_; // Maps probe ids to probe handles (from lowered cell) and tags (from probe descriptions). - using probe_handle = typename lowered_cell_type::probe_handle; probe_association_map<probe_handle> probe_map_; // Collection of samplers to be run against probes in this group. @@ -281,31 +95,6 @@ private: // Lookup table for target ids -> local target handle indices. std::vector<std::size_t> target_handle_divisions_; - - // Build handle index lookup tables. - void build_target_handle_partition(const recipe& rec) { - util::make_partition(target_handle_divisions_, - util::transform_view(gids_, [&rec](cell_gid_type i) { return rec.num_targets(i); })); - } - - // Get target handle from target id. - target_handle get_target_handle(cell_member_type id) const { - return target_handles_[target_handle_divisions_[gid_to_index(id.gid)]+id.index]; - } - - void reset_samplers() { - // clear all pending sample events and reset to start at time 0 - sample_events_.clear(); - for (auto &assoc: sampler_map_) { - assoc.sched.reset(); - } - } - - cell_gid_type gid_to_index(cell_gid_type gid) const { - auto it = gid_index_map_.find(gid); - EXPECTS(it!=gid_index_map_.end()); - return it->second; - } }; } // namespace arb diff --git a/src/mechanism.hpp b/src/mechanism.hpp index b0c4105213d22801556a841217c6eefb6ac02cee..8b4f2fd5b935443dcf8a18f15298e35f913f12f0 100644 --- a/src/mechanism.hpp +++ b/src/mechanism.hpp @@ -1,180 +1,94 @@ #pragma once -#include <algorithm> #include <memory> #include <string> +#include <vector> #include <backends/fvm_types.hpp> -#include <backends/event.hpp> -#include <backends/multi_event_stream_state.hpp> #include <ion.hpp> -#include <util/indirect.hpp> -#include <util/meta.hpp> -#include <util/make_unique.hpp> +#include <mechinfo.hpp> namespace arb { -struct field_spec { - enum field_kind { - parameter, // defined in 'PARAMETER' block and a 'RANGE' variable. - global, // defined in 'PARAMETER' block and a 'GLOBAL' variable. - state, // defined in 'STATE' block; run-time, read only values. - }; - enum field_kind kind = parameter; - - std::string units; - - fvm_value_type default_value = 0; - fvm_value_type lower_bound = std::numeric_limits<fvm_value_type>::lowest(); - fvm_value_type upper_bound = std::numeric_limits<fvm_value_type>::max(); - - // Until C++14, we need a ctor to provide default values instead of using - // default member initializers and aggregate initialization. - field_spec( - enum field_kind kind = parameter, - std::string units = "", - fvm_value_type default_value = 0., - fvm_value_type lower_bound = std::numeric_limits<fvm_value_type>::lowest(), - fvm_value_type upper_bound = std::numeric_limits<fvm_value_type>::max() - ): - kind(kind), units(units), default_value(default_value), lower_bound(lower_bound), upper_bound(upper_bound) - {} -}; +enum class mechanismKind {point, density}; +class mechanism; +using mechanism_ptr = std::unique_ptr<mechanism>; -enum class mechanismKind {point, density}; +template <typename B> class concrete_mechanism; +template <typename B> +using concrete_mech_ptr = std::unique_ptr<concrete_mechanism<B>>; -/// The mechanism type is templated on a memory policy type. -/// The only difference between the abstract definition of a mechanism on host -/// or gpu is the information is stored, and how it is accessed. -template <typename Backend> class mechanism { public: - struct ion_spec { - bool uses; - bool write_concentration_in; - bool write_concentration_out; + mechanism() = default; + mechanism(const mechanism&) = delete; + + // Description of layout of mechanism across cell group: used as parameter in + // `concrete_mechanism<B>::instantiate` (v.i.) + struct layout { + std::vector<fvm_index_type> cv; // Maps in-instance index to CV index. + std::vector<fvm_value_type> weight; // Maps in-instance index to compartment contribution. }; - using backend = Backend; - - using value_type = typename backend::value_type; - using size_type = typename backend::size_type; - - // define storage types - using array = typename backend::array; - using iarray = typename backend::iarray; - - using view = typename backend::view; - using iview = typename backend::iview; - - using const_view = typename backend::const_view; - using const_iview = typename backend::const_iview; + // Return fingerprint of mechanism dynamics source description for validation/replication. + virtual const mechanism_fingerprint& fingerprint() const = 0; - using ion_type = ion<backend>; + // Name as given in mechanism source. + virtual std::string internal_name() const { return ""; } - using deliverable_event_stream_state = multi_event_stream_state<deliverable_event_data>; + // Density or point mechanism? + virtual mechanismKind kind() const = 0; - mechanism(size_type mech_id, const_iview vec_ci, const_view vec_t, const_view vec_t_to, const_view vec_dt, view vec_v, view vec_i, iarray&& node_index): - mech_id_(mech_id), - vec_ci_(vec_ci), - vec_t_(vec_t), - vec_t_to_(vec_t_to), - vec_dt_(vec_dt), - vec_v_(vec_v), - vec_i_(vec_i), - node_index_(std::move(node_index)) - {} + // Does the implementation require padding and alignment of shared data structures? + virtual unsigned data_alignment() const { return 1; } - std::size_t size() const { - return node_index_.size(); - } + // Memory use in bytes. + virtual std::size_t memory() const = 0; - const_iview node_index() const { - return node_index_; - } + // Width of an instance: number of CVs (density mechanism) or sites (point mechanism) + // that the mechanism covers. + virtual std::size_t size() const = 0; - // Save pointers to data for use with GPU-side mechanisms; - // TODO: might be able to remove this method if we separate instantiation - // from initialization. - virtual void set_params() {} - virtual void set_weights(array&& weights) {} // override for density mechanisms + // Cloning makes a new object of the derived concrete mechanism type, but does not + // copy any state. + virtual mechanism_ptr clone() const = 0; - virtual std::string name() const = 0; - virtual std::size_t memory() const = 0; - virtual void nrn_init() = 0; - virtual void nrn_state() = 0; - virtual void nrn_current() = 0; - virtual void deliver_events(const deliverable_event_stream_state& events) {}; - virtual ion_spec uses_ion(ionKind) const = 0; - virtual void set_ion(ionKind k, ion_type& i, const std::vector<size_type>& index) = 0; - virtual mechanismKind kind() const = 0; + // Parameter setting + virtual void set_global(const std::string& param, fvm_value_type value) = 0; - // Used by mechanisms that update ion concentrations. - // Calling will copy the concentration, stored as internal state of the - // mechanism, to the "global" copy of ion species state. - virtual void write_back() {}; + // Non-global parameters can be set post-instantiation: + virtual void set_parameter(const std::string& key, const std::vector<fvm_value_type>& values) = 0; - // Mechanism instances with different global parameter settings can be distinguished by alias. - std::string alias() const { - return alias_.empty()? name(): alias_; - } + // Simulation interfaces: + virtual void nrn_init() = 0; + virtual void nrn_state() = 0; + virtual void nrn_current() = 0; + virtual void deliver_events() {}; + virtual void write_ions() = 0; - void set_alias(std::string alias) { - alias_ = std::move(alias); - } + virtual ~mechanism() = default; - // For non-global fields: - virtual view mechanism::* field_view_ptr(const char* id) const { return nullptr; } - // For global fields: - virtual value_type mechanism::* field_value_ptr(const char* id) const { return nullptr; } + // Per-cell group identifier for an instantiated mechanism. + fvm_size_type mechanism_id() const { return mechanism_id_; } - // Convenience wrappers for field access methods with string parameter. - view mechanism::* field_view_ptr(const std::string& id) const { return field_view_ptr(id.c_str()); } - value_type mechanism::* field_value_ptr(const std::string& id) const { return field_value_ptr(id.c_str()); } +protected: + // Per-cell group identifier for an instantiation of a mechanism; set by + // concrete_mechanism<B>::instantiate() + fvm_size_type mechanism_id_ = -1; +}; - // net_receive() is used internally by deliver_events(), but - // is exposed primarily for unit testing. - virtual void net_receive(int, value_type) {}; +// Backend-specific implementations provide mechanisms that are derived from `concrete_mechanism<Backend>`, +// likely via an intermediate class that captures common behaviour for that backend. - virtual ~mechanism() = default; +template <typename Backend> +class concrete_mechanism: public mechanism { +public: + using backend = Backend; - // Mechanism identifier: index into list of mechanisms on cell group. - size_type mech_id_; - - // Maps compartment index to cell index. - const_iview vec_ci_; - // Maps cell index to integration start time. - const_view vec_t_; - // Maps cell index to integration stop time. - const_view vec_t_to_; - // Maps compartment index to (stop time) - (start time). - const_view vec_dt_; - // Maps compartment index to voltage. - view vec_v_; - // Maps compartment index to current. - view vec_i_; - // Maps mechanism instance index to compartment index. - iarray node_index_; - - std::string alias_; + // Instantiation: allocate per-instance state; set views/pointers to shared data. + virtual void instantiate(fvm_size_type id, typename backend::shared_state&, const layout&) = 0; }; -template <class Backend> -using mechanism_ptr = std::unique_ptr<mechanism<Backend>>; - -template <typename M> -auto make_mechanism( - typename M::size_type mech_id, - typename M::const_iview vec_ci, - typename M::const_view vec_t, - typename M::const_view vec_t_to, - typename M::const_view vec_dt, - typename M::view vec_v, - typename M::view vec_i, - typename M::array&& weights, - typename M::iarray&& node_indices -) -DEDUCED_RETURN_TYPE(util::make_unique<M>(mech_id, vec_ci, vec_t, vec_t_to, vec_dt, vec_v, vec_i, std::move(weights), std::move(node_indices))) } // namespace arb diff --git a/src/mechcat.cpp b/src/mechcat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..83360802a9cf7867c490c4d7e751bf2bb5253084 --- /dev/null +++ b/src/mechcat.cpp @@ -0,0 +1,184 @@ +#include <map> +#include <memory> +#include <string> +#include <vector> + +#include <mechcat.hpp> +#include <util/maputil.hpp> +#include <util/make_unique.hpp> + +namespace arb { + +using util::value_by_key; +using util::make_unique; + +void mechanism_catalogue::add(const std::string& name, mechanism_info info) { + if (has(name)) { + throw std::invalid_argument("mechanism '"+name+"' already exists in catalogue"); + } + + info_map_[name] = mechanism_info_ptr(new mechanism_info(std::move(info))); +} + +const mechanism_info& mechanism_catalogue::operator[](const std::string& name) const { + if (const auto& deriv = value_by_key(derived_map_, name)) { + return *(deriv->derived_info.get()); + } + else if (auto p = value_by_key(info_map_, name)) { + return *(p->get()); + } + + throw std::invalid_argument("no mechanism with name +'"+name+"' in catalogue"); +} + +const mechanism_fingerprint& mechanism_catalogue::fingerprint(const std::string& name) const { + std::string base = name; + while (auto deriv = value_by_key(derived_map_, base)) { + base = deriv->parent; + } + + if (const auto& p = value_by_key(info_map_, base)) { + return p.value()->fingerprint; + } + + throw std::invalid_argument("no mechanism with name +'"+name+"' in catalogue"); +} + +void mechanism_catalogue::derive(const std::string& name, const std::string& parent, const std::vector<std::pair<std::string, double>>& global_params) { + if (has(name)) { + throw std::invalid_argument("mechanism with name '"+name+"' already exists in catalogue"); + } + + if (!has(parent)) { + throw std::invalid_argument("no mechanism with name '"+parent+"' in catalogue"); + } + + derivation deriv = {parent, {}, nullptr}; + mechanism_info_ptr info = mechanism_info_ptr(new mechanism_info((*this)[deriv.parent])); + + for (const auto& kv: global_params) { + const auto& param = kv.first; + const auto& value = kv.second; + + if (auto p = value_by_key(info->globals, param)) { + if (!p->valid(value)) { + throw std::invalid_argument("invalid value for parameter '"+param+"' in mechanism '"+name+"'"); + } + } + else { + throw std::invalid_argument("mechanism '"+name+"' has no global parameter '"+param+"'"); + } + + deriv.globals[param] = value; + info->globals.at(param).default_value = value; + } + + deriv.derived_info = std::move(info); + derived_map_[name] = std::move(deriv); +} + +void mechanism_catalogue::remove(const std::string& name) { + if (!has(name)) { + throw std::invalid_argument("no mechanism with name '"+name+"' in catalogue"); + } + + if (is_derived(name)) { + derived_map_.erase(name); + } + else { + info_map_.erase(name); + impl_map_.erase(name); + } + + // Erase any dangling derivation map entries. + std::size_t n_delete; + do { + n_delete = 0; + for (auto it = derived_map_.begin(); it!=derived_map_.end(); ) { + const auto& parent = it->second.parent; + if (info_map_.count(parent) || derived_map_.count(parent)) { + ++it; + } + else { + derived_map_.erase(it++); + ++n_delete; + } + } + } while (n_delete>0); +} + +std::unique_ptr<mechanism> mechanism_catalogue::instance_impl(std::type_index tidx, const std::string& name) const { + // Find implementation associated with this name or its closest ancestor. + + auto impl_name = name; + const mechanism* prototype = nullptr; + + for (;;) { + if (const auto mech_impls = value_by_key(impl_map_, impl_name)) { + if (auto p = value_by_key(mech_impls.value(), tidx)) { + prototype = p->get(); + break; + } + } + + // Try parent instead. + if (const auto p = value_by_key(derived_map_, impl_name)) { + impl_name = p->parent; + } + else { + throw std::invalid_argument("missing implementation for mechanism named '"+name+"'"); + } + } + + std::unique_ptr<mechanism> mech = prototype->clone(); + // TODO: make recursive lambda without std::function in C++14 + std::function<void (std::string, mechanism*)> apply_globals = + [this, &apply_globals](const std::string& name, mechanism* mptr) + { + if (auto p = value_by_key(derived_map_, name)) { + apply_globals(p->parent, mptr); + + for (auto& kv: p->globals) { + mptr->set_global(kv.first, kv.second); + } + } + }; + + apply_globals(name, mech.get()); + return mech; +} + +void mechanism_catalogue::register_impl(std::type_index tidx, const std::string& name, std::unique_ptr<mechanism> mech) { + const mechanism_info& info = (*this)[name]; + + if (mech->fingerprint()!=info.fingerprint) { + throw std::invalid_argument("implementation fingerprint does not match schema"); + } + + impl_map_[name][tidx] = std::move(mech); +} + +void mechanism_catalogue::copy_impl(const mechanism_catalogue& other) { + info_map_.clear(); + for (const auto& kv: other.info_map_) { + info_map_[kv.first] = make_unique<mechanism_info>(*kv.second); + } + + derived_map_.clear(); + for (const auto& kv: other.derived_map_) { + const derivation& v = kv.second; + derived_map_[kv.first] = {v.parent, v.globals, make_unique<mechanism_info>(*v.derived_info)}; + } + + impl_map_.clear(); + for (const auto& name_impls: other.impl_map_) { + std::unordered_map<std::type_index, std::unique_ptr<mechanism>> impls; + for (const auto& tidx_mptr: name_impls.second) { + impls[tidx_mptr.first] = tidx_mptr.second->clone(); + } + + impl_map_[name_impls.first] = std::move(impls); + } +} + +} // namespace arb diff --git a/src/mechcat.hpp b/src/mechcat.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9e30eb98e84fd743964f76faf7a636b97e0c7d8f --- /dev/null +++ b/src/mechcat.hpp @@ -0,0 +1,125 @@ +#pragma once + +#include <map> +#include <memory> +#include <string> +#include <typeindex> +#include <vector> + +#include <mechinfo.hpp> +#include <mechanism.hpp> + +// Mechanism catalogue maintains: +// +// 1. Collection of mechanism metadata indexed by name. +// +// 2. A further hierarchy of 'derived' mechanisms, that allow specialization of +// global parameters and implementations. +// +// 3. A map taking mechanism names x back-end class -> mechanism implementation +// prototype object. +// +// Implementations for a backend `B` are represented by a pointer to a +// `concrete_mechanism<B>` object. +// +// References to mechanism_info and mechanism_fingerprint objects are invalidated +// after any modification to the catalogue. +// +// There is in addition a global default mechanism catalogue object that is +// populated with any builtin mechanisms and mechanisms generated from +// module files included with arbor. + +namespace arb { + +class mechanism_catalogue { +public: + using value_type = double; + + mechanism_catalogue() = default; + mechanism_catalogue(mechanism_catalogue&& other) = default; + mechanism_catalogue& operator=(mechanism_catalogue&& other) = default; + + // Copying a catalogue requires cloning the prototypes. + mechanism_catalogue(const mechanism_catalogue& other) { + copy_impl(other); + } + + mechanism_catalogue& operator=(const mechanism_catalogue& other) { + copy_impl(other); + return *this; + } + + void add(const std::string& name, mechanism_info info); + + bool has(const std::string& name) const { + return info_map_.count(name) || is_derived(name); + } + + bool is_derived(const std::string& name) const { + return derived_map_.count(name); + } + + // Read-only access to mechanism info. + const mechanism_info& operator[](const std::string& name) const; + + // Read-only access to mechanism fingerprint. + const mechanism_fingerprint& fingerprint(const std::string& name) const; + + // Construct a schema for a mechanism derived from an existing entry, + // with a sequence of overrides for global scalar parameter settings. + void derive(const std::string& name, const std::string& parent, const std::vector<std::pair<std::string, double>>& global_params); + + // Remove mechanism from catalogue, together with any derived. + void remove(const std::string& name); + + // Clone the implementation associated with name (search derivation hierarchy starting from + // most derived) and set global parameters according to derivations. + template <typename B> + std::unique_ptr<concrete_mechanism<B>> instance(const std::string& name) const { + mechanism_ptr mech = instance_impl(std::type_index(typeid(B)), name); + + return std::unique_ptr<concrete_mechanism<B>>(dynamic_cast<concrete_mechanism<B>*>(mech.release())); + } + + // Associate a concrete (prototype) mechanism for a given back-end B with a (possibly derived) + // mechanism name. + template <typename B> + void register_implementation(const std::string& name, std::unique_ptr<concrete_mechanism<B>> proto) { + mechanism_ptr generic_proto = mechanism_ptr(proto.release()); + register_impl(std::type_index(typeid(B)), name, std::move(generic_proto)); + } + +private: + using mechanism_info_ptr = std::unique_ptr<mechanism_info>; + + template <typename V> + using string_map = std::unordered_map<std::string, V>; + + // Schemata for (un-derived) mechanisms. + string_map<mechanism_info_ptr> info_map_; + + struct derivation { + std::string parent; + string_map<value_type> globals; // global overrides relative to parent + mechanism_info_ptr derived_info; + }; + + // Parent and global setting values for derived mechanisms. + string_map<derivation> derived_map_; + + // Prototype register, keyed on mechanism name, then backend type (index). + string_map<std::unordered_map<std::type_index, mechanism_ptr>> impl_map_; + + // Concrete-type erased helper methods. + mechanism_ptr instance_impl(std::type_index, const std::string&) const; + void register_impl(std::type_index, const std::string&, mechanism_ptr); + + // Perform copy and prototype clone from other catalogue (overwrites all entries). + void copy_impl(const mechanism_catalogue&); +}; + +// Reference to global default mechanism catalogue. + +const mechanism_catalogue& global_default_catalogue(); + +} // namespace arb diff --git a/src/mechinfo.hpp b/src/mechinfo.hpp index cd40c38728c8e1bca67c0812be80f1694c9d2755..6e64bd88c42d09c358ba31ff27dc6a8f83030eb3 100644 --- a/src/mechinfo.hpp +++ b/src/mechinfo.hpp @@ -1,108 +1,79 @@ #pragma once -/* Mechanism schema classes, catalogue and parameter specification. - * - * Catalogue and schemata have placeholder implementations, to be - * completed in future work. - * - * The `mechanism_spec` class is the public interface for describing - * a mechanism and its (non-global) parameters in cable1d_cell - * recipes. It presents a map-like interface for accessing and querying - * parameter values, and parameter assignments will be validated - * against a corresponding schema when that infrastructure is in - * place. +/* Classes for representing a mechanism schema, including those + * generated automatically by modcc. */ +#include <limits> #include <string> +#include <unordered_map> #include <utility> #include <vector> +#include <ion.hpp> +#include <util/enumhash.hpp> + namespace arb { -struct mechanism_schema_field { - void validate(double) const {} +struct mechanism_field_spec { + enum field_kind { + parameter, + global, + state, + }; + enum field_kind kind = parameter; + + std::string units; + double default_value = 0; + double lower_bound = std::numeric_limits<double>::lowest(); + double upper_bound = std::numeric_limits<double>::max(); + + bool valid(double x) const { return x>=lower_bound && x<=upper_bound; } + + // TODO: C++14 - no need for ctor below, as aggregate initialization + // will work with default member initializers. + + mechanism_field_spec( + enum field_kind kind = parameter, + std::string units = "", + double default_value = 0., + double lower_bound = std::numeric_limits<double>::lowest(), + double upper_bound = std::numeric_limits<double>::max() + ): + kind(kind), units(units), default_value(default_value), lower_bound(lower_bound), upper_bound(upper_bound) + {} }; -struct mechanism_schema { - const mechanism_schema_field* field(const std::string& key) const { - static mechanism_schema_field dummy_field; - return &dummy_field; - } - - static const mechanism_schema* dummy_schema() { - static mechanism_schema d; - return &d; - } +struct ion_dependency { + bool write_concentration_int; + bool write_concentration_ext; }; -class mechanism_spec { -public: - struct field_proxy { - mechanism_spec* m; - std::string key; +// A hash of the mechanism dynamics description is used to ensure that offline-compiled +// mechanism implementations are correctly associated with their corresponding generated +// mechanism information. +// +// Use a textual representation to ease readability. +using mechanism_fingerprint = std::string; - field_proxy& operator=(double v) { - m->set(key, v); - return *this; - } +struct mechanism_info { + // Global fields have one value common to an instance of a mechanism, are + // constant in time and set at instantiation. + std::unordered_map<std::string, mechanism_field_spec> globals; - operator double() const { - return m->get(key); - } - }; + // Parameter fields may vary across the extent of a mechanism, but are + // constant in time and set at instantiation. + std::unordered_map<std::string, mechanism_field_spec> parameters; + + // State fields vary in time and across the extent of a mechanism, and + // potentially can be sampled at run-time. + std::unordered_map<std::string, mechanism_field_spec> state; + + // Ion dependencies. + std::unordered_map<ionKind, ion_dependency, util::enum_hash> ions; - // implicit - mechanism_spec(std::string name): name_(std::move(name)) { - // get schema pointer from global catalogue, or throw - schema_ = mechanism_schema::dummy_schema(); - if (!schema_) { - throw std::runtime_error("no mechanism "+name_); - } - } - - // implicit - mechanism_spec(const char* name): mechanism_spec(std::string(name)) {} - - mechanism_spec& set(std::string key, double value) { - auto field_schema = schema_->field(key); - if (!field_schema) { - throw std::runtime_error("no field "+key+" in mechanism "+name_); - } - - field_schema->validate(value); - param_[key] = value; - return *this; - } - - double operator[](const std::string& key) const { - return get(key); - } - - double get(const std::string& key) const { - auto field_schema = schema_->field(key); - if (!field_schema) { - throw std::runtime_error("no field "+key+" in mechanism "+name_); - } - - auto it = param_.find(key); - return it==param_.end()? field_schema->default_value: it->second; - } - - field_proxy operator[](const std::string& key) { - return {this, key}; - } - - const std::map<std::string, double>& values() const { - return param_; - } - - const std::string& name() const { return name_; } - -private: - std::string name_; - std::map<std::string, double> param_; - const mechanism_schema* schema_; // non-owning; schema must have longer lifetime + mechanism_fingerprint fingerprint; }; } // namespace arb diff --git a/src/memory/array.hpp b/src/memory/array.hpp index c5e358e6f5b736ce481d12774b4fa26de4109df1..4382b9653aeed084c935ac34312370ee7141e2c1 100644 --- a/src/memory/array.hpp +++ b/src/memory/array.hpp @@ -6,22 +6,28 @@ #pragma once +#include <cstdlib> #include <iostream> #include <type_traits> +#include <util/debug.hpp> #include <util/range.hpp> #include "definitions.hpp" #include "util.hpp" +#include "allocator.hpp" #include "array_view.hpp" namespace arb { -namespace memory{ +namespace memory { // forward declarations -template<typename T, typename Coord> +template <typename T, typename Coord> class array; +template <typename T, class Allocator> +class host_coordinator; + namespace util { template <typename T, typename Coord> struct type_printer<array<T,Coord>>{ @@ -209,7 +215,7 @@ public: template < typename It, - typename = arb::util::enable_if_t<arb::util::is_forward_iterator<It>::value> > + typename = arb::util::enable_if_t<arb::util::is_random_access_iterator<It>::value> > array(It b, It e) : base(coordinator_type().allocate(std::distance(b, e))) { @@ -219,8 +225,13 @@ public: << "\n this " << util::pretty_printer<array>::print(*this) << "\n"; //<< "\n other " << util::pretty_printer<Other>::print(other) << std::endl; #endif - //auto canon = arb::util::canonical_view(rng); - std::copy(b, e, this->begin()); + // Only valid for contiguous range, but we can't test that at compile time. + // Can check though that taking &*b+n = &*e where n = e-b, while acknowledging + // this is not fail safe. + EXPECTS(&*b+(e-b)==&*e); + + using V = typename std::iterator_traits<iterator>::value_type; + coordinator_.copy(const_array_view<V, host_coordinator<V, aligned_allocator<V>>>(&*b, e-b), view_type(*this)); } // use the accessors provided by array_view diff --git a/src/memory/device_coordinator.hpp b/src/memory/device_coordinator.hpp index f710d6bd04ab0e3c64569f460182a351eb8e7cbf..ff45f28644055ed45fdc7d7e8f5e7f3e793a1220 100644 --- a/src/memory/device_coordinator.hpp +++ b/src/memory/device_coordinator.hpp @@ -4,11 +4,11 @@ #include <exception> #include <util/debug.hpp> -#include <backends/gpu/fill.hpp> #include "allocator.hpp" #include "array.hpp" #include "definitions.hpp" +#include "fill.hpp" #include "gpu.hpp" #include "util.hpp" @@ -23,7 +23,6 @@ template <typename T, class Allocator> class host_coordinator; namespace util { - template <typename T, typename Allocator> struct type_printer<device_coordinator<T,Allocator>>{ static std::string print() { @@ -252,7 +251,7 @@ public: // fill memory void set(view_type &rng, value_type value) { if (rng.size()) { - arb::gpu::fill<value_type>(rng.data(), value, rng.size()); + gpu::fill<value_type>(rng.data(), value, rng.size()); } } diff --git a/src/backends/gpu/fill.cu b/src/memory/fill.cu similarity index 100% rename from src/backends/gpu/fill.cu rename to src/memory/fill.cu diff --git a/src/backends/gpu/fill.hpp b/src/memory/fill.hpp similarity index 100% rename from src/backends/gpu/fill.hpp rename to src/memory/fill.hpp diff --git a/src/memory/memory.hpp b/src/memory/memory.hpp index fe5f9b7090ecdbff609d73c9403ba0271a3113e8..4e9d393fdc1daf433fce62a9d67bedf048e94ab2 100644 --- a/src/memory/memory.hpp +++ b/src/memory/memory.hpp @@ -18,6 +18,8 @@ template <typename T> using host_vector = array<T, host_coordinator<T>>; template <typename T> using host_view = array_view<T, host_coordinator<T>>; +template <typename T> +using const_host_view = const_array_view<T, host_coordinator<T>>; template <typename T> std::ostream& operator<< (std::ostream& o, host_view<T> const& v) { @@ -42,6 +44,8 @@ template <typename T> using device_vector = array<T, device_coordinator<T, cuda_allocator<T>>>; template <typename T> using device_view = array_view<T, device_coordinator<T, cuda_allocator<T>>>; +template <typename T> +using const_device_view = const_array_view<T, device_coordinator<T, cuda_allocator<T>>>; #endif #ifdef WITH_KNL diff --git a/src/profiling/meter_manager.cpp b/src/profiling/meter_manager.cpp index aa869108a74760dd757580a35c3fbcdeda372a49..d745db6b35d9832be7170a5671491dffc961d1a5 100644 --- a/src/profiling/meter_manager.cpp +++ b/src/profiling/meter_manager.cpp @@ -1,3 +1,4 @@ +#include <algorithms.hpp> #include <communication/global_policy.hpp> #include <util/hostname.hpp> #include <util/strprintf.hpp> diff --git a/src/recipe.hpp b/src/recipe.hpp index caaddaddb69a2a8c4c4154808b656bac54365a5c..a03518fd69734476e61af57702c84e7376280bb5 100644 --- a/src/recipe.hpp +++ b/src/recipe.hpp @@ -74,9 +74,10 @@ public: throw std::logic_error("no probes"); } - // Global property type will be specific to given cell kind. virtual util::any get_global_properties(cell_kind) const { return util::any{}; }; + + virtual ~recipe() {} }; } // namespace arb diff --git a/src/schedule.hpp b/src/schedule.hpp index 06937357881d44280659dae61fe223ad30c8ee9b..3f7788c2c74337861975a4c540cb4942a1d606e3 100644 --- a/src/schedule.hpp +++ b/src/schedule.hpp @@ -101,7 +101,7 @@ inline schedule regular_schedule(time_type dt) { // Schedule at times given explicitly via a provided sorted sequence. class explicit_schedule_impl { public: - template <typename Seq, typename = util::enable_if_sequence_t<Seq>> + template <typename Seq, typename = util::enable_if_sequence_t<const Seq&>> explicit explicit_schedule_impl(const Seq& seq): start_index_(0), times_(std::begin(seq), compat::end(seq)) diff --git a/src/segment.hpp b/src/segment.hpp index 6c2e2e2b107b45b51fe22ab13c2eae026e80ee5f..43f8b21ebf7f69218f3ca29decf0ffd2589096c9 100644 --- a/src/segment.hpp +++ b/src/segment.hpp @@ -1,6 +1,9 @@ #pragma once #include <cmath> +#include <stdexcept> +#include <string> +#include <unordered_map> #include <vector> #include <algorithms.hpp> @@ -11,9 +14,60 @@ #include <mechinfo.hpp> #include <point.hpp> #include <util/make_unique.hpp> +#include <util/maputil.hpp> namespace arb { +// Mechanism information attached to a segment. + +struct mechanism_desc { + struct field_proxy { + mechanism_desc* m; + std::string key; + + field_proxy& operator=(double v) { + m->set(key, v); + return *this; + } + + operator double() const { + return m->get(key); + } + }; + + // implicit + mechanism_desc(std::string name): name_(std::move(name)) {} + mechanism_desc(const char* name): name_(name) {} + + mechanism_desc& set(const std::string& key, double value) { + param_[key] = value; + return *this; + } + + double operator[](const std::string& key) const { + return get(key); + } + + field_proxy operator[](const std::string& key) { + return {this, key}; + } + + double get(const std::string& key) const { + auto optv = util::value_by_key(param_, key); + return optv? *optv: throw std::out_of_range("no field "+key+" set"); + } + + const std::unordered_map<std::string, double>& values() const { + return param_; + } + + const std::string& name() const { return name_; } + +private: + std::string name_; + std::unordered_map<std::string, double> param_; +}; + // forward declarations of segment specializations class soma_segment; class cable_segment; @@ -80,13 +134,13 @@ public: return false; } - util::optional<mechanism_spec&> mechanism(const std::string& name) { + util::optional<mechanism_desc&> mechanism(const std::string& name) { auto it = std::find_if(mechanisms_.begin(), mechanisms_.end(), - [&](mechanism_spec& m) { return m.name()==name; }); + [&](mechanism_desc& m) { return m.name()==name; }); return it==mechanisms_.end()? util::nullopt: util::just(*it); } - void add_mechanism(mechanism_spec mech) { + void add_mechanism(mechanism_desc mech) { auto m = mechanism(mech.name()); if (m) { *m = std::move(mech); @@ -96,11 +150,11 @@ public: } } - const std::vector<mechanism_spec>& mechanisms() { + const std::vector<mechanism_desc>& mechanisms() { return mechanisms_; } - const std::vector<mechanism_spec>& mechanisms() const { + const std::vector<mechanism_desc>& mechanisms() const { return mechanisms_; } @@ -112,7 +166,7 @@ protected: segment(section_kind kind): kind_(kind) {} section_kind kind_; - std::vector<mechanism_spec> mechanisms_; + std::vector<mechanism_desc> mechanisms_; }; class placeholder_segment : public segment { @@ -218,11 +272,10 @@ public: cable_segment(section_kind k, std::vector<value_type> r, std::vector<value_type> lens): segment(k), radii_(std::move(r)), lengths_(std::move(lens)) { - assert(kind_==section_kind::dendrite || kind_==section_kind::axon); + EXPECTS(kind_==section_kind::dendrite || kind_==section_kind::axon); } cable_segment(section_kind k, value_type r1, value_type r2, value_type len): - //cable_segment{k, std::vector<value_type>{r1, r2}, std::vector<value_type>{len}} cable_segment{k, {r1, r2}, decltype(lengths_){len}} {} @@ -231,7 +284,7 @@ public: cable_segment(section_kind k, std::vector<value_type> r, std::vector<point_type> p): segment(k), radii_(std::move(r)), locations_(std::move(p)) { - assert(kind_==section_kind::dendrite || kind_==section_kind::axon); + EXPECTS(kind_==section_kind::dendrite || kind_==section_kind::axon); update_lengths(); } @@ -393,17 +446,4 @@ DivCompClass div_compartments(const cable_segment* cable) { return DivCompClass(cable->num_compartments(), cable->radii(), cable->lengths()); } -struct segment_location { - segment_location(cell_lid_type s, double l): - segment(s), position(l) - { - EXPECTS(position>=0. && position<=1.); - } - friend bool operator==(segment_location l, segment_location r) { - return l.segment==r.segment && l.position==r.position; - } - cell_lid_type segment; - double position; -}; - } // namespace arb diff --git a/src/simd/avx.hpp b/src/simd/avx.hpp index dfef7c0a6f0b2a2aa9826eaaeda29723edc5bd1c..ae3dd27e3ffa8dceed113a07e55c0c17e60ea257 100644 --- a/src/simd/avx.hpp +++ b/src/simd/avx.hpp @@ -37,6 +37,8 @@ struct simd_traits<avx_double4> { struct avx_int4: implbase<avx_int4> { // Use default implementations for: element, set_element, div. + using implbase<avx_int4>::cast_from; + using int32 = std::int32_t; static __m128i broadcast(int32 v) { @@ -64,6 +66,14 @@ struct avx_int4: implbase<avx_int4> { return ifelse(mask, _mm_castps_si128(d), v); } + static __m128i cast_from(tag<avx_double4>, const __m256d& v) { + return _mm256_cvttpd_epi32(v); + } + + static int element0(const __m128i& a) { + return _mm_cvtsi128_si32(a); + } + static __m128i negate(const __m128i& a) { __m128i zero = _mm_setzero_si128(); return _mm_sub_epi32(zero, a); @@ -172,12 +182,23 @@ struct avx_int4: implbase<avx_int4> { static __m128i min(const __m128i& a, const __m128i& b) { return _mm_min_epi32(a, b); } + + static int reduce_add(const __m128i& a) { + // Add [a3|a2|a1|a0] to [a2|a3|a0|a1] + __m128i b = add(a, _mm_shuffle_epi32(a, 0xb1)); + // Add [b3|b2|b1|b0] to [b1|b0|b3|b2] + __m128i c = add(b, _mm_shuffle_epi32(b, 0x4e)); + + return element0(c); + } }; struct avx_double4: implbase<avx_double4> { // Use default implementations for: // element, set_element, fma. + using implbase<avx_double4>::cast_from; + using int64 = std::int64_t; // CMPPD predicates: @@ -216,6 +237,14 @@ struct avx_double4: implbase<avx_double4> { return ifelse(mask, d, v); } + static __m256d cast_from(tag<avx_int4>, const __m128i& v) { + return _mm256_cvtepi32_pd(v); + } + + static double element0(const __m256d& a) { + return _mm_cvtsd_f64(_mm256_castpd256_pd128(a)); + } + static __m256d negate(const __m256d& a) { return _mm256_sub_pd(zero(), a); } @@ -348,6 +377,15 @@ struct avx_double4: implbase<avx_double4> { return _mm256_and_pd(x, _mm256_castsi256_pd(m)); } + static double reduce_add(const __m256d& a) { + // add [a3|a2|a1|a0] to [a1|a0|a3|a2] + __m256d b = add(a, _mm256_permute2f128_pd(a, a, 0x01)); + // add [b3|b2|b1|b0] to [b2|b3|b0|b1] + __m256d c = add(b, _mm256_permute_pd(b, 0x05)); + + return element0(c); + } + // Exponential is calculated as follows: // // e^x = e^g · 2^n, @@ -650,11 +688,17 @@ protected: #if defined(__AVX2__) && defined(__FMA__) -// Same implementation as for AVX. -using avx2_int4 = avx_int4; - +struct avx2_int4; struct avx2_double4; +template <> +struct simd_traits<avx2_int4> { + static constexpr unsigned width = 4; + using scalar_type = std::int32_t; + using vector_type = __m128i; + using mask_impl = avx_int4; +}; + template <> struct simd_traits<avx2_double4> { static constexpr unsigned width = 4; @@ -663,12 +707,34 @@ struct simd_traits<avx2_double4> { using mask_impl = avx2_double4; }; +// Note: we derive from avx_int4 only as an implementation shortcut. +// Because `avx2_int4` does not derive from `implbase<avx2_int4>`, +// any fallback methods in `implbase` will use the `avx_int4` +// functions rather than the `avx2_int4` functions. + +struct avx2_int4: avx_int4 { + using implbase<avx_int4>::cast_from; + + // Need to provide a cast overload for avx2_double4 tag: + static __m128i cast_from(tag<avx2_double4>, const __m256d& v) { + return _mm256_cvttpd_epi32(v); + } +}; + // Note: we derive from avx_double4 only as an implementation shortcut. // Because `avx2_double4` does not derive from `implbase<avx2_double4>`, // any fallback methods in `implbase` will use the `avx_double4` // functions rather than the `avx2_double4` functions. struct avx2_double4: avx_double4 { + using implbase<avx_double4>::cast_from; + using implbase<avx_double4>::gather; + + // Need to provide a cast overload for avx2_int4 tag: + static __m256d cast_from(tag<avx2_int4>, const __m128i& v) { + return _mm256_cvtepi32_pd(v); + } + static __m256d fma(const __m256d& a, const __m256d& b, const __m256d& c) { return _mm256_fmadd_pd(a, b, c); } @@ -705,11 +771,11 @@ struct avx2_double4: avx_double4 { return _mm256_castsi256_pd(_mm256_sub_epi64(zero, _mm256_cvtepi8_epi64(r))); } - static __m256d gather(avx2_int4, const double* p, const __m128i& index) { + static __m256d gather(tag<avx2_int4>, const double* p, const __m128i& index) { return _mm256_i32gather_pd(p, index, 8); } - static __m256d gather(avx2_int4, __m256d a, const double* p, const __m128i& index, const __m256d& mask) { + static __m256d gather(tag<avx2_int4>, __m256d a, const double* p, const __m128i& index, const __m256d& mask) { return _mm256_mask_i32gather_pd(a, p, index, mask, 8); }; diff --git a/src/simd/avx512.hpp b/src/simd/avx512.hpp index 4b1a877eee82d3053e96bbd044d8d403ab3d0208..6f1236127c8ee0544990b4027608448e88dd3d97 100644 --- a/src/simd/avx512.hpp +++ b/src/simd/avx512.hpp @@ -44,6 +44,10 @@ struct simd_traits<avx512_int8> { }; struct avx512_mask8: implbase<avx512_mask8> { + using implbase<avx512_mask8>::gather; + using implbase<avx512_mask8>::scatter; + using implbase<avx512_mask8>::cast_from; + static __mmask8 broadcast(bool b) { return _mm512_int2mask(-b); } @@ -207,6 +211,10 @@ struct avx512_int8: implbase<avx512_int8> { // to __mmask8 seem to produce a lot of ultimately unnecessary // operations. + using implbase<avx512_int8>::gather; + using implbase<avx512_int8>::scatter; + using implbase<avx512_int8>::cast_from; + using int32 = std::int32_t; static __mmask8 lo() { @@ -237,6 +245,10 @@ struct avx512_int8: implbase<avx512_int8> { return _mm512_mask_loadu_epi32(v, mask, p); } + static int element0(const __m512i& a) { + return _mm_cvtsi128_si32(_mm512_castsi512_si128(a)); + } + static __m512i negate(const __m512i& a) { return sub(_mm512_setzero_epi32(), a); } @@ -312,6 +324,20 @@ struct avx512_int8: implbase<avx512_int8> { return _mm512_abs_epi32(a); } + static int reduce_add(const __m512i& a) { + // Add [...|a7|a6|a5|a4|a3|a2|a1|a0] to [...|a3|a2|a1|a0|a7|a6|a5|a4] + //__m512i b = add(a, _mm512_shuffle_i32x4(a, a, 0xb1)); + __m512i b = add(a, _mm512_shuffle_i32x4(a, a, _MM_PERM_CDAB)); // 0xb1 + // Add [...|b7|b6|b5|b4|b3|b2|b1|b0] to [...|b6|b7|b4|b5|b2|b3|b0|b1] + //__m512i c = add(b, _mm512_shuffle_epi32(b, 0xb1)); + __m512i c = add(b, _mm512_shuffle_epi32(b, _MM_PERM_CDAB)); // 0xb1 + // Add [...|c7|c6|c5|c4|c3|c2|c1|c0] to [...|c5|c4|c7|c6|c1|c0|c3|c2] + //__m512i d = add(c, _mm512_shuffle_epi32(c, 0x4e)); + __m512i d = add(c, _mm512_shuffle_epi32(c, _MM_PERM_BADC)); // 0x4e + + return element0(d); + } + // Generic 8-wide int solutions for gather and scatter. template <typename Impl> @@ -319,7 +345,7 @@ struct avx512_int8: implbase<avx512_int8> { template <typename ImplIndex, typename = typename std::enable_if<is_int8_simd<ImplIndex>::value>::type> - static __m512i gather(ImplIndex, const int32* p, const typename ImplIndex::vector_type& index) { + static __m512i gather(tag<ImplIndex>, const int32* p, const typename ImplIndex::vector_type& index) { int32 o[16]; ImplIndex::copy_to(index, o); auto op = reinterpret_cast<const __m512i*>(o); @@ -328,7 +354,7 @@ struct avx512_int8: implbase<avx512_int8> { template <typename ImplIndex, typename = typename std::enable_if<is_int8_simd<ImplIndex>::value>::type> - static __m512i gather(ImplIndex, __m512i a, const int32* p, const typename ImplIndex::vector_type& index, const __mmask8& mask) { + static __m512i gather(tag<ImplIndex>, const __m512i& a, const int32* p, const typename ImplIndex::vector_type& index, const __mmask8& mask) { int32 o[16]; ImplIndex::copy_to(index, o); auto op = reinterpret_cast<const __m512i*>(o); @@ -337,7 +363,7 @@ struct avx512_int8: implbase<avx512_int8> { template <typename ImplIndex, typename = typename std::enable_if<is_int8_simd<ImplIndex>::value>::type> - static void scatter(ImplIndex, const __m512i& s, int32* p, const typename ImplIndex::vector_type& index) { + static void scatter(tag<ImplIndex>, const __m512i& s, int32* p, const typename ImplIndex::vector_type& index) { int32 o[16]; ImplIndex::copy_to(index, o); auto op = reinterpret_cast<const __m512i*>(o); @@ -346,7 +372,7 @@ struct avx512_int8: implbase<avx512_int8> { template <typename ImplIndex, typename = typename std::enable_if<is_int8_simd<ImplIndex>::value>::type> - static void scatter(ImplIndex, const __m512i& s, int32* p, const typename ImplIndex::vector_type& index, const __mmask8& mask) { + static void scatter(tag<ImplIndex>, const __m512i& s, int32* p, const typename ImplIndex::vector_type& index, const __mmask8& mask) { int32 o[16]; ImplIndex::copy_to(index, o); auto op = reinterpret_cast<const __m512i*>(o); @@ -376,6 +402,10 @@ struct avx512_double8: implbase<avx512_double8> { // Use default implementations for: // element, set_element. + using implbase<avx512_double8>::gather; + using implbase<avx512_double8>::scatter; + using implbase<avx512_double8>::cast_from; + // CMPPD predicates: static constexpr int cmp_eq_oq = 0; static constexpr int cmp_unord_q = 3; @@ -411,6 +441,10 @@ struct avx512_double8: implbase<avx512_double8> { return _mm512_mask_loadu_pd(v, mask, p); } + static double element0(const __m512d& a) { + return _mm_cvtsd_f64(_mm512_castpd512_pd128(a)); + } + static __m512d negate(const __m512d& a) { return _mm512_sub_pd(_mm512_setzero_pd(), a); } @@ -477,13 +511,24 @@ struct avx512_double8: implbase<avx512_double8> { return _mm512_castsi512_pd(_mm512_and_epi64(_mm512_castpd_si512(x), m)); } + static double reduce_add(const __m512d& a) { + // add [a7|a6|a5|a4|a3|a2|a1|a0] to [a3|a2|a1|a0|a7|a6|a5|a4] + __m512d b = add(a, _mm512_shuffle_f64x2(a, a, 0x4e)); + // add [b7|b6|b5|b4|b3|b2|b1|b0] to [b5|b4|b7|b6|b1|b0|b3|b2] + __m512d c = add(b, _mm512_permutex_pd(b, 0x4e)); + // add [c7|c6|c5|c4|c3|c2|c1|c0] to [c6|c7|c4|c5|c2|c3|c0|c1] + __m512d d = add(c, _mm512_permute_pd(c, 0x55)); + + return element0(d); + } + // Generic 8-wide int solutions for gather and scatter. template <typename Impl> using is_int8_simd = std::integral_constant<bool, std::is_same<int, typename Impl::scalar_type>::value && Impl::width==8>; template <typename ImplIndex, typename = typename std::enable_if<is_int8_simd<ImplIndex>::value>::type> - static __m512d gather(ImplIndex, const double* p, const typename ImplIndex::vector_type& index) { + static __m512d gather(tag<ImplIndex>, const double* p, const typename ImplIndex::vector_type& index) { int o[8]; ImplIndex::copy_to(index, o); auto op = reinterpret_cast<const __m256i*>(o); @@ -491,7 +536,7 @@ struct avx512_double8: implbase<avx512_double8> { } template <typename ImplIndex, typename = typename std::enable_if<is_int8_simd<ImplIndex>::value>::type> - static __m512d gather(ImplIndex, __m512d a, const double* p, const typename ImplIndex::vector_type& index, const __mmask8& mask) { + static __m512d gather(tag<ImplIndex>, const __m512d& a, const double* p, const typename ImplIndex::vector_type& index, const __mmask8& mask) { int o[8]; ImplIndex::copy_to(index, o); auto op = reinterpret_cast<const __m256i*>(o); @@ -499,7 +544,7 @@ struct avx512_double8: implbase<avx512_double8> { } template <typename ImplIndex, typename = typename std::enable_if<is_int8_simd<ImplIndex>::value>::type> - static void scatter(ImplIndex, const __m512d& s, double* p, const typename ImplIndex::vector_type& index) { + static void scatter(tag<ImplIndex>, const __m512d& s, double* p, const typename ImplIndex::vector_type& index) { int o[8]; ImplIndex::copy_to(index, o); auto op = reinterpret_cast<const __m256i*>(o); @@ -507,7 +552,7 @@ struct avx512_double8: implbase<avx512_double8> { } template <typename ImplIndex, typename = typename std::enable_if<is_int8_simd<ImplIndex>::value>::type> - static void scatter(ImplIndex, const __m512d& s, double* p, const typename ImplIndex::vector_type& index, const __mmask8& mask) { + static void scatter(tag<ImplIndex>, const __m512d& s, double* p, const typename ImplIndex::vector_type& index, const __mmask8& mask) { int o[8]; ImplIndex::copy_to(index, o); auto op = reinterpret_cast<const __m256i*>(o); @@ -572,8 +617,7 @@ struct avx512_double8: implbase<avx512_double8> { auto gg = mul(g, g); // Compute the g*P(g^2) and Q(g^2). - - auto odd = mul(g, horner(gg, P0exp, P1exp, P2exp)); +auto odd = mul(g, horner(gg, P0exp, P1exp, P2exp)); auto even = horner(gg, Q0exp, Q1exp, Q2exp, Q3exp); // Compute R(g)/R(-g) = 1 + 2*g*P(g^2) / (Q(g^2)-g*P(g^2)) diff --git a/src/simd/implbase.hpp b/src/simd/implbase.hpp index f5b968f95fe834c67a2a8b87e1f47fb9cf041bc4..9253add19f6bee25ab13bbfffda803e90a608fab 100644 --- a/src/simd/implbase.hpp +++ b/src/simd/implbase.hpp @@ -51,6 +51,19 @@ namespace arb { namespace simd { + +// Constraints on possible index conflicts can be used to select a more +// efficient indexed update, gather or scatter. + +enum class index_constraint { + none = 0, + // For indices k[0], k[1],...: + + independent, // k[i]==k[j] => i=j. + contiguous, // k[i]==k[0]+i + constant // k[i]==k[j] ∀ i, j +}; + namespace simd_detail { // The simd_traits class provides the mapping between a concrete SIMD @@ -65,6 +78,13 @@ struct simd_traits { using mask_impl = void; }; +// The `tag` template is used to dispatch gather, scatter and cast_from +// operations that involve a (possibly) different SIMD implemenation +// class. + +template <typename I> +struct tag {}; + template <typename I> struct implbase { constexpr static unsigned width = simd_traits<I>::width; @@ -77,6 +97,31 @@ struct implbase { using store = scalar_type[width]; using mask_store = bool[width]; + template <typename ImplFrom> + static vector_type cast_from_(tag<ImplFrom>, const typename ImplFrom::vector_type& v, std::true_type) { + store a; + ImplFrom::copy_to(v, a); + return I::copy_from(a); + } + + template <typename ImplFrom> + static vector_type cast_from_(tag<ImplFrom>, const typename ImplFrom::vector_type& v, std::false_type) { + using other_scalar_type = typename simd_traits<ImplFrom>::scalar_type; + other_scalar_type b[width]; + ImplFrom::copy_to(v, b); + store a; + std::copy(b, b+width, a); + return I::copy_from(a); + } + + template < + typename ImplFrom, + typename other_scalar_type = typename simd_traits<ImplFrom>::scalar_type + > + static vector_type cast_from(tag<ImplFrom> tag, const typename ImplFrom::vector_type& v) { + return cast_from_(tag, v, typename std::is_same<scalar_type, other_scalar_type>::type{}); + } + static vector_type broadcast(scalar_type x) { store a; std::fill(std::begin(a), std::end(a), x); @@ -89,6 +134,10 @@ struct implbase { return a[i]; } + static scalar_type element0(const vector_type&v) { + return element(v, 0); + } + static void set_element(vector_type& v, int i, scalar_type x) { store a; I::copy_to(v, a); @@ -329,7 +378,7 @@ struct implbase { } template <typename ImplIndex> - static vector_type gather(ImplIndex, const scalar_type* p, const typename ImplIndex::vector_type& index) { + static vector_type gather(tag<ImplIndex>, const scalar_type* p, const typename ImplIndex::vector_type& index) { typename ImplIndex::scalar_type o[width]; ImplIndex::copy_to(index, o); @@ -341,7 +390,7 @@ struct implbase { } template <typename ImplIndex> - static vector_type gather(ImplIndex, const vector_type& s, const scalar_type* p, const typename ImplIndex::vector_type& index, const mask_type& mask) { + static vector_type gather(tag<ImplIndex>, const vector_type& s, const scalar_type* p, const typename ImplIndex::vector_type& index, const mask_type& mask) { mask_store m; mask_impl::mask_copy_to(mask, m); @@ -355,10 +404,10 @@ struct implbase { if (m[i]) { a[i] = p[o[i]]; } } return I::copy_from(a); - }; + } template <typename ImplIndex> - static void scatter(ImplIndex, const vector_type& s, scalar_type* p, const typename ImplIndex::vector_type& index) { + static void scatter(tag<ImplIndex>, const vector_type& s, scalar_type* p, const typename ImplIndex::vector_type& index) { typename ImplIndex::scalar_type o[width]; ImplIndex::copy_to(index, o); @@ -371,7 +420,7 @@ struct implbase { } template <typename ImplIndex> - static void scatter(ImplIndex, const vector_type& s, scalar_type* p, const typename ImplIndex::vector_type& index, const mask_type& mask) { + static void scatter(tag<ImplIndex>, const vector_type& s, scalar_type* p, const typename ImplIndex::vector_type& index, const mask_type& mask) { mask_store m; mask_impl::mask_copy_to(mask, m); @@ -386,6 +435,54 @@ struct implbase { } } + template <typename ImplIndex> + static void compound_indexed_add(tag<ImplIndex> tag, const vector_type& s, scalar_type* p, const typename ImplIndex::vector_type& index, index_constraint constraint) { + switch (constraint) { + case index_constraint::none: + { + typename ImplIndex::scalar_type o[width]; + ImplIndex::copy_to(index, o); + + store a; + I::copy_to(s, a); + + for (unsigned i = 0; i<width; ++i) { + p[o[i]] += a[i]; + } + } + break; + case index_constraint::independent: + { + vector_type v = I::add(I::gather(tag, p, index), s); + I::scatter(tag, v, p, index); + } + break; + case index_constraint::contiguous: + { + p += ImplIndex::element0(index); + vector_type v = I::add(I::copy_from(p), s); + I::copy_to(v, p); + } + break; + case index_constraint::constant: + { + p += ImplIndex::element0(index); + *p += I::reduce_add(s); + } + break; + } + } + + static scalar_type reduce_add(const vector_type& s) { + store a; + I::copy_to(s, a); + scalar_type r = a[0]; + for (unsigned i=1; i<width; ++i) { + r += a[i]; + } + return r; + } + // Maths static vector_type abs(const vector_type& u) { @@ -399,11 +496,11 @@ struct implbase { } static vector_type min(const vector_type& s, const vector_type& t) { - return ifelse(cmp_gt(t, s), s, t); + return I::ifelse(I::cmp_gt(t, s), s, t); } static vector_type max(const vector_type& s, const vector_type& t) { - return ifelse(cmp_gt(t, s), t, s); + return I::ifelse(I::cmp_gt(t, s), t, s); } static vector_type sin(const vector_type& s) { @@ -458,7 +555,7 @@ struct implbase { static vector_type exprelr(const vector_type& s) { vector_type ones = I::broadcast(1); - return ifelse(cmp_eq(ones, add(ones, s)), ones, div(s, expm1(s))); + return I::ifelse(I::cmp_eq(ones, I::add(ones, s)), ones, I::div(s, I::expm1(s))); } static vector_type pow(const vector_type& s, const vector_type &t) { diff --git a/src/simd/simd.hpp b/src/simd/simd.hpp index a7a1c65505c3c181e74d84fc0ea850b42a965070..17922040523f1d381885ec8c13f2b6b41d2d58fa 100644 --- a/src/simd/simd.hpp +++ b/src/simd/simd.hpp @@ -1,5 +1,7 @@ #pragma once +#include <array> +#include <cstddef> #include <type_traits> #include <simd/implbase.hpp> @@ -17,40 +19,41 @@ namespace simd_detail { struct simd_mask_impl; } -// Forward declarations for top-level maths functions. -// (these require access to private simd_impl<Impl>::wrap method). - -template <typename Impl> -simd_detail::simd_impl<Impl> abs(const simd_detail::simd_impl<Impl>&); - -template <typename Impl> -simd_detail::simd_impl<Impl> sin(const simd_detail::simd_impl<Impl>&); - -template <typename Impl> -simd_detail::simd_impl<Impl> cos(const simd_detail::simd_impl<Impl>&); - -template <typename Impl> -simd_detail::simd_impl<Impl> exp(const simd_detail::simd_impl<Impl>&); +namespace simd_detail { + template <typename Impl, typename V> + struct indirect_expression { + V* p; + typename simd_traits<Impl>::vector_type index; + index_constraint constraint; -template <typename Impl> -simd_detail::simd_impl<Impl> log(const simd_detail::simd_impl<Impl>&); + indirect_expression() = default; + indirect_expression(V* p, const simd_impl<Impl>& index_simd, index_constraint constraint): + p(p), index(index_simd.value_), constraint(constraint) + {} -template <typename Impl> -simd_detail::simd_impl<Impl> expm1(const simd_detail::simd_impl<Impl>&); + // Simple assignment included for consistency with compound assignment interface. -template <typename Impl> -simd_detail::simd_impl<Impl> exprelr(const simd_detail::simd_impl<Impl>&); + template <typename Other> + indirect_expression& operator=(const simd_impl<Other>& s) { + s.copy_to(*this); + return *this; + } -template <typename Impl> -simd_detail::simd_impl<Impl> pow(const simd_detail::simd_impl<Impl>&, const simd_detail::simd_impl<Impl>&); + // Compound assignment (currently only addition and subtraction!): -template <typename Impl> -simd_detail::simd_impl<Impl> min(const simd_detail::simd_impl<Impl>&, const simd_detail::simd_impl<Impl>&); + template <typename Other> + indirect_expression& operator+=(const simd_impl<Other>& s) { + Other::compound_indexed_add(tag<Impl>{}, s.value_, p, index, constraint); + return *this; + } -template <typename Impl> -simd_detail::simd_impl<Impl> max(const simd_detail::simd_impl<Impl>&, const simd_detail::simd_impl<Impl>&); + template <typename Other> + indirect_expression& operator-=(const simd_impl<Other>& s) { + Other::compound_indexed_add(tag<Impl>{}, (-s).value_, p, index, constraint); + return *this; + } + }; -namespace simd_detail { template <typename Impl> struct simd_impl { // Type aliases: @@ -72,7 +75,10 @@ namespace simd_detail { static constexpr unsigned width = simd_traits<Impl>::width; template <typename Other> - friend class simd_impl; + friend struct simd_impl; + + template <typename Other, typename V> + friend struct indirect_expression; simd_impl() = default; @@ -101,6 +107,23 @@ namespace simd_detail { value_ = Impl::copy_from_masked(a, m.value_); } + // Construct from a different SIMD value by casting. + template <typename Other, typename = typename std::enable_if<width==simd_traits<Other>::width>::type> + explicit simd_impl(const simd_impl<Other>& x) { + value_ = Impl::cast_from(tag<Other>{}, x.value_); + } + + // Construct from indirect expression (gather). + template <typename IndexImpl, typename = typename std::enable_if<width==simd_traits<IndexImpl>::width>::type> + explicit simd_impl(indirect_expression<IndexImpl, scalar_type> pi) { + copy_from(pi); + } + + template <typename IndexImpl, typename = typename std::enable_if<width==simd_traits<IndexImpl>::width>::type> + explicit simd_impl(indirect_expression<IndexImpl, const scalar_type> pi) { + copy_from(pi); + } + // Copy constructor. simd_impl(const simd_impl& other) { std::memcpy(&value_, &other.value_, sizeof(vector_type)); @@ -124,10 +147,25 @@ namespace simd_detail { Impl::copy_to(value_, p); } + template <typename IndexImpl, typename = typename std::enable_if<width==simd_traits<IndexImpl>::width>::type> + void copy_to(indirect_expression<IndexImpl, scalar_type> pi) const { + Impl::scatter(tag<IndexImpl>{}, value_, pi.p, pi.index); + } + void copy_from(const scalar_type* p) { value_ = Impl::copy_from(p); } + template <typename IndexImpl, typename = typename std::enable_if<width==simd_traits<IndexImpl>::width>::type> + void copy_from(indirect_expression<IndexImpl, scalar_type> pi) { + value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index); + } + + template <typename IndexImpl, typename = typename std::enable_if<width==simd_traits<IndexImpl>::width>::type> + void copy_from(indirect_expression<IndexImpl, const scalar_type> pi) { + value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index); + } + // Arithmetic operations: +, -, *, /, fma. simd_impl operator-() const { @@ -202,18 +240,6 @@ namespace simd_detail { return *this; } - // Gather and scatter. - - template <typename IndexImpl, typename = typename std::enable_if<width==simd_traits<IndexImpl>::width>::type> - void gather(const scalar_type* p, const simd_impl<IndexImpl>& index) { - value_ = Impl::gather(IndexImpl{}, p, index.value_); - } - - template <typename IndexImpl, typename = typename std::enable_if<width==simd_traits<IndexImpl>::width>::type> - void scatter(scalar_type* p, const simd_impl<IndexImpl>& index) { - Impl::scatter(IndexImpl{}, value_, p, index.value_); - } - // Array subscript operations. struct reference { @@ -245,6 +271,12 @@ namespace simd_detail { return Impl::element(value_, i); } + // Reductions (horizontal operations). + + scalar_type sum() const { + return Impl::reduce_add(value_); + } + // Masked assignment (via where expressions). struct where_expression { @@ -272,14 +304,16 @@ namespace simd_detail { data_.value_ = Impl::copy_from_masked(data_.value_, p, mask_.value_); } + // Gather and scatter. + template <typename IndexImpl, typename = typename std::enable_if<width==simd_traits<IndexImpl>::width>::type> - void gather(const scalar_type* p, const simd_impl<IndexImpl>& index) { - data_.value_ = Impl::gather(IndexImpl{}, data_.value_, p, index.value_, mask_.value_); + void copy_from(indirect_expression<IndexImpl, scalar_type> pi) { + data_.value_ = Impl::gather(tag<IndexImpl>{}, data_.value_, pi.p, pi.index, mask_.value_); } template <typename IndexImpl, typename = typename std::enable_if<width==simd_traits<IndexImpl>::width>::type> - void scatter(scalar_type* p, const simd_impl<IndexImpl>& index) { - Impl::scatter(IndexImpl{}, data_.value_, p, index.value_, mask_.value_); + void copy_to(indirect_expression<IndexImpl, scalar_type> pi) const { + Impl::scatter(tag<IndexImpl>{}, data_.value_, pi.p, pi.index, mask_.value_); } private: @@ -287,19 +321,49 @@ namespace simd_detail { simd_impl& data_; }; - // Maths functions are implemented as top-level functions, but require - // access to `wrap`. + // Maths functions are implemented as top-level functions; declare as friends + // for access to `wrap` and to enjoy ADL, allowing implicit conversion from + // scalar_type in binary operation arguments. - friend simd_impl abs<Impl>(const simd_impl&); - friend simd_impl sin<Impl>(const simd_impl&); - friend simd_impl cos<Impl>(const simd_impl&); - friend simd_impl exp<Impl>(const simd_impl&); - friend simd_impl log<Impl>(const simd_impl&); - friend simd_impl expm1<Impl>(const simd_impl&); - friend simd_impl exprelr<Impl>(const simd_impl&); - friend simd_impl min<Impl>(const simd_impl&, const simd_impl&); - friend simd_impl max<Impl>(const simd_impl&, const simd_impl&); - friend simd_impl pow<Impl>(const simd_impl&, const simd_impl&); + friend simd_impl abs(const simd_impl& s) { + return simd_impl::wrap(Impl::abs(s.value_)); + } + + friend simd_impl sin(const simd_impl& s) { + return simd_impl::wrap(Impl::sin(s.value_)); + } + + friend simd_impl cos(const simd_impl& s) { + return simd_impl::wrap(Impl::cos(s.value_)); + } + + friend simd_impl exp(const simd_impl& s) { + return simd_impl::wrap(Impl::exp(s.value_)); + } + + friend simd_impl log(const simd_impl& s) { + return simd_impl::wrap(Impl::log(s.value_)); + } + + friend simd_impl expm1(const simd_impl& s) { + return simd_impl::wrap(Impl::expm1(s.value_)); + } + + friend simd_impl exprelr(const simd_impl& s) { + return simd_impl::wrap(Impl::exprelr(s.value_)); + } + + friend simd_impl pow(const simd_impl& s, const simd_impl& t) { + return simd_impl::wrap(Impl::pow(s.value_, t.value_)); + } + + friend simd_impl min(const simd_impl& s, const simd_impl& t) { + return simd_impl::wrap(Impl::min(s.value_, t.value_)); + } + + friend simd_impl max(const simd_impl& s, const simd_impl& t) { + return simd_impl::wrap(Impl::max(s.value_, t.value_)); + } protected: vector_type value_; @@ -422,7 +486,7 @@ namespace simd_detail { private: simd_mask_impl(const vector_type& v): base(v) {} - template <class> friend class simd_impl; + template <class> friend struct simd_impl; static simd_mask_impl wrap(const vector_type& v) { simd_mask_impl m; @@ -430,6 +494,40 @@ namespace simd_detail { return m; } }; + + template <typename To> + struct simd_cast_impl {}; + + template <typename ImplTo> + struct simd_cast_impl<simd_impl<ImplTo>> { + static constexpr unsigned N = simd_traits<ImplTo>::width; + using scalar_type = typename simd_traits<ImplTo>::scalar_type; + + template <typename ImplFrom, typename = typename std::enable_if<N==simd_traits<ImplFrom>::width>::type> + static simd_impl<ImplTo> cast(const simd_impl<ImplFrom>& v) { + return simd_impl<ImplTo>(v); + } + + static simd_impl<ImplTo> cast(const std::array<scalar_type, N>& a) { + return simd_impl<ImplTo>(a.data()); + } + }; + + template <typename V, std::size_t N> + struct simd_cast_impl<std::array<V, N>> { + template < + typename ImplFrom, + typename = typename std::enable_if< + N==simd_traits<ImplFrom>::width && + std::is_same<V, typename simd_traits<ImplFrom>::scalar_type>::value + >::type + > + static std::array<V, N> cast(const simd_impl<ImplFrom>& s) { + std::array<V, N> a; + s.copy_to(a.data()); + return a; + } + }; } // namespace simd_detail namespace simd_abi { @@ -465,57 +563,29 @@ struct is_simd: std::false_type {}; template <typename Impl> struct is_simd<simd_detail::simd_impl<Impl>>: std::true_type {}; -// Top-level maths functions: forward to underlying Impl. - -template <typename Impl> -simd_detail::simd_impl<Impl> abs(const simd_detail::simd_impl<Impl>& s) { - return simd_detail::simd_impl<Impl>::wrap(Impl::abs(s.value_)); -} - -template <typename Impl> -simd_detail::simd_impl<Impl> sin(const simd_detail::simd_impl<Impl>& s) { - return simd_detail::simd_impl<Impl>::wrap(Impl::sin(s.value_)); -} - -template <typename Impl> -simd_detail::simd_impl<Impl> cos(const simd_detail::simd_impl<Impl>& s) { - return simd_detail::simd_impl<Impl>::wrap(Impl::cos(s.value_)); -} +// Casting is dispatched to simd_cast_impl in order to handle conversions to +// and from std::array. -template <typename Impl> -simd_detail::simd_impl<Impl> exp(const simd_detail::simd_impl<Impl>& s) { - return simd_detail::simd_impl<Impl>::wrap(Impl::exp(s.value_)); +template <typename To, typename From> +To simd_cast(const From& s) { + return simd_detail::simd_cast_impl<To>::cast(s); } -template <typename Impl> -simd_detail::simd_impl<Impl> log(const simd_detail::simd_impl<Impl>& s) { - return simd_detail::simd_impl<Impl>::wrap(Impl::log(s.value_)); +// Gather/scatter indexed memory specification. + +template < + typename IndexImpl, + typename PtrLike, + typename V = typename std::remove_reference<decltype(*std::declval<PtrLike>())>::type +> +simd_detail::indirect_expression<IndexImpl, V> indirect( + PtrLike p, + const simd_detail::simd_impl<IndexImpl>& index, + index_constraint constraint = index_constraint::none) +{ + return simd_detail::indirect_expression<IndexImpl, V>(p, index, constraint); } -template <typename Impl> -simd_detail::simd_impl<Impl> expm1(const simd_detail::simd_impl<Impl>& s) { - return simd_detail::simd_impl<Impl>::wrap(Impl::expm1(s.value_)); -} - -template <typename Impl> -simd_detail::simd_impl<Impl> exprelr(const simd_detail::simd_impl<Impl>& s) { - return simd_detail::simd_impl<Impl>::wrap(Impl::exprelr(s.value_)); -} - -template <typename Impl> -simd_detail::simd_impl<Impl> pow(const simd_detail::simd_impl<Impl>& s, const simd_detail::simd_impl<Impl>& t) { - return simd_detail::simd_impl<Impl>::wrap(Impl::pow(s.value_, t.value_)); -} - -template <typename Impl> -simd_detail::simd_impl<Impl> min(const simd_detail::simd_impl<Impl>& s, const simd_detail::simd_impl<Impl>& t) { - return simd_detail::simd_impl<Impl>::wrap(Impl::min(s.value_, t.value_)); -} - -template <typename Impl> -simd_detail::simd_impl<Impl> max(const simd_detail::simd_impl<Impl>& s, const simd_detail::simd_impl<Impl>& t) { - return simd_detail::simd_impl<Impl>::wrap(Impl::max(s.value_, t.value_)); -} } // namespace simd } // namespace arb diff --git a/src/simd/simd_io.hpp b/src/simd/simd_io.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3cc4c84149c8d2f77ea19d94c6190aca038f6389 --- /dev/null +++ b/src/simd/simd_io.hpp @@ -0,0 +1,28 @@ +#pragma once + +// Overloads for iostream formatted output for SIMD value classes. + +#include <iostream> + +#include <simd/simd.hpp> + +namespace arb { +namespace simd { +namespace simd_detail { + +template <typename Impl> +std::ostream& operator<<(std::ostream& o, const simd_impl<Impl>& s) { + using Simd = simd_impl<Impl>; + + typename Simd::scalar_type data[Simd::width]; + s.copy_to(data); + o << data[0]; + for (unsigned i = 1; i<Simd::width; ++i) { + o << ' ' << data[i]; + } + return o; +} + +} // namespace simd_detail +} // namespace simd +} // namespace arb diff --git a/src/stimulus.hpp b/src/stimulus.hpp deleted file mode 100644 index 071966c6b60c7b03b2a6cbda6953fea9824ca4b0..0000000000000000000000000000000000000000 --- a/src/stimulus.hpp +++ /dev/null @@ -1,52 +0,0 @@ -#pragma once - -namespace arb { - -class i_clamp { - public: - - using value_type = double; - - i_clamp(value_type del, value_type dur, value_type amp) - : delay_(del), - duration_(dur), - amplitude_(amp) - {} - - value_type delay() const { - return delay_; - } - value_type duration() const { - return duration_; - } - value_type amplitude() const { - return amplitude_; - } - - void set_delay(value_type d) { - delay_ = d; - } - void set_duration(value_type d) { - duration_ = d; - } - void set_amplitude(value_type a) { - amplitude_ = a; - } - - // current is set to amplitude for time in the half open interval: - // t \in [delay, delay+duration) - value_type amplitude(double t) { - if(t>=delay_ && t<(delay_+duration_)) { - return amplitude_; - } - return 0; - } - - private: - - value_type delay_ = 0; // [ms] - value_type duration_ = 0; // [ms] - value_type amplitude_ = 0; // [nA] -}; - -} // namespace arb diff --git a/src/tree.hpp b/src/tree.hpp index 96f90043bc81d57afd503fc1244927233770957f..75a9cac24a742398a82c19dff57959e0980302f3 100644 --- a/src/tree.hpp +++ b/src/tree.hpp @@ -289,7 +289,7 @@ std::vector<tree::int_type> make_parent_index(tree const& t, C const& counts) using int_type = tree::int_type; constexpr auto no_parent = tree::no_parent; - if (!algorithms::is_positive(counts) || counts.size() != t.num_segments()) { + if (!algorithms::all_positive(counts) || counts.size() != t.num_segments()) { throw std::domain_error( "make_parent_index requires one non-zero count per segment" ); diff --git a/src/util/compat.hpp b/src/util/compat.hpp index 977bce1205489d5ecd258119e106b98b326121c3..7784f04724296a54a62cb57d01d615c86655becb 100644 --- a/src/util/compat.hpp +++ b/src/util/compat.hpp @@ -27,8 +27,14 @@ constexpr bool using_gnu_compiler(int major=0, int minor=0, int patchlevel=0) { // std::end() broken with (at least) xlC 13.1.4. +namespace impl { + using std::end; + template <typename T> + auto end_(T& x) -> decltype(end(x)) { return end(x); } +} + template <typename T> -auto end(T& x) -> decltype(x.end()) { return x.end(); } +auto end(T& x) -> decltype(impl::end_(x)) { return impl::end_(x); } template <typename T, std::size_t N> T* end(T (&x)[N]) { return &x[0]+N; } diff --git a/src/util/index_into.hpp b/src/util/index_into.hpp new file mode 100644 index 0000000000000000000000000000000000000000..41ed48ea05656b6c98ac322ca16fdc99941858ab --- /dev/null +++ b/src/util/index_into.hpp @@ -0,0 +1,154 @@ +#pragma once + +// Iterator and range that represents the indices within a super-sequence +// of elements in a sub-sequence. +// +// It is a prerequisite that the elements of the sub-sequence do indeed +// exist within the super-sequence with the same order. +// +// Example: +// +// Given sequence S = { 1, 3, 5, 5, 2 } +// and T = { 0, 5, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 2, 8 }, +// then index_into(S, T) would present the indices +// { 3, 4, 8, 8, 13 }. + +#include <iterator> +#include <type_traits> + +#include <util/compat.hpp> +#include <util/debug.hpp> +#include <util/meta.hpp> +#include <util/range.hpp> + +namespace arb { +namespace util { + +template <typename Sub, typename Sup, typename SupEnd> +struct index_into_iterator { + using value_type = typename std::iterator_traits<Sup>::difference_type; + using difference_type = value_type; + using pointer = const value_type*; + using reference = const value_type&; + using iterator_category = typename + std::conditional< + std::is_same<Sup, SupEnd>::value + && is_bidirectional_iterator_t<Sup>::value + && is_bidirectional_iterator_t<Sub>::value, + std::bidirectional_iterator_tag, + std::forward_iterator_tag + >::type; + + index_into_iterator(const Sub& sub, const Sub& sub_end, const Sup& sup, const SupEnd& sup_end): + sub(sub), sub_end(sub_end), sup(sup), sup_end(sup_end), idx(0) + { + align_fwd(); + } + + value_type operator*() const { + return idx; + } + + index_into_iterator& operator++() { + EXPECTS(sup!=sup_end); + + ++sub; + align_fwd(); + return *this; + } + + index_into_iterator operator++(int) { + auto keep = *this; + ++(*this); + return keep; + } + + index_into_iterator& operator--() { + if (sub==sub_end) { + // decrementing one-past-the-end iterator + idx = std::distance(sup, sup_end)-1; + sup = std::prev(sup_end); + } + + --sub; + align_rev(); + return *this; + } + + index_into_iterator operator--(int) { + auto keep = *this; + --(*this); + return keep; + } + + template <typename A, typename C, typename D> + friend struct index_into_iterator; + + template <typename OSub> + bool operator==(const index_into_iterator<OSub, Sup, SupEnd>& other) const { + return sub==other.sub; + } + + template <typename OSub> + bool operator!=(const index_into_iterator<OSub, Sup, SupEnd>& other) const { + return !(*this==other); + } + +private: + Sub sub; + Sub sub_end; + Sup sup; + SupEnd sup_end; + difference_type idx; + + void align_fwd() { + if (sub!=sub_end) { + while (sup!=sup_end && !(*sub==*sup)) { + ++idx; + ++sup; + } + } + } + + void align_rev() { + while (idx>0 && !(*sub==*sup)) { + --idx; + --sup; + } + + EXPECTS(*sub==*sup); + } +}; + +template < + typename Sub, + typename Super, + typename Canon = decltype(canonical_view(std::declval<Sub>())) +> +auto index_into(const Sub& sub, const Super& sup) + -> range< + index_into_iterator< + typename sequence_traits<Canon>::const_iterator, + typename sequence_traits<Super>::const_iterator, + typename sequence_traits<Super>::const_sentinel + > + > +{ + using iterator = + index_into_iterator< + typename sequence_traits<Canon>::const_iterator, + typename sequence_traits<Super>::const_iterator, + typename sequence_traits<Super>::const_sentinel + >; + + using std::begin; + + auto canon = canonical_view(sub); + iterator b(begin(canon), compat::end(canon), begin(sup), compat::end(sup)); + iterator e(compat::end(canon), compat::end(canon), begin(sup), compat::end(sup)); + + return range<iterator>(b, e); +} + +} // namespace util +} // namespace arb diff --git a/src/util/maputil.hpp b/src/util/maputil.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f6dd91b7aff0fbdf14e1d7e35caadf3e661303ba --- /dev/null +++ b/src/util/maputil.hpp @@ -0,0 +1,138 @@ +#pragma once + +#include <algorithm> +#include <cstring> +#include <iterator> +#include <utility> +#include <type_traits> + +#include <util/deduce_return.hpp> +#include <util/meta.hpp> +#include <util/optional.hpp> +#include <util/transform.hpp> + +// Convenience views, algorithms for maps and map-like containers. + +namespace arb { +namespace util { + +// View over the keys (first elements) in a sequence of pairs or tuples. + +template <typename Seq> +auto keys(Seq&& m) DEDUCED_RETURN_TYPE(util::transform_view(std::forward<Seq>(m), util::first)) + +// Is a container/sequence a map? + +namespace impl { + template < + typename C, + typename seq_value = typename sequence_traits<C>::value_type, + typename K = typename std::tuple_element<0, seq_value>::type, + typename V = typename std::tuple_element<0, seq_value>::type, + typename find_value = decay_t<decltype(*std::declval<C>().find(std::declval<K>()))> + > + struct assoc_test: std::integral_constant<bool, std::is_same<seq_value, find_value>::value> {}; +} + +template <typename Seq, typename = void> +struct is_associative_container: std::false_type {}; + +template <typename Seq> +struct is_associative_container<Seq, void_t<impl::assoc_test<Seq>>>: impl::assoc_test<Seq> {}; + +// Find value in a sequence of key-value pairs or in a key-value assocation map, with +// optional explicit comparator. +// +// If no comparator is given, and the container is associative, use the `find` method, otherwise +// perform a linear search. +// +// Returns optional<value> or optional<value&>. A reference optional is returned if: +// 1. the sequence is an lvalue reference, and +// 2. if the deduced return type from calling `get` on an entry from the sequence is an lvalue reference. + +namespace impl { + // import std::get for ADL below. + using std::get; + + // TODO: C++14 use std::equal_to<void> for this. + struct generic_equal_to { + template <typename A, typename B> + bool operator()(A&& a, B&& b) { + return std::forward<A>(a)==std::forward<B>(b); + } + }; + + // use linear search + template < + typename Seq, + typename Key, + typename Eq = generic_equal_to, + typename Ret0 = decltype(get<1>(*std::begin(std::declval<Seq&&>()))), + typename Ret = typename std::conditional< + std::is_rvalue_reference<Seq&&>::value || !std::is_lvalue_reference<Ret0>::value, + typename std::remove_reference<Ret0>::type, + Ret0 + >::type + > + optional<Ret> value_by_key(std::false_type, Seq&& seq, const Key& key, Eq eq=Eq{}) { + for (auto&& entry: seq) { + if (eq(get<0>(entry), key)) { + return get<1>(entry); + } + } + return nullopt; + } + + // use map find + template < + typename Assoc, + typename Key, + typename FindRet = decltype(std::declval<Assoc&&>().find(std::declval<Key>())), + typename Ret0 = decltype(get<1>(*std::declval<FindRet>())), + typename Ret = typename std::conditional< + std::is_rvalue_reference<Assoc&&>::value || !std::is_lvalue_reference<Ret0>::value, + typename std::remove_reference<Ret0>::type, + Ret0 + >::type + > + optional<Ret> value_by_key(std::true_type, Assoc&& map, const Key& key) { + auto it = map.find(key); + if (it!=std::end(map)) { + return get<1>(*it); + } + return nullopt; + } +} + +template <typename C, typename Key, typename Eq> +auto value_by_key(C&& c, const Key& k, Eq eq) + DEDUCED_RETURN_TYPE(impl::value_by_key(std::false_type{}, std::forward<C>(c), k, eq)) + +template <typename C, typename Key> +auto value_by_key(C&& c, const Key& k) + DEDUCED_RETURN_TYPE( + impl::value_by_key( + std::integral_constant<bool, is_associative_container<C>::value>{}, + std::forward<C>(c), k)) + +// Find the index into an ordered sequence of a value by binary search; +// returns optional<size_type> for the size_type associated with the sequence. +// (Note: this is pretty much all we use algorthim::binary_find for.) + +template <typename C, typename Key> +optional<typename sequence_traits<C>::difference_type> binary_search_index(const C& c, const Key& key) { + auto strict = strict_view(c); + auto it = std::lower_bound(strict.begin(), strict.end(), key); + return it!=strict.end() && key==*it? util::just(std::distance(strict.begin(), it)): util::nullopt; +} + +// Key equality helper for NUL-terminated strings. + +struct cstr_equal { + bool operator()(const char* u, const char* v) { + return !std::strcmp(u, v); + } +}; + +} // namespace util +} // namespace arb diff --git a/src/util/meta.hpp b/src/util/meta.hpp index 18820aa91e31a35d3803600df75f43230c5dc704..e4ed6bbccfe72088791e11c954d39066af44ed68 100644 --- a/src/util/meta.hpp +++ b/src/util/meta.hpp @@ -7,11 +7,19 @@ #include <type_traits> #include <util/compat.hpp> +#include <util/deduce_return.hpp> namespace arb { namespace util { -// Until C++14 ... +// The following classes and functions can be replaced +// with std functions when we migrate to later versions of C++. +// +// C++14: +// result_of_t, enable_if_t, decay_t, size, cbegin, cend. +// +// C++17: +// void_t, empty, data, as_const template <typename T> using result_of_t = typename std::result_of<T>::type; @@ -26,26 +34,47 @@ template <typename T> using decay_t = typename std::decay<T>::type; template <typename X> -std::size_t size(const X& x) { return x.size(); } +constexpr std::size_t size(const X& x) { return x.size(); } template <typename X, std::size_t N> -constexpr std::size_t size(X (&)[N]) { return N; } +constexpr std::size_t size(X (&)[N]) noexcept { return N; } + +template <typename C> +constexpr auto data(C& c) -> decltype(c.data()) { return c.data(); } + +template <typename C> +constexpr auto data(const C& c) -> decltype(c.data()) { return c.data(); } + +template <typename T, std::size_t N> +constexpr T* data(T (&a)[N]) noexcept { return a; } template <typename T> -constexpr auto cbegin(const T& c) -> decltype(std::begin(c)) { - return std::begin(c); -} +void as_const(T&& t) = delete; template <typename T> -constexpr auto cend(const T& c) -> decltype(compat::end(c)) { - // COMPAT: use own `end` implementation to work around xlC 13.1 bug. - return compat::end(c); +constexpr typename std::add_const<T>::type& as_const(T& t) { + return t; +} + +// Wrap cbegin in inner namespace in order to properly invoke ADL. + +namespace impl { + using std::begin; + + template <typename T> + constexpr auto cbegin_(const T& c) DEDUCED_RETURN_TYPE(begin(c)) } +template <typename T> +constexpr auto cbegin(const T& c) DEDUCED_RETURN_TYPE(impl::cbegin_(c)) + +template <typename T> +constexpr auto cend(const T& c) DEDUCED_RETURN_TYPE(compat::end(c)) + // Use sequence `empty() const` method if exists, otherwise // compare begin and end. -namespace impl { +namespace impl_empty { template <typename C> struct has_const_empty_method { template <typename T> @@ -56,13 +85,11 @@ namespace impl { using type = decltype(test<C>(0)); }; - // For correct ADL on begin and end: using std::begin; - using std::end; template <typename Seq> constexpr bool empty(const Seq& seq, std::false_type) { - return begin(seq)==end(seq); + return begin(seq)==compat::end(seq); } template <typename Seq> @@ -73,7 +100,7 @@ namespace impl { template <typename Seq> constexpr bool empty(const Seq& seq) { - return impl::empty(seq, typename impl::has_const_empty_method<Seq>::type{}); + return impl_empty::empty(seq, typename impl_empty::has_const_empty_method<Seq>::type{}); } template <typename T, std::size_t N> @@ -83,18 +110,34 @@ constexpr bool empty(const T (& c)[N]) noexcept { // Types associated with a container or sequence +namespace impl_seqtrait { + using std::begin; + + template <typename Seq, typename = void> + struct data_returns_pointer: std::false_type {}; + + template <typename T> + struct data_returns_pointer<T, void_t<decltype(util::data(std::declval<T>()))>>: + public std::is_pointer<decltype(util::data(std::declval<T>()))>::type {}; + + template <typename Seq> + struct sequence_traits { + using iterator = decltype(begin(std::declval<Seq&>())); + using const_iterator = decltype(begin(std::declval<const Seq&>())); + using value_type = typename std::iterator_traits<iterator>::value_type; + using reference = typename std::iterator_traits<iterator>::reference; + using difference_type = typename std::iterator_traits<iterator>::difference_type; + using size_type = decltype(size(std::declval<Seq&>())); + // For use with heterogeneous ranges: + using sentinel = decltype(compat::end(std::declval<Seq&>())); + using const_sentinel = decltype(compat::end(std::declval<const Seq&>())); + + static constexpr bool is_contiguous = data_returns_pointer<Seq>::value; + }; +} + template <typename Seq> -struct sequence_traits { - using iterator = decltype(std::begin(std::declval<Seq&>())); - using const_iterator = decltype(util::cbegin(std::declval<Seq&>())); - using value_type = typename std::iterator_traits<iterator>::value_type; - using reference = typename std::iterator_traits<iterator>::reference; - using difference_type = typename std::iterator_traits<iterator>::difference_type; - using size_type = decltype(size(std::declval<Seq&>())); - // for use with heterogeneous ranges - using sentinel = decltype(std::end(std::declval<Seq&>())); - using const_sentinel = decltype(util::cend(std::declval<Seq&>())); -}; +using sequence_traits = impl_seqtrait::sequence_traits<Seq>; // Convenience short cuts for `enable_if` @@ -137,7 +180,7 @@ struct is_iterator<T, void_t<typename std::iterator_traits<T>::iterator_category public std::true_type {}; template <typename T> -using is_iterator_t = typename is_iterator<T>::type; +using is_iterator_t = typename util::is_iterator<T>::type; // Random access iterator test @@ -152,7 +195,7 @@ struct is_random_access_iterator<T, enable_if_t< >> : public std::true_type {}; template <typename T> -using is_random_access_iterator_t = typename is_random_access_iterator<T>::type; +using is_random_access_iterator_t = typename util::is_random_access_iterator<T>::type; // Bidirectional iterator test @@ -171,7 +214,7 @@ struct is_bidirectional_iterator<T, enable_if_t< >> : public std::true_type {}; template <typename T> -using is_bidirectional_iterator_t = typename is_bidirectional_iterator<T>::type; +using is_bidirectional_iterator_t = typename util::is_bidirectional_iterator<T>::type; // Forward iterator test @@ -194,7 +237,7 @@ struct is_forward_iterator<T, enable_if_t< >> : public std::true_type {}; template <typename T> -using is_forward_iterator_t = typename is_forward_iterator<T>::type; +using is_forward_iterator_t = typename util::is_forward_iterator<T>::type; template <typename I, typename E, typename = void, typename = void> @@ -204,33 +247,27 @@ template <typename I, typename E> struct common_random_access_iterator< I, E, - void_t<decltype(false ? std::declval<I>() : std::declval<E>())>, - enable_if_t< + void_t<decltype(false? std::declval<I>(): std::declval<E>())>, + util::enable_if_t< is_random_access_iterator< - decay_t<decltype(false ? std::declval<I>() : std::declval<E>())> + decay_t<decltype(false? std::declval<I>(): std::declval<E>())> >::value > > { - using type = decay_t< + using type = util::decay_t< decltype(false ? std::declval<I>() : std::declval<E>()) >; }; template <typename I, typename E> -using common_random_access_iterator_t = typename common_random_access_iterator<I, E>::type; - -namespace impl { - /// Helper for SFINAE tests that can "sink" any type - template<typename T> - using sink = void; -} +using common_random_access_iterator_t = typename util::common_random_access_iterator<I, E>::type; template <typename I, typename E, typename V=void> struct has_common_random_access_iterator: std::false_type {}; template <typename I, typename E> -struct has_common_random_access_iterator<I, E, impl::sink<typename common_random_access_iterator<I, E>::type>>: +struct has_common_random_access_iterator<I, E, void_t<util::common_random_access_iterator_t<I, E>>>: std::true_type {}; template<typename T, typename V=void> @@ -238,12 +275,11 @@ struct is_sequence: std::false_type {}; template<typename T> -struct is_sequence<T, impl::sink<decltype(cbegin(std::declval<T>()))>>: +struct is_sequence<T, void_t<decltype(std::begin(std::declval<T>()))>>: std::true_type {}; template <typename T> -using enable_if_sequence_t = - enable_if_t<is_sequence<T>::value>; +using enable_if_sequence_t = util::enable_if_t<util::is_sequence<T>::value>; // No generic lambdas in C++11, so some convenience accessors for pairs that // are type-generic diff --git a/src/util/padded_alloc.hpp b/src/util/padded_alloc.hpp index 679e918e152cea168e40b8b99f89f723166c114f..9dfe7ea38a03c623328cdf5c84355ad6af7215ea 100644 --- a/src/util/padded_alloc.hpp +++ b/src/util/padded_alloc.hpp @@ -15,31 +15,27 @@ // // Any alignment `n` specified must be a power of two. // -// Assignment does not change the alignment property of the -// allocator on the left hand side of the assignment, so that +// Assignment operations propagate the alignment/padding, so that // e.g. // ``` // std::vector<int, padded_allocator<int>> a(100, 32), b(50, 64); // a = b; -// assert(a.get_allocator().alignment()==32); +// assert(a.get_allocator().alignment()==64); // ``` -// will pass, and the vector `a` will not require reallocation. -// -// For move assignment, this means we cannot allow a simple ownership -// transfer if the left hand side has a stronger alignment guarantee -// that the right hand side. Correspondingly, we have to return `false` +// will pass, and the vector `a` will require reallocation. +// Correspondingly, we have to return `false` // for the allocator equality test if the alignments differ. namespace arb { namespace util { -template <typename T> +template <typename T = void> struct padded_allocator { using value_type = T; using pointer = T*; - using propagate_on_container_copy_assignment = std::false_type; - using propagate_on_container_move_assignment = std::false_type; - using propagate_on_container_swap = std::false_type; + using propagate_on_container_copy_assignment = std::true_type; + using propagate_on_container_move_assignment = std::true_type; + using propagate_on_container_swap = std::true_type; using is_always_equal = std::false_type; padded_allocator() noexcept {} diff --git a/src/util/range.hpp b/src/util/range.hpp index 34eac57ab2b4ddf465ef5092fe0bf6d88824329b..7e36e6d136c07f58161d0a7a0016ba59af3aaff0 100644 --- a/src/util/range.hpp +++ b/src/util/range.hpp @@ -63,9 +63,27 @@ struct range { left(std::forward<U1>(l)), right(std::forward<U2>(r)) {} + template < + typename U1, + typename U2, + typename = enable_if_t< + std::is_constructible<iterator, U1>::value && + std::is_constructible<sentinel, U2>::value> + > + range(const range<U1, U2>& other): + left(other.left), right(other.right) + {} + range& operator=(const range&) = default; range& operator=(range&&) = default; + template <typename U1, typename U2> + range& operator=(const range<U1, U2>& other) { + left = other.left; + right = other.right; + return *this; + } + bool empty() const { return left == right; } iterator begin() const { return left; } @@ -108,6 +126,13 @@ struct range { return (*this)[n]; } + // Expose `data` method if a pointer range. + template <typename V = iterator, typename W = sentinel> + enable_if_t<std::is_same<V, W>::value && std::is_pointer<V>::value, iterator> + data() const { + return left; + } + #ifdef ARB_HAVE_TBB template < typename V = iterator, @@ -176,16 +201,18 @@ auto canonical_view(const Seq& s) -> // iterators. Note: O(N) behaviour with forward iterator ranges or sentinel-terminated ranges. template <typename Seq> -auto strict_view(Seq& s) -> range<decltype(std::begin(s))> +auto strict_view(Seq&& s) -> range<decltype(std::begin(s))> { - return make_range(std::begin(s), std::next(util::upto(std::begin(s), std::end(s)))); + return make_range(std::begin(s), std::begin(s)==std::end(s)? std::begin(s): std::next(util::upto(std::begin(s), std::end(s)))); } +#if 0 template <typename Seq> auto strict_view(const Seq& s) -> range<decltype(std::begin(s))> { - return make_range(std::begin(s), std::next(util::upto(std::begin(s), std::end(s)))); + return make_range(std::begin(s), std::begin(s)==std::end(s)? std::begin(s): std::next(util::upto(std::begin(s), std::end(s)))); } +#endif } // namespace util } // namespace arb diff --git a/src/util/rangeutil.hpp b/src/util/rangeutil.hpp index e9f70699140abc2be75dd85cf55627ccfe00c7ad..1979380cf7b246a1f3212236854b73e340e23f69 100644 --- a/src/util/rangeutil.hpp +++ b/src/util/rangeutil.hpp @@ -34,19 +34,23 @@ range<const T*> singleton_view(const T& item) { // Non-owning views and subviews template <typename Seq> -range<typename sequence_traits<Seq>::iterator, typename sequence_traits<Seq>::sentinel> -range_view(Seq& seq) { +range<typename sequence_traits<Seq&&>::iterator, typename sequence_traits<Seq&&>::sentinel> +range_view(Seq&& seq) { return make_range(std::begin(seq), std::end(seq)); } +template <typename Seq, typename = enable_if_t<sequence_traits<Seq&&>::is_contiguous>> +auto range_pointer_view(Seq&& seq) + DEDUCED_RETURN_TYPE(make_range(util::data(seq), util::data(seq)+util::size(seq))) + template < typename Seq, typename Offset1, typename Offset2, - typename Iter = typename sequence_traits<Seq>::iterator + typename Iter = typename sequence_traits<Seq&&>::iterator > enable_if_t<is_forward_iterator<Iter>::value, range<Iter>> -subrange_view(Seq& seq, Offset1 bi, Offset2 ei) { +subrange_view(Seq&& seq, Offset1 bi, Offset2 ei) { Iter b = std::begin(seq); std::advance(b, bi); @@ -59,11 +63,11 @@ template < typename Seq, typename Offset1, typename Offset2, - typename Iter = typename sequence_traits<Seq>::iterator + typename Iter = typename sequence_traits<Seq&&>::iterator > enable_if_t<is_forward_iterator<Iter>::value, range<Iter>> -subrange_view(Seq& seq, std::pair<Offset1, Offset2> index) { - return subrange_view(seq, index.first, index.second); +subrange_view(Seq&& seq, std::pair<Offset1, Offset2> index) { + return subrange_view(std::forward<Seq>(seq), index.first, index.second); } // helper for determining the type of a subrange_view @@ -74,13 +78,7 @@ using subrange_view_type = decltype(subrange_view(std::declval<Seq&>(), 0, 0)); // Fill container or range. template <typename Seq, typename V> -void fill(Seq& seq, const V& value) { - auto canon = canonical_view(seq); - std::fill(canon.begin(), canon.end(), value); -} - -template <typename Range, typename V> -void fill(const Range& seq, const V& value) { +void fill(Seq&& seq, const V& value) { auto canon = canonical_view(seq); std::fill(canon.begin(), canon.end(), value); } @@ -99,7 +97,7 @@ Container& append(Container &c, const Seq& seq) { template <typename AssignableContainer, typename Seq> AssignableContainer& assign(AssignableContainer& c, const Seq& seq) { auto canon = canonical_view(seq); - c.assign(std::begin(canon), std::end(canon)); + c.assign(canon.begin(), canon.end()); return c; } @@ -114,7 +112,8 @@ namespace impl { // This requires that C supports construction from a pair of iterators template <typename C> operator C() const { - return C(std::begin(ref), std::end(ref)); + auto canon = canonical_view(ref); + return C(canon.begin(), canon.end()); } const Seq& ref; @@ -140,29 +139,15 @@ AssignableContainer& assign_by(AssignableContainer& c, const Seq& seq, const Pro // Note that a const range reference may wrap non-const iterators. template <typename Seq> -enable_if_t<!std::is_const<typename sequence_traits<Seq>::reference>::value> -sort(Seq& seq) { +enable_if_t<!std::is_const<typename sequence_traits<Seq&&>::reference>::value> +sort(Seq&& seq) { auto canon = canonical_view(seq); std::sort(std::begin(canon), std::end(canon)); } template <typename Seq, typename Less> -enable_if_t<!std::is_const<typename sequence_traits<Seq>::reference>::value> -sort(Seq& seq, const Less& less) { - auto canon = canonical_view(seq); - std::sort(std::begin(canon), std::end(canon), less); -} - -template <typename Seq> -enable_if_t<!std::is_const<typename sequence_traits<Seq>::reference>::value> -sort(const Seq& seq) { - auto canon = canonical_view(seq); - std::sort(std::begin(canon), std::end(canon)); -} - -template <typename Seq, typename Less> -enable_if_t<!std::is_const<typename sequence_traits<Seq>::reference>::value> -sort(const Seq& seq, const Less& less) { +enable_if_t<!std::is_const<typename sequence_traits<Seq&&>::reference>::value> +sort(Seq&& seq, const Less& less) { auto canon = canonical_view(seq); std::sort(std::begin(canon), std::end(canon), less); } @@ -170,21 +155,9 @@ sort(const Seq& seq, const Less& less) { // Sort in-place by projection `proj` template <typename Seq, typename Proj> -enable_if_t<!std::is_const<typename sequence_traits<Seq>::reference>::value> -sort_by(Seq& seq, const Proj& proj) { - using value_type = typename sequence_traits<Seq>::value_type; - auto canon = canonical_view(seq); - - std::sort(std::begin(canon), std::end(canon), - [&proj](const value_type& a, const value_type& b) { - return proj(a) < proj(b); - }); -} - -template <typename Seq, typename Proj> -enable_if_t<!std::is_const<typename sequence_traits<Seq>::reference>::value> -sort_by(const Seq& seq, const Proj& proj) { - using value_type = typename sequence_traits<Seq>::value_type; +enable_if_t<!std::is_const<typename sequence_traits<Seq&&>::reference>::value> +sort_by(Seq&& seq, const Proj& proj) { + using value_type = typename sequence_traits<Seq&&>::value_type; auto canon = canonical_view(seq); std::sort(std::begin(canon), std::end(canon), @@ -196,21 +169,9 @@ sort_by(const Seq& seq, const Proj& proj) { // Stable sort in-place by projection `proj` template <typename Seq, typename Proj> -enable_if_t<!std::is_const<typename sequence_traits<Seq>::reference>::value> -stable_sort_by(Seq& seq, const Proj& proj) { - using value_type = typename sequence_traits<Seq>::value_type; - auto canon = canonical_view(seq); - - std::stable_sort(std::begin(canon), std::end(canon), - [&proj](const value_type& a, const value_type& b) { - return proj(a) < proj(b); - }); -} - -template <typename Seq, typename Proj> -enable_if_t<!std::is_const<typename sequence_traits<Seq>::reference>::value> -stable_sort_by(const Seq& seq, const Proj& proj) { - using value_type = typename sequence_traits<Seq>::value_type; +enable_if_t<!std::is_const<typename sequence_traits<Seq&&>::reference>::value> +stable_sort_by(Seq&& seq, const Proj& proj) { + using value_type = typename sequence_traits<Seq&&>::value_type; auto canon = canonical_view(seq); std::stable_sort(std::begin(canon), std::end(canon), @@ -238,7 +199,7 @@ bool any_of(const Seq& seq, const Predicate& pred) { template < typename Seq, typename Proj, - typename Value = typename transform_iterator<typename sequence_traits<Seq>::const_iterator, Proj>::value_type + typename Value = typename transform_iterator<typename sequence_traits<const Seq&>::const_iterator, Proj>::value_type > Value sum_by(const Seq& seq, const Proj& proj, Value base = Value{}) { auto canon = canonical_view(transform_view(seq, proj)); @@ -250,21 +211,9 @@ Value sum_by(const Seq& seq, const Proj& proj, Value base = Value{}) { // value of `proj(*i)`. template <typename Seq, typename Proj> -typename sequence_traits<Seq>::iterator -max_element_by(Seq& seq, const Proj& proj) { - using value_type = typename sequence_traits<Seq>::value_type; - auto canon = canonical_view(seq); - - return std::max_element(std::begin(canon), std::end(canon), - [&proj](const value_type& a, const value_type& b) { - return proj(a) < proj(b); - }); -} - -template <typename Seq, typename Proj> -typename sequence_traits<Seq>::iterator -max_element_by(const Seq& seq, const Proj& proj) { - using value_type = typename sequence_traits<Seq>::value_type; +typename sequence_traits<Seq&&>::iterator +max_element_by(Seq&& seq, const Proj& proj) { + using value_type = typename sequence_traits<Seq&&>::value_type; auto canon = canonical_view(seq); return std::max_element(std::begin(canon), std::end(canon), @@ -284,7 +233,7 @@ max_element_by(const Seq& seq, const Proj& proj) { template < typename Seq, - typename Value = typename sequence_traits<Seq>::value_type, + typename Value = typename sequence_traits<const Seq&>::value_type, typename Compare = std::less<Value> > Value max_value(const Seq& seq, Compare cmp = Compare{}) { @@ -308,7 +257,7 @@ Value max_value(const Seq& seq, Compare cmp = Compare{}) { template < typename Seq, - typename Value = typename sequence_traits<Seq>::value_type, + typename Value = typename sequence_traits<const Seq&>::value_type, typename Compare = std::less<Value> > std::pair<Value, Value> minmax_value(const Seq& seq, Compare cmp = Compare{}) { @@ -332,10 +281,14 @@ std::pair<Value, Value> minmax_value(const Seq& seq, Compare cmp = Compare{}) { return {lower, upper}; } -// View over the keys in an associative container. +// Range-wrapper for std::is_sorted. + +template <typename Seq, typename = util::enable_if_sequence_t<const Seq&>> +bool is_sorted(const Seq& seq) { + auto canon = canonical_view(seq); + return std::is_sorted(std::begin(canon), std::end(canon)); +} -template <typename Map> -auto keys(Map& m) DEDUCED_RETURN_TYPE(util::transform_view(m, util::first)); // Test if sequence is sorted after apply projection `proj` to elements. // (TODO: this will perform unnecessary copies if `proj` returns a reference; @@ -344,7 +297,7 @@ auto keys(Map& m) DEDUCED_RETURN_TYPE(util::transform_view(m, util::first)); template < typename Seq, typename Proj, - typename Compare = std::less<typename std::result_of<Proj (typename sequence_traits<Seq>::value_type)>::type> + typename Compare = std::less<typename std::result_of<Proj (typename sequence_traits<const Seq&>::value_type)>::type> > bool is_sorted_by(const Seq& seq, const Proj& proj, Compare cmp = Compare{}) { auto i = std::begin(seq); @@ -390,6 +343,18 @@ C make_copy(Seq const& seq) { return C{std::begin(seq), std::end(seq)}; } +// Present a view of a finite sequence in reverse order, provided +// the sequence iterator is bidirectional. +template < + typename Seq, + typename It = typename sequence_traits<Seq&&>::iterator, + typename Rev = std::reverse_iterator<It> +> +range<Rev, Rev> reverse_view(Seq&& seq) { + auto strict = strict_view(seq); + return range<Rev, Rev>(Rev(strict.right), Rev(strict.left)); +} + } // namespace util } // namespace arb diff --git a/src/util/simple_table.hpp b/src/util/simple_table.hpp deleted file mode 100644 index 32d6ee42c09eeb5c0cf62b90ab4fc63293bc2ad9..0000000000000000000000000000000000000000 --- a/src/util/simple_table.hpp +++ /dev/null @@ -1,44 +0,0 @@ -#pragma once - -#include <iterator> -#include <cstring> -#include <utility> - -// Linear lookup of values in a table indexed by a key. -// -// Tables are any sequence of pairs or tuples, with the key as the first element. - -namespace arb { -namespace util { - -namespace impl { - struct key_equal { - template <typename U, typename V> - bool operator()(U&& u, V&& v) const { - return std::forward<U>(u)==std::forward<V>(v); - } - - // special case for C strings: - bool operator()(const char* u, const char* v) const { - return !std::strcmp(u, v); - } - }; -}; - -// Return pointer to value (second element in entry) in table if key found, -// otherwise nullptr. - -template <typename PairSeq, typename Key, typename Eq = impl::key_equal> -auto table_lookup(PairSeq&& seq, const Key& key, Eq eq = Eq{}) - -> decltype(&std::get<1>(*std::begin(seq))) -{ - for (auto&& entry: seq) { - if (eq(std::get<0>(entry), key)) { - return &std::get<1>(entry); - } - } - return nullptr; -} - -} // namespace util -} // namespace arb diff --git a/src/util/span.hpp b/src/util/span.hpp index 6767ba346e14f7ac16083ebf61e1ea5596ccd230..7a572cbcc87a8cffbc6334a65fe94337ce2e01aa 100644 --- a/src/util/span.hpp +++ b/src/util/span.hpp @@ -8,11 +8,18 @@ #include <utility> #include <util/counter.hpp> +#include <util/deduce_return.hpp> +#include <util/meta.hpp> #include <util/range.hpp> namespace arb { namespace util { +// TODO: simplify span-using code by: +// 1. replace type alias `span` with `span_type` alias; +// 2. rename `make_span` as `span` +// 3. add another `span(I n)` overload equivalent to `span(I{}, n)`. + template <typename I> using span = range<counter<I>>; @@ -26,6 +33,13 @@ span<typename std::common_type<I, J>::type> make_span(std::pair<I, J> interval) return span<typename std::common_type<I, J>::type>(interval.first, interval.second); } +template <typename I> +span<I> make_span(I right) { + return span<I>(I{}, right); +} + +template <typename Seq> +auto count_along(const Seq& s) DEDUCED_RETURN_TYPE(util::make_span(util::size(s))) } // namespace util } // namespace arb diff --git a/tests/common_cells.hpp b/tests/common_cells.hpp index 088967dbb86008940738bdb6e241ec1be7a73e75..4f5b6b39a6a848da4caef430d6ed51590dd5ecb1 100644 --- a/tests/common_cells.hpp +++ b/tests/common_cells.hpp @@ -258,7 +258,7 @@ inline cell make_cell_simple_cable(bool with_stim = true) { double gbar = 0.000025; double I = 0.1; - mechanism_spec pas("pas"); + mechanism_desc pas("pas"); pas["g"] = gbar; for (auto& seg: c.segments()) { diff --git a/tests/global_communication/CMakeLists.txt b/tests/global_communication/CMakeLists.txt index 88061eec44842ffe1f19ab8e1303989d8a793bcd..bd7c32d235417a85d7c27de34255c1cde0cc3724 100644 --- a/tests/global_communication/CMakeLists.txt +++ b/tests/global_communication/CMakeLists.txt @@ -16,7 +16,8 @@ add_executable(global_communication.exe ${COMMUNICATION_SOURCES} ${HEADERS}) set(TARGETS global_communication.exe) foreach(target ${TARGETS}) - target_link_libraries(${target} LINK_PUBLIC arbor gtest) + target_link_libraries(${target} LINK_PUBLIC gtest) + target_link_libraries(${target} LINK_PUBLIC ${ARB_LIBRARIES}) target_link_libraries(${target} LINK_PUBLIC ${EXTERNAL_LIBRARIES}) if(ARB_WITH_MPI) diff --git a/tests/global_communication/test_communicator.cpp b/tests/global_communication/test_communicator.cpp index a914befcb608befbc77dc192934ac7f568261fb5..90d199131c98997619172a2190c282b272d9167c 100644 --- a/tests/global_communication/test_communicator.cpp +++ b/tests/global_communication/test_communicator.cpp @@ -132,12 +132,12 @@ TEST(communicator, gather_spikes_variant) { // Parameter used to scale the number of spikes generated on successive // ranks. - const auto scale = 10; + constexpr int scale = 10; // Calculates the number of spikes generated by the first n ranks. // Can be used to calculate the index of the range of spikes // generated by a given rank, and to determine the total number of // spikes generated globally. - auto sumn = [scale](int n) {return scale*n*(n+1)/2;}; + auto sumn = [](int n) {return scale*n*(n+1)/2;}; const auto n_local_spikes = scale*rank; // Create local spikes for communication. @@ -397,9 +397,9 @@ TEST(communicator, ring) // last cell in each domain fires EXPECT_TRUE(test_ring(D, C, [n_local](cell_gid_type g){return (g+1)%n_local == 0u;})); // even-numbered cells fire - EXPECT_TRUE(test_ring(D, C, [n_local](cell_gid_type g){return g%2==0;})); + EXPECT_TRUE(test_ring(D, C, [](cell_gid_type g){return g%2==0;})); // odd-numbered cells fire - EXPECT_TRUE(test_ring(D, C, [n_local](cell_gid_type g){return g%2==1;})); + EXPECT_TRUE(test_ring(D, C, [](cell_gid_type g){return g%2==1;})); } template <typename F> @@ -490,9 +490,9 @@ TEST(communicator, all2all) // every cell fires EXPECT_TRUE(test_all2all(D, C, [](cell_gid_type g){return true;})); // only cell 0 fires - EXPECT_TRUE(test_all2all(D, C, [n_local](cell_gid_type g){return g==0u;})); + EXPECT_TRUE(test_all2all(D, C, [](cell_gid_type g){return g==0u;})); // even-numbered cells fire - EXPECT_TRUE(test_all2all(D, C, [n_local](cell_gid_type g){return g%2==0;})); + EXPECT_TRUE(test_all2all(D, C, [](cell_gid_type g){return g%2==0;})); // odd-numbered cells fire - EXPECT_TRUE(test_all2all(D, C, [n_local](cell_gid_type g){return g%2==1;})); + EXPECT_TRUE(test_all2all(D, C, [](cell_gid_type g){return g%2==1;})); } diff --git a/tests/modcc/CMakeLists.txt b/tests/modcc/CMakeLists.txt index 261e9b2bc6794b2d4e8fb7254b4846a7d5eaae67..64d5e06c04789a5da97178be469bb21ace5e4ba7 100644 --- a/tests/modcc/CMakeLists.txt +++ b/tests/modcc/CMakeLists.txt @@ -5,6 +5,7 @@ set(MODCC_TEST_SOURCES test_module.cpp test_msparse.cpp test_parser.cpp + test_prefixbuf.cpp test_printers.cpp test_removelocals.cpp test_symdiff.cpp @@ -21,6 +22,8 @@ set(MODCC_TEST_SOURCES test.cpp ) +include_directories("${PROJECT_SOURCE_DIR}/modcc") + add_definitions("-DDATADIR=\"${PROJECT_SOURCE_DIR}/data\"") add_executable(test_modcc ${MODCC_TEST_SOURCES}) diff --git a/tests/unit/test_prefixbuf.cpp b/tests/modcc/test_prefixbuf.cpp similarity index 83% rename from tests/unit/test_prefixbuf.cpp rename to tests/modcc/test_prefixbuf.cpp index b211c8161338da03ea164f599c5fd066a52a9436..a4a90d8c396462c3334174d418481bb724e285f6 100644 --- a/tests/unit/test_prefixbuf.cpp +++ b/tests/modcc/test_prefixbuf.cpp @@ -1,13 +1,12 @@ -#include "../gtest.h" - #include <cstring> #include <iomanip> #include <string> #include <sstream> -#include <util/prefixbuf.hpp> +#include "io/prefixbuf.hpp" +#include "test.hpp" -using namespace arb::util; +using namespace io; // Test public std::stringbuf 'put' interfaces on prefixbuf. @@ -49,6 +48,36 @@ TEST(prefixbuf, prefix) { EXPECT_EQ(expected, s.str()); } +// A prefixbuf can be configure to emit or not emit the +// prefix for empty lines. + +TEST(prefixbuf, empty_lines) { + auto write_sputn = [](std::streambuf& b, const char* c) { + b.sputn(c, std::strlen(c)); + }; + + std::stringbuf s; + + prefixbuf p1(&s, false); // omit prefix on blank lines + p1.prefix = "1> "; + write_sputn(p1, "hello\n\nfishies!\n\n"); + + prefixbuf p2(&s, true); // include prefix on blank lines + p2.prefix = "2> "; + write_sputn(p2, "hello\n\nbunnies!\n"); + + std::string expected = + "1> hello\n" + "\n" + "1> fishies!\n" + "\n" + "2> hello\n" + "2> \n" + "2> bunnies!\n"; + + EXPECT_EQ(expected, s.str()); +} + // Test `pfxstringstream` basic functionality: // // 1. `rdbuf()` method gives pointer to `prefixbuf`. diff --git a/tests/modcc/test_printers.cpp b/tests/modcc/test_printers.cpp index 1160cbbea047e2c56c16cd07e58266efde29c664..410c0b184a9b99dc212ffbc59bbc184413cf96e3 100644 --- a/tests/modcc/test_printers.cpp +++ b/tests/modcc/test_printers.cpp @@ -1,12 +1,16 @@ #include <regex> #include <string> +#include <sstream> #include "test.hpp" -#include "cprinter.hpp" -#include "cudaprinter.hpp" +#include "printer/cexpr_emit.hpp" +#include "printer/cprinter.hpp" #include "expression.hpp" -#include "textbuffer.hpp" +#include "symdiff.hpp" + +// Note: CUDA printer disabled until new implementation finished. +//#include "printer/cudaprinter.hpp" struct testcase { const char* source; @@ -32,6 +36,26 @@ static std::string strip(std::string text) { return text; } +TEST(scalar_printer, constants) { + testcase testcases[] = { + {"1./0.", "INFINITY"}, + {"-1./0.", "-INFINITY"}, + {"(-1)^0.5", "NAN"}, + {"1/(-1./0.)", "-0."}, + {"1-1", "0"}, + }; + + for (const auto& tc: testcases) { + auto expr = constant_simplify(parse_expression(tc.source)); + ASSERT_TRUE(expr && expr->is_number()); + + std::stringstream s; + s << as_c_double(expr->is_number()->value()); + + EXPECT_EQ(std::string(tc.expected), s.str()); + } +} + TEST(scalar_printer, statement) { std::vector<testcase> testcases = { {"y=x+3", "y=x+3"}, @@ -64,14 +88,16 @@ TEST(scalar_printer, statement) { { SCOPED_TRACE("CPrinter"); - auto printer = make_unique<CPrinter>(); + std::stringstream out; + auto printer = make_unique<CPrinter>(out); e->accept(printer.get()); - std::string text = printer->text(); + std::string text = out.str(); verbose_print(e->to_string(), " :--: ", text); EXPECT_EQ(strip(tc.expected), strip(text)); } +#if 0 { SCOPED_TRACE("CUDAPrinter"); TextBuffer buf; @@ -84,27 +110,26 @@ TEST(scalar_printer, statement) { verbose_print(e->to_string(), " :--: ", text); EXPECT_EQ(strip(tc.expected), strip(text)); } +#endif } } -TEST(CPrinter, proc) { +TEST(CPrinter, proc_body) { std::vector<testcase> testcases = { { "PROCEDURE trates(v) {\n" " LOCAL k\n" - " minf=1-1/(1+exp((v-k)/k))\n" - " hinf=1/(1+exp((v-k)/k))\n" - " mtau = 0.6\n" + " minf = 1-1/(1+exp((v-k)/k))\n" + " hinf = 1/(1+exp((v-k)/k))\n" + " mtau = 0.5\n" " htau = 1500\n" "}" , - "void trates(int i_, value_type v) {\n" "value_type k;\n" "minf[i_] = 1-1/(1+exp((v-k)/k));\n" "hinf[i_] = 1/(1+exp((v-k)/k));\n" - "mtau[i_] = 0.6;\n" + "mtau[i_] = 0.5;\n" "htau[i_] = 1500;\n" - "}" } }; @@ -124,12 +149,14 @@ TEST(CPrinter, proc) { auto& proc = (globals[procname] = symbol_ptr(e.release()->is_symbol())); proc->semantic(globals); - auto v = make_unique<CPrinter>(); - proc->accept(v.get()); + std::stringstream out; + auto v = make_unique<CPrinter>(out); + proc->is_procedure()->body()->accept(v.get()); + std::string text = out.str(); - verbose_print(proc->to_string()); - verbose_print(" :--: ", v->text()); + verbose_print(proc->is_procedure()->body()->to_string()); + verbose_print(" :--: ", text); - EXPECT_EQ(strip(tc.expected), strip(v->text())); + EXPECT_EQ(strip(tc.expected), strip(text)); } } diff --git a/tests/modcc/test_simd_backend.cpp b/tests/modcc/test_simd_backend.cpp index f01fc2c5d25660c657998815eddde514130a0d28..8f2ac112a84affb136c191e77a8edca40466e85e 100644 --- a/tests/modcc/test_simd_backend.cpp +++ b/tests/modcc/test_simd_backend.cpp @@ -1,3 +1,6 @@ +#if 0 +// Disabled pending new SIMD printer code. + #include "backends/simd.hpp" #include "textbuffer.hpp" #include "token.hpp" @@ -61,3 +64,4 @@ TEST(avx512, emit_unary_op) { simd_backend::emit_load_index(tb, "&a"); EXPECT_EQ("_mm256_lddqu_si256(&a)", tb.str()); } +#endif diff --git a/tests/performance/io/disk_io.cpp b/tests/performance/io/disk_io.cpp index 562ab2c9ba870a6b947ca9370c99873c06059ffa..a3bec2b8caf80fa8a5359046fb6b690920b71c4d 100644 --- a/tests/performance/io/disk_io.cpp +++ b/tests/performance/io/disk_io.cpp @@ -9,7 +9,6 @@ #include <common_types.hpp> #include <communication/communicator.hpp> #include <communication/global_policy.hpp> -#include <fvm_multicell.hpp> #include <io/exporter_spike_file.hpp> #include <profiling/profiler.hpp> #include <spike.hpp> diff --git a/tests/simple_recipes.hpp b/tests/simple_recipes.hpp index ef6e04fca3bd43111d9d874c7f57a1b27d60335e..9830dee0ec828593adfeac1fb39714ae782cf4cc 100644 --- a/tests/simple_recipes.hpp +++ b/tests/simple_recipes.hpp @@ -16,6 +16,12 @@ namespace arb { class simple_recipe_base: public recipe { public: + simple_recipe_base(): + catalogue_(global_default_catalogue()) + { + cell_gprop_.catalogue = &catalogue_; + } + cell_size_type num_probes(cell_gid_type i) const override { return probes_.count(i)? probes_.at(i).size(): 0; } @@ -31,22 +37,23 @@ public: pvec_.push_back({probe_id, tag, std::move(address)}); } - void add_specialized_mechanism(std::string name, specialized_mechanism m) { - cell_gprop.special_mechs[name] = std::move(m); - } - util::any get_global_properties(cell_kind k) const override { switch (k) { - case cell_kind::cable1d_neuron: - return cell_gprop; + case cell_kind::cable1d_neuron: + return cell_gprop_; default: return util::any{}; } } + mechanism_catalogue& catalogue() { + return catalogue_; + } + protected: std::unordered_map<cell_gid_type, std::vector<probe_info>> probes_; - cell_global_properties cell_gprop; + cell_global_properties cell_gprop_; + mechanism_catalogue catalogue_; }; // Convenience derived recipe class for wrapping n copies of a single diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 5133c8963b418d0f083abd0c17fb2ae51fea0692..adaae9002a2a04fcd68e24ba16d132140165107a 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -17,21 +17,24 @@ build_modules( # Unit test sources -set(TEST_CUDA_SOURCES +set(test_cuda_sources test_intrin.cu - test_mc_cell_group.cu test_gpu_stack.cu test_matrix.cu - test_multi_event_stream.cu + test_matrix_cpuvsgpu.cpp test_reduce_by_key.cu - test_spikes.cu test_vector.cu + test_mc_cell_group_gpu.cpp + test_multi_event_stream_gpu.cpp + test_multi_event_stream_gpu.cu + test_spikes_gpu.cpp + # unit test driver test.cpp ) -set(TEST_SOURCES +set(test_sources # unit tests test_algorithms.cpp test_any.cpp @@ -48,14 +51,17 @@ set(TEST_SOURCES test_event_generators.cpp test_event_queue.cpp test_filter.cpp - test_fvm_multi.cpp - test_lif_cell_group.cpp + test_fvm_layout.cpp + test_fvm_lowered.cpp test_mc_cell_group.cpp test_lexcmp.cpp + test_lif_cell_group.cpp + test_maputil.cpp test_mask_stream.cpp test_math.cpp test_matrix.cpp test_mechanisms.cpp + test_mechcat.cpp test_merge_events.cpp test_multi_event_stream.cpp test_nop.cpp @@ -65,7 +71,6 @@ set(TEST_SOURCES test_partition.cpp test_path.cpp test_point.cpp - test_prefixbuf.cpp test_probe.cpp test_range.cpp test_segment.cpp @@ -76,7 +81,6 @@ set(TEST_SOURCES test_spikes.cpp test_spike_store.cpp test_stats.cpp - test_stimulus.cpp test_strprintf.cpp test_swcio.cpp test_synapses.cpp @@ -93,13 +97,10 @@ set(TEST_SOURCES stats.cpp ) -if(ARB_VECTORIZE_TARGET STREQUAL "AVX2") - list(APPEND TEST_SOURCES test_intrin.cpp) -endif() - -set(TARGETS test.exe) +set(targets test.exe) -add_executable(test.exe ${TEST_SOURCES}) +add_executable(test.exe ${test_sources}) +target_compile_options(test.exe PRIVATE ${CXXOPT_ARCH}) target_compile_definitions(test.exe PUBLIC "-DDATADIR=\"${PROJECT_SOURCE_DIR}/data\"") if (ARB_AUTO_RUN_MODCC_ON_CHANGES) @@ -109,11 +110,11 @@ endif() target_include_directories(test.exe PRIVATE "${mech_proto_dir}/..") if(ARB_WITH_CUDA) - set(TARGETS ${TARGETS} test_cuda.exe) - cuda_add_executable(test_cuda.exe ${TEST_CUDA_SOURCES}) + list(APPEND targets test_cuda.exe) + cuda_add_executable(test_cuda.exe ${test_cuda_sources}) endif() -foreach(target ${TARGETS}) +foreach(target ${targets}) target_link_libraries(${target} LINK_PUBLIC gtest) target_link_libraries(${target} LINK_PUBLIC ${ARB_LIBRARIES}) target_link_libraries(${target} LINK_PUBLIC ${EXTERNAL_LIBRARIES}) diff --git a/tests/unit/common.hpp b/tests/unit/common.hpp index ad6c68739d4711831b2c47ba14ee8832265a8275..7f9b73d9bfcf18d7958b21c00b3d1bfc3dce96a2 100644 --- a/tests/unit/common.hpp +++ b/tests/unit/common.hpp @@ -13,7 +13,7 @@ namespace testing { -// string ctor suffix (until C++14!) +// String ctor suffix (until C++14!). namespace string_literals { inline std::string operator ""_s(const char* s, std::size_t n) { @@ -21,9 +21,13 @@ namespace string_literals { } } -// sentinel for use with range-related tests + +// Sentinel for C-style strings, for use with range-related tests. struct null_terminated_t { + bool operator==(null_terminated_t) const { return true; } + bool operator!=(null_terminated_t) const { return false; } + bool operator==(const char *p) const { return !*p; } bool operator!=(const char *p) const { return !!*p; } @@ -120,10 +124,74 @@ int nomove<V>::copy_ctor_count; template <typename V> int nomove<V>::copy_assign_count; + +// Subvert class access protections. Demo: +// +// class foo { +// int secret = 7; +// }; +// +// int foo::* secret_mptr; +// template class access::bind<int foo::*, secret_mptr, &foo::secret>; +// +// int seven = foo{}.*secret_mptr; +// +// Or with shortcut define (places global in anonymous namespace): +// +// ACCESS_BIND(int foo::*, secret_mptr, &foo::secret) +// +// int seven = foo{}.*secret_mptr; + +namespace access { + template <typename V, V& store, V value> + struct bind { + static struct binder { + binder() { store = value; } + } init; + }; + + template <typename V, V& store, V value> + typename bind<V, store, value>::binder bind<V, store, value>::init; +} // namespace access + +#define ACCESS_BIND(type, global, value)\ +namespace { using global ## _type_ = type; global ## _type_ global; }\ +template struct ::testing::access::bind<type, global, value>; + + // Google Test assertion-returning predicates: -// Assert two sequences of floating point values are almost equal. +// Assert two values are 'almost equal', with exact test for non-floating point types. // (Uses internal class `FloatingPoint` from gtest.) + +template <typename FPType> +::testing::AssertionResult almost_eq_(FPType a, FPType b, std::true_type) { + using FP = testing::internal::FloatingPoint<FPType>; + + if ((std::isnan(a) && std::isnan(b)) || FP{a}.AlmostEquals(FP{b})) { + return ::testing::AssertionSuccess(); + } + + return ::testing::AssertionFailure() << "floating point numbers " << a << " and " << b << " differ"; +} + +template <typename X> +::testing::AssertionResult almost_eq_(const X& a, const X& b, std::false_type) { + if (a==b) { + return ::testing::AssertionSuccess(); + } + + return ::testing::AssertionFailure() << "values " << a << " and " << b << " differ"; +} + +template <typename X> +::testing::AssertionResult almost_eq(const X& a, const X& b) { + return almost_eq_(a, b, typename std::is_floating_point<X>::type{}); +} + +// Assert two sequences of floating point values are almost equal, with explicit +// specification of floating point type. + template <typename FPType, typename Seq1, typename Seq2> ::testing::AssertionResult seq_almost_eq(Seq1&& seq1, Seq2&& seq2) { using std::begin; @@ -136,7 +204,6 @@ template <typename FPType, typename Seq1, typename Seq2> auto e2 = end(seq2); for (std::size_t j = 0; i1!=e1 && i2!=e2; ++i1, ++i2, ++j) { - using FP = testing::internal::FloatingPoint<FPType>; auto v1 = *i1; auto v2 = *i2; @@ -144,10 +211,8 @@ template <typename FPType, typename Seq1, typename Seq2> // Cast to FPType to avoid warnings about lowering conversion // if FPType has lower precision than Seq{12}::value_type. - if (!(std::isnan(v1) && std::isnan(v2)) && !FP{v1}.AlmostEquals(FP{v2})) { - return ::testing::AssertionFailure() << "floating point numbers " << v1 << " and " << v2 << " differ at index " << j; - } - + auto status = almost_eq((FPType)(v1), (FPType)(v2)); + if (!status) return status << " at index " << j; } if (i1!=e1 || i2!=e2) { @@ -194,6 +259,7 @@ template <typename Seq1, typename Seq2> } // Assert elements 0..n-1 inclusive of two indexed collections are exactly equal. + template <typename Arr1, typename Arr2> ::testing::AssertionResult indexed_eq_n(int n, Arr1&& a1, Arr2&& a2) { for (int i = 0; i<n; ++i) { @@ -208,8 +274,23 @@ template <typename Arr1, typename Arr2> return ::testing::AssertionSuccess(); } +// Assert elements 0..n-1 inclusive of two indexed collections are almost equal. + +template <typename Arr1, typename Arr2> +::testing::AssertionResult indexed_almost_eq_n(int n, Arr1&& a1, Arr2&& a2) { + for (int i = 0; i<n; ++i) { + auto v1 = a1[i]; + auto v2 = a2[i]; + + auto status = almost_eq(v1, v2); + if (!status) return status << " at index " << i; + } + + return ::testing::AssertionSuccess(); +} // Assert two floating point values are within a relative tolerance. + inline ::testing::AssertionResult near_relative(double a, double b, double relerr) { double tol = relerr*std::max(std::abs(a), std::abs(b)); if (std::abs(a-b)>tol) { diff --git a/tests/unit/instrument_malloc.hpp b/tests/unit/instrument_malloc.hpp index b36b8f83ef3e28e120ddbfd40562d3cdcf040691..825d4df92bdb591c802e51cf45e9438cee744352 100644 --- a/tests/unit/instrument_malloc.hpp +++ b/tests/unit/instrument_malloc.hpp @@ -25,6 +25,19 @@ #define CAN_INSTRUMENT_MALLOC #endif +// Disable if using address sanitizer though: + +// This is how clang tells us. +#if defined(__has_feature) +#if __has_feature(address_sanitizer) +#undef CAN_INSTRUMENT_MALLOC +#endif +#endif +// This is how gcc tells us. +#if defined(__SANITIZE_ADDRESS__) +#undef CAN_INSTRUMENT_MALLOC +#endif + namespace testing { #ifdef CAN_INSTRUMENT_MALLOC @@ -35,6 +48,10 @@ namespace testing { #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#if defined(__INTEL_COMPILER) +#pragma warning push +#pragma warning disable 1478 +#endif // Totally not thread safe! struct with_instrumented_malloc { @@ -119,6 +136,9 @@ private: }; #pragma GCC diagnostic pop +#if defined(__INTEL_COMPILER) +#pragma warning pop +#endif #else diff --git a/tests/unit/test_algorithms.cpp b/tests/unit/test_algorithms.cpp index c3fe55baf47ea2a2e8d88ba7c665be550bbd03e9..230fd0afc713da4f7b7b9fdf2e1fb628730e78f2 100644 --- a/tests/unit/test_algorithms.cpp +++ b/tests/unit/test_algorithms.cpp @@ -1,14 +1,19 @@ +#include <forward_list> #include <iterator> #include <random> +#include <string> #include <vector> #include "../gtest.h" #include <algorithms.hpp> -#include "../test_util.hpp" +#include <util/compat.hpp> #include <util/debug.hpp> +#include <util/index_into.hpp> #include <util/meta.hpp> +#include "common.hpp" + /// tests the sort implementation in threading /// is only parallel if TBB is being used TEST(algorithms, parallel_sort) @@ -176,28 +181,39 @@ TEST(algorithms, is_strictly_monotonic_decreasing) ); } -TEST(algorithms, is_positive) -{ - EXPECT_TRUE( - arb::algorithms::is_positive( - std::vector<int>{} - ) - ); - EXPECT_TRUE( - arb::algorithms::is_positive( - std::vector<int>{3, 2, 1} - ) - ); - EXPECT_FALSE( - arb::algorithms::is_positive( - std::vector<int>{3, 2, 1, 0} - ) - ); - EXPECT_FALSE( - arb::algorithms::is_positive( - std::vector<int>{-1} - ) - ); +TEST(algorithms, all_positive) { + using arb::algorithms::all_positive; + + EXPECT_TRUE(all_positive(std::vector<int>{})); + EXPECT_TRUE(all_positive(std::vector<int>{3, 2, 1})); + EXPECT_FALSE(all_positive(std::vector<int>{3, 2, 1, 0})); + EXPECT_FALSE(all_positive(std::vector<int>{-1})); + + EXPECT_TRUE(all_positive((double []){1., 2.})); + EXPECT_FALSE(all_positive((double []){1., 0.})); + EXPECT_FALSE(all_positive((double []){NAN})); + + EXPECT_TRUE(all_positive((std::string []){"a", "b"})); + EXPECT_FALSE(all_positive((std::string []){"a", "", "b"})); +} + +TEST(algorithms, all_negative) { + using arb::algorithms::all_negative; + + EXPECT_TRUE(all_negative(std::vector<int>{})); + EXPECT_TRUE(all_negative(std::vector<int>{-3, -2, -1})); + EXPECT_FALSE(all_negative(std::vector<int>{-3, -2, -1, 0})); + EXPECT_FALSE(all_negative(std::vector<int>{1})); + + double negzero = std::copysign(0., -1.); + + EXPECT_TRUE(all_negative((double []){-1., -2.})); + EXPECT_FALSE(all_negative((double []){-1., 0.})); + EXPECT_FALSE(all_negative((double []){-1., negzero})); + EXPECT_FALSE(all_negative((double []){NAN})); + + EXPECT_FALSE(all_negative((std::string []){"", "b"})); + EXPECT_FALSE(all_negative((std::string []){""})); } TEST(algorithms, has_contiguous_compartments) @@ -332,50 +348,6 @@ TEST(algorithms, is_unique) ); } -TEST(algorithms, is_sorted) -{ - EXPECT_TRUE( - arb::algorithms::is_sorted( - std::vector<int>{} - ) - ); - EXPECT_TRUE( - arb::algorithms::is_sorted( - std::vector<int>{100} - ) - ); - EXPECT_TRUE( - arb::algorithms::is_sorted( - std::vector<int>{0,1,2} - ) - ); - EXPECT_TRUE( - arb::algorithms::is_sorted( - std::vector<int>{0,2,100} - ) - ); - EXPECT_TRUE( - arb::algorithms::is_sorted( - std::vector<int>{0,0} - ) - ); - EXPECT_TRUE( - arb::algorithms::is_sorted( - std::vector<int>{0,1,2,2,2,2,3,4,5,5,5} - ) - ); - EXPECT_FALSE( - arb::algorithms::is_sorted( - std::vector<int>{0,1,2,1} - ) - ); - EXPECT_FALSE( - arb::algorithms::is_sorted( - std::vector<int>{1,0} - ) - ); -} - TEST(algorithms, child_count) { { @@ -543,41 +515,130 @@ struct test_index_into { } }; +template <typename Sub, typename Sup> +::testing::AssertionResult validate_index_into(const Sub& sub, const Sup& sup) { + using namespace arb; + + auto indices = util::index_into(sub, sup); + auto n_indices = util::size(indices); + auto n_sub = util::size(sub); + if (util::size(indices)!=util::size(sub)) { + return ::testing::AssertionFailure() + << "index_into size " << n_indices << " does not equal sub-sequence size " << n_sub; + } + + using std::begin; + using compat::end; + + auto sub_i = begin(sub); + auto sup_i = begin(sup); + auto sup_end = end(sup); + std::ptrdiff_t sup_idx = 0; + + for (auto i: indices) { + if (sup_idx>i) { + return ::testing::AssertionFailure() << "indices in index_into sequence not monotonic"; + } + + while (sup_idx<i && sup_i!=sup_end) ++sup_idx, ++sup_i; + + if (sup_i==sup_end) { + return ::testing::AssertionFailure() << "index " << i << "in index_into sequence is past the end"; + } + + if (!(*sub_i==*sup_i)) { + return ::testing::AssertionFailure() + << "value mismatch: sub-sequence element " << *sub_i + << " not equal to super-sequence element " << *sup_i << " at index " << i; + } + + ++sub_i; + } + + return ::testing::AssertionSuccess(); +} + +template <typename I> +arb::util::range<std::reverse_iterator<I>> reverse_range(arb::util::range<I> r) { + using reviter = std::reverse_iterator<I>; + return arb::util::make_range(reviter(r.end()), reviter(r.begin())); +} + TEST(algorithms, index_into) { - using C = std::vector<int>; + using ivector = std::vector<std::ptrdiff_t>; using arb::util::size; - - // by default index_into assumes that the inputs satisfy - // quite a strong set of prerequisites - auto tests = { - std::make_pair(C{}, C{}), - std::make_pair(C{100}, C{}), - std::make_pair(C{0,1,3,4,6,7,10,11}, C{0,4,6,7,11}), - std::make_pair(C{0,1,3,4,6,7,10,11}, C{0}), - std::make_pair(C{0,1,3,4,6,7,10,11}, C{11}), - std::make_pair(C{0,1,3,4,6,7,10,11}, C{4}), - std::make_pair(C{0,1,3,4,6,7,10,11}, C{0,11}), - std::make_pair(C{0,1,3,4,6,7,10,11}, C{4,11}), - std::make_pair(C{0,1,3,4,6,7,10,11}, C{0,1,3,4,6,7,10,11}) + using arb::util::index_into; + using arb::util::assign_from; + using arb::util::make_range; + using arb::util::all_of; + + std::vector<std::pair<std::vector<int>, std::vector<int>>> vector_tests = { + // Empty sequences: + {{}, {}}, + {{100}, {}}, + // Strictly monotonic sequences: + {{0, 1, 3, 4, 6, 7, 10, 11}, {0, 4, 6, 7, 11}}, + {{0, 1, 3, 4, 6, 7, 10, 11}, {0}}, + {{0, 1, 3, 4, 6, 7, 10, 11}, {11}}, + {{0, 1, 3, 4, 6, 7, 10, 11}, {4}}, + {{0, 1, 3, 4, 6, 7, 10, 11}, {0, 11}}, + {{0, 1, 3, 4, 6, 7, 10, 11}, {4, 11}}, + {{0, 1, 3, 4, 6, 7, 10, 11}, {0, 1, 3, 4, 6, 7, 10, 11}}, + // Sequences with duplicates: + {{8, 8, 10, 10, 12, 12, 12, 13}, {8, 10, 13}}, + {{8, 8, 10, 10, 12, 12, 12, 13}, {10, 10, 13}}, + // Unordered sequences: + {{10, 3, 7, -8, 11, -8, 1, 2}, {3, -8, -8, 1}} }; - test_index_into tester; - for(auto& t : tests) { - EXPECT_TRUE( - tester(t.second, t.first, arb::algorithms::index_into(t.second, t.first)) - ); + for (auto& testcase: vector_tests) { + EXPECT_TRUE(validate_index_into(testcase.second, testcase.first)); } - // test for arrays - int sub[] = {2, 3, 5, 9}; - int sup[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - auto idx = arb::algorithms::index_into(sub, sup); - EXPECT_EQ(size(sub), size(idx)); - auto it = idx.begin(); - for (auto i: sub) { - EXPECT_EQ(i, *it++); + // Test across array types. + + int subarr1[] = {2, 3, 9, 5}; + int suparr1[] = {10, 2, 9, 3, 8, 5, 9, 3, 5, 10}; + EXPECT_TRUE(validate_index_into(subarr1, suparr1)); + + // Test bidirectionality. + + auto arr_indices = index_into(subarr1, suparr1); + ivector arridx = assign_from(arr_indices); + + ivector revidx; + for (auto i = arr_indices.end(); i!=arr_indices.begin(); ) { + revidx.push_back(*--i); } + + std::vector<std::ptrdiff_t> expected(arridx); + std::reverse(expected.begin(), expected.end()); + EXPECT_EQ(expected, revidx); + + int subarr2[] = {8, 8, 8, 8, 8}; + int suparr2[] = {8}; + + auto z_indices = index_into(subarr2, suparr2); + EXPECT_TRUE(all_of(z_indices, [](std::ptrdiff_t n) { return n==0; })); + EXPECT_EQ(0, z_indices.back()); + + // Test: strictly forward sequences; heterogenous sequences; sentinel-terminated ranges. + + std::forward_list<double> sup_flist = {10.0, 2.1, 8.0, 3.8, 4.0, 4.0, 7.0, 1.0}; + std::forward_list<int> sub_flist = {8, 4, 4, 1}; + + auto flist_indices = index_into(sub_flist, sup_flist); + ivector idx_flist = assign_from(flist_indices); + EXPECT_EQ((ivector{2, 4, 4, 7}), idx_flist); + + const char* hello_world = "hello world"; + const char* lol = "lol"; + auto sup_cstr = make_range(hello_world, testing::null_terminated); + auto sub_cstr = make_range(lol, testing::null_terminated); + auto cstr_indices = index_into(sub_cstr, sup_cstr); + ivector idx_cstr = assign_from(cstr_indices); + EXPECT_EQ((ivector{2, 4, 9}), idx_cstr); } TEST(algorithms, binary_find) diff --git a/tests/unit/test_backend.cpp b/tests/unit/test_backend.cpp index 441429221acf202ecec6f5cc7f12ef9fa64640a3..f3d846c15869b5d00a9dba2e1343f65794b41c65 100644 --- a/tests/unit/test_backend.cpp +++ b/tests/unit/test_backend.cpp @@ -1,23 +1,18 @@ #include <type_traits> -#include <backends/fvm.hpp> -#include <memory/memory.hpp> +#include <backends.hpp> +#include <fvm_lowered_cell.hpp> #include <util/config.hpp> #include "../gtest.h" -TEST(backends, gpu_is_null) { - using backend = arb::gpu::backend; +using namespace arb; - static_assert(std::is_same<backend, arb::null_backend>::value || arb::config::has_cuda, - "gpu back should be defined as null when compiling without gpu support."); - - if (not arb::config::has_cuda) { - EXPECT_FALSE(backend::is_supported()); - - EXPECT_FALSE(backend::has_mechanism("hh")); - EXPECT_THROW( - backend::make_mechanism("hh", 0, backend::const_iview(), backend::const_view(), backend::const_view(), backend::const_view(), backend::view(), backend::view(), {}, {}), - std::runtime_error); +TEST(backends, gpu_test) { + if (!arb::config::has_cuda) { + EXPECT_ANY_THROW(make_fvm_lowered_cell(backend_kind::gpu)); + } + else { + EXPECT_NO_THROW(make_fvm_lowered_cell(backend_kind::gpu)); } } diff --git a/tests/unit/test_cell.cpp b/tests/unit/test_cell.cpp index 58981bfa626550896e0adfed71bd1d1279bbdca3..fee64731befd8123c0dbb023dc006bdc371e7345 100644 --- a/tests/unit/test_cell.cpp +++ b/tests/unit/test_cell.cpp @@ -218,7 +218,7 @@ TEST(cell, clone) c.add_cable(1, section_kind::dendrite, 0.2, 0.15, 20); c.segment(2)->set_compartments(5); - c.add_synapse({1, 0.3}, mechanism_spec("expsyn")); + c.add_synapse({1, 0.3}, "expsyn"); c.add_detector({0, 0.5}, 10.0); diff --git a/tests/unit/test_fvm_layout.cpp b/tests/unit/test_fvm_layout.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bd4774f6f42590b9135654c87078e3dcc913ce2b --- /dev/null +++ b/tests/unit/test_fvm_layout.cpp @@ -0,0 +1,579 @@ +#include <vector> + +#include <cell.hpp> +#include <fvm_layout.hpp> +#include <math.hpp> +#include <mechcat.hpp> +#include <util/maputil.hpp> +#include <util/optional.hpp> +#include <util/rangeutil.hpp> +#include <util/span.hpp> + +#include "common.hpp" +#include "../common_cells.hpp" + +using namespace arb; +using namespace testing::string_literals; + +using util::make_span; +using util::count_along; +using util::value_by_key; + +std::vector<cell> two_cell_system() { + std::vector<cell> cells; + + // Cell 0: simple ball and stick (see common_cells.hpp) + cells.push_back(make_cell_ball_and_stick()); + + // Cell 1: ball and 3-stick, but with uneven dendrite + // length and heterogeneous electrical properties: + // + // Bulk resistivity: 90 Ω·cm + // capacitance: + // soma: 0.01 F/m² [default] + // segment 1: 0.017 F/m² + // segment 2: 0.013 F/m² + // segment 3: 0.018 F/m² + // + // Soma diameter: 14 µm + // Some mechanisms: HH (default params) + // + // Segment 1 diameter: 1 µm + // Segment 1 length: 200 µm + // + // Segment 2 diameter: 0.8 µm + // Segment 2 length: 300 µm + // + // Segment 3 diameter: 0.7 µm + // Segment 3 length: 180 µm + // + // Dendrite mechanisms: passive (default params). + // Stimulus at end of segment 2, amplitude 0.45. + // Stimulus at end of segment 3, amplitude -0.2. + // + // All dendrite segments with 4 compartments. + + cell c2; + segment* s; + + s = c2.add_soma(14./2); + s->add_mechanism("hh"); + + s = c2.add_cable(0, section_kind::dendrite, 1.0/2, 1.0/2, 200); + s->cm = 0.017; + + s = c2.add_cable(1, section_kind::dendrite, 0.8/2, 0.8/2, 300); + s->cm = 0.013; + + s = c2.add_cable(1, section_kind::dendrite, 0.7/2, 0.7/2, 180); + s->cm = 0.018; + + c2.add_stimulus({2,1}, {5., 80., 0.45}); + c2.add_stimulus({3,1}, {40., 10.,-0.2}); + + for (auto& seg: c2.segments()) { + seg->rL = 90.; + if (seg->is_dendrite()) { + seg->add_mechanism("pas"); + seg->set_compartments(4); + } + } + cells.push_back(std::move(c2)); + return cells; +} + +void check_two_cell_system(std::vector<cell>& cells) { + ASSERT_EQ(2u, cells[0].num_segments()); + ASSERT_EQ(cells[0].segment(1)->num_compartments(), 4u); + ASSERT_EQ(cells[1].num_segments(), 4u); + ASSERT_EQ(cells[1].segment(1)->num_compartments(), 4u); + ASSERT_EQ(cells[1].segment(2)->num_compartments(), 4u); + ASSERT_EQ(cells[1].segment(3)->num_compartments(), 4u); +} + +TEST(fvm_layout, topology) { + std::vector<cell> cells = two_cell_system(); + check_two_cell_system(cells); + + fvm_discretization D = fvm_discretize(cells); + + // Expected CV layouts for cells, segment indices in paren. + // + // Cell 0: + // + // CV: | 0 | 1 | 2 | 3 | 4| + // [soma (0)][ segment (1) ] + // + // Cell 1: + // + // CV: | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13| + // [soma (2)][ segment (3) ][ segment (4) ] + // [ segment (5) ] + // | 14 | 15 | 16 | 17| + + EXPECT_EQ(2u, D.ncell); + EXPECT_EQ(18u, D.ncomp); + + unsigned nseg = 6; + EXPECT_EQ(nseg, D.segments.size()); + + // General sanity checks: + + ASSERT_EQ(D.ncell, D.cell_segment_part().size()); + ASSERT_EQ(D.ncell, D.cell_cv_part().size()); + + ASSERT_EQ(D.ncomp, D.parent_cv.size()); + ASSERT_EQ(D.ncomp, D.cv_to_cell.size()); + ASSERT_EQ(D.ncomp, D.face_conductance.size()); + ASSERT_EQ(D.ncomp, D.cv_area.size()); + ASSERT_EQ(D.ncomp, D.cv_capacitance.size()); + + // Partitions of CVs and segments by cell: + + using spair = std::pair<fvm_size_type, fvm_size_type>; + using ipair = std::pair<fvm_index_type, fvm_index_type>; + + EXPECT_EQ(spair(0, 2), D.cell_segment_part()[0]); + EXPECT_EQ(spair(2, nseg), D.cell_segment_part()[1]); + + EXPECT_EQ(ipair(0, 5), D.cell_cv_part()[0]); + EXPECT_EQ(ipair(5, D.ncomp), D.cell_cv_part()[1]); + + // Segment and CV parent relationships: + + using ivec = std::vector<fvm_index_type>; + + EXPECT_EQ(ivec({0,0,1,2,3,5,5,6,7,8,9,10,11,12,9,14,15,16}), D.parent_cv); + + EXPECT_FALSE(D.segments[0].has_parent()); + EXPECT_EQ(0, D.segments[1].parent_cv); + + EXPECT_FALSE(D.segments[2].has_parent()); + EXPECT_EQ(5, D.segments[3].parent_cv); + EXPECT_EQ(9, D.segments[4].parent_cv); + EXPECT_EQ(9, D.segments[5].parent_cv); + + // Segment CV ranges (half-open, exclusing parent): + + EXPECT_EQ(ipair(0,1), D.segments[0].cv_range()); + EXPECT_EQ(ipair(1,5), D.segments[1].cv_range()); + EXPECT_EQ(ipair(5,6), D.segments[2].cv_range()); + EXPECT_EQ(ipair(6,10), D.segments[3].cv_range()); + EXPECT_EQ(ipair(10,14), D.segments[4].cv_range()); + EXPECT_EQ(ipair(14,18), D.segments[5].cv_range()); + + // CV to cell index: + + for (auto ci: make_span(D.ncell)) { + for (auto cv: make_span(D.cell_cv_part()[ci])) { + EXPECT_EQ(ci, (fvm_size_type)D.cv_to_cell[cv]); + } + } +} + +TEST(fvm_layout, area) { + std::vector<cell> cells = two_cell_system(); + check_two_cell_system(cells); + + fvm_discretization D = fvm_discretize(cells); + + // Note: stick models have constant diameter segments. + // Refer to comment above for CV vs. segment layout. + + std::vector<double> A; + for (auto ci: make_span(D.ncell)) { + for (auto si: make_span(cells[ci].num_segments())) { + A.push_back(cells[ci].segment(si)->area()); + } + } + + unsigned n = 4; // compartments per dendritic segment + EXPECT_FLOAT_EQ(A[0]+A[1]/(2*n), D.cv_area[0]); + EXPECT_FLOAT_EQ(A[1]/n, D.cv_area[1]); + EXPECT_FLOAT_EQ(A[1]/n, D.cv_area[2]); + EXPECT_FLOAT_EQ(A[1]/n, D.cv_area[3]); + EXPECT_FLOAT_EQ(A[1]/(2*n), D.cv_area[4]); + + EXPECT_FLOAT_EQ(A[2]+A[3]/(2*n), D.cv_area[5]); + EXPECT_FLOAT_EQ(A[3]/n, D.cv_area[6]); + EXPECT_FLOAT_EQ(A[3]/n, D.cv_area[7]); + EXPECT_FLOAT_EQ(A[3]/n, D.cv_area[8]); + EXPECT_FLOAT_EQ((A[3]+A[4]+A[5])/(2*n), D.cv_area[9]); + EXPECT_FLOAT_EQ(A[4]/n, D.cv_area[10]); + EXPECT_FLOAT_EQ(A[4]/n, D.cv_area[11]); + EXPECT_FLOAT_EQ(A[4]/n, D.cv_area[12]); + EXPECT_FLOAT_EQ(A[4]/(2*n), D.cv_area[13]); + EXPECT_FLOAT_EQ(A[5]/n, D.cv_area[14]); + EXPECT_FLOAT_EQ(A[5]/n, D.cv_area[15]); + EXPECT_FLOAT_EQ(A[5]/n, D.cv_area[16]); + EXPECT_FLOAT_EQ(A[5]/(2*n), D.cv_area[17]); + + // Confirm proportional allocation of surface capacitance: + + // CV 9 should have area-weighted sum of the specific + // capacitance from segments 3, 4 and 5 (cell 1 segments + // 1, 2 and 3 respectively). + + double cm1 = cells[1].segment(1)->cm; + double cm2 = cells[1].segment(2)->cm; + double cm3 = cells[1].segment(3)->cm; + + double c = A[3]/(2*n)*cm1+A[4]/(2*n)*cm2+A[5]/(2*n)*cm3; + EXPECT_FLOAT_EQ(c, D.cv_capacitance[9]); + + // CV 5 should be a weighted sum of soma and first segment + // capacitcance from cell 1. + + double cm0 = cells[1].soma()->cm; + c = A[2]*cm0+A[3]/(2*n)*cm1; + EXPECT_FLOAT_EQ(c, D.cv_capacitance[5]); + + // Confirm face conductance within a constant diameter + // segment equals a/h·1/rL where a is the cross sectional + // area, and h is the compartment length (given the + // regular discretization). + + cable_segment* cable = cells[1].segment(2)->as_cable(); + double a = cable->volume()/cable->length(); + EXPECT_FLOAT_EQ(math::pi<double>()*0.8*0.8/4, a); + + double h = cable->length()/4; + double g = a/h/cable->rL; // [µm·S/cm] + g *= 100; // [µS] + + EXPECT_FLOAT_EQ(g, D.face_conductance[11]); +} + +TEST(fvm_layout, mech_index) { + std::vector<cell> cells = two_cell_system(); + check_two_cell_system(cells); + + // Add four synapses of two varieties across the cells. + cells[0].add_synapse({1, 0.4}, "expsyn"); + cells[0].add_synapse({1, 0.4}, "expsyn"); + cells[1].add_synapse({2, 0.4}, "exp2syn"); + cells[1].add_synapse({3, 0.4}, "expsyn"); + + fvm_discretization D = fvm_discretize(cells); + fvm_mechanism_data M = fvm_build_mechanism_data(global_default_catalogue(), cells, D); + + auto& hh_config = M.mechanisms.at("hh"); + auto& expsyn_config = M.mechanisms.at("expsyn"); + auto& exp2syn_config = M.mechanisms.at("exp2syn"); + + using ivec = std::vector<fvm_index_type>; + using fvec = std::vector<fvm_value_type>; + + // HH on somas of two cells, with CVs 0 and 5. + // Proportional area contrib: soma area/CV area. + + EXPECT_EQ(mechanismKind::density, hh_config.kind); + EXPECT_EQ(ivec({0,5}), hh_config.cv); + + fvec norm_area({cells[0].soma()->area()/D.cv_area[0], cells[1].soma()->area()/D.cv_area[5]}); + EXPECT_TRUE(testing::seq_almost_eq<double>(norm_area, hh_config.norm_area)); + + // Three expsyn synapses, two 0.4 along segment 1, and one 0.4 along segment 5. + // 0.4 along => second (non-parent) CV for segment. + + EXPECT_EQ(ivec({2, 2, 15}), expsyn_config.cv); + + // One exp2syn synapse, 0.4 along segment 4. + + EXPECT_EQ(ivec({11}), exp2syn_config.cv); + + // There should be a K and Na ion channel associated with each + // hh mechanism node. + + ASSERT_EQ(1u, M.ions.count(ionKind::na)); + ASSERT_EQ(1u, M.ions.count(ionKind::k)); + EXPECT_EQ(0u, M.ions.count(ionKind::ca)); + + EXPECT_EQ(ivec({0,5}), M.ions.at(ionKind::na).cv); + EXPECT_EQ(ivec({0,5}), M.ions.at(ionKind::k).cv); +} + +TEST(fvm_layout, synapse_targets) { + std::vector<cell> cells = two_cell_system(); + + // Add synapses with different parameter values so that we can + // ensure: 1) CVs for each synapse mechanism are sorted while + // 2) the target index for each synapse corresponds to the + // original ordering. + + const unsigned nsyn = 7; + std::vector<double> syn_e(nsyn); + for (auto i: count_along(syn_e)) { + syn_e[i] = 0.1*(1+i); + } + + auto syn_desc = [&](const char* name, int idx) { + return mechanism_desc(name).set("e", syn_e.at(idx)); + }; + + cells[0].add_synapse({1, 0.9}, syn_desc("expsyn", 0)); + cells[0].add_synapse({0, 0.5}, syn_desc("expsyn", 1)); + cells[0].add_synapse({1, 0.4}, syn_desc("expsyn", 2)); + + cells[1].add_synapse({2, 0.4}, syn_desc("exp2syn", 3)); + cells[1].add_synapse({1, 0.4}, syn_desc("exp2syn", 4)); + cells[1].add_synapse({3, 0.4}, syn_desc("expsyn", 5)); + cells[1].add_synapse({3, 0.7}, syn_desc("exp2syn", 6)); + + fvm_discretization D = fvm_discretize(cells); + fvm_mechanism_data M = fvm_build_mechanism_data(global_default_catalogue(), cells, D); + + ASSERT_EQ(1u, M.mechanisms.count("expsyn")); + ASSERT_EQ(1u, M.mechanisms.count("exp2syn")); + + auto& expsyn_cv = M.mechanisms.at("expsyn").cv; + auto& expsyn_target = M.mechanisms.at("expsyn").target; + auto& expsyn_e = value_by_key(M.mechanisms.at("expsyn").param_values, "e"_s).value(); + + auto& exp2syn_cv = M.mechanisms.at("exp2syn").cv; + auto& exp2syn_target = M.mechanisms.at("exp2syn").target; + auto& exp2syn_e = value_by_key(M.mechanisms.at("exp2syn").param_values, "e"_s).value(); + + EXPECT_TRUE(util::is_sorted(expsyn_cv)); + EXPECT_TRUE(util::is_sorted(exp2syn_cv)); + + using uvec = std::vector<fvm_size_type>; + uvec all_target_indices; + util::append(all_target_indices, expsyn_target); + util::append(all_target_indices, exp2syn_target); + util::sort(all_target_indices); + + uvec nsyn_iota; + util::assign(nsyn_iota, make_span(nsyn)); + EXPECT_EQ(nsyn_iota, all_target_indices); + + for (auto i: count_along(expsyn_target)) { + EXPECT_EQ(syn_e[expsyn_target[i]], expsyn_e[i]); + } + + for (auto i: count_along(exp2syn_target)) { + EXPECT_EQ(syn_e[exp2syn_target[i]], exp2syn_e[i]); + } +} + + +// TODO: migrate tests for proportional parameter setting. + + +namespace { + double wm_impl(double wa, double xa) { + return wa? xa/wa: 0; + } + + template <typename... R> + double wm_impl(double wa, double xa, double w, double x, R... rest) { + return wm_impl(wa+w, xa+w*x, rest...); + } + + // Computed weighted mean (w*x + ...) / (w + ...). + template <typename... R> + double wmean(double w, double x, R... rest) { + return wm_impl(w, w*x, rest...); + } +} + +TEST(fvm_layout, density_norm_area) { + // Test area-weighted linear combination of density mechanism parameters. + + // Create a cell with 4 segments: + // - Soma (segment 0) plus three dendrites (1, 2, 3) meeting at a branch point. + // - HH mechanism on all segments. + // - Dendritic segments are given 3 compartments each. + // + // The CV corresponding to the branch point should comprise the terminal + // 1/6 of segment 1 and the initial 1/6 of segments 2 and 3. + // + // The HH mechanism current density parameters ('gnabar', 'gkbar' and 'gl') are set + // differently for each segment: + // + // soma: all default values (gnabar = 0.12, gkbar = .036, gl = .0003) + // segment 1: gl = .0002 + // segment 2: gkbar = .05 + // segment 3: gkbar = .07, gl = .0004 + // + // Geometry: + // segment 1: 100 µm long, 1 µm diameter cylinder. + // segment 2: 200 µm long, diameter linear taper from 1 µm to 0.2 µm. + // segment 3: 150 µm long, 0.8 µm diameter cylinder. + // + // Use divided compartment view on segments to compute area contributions. + + std::vector<cell> cells(1); + cell& c = cells[0]; + auto soma = c.add_soma(12.6157/2.0); + + c.add_cable(0, section_kind::dendrite, 0.5, 0.5, 100); + c.add_cable(1, section_kind::dendrite, 0.5, 0.1, 200); + c.add_cable(1, section_kind::dendrite, 0.4, 0.4, 150); + + auto& segs = c.segments(); + + double dflt_gkbar = .036; + double dflt_gl = 0.0003; + + double seg1_gl = .0002; + double seg2_gkbar = .05; + double seg3_gkbar = .0004; + double seg3_gl = .0004; + + for (int i = 0; i<4; ++i) { + segment& seg = *segs[i]; + seg.set_compartments(3); + + mechanism_desc hh("hh"); + switch (i) { + case 1: + hh["gl"] = seg1_gl; + break; + case 2: + hh["gkbar"] = seg2_gkbar; + break; + case 3: + hh["gkbar"] = seg3_gkbar; + hh["gl"] = seg3_gl; + break; + default: ; + } + seg.add_mechanism(hh); + } + + int ncv = 10; + std::vector<double> expected_gkbar(ncv, dflt_gkbar); + std::vector<double> expected_gl(ncv, dflt_gl); + + double soma_area = soma->area(); + auto seg1_divs = div_compartments<div_compartment_by_ends>(segs[1]->as_cable()); + auto seg2_divs = div_compartments<div_compartment_by_ends>(segs[2]->as_cable()); + auto seg3_divs = div_compartments<div_compartment_by_ends>(segs[3]->as_cable()); + + // CV 0: mix of soma and left of segment 1 + expected_gl[0] = wmean(soma_area, dflt_gl, seg1_divs(0).left.area, seg1_gl); + + expected_gl[1] = seg1_gl; + expected_gl[2] = seg1_gl; + + // CV 3: mix of right of segment 1 and left of segments 2 and 3. + expected_gkbar[3] = wmean(seg1_divs(2).right.area, dflt_gkbar, seg2_divs(0).left.area, seg2_gkbar, seg3_divs(0).left.area, seg3_gkbar); + expected_gl[3] = wmean(seg1_divs(2).right.area, seg1_gl, seg2_divs(0).left.area, dflt_gl, seg3_divs(0).left.area, seg3_gl); + + // CV 4-6: just segment 2 + expected_gkbar[4] = seg2_gkbar; + expected_gkbar[5] = seg2_gkbar; + expected_gkbar[6] = seg2_gkbar; + + // CV 7-9: just segment 3 + expected_gkbar[7] = seg3_gkbar; + expected_gkbar[8] = seg3_gkbar; + expected_gkbar[9] = seg3_gkbar; + expected_gl[7] = seg3_gl; + expected_gl[8] = seg3_gl; + expected_gl[9] = seg3_gl; + + fvm_discretization D = fvm_discretize(cells); + fvm_mechanism_data M = fvm_build_mechanism_data(global_default_catalogue(), cells, D); + + // Check CV area assumptions. + // Note: area integrator used here and in `fvm_multicell` may differ, and so areas computed may + // differ some due to rounding area, even given that we're dealing with simple truncated cones + // for segments. Check relative error within a tolerance of (say) 10 epsilon. + + double area_relerr = 10*std::numeric_limits<double>::epsilon(); + EXPECT_TRUE(testing::near_relative(D.cv_area[0], + soma_area+seg1_divs(0).left.area, area_relerr)); + EXPECT_TRUE(testing::near_relative(D.cv_area[1], + seg1_divs(0).right.area+seg1_divs(1).left.area, area_relerr)); + EXPECT_TRUE(testing::near_relative(D.cv_area[3], + seg1_divs(2).right.area+seg2_divs(0).left.area+seg3_divs(0).left.area, area_relerr)); + EXPECT_TRUE(testing::near_relative(D.cv_area[6], + seg2_divs(2).right.area, area_relerr)); + + // Grab the HH parameters from the mechanism. + + EXPECT_EQ(1u, M.mechanisms.size()); + ASSERT_EQ(1u, M.mechanisms.count("hh")); + auto& hh_params = M.mechanisms.at("hh").param_values; + + auto& gkbar = value_by_key(hh_params, "gkbar"_s).value(); + auto& gl = value_by_key(hh_params, "gl"_s).value(); + + EXPECT_TRUE(testing::seq_almost_eq<double>(expected_gkbar, gkbar)); + EXPECT_TRUE(testing::seq_almost_eq<double>(expected_gl, gl)); +} + +TEST(fvm_layout, ion_weights) { + // Create a cell with 4 segments: + // - Soma (segment 0) plus three dendrites (1, 2, 3) meeting at a branch point. + // - Dendritic segments are given 1 compartments each. + // + // / + // d2 + // / + // s0-d1 + // \. + // d3 + // + // The CV corresponding to the branch point should comprise the terminal + // 1/2 of segment 1 and the initial 1/2 of segments 2 and 3. + // + // Geometry: + // soma 0: radius 5 µm + // dend 1: 100 µm long, 1 µm diameter cynlinder + // dend 2: 200 µm long, 1 µm diameter cynlinder + // dend 3: 100 µm long, 1 µm diameter cynlinder + // + // The radius of the soma is chosen such that the surface area of soma is + // the same as a 100µm dendrite, which makes it easier to describe the + // expected weights. + + auto construct_cell = [](cell& c) { + c.add_soma(5); + + c.add_cable(0, section_kind::dendrite, 0.5, 0.5, 100); + c.add_cable(1, section_kind::dendrite, 0.5, 0.5, 200); + c.add_cable(1, section_kind::dendrite, 0.5, 0.5, 100); + + for (auto& s: c.segments()) s->set_compartments(1); + }; + + using uvec = std::vector<fvm_size_type>; + using ivec = std::vector<fvm_index_type>; + using fvec = std::vector<fvm_value_type>; + + uvec mech_segs[] = { + {0}, {0,2}, {2, 3}, {0, 1, 2, 3}, {3} + }; + + ivec expected_ion_cv[] = { + {0}, {0, 1, 2}, {1, 2, 3}, {0, 1, 2, 3}, {1, 3} + }; + + fvec expected_iconc_norm_area[] = { + {1./3}, {1./3, 1./2, 0.}, {1./4, 0., 0.}, {0., 0., 0., 0.}, {3./4, 0.} + }; + + for (auto run: count_along(mech_segs)) { + std::vector<cell> cells(1); + cell& c = cells[0]; + construct_cell(c); + + for (auto i: mech_segs[run]) { + c.segments()[i]->add_mechanism("test_ca"); + } + + fvm_discretization D = fvm_discretize(cells); + fvm_mechanism_data M = fvm_build_mechanism_data(global_default_catalogue(), cells, D); + + ASSERT_EQ(1u, M.ions.count(ionKind::ca)); + auto& ca = M.ions.at(ionKind::ca); + + EXPECT_EQ(expected_ion_cv[run], ca.cv); + EXPECT_TRUE(testing::seq_almost_eq<fvm_value_type>(expected_iconc_norm_area[run], ca.iconc_norm_area)); + + EXPECT_TRUE(util::all_of(ca.econc_norm_area, [](fvm_value_type v) { return v==1.; })); + } +} diff --git a/tests/unit/test_fvm_lowered.cpp b/tests/unit/test_fvm_lowered.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72f305c0c8f68abb041043f7f017c635678eb564 --- /dev/null +++ b/tests/unit/test_fvm_lowered.cpp @@ -0,0 +1,439 @@ +#include <vector> + +#include "../gtest.h" + +#include <algorithms.hpp> +#include <backends/fvm_types.hpp> +#include <backends/multicore/fvm.hpp> +#include <backends/multicore/mechanism.hpp> +#include <cell.hpp> +#include <common_types.hpp> +#include <fvm_lowered_cell.hpp> +#include <fvm_lowered_cell_impl.hpp> +#include <load_balance.hpp> +#include <math.hpp> +#include <simulation.hpp> +#include <recipe.hpp> +#include <sampler_map.hpp> +#include <sampling.hpp> +#include <schedule.hpp> +#include <segment.hpp> +#include <util/meta.hpp> +#include <util/maputil.hpp> +#include <util/rangeutil.hpp> + +#include "common.hpp" +#include "../common_cells.hpp" +#include "../simple_recipes.hpp" + +using namespace testing::string_literals; + +using backend = arb::multicore::backend; +using fvm_cell = arb::fvm_lowered_cell_impl<backend>; + +// Access to fvm_cell private data: + +using shared_state = backend::shared_state; +ACCESS_BIND(std::unique_ptr<shared_state> fvm_cell::*, private_state_ptr, &fvm_cell::state_) + +using matrix = arb::matrix<arb::multicore::backend>; +ACCESS_BIND(matrix fvm_cell::*, private_matrix_ptr, &fvm_cell::matrix_) + +ACCESS_BIND(std::vector<arb::mechanism_ptr> fvm_cell::*, private_mechanisms_ptr, &fvm_cell::mechanisms_) + +arb::mechanism* find_mechanism(fvm_cell& fvcell, const std::string& name) { + for (auto& mech: fvcell.*private_mechanisms_ptr) { + if (mech->internal_name()==name) { + return mech.get(); + } + } + return nullptr; +} + +// Access to mechanism-internal data: + +using mechanism_global_table = std::vector<std::pair<const char*, arb::fvm_value_type*>>; +using mechanism_field_table = std::vector<std::pair<const char*, arb::fvm_value_type**>>; +using mechanism_ion_index_table = std::vector<std::pair<arb::ionKind, backend::iarray*>>; + +ACCESS_BIND(\ + mechanism_global_table (arb::multicore::mechanism::*)(),\ + private_global_table_ptr,\ + &arb::multicore::mechanism::global_table) + +ACCESS_BIND(\ + mechanism_field_table (arb::multicore::mechanism::*)(),\ + private_field_table_ptr,\ + &arb::multicore::mechanism::field_table) + +ACCESS_BIND(\ + mechanism_ion_index_table (arb::multicore::mechanism::*)(),\ + private_ion_index_table_ptr,\ + &arb::multicore::mechanism::ion_index_table) + + +// TODO: C++14 replace use with generic lambda +struct generic_isnan { + template <typename V> + bool operator()(V& v) const { return std::isnan(v); } +} isnan_; + +using namespace arb; + +TEST(fvm_lowered, matrix_init) +{ + algorithms::generic_is_positive ispos; + algorithms::generic_is_negative isneg; + + arb::cell cell = make_cell_ball_and_stick(); + + ASSERT_EQ(2u, cell.num_segments()); + cell.segment(1)->set_compartments(10); + + std::vector<target_handle> targets; + probe_association_map<probe_handle> probe_map; + + fvm_cell fvcell; + fvcell.initialize({0}, cable1d_recipe(cell), targets, probe_map); + + auto& J = fvcell.*private_matrix_ptr; + EXPECT_EQ(J.size(), 11u); + + // Test that the matrix is initialized with sensible values + + fvcell.integrate(0.01, 0.01, {}, {}); + + auto n = J.size(); + auto& mat = J.state_; + + EXPECT_FALSE(util::any_of(util::subrange_view(mat.u, 1, n), isnan_)); + EXPECT_FALSE(util::any_of(mat.d, isnan_)); + EXPECT_FALSE(util::any_of(J.solution(), isnan_)); + + EXPECT_FALSE(util::any_of(util::subrange_view(mat.u, 1, n), ispos)); + EXPECT_FALSE(util::any_of(mat.d, isneg)); +} + +TEST(fvm_lowered, target_handles) { + using namespace arb; + + arb::cell cells[] = { + make_cell_ball_and_stick(), + make_cell_ball_and_3stick() + }; + + EXPECT_EQ(cells[0].num_segments(), 2u); + EXPECT_EQ(cells[1].num_segments(), 4u); + + // (in increasing target order) + cells[0].add_synapse({1, 0.4}, "expsyn"); + cells[0].add_synapse({0, 0.5}, "expsyn"); + cells[1].add_synapse({2, 0.2}, "exp2syn"); + cells[1].add_synapse({2, 0.8}, "expsyn"); + + cells[1].add_detector({0, 0}, 3.3); + + std::vector<target_handle> targets; + probe_association_map<probe_handle> probe_map; + + fvm_cell fvcell; + fvcell.initialize({0, 1}, cable1d_recipe(cells), targets, probe_map); + + mechanism* expsyn = find_mechanism(fvcell, "expsyn"); + ASSERT_TRUE(expsyn); + mechanism* exp2syn = find_mechanism(fvcell, "exp2syn"); + ASSERT_TRUE(exp2syn); + + unsigned expsyn_id = expsyn->mechanism_id(); + unsigned exp2syn_id = exp2syn->mechanism_id(); + + EXPECT_EQ(4u, targets.size()); + + EXPECT_EQ(expsyn_id, targets[0].mech_id); + EXPECT_EQ(1u, targets[0].mech_index); + EXPECT_EQ(0u, targets[0].cell_index); + + EXPECT_EQ(expsyn_id, targets[1].mech_id); + EXPECT_EQ(0u, targets[1].mech_index); + EXPECT_EQ(0u, targets[1].cell_index); + + EXPECT_EQ(exp2syn_id, targets[2].mech_id); + EXPECT_EQ(0u, targets[2].mech_index); + EXPECT_EQ(1u, targets[2].cell_index); + + EXPECT_EQ(expsyn_id, targets[3].mech_id); + EXPECT_EQ(2u, targets[3].mech_index); + EXPECT_EQ(1u, targets[3].cell_index); +} + +TEST(fvm_lowered, stimulus) { + // Ball-and-stick with two stimuli: + // + // |stim0 |stim1 + // ----------------------- + // delay | 5 | 1 + // duration | 80 | 2 + // amplitude | 0.3 | 0.1 + // CV | 4 | 0 + + std::vector<cell> cells; + cells.push_back(make_cell_ball_and_stick(false)); + + cells[0].add_stimulus({1,1}, {5., 80., 0.3}); + cells[0].add_stimulus({0,0.5}, {1., 2., 0.1}); + + const fvm_size_type soma_cv = 0u; + const fvm_size_type tip_cv = 4u; + + // now we have two stims : + // + // + // The implementation of the stimulus is tested by creating a lowered cell, then + // testing that the correct currents are injected at the correct control volumes + // as during the stimulus windows. + + fvm_discretization D = fvm_discretize(cells); + const auto& A = D.cv_area; + + std::vector<target_handle> targets; + probe_association_map<probe_handle> probe_map; + + fvm_cell fvcell; + fvcell.initialize({0}, cable1d_recipe(cells), targets, probe_map); + + mechanism* stim = find_mechanism(fvcell, "_builtin_stimulus"); + ASSERT_TRUE(stim); + EXPECT_EQ(2u, stim->size()); + + auto& state = *(fvcell.*private_state_ptr).get(); + auto& J = state.current_density; + auto& T = state.time; + + // Test that no current is injected at t=0. + memory::fill(J, 0.); + memory::fill(T, 0.); + stim->nrn_current(); + + for (auto j: J) { + EXPECT_EQ(j, 0.); + } + + // Test that 0.1 nA current is injected at soma at t=1. + memory::fill(J, 0.); + memory::fill(T, 1.); + stim->nrn_current(); + constexpr double unit_factor = 1e-3; // scale A/m²·µm² to nA + EXPECT_DOUBLE_EQ(-0.1, J[soma_cv]*A[soma_cv]*unit_factor); + + // Test that 0.1 nA is again injected at t=1.5, for a total of 0.2 nA. + memory::fill(T, 1.); + stim->nrn_current(); + EXPECT_DOUBLE_EQ(-0.2, J[soma_cv]*A[soma_cv]*unit_factor); + + // Test that at t=10, no more current is injected at soma, and that + // that 0.3 nA is injected at dendrite tip. + memory::fill(T, 10.); + stim->nrn_current(); + EXPECT_DOUBLE_EQ(-0.2, J[soma_cv]*A[soma_cv]*unit_factor); + EXPECT_DOUBLE_EQ(-0.3, J[tip_cv]*A[tip_cv]*unit_factor); +} + +// Test derived mechanism behaviour. + +TEST(fvm_lowered, derived_mechs) { + // Create ball and stick cells with the 'test_kin1' mechanism, which produces + // a voltage-independent current density of the form a + exp(-t/tau) as a function + // of time t. + // + // 1. Default 'test_kin1': tau = 10 [ms]. + // + // 2. Specialized version 'custom_kin1' with tau = 20 [ms]. + // + // 3. Cell with both test_kin1 and custom_kin1. + + std::vector<cell> cells(3); + for (int i = 0; i<3; ++i) { + cell& c = cells[i]; + c.add_soma(6.0); + c.add_cable(0, section_kind::dendrite, 0.5, 0.5, 100); + + c.segment(1)->set_compartments(4); + for (auto& seg: c.segments()) { + if (!seg->is_soma()) { + seg->as_cable()->set_compartments(4); + } + switch (i) { + case 0: + seg->add_mechanism("test_kin1"); + break; + case 1: + seg->add_mechanism("custom_kin1"); + break; + case 2: + seg->add_mechanism("test_kin1"); + seg->add_mechanism("custom_kin1"); + break; + } + } + } + + cable1d_recipe rec(cells); + rec.catalogue().derive("custom_kin1", "test_kin1", {{"tau", 20.0}}); + + cell_probe_address where{{1, 0.3}, cell_probe_address::membrane_current}; + rec.add_probe(0, 0, where); + rec.add_probe(1, 0, where); + rec.add_probe(2, 0, where); + + { + // Test initialization and global parameter values. + + std::vector<target_handle> targets; + probe_association_map<probe_handle> probe_map; + + fvm_cell fvcell; + fvcell.initialize({0, 1, 2}, rec, targets, probe_map); + + // Both mechanisms will have the same internal name, "test_kin1". + + using fvec = std::vector<fvm_value_type>; + fvec tau_values; + for (auto& mech: fvcell.*private_mechanisms_ptr) { + EXPECT_EQ("test_kin1"_s, mech->internal_name()); + + auto cmech = dynamic_cast<multicore::mechanism*>(mech.get()); + ASSERT_TRUE(cmech); + + auto opt_tau_ptr = util::value_by_key((cmech->*private_global_table_ptr)(), "tau"_s); + ASSERT_TRUE(opt_tau_ptr); + tau_values.push_back(*opt_tau_ptr.value()); + } + util::sort(tau_values); + EXPECT_EQ(fvec({10., 20.}), tau_values); + } + + { + // Test dynamics: + // 1. Current at same point on cell 0 at time 10 ms should equal that + // on cell 1 at time 20 ms. + // 2. Current for cell 2 should be sum of currents for cells 0 and 1 at any given time. + + std::vector<double> samples[3]; + + sampler_function sampler = [&](cell_member_type pid, probe_tag, std::size_t n, const sample_record* records) { + for (std::size_t i = 0; i<n; ++i) { + double v = *util::any_cast<const double*>(records[i].data); + samples[pid.gid].push_back(v); + } + }; + + float times[] = {10.f, 20.f}; + + auto decomp = partition_load_balance(rec, hw::node_info{1u, 0u}); + simulation sim(rec, decomp); + sim.add_sampler(all_probes, explicit_schedule(times), sampler); + sim.run(30.0, 1.f/1024); + + ASSERT_EQ(2u, samples[0].size()); + ASSERT_EQ(2u, samples[1].size()); + ASSERT_EQ(2u, samples[2].size()); + + // Integration isn't exact: let's aim for one part in 10'000. + double relerr = 0.0001; + EXPECT_TRUE(testing::near_relative(samples[0][0], samples[1][1], relerr)); + EXPECT_TRUE(testing::near_relative(samples[0][0]+samples[1][0], samples[2][0], relerr)); + EXPECT_TRUE(testing::near_relative(samples[0][1]+samples[1][1], samples[2][1], relerr)); + } +} + +// Test area-weighted linear combination of ion species concentrations + +TEST(fvm_lowered, weighted_write_ion) { + // Create a cell with 4 segments (same morphopology as in fvm_layout.ion_weights test): + // - Soma (segment 0) plus three dendrites (1, 2, 3) meeting at a branch point. + // - Dendritic segments are given 1 compartments each. + // + // / + // d2 + // / + // s0-d1 + // \. + // d3 + // + // The CV corresponding to the branch point should comprise the terminal + // 1/2 of segment 1 and the initial 1/2 of segments 2 and 3. + // + // Geometry: + // soma 0: radius 5 µm + // dend 1: 100 µm long, 1 µm diameter cynlinder + // dend 2: 200 µm long, 1 µm diameter cynlinder + // dend 3: 100 µm long, 1 µm diameter cynlinder + // + // The radius of the soma is chosen such that the surface area of soma is + // the same as a 100µm dendrite, which makes it easier to describe the + // expected weights. + + cell c; + c.add_soma(5); + + c.add_cable(0, section_kind::dendrite, 0.5, 0.5, 100); + c.add_cable(1, section_kind::dendrite, 0.5, 0.5, 200); + c.add_cable(1, section_kind::dendrite, 0.5, 0.5, 100); + + for (auto& s: c.segments()) s->set_compartments(1); + + const double con_int = 80; + const double con_ext = 120; + + // Ca ion reader test_kinlva on CV 1 and 2 via segment 2: + c.segments()[2] ->add_mechanism("test_kinlva"); + + // Ca ion writer test_ca on CV 1 and 3 via segment 3: + c.segments()[3] ->add_mechanism("test_ca"); + + std::vector<target_handle> targets; + probe_association_map<probe_handle> probe_map; + + fvm_cell fvcell; + fvcell.initialize({0}, cable1d_recipe(c), targets, probe_map); + + auto& state = *(fvcell.*private_state_ptr).get(); + auto& ion = state.ion_data.at(ionKind::ca); + ion.default_int_concentration = con_int; + ion.default_ext_concentration = con_ext; + ion.init_concentration(); + + std::vector<unsigned> ion_nodes = util::assign_from(ion.node_index_); + std::vector<unsigned> expected_ion_nodes = {1, 2, 3}; + EXPECT_EQ(expected_ion_nodes, ion_nodes); + + std::vector<double> ion_iconc_weights = util::assign_from(ion.weight_Xi_); + std::vector<double> expected_ion_iconc_weights = {0.75, 1., 0}; + EXPECT_EQ(expected_ion_iconc_weights, ion_iconc_weights); + + auto test_ca = dynamic_cast<multicore::mechanism*>(find_mechanism(fvcell, "test_ca")); + + auto opt_cai_ptr = util::value_by_key((test_ca->*private_field_table_ptr)(), "cai"_s); + ASSERT_TRUE(opt_cai_ptr); + auto& test_ca_cai = *opt_cai_ptr.value(); + + auto opt_ca_index_ptr = util::value_by_key((test_ca->*private_ion_index_table_ptr)(), ionKind::ca); + ASSERT_TRUE(opt_ca_index_ptr); + auto& test_ca_ca_index = *opt_ca_index_ptr.value(); + + double cai_contrib[3] = {200., 0., 300.}; + for (int i = 0; i<2; ++i) { + test_ca_cai[i] = cai_contrib[test_ca_ca_index[i]]; + } + + std::vector<double> expected_iconc(3); + for (int i = 0; i<3; ++i) { + expected_iconc[i] = math::lerp(cai_contrib[i], con_int, ion_iconc_weights[i]); + } + + ion.init_concentration(); + test_ca->write_ions(); + std::vector<double> ion_iconc = util::assign_from(ion.Xi_); + EXPECT_EQ(expected_iconc, ion_iconc); +} + diff --git a/tests/unit/test_fvm_multi.cpp b/tests/unit/test_fvm_multi.cpp deleted file mode 100644 index 0245ae1e7af2f93ae99b723dab2d950fc83dad63..0000000000000000000000000000000000000000 --- a/tests/unit/test_fvm_multi.cpp +++ /dev/null @@ -1,953 +0,0 @@ -#include <vector> - -#include "../gtest.h" - -#include <backends/multicore/fvm.hpp> -#include <cell.hpp> -#include <common_types.hpp> -#include <fvm_multicell.hpp> -#include <load_balance.hpp> -#include <math.hpp> -#include <simulation.hpp> -#include <recipe.hpp> -#include <sampler_map.hpp> -#include <sampling.hpp> -#include <schedule.hpp> -#include <segment.hpp> -#include <util/meta.hpp> -#include <util/rangeutil.hpp> - -#include "common.hpp" -#include "../common_cells.hpp" -#include "../simple_recipes.hpp" - -#include "mechanisms/multicore/test_ca_cpu.hpp" - -using fvm_cell = - arb::fvm::fvm_multicell<arb::multicore::backend>; - -TEST(fvm_multi, cable) -{ - using namespace arb; - - arb::cell cell=make_cell_ball_and_3stick(); - - std::vector<fvm_cell::target_handle> targets; - probe_association_map<fvm_cell::probe_handle> probe_map; - - fvm_cell fvcell; - fvcell.initialize({0}, cable1d_recipe(cell), targets, probe_map); - - auto& J = fvcell.jacobian(); - - // 1 (soma) + 3 (dendritic segments) × 4 compartments - EXPECT_EQ(cell.num_compartments(), 13u); - - // assert that the matrix has one row for each compartment - EXPECT_EQ(J.size(), cell.num_compartments()); - - // assert that the number of cv areas is the same as the matrix size - // i.e. both should equal the number of compartments - EXPECT_EQ(fvcell.cv_areas().size(), J.size()); -} - -TEST(fvm_multi, init) -{ - using namespace arb; - - arb::cell cell = make_cell_ball_and_stick(); - - const auto m = cell.model(); - EXPECT_EQ(m.tree.num_segments(), 2u); - - auto& soma_hh = (cell.soma()->mechanism("hh")).value(); - - soma_hh.set("gnabar", 0.12); - soma_hh.set("gkbar", 0.036); - soma_hh.set("gl", 0.0003); - soma_hh.set("el", -54.3); - - cell.segment(1)->set_compartments(10); - - std::vector<fvm_cell::target_handle> targets; - probe_association_map<fvm_cell::probe_handle> probe_map; - - fvm_cell fvcell; - fvcell.initialize({0}, cable1d_recipe(cell), targets, probe_map); - - // This is naughty: removing const from the matrix reference, but is needed - // to test the build_matrix() method below (which is only accessable - // through non-const interface). - //auto& J = const_cast<fvm_cell::matrix_type&>(fvcell.jacobian()); - auto& J = fvcell.jacobian(); - EXPECT_EQ(J.size(), 11u); - - // test that the matrix is initialized with sensible values - //J.build_matrix(0.01); - fvcell.setup_integration(0.01, 0.01, {}, {}); - fvcell.step_integration(); - - auto& mat = J.state_; - auto test_nan = [](decltype(mat.u) v) { - for(auto val : v) if(val != val) return false; - return true; - }; - EXPECT_TRUE(test_nan(mat.u(1, J.size()))); - EXPECT_TRUE(test_nan(mat.d)); - EXPECT_TRUE(test_nan(J.solution())); - - // test matrix diagonals for sign - auto is_pos = [](decltype(mat.u) v) { - for(auto val : v) if(val<=0.) return false; - return true; - }; - auto is_neg = [](decltype(mat.u) v) { - for(auto val : v) if(val>=0.) return false; - return true; - }; - EXPECT_TRUE(is_neg(mat.u(1, J.size()))); - EXPECT_TRUE(is_pos(mat.d)); -} - -TEST(fvm_multi, multi_init) -{ - using namespace arb; - - arb::cell cells[] = { - make_cell_ball_and_stick(), - make_cell_ball_and_3stick() - }; - - EXPECT_EQ(cells[0].num_segments(), 2u); - EXPECT_EQ(cells[0].segment(1)->num_compartments(), 4u); - EXPECT_EQ(cells[1].num_segments(), 4u); - EXPECT_EQ(cells[1].segment(1)->num_compartments(), 4u); - EXPECT_EQ(cells[1].segment(2)->num_compartments(), 4u); - EXPECT_EQ(cells[1].segment(3)->num_compartments(), 4u); - - cells[0].add_synapse({1, 0.4}, "expsyn"); - cells[0].add_synapse({1, 0.4}, "expsyn"); - cells[1].add_synapse({2, 0.4}, "exp2syn"); - cells[1].add_synapse({3, 0.4}, "expsyn"); - - cells[1].add_detector({0, 0}, 3.3); - - std::vector<fvm_cell::target_handle> targets; - probe_association_map<fvm_cell::probe_handle> probe_map; - - fvm_cell fvcell; - fvcell.initialize({0, 1}, cable1d_recipe(cells), targets, probe_map); - - EXPECT_EQ(4u, targets.size()); - - auto& J = fvcell.jacobian(); - EXPECT_EQ(J.size(), 5u+13u); - - // check indices in instantiated mechanisms - for (const auto& mech: fvcell.mechanisms()) { - if (mech->name()=="hh") { - // HH on somas of two cells, with group compartment indices - // 0 and 5. - ASSERT_EQ(mech->node_index().size(), 2u); - EXPECT_EQ(mech->node_index()[0], 0u); - EXPECT_EQ(mech->node_index()[1], 5u); - } - if (mech->name()=="expsyn") { - // Three expsyn synapses, two in second compartment - // of dendrite segment of first cell, one in second compartment - // of last segment of second cell. - ASSERT_EQ(mech->node_index().size(), 3u); - EXPECT_EQ(mech->node_index()[0], 2u); - EXPECT_EQ(mech->node_index()[1], 2u); - EXPECT_EQ(mech->node_index()[2], 15u); - } - if (mech->name()=="exp2syn") { - // One exp2syn synapse, in second compartment - // of penultimate segment of second cell. - ASSERT_EQ(mech->node_index().size(), 1u); - EXPECT_EQ(mech->node_index()[0], 11u); - } - } -} - -// test that stimuli are added correctly -TEST(fvm_multi, stimulus) -{ - using namespace arb; - - // the default ball and stick has one stimulus at the terminal end of the dendrite - auto cell = make_cell_ball_and_stick(); - - // ... so add a second at the soma to make things more interesting - cell.add_stimulus({0,0.5}, {1., 2., 0.1}); - - // now we have two stims : - // - // |stim0 |stim1 - // ----------------------- - // delay | 5 | 1 - // duration | 80 | 2 - // amplitude | 0.3 | 0.1 - // CV | 4 | 0 - // - // The implementation of the stimulus is tested by creating a lowered cell, then - // testing that the correct currents are injected at the correct control volumes - // as during the stimulus windows. - - std::vector<fvm_cell::target_handle> targets; - probe_association_map<fvm_cell::probe_handle> probe_map; - - fvm_cell fvcell; - fvcell.initialize({0}, cable1d_recipe(cell), targets, probe_map); - - auto ref = fvcell.find_mechanism("stimulus"); - ASSERT_TRUE(ref) << "no stimuli retrieved from lowered fvm cell: expected 2"; - - auto& stims = ref.value(); - EXPECT_EQ(stims->size(), 2u); - - auto I = fvcell.current(); - auto A = fvcell.cv_areas(); - - auto soma_idx = 0u; - auto dend_idx = 4u; - - // test 1: Test that no current is injected at t=0 - memory::fill(I, 0.); - fvcell.set_time_global(0.); - fvcell.set_time_to_global(0.1); - stims->set_params(); - stims->nrn_current(); - for (auto i: I) { - EXPECT_EQ(i, 0.); - } - - // test 2: Test that current is injected at soma at t=1 - fvcell.set_time_global(1.); - fvcell.set_time_to_global(1.1); - stims->nrn_current(); - // take care to convert from A.m^-2 to nA - EXPECT_EQ(I[soma_idx]/(1e3/A[soma_idx]), -0.1); - - // test 3: Test that current is still injected at soma at t=1.5. - // Note that we test for injection of -0.2, because the - // current contributions are accumulative, and the current - // values have not been cleared since the last update. - fvcell.set_time_global(1.5); - fvcell.set_time_to_global(1.6); - stims->set_params(); - stims->nrn_current(); - EXPECT_EQ(I[soma_idx]/(1e3/A[soma_idx]), -0.2); - - // test 4: test at t=10ms, when the the soma stim is not active, and - // dendrite stimulus is injecting a current of 0.3 nA - fvcell.set_time_global(10.); - fvcell.set_time_to_global(10.1); - stims->nrn_current(); - EXPECT_EQ(I[soma_idx]/(1e3/A[soma_idx]), -0.2); - EXPECT_EQ(I[dend_idx]/(1e3/A[dend_idx]), -0.3); -} - -// test that mechanism indexes are computed correctly -TEST(fvm_multi, mechanism_indexes) -{ - using namespace arb; - - // create a cell with 4 sements: - // a soma with a branching dendrite - // - hh on soma and first branch of dendrite (segs 0 and 2) - // - pas on main dendrite and second branch (segs 1 and 3) - // - // / - // pas - // / - // hh---pas--. - // \. - // hh - // \. - - cell c; - auto soma = c.add_soma(12.6157/2.0); - soma->add_mechanism("hh"); - - // add dendrite of length 200 um and diameter 1 um with passive channel - c.add_cable(0, section_kind::dendrite, 0.5, 0.5, 100); - c.add_cable(1, section_kind::dendrite, 0.5, 0.5, 100); - c.add_cable(1, section_kind::dendrite, 0.5, 0.5, 100); - - auto& segs = c.segments(); - segs[1]->add_mechanism("pas"); - segs[2]->add_mechanism("hh"); - segs[3]->add_mechanism("pas"); - - for (auto& seg: segs) { - if (seg->is_dendrite()) { - seg->rL = 100; - seg->set_compartments(4); - } - } - - // generate the lowered fvm cell - std::vector<fvm_cell::target_handle> targets; - probe_association_map<fvm_cell::probe_handle> probe_map; - - fvm_cell fvcell; - fvcell.initialize({0}, cable1d_recipe(c), targets, probe_map); - - // make vectors with the expected CV indexes for each mechanism - std::vector<unsigned> hh_index = {0u, 4u, 5u, 6u, 7u, 8u}; - std::vector<unsigned> pas_index = {0u, 1u, 2u, 3u, 4u, 9u, 10u, 11u, 12u}; - // iterate over mechanisms and test whether they were assigned to the correct CVs - // TODO : this fails because we do not handle CVs at branching points (including soma) correctly - for(auto& mech : fvcell.mechanisms()) { - auto const& n = mech->node_index(); - std::vector<unsigned> ni(n.begin(), n.end()); - if(mech->name()=="hh") { - EXPECT_EQ(ni, hh_index); - } - else if(mech->name()=="pas") { - EXPECT_EQ(ni, pas_index); - } - } - - // similarly, test that the different ion channels were assigned to the correct - // compartments. In this case, the passive channel has no ion species - // associated with it, while the hh channel has both pottassium and sodium - // channels. Hence, we expect sodium and potassium to be present in the same - // compartments as the hh mechanism. - { - auto ni = fvcell.ion_na().node_index(); - std::vector<unsigned> na(ni.begin(), ni.end()); - EXPECT_EQ(na, hh_index); - } - { - auto ni = fvcell.ion_k().node_index(); - std::vector<unsigned> k(ni.begin(), ni.end()); - EXPECT_EQ(k, hh_index); - } - { - // calcium channel should be empty - EXPECT_EQ(0u, fvcell.ion_ca().node_index().size()); - } -} - -namespace { - double wm_impl(double wa, double xa) { - return wa? xa/wa: 0; - } - - template <typename... R> - double wm_impl(double wa, double xa, double w, double x, R... rest) { - return wm_impl(wa+w, xa+w*x, rest...); - } - - // Computed weighted mean (w*x + ...) / (w + ...). - template <typename... R> - double wmean(double w, double x, R... rest) { - return wm_impl(w, w*x, rest...); - } -} - -// Test area-weighted linear combination of density mechanism parameters. - -TEST(fvm_multi, density_weights) { - using namespace arb; - - // Create a cell with 4 segments: - // - Soma (segment 0) plus three dendrites (1, 2, 3) meeting at a branch point. - // - HH mechanism on all segments. - // - Dendritic segments are given 3 compartments each. - // - // The CV corresponding to the branch point should comprise the terminal - // 1/6 of segment 1 and the initial 1/6 of segments 2 and 3. - // - // The HH mechanism current density parameters ('gnabar', 'gkbar' and 'gl') are set - // differently for each segment: - // - // soma: all default values (gnabar = 0.12, gkbar = .036, gl = .0003) - // segment 1: gl = .0002 - // segment 2: gkbar = .05 - // segment 3: gkbar = .07, gl = .0004 - // - // Geometry: - // segment 1: 100 µm long, 1 µm diameter cylinder. - // segment 2: 200 µm long, diameter linear taper from 1 µm to 0.2 µm. - // segment 3: 150 µm long, 0.8 µm diameter cylinder. - // - // Use divided compartment view on segments to compute area contributions. - - cell c; - auto soma = c.add_soma(12.6157/2.0); - - c.add_cable(0, section_kind::dendrite, 0.5, 0.5, 100); - c.add_cable(1, section_kind::dendrite, 0.5, 0.1, 200); - c.add_cable(1, section_kind::dendrite, 0.4, 0.4, 150); - - auto& segs = c.segments(); - - double dflt_gkbar = .036; - double dflt_gnabar = 0.12; - double dflt_gl = 0.0003; - - double seg1_gl = .0002; - double seg2_gkbar = .05; - double seg3_gkbar = .0004; - double seg3_gl = .0004; - - for (int i = 0; i<4; ++i) { - segment& seg = *segs[i]; - seg.set_compartments(3); - - mechanism_spec hh("hh"); - switch (i) { - case 1: - hh["gl"] = seg1_gl; - break; - case 2: - hh["gkbar"] = seg2_gkbar; - break; - case 3: - hh["gkbar"] = seg3_gkbar; - hh["gl"] = seg3_gl; - break; - default: ; - } - seg.add_mechanism(hh); - } - - int ncv = 10; - std::vector<double> expected_gkbar(ncv, dflt_gkbar); - std::vector<double> expected_gnabar(ncv, dflt_gnabar); - std::vector<double> expected_gl(ncv, dflt_gl); - - double soma_area = soma->area(); - auto seg1_divs = div_compartments<div_compartment_by_ends>(segs[1]->as_cable()); - auto seg2_divs = div_compartments<div_compartment_by_ends>(segs[2]->as_cable()); - auto seg3_divs = div_compartments<div_compartment_by_ends>(segs[3]->as_cable()); - - // CV 0: mix of soma and left of segment 1 - expected_gl[0] = wmean(soma_area, dflt_gl, seg1_divs(0).left.area, seg1_gl); - - expected_gl[1] = seg1_gl; - expected_gl[2] = seg1_gl; - - // CV 3: mix of right of segment 1 and left of segments 2 and 3. - expected_gkbar[3] = wmean(seg1_divs(2).right.area, dflt_gkbar, seg2_divs(0).left.area, seg2_gkbar, seg3_divs(0).left.area, seg3_gkbar); - expected_gl[3] = wmean(seg1_divs(2).right.area, seg1_gl, seg2_divs(0).left.area, dflt_gl, seg3_divs(0).left.area, seg3_gl); - - // CV 4-6: just segment 2 - expected_gkbar[4] = seg2_gkbar; - expected_gkbar[5] = seg2_gkbar; - expected_gkbar[6] = seg2_gkbar; - - // CV 7-9: just segment 3 - expected_gkbar[7] = seg3_gkbar; - expected_gkbar[8] = seg3_gkbar; - expected_gkbar[9] = seg3_gkbar; - expected_gl[7] = seg3_gl; - expected_gl[8] = seg3_gl; - expected_gl[9] = seg3_gl; - - // Generate the lowered fvm cell. - std::vector<fvm_cell::target_handle> targets; - probe_association_map<fvm_cell::probe_handle> probe_map; - - fvm_cell fvcell; - fvcell.initialize({0}, cable1d_recipe(c), targets, probe_map); - - // Check CV area assumptions. - // Note: area integrator used here and in `fvm_multicell` may differ, and so areas computed may - // differ some due to rounding area, even given that we're dealing with simple truncated cones - // for segments. Check relative error within a tolerance of (say) 10 epsilon. - auto cv_areas = fvcell.cv_areas(); - double area_relerr = 10*std::numeric_limits<double>::epsilon(); - EXPECT_TRUE(testing::near_relative(cv_areas[0], - soma_area+seg1_divs(0).left.area, area_relerr)); - EXPECT_TRUE(testing::near_relative(cv_areas[1], - seg1_divs(0).right.area+seg1_divs(1).left.area, area_relerr)); - EXPECT_TRUE(testing::near_relative(cv_areas[3], - seg1_divs(2).right.area+seg2_divs(0).left.area+seg3_divs(0).left.area, area_relerr)); - EXPECT_TRUE(testing::near_relative(cv_areas[6], - seg2_divs(2).right.area, area_relerr)); - - // Grab the HH parameters from the mechanism. - EXPECT_EQ(1u, fvcell.mechanisms().size()); - auto& hh_mech = *fvcell.mechanisms().front(); - - auto gnabar_field = hh_mech.field_view_ptr("gnabar"); - auto gkbar_field = hh_mech.field_view_ptr("gkbar"); - auto gl_field = hh_mech.field_view_ptr("gl"); - - EXPECT_TRUE(testing::seq_almost_eq<double>(expected_gnabar, hh_mech.*gnabar_field)); - EXPECT_TRUE(testing::seq_almost_eq<double>(expected_gkbar, hh_mech.*gkbar_field)); - EXPECT_TRUE(testing::seq_almost_eq<double>(expected_gl, hh_mech.*gl_field)); -} - -// Test specialized mechanism behaviour. - -TEST(fvm_multi, specialized_mechs) { - using namespace arb; - - // Create ball and stick cells with the 'test_kin1' mechanism, which produces - // a voltage-independent current density of the form a + exp(-t/tau) as a function - // of time t. - // - // 1. Default 'test_kin1': tau = 10 [ms]. - // - // 2. Specialized version 'custom_kin1' with tau = 20 [ms]. - // - // 3. Cell with both test_kin1 and custom_kin1. - - specialized_mechanism custom_kin1 = {"test_kin1", {{"tau", 20.0}}}; - - cell cells[3]; - - for (int i = 0; i<3; ++i) { - cell& c = cells[i]; - c.add_soma(6.0); - c.add_cable(0, section_kind::dendrite, 0.5, 0.5, 100); - - c.segment(1)->set_compartments(4); - for (auto& seg: c.segments()) { - if (!seg->is_soma()) { - seg->as_cable()->set_compartments(4); - } - - switch (i) { - case 0: - seg->add_mechanism("test_kin1"); - break; - case 1: - seg->add_mechanism("custom_kin1"); - break; - case 2: - seg->add_mechanism("test_kin1"); - seg->add_mechanism("custom_kin1"); - break; - } - } - } - - cable1d_recipe rec(cells); - rec.add_specialized_mechanism("custom_kin1", custom_kin1); - - cell_probe_address where{{1, 0.3}, cell_probe_address::membrane_current}; - rec.add_probe(0, 0, where); - rec.add_probe(1, 0, where); - rec.add_probe(2, 0, where); - - { - // Test initialization and global parameter values. - - std::vector<fvm_cell::target_handle> targets; - probe_association_map<fvm_cell::probe_handle> probe_map; - - fvm_cell fvcell; - fvcell.initialize({0, 1, 2}, rec, targets, probe_map); - - std::map<std::string, fvm_cell::mechanism*> mechmap; - for (auto& m: fvcell.mechanisms()) { - // (names of mechanisms should _all_ be 'test_kin1', but aliases will differ) - EXPECT_EQ("test_kin1", m->name()); - mechmap[m->alias()] = m.get(); - } - - ASSERT_EQ(2u, mechmap.size()); - EXPECT_NE(0u, mechmap.count("test_kin1")); - EXPECT_NE(0u, mechmap.count("custom_kin1")); - - // Both mechanisms are of the same type, so we can use the - // same member pointer. - auto fptr = mechmap.begin()->second->field_value_ptr("tau"); - - ASSERT_NE(nullptr, fptr); - - EXPECT_EQ(10.0, mechmap["test_kin1"]->*fptr); - EXPECT_EQ(20.0, mechmap["custom_kin1"]->*fptr); - } - - { - // Test dynamics: - // 1. Current at same point on cell 0 at time 10 ms should equal that - // on cell 1 at time 20 ms. - // 2. Current for cell 2 should be sum of currents for cells 0 and 1 at any given time. - - std::vector<double> samples[3]; - - sampler_function sampler = [&](cell_member_type pid, probe_tag, std::size_t n, const sample_record* records) { - for (std::size_t i = 0; i<n; ++i) { - double v = *util::any_cast<const double*>(records[i].data); - samples[pid.gid].push_back(v); - } - }; - - float times[] = {10.f, 20.f}; - - auto decomp = partition_load_balance(rec, hw::node_info{1u, 0u}); - simulation sim(rec, decomp); - sim.add_sampler(all_probes, explicit_schedule(times), sampler); - sim.run(30.0, 1.f/1024); - - ASSERT_EQ(2u, samples[0].size()); - ASSERT_EQ(2u, samples[1].size()); - ASSERT_EQ(2u, samples[2].size()); - - // Integration isn't exact: let's aim for one part in 10'000. - double relerr = 0.0001; - EXPECT_TRUE(testing::near_relative(samples[0][0], samples[1][1], relerr)); - EXPECT_TRUE(testing::near_relative(samples[0][0]+samples[1][0], samples[2][0], relerr)); - EXPECT_TRUE(testing::near_relative(samples[0][1]+samples[1][1], samples[2][1], relerr)); - } -} - -// Test synapses with differing parameter settings. - -TEST(fvm_multi, synapse_parameters) { - using namespace arb; - - cell c; - c.add_soma(6.0); - c.add_cable(0, section_kind::dendrite, 0.4, 0.4, 100.0); - c.segment(1)->set_compartments(4); - - // Add synapses out-of-order, with a parameter value a function of position - // on the segment, so we can test that parameters are properly associated - // after re-ordering. - - struct pset { - double x; // segment position - double tau1; - double tau2; - }; - - pset settings[] = { - {0.8, 1.5, 2.5}, - {0.1, 1.6, 3.7}, - {0.5, 1.7, 3.6}, - {0.6, 0.8, 2.5}, - {0.4, 0.9, 3.4}, - {0.9, 1.1, 2.3} - }; - - for (auto s: settings) { - mechanism_spec m("exp2syn"); - c.add_synapse({1, s.x}, m.set("tau1", s.tau1).set("tau2", s.tau2)); - } - - std::vector<fvm_cell::target_handle> targets; - probe_association_map<fvm_cell::probe_handle> probe_map; - - fvm_cell fvcell; - fvcell.initialize({0}, cable1d_recipe(c), targets, probe_map); - - EXPECT_EQ(1u, fvcell.mechanisms().size()); - auto& exp2syn_mech = *fvcell.mechanisms().front(); - - auto tau1_ptr = exp2syn_mech.field_view_ptr("tau1"); - auto tau2_ptr = exp2syn_mech.field_view_ptr("tau2"); - - // Compare tau1, tau2 values from settings and from mechanism, ignoring order. - std::set<std::pair<double, double>> expected; - for (auto s: settings) { - expected.insert({s.tau1, s.tau2}); - } - - unsigned n = exp2syn_mech.size(); - ASSERT_EQ(util::size(settings), n); - - std::set<std::pair<double, double>> values; - for (unsigned i = 0; i<n; ++i) { - values.insert({(exp2syn_mech.*tau1_ptr)[i], (exp2syn_mech.*tau2_ptr)[i]}); - } - - EXPECT_EQ(expected, values); -} - -struct handle_info { - unsigned cell; - std::string mech; - unsigned cv; -}; - -// test handle <-> mechanism/index correspondence -// on a two-cell ball-and-stick system. - -void run_target_handle_test(std::vector<handle_info> all_handles) { - using namespace arb; - - arb::cell cells[] = { - make_cell_ball_and_stick(), - make_cell_ball_and_stick() - }; - - EXPECT_EQ(2u, cells[0].num_segments()); - EXPECT_EQ(4u, cells[0].segment(1)->num_compartments()); - EXPECT_EQ(5u, cells[0].num_compartments()); - - EXPECT_EQ(2u, cells[1].num_segments()); - EXPECT_EQ(4u, cells[1].segment(1)->num_compartments()); - EXPECT_EQ(5u, cells[1].num_compartments()); - - std::vector<std::vector<handle_info>> handles(2); - - for (auto x: all_handles) { - unsigned seg_id; - double pos; - - ASSERT_TRUE(x.cell==0 || x.cell==1); - ASSERT_TRUE(x.cv<5); - ASSERT_TRUE(x.mech=="expsyn" || x.mech=="exp2syn"); - - if (x.cv==0) { - // place on soma - seg_id = 0; - pos = 0; - } - else { - // place on dendrite - seg_id = 1; - pos = x.cv/4.0; - } - - if (x.cell==1) { - x.cv += 5; // offset for cell 1 - } - - cells[x.cell].add_synapse({seg_id, pos}, x.mech); - handles[x.cell].push_back(x); - } - - auto n = all_handles.size(); - std::vector<fvm_cell::target_handle> targets; - probe_association_map<fvm_cell::probe_handle> probe_map; - - fvm_cell fvcell; - fvcell.initialize({0, 1}, cable1d_recipe(cells), targets, probe_map); - - ASSERT_EQ(n, util::size(targets)); - unsigned i = 0; - for (unsigned ci = 0; ci<=1; ++ci) { - for (auto h: handles[ci]) { - // targets are represented by a pair of mechanism index and instance index - const auto& mech = fvcell.mechanisms()[targets[i].mech_id]; - const auto& cvidx = mech->node_index(); - EXPECT_EQ(h.mech, mech->name()); - EXPECT_EQ(h.cv, cvidx[targets[i].mech_index]); - EXPECT_EQ(h.cell, targets[i].cell_index); - ++i; - } - } -} - -TEST(fvm_multi, target_handles_onecell) -{ - { - SCOPED_TRACE("handles: exp2syn only on cell 0"); - std::vector<handle_info> handles0 = { - {0, "exp2syn", 4}, - {0, "exp2syn", 4}, - {0, "exp2syn", 3}, - {0, "exp2syn", 2}, - {0, "exp2syn", 0}, - {0, "exp2syn", 1}, - {0, "exp2syn", 2} - }; - run_target_handle_test(handles0); - } - - { - SCOPED_TRACE("handles: expsyn only on cell 1"); - std::vector<handle_info> handles1 = { - {1, "expsyn", 4}, - {1, "expsyn", 4}, - {1, "expsyn", 3}, - {1, "expsyn", 2}, - {1, "expsyn", 0}, - {1, "expsyn", 1}, - {1, "expsyn", 2} - }; - run_target_handle_test(handles1); - } -} - -TEST(fvm_multi, target_handles_twocell) -{ - SCOPED_TRACE("handles: expsyn only on cells 0 and 1"); - std::vector<handle_info> handles = { - {0, "expsyn", 0}, - {1, "expsyn", 3}, - {0, "expsyn", 2}, - {1, "expsyn", 2}, - {0, "expsyn", 4}, - {1, "expsyn", 1}, - {1, "expsyn", 4} - }; - run_target_handle_test(handles); -} - -TEST(fvm_multi, target_handles_mixed_synapse) -{ - SCOPED_TRACE("handles: expsyn and exp2syn on cells 0"); - std::vector<handle_info> handles = { - {0, "expsyn", 4}, - {0, "exp2syn", 4}, - {0, "expsyn", 3}, - {0, "exp2syn", 2}, - {0, "exp2syn", 0}, - {0, "expsyn", 1}, - {0, "expsyn", 2} - }; - run_target_handle_test(handles); -} - -TEST(fvm_multi, target_handles_general) -{ - SCOPED_TRACE("handles: expsyn and exp2syn on cells 0 and 1"); - std::vector<handle_info> handles = { - {0, "expsyn", 4}, - {0, "exp2syn", 2}, - {0, "exp2syn", 0}, - {1, "exp2syn", 4}, - {1, "expsyn", 3}, - {1, "expsyn", 1}, - {1, "expsyn", 2} - }; - run_target_handle_test(handles); -} - -// Test area-weighted linear combination of ion species concentrations - -TEST(fvm_multi, ion_weights) { - using namespace arb; - - // Create a cell with 4 segments: - // - Soma (segment 0) plus three dendrites (1, 2, 3) meeting at a branch point. - // - Dendritic segments are given 1 compartments each. - // - // / - // d2 - // / - // s0-d1 - // \. - // d3 - // - // The CV corresponding to the branch point should comprise the terminal - // 1/2 of segment 1 and the initial 1/2 of segments 2 and 3. - // - // Geometry: - // soma 0: radius 5 µm - // dend 1: 100 µm long, 1 µm diameter cynlinder - // dend 2: 200 µm long, 1 µm diameter cynlinder - // dend 3: 100 µm long, 1 µm diameter cynlinder - // The radius of the soma is chosen such that the surface area of soma is - // the same as a 100µm dendrite, which makes it easier to describe the - // expected weights. - - std::vector<std::vector<int>> seg_sets = { - {0}, {0,2}, {2, 3}, {0, 1, 2, 3}, {3} - }; - std::vector<std::vector<unsigned>> expected_nodes = { - {0}, {0, 1, 2}, {1, 2, 3}, {0, 1, 2, 3}, {1, 3}, - }; - std::vector<std::vector<fvm_value_type>> expected_wght = { - {1./3}, {1./3, 1./2, 0.}, {1./4, 0., 0.}, {0., 0., 0., 0.}, {3./4, 0.} - }; - - double con_int = 80; - double con_ext = 120; - - auto construct_cell = [](cell& c) { - c.add_soma(5); - - c.add_cable(0, section_kind::dendrite, 0.5, 0.5, 100); - c.add_cable(1, section_kind::dendrite, 0.5, 0.5, 200); - c.add_cable(1, section_kind::dendrite, 0.5, 0.5, 100); - - for (auto& s: c.segments()) s->set_compartments(1); - }; - - for (auto run=0u; run<seg_sets.size(); ++run) { - cell c; - construct_cell(c); - - for (auto i: seg_sets[run]) { - c.segments()[i]->add_mechanism(mechanism_spec("test_ca")); - } - - std::vector<fvm_cell::target_handle> targets; - probe_association_map<fvm_cell::probe_handle> probe_map; - - fvm_cell fvcell; - fvcell.initialize({0}, cable1d_recipe(c), targets, probe_map); - - auto& ion = fvcell.ion_ca(); - ion.default_int_concentration = con_int; - ion.default_ext_concentration = con_ext; - ion.init_concentration(); - - auto& nodes = expected_nodes[run]; - auto& weights = expected_wght[run]; - auto ncv = nodes.size(); - EXPECT_EQ(ncv, ion.node_index().size()); - for (auto i: util::make_span(0, ncv)) { - EXPECT_EQ(nodes[i], ion.node_index()[i]); - EXPECT_FLOAT_EQ(weights[i], ion.internal_concentration_weights()[i]); - - EXPECT_EQ(con_ext, ion.external_concentration()[i]); - EXPECT_FLOAT_EQ(1.0, ion.external_concentration_weights()[i]); - } - } - - // Check correct indexing when writing mechanism nodes are a subset - // of ion nodes. - - { - cell c; - construct_cell(c); - - // reader on CV 1 and 2 via segment 2: - c.segments()[2] ->add_mechanism(mechanism_spec("test_kinlva")); - - // test_ca writer on CV 1 and 3 via segment 3: - c.segments()[3] ->add_mechanism(mechanism_spec("test_ca")); - - std::vector<fvm_cell::target_handle> targets; - probe_association_map<fvm_cell::probe_handle> probe_map; - - fvm_cell fvcell; - fvcell.initialize({0}, cable1d_recipe(c), targets, probe_map); - - auto& ion = fvcell.ion_ca(); - ion.default_int_concentration = con_int; - ion.default_ext_concentration = con_ext; - ion.init_concentration(); - - //auto ni = ion.node_index(); - std::vector<unsigned> ion_nodes = util::assign_from(ion.node_index()); - std::vector<unsigned> expected_ion_nodes = {1, 2, 3}; - EXPECT_EQ(expected_ion_nodes, ion_nodes); - - std::vector<double> ion_iconc_weights = util::assign_from(ion.internal_concentration_weights()); - std::vector<double> expected_ion_iconc_weights = {0.75, 1., 0}; - EXPECT_EQ(expected_ion_iconc_weights, ion_iconc_weights); - - multicore::mechanism_test_ca<multicore::backend>* test_ca = - dynamic_cast<multicore::mechanism_test_ca<multicore::backend>*>( - fvcell.find_mechanism("test_ca").value().get()); - - double cai_contrib[3] = {200., 0., 300.}; - for (int i = 0; i<2; ++i) { - test_ca->cai[i] = cai_contrib[test_ca->ion_ca.index[i]]; - } - - std::vector<double> expected_iconc(3); - for (int i = 0; i<3; ++i) { - expected_iconc[i] = math::lerp(cai_contrib[i], con_int, ion_iconc_weights[i]); - } - - ion.init_concentration(); - test_ca->write_back(); - std::vector<double> ion_iconc = util::assign_from(ion.internal_concentration()); - EXPECT_EQ(expected_iconc, ion_iconc); - } -} diff --git a/tests/unit/test_gpu_stack.cu b/tests/unit/test_gpu_stack.cu index 214a86122810841748ee1695839e9e5f27079d2c..18d9c180489b16d6b146eb8ffe9a4f509344cccb 100644 --- a/tests/unit/test_gpu_stack.cu +++ b/tests/unit/test_gpu_stack.cu @@ -1,7 +1,7 @@ #include "../gtest.h" -#include <backends/gpu/kernels/stack.hpp> #include <backends/gpu/stack.hpp> +#include <backends/gpu/stack_cu.hpp> #include <backends/gpu/managed_ptr.hpp> using namespace arb; diff --git a/tests/unit/test_intrin.cu b/tests/unit/test_intrin.cu index 5787b0bdd2054982df0a2ae019d1b16c71aa04f6..a5203bc940ce9609ac1024cd93547651096689ed 100644 --- a/tests/unit/test_intrin.cu +++ b/tests/unit/test_intrin.cu @@ -2,7 +2,8 @@ #include <limits> -#include <backends/gpu/intrinsics.hpp> +#include <backends/gpu/cuda_atomic.hpp> +#include <backends/gpu/math.hpp> #include <backends/gpu/managed_ptr.hpp> #include <memory/memory.hpp> #include <util/rangeutil.hpp> @@ -24,19 +25,19 @@ namespace kernels { __global__ void test_min(double* x, double* y, double* result) { const auto i = threadIdx.x; - result[i] = min(x[i], y[i]); + result[i] = arb::gpu::min(x[i], y[i]); } __global__ void test_max(double* x, double* y, double* result) { const auto i = threadIdx.x; - result[i] = max(x[i], y[i]); + result[i] = arb::gpu::max(x[i], y[i]); } __global__ void test_exprelr(double* x, double* result) { const auto i = threadIdx.x; - result[i] = exprelr(x[i]); + result[i] = arb::gpu::exprelr(x[i]); } } diff --git a/tests/unit/test_maputil.cpp b/tests/unit/test_maputil.cpp new file mode 100644 index 0000000000000000000000000000000000000000..17e8693842f6b8a133e65958c17cdfd7faa9462e --- /dev/null +++ b/tests/unit/test_maputil.cpp @@ -0,0 +1,190 @@ +#include "../gtest.h" + +#include <map> +#include <set> +#include <unordered_map> +#include <unordered_set> +#include <vector> +#include <utility> + +#include <util/maputil.hpp> +#include <util/rangeutil.hpp> + +#include "common.hpp" + +using namespace arb; + +using namespace testing::string_literals; +using testing::nocopy; +using testing::nomove; + +// TODO: Add unit tests for other new functionality in maputil. + +TEST(maputil, keys) { + { + std::map<int, double> map = {{10, 2.0}, {3, 8.0}}; + std::vector<int> expected = {3, 10}; + std::vector<int> keys = util::assign_from(util::keys(map)); + EXPECT_EQ(expected, keys); + } + + { + struct cmp { + bool operator()(const nocopy<int>& a, const nocopy<int>& b) const { + return a.value<b.value; + } + }; + std::map<nocopy<int>, double, cmp> map; + map.insert(std::pair<nocopy<int>, double>(11, 2.0)); + map.insert(std::pair<nocopy<int>, double>(2, 0.3)); + map.insert(std::pair<nocopy<int>, double>(2, 0.8)); + map.insert(std::pair<nocopy<int>, double>(5, 0.1)); + + std::vector<int> expected = {2, 5, 11}; + std::vector<int> keys; + for (auto& k: util::keys(map)) { + keys.push_back(k.value); + } + EXPECT_EQ(expected, keys); + } + + { + std::unordered_multimap<int, double> map = {{3, 0.1}, {5, 0.4}, {11, 0.8}, {5, 0.2}}; + std::vector<int> expected = {3, 5, 5, 11}; + std::vector<int> keys = util::assign_from(util::keys(map)); + util::sort(keys); + EXPECT_EQ(expected, keys); + } +} + +TEST(maputil, is_assoc) { + using util::is_associative_container; + + EXPECT_TRUE((is_associative_container<std::map<int, double>>::value)); + EXPECT_TRUE((is_associative_container<std::unordered_map<int, double>>::value)); + EXPECT_TRUE((is_associative_container<std::unordered_multimap<int, double>>::value)); + + EXPECT_FALSE((is_associative_container<std::set<int>>::value)); + EXPECT_FALSE((is_associative_container<std::set<std::pair<int, double>>>::value)); + EXPECT_FALSE((is_associative_container<std::vector<std::pair<int, double>>>::value)); +} + +// Sub-class map to check that find method is being properly used. + +namespace { + struct static_counter { + static int count; + } S; + int static_counter::count = 0; + + template <typename K, typename V> + struct check_map: std::map<K, V>, static_counter { + using typename std::map<K, V>::value_type; + using typename std::map<K, V>::iterator; + using typename std::map<K, V>::const_iterator; + + check_map(): std::map<K, V>() {} + check_map(std::initializer_list<value_type> init): std::map<K, V>(init) {} + + const_iterator find(const K& key) const { + ++count; + return std::map<K, V>::find(key); + } + + iterator find(const K& key) { + ++count; + return std::map<K, V>::find(key); + } + }; + + template <typename X> + constexpr bool is_optional_reference(X) { return false; } + + template <typename X> + constexpr bool is_optional_reference(util::optional<X&>) { return true; } +} + +TEST(maputil, value_by_key_map) { + using util::value_by_key; + + check_map<std::string, int> map_s2i = { + {"fish", 4}, + {"sheep", 5}, + }; + + S.count = 0; + EXPECT_FALSE(value_by_key(map_s2i, "deer")); + EXPECT_EQ(1, S.count); + + // Should get an optional reference if argument is an lvalue. + + S.count = 0; + auto r1 = value_by_key(map_s2i, "sheep"); + EXPECT_EQ(1, S.count); + EXPECT_TRUE(r1); + EXPECT_EQ(5, r1.value()); + EXPECT_TRUE(is_optional_reference(r1)); + r1.value() = 6; + EXPECT_EQ(6, value_by_key(map_s2i, "sheep").value()); + + // Should not get an optional reference if argument is an rvalue. + + S.count = 0; + auto r2 = value_by_key(check_map<std::string, int>(map_s2i), "fish"); + EXPECT_EQ(1, S.count); + EXPECT_TRUE(r2); + EXPECT_EQ(4, r2.value()); + EXPECT_FALSE(is_optional_reference(r2)); + + // Providing an explicit comparator should fall-back to serial search. + + S.count = 0; + auto str_cmp = [](const std::string& k1, const char* k2) { return k1==k2; }; + auto r3 = value_by_key(map_s2i, "fish", str_cmp); + EXPECT_EQ(0, S.count); + EXPECT_TRUE(r3); + EXPECT_EQ(4, r3.value()); +} + +TEST(maputil, value_by_key_sequence) { + using util::value_by_key; + + // Note: value_by_key returns the value of `get<1>` on the + // entries in the map or sequence. + + using entry = std::tuple<int, std::string, char>; + std::vector<entry> table = { + entry(1, "one", '1'), + entry(3, "three", '3'), + entry(5, "five", '5') + }; + + EXPECT_FALSE(value_by_key(table, 2)); + + auto r1 = value_by_key(table, 3); + EXPECT_TRUE(r1); + EXPECT_EQ("three"_s, r1.value()); + EXPECT_TRUE(is_optional_reference(r1)); + r1.value() = "four"; + EXPECT_EQ("four"_s, value_by_key(table, 3).value()); + + auto r2 = value_by_key(std::move(table), 1); + EXPECT_TRUE(r2); + EXPECT_EQ("one", r2.value()); + EXPECT_FALSE(is_optional_reference(r2)); +} + +TEST(maputil, binary_search_index) { + using util::binary_search_index; + const int ns[] = {2, 3, 3, 5, 6, 113, 114, 114, 116}; + + for (int x: {7, 1, 117}) { + EXPECT_FALSE(binary_search_index(ns, x)); + } + + for (int x: {114, 3, 5, 2, 116}) { + auto opti = binary_search_index(ns, x); + ASSERT_TRUE(opti); + EXPECT_EQ(x, ns[*opti]); + } +} diff --git a/tests/unit/test_math.cpp b/tests/unit/test_math.cpp index 37e38594b277cce75b4ac333ce46a29c498ac9aa..88e979ed2e376361a7ec4c96bda0103db60ae6ab 100644 --- a/tests/unit/test_math.cpp +++ b/tests/unit/test_math.cpp @@ -140,6 +140,89 @@ TEST(math, signum) { EXPECT_EQ(-1, signum(-infinity<float>())); } +TEST(math, next_pow2) { + EXPECT_EQ(0u, next_pow2(0u)); + EXPECT_EQ(1u, next_pow2(1u)); + EXPECT_EQ(2u, next_pow2(2u)); + EXPECT_EQ(4u, next_pow2(3u)); + EXPECT_EQ(4u, next_pow2(3u)); + EXPECT_EQ(64u, next_pow2(53u)); + EXPECT_EQ(0u, next_pow2(unsigned(-1))); + + auto unsigned_bits = std::numeric_limits<unsigned>::digits; + unsigned big = 1u<<(unsigned_bits-1); + + EXPECT_EQ(big, next_pow2(big)); + EXPECT_EQ(big, next_pow2(big-1)); + EXPECT_EQ(big/2, next_pow2(big/2)); + EXPECT_EQ(big/2, next_pow2(big/2-35)); + EXPECT_EQ(0u, next_pow2(big+1)); + EXPECT_EQ(0u, next_pow2(big+big/2)); + + + EXPECT_EQ(0ull, next_pow2(0ull)); + EXPECT_EQ(1ull, next_pow2(1ull)); + EXPECT_EQ(2ull, next_pow2(2ull)); + EXPECT_EQ(4ull, next_pow2(3ull)); + EXPECT_EQ(4ull, next_pow2(3ull)); + EXPECT_EQ(64ull, next_pow2(53ull)); + EXPECT_EQ(0ull, next_pow2((unsigned long long)(-1))); + + auto ull_bits = std::numeric_limits<unsigned long long>::digits; + unsigned long long bigll = 1ull<<(ull_bits-1); + + EXPECT_EQ(bigll, next_pow2(bigll)); + EXPECT_EQ(bigll, next_pow2(bigll-1)); + EXPECT_EQ(bigll/2, next_pow2(bigll/2)); + EXPECT_EQ(bigll/2, next_pow2(bigll/2-35)); + EXPECT_EQ(0ull, next_pow2(bigll+1)); + EXPECT_EQ(0ull, next_pow2(bigll+bigll/2)); +} + +TEST(math, round_up) { + // signed tests + + EXPECT_EQ(0, round_up(0, 23)); + EXPECT_EQ(0, round_up(0, -23)); + + EXPECT_EQ(99, round_up(99, 1)); + EXPECT_EQ(99, round_up(99, -1)); + EXPECT_EQ(-99, round_up(-99, 1)); + EXPECT_EQ(-99, round_up(-99, -1)); + + int base1 = 100; + EXPECT_EQ(5*base1, round_up(5*base1, base1)); + EXPECT_EQ(5*base1, round_up(5*base1-1, base1)); + EXPECT_EQ(5*base1, round_up(4*base1+1, base1)); + EXPECT_EQ(-5*base1, round_up(-5*base1, base1)); + EXPECT_EQ(-5*base1, round_up(-5*base1+1, base1)); + EXPECT_EQ(-5*base1, round_up(-4*base1-1, base1)); + + int base2 = -23; + EXPECT_EQ(7*base2, round_up(7*base2, base2)); + EXPECT_EQ(7*base2, round_up(7*base2+1, base2)); + EXPECT_EQ(7*base2, round_up(6*base2-1, base2)); + EXPECT_EQ(-7*base2, round_up(-7*base2, base2)); + EXPECT_EQ(-7*base2, round_up(-7*base2-1, base2)); + EXPECT_EQ(-7*base2, round_up(-6*base2+1, base2)); + + // unsigned tests + + EXPECT_EQ(0u, round_up(0u, 23u)); + EXPECT_EQ(99u, round_up(99, 1u)); + + unsigned base3 = 100; + EXPECT_EQ(5*base3, round_up(5*base3, base3)); + EXPECT_EQ(5*base3, round_up(5*base3-1, base3)); + EXPECT_EQ(5*base3, round_up(4*base3+1, base3)); + + // promotion works? + ASSERT_GT(sizeof(unsigned long long), sizeof(int)); + unsigned long long v = 1ull << (std::numeric_limits<unsigned long long>::digits-1); + int base = 4; + EXPECT_EQ(v, round_up(v, base)); + EXPECT_EQ(v-base, round_up(v-base-1, base)); +} TEST(quaternion, ctor) { // scalar diff --git a/tests/unit/test_matrix.cpp b/tests/unit/test_matrix.cpp index 643c914679994a859278efa9b23923da8429c5d7..aa5d17e75eb648fbb032921f505e155fbd0d80b2 100644 --- a/tests/unit/test_matrix.cpp +++ b/tests/unit/test_matrix.cpp @@ -6,6 +6,7 @@ #include <math.hpp> #include <matrix.hpp> #include <backends/multicore/fvm.hpp> +#include <util/rangeutil.hpp> #include <util/span.hpp> #include "common.hpp" @@ -13,29 +14,29 @@ using namespace arb; using matrix_type = matrix<arb::multicore::backend>; -using size_type = matrix_type::size_type; +using index_type = matrix_type::index_type; using value_type = matrix_type::value_type; using vvec = std::vector<value_type>; TEST(matrix, construct_from_parent_only) { - std::vector<size_type> p = {0,0,1}; + std::vector<index_type> p = {0,0,1}; matrix_type m(p, {0, 3}, vvec(3), vvec(3), vvec(3)); EXPECT_EQ(m.num_cells(), 1u); EXPECT_EQ(m.size(), 3u); EXPECT_EQ(p.size(), 3u); auto mp = m.p(); - EXPECT_EQ(mp[0], 0u); - EXPECT_EQ(mp[1], 0u); - EXPECT_EQ(mp[2], 1u); + EXPECT_EQ(mp[0], index_type(0)); + EXPECT_EQ(mp[1], index_type(0)); + EXPECT_EQ(mp[2], index_type(1)); } TEST(matrix, solve_host) { using util::make_span; - using memory::fill; + using util::fill; // trivial case : 1x1 matrix { @@ -52,12 +53,12 @@ TEST(matrix, solve_host) // matrices in the range of 2x2 to 1000x1000 { - for(auto n : make_span(2u,1001u)) { - auto p = std::vector<size_type>(n); + for(auto n : make_span(2, 1001)) { + auto p = std::vector<index_type>(n); std::iota(p.begin()+1, p.end(), 0); matrix_type m(p, {0, n}, vvec(n), vvec(n), vvec(n)); - EXPECT_EQ(m.size(), n); + EXPECT_EQ(m.size(), (unsigned)n); EXPECT_EQ(m.num_cells(), 1u); auto& A = m.state_; @@ -87,20 +88,20 @@ TEST(matrix, zero_diagonal) // elements should be ignored). // These submatrices should leave the rhs as-is when solved. - using memory::make_const_view; + using util::assign; // Three matrices, sizes 3, 3 and 2, with no branching. - std::vector<size_type> p = {0, 0, 1, 3, 3, 5, 5}; - std::vector<size_type> c = {0, 3, 5, 7}; + std::vector<index_type> p = {0, 0, 1, 3, 3, 5, 5}; + std::vector<index_type> c = {0, 3, 5, 7}; matrix_type m(p, c, vvec(7), vvec(7), vvec(7)); EXPECT_EQ(7u, m.size()); EXPECT_EQ(3u, m.num_cells()); auto& A = m.state_; - A.d = make_const_view(vvec({2, 3, 2, 0, 0, 4, 5})); - A.u = make_const_view(vvec({0, -1, -1, 0, -1, 0, -2})); - A.rhs = make_const_view(vvec({3, 5, 7, 7, 8, 16, 32})); + assign(A.d, vvec({2, 3, 2, 0, 0, 4, 5})); + assign(A.u, vvec({0, -1, -1, 0, -1, 0, -2})); + assign(A.rhs, vvec({3, 5, 7, 7, 8, 16, 32})); // Expected solution: std::vector<value_type> expected = {4, 5, 6, 7, 8, 9, 10}; @@ -117,7 +118,7 @@ TEST(matrix, zero_diagonal_assembled) // test case from CV data. using util::assign; - using memory::make_view; + using array = matrix_type::array; // Combined matrix may have zero-blocks, corresponding to a zero dt. // Zero-blocks are indicated by zero value in the diagonal (the off-diagonal @@ -125,22 +126,22 @@ TEST(matrix, zero_diagonal_assembled) // These submatrices should leave the rhs as-is when solved. // Three matrices, sizes 3, 3 and 2, with no branching. - std::vector<size_type> p = {0, 0, 1, 3, 3, 5, 5}; - std::vector<size_type> c = {0, 3, 5, 7}; + std::vector<index_type> p = {0, 0, 1, 3, 3, 5, 5}; + std::vector<index_type> c = {0, 3, 5, 7}; // Face conductances. vvec g = {0, 1, 1, 0, 1, 0, 2}; // dt of 1e-3. - vvec dt(3, 1.0e-3); + array dt(3, 1.0e-3); // Capacitances. vvec Cm = {1, 1, 1, 1, 1, 2, 3}; // Intial voltage of zero; currents alone determine rhs. - vvec v(7, 0.0); + array v(7, 0.0); vvec area(7, 1.0); - vvec i = {-3000, -5000, -7000, -6000, -9000, -16000, -32000}; + array i = {-3000, -5000, -7000, -6000, -9000, -16000, -32000}; // Expected matrix and rhs: // u = [ 0 -1 -1 0 -1 0 -2] @@ -151,11 +152,11 @@ TEST(matrix, zero_diagonal_assembled) // x = [ 4 5 6 7 8 9 10] matrix_type m(p, c, Cm, g, area); - m.assemble(make_view(dt), make_view(v), make_view(i)); + m.assemble(dt, v, i); m.solve(); vvec x; - assign(x, on_host(m.solution())); + assign(x, m.solution()); std::vector<value_type> expected = {4, 5, 6, 7, 8, 9, 10}; EXPECT_TRUE(testing::seq_almost_eq<double>(expected, x)); @@ -166,7 +167,7 @@ TEST(matrix, zero_diagonal_assembled) dt[1] = 0; v[3] = 20; v[4] = 30; - m.assemble(make_view(dt), make_view(v), make_view(i)); + m.assemble(dt, v, i); m.solve(); assign(x, m.solution()); diff --git a/tests/unit/test_matrix.cu b/tests/unit/test_matrix.cu index 5222bb1b51b6d86488c56991877dc43617ab7d3d..a24863ac557a05b81058de3aea7515be572bd004 100644 --- a/tests/unit/test_matrix.cu +++ b/tests/unit/test_matrix.cu @@ -5,16 +5,16 @@ #include "../gtest.h" #include "common.hpp" +#include <algorithms.hpp> #include <math.hpp> #include <matrix.hpp> -#include <backends/gpu/fvm.hpp> -#include <backends/multicore/fvm.hpp> #include <memory/memory.hpp> #include <util/span.hpp> +#include <backends/gpu/cuda_common.hpp> #include <backends/gpu/matrix_state_flat.hpp> #include <backends/gpu/matrix_state_interleaved.hpp> -#include <backends/gpu/kernels/interleave.hpp> +#include <backends/gpu/matrix_interleave.hpp> #include <cuda.h> @@ -31,7 +31,8 @@ using testing::seq_almost_eq; using std::begin; using std::end; -// will test the flat_to_interleaved and interleaved_to_flat operations for the + +// Test the flat_to_interleaved and interleaved_to_flat operations for the // set of matrices defined by sizes and starts. // Applies the interleave to the vector in values, and checks this against // a reference result generated using a host side reference implementation. @@ -203,118 +204,16 @@ TEST(matrix, interleave) } } -// Test that matrix assembly works. -// The test proceeds by assembling a reference matrix on the host and -// device backends, then performs solve, and compares solution. -// -// limitations of test -// * matrices all have same size and structure -TEST(matrix, assemble) -{ - using gpu_state = gpu::backend::matrix_state; - using mc_state = multicore::backend::matrix_state; - - using T = typename gpu::backend::value_type; - using I = typename gpu::backend::size_type; - - using gpu_array = typename gpu::backend::array; - using host_array = typename multicore::backend::array; - - // There are two matrix structures: - // - // p_1: 3 branches, 6 compartments - // - // 3 - // /. - // 0 - 1 - 2 - // \. - // 4 - // \. - // 5 - // - // p_2: 5 branches, 8 compartments - // - // 4 - // /. - // 3 - // / \. - // 0 - 1 - 2 5 - // \. - // 6 - // \. - // 7 - - // The parent indexes that define the two matrix structures - std::vector<std::vector<I>> - p_base = { {0,0,1,2,2,4}, {0,0,1,2,3,3,2,6} }; - - // Make a set of matrices based on repeating this pattern. - // We assign the patterns round-robin, i.e. so that the input - // matrices will have alternating sizes of 6 and 8, which will - // test the solver with variable matrix size, and exercise - // solvers that reorder matrices according to size. - const int num_mtx = 8; - - std::vector<I> p; - std::vector<I> cell_index; - for (auto m=0; m<num_mtx; ++m) { - auto &p_ref = p_base[m%2]; - auto first = p.size(); - for (auto i: p_ref) { - p.push_back(i + first); - } - cell_index.push_back(first); - } - cell_index.push_back(p.size()); - - auto group_size = cell_index.back(); - - // Build the capacitance and conductance vectors and - // populate with nonzero random values. - auto gen = std::mt19937(); - auto dist = std::uniform_real_distribution<T>(1, 2); - - std::vector<T> Cm(group_size); - std::generate(Cm.begin(), Cm.end(), [&](){return dist(gen);}); - - std::vector<T> g(group_size); - std::generate(g.begin(), g.end(), [&](){return dist(gen);}); - - std::vector<T> area(group_size, 1e3); - - // Make the reference matrix and the gpu matrix - auto m_mc = mc_state( p, cell_index, Cm, g, area); // on host - auto m_gpu = gpu_state(p, cell_index, Cm, g, area); // on gpu - - // Set the integration times for the cells to be between 0.1 and 0.2 ms. - std::vector<T> dt(num_mtx); - - auto dt_dist = std::uniform_real_distribution<T>(0.1, 0.2); - std::generate(dt.begin(), dt.end(), [&](){return dt_dist(gen);}); - - // Voltage and current values - m_mc.assemble(on_host(dt), host_array(group_size, -64), host_array(group_size, 10)); - m_mc.solve(); - m_gpu.assemble(on_gpu(dt), gpu_array(group_size, -64), gpu_array(group_size, 10)); - m_gpu.solve(); - - // Compare the GPU and CPU results. - // Cast result to float, because we are happy to ignore small differencs - std::vector<float> result_h = util::assign_from(m_mc.solution()); - std::vector<float> result_g = util::assign_from(on_host(m_gpu.solution())); - EXPECT_TRUE(seq_almost_eq<float>(result_h, result_g)); -} - // test that the flat and interleaved storage back ends produce identical results TEST(matrix, backends) { - using T = typename gpu::backend::value_type; - using I = typename gpu::backend::size_type; + using T = fvm_value_type; + using I = fvm_index_type; using state_flat = gpu::matrix_state_flat<T, I>; using state_intl = gpu::matrix_state_interleaved<T, I>; - using gpu_array = typename gpu::backend::array; + using gpu_array = memory::device_vector<T>; // There are two matrix structures: // diff --git a/tests/unit/test_matrix_cpuvsgpu.cpp b/tests/unit/test_matrix_cpuvsgpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4b4959bee0470c165efe7419afe1d8dcc4167f22 --- /dev/null +++ b/tests/unit/test_matrix_cpuvsgpu.cpp @@ -0,0 +1,131 @@ +#include <numeric> +#include <random> +#include <vector> + +#include "../gtest.h" +#include "common.hpp" + +#include <algorithms.hpp> +#include <math.hpp> +#include <matrix.hpp> +#include <memory/memory.hpp> +#include <util/span.hpp> + +#include <backends/gpu/fvm.hpp> +#include <backends/multicore/fvm.hpp> + +using namespace arb; + +using util::make_span; +using util::assign_from; +using memory::on_gpu; +using memory::on_host; + +using testing::seq_almost_eq; + +using std::begin; +using std::end; + + +// Test that matrix assembly works. +// The test proceeds by assembling a reference matrix on the host and +// device backends, then performs solve, and compares solution. +// +// limitations of test +// * matrices all have same size and structure +TEST(matrix, assemble) +{ + using gpu_state = gpu::backend::matrix_state; + using mc_state = multicore::backend::matrix_state; + + using T = fvm_value_type; + using I = fvm_index_type; + + using gpu_array = typename gpu::backend::array; + using host_array = typename multicore::backend::array; + + // There are two matrix structures: + // + // p_1: 3 branches, 6 compartments + // + // 3 + // /. + // 0 - 1 - 2 + // \. + // 4 + // \. + // 5 + // + // p_2: 5 branches, 8 compartments + // + // 4 + // /. + // 3 + // / \. + // 0 - 1 - 2 5 + // \. + // 6 + // \. + // 7 + + // The parent indexes that define the two matrix structures + std::vector<std::vector<I>> + p_base = { {0,0,1,2,2,4}, {0,0,1,2,3,3,2,6} }; + + // Make a set of matrices based on repeating this pattern. + // We assign the patterns round-robin, i.e. so that the input + // matrices will have alternating sizes of 6 and 8, which will + // test the solver with variable matrix size, and exercise + // solvers that reorder matrices according to size. + const int num_mtx = 8; + + std::vector<I> p; + std::vector<I> cell_index; + for (auto m=0; m<num_mtx; ++m) { + auto &p_ref = p_base[m%2]; + auto first = p.size(); + for (auto i: p_ref) { + p.push_back(i + first); + } + cell_index.push_back(first); + } + cell_index.push_back(p.size()); + + auto group_size = cell_index.back(); + + // Build the capacitance and conductance vectors and + // populate with nonzero random values. + auto gen = std::mt19937(); + auto dist = std::uniform_real_distribution<T>(1, 2); + + std::vector<T> Cm(group_size); + std::generate(Cm.begin(), Cm.end(), [&](){return dist(gen);}); + + std::vector<T> g(group_size); + std::generate(g.begin(), g.end(), [&](){return dist(gen);}); + + std::vector<T> area(group_size, 1e3); + + // Make the reference matrix and the gpu matrix + auto m_mc = mc_state( p, cell_index, Cm, g, area); // on host + auto m_gpu = gpu_state(p, cell_index, Cm, g, area); // on gpu + + // Set the integration times for the cells to be between 0.1 and 0.2 ms. + std::vector<T> dt(num_mtx); + + auto dt_dist = std::uniform_real_distribution<T>(0.1, 0.2); + std::generate(dt.begin(), dt.end(), [&](){return dt_dist(gen);}); + + // Voltage and current values + m_mc.assemble(host_array(dt.begin(), dt.end()), host_array(group_size, -64), host_array(group_size, 10)); + m_mc.solve(); + m_gpu.assemble(on_gpu(dt), gpu_array(group_size, -64), gpu_array(group_size, 10)); + m_gpu.solve(); + + // Compare the GPU and CPU results. + // Cast result to float, because we are happy to ignore small differencs + std::vector<float> result_h = util::assign_from(m_mc.solution()); + std::vector<float> result_g = util::assign_from(on_host(m_gpu.solution())); + EXPECT_TRUE(seq_almost_eq<float>(result_h, result_g)); +} + diff --git a/tests/unit/test_mc_cell_group.cpp b/tests/unit/test_mc_cell_group.cpp index 92d4f415e68478830fd6b532ce6b9e6a7745af93..4bef15599e15dd11dda1646cd5a0ba6b2a949dd3 100644 --- a/tests/unit/test_mc_cell_group.cpp +++ b/tests/unit/test_mc_cell_group.cpp @@ -1,9 +1,9 @@ #include "../gtest.h" -#include <backends/multicore/fvm.hpp> +#include <backends.hpp> #include <common_types.hpp> #include <epoch.hpp> -#include <fvm_multicell.hpp> +#include <fvm_lowered_cell.hpp> #include <mc_cell_group.hpp> #include <util/rangeutil.hpp> @@ -12,31 +12,39 @@ #include "../simple_recipes.hpp" using namespace arb; -using fvm_cell = fvm::fvm_multicell<arb::multicore::backend>; -cell make_cell() { - auto c = make_cell_ball_and_stick(); +namespace { + fvm_lowered_cell_ptr lowered_cell() { + return make_fvm_lowered_cell(backend_kind::multicore); + } + + cell make_cell() { + auto c = make_cell_ball_and_stick(); - c.add_detector({0, 0}, 0); - c.segment(1)->set_compartments(101); + c.add_detector({0, 0}, 0); + c.segment(1)->set_compartments(101); - return c; + return c; + } } +ACCESS_BIND( + std::vector<cell_member_type> mc_cell_group::*, + private_spike_sources_ptr, + &mc_cell_group::spike_sources_) + TEST(mc_cell_group, get_kind) { - mc_cell_group<fvm_cell> group{{0}, cable1d_recipe(make_cell()) }; + mc_cell_group group{{0}, cable1d_recipe(make_cell()), lowered_cell()}; - // we are generating a mc_cell_group which should be of the correct type EXPECT_EQ(cell_kind::cable1d_neuron, group.get_cell_kind()); } TEST(mc_cell_group, test) { - mc_cell_group<fvm_cell> group{{0}, cable1d_recipe(make_cell()) }; - + mc_cell_group group{{0}, cable1d_recipe(make_cell()), lowered_cell()}; group.advance(epoch(0, 50), 0.01, {}); - // the model is expected to generate 4 spikes as a result of the - // fixed stimulus over 50 ms + // Model is expected to generate 4 spikes as a result of the + // fixed stimulus over 50 ms. EXPECT_EQ(4u, group.spikes().size()); } @@ -55,12 +63,12 @@ TEST(mc_cell_group, sources) { } std::vector<cell_gid_type> gids = {3u, 4u, 10u, 16u, 17u, 18u}; - mc_cell_group<fvm_cell> group{gids, cable1d_recipe(cells)}; + mc_cell_group group{gids, cable1d_recipe(cells), lowered_cell()}; // Expect group sources to be lexicographically sorted by source id // with gids in cell group's range and indices starting from zero. - const auto& sources = group.spike_sources(); + const auto& sources = group.*private_spike_sources_ptr; for (unsigned j = 0; j<sources.size(); ++j) { auto id = sources[j]; if (j==0) { diff --git a/tests/unit/test_mc_cell_group.cu b/tests/unit/test_mc_cell_group.cu deleted file mode 100644 index 8c5a2e112c6bcb1a39effc3379c54c603f8885c1..0000000000000000000000000000000000000000 --- a/tests/unit/test_mc_cell_group.cu +++ /dev/null @@ -1,34 +0,0 @@ -#include "../gtest.h" - -#include <backends/gpu/fvm.hpp> -#include <common_types.hpp> -#include <epoch.hpp> -#include <fvm_multicell.hpp> -#include <mc_cell_group.hpp> -#include <util/rangeutil.hpp> - -#include "../common_cells.hpp" -#include "../simple_recipes.hpp" - -using namespace arb; -using fvm_cell = fvm::fvm_multicell<arb::gpu::backend>; - -cell make_cell() { - auto c = make_cell_ball_and_stick(); - - c.add_detector({0, 0}, 0); - c.segment(1)->set_compartments(101); - - return c; -} - -TEST(mc_cell_group, test) -{ - mc_cell_group<fvm_cell> group({0u}, cable1d_recipe(make_cell())); - - group.advance(epoch(0, 50), 0.01, {}); - - // the model is expected to generate 4 spikes as a result of the - // fixed stimulus over 50 ms - EXPECT_EQ(4u, group.spikes().size()); -} diff --git a/tests/unit/test_mc_cell_group_gpu.cpp b/tests/unit/test_mc_cell_group_gpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..841db8ed2e9c6e6145e5fe134ec37c6125408a45 --- /dev/null +++ b/tests/unit/test_mc_cell_group_gpu.cpp @@ -0,0 +1,37 @@ +#include "../gtest.h" + +#include <backends.hpp> +#include <common_types.hpp> +#include <epoch.hpp> +#include <fvm_lowered_cell.hpp> +#include <mc_cell_group.hpp> + +#include "../common_cells.hpp" +#include "../simple_recipes.hpp" + +using namespace arb; + +namespace { + fvm_lowered_cell_ptr lowered_cell() { + return make_fvm_lowered_cell(backend_kind::gpu); + } + + cell make_cell() { + auto c = make_cell_ball_and_stick(); + + c.add_detector({0, 0}, 0); + c.segment(1)->set_compartments(101); + + return c; + } +} + +TEST(mc_cell_group, test) +{ + mc_cell_group group({0}, cable1d_recipe(make_cell()), lowered_cell()); + group.advance(epoch(0, 50), 0.01, {}); + + // the model is expected to generate 4 spikes as a result of the + // fixed stimulus over 50 ms + EXPECT_EQ(4u, group.spikes().size()); +} diff --git a/tests/unit/test_mechanisms.cpp b/tests/unit/test_mechanisms.cpp index a49557f7adef09d392ee3bca1898887613942827..1c991d59dc0fc001e59ba40037851776a2e7d27b 100644 --- a/tests/unit/test_mechanisms.cpp +++ b/tests/unit/test_mechanisms.cpp @@ -1,5 +1,8 @@ #include "../gtest.h" +// TODO: Amend for new mechanism architecture +#if 0 + // Prototype mechanisms in tests #include "mech_proto/expsyn_cpu.hpp" #include "mech_proto/exp2syn_cpu.hpp" @@ -237,3 +240,5 @@ using mechanism_types = ::testing::Types< >; INSTANTIATE_TYPED_TEST_CASE_P(mechanism_types, mechanisms, mechanism_types); + +#endif // 0 diff --git a/tests/unit/test_mechcat.cpp b/tests/unit/test_mechcat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f7f6db969189f922c060ca788d54396a1936b8e6 --- /dev/null +++ b/tests/unit/test_mechcat.cpp @@ -0,0 +1,312 @@ +#include <backends/fvm_types.hpp> +#include <mechanism.hpp> +#include <mechcat.hpp> +#include <mechinfo.hpp> + +#include "common.hpp" + +using namespace arb; +using namespace testing::string_literals; + +// Set up a small system of mechanisms and backends for testing, +// comprising: +// +// * Two mechanisms: burble and fleeb. +// +// * Two backends: foo and bar. +// +// * Three implementations of fleeb: +// - Two for the foo backend: fleeb_foo and special_fleeb_foo. +// - One for the bar backend: fleeb_bar. +// +// * One implementation of burble for the bar backend, with +// a mismatched fingerprint: burble_bar. + +// Mechanism info: + +using field_kind = mechanism_field_spec::field_kind; + +mechanism_info burble_info = { + {{"quux", {field_kind::global, "nA", 2.3, 0, 10.}}, + {"xyzzy", {field_kind::global, "mV", 5.1, -20, 20.}}}, + {}, + {}, + {}, + "burbleprint" +}; + +mechanism_info fleeb_info = { + {{"plugh", {field_kind::global, "C", 2.3, 0, 10.}}, + {"norf", {field_kind::global, "mGy", 0.1, 0, 5000.}}}, + {}, + {}, + {}, + "fleebprint" +}; + +// Backend classes: + +template <typename B> +struct common_impl: concrete_mechanism<B> { + void instantiate(fvm_size_type id, typename B::shared_state& state, const mechanism::layout& l) override { + width_ = l.cv.size(); + // Write mechanism global values to shared state to test instatiation call and catalogue global + // variable overrides. + for (auto& kv: overrides_) { + state.overrides.insert(kv); + } + } + + std::size_t memory() const override { return 10u; } + std::size_t size() const override { return width_; } + + void set_parameter(const std::string& key, const std::vector<fvm_value_type>& vs) override {} + + void set_global(const std::string& key, fvm_value_type v) override { + overrides_[key] = v; + } + + void nrn_init() override {} + void nrn_state() override {} + void nrn_current() override {} + void deliver_events() override {} + void write_ions() override {} + + std::unordered_map<std::string, fvm_value_type> overrides_; + std::size_t width_ = 0; +}; + +struct foo_backend { + struct shared_state { + std::unordered_map<std::string, fvm_value_type> overrides; + }; +}; + +using foo_mechanism = common_impl<foo_backend>; + +struct bar_backend { + struct shared_state { + std::unordered_map<std::string, fvm_value_type> overrides; + }; +}; + +using bar_mechanism = common_impl<bar_backend>; + +// Fleeb implementations: + +struct fleeb_foo: foo_mechanism { + const mechanism_fingerprint& fingerprint() const override { + static mechanism_fingerprint hash = "fleebprint"; + return hash; + } + + std::string internal_name() const override { return "fleeb"; } + mechanismKind kind() const override { return mechanismKind::density; } + mechanism_ptr clone() const override { return mechanism_ptr(new fleeb_foo()); } +}; + +struct special_fleeb_foo: foo_mechanism { + const mechanism_fingerprint& fingerprint() const override { + static mechanism_fingerprint hash = "fleebprint"; + return hash; + } + + std::string internal_name() const override { return "special fleeb"; } + mechanismKind kind() const override { return mechanismKind::density; } + mechanism_ptr clone() const override { return mechanism_ptr(new special_fleeb_foo()); } +}; + +struct fleeb_bar: bar_mechanism { + const mechanism_fingerprint& fingerprint() const override { + static mechanism_fingerprint hash = "fleebprint"; + return hash; + } + + std::string internal_name() const override { return "fleeb"; } + mechanismKind kind() const override { return mechanismKind::density; } + mechanism_ptr clone() const override { return mechanism_ptr(new fleeb_bar()); } +}; + +// Burble implementation: + +struct burble_bar: bar_mechanism { + const mechanism_fingerprint& fingerprint() const override { + static mechanism_fingerprint hash = "fnord"; + return hash; + } + + std::string internal_name() const override { return "burble"; } + mechanismKind kind() const override { return mechanismKind::density; } + mechanism_ptr clone() const override { return mechanism_ptr(new burble_bar()); } +}; + +// Implementation register helper: + +template <typename B, typename M> +std::unique_ptr<concrete_mechanism<B>> make_mech() { + return std::unique_ptr<concrete_mechanism<B>>(new M()); +} + +// Mechinfo equality test: + +namespace arb { +static bool operator==(const mechanism_field_spec& a, const mechanism_field_spec& b) { + return a.kind==b.kind && a.units==b.units && a.default_value==b.default_value && a.lower_bound==b.lower_bound && a.upper_bound==b.upper_bound; +} + +static bool operator==(const ion_dependency& a, const ion_dependency& b) { + return a.write_concentration_int==b.write_concentration_int && a.write_concentration_ext==b.write_concentration_ext; +} + +static bool operator==(const mechanism_info& a, const mechanism_info& b) { + return a.globals==b.globals && a.parameters==b.parameters && a.state==b.state && a.ions==b.ions && a.fingerprint==b.fingerprint; +} +} + +mechanism_catalogue build_fake_catalogue() { + mechanism_catalogue cat; + + cat.add("fleeb", fleeb_info); + cat.add("burble", burble_info); + + // Add derived versions with global overrides: + + cat.derive("fleeb1", "fleeb", {{"plugh", 1.0}}); + cat.derive("special_fleeb", "fleeb", {{"plugh", 2.0}}); + cat.derive("fleeb2", "special_fleeb", {{"norf", 11.0}}); + cat.derive("bleeble", "burble", {{"quux", 10.}, {"xyzzy", -20.}}); + + // Attach implementations: + + cat.register_implementation<bar_backend>("fleeb", make_mech<bar_backend, fleeb_bar>()); + cat.register_implementation<foo_backend>("fleeb", make_mech<foo_backend, fleeb_foo>()); + cat.register_implementation<foo_backend>("special_fleeb", make_mech<foo_backend, special_fleeb_foo>()); + + return cat; +} + +TEST(mechcat, fingerprint) { + auto cat = build_fake_catalogue(); + + EXPECT_EQ("fleebprint", cat.fingerprint("fleeb")); + EXPECT_EQ("fleebprint", cat.fingerprint("special_fleeb")); + EXPECT_EQ("burbleprint", cat.fingerprint("burble")); + EXPECT_EQ("burbleprint", cat.fingerprint("bleeble")); + + EXPECT_THROW(cat.register_implementation<bar_backend>("burble", make_mech<bar_backend, burble_bar>()), + std::invalid_argument); +} + +TEST(mechcat, derived_info) { + auto cat = build_fake_catalogue(); + + EXPECT_EQ(fleeb_info, cat["fleeb"]); + EXPECT_EQ(burble_info, cat["burble"]); + + mechanism_info expected_special_fleeb = fleeb_info; + expected_special_fleeb.globals["plugh"].default_value = 2.0; + EXPECT_EQ(expected_special_fleeb, cat["special_fleeb"]); + + mechanism_info expected_fleeb2 = fleeb_info; + expected_fleeb2.globals["plugh"].default_value = 2.0; + expected_fleeb2.globals["norf"].default_value = 11.0; + EXPECT_EQ(expected_fleeb2, cat["fleeb2"]); +} + +TEST(mechcat, queries) { + auto cat = build_fake_catalogue(); + + EXPECT_TRUE(cat.has("fleeb")); + EXPECT_TRUE(cat.has("special_fleeb")); + EXPECT_TRUE(cat.has("fleeb1")); + EXPECT_TRUE(cat.has("fleeb2")); + EXPECT_TRUE(cat.has("burble")); + EXPECT_TRUE(cat.has("bleeble")); + EXPECT_FALSE(cat.has("corge")); + + EXPECT_TRUE(cat.is_derived("special_fleeb")); + EXPECT_TRUE(cat.is_derived("fleeb1")); + EXPECT_TRUE(cat.is_derived("fleeb2")); + EXPECT_TRUE(cat.is_derived("bleeble")); + EXPECT_FALSE(cat.is_derived("fleeb")); + EXPECT_FALSE(cat.is_derived("burble")); +} + +TEST(mechcat, remove) { + auto cat = build_fake_catalogue(); + + cat.remove("special_fleeb"); + EXPECT_TRUE(cat.has("fleeb")); + EXPECT_TRUE(cat.has("fleeb1")); + EXPECT_FALSE(cat.has("special_fleeb")); + EXPECT_FALSE(cat.has("fleeb2")); // fleeb2 derived from special_fleeb. +} + +TEST(mechcat, instance) { + auto cat = build_fake_catalogue(); + + EXPECT_THROW(cat.instance<bar_backend>("burble"), std::invalid_argument); + + // All fleebs on the bar backend have the same implementation: + + auto fleeb_bar_mech = cat.instance<bar_backend>("fleeb"); + auto fleeb1_bar_mech = cat.instance<bar_backend>("fleeb1"); + auto special_fleeb_bar_mech = cat.instance<bar_backend>("special_fleeb"); + auto fleeb2_bar_mech = cat.instance<bar_backend>("fleeb2"); + + EXPECT_EQ(typeid(fleeb_bar), typeid(*fleeb_bar_mech.get())); + EXPECT_EQ(typeid(fleeb_bar), typeid(*fleeb1_bar_mech.get())); + EXPECT_EQ(typeid(fleeb_bar), typeid(*special_fleeb_bar_mech.get())); + EXPECT_EQ(typeid(fleeb_bar), typeid(*fleeb2_bar_mech.get())); + + EXPECT_EQ("fleeb"_s, fleeb2_bar_mech->internal_name()); + + // special_fleeb and fleeb2 (deriving from special_fleeb) have a specialized + // implementation: + + auto fleeb_foo_mech = cat.instance<foo_backend>("fleeb"); + auto fleeb1_foo_mech = cat.instance<foo_backend>("fleeb1"); + auto special_fleeb_foo_mech = cat.instance<foo_backend>("special_fleeb"); + auto fleeb2_foo_mech = cat.instance<foo_backend>("fleeb2"); + + EXPECT_EQ(typeid(fleeb_foo), typeid(*fleeb_foo_mech.get())); + EXPECT_EQ(typeid(fleeb_foo), typeid(*fleeb1_foo_mech.get())); + EXPECT_EQ(typeid(special_fleeb_foo), typeid(*special_fleeb_foo_mech.get())); + EXPECT_EQ(typeid(special_fleeb_foo), typeid(*fleeb2_foo_mech.get())); + + EXPECT_EQ("fleeb"_s, fleeb1_foo_mech->internal_name()); + EXPECT_EQ("special fleeb"_s, fleeb2_foo_mech->internal_name()); +} + +TEST(mechcat, instantiate) { + // Note: instantiating a mechanism doesn't normally have that mechanism + // write its specialized global variables to shared state, but we do in + // these tests for testing purposes. + + mechanism::layout layout = {{0u, 1u, 2u}, {1., 2., 1.}}; + bar_backend::shared_state bar_state; + + auto cat = build_fake_catalogue(); + + cat.instance<bar_backend>("fleeb")->instantiate(0, bar_state, layout); + EXPECT_TRUE(bar_state.overrides.empty()); + + bar_state.overrides.clear(); + cat.instance<bar_backend>("fleeb2")->instantiate(0, bar_state, layout); + EXPECT_EQ(2.0, bar_state.overrides.at("plugh")); + EXPECT_EQ(11.0, bar_state.overrides.at("norf")); +} + +TEST(mechcat, copy) { + auto cat = build_fake_catalogue(); + mechanism_catalogue cat2 = cat; + + EXPECT_EQ(cat["fleeb2"], cat2["fleeb2"]); + + auto fleeb2_instance = cat.instance<foo_backend>("fleeb2"); + auto fleeb2_instance2 = cat2.instance<foo_backend>("fleeb2"); + + EXPECT_EQ(typeid(*fleeb2_instance.get()), typeid(*fleeb2_instance.get())); +} + + diff --git a/tests/unit/test_mechinfo.cpp b/tests/unit/test_mechinfo.cpp index 1e92ba0842a20b1bd4d6e7a501330e1c3ac3591a..ab8329248fe007cf75901d06ac5ae3de2d42f85c 100644 --- a/tests/unit/test_mechinfo.cpp +++ b/tests/unit/test_mechinfo.cpp @@ -2,25 +2,24 @@ #include <string> #include <vector> -#include "mechinfo.hpp" +#include <cell.hpp> +//#include "mechinfo.hpp" #include "../gtest.h" #include "../test_util.hpp" -// TODO: expand tests when we have exported mechanism schemata -// from modcc. +// TODO: This test is really checking part of the recipe description +// for cable1d cells, so move it there. Make actual tests for mechinfo +// here! using namespace arb; -TEST(mechanism_spec, setting) { - mechanism_spec m("foo"); +TEST(mechanism_desc, setting) { + mechanism_desc m("foo"); m.set("a", 3.2); m.set("b", 4.3); - auto dflt = m["c"]; - EXPECT_EQ(0., dflt); // note: 0 default is artefact of dummy schema - EXPECT_EQ(3.2, m["a"]); EXPECT_EQ(4.3, m["b"]); diff --git a/tests/unit/test_multi_event_stream.cu b/tests/unit/test_multi_event_stream_gpu.cpp similarity index 92% rename from tests/unit/test_multi_event_stream.cu rename to tests/unit/test_multi_event_stream_gpu.cpp index 7e1d171915295657a12cd7794f255d2097710678..62bffa02f358760e10e3a841a823ca191fd0f254 100644 --- a/tests/unit/test_multi_event_stream.cu +++ b/tests/unit/test_multi_event_stream_gpu.cpp @@ -2,12 +2,10 @@ #include <random> #include <vector> -#include <cuda.h> #include "../gtest.h" #include <backends/event.hpp> #include <backends/gpu/multi_event_stream.hpp> -#include <backends/gpu/time_ops.hpp> #include <memory/wrappers.hpp> #include <util/rangeutil.hpp> @@ -65,33 +63,20 @@ TEST(multi_event_stream, init) { EXPECT_TRUE(m.empty()); } -__global__ -void copy_marked_events_kernel( +// CUDA kernel wrapper: +void run_copy_marked_events_kernel( unsigned ci, deliverable_event_stream::state state, deliverable_event_data* store, unsigned& count, - unsigned max_ev) -{ - // use only one thread here - if (threadIdx.x || blockIdx.x) return; - - unsigned k = 0; - auto begin = state.ev_data+state.begin_offset[ci]; - auto end = state.ev_data+state.end_offset[ci]; - for (auto p = begin; p<end; ++p) { - if (k>=max_ev) break; - store[k++] = *p; - } - count = k; -} + unsigned max_ev); std::vector<deliverable_event_data> copy_marked_events(int ci, deliverable_event_stream& m) { unsigned max_ev = 1000; memory::device_vector<deliverable_event_data> store(max_ev); memory::device_vector<unsigned> counter(1); - copy_marked_events_kernel<<<1,1>>>(ci, m.marked_events(), store.data(), *counter.data(), max_ev); + run_copy_marked_events_kernel(ci, m.marked_events(), store.data(), *counter.data(), max_ev); unsigned n_ev = counter[0]; std::vector<deliverable_event_data> ev(n_ev); memory::copy(store(0, n_ev), ev); diff --git a/tests/unit/test_multi_event_stream_gpu.cu b/tests/unit/test_multi_event_stream_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..ed9bd83a0374cdebf37e71e74ea6385d14987912 --- /dev/null +++ b/tests/unit/test_multi_event_stream_gpu.cu @@ -0,0 +1,40 @@ +#include <backends/event.hpp> +#include <backends/multi_event_stream_state.hpp> + +using namespace arb; + +using stream_state = multi_event_stream_state<deliverable_event_data>; + +namespace kernel { +__global__ +void copy_marked_events_kernel( + unsigned ci, + stream_state state, + deliverable_event_data* store, + unsigned& count, + unsigned max_ev) +{ + // use only one thread here + if (threadIdx.x || blockIdx.x) return; + + unsigned k = 0; + auto begin = state.ev_data+state.begin_offset[ci]; + auto end = state.ev_data+state.end_offset[ci]; + for (auto p = begin; p<end; ++p) { + if (k>=max_ev) break; + store[k++] = *p; + } + count = k; +} +} + +void run_copy_marked_events_kernel( + unsigned ci, + stream_state state, + deliverable_event_data* store, + unsigned& count, + unsigned max_ev) +{ + kernel::copy_marked_events_kernel<<<1,1>>>(ci, state, store, count, max_ev); +} + diff --git a/tests/unit/test_padded.cpp b/tests/unit/test_padded.cpp index 4bdb453d86f9f2128ae6cd72d8c982ecf859afb4..ebe2049be6a34be395f1f4ebf8605223a9111560 100644 --- a/tests/unit/test_padded.cpp +++ b/tests/unit/test_padded.cpp @@ -1,10 +1,5 @@ #include <cstdint> -#if (__GLIBC__==2) -#include <malloc.h> -#define INSTRUMENT_MALLOC -#endif - #include <util/padded_alloc.hpp> #include "../gtest.h" @@ -52,23 +47,24 @@ TEST(padded_vector, allocator_propagation) { EXPECT_EQ(1u, pb.alignment()); EXPECT_NE(pa, pb); - // Don't propagate on copy- or move-assignment: + // Propagate on copy- or move-assignment: b = a; - EXPECT_EQ(pb.alignment(), b.get_allocator().alignment()); - EXPECT_NE(pb.alignment(), pa.alignment()); + EXPECT_NE(pb.alignment(), b.get_allocator().alignment()); + EXPECT_EQ(pa.alignment(), b.get_allocator().alignment()); pvector<double> c; c = std::move(a); - EXPECT_NE(c.get_allocator().alignment(), pa.alignment()); + EXPECT_EQ(c.get_allocator().alignment(), pa.alignment()); } -#ifdef INSTRUMENT_MALLOC +#ifdef CAN_INSTRUMENT_MALLOC struct alloc_data { unsigned n_malloc = 0; unsigned n_realloc = 0; unsigned n_memalign = 0; + unsigned n_free = 0; std::size_t last_malloc = -1; std::size_t last_realloc = -1; @@ -93,6 +89,10 @@ struct count_allocs: testing::with_instrumented_malloc { data.last_memalign = size; } + void on_free(void*, const void*) override { + ++data.n_free; + } + void reset() { data = alloc_data(); } @@ -113,15 +113,15 @@ TEST(padded_vector, instrumented) { EXPECT_EQ(0u, mdata.n_realloc); EXPECT_EQ(expected_v1_alloc, mdata.last_memalign); - // Move assignment: v2 has differing alignment guarantee, so cannot - // take ownership of v1's data. We expect that v2 will need to allocate. + // Move assignment: allocators propagate, so we do not expect v2 + // to perform a new allocation. pvector<double> v2p32(10, pad32); A.reset(); v2p32 = std::move(v1p256); mdata = A.data; - EXPECT_EQ(1u, mdata.n_memalign); + EXPECT_EQ(0u, mdata.n_memalign); EXPECT_EQ(0u, mdata.n_malloc); EXPECT_EQ(0u, mdata.n_realloc); @@ -148,12 +148,13 @@ TEST(padded_vector, instrumented) { EXPECT_EQ(expected_v5_alloc, mdata.last_memalign); A.reset(); - v5p32 = v3p256; // different alignment, but enough space, so shouldn't reallocate. + v5p32 = v3p256; // enough space, but different alignment, so should free and then allocate. mdata = A.data; - EXPECT_EQ(0u, mdata.n_memalign); + EXPECT_EQ(1u, mdata.n_free); + EXPECT_EQ(1u, mdata.n_memalign); EXPECT_EQ(0u, mdata.n_malloc); EXPECT_EQ(0u, mdata.n_realloc); } -#endif // ifdef INSTRUMENT_MALLOC +#endif // ifdef CAN_INSTRUMENT_MALLOC diff --git a/tests/unit/test_path.cpp b/tests/unit/test_path.cpp index 350f1b29ffb633fb0adcbdd1568d19caa4e7ccac..4fc98e58827dddedb643146ef9c9b385f243eacd 100644 --- a/tests/unit/test_path.cpp +++ b/tests/unit/test_path.cpp @@ -68,8 +68,8 @@ TEST(path, posix_native) { EXPECT_EQ(qs, qs_bis); // cstr - const char *c = posix_path{ps}.c_str(); - EXPECT_TRUE(!std::strcmp(c, ps.c_str())); + posix_path ps_path{ps}; + EXPECT_TRUE(!std::strcmp(ps_path.c_str(), ps.c_str())); } TEST(path, posix_generic) { diff --git a/tests/unit/test_probe.cpp b/tests/unit/test_probe.cpp index 38a279e3d90b1cf46a09e02932f30a4134cc6143..92ff46dd5e33c7feb1f7487c77fc5d780ee46eb3 100644 --- a/tests/unit/test_probe.cpp +++ b/tests/unit/test_probe.cpp @@ -1,20 +1,23 @@ #include "../gtest.h" +#include <backends/event.hpp> #include <backends/multicore/fvm.hpp> #include <common_types.hpp> #include <cell.hpp> -#include <fvm_multicell.hpp> +#include <fvm_lowered_cell_impl.hpp> #include <util/rangeutil.hpp> +#include "common.hpp" #include "../common_cells.hpp" #include "../simple_recipes.hpp" using namespace arb; +using fvm_cell = fvm_lowered_cell_impl<multicore::backend>; +using shared_state = multicore::backend::shared_state; -TEST(probe, fvm_multicell) -{ - using fvm_cell = fvm::fvm_multicell<arb::multicore::backend>; +ACCESS_BIND(std::unique_ptr<shared_state> fvm_cell::*, fvm_state_ptr, &fvm_cell::state_); +TEST(probe, fvm_lowered_cell) { cell bs = make_cell_ball_and_stick(false); i_clamp stim(0, 100, 0.3); @@ -30,8 +33,8 @@ TEST(probe, fvm_multicell) rec.add_probe(0, 20, cell_probe_address{loc1, cell_probe_address::membrane_voltage}); rec.add_probe(0, 30, cell_probe_address{loc2, cell_probe_address::membrane_current}); - std::vector<fvm_cell::target_handle> targets; - probe_association_map<fvm_cell::probe_handle> probe_map; + std::vector<target_handle> targets; + probe_association_map<probe_handle> probe_map; fvm_cell lcell; lcell.initialize({0}, rec, targets, probe_map); @@ -43,20 +46,24 @@ TEST(probe, fvm_multicell) EXPECT_EQ(20, probe_map.at({0, 1}).tag); EXPECT_EQ(30, probe_map.at({0, 2}).tag); - fvm_cell::probe_handle p0 = probe_map.at({0, 0}).handle; - fvm_cell::probe_handle p1 = probe_map.at({0, 1}).handle; - fvm_cell::probe_handle p2 = probe_map.at({0, 2}).handle; + probe_handle p0 = probe_map.at({0, 0}).handle; + probe_handle p1 = probe_map.at({0, 1}).handle; + probe_handle p2 = probe_map.at({0, 2}).handle; // Expect initial probe values to be the resting potential // for the voltage probes (cell membrane potential should // be constant), and zero for the current probe. - auto resting = lcell.voltage()[0]; + auto& state = *(lcell.*fvm_state_ptr).get(); + auto& voltage = state.voltage; + + auto resting = voltage[0]; EXPECT_NE(0.0, resting); - EXPECT_EQ(resting, lcell.probe(p0)); - EXPECT_EQ(resting, lcell.probe(p1)); - EXPECT_EQ(0.0, lcell.probe(p2)); + // (Probe handles are just pointers in this implementation). + EXPECT_EQ(resting, *p0); + EXPECT_EQ(resting, *p1); + EXPECT_EQ(0.0, *p2); // After an integration step, expect voltage probe values // to differ from resting, and between each other, and @@ -65,15 +72,13 @@ TEST(probe, fvm_multicell) // First probe, at (0,0), should match voltage in first // compartment. - lcell.setup_integration(0.1, 0.0025, {}, {}); - lcell.step_integration(); - lcell.step_integration(); + lcell.integrate(0.01, 0.0025, {}, {}); - EXPECT_NE(resting, lcell.probe(p0)); - EXPECT_NE(resting, lcell.probe(p1)); - EXPECT_NE(lcell.probe(p0), lcell.probe(p1)); - EXPECT_NE(0.0, lcell.probe(p2)); + EXPECT_NE(resting, *p0); + EXPECT_NE(resting, *p1); + EXPECT_NE(*p0, *p1); + EXPECT_NE(0.0, *p2); - EXPECT_EQ(lcell.voltage()[0], lcell.probe(p0)); + EXPECT_EQ(voltage[0], *p0); } diff --git a/tests/unit/test_range.cpp b/tests/unit/test_range.cpp index 956c85d643b66551499ae12e7626af678e3786cf..cb0a3c63ae794ce5c2267afec11f6bf72343b344 100644 --- a/tests/unit/test_range.cpp +++ b/tests/unit/test_range.cpp @@ -24,7 +24,7 @@ using namespace arb; -using namespace testing::string_literals; +using namespace testing::string_literals; using testing::null_terminated; using testing::nocopy; using testing::nomove; @@ -194,6 +194,40 @@ TEST(range, strictify) { EXPECT_TRUE((std::is_same<decltype(ptr_range), util::range<const char *>>::value)); EXPECT_EQ(cstr, ptr_range.left); EXPECT_EQ(cstr+11, ptr_range.right); + + std::vector<double> empty; + auto empty_vec_range = util::strict_view(empty); + EXPECT_EQ(0u, empty_vec_range.size()); + EXPECT_EQ(empty_vec_range.begin(), empty_vec_range.end()); + +} + +TEST(range, range_view) { + double a[23]; + + auto r1 = util::range_view(a); + EXPECT_EQ(std::begin(a), r1.left); + EXPECT_EQ(std::end(a), r1.right); + + std::list<int> l = {2, 3, 4}; + + auto r2 = util::range_view(l); + EXPECT_EQ(std::begin(l), r2.left); + EXPECT_EQ(std::end(l), r2.right); +} + +TEST(range, range_pointer_view) { + double a[23]; + + auto r1 = util::range_pointer_view(a); + EXPECT_EQ(&a[0], r1.left); + EXPECT_EQ(&a[0]+23, r1.right); + + std::vector<int> v = {2, 3, 4}; + + auto r2 = util::range_pointer_view(v); + EXPECT_EQ(&v[0], r2.left); + EXPECT_EQ(&v[0]+3, r2.right); } TEST(range, subrange) { @@ -465,7 +499,7 @@ TEST(range, sum_by) { TEST(range, is_sequence) { EXPECT_TRUE(arb::util::is_sequence<std::vector<int>>::value); EXPECT_TRUE(arb::util::is_sequence<std::string>::value); - EXPECT_TRUE(arb::util::is_sequence<int[8]>::value); + EXPECT_TRUE(arb::util::is_sequence<int (&)[8]>::value); } TEST(range, all_of_any_of) { @@ -499,41 +533,26 @@ TEST(range, all_of_any_of) { EXPECT_TRUE(util::any_of(cstr("87654x"), pred)); } -TEST(range, keys) { - { - std::map<int, double> map = {{10, 2.0}, {3, 8.0}}; - std::vector<int> expected = {3, 10}; - std::vector<int> keys = util::assign_from(util::keys(map)); - EXPECT_EQ(expected, keys); - } +TEST(range, is_sorted) { + // make a C string into a sentinel-terminated range + auto cstr = [](const char* s) { return util::make_range(s, null_terminated); }; - { - struct cmp { - bool operator()(const nocopy<int>& a, const nocopy<int>& b) const { - return a.value<b.value; - } - }; - std::map<nocopy<int>, double, cmp> map; - map.insert(std::pair<nocopy<int>, double>(11, 2.0)); - map.insert(std::pair<nocopy<int>, double>(2, 0.3)); - map.insert(std::pair<nocopy<int>, double>(2, 0.8)); - map.insert(std::pair<nocopy<int>, double>(5, 0.1)); - - std::vector<int> expected = {2, 5, 11}; - std::vector<int> keys; - for (auto& k: util::keys(map)) { - keys.push_back(k.value); - } - EXPECT_EQ(expected, keys); - } + std::vector<int> s1 = {1, 2, 2, 3, 4}; + std::vector<int> s2 = {}; - { - std::unordered_multimap<int, double> map = {{3, 0.1}, {5, 0.4}, {11, 0.8}, {5, 0.2}}; - std::vector<int> expected = {3, 5, 5, 11}; - std::vector<int> keys = util::assign_from(util::keys(map)); - util::sort(keys); - EXPECT_EQ(expected, keys); - } + std::vector<int> u1 = {1, 2, 2, 1, 4}; + + using ivec = std::vector<int>; + + EXPECT_TRUE(util::is_sorted(ivec{})); + EXPECT_TRUE(util::is_sorted(ivec({1,2,2,3,4}))); + EXPECT_FALSE(util::is_sorted(ivec({1,2,2,1,4}))); + + EXPECT_TRUE(util::is_sorted(cstr("abccd"))); + EXPECT_TRUE(util::is_sorted("abccd"_s)); + + EXPECT_FALSE(util::is_sorted(cstr("hello"))); + EXPECT_FALSE(util::is_sorted("hello"_s)); } template <typename C> @@ -633,6 +652,17 @@ TEST(range, is_sorted_by) { EXPECT_TRUE(util::is_sorted_by(seq, [](int x) { return x+2; }, std::greater<int>{})); } +TEST(range, reverse) { + // make a C string into a sentinel-terminated range + auto cstr = [](const char* s) { return util::make_range(s, null_terminated); }; + + std::string rev; + util::assign(rev, util::reverse_view(cstr("hello"))); + + EXPECT_EQ("olleh"_s, rev); +} + + #ifdef ARB_HAVE_TBB TEST(range, tbb_split) { diff --git a/tests/unit/test_reduce_by_key.cu b/tests/unit/test_reduce_by_key.cu index a226a5584b041a714e080bdc66d41d318f518f70..ad662fe07246af1290a9fde9fc7a3b096966a2e4 100644 --- a/tests/unit/test_reduce_by_key.cu +++ b/tests/unit/test_reduce_by_key.cu @@ -2,7 +2,7 @@ #include <vector> -#include <backends/gpu/kernels/reduce_by_key.hpp> +#include <backends/gpu/reduce_by_key.hpp> #include <memory/memory.hpp> #include <util/span.hpp> #include <util/rangeutil.hpp> diff --git a/tests/unit/test_schedule.cpp b/tests/unit/test_schedule.cpp index 0281d647baaa9456be6947e11374329582a55811..504d50cf7172bd40216c41ce31c2f1801a3e64da 100644 --- a/tests/unit/test_schedule.cpp +++ b/tests/unit/test_schedule.cpp @@ -177,7 +177,7 @@ double poisson_schedule_dispersion(int nbin, double mean_dt, RNG& G) { // random sequence were allowed to vary freely. TEST(schedule, poisson_uniformity) { - // Run Poisson dispersion test for N=101 with two-sided + // Run Poisson dispersion test for N=1001 with two-sided // χ²-test critical value α=0.01. // // Test based on: N·dispersion ~ χ²(N-1) (approximately) diff --git a/tests/unit/test_simd.cpp b/tests/unit/test_simd.cpp index f22de961aafac7ac946e7529d0dbcf37e793b327..1cec58f73214825274a7eff32d36fcd62dee9a3b 100644 --- a/tests/unit/test_simd.cpp +++ b/tests/unit/test_simd.cpp @@ -1,5 +1,9 @@ +#include <algorithm> +#include <array> #include <cmath> +#include <iterator> #include <random> +#include <unordered_set> #include <simd/simd.hpp> #include <simd/avx.hpp> @@ -510,7 +514,55 @@ TYPED_TEST_P(simd_value, maths) { } } -REGISTER_TYPED_TEST_CASE_P(simd_value, elements, element_lvalue, copy_to_from, copy_to_from_masked, construct_masked, arithmetic, compound_assignment, comparison, mask_elements, mask_element_lvalue, mask_copy_to_from, mask_unpack, maths); +TYPED_TEST_P(simd_value, reductions) { + // Only addition for now. + + using simd = TypeParam; + using scalar = typename simd::scalar_type; + constexpr unsigned N = simd::width; + + std::minstd_rand rng(1041); + + for (unsigned i = 0; i<nrounds; ++i) { + scalar a[N], test = 0; + + // To avoid discrepancies due to catastrophic cancelation, + // keep f.p. values non-negative. + + if (std::is_floating_point<scalar>::value) { + fill_random(a, rng, 0, 1); + } + else { + fill_random(a, rng); + } + + simd as(a); + + for (unsigned j = 0; j<N; ++j) { test += a[j]; } + EXPECT_TRUE(testing::almost_eq(test, as.sum())); + } +} + +TYPED_TEST_P(simd_value, simd_array_cast) { + // Test conversion to/from array of scalar type. + + using simd = TypeParam; + using scalar = typename simd::scalar_type; + constexpr unsigned N = simd::width; + + std::minstd_rand rng(1032); + + for (unsigned i = 0; i<nrounds; ++i) { + std::array<scalar, N> a; + + fill_random(a, rng); + simd as = simd_cast<simd>(a); + EXPECT_TRUE(testing::indexed_eq_n(N, as, a)); + EXPECT_TRUE(testing::seq_eq(a, simd_cast<std::array<scalar, N>>(as))); + } +} + +REGISTER_TYPED_TEST_CASE_P(simd_value, elements, element_lvalue, copy_to_from, copy_to_from_masked, construct_masked, arithmetic, compound_assignment, comparison, mask_elements, mask_element_lvalue, mask_copy_to_from, mask_unpack, maths, simd_array_cast, reductions); typedef ::testing::Types< @@ -841,17 +893,16 @@ TYPED_TEST_P(simd_indirect, gather) { for (unsigned i = 0; i<nrounds; ++i) { scalar array[buflen]; - index indirect[N]; + index offset[N]; fill_random(array, rng); - fill_random(indirect, rng, 0, (int)(buflen-1)); + fill_random(offset, rng, 0, (int)(buflen-1)); - simd s; - s.gather(array, simd_index(indirect)); + simd s(indirect(array, simd_index(offset))); scalar test[N]; for (unsigned j = 0; j<N; ++j) { - test[j] = array[indirect[j]]; + test[j] = array[offset[j]]; } EXPECT_TRUE(::testing::indexed_eq_n(N, test, s)); @@ -873,21 +924,21 @@ TYPED_TEST_P(simd_indirect, masked_gather) { for (unsigned i = 0; i<nrounds; ++i) { scalar array[buflen], original[N], test[N]; - index indirect[N]; + index offset[N]; bool mask[N]; fill_random(array, rng); fill_random(original, rng); - fill_random(indirect, rng, 0, (int)(buflen-1)); + fill_random(offset, rng, 0, (int)(buflen-1)); fill_random(mask, rng); for (unsigned j = 0; j<N; ++j) { - test[j] = mask[j]? array[indirect[j]]: original[j]; + test[j] = mask[j]? array[offset[j]]: original[j]; } simd s(original); simd_mask m(mask); - where(m, s).gather(array, simd_index(indirect)); + where(m, s).copy_from(indirect(array, simd_index(offset))); EXPECT_TRUE(::testing::indexed_eq_n(N, test, s)); } @@ -907,21 +958,21 @@ TYPED_TEST_P(simd_indirect, scatter) { for (unsigned i = 0; i<nrounds; ++i) { scalar array[buflen], test[buflen], values[N]; - index indirect[N]; + index offset[N]; fill_random(array, rng); fill_random(values, rng); - fill_random(indirect, rng, 0, (int)(buflen-1)); + fill_random(offset, rng, 0, (int)(buflen-1)); for (unsigned j = 0; j<buflen; ++j) { test[j] = array[j]; } for (unsigned j = 0; j<N; ++j) { - test[indirect[j]] = values[j]; + test[offset[j]] = values[j]; } simd s(values); - s.scatter(array, simd_index(indirect)); + s.copy_to(indirect(array, simd_index(offset))); EXPECT_TRUE(::testing::indexed_eq_n(N, test, array)); } @@ -942,47 +993,182 @@ TYPED_TEST_P(simd_indirect, masked_scatter) { for (unsigned i = 0; i<nrounds; ++i) { scalar array[buflen], test[buflen], values[N]; - index indirect[N]; + index offset[N]; bool mask[N]; fill_random(array, rng); fill_random(values, rng); - fill_random(indirect, rng, 0, (int)(buflen-1)); + fill_random(offset, rng, 0, (int)(buflen-1)); fill_random(mask, rng); for (unsigned j = 0; j<buflen; ++j) { test[j] = array[j]; } for (unsigned j = 0; j<N; ++j) { - if (mask[j]) { test[indirect[j]] = values[j]; } + if (mask[j]) { test[offset[j]] = values[j]; } } simd s(values); simd_mask m(mask); - where(m, s).scatter(array, simd_index(indirect)); + where(m, s).copy_to(indirect(array, simd_index(offset))); + + EXPECT_TRUE(::testing::indexed_eq_n(N, test, array)); + } +} + +TYPED_TEST_P(simd_indirect, add_and_subtract) { + using simd = typename TypeParam::simd; + using simd_index = typename TypeParam::simd_index; + + constexpr unsigned N = simd::width; + using scalar = typename simd::scalar_type; + using index = typename simd_index::scalar_type; + + std::minstd_rand rng(1011); + + constexpr std::size_t buflen = 1000; + + for (unsigned i = 0; i<nrounds; ++i) { + scalar array[buflen], test[buflen], values[N]; + index offset[N]; + + fill_random(array, rng); + fill_random(values, rng); + fill_random(offset, rng, 0, (int)(buflen-1)); + + for (unsigned j = 0; j<buflen; ++j) { + test[j] = array[j]; + } + for (unsigned j = 0; j<N; ++j) { + test[offset[j]] += values[j]; + } + indirect(array, simd_index(offset)) += simd(values); + EXPECT_TRUE(::testing::indexed_eq_n(N, test, array)); + + fill_random(offset, rng, 0, (int)(buflen-1)); + + for (unsigned j = 0; j<buflen; ++j) { + test[j] = array[j]; + } + for (unsigned j = 0; j<N; ++j) { + test[offset[j]] -= values[j]; + } + + indirect(array, simd_index(offset)) -= simd(values); EXPECT_TRUE(::testing::indexed_eq_n(N, test, array)); } } +template <typename X> +bool unique_elements(const X& xs) { + using std::begin; + std::unordered_set<typename std::decay<decltype(*begin(xs))>::type> set; + for (auto& x: xs) { + if (!set.insert(x).second) return false; + } + return true; +} + +TYPED_TEST_P(simd_indirect, constrained_add) { + using simd = typename TypeParam::simd; + using simd_index = typename TypeParam::simd_index; + + constexpr unsigned N = simd::width; + using scalar = typename simd::scalar_type; + using index = typename simd_index::scalar_type; + + std::minstd_rand rng(1011); + + constexpr std::size_t buflen = 1000; + + for (unsigned i = 0; i<nrounds; ++i) { + scalar array[buflen], test[buflen], values[N]; + index offset[N]; + + fill_random(array, rng); + fill_random(values, rng); + + auto make_test_array = [&]() { + for (unsigned j = 0; j<buflen; ++j) { + test[j] = array[j]; + } + for (unsigned j = 0; j<N; ++j) { + test[offset[j]] += values[j]; + } + }; + + // Independent: + + do { + fill_random(offset, rng, 0, (int)(buflen-1)); + } while (!unique_elements(offset)); + + make_test_array(); + indirect(array, simd_index(offset), index_constraint::independent) += simd(values); + + EXPECT_TRUE(::testing::indexed_eq_n(N, test, array)); + + // Contiguous: + + offset[0] = make_udist<index>(0, (int)(buflen)-N)(rng); + for (unsigned j = 1; j<N; ++j) { + offset[j] = offset[0]+j; + } + + make_test_array(); + indirect(array, simd_index(offset), index_constraint::contiguous) += simd(values); -REGISTER_TYPED_TEST_CASE_P(simd_indirect, gather, masked_gather, scatter, masked_scatter); + EXPECT_TRUE(::testing::indexed_eq_n(N, test, array)); + + // Constant: + + for (unsigned j = 1; j<N; ++j) { + offset[j] = offset[0]; + } + + // Reduction can be done in a different order, so 1) use approximate test + // and 2) keep f.p. values non-negative to avoid catastrophic cancellation. + + if (std::is_floating_point<scalar>::value) { + fill_random(array, rng, 0, 1); + fill_random(values, rng, 0, 1); + } + + make_test_array(); + indirect(array, simd_index(offset), index_constraint::constant) += simd(values); + + EXPECT_TRUE(::testing::indexed_almost_eq_n(N, test, array)); + + } +} + +REGISTER_TYPED_TEST_CASE_P(simd_indirect, gather, masked_gather, scatter, masked_scatter, add_and_subtract, constrained_add); typedef ::testing::Types< #ifdef __AVX__ simd_and_index<simd<double, 4, simd_abi::avx>, simd<int, 4, simd_abi::avx>>, + + simd_and_index<simd<int, 4, simd_abi::avx>, + simd<int, 4, simd_abi::avx>>, #endif #ifdef __AVX2__ simd_and_index<simd<double, 4, simd_abi::avx2>, simd<int, 4, simd_abi::avx2>>, + + simd_and_index<simd<int, 4, simd_abi::avx2>, + simd<int, 4, simd_abi::avx2>>, #endif #ifdef __AVX512F__ simd_and_index<simd<double, 8, simd_abi::avx512>, simd<int, 8, simd_abi::avx512>>, + + simd_and_index<simd<int, 8, simd_abi::avx512>, + simd<int, 8, simd_abi::avx512>>, #endif simd_and_index<simd<float, 4, simd_abi::generic>, @@ -999,3 +1185,73 @@ typedef ::testing::Types< > simd_indirect_test_types; INSTANTIATE_TYPED_TEST_CASE_P(S, simd_indirect, simd_indirect_test_types); + + +// SIMD cast tests + +template <typename A, typename B> +struct simd_pair { + using simd_first = A; + using simd_second = B; +}; + +template <typename SI> +struct simd_casting: public ::testing::Test {}; + +TYPED_TEST_CASE_P(simd_casting); + +TYPED_TEST_P(simd_casting, cast) { + using simd_x = typename TypeParam::simd_first; + using simd_y = typename TypeParam::simd_second; + + constexpr unsigned N = simd_x::width; + using scalar_x = typename simd_x::scalar_type; + using scalar_y = typename simd_y::scalar_type; + + std::minstd_rand rng(1011); + + for (unsigned i = 0; i<nrounds; ++i) { + scalar_x x[N], test_x[N]; + scalar_y y[N], test_y[N]; + + fill_random(x, rng); + fill_random(y, rng); + + for (unsigned j = 0; j<N; ++j) { + test_y[j] = static_cast<scalar_y>(x[j]); + test_x[j] = static_cast<scalar_x>(y[j]); + } + + simd_x xs(x); + simd_y ys(y); + + EXPECT_TRUE(testing::indexed_eq_n(N, test_y, simd_cast<simd_y>(xs))); + EXPECT_TRUE(testing::indexed_eq_n(N, test_x, simd_cast<simd_x>(ys))); + } +} + +REGISTER_TYPED_TEST_CASE_P(simd_casting, cast); + + +typedef ::testing::Types< + +#ifdef __AVX__ + simd_pair<simd<double, 4, simd_abi::avx>, + simd<int, 4, simd_abi::avx>>, +#endif + +#ifdef __AVX2__ + simd_pair<simd<double, 4, simd_abi::avx2>, + simd<int, 4, simd_abi::avx2>>, +#endif + +#ifdef __AVX512F__ + simd_pair<simd<double, 8, simd_abi::avx512>, + simd<int, 8, simd_abi::avx512>>, +#endif + + simd_pair<simd<double, 4, simd_abi::default_abi>, + simd<float, 4, simd_abi::default_abi>> +> simd_casting_test_types; + +INSTANTIATE_TYPED_TEST_CASE_P(S, simd_casting, simd_casting_test_types); diff --git a/tests/unit/test_spikes.cpp b/tests/unit/test_spikes.cpp index e8df09d59978639606892c4f65d04e7780d81ec8..2cc3850917915bb5cabc508b5011b442a457abd7 100644 --- a/tests/unit/test_spikes.cpp +++ b/tests/unit/test_spikes.cpp @@ -7,7 +7,7 @@ using namespace arb; -// This source is included in `test_spikes.cu`, which defines +// This source is included in `test_spikes_gpu.cpp`, which defines // USE_BACKEND to override the default `multicore::backend` // used for CPU tests. @@ -18,17 +18,17 @@ using backend = USE_BACKEND; #endif TEST(spikes, threshold_watcher) { - using size_type = backend::size_type; using value_type = backend::value_type; + using index_type = backend::index_type; using array = backend::array; using iarray = backend::iarray; - using list = backend::threshold_watcher::crossing_list; + using list = std::vector<threshold_crossing>; // the test creates a watch on 3 values in the array values (which has 10 // elements in total). const auto n = 10; - const std::vector<size_type> index{0, 5, 7}; + const std::vector<index_type> index{0, 5, 7}; const std::vector<value_type> thresh{1., 2., 3.}; // all values are initially 0, except for values[5] which we set @@ -50,7 +50,7 @@ TEST(spikes, threshold_watcher) { list expected; // create the watch - backend::threshold_watcher watch(cell_index, time_before, time_after, values, index, thresh); + backend::threshold_watcher watch(cell_index.data(), time_before.data(), time_after.data(), values.data(), index, thresh); // initially the first and third watch should not be spiking // the second is spiking diff --git a/tests/unit/test_spikes.cu b/tests/unit/test_spikes_gpu.cpp similarity index 100% rename from tests/unit/test_spikes.cu rename to tests/unit/test_spikes_gpu.cpp diff --git a/tests/unit/test_stimulus.cpp b/tests/unit/test_stimulus.cpp deleted file mode 100644 index b706767c519e426ddd5a2afd8a3a3556d8618427..0000000000000000000000000000000000000000 --- a/tests/unit/test_stimulus.cpp +++ /dev/null @@ -1,33 +0,0 @@ -#include "../gtest.h" - -#include <stimulus.hpp> - -TEST(stimulus, i_clamp) -{ - using namespace arb; - - // stimulus with delay 2, duration 0.5, amplitude 6.0 - i_clamp stim(2.0, 0.5, 6.0); - - EXPECT_EQ(stim.delay(), 2.0); - EXPECT_EQ(stim.duration(), 0.5); - EXPECT_EQ(stim.amplitude(), 6.0); - - // test that current only turned on in the half open interval - // t \in [2, 2.5) - EXPECT_EQ(stim.amplitude(0.0), 0.0); - EXPECT_EQ(stim.amplitude(1.0), 0.0); - EXPECT_EQ(stim.amplitude(2.0), 6.0); - EXPECT_EQ(stim.amplitude(2.4999), 6.0); - EXPECT_EQ(stim.amplitude(2.5), 0.0); - - // update: delay 1.0, duration 1.5, amplitude 3.0 - stim.set_delay(1.0); - stim.set_duration(1.5); - stim.set_amplitude(3.0); - - EXPECT_EQ(stim.delay(), 1.0); - EXPECT_EQ(stim.duration(), 1.5); - EXPECT_EQ(stim.amplitude(), 3.0); -} - diff --git a/tests/unit/test_swcio.cpp b/tests/unit/test_swcio.cpp index a27572ee3332bdb98e5bc7a074b0ff2d95b2663b..6f415c39f13e2ef902982112062f85f310cf68fc 100644 --- a/tests/unit/test_swcio.cpp +++ b/tests/unit/test_swcio.cpp @@ -473,9 +473,9 @@ TEST(swc_io, cell_construction) { EXPECT_TRUE(cell.has_soma()); EXPECT_EQ(4u, cell.num_segments()); - EXPECT_EQ(norm(points[1]-points[2]), cell.cable(1)->length()); - EXPECT_EQ(norm(points[2]-points[3]), cell.cable(2)->length()); - EXPECT_EQ(norm(points[2]-points[4]) + norm(points[4]-points[5]), + EXPECT_DOUBLE_EQ(norm(points[1]-points[2]), cell.cable(1)->length()); + EXPECT_DOUBLE_EQ(norm(points[2]-points[3]), cell.cable(2)->length()); + EXPECT_DOUBLE_EQ(norm(points[2]-points[4]) + norm(points[4]-points[5]), cell.cable(3)->length()); diff --git a/tests/unit/test_synapses.cpp b/tests/unit/test_synapses.cpp index 45e423faad1e4df807a29ffab589e27c745eedfa..a5559409e5af8f77368fd9d80fbaaa437d9112d7 100644 --- a/tests/unit/test_synapses.cpp +++ b/tests/unit/test_synapses.cpp @@ -1,18 +1,47 @@ #include "../gtest.h" -#include "../test_util.hpp" + +#include <cmath> +#include <tuple> +#include <vector> #include <cell.hpp> +#include <constants.hpp> +#include <mechcat.hpp> #include <backends/multicore/fvm.hpp> +#include <backends/multicore/mechanism.hpp> +#include <util/optional.hpp> +#include <util/maputil.hpp> +#include <util/range.hpp> + +#include "common.hpp" +#include "../test_util.hpp" + +using namespace arb; + +using backend = ::arb::multicore::backend; +using shared_state = backend::shared_state; +using value_type = backend::value_type; +using size_type = backend::size_type; + +// Access to mechanisms protected data: +using field_table_type = std::vector<std::pair<const char*, value_type**>>; +ACCESS_BIND(field_table_type (multicore::mechanism::*)(), field_table_ptr, &multicore::mechanism::field_table) + +util::range<const value_type*> mechanism_field(std::unique_ptr<multicore::mechanism>& m, const std::string& key) { + if (auto opt_ptr = util::value_by_key((m.get()->*field_table_ptr)(), key)) { + const value_type* field = *opt_ptr.value(); + return util::make_range(field, field+m->size()); + } + throw std::logic_error("internal error: no such field in mechanism"); +} -#include <mechanisms/multicore/expsyn_cpu.hpp> -#include <mechanisms/multicore/exp2syn_cpu.hpp> +ACCESS_BIND(const value_type* multicore::mechanism::*, vec_v_ptr, &multicore::mechanism::vec_v_) +ACCESS_BIND(value_type* multicore::mechanism::*, vec_i_ptr, &multicore::mechanism::vec_i_) -// compares results with those generated by nrn/ball_and_stick.py -TEST(synapses, add_to_cell) -{ +TEST(synapses, add_to_cell) { using namespace arb; - arb::cell cell; + ::arb::cell cell; // Soma with diameter 12.6157 um and HH channel auto soma = cell.add_soma(12.6157/2.0); @@ -38,129 +67,110 @@ TEST(synapses, add_to_cell) EXPECT_EQ(syns[2].mechanism.name(), "expsyn"); } -TEST(synapses, expsyn_basic_state) -{ - using namespace arb; - using memory::make_const_view; - using size_type = multicore::backend::size_type; +template <typename Seq> +static bool all_equal_to(const Seq& s, double v) { + return util::all_of(s, [v](double x) { + return (std::isnan(v) && std::isnan(x)) || v==x; + }); +} + +TEST(synapses, syn_basic_state) { + using util::fill; using value_type = multicore::backend::value_type; + using index_type = multicore::backend::index_type; - using synapse_type = multicore::mechanism_expsyn<multicore::backend>; int num_syn = 4; int num_comp = 4; int num_cell = 1; - synapse_type::iarray cell_index(num_comp, 0); - synapse_type::array time(num_cell, 0); - synapse_type::array time_to(num_cell, 0.1); - synapse_type::array dt(num_comp, 0.1); + auto multicore_mechanism_instance = [](const char* name) { + return std::unique_ptr<multicore::mechanism>( + dynamic_cast<multicore::mechanism*>( + global_default_catalogue().instance<backend>(name).release())); + }; - std::vector<size_type> node_index(num_syn, 0); - std::vector<value_type> weights(num_syn, 1.0); - synapse_type::array voltage(num_comp, -65.0); - synapse_type::array current(num_comp, 1.0); + auto expsyn = multicore_mechanism_instance("expsyn"); + ASSERT_TRUE(expsyn); - auto mech = make_mechanism<synapse_type>(0, cell_index, time, time_to, dt, voltage, current, make_const_view(weights), make_const_view(node_index)); - auto ptr = dynamic_cast<synapse_type*>(mech.get()); + auto exp2syn = multicore_mechanism_instance("exp2syn"); + ASSERT_TRUE(exp2syn); - auto n = ptr->size(); - using view = synapse_type::view; + auto align = std::max(expsyn->data_alignment(), exp2syn->data_alignment()); + shared_state state(num_cell, std::vector<index_type>(num_comp, 0), align); - // parameters initialized to default values - for(auto e : view(ptr->e, n)) { - EXPECT_EQ(e, 0.); - } - for(auto tau : view(ptr->tau, n)) { - EXPECT_EQ(tau, 2.0); - } + state.reset(-65., constant::hh_squid_temp); + fill(state.current_density, 1.0); + fill(state.time_to, 0.1); + state.set_dt(); - // current and voltage vectors correctly hooked up - for(auto v : view(ptr->vec_v_, n)) { - EXPECT_EQ(v, -65.); - } - for(auto i : view(ptr->vec_i_, n)) { - EXPECT_EQ(i, 1.0); - } + std::vector<index_type> syn_cv(num_syn, 0); + std::vector<value_type> syn_weight(num_syn, 1.0); - // should be initialized to NaN - for(auto g : view(ptr->g, n)) { - EXPECT_NE(g, g); - } + expsyn->instantiate(0, state, {syn_cv, syn_weight}); + exp2syn->instantiate(1, state, {syn_cv, syn_weight}); - // initialize state then check g has been set to zero - ptr->nrn_init(); - for(auto g : view(ptr->g, n)) { - EXPECT_EQ(g, 0.); - } + // Parameters initialized to default values? - // call net_receive on two of the synapses - ptr->net_receive(1, 3.14); - ptr->net_receive(3, 1.04); - EXPECT_EQ(ptr->g[1], 3.14); - EXPECT_EQ(ptr->g[3], 1.04); -} + EXPECT_TRUE(all_equal_to(mechanism_field(expsyn, "e"), 0.)); + EXPECT_TRUE(all_equal_to(mechanism_field(expsyn, "tau"), 2.0)); + EXPECT_TRUE(all_equal_to(mechanism_field(expsyn, "g"), NAN)); -TEST(synapses, exp2syn_basic_state) -{ - using namespace arb; - using memory::make_const_view; - using size_type = multicore::backend::size_type; - using value_type = multicore::backend::value_type; + EXPECT_TRUE(all_equal_to(mechanism_field(exp2syn, "e"), 0.)); + EXPECT_TRUE(all_equal_to(mechanism_field(exp2syn, "tau1"), 0.5)); + EXPECT_TRUE(all_equal_to(mechanism_field(exp2syn, "tau2"), 2.0)); + EXPECT_TRUE(all_equal_to(mechanism_field(exp2syn, "A"), NAN)); + EXPECT_TRUE(all_equal_to(mechanism_field(exp2syn, "B"), NAN)); - using synapse_type = multicore::mechanism_exp2syn<multicore::backend>; - int num_syn = 4; - int num_comp = 4; - int num_cell = 1; + // Current and voltage views correctly hooked up? - synapse_type::iarray cell_index(num_comp, 0); - synapse_type::array time(num_cell, 0); - synapse_type::array time_to(num_cell, 0.1); - synapse_type::array dt(num_comp, 0.1); + const value_type* v_ptr; + v_ptr = expsyn.get()->*vec_v_ptr; + EXPECT_TRUE(all_equal_to(util::make_range(v_ptr, v_ptr+num_comp), -65.)); - std::vector<size_type> node_index(num_syn, 0); - std::vector<value_type> weights(num_syn, 1.0); - synapse_type::array voltage(num_comp, -65.0); - synapse_type::array current(num_comp, 1.0); + v_ptr = exp2syn.get()->*vec_v_ptr; + EXPECT_TRUE(all_equal_to(util::make_range(v_ptr, v_ptr+num_comp), -65.)); - auto mech = make_mechanism<synapse_type>(0, cell_index, time, time_to, dt, voltage, current, make_const_view(weights), make_const_view(node_index)); - auto ptr = dynamic_cast<synapse_type*>(mech.get()); + const value_type* i_ptr; + i_ptr = expsyn.get()->*vec_i_ptr; + EXPECT_TRUE(all_equal_to(util::make_range(i_ptr, i_ptr+num_comp), 1.)); - auto n = ptr->size(); - using view = synapse_type::view; + i_ptr = exp2syn.get()->*vec_i_ptr; + EXPECT_TRUE(all_equal_to(util::make_range(i_ptr, i_ptr+num_comp), 1.)); - // parameters initialized to default values - for(auto e : view(ptr->e, n)) { - EXPECT_EQ(e, 0.); - } - for(auto tau1: view(ptr->tau1, n)) { - EXPECT_EQ(tau1, 0.5); - } - for(auto tau2: view(ptr->tau2, n)) { - EXPECT_EQ(tau2, 2.0); - } + // Initialize state then check g, A, B have been set to zero. - // should be initialized to NaN - for(auto factor: view(ptr->factor, n)) { - EXPECT_NE(factor, factor); - } + expsyn->nrn_init(); + EXPECT_TRUE(all_equal_to(mechanism_field(expsyn, "g"), 0.)); - // initialize state then check factor has sane (positive) value - // and A and B are zero - ptr->nrn_init(); - for(auto factor: view(ptr->factor, n)) { - EXPECT_GT(factor, 0.); - } - for(auto A: view(ptr->A, n)) { - EXPECT_EQ(A, 0.); - } - for(auto B: view(ptr->B, n)) { - EXPECT_EQ(B, 0.); - } + exp2syn->nrn_init(); + EXPECT_TRUE(all_equal_to(mechanism_field(exp2syn, "A"), 0.)); + EXPECT_TRUE(all_equal_to(mechanism_field(exp2syn, "B"), 0.)); + + // Deliver two events (at time 0), one each to expsyn synapses 1 and 3 + // and exp2syn synapses 0 and 2. + + std::vector<deliverable_event> events = { + {0., {0, 1, 0}, 3.14f}, + {0., {0, 3, 0}, 1.41f}, + {0., {1, 0, 0}, 2.71f}, + {0., {1, 2, 0}, 0.07f} + }; + state.deliverable_events.init(events); + state.deliverable_events.mark_until_after(state.time); - // call net_receive on two of the synapses - ptr->net_receive(1, 3.14); - ptr->net_receive(3, 1.04); + expsyn->deliver_events(); + exp2syn->deliver_events(); - EXPECT_NEAR(ptr->A[1], ptr->factor[1]*3.14, 1e-6); - EXPECT_NEAR(ptr->B[3], ptr->factor[3]*1.04, 1e-6); + using fvec = std::vector<fvm_value_type>; + + EXPECT_TRUE(testing::seq_almost_eq<fvm_value_type>( + fvec({0, 3.14f, 0, 1.41f}), mechanism_field(expsyn, "g"))); + + double factor = mechanism_field(exp2syn, "factor")[0]; + EXPECT_TRUE(factor>1.); + fvec expected = {2.71f*factor, 0, 0.07f*factor, 0}; + + EXPECT_TRUE(testing::seq_almost_eq<fvm_value_type>(expected, mechanism_field(exp2syn, "A"))); + EXPECT_TRUE(testing::seq_almost_eq<fvm_value_type>(expected, mechanism_field(exp2syn, "B"))); } + diff --git a/tests/validation/convergence_test.hpp b/tests/validation/convergence_test.hpp index 6cbb945d0ba198da5bdd987eb614b1c3773bfe0d..610d4abb5552e3675c0d4f16b826312b30e6beb0 100644 --- a/tests/validation/convergence_test.hpp +++ b/tests/validation/convergence_test.hpp @@ -144,8 +144,8 @@ inline std::vector<float> stimulus_ends(const cell& c) { std::vector<float> ts; for (const auto& stimulus: c.stimuli()) { - float t0 = stimulus.clamp.delay(); - float t1 = t0+stimulus.clamp.duration(); + float t0 = stimulus.clamp.delay; + float t1 = t0+stimulus.clamp.duration; ts.push_back(t0); ts.push_back(t1); } diff --git a/tests/validation/validate_synapses.cpp b/tests/validation/validate_synapses.cpp index 527262ecd826d59eb2d341d4e3ba52fd333dc092..e69934ecc22e63c18f657cafa90a95a488e606ae 100644 --- a/tests/validation/validate_synapses.cpp +++ b/tests/validation/validate_synapses.cpp @@ -37,7 +37,7 @@ void run_synapse_test( }; cell c = make_cell_ball_and_stick(false); // no stimuli - mechanism_spec syn_default(syn_type); + mechanism_desc syn_default(syn_type); c.add_synapse({1, 0.5}, syn_default); // injected spike events