diff --git a/arbor/backends/multicore/mechanism.cpp b/arbor/backends/multicore/mechanism.cpp index 87597207a70e29dcc0591e06792ff7b46255a352..a38be492b0775317ca497a9f3cbc016d09273bf1 100644 --- a/arbor/backends/multicore/mechanism.cpp +++ b/arbor/backends/multicore/mechanism.cpp @@ -28,9 +28,6 @@ namespace multicore { using util::make_range; using util::value_by_key; -constexpr unsigned simd_width = S::simd_abi::native_width<fvm_value_type>::value; - - // Copy elements from source sequence into destination sequence, // and fill the remaining elements of the destination sequence // with the given fill value. @@ -154,7 +151,7 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me copy_extend(pos_data.cv, node_index_, pos_data.cv.back()); copy_extend(pos_data.weight, make_range(data_.data(), data_.data()+width_padded_), 0); - index_constraints_ = make_constraint_partition(node_index_, width_, simd_width); + index_constraints_ = make_constraint_partition(node_index_, width_, simd_width()); if (mult_in_place_) { multiplicity_ = iarray(width_padded_, pad); @@ -176,7 +173,7 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me ion_index = iarray(width_padded_, pad); copy_extend(indices, ion_index, util::back(indices)); - arb_assert(compatible_index_constraints(node_index_, ion_index, simd_width)); + arb_assert(compatible_index_constraints(node_index_, ion_index, simd_width())); } } diff --git a/arbor/backends/multicore/mechanism.hpp b/arbor/backends/multicore/mechanism.hpp index 0d1357304e357f9b56f7f1c5736a6957b1f4b7d1..cf4a331cb9584992eac5245d08b8901c7847e69a 100644 --- a/arbor/backends/multicore/mechanism.hpp +++ b/arbor/backends/multicore/mechanism.hpp @@ -137,6 +137,10 @@ protected: virtual mechanism_ion_state_table ion_state_table() { return {}; } virtual mechanism_ion_index_table ion_index_table() { return {}; } + // Simd width used in mechanism. + + virtual unsigned simd_width() const { return 1; } + // Report raw size in bytes of mechanism object. virtual std::size_t object_sizeof() const = 0; diff --git a/arbor/backends/multicore/shared_state.cpp b/arbor/backends/multicore/shared_state.cpp index 10a1295d25b36a5601f9ca35e90e380124f526c1..8749ac6950a1eff51b938f0d68fded24d1aa5c7f 100644 --- a/arbor/backends/multicore/shared_state.cpp +++ b/arbor/backends/multicore/shared_state.cpp @@ -26,9 +26,10 @@ namespace arb { namespace multicore { -constexpr unsigned simd_width = simd::simd_abi::native_width<fvm_value_type>::value; -using simd_value_type = simd::simd<fvm_value_type, simd_width>; -using simd_index_type = simd::simd<fvm_index_type, simd_width>; +constexpr unsigned vector_length = (unsigned) simd::simd_abi::native_width<fvm_value_type>::value; +using simd_value_type = simd::simd<fvm_value_type, vector_length, simd::simd_abi::default_abi>; +using simd_index_type = simd::simd<fvm_index_type, vector_length, simd::simd_abi::default_abi>; +const int simd_width = simd::width<simd_value_type>(); // Pick alignment compatible with native SIMD width for explicitly // vectorized operations below. @@ -168,27 +169,38 @@ void shared_state::ions_init_concentration() { } void shared_state::update_time_to(fvm_value_type dt_step, fvm_value_type tmax) { + using simd::assign; + using simd::indirect; + using simd::add; + using simd::min; for (fvm_size_type i = 0; i<n_intdom; i+=simd_width) { - simd_value_type t(time.data()+i); - t = min(t+dt_step, simd_value_type(tmax)); - t.copy_to(time_to.data()+i); + simd_value_type t; + assign(t, indirect(time.data()+i, simd_width)); + t = min(add(t, dt_step), tmax); + indirect(time_to.data()+i, simd_width) = t; } } void shared_state::set_dt() { + using simd::assign; + using simd::indirect; + using simd::sub; for (fvm_size_type j = 0; j<n_intdom; j+=simd_width) { - simd_value_type t(time.data()+j); - simd_value_type t_to(time_to.data()+j); + simd_value_type t, t_to; + assign(t, indirect(time.data()+j, simd_width)); + assign(t_to, indirect(time_to.data()+j, simd_width)); - auto dt = t_to-t; - dt.copy_to(dt_intdom.data()+j); + auto dt = sub(t_to,t); + indirect(dt_intdom.data()+j, simd_width) = dt; } for (fvm_size_type i = 0; i<n_cv; i+=simd_width) { - simd_index_type intdom_idx(cv_to_intdom.data()+i); + simd_index_type intdom_idx; + assign(intdom_idx, indirect(cv_to_intdom.data()+i, simd_width)); - simd_value_type dt(simd::indirect(dt_intdom.data(), intdom_idx)); - dt.copy_to(dt_cv.data()+i); + simd_value_type dt; + assign(dt, indirect(dt_intdom.data(), intdom_idx, simd_width)); + indirect(dt_cv.data()+i, simd_width) = dt; } } diff --git a/arbor/include/arbor/simd/avx.hpp b/arbor/include/arbor/simd/avx.hpp index 31fcd5139909467bd825f945776c7ed8d0ce3c47..1dd39c9cfc78eff96bb4d40df5bc11e03ad6edff 100644 --- a/arbor/include/arbor/simd/avx.hpp +++ b/arbor/include/arbor/simd/avx.hpp @@ -74,7 +74,7 @@ struct avx_int4: implbase<avx_int4> { return _mm_cvtsi128_si32(a); } - static __m128i negate(const __m128i& a) { + static __m128i neg(const __m128i& a) { __m128i zero = _mm_setzero_si128(); return _mm_sub_epi32(zero, a); } @@ -163,7 +163,7 @@ struct avx_int4: implbase<avx_int4> { // bottom 4 bytes. __m128i s = _mm_setr_epi32(0x0c080400ul,0,0,0); - __m128i p = _mm_shuffle_epi8(negate(m), s); + __m128i p = _mm_shuffle_epi8(neg(m), s); std::memcpy(y, &p, 4); } @@ -172,7 +172,7 @@ struct avx_int4: implbase<avx_int4> { std::memcpy(&r, w, 4); __m128i s = _mm_setr_epi32(0x80808000ul, 0x80808001ul, 0x80808002ul, 0x80808003ul); - return negate(_mm_shuffle_epi8(r, s)); + return neg(_mm_shuffle_epi8(r, s)); } static __m128i max(const __m128i& a, const __m128i& b) { @@ -245,7 +245,7 @@ struct avx_double4: implbase<avx_double4> { return _mm_cvtsd_f64(_mm256_castpd256_pd128(a)); } - static __m256d negate(const __m256d& a) { + static __m256d neg(const __m256d& a) { return _mm256_sub_pd(zero(), a); } diff --git a/arbor/include/arbor/simd/avx512.hpp b/arbor/include/arbor/simd/avx512.hpp index 4f3a524e58caefc9f72663b6bd9211906f241699..0416d8a0d7a1979995624d2cb922e8e46aa127ec 100644 --- a/arbor/include/arbor/simd/avx512.hpp +++ b/arbor/include/arbor/simd/avx512.hpp @@ -98,7 +98,7 @@ struct avx512_mask8: implbase<avx512_mask8> { // max(a, b) a | b // min(a, b) a & b - static __mmask8 negate(const __mmask8& a) { + static __mmask8 neg(const __mmask8& a) { return a; } @@ -249,7 +249,7 @@ struct avx512_int8: implbase<avx512_int8> { return _mm_cvtsi128_si32(_mm512_castsi512_si128(a)); } - static __m512i negate(const __m512i& a) { + static __m512i neg(const __m512i& a) { return sub(_mm512_setzero_epi32(), a); } @@ -445,7 +445,7 @@ struct avx512_double8: implbase<avx512_double8> { return _mm_cvtsd_f64(_mm512_castpd512_pd128(a)); } - static __m512d negate(const __m512d& a) { + static __m512d neg(const __m512d& a) { return _mm512_sub_pd(_mm512_setzero_pd(), a); } diff --git a/arbor/include/arbor/simd/implbase.hpp b/arbor/include/arbor/simd/implbase.hpp index ff79e777eb955a6dcb1548b36fb825961daa7cbe..544846cd274f61334db8bd55e9aa3b2f4e48083a 100644 --- a/arbor/include/arbor/simd/implbase.hpp +++ b/arbor/include/arbor/simd/implbase.hpp @@ -14,9 +14,9 @@ // // Function | Default implemention by // ---------------------------------- -// min | negate, cmp_gt, ifelse -// max | negate, cmp_gt, ifelse -// abs | negate, max +// min | neg, cmp_gt, ifelse +// max | neg, cmp_gt, ifelse +// abs | neg, max // sin | lane-wise std::sin // cos | lane-wise std::cos // exp | lane-wise std::exp @@ -181,7 +181,7 @@ struct implbase { return I::copy_from(a); } - static vector_type negate(const vector_type& u) { + static vector_type neg(const vector_type& u) { store a, r; I::copy_to(u, a); diff --git a/arbor/include/arbor/simd/native.hpp b/arbor/include/arbor/simd/native.hpp index 8164ceaa427d87172e1cc832204225bdcbedb20a..1d945574e5285aa13cc3021274423ca6a20a70b9 100644 --- a/arbor/include/arbor/simd/native.hpp +++ b/arbor/include/arbor/simd/native.hpp @@ -65,7 +65,13 @@ ARB_DEF_NATIVE_SIMD_(double, 8, avx512) #endif -#if defined(__ARM_NEON) +#if defined(__ARM_FEATURE_SVE) + +#include "sve.hpp" +ARB_DEF_NATIVE_SIMD_(int, 0, sve) +ARB_DEF_NATIVE_SIMD_(double, 0, sve) + +#elif defined(__ARM_NEON) #include <arbor/simd/neon.hpp> ARB_DEF_NATIVE_SIMD_(int, 2, neon) @@ -73,7 +79,6 @@ ARB_DEF_NATIVE_SIMD_(double, 2, neon) #endif - namespace arb { namespace simd { namespace simd_abi { @@ -87,15 +92,15 @@ struct native_width; template <typename Value, int k> struct native_width { - static constexpr int value = - std::is_same<void, typename native<Value, k>::type>::value? - native_width<Value, k/2>::value: - k; + static constexpr int value = + std::is_same<void, typename native<Value, k>::type>::value? + native_width<Value, k/2>::value: + k; }; template <typename Value> struct native_width<Value, 1> { - static constexpr int value = 1; + static constexpr int value = std::is_same<void, typename native<Value, 0>::type>::value; }; } // namespace simd_abi diff --git a/arbor/include/arbor/simd/neon.hpp b/arbor/include/arbor/simd/neon.hpp index f2710461581030314eb01cf044f93fcda332ded6..32d0701234cf7af842ed24cfafc55b4ddac60eed 100644 --- a/arbor/include/arbor/simd/neon.hpp +++ b/arbor/include/arbor/simd/neon.hpp @@ -68,7 +68,7 @@ struct neon_int2 : implbase<neon_int2> { return a; } - static int32x2_t negate(const int32x2_t& a) { return vneg_s32(a); } + static int32x2_t neg(const int32x2_t& a) { return vneg_s32(a); } static int32x2_t add(const int32x2_t& a, const int32x2_t& b) { return vadd_s32(a, b); @@ -222,7 +222,7 @@ struct neon_double2 : implbase<neon_double2> { return a; } - static float64x2_t negate(const float64x2_t& a) { return vnegq_f64(a); } + static float64x2_t neg(const float64x2_t& a) { return vnegq_f64(a); } static float64x2_t add(const float64x2_t& a, const float64x2_t& b) { return vaddq_f64(a, b); diff --git a/arbor/include/arbor/simd/simd.hpp b/arbor/include/arbor/simd/simd.hpp index 3e0fb573387a6c2e8f291cc1779ba04949b23cd9..58cb0540a649f2c4bcf3e3fa3aedf4eaecaa2c28 100644 --- a/arbor/include/arbor/simd/simd.hpp +++ b/arbor/include/arbor/simd/simd.hpp @@ -7,6 +7,7 @@ #include <arbor/simd/implbase.hpp> #include <arbor/simd/generic.hpp> #include <arbor/simd/native.hpp> +#include <arbor/util/pp_util.hpp> namespace arb { namespace simd { @@ -17,41 +18,350 @@ namespace detail { template <typename Impl> struct simd_mask_impl; + + template <typename To> + struct simd_cast_impl; + + template <typename I, typename V> + class indirect_indexed_expression; + + template <typename V> + class indirect_expression; + + template <typename T, typename M> + class where_expression; + + template <typename T, typename M> + class const_where_expression; +} + +// Top level functions for second API +using detail::simd_impl; +using detail::simd_mask_impl; + +template <typename Impl, typename V> +void assign(simd_impl<Impl>& a, const detail::indirect_expression<V>& b) { + a.copy_from(b); +} + +template <typename Impl, typename ImplIndex, typename V> +void assign(simd_impl<Impl>& a, const detail::indirect_indexed_expression<ImplIndex, V>& b) { + a.copy_from(b); +} + +template <typename Impl> +typename simd_impl<Impl>::scalar_type sum(const simd_impl<Impl>& a) { + return a.sum(); +}; + +#define ARB_UNARY_ARITHMETIC_(name)\ +template <typename Impl>\ +simd_impl<Impl> name(const simd_impl<Impl>& a) {\ + return simd_impl<Impl>::wrap(Impl::name(a.value_));\ +}; + +#define ARB_BINARY_ARITHMETIC_(name)\ +template <typename Impl>\ +simd_impl<Impl> name(const simd_impl<Impl>& a, simd_impl<Impl> b) {\ + return simd_impl<Impl>::wrap(Impl::name(a.value_, b.value_));\ +};\ +template <typename Impl>\ +simd_impl<Impl> name(const simd_impl<Impl>& a, typename simd_impl<Impl>::scalar_type b) {\ + return simd_impl<Impl>::wrap(Impl::name(a.value_, Impl::broadcast(b)));\ +};\ +template <typename Impl>\ +simd_impl<Impl> name(const typename simd_impl<Impl>::scalar_type a, simd_impl<Impl> b) {\ + return simd_impl<Impl>::wrap(Impl::name(Impl::broadcast(a), b.value_));\ +}; + +#define ARB_BINARY_COMPARISON_(name)\ +template <typename Impl>\ +typename simd_impl<Impl>::simd_mask name(const simd_impl<Impl>& a, simd_impl<Impl> b) {\ + return simd_impl<Impl>::mask(Impl::name(a.value_, b.value_));\ +};\ +template <typename Impl>\ +typename simd_impl<Impl>::simd_mask name(const simd_impl<Impl>& a, typename simd_impl<Impl>::scalar_type b) {\ + return simd_impl<Impl>::mask(Impl::name(a.value_, Impl::broadcast(b)));\ +};\ +template <typename Impl>\ +typename simd_impl<Impl>::simd_mask name(const typename simd_impl<Impl>::scalar_type a, simd_impl<Impl> b) {\ + return simd_impl<Impl>::mask(Impl::name(Impl::broadcast(a), b.value_));\ +}; + +ARB_PP_FOREACH(ARB_BINARY_ARITHMETIC_, add, sub, mul, div, pow, max, min) +ARB_PP_FOREACH(ARB_BINARY_COMPARISON_, cmp_eq, cmp_neq, cmp_leq, cmp_lt, cmp_geq, cmp_gt) +ARB_PP_FOREACH(ARB_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, exprelr) + +#undef ARB_BINARY_ARITHMETIC_ +#undef ARB_BINARY_COMPARISON__ +#undef ARB_UNARY_ARITHMETIC_ + +template <typename T> +simd_mask_impl<T> logical_and(const simd_mask_impl<T>& a, simd_mask_impl<T> b) { + return a && b; +} + +template <typename T> +simd_mask_impl<T> logical_or(const simd_mask_impl<T>& a, simd_mask_impl<T> b) { + return a || b; +} + +template <typename T> +simd_mask_impl<T> logical_not(const simd_mask_impl<T>& a) { + return !a; +} + +template <typename T> +simd_impl<T> fma(const simd_impl<T> a, simd_impl<T> b, simd_impl<T> c) { + return simd_impl<T>::wrap(T::fma(a.value_, b.value_, c.value_)); } namespace detail { - template <typename Impl, typename V> - struct indirect_expression { + /// Indirect Expressions + template <typename V> + class indirect_expression { + public: + indirect_expression(V* p, unsigned width): p(p), width(width) {} + + indirect_expression& operator=(V s) { + for (unsigned i = 0; i < width; ++i) { + p[i] = s; + } + return *this; + } + + template <typename Other> + indirect_expression& operator=(const Other& s) { + indirect_copy_to(s, p, width); + return *this; + } + + template <typename Impl, typename ImplMask> + indirect_expression& operator=(const const_where_expression<Impl, ImplMask>& s) { + indirect_copy_to(s.data_, s.mask_, p, width); + return *this; + } + + template <typename Impl, typename ImplMask> + indirect_expression& operator=(const where_expression<Impl, ImplMask>& s) { + indirect_copy_to(s.data_, s.mask_, p, width); + return *this; + } + + template <typename Impl> friend struct simd_impl; + template <typename Impl> friend struct simd_mask_impl; + template <typename To> friend struct simd_cast_impl; + template <typename T, typename M> friend class where_expression; + + private: V* p; - typename simd_traits<Impl>::vector_type index; - index_constraint constraint; + unsigned width; + }; + + template <typename Impl, typename V> + static void indirect_copy_to(const simd_mask_impl<Impl>& s, V* p, unsigned width) { + Impl::mask_copy_to(s.value_, p); + } - indirect_expression() = default; - indirect_expression(V* p, const simd_impl<Impl>& index_simd, index_constraint constraint): - p(p), index(index_simd.value_), constraint(constraint) + template <typename Impl, typename V> + static void indirect_copy_to(const simd_impl<Impl>& s, V* p, unsigned width) { + Impl::copy_to(s.value_, p); + } + + template <typename Impl, typename ImplMask, typename V> + static void indirect_copy_to(const simd_impl<Impl>& data, const simd_mask_impl<ImplMask>& mask, V* p, unsigned width) { + Impl::copy_to_masked(data.value_, p, mask.value_); + } + + /// Indirect Indexed Expressions + template <typename ImplIndex, typename V> + class indirect_indexed_expression { + public: + indirect_indexed_expression(V* p, const ImplIndex& index_simd, unsigned width, index_constraint constraint): + p(p), index(index_simd), width(width), constraint(constraint) {} - // Simple assignment included for consistency with compound assignment interface. + indirect_indexed_expression& operator=(V s) { + typename simd_traits<ImplIndex>::scalar_type idx[width]; + ImplIndex::copy_to(index.value_, idx); + for (unsigned i = 0; i < width; ++i) { + p[idx[i]] = s; + } + return *this; + } template <typename Other> - indirect_expression& operator=(const simd_impl<Other>& s) { - s.copy_to(*this); + indirect_indexed_expression& operator=(const Other& s) { + indirect_indexed_copy_to(s, p, index, width); return *this; } - // Compound assignment (currently only addition and subtraction!): + template <typename Impl, typename ImplMask> + indirect_indexed_expression& operator=(const const_where_expression<Impl, ImplMask>& s) { + indirect_indexed_copy_to(s.data_, s.mask_, p, index, width); + return *this; + } + + template <typename Impl, typename ImplMask> + indirect_indexed_expression& operator=(const where_expression<Impl, ImplMask>& s) { + indirect_indexed_copy_to(s.data_, s.mask_, p, index, width); + return *this; + } + + template <typename Other> + indirect_indexed_expression& operator+=(const Other& s) { + compound_indexed_add(s, p, index, width, constraint); + return *this; + } template <typename Other> - indirect_expression& operator+=(const simd_impl<Other>& s) { - simd_impl<Other>::compound_indexed_add(tag<Impl>{}, s.value_, p, index, constraint); + indirect_indexed_expression& operator-=(const Other& s) { + compound_indexed_add(neg(s), p, index, width, constraint); return *this; } + template <typename Impl> friend struct simd_impl; + template <typename To> friend struct simd_cast_impl; + template <typename T, typename M> friend class where_expression; + + private: + V* p; + const ImplIndex& index; + unsigned width; + index_constraint constraint; + }; + + template <typename Impl, typename ImplIndex, typename V> + static void indirect_indexed_copy_to(const simd_impl<Impl>& s, V* p, const simd_impl<ImplIndex>& index, unsigned width) { + Impl::scatter(tag<ImplIndex>{}, s.value_, p, index.value_); + } + + template <typename Impl, typename ImplIndex, typename ImplMask, typename V> + static void indirect_indexed_copy_to(const simd_impl<Impl>& data, const simd_mask_impl<ImplMask>& mask, V* p, const simd_impl<ImplIndex>& index, unsigned width) { + Impl::scatter(tag<ImplIndex>{}, data.value_, p, index.value_, mask.value_); + } + + template <typename ImplIndex, typename Impl, typename V> + static void compound_indexed_add( + const simd_impl<Impl>& s, + V* p, + const simd_impl<ImplIndex>& index, + unsigned width, + index_constraint constraint) + { + switch (constraint) { + case index_constraint::none: + { + typename ImplIndex::scalar_type o[width]; + ImplIndex::copy_to(index.value_, o); + + V a[width]; + Impl::copy_to(s.value_, a); + + V temp = 0; + for (unsigned i = 0; i<width-1; ++i) { + temp += a[i]; + if (o[i] != o[i+1]) { + p[o[i]] += temp; + temp = 0; + } + } + temp += a[width-1]; + p[o[width-1]] += temp; + } + break; + case index_constraint::independent: + { + auto v = Impl::add(Impl::gather(tag<ImplIndex>{}, p, index.value_), s.value_); + Impl::scatter(tag<ImplIndex>{}, v, p, index.value_); + } + break; + case index_constraint::contiguous: + { + p += ImplIndex::element0(index.value_); + auto v = Impl::add(Impl::copy_from(p), s.value_); + Impl::copy_to(v, p); + } + break; + case index_constraint::constant: + p += ImplIndex::element0(index.value_); + *p += Impl::reduce_add(s.value_); + break; + } + } + + /// Where Expressions + template <typename Impl, typename ImplMask> + class where_expression { + public: + where_expression(const ImplMask& m, Impl& s): + mask_(m), data_(s) {} + template <typename Other> - indirect_expression& operator-=(const simd_impl<Other>& s) { - simd_impl<Other>::compound_indexed_add(tag<Impl>{}, (-s).value_, p, index, constraint); + where_expression& operator=(const Other& v) { + where_copy_to(mask_, data_, v); + return *this; + } + + template <typename V> + where_expression& operator=(const indirect_expression<V>& v) { + where_copy_to(mask_, data_, v.p, v.width); + return *this; + } + + template <typename ImplIndex, typename V> + where_expression& operator=(const indirect_indexed_expression<ImplIndex, V>& v) { + where_copy_to(mask_, data_, v.p, v.index, v.width); return *this; } + + template <typename T> friend struct simd_impl; + template <typename To> friend struct simd_cast_impl; + template <typename V> friend class indirect_expression; + template <typename I, typename V> friend class indirect_indexed_expression; + + private: + const ImplMask& mask_; + Impl& data_; + }; + + template <typename Impl, typename ImplMask, typename V> + static void where_copy_to(const simd_mask_impl<ImplMask>& mask, simd_impl<Impl>& f, const V& t) { + f.value_ = Impl::ifelse(mask.value_, Impl::broadcast(t), f.value_); + } + + template <typename Impl, typename ImplMask> + static void where_copy_to(const simd_mask_impl<ImplMask>& mask, simd_impl<Impl>& f, const simd_impl<Impl>& t) { + f.value_ = Impl::ifelse(mask.value_, t.value_, f.value_); + } + + template <typename Impl, typename ImplMask, typename V> + static void where_copy_to(const simd_mask_impl<ImplMask>& mask, simd_impl<Impl>& f, const V* t, unsigned width) { + f.value_ = Impl::ifelse(mask.value_, Impl::copy_from_masked(t, mask.value_), f.value_); + } + + template <typename Impl, typename ImplIndex, typename ImplMask, typename V> + static void where_copy_to(const simd_mask_impl<ImplMask>& mask, simd_impl<Impl>& f, const V* p, const simd_impl<ImplIndex>& index, unsigned width) { + simd_impl<Impl> temp = Impl::broadcast(0); + temp.value_ = Impl::gather(tag<ImplIndex>{}, temp.value_, p, index.value_, mask.value_); + f.value_ = Impl::ifelse(mask.value_, temp.value_, f.value_); + } + + /// Const Where Expressions + template <typename Impl, typename ImplMask> + class const_where_expression { + public: + const_where_expression(const ImplMask& m, const Impl& s): + mask_(m), data_(s) {} + + template <typename T> friend struct simd_impl; + template <typename To> friend struct simd_cast_impl; + template <typename V> friend class indirect_expression; + template <typename I, typename V> friend class indirect_indexed_expression; + + private: + const ImplMask& mask_; + const Impl& data_; }; template <typename Impl> @@ -68,6 +378,7 @@ namespace detail { using scalar_type = typename simd_traits<Impl>::scalar_type; using simd_mask = simd_mask_impl<typename simd_traits<Impl>::mask_impl>; + using simd_base = Impl; protected: using vector_type = typename simd_traits<Impl>::vector_type; @@ -80,7 +391,10 @@ namespace detail { friend struct simd_impl; template <typename Other, typename V> - friend struct indirect_expression; + friend class indirect_indexed_expression; + + template <typename V> + friend class indirect_expression; simd_impl() = default; @@ -117,12 +431,12 @@ namespace detail { // Construct from indirect expression (gather). template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> - explicit simd_impl(indirect_expression<IndexImpl, scalar_type> pi) { + explicit simd_impl(indirect_indexed_expression<IndexImpl, scalar_type> pi) { copy_from(pi); } template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> - explicit simd_impl(indirect_expression<IndexImpl, const scalar_type> pi) { + explicit simd_impl(indirect_indexed_expression<IndexImpl, const scalar_type> pi) { copy_from(pi); } @@ -149,8 +463,9 @@ namespace detail { Impl::copy_to(value_, p); } - template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> - void copy_to(indirect_expression<IndexImpl, scalar_type> pi) const { + template <typename Index, typename = std::enable_if_t<width==simd_traits<typename Index::simd_base>::width>> + void copy_to(indirect_indexed_expression<Index, scalar_type> pi) const { + using IndexImpl = typename Index::simd_base; Impl::scatter(tag<IndexImpl>{}, value_, pi.p, pi.index); } @@ -158,24 +473,26 @@ namespace detail { value_ = Impl::copy_from(p); } - template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> - void copy_from(indirect_expression<IndexImpl, scalar_type> pi) { + + template <typename Index, typename = std::enable_if_t<width==simd_traits<typename Index::simd_base>::width>> + void copy_from(indirect_indexed_expression<Index, scalar_type> pi) { + using IndexImpl = typename Index::simd_base; switch (pi.constraint) { case index_constraint::none: - value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index); + value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index.value_); break; case index_constraint::independent: - value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index); + value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index.value_); break; case index_constraint::contiguous: { - scalar_type* p = IndexImpl::element0(pi.index) + pi.p; + scalar_type* p = IndexImpl::element0(pi.index.value_) + pi.p; value_ = Impl::copy_from(p); } break; case index_constraint::constant: { - scalar_type* p = IndexImpl::element0(pi.index) + pi.p; + scalar_type* p = IndexImpl::element0(pi.index.value_) + pi.p; scalar_type l = (*p); value_ = Impl::broadcast(l); } @@ -183,80 +500,54 @@ namespace detail { } } - template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> - void copy_from(indirect_expression<IndexImpl, const scalar_type> pi) { + template <typename Index, typename = std::enable_if_t<width==simd_traits<typename Index::simd_base>::width>> + void copy_from(indirect_indexed_expression<Index, const scalar_type> pi) { + using IndexImpl = typename Index::simd_base; switch (pi.constraint) { case index_constraint::none: - value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index); + value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index.value_); break; case index_constraint::independent: - value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index); + value_ = Impl::gather(tag<IndexImpl>{}, pi.p, pi.index.value_); break; case index_constraint::contiguous: { - const scalar_type* p = IndexImpl::element0(pi.index) + pi.p; + const scalar_type* p = IndexImpl::element0(pi.index.value_) + pi.p; value_ = Impl::copy_from(p); } break; case index_constraint::constant: { - const scalar_type *p = IndexImpl::element0(pi.index) + pi.p; + const scalar_type *p = IndexImpl::element0(pi.index.value_) + pi.p; scalar_type l = (*p); value_ = Impl::broadcast(l); } break; } + } + void copy_from(indirect_expression<scalar_type> pi) { + value_ = Impl::copy_from(pi.p); } - template <typename ImplIndex> - static void compound_indexed_add(tag<ImplIndex> tag, const vector_type& s, scalar_type* p, const typename ImplIndex::vector_type& index, index_constraint constraint) { - switch (constraint) { - case index_constraint::none: - { - typename ImplIndex::scalar_type o[width]; - ImplIndex::copy_to(index, o); - - scalar_type a[width]; - Impl::copy_to(s, a); - - scalar_type temp = 0; - for (unsigned i = 0; i<width-1; ++i) { - temp += a[i]; - if (o[i] != o[i+1]) { - p[o[i]] += temp; - temp = 0; - } - } - temp += a[width-1]; - p[o[width-1]] += temp; - } - break; - case index_constraint::independent: - { - vector_type v = Impl::add(Impl::gather(tag, p, index), s); - Impl::scatter(tag, v, p, index); - } - break; - case index_constraint::contiguous: - { - p += ImplIndex::element0(index); - vector_type v = Impl::add(Impl::copy_from(p), s); - Impl::copy_to(v, p); - } - break; - case index_constraint::constant: - p += ImplIndex::element0(index); - *p += Impl::reduce_add(s); - break; - } + void copy_from(indirect_expression<const scalar_type> pi) { + value_ = Impl::copy_from(pi.p); + } + + template <typename T, typename M> + void copy_from(const_where_expression<T, M> w) { + value_ = Impl::ifelse(w.mask_.value_, w.data_.value_, value_); } + template <typename T, typename M> + void copy_from(where_expression<T, M> w) { + value_ = Impl::ifelse(w.mask_.value_, w.data_.value_, value_); + } // Arithmetic operations: +, -, *, /, fma. simd_impl operator-() const { - return wrap(Impl::negate(value_)); + return simd_impl::wrap(Impl::neg(value_)); } friend simd_impl operator+(const simd_impl& a, simd_impl b) { @@ -364,115 +655,67 @@ namespace detail { return Impl::reduce_add(value_); } - // Masked assignment (via where expressions). - - struct where_expression { - where_expression(const where_expression&) = default; - where_expression& operator=(const where_expression&) = delete; - - where_expression(const simd_mask& m, simd_impl& s): - mask_(m), data_(s) {} - - where_expression& operator=(scalar_type v) { - data_.value_ = Impl::ifelse(mask_.value_, simd_impl(v).value_, data_.value_); - return *this; - } - - where_expression& operator=(const simd_impl& v) { - data_.value_ = Impl::ifelse(mask_.value_, v.value_, data_.value_); - return *this; - } - - void copy_to(scalar_type* p) const { - Impl::copy_to_masked(data_.value_, p, mask_.value_); - } - - void copy_from(const scalar_type* p) { - data_.value_ = Impl::copy_from_masked(data_.value_, p, mask_.value_); - } - - // Gather and scatter. - - template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> - void copy_from(indirect_expression<IndexImpl, scalar_type> pi) { - data_.value_ = Impl::gather(tag<IndexImpl>{}, data_.value_, pi.p, pi.index, mask_.value_); - } - - template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> - void copy_to(indirect_expression<IndexImpl, scalar_type> pi) const { - Impl::scatter(tag<IndexImpl>{}, data_.value_, pi.p, pi.index, mask_.value_); - } - - private: - const simd_mask& mask_; - simd_impl& data_; - }; + // Maths functions are implemented as top-level functions; declare as friends for access to `wrap` - struct const_where_expression { - const_where_expression(const const_where_expression&) = default; - const_where_expression& operator=(const const_where_expression&) = delete; + #define ARB_DECLARE_UNARY_ARITHMETIC_(name)\ + template <typename T>\ + friend simd_impl<T> arb::simd::name(const simd_impl<T>& a); - const_where_expression(const simd_mask& m, const simd_impl& s): - mask_(m), data_(s) {} + #define ARB_DECLARE_BINARY_ARITHMETIC_(name)\ + template <typename T>\ + friend simd_impl<T> arb::simd::name(const simd_impl<T>& a, simd_impl<T> b);\ + template <typename T>\ + friend simd_impl<T> arb::simd::name(const simd_impl<T>& a, typename simd_impl<T>::scalar_type b);\ + template <typename T>\ + friend simd_impl<T> arb::simd::name(const typename simd_impl<T>::scalar_type a, simd_impl<T> b); - void copy_to(scalar_type* p) const { - Impl::copy_to_masked(data_.value_, p, mask_.value_); - } - - template <typename IndexImpl, typename = std::enable_if_t<width==simd_traits<IndexImpl>::width>> - void copy_to(indirect_expression<IndexImpl, scalar_type> pi) const { - Impl::scatter(tag<IndexImpl>{}, data_.value_, pi.p, pi.index, mask_.value_); - } + #define ARB_DECLARE_BINARY_COMPARISON_(name)\ + template <typename T>\ + friend typename simd_impl<T>::simd_mask arb::simd::name(const simd_impl<T>& a, simd_impl<T> b);\ + template <typename T>\ + friend typename simd_impl<T>::simd_mask arb::simd::name(const simd_impl<T>& a, typename simd_impl<T>::scalar_type b);\ + template <typename T>\ + friend typename simd_impl<T>::simd_mask arb::simd::name(const typename simd_impl<T>::scalar_type a, simd_impl<T> b); - private: - const simd_mask& mask_; - const simd_impl& data_; - }; + ARB_PP_FOREACH(ARB_DECLARE_BINARY_ARITHMETIC_, add, sub, mul, div, pow, max, min, cmp_eq) + ARB_PP_FOREACH(ARB_DECLARE_BINARY_COMPARISON_, cmp_eq, cmp_neq, cmp_lt, cmp_leq, cmp_gt, cmp_geq) + ARB_PP_FOREACH(ARB_DECLARE_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, exprelr) + #undef ARB_DECLARE_UNARY_ARITHMETIC_ + #undef ARB_DECLARE_BINARY_ARITHMETIC_ + #undef ARB_DECLARE_BINARY_COMPARISON_ - // Maths functions are implemented as top-level functions; declare as friends - // for access to `wrap` and to enjoy ADL, allowing implicit conversion from - // scalar_type in binary operation arguments. + template <typename T> + friend simd_impl<T> arb::simd::fma(const simd_impl<T> a, simd_impl<T> b, simd_impl<T> c); - friend simd_impl abs(const simd_impl& s) { - return simd_impl::wrap(Impl::abs(s.value_)); - } + // Declare Indirect/Indirect indexed/Where Expression copy function as friends - friend simd_impl sin(const simd_impl& s) { - return simd_impl::wrap(Impl::sin(s.value_)); - } + template <typename T, typename I, typename V> + friend void compound_indexed_add(const simd_impl<I>& s, V* p, const simd_impl<T>& index, unsigned width, index_constraint constraint); - friend simd_impl cos(const simd_impl& s) { - return simd_impl::wrap(Impl::cos(s.value_)); - } + template <typename I, typename V> + friend void indirect_copy_to(const simd_impl<I>& s, V* p, unsigned width); - friend simd_impl exp(const simd_impl& s) { - return simd_impl::wrap(Impl::exp(s.value_)); - } + template <typename T, typename M, typename V> + friend void indirect_copy_to(const simd_impl<T>& data, const simd_mask_impl<M>& mask, V* p, unsigned width); - friend simd_impl log(const simd_impl& s) { - return simd_impl::wrap(Impl::log(s.value_)); - } + template <typename T, typename I, typename V> + friend void indirect_indexed_copy_to(const simd_impl<T>& s, V* p, const simd_impl<I>& index, unsigned width); - friend simd_impl expm1(const simd_impl& s) { - return simd_impl::wrap(Impl::expm1(s.value_)); - } + template <typename T, typename I, typename M, typename V> + friend void indirect_indexed_copy_to(const simd_impl<T>& data, const simd_mask_impl<M>& mask, V* p, const simd_impl<I>& index, unsigned width); - friend simd_impl exprelr(const simd_impl& s) { - return simd_impl::wrap(Impl::exprelr(s.value_)); - } + template <typename T, typename M, typename V> + friend void where_copy_to(const simd_mask_impl<M>& mask, simd_impl<T>& f, const V& t); - friend simd_impl pow(const simd_impl& s, const simd_impl& t) { - return simd_impl::wrap(Impl::pow(s.value_, t.value_)); - } + template <typename T, typename M> + friend void where_copy_to(const simd_mask_impl<M>& mask, simd_impl<T>& f, const simd_impl<T>& t); - friend simd_impl min(const simd_impl& s, const simd_impl& t) { - return simd_impl::wrap(Impl::min(s.value_, t.value_)); - } + template <typename T, typename M, typename V> + friend void where_copy_to(const simd_mask_impl<M>& mask, simd_impl<T>& f, const V* p, unsigned width); - friend simd_impl max(const simd_impl& s, const simd_impl& t) { - return simd_impl::wrap(Impl::max(s.value_, t.value_)); - } + template <typename T, typename I, typename M, typename V> + friend void where_copy_to(const simd_mask_impl<M>& mask, simd_impl<T>& f, const V* p, const simd_impl<I>& index, unsigned width); protected: vector_type value_; @@ -541,6 +784,10 @@ namespace detail { value_ = Impl::mask_copy_from(y); } + void copy_from(indirect_expression<bool> pi) { + value_ = Impl::mask_copy_from(pi.p); + } + // Array subscript operations. struct reference { @@ -605,7 +852,33 @@ namespace detail { }; template <typename To> - struct simd_cast_impl {}; + struct simd_cast_impl; + + template <typename ImplTo> + struct simd_cast_impl<simd_mask_impl<ImplTo>> { + static constexpr unsigned N = simd_traits<ImplTo>::width; + using scalar_type = typename simd_traits<ImplTo>::scalar_type; + + template <typename ImplFrom, typename = std::enable_if_t<N==simd_traits<ImplFrom>::width>> + static simd_mask_impl<ImplTo> cast(const simd_mask_impl<ImplFrom>& v) { + return simd_mask_impl<ImplTo>(v); + } + + static simd_mask_impl<ImplTo> cast(const std::array<scalar_type, N>& a) { + return simd_mask_impl<ImplTo>(a.data()); + } + + static simd_mask_impl<ImplTo> cast(scalar_type a) { + simd_mask_impl<ImplTo> r = a; + return r; + } + + static simd_mask_impl<ImplTo> cast(const indirect_expression<bool>& a) { + simd_mask_impl<ImplTo> r; + r.copy_from(a); + return r; + } + }; template <typename ImplTo> struct simd_cast_impl<simd_impl<ImplTo>> { @@ -620,6 +893,32 @@ namespace detail { static simd_impl<ImplTo> cast(const std::array<scalar_type, N>& a) { return simd_impl<ImplTo>(a.data()); } + + static simd_impl<ImplTo> cast(scalar_type a) { + simd_impl<ImplTo> r = a; + return r; + } + + template <typename V> + static simd_impl<ImplTo> cast(const indirect_expression<V>& a) { + simd_impl<ImplTo> r; + r.copy_from(a); + return r; + } + + template <typename Impl, typename V> + static simd_impl<ImplTo> cast(const indirect_indexed_expression<Impl,V>& a) { + simd_impl<ImplTo> r; + r.copy_from(a); + return r; + } + + template <typename Impl, typename V> + static simd_impl<ImplTo> cast(const const_where_expression<Impl,V>& a) { + simd_impl<ImplTo> r = 0; + r.copy_from(a); + return r; + } }; template <typename V, std::size_t N> @@ -652,27 +951,17 @@ namespace simd_abi { }; } -template <typename Value, unsigned N, template <class, unsigned> class Abi = simd_abi::default_abi> -using simd = detail::simd_impl<typename Abi<Value, N>::type>; - -template <typename Value, unsigned N> -using simd_mask = typename simd<Value, N>::simd_mask; +template <typename Value, unsigned N, template <class, unsigned> class Abi> +struct simd_wrap { using type = detail::simd_impl<typename Abi<Value, N>::type>; }; -template <typename Simd> -using where_expression = typename Simd::where_expression; +template <typename Value, unsigned N, template <class, unsigned> class Abi> +using simd = typename simd_wrap<Value, N, Abi>::type; -template <typename Simd> -using const_where_expression = typename Simd::const_where_expression; +template <typename Value, unsigned N, template <class, unsigned> class Abi> +struct simd_mask_wrap { using type = typename simd<Value, N, Abi>::simd_mask; }; -template <typename Simd> -where_expression<Simd> where(const typename Simd::simd_mask& m, Simd& v) { - return where_expression<Simd>(m, v); -} - -template <typename Simd> -const_where_expression<Simd> where(const typename Simd::simd_mask& m, const Simd& v) { - return const_where_expression<Simd>(m, v); -} +template <typename Value, unsigned N, template <class, unsigned> class Abi> +using simd_mask = typename simd_mask_wrap<Value, N, Abi>::type; template <typename> struct is_simd: std::false_type {}; @@ -688,6 +977,11 @@ To simd_cast(const From& s) { return detail::simd_cast_impl<To>::cast(s); } +template <typename S, std::enable_if_t<is_simd<S>::value, int> = 0> +inline constexpr int width(const S a = S{}) { + return S::width; +}; + // Gather/scatter indexed memory specification. template < @@ -695,14 +989,35 @@ template < typename PtrLike, typename V = std::remove_reference_t<decltype(*std::declval<PtrLike>())> > -detail::indirect_expression<IndexImpl, V> indirect( +detail::indirect_indexed_expression<IndexImpl, V> indirect( PtrLike p, - const detail::simd_impl<IndexImpl>& index, + const IndexImpl& index, + unsigned width, index_constraint constraint = index_constraint::none) { - return detail::indirect_expression<IndexImpl, V>(p, index, constraint); + return detail::indirect_indexed_expression<IndexImpl, V>(p, index, width, constraint); +} + +template < + typename PtrLike, + typename V = std::remove_reference_t<decltype(*std::declval<PtrLike>())> +> +detail::indirect_expression<V> indirect( + PtrLike p, + unsigned width) +{ + return detail::indirect_expression<V>(p, width); } +template <typename Impl, typename ImplMask> +detail::where_expression<Impl, ImplMask> where(const ImplMask& m, Impl& v) { + return detail::where_expression<Impl, ImplMask>(m, v); +} + +template <typename Impl, typename ImplMask> +detail::const_where_expression<Impl, ImplMask> where(const ImplMask& m, const Impl& v) { + return detail::const_where_expression<Impl, ImplMask>(m, v); +} } // namespace simd } // namespace arb diff --git a/arbor/include/arbor/simd/sve.hpp b/arbor/include/arbor/simd/sve.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7a5c0ec9d59bf7c2beae5a638421306b1196e19a --- /dev/null +++ b/arbor/include/arbor/simd/sve.hpp @@ -0,0 +1,904 @@ +#pragma once + +// SVE SIMD intrinsics implementation. + +#ifdef __ARM_FEATURE_SVE + +#include <arm_sve.h> +#include <cmath> +#include <cstdint> +#include <iostream> + +#include <arbor/util/pp_util.hpp> + +#include "approx.hpp" + +namespace arb { +namespace simd { +namespace detail { + +struct sve_double; +struct sve_int; +struct sve_mask; + +template<typename Type> struct sve_type_to_impl; +template<> struct sve_type_to_impl<svint64_t> { using type = detail::sve_int;}; +template<> struct sve_type_to_impl<svfloat64_t> { using type = detail::sve_double;}; +template<> struct sve_type_to_impl<svbool_t> { using type = detail::sve_mask;}; + +template<typename> struct is_sve : std::false_type {}; +template<> struct is_sve<svint64_t> : std::true_type {}; +template<> struct is_sve<svfloat64_t> : std::true_type {}; +template<> struct is_sve<svbool_t> : std::true_type {}; + +template <> +struct simd_traits<sve_mask> { + static constexpr unsigned width = 8; + using scalar_type = bool; + using vector_type = svbool_t; + using mask_impl = sve_mask; +}; + +template <> +struct simd_traits<sve_double> { + static constexpr unsigned width = 8; + using scalar_type = double; + using vector_type = svfloat64_t; + using mask_impl = sve_mask; +}; + +template <> +struct simd_traits<sve_int> { + static constexpr unsigned width = 8; + using scalar_type = int32_t; + using vector_type = svint64_t; + using mask_impl = sve_mask; +}; + +struct sve_mask { + static svbool_t broadcast(bool b) { + return svdup_b64(-b); + } + + static void copy_to(const svbool_t& k, bool* b) { + svuint64_t a = svdup_u64_z(k, 1); + svst1b_u64(svptrue_b64(), reinterpret_cast<uint8_t*>(b), a); + } + + static void copy_to_masked(const svbool_t& k, bool* b, const svbool_t& mask) { + svuint64_t a = svdup_u64_z(k, 1); + svst1b_u64(mask, reinterpret_cast<uint8_t*>(b), a); + } + + static svbool_t copy_from(const bool* p) { + svuint64_t a = svld1ub_u64(svptrue_b64(), reinterpret_cast<const uint8_t*>(p)); + svuint64_t ones = svdup_n_u64(1); + return svcmpeq_u64(svptrue_b64(), a, ones); + } + + static svbool_t copy_from_masked(const bool* p, const svbool_t& mask) { + svuint64_t a = svld1ub_u64(mask, reinterpret_cast<const uint8_t*>(p)); + svuint64_t ones = svdup_n_u64(1); + return svcmpeq_u64(mask, a, ones); + } + + static svbool_t logical_not(const svbool_t& k, const svbool_t& mask = svptrue_b64()) { + return svnot_b_z(mask, k); + } + + static svbool_t logical_and(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return svand_b_z(mask, a, b); + } + + static svbool_t logical_or(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return svorr_b_z(mask, a, b); + } + + // Arithmetic operations not necessarily appropriate for + // packed bit mask, but implemented for completeness/testing, + // with Z modulo 2 semantics: + // a + b is equivalent to a ^ b + // a * b a & b + // a / b a + // a - b a ^ b + // -a a + // max(a, b) a | b + // min(a, b) a & b + + static svbool_t neg(const svbool_t& a) { + return a; + } + + static svbool_t add(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return sveor_b_z(mask, a, b); + } + + static svbool_t sub(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return sveor_b_z(mask, a, b); + } + + static svbool_t mul(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return svand_b_z(mask, a, b); + } + + static svbool_t div(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return a; + } + + static svbool_t fma(const svbool_t& a, const svbool_t& b, const svbool_t& c, const svbool_t& mask = svptrue_b64()) { + return add(mul(a, b, mask), c, mask); + } + + static svbool_t max(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return svorr_b_z(mask, a, b); + } + + static svbool_t min(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return svand_b_z(mask, a, b); + } + + // Comparison operators are also taken as operating on Z modulo 2, + // with 1 > 0: + // + // a > b is equivalent to a & ~b + // a >= b a | ~b, ~(~a & b) + // a < b ~a & b + // a <= b ~a | b, ~(a & ~b) + // a == b ~(a ^ b) + // a != b a ^ b + + static svbool_t cmp_eq(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return svnot_b_z(mask, sveor_b_z(mask, a, b)); + } + + static svbool_t cmp_neq(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return sveor_b_z(mask, a, b); + } + + static svbool_t cmp_lt(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return svbic_b_z(mask, b, a); + } + + static svbool_t cmp_gt(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return cmp_lt(b, a); + } + + static svbool_t cmp_geq(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return logical_not(cmp_lt(a, b)); + } + + static svbool_t cmp_leq(const svbool_t& a, const svbool_t& b, const svbool_t& mask = svptrue_b64()) { + return logical_not(cmp_gt(a, b)); + } + + static svbool_t ifelse(const svbool_t& m, const svbool_t& u, const svbool_t& v) { + return svsel_b(m, u, v); + } + + static svbool_t mask_broadcast(bool b) { + return broadcast(b); + } + + static void mask_copy_to(const svbool_t& m, bool* y) { + copy_to(m, y); + } + + static svbool_t mask_copy_from(const bool* y) { + return copy_from(y); + } + + static svbool_t true_mask(unsigned width) { + return svwhilelt_b64_u64(0, (uint64_t)width); + } +}; + +struct sve_int { + // Use default implementations for: + // element, set_element. + + using int32 = std::int32_t; + + static svint64_t broadcast(int32 v) { + return svreinterpret_s64_s32(svdup_n_s32(v)); + } + + static void copy_to(const svint64_t& v, int32* p) { + svst1w_s64(svptrue_b64(), p, v); + } + + static void copy_to_masked(const svint64_t& v, int32* p, const svbool_t& mask) { + svst1w_s64(mask, p, v); + } + + static svint64_t copy_from(const int32* p) { + return svld1sw_s64(svptrue_b64(), p); + } + + static svint64_t copy_from_masked(const int32* p, const svbool_t& mask) { + return svld1sw_s64(mask, p); + } + + static svint64_t copy_from_masked(const svint64_t& v, const int32* p, const svbool_t& mask) { + return svsel_s64(mask, svld1sw_s64(mask, p), v); + } + + static int element0(const svint64_t& a) { + return svlasta_s64(svptrue_b64(), a); + } + + static svint64_t neg(const svint64_t& a, const svbool_t& mask = svptrue_b64()) { + return svneg_s64_z(mask, a); + } + + static svint64_t add(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svadd_s64_z(mask, a, b); + } + + static svint64_t sub(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svsub_s64_m(mask, a, b); + } + + static svint64_t mul(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + //May overflow + return svmul_s64_z(mask, a, b); + } + + static svint64_t div(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svdiv_s64_z(mask, a, b); + } + + static svint64_t fma(const svint64_t& a, const svint64_t& b, const svint64_t& c, const svbool_t& mask = svptrue_b64()) { + return add(mul(a, b, mask), c, mask); + } + + static svbool_t cmp_eq(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmpeq_s64(mask, a, b); + } + + static svbool_t cmp_neq(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmpne_s64(mask, a, b); + } + + static svbool_t cmp_gt(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmpgt_s64(mask, a, b); + } + + static svbool_t cmp_geq(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmpge_s64(mask, a, b); + } + + static svbool_t cmp_lt(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmplt_s64(mask, a, b); + } + + static svbool_t cmp_leq(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmple_s64(mask, a, b); + } + + static svint64_t ifelse(const svbool_t& m, const svint64_t& u, const svint64_t& v) { + return svsel_s64(m, u, v); + } + + static svint64_t max(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svmax_s64_x(mask, a, b); + } + + static svint64_t min(const svint64_t& a, const svint64_t& b, const svbool_t& mask = svptrue_b64()) { + return svmin_s64_x(mask, a, b); + } + + static svint64_t abs(const svint64_t& a, const svbool_t& mask = svptrue_b64()) { + return svabs_s64_z(mask, a); + } + + static int reduce_add(const svint64_t& a, const svbool_t& mask = svptrue_b64()) { + return svaddv_s64(mask, a); + } + + static svint64_t pow(const svint64_t& x, const svint64_t& y, const svbool_t& mask = svptrue_b64()) { + auto len = svlen_s64(x); + int32 a[len], b[len], r[len]; + copy_to_masked(x, a, mask); + copy_to_masked(y, b, mask); + + for (unsigned i = 0; i<len; ++i) { + r[i] = std::pow(a[i], b[i]); + } + return copy_from_masked(r, mask); + } + + static svint64_t gather(tag<sve_int>, const int32* p, const svint64_t& index, const svbool_t& mask = svptrue_b64()) { + return svld1sw_gather_s64index_s64(mask, p, index); + } + + static svint64_t gather(tag<sve_int>, svint64_t a, const int32* p, const svint64_t& index, const svbool_t& mask) { + return svsel_s64(mask, svld1sw_gather_s64index_s64(mask, p, index), a); + } + + static void scatter(tag<sve_int>, const svint64_t& s, int32* p, const svint64_t& index, const svbool_t& mask = svptrue_b64()) { + svst1w_scatter_s64index_s64(mask, p, index, s); + } + + static unsigned simd_width(const svint64_t& m) { + return svlen_s64(m); + } +}; + +struct sve_double { + // Use default implementations for: + // element, set_element. + + static svfloat64_t broadcast(double v) { + return svdup_n_f64(v); + } + + static void copy_to(const svfloat64_t& v, double* p) { + svst1_f64(svptrue_b64(), p, v); + } + + static void copy_to_masked(const svfloat64_t& v, double* p, const svbool_t& mask) { + svst1_f64(mask, p, v); + } + + static svfloat64_t copy_from(const double* p) { + return svld1_f64(svptrue_b64(), p); + } + + static svfloat64_t copy_from_masked(const double* p, const svbool_t& mask) { + return svld1_f64(mask, p); + } + + static svfloat64_t copy_from_masked(const svfloat64_t& v, const double* p, const svbool_t& mask) { + return svsel_f64(mask, svld1_f64(mask, p), v); + } + + static double element0(const svfloat64_t& a) { + return svlasta_f64(svptrue_b64(), a); + } + + static svfloat64_t neg(const svfloat64_t& a, const svbool_t& mask = svptrue_b64()) { + return svneg_f64_z(mask, a); + } + + static svfloat64_t add(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svadd_f64_z(mask, a, b); + } + + static svfloat64_t sub(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svsub_f64_z(mask, a, b); + } + + static svfloat64_t mul(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svmul_f64_z(mask, a, b); + } + + static svfloat64_t div(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svdiv_f64_z(mask, a, b); + } + + static svfloat64_t fma(const svfloat64_t& a, const svfloat64_t& b, const svfloat64_t& c, const svbool_t& mask = svptrue_b64()) { + return svmad_f64_z(mask, a, b, c); + } + + static svbool_t cmp_eq(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmpeq_f64(mask, a, b); + } + + static svbool_t cmp_neq(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmpne_f64(mask, a, b); + } + + static svbool_t cmp_gt(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmpgt_f64(mask, a, b); + } + + static svbool_t cmp_geq(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmpge_f64(mask, a, b); + } + + static svbool_t cmp_lt(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmplt_f64(mask, a, b); + } + + static svbool_t cmp_leq(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svcmple_f64(mask, a, b); + } + + static svfloat64_t ifelse(const svbool_t& m, const svfloat64_t& u, const svfloat64_t& v) { + return svsel_f64(m, u, v); + } + + static svfloat64_t max(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svmax_f64_x(mask, a, b); + } + + static svfloat64_t min(const svfloat64_t& a, const svfloat64_t& b, const svbool_t& mask = svptrue_b64()) { + return svmin_f64_x(mask, a, b); + } + + static svfloat64_t abs(const svfloat64_t& x, const svbool_t& mask = svptrue_b64()) { + return svabs_f64_x(mask, x); + } + + static double reduce_add(const svfloat64_t& a, const svbool_t& mask = svptrue_b64()) { + return svaddv_f64(mask, a); + } + + static svfloat64_t gather(tag<sve_int>, const double* p, const svint64_t& index, const svbool_t& mask = svptrue_b64()) { + return svld1_gather_s64index_f64(mask, p, index); + } + + static svfloat64_t gather(tag<sve_int>, svfloat64_t a, const double* p, const svint64_t& index, const svbool_t& mask) { + return svsel_f64(mask, svld1_gather_s64index_f64(mask, p, index), a); + } + + static void scatter(tag<sve_int>, const svfloat64_t& s, double* p, const svint64_t& index, const svbool_t& mask = svptrue_b64()) { + svst1_scatter_s64index_f64(mask, p, index, s); + } + + // Refer to avx/avx2 code for details of the exponential and log + // implementations. + + static svfloat64_t exp(const svfloat64_t& x) { + // Masks for exceptional cases. + + auto is_large = cmp_gt(x, broadcast(exp_maxarg)); + auto is_small = cmp_lt(x, broadcast(exp_minarg)); + + // Compute n and g. + + auto n = svrintz_f64_z(svptrue_b64(), add(mul(broadcast(ln2inv), x), broadcast(0.5))); + + auto g = fma(n, broadcast(-ln2C1), x); + g = fma(n, broadcast(-ln2C2), g); + + auto gg = mul(g, g); + + // Compute the g*P(g^2) and Q(g^2). + auto odd = mul(g, horner(gg, P0exp, P1exp, P2exp)); + auto even = horner(gg, Q0exp, Q1exp, Q2exp, Q3exp); + + // Compute R(g)/R(-g) = 1 + 2*g*P(g^2) / (Q(g^2)-g*P(g^2)) + + auto expg = fma(broadcast(2), div(odd, sub(even, odd)), broadcast(1)); + + // Scale by 2^n, propogating NANs. + + auto result = svscale_f64_z(svptrue_b64(), expg, svcvt_s64_f64_z(svptrue_b64(), n)); + + return + ifelse(is_large, broadcast(HUGE_VAL), + ifelse(is_small, broadcast(0), + result)); + } + + static svfloat64_t expm1(const svfloat64_t& x) { + auto is_large = cmp_gt(x, broadcast(exp_maxarg)); + auto is_small = cmp_lt(x, broadcast(expm1_minarg)); + + auto half = broadcast(0.5); + auto one = broadcast(1.); + + auto nnz = cmp_gt(abs(x), half); + auto n = svrinta_f64_z(nnz, mul(broadcast(ln2inv), x)); + + auto g = fma(n, broadcast(-ln2C1), x); + g = fma(n, broadcast(-ln2C2), g); + + auto gg = mul(g, g); + + auto odd = mul(g, horner(gg, P0exp, P1exp, P2exp)); + auto even = horner(gg, Q0exp, Q1exp, Q2exp, Q3exp); + + // Compute R(g)/R(-g) -1 = 2*g*P(g^2) / (Q(g^2)-g*P(g^2)) + + auto expgm1 = div(mul(broadcast(2), odd), sub(even, odd)); + + // For small x (n zero), bypass scaling step to avoid underflow. + // Otherwise, compute result 2^n * expgm1 + (2^n-1) by: + // result = 2 * ( 2^(n-1)*expgm1 + (2^(n-1)+0.5) ) + // to avoid overflow when n=1024. + + auto nm1 = svcvt_s64_f64_z(svptrue_b64(), sub(n, one)); + + auto result = + svscale_f64_z(svptrue_b64(), + add(sub(svscale_f64_z(svptrue_b64(),one, nm1), half), + svscale_f64_z(svptrue_b64(),expgm1, nm1)), + svcvt_s64_f64_z(svptrue_b64(), one)); + + return + ifelse(is_large, broadcast(HUGE_VAL), + ifelse(is_small, broadcast(-1), + ifelse(nnz, result, expgm1))); + } + + static svfloat64_t exprelr(const svfloat64_t& x) { + auto ones = broadcast(1); + return ifelse(cmp_eq(ones, add(ones, x)), ones, div(x, expm1(x))); + } + + static svfloat64_t log(const svfloat64_t& x) { + // Masks for exceptional cases. + + auto is_large = cmp_geq(x, broadcast(HUGE_VAL)); + auto is_small = cmp_lt(x, broadcast(log_minarg)); + auto is_domainerr = cmp_lt(x, broadcast(0)); + + auto is_nan = svnot_b_z(svptrue_b64(), cmp_eq(x, x)); + is_domainerr = svorr_b_z(svptrue_b64(), is_nan, is_domainerr); + + svfloat64_t g = svcvt_f64_s32_z(svptrue_b64(), logb_normal(x)); + svfloat64_t u = fraction_normal(x); + + svfloat64_t one = broadcast(1.); + svfloat64_t half = broadcast(0.5); + auto gtsqrt2 = cmp_geq(u, broadcast(sqrt2)); + g = ifelse(gtsqrt2, add(g, one), g); + u = ifelse(gtsqrt2, mul(u, half), u); + + auto z = sub(u, one); + auto pz = horner(z, P0log, P1log, P2log, P3log, P4log, P5log); + auto qz = horner1(z, Q0log, Q1log, Q2log, Q3log, Q4log); + + auto z2 = mul(z, z); + auto z3 = mul(z2, z); + + auto r = div(mul(z3, pz), qz); + r = add(r, mul(g, broadcast(ln2C4))); + r = sub(r, mul(z2, half)); + r = add(r, z); + r = add(r, mul(g, broadcast(ln2C3))); + + // r is alrady NaN if x is NaN or negative, otherwise + // return +inf if x is +inf, or -inf if zero or (positive) denormal. + + return ifelse(is_domainerr, broadcast(NAN), + ifelse(is_large, broadcast(HUGE_VAL), + ifelse(is_small, broadcast(-HUGE_VAL), r))); + } + + static svfloat64_t pow(const svfloat64_t& x, const svfloat64_t& y) { + auto len = svlen_f64(x); + double a[len], b[len], r[len]; + copy_to(x, a); + copy_to(y, b); + + for (unsigned i = 0; i<len; ++i) { + r[i] = std::pow(a[i], b[i]); + } + return copy_from(r); + } + + static unsigned simd_width(const svfloat64_t& m) { + return svlen_f64(m); + } + +protected: + // Compute n and f such that x = 2^n·f, with |f| ∈ [1,2), given x is finite and normal. + static svint32_t logb_normal(const svfloat64_t& x) { + svuint32_t xw = svtrn2_u32(svreinterpret_u32_f64(x), svreinterpret_u32_f64(x)); + svuint64_t lmask = svdup_n_u64(0x00000000ffffffff); + svuint64_t xt = svand_u64_z(svptrue_b64(), svreinterpret_u64_u32(xw), lmask); + svuint32_t xhi = svreinterpret_u32_u64(xt); + auto emask = svdup_n_u32(0x7ff00000); + auto ebiased = svlsr_n_u32_z(svptrue_b64(), svand_u32_z(svptrue_b64(), xhi, emask), 20); + + return svsub_s32_z(svptrue_b64(), svreinterpret_s32_u32(ebiased), svdup_n_s32(1023)); + } + + static svfloat64_t fraction_normal(const svfloat64_t& x) { + svuint64_t emask = svdup_n_u64(-0x7ff0000000000001); + svuint64_t bias = svdup_n_u64(0x3ff0000000000000); + return svreinterpret_f64_u64( + svorr_u64_z(svptrue_b64(), bias, svand_u64_z(svptrue_b64(), emask, svreinterpret_u64_f64(x)))); + } + + static inline svfloat64_t horner1(svfloat64_t x, double a0) { + return add(x, broadcast(a0)); + } + + static inline svfloat64_t horner(svfloat64_t x, double a0) { + return broadcast(a0); + } + + template <typename... T> + static svfloat64_t horner(svfloat64_t x, double a0, T... tail) { + return fma(x, horner(x, tail...), broadcast(a0)); + } + + template <typename... T> + static svfloat64_t horner1(svfloat64_t x, double a0, T... tail) { + return fma(x, horner1(x, tail...), broadcast(a0)); + } + + static svfloat64_t fms(const svfloat64_t& a, const svfloat64_t& b, const svfloat64_t& c) { + return svnmsb_f64_z(svptrue_b64(), a, b, c); + } +}; + +} // namespace detail + +namespace simd_abi { +template <typename T, unsigned N> struct sve; +template <> struct sve<double, 0> {using type = detail::sve_double;}; +template <> struct sve<int, 0> {using type = detail::sve_int;}; +}; // namespace simd_abi + +template <typename Value, unsigned N, template <class, unsigned> class Abi> +struct simd_wrap; + +template <typename Value, template <class, unsigned> class Abi> +struct simd_wrap<Value, (unsigned)0, Abi> { using type = typename detail::simd_traits<typename simd_abi::sve<Value, 0u>::type>::vector_type; }; + +template <typename Value, unsigned N, template <class, unsigned> class Abi> +struct simd_mask_wrap; + +template <typename Value, template <class, unsigned> class Abi> +struct simd_mask_wrap<Value, (unsigned)0, Abi> { + using type = typename detail::simd_traits< + typename detail::simd_traits<typename simd_abi::sve<Value, 0u>::type>::mask_impl + >::vector_type; }; + +// Math functions exposed for SVE types + +#define ARB_SVE_UNARY_ARITHMETIC_(name)\ +template <typename T>\ +T name(const T& a) {\ + return detail::sve_type_to_impl<T>::type::name(a);\ +}; + +#define ARB_SVE_BINARY_ARITHMETIC_(name)\ +template <typename T>\ +auto name(const T& a, const T& b) {\ + return detail::sve_type_to_impl<T>::type::name(a, b);\ +};\ +template <typename T>\ +auto name(const T& a, const typename detail::simd_traits<typename detail::sve_type_to_impl<T>::type>::scalar_type& b) {\ + return name(a, detail::sve_type_to_impl<T>::type::broadcast(b));\ +};\ +template <typename T>\ +auto name(const typename detail::simd_traits<typename detail::sve_type_to_impl<T>::type>::scalar_type& a, const T& b) {\ + return name(detail::sve_type_to_impl<T>::type::broadcast(a), b);\ +}; + + +ARB_PP_FOREACH(ARB_SVE_BINARY_ARITHMETIC_, add, sub, mul, div, pow, max, min) +ARB_PP_FOREACH(ARB_SVE_BINARY_ARITHMETIC_, cmp_eq, cmp_neq, cmp_leq, cmp_lt, cmp_geq, cmp_gt, logical_and, logical_or) +ARB_PP_FOREACH(ARB_SVE_UNARY_ARITHMETIC_, logical_not, neg, abs, exp, log, expm1, exprelr) + +#undef ARB_SVE_UNARY_ARITHMETIC_ +#undef ARB_SVE_BINARY_ARITHMETIC_ + +template <typename T> +T fma(const T& a, T b, T c) { + return detail::sve_type_to_impl<T>::type::fma(a, b, c); +} + +template <typename T> +auto sum(const T& a) { + return detail::sve_type_to_impl<T>::type::reduce_add(a); +} + +// Indirect/Indirect indexed/Where Expression copy methods + +template <typename T, typename V> +static void indirect_copy_to(const T& s, V* p, unsigned width) { + using Impl = typename detail::sve_type_to_impl<T>::type; + using ImplMask = typename detail::simd_traits<Impl>::mask_impl; + Impl::copy_to_masked(s, p, ImplMask::true_mask(width)); +} + +template <typename T, typename M, typename V> +static void indirect_copy_to(const T& data, const M& mask, V* p, unsigned width) { + using Impl = typename detail::sve_type_to_impl<T>::type; + using ImplMask = typename detail::sve_type_to_impl<M>::type; + + Impl::copy_to_masked(data, p, ImplMask::logical_and(mask, ImplMask::true_mask(width))); +} + +template <typename T, typename I, typename V> +static void indirect_indexed_copy_to(const T& s, V* p, const I& index, unsigned width) { + using Impl = typename detail::sve_type_to_impl<T>::type; + using ImplIndex = typename detail::sve_type_to_impl<I>::type; + using ImplMask = typename detail::simd_traits<Impl>::mask_impl; + + Impl::scatter(detail::tag<ImplIndex>{}, s, p, index, ImplMask::true_mask(width)); +} + +template <typename T, typename I, typename M, typename V> +static void indirect_indexed_copy_to(const T& data, const M& mask, V* p, const I& index, unsigned width) { + using Impl = typename detail::sve_type_to_impl<T>::type; + using ImplIndex = typename detail::sve_type_to_impl<I>::type; + using ImplMask = typename detail::sve_type_to_impl<M>::type; + + Impl::scatter(detail::tag<ImplIndex>{}, data, p, index, ImplMask::logical_and(mask, ImplMask::true_mask(width))); +} + +template <typename T, typename M, typename V> +static void where_copy_to(const M& mask, T& f, const V& t) { + using Impl = typename detail::sve_type_to_impl<T>::type; + f = Impl::ifelse(mask, Impl::broadcast(t), f); +} + +template <typename T, typename M> +static void where_copy_to(const M& mask, T& f, const T& t) { + f = detail::sve_type_to_impl<T>::type::ifelse(mask, t, f); +} + +template <typename T, typename M, typename V> +static void where_copy_to(const M& mask, T& f, const V* p, unsigned width) { + using Impl = typename detail::sve_type_to_impl<T>::type; + using ImplMask = typename detail::sve_type_to_impl<M>::type; + + auto m = ImplMask::logical_and(mask, ImplMask::true_mask(width)); + f = Impl::ifelse(mask, Impl::copy_from_masked(p, m), f); +} + +template <typename T, typename I, typename M, typename V> +static void where_copy_to(const M& mask, T& f, const V* p, const I& index, unsigned width) { + using Impl = typename detail::sve_type_to_impl<T>::type; + using IndexImpl = typename detail::sve_type_to_impl<I>::type; + using ImplMask = typename detail::sve_type_to_impl<M>::type; + + auto m = ImplMask::logical_and(mask, ImplMask::true_mask(width)); + T temp = Impl::gather(detail::tag<IndexImpl>{}, p, index, m); + f = Impl::ifelse(mask, temp, f); +} + +template <typename I, typename T, typename V> +void compound_indexed_add( + const T& s, + V* p, + const I& index, + unsigned width, + index_constraint constraint) +{ + using Impl = typename detail::sve_type_to_impl<T>::type; + using ImplIndex = typename detail::sve_type_to_impl<I>::type; + using ImplMask = typename detail::simd_traits<Impl>::mask_impl; + + auto mask = ImplMask::true_mask(width); + switch (constraint) { + case index_constraint::none: + { + typename detail::simd_traits<ImplIndex>::scalar_type o[width]; + ImplIndex::copy_to_masked(index, o, mask); + + V a[width]; + Impl::copy_to_masked(s, a, mask); + + V temp = 0; + for (unsigned i = 0; i<width-1; ++i) { + temp += a[i]; + if (o[i] != o[i+1]) { + p[o[i]] += temp; + temp = 0; + } + } + temp += a[width-1]; + p[o[width-1]] += temp; + } + break; + case index_constraint::independent: + { + auto v = Impl::add(Impl::gather(detail::tag<ImplIndex>{}, p, index, mask), s, mask); + Impl::scatter(detail::tag<ImplIndex>{}, v, p, index, mask); + } + break; + case index_constraint::contiguous: + { + p += ImplIndex::element0(index); + auto v = Impl::add(Impl::copy_from_masked(p, mask), s, mask); + Impl::copy_to_masked(v, p, mask); + } + break; + case index_constraint::constant: + p += ImplIndex::element0(index); + *p += Impl::reduce_add(s, mask); + break; + } +} + +static int width(const svfloat64_t& v) { + return svlen_f64(v); +}; + +static int width(const svint64_t& v) { + return svlen_s64(v); +}; + +template <typename S, typename std::enable_if_t<detail::is_sve<S>::value, int> = 0> +static int width() { S v; return width(v); } + +namespace detail { + +template <typename I, typename V> +class indirect_indexed_expression; + +template <typename V> +class indirect_expression; + +template <typename T, typename M> +class where_expression; + +template <typename T, typename M> +class const_where_expression; + +template <typename To> +struct simd_cast_impl { + template <typename V> + static To cast(const V& a) { + return detail::sve_type_to_impl<To>::type::broadcast(a); + } + + template <typename V> + static To cast(const indirect_expression<V>& a) { + using Impl = typename detail::sve_type_to_impl<To>::type; + using ImplMask = typename detail::simd_traits<Impl>::mask_impl; + + return Impl::copy_from_masked(a.p, ImplMask::true_mask(a.width)); + } + + template <typename I, typename V> + static To cast(const indirect_indexed_expression<I,V>& a) { + using Impl = typename detail::sve_type_to_impl<To>::type; + using IndexImpl = typename detail::sve_type_to_impl<I>::type; + using ImplMask = typename detail::simd_traits<Impl>::mask_impl; + + To r; + auto mask = ImplMask::true_mask(a.width); + switch (a.constraint) { + case index_constraint::none: + r = Impl::gather(tag<IndexImpl>{}, a.p, a.index, mask); + break; + case index_constraint::independent: + r = Impl::gather(tag<IndexImpl>{}, a.p, a.index, mask); + break; + case index_constraint::contiguous: + { + const auto* p = IndexImpl::element0(a.index) + a.p; + r = Impl::copy_from_masked(p, mask); + } + break; + case index_constraint::constant: + { + const auto *p = IndexImpl::element0(a.index) + a.p; + auto l = (*p); + r = Impl::broadcast(l); + } + break; + } + return r; + } + + template <typename T, typename V> + static To cast(const const_where_expression<T,V>& a) { + auto r = detail::sve_type_to_impl<To>::type::broadcast(0); + r = detail::sve_type_to_impl<To>::type::ifelse(a.mask_, a.data_, r); + return r; + } + + template <typename T, typename V> + static To cast(const where_expression<T,V>& a) { + auto r = detail::sve_type_to_impl<To>::type::broadcast(0); + r = detail::sve_type_to_impl<To>::type::ifelse(a.mask_, a.data_, r); + return r; + } +}; + +template <typename T, typename V> +void assign(T& a, const detail::indirect_expression<V>& b) { + a = detail::simd_cast_impl<T>::cast(b); +} + +template <typename T, typename I, typename V> +void assign(T& a, const detail::indirect_indexed_expression<I, V>& b) { + a = detail::simd_cast_impl<T>::cast(b); +} + +} // namespace detail +} // namespace simd +} // namespace arb + +#endif // def __ARM_FEATURE_SVE diff --git a/mechanisms/default/kdrmt.mod b/mechanisms/default/kdrmt.mod index 54c070441deb82925b7468c0e04d62de9fe4a205..86a697a051eaa5ba1440abd05592ebad40862521 100644 --- a/mechanisms/default/kdrmt.mod +++ b/mechanisms/default/kdrmt.mod @@ -61,12 +61,12 @@ DERIVATIVE states { PROCEDURE trates(v,celsius) { LOCAL qt - LOCAL alpm, betm + LOCAL alpm_t, betm_t LOCAL tmp qt=q10^((celsius-24)/10) minf = 1/(1 + exp(-(v-21)/10)) tmp = zetam*(v-vhalfm) - alpm = exp(tmp) - betm = exp(gmm*tmp) - mtau = betm/(qt*a0m*(1+alpm)) + alpm_t = exp(tmp) + betm_t = exp(gmm*tmp) + mtau = betm_t/(qt*a0m*(1+alpm_t)) } diff --git a/modcc/modcc.cpp b/modcc/modcc.cpp index 488f0ed0b4803f65fee0324623ef13001ce45e5a..d1ec2c76a44abbdbfaf29e0ee4cef95a8e8ea27e 100644 --- a/modcc/modcc.cpp +++ b/modcc/modcc.cpp @@ -48,6 +48,7 @@ std::unordered_map<std::string, targetKind> targetKindMap = { std::unordered_map<std::string, enum simd_spec::simd_abi> simdAbiMap = { {"none", simd_spec::none}, {"neon", simd_spec::neon}, + {"sve", simd_spec::sve}, {"avx", simd_spec::avx}, {"avx2", simd_spec::avx2}, {"avx512", simd_spec::avx512}, @@ -118,7 +119,7 @@ std::istream& operator>> (std::istream& i, simd_spec& spec) { auto npos = std::string::npos; std::string s; i >> s; - unsigned width = 0; + unsigned width = no_size; auto suffix = s.find_last_of('/'); if (suffix!=npos) { diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp index f6edfe9ad6adbae8116c5192101b2a545577c44e..4c947d6088179c5870c1aa9ba5941e214863e50e 100644 --- a/modcc/printer/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -170,6 +170,157 @@ void CExprEmitter::visit(IfExpression* e) { /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// std::unordered_set<std::string> SimdExprEmitter::mask_names_; +void SimdExprEmitter::visit(PowBinaryExpression* e) { + out_ << "S::pow("; + e->lhs()->accept(this); + out_ << ", "; + e->rhs()->accept(this); + out_ << ')'; +} + +void SimdExprEmitter::visit(NumberExpression* e) { + out_ << " (double)" << as_c_double(e->value()); +} + +void SimdExprEmitter::visit(UnaryExpression* e) { + static std::unordered_map<tok, const char*> unaryop_tbl = { + {tok::minus, "S::neg"}, + {tok::exp, "S::exp"}, + {tok::cos, "S::cos"}, + {tok::sin, "S::sin"}, + {tok::log, "S::log"}, + {tok::abs, "S::abs"}, + {tok::exprelr, "S::exprelr"}, + {tok::safeinv, "safeinv"} + }; + + if (!unaryop_tbl.count(e->op())) { + throw compiler_exception( + "CExprEmitter: unsupported unary operator "+token_string(e->op()), e->location()); + } + + const char* op_spelling = unaryop_tbl.at(e->op()); + Expression* inner = e->expression(); + + auto iden = inner->is_identifier(); + bool is_scalar = iden && scalars_.count(iden->name()); + if (e->op()==tok::minus && is_scalar) { + out_ << "simd_cast<simd_value>(-"; + inner->accept(this); + out_ << ")"; + } + else { + emit_as_call(op_spelling, inner); + } +} + +void SimdExprEmitter::visit(BinaryExpression* e) { + static std::unordered_map<tok, const char *> func_tbl = { + {tok::minus, "S::sub"}, + {tok::plus, "S::add"}, + {tok::times, "S::mul"}, + {tok::divide, "S::div"}, + {tok::lt, "S::cmp_lt"}, + {tok::lte, "S::cmp_leq"}, + {tok::gt, "S::cmp_gt"}, + {tok::gte, "S::cmp_geq"}, + {tok::equality, "S::cmp_eq"}, + {tok::land, "S::logical_and"}, + {tok::lor, "S::logical_or"}, + {tok::ne, "S::cmp_neq"}, + {tok::min, "S::min"}, + {tok::max, "S::max"}, + }; + + static std::unordered_map<tok, const char *> binop_tbl = { + {tok::minus, "-"}, + {tok::plus, "+"}, + {tok::times, "*"}, + {tok::divide, "/"}, + {tok::lt, "<"}, + {tok::lte, "<="}, + {tok::gt, ">"}, + {tok::gte, ">="}, + {tok::equality, "=="}, + {tok::land, "&&"}, + {tok::lor, "||"}, + {tok::ne, "!="}, + {tok::min, "min"}, + {tok::max, "max"}, + }; + + + if (!binop_tbl.count(e->op())) { + throw compiler_exception( + "CExprEmitter: unsupported binary operator " + token_string(e->op()), e->location()); + } + + std::string rhs_name, lhs_name; + + auto rhs = e->rhs(); + auto lhs = e->lhs(); + + const char *op_spelling = binop_tbl.at(e->op()); + const char *func_spelling = func_tbl.at(e->op()); + + if (rhs->is_identifier()) { + rhs_name = rhs->is_identifier()->name(); + } + if (lhs->is_identifier()) { + lhs_name = lhs->is_identifier()->name(); + } + + if (scalars_.count(rhs_name) && scalars_.count(lhs_name)) { + if (e->is_infix()) { + associativityKind assoc = Lexer::operator_associativity(e->op()); + int op_prec = Lexer::binop_precedence(e->op()); + + auto need_paren = [op_prec](Expression *subexpr, bool assoc_side) -> bool { + if (auto b = subexpr->is_binary()) { + int sub_prec = Lexer::binop_precedence(b->op()); + return sub_prec < op_prec || (!assoc_side && sub_prec == op_prec); + } + return false; + }; + + out_ << "simd_cast<simd_value>("; + if (need_paren(lhs, assoc == associativityKind::left)) { + emit_as_call("", lhs); + } else { + lhs->accept(this); + } + + out_ << op_spelling; + + if (need_paren(rhs, assoc == associativityKind::right)) { + emit_as_call("", rhs); + } else { + rhs->accept(this); + } + out_ << ")"; + } else { + out_ << "simd_cast<simd_value>("; + emit_as_call(op_spelling, lhs, rhs); + out_ << ")"; + } + } else if (scalars_.count(rhs_name) && !scalars_.count(lhs_name)) { + out_ << func_spelling << '('; + lhs->accept(this); + out_ << ", simd_cast<simd_value>(" << rhs_name ; + out_ << "))"; + } else if (!scalars_.count(rhs_name) && scalars_.count(lhs_name)) { + out_ << func_spelling << "(simd_cast<simd_value>(" << lhs_name << "), "; + rhs->accept(this); + out_ << ")"; + } else { + out_ << func_spelling << '('; + lhs->accept(this); + out_ << ", "; + rhs->accept(this); + out_ << ')'; + } +} + void SimdExprEmitter::visit(BlockExpression* block) { for (auto& stmt: block->statements()) { if (!stmt->is_local_declaration()) { @@ -209,15 +360,22 @@ void SimdExprEmitter::visit(AssignmentExpression* e) { if (lhs->is_variable() && lhs->is_variable()->is_range()) { if (!input_mask_.empty()) { - mask = mask + " && " + input_mask_; + mask = "S::logical_and(" + mask + ", " + input_mask_ + ")"; } - out_ << "S::where(" << mask << ", " << "simd_value("; - e->rhs()->accept(this); - out_ << "))"; if(is_indirect_) - out_ << ".copy_to(" << lhs->name() << "+index_)"; + out_ << "indirect(" << lhs->name() << "+index_, simd_width_) = "; else - out_ << ".copy_to(" << lhs->name() << "+i_)"; + out_ << "indirect(" << lhs->name() << "+i_, simd_width_) = "; + + out_ << "S::where(" << mask << ", "; + + bool cast = e->rhs()->is_number(); + if (cast) out_ << "simd_cast<simd_value>("; + e->rhs()->accept(this); + + out_ << ")"; + + if (cast) out_ << ")"; } else { out_ << "S::where(" << mask << ", "; e->lhs()->accept(this); @@ -237,17 +395,17 @@ void SimdExprEmitter::visit(IfExpression* e) { auto new_mask = make_unique_var(e->scope(), "mask_"); // Set new masks - out_ << "simd_value::simd_mask " << new_mask << " = "; + out_ << "simd_mask " << new_mask << " = "; e->condition()->accept(this); out_ << ";\n"; if (!current_mask_.empty()) { auto base_mask = processing_true_ ? current_mask_ : current_mask_bar_; - current_mask_bar_ = base_mask + " && !" + new_mask; - current_mask_ = base_mask + " && " + new_mask; + current_mask_bar_ = "S::logical_and(" + base_mask + ", S::logical_not(" + new_mask + "))"; + current_mask_ = "S::logical_and(" + base_mask + ", " + new_mask + ")"; } else { - current_mask_bar_ = "!" + new_mask; + current_mask_bar_ = "S::logical_not(" + new_mask + ")"; current_mask_ = new_mask; } diff --git a/modcc/printer/cexpr_emit.hpp b/modcc/printer/cexpr_emit.hpp index fbbfd448d017dfde85b69397c967f7b323e500eb..9286a5c94a97ea0cebb846c36f8539a86e0cfece 100644 --- a/modcc/printer/cexpr_emit.hpp +++ b/modcc/printer/cexpr_emit.hpp @@ -40,12 +40,21 @@ inline void cexpr_emit(Expression* e, std::ostream& out, Visitor* fallback) { class SimdExprEmitter: public CExprEmitter { using CExprEmitter::visit; public: - SimdExprEmitter(std::ostream& out, bool is_indirect, std::string input_mask, Visitor* fallback): - CExprEmitter(out, fallback), is_indirect_(is_indirect), input_mask_(input_mask) {} + SimdExprEmitter( + std::ostream& out, + bool is_indirect, + std::string input_mask, + const std::unordered_set<std::string>& scalars, + Visitor* fallback): + CExprEmitter(out, fallback), is_indirect_(is_indirect), input_mask_(input_mask), scalars_(scalars), fallback_(fallback) {} void visit(BlockExpression *e) override; void visit(CallExpression *e) override; + void visit(UnaryExpression *e) override; + void visit(BinaryExpression *e) override; void visit(AssignmentExpression *e) override; + void visit(PowBinaryExpression *e) override; + void visit(NumberExpression *e) override; void visit(IfExpression *e) override; protected: @@ -53,6 +62,8 @@ protected: bool processing_true_; bool is_indirect_; std::string current_mask_, current_mask_bar_, input_mask_; + std::unordered_set<std::string> scalars_; + Visitor* fallback_; private: std::string make_unique_var(scope_ptr scope, std::string prefix) { @@ -66,8 +77,15 @@ private: }; }; -inline void simd_expr_emit(Expression* e, std::ostream& out, bool is_indirect, std::string input_mask, Visitor* fallback) { - SimdExprEmitter emitter(out, is_indirect, input_mask, fallback); +inline void simd_expr_emit( + Expression* e, + std::ostream& out, + bool is_indirect, + std::string input_mask, + const std::unordered_set<std::string>& scalars, + Visitor* fallback) +{ + SimdExprEmitter emitter(out, is_indirect, input_mask, scalars, fallback); e->accept(&emitter); } diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index 2ebd8db35a9342513d304513b39d00b7fa9ef4e5..73ab685dc1b98e5d9b951d0064b27e1a69e9a0b2 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -29,7 +29,7 @@ void emit_simd_procedure_proto(std::ostream&, ProcedureExpression*, const std::s void emit_masked_simd_procedure_proto(std::ostream&, ProcedureExpression*, const std::string& qualified = ""); void emit_api_body(std::ostream&, APIMethod*); -void emit_simd_api_body(std::ostream&, APIMethod*, moduleKind); +void emit_simd_api_body(std::ostream&, APIMethod*, const std::vector<VariableExpression*>& scalars); void emit_index_initialize(std::ostream& out, const std::unordered_set<std::string>& indices, simd_expr_constraint constraint); @@ -59,8 +59,13 @@ struct simdprint { Expression* expr_; bool is_indirect_ = false; bool is_masked_ = false; + std::unordered_set<std::string> scalars_; - explicit simdprint(Expression* expr): expr_(expr) {} + explicit simdprint(Expression* expr, const std::vector<VariableExpression*>& scalars): expr_(expr) { + for (const auto& s: scalars) { + scalars_.insert(s->name()); + } + } void set_indirect_index() { is_indirect_ = true; @@ -75,6 +80,7 @@ struct simdprint { printer.set_input_mask("mask_input_"); } printer.set_var_indexed(w.is_indirect_); + printer.save_scalar_names(w.scalars_); return w.expr_->accept(&printer), out; } }; @@ -147,6 +153,8 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { if (with_simd) { out << "#include <" << arb_header_prefix() << "simd/simd.hpp>\n"; + out << "#undef NDEBUG\n"; + out << "#include <cassert>\n"; } out << @@ -173,12 +181,20 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { out << "namespace S = ::arb::simd;\n" "using S::index_constraint;\n" - "static constexpr unsigned simd_width_ = "; + "using S::simd_cast;\n" + "using S::indirect;\n"; - if (!opt.simd.width) { + out << "static constexpr unsigned vector_length_ = "; + if (opt.simd.size == no_size) { out << "S::simd_abi::native_width<::arb::fvm_value_type>::value;\n"; + } else { + out << opt.simd.size << ";\n"; } - else { + + out << "static constexpr unsigned simd_width_ = "; + if (opt.simd.width == no_size) { + out << " vector_length_ ? vector_length_ : " << opt.simd.default_width << ";\n"; + } else { out << opt.simd.width << ";\n"; } @@ -188,18 +204,22 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { case simd_spec::avx2: abi += "avx2"; break; case simd_spec::avx512: abi += "avx512"; break; case simd_spec::neon: abi += "neon"; break; + case simd_spec::sve: abi += "sve"; break; case simd_spec::native: abi += "native"; break; default: abi += "default_abi"; break; } out << - "using simd_value = S::simd<::arb::fvm_value_type, simd_width_, " << abi << ">;\n" - "using simd_index = S::simd<::arb::fvm_index_type, simd_width_, " << abi << ">;\n" + "using simd_value = S::simd<::arb::fvm_value_type, vector_length_, " << abi << ">;\n" + "using simd_index = S::simd<::arb::fvm_index_type, vector_length_, " << abi << ">;\n" + "using simd_mask = S::simd_mask<::arb::fvm_value_type, vector_length_, "<< abi << ">;\n" "\n" "inline simd_value safeinv(simd_value x) {\n" - " S::where(x+1==1, x) = DBL_EPSILON;\n" - " return 1/x;\n" + " simd_value ones = simd_cast<simd_value>(1.0);\n" + " auto mask = S::cmp_eq(S::add(x,ones), ones);\n" + " S::where(mask, x) = simd_cast<simd_value>(DBL_EPSILON);\n" + " return S::div(ones, x);\n" "}\n" "\n"; } @@ -224,6 +244,8 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "void deliver_events(deliverable_event_stream::state events) override;\n" "void net_receive(int i_, value_type weight);\n"; + with_simd && out << "unsigned simd_width() const override { return simd_width_; }\n"; + out << "\n" << popindent << "protected:\n" << indent << @@ -355,7 +377,7 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { auto emit_body = [&](APIMethod *p) { if (with_simd) { - emit_simd_api_body(out, p, module_.kind()); + emit_simd_api_body(out, p, vars.scalars); } else { emit_api_body(out, p); @@ -387,11 +409,11 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { for (auto proc: normal_procedures(module_)) { if (with_simd) { emit_simd_procedure_proto(out, proc, class_name); - auto simd_print = simdprint(proc->body()); + auto simd_print = simdprint(proc->body(), vars.scalars); out << " {\n" << indent << simd_print << popindent << "}\n\n"; emit_masked_simd_procedure_proto(out, proc, class_name); - auto masked_print = simdprint(proc->body()); + auto masked_print = simdprint(proc->body(), vars.scalars); masked_print.set_masked(); out << " {\n" << indent << masked_print << popindent << "}\n\n"; } else { @@ -553,9 +575,9 @@ void SimdPrinter::visit(LocalVariable* sym) { void SimdPrinter::visit(VariableExpression *sym) { if (sym->is_range()) { if(is_indirect_) - out_ << "simd_value(" << sym->name() << "+index_)"; + out_ << "simd_cast<simd_value>(indirect(" << sym->name() << "+index_, simd_width_))"; else - out_ << "simd_value(" << sym->name() << "+i_)"; + out_ << "simd_cast<simd_value>(indirect(" << sym->name() << "+i_, simd_width_))"; } else { out_ << sym->name(); @@ -568,26 +590,35 @@ void SimdPrinter::visit(AssignmentExpression* e) { } Symbol* lhs = e->lhs()->is_identifier()->symbol(); + + bool cast = false; + if (auto id = e->rhs()->is_identifier()) { + if (scalars_.count(id->name())) cast = true; + } + if (e->rhs()->is_number()) cast = true; + if (scalars_.count(e->lhs()->is_identifier()->name())) cast = false; if (lhs->is_variable() && lhs->is_variable()->is_range()) { - if (!input_mask_.empty()) - out_ << "S::where(" << input_mask_ << ", simd_value("; + if(is_indirect_) + out_ << "indirect(" << lhs->name() << "+index_, simd_width_) = "; else - out_ << "simd_value("; + out_ << "indirect(" << lhs->name() << "+i_, simd_width_) = "; + + if (!input_mask_.empty()) + out_ << "S::where(" << input_mask_ << ", "; + if (cast) out_ << "simd_cast<simd_value>("; e->rhs()->accept(this); + if (cast) out_ << ")"; if (!input_mask_.empty()) out_ << ")"; - - if(is_indirect_) - out_ << ").copy_to(" << lhs->name() << "+index_)"; - else - out_ << ").copy_to(" << lhs->name() << "+i_)"; } else { out_ << lhs->name() << " = "; + if (cast) out_ << "simd_cast<simd_value>("; e->rhs()->accept(this); + if (cast) out_ << ")"; } } @@ -637,7 +668,7 @@ void emit_simd_procedure_proto(std::ostream& out, ProcedureExpression* e, const void emit_masked_simd_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& qualified) { out << "void " << qualified << (qualified.empty()? "": "::") << e->name() - << "(index_type i_, simd_value::simd_mask mask_input_"; + << "(index_type i_, simd_mask mask_input_"; for (auto& arg: e->args()) { out << ", const simd_value& " << arg->is_argument()->name(); } @@ -650,29 +681,29 @@ void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_con if (local->is_read()) { auto d = decode_indexed_variable(local->external_variable()); if (d.scalar()) { - out << "(" << d.data_var + out << " = simd_cast<simd_value>(" << d.data_var << "[0]);\n"; } else if (constraint == simd_expr_constraint::contiguous) { - out << "(" << d.data_var + out << " = simd_cast<simd_value>(indirect(" << d.data_var << " + " << d.index_var - << "[index_]);\n"; + << "[index_], simd_width_));\n"; } else if (constraint == simd_expr_constraint::constant) { - out << "(" << d.data_var + out << " = simd_cast<simd_value>(" << d.data_var << "[" << d.index_var << "element0]);\n"; } else { - out << "(S::indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", constraint_category_));\n"; + out << " = simd_cast<simd_value>(indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", simd_width_, constraint_category_));\n"; } if (d.scale != 1) { - out << local->name() << " *= " << d.scale << ";\n"; + out << local->name() << " = S::mul(" << local->name() << ", simd_cast<simd_value>(" << d.scale << "));\n"; } } else { - out << " = 0;\n"; + out << " = simd_cast<simd_value>(0);\n"; } } @@ -690,40 +721,50 @@ void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* ex std::string tempvar = "t_"+external->name(); if (constraint == simd_expr_constraint::contiguous) { - out << "simd_value "<< tempvar <<"(" << d.data_var << " + " << d.index_var << "[index_]);\n" - << tempvar << " += w_*"; + out << "simd_value "<< tempvar <<" = simd_cast<simd_value>(indirect(" << d.data_var << " + " << d.index_var << "[index_], simd_width_));\n"; - if (coeff!=1) out << as_c_double(coeff) << "*"; + if (coeff!=1) { + out << tempvar << " = S::fma(S::mul(w_, simd_cast<simd_value>(" + << as_c_double(coeff) << "))," + << from->name() << ", " << tempvar << ");\n"; + } else { + out << tempvar << " = S::fma(w_, " << from->name() << ", " << tempvar << ");\n"; + } - out << from->name() << ";\n" - << tempvar << ".copy_to(" << d.data_var << " + " << d.index_var << "[index_]);\n"; + out << "indirect(" << d.data_var << " + " << d.index_var << "[index_], simd_width_) = " << tempvar << ";\n"; } else { - out << "S::indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", constraint_category_)" - << " += w_*"; + out << "indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", simd_width_, constraint_category_)"; - if (coeff!=1) out << as_c_double(coeff) << "*"; - - out << from->name() << ";\n"; + if (coeff!=1) { + out << " += S::mul(w_, S::mul(simd_cast<simd_value>(" + << as_c_double(coeff) << "), " + << from->name() << "));\n"; + } else { + out << " += S::mul(w_, " << from->name() << ");\n"; + } } } else { if (constraint == simd_expr_constraint::contiguous) { + out << "indirect(" << d.data_var << " + " << d.index_var << "[index_], simd_width_) = "; if (coeff!=1) { - out << "(" << as_c_double(coeff) << "*" << from->name() << ")"; + out << "(S::mul(simd_cast<simd_value>(" << as_c_double(coeff) << ")," << from->name() << "));\n"; } else { - out << from->name(); + out << from->name() << ";\n"; } - out << ".copy_to(" << d.data_var << " + " << d.index_var << "[index_]);\n"; } else { - out << "S::indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", constraint_category_)" + out << "indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", simd_width_, constraint_category_)" << " = "; - if (coeff!=1) out << as_c_double(coeff) << "*"; - - out << from->name() << ";\n"; + if (coeff!=1) { + out << "(S::mul(simd_cast<simd_value>(" << as_c_double(coeff) << ")," << from->name() << "));\n"; + } + else { + out << from->name() << ";\n"; + } } } } @@ -735,28 +776,33 @@ void emit_index_initialize(std::ostream& out, const std::unordered_set<std::stri break; case simd_expr_constraint::constant: for (auto& index: indices) { - out << "simd_index::scalar_type " << index << "element0 = " << index << "[index_];\n"; - out << index_i_name(index) << " = " << index << "element0;\n"; + out << "arb::fvm_index_type " << index << "element0 = " << index << "[index_];\n"; + out << index_i_name(index) << " = simd_cast<simd_index>(" << index << "element0);\n"; } break; case simd_expr_constraint::other: for (auto& index: indices) { - out << index_i_name(index) << ".copy_from(" << index << ".data() + index_);\n"; + out << index_i_name(index) << " = simd_cast<simd_index>(indirect(" << index << ".data() + index_, simd_width_));\n"; } break; } } -void emit_body_for_loop(std::ostream& out, BlockExpression* body, const std::vector<LocalVariable*>& indexed_vars, - const std::unordered_set<std::string>& indices, const simd_expr_constraint& read_constraint, - const simd_expr_constraint& write_constraint) { +void emit_body_for_loop( + std::ostream& out, + BlockExpression* body, + const std::vector<LocalVariable*>& indexed_vars, + const std::vector<VariableExpression*>& scalars, + const std::unordered_set<std::string>& indices, + const simd_expr_constraint& read_constraint, + const simd_expr_constraint& write_constraint) { emit_index_initialize(out, indices, read_constraint); for (auto& sym: indexed_vars) { emit_simd_state_read(out, sym, read_constraint); } - simdprint printer(body); + simdprint printer(body, scalars); printer.set_indirect_index(); out << printer; @@ -768,6 +814,7 @@ void emit_body_for_loop(std::ostream& out, BlockExpression* body, const std::vec void emit_for_loop_per_constraint(std::ostream& out, BlockExpression* body, const std::vector<LocalVariable*>& indexed_vars, + const std::vector<VariableExpression*>& scalars, bool requires_weight, const std::unordered_set<std::string>& indices, const simd_expr_constraint& read_constraint, @@ -781,21 +828,36 @@ void emit_for_loop_per_constraint(std::ostream& out, BlockExpression* body, out << "index_type index_ = index_constraints_." << underlying_constraint_name << "[i_];\n"; if (requires_weight) { - out << "simd_value w_(weight_+index_);\n"; + out << "simd_value w_ = simd_cast<simd_value>(indirect((weight_+index_), simd_width_));\n"; } - emit_body_for_loop(out, body, indexed_vars, indices, read_constraint, write_constraint); + emit_body_for_loop(out, body, indexed_vars, scalars, indices, read_constraint, write_constraint); out << popindent << "}\n"; } -void emit_simd_api_body(std::ostream& out, APIMethod* method, moduleKind module_kind) { +void emit_simd_api_body(std::ostream& out, APIMethod* method, const std::vector<VariableExpression*>& scalars) { auto body = method->body(); auto indexed_vars = indexed_locals(method->scope()); bool requires_weight = false; std::vector<LocalVariable*> scalar_indexed_vars; std::unordered_set<std::string> indices; + + for (auto& s: body->is_block()->statements()) { + if (s->is_assignment()) { + for (auto& v: indexed_vars) { + if (s->is_assignment()->lhs()->is_identifier()->name() == v->external_variable()->name()) { + auto info = decode_indexed_variable(v->external_variable()); + if (info.accumulate) { + requires_weight = true; + } + break; + } + } + } + } + for (auto& sym: indexed_vars) { auto info = decode_indexed_variable(sym->external_variable()); if (!info.scalar()) { @@ -804,12 +866,9 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, moduleKind module_ else { scalar_indexed_vars.push_back(sym); } - if (info.accumulate) { - requires_weight = true; - } } - if (!body->statements().empty()) { + out << "assert(simd_width_ <= (unsigned)S::width(simd_cast<simd_value>(0)));\n"; if (!indices.empty()) { for (auto& index: indices) { out << "simd_index " << index_i_name(index) << ";\n"; @@ -821,21 +880,21 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, moduleKind module_ simd_expr_constraint constraint = simd_expr_constraint::contiguous; std::string underlying_constraint = "contiguous"; - emit_for_loop_per_constraint(out, body, indexed_vars, requires_weight, indices, constraint, + emit_for_loop_per_constraint(out, body, indexed_vars, scalars, requires_weight, indices, constraint, constraint, underlying_constraint); //Generate for loop for all independent simd_vectors constraint = simd_expr_constraint::other; underlying_constraint = "independent"; - emit_for_loop_per_constraint(out, body, indexed_vars, requires_weight, indices, constraint, + emit_for_loop_per_constraint(out, body, indexed_vars, scalars, requires_weight, indices, constraint, constraint, underlying_constraint); //Generate for loop for all simd_vectors that have no optimizing constraints constraint = simd_expr_constraint::other; underlying_constraint = "none"; - emit_for_loop_per_constraint(out, body, indexed_vars, requires_weight, indices, constraint, + emit_for_loop_per_constraint(out, body, indexed_vars, scalars, requires_weight, indices, constraint, constraint, underlying_constraint); //Generate for loop for all constant simd_vectors @@ -843,7 +902,7 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, moduleKind module_ simd_expr_constraint write_constraint = simd_expr_constraint::other; underlying_constraint = "constant"; - emit_for_loop_per_constraint(out, body, indexed_vars, requires_weight, indices, read_constraint, + emit_for_loop_per_constraint(out, body, indexed_vars, scalars, requires_weight, indices, read_constraint, write_constraint, underlying_constraint); } @@ -856,7 +915,7 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, moduleKind module_ out << "unsigned n_ = width_;\n\n" "for (unsigned i_ = 0; i_ < n_; i_ += simd_width_) {\n" << indent << - simdprint(body) << popindent << + simdprint(body, scalars) << popindent << "}\n"; } } diff --git a/modcc/printer/cprinter.hpp b/modcc/printer/cprinter.hpp index 5b12957f0a90f84fa421fa3008945ddfbfe8d2fc..ca450eb284f53326c2f625974d50ea18bfcef258 100644 --- a/modcc/printer/cprinter.hpp +++ b/modcc/printer/cprinter.hpp @@ -57,6 +57,9 @@ public: void set_input_mask(std::string input_mask) { input_mask_ = input_mask; } + void save_scalar_names(const std::unordered_set<std::string>& scalars) { + scalars_ = scalars; + } void visit(BlockExpression*) override; void visit(CallExpression*) override; @@ -65,13 +68,14 @@ public: void visit(LocalVariable*) override; void visit(AssignmentExpression*) override; - void visit(NumberExpression* e) override { cexpr_emit(e, out_, this); } - void visit(UnaryExpression* e) override { cexpr_emit(e, out_, this); } - void visit(BinaryExpression* e) override { cexpr_emit(e, out_, this); } - void visit(IfExpression* e) override { simd_expr_emit(e, out_, is_indirect_, input_mask_, this); } + void visit(NumberExpression* e) override { simd_expr_emit(e, out_, is_indirect_, input_mask_, scalars_, this); } + void visit(UnaryExpression* e) override { simd_expr_emit(e, out_, is_indirect_, input_mask_, scalars_, this); } + void visit(BinaryExpression* e) override { simd_expr_emit(e, out_, is_indirect_, input_mask_, scalars_, this); } + void visit(IfExpression* e) override { simd_expr_emit(e, out_, is_indirect_, input_mask_, scalars_, this); } private: std::ostream& out_; std::string input_mask_; bool is_indirect_ = false; + std::unordered_set<std::string> scalars_; }; diff --git a/modcc/printer/simd.hpp b/modcc/printer/simd.hpp index a218842f967d233c014264821207b60772c3b429..0b17a7f042d9ade114d15641bc110718eaa379fb 100644 --- a/modcc/printer/simd.hpp +++ b/modcc/printer/simd.hpp @@ -1,14 +1,18 @@ #pragma once +constexpr unsigned no_size = unsigned(-1); + struct simd_spec { - enum simd_abi { none, neon, avx, avx2, avx512, native, default_abi } abi = none; - unsigned width = 0; // zero => use `simd::native_width` to determine. + enum simd_abi { none, neon, avx, avx2, avx512, sve, native, default_abi } abi = none; + static constexpr unsigned default_width = 8; + unsigned size = no_size; // -1 => use `simd::native_width` to determine. + unsigned width = no_size; simd_spec() = default; - simd_spec(enum simd_abi a, unsigned w = 0): + simd_spec(enum simd_abi a, unsigned w = no_size): abi(a), width(w) { - if (width==0) { + if (width==no_size) { // Pick a width based on abi, if applicable. switch (abi) { case avx: @@ -24,5 +28,23 @@ struct simd_spec { default: ; } } + + switch (abi) { + case avx: + case avx2: + size = 4; + break; + case avx512: + size = 8; + break; + case neon: + size = 2; + break; + case sve: + size = 0; + break; + default: ; + } + } }; diff --git a/test/unit-modcc/test_printers.cpp b/test/unit-modcc/test_printers.cpp index fa8bd3ab608184fa8e8dae6658823386e600ddd2..7db8dac6b7a06c3032a855b5c2f7f31a1f3b660a 100644 --- a/test/unit-modcc/test_printers.cpp +++ b/test/unit-modcc/test_printers.cpp @@ -267,26 +267,26 @@ TEST(CPrinter, proc_body_inlined) { TEST(SimdPrinter, simd_if_else) { std::vector<const char*> expected_procs = { "simd_value u;\n" - "simd_value::simd_mask mask_0_ = i > 2;\n" - "S::where(mask_0_,u) = 7;\n" - "S::where(!mask_0_,u) = 5;\n" - "S::where(!mask_0_,simd_value(42)).copy_to(s+i_);\n" - "simd_value(u).copy_to(s+i_);" + "simd_mask mask_0_ = S::cmp_gt(i, (double)2);\n" + "S::where(mask_0_,u) = (double)7;\n" + "S::where(S::logical_not(mask_0_),u) = (double)5;\n" + "indirect(s+i_, simd_width_) = S::where(S::logical_not(mask_0_),simd_cast<simd_value>((double)42));\n" + "indirect(s+i_, simd_width_) = u;" , "simd_value u;\n" - "simd_value::simd_mask mask_1_ = i > 2;\n" - "S::where(mask_1_,u) = 7;\n" - "S::where(!mask_1_,u) = 5;\n" - "S::where(!mask_1_ && mask_input_,simd_value(42)).copy_to(s+i_);\n" - "S::where(mask_input_, simd_value(u)).copy_to(s+i_);" + "simd_mask mask_1_ = S::cmp_gt(i, (double)2);\n" + "S::where(mask_1_,u) = (double)7;\n" + "S::where(S::logical_not(mask_1_),u) = (double)5;\n" + "indirect(s+i_, simd_width_) = S::where(S::logical_and(S::logical_not(mask_1_), mask_input_),simd_cast<simd_value>((double)42));\n" + "indirect(s+i_, simd_width_) = S::where(mask_input_, u);" , - "simd_value::simd_mask mask_2_ = simd_value(g+i_)>2;\n" - "simd_value::simd_mask mask_3_ = simd_value(g+i_)>3;\n" - "S::where(mask_2_&&mask_3_,i) = 0.;\n" - "S::where(mask_2_&&!mask_3_,i) = 1;\n" - "simd_value::simd_mask mask_4_ = simd_value(g+i_)<1;\n" - "S::where(!mask_2_&& mask_4_,simd_value(2)).copy_to(s+i_);\n" - "rates(i_, !mask_2_&&!mask_4_, i);" + "simd_mask mask_2_ = S::cmp_gt(simd_cast<simd_value>(indirect(g+i_, simd_width_)), (double)2);\n" + "simd_mask mask_3_ = S::cmp_gt(simd_cast<simd_value>(indirect(g+i_, simd_width_)), (double)3);\n" + "S::where(S::logical_and(mask_2_,mask_3_),i) = (double)0.;\n" + "S::where(S::logical_and(mask_2_,S::logical_not(mask_3_)),i) = (double)1;\n" + "simd_mask mask_4_ = S::cmp_lt(simd_cast<simd_value>(indirect(g+i_, simd_width_)), (double)1);\n" + "indirect(s+i_, simd_width_) = S::where(S::logical_and(S::logical_not(mask_2_),mask_4_),simd_cast<simd_value>((double)2));\n" + "rates(i_, S::logical_and(S::logical_not(mask_2_),S::logical_not(mask_4_)), i);" }; Module m(io::read_all(DATADIR "/mod_files/test7.mod"), "test7.mod"); diff --git a/test/unit/test_partition_by_constraint.cpp b/test/unit/test_partition_by_constraint.cpp index 77fa0c72732aa7efd3a67351a18f1204b621b700..2e8a29ad283757ebc89971519482d8e52fef76bf 100644 --- a/test/unit/test_partition_by_constraint.cpp +++ b/test/unit/test_partition_by_constraint.cpp @@ -13,7 +13,9 @@ using namespace arb; using iarray = multicore::iarray; -static constexpr unsigned simd_width_ = arb::simd::simd_abi::native_width<fvm_value_type>::value; +constexpr unsigned vector_length = (unsigned) simd::simd_abi::native_width<fvm_value_type>::value; +using simd_value_type = simd::simd<fvm_value_type, vector_length, simd::simd_abi::default_abi>; +const int simd_width_ = simd::width<simd_value_type>(); const int input_size_ = 1024; diff --git a/test/unit/test_simd.cpp b/test/unit/test_simd.cpp index 3d611b260f6b1df9b8a34444ea1af223008ca078..67ac6a1adefcc54fc04c8d3e34b3cab03ff182b5 100644 --- a/test/unit/test_simd.cpp +++ b/test/unit/test_simd.cpp @@ -5,9 +5,10 @@ #include <random> #include <unordered_set> -#include <arbor/simd/simd.hpp> #include <arbor/simd/avx.hpp> #include <arbor/simd/neon.hpp> +#include <arbor/simd/sve.hpp> +#include <arbor/simd/simd.hpp> #include <arbor/util/compat.hpp> #include "common.hpp" @@ -203,21 +204,22 @@ TYPED_TEST_P(simd_value, copy_to_from_masked) { } simd s(buf1); - where(m1, s).copy_from(buf2); + where(m1, s) = indirect(buf2, N); EXPECT_TRUE(testing::indexed_eq_n(N, expected, s)); for (unsigned i = 0; i<N; ++i) { if (!mbuf2[i]) expected[i] = buf3[i]; } - where(m2, s).copy_to(buf3); + indirect(buf3, N) = where(m2, s); EXPECT_TRUE(testing::indexed_eq_n(N, expected, buf3)); for (unsigned i = 0; i<N; ++i) { expected[i] = mbuf2[i]? buf1[i]: buf4[i]; } - where(m2, simd(buf1)).copy_to(buf4); + simd b(buf1); + indirect(buf4, N) = where(m2, b); EXPECT_TRUE(testing::indexed_eq_n(N, expected, buf4)); } } @@ -875,7 +877,7 @@ typedef ::testing::Types< #ifdef __AVX512F__ simd<double, 8, simd_abi::avx512>, #endif -#if defined(__ARM_NEON) +#ifdef __ARM_NEON simd<double, 2, simd_abi::neon>, #endif @@ -921,7 +923,7 @@ TYPED_TEST_P(simd_indirect, gather) { fill_random(array, rng); fill_random(offset, rng, 0, (int)(buflen-1)); - simd s(indirect(array, simd_index(offset))); + simd s = simd_cast<simd>(indirect(array, simd_index(offset), N)); scalar test[N]; for (unsigned j = 0; j<N; ++j) { @@ -961,7 +963,8 @@ TYPED_TEST_P(simd_indirect, masked_gather) { simd s(original); simd_mask m(mask); - where(m, s).copy_from(indirect(array, simd_index(offset))); + + where(m, s) = indirect(array, simd_index(offset), N); EXPECT_TRUE(::testing::indexed_eq_n(N, test, s)); } @@ -995,7 +998,7 @@ TYPED_TEST_P(simd_indirect, scatter) { } simd s(values); - s.copy_to(indirect(array, simd_index(offset))); + indirect(array, simd_index(offset), N) = s; EXPECT_TRUE(::testing::indexed_eq_n(buflen, test, array)); } @@ -1033,7 +1036,7 @@ TYPED_TEST_P(simd_indirect, masked_scatter) { simd s(values); simd_mask m(mask); - where(m, s).copy_to(indirect(array, simd_index(offset))); + indirect(array, simd_index(offset), N) = where(m, s); EXPECT_TRUE(::testing::indexed_eq_n(buflen, test, array)); @@ -1041,7 +1044,8 @@ TYPED_TEST_P(simd_indirect, masked_scatter) { array[j] = test[j]; } - where(m, simd(values)).copy_to(indirect(array, simd_index(offset))); + simd v(values); + indirect(array, simd_index(offset), N) = where(m, v); EXPECT_TRUE(::testing::indexed_eq_n(buflen, test, array)); } @@ -1074,7 +1078,7 @@ TYPED_TEST_P(simd_indirect, add_and_subtract) { test[offset[j]] += values[j]; } - indirect(array, simd_index(offset)) += simd(values); + indirect(array, simd_index(offset), N) += simd(values); EXPECT_TRUE(::testing::indexed_eq_n(buflen, test, array)); fill_random(offset, rng, 0, (int)(buflen-1)); @@ -1086,7 +1090,7 @@ TYPED_TEST_P(simd_indirect, add_and_subtract) { test[offset[j]] -= values[j]; } - indirect(array, simd_index(offset)) -= simd(values); + indirect(array, simd_index(offset), N) -= simd(values); EXPECT_TRUE(::testing::indexed_eq_n(buflen, test, array)); } } @@ -1137,7 +1141,7 @@ TYPED_TEST_P(simd_indirect, constrained_add) { } while (!unique_elements(offset)); make_test_array(); - indirect(array, simd_index(offset), index_constraint::independent) += simd(values); + indirect(array, simd_index(offset), N, index_constraint::independent) += simd(values); EXPECT_TRUE(::testing::indexed_eq_n(buflen, test, array)); @@ -1149,7 +1153,7 @@ TYPED_TEST_P(simd_indirect, constrained_add) { } make_test_array(); - indirect(array, simd_index(offset), index_constraint::contiguous) += simd(values); + indirect(array, simd_index(offset), N, index_constraint::contiguous) += simd(values); EXPECT_TRUE(::testing::indexed_eq_n(buflen, test, array)); @@ -1168,7 +1172,7 @@ TYPED_TEST_P(simd_indirect, constrained_add) { } make_test_array(); - indirect(array, simd_index(offset), index_constraint::constant) += simd(values); + indirect(array, simd_index(offset), N, index_constraint::constant) += simd(values); EXPECT_TRUE(::testing::indexed_almost_eq_n(buflen, test, array)); @@ -1202,7 +1206,7 @@ typedef ::testing::Types< simd_and_index<simd<int, 8, simd_abi::avx512>, simd<int, 8, simd_abi::avx512>>, #endif -#if defined(__ARM_NEON) +#ifdef __ARM_NEON simd_and_index<simd<double, 2, simd_abi::neon>, simd<int, 2, simd_abi::neon>>, @@ -1288,7 +1292,7 @@ typedef ::testing::Types< simd_pair<simd<double, 8, simd_abi::avx512>, simd<int, 8, simd_abi::avx512>>, #endif -#if defined(__ARM_NEON) +#ifdef __ARM_NEON simd_pair<simd<double, 2, simd_abi::neon>, simd<int, 2, simd_abi::neon>>, #endif @@ -1298,3 +1302,525 @@ typedef ::testing::Types< > simd_casting_test_types; INSTANTIATE_TYPED_TEST_CASE_P(S, simd_casting, simd_casting_test_types); + + +// Sizeless simd types API tests + +template <typename T, typename V, unsigned N> +struct simd_t { + using simd_type = T; + using scalar_type = V; + static constexpr unsigned width = N; +}; + +template <typename T, typename I, typename M> +struct simd_types_t { + using simd_value = T; + using simd_index = I; + using simd_value_mask = M; +}; + +template <typename SI> +struct sizeless_api: public ::testing::Test {}; + +TYPED_TEST_CASE_P(sizeless_api); + +TYPED_TEST_P(sizeless_api, construct) { + using simd_value = typename TypeParam::simd_value::simd_type; + using scalar_value = typename TypeParam::simd_value::scalar_type; + + using simd_index = typename TypeParam::simd_index::simd_type; + using scalar_index = typename TypeParam::simd_index::scalar_type; + + constexpr unsigned N = TypeParam::simd_value::width; + + std::minstd_rand rng(1001); + + { + scalar_value a_in[N], a_out[N]; + fill_random(a_in, rng); + + simd_value av = simd_cast<simd_value>(indirect(a_in, N)); + + indirect(a_out, N) = av; + + EXPECT_TRUE(testing::indexed_eq_n(N, a_in, a_out)); + } + { + scalar_value a_in[2*N], b_in[N], a_out[N], exp_0[N], exp_1[2*N]; + fill_random(a_in, rng); + fill_random(b_in, rng); + + scalar_index idx[N]; + + auto make_test_indirect2simd = [&]() { + for (unsigned i = 0; i<N; ++i) { + exp_0[i] = a_in[idx[i]]; + } + }; + + auto make_test_simd2indirect = [&]() { + for (unsigned i = 0; i<2*N; ++i) { + exp_1[i] = a_in[i]; + } + for (unsigned i = 0; i<N; ++i) { + exp_1[idx[i]] = b_in[i]; + } + }; + + // Independent + for (unsigned i = 0; i < N; ++i) { + idx[i] = i*2; + } + simd_index idxv = simd_cast<simd_index>(indirect(idx, N)); + + make_test_indirect2simd(); + + simd_value av = simd_cast<simd_value>(indirect(a_in, idxv, N, index_constraint::independent)); + indirect(a_out, N) = av; + + EXPECT_TRUE(testing::indexed_eq_n(N, exp_0, a_out)); + + make_test_simd2indirect(); + + indirect(a_in, idxv, N, index_constraint::independent) = simd_cast<simd_value>(indirect(b_in, N)); + + EXPECT_TRUE(testing::indexed_eq_n(2*N, exp_1, a_in)); + + // contiguous + for (unsigned i = 0; i < N; ++i) { + idx[i] = i; + } + idxv = simd_cast<simd_index>(indirect(idx, N)); + + make_test_indirect2simd(); + + av = simd_cast<simd_value>(indirect(a_in, idxv, N, index_constraint::contiguous)); + indirect(a_out, N) = av; + + EXPECT_TRUE(testing::indexed_eq_n(N, exp_0, a_out)); + + make_test_simd2indirect(); + + indirect(a_in, idxv, N, index_constraint::contiguous) = simd_cast<simd_value>(indirect(b_in, N)); + + EXPECT_TRUE(testing::indexed_eq_n(2*N, exp_1, a_in)); + + // none + for (unsigned i = 0; i < N; ++i) { + idx[i] = i/2; + } + + idxv = simd_cast<simd_index>(indirect(idx, N)); + + make_test_indirect2simd(); + + av = simd_cast<simd_value>(indirect(a_in, idxv, N, index_constraint::none)); + indirect(a_out, N) = av; + + EXPECT_TRUE(testing::indexed_eq_n(N, exp_0, a_out)); + + make_test_simd2indirect(); + + indirect(a_in, idxv, N, index_constraint::none) = simd_cast<simd_value>(indirect(b_in, N)); + + EXPECT_TRUE(testing::indexed_eq_n(2*N, exp_1, a_in)); + + // constant + for (unsigned i = 0; i < N; ++i) { + idx[i] = 0; + } + + idxv = simd_cast<simd_index>(indirect(idx, N)); + + make_test_indirect2simd(); + + av = simd_cast<simd_value>(indirect(a_in, idxv, N, index_constraint::constant)); + indirect(a_out, N) = av; + + EXPECT_TRUE(testing::indexed_eq_n(N, exp_0, a_out)); + + make_test_simd2indirect(); + + indirect(a_in, idxv, N, index_constraint::constant) = simd_cast<simd_value>(indirect(b_in, N)); + + EXPECT_TRUE(testing::indexed_eq_n(2*N, exp_1, a_in)); + } +} + +TYPED_TEST_P(sizeless_api, where_exp) { + using simd_value = typename TypeParam::simd_value::simd_type; + using scalar_value = typename TypeParam::simd_value::scalar_type; + + using simd_index = typename TypeParam::simd_index::simd_type; + using scalar_index = typename TypeParam::simd_index::scalar_type; + + using mask_simd = typename TypeParam::simd_value_mask::simd_type; + + constexpr unsigned N = TypeParam::simd_value::width; + + std::minstd_rand rng(201); + + bool m[N]; + fill_random(m, rng); + mask_simd mv = simd_cast<mask_simd>(indirect(m, N)); + + scalar_value a[N], b[N], exp[N]; + fill_random(a, rng); + fill_random(b, rng); + + { + bool c[N]; + indirect(c, N) = mv; + EXPECT_TRUE(testing::indexed_eq_n(N, c, m)); + } + + // where = constant + { + scalar_value c[N]; + + simd_value av = simd_cast<simd_value>(indirect(a, N)); + + where(mv, av) = 42.3; + indirect(c, N) = av; + + for (unsigned i = 0; i<N; ++i) { + exp[i] = m[i]? 42.3 : a[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + } + + // where = simd + { + scalar_value c[N]; + + simd_value av = simd_cast<simd_value>(indirect(a, N)); + simd_value bv = simd_cast<simd_value>(indirect(b, N)); + + where(mv, av) = bv; + indirect(c, N) = av; + + for (unsigned i = 0; i<N; ++i) { + exp[i] = m[i]? b[i] : a[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + } + + // simd = where + { + scalar_value c[N]; + + simd_value av = simd_cast<simd_value>(indirect(a, N)); + simd_value bv = simd_cast<simd_value>(indirect(b, N)); + + simd_value cv = simd_cast<simd_value>(where(mv, add(av, bv))); + indirect(c, N) = cv; + + for (unsigned i = 0; i<N; ++i) { + exp[i] = m[i]? (a[i] + b[i]) : 0; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + } + + // where = indirect + { + scalar_value c[N]; + + simd_value av = simd_cast<simd_value>(indirect(a, N)); + + where(mv, av) = indirect(b, N); + indirect(c, N) = av; + + for (unsigned i = 0; i<N; ++i) { + exp[i] = m[i]? b[i] : a[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + } + + // indirect = where + { + scalar_value c[N]; + fill_random(c, rng); + + simd_value av = simd_cast<simd_value>(indirect(a, N)); + + indirect(c, N) = where(mv, av); + + for (unsigned i = 0; i<N; ++i) { + exp[i] = m[i]? a[i] : c[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + + indirect(c, N) = where(mv, neg(av)); + + for (unsigned i = 0; i<N; ++i) { + exp[i] = m[i]? -a[i] : c[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + } + + // where = indirect indexed + { + scalar_value c[N]; + + simd_value av = simd_cast<simd_value>(indirect(a, N)); + + scalar_index idx[N]; + for (unsigned i =0; i<N; ++i) { + idx[i] = i/2; + } + simd_index idxv = simd_cast<simd_index>(indirect(idx, N)); + + where(mv, av) = indirect(b, idxv, N, index_constraint::none); + indirect(c, N) = av; + + for (unsigned i = 0; i<N; ++i) { + exp[i] = m[i]? b[idx[i]] : a[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + } + + // indirect indexed = where + { + scalar_value c[N]; + fill_random(c, rng); + + simd_value av = simd_cast<simd_value>(indirect(a, N)); + simd_value bv = simd_cast<simd_value>(indirect(b, N)); + + scalar_index idx[N]; + for (unsigned i =0; i<N; ++i) { + idx[i] = i; + } + simd_index idxv = simd_cast<simd_index>(indirect(idx, N)); + + indirect(c, idxv, N, index_constraint::contiguous) = where(mv, av); + + for (unsigned i = 0; i<N; ++i) { + exp[idx[i]] = m[i]? a[i] : c[idx[i]]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + + indirect(c, idxv, N, index_constraint::contiguous) = where(mv, sub(av, bv)); + + for (unsigned i = 0; i<N; ++i) { + exp[idx[i]] = m[i]? a[i] - b[i] : c[idx[i]]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, c, exp)); + } +} + +TYPED_TEST_P(sizeless_api, arithmetic) { + using simd_value = typename TypeParam::simd_value::simd_type; + using scalar_value = typename TypeParam::simd_value::scalar_type; + + constexpr unsigned N = TypeParam::simd_value::width; + + std::minstd_rand rng(201); + + scalar_value a[N], b[N], c[N], expected[N]; + fill_random(a, rng); + fill_random(b, rng); + + bool m[N], expected_m[N]; + fill_random(m, rng); + + simd_value av = simd_cast<simd_value>(indirect(a, N)); + simd_value bv = simd_cast<simd_value>(indirect(b, N)); + + // add + { + indirect(c, N) = add(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected[i] = a[i] + b[i]; + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // sub + { + indirect(c, N) = sub(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected[i] = a[i] - b[i]; + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // mul + { + indirect(c, N) = mul(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected[i] = a[i] * b[i]; + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // div + { + indirect(c, N) = div(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected[i] = a[i] / b[i]; + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // pow + { + indirect(c, N) = pow(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected[i] = std::pow(a[i], b[i]); + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // min + { + indirect(c, N) = min(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected[i] = std::min(a[i], b[i]); + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // max + { + indirect(c, N) = max(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected[i] = std::max(a[i], b[i]); + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // cmp_eq + { + indirect(m, N) = cmp_eq(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected_m[i] = a[i] == b[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, m, expected_m)); + } + // cmp_neq + { + indirect(m, N) = cmp_neq(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected_m[i] = a[i] != b[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, m, expected_m)); + } + // cmp_leq + { + indirect(m, N) = cmp_leq(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected_m[i] = a[i] <= b[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, m, expected_m)); + } + // cmp_geq + { + indirect(m, N) = cmp_geq(av, bv); + for (unsigned i = 0; i<N; ++i) { + expected_m[i] = a[i] >= b[i]; + } + EXPECT_TRUE(testing::indexed_eq_n(N, m, expected_m)); + } + // sum + { + auto s = sum(av); + scalar_value expected_sum = 0; + + for (unsigned i = 0; i<N; ++i) { + expected_sum += a[i]; + } + EXPECT_FLOAT_EQ(expected_sum, s); + } + // neg + { + indirect(c, N) = neg(av); + for (unsigned i = 0; i<N; ++i) { + expected[i] = -a[i]; + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // abs + { + indirect(c, N) = abs(av); + for (unsigned i = 0; i<N; ++i) { + expected[i] = std::abs(a[i]); + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + // exp + { + indirect(c, N) = exp(av); + for (unsigned i = 0; i<N; ++i) { + EXPECT_NEAR(std::exp(a[i]), c[i], 1e-6); + } + } + // expm1 + { + indirect(c, N) = expm1(av); + for (unsigned i = 0; i<N; ++i) { + EXPECT_NEAR(std::expm1(a[i]), c[i], 1e-6); + } + } + // exprelr + { + indirect(c, N) = exprelr(av); + for (unsigned i = 0; i<N; ++i) { + EXPECT_NEAR(a[i]/(std::expm1(a[i])), c[i], 1e-6); + } + } + // log + { + scalar_value l[N]; + int max_exponent = std::numeric_limits<scalar_value>::max_exponent; + fill_random(l, rng, -max_exponent*std::log(2.), max_exponent*std::log(2.)); + for (auto& x: l) { + x = std::exp(x); + // SIMD log implementation may treat subnormal as zero + if (std::fpclassify(x)==FP_SUBNORMAL) x = 0; + } + simd_value lv = simd_cast<simd_value>(indirect(l, N)); + + indirect(c, N) = log(lv); + + for (unsigned i = 0; i<N; ++i) { + expected[i] = std::log(l[i]); + } + EXPECT_TRUE(testing::indexed_almost_eq_n(N, c, expected)); + } + +} + +REGISTER_TYPED_TEST_CASE_P(sizeless_api, construct, where_exp, arithmetic); + +typedef ::testing::Types< + +#ifdef __AVX__ + simd_types_t< simd_t< simd<double, 4, simd_abi::avx>, double, 4>, + simd_t< simd<int, 4, simd_abi::avx>, int, 4>, + simd_t<simd_mask<double, 4, simd_abi::avx>, int, 4>>, +#endif +#ifdef __AVX2__ + simd_types_t< simd_t< simd<double, 4, simd_abi::avx2>, double, 4>, + simd_t< simd<int, 4, simd_abi::avx2>, int, 4>, + simd_t<simd_mask<double, 4, simd_abi::avx2>, int, 4>>, +#endif +#ifdef __AVX512F__ + simd_types_t< simd_t< simd<double, 8, simd_abi::avx512>, double, 8>, + simd_t< simd<int, 8, simd_abi::avx512>, int, 8>, + simd_t<simd_mask<double, 8, simd_abi::avx512>, int, 8>>, +#endif +#ifdef __ARM_NEON + simd_types_t< simd_t< simd<double, 2, simd_abi::neon>, double, 2>, + simd_t< simd<int, 2, simd_abi::neon>, int, 2>, + simd_t<simd_mask<double, 2, simd_abi::neon>, double, 2>>, +#endif +#ifdef __ARM_FEATURE_SVE + simd_types_t< simd_t< simd<double, 0, simd_abi::sve>, double, 4>, + simd_t< simd<int, 0, simd_abi::sve>, int, 4>, + simd_t<simd_mask<double, 0, simd_abi::sve>, bool, 4>>, + + simd_types_t< simd_t< simd<double, 0, simd_abi::sve>, double, 8>, + simd_t< simd<int, 0, simd_abi::sve>, int, 8>, + simd_t<simd_mask<double, 0, simd_abi::sve>, bool, 8>>, +#endif + simd_types_t< simd_t< simd<double, 8, simd_abi::default_abi>, double, 8>, + simd_t< simd<int, 8, simd_abi::default_abi>, int, 8>, + simd_t<simd_mask<double, 8, simd_abi::default_abi>, bool, 8>> +> sizeless_api_test_types; + +INSTANTIATE_TYPED_TEST_CASE_P(S, sizeless_api, sizeless_api_test_types);