-
Notifications
You must be signed in to change notification settings - Fork 245
perf: Reduce memory footprint for ChunkedDistribuedLogProb #1895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
aaa5b0b to
3a6cd1a
Compare
3a6cd1a to
d1ce6b9
Compare
Signed-off-by: mloh <mloh@nvidia.com>
d1ce6b9 to
f10e4e4
Compare
📝 WalkthroughWalkthroughModified the backward gradient computation in Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@nemo_rl/distributed/model_utils.py`:
- Line 223: The preallocated grad_input created with
torch.empty_like(vocab_parallel_logits) inherits bf16/fp16 and thus silently
truncates float32 gradient arithmetic; change the allocation of grad_input in
model_utils.py so it is explicitly float32 (e.g., create an empty tensor with
the same shape/device but dtype=torch.float32) so the subsequent operations
(is_chosen.float().sub_(softmax_output), the copy_ into grad_input, and the mul_
call) run in float32 and preserve gradient precision; update the grad_input
creation site (the grad_input variable near the vocab_parallel_logits usage) to
allocate float32 and ensure device/shape match.
🧹 Nitpick comments (1)
nemo_rl/distributed/model_utils.py (1)
209-260: Consider applying the same preallocation pattern toChunkedDistributedGatherLogprob.backwardandChunkedDistributedEntropy.backward.Both sibling backward methods (lines 334–386 and 1041–1060) still use the list-accumulate +
torch.catpattern. They'd benefit from the same preallocation optimization for consistency and memory reduction.
Signed-off-by: mloh <mloh@nvidia.com>
7725656 to
ac29781
Compare
Should we implement this to |
What does this PR do ?
Reduces the peak memory footprint when using chunk during loss function. (when sequence packing is disabled)
Issues
Improve loss function memory usage
Details
Problem
Current approach stores each chunk's gradient in a list, followed by
torch.catat the end to return the entire gradient tensor. This resulted in at least 2 copies of gradient tensor to exist during peak.Modification
Preallocate a gradient tensor and copy each chunk's gradient inplace.
Additional Gains
Reducing the chunk size observes less reduction in peak memory in current approach due to some intermediate tensors overlapping (lazy delete).
With explicit deallocation of intermediate tensor (using
del), the memory footprint reduction is more significant (-0.4GiBto-0.6GiB)Caveats
The modification does not reduce the peak memory when sequence packing is enable. This is because SequencePackingLossWrapper is being used and default
torch.autogradhandles the backprop resulting in this undesirable behavior.This issue should be able to solve this once there is a customize
torch.autograd.Functionthat can handle sequence packing.Summary by CodeRabbit