diff --git a/arbor/include/arbor/util/variant.hpp b/arbor/include/arbor/util/variant.hpp index 5b2715b84bb97394366bfdbb37e600d5114853bc..22fec8f1234e35f08c1f0556c876e6c7246f3488 100644 --- a/arbor/include/arbor/util/variant.hpp +++ b/arbor/include/arbor/util/variant.hpp @@ -181,6 +181,14 @@ struct variant_dynamic_impl<> { } static void destroy(std::size_t i, char* data) {} + + static bool cmp_eq(std::size_t i, const char* left, const char* right) { + return i==std::size_t(-1)? true: throw bad_variant_access{}; + } + + static bool cmp_ne(std::size_t i, const char* left, const char* right) { + return i==std::size_t(-1)? false: throw bad_variant_access{}; + } }; template <typename H, typename... T> @@ -240,6 +248,18 @@ struct variant_dynamic_impl<H, T...> { variant_dynamic_impl<T...>::destroy(i-1, data); } } + + static bool cmp_eq(std::size_t i, const char* left, const char* right) { + return i==0? + *reinterpret_cast<const H*>(left)==*reinterpret_cast<const H*>(right): + variant_dynamic_impl<T...>::cmp_eq(i-1, left, right); + } + + static bool cmp_ne(std::size_t i, const char* left, const char* right) { + return i==0? + *reinterpret_cast<const H*>(left)!=*reinterpret_cast<const H*>(right): + variant_dynamic_impl<T...>::cmp_ne(i-1, left, right); + } }; template <typename... T> @@ -467,6 +487,16 @@ struct variant { template <typename X, std::size_t I = type_index<X, T...>::value> decltype(auto) get() const { return get<I>(); } + + // Comparisons. + + bool operator==(const variant& x) const { + return which_==x.which_ && (valueless_by_exception() || variant_dynamic_impl<T...>::cmp_eq(which_, data, x.data)); + } + + bool operator!=(const variant& x) const { + return which_!=x.which_ || (!valueless_by_exception() && variant_dynamic_impl<T...>::cmp_ne(which_, data, x.data)); + } }; template <std::size_t I, std::size_t N> @@ -584,16 +614,16 @@ namespace std { // Unambitious hash: template <typename... T> struct hash<::arb::util::variant<T...>> { - std::size_t operator()(const ::arb::util::variant<T...>& v) { + std::size_t operator()(const ::arb::util::variant<T...>& v) const { return v.index() ^ - visit([](const auto& a) { return std::hash<std::remove_cv_t<decltype(a)>>{}(a); }, v); + ::arb::util::visit([](const auto& a) { return std::hash<std::remove_cv_t<std::remove_reference_t<decltype(a)>>>{}(a); }, v); } }; // Still haven't really determined if it is okay to have a variant<>, but if we do allow it... template <> struct hash<::arb::util::variant<>> { - std::size_t operator()(const ::arb::util::variant<>& v) { return 0u; }; + std::size_t operator()(const ::arb::util::variant<>& v) const { return 0u; }; }; // std::swap specialization. diff --git a/test/unit/test_variant.cpp b/test/unit/test_variant.cpp index e8ca9331371bef4f18a802ec5367d940341a5e36..50708993064682ecf91ed6800425836300bd416d 100644 --- a/test/unit/test_variant.cpp +++ b/test/unit/test_variant.cpp @@ -270,6 +270,57 @@ TEST(variant, valueless) { EXPECT_EQ(std::size_t(-1), vi.index()); } +TEST(variant, equality) { + struct X { + int i; + X(int i): i(i) {} + X& operator=(const X&) { throw "nope"; } + bool operator==(X x) const { return i==x.i; } + // Crazy != semantics on purpose: + bool operator!=(X x) const { return i==x.i+1; } + }; + + ASSERT_TRUE(X{1} == X{1}); + ASSERT_FALSE(X{1} == X{0}); + ASSERT_FALSE(X{1} == X{2}); + + ASSERT_TRUE(X{1} != X{0}); + ASSERT_FALSE(X{1} != X{1}); + ASSERT_FALSE(X{1} != X{2}); + + using vidX = variant<int, double, X>; + auto valueless = []() { + vidX v{X{0}}; + try { v = v; } catch (...) {}; + return v; + }; + + EXPECT_TRUE(valueless() == valueless()); + EXPECT_FALSE(valueless() != valueless()); + + EXPECT_TRUE(vidX{3} == vidX{3}); + EXPECT_FALSE(vidX{3.0} == vidX{3}); + EXPECT_FALSE(vidX{X{3}} == vidX{3}); + EXPECT_FALSE(valueless() == vidX{3}); + + EXPECT_TRUE(vidX{X{2}} == vidX{X{2}}); + EXPECT_FALSE(vidX{X{2}} == vidX{2}); + EXPECT_FALSE(vidX{X{2}} == vidX{2.0}); + EXPECT_FALSE(vidX{X{2}} == valueless()); + + EXPECT_FALSE(vidX{3} != vidX{3}); + EXPECT_TRUE(vidX{3.0} != vidX{3}); + EXPECT_TRUE(vidX{X{3}} != vidX{3}); + EXPECT_TRUE(valueless() != vidX{3}); + + EXPECT_TRUE(vidX{X{2}} != vidX{X{1}}); // note custom != + EXPECT_FALSE(vidX{X{2}} != vidX{X{2}}); // note custom != + EXPECT_FALSE(vidX{X{2}} != vidX{X{3}}); // note custom != + EXPECT_TRUE(vidX{X{2}} != vidX{2}); + EXPECT_TRUE(vidX{X{2}} != vidX{2.0}); + EXPECT_TRUE(vidX{X{2}} != valueless()); +} + TEST(variant, hash) { // Just ensure we find std::hash specializations. @@ -277,6 +328,7 @@ TEST(variant, hash) { EXPECT_TRUE((std::is_same<std::size_t, decltype(h0(std::declval<variant<>>()))>::value)); std::hash<variant<int, double>> h2; + (void)h2(variant<int, double>(3.1)); EXPECT_TRUE((std::is_same<std::size_t, decltype(h2(std::declval<variant<int, double>>()))>::value)); }