Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync : ggml #12104

Merged
merged 11 commits into from
Mar 3, 2025
Merged

sync : ggml #12104

merged 11 commits into from
Mar 3, 2025

Conversation

ggerganov
Copy link
Member

No description provided.

ggerganov and others added 7 commits February 28, 2025 09:09
… backend (ggml/1121)

* Support float16-to-float16 add/sub/mul/div operations in the CUDA backend

* Add fp16 support for add/sub/mul/div on the CPU backend

* Add test cases for fp16 add/sub/mul/div
It is used by Whisper talk-llama example.

Co-authored-by: Petter Reinholdtsen <[email protected]>
* Add small comment re: VSX to readme

Co-authored-by: midnight <[email protected]>
* whisper : support GGML_BACKEND_DL

* fix DTW crash

* whisper.objc : fix build - add ggml-cpp.h

---------

Co-authored-by: Georgi Gerganov <[email protected]>
* Support fp16 unary operations in the CUDA backend

* cpu: increase fp16 support for unary operators in the CPU backend

* cuda: increase fp16 support for unary operators in the CUDA backend

* Add test cases for fp16 unary operators

* metal: update supports_op for unary operators that don't support fp16, to prevent test-backend-ops from failing

* metal: fix PR comments for unary op support after fp16 unary tests
ggml-ci
@github-actions github-actions bot added script Script related testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Feb 28, 2025
@ggerganov
Copy link
Member Author

@cmdr2 The CUDA builds are failing with the following error after the recent changes (ggml-org/ggml#1125):

https://github.com/ggml-org/llama.cpp/actions/runs/13583176617/job/37972693333?pr=12104#step:7:140

Any suggestions how to fix it?

@github-actions github-actions bot added the Vulkan Issues specific to the Vulkan backend label Feb 28, 2025
@cmdr2
Copy link
Contributor

cmdr2 commented Feb 28, 2025

Taking a look

@cmdr2
Copy link
Contributor

cmdr2 commented Feb 28, 2025

Hi, some notes:

  1. It failed CUDA 11.7 but passed CUDA 12.4. I'm looking into this.
  2. Vulkan needs a few more supports_op entries. I'll post a patch for this shortly in this PR.
  3. Sorry for the typo in the SILU_BACK test case. The CUDA and CPU values are pretty close, but beyond the test threshold. And it's transient - the test passes on occasion. Maybe for now SILU_BACK can be restricted to FP32 again, since fp16 isn't probably a key target for training?

SILU_BACK results (for {10, 1, 1, 1} in the test):

f1: 0.449463, 0.209595, -0.402588, 0.226440, -0.096497, -0.304199, 0.871582, 0.040558, 0.272705, -0.024597,
f2: 0.449219, 0.209717, -0.402344, 0.226318, -0.096436, -0.304443, 0.871582, 0.040588, 0.272705, -0.024597,

Different enough to cross the threshold.

If this reasoning makes sense, I could remove the fp16 test case for SILU_BACK, and update supports_op to expect only FP32.

Thanks

@ggerganov
Copy link
Member Author

Thanks, sounds good.

@cmdr2
Copy link
Contributor

cmdr2 commented Feb 28, 2025

Update: The CUDA 11.7 failures are due to the lack of operator overloading for half data types before CUDA 12.2.

And it fails only when compiling for arch 50. Arch 60 onwards compile and work fine (even with CUDA 11.7)

CUDA 12.1 doc - no operator overloading

CUDA 12.2 doc - overloaded operators

I'm still working on this.

@cmdr2
Copy link
Contributor

cmdr2 commented Feb 28, 2025

tl;dr; - Since this is blocking ggml's sync with other repos, maybe we can put the new FP16 function calls behind the GGML_CUDA_F16 flag (to unblock the build), and I'll work in parallel on replicating the binbcast.cu approach for this?

Short version:
How about masking the FP16 function call inside #ifdef GGML_CUDA_F16, to prevent those lines from compiling on arch 50? So if a user wants to use GGML with float16, they'll have to compile it with -D GGML_CUDA_F16=1, which targets arch 60 onwards.

I don't like this solution tbh (would like always-on fp16 support), and would prefer a better solution, if available.

A better solution might be to replicate the binbcast.cu approach, because it manages to compile on arch 50 with CUDA 11.7 even though it uses half and overloaded operators.

Since this is blocking ggml's sync with other repos, maybe we can put the new FP16 lines behind the GGML_CUDA_F16 flag (to unblock the build), and I'll work in parallel on replicating the binbcast.cu approach for this?

Details:
This fails only when compiling for arch 50. It compiles fine for arch 60 onwards, and operator overloading works too (even with CUDA 11.7).

For e.g. using CUDA 11.7's nvcc and half_add.cu:

  • This fails to compile: nvcc -o half_add.exe half_add.cu "--generate-code=arch=compute_50,code=[compute_50,sm_50]"
  • This works: nvcc -o half_add.exe half_add.cu "--generate-code=arch=compute_61,code=[compute_61,sm_61]"

And even if I use __hadd() (instead of the + operator), it compiles for arch 50 but produces incorrect results.

So I suppose we shouldn't compile the half version for arch 50 anyway.

Unfortunately it looks like __CUDA__ARCH__ is device-side only, i.e. it isn't set on the CPU-side code.

@slaren
Copy link
Member

slaren commented Feb 28, 2025

It should work if you cast everything to float manually. You are not going to get any extra performance from using F16 math with one element at a time anyway.

@cmdr2
Copy link
Contributor

cmdr2 commented Feb 28, 2025

@slaren Are you referring to the binbcast.cu approach (whose kernel takes float args), but actually does fp16 operations without consuming extra VRAM?

With binbcast, fp16-fp16 addition is about 30-35% faster for me than fp32-fp32.

Same for clamp. I just tested clamp (the latest implementation) with a 1 GB tensor, and it takes 8ms with fp16, and 13ms with fp32. On my 3060 12 GB.

@slaren
Copy link
Member

slaren commented Feb 28, 2025

What I mean is to change the implementation of clamp to something like this:

dst[i] = (T)fminf(fmaxf((float)x[i], (float)min), (float)max);

You get better performance with F16 tensors since this kernel is entirely memory bound, but the math itself is also not any faster with F16 either, unless you are computing multiple values at a time (with half2 functions such as __hadd2). Casting all the operands to float and doing the math in F32 is not really any slower than doing the math in F16 directly.

@cmdr2
Copy link
Contributor

cmdr2 commented Feb 28, 2025

@slaren Makes sense. But this problem will continue with all the operators in unary.cu as well, which now have fp16 support.

@cmdr2
Copy link
Contributor

cmdr2 commented Feb 28, 2025

@slaren Would you suggest just casting to float in all those unary operators as well? I have to try that, not sure.

@slaren
Copy link
Member

slaren commented Feb 28, 2025

@slaren Would you suggest just casting to float in all those unary operators as well? I have to try that, not sure.

Yes, that's probably a good idea. All of the unary operators are almost certainly memory bound, so I don't think it would be even worth to write specific versions for half2.

@cmdr2
Copy link
Contributor

cmdr2 commented Feb 28, 2025

@slaren Thanks, I'm giving that a try now.

@ggerganov In the meanwhile, I've also pushed a "plan B" to my fp16-fix branch, which puts the new FP16 unary code paths behind a compile-time GGML_CUDA_F16 flag.

I've tested that test-backend-ops passes in both scenarios (with and without the flag). Without the flag, the fp16 unary tests are skipped, and nvcc does not attempt to compile the half version of those functions.

This is simply a "plan B" if the build needs to be unblocked urgently. I'd definitely prefer a better solution that doesn't need to do this.

I'm looking at @slaren 's suggestion now.

@ggerganov
Copy link
Member Author

No worries, there is nothing urgent. slaren's suggestion should work.

@cmdr2
Copy link
Contributor

cmdr2 commented Feb 28, 2025

Thanks. As a thought, would it be possible to use the same runners on ggml's branches too? Would help catch problems earlier, and make syncs less likely to bring surprises.

@cmdr2
Copy link
Contributor

cmdr2 commented Mar 3, 2025

Submitted a PR with the suggested change, thanks - ggml-org/ggml#1130

@ggerganov ggerganov merged commit dfd6b2c into master Mar 3, 2025
53 checks passed
@ggerganov ggerganov deleted the sync-ggml-25-02-28 branch March 3, 2025 16:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple Metal https://en.wikipedia.org/wiki/Metal_(API) ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs script Script related testing Everything test related Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants