diff --git a/include/arbor/simd/avx512.hpp b/include/arbor/simd/avx512.hpp index 0f062d11b9c9b35cf7d209722e1332c02fdc67b1..e00871d4a49d52ad4899921fd2b633af83d2616f 100644 --- a/include/arbor/simd/avx512.hpp +++ b/include/arbor/simd/avx512.hpp @@ -381,19 +381,19 @@ struct avx512_int8: implbase<avx512_int8> { // Specialized 8-wide gather and scatter for avx512_int8 implementation. - static __m512i gather(avx512_int8, const int32* p, const __m512i& index) { + static __m512i gather(tag<avx512_int8>, const int32* p, const __m512i& index) { return _mm512_mask_i32gather_epi32(_mm512_setzero_epi32(), lo(), index, p, 4); } - static __m512i gather(avx512_int8, __m512i a, const int32* p, const __m512i& index, const __mmask8& mask) { + static __m512i gather(tag<avx512_int8>, __m512i a, const int32* p, const __m512i& index, const __mmask8& mask) { return _mm512_mask_i32gather_epi32(a, mask, index, p, 4); } - static void scatter(avx512_int8, const __m512i& s, int32* p, const __m512i& index) { + static void scatter(tag<avx512_int8>, const __m512i& s, int32* p, const __m512i& index) { _mm512_mask_i32scatter_epi32(p, lo(), index, s, 4); } - static void scatter(avx512_int8, const __m512i& s, int32* p, const __m512i& index, const __mmask8& mask) { + static void scatter(tag<avx512_int8>, const __m512i& s, int32* p, const __m512i& index, const __mmask8& mask) { _mm512_mask_i32scatter_epi32(p, mask, index, s, 4); } }; @@ -561,19 +561,19 @@ struct avx512_double8: implbase<avx512_double8> { // Specialized 8-wide gather and scatter for avx512_int8 implementation. - static __m512d gather(avx512_int8, const double* p, const __m512i& index) { + static __m512d gather(tag<avx512_int8>, const double* p, const __m512i& index) { return _mm512_i32gather_pd(_mm512_castsi512_si256(index), p, 8); } - static __m512d gather(avx512_int8, __m512d a, const double* p, const __m512i& index, const __mmask8& mask) { + static __m512d gather(tag<avx512_int8>, __m512d a, const double* p, const __m512i& index, const __mmask8& mask) { return _mm512_mask_i32gather_pd(a, mask, _mm512_castsi512_si256(index), p, 8); } - static void scatter(avx512_int8, const __m512d& s, double* p, const __m512i& index) { + static void scatter(tag<avx512_int8>, const __m512d& s, double* p, const __m512i& index) { _mm512_i32scatter_pd(p, _mm512_castsi512_si256(index), s, 8); } - static void scatter(avx512_int8, const __m512d& s, double* p, const __m512i& index, const __mmask8& mask) { + static void scatter(tag<avx512_int8>, const __m512d& s, double* p, const __m512i& index, const __mmask8& mask) { _mm512_mask_i32scatter_pd(p, mask, _mm512_castsi512_si256(index), s, 8); }