Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions examples/peft/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,23 @@

Usage
-----
python merge_lora.py \
--lora-checkpoint path/to/finetune_ckpt \
--hf-model-path path/to/hf_model \
--output path/to/merged_ckpt \
[--pretrained path/to/base_ckpt] \
[--tp 1] [--pp 1] [--ep 1] [--cpu]
CPU-only (single process, no GPU required)::

python merge_lora.py \
--lora-checkpoint path/to/finetune_ckpt \
--hf-model-path path/to/hf_model \
--output path/to/merged_ckpt \
[--pretrained path/to/base_ckpt] \
--cpu

GPU with tensor/pipeline/expert parallelism::

torchrun --nproc_per_node <N> merge_lora.py \
--lora-checkpoint path/to/finetune_ckpt \
--hf-model-path path/to/hf_model \
--output path/to/merged_ckpt \
[--pretrained path/to/base_ckpt] \
[--tp 1] [--pp 1] [--ep 1]
"""

from __future__ import annotations
Expand All @@ -54,7 +65,7 @@
)
from megatron.bridge.training.model_load_save import save_megatron_model
from megatron.bridge.training.utils.checkpoint_utils import read_run_config
from megatron.bridge.utils.common_utils import print_rank_0
from megatron.bridge.utils.common_utils import print_rank_0, resolve_path


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -100,15 +111,15 @@ def parse_args() -> argparse.Namespace:

def _resolve_pretrained(lora_dir: Path, explicit: Optional[str]) -> Path:
if explicit:
return Path(explicit).expanduser().resolve()
return resolve_path(explicit)
cfg_path = lora_dir / "run_config.yaml"
if not cfg_path.exists():
raise FileNotFoundError("run_config.yaml not found in LoRA checkpoint and --pretrained not supplied")
cfg = read_run_config(str(cfg_path))
base = cfg.get("checkpoint", {}).get("pretrained_checkpoint")
if base is None:
raise ValueError("pretrained_checkpoint missing in run_config.yaml; pass --pretrained")
return Path(base).expanduser().resolve()
return resolve_path(base)


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -150,7 +161,11 @@ def merge_lora(
model_provider.expert_tensor_parallel_size = 1
model_provider.pipeline_dtype = torch.bfloat16
if args.cpu:
assert args.tp == args.pp == args.ep == 1, "TP, PP, and EP must be 1 when using CPU merge"
if args.tp != 1 or args.pp != 1 or args.ep != 1:
logger.warning("TP, PP, and EP must be 1 when using CPU merge. Setting to 1.")
args.tp = 1
args.pp = 1
args.ep = 1
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("gloo")
Comment on lines 163 to 170
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In the current PyTorch documentation, what defaults does torch.distributed.init_process_group() use when only backend="gloo" is passed, and which environment variables are required for env:// initialization?

💡 Result:

Defaults used by torch.distributed.init_process_group(backend="gloo") (i.e., only backend passed)

From the current PyTorch docs, the function signature/defaults are: init_process_group(backend=None, init_method=None, timeout=None, world_size=-1, rank=-1, store=None, group_name='', pg_options=None, device_id=None, ...). (docs.pytorch.org)

So with only backend="gloo" passed, the defaults in effect are:

  • init_method: assumed to be "env://" (because neither store nor init_method is specified). (docs.pytorch.org)
  • timeout: 30 minutes (docs: “10 minutes for NCCL and 30 minutes for other backends”; Gloo is “other backends”). (docs.pytorch.org)
  • world_size: -1 (meaning it must be provided via env:// env vars or explicitly). (docs.pytorch.org)
  • rank: -1 (meaning it must be provided via env:// env vars or explicitly). (docs.pytorch.org)
  • store=None, group_name='', pg_options=None, device_id=None (and other shown defaults). (docs.pytorch.org)

Environment variables required for env:// initialization

PyTorch docs list these variables for env://:

  • MASTER_PORT — required
  • MASTER_ADDR — required (except for rank 0)
  • WORLD_SIZE — required (unless you pass world_size directly)
  • RANK — required (unless you pass rank directly) (docs.pytorch.org)

Citations:


🏁 Script executed:

# First, let's check the file and its surrounding context
head -n 180 examples/peft/merge_lora.py | tail -n 50

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2363


🏁 Script executed:

# Also look for the docstring/description of what --cpu is supposed to do
head -n 50 examples/peft/merge_lora.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1931


🏁 Script executed:

# Check if there's any conditional logic that skips distributed init for CPU-only
grep -n "torch.distributed" examples/peft/merge_lora.py -A 3 -B 3

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 682


🏁 Script executed:

# Check if there's a way to properly initialize distributed for single-rank case
grep -n "initialize_model_parallel\|world_size\|rank" examples/peft/merge_lora.py | head -20

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 673


🏁 Script executed:

# Look for environment variable handling or alternative initialization patterns
grep -n "MASTER\|RANK\|WORLD_SIZE" examples/peft/merge_lora.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Check if there's conditional logic or context that handles single-rank initialization
grep -n "rank\|world_size" examples/peft/merge_lora.py -i

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 616


🏁 Script executed:

# Look for torch.distributed usage patterns in other files in the repo
find . -name "*.py" -type f -exec grep -l "init_process_group" {} \; | head -5

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1106


🏁 Script executed:

# Check how other examples initialize distributed process group
grep -A 5 "init_process_group" ./examples/conversion/compare_text_generation.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 237


🏁 Script executed:

# Check the model_provider to understand if initialize_model_parallel requires distributed init
grep -A 10 -B 5 "initialize_model_parallel" ./src/megatron/bridge/models/model_provider.py | head -40

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1942


🏁 Script executed:

# Search for any handling of rank=0 or single-rank initialization patterns
rg "rank.*0|world_size.*1" examples/peft/merge_lora.py -A 2 -B 2

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1468


Supply explicit single-rank configuration for standalone CPU usage or document required environment variables.

The docstring advertises python merge_lora.py ... --cpu as a standalone single-process example, but the code calls torch.distributed.init_process_group("gloo") without explicit rank, world_size, or init_method. PyTorch defaults init_method to "env://" when not specified, requiring MASTER_PORT, MASTER_ADDR, WORLD_SIZE, and RANK environment variables—which will not be set in a standalone Python invocation. Either pass explicit parameters (e.g., rank=0, world_size=1) or document that users must export these environment variables before running the CPU example.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/peft/merge_lora.py` around lines 163 - 170, When args.cpu is true
the code calls torch.distributed.init_process_group("gloo") without
rank/world_size/init_method which relies on environment vars; update the init
call in merge_lora.py (the block guarded by args.cpu and
torch.distributed.is_initialized()) to initialize a single-rank group explicitly
(e.g., pass backend="gloo", rank=0, world_size=1 or provide an explicit
init_method suitable for local single-process use) so the CPU example runs
standalone, or alternatively add a short comment/docstring near args.cpu
explaining that MASTER_ADDR, MASTER_PORT, WORLD_SIZE and RANK must be set if
using the env:// init method.

model_provider.initialize_model_parallel(seed=0)
Expand Down Expand Up @@ -255,7 +270,7 @@ def main() -> None:
level=logging.DEBUG if args.debug else logging.INFO, format="%(asctime)s %(levelname)s %(message)s"
)

lora_dir = Path(args.lora_checkpoint).expanduser().resolve()
lora_dir = resolve_path(args.lora_checkpoint)
if not lora_dir.exists():
raise FileNotFoundError(f"LoRA checkpoint not found: {lora_dir}")
base_dir = _resolve_pretrained(lora_dir, args.pretrained)
Expand All @@ -265,19 +280,16 @@ def main() -> None:
merge_lora(
base_dir=base_dir,
lora_dir=lora_dir,
out_dir=Path(args.output).expanduser().resolve(),
out_dir=resolve_path(args.output),
hf_model_path=args.hf_model_path,
args=args,
)
except torch.cuda.OutOfMemoryError:
logger.error("Out of memory error during merge. Trying CPU merge...")
merge_lora(
base_dir=base_dir,
lora_dir=lora_dir,
out_dir=Path(args.output).expanduser().resolve(),
hf_model_path=args.hf_model_path,
args=args,
)
logger.warning("CUDA out of memory during merge. Please rerun this script on CPU by adding the `--cpu` flag.")
raise SystemExit(1)
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()


if __name__ == "__main__":
Expand Down
Loading