Skip to content

Commit

Permalink
reacking entropy + routing adaptation parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
oleksost committed Aug 13, 2023
1 parent d936017 commit dfbcfdb
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 113 deletions.
16 changes: 11 additions & 5 deletions finetune_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def remove_non_serializable(d):

class Config(MTTLConfig):
def __init__(self, **kwargs):
self.rank = 1
self.rank = 1
self.prune_unused_loras = True
self.init_b_random = False
self.lora_dropout = 0
Expand All @@ -72,6 +72,7 @@ def __init__(self, **kwargs):
self.micro_batch_size = 4
self.share_lora_at_attn = 0
self.share_lora_a = False
self.x_router_init_scale = 0.02
self.merge_A_B_seperately = True
self.train_on_inputs = False
self.padding_side = "right"
Expand All @@ -83,7 +84,10 @@ def __init__(self, **kwargs):
self.wandb_project = None
self.switch_to_average = 0
# self.balanced = 0


self.router_weight_decay = None
self.normalize_xrouter_weights = False
self.normalize_xrouter_input = False
self.reverse_xrouter_kl = False
self.param_names_added_to_sd = "" # define additional params that will be added to state dict additionally to the trainable ones.
self.xrouter_pad_token_mask = False
Expand All @@ -96,7 +100,9 @@ def __init__(self, **kwargs):
self.validation_portion = 0.03
self.per_cluster_test = False
self.use_test_set = False # wether to use examples marked as is_test = 1 in ClusterInfo as test set


self.superni_eval_batchsize = 2
self.router_learning_rate = None
self.sep_teacher_student = False
self.x_router_sim_metric = "kl"
self.eval_superni = True
Expand Down Expand Up @@ -377,8 +383,8 @@ def run_multitask(args):
if args.eval_superni:
print("Evaluating on super NI")
from inst_follow.eval.gen_ni_predictions import eval_superni
rouge_L_super_ni = eval_superni(model_name="",
batch_size=2,
rouge_L_super_ni = eval_superni(model_name="",
batch_size=args.superni_eval_batchsize,
out_prefix=f"{args.exp_name}",
model_path=path_best_model,
nshot=0, use_outputs=args.eval_superni_use_outputs)
Expand Down
40 changes: 33 additions & 7 deletions inst_follow/models/clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import os
import sys
import torch
import copy
from torch import nn
import copy
from torch.optim.optimizer import Optimizer
import wandb
from collections import defaultdict
from torch import Tensor, nn
from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM

sys.path.append("/home/v-oostapenko/dev/mttl")
Expand All @@ -25,7 +28,7 @@ def __init__(self, **kwargs):
self.tokenizer = kwargs["tokenizer"]
self.pad_token_id = self.tokenizer.pad_token_id
self.model: AutoModelForCausalLM = None

self.accumulate_metrics = defaultdict(list)
if kwargs.get('model_object') is None:
raise NotImplementedError()
self.model = AutoModelForSeq2SeqLM.from_pretrained(
Expand Down Expand Up @@ -86,7 +89,7 @@ def forward(self, batch, reduction='mean'):
bs = input_ids.size(0)
logits = outputs.logits
vocab_size = logits.size(-1)
labels = labels.squeeze(-1)
labels = labels.squeeze(-1)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
Expand All @@ -105,6 +108,13 @@ def forward(self, batch, reduction='mean'):
loss = loss.sum(dim=-1) / non_zero_loss
del outputs, shift_logits, shift_labels
aux_loss = torch.mean(torch.stack(self.model.task_id_container["routing_infos"].aux_loss)) if len(self.model.task_id_container["routing_infos"].aux_loss)>0 else 0


# log metrics if training
for k,v in self.model.task_id_container["routing_infos"].metrics.items():
self.accumulate_metrics[k].append(torch.tensor(v).mean())


return loss, aux_loss

def calculate_routing_mask(self, x, routing_infos):
Expand Down Expand Up @@ -190,14 +200,25 @@ def generate(self, batch, **kwargs):
**kwargs
)

def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
# if wandb.run is not None:
# pl logger logs every 50 steps, we want to log before every update
for k,v in self.accumulate_metrics.items():
# log to wandb directly
# wandb.log({f"train/{k}": torch.tensor(v).mean()})
self.log(f"train/{k}", torch.tensor(v).mean(), on_step=True)
self.accumulate_metrics = defaultdict(list)
return super().on_before_optimizer_step(optimizer, optimizer_idx)

def training_step(self, batch, _):
loss, aux_loss = self.forward(batch)
total_loss = loss+aux_loss
# outputs = self.model.forward(**batch)
# loss= outputs.loss
self.log("train/loss", loss, on_epoch=True, prog_bar=True)
self.log("train/loss", loss, on_epoch=True, prog_bar=True)
self.log("train/aux_loss", aux_loss, on_epoch=True, prog_bar=True)
self.log("train/total_loss", total_loss, on_epoch=True, prog_bar=True)


for plugin in self.loss_plugins.values():
plugin_loss = plugin.compute_loss(self.model, batch)
Expand All @@ -211,17 +232,22 @@ def training_step(self, batch, _):
return total_loss

def validation_step(self, batch, batch_idx, log=True):
self.accumulate_metrics = defaultdict(list)
loss, aux_loss = self.forward(batch, reduction='none')
total_loss = loss #+aux_loss
# outputs = self.model.forward(**batch, reduction='none')
# loss= outputs.loss
# loss= outputs.loss
mean_loss = total_loss.sum() / loss.size(0)
if log:
self.log("val/loss", mean_loss, on_epoch=True, prog_bar=True)
self.log("val/loss", mean_loss, on_epoch=True, prog_bar=True)
self.log("val/aux_loss", aux_loss, on_epoch=True, prog_bar=True)
for k,v in self.model.task_id_container["routing_infos"].metrics.items():
self.log(f"val/{k}", torch.tensor(v).mean(), on_epoch=True, prog_bar=True)
# self.log("val/total_loss", total_loss, on_epoch=True, prog_bar=True)
return loss, batch['task_ids']

def on_before_backward(self, loss: Tensor) -> None:
return super().on_before_backward(loss)

def test_step(self, batch, batch_idx):
loss, aux_loss = self.forward(batch, reduction='none')
Expand Down
21 changes: 17 additions & 4 deletions mttl/models/get_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ def get_optimizer(model, args, no_decay=None):
):
if any(nd in param_name for nd in no_decay):
param_groups["no_decay"]["params"].append(param)
elif "module_logits" in param_name:
elif "module_logits" in param_name:
param_groups["module_logits"]["params"].append(param)
elif "lora" in param_name:
param_groups["adapters"]["params"].append(param)
elif "selector" in param_name:
param_groups["router"]["params"].append(param)
else:
param_groups["others"]["params"].append(param)
trainable_param_names.add(param_name)
Expand All @@ -48,12 +50,23 @@ def get_optimizer(model, args, no_decay=None):
if args.adapters_weight_decay is not None
else args.weight_decay
)
param_groups[key]["lr"] = (
args.adapters_learning_rate
param_groups[key]["lr"] = (
args.adapters_learning_rate
if key in ["adapters"] and args.adapters_learning_rate
else args.learning_rate
)
else:
elif key in ["router"]:
param_groups[key]["weight_decay"] = (
args.router_weight_decay
if args.router_weight_decay is not None
else args.weight_decay
)
param_groups[key]["lr"] = (
args.router_learning_rate
if key in ["router"] and args.router_learning_rate
else args.learning_rate
)
else:
param_groups[key]["weight_decay"] = (
0.0 if key in ["module_logits", "no_decay"] else args.weight_decay
)
Expand Down
Loading

0 comments on commit dfbcfdb

Please sign in to comment.