diff --git a/megatron/core/ssm/gated_delta_net.py b/megatron/core/ssm/gated_delta_net.py index 601a72a4356..d347697bae6 100644 --- a/megatron/core/ssm/gated_delta_net.py +++ b/megatron/core/ssm/gated_delta_net.py @@ -296,8 +296,33 @@ def forward( raise NotImplementedError("GDN does not support inference for now.") if packed_seq_params is not None: - # TODO: support packed sequence - raise NotImplementedError("GDN does not support packed sequence for now.") + assert batch == 1, "Packed sequence expects batch dimension to be 1" + assert ( + not self.config.deterministic_mode + ), "Packed sequence does not support deterministic mode." + + # Prefer cu_seqlens_q_padded if available, otherwise use cu_seqlens_q + if packed_seq_params.cu_seqlens_q_padded is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded + else: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + # Prefer cu_seqlens_kv_padded if available, otherwise use cu_seqlens_kv + if packed_seq_params.cu_seqlens_kv_padded is not None: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded + else: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + assert torch.equal(cu_seqlens_q, cu_seqlens_kv), ( + "Currently only support cu_seqlens_q equals to cu_seqlens_kv, " + f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}" + ) + num_packed_seqs = cu_seqlens_q.shape[0] - 1 + assert num_packed_seqs > 0, ( + "Number of packed sequences must be greater than 0, " + f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}" + ) + else: + cu_seqlens_q = None + cu_seqlens_kv = None # Input projection nvtx_range_push(suffix="in_proj") @@ -305,20 +330,41 @@ def forward( nvtx_range_pop(suffix="in_proj") # CP All to All: CP to HP - qkvzba = tensor_a2a_cp2hp( - qkvzba, - seq_dim=0, - head_dim=-1, - cp_group=self.pg_collection.cp, - split_sections=[ - self.qk_dim_local_tp, - self.qk_dim_local_tp, - self.v_dim_local_tp, - self.v_dim_local_tp, - self.num_value_heads // self.tp_size, - self.num_value_heads // self.tp_size, - ], - ) + if packed_seq_params is not None: + unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens_q // self.cp_size, dim=0) + outputs = [] + for qkvzba_i in unpacked_qkvzba: + qkvzba_i = tensor_a2a_cp2hp( + qkvzba_i, + seq_dim=0, + head_dim=-1, + cp_group=self.pg_collection.cp, + split_sections=[ + self.qk_dim_local_tp, + self.qk_dim_local_tp, + self.v_dim_local_tp, + self.v_dim_local_tp, + self.num_value_heads // self.tp_size, + self.num_value_heads // self.tp_size, + ], + ) + outputs.append(qkvzba_i) + qkvzba = torch.cat(outputs, dim=0) + else: + qkvzba = tensor_a2a_cp2hp( + qkvzba, + seq_dim=0, + head_dim=-1, + cp_group=self.pg_collection.cp, + split_sections=[ + self.qk_dim_local_tp, + self.qk_dim_local_tp, + self.v_dim_local_tp, + self.v_dim_local_tp, + self.num_value_heads // self.tp_size, + self.num_value_heads // self.tp_size, + ], + ) # Transpose: s b x --> b s x # From sbhd to bshd format @@ -385,6 +431,7 @@ def forward( activation=self.activation, initial_state=None, output_final_state=False, + cu_seqlens=cu_seqlens_q, ) nvtx_range_pop(suffix="conv1d") @@ -440,6 +487,7 @@ def forward( initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=False, + cu_seqlens=cu_seqlens_q, ) nvtx_range_pop(suffix="gated_delta_rule") @@ -454,9 +502,19 @@ def forward( norm_out = norm_out.transpose(0, 1).contiguous() # CP all to all: HP to CP - norm_out = tensor_a2a_hp2cp( - norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp - ) + if packed_seq_params is not None: + unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens_q, dim=0) + outputs = [] + for norm_out_i in unpacked_norm_out: + norm_out_i = tensor_a2a_hp2cp( + norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp + ) + outputs.append(norm_out_i) + norm_out = torch.cat(outputs, dim=0) + else: + norm_out = tensor_a2a_hp2cp( + norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp + ) # Output projection nvtx_range_push(suffix="out_proj") @@ -575,6 +633,17 @@ def _backward_out_proj(self): self.out_proj.backward_dw() +def _unpack_sequence(x, cu_seqlens, dim=1): + unpacked_x = [] + num_seqs = cu_seqlens.shape[0] - 1 + for i in range(num_seqs): + idx_start = cu_seqlens[i].item() + idx_end = cu_seqlens[i + 1].item() + chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)] + unpacked_x.append(x[tuple(chunked_index)]) + return unpacked_x + + #################### # Sharded state dict utilities #################### @@ -826,6 +895,7 @@ def torch_chunk_gated_delta_rule( initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=False, + cu_seqlens=None, ): # pylint: disable=line-too-long ''' @@ -835,6 +905,10 @@ def torch_chunk_gated_delta_rule( Reference: https://github.com/huggingface/transformers/blob/144c8ce2809a2e21914017652700e1ecb450501e/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L470-L547 ''' + assert cu_seqlens is None, ( + "cu_seqlens is not supported for torch_chunk_gated_delta_rule for now." + ) + initial_dtype = query.dtype if use_qk_l2norm_in_kernel: query = l2norm(query, dim=-1, eps=1e-6) diff --git a/tests/unit_tests/ssm/test_gated_delta_net.py b/tests/unit_tests/ssm/test_gated_delta_net.py index 81f8eed0574..e3f7b0c4f20 100644 --- a/tests/unit_tests/ssm/test_gated_delta_net.py +++ b/tests/unit_tests/ssm/test_gated_delta_net.py @@ -32,6 +32,7 @@ ) from tests.unit_tests.test_utilities import Utils from tests.unit_tests.transformer.test_attention import _test_parallel_attention_correctness +from tests.unit_tests.transformer.test_multi_latent_attention import make_test_packed_seq_params try: import fla @@ -138,7 +139,51 @@ def test_gpu_forward(self): output.dtype == hidden_states.dtype ), f"Output dtype {output.dtype=} mismatch with {hidden_states.dtype=}" + def test_gpu_forward_thd_correctness(self): + if self.sp_size > 1: + pytest.skip("Sequence parallel is not supported for this test case.") + atol, rtol = 3e-4, 3e-4 + + # Input shape + sequence_length = 32 + micro_batch_size = 4 + cu_seqlens = [0, 32, 64, 96, 128] + # sbhd input shape: [sequence length, batch size, hidden size] + sub_sequence_length = sequence_length // self.cp_size + hidden_states_sbhd = torch.rand( + (sub_sequence_length, micro_batch_size, self.gdn.config.hidden_size) + ) + attention_mask_sbhd = None + hidden_states_sbhd = hidden_states_sbhd.cuda().bfloat16() + # thd input shape: [sequence length * batch size, 1, hidden size] + hidden_states_thd = hidden_states_sbhd.transpose(0, 1).contiguous() + hidden_states_thd = hidden_states_thd.view(-1, 1, self.gdn.config.hidden_size) + attention_mask_thd = None + packed_seq_params = make_test_packed_seq_params(cu_seqlens=cu_seqlens) + + # THD format + output_thd, _ = self.gdn( + hidden_states_thd, attention_mask_thd, packed_seq_params=packed_seq_params + ) + # SBHD format + output_sbhd, _ = self.gdn(hidden_states_sbhd, attention_mask_sbhd) + output_sbhd_T = output_sbhd.transpose(0, 1).contiguous().view(*output_thd.shape) + + rank = torch.distributed.get_rank() + assert output_thd.shape[0] == sub_sequence_length * micro_batch_size + assert output_thd.shape[1] == 1 + assert output_thd.shape[2] == self.gdn.config.hidden_size + torch.testing.assert_close( + output_sbhd_T, + output_thd, + atol=atol, + rtol=rtol, + msg=lambda msg: f"Output mismatch ({rank=}): {msg}", + ) + + +@pytest.mark.parametrize("sequence_packing", [False, True]) @pytest.mark.parametrize( ("tp", "sp", "cp"), [ @@ -150,7 +195,7 @@ def test_gpu_forward(self): ], ) @pytest.mark.skipif(not HAVE_FLA, reason="FLA is not installed.") -def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, tp, sp, cp): +def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, sequence_packing, tp, sp, cp): transformer_config = TransformerConfig( hidden_size=128, linear_conv_kernel_dim=2, @@ -191,4 +236,5 @@ def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, tp, sp, cp): seed=123, sequence_length=256, micro_batch_size=4, + sequence_packing=sequence_packing, ) diff --git a/tests/unit_tests/transformer/test_attention.py b/tests/unit_tests/transformer/test_attention.py index 0fbc6b4da23..d760b314c0a 100644 --- a/tests/unit_tests/transformer/test_attention.py +++ b/tests/unit_tests/transformer/test_attention.py @@ -40,6 +40,7 @@ init_checkpointing_mock_args, ) from tests.unit_tests.test_utilities import Utils +from tests.unit_tests.transformer.test_multi_latent_attention import make_test_packed_seq_params try: from transformer_engine.pytorch.attention.rope import apply_fused_qkv_rotary_pos_emb @@ -710,6 +711,7 @@ def _test_parallel_attention_correctness( seed=123, sequence_length=256, micro_batch_size=4, + sequence_packing=False, ): # Model initialization function def initialize_gpt_model( @@ -803,17 +805,24 @@ def initialize_gpt_model( def get_tensor_on_this_rank(tensor): if cp > 1: tensor = get_tensor_on_this_cp_rank(tensor, 0, cp_group) + if sequence_packing: + tensor = tensor.transpose(0, 1).contiguous().view(-1, 1, *tensor.shape[2:]) if tp > 1 and sp: - sp_seg = sequence_length // tp // cp + sp_seg = tensor.shape[0] // tp tensor = tensor[tp_rank * sp_seg : (tp_rank + 1) * sp_seg] return tensor # Calculate parallel model output + if sequence_packing: + cu_seqlens = [i * sequence_length for i in range(micro_batch_size + 1)] + packed_seq_params = make_test_packed_seq_params(cu_seqlens=cu_seqlens) + else: + packed_seq_params = None input_hidden_states = get_tensor_on_this_rank(input_hidden_states) input_hidden_states = input_hidden_states.detach().requires_grad_(True) parallel_attention = gpt_model[0].decoder.layers[0].self_attention output_hidden_states_parallel, bias_hidden_states_parallel = parallel_attention( - input_hidden_states, attention_mask=None + input_hidden_states, attention_mask=None, packed_seq_params=packed_seq_params ) output_hidden_states_parallel.sum().backward() input_grad_parallel = input_hidden_states.grad.detach() @@ -879,6 +888,7 @@ def get_tensor_on_this_rank(tensor): # TODO(yuzhongw): Add test case for fallback_to_eager_attn +@pytest.mark.parametrize("sequence_packing", [False, True]) @pytest.mark.parametrize("apply_rope_fusion", [False, True]) @pytest.mark.parametrize( ("tp", "sp", "cp"), @@ -893,7 +903,7 @@ def get_tensor_on_this_rank(tensor): @pytest.mark.parametrize("qk_layernorm", [False, True]) @pytest.mark.parametrize("output_gate", [False, True]) def test_parallel_attention_correctness( - tmp_path_dist_ckpt, apply_rope_fusion, tp, sp, cp, qk_layernorm, output_gate + tmp_path_dist_ckpt, sequence_packing, apply_rope_fusion, tp, sp, cp, qk_layernorm, output_gate ): transformer_config = TransformerConfig( num_layers=1, @@ -922,6 +932,7 @@ def test_parallel_attention_correctness( cp=cp, seed=123, sequence_length=256, + sequence_packing=sequence_packing, )