Skip to content

Commit ab08612

Browse files
authored
[Flux] Enable FSDP for flux model training (#1074)
## Context - Enable FSDP for flux model training ## Test Some ablation study using Flux-dev model (flux-dev enabled FSDP) on 8 H100 GPU: Test id | Full AC? | Shard t5? | T5 on/off load? | Batch size? | Results? -- | -- | -- | -- | -- | -- 1 | Yes | No | Yes | 4 * 8 = 32 | ✅ 2 | Yes | No | Yes | 8 * 8 = 64 | ❌ OOM 3 | Yes | Yes | Yes | 4 * 8 = 32 | ✅ GPU memory, 78.35GiB(82.48%) 4 | Yes | Yes | No (T5 always on GPU) | 4 * 8 = 32 | ✅ GPU memory 78.41GiB(82.54%). See profiler analysis. 5 | Yes | Yes | No (T5 always on GPU) | 8 * 8 = 64 | ❌ OOM 6 | Yes | No | No (T5 always on GPU) | 4 * 8 = 32 | ✅ GPU memory 84.98GiB(89.46%). See profiler analysis. 7 | Yes | No | No (T5 always on GPU) | 8 * 8 = 64 | ❌ OOM - T5 encoder on/off loading saves a little bit GPU memory, might take extra time to perform on/off load between GPU and CPU. **Thus we don’t recommend enabling on/off load for T5 model**. - For end-to-end training, if a user doesn't shard T5 w/ FSDP, and doesn't want to use T5 on/off loading, the max batch size is 32. Profiler observation of test No.4: **[Recommended]** <img width="1738" alt="Screenshot 2025-04-09 at 1 51 10 PM" src="https://github.com/user-attachments/assets/3b836d7b-089d-4c41-9069-755b6c3ee0bf" /> Profiler observation of test No.6: <img width="1738" alt="Screenshot 2025-04-09 at 1 50 57 PM" src="https://github.com/user-attachments/assets/1a3fbde5-33b7-444d-878b-74d789ea53d8" /> - From the above profiling comparison, enabling FSDP for T5 didn't increase throughput (no bubbles in computation) but saves GPU memory. **So we recommend enable FSDP sharding T5 by default.**
1 parent 981552e commit ab08612

File tree

14 files changed

+339
-66
lines changed

14 files changed

+339
-66
lines changed
Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,30 @@
11
# FLUX model in torchtitan
22

33
## Overview
4+
This directory contains the implementation of the [FLUX](https://github.com/black-forest-labs/flux/tree/main) model in torchtitan. In torchtitan, we showcase the pre-training process of text-to-image part of the FLUX model.
45

56
## Usage
67
First, download the autoencoder model from HuggingFace with your own access token:
78
```bash
89
python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token <your_access_token>
910
```
11+
1012
This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file.
1113

1214
Run the following command to train the model on a single GPU:
1315
```bash
14-
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml
16+
./torchtitan/experiments/flux/run_train.sh
17+
1518
```
1619

20+
## Supported Features
21+
- Parallelism: The model supports FSDP, HSDP for training on multiple GPUs.
22+
- Activation checkpointing: The model uses activation checkpointing to reduce memory usage during training.
23+
24+
1725
## TODO
18-
- [ ] Supporting for multiple GPUs is comming soon (FSDP, etc)
19-
- [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
2026
- [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
2127
- [ ] Support for distributed checkpointing and loading
2228
- [ ] Implement init_weights() function to initialize the model weights
2329
- [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
30+
- [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)

torchtitan/experiments/flux/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#
77
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
88

9+
910
from torchtitan.components.lr_scheduler import build_lr_schedulers
1011
from torchtitan.components.optimizer import build_optimizers
1112
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
@@ -29,7 +30,7 @@
2930
in_channels=64,
3031
out_channels=64,
3132
vec_in_dim=768,
32-
context_in_dim=512,
33+
context_in_dim=4096,
3334
hidden_size=3072,
3435
mlp_ratio=4.0,
3536
num_heads=24,
@@ -81,10 +82,10 @@
8182
in_channels=64,
8283
out_channels=64,
8384
vec_in_dim=768,
84-
context_in_dim=512,
85-
hidden_size=512,
85+
context_in_dim=4096,
86+
hidden_size=3072,
8687
mlp_ratio=4.0,
87-
num_heads=4,
88+
num_heads=24,
8889
depth=2,
8990
depth_single_blocks=2,
9091
axes_dim=(16, 56, 56),

torchtitan/experiments/flux/dataset/flux_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def _process_cc12m_image(
5656

5757
assert resized_img.size[0] == resized_img.size[1] == output_size
5858

59-
# Skip grayscale images
60-
if resized_img.mode == "L":
59+
# Skip grayscale images, and RGBA, CMYK images
60+
if resized_img.mode != "RGB":
6161
return None
6262

6363
np_img = np.array(resized_img).transpose((2, 0, 1))

torchtitan/experiments/flux/flux_argparser.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,8 @@ def extend_parser(parser: argparse.ArgumentParser) -> None:
4040
default=512,
4141
help="Maximum length of the T5 encoding.",
4242
)
43+
parser.add_argument(
44+
"--encoder.offload_encoder",
45+
action="store_true",
46+
help="Whether to shard the encoder using FSDP",
47+
)

torchtitan/experiments/flux/model/layers.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,12 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
4343
"""
4444
t = time_factor * t
4545
half = dim // 2
46-
freqs = torch.exp(
47-
-math.log(max_period)
48-
* torch.arange(start=0, end=half, dtype=torch.float32)
49-
/ half
50-
).to(t.device)
46+
with torch.device(t.device):
47+
freqs = torch.exp(
48+
-math.log(max_period)
49+
* torch.arange(start=0, end=half, dtype=torch.float32)
50+
/ half
51+
)
5152

5253
args = t[:, None].float() * freqs[None]
5354
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)

torchtitan/experiments/flux/model/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(self, model_args: FluxModelArgs):
6969
super().__init__()
7070

7171
self.model_args = model_args
72+
7273
self.in_channels = model_args.in_channels
7374
self.out_channels = model_args.out_channels
7475
if model_args.hidden_size % model_args.num_heads != 0:

torchtitan/experiments/flux/parallelize_flux.py

Lines changed: 136 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
8-
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
9-
107

8+
import torch
119
import torch.nn as nn
10+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
11+
checkpoint_wrapper as ptd_checkpoint_wrapper,
12+
)
1213

1314
from torch.distributed.device_mesh import DeviceMesh
15+
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
1416

15-
from torchtitan.config_manager import JobConfig
17+
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
1618
from torchtitan.distributed import ParallelDims
19+
from torchtitan.tools.logging import logger
1720

1821

1922
def parallelize_flux(
@@ -22,5 +25,133 @@ def parallelize_flux(
2225
parallel_dims: ParallelDims,
2326
job_config: JobConfig,
2427
):
25-
# TODO: Add model parallel strategy here
28+
if job_config.activation_checkpoint.mode != "none":
29+
apply_ac(model, job_config.activation_checkpoint)
30+
31+
if (
32+
parallel_dims.dp_shard_enabled or parallel_dims.dp_replicate_enabled
33+
): # apply FSDP or HSDP
34+
if parallel_dims.dp_replicate_enabled:
35+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
36+
else:
37+
dp_mesh_dim_names = ("dp_shard_cp",)
38+
39+
apply_fsdp(
40+
model,
41+
world_mesh[tuple(dp_mesh_dim_names)],
42+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
43+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
44+
cpu_offload=job_config.training.enable_cpu_offload,
45+
)
46+
47+
if parallel_dims.dp_replicate_enabled:
48+
logger.info("Applied HSDP to the model")
49+
else:
50+
logger.info("Applied FSDP to the model")
51+
2652
return model
53+
54+
55+
def apply_fsdp(
56+
model: nn.Module,
57+
dp_mesh: DeviceMesh,
58+
param_dtype: torch.dtype,
59+
reduce_dtype: torch.dtype,
60+
cpu_offload: bool = False,
61+
):
62+
"""
63+
Apply data parallelism (via FSDP2) to the model.
64+
65+
Args:
66+
model (nn.Module): The model to apply data parallelism to.
67+
dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
68+
param_dtype (torch.dtype): The data type to use for model parameters.
69+
reduce_dtype (torch.dtype): The data type to use for reduction operations.
70+
cpu_offload (bool): Whether to offload model parameters to CPU. Defaults to False.
71+
"""
72+
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
73+
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
74+
if cpu_offload:
75+
fsdp_config["offload_policy"] = CPUOffloadPolicy()
76+
77+
linear_layers = [
78+
model.img_in,
79+
model.time_in,
80+
model.guidance_in,
81+
model.vector_in,
82+
model.txt_in,
83+
]
84+
for layer in linear_layers:
85+
fully_shard(layer, **fsdp_config)
86+
87+
for block in model.double_blocks:
88+
fully_shard(
89+
block,
90+
**fsdp_config,
91+
)
92+
93+
for block in model.single_blocks:
94+
fully_shard(
95+
block,
96+
**fsdp_config,
97+
)
98+
# apply FSDP to last layer
99+
fully_shard(model.final_layer, **fsdp_config)
100+
# Wrap all the rest of model
101+
fully_shard(model, **fsdp_config)
102+
103+
104+
def apply_ac(model: nn.Module, ac_config):
105+
"""Apply activation checkpointing to the model."""
106+
107+
for layer_id, block in model.double_blocks.named_children():
108+
block = ptd_checkpoint_wrapper(block, preserve_rng_state=False)
109+
model.double_blocks.register_module(layer_id, block)
110+
111+
for layer_id, block in model.single_blocks.named_children():
112+
block = ptd_checkpoint_wrapper(block, preserve_rng_state=False)
113+
model.single_blocks.register_module(layer_id, block)
114+
115+
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
116+
117+
118+
def parallelize_encoders(
119+
t5_model: nn.Module,
120+
clip_model: nn.Module,
121+
world_mesh: DeviceMesh,
122+
parallel_dims: ParallelDims,
123+
job_config: JobConfig,
124+
):
125+
if (
126+
parallel_dims.dp_shard_enabled or parallel_dims.dp_replicate_enabled
127+
): # apply FSDP or HSDP
128+
if parallel_dims.dp_replicate_enabled:
129+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
130+
else:
131+
dp_mesh_dim_names = ("dp_shard_cp",)
132+
133+
mp_policy = MixedPrecisionPolicy(
134+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
135+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
136+
)
137+
fsdp_config = {
138+
"mesh": world_mesh[tuple(dp_mesh_dim_names)],
139+
"mp_policy": mp_policy,
140+
}
141+
if job_config.training.enable_cpu_offload:
142+
fsdp_config["offload_policy"] = CPUOffloadPolicy()
143+
# FSDP for encoder blocks
144+
for block in clip_model.hf_module.text_model.encoder.layers:
145+
fully_shard(block, **fsdp_config)
146+
fully_shard(clip_model, **fsdp_config)
147+
148+
for block in t5_model.hf_module.encoder.block:
149+
fully_shard(block, **fsdp_config)
150+
fully_shard(t5_model.hf_module, **fsdp_config)
151+
152+
if parallel_dims.dp_replicate_enabled:
153+
logger.info("Applied FSDP to the T5 and CLIP model")
154+
else:
155+
logger.info("Applied FSDP to the T5 and CLIP model")
156+
157+
return t5_model, clip_model
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
transformers
1+
transformers>=4.51.1
22
einops
3+
sentencepiece
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/usr/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -ex
9+
10+
# use envs as local overrides for convenience
11+
# e.g.
12+
# LOG_RANK=0,1 NGPU=4 ./torchtitan/experiments/flux/run_train.sh
13+
NGPU=${NGPU:-"8"}
14+
export LOG_RANK=${LOG_RANK:-0}
15+
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/experiments/flux/train_configs/debug_model.toml"}
16+
17+
overrides=""
18+
if [ $# -ne 0 ]; then
19+
overrides="$*"
20+
fi
21+
22+
23+
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
24+
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
25+
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
26+
-m torchtitan.experiments.flux.train --job.config_file ${CONFIG_FILE} $overrides

torchtitan/experiments/flux/train.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchtitan.experiments.flux.model.autoencoder import load_ae
1515
from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
1616
from torchtitan.experiments.flux.model.model import FluxModel
17+
from torchtitan.experiments.flux.parallelize_flux import parallelize_encoders
1718
from torchtitan.experiments.flux.utils import (
1819
create_position_encoding_for_latents,
1920
pack_latents,
@@ -29,24 +30,36 @@ def __init__(self, job_config: JobConfig):
2930
super().__init__(job_config)
3031

3132
self.preprocess_fn = preprocess_flux_data
32-
# self.dtype = job_config.encoder.dtype
33+
# NOTE: self._dtype is the data type used for encoders (image encoder, T5 text encoder, CLIP text encoder).
34+
# We cast the encoders and it's input/output to this dtype.
35+
# For Flux model, we use FSDP with mixed precision training.
3336
self._dtype = torch.bfloat16
3437
self._seed = job_config.training.seed
3538
self._guidance = job_config.training.guidance
3639

3740
# load components
3841
model_config = self.train_spec.config[job_config.model.flavor]
42+
3943
self.autoencoder = load_ae(
4044
job_config.encoder.auto_encoder_path,
4145
model_config.autoencoder_params,
42-
device="cpu",
46+
device=self.device,
4347
dtype=self._dtype,
4448
)
4549
self.clip_encoder = FluxEmbedder(version=job_config.encoder.clip_encoder).to(
46-
dtype=self._dtype
50+
device=self.device, dtype=self._dtype
4751
)
4852
self.t5_encoder = FluxEmbedder(version=job_config.encoder.t5_encoder).to(
49-
dtype=self._dtype
53+
device=self.device, dtype=self._dtype
54+
)
55+
56+
# Apply FSDP to the T5 model / CLIP model
57+
self.t5_encoder, self.clip_encoder = parallelize_encoders(
58+
t5_model=self.t5_encoder,
59+
clip_model=self.clip_encoder,
60+
world_mesh=self.world_mesh,
61+
parallel_dims=self.parallel_dims,
62+
job_config=job_config,
5063
)
5164

5265
def _predict_noise(
@@ -120,7 +133,6 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
120133
clip_encoder=self.clip_encoder,
121134
t5_encoder=self.t5_encoder,
122135
batch=input_dict,
123-
offload=True,
124136
)
125137
labels = input_dict["img_encodings"]
126138

@@ -148,8 +160,6 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
148160
target = noise - labels
149161

150162
assert len(model_parts) == 1
151-
# TODO(jianiw): model_parts will be wrapped by FSDP, which will cacluate
152-
model_parts[0] = model_parts[0].to(dtype=self._dtype)
153163

154164
pred = self._predict_noise(
155165
model_parts[0],

0 commit comments

Comments
 (0)