Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1510,9 +1510,12 @@ def __init__(
self.kept_packed_seq_params.discard("cu_seqlens_q_padded")
self.kept_packed_seq_params.discard("cu_seqlens_kv_padded")

# total_tokens and seq_idx are only for Mamba and should not be forwarded to TE attention.
# total_tokens, seq_idx, and max_seqlen tensors are only for Mamba / CUDA graph
# buffer management and should not be forwarded to TE attention.
self.kept_packed_seq_params.discard("total_tokens")
self.kept_packed_seq_params.discard("seq_idx")
self.kept_packed_seq_params.discard("max_seqlen_q_tensor")
self.kept_packed_seq_params.discard("max_seqlen_kv_tensor")

if config.qk_clip or config.log_max_attention_logit:
# qk-clip is only supported in TE 2.9.0 and later
Expand Down
195 changes: 174 additions & 21 deletions megatron/core/packed_seq_params.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,32 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Optional, Tuple

import torch
import torch.distributed as dist
from torch import Tensor


# Maximum number of packed sequences supported by CUDA graph capture.
# cu_seqlens tensors are padded to this length + 1 for fixed-shape graph inputs.
# Override at runtime with --cuda-graph-max-packed-seqs.
CUDA_GRAPH_MAX_PACKED_SEQS: int = 2048


# Module-level cache for shared CUDA graph buffer tensors.
# Key: (tag, seq_length, max_seqs, device_id) -> shared buffer dict (or tensor for seq_idx).
# All layers with the same key share the SAME dict and SAME underlying tensor objects.
# Updating these tensors once per micro-batch propagates to ALL layers' CUDA graphs.
_CG_SHARED_BUFFERS: dict = {}


@dataclass
class PackedSeqParams:
'''
parameters to TEDotProductAttention and fused rope kernels for the
`thd` (packed) sequence format
Parameters for TEDotProductAttention and fused rope kernels for the
`thd` (packed) sequence format.
'''

qkv_format: str = None
Expand All @@ -20,40 +36,177 @@ class PackedSeqParams:
cu_seqlens_kv_padded: Tensor = None
max_seqlen_q: int = None
max_seqlen_kv: int = None
# Tensor versions of max_seqlen for CUDA graph buffer updates (avoids int->tensor inside CG).
max_seqlen_q_tensor: Tensor = None
max_seqlen_kv_tensor: Tensor = None
local_cp_size: int = None
cp_group: dist.ProcessGroup = None
total_tokens: int = None
seq_idx: Tensor = None
# Pre-computed seq_idx for Mamba. When set, mamba_mixer reads it directly,
# avoiding dynamic allocations that are forbidden inside CUDA graph capture.
seq_idx: Optional[Tensor] = None

def __post_init__(self):
"""Pre-compute seq_idx for Mamba mixer CUDA graph compatibility.

If total_tokens is 16 (for example), this method takes packed_seq_params.cu_seqlens_q_padded
(or cu_seqlens_q) which is of the form [0, 5, 7, 11] and returns a tensor of the form
[0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3],
which is [0]*(5-0) + [1]*(7-5) + [2]*(11-7) + [3]*(16-11)
In the above example, there are three sequences in the pack.
In general, the output has an additional sequence index (e.g. 0, 1, 2, 3) so that any tokens
beyond the last padded input sequence are accounted for as an extra sequence. However, If
cu_seqlens_q_padded[-1] == max_seqlen then this additional sequence index will not be
included.
"""Pre-compute seq_idx for Mamba mixer.

Converts cu_seqlens into a per-token sequence index tensor. For example,
cu_seqlens=[0, 5, 7, 11] with total_tokens=16 produces:
[0,0,0,0,0, 1,1, 2,2,2,2, 3,3,3,3,3]

An extra sequence index is appended for tokens beyond the last cu_seqlens entry.
"""
cu_seqlens = (
self.cu_seqlens_q_padded if self.cu_seqlens_q_padded is not None else self.cu_seqlens_q
)
if self.seq_idx is not None:
return # Already set (e.g. CG dummy PSP with pre-allocated buffer)

cu_seqlens = self.cu_seqlens_q
if isinstance(cu_seqlens, Tensor) and self.total_tokens is not None:
# Skip seq_idx computation when cu_seqlens has been CG-padded.
# CG-padded cu_seqlens contain entries at the global seq_len
# (e.g. 262144) while total_tokens is CP-local (e.g. 8192).
# In CG mode, seq_idx is managed separately by mamba_layer.py's
# _te_cuda_graph_replay via shared CG buffers.
if cu_seqlens[-1] > self.total_tokens:
return # CG-padded: skip, let mamba_layer handle seq_idx

total_tokens_tensor = torch.tensor(
[self.total_tokens], dtype=cu_seqlens.dtype, device=cu_seqlens.device
)
# Example: [0, 5, 7, 11] -> [0, 5, 7, 11, 16]
cu_seqlens_with_max = torch.cat([cu_seqlens, total_tokens_tensor])
# Example: [0, 5, 7, 11, 16] -> [5, 2, 4, 5]
seq_lengths = cu_seqlens_with_max[1:] - cu_seqlens_with_max[:-1]
# Example: [5, 2, 4, 5] -> [0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3]
# Pass output_size to avoid a GPU->CPU sync that repeat_interleave
# performs when the output length is unknown.
self.seq_idx = (
torch.repeat_interleave(
torch.arange(seq_lengths.numel(), device=cu_seqlens.device), seq_lengths
torch.arange(seq_lengths.numel(), device=cu_seqlens.device),
seq_lengths,
output_size=self.total_tokens,
)
.to(torch.int32)
.unsqueeze(0) # Add a batch dimension
)

# ----------------------------------------------------------------
# CUDA graph padding utilities
# ----------------------------------------------------------------

@staticmethod
def pad_cu_seqlens(cu_seqlens: Tensor, target_len: int) -> Tensor:
"""Pad cu_seqlens to a fixed length using the last element as fill value.

CUDA graphs require fixed-shape inputs. By padding cu_seqlens to a
constant size (bucket_size + 1), the graph captures a single shape and
replays it for all batches that fit within the bucket.
"""
actual_len = cu_seqlens.shape[0]
if actual_len >= target_len:
return cu_seqlens[:target_len]
padded = cu_seqlens.new_empty(target_len)
padded[:actual_len] = cu_seqlens
padded[actual_len:] = cu_seqlens[-1]
return padded

def ensure_cg_padded(self, target_len: int) -> None:
"""Lazily compute and cache padded cu_seqlens for CUDA graph replay.

Called per-layer during CG replay but computes padding only once per
micro-batch (the PSP object is reused across all layers in the same
iteration). Subsequent calls are a no-op because the cache is stored
on the PSP instance itself.
"""
if getattr(self, '_cg_pad_target', None) == target_len:
return # Already cached for this target_len
self._cg_pad_target = target_len
self._cg_padded_q = PackedSeqParams.pad_cu_seqlens(self.cu_seqlens_q, target_len)
self._cg_padded_kv = PackedSeqParams.pad_cu_seqlens(self.cu_seqlens_kv, target_len)
self._cg_padded_qp = (
PackedSeqParams.pad_cu_seqlens(self.cu_seqlens_q_padded, target_len)
if self.cu_seqlens_q_padded is not None
else None
)
self._cg_padded_kvp = (
PackedSeqParams.pad_cu_seqlens(self.cu_seqlens_kv_padded, target_len)
if self.cu_seqlens_kv_padded is not None
else None
)

# ----------------------------------------------------------------
# Shared CUDA graph buffer management
# ----------------------------------------------------------------

@classmethod
def get_or_create_shared_cg_buffers(
cls,
seq_length: int,
max_seqs: int,
device: torch.device,
*,
tag: str = 'attn',
) -> Dict[str, Tensor]:
"""Return the shared PSP buffer dict for CUDA graph replay.

All layers with the same (tag, seq_length, max_seqs, device) share the
SAME dict object and therefore the SAME underlying tensor objects.
Updating the tensors once per micro-batch (via copy_()) propagates to
all layers' captured CUDA graphs simultaneously.
"""
key = (tag, seq_length, max_seqs, int(device.index or 0))
if key not in _CG_SHARED_BUFFERS:
_, buffers = cls.create_dummy_for_cuda_graph(seq_length, max_seqs=max_seqs)
# Object-identity gate; None forces first update.
buffers['_last_updated_psp'] = None
_CG_SHARED_BUFFERS[key] = buffers
return _CG_SHARED_BUFFERS[key]

@classmethod
def get_or_create_shared_seq_idx_buffer(
cls, total_tokens: int, device: torch.device
) -> Tensor:
"""Return the shared seq_idx buffer tensor for Mamba CUDA graph replay."""
key = ('seq_idx', total_tokens, int(device.index or 0))
if key not in _CG_SHARED_BUFFERS:
_CG_SHARED_BUFFERS[key] = torch.zeros(
1, total_tokens, dtype=torch.int32, device=device
)
return _CG_SHARED_BUFFERS[key]

@classmethod
def create_dummy_for_cuda_graph(
cls, seq_length: int, max_seqs: int = CUDA_GRAPH_MAX_PACKED_SEQS
) -> Tuple[PackedSeqParams, Dict[str, Tensor]]:
"""Create a dummy PackedSeqParams for CUDA graph capture.

Returns the dummy PSP and a dict of tensor buffer references that can
be updated via copy_() during graph replay.
"""
cu_seqlens_len = max_seqs + 1
device = torch.cuda.current_device()
dtype = torch.int32

cu_seqlens_q = torch.zeros(cu_seqlens_len, dtype=dtype, device=device)
cu_seqlens_q[1:] = seq_length
cu_seqlens_kv = cu_seqlens_q.clone()
cu_seqlens_q_padded = cu_seqlens_q.clone()
cu_seqlens_kv_padded = cu_seqlens_q.clone()
max_seqlen_q_tensor = torch.tensor([seq_length], dtype=dtype, device=device)
max_seqlen_kv_tensor = torch.tensor([seq_length], dtype=dtype, device=device)

psp = cls(
qkv_format="thd",
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
max_seqlen_q=seq_length,
max_seqlen_kv=seq_length,
max_seqlen_q_tensor=max_seqlen_q_tensor,
max_seqlen_kv_tensor=max_seqlen_kv_tensor,
)
buffers = {
'cu_seqlens_q': cu_seqlens_q,
'cu_seqlens_kv': cu_seqlens_kv,
'cu_seqlens_q_padded': cu_seqlens_q_padded,
'cu_seqlens_kv_padded': cu_seqlens_kv_padded,
'max_seqlen_q_tensor': max_seqlen_q_tensor,
'max_seqlen_kv_tensor': max_seqlen_kv_tensor,
}
return psp, buffers
135 changes: 130 additions & 5 deletions megatron/core/ssm/mamba_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,143 @@ def sharded_state_dict(
apply_prefix_mapping(sharded_state_dict, prefixed_map)
return sharded_state_dict

def get_layer_static_inputs(self, seq_length, micro_batch_size):
"""Prepare static inputs for CUDA graph capture.

When packed sequences are in use (SFT), also sets up shared CUDA graph
buffer tensors and a dummy PackedSeqParams so the graph captures the
packed-sequence code path (Mamba with seq_idx).
"""
static_inputs = super().get_layer_static_inputs(seq_length, micro_batch_size)

from megatron.core.packed_seq_params import CUDA_GRAPH_MAX_PACKED_SEQS
from megatron.training import get_args

if getattr(get_args(), 'sft', False):
self._cuda_graph_seq_length = seq_length
_args = get_args()
_max_seqs = (
getattr(_args, 'cuda_graph_max_packed_seqs', None) or CUDA_GRAPH_MAX_PACKED_SEQS
)
# Compute total_tokens as seen by Mamba SSM after CP all_to_all.
mamba_cp_size = self.mixer.cp.cp_size
total_tokens = (seq_length // self.config.context_parallel_size) * mamba_cp_size
device = torch.device('cuda', torch.cuda.current_device())

# All Mamba layers with the same config share the SAME dict and tensors.
shared_bufs = PackedSeqParams.get_or_create_shared_cg_buffers(
seq_length, _max_seqs, device, tag='mamba'
)
self._cuda_graph_psp_buffers = shared_bufs

# Shared seq_idx buffer for all Mamba layers.
seq_idx_buf = PackedSeqParams.get_or_create_shared_seq_idx_buffer(
total_tokens, device
)
shared_bufs['seq_idx'] = seq_idx_buf

# Build dummy PSP whose tensor fields point to the shared buffers.
dummy_psp = PackedSeqParams(
qkv_format="thd",
cu_seqlens_q=shared_bufs['cu_seqlens_q'],
cu_seqlens_kv=shared_bufs['cu_seqlens_kv'],
cu_seqlens_q_padded=shared_bufs['cu_seqlens_q_padded'],
cu_seqlens_kv_padded=shared_bufs['cu_seqlens_kv_padded'],
max_seqlen_q=seq_length,
max_seqlen_kv=seq_length,
max_seqlen_q_tensor=shared_bufs['max_seqlen_q_tensor'],
max_seqlen_kv_tensor=shared_bufs['max_seqlen_kv_tensor'],
)
dummy_psp.seq_idx = seq_idx_buf
self._cuda_graph_psp = dummy_psp

# Correct cu_seqlens for Mamba's CP all_to_all sequence gathering.
# pre_conv_ssm gathers: [seq_length/cp, b, d] -> [seq_length, b, d/cp]
if mamba_cp_size > 1:
for k in (
'cu_seqlens_q', 'cu_seqlens_kv',
'cu_seqlens_q_padded', 'cu_seqlens_kv_padded',
):
shared_bufs[k][1:] = total_tokens
dummy_psp.max_seqlen_q = total_tokens
dummy_psp.max_seqlen_kv = total_tokens
shared_bufs['max_seqlen_q_tensor'].fill_(total_tokens)
shared_bufs['max_seqlen_kv_tensor'].fill_(total_tokens)

return static_inputs

def _te_cuda_graph_capture(self, *args, **kwargs):
"""Inject dummy PSP for CUDA graph capture so Mamba captures the packed-seq code path."""
if hasattr(self, '_cuda_graph_psp') and kwargs.get('packed_seq_params') is None:
kwargs = dict(kwargs)
kwargs['packed_seq_params'] = self._cuda_graph_psp
return self.forward(*args, **kwargs)

def _te_cuda_graph_replay(self, *args, **kwargs):
"""
CUDA graph replay for this layer and microbatch `self.current_microbatch` using TE
interface. TransformerEngine versions>=1.10 allow keyword arguments with CUDA graph.
However, CUDA graph accepts only Tensor inputs.
Hence, `inference_context` is excluded from input list.
CUDA graph replay for Mamba layer using TE interface.

Copies PackedSeqParams tensor fields (cu_seqlens, seq_idx) into the
captured graph's shared buffers. Falls back to non-CG forward when
the actual packed-sequence count exceeds the CG bucket size.
"""
assert kwargs.get('inference_context') is None, (
"CUDA graph accepts only Tensor inputs. inference_context is excluded from input list. "
"For inference cuda graph, please use cuda_graph_impl=local instead."
)
return super()._te_cuda_graph_replay(*args, **kwargs)
psp = kwargs.get('packed_seq_params')
if psp is not None and hasattr(self, '_cuda_graph_psp_buffers'):
bucket_max = self._cuda_graph_psp_buffers['cu_seqlens_q'].shape[0] # max_seqs + 1
if psp.cu_seqlens_q.shape[0] > bucket_max:
# Actual N_docs exceeds bucket -> fall back to non-CG forward.
return self.forward(*args, **kwargs)

bufs = self._cuda_graph_psp_buffers
target_len = bufs['cu_seqlens_q'].shape[0]

# PSP-identity gate: shared buffers need only be updated ONCE per
# micro-batch. Use 'is' to avoid false-positive cache hits from
# CPython id() recycling.
if bufs.get('_last_updated_psp') is not psp:
psp.ensure_cg_padded(target_len)
bufs['cu_seqlens_q'].copy_(psp._cg_padded_q)
bufs['cu_seqlens_kv'].copy_(psp._cg_padded_kv)
bufs['cu_seqlens_q_padded'].copy_(
psp._cg_padded_qp if psp._cg_padded_qp is not None else psp._cg_padded_q
)
bufs['cu_seqlens_kv_padded'].copy_(
psp._cg_padded_kvp if psp._cg_padded_kvp is not None else psp._cg_padded_kv
)
if psp.max_seqlen_q_tensor is not None:
bufs['max_seqlen_q_tensor'].copy_(psp.max_seqlen_q_tensor)
if psp.max_seqlen_kv_tensor is not None:
bufs['max_seqlen_kv_tensor'].copy_(psp.max_seqlen_kv_tensor)
# Copy seq_idx into shared buffer (computed by __post_init__).
if 'seq_idx' in bufs and psp.seq_idx is not None:
bufs['seq_idx'].copy_(psp.seq_idx)
bufs['_last_updated_psp'] = psp

# Set int constants on dummy PSP (captured as Python constants in graph).
self._cuda_graph_psp.max_seqlen_q = self._cuda_graph_seq_length
self._cuda_graph_psp.max_seqlen_kv = self._cuda_graph_seq_length

# Replace real PSP with fixed-size dummy PSP.
kwargs = dict(kwargs)
kwargs['packed_seq_params'] = self._cuda_graph_psp

kwargs_filtered = {
k: v
for k, v in kwargs.items()
if v is None or isinstance(v, torch.Tensor) or isinstance(v, PackedSeqParams)
}

cg_index = getattr(self, 'current_microbatch', 0) % len(self.cuda_graphs)
cudagraph_args, cudagraph_kwargs = self._get_te_cuda_graph_replay_args(
*args, **kwargs_filtered
)
for hook, hook_args in self.cuda_graph_manual_hooks:
hook(*hook_args)
return self.cuda_graphs[cg_index](*cudagraph_args, **cudagraph_kwargs)

def _should_call_local_cudagraph(self, *args, **kwargs):
"""
Expand Down
Loading