-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcheck_jit_resize_uint8.py
43 lines (34 loc) · 1.16 KB
/
check_jit_resize_uint8.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from torch.utils.benchmark import Timer, Compare
from torchvision.transforms import functional as F_stable
from torchvision.transforms.v2 import functional as F_v2
from itertools import product
from functools import partial
make_arg_int = partial(torch.randint, 0, 256, dtype=torch.uint8)
shapes = (
(3, 400, 400),
(16, 3, 400, 400)
)
modes = [
F_stable.InterpolationMode.NEAREST,
F_stable.InterpolationMode.BILINEAR,
# F_stable.InterpolationMode.BICUBIC,
]
makers = (make_arg_int, )
devices = ("cpu", "cuda")
fns = ["resize", ]
threads = (1, )
for make, shape, device, fn_name, threads, mode in product(makers, shapes, devices, fns, threads, modes):
t1 = make(shape, device=device)
args = ([64,], )
kwargs = dict(interpolation=mode, antialias=True)
fn = getattr(F_v2, fn_name)
sfn = torch.jit.script(fn)
out = sfn(t1, *args, **kwargs)
ref = fn(t1, *args, **kwargs)
torch.testing.assert_close(ref, out, atol=1, rtol=0)
fn = getattr(F_stable, fn_name)
sfn = torch.jit.script(fn)
out = sfn(t1, *args, **kwargs)
ref = fn(t1, *args, **kwargs)
torch.testing.assert_close(ref, out, atol=1, rtol=0)