-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcheck_interp_cf.py
148 lines (116 loc) · 5.05 KB
/
check_interp_cf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Check wt_max computation overhead. wt_max is computed in _compute_weights_aa
import PIL.Image
import torch
import torch.utils.benchmark as benchmark
import fire
def pth_downsample_i8(img, mode, size, aa=True):
align_corners = False
if mode == "nearest":
align_corners = None
out = torch.nn.functional.interpolate(
img, size=size,
mode=mode,
align_corners=align_corners,
antialias=aa,
)
return out
def pth_downsample(img, mode, size, aa=True):
align_corners = False
if mode == "nearest":
align_corners = None
out = torch.nn.functional.interpolate(
img.float(), size=size,
mode=mode,
align_corners=align_corners,
antialias=aa,
)
return out.to(img.dtype)
if not hasattr(PIL.Image, "Resampling"):
resampling_map = {
"bilinear": PIL.Image.BILINEAR,
"nearest": PIL.Image.NEAREST,
"bicubic": PIL.Image.BICUBIC,
}
else:
resampling_map = {
"bilinear": PIL.Image.Resampling.BILINEAR,
"nearest": PIL.Image.Resampling.NEAREST,
"bicubic": PIL.Image.Resampling.BICUBIC,
}
def main(min_run_time=10.0):
tag = "PR"
results = []
torch.manual_seed(12)
for mf in ["channels_first", ]:
for size in [256, ]:
for osize, aa, mode in [
# ((224, 224), True, "bilinear"),
# Vertical pass
# ((224, 256), True, "bilinear"),
# Horizontal pass
((256, 224), True, "bilinear"),
# ((224, 224), False, "bilinear"),
]:
for c, dtype in [
(3, torch.uint8),
# (4, torch.float32),
]:
if dtype == torch.bool:
tensor = torch.randint(0, 2, size=(c, size, size), dtype=dtype)
elif dtype == torch.complex64:
real = torch.randint(0, 256, size=(c, size, size), dtype=torch.float32)
imag = torch.randint(0, 256, size=(c, size, size), dtype=torch.float32)
tensor = torch.complex(real, imag)
elif dtype == torch.int8:
tensor = torch.randint(-127, 127, size=(c, size, size), dtype=dtype)
else:
tensor = torch.randint(0, 256, size=(c, size, size), dtype=dtype)
pil_img = None
if dtype == torch.uint8 and c == 3 and aa:
np_array = tensor.clone().permute(1, 2, 0).contiguous().numpy()
pil_img = PIL.Image.fromarray(np_array)
memory_format = torch.channels_last if mf == "channels_last" else torch.contiguous_format
tensor = tensor[None, ...].contiguous(memory_format=memory_format)
output = pth_downsample_i8(tensor, mode=mode, size=osize, aa=aa)
output = output[0, ...]
if pil_img is not None:
results.append(
benchmark.Timer(
# pil_img = pil_img.resize((osize, osize), resample=resampling_map[mode])
stmt=f"data.resize({osize[::-1]}, resample=resample_val)",
globals={
"data": pil_img,
"resample_val": resampling_map[mode],
},
num_threads=torch.get_num_threads(),
label="Resize",
sub_label=f"{c} {dtype} {mf} {mode} {size} -> {osize} aa={aa}",
description=f"Pillow ({PIL.__version__})",
).blocked_autorange(min_run_time=min_run_time)
)
# Tensor interp
results.append(
benchmark.Timer(
# output = pth_downsample_i8(tensor, mode=mode, size=(osize, osize), aa=aa)
stmt=f"fn(data, mode='{mode}', size={osize}, aa={aa})",
globals={
"data": tensor,
"fn": pth_downsample_i8
},
num_threads=torch.get_num_threads(),
label="Resize",
sub_label=f"{c} {dtype} {mf} {mode} {size} -> {osize} aa={aa}",
description=f"torch ({torch.__version__}) {tag}",
).blocked_autorange(min_run_time=min_run_time)
)
compare = benchmark.Compare(results)
compare.print()
if __name__ == "__main__":
torch.set_num_threads(1)
print("")
print(f"Torch version: {torch.__version__}")
print(f"Torch config: {torch.__config__.show()}")
print(f"Num threads: {torch.get_num_threads()}")
print("")
print("PIL version: ", PIL.__version__)
fire.Fire(main)