diff --git a/external/vector b/external/vector index 193df3683a3eaa2bf203bbcfa935d7e7019b1719..8284611f05b0fbe21a1f84630e2726015cb1d96d 160000 --- a/external/vector +++ b/external/vector @@ -1 +1 @@ -Subproject commit 193df3683a3eaa2bf203bbcfa935d7e7019b1719 +Subproject commit 8284611f05b0fbe21a1f84630e2726015cb1d96d diff --git a/src/algorithms.hpp b/src/algorithms.hpp index 83eeaf8dd7e32f776384d0c47fec5e5d0b43d1a8..5d609b1effe57f12d3998e10d856b659c042120e 100644 --- a/src/algorithms.hpp +++ b/src/algorithms.hpp @@ -19,230 +19,302 @@ namespace nest { namespace mc { -namespace algorithms{ - - template <typename C> - typename C::value_type - sum(C const& c) - { - using value_type = typename C::value_type; - return std::accumulate(c.begin(), c.end(), value_type{0}); - } - - template <typename C> - typename C::value_type - mean(C const& c) - { - return sum(c)/c.size(); +namespace algorithms { + +template <typename C> +typename C::value_type +sum(C const& c) +{ + using value_type = typename C::value_type; + return std::accumulate(c.begin(), c.end(), value_type{0}); +} + +template <typename C> +typename C::value_type +mean(C const& c) +{ + return sum(c)/c.size(); +} + +template <typename C> +C make_index(C const& c) +{ + static_assert( + std::is_integral<typename C::value_type>::value, + "make_index only applies to integral types" + ); + + C out(c.size()+1); + out[0] = 0; + std::partial_sum(c.begin(), c.end(), out.begin()+1); + return out; +} + +/// works like std::is_sorted(), but with stronger condition that succesive +/// elements must be greater than those before them +template <typename C> +bool is_strictly_monotonic_increasing(C const& c) +{ + using value_type = typename C::value_type; + return std::is_sorted( + c.begin(), + c.end(), + [] (value_type const& lhs, value_type const& rhs) { + return lhs <= rhs; + } + ); +} + +template <typename C> +bool is_strictly_monotonic_decreasing(C const& c) +{ + using value_type = typename C::value_type; + return std::is_sorted( + c.begin(), + c.end(), + [] (value_type const& lhs, value_type const& rhs) { + return lhs >= rhs; + } + ); +} + +template < + typename C, + typename = typename std::enable_if<std::is_integral<typename C::value_type>::value> +> +bool is_minimal_degree(C const& c) +{ + static_assert( + std::is_integral<typename C::value_type>::value, + "is_minimal_degree only applies to integral types" + ); + + if(c.size()==0u) { + return true; } - template <typename C> - C make_index(C const& c) - { - static_assert( - std::is_integral<typename C::value_type>::value, - "make_index only applies to integral types" - ); - - C out(c.size()+1); - out[0] = 0; - std::partial_sum(c.begin(), c.end(), out.begin()+1); - return out; + using value_type = typename C::value_type; + if(c[0] != value_type(0)) { + return false; } - - /// works like std::is_sorted(), but with stronger condition that succesive - /// elements must be greater than those before them - template <typename C> - bool is_strictly_monotonic_increasing(C const& c) - { - using value_type = typename C::value_type; - return std::is_sorted( - c.begin(), - c.end(), - [] (value_type const& lhs, value_type const& rhs) { - return lhs <= rhs; - } - ); + auto i = value_type(1); + auto it = std::find_if( + c.begin()+1, c.end(), [&i](value_type v) { return v>=(i++); } + ); + return it==c.end(); +} + +template <typename C> +bool is_positive(C const& c) +{ + static_assert( + std::is_integral<typename C::value_type>::value, + "is_positive only applies to integral types" + ); + for(auto v : c) { + if(v<1) { + return false; + } } - - template <typename C> - bool is_strictly_monotonic_decreasing(C const& c) - { - using value_type = typename C::value_type; - return std::is_sorted( - c.begin(), - c.end(), - [] (value_type const& lhs, value_type const& rhs) { - return lhs >= rhs; - } - ); + return true; +} + +template<typename C> +bool has_contiguous_segments(const C &parent_index) +{ + static_assert( + std::is_integral<typename C::value_type>::value, + "integral type required" + ); + + if (!is_minimal_degree(parent_index)) { + return false; } - template < - typename C, - typename = typename std::enable_if<std::is_integral<typename C::value_type>::value> - > - bool is_minimal_degree(C const& c) - { - static_assert( - std::is_integral<typename C::value_type>::value, - "is_minimal_degree only applies to integral types" - ); - - if(c.size()==0u) { - return true; - } + int n = parent_index.size(); + std::vector<bool> is_leaf(n, false); - using value_type = typename C::value_type; - if(c[0] != value_type(0)) { + for(auto i=1; i<n; ++i) { + auto p = parent_index[i]; + if(is_leaf[p]) { return false; } - auto i = value_type(1); - auto it = std::find_if( - c.begin()+1, c.end(), [&i](value_type v) { return v>=(i++); } - ); - return it==c.end(); - } - template <typename C> - bool is_positive(C const& c) - { - static_assert( - std::is_integral<typename C::value_type>::value, - "is_positive only applies to integral types" - ); - for(auto v : c) { - if(v<1) { - return false; - } + if(p != i-1) { + // we have a branch and i-1 is a leaf node + is_leaf[i-1] = true; } - return true; } - template<typename C> - bool has_contiguous_segments(const C &parent_index) - { - static_assert( - std::is_integral<typename C::value_type>::value, - "integral type required" - ); + return true; +} - if (!is_minimal_degree(parent_index)) { - return false; - } +template<typename C> +std::vector<typename C::value_type> child_count(const C &parent_index) +{ + static_assert( + std::is_integral<typename C::value_type>::value, + "integral type required" + ); - int n = parent_index.size(); - std::vector<bool> is_leaf(n, false); + std::vector<typename C::value_type> count(parent_index.size(), 0); + for (std::size_t i = 1; i < parent_index.size(); ++i) { + ++count[parent_index[i]]; + } - for(auto i=1; i<n; ++i) { - auto p = parent_index[i]; - if(is_leaf[p]) { - return false; - } + return count; +} - if(p != i-1) { - // we have a branch and i-1 is a leaf node - is_leaf[i-1] = true; - } - } +template<typename C> +std::vector<typename C::value_type> branches(const C& parent_index) +{ + static_assert( + std::is_integral<typename C::value_type>::value, + "integral type required" + ); - return true; + EXPECTS(has_contiguous_segments(parent_index)); + + std::vector<typename C::value_type> branch_index; + if (parent_index.empty()) { + return branch_index; } - template<typename C> - std::vector<typename C::value_type> child_count(const C &parent_index) - { - static_assert( - std::is_integral<typename C::value_type>::value, - "integral type required" - ); - - std::vector<typename C::value_type> count(parent_index.size(), 0); - for (std::size_t i = 1; i < parent_index.size(); ++i) { - ++count[parent_index[i]]; + auto num_child = child_count(parent_index); + branch_index.push_back(0); + for (std::size_t i = 1; i < parent_index.size(); ++i) { + auto p = parent_index[i]; + if (num_child[p] > 1 || parent_index[p] == p) { + // parent_index[p] == p -> parent_index[i] is the soma + branch_index.push_back(i); } - - return count; } - template<typename C> - std::vector<typename C::value_type> branches(const C &parent_index) - { - static_assert( - std::is_integral<typename C::value_type>::value, - "integral type required" - ); - - EXPECTS(has_contiguous_segments(parent_index)); - - auto num_child = child_count(parent_index); - std::vector<typename C::value_type> branch_runs( - parent_index.size(), 0 - ); - - std::size_t num_branches = (num_child[0] == 1) ? 1 : 0; - for (std::size_t i = 1; i < parent_index.size(); ++i) { - auto p = parent_index[i]; - if (num_child[p] > 1) { - ++num_branches; - } - - branch_runs[i] = num_branches; - } + branch_index.push_back(parent_index.size()); + return branch_index; +} + - return branch_runs; +template<typename C> +std::vector<typename C::value_type> expand_branches(const C& branch_index) +{ + static_assert( + std::is_integral<typename C::value_type>::value, + "integral type required" + ); + + if (branch_index.empty()) + return {}; + + std::vector<typename C::value_type> expanded(branch_index.back()); + for (std::size_t i = 0; i < branch_index.size()-1; ++i) { + for (std::size_t j = branch_index[i]; j < branch_index[i+1]; ++j) { + expanded[j] = i; + } } - template<typename C> - bool is_sorted(const C& c) - { - return std::is_sorted(c.begin(), c.end()); + return expanded; +} + +template<typename C> +typename C::value_type find_branch(const C& branch_index, + typename C::value_type nid) +{ + using value_type = typename C::value_type; + static_assert( + std::is_integral<value_type>::value, + "integral type required" + ); + + auto it = std::find_if( + branch_index.begin(), branch_index.end(), + [nid](const value_type &v) { return v > nid; } + ); + + return it - branch_index.begin() - 1; +} + + +template<typename C> +std::vector<typename C::value_type> make_parent_index( + const C& parent_index, const C& branch_index) +{ + static_assert( + std::is_integral<typename C::value_type>::value, + "integral type required" + ); + + if (parent_index.empty() && branch_index.empty()) { + return {}; } - template<typename C> - bool is_unique(const C& c) - { - return std::adjacent_find(c.begin(), c.end()) == c.end(); + EXPECTS(parent_index.size() == branch_index.back()); + EXPECTS(has_contiguous_segments(parent_index)); + EXPECTS(is_strictly_monotonic_increasing(branch_index)); + + // expand the branch index + auto expanded_branch = expand_branches(branch_index); + + std::vector<typename C::value_type> new_parent_index; + for (std::size_t i = 0; i < branch_index.size()-1; ++i) { + auto p = parent_index[branch_index[i]]; + new_parent_index.push_back(expanded_branch[p]); } - /// Return and index that maps entries in sub to their corresponding - /// values in super, where sub is a subset of super. - /// - /// Both sets are sorted and have unique entries. - /// Complexity is O(n), where n is size of super - template<typename C> - // C::iterator models forward_iterator - // C::value_type is_integral - C index_into(const C& super, const C& sub) - { - //EXPECTS {s \in super : \forall s \in sub}; - EXPECTS(is_unique(super) && is_unique(sub)); - EXPECTS(is_sorted(super) && is_sorted(sub)); - EXPECTS(sub.size() <= super.size()); - - static_assert( - std::is_integral<typename C::value_type>::value, - "index_into only applies to integral types" - ); - - C out(sub.size()); // out will have one entry for each index in sub - - auto sub_it=sub.begin(); - auto super_it=super.begin(); - auto sub_idx=0u, super_idx = 0u; - - while(sub_it!=sub.end() && super_it!=super.end()) { - if(*sub_it==*super_it) { - out[sub_idx] = super_idx; - ++sub_it; ++sub_idx; - } - ++super_it; ++super_idx; + return new_parent_index; +} + + +template<typename C> +bool is_sorted(const C& c) +{ + return std::is_sorted(c.begin(), c.end()); +} + +template<typename C> +bool is_unique(const C& c) +{ + return std::adjacent_find(c.begin(), c.end()) == c.end(); +} + +/// Return and index that maps entries in sub to their corresponding +/// values in super, where sub is a subset of super. +/// +/// Both sets are sorted and have unique entries. +/// Complexity is O(n), where n is size of super +template<typename C> +// C::iterator models forward_iterator +// C::value_type is_integral +C index_into(const C& super, const C& sub) +{ + //EXPECTS {s \in super : \forall s \in sub}; + EXPECTS(is_unique(super) && is_unique(sub)); + EXPECTS(is_sorted(super) && is_sorted(sub)); + EXPECTS(sub.size() <= super.size()); + + static_assert( + std::is_integral<typename C::value_type>::value, + "index_into only applies to integral types" + ); + + C out(sub.size()); // out will have one entry for each index in sub + + auto sub_it=sub.begin(); + auto super_it=super.begin(); + auto sub_idx=0u, super_idx = 0u; + + while(sub_it!=sub.end() && super_it!=super.end()) { + if(*sub_it==*super_it) { + out[sub_idx] = super_idx; + ++sub_it; ++sub_idx; } + ++super_it; ++super_idx; + } - EXPECTS(sub_idx==sub.size()); + EXPECTS(sub_idx==sub.size()); - return out; - } + return out; +} } // namespace algorithms } // namespace mc diff --git a/src/cell_tree.hpp b/src/cell_tree.hpp index 1e04fc2c3bcc4e0fe887c7048997afd272fa4bff..4d9f70b1a934ab3b5bdb840d6bb0bd1b8419dfec 100644 --- a/src/cell_tree.hpp +++ b/src/cell_tree.hpp @@ -32,7 +32,7 @@ class cell_tree { public : // use a signed 16-bit integer for storage of indexes, which is reasonable given // that typical cells have at most 1000-2000 segments - using int_type = int16_t; + using int_type = int; using index_type = memory::HostVector<int_type>; using view_type = index_type::view_type; using const_view_type = index_type::const_view_type; @@ -298,4 +298,3 @@ private : } // namespace mc } // namespace nest - diff --git a/src/swcio.cpp b/src/swcio.cpp index 792d0fc7bdb3b4898b575dae552d0347efd85e9c..203f69eb1bdca8905145edb4edbe86d564682e9b 100644 --- a/src/swcio.cpp +++ b/src/swcio.cpp @@ -1,4 +1,5 @@ #include <algorithm> +#include <functional> #include <iomanip> #include <map> #include <sstream> @@ -230,88 +231,60 @@ swc_record_range_clean::swc_record_range_clean(std::istream &is) } } -// -// Convenience functions for returning the radii and the coordinates of a series -// of swc records -// -static std::vector<swc_record::coord_type> -swc_radii(const std::vector<swc_record> &records) -{ - std::vector<swc_record::coord_type> radii; - for (const auto &r : records) { - radii.push_back(r.radius()); - } - - return radii; -} - -static std::vector<nest::mc::point<swc_record::coord_type> > -swc_points(const std::vector<swc_record> &records) -{ - std::vector<nest::mc::point<swc_record::coord_type> > points; - for (const auto &r : records) { - points.push_back(r.coord()); - } - - return points; -} - -static void make_cable(cell &cell, - const std::vector<swc_record::id_type> &branch_index, - const std::vector<swc_record> &branch_run) -{ - auto new_parent = branch_index[branch_run.back().id()] - 1; - cell.add_cable(new_parent, nest::mc::segmentKind::dendrite, - swc_radii(branch_run), swc_points(branch_run)); -} - cell swc_read_cell(std::istream &is) { + using namespace nest::mc; + cell newcell; - std::vector<swc_record::id_type> parent_list; + std::vector<swc_record::id_type> parent_index; std::vector<swc_record> swc_records; for (const auto &r : swc_get_records<swc_io_clean>(is)) { swc_records.push_back(r); - parent_list.push_back(r.parent()); + parent_index.push_back(r.parent()); } - // The parent of soma must be 0 - if (!parent_list.empty()) { - parent_list[0] = 0; + if (parent_index.empty()) { + return newcell; } - auto branch_index = nest::mc::algorithms::branches(parent_list); - std::vector<swc_record> branch_run; - - branch_run.reserve(parent_list.size()); - auto last_branch_point = branch_index[0]; - for (auto i = 0u; i < swc_records.size(); ++i) { - if (branch_index[i] != last_branch_point) { - // New branch encountered; add to cell the current one - const auto &p = branch_run.back(); - if (p.parent() == -1) { - // This is a soma - newcell.add_soma(p.radius(), p.coord()); - last_branch_point = i; - } else { - last_branch_point = i - 1; - make_cable(newcell, branch_index, branch_run); - } - - // Reset the branch run - branch_run.clear(); - if (p.parent() != -1) { - // Add parent of the current cell to the branch, - // if not branching from soma - branch_run.push_back(swc_records[parent_list[i]]); - } + // The parent of soma must be 0, while in SWC files is -1 + parent_index[0] = 0; + auto branch_index = algorithms::branches(parent_index); + auto new_parent_index = algorithms::make_parent_index(parent_index, + branch_index); + + // sanity check + EXPECTS(new_parent_index.size() == branch_index.size() - 1); + + // Add the soma first; then the segments + newcell.add_soma(swc_records[0].radius(), swc_records[0].coord()); + for (std::size_t i = 1; i < new_parent_index.size(); ++i) { + auto b_start = std::next(swc_records.begin(), branch_index[i]); + auto b_end = std::next(swc_records.begin(), branch_index[i+1]); + + std::vector<swc_record::coord_type> radii; + std::vector<nest::mc::point<swc_record::coord_type>> points; + if (new_parent_index[i] != 0) { + // include the parent of current record if not branching from soma + auto p = parent_index[branch_index[i]]; + radii.push_back(swc_records[p].radius()); + points.push_back(swc_records[p].coord()); } - branch_run.push_back(swc_records[i]); - } + // extract the radii and the points + std::for_each(b_start, b_end, + [&radii](const swc_record& r) { + radii.push_back(r.radius()); + }); + + std::for_each(b_start, b_end, + [&points](const swc_record& r) { + points.push_back(r.coord()); + }); - if (!branch_run.empty()) { - make_cable(newcell, branch_index, branch_run); + // add the new cable + newcell.add_cable(new_parent_index[i], + nest::mc::segmentKind::dendrite, radii, points); } return newcell; diff --git a/src/tree.hpp b/src/tree.hpp index 3d76a9dc1e8171f62026320a98f71078e91f1dc5..98ff9aa43656f9d68b879297a6ffd5399348b6f0 100644 --- a/src/tree.hpp +++ b/src/tree.hpp @@ -17,7 +17,7 @@ class tree { public : - using int_type = int16_t; + using int_type = int; using index_type = memory::HostVector<int_type>; using view_type = index_type::view_type; @@ -62,92 +62,22 @@ class tree { ); } - // n = number of compartment in cell - auto n = parent_index.size(); + auto new_parent_index = algorithms::make_parent_index( + parent_index, algorithms::branches(parent_index)); - // On completion of this loop child_count[i] is the number of children - // of compartment i compensate count for compartment 0, which has itself - // as its own parent - index_type child_count(n, 0); - child_count[0] = -1; - for(auto i : parent_index) { - ++child_count[i]; - } + init(new_parent_index.size()); + parents_(memory::all) = new_parent_index; + parents_[0] = -1; - // Find the number of branches by summing the number of children of all - // compartments with more than 1 child. - auto nbranches = 1 + child_count[0]; // children of the root node are counted differently - for(auto i : range(1,n)) { - if(child_count[i]>1) { - nbranches += child_count[i]; - } - } + child_index_(memory::all) = + algorithms::make_index(algorithms::child_count(parents_)); - // allocate memory for storing the tree - init(nbranches); - - // index of the branch for each compartment - std::vector<I> branch_index(n); - - // Mark the parent of the root node as -1. - // This simplifies the implementation of some algorithms on the tree - parents_[0]=-1; - - // bcount records how many branches have been created in the loop - auto bcount=0; - - for(auto i : range(1,n)) { - // the branch index of the parent of compartment i - auto parent_node = parent_index[i]; - // index of the parent of compartment i - auto parent_branch = branch_index[parent_node]; - - // if this compartments's parent - // - has more than one child - // - or is the root branch - // this is the first compartment in a branch, so mark it as such - if(child_count[parent_node]>1 || parent_node==0) { - bcount++; - branch_index[i] = bcount; - parents_[bcount] = parent_branch; - } - // not the first compartment in a branch, - // so inherit the parent's branch number - else { - branch_index[i] = parent_branch; - } - } - - child_index_(memory::all) = 0; - // the root node has to be handled separately: all of its children must - // be counted - child_index_[1] = child_count[0]; - for(auto i : range(1, n)) { - if(child_count[i]>1) { - child_index_[branch_index[i]+1] = child_count[i]; - } - } - std::partial_sum(child_index_.begin(), child_index_.end(), child_index_.begin()); - - // Fill in the list of children of each branch. - // Requires some additional book keeping to keep track of how many - // children have already been filled in for each branch. - for(auto i : range(1, nbranches)) { - // parents_[i] is the parent of branch i, for which i must be added - // as a child, and child_index_[p] is the index into which the next - // child of p is to be stored + std::vector<int> pos(parents_.size(), 0); + for (auto i = 1u; i < parents_.size(); ++i) { auto p = parents_[i]; - children_[child_index_[p]] = i; - ++child_index_[p]; - } - - // The child index has already been calculated as a side-effect of the - // loop above, but is shifted one value to the left, so perform a - // rotation to the right. - for(auto i=nbranches-1; i>=00; --i) { - child_index_[i+1] = child_index_[i]; + children_[child_index_[p] + pos[p]] = i; + ++pos[p]; } - child_index_[0] = 0; } size_t num_children() const { diff --git a/tests/test_algorithms.cpp b/tests/test_algorithms.cpp index 584bdf8ad10db170b44604d3db8f8047d3c29860..ec25bd42c0941b8324fbea3288ed5b4f1d1f7a59 100644 --- a/tests/test_algorithms.cpp +++ b/tests/test_algorithms.cpp @@ -369,11 +369,11 @@ TEST(algorithms, branches) { // - // 0 - // /|\. - // 1 4 6 - // / | \. - // 2 5 7 + // 0 0 + // /|\. /|\. + // 1 4 6 1 2 3 + // / | \. => / \. + // 2 5 7 4 5 // / \. // 3 8 // / \. @@ -386,58 +386,100 @@ TEST(algorithms, branches) std::vector<int> parent_index = { 0, 0, 1, 2, 0, 4, 0, 6, 7, 8, 9, 8, 11, 12 }; std::vector<int> expected_branches = - { 0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5, 5 }; + { 0, 1, 4, 6, 9, 11, 14 }; + std::vector<int> expected_parent_index = + { 0, 0, 0, 0, 3, 3 }; auto actual_branches = algorithms::branches(parent_index); EXPECT_EQ(expected_branches, actual_branches); + + auto actual_parent_index = + algorithms::make_parent_index(parent_index, actual_branches); + EXPECT_EQ(expected_parent_index, actual_parent_index); + + // Check find_branch + EXPECT_EQ(0, algorithms::find_branch(actual_branches, 0)); + EXPECT_EQ(1, algorithms::find_branch(actual_branches, 1)); + EXPECT_EQ(1, algorithms::find_branch(actual_branches, 2)); + EXPECT_EQ(1, algorithms::find_branch(actual_branches, 3)); + EXPECT_EQ(2, algorithms::find_branch(actual_branches, 4)); + EXPECT_EQ(2, algorithms::find_branch(actual_branches, 5)); + EXPECT_EQ(3, algorithms::find_branch(actual_branches, 6)); + EXPECT_EQ(3, algorithms::find_branch(actual_branches, 7)); + EXPECT_EQ(3, algorithms::find_branch(actual_branches, 8)); + EXPECT_EQ(4, algorithms::find_branch(actual_branches, 9)); + EXPECT_EQ(4, algorithms::find_branch(actual_branches, 10)); + EXPECT_EQ(5, algorithms::find_branch(actual_branches, 11)); + EXPECT_EQ(5, algorithms::find_branch(actual_branches, 12)); + EXPECT_EQ(5, algorithms::find_branch(actual_branches, 13)); + EXPECT_EQ(6, algorithms::find_branch(actual_branches, 55)); + + // Check expand_branches + auto expanded = algorithms::expand_branches(actual_branches); + EXPECT_EQ(parent_index.size(), expanded.size()); + for (std::size_t i = 0; i < parent_index.size(); ++i) { + EXPECT_EQ(algorithms::find_branch(actual_branches, i), + expanded[i]); + } } { // - // 0 - // | - // 1 + // 0 0 + // | | + // 1 => 1 // | // 2 // | // 3 // - std::vector<int> parent_index = - { 0, 0, 1, 2 }; - std::vector<int> expected_branches = - { 0, 1, 1, 1 }; + std::vector<int> parent_index = { 0, 0, 1, 2 }; + std::vector<int> expected_branches = { 0, 1, 4 }; + std::vector<int> expected_parent_index = { 0, 0 }; auto actual_branches = algorithms::branches(parent_index); EXPECT_EQ(expected_branches, actual_branches); + + auto actual_parent_index = + algorithms::make_parent_index(parent_index, actual_branches); + EXPECT_EQ(expected_parent_index, actual_parent_index); } { // - // 0 - // | - // 1 - // | - // 2 + // 0 0 + // | | + // 1 => 1 + // | / \. + // 2 2 3 // / \. // 3 4 // \. // 5 // - std::vector<int> parent_index = - { 0, 0, 1, 2, 2, 4 }; - std::vector<int> expected_branches = - { 0, 1, 1, 2, 3, 3 }; + std::vector<int> parent_index = { 0, 0, 1, 2, 2, 4 }; + std::vector<int> expected_branches = { 0, 1, 3, 4, 6 }; + std::vector<int> expected_parent_index = { 0, 0, 1, 1 }; auto actual_branches = algorithms::branches(parent_index); EXPECT_EQ(expected_branches, actual_branches); + + auto actual_parent_index = + algorithms::make_parent_index(parent_index, actual_branches); + EXPECT_EQ(expected_parent_index, actual_parent_index); } { - std::vector<int> parent_index = { 0 }; - std::vector<int> expected_branches = { 0 }; + std::vector<int> parent_index = { 0 }; + std::vector<int> expected_branches = { 0, 1 }; + std::vector<int> expected_parent_index = { 0 }; auto actual_branches = algorithms::branches(parent_index); EXPECT_EQ(expected_branches, actual_branches); + + auto actual_parent_index = + algorithms::make_parent_index(parent_index, actual_branches); + EXPECT_EQ(expected_parent_index, actual_parent_index); } } diff --git a/tests/test_swcio.cpp b/tests/test_swcio.cpp index 62c52b099e6142f9bc409ca461568dbce7ffa4d7..5ba0b991a865dfb397e5c1dbbb1e8813e71b628c 100644 --- a/tests/test_swcio.cpp +++ b/tests/test_swcio.cpp @@ -521,4 +521,3 @@ TEST(swc_parser, from_file_ball_and_stick) EXPECT_TRUE(nest::mc::cell_basic_equality(local_cell, cell)); } - diff --git a/tests/test_tree.cpp b/tests/test_tree.cpp index 1e06b3a90fdf92f96e43e2c93036d37788c3efa6..ca9c6d73b1d9e4bde0bb90fc9061ff34f69f0503 100644 --- a/tests/test_tree.cpp +++ b/tests/test_tree.cpp @@ -142,20 +142,27 @@ TEST(cell_tree, from_parent_index) { { // // 0 - // / \. - // 1 2 + // /|\. + // 1 4 5 // / \. - // 3 4 - std::vector<int> parent_index = {0,0,0,1,1}; + // 2 3 + std::vector<int> parent_index = {0,0,1,1,0,0}; cell_tree tree(parent_index); - EXPECT_EQ(tree.num_segments(), 5u); + EXPECT_EQ(tree.num_segments(), 6u); - EXPECT_EQ(tree.num_children(0), 2u); + EXPECT_EQ(tree.num_children(0), 3u); EXPECT_EQ(tree.num_children(1), 2u); EXPECT_EQ(tree.num_children(2), 0u); EXPECT_EQ(tree.num_children(3), 0u); EXPECT_EQ(tree.num_children(4), 0u); + + // Check children + EXPECT_EQ(1, tree.children(0)[0]); + EXPECT_EQ(4, tree.children(0)[1]); + EXPECT_EQ(5, tree.children(0)[2]); + EXPECT_EQ(2, tree.children(1)[0]); + EXPECT_EQ(3, tree.children(1)[1]); } { // 0