Skip to content

Commit 7a06dbf

Browse files
committed
Add position ids
1 parent 35fd835 commit 7a06dbf

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

examples/tokenize_data.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,30 +76,40 @@ def pack_sequences(
7676
['▁toys', '▁.', '</s>', '<s>', '▁but', '▁just', '▁one', '▁look']
7777
"""
7878
packed_sequences = []
79+
packed_position_ids = []
7980
buffer = []
81+
position_buffer = []
8082

8183
for input_ids in batch["input_ids"]:
82-
# Add the current sequence to the buffer
84+
# Position IDs reset to 0 at the start of each sub-sequence; EOS gets the next position.
85+
seq_positions = list(range(len(input_ids) + 1))
8386
buffer.extend(input_ids)
8487
buffer.append(eos_token_id) # Add EOS at the end of each sequence
88+
position_buffer.extend(seq_positions)
8589

8690
# Check if buffer needs to be split into chunks
8791
while len(buffer) > max_seq_len:
8892
# Take a full chunk from the buffer and append it to packed_sequences
8993
packed_sequences.append(buffer[:max_seq_len])
94+
packed_position_ids.append(position_buffer[:max_seq_len])
9095
# Remove the processed chunk from the buffer
9196
buffer = buffer[max_seq_len:]
97+
position_buffer = position_buffer[max_seq_len:]
9298

9399
# Add the last buffer if it's exactly chunk_size
94100
if len(buffer) == max_seq_len:
95101
packed_sequences.append(buffer)
102+
packed_position_ids.append(position_buffer)
96103
elif len(buffer) > cutoff_size:
97104
# if the buffer is larger than the cutoff size, pad it to the chunk_size
98105
# if not, we do not include in the packed_sequences
99-
buffer.extend([pad_token_id] * (max_seq_len - len(buffer)))
106+
padding_length = max_seq_len - len(buffer)
107+
buffer.extend([pad_token_id] * padding_length)
108+
position_buffer.extend([0] * padding_length)
100109
packed_sequences.append(buffer)
110+
packed_position_ids.append(position_buffer)
101111

102-
output = {"input_ids": packed_sequences}
112+
output = {"input_ids": packed_sequences, "position_ids": packed_position_ids}
103113
if add_labels:
104114
output["labels"] = [
105115
[
@@ -109,7 +119,6 @@ def pack_sequences(
109119
for example in output["input_ids"]
110120
]
111121

112-
# mask attention for padding tokens, a better version would also mask cross-sequence dependencies
113122
output["attention_mask"] = [
114123
[0 if token_id == pad_token_id else 1 for token_id in example]
115124
for example in output["input_ids"]

0 commit comments

Comments
 (0)