diff --git a/arbor/include/arbor/simd/simd.hpp b/arbor/include/arbor/simd/simd.hpp index 58cb0540a649f2c4bcf3e3fa3aedf4eaecaa2c28..5a72bab71e602ddfb63cd55dcd984d950fdf6738 100644 --- a/arbor/include/arbor/simd/simd.hpp +++ b/arbor/include/arbor/simd/simd.hpp @@ -36,56 +36,48 @@ namespace detail { } // Top level functions for second API -using detail::simd_impl; -using detail::simd_mask_impl; - -template <typename Impl, typename V> -void assign(simd_impl<Impl>& a, const detail::indirect_expression<V>& b) { - a.copy_from(b); -} - -template <typename Impl, typename ImplIndex, typename V> -void assign(simd_impl<Impl>& a, const detail::indirect_indexed_expression<ImplIndex, V>& b) { +template <typename Impl, typename Other> +void assign(detail::simd_impl<Impl>& a, const Other& b) { a.copy_from(b); } template <typename Impl> -typename simd_impl<Impl>::scalar_type sum(const simd_impl<Impl>& a) { +typename detail::simd_impl<Impl>::scalar_type sum(const detail::simd_impl<Impl>& a) { return a.sum(); }; #define ARB_UNARY_ARITHMETIC_(name)\ template <typename Impl>\ -simd_impl<Impl> name(const simd_impl<Impl>& a) {\ - return simd_impl<Impl>::wrap(Impl::name(a.value_));\ +detail::simd_impl<Impl> name(const detail::simd_impl<Impl>& a) {\ + return detail::simd_impl<Impl>::wrap(Impl::name(a.value_));\ }; #define ARB_BINARY_ARITHMETIC_(name)\ template <typename Impl>\ -simd_impl<Impl> name(const simd_impl<Impl>& a, simd_impl<Impl> b) {\ - return simd_impl<Impl>::wrap(Impl::name(a.value_, b.value_));\ +detail::simd_impl<Impl> name(const detail::simd_impl<Impl>& a, detail::simd_impl<Impl> b) {\ + return detail::simd_impl<Impl>::wrap(Impl::name(a.value_, b.value_));\ };\ template <typename Impl>\ -simd_impl<Impl> name(const simd_impl<Impl>& a, typename simd_impl<Impl>::scalar_type b) {\ - return simd_impl<Impl>::wrap(Impl::name(a.value_, Impl::broadcast(b)));\ +detail::simd_impl<Impl> name(const detail::simd_impl<Impl>& a, typename detail::simd_impl<Impl>::scalar_type b) {\ + return detail::simd_impl<Impl>::wrap(Impl::name(a.value_, Impl::broadcast(b)));\ };\ template <typename Impl>\ -simd_impl<Impl> name(const typename simd_impl<Impl>::scalar_type a, simd_impl<Impl> b) {\ - return simd_impl<Impl>::wrap(Impl::name(Impl::broadcast(a), b.value_));\ +detail::simd_impl<Impl> name(const typename detail::simd_impl<Impl>::scalar_type a, detail::simd_impl<Impl> b) {\ + return detail::simd_impl<Impl>::wrap(Impl::name(Impl::broadcast(a), b.value_));\ }; #define ARB_BINARY_COMPARISON_(name)\ template <typename Impl>\ -typename simd_impl<Impl>::simd_mask name(const simd_impl<Impl>& a, simd_impl<Impl> b) {\ - return simd_impl<Impl>::mask(Impl::name(a.value_, b.value_));\ +typename detail::simd_impl<Impl>::simd_mask name(const detail::simd_impl<Impl>& a, detail::simd_impl<Impl> b) {\ + return detail::simd_impl<Impl>::mask(Impl::name(a.value_, b.value_));\ };\ template <typename Impl>\ -typename simd_impl<Impl>::simd_mask name(const simd_impl<Impl>& a, typename simd_impl<Impl>::scalar_type b) {\ - return simd_impl<Impl>::mask(Impl::name(a.value_, Impl::broadcast(b)));\ +typename detail::simd_impl<Impl>::simd_mask name(const detail::simd_impl<Impl>& a, typename detail::simd_impl<Impl>::scalar_type b) {\ + return detail::simd_impl<Impl>::mask(Impl::name(a.value_, Impl::broadcast(b)));\ };\ template <typename Impl>\ -typename simd_impl<Impl>::simd_mask name(const typename simd_impl<Impl>::scalar_type a, simd_impl<Impl> b) {\ - return simd_impl<Impl>::mask(Impl::name(Impl::broadcast(a), b.value_));\ +typename detail::simd_impl<Impl>::simd_mask name(const typename detail::simd_impl<Impl>::scalar_type a, detail::simd_impl<Impl> b) {\ + return detail::simd_impl<Impl>::mask(Impl::name(Impl::broadcast(a), b.value_));\ }; ARB_PP_FOREACH(ARB_BINARY_ARITHMETIC_, add, sub, mul, div, pow, max, min) @@ -97,23 +89,23 @@ ARB_PP_FOREACH(ARB_UNARY_ARITHMETIC_, neg, abs, sin, cos, exp, log, expm1, expr #undef ARB_UNARY_ARITHMETIC_ template <typename T> -simd_mask_impl<T> logical_and(const simd_mask_impl<T>& a, simd_mask_impl<T> b) { +detail::simd_mask_impl<T> logical_and(const detail::simd_mask_impl<T>& a, detail::simd_mask_impl<T> b) { return a && b; } template <typename T> -simd_mask_impl<T> logical_or(const simd_mask_impl<T>& a, simd_mask_impl<T> b) { +detail::simd_mask_impl<T> logical_or(const detail::simd_mask_impl<T>& a, detail::simd_mask_impl<T> b) { return a || b; } template <typename T> -simd_mask_impl<T> logical_not(const simd_mask_impl<T>& a) { +detail::simd_mask_impl<T> logical_not(const detail::simd_mask_impl<T>& a) { return !a; } template <typename T> -simd_impl<T> fma(const simd_impl<T> a, simd_impl<T> b, simd_impl<T> c) { - return simd_impl<T>::wrap(T::fma(a.value_, b.value_, c.value_)); +detail::simd_impl<T> fma(const detail::simd_impl<T> a, detail::simd_impl<T> b, detail::simd_impl<T> c) { + return detail::simd_impl<T>::wrap(T::fma(a.value_, b.value_, c.value_)); } namespace detail { @@ -469,6 +461,15 @@ namespace detail { Impl::scatter(tag<IndexImpl>{}, value_, pi.p, pi.index); } + template <typename Other, typename = std::enable_if_t<width==simd_traits<Other>::width>> + void copy_from(const simd_impl<Other>& x) { + value_ = Impl::cast_from(tag<Other>{}, x.value_); + } + + void copy_from(const scalar_type p) { + value_ = Impl::broadcast(p); + } + void copy_from(const scalar_type* p) { value_ = Impl::copy_from(p); } diff --git a/arbor/include/arbor/simd/sve.hpp b/arbor/include/arbor/simd/sve.hpp index 7a5c0ec9d59bf7c2beae5a638421306b1196e19a..6d2bd34b3776d988be0c5339be455c6e5a4d20c8 100644 --- a/arbor/include/arbor/simd/sve.hpp +++ b/arbor/include/arbor/simd/sve.hpp @@ -827,6 +827,10 @@ class const_where_expression; template <typename To> struct simd_cast_impl { + static To cast(const To& a) { + return a; + } + template <typename V> static To cast(const V& a) { return detail::sve_type_to_impl<To>::type::broadcast(a); @@ -887,13 +891,8 @@ struct simd_cast_impl { } }; -template <typename T, typename V> -void assign(T& a, const detail::indirect_expression<V>& b) { - a = detail::simd_cast_impl<T>::cast(b); -} - -template <typename T, typename I, typename V> -void assign(T& a, const detail::indirect_indexed_expression<I, V>& b) { +template <typename T, typename Other> +void assign(T& a, const Other& b) { a = detail::simd_cast_impl<T>::cast(b); } diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index 73ab685dc1b98e5d9b951d0064b27e1a69e9a0b2..b2a51c1f2b3f68185539717b317f91f2998a0619 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -182,7 +182,8 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "namespace S = ::arb::simd;\n" "using S::index_constraint;\n" "using S::simd_cast;\n" - "using S::indirect;\n"; + "using S::indirect;\n" + "using S::assign;\n"; out << "static constexpr unsigned vector_length_ = "; if (opt.simd.size == no_size) { @@ -574,10 +575,8 @@ void SimdPrinter::visit(LocalVariable* sym) { void SimdPrinter::visit(VariableExpression *sym) { if (sym->is_range()) { - if(is_indirect_) - out_ << "simd_cast<simd_value>(indirect(" << sym->name() << "+index_, simd_width_))"; - else - out_ << "simd_cast<simd_value>(indirect(" << sym->name() << "+i_, simd_width_))"; + auto index = is_indirect_? "index_": "i_"; + out_ << "simd_cast<simd_value>(indirect(" << sym->name() << "+" << index << ", simd_width_))"; } else { out_ << sym->name(); @@ -615,10 +614,19 @@ void SimdPrinter::visit(AssignmentExpression* e) { out_ << ")"; } else { - out_ << lhs->name() << " = "; - if (cast) out_ << "simd_cast<simd_value>("; + out_ << "assign(" << lhs->name() << ", "; + if (auto rhs = e->rhs()->is_identifier()) { + if (auto sym = rhs->symbol()) { + // We shouldn't call the rhs visitor in this case because it automatically casts indirect expressions + if (sym->is_variable() && sym->is_variable()->is_range()) { + auto index = is_indirect_ ? "index_" : "i_"; + out_ << "indirect(" << rhs->name() << "+" << index << ", simd_width_))"; + return; + } + } + } e->rhs()->accept(this); - if (cast) out_ << ")"; + out_ << ")"; } } @@ -685,7 +693,8 @@ void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_con << "[0]);\n"; } else if (constraint == simd_expr_constraint::contiguous) { - out << " = simd_cast<simd_value>(indirect(" << d.data_var + out << ";\n" + << "assign(" << local->name() << ", indirect(" << d.data_var << " + " << d.index_var << "[index_], simd_width_));\n"; } @@ -695,7 +704,8 @@ void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_con << "element0]);\n"; } else { - out << " = simd_cast<simd_value>(indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", simd_width_, constraint_category_));\n"; + out << ";\n" + << "assign(" << local->name() << ", indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", simd_width_, constraint_category_));\n"; } if (d.scale != 1) { @@ -721,7 +731,8 @@ void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* ex std::string tempvar = "t_"+external->name(); if (constraint == simd_expr_constraint::contiguous) { - out << "simd_value "<< tempvar <<" = simd_cast<simd_value>(indirect(" << d.data_var << " + " << d.index_var << "[index_], simd_width_));\n"; + out << "simd_value "<< tempvar <<";\n" + << "assign(" << tempvar << ", indirect(" << d.data_var << " + " << d.index_var << "[index_], simd_width_));\n"; if (coeff!=1) { out << tempvar << " = S::fma(S::mul(w_, simd_cast<simd_value>(" @@ -828,7 +839,8 @@ void emit_for_loop_per_constraint(std::ostream& out, BlockExpression* body, out << "index_type index_ = index_constraints_." << underlying_constraint_name << "[i_];\n"; if (requires_weight) { - out << "simd_value w_ = simd_cast<simd_value>(indirect((weight_+index_), simd_width_));\n"; + out << "simd_value w_;\n" + << "assign(w_, indirect((weight_+index_), simd_width_));\n"; } emit_body_for_loop(out, body, indexed_vars, scalars, indices, read_constraint, write_constraint);