Add MIMO forward step and per-token loss for the stock schedule#5376
Draft
yashaswikarnati wants to merge 1 commit into
Draft
Add MIMO forward step and per-token loss for the stock schedule#5376yashaswikarnati wants to merge 1 commit into
yashaswikarnati wants to merge 1 commit into
Conversation
Add the role-agnostic forward_step_func consumed by the stock MCore pipeline schedule (forward_backward_pipelining_without_interleaving) for MIMO training. It pulls a batch, moves it (including PackedSeqParams tensor fields) to CUDA, calls model(**batch), and returns the per-token-loss closure. num_tokens is returned as an integer tensor to match the stock schedule's int total_num_tokens accumulation (a float dtype casts-faults on last-PP-stage ranks and deadlocks the collective). This is the non-colocated path. The optional mark_modality_participation hook (from MM3, in-flight NVIDIA#5286) is imported under try/except so this module stands alone. The unit test travels with it. Validated in the 8-GPU 20L Nemotron VLM e2e (trains + checkpoint save/resume, lm loss 12.18->11.54 across resume). Signed-off-by: ykarnati <ykarnati@nvidia.com>
yashaswikarnati
commented
Jun 17, 2026
| # no-op when the module is absent. | ||
| # TODO(NMFW-516): make this import + call unconditional once MM3 lands. | ||
| try: | ||
| from examples.mimo.training.grad_sync import mark_modality_participation |
Contributor
Author
There was a problem hiding this comment.
always try imports top of file if possible? also why we need try except ?
yashaswikarnati
commented
Jun 17, 2026
| @@ -0,0 +1,141 @@ | |||
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | |||
|
|
|||
| """Forward step and per-token loss for MIMO training on the stock schedule. | |||
Contributor
Author
There was a problem hiding this comment.
remove the usage of word stock. also very verbose doc string or comment. just concise couple lines is fine.
yashaswikarnati
commented
Jun 17, 2026
| triggers ``RuntimeError: result type Float can't be cast to the desired | ||
| output type Int`` on last-PP-stage ranks, which deadlocks the collective). | ||
|
|
||
| Scope note: this is the NON-COLOCATED path. The encoder-prefetch fast-path |
Contributor
Author
There was a problem hiding this comment.
remove comments explaining whats ommitted
yashaswikarnati
commented
Jun 17, 2026
| def loss_func(output_tensor: torch.Tensor, *, loss_mask: torch.Tensor): | ||
| """Return terminal language-model loss sum, local token count, and logging tensors. | ||
|
|
||
| Returns the stock per-token-loss 3-tuple ``(loss_sum, num_tokens, |
Contributor
Author
There was a problem hiding this comment.
dont use the word stock
yashaswikarnati
commented
Jun 17, 2026
| ) | ||
|
|
||
|
|
||
| def mimo_forward_step(data_iterator, model): |
Contributor
Author
There was a problem hiding this comment.
what is the caller site of mimo forward step - where does the output tenosr (is it loss value at each token index in the batch?) then loss func closure - where are they used ?
yashaswikarnati
commented
Jun 17, 2026
| raise AssertionError("expected RuntimeError on missing loss_mask") | ||
|
|
||
|
|
||
| def test_loss_func_shape_mismatch_raises(): |
Contributor
Author
There was a problem hiding this comment.
remove these kinds of dead tests, not adding value
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Add the role-agnostic
forward_step_funcconsumed by the stock MCore pipeline schedule (forward_backward_pipelining_without_interleaving) for MIMO training. It pulls a batch, moves it (includingPackedSeqParamstensor fields) to CUDA, callsmodel(**batch), and returns the per-token-loss closure.num_tokensis returned as an integer tensor to match the stock schedule's inttotal_num_tokensaccumulation (a float dtype casts-faults on last-PP-stage ranks and deadlocks).Why
Leaf module. The optional
mark_modality_participationhook (from in-flight #5286) is imported undertry/except ImportErrorso this module stands alone. The unit test travels with it.Validation
Validated in the 8-GPU 20L Nemotron VLM e2e (trains + checkpoint save/resume, lm loss 12.18->11.54 across resume).
CODEOWNERS
examples/mimo/...+tests/unit_tests/...-> repo default owners.🤖 Generated with Claude Code