Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't add two float16 tensors on CUDA? #1117

Closed
cmdr2 opened this issue Feb 19, 2025 · 7 comments
Closed

Can't add two float16 tensors on CUDA? #1117

cmdr2 opened this issue Feb 19, 2025 · 7 comments

Comments

@cmdr2
Copy link
Collaborator

cmdr2 commented Feb 19, 2025

Hi, I'm new to ggml, so apologies if I'm missing something obvious.

I wrote a simple program to add two float32 tensors in ggml using CUDA, and that works fine.

But when I changed the two tensor types to GGML_TYPE_F16 and tried to add them, I got a GGML assertion error:
ggml-cuda\binbcast.cu:297: GGML_ASSERT(src1->type == GGML_TYPE_F32) failed

Key snippets (and I've included the complete program at the bottom):

struct ggml_tensor* a = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 3);
struct ggml_tensor* b = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 3);
struct ggml_tensor* result = ggml_add(ctx, a, b);
printf("Computing graph...\n");
ggml_backend_graph_compute(backend, gf); // <---- fails here
printf("Finished computing\n");

I'm sending float16 data, but that doesn't seem to matter.

I have an NVIDIA 3060 12 GB, with compute capability 8.6. PyTorch works just fine in float16 for me.

Digging into the code, it looks like a lot of operations enforce F32 for the second tensor (add, sub, mul, div etc).

Am I missing something, and if not, why can't we add two float16 tensors using ggml?

Thanks for your help! :)

Complete program
#include "ggml.h"
#include "ggml-cpu.h"

#ifdef GGML_USE_CUDA
#include "ggml-cuda.h"
#endif

#include <vector>
#include <iostream>

ggml_backend_t backend = NULL;
ggml_gallocr_t allocr = NULL;

void init_backend() {
#ifdef GGML_USE_CUDA
    fprintf(stderr, "%s: using CUDA backend\n", __func__);
    backend = ggml_backend_cuda_init(0); // init device 0
    if (!backend) {
        fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
    }
#endif

    if (!backend) {
        backend = ggml_backend_cpu_init();
    }
}

void init_mem_allocator() {
    allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
}

void predict() {
    // create a context
    struct ggml_init_params params = {
        /*.mem_size   =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
        /*.mem_buffer =*/ NULL,
        /*.no_alloc   =*/ true,
    };
    struct ggml_context* ctx = ggml_init(params);

    // 1. Define the tensor variables
    struct ggml_tensor* a = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 3);
    struct ggml_tensor* b = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 3);

    // 2. Define the computation graph
    struct ggml_tensor* result = ggml_add(ctx, a, b);

    struct ggml_cgraph* gf = ggml_new_graph(ctx);
    ggml_build_forward_expand(gf, result);

    // 3. Allocate memory for the tensor variables, and assign the data
    ggml_gallocr_alloc_graph(allocr, gf);

    std::vector<float> a_data_f32 = {1, 2, 3};
    std::vector<float> b_data_f32 = {10, 20, 30};

    // Convert float data to ggml_fp16_t
    std::vector<ggml_fp16_t> a_data(a_data_f32.size());
    std::vector<ggml_fp16_t> b_data(b_data_f32.size());

    for (size_t i = 0; i < a_data_f32.size(); ++i) {
        a_data[i] = ggml_fp32_to_fp16(a_data_f32[i]);
        b_data[i] = ggml_fp32_to_fp16(b_data_f32[i]);
    }

    ggml_backend_tensor_set(a, a_data.data(), 0, ggml_nbytes(a));
    ggml_backend_tensor_set(b, b_data.data(), 0, ggml_nbytes(b));

    // 4. Run the computation, and read the result
    printf("Computing graph...\n");
    ggml_backend_graph_compute(backend, gf);
    printf("Finished computing\n");

    struct ggml_tensor* result_node = ggml_graph_node(gf, -1);  // get the last node in the graph

    int n = ggml_nelements(result_node); // create an array to store the result data
    std::vector<ggml_fp16_t> result_data(n);

    // copy the data from the backend memory into the result array
    ggml_backend_tensor_get(result_node, result_data.data(), 0, ggml_nbytes(result_node));

    // print the data
    for (int i = 0; i < n; i++) {
        std::cout<<ggml_fp16_to_fp32(result_data[i])<<", ";
    }
    std::cout<<std::endl;

    // free the resources
    ggml_free(ctx);
}

int main(int argc, char* argv[]) {
    init_backend();
    init_mem_allocator();

    predict();

    // free the resources
    ggml_gallocr_free(allocr);
    ggml_backend_free(backend);

    return 0;
}
@JohannesGaessler
Copy link
Collaborator

ggml (CUDA) operators tend to be added as they are needed. The addition of 2 FP16 tensors has so far not seen any use so it is not implemented. PRs welcome.

@cmdr2
Copy link
Collaborator Author

cmdr2 commented Feb 20, 2025

Thanks, fair enough. If this is okay to add, I'd be happy to take a stab at this.

I'm not sure if this is a strong-enough justification, but:

  1. The Stable Diffusion world (where I come from) still has a lot of fp16 weights floating around (pre-Flux/SD3 models).
  2. I'm writing a series of simple guides for ggml (as I learn about it myself). Part 1, Part 2. The next part will focus on reducing the model size for inference (i.e. fp16 and quantization).

Before introducing quantization to readers, fp16 is the obvious first way to reduce a model's size for inference.

Thanks!

@JohannesGaessler
Copy link
Collaborator

FP16 + FP16 -> FP16/FP32 is not implemented but FP16 + FP32 -> FP32 should be. That is the operation that is more commonly used since src1 is usually a temporary tensor where reducing the size does not make a meaningful difference.

@cmdr2
Copy link
Collaborator Author

cmdr2 commented Feb 21, 2025

Thanks! Sorry about the long reply.

Part 1

I agree that src1 can be temporary, and this is valid workaround for binary ops. But what about softmax? I'd have to double the memory usage of my input tensors, as well as for storing the result.

Part 2

I actually got fp16 ops working in ggml for CUDA (for binbcast ops - add/sub/mul/div), and tested that it doesn't consume more VRAM than necessary, and is about 35% faster than fp32. The change was fairly small.

Tensor additions that took 4 GB for fp32 took 2 GB with fp16, and CUDA's peak VRAM usage didn't exceed the expected amount (using ggml_backend_cuda_get_device_memory). And it took 13ms instead of 20ms.

General change in binbcast.cu (I could make my change cleaner):

GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);

...
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
    op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream);
}

The next challenge was with getting test-backend-ops to work correctly, because when I added binbcast test cases for GGML_TYPE_F16, it also ran fp16 on the CPU backend (for comparison), and for some test cases the mean square error is larger than the threshold. I might be wrong, but that seems expected right? fp16 on the CPU is unlikely to match a different device (but I might be very wrong here).

Part 3

My real motivation here (and behind learning ggml) is to have Stable Diffusion inference working at competitive speeds. I'm the maintainer of Easy Diffusion, and have been seriously considering stable-diffusion.cpp as our new backend. But its performance is half of vanilla diffusers (pre-Flux models), so I dug further into it.

I saw that a lot of operations in sd.cpp are done in fp32, even if the weights are fp16. Right from the initial latent.

If I run a diffusers pipeline at fp32, it's performance is similar to sd.cpp. If I run diffusers at fp16, the performance (it/sec) is double. In sd.cpp, there's no real perf difference between fp16 and fp32.

I admit there's an assumption here - that fp16 is the reason for the poor performance. It might also be due to implementation differences. But it just seems odd that the speed of diffusers doubles with fp16, but its fp32 perf is the same as sd.cpp.

A vast number of the popular SD models are fp16.

My main blocker right now is performance. With sd.cpp it's half of what we get right now.

I'm definitely happy to work on increasing fp16 support, it's fun to hack on ggml code. I found ggml really simple and intuitive, especially after the new backend API. :)

Thanks!

@JohannesGaessler
Copy link
Collaborator

If your goal is to optimize performance, use NVIDIA NSight Systems to determine which kernels take up which amount of runtime because that imposes a hard limit on how much performance can be gained from optimization. The performance bottlenecks for neural networks are usually either matrix multiplications or convolutions. IIRC Stable Diffusion uses convolutional layers which are as of right now poorly supported in ggml. Instead of a dedicated convolution operator they are instead converted to matrix multiplications using IM2COL.

@cmdr2
Copy link
Collaborator Author

cmdr2 commented Feb 26, 2025

Thanks! You're right, IM2COL is where it spends the maximum amount of time right now (as per NSight). I'll dig into that further.

Resolving this issue since fp16 addition now works in ggml+CUDA (after #1121) and more unary fp16 operations are being tracked in #1125. Thanks for your help!

@cmdr2 cmdr2 closed this as completed Feb 26, 2025
@JohannesGaessler
Copy link
Collaborator

Also take a look at #971 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants