Skip to content

Conversation

@ruisizhang123
Copy link
Contributor

@ruisizhang123 ruisizhang123 commented Oct 15, 2025

This PR adds support for aten-level manual bucketing in SimpleFSDP+aot_eager backend. Dependent on PyTorch PR

TODO List:

  • We should have better way of handling region info other than a list of str FQNs in current manual_bucketed_modules. It would be very easy to miss some of model modules. (cc. @xmfan @SherlockNoMad )
  • Currently, the reordering happens under the hood and overlap with last/next compute. We should allow users to specify which module they want to reorder.
  • Loss difference on multi-node training
  • DSV3 manual bucketing

I'll address the TODO items in follow up PRs. Let's start with this simple FSDP+TP+llama3 PR.

  1. Performance (FSDP2 under eager mode, SimpleFSDP uses aot_eager backend)

Llama 3-8B

  • Performance (All Batch_size = 1). (The slower TPS on Single Node is sort of as expected, since FSDP2 handles copy-in/out in two different streams, whereas SimpleFSDP handles copy-in/out in the same stream)
Node Method Parallelism Memory TPS Trace
1-Node (8H100) SimpleFSDP FSDP=8 40.96GiB(43.12%) 7,227 LINK
1-Node (8H100) FSDP2-eager FSDP=8 47.82GiB(50.35%) 7,380 LINK
8-Node (64H100) SimpleFSDP FSDP=64 29.37GiB 4,984
8-Node (64H100) FSDP2 FSDP=64 31.41GiB 5,097
1-Node (8H100) SimpleFSDP FSDP=4 TP=2 28.28GiB(29.77%) 5,881 LINK
1-Node (8H100) FSDP2 FSDP=4 TP=2 35.33GiB(37.20%) 5,898 LINK
8-Node (64H100) SimpleFSDP FSDP=8 TP=8
8-Node (64H100) FSDP2 FSDP=8 TP=8

Example SimpleFSDP 1D overlapping trace:

Screenshot 2025-10-16 at 10 49 55 AM

Example SimpleFSDP 2D overlapping trace:
Screenshot 2025-10-26 at 6 00 51 PM

  • Bitwise Loss:

FSDP-only:
Screenshot 2025-10-17 at 10 41 56 AM

FSDP+TP:
Screenshot 2025-10-26 at 9 03 58 PM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 15, 2025
@ruisizhang123 ruisizhang123 marked this pull request as draft October 15, 2025 17:41
@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch 3 times, most recently from c20775e to a5c4027 Compare October 23, 2025 21:27
@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch 4 times, most recently from 8fa2426 to 71cb39b Compare October 27, 2025 04:51
@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch from 71cb39b to 27bcc7d Compare October 28, 2025 07:06
@ruisizhang123 ruisizhang123 marked this pull request as ready for review October 28, 2025 07:06
@ruisizhang123 ruisizhang123 changed the title [WIP][SimpleFSDP] add manual bucketing pass [SimpleFSDP] add manual bucketing pass Oct 28, 2025
@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch from 27bcc7d to 3c46d64 Compare October 28, 2025 17:57
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks nice. Had some comments.

"1D+aot_eager_autobucketing",
"1d_aot_eager_autobucketing",
),
# TODO(ruisizhang123): add back after autobucketing pass is mature
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add a manual bucketing test?

we should also add one in the loss unit test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few to do items for reordering. I think it'd be better to add the tests after the API is stable?

"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """

manual_bucketed_modules: list[str] = field(default_factory=list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to have instructions about this field. E.g. it's not super obvious what this means "tok_embeddings,layers.[0-5],norm+output", as it involves regex I have a guess, but users might not.

btw, are the list separated by ,?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list is separated by ,; but I didn't do explicit spilting here. essentially, it's similar to filter_fqns here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add fsdp_ prefix? Or do we imagine this field will be use for other use cases, if so what are the use cases?

@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch from 3c46d64 to d62eb25 Compare October 30, 2025 04:44
manual_overlap_bucketing,
)

torch._inductor.config.allow_buffer_reuse = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens by default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In bucketing, we shouldn't allow buffer reuse; otherwise newly created comm copy-in/copy-out buffers will reuse prev buffer, which messed up the copied out data value and made the loss nan.

class Compile:
model_backend_override: str | None = None
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should make this subclass torchtitan.config.job_config.Compile

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's additional config extended from job_config.Comfile. not sure wdym here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like class Compile(torchtitan.config.job_config.Compile)

@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch from d62eb25 to 1453136 Compare October 30, 2025 05:21
elif backend_name == "aot_eager_manualbucketing":
# Perform manual optimization in aten fx-level and execute code in aot_eager backend
# The manualbucketing logic is here:
bucketing_modules = compile_config.manual_bucketed_modules
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This local variable is not used.

"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """

manual_bucketed_modules: list[str] = field(default_factory=list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add fsdp_ prefix? Or do we imagine this field will be use for other use cases, if so what are the use cases?

Manual bucket modules based on user specified FQNs
Abbreviations are supported to make specifying modules easier.
Currently, the following abbreviations are available:
(1) layers.[0-2] -> [layers.0], [layers.1], [layers.2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now user has to know how many layer a particular flavor of model has, when applying manual bucketing. Do you think we can improve the UX by automatically resolving the number of layers?

I even think we shouldn't expose this option in toml. In toml user should just need to specify bucketing_mode = "none", "transformer_block", "auto"
And if it's transformer_block, we explicitly iterate over all the transformerblocks and pass the expanded fqns in manual_overlap_bucketing. That means manual_overlap_bucketing don't need to be smart about abbreviations.

Happy to hear people's thoughts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we could have another "manual" mode supporting Manual bucket modules if people really want to override, but a good default of transformer block level bucketing should be enabled more easily.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants