-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdebug_interp6_torch_compile.py
41 lines (25 loc) · 1.08 KB
/
debug_interp6_torch_compile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
import PIL.Image
import torch
resampling_map = {"bilinear": PIL.Image.BILINEAR, "nearest": PIL.Image.NEAREST, "bicubic": PIL.Image.BICUBIC}
def resize(x: torch.Tensor, oh: int, ow: int):
return torch.nn.functional.interpolate(x, (oh, ow), mode="bilinear", antialias=True)
def main():
h, w, c = 256, 256, 3
s = w * c
rgb = list(range(h * s))
oh, ow = 224, 224
compiled_resize = torch.compile(resize)
t_input = torch.tensor(rgb, dtype=torch.float32).reshape(1, h, w, 3).permute(0, 3, 1, 2).contiguous(memory_format=torch.channels_last)
print(t_input.shape, t_input.is_contiguous(memory_format=torch.channels_last))
t_output = compiled_resize(t_input, oh, ow)
print(t_output.shape, t_output.is_contiguous(memory_format=torch.channels_last))
if __name__ == "__main__":
torch.set_num_threads(1)
print("")
print(f"Torch version: {torch.__version__}")
print(f"Torch config: {torch.__config__.show()}")
print(f"Num threads: {torch.get_num_threads()}")
print("")
print("PIL version: ", PIL.__version__)
main()