Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,9 @@ endif()
if(EXECUTORCH_BUILD_CUDA)
# Build CUDA-specific AOTI functionality
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cuda)
# Add aoti_cuda to backends - it already depends on aoti_common
list(APPEND _executorch_backends aoti_cuda)
# Add aoti_cuda_backend to backends - it transitively includes aoti_cuda_shims
# and cuda_platform
list(APPEND _executorch_backends aoti_cuda_backend)
endif()

if(EXECUTORCH_BUILD_METAL)
Expand Down
3 changes: 3 additions & 0 deletions backends/aoti/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ target_compile_options(
PUBLIC $<$<CXX_COMPILER_ID:MSVC>:/EHsc /GR>
$<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-fexceptions -frtti -fPIC>
)
target_compile_definitions(
aoti_common PRIVATE $<$<PLATFORM_ID:Windows>:EXPORT_AOTI_FUNCTIONS>
)
# Ensure symbols are exported properly
if(APPLE)
target_link_options(aoti_common PUBLIC -Wl,-export_dynamic)
Expand Down
69 changes: 67 additions & 2 deletions backends/aoti/common_shims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ namespace aoti {

namespace internal {
// Global storage for tensor metadata
std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_sizes;
std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_strides;
AOTI_SHIM_EXPORT std::unordered_map<Tensor*, std::vector<int64_t>>
tensor_to_sizes;
AOTI_SHIM_EXPORT std::unordered_map<Tensor*, std::vector<int64_t>>
tensor_to_strides;
} // namespace internal

extern "C" {
Expand Down Expand Up @@ -204,6 +206,69 @@ void cleanup_tensor_metadata() {
internal::tensor_to_strides.clear();
}

AOTI_SHIM_EXPORT void aoti_torch_warn(
const char* func,
const char* file,
uint32_t line,
const char* msg) {
ET_LOG(Error, "[%s:%u] %s: %s", file, line, func, msg);
}

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size) {
(void)tensor;
(void)ret_size;
throw std::runtime_error("Not implemented");
return Error::Internal;
}

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_clone_preserve_strides(Tensor* self, Tensor** ret_new_tensor) {
(void)self;
(void)ret_new_tensor;
throw std::runtime_error("Not implemented");
return Error::Internal;
}

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor) {
(void)self;
(void)ret_new_tensor;
throw std::runtime_error("Not implemented");
return Error::Internal;
}

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle) {
(void)orig_handle;
(void)new_handle;
throw std::runtime_error("Not implemented");
return Error::Internal;
}

AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
void* data_ptr,
int64_t ndim,
const int64_t* sizes,
const int64_t* strides,
int64_t storage_offset,
int32_t dtype,
int32_t device_type,
int32_t device_index,
Tensor** ret_new_tensor) {
(void)data_ptr;
(void)ndim;
(void)sizes;
(void)strides;
(void)storage_offset;
(void)dtype;
(void)device_type;
(void)device_index;
(void)ret_new_tensor;
throw std::runtime_error("Not implemented");
return Error::Internal;
}

} // extern "C"

} // namespace aoti
Expand Down
91 changes: 62 additions & 29 deletions backends/aoti/common_shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include <executorch/backends/aoti/export.h>
#include <executorch/backends/aoti/utils.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
Expand All @@ -23,57 +24,89 @@ namespace aoti {
using executorch::runtime::Error;
using executorch::runtime::etensor::Tensor;

// Global storage for tensor metadata
extern std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_sizes;
extern std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_strides;

extern "C" {

// Common AOTI type aliases
using AOTIRuntimeError = Error;
using AOTITorchError = Error;

// Global storage for tensor metadata
extern std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_sizes;
extern std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_strides;

// Attribute-related operations (memory-irrelevant)
AOTITorchError aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr);
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr);

AOTITorchError aoti_torch_get_storage_offset(
Tensor* tensor,
int64_t* ret_storage_offset);
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_storage_offset(Tensor* tensor, int64_t* ret_storage_offset);

AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides);
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides);

AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype);
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype);

AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes);
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes);

AOTITorchError aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size);
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size);

AOTITorchError aoti_torch_get_device_index(
Tensor* tensor,
int32_t* ret_device_index);
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_device_index(Tensor* tensor, int32_t* ret_device_index);

AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);

// Utility functions for device and layout information
int32_t aoti_torch_device_type_cpu();
int32_t aoti_torch_layout_strided();
int32_t aoti_torch_dtype_float32();
int32_t aoti_torch_dtype_bfloat16();
int32_t aoti_torch_dtype_int8();
int32_t aoti_torch_dtype_int16();
int32_t aoti_torch_dtype_int32();
int32_t aoti_torch_dtype_int64();
int32_t aoti_torch_dtype_bool();
AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu();
AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int64();

// Dtype utility function needed by Metal backend
size_t aoti_torch_dtype_element_size(int32_t dtype);
AOTI_SHIM_EXPORT size_t aoti_torch_dtype_element_size(int32_t dtype);

// Autograd mode functions
int32_t aoti_torch_grad_mode_is_enabled();
void aoti_torch_grad_mode_set_enabled(bool enabled);
AOTI_SHIM_EXPORT int32_t aoti_torch_grad_mode_is_enabled();
AOTI_SHIM_EXPORT void aoti_torch_grad_mode_set_enabled(bool enabled);

// Cleanup functions for clearing global state
void cleanup_tensor_metadata();
AOTI_SHIM_EXPORT void cleanup_tensor_metadata();

AOTI_SHIM_EXPORT void aoti_torch_warn(
const char* func,
const char* file,
uint32_t line,
const char* msg);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_clone_preserve_strides(Tensor* self, Tensor** ret_new_tensor);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle);

AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
void* data_ptr,
int64_t ndim,
const int64_t* sizes,
const int64_t* strides,
int64_t storage_offset,
int32_t dtype,
int32_t device_type,
int32_t device_index,
Tensor** ret_new_tensor);

} // extern "C"

Expand Down
25 changes: 25 additions & 0 deletions backends/aoti/export.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

// Define export macro for Windows DLL
// When building the aoti_cuda library, EXPORT_AOTI_FUNCTIONS is defined by
// CMake, which causes this macro to export symbols using __declspec(dllexport).
// When consuming the library, the macro imports symbols using
// __declspec(dllimport). On non-Windows platforms, the macro is empty and has
// no effect.
#ifdef _WIN32
#ifdef EXPORT_AOTI_FUNCTIONS
#define AOTI_SHIM_EXPORT __declspec(dllexport)
#else
#define AOTI_SHIM_EXPORT __declspec(dllimport)
#endif
#else
#define AOTI_SHIM_EXPORT
#endif
Loading
Loading