diff --git a/torchtitan/experiments/simple_fsdp/backend.py b/torchtitan/experiments/simple_fsdp/backend.py index 7fc9d13bf4..09bd6a1b3c 100644 --- a/torchtitan/experiments/simple_fsdp/backend.py +++ b/torchtitan/experiments/simple_fsdp/backend.py @@ -52,7 +52,7 @@ def get_compile_backend_with_passes( def aot_eager_autobucketing_reordering_pass( gm: torch.fx.GraphModule, example_inputs: Any ) -> torch.fx.GraphModule: - schedule_overlap_bucketing(gm) + schedule_overlap_bucketing(gm, collective_bucketing=True) gm.recompile() return gm @@ -67,7 +67,11 @@ def aot_eager_autobucketing_reordering_pass( def inductor_autobucketing_reordering_pass( gm: torch.fx.Graph, ) -> torch.fx.GraphModule: - return schedule_overlap_bucketing(gm.owning_module) + return schedule_overlap_bucketing( + gm.owning_module, + collective_bucketing=True, + insert_overlap_deps=True, + ) dist_opts.insert_overlap_deps = True torch._inductor.config.reorder_for_peak_memory = False