Skip to content

Commit 971606b

Browse files
janeyx99pytorchmergebot
authored andcommitted
Add a stable TORCH_LIBRARY to C shim (pytorch#148124)
This PR adds two main parts: - shim.h stable C APIs into torch::Library APIs - a higher level API in torch/csrc/stable/library.h that calls into this shim.h + otherwise is self contained Goal: custom kernel writers should be able to call the apis in the directories above in order to register their library in a way that allows their custom extension to run with a different libtorch version than it was built with. Subplots resolved: - Do we want a whole separate StableLibrary or do we want to freeze torch::Library and add `m.stable_impl(cstring, void (*fn)(void **, int64_t, int64_t)` into it - Yes, we want a separate StableLibrary. We cannot freeze Library and it is NOT header only. - Should I use unint64_t as the common denominator instead of void* to support 32bit architectures better? - Yes, and done - Should I add a stable `def` and `fragment` when those can be done in python? - I think we do want these --- and now they're done - Where should library_stable_impl.cpp live? -- no longer relevant - I need some solid test cases to make sure everything's going ok. I've intentionally thrown in a bunch of random dtypes into the signature, but I still haven't tested returning multiple things, returning nothing, complex dtypes, etc. - Have since tested all the torch library endpoints. the others can be tested in a followup to separate components that need to be in shim.h vs can be added later Pull Request resolved: pytorch#148124 Approved by: https://github.com/albanD, https://github.com/zou3519, https://github.com/atalman
1 parent 4d10da7 commit 971606b

File tree

15 files changed

+765
-9
lines changed

15 files changed

+765
-9
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ test/generated_type_hints_smoketest.py
6464
test/htmlcov
6565
test/cpp_extensions/install/
6666
test/cpp_extensions/open_registration_extension/install
67+
test/cpp_extensions/libtorch_agnostic_extension/install
6768
test/kernel.errors.txt
6869
third_party/build/
6970
third_party/nccl/

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ jit_core_headers = [
4848
"torch/csrc/jit/frontend/schema_type_parser.h",
4949
"torch/csrc/jit/frontend/error_report.h",
5050
"torch/csrc/jit/frontend/tree.h",
51+
"torch/csrc/stable/library.h",
5152
"torch/custom_class.h",
5253
"torch/custom_class_detail.h",
5354
"torch/library.h",

docs/cpp/source/Doxyfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ INPUT = ../../../aten/src/ATen/ATen.h \
6767
../../../torch/csrc/jit/runtime/custom_operator.h \
6868
../../../torch/csrc/jit/serialization/import.h \
6969
../../../torch/csrc/jit/api/module.h \
70+
../../../torch/csrc/stable/library.h \
7071
../../../torch/library.h \
7172
../../../torch/custom_class.h
7273
# Don't include .cpp files!

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,6 +1274,7 @@ def main():
12741274
"include/c10/xpu/impl/*.h",
12751275
"include/torch/*.h",
12761276
"include/torch/csrc/*.h",
1277+
"include/torch/csrc/stable/*.h",
12771278
"include/torch/csrc/api/include/torch/*.h",
12781279
"include/torch/csrc/api/include/torch/data/*.h",
12791280
"include/torch/csrc/api/include/torch/data/dataloader/*.h",
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import ctypes
2+
from pathlib import Path
3+
4+
import torch
5+
6+
7+
so_files = list(Path(__file__).parent.glob("_C*.so"))
8+
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
9+
10+
# use ctypes.CDLL instead of load_library to be able to test the unload logic
11+
# below code is reduced from the load_library code
12+
with torch._ops.dl_open_guard():
13+
loaded_lib = ctypes.CDLL(so_files[0])
14+
15+
from . import ops
16+
17+
18+
__all__ = [
19+
"loaded_lib",
20+
"ops",
21+
]
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
2+
#include <torch/csrc/inductor/aoti_runtime/utils.h>
3+
#include <torch/csrc/stable/library.h>
4+
5+
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;
6+
7+
void inline sgd_math(
8+
float* param_ptr,
9+
float* grad_ptr,
10+
float* out_ptr,
11+
const float weight_decay,
12+
const double lr,
13+
const bool maximize,
14+
int64_t size
15+
){
16+
int64_t d = 0;
17+
for (; d < size; d++) {
18+
float grad_val = grad_ptr[d];
19+
if (maximize) grad_val = -grad_val;
20+
if (weight_decay != 0.0){
21+
grad_val += param_ptr[d] * weight_decay;
22+
}
23+
out_ptr[d] = param_ptr[d] - grad_val * float(lr);
24+
}
25+
}
26+
27+
28+
RAIIATH sgd_out_of_place(
29+
const RAIIATH param,
30+
const RAIIATH grad,
31+
const float weight_decay,
32+
const double lr,
33+
const bool maximize) {
34+
35+
int64_t param_dim;
36+
aoti_torch_get_dim(param.get(), &param_dim);
37+
38+
int64_t *param_sizes;
39+
int64_t *param_strides;
40+
aoti_torch_get_sizes(param.get(), &param_sizes);
41+
aoti_torch_get_strides(param.get(), &param_strides);
42+
43+
int32_t param_dtype;
44+
aoti_torch_get_dtype(param.get(), &param_dtype);
45+
46+
int32_t param_device_type;
47+
int32_t param_device_index;
48+
aoti_torch_get_device_type(param.get(), &param_device_type);
49+
aoti_torch_get_device_index(param.get(), &param_device_index);
50+
51+
AtenTensorHandle out;
52+
aoti_torch_empty_strided(param_dim, param_sizes, param_strides, param_dtype, param_device_type, param_device_index, &out);
53+
54+
void* param_ptr;
55+
aoti_torch_get_data_ptr(param.get(), &param_ptr);
56+
void* grad_ptr;
57+
aoti_torch_get_data_ptr(grad.get(), &grad_ptr);
58+
void* out_ptr;
59+
aoti_torch_get_data_ptr(out, &out_ptr);
60+
61+
auto param_fp_ptr = reinterpret_cast<float*>(param_ptr);
62+
auto grad_fp_ptr = reinterpret_cast<float*>(grad_ptr);
63+
auto out_fp_ptr = reinterpret_cast<float*>(out_ptr);
64+
65+
int64_t param_numel;
66+
aoti_torch_get_numel(param.get(), &param_numel);
67+
68+
sgd_math(
69+
param_fp_ptr,
70+
grad_fp_ptr,
71+
out_fp_ptr,
72+
weight_decay,
73+
lr,
74+
maximize,
75+
param_numel
76+
);
77+
78+
return RAIIATH(out);
79+
}
80+
81+
82+
void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
83+
RAIIATH param(to<AtenTensorHandle>(stack[0]));
84+
RAIIATH grad(to<AtenTensorHandle>(stack[1]));
85+
auto weight_decay = to<double>(stack[2]);
86+
auto lr = to<double>(stack[3]);
87+
auto maximize = to<bool>(stack[4]);
88+
89+
RAIIATH raiiath_res = sgd_out_of_place(
90+
std::move(param),
91+
std::move(grad),
92+
float(weight_decay),
93+
lr,
94+
maximize);
95+
96+
stack[0] = from(raiiath_res.release());
97+
}
98+
99+
STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
100+
m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor");
101+
}
102+
103+
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
104+
m.impl("sgd_out_of_place", &boxed_sgd_out_of_place);
105+
}
106+
107+
RAIIATH identity(RAIIATH t) {
108+
return std::move(t);
109+
}
110+
111+
void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
112+
RAIIATH t(to<AtenTensorHandle>(stack[0]));
113+
RAIIATH raiiath_res = identity(std::move(t));
114+
stack[0] = from(raiiath_res.release());
115+
}
116+
117+
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
118+
m.def("identity(Tensor t) -> Tensor");
119+
}
120+
121+
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) {
122+
m.impl("identity", &boxed_identity);
123+
}
124+
125+
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
126+
m.impl("identity", &boxed_identity);
127+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
from torch import Tensor
3+
4+
5+
def sgd_out_of_place(param, grad, weight_decay, lr, maximize) -> Tensor:
6+
"""
7+
Computes a single step of SGD on a single parameter Tensor with grad.
8+
9+
Assumes:
10+
- param and grad are the same shape and are 1D.
11+
- param and grad are float and on CPU
12+
13+
Args:
14+
param: a 1D tensor of floats
15+
grad: a 1D tensor of floats
16+
weight_decay: a python double between 0 and 1
17+
lr: a python double
18+
19+
Returns:
20+
a 1D float Tensor the same shape as param
21+
22+
"""
23+
return torch.ops.libtorch_agnostic.sgd_out_of_place.default(
24+
param, grad, weight_decay, lr, maximize
25+
)
26+
27+
28+
def identity(t) -> Tensor:
29+
"""
30+
Returns the input tensor
31+
32+
Args:
33+
t: any Tensor
34+
35+
Returns:
36+
a Tensor, the same as input.
37+
"""
38+
return torch.ops.libtorch_agnostic.identity.default(t)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import distutils.command.clean
2+
import shutil
3+
from pathlib import Path
4+
5+
from setuptools import find_packages, setup
6+
7+
from torch.utils.cpp_extension import BuildExtension, CppExtension
8+
9+
10+
ROOT_DIR = Path(__file__).parent
11+
CSRC_DIR = ROOT_DIR / "libtorch_agnostic" / "csrc"
12+
13+
14+
class clean(distutils.command.clean.clean):
15+
def run(self):
16+
# Run default behavior first
17+
distutils.command.clean.clean.run(self)
18+
19+
# Remove extension
20+
for path in (ROOT_DIR / "libtorch_agnostic").glob("**/*.so"):
21+
path.unlink()
22+
# Remove build and dist and egg-info directories
23+
dirs = [
24+
ROOT_DIR / "build",
25+
ROOT_DIR / "dist",
26+
ROOT_DIR / "libtorch_agnostic.egg-info",
27+
]
28+
for path in dirs:
29+
if path.exists():
30+
shutil.rmtree(str(path), ignore_errors=True)
31+
32+
33+
def get_extension():
34+
extra_compile_args = {
35+
"cxx": ["-fdiagnostics-color=always"],
36+
}
37+
38+
sources = list(CSRC_DIR.glob("**/*.cpp"))
39+
40+
return [
41+
CppExtension(
42+
"libtorch_agnostic._C",
43+
sources=sorted(str(s) for s in sources),
44+
py_limited_api=True,
45+
extra_compile_args=extra_compile_args,
46+
extra_link_args=[],
47+
)
48+
]
49+
50+
51+
setup(
52+
name="libtorch_agnostic",
53+
version="0.0",
54+
author="PyTorch Core Team",
55+
description="Example of libtorch agnostic extension",
56+
packages=find_packages(exclude=("test",)),
57+
package_data={"libtorch_agnostic": ["*.dll", "*.dylib", "*.so"]},
58+
install_requires=[
59+
"torch",
60+
],
61+
ext_modules=get_extension(),
62+
cmdclass={
63+
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
64+
"clean": clean,
65+
},
66+
options={"bdist_wheel": {"py_limited_api": "cp39"}},
67+
)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Owner(s): ["module: cpp"]
2+
3+
import libtorch_agnostic # noqa: F401
4+
5+
import torch
6+
from torch.testing._internal.common_device_type import (
7+
instantiate_device_type_tests,
8+
onlyCPU,
9+
onlyCUDA,
10+
)
11+
from torch.testing._internal.common_utils import run_tests, TestCase
12+
13+
14+
class TestLibtorchAgnostic(TestCase):
15+
@onlyCPU
16+
def test_slow_sgd(self, device):
17+
param = torch.rand(5, device=device)
18+
grad = torch.rand_like(param)
19+
weight_decay = 0.01
20+
lr = 0.001
21+
maximize = False
22+
23+
new_param = libtorch_agnostic.ops.sgd_out_of_place(
24+
param, grad, weight_decay, lr, maximize
25+
)
26+
torch._fused_sgd_(
27+
(param,),
28+
(grad,),
29+
(),
30+
weight_decay=weight_decay,
31+
momentum=0.0,
32+
lr=lr,
33+
dampening=0.0,
34+
nesterov=False,
35+
maximize=maximize,
36+
is_first_step=False,
37+
)
38+
self.assertEqual(new_param, param)
39+
40+
@onlyCUDA
41+
def test_identity_does_not_hog_memory(self, device):
42+
def _run_identity(prior_mem):
43+
t = torch.rand(32, 32, device=device)
44+
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
45+
identi_t = libtorch_agnostic.ops.identity(t)
46+
assert identi_t is t
47+
48+
init_mem = torch.cuda.memory_allocated(device)
49+
50+
for _ in range(3):
51+
_run_identity(init_mem)
52+
curr_mem = torch.cuda.memory_allocated(device)
53+
self.assertEqual(curr_mem, init_mem)
54+
55+
@onlyCUDA
56+
def test_z_delete_torch_lib(self, device):
57+
# Why the z + CUDA? THIS TEST MUST BE RUN LAST
58+
# We are testing that unloading the library properly deletes the registrations, so running this test
59+
# earlier will cause all other tests in this file to fail
60+
lib = libtorch_agnostic.loaded_lib
61+
62+
# code for unloading a library inspired from
63+
# https://stackoverflow.com/questions/19547084/can-i-explicitly-close-a-ctypes-cdll
64+
lib_handle = lib._handle
65+
lib.dlclose(lib_handle)
66+
67+
t = torch.tensor([-2.0, 0.5])
68+
with self.assertRaises(RuntimeError):
69+
libtorch_agnostic.ops.identity(
70+
t
71+
) # errors as identity shouldn't be registered anymore
72+
73+
74+
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
75+
76+
if __name__ == "__main__":
77+
run_tests()

0 commit comments

Comments
 (0)