Skip to content

Commit

Permalink
cuda: unary ops as float + de-duplicate (#1130)
Browse files Browse the repository at this point in the history
  • Loading branch information
cmdr2 authored Mar 3, 2025
1 parent ff90529 commit 58ecf6b
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 567 deletions.
10 changes: 7 additions & 3 deletions src/ggml-cuda/clamp.cu
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
#include "clamp.cuh"

static __device__ __forceinline__ float op_clamp(float x, float min, float max) {
return fminf(fmaxf(x, min), max);
}

template <class T>
static __global__ void op_clamp(const T * x, T * dst, const T min, const T max, const int k) {
static __global__ void op_clamp_kernel(const T * x, T * dst, const T min, const T max, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
}

dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
dst[i] = (T)op_clamp((float)x[i], (float)min, (float)max);
}

template <class T>
static void clamp_cuda(const T * x, T * dst, const T min, const T max, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
op_clamp<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
op_clamp_kernel<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
}


Expand Down
Loading

0 comments on commit 58ecf6b

Please sign in to comment.