Skip to content

Commit

Permalink
AVX10 QS8/QD8 GEMM/IGEMM
Browse files Browse the repository at this point in the history
- Rename from avx512skx to avx256skx which enables kernels on avx10 hardware
- avx256skx kernels also run on skylake or later with avx512VL

PiperOrigin-RevId: 648191801
  • Loading branch information
fbarchard authored and xnnpack-bot committed Jun 30, 2024
1 parent 3f3f040 commit 2d98d53
Show file tree
Hide file tree
Showing 104 changed files with 2,162 additions and 2,065 deletions.
91 changes: 47 additions & 44 deletions bench/qd8-f16-qc4w-gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,53 @@
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64


#if XNN_ENABLE_AVX256SKX && (XNN_ARCH_X86 || XNN_ARCH_X86_64)
static void qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx,
xnn_init_f16_qc4w_minmax_avx_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX256SKX);
}

BENCHMARK_GEMM(qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx)

static void qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx,
xnn_init_f16_qc4w_minmax_avx_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX256SKX);
}

BENCHMARK_GEMM(qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx)

static void qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx,
xnn_init_f16_qc4w_minmax_avx_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX256SKX);
}

BENCHMARK_GEMM(qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx)

static void qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx,
xnn_init_f16_qc4w_minmax_avx_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX256SKX);
}

BENCHMARK_GEMM(qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx)
#endif // XNN_ENABLE_AVX256SKX && (XNN_ARCH_X86 || XNN_ARCH_X86_64)


#if XNN_ENABLE_AVXVNNI && (XNN_ARCH_X86 || XNN_ARCH_X86_64)
static void qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
Expand Down Expand Up @@ -824,50 +871,6 @@


#if XNN_ARCH_X86 || XNN_ARCH_X86_64
static void qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx512skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx512skx,
xnn_init_f16_qc4w_minmax_avx_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX512SKX);
}

BENCHMARK_GEMM(qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx512skx)

static void qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx512skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx512skx,
xnn_init_f16_qc4w_minmax_avx_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX512SKX);
}

BENCHMARK_GEMM(qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx512skx)

static void qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx512skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx512skx,
xnn_init_f16_qc4w_minmax_avx_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX512SKX);
}

BENCHMARK_GEMM(qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx512skx)

static void qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx512skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx512skx,
xnn_init_f16_qc4w_minmax_avx_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX512SKX);
}

BENCHMARK_GEMM(qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx512skx)

static void qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2,
Expand Down
91 changes: 47 additions & 44 deletions bench/qd8-f16-qc8w-gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,53 @@
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64


#if XNN_ENABLE_AVX256SKX && (XNN_ARCH_X86 || XNN_ARCH_X86_64)
static void qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx256skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx256skx,
xnn_init_f16_minmax_avx_params,
xnn_pack_qs8_gemm_goi_w,
/*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX256SKX);
}

BENCHMARK_GEMM(qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx256skx)

static void qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx256skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx256skx,
xnn_init_f16_minmax_avx_params,
xnn_pack_qs8_gemm_goi_w,
/*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX256SKX);
}

BENCHMARK_GEMM(qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx256skx)

static void qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx256skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx256skx,
xnn_init_f16_minmax_avx_params,
xnn_pack_qs8_gemm_goi_w,
/*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX256SKX);
}

BENCHMARK_GEMM(qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx256skx)

static void qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx256skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx256skx,
xnn_init_f16_minmax_avx_params,
xnn_pack_qs8_gemm_goi_w,
/*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX256SKX);
}

BENCHMARK_GEMM(qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx256skx)
#endif // XNN_ENABLE_AVX256SKX && (XNN_ARCH_X86 || XNN_ARCH_X86_64)


#if XNN_ENABLE_AVXVNNI && (XNN_ARCH_X86 || XNN_ARCH_X86_64)
static void qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
Expand Down Expand Up @@ -692,50 +739,6 @@


#if XNN_ARCH_X86 || XNN_ARCH_X86_64
static void qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx512skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx512skx,
xnn_init_f16_minmax_avx_params,
xnn_pack_qs8_gemm_goi_w,
/*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX512SKX);
}

BENCHMARK_GEMM(qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx512skx)

static void qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx512skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx512skx,
xnn_init_f16_minmax_avx_params,
xnn_pack_qs8_gemm_goi_w,
/*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX512SKX);
}

BENCHMARK_GEMM(qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx512skx)

static void qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx512skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx512skx,
xnn_init_f16_minmax_avx_params,
xnn_pack_qs8_gemm_goi_w,
/*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX512SKX);
}

BENCHMARK_GEMM(qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx512skx)

static void qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx512skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx512skx,
xnn_init_f16_minmax_avx_params,
xnn_pack_qs8_gemm_goi_w,
/*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX512SKX);
}

BENCHMARK_GEMM(qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx512skx)

static void qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx2,
Expand Down
91 changes: 47 additions & 44 deletions bench/qd8-f32-qc4w-gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1869,6 +1869,53 @@
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64


#if XNN_ENABLE_AVX256SKX && (XNN_ARCH_X86 || XNN_ARCH_X86_64)
static void qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx,
xnn_init_f32_qc4w_minmax_avx_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX256SKX);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx)

static void qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx,
xnn_init_f32_qc4w_minmax_avx_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX256SKX);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx)

static void qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx,
xnn_init_f32_qc4w_minmax_avx_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX256SKX);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx)

static void qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx,
xnn_init_f32_qc4w_minmax_avx_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX256SKX);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx)
#endif // XNN_ENABLE_AVX256SKX && (XNN_ARCH_X86 || XNN_ARCH_X86_64)


#if XNN_ENABLE_AVXVNNI && (XNN_ARCH_X86 || XNN_ARCH_X86_64)
static void qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
Expand Down Expand Up @@ -2137,50 +2184,6 @@

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512skx_prfm)

static void qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx512skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx512skx,
xnn_init_f32_qc4w_minmax_avx_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX512SKX);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx512skx)

static void qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx512skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx512skx,
xnn_init_f32_qc4w_minmax_avx_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX512SKX);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx512skx)

static void qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx512skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx512skx,
xnn_init_f32_qc4w_minmax_avx_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX512SKX);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx512skx)

static void qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx512skx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx512skx,
xnn_init_f32_qc4w_minmax_avx_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1,
benchmark::utils::CheckAVX512SKX);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx512skx)

static void qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx2,
Expand Down
Loading

0 comments on commit 2d98d53

Please sign in to comment.