Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RVV] Add qs8-gemm/igemm support for risc-v #7639

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bench/gemm-benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ void GEMMBenchmark(benchmark::State& state,
std::generate(b.begin(), b.end(), std::ref(i32rng));

const size_t w_element_size = sizeof(int8_t);
const size_t w_size = nc_stride * sizeof(int32_t) + kc_stride * nc_stride * w_element_size;
const size_t w_size = nc_stride * (sizeof(float) + sizeof(int32_t)) + kc_stride * nc_stride * w_element_size;
const size_t c_elements = mc * nc;
const size_t num_buffers = 1 + benchmark::utils::DivideRoundUp<size_t>(
benchmark::utils::GetMaxCacheSize(),
Expand All @@ -152,6 +152,7 @@ void GEMMBenchmark(benchmark::State& state,
const xnn_qs8_packing_params packing_params = {int8_t(127 - 0x80)};
pack(/*g=*/1, nc, kc, nr, kr, sr, k.data(), b.data(), /*scale=*/nullptr,
w.data(), nr * sizeof(float), &packing_params);

xnnpack::Buffer<int8_t> c(c_elements * num_buffers);

union xnn_qs8_qc8w_conv_minmax_params quantization_params;
Expand Down
26 changes: 26 additions & 0 deletions bench/qs8-qc8w-gemm-fp32.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4648,6 +4648,32 @@ static void qs8_qc8w_gemm_minmax_fp32_ukernel_4x4__scalar_lrintf(benchmark::Stat

BENCHMARK_GEMM(qs8_qc8w_gemm_minmax_fp32_ukernel_4x4__scalar_lrintf)

#if XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR

static void qs8_qc8w_gemm_minmax_fp32_ukernel_1x4v__rvv(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4v__rvv,
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
xnn_pack_qs8_gemm_goi_w,
/*mr=*/1, /*nr=*/4 * xnn_init_hardware_config()->vlenb / sizeof(int32_t), /*kr=*/1, /*sr=*/1,
/*isa_check=*/nullptr);
}

BENCHMARK_GEMM(qs8_qc8w_gemm_minmax_fp32_ukernel_1x4v__rvv)

static void qs8_qc8w_gemm_minmax_fp32_ukernel_4x4v__rvv(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4v__rvv,
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
xnn_pack_qs8_gemm_goi_w,
/*mr=*/4, /*nr=*/4 * xnn_init_hardware_config()->vlenb / sizeof(int32_t), /*kr=*/1, /*sr=*/1,
/*isa_check=*/nullptr);
}

BENCHMARK_GEMM(qs8_qc8w_gemm_minmax_fp32_ukernel_4x4v__rvv)

#endif // XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR

#ifndef XNNPACK_BENCHMARK_NO_MAIN
BENCHMARK_MAIN();
#endif
4 changes: 4 additions & 0 deletions cmake/gen/rvv_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ SET(PROD_RVV_MICROKERNEL_SRCS
src/f32-vrnd/gen/f32-vrndz-rvv-u4v.c
src/f32-vrsqrt/gen/f32-vrsqrt-rvv-rsqrt-u4v.c
src/qs8-f32-vcvt/gen/qs8-f32-vcvt-rvv-u2v.c
src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x4v-minmax-fp32-rvv.c
src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x4v-minmax-fp32-rvv.c
src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x4v-minmax-fp32-rvv.c
src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x4v-minmax-fp32-rvv.c
src/qs8-vlrelu/gen/qs8-vlrelu-rvv-u2v.c
src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u2v.c
src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u2v.c
Expand Down
4 changes: 4 additions & 0 deletions scripts/generate-qs8-gemm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ tools/xngen src/qs8-gemm/rvv.c.in -D MR=5 -D NR=m4 -D -D DATATYPE=QD8 -o src/qd
tools/xngen src/qs8-gemm/rvv.c.in -D MR=6 -D NR=m4 -D -D DATATYPE=QD8 -o src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x4v-minmax-rvv.c &
tools/xngen src/qs8-gemm/rvv.c.in -D MR=7 -D NR=m4 -D -D DATATYPE=QD8 -o src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x4v-minmax-rvv.c &
tools/xngen src/qs8-gemm/rvv.c.in -D MR=8 -D NR=m4 -D -D DATATYPE=QD8 -o src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x4v-minmax-rvv.c &

tools/xngen src/qs8-gemm/rvv.c.in -D MR=1 -D NR=m4 -D REQUANTIZATION=FP32 -D DATATYPE=QC8 -o src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x4v-minmax-fp32-rvv.c &
tools/xngen src/qs8-gemm/rvv.c.in -D MR=4 -D NR=m4 -D REQUANTIZATION=FP32 -D DATATYPE=QC8 -o src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x4v-minmax-fp32-rvv.c &

################################## WAsm SIMD ##################################
### C2 micro-kernels
tools/xngen src/qs8-gemm/MRx4c2-wasmsimd-dot16x2.c.in -D MR=1 -D VARIANT=LD64 -D REQUANTIZATION= -D DATATYPE=QD8 -o src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x4c2-minmax-wasmsimd-dot16x2-ld64.c &
Expand Down
5 changes: 5 additions & 0 deletions scripts/generate-qs8-igemm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1152,4 +1152,9 @@ tools/xngen src/qs8-igemm/c4-avx512amx.c.in -D GFNI=0 -D PREFETCH=0 -D MR=7 -D
tools/xngen src/qs8-igemm/c4-avx512amx.c.in -D GFNI=0 -D PREFETCH=0 -D MR=16 -D NR=64 -D DATATYPE=QD8_F16 -D REQUANTIZATION= -o src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-16x64c4-minmax-avx512amx.c &
tools/xngen src/qs8-igemm/c4-avx512amx.c.in -D GFNI=0 -D PREFETCH=1 -D MR=16 -D NR=64 -D DATATYPE=QD8_F16 -D REQUANTIZATION= -o src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-16x64c4-minmax-avx512amx-prfm.c &

################################ RISC-V Vector ################################
tools/xngen src/qs8-igemm/rvv.c.in -D MR=1 -D NR=m4 -D REQUANTIZATION=FP32 -D DATATYPE=QC8 -o src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x4v-minmax-fp32-rvv.c &
tools/xngen src/qs8-igemm/rvv.c.in -D MR=4 -D NR=m4 -D REQUANTIZATION=FP32 -D DATATYPE=QC8 -o src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x4v-minmax-fp32-rvv.c &


wait
12 changes: 12 additions & 0 deletions src/configs/gemm-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -3785,6 +3785,18 @@ static void init_qs8_qc8w_gemm_config(void) {
qs8_qc8w_gemm_config.mr = 4;
qs8_qc8w_gemm_config.nr = 4;
}
#elif XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4v__rvv);
qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4v__rvv);
qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4v__rvv);
qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4v__rvv);
qs8_qc8w_gemm_config.init.qs8_qc8w = xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params;
qs8_qc8w_gemm_config.pack_igemm_goki = (xnn_pack_conv_goki_w_fn) xnn_pack_qs8_conv_goki_w;
qs8_qc8w_gemm_config.pack_igemm_kgo = (xnn_pack_conv_kgo_w_fn) xnn_pack_qs8_conv_kgo_w;
qs8_qc8w_gemm_config.pack_deconv_goki = (xnn_pack_deconv_goki_w_fn) xnn_pack_qs8_deconv_goki_w;
qs8_qc8w_gemm_config.mr = 4;
qs8_qc8w_gemm_config.nr = 4 * hardware_config->vlenb / sizeof(int32_t);
#else
qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4__scalar_lrintf);
qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(3)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4__scalar_lrintf);
Expand Down
215 changes: 152 additions & 63 deletions src/qs8-gemm/rvv.c.in
Original file line number Diff line number Diff line change
@@ -1,46 +1,59 @@
// Copyright 2024 SiFive, Inc.
// Copyright 2024 Microchip
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Copyright 2025 Microchip

//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert DATATYPE in ["QD8", "QC4"]
$assert DATATYPE in ["QD8", "QC4"] or REQUANTIZATION in ["FP32"]
$assert DATATYPE in ["QC8", "QD8", "QC4", "QU8", "QS8"]
$assert MR >= 1
$assert NR in ["m4", "m8"]
$OUT_LMUL = NR
$IN_LMUL = {"m4": "m1", "m8": "m2"}[OUT_LMUL]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesnt seem like most xnnpack kernels with respect to LMUL, but I do see that it is similar to f32-gemm.
Consider taking lmul as a parameter and generating 1v, 2v, 4v and 8v variations of the kernel.
Can be done as a followup

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it takse it as parameter but using naming NR.
$assert NR in ["m4", "m8"]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I followed the pattern used for other RVV generators which seemed to make sense to me. NR is providing the LMUL to use. So rather than NR=4, using NR=m4 signifies this better (imho), which on the K1 (vlen=256) results in nr = 4 * (256/8) / sizeof(int32_t) = 32.

$INTERMEDIATE_MLUL = {"m4": "m2", "m8": "m4"}[OUT_LMUL]
$INTER_LMUL = {"m4": "m2", "m8": "m4"}[OUT_LMUL]
#include <assert.h>

#include <riscv_vector.h>

#include "xnnpack/gemm.h"
#include "xnnpack/math.h"

$DATATYPE_SPEC = {"QD8": "qd8_f32_qc8w", "QC4": "qd8_f32_qc4w"}[DATATYPE]
$PARAMS_TYPE = {"QD8": "union xnn_f32_minmax_params", "QC4": "struct xnn_f32_qc4w_minmax_params"}[DATATYPE]
void xnn_${DATATYPE_SPEC}_gemm_minmax_ukernel_${MR}x${OUT_LMUL[1]}v__rvv(
$DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QD8": "qd8_f32_qc8w", "QC4": "qd8_f32_qc4w", "QS8": "qs8", "QU8": "qu8"}[DATATYPE]
$PARAMS_TYPE = {"QC8": "union xnn_qs8_qc8w_conv_minmax_params", "QD8": "union xnn_f32_minmax_params", "QC4": "struct xnn_f32_qc4w_minmax_params", "QS8": "union xnn_qs8_conv_minmax_params", "QU8": "union xnn_qu8_conv_minmax_params"}[DATATYPE]
$if DATATYPE in ["QC8", "QS8", "QU8"]:
$REQUANTIZATION_SPEC = "_" + REQUANTIZATION.lower() if REQUANTIZATION else ""
$PARAMS_STRUCT = REQUANTIZATION.lower() + "_scalar"
$else:
$REQUANTIZATION_SPEC = ""
$PARAMS_STRUCT = ""
$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
$OUT_T = {"QC8": "int8_t", "QD8": "float", "QC4": "float", "QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]
void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${OUT_LMUL[1]}v__rvv(
size_t mr,
size_t nc,
size_t kc,
const int8_t* restrict a,
const ${XINT8_T}* restrict a,
size_t a_stride,
const void* restrict w,
float* restrict c,
${OUT_T}* restrict c,
size_t cm_stride,
size_t cn_stride,
const ${PARAMS_TYPE} params[restrict XNN_MIN_ELEMENTS(1)],
const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)])
$if DATATYPE in ["QD8", "QC4"]:
const ${PARAMS_TYPE} params[restrict XNN_MIN_ELEMENTS(1)],
const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)])
$else:
const ${PARAMS_TYPE} params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(mr != 0);
assert(mr <= ${MR});
assert(nc != 0);
assert(kc != 0);

const int8_t* a0 = a;
float* c0 = c;
const ${XINT8_T}* a0 = a;
${OUT_T}* c0 = c;
$for M in range(1, MR):
const int8_t* a${M} = (const int8_t*) ((uintptr_t) a${M-1} + a_stride);
float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
const ${XINT8_T}* a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M-1} + a_stride);
${OUT_T}* c${M} = (${OUT_T}*) ((uintptr_t) c${M-1} + cm_stride);
$if M % 2 == 0:
if XNN_UNPREDICTABLE(mr <= ${M}) {
a${M} = a${M-1};
Expand All @@ -59,20 +72,36 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax_ukernel_${MR}x${OUT_LMUL[1]}v__rvv(

const size_t nr = __riscv_vsetvlmax_e32${OUT_LMUL}();
size_t vl = nr;
$if DATATYPE == "QC4":

$if DATATYPE not in ["QD8", "QC4"]:
$if REQUANTIZATION == "FP32":
$if DATATYPE != "QC8":
const float vscale = params->${PARAMS_STRUCT}.scale;
const int32_t output_min_less_zero_point = (int32_t) params->${PARAMS_STRUCT}.output_min - (int32_t) params->${PARAMS_STRUCT}.output_zero_point;
const int32_t output_max_less_zero_point = (int32_t) params->${PARAMS_STRUCT}.output_max - (int32_t) params->${PARAMS_STRUCT}.output_zero_point;
const int32_t output_zero_point = params->${PARAMS_STRUCT}.output_zero_point;
$if DATATYPE == "QU8":
const int32_t vb_zero_point = params->${PARAMS_STRUCT}.kernel_zero_point;
$elif DATATYPE == "QC4":
kc = round_up_po2(kc, 2);
do {
if XNN_UNLIKELY(nc < nr) {
vl = __riscv_vsetvl_e32${OUT_LMUL}(nc);
}
nc = nc - vl;

vint32${OUT_LMUL}_t vksum = __riscv_vle32_v_i32${OUT_LMUL}((const int32_t*)w, vl);
$if DATATYPE in ["QD8", "QC4"]:
vint32${OUT_LMUL}_t vksum = __riscv_vle32_v_i32${OUT_LMUL}((const int32_t*)w, vl);
$for M in range(MR):
const int32_t vinput_zero_point${M} = quantization_params[${M}].zero_point;
$for M in range(MR):
vint32${OUT_LMUL}_t vacc${M} = __riscv_vmul_vx_i32${OUT_LMUL}(vksum, vinput_zero_point${M}, vl);
$else:
vint32${OUT_LMUL}_t vacc0 = __riscv_vle32_v_i32${OUT_LMUL}((const int32_t*)w, vl);
$for M in range(1, MR):
vint32${OUT_LMUL}_t vacc${M} = vacc0;

w = (const int32_t*) w + nr;
$for M in range(MR):
const int32_t vinput_zero_point${M} = quantization_params[${M}].zero_point;
$for M in range(MR):
vint32${OUT_LMUL}_t vacc${M} = __riscv_vmul_vx_i32${OUT_LMUL}(vksum, vinput_zero_point${M}, vl);

size_t k = kc;
$if DATATYPE == "QC4":
Expand All @@ -87,60 +116,120 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax_ukernel_${MR}x${OUT_LMUL[1]}v__rvv(
const vint8${IN_LMUL}_t vbc1 = __riscv_vand_vx_i8${IN_LMUL}(vbi, 0xF0, vl);

$for M in range(MR):
vint16${INTERMEDIATE_MLUL}_t va${M}bc0 = __riscv_vwmul_vx_i16${INTERMEDIATE_MLUL}(vbc0, va${M}c0, vl);
vint16${INTER_LMUL}_t va${M}bc0 = __riscv_vwmul_vx_i16${INTER_LMUL}(vbc0, va${M}c0, vl);
vacc${M} = __riscv_vwadd_wv_i32${OUT_LMUL}(vacc${M}, va${M}bc0, vl);
vint16${INTERMEDIATE_MLUL}_t va${M}bc1 = __riscv_vwmul_vx_i16${INTERMEDIATE_MLUL}(vbc1, va${M}c1, vl);
vint16${INTER_LMUL}_t va${M}bc1 = __riscv_vwmul_vx_i16${INTER_LMUL}(vbc1, va${M}c1, vl);
vacc${M} = __riscv_vwadd_wv_i32${OUT_LMUL}(vacc${M}, va${M}bc1, vl);
}
$else:
do {
$for M in range(MR):
const int8_t va${M} = *a${M}++;
const vint8${IN_LMUL}_t vb = __riscv_vle8_v_i8${IN_LMUL}((const int8_t*) w, vl);
w = (const int8_t*) w + nr;
$if DATATYPE == "QU8":
const int32_t va${M} = (int32_t)(uint32_t) *a${M}++;
$else:
const int32_t va${M} = (int32_t) *a${M}++;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input is a single byte, so this will be extremely slow. It may be a good starting point, but will not be fast enough to use in production. There are various alternatives

  1. neon reads a full vector and uses lanes to get to each byte in a vector. is there something similar where a multiple bytes are read into a vector and then multiply against values in the vector?
    e.g. read 64 bits, unroll the loop to 8 channels, and multiply the weights against each byte in the input?
  2. KR=8 with partial sums. On Intel we read 8 bytes of input, broadcast those 8 bytes to a full vector and multiply full vectors. Then outside the loop, do horizontal sums of 8 values (int32) to produce the final single output value.
  3. SR=4 on intel the 'shuffle' kernels read full vectors, multiple/accumulator full vectors and do a rotate of the input after each multiply. This idea could be extended to RVV, but it would be a little complicated, not knowing the vlen until runtime, and affecting packing.
  4. dotproduct. ultimately we want to use dot product, which would be C4. On Intel we extend that to C8 using a partial sum, but also do a python generator for cpus that dont have dotproduct and implement a dotproduct instruction with a function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1. neon reads a full vector and uses lanes to get to each byte in a vector. is there something similar where a multiple bytes are read into a vector and then multiply against values in the vector?

RISC-V Side we can use vle + several vrgather.vx to achieve.
For instance, load 8-bit input with vl 4. Unroll k-loop 8 time

vi = vle(a, 8 /*vl*/);
va_0 = vrgather(vi, 0, vlmax);
va_1 = vrgather(vi, 1, vlmax);
va_2 = vrgather(vi, 2, vlmax);
va_3 = vrgather(vi, 3, vlmax);
> e.g. read 64 bits, unroll the loop to 8 channels, and multiply the weights against each byte in the input?

=> This is fast and easy way to speed up current code.


2. KR=8 with partial sums. On Intel we read 8 bytes of input, broadcast those 8 bytes to a full vector and multiply full vectors. Then outside the loop, do horizontal sums of 8 values (int32) to produce the final single output value.

This complicates the code when we want to combine results outside the k-loop .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KR=8 would require 3 paired adds to reduce horizontally.
I'm pretty sure this would be a substantial performance improvement, but the change is substantial enough it would be a different kernel.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the vrgather reasonably fast?
Another option is using int64 math for 'a'
uint64_t a = load
a0 = a & 255;
a1 = (a >> 8) & 255;
etc
with __riscv_vmacc_vx_i32 does it use the entire 'x' register, or just the low 8 bits? If its an 8 bit multiply and ignores the upper bits, you could remove the AND.
On most cpus (e.g. cortex a53) you can issue an integer instruction as well as a vector instruction, per cycle, so the int64 math may be relatively fast.

Has performance been benchmarked?
Typically qs8 gemm would be 4x faster than f32 gemm and 2x faster than f16 gemm, but it depends alot on implementation details.
A vectorized kernel should also be substantially faster than a scalar kernel... on the order of 10x, but if the scalar kernel auto vectorized, then 2-4x.
My concern is the 1 byte input reads are impractical and this kernel will only be marginally faster than scalar. Thats why we dont have ARM and Intel kernels like this. Even though ARM doesnt support lanes for bytes, we convert bytes to shorts and use lanes, or implement KR=8
The KR=8 on arm uses vpadal to do paired adds and accumulate

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance isn't quite as bad as you think @fbarchard since nr=32. Nevertheless I'll take a look at the K=8 unrolling, as Bruce suggests, which is what I had done a long time ago for openblas gemm rvv. At that time I had found vrgather to be quite slow, and so I don't think that will be a better option here.

The test cases all pass, but I am running into an apparent heap corruption with qs8-qc8w-gemm-fp32.cc which I'll try and track down. Running qu8-gemm-fp32-bench produces this snip of results. Roughly 10x faster. (I'll add qu8 to this PR shortly). We could go up to 7x4v without register spilling.

Scalar:
qu8_gemm_minmax_fp32_ukernel_1x4__scalar_lrintf/mobilenet_v1/M:12544/N:32/K:27/real_time 44677547 ns 44678404 ns 16 OPS=485.166M/s cpufreq=1.6G
qu8_gemm_minmax_fp32_ukernel_4x4__scalar_lrintf/mobilenet_v1/M:12544/N:32/K:27/real_time 33292034 ns 33293643 ns 21 OPS=651.088M/s cpufreq=1.6G

RVV:
qu8_gemm_minmax_fp32_ukernel_1x4v__rvv/mobilenet_v1/M:12544/N:32/K:27/real_time 5287535 ns 5291038 ns 131 OPS=4.09946G/s cpufreq=1.6G
qu8_gemm_minmax_fp32_ukernel_4x4v__rvv/mobilenet_v1/M:12544/N:32/K:27/real_time 2245284 ns 2245234 ns 311 OPS=9.65403G/s cpufreq=1.6G

For "dotproduct" there is no RVV standard vector instruction. However there is SiFive's "XSfvqmaccqoq" which is effectively that for this particular use case, and which I hope to give a try here down the line. https://github.com/llvm/llvm-project/blob/main/llvm/docs/RISCVUsage.rst#experimental-extensions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current code LGTM.
I think Frank's proposal is good alternative way to optimize load a in int 8 case.

uint64_t a = load
a0 = a & 255;
a1 = (a >> 8) & 255;

Maybe we can have further optimizations(optimize a, k != 1, XSfvqmaccqoq...) in future patches?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @bhbruce. It would be nice to get this PR approved and merged instead of adding too much more to it. I do have few additional changes that I was going to push, adding qu8_gemm/qu8_igemm, and completing the config of qd8-f32-qc8w-gemm by adding qd8-f32-qc8w-igemm. But I can push this all in to this PR if you and @fbarchard would prefer.

I would greatly prefer to hold off on further optimizations until a next PR. With this one we are >10x scalar.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unit tests pass but as noted there is an issue with qs8-qc8w-bench-fp32 which results in heap corruption by pack() and a subsequent malloc failure.

The allocation of the w buffer here does not appear to be sufficient for the weights when nr is large (as in the rvv case). Increasing the w buffer size resolves the problem, although I'm not clear on the exact calculation that should be used.

https://github.com/google/XNNPACK/blob/master/bench/gemm-benchmark.cc#L150


$if DATATYPE == "QU8":
const vuint8${IN_LMUL}_t vb = __riscv_vle8_v_u8${IN_LMUL}((const uint8_t*) w, vl);
const vint32${OUT_LMUL}_t vb0 = __riscv_vsub_vx_i32${OUT_LMUL}(__riscv_vreinterpret_i32${OUT_LMUL}(__riscv_vzext_vf4(vb, vl)), vb_zero_point, vl);
$else:
const vint8${IN_LMUL}_t vb = __riscv_vle8_v_i8${IN_LMUL}((const int8_t*) w, vl);
const vint32${OUT_LMUL}_t vb0 = __riscv_vsext_vf4(vb, vl);

w = (const ${XINT8_T}*) w + nr;

$for M in range(MR):
vint16${INTERMEDIATE_MLUL}_t va${M}b = __riscv_vwmul_vx_i16${INTERMEDIATE_MLUL}(vb, va${M}, vl);
vacc${M} = __riscv_vwadd_wv_i32${OUT_LMUL}(vacc${M}, va${M}b, vl);
vacc${M} = __riscv_vmacc_vx_i32${OUT_LMUL}(vacc${M}, va${M}, vb0, vl);

k -= sizeof(int8_t);
k -= sizeof(${XINT8_T});
} while (k != 0);
$if DATATYPE == "QC4":

$if DATATYPE in ["QD8", "QC4"]:
$if DATATYPE == "QC4":
$for M in range(MR):
vacc${M} = __riscv_vsra_vx_i32${OUT_LMUL}(vacc${M}, 4, vl);
// i32 -> f32
$for M in range(MR):
vfloat32${OUT_LMUL}_t vout${M} = __riscv_vfcvt_f_x_v_f32${OUT_LMUL}(vacc${M}, vl);

// vout * input_scale
$for M in range(MR):
const float vinput_scale${M} = quantization_params[${M}].inv_scale;
$for M in range(MR):
vacc${M} = __riscv_vsra_vx_i32${OUT_LMUL}(vacc${M}, 4, vl);
// i32 -> f32
$for M in range(MR):
vfloat32${OUT_LMUL}_t vout${M} = __riscv_vfcvt_f_x_v_f32${OUT_LMUL}(vacc${M}, vl);

// vout * input_scale
$for M in range(MR):
const float vinput_scale${M} = quantization_params[${M}].inv_scale;
$for M in range(MR):
vout${M} = __riscv_vfmul_vf_f32${OUT_LMUL}(vout${M}, vinput_scale${M}, vl);

const vfloat32${OUT_LMUL}_t vfilter_output_scale = __riscv_vle32_v_f32${OUT_LMUL}((const float*) w, vl);
w = (const float*) w + nr;
$for M in range(MR):
vout${M} = __riscv_vfmul_vv_f32${OUT_LMUL}(vout${M}, vfilter_output_scale, vl);

const vfloat32${OUT_LMUL}_t vbias = __riscv_vle32_v_f32${OUT_LMUL}((const float*) w, vl);
w = (const float*) w + nr;
$for M in range(MR):
vout${M} = __riscv_vfadd_vv_f32${OUT_LMUL}(vout${M}, vbias, vl);

const float vmin = params->scalar.min;
$for M in range(MR):
vout${M} = __riscv_vfmax_vf_f32${OUT_LMUL}(vout${M}, vmin, vl);
const float vmax = params->scalar.max;
$for M in range(MR):
vout${M} = __riscv_vfmin_vf_f32${OUT_LMUL}(vout${M}, vmax, vl);

// store ${MR} x vl results to c
$for M in range(MR):
__riscv_vse32_v_f32${OUT_LMUL}(c${M}, vout${M}, vl);
c${M} = (float*) ((uintptr_t) c${M} + cn_stride);

$for M in range(MR):
a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
vout${M} = __riscv_vfmul_vf_f32${OUT_LMUL}(vout${M}, vinput_scale${M}, vl);

const vfloat32${OUT_LMUL}_t vfilter_output_scale = __riscv_vle32_v_f32${OUT_LMUL}((const float*) w, vl);
w = (const float*) w + nr;
$for M in range(MR):
vout${M} = __riscv_vfmul_vv_f32${OUT_LMUL}(vout${M}, vfilter_output_scale, vl);

const vfloat32${OUT_LMUL}_t vbias = __riscv_vle32_v_f32${OUT_LMUL}((const float*) w, vl);
w = (const float*) w + nr;
$for M in range(MR):
vout${M} = __riscv_vfadd_vv_f32${OUT_LMUL}(vout${M}, vbias, vl);

const float vmin = params->scalar.min;
$for M in range(MR):
vout${M} = __riscv_vfmax_vf_f32${OUT_LMUL}(vout${M}, vmin, vl);
const float vmax = params->scalar.max;
$for M in range(MR):
vout${M} = __riscv_vfmin_vf_f32${OUT_LMUL}(vout${M}, vmax, vl);

// store ${MR} x vl results to c
$for M in range(MR):
__riscv_vse32_v_f32${OUT_LMUL}(c${M}, vout${M}, vl);
c${M} = (float*) ((uintptr_t) c${M} + cn_stride);

$for M in range(MR):
a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc);
$else:
$if REQUANTIZATION == "FP32":
$for M in range(MR):
vfloat32${OUT_LMUL}_t vfacc${M} = __riscv_vfcvt_f_x_v_f32${OUT_LMUL}(vacc${M}, vl);

$if DATATYPE == "QC8":
const vfloat32${OUT_LMUL}_t vscale = __riscv_vle32_v_f32${OUT_LMUL}((const float*) w, vl);
$for M in range(MR):
vfacc${M} = __riscv_vfmul_vv_f32${OUT_LMUL}(vfacc${M}, vscale, vl);
w = (const float*) w + nr;
$else:
$for M in range(MR):
vfacc${M} = __riscv_vfmul_vf_f32${OUT_LMUL}(vfacc${M}, vscale, vl);

$for M in range(MR):
vfacc${M} = __riscv_vfmax_vf_f32${OUT_LMUL}(vfacc${M}, output_min_less_zero_point, vl);
$for M in range(MR):
vfacc${M} = __riscv_vfmin_vf_f32${OUT_LMUL}(vfacc${M}, output_max_less_zero_point, vl);

$if DATATYPE == "QU8":
$for M in range(MR):
vuint16${INTER_LMUL}_t vout${M} = __riscv_vfncvt_xu(vfacc${M}, vl);

$for M in range(MR):
vout${M} = __riscv_vadd_vx_u16${INTER_LMUL}(vout${M}, (uint16_t) output_zero_point, vl);

$for M in range(MR):
vuint8${IN_LMUL}_t vout8${M} = __riscv_vnclipu_wx_u8${IN_LMUL}(vout${M}, 0, vl);

$for M in range(MR):
__riscv_vse8_v_u8${IN_LMUL}(c${M}, vout8${M}, vl);
c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride);
$else:
$for M in range(MR):
vint16${INTER_LMUL}_t vout${M} = __riscv_vfncvt_x(vfacc${M}, vl);

$for M in range(MR):
vout${M} = __riscv_vadd_vx_i16${INTER_LMUL}(vout${M}, (int16_t) output_zero_point, vl);

$for M in range(MR):
vint8${IN_LMUL}_t vout8${M} = __riscv_vncvt_x_x_w_i8${IN_LMUL}(vout${M}, vl);

$for M in range(MR):
__riscv_vse8_v_i8${IN_LMUL}(c${M}, vout8${M}, vl);
c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride);

$for M in range(MR):
a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc);

} while (nc != 0);
}
Loading