Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions backends/aoti/aoti_model_container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ AOTInductorModelContainerGetNumOutputsFunc
AOTInductorModelContainerGetNumOutputs = nullptr;
AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr;

// Additional global function pointers for AOT Inductor model container
// operations needed by Metal backend
AOTInductorModelContainerGetInputNameFunc
AOTInductorModelContainerGetInputName = nullptr;
AOTInductorModelContainerGetNumConstantsFunc
AOTInductorModelContainerGetNumConstants = nullptr;

} // extern "C"

} // namespace aoti
Expand Down
20 changes: 20 additions & 0 deletions backends/aoti/aoti_model_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,26 @@ extern AOTInductorModelContainerGetNumOutputsFunc
AOTInductorModelContainerGetNumOutputs;
extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun;

// Retrieves the name of an input tensor by index from the AOTI model container.
// Needed by Metal backend
using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
size_t input_idx,
const char** input_name);

// Retrieves the number of constants from the AOTI model container.
// Needed by Metal backend
using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
size_t* num_constants);

// Global function pointers (will be loaded dynamically).
// Needed by Metal backend
extern AOTInductorModelContainerGetInputNameFunc
AOTInductorModelContainerGetInputName;
extern AOTInductorModelContainerGetNumConstantsFunc
AOTInductorModelContainerGetNumConstants;

} // extern "C"

// AOTI Delegate Handle structure
Expand Down
6 changes: 6 additions & 0 deletions backends/aoti/common_shims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ int32_t aoti_torch_dtype_int64() {
return 4; // PyTorch's int64 dtype code
}

// Dtype utility function needed by Metal backend.
// Returns the size of the dtype in bytes.
size_t aoti_torch_dtype_element_size(int32_t dtype) {
return dtype_to_element_size(dtype);
}

// Cleanup functions
void cleanup_tensor_metadata() {
internal::tensor_to_sizes.clear();
Expand Down
3 changes: 3 additions & 0 deletions backends/aoti/common_shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ int32_t aoti_torch_dtype_float32();
int32_t aoti_torch_dtype_bfloat16();
int32_t aoti_torch_dtype_int64();

// Dtype utility function needed by Metal backend
size_t aoti_torch_dtype_element_size(int32_t dtype);

// Autograd mode functions
int32_t aoti_torch_grad_mode_is_enabled();
void aoti_torch_grad_mode_set_enabled(bool enabled);
Expand Down
Loading