@@ -118,8 +118,6 @@ struct webgpu_context_struct {
118118 wgpu::Limits limits;
119119
120120 std::recursive_mutex mutex;
121- std::mutex get_tensor_mutex;
122- std::mutex init_mutex;
123121
124122 bool device_init = false ;
125123
@@ -139,6 +137,8 @@ struct webgpu_context_struct {
139137
140138 // Parameter buffers associated with the staged command buffers
141139 std::vector<webgpu_param_bufs> staged_param_bufs;
140+
141+ std::vector<wgpu::FutureWaitInfo> callback_futures;
142142};
143143
144144typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
@@ -221,25 +221,39 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
221221
222222/* * WebGPU Actions */
223223
224+ // Wait for the queue to finish processing all submitted work
224225static void ggml_backend_webgpu_wait_on_submission (webgpu_context & ctx) {
225- // Wait for the queue to finish processing all commands
226- ctx->instance .WaitAny (ctx->queue .OnSubmittedWorkDone (
227- wgpu::CallbackMode::AllowSpontaneous,
228- [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
229- if (status != wgpu::QueueWorkDoneStatus::Success) {
230- GGML_LOG_ERROR (" ggml_webgpu: Failed to wait on queue: %s\n " , message.data );
231- }
232- }),
233- UINT64_MAX);
226+ std::lock_guard<std::recursive_mutex> lock (ctx->mutex );
227+ if (ctx->callback_futures .empty ()) {
228+ // no existing callbacks, wait on queue submission
229+ ctx->instance .WaitAny (ctx->queue .OnSubmittedWorkDone (
230+ wgpu::CallbackMode::AllowSpontaneous,
231+ [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
232+ if (status != wgpu::QueueWorkDoneStatus::Success) {
233+ GGML_LOG_ERROR (" ggml_webgpu: Failed to submit commands: %s\n " , message.data );
234+ }
235+ }),
236+ UINT64_MAX);
237+ } else {
238+ // existing callbacks, wait on them
239+ ctx->instance .WaitAny (ctx->callback_futures .size (), ctx->callback_futures .data (), UINT64_MAX);
240+ ctx->callback_futures .clear ();
241+ }
234242}
235243
236244static void ggml_backend_webgpu_submit_queue (webgpu_context & ctx) {
237245 std::lock_guard<std::recursive_mutex> lock (ctx->mutex );
246+ WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_submit_queue()" );
247+ if (ctx->staged_command_bufs .empty ()) {
248+ // Nothing to submit
249+ return ;
250+ }
238251 ctx->queue .Submit (ctx->staged_command_bufs .size (), ctx->staged_command_bufs .data ());
239252 ctx->staged_command_bufs .clear ();
240253 std::vector<webgpu_param_bufs> staged_param_bufs = std::move (ctx->staged_param_bufs );
254+
241255 // Free the staged parameter buffers once the submission completes
242- ctx->queue .OnSubmittedWorkDone (
256+ wgpu::Future f = ctx->queue .OnSubmittedWorkDone (
243257 wgpu::CallbackMode::AllowSpontaneous,
244258 [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
245259 if (status != wgpu::QueueWorkDoneStatus::Success) {
@@ -248,6 +262,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
248262 // Free the staged parameter buffers
249263 ctx->param_buf_pool .free_bufs (staged_param_bufs);
250264 });
265+ ctx->callback_futures .push_back ({ f });
251266}
252267
253268static void ggml_backend_webgpu_map_buffer (webgpu_context & ctx,
@@ -273,7 +288,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
273288 std::vector<uint32_t > params,
274289 std::vector<wgpu::BindGroupEntry> bind_group_entries,
275290 uint32_t wg_x,
276- bool submit_imm = false ) {
291+ bool submit_and_wait = false ) {
277292 webgpu_param_bufs params_bufs = ctx->param_buf_pool .alloc_bufs ();
278293
279294 ggml_backend_webgpu_map_buffer (ctx, params_bufs.host_buf , wgpu::MapMode::Write, 0 , params_bufs.host_buf .GetSize ());
@@ -304,17 +319,18 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
304319 pass.DispatchWorkgroups (wg_x, 1 , 1 );
305320 pass.End ();
306321 wgpu::CommandBuffer commands = encoder.Finish ();
307- if (submit_imm ) {
308- // Submit immediately
322+ if (submit_and_wait ) {
323+ // Submit and wait immediately
309324 ctx->queue .Submit (1 , &commands);
310- ctx->queue .OnSubmittedWorkDone (wgpu::CallbackMode::AllowSpontaneous,
311- [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
312- if (status != wgpu::QueueWorkDoneStatus::Success) {
313- GGML_LOG_ERROR (" ggml_webgpu: Failed to submit commands: %s\n " ,
314- message.data );
315- }
316- ctx->param_buf_pool .free_bufs ({ params_bufs });
317- });
325+ ctx->instance .WaitAny (ctx->queue .OnSubmittedWorkDone (
326+ wgpu::CallbackMode::AllowSpontaneous,
327+ [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
328+ if (status != wgpu::QueueWorkDoneStatus::Success) {
329+ GGML_LOG_ERROR (" ggml_webgpu: Failed to submit commands: %s\n " , message.data );
330+ }
331+ ctx->param_buf_pool .free_bufs ({ params_bufs });
332+ }),
333+ UINT64_MAX);
318334 } else {
319335 // Lock the context mutex when pushing to the staging vectors.
320336 std::lock_guard<std::recursive_mutex> lock (ctx->mutex );
@@ -579,6 +595,9 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
579595 // memset the remaining bytes
580596 ggml_backend_webgpu_buffer_memset (
581597 webgpu_ctx, buf_ctx->buffer , val32, total_offset + (size - remaining_size), remaining_size);
598+ } else {
599+ // wait for WriteBuffer to complete
600+ ggml_backend_webgpu_wait_on_submission (webgpu_ctx);
582601 }
583602}
584603
@@ -602,7 +621,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
602621 final_size = size + (4 - (size % 4 ));
603622 }
604623
605- std::lock_guard<std::mutex > lock (webgpu_ctx->get_tensor_mutex );
624+ std::lock_guard<std::recursive_mutex > lock (webgpu_ctx->mutex );
606625
607626 if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf .GetSize () < final_size) {
608627 // Create a new staging buffer if it doesn't exist or is too small
@@ -768,10 +787,11 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
768787 webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx ;
769788
770789 // Multiple threads may try to initialize the device
771- std::lock_guard<std::mutex > lock (webgpu_ctx->init_mutex );
790+ std::lock_guard<std::recursive_mutex > lock (webgpu_ctx->mutex );
772791 if (!webgpu_ctx->device_init ) {
773792 // Initialize device
774- std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16, wgpu::FeatureName::ImplicitDeviceSynchronization };
793+ std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
794+ wgpu::FeatureName::ImplicitDeviceSynchronization };
775795 wgpu::DeviceDescriptor dev_desc;
776796 dev_desc.requiredLimits = &webgpu_ctx->limits ;
777797 dev_desc.requiredFeatures = required_features.data ();
0 commit comments