Skip to content

feat: CUDA graph support for packed sequence (variable-length) training#3869

Draft
seonjinn wants to merge 2 commits intoNVIDIA:mainfrom
seonjinn:sj/cudagraph-packedseq
Draft

feat: CUDA graph support for packed sequence (variable-length) training#3869
seonjinn wants to merge 2 commits intoNVIDIA:mainfrom
seonjinn:sj/cudagraph-packedseq

Conversation

@seonjinn
Copy link
Contributor

@seonjinn seonjinn commented Mar 14, 2026

Enable CUDA graph capture/replay for packed sequence (SFT) training
with Mamba-Transformer hybrid models.

Problem

CUDA graphs require fixed-shape tensor inputs, but packed sequences
have a variable number of documents per micro-batch, so cu_seqlens
varies in length. This is incompatible with CUDA graph capture.

Solution

Pad cu_seqlens to a configurable fixed size for CUDA graph replay.
If a batch exceeds this size, fall back to eager forward. This gives
CG benefits for most batches while maintaining correctness for all.

Key Changes

  • PackedSeqParams: cu_seqlens padding, shared CG buffers across
    layers, dummy PSP for graph capture
  • TransformerLayer: CG capture/replay with fallback for attention
  • MambaLayer: CG capture/replay with pre-computed seq_idx
  • MambaMixer: Avoid dynamic allocations inside CG (seq_idx reuse,
    output_size parameter to avoid GPU->CPU sync)
  • pretrain_mamba: cu_seqlens padding in get_batch()
  • New arg: --cuda-graph-max-packed-seqs
  • te_patches/: Patch TE context_parallel to avoid GPU->CPU sync
    during CUDA graph capture (applied via PYTHONPATH import hook)

Usage

1. Training script arguments

--cuda-graph-impl transformer_engine
--cuda-graph-scope mamba attn
--cuda-graph-max-packed-seqs <MAX_SEQS> \

MAX_SEQS controls the fixed cu_seqlens size:

  • Smaller value = less flash_attn padding overhead, more fallbacks
  • Larger value = fewer fallbacks, more padding overhead
  • Set based on your dataset's N_docs distribution (e.g., P90 or P99)

2. Apply TE context_parallel patch

export PYTHONPATH=<repo_root>/te_patches:${PYTHONPATH}

Required to avoid GPU->CPU sync errors during CG capture.

3. Example

export PYTHONPATH=/path/to/Megatron-LM/te_patches:${PYTHONPATH}

torchrun pretrain_mamba.py
--sft
--cuda-graph-impl transformer_engine
--cuda-graph-scope mamba attn
--cuda-graph-max-packed-seqs 64
--context-parallel-size 32
--tensor-model-parallel-size 8
--expert-model-parallel-size 64
...

4. Choosing MAX_SEQS

Analyze your dataset's packed sequence distribution:

  • Example: P50=12, P90=43, P99=106, max=358
  • 64 covers ~96% of batches
  • 106 covers ~99% but more padding overhead
  • Setting to max covers 100% but padding overhead may
    outweigh CG benefit

What does this PR do ?

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

@seonjinn seonjinn requested review from a team as code owners March 14, 2026 00:25
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 14, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@svcnvidia-nemo-ci svcnvidia-nemo-ci marked this pull request as draft March 14, 2026 00:25
@github-actions
Copy link
Contributor

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

Enable CUDA graph capture/replay for packed sequence (SFT) training
with Mamba-Transformer hybrid models.

## Problem

CUDA graphs require fixed-shape tensor inputs, but packed sequences
have a variable number of documents per micro-batch, so cu_seqlens
varies in length. This is incompatible with CUDA graph capture.

## Solution

Pad cu_seqlens to a configurable fixed size for CUDA graph replay.
If a batch exceeds this size, fall back to eager forward. This gives
CG benefits for most batches while maintaining correctness for all.

## Key Changes

- PackedSeqParams: cu_seqlens padding, shared CG buffers across
  layers, dummy PSP for graph capture
- TransformerLayer: CG capture/replay with fallback for attention
- MambaLayer: CG capture/replay with pre-computed seq_idx
- MambaMixer: Avoid dynamic allocations inside CG (seq_idx reuse,
  output_size parameter to avoid GPU->CPU sync)
- pretrain_mamba: cu_seqlens padding in get_batch()
- New arg: --cuda-graph-max-packed-seqs
- te_patches/: Patch TE context_parallel to avoid GPU->CPU sync
  during CUDA graph capture (applied via PYTHONPATH import hook)

## Usage

### 1. Training script arguments

  --cuda-graph-impl transformer_engine \
  --cuda-graph-scope mamba attn \
  --cuda-graph-max-packed-seqs <MAX_SEQS> \

MAX_SEQS controls the fixed cu_seqlens size:
  - Smaller value = less flash_attn padding overhead, more fallbacks
  - Larger value = fewer fallbacks, more padding overhead
  - Set based on your dataset's N_docs distribution (e.g., P90 or P99)

### 2. Apply TE context_parallel patch

  export PYTHONPATH=<repo_root>/te_patches:${PYTHONPATH}

Required to avoid GPU->CPU sync errors during CG capture.

### 3. Example

  export PYTHONPATH=/path/to/Megatron-LM/te_patches:${PYTHONPATH}

  torchrun pretrain_mamba.py \
    --sft \
    --cuda-graph-impl transformer_engine \
    --cuda-graph-scope mamba attn \
    --cuda-graph-max-packed-seqs 64 \
    --context-parallel-size 32 \
    --tensor-model-parallel-size 8 \
    --expert-model-parallel-size 64 \
    ...

### 4. Choosing MAX_SEQS

Analyze your dataset's packed sequence distribution:
  - Example: P50=12, P90=43, P99=106, max=358
  - 64 covers ~96% of batches
  - 106 covers ~99% but more padding overhead
  - Setting to max covers 100% but padding overhead may
    outweigh CG benefit

Signed-off-by: Seonjin Na <sna@nvidia.com>
@seonjinn seonjinn force-pushed the sj/cudagraph-packedseq branch 2 times, most recently from 61d6cbb to 9b5382b Compare March 15, 2026 08:08
When cu_seqlens is CG-padded, the last entry exceeds total_tokens
(CP-local). Skip seq_idx computation entirely — in CG mode,
mamba_layer.py manages seq_idx via shared CG buffers in
_te_cuda_graph_replay.

For non-CG (unpadded cu_seqlens), __post_init__ works as before.

Signed-off-by: Seonjin Na <sna@nvidia.com>
@seonjinn seonjinn force-pushed the sj/cudagraph-packedseq branch from 9b5382b to f12d221 Compare March 15, 2026 08:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant