Skip to content

Commit ff14b58

Browse files
committed
wip: use torch from a wheel
1 parent 1d40939 commit ff14b58

File tree

5 files changed

+207
-33
lines changed

5 files changed

+207
-33
lines changed

WORKSPACE

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,23 @@ python_configure(
3535
################################ PyTorch Setup ################################
3636

3737
load("//bazel:dependencies.bzl", "PYTORCH_LOCAL_DIR")
38+
load("//bazel:torch_repo.bzl", "torch_repo")
3839

39-
new_local_repository(
40+
torch_repo(
4041
name = "torch",
41-
build_file = "//bazel:torch.BUILD",
42-
path = PYTORCH_LOCAL_DIR,
42+
dist_dir = "../dist",
4343
)
44+
##new_local_repository(
45+
## name = "torch",
46+
## build_file = "//bazel:torch.BUILD",
47+
## path = PYTORCH_LOCAL_DIR,
48+
##)
4449

4550
############################# OpenXLA Setup ###############################
4651

4752
# To build PyTorch/XLA with a new revison of OpenXLA, update the xla_hash to
4853
# the openxla git commit hash and note the date of the commit.
49-
xla_hash = '9ac36592456e7be0d66506be75fbdacc90dd4e91' # Committed on 2025-06-11.
54+
xla_hash = "9ac36592456e7be0d66506be75fbdacc90dd4e91" # Committed on 2025-06-11.
5055

5156
http_archive(
5257
name = "xla",
@@ -66,8 +71,6 @@ http_archive(
6671
],
6772
)
6873

69-
70-
7174
# For development, one often wants to make changes to the OpenXLA repository as well
7275
# as the PyTorch/XLA repository. You can override the pinned repository above with a
7376
# local checkout by either:
@@ -89,14 +92,14 @@ python_init_rules()
8992
load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories")
9093

9194
python_init_repositories(
95+
default_python_version = "system",
96+
local_wheel_workspaces = ["@torch//:WORKSPACE"],
9297
requirements = {
9398
"3.8": "//:requirements_lock_3_8.txt",
9499
"3.9": "//:requirements_lock_3_9.txt",
95100
"3.10": "//:requirements_lock_3_10.txt",
96101
"3.11": "//:requirements_lock_3_11.txt",
97102
},
98-
local_wheel_workspaces = ["@torch//:WORKSPACE"],
99-
default_python_version = "system",
100103
)
101104

102105
load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains")
@@ -111,8 +114,6 @@ load("@pypi//:requirements.bzl", "install_deps")
111114

112115
install_deps()
113116

114-
115-
116117
# Initialize OpenXLA's external dependencies.
117118
load("@xla//:workspace4.bzl", "xla_workspace4")
118119

@@ -134,7 +135,6 @@ load("@xla//:workspace0.bzl", "xla_workspace0")
134135

135136
xla_workspace0()
136137

137-
138138
load(
139139
"@xla//third_party/gpus:cuda_configure.bzl",
140140
"cuda_configure",

bazel/torch.BUILD

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,36 +23,44 @@ cc_library(
2323
filegroup(
2424
name = "torchgen_deps",
2525
srcs = [
26-
"aten/src/ATen/native/native_functions.yaml",
27-
"aten/src/ATen/native/tags.yaml",
28-
"aten/src/ATen/native/ts_native_functions.yaml",
29-
"aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp",
30-
"aten/src/ATen/templates/DispatchKeyNativeFunctions.h",
31-
"aten/src/ATen/templates/LazyIr.h",
32-
"aten/src/ATen/templates/LazyNonNativeIr.h",
33-
"aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
34-
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
35-
"torch/csrc/lazy/core/shape_inference.h",
36-
"torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
26+
# torchgen/packaged/ instead of aten/src
27+
"torchgen/packaged/ATen/native/native_functions.yaml",
28+
"torchgen/packaged/ATen/native/tags.yaml",
29+
##"torchgen/packaged/ATen/native/ts_native_functions.yaml",
30+
"torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp",
31+
"torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h",
32+
"torchgen/packaged/ATen/templates/LazyIr.h",
33+
"torchgen/packaged/ATen/templates/LazyNonNativeIr.h",
34+
"torchgen/packaged/ATen/templates/RegisterDispatchDefinitions.ini",
35+
"torchgen/packaged/ATen/templates/RegisterDispatchKey.cpp",
36+
# Add torch/include prefix
37+
"torch/include/torch/csrc/lazy/core/shape_inference.h",
38+
##"torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
3739
],
3840
)
3941

40-
cc_import(
42+
# Changed to cc_library from cc_import
43+
44+
cc_library(
4145
name = "libtorch",
42-
shared_library = "build/lib/libtorch.so",
46+
srcs = ["torch/lib/libtorch.so"],
4347
)
4448

45-
cc_import(
49+
cc_library(
4650
name = "libtorch_cpu",
47-
shared_library = "build/lib/libtorch_cpu.so",
51+
srcs = ["torch/lib/libtorch_cpu.so"],
4852
)
4953

50-
cc_import(
54+
cc_library(
5155
name = "libtorch_python",
52-
shared_library = "build/lib/libtorch_python.so",
56+
srcs = [
57+
# Added this
58+
"torch/lib/libshm.so",
59+
"torch/lib/libtorch_python.so",
60+
],
5361
)
5462

55-
cc_import(
63+
cc_library(
5664
name = "libc10",
57-
shared_library = "build/lib/libc10.so",
65+
srcs = ["torch/lib/libc10.so"],
5866
)

bazel/torch_repo.bzl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""Repository rule to setup a torch repo."""
2+
3+
_BUILD_TEMPLATE = """
4+
5+
load("@//bazel:torch_targets.bzl", "define_torch_targets")
6+
7+
package(
8+
default_visibility = [
9+
"//visibility:public",
10+
],
11+
)
12+
13+
define_torch_targets()
14+
"""
15+
16+
def _get_url_basename(url):
17+
basename = url.rpartition("/")[2]
18+
19+
# Starlark doesn't have any URL decode functions, so just approximate
20+
# one with the cases we see.
21+
return basename.replace("%2B", "+")
22+
23+
def _torch_repo_impl(rctx):
24+
rctx.file("BUILD.bazel", _BUILD_TEMPLATE)
25+
26+
env_torch_whl = rctx.os.environ.get("TORCH_WHL", "")
27+
28+
urls = None
29+
local_path = None
30+
if env_torch_whl:
31+
if env_torch_whl.startswith("http"):
32+
urls = [env_torch_whl]
33+
else:
34+
local_path = rctx.path(env_torch_whl)
35+
else:
36+
root_workspace = rctx.path(Label("@@//:WORKSPACE")).dirname
37+
dist_dir = rctx.workspace_root.get_child(rctx.attr.dist_dir)
38+
39+
if dist_dir.exists:
40+
for child in dist_dir.readdir():
41+
# For lack of a better option, take the first match
42+
if child.basename.endswith(".whl"):
43+
local_path = child
44+
break
45+
46+
if not local_path and not urls:
47+
fail((
48+
"No torch wheel source configured:\n" +
49+
"* Set TORCH_WHL environment variable to a local path or URL.\n" +
50+
"* Or ensure the {dist_dir} directory is present with a torch wheel." +
51+
"\n"
52+
).format(
53+
dist_dir = dist_dir,
54+
))
55+
56+
if local_path:
57+
whl_path = local_path
58+
if not whl_path.exists:
59+
fail("File not found: {}".format(whl_path))
60+
61+
# The dist/ directory is necessary for XLA's python_init_repositories
62+
# to discover the wheel and add it to requirements.txt
63+
rctx.symlink(whl_path, "dist/{}".format(whl_path.basename))
64+
elif urls:
65+
whl_basename = _get_url_basename(urls[0])
66+
67+
# The dist/ directory is necessary for XLA's python_init_repositories
68+
# to discover the wheel and add it to requirements.txt
69+
whl_path = rctx.path("dist/{}".format(whl_basename))
70+
result = rctx.download(
71+
url = urls,
72+
output = whl_path,
73+
)
74+
if not result.success:
75+
fail("Failed to download: {}", urls)
76+
77+
# Extract into the repo root. Also use .zip as the extension so that extract
78+
# recognizes the file type.
79+
# Use the whl basename so progress messages are more informative.
80+
whl_zip = whl_path.basename.replace(".whl", ".zip")
81+
rctx.symlink(whl_path, whl_zip)
82+
rctx.extract(whl_zip)
83+
rctx.delete(whl_zip)
84+
85+
torch_repo = repository_rule(
86+
implementation = _torch_repo_impl,
87+
doc = """
88+
Creates a repository with torch headers, shared libraries, and wheel
89+
for integration with Bazel.
90+
""",
91+
attrs = {
92+
"dist_dir": attr.string(
93+
doc = "Directory with a prebuilt torch wheel. Typically points to " +
94+
"a source checkout that built a torch wheel.",
95+
),
96+
},
97+
environ = ["TORCH_WHL"],
98+
)

bazel/torch_repo_targets.bzl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Handles the loading phase to define targets for torch_repo."""
2+
3+
cc_library = native.cc_library
4+
5+
def define_torch_targets():
6+
cc_library(
7+
name = "headers",
8+
hdrs = native.glob(
9+
["torch/include/**/*.h"],
10+
["torch/include/google/protobuf/**/*.h"],
11+
),
12+
strip_include_prefix = "torch/include",
13+
)
14+
15+
# Runtime headers, for importing <torch/torch.h>.
16+
cc_library(
17+
name = "runtime_headers",
18+
hdrs = native.glob(["torch/include/torch/csrc/api/include/**/*.h"]),
19+
strip_include_prefix = "torch/include/torch/csrc/api/include",
20+
)
21+
22+
native.filegroup(
23+
name = "torchgen_deps",
24+
srcs = [
25+
# torchgen/packaged/ instead of aten/src
26+
"torchgen/packaged/ATen/native/native_functions.yaml",
27+
"torchgen/packaged/ATen/native/tags.yaml",
28+
##"torchgen/packaged/ATen/native/ts_native_functions.yaml",
29+
"torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp",
30+
"torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h",
31+
"torchgen/packaged/ATen/templates/LazyIr.h",
32+
"torchgen/packaged/ATen/templates/LazyNonNativeIr.h",
33+
"torchgen/packaged/ATen/templates/RegisterDispatchDefinitions.ini",
34+
"torchgen/packaged/ATen/templates/RegisterDispatchKey.cpp",
35+
# Add torch/include prefix
36+
"torch/include/torch/csrc/lazy/core/shape_inference.h",
37+
##"torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
38+
],
39+
)
40+
41+
# Changed to cc_library from cc_import
42+
43+
cc_library(
44+
name = "libtorch",
45+
srcs = ["torch/lib/libtorch.so"],
46+
)
47+
48+
cc_library(
49+
name = "libtorch_cpu",
50+
srcs = ["torch/lib/libtorch_cpu.so"],
51+
)
52+
53+
cc_library(
54+
name = "libtorch_python",
55+
srcs = [
56+
# Added this
57+
"torch/lib/libshm.so",
58+
"torch/lib/libtorch_python.so",
59+
],
60+
)
61+
62+
cc_library(
63+
name = "libc10",
64+
srcs = ["torch/lib/libc10.so"],
65+
)

codegen/lazy_tensor_generator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121
kernel_signature,
2222
)
2323

24-
aten_path = os.path.join(torch_root, "aten", "src", "ATen")
25-
shape_inference_hdr = os.path.join(torch_root, "torch", "csrc", "lazy", "core",
26-
"shape_inference.h")
24+
##aten_path = os.path.join(torch_root, "aten", "src", "ATen")
25+
aten_path = os.path.join(torch_root, "torchgen", "packaged", "ATen")
26+
##shape_inference_hdr = os.path.join(torch_root, "torch", "csrc", "lazy", "core",
27+
## "shape_inference.h")
28+
shape_inference_hdr = os.path.join(torch_root, "torch", "include",
29+
"torch", "csrc", "lazy", "core", "shape_inference.h")
2730
impl_path = os.path.join(xla_root, "__main__",
2831
"torch_xla/csrc/aten_xla_type.cpp")
2932
source_yaml = sys.argv[2]

0 commit comments

Comments
 (0)