Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions QEfficient/finetune/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@
#
# -----------------------------------------------------------------------------
import os
from contextlib import nullcontext

import torch

try:
import torch_qaic.debug as qaic_debug # noqa: F401
except ImportError as e:
print(f"Warning: {e}. Moving ahead without these qaic modules.")


TASK_TYPE = ["generation", "seq_classification"]
PEFT_METHOD = ["lora"]
Expand All @@ -14,3 +23,34 @@

def get_num_ddp_devices():
return int(os.getenv("WORLD_SIZE", 1))


def get_autocast_ctx(use_autocast, device_type, dtype=torch.float16):
return torch.autocast(device_type=device_type, dtype=dtype) if use_autocast else nullcontext()


def get_op_verifier_ctx(
use_op_by_op_verifier,
train_device,
dump_dir,
step,
ref_device="cpu",
ref_dtype=torch.float32,
atol=1e-1,
rtol=1e-5,
use_ref_output_on_mismatch=True,
):
if not use_op_by_op_verifier:
return nullcontext()

filter_config = qaic_debug.DispatchFilterConfig.default(train_device)
dump_dir = dump_dir + "_" + str(step)
return qaic_debug.OpByOpVerifierMode(
ref_device=ref_device,
ref_dtype=ref_dtype,
atol=atol,
rtol=rtol,
use_ref_output_on_mismatch=use_ref_output_on_mismatch,
filter_config=filter_config,
dump_root_dir=dump_dir,
)
97 changes: 38 additions & 59 deletions QEfficient/finetune/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import json
import os
import time
from contextlib import nullcontext
from datetime import datetime
from functools import partial
from typing import Dict, List, Tuple

import torch
Expand All @@ -19,6 +19,7 @@
from tqdm import tqdm

from QEfficient.finetune.configs.training import TrainConfig
from QEfficient.finetune.utils.helper import get_autocast_ctx, get_op_verifier_ctx

try:
import torch_qaic # noqa: F401
Expand Down Expand Up @@ -110,6 +111,9 @@ def train(
num_classes = model.classifier.out_features
acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device)

autocast_ctx = get_autocast_ctx(train_config.use_autocast, device_type, dtype=torch.float16)
op_verifier_ctx = partial(get_op_verifier_ctx, train_config.opByOpVerifier, device, train_config.dump_root_dir)

# Start the training loop
for epoch in range(train_config.num_epochs):
if loss_0_counter.item() == train_config.convergence_counter:
Expand Down Expand Up @@ -174,60 +178,38 @@ def train(
break
batch = {k: v.to(device) for k, v in batch.items()} # move the batch elements to qaic device

with (
torch.autocast(device_type=device_type, dtype=torch.float16)
if train_config.use_autocast
else nullcontext()
):
# an additional condition can be put here to avoid opByOpVerifier getting triggered for each step
if train_config.opByOpVerifier:
with qaic_debug.OpByOpVerifierMode(
ref_device="cpu",
ref_dtype=torch.float32,
# adjust atol & rtol this as required
atol=1e-1,
use_ref_output_on_mismatch=True,
filter_config=qaic_debug.DispatchFilterConfig.default(device),
dump_root_dir=train_config.dump_root_dir + str(step),
) as verifier:
model_outputs = model(**batch)
loss = model_outputs.loss # Forward call
if (batch["labels"] != -100).sum() == 0:
loss = loss.nan_to_num(nan=0.0)
num_dummy_samples += train_config.train_batch_size
else:
num_dummy_samples_per_batch = (
(torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item()
)
if num_dummy_samples_per_batch > 0:
num_dummy_samples += num_dummy_samples_per_batch
loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch

if train_config.task_type == "seq_classification":
logits = model_outputs.logits
labels = batch["labels"][:, 0]
preds = torch.nn.functional.softmax(logits, dim=-1)
acc_helper.forward(preds, labels)
print("Mismatches detected:", verifier.get_perop_mismatch_count())
is_optimizer_step = (step + 1) % train_config.gradient_accumulation_steps == 0 or step == len(
train_dataloader
) - 1
if train_config.enable_ddp:
# Below block derived from : https://github.com/karpathy/nanoGPT/blob/93a43d9a5c22450bbf06e78da2cb6eeef084b717/train.py#L293
# in DDP training we only need to sync gradients at the last micro step.
# the official way to do this is with model.no_sync() context manager, but
# using too many context managers may bloat the code and forces us to repeat code
# looking at the source of that context manager, it just toggles this variable
model.require_backward_grad_sync = is_optimizer_step

with autocast_ctx, op_verifier_ctx(step) as verifier:
model_outputs = model(**batch)
loss = model_outputs.loss # Forward call
if (batch["labels"] != -100).sum() == 0:
loss = loss.nan_to_num(nan=0.0)
num_dummy_samples += train_config.train_batch_size
else:
model_outputs = model(**batch)
loss = model_outputs.loss # Forward call
if (batch["labels"] != -100).sum() == 0:
loss = loss.nan_to_num(nan=0.0)
num_dummy_samples += train_config.train_batch_size
else:
num_dummy_samples_per_batch = (
(torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item()
)
if num_dummy_samples_per_batch > 0:
num_dummy_samples += num_dummy_samples_per_batch
loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch
num_dummy_samples_per_batch = (
(torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item()
)
if num_dummy_samples_per_batch > 0:
num_dummy_samples += num_dummy_samples_per_batch
loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch

if train_config.task_type == "seq_classification":
logits = model_outputs.logits
labels = batch["labels"][:, 0]
preds = torch.nn.functional.softmax(logits, dim=-1)
acc_helper.forward(preds, labels)
if train_config.task_type == "seq_classification":
logits = model_outputs.logits
labels = batch["labels"][:, 0]
preds = torch.nn.functional.softmax(logits, dim=-1)
acc_helper.forward(preds, labels)
if train_config.opByOpVerifier:
print("Mismatches detected:", verifier.get_perop_mismatch_count())

total_loss += loss.detach().float()

Expand Down Expand Up @@ -274,7 +256,7 @@ def train(
else:
loss.backward() # backward pass

if (step + 1) % train_config.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
if is_optimizer_step:
if train_config.grad_scaler:
scaler.step(optimizer)
scaler.update()
Expand Down Expand Up @@ -468,6 +450,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
device_type = torch.device(device).type

num_dummy_samples = 0
autocast_ctx = get_autocast_ctx(train_config.use_autocast, device_type, dtype=torch.float16)
for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
# stop when the maximum number of eval steps is reached
if train_config.max_eval_step > 0 and step > train_config.max_eval_step:
Expand All @@ -478,11 +461,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
# Ensure no gradients are computed for this scope to save memory
with torch.no_grad():
# Forward pass and compute loss
with (
torch.autocast(device_type=device_type, dtype=torch.float16)
if train_config.use_autocast
else nullcontext()
):
with autocast_ctx:
outputs = model(**batch)
loss = outputs.loss

Expand Down
Loading