From 55768b2f8ccdb537a940548038638d86cb359416 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 13 Nov 2024 11:32:18 -0500 Subject: [PATCH] basic knot merging --- mttl/models/library/library_transforms.py | 212 ++++++++++++++++++++-- tests/test_library_transforms.py | 54 +++++- 2 files changed, 245 insertions(+), 21 deletions(-) diff --git a/mttl/models/library/library_transforms.py b/mttl/models/library/library_transforms.py index f2dce145c..0095bc958 100644 --- a/mttl/models/library/library_transforms.py +++ b/mttl/models/library/library_transforms.py @@ -1,9 +1,11 @@ import abc import copy import dataclasses +import os import re from abc import abstractmethod from collections import defaultdict +from copy import deepcopy from dataclasses import dataclass from typing import Dict, List, Union @@ -24,7 +26,9 @@ from mttl.models.library.expert_library import ExpertLibrary from mttl.models.lightning.callbacks import LiveCheckpointCallback from mttl.models.lightning.loggers import get_pl_loggers +from mttl.models.modifiers._delta_mod import WDeltaRAConfig from mttl.models.modifiers.base import get_target_2_source_param_mapping +from mttl.models.modifiers.lora import LoRAConfig from mttl.models.monitors import get_monitors from mttl.models.utils import transfer_batch_to_device from mttl.registrable import Registrable @@ -320,6 +324,166 @@ def transform(self, library) -> Expert: return base_expert +@dataclass +class KnotMergeConfig(WeightedLinearMergeConfig): + path: str = None # path to store SVD components + + +@LibraryTransform.register("weighted_knot_merge", KnotMergeConfig) +class KnotMerge(LibraryTransform): + """ + Computes a weighted KnoT merge for LoRA ezperts as in https://arxiv.org/pdf/2410.19735 + """ + + def __init__(self, config: KnotMergeConfig = None): + super().__init__(config or KnotMergeConfig()) + self.ingredients = None + + def transform(self, library) -> Expert: + if type(library) == str: + library = ExpertLibrary.get_expert_library(library) + # TODO: this should probably be stored in the library. Its not related to any expert, but current libary.add_auxiliary_data requires that aux data is associated with an expert. + if not os.path.exists(self.config.path): + U, task_Ss, task_sVs, UsV_dict = self.apply_svd(library) + + self.ingredients = { + "U": U, + "task_Ss": task_Ss, + "task_sVs": task_sVs, # premultiplied with s, cause Vs alone do not have the scale info. + "UsV_dict": UsV_dict, + } + + torch.save(self.ingredients, self.config.path) + self.ingredients = torch.load(self.config.path) + task_sVs = self.ingredients["task_sVs"] + U = self.ingredients["U"] + ties_mergert = TiesMerge() + + # Prepare for Ties merging of sVs + expert_vectors = [] + for expert, params in enumerate(task_sVs): + expert_vectors += [ + torch.nn.utils.parameters_to_vector( + list(params[k] for k in params.keys()) + ) + ] + + state_dict = {} + expert_vectors = torch.stack(expert_vectors, dim=0) + per_exp_th = expert_vectors.abs().quantile( + 1.0 - ties_mergert.config.top_k, dim=1 + ) + param_names = list(task_sVs[0].keys()) + + for p_name in param_names: + expert_weights = torch.stack([expert[p_name] for expert in task_sVs], dim=0) + TH = per_exp_th.view(-1, *((1,) * (expert_weights.ndim - 1))) + final_param, _, _ = ties_mergert.merge_param(TH, expert_weights) + delta_W = U[p_name] @ final_param # out_features, in_features + state_dict[p_name + ".weight"] = delta_W + + expert_names = list(library.keys()) + base_expert = copy.deepcopy(library[expert_names[0]]) + expert_config = WDeltaRAConfig( + modify_modules=base_expert.expert_config.modify_modules, + modify_layers=base_expert.expert_config.modify_layers, + ) + expert_info = base_expert.expert_info + expert_info.expert_name = "knoted_expert" + expert_info.expert_task_name = "+".join(expert_names) + expert_info.expert_config = expert_config + return Expert(expert_info, state_dict) + + def apply_svd(self, library): + """ + Reused from https://github.com/gstoica27/KnOTS/blob/main/task_merger.py + """ + expert_names = list(library.keys()) + experts = [library[name] for name in expert_names] + + logger.info("Knotting {} experts".format(len(experts))) + + base_expert = copy.deepcopy(experts[0]) + base_expert.name = "weighted_expert" + + if self.config.weights is not None: + assert set(self.config.weights.keys()) == set( + expert_names + ), "Weights must have the same keys as the experts" + if not (1 - 1e-6) <= sum(self.config.weights.values()) <= (1 + 1e-6): + logger.warning( + "Weights do not sum to 1.0, please make sure this is intended" + ) + + layers = set( + [ + k.split(".lora")[0] + for k in base_expert.expert_weights.keys() + if ".lora" in k + ] + ) + d_in, d_out = ( + base_expert.expert_weights[f"{list(layers)[0]}.lora_a"].shape[0], + base_expert.expert_weights[f"{list(layers)[0]}.lora_b"].shape[1], + ) + + UsV_dict = {} + basis_dict = {} # basis for reconstruction + s_compositions_dict = [ + dict() for _ in range(len(experts)) + ] # singular values composition information per task + V_compositions_dict = [ + dict() for _ in range(len(experts)) + ] # basis composition information per task + + for layer in layers: + Ws = [] + logger.info(f"Computing KnoT merge for layer {layer}") + # retreieve lora A and B from all experts + # create W + for _, expert in zip(expert_names, experts): + # Validate that the expert is compatible + assert ( + type(expert.expert_info.expert_config) == LoRAConfig + ), "Expert configs must be the same type" + assert set(expert.expert_weights.keys()) == set( + base_expert.expert_weights.keys() + ), "Expert weights must have the same keys" + lora_a = expert.expert_weights[f"{layer}.lora_a"] + lora_b = expert.expert_weights[f"{layer}.lora_b"] + rank = expert.expert_config.lora_rank + assert ( + lora_b.shape[0] == lora_a.shape[1] == rank + ), "lora_a and lora_a must have the same rank as the expert" + W = (lora_a @ lora_b).T # out_features, in_features + Ws.append(W) + + # SVD + device = "cuda" if torch.cuda.is_available() else "cpu" + W_l = torch.cat(Ws, dim=1).to(device) + U, s, Vt = torch.linalg.svd(W_l.to(torch.float64), full_matrices=False) + U = U[:, s > 1e-5].type(torch.float32) + Vt = Vt[s > 1e-5].type(torch.float32) + s = s[s > 1e-5].type(torch.float32) + UsV_dict[layer] = {"U": deepcopy(U), "s": deepcopy(s), "V": deepcopy(Vt)} + # Set all s to be the same scale + s[s <= 1e-5] = 0 + cat_hidden_dim = Vt.shape[1] // len(experts) + + basis_dict[layer] = U.cpu() + sV_concat = Vt + Vs = list(torch.split(sV_concat, cat_hidden_dim, dim=1)) + for idx, V in enumerate(Vs): + V = ( + torch.diag(s) @ V + ) # WE use Ties merging hat relies on magnitde info, which is not present in Vs only. Comment from original code base: Simple and safe for all merging methods we use. + s_model = s / s + + s_compositions_dict[idx][layer] = s_model.cpu() + V_compositions_dict[idx][layer] = V.cpu() + return basis_dict, s_compositions_dict, V_compositions_dict, UsV_dict + + @dataclass class TiesMergeConfig(LibraryTransformConfig): top_k: float = 0.2 @@ -337,6 +501,30 @@ def __init__(self, config: TiesMergeConfig = None): assert self.config.top_k > 0.0 and self.config.top_k <= 1.0 + @torch.no_grad() + def merge_param(self, TH, expert_weights): + # keep weights over the threshold + keep_mask = expert_weights.abs() >= TH + expert_weights = expert_weights * keep_mask + used = 0 + + if self.config.only_sparsify: + final_param = expert_weights.mean(0) + used += keep_mask.sum().item() + else: + # sign majority vote + sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign() + sign_per_dim = expert_weights.sum(0, keepdim=True).sign() + + # keep only weights whose sign agree with the majority + use_for_avg = expert_weights.sign() == sign_per_dim + + deno = use_for_avg.sum(0).clamp(min=1.0) + sum_param = (expert_weights * use_for_avg).sum(0) + final_param = sum_param / deno + used += (use_for_avg & (sign_per_dim != 0.0)).sum().item() + return final_param, used, expert_weights + @torch.no_grad() def transform(self, library) -> Expert: if type(library) == str: @@ -376,28 +564,12 @@ def transform(self, library) -> Expert: expert_weights = torch.stack( [expert.expert_weights[param_name] for expert in experts], dim=0 ) - - # keep weights over the threshold TH = per_exp_th.view(-1, *((1,) * (expert_weights.ndim - 1))) - keep_mask = expert_weights.abs() >= TH - expert_weights = expert_weights * keep_mask - - if self.config.only_sparsify: - final_param = expert_weights.mean(0) - used += keep_mask.sum().item() - else: - # sign majority vote - sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign() - sign_per_dim = expert_weights.sum(0, keepdim=True).sign() - - # keep only weights whose sign agree with the majority - use_for_avg = expert_weights.sign() == sign_per_dim - - deno = use_for_avg.sum(0).clamp(min=1.0) - sum_param = (expert_weights * use_for_avg).sum(0) - final_param = sum_param / deno - used += (use_for_avg & (sign_per_dim != 0.0)).sum().item() + final_param, used_per_pa, expert_weights = self.merge_param( + TH, expert_weights + ) + used += used_per_pa kept += (expert_weights.abs() > TH).sum() total += expert_weights.numel() diff --git a/tests/test_library_transforms.py b/tests/test_library_transforms.py index b45bb1afe..66ffc1287 100644 --- a/tests/test_library_transforms.py +++ b/tests/test_library_transforms.py @@ -11,13 +11,20 @@ from mttl.arguments import ExpertConfig from mttl.logging import logger from mttl.models.containers.selectors import PhatgooseSelector, PhatgooseSelectorConfig -from mttl.models.expert_model import MultiExpertModel, MultiExpertModelConfig +from mttl.models.expert_model import ( + ExpertModel, + ExpertModelConfig, + MultiExpertModel, + MultiExpertModelConfig, +) from mttl.models.library.expert_library import HFExpertLibrary, LocalExpertLibrary from mttl.models.library.library_transforms import ( ArrowConfig, ArrowTransform, HiddenStateComputer, HiddenStateComputerConfig, + KnotMerge, + KnotMergeConfig, MBClusteringTransformConfig, MBCWithCosSimTransform, PhatgooseConfig, @@ -39,6 +46,51 @@ def test_config(): cfg3 = ArrowConfig(ab_only=True, scale=False) assert cfg3.save_name == cfg.save_name +def test_knot_merge(tmp_path, create_dummy_expert): + config = ExpertConfig( + **{ + "model_modifier": "lora", + "lora_rank": 8, + "lora_alpha": 16, + "warmup_steps": 0, + "modify_layers": "c_fc", + "trainable_param_names": ".*lora_[ab].*", + "output_dir": tmp_path, + "precision": "32", + "model": "EleutherAI/gpt-neo-125m", + "device_map": "cpu", + "dataset_type": "flan", + "lora_init_b_random": True, # this is important otw phatgoose gates are 0 given that the experts are not trained + } + ) + model = ExpertModel(ExpertModelConfig(base_model="EleutherAI/gpt-neo-125m")) + + config.finetune_task_name = "cot_creak" + expert1 = create_dummy_expert(config, "cot_creak") + + config.finetune_task_name = "cot_creak_ii" + expert2 = create_dummy_expert(config, "cot_creak_ii") + # only leave 1 layer to speed up things. + expert1.expert_weights = {k:v for k,v in expert1.expert_weights.items() if "8.mlp" in k} + expert2.expert_weights = {k:v for k,v in expert2.expert_weights.items() if "8.mlp" in k} + + + library = LocalExpertLibrary(tmp_path) + library.add_expert(expert1) + library.add_expert(expert2) + + transform = KnotMerge(KnotMergeConfig(path=f"{tmp_path}/knot_ingredients.pt")) + exp = transform.transform(library) + state_dict = model.model.state_dict() + + # TODO: this can be implemented as a seperate modifier maybe or utils func. + merged_layers = [] + for p_name, value in exp.expert_weights.items(): + if p_name in state_dict: + merged_layers.append(p_name) + state_dict[p_name]+=value + assert len(merged_layers) == len(exp.expert_weights.keys()) == 1 + def test_arrow(): logger.setLevel(logging.DEBUG)