Skip to content

Commit

Permalink
debug upgrade it to pytorch_lightning2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
zhansu committed Aug 16, 2023
1 parent 2b7e22c commit de9991c
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 143 deletions.
6 changes: 3 additions & 3 deletions finetune_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,17 +336,17 @@ def run_multitask(args):
callbacks.append(checkpoint_callback)

trainer = Trainer(
gpus=1,
devices=1,
accelerator="gpu",
logger=loggers,
num_sanity_val_steps=5,
amp_backend="native",
# amp_backend="native",
default_root_dir=args.output_dir,
max_epochs=args.num_train_epochs,
max_steps=args.total_steps + 1 if args.total_steps != -1 else -1,
gradient_clip_val=args.max_grad_norm,
log_every_n_steps=20,
strategy=args.compute_strategy if args.compute_strategy else None,
strategy="ddp" if not args.compute_strategy else args.compute_strategy,
callbacks=callbacks,
accumulate_grad_batches=args.gradient_accumulation_steps,
precision=int(args.precision)
Expand Down
Loading

0 comments on commit de9991c

Please sign in to comment.