Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 94 additions & 121 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,13 @@ struct webgpu_submission_futures {
struct webgpu_buf_pool {
std::vector<webgpu_pool_bufs> free;

std::mutex mutex;

// The pool must be synchronized because
// 1. The memset pool is shared globally by every ggml buffer,
// since allocating a pool per ggml buffer would consume too much memory.
// 2. For the per-thread buffer pools in webgpu_context,
// buffers are allocated and freed in Dawn callbacks,
// which can run on a different thread than the calling thread.
std::mutex mutex;
std::condition_variable cv;

void init(wgpu::Device device,
Expand Down Expand Up @@ -266,7 +271,7 @@ struct webgpu_command {
#endif
};

struct webgpu_capabilities_base {
struct webgpu_capabilities {
wgpu::Limits limits;
bool supports_subgroup_matrix = false;

Expand All @@ -286,11 +291,11 @@ struct webgpu_global_context_struct {
wgpu::Device device;
wgpu::Queue queue;

webgpu_capabilities_base capabilities;
webgpu_capabilities capabilities;
// Shared buffer to move data from device to host
wgpu::Buffer get_tensor_staging_buf;
wgpu::Buffer get_tensor_staging_buf;
// Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches.
std::recursive_mutex mutex;
std::recursive_mutex mutex;

webgpu_buf_pool memset_buf_pool;
std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
Expand Down Expand Up @@ -361,7 +366,6 @@ struct webgpu_context_struct {
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash> pad_pipelines;

size_t memset_bytes_per_thread;

};

typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
Expand All @@ -383,9 +387,8 @@ struct ggml_backend_webgpu_device_context {

// Per-thread data required to actually run WebGPU operations in a backend instance
struct ggml_backend_webgpu_context {
webgpu_context webgpu_ctx;
std::once_flag init_once;
std::string name;
webgpu_context webgpu_ctx;
std::string name;
};

// Per-thread data related to buffers
Expand Down Expand Up @@ -861,20 +864,15 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g
};

webgpu_pipeline pipeline;
{
// TODO: remove guard once pipeline caches are per-thread
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
auto it = ctx->pad_pipelines.find(pipeline_key);
if (it != ctx->pad_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->pad_pipelines.emplace(pipeline_key, pipeline);
}
auto it = ctx->pad_pipelines.find(pipeline_key);
if (it != ctx->pad_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->pad_pipelines.emplace(pipeline_key, pipeline);
}

ggml_webgpu_generic_shader_decisions decisions =
Expand Down Expand Up @@ -944,20 +942,16 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
};

webgpu_pipeline pipeline;
// TODO: remove guard once pipeline caches are per-thread
{
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
auto it = ctx->set_rows_pipelines.find(key);
if (it != ctx->set_rows_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->set_rows_pipelines.emplace(key, pipeline);
}
auto it = ctx->set_rows_pipelines.find(key);
if (it != ctx->set_rows_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->set_rows_pipelines.emplace(key, pipeline);
}

ggml_webgpu_generic_shader_decisions decisions =
Expand Down Expand Up @@ -1261,29 +1255,25 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
};

webgpu_pipeline pipeline;
// TODO: remove guard once pipeline caches are per-thread
{
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
auto it = ctx->flash_attn_pipelines.find(key);
if (it != ctx->flash_attn_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
.key = key,
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size
};

ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->flash_attn_pipelines.emplace(key, pipeline);
}
auto it = ctx->flash_attn_pipelines.find(key);
if (it != ctx->flash_attn_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
.key = key,
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size
};

ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->flash_attn_pipelines.emplace(key, pipeline);
}

ggml_webgpu_flash_attn_shader_decisions decisions =
Expand All @@ -1308,20 +1298,16 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
};

webgpu_pipeline pipeline;
{
// TODO: remove guard once pipeline caches are per-thread
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
auto it = ctx->unary_pipelines.find(pipeline_key);
if (it != ctx->unary_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->unary_pipelines.emplace(pipeline_key, pipeline);
}
auto it = ctx->unary_pipelines.find(pipeline_key);
if (it != ctx->unary_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->unary_pipelines.emplace(pipeline_key, pipeline);
}

ggml_webgpu_generic_shader_decisions decisions =
Expand Down Expand Up @@ -1743,19 +1729,15 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src
};

webgpu_pipeline pipeline;
{
// TODO: remove guard once pipeline caches are per-thread
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4);
if (it != ctx->argmax_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax");
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline);
}
auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4);
if (it != ctx->argmax_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax");
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline);
}
uint32_t wg_x = ggml_nelements(dst);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
Expand All @@ -1772,9 +1754,8 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
.order = order
};

std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
webgpu_pipeline argsort_pipeline;
auto it = ctx->argsort_pipelines.find(order);
webgpu_pipeline argsort_pipeline;
auto it = ctx->argsort_pipelines.find(order);
if (it != ctx->argsort_pipelines.end()) {
argsort_pipeline = it->second;
} else {
Expand Down Expand Up @@ -1963,19 +1944,15 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline;
// TODO: remove guard once pipeline caches are per-thread
{
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
auto it = ctx->cumsum_pipelines.find(1);
if (it != ctx->cumsum_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum");
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ctx->cumsum_pipelines.emplace(1, pipeline);
}
auto it = ctx->cumsum_pipelines.find(1);
if (it != ctx->cumsum_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum");
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ctx->cumsum_pipelines.emplace(1, pipeline);
}
uint32_t wg_x = ggml_nrows(dst);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
Expand Down Expand Up @@ -2009,19 +1986,15 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s
};

webgpu_pipeline pipeline;
{
// TODO: remove guard once pipeline caches are per-thread
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
auto it = ctx->sum_rows_pipelines.find(1);
if (it != ctx->sum_rows_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows");
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ctx->sum_rows_pipelines.emplace(1, pipeline);
}
auto it = ctx->sum_rows_pipelines.find(1);
if (it != ctx->sum_rows_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows");
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ctx->sum_rows_pipelines.emplace(1, pipeline);
}
uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
Expand Down Expand Up @@ -3016,10 +2989,10 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {

#ifdef GGML_WEBGPU_GPU_PROFILE
// Initialize buffer pool for timestamp queries, used for profiling
ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS,
WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(
ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
#endif

GGML_LOG_INFO(
Expand Down
Loading