- 
                Notifications
    You must be signed in to change notification settings 
- Fork 587
[SimpleFSDP] add manual bucketing pass #1881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
c20775e    to
    a5c4027      
    Compare
  
    8fa2426    to
    71cb39b      
    Compare
  
    71cb39b    to
    27bcc7d      
    Compare
  
    27bcc7d    to
    3c46d64      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks nice. Had some comments.
| "1D+aot_eager_autobucketing", | ||
| "1d_aot_eager_autobucketing", | ||
| ), | ||
| # TODO(ruisizhang123): add back after autobucketing pass is mature | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we add a manual bucketing test?
we should also add one in the loss unit test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a few to do items for reordering. I think it'd be better to add the tests after the API is stable?
| """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing""" | ||
| """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """ | ||
|  | ||
| manual_bucketed_modules: list[str] = field(default_factory=list) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to have instructions about this field. E.g. it's not super obvious what this means "tok_embeddings,layers.[0-5],norm+output", as it involves regex I have a guess, but users might not.
btw, are the list separated by ,?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list is separated by ,; but I didn't do explicit spilting here. essentially, it's similar to filter_fqns here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add fsdp_ prefix? Or do we imagine this field will be use for other use cases, if so what are the use cases?
3c46d64    to
    d62eb25      
    Compare
  
    | manual_overlap_bucketing, | ||
| ) | ||
|  | ||
| torch._inductor.config.allow_buffer_reuse = False | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In bucketing, we shouldn't allow buffer reuse; otherwise newly created comm copy-in/copy-out buffers will reuse prev buffer, which messed up the copied out data value and made the loss nan.
| class Compile: | ||
| model_backend_override: str | None = None | ||
| """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing""" | ||
| """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """ | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should make this subclass torchtitan.config.job_config.Compile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's additional config extended from job_config.Comfile. not sure wdym here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
something like class Compile(torchtitan.config.job_config.Compile)
d62eb25    to
    1453136      
    Compare
  
    | elif backend_name == "aot_eager_manualbucketing": | ||
| # Perform manual optimization in aten fx-level and execute code in aot_eager backend | ||
| # The manualbucketing logic is here: | ||
| bucketing_modules = compile_config.manual_bucketed_modules | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This local variable is not used.
| """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing""" | ||
| """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """ | ||
|  | ||
| manual_bucketed_modules: list[str] = field(default_factory=list) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add fsdp_ prefix? Or do we imagine this field will be use for other use cases, if so what are the use cases?
| Manual bucket modules based on user specified FQNs | ||
| Abbreviations are supported to make specifying modules easier. | ||
| Currently, the following abbreviations are available: | ||
| (1) layers.[0-2] -> [layers.0], [layers.1], [layers.2] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now user has to know how many layer a particular flavor of model has, when applying manual bucketing. Do you think we can improve the UX by automatically resolving the number of layers?
I even think we shouldn't expose this option in toml. In toml user should just need to specify bucketing_mode = "none", "transformer_block", "auto"
And if it's transformer_block, we explicitly iterate over all the transformerblocks and pass the expanded fqns in manual_overlap_bucketing. That means manual_overlap_bucketing don't need to be smart about abbreviations.
Happy to hear people's thoughts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean we could have another "manual" mode supporting Manual bucket modules if people really want to override, but a good default of transformer block level bucketing should be enabled more easily.
This PR adds support for aten-level manual bucketing in SimpleFSDP+
aot_eagerbackend. Dependent on PyTorch PRTODO List:
manual_bucketed_modules. It would be very easy to miss some of model modules. (cc. @xmfan @SherlockNoMad )I'll address the TODO items in follow up PRs. Let's start with this simple FSDP+TP+llama3 PR.
aot_eagerbackend)Llama 3-8B
Example SimpleFSDP 1D overlapping trace:
Example SimpleFSDP 2D overlapping trace:

FSDP-only:

FSDP+TP:
