diff --git a/arbor/include/arbor/simd/neon.hpp b/arbor/include/arbor/simd/neon.hpp index 22ae0e7c11ad4881f17e1820c9e2ff7cffb0f5b8..beaa00d113a1d51d377276588975946f637dca11 100644 --- a/arbor/include/arbor/simd/neon.hpp +++ b/arbor/include/arbor/simd/neon.hpp @@ -395,12 +395,7 @@ struct neon_double2 : implbase<neon_double2> { auto is_large = cmp_gt(x, broadcast(exp_maxarg)); auto is_small = cmp_lt(x, broadcast(exp_minarg)); - - bool a[2]; - a[0] = isnan(vgetq_lane_f64(x, 0)) == 0 ? 0 : 1; - a[1] = isnan(vgetq_lane_f64(x, 1)) == 0 ? 0 : 1; - - auto is_nan = mask_copy_from(a); + auto is_not_nan = cmp_eq(x, x); // Compute n and g. @@ -430,7 +425,7 @@ struct neon_double2 : implbase<neon_double2> { return ifelse(is_large, broadcast(HUGE_VAL), ifelse(is_small, broadcast(0), - ifelse(is_nan, broadcast(NAN), result))); + ifelse(is_not_nan, result, broadcast(NAN)))); } // Use same rational polynomial expansion as for exp(x), without @@ -443,12 +438,7 @@ struct neon_double2 : implbase<neon_double2> { static float64x2_t expm1(const float64x2_t& x) { auto is_large = cmp_gt(x, broadcast(exp_maxarg)); auto is_small = cmp_lt(x, broadcast(expm1_minarg)); - - bool a[2]; - a[0] = isnan(vgetq_lane_f64(x, 0)) == 0 ? 0 : 1; - a[1] = isnan(vgetq_lane_f64(x, 1)) == 0 ? 0 : 1; - - auto is_nan = mask_copy_from(a); + auto is_not_nan = cmp_eq(x, x); auto half = broadcast(0.5); auto one = broadcast(1.); @@ -484,9 +474,7 @@ struct neon_double2 : implbase<neon_double2> { return ifelse(is_large, broadcast(HUGE_VAL), ifelse(is_small, broadcast(-1), - ifelse(is_nan, broadcast(NAN), - ifelse(nzero, expgm1, scaled)))); - } + ifelse(is_not_nan, ifelse(nzero, expgm1, scaled), broadcast(NAN)))); // Natural logarithm: // @@ -514,11 +502,7 @@ struct neon_double2 : implbase<neon_double2> { auto is_small = cmp_lt(x, broadcast(log_minarg)); auto is_domainerr = cmp_lt(x, broadcast(0)); - bool a[2]; - a[0] = isnan(vgetq_lane_f64(x, 0)) == 0 ? 0 : 1; - a[1] = isnan(vgetq_lane_f64(x, 0)) == 0 ? 0 : 1; - - auto is_nan = mask_copy_from(a); + auto is_nan = logical_not(cmp_eq(x, x)); is_domainerr = logical_or(is_nan, is_domainerr); float64x2_t g = vcvt_f64_f32(vcvt_f32_s32(logb_normal(x)));