Skip to content

Commit 98b6924

Browse files
committed
Simplify
1 parent 7a06dbf commit 98b6924

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

examples/tokenize_data.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,20 +81,29 @@ def pack_sequences(
8181
position_buffer = []
8282

8383
for input_ids in batch["input_ids"]:
84-
# Position IDs reset to 0 at the start of each sub-sequence; EOS gets the next position.
85-
seq_positions = list(range(len(input_ids) + 1))
86-
buffer.extend(input_ids)
87-
buffer.append(eos_token_id) # Add EOS at the end of each sequence
84+
# Truncate sequences that individually exceed max_seq_len (including EOS token).
85+
seq_with_eos = (input_ids + [eos_token_id])[:max_seq_len]
86+
# Position IDs reset to 0 at the start of each sub-sequence.
87+
seq_positions = list(range(len(seq_with_eos)))
88+
89+
# If adding this sequence would overflow, flush the current buffer first.
90+
# This ensures every chunk starts at a sequence boundary (position_ids[0] == 0).
91+
if buffer and len(buffer) + len(seq_with_eos) > max_seq_len:
92+
padding_length = max_seq_len - len(buffer)
93+
packed_sequences.append(buffer + [pad_token_id] * padding_length)
94+
packed_position_ids.append(position_buffer + [0] * padding_length)
95+
buffer = []
96+
position_buffer = []
97+
98+
buffer.extend(seq_with_eos)
8899
position_buffer.extend(seq_positions)
89100

90-
# Check if buffer needs to be split into chunks
91-
while len(buffer) > max_seq_len:
92-
# Take a full chunk from the buffer and append it to packed_sequences
93-
packed_sequences.append(buffer[:max_seq_len])
94-
packed_position_ids.append(position_buffer[:max_seq_len])
95-
# Remove the processed chunk from the buffer
96-
buffer = buffer[max_seq_len:]
97-
position_buffer = position_buffer[max_seq_len:]
101+
# Flush immediately if exactly full (no padding needed).
102+
if len(buffer) == max_seq_len:
103+
packed_sequences.append(buffer)
104+
packed_position_ids.append(position_buffer)
105+
buffer = []
106+
position_buffer = []
98107

99108
# Add the last buffer if it's exactly chunk_size
100109
if len(buffer) == max_seq_len:

0 commit comments

Comments
 (0)