Skip to content

Add MIMO forward step and per-token loss for the stock schedule#5376

Draft
yashaswikarnati wants to merge 1 commit into
NVIDIA:mainfrom
yashaswikarnati:x3-forward-step
Draft

Add MIMO forward step and per-token loss for the stock schedule#5376
yashaswikarnati wants to merge 1 commit into
NVIDIA:mainfrom
yashaswikarnati:x3-forward-step

Conversation

@yashaswikarnati

Copy link
Copy Markdown
Contributor

What

Add the role-agnostic forward_step_func consumed by the stock MCore pipeline schedule (forward_backward_pipelining_without_interleaving) for MIMO training. It pulls a batch, moves it (including PackedSeqParams tensor fields) to CUDA, calls model(**batch), and returns the per-token-loss closure. num_tokens is returned as an integer tensor to match the stock schedule's int total_num_tokens accumulation (a float dtype casts-faults on last-PP-stage ranks and deadlocks).

Why

Leaf module. The optional mark_modality_participation hook (from in-flight #5286) is imported under try/except ImportError so this module stands alone. The unit test travels with it.

Validation

Validated in the 8-GPU 20L Nemotron VLM e2e (trains + checkpoint save/resume, lm loss 12.18->11.54 across resume).

CODEOWNERS

examples/mimo/... + tests/unit_tests/... -> repo default owners.

🤖 Generated with Claude Code

Add the role-agnostic forward_step_func consumed by the stock MCore
pipeline schedule (forward_backward_pipelining_without_interleaving) for
MIMO training. It pulls a batch, moves it (including PackedSeqParams tensor
fields) to CUDA, calls model(**batch), and returns the per-token-loss
closure. num_tokens is returned as an integer tensor to match the stock
schedule's int total_num_tokens accumulation (a float dtype casts-faults on
last-PP-stage ranks and deadlocks the collective).

This is the non-colocated path. The optional mark_modality_participation
hook (from MM3, in-flight NVIDIA#5286) is imported under try/except so this
module stands alone. The unit test travels with it.

Validated in the 8-GPU 20L Nemotron VLM e2e (trains + checkpoint
save/resume, lm loss 12.18->11.54 across resume).

Signed-off-by: ykarnati <ykarnati@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 16, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

# no-op when the module is absent.
# TODO(NMFW-516): make this import + call unconditional once MM3 lands.
try:
from examples.mimo.training.grad_sync import mark_modality_participation

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always try imports top of file if possible? also why we need try except ?

@@ -0,0 +1,141 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

"""Forward step and per-token loss for MIMO training on the stock schedule.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the usage of word stock. also very verbose doc string or comment. just concise couple lines is fine.

triggers ``RuntimeError: result type Float can't be cast to the desired
output type Int`` on last-PP-stage ranks, which deadlocks the collective).

Scope note: this is the NON-COLOCATED path. The encoder-prefetch fast-path

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comments explaining whats ommitted

def loss_func(output_tensor: torch.Tensor, *, loss_mask: torch.Tensor):
"""Return terminal language-model loss sum, local token count, and logging tensors.

Returns the stock per-token-loss 3-tuple ``(loss_sum, num_tokens,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont use the word stock

)


def mimo_forward_step(data_iterator, model):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the caller site of mimo forward step - where does the output tenosr (is it loss value at each token index in the batch?) then loss func closure - where are they used ?

raise AssertionError("expected RuntimeError on missing loss_mask")


def test_loss_func_shape_mismatch_raises():

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove these kinds of dead tests, not adding value

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant