-
Notifications
You must be signed in to change notification settings - Fork 22
Cherry picking keep_fp8_weight_transpose_cache flag refactor and fsdp2 fp8 autocast all gather commits #389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release_v2.4_rocm
Are you sure you want to change the base?
Cherry picking keep_fp8_weight_transpose_cache flag refactor and fsdp2 fp8 autocast all gather commits #389
Conversation
* Initial commit * Removed rocm_utils * Added comment and bug fixes * Grouped IS_HIP_EXTENSION with the property assignment * Reverted transpose.cpp, removed keep_fp8_transpose_cache flag from grouped_linear, removed manual clearing of tensors in modules * Aligning grouped_linear module with upstream * Reverted tests to use _test_granular_accuracy_with_fp8 multiple times as needed * Added comments back * Moved comment to the test --------- Co-authored-by: sudhu2k <[email protected]>
* Initial commit * Removed Print statements, added keep_fp8_transpose cache integration with fsdp2 * Added use_fsdp flag to Linear module, added profile code, added test code, added all reduce for amax * Fixed unit test * Removing all reduce code for amax since by default TE does all reduce when torch.distributed is initialized. * reverting case where out is already present * Added unit test with regualr sgpu training * Modified unit test to compare FSDP2 with DDP * bug fixes * Code cleaning up * Initial commit to add MXFP8 * Added fp8 current scaling. * Added MXFP8, Modified unit test to run based on recipes * Extended use_fsdp to layernorm linear and layernorm mlp * Moved amax reduce from forward to backward for fsdp2 * Added automatic detection of use fsdp from base module * Use SKIP_FP8_REDUCTION_FOR_FSDP2 in backward for check if need to do forward reduce * Added memory profile code, added a check before setting SKIP_FP8_REDUCTION_FOR_FSDP2 * Fix for fused optimizer, changed _elem to _data, code clean up * Fixed layernorm mlp * Code cleanup and added test to pytorch.sh * Removed whitespaces * Fixed comments and license * Added guards * Added reduce for forward in cuda graph backward, added code to remove test artifacts, reverted upstream test file --------- Co-authored-by: sudhu2k <[email protected]>
|
@ipanfilo, @wangye805 The PR is ready for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AMD copyright is needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But nothing significant was added from our side, this PR actually removes the code that was added for keep_fp8_weight_transpose_cache which means technically we are reverting to upstream code for grouped_linear.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it match upstream now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does, yes.
ipanfilo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's wait for CI to be run and passed
…spose-refactor-cherrypick-rv2.2
Description
This PR is used for cherry picking #349 and #328 into release_v2.4 branch
Type of change
Checklist: