From 7fe3141d1a8bfa8f677899297d935cbc181a7e42 Mon Sep 17 00:00:00 2001 From: Nora Abi Akar <nora.abiakar@gmail.com> Date: Thu, 25 Jun 2020 13:06:43 +0200 Subject: [PATCH] Refactor simd API and add SVE backend (#1044) * Add new API for the SIMD library that is compatible with the ARM 'sizeless' SVE vectors. Language restrictions prevent the use of the operator overload interface used up to this point for SIMD. * Add `indirect_expressions` and `indirect_indexed_expressions` for describing memory reads/writes. `where_expressions` control masked access to simd vectors. * Implement the SVE SIMD back-end in accordance with the alternate SIMD API. * Retrieve vector width information from compiled mechanisms. * Use alternate SIMD API in modcc-generated mechanisms. * Add assertion in generated mechanism code that checks runtime vector width compatibility. Fixess #1021. --- arbor/backends/multicore/mechanism.cpp | 7 +- arbor/backends/multicore/mechanism.hpp | 4 + arbor/backends/multicore/shared_state.cpp | 38 +- arbor/include/arbor/simd/avx.hpp | 8 +- arbor/include/arbor/simd/avx512.hpp | 6 +- arbor/include/arbor/simd/implbase.hpp | 8 +- arbor/include/arbor/simd/native.hpp | 19 +- arbor/include/arbor/simd/neon.hpp | 4 +- arbor/include/arbor/simd/simd.hpp | 693 +++++++++++----- arbor/include/arbor/simd/sve.hpp | 904 +++++++++++++++++++++ mechanisms/default/kdrmt.mod | 8 +- modcc/modcc.cpp | 3 +- modcc/printer/cexpr_emit.cpp | 178 +++- modcc/printer/cexpr_emit.hpp | 26 +- modcc/printer/cprinter.cpp | 191 +++-- modcc/printer/cprinter.hpp | 12 +- modcc/printer/simd.hpp | 30 +- test/unit-modcc/test_printers.cpp | 34 +- test/unit/test_partition_by_constraint.cpp | 4 +- test/unit/test_simd.cpp | 560 ++++++++++++- 20 files changed, 2382 insertions(+), 355 deletions(-) create mode 100644 arbor/include/arbor/simd/sve.hpp diff --git a/arbor/backends/multicore/mechanism.cpp b/arbor/backends/multicore/mechanism.cpp index 87597207..a38be492 100644 --- a/arbor/backends/multicore/mechanism.cpp +++ b/arbor/backends/multicore/mechanism.cpp @@ -28,9 +28,6 @@ namespace multicore { using util::make_range; using util::value_by_key; -constexpr unsigned simd_width = S::simd_abi::native_width<fvm_value_type>::value; - - // Copy elements from source sequence into destination sequence, // and fill the remaining elements of the destination sequence // with the given fill value. @@ -154,7 +151,7 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me copy_extend(pos_data.cv, node_index_, pos_data.cv.back()); copy_extend(pos_data.weight, make_range(data_.data(), data_.data()+width_padded_), 0); - index_constraints_ = make_constraint_partition(node_index_, width_, simd_width); + index_constraints_ = make_constraint_partition(node_index_, width_, simd_width()); if (mult_in_place_) { multiplicity_ = iarray(width_padded_, pad); @@ -176,7 +173,7 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me ion_index = iarray(width_padded_, pad); copy_extend(indices, ion_index, util::back(indices)); - arb_assert(compatible_index_constraints(node_index_, ion_index, simd_width)); + arb_assert(compatible_index_constraints(node_index_, ion_index, simd_width())); } } diff --git a/arbor/backends/multicore/mechanism.hpp b/arbor/backends/multicore/mechanism.hpp index 0d135730..cf4a331c 100644 --- a/arbor/backends/multicore/mechanism.hpp +++ b/arbor/backends/multicore/mechanism.hpp @@ -137,6 +137,10 @@ protected: virtual mechanism_ion_state_table ion_state_table() { return {}; } virtual mechanism_ion_index_table ion_index_table() { return {}; } + // Simd width used in mechanism. + + virtual unsigned simd_width() const { return 1; } + // Report raw size in bytes of mechanism object. virtual std::size_t object_sizeof() const = 0; diff --git a/arbor/backends/multicore/shared_state.cpp b/arbor/backends/multicore/shared_state.cpp index 10a1295d..8749ac69 100644 --- a/arbor/backends/multicore/shared_state.cpp +++ b/arbor/backends/multicore/shared_state.cpp @@ -26,9 +26,10 @@ namespace arb { namespace multicore { -constexpr unsigned simd_width = simd::simd_abi::native_width<fvm_value_type>::value; -using simd_value_type = simd::simd<fvm_value_type, simd_width>; -using simd_index_type = simd::simd<fvm_index_type, simd_width>; +constexpr unsigned vector_length = (unsigned) simd::simd_abi::native_width<fvm_value_type>::value; +using simd_value_type = simd::simd<fvm_value_type, vector_length, simd::simd_abi::default_abi>; +using simd_index_type = simd::simd<fvm_index_type, vector_length, simd::simd_abi::default_abi>; +const int simd_width = simd::width<simd_value_type>(); // Pick alignment compatible with native SIMD width for explicitly // vectorized operations below. @@ -168,27 +169,38 @@ void shared_state::ions_init_concentration() { } void shared_state::update_time_to(fvm_value_type dt_step, fvm_value_type tmax) { + using simd::assign; + using simd::indirect; + using simd::add; + using simd::min; for (fvm_size_type i = 0; i<n_intdom; i+=simd_width) { - simd_value_type t(time.data()+i); - t = min(t+dt_step, simd_value_type(tmax)); - t.copy_to(time_to.data()+i); + simd_value_type t; + assign(t, indirect(time.data()+i, simd_width)); + t = min(add(t, dt_step), tmax); + indirect(time_to.data()+i, simd_width) = t; } } void shared_state::set_dt() { + using simd::assign; + using simd::indirect; + using simd::sub; for (fvm_size_type j = 0; j<n_intdom; j+=simd_width) { - simd_value_type t(time.data()+j); - simd_value_type t_to(time_to.data()+j); + simd_value_type t, t_to; + assign(t, indirect(time.data()+j, simd_width)); + assign(t_to, indirect(time_to.data()+j, simd_width)); - auto dt = t_to-t; - dt.copy_to(dt_intdom.data()+j); + auto dt = sub(t_to,t); + indirect(dt_intdom.data()+j, simd_width) = dt; } for (fvm_size_type i = 0; i<n_cv; i+=simd_width) { - simd_index_type intdom_idx(cv_to_intdom.data()+i); + simd_index_type intdom_idx; + assign(intdom_idx, indirect(cv_to_intdom.data()+i, simd_width)); - simd_value_type dt(simd::indirect(dt_intdom.data(), intdom_idx)); - dt.copy_to(dt_cv.data()+i); + simd_value_type dt; + assign(dt, indirect(dt_intdom.data(), intdom_idx, simd_width)); + indirect(dt_cv.data()+i, simd_width) = dt; } } diff --git a/arbor/include/arbor/simd/avx.hpp b/arbor/include/arbor/simd/avx.hpp index 31fcd513..1dd39c9c 100644 --- a/arbor/include/arbor/simd/avx.hpp +++ b/arbor/include/arbor/simd/avx.hpp @@ -74,7 +74,7 @@ struct avx_int4: implbase<avx_int4> { return _mm_cvtsi128_si32(a); } - static __m128i negate(const __m128i& a) { + static __m128i neg(const __m128i& a) { __m128i zero = _mm_setzero_si128(); return _mm_sub_epi32(zero, a); } @@ -163,7 +163,7 @@ struct avx_int4: implbase<avx_int4> { // bottom 4 bytes. __m128i s = _mm_setr_epi32(0x0c080400ul,0,0,0); - __m128i p = _mm_shuffle_epi8(negate(m), s); + __m128i p = _mm_shuffle_epi8(neg(m), s); std::memcpy(y, &p, 4); } @@ -172,7 +172,7 @@ struct avx_int4: implbase<avx_int4> { std::memcpy(&r, w, 4); __m128i s = _mm_setr_epi32(0x80808000ul, 0x80808001ul, 0x80808002ul, 0x80808003ul); - return negate(_mm_shuffle_epi8(r, s)); + return neg(_mm_shuffle_epi8(r, s)); } static __m128i max(const __m128i& a, const __m128i& b) { @@ -245,7 +245,7 @@ struct avx_double4: implbase<avx_double4> { return _mm_cvtsd_f64(_mm256_castpd256_pd128(a)); } - static __m256d negate(const __m256d& a) { + static __m256d neg(const __m256d& a) { return _mm256_sub_pd(zero(), a); } diff --git a/arbor/include/arbor/simd/avx512.hpp b/arbor/include/arbor/simd/avx512.hpp index 4f3a524e..0416d8a0 100644 --- a/arbor/include/arbor/simd/avx512.hpp +++ b/arbor/include/arbor/simd/avx512.hpp @@ -98,7 +98,7 @@ struct avx512_mask8: implbase<avx512_mask8> { // max(a, b) a | b // min(a, b) a & b - static __mmask8 negate(const __mmask8& a) { + static __mmask8 neg(const __mmask8& a) { return a; } @@ -249,7 +249,7 @@ struct avx512_int8: implbase<avx512_int8> { return _mm_cvtsi128_si32(_mm512_castsi512_si128(a)); } - static __m512i negate(const __m512i& a) { + static __m512i neg(const __m512i& a) { return sub(_mm512_setzero_epi32(), a); } @@ -445,7 +445,7 @@ struct avx512_double8: implbase<avx512_double8> { return _mm_cvtsd_f64(_mm512_castpd512_pd128(a)); } - static __m512d negate(const __m512d& a) { + static __m512d neg(const __m512d& a) { return _mm512_sub_pd(_mm512_setzero_pd(), a); } diff --git a/arbor/include/arbor/simd/implbase.hpp b/arbor/include/arbor/simd/implbase.hpp index ff79e777..544846cd 100644 --- a/arbor/include/arbor/simd/implbase.hpp +++ b/arbor/include/arbor/simd/implbase.hpp @@ -14,9 +14,9 @@ // // Function | Default implemention by // ---------------------------------- -// min | negate, cmp_gt, ifelse -// max | negate, cmp_gt, ifelse -// abs | negate, max +// min | neg, cmp_gt, ifelse +// max | neg, cmp_gt, ifelse +// abs | neg, max // sin | lane-wise std::sin // cos | lane-wise std::cos // exp | lane-wise std::exp @@ -181,7 +181,7 @@ struct implbase { return I::copy_from(a); } - static vector_type negate(const vector_type& u) { + static vector_type neg(const vector_type& u) { store a, r; I::copy_to(u, a); diff --git a/arbor/include/arbor/simd/native.hpp b/arbor/include/arbor/simd/native.hpp index 8164ceaa..1d945574 100644 --- a/arbor/include/arbor/simd/native.hpp +++ b/arbor/include/arbor/simd/native.hpp @@ -65,7 +65,13 @@ ARB_DEF_NATIVE_SIMD_(double, 8, avx512) #endif -#if defined(__ARM_NEON) +#if defined(__ARM_FEATURE_SVE) + +#include "sve.hpp" +ARB_DEF_NATIVE_SIMD_(int, 0, sve) +ARB_DEF_NATIVE_SIMD_(double, 0, sve) + +#elif defined(__ARM_NEON) #include <arbor/simd/neon.hpp> ARB_DEF_NATIVE_SIMD_(int, 2, neon) @@ -73,7 +79,6 @@ ARB_DEF_NATIVE_SIMD_(double, 2, neon) #endif - namespace arb { namespace simd { namespace simd_abi { @@ -87,15 +92,15 @@ struct native_width; template <typename Value, int k> struct native_width { - static constexpr int value = - std::is_same<void, typename native<Value, k>::type>::value? - native_width<Value, k/2>::value: - k; + static constexpr int value = + std::is_same<void, typename native<Value, k>::type>::value? + native_width<Value, k/2>::value: + k; }; template <typename Value> struct native_width<Value, 1> { - static constexpr int value = 1; + static constexpr int value = std::is_same<void, typename native<Value, 0>::type>::value; }; } // namespace simd_abi diff --git a/arbor/include/arbor/simd/neon.hpp b/arbor/include/arbor/simd/neon.hpp index f2710461..32d07012 100644 --- a/arbor/include/arbor/simd/neon.hpp +++ b/arbor/include/arbor/simd/neon.hpp @@ -68,7 +68,7 @@ struct neon_int2 : implbase<neon_int2> { return a; } - static int32x2_t negate(const int32x2_t& a) { return vneg_s32(a); } + static int32x2_t neg(const int32x2_t& a) { return vneg_s32(a); } static int32x2_t add(const int32x2_t& a, const int32x2_t& b) { return vadd_s32(a, b); @@ -222,7 +222,7 @@ struct neon_double2 : implbase<neon_double2> { return a; } - static float64x2_t negate(const float64x2_t& a) { return vnegq_f64(a); } + static float64x2_t neg(const float64x2_t& a) { return vnegq_f64(a); } static float64x2_t add(const float64x2_t& a, const float64x2_t& b) { return vaddq_f64(a, b); diff --git a/arbor/include/arbor/simd/simd.hpp b/arbor/include/arbor/simd/simd.hpp index 3e0fb573..58cb0540 100644 --- a/arbor/include/arbor/simd/simd.hpp +++ b/arbor/include/arbor/simd/simd.hpp @@ -7,6 +7,7 @@ #include <arbor/simd/implbase.hpp> #include <arbor/simd/generic.hpp> #include <arbor/simd/native.hpp> +#include <arbor/util/pp_util.hpp> namespace arb { namespace simd { @@ -17,41 +18,350 @@ namespace detail { template <typename Impl> struct simd_mask_impl; + + template <typename To> + struct simd_cast_impl; + + template <typename I, typename V> + class indirect_indexed_expression; + + template <typename V> + class indirect_expression; + + template <typename T, typename M> + class where_expression; + + template <typename T, typename M> + class const_where_expression; +} + +// Top level functions for second API +using detail::simd_impl; +using detail::simd_mask_impl; + +template <typename Impl, typename V> +void assign(simd_impl<Impl>& a, const detail::indirect_expression<V>& b) { + a.copy_from(b); +} + +template <typename Impl, typename ImplIndex, typename V> +void assign(simd_impl<Impl>& a, const detail::indirect_indexed_expression<ImplIndex, V>& b) { + a.copy_from(b); +} + +template <typename Impl> +typename simd_impl<Impl>::scalar_type sum(const simd_impl<Impl>& a) { + return a.sum(); +}; + +#define ARB_UNARY_ARITHMETIC_(name)\ +template <typename Impl>\ +simd_impl<Impl> name(const simd_impl<Impl>& a) {\ + return simd_impl<Impl>::wrap(Impl::name(a.value_));\ +}; + +#define ARB_BINARY_ARITHMETIC_(name)\ +template <typename Impl>\ +simd_impl<Impl> name(const simd_impl<Impl>& a, simd_impl<Impl> b) {\ + return simd_impl<Impl>::wrap(Impl::name(a.value_, b.value_));\ +};\ +template <typename Impl>\ +simd_impl<Impl> name(const simd_impl<Impl>& a, typename simd_impl<Impl>::scalar_type b) {\ + return simd_impl<Impl>::wrap(Impl::name(a.value_, Impl::broadcast(b)));\ +};\ +template <typename Impl>\ +simd_impl<Impl> name(const typename simd_impl<Impl>::scalar_type a, simd_impl<Impl> b) {\ + return simd_impl<Impl>::wrap(Impl::name(Impl::broadcast(a), b.value_));\ +}; + +#define ARB_BINARY_COMPARISON_(name)\ +template <typename Impl>\ +typename simd_impl<Impl>::simd_mask name(const simd_impl<Impl>& a, simd_impl<Impl> b) {\ + return simd_impl<Impl>::mask(Impl::name(a.value_, b.value_));\ +};\ +template <typename Impl>\ +typename simd_impl<Impl>::simd_mask name(const simd_impl<Impl>& a, typename simd_impl<Impl>::scalar_type b) {\ + return simd_impl<Impl>::mask(Impl::name(a.value_, Impl::broadcast(b)));\ +};\ +template <typename Impl>\ +typename simd_impl<Impl>::simd_mask name(const typename simd_impl<Impl>::scalar_type a, simd_impl<Impl> b) {\ + return simd_impl<Impl>::mask(Impl::name(Impl::broadcast(a), b.value_));\ +}; + +ARB_PP_FOREACH(ARB_BINARY_ARITHMETIC_, add, sub, mul, div, pow, max, min) +ARB_PP_FOREACH(ARB_BINARY_COMPARISON_, cmp_eq, cmp_neq, cmp_leq, cmp_lt, cmp_geq, cmp_gt) +ARB_PP_FOREACH(ARB_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, exprelr) + +#undef ARB_BINARY_ARITHMETIC_ +#undef ARB_BINARY_COMPARISON__ +#undef ARB_UNARY_ARITHMETIC_ + +template <typename T> +simd_mask_impl<T> logical_and(const simd_mask_impl<T>& a, simd_mask_impl<T> b) { + return a && b; +} + +template <typename T> +simd_mask_impl<T> logical_or(const simd_mask_impl<T>& a, simd_mask_impl<T> b) { + return a || b; +} + +template <typename T> +simd_mask_impl<T> logical_not(const simd_mask_impl<T>& a) { + return !a; +} + +template <typename T> +simd_impl<T> fma(const simd_impl<T> a, simd_impl<T> b, simd_impl<T> c) { + return simd_impl<T>::wrap(T::fma(a.value_, b.value_, c.value_)); } namespace detail { - template <typename Impl, typename V> - struct indirect_expression { + /// Indirect Expressions + template <typename V> + class indirect_expression { + public: + indirect_expression(V* p, unsigned width): p(p), width(width) {} + + indirect_expression& operator=(V s) { + for (unsigned i = 0; i < width; ++i) { + p[i] = s; + } + return *this; + } + + template <typename Other> + indirect_expression& operator=(const Other& s) { + indirect_copy_to(s, p, width); + return *this; + } + + template <typename Impl, typename ImplMask> + indirect_expression& operator=(const const_where_expression<Impl, ImplMask>& s) { + indirect_copy_to(s.data_, s.mask_, p, width); + return *this; + } + + template <typename Impl, typename ImplMask> + indirect_expression& operator=(const where_expression<Impl, ImplMask>& s) { + indirect_copy_to(s.data_, s.mask_, p, width); + return *this; + } + + template <typename Impl> friend struct simd_impl; + template <typename Impl> friend struct simd_mask_impl; + template <typename To> friend struct simd_cast_impl; + template <typename T, typename M> friend class where_expression; + + private: V* p; - typename simd_traits<Impl>::vector_type index; - index_constraint constraint; + unsigned width; + }; + + template <typename Impl, typename V> + static void indirect_copy_to(const simd_mask_impl<Impl>& s, V* p, unsigned width) { + Impl::mask_copy_to(s.value_, p); + } - indirect_expression() = default; - indirect_expression(V* p, const simd_impl<Impl>& index_simd, index_constraint constraint): - p(p), index(index_simd.value_), constraint(constraint) + template <typename Impl, typename V> + static void indirect_copy_to(const simd_impl<Impl>& s, V* p, unsigned width) { + Impl::copy_to(s.value_, p); + } + + template <typename Impl, typename ImplMask, typename V> + static void indirect_copy_to(const simd_impl<Impl>& data, const simd_mask_impl<ImplMask>& mask, V* p, unsigned width) { + Impl::copy_to_masked(data.value_, p, mask.value_); + } + + /// Indirect Indexed Expressions + template <typename ImplIndex, typename V> + class indirect_indexed_expression { + public: + indirect_indexed_expression(V* p, const ImplIndex& index_simd, unsigned width, index_constraint constraint): + p(p), index(index_simd), width(width), constraint(constraint) {} - // Simple assignment included for consistency with compound assignment interface. + indirect_indexed_expression& operator=(V s) { + typename simd_traits<ImplIndex>::scalar_type idx[width]; + ImplIndex::copy_to(index.value_, idx); + for (unsigned i = 0; i < width; ++i) { + p[idx[i]] = s; + } + return *this; + } template <typename Other> - indirect_expression& operator=(const simd_impl<Other>& s) { - s.copy_to(*this); + indirect_indexed_expression& operator=(const Other& s) { + indirect_indexed_copy_to(s, p, index, width); return *this; } - // Compound assignment (currently only addition and subtraction!): + template <typename Impl, typename ImplMask> + indirect_indexed_expression& operator=(const const_where_expression<Impl, ImplMask>& s) { + indirect_indexed_copy_to(s.data_, s.mask_, p, index, width); + return *this; + } + + template <typename Impl, typename ImplMask> + indirect_indexed_expression& operator=(const where_expression<Impl, ImplMask>& s) { + indirect_indexed_copy_to(s.data_, s.mask_, p, index, width); + return *this; + } + + template <typename Other> + indirect_indexed_expression& operator+=(const Other& s) { + compound_indexed_add(s, p, index, width, constraint); + return *this; + } template <typename Other> - indirect_expression& operator+=(const simd_impl<Other>& s) { - simd_impl<Other>::compound_indexed_add(tag<Impl>{}, s.value_, p, index, constraint); + indirect_indexed_expression& operator-=(const Other& s) { + compound_indexed_add(neg(s), p, index, width, constraint); return *this; } + template <typename Impl> friend struct simd_impl; + template <typename To> friend struct simd_cast_impl; + template <typename T, typename M> friend class where_expression; + + private: + V* p; + const ImplIndex& index; + unsigned width; + index_constraint constraint; + }; + + template <typename Impl, typename ImplIndex, typename V> + static void indirect_indexed_copy_to(const simd_impl<Impl>& s, V* p, const simd_impl<ImplIndex>& index, unsigned width) { + Impl::scatter(tag<ImplIndex>{}, s.value_, p, index.value_); + } + + template <typename Impl, typename ImplIndex, typename ImplMask, typename V> + static void indirect_indexed_copy_to(const simd_impl<Impl>& data, const simd_mask_impl<ImplMask>& mask, V* p, const simd_impl<ImplIndex>& index, unsigned width) { + Impl::scatter(tag<ImplIndex>{}, data.value_, p, index.value_, mask.value_); + } + + template <typename ImplIndex, typename Impl, typename V> + static void compound_indexed_add( + const simd_impl<Impl>& s, + V* p, + const simd_impl<ImplIndex>& index, + unsigned width, + index_constraint constraint) + { + switch (constraint) { + case index_constraint::none: + { + typename ImplIndex::scalar_type o[width]; + ImplIndex::copy_to(index.value_, o); + + V a[width]; + Impl::copy_to(s.value_, a); + + V temp = 0; + for (unsigned i = 0; i<width-1; ++i) { + temp += a[i]; + if (o[i] != o[i+1]) { + p[o[i]] += temp; + temp = 0; + } + } + temp += a[width-1]; + p[o[width-1]] += temp; + } + break; + case index_constraint::independent: + { + auto v = Impl::add(Impl::gather(tag<ImplIndex>{}, p, index.value_), s.value_); + Impl::scatter(tag<ImplIndex>{}, v, p, index.value_); + } + break; + case index_constraint::contiguous: + { + p += ImplIndex::element0(index.value_); + auto v = Impl::add(Impl::copy_from(p), s.value_); + Impl::copy_to(v, p); + } + break; + case index_constraint::constant: + p += ImplIndex::element0(index.value_); + *p += Impl::reduce_add(s.value_); + break; + } + } + + /// Where Expressions + template <typename Impl, typename ImplMask> + class where_expression { + public: + where_expression(const ImplMask& m, Impl& s): + mask_(m), data_(s) {} + template <typename Other> - indirect_expression& operator-=(const simd_impl<Other>& s) { - simd_impl<Other>::compound_indexed_add(tag<Impl>{}, (-s).value_, p, index, constraint); + where_expression& operator=(const Other& v) { + where_copy_to(mask_, data_, v); + return *this; + } + + template <typename V> + where_expression& operator=(const indirect_expression<V>& v) { + where_copy_to(mask_, data_, v.p, v.width); + return *this; + } + + template <typename ImplIndex, typename V> + where_expression& operator=(const indirect_indexed_expression<ImplIndex, V>& v) { + where_copy_to(mask_, data_, v.p, v.index, v.width); return *this; } + + template <typename T> friend struct simd_impl; + template <typename To> friend struct simd_cast_impl; + template <typename V> friend class indirect_expression; + template <typename I, typename V> friend class indirect_indexed_expression; + + private: + const ImplMask& mask_; + Impl& data_; + }; + + template <typename Impl, typename ImplMask, typename V> + static void where_copy_to(const simd_mask_impl<ImplMask>& mask, simd_impl<Impl>& f, const V& t) { + f.value_ = Impl::ifelse(mask.value_, Impl::broadcast(t), f.value_); + } + + template <typename Impl, typename ImplMask> + static void where_copy_to(const simd_mask_impl<ImplMask>& mask, simd_impl<Impl>& f, const simd_impl<Impl>& t) { + f.value_ = Impl::ifelse(mask.value_, t.value_, f.value_); + } + + template <typename Impl, typename ImplMask, typename V> + static void where_copy_to(const simd_mask_impl<ImplMask>& mask, simd_impl<Impl>& f, const V* t, unsigned width) { + f.value_ = Impl::ifelse(mask.value_, Impl::copy_from_masked(t, mask.value_), f.value_); + } + + template <typename Impl, typename ImplIndex, typename ImplMask, typename V> + static void where_copy_to(const simd_mask_impl<ImplMask>& mask, simd_impl<Impl>& f, const V* p, const simd_impl<ImplIndex>& index, unsigned width) { + simd_impl<Impl> temp = Impl::broadcast(0); + temp.value_ = Impl::gather(tag<ImplIndex>{}, temp.value_, p, index.value_, mask.value_); + f.value_ = Impl::ifelse(mask.value_, temp.value_, f.value_); + } + + /// Const Where Expressions + template <typename Impl, typename ImplMask> + class const_where_expression { + public: + const_where_expression(const ImplMask& m, const Impl& s): + mask_(m), data_(s) {} + + template <typename T> friend struct simd_impl; + template <typename To> friend struct simd_cast_impl; + template <typename V> friend class indirect_expression; + template <typename I, typename V> friend class indirect_indexed_expression; + + private: + const ImplMask& mask_; + const Impl& data_; }; template <typename Impl> @@ -68,6 +378,7 @@ namespace detail { using scalar_type = typename simd_traits<Impl>::scalar_type; using simd_mask = simd_mask_impl<typename simd_traits<Impl>::mask_impl>; + using simd_base = Impl; protected: using vector_type = typename simd_traits<Impl>::vector_type; @@ -80,7 +391,10 @@ namespace detail { friend struct simd_impl; template <typename Other, typename V> - friend struct indirect_expression; + friend class indirect_indexed_expression; + + template <typename V> + friend class indirect_expression; simd_impl() = default; @@ -117,12 +431,12 @@ namespace detail { // Construct from indirect expression (gather). template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> - explicit simd_impl(indirect_expression<IndexImpl, scalar_type> pi) { + explicit simd_impl(indirect_indexed_expression<IndexImpl, scalar_type> pi) { copy_from(pi); } template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> - explicit simd_impl(indirect_expression<IndexImpl, const scalar_type> pi) { + explicit simd_impl(indirect_indexed_expression<IndexImpl, const scalar_type> pi) { copy_from(pi); } @@ -149,8 +463,9 @@ namespace detail { Impl::copy_to(value_, p); } - template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> - void copy_to(indirect_expression<IndexImpl, scalar_type> pi) const { + template <typename Index, typename = std::enable_if_t<width==simd_traits<typename Index::simd_base>::width>> + void copy_to(indirect_indexed_expression<Index, scalar_type> pi) const { + using IndexImpl = typename Index::simd_base; Impl::scatter(tag<IndexImpl>{}, value_, pi.p, pi.index); } @@ -158,24 +473,26 @@ namespace detail { value_ = Impl::copy_from(p); } - template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> - void copy_from(indirect_expression<IndexImpl, scalar_type> pi) { + + template <typename Index, typename = std::enable_if_t<width==simd_traits<typename Index::simd_base>::width>> + void copy_from(indirect_indexed_expression<Index, scalar_type> pi) { + using IndexImpl = typename Index::simd_base; switch (pi.constraint) { case index_constraint::none: - value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index); + value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index.value_); break; case index_constraint::independent: - value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index); + value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index.value_); break; case index_constraint::contiguous: { - scalar_type* p = IndexImpl::element0(pi.index) + pi.p; + scalar_type* p = IndexImpl::element0(pi.index.value_) + pi.p; value_ = Impl::copy_from(p); } break; case index_constraint::constant: { - scalar_type* p = IndexImpl::element0(pi.index) + pi.p; + scalar_type* p = IndexImpl::element0(pi.index.value_) + pi.p; scalar_type l = (*p); value_ = Impl::broadcast(l); } @@ -183,80 +500,54 @@ namespace detail { } } - template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> - void copy_from(indirect_expression<IndexImpl, const scalar_type> pi) { + template <typename Index, typename = std::enable_if_t<width==simd_traits<typename Index::simd_base>::width>> + void copy_from(indirect_indexed_expression<Index, const scalar_type> pi) { + using IndexImpl = typename Index::simd_base; switch (pi.constraint) { case index_constraint::none: - value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index); + value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index.value_); break; case index_constraint::independent: - value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index); + value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index.value_); break; case index_constraint::contiguous: { - const scalar_type* p = IndexImpl::element0(pi.index) + pi.p; + const scalar_type* p = IndexImpl::element0(pi.index.value_) + pi.p; value_ = Impl::copy_from(p); } break; case index_constraint::constant: { - const scalar_type *p = IndexImpl::element0(pi.index) + pi.p; + const scalar_type *p = IndexImpl::element0(pi.index.value_) + pi.p; scalar_type l = (*p); value_ = Impl::broadcast(l); } break; } + } + void copy_from(indirect_expression<scalar_type> pi) { + value_ = Impl::copy_from(pi.p); } - template <typename ImplIndex> - static void compound_indexed_add(tag<ImplIndex> tag, const vector_type& s, scalar_type* p, const typename ImplIndex::vector_type& index, index_constraint constraint) { - switch (constraint) { - case index_constraint::none: - { - typename ImplIndex::scalar_type o[width]; - ImplIndex::copy_to(index, o); - - scalar_type a[width]; - Impl::copy_to(s, a); - - scalar_type temp = 0; - for (unsigned i = 0; i<width-1; ++i) { - temp += a[i]; - if (o[i] != o[i+1]) { - p[o[i]] += temp; - temp = 0; - } - } - temp += a[width-1]; - p[o[width-1]] += temp; - } - break; - case index_constraint::independent: - { - vector_type v = Impl::add(Impl::gather(tag, p, index), s); - Impl::scatter(tag, v, p, index); - } - break; - case index_constraint::contiguous: - { - p += ImplIndex::element0(index); - vector_type v = Impl::add(Impl::copy_from(p), s); - Impl::copy_to(v, p); - } - break; - case index_constraint::constant: - p += ImplIndex::element0(index); - *p += Impl::reduce_add(s); - break; - } + void copy_from(indirect_expression<const scalar_type> pi) { + value_ = Impl::copy_from(pi.p); + } + + template <typename T, typename M> + void copy_from(const_where_expression<T, M> w) { + value_ = Impl::ifelse(w.mask_.value_, w.data_.value_, value_); } + template <typename T, typename M> + void copy_from(where_expression<T, M> w) { + value_ = Impl::ifelse(w.mask_.value_, w.data_.value_, value_); + } // Arithmetic operations: +, -, *, /, fma. simd_impl operator-() const { - return wrap(Impl::negate(value_)); + return simd_impl::wrap(Impl::neg(value_)); } friend simd_impl operator+(const simd_impl& a, simd_impl b) { @@ -364,115 +655,67 @@ namespace detail { return Impl::reduce_add(value_); } - // Masked assignment (via where expressions). - - struct where_expression { - where_expression(const where_expression&) = default; - where_expression& operator=(const where_expression&) = delete; - - where_expression(const simd_mask& m, simd_impl& s): - mask_(m), data_(s) {} - - where_expression& operator=(scalar_type v) { - data_.value_ = Impl::ifelse(mask_.value_, simd_impl(v).value_, data_.value_); - return *this; - } - - where_expression& operator=(const simd_impl& v) { - data_.value_ = Impl::ifelse(mask_.value_, v.value_, data_.value_); - return *this; - } - - void copy_to(scalar_type* p) const { - Impl::copy_to_masked(data_.value_, p, mask_.value_); - } - - void copy_from(const scalar_type* p) { - data_.value_ = Impl::copy_from_masked(data_.value_, p, mask_.value_); - } - - // Gather and scatter. - - template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> - void copy_from(indirect_expression<IndexImpl, scalar_type> pi) { - data_.value_ = Impl::gather(tag<IndexImpl>{}, data_.value_, pi.p, pi.index, mask_.value_); - } - - template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> - void copy_to(indirect_expression<IndexImpl, scalar_type> pi) const { - Impl::scatter(tag<IndexImpl>{}, data_.value_, pi.p, pi.index, mask_.value_); - } - - private: - const simd_mask& mask_; - simd_impl& data_; - }; + // Maths functions are implemented as top-level functions; declare as friends for access to `wrap` - struct const_where_expression { - const_where_expression(const const_where_expression&) = default; - const_where_expression& operator=(const const_where_expression&) = delete; + #define ARB_DECLARE_UNARY_ARITHMETIC_(name)\ + template <typename T>\ + friend simd_impl<T> arb::simd::name(const simd_impl<T>& a); - const_where_expression(const simd_mask& m, const simd_impl& s): - mask_(m), data_(s) {} + #define ARB_DECLARE_BINARY_ARITHMETIC_(name)\ + template <typename T>\ + friend simd_impl<T> arb::simd::name(const simd_impl<T>& a, simd_impl<T> b);\ + template <typename T>\ + friend simd_impl<T> arb::simd::name(const simd_impl<T>& a, typename simd_impl<T>::scalar_type b);\ + template <typename T>\ + friend simd_impl<T> arb::simd::name(const typename simd_impl<T>::scalar_type a, simd_impl<T> b); - void copy_to(scalar_type* p) const { - Impl::copy_to_masked(data_.value_, p, mask_.value_); - } - - template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> - void copy_to(indirect_expression<IndexImpl, scalar_type> pi) const { - Impl::scatter(tag<IndexImpl>{}, data_.value_, pi.p, pi.index, mask_.value_); - } + #define ARB_DECLARE_BINARY_COMPARISON_(name)\ + template <typename T>\ + friend typename simd_impl<T>::simd_mask arb::simd::name(const simd_impl<T>& a, simd_impl<T> b);\ + template <typename T>\ + friend typename simd_impl<T>::simd_mask arb::simd::name(const simd_impl<T>& a, typename simd_impl<T>::scalar_type b);\ + template <typename T>\ + friend typename simd_impl<T>::simd_mask arb::simd::name(const typename simd_impl<T>::scalar_type a, simd_impl<T> b); - private: - const simd_mask& mask_; - const simd_impl& data_; - }; + ARB_PP_FOREACH(ARB_DECLARE_BINARY_ARITHMETIC_, add, sub, mul, div, pow, max, min, cmp_eq) + ARB_PP_FOREACH(ARB_DECLARE_BINARY_COMPARISON_, cmp_eq, cmp_neq, cmp_lt, cmp_leq, cmp_gt, cmp_geq) + ARB_PP_FOREACH(ARB_DECLARE_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, exprelr) + #undef ARB_DECLARE_UNARY_ARITHMETIC_ + #undef ARB_DECLARE_BINARY_ARITHMETIC_ + #undef ARB_DECLARE_BINARY_COMPARISON_ - // Maths functions are implemented as top-level functions; declare as friends - // for access to `wrap` and to enjoy ADL, allowing implicit conversion from - // scalar_type in binary operation arguments. + template <typename T> + friend simd_impl<T> arb::simd::fma(const simd_impl<T> a, simd_impl<T> b, simd_impl<T> c); - friend simd_impl abs(const simd_impl& s) { - return simd_impl::wrap(Impl::abs(s.value_)); - } + // Declare Indirect/Indirect indexed/Where Expression copy function as friends - friend simd_impl sin(const simd_impl& s) { - return simd_impl::wrap(Impl::sin(s.value_)); - } + template <typename T, typename I, typename V> + friend void compound_indexed_add(const simd_impl<I>& s, V* p, const simd_impl<T>& index, unsigned width, index_constraint constraint); - friend simd_impl cos(const simd_impl& s) { - return simd_impl::wrap(Impl::cos(s.value_)); - } + template <typename I, typename V> + friend void indirect_copy_to(const simd_impl<I>& s, V* p, unsigned width); - friend simd_impl exp(const simd_impl& s) { - return simd_impl::wrap(Impl::exp(s.value_)); - } + template <typename T, typename M, typename V> + friend void indirect_copy_to(const simd_impl<T>& data, const simd_mask_impl<M>& mask, V* p, unsigned width); - friend simd_impl log(const simd_impl& s) { - return simd_impl::wrap(Impl::log(s.value_)); - } + template <typename T, typename I, typename V> + friend void indirect_indexed_copy_to(const simd_impl<T>& s, V* p, const simd_impl<I>& index, unsigned width); - friend simd_impl expm1(const simd_impl& s) { - return simd_impl::wrap(Impl::expm1(s.value_)); - } + template <typename T, typename I, typename M, typename V> + friend void indirect_indexed_copy_to(const simd_impl<T>& data, const simd_mask_impl<M>& mask, V* p, const simd_impl<I>& index, unsigned width); - friend simd_impl exprelr(const simd_impl& s) { - return simd_impl::wrap(Impl::exprelr(s.value_)); - } + template <typename T, typename M, typename V> + friend void where_copy_to(const simd_mask_impl<M>& mask, simd_impl<T>& f, const V& t); - friend simd_impl pow(const simd_impl& s, const simd_impl& t) { - return simd_impl::wrap(Impl::pow(s.value_, t.value_)); - } + template <typename T, typename M> + friend void where_copy_to(const simd_mask_impl<M>& mask, simd_impl<T>& f, const simd_impl<T>& t); - friend simd_impl min(const simd_impl& s, const simd_impl& t) { - return simd_impl::wrap(Impl::min(s.value_, t.value_)); - } + template <typename T, typename M, typename V> + friend void where_copy_to(const simd_mask_impl<M>& mask, simd_impl<T>& f, const V* p, unsigned width); - friend simd_impl max(const simd_impl& s, const simd_impl& t) { - return simd_impl::wrap(Impl::max(s.value_, t.value_)); - } + template <typename T, typename I, typename M, typename V> + friend void where_copy_to(const simd_mask_impl<M>& mask, simd_impl<T>& f, const V* p, const simd_impl<I>& index, unsigned width); protected: vector_type value_; @@ -541,6 +784,10 @@ namespace detail { value_ = Impl::mask_copy_from(y); } + void copy_from(indirect_expression<bool> pi) { + value_ = Impl::mask_copy_from(pi.p); + } + // Array subscript operations. struct reference { @@ -605,7 +852,33 @@ namespace detail { }; template <typename To> - struct simd_cast_impl {}; + struct simd_cast_impl; + + template <typename ImplTo> + struct simd_cast_impl<simd_mask_impl<ImplTo>> { + static constexpr unsigned N = simd_traits<ImplTo>::width; + using scalar_type = typename simd_traits<ImplTo>::scalar_type; + + template <typename ImplFrom, typename = std::enable_if_t<N==simd_traits<ImplFrom>::width>> + static simd_mask_impl<ImplTo> cast(const simd_mask_impl<ImplFrom>& v) { + return simd_mask_impl<ImplTo>(v); + } + + static simd_mask_impl<ImplTo> cast(const std::array<scalar_type, N>& a) { + return simd_mask_impl<ImplTo>(a.data()); + } + + static simd_mask_impl<ImplTo> cast(scalar_type a) { + simd_mask_impl<ImplTo> r = a; + return r; + } + + static simd_mask_impl<ImplTo> cast(const indirect_expression<bool>& a) { + simd_mask_impl<ImplTo> r; + r.copy_from(a); + return r; + } + }; template <typename ImplTo> struct simd_cast_impl<simd_impl<ImplTo>> { @@ -620,6 +893,32 @@ namespace detail { static simd_impl<ImplTo> cast(const std::array<scalar_type, N>& a) { return simd_impl<ImplTo>(a.data()); } + + static simd_impl<ImplTo> cast(scalar_type a) { + simd_impl<ImplTo> r = a; + return r; + } + + template <typename V> + static simd_impl<ImplTo> cast(const indirect_expression<V>& a) { + simd_impl<ImplTo> r; + r.copy_from(a); + return r; + } + + template <typename Impl, typename V> + static simd_impl<ImplTo> cast(const indirect_indexed_expression<Impl,V>& a) { + simd_impl<ImplTo> r; + r.copy_from(a); + return r; + } + + template <typename Impl, typename V> + static simd_impl<ImplTo> cast(const const_where_expression<Impl,V>& a) { + simd_impl<ImplTo> r = 0; + r.copy_from(a); + return r; + } }; template <typename V, std::size_t N> @@ -652,27 +951,17 @@ namespace simd_abi { }; } -template <typename Value, unsigned N, template <class, unsigned> class Abi = simd_abi::default_abi> -using simd = detail::simd_impl<typename Abi<Value, N>::type>; - -template <typename Value, unsigned N> -using simd_mask = typename simd<Value, N>::simd_mask; +template <typename Value, unsigned N, template <class, unsigned> class Abi> +struct simd_wrap { using type = detail::simd_impl<typename Abi<Value, N>::type>; }; -template <typename Simd> -using where_expression = typename Simd::where_expression; +template <typename Value, unsigned N, template <class, unsigned> class Abi> +using simd = typename simd_wrap<Value, N, Abi>::type; -template <typename Simd> -using const_where_expression = typename Simd::const_where_expression; +template <typename Value, unsigned N, template <class, unsigned> class Abi> +struct simd_mask_wrap { using type = typename simd<Value, N, Abi>::simd_mask; }; -template <typename Simd> -where_expression<Simd> where(const typename Simd::simd_mask& m, Simd& v) { - return where_expression<Simd>(m, v); -} - -template <typename Simd> -const_where_expression<Simd> where(const typename Simd::simd_mask& m, const Simd& v) { - return const_where_expression<Simd>(m, v); -} +template <typename Value, unsigned N, template <class, unsigned> class Abi> +using simd_mask = typename simd_mask_wrap<Value, N, Abi>::type; template <typename> struct is_simd: std::false_type {}; @@ -688,6 +977,11 @@ To simd_cast(const From& s) { return detail::simd_cast_impl<To>::cast(s); } +template <typename S, std::enable_if_t<is_simd<S>::value, int> = 0> +inline constexpr int width(const S a = S{}) { + return S::width; +}; + // Gather/scatter indexed memory specification. template < @@ -695,14 +989,35 @@ template < typename PtrLike, typename V = std::remove_reference_t<decltype(*std::declval<PtrLike>())> > -detail::indirect_expression<IndexImpl, V> indirect( +detail::indirect_indexed_expression<IndexImpl, V> indirect( PtrLike p, - const detail::simd_impl<IndexImpl>& index, + const IndexImpl& index, + unsigned width, index_constraint constraint = index_constraint::none) { - return detail::indirect_expression<IndexImpl, V>(p, index, constraint); + return detail::indirect_indexed_expression<IndexImpl, V>(p, index, width, constraint); +} + +template < + typename PtrLike, + typename V = std::remove_reference_t<decltype(*std::declval<PtrLike>())> +> +detail::indirect_expression<V> indirect( + PtrLike p, + unsigned width) +{ + return detail::indirect_expression<V>(p, width); } +template <typename Impl, typename ImplMask> +detail::where_expression<Impl, ImplMask> where(const ImplMask& m, Impl& v) { + return detail::where_expression<Impl, ImplMask>(m, v); +} + +template <typename Impl, typename ImplMask> +detail::const_where_expression<Impl, ImplMask> where(const ImplMask& m, const Impl& v) { + return detail::const_where_expression<Impl, ImplMask>(m, v); +} } // namespace simd } // namespace arb diff --git a/arbor/include/arbor/simd/sve.hpp b/arbor/include/arbor/simd/sve.hpp new file mode 100644 index 00000000..7a5c0ec9 --- /dev/null +++ b/arbor/include/arbor/simd/sve.hpp @@ -0,0 +1,904 @@ +#pragma once + +// SVE SIMD intrinsics implementation. + +#ifdef __ARM_FEATURE_SVE + +#include <arm_sve.h> +#include <cmath> +#include <cstdint> +#include <iostream> + +#include <arbor/util/pp_util.hpp> + +#include "approx.hpp" + +namespace arb { +namespace simd { +namespace detail { + +struct sve_double; +struct sve_int; +struct sve_mask; + +template<typename Type> struct sve_type_to_impl; +template<> struct sve_type_to_impl<svint64_t> { using type = detail::sve_int;}; +template<> struct sve_type_to_impl<svfloat64_t> { using type = detail::sve_double;}; +template<> struct sve_type_to_impl<svbool_t> { using type = detail::sve_mask;}; + +template<typename> struct is_sve : std::false_type {}; +template<> struct is_sve<svint64_t> : std::true_type {}; +template<> struct is_sve<svfloat64_t> : std::true_type {}; +template<> struct is_sve<svbool_t> : std::true_type {}; + +template <> +struct simd_traits<sve_mask> { + static constexpr unsigned width = 8; + using scalar_type = bool; + using vector_type = svbool_t; + using mask_impl = sve_mask; +}; + +template <> +struct simd_traits<sve_double> { + static constexpr unsigned width = 8; + using scalar_type = double; + using vector_type = svfloat64_t; + using mask_impl = sve_mask; +}; + +template <> +struct simd_traits<sve_int> { + static constexpr unsigned width = 8; + using scalar_type = int32_t; + using vector_type = svint64_t; + using mask_impl = sve_mask; +}; + +struct sve_mask { + static svbool_t broadcast(bool b) { + return svdup_b64(-b); + } + + static void copy_to(const svbool_t& k, bool* b) { + svuint64_t a = svdup_u64_z(k, 1); + svst1b_u64(svptrue_b64(), reinterpret_cast<uint8_t*>(b), a); + } + + static void copy_to_masked(const svbool_t& k, bool* b, const svbool_t& mask) { + svuint64_t a = svdup_u64_z(k, 1); + svst1b_u64(mask, reinterpret_cast<uint8_t*>(b), a); + } + + static svbool_t copy_from(const bool* p) { + svuint64_t a = svld1ub_u64(svptrue_b64(), reinterpret_cast<const uint8_t*>(p)); + svuint64_t ones = svdup_n_u64(1); + return svcmpeq_u64(svptrue_b64(), a, ones); + } + + static svbool_t copy_from_masked(const bool* p, const svbool_t& mask) { + svuint64_t a = svld1ub_u64(mask, reinterpret_cast<const uint8_t*>(p)); + svuint64_t ones = svdup_n_u64(1); + return svcmpeq_u64(mask, a, ones); + } + + static svbool_t logical_not(const svbool_t& k, const svbool_t& mask = svptrue_b64()) { + return svnot_b_z(mask, k); + } + + static svbool_t logical_and(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return svand_b_z(mask, a, b); + } + + static svbool_t logical_or(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return svorr_b_z(mask, a, b); + } + + // Arithmetic operations not necessarily appropriate for + // packed bit mask, but implemented for completeness/testing, + // with Z modulo 2 semantics: + // a + b is equivalent to a ^ b + // a * b a & b + // a / b a + // a - b a ^ b + // -a a + // max(a, b) a | b + // min(a, b) a & b + + static svbool_t neg(const svbool_t& a) { + return a; + } + + static svbool_t add(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return sveor_b_z(mask, a, b); + } + + static svbool_t sub(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return sveor_b_z(mask, a, b); + } + + static svbool_t mul(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return svand_b_z(mask, a, b); + } + + static svbool_t div(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return a; + } + + static svbool_t fma(const svbool_t& a, const svbool_t& b, const svbool_t& c, const svbool_t& mask = svptrue_b64()) { + return add(mul(a, b, mask), c, mask); + } + + static svbool_t max(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return svorr_b_z(mask, a, b); + } + + static svbool_t min(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return svand_b_z(mask, a, b); + } + + // Comparison operators are also taken as operating on Z modulo 2, + // with 1 > 0: + // + // a > b is equivalent to a & ~b + // a >= b a | ~b, ~(~a & b) + // a < b ~a & b + // a <= b ~a | b, ~(a & ~b) + // a == b ~(a ^ b) + // a != b a ^ b + + static svbool_t cmp_eq(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return svnot_b_z(mask, sveor_b_z(mask, a, b)); + } + + static svbool_t cmp_neq(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return sveor_b_z(mask, a, b); + } + + static svbool_t cmp_lt(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return svbic_b_z(mask, b, a); + } + + static svbool_t cmp_gt(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return cmp_lt(b, a); + } + + static svbool_t cmp_geq(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return logical_not(cmp_lt(a, b)); + } + + static svbool_t cmp_leq(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return logical_not(cmp_gt(a, b)); + } + + static svbool_t ifelse(const svbool_t& m, const svbool_t& u, const svbool_t& v) { + return svsel_b(m, u, v); + } + + static svbool_t mask_broadcast(bool b) { + return broadcast(b); + } + + static void mask_copy_to(const svbool_t& m, bool* y) { + copy_to(m, y); + } + + static svbool_t mask_copy_from(const bool* y) { + return copy_from(y); + } + + static svbool_t true_mask(unsigned width) { + return svwhilelt_b64_u64(0, (uint64_t)width); + } +}; + +struct sve_int { + // Use default implementations for: + // element, set_element. + + using int32 = std::int32_t; + + static svint64_t broadcast(int32 v) { + return svreinterpret_s64_s32(svdup_n_s32(v)); + } + + static void copy_to(const svint64_t& v, int32* p) { + svst1w_s64(svptrue_b64(), p, v); + } + + static void copy_to_masked(const svint64_t& v, int32* p, const svbool_t& mask) { + svst1w_s64(mask, p, v); + } + + static svint64_t copy_from(const int32* p) { + return svld1sw_s64(svptrue_b64(), p); + } + + static svint64_t copy_from_masked(const int32* p, const svbool_t& mask) { + return svld1sw_s64(mask, p); + } + + static svint64_t copy_from_masked(const svint64_t& v, const int32* p, const svbool_t& mask) { + return svsel_s64(mask, svld1sw_s64(mask, p), v); + } + + static int element0(const svint64_t& a) { + return svlasta_s64(svptrue_b64(), a); + } + + static svint64_t neg(const svint64_t& a, const svbool_t& mask = svptrue_b64()) { + return svneg_s64_z(mask, a); + } + + static svint64_t add(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svadd_s64_z(mask, a, b); + } + + static svint64_t sub(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svsub_s64_m(mask, a, b); + } + + static svint64_t mul(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + //May overflow + return svmul_s64_z(mask, a, b); + } + + static svint64_t div(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svdiv_s64_z(mask, a, b); + } + + static svint64_t fma(const svint64_t& a, const svint64_t& b, const svint64_t& c, const svbool_t& mask = svptrue_b64()) { + return add(mul(a, b, mask), c, mask); + } + + static svbool_t cmp_eq(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmpeq_s64(mask, a, b); + } + + static svbool_t cmp_neq(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmpne_s64(mask, a, b); + } + + static svbool_t cmp_gt(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmpgt_s64(mask, a, b); + } + + static svbool_t cmp_geq(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmpge_s64(mask, a, b); + } + + static svbool_t cmp_lt(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmplt_s64(mask, a, b); + } + + static svbool_t cmp_leq(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmple_s64(mask, a, b); + } + + static svint64_t ifelse(const svbool_t& m, const svint64_t& u, const svint64_t& v) { + return svsel_s64(m, u, v); + } + + static svint64_t max(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svmax_s64_x(mask, a, b); + } + + static svint64_t min(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svmin_s64_x(mask, a, b); + } + + static svint64_t abs(const svint64_t& a, const svbool_t& mask = svptrue_b64()) { + return svabs_s64_z(mask, a); + } + + static int reduce_add(const svint64_t& a, const svbool_t& mask = svptrue_b64()) { + return svaddv_s64(mask, a); + } + + static svint64_t pow(const svint64_t& x, const svint64_t& y, const svbool_t& mask = svptrue_b64()) { + auto len = svlen_s64(x); + int32 a[len], b[len], r[len]; + copy_to_masked(x, a, mask); + copy_to_masked(y, b, mask); + + for (unsigned i = 0; i<len; ++i) { + r[i] = std::pow(a[i], b[i]); + } + return copy_from_masked(r, mask); + } + + static svint64_t gather(tag<sve_int>, const int32* p, const svint64_t& index, const svbool_t& mask = svptrue_b64()) { + return svld1sw_gather_s64index_s64(mask, p, index); + } + + static svint64_t gather(tag<sve_int>, svint64_t a, const int32* p, const svint64_t& index, const svbool_t& mask) { + return svsel_s64(mask, svld1sw_gather_s64index_s64(mask, p, index), a); + } + + static void scatter(tag<sve_int>, const svint64_t& s, int32* p, const svint64_t& index, const svbool_t& mask = svptrue_b64()) { + svst1w_scatter_s64index_s64(mask, p, index, s); + } + + static unsigned simd_width(const svint64_t& m) { + return svlen_s64(m); + } +}; + +struct sve_double { + // Use default implementations for: + // element, set_element. + + static svfloat64_t broadcast(double v) { + return svdup_n_f64(v); + } + + static void copy_to(const svfloat64_t& v, double* p) { + svst1_f64(svptrue_b64(), p, v); + } + + static void copy_to_masked(const svfloat64_t& v, double* p, const svbool_t& mask) { + svst1_f64(mask, p, v); + } + + static svfloat64_t copy_from(const double* p) { + return svld1_f64(svptrue_b64(), p); + } + + static svfloat64_t copy_from_masked(const double* p, const svbool_t& mask) { + return svld1_f64(mask, p); + } + + static svfloat64_t copy_from_masked(const svfloat64_t& v, const double* p, const svbool_t& mask) { + return svsel_f64(mask, svld1_f64(mask, p), v); + } + + static double element0(const svfloat64_t& a) { + return svlasta_f64(svptrue_b64(), a); + } + + static svfloat64_t neg(const svfloat64_t& a, const svbool_t& mask = svptrue_b64()) { + return svneg_f64_z(mask, a); + } + + static svfloat64_t add(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svadd_f64_z(mask, a, b); + } + + static svfloat64_t sub(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svsub_f64_z(mask, a, b); + } + + static svfloat64_t mul(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svmul_f64_z(mask, a, b); + } + + static svfloat64_t div(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svdiv_f64_z(mask, a, b); + } + + static svfloat64_t fma(const svfloat64_t& a, const svfloat64_t& b, const svfloat64_t& c, const svbool_t& mask = svptrue_b64()) { + return svmad_f64_z(mask, a, b, c); + } + + static svbool_t cmp_eq(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmpeq_f64(mask, a, b); + } + + static svbool_t cmp_neq(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmpne_f64(mask, a, b); + } + + static svbool_t cmp_gt(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmpgt_f64(mask, a, b); + } + + static svbool_t cmp_geq(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmpge_f64(mask, a, b); + } + + static svbool_t cmp_lt(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmplt_f64(mask, a, b); + } + + static svbool_t cmp_leq(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmple_f64(mask, a, b); + } + + static svfloat64_t ifelse(const svbool_t& m, const svfloat64_t& u, const svfloat64_t& v) { + return svsel_f64(m, u, v); + } + + static svfloat64_t max(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svmax_f64_x(mask, a, b); + } + + static svfloat64_t min(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svmin_f64_x(mask, a, b); + } + + static svfloat64_t abs(const svfloat64_t& x, const svbool_t& mask = svptrue_b64()) { + return svabs_f64_x(mask, x); + } + + static double reduce_add(const svfloat64_t& a, const svbool_t& mask = svptrue_b64()) { + return svaddv_f64(mask, a); + } + + static svfloat64_t gather(tag<sve_int>, const double* p, const svint64_t& index, const svbool_t& mask = svptrue_b64()) { + return svld1_gather_s64index_f64(mask, p, index); + } + + static svfloat64_t gather(tag<sve_int>, svfloat64_t a, const double* p, const svint64_t& index, const svbool_t& mask) { + return svsel_f64(mask, svld1_gather_s64index_f64(mask, p, index), a); + } + + static void scatter(tag<sve_int>, const svfloat64_t& s, double* p, const svint64_t& index, const svbool_t& mask = svptrue_b64()) { + svst1_scatter_s64index_f64(mask, p, index, s); + } + + // Refer to avx/avx2 code for details of the exponential and log + // implementations. + + static svfloat64_t exp(const svfloat64_t& x) { + // Masks for exceptional cases. + + auto is_large = cmp_gt(x, broadcast(exp_maxarg)); + auto is_small = cmp_lt(x, broadcast(exp_minarg)); + + // Compute n and g. + + auto n = svrintz_f64_z(svptrue_b64(), add(mul(broadcast(ln2inv), x), broadcast(0.5))); + + auto g = fma(n, broadcast(-ln2C1), x); + g = fma(n, broadcast(-ln2C2), g); + + auto gg = mul(g, g); + + // Compute the g*P(g^2) and Q(g^2). + auto odd = mul(g, horner(gg, P0exp, P1exp, P2exp)); + auto even = horner(gg, Q0exp, Q1exp, Q2exp, Q3exp); + + // Compute R(g)/R(-g) = 1 + 2*g*P(g^2) / (Q(g^2)-g*P(g^2)) + + auto expg = fma(broadcast(2), div(odd, sub(even, odd)), broadcast(1)); + + // Scale by 2^n, propogating NANs. + + auto result = svscale_f64_z(svptrue_b64(), expg, svcvt_s64_f64_z(svptrue_b64(), n)); + + return + ifelse(is_large, broadcast(HUGE_VAL), + ifelse(is_small, broadcast(0), + result)); + } + + static svfloat64_t expm1(const svfloat64_t& x) { + auto is_large = cmp_gt(x, broadcast(exp_maxarg)); + auto is_small = cmp_lt(x, broadcast(expm1_minarg)); + + auto half = broadcast(0.5); + auto one = broadcast(1.); + + auto nnz = cmp_gt(abs(x), half); + auto n = svrinta_f64_z(nnz, mul(broadcast(ln2inv), x)); + + auto g = fma(n, broadcast(-ln2C1), x); + g = fma(n, broadcast(-ln2C2), g); + + auto gg = mul(g, g); + + auto odd = mul(g, horner(gg, P0exp, P1exp, P2exp)); + auto even = horner(gg, Q0exp, Q1exp, Q2exp, Q3exp); + + // Compute R(g)/R(-g) -1 = 2*g*P(g^2) / (Q(g^2)-g*P(g^2)) + + auto expgm1 = div(mul(broadcast(2), odd), sub(even, odd)); + + // For small x (n zero), bypass scaling step to avoid underflow. + // Otherwise, compute result 2^n * expgm1 + (2^n-1) by: + // result = 2 * ( 2^(n-1)*expgm1 + (2^(n-1)+0.5) ) + // to avoid overflow when n=1024. + + auto nm1 = svcvt_s64_f64_z(svptrue_b64(), sub(n, one)); + + auto result = + svscale_f64_z(svptrue_b64(), + add(sub(svscale_f64_z(svptrue_b64(),one, nm1), half), + svscale_f64_z(svptrue_b64(),expgm1, nm1)), + svcvt_s64_f64_z(svptrue_b64(), one)); + + return + ifelse(is_large, broadcast(HUGE_VAL), + ifelse(is_small, broadcast(-1), + ifelse(nnz, result, expgm1))); + } + + static svfloat64_t exprelr(const svfloat64_t& x) { + auto ones = broadcast(1); + return ifelse(cmp_eq(ones, add(ones, x)), ones, div(x, expm1(x))); + } + + static svfloat64_t log(const svfloat64_t& x) { + // Masks for exceptional cases. + + auto is_large = cmp_geq(x, broadcast(HUGE_VAL)); + auto is_small = cmp_lt(x, broadcast(log_minarg)); + auto is_domainerr = cmp_lt(x, broadcast(0)); + + auto is_nan = svnot_b_z(svptrue_b64(), cmp_eq(x, x)); + is_domainerr = svorr_b_z(svptrue_b64(), is_nan, is_domainerr); + + svfloat64_t g = svcvt_f64_s32_z(svptrue_b64(), logb_normal(x)); + svfloat64_t u = fraction_normal(x); + + svfloat64_t one = broadcast(1.); + svfloat64_t half = broadcast(0.5); + auto gtsqrt2 = cmp_geq(u, broadcast(sqrt2)); + g = ifelse(gtsqrt2, add(g, one), g); + u = ifelse(gtsqrt2, mul(u, half), u); + + auto z = sub(u, one); + auto pz = horner(z, P0log, P1log, P2log, P3log, P4log, P5log); + auto qz = horner1(z, Q0log, Q1log, Q2log, Q3log, Q4log); + + auto z2 = mul(z, z); + auto z3 = mul(z2, z); + + auto r = div(mul(z3, pz), qz); + r = add(r, mul(g, broadcast(ln2C4))); + r = sub(r, mul(z2, half)); + r = add(r, z); + r = add(r, mul(g, broadcast(ln2C3))); + + // r is alrady NaN if x is NaN or negative, otherwise + // return +inf if x is +inf, or -inf if zero or (positive) denormal. + + return ifelse(is_domainerr, broadcast(NAN), + ifelse(is_large, broadcast(HUGE_VAL), + ifelse(is_small, broadcast(-HUGE_VAL), r))); + } + + static svfloat64_t pow(const svfloat64_t& x, const svfloat64_t& y) { + auto len = svlen_f64(x); + double a[len], b[len], r[len]; + copy_to(x, a); + copy_to(y, b); + + for (unsigned i = 0; i<len; ++i) { + r[i] = std::pow(a[i], b[i]); + } + return copy_from(r); + } + + static unsigned simd_width(const svfloat64_t& m) { + return svlen_f64(m); + } + +protected: + // Compute n and f such that x = 2^n·f, with |f| ∈ [1,2), given x is finite and normal. + static svint32_t logb_normal(const svfloat64_t& x) { + svuint32_t xw = svtrn2_u32(svreinterpret_u32_f64(x), svreinterpret_u32_f64(x)); + svuint64_t lmask = svdup_n_u64(0x00000000ffffffff); + svuint64_t xt = svand_u64_z(svptrue_b64(), svreinterpret_u64_u32(xw), lmask); + svuint32_t xhi = svreinterpret_u32_u64(xt); + auto emask = svdup_n_u32(0x7ff00000); + auto ebiased = svlsr_n_u32_z(svptrue_b64(), svand_u32_z(svptrue_b64(), xhi, emask), 20); + + return svsub_s32_z(svptrue_b64(), svreinterpret_s32_u32(ebiased), svdup_n_s32(1023)); + } + + static svfloat64_t fraction_normal(const svfloat64_t& x) { + svuint64_t emask = svdup_n_u64(-0x7ff0000000000001); + svuint64_t bias = svdup_n_u64(0x3ff0000000000000); + return svreinterpret_f64_u64( + svorr_u64_z(svptrue_b64(), bias, svand_u64_z(svptrue_b64(), emask, svreinterpret_u64_f64(x)))); + } + + static inline svfloat64_t horner1(svfloat64_t x, double a0) { + return add(x, broadcast(a0)); + } + + static inline svfloat64_t horner(svfloat64_t x, double a0) { + return broadcast(a0); + } + + template <typename... T> + static svfloat64_t horner(svfloat64_t x, double a0, T... tail) { + return fma(x, horner(x, tail...), broadcast(a0)); + } + + template <typename... T> + static svfloat64_t horner1(svfloat64_t x, double a0, T... tail) { + return fma(x, horner1(x, tail...), broadcast(a0)); + } + + static svfloat64_t fms(const svfloat64_t& a, const svfloat64_t& b, const svfloat64_t& c) { + return svnmsb_f64_z(svptrue_b64(), a, b, c); + } +}; + +} // namespace detail + +namespace simd_abi { +template <typename T, unsigned N> struct sve; +template <> struct sve<double, 0> {using type = detail::sve_double;}; +template <> struct sve<int, 0> {using type = detail::sve_int;}; +}; // namespace simd_abi + +template <typename Value, unsigned N, template <class, unsigned> class Abi> +struct simd_wrap; + +template <typename Value, template <class, unsigned> class Abi> +struct simd_wrap<Value, (unsigned)0, Abi> { using type = typename detail::simd_traits<typename simd_abi::sve<Value, 0u>::type>::vector_type; }; + +template <typename Value, unsigned N, template <class, unsigned> class Abi> +struct simd_mask_wrap; + +template <typename Value, template <class, unsigned> class Abi> +struct simd_mask_wrap<Value, (unsigned)0, Abi> { + using type = typename detail::simd_traits< + typename detail::simd_traits<typename simd_abi::sve<Value, 0u>::type>::mask_impl + >::vector_type; }; + +// Math functions exposed for SVE types + +#define ARB_SVE_UNARY_ARITHMETIC_(name)\ +template <typename T>\ +T name(const T& a) {\ + return detail::sve_type_to_impl<T>::type::name(a);\ +}; + +#define ARB_SVE_BINARY_ARITHMETIC_(name)\ +template <typename T>\ +auto name(const T& a, const T& b) {\ + return detail::sve_type_to_impl<T>::type::name(a, b);\ +};\ +template <typename T>\ +auto name(const T& a, const typename detail::simd_traits<typename detail::sve_type_to_impl<T>::type>::scalar_type& b) {\ + return name(a, detail::sve_type_to_impl<T>::type::broadcast(b));\ +};\ +template <typename T>\ +auto name(const typename detail::simd_traits<typename detail::sve_type_to_impl<T>::type>::scalar_type& a, const T& b) {\ + return name(detail::sve_type_to_impl<T>::type::broadcast(a), b);\ +}; + + +ARB_PP_FOREACH(ARB_SVE_BINARY_ARITHMETIC_, add, sub, mul, div, pow, max, min) +ARB_PP_FOREACH(ARB_SVE_BINARY_ARITHMETIC_, cmp_eq, cmp_neq, cmp_leq, cmp_lt, cmp_geq, cmp_gt, logical_and, logical_or) +ARB_PP_FOREACH(ARB_SVE_UNARY_ARITHMETIC_, logical_not, neg, abs, exp, log, expm1, exprelr) + +#undef ARB_SVE_UNARY_ARITHMETIC_ +#undef ARB_SVE_BINARY_ARITHMETIC_ + +template <typename T> +T fma(const T& a, T b, T c) { + return detail::sve_type_to_impl<T>::type::fma(a, b, c); +} + +template <typename T> +auto sum(const T& a) { + return detail::sve_type_to_impl<T>::type::reduce_add(a); +} + +// Indirect/Indirect indexed/Where Expression copy methods + +template <typename T, typename V> +static void indirect_copy_to(const T& s, V* p, unsigned width) { + using Impl = typename detail::sve_type_to_impl<T>::type; + using ImplMask = typename detail::simd_traits<Impl>::mask_impl; + Impl::copy_to_masked(s, p, ImplMask::true_mask(width)); +} + +template <typename T, typename M, typename V> +static void indirect_copy_to(const T& data, const M& mask, V* p, unsigned width) { + using Impl = typename detail::sve_type_to_impl<T>::type; + using ImplMask = typename detail::sve_type_to_impl<M>::type; + + Impl::copy_to_masked(data, p, ImplMask::logical_and(mask, ImplMask::true_mask(width))); +} + +template <typename T, typename I, typename V> +static void indirect_indexed_copy_to(const T& s, V* p, const I& index, unsigned width) { + using Impl = typename detail::sve_type_to_impl<T>::type; + using ImplIndex = typename detail::sve_type_to_impl<I>::type; + using ImplMask = typename detail::simd_traits<Impl>::mask_impl; + + Impl::scatter(detail::tag<ImplIndex>{}, s, p, index, ImplMask::true_mask(width)); +} + +template <typename T, typename I, typename M, typename V> +static void indirect_indexed_copy_to(const T& data, const M& mask, V* p, const I& index, unsigned width) { + using Impl = typename detail::sve_type_to_impl<T>::type; + using ImplIndex = typename detail::sve_type_to_impl<I>::type; + using ImplMask = typename detail::sve_type_to_impl<M>::type; + + Impl::scatter(detail::tag<ImplIndex>{}, data, p, index, ImplMask::logical_and(mask, ImplMask::true_mask(width))); +} + +template <typename T, typename M, typename V> +static void where_copy_to(const M& mask, T& f, const V& t) { + using Impl = typename detail::sve_type_to_impl<T>::type; + f = Impl::ifelse(mask, Impl::broadcast(t), f); +} + +template <typename T, typename M> +static void where_copy_to(const M& mask, T& f, const T& t) { + f = detail::sve_type_to_impl<T>::type::ifelse(mask, t, f); +} + +template <typename T, typename M, typename V> +static void where_copy_to(const M& mask, T& f, const V* p, unsigned width) { + using Impl = typename detail::sve_type_to_impl<T>::type; + using ImplMask = typename detail::sve_type_to_impl<M>::type; + + auto m = ImplMask::logical_and(mask, ImplMask::true_mask(width)); + f = Impl::ifelse(mask, Impl::copy_from_masked(p, m), f); +} + +template <typename T, typename I, typename M, typename V> +static void where_copy_to(const M& mask, T& f, const V* p, const I& index, unsigned width) { + using Impl = typename detail::sve_type_to_impl<T>::type; + using IndexImpl = typename detail::sve_type_to_impl<I>::type; + using ImplMask = typename detail::sve_type_to_impl<M>::type; + + auto m = ImplMask::logical_and(mask, ImplMask::true_mask(width)); + T temp = Impl::gather(detail::tag<IndexImpl>{}, p, index, m); + f = Impl::ifelse(mask, temp, f); +} + +template <typename I, typename T, typename V> +void compound_indexed_add( + const T& s, + V* p, + const I& index, + unsigned width, + index_constraint constraint) +{ + using Impl = typename detail::sve_type_to_impl<T>::type; + using ImplIndex = typename detail::sve_type_to_impl<I>::type; + using ImplMask = typename detail::simd_traits<Impl>::mask_impl; + + auto mask = ImplMask::true_mask(width); + switch (constraint) { + case index_constraint::none: + { + typename detail::simd_traits<ImplIndex>::scalar_type o[width]; + ImplIndex::copy_to_masked(index, o, mask); + + V a[width]; + Impl::copy_to_masked(s, a, mask); + + V temp = 0; + for (unsigned i = 0; i<width-1; ++i) { + temp += a[i]; + if (o[i] != o[i+1]) { + p[o[i]] += temp; + temp = 0; + } + } + temp += a[width-1]; + p[o[width-1]] += temp; + } + break; + case index_constraint::independent: + { + auto v = Impl::add(Impl::gather(detail::tag<ImplIndex>{}, p, index, mask), s, mask); + Impl::scatter(detail::tag<ImplIndex>{}, v, p, index, mask); + } + break; + case index_constraint::contiguous: + { + p += ImplIndex::element0(index); + auto v = Impl::add(Impl::copy_from_masked(p, mask), s, mask); + Impl::copy_to_masked(v, p, mask); + } + break; + case index_constraint::constant: + p += ImplIndex::element0(index); + *p += Impl::reduce_add(s, mask); + break; + } +} + +static int width(const svfloat64_t& v) { + return svlen_f64(v); +}; + +static int width(const svint64_t& v) { + return svlen_s64(v); +}; + +template <typename S, typename std::enable_if_t<detail::is_sve<S>::value, int> = 0> +static int width() { S v; return width(v); } + +namespace detail { + +template <typename I, typename V> +class indirect_indexed_expression; + +template <typename V> +class indirect_expression; + +template <typename T, typename M> +class where_expression; + +template <typename T, typename M> +class const_where_expression; + +template <typename To> +struct simd_cast_impl { + template <typename V> + static To cast(const V& a) { + return detail::sve_type_to_impl<To>::type::broadcast(a); + } + + template <typename V> + static To cast(const indirect_expression<V>& a) { + using Impl = typename detail::sve_type_to_impl<To>::type; + using ImplMask = typename detail::simd_traits<Impl>::mask_impl; + + return Impl::copy_from_masked(a.p, ImplMask::true_mask(a.width)); + } + + template <typename I, typename V> + static To cast(const indirect_indexed_expression<I,V>& a) { + using Impl = typename detail::sve_type_to_impl<To>::type; + using IndexImpl = typename detail::sve_type_to_impl<I>::type; + using ImplMask = typename detail::simd_traits<Impl>::mask_impl; + + To r; + auto mask = ImplMask::true_mask(a.width); + switch (a.constraint) { + case index_constraint::none: + r = Impl::gather(tag<IndexImpl>{}, a.p, a.index, mask); + break; + case index_constraint::independent: + r = Impl::gather(tag<IndexImpl>{}, a.p, a.index, mask); + break; + case index_constraint::contiguous: + { + const auto* p = IndexImpl::element0(a.index) + a.p; + r = Impl::copy_from_masked(p, mask); + } + break; + case index_constraint::constant: + { + const auto *p = IndexImpl::element0(a.index) + a.p; + auto l = (*p); + r = Impl::broadcast(l); + } + break; + } + return r; + } + + template <typename T, typename V> + static To cast(const const_where_expression<T,V>& a) { + auto r = detail::sve_type_to_impl<To>::type::broadcast(0); + r = detail::sve_type_to_impl<To>::type::ifelse(a.mask_, a.data_, r); + return r; + } + + template <typename T, typename V> + static To cast(const where_expression<T,V>& a) { + auto r = detail::sve_type_to_impl<To>::type::broadcast(0); + r = detail::sve_type_to_impl<To>::type::ifelse(a.mask_, a.data_, r); + return r; + } +}; + +template <typename T, typename V> +void assign(T& a, const detail::indirect_expression<V>& b) { + a = detail::simd_cast_impl<T>::cast(b); +} + +template <typename T, typename I, typename V> +void assign(T& a, const detail::indirect_indexed_expression<I, V>& b) { + a = detail::simd_cast_impl<T>::cast(b); +} + +} // namespace detail +} // namespace simd +} // namespace arb + +#endif // def __ARM_FEATURE_SVE diff --git a/mechanisms/default/kdrmt.mod b/mechanisms/default/kdrmt.mod index 54c07044..86a697a0 100644 --- a/mechanisms/default/kdrmt.mod +++ b/mechanisms/default/kdrmt.mod @@ -61,12 +61,12 @@ DERIVATIVE states { PROCEDURE trates(v,celsius) { LOCAL qt - LOCAL alpm, betm + LOCAL alpm_t, betm_t LOCAL tmp qt=q10^((celsius-24)/10) minf = 1/(1 + exp(-(v-21)/10)) tmp = zetam*(v-vhalfm) - alpm = exp(tmp) - betm = exp(gmm*tmp) - mtau = betm/(qt*a0m*(1+alpm)) + alpm_t = exp(tmp) + betm_t = exp(gmm*tmp) + mtau = betm_t/(qt*a0m*(1+alpm_t)) } diff --git a/modcc/modcc.cpp b/modcc/modcc.cpp index 488f0ed0..d1ec2c76 100644 --- a/modcc/modcc.cpp +++ b/modcc/modcc.cpp @@ -48,6 +48,7 @@ std::unordered_map<std::string, targetKind> targetKindMap = { std::unordered_map<std::string, enum simd_spec::simd_abi> simdAbiMap = { {"none", simd_spec::none}, {"neon", simd_spec::neon}, + {"sve", simd_spec::sve}, {"avx", simd_spec::avx}, {"avx2", simd_spec::avx2}, {"avx512", simd_spec::avx512}, @@ -118,7 +119,7 @@ std::istream& operator>> (std::istream& i, simd_spec& spec) { auto npos = std::string::npos; std::string s; i >> s; - unsigned width = 0; + unsigned width = no_size; auto suffix = s.find_last_of('/'); if (suffix!=npos) { diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp index f6edfe9a..4c947d60 100644 --- a/modcc/printer/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -170,6 +170,157 @@ void CExprEmitter::visit(IfExpression* e) { /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// std::unordered_set<std::string> SimdExprEmitter::mask_names_; +void SimdExprEmitter::visit(PowBinaryExpression* e) { + out_ << "S::pow("; + e->lhs()->accept(this); + out_ << ", "; + e->rhs()->accept(this); + out_ << ')'; +} + +void SimdExprEmitter::visit(NumberExpression* e) { + out_ << " (double)" << as_c_double(e->value()); +} + +void SimdExprEmitter::visit(UnaryExpression* e) { + static std::unordered_map<tok, const char*> unaryop_tbl = { + {tok::minus, "S::neg"}, + {tok::exp, "S::exp"}, + {tok::cos, "S::cos"}, + {tok::sin, "S::sin"}, + {tok::log, "S::log"}, + {tok::abs, "S::abs"}, + {tok::exprelr, "S::exprelr"}, + {tok::safeinv, "safeinv"} + }; + + if (!unaryop_tbl.count(e->op())) { + throw compiler_exception( + "CExprEmitter: unsupported unary operator "+token_string(e->op()), e->location()); + } + + const char* op_spelling = unaryop_tbl.at(e->op()); + Expression* inner = e->expression(); + + auto iden = inner->is_identifier(); + bool is_scalar = iden && scalars_.count(iden->name()); + if (e->op()==tok::minus && is_scalar) { + out_ << "simd_cast<simd_value>(-"; + inner->accept(this); + out_ << ")"; + } + else { + emit_as_call(op_spelling, inner); + } +} + +void SimdExprEmitter::visit(BinaryExpression* e) { + static std::unordered_map<tok, const char *> func_tbl = { + {tok::minus, "S::sub"}, + {tok::plus, "S::add"}, + {tok::times, "S::mul"}, + {tok::divide, "S::div"}, + {tok::lt, "S::cmp_lt"}, + {tok::lte, "S::cmp_leq"}, + {tok::gt, "S::cmp_gt"}, + {tok::gte, "S::cmp_geq"}, + {tok::equality, "S::cmp_eq"}, + {tok::land, "S::logical_and"}, + {tok::lor, "S::logical_or"}, + {tok::ne, "S::cmp_neq"}, + {tok::min, "S::min"}, + {tok::max, "S::max"}, + }; + + static std::unordered_map<tok, const char *> binop_tbl = { + {tok::minus, "-"}, + {tok::plus, "+"}, + {tok::times, "*"}, + {tok::divide, "/"}, + {tok::lt, "<"}, + {tok::lte, "<="}, + {tok::gt, ">"}, + {tok::gte, ">="}, + {tok::equality, "=="}, + {tok::land, "&&"}, + {tok::lor, "||"}, + {tok::ne, "!="}, + {tok::min, "min"}, + {tok::max, "max"}, + }; + + + if (!binop_tbl.count(e->op())) { + throw compiler_exception( + "CExprEmitter: unsupported binary operator " + token_string(e->op()), e->location()); + } + + std::string rhs_name, lhs_name; + + auto rhs = e->rhs(); + auto lhs = e->lhs(); + + const char *op_spelling = binop_tbl.at(e->op()); + const char *func_spelling = func_tbl.at(e->op()); + + if (rhs->is_identifier()) { + rhs_name = rhs->is_identifier()->name(); + } + if (lhs->is_identifier()) { + lhs_name = lhs->is_identifier()->name(); + } + + if (scalars_.count(rhs_name) && scalars_.count(lhs_name)) { + if (e->is_infix()) { + associativityKind assoc = Lexer::operator_associativity(e->op()); + int op_prec = Lexer::binop_precedence(e->op()); + + auto need_paren = [op_prec](Expression *subexpr, bool assoc_side) -> bool { + if (auto b = subexpr->is_binary()) { + int sub_prec = Lexer::binop_precedence(b->op()); + return sub_prec < op_prec || (!assoc_side && sub_prec == op_prec); + } + return false; + }; + + out_ << "simd_cast<simd_value>("; + if (need_paren(lhs, assoc == associativityKind::left)) { + emit_as_call("", lhs); + } else { + lhs->accept(this); + } + + out_ << op_spelling; + + if (need_paren(rhs, assoc == associativityKind::right)) { + emit_as_call("", rhs); + } else { + rhs->accept(this); + } + out_ << ")"; + } else { + out_ << "simd_cast<simd_value>("; + emit_as_call(op_spelling, lhs, rhs); + out_ << ")"; + } + } else if (scalars_.count(rhs_name) && !scalars_.count(lhs_name)) { + out_ << func_spelling << '('; + lhs->accept(this); + out_ << ", simd_cast<simd_value>(" << rhs_name ; + out_ << "))"; + } else if (!scalars_.count(rhs_name) && scalars_.count(lhs_name)) { + out_ << func_spelling << "(simd_cast<simd_value>(" << lhs_name << "), "; + rhs->accept(this); + out_ << ")"; + } else { + out_ << func_spelling << '('; + lhs->accept(this); + out_ << ", "; + rhs->accept(this); + out_ << ')'; + } +} + void SimdExprEmitter::visit(BlockExpression* block) { for (auto& stmt: block->statements()) { if (!stmt->is_local_declaration()) { @@ -209,15 +360,22 @@ void SimdExprEmitter::visit(AssignmentExpression* e) { if (lhs->is_variable() && lhs->is_variable()->is_range()) { if (!input_mask_.empty()) { - mask = mask + " && " + input_mask_; + mask = "S::logical_and(" + mask + ", " + input_mask_ + ")"; } - out_ << "S::where(" << mask << ", " << "simd_value("; - e->rhs()->accept(this); - out_ << "))"; if(is_indirect_) - out_ << ".copy_to(" << lhs->name() << "+index_)"; + out_ << "indirect(" << lhs->name() << "+index_, simd_width_) = "; else - out_ << ".copy_to(" << lhs->name() << "+i_)"; + out_ << "indirect(" << lhs->name() << "+i_, simd_width_) = "; + + out_ << "S::where(" << mask << ", "; + + bool cast = e->rhs()->is_number(); + if (cast) out_ << "simd_cast<simd_value>("; + e->rhs()->accept(this); + + out_ << ")"; + + if (cast) out_ << ")"; } else { out_ << "S::where(" << mask << ", "; e->lhs()->accept(this); @@ -237,17 +395,17 @@ void SimdExprEmitter::visit(IfExpression* e) { auto new_mask = make_unique_var(e->scope(), "mask_"); // Set new masks - out_ << "simd_value::simd_mask " << new_mask << " = "; + out_ << "simd_mask " << new_mask << " = "; e->condition()->accept(this); out_ << ";\n"; if (!current_mask_.empty()) { auto base_mask = processing_true_ ? current_mask_ : current_mask_bar_; - current_mask_bar_ = base_mask + " && !" + new_mask; - current_mask_ = base_mask + " && " + new_mask; + current_mask_bar_ = "S::logical_and(" + base_mask + ", S::logical_not(" + new_mask + "))"; + current_mask_ = "S::logical_and(" + base_mask + ", " + new_mask + ")"; } else { - current_mask_bar_ = "!" + new_mask; + current_mask_bar_ = "S::logical_not(" + new_mask + ")"; current_mask_ = new_mask; } diff --git a/modcc/printer/cexpr_emit.hpp b/modcc/printer/cexpr_emit.hpp index fbbfd448..9286a5c9 100644 --- a/modcc/printer/cexpr_emit.hpp +++ b/modcc/printer/cexpr_emit.hpp @@ -40,12 +40,21 @@ inline void cexpr_emit(Expression* e, std::ostream& out, Visitor* fallback) { class SimdExprEmitter: public CExprEmitter { using CExprEmitter::visit; public: - SimdExprEmitter(std::ostream& out, bool is_indirect, std::string input_mask, Visitor* fallback): - CExprEmitter(out, fallback), is_indirect_(is_indirect), input_mask_(input_mask) {} + SimdExprEmitter( + std::ostream& out, + bool is_indirect, + std::string input_mask, + const std::unordered_set<std::string>& scalars, + Visitor* fallback): + CExprEmitter(out, fallback), is_indirect_(is_indirect), input_mask_(input_mask), scalars_(scalars), fallback_(fallback) {} void visit(BlockExpression *e) override; void visit(CallExpression *e) override; + void visit(UnaryExpression *e) override; + void visit(BinaryExpression *e) override; void visit(AssignmentExpression *e) override; + void visit(PowBinaryExpression *e) override; + void visit(NumberExpression *e) override; void visit(IfExpression *e) override; protected: @@ -53,6 +62,8 @@ protected: bool processing_true_; bool is_indirect_; std::string current_mask_, current_mask_bar_, input_mask_; + std::unordered_set<std::string> scalars_; + Visitor* fallback_; private: std::string make_unique_var(scope_ptr scope, std::string prefix) { @@ -66,8 +77,15 @@ private: }; }; -inline void simd_expr_emit(Expression* e, std::ostream& out, bool is_indirect, std::string input_mask, Visitor* fallback) { - SimdExprEmitter emitter(out, is_indirect, input_mask, fallback); +inline void simd_expr_emit( + Expression* e, + std::ostream& out, + bool is_indirect, + std::string input_mask, + const std::unordered_set<std::string>& scalars, + Visitor* fallback) +{ + SimdExprEmitter emitter(out, is_indirect, input_mask, scalars, fallback); e->accept(&emitter); } diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index 2ebd8db3..73ab685d 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -29,7 +29,7 @@ void emit_simd_procedure_proto(std::ostream&, ProcedureExpression*, const std::s void emit_masked_simd_procedure_proto(std::ostream&, ProcedureExpression*, const std::string& qualified = ""); void emit_api_body(std::ostream&, APIMethod*); -void emit_simd_api_body(std::ostream&, APIMethod*, moduleKind); +void emit_simd_api_body(std::ostream&, APIMethod*, const std::vector<VariableExpression*>& scalars); void emit_index_initialize(std::ostream& out, const std::unordered_set<std::string>& indices, simd_expr_constraint constraint); @@ -59,8 +59,13 @@ struct simdprint { Expression* expr_; bool is_indirect_ = false; bool is_masked_ = false; + std::unordered_set<std::string> scalars_; - explicit simdprint(Expression* expr): expr_(expr) {} + explicit simdprint(Expression* expr, const std::vector<VariableExpression*>& scalars): expr_(expr) { + for (const auto& s: scalars) { + scalars_.insert(s->name()); + } + } void set_indirect_index() { is_indirect_ = true; @@ -75,6 +80,7 @@ struct simdprint { printer.set_input_mask("mask_input_"); } printer.set_var_indexed(w.is_indirect_); + printer.save_scalar_names(w.scalars_); return w.expr_->accept(&printer), out; } }; @@ -147,6 +153,8 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { if (with_simd) { out << "#include <" << arb_header_prefix() << "simd/simd.hpp>\n"; + out << "#undef NDEBUG\n"; + out << "#include <cassert>\n"; } out << @@ -173,12 +181,20 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { out << "namespace S = ::arb::simd;\n" "using S::index_constraint;\n" - "static constexpr unsigned simd_width_ = "; + "using S::simd_cast;\n" + "using S::indirect;\n"; - if (!opt.simd.width) { + out << "static constexpr unsigned vector_length_ = "; + if (opt.simd.size == no_size) { out << "S::simd_abi::native_width<::arb::fvm_value_type>::value;\n"; + } else { + out << opt.simd.size << ";\n"; } - else { + + out << "static constexpr unsigned simd_width_ = "; + if (opt.simd.width == no_size) { + out << " vector_length_ ? vector_length_ : " << opt.simd.default_width << ";\n"; + } else { out << opt.simd.width << ";\n"; } @@ -188,18 +204,22 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { case simd_spec::avx2: abi += "avx2"; break; case simd_spec::avx512: abi += "avx512"; break; case simd_spec::neon: abi += "neon"; break; + case simd_spec::sve: abi += "sve"; break; case simd_spec::native: abi += "native"; break; default: abi += "default_abi"; break; } out << - "using simd_value = S::simd<::arb::fvm_value_type, simd_width_, " << abi << ">;\n" - "using simd_index = S::simd<::arb::fvm_index_type, simd_width_, " << abi << ">;\n" + "using simd_value = S::simd<::arb::fvm_value_type, vector_length_, " << abi << ">;\n" + "using simd_index = S::simd<::arb::fvm_index_type, vector_length_, " << abi << ">;\n" + "using simd_mask = S::simd_mask<::arb::fvm_value_type, vector_length_, "<< abi << ">;\n" "\n" "inline simd_value safeinv(simd_value x) {\n" - " S::where(x+1==1, x) = DBL_EPSILON;\n" - " return 1/x;\n" + " simd_value ones = simd_cast<simd_value>(1.0);\n" + " auto mask = S::cmp_eq(S::add(x,ones), ones);\n" + " S::where(mask, x) = simd_cast<simd_value>(DBL_EPSILON);\n" + " return S::div(ones, x);\n" "}\n" "\n"; } @@ -224,6 +244,8 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "void deliver_events(deliverable_event_stream::state events) override;\n" "void net_receive(int i_, value_type weight);\n"; + with_simd && out << "unsigned simd_width() const override { return simd_width_; }\n"; + out << "\n" << popindent << "protected:\n" << indent << @@ -355,7 +377,7 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { auto emit_body = [&](APIMethod *p) { if (with_simd) { - emit_simd_api_body(out, p, module_.kind()); + emit_simd_api_body(out, p, vars.scalars); } else { emit_api_body(out, p); @@ -387,11 +409,11 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { for (auto proc: normal_procedures(module_)) { if (with_simd) { emit_simd_procedure_proto(out, proc, class_name); - auto simd_print = simdprint(proc->body()); + auto simd_print = simdprint(proc->body(), vars.scalars); out << " {\n" << indent << simd_print << popindent << "}\n\n"; emit_masked_simd_procedure_proto(out, proc, class_name); - auto masked_print = simdprint(proc->body()); + auto masked_print = simdprint(proc->body(), vars.scalars); masked_print.set_masked(); out << " {\n" << indent << masked_print << popindent << "}\n\n"; } else { @@ -553,9 +575,9 @@ void SimdPrinter::visit(LocalVariable* sym) { void SimdPrinter::visit(VariableExpression *sym) { if (sym->is_range()) { if(is_indirect_) - out_ << "simd_value(" << sym->name() << "+index_)"; + out_ << "simd_cast<simd_value>(indirect(" << sym->name() << "+index_, simd_width_))"; else - out_ << "simd_value(" << sym->name() << "+i_)"; + out_ << "simd_cast<simd_value>(indirect(" << sym->name() << "+i_, simd_width_))"; } else { out_ << sym->name(); @@ -568,26 +590,35 @@ void SimdPrinter::visit(AssignmentExpression* e) { } Symbol* lhs = e->lhs()->is_identifier()->symbol(); + + bool cast = false; + if (auto id = e->rhs()->is_identifier()) { + if (scalars_.count(id->name())) cast = true; + } + if (e->rhs()->is_number()) cast = true; + if (scalars_.count(e->lhs()->is_identifier()->name())) cast = false; if (lhs->is_variable() && lhs->is_variable()->is_range()) { - if (!input_mask_.empty()) - out_ << "S::where(" << input_mask_ << ", simd_value("; + if(is_indirect_) + out_ << "indirect(" << lhs->name() << "+index_, simd_width_) = "; else - out_ << "simd_value("; + out_ << "indirect(" << lhs->name() << "+i_, simd_width_) = "; + + if (!input_mask_.empty()) + out_ << "S::where(" << input_mask_ << ", "; + if (cast) out_ << "simd_cast<simd_value>("; e->rhs()->accept(this); + if (cast) out_ << ")"; if (!input_mask_.empty()) out_ << ")"; - - if(is_indirect_) - out_ << ").copy_to(" << lhs->name() << "+index_)"; - else - out_ << ").copy_to(" << lhs->name() << "+i_)"; } else { out_ << lhs->name() << " = "; + if (cast) out_ << "simd_cast<simd_value>("; e->rhs()->accept(this); + if (cast) out_ << ")"; } } @@ -637,7 +668,7 @@ void emit_simd_procedure_proto(std::ostream& out, ProcedureExpression* e, const void emit_masked_simd_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& qualified) { out << "void " << qualified << (qualified.empty()? "": "::") << e->name() - << "(index_type i_, simd_value::simd_mask mask_input_"; + << "(index_type i_, simd_mask mask_input_"; for (auto& arg: e->args()) { out << ", const simd_value& " << arg->is_argument()->name(); } @@ -650,29 +681,29 @@ void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_con if (local->is_read()) { auto d = decode_indexed_variable(local->external_variable()); if (d.scalar()) { - out << "(" << d.data_var + out << " = simd_cast<simd_value>(" << d.data_var << "[0]);\n"; } else if (constraint == simd_expr_constraint::contiguous) { - out << "(" << d.data_var + out << " = simd_cast<simd_value>(indirect(" << d.data_var << " + " << d.index_var - << "[index_]);\n"; + << "[index_], simd_width_));\n"; } else if (constraint == simd_expr_constraint::constant) { - out << "(" << d.data_var + out << " = simd_cast<simd_value>(" << d.data_var << "[" << d.index_var << "element0]);\n"; } else { - out << "(S::indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", constraint_category_));\n"; + out << " = simd_cast<simd_value>(indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", simd_width_, constraint_category_));\n"; } if (d.scale != 1) { - out << local->name() << " *= " << d.scale << ";\n"; + out << local->name() << " = S::mul(" << local->name() << ", simd_cast<simd_value>(" << d.scale << "));\n"; } } else { - out << " = 0;\n"; + out << " = simd_cast<simd_value>(0);\n"; } } @@ -690,40 +721,50 @@ void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* ex std::string tempvar = "t_"+external->name(); if (constraint == simd_expr_constraint::contiguous) { - out << "simd_value "<< tempvar <<"(" << d.data_var << " + " << d.index_var << "[index_]);\n" - << tempvar << " += w_*"; + out << "simd_value "<< tempvar <<" = simd_cast<simd_value>(indirect(" << d.data_var << " + " << d.index_var << "[index_], simd_width_));\n"; - if (coeff!=1) out << as_c_double(coeff) << "*"; + if (coeff!=1) { + out << tempvar << " = S::fma(S::mul(w_, simd_cast<simd_value>(" + << as_c_double(coeff) << "))," + << from->name() << ", " << tempvar << ");\n"; + } else { + out << tempvar << " = S::fma(w_, " << from->name() << ", " << tempvar << ");\n"; + } - out << from->name() << ";\n" - << tempvar << ".copy_to(" << d.data_var << " + " << d.index_var << "[index_]);\n"; + out << "indirect(" << d.data_var << " + " << d.index_var << "[index_], simd_width_) = " << tempvar << ";\n"; } else { - out << "S::indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", constraint_category_)" - << " += w_*"; + out << "indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", simd_width_, constraint_category_)"; - if (coeff!=1) out << as_c_double(coeff) << "*"; - - out << from->name() << ";\n"; + if (coeff!=1) { + out << " += S::mul(w_, S::mul(simd_cast<simd_value>(" + << as_c_double(coeff) << "), " + << from->name() << "));\n"; + } else { + out << " += S::mul(w_, " << from->name() << ");\n"; + } } } else { if (constraint == simd_expr_constraint::contiguous) { + out << "indirect(" << d.data_var << " + " << d.index_var << "[index_], simd_width_) = "; if (coeff!=1) { - out << "(" << as_c_double(coeff) << "*" << from->name() << ")"; + out << "(S::mul(simd_cast<simd_value>(" << as_c_double(coeff) << ")," << from->name() << "));\n"; } else { - out << from->name(); + out << from->name() << ";\n"; } - out << ".copy_to(" << d.data_var << " + " << d.index_var << "[index_]);\n"; } else { - out << "S::indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", constraint_category_)" + out << "indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", simd_width_, constraint_category_)" << " = "; - if (coeff!=1) out << as_c_double(coeff) << "*"; - - out << from->name() << ";\n"; + if (coeff!=1) { + out << "(S::mul(simd_cast<simd_value>(" << as_c_double(coeff) << ")," << from->name() << "));\n"; + } + else { + out << from->name() << ";\n"; + } } } } @@ -735,28 +776,33 @@ void emit_index_initialize(std::ostream& out, const std::unordered_set<std::stri break; case simd_expr_constraint::constant: for (auto& index: indices) { - out << "simd_index::scalar_type " << index << "element0 = " << index << "[index_];\n"; - out << index_i_name(index) << " = " << index << "element0;\n"; + out << "arb::fvm_index_type " << index << "element0 = " << index << "[index_];\n"; + out << index_i_name(index) << " = simd_cast<simd_index>(" << index << "element0);\n"; } break; case simd_expr_constraint::other: for (auto& index: indices) { - out << index_i_name(index) << ".copy_from(" << index << ".data() + index_);\n"; + out << index_i_name(index) << " = simd_cast<simd_index>(indirect(" << index << ".data() + index_, simd_width_));\n"; } break; } } -void emit_body_for_loop(std::ostream& out, BlockExpression* body, const std::vector<LocalVariable*>& indexed_vars, - const std::unordered_set<std::string>& indices, const simd_expr_constraint& read_constraint, - const simd_expr_constraint& write_constraint) { +void emit_body_for_loop( + std::ostream& out, + BlockExpression* body, + const std::vector<LocalVariable*>& indexed_vars, + const std::vector<VariableExpression*>& scalars, + const std::unordered_set<std::string>& indices, + const simd_expr_constraint& read_constraint, + const simd_expr_constraint& write_constraint) { emit_index_initialize(out, indices, read_constraint); for (auto& sym: indexed_vars) { emit_simd_state_read(out, sym, read_constraint); } - simdprint printer(body); + simdprint printer(body, scalars); printer.set_indirect_index(); out << printer; @@ -768,6 +814,7 @@ void emit_body_for_loop(std::ostream& out, BlockExpression* body, const std::vec void emit_for_loop_per_constraint(std::ostream& out, BlockExpression* body, const std::vector<LocalVariable*>& indexed_vars, + const std::vector<VariableExpression*>& scalars, bool requires_weight, const std::unordered_set<std::string>& indices, const simd_expr_constraint& read_constraint, @@ -781,21 +828,36 @@ void emit_for_loop_per_constraint(std::ostream& out, BlockExpression* body, out << "index_type index_ = index_constraints_." << underlying_constraint_name << "[i_];\n"; if (requires_weight) { - out << "simd_value w_(weight_+index_);\n"; + out << "simd_value w_ = simd_cast<simd_value>(indirect((weight_+index_), simd_width_));\n"; } - emit_body_for_loop(out, body, indexed_vars, indices, read_constraint, write_constraint); + emit_body_for_loop(out, body, indexed_vars, scalars, indices, read_constraint, write_constraint); out << popindent << "}\n"; } -void emit_simd_api_body(std::ostream& out, APIMethod* method, moduleKind module_kind) { +void emit_simd_api_body(std::ostream& out, APIMethod* method, const std::vector<VariableExpression*>& scalars) { auto body = method->body(); auto indexed_vars = indexed_locals(method->scope()); bool requires_weight = false; std::vector<LocalVariable*> scalar_indexed_vars; std::unordered_set<std::string> indices; + + for (auto& s: body->is_block()->statements()) { + if (s->is_assignment()) { + for (auto& v: indexed_vars) { + if (s->is_assignment()->lhs()->is_identifier()->name() == v->external_variable()->name()) { + auto info = decode_indexed_variable(v->external_variable()); + if (info.accumulate) { + requires_weight = true; + } + break; + } + } + } + } + for (auto& sym: indexed_vars) { auto info = decode_indexed_variable(sym->external_variable()); if (!info.scalar()) { @@ -804,12 +866,9 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, moduleKind module_ else { scalar_indexed_vars.push_back(sym); } - if (info.accumulate) { - requires_weight = true; - } } - if (!body->statements().empty()) { + out << "assert(simd_width_ <= (unsigned)S::width(simd_cast<simd_value>(0)));\n"; if (!indices.empty()) { for (auto& index: indices) { out << "simd_index " << index_i_name(index) << ";\n"; @@ -821,21 +880,21 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, moduleKind module_ simd_expr_constraint constraint = simd_expr_constraint::contiguous; std::string underlying_constraint = "contiguous"; - emit_for_loop_per_constraint(out, body, indexed_vars, requires_weight, indices, constraint, + emit_for_loop_per_constraint(out, body, indexed_vars, scalars, requires_weight, indices, constraint, constraint, underlying_constraint); //Generate for loop for all independent simd_vectors constraint = simd_expr_constraint::other; underlying_constraint = "independent"; - emit_for_loop_per_constraint(out, body, indexed_vars, requires_weight, indices, constraint, + emit_for_loop_per_constraint(out, body, indexed_vars, scalars, requires_weight, indices, constraint, constraint, underlying_constraint); //Generate for loop for all simd_vectors that have no optimizing constraints constraint = simd_expr_constraint::other; underlying_constraint = "none"; - emit_for_loop_per_constraint(out, body, indexed_vars, requires_weight, indices, constraint, + emit_for_loop_per_constraint(out, body, indexed_vars, scalars, requires_weight, indices, constraint, constraint, underlying_constraint); //Generate for loop for all constant simd_vectors @@ -843,7 +902,7 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, moduleKind module_ simd_expr_constraint write_constraint = simd_expr_constraint::other; underlying_constraint = "constant"; - emit_for_loop_per_constraint(out, body, indexed_vars, requires_weight, indices, read_constraint, + emit_for_loop_per_constraint(out, body, indexed_vars, scalars, requires_weight, indices, read_constraint, write_constraint, underlying_constraint); } @@ -856,7 +915,7 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, moduleKind module_ out << "unsigned n_ = width_;\n\n" "for (unsigned i_ = 0; i_ < n_; i_ += simd_width_) {\n" << indent << - simdprint(body) << popindent << + simdprint(body, scalars) << popindent << "}\n"; } } diff --git a/modcc/printer/cprinter.hpp b/modcc/printer/cprinter.hpp index 5b12957f..ca450eb2 100644 --- a/modcc/printer/cprinter.hpp +++ b/modcc/printer/cprinter.hpp @@ -57,6 +57,9 @@ public: void set_input_mask(std::string input_mask) { input_mask_ = input_mask; } + void save_scalar_names(const std::unordered_set<std::string>& scalars) { + scalars_ = scalars; + } void visit(BlockExpression*) override; void visit(CallExpression*) override; @@ -65,13 +68,14 @@ public: void visit(LocalVariable*) override; void visit(AssignmentExpression*) override; - void visit(NumberExpression* e) override { cexpr_emit(e, out_, this); } - void visit(UnaryExpression* e) override { cexpr_emit(e, out_, this); } - void visit(BinaryExpression* e) override { cexpr_emit(e, out_, this); } - void visit(IfExpression* e) override { simd_expr_emit(e, out_, is_indirect_, input_mask_, this); } + void visit(NumberExpression* e) override { simd_expr_emit(e, out_, is_indirect_, input_mask_, scalars_, this); } + void visit(UnaryExpression* e) override { simd_expr_emit(e, out_, is_indirect_, input_mask_, scalars_, this); } + void visit(BinaryExpression* e) override { simd_expr_emit(e, out_, is_indirect_, input_mask_, scalars_, this); } + void visit(IfExpression* e) override { simd_expr_emit(e, out_, is_indirect_, input_mask_, scalars_, this); } private: std::ostream& out_; std::string input_mask_; bool is_indirect_ = false; + std::unordered_set<std::string> scalars_; }; diff --git a/modcc/printer/simd.hpp b/modcc/printer/simd.hpp index a218842f..0b17a7f0 100644 --- a/modcc/printer/simd.hpp +++ b/modcc/printer/simd.hpp @@ -1,14 +1,18 @@ #pragma once +constexpr unsigned no_size = unsigned(-1); + struct simd_spec { - enum simd_abi { none, neon, avx, avx2, avx512, native, default_abi } abi = none; - unsigned width = 0; // zero => use `simd::native_width` to determine. + enum simd_abi { none, neon, avx, avx2, avx512, sve, native, default_abi } abi = none; + static constexpr unsigned default_width = 8; + unsigned size = no_size; // -1 => use `simd::native_width` to determine. + unsigned width = no_size; simd_spec() = default; - simd_spec(enum simd_abi a, unsigned w = 0): + simd_spec(enum simd_abi a, unsigned w = no_size): abi(a), width(w) { - if (width==0) { + if (width==no_size) { // Pick a width based on abi, if applicable. switch (abi) { case avx: @@ -24,5 +28,23 @@ struct simd_spec { default: ; } } + + switch (abi) { + case avx: + case avx2: + size = 4; + break; + case avx512: + size = 8; + break; + case neon: + size = 2; + break; + case sve: + size = 0; + break; + default: ; + } + } }; diff --git a/test/unit-modcc/test_printers.cpp b/test/unit-modcc/test_printers.cpp index fa8bd3ab..7db8dac6 100644 --- a/test/unit-modcc/test_printers.cpp +++ b/test/unit-modcc/test_printers.cpp @@ -267,26 +267,26 @@ TEST(CPrinter, proc_body_inlined) { TEST(SimdPrinter, simd_if_else) { std::vector<const char*> expected_procs = { "simd_value u;\n" - "simd_value::simd_mask mask_0_ = i > 2;\n" - "S::where(mask_0_,u) = 7;\n" - "S::where(!mask_0_,u) = 5;\n" - "S::where(!mask_0_,simd_value(42)).copy_to(s+i_);\n" - "simd_value(u).copy_to(s+i_);" + "simd_mask mask_0_ = S::cmp_gt(i, (double)2);\n" + "S::where(mask_0_,u) = (double)7;\n" + "S::where(S::logical_not(mask_0_),u) = (double)5;\n" + "indirect(s+i_, simd_width_) = S::where(S::logical_not(mask_0_),simd_cast<simd_value>((double)42));\n" + "indirect(s+i_, simd_width_) = u;" , "simd_value u;\n" - "simd_value::simd_mask mask_1_ = i > 2;\n" - "S::where(mask_1_,u) = 7;\n" - "S::where(!mask_1_,u) = 5;\n" - "S::where(!mask_1_ && mask_input_,simd_value(42)).copy_to(s+i_);\n" - "S::where(mask_input_, simd_value(u)).copy_to(s+i_);" + "simd_mask mask_1_ = S::cmp_gt(i, (double)2);\n" + "S::where(mask_1_,u) = (double)7;\n" + "S::where(S::logical_not(mask_1_),u) = (double)5;\n" + "indirect(s+i_, simd_width_) = S::where(S::logical_and(S::logical_not(mask_1_), mask_input_),simd_cast<simd_value>((double)42));\n" + "indirect(s+i_, simd_width_) = S::where(mask_input_, u);" , - "simd_value::simd_mask mask_2_ = simd_value(g+i_)>2;\n" - "simd_value::simd_mask mask_3_ = simd_value(g+i_)>3;\n" - "S::where(mask_2_&&mask_3_,i) = 0.;\n" - "S::where(mask_2_&&!mask_3_,i) = 1;\n" - "simd_value::simd_mask mask_4_ = simd_value(g+i_)<1;\n" - "S::where(!mask_2_&& mask_4_,simd_value(2)).copy_to(s+i_);\n" - "rates(i_, !mask_2_&&!mask_4_, i);" + "simd_mask mask_2_ = S::cmp_gt(simd_cast<simd_value>(indirect(g+i_, simd_width_)), (double)2);\n" + "simd_mask mask_3_ = S::cmp_gt(simd_cast<simd_value>(indirect(g+i_, simd_width_)), (double)3);\n" + "S::where(S::logical_and(mask_2_,mask_3_),i) = (double)0.;\n" + "S::where(S::logical_and(mask_2_,S::logical_not(mask_3_)),i) = (double)1;\n" + "simd_mask mask_4_ = S::cmp_lt(simd_cast<simd_value>(indirect(g+i_, simd_width_)), (double)1);\n" + "indirect(s+i_, simd_width_) = S::where(S::logical_and(S::logical_not(mask_2_),mask_4_),simd_cast<simd_value>((double)2));\n" + "rates(i_, S::logical_and(S::logical_not(mask_2_),S::logical_not(mask_4_)), i);" }; Module m(io::read_all(DATADIR "/mod_files/test7.mod"), "test7.mod"); diff --git a/test/unit/test_partition_by_constraint.cpp b/test/unit/test_partition_by_constraint.cpp index 77fa0c72..2e8a29ad 100644 --- a/test/unit/test_partition_by_constraint.cpp +++ b/test/unit/test_partition_by_constraint.cpp @@ -13,7 +13,9 @@ using namespace arb; using iarray = multicore::iarray; -static constexpr unsigned simd_width_ = arb::simd::simd_abi::native_width<fvm_value_type>::value; +constexpr unsigned vector_length = (unsigned) simd::simd_abi::native_width<fvm_value_type>::value; +using simd_value_type = simd::simd<fvm_value_type, vector_length, simd::simd_abi::default_abi>; +const int simd_width_ = simd::width<simd_value_type>(); const int input_size_ = 1024; diff --git a/test/unit/test_simd.cpp b/test/unit/test_simd.cpp index 3d611b26..67ac6a1a 100644 --- a/test/unit/test_simd.cpp +++ b/test/unit/test_simd.cpp @@ -5,9 +5,10 @@ #include <random> #include <unordered_set> -#include <arbor/simd/simd.hpp> #include <arbor/simd/avx.hpp> #include <arbor/simd/neon.hpp> +#include <arbor/simd/sve.hpp> +#include <arbor/simd/simd.hpp> #include <arbor/util/compat.hpp> #include "common.hpp" @@ -203,21 +204,22 @@ TYPED_TEST_P(simd_value, copy_to_from_masked) { } simd s(buf1); - where(m1, s).copy_from(buf2); + where(m1, s) = indirect(buf2, N); EXPECT_TRUE(testing::indexed_eq_n(N, expected, s)); for (unsigned i = 0; i<N; ++i) { if (!mbuf2[i]) expected[i] = buf3[i]; } - where(m2, s).copy_to(buf3); + indirect(buf3, N) = where(m2, s); EXPECT_TRUE(testing::indexed_eq_n(N, expected, buf3)); for (unsigned i = 0; i<N; ++i) { expected[i] = mbuf2[i]? buf1[i]: buf4[i]; } - where(m2, simd(buf1)).copy_to(buf4); + simd b(buf1); + indirect(buf4, N) = where(m2, b); EXPECT_TRUE(testing::indexed_eq_n(N, expected, buf4)); } } @@ -875,7 +877,7 @@ typedef ::testing::Types< #ifdef __AVX512F__ simd<double, 8, simd_abi::avx512>, #endif -#if defined(__ARM_NEON) +#ifdef __ARM_NEON simd<double, 2, simd_abi::neon>, #endif @@ -921,7 +923,7 @@ TYPED_TEST_P(simd_indirect, gather) { fill_random(array, rng); fill_random(offset, rng, 0, (int)(buflen-1)); - simd s(indirect(array, simd_index(offset))); + simd s = simd_cast<simd>(indirect(array, simd_index(offset), N)); scalar test[N]; for (unsigned j = 0; j<N; ++j) { @@ -961,7 +963,8 @@ TYPED_TEST_P(simd_indirect, masked_gather) { simd s(original); simd_mask m(mask); - where(m, s).copy_from(indirect(array, simd_index(offset))); + + where(m, s) = indirect(array, simd_index(offset), N); EXPECT_TRUE(::testing::indexed_eq_n(N, test, s)); } @@ -995,7 +998,7 @@ TYPED_TEST_P(simd_indirect, scatter) { } simd s(values); - s.copy_to(indirect(array, simd_index(offset))); + indirect(array, simd_index(offset), N) = s; EXPECT_TRUE(::testing::indexed_eq_n(buflen, test, array)); } @@ -1033,7 +1036,7 @@ TYPED_TEST_P(simd_indirect, masked_scatter) { simd s(values); simd_mask m(mask); - where(m, s).copy_to(indirect(array, simd_index(offset))); + indirect(array, simd_index(offset), N) = where(m, s); EXPECT_TRUE(::testing::indexed_eq_n(buflen, test, array)); @@ -1041,7 +1044,8 @@ TYPED_TEST_P(simd_indirect, masked_scatter) { array[j] = test[j]; } - where(m, simd(values)).copy_to(indirect(array, simd_index(offset))); + simd v(values); + indirect(array, simd_index(offset), N) = where(m, v); EXPECT_TRUE(::testing::indexed_eq_n(buflen, test, array)); } @@ -1074,7 +1078,7 @@ TYPED_TEST_P(simd_indirect, add_and_subtract) { test[offset[j]] += values[j]; } - indirect(array, simd_index(offset)) += simd(values); + indirect(array, simd_index(offset), N) += simd(values); EXPECT_TRUE(::testing::indexed_eq_n(buflen, test, array)); fill_random(offset, rng, 0, (int)(buflen-1)); @@ -1086,7 +1090,7 @@ TYPED_TEST_P(simd_indirect, add_and_subtract) { test[offset[j]] -= values[j]; } - indirect(array, simd_index(offset)) -= simd(values); + indirect(array, simd_index(offset), N) -= simd(values); EXPECT_TRUE(::testing::indexed_eq_n(buflen, test, array)); } } @@ -1137,7 +1141,7 @@ TYPED_TEST_P(simd_indirect, constrained_add) { } while (!unique_elements(offset)); make_test_array(); - indirect(array, simd_index(offset), index_constraint::independent) += simd(values); + indirect(array, simd_index(offset), N, index_constraint::independent) += simd(values); EXPECT_TRUE(::testing::indexed_eq_n(buflen, test, array)); @@ -1149,7 +1153,7 @@ TYPED_TEST_P(simd_indirect, constrained_add) { } make_test_array(); - indirect(array, simd_index(offset), index_constraint::contiguous) += simd(values); + indirect(array, simd_index(offset), N, index_constraint::contiguous) += simd(values); EXPECT_TRUE(::testing::indexed_eq_n(buflen, test, array)); @@ -1168,7 +1172,7 @@ TYPED_TEST_P(simd_indirect, constrained_add) { } make_test_array(); - indirect(array, simd_index(offset), index_constraint::constant) += simd(values); + indirect(array, simd_index(offset), N, index_constraint::constant) += simd(values); EXPECT_TRUE(::testing::indexed_almost_eq_n(buflen, test, array)); @@ -1202,7 +1206,7 @@ typedef ::testing::Types< simd_and_index<simd<int, 8, simd_abi::avx512>, simd<int, 8, simd_abi::avx512>>, #endif -#if defined(__ARM_NEON) +#ifdef __ARM_NEON simd_and_index<simd<double, 2, simd_abi::neon>, simd<int, 2, simd_abi::neon>>, @@ -1288,7 +1292,7 @@ typedef ::testing::Types< simd_pair<simd<double, 8, simd_abi::avx512>, simd<int, 8, simd_abi::avx512>>, #endif -#if defined(__ARM_NEON) +#ifdef __ARM_NEON simd_pair<simd<double, 2, simd_abi::neon>, simd<int, 2, simd_abi::neon>>, #endif @@ -1298,3 +1302,525 @@ typedef ::testing::Types< > simd_casting_test_types; INSTANTIATE_TYPED_TEST_CASE_P(S, simd_casting, simd_casting_test_types); + + +// Sizeless simd types API tests + +template <typename T, typename V, unsigned N> +struct simd_t { + using simd_type = T; + using scalar_type = V; + static constexpr unsigned width = N; +}; + +template <typename T, typename I, typename M> +struct simd_types_t { + using simd_value = T; + using simd_index = I; + using simd_value_mask = M; +}; + +template <typename SI> +struct sizeless_api: public ::testing::Test {}; + +TYPED_TEST_CASE_P(sizeless_api); + +TYPED_TEST_P(sizeless_api, construct) { + using simd_value = typename TypeParam::simd_value::simd_type; + using scalar_value = typename TypeParam::simd_value::scalar_type; + + using simd_index = typename TypeParam::simd_index::simd_type; + using scalar_index = typename TypeParam::simd_index::scalar_type; + + constexpr unsigned N = TypeParam::simd_value::width; + + std::minstd_rand rng(1001); + + { + scalar_value a_in[N], a_out[N]; + fill_random(a_in, rng); + + simd_value av = simd_cast<simd_value>(indirect(a_in, N)); + + indirect(a_out, N) = av; + + EXPECT_TRUE(testing::indexed_eq_n(N, a_in, a_out)); + } + { + scalar_value a_in[2*N], b_in[N], a_out[N], exp_0[N], exp_1[2*N]; + fill_random(a_in, rng); + fill_random(b_in, rng); + + scalar_index idx[N]; + + auto make_test_indirect2simd = [&]() { + for (unsigned i = 0; i<N; ++i) { + exp_0[i] = a_in[idx[i]]; + } + }; + + auto make_test_simd2indirect = [&]() { + for (unsigned i = 0; i<2*N; ++i) { + exp_1[i] = a_in[i]; + } + for (unsigned i = 0; i<N; ++i) { + exp_1[idx[i]] = b_in[i]; + } + }; + + // Independent + for (unsigned i = 0; i < N; ++i) { + idx[i] = i*2; + } + simd_index idxv = simd_cast<simd_index>(indirect(idx, N)); + + make_test_indirect2simd(); + + simd_value av = simd_cast<simd_value>(indirect(a_in, idxv, N, index_constraint::independent)); + indirect(a_out, N) = av; + + EXPECT_TRUE(testing::indexed_eq_n(N, exp_0, a_out)); + + make_test_simd2indirect(); + + indirect(a_in, idxv, N, index_constraint::independent) = simd_cast<simd_value>(indirect(b_in, N)); + + EXPECT_TRUE(testing::indexed_eq_n(2*N, exp_1, a_in)); + + // contiguous + for (unsigned i = 0; i < N; ++i) { + idx[i] = i; + } + idxv = simd_cast<simd_index>(indirect(idx, N)); + + make_test_indirect2simd(); + + av = simd_cast<simd_value>(indirect(a_in, idxv, N, index_constraint::contiguous)); + indirect(a_out, N) = av; + + EXPECT_TRUE(testing::indexed_eq_n(N, exp_0, a_out)); + + make_test_simd2indirect(); + + indirect(a_in, idxv, N, index_constraint::contiguous) = simd_cast<simd_value>(indirect(b_in, N)); + + EXPECT_TRUE(testing::indexed_eq_n(2*N, exp_1, a_in)); + + // none + for (unsigned i = 0; i < N; ++i) { + idx[i] = i/2; + } + + idxv = simd_cast<simd_index>(indirect(idx, N)); + + make_test_indirect2simd(); + + av = simd_cast<simd_value>(indirect(a_in, idxv, N, index_constraint::none)); + indirect(a_out, N) = av; + + EXPECT_TRUE(testing::indexed_eq_n(N, exp_0, a_out)); + + make_test_simd2indirect(); + + indirect(a_in, idxv, N, index_constraint::none) = simd_cast<simd_value>(indirect(b_in, N)); + + EXPECT_TRUE(testing::indexed_eq_n(2*N, exp_1, a_in)); + + // constant + for (unsigned i = 0; i < N; ++i) { + idx[i] = 0; + } + + idxv = simd_cast<simd_index>(indirect(idx, N)); + + make_test_indirect2simd(); + + av = simd_cast<simd_value>(indirect(a_in, idxv, N, index_constraint::constant)); + indirect(a_out, N) = av; + + EXPECT_TRUE(testing::indexed_eq_n(N, exp_0, a_out)); + + make_test_simd2indirect(); + + indirect(a_in, idxv, N, index_constraint::constant) = simd_cast<simd_value>(indirect(b_in, N)); + + EXPECT_TRUE(testing::indexed_eq_n(2*N, exp_1, a_in)); + } +} + +TYPED_TEST_P(sizeless_api, where_exp) { + using simd_value = typename TypeParam::simd_value::simd_type; + using scalar_value = typename TypeParam::simd_value::scalar_type; + + using simd_index = typename TypeParam::simd_index::simd_type; + using scalar_index = typename TypeParam::simd_index::scalar_type; + + using mask_simd = typename TypeParam::simd_value_mask::simd_type; + + constexpr unsigned N = TypeParam::simd_value::width; + + std::minstd_rand rng(201); + + bool m[N]; + fill_random(m, rng); + mask_simd mv = simd_cast<mask_simd>(indirect(m, N)); + + scalar_value a[N], b[N], exp[N]; + fill_random(a, rng); + fill_random(b, rng); + + { + bool c[N]; + indirect(c, N) = mv; + EXPECT_TRUE(testing::indexed_eq_n(N, c, m)); + } + + // where = constant + { + scalar_value c[N]; + + simd_value av = simd_cast<simd_value>(indirect(a, N)); + + where(mv, av) = 42.3; + indirect(c, N) = av; + + for (unsigned i = 0; i<N; ++i) { + exp[i] = m[i]? 42.3 : a[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + } + + // where = simd + { + scalar_value c[N]; + + simd_value av = simd_cast<simd_value>(indirect(a, N)); + simd_value bv = simd_cast<simd_value>(indirect(b, N)); + + where(mv, av) = bv; + indirect(c, N) = av; + + for (unsigned i = 0; i<N; ++i) { + exp[i] = m[i]? b[i] : a[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + } + + // simd = where + { + scalar_value c[N]; + + simd_value av = simd_cast<simd_value>(indirect(a, N)); + simd_value bv = simd_cast<simd_value>(indirect(b, N)); + + simd_value cv = simd_cast<simd_value>(where(mv, add(av, bv))); + indirect(c, N) = cv; + + for (unsigned i = 0; i<N; ++i) { + exp[i] = m[i]? (a[i] + b[i]) : 0; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + } + + // where = indirect + { + scalar_value c[N]; + + simd_value av = simd_cast<simd_value>(indirect(a, N)); + + where(mv, av) = indirect(b, N); + indirect(c, N) = av; + + for (unsigned i = 0; i<N; ++i) { + exp[i] = m[i]? b[i] : a[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + } + + // indirect = where + { + scalar_value c[N]; + fill_random(c, rng); + + simd_value av = simd_cast<simd_value>(indirect(a, N)); + + indirect(c, N) = where(mv, av); + + for (unsigned i = 0; i<N; ++i) { + exp[i] = m[i]? a[i] : c[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + + indirect(c, N) = where(mv, neg(av)); + + for (unsigned i = 0; i<N; ++i) { + exp[i] = m[i]? -a[i] : c[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + } + + // where = indirect indexed + { + scalar_value c[N]; + + simd_value av = simd_cast<simd_value>(indirect(a, N)); + + scalar_index idx[N]; + for (unsigned i =0; i<N; ++i) { + idx[i] = i/2; + } + simd_index idxv = simd_cast<simd_index>(indirect(idx, N)); + + where(mv, av) = indirect(b, idxv, N, index_constraint::none); + indirect(c, N) = av; + + for (unsigned i = 0; i<N; ++i) { + exp[i] = m[i]? b[idx[i]] : a[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + } + + // indirect indexed = where + { + scalar_value c[N]; + fill_random(c, rng); + + simd_value av = simd_cast<simd_value>(indirect(a, N)); + simd_value bv = simd_cast<simd_value>(indirect(b, N)); + + scalar_index idx[N]; + for (unsigned i =0; i<N; ++i) { + idx[i] = i; + } + simd_index idxv = simd_cast<simd_index>(indirect(idx, N)); + + indirect(c, idxv, N, index_constraint::contiguous) = where(mv, av); + + for (unsigned i = 0; i<N; ++i) { + exp[idx[i]] = m[i]? a[i] : c[idx[i]]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + + indirect(c, idxv, N, index_constraint::contiguous) = where(mv, sub(av, bv)); + + for (unsigned i = 0; i<N; ++i) { + exp[idx[i]] = m[i]? a[i] - b[i] : c[idx[i]]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + } +} + +TYPED_TEST_P(sizeless_api, arithmetic) { + using simd_value = typename TypeParam::simd_value::simd_type; + using scalar_value = typename TypeParam::simd_value::scalar_type; + + constexpr unsigned N = TypeParam::simd_value::width; + + std::minstd_rand rng(201); + + scalar_value a[N], b[N], c[N], expected[N]; + fill_random(a, rng); + fill_random(b, rng); + + bool m[N], expected_m[N]; + fill_random(m, rng); + + simd_value av = simd_cast<simd_value>(indirect(a, N)); + simd_value bv = simd_cast<simd_value>(indirect(b, N)); + + // add + { + indirect(c, N) = add(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected[i] = a[i] + b[i]; + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // sub + { + indirect(c, N) = sub(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected[i] = a[i] - b[i]; + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // mul + { + indirect(c, N) = mul(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected[i] = a[i] * b[i]; + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // div + { + indirect(c, N) = div(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected[i] = a[i] / b[i]; + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // pow + { + indirect(c, N) = pow(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected[i] = std::pow(a[i], b[i]); + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // min + { + indirect(c, N) = min(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected[i] = std::min(a[i], b[i]); + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // max + { + indirect(c, N) = max(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected[i] = std::max(a[i], b[i]); + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // cmp_eq + { + indirect(m, N) = cmp_eq(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected_m[i] = a[i] == b[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, m, expected_m)); + } + // cmp_neq + { + indirect(m, N) = cmp_neq(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected_m[i] = a[i] != b[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, m, expected_m)); + } + // cmp_leq + { + indirect(m, N) = cmp_leq(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected_m[i] = a[i] <= b[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, m, expected_m)); + } + // cmp_geq + { + indirect(m, N) = cmp_geq(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected_m[i] = a[i] >= b[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, m, expected_m)); + } + // sum + { + auto s = sum(av); + scalar_value expected_sum = 0; + + for (unsigned i = 0; i<N; ++i) { + expected_sum += a[i]; + } + EXPECT_FLOAT_EQ(expected_sum, s); + } + // neg + { + indirect(c, N) = neg(av); + for (unsigned i = 0; i<N; ++i) { + expected[i] = -a[i]; + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // abs + { + indirect(c, N) = abs(av); + for (unsigned i = 0; i<N; ++i) { + expected[i] = std::abs(a[i]); + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // exp + { + indirect(c, N) = exp(av); + for (unsigned i = 0; i<N; ++i) { + EXPECT_NEAR(std::exp(a[i]), c[i], 1e-6); + } + } + // expm1 + { + indirect(c, N) = expm1(av); + for (unsigned i = 0; i<N; ++i) { + EXPECT_NEAR(std::expm1(a[i]), c[i], 1e-6); + } + } + // exprelr + { + indirect(c, N) = exprelr(av); + for (unsigned i = 0; i<N; ++i) { + EXPECT_NEAR(a[i]/(std::expm1(a[i])), c[i], 1e-6); + } + } + // log + { + scalar_value l[N]; + int max_exponent = std::numeric_limits<scalar_value>::max_exponent; + fill_random(l, rng, -max_exponent*std::log(2.), max_exponent*std::log(2.)); + for (auto& x: l) { + x = std::exp(x); + // SIMD log implementation may treat subnormal as zero + if (std::fpclassify(x)==FP_SUBNORMAL) x = 0; + } + simd_value lv = simd_cast<simd_value>(indirect(l, N)); + + indirect(c, N) = log(lv); + + for (unsigned i = 0; i<N; ++i) { + expected[i] = std::log(l[i]); + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + +} + +REGISTER_TYPED_TEST_CASE_P(sizeless_api, construct, where_exp, arithmetic); + +typedef ::testing::Types< + +#ifdef __AVX__ + simd_types_t< simd_t< simd<double, 4, simd_abi::avx>, double, 4>, + simd_t< simd<int, 4, simd_abi::avx>, int, 4>, + simd_t<simd_mask<double, 4, simd_abi::avx>, int, 4>>, +#endif +#ifdef __AVX2__ + simd_types_t< simd_t< simd<double, 4, simd_abi::avx2>, double, 4>, + simd_t< simd<int, 4, simd_abi::avx2>, int, 4>, + simd_t<simd_mask<double, 4, simd_abi::avx2>, int, 4>>, +#endif +#ifdef __AVX512F__ + simd_types_t< simd_t< simd<double, 8, simd_abi::avx512>, double, 8>, + simd_t< simd<int, 8, simd_abi::avx512>, int, 8>, + simd_t<simd_mask<double, 8, simd_abi::avx512>, int, 8>>, +#endif +#ifdef __ARM_NEON + simd_types_t< simd_t< simd<double, 2, simd_abi::neon>, double, 2>, + simd_t< simd<int, 2, simd_abi::neon>, int, 2>, + simd_t<simd_mask<double, 2, simd_abi::neon>, double, 2>>, +#endif +#ifdef __ARM_FEATURE_SVE + simd_types_t< simd_t< simd<double, 0, simd_abi::sve>, double, 4>, + simd_t< simd<int, 0, simd_abi::sve>, int, 4>, + simd_t<simd_mask<double, 0, simd_abi::sve>, bool, 4>>, + + simd_types_t< simd_t< simd<double, 0, simd_abi::sve>, double, 8>, + simd_t< simd<int, 0, simd_abi::sve>, int, 8>, + simd_t<simd_mask<double, 0, simd_abi::sve>, bool, 8>>, +#endif + simd_types_t< simd_t< simd<double, 8, simd_abi::default_abi>, double, 8>, + simd_t< simd<int, 8, simd_abi::default_abi>, int, 8>, + simd_t<simd_mask<double, 8, simd_abi::default_abi>, bool, 8>> +> sizeless_api_test_types; + +INSTANTIATE_TYPED_TEST_CASE_P(S, sizeless_api, sizeless_api_test_types); -- GitLab