-
Notifications
You must be signed in to change notification settings - Fork 44
[QEff Finetune] Adding dataset padding changes #478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Swati Allabadi <[email protected]>
b8104e1
to
adebe02
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please generate the ppl numbers across different ddp devices, grad accum step to make this change concrete.
@@ -64,19 +64,35 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, split): | |||
|
|||
def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"): | |||
dataset = get_preprocessed_dataset(tokenizer, dataset_config, split, context_length=train_config.context_length) | |||
dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split) | |||
dataset = dataset.select(range(0, 10)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why slicing dataset to pick first 10 samples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was added for local experiments. Removed in the PR.
dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split) | ||
dataset = dataset.select(range(0, 10)) | ||
dataset = dataset.map(lambda x: {"input_length": len(x["input_ids"])}) | ||
dataset = dataset.sort("input_length") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why sorting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done the sorting here for the non padded dataset in place of being done in sampler.py for the padded dataset to keep the dummy samples in the end.
dummy_row["labels"] = [-100] * len(dummy_row["labels"]) | ||
padding_size = 0 | ||
num_replicas = dist.get_world_size() | ||
if len(dataset) % num_replicas > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bs>1 is not considered here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had skipped it since we are not supporting bs > 1 as of now. Made this change for sake of completion.
if len(dataset) % num_replicas > 0: | ||
padding_size = num_replicas - len(dataset) % num_replicas | ||
|
||
dummy_data = [dummy_row.copy() for _ in range(padding_size)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
L78 to L80 can be refactored.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found this way cleaner. Please suggest if you have anything better idea.
@@ -192,6 +192,9 @@ def train( | |||
) as verifier: | |||
model_outputs = model(**batch) | |||
loss = model_outputs.loss # Forward call | |||
if (batch["labels"] != -100).sum() == 0: | |||
loss = loss.nan_to_num(nan=0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loss is zeroed for dummy samples. But the total loss is averaged across all samples including dummy samples. Correct it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -237,6 +242,9 @@ def train( | |||
step_metric_val = float(torch.exp(loss.detach().float())) | |||
train_step_metric.append(step_metric_val) | |||
|
|||
# Accumalate gradients | |||
loss = loss / train_config.gradient_accumulation_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should change.
E.g. 100 samples, 30 global bs.
For first 30 samples, loss = loss / 30
For next 30 samples, loss = loss / 30
For next 30 samples, loss = loss / 30
For last 10 samples, loss = loss / 10 not loss = loss / 30
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -439,6 +447,9 @@ def evaluation_helper(model, train_config, eval_dataloader, device): | |||
outputs = model(**batch) | |||
loss = outputs.loss | |||
|
|||
if (batch["labels"] != -100).sum() == 0: | |||
loss = loss.nan_to_num(nan=0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
dummy_data = [dummy_row.copy() for _ in range(padding_size)] | ||
dummy_dataset = datasets.Dataset.from_list(dummy_data) | ||
combined_dataset = datasets.concatenate_datasets([dataset, dummy_dataset]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try to enclose this padding logic in separate function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Signed-off-by: Swati Allabadi <[email protected]>
No description provided.