diff --git a/finetune_llama.py b/finetune_llama.py index b81738f54..09bbb78d2 100644 --- a/finetune_llama.py +++ b/finetune_llama.py @@ -336,17 +336,17 @@ def run_multitask(args): callbacks.append(checkpoint_callback) trainer = Trainer( - gpus=1, + devices=1, accelerator="gpu", logger=loggers, num_sanity_val_steps=5, - amp_backend="native", + # amp_backend="native", default_root_dir=args.output_dir, max_epochs=args.num_train_epochs, max_steps=args.total_steps + 1 if args.total_steps != -1 else -1, gradient_clip_val=args.max_grad_norm, log_every_n_steps=20, - strategy=args.compute_strategy if args.compute_strategy else None, + strategy="ddp" if not args.compute_strategy else args.compute_strategy, callbacks=callbacks, accumulate_grad_batches=args.gradient_accumulation_steps, precision=int(args.precision) diff --git a/inst_follow/models/clm.py b/inst_follow/models/clm.py index 9d4d51eea..21fc5525b 100644 --- a/inst_follow/models/clm.py +++ b/inst_follow/models/clm.py @@ -1,22 +1,27 @@ import json import os import sys -import torch +import torch import copy -from torch.optim.optimizer import Optimizer -import wandb +from torch.optim.optimizer import Optimizer +import wandb from collections import defaultdict from torch import Tensor, nn from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM sys.path.append("/home/v-oostapenko/dev/mttl") from mttl.models.get_optimizer import get_optimizer -from mttl.models.get_scheduler import get_scheduler +from mttl.models.get_scheduler import get_scheduler from mttl.models.modify_model import modify_transformer -from mttl.models.utils import EfficientCheckpointModule, RoutingInfo, get_global_batch_size +from mttl.models.utils import ( + EfficientCheckpointModule, + RoutingInfo, + get_global_batch_size, +) from mttl.utils import freeze_embeds from pytorch_lightning.utilities.parsing import AttributeDict + class CLM(EfficientCheckpointModule): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -29,10 +34,11 @@ def __init__(self, **kwargs): self.pad_token_id = self.tokenizer.pad_token_id self.model: AutoModelForCausalLM = None self.accumulate_metrics = defaultdict(list) - if kwargs.get('model_object') is None: + if kwargs.get("model_object") is None: raise NotImplementedError() self.model = AutoModelForSeq2SeqLM.from_pretrained( - self.args.model, cache_dir=os.environ.get('TRANSFORMERS_CACHE', "/tmp/hf-cache") + self.args.model, + cache_dir=os.environ.get("TRANSFORMERS_CACHE", "/tmp/hf-cache"), ) # free-up temporary space os.system("/bin/rm -rf /tmp/hf-cache") @@ -62,7 +68,7 @@ def __init__(self, **kwargs): print("Freezing embeddings") freeze_embeds(self.model) else: - self.model = kwargs.get('model_object') + self.model = kwargs.get("model_object") self.loss_plugins = nn.ModuleDict({}) self.test_results = [] @@ -74,153 +80,194 @@ def add_loss_plugin(self, plugin): else: self.loss_plugins = nn.ModuleDict({plugin.name: plugin}) - def forward(self, batch, reduction='mean'): - input_ids, labels = batch["input_ids"], batch["labels"] - self.model.task_id_container["routing_infos"] = RoutingInfo.from_batch(batch) # pad tokens also have -100 - padding_mask = self.calculate_routing_mask(batch["input_ids"], self.model.task_id_container["routing_infos"]) - setattr(self.model.task_id_container["routing_infos"], "pad_token_mask", padding_mask) + def forward(self, batch, reduction="mean"): + input_ids, labels = batch["input_ids"], batch["labels"] + self.model.task_id_container["routing_infos"] = RoutingInfo.from_batch( + batch + ) # pad tokens also have -100 + padding_mask = self.calculate_routing_mask( + batch["input_ids"], self.model.task_id_container["routing_infos"] + ) + setattr( + self.model.task_id_container["routing_infos"], + "pad_token_mask", + padding_mask, + ) outputs = self.model.forward( input_ids, attention_mask=(input_ids != self.pad_token_id).float(), - ) + ) # output_ids = outputs.logits.argmax(-1) # calculate loss, could also be done inside of the model bs = input_ids.size(0) logits = outputs.logits vocab_size = logits.size(-1) - labels = labels.squeeze(-1) + labels = labels.squeeze(-1) shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens + # Flatten the tokens loss_fct = torch.nn.CrossEntropyLoss(reduction=reduction) shift_logits = shift_logits.view(-1, vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) - #reshape back - if reduction == 'none': - loss = loss.view((bs,-1)) + # reshape back + if reduction == "none": + loss = loss.view((bs, -1)) # mean only non-zero non_zero_loss = (loss != 0).sum(dim=-1) non_zero_loss[non_zero_loss == 0] = 1 loss = loss.sum(dim=-1) / non_zero_loss del outputs, shift_logits, shift_labels - aux_loss = torch.mean(torch.stack(self.model.task_id_container["routing_infos"].aux_loss)) if len(self.model.task_id_container["routing_infos"].aux_loss)>0 else 0 - - + aux_loss = ( + torch.mean( + torch.stack(self.model.task_id_container["routing_infos"].aux_loss) + ) + if len(self.model.task_id_container["routing_infos"].aux_loss) > 0 + else 0 + ) + # log metrics if training - for k,v in self.model.task_id_container["routing_infos"].metrics.items(): - self.accumulate_metrics[k].append(torch.tensor(v).mean()) - - + for k, v in self.model.task_id_container["routing_infos"].metrics.items(): + self.accumulate_metrics[k].append(torch.tensor(v).mean()) + return loss, aux_loss - - def calculate_routing_mask(self, x, routing_infos): - # self.args.x_routing_option=4 + + def calculate_routing_mask(self, x, routing_infos): + # self.args.x_routing_option=4 padding_mask = None bs, seq = x.shape - gen_mode = 0 + gen_mode = 0 if hasattr(routing_infos, "gen_mode"): gen_mode = routing_infos.gen_mode - if self.args.x_routing_option>0: - padding_mask = routing_infos.pad_token_mask # 1 if the token is not a pad token, so its either instruciton or output - instruction_mask = torch.ones_like(padding_mask) # 1 if the token is part of instruction or pad token (so outputs are 0s) - + if self.args.x_routing_option > 0: + padding_mask = ( + routing_infos.pad_token_mask + ) # 1 if the token is not a pad token, so its either instruciton or output + instruction_mask = torch.ones_like( + padding_mask + ) # 1 if the token is part of instruction or pad token (so outputs are 0s) + if routing_infos.labels is not None: - instruction_mask = (routing_infos.labels==-100).float() # 1 if the token is part of instruction or pad token (so outputs are 0s) - if self.args.x_routing_option==1 or (self.args.x_routing_option==2 and gen_mode): # here we only use instruction to decide about the routing - padding_mask = padding_mask * instruction_mask # 1 if the token is part of instruction - - elif self.args.x_routing_option==2: # or self.args.x_routing_option==3: # here we use instruction and part of the output sofar to decide about the routing, routing will be different for each token - padding_mask = padding_mask * instruction_mask # only the instruction + instruction_mask = ( + routing_infos.labels == -100 + ).float() # 1 if the token is part of instruction or pad token (so outputs are 0s) + if self.args.x_routing_option == 1 or ( + self.args.x_routing_option == 2 and gen_mode + ): # here we only use instruction to decide about the routing + padding_mask = ( + padding_mask * instruction_mask + ) # 1 if the token is part of instruction + + elif ( + self.args.x_routing_option == 2 + ): # or self.args.x_routing_option==3: # here we use instruction and part of the output sofar to decide about the routing, routing will be different for each token + padding_mask = padding_mask * instruction_mask # only the instruction # Find the indices of the last occurrence of 1 in tensor A along the last dimension - last_ones_indices = padding_mask.sum(dim=1).unsqueeze(-1)#.cpu() - + last_ones_indices = padding_mask.sum(dim=1).unsqueeze(-1) # .cpu() + # Expand dimensions of last_ones_indices to match the shape of B expanded_indices = last_ones_indices expanded_indices = expanded_indices.repeat(1, seq) expanded_indices_inverse = seq - expanded_indices - expanded_indices_inverse-= torch.arange(seq).unsqueeze(0).to(x.device) - expanded_indices_inverse = torch.max(expanded_indices_inverse, torch.zeros_like(expanded_indices_inverse)) + expanded_indices_inverse -= torch.arange(seq).unsqueeze(0).to(x.device) + expanded_indices_inverse = torch.max( + expanded_indices_inverse, torch.zeros_like(expanded_indices_inverse) + ) expanded_indices_inverse = expanded_indices_inverse.flip(1) mask = expanded_indices + expanded_indices_inverse - mask = mask.unsqueeze(-1).repeat(1,1,seq) + mask = mask.unsqueeze(-1).repeat(1, 1, seq) # shape like mask ar = torch.arange(seq).to(x.device) ar = ar.unsqueeze(0).unsqueeze(0).repeat(bs, seq, 1) - + A = torch.zeros(bs, seq, seq).to(mask.device) B = torch.ones(bs, seq, seq).to(mask.device) - padding_mask = torch.where(ar None: - # if wandb.run is not None: - # pl logger logs every 50 steps, we want to log before every update - for k,v in self.accumulate_metrics.items(): - # log to wandb directly - # wandb.log({f"train/{k}": torch.tensor(v).mean()}) - self.log(f"train/{k}", torch.tensor(v).mean(), on_step=True) + setattr( + self.model.task_id_container["routing_infos"], + "pad_token_mask", + padding_mask, + ) + return self.model.generate(inputs=batch["input_ids"], **kwargs) + + def on_before_optimizer_step( + self, optimizer: Optimizer, optimizer_idx: int + ) -> None: + # if wandb.run is not None: + # pl logger logs every 50 steps, we want to log before every update + for k, v in self.accumulate_metrics.items(): + # log to wandb directly + # wandb.log({f"train/{k}": torch.tensor(v).mean()}) + self.log(f"train/{k}", torch.tensor(v).mean(), on_step=True) self.accumulate_metrics = defaultdict(list) return super().on_before_optimizer_step(optimizer, optimizer_idx) - - def training_step(self, batch, _): + + def training_step(self, batch, _): loss, aux_loss = self.forward(batch) - total_loss = loss+aux_loss - # outputs = self.model.forward(**batch) + total_loss = loss + aux_loss + # outputs = self.model.forward(**batch) # loss= outputs.loss - self.log("train/loss", loss, on_epoch=True, prog_bar=True) + self.log("train/loss", loss, on_epoch=True, prog_bar=True) self.log("train/aux_loss", aux_loss, on_epoch=True, prog_bar=True) self.log("train/total_loss", total_loss, on_epoch=True, prog_bar=True) - - for plugin in self.loss_plugins.values(): + for plugin in self.loss_plugins.values(): plugin_loss = plugin.compute_loss(self.model, batch) loss += plugin.factor * plugin_loss self.log( @@ -231,33 +278,35 @@ def training_step(self, batch, _): self.log(f"train/lr_{i}", pg["lr"]) return total_loss - def validation_step(self, batch, batch_idx, log=True): + def validation_step(self, batch, batch_idx, log=True): self.accumulate_metrics = defaultdict(list) - loss, aux_loss = self.forward(batch, reduction='none') - total_loss = loss #+aux_loss + loss, aux_loss = self.forward(batch, reduction="none") + total_loss = loss # +aux_loss # outputs = self.model.forward(**batch, reduction='none') - # loss= outputs.loss + # loss= outputs.loss mean_loss = total_loss.sum() / loss.size(0) if log: - self.log("val/loss", mean_loss, on_epoch=True, prog_bar=True) + self.log("val/loss", mean_loss, on_epoch=True, prog_bar=True) self.log("val/aux_loss", aux_loss, on_epoch=True, prog_bar=True) - for k,v in self.model.task_id_container["routing_infos"].metrics.items(): - self.log(f"val/{k}", torch.tensor(v).mean(), on_epoch=True, prog_bar=True) + for k, v in self.model.task_id_container["routing_infos"].metrics.items(): + self.log( + f"val/{k}", torch.tensor(v).mean(), on_epoch=True, prog_bar=True + ) # self.log("val/total_loss", total_loss, on_epoch=True, prog_bar=True) - return loss, batch['task_ids'] + return loss, batch["task_ids"] def on_before_backward(self, loss: Tensor) -> None: return super().on_before_backward(loss) - - def test_step(self, batch, batch_idx): - loss, aux_loss = self.forward(batch, reduction='none') - return loss, batch['task_ids'] - - def test_epoch_end(self, outputs): + + def test_step(self, batch, batch_idx): + loss, aux_loss = self.forward(batch, reduction="none") + return loss, batch["task_ids"] + + def test_epoch_end(self, outputs): losses = torch.cat([out[0] for out in outputs], 0) task_ids = torch.cat([out[1] for out in outputs], 0) log_name = f"test/loss" - if hasattr(self.model, "checkpoint_tested"): + if hasattr(self.model, "checkpoint_tested"): log_name = f"test/{self.model.checkpoint_tested}/loss" # log per task loss and overall loss self.log(log_name, losses.mean(), on_epoch=True, prog_bar=True) @@ -266,32 +315,36 @@ def test_epoch_end(self, outputs): if hasattr(self.model, "checkpoint_tested"): log_name = f"test/{self.model.checkpoint_tested}/loss_{task_id.item()}" self.log( - log_name, + log_name, losses[task_ids == task_id].mean(), on_epoch=True, prog_bar=True, ) return losses - - def validation_epoch_end(self, outputs): + + def on_validation_epoch_end(self, outputs): losses = torch.cat([out[0] for out in outputs], 0) task_ids = torch.cat([out[1] for out in outputs], 0) # compute the loss per task id - with open(os.path.join(self.args.output_dir, "val_loss_by_task.txt"), "a+") as f: + with open( + os.path.join(self.args.output_dir, "val_loss_by_task.txt"), "a+" + ) as f: task_losses = {} for task_id in torch.unique(task_ids): task_losses[task_id.item()] = losses[task_ids == task_id].mean().item() f.write(json.dumps(task_losses) + "\n") - def configure_optimizers(self): - args = self.args + def configure_optimizers(self): + args = self.args self.ml_optimizer = self.ml_scheduler = None optimizer, self.trainable_param_names = get_optimizer( self, args, no_decay=["bias", "LayerNorm.weight"] ) - global_bs = get_global_batch_size(args.train_batch_size, args.gradient_accumulation_steps) + global_bs = get_global_batch_size( + args.train_batch_size, args.gradient_accumulation_steps + ) if args.total_steps == -1: args.total_steps = ( @@ -310,10 +363,12 @@ def configure_optimizers(self): "scheduler": scheduler, "interval": "step", }, - } - + } + @property - def hparams_initial(self): # to make wandb logger work we need to override this method + def hparams_initial( + self, + ): # to make wandb logger work we need to override this method """The collection of hyperparameters saved with :meth:`save_hyperparameters`. These contents are read-only. Manual updates to the saved hyperparameters can instead be performed through :attr:`hparams`. @@ -322,8 +377,8 @@ def hparams_initial(self): # to make wandb logger work we need to override this """ if not hasattr(self, "_hparams_initial"): return AttributeDict() - # prevent any change - hparams_initial=copy.deepcopy(self._hparams_initial) + # prevent any change + hparams_initial = copy.deepcopy(self._hparams_initial) # pop anything that is not json serializable - hparams_initial.pop('_updated_kwargs') + hparams_initial.pop("_updated_kwargs") return hparams_initial diff --git a/requirements.txt b/requirements.txt index ff87fd248..2fcb79bbb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,37 +1,23 @@ -datasets transformers==4.28.1 pytorch-lightning==2.0.5 +scikit-learn +higher wandb -rouge +datasets==2.12.0 wget +rouge tqdm pandas -scikit-learn -higher -wandb -datasets==2.11.0 -transformers==4.28.1 -pytorch-lightning==1.8.6 -wget==3.2 -rouge==1.0.0 -tqdm==4.64.0 -pandas==1.4.2 -scikit-learn==1.1.0 -higher==0.2.1 promptsource deepspeed sentencepiece torch_kmeans sentence-transformers fsspec[adl] -rich openai -ray -shortuuid nomic==1.1.6 evaluate click -click rich ray shortuuid