From e0f0b5d79417946116adf726e86f52b06791c35f Mon Sep 17 00:00:00 2001
From: Ben Cumming <bcumming@cscs.ch>
Date: Wed, 9 May 2018 13:08:32 +0200
Subject: [PATCH] CUDA back end for the new mechanism infrastructure (#487)

Completes CUDA printing in modcc.
* Add CudaPrinter visitor, overriding CPrinter.
* Add `ostream` `operator<<` overloads for `arb::gpu::shared_state` and `device_view` for debugging.
* Fix GPU back-end bugs.
---
 example/generators/event_gen.cpp          |   4 +-
 modcc/printer/cprinter.cpp                |  43 ----
 modcc/printer/cprinter.hpp                |   2 +-
 modcc/printer/cudaprinter.cpp             | 272 +++++++++++++++++++---
 modcc/printer/cudaprinter.hpp             |  12 +
 modcc/printer/printerutil.cpp             |  50 ++++
 modcc/printer/printerutil.hpp             |  12 +
 src/backends/gpu/mechanism.cpp            |  42 ++--
 src/backends/gpu/mechanism.hpp            |   2 +-
 src/backends/gpu/mechanism_ppack_base.hpp |   2 +
 src/backends/gpu/shared_state.cpp         |  27 ++-
 src/backends/gpu/shared_state.hpp         |   7 +-
 src/backends/gpu/threshold_watcher.cu     |  18 +-
 src/backends/gpu/threshold_watcher.hpp    |  38 +--
 src/backends/multicore/shared_state.cpp   |  29 +--
 src/memory/memory.hpp                     |  13 ++
 tests/unit/test_algorithms.cpp            |   1 +
 17 files changed, 443 insertions(+), 131 deletions(-)

diff --git a/example/generators/event_gen.cpp b/example/generators/event_gen.cpp
index 0eac3718..38f3094e 100644
--- a/example/generators/event_gen.cpp
+++ b/example/generators/event_gen.cpp
@@ -145,8 +145,8 @@ int main() {
     // Now attach the sampler at probe_id, with sampling schedule sched, writing to voltage
     sim.add_sampler(arb::one_probe(probe_id), sched, arb::make_simple_sampler(voltage));
 
-    // Run the simulation for 1 s (1000 ms), with time steps of 0.01 ms.
-    sim.run(50, 0.01);
+    // Run the simulation for 100 ms, with time steps of 0.01 ms.
+    sim.run(100, 0.01);
 
     // Write the samples to a json file.
     write_trace_json(voltage);
diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp
index b17184de..699485a7 100644
--- a/modcc/printer/cprinter.cpp
+++ b/modcc/printer/cprinter.cpp
@@ -313,49 +313,6 @@ std::string emit_cpp_source(const Module& module_, const std::string& ns, simd_s
     return out.str();
 }
 
-struct indexed_variable_info {
-    std::string data_var;
-    std::string index_var;
-};
-
-indexed_variable_info decode_indexed_variable(IndexedVariable* sym) {
-    std::string data_var, ion_pfx;
-    std::string index_var = "node_index_";
-
-    if (sym->is_ion()) {
-        ion_pfx = "ion_"+to_string(sym->ion_channel())+"_";
-        index_var = ion_pfx+"index_";
-    }
-
-    switch (sym->data_source()) {
-    case sourceKind::voltage:
-        data_var="vec_v_";
-        break;
-    case sourceKind::current:
-        data_var="vec_i_";
-        break;
-    case sourceKind::dt:
-        data_var="vec_dt_";
-        break;
-    case sourceKind::ion_current:
-        data_var=ion_pfx+".current_density";
-        break;
-    case sourceKind::ion_revpot:
-        data_var=ion_pfx+".reversal_potential";
-        break;
-    case sourceKind::ion_iconc:
-        data_var=ion_pfx+".internal_concentration";
-        break;
-    case sourceKind::ion_econc:
-        data_var=ion_pfx+".external_concentration";
-        break;
-    default:
-        throw compiler_exception("unrecognized indexed data source", sym->location());
-    }
-
-    return {data_var, index_var};
-}
-
 // Scalar printing:
 
 void CPrinter::visit(IdentifierExpression *e) {
diff --git a/modcc/printer/cprinter.hpp b/modcc/printer/cprinter.hpp
index a92d014d..77b2e37c 100644
--- a/modcc/printer/cprinter.hpp
+++ b/modcc/printer/cprinter.hpp
@@ -34,7 +34,7 @@ public:
     void visit(BinaryExpression* e) override { cexpr_emit(e, out_, this); }
     void visit(IfExpression* e) override { cexpr_emit(e, out_, this); }
 
-private:
+protected:
     std::ostream& out_;
 };
 
diff --git a/modcc/printer/cudaprinter.cpp b/modcc/printer/cudaprinter.cpp
index c1272015..8051ab09 100644
--- a/modcc/printer/cudaprinter.cpp
+++ b/modcc/printer/cudaprinter.cpp
@@ -3,6 +3,7 @@
 #include <string>
 #include <unordered_set>
 
+#include "cudaprinter.hpp"
 #include "expression.hpp"
 #include "io/ostream_wrappers.hpp"
 #include "io/prefixbuf.hpp"
@@ -13,8 +14,23 @@ using io::indent;
 using io::popindent;
 using io::quote;
 
-// Emit stream_stat alis, parameter pack struct.
 void emit_common_defs(std::ostream&, const Module& module_);
+void emit_api_body_cu(std::ostream& out, APIMethod* method, bool is_point_proc);
+void emit_procedure_body_cu(std::ostream& out, ProcedureExpression* proc);
+void emit_state_read_cu(std::ostream& out, LocalVariable* local);
+void emit_state_update_cu(std::ostream& out, Symbol* from,
+                          IndexedVariable* external, bool is_point_proc);
+const char* index_id(Symbol *s);
+
+struct cuprint {
+    Expression* expr_;
+    explicit cuprint(Expression* expr): expr_(expr) {}
+
+    friend std::ostream& operator<<(std::ostream& out, const cuprint& w) {
+        CudaPrinter printer(out);
+        return w.expr_->accept(&printer), out;
+    }
+};
 
 std::string make_class_name(const std::string& module_name) {
     return "mechanism_gpu_"+module_name;
@@ -53,20 +69,21 @@ std::string emit_cuda_cpp_source(const Module& module_, const std::string& ns) {
 
     out <<
         "#include <" << arb_header_prefix() << "backends/gpu/mechanism.hpp>\n"
-        "#include <" << arb_header_prefix() << "backends/gpu/mechanism_ppack_base.hpp>\n"
-        "\n" << namespace_declaration_open(ns_components) <<
-        "\n";
+        "#include <" << arb_header_prefix() << "backends/gpu/mechanism_ppack_base.hpp>\n";
+
+    out << "\n" << namespace_declaration_open(ns_components) << "\n";
 
     emit_common_defs(out, module_);
 
     out <<
-        "void " << class_name << "_nrn_init_(int, " << ppack_name << "&);\n"
-        "void " << class_name << "_nrn_state_(int, " << ppack_name << "&);\n"
-        "void " << class_name << "_nrn_current_(int, " << ppack_name << "&);\n"
-        "void " << class_name << "_write_ions_(int, " << ppack_name << "&);\n";
+        "void " << class_name << "_nrn_init_(" << ppack_name << "&);\n"
+        "void " << class_name << "_nrn_state_(" << ppack_name << "&);\n"
+        "void " << class_name << "_nrn_current_(" << ppack_name << "&);\n"
+        "void " << class_name << "_write_ions_(" << ppack_name << "&);\n";
 
     net_receive && out <<
-        "void " << class_name << "_deliver_events_(int, " << ppack_name << "&, deliverable_event_stream_state events);\n";
+        "void " << class_name << "_deliver_events_(int mech_id, "
+        << ppack_name << "&, deliverable_event_stream_state events);\n";
 
     out <<
         "\n"
@@ -81,21 +98,21 @@ std::string emit_cuda_cpp_source(const Module& module_, const std::string& ns) {
         "mechanism_ptr clone() const override { return mechanism_ptr(new " << class_name << "()); }\n"
         "\n"
         "void nrn_init() override {\n" << indent <<
-        class_name << "_nrn_init_(width_, pp_);\n" << popindent <<
+        class_name << "_nrn_init_(pp_);\n" << popindent <<
         "}\n\n"
         "void nrn_state() override {\n" << indent <<
-        class_name << "_nrn_state_(width_, pp_);\n" << popindent <<
+        class_name << "_nrn_state_(pp_);\n" << popindent <<
         "}\n\n"
         "void nrn_current() override {\n" << indent <<
-        class_name << "_nrn_current_(width_, pp_);\n" << popindent <<
+        class_name << "_nrn_current_(pp_);\n" << popindent <<
         "}\n\n"
         "void write_ions() override {\n" << indent <<
-        class_name << "_write_ions_(width_, pp_);\n" << popindent <<
+        class_name << "_write_ions_(pp_);\n" << popindent <<
         "}\n\n";
 
     net_receive && out <<
         "void deliver_events(deliverable_event_stream_state events) override {\n" << indent <<
-        class_name << "_deliver_events_(width_, pp_, events);\n" << popindent <<
+        class_name << "_deliver_events_(mechanism_id_, pp_, events);\n" << popindent <<
         "}\n\n";
 
     out << popindent <<
@@ -182,6 +199,7 @@ std::string emit_cuda_cu_source(const Module& module_, const std::string& ns) {
     std::string class_name = make_class_name(name);
     std::string ppack_name = make_ppack_name(name);
     auto ns_components = namespace_components(ns);
+    const bool is_point_proc = module_.kind() == moduleKind::point;
 
     NetReceiveExpression* net_receive = find_net_receive(module_);
     APIMethod* init_api = find_api_method(module_, "nrn_init");
@@ -196,29 +214,122 @@ std::string emit_cuda_cu_source(const Module& module_, const std::string& ns) {
     io::pfxstringstream out;
 
     out <<
+        "#include <iostream>\n"
         "#include <" << arb_header_prefix() << "backends/event.hpp>\n"
         "#include <" << arb_header_prefix() << "backends/multi_event_stream_state.hpp>\n"
-        "#include <" << arb_header_prefix() << "backends/gpu/mechanism_ppack_base.hpp>\n"
-        "\n" << namespace_declaration_open(ns_components) <<
-        "\n";
+        "#include <" << arb_header_prefix() << "backends/gpu/cuda_common.hpp>\n"
+        "#include <" << arb_header_prefix() << "backends/gpu/math.hpp>\n"
+        "#include <" << arb_header_prefix() << "backends/gpu/mechanism_ppack_base.hpp>\n";
 
-    emit_common_defs(out, module_);
+    is_point_proc && out <<
+        "#include <" << arb_header_prefix() << "backends/gpu/reduce_by_key.hpp>\n";
+
+    out << "\n" << namespace_declaration_open(ns_components) << "\n";
 
     out <<
         "using value_type = ::arb::gpu::mechanism_ppack_base::value_type;\n"
         "using index_type = ::arb::gpu::mechanism_ppack_base::index_type;\n"
         "\n";
 
-    out <<
-        "void " << class_name << "_nrn_init_(int, " << ppack_name << "&) {};\n"
-        "void " << class_name << "_nrn_state_(int, " << ppack_name << "&) {};\n"
-        "void " << class_name << "_nrn_current_(int, " << ppack_name << "&) {};\n"
-        "void " << class_name << "_write_ions_(int, " << ppack_name << "&) {};\n";
+    emit_common_defs(out, module_);
 
-    net_receive && out <<
-        "void " << class_name << "_deliver_events_(int, " << ppack_name << "&, deliverable_event_stream_state events) {}\n";
+    // Print the CUDA code and kernels:
+    //  - first __device__ functions that implement NMODL PROCEDUREs.
+    //  - then __global__ kernels that implement API methods and call the procedures.
+
+    out << "namespace {\n\n"; // place inside an anonymous namespace
+
+    out << "using arb::gpu::exprelr;\n";
+    out << "using arb::gpu::min;\n";
+    out << "using arb::gpu::max;\n\n";
+
+    // Procedures as __device__ functions.
+    auto emit_procedure_kernel = [&] (ProcedureExpression* e) {
+        out << "__device__\n"
+            << "void " << e->name()
+            << "(" << ppack_name << " params_, int tid_";
+        for(auto& arg: e->args()) {
+            out << ", value_type " << arg->is_argument()->name();
+        }
+        out << ") {\n" << indent
+            << cuprint(e->body())
+            << popindent << "}\n\n";
+    };
 
-    (void)write_ions_api;
+    for (auto& p: module_normal_procedures(module_)) {
+        emit_procedure_kernel(p);
+    }
+
+    // API methods as __global__ kernels.
+    auto emit_api_kernel = [&] (APIMethod* e) {
+        // Only print the kernel if the method is not empty.
+        if (!e->body()->statements().empty()) {
+            out << "__global__\n"
+                << "void " << e->name() << "(" << ppack_name << " params_) {\n" << indent
+                << "int n_ = params_.width_;\n";
+            emit_api_body_cu(out, e, is_point_proc);
+            out << popindent << "}\n\n";
+        }
+    };
+
+    emit_api_kernel(init_api);
+    emit_api_kernel(state_api);
+    emit_api_kernel(current_api);
+    emit_api_kernel(write_ions_api);
+
+    // event delivery
+    if (net_receive) {
+        out << "__global__\n"
+            << "void deliver_events(int mech_id_, " <<  ppack_name << " params_, "
+            << "deliverable_event_stream_state events) {\n" << indent
+            << "auto tid_ = threadIdx.x + blockDim.x*blockIdx.x;\n"
+            << "auto const ncell_ = events.n;\n\n"
+
+            << "if(tid_<ncell_) {\n" << indent
+            << "auto begin = events.ev_data+events.begin_offset[tid_];\n"
+            << "auto end = events.ev_data+events.end_offset[tid_];\n"
+            << "for (auto p = begin; p<end; ++p) {\n" << indent
+            << "if (p->mech_id==mech_id_) {\n" << indent
+            << "auto tid_ = p->mech_index;\n"
+            << "auto weight = p->weight;\n"
+            << cuprint(net_receive->body())
+            << popindent << "}\n"
+            << popindent << "}\n"
+            << popindent << "}\n"
+            << popindent << "}\n";
+    }
+
+    out << "} // namspace\n\n"; // close anonymous namespace
+
+    // Write wrappers.
+    auto emit_api_wrapper = [&] (APIMethod* e) {
+        out << "void " << class_name << "_" << e->name() << "_(" << ppack_name << "& p) {";
+
+        // Only call the kernel if the kernel is required.
+        !e->body()->statements().empty() && out
+            << "\n" << indent
+            << "auto n = p.width_;\n"
+            << "unsigned block_dim = 128;\n"
+            << "unsigned grid_dim = gpu::impl::block_count(n, block_dim);\n"
+            << e->name() << "<<<grid_dim, block_dim>>>(p);\n"
+            << popindent;
+
+        out << "}\n\n";
+    };
+    emit_api_wrapper(init_api);
+    emit_api_wrapper(current_api);
+    emit_api_wrapper(state_api);
+    emit_api_wrapper(write_ions_api);
+
+    net_receive && out
+        << "void " << class_name << "_deliver_events_("
+        << "int mech_id, "
+        << ppack_name << "& p, deliverable_event_stream_state events) {\n" << indent
+        << "auto n = events.n;\n"
+        << "unsigned block_dim = 128;\n"
+        << "unsigned grid_dim = gpu::impl::block_count(n, block_dim);\n"
+        << "deliver_events<<<grid_dim, block_dim>>>(mech_id, p, events);\n"
+        << popindent << "}\n\n";
 
     out << namespace_declaration_close(ns_components);
     return out.str();
@@ -232,11 +343,10 @@ void emit_common_defs(std::ostream& out, const Module& module_) {
     auto ion_deps = module_.ion_deps();
 
     find_net_receive(module_) && out <<
-        "using deliverable_event_stream_state = ::arb::multi_event_stream_state<::arb::deliverable_event_data>;\n"
-        "\n";
+        "using deliverable_event_stream_state =\n"
+        "    ::arb::multi_event_stream_state<::arb::deliverable_event_data>;\n\n";
 
-    out <<
-        "struct " << ppack_name << ": ::arb::gpu::mechanism_ppack_base {\n" << indent;
+    out << "struct " << ppack_name << ": ::arb::gpu::mechanism_ppack_base {\n" << indent;
 
     for (const auto& scalar: vars.scalars) {
         out << "value_type " << scalar->name() <<  " = " << as_c_double(scalar->value()) << ";\n";
@@ -252,4 +362,106 @@ void emit_common_defs(std::ostream& out, const Module& module_) {
     out << popindent << "};\n\n";
 }
 
+void emit_api_body_cu(std::ostream& out, APIMethod* e, bool is_point_proc) {
+    auto body = e->body();
+    auto indexed_vars = indexed_locals(e->scope());
+
+    if (!body->statements().empty()) {
+        out << "int tid_ = threadIdx.x + blockDim.x*blockIdx.x;\n";
+        out << "if (tid_<n_) {\n" << indent;
+        out << "auto gid_ __attribute__((unused)) = params_.node_index_[tid_];\n";
+        out << "auto cid_ __attribute__((unused)) = params_.vec_ci_[gid_];\n";
+
+        auto uses_ion = [&] (ionKind k) {
+            for(auto &symbol : e->scope()->locals()) {
+                if (k == symbol.second->is_local_variable()->ion_channel()) {
+                    return true;
+                }
+            }
+            return false;
+        };
+
+        uses_ion(ionKind::K)  && out << "auto kid_  = params_.ion_k_index_[tid_];\n";
+        uses_ion(ionKind::Ca) && out << "auto caid_ = params_.ion_ca_index_[tid_];\n";
+        uses_ion(ionKind::Na) && out << "auto naid_ = params_.ion_na_index_[tid_];\n";
+
+        for (auto& sym: indexed_vars) {
+            emit_state_read_cu(out, sym);
+        }
+
+        out << cuprint(body);
 
+        for (auto& sym: indexed_vars) {
+            emit_state_update_cu(out, sym, sym->external_variable(), is_point_proc);
+        }
+        out << popindent << "}\n";
+    }
+}
+
+void emit_procedure_body_cu(std::ostream& out, ProcedureExpression* e) {
+    out << cuprint(e->body());
+}
+
+void emit_state_read_cu(std::ostream& out, LocalVariable* local) {
+    out << "value_type " << cuprint(local) << " = ";
+
+    if (local->is_read()) {
+        out << cuprint(local->external_variable()) << ";\n";
+    }
+    else {
+        out << "0;\n";
+    }
+}
+
+void emit_state_update_cu(std::ostream& out, Symbol* from,
+                          IndexedVariable* external, bool is_point_proc)
+{
+    if (!external->is_write()) return;
+
+    const bool is_minus = external->op()==tok::minus;
+    if (is_point_proc) {
+        out << "arb::gpu::reduce_by_key(";
+        is_minus && out << "-";
+        out << from->name()
+            << ", params_." << decode_indexed_variable(external).data_var << ", gid_);\n";
+    }
+    else {
+        out << cuprint(external) << (is_minus? " -= ": " += ") << from->name() << ";\n";
+    }
+}
+
+// Gives the name of the local variable used to index into a field.
+const char* index_id(Symbol *s) {
+    if(s->is_variable()) return "tid_";
+    else if(auto var = s->is_indexed_variable()) {
+        switch(var->ion_channel()) {
+            case ionKind::none: return "gid_";
+            case ionKind::Ca:   return "caid_";
+            case ionKind::Na:   return "naid_";
+            case ionKind::K:    return "kid_";
+            default :
+                throw compiler_exception("CudaPrinter unknown ion type", s->location());
+        }
+    }
+    return "";
+}
+
+// CUDA Printer visitors
+
+void CudaPrinter::visit(VariableExpression *sym) {
+    out_ << "params_." << sym->name() << (sym->is_range()? "[tid_]": "");
+}
+
+void CudaPrinter::visit(IndexedVariable *e) {
+    auto d = decode_indexed_variable(e);
+    out_ << "params_." << d.data_var << "[" << index_id(e) <<  "]";
+}
+
+void CudaPrinter::visit(CallExpression* e) {
+    out_ << e->name() << "(params_, tid_";
+    for (auto& arg: e->args()) {
+        out_ << ", ";
+        arg->accept(this);
+    }
+    out_ << ")";
+}
diff --git a/modcc/printer/cudaprinter.hpp b/modcc/printer/cudaprinter.hpp
index 767a4eae..a4ec7df4 100644
--- a/modcc/printer/cudaprinter.hpp
+++ b/modcc/printer/cudaprinter.hpp
@@ -2,7 +2,19 @@
 
 #include <string>
 
+#include "cprinter.hpp"
 #include "module.hpp"
+#include "cexpr_emit.hpp"
 
 std::string emit_cuda_cpp_source(const Module& m, const std::string& ns);
 std::string emit_cuda_cu_source(const Module& m, const std::string& ns);
+
+class CudaPrinter: public CPrinter {
+public:
+    CudaPrinter(std::ostream& out): CPrinter(out) {}
+
+    void visit(CallExpression*) override;
+    void visit(VariableExpression*) override;
+    void visit(IndexedVariable*) override;
+};
+
diff --git a/modcc/printer/printerutil.cpp b/modcc/printer/printerutil.cpp
index 17cd445d..12df7d87 100644
--- a/modcc/printer/printerutil.cpp
+++ b/modcc/printer/printerutil.cpp
@@ -91,6 +91,18 @@ module_variables_t local_module_variables(const Module& m) {
     return mv;
 }
 
+std::vector<ProcedureExpression*> module_normal_procedures(const Module& m) {
+    std::vector<ProcedureExpression*> procs;
+    for (auto& sym: m.symbols()) {
+        auto p = sym.second->is_procedure();
+        if (p && p->kind()==procedureKind::normal) {
+            procs.push_back(p);
+        }
+    }
+
+    return procs;
+}
+
 APIMethod* find_api_method(const Module& m, const char* which) {
     auto it = m.symbols().find(which);
     return  it==m.symbols().end()? nullptr: it->second->is_api_method();
@@ -100,3 +112,41 @@ NetReceiveExpression* find_net_receive(const Module& m) {
     auto it = m.symbols().find("net_receive");
     return it==m.symbols().end()? nullptr: it->second->is_net_receive();
 }
+
+indexed_variable_info decode_indexed_variable(IndexedVariable* sym) {
+    std::string data_var, ion_pfx;
+    std::string index_var = "node_index_";
+
+    if (sym->is_ion()) {
+        ion_pfx = "ion_"+to_string(sym->ion_channel())+"_";
+        index_var = ion_pfx+"index_";
+    }
+
+    switch (sym->data_source()) {
+    case sourceKind::voltage:
+        data_var="vec_v_";
+        break;
+    case sourceKind::current:
+        data_var="vec_i_";
+        break;
+    case sourceKind::dt:
+        data_var="vec_dt_";
+        break;
+    case sourceKind::ion_current:
+        data_var=ion_pfx+".current_density";
+        break;
+    case sourceKind::ion_revpot:
+        data_var=ion_pfx+".reversal_potential";
+        break;
+    case sourceKind::ion_iconc:
+        data_var=ion_pfx+".internal_concentration";
+        break;
+    case sourceKind::ion_econc:
+        data_var=ion_pfx+".external_concentration";
+        break;
+    default:
+        throw compiler_exception("unrecognized indexed data source: "+to_string(sym), sym->location());
+    }
+
+    return {data_var, index_var};
+}
diff --git a/modcc/printer/printerutil.hpp b/modcc/printer/printerutil.hpp
index 49cf97fb..ccc0c593 100644
--- a/modcc/printer/printerutil.hpp
+++ b/modcc/printer/printerutil.hpp
@@ -94,9 +94,21 @@ struct module_variables_t {
 
 module_variables_t local_module_variables(const Module&);
 
+// "normal" procedures in a module.
+// A normal procedure is one that has been declared with the
+// PROCEDURE keyword in NMODL.
+
+std::vector<ProcedureExpression*> module_normal_procedures(const Module& m);
+
 // Extract key procedures from module.
 
 APIMethod* find_api_method(const Module& m, const char* which);
 
 NetReceiveExpression* find_net_receive(const Module& m);
 
+struct indexed_variable_info {
+    std::string data_var;
+    std::string index_var;
+};
+
+indexed_variable_info decode_indexed_variable(IndexedVariable* sym);
diff --git a/src/backends/gpu/mechanism.cpp b/src/backends/gpu/mechanism.cpp
index 8cf643cd..e050001f 100644
--- a/src/backends/gpu/mechanism.cpp
+++ b/src/backends/gpu/mechanism.cpp
@@ -36,11 +36,24 @@ memory::const_device_view<T> device_view(const T* ptr, std::size_t n) {
     return memory::const_device_view<T>(ptr, n);
 }
 
-// The derived class (typically generated code from modcc) holds pointers that need
-// to be set to point inside the shared state, or into the allocated parameter/variable
-// data block.
-
-void mechanism::instantiate(fvm_size_type id, backend::shared_state& shared, const layout& pos_data) {
+// The derived class (typically generated code from modcc) holds pointers to
+// data fields. These point point to either:
+//   * shared fields read/written by all mechanisms in a cell group
+//     (e.g. the per-compartment voltage vec_c);
+//   * or mechanism specific parameter or variable fields stored inside the
+//     mechanism.
+// These pointers need to be set point inside the shared state of the cell
+// group, or into the allocated parameter/variable data block.
+//
+// The mechanism::instantiate() method takes a reference to the cell group
+// shared state and discretised cell layout information, and sets the
+// pointers. This also involves setting the pointers in the parameter pack,
+// which is used to pass pointers to CUDA kernels.
+
+void mechanism::instantiate(fvm_size_type id,
+                            backend::shared_state& shared,
+                            const layout& pos_data)
+{
     mechanism_id_ = id;
     width_ = pos_data.cv.size();
 
@@ -51,6 +64,8 @@ void mechanism::instantiate(fvm_size_type id, backend::shared_state& shared, con
 
     mechanism_ppack_base* pp = ppack_ptr(); // From derived class instance.
 
+    pp->width_ = width_;
+
     pp->vec_ci_   = shared.cv_to_cell.data();
     pp->vec_t_    = shared.time.data();
     pp->vec_t_to_ = shared.time_to.data();
@@ -60,7 +75,7 @@ void mechanism::instantiate(fvm_size_type id, backend::shared_state& shared, con
     pp->vec_i_    = shared.current_density.data();
 
     auto ion_state_tbl = ion_state_table();
-    n_ion_ = ion_state_tbl.size();
+    num_ions_ = ion_state_tbl.size();
 
     for (auto i: ion_state_tbl) {
         util::optional<ion_state&> oion = value_by_key(shared.ion_data, i.first);
@@ -86,13 +101,13 @@ void mechanism::instantiate(fvm_size_type id, backend::shared_state& shared, con
     // (First sub-array of data_ is used for width_.)
 
     auto fields = field_table();
-    std::size_t n_field = fields.size();
+    std::size_t num_fields = fields.size();
 
-    data_ = array((1+n_field)*width_padded_, NAN);
+    data_ = array((1+num_fields)*width_padded_, NAN);
     memory::copy(make_const_view(pos_data.weight), device_view(data_.data(), width_));
     pp->weight_ = data_.data();
 
-    for (auto i: make_span(0, n_field)) {
+    for (auto i: make_span(0, num_fields)) {
         // Take reference to corresponding derived (generated) mechanism value pointer member.
         fvm_value_type*& field_ptr = *std::get<1>(fields[i]);
         field_ptr = data_.data()+(i+1)*width_padded_;
@@ -105,21 +120,22 @@ void mechanism::instantiate(fvm_size_type id, backend::shared_state& shared, con
     // Allocate and initialize index vectors, viz. node_index_ and any ion indices.
     // (First sub-array of indices_ is used for node_index_.)
 
-    indices_ = iarray((1+n_ion_)*width_padded_);
+    indices_ = iarray((1+num_ions_)*width_padded_);
 
     memory::copy(make_const_view(pos_data.cv), device_view(indices_.data(), width_));
     pp->node_index_ = indices_.data();
 
     auto ion_index_tbl = ion_index_table();
-    EXPECTS(n_ion_==ion_index_tbl.size());
+    EXPECTS(num_ions_==ion_index_tbl.size());
 
-    for (auto i: make_span(0, n_ion_)) {
+    for (auto i: make_span(0, num_ions_)) {
         util::optional<ion_state&> oion = value_by_key(shared.ion_data, ion_index_tbl[i].first);
         if (!oion) {
             throw std::logic_error("mechanism holds ion with no corresponding shared state");
         }
 
-        auto indices = util::index_into(pos_data.cv, memory::on_host(oion->node_index_));
+        auto ni = memory::on_host(oion->node_index_);
+        auto indices = util::index_into(pos_data.cv, ni);
         std::vector<index_type> mech_ion_index(indices.begin(), indices.end());
 
         // Take reference to derived (generated) mechanism ion index pointer.
diff --git a/src/backends/gpu/mechanism.hpp b/src/backends/gpu/mechanism.hpp
index f4adf68e..3c28116c 100644
--- a/src/backends/gpu/mechanism.hpp
+++ b/src/backends/gpu/mechanism.hpp
@@ -61,7 +61,7 @@ public:
 
 protected:
     size_type width_ = 0;        // Instance width (number of CVs/sites)
-    size_type n_ion_ = 0;
+    size_type num_ions_ = 0;
 
     // Returns pointer to (derived) parameter-pack object that holds:
     // * pointers to shared cell state `vec_ci_` et al.,
diff --git a/src/backends/gpu/mechanism_ppack_base.hpp b/src/backends/gpu/mechanism_ppack_base.hpp
index 68095dd3..5147d7ef 100644
--- a/src/backends/gpu/mechanism_ppack_base.hpp
+++ b/src/backends/gpu/mechanism_ppack_base.hpp
@@ -27,6 +27,8 @@ struct mechanism_ppack_base {
     using index_type = fvm_index_type;
     using ion_state_view = ::arb::gpu::ion_state_view;
 
+    index_type width_;
+
     const index_type* vec_ci_;
     const value_type* vec_t_;
     const value_type* vec_t_to_;
diff --git a/src/backends/gpu/shared_state.cpp b/src/backends/gpu/shared_state.cpp
index 183abd9f..1f59bccc 100644
--- a/src/backends/gpu/shared_state.cpp
+++ b/src/backends/gpu/shared_state.cpp
@@ -76,7 +76,7 @@ ion_state::ion_state(
 
 void ion_state::nernst(fvm_value_type temperature_K) {
     // Nernst equation: reversal potenial eX given by:
-    //     
+    //
     //     eX = RT/zF * ln(Xo/Xi)
     //
     // where:
@@ -90,7 +90,7 @@ void ion_state::nernst(fvm_value_type temperature_K) {
     constexpr fvm_value_type RF = 1e3*constant::gas_constant/constant::faraday;
 
     fvm_value_type factor = RF*temperature_K/charge;
-    nernst_impl(Xi_.size(), factor, Xi_.data(), Xo_.data(), eX_.data());
+    nernst_impl(Xi_.size(), factor, Xo_.data(), Xi_.data(), eX_.data());
 }
 
 void ion_state::init_concentration() {
@@ -170,7 +170,7 @@ void shared_state::update_time_to(fvm_value_type dt_step, fvm_value_type tmax) {
 }
 
 void shared_state::set_dt() {
-    set_dt_impl(n_cell, n_cv, dt_cell.data(), dt_cv.data(), time_to.data(), time_to.data(), cv_to_cell.data());
+    set_dt_impl(n_cell, n_cv, dt_cell.data(), dt_cv.data(), time_to.data(), time.data(), cv_to_cell.data());
 }
 
 std::pair<fvm_value_type, fvm_value_type> shared_state::time_bounds() const {
@@ -185,5 +185,26 @@ void shared_state::take_samples(const sample_event_stream::state& s, array& samp
     take_samples_impl(s, time.data(), sample_time.data(), sample_value.data());
 }
 
+// Debug interface
+std::ostream& operator<<(std::ostream& o, shared_state& s) {
+    o << " cv_to_cell " << s.cv_to_cell << "\n";
+    o << " time       " << s.time << "\n";
+    o << " time_to    " << s.time_to << "\n";
+    o << " dt_cell    " << s.dt_cell << "\n";
+    o << " dt_cv      " << s.dt_cv << "\n";
+    o << " voltage    " << s.voltage << "\n";
+    o << " current    " << s.current_density << "\n";
+    for (auto& ki: s.ion_data) {
+        auto kn = to_string(ki.first);
+        auto& i = const_cast<ion_state&>(ki.second);
+        o << " " << kn << ".current_density        " << i.iX_ << "\n";
+        o << " " << kn << ".reversal_potential     " << i.eX_ << "\n";
+        o << " " << kn << ".internal_concentration " << i.Xi_ << "\n";
+        o << " " << kn << ".external_concentration " << i.Xo_ << "\n";
+        o << " " << kn << ".node_index             " << i.node_index_ << "\n";
+    }
+    return o;
+}
+
 } // namespace gpu
 } // namespace arb
diff --git a/src/backends/gpu/shared_state.hpp b/src/backends/gpu/shared_state.hpp
index 642dd1eb..256198f6 100644
--- a/src/backends/gpu/shared_state.hpp
+++ b/src/backends/gpu/shared_state.hpp
@@ -5,9 +5,10 @@
 #include <utility>
 #include <vector>
 
-#include <util/enumhash.hpp>
 #include <backends/fvm_types.hpp>
 #include <backends/gpu/gpu_store_types.hpp>
+#include <ion.hpp>
+#include <util/enumhash.hpp>
 
 namespace arb {
 namespace gpu {
@@ -121,8 +122,8 @@ struct shared_state {
     void reset(fvm_value_type initial_voltage, fvm_value_type temperature_K);
 };
 
-// For debugging only:
-std::ostream& operator<<(std::ostream& o, const shared_state& s);
+// For debugging only
+std::ostream& operator<<(std::ostream& o, shared_state& s);
 
 } // namespace gpu
 } // namespace arb
diff --git a/src/backends/gpu/threshold_watcher.cu b/src/backends/gpu/threshold_watcher.cu
index 12c43676..734bb192 100644
--- a/src/backends/gpu/threshold_watcher.cu
+++ b/src/backends/gpu/threshold_watcher.cu
@@ -90,19 +90,23 @@ void test_thresholds_impl(
     fvm_index_type* is_crossed, fvm_value_type* prev_values,
     const fvm_index_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds)
 {
-    constexpr int block_dim = 128;
-    const int grid_dim = impl::block_count(size, block_dim);
-    kernel::test_thresholds_impl<<<grid_dim, block_dim>>>(
-        size, cv_to_cell, t_after, t_before, stack, is_crossed, prev_values, cv_index, values, thresholds);
+    if (size>0) {
+        constexpr int block_dim = 128;
+        const int grid_dim = impl::block_count(size, block_dim);
+        kernel::test_thresholds_impl<<<grid_dim, block_dim>>>(
+            size, cv_to_cell, t_after, t_before, stack, is_crossed, prev_values, cv_index, values, thresholds);
+    }
 }
 
 void reset_crossed_impl(
     int size, fvm_index_type* is_crossed,
     const fvm_index_type* cv_index, const fvm_value_type* values, const fvm_value_type* thresholds)
 {
-    constexpr int block_dim = 128;
-    const int grid_dim = impl::block_count(size, block_dim);
-    kernel::reset_crossed_impl<<<grid_dim, block_dim>>>(size, is_crossed, cv_index, values, thresholds);
+    if (size>0) {
+        constexpr int block_dim = 128;
+        const int grid_dim = impl::block_count(size, block_dim);
+        kernel::reset_crossed_impl<<<grid_dim, block_dim>>>(size, is_crossed, cv_index, values, thresholds);
+    }
 }
 
 } // namespace gpu
diff --git a/src/backends/gpu/threshold_watcher.hpp b/src/backends/gpu/threshold_watcher.hpp
index b38e850f..898bcca8 100644
--- a/src/backends/gpu/threshold_watcher.hpp
+++ b/src/backends/gpu/threshold_watcher.hpp
@@ -58,6 +58,7 @@ public:
         // A more robust approach might be needed to avoid overflows.
         stack_(10*size())
     {
+        crossings_.reserve(stack_.capacity());
         reset();
     }
 
@@ -74,7 +75,9 @@ public:
     /// calling, because the values are used to determine the initial state
     void reset() {
         clear_crossings();
-        reset_crossed_impl((int)size(), is_crossed_.data(), cv_index_.data(), values_, thresholds_.data());
+        if (size()>0) {
+            reset_crossed_impl((int)size(), is_crossed_.data(), cv_index_.data(), values_, thresholds_.data());
+        }
     }
 
     // Testing-only interface.
@@ -82,12 +85,14 @@ public:
         return is_crossed_[i];
     }
 
-    const std::vector<threshold_crossing> crossings() const {
+    const std::vector<threshold_crossing>& crossings() const {
         if (stack_.overflow()) {
             throw std::runtime_error("GPU spike buffer overflow.");
         }
 
-        return std::vector<threshold_crossing>(stack_.begin(), stack_.end());
+        crossings_.clear();
+        crossings_.insert(crossings_.end(), stack_.begin(), stack_.end());
+        return crossings_;
     }
 
     /// Tests each target for changed threshold state.
@@ -95,17 +100,19 @@ public:
     /// crossed since current time t, and the last time the test was
     /// performed.
     void test() {
-        test_thresholds_impl(
-            (int)size(),
-            cv_to_cell_, t_after_, t_before_,
-            stack_.storage(),
-            is_crossed_.data(), v_prev_.data(),
-            cv_index_.data(), values_, thresholds_.data());
-
-        // Check that the number of spikes has not exceeded capacity.
-        // ATTENTION: requires cudaDeviceSynchronize to avoid simultaneous
-        // host-device managed memory access.
-        EXPECTS((cudaDeviceSynchronize(), !stack_.overflow()));
+        if (size()>0) {
+            test_thresholds_impl(
+                (int)size(),
+                cv_to_cell_, t_after_, t_before_,
+                stack_.storage(),
+                is_crossed_.data(), v_prev_.data(),
+                cv_index_.data(), values_, thresholds_.data());
+
+            // Check that the number of spikes has not exceeded capacity.
+            // ATTENTION: requires cudaDeviceSynchronize to avoid simultaneous
+            // host-device managed memory access.
+            EXPECTS((cudaDeviceSynchronize(), !stack_.overflow()));
+        }
     }
 
     /// the number of threashold values that are being monitored
@@ -129,6 +136,9 @@ private:
 
     // Hybrid host/gpu data structure for accumulating threshold crossings.
     stack_type stack_;
+
+    // host side storage for the crossings
+    mutable std::vector<threshold_crossing> crossings_;
 };
 
 } // namespace gpu
diff --git a/src/backends/multicore/shared_state.cpp b/src/backends/multicore/shared_state.cpp
index a1588692..bc02cc9c 100644
--- a/src/backends/multicore/shared_state.cpp
+++ b/src/backends/multicore/shared_state.cpp
@@ -69,7 +69,7 @@ ion_state::ion_state(
 
 void ion_state::nernst(fvm_value_type temperature_K) {
     // Nernst equation: reversal potenial eX given by:
-    //     
+    //
     //     eX = RT/zF * ln(Xo/Xi)
     //
     // where:
@@ -233,22 +233,23 @@ void shared_state::take_samples(
 std::ostream& operator<<(std::ostream& out, const shared_state& s) {
     using util::csv;
 
-    out << "n_cell " << s.n_cell << "\n----\n";
-    out << "n_cv " << s.n_cv << "\n----\n";
-    out << "cv_to_cell:\n" << csv(s.cv_to_cell) << "\n";
-    out << "time:\n" << csv(s.time) << "\n";
-    out << "time_to:\n" << csv(s.time_to) << "\n";
-    out << "dt:\n" << csv(s.dt_cell) << "\n";
-    out << "dt_comp:\n" << csv(s.dt_cv) << "\n";
-    out << "voltage:\n" << csv(s.voltage) << "\n";
-    out << "current_density:\n" << csv(s.current_density) << "\n";
+    out << "n_cell     " << s.n_cell << "\n";
+    out << "n_cv       " << s.n_cv << "\n";
+    out << "cv_to_cell " << csv(s.cv_to_cell) << "\n";
+    out << "time       " << csv(s.time) << "\n";
+    out << "time_to    " << csv(s.time_to) << "\n";
+    out << "dt_cell    " << csv(s.dt_cell) << "\n";
+    out << "dt_cv      " << csv(s.dt_cv) << "\n";
+    out << "voltage    " << csv(s.voltage) << "\n";
+    out << "current    " << csv(s.current_density) << "\n";
     for (auto& ki: s.ion_data) {
         auto kn = to_string(ki.first);
         auto& i = const_cast<ion_state&>(ki.second);
-        out << kn << ".current_density:\n" << csv(i.iX_) << "\n";
-        out << kn << ".reversal_potential:\n" << csv(i.eX_) << "\n";
-        out << kn << ".internal_concentration:\n" << csv(i.Xi_) << "\n";
-        out << kn << ".external_concentration:\n" << csv(i.Xo_) << "\n";
+        out << kn << ".current_density        " << csv(i.iX_) << "\n";
+        out << kn << ".reversal_potential     " << csv(i.eX_) << "\n";
+        out << kn << ".internal_concentration " << csv(i.Xi_) << "\n";
+        out << kn << ".external_concentration " << csv(i.Xo_) << "\n";
+        out << kn << ".node_index             " << csv(i.node_index_) << "\n";
     }
 
     return out;
diff --git a/src/memory/memory.hpp b/src/memory/memory.hpp
index 4e9d393f..3404ef7c 100644
--- a/src/memory/memory.hpp
+++ b/src/memory/memory.hpp
@@ -46,6 +46,19 @@ template <typename T>
 using device_view = array_view<T, device_coordinator<T, cuda_allocator<T>>>;
 template <typename T>
 using const_device_view = const_array_view<T, device_coordinator<T, cuda_allocator<T>>>;
+
+template <typename T>
+std::ostream& operator<<(std::ostream& o, device_view<T> v) {
+    std::size_t i=0u;
+    for (; i<v.size()-1; ++i) o << v[i] << ", ";
+    return o << v[i];
+}
+template <typename T>
+std::ostream& operator<<(std::ostream& o, const_device_view<T> v) {
+    std::size_t i=0u;
+    for (; i<v.size()-1; ++i) o << v[i] << ", ";
+    return o << v[i];
+}
 #endif
 
 #ifdef WITH_KNL
diff --git a/tests/unit/test_algorithms.cpp b/tests/unit/test_algorithms.cpp
index 230fd0af..a8919dd7 100644
--- a/tests/unit/test_algorithms.cpp
+++ b/tests/unit/test_algorithms.cpp
@@ -578,6 +578,7 @@ TEST(algorithms, index_into)
         {{}, {}},
         {{100}, {}},
         // Strictly monotonic sequences:
+        {{0, 7}, {0, 7}},
         {{0, 1, 3, 4, 6, 7, 10, 11}, {0, 4, 6, 7, 11}},
         {{0, 1, 3, 4, 6, 7, 10, 11}, {0}},
         {{0, 1, 3, 4, 6, 7, 10, 11}, {11}},
-- 
GitLab