-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_benchmark_nhug.py
73 lines (55 loc) · 2.19 KB
/
run_benchmark_nhug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import operator_benchmark as op_bench
import torch
"""Microbenchmarks for interpolate operator."""
class InterpolateBenchmark(op_bench.TorchBenchmarkBase):
def init(self, input_size, output_size, channels_last=False, mode='linear', antialias=False, dtype=torch.float):
input_image = torch.randint(0, 256, size=input_size, dtype=torch.uint8, device='cpu')
if channels_last:
input_image = input_image.contiguous(memory_format=torch.channels_last)
self.inputs = {
"input_image": input_image,
"output_size": output_size,
"mode": mode,
"antialias": antialias,
"dtype":dtype,
}
self.set_module_name("interpolate")
def forward(self, input_image, output_size, mode, antialias, dtype):
if dtype == torch.float:
input_image = input_image.float()
out = torch.nn.functional.interpolate(input_image, size=output_size, mode=mode, align_corners=False, antialias=antialias)
if dtype == torch.float:
out = out.round().clamp(min=0, max=256).to(torch.uint8)
def make_config():
sizes = (
((224, 224), (64, 64)),
((270, 268), (224, 224)),
((256, 256), (1024, 1024)),
)
attrs = []
for (HW1, HW2) in sizes:
attrs.append([(1, 3, *HW1), HW2]) # 3 channels
# attrs.append([(1, 1, *HW1), HW2]) # 1 channel
attrs.append([(1, 3, *HW2), HW1]) # 3 channels
# attrs.append([(1, 1, *HW2), HW1]) # 1 channel
config = op_bench.config_list(
attr_names=["input_size", "output_size"],
attrs=attrs,
cross_product_configs={
# 'channels_last': [True, False],
'channels_last': [True, ],
# 'mode': ["bilinear", "bicubic"],
'mode': ["bilinear", ],
'antialias': [True, ],
# 'antialias': [True, False],
# 'dtype': [torch.float, torch.uint8]
'dtype': [torch.uint8]
# 'dtype': [torch.float]
},
tags=["short"],
)
return config
config = make_config()
op_bench.generate_pt_test(config, InterpolateBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()