diff --git a/examples/peft/merge_lora.py b/examples/peft/merge_lora.py index 5cfb139b33..c91724dd3e 100644 --- a/examples/peft/merge_lora.py +++ b/examples/peft/merge_lora.py @@ -27,12 +27,23 @@ Usage ----- -python merge_lora.py \ - --lora-checkpoint path/to/finetune_ckpt \ - --hf-model-path path/to/hf_model \ - --output path/to/merged_ckpt \ - [--pretrained path/to/base_ckpt] \ - [--tp 1] [--pp 1] [--ep 1] [--cpu] +CPU-only (single process, no GPU required):: + + python merge_lora.py \ + --lora-checkpoint path/to/finetune_ckpt \ + --hf-model-path path/to/hf_model \ + --output path/to/merged_ckpt \ + [--pretrained path/to/base_ckpt] \ + --cpu + +GPU with tensor/pipeline/expert parallelism:: + + torchrun --nproc_per_node merge_lora.py \ + --lora-checkpoint path/to/finetune_ckpt \ + --hf-model-path path/to/hf_model \ + --output path/to/merged_ckpt \ + [--pretrained path/to/base_ckpt] \ + [--tp 1] [--pp 1] [--ep 1] """ from __future__ import annotations @@ -54,7 +65,7 @@ ) from megatron.bridge.training.model_load_save import save_megatron_model from megatron.bridge.training.utils.checkpoint_utils import read_run_config -from megatron.bridge.utils.common_utils import print_rank_0 +from megatron.bridge.utils.common_utils import print_rank_0, resolve_path logger = logging.getLogger(__name__) @@ -100,7 +111,7 @@ def parse_args() -> argparse.Namespace: def _resolve_pretrained(lora_dir: Path, explicit: Optional[str]) -> Path: if explicit: - return Path(explicit).expanduser().resolve() + return resolve_path(explicit) cfg_path = lora_dir / "run_config.yaml" if not cfg_path.exists(): raise FileNotFoundError("run_config.yaml not found in LoRA checkpoint and --pretrained not supplied") @@ -108,7 +119,7 @@ def _resolve_pretrained(lora_dir: Path, explicit: Optional[str]) -> Path: base = cfg.get("checkpoint", {}).get("pretrained_checkpoint") if base is None: raise ValueError("pretrained_checkpoint missing in run_config.yaml; pass --pretrained") - return Path(base).expanduser().resolve() + return resolve_path(base) # ----------------------------------------------------------------------------- @@ -150,7 +161,11 @@ def merge_lora( model_provider.expert_tensor_parallel_size = 1 model_provider.pipeline_dtype = torch.bfloat16 if args.cpu: - assert args.tp == args.pp == args.ep == 1, "TP, PP, and EP must be 1 when using CPU merge" + if args.tp != 1 or args.pp != 1 or args.ep != 1: + logger.warning("TP, PP, and EP must be 1 when using CPU merge. Setting to 1.") + args.tp = 1 + args.pp = 1 + args.ep = 1 if not torch.distributed.is_initialized(): torch.distributed.init_process_group("gloo") model_provider.initialize_model_parallel(seed=0) @@ -255,7 +270,7 @@ def main() -> None: level=logging.DEBUG if args.debug else logging.INFO, format="%(asctime)s %(levelname)s %(message)s" ) - lora_dir = Path(args.lora_checkpoint).expanduser().resolve() + lora_dir = resolve_path(args.lora_checkpoint) if not lora_dir.exists(): raise FileNotFoundError(f"LoRA checkpoint not found: {lora_dir}") base_dir = _resolve_pretrained(lora_dir, args.pretrained) @@ -265,19 +280,16 @@ def main() -> None: merge_lora( base_dir=base_dir, lora_dir=lora_dir, - out_dir=Path(args.output).expanduser().resolve(), + out_dir=resolve_path(args.output), hf_model_path=args.hf_model_path, args=args, ) except torch.cuda.OutOfMemoryError: - logger.error("Out of memory error during merge. Trying CPU merge...") - merge_lora( - base_dir=base_dir, - lora_dir=lora_dir, - out_dir=Path(args.output).expanduser().resolve(), - hf_model_path=args.hf_model_path, - args=args, - ) + logger.warning("CUDA out of memory during merge. Please rerun this script on CPU by adding the `--cpu` flag.") + raise SystemExit(1) + finally: + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() if __name__ == "__main__":