Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix PG #136

Merged
merged 12 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mttl/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ def __post_init__(self):
task_names.append(task_name)

self.finetune_task_name = ",".join(task_names)
self.finetune_task_path = None

n_devices = torch.cuda.device_count()
if n_devices > 1:
Expand Down
4 changes: 3 additions & 1 deletion mttl/datamodule/mt_seq_to_seq_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def apply_source_template(dataset, source_template):
class FlatMultiTaskModule(DataModule):
def setup_dataset(self):
self.dataset = DatasetLibrary.pull_dataset_with_retry(self.config.dataset)
n_proc = int(os.environ.get("MTTL_NUM_PROC_DATASETS", 16))
n_proc = min(
len(self.dataset), int(os.environ.get("MTTL_NUM_PROC_DATASETS", 16))
)

if "split" not in self.dataset.column_names["train"]:
logger.warning(
Expand Down
5 changes: 5 additions & 0 deletions mttl/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def warn_once(msg: str, **kwargs):
logger.warning(msg, **kwargs)


@lru_cache
def debug_once(msg: str, **kwargs):
logger.debug(msg, **kwargs)


def setup_logging(log_dir: str = None):
logging.basicConfig(
format="%(asctime)s %(levelname)s --> %(message)s",
Expand Down
4 changes: 4 additions & 0 deletions mttl/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def forward(
def device(self):
return self.model.device

@property
def dtype(self):
return self.model.dtype

@property
def generation_config(self):
return self.model.generation_config
Expand Down
5 changes: 5 additions & 0 deletions mttl/models/containers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def __init__(self, config, layer, selector=None):
self._enabled = True
self.config = config
self.layer = layer
if isinstance(layer, nn.Module):
self.device = next(layer.parameters()).device
else:
self.device = None
self.selector = selector or TaskNameSelector()
self._default_expert_name = None
self.expert_infos = {}
Expand Down Expand Up @@ -89,6 +93,7 @@ def assign_selector(self, selector: Selector) -> None:
self.selector = selector
# dependency injection on layer name
self.selector.__layer_name__ = self.layer_name + ".selector"
self.selector.device = self.device

for expert_name, expert_info in self.expert_infos.items():
self.selector.add_expert(
Expand Down
2 changes: 1 addition & 1 deletion mttl/models/containers/selectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def __init__(self, config=None, **kwargs):
self._task_to_expert_name = {}
# dependency injection filled from ExpertContainer
self.__layer_name__ = None
self.device = None

@property
def expert_names(self) -> list:
Expand Down Expand Up @@ -326,7 +327,6 @@ def info_container(self):

@property
def routing_infos(self):

info_container = self.info_container
if not info_container:
return None
Expand Down
7 changes: 4 additions & 3 deletions mttl/models/containers/selectors/phatgoose_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ class PhatgooseTrainerSelectorConfig(SelectorConfig):


class SigmoidGate(nn.Module):
def __init__(self, input_dim, output_dim=1, **kwargs):
def __init__(self, input_dim, output_dim=1, device="cpu", **kwargs):
super().__init__()
self.v = nn.Parameter(torch.zeros(output_dim, input_dim))

self.v = nn.Parameter(torch.zeros(output_dim, input_dim, device=device))

def forward(self, x):
return torch.sigmoid(torch.nn.functional.linear(x, self.v, bias=None))
Expand Down Expand Up @@ -130,7 +131,7 @@ def forward(self, input, **kwargs) -> BatchSequenceExpertsAndWeightsSelectorOutp
def on_add_expert(
self, expert_name: str, expert_info: "ExpertInfo", is_default: bool = False
):
self.gates[expert_name] = SigmoidGate(self.input_dim)
self.gates[expert_name] = SigmoidGate(self.input_dim, device=self.device)

def get_merging_weights(self, **selector_kwargs) -> Dict:
raise ValueError(
Expand Down
22 changes: 10 additions & 12 deletions mttl/models/get_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,19 @@ def get_optimizer_and_scheduler(model, args, num_train_examples, no_decay=None):
math.ceil(num_train_examples / global_bs) * args.num_train_epochs
)

if args.warmup_steps == -1 or args.warmup_proportion > 0.0:
logger.warning(
"Warmup proportion is set to {}, has priority over warmup_steps".format(
args.warmup_proportion
)
if args.warmup_steps == -1 or args.warmup_proportion > 0.0:
logger.warning(
"Warmup proportion is set to {}, has priority over warmup_steps".format(
args.warmup_proportion
)
)

args.warmup_steps = int(args.warmup_proportion * args.total_steps)

logger.info("Optimizer setup:")
logger.info("Total steps: {}".format(args.total_steps))
logger.info("Warmup steps: {}".format(args.warmup_steps))
logger.info("Scheduler: {}".format(args.scheduler))
args.warmup_steps = int(args.warmup_proportion * args.total_steps)

scheduler = get_scheduler(optimizer, args)
logger.info("Optimizer setup:")
logger.info("Total steps: {}".format(args.total_steps))
logger.info("Warmup steps: {}".format(args.warmup_steps))
logger.info("Scheduler: {}".format(args.scheduler))

optimizer, trainable_param_names = get_optimizer(model, args, no_decay=no_decay)
scheduler = get_scheduler(optimizer, args)
Expand Down
162 changes: 73 additions & 89 deletions mttl/models/library/library_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from mttl.datamodule.base import get_datamodule
from mttl.logging import logger
from mttl.models.containers.lora_containers import ExpertContainer
from mttl.models.containers.selectors.phatgoose_selector import (
PhatgooseTrainerSelectorConfig,
)
from mttl.models.expert_model import MultiExpertModel, MultiExpertModelConfig
from mttl.models.get_optimizer import get_optimizer_and_scheduler
from mttl.models.library.expert import Expert
from mttl.models.library.expert_library import ExpertLibrary
from mttl.models.lightning.callbacks import LiveCheckpointCallback
Expand All @@ -31,60 +35,49 @@
from mttl.serializable import Serializable


def train_module(args: "ExpertConfig", module: "ExpertModule", dm):
loggers = get_pl_loggers(args)
callbacks = get_monitors(args)
def train_phatgoose(args, model, datamodule):
import tqdm

monitor = "val/loss"
mode = "min"

checkpoint_callback = LiveCheckpointCallback(
dirpath=args.output_dir,
monitor=monitor,
save_last=True,
mode=mode,
)
callbacks.append(checkpoint_callback)

val_check_interval = args.eval_every
if val_check_interval == -1 or val_check_interval is None:
val_check_interval = None
else:
val_check_interval = args.gradient_accumulation_steps * args.eval_every

if val_check_interval > len(dm.train_dataloader()):
val_check_interval = len(dm.train_dataloader())

if val_check_interval > args.total_steps and args.total_steps != -1:
val_check_interval = args.total_steps

trainer = Trainer(
devices=1,
accelerator="cpu" if args.device_map == "cpu" else "gpu",
num_sanity_val_steps=0,
default_root_dir=args.output_dir,
max_epochs=args.num_train_epochs,
max_steps=args.total_steps,
gradient_clip_val=args.max_grad_norm,
strategy=args.compute_strategy,
callbacks=callbacks,
logger=loggers,
enable_checkpointing=False,
log_every_n_steps=args.gradient_accumulation_steps,
accumulate_grad_batches=args.gradient_accumulation_steps,
precision=args.precision,
val_check_interval=val_check_interval,
)

trainer.fit(module, dm)

checkpoint = (
checkpoint_callback.best_model_path or checkpoint_callback.last_model_path
(optimizer, scheduler), _ = get_optimizer_and_scheduler(
model, args, num_train_examples=len(datamodule.train_dataset)
)

# reload the best/last model from the checkpoint
module.load_from_checkpoint(checkpoint)
return checkpoint
iter_train = iter(datamodule.train_dataloader())

bar = tqdm.tqdm(range(args.total_steps))
running_loss = 0.0
for step in bar:
loss_accum = 0.0
model.train()
optimizer.zero_grad()

for micro_step in range(args.gradient_accumulation_steps):
try:
batch = next(iter_train)
except StopIteration:
iter_train = iter(datamodule.train_dataloader())
batch = next(iter_train)

with torch.autocast(
device_type=model.device.type,
dtype=model.dtype,
):
batch = transfer_batch_to_device(batch, model.device)
loss = model.forward(**batch).loss
loss = loss / args.gradient_accumulation_steps
loss_accum += loss.detach()
loss.backward()

if loss_accum:
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
running_loss += loss_accum.item()
optimizer.step()
scheduler.step()
if model.device.type == "cuda":
torch.cuda.synchronize()
bar.set_description_str(
f"Step {step + 1}/{args.total_steps}, Loss: {running_loss / (step + 1):.4f}, Lr: {scheduler.get_last_lr()[0]:.4f}"
)
return model


class LibraryTransform(abc.ABC, Registrable):
Expand Down Expand Up @@ -633,9 +626,9 @@ def transform(

@dataclass
class PhatgooseConfig(LibraryTransformConfig):
n_steps: int = 1000
n_steps: int = 200
learning_rate: float = 3e-3
warmup_ratio: float = 0.1 # 0.9999999 # 0.1
warmup_ratio: float = 0.1
micro_batch_size: int = 1
batch_size: int = 1

Expand All @@ -657,7 +650,7 @@ def fetch(self, library: Union[str, ExpertLibrary]):
output = library.get_auxiliary_data(data_type=self.config.save_name)

if len(output) != len(library):
logger.warn(
logger.warning(
"Found {} precomputed Phatgoose prototypes. Some experts might not have prototypes.".format(
len(output)
)
Expand Down Expand Up @@ -702,14 +695,8 @@ def transform(
if default_args is not None:
self._update_args(training_config, default_args)

training_config.router_selector = "phatgoose_trainer_selector"
training_config.trainable_param_names = ".*selector.*"
training_config.logging_prefix = expert_name + "/"
training_config.weight_decay = 0.0
# for training, we set this to true even if there is just a single expert.
# This ensures that we do (gate * AB * x) instead of ((gate * A) * (gate * B) * x)
training_config.lora_merge_after = True
training_config.eval_every = -1
training_config.total_steps = self.config.n_steps
training_config.learning_rate = self.config.learning_rate
training_config.warmup_proportion = self.config.warmup_ratio
Expand All @@ -730,7 +717,16 @@ def transform(

logger.info("Training config: {}".format(vars(training_config)))

model = MultiExpertModule(**vars(training_config))
model = MultiExpertModel(
MultiExpertModelConfig(
base_model=training_config.model,
selector_config=PhatgooseTrainerSelectorConfig(
lora_merge_after=True,
),
),
precision=training_config.precision,
device_map="cuda" if torch.cuda.is_available() else "cpu",
)
model.add_expert_instance(expert, is_default=True)

# for checksum
Expand All @@ -745,34 +741,22 @@ def transform(
frozen_sum += value.sum()
value.requires_grad = False

checkpoint = train_module(training_config, model, dm)

if (
training_config.compute_strategy
and training_config.compute_strategy != "deepspeed"
):
from mttl.models.lightning.expert_module import MultiExpertModule

model_after = MultiExpertModule(**vars(training_config))
model_after.add_expert_instance(expert, is_default=True)
model_after.load_state_dict(
torch.load(checkpoint, weights_only=False)["state_dict"]
)
train_phatgoose(training_config, model, dm)

# for checksum
frozen_sum_after, unfrozen_sum_after = 0, 0
for key, value in model_after.state_dict().items():
if re.match(".*selector.gates.*.v", key):
unfrozen_sum_after += value.sum()
else:
frozen_sum_after += value.sum()

assert (
frozen_sum == frozen_sum_after
), "Frozen params changed during training"
assert (
unfrozen_sum != unfrozen_sum_after
), "Unfrozen params did not change during training"
# for checksum
frozen_sum_after, unfrozen_sum_after = 0, 0
for key, value in model.state_dict().items():
if re.match(".*selector.gates.*.v", key):
unfrozen_sum_after += value.sum()
else:
frozen_sum_after += value.sum()

assert (
frozen_sum == frozen_sum_after
), "Frozen params changed during training"
assert (
unfrozen_sum != unfrozen_sum_after
), "Unfrozen params did not change during training"

# extract prototypes
prototypes = {}
Expand Down
4 changes: 2 additions & 2 deletions mttl/models/modifiers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from torch import nn

from mttl.logging import warn_once
from mttl.logging import debug_once, warn_once
from mttl.models.modifiers.base import MergeableModifierMixin, Modifier, ModifierConfig


Expand Down Expand Up @@ -392,7 +392,7 @@ def parallel_linear_weighted_forward(

if n_skills == 1:
# For Phatgoose, we have a single skill, but we still need a selector
warn_once(
debug_once(
f"You are using Skilled LoRA with only one skill. Make sure this is needed"
)

Expand Down
5 changes: 4 additions & 1 deletion mttl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def model_loader_helper(
logger.info(f"Loading phi-2 model from {os.environ['PHI_PATH']}")
else:
model_object = None
exception = None
for klass in [AutoModelForCausalLM, AutoModelForSeq2SeqLM]:
try:
model_object = klass.from_pretrained(
Expand All @@ -158,7 +159,9 @@ def model_loader_helper(
torch_dtype=torch_dtype,
)
break
except:
except Exception as e:
logger.warning(f"Couldn't load {model_name}! Exception: {e}")
exception = e
continue
if model_object is None:
raise ValueError(f"Couldn't load {model_name}!")
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
transformers>=4.42.0
transformers>=4.44.2
torch>=2.3.1
datasets>=2.20.0
pytorch-lightning>=2.3.3
Expand Down
Loading