@@ -81,20 +81,29 @@ def pack_sequences(
8181 position_buffer = []
8282
8383 for input_ids in batch ["input_ids" ]:
84- # Position IDs reset to 0 at the start of each sub-sequence; EOS gets the next position.
85- seq_positions = list (range (len (input_ids ) + 1 ))
86- buffer .extend (input_ids )
87- buffer .append (eos_token_id ) # Add EOS at the end of each sequence
84+ # Truncate sequences that individually exceed max_seq_len (including EOS token).
85+ seq_with_eos = (input_ids + [eos_token_id ])[:max_seq_len ]
86+ # Position IDs reset to 0 at the start of each sub-sequence.
87+ seq_positions = list (range (len (seq_with_eos )))
88+
89+ # If adding this sequence would overflow, flush the current buffer first.
90+ # This ensures every chunk starts at a sequence boundary (position_ids[0] == 0).
91+ if buffer and len (buffer ) + len (seq_with_eos ) > max_seq_len :
92+ padding_length = max_seq_len - len (buffer )
93+ packed_sequences .append (buffer + [pad_token_id ] * padding_length )
94+ packed_position_ids .append (position_buffer + [0 ] * padding_length )
95+ buffer = []
96+ position_buffer = []
97+
98+ buffer .extend (seq_with_eos )
8899 position_buffer .extend (seq_positions )
89100
90- # Check if buffer needs to be split into chunks
91- while len (buffer ) > max_seq_len :
92- # Take a full chunk from the buffer and append it to packed_sequences
93- packed_sequences .append (buffer [:max_seq_len ])
94- packed_position_ids .append (position_buffer [:max_seq_len ])
95- # Remove the processed chunk from the buffer
96- buffer = buffer [max_seq_len :]
97- position_buffer = position_buffer [max_seq_len :]
101+ # Flush immediately if exactly full (no padding needed).
102+ if len (buffer ) == max_seq_len :
103+ packed_sequences .append (buffer )
104+ packed_position_ids .append (position_buffer )
105+ buffer = []
106+ position_buffer = []
98107
99108 # Add the last buffer if it's exactly chunk_size
100109 if len (buffer ) == max_seq_len :
0 commit comments