Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 5 additions & 58 deletions megatron/rl/sequence_packing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class PackingContext:
original_trajs: All trajectories before packing
packed_trajs: Packed trajectories tensor [num_bins, bin_size]
packed_position_ids: Position IDs for packed sequences [num_bins, bin_size]
packed_attention_mask: Attention mask for packed sequences [num_bins, 1, bin_size, bin_size]
packed_loss_mask: Loss mask for packed sequences [num_bins, bin_size]
original_inference_logprobs: Inference logprobs for all sequences before packing (optional)
bin_advantages: List of advantage tensors for each bin
Expand All @@ -64,7 +63,6 @@ class PackingContext:
original_trajs: torch.Tensor
packed_trajs: torch.Tensor
packed_position_ids: torch.Tensor
packed_attention_mask: torch.Tensor
packed_loss_mask: torch.Tensor
original_inference_logprobs: Optional[torch.Tensor] = None
bin_advantages: List[torch.Tensor] = field(default_factory=list)
Expand Down Expand Up @@ -314,9 +312,8 @@ def create_empty_bins(
packed_trajs : torch.Tensor,
packed_position_ids : torch.Tensor,
packed_loss_mask : torch.Tensor,
packed_attention_mask : torch.Tensor,
tokenizer,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[Dict[str, Any]]]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[Dict[str, Any]]]:
"""Create empty bins for padding to ensure all ranks have the same number of bins.

Args:
Expand All @@ -325,19 +322,17 @@ def create_empty_bins(
packed_trajs: Packed trajectories tensor (for dtype/device reference)
packed_position_ids: Packed position IDs tensor (for dtype/device reference)
packed_loss_mask: Packed loss mask tensor (for dtype/device reference)
packed_attention_mask: Packed attention mask tensor (can be None)
tokenizer: Tokenizer for pad token

Returns:
Tuple of (empty_trajs, empty_position_ids, empty_loss_mask, empty_attention_mask, empty_packing_info_entries)
Tuple of (empty_trajs, empty_position_ids, empty_loss_mask, empty_packing_info_entries)
"""
device = packed_trajs.device

# Create empty bins with proper shape
empty_bins = []
empty_position_ids_list = []
empty_loss_mask_list = []
empty_attention_mask_list = []
empty_packing_info_entries = []

for i in range(num_empty_bins):
Expand All @@ -355,14 +350,6 @@ def create_empty_bins(
empty_loss = torch.zeros(1, bin_size, dtype=packed_loss_mask.dtype, device=device)
empty_loss_mask_list.append(empty_loss)

# Zero attention mask if needed
if packed_attention_mask is not None:
# Attention mask is always 4D: [num_bins, 1, bin_size, bin_size]
empty_attn = torch.zeros(
1, 1, bin_size, bin_size, dtype=packed_attention_mask.dtype, device=device
)
empty_attention_mask_list.append(empty_attn)

# Empty packing info entries
empty_packing_info_entries.append(
{
Expand All @@ -376,22 +363,15 @@ def create_empty_bins(
empty_trajs = torch.cat(empty_bins, dim=0)
empty_position_ids = torch.cat(empty_position_ids_list, dim=0)
empty_loss_mask = torch.cat(empty_loss_mask_list, dim=0)
empty_attention_mask = (
torch.cat(empty_attention_mask_list, dim=0)
if packed_attention_mask is not None
else None
)
else:
empty_trajs = None
empty_position_ids = None
empty_loss_mask = None
empty_attention_mask = None

return (
empty_trajs,
empty_position_ids,
empty_loss_mask,
empty_attention_mask,
empty_packing_info_entries,
)

Expand Down Expand Up @@ -708,9 +688,6 @@ def pack_sequences(
position_ids = torch.zeros(
(num_bins, self.bin_size), dtype=torch.long, device=device, requires_grad=False
)
attention_mask = torch.zeros(
(num_bins, 1, self.bin_size, self.bin_size), dtype=torch.bool, device=device
)
loss_mask = torch.zeros((num_bins, self.bin_size), dtype=torch.float, device=device)

# Track packing information for unpacking later
Expand Down Expand Up @@ -741,12 +718,6 @@ def pack_sequences(
len(seq), device=device, requires_grad=False
)

# Causal attention mask within each sequence
seq_len = end - start
attention_mask[bin_idx, 0, start:end, start:end] = torch.tril(
torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
)

# Loss mask (excluding padding)
loss_mask[bin_idx, start:end] = 1.0

Expand All @@ -761,12 +732,6 @@ def pack_sequences(
seq_starts.append(current_pos)
seq_starts_dict[bin_idx] = seq_starts

# Note: We'll store the actual padded length later when we know it
# (it depends on the original trajectories passed to pack_sequences)

# Invert attention mask, before inversion: (True = attend, False = mask)
attention_mask.bitwise_not_()

# Create the PackingInfo dataclass
packing_info = PackingInfo(
bin_seq_indices=bin_seq_indices,
Expand Down Expand Up @@ -795,15 +760,14 @@ def pack_sequences(
)
log_single_rank(logger, logging.DEBUG, f" - First 20 bins: {seq_per_bin[:20]}")

return packed_sequences, position_ids, attention_mask, loss_mask, packing_info
return packed_sequences, position_ids, loss_mask, packing_info

def distribute_packed_bins(
packed_trajs: torch.Tensor,
packed_position_ids: torch.Tensor,
packed_attention_mask: torch.Tensor,
packed_loss_mask: torch.Tensor,
packing_info: PackingInfo,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, PackingInfo]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, PackingInfo]:
"""Distribute packed bins across the data parallel ranks."""
rank = mpu.get_data_parallel_rank()
world_size = mpu.get_data_parallel_world_size()
Expand Down Expand Up @@ -840,7 +804,6 @@ def distribute_packed_bins(
# Extract this rank's bins
my_packed_trajs = []
my_packed_position_ids = []
my_packed_attention_mask = []
my_packed_loss_mask = []
my_bin_seq_indices = []
my_seq_starts = {}
Expand All @@ -850,8 +813,6 @@ def distribute_packed_bins(
for new_idx, old_idx in enumerate(my_bin_indices):
my_packed_trajs.append(packed_trajs[old_idx])
my_packed_position_ids.append(packed_position_ids[old_idx])
if packed_attention_mask is not None:
my_packed_attention_mask.append(packed_attention_mask[old_idx])
my_packed_loss_mask.append(packed_loss_mask[old_idx])
my_bin_seq_indices.append(packing_info.bin_seq_indices[old_idx])
my_seq_starts[new_idx] = packing_info.seq_starts[old_idx]
Expand All @@ -877,9 +838,6 @@ def distribute_packed_bins(
device=packed_position_ids.device,
)
)
packed_attention_mask = (
torch.stack(my_packed_attention_mask) if my_packed_attention_mask else None
)
packed_loss_mask = (
torch.stack(my_packed_loss_mask)
if my_packed_loss_mask
Expand Down Expand Up @@ -937,15 +895,13 @@ def distribute_packed_bins(
empty_trajs,
empty_position_ids,
empty_loss_mask,
empty_attention_mask,
empty_packing_entries,
) = create_empty_bins(
num_empty_bins,
bin_size,
packed_trajs,
packed_position_ids,
packed_loss_mask,
packed_attention_mask,
tokenizer,
)

Expand All @@ -956,18 +912,13 @@ def distribute_packed_bins(
)
packed_loss_mask = torch.cat([packed_loss_mask, empty_loss_mask], dim=0)

if packed_attention_mask is not None and empty_attention_mask is not None:
packed_attention_mask = torch.cat(
[packed_attention_mask, empty_attention_mask], dim=0
)

# Add empty entries to packing_info
for i, entry in enumerate(empty_packing_entries):
bin_idx = current_bins + i
new_packing_info.bin_seq_indices.append(entry['bin_seq_indices'])
new_packing_info.seq_starts[bin_idx] = entry['seq_starts']

return packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, new_packing_info
return packed_trajs, packed_position_ids, packed_loss_mask, new_packing_info


def pack_all_trajectories(trajs, generation_masks, inference_logprobs, global_advantages, bin_size, max_sequences_per_bin, packing_algo):
Expand Down Expand Up @@ -1000,7 +951,6 @@ def _gather(data):
(
packed_trajs,
packed_position_ids,
packed_attention_mask,
packed_loss_mask,
packing_info,
) = packer.pack_sequences(trajs, generation_masks)
Expand All @@ -1010,13 +960,11 @@ def _gather(data):
(
packed_trajs,
packed_position_ids,
packed_attention_mask,
packed_loss_mask,
packing_info,
) = distribute_packed_bins(
packed_trajs,
packed_position_ids,
packed_attention_mask,
packed_loss_mask,
packing_info,
)
Expand Down Expand Up @@ -1053,7 +1001,6 @@ def _gather(data):
original_trajs=trajs,
packed_trajs=packed_trajs,
packed_position_ids=packed_position_ids,
packed_attention_mask=packed_attention_mask,
packed_loss_mask=packed_loss_mask,
original_inference_logprobs=inference_logprobs,
bin_advantages=bin_advantages,
Expand Down
17 changes: 7 additions & 10 deletions tests/unit_tests/rl/test_sequence_packing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,12 @@ def test_sequence_packing_basic():
rewards = torch.tensor([1.0, 2.0, 3.0, 4.0])

sequences_tensor = torch.stack(sequences)
packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, packing_info = (
packer.pack_sequences(sequences_tensor, generation_masks)
packed_trajs, packed_position_ids, packed_loss_mask, packing_info = packer.pack_sequences(
sequences_tensor, generation_masks
)

assert packed_trajs is not None
assert packed_position_ids is not None
assert packed_attention_mask is not None
assert packed_loss_mask is not None
assert packing_info is not None

Expand Down Expand Up @@ -140,8 +139,8 @@ def test_sequence_packing_with_generation_masks():
)

padded_sequences_tensor = torch.stack(padded_sequences)
packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, packing_info = (
packer.pack_sequences(padded_sequences_tensor, generation_masks)
packed_trajs, packed_position_ids, packed_loss_mask, packing_info = packer.pack_sequences(
padded_sequences_tensor, generation_masks
)

assert packed_trajs.shape[0] == 1
Expand All @@ -162,16 +161,14 @@ def test_sequence_packing_empty_bins():
)
packed_position_ids = torch.tensor([[0, 1, 2, 3, 0, 0, 0, 0]])
packed_loss_mask = torch.tensor([[1, 1, 1, 1, 0, 0, 0, 0]], dtype=torch.float)
packed_attention_mask = torch.ones(1, bin_size, bin_size)

empty_trajs, empty_position_ids, empty_loss_mask, empty_attention_mask, empty_packing_info = (
empty_trajs, empty_position_ids, empty_loss_mask, empty_packing_info = (
sequence_packing_utils.create_empty_bins(
num_empty_bins=num_empty_bins,
bin_size=bin_size,
packed_trajs=packed_trajs,
packed_position_ids=packed_position_ids,
packed_loss_mask=packed_loss_mask,
packed_attention_mask=packed_attention_mask,
tokenizer=tokenizer,
)
)
Expand Down Expand Up @@ -220,8 +217,8 @@ def test_sequence_packing_integration():
]

sequences_tensor = torch.stack(sequences)
packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, packing_info = (
packer.pack_sequences(sequences_tensor, generation_masks)
packed_trajs, packed_position_ids, packed_loss_mask, packing_info = packer.pack_sequences(
sequences_tensor, generation_masks
)

assert packed_trajs is not None
Expand Down
Loading