Skip to content
Snippets Groups Projects
Commit 02d41881 authored by Vasileios Karakasis's avatar Vasileios Karakasis Committed by Ben Cumming
Browse files

modcc: AVX2 backend for mechanisms (#282)

Add AVX2 instrinsics back end for Haswell and Broadwell architectures.

We're still 3.5% and 5% slower than the icc `#pragma` version on Haswell and Broadwell,
respectively.
parent bd1e56a5
No related branches found
No related tags found
No related merge requests found
......@@ -15,8 +15,7 @@ elseif(NMC_VECTORIZE_TARGET STREQUAL "AVX")
set(modcc_opt "-O")
set(modcc_target "cpu")
elseif(NMC_VECTORIZE_TARGET STREQUAL "AVX2")
set(modcc_opt "-O")
set(modcc_target "cpu")
set(modcc_target "avx2")
else()
set(modcc_target "cpu")
endif()
......
//
// AVX2 backend
//
#pragma once
#include "backends/base.hpp"
namespace nest {
namespace mc {
namespace modcc {
// Specialize for the different architectures
template<>
struct simd_intrinsics<targetKind::avx2> {
static bool has_scatter() {
return false;
}
static bool has_gather() {
return true;
}
static std::string emit_headers() {
return "#include <immintrin.h>";
};
static std::string emit_simd_width() {
return "256";
}
static std::string emit_value_type() {
return "__m256d";
}
static std::string emit_index_type() {
return "__m128i";
}
template<typename T1, typename T2>
static void emit_binary_op(TextBuffer& tb, tok op,
const T1& arg1, const T2& arg2) {
switch (op) {
case tok::plus:
tb << "_mm256_add_pd(";
break;
case tok::minus:
tb << "_mm256_sub_pd(";
break;
case tok::times:
tb << "_mm256_mul_pd(";
break;
case tok::divide:
tb << "_mm256_div_pd(";
break;
default:
throw std::invalid_argument("Unknown binary operator");
}
emit_operands(tb, arg_emitter(arg1), arg_emitter(arg2));
tb << ")";
}
template<typename T>
static void emit_unary_op(TextBuffer& tb, tok op, const T& arg) {
switch (op) {
case tok::minus:
tb << "_mm256_sub_pd(_mm256_set1_pd(0), ";
break;
case tok::exp:
tb << "_mm256_exp_pd(";
break;
case tok::log:
tb << "_mm256_log_pd(";
break;
default:
throw std::invalid_argument("Unknown unary operator");
}
emit_operands(tb, arg_emitter(arg));
tb << ")";
}
template<typename B, typename E>
static void emit_pow(TextBuffer& tb, const B& base, const E& exp) {
tb << "_mm256_pow_pd(";
emit_operands(tb, arg_emitter(base), arg_emitter(exp));
tb << ")";
}
template<typename A, typename V>
static void emit_store_unaligned(TextBuffer& tb, const A& addr,
const V& value) {
tb << "_mm256_storeu_pd(";
emit_operands(tb, arg_emitter(addr), arg_emitter(value));
tb << ")";
}
template<typename A>
static void emit_load_unaligned(TextBuffer& tb, const A& addr) {
tb << "_mm256_loadu_pd(";
emit_operands(tb, arg_emitter(addr));
tb << ")";
}
template<typename A>
static void emit_load_index(TextBuffer& tb, const A& addr) {
tb << "_mm_lddqu_si128(";
emit_operands(tb, arg_emitter(addr));
tb << ")";
}
template<typename A, typename I, typename V, typename S>
static void emit_scatter(TextBuffer& tb, const A& addr,
const I& index, const V& value, const S& scale) {
// no support of scatter in AVX2, so revert to simple scalar updates
std::string scalar_index_ptr = varprefix + std::to_string(varcnt++);
std::string scalar_value_ptr = varprefix + std::to_string(varcnt++);
tb.end_line("{");
tb.increase_indentation();
// FIXME: should probably read "index_type*"
tb.add_gutter();
tb << "int* " << scalar_index_ptr
<< " = (int*) &" << index;
tb.end_line(";");
tb.add_gutter();
tb << "value_type* " << scalar_value_ptr
<< " = (value_type*) &" << value;
tb.end_line(";");
tb.add_line("for (int k_ = 0; k_ < simd_width; ++k_) {");
tb.increase_indentation();
tb.add_gutter();
tb << addr << "[" << scalar_index_ptr << "[k_]] = "
<< scalar_value_ptr << "[k_]";
tb.end_line(";");
tb.decrease_indentation();
tb.add_line("}");
tb.decrease_indentation();
tb.add_gutter();
tb << "}";
}
template<typename A, typename I, typename S>
static void emit_gather(TextBuffer& tb, const A& addr,
const I& index, const S& scale) {
tb << "_mm256_i32gather_pd(";
emit_operands(tb, arg_emitter(addr), arg_emitter(index),
arg_emitter(scale));
tb << ")";
}
template<typename T>
static void emit_set_value(TextBuffer& tb, const T& arg) {
tb << "_mm256_set1_pd(";
emit_operands(tb, arg_emitter(arg));
tb << ")";
}
private:
static int varcnt;
const static std::string varprefix;
};
int simd_intrinsics<targetKind::avx2>::varcnt = 0;
const std::string simd_intrinsics<targetKind::avx2>::varprefix = "_r";
}}} // closing namespaces
......@@ -15,7 +15,11 @@ namespace modcc {
template<>
struct simd_intrinsics<targetKind::avx512> {
static bool has_gather_scatter() {
static bool has_scatter() {
return true;
}
static bool has_gather() {
return true;
}
......
......@@ -79,7 +79,8 @@ struct simd_intrinsics {
template<typename T>
static void emit_set_value(TextBuffer& tb, const T& arg);
static bool has_gather_scatter();
static bool has_gather();
static bool has_scatter();
};
}}} // closing namespaces
#pragma once
#include "backends/avx2.hpp"
#include "backends/avx512.hpp"
......@@ -68,8 +68,12 @@ int main(int argc, char **argv) {
else if(targstr == "avx512") {
Options::instance().target = targetKind::avx512;
}
else if(targstr == "avx2") {
Options::instance().target = targetKind::avx2;
}
else {
std::cerr << red("error") << " target must be one in {cpu, gpu}\n";
std::cerr << red("error")
<< " target must be one in {cpu, gpu, avx2, avx512}\n";
return 1;
}
}
......@@ -153,6 +157,10 @@ int main(int argc, char **argv) {
text = SimdPrinter<targetKind::avx512>(
m, Options::instance().optimize).emit_source();
break;
case targetKind::avx2:
text = SimdPrinter<targetKind::avx2>(
m, Options::instance().optimize).emit_source();
break;
default :
std::cerr << red("error") << ": unknown printer" << std::endl;
exit(1);
......
......@@ -7,6 +7,7 @@ enum class targetKind {
cpu,
gpu,
// Vectorisation targets
avx2,
avx512
};
......
......@@ -259,6 +259,7 @@ void SimdPrinter<Arch>::emit_api_loop(APIMethod* e,
text_.add_gutter();
text_ << simd_backend::emit_index_type() << " "
<< vindex_name << " = ";
// FIXME: cast should better go inside `emit_load_index()`
simd_backend::emit_load_index(
text_, cast_type + "&" + index_ptr_name + "[off_]");
text_.end_line(";");
......@@ -289,7 +290,7 @@ void SimdPrinter<Arch>::emit_api_loop(APIMethod* e,
auto var = symbol.second->is_local_variable();
if (is_output(var) &&
!is_point_process() &&
simd_backend::has_gather_scatter()) {
simd_backend::has_scatter()) {
// We can safely use scatter, but we need to fetch the variable
// first
text_.add_line();
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment