Skip to content

Commit e203fa0

Browse files
Fix TorchScript warning "doesn't support instance-level annotations" (e3nn#437)
1 parent 19dca23 commit e203fa0

File tree

8 files changed

+87
-48
lines changed

8 files changed

+87
-48
lines changed

e3nn/o3/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@
103103
"FullyConnectedTensorProduct",
104104
"ElementwiseTensorProduct",
105105
"FullTensorProduct",
106-
"FullTensorProductv2" "TensorSquare",
106+
"FullTensorProductv2",
107+
"TensorSquare",
107108
"SphericalHarmonics",
108109
"spherical_harmonics",
109110
"SphericalHarmonicsAlphaBeta",

e3nn/o3/experimental/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from ._full_tp import FullTensorProduct as FullTensorProductv2
2+
3+
__all__ = [FullTensorProductv2]

e3nn/o3/experimental/_full_tp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# flake8: noqa
2+
13
from e3nn.util.datatypes import Path, Chunk
24
from e3nn import o3
35

e3nn/util/codegen/_mixin.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import io
22
from typing import Dict
33

4-
import e3nn
4+
import e3nn.util.jit
55
import torch
6-
from opt_einsum_fx import jitable
76
from torch import fx
87

98

@@ -40,7 +39,9 @@ def _codegen_register(
4039
assert isinstance(graphmod, fx.GraphModule)
4140

4241
if opt_defaults["jit_script_fx"]:
43-
scriptmod = torch.jit.script(jitable(graphmod))
42+
# With recurse=False, this more or less is equivalent to
43+
# torch.jit.script(jitable(graphmod))
44+
scriptmod = e3nn.util.jit.compile(graphmod, recurse=False)
4445
assert isinstance(scriptmod, torch.jit.ScriptModule)
4546
else:
4647
scriptmod = graphmod
@@ -73,7 +74,7 @@ def __getstate__(self):
7374
# Get the module
7475
smod = getattr(self, fname)
7576
if isinstance(smod, fx.GraphModule):
76-
smod = torch.jit.script(jitable(smod))
77+
smod = e3nn.util.jit.compile(smod, recurse=False)
7778
assert isinstance(smod, torch.jit.ScriptModule)
7879
# Save the compiled code as TorchScript IR
7980
buffer = io.BytesIO()

e3nn/util/jit.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import inspect
33
import warnings
4+
import re
45
from typing import Optional
56

67
import torch
@@ -59,6 +60,7 @@ def compile(
5960
script_options: dict = None,
6061
trace_options: dict = None,
6162
in_place: bool = True,
63+
recurse: bool = True,
6264
):
6365
"""Recursively compile a module and all submodules according to their decorators.
6466
@@ -76,6 +78,10 @@ def compile(
7678
Extra kwargs for ``torch.jit.script``.
7779
trace_options : dict, default = {}
7880
Extra kwargs for ``torch.jit.trace``.
81+
in_place : bool, default True
82+
Whether to insert the recursively compiled submodules in-place, or do a deepcopy first.
83+
recurse : bool, default True
84+
Whether to recurse through the module's children before passing the parent to TorchScript
7985
8086
Returns
8187
-------
@@ -92,25 +98,57 @@ def compile(
9298
mod = copy.deepcopy(mod)
9399
# TODO: debug logging
94100
assert n_trace_checks >= 1
95-
# == recurse to children ==
96-
# This allows us to trace compile submodules of modules we are going to script
97-
for submod_name, submod in mod.named_children():
98-
setattr(
99-
mod,
100-
submod_name,
101-
compile(
102-
submod,
103-
n_trace_checks=n_trace_checks,
104-
script_options=script_options,
105-
trace_options=trace_options,
106-
in_place=True, # since we deepcopied the module above, we can do inplace
107-
),
108-
)
101+
102+
if recurse:
103+
# == recurse to children ==
104+
# This allows us to trace compile submodules of modules we are going to script
105+
for submod_name, submod in mod.named_children():
106+
setattr(
107+
mod,
108+
submod_name,
109+
compile(
110+
submod,
111+
n_trace_checks=n_trace_checks,
112+
script_options=script_options,
113+
trace_options=trace_options,
114+
in_place=True, # since we deepcopied the module above, we can do inplace
115+
recurse=recurse, # always true in this branch
116+
),
117+
)
118+
109119
# == Compile this module now ==
110120
if mode == "script":
111121
if isinstance(mod, fx.GraphModule):
112122
mod = jitable(mod)
113-
mod = torch.jit.script(mod, **script_options)
123+
# In recent PyTorch versions (probably >1.12, definitely >=2.0), PyTorch's implementation of fx.GraphModule
124+
# causes a warning to be raised when fx.GraphModules are compiled to TorchScript with `torch.jit.script`:
125+
#
126+
# torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level
127+
# annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the
128+
# class body, or 2) wrap the type in `torch.jit.Attribute`.
129+
#
130+
# Using the debugger traces this back to the following line in PyTorch:
131+
# https://github.com/pytorch/pytorch/blob/v2.3.1/torch/fx/graph_module.py#L446
132+
# Because the metadata stored by GraphModule is not relevant to the compiled TorchScript module
133+
# we need, it should be safe to ignore this warning. The below code suppresses this warning as
134+
# narrowly as possible to ensure it can still be raised from user code.
135+
# See also: https://github.com/pytorch/pytorch/issues/89064
136+
with warnings.catch_warnings():
137+
warnings.filterwarnings(
138+
"ignore",
139+
# warnings treats this argument as a regex, but we want to match a string literal exactly, so escape it:
140+
message=re.escape(
141+
"The TorchScript type system doesn't support instance-level annotations on empty non-base types "
142+
"in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type "
143+
"in `torch.jit.Attribute`."
144+
),
145+
# Being specific is good form, even though matching the message should be enough:
146+
category=UserWarning,
147+
module="torch",
148+
)
149+
mod = torch.jit.script(mod, **script_options)
150+
else:
151+
mod = torch.jit.script(mod, **script_options)
114152
elif mode == "trace":
115153
# These are always modules, so we're always using trace_module
116154
# We need tracing inputs:

examples/conftest.py

Lines changed: 0 additions & 9 deletions
This file was deleted.

tests/o3/angular_spherical_harmonics_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def test_jit(float_tolerance) -> None:
1212

1313
a = torch.randn(5, 4)
1414
b = torch.randn(5, 4)
15-
return (sh(a, b) - jited(a, b)).abs().max() < float_tolerance
15+
assert (sh(a, b) - jited(a, b)).abs().max() < float_tolerance
1616

1717

1818
def test_sh_equivariance1(float_tolerance) -> None:

tests/o3/experimental/benchmark_pt2.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,32 @@
1-
import torch
2-
from torch._inductor.utils import print_performance
1+
# flake8: noqa
32

4-
# Borrowed from https://github.com/pytorch-labs/gpt-fast/blob/db7b273ab86b75358bd3b014f1f022a19aba4797/generate.py#L16-L18
5-
torch.set_float32_matmul_precision("high")
6-
import torch._dynamo.config
7-
import torch._inductor.config
83

9-
torch._inductor.config.coordinate_descent_tuning = True
10-
torch._inductor.config.triton.unique_kernel_names = True
4+
def main():
5+
import torch
6+
from torch._inductor.utils import print_performance
117

12-
device = "cuda"
13-
compile_mode = "max-autotune" # Bringing out all of the tricks that Torch 2.0 has but "reduce-overhead" should work as well
8+
# Borrowed from https://github.com/pytorch-labs/gpt-fast/blob/db7b273ab86b75358bd3b014f1f022a19aba4797/generate.py#L16-L18
9+
torch.set_float32_matmul_precision("high")
10+
import torch._dynamo.config
11+
import torch._inductor.config
1412

15-
from e3nn import o3, util
16-
import numpy as np
17-
from torch import nn
18-
import time
13+
torch._inductor.config.coordinate_descent_tuning = True
14+
torch._inductor.config.triton.unique_kernel_names = True
1915

20-
LMAX = 8
21-
CHANNEL = 128
22-
BATCH = 100
16+
device = "cuda"
17+
compile_mode = (
18+
"max-autotune" # Bringing out all of the tricks that Torch 2.0 has but "reduce-overhead" should work as well
19+
)
2320

21+
from e3nn import o3, util
22+
import numpy as np
23+
from torch import nn
24+
import time
25+
26+
LMAX = 8
27+
CHANNEL = 128
28+
BATCH = 100
2429

25-
def main():
2630
for lmax in range(1, LMAX + 1):
2731
irreps = o3.Irreps.spherical_harmonics(lmax)
2832
irreps_x = (CHANNEL * irreps).regroup()

0 commit comments

Comments
 (0)