diff --git a/arbor/backends/gpu/mechanism.cpp b/arbor/backends/gpu/mechanism.cpp index aa500ddcf5ce27445249e094e84b4bd62eee14f9..a0304fe7975cefc2f656a1c994f8658b86b62871 100644 --- a/arbor/backends/gpu/mechanism.cpp +++ b/arbor/backends/gpu/mechanism.cpp @@ -147,17 +147,24 @@ void mechanism::instantiate(unsigned id, // Allocate and initialize index vectors, viz. node_index_ and any ion indices. // (First sub-array of indices_ is used for node_index_, last sub-array used for multiplicity_ if it is not empty) - size_type num_elements = mult_in_place_ ? 2 + num_ions_ : 1 + num_ions_; - indices_ = iarray((num_elements)*width_padded_); + size_type num_elements = (mult_in_place_ ? 1 : 0) + 1 + num_ions_; + indices_ = iarray(num_elements*width_padded_); - memory::copy(make_const_view(pos_data.cv), device_view(indices_.data(), width_)); - pp->node_index_ = indices_.data(); + auto base_ptr = indices_.data(); + + auto append_chunk = [&](const auto& input, auto& output) { + memory::copy(make_const_view(input), device_view(base_ptr, width_)); + output = base_ptr; + base_ptr += width_padded_; + }; + + append_chunk(pos_data.cv, pp->node_index_); auto ion_index_tbl = ion_index_table(); arb_assert(num_ions_==ion_index_tbl.size()); - for (auto i: make_span(0, num_ions_)) { - auto ion_binding = value_by_key(overrides.ion_rebind, ion_index_tbl[i].first).value_or(ion_index_tbl[i].first); + for (auto& [ion, ion_ptr]: ion_index_tbl) { + auto ion_binding = value_by_key(overrides.ion_rebind, ion).value_or(ion); ion_state* oion = ptr_by_key(shared.ion_data, ion_binding); @@ -170,15 +177,11 @@ void mechanism::instantiate(unsigned id, std::vector<index_type> mech_ion_index(indices.begin(), indices.end()); // Take reference to derived (generated) mechanism ion index pointer. - auto& ion_index_ptr = *ion_index_tbl[i].second; - auto index_start = indices_.data()+(i+1)*width_padded_; - ion_index_ptr = index_start; - memory::copy(make_const_view(mech_ion_index), device_view(index_start, width_)); + append_chunk(mech_ion_index, *ion_ptr); } if (mult_in_place_) { - memory::copy(make_const_view(pos_data.multiplicity), device_view(indices_.data() + width_padded_, width_)); - pp->multiplicity_ = indices_.data() + (num_ions_ + 1)*width_padded_; + append_chunk(pos_data.multiplicity, pp->multiplicity_); } }