Skip to content

Commit adebe02

Browse files
author
Swati Allabadi
committed
Adding dataset padding changes
Signed-off-by: Swati Allabadi <[email protected]>
1 parent 740f7c2 commit adebe02

File tree

3 files changed

+36
-12
lines changed

3 files changed

+36
-12
lines changed

QEfficient/finetune/data/sampler.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7-
87
import random
98
from itertools import islice
109

@@ -22,17 +21,15 @@ def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool
2221
self.batch_size = batch_size
2322
self.drop_last = drop_last
2423
self.shuffle = shuffle
24+
self.data_source = data_source
2525

2626
def __iter__(self):
27-
ids = np.argsort(self.lengths, kind="mergesort")
27+
ids = [i for i in range(len(self.data_source))]
2828
if self.drop_last:
2929
ids = ids[: len(ids) // self.batch_size * self.batch_size]
3030

3131
batches = [ids[i : i + self.batch_size] for i in range(0, len(ids), self.batch_size)]
3232

33-
if self.shuffle:
34-
random.shuffle(batches)
35-
3633
for b in batches:
3734
yield b
3835

@@ -49,7 +46,7 @@ def __init__(
4946
) -> None:
5047
random.seed(seed)
5148
self.batch_sampler = LengthBasedBatchSampler(
52-
data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
49+
data_source, batch_size=batch_size, drop_last=False, shuffle=shuffle
5350
)
5451
self.num_replicas = num_replicas
5552
self.rank = rank

QEfficient/finetune/utils/dataset_utils.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7-
7+
import datasets
88
import torch
99
import torch.distributed as dist
1010
from transformers.data import DataCollatorForSeq2Seq
@@ -64,19 +64,35 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, split):
6464

6565
def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
6666
dataset = get_preprocessed_dataset(tokenizer, dataset_config, split, context_length=train_config.context_length)
67-
dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split)
67+
dataset = dataset.select(range(0, 10))
68+
dataset = dataset.map(lambda x: {"input_length": len(x["input_ids"])})
69+
dataset = dataset.sort("input_length")
70+
dataset = dataset.remove_columns("input_length")
71+
dummy_row = next(iter(dataset))
72+
dummy_row["labels"] = [-100] * len(dummy_row["labels"])
73+
padding_size = 0
74+
num_replicas = dist.get_world_size()
75+
if len(dataset) % num_replicas > 0:
76+
padding_size = num_replicas - len(dataset) % num_replicas
77+
78+
dummy_data = [dummy_row.copy() for _ in range(padding_size)]
79+
dummy_dataset = datasets.Dataset.from_list(dummy_data)
80+
combined_dataset = datasets.concatenate_datasets([dataset, dummy_dataset])
81+
82+
dl_kwargs = get_dataloader_kwargs(train_config, combined_dataset, tokenizer, split)
6883

6984
# FIXME (Meet): Add custom data collator registration from the outside by the user.
7085
custom_data_collator = get_custom_data_collator(tokenizer, dataset_config)
86+
7187
if custom_data_collator:
7288
print("custom_data_collator is used")
7389
dl_kwargs["collate_fn"] = custom_data_collator
7490

7591
print(f"length of dataset_{split}", len(dataset))
76-
7792
# Create data loader
93+
7894
dataloader = torch.utils.data.DataLoader(
79-
dataset,
95+
combined_dataset,
8096
num_workers=train_config.num_workers_dataloader,
8197
pin_memory=True,
8298
**dl_kwargs,

QEfficient/finetune/utils/train_utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,9 @@ def train(
192192
) as verifier:
193193
model_outputs = model(**batch)
194194
loss = model_outputs.loss # Forward call
195+
if (batch["labels"] != -100).sum() == 0:
196+
loss = loss.nan_to_num(nan=0.0)
197+
195198
if train_config.task_type == "seq_classification":
196199
logits = model_outputs.logits
197200
labels = batch["labels"][:, 0]
@@ -201,15 +204,17 @@ def train(
201204
else:
202205
model_outputs = model(**batch)
203206
loss = model_outputs.loss # Forward call
207+
if (batch["labels"] != -100).sum() == 0:
208+
loss = loss.nan_to_num(nan=0.0)
209+
204210
if train_config.task_type == "seq_classification":
205211
logits = model_outputs.logits
206212
labels = batch["labels"][:, 0]
207213
preds = torch.nn.functional.softmax(logits, dim=-1)
208214
acc_helper.forward(preds, labels)
209215

210216
total_loss += loss.detach().float()
211-
# Accumalate gradients
212-
loss = loss / train_config.gradient_accumulation_steps
217+
213218
if train_config.enable_ddp:
214219
if local_rank == 0:
215220
if loss <= train_config.convergence_loss:
@@ -237,6 +242,9 @@ def train(
237242
step_metric_val = float(torch.exp(loss.detach().float()))
238243
train_step_metric.append(step_metric_val)
239244

245+
# Accumalate gradients
246+
loss = loss / train_config.gradient_accumulation_steps
247+
240248
if train_config.grad_scaler:
241249
scaler.scale(loss).backward() # backward pass
242250
else:
@@ -439,6 +447,9 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
439447
outputs = model(**batch)
440448
loss = outputs.loss
441449

450+
if (batch["labels"] != -100).sum() == 0:
451+
loss = loss.nan_to_num(nan=0.0)
452+
442453
if train_config.task_type == "seq_classification":
443454
logits = outputs.logits
444455
labels = batch["labels"][:, 0]

0 commit comments

Comments
 (0)