From 98209f0d21842fcf1ff50e50a1f8225f638e0371 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Wed, 3 Dec 2025 10:11:35 -0800 Subject: [PATCH 1/2] [simple_fsdp] Turn on bucketing by default [ghstack-poisoned] --- torchtitan/experiments/simple_fsdp/backend.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchtitan/experiments/simple_fsdp/backend.py b/torchtitan/experiments/simple_fsdp/backend.py index 7fc9d13bf4..76897f1914 100644 --- a/torchtitan/experiments/simple_fsdp/backend.py +++ b/torchtitan/experiments/simple_fsdp/backend.py @@ -52,7 +52,9 @@ def get_compile_backend_with_passes( def aot_eager_autobucketing_reordering_pass( gm: torch.fx.GraphModule, example_inputs: Any ) -> torch.fx.GraphModule: - schedule_overlap_bucketing(gm) + schedule_overlap_bucketing( + gm, collective_bucketing=True, insert_overlap_deps=True + ) gm.recompile() return gm @@ -67,7 +69,11 @@ def aot_eager_autobucketing_reordering_pass( def inductor_autobucketing_reordering_pass( gm: torch.fx.Graph, ) -> torch.fx.GraphModule: - return schedule_overlap_bucketing(gm.owning_module) + return schedule_overlap_bucketing( + gm.owning_module, + collective_bucketing=True, + insert_overlap_deps=True, + ) dist_opts.insert_overlap_deps = True torch._inductor.config.reorder_for_peak_memory = False From 56dbbaf384418a10803c279ab3e3ed5acf3d7509 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Wed, 3 Dec 2025 10:47:24 -0800 Subject: [PATCH 2/2] Update on "[simple_fsdp] Turn on bucketing by default" [ghstack-poisoned] --- torchtitan/experiments/simple_fsdp/backend.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchtitan/experiments/simple_fsdp/backend.py b/torchtitan/experiments/simple_fsdp/backend.py index 76897f1914..09bd6a1b3c 100644 --- a/torchtitan/experiments/simple_fsdp/backend.py +++ b/torchtitan/experiments/simple_fsdp/backend.py @@ -52,9 +52,7 @@ def get_compile_backend_with_passes( def aot_eager_autobucketing_reordering_pass( gm: torch.fx.GraphModule, example_inputs: Any ) -> torch.fx.GraphModule: - schedule_overlap_bucketing( - gm, collective_bucketing=True, insert_overlap_deps=True - ) + schedule_overlap_bucketing(gm, collective_bucketing=True) gm.recompile() return gm