Skip to content

Commit

Permalink
basic knot merging
Browse files Browse the repository at this point in the history
  • Loading branch information
oleksost committed Nov 13, 2024
1 parent 8dff094 commit 55768b2
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 21 deletions.
212 changes: 192 additions & 20 deletions mttl/models/library/library_transforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import abc
import copy
import dataclasses
import os
import re
from abc import abstractmethod
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Union

Expand All @@ -24,7 +26,9 @@
from mttl.models.library.expert_library import ExpertLibrary
from mttl.models.lightning.callbacks import LiveCheckpointCallback
from mttl.models.lightning.loggers import get_pl_loggers
from mttl.models.modifiers._delta_mod import WDeltaRAConfig
from mttl.models.modifiers.base import get_target_2_source_param_mapping
from mttl.models.modifiers.lora import LoRAConfig
from mttl.models.monitors import get_monitors
from mttl.models.utils import transfer_batch_to_device
from mttl.registrable import Registrable
Expand Down Expand Up @@ -320,6 +324,166 @@ def transform(self, library) -> Expert:
return base_expert


@dataclass
class KnotMergeConfig(WeightedLinearMergeConfig):
path: str = None # path to store SVD components


@LibraryTransform.register("weighted_knot_merge", KnotMergeConfig)
class KnotMerge(LibraryTransform):
"""
Computes a weighted KnoT merge for LoRA ezperts as in https://arxiv.org/pdf/2410.19735
"""

def __init__(self, config: KnotMergeConfig = None):
super().__init__(config or KnotMergeConfig())
self.ingredients = None

def transform(self, library) -> Expert:
if type(library) == str:
library = ExpertLibrary.get_expert_library(library)
# TODO: this should probably be stored in the library. Its not related to any expert, but current libary.add_auxiliary_data requires that aux data is associated with an expert.
if not os.path.exists(self.config.path):
U, task_Ss, task_sVs, UsV_dict = self.apply_svd(library)

self.ingredients = {
"U": U,
"task_Ss": task_Ss,
"task_sVs": task_sVs, # premultiplied with s, cause Vs alone do not have the scale info.
"UsV_dict": UsV_dict,
}

torch.save(self.ingredients, self.config.path)
self.ingredients = torch.load(self.config.path)
task_sVs = self.ingredients["task_sVs"]
U = self.ingredients["U"]
ties_mergert = TiesMerge()

# Prepare for Ties merging of sVs
expert_vectors = []
for expert, params in enumerate(task_sVs):
expert_vectors += [
torch.nn.utils.parameters_to_vector(
list(params[k] for k in params.keys())
)
]

state_dict = {}
expert_vectors = torch.stack(expert_vectors, dim=0)
per_exp_th = expert_vectors.abs().quantile(
1.0 - ties_mergert.config.top_k, dim=1
)
param_names = list(task_sVs[0].keys())

for p_name in param_names:
expert_weights = torch.stack([expert[p_name] for expert in task_sVs], dim=0)
TH = per_exp_th.view(-1, *((1,) * (expert_weights.ndim - 1)))
final_param, _, _ = ties_mergert.merge_param(TH, expert_weights)
delta_W = U[p_name] @ final_param # out_features, in_features
state_dict[p_name + ".weight"] = delta_W

expert_names = list(library.keys())
base_expert = copy.deepcopy(library[expert_names[0]])
expert_config = WDeltaRAConfig(
modify_modules=base_expert.expert_config.modify_modules,
modify_layers=base_expert.expert_config.modify_layers,
)
expert_info = base_expert.expert_info
expert_info.expert_name = "knoted_expert"
expert_info.expert_task_name = "+".join(expert_names)
expert_info.expert_config = expert_config
return Expert(expert_info, state_dict)

def apply_svd(self, library):
"""
Reused from https://github.com/gstoica27/KnOTS/blob/main/task_merger.py
"""
expert_names = list(library.keys())
experts = [library[name] for name in expert_names]

logger.info("Knotting {} experts".format(len(experts)))

base_expert = copy.deepcopy(experts[0])
base_expert.name = "weighted_expert"

if self.config.weights is not None:
assert set(self.config.weights.keys()) == set(
expert_names
), "Weights must have the same keys as the experts"
if not (1 - 1e-6) <= sum(self.config.weights.values()) <= (1 + 1e-6):
logger.warning(
"Weights do not sum to 1.0, please make sure this is intended"
)

layers = set(
[
k.split(".lora")[0]
for k in base_expert.expert_weights.keys()
if ".lora" in k
]
)
d_in, d_out = (
base_expert.expert_weights[f"{list(layers)[0]}.lora_a"].shape[0],
base_expert.expert_weights[f"{list(layers)[0]}.lora_b"].shape[1],
)

UsV_dict = {}
basis_dict = {} # basis for reconstruction
s_compositions_dict = [
dict() for _ in range(len(experts))
] # singular values composition information per task
V_compositions_dict = [
dict() for _ in range(len(experts))
] # basis composition information per task

for layer in layers:
Ws = []
logger.info(f"Computing KnoT merge for layer {layer}")
# retreieve lora A and B from all experts
# create W
for _, expert in zip(expert_names, experts):
# Validate that the expert is compatible
assert (
type(expert.expert_info.expert_config) == LoRAConfig
), "Expert configs must be the same type"
assert set(expert.expert_weights.keys()) == set(
base_expert.expert_weights.keys()
), "Expert weights must have the same keys"
lora_a = expert.expert_weights[f"{layer}.lora_a"]
lora_b = expert.expert_weights[f"{layer}.lora_b"]
rank = expert.expert_config.lora_rank
assert (
lora_b.shape[0] == lora_a.shape[1] == rank
), "lora_a and lora_a must have the same rank as the expert"
W = (lora_a @ lora_b).T # out_features, in_features
Ws.append(W)

# SVD
device = "cuda" if torch.cuda.is_available() else "cpu"
W_l = torch.cat(Ws, dim=1).to(device)
U, s, Vt = torch.linalg.svd(W_l.to(torch.float64), full_matrices=False)
U = U[:, s > 1e-5].type(torch.float32)
Vt = Vt[s > 1e-5].type(torch.float32)
s = s[s > 1e-5].type(torch.float32)
UsV_dict[layer] = {"U": deepcopy(U), "s": deepcopy(s), "V": deepcopy(Vt)}
# Set all s to be the same scale
s[s <= 1e-5] = 0
cat_hidden_dim = Vt.shape[1] // len(experts)

basis_dict[layer] = U.cpu()
sV_concat = Vt
Vs = list(torch.split(sV_concat, cat_hidden_dim, dim=1))
for idx, V in enumerate(Vs):
V = (
torch.diag(s) @ V
) # WE use Ties merging hat relies on magnitde info, which is not present in Vs only. Comment from original code base: Simple and safe for all merging methods we use.
s_model = s / s

s_compositions_dict[idx][layer] = s_model.cpu()
V_compositions_dict[idx][layer] = V.cpu()
return basis_dict, s_compositions_dict, V_compositions_dict, UsV_dict


@dataclass
class TiesMergeConfig(LibraryTransformConfig):
top_k: float = 0.2
Expand All @@ -337,6 +501,30 @@ def __init__(self, config: TiesMergeConfig = None):

assert self.config.top_k > 0.0 and self.config.top_k <= 1.0

@torch.no_grad()
def merge_param(self, TH, expert_weights):
# keep weights over the threshold
keep_mask = expert_weights.abs() >= TH
expert_weights = expert_weights * keep_mask
used = 0

if self.config.only_sparsify:
final_param = expert_weights.mean(0)
used += keep_mask.sum().item()
else:
# sign majority vote
sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign()
sign_per_dim = expert_weights.sum(0, keepdim=True).sign()

# keep only weights whose sign agree with the majority
use_for_avg = expert_weights.sign() == sign_per_dim

deno = use_for_avg.sum(0).clamp(min=1.0)
sum_param = (expert_weights * use_for_avg).sum(0)
final_param = sum_param / deno
used += (use_for_avg & (sign_per_dim != 0.0)).sum().item()
return final_param, used, expert_weights

@torch.no_grad()
def transform(self, library) -> Expert:
if type(library) == str:
Expand Down Expand Up @@ -376,28 +564,12 @@ def transform(self, library) -> Expert:
expert_weights = torch.stack(
[expert.expert_weights[param_name] for expert in experts], dim=0
)

# keep weights over the threshold
TH = per_exp_th.view(-1, *((1,) * (expert_weights.ndim - 1)))
keep_mask = expert_weights.abs() >= TH
expert_weights = expert_weights * keep_mask

if self.config.only_sparsify:
final_param = expert_weights.mean(0)
used += keep_mask.sum().item()
else:
# sign majority vote
sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign()
sign_per_dim = expert_weights.sum(0, keepdim=True).sign()

# keep only weights whose sign agree with the majority
use_for_avg = expert_weights.sign() == sign_per_dim

deno = use_for_avg.sum(0).clamp(min=1.0)
sum_param = (expert_weights * use_for_avg).sum(0)
final_param = sum_param / deno
used += (use_for_avg & (sign_per_dim != 0.0)).sum().item()
final_param, used_per_pa, expert_weights = self.merge_param(
TH, expert_weights
)

used += used_per_pa
kept += (expert_weights.abs() > TH).sum()
total += expert_weights.numel()

Expand Down
54 changes: 53 additions & 1 deletion tests/test_library_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@
from mttl.arguments import ExpertConfig
from mttl.logging import logger
from mttl.models.containers.selectors import PhatgooseSelector, PhatgooseSelectorConfig
from mttl.models.expert_model import MultiExpertModel, MultiExpertModelConfig
from mttl.models.expert_model import (
ExpertModel,
ExpertModelConfig,
MultiExpertModel,
MultiExpertModelConfig,
)
from mttl.models.library.expert_library import HFExpertLibrary, LocalExpertLibrary
from mttl.models.library.library_transforms import (
ArrowConfig,
ArrowTransform,
HiddenStateComputer,
HiddenStateComputerConfig,
KnotMerge,
KnotMergeConfig,
MBClusteringTransformConfig,
MBCWithCosSimTransform,
PhatgooseConfig,
Expand All @@ -39,6 +46,51 @@ def test_config():
cfg3 = ArrowConfig(ab_only=True, scale=False)
assert cfg3.save_name == cfg.save_name

def test_knot_merge(tmp_path, create_dummy_expert):
config = ExpertConfig(
**{
"model_modifier": "lora",
"lora_rank": 8,
"lora_alpha": 16,
"warmup_steps": 0,
"modify_layers": "c_fc",
"trainable_param_names": ".*lora_[ab].*",
"output_dir": tmp_path,
"precision": "32",
"model": "EleutherAI/gpt-neo-125m",
"device_map": "cpu",
"dataset_type": "flan",
"lora_init_b_random": True, # this is important otw phatgoose gates are 0 given that the experts are not trained
}
)
model = ExpertModel(ExpertModelConfig(base_model="EleutherAI/gpt-neo-125m"))

config.finetune_task_name = "cot_creak"
expert1 = create_dummy_expert(config, "cot_creak")

config.finetune_task_name = "cot_creak_ii"
expert2 = create_dummy_expert(config, "cot_creak_ii")
# only leave 1 layer to speed up things.
expert1.expert_weights = {k:v for k,v in expert1.expert_weights.items() if "8.mlp" in k}
expert2.expert_weights = {k:v for k,v in expert2.expert_weights.items() if "8.mlp" in k}


library = LocalExpertLibrary(tmp_path)
library.add_expert(expert1)
library.add_expert(expert2)

transform = KnotMerge(KnotMergeConfig(path=f"{tmp_path}/knot_ingredients.pt"))
exp = transform.transform(library)
state_dict = model.model.state_dict()

# TODO: this can be implemented as a seperate modifier maybe or utils func.
merged_layers = []
for p_name, value in exp.expert_weights.items():
if p_name in state_dict:
merged_layers.append(p_name)
state_dict[p_name]+=value
assert len(merged_layers) == len(exp.expert_weights.keys()) == 1


def test_arrow():
logger.setLevel(logging.DEBUG)
Expand Down

0 comments on commit 55768b2

Please sign in to comment.