From aee813d9a6f12a954759389767ae3a3b9b2a4806 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Mon, 11 Nov 2024 21:29:37 -0800 Subject: [PATCH 01/12] fix PG --- mttl/logging.py | 5 + mttl/models/containers/base.py | 2 + mttl/models/containers/selectors/base.py | 2 +- .../selectors/phatgoose_selector.py | 7 +- mttl/models/get_optimizer.py | 22 ++- mttl/models/library/library_transforms.py | 154 ++++++++---------- mttl/models/modifiers/lora.py | 4 +- 7 files changed, 91 insertions(+), 105 deletions(-) diff --git a/mttl/logging.py b/mttl/logging.py index 9c87ef1d9..3b80daade 100644 --- a/mttl/logging.py +++ b/mttl/logging.py @@ -21,6 +21,11 @@ def warn_once(msg: str, **kwargs): logger.warning(msg, **kwargs) +@lru_cache +def debug_once(msg: str, **kwargs): + logger.debug(msg, **kwargs) + + def setup_logging(log_dir: str = None): logging.basicConfig( format="%(asctime)s %(levelname)s --> %(message)s", diff --git a/mttl/models/containers/base.py b/mttl/models/containers/base.py index c88f3124f..d64c74d88 100644 --- a/mttl/models/containers/base.py +++ b/mttl/models/containers/base.py @@ -54,6 +54,7 @@ def __init__(self, config, layer, selector=None): self._enabled = True self.config = config self.layer = layer + self.device = next(layer.parameters()).device self.selector = selector or TaskNameSelector() self._default_expert_name = None self.expert_infos = {} @@ -89,6 +90,7 @@ def assign_selector(self, selector: Selector) -> None: self.selector = selector # dependency injection on layer name self.selector.__layer_name__ = self.layer_name + ".selector" + self.selector.device = self.device for expert_name, expert_info in self.expert_infos.items(): self.selector.add_expert( diff --git a/mttl/models/containers/selectors/base.py b/mttl/models/containers/selectors/base.py index cad86183b..e09dac240 100644 --- a/mttl/models/containers/selectors/base.py +++ b/mttl/models/containers/selectors/base.py @@ -266,6 +266,7 @@ def __init__(self, config=None, **kwargs): self._task_to_expert_name = {} # dependency injection filled from ExpertContainer self.__layer_name__ = None + self.device = None @property def expert_names(self) -> list: @@ -326,7 +327,6 @@ def info_container(self): @property def routing_infos(self): - info_container = self.info_container if not info_container: return None diff --git a/mttl/models/containers/selectors/phatgoose_selector.py b/mttl/models/containers/selectors/phatgoose_selector.py index 8c91b94e6..e7d4f66ee 100644 --- a/mttl/models/containers/selectors/phatgoose_selector.py +++ b/mttl/models/containers/selectors/phatgoose_selector.py @@ -86,9 +86,10 @@ class PhatgooseTrainerSelectorConfig(SelectorConfig): class SigmoidGate(nn.Module): - def __init__(self, input_dim, output_dim=1, **kwargs): + def __init__(self, input_dim, output_dim=1, device="cpu", **kwargs): super().__init__() - self.v = nn.Parameter(torch.zeros(output_dim, input_dim)) + + self.v = nn.Parameter(torch.zeros(output_dim, input_dim, device=device)) def forward(self, x): return torch.sigmoid(torch.nn.functional.linear(x, self.v, bias=None)) @@ -130,7 +131,7 @@ def forward(self, input, **kwargs) -> BatchSequenceExpertsAndWeightsSelectorOutp def on_add_expert( self, expert_name: str, expert_info: "ExpertInfo", is_default: bool = False ): - self.gates[expert_name] = SigmoidGate(self.input_dim) + self.gates[expert_name] = SigmoidGate(self.input_dim, device=self.device) def get_merging_weights(self, **selector_kwargs) -> Dict: raise ValueError( diff --git a/mttl/models/get_optimizer.py b/mttl/models/get_optimizer.py index f058894fe..3d2393ae6 100644 --- a/mttl/models/get_optimizer.py +++ b/mttl/models/get_optimizer.py @@ -123,21 +123,19 @@ def get_optimizer_and_scheduler(model, args, num_train_examples, no_decay=None): math.ceil(num_train_examples / global_bs) * args.num_train_epochs ) - if args.warmup_steps == -1 or args.warmup_proportion > 0.0: - logger.warning( - "Warmup proportion is set to {}, has priority over warmup_steps".format( - args.warmup_proportion - ) + if args.warmup_steps == -1 or args.warmup_proportion > 0.0: + logger.warning( + "Warmup proportion is set to {}, has priority over warmup_steps".format( + args.warmup_proportion ) + ) - args.warmup_steps = int(args.warmup_proportion * args.total_steps) - - logger.info("Optimizer setup:") - logger.info("Total steps: {}".format(args.total_steps)) - logger.info("Warmup steps: {}".format(args.warmup_steps)) - logger.info("Scheduler: {}".format(args.scheduler)) + args.warmup_steps = int(args.warmup_proportion * args.total_steps) - scheduler = get_scheduler(optimizer, args) + logger.info("Optimizer setup:") + logger.info("Total steps: {}".format(args.total_steps)) + logger.info("Warmup steps: {}".format(args.warmup_steps)) + logger.info("Scheduler: {}".format(args.scheduler)) optimizer, trainable_param_names = get_optimizer(model, args, no_decay=no_decay) scheduler = get_scheduler(optimizer, args) diff --git a/mttl/models/library/library_transforms.py b/mttl/models/library/library_transforms.py index 0bc7332cd..782e8bec8 100644 --- a/mttl/models/library/library_transforms.py +++ b/mttl/models/library/library_transforms.py @@ -19,7 +19,11 @@ from mttl.datamodule.base import get_datamodule from mttl.logging import logger from mttl.models.containers.lora_containers import ExpertContainer +from mttl.models.containers.selectors.phatgoose_selector import ( + PhatgooseTrainerSelectorConfig, +) from mttl.models.expert_model import MultiExpertModel, MultiExpertModelConfig +from mttl.models.get_optimizer import get_optimizer_and_scheduler from mttl.models.library.expert import Expert from mttl.models.library.expert_library import ExpertLibrary from mttl.models.lightning.callbacks import LiveCheckpointCallback @@ -31,60 +35,45 @@ from mttl.serializable import Serializable -def train_module(args: "ExpertConfig", module: "ExpertModule", dm): - loggers = get_pl_loggers(args) - callbacks = get_monitors(args) +def train_phatgoose(args, model, datamodule): + import tqdm - monitor = "val/loss" - mode = "min" - - checkpoint_callback = LiveCheckpointCallback( - dirpath=args.output_dir, - monitor=monitor, - save_last=True, - mode=mode, - ) - callbacks.append(checkpoint_callback) - - val_check_interval = args.eval_every - if val_check_interval == -1 or val_check_interval is None: - val_check_interval = None - else: - val_check_interval = args.gradient_accumulation_steps * args.eval_every - - if val_check_interval > len(dm.train_dataloader()): - val_check_interval = len(dm.train_dataloader()) - - if val_check_interval > args.total_steps and args.total_steps != -1: - val_check_interval = args.total_steps - - trainer = Trainer( - devices=1, - accelerator="cpu" if args.device_map == "cpu" else "gpu", - num_sanity_val_steps=0, - default_root_dir=args.output_dir, - max_epochs=args.num_train_epochs, - max_steps=args.total_steps, - gradient_clip_val=args.max_grad_norm, - strategy=args.compute_strategy, - callbacks=callbacks, - logger=loggers, - enable_checkpointing=False, - log_every_n_steps=args.gradient_accumulation_steps, - accumulate_grad_batches=args.gradient_accumulation_steps, - precision=args.precision, - val_check_interval=val_check_interval, - ) - - trainer.fit(module, dm) - - checkpoint = ( - checkpoint_callback.best_model_path or checkpoint_callback.last_model_path + (optimizer, scheduler), _ = get_optimizer_and_scheduler( + model, args, num_train_examples=len(datamodule.train_dataset) ) - - # reload the best/last model from the checkpoint - module.load_from_checkpoint(checkpoint) - return checkpoint + iter_train = iter(datamodule.train_dataloader()) + + for step in tqdm.tqdm(range(args.total_steps)): + loss_accum = 0.0 + model.train() + optimizer.zero_grad() + + for micro_step in range(args.gradient_accumulation_steps): + try: + batch = next(iter_train) + except StopIteration: + iter_train = iter(datamodule.train_dataloader()) + batch = next(iter_train) + + with torch.autocast( + device_type="cuda", + dtype=torch.bfloat16, + ): + batch = transfer_batch_to_device(batch, model.device) + loss = model.forward(**batch).loss + loss = loss / args.gradient_accumulation_steps + loss_accum += loss.detach() + loss.backward() + + if loss_accum: + norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + scheduler.step() + optimizer.step() + torch.cuda.synchronize() + logger.debug( + f"Step {step}/{args.total_steps}, Loss: {loss_accum.item():.4f}" + ) + return model class LibraryTransform(abc.ABC, Registrable): @@ -633,7 +622,7 @@ def transform( @dataclass class PhatgooseConfig(LibraryTransformConfig): - n_steps: int = 1000 + n_steps: int = 100 learning_rate: float = 3e-3 warmup_ratio: float = 0.1 # 0.9999999 # 0.1 micro_batch_size: int = 1 @@ -702,14 +691,8 @@ def transform( if default_args is not None: self._update_args(training_config, default_args) - training_config.router_selector = "phatgoose_trainer_selector" training_config.trainable_param_names = ".*selector.*" - training_config.logging_prefix = expert_name + "/" training_config.weight_decay = 0.0 - # for training, we set this to true even if there is just a single expert. - # This ensures that we do (gate * AB * x) instead of ((gate * A) * (gate * B) * x) - training_config.lora_merge_after = True - training_config.eval_every = -1 training_config.total_steps = self.config.n_steps training_config.learning_rate = self.config.learning_rate training_config.warmup_proportion = self.config.warmup_ratio @@ -730,7 +713,16 @@ def transform( logger.info("Training config: {}".format(vars(training_config))) - model = MultiExpertModule(**vars(training_config)) + model = MultiExpertModel( + MultiExpertModelConfig( + base_model=training_config.model, + selector_config=PhatgooseTrainerSelectorConfig( + lora_merge_after=True, + ), + ), + precision="bf16", + device_map="cuda", + ) model.add_expert_instance(expert, is_default=True) # for checksum @@ -745,34 +737,22 @@ def transform( frozen_sum += value.sum() value.requires_grad = False - checkpoint = train_module(training_config, model, dm) - - if ( - training_config.compute_strategy - and training_config.compute_strategy != "deepspeed" - ): - from mttl.models.lightning.expert_module import MultiExpertModule - - model_after = MultiExpertModule(**vars(training_config)) - model_after.add_expert_instance(expert, is_default=True) - model_after.load_state_dict( - torch.load(checkpoint, weights_only=False)["state_dict"] - ) + train_phatgoose(training_config, model, dm) - # for checksum - frozen_sum_after, unfrozen_sum_after = 0, 0 - for key, value in model_after.state_dict().items(): - if re.match(".*selector.gates.*.v", key): - unfrozen_sum_after += value.sum() - else: - frozen_sum_after += value.sum() - - assert ( - frozen_sum == frozen_sum_after - ), "Frozen params changed during training" - assert ( - unfrozen_sum != unfrozen_sum_after - ), "Unfrozen params did not change during training" + # for checksum + frozen_sum_after, unfrozen_sum_after = 0, 0 + for key, value in model.state_dict().items(): + if re.match(".*selector.gates.*.v", key): + unfrozen_sum_after += value.sum() + else: + frozen_sum_after += value.sum() + + assert ( + frozen_sum == frozen_sum_after + ), "Frozen params changed during training" + assert ( + unfrozen_sum != unfrozen_sum_after + ), "Unfrozen params did not change during training" # extract prototypes prototypes = {} diff --git a/mttl/models/modifiers/lora.py b/mttl/models/modifiers/lora.py index 67f3757d3..ab463ab01 100644 --- a/mttl/models/modifiers/lora.py +++ b/mttl/models/modifiers/lora.py @@ -7,7 +7,7 @@ import torch from torch import nn -from mttl.logging import warn_once +from mttl.logging import debug_once, warn_once from mttl.models.modifiers.base import MergeableModifierMixin, Modifier, ModifierConfig @@ -392,7 +392,7 @@ def parallel_linear_weighted_forward( if n_skills == 1: # For Phatgoose, we have a single skill, but we still need a selector - warn_once( + debug_once( f"You are using Skilled LoRA with only one skill. Make sure this is needed" ) From d3d2096bed3329feeb1d9336f64a999742593b99 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Mon, 11 Nov 2024 21:33:38 -0800 Subject: [PATCH 02/12] working! --- mttl/models/library/library_transforms.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mttl/models/library/library_transforms.py b/mttl/models/library/library_transforms.py index 782e8bec8..d2cf51daf 100644 --- a/mttl/models/library/library_transforms.py +++ b/mttl/models/library/library_transforms.py @@ -622,15 +622,12 @@ def transform( @dataclass class PhatgooseConfig(LibraryTransformConfig): - n_steps: int = 100 + n_steps: int = 200 learning_rate: float = 3e-3 - warmup_ratio: float = 0.1 # 0.9999999 # 0.1 + warmup_ratio: float = 0.1 micro_batch_size: int = 1 batch_size: int = 1 - def __post_init__(self): - self.gradient_accumulation_steps = self.batch_size // self.micro_batch_size - @LibraryTransform.register("phatgoose", PhatgooseConfig) class PhatgooseTransform(HiddenStateComputer): @@ -699,7 +696,7 @@ def transform( training_config.train_batch_size = self.config.batch_size training_config.micro_batch_size = self.config.micro_batch_size training_config.gradient_accumulation_steps = ( - self.config.gradient_accumulation_steps + self.config.batch_size // self.config.micro_batch_size ) training_config.dataset = expert.expert_info.dataset From 0ec0b1eab0e98cc099400938fcfc7901093acf94 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Tue, 12 Nov 2024 06:14:44 -0800 Subject: [PATCH 03/12] device none if layer is not module --- mttl/models/containers/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mttl/models/containers/base.py b/mttl/models/containers/base.py index d64c74d88..53acf0d22 100644 --- a/mttl/models/containers/base.py +++ b/mttl/models/containers/base.py @@ -54,7 +54,10 @@ def __init__(self, config, layer, selector=None): self._enabled = True self.config = config self.layer = layer - self.device = next(layer.parameters()).device + if isinstance(layer, nn.Module): + self.device = next(layer.parameters()).device + else: + self.device = None self.selector = selector or TaskNameSelector() self._default_expert_name = None self.expert_infos = {} From bc14f13a04d6bb9d2982b65e8ebd7942d284e6db Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Tue, 12 Nov 2024 08:14:56 -0800 Subject: [PATCH 04/12] utils! --- mttl/models/utils.py | 4 ++-- tests/test_library_transforms.py | 7 ++----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/mttl/models/utils.py b/mttl/models/utils.py index dedf78bea..55651284f 100644 --- a/mttl/models/utils.py +++ b/mttl/models/utils.py @@ -158,10 +158,10 @@ def model_loader_helper( torch_dtype=torch_dtype, ) break - except: + except Exception as e: continue if model_object is None: - raise ValueError(f"Couldn't load {model_name}!") + raise ValueError(f"Couldn't load {model_name}! Exception: {e}") if bnb_config is not None: model_object = prepare_model_for_kbit_training(model_object) diff --git a/tests/test_library_transforms.py b/tests/test_library_transforms.py index b45bb1afe..f385027bb 100644 --- a/tests/test_library_transforms.py +++ b/tests/test_library_transforms.py @@ -121,16 +121,13 @@ def patch_expert_weights(expert, offset=0): def test_phatgoose(tiny_flan, tmp_path, create_dummy_expert, monkeypatch): - # disable wandb - monkeypatch.setenv("WANDB_MODE", "disabled") - dataset, dataset_id = tiny_flan config = ExpertConfig( **{ "model_modifier": "lora", - "lora_rank": 32, - "lora_alpha": 16, + "lora_rank": 4, + "lora_alpha": 1, "warmup_steps": 0, "modify_layers": "k_proj|v_proj|q_proj|o_proj", "trainable_param_names": ".*lora_[ab].*", From ec885ebfda72ae0460b93d968d5c66cb8433af28 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Tue, 12 Nov 2024 08:33:18 -0800 Subject: [PATCH 05/12] min(n_proc , ...) --- mttl/datamodule/mt_seq_to_seq_module.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mttl/datamodule/mt_seq_to_seq_module.py b/mttl/datamodule/mt_seq_to_seq_module.py index 40c49aa3c..5dd13ef48 100644 --- a/mttl/datamodule/mt_seq_to_seq_module.py +++ b/mttl/datamodule/mt_seq_to_seq_module.py @@ -138,7 +138,9 @@ def apply_source_template(dataset, source_template): class FlatMultiTaskModule(DataModule): def setup_dataset(self): self.dataset = DatasetLibrary.pull_dataset_with_retry(self.config.dataset) - n_proc = int(os.environ.get("MTTL_NUM_PROC_DATASETS", 16)) + n_proc = min( + len(self.dataset), int(os.environ.get("MTTL_NUM_PROC_DATASETS", 16)) + ) if "split" not in self.dataset.column_names["train"]: logger.warning( From 706494d01db0b8520943f5b11c133e9ce74db3a1 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Tue, 12 Nov 2024 09:08:52 -0800 Subject: [PATCH 06/12] pathgoose fixes --- mttl/arguments.py | 6 +++++- mttl/models/library/library_transforms.py | 20 +++++++++++++------- tests/test_library_transforms.py | 1 - 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/mttl/arguments.py b/mttl/arguments.py index 736acbbd0..ae2724756 100644 --- a/mttl/arguments.py +++ b/mttl/arguments.py @@ -461,7 +461,10 @@ def __post_init__(self): self.train_batch_size = self.micro_batch_size if self.finetune_task_path is not None: - if not os.path.exists(self.finetune_task_path): + if ( + not os.path.exists(self.finetune_task_path) + and self.finetune_task_name is None + ): raise ValueError(f"Task path {self.finetune_task_path} does not exist!") # resolve task keys @@ -479,6 +482,7 @@ def __post_init__(self): task_names.append(task_name) self.finetune_task_name = ",".join(task_names) + self.finetune_task_path = None n_devices = torch.cuda.device_count() if n_devices > 1: diff --git a/mttl/models/library/library_transforms.py b/mttl/models/library/library_transforms.py index d2cf51daf..724a3ac59 100644 --- a/mttl/models/library/library_transforms.py +++ b/mttl/models/library/library_transforms.py @@ -43,7 +43,9 @@ def train_phatgoose(args, model, datamodule): ) iter_train = iter(datamodule.train_dataloader()) - for step in tqdm.tqdm(range(args.total_steps)): + bar = tqdm.tqdm(range(args.total_steps)) + running_loss = 0.0 + for step in bar: loss_accum = 0.0 model.train() optimizer.zero_grad() @@ -67,11 +69,12 @@ def train_phatgoose(args, model, datamodule): if loss_accum: norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) - scheduler.step() + running_loss += loss_accum.item() optimizer.step() + scheduler.step() torch.cuda.synchronize() - logger.debug( - f"Step {step}/{args.total_steps}, Loss: {loss_accum.item():.4f}" + bar.set_description_str( + f"Step {step + 1}/{args.total_steps}, Loss: {running_loss / (step + 1):.4f}, Lr: {scheduler.get_last_lr()[0]:.4f}" ) return model @@ -628,6 +631,9 @@ class PhatgooseConfig(LibraryTransformConfig): micro_batch_size: int = 1 batch_size: int = 1 + def __post_init__(self): + self.gradient_accumulation_steps = self.batch_size // self.micro_batch_size + @LibraryTransform.register("phatgoose", PhatgooseConfig) class PhatgooseTransform(HiddenStateComputer): @@ -643,7 +649,7 @@ def fetch(self, library: Union[str, ExpertLibrary]): output = library.get_auxiliary_data(data_type=self.config.save_name) if len(output) != len(library): - logger.warn( + logger.warning( "Found {} precomputed Phatgoose prototypes. Some experts might not have prototypes.".format( len(output) ) @@ -696,7 +702,7 @@ def transform( training_config.train_batch_size = self.config.batch_size training_config.micro_batch_size = self.config.micro_batch_size training_config.gradient_accumulation_steps = ( - self.config.batch_size // self.config.micro_batch_size + self.config.gradient_accumulation_steps ) training_config.dataset = expert.expert_info.dataset @@ -717,7 +723,7 @@ def transform( lora_merge_after=True, ), ), - precision="bf16", + precision=training_config.precision, device_map="cuda", ) model.add_expert_instance(expert, is_default=True) diff --git a/tests/test_library_transforms.py b/tests/test_library_transforms.py index f385027bb..842524875 100644 --- a/tests/test_library_transforms.py +++ b/tests/test_library_transforms.py @@ -153,7 +153,6 @@ def test_phatgoose(tiny_flan, tmp_path, create_dummy_expert, monkeypatch): pg_config = PhatgooseConfig(n_steps=1, warmup_ratio=0.0, learning_rate=1e-2) phatgoose = PhatgooseTransform(pg_config) - phatgoose.transform(library, persist=True, recompute=True, default_args=config) # now try to load a selector with the same config From 895b1cb1b3647be175bf1f75ef5de10fa2c89966 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Tue, 12 Nov 2024 09:49:35 -0800 Subject: [PATCH 07/12] update reqs --- mttl/arguments.py | 5 +---- mttl/models/utils.py | 4 +++- requirements.txt | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/mttl/arguments.py b/mttl/arguments.py index ae2724756..203407e50 100644 --- a/mttl/arguments.py +++ b/mttl/arguments.py @@ -461,10 +461,7 @@ def __post_init__(self): self.train_batch_size = self.micro_batch_size if self.finetune_task_path is not None: - if ( - not os.path.exists(self.finetune_task_path) - and self.finetune_task_name is None - ): + if not os.path.exists(self.finetune_task_path): raise ValueError(f"Task path {self.finetune_task_path} does not exist!") # resolve task keys diff --git a/mttl/models/utils.py b/mttl/models/utils.py index 55651284f..01e294603 100644 --- a/mttl/models/utils.py +++ b/mttl/models/utils.py @@ -147,6 +147,7 @@ def model_loader_helper( logger.info(f"Loading phi-2 model from {os.environ['PHI_PATH']}") else: model_object = None + exception = None for klass in [AutoModelForCausalLM, AutoModelForSeq2SeqLM]: try: model_object = klass.from_pretrained( @@ -159,9 +160,10 @@ def model_loader_helper( ) break except Exception as e: + exception = e continue if model_object is None: - raise ValueError(f"Couldn't load {model_name}! Exception: {e}") + raise ValueError(f"Couldn't load {model_name}! Exception: {exception}") if bnb_config is not None: model_object = prepare_model_for_kbit_training(model_object) diff --git a/requirements.txt b/requirements.txt index 4f462e05e..e69ba5ae5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -transformers>=4.42.0 +transformers>=4.44.2 torch>=2.3.1 datasets>=2.20.0 pytorch-lightning>=2.3.3 From ce9b738e11277f173dc4436409252c5688ee369c Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Tue, 12 Nov 2024 10:55:19 -0800 Subject: [PATCH 08/12] only use cuda if it is available --- mttl/models/library/library_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mttl/models/library/library_transforms.py b/mttl/models/library/library_transforms.py index 724a3ac59..ba778e327 100644 --- a/mttl/models/library/library_transforms.py +++ b/mttl/models/library/library_transforms.py @@ -724,7 +724,7 @@ def transform( ), ), precision=training_config.precision, - device_map="cuda", + device_map="cuda" if torch.cuda.is_available() else "cpu", ) model.add_expert_instance(expert, is_default=True) From 4e23212147b26912a9e45b5b4b250d8c3697c04f Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Tue, 12 Nov 2024 10:58:26 -0800 Subject: [PATCH 09/12] cpu support --- mttl/models/library/library_transforms.py | 8 +++++++- mttl/models/utils.py | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/mttl/models/library/library_transforms.py b/mttl/models/library/library_transforms.py index ba778e327..18cdc90c7 100644 --- a/mttl/models/library/library_transforms.py +++ b/mttl/models/library/library_transforms.py @@ -42,6 +42,12 @@ def train_phatgoose(args, model, datamodule): model, args, num_train_examples=len(datamodule.train_dataset) ) iter_train = iter(datamodule.train_dataloader()) + if args.precision == "bf16": + dtype = torch.bfloat16 + elif args.precision == "16": + dtype = torch.float16 + else: + dtype = torch.float32 bar = tqdm.tqdm(range(args.total_steps)) running_loss = 0.0 @@ -59,7 +65,7 @@ def train_phatgoose(args, model, datamodule): with torch.autocast( device_type="cuda", - dtype=torch.bfloat16, + dtype=dtype, ): batch = transfer_batch_to_device(batch, model.device) loss = model.forward(**batch).loss diff --git a/mttl/models/utils.py b/mttl/models/utils.py index 01e294603..bc3cdf8cc 100644 --- a/mttl/models/utils.py +++ b/mttl/models/utils.py @@ -160,10 +160,11 @@ def model_loader_helper( ) break except Exception as e: + logger.warning(f"Couldn't load {model_name}! Exception: {e}") exception = e continue if model_object is None: - raise ValueError(f"Couldn't load {model_name}! Exception: {exception}") + raise ValueError(f"Couldn't load {model_name}!") if bnb_config is not None: model_object = prepare_model_for_kbit_training(model_object) From f8e891fed659dc9a942d195cc044080b754d75a6 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Tue, 12 Nov 2024 11:13:11 -0800 Subject: [PATCH 10/12] only use cuda if it is available --- mttl/models/library/library_transforms.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/mttl/models/library/library_transforms.py b/mttl/models/library/library_transforms.py index 18cdc90c7..93cf69a86 100644 --- a/mttl/models/library/library_transforms.py +++ b/mttl/models/library/library_transforms.py @@ -42,12 +42,6 @@ def train_phatgoose(args, model, datamodule): model, args, num_train_examples=len(datamodule.train_dataset) ) iter_train = iter(datamodule.train_dataloader()) - if args.precision == "bf16": - dtype = torch.bfloat16 - elif args.precision == "16": - dtype = torch.float16 - else: - dtype = torch.float32 bar = tqdm.tqdm(range(args.total_steps)) running_loss = 0.0 @@ -64,8 +58,8 @@ def train_phatgoose(args, model, datamodule): batch = next(iter_train) with torch.autocast( - device_type="cuda", - dtype=dtype, + device_type=model.device.type, + dtype=model.dtype, ): batch = transfer_batch_to_device(batch, model.device) loss = model.forward(**batch).loss From 4fbcc78abfefd8d3eec876560fbc78555ed004bf Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Tue, 12 Nov 2024 11:36:14 -0800 Subject: [PATCH 11/12] surface dtype --- mttl/models/base_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mttl/models/base_model.py b/mttl/models/base_model.py index a0d37d346..2d3253571 100644 --- a/mttl/models/base_model.py +++ b/mttl/models/base_model.py @@ -177,6 +177,10 @@ def forward( def device(self): return self.model.device + @property + def dtype(self): + return self.model.dtype + @property def generation_config(self): return self.model.generation_config From f1bf6af34debfb4f6d9b426578c28f5d04215643 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Tue, 12 Nov 2024 11:53:04 -0800 Subject: [PATCH 12/12] only use cuda if it is available --- mttl/models/library/library_transforms.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mttl/models/library/library_transforms.py b/mttl/models/library/library_transforms.py index 93cf69a86..b45a0e885 100644 --- a/mttl/models/library/library_transforms.py +++ b/mttl/models/library/library_transforms.py @@ -72,7 +72,8 @@ def train_phatgoose(args, model, datamodule): running_loss += loss_accum.item() optimizer.step() scheduler.step() - torch.cuda.synchronize() + if model.device.type == "cuda": + torch.cuda.synchronize() bar.set_description_str( f"Step {step + 1}/{args.total_steps}, Loss: {running_loss / (step + 1):.4f}, Lr: {scheduler.get_last_lr()[0]:.4f}" )