Skip to content

Conversation

@nujoug
Copy link

@nujoug nujoug commented Feb 6, 2026

What does this PR do ?

Reduces the peak memory footprint when using chunk during loss function. (when sequence packing is disabled)

Issues

Improve loss function memory usage

Details

Problem

Current approach stores each chunk's gradient in a list, followed by torch.cat at the end to return the entire gradient tensor. This resulted in at least 2 copies of gradient tensor to exist during peak.

Screenshot 2026-02-06 at 10 45 31 AM

Modification

Preallocate a gradient tensor and copy each chunk's gradient inplace.

Screenshot 2026-02-06 at 10 57 57 AM

Additional Gains

Reducing the chunk size observes less reduction in peak memory in current approach due to some intermediate tensors overlapping (lazy delete).

Screenshot 2026-02-06 at 11 09 39 AM Screenshot 2026-02-06 at 11 14 37 AM

With explicit deallocation of intermediate tensor (using del), the memory footprint reduction is more significant (-0.4GiB to -0.6GiB)

Screenshot 2026-02-06 at 11 28 39 AM Screenshot 2026-02-06 at 11 27 07 AM

Caveats

The modification does not reduce the peak memory when sequence packing is enable. This is because SequencePackingLossWrapper is being used and default torch.autograd handles the backprop resulting in this undesirable behavior.

Screenshot 2026-02-06 at 11 36 53 AM

This issue should be able to solve this once there is a customize torch.autograd.Function that can handle sequence packing.

Summary by CodeRabbit

  • Performance Improvements
    • Optimized gradient computation during the backward pass to reduce memory usage and improve efficiency through better memory allocation and cleanup strategies.

@nujoug nujoug self-assigned this Feb 6, 2026
@nujoug nujoug force-pushed the mloh/loss_fn_memory_footprint branch from aaa5b0b to 3a6cd1a Compare February 6, 2026 18:31
@nujoug nujoug changed the title perf: Reduce memory footprint for ChunkedDistribuedLogProb Draft: perf: Reduce memory footprint for ChunkedDistribuedLogProb Feb 6, 2026
@nujoug nujoug linked an issue Feb 6, 2026 that may be closed by this pull request
@nujoug nujoug force-pushed the mloh/loss_fn_memory_footprint branch from 3a6cd1a to d1ce6b9 Compare February 6, 2026 19:47
@nujoug nujoug changed the title Draft: perf: Reduce memory footprint for ChunkedDistribuedLogProb perf: Reduce memory footprint for ChunkedDistribuedLogProb Feb 6, 2026
@nujoug nujoug marked this pull request as ready for review February 6, 2026 19:49
@nujoug nujoug requested a review from a team as a code owner February 6, 2026 19:49
Signed-off-by: mloh <mloh@nvidia.com>
@nujoug nujoug force-pushed the mloh/loss_fn_memory_footprint branch from d1ce6b9 to f10e4e4 Compare February 6, 2026 19:51
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 6, 2026

📝 Walkthrough

Walkthrough

Modified the backward gradient computation in ChunkedDistributedLogprob to use preallocated gradient tensors with in-place chunk copying instead of accumulating gradients in a list and concatenating them. Added explicit in-loop deallocation of temporary tensors to optimize memory usage.

Changes

Cohort / File(s) Summary
Backward Pass Optimization
nemo_rl/distributed/model_utils.py
Refactored ChunkedDistributedLogprob.backward to preallocate grad_input tensor and copy chunk results into corresponding slices, replacing list accumulation and final concatenation. Added in-loop deletion of temporary tensors (softmax_output, is_chosen, logits) for explicit memory deallocation. Updated single-chunk path to use preallocated chunk views with copy_ operations.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR lacks test results for numerical correctness and convergence despite demonstrating memory improvements. A dtype precision concern exists where bf16/fp16 preallocation conflicts with float32 gradient computations. Add unit tests verifying gradient equivalence, small-scale convergence comparisons, and address dtype precision by either adding dtype=torch.float32 to preallocation or justifying the dtype change with test results.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title directly describes the main optimization: reducing memory footprint for ChunkedDistributedLogProb, which is the primary objective of this performance-focused PR.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch mloh/loss_fn_memory_footprint

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@nemo_rl/distributed/model_utils.py`:
- Line 223: The preallocated grad_input created with
torch.empty_like(vocab_parallel_logits) inherits bf16/fp16 and thus silently
truncates float32 gradient arithmetic; change the allocation of grad_input in
model_utils.py so it is explicitly float32 (e.g., create an empty tensor with
the same shape/device but dtype=torch.float32) so the subsequent operations
(is_chosen.float().sub_(softmax_output), the copy_ into grad_input, and the mul_
call) run in float32 and preserve gradient precision; update the grad_input
creation site (the grad_input variable near the vocab_parallel_logits usage) to
allocate float32 and ensure device/shape match.
🧹 Nitpick comments (1)
nemo_rl/distributed/model_utils.py (1)

209-260: Consider applying the same preallocation pattern to ChunkedDistributedGatherLogprob.backward and ChunkedDistributedEntropy.backward.

Both sibling backward methods (lines 334–386 and 1041–1060) still use the list-accumulate + torch.cat pattern. They'd benefit from the same preallocation optimization for consistency and memory reduction.

Signed-off-by: mloh <mloh@nvidia.com>
@nujoug nujoug force-pushed the mloh/loss_fn_memory_footprint branch from 7725656 to ac29781 Compare February 6, 2026 21:21
@nujoug
Copy link
Author

nujoug commented Feb 6, 2026

Nitpick comments (1)

nemo_rl/distributed/model_utils.py (1)> 209-260: Consider applying the same preallocation pattern to ChunkedDistributedGatherLogprob.backward and ChunkedDistributedEntropy.backward.

Both sibling backward methods (lines 334–386 and 1041–1060) still use the list-accumulate + torch.cat pattern. They'd benefit from the same preallocation optimization for consistency and memory reduction.

Should we implement this to ChunkedDistributedGatherLogprob.backward and ChunkedDistributedEntropy.backward?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Improve loss function memory usage

1 participant