Skip to content

Commit

Permalink
pathgoose fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Alessandro Sordoni committed Nov 12, 2024
1 parent ec885eb commit 706494d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
6 changes: 5 additions & 1 deletion mttl/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,10 @@ def __post_init__(self):
self.train_batch_size = self.micro_batch_size

if self.finetune_task_path is not None:
if not os.path.exists(self.finetune_task_path):
if (
not os.path.exists(self.finetune_task_path)
and self.finetune_task_name is None
):
raise ValueError(f"Task path {self.finetune_task_path} does not exist!")

# resolve task keys
Expand All @@ -479,6 +482,7 @@ def __post_init__(self):
task_names.append(task_name)

self.finetune_task_name = ",".join(task_names)
self.finetune_task_path = None

n_devices = torch.cuda.device_count()
if n_devices > 1:
Expand Down
20 changes: 13 additions & 7 deletions mttl/models/library/library_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def train_phatgoose(args, model, datamodule):
)
iter_train = iter(datamodule.train_dataloader())

for step in tqdm.tqdm(range(args.total_steps)):
bar = tqdm.tqdm(range(args.total_steps))
running_loss = 0.0
for step in bar:
loss_accum = 0.0
model.train()
optimizer.zero_grad()
Expand All @@ -67,11 +69,12 @@ def train_phatgoose(args, model, datamodule):

if loss_accum:
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scheduler.step()
running_loss += loss_accum.item()
optimizer.step()
scheduler.step()
torch.cuda.synchronize()
logger.debug(
f"Step {step}/{args.total_steps}, Loss: {loss_accum.item():.4f}"
bar.set_description_str(
f"Step {step + 1}/{args.total_steps}, Loss: {running_loss / (step + 1):.4f}, Lr: {scheduler.get_last_lr()[0]:.4f}"
)
return model

Expand Down Expand Up @@ -628,6 +631,9 @@ class PhatgooseConfig(LibraryTransformConfig):
micro_batch_size: int = 1
batch_size: int = 1

def __post_init__(self):
self.gradient_accumulation_steps = self.batch_size // self.micro_batch_size


@LibraryTransform.register("phatgoose", PhatgooseConfig)
class PhatgooseTransform(HiddenStateComputer):
Expand All @@ -643,7 +649,7 @@ def fetch(self, library: Union[str, ExpertLibrary]):
output = library.get_auxiliary_data(data_type=self.config.save_name)

if len(output) != len(library):
logger.warn(
logger.warning(
"Found {} precomputed Phatgoose prototypes. Some experts might not have prototypes.".format(
len(output)
)
Expand Down Expand Up @@ -696,7 +702,7 @@ def transform(
training_config.train_batch_size = self.config.batch_size
training_config.micro_batch_size = self.config.micro_batch_size
training_config.gradient_accumulation_steps = (
self.config.batch_size // self.config.micro_batch_size
self.config.gradient_accumulation_steps
)
training_config.dataset = expert.expert_info.dataset

Expand All @@ -717,7 +723,7 @@ def transform(
lora_merge_after=True,
),
),
precision="bf16",
precision=training_config.precision,
device_map="cuda",
)
model.add_expert_instance(expert, is_default=True)
Expand Down
1 change: 0 additions & 1 deletion tests/test_library_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def test_phatgoose(tiny_flan, tmp_path, create_dummy_expert, monkeypatch):

pg_config = PhatgooseConfig(n_steps=1, warmup_ratio=0.0, learning_rate=1e-2)
phatgoose = PhatgooseTransform(pg_config)

phatgoose.transform(library, persist=True, recompute=True, default_args=config)

# now try to load a selector with the same config
Expand Down

0 comments on commit 706494d

Please sign in to comment.