Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RVV] Add qs8-gemm/igemm support for risc-v #7639

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

ken-unger
Copy link

  • Added qs8-qc8w-gemm and qs8-qc8w-igemm support for RVV
  • Generator script can support qu8 but I've left those kernels out given past comments on qu8 being deprecated.
  • Note that qd8 support was never productized (in gemm-config) but I will take care of that in a subsequent PR.
  • Tested on qemu and Spacemit K1.

//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert DATATYPE in ["QD8", "QC4"]
$assert DATATYPE in ["QD8", "QC4"] or REQUANTIZATION in ["FP32"]
$assert DATATYPE in ["QC8", "QD8", "QC4", "QU8", "QS8"]
$assert MR >= 1
$assert NR in ["m4", "m8"]
$OUT_LMUL = NR
$IN_LMUL = {"m4": "m1", "m8": "m2"}[OUT_LMUL]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesnt seem like most xnnpack kernels with respect to LMUL, but I do see that it is similar to f32-gemm.
Consider taking lmul as a parameter and generating 1v, 2v, 4v and 8v variations of the kernel.
Can be done as a followup

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it takse it as parameter but using naming NR.
$assert NR in ["m4", "m8"]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I followed the pattern used for other RVV generators which seemed to make sense to me. NR is providing the LMUL to use. So rather than NR=4, using NR=m4 signifies this better (imho), which on the K1 (vlen=256) results in nr = 4 * (256/8) / sizeof(int32_t) = 32.

$if DATATYPE == "QU8":
const int32_t va${M} = (int32_t)(uint32_t) *a${M}++;
$else:
const int32_t va${M} = (int32_t) *a${M}++;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input is a single byte, so this will be extremely slow. It may be a good starting point, but will not be fast enough to use in production. There are various alternatives

  1. neon reads a full vector and uses lanes to get to each byte in a vector. is there something similar where a multiple bytes are read into a vector and then multiply against values in the vector?
    e.g. read 64 bits, unroll the loop to 8 channels, and multiply the weights against each byte in the input?
  2. KR=8 with partial sums. On Intel we read 8 bytes of input, broadcast those 8 bytes to a full vector and multiply full vectors. Then outside the loop, do horizontal sums of 8 values (int32) to produce the final single output value.
  3. SR=4 on intel the 'shuffle' kernels read full vectors, multiple/accumulator full vectors and do a rotate of the input after each multiply. This idea could be extended to RVV, but it would be a little complicated, not knowing the vlen until runtime, and affecting packing.
  4. dotproduct. ultimately we want to use dot product, which would be C4. On Intel we extend that to C8 using a partial sum, but also do a python generator for cpus that dont have dotproduct and implement a dotproduct instruction with a function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1. neon reads a full vector and uses lanes to get to each byte in a vector. is there something similar where a multiple bytes are read into a vector and then multiply against values in the vector?

RISC-V Side we can use vle + several vrgather.vx to achieve.
For instance, load 8-bit input with vl 4. Unroll k-loop 8 time

vi = vle(a, 8 /*vl*/);
va_0 = vrgather(vi, 0, vlmax);
va_1 = vrgather(vi, 1, vlmax);
va_2 = vrgather(vi, 2, vlmax);
va_3 = vrgather(vi, 3, vlmax);
> e.g. read 64 bits, unroll the loop to 8 channels, and multiply the weights against each byte in the input?

=> This is fast and easy way to speed up current code.


2. KR=8 with partial sums. On Intel we read 8 bytes of input, broadcast those 8 bytes to a full vector and multiply full vectors. Then outside the loop, do horizontal sums of 8 values (int32) to produce the final single output value.

This complicates the code when we want to combine results outside the k-loop .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KR=8 would require 3 paired adds to reduce horizontally.
I'm pretty sure this would be a substantial performance improvement, but the change is substantial enough it would be a different kernel.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the vrgather reasonably fast?
Another option is using int64 math for 'a'
uint64_t a = load
a0 = a & 255;
a1 = (a >> 8) & 255;
etc
with __riscv_vmacc_vx_i32 does it use the entire 'x' register, or just the low 8 bits? If its an 8 bit multiply and ignores the upper bits, you could remove the AND.
On most cpus (e.g. cortex a53) you can issue an integer instruction as well as a vector instruction, per cycle, so the int64 math may be relatively fast.

Has performance been benchmarked?
Typically qs8 gemm would be 4x faster than f32 gemm and 2x faster than f16 gemm, but it depends alot on implementation details.
A vectorized kernel should also be substantially faster than a scalar kernel... on the order of 10x, but if the scalar kernel auto vectorized, then 2-4x.
My concern is the 1 byte input reads are impractical and this kernel will only be marginally faster than scalar. Thats why we dont have ARM and Intel kernels like this. Even though ARM doesnt support lanes for bytes, we convert bytes to shorts and use lanes, or implement KR=8
The KR=8 on arm uses vpadal to do paired adds and accumulate

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance isn't quite as bad as you think @fbarchard since nr=32. Nevertheless I'll take a look at the K=8 unrolling, as Bruce suggests, which is what I had done a long time ago for openblas gemm rvv. At that time I had found vrgather to be quite slow, and so I don't think that will be a better option here.

The test cases all pass, but I am running into an apparent heap corruption with qs8-qc8w-gemm-fp32.cc which I'll try and track down. Running qu8-gemm-fp32-bench produces this snip of results. Roughly 10x faster. (I'll add qu8 to this PR shortly). We could go up to 7x4v without register spilling.

Scalar:
qu8_gemm_minmax_fp32_ukernel_1x4__scalar_lrintf/mobilenet_v1/M:12544/N:32/K:27/real_time 44677547 ns 44678404 ns 16 OPS=485.166M/s cpufreq=1.6G
qu8_gemm_minmax_fp32_ukernel_4x4__scalar_lrintf/mobilenet_v1/M:12544/N:32/K:27/real_time 33292034 ns 33293643 ns 21 OPS=651.088M/s cpufreq=1.6G

RVV:
qu8_gemm_minmax_fp32_ukernel_1x4v__rvv/mobilenet_v1/M:12544/N:32/K:27/real_time 5287535 ns 5291038 ns 131 OPS=4.09946G/s cpufreq=1.6G
qu8_gemm_minmax_fp32_ukernel_4x4v__rvv/mobilenet_v1/M:12544/N:32/K:27/real_time 2245284 ns 2245234 ns 311 OPS=9.65403G/s cpufreq=1.6G

For "dotproduct" there is no RVV standard vector instruction. However there is SiFive's "XSfvqmaccqoq" which is effectively that for this particular use case, and which I hope to give a try here down the line. https://github.com/llvm/llvm-project/blob/main/llvm/docs/RISCVUsage.rst#experimental-extensions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current code LGTM.
I think Frank's proposal is good alternative way to optimize load a in int 8 case.

uint64_t a = load
a0 = a & 255;
a1 = (a >> 8) & 255;

Maybe we can have further optimizations(optimize a, k != 1, XSfvqmaccqoq...) in future patches?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @bhbruce. It would be nice to get this PR approved and merged instead of adding too much more to it. I do have few additional changes that I was going to push, adding qu8_gemm/qu8_igemm, and completing the config of qd8-f32-qc8w-gemm by adding qd8-f32-qc8w-igemm. But I can push this all in to this PR if you and @fbarchard would prefer.

I would greatly prefer to hold off on further optimizations until a next PR. With this one we are >10x scalar.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unit tests pass but as noted there is an issue with qs8-qc8w-bench-fp32 which results in heap corruption by pack() and a subsequent malloc failure.

The allocation of the w buffer here does not appear to be sufficient for the weights when nr is large (as in the rvv case). Increasing the w buffer size resolves the problem, although I'm not clear on the exact calculation that should be used.

https://github.com/google/XNNPACK/blob/master/bench/gemm-benchmark.cc#L150

Copy link
Collaborator

@fbarchard fbarchard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR... just some quick comments for now

@bhbruce
Copy link
Contributor

bhbruce commented Jan 10, 2025

Thanks for PR to complete the missing part of my previous PR. I'll continue to review this PR next week.

@@ -1,46 +1,59 @@
// Copyright 2024 SiFive, Inc.
// Copyright 2024 Microchip
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Copyright 2025 Microchip

Copy link
Collaborator

@fbarchard fbarchard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you run the unittests and benchmarks to make sure it works and is an improvement over the current gemm-config'ed scalar kernel?

@ken-unger
Copy link
Author

@fbarchard please approve so this can be merged. I will continue with a new RVV PR to add qu8_gemm/qu8_igemm, and complete the config of qd8-f32-qc8w-gemm by adding qd8-f32-qc8w-igemm. Additional performance improvements can be made in subsequent PRs.

On K1 (vlen=256), with this PR, we can see roughly a 10x-18x improvement. Test cases pass.

RVV:

qs8_qc8w_gemm_minmax_fp32_ukernel_1x4v__rvv/mobilenet_v1/M:12544/N:32/K:27/real_time 3857486 ns 3857206 ns 181 OPS=5.61921G/s cpufreq=1.6G
qs8_qc8w_gemm_minmax_fp32_ukernel_4x4v__rvv/mobilenet_v1/M:12544/N:32/K:27/real_time 1809814 ns 1809522 ns 386 OPS=11.9769G/s cpufreq=1.6G

Scalar:

qs8_qc8w_gemm_minmax_fp32_ukernel_1x4__scalar_lrintf/mobilenet_v1/M:12544/N:32/K:27/real_time 38987444 ns 38990368 ns 18 OPS=555.975M/s cpufreq=1.6G
qs8_qc8w_gemm_minmax_fp32_ukernel_4x4__scalar_lrintf/mobilenet_v1/M:12544/N:32/K:27/real_time 32470450 ns 32472973 ns 22 OPS=667.562M/s cpufreq=1.6G

@ken-unger
Copy link
Author

Just a friendly ping @fbarchard for your approval here, and also PR #7638. Both have the requested changes integrated and I'd like to move on to the next iteration of RVV support. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants