From 385cb4948cd5e85cce58d9d05c7da3c8b50146d7 Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Fri, 13 Mar 2026 11:06:21 -0500 Subject: [PATCH 1/2] Remove packed_attention_mask unused parameter --- megatron/rl/sequence_packing_utils.py | 63 ++----------------- .../rl/test_sequence_packing_utils.py | 11 ++-- 2 files changed, 9 insertions(+), 65 deletions(-) diff --git a/megatron/rl/sequence_packing_utils.py b/megatron/rl/sequence_packing_utils.py index 1285d0926f8..89efbc1ee1a 100644 --- a/megatron/rl/sequence_packing_utils.py +++ b/megatron/rl/sequence_packing_utils.py @@ -51,7 +51,6 @@ class PackingContext: original_trajs: All trajectories before packing packed_trajs: Packed trajectories tensor [num_bins, bin_size] packed_position_ids: Position IDs for packed sequences [num_bins, bin_size] - packed_attention_mask: Attention mask for packed sequences [num_bins, 1, bin_size, bin_size] packed_loss_mask: Loss mask for packed sequences [num_bins, bin_size] original_inference_logprobs: Inference logprobs for all sequences before packing (optional) bin_advantages: List of advantage tensors for each bin @@ -64,7 +63,6 @@ class PackingContext: original_trajs: torch.Tensor packed_trajs: torch.Tensor packed_position_ids: torch.Tensor - packed_attention_mask: torch.Tensor packed_loss_mask: torch.Tensor original_inference_logprobs: Optional[torch.Tensor] = None bin_advantages: List[torch.Tensor] = field(default_factory=list) @@ -314,9 +312,8 @@ def create_empty_bins( packed_trajs : torch.Tensor, packed_position_ids : torch.Tensor, packed_loss_mask : torch.Tensor, - packed_attention_mask : torch.Tensor, tokenizer, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[Dict[str, Any]]]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[Dict[str, Any]]]: """Create empty bins for padding to ensure all ranks have the same number of bins. Args: @@ -325,11 +322,10 @@ def create_empty_bins( packed_trajs: Packed trajectories tensor (for dtype/device reference) packed_position_ids: Packed position IDs tensor (for dtype/device reference) packed_loss_mask: Packed loss mask tensor (for dtype/device reference) - packed_attention_mask: Packed attention mask tensor (can be None) tokenizer: Tokenizer for pad token Returns: - Tuple of (empty_trajs, empty_position_ids, empty_loss_mask, empty_attention_mask, empty_packing_info_entries) + Tuple of (empty_trajs, empty_position_ids, empty_loss_mask, empty_packing_info_entries) """ device = packed_trajs.device @@ -337,7 +333,6 @@ def create_empty_bins( empty_bins = [] empty_position_ids_list = [] empty_loss_mask_list = [] - empty_attention_mask_list = [] empty_packing_info_entries = [] for i in range(num_empty_bins): @@ -355,14 +350,6 @@ def create_empty_bins( empty_loss = torch.zeros(1, bin_size, dtype=packed_loss_mask.dtype, device=device) empty_loss_mask_list.append(empty_loss) - # Zero attention mask if needed - if packed_attention_mask is not None: - # Attention mask is always 4D: [num_bins, 1, bin_size, bin_size] - empty_attn = torch.zeros( - 1, 1, bin_size, bin_size, dtype=packed_attention_mask.dtype, device=device - ) - empty_attention_mask_list.append(empty_attn) - # Empty packing info entries empty_packing_info_entries.append( { @@ -376,22 +363,15 @@ def create_empty_bins( empty_trajs = torch.cat(empty_bins, dim=0) empty_position_ids = torch.cat(empty_position_ids_list, dim=0) empty_loss_mask = torch.cat(empty_loss_mask_list, dim=0) - empty_attention_mask = ( - torch.cat(empty_attention_mask_list, dim=0) - if packed_attention_mask is not None - else None - ) else: empty_trajs = None empty_position_ids = None empty_loss_mask = None - empty_attention_mask = None return ( empty_trajs, empty_position_ids, empty_loss_mask, - empty_attention_mask, empty_packing_info_entries, ) @@ -708,9 +688,6 @@ def pack_sequences( position_ids = torch.zeros( (num_bins, self.bin_size), dtype=torch.long, device=device, requires_grad=False ) - attention_mask = torch.zeros( - (num_bins, 1, self.bin_size, self.bin_size), dtype=torch.bool, device=device - ) loss_mask = torch.zeros((num_bins, self.bin_size), dtype=torch.float, device=device) # Track packing information for unpacking later @@ -741,12 +718,6 @@ def pack_sequences( len(seq), device=device, requires_grad=False ) - # Causal attention mask within each sequence - seq_len = end - start - attention_mask[bin_idx, 0, start:end, start:end] = torch.tril( - torch.ones(seq_len, seq_len, dtype=torch.bool, device=device) - ) - # Loss mask (excluding padding) loss_mask[bin_idx, start:end] = 1.0 @@ -761,12 +732,6 @@ def pack_sequences( seq_starts.append(current_pos) seq_starts_dict[bin_idx] = seq_starts - # Note: We'll store the actual padded length later when we know it - # (it depends on the original trajectories passed to pack_sequences) - - # Invert attention mask, before inversion: (True = attend, False = mask) - attention_mask.bitwise_not_() - # Create the PackingInfo dataclass packing_info = PackingInfo( bin_seq_indices=bin_seq_indices, @@ -795,15 +760,14 @@ def pack_sequences( ) log_single_rank(logger, logging.DEBUG, f" - First 20 bins: {seq_per_bin[:20]}") - return packed_sequences, position_ids, attention_mask, loss_mask, packing_info + return packed_sequences, position_ids, loss_mask, packing_info def distribute_packed_bins( packed_trajs: torch.Tensor, packed_position_ids: torch.Tensor, - packed_attention_mask: torch.Tensor, packed_loss_mask: torch.Tensor, packing_info: PackingInfo, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, PackingInfo]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, PackingInfo]: """Distribute packed bins across the data parallel ranks.""" rank = mpu.get_data_parallel_rank() world_size = mpu.get_data_parallel_world_size() @@ -840,7 +804,6 @@ def distribute_packed_bins( # Extract this rank's bins my_packed_trajs = [] my_packed_position_ids = [] - my_packed_attention_mask = [] my_packed_loss_mask = [] my_bin_seq_indices = [] my_seq_starts = {} @@ -850,8 +813,6 @@ def distribute_packed_bins( for new_idx, old_idx in enumerate(my_bin_indices): my_packed_trajs.append(packed_trajs[old_idx]) my_packed_position_ids.append(packed_position_ids[old_idx]) - if packed_attention_mask is not None: - my_packed_attention_mask.append(packed_attention_mask[old_idx]) my_packed_loss_mask.append(packed_loss_mask[old_idx]) my_bin_seq_indices.append(packing_info.bin_seq_indices[old_idx]) my_seq_starts[new_idx] = packing_info.seq_starts[old_idx] @@ -877,9 +838,6 @@ def distribute_packed_bins( device=packed_position_ids.device, ) ) - packed_attention_mask = ( - torch.stack(my_packed_attention_mask) if my_packed_attention_mask else None - ) packed_loss_mask = ( torch.stack(my_packed_loss_mask) if my_packed_loss_mask @@ -937,7 +895,6 @@ def distribute_packed_bins( empty_trajs, empty_position_ids, empty_loss_mask, - empty_attention_mask, empty_packing_entries, ) = create_empty_bins( num_empty_bins, @@ -945,7 +902,6 @@ def distribute_packed_bins( packed_trajs, packed_position_ids, packed_loss_mask, - packed_attention_mask, tokenizer, ) @@ -956,18 +912,13 @@ def distribute_packed_bins( ) packed_loss_mask = torch.cat([packed_loss_mask, empty_loss_mask], dim=0) - if packed_attention_mask is not None and empty_attention_mask is not None: - packed_attention_mask = torch.cat( - [packed_attention_mask, empty_attention_mask], dim=0 - ) - # Add empty entries to packing_info for i, entry in enumerate(empty_packing_entries): bin_idx = current_bins + i new_packing_info.bin_seq_indices.append(entry['bin_seq_indices']) new_packing_info.seq_starts[bin_idx] = entry['seq_starts'] - return packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, new_packing_info + return packed_trajs, packed_position_ids, packed_loss_mask, new_packing_info def pack_all_trajectories(trajs, generation_masks, inference_logprobs, global_advantages, bin_size, max_sequences_per_bin, packing_algo): @@ -1000,7 +951,6 @@ def _gather(data): ( packed_trajs, packed_position_ids, - packed_attention_mask, packed_loss_mask, packing_info, ) = packer.pack_sequences(trajs, generation_masks) @@ -1010,13 +960,11 @@ def _gather(data): ( packed_trajs, packed_position_ids, - packed_attention_mask, packed_loss_mask, packing_info, ) = distribute_packed_bins( packed_trajs, packed_position_ids, - packed_attention_mask, packed_loss_mask, packing_info, ) @@ -1053,7 +1001,6 @@ def _gather(data): original_trajs=trajs, packed_trajs=packed_trajs, packed_position_ids=packed_position_ids, - packed_attention_mask=packed_attention_mask, packed_loss_mask=packed_loss_mask, original_inference_logprobs=inference_logprobs, bin_advantages=bin_advantages, diff --git a/tests/unit_tests/rl/test_sequence_packing_utils.py b/tests/unit_tests/rl/test_sequence_packing_utils.py index 06e63adf217..75a7981457d 100644 --- a/tests/unit_tests/rl/test_sequence_packing_utils.py +++ b/tests/unit_tests/rl/test_sequence_packing_utils.py @@ -98,13 +98,12 @@ def test_sequence_packing_basic(): rewards = torch.tensor([1.0, 2.0, 3.0, 4.0]) sequences_tensor = torch.stack(sequences) - packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, packing_info = ( + packed_trajs, packed_position_ids, packed_loss_mask, packing_info = ( packer.pack_sequences(sequences_tensor, generation_masks) ) assert packed_trajs is not None assert packed_position_ids is not None - assert packed_attention_mask is not None assert packed_loss_mask is not None assert packing_info is not None @@ -140,7 +139,7 @@ def test_sequence_packing_with_generation_masks(): ) padded_sequences_tensor = torch.stack(padded_sequences) - packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, packing_info = ( + packed_trajs, packed_position_ids, packed_loss_mask, packing_info = ( packer.pack_sequences(padded_sequences_tensor, generation_masks) ) @@ -162,16 +161,14 @@ def test_sequence_packing_empty_bins(): ) packed_position_ids = torch.tensor([[0, 1, 2, 3, 0, 0, 0, 0]]) packed_loss_mask = torch.tensor([[1, 1, 1, 1, 0, 0, 0, 0]], dtype=torch.float) - packed_attention_mask = torch.ones(1, bin_size, bin_size) - empty_trajs, empty_position_ids, empty_loss_mask, empty_attention_mask, empty_packing_info = ( + empty_trajs, empty_position_ids, empty_loss_mask, empty_packing_info = ( sequence_packing_utils.create_empty_bins( num_empty_bins=num_empty_bins, bin_size=bin_size, packed_trajs=packed_trajs, packed_position_ids=packed_position_ids, packed_loss_mask=packed_loss_mask, - packed_attention_mask=packed_attention_mask, tokenizer=tokenizer, ) ) @@ -220,7 +217,7 @@ def test_sequence_packing_integration(): ] sequences_tensor = torch.stack(sequences) - packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, packing_info = ( + packed_trajs, packed_position_ids, packed_loss_mask, packing_info = ( packer.pack_sequences(sequences_tensor, generation_masks) ) From 741ed8ffb02ff1268ef1a779aae93d10aa000adc Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Sun, 15 Mar 2026 05:36:21 -0500 Subject: [PATCH 2/2] lint --- tests/unit_tests/rl/test_sequence_packing_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unit_tests/rl/test_sequence_packing_utils.py b/tests/unit_tests/rl/test_sequence_packing_utils.py index 75a7981457d..0a4f4cd03f9 100644 --- a/tests/unit_tests/rl/test_sequence_packing_utils.py +++ b/tests/unit_tests/rl/test_sequence_packing_utils.py @@ -98,8 +98,8 @@ def test_sequence_packing_basic(): rewards = torch.tensor([1.0, 2.0, 3.0, 4.0]) sequences_tensor = torch.stack(sequences) - packed_trajs, packed_position_ids, packed_loss_mask, packing_info = ( - packer.pack_sequences(sequences_tensor, generation_masks) + packed_trajs, packed_position_ids, packed_loss_mask, packing_info = packer.pack_sequences( + sequences_tensor, generation_masks ) assert packed_trajs is not None @@ -139,8 +139,8 @@ def test_sequence_packing_with_generation_masks(): ) padded_sequences_tensor = torch.stack(padded_sequences) - packed_trajs, packed_position_ids, packed_loss_mask, packing_info = ( - packer.pack_sequences(padded_sequences_tensor, generation_masks) + packed_trajs, packed_position_ids, packed_loss_mask, packing_info = packer.pack_sequences( + padded_sequences_tensor, generation_masks ) assert packed_trajs.shape[0] == 1 @@ -217,8 +217,8 @@ def test_sequence_packing_integration(): ] sequences_tensor = torch.stack(sequences) - packed_trajs, packed_position_ids, packed_loss_mask, packing_info = ( - packer.pack_sequences(sequences_tensor, generation_masks) + packed_trajs, packed_position_ids, packed_loss_mask, packing_info = packer.pack_sequences( + sequences_tensor, generation_masks ) assert packed_trajs is not None