|  | 
| 7 | 7 | from typing import Any, Union | 
| 8 | 8 | 
 | 
| 9 | 9 | import torch | 
|  | 10 | +from torchtitan.config import JobConfig | 
| 10 | 11 | 
 | 
| 11 | 12 | 
 | 
| 12 |  | -def get_compile_backend(backend_name: str) -> Union[str, callable]: | 
|  | 13 | +def get_compile_backend(job_config: JobConfig) -> Union[str, callable]: | 
| 13 | 14 |     # return the compile backends used in SimpleFSDP training | 
| 14 | 15 |     # Step1: check if backend_name is inside available torch.compile backends | 
| 15 | 16 |     # Step2: check if the backend_name has been registered as a customized backend | 
|  | 17 | +    backend_name = ( | 
|  | 18 | +        getattr(job_config.compile, "model_backend_override", None) | 
|  | 19 | +        or job_config.compile.backend | 
|  | 20 | +    ) | 
|  | 21 | + | 
| 16 | 22 |     available_torch_backend = torch._dynamo.list_backends(exclude_tags=()) | 
| 17 | 23 |     if backend_name in available_torch_backend: | 
| 18 | 24 |         return backend_name | 
| @@ -41,6 +47,36 @@ def aten_autobucketing_reordering_pass( | 
| 41 | 47 |             bw_compiler=aten_autobucketing_reordering_pass, | 
| 42 | 48 |             keep_inference_input_mutations=True, | 
| 43 | 49 |         ) | 
|  | 50 | +    elif backend_name == "aot_eager_manualbucketing": | 
|  | 51 | +        # Perform manual optimization in aten fx-level and execute code in aot_eager backend | 
|  | 52 | +        # The manualbucketing logic is here: | 
|  | 53 | +        bucketing_modules = job_config.compile.manual_bucketed_modules | 
|  | 54 | +        from functools import partial | 
|  | 55 | + | 
|  | 56 | +        from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend | 
|  | 57 | +        from torch._inductor.fx_passes.overlap_manual_scheduling import ( | 
|  | 58 | +            manual_overlap_bucketing, | 
|  | 59 | +        ) | 
|  | 60 | + | 
|  | 61 | +        torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = True | 
|  | 62 | +        torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps = False | 
|  | 63 | +        torch._inductor.config.allow_buffer_reuse = False | 
|  | 64 | +        manual_overlap_bucketing = partial( | 
|  | 65 | +            manual_overlap_bucketing, | 
|  | 66 | +            module_bucket_plans=job_config.compile.manual_bucketed_modules, | 
|  | 67 | +        ) | 
|  | 68 | + | 
|  | 69 | +        def aten_manualbucketing_reordering_pass( | 
|  | 70 | +            gm: torch.fx.GraphModule, example_inputs: Any | 
|  | 71 | +        ) -> torch.fx.GraphModule: | 
|  | 72 | +            manual_overlap_bucketing(gm) | 
|  | 73 | +            return gm | 
|  | 74 | + | 
|  | 75 | +        backend = aot_autograd_backend( | 
|  | 76 | +            fw_compiler=aten_manualbucketing_reordering_pass, | 
|  | 77 | +            bw_compiler=aten_manualbucketing_reordering_pass, | 
|  | 78 | +            keep_inference_input_mutations=True, | 
|  | 79 | +        ) | 
| 44 | 80 |     else: | 
| 45 | 81 |         raise AssertionError(f"Unsupported customized backend: {backend_name}") | 
| 46 | 82 | 
 | 
|  | 
0 commit comments