diff --git a/torchbenchmark/util/backends/torchdynamo.py b/torchbenchmark/util/backends/torchdynamo.py index 9fac86784e..2eedc89dc7 100644 --- a/torchbenchmark/util/backends/torchdynamo.py +++ b/torchbenchmark/util/backends/torchdynamo.py @@ -53,6 +53,11 @@ def parse_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', dy action='store_true', help="enable group fusion in Inductor" ) + parser.add_argument( + "--torchinductor_enable_batch_fusion", + action='store_true', + help="enable batch fusion in Inductor" + ) parser.add_argument( "--dynamo_disable_optimizer_step", type=distutils.util.strtobool, @@ -86,6 +91,8 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar # torchinductor.config.triton.use_bmm = True if args.torchinductor_enable_group_fusion: torchinductor.config.group_fusion = True + if args.torchinductor_enable_batch_fusion: + torchinductor.config.batch_fusion = True # used for correctness checks, to avoid triton rand() behaving differently from torch rand(). torchinductor.config.fallback_random = bool(args.torchinductor_fallback_random)