Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

basic knot merging #140

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 193 additions & 20 deletions mttl/models/library/library_transforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import abc
import copy
import dataclasses
import os
import re
from abc import abstractmethod
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Union

Expand All @@ -24,7 +26,9 @@
from mttl.models.library.expert_library import ExpertLibrary
from mttl.models.lightning.callbacks import LiveCheckpointCallback
from mttl.models.lightning.loggers import get_pl_loggers
from mttl.models.modifiers._delta_mod import WDeltaRAConfig
from mttl.models.modifiers.base import get_target_2_source_param_mapping
from mttl.models.modifiers.lora import LoRAConfig
from mttl.models.monitors import get_monitors
from mttl.models.utils import transfer_batch_to_device
from mttl.registrable import Registrable
Expand Down Expand Up @@ -320,6 +324,167 @@ def transform(self, library) -> Expert:
return base_expert


@dataclass
class KnotMergeConfig(WeightedLinearMergeConfig):
path: str = None # path to store SVD components


@LibraryTransform.register("weighted_knot_merge", KnotMergeConfig)
class KnotMerge(LibraryTransform):
"""
Computes a weighted KnoT merge for LoRA ezperts as in https://arxiv.org/pdf/2410.19735
"""

def __init__(self, config: KnotMergeConfig = None):
super().__init__(config or KnotMergeConfig())
self.ingredients = None

def transform(self, library) -> Expert:
if type(library) == str:
library = ExpertLibrary.get_expert_library(library)
# TODO: this should probably be stored in the library instead of the local path.
# Current libary.add_auxiliary_data requires that aux data is associated with an expert, this is not associated with any expert.
if not os.path.exists(self.config.path):
U, task_Ss, task_sVs, UsV_dict = self.apply_svd(library)

self.ingredients = {
"U": U,
"task_Ss": task_Ss,
"task_sVs": task_sVs, # premultiplied with s, cause Vs alone do not have the scale info.
"UsV_dict": UsV_dict,
}

torch.save(self.ingredients, self.config.path)
self.ingredients = torch.load(self.config.path)
task_sVs = self.ingredients["task_sVs"]
U = self.ingredients["U"]
ties_mergert = TiesMerge()

# Prepare for Ties merging of sVs
expert_vectors = []
for expert, params in enumerate(task_sVs):
expert_vectors += [
torch.nn.utils.parameters_to_vector(
list(params[k] for k in params.keys())
)
]

state_dict = {}
expert_vectors = torch.stack(expert_vectors, dim=0)
per_exp_th = expert_vectors.abs().quantile(
1.0 - ties_mergert.config.top_k, dim=1
)
param_names = list(task_sVs[0].keys())

for p_name in param_names:
expert_weights = torch.stack([expert[p_name] for expert in task_sVs], dim=0)
TH = per_exp_th.view(-1, *((1,) * (expert_weights.ndim - 1)))
final_param, _, _ = ties_mergert.merge_param(TH, expert_weights)
delta_W = U[p_name] @ final_param # out_features, in_features
state_dict[p_name + ".weight"] = delta_W

expert_names = list(library.keys())
base_expert = copy.deepcopy(library[expert_names[0]])
expert_config = WDeltaRAConfig(
modify_modules=base_expert.expert_config.modify_modules,
modify_layers=base_expert.expert_config.modify_layers,
)
expert_info = base_expert.expert_info
expert_info.expert_name = "knoted_expert"
expert_info.expert_task_name = "+".join(expert_names)
expert_info.expert_config = expert_config
return Expert(expert_info, state_dict)

def apply_svd(self, library):
"""
Reused from https://github.com/gstoica27/KnOTS/blob/main/task_merger.py
"""
expert_names = list(library.keys())
experts = [library[name] for name in expert_names]

logger.info("Knotting {} experts".format(len(experts)))

base_expert = copy.deepcopy(experts[0])
base_expert.name = "weighted_expert"

if self.config.weights is not None:
assert set(self.config.weights.keys()) == set(
expert_names
), "Weights must have the same keys as the experts"
if not (1 - 1e-6) <= sum(self.config.weights.values()) <= (1 + 1e-6):
logger.warning(
"Weights do not sum to 1.0, please make sure this is intended"
)

layers = set(
[
k.split(".lora")[0]
for k in base_expert.expert_weights.keys()
if ".lora" in k
]
)
d_in, d_out = (
base_expert.expert_weights[f"{list(layers)[0]}.lora_a"].shape[0],
base_expert.expert_weights[f"{list(layers)[0]}.lora_b"].shape[1],
)

UsV_dict = {}
basis_dict = {} # basis for reconstruction
s_compositions_dict = [
dict() for _ in range(len(experts))
] # singular values composition information per task
V_compositions_dict = [
dict() for _ in range(len(experts))
] # basis composition information per task

for layer in layers:
Ws = []
logger.info(f"Computing KnoT merge for layer {layer}")
# retreieve lora A and B from all experts
# create W
for _, expert in zip(expert_names, experts):
# Validate that the expert is compatible
assert (
type(expert.expert_info.expert_config) == LoRAConfig
), "Expert configs must be the same type"
assert set(expert.expert_weights.keys()) == set(
base_expert.expert_weights.keys()
), "Expert weights must have the same keys"
lora_a = expert.expert_weights[f"{layer}.lora_a"]
lora_b = expert.expert_weights[f"{layer}.lora_b"]
rank = expert.expert_config.lora_rank
assert (
lora_b.shape[0] == lora_a.shape[1] == rank
), "lora_a and lora_a must have the same rank as the expert"
W = (lora_a @ lora_b).T # out_features, in_features
Ws.append(W)

# SVD
device = "cuda" if torch.cuda.is_available() else "cpu"
W_l = torch.cat(Ws, dim=1).to(device)
U, s, Vt = torch.linalg.svd(W_l.to(torch.float64), full_matrices=False)
U = U[:, s > 1e-5].type(torch.float32)
Vt = Vt[s > 1e-5].type(torch.float32)
s = s[s > 1e-5].type(torch.float32)
UsV_dict[layer] = {"U": deepcopy(U), "s": deepcopy(s), "V": deepcopy(Vt)}
# Set all s to be the same scale
s[s <= 1e-5] = 0
cat_hidden_dim = Vt.shape[1] // len(experts)

basis_dict[layer] = U.cpu()
sV_concat = Vt
Vs = list(torch.split(sV_concat, cat_hidden_dim, dim=1))
for idx, V in enumerate(Vs):
V = (
torch.diag(s) @ V
) # WE use Ties merging hat relies on magnitde info, which is not present in Vs only. Comment from original code base: Simple and safe for all merging methods we use.
s_model = s / s

s_compositions_dict[idx][layer] = s_model.cpu()
V_compositions_dict[idx][layer] = V.cpu()
return basis_dict, s_compositions_dict, V_compositions_dict, UsV_dict


@dataclass
class TiesMergeConfig(LibraryTransformConfig):
top_k: float = 0.2
Expand All @@ -337,6 +502,30 @@ def __init__(self, config: TiesMergeConfig = None):

assert self.config.top_k > 0.0 and self.config.top_k <= 1.0

@torch.no_grad()
def merge_param(self, TH, expert_weights):
# keep weights over the threshold
keep_mask = expert_weights.abs() >= TH
expert_weights = expert_weights * keep_mask
used = 0

if self.config.only_sparsify:
final_param = expert_weights.mean(0)
used += keep_mask.sum().item()
else:
# sign majority vote
sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign()
sign_per_dim = expert_weights.sum(0, keepdim=True).sign()

# keep only weights whose sign agree with the majority
use_for_avg = expert_weights.sign() == sign_per_dim

deno = use_for_avg.sum(0).clamp(min=1.0)
sum_param = (expert_weights * use_for_avg).sum(0)
final_param = sum_param / deno
used += (use_for_avg & (sign_per_dim != 0.0)).sum().item()
return final_param, used, expert_weights

@torch.no_grad()
def transform(self, library) -> Expert:
if type(library) == str:
Expand Down Expand Up @@ -376,28 +565,12 @@ def transform(self, library) -> Expert:
expert_weights = torch.stack(
[expert.expert_weights[param_name] for expert in experts], dim=0
)

# keep weights over the threshold
TH = per_exp_th.view(-1, *((1,) * (expert_weights.ndim - 1)))
keep_mask = expert_weights.abs() >= TH
expert_weights = expert_weights * keep_mask

if self.config.only_sparsify:
final_param = expert_weights.mean(0)
used += keep_mask.sum().item()
else:
# sign majority vote
sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign()
sign_per_dim = expert_weights.sum(0, keepdim=True).sign()

# keep only weights whose sign agree with the majority
use_for_avg = expert_weights.sign() == sign_per_dim

deno = use_for_avg.sum(0).clamp(min=1.0)
sum_param = (expert_weights * use_for_avg).sum(0)
final_param = sum_param / deno
used += (use_for_avg & (sign_per_dim != 0.0)).sum().item()
final_param, used_per_pa, expert_weights = self.merge_param(
TH, expert_weights
)

used += used_per_pa
kept += (expert_weights.abs() > TH).sum()
total += expert_weights.numel()

Expand Down
58 changes: 57 additions & 1 deletion tests/test_library_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@
from mttl.arguments import ExpertConfig
from mttl.logging import logger
from mttl.models.containers.selectors import PhatgooseSelector, PhatgooseSelectorConfig
from mttl.models.expert_model import MultiExpertModel, MultiExpertModelConfig
from mttl.models.expert_model import (
ExpertModel,
ExpertModelConfig,
MultiExpertModel,
MultiExpertModelConfig,
)
from mttl.models.library.expert_library import HFExpertLibrary, LocalExpertLibrary
from mttl.models.library.library_transforms import (
ArrowConfig,
ArrowTransform,
HiddenStateComputer,
HiddenStateComputerConfig,
KnotMerge,
KnotMergeConfig,
MBClusteringTransformConfig,
MBCWithCosSimTransform,
PhatgooseConfig,
Expand All @@ -40,6 +47,55 @@ def test_config():
assert cfg3.save_name == cfg.save_name


def test_knot_merge(tmp_path, create_dummy_expert):
config = ExpertConfig(
**{
"model_modifier": "lora",
"lora_rank": 8,
"lora_alpha": 16,
"warmup_steps": 0,
"modify_layers": "c_fc",
"trainable_param_names": ".*lora_[ab].*",
"output_dir": tmp_path,
"precision": "32",
"model": "EleutherAI/gpt-neo-125m",
"device_map": "cpu",
"dataset_type": "flan",
"lora_init_b_random": True, # this is important otw phatgoose gates are 0 given that the experts are not trained
}
)
model = ExpertModel(ExpertModelConfig(base_model="EleutherAI/gpt-neo-125m"))

config.finetune_task_name = "cot_creak"
expert1 = create_dummy_expert(config, "cot_creak")

config.finetune_task_name = "cot_creak_ii"
expert2 = create_dummy_expert(config, "cot_creak_ii")
# only leave 1 layer to speed up things.
expert1.expert_weights = {
k: v for k, v in expert1.expert_weights.items() if "8.mlp" in k
}
expert2.expert_weights = {
k: v for k, v in expert2.expert_weights.items() if "8.mlp" in k
}

library = LocalExpertLibrary(tmp_path)
library.add_expert(expert1)
library.add_expert(expert2)

transform = KnotMerge(KnotMergeConfig(path=f"{tmp_path}/knot_ingredients.pt"))
exp = transform.transform(library)
state_dict = model.model.state_dict()

# TODO: this can be implemented as a seperate modifier maybe or utils func.
merged_layers = []
for p_name, value in exp.expert_weights.items():
if p_name in state_dict:
merged_layers.append(p_name)
state_dict[p_name] += value
assert len(merged_layers) == len(exp.expert_weights.keys()) == 1


def test_arrow():
logger.setLevel(logging.DEBUG)

Expand Down
Loading