@@ -499,9 +499,9 @@ def register_metric(
499
499
):
500
500
def decorator (func ):
501
501
metric_name = func .__name__
502
- if not metric_name in BUILTIN_METRICS :
502
+ if metric_name not in BUILTIN_METRICS :
503
503
operator_name = _find_op_name_from_module_path (func .__module__ )
504
- if not operator_name in REGISTERED_METRICS :
504
+ if operator_name not in REGISTERED_METRICS :
505
505
REGISTERED_METRICS [operator_name ] = []
506
506
REGISTERED_METRICS [operator_name ].append (func .__name__ )
507
507
if skip_baseline :
@@ -1232,8 +1232,6 @@ def service_exists(service_name):
1232
1232
return str (ncu_output_file .resolve ())
1233
1233
1234
1234
def kineto_trace (self , input_id : int , fn : Callable ) -> str :
1235
- from pathlib import Path
1236
-
1237
1235
from torchbenchmark ._components .kineto import do_bench_kineto
1238
1236
1239
1237
kineto_output_dir = self .get_temp_path (f"kineto_traces/{ fn ._name } _{ input_id } " )
@@ -1332,7 +1330,7 @@ def work_func():
1332
1330
return total_flops
1333
1331
1334
1332
fn = self ._get_bm_func (fn_name )
1335
- if not fn in self ._op_flops :
1333
+ if fn not in self ._op_flops :
1336
1334
self ._op_flops [fn ] = _get_flops (self , fn )
1337
1335
op_flops = self ._op_flops [fn ]
1338
1336
return op_flops / metrics .latency / 1e12 * 1e3
0 commit comments