diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py b/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py index 5df112257e..9dc3372ce1 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py @@ -209,7 +209,7 @@ def extract_params( for bs in batch_size_per_feature_per_rank for b in bs ] - ) + ).float() ) ) )