|
7 | 7 | from typing import Any, Union |
8 | 8 |
|
9 | 9 | import torch |
| 10 | +from torchtitan.config import JobConfig |
10 | 11 |
|
11 | 12 |
|
12 | | -def get_compile_backend(backend_name: str) -> Union[str, callable]: |
| 13 | +def get_compile_backend(job_config: JobConfig) -> Union[str, callable]: |
13 | 14 | # return the compile backends used in SimpleFSDP training |
14 | 15 | # Step1: check if backend_name is inside available torch.compile backends |
15 | 16 | # Step2: check if the backend_name has been registered as a customized backend |
| 17 | + backend_name = getattr(job_config.compile, "model_backend_override", None) or job_config.compile.backend |
| 18 | + |
16 | 19 | available_torch_backend = torch._dynamo.list_backends(exclude_tags=()) |
17 | 20 | if backend_name in available_torch_backend: |
18 | 21 | return backend_name |
@@ -41,6 +44,32 @@ def aten_autobucketing_reordering_pass( |
41 | 44 | bw_compiler=aten_autobucketing_reordering_pass, |
42 | 45 | keep_inference_input_mutations=True, |
43 | 46 | ) |
| 47 | + elif backend_name == "aot_eager_manualbucketing": |
| 48 | + # Perform manual optimization in aten fx-level and execute code in aot_eager backend |
| 49 | + # The manualbucketing logic is here: |
| 50 | + bucketing_modules = job_config.compile.manual_bucketed_modules |
| 51 | + from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend |
| 52 | + from torch._inductor.fx_passes.overlap_manual_scheduling import ( |
| 53 | + manual_overlap_bucketing, |
| 54 | + ) |
| 55 | + from functools import partial |
| 56 | + |
| 57 | + torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = True |
| 58 | + torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps = False |
| 59 | + torch._inductor.config.allow_buffer_reuse = False |
| 60 | + manual_overlap_bucketing = partial(manual_overlap_bucketing, module_bucket_plans=job_config.compile.manual_bucketed_modules) |
| 61 | + |
| 62 | + def aten_manualbucketing_reordering_pass( |
| 63 | + gm: torch.fx.GraphModule, example_inputs: Any |
| 64 | + ) -> torch.fx.GraphModule: |
| 65 | + manual_overlap_bucketing(gm) |
| 66 | + return gm |
| 67 | + |
| 68 | + backend = aot_autograd_backend( |
| 69 | + fw_compiler=aten_manualbucketing_reordering_pass, |
| 70 | + bw_compiler=aten_manualbucketing_reordering_pass, |
| 71 | + keep_inference_input_mutations=True, |
| 72 | + ) |
44 | 73 | else: |
45 | 74 | raise AssertionError(f"Unsupported customized backend: {backend_name}") |
46 | 75 |
|
|
0 commit comments