Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions ggml/src/ggml-sycl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,14 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
});
}

template<typename T>
static void arange_kernel(T * dst, const int k, T start, T step,
const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = start + static_cast<T>(i) * step;
}
}

template<typename T>
static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
const int nb02, const int nb03, const int ne10, const int ne11,
Expand Down Expand Up @@ -631,6 +639,29 @@ static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, gg
}
}

static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->type == GGML_TYPE_F32);

float start, stop, step;
memcpy(&start, dst->op_params, sizeof(float));
memcpy(&stop, (float *) dst->op_params + 1, sizeof(float));
memcpy(&step, (float *) dst->op_params + 2, sizeof(float));

dpct::queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));

float * dst_ptr = (float *)dst->data;
const int k = (int)ggml_nelements(dst); // הוספה חשובה!

const int num_blocks = ceil_div(k, SYCL_ARANGE_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
arange_kernel(dst_ptr, k, start, step, item_ct1);
});
}

} // namespace ggml_sycl_detail


Expand Down Expand Up @@ -1168,3 +1199,8 @@ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_geglu_quick(ctx, dst);
}

void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0);
ggml_sycl_detail::ggml_sycl_op_arange(ctx, dst);
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-sycl/element_wise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,6 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

#endif // GGML_SYCL_ELEMENTWISE_HPP
5 changes: 5 additions & 0 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3768,6 +3768,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_GATED_LINEAR_ATTN:
ggml_sycl_op_gated_linear_attn(ctx, dst);
break;
case GGML_OP_ARANGE:
ggml_sycl_arange(ctx, dst);
break;
default:
return false;
}
Expand Down Expand Up @@ -4416,6 +4419,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_RWKV_WKV7:
case GGML_OP_GATED_LINEAR_ATTN:
return true;
case GGML_OP_ARANGE:
return op->type == GGML_TYPE_F32;
default:
return false;
}
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-sycl/presets.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#define SYCL_ARGMAX_BLOCK_SIZE 256
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
#define SYCL_ARANGE_BLOCK_SIZE 256

// dmmv = dequantize_mul_mat_vec
#ifndef GGML_SYCL_DMMV_X
Expand Down