Skip to content

Conversation

@leisuzz
Copy link
Contributor

@leisuzz leisuzz commented Dec 18, 2025

What does this PR do?

The text encoder is too large in Flux2, and offload to cpu requires a lot of time to get the prompt.

  1. I add the feature to use FSDP in text encoder, which can compute efficiently with multiple GPUs.
  2. The checkpoint is not supporting FSDP now, I added the option if the accelerate uses FSDP.

It is FSDP2, and the script is:

accelerate launch --config_file ${config_file} \
  ./train_dreambooth_lora_flux2_img2img.py \
  --pretrained_model_name_or_path=$model_name  \
  --dataset_name=$dataset_name \
  --image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
  --resolution=$resolution \
  --train_batch_size=$batch_size \
  --guidance_scale=1 \
  --mixed_precision=$mixed_precision \
  --max_grad_norm=1 \
  --dataloader_num_workers=0 \
  --gradient_accumulation_steps=$gradient_accumulation_steps \
  --learning_rate=1e-05 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --gradient_checkpointing \
  --max_train_steps=$max_train_steps \
  --checkpointing_steps=5000 \
  --enable_npu_flash_attention \
  --rank=16 \
  --seed="0" \
  --skip_final_inference \
  --cache_latents \
  --offload \
  --fsdp_text_encoder \
  --output_dir=${output_path} \

The accelerate config is:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_version: 2
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock,Flux2SingleTransformerBlock
  fsdp_forward_prefetch: true
  fsdp_sync_module_states: false
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_use_orig_params: false
  fsdp_activation_checkpointing: true
  fsdp_reshard_after_forward: true
  fsdp_cpu_ram_efficient_loading: false
main_training_function: main
machine_rank: 0
main_process_ip: localhost
main_process_port: 6878
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@leisuzz leisuzz force-pushed the fsdp branch 3 times, most recently from 559a7a3 to 343b12a Compare December 18, 2025 12:31
@leisuzz
Copy link
Contributor Author

leisuzz commented Dec 18, 2025

@sayakpaul Please take a look at this PR. Thank you for your help :)

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool work, thank you for this!

Just confirming -- this is FSDP2, right?

Also, could you provide an example command and your setup so that we can test?

Additionally, can we similarly the denoiser like this?

Comment on lines 1536 to 1549
original_text_encoder = text_encoding_pipeline.text_encoder
transformer_layer = type(original_text_encoder.model.language_model.layers[0])
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={transformer_layer})

text_encoder_fsdp = FSDP(
original_text_encoder,
device_id=accelerator.device,
sharding_strategy=ShardingStrategy.FULL_SHARD,
cpu_offload=CPUOffload(offload_params=args.offload),
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
use_orig_params=True,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to wrap this into a utility function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've modified it, please take a look

from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

original_text_encoder = text_encoding_pipeline.text_encoder
transformer_layer = type(original_text_encoder.model.language_model.layers[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be configurable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've modified it, please take a look

@leisuzz
Copy link
Contributor Author

leisuzz commented Dec 19, 2025

Very cool work, thank you for this!

Just confirming -- this is FSDP2, right?

Also, could you provide an example command and your setup so that we can test?

Additionally, can we similarly the denoiser like this?

It is FSDP2, and the script is:

accelerate launch --config_file ${config_file} \
  ./train_dreambooth_lora_flux2_img2img.py \
  --pretrained_model_name_or_path=$model_name  \
  --dataset_name=$dataset_name \
  --image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
  --resolution=$resolution \
  --train_batch_size=$batch_size \
  --guidance_scale=1 \
  --mixed_precision=$mixed_precision \
  --max_grad_norm=1 \
  --dataloader_num_workers=0 \
  --gradient_accumulation_steps=$gradient_accumulation_steps \
  --learning_rate=1e-05 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --gradient_checkpointing \
  --max_train_steps=$max_train_steps \
  --checkpointing_steps=5000 \
  --enable_npu_flash_attention \
  --rank=16 \
  --seed="0" \
  --skip_final_inference \
  --cache_latents \
  --offload \
  --fsdp_text_encoder \
  --output_dir=${output_path} \

The accelerate config is:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_version: 2
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock,Flux2SingleTransformerBlock
  fsdp_forward_prefetch: true
  fsdp_sync_module_states: false
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_use_orig_params: false
  fsdp_activation_checkpointing: true
  fsdp_reshard_after_forward: true
  fsdp_cpu_ram_efficient_loading: false
main_training_function: main
machine_rank: 0
main_process_ip: localhost
main_process_port: 6878
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants