Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda/cpu: Increase support for fp16 unary operations #1125

Merged
merged 7 commits into from
Feb 28, 2025

Conversation

cmdr2
Copy link
Collaborator

@cmdr2 cmdr2 commented Feb 26, 2025

// apologies for adding even more lines to ggml-cpu.c, I will templatize these in my next PR (to shrink the code again). I kept them as simple copies for easier readability during review, and to keep the PR focused on one thing.

This PR increases support for pure float16 operations (in CUDA and CPU) for the following unary operations: abs, sgn, neg, step, tanh, gelu, silu, silu_back, gelu_quick, relu, sigmoid, hardsigmoid, exp, hardswish, leaky_relu, sqr, sqrt, sin, cos, log, clamp.

Sorry for the large-sized PR, but it's actually just three simple things (repeated several times):

  1. cpu: copy the fp32 version and use the corresponding fp16 operation
  2. cuda: templatize the unary kernels to work with float or half
  3. cuda: add kernels for abs, sgn and log

Also: Added test cases for these operations in test-backend-ops, and updated supports_op for Metal (to avoid CI failure).

test-backend-ops passes.

Thanks!
PS: unary.cu could also be templatized further, since most of the functions are duplicates.

Comment on lines +1435 to +1439
inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
for (int i = 0; i < n; ++i) {
y[i] = GGML_FP32_TO_FP16(-GGML_FP16_TO_FP32(x[i]));
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are converting entire vectors to F16, there may be a significant performance advantage using ggml_fp32_to_fp16_row, since it can use F16C to convert multiple values in a single cycle. The inverse (ggml_fp16_to_fp32_row) could also be accelerated with F16C, but it does not have an implementation at the moment. However, the F16C implementation of these function would need to be moved to the CPU backend, currently it is in ggml.c which normally isn't compiled with support for these instructions (it's a leftover from when the CPU backend was split to a different target).

Not very important right now, just something to consider. When the code is ported to C++, it should be possible to abstract these conversions, and reuse the same code for all F16 ops.

Copy link
Collaborator Author

@cmdr2 cmdr2 Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, thanks! Strangely enough, the row-based version seems to be 10-15% slower than converting per cycle.

Maybe msvc isn't enabling F16C? Not sure. I'm using cmake -B build; cmake --build build --config Release on Windows with msvc command line. CMake says: Adding CPU backend variant ggml-cpu: /arch:AVX512 GGML_AVX512

I'm using a 1 GB fp16 tensor, and it takes 1088ms without row optimization, and 1330ms with row optimization.

I used this for vec_neg_f16:

inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
    float * y_f32 = (float *)malloc(sizeof(float) * n);
    for (int i = 0; i < n; ++i) {
        y_f32[i] = -GGML_FP16_TO_FP32(x[i]);
    }
    ggml_fp32_to_fp16_row(y_f32, y, n);
}

I'm happy to continue optimizing it in this PR, if you wish. Or we can pursue the optimization in another PR, after using templates, so that all the functions would optimized in one shot, like you said.

I'm okay with either, please let me know. Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you move the function to the CPU backend? Otherwise it will be built without the /arch flag.

Comment on lines +3 to +12
template <class T>
static __global__ void op_abs(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
}

dst[i] = fabsf(x[i]);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super important right now, but it would be nice to generalize all of these functions to a single kernel that takes the unary function to perform as a parameter. This could be done in a similar way as it is done in binbcast.cu.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm working on that PR next, using binbcast.cu as the template :)

For the unary operators in the CPU as well as unary.cu.

I kept these functions separate for now to have a single-focus PR, and do refactoring in a second PR.

@cmdr2
Copy link
Collaborator Author

cmdr2 commented Feb 27, 2025

@slaren Thanks for your comments! I've replied to them, please let me know if you'd like any changes.

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can fix the Metal backend with this patch:

diff --git a/src/ggml-metal/ggml-metal.m b/src/ggml-metal/ggml-metal.m
index 63e944e..ff88d2c 100644
--- a/src/ggml-metal/ggml-metal.m
+++ b/src/ggml-metal/ggml-metal.m
@@ -1210,11 +1210,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_TRANSPOSE:
         case GGML_OP_PERMUTE:
         case GGML_OP_CONCAT:
+            return true;
         case GGML_OP_ADD:
         case GGML_OP_SUB:
-        case GGML_OP_ACC:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
+            return op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_ACC:
         case GGML_OP_REPEAT:
         case GGML_OP_SCALE:
         case GGML_OP_CONV_TRANSPOSE_1D:
@@ -1225,8 +1227,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_SQRT:
         case GGML_OP_SIN:
         case GGML_OP_COS:
-        case GGML_OP_LOG:
             return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_LOG:
+           return false; // TODO: implement
         case GGML_OP_SUM_ROWS:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_GROUP_NORM:
@@ -1256,11 +1259,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
         case GGML_OP_PAD_REFLECT_1D:
-        case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_ARGSORT:
         case GGML_OP_LEAKY_RELU:
             return op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_ARANGE:
+            return true;
         case GGML_OP_FLASH_ATTN_EXT:
             if (op->src[1]->type != op->src[2]->type) {
                 return false;

@cmdr2
Copy link
Collaborator Author

cmdr2 commented Feb 27, 2025

You can fix the Metal backend with this patch

@ggerganov Thanks, I appreciate the patch and help! :) I've pushed it to the PR's branch.

Sorry, yeah it's a bit embarrassing for me to commit changes for Metal blindly :) In the future, I'll try to get hold of some cloud runner if possible to test. My ancient Macbook Pro (early 2013) fails to compile ggml with Metal, since it's stuck at macOS 10.15 (Catalina).

@cmdr2
Copy link
Collaborator Author

cmdr2 commented Feb 27, 2025

Also, it looks like the 'test-macos-metal' CI on PRs is only running on BLAS, instead of Metal. For e.g. this is the CI that ran on my latest commit in this PR: https://github.com/ggml-org/ggml/actions/runs/13564669680/job/37914974037?pr=1125

Testing 2 devices
Backend 1/2: BLAS
...
Backend 2/2: CPU
...

@ggerganov
Copy link
Member

Also, it looks like the 'test-macos-metal' CI on PRs is only running on BLAS, instead of Metal.

Yes, the Github runners don't have Metal - this CI workflow is incorrectly named.

Now that you are a collaborator to the project, you can push branches in this repo and they will run the ggml-ci which covers a wide range of hardware, including Metal on Apple Silicon.

@cmdr2
Copy link
Collaborator Author

cmdr2 commented Feb 27, 2025

Thanks! That's very helpful, and I appreciate the trust. I won't misuse it.

@ggerganov ggerganov merged commit 0d1ea2e into ggml-org:master Feb 28, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants