Skip to content

Commit

Permalink
metal : add debug capture backend function (#694)
Browse files Browse the repository at this point in the history
Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
jmousseau and ggerganov authored Jan 29, 2024
1 parent 53558f9 commit a120e20
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/ggml-metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(voi
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);

// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);

#ifdef __cplusplus
}
#endif
Expand Down
40 changes: 34 additions & 6 deletions src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@

bool support_simdgroup_reduction;
bool support_simdgroup_mm;

bool should_capture_next_compute;
};

// MSL code
Expand Down Expand Up @@ -684,6 +686,20 @@ static bool ggml_metal_graph_compute(
const int n_cb = ctx->n_cb;
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;

const bool should_capture = ctx->should_capture_next_compute;
if (should_capture) {
ctx->should_capture_next_compute = false;

MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
descriptor.captureObject = ctx->queue;

NSError * error = nil;
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
GGML_ASSERT(!"capture failed");
}
}

id<MTLCommandBuffer> command_buffer_builder[n_cb];
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
Expand All @@ -692,6 +708,7 @@ static bool ggml_metal_graph_compute(
// enqueue the command buffers in order to specify their execution order
[command_buffer enqueue];
}

const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;

dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
Expand Down Expand Up @@ -738,9 +755,9 @@ static bool ggml_metal_graph_compute(
GGML_ASSERT(!"unsupported op");
}

#ifndef GGML_METAL_NDEBUG
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
#endif
if (should_capture) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
}

const int64_t ne00 = src0 ? src0->ne[0] : 0;
const int64_t ne01 = src0 ? src0->ne[1] : 0;
Expand Down Expand Up @@ -2190,9 +2207,9 @@ static bool ggml_metal_graph_compute(
}
}

#ifndef GGML_METAL_NDEBUG
[encoder popDebugGroup];
#endif
if (should_capture) {
[encoder popDebugGroup];
}
}

[encoder endEncoding];
Expand All @@ -2214,6 +2231,10 @@ static bool ggml_metal_graph_compute(
}
}

if (should_capture) {
[[MTLCaptureManager sharedCaptureManager] stopCapture];
}

return true;
}

Expand Down Expand Up @@ -2575,6 +2596,13 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
}

void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
GGML_ASSERT(ggml_backend_is_metal(backend));

struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
ctx->should_capture_next_compute = true;
}

GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning

GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
Expand Down

0 comments on commit a120e20

Please sign in to comment.