@@ -76,30 +76,40 @@ def pack_sequences(
7676 ['▁toys', '▁.', '</s>', '<s>', '▁but', '▁just', '▁one', '▁look']
7777 """
7878 packed_sequences = []
79+ packed_position_ids = []
7980 buffer = []
81+ position_buffer = []
8082
8183 for input_ids in batch ["input_ids" ]:
82- # Add the current sequence to the buffer
84+ # Position IDs reset to 0 at the start of each sub-sequence; EOS gets the next position.
85+ seq_positions = list (range (len (input_ids ) + 1 ))
8386 buffer .extend (input_ids )
8487 buffer .append (eos_token_id ) # Add EOS at the end of each sequence
88+ position_buffer .extend (seq_positions )
8589
8690 # Check if buffer needs to be split into chunks
8791 while len (buffer ) > max_seq_len :
8892 # Take a full chunk from the buffer and append it to packed_sequences
8993 packed_sequences .append (buffer [:max_seq_len ])
94+ packed_position_ids .append (position_buffer [:max_seq_len ])
9095 # Remove the processed chunk from the buffer
9196 buffer = buffer [max_seq_len :]
97+ position_buffer = position_buffer [max_seq_len :]
9298
9399 # Add the last buffer if it's exactly chunk_size
94100 if len (buffer ) == max_seq_len :
95101 packed_sequences .append (buffer )
102+ packed_position_ids .append (position_buffer )
96103 elif len (buffer ) > cutoff_size :
97104 # if the buffer is larger than the cutoff size, pad it to the chunk_size
98105 # if not, we do not include in the packed_sequences
99- buffer .extend ([pad_token_id ] * (max_seq_len - len (buffer )))
106+ padding_length = max_seq_len - len (buffer )
107+ buffer .extend ([pad_token_id ] * padding_length )
108+ position_buffer .extend ([0 ] * padding_length )
100109 packed_sequences .append (buffer )
110+ packed_position_ids .append (position_buffer )
101111
102- output = {"input_ids" : packed_sequences }
112+ output = {"input_ids" : packed_sequences , "position_ids" : packed_position_ids }
103113 if add_labels :
104114 output ["labels" ] = [
105115 [
@@ -109,7 +119,6 @@ def pack_sequences(
109119 for example in output ["input_ids" ]
110120 ]
111121
112- # mask attention for padding tokens, a better version would also mask cross-sequence dependencies
113122 output ["attention_mask" ] = [
114123 [0 if token_id == pad_token_id else 1 for token_id in example ]
115124 for example in output ["input_ids" ]
0 commit comments