Skip to content

Commit d390f78

Browse files
Re-sync with internal repository
1 parent abaca22 commit d390f78

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

torchbenchmark/util/triton_op.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -499,9 +499,9 @@ def register_metric(
499499
):
500500
def decorator(func):
501501
metric_name = func.__name__
502-
if not metric_name in BUILTIN_METRICS:
502+
if metric_name not in BUILTIN_METRICS:
503503
operator_name = _find_op_name_from_module_path(func.__module__)
504-
if not operator_name in REGISTERED_METRICS:
504+
if operator_name not in REGISTERED_METRICS:
505505
REGISTERED_METRICS[operator_name] = []
506506
REGISTERED_METRICS[operator_name].append(func.__name__)
507507
if skip_baseline:
@@ -1232,8 +1232,6 @@ def service_exists(service_name):
12321232
return str(ncu_output_file.resolve())
12331233

12341234
def kineto_trace(self, input_id: int, fn: Callable) -> str:
1235-
from pathlib import Path
1236-
12371235
from torchbenchmark._components.kineto import do_bench_kineto
12381236

12391237
kineto_output_dir = self.get_temp_path(f"kineto_traces/{fn._name}_{input_id}")
@@ -1332,7 +1330,7 @@ def work_func():
13321330
return total_flops
13331331

13341332
fn = self._get_bm_func(fn_name)
1335-
if not fn in self._op_flops:
1333+
if fn not in self._op_flops:
13361334
self._op_flops[fn] = _get_flops(self, fn)
13371335
op_flops = self._op_flops[fn]
13381336
return op_flops / metrics.latency / 1e12 * 1e3

0 commit comments

Comments
 (0)