Skip to content

Commit

Permalink
fix PG
Browse files Browse the repository at this point in the history
  • Loading branch information
Alessandro Sordoni committed Nov 12, 2024
1 parent 7d3cffb commit aee813d
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 105 deletions.
5 changes: 5 additions & 0 deletions mttl/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def warn_once(msg: str, **kwargs):
logger.warning(msg, **kwargs)


@lru_cache
def debug_once(msg: str, **kwargs):
logger.debug(msg, **kwargs)


def setup_logging(log_dir: str = None):
logging.basicConfig(
format="%(asctime)s %(levelname)s --> %(message)s",
Expand Down
2 changes: 2 additions & 0 deletions mttl/models/containers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, config, layer, selector=None):
self._enabled = True
self.config = config
self.layer = layer
self.device = next(layer.parameters()).device
self.selector = selector or TaskNameSelector()
self._default_expert_name = None
self.expert_infos = {}
Expand Down Expand Up @@ -89,6 +90,7 @@ def assign_selector(self, selector: Selector) -> None:
self.selector = selector
# dependency injection on layer name
self.selector.__layer_name__ = self.layer_name + ".selector"
self.selector.device = self.device

for expert_name, expert_info in self.expert_infos.items():
self.selector.add_expert(
Expand Down
2 changes: 1 addition & 1 deletion mttl/models/containers/selectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def __init__(self, config=None, **kwargs):
self._task_to_expert_name = {}
# dependency injection filled from ExpertContainer
self.__layer_name__ = None
self.device = None

@property
def expert_names(self) -> list:
Expand Down Expand Up @@ -326,7 +327,6 @@ def info_container(self):

@property
def routing_infos(self):

info_container = self.info_container
if not info_container:
return None
Expand Down
7 changes: 4 additions & 3 deletions mttl/models/containers/selectors/phatgoose_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ class PhatgooseTrainerSelectorConfig(SelectorConfig):


class SigmoidGate(nn.Module):
def __init__(self, input_dim, output_dim=1, **kwargs):
def __init__(self, input_dim, output_dim=1, device="cpu", **kwargs):
super().__init__()
self.v = nn.Parameter(torch.zeros(output_dim, input_dim))

self.v = nn.Parameter(torch.zeros(output_dim, input_dim, device=device))

def forward(self, x):
return torch.sigmoid(torch.nn.functional.linear(x, self.v, bias=None))
Expand Down Expand Up @@ -130,7 +131,7 @@ def forward(self, input, **kwargs) -> BatchSequenceExpertsAndWeightsSelectorOutp
def on_add_expert(
self, expert_name: str, expert_info: "ExpertInfo", is_default: bool = False
):
self.gates[expert_name] = SigmoidGate(self.input_dim)
self.gates[expert_name] = SigmoidGate(self.input_dim, device=self.device)

def get_merging_weights(self, **selector_kwargs) -> Dict:
raise ValueError(
Expand Down
22 changes: 10 additions & 12 deletions mttl/models/get_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,19 @@ def get_optimizer_and_scheduler(model, args, num_train_examples, no_decay=None):
math.ceil(num_train_examples / global_bs) * args.num_train_epochs
)

if args.warmup_steps == -1 or args.warmup_proportion > 0.0:
logger.warning(
"Warmup proportion is set to {}, has priority over warmup_steps".format(
args.warmup_proportion
)
if args.warmup_steps == -1 or args.warmup_proportion > 0.0:
logger.warning(
"Warmup proportion is set to {}, has priority over warmup_steps".format(
args.warmup_proportion
)
)

args.warmup_steps = int(args.warmup_proportion * args.total_steps)

logger.info("Optimizer setup:")
logger.info("Total steps: {}".format(args.total_steps))
logger.info("Warmup steps: {}".format(args.warmup_steps))
logger.info("Scheduler: {}".format(args.scheduler))
args.warmup_steps = int(args.warmup_proportion * args.total_steps)

scheduler = get_scheduler(optimizer, args)
logger.info("Optimizer setup:")
logger.info("Total steps: {}".format(args.total_steps))
logger.info("Warmup steps: {}".format(args.warmup_steps))
logger.info("Scheduler: {}".format(args.scheduler))

optimizer, trainable_param_names = get_optimizer(model, args, no_decay=no_decay)
scheduler = get_scheduler(optimizer, args)
Expand Down
154 changes: 67 additions & 87 deletions mttl/models/library/library_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from mttl.datamodule.base import get_datamodule
from mttl.logging import logger
from mttl.models.containers.lora_containers import ExpertContainer
from mttl.models.containers.selectors.phatgoose_selector import (
PhatgooseTrainerSelectorConfig,
)
from mttl.models.expert_model import MultiExpertModel, MultiExpertModelConfig
from mttl.models.get_optimizer import get_optimizer_and_scheduler
from mttl.models.library.expert import Expert
from mttl.models.library.expert_library import ExpertLibrary
from mttl.models.lightning.callbacks import LiveCheckpointCallback
Expand All @@ -31,60 +35,45 @@
from mttl.serializable import Serializable


def train_module(args: "ExpertConfig", module: "ExpertModule", dm):
loggers = get_pl_loggers(args)
callbacks = get_monitors(args)
def train_phatgoose(args, model, datamodule):
import tqdm

monitor = "val/loss"
mode = "min"

checkpoint_callback = LiveCheckpointCallback(
dirpath=args.output_dir,
monitor=monitor,
save_last=True,
mode=mode,
)
callbacks.append(checkpoint_callback)

val_check_interval = args.eval_every
if val_check_interval == -1 or val_check_interval is None:
val_check_interval = None
else:
val_check_interval = args.gradient_accumulation_steps * args.eval_every

if val_check_interval > len(dm.train_dataloader()):
val_check_interval = len(dm.train_dataloader())

if val_check_interval > args.total_steps and args.total_steps != -1:
val_check_interval = args.total_steps

trainer = Trainer(
devices=1,
accelerator="cpu" if args.device_map == "cpu" else "gpu",
num_sanity_val_steps=0,
default_root_dir=args.output_dir,
max_epochs=args.num_train_epochs,
max_steps=args.total_steps,
gradient_clip_val=args.max_grad_norm,
strategy=args.compute_strategy,
callbacks=callbacks,
logger=loggers,
enable_checkpointing=False,
log_every_n_steps=args.gradient_accumulation_steps,
accumulate_grad_batches=args.gradient_accumulation_steps,
precision=args.precision,
val_check_interval=val_check_interval,
)

trainer.fit(module, dm)

checkpoint = (
checkpoint_callback.best_model_path or checkpoint_callback.last_model_path
(optimizer, scheduler), _ = get_optimizer_and_scheduler(
model, args, num_train_examples=len(datamodule.train_dataset)
)

# reload the best/last model from the checkpoint
module.load_from_checkpoint(checkpoint)
return checkpoint
iter_train = iter(datamodule.train_dataloader())

for step in tqdm.tqdm(range(args.total_steps)):
loss_accum = 0.0
model.train()
optimizer.zero_grad()

for micro_step in range(args.gradient_accumulation_steps):
try:
batch = next(iter_train)
except StopIteration:
iter_train = iter(datamodule.train_dataloader())
batch = next(iter_train)

with torch.autocast(
device_type="cuda",
dtype=torch.bfloat16,
):
batch = transfer_batch_to_device(batch, model.device)
loss = model.forward(**batch).loss
loss = loss / args.gradient_accumulation_steps
loss_accum += loss.detach()
loss.backward()

if loss_accum:
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scheduler.step()
optimizer.step()
torch.cuda.synchronize()
logger.debug(
f"Step {step}/{args.total_steps}, Loss: {loss_accum.item():.4f}"
)
return model


class LibraryTransform(abc.ABC, Registrable):
Expand Down Expand Up @@ -633,7 +622,7 @@ def transform(

@dataclass
class PhatgooseConfig(LibraryTransformConfig):
n_steps: int = 1000
n_steps: int = 100
learning_rate: float = 3e-3
warmup_ratio: float = 0.1 # 0.9999999 # 0.1
micro_batch_size: int = 1
Expand Down Expand Up @@ -702,14 +691,8 @@ def transform(
if default_args is not None:
self._update_args(training_config, default_args)

training_config.router_selector = "phatgoose_trainer_selector"
training_config.trainable_param_names = ".*selector.*"
training_config.logging_prefix = expert_name + "/"
training_config.weight_decay = 0.0
# for training, we set this to true even if there is just a single expert.
# This ensures that we do (gate * AB * x) instead of ((gate * A) * (gate * B) * x)
training_config.lora_merge_after = True
training_config.eval_every = -1
training_config.total_steps = self.config.n_steps
training_config.learning_rate = self.config.learning_rate
training_config.warmup_proportion = self.config.warmup_ratio
Expand All @@ -730,7 +713,16 @@ def transform(

logger.info("Training config: {}".format(vars(training_config)))

model = MultiExpertModule(**vars(training_config))
model = MultiExpertModel(
MultiExpertModelConfig(
base_model=training_config.model,
selector_config=PhatgooseTrainerSelectorConfig(
lora_merge_after=True,
),
),
precision="bf16",
device_map="cuda",
)
model.add_expert_instance(expert, is_default=True)

# for checksum
Expand All @@ -745,34 +737,22 @@ def transform(
frozen_sum += value.sum()
value.requires_grad = False

checkpoint = train_module(training_config, model, dm)

if (
training_config.compute_strategy
and training_config.compute_strategy != "deepspeed"
):
from mttl.models.lightning.expert_module import MultiExpertModule

model_after = MultiExpertModule(**vars(training_config))
model_after.add_expert_instance(expert, is_default=True)
model_after.load_state_dict(
torch.load(checkpoint, weights_only=False)["state_dict"]
)
train_phatgoose(training_config, model, dm)

# for checksum
frozen_sum_after, unfrozen_sum_after = 0, 0
for key, value in model_after.state_dict().items():
if re.match(".*selector.gates.*.v", key):
unfrozen_sum_after += value.sum()
else:
frozen_sum_after += value.sum()

assert (
frozen_sum == frozen_sum_after
), "Frozen params changed during training"
assert (
unfrozen_sum != unfrozen_sum_after
), "Unfrozen params did not change during training"
# for checksum
frozen_sum_after, unfrozen_sum_after = 0, 0
for key, value in model.state_dict().items():
if re.match(".*selector.gates.*.v", key):
unfrozen_sum_after += value.sum()
else:
frozen_sum_after += value.sum()

assert (
frozen_sum == frozen_sum_after
), "Frozen params changed during training"
assert (
unfrozen_sum != unfrozen_sum_after
), "Unfrozen params did not change during training"

# extract prototypes
prototypes = {}
Expand Down
4 changes: 2 additions & 2 deletions mttl/models/modifiers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from torch import nn

from mttl.logging import warn_once
from mttl.logging import debug_once, warn_once
from mttl.models.modifiers.base import MergeableModifierMixin, Modifier, ModifierConfig


Expand Down Expand Up @@ -392,7 +392,7 @@ def parallel_linear_weighted_forward(

if n_skills == 1:
# For Phatgoose, we have a single skill, but we still need a selector
warn_once(
debug_once(
f"You are using Skilled LoRA with only one skill. Make sure this is needed"
)

Expand Down

0 comments on commit aee813d

Please sign in to comment.