From 20c7216b0b7be6af7ed6837ac5acf96042509429 Mon Sep 17 00:00:00 2001 From: Sam Yates <halfflat@gmail.com> Date: Tue, 4 May 2021 15:01:43 +0200 Subject: [PATCH] Clarify/rewrite pw_elements zip. (#1515) * Use a struct with meaningful field names instead of a `std::pair` to represent elements in `pw_elements`. * Describe explicitly the semantics of the zip operation. * Update implementation and add unit tests to suit. * Add comment explaining `pw_elements::equal_range`. * Remove parameter names from unused parameters in `impl::piecewise_pairify`. --- arbor/fvm_layout.cpp | 2 +- arbor/fvm_layout.hpp | 2 +- arbor/morph/embed_pwlin.cpp | 12 +-- arbor/util/piecewise.hpp | 196 ++++++++++++++++++++--------------- test/unit/test_piecewise.cpp | 145 +++++++++++++++++++------- 5 files changed, 228 insertions(+), 129 deletions(-) diff --git a/arbor/fvm_layout.cpp b/arbor/fvm_layout.cpp index 1f917939..1ec58645 100644 --- a/arbor/fvm_layout.cpp +++ b/arbor/fvm_layout.cpp @@ -1164,7 +1164,7 @@ fvm_mechanism_data fvm_build_mechanism_data(const cable_cell_global_properties& const mcable_map<init_reversal_potential>& rvpot_on_cable = initial_rvpot_map[ion]; auto pw_times = [](const pw_elements<double>& a, const pw_elements<double>& b) { - return zip(a, b, [](double left, double right, pw_element<double> a, pw_element<double> b) { return a.second*b.second; }); + return zip(a, b, [](double left, double right, pw_element<double> a, pw_element<double> b) { return a.element*b.element; }); }; for (auto i: count_along(config.cv)) { diff --git a/arbor/fvm_layout.hpp b/arbor/fvm_layout.hpp index 8d8ba818..88cf14dd 100644 --- a/arbor/fvm_layout.hpp +++ b/arbor/fvm_layout.hpp @@ -145,7 +145,7 @@ struct cv_geometry { } index_type cv_base = cell_cv_divs.at(cell_idx); - return cv_base+pw_cv_offset[i].second; + return cv_base+pw_cv_offset[i].element; } }; diff --git a/arbor/morph/embed_pwlin.cpp b/arbor/morph/embed_pwlin.cpp index 03acfaa7..03393acc 100644 --- a/arbor/morph/embed_pwlin.cpp +++ b/arbor/morph/embed_pwlin.cpp @@ -93,9 +93,9 @@ mcable_list data_cmp(const branch_pw_ratpoly<1, 0>& f, unsigned bid, double val, mcable_list L; const auto& pw = f.at(bid); for (const auto& piece: pw) { - auto extents = piece.first; - auto left_val = piece.second(0); - auto right_val = piece.second(1); + auto extents = piece.interval; + auto left_val = piece.element(0); + auto right_val = piece.element(1); if (!op(left_val, val) && !op(right_val, val)) { continue; @@ -275,11 +275,11 @@ embed_pwlin::embed_pwlin(const arb::morphology& m) { segment_cables_[seg.id] = mcable{bid, pos0, pos1}; } - double length_0 = parent==mnpos? 0: data_->length[parent].back().second[1]; + double length_0 = parent==mnpos? 0: data_->length[parent].back().element[1]; data_->length[bid].push_back(0., 1, rat_element<1, 0>(length_0, length_0+branch_length)); - double area_0 = parent==mnpos? 0: data_->area[parent].back().second[2]; - double ixa_0 = parent==mnpos? 0: data_->ixa[parent].back().second[2]; + double area_0 = parent==mnpos? 0: data_->area[parent].back().element[2]; + double ixa_0 = parent==mnpos? 0: data_->ixa[parent].back().element[2]; for (auto i: util::count_along(segments)) { auto prox = segments[i].prox; diff --git a/arbor/util/piecewise.hpp b/arbor/util/piecewise.hpp index 9f157989..f250c1b7 100644 --- a/arbor/util/piecewise.hpp +++ b/arbor/util/piecewise.hpp @@ -72,11 +72,20 @@ struct pw_elements { using size_type = pw_size_type; static constexpr size_type npos = pw_npos; - using value_type = std::pair<std::pair<double, double>, X>; + struct value_type { + std::pair<double, double> interval; + X element; + + bool operator==(const value_type& other) const { return interval==other.interval && element==other.element; } + bool operator!=(const value_type& other) const { return interval!=other.interval || element!=other.element; } + }; + + using const_iterator = indexed_const_iterator<pw_elements<X>>; + using iterator = const_iterator; // Consistency requirements: - // 1. empty() || element.size()+1 = vertex.size() - // 2. vertex[i]<=vertex[j] for i<=j. + // 1. empty() || element_.size()+1 = vertex_.size() + // 2. vertex_[i]<=vertex_[j] for i<=j. std::vector<double> vertex_; std::vector<X> element_; @@ -128,9 +137,6 @@ struct pw_elements { const X& element(size_type i) const & { return element_[i]; } value_type operator[](size_type i) const { return value_type{interval(i), element(i)}; } - using const_iterator = indexed_const_iterator<pw_elements<X>>; - using iterator = const_iterator; - const_iterator cbegin() const { return const_iterator{this, 0}; } const_iterator begin() const { return cbegin(); } const_iterator cend() const { return const_iterator{this, size()}; } @@ -151,6 +157,20 @@ struct pw_elements { std::pair<iterator, iterator> equal_range(double x) const { auto eq = std::equal_range(vertex_.begin(), vertex_.end(), x); + // Let n be the number of elements, indexed from 0 to n-1, with + // vertices indexed from 0 to n. Observe: + // * eq.first points to least vertex v_i ≥ x. + // * eq.second points to vertex_.end() if the last vertex v_n ≤ x, + // or else to the least vertex v_k > x. + // + // Elements then correspond to the index range [b, e), where: + // * b=0 if i=0, else b=i-1, as v_i will be the upper vertex for + // the first element whose (closed) support contains x. + // * e=k if k<n, since v_k will be the upper vertex for the + // the last element (index k-1) whose support contains x. + // Otherwise, if k==n or eq.second is vertex_.end(), the + // last element (index n-1) contains x, and so e=n. + if (eq.first==vertex_.end()) return {end(), end()}; if (eq.first>vertex_.begin()) --eq.first; if (eq.second==vertex_.end()) --eq.second; @@ -261,7 +281,15 @@ template <> struct pw_elements<void> { std::vector<double> vertex_; - using value_type = std::pair<double, double>; + struct value_type { + std::pair<double, double> interval; + + bool operator==(const value_type& other) const { return interval==other.interval; } + bool operator!=(const value_type& other) const { return interval!=other.interval; } + }; + + using const_iterator = indexed_const_iterator<pw_elements<void>>; + using iterator = const_iterator; // ctors and assignment: @@ -285,7 +313,7 @@ template <> struct pw_elements<void> { auto intervals() const { return util::partition_view(vertex_); } auto interval(size_type i) const { return intervals()[i]; } - value_type operator[](size_type i) const { return interval(i); } + value_type operator[](size_type i) const { return value_type{interval(i)}; } auto bounds() const { return intervals().bounds(); } @@ -305,8 +333,16 @@ template <> struct pw_elements<void> { else return partn.index(x); } - using const_iterator = indexed_const_iterator<pw_elements<void>>; - using iterator = const_iterator; + // Return iterator pair spanning elements whose corresponding closed intervals contain x. + std::pair<iterator, iterator> equal_range(double x) const { + auto eq = std::equal_range(vertex_.begin(), vertex_.end(), x); + + if (eq.first==vertex_.end()) return {end(), end()}; + if (eq.first>vertex_.begin()) --eq.first; + if (eq.second==vertex_.end()) --eq.second; + + return {begin()+(eq.first-vertex_.begin()), begin()+(eq.second-vertex_.begin())}; + } const_iterator cbegin() const { return const_iterator{this, 0}; } const_iterator begin() const { return cbegin(); } @@ -381,48 +417,79 @@ namespace impl { template <typename A, typename B> struct piecewise_pairify { std::pair<A, B> operator()( - double left, double right, - const pw_element<A> a_elem, - const pw_element<B> b_elem) const + double, double, + const pw_element<A>& a_elem, + const pw_element<B>& b_elem) const { - return {a_elem.second, b_elem.second}; + return {a_elem.element, b_elem.element}; } }; template <typename X> struct piecewise_pairify<X, void> { X operator()( - double left, double right, - const pw_element<X> a_elem, - const pw_element<void> b_elem) const + double, double, + const pw_element<X>& a_elem, + const pw_element<void>& b_elem) const { - return a_elem.second; + return a_elem.element; } }; template <typename X> struct piecewise_pairify<void, X> { X operator()( - double left, double right, - const pw_element<void> a_elem, - const pw_element<X> b_elem) const + double, double, + const pw_element<void>& a_elem, + const pw_element<X>& b_elem) const { - return b_elem.second; + return b_elem.element; } }; -} -// TODO: Consider making a lazy `zip_view` version of zip. + template <> + struct piecewise_pairify<void, void> { + void operator()( + double, double, + const pw_element<void>&, + const pw_element<void>&) const {} + }; +} -// Combine functional takes four arguments: -// double left, double right, pw_elements<A>::value_type, pw_elements<B>::value_type b> +// Zip combines successive elements from two pw_elements sequences where the +// elements overlap. Let A_i and B_i denote the ordered elements from each of +// the sequences A and B, and Z_i denote the elements from the resulting +// sequence Z. +// +// * The support (`bounds()`) of the zipped result is the intersections of the +// support of the two sequences. +// +// * Each element Z_k in the zip corresponds to the intersection of some +// element A_i and B_j. The extent of Z_k is the intersection of the +// extents of A_i and B_j, and its value is determined by the supplied +// `combine` function. +// +// * For every element in A_i in A, if A_i intersects with an element of +// B, then there will be an Z_k corresponding to the intersection of A_i +// with some element of B. Likewise for elements B_i of B. // -// Default combine functional returns std::pair<A, B>, unless one of A and B is void. +// * Elements of Z respect the ordering of elements in A and B, and do +// not repeat. If Z_k is derived from A_i and B_j, then Z_(k+1) is derived +// from A_(i+1) and B_(j+1) if possible, or else from A_i and B_(j+1) or +// A_(i+1) and B_j. +// +// The Combine functional takes four arguments: double left, double right, +// pw_elements<A>::value_type, pw_elements<B>::value_type b. The default +// combine functional returns std::pair<A, B>, unless one of A and B is void. +// +// TODO: Consider making a lazy `zip_view` version of zip. template <typename A, typename B, typename Combine = impl::piecewise_pairify<A, B>> auto zip(const pw_elements<A>& a, const pw_elements<B>& b, Combine combine = {}) { using Out = decltype(combine(0., 0., a.front(), b.front())); + constexpr bool is_void = std::is_void_v<Out>; + pw_elements<Out> z; if (a.empty() || b.empty()) return z; @@ -430,77 +497,34 @@ auto zip(const pw_elements<A>& a, const pw_elements<B>& b, Combine combine = {}) double rmin = std::min(a.bounds().second, b.bounds().second); if (rmin<lmax) return z; - double left = lmax; - pw_size_type ai = a.index_of(left); - pw_size_type bi = b.index_of(left); - - arb_assert(ai!=(pw_size_type)-1); - arb_assert(bi!=(pw_size_type)-1); - - if (rmin==left) { - z.push_back(left, left, combine(left, left, a[ai], b[bi])); - return z; - } + auto ai = a.equal_range(lmax).first; + auto bi = b.equal_range(lmax).first; - double a_right = a.interval(ai).second; - double b_right = b.interval(bi).second; + auto a_end = a.equal_range(rmin).second; + auto b_end = b.equal_range(rmin).second; + double left = lmax; + double a_right = ai->interval.second; + double b_right = bi->interval.second; for (;;) { double right = std::min(a_right, b_right); - right = std::min(right, rmin); - - z.push_back(left, right, combine(left, right, a[ai], b[bi])); - if (right==rmin) break; - - if (a_right==right) { - a_right = a.interval(++ai).second; + if constexpr (is_void) { + z.push_back(left, right); } - if (b_right==right) { - b_right = b.interval(++bi).second; + else { + z.push_back(left, right, combine(left, right, *ai, *bi)); } - left = right; - } - return z; -} - -inline pw_elements<void> zip(const pw_elements<void>& a, const pw_elements<void>& b) { - pw_elements<void> z; - if (a.empty() || b.empty()) return z; - - double lmax = std::max(a.bounds().first, b.bounds().first); - double rmin = std::min(a.bounds().second, b.bounds().second); - if (rmin<lmax) return z; - - double left = lmax; - pw_size_type ai = a.intervals().index(left); - pw_size_type bi = b.intervals().index(left); - - if (rmin==left) { - z.push_back(left, left); - return z; - } - - double a_right = a.interval(ai).second; - double b_right = b.interval(bi).second; - - while (left<rmin) { - double right = std::min(a_right, b_right); - right = std::min(right, rmin); - - z.push_back(left, right); - if (a_right<=right) { - a_right = a.interval(++ai).second; - } - if (b_right<=right) { - b_right = b.interval(++bi).second; - } + bool advance_a = a_right==right && std::next(ai)!=a_end; + bool advance_b = b_right==right && std::next(bi)!=b_end; + if (!advance_a && !advance_b) break; + if (advance_a) a_right = (++ai)->interval.second; + if (advance_b) b_right = (++bi)->interval.second; left = right; } return z; } - } // namespace util } // namespace arb diff --git a/test/unit/test_piecewise.cpp b/test/unit/test_piecewise.cpp index cac24581..f05b540a 100644 --- a/test/unit/test_piecewise.cpp +++ b/test/unit/test_piecewise.cpp @@ -82,10 +82,10 @@ TEST(piecewise, assign) { ASSERT_EQ(4u, p.size()); - EXPECT_EQ(10, p[0].second); - EXPECT_EQ( 8, p[1].second); - EXPECT_EQ( 9, p[2].second); - EXPECT_EQ( 4, p[3].second); + EXPECT_EQ(10, p[0].element); + EXPECT_EQ( 8, p[1].element); + EXPECT_EQ( 9, p[2].element); + EXPECT_EQ( 4, p[3].element); using dp = std::pair<double, double>; EXPECT_EQ(dp(1.0, 1.5), p.interval(0)); @@ -168,13 +168,13 @@ TEST(piecewise, access) { p.assign(v, x); for (unsigned i = 0; i<4; ++i) { - EXPECT_EQ(v[i], p[i].first.first); - EXPECT_EQ(v[i+1], p[i].first.second); + EXPECT_EQ(v[i], p[i].interval.first); + EXPECT_EQ(v[i+1], p[i].interval.second); EXPECT_EQ(v[i], p.interval(i).first); EXPECT_EQ(v[i+1], p.interval(i).second); - EXPECT_EQ(x[i], p[i].second); + EXPECT_EQ(x[i], p[i].element); EXPECT_EQ(x[i], p.element(i)); } @@ -240,21 +240,21 @@ TEST(piecewise, equal_range) { auto er1 = p.equal_range(1.0); ASSERT_EQ(1, er1.second-er1.first); - EXPECT_EQ(10, er1.first->second); + EXPECT_EQ(10, er1.first->element); auto er2 = p.equal_range(2.0); ASSERT_EQ(2, er2.second-er2.first); auto iter = er2.first; - EXPECT_EQ(10, iter++->second); - EXPECT_EQ(9, iter->second); + EXPECT_EQ(10, iter++->element); + EXPECT_EQ(9, iter->element); auto er3_5 = p.equal_range(3.5); ASSERT_EQ(1, er3_5.second-er3_5.first); - EXPECT_EQ(8, er3_5.first->second); + EXPECT_EQ(8, er3_5.first->element); auto er4 = p.equal_range(4.0); ASSERT_EQ(1, er4.second-er4.first); - EXPECT_EQ(8, er4.first->second); + EXPECT_EQ(8, er4.first->element); auto er5 = p.equal_range(5.0); ASSERT_EQ(er5.first, er5.second); @@ -269,22 +269,22 @@ TEST(piecewise, equal_range) { auto er1 = p.equal_range(1.0); ASSERT_EQ(2, er1.second-er1.first); auto iter = er1.first; - EXPECT_EQ(10, iter++->second); - EXPECT_EQ(11, iter++->second); + EXPECT_EQ(10, iter++->element); + EXPECT_EQ(11, iter++->element); auto er2 = p.equal_range(2.0); ASSERT_EQ(4, er2.second-er2.first); iter = er2.first; - EXPECT_EQ(11, iter++->second); - EXPECT_EQ(12, iter++->second); - EXPECT_EQ(13, iter++->second); - EXPECT_EQ(14, iter++->second); + EXPECT_EQ(11, iter++->element); + EXPECT_EQ(12, iter++->element); + EXPECT_EQ(13, iter++->element); + EXPECT_EQ(14, iter++->element); auto er3 = p.equal_range(3.0); ASSERT_EQ(2, er3.second-er3.first); iter = er3.first; - EXPECT_EQ(14, iter++->second); - EXPECT_EQ(15, iter++->second); + EXPECT_EQ(14, iter++->element); + EXPECT_EQ(15, iter++->element); auto er5 = p.equal_range(5.0); ASSERT_EQ(er5.first, er5.second); @@ -304,12 +304,12 @@ TEST(piecewise, push) { q.push_back(3.1, 4.3, 5); EXPECT_EQ(dp(1.1, 3.1), q.interval(0)); EXPECT_EQ(dp(3.1, 4.3), q.interval(1)); - EXPECT_EQ(4, q[0].second); - EXPECT_EQ(5, q[1].second); + EXPECT_EQ(4, q[0].element); + EXPECT_EQ(5, q[1].element); q.push_back(7.2, 6); EXPECT_EQ(dp(4.3, 7.2), q.interval(2)); - EXPECT_EQ(6, q[2].second); + EXPECT_EQ(6, q[2].element); // Supplied left side doesn't match current right. EXPECT_THROW(q.push_back(7.4, 9.1, 7), std::runtime_error); @@ -340,7 +340,7 @@ TEST(piecewise, zip) { p03.assign((double [3]){0., 1.5, 3.}, (int [2]){10, 11}); pw_elements<int> p14; - p14.assign((double [5]){1, 2.25, 3., 3.5, 4.}, (int [4]){3, 4, 5, 6}); + p14.assign((double [5]){1, 2.25, 3.25, 3.5, 4.}, (int [4]){3, 4, 5, 6}); using ii = std::pair<int, int>; pw_elements<ii> p03_14 = zip(p03, p14); @@ -367,27 +367,102 @@ TEST(piecewise, zip) { EXPECT_EQ((std::vector<double>{1., 1.5, 2.25, 3.}), zip(p14, v03).vertices()); auto project = [](double l, double r, pw_element<void>, const pw_element<int>& b) -> double { - double b_width = b.first.second-b.first.first; - return b.second*(r-l)/b_width; + double b_width = b.interval.second-b.interval.first; + return b.element*(r-l)/b_width; }; pw_elements<void> vxx; // elements cover bounds of p14 vxx.assign((double [6]){0.2, 1.7, 1.95, 2.325, 2.45, 4.9}); pw_elements<double> pxx = zip(vxx, p14, project); - double p14_sum = util::sum(util::transform_view(p14, [](auto v) { return v.second; })); - double pxx_sum = util::sum(util::transform_view(pxx, [](auto v) { return v.second; })); + double p14_sum = util::sum(util::transform_view(p14, [](auto v) { return v.element; })); + double pxx_sum = util::sum(util::transform_view(pxx, [](auto v) { return v.element; })); EXPECT_DOUBLE_EQ(p14_sum, pxx_sum); } -TEST(piecewise, zip_void) { - pw_elements<> p03; - p03.assign((double [3]){0., 1.5, 3.}); +TEST(piecewise, zip_zero_length_elements) { + pw_elements<int> p03a; + p03a.assign((double [5]){0, 0, 1.5, 3, 3}, (int [4]){10, 11, 12, 13}); - pw_elements<> p14; - p14.assign((double [5]){1, 2.25, 3., 3.5, 4.}); + pw_elements<int> p03b; + p03b.assign((double [7]){0, 0, 0, 1, 3, 3, 3.}, (int [6]){20, 21, 22, 23, 24, 25}); - EXPECT_EQ((std::vector<double>{1., 1.5, 2.25, 3.}), zip(p03, p14).vertices()); - EXPECT_EQ((std::vector<double>{1., 1.5, 2.25, 3.}), zip(p14, p03).vertices()); + pw_elements<int> p33; + p33.assign((double [3]){3, 3, 3}, (int [2]){30, 31}); + + pw_elements<int> p14; + p14.assign((double [3]){1, 2, 4}, (int [2]){40, 41}); + + auto flip = [](auto& pairs) { for (auto& [l, r]: pairs) std::swap(l, r); }; + using ii = std::pair<int, int>; + + { + pw_elements<ii> zz = zip(p03a, p03b); + EXPECT_EQ(0., zz.bounds().first); + EXPECT_EQ(3., zz.bounds().second); + + std::vector<double> expected_vertices = {0, 0, 0, 1, 1.5, 3, 3, 3}; + std::vector<ii> expected_elements = {ii(10, 20), ii(11, 21), ii(11,22), ii(11,23), ii(12, 23), ii(13,24), ii(13,25)}; + + EXPECT_EQ(expected_vertices, zz.vertices()); + EXPECT_EQ(expected_elements, zz.elements()); + + pw_elements<ii> yy = zip(p03b, p03a); + flip(expected_elements); + + EXPECT_EQ(expected_vertices, yy.vertices()); + EXPECT_EQ(expected_elements, yy.elements()); + } + + { + pw_elements<ii> zz = zip(p03a, p33); + EXPECT_EQ(3., zz.bounds().first); + EXPECT_EQ(3., zz.bounds().second); + + std::vector<double> expected_vertices = {3, 3, 3}; + std::vector<ii> expected_elements = {ii(12, 30), ii(13, 31)}; + + EXPECT_EQ(expected_vertices, zz.vertices()); + EXPECT_EQ(expected_elements, zz.elements()); + + pw_elements<ii> yy = zip(p33, p03a); + flip(expected_elements); + + EXPECT_EQ(expected_vertices, yy.vertices()); + EXPECT_EQ(expected_elements, yy.elements()); + + } + + { + pw_elements<ii> zz = zip(p03a, p14); + EXPECT_EQ(1., zz.bounds().first); + EXPECT_EQ(3., zz.bounds().second); + + std::vector<double> expected_vertices = {1, 1.5, 2, 3, 3}; + std::vector<ii> expected_elements = {ii(11, 40), ii(12, 40), ii(12, 41), ii(13, 41)}; + + EXPECT_EQ(expected_vertices, zz.vertices()); + EXPECT_EQ(expected_elements, zz.elements()); + + pw_elements<ii> yy = zip(p14, p03a); + flip(expected_elements); + + EXPECT_EQ(expected_vertices, yy.vertices()); + EXPECT_EQ(expected_elements, yy.elements()); + } + + { + // Check void version too! + pw_elements<> v03a(p03a), v03b(p03b); + pw_elements<> zz = zip(v03a, v03b); + EXPECT_EQ(0., zz.bounds().first); + EXPECT_EQ(3., zz.bounds().second); + + std::vector<double> expected_vertices = {0, 0, 0, 1, 1.5, 3, 3, 3}; + EXPECT_EQ(expected_vertices, zz.vertices()); + + pw_elements<> yy = zip(v03b, v03a); + EXPECT_EQ(expected_vertices, yy.vertices()); + } } -- GitLab