Skip to content

Commit 1d12b13

Browse files
authored
feat: support metal cpp (#295)
This PR supersedes #291 --- This PR adds support for a new `metal-cpp` kernel dependency. This is a follow up to the metal-cpp support in hf nix: huggingface/hf-nix#128 and enables kernels to use the cpp headers to drive metal kernels. Changes: - adds dep to build2cmake - adds new relu-metal-cpp example - builds example in CI example usage ```bash cd examples/relu-metal-cpp nix build -L . cd ... uv run test_relu_metal_cpp.py ``` `test_relu_metal_cpp.py` ```python # /// script # requires-python = ">=3.10" # dependencies = ["kernels", "torch", "numpy"] # /// from kernels import get_local_kernel import torch from pathlib import Path relu = get_local_kernel(Path("examples/relu-metal-cpp/result"), "relu").relu input = torch.tensor([-1.0, -1.5, 0.0, 2.0, 3.5], device="mps", dtype=torch.float16) out = relu(input) ref = torch.relu(input) assert torch.allclose(out, ref), f"Float16 failed: {out} != {ref}" print(out.cpu().numpy()) print(ref.cpu().numpy()) print("PASS") ``` output ``` [0. 0. 0. 2. 3.5] [0. 0. 0. 2. 3.5] PASS ```
1 parent 3902f51 commit 1d12b13

File tree

15 files changed

+291
-9
lines changed

15 files changed

+291
-9
lines changed

.github/workflows/build_kernel_macos.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,6 @@ jobs:
2626
# kernels. Also run tests once we have a macOS runner.
2727
- name: Build relu kernel
2828
run: ( cd examples/relu && nix build .\#redistributable.torch29-metal-aarch64-darwin -L )
29+
30+
- name: Build relu metal cpp kernel
31+
run: ( cd examples/relu-metal-cpp && nix build .\#redistributable.torch29-metal-aarch64-darwin -L )

build2cmake/src/config/v2.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,8 @@ pub enum Dependencies {
254254
Cutlass4_0,
255255
#[serde(rename = "cutlass_sycl")]
256256
CutlassSycl,
257+
#[serde(rename = "metal-cpp")]
258+
MetalCpp,
257259
Torch,
258260
}
259261

examples/relu-metal-cpp/build.toml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
[general]
2+
name = "relu"
3+
universal = false
4+
5+
[torch]
6+
src = [
7+
"torch-ext/torch_binding.cpp",
8+
"torch-ext/torch_binding.h",
9+
]
10+
11+
12+
[kernel.relu_metal]
13+
backend = "metal"
14+
src = [
15+
"relu/relu.cpp",
16+
"relu/metallib_loader.mm",
17+
"relu/relu_cpp.metal",
18+
"relu/common.h",
19+
]
20+
depends = [ "torch", "metal-cpp" ]

examples/relu-metal-cpp/flake.nix

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
description = "Flake for ReLU metal cpp kernel";
3+
4+
inputs = {
5+
kernel-builder.url = "path:../..";
6+
};
7+
8+
outputs =
9+
{
10+
self,
11+
kernel-builder,
12+
}:
13+
kernel-builder.lib.genFlakeOutputs {
14+
inherit self;
15+
path = ./.;
16+
};
17+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include <metal_stdlib>
4+
using namespace metal;
5+
6+
// Common constants and utilities for Metal kernels
7+
constant float RELU_THRESHOLD = 0.0f;
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#import <Metal/Metal.h>
2+
#include <ATen/mps/MPSDevice.h>
3+
#include <ATen/mps/MPSStream.h>
4+
5+
#ifdef EMBEDDED_METALLIB_HEADER
6+
#include EMBEDDED_METALLIB_HEADER
7+
#else
8+
#error "EMBEDDED_METALLIB_HEADER not defined"
9+
#endif
10+
11+
// C++ interface to load the embedded metallib without exposing ObjC types
12+
extern "C" {
13+
void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg) {
14+
id<MTLDevice> mtlDevice = (__bridge id<MTLDevice>)device;
15+
NSError* error = nil;
16+
17+
id<MTLLibrary> library = EMBEDDED_METALLIB_NAMESPACE::createLibrary(mtlDevice, &error);
18+
19+
if (!library && errorMsg && error) {
20+
*errorMsg = strdup([error.localizedDescription UTF8String]);
21+
}
22+
23+
// Manually retain since we're not using ARC
24+
// The caller will wrap in NS::TransferPtr which assumes ownership
25+
if (library) {
26+
[library retain];
27+
}
28+
return (__bridge void*)library;
29+
}
30+
31+
// Get PyTorch's MPS device (returns id<MTLDevice> as void*)
32+
void* getMPSDevice() {
33+
return (__bridge void*)at::mps::MPSDevice::getInstance()->device();
34+
}
35+
36+
// Get PyTorch's current MPS command queue (returns id<MTLCommandQueue> as void*)
37+
void* getMPSCommandQueue() {
38+
return (__bridge void*)at::mps::getCurrentMPSStream()->commandQueue();
39+
}
40+
}
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#define NS_PRIVATE_IMPLEMENTATION
2+
#define MTL_PRIVATE_IMPLEMENTATION
3+
4+
// Include metal-cpp headers from system
5+
#include <Metal/Metal.hpp>
6+
#include <Foundation/NSSharedPtr.hpp>
7+
8+
#include <torch/torch.h>
9+
10+
// C interface from metallib_loader.mm
11+
extern "C" void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg);
12+
extern "C" void* getMPSDevice();
13+
extern "C" void* getMPSCommandQueue();
14+
15+
namespace {
16+
17+
MTL::Buffer* getMTLBuffer(const torch::Tensor& tensor) {
18+
return reinterpret_cast<MTL::Buffer*>(const_cast<void*>(tensor.storage().data()));
19+
}
20+
21+
NS::String* makeNSString(const std::string& value) {
22+
return NS::String::string(value.c_str(), NS::StringEncoding::UTF8StringEncoding);
23+
}
24+
25+
MTL::Library* loadLibrary(MTL::Device* device) {
26+
const char* errorMsg = nullptr;
27+
void* library = loadEmbeddedMetalLibrary(reinterpret_cast<void*>(device), &errorMsg);
28+
29+
TORCH_CHECK(library != nullptr, "Failed to create Metal library from embedded data: ",
30+
errorMsg ? errorMsg : "Unknown error");
31+
32+
if (errorMsg) {
33+
free(const_cast<char*>(errorMsg));
34+
}
35+
36+
return reinterpret_cast<MTL::Library*>(library);
37+
}
38+
39+
} // namespace
40+
41+
void dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) {
42+
// Use PyTorch's MPS device and command queue (these are borrowed references, not owned)
43+
MTL::Device* device = reinterpret_cast<MTL::Device*>(getMPSDevice());
44+
TORCH_CHECK(device != nullptr, "Failed to get MPS device");
45+
46+
MTL::CommandQueue* commandQueue = reinterpret_cast<MTL::CommandQueue*>(getMPSCommandQueue());
47+
TORCH_CHECK(commandQueue != nullptr, "Failed to get MPS command queue");
48+
49+
MTL::Library* libraryPtr = reinterpret_cast<MTL::Library*>(loadLibrary(device));
50+
NS::SharedPtr<MTL::Library> library = NS::TransferPtr(libraryPtr);
51+
52+
const std::string kernelName =
53+
std::string("relu_forward_kernel_") + (input.scalar_type() == torch::kFloat ? "float" : "half");
54+
NS::SharedPtr<NS::String> kernelNameString = NS::TransferPtr(makeNSString(kernelName));
55+
56+
NS::SharedPtr<MTL::Function> computeFunction =
57+
NS::TransferPtr(library->newFunction(kernelNameString.get()));
58+
TORCH_CHECK(computeFunction.get() != nullptr, "Failed to create Metal function for ", kernelName);
59+
60+
NS::Error* pipelineError = nullptr;
61+
NS::SharedPtr<MTL::ComputePipelineState> pipelineState =
62+
NS::TransferPtr(device->newComputePipelineState(computeFunction.get(), &pipelineError));
63+
TORCH_CHECK(pipelineState.get() != nullptr,
64+
"Failed to create compute pipeline state: ",
65+
pipelineError ? pipelineError->localizedDescription()->utf8String() : "Unknown error");
66+
67+
// Don't use SharedPtr for command buffer/encoder - they're managed by PyTorch's command queue
68+
MTL::CommandBuffer* commandBuffer = commandQueue->commandBuffer();
69+
TORCH_CHECK(commandBuffer != nullptr, "Failed to create Metal command buffer");
70+
71+
MTL::ComputeCommandEncoder* encoder = commandBuffer->computeCommandEncoder();
72+
TORCH_CHECK(encoder != nullptr, "Failed to create compute command encoder");
73+
74+
encoder->setComputePipelineState(pipelineState.get());
75+
76+
auto* inputBuffer = getMTLBuffer(input);
77+
auto* outputBuffer = getMTLBuffer(output);
78+
TORCH_CHECK(inputBuffer != nullptr, "Input buffer is null");
79+
TORCH_CHECK(outputBuffer != nullptr, "Output buffer is null");
80+
81+
encoder->setBuffer(inputBuffer, input.storage_offset() * input.element_size(), 0);
82+
encoder->setBuffer(outputBuffer, output.storage_offset() * output.element_size(), 1);
83+
84+
const NS::UInteger totalThreads = input.numel();
85+
NS::UInteger threadGroupSize = pipelineState->maxTotalThreadsPerThreadgroup();
86+
if (threadGroupSize > totalThreads) {
87+
threadGroupSize = totalThreads;
88+
}
89+
90+
const MTL::Size gridSize = MTL::Size::Make(totalThreads, 1, 1);
91+
const MTL::Size threadsPerThreadgroup = MTL::Size::Make(threadGroupSize, 1, 1);
92+
93+
encoder->dispatchThreads(gridSize, threadsPerThreadgroup);
94+
encoder->endEncoding();
95+
96+
commandBuffer->commit();
97+
}
98+
99+
void relu(torch::Tensor& out, const torch::Tensor& input) {
100+
TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor");
101+
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
102+
TORCH_CHECK(input.scalar_type() == torch::kFloat || input.scalar_type() == torch::kHalf,
103+
"Unsupported data type: ", input.scalar_type());
104+
105+
TORCH_CHECK(input.sizes() == out.sizes(),
106+
"Tensors must have the same shape. Got input shape: ",
107+
input.sizes(), " and output shape: ", out.sizes());
108+
109+
TORCH_CHECK(input.scalar_type() == out.scalar_type(),
110+
"Tensors must have the same data type. Got input dtype: ",
111+
input.scalar_type(), " and output dtype: ", out.scalar_type());
112+
113+
TORCH_CHECK(input.device() == out.device(),
114+
"Tensors must be on the same device. Got input device: ",
115+
input.device(), " and output device: ", out.device());
116+
117+
dispatchReluKernel(input, out);
118+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include <metal_stdlib>
2+
#include "common.h"
3+
using namespace metal;
4+
5+
kernel void relu_forward_kernel_float(device const float *inA [[buffer(0)]],
6+
device float *outC [[buffer(1)]],
7+
uint index [[thread_position_in_grid]]) {
8+
// Explicitly write to output
9+
outC[index] = max(RELU_THRESHOLD, inA[index]);
10+
}
11+
12+
kernel void relu_forward_kernel_half(device const half *inA [[buffer(0)]],
13+
device half *outC [[buffer(1)]],
14+
uint index [[thread_position_in_grid]]) {
15+
// Explicitly write to output
16+
outC[index] = max(static_cast<half>(0.0), inA[index]);
17+
}

examples/relu-metal-cpp/tests/__init__.py

Whitespace-only changes.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import platform
2+
3+
import torch
4+
import torch.nn.functional as F
5+
6+
import relu
7+
8+
9+
def test_relu():
10+
if platform.system() == "Darwin":
11+
device = torch.device("mps")
12+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
13+
device = torch.device("xpu")
14+
elif torch.version.cuda is not None and torch.cuda.is_available():
15+
device = torch.device("cuda")
16+
else:
17+
device = torch.device("cpu")
18+
x = torch.randn(1024, 1024, dtype=torch.float32, device=device)
19+
torch.testing.assert_allclose(F.relu(x), relu.relu(x))

0 commit comments

Comments
 (0)