44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8- # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9-
107
8+ import torch
119import torch .nn as nn
10+ from torch .distributed .algorithms ._checkpoint .checkpoint_wrapper import (
11+ checkpoint_wrapper as ptd_checkpoint_wrapper ,
12+ )
1213
1314from torch .distributed .device_mesh import DeviceMesh
15+ from torch .distributed .fsdp import CPUOffloadPolicy , fully_shard , MixedPrecisionPolicy
1416
15- from torchtitan .config_manager import JobConfig
17+ from torchtitan .config_manager import JobConfig , TORCH_DTYPE_MAP
1618from torchtitan .distributed import ParallelDims
19+ from torchtitan .tools .logging import logger
1720
1821
1922def parallelize_flux (
@@ -22,5 +25,133 @@ def parallelize_flux(
2225 parallel_dims : ParallelDims ,
2326 job_config : JobConfig ,
2427):
25- # TODO: Add model parallel strategy here
28+ if job_config .activation_checkpoint .mode != "none" :
29+ apply_ac (model , job_config .activation_checkpoint )
30+
31+ if (
32+ parallel_dims .dp_shard_enabled or parallel_dims .dp_replicate_enabled
33+ ): # apply FSDP or HSDP
34+ if parallel_dims .dp_replicate_enabled :
35+ dp_mesh_dim_names = ("dp_replicate" , "dp_shard_cp" )
36+ else :
37+ dp_mesh_dim_names = ("dp_shard_cp" ,)
38+
39+ apply_fsdp (
40+ model ,
41+ world_mesh [tuple (dp_mesh_dim_names )],
42+ param_dtype = TORCH_DTYPE_MAP [job_config .training .mixed_precision_param ],
43+ reduce_dtype = TORCH_DTYPE_MAP [job_config .training .mixed_precision_reduce ],
44+ cpu_offload = job_config .training .enable_cpu_offload ,
45+ )
46+
47+ if parallel_dims .dp_replicate_enabled :
48+ logger .info ("Applied HSDP to the model" )
49+ else :
50+ logger .info ("Applied FSDP to the model" )
51+
2652 return model
53+
54+
55+ def apply_fsdp (
56+ model : nn .Module ,
57+ dp_mesh : DeviceMesh ,
58+ param_dtype : torch .dtype ,
59+ reduce_dtype : torch .dtype ,
60+ cpu_offload : bool = False ,
61+ ):
62+ """
63+ Apply data parallelism (via FSDP2) to the model.
64+
65+ Args:
66+ model (nn.Module): The model to apply data parallelism to.
67+ dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
68+ param_dtype (torch.dtype): The data type to use for model parameters.
69+ reduce_dtype (torch.dtype): The data type to use for reduction operations.
70+ cpu_offload (bool): Whether to offload model parameters to CPU. Defaults to False.
71+ """
72+ mp_policy = MixedPrecisionPolicy (param_dtype = param_dtype , reduce_dtype = reduce_dtype )
73+ fsdp_config = {"mesh" : dp_mesh , "mp_policy" : mp_policy }
74+ if cpu_offload :
75+ fsdp_config ["offload_policy" ] = CPUOffloadPolicy ()
76+
77+ linear_layers = [
78+ model .img_in ,
79+ model .time_in ,
80+ model .guidance_in ,
81+ model .vector_in ,
82+ model .txt_in ,
83+ ]
84+ for layer in linear_layers :
85+ fully_shard (layer , ** fsdp_config )
86+
87+ for block in model .double_blocks :
88+ fully_shard (
89+ block ,
90+ ** fsdp_config ,
91+ )
92+
93+ for block in model .single_blocks :
94+ fully_shard (
95+ block ,
96+ ** fsdp_config ,
97+ )
98+ # apply FSDP to last layer
99+ fully_shard (model .final_layer , ** fsdp_config )
100+ # Wrap all the rest of model
101+ fully_shard (model , ** fsdp_config )
102+
103+
104+ def apply_ac (model : nn .Module , ac_config ):
105+ """Apply activation checkpointing to the model."""
106+
107+ for layer_id , block in model .double_blocks .named_children ():
108+ block = ptd_checkpoint_wrapper (block , preserve_rng_state = False )
109+ model .double_blocks .register_module (layer_id , block )
110+
111+ for layer_id , block in model .single_blocks .named_children ():
112+ block = ptd_checkpoint_wrapper (block , preserve_rng_state = False )
113+ model .single_blocks .register_module (layer_id , block )
114+
115+ logger .info (f"Applied { ac_config .mode } activation checkpointing to the model" )
116+
117+
118+ def parallelize_encoders (
119+ t5_model : nn .Module ,
120+ clip_model : nn .Module ,
121+ world_mesh : DeviceMesh ,
122+ parallel_dims : ParallelDims ,
123+ job_config : JobConfig ,
124+ ):
125+ if (
126+ parallel_dims .dp_shard_enabled or parallel_dims .dp_replicate_enabled
127+ ): # apply FSDP or HSDP
128+ if parallel_dims .dp_replicate_enabled :
129+ dp_mesh_dim_names = ("dp_replicate" , "dp_shard_cp" )
130+ else :
131+ dp_mesh_dim_names = ("dp_shard_cp" ,)
132+
133+ mp_policy = MixedPrecisionPolicy (
134+ param_dtype = TORCH_DTYPE_MAP [job_config .training .mixed_precision_param ],
135+ reduce_dtype = TORCH_DTYPE_MAP [job_config .training .mixed_precision_reduce ],
136+ )
137+ fsdp_config = {
138+ "mesh" : world_mesh [tuple (dp_mesh_dim_names )],
139+ "mp_policy" : mp_policy ,
140+ }
141+ if job_config .training .enable_cpu_offload :
142+ fsdp_config ["offload_policy" ] = CPUOffloadPolicy ()
143+ # FSDP for encoder blocks
144+ for block in clip_model .hf_module .text_model .encoder .layers :
145+ fully_shard (block , ** fsdp_config )
146+ fully_shard (clip_model , ** fsdp_config )
147+
148+ for block in t5_model .hf_module .encoder .block :
149+ fully_shard (block , ** fsdp_config )
150+ fully_shard (t5_model .hf_module , ** fsdp_config )
151+
152+ if parallel_dims .dp_replicate_enabled :
153+ logger .info ("Applied FSDP to the T5 and CLIP model" )
154+ else :
155+ logger .info ("Applied FSDP to the T5 and CLIP model" )
156+
157+ return t5_model , clip_model
0 commit comments