Skip to content

Commit

Permalink
merge with instr_composition
Browse files Browse the repository at this point in the history
  • Loading branch information
zhansu committed Aug 16, 2023
2 parents 15dcb5d + dfbcfdb commit 5ca1b92
Show file tree
Hide file tree
Showing 8 changed files with 643 additions and 185 deletions.
11 changes: 9 additions & 2 deletions finetune_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from mttl.datamodule.db_dolly_module import DatBricksDollyModule
from mttl.datamodule.longform_data_module import LongFormDataModule
from mttl.datamodule.wizzard_data_module import WizzardDataModule
from mttl.datamodule.flan_module import FlanModule

# from mttl.datamodule.flan_module import FlanModule
from mttl.models.encoder_decoder import EncoderDecoder
from mttl.models.t0_encoder_decoder import T0EncoderDecoder
from mttl.config import Config as MTTLConfig
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(self, **kwargs):
self.micro_batch_size = 4
self.share_lora_at_attn = 0
self.share_lora_a = False
self.x_router_init_scale = 0.02
self.merge_A_B_seperately = True
self.train_on_inputs = False
self.padding_side = "right"
Expand All @@ -91,6 +93,9 @@ def __init__(self, **kwargs):
self.switch_to_average = 0
# self.balanced = 0

self.router_weight_decay = None
self.normalize_xrouter_weights = False
self.normalize_xrouter_input = False
self.reverse_xrouter_kl = False
self.param_names_added_to_sd = "" # define additional params that will be added to state dict additionally to the trainable ones.
self.xrouter_pad_token_mask = False
Expand All @@ -104,6 +109,8 @@ def __init__(self, **kwargs):
self.per_cluster_test = False
self.use_test_set = False # wether to use examples marked as is_test = 1 in ClusterInfo as test set

self.superni_eval_batchsize = 2
self.router_learning_rate = None
self.sep_teacher_student = False
self.x_router_sim_metric = "kl"
self.eval_superni = True
Expand Down Expand Up @@ -403,7 +410,7 @@ def run_multitask(args):

rouge_L_super_ni = eval_superni(
model_name="",
batch_size=2,
batch_size=args.superni_eval_batchsize,
out_prefix=f"{args.exp_name}",
model_path=path_best_model,
nshot=0,
Expand Down
40 changes: 33 additions & 7 deletions inst_follow/models/clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import os
import sys
import torch
import copy
from torch import nn
import copy
from torch.optim.optimizer import Optimizer
import wandb
from collections import defaultdict
from torch import Tensor, nn
from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM

sys.path.append("/home/v-oostapenko/dev/mttl")
Expand All @@ -25,7 +28,7 @@ def __init__(self, **kwargs):
self.tokenizer = kwargs["tokenizer"]
self.pad_token_id = self.tokenizer.pad_token_id
self.model: AutoModelForCausalLM = None

self.accumulate_metrics = defaultdict(list)
if kwargs.get('model_object') is None:
raise NotImplementedError()
self.model = AutoModelForSeq2SeqLM.from_pretrained(
Expand Down Expand Up @@ -86,7 +89,7 @@ def forward(self, batch, reduction='mean'):
bs = input_ids.size(0)
logits = outputs.logits
vocab_size = logits.size(-1)
labels = labels.squeeze(-1)
labels = labels.squeeze(-1)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
Expand All @@ -105,6 +108,13 @@ def forward(self, batch, reduction='mean'):
loss = loss.sum(dim=-1) / non_zero_loss
del outputs, shift_logits, shift_labels
aux_loss = torch.mean(torch.stack(self.model.task_id_container["routing_infos"].aux_loss)) if len(self.model.task_id_container["routing_infos"].aux_loss)>0 else 0


# log metrics if training
for k,v in self.model.task_id_container["routing_infos"].metrics.items():
self.accumulate_metrics[k].append(torch.tensor(v).mean())


return loss, aux_loss

def calculate_routing_mask(self, x, routing_infos):
Expand Down Expand Up @@ -190,14 +200,25 @@ def generate(self, batch, **kwargs):
**kwargs
)

def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
# if wandb.run is not None:
# pl logger logs every 50 steps, we want to log before every update
for k,v in self.accumulate_metrics.items():
# log to wandb directly
# wandb.log({f"train/{k}": torch.tensor(v).mean()})
self.log(f"train/{k}", torch.tensor(v).mean(), on_step=True)
self.accumulate_metrics = defaultdict(list)
return super().on_before_optimizer_step(optimizer, optimizer_idx)

def training_step(self, batch, _):
loss, aux_loss = self.forward(batch)
total_loss = loss+aux_loss
# outputs = self.model.forward(**batch)
# loss= outputs.loss
self.log("train/loss", loss, on_epoch=True, prog_bar=True)
self.log("train/loss", loss, on_epoch=True, prog_bar=True)
self.log("train/aux_loss", aux_loss, on_epoch=True, prog_bar=True)
self.log("train/total_loss", total_loss, on_epoch=True, prog_bar=True)


for plugin in self.loss_plugins.values():
plugin_loss = plugin.compute_loss(self.model, batch)
Expand All @@ -211,17 +232,22 @@ def training_step(self, batch, _):
return total_loss

def validation_step(self, batch, batch_idx, log=True):
self.accumulate_metrics = defaultdict(list)
loss, aux_loss = self.forward(batch, reduction='none')
total_loss = loss #+aux_loss
# outputs = self.model.forward(**batch, reduction='none')
# loss= outputs.loss
# loss= outputs.loss
mean_loss = total_loss.sum() / loss.size(0)
if log:
self.log("val/loss", mean_loss, on_epoch=True, prog_bar=True)
self.log("val/loss", mean_loss, on_epoch=True, prog_bar=True)
self.log("val/aux_loss", aux_loss, on_epoch=True, prog_bar=True)
for k,v in self.model.task_id_container["routing_infos"].metrics.items():
self.log(f"val/{k}", torch.tensor(v).mean(), on_epoch=True, prog_bar=True)
# self.log("val/total_loss", total_loss, on_epoch=True, prog_bar=True)
return loss, batch['task_ids']

def on_before_backward(self, loss: Tensor) -> None:
return super().on_before_backward(loss)

def test_step(self, batch, batch_idx):
loss, aux_loss = self.forward(batch, reduction='none')
Expand Down
Loading

0 comments on commit 5ca1b92

Please sign in to comment.