Skip to content

Commit

Permalink
metal : add debug capture backend function
Browse files Browse the repository at this point in the history
  • Loading branch information
jmousseau committed Jan 11, 2024
1 parent c75db1e commit 21815c6
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/ggml-metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);

// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);

#ifdef __cplusplus
}
#endif
Expand Down
35 changes: 33 additions & 2 deletions src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

struct ggml_metal_context {
int n_cb;
bool should_capture_next_compute;

id<MTLDevice> device;
id<MTLCommandQueue> queue;
Expand Down Expand Up @@ -1011,6 +1012,20 @@ bool ggml_metal_graph_compute(

const int n_cb = ctx->n_cb;

const bool should_capture = ctx->should_capture_next_compute;
if (should_capture) {
ctx->should_capture_next_compute = false;

MTLCaptureDescriptor *descriptor = [MTLCaptureDescriptor new];
descriptor.captureObject = ctx->queue;

NSError *error = nil;
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
GGML_ASSERT(!"capture failed");
}
}

for (int i = 0; i < n_cb; ++i) {
ctx->command_buffers[i] = [ctx->queue commandBuffer];

Expand Down Expand Up @@ -1067,7 +1082,10 @@ bool ggml_metal_graph_compute(
GGML_ASSERT(!"unsupported op");
}

[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst)]];
if (should_capture) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst)
encoding:NSUTF8StringEncoding]];
}

const int64_t ne00 = src0 ? src0->ne[0] : 0;
const int64_t ne01 = src0 ? src0->ne[1] : 0;
Expand Down Expand Up @@ -2426,7 +2444,9 @@ bool ggml_metal_graph_compute(
}
}

[encoder popDebugGroup];
if (should_capture) {
[encoder popDebugGroup];
}
}

if (encoder != nil) {
Expand All @@ -2453,6 +2473,10 @@ bool ggml_metal_graph_compute(
}
}

if (should_capture) {
[[MTLCaptureManager sharedCaptureManager] stopCapture];
}

return true;
}
}
Expand Down Expand Up @@ -2798,6 +2822,13 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
}

void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
GGML_ASSERT(ggml_backend_is_metal(backend));

struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
ctx->should_capture_next_compute = true;
}

ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning

ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
Expand Down

0 comments on commit 21815c6

Please sign in to comment.