diff --git a/examples/create_arrow_model.py b/examples/create_arrow_model.py index 5645d85a..4b0839d7 100644 --- a/examples/create_arrow_model.py +++ b/examples/create_arrow_model.py @@ -5,7 +5,7 @@ from mttl.models.containers.selectors import ArrowSelectorConfig from mttl.models.expert_model import MultiExpertModel, MultiExpertModelConfig from mttl.models.library.expert_library import ExpertLibrary -from mttl.models.library.library_transforms import ArrowConfig, ArrowTransform +from mttl.models.library.library_transforms import ArrowTransform, ArrowTransformConfig @click.command() @@ -30,7 +30,7 @@ def make_arrow(experts, push_to_hub): library.add_expert_from_ckpt(path) # compute arrow prototypes and store them in the library - arrow_config = ArrowConfig() + arrow_config = ArrowTransformConfig() transform = ArrowTransform(arrow_config) transform.transform(library, persist=True) diff --git a/mttl/logging.py b/mttl/logging.py index e539c76b..eedc35b3 100644 --- a/mttl/logging.py +++ b/mttl/logging.py @@ -10,7 +10,6 @@ # warning if logger is not initialized logger = logging.getLogger("mttl") -logger.setLevel(logging.WARNING) logging.getLogger("datasets.arrow_dataset").setLevel(logging.CRITICAL + 1) diff --git a/mttl/models/containers/selectors/arrow_selector.py b/mttl/models/containers/selectors/arrow_selector.py index f5a10922..fbfdb387 100644 --- a/mttl/models/containers/selectors/arrow_selector.py +++ b/mttl/models/containers/selectors/arrow_selector.py @@ -16,9 +16,12 @@ def compute_arrow_embeddings( add_base_proto=False, recompute_prototypes=False, ) -> str: - from mttl.models.library.library_transforms import ArrowConfig, ArrowTransform + from mttl.models.library.library_transforms import ( + ArrowTransform, + ArrowTransformConfig, + ) - cfg = ArrowConfig( + cfg = ArrowTransformConfig( name=selector_data_id, ab_only=ab_only, tie_params=tie_params or "default", @@ -48,8 +51,6 @@ class ArrowSelector(PerTokenSelector): @artifacts_cache def load_from_library(cls, config): """Fetches prototypes from the library.""" - from mttl.models.library.library_transforms import ArrowConfig, ArrowTransform + from mttl.models.library.library_transforms import ArrowTransform - return ArrowTransform(ArrowConfig(name=config.selector_data_id)).fetch( - config.library_id - ) + return ArrowTransform.fetch(config.library_id, config.selector_data_id) diff --git a/mttl/models/containers/selectors/average_activation_selector.py b/mttl/models/containers/selectors/average_activation_selector.py index 21644aac..1eb256f7 100644 --- a/mttl/models/containers/selectors/average_activation_selector.py +++ b/mttl/models/containers/selectors/average_activation_selector.py @@ -53,11 +53,9 @@ class AverageActivationSelector(PerTokenSelector): @artifacts_cache def load_from_library(cls, config): """Fetches prototypes from the library.""" - from mttl.models.library.library_transforms import ( - HiddenStateComputer, - HiddenStateComputerConfig, - ) + from mttl.models.library.library_transforms import HiddenStateComputer - return HiddenStateComputer( - HiddenStateComputerConfig(name=config.selector_data_id) - ).fetch(config.library_id) + return HiddenStateComputer.fetch( + config.library_id, + config.selector_data_id, + ) diff --git a/mttl/models/containers/selectors/phatgoose_selector.py b/mttl/models/containers/selectors/phatgoose_selector.py index e7d4f66e..944b1efd 100644 --- a/mttl/models/containers/selectors/phatgoose_selector.py +++ b/mttl/models/containers/selectors/phatgoose_selector.py @@ -28,11 +28,11 @@ def compute_phatgoose_embeddings( ) -> str: """Computes Phatgoose embeddings for the given library.""" from mttl.models.library.library_transforms import ( - PhatgooseConfig, PhatgooseTransform, + PhatgooseTransformConfig, ) - cfg = PhatgooseConfig( + cfg = PhatgooseTransformConfig( n_steps=n_steps_pg, learning_rate=learning_rate_pg, name=selector_data_id, @@ -70,14 +70,9 @@ def __init__(self, config, **kwargs) -> None: @artifacts_cache def load_from_library(cls, config): """Fetches prototypes from the library.""" - from mttl.models.library.library_transforms import ( - PhatgooseConfig, - PhatgooseTransform, - ) + from mttl.models.library.library_transforms import PhatgooseTransform - return PhatgooseTransform(PhatgooseConfig(name=config.selector_data_id)).fetch( - config.library_id - ) + return PhatgooseTransform.fetch(config.library_id, config.selector_data_id) @dataclass diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index fcf4428d..ec9cc868 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -571,7 +571,7 @@ def from_pretrained_library( if isinstance(selector_config, LoadableSelectorConfig): selector_config.library_id = repo_id - elif isinstance(selector_config, dict): + elif isinstance(selector_config, MultiSelectorConfig): for modifier_name, cfg in selector_config.items(): # inject the library id if it is None if ( diff --git a/mttl/models/library/library_transforms.py b/mttl/models/library/library_transforms.py index b45a0e88..02508c70 100644 --- a/mttl/models/library/library_transforms.py +++ b/mttl/models/library/library_transforms.py @@ -36,8 +36,13 @@ def train_phatgoose(args, model, datamodule): + """Mini-training loop for phatgoose.""" import tqdm + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(args.seed) + (optimizer, scheduler), _ = get_optimizer_and_scheduler( model, args, num_train_examples=len(datamodule.train_dataset) ) @@ -159,13 +164,16 @@ def __init__(self, config, random_state=None): super().__init__(config) self.random_state = random_state + @classmethod @torch.no_grad() - def fetch(self, library: Union[str, ExpertLibrary]): + def fetch(cls, library: Union[str, ExpertLibrary], config_hash: str = None): if isinstance(library, str): library = ExpertLibrary.get_expert_library(library) + config_hash = config_hash or SVDEmbeddingTransformConfig().save_name + # try to fetch auxiliary data - output = library.get_auxiliary_data(data_type=self.config.save_name) + output = library.get_auxiliary_data(data_type=config_hash) if len(output) == len(library): logger.info("Found {} precomputed SVD Embeddings".format(len(output))) @@ -180,7 +188,7 @@ def transform(self, library, persist=True, recompute=False): library = ExpertLibrary.get_expert_library(library) try: - output = self.fetch(library) + output = self.fetch(library, self.config.save_name) if not recompute: logger.info("Found {} precomputed SVD Embeddings".format(len(output))) @@ -485,13 +493,16 @@ def _retrieve_hidden_states(self, model): return {k: v for k, v in zip(keys, values)} + @classmethod @torch.no_grad() - def fetch(self, library: Union[str, ExpertLibrary]): + def fetch(cls, library: Union[str, ExpertLibrary], config_hash: str = None): if isinstance(library, str): library = ExpertLibrary.get_expert_library(library) + config_hash = config_hash or HiddenStateComputerConfig().save_name + # try to fetch auxiliary data - output = library.get_auxiliary_data(data_type=self.config.save_name) + output = library.get_auxiliary_data(data_type=config_hash) if len(output) > 0: logger.info("Found {} precomputed centroids".format(len(output))) @@ -517,7 +528,7 @@ def transform( library = ExpertLibrary.get_expert_library(library) try: - protos = self.fetch(library) + protos = self.fetch(library, self.config.save_name) if not recompute: logger.info("Found {} precomputed centroids".format(len(protos))) @@ -625,29 +636,30 @@ def transform( @dataclass -class PhatgooseConfig(LibraryTransformConfig): - n_steps: int = 200 - learning_rate: float = 3e-3 +class PhatgooseTransformConfig(LibraryTransformConfig): + n_steps: int = 100 + learning_rate: float = 1e-3 warmup_ratio: float = 0.1 - micro_batch_size: int = 1 - batch_size: int = 1 - - def __post_init__(self): - self.gradient_accumulation_steps = self.batch_size // self.micro_batch_size + micro_batch_size: int = 4 + batch_size: int = 4 + seed: int = 42 -@LibraryTransform.register("phatgoose", PhatgooseConfig) +@LibraryTransform.register("phatgoose", PhatgooseTransformConfig) class PhatgooseTransform(HiddenStateComputer): - def __init__(self, config: PhatgooseConfig = None): - super().__init__(config or PhatgooseConfig()) + def __init__(self, config: PhatgooseTransformConfig = None): + super().__init__(config or PhatgooseTransformConfig()) + @classmethod @torch.no_grad() - def fetch(self, library: Union[str, ExpertLibrary]): + def fetch(cls, library: Union[str, ExpertLibrary], config_hash: str): if isinstance(library, str): library = ExpertLibrary.get_expert_library(library) + config_hash = config_hash or PhatgooseTransformConfig().save_name + # try to fetch auxiliary data - output = library.get_auxiliary_data(data_type=self.config.save_name) + output = library.get_auxiliary_data(data_type=config_hash) if len(output) != len(library): logger.warning( @@ -702,9 +714,6 @@ def transform( training_config.warmup_proportion = self.config.warmup_ratio training_config.train_batch_size = self.config.batch_size training_config.micro_batch_size = self.config.micro_batch_size - training_config.gradient_accumulation_steps = ( - self.config.gradient_accumulation_steps - ) training_config.dataset = expert.expert_info.dataset if expert.expert_info.expert_task_name: @@ -787,28 +796,23 @@ def transform( @dataclass -class ArrowConfig(LibraryTransformConfig): +class ArrowTransformConfig(LibraryTransformConfig): ab_only: bool = True scale: bool = False # If True, scale by eigenvalue tie_params: str = ( "default" # If default, ties the same params as during training. If a regex, processed the same way as during training ) tie_op: str = "concat" # or "sum" - add_base_proto: bool = False - - def param_hash(self): - # for convenience, we exclude the add_base_proto field as it was added later - return param_hash(self, exclude_fields=["add_base_proto"]) -@LibraryTransform.register("arrow", ArrowConfig) +@LibraryTransform.register("arrow", ArrowTransformConfig) class ArrowTransform(LibraryTransform): """ Given a library of experts, extract the input direction most affected by the linear transforms """ - def __init__(self, config: ArrowConfig = None): - super().__init__(config or ArrowConfig()) + def __init__(self, config: ArrowTransformConfig = None): + super().__init__(config or ArrowTransformConfig()) def _maybe_scale(self, vectors, eigvals): """ @@ -850,57 +854,6 @@ def _low_rank_svd(self, A, B): return U_W, Sigma_C, V_W_T - def _compute_base_proto(self, library, base_model=None, persist=True): - """Compute Arrow prototypes for base model weights""" - - try: - base_vector = library.get_auxiliary_data( - data_type="vectors", expert_name="base_model" - ) - base_eigval = library.get_auxiliary_data( - data_type="eigvals", expert_name="base_model" - ) - except ValueError: - # `get_auxiliary_data` will throw a ValueError if the object is not found. - base_vector = base_eigval = {} - - if len(base_vector) == len(base_eigval) > 0: - # TODO: should we perform some checks to see if the keys lineup - return base_vector, base_eigval - - if base_model is None: - from mttl.models.lightning.expert_module import MultiExpertModule - - an_expert = library[next(iter(library.keys()))] - training_config = an_expert.training_config - training_config.model_modifier = None - base_model = MultiExpertModule(**vars(training_config)) - - vectors, eigvals = {}, {} - for key, base_W in base_model.named_parameters(): - if base_W.ndim != 2: # Only compute base model for matrices - continue - - logger.info(f"\tComputing SVD for base model parameter {key}") - - base_W = base_W.float() - U, E, Vt = torch.linalg.svd(base_W) - vectors[key] = Vt[0].cpu().numpy() - eigvals[key] = E[0].item() - - # Always persist - if persist: - for data_name, data in [("vectors", vectors), ("eigvals", eigvals)]: - library.add_auxiliary_data( - data_type=data_name, - expert_name="base_model", - data=data, - config=None, - force=True, # make sure we overwrite - ) - - return vectors, eigvals - def _get_unique_parent_names(self, alist): """ if adict.keys() = ['model.layer1.lora_a', 'model.layer.lora_b', 'model.layer2.lora_a'] @@ -909,8 +862,9 @@ def _get_unique_parent_names(self, alist): dict_keys = sorted(list(set(".".join(k.split(".")[:-1]) for k in alist))) return dict_keys + @classmethod @torch.no_grad() - def fetch(self, library: Union[str, ExpertLibrary], scale=True): + def fetch(cls, library: Union[str, ExpertLibrary], config_hash: str): """Fetch arrow prototypes from the library, raises ValueError if they are not computed. Args: @@ -920,16 +874,11 @@ def fetch(self, library: Union[str, ExpertLibrary], scale=True): if not isinstance(library, ExpertLibrary): library = ExpertLibrary.get_expert_library(library) + config_hash = config_hash or ArrowTransformConfig().save_name + # try to fetch auxiliary data - vectors = library.get_auxiliary_data( - data_type=self.config.save_name + "_vectors" - ) - eigvals = library.get_auxiliary_data( - data_type=self.config.save_name + "_eigvals" - ) - if scale: - return self._maybe_scale(vectors, eigvals) - return vectors, eigvals + protos = library.get_auxiliary_data(data_type=config_hash + "_protos") + return protos @torch.no_grad() def transform( @@ -943,14 +892,16 @@ def transform( if isinstance(library, str): library = ExpertLibrary.get_expert_library(library) - add_base_proto = self.config.add_base_proto base_model = None - vectors, eigvals = self.fetch(library, scale=False) + # Try to fetch the precomputed Arrow prototypes + protos = self.fetch(library, self.config.save_name) already_computed = [] + vectors = {} + eigvals = {} for expert_name, expert in library.items(): - if expert_name in vectors and not recompute: + if expert_name in protos and not recompute: logger.info( "Found precomputed Arrow prototypes for expert {}".format( expert_name @@ -1098,6 +1049,8 @@ def transform( eigvals[expert_name][parent] = top_value.item() to_upload = [x for x in library.keys() if x not in already_computed] + new_protos = self._maybe_scale(vectors, eigvals) + if persist and len(to_upload) > 0: # add embeddings to the library with library.batched_commit(): @@ -1108,6 +1061,7 @@ def transform( for data_name, data in [ ("vectors", vectors), ("eigvals", eigvals), + ("protos", new_protos), ]: library.add_auxiliary_data( data_type=self.config.save_name + "_" + data_name, @@ -1117,14 +1071,8 @@ def transform( force=True, # make sure we overwrite ) - if add_base_proto: - base_vec, base_val = self._compute_base_proto( - library, base_model=base_model, persist=persist - ) - vectors.update({"base_model": base_vec}) - eigvals.update({"base_model": base_val}) - - return self._maybe_scale(vectors, eigvals) + protos.update(new_protos) + return protos @dataclass diff --git a/projects/modular_llm/train_phatgoose_selector.py b/projects/modular_llm/train_phatgoose_selector.py index 5ac9817d..1229ccfb 100644 --- a/projects/modular_llm/train_phatgoose_selector.py +++ b/projects/modular_llm/train_phatgoose_selector.py @@ -16,8 +16,8 @@ def train_with_transform(args: EvaluationConfig): seed_everything(args.seed, workers=True) from mttl.models.library.library_transforms import ( - PhatgooseConfig, PhatgooseTransform, + PhatgooseTransformConfig, ) library_id, expert_names = parse_libname(args.library_id) @@ -27,7 +27,9 @@ def train_with_transform(args: EvaluationConfig): destination_id=args.destination_library_id, ) phagoose_transform = PhatgooseTransform( - PhatgooseConfig(n_steps=args.n_steps_pg, learning_rate=args.learning_rate_pg) + PhatgooseTransformConfig( + n_steps=args.n_steps_pg, learning_rate=args.learning_rate_pg + ) ) embeddings = phagoose_transform.transform( library, expert_names=expert_names, default_args=args, recompute=True diff --git a/tests/test_expert_model.py b/tests/test_expert_model.py index d55f4937..2d0dd529 100644 --- a/tests/test_expert_model.py +++ b/tests/test_expert_model.py @@ -25,7 +25,7 @@ MultiExpertModelConfig, ) from mttl.models.library.expert import Expert -from mttl.models.library.library_transforms import ArrowConfig, ArrowTransform +from mttl.models.library.library_transforms import ArrowTransform, ArrowTransformConfig from mttl.models.modifiers.lora import LoRAConfig, SkilledLoRAConfig @@ -184,7 +184,7 @@ def test_from_pretrained_with_arrow_save_and_reload(tmp_path): library = model.save_to_library(f"local://{tmp_path}") # store arrow experts - protos = ArrowTransform(ArrowConfig()).transform(library, persist=True) + protos = ArrowTransform(ArrowTransformConfig()).transform(library, persist=True) # from pretrained library selector_config = ArrowSelectorConfig(top_k=4) @@ -232,7 +232,7 @@ def test_from_pretrained_with_arrow(tmp_path): library = model.save_to_library(f"local://{tmp_path}") # store arrow experts - protos = ArrowTransform(ArrowConfig()).transform(library, persist=True) + protos = ArrowTransform(ArrowTransformConfig()).transform(library, persist=True) # from pretrained library selector_config = ArrowSelectorConfig(top_k=4) diff --git a/tests/test_library_transforms.py b/tests/test_library_transforms.py index 84252487..2536b077 100644 --- a/tests/test_library_transforms.py +++ b/tests/test_library_transforms.py @@ -14,14 +14,14 @@ from mttl.models.expert_model import MultiExpertModel, MultiExpertModelConfig from mttl.models.library.expert_library import HFExpertLibrary, LocalExpertLibrary from mttl.models.library.library_transforms import ( - ArrowConfig, ArrowTransform, + ArrowTransformConfig, HiddenStateComputer, HiddenStateComputerConfig, MBClusteringTransformConfig, MBCWithCosSimTransform, - PhatgooseConfig, PhatgooseTransform, + PhatgooseTransformConfig, TiesMerge, TiesMergeConfig, WeightedLinearMerge, @@ -30,13 +30,11 @@ def test_config(): - cfg = ArrowConfig(ab_only=True, scale=False) - assert cfg.save_name == "arrowconfig-a8327e21d374166ceeb94c40d2e7676f" - - cfg2 = ArrowConfig(ab_only=True, scale=True) + cfg = ArrowTransformConfig(ab_only=True, scale=False) + cfg2 = ArrowTransformConfig(ab_only=True, scale=True) assert cfg2.save_name != cfg.save_name - cfg3 = ArrowConfig(ab_only=True, scale=False) + cfg3 = ArrowTransformConfig(ab_only=True, scale=False) assert cfg3.save_name == cfg.save_name @@ -45,7 +43,7 @@ def test_arrow(): library = HFExpertLibrary("sordonia/new-test-library") - cfg = ArrowConfig(ab_only=True, scale=False) + cfg = ArrowTransformConfig(ab_only=True, scale=False) transform = ArrowTransform(cfg) protos = transform.transform(library, persist=False, recompute=True) @@ -106,7 +104,7 @@ def patch_expert_weights(expert, offset=0): library.add_expert(expert1, expert1.name) library.add_expert(expert2, expert2.name) - cfg = ArrowConfig(ab_only=True, scale=False) + cfg = ArrowTransformConfig(ab_only=True, scale=False) transform = ArrowTransform(cfg) protos = transform.transform(library, persist=False, recompute=True) @@ -151,7 +149,9 @@ def test_phatgoose(tiny_flan, tmp_path, create_dummy_expert, monkeypatch): library.add_expert(expert1) library.add_expert(expert2) - pg_config = PhatgooseConfig(n_steps=1, warmup_ratio=0.0, learning_rate=1e-2) + pg_config = PhatgooseTransformConfig( + n_steps=1, warmup_ratio=0.0, learning_rate=1e-2 + ) phatgoose = PhatgooseTransform(pg_config) phatgoose.transform(library, persist=True, recompute=True, default_args=config)