diff --git a/QEfficient/finetune/utils/helper.py b/QEfficient/finetune/utils/helper.py index 8562b2aed..9e55a16ff 100644 --- a/QEfficient/finetune/utils/helper.py +++ b/QEfficient/finetune/utils/helper.py @@ -5,6 +5,15 @@ # # ----------------------------------------------------------------------------- import os +from contextlib import nullcontext + +import torch + +try: + import torch_qaic.debug as qaic_debug # noqa: F401 +except ImportError as e: + print(f"Warning: {e}. Moving ahead without these qaic modules.") + TASK_TYPE = ["generation", "seq_classification"] PEFT_METHOD = ["lora"] @@ -14,3 +23,34 @@ def get_num_ddp_devices(): return int(os.getenv("WORLD_SIZE", 1)) + + +def get_autocast_ctx(use_autocast, device_type, dtype=torch.float16): + return torch.autocast(device_type=device_type, dtype=dtype) if use_autocast else nullcontext() + + +def get_op_verifier_ctx( + use_op_by_op_verifier, + train_device, + dump_dir, + step, + ref_device="cpu", + ref_dtype=torch.float32, + atol=1e-1, + rtol=1e-5, + use_ref_output_on_mismatch=True, +): + if not use_op_by_op_verifier: + return nullcontext() + + filter_config = qaic_debug.DispatchFilterConfig.default(train_device) + dump_dir = dump_dir + "_" + str(step) + return qaic_debug.OpByOpVerifierMode( + ref_device=ref_device, + ref_dtype=ref_dtype, + atol=atol, + rtol=rtol, + use_ref_output_on_mismatch=use_ref_output_on_mismatch, + filter_config=filter_config, + dump_root_dir=dump_dir, + ) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index f513ba5c4..6eb44dc43 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -8,8 +8,8 @@ import json import os import time -from contextlib import nullcontext from datetime import datetime +from functools import partial from typing import Dict, List, Tuple import torch @@ -19,6 +19,7 @@ from tqdm import tqdm from QEfficient.finetune.configs.training import TrainConfig +from QEfficient.finetune.utils.helper import get_autocast_ctx, get_op_verifier_ctx try: import torch_qaic # noqa: F401 @@ -110,6 +111,9 @@ def train( num_classes = model.classifier.out_features acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device) + autocast_ctx = get_autocast_ctx(train_config.use_autocast, device_type, dtype=torch.float16) + op_verifier_ctx = partial(get_op_verifier_ctx, train_config.opByOpVerifier, device, train_config.dump_root_dir) + # Start the training loop for epoch in range(train_config.num_epochs): if loss_0_counter.item() == train_config.convergence_counter: @@ -174,60 +178,38 @@ def train( break batch = {k: v.to(device) for k, v in batch.items()} # move the batch elements to qaic device - with ( - torch.autocast(device_type=device_type, dtype=torch.float16) - if train_config.use_autocast - else nullcontext() - ): - # an additional condition can be put here to avoid opByOpVerifier getting triggered for each step - if train_config.opByOpVerifier: - with qaic_debug.OpByOpVerifierMode( - ref_device="cpu", - ref_dtype=torch.float32, - # adjust atol & rtol this as required - atol=1e-1, - use_ref_output_on_mismatch=True, - filter_config=qaic_debug.DispatchFilterConfig.default(device), - dump_root_dir=train_config.dump_root_dir + str(step), - ) as verifier: - model_outputs = model(**batch) - loss = model_outputs.loss # Forward call - if (batch["labels"] != -100).sum() == 0: - loss = loss.nan_to_num(nan=0.0) - num_dummy_samples += train_config.train_batch_size - else: - num_dummy_samples_per_batch = ( - (torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item() - ) - if num_dummy_samples_per_batch > 0: - num_dummy_samples += num_dummy_samples_per_batch - loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch - - if train_config.task_type == "seq_classification": - logits = model_outputs.logits - labels = batch["labels"][:, 0] - preds = torch.nn.functional.softmax(logits, dim=-1) - acc_helper.forward(preds, labels) - print("Mismatches detected:", verifier.get_perop_mismatch_count()) + is_optimizer_step = (step + 1) % train_config.gradient_accumulation_steps == 0 or step == len( + train_dataloader + ) - 1 + if train_config.enable_ddp: + # Below block derived from : https://github.com/karpathy/nanoGPT/blob/93a43d9a5c22450bbf06e78da2cb6eeef084b717/train.py#L293 + # in DDP training we only need to sync gradients at the last micro step. + # the official way to do this is with model.no_sync() context manager, but + # using too many context managers may bloat the code and forces us to repeat code + # looking at the source of that context manager, it just toggles this variable + model.require_backward_grad_sync = is_optimizer_step + + with autocast_ctx, op_verifier_ctx(step) as verifier: + model_outputs = model(**batch) + loss = model_outputs.loss # Forward call + if (batch["labels"] != -100).sum() == 0: + loss = loss.nan_to_num(nan=0.0) + num_dummy_samples += train_config.train_batch_size else: - model_outputs = model(**batch) - loss = model_outputs.loss # Forward call - if (batch["labels"] != -100).sum() == 0: - loss = loss.nan_to_num(nan=0.0) - num_dummy_samples += train_config.train_batch_size - else: - num_dummy_samples_per_batch = ( - (torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item() - ) - if num_dummy_samples_per_batch > 0: - num_dummy_samples += num_dummy_samples_per_batch - loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch + num_dummy_samples_per_batch = ( + (torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item() + ) + if num_dummy_samples_per_batch > 0: + num_dummy_samples += num_dummy_samples_per_batch + loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch - if train_config.task_type == "seq_classification": - logits = model_outputs.logits - labels = batch["labels"][:, 0] - preds = torch.nn.functional.softmax(logits, dim=-1) - acc_helper.forward(preds, labels) + if train_config.task_type == "seq_classification": + logits = model_outputs.logits + labels = batch["labels"][:, 0] + preds = torch.nn.functional.softmax(logits, dim=-1) + acc_helper.forward(preds, labels) + if train_config.opByOpVerifier: + print("Mismatches detected:", verifier.get_perop_mismatch_count()) total_loss += loss.detach().float() @@ -274,7 +256,7 @@ def train( else: loss.backward() # backward pass - if (step + 1) % train_config.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + if is_optimizer_step: if train_config.grad_scaler: scaler.step(optimizer) scaler.update() @@ -468,6 +450,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device): device_type = torch.device(device).type num_dummy_samples = 0 + autocast_ctx = get_autocast_ctx(train_config.use_autocast, device_type, dtype=torch.float16) for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)): # stop when the maximum number of eval steps is reached if train_config.max_eval_step > 0 and step > train_config.max_eval_step: @@ -478,11 +461,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device): # Ensure no gradients are computed for this scope to save memory with torch.no_grad(): # Forward pass and compute loss - with ( - torch.autocast(device_type=device_type, dtype=torch.float16) - if train_config.use_autocast - else nullcontext() - ): + with autocast_ctx: outputs = model(**batch) loss = outputs.loss