-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Add FSDP option for Flux2 #12860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add FSDP option for Flux2 #12860
Conversation
559a7a3 to
343b12a
Compare
|
@sayakpaul Please take a look at this PR. Thank you for your help :) |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool work, thank you for this!
Just confirming -- this is FSDP2, right?
Also, could you provide an example command and your setup so that we can test?
Additionally, can we similarly the denoiser like this?
| original_text_encoder = text_encoding_pipeline.text_encoder | ||
| transformer_layer = type(original_text_encoder.model.language_model.layers[0]) | ||
| auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={transformer_layer}) | ||
|
|
||
| text_encoder_fsdp = FSDP( | ||
| original_text_encoder, | ||
| device_id=accelerator.device, | ||
| sharding_strategy=ShardingStrategy.FULL_SHARD, | ||
| cpu_offload=CPUOffload(offload_params=args.offload), | ||
| auto_wrap_policy=auto_wrap_policy, | ||
| backward_prefetch=BackwardPrefetch.BACKWARD_PRE, | ||
| limit_all_gathers=True, | ||
| use_orig_params=True, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try to wrap this into a utility function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've modified it, please take a look
| from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy | ||
|
|
||
| original_text_encoder = text_encoding_pipeline.text_encoder | ||
| transformer_layer = type(original_text_encoder.model.language_model.layers[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be configurable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've modified it, please take a look
It is FSDP2, and the script is: The accelerate config is: |
What does this PR do?
The text encoder is too large in Flux2, and offload to cpu requires a lot of time to get the prompt.
It is FSDP2, and the script is:
The accelerate config is:
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.