diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 76b0c2a988727..8a8775be36583 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1297,6 +1297,19 @@ extern "C" { struct ggml_tensor * a, float s); + // x = s * a + b + GGML_API struct ggml_tensor * ggml_scale_bias( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b); + + GGML_API struct ggml_tensor * ggml_scale_bias_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b); + // b -> view(a,offset,nb1,nb2,3), return modified a GGML_API struct ggml_tensor * ggml_set( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index eae575cc040cd..ccb17eb072eb2 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2188,7 +2188,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_RMS_NORM: - case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_CLAMP: @@ -2210,6 +2209,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: return true; + case GGML_OP_SCALE: + float bias; + memcpy(&bias, (float*)op->op_params + 1, sizeof(float)); + return bias == 0.0f; // TODO: support bias != 0.0f case GGML_OP_SOFT_MAX: // TODO: support broadcast // ref: https://github.com/ggml-org/llama.cpp/pull/14435 diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index aaeee614ab993..5a07819038d30 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4643,9 +4643,11 @@ static void ggml_compute_forward_scale_f32( GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); - // scale factor - float v; - memcpy(&v, dst->op_params, sizeof(float)); + float s; // scale factor + float b; // bias + + memcpy(&s, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&b, (float *) dst->op_params + 1, sizeof(float)); const int ith = params->ith; const int nth = params->nth; @@ -4664,12 +4666,22 @@ static void ggml_compute_forward_scale_f32( const size_t nb1 = dst->nb[1]; - for (int i1 = ir0; i1 < ir1; i1++) { - if (dst->data != src0->data) { - // src0 is same shape as dst => same indices - memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + if (b == 0.0f) { + for (int i1 = ir0; i1 < ir1; i1++) { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + } + ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s); + } + } else { + for (int i1 = ir0; i1 < ir1; i1++) { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + } + ggml_vec_mad1_f32(nc, (float *) ((char *) dst->data + i1*nb1), s, b); } - ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v); } } diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 1f5857a23e35c..4652598ead13c 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -351,6 +351,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int #endif } +inline static void ggml_vec_mad1_f32(const int n, float * y, const float s, const float b) { +#if defined(GGML_USE_ACCELERATE) + vDSP_vsmsa(y, 1, &s, &b, y, 1, n); +#elif defined(GGML_SIMD) + #if defined(__ARM_FEATURE_SVE) + // scalar ; TODO: Write SVE code + for (int i = 0; i < n; ++i) { + y[i] = y[i]*s + b; + } + #else + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vs = GGML_F32_VEC_SET1(s); + GGML_F32_VEC vb = GGML_F32_VEC_SET1(b); + + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = y[i]*s + b; + } + #endif +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = y[i]*s + b; + } +#endif +} + //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #if defined(GGML_USE_ACCELERATE) diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu index 1405e066e86a2..2ee9e588992f4 100644 --- a/ggml/src/ggml-cuda/scale.cu +++ b/ggml/src/ggml-cuda/scale.cu @@ -1,18 +1,18 @@ #include "scale.cuh" -static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) { +static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } - dst[i] = scale * x[i]; + dst[i] = scale * x[i] + bias; } -static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { +static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; - scale_f32<<>>(x, dst, scale, k); + scale_f32<<>>(x, dst, scale, bias, k); } void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -25,7 +25,9 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT( dst->type == GGML_TYPE_F32); float scale; - memcpy(&scale, dst->op_params, sizeof(float)); + float bias; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&bias, (float *) dst->op_params + 1, sizeof(float)); - scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream); + scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream); } diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 40fc315e82fd1..83a0739809a6e 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2256,7 +2256,9 @@ static bool ggml_metal_encode_node( GGML_ASSERT(ggml_is_contiguous(src0)); float scale; - memcpy(&scale, dst->op_params, sizeof(scale)); + float bias; + memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float)); int64_t n = ggml_nelements(dst); @@ -2273,6 +2275,7 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + [encoder setBytes:&bias length:sizeof(bias) atIndex:3]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 22240bab47249..239ec31fbcb58 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1014,16 +1014,18 @@ kernel void kernel_scale( device const float * src0, device float * dst, constant float & scale, + constant float & bias, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; + dst[tpig] = src0[tpig] * scale + bias; } kernel void kernel_scale_4( device const float4 * src0, device float4 * dst, constant float & scale, + constant float & bias, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; + dst[tpig] = src0[tpig] * scale + bias; } kernel void kernel_clamp( diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index a9fc039038705..43d8e5c72c937 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -5587,7 +5587,9 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; float scale; - memcpy(&scale, dst->op_params, sizeof(scale)); + float bias; + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&bias, ((int32_t *) dst->op_params) + 1, sizeof(float)); ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; @@ -5602,6 +5604,7 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias)); int n = ggml_nelements(dst)/4; diff --git a/ggml/src/ggml-opencl/kernels/scale.cl b/ggml/src/ggml-opencl/kernels/scale.cl index 8cfd518fa5a3e..aeca8a456e4fe 100644 --- a/ggml/src/ggml-opencl/kernels/scale.cl +++ b/ggml/src/ggml-opencl/kernels/scale.cl @@ -8,9 +8,10 @@ kernel void kernel_scale( ulong offset0, global float4 * dst, ulong offsetd, - float scale + float scale, + float bias ) { src0 = (global float4*)((global char*)src0 + offset0); dst = (global float4*)((global char*)dst + offsetd); - dst[get_global_id(0)] = src0[get_global_id(0)] * scale; + dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias; } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 21c81e99a19aa..cd15bbdb29fa2 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1695,7 +1695,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; } -static void scale_f32(const float * x, float * dst, const float scale, const int k, +static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -1704,7 +1704,7 @@ static void scale_f32(const float * x, float * dst, const float scale, const int return; } - dst[i] = scale * x[i]; + dst[i] = scale * x[i] + bias; } @@ -1842,7 +1842,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl( -static void scale_f32_sycl(const float *x, float *dst, const float scale, +static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE; stream->parallel_for( @@ -1850,7 +1850,7 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale, sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - scale_f32(x, dst, scale, k, item_ct1); + scale_f32(x, dst, scale, bias, k, item_ct1); }); } @@ -2319,9 +2319,11 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds float * dst_dd = static_cast(dst->data); float scale; - memcpy(&scale, dst->op_params, sizeof(float)); + float bias; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&bias, (float *) dst->op_params + 1, sizeof(float)); - scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream); + scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream); /* DPCT1010:87: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2245a655498c5..c36e1a6d3bfc2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7508,7 +7508,7 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - op_params[0], 0.0f, + op_params[0], op_params[1], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }, dryrun); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp index 4663428dee0a2..f10b0a02b5076 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp @@ -18,7 +18,7 @@ void main() { continue; } - data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1)); + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2)); idx += num_threads; } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 75fc1e7072970..5ae1c527df639 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3069,12 +3069,14 @@ static struct ggml_tensor * ggml_scale_impl( struct ggml_context * ctx, struct ggml_tensor * a, float s, + float b, bool inplace) { GGML_ASSERT(ggml_is_padded_1d(a)); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_set_op_params(result, &s, sizeof(s)); + float params[2] = { s, b }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_SCALE; result->src[0] = a; @@ -3086,14 +3088,30 @@ struct ggml_tensor * ggml_scale( struct ggml_context * ctx, struct ggml_tensor * a, float s) { - return ggml_scale_impl(ctx, a, s, false); + return ggml_scale_impl(ctx, a, s, 0.0, false); } struct ggml_tensor * ggml_scale_inplace( struct ggml_context * ctx, struct ggml_tensor * a, float s) { - return ggml_scale_impl(ctx, a, s, true); + return ggml_scale_impl(ctx, a, s, 0.0, true); +} + +struct ggml_tensor * ggml_scale_bias( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b) { + return ggml_scale_impl(ctx, a, s, b, false); +} + +struct ggml_tensor * ggml_scale_bias_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b) { + return ggml_scale_impl(ctx, a, s, b, true); } // ggml_set @@ -5777,7 +5795,7 @@ static void ggml_compute_backward( } break; case GGML_OP_MEAN: { if (src0_needs_grads) { - ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false)); + ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false)); } } break; case GGML_OP_REPEAT: { @@ -5854,7 +5872,7 @@ static void ggml_compute_backward( if (src0_needs_grads) { float s; memcpy(&s, tensor->op_params, sizeof(float)); - ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false)); } } break; case GGML_OP_SET: { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b54bcc8a35e64..1d837b4322cfa 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2368,22 +2368,24 @@ struct test_scale : public test_case { const ggml_type type; const std::array ne; float scale; + float bias; std::string vars() override { - return VARS_TO_STR3(type, ne, scale); + return VARS_TO_STR4(type, ne, scale, bias); } test_scale(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 10, 10}, - float scale = 2.0f) - : type(type), ne(ne), scale(scale) {} + float scale = 2.0f, + float bias = 0.0f) + : type(type), ne(ne), scale(scale), bias(bias) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(a); ggml_set_name(a, "a"); - ggml_tensor * out = ggml_scale(ctx, a, scale); + ggml_tensor * out = ggml_scale_bias(ctx, a, scale, bias); ggml_set_name(out, "out"); return out; @@ -5044,6 +5046,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_add1()); test_cases.emplace_back(new test_scale()); + test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f)); test_cases.emplace_back(new test_silu_back()); for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {