Skip to content

Commit 7e1cd8b

Browse files
authored
Handle pytorch pybind11 changes and bump nightly pins (#4352)
I think this is the simplest approach, for now, to resolve #4343 It would be good to eventually finish #4348 ; however, it became a bit too much to rework the generated sources scripts in a timely fashion. See also another parallel attempt to address the ci problems: #4345 This PR modifies the Cmake pytorch configure function to simply not set any `TORCH_CXX_FLAGS` whenever pytorch is missing the old PYBIND_BUILD_ABI tag. I think whatever compiler flags we were pushing through to make pybind think we are GCC and to use a specific ABI version is just completely unnecessary now. I was worried we might need to update our pybind version in the requirements, but it appears to not be relevant. Additionally, nightly pins are updated and small fixes are made to resolve misc failures in tests after the bump. --------- Signed-off-by: zjgarvey <[email protected]>
1 parent 18e6b7f commit 7e1cd8b

File tree

8 files changed

+33
-57
lines changed

8 files changed

+33
-57
lines changed

build_tools/cmake/TorchMLIRPyTorch.cmake

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ endfunction()
3939
# Separately, pybind11 keeps an internal variable which records its ABI info
4040
# (PYBIND11_INTERNALS_ID in include/pybind11/detail/internals.h). Differences
4141
# in this variable between torch-mlir and PyTorch will cause type errors.
42-
# Thus, our best option is to:
42+
# Note: as of version 2.9.0.dev20250826, torch has updated to pybind11 ver 3.0.
43+
# This simplifies compatibility considerably. For reference, see
44+
# https://github.com/pybind/pybind11/pull/5439
45+
# For pre-version 3.0 pybind11, our best option is to:
4346
# a) Identify which ABI version PyTorch was compiled with
4447
# b) Tell gcc to use that version
4548
# or
@@ -70,23 +73,27 @@ function(TorchMLIRConfigurePyTorch)
7073
# Check ABI compatibility version
7174
execute_process(
7275
COMMAND ${Python3_EXECUTABLE}
73-
-c "import torch; import sys; abi=torch._C._PYBIND11_BUILD_ABI; abi.startswith('_cxxabi10') or sys.exit(1); sys.stdout.write(str(abi[-2:]))"
76+
-c "import torch; import sys; abi=getattr(torch._C, '_PYBIND11_BUILD_ABI', '-1'); abi=='-1' or abi.startswith('_cxxabi10') or sys.exit(1); sys.stdout.write(str(abi[-2:]))"
7477
RESULT_VARIABLE _result
7578
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
7679
OUTPUT_VARIABLE _cxx_abi_version)
7780
if(_result)
7881
message(FATAL_ERROR "Failed to determine C++ ABI version")
79-
endif()
80-
message(STATUS "PyTorch C++ ABI version: \"${_cxx_abi_version}\"")
81-
82-
# Specialize compile flags for compiler
83-
if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU")
84-
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -fabi-version=${_cxx_abi_version}")
85-
elseif(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang")
86-
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -U__GXX_ABI_VERSION -D__GXX_ABI_VERSION=10${_cxx_abi_version} '-DPYBIND11_COMPILER_TYPE=\"_gcc\"'")
82+
elseif(${_cxx_abi_version} STREQUAL "-1")
83+
message(STATUS "Could not find `torch._C._PYBIND_BUILD_ABI`. This was removed in torch 2.9.0 (as of nightly release: dev20250826), and the TORCH_CXX_FLAGS manipulation is no longer required.")
84+
# Everyone involved should be using cxx11 abi by default, but specify this just in case.
85+
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi}")
8786
else()
88-
message(WARNING "Unrecognized compiler. Cannot determine ABI flags.")
89-
return()
87+
message(STATUS "PyTorch C++ ABI version: \"${_cxx_abi_version}\"")
88+
# Specialize compile flags for compiler
89+
if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU")
90+
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -fabi-version=${_cxx_abi_version}")
91+
elseif(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang")
92+
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -U__GXX_ABI_VERSION -D__GXX_ABI_VERSION=10${_cxx_abi_version} '-DPYBIND11_COMPILER_TYPE=\"_gcc\"'")
93+
else()
94+
message(WARNING "Unrecognized compiler. Cannot determine ABI flags.")
95+
return()
96+
endif()
9097
endif()
9198
set(TORCH_CXXFLAGS "${TORCH_CXXFLAGS}" PARENT_SCOPE)
9299
endif()

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15960,9 +15960,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1596015960
" return %0 : !torch.int\n"
1596115961
" }\n"
1596215962
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.int {\n"
15963-
" %int6 = torch.constant.int 6\n"
15964-
" %int15 = torch.constant.int 15\n"
15965-
" %int5 = torch.constant.int 5\n"
1596615963
" %true = torch.constant.bool true\n"
1596715964
" %none = torch.constant.none\n"
1596815965
" %str = torch.constant.str \"AssertionError: \"\n"
@@ -16011,22 +16008,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1601116008
" }\n"
1601216009
" torch.prim.If.yield %9 : !torch.int\n"
1601316010
" } else {\n"
16014-
" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
16015-
" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
16016-
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
16017-
" torch.prim.If.yield %int6 : !torch.int\n"
16018-
" } else {\n"
16019-
" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
16020-
" torch.prim.If.yield %8 : !torch.int\n"
16021-
" }\n"
16022-
" torch.prim.If.yield %7 : !torch.int\n"
16011+
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
16012+
" torch.prim.If.yield %5 : !torch.int\n"
1602316013
" }\n"
1602416014
" return %4 : !torch.int\n"
1602516015
" }\n"
1602616016
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<number>, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.int {\n"
16027-
" %int6 = torch.constant.int 6\n"
16028-
" %int15 = torch.constant.int 15\n"
16029-
" %int5 = torch.constant.int 5\n"
1603016017
" %true = torch.constant.bool true\n"
1603116018
" %none = torch.constant.none\n"
1603216019
" %str = torch.constant.str \"AssertionError: \"\n"
@@ -16075,15 +16062,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1607516062
" }\n"
1607616063
" torch.prim.If.yield %9 : !torch.int\n"
1607716064
" } else {\n"
16078-
" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
16079-
" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
16080-
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
16081-
" torch.prim.If.yield %int6 : !torch.int\n"
16082-
" } else {\n"
16083-
" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
16084-
" torch.prim.If.yield %8 : !torch.int\n"
16085-
" }\n"
16086-
" torch.prim.If.yield %7 : !torch.int\n"
16065+
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
16066+
" torch.prim.If.yield %5 : !torch.int\n"
1608716067
" }\n"
1608816068
" return %4 : !torch.int\n"
1608916069
" }\n"
@@ -16107,8 +16087,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1610716087
" }\n"
1610816088
" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
1610916089
" %true = torch.constant.bool true\n"
16110-
" %int6 = torch.constant.int 6\n"
16111-
" %int15 = torch.constant.int 15\n"
1611216090
" %int5 = torch.constant.int 5\n"
1611316091
" %int8 = torch.constant.int 8\n"
1611416092
" %none = torch.constant.none\n"
@@ -16126,15 +16104,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1612616104
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
1612716105
" torch.prim.If.yield %int5 : !torch.int\n"
1612816106
" } else {\n"
16129-
" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
16130-
" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
16131-
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
16132-
" torch.prim.If.yield %int6 : !torch.int\n"
16133-
" } else {\n"
16134-
" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
16135-
" torch.prim.If.yield %8 : !torch.int\n"
16136-
" }\n"
16137-
" torch.prim.If.yield %7 : !torch.int\n"
16107+
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
16108+
" torch.prim.If.yield %5 : !torch.int\n"
1613816109
" }\n"
1613916110
" return %4 : !torch.int\n"
1614016111
" }\n"

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5544,8 +5544,6 @@ def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni
55445544
return aten〇std〡dtype((self_rank, dtype))
55455545
assert not is_complex_dtype(dtype)
55465546
return dtype
5547-
if self_dtype in [torch.float16, torch.bfloat16]:
5548-
return torch.float32
55495547
return aten〇std〡dtype(self_rank_dtype)
55505548

55515549
@check_dtype_function(
@@ -5569,8 +5567,6 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U
55695567
return aten〇std〡dtype((self_rank, dtype))
55705568
assert not is_complex_dtype(dtype)
55715569
return dtype
5572-
if self_dtype in [torch.float16, torch.bfloat16]:
5573-
return torch.float32
55745570
return aten〇std〡dtype(self_rank_dtype)
55755571

55765572
def aten〇binary_cross_entropy_with_logits〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]] = None, pos_weight_rank_dtype: Optional[Tuple[int, int]] = None, reduction: int = 1) -> int:
@@ -5604,8 +5600,6 @@ def aten〇norm〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int,
56045600
# Should possibly be added to aten〇std〡dtype.
56055601
if self_dtype == torch.complex32:
56065602
return torch.half
5607-
if self_dtype in [torch.float16, torch.bfloat16]:
5608-
return torch.float32
56095603
return aten〇std〡dtype(self_rank_dtype)
56105604

56115605
@check_dtype_function([Invocation(0.0),

projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,12 @@ def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
149149
)
150150
module = self._backend.compile(module)
151151
backend_module = self._backend.load(module)
152+
input_buffers = prog.graph_signature.inputs_to_buffers.values()
152153
params = {
153154
# **dict(artifact.named_parameters(remove_duplicate=False)),
154-
**dict(artifact.named_buffers(remove_duplicate=False)),
155+
name: value
156+
for (name, value) in artifact.named_buffers(remove_duplicate=False)
157+
if name in input_buffers
155158
}
156159
params_flat, params_spec = pytree.tree_flatten(params)
157160
params_flat = list(params_flat)

projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def convert_onnx(model, inputs):
8585
input_names=input_names,
8686
dynamic_axes=dynamic_tensors,
8787
opset_version=max_opset_ver,
88+
dynamo=False,
8889
)
8990
buffer = buffer.getvalue()
9091
return import_onnx(buffer)

pytorch-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
7956a1d1d0dc7cdaaaa42d0863eebb1b1e75eb65
1+
0dfcb1a118dd45c544a156e1d86566368e528e69

pytorch-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
-f https://download.pytorch.org/whl/nightly/cpu/torch/
22
--pre
3-
torch==2.9.0.dev20250820
3+
torch==2.10.0.dev20251016

torchvision-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
-f https://download.pytorch.org/whl/nightly/cpu/torchvision/
22
--pre
3-
torchvision==0.24.0.dev20250820
3+
torchvision==0.25.0.dev20251016

0 commit comments

Comments
 (0)