Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENHANCEMENT] Add support for Apex RMSNorm for use in qk-norm #1261

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

wdevazelhes
Copy link

This PR allows to use qk-layernorm even when one has set normalization: "RMSNorm" (which threw an error before cf. msg from @SeunghyunSEO here from original PR here: indeed, only LayerNorm was allowed when using qk-layernorm in gpt, not RMSNorm, since this commit)

What this PR does is that it:

  • adds FusedRMSNorm in megatroncore/fusions/fused_layer_norm.py, which will serve as a wrapper to apex's FusedRMSNormAffineFunction (the same way as the existing megatroncore/fusions/fused_layer_norm.py:FusedLayerNorm is a wrapper to apex's FusedLayerNormAffineFunction) (note that while FusedLayerNorm can also support for the persist version of fused layer norm (apex's contrib.layer_norm.FastLayerNorm), there's no such persist version in Apex for RMSNorm so we just use a non-persist version)
  • adds a wrapper similar to TENorm but for Apex, which we call ApexNorm (TENorm is a wrapper that gets transformed into either te.pytorch.LayerNorm or te.pytorch.RMSNorm depending on whether normalization: 'RMSNorm' or normalization: 'LayerNorm' is used in the config). For that we add an ApexFusedNorm which gets transformed into either megatroncore/fusions/fused_layer_norm.py:FusedLayerNorm, or the fused_layer_norm.py:RMSNorm just added above

Advantage: this way if we specify LayerNorm or RMSNorm for --normalization, we'll use that for qk-normalization, and it'll try using first the Apex one, and fallback to the python one if Apex is not installed (we don't use the TE one as it seems to be unstable as was put in comments in the original code and as was forced by this commit)

Note: for the implementation of FusedRMSNorm I just copy/pasted the code from FusedLayerNorm and changed it to be doing an RMSNorm

Tagging @SeunghyunSEO and @ftgreat as you could be interested in this PR given that PR of yours. Tagging @jaredcasper and @jon-barker, also the authors of this commit could be interested (Mike Chrzanowski and Shanmugam Ramasamy).

@wdevazelhes wdevazelhes marked this pull request as draft October 28, 2024 11:01
@wdevazelhes wdevazelhes marked this pull request as ready for review October 28, 2024 11:40
k_layernorm=FusedLayerNorm if qk_layernorm else IdentityOp,
# for QKLayerNorm; we instead use the Apex implementation (or pytorch
# one if Apex is not installed).
q_layernorm=LNImpl if qk_layernorm else IdentityOp,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that at a few other places in the code (here, here and here ), TENorm is still used for qk-normalization (even if according to the comment above and this commit, using TENorm for qk-layernorm is unstable).
Let me know if I should also modify these other places 👍

Copy link

@SeunghyunSEO SeunghyunSEO Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in my case, i exactly did same patch for my own megatron fork.
so changes in this PR looks good to me, but i think we should clarify why it's happening?
like you said, someone still use tenorm for qknorm but model converges.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SeunghyunSEO thanks!
Regarding the clarification on why this is happening, do you mean that we should check why the TE implementation is diverging ? (I didn't try it myself, I just assumed it does based on your PR and also based on the comment in this commit)

Copy link

@SeunghyunSEO SeunghyunSEO Nov 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SeunghyunSEO thanks!

Regarding the clarification on why this is happening, do you mean that we should check why the TE implementation is diverging ? (I didn't try it myself, I just assumed it does based on your PR and also based on the comment in this commit)

i mean when additional feature is added, at least we should know whether it is necessary or not.
any megatron or TE maintainers know why TEnorm for qk norm diverge sometimes??? i cc sir deepak because he is the only one i communicate with! @deepakn94 (sry for the wrong tagging but i ask you to tag expert in numerical precision issue)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense I agree 👍 Thanks for tagging @deepakn94 🙏
Also I think Mike Chrzanowski and Shanmugam Ramasamy can be tagged if Nvidia folks know their contact ? (because I couldn't find their github handle)
Because they are the one who created this commit which prevented the use of TENorm, and also Mike Chrzanowski wrote a paper using qk-layernorm 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants