From 21815c6c34244d848e86f361eee200135ddf230f Mon Sep 17 00:00:00 2001 From: Jack Mousseau Date: Wed, 10 Jan 2024 18:16:31 -0800 Subject: [PATCH] metal : add debug capture backend function --- src/ggml-metal.h | 3 +++ src/ggml-metal.m | 35 +++++++++++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/src/ggml-metal.h b/src/ggml-metal.h index c4b7325da..7f71d4b2d 100644 --- a/src/ggml-metal.h +++ b/src/ggml-metal.h @@ -109,6 +109,9 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family); +// capture all command buffers committed the next time `ggml_backend_graph_compute` is called +GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend); + #ifdef __cplusplus } #endif diff --git a/src/ggml-metal.m b/src/ggml-metal.m index 161906824..ff6627392 100644 --- a/src/ggml-metal.m +++ b/src/ggml-metal.m @@ -37,6 +37,7 @@ struct ggml_metal_context { int n_cb; + bool should_capture_next_compute; id device; id queue; @@ -1011,6 +1012,20 @@ bool ggml_metal_graph_compute( const int n_cb = ctx->n_cb; + const bool should_capture = ctx->should_capture_next_compute; + if (should_capture) { + ctx->should_capture_next_compute = false; + + MTLCaptureDescriptor *descriptor = [MTLCaptureDescriptor new]; + descriptor.captureObject = ctx->queue; + + NSError *error = nil; + if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { + GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); + GGML_ASSERT(!"capture failed"); + } + } + for (int i = 0; i < n_cb; ++i) { ctx->command_buffers[i] = [ctx->queue commandBuffer]; @@ -1067,7 +1082,10 @@ bool ggml_metal_graph_compute( GGML_ASSERT(!"unsupported op"); } - [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst)]]; + if (should_capture) { + [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) + encoding:NSUTF8StringEncoding]]; + } const int64_t ne00 = src0 ? src0->ne[0] : 0; const int64_t ne01 = src0 ? src0->ne[1] : 0; @@ -2426,7 +2444,9 @@ bool ggml_metal_graph_compute( } } - [encoder popDebugGroup]; + if (should_capture) { + [encoder popDebugGroup]; + } } if (encoder != nil) { @@ -2453,6 +2473,10 @@ bool ggml_metal_graph_compute( } } + if (should_capture) { + [[MTLCaptureManager sharedCaptureManager] stopCapture]; + } + return true; } } @@ -2798,6 +2822,13 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) { return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; } +void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context; + ctx->should_capture_next_compute = true; +} + ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {