Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 93 additions & 19 deletions megatron/core/ssm/gated_delta_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,29 +296,75 @@ def forward(
raise NotImplementedError("GDN does not support inference for now.")

if packed_seq_params is not None:
# TODO: support packed sequence
raise NotImplementedError("GDN does not support packed sequence for now.")
assert batch == 1, "Packed sequence expects batch dimension to be 1"
assert (
not self.config.deterministic_mode
), "Packed sequence does not support deterministic mode."

# Prefer cu_seqlens_q_padded if available, otherwise use cu_seqlens_q
if packed_seq_params.cu_seqlens_q_padded is not None:
cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded
else:
cu_seqlens_q = packed_seq_params.cu_seqlens_q
# Prefer cu_seqlens_kv_padded if available, otherwise use cu_seqlens_kv
if packed_seq_params.cu_seqlens_kv_padded is not None:
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded
else:
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
assert torch.equal(cu_seqlens_q, cu_seqlens_kv), (
"Currently only support cu_seqlens_q equals to cu_seqlens_kv, "
f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}"
)
num_packed_seqs = cu_seqlens_q.shape[0] - 1
assert num_packed_seqs > 0, (
"Number of packed sequences must be greater than 0, "
f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}"
)
else:
cu_seqlens_q = None
cu_seqlens_kv = None

# Input projection
nvtx_range_push(suffix="in_proj")
qkvzba, _ = self.in_proj(hidden_states)
nvtx_range_pop(suffix="in_proj")

# CP All to All: CP to HP
qkvzba = tensor_a2a_cp2hp(
qkvzba,
seq_dim=0,
head_dim=-1,
cp_group=self.pg_collection.cp,
split_sections=[
self.qk_dim_local_tp,
self.qk_dim_local_tp,
self.v_dim_local_tp,
self.v_dim_local_tp,
self.num_value_heads // self.tp_size,
self.num_value_heads // self.tp_size,
],
)
if packed_seq_params is not None:
unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens_q // self.cp_size, dim=0)
outputs = []
for qkvzba_i in unpacked_qkvzba:
qkvzba_i = tensor_a2a_cp2hp(
qkvzba_i,
seq_dim=0,
head_dim=-1,
cp_group=self.pg_collection.cp,
split_sections=[
self.qk_dim_local_tp,
self.qk_dim_local_tp,
self.v_dim_local_tp,
self.v_dim_local_tp,
self.num_value_heads // self.tp_size,
self.num_value_heads // self.tp_size,
],
)
outputs.append(qkvzba_i)
qkvzba = torch.cat(outputs, dim=0)
else:
qkvzba = tensor_a2a_cp2hp(
qkvzba,
seq_dim=0,
head_dim=-1,
cp_group=self.pg_collection.cp,
split_sections=[
self.qk_dim_local_tp,
self.qk_dim_local_tp,
self.v_dim_local_tp,
self.v_dim_local_tp,
self.num_value_heads // self.tp_size,
self.num_value_heads // self.tp_size,
],
)

# Transpose: s b x --> b s x
# From sbhd to bshd format
Expand Down Expand Up @@ -385,6 +431,7 @@ def forward(
activation=self.activation,
initial_state=None,
output_final_state=False,
cu_seqlens=cu_seqlens_q,
)
nvtx_range_pop(suffix="conv1d")

Expand Down Expand Up @@ -440,6 +487,7 @@ def forward(
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=False,
cu_seqlens=cu_seqlens_q,
)
nvtx_range_pop(suffix="gated_delta_rule")

Expand All @@ -454,9 +502,19 @@ def forward(
norm_out = norm_out.transpose(0, 1).contiguous()

# CP all to all: HP to CP
norm_out = tensor_a2a_hp2cp(
norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp
)
if packed_seq_params is not None:
unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens_q, dim=0)
outputs = []
for norm_out_i in unpacked_norm_out:
norm_out_i = tensor_a2a_hp2cp(
norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp
)
outputs.append(norm_out_i)
norm_out = torch.cat(outputs, dim=0)
else:
norm_out = tensor_a2a_hp2cp(
norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp
)

# Output projection
nvtx_range_push(suffix="out_proj")
Expand Down Expand Up @@ -575,6 +633,17 @@ def _backward_out_proj(self):
self.out_proj.backward_dw()


def _unpack_sequence(x, cu_seqlens, dim=1):
unpacked_x = []
num_seqs = cu_seqlens.shape[0] - 1
for i in range(num_seqs):
idx_start = cu_seqlens[i].item()
idx_end = cu_seqlens[i + 1].item()
chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)]
unpacked_x.append(x[tuple(chunked_index)])
return unpacked_x


####################
# Sharded state dict utilities
####################
Expand Down Expand Up @@ -826,6 +895,7 @@ def torch_chunk_gated_delta_rule(
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=False,
cu_seqlens=None,
):
# pylint: disable=line-too-long
'''
Expand All @@ -835,6 +905,10 @@ def torch_chunk_gated_delta_rule(
Reference: https://github.com/huggingface/transformers/blob/144c8ce2809a2e21914017652700e1ecb450501e/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L470-L547
'''

assert cu_seqlens is None, (
"cu_seqlens is not supported for torch_chunk_gated_delta_rule for now."
)

initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = l2norm(query, dim=-1, eps=1e-6)
Expand Down
48 changes: 47 additions & 1 deletion tests/unit_tests/ssm/test_gated_delta_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from tests.unit_tests.test_utilities import Utils
from tests.unit_tests.transformer.test_attention import _test_parallel_attention_correctness
from tests.unit_tests.transformer.test_multi_latent_attention import make_test_packed_seq_params

try:
import fla
Expand Down Expand Up @@ -138,7 +139,51 @@ def test_gpu_forward(self):
output.dtype == hidden_states.dtype
), f"Output dtype {output.dtype=} mismatch with {hidden_states.dtype=}"

def test_gpu_forward_thd_correctness(self):
if self.sp_size > 1:
pytest.skip("Sequence parallel is not supported for this test case.")

atol, rtol = 3e-4, 3e-4

# Input shape
sequence_length = 32
micro_batch_size = 4
cu_seqlens = [0, 32, 64, 96, 128]
# sbhd input shape: [sequence length, batch size, hidden size]
sub_sequence_length = sequence_length // self.cp_size
hidden_states_sbhd = torch.rand(
(sub_sequence_length, micro_batch_size, self.gdn.config.hidden_size)
)
attention_mask_sbhd = None
hidden_states_sbhd = hidden_states_sbhd.cuda().bfloat16()
# thd input shape: [sequence length * batch size, 1, hidden size]
hidden_states_thd = hidden_states_sbhd.transpose(0, 1).contiguous()
hidden_states_thd = hidden_states_thd.view(-1, 1, self.gdn.config.hidden_size)
attention_mask_thd = None
packed_seq_params = make_test_packed_seq_params(cu_seqlens=cu_seqlens)

# THD format
output_thd, _ = self.gdn(
hidden_states_thd, attention_mask_thd, packed_seq_params=packed_seq_params
)
# SBHD format
output_sbhd, _ = self.gdn(hidden_states_sbhd, attention_mask_sbhd)
output_sbhd_T = output_sbhd.transpose(0, 1).contiguous().view(*output_thd.shape)

rank = torch.distributed.get_rank()
assert output_thd.shape[0] == sub_sequence_length * micro_batch_size
assert output_thd.shape[1] == 1
assert output_thd.shape[2] == self.gdn.config.hidden_size
torch.testing.assert_close(
output_sbhd_T,
output_thd,
atol=atol,
rtol=rtol,
msg=lambda msg: f"Output mismatch ({rank=}): {msg}",
)


@pytest.mark.parametrize("sequence_packing", [False, True])
@pytest.mark.parametrize(
("tp", "sp", "cp"),
[
Expand All @@ -150,7 +195,7 @@ def test_gpu_forward(self):
],
)
@pytest.mark.skipif(not HAVE_FLA, reason="FLA is not installed.")
def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, tp, sp, cp):
def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, sequence_packing, tp, sp, cp):
transformer_config = TransformerConfig(
hidden_size=128,
linear_conv_kernel_dim=2,
Expand Down Expand Up @@ -191,4 +236,5 @@ def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, tp, sp, cp):
seed=123,
sequence_length=256,
micro_batch_size=4,
sequence_packing=sequence_packing,
)
17 changes: 14 additions & 3 deletions tests/unit_tests/transformer/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
init_checkpointing_mock_args,
)
from tests.unit_tests.test_utilities import Utils
from tests.unit_tests.transformer.test_multi_latent_attention import make_test_packed_seq_params

try:
from transformer_engine.pytorch.attention.rope import apply_fused_qkv_rotary_pos_emb
Expand Down Expand Up @@ -710,6 +711,7 @@ def _test_parallel_attention_correctness(
seed=123,
sequence_length=256,
micro_batch_size=4,
sequence_packing=False,
):
# Model initialization function
def initialize_gpt_model(
Expand Down Expand Up @@ -803,17 +805,24 @@ def initialize_gpt_model(
def get_tensor_on_this_rank(tensor):
if cp > 1:
tensor = get_tensor_on_this_cp_rank(tensor, 0, cp_group)
if sequence_packing:
tensor = tensor.transpose(0, 1).contiguous().view(-1, 1, *tensor.shape[2:])
if tp > 1 and sp:
sp_seg = sequence_length // tp // cp
sp_seg = tensor.shape[0] // tp
tensor = tensor[tp_rank * sp_seg : (tp_rank + 1) * sp_seg]
return tensor

# Calculate parallel model output
if sequence_packing:
cu_seqlens = [i * sequence_length for i in range(micro_batch_size + 1)]
packed_seq_params = make_test_packed_seq_params(cu_seqlens=cu_seqlens)
else:
packed_seq_params = None
input_hidden_states = get_tensor_on_this_rank(input_hidden_states)
input_hidden_states = input_hidden_states.detach().requires_grad_(True)
parallel_attention = gpt_model[0].decoder.layers[0].self_attention
output_hidden_states_parallel, bias_hidden_states_parallel = parallel_attention(
input_hidden_states, attention_mask=None
input_hidden_states, attention_mask=None, packed_seq_params=packed_seq_params
)
output_hidden_states_parallel.sum().backward()
input_grad_parallel = input_hidden_states.grad.detach()
Expand Down Expand Up @@ -879,6 +888,7 @@ def get_tensor_on_this_rank(tensor):


# TODO(yuzhongw): Add test case for fallback_to_eager_attn
@pytest.mark.parametrize("sequence_packing", [False, True])
@pytest.mark.parametrize("apply_rope_fusion", [False, True])
@pytest.mark.parametrize(
("tp", "sp", "cp"),
Expand All @@ -893,7 +903,7 @@ def get_tensor_on_this_rank(tensor):
@pytest.mark.parametrize("qk_layernorm", [False, True])
@pytest.mark.parametrize("output_gate", [False, True])
def test_parallel_attention_correctness(
tmp_path_dist_ckpt, apply_rope_fusion, tp, sp, cp, qk_layernorm, output_gate
tmp_path_dist_ckpt, sequence_packing, apply_rope_fusion, tp, sp, cp, qk_layernorm, output_gate
):
transformer_config = TransformerConfig(
num_layers=1,
Expand Down Expand Up @@ -922,6 +932,7 @@ def test_parallel_attention_correctness(
cp=cp,
seed=123,
sequence_length=256,
sequence_packing=sequence_packing,
)


Expand Down