-
Notifications
You must be signed in to change notification settings - Fork 386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RVV] Add qs8-gemm/igemm support for risc-v #7639
base: master
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,59 @@ | ||
// Copyright 2024 SiFive, Inc. | ||
// Copyright 2024 Microchip | ||
// | ||
// This source code is licensed under the BSD-style license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
$assert DATATYPE in ["QD8", "QC4"] | ||
$assert DATATYPE in ["QD8", "QC4"] or REQUANTIZATION in ["FP32"] | ||
$assert DATATYPE in ["QC8", "QD8", "QC4", "QU8", "QS8"] | ||
$assert MR >= 1 | ||
$assert NR in ["m4", "m8"] | ||
$OUT_LMUL = NR | ||
$IN_LMUL = {"m4": "m1", "m8": "m2"}[OUT_LMUL] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesnt seem like most xnnpack kernels with respect to LMUL, but I do see that it is similar to f32-gemm. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that it takse it as parameter but using naming NR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. I followed the pattern used for other RVV generators which seemed to make sense to me. NR is providing the LMUL to use. So rather than NR=4, using NR=m4 signifies this better (imho), which on the K1 (vlen=256) results in nr = 4 * (256/8) / sizeof(int32_t) = 32. |
||
$INTERMEDIATE_MLUL = {"m4": "m2", "m8": "m4"}[OUT_LMUL] | ||
$INTER_LMUL = {"m4": "m2", "m8": "m4"}[OUT_LMUL] | ||
#include <assert.h> | ||
|
||
#include <riscv_vector.h> | ||
|
||
#include "xnnpack/gemm.h" | ||
#include "xnnpack/math.h" | ||
|
||
$DATATYPE_SPEC = {"QD8": "qd8_f32_qc8w", "QC4": "qd8_f32_qc4w"}[DATATYPE] | ||
$PARAMS_TYPE = {"QD8": "union xnn_f32_minmax_params", "QC4": "struct xnn_f32_qc4w_minmax_params"}[DATATYPE] | ||
void xnn_${DATATYPE_SPEC}_gemm_minmax_ukernel_${MR}x${OUT_LMUL[1]}v__rvv( | ||
$DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QD8": "qd8_f32_qc8w", "QC4": "qd8_f32_qc4w", "QS8": "qs8", "QU8": "qu8"}[DATATYPE] | ||
$PARAMS_TYPE = {"QC8": "union xnn_qs8_qc8w_conv_minmax_params", "QD8": "union xnn_f32_minmax_params", "QC4": "struct xnn_f32_qc4w_minmax_params", "QS8": "union xnn_qs8_conv_minmax_params", "QU8": "union xnn_qu8_conv_minmax_params"}[DATATYPE] | ||
$if DATATYPE in ["QC8", "QS8", "QU8"]: | ||
$REQUANTIZATION_SPEC = "_" + REQUANTIZATION.lower() if REQUANTIZATION else "" | ||
$PARAMS_STRUCT = REQUANTIZATION.lower() + "_scalar" | ||
$else: | ||
$REQUANTIZATION_SPEC = "" | ||
$PARAMS_STRUCT = "" | ||
$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" | ||
$OUT_T = {"QC8": "int8_t", "QD8": "float", "QC4": "float", "QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE] | ||
void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${OUT_LMUL[1]}v__rvv( | ||
size_t mr, | ||
size_t nc, | ||
size_t kc, | ||
const int8_t* restrict a, | ||
const ${XINT8_T}* restrict a, | ||
size_t a_stride, | ||
const void* restrict w, | ||
float* restrict c, | ||
${OUT_T}* restrict c, | ||
size_t cm_stride, | ||
size_t cn_stride, | ||
const ${PARAMS_TYPE} params[restrict XNN_MIN_ELEMENTS(1)], | ||
const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)]) | ||
$if DATATYPE in ["QD8", "QC4"]: | ||
const ${PARAMS_TYPE} params[restrict XNN_MIN_ELEMENTS(1)], | ||
const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)]) | ||
$else: | ||
const ${PARAMS_TYPE} params[restrict XNN_MIN_ELEMENTS(1)]) | ||
{ | ||
assert(mr != 0); | ||
assert(mr <= ${MR}); | ||
assert(nc != 0); | ||
assert(kc != 0); | ||
|
||
const int8_t* a0 = a; | ||
float* c0 = c; | ||
const ${XINT8_T}* a0 = a; | ||
${OUT_T}* c0 = c; | ||
$for M in range(1, MR): | ||
const int8_t* a${M} = (const int8_t*) ((uintptr_t) a${M-1} + a_stride); | ||
float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride); | ||
const ${XINT8_T}* a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M-1} + a_stride); | ||
${OUT_T}* c${M} = (${OUT_T}*) ((uintptr_t) c${M-1} + cm_stride); | ||
$if M % 2 == 0: | ||
if XNN_UNPREDICTABLE(mr <= ${M}) { | ||
a${M} = a${M-1}; | ||
|
@@ -59,20 +72,36 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax_ukernel_${MR}x${OUT_LMUL[1]}v__rvv( | |
|
||
const size_t nr = __riscv_vsetvlmax_e32${OUT_LMUL}(); | ||
size_t vl = nr; | ||
$if DATATYPE == "QC4": | ||
|
||
$if DATATYPE not in ["QD8", "QC4"]: | ||
$if REQUANTIZATION == "FP32": | ||
$if DATATYPE != "QC8": | ||
const float vscale = params->${PARAMS_STRUCT}.scale; | ||
const int32_t output_min_less_zero_point = (int32_t) params->${PARAMS_STRUCT}.output_min - (int32_t) params->${PARAMS_STRUCT}.output_zero_point; | ||
const int32_t output_max_less_zero_point = (int32_t) params->${PARAMS_STRUCT}.output_max - (int32_t) params->${PARAMS_STRUCT}.output_zero_point; | ||
const int32_t output_zero_point = params->${PARAMS_STRUCT}.output_zero_point; | ||
$if DATATYPE == "QU8": | ||
const int32_t vb_zero_point = params->${PARAMS_STRUCT}.kernel_zero_point; | ||
$elif DATATYPE == "QC4": | ||
kc = round_up_po2(kc, 2); | ||
do { | ||
if XNN_UNLIKELY(nc < nr) { | ||
vl = __riscv_vsetvl_e32${OUT_LMUL}(nc); | ||
} | ||
nc = nc - vl; | ||
|
||
vint32${OUT_LMUL}_t vksum = __riscv_vle32_v_i32${OUT_LMUL}((const int32_t*)w, vl); | ||
$if DATATYPE in ["QD8", "QC4"]: | ||
vint32${OUT_LMUL}_t vksum = __riscv_vle32_v_i32${OUT_LMUL}((const int32_t*)w, vl); | ||
$for M in range(MR): | ||
const int32_t vinput_zero_point${M} = quantization_params[${M}].zero_point; | ||
$for M in range(MR): | ||
vint32${OUT_LMUL}_t vacc${M} = __riscv_vmul_vx_i32${OUT_LMUL}(vksum, vinput_zero_point${M}, vl); | ||
$else: | ||
vint32${OUT_LMUL}_t vacc0 = __riscv_vle32_v_i32${OUT_LMUL}((const int32_t*)w, vl); | ||
$for M in range(1, MR): | ||
vint32${OUT_LMUL}_t vacc${M} = vacc0; | ||
|
||
w = (const int32_t*) w + nr; | ||
$for M in range(MR): | ||
const int32_t vinput_zero_point${M} = quantization_params[${M}].zero_point; | ||
$for M in range(MR): | ||
vint32${OUT_LMUL}_t vacc${M} = __riscv_vmul_vx_i32${OUT_LMUL}(vksum, vinput_zero_point${M}, vl); | ||
|
||
size_t k = kc; | ||
$if DATATYPE == "QC4": | ||
|
@@ -87,60 +116,120 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax_ukernel_${MR}x${OUT_LMUL[1]}v__rvv( | |
const vint8${IN_LMUL}_t vbc1 = __riscv_vand_vx_i8${IN_LMUL}(vbi, 0xF0, vl); | ||
|
||
$for M in range(MR): | ||
vint16${INTERMEDIATE_MLUL}_t va${M}bc0 = __riscv_vwmul_vx_i16${INTERMEDIATE_MLUL}(vbc0, va${M}c0, vl); | ||
vint16${INTER_LMUL}_t va${M}bc0 = __riscv_vwmul_vx_i16${INTER_LMUL}(vbc0, va${M}c0, vl); | ||
vacc${M} = __riscv_vwadd_wv_i32${OUT_LMUL}(vacc${M}, va${M}bc0, vl); | ||
vint16${INTERMEDIATE_MLUL}_t va${M}bc1 = __riscv_vwmul_vx_i16${INTERMEDIATE_MLUL}(vbc1, va${M}c1, vl); | ||
vint16${INTER_LMUL}_t va${M}bc1 = __riscv_vwmul_vx_i16${INTER_LMUL}(vbc1, va${M}c1, vl); | ||
vacc${M} = __riscv_vwadd_wv_i32${OUT_LMUL}(vacc${M}, va${M}bc1, vl); | ||
} | ||
$else: | ||
do { | ||
$for M in range(MR): | ||
const int8_t va${M} = *a${M}++; | ||
const vint8${IN_LMUL}_t vb = __riscv_vle8_v_i8${IN_LMUL}((const int8_t*) w, vl); | ||
w = (const int8_t*) w + nr; | ||
$if DATATYPE == "QU8": | ||
const int32_t va${M} = (int32_t)(uint32_t) *a${M}++; | ||
$else: | ||
const int32_t va${M} = (int32_t) *a${M}++; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. input is a single byte, so this will be extremely slow. It may be a good starting point, but will not be fast enough to use in production. There are various alternatives
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
RISC-V Side we can use vle + several vrgather.vx to achieve.
=> This is fast and easy way to speed up current code.
This complicates the code when we want to combine results outside the k-loop . There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. KR=8 would require 3 paired adds to reduce horizontally. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is the vrgather reasonably fast? Has performance been benchmarked? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Performance isn't quite as bad as you think @fbarchard since nr=32. Nevertheless I'll take a look at the K=8 unrolling, as Bruce suggests, which is what I had done a long time ago for openblas gemm rvv. At that time I had found vrgather to be quite slow, and so I don't think that will be a better option here. The test cases all pass, but I am running into an apparent heap corruption with qs8-qc8w-gemm-fp32.cc which I'll try and track down. Running qu8-gemm-fp32-bench produces this snip of results. Roughly 10x faster. (I'll add qu8 to this PR shortly). We could go up to 7x4v without register spilling. Scalar: RVV: For "dotproduct" there is no RVV standard vector instruction. However there is SiFive's "XSfvqmaccqoq" which is effectively that for this particular use case, and which I hope to give a try here down the line. https://github.com/llvm/llvm-project/blob/main/llvm/docs/RISCVUsage.rst#experimental-extensions There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Current code LGTM.
Maybe we can have further optimizations(optimize a, k != 1, XSfvqmaccqoq...) in future patches? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @bhbruce. It would be nice to get this PR approved and merged instead of adding too much more to it. I do have few additional changes that I was going to push, adding qu8_gemm/qu8_igemm, and completing the config of qd8-f32-qc8w-gemm by adding qd8-f32-qc8w-igemm. But I can push this all in to this PR if you and @fbarchard would prefer. I would greatly prefer to hold off on further optimizations until a next PR. With this one we are >10x scalar. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The unit tests pass but as noted there is an issue with qs8-qc8w-bench-fp32 which results in heap corruption by pack() and a subsequent malloc failure. The allocation of the w buffer here does not appear to be sufficient for the weights when nr is large (as in the rvv case). Increasing the w buffer size resolves the problem, although I'm not clear on the exact calculation that should be used. https://github.com/google/XNNPACK/blob/master/bench/gemm-benchmark.cc#L150 |
||
|
||
$if DATATYPE == "QU8": | ||
const vuint8${IN_LMUL}_t vb = __riscv_vle8_v_u8${IN_LMUL}((const uint8_t*) w, vl); | ||
const vint32${OUT_LMUL}_t vb0 = __riscv_vsub_vx_i32${OUT_LMUL}(__riscv_vreinterpret_i32${OUT_LMUL}(__riscv_vzext_vf4(vb, vl)), vb_zero_point, vl); | ||
$else: | ||
const vint8${IN_LMUL}_t vb = __riscv_vle8_v_i8${IN_LMUL}((const int8_t*) w, vl); | ||
const vint32${OUT_LMUL}_t vb0 = __riscv_vsext_vf4(vb, vl); | ||
|
||
w = (const ${XINT8_T}*) w + nr; | ||
|
||
$for M in range(MR): | ||
vint16${INTERMEDIATE_MLUL}_t va${M}b = __riscv_vwmul_vx_i16${INTERMEDIATE_MLUL}(vb, va${M}, vl); | ||
vacc${M} = __riscv_vwadd_wv_i32${OUT_LMUL}(vacc${M}, va${M}b, vl); | ||
vacc${M} = __riscv_vmacc_vx_i32${OUT_LMUL}(vacc${M}, va${M}, vb0, vl); | ||
|
||
k -= sizeof(int8_t); | ||
k -= sizeof(${XINT8_T}); | ||
} while (k != 0); | ||
$if DATATYPE == "QC4": | ||
|
||
$if DATATYPE in ["QD8", "QC4"]: | ||
$if DATATYPE == "QC4": | ||
$for M in range(MR): | ||
vacc${M} = __riscv_vsra_vx_i32${OUT_LMUL}(vacc${M}, 4, vl); | ||
// i32 -> f32 | ||
$for M in range(MR): | ||
vfloat32${OUT_LMUL}_t vout${M} = __riscv_vfcvt_f_x_v_f32${OUT_LMUL}(vacc${M}, vl); | ||
|
||
// vout * input_scale | ||
$for M in range(MR): | ||
const float vinput_scale${M} = quantization_params[${M}].inv_scale; | ||
$for M in range(MR): | ||
vacc${M} = __riscv_vsra_vx_i32${OUT_LMUL}(vacc${M}, 4, vl); | ||
// i32 -> f32 | ||
$for M in range(MR): | ||
vfloat32${OUT_LMUL}_t vout${M} = __riscv_vfcvt_f_x_v_f32${OUT_LMUL}(vacc${M}, vl); | ||
|
||
// vout * input_scale | ||
$for M in range(MR): | ||
const float vinput_scale${M} = quantization_params[${M}].inv_scale; | ||
$for M in range(MR): | ||
vout${M} = __riscv_vfmul_vf_f32${OUT_LMUL}(vout${M}, vinput_scale${M}, vl); | ||
|
||
const vfloat32${OUT_LMUL}_t vfilter_output_scale = __riscv_vle32_v_f32${OUT_LMUL}((const float*) w, vl); | ||
w = (const float*) w + nr; | ||
$for M in range(MR): | ||
vout${M} = __riscv_vfmul_vv_f32${OUT_LMUL}(vout${M}, vfilter_output_scale, vl); | ||
|
||
const vfloat32${OUT_LMUL}_t vbias = __riscv_vle32_v_f32${OUT_LMUL}((const float*) w, vl); | ||
w = (const float*) w + nr; | ||
$for M in range(MR): | ||
vout${M} = __riscv_vfadd_vv_f32${OUT_LMUL}(vout${M}, vbias, vl); | ||
|
||
const float vmin = params->scalar.min; | ||
$for M in range(MR): | ||
vout${M} = __riscv_vfmax_vf_f32${OUT_LMUL}(vout${M}, vmin, vl); | ||
const float vmax = params->scalar.max; | ||
$for M in range(MR): | ||
vout${M} = __riscv_vfmin_vf_f32${OUT_LMUL}(vout${M}, vmax, vl); | ||
|
||
// store ${MR} x vl results to c | ||
$for M in range(MR): | ||
__riscv_vse32_v_f32${OUT_LMUL}(c${M}, vout${M}, vl); | ||
c${M} = (float*) ((uintptr_t) c${M} + cn_stride); | ||
|
||
$for M in range(MR): | ||
a${M} = (const int8_t*) ((uintptr_t) a${M} - kc); | ||
vout${M} = __riscv_vfmul_vf_f32${OUT_LMUL}(vout${M}, vinput_scale${M}, vl); | ||
|
||
const vfloat32${OUT_LMUL}_t vfilter_output_scale = __riscv_vle32_v_f32${OUT_LMUL}((const float*) w, vl); | ||
w = (const float*) w + nr; | ||
$for M in range(MR): | ||
vout${M} = __riscv_vfmul_vv_f32${OUT_LMUL}(vout${M}, vfilter_output_scale, vl); | ||
|
||
const vfloat32${OUT_LMUL}_t vbias = __riscv_vle32_v_f32${OUT_LMUL}((const float*) w, vl); | ||
w = (const float*) w + nr; | ||
$for M in range(MR): | ||
vout${M} = __riscv_vfadd_vv_f32${OUT_LMUL}(vout${M}, vbias, vl); | ||
|
||
const float vmin = params->scalar.min; | ||
$for M in range(MR): | ||
vout${M} = __riscv_vfmax_vf_f32${OUT_LMUL}(vout${M}, vmin, vl); | ||
const float vmax = params->scalar.max; | ||
$for M in range(MR): | ||
vout${M} = __riscv_vfmin_vf_f32${OUT_LMUL}(vout${M}, vmax, vl); | ||
|
||
// store ${MR} x vl results to c | ||
$for M in range(MR): | ||
__riscv_vse32_v_f32${OUT_LMUL}(c${M}, vout${M}, vl); | ||
c${M} = (float*) ((uintptr_t) c${M} + cn_stride); | ||
|
||
$for M in range(MR): | ||
a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc); | ||
$else: | ||
$if REQUANTIZATION == "FP32": | ||
$for M in range(MR): | ||
vfloat32${OUT_LMUL}_t vfacc${M} = __riscv_vfcvt_f_x_v_f32${OUT_LMUL}(vacc${M}, vl); | ||
|
||
$if DATATYPE == "QC8": | ||
const vfloat32${OUT_LMUL}_t vscale = __riscv_vle32_v_f32${OUT_LMUL}((const float*) w, vl); | ||
$for M in range(MR): | ||
vfacc${M} = __riscv_vfmul_vv_f32${OUT_LMUL}(vfacc${M}, vscale, vl); | ||
w = (const float*) w + nr; | ||
$else: | ||
$for M in range(MR): | ||
vfacc${M} = __riscv_vfmul_vf_f32${OUT_LMUL}(vfacc${M}, vscale, vl); | ||
|
||
$for M in range(MR): | ||
vfacc${M} = __riscv_vfmax_vf_f32${OUT_LMUL}(vfacc${M}, output_min_less_zero_point, vl); | ||
$for M in range(MR): | ||
vfacc${M} = __riscv_vfmin_vf_f32${OUT_LMUL}(vfacc${M}, output_max_less_zero_point, vl); | ||
|
||
$if DATATYPE == "QU8": | ||
$for M in range(MR): | ||
vuint16${INTER_LMUL}_t vout${M} = __riscv_vfncvt_xu(vfacc${M}, vl); | ||
|
||
$for M in range(MR): | ||
vout${M} = __riscv_vadd_vx_u16${INTER_LMUL}(vout${M}, (uint16_t) output_zero_point, vl); | ||
|
||
$for M in range(MR): | ||
vuint8${IN_LMUL}_t vout8${M} = __riscv_vnclipu_wx_u8${IN_LMUL}(vout${M}, 0, vl); | ||
|
||
$for M in range(MR): | ||
__riscv_vse8_v_u8${IN_LMUL}(c${M}, vout8${M}, vl); | ||
c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride); | ||
$else: | ||
$for M in range(MR): | ||
vint16${INTER_LMUL}_t vout${M} = __riscv_vfncvt_x(vfacc${M}, vl); | ||
|
||
$for M in range(MR): | ||
vout${M} = __riscv_vadd_vx_i16${INTER_LMUL}(vout${M}, (int16_t) output_zero_point, vl); | ||
|
||
$for M in range(MR): | ||
vint8${IN_LMUL}_t vout8${M} = __riscv_vncvt_x_x_w_i8${IN_LMUL}(vout${M}, vl); | ||
|
||
$for M in range(MR): | ||
__riscv_vse8_v_i8${IN_LMUL}(c${M}, vout8${M}, vl); | ||
c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride); | ||
|
||
$for M in range(MR): | ||
a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc); | ||
|
||
} while (nc != 0); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Copyright 2025 Microchip