diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index ff2a9fcfe..915f4399d 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -994,6 +994,68 @@ jobs: shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + e2e-test-lora: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-lora')) + runs-on: self-hosted + container: + image: radixark/miles:dev + options: > + --gpus all + --ipc=host + --shm-size=32g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + --privileged + --ulimit nofile=65535:65535 + -v /tmp:/tmp + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 8, "test_file": "e2e/lora/test_lora_qwen2.5_0.5B.py"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} + MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} + MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} + MILES_TEST_FEW_GPU: '0' + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Cleanup Ray processes + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + pkill -9 -f gcs_server 2>/dev/null || true + pkill -9 -f 'ray-dashboard' 2>/dev/null || true + pkill -9 sglang 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + sleep 3 + + - name: Install + shell: bash + run: | + cd /sgl-workspace/sglang && git fetch origin sglang-miles && git checkout FETCH_HEAD && git log --oneline -1 && pip install -e python --no-deps --break-system-packages + cd /root/Megatron-LM && git reset --hard HEAD && git log --oneline -1 && git apply $GITHUB_WORKSPACE/docker/patch/dev/megatron.patch && pip install -e . --no-deps --break-system-packages + cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages + + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + e2e-test-image: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-image')) runs-on: self-hosted @@ -1016,7 +1078,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_mimo_7B_mtp_only_grad.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_glm47_flash_r3_mtp.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 8, "test_file": "e2e/precision/test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k_async.py"}] + info: [{"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_mimo_7B_mtp_only_grad.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_glm47_flash_r3_mtp.py"}, {"num_gpus": 8, "test_file": "e2e/lora/test_lora_qwen2.5_0.5B.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 8, "test_file": "e2e/precision/test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k_async.py"}] defaults: run: working-directory: ${{ github.workspace }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 68d8a8921..b07b7ee00 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -39,6 +39,10 @@ {'test_file': 'e2e/long/test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 8}, ] %> +<% set lora_tests = [ + {'test_file': 'e2e/lora/test_lora_qwen2.5_0.5B.py', 'num_gpus': 8}, +] %> + <% set jobs = { 'fast': { 'test_executor': 'pytest', @@ -83,9 +87,13 @@ 'label': 'run-ci-long', 'tests': long_tests, }, + 'e2e-test-lora': { + 'label': 'run-ci-lora', + 'tests': lora_tests, + }, 'e2e-test-image': { 'label': 'run-ci-image', - 'tests': fsdp_tests + megatron_tests + short_tests + precision_tests + ckpt_tests + long_tests, + 'tests': fsdp_tests + megatron_tests + lora_tests + short_tests + precision_tests + ckpt_tests + long_tests, }, } %> name: PR Test diff --git a/examples/lora/run-qwen2.5-0.5B-megatron-lora.sh b/examples/lora/run-qwen2.5-0.5B-megatron-lora.sh new file mode 100644 index 000000000..18fbe40b9 --- /dev/null +++ b/examples/lora/run-qwen2.5-0.5B-megatron-lora.sh @@ -0,0 +1,165 @@ +#!/bin/bash +export FLASHINFER_DISABLE_VERSION_CHECK=1 +export GPUS_PER_NODE=8 +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +# for rerun the task +pkill sglang +ray stop --force +sleep 5 # Wait for processes to terminate gracefully +# Force kill any remaining processes. +# Note: `pkill -9 python` is broad and can be risky. +pkill -9 sglang +pkill -9 ray +pkill -9 python + +set -ex + + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/../../scripts/models/qwen2.5-0.5B.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen2.5-0.5B-Instruct/ + --megatron-to-hf-mode bridge +) + +LORA_ARGS=( + --lora-rank 32 # LoRA rank (typical values: 8, 16, 32, 64) + --lora-alpha 32 # LoRA alpha (usually 2x rank) + --lora-dropout 0.0 # LoRA dropout (0.0 for RL training) + --target-modules "all-linear" + --megatron-to-hf-mode bridge +) +############################## +############################## +############################## + +ROLLOUT_ARGS=( + --prompt-data /root/gsm8k/train.parquet + --input-key messages + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type math + --num-rollout 100 + # --num-rollout 10 # onyl train 10 stesp + --rollout-batch-size 32 + # --rollout-batch-size 16 # for testing + --n-samples-per-prompt 8 + --rollout-max-response-len 1024 + --rollout-temperature 1 + + --global-batch-size 256 + # --global-batch-size 32 # for testing +) + +EVAL_ARGS=( + # --eval-interval 20 + --eval-interval 10 + --eval-prompt-data gsm8k /root/gsm8k/test.parquet + --n-samples-per-eval-prompt 1 + --eval-max-response-len 1024 + --eval-top-k 1 +) + +PERF_ARGS=( + --tensor-model-parallel-size 1 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + # --use-kl-loss # if use kl loss, should use --ref-load + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + # --lr 1e-6 + --lr 1e-5 # Higher LR often works better for LoRA + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-host https://wandb.ai/ + --wandb-team miles-lora + --wandb-project miles-lora-megatron + --wandb-group qwen2.5-0.5B-gsm8k-test +) + + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + # --sglang-mem-fraction-static 0.7 + --sglang-mem-fraction-static 0.4 + + # --sglang-enable-deterministic-inference + # --sglang-attention-backend flashinfer + # --deterministic-mode +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash +) + + +# launch the master node of ray in container +ray start --head --node-ip-address 127.0.0.1 --num-gpus $GPUS_PER_NODE --disable-usage-stats +# ray start --head --node-ip-address 127.0.0.1 --num-gpus 1 --disable-usage-stats + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json='{ + "env_vars": { + "PYTHONPATH": "/root/Megatron-LM", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "NCCL_ALGO": "Ring", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", + "CUBLAS_WORKSPACE_CONFIG": ":4096:8" + } + }' \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node $GPUS_PER_NODE \ + --colocate \ + --calculate-per-token-loss \ + --use-miles-router \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${LORA_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} + + +# colocate : update from tesnor +# disaggrate : update from distributed \ No newline at end of file diff --git a/examples/lora/run-qwen3-4B-megatron-lora.sh b/examples/lora/run-qwen3-4B-megatron-lora.sh new file mode 100644 index 000000000..0980bee4d --- /dev/null +++ b/examples/lora/run-qwen3-4B-megatron-lora.sh @@ -0,0 +1,200 @@ +#!/bin/bash + +# Example launcher that reuses the Qwen3-4B recipe but delegates evaluation to an +# external Nemo Skills server via the eval_delegate_rollout wrapper. + +# Clean up any stale processes from a previous run. +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +SKILLS_OPENAI_MODEL_NAME=${SKILLS_OPENAI_MODEL_NAME:-"miles-openai-model"} + +export GPUS_PER_NODE=4 +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/../../.." &>/dev/null && pwd)" +source "${REPO_ROOT}/miles/scripts/models/qwen3-4B.sh" + +# Store eval/delegate settings in a YAML config similar to examples/eval_multi_task. +# EVAL_CONFIG_PATH=${SKILLS_EVAL_CONFIG_PATH:-"${REPO_ROOT}/examples/eval/scripts/multi_tasks.yaml"} +EVAL_CONFIG_PATH=${SKILLS_EVAL_CONFIG_PATH:-"${REPO_ROOT}/miles/examples/eval/scripts/multi_tasks.yaml"} + + +CKPT_ARGS=( + # --hf-checkpoint /root/Qwen3-4B + --hf-checkpoint /root/models/Qwen3-4B + --megatron-to-hf-mode bridge +) + + +LORA_ARGS=( + --lora-rank 32 # LoRA rank (typical values: 8, 16, 32, 64) + --lora-alpha 32 # LoRA alpha (usually 2x rank) + --lora-dropout 0.0 # LoRA dropout (0.0 for RL training) + --target-modules "all-linear" + --megatron-to-hf-mode bridge +) + + + +ROLLOUT_ARGS=( + # --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 3000 + # --rollout-batch-size 32 + --rollout-batch-size 16 + --n-samples-per-prompt 8 + # --rollout-max-response-len 8192 + --rollout-max-response-len 2048 + --rollout-temperature 1 + --over-sampling-batch-size 64 + + --dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std + # --global-batch-size 256 + --global-batch-size 128 + --balance-data +) + +# EVAL_ARGS=( +# --eval-interval 5 +# --eval-config "${EVAL_CONFIG_PATH}" +# --eval-function-path examples.eval.eval_delegate_rollout.generate_rollout +# ) + +EVAL_ARGS=( + --eval-interval 5 + --eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 2 + --eval-max-response-len 16384 + --eval-top-k 1 +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + # --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + # --lr 1e-6 + --lr 1e-5 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-host https://wandb.ai/ + --wandb-team miles-lora + --wandb-project miles-lora-megatron + --wandb-group qwen3-4B-test +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 + # --sglang-mem-fraction-static 0.4 + + --sglang-enable-deterministic-inference + --sglang-attention-backend flashinfer + --deterministic-mode +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +# export CUDA_VISIBLE_DEVICES=0,1 +# Set Up Your GPUs for Training + +# export GPUS_PER_NODE=2 #default + + +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus $GPUS_PER_NODE --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" + } +}" + + +# ray job submit --address="http://127.0.0.1:8265" \ + # --runtime-env-json="${RUNTIME_ENV_JSON}" \ + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json='{ + "env_vars": { + "PYTHONPATH": "/root/Megatron-LM", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "NCCL_ALGO": "Ring", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", + "CUBLAS_WORKSPACE_CONFIG": ":4096:8" + } + }' \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node $GPUS_PER_NODE \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${LORA_ARGS[@]} diff --git a/examples/lora/run-qwen3-4b-megatron-lora-result.sh b/examples/lora/run-qwen3-4b-megatron-lora-result.sh new file mode 100644 index 000000000..b8e2f7596 --- /dev/null +++ b/examples/lora/run-qwen3-4b-megatron-lora-result.sh @@ -0,0 +1,173 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# export SGLANG_LORA_PROFILE=1 +# export SGLANG_LORA_PROFILE_INTERVAL=10 +# export SGLANG_LORA_ENABLE_FUSION=1 + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 +# export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" +# export PYTORCH_ALLOC_CONF="expandable_segments:True" +export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3} +NUM_GPUS=$(echo $CUDA_VISIBLE_DEVICES | tr ',' '\n' | wc -l) +NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +LR=2e-5 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source /root/miles/scripts/models/qwen3-4B.sh + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen3-4B + --save /root/Qwen3-4B-lora-ckpt + --save-interval 50 +) + +LORA_ARGS=( + --lora-rank 64 + --lora-alpha 32 + --lora-dropout 0.0 # +fsdp + --target-modules all-linear +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --balance-data + --rm-type deepscaler + --num-rollout 100 + --rollout-batch-size 8 + --n-samples-per-prompt 8 + --rollout-max-response-len 4096 + --rollout-temperature 1 + --global-batch-size 64 +) + +EVAL_ARGS=( + --eval-interval 20 + --eval-prompt-data aime24 /root/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 16 + --eval-max-response-len 16384 + --eval-top-p 1 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr ${LR} + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-host https://wandb.ai/ + --wandb-team miles-lora + --wandb-project miles-lora-test + --wandb-group qwen3-4B-megatron-lora-dapo-lr${LR} + --disable-wandb-random-suffix +) + + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-decode-log-interval 1000 + # --sglang-enable-metrics # -fsdp + --sglang-mem-fraction-static 0.4 # +fsdp, memory usage on H200 = 140*0.4=56GB per GPU + # --sglang-attention-backend fa3 # +fsdp + # --sglang-attention-backend flashinfer + --sglang-chunked-prefill-size 4096 +) + +MEGATRON_ARGS=( + # --no-offload-train + # --no-offload-rollout + --megatron-to-hf-mode bridge + # --offload-rollout-level kv_cache weight # -fsdp: not supported in megatron + # --train-backend fsdp # -fsdp: use megatron instead + --train-backend megatron # +fsdp + --attention-dropout 0.0 # +fsdp: default dropout in megatron is 0.1 + --hidden-dropout 0.0 # +fsdp: default dropout in megatron is 0.1 + --accumulate-allreduce-grads-in-fp32 # +fsdp, megatron specific + --attention-softmax-in-fp32 # +fsdp, megatron specific + # --attention-backend flash # +fsdp, megatron specific + # --train-env-vars '{"PYTORCH_CUDA_ALLOC_CONF":"expandable_segments:True"}' # +fsdp, otherwise OOM +) + +PERF_ARGS=( + # --gradient-checkpointing # +fsdp + # --sequence-parallel # +fsdp + # --use-dynamic-batch-size # +fsdpF + # --max-tokens-per-gpu 9216 # +fsdp, perf +) + +MISC_ARGS=( + --actor-num-nodes 1 + --actor-num-gpus-per-node ${NUM_GPUS} + --colocate + --calculate-per-token-loss # +fsdp + --use-miles-router # +fsdp +) + +# launch the master node of ray in container +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats + + +RUNTIME_ENV_JSON='{ + "env_vars": { + "PYTHONPATH": "/root/Megatron-LM", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "NCCL_ALGO": "Ring", + "CUBLAS_WORKSPACE_CONFIG": ":4096:8" + } +}' + + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + "${MODEL_ARGS[@]}" \ + "${CKPT_ARGS[@]}" \ + "${LORA_ARGS[@]}" \ + "${ROLLOUT_ARGS[@]}" \ + "${EVAL_ARGS[@]}" \ + "${OPTIMIZER_ARGS[@]}" \ + "${GRPO_ARGS[@]}" \ + "${WANDB_ARGS[@]}" \ + "${SGLANG_ARGS[@]}" \ + "${MEGATRON_ARGS[@]}" \ + "${PERF_ARGS[@]}" \ + "${MISC_ARGS[@]}" \ No newline at end of file diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index f7d53bcf3..4c430c152 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -32,6 +32,7 @@ from ..training_utils.loss import compute_advantages_and_returns, get_log_probs_and_entropy, get_values from .checkpoint import load_checkpoint from .initialize import init, is_megatron_main_rank +from .lora_utils import is_lora_enabled from .model import forward_only, initialize_model_and_optimizer, save, train from .parallel import create_megatron_parallel_state from .replay_utils import get_register_replay_list_func @@ -77,6 +78,7 @@ def init( if args.offload_train: if (x := args.train_memory_margin_bytes) > 0: + # --train-memory-margin-bytes can tune this logger.info(f"Set torch_memory_saver.memory_margin_bytes to {x}") torch_memory_saver.memory_margin_bytes = x @@ -138,6 +140,7 @@ def init( weights_getter=lambda: self.weights_backuper.get("actor"), model_name=type(self.hf_config).__name__.lower() if self.args.model_name is None else self.args.model_name, quantization_config=getattr(self.hf_config, "quantization_config", None), + is_lora=is_lora_enabled(args), ) # empty cache after initialization @@ -483,12 +486,19 @@ def update_weights(self) -> None: if dist.get_rank() == 0: ray.get(self.rollout_manager.clear_num_new_engines.remote()) + if self.args.offload_train and is_lora_enabled(self.args): + # For LoRA, we must resume() to restore GPU memory backing for adapter + # weights. Unlike base model weights (which are read from CPU backups), + # LoRA adapter weights are accessed directly from GPU model parameters. + # The disable() context alone only prevents new allocations from being + # tracked -- it does NOT restore previously paused/offloaded tensors. + torch_memory_saver.resume() with torch_memory_saver.disable() if self.args.offload_train else nullcontext(): print_memory("before update_weights") self.weight_updater.update_weights() print_memory("after update_weights") - if self.args.ci_test and len(rollout_engines) > 0: + if self.args.ci_test and len(rollout_engines) > 0 and not is_lora_enabled(self.args): engine = random.choice(rollout_engines) engine_version = ray.get(engine.get_weight_version.remote()) if str(engine_version) != str(self.weight_updater.weight_version): @@ -508,6 +518,8 @@ def update_weights(self) -> None: self.weights_backuper.backup("old_actor") if self.args.offload_train: + if is_lora_enabled(self.args): + torch_memory_saver.pause() destroy_process_groups() def load_other_checkpoint(self, model_tag: str, path: str) -> None: diff --git a/miles/backends/megatron_utils/bridge_lora_helpers.py b/miles/backends/megatron_utils/bridge_lora_helpers.py new file mode 100644 index 000000000..5244faa35 --- /dev/null +++ b/miles/backends/megatron_utils/bridge_lora_helpers.py @@ -0,0 +1,122 @@ +"""Bridge / LoRA model setup helpers. + +Extracted from ``model.py`` to keep the main training module focused on +forward / backward / optimizer logic. +""" + +from __future__ import annotations + +from argparse import Namespace +from dataclasses import dataclass + +from megatron.core.utils import get_attr_wrapped_model + +from .lora_utils import create_lora_instance + + +@dataclass +class _BridgeWrapperConfig: + """Configuration for Megatron-Bridge module wrapping.""" + + is_value_model: bool = False + wrap_with_ddp: bool = True + use_distributed_optimizer: bool = True + + +def _ensure_model_list(model): + return model if isinstance(model, list) else [model] + + +def _make_value_model_hook(hidden_size: int, sequence_parallel: bool): + """Create a pre-wrap hook that replaces the output layer with a value head.""" + from megatron.core import parallel_state + + from .model_provider import LinearForLastLayer + + def hook(model): + model_post_process = [] + if ( + parallel_state.get_pipeline_model_parallel_world_size() > 1 + and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None + ): + for i in range(parallel_state.get_virtual_pipeline_model_parallel_world_size()): + model_post_process.append(parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i)) + else: + model_post_process.append(parallel_state.is_pipeline_last_stage()) + + model_list = _ensure_model_list(model) + assert len(model_post_process) == len(model_list), "Model list length and post process list length must match." + + for index, model_chunk in enumerate(model_list): + if not model_post_process[index]: + continue + model_chunk.output_layer = LinearForLastLayer( + input_size=hidden_size, + output_size=1, + sequence_parallel=sequence_parallel, + ) + + return hook + + +def _get_model_config_from_wrapped(model): + return get_attr_wrapped_model(model, "config", allow_none=False) + + +def _setup_lora_model_via_bridge(args: Namespace) -> list: + """Build Megatron model with LoRA using Megatron-Bridge. + + This handles: + 1. Creating the Bridge and Provider + 2. Creating and registering the LoRA pre-wrap hook + 3. Registering value-model hooks if needed + 4. Building the DDP-wrapped model + + Args: + args: Training arguments. + + Returns: + List of DDP-wrapped model chunks with LoRA applied. + """ + from megatron.bridge import AutoBridge + from megatron.bridge.training.config import DistributedDataParallelConfig + from transformers import AutoConfig + + hf_config = AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True) + bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) + provider = bridge.to_megatron_provider(load_weights=False) + + provider.tensor_model_parallel_size = args.tensor_model_parallel_size + provider.pipeline_model_parallel_size = args.pipeline_model_parallel_size + provider.expert_model_parallel_size = args.expert_model_parallel_size + provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size + provider.sequence_parallel = args.sequence_parallel + provider.virtual_pipeline_model_parallel_size = args.virtual_pipeline_model_parallel_size + provider.context_parallel_size = args.context_parallel_size + provider.variable_seq_lengths = True + provider.moe_token_dispatcher_type = "alltoall" + provider.moe_router_load_balancing_type = "none" + provider.finalize() + + lora = create_lora_instance(args) + + def apply_lora_hook(model_chunks): + transformed = lora(model_chunks, training=True) + lora.set_params_to_save(transformed) + return transformed + + provider.register_pre_wrap_hook(apply_lora_hook) + + is_value_model = ( + "ForTokenClassification" in hf_config.architectures[0] + or "ForSequenceClassification" in hf_config.architectures[0] + ) + if is_value_model: + hidden_size = hf_config.text_config.hidden_size if hasattr(hf_config, "text_config") else hf_config.hidden_size + provider.register_pre_wrap_hook(_make_value_model_hook(hidden_size, provider.sequence_parallel)) + + ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True) + ddp_config.finalize() + + model = provider.provide_distributed_model(wrap_with_ddp=True, ddp_config=ddp_config) + return model diff --git a/miles/backends/megatron_utils/checkpoint.py b/miles/backends/megatron_utils/checkpoint.py index 87495b0d0..03a9dd7bb 100644 --- a/miles/backends/megatron_utils/checkpoint.py +++ b/miles/backends/megatron_utils/checkpoint.py @@ -3,6 +3,8 @@ import re from pathlib import Path +import torch.distributed as dist + # TODO: may need to copy those 2 functions and do refactoring. from megatron.training.checkpointing import load_checkpoint as _load_checkpoint_megatron from megatron.training.checkpointing import save_checkpoint @@ -10,11 +12,12 @@ from miles.utils import megatron_bridge_utils +from .lora_utils import is_lora_enabled, is_lora_model, load_lora_adapter, save_lora_checkpoint + try: # Here we patch out the `validate_non_overlapping_shards_metadata` in both functions # because it is really slow for large models with many shards. # TODO: find a less hacky way to do this. - import torch.distributed as dist import torch.distributed._shard.sharding_spec as shard_spec from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata @@ -91,7 +94,7 @@ def _init_from_local_shards_and_global_metadata( # type: ignore[override] logger = logging.getLogger(__name__) -__all__ = ["save_checkpoint"] +__all__ = ["save_checkpoint", "save_checkpoint_with_lora", "load_checkpoint"] def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, checkpointing_context, skip_load_to_model_and_opt): @@ -104,7 +107,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, checkpointing_con ), f"{args.load=} does not exist or is an empty directory. Did you specify the wrong folder?" if _is_megatron_checkpoint(load_path): - return _load_checkpoint_megatron( + result = _load_checkpoint_megatron( ddp_model=ddp_model, optimizer=optimizer, opt_param_scheduler=opt_param_scheduler, @@ -112,13 +115,55 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, checkpointing_con skip_load_to_model_and_opt=skip_load_to_model_and_opt, ) else: - return _load_checkpoint_hf( + result = _load_checkpoint_hf( ddp_model=ddp_model, optimizer=optimizer, args=args, load_path=load_path, ) + # Load LoRA adapter weights if available + if is_lora_enabled(args): + adapter_path = getattr(args, "lora_adapter_path", None) + if adapter_path is not None: + loaded, iteration = load_lora_adapter( + ddp_model, + adapter_path, + optimizer=optimizer, + opt_param_scheduler=opt_param_scheduler, + ) + if loaded: + logger.info(f"Successfully loaded LoRA adapter from {adapter_path}") + if iteration is not None: + result = (iteration, result[1]) + else: + logger.warning( + f"LoRA is enabled and --lora-adapter-path={adapter_path} was specified, " + f"but adapter weights could not be loaded. " + f"Training will start with freshly initialized adapter weights." + ) + + return result + + +def save_checkpoint_with_lora(iteration, model, optimizer, opt_param_scheduler): + """Extended save that handles LoRA adapters separately.""" + args = get_args() + + if is_lora_model(model): + save_dir = Path(args.save) / f"iter_{iteration:07d}" / "adapter" + logger.info(f"Saving LoRA checkpoint to {save_dir}") + save_lora_checkpoint( + model, + args, + str(save_dir), + optimizer=optimizer, + opt_param_scheduler=opt_param_scheduler, + iteration=iteration, + ) + else: + save_checkpoint(iteration, model, optimizer, opt_param_scheduler) + def _is_megatron_checkpoint(path: str | Path) -> bool: return (Path(path) / "latest_checkpointed_iteration.txt").is_file() or bool( diff --git a/miles/backends/megatron_utils/lora_utils.py b/miles/backends/megatron_utils/lora_utils.py new file mode 100644 index 000000000..15bda708e --- /dev/null +++ b/miles/backends/megatron_utils/lora_utils.py @@ -0,0 +1,470 @@ +"""LoRA utilities for Megatron backend using Megatron-Bridge PEFT integration.""" + +import logging +import os +from argparse import Namespace +from collections.abc import Sequence +from pathlib import Path +from typing import Any + +import torch +import torch.distributed as dist +from megatron.core import mpu + +logger = logging.getLogger(__name__) + +LORA_ADAPTER_NAME = "miles_lora" + +# --------------------------------------------------------------------------- +# Unified HF <-> Megatron module name mappings +# --------------------------------------------------------------------------- + +# Standard LoRA: merged Q/K/V and merged up/gate +_STANDARD_LORA_HF_TO_MEGATRON = { + "q_proj": "linear_qkv", + "k_proj": "linear_qkv", + "v_proj": "linear_qkv", + "o_proj": "linear_proj", + "gate_proj": "linear_fc1", + "up_proj": "linear_fc1", + "down_proj": "linear_fc2", +} + +_STANDARD_LORA_ALL_MODULES = ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + +# CanonicalLoRA: Split Q/K/V and up/gate +_CANONICAL_LORA_HF_TO_MEGATRON = { + "q_proj": "linear_q", + "k_proj": "linear_k", + "v_proj": "linear_v", + "o_proj": "linear_proj", + "gate_proj": "linear_fc1_gate", + "up_proj": "linear_fc1_up", + "down_proj": "linear_fc2", +} + +_CANONICAL_LORA_ALL_MODULES = [ + "linear_q", + "linear_k", + "linear_v", + "linear_proj", + "linear_fc1_up", + "linear_fc1_gate", + "linear_fc2", +] + +# Megatron -> HF (inverse mapping, one-to-many) +# Covers both standard LoRA (merged) and CanonicalLoRA (split) module names. +_MEGATRON_TO_HF_MODULES = { + # Standard LoRA (merged layers) + "linear_qkv": ["q_proj", "k_proj", "v_proj"], + "linear_proj": ["o_proj"], + "linear_fc1": ["gate_proj", "up_proj"], + "linear_fc2": ["down_proj"], + # CanonicalLoRA (split layers) + "linear_q": ["q_proj"], + "linear_k": ["k_proj"], + "linear_v": ["v_proj"], + "linear_fc1_gate": ["gate_proj"], + "linear_fc1_up": ["up_proj"], +} + +_HF_MODULE_NAMES = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"} + + +# --------------------------------------------------------------------------- +# Core helpers +# --------------------------------------------------------------------------- + + +def is_lora_enabled(args: Namespace) -> bool: + """Check if LoRA is enabled based on arguments.""" + return getattr(args, "lora_rank", 0) > 0 or getattr(args, "lora_adapter_path", None) is not None + + +def is_lora_model(model: Sequence[torch.nn.Module]) -> bool: + """Check if model has LoRA layers applied.""" + for model_chunk in model: + if hasattr(model_chunk.module, "peft_config"): + return True + for name, _ in model_chunk.named_parameters(): + if "lora_" in name or "adapter" in name: + return True + return False + + +def is_lora_weight_name(name: str) -> bool: + """Check if a weight name corresponds to a LoRA adapter weight.""" + return ".lora_A." in name or ".lora_B." in name + + +def _is_adapter_param_name(name: str) -> bool: + """Check if a parameter name belongs to a LoRA adapter (Megatron internal naming).""" + return "lora_" in name or (".adapter." in name and ("linear_in" in name or "linear_out" in name)) + + +# --------------------------------------------------------------------------- +# Module name conversion +# --------------------------------------------------------------------------- + + +def _get_lora_class_name(lora_type: type | object | None) -> str: + """Resolve LoRA type to its class name string.""" + if lora_type is None: + return "CanonicalLoRA" + if isinstance(lora_type, type): + return lora_type.__name__ + return type(lora_type).__name__ + + +def convert_target_modules_to_megatron( + hf_modules: str | list[str], + lora_type: type | object | None = None, +) -> list[str]: + """Convert HuggingFace LoRA target module names to Megatron format. + + HF: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj + Megatron (LoRA): linear_qkv, linear_proj, linear_fc1, linear_fc2 + Megatron (CanonicalLoRA): linear_q, linear_k, linear_v, linear_proj, + linear_fc1_up, linear_fc1_gate, linear_fc2 + + Special values: "all", "all-linear", "all_linear" -> all standard linear modules. + If input is already in Megatron format, returns as-is. + """ + class_name = _get_lora_class_name(lora_type) + is_canonical = class_name == "CanonicalLoRA" + + all_modules = _CANONICAL_LORA_ALL_MODULES if is_canonical else _STANDARD_LORA_ALL_MODULES + hf_to_megatron = _CANONICAL_LORA_HF_TO_MEGATRON if is_canonical else _STANDARD_LORA_HF_TO_MEGATRON + + # Handle special "all-linear" variants + if isinstance(hf_modules, str): + if hf_modules in ("all", "all-linear", "all_linear"): + return list(all_modules) + hf_modules = [hf_modules] + elif isinstance(hf_modules, list) and len(hf_modules) == 1: + if hf_modules[0] in ("all", "all-linear", "all_linear"): + return list(all_modules) + + # Check if already in Megatron format + if all(m not in _HF_MODULE_NAMES for m in hf_modules if "*" not in m): + return hf_modules + + # Convert HF names to Megatron names (dedup while preserving order) + megatron_modules: list[str] = [] + for module in hf_modules: + megatron_name = hf_to_megatron.get(module, module) + if megatron_name not in megatron_modules: + megatron_modules.append(megatron_name) + + return megatron_modules + + +def convert_target_modules_to_hf(megatron_modules: list[str]) -> list[str]: + """Convert Megatron LoRA target module names to HuggingFace format. + + Supports both standard LoRA and CanonicalLoRA module names. + + Megatron standard: linear_qkv, linear_proj, linear_fc1, linear_fc2 + Megatron canonical: linear_q, linear_k, linear_v, linear_proj, + linear_fc1_up, linear_fc1_gate, linear_fc2 + HF: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj + """ + hf_modules: list[str] = [] + for module in megatron_modules: + if module in _MEGATRON_TO_HF_MODULES: + hf_modules.extend(_MEGATRON_TO_HF_MODULES[module]) + else: + hf_modules.append(module) + return hf_modules + + +# --------------------------------------------------------------------------- +# Model setup helpers (used by model.py) +# --------------------------------------------------------------------------- + + +def parse_exclude_modules(args: Namespace, lora_type=None) -> list[str]: + """Parse and convert exclude_modules argument.""" + exclude_modules: list[str] = [] + raw = getattr(args, "exclude_modules", None) + if raw: + if isinstance(raw, str): + exclude_modules = [m.strip() for m in raw.split(",")] + else: + exclude_modules = list(raw) + exclude_modules = convert_target_modules_to_megatron(exclude_modules, lora_type=lora_type) + return exclude_modules + + +def create_lora_instance(args: Namespace): + """Create a LoRA or CanonicalLoRA instance based on args. + + Returns: + A LoRA/CanonicalLoRA dataclass instance ready to be applied to a model. + """ + from megatron.bridge.peft.canonical_lora import CanonicalLoRA + from megatron.bridge.peft.lora import LoRA + + lora_type_name = getattr(args, "lora_type", "lora").lower() + + if lora_type_name == "canonical_lora": + lora_cls = CanonicalLoRA + else: + lora_cls = LoRA + + target_modules = convert_target_modules_to_megatron(args.target_modules, lora_type=lora_cls) + exclude_modules = parse_exclude_modules(args, lora_type=lora_cls) + + lora = lora_cls( + target_modules=target_modules, + exclude_modules=exclude_modules, + dim=args.lora_rank, + alpha=args.lora_alpha, + dropout=args.lora_dropout, + lora_A_init_method=getattr(args, "lora_A_init_method", "xavier"), + lora_B_init_method=getattr(args, "lora_B_init_method", "zero"), + ) + + logger.info( + f"Created {lora_cls.__name__}: rank={args.lora_rank}, alpha={args.lora_alpha}, " + f"dropout={args.lora_dropout}, target_modules={target_modules}, " + f"exclude_modules={exclude_modules}" + ) + return lora + + +# --------------------------------------------------------------------------- +# Checkpoint save/load +# --------------------------------------------------------------------------- + + +def save_lora_checkpoint( + model: Sequence[torch.nn.Module], + args: Namespace, + save_dir: str, + *, + optimizer: Any | None = None, + opt_param_scheduler: Any | None = None, + iteration: int | None = None, +) -> str: + """Save LoRA adapter checkpoint to disk. + + Saves in two formats: + 1. **HF PEFT format** (``adapter_model.bin`` + ``adapter_config.json``) for + external tool compatibility. Uses Megatron-Bridge's ``export_adapter_weights`` + which correctly handles fused QKV / gate-up weight splitting and TP gathering. + 2. **Megatron-native format** (``adapter_megatron_tp{tp}_pp{pp}.pt``) for fast + checkpoint resume without name/weight conversion. Each TP/PP rank saves its + own shard with original parameter names. + + When ``optimizer`` is provided, training state (optimizer + LR scheduler) is + also saved per-rank for checkpoint resume. Base model weights are frozen and + never change, so they are not saved. + + This function is collective: **all ranks must call it** because the bridge + export performs TP all-gather internally. Only ``dp_rank == 0`` writes files. + """ + import json + + from megatron.bridge import AutoBridge + + from miles.utils import megatron_bridge_utils + + save_path = Path(save_dir) + is_dp_rank_0 = mpu.get_data_parallel_rank() == 0 + tp_rank = mpu.get_tensor_model_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + + # Create directory on dp_rank=0, then synchronize + if is_dp_rank_0: + save_path.mkdir(parents=True, exist_ok=True) + if dist.is_initialized(): + dist.barrier() + + # ---- Megatron-native format (per TP/PP rank, fast resume) ---- + if is_dp_rank_0: + adapter_state: dict[str, torch.Tensor] = {} + for model_chunk in model: + for name, param in model_chunk.named_parameters(): + if _is_adapter_param_name(name): + adapter_state[name] = param.data.cpu() + + native_path = save_path / f"adapter_megatron_tp{tp_rank}_pp{pp_rank}.pt" + torch.save(adapter_state, native_path) + logger.info(f"Saved {len(adapter_state)} adapter tensors (native) to {native_path}") + + # ---- HF PEFT format (uses bridge for correct name/weight conversion) ---- + # Bridge export is collective: all TP ranks participate in the all-gather, + # so every rank must call export_adapter_weights. + bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) + + lora_state_dict: dict[str, torch.Tensor] = {} + with megatron_bridge_utils.patch_megatron_model(model): + for hf_name, weight, _megatron_name in bridge.export_adapter_weights( + model, + cpu=True, + show_progress=False, + ): + lora_state_dict[hf_name] = weight + + # Only one rank writes the HF PEFT files (bridge already gathered across TP) + if is_dp_rank_0 and tp_rank == 0: + torch.save(lora_state_dict, save_path / "adapter_model.bin") + + target_modules_hf = ( + convert_target_modules_to_hf(list(args.target_modules)) + if args.target_modules + else ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + ) + config = { + "peft_type": "LORA", + "r": args.lora_rank, + "lora_alpha": args.lora_alpha, + "target_modules": target_modules_hf, + "lora_dropout": args.lora_dropout, + "bias": "none", + "task_type": "CAUSAL_LM", + } + with open(save_path / "adapter_config.json", "w") as f: + json.dump(config, f, indent=2) + + os.sync() + logger.info(f"Saved HF PEFT adapter to {save_path} with {len(lora_state_dict)} tensors") + + # ---- Training state (optimizer + scheduler) for resume ---- + if optimizer is not None: + rank = dist.get_rank() if dist.is_initialized() else 0 + torch.save( + { + "iteration": iteration, + "optimizer": optimizer.state_dict(), + "opt_param_scheduler": opt_param_scheduler.state_dict() if opt_param_scheduler else None, + }, + save_path / f"training_state_rank{rank}.pt", + ) + logger.info(f"Saved optimizer/scheduler state to {save_path}") + + if dist.is_initialized(): + dist.barrier() + + return str(save_path) + + +def load_lora_adapter( + model: Sequence[torch.nn.Module], + adapter_path: str, + *, + optimizer: Any | None = None, + opt_param_scheduler: Any | None = None, +) -> tuple[bool, int | None]: + """Load LoRA adapter weights from a saved checkpoint into the model. + + Attempts to load from Megatron-native format first (per-rank ``.pt`` files), + which preserves the exact TP/PP sharding and requires no name conversion. + Falls back to HF PEFT ``adapter_model.bin`` if native files are not found + (not yet implemented for HF PEFT format). + + When ``optimizer`` is provided, also restores training state (optimizer + + LR scheduler) from a co-located ``training_state_rank*.pt`` file. + + Args: + model: List of DDP-wrapped model chunks with LoRA layers already applied. + adapter_path: Path to the adapter checkpoint directory. + optimizer: If provided, restore optimizer state for training resume. + opt_param_scheduler: If provided, restore LR scheduler state. + + Returns: + ``(loaded, iteration)`` — *loaded* is True if adapter weights were + successfully loaded; *iteration* is the saved iteration number (or None + if no training state was found). + """ + adapter_dir = Path(adapter_path) + if not adapter_dir.exists(): + logger.warning(f"LoRA adapter path does not exist: {adapter_dir}") + return False, None + + tp_rank = mpu.get_tensor_model_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + + # ---- Try Megatron-native format first (fast, no conversion needed) ---- + native_path = adapter_dir / f"adapter_megatron_tp{tp_rank}_pp{pp_rank}.pt" + if native_path.exists(): + state_dict = torch.load(native_path, map_location="cpu", weights_only=True) + loaded = 0 + for model_chunk in model: + for name, param in model_chunk.named_parameters(): + if name in state_dict: + param.data.copy_(state_dict[name].to(device=param.device)) + loaded += 1 + logger.info(f"Loaded {loaded} adapter tensors from Megatron-native checkpoint: {native_path}") + + iteration = _load_training_state(adapter_dir, optimizer, opt_param_scheduler) + return True, iteration + + # ---- HF PEFT format (future work) ---- + hf_path = adapter_dir / "adapter_model.bin" + if hf_path.exists(): + logger.warning( + f"Found HF PEFT adapter at {hf_path} but direct HF PEFT loading into " + f"Megatron is not yet supported. Please save using Megatron-native format " + f"(adapter_megatron_tp*_pp*.pt files) for checkpoint resume." + ) + return False, None + + logger.warning(f"No adapter checkpoint found at {adapter_dir}") + return False, None + + +def _load_training_state( + adapter_dir: Path, + optimizer: Any | None, + opt_param_scheduler: Any | None, +) -> int | None: + """Restore optimizer/scheduler state saved alongside a LoRA adapter checkpoint.""" + if optimizer is None: + return None + + rank = dist.get_rank() if dist.is_initialized() else 0 + state_path = adapter_dir / f"training_state_rank{rank}.pt" + if not state_path.exists(): + return None + + # Optimizer state dicts may contain non-tensor objects (e.g. step counts, + # param group metadata), so full unpickling is required here. + training_state = torch.load(state_path, map_location="cpu", weights_only=False) + + optimizer.load_state_dict(training_state["optimizer"]) + logger.info("Restored optimizer state from LoRA checkpoint") + + if opt_param_scheduler is not None and training_state.get("opt_param_scheduler") is not None: + opt_param_scheduler.load_state_dict(training_state["opt_param_scheduler"]) + logger.info("Restored LR scheduler state from LoRA checkpoint") + + iteration = training_state.get("iteration") + if iteration is not None: + logger.info(f"Resuming LoRA training from iteration {iteration}") + return iteration + + +# --------------------------------------------------------------------------- +# LoRA config dict for weight sync to SGLang +# --------------------------------------------------------------------------- + + +def build_lora_sync_config(args: Namespace) -> dict[str, Any]: + """Build LoRA config dict for syncing weights to SGLang engines.""" + target_modules_hf = ( + convert_target_modules_to_hf(list(args.target_modules)) + if args.target_modules + else ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + ) + return { + "peft_type": "LORA", + "r": args.lora_rank, + "lora_alpha": args.lora_alpha, + "target_modules": target_modules_hf, + "lora_dropout": args.lora_dropout, + "bias": "none", + "task_type": "CAUSAL_LM", + } diff --git a/miles/backends/megatron_utils/megatron_to_hf/__init__.py b/miles/backends/megatron_utils/megatron_to_hf/__init__.py index b42cacdce..e1b6bb4c0 100644 --- a/miles/backends/megatron_utils/megatron_to_hf/__init__.py +++ b/miles/backends/megatron_utils/megatron_to_hf/__init__.py @@ -78,3 +78,49 @@ def _convert_to_hf_core(args, model_name, name, param): else: converted_named_tensors.append((converted_name, converted_param)) return converted_named_tensors + + +def convert_lora_to_hf(args, model_name, name, param): + """Convert Megatron LoRA parameter to HuggingFace PEFT format. + + .. deprecated:: + This function uses hardcoded string replacements that do NOT correctly + handle fused layers (e.g. linear_qkv -> q/k/v_proj, linear_fc1 -> + gate/up_proj). Use ``AutoBridge.export_adapter_weights`` instead, which + properly splits fused adapter weights and handles TP gathering. + + Megatron format: module.module.decoder.layers.0.self_attention.linear_qkv.adapter.linear_in.weight + HF PEFT format: base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight + """ + import warnings + + warnings.warn( + "convert_lora_to_hf uses incorrect hardcoded name mapping for fused layers. " + "Use AutoBridge.export_adapter_weights instead.", + DeprecationWarning, + stacklevel=2, + ) + + # Determine lora_A vs lora_B + if ".linear_in." in name or ".lora_A." in name: + lora_suffix = "lora_A.weight" + elif ".linear_out." in name or ".lora_B." in name: + lora_suffix = "lora_B.weight" + else: + return [(name, param)] + + # Convert Megatron naming to HF PEFT naming + hf_name = name + hf_name = hf_name.replace("module.module.", "base_model.model.") + hf_name = hf_name.replace(".decoder.layers.", ".model.layers.") + hf_name = hf_name.replace(".self_attention.linear_qkv", ".self_attn.q_proj") + hf_name = hf_name.replace(".self_attention.linear_proj", ".self_attn.o_proj") + hf_name = hf_name.replace(".mlp.linear_fc1", ".mlp.gate_proj") + hf_name = hf_name.replace(".mlp.linear_fc2", ".mlp.down_proj") + + hf_name = hf_name.replace(".adapter.linear_in.weight", f".{lora_suffix}") + hf_name = hf_name.replace(".adapter.linear_out.weight", f".{lora_suffix}") + hf_name = hf_name.replace(".lora_A.weight", f".{lora_suffix}") + hf_name = hf_name.replace(".lora_B.weight", f".{lora_suffix}") + + return [(hf_name, param)] diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index 29c2d704a..eb46e612b 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -28,13 +28,18 @@ from ..training_utils.log_utils import aggregate_forward_results, aggregate_train_losses, log_train_step from ..training_utils.loss import loss_function from ..training_utils.parallel import ParallelState -from .checkpoint import load_checkpoint, save_checkpoint +from .checkpoint import load_checkpoint, save_checkpoint, save_checkpoint_with_lora +from .lora_utils import is_lora_enabled, is_lora_model from .model_provider import get_model_provider_func from .parallel import get_packed_seq_params logger = logging.getLogger(__name__) +from .bridge_lora_helpers import _ensure_model_list, _setup_lora_model_via_bridge # noqa: F401 +from .lora_utils import save_lora_checkpoint + + def get_optimizer_param_scheduler(args: Namespace, optimizer: MegatronOptimizer) -> OptimizerParamScheduler: """Create and configure the optimizer learning-rate/weight-decay scheduler. @@ -83,6 +88,11 @@ def get_optimizer_param_scheduler(args: Namespace, optimizer: MegatronOptimizer) return opt_param_scheduler +# --------------------------------------------------------------------------- +# Model + Optimizer setup +# --------------------------------------------------------------------------- + + def setup_model_and_optimizer( args: Namespace, role: str = "actor", @@ -92,11 +102,6 @@ def setup_model_and_optimizer( Args: args (Namespace): Training/runtime arguments (argparse namespace). role (str): Logical role of the model (e.g., "actor", "critic"). - no_wd_decay_cond (Callable[..., bool] | None): Predicate to exclude - parameters from weight decay. - scale_lr_cond (Callable[..., bool] | None): Predicate to scale LR for - selected parameter groups. - lr_mult (float): Global learning-rate multiplier for the optimizer. Returns: tuple[list[DDP], MegatronOptimizer, OptimizerParamScheduler]: @@ -107,7 +112,10 @@ def setup_model_and_optimizer( assert not args.moe_use_upcycling assert args.load is not None or args.pretrained_checkpoint is not None - model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder) + if is_lora_enabled(args) and role == "actor" and args.megatron_to_hf_mode == "bridge": + model = _setup_lora_model_via_bridge(args) + else: + model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder) # Optimizer kwargs = {} @@ -116,7 +124,6 @@ def setup_model_and_optimizer( kwargs[f.name] = getattr(args, f.name) config = OptimizerConfig(**kwargs) config.timers = None - optimizer = get_megatron_optimizer( config=config, model_chunks=model, @@ -126,6 +133,11 @@ def setup_model_and_optimizer( return model, optimizer, opt_param_scheduler +# --------------------------------------------------------------------------- +# Forward pre-hook helpers +# --------------------------------------------------------------------------- + + def enable_forward_pre_hook(model_chunks: Sequence[DDP]) -> None: """Enable forward pre-hooks for provided DDP-wrapped model chunks. @@ -149,6 +161,16 @@ def disable_forward_pre_hook(model_chunks: Sequence[DDP], param_sync: bool = Tru model_chunk.disable_forward_pre_hook(param_sync=param_sync) +def should_disable_forward_pre_hook(args: Namespace) -> bool: + """Block forward pre-hook for certain configurations.""" + return args.use_distributed_optimizer and args.overlap_param_gather + + +# --------------------------------------------------------------------------- +# Forward-only inference +# --------------------------------------------------------------------------- + + @torch.no_grad() def forward_only( f: Callable[..., dict[str, list[torch.Tensor]]], @@ -165,23 +187,16 @@ def forward_only( executed, and relevant outputs are aggregated and returned. Args: - f (Callable[..., dict[str, list[torch.Tensor]]]): Post-forward callback used to - compute and package outputs to collect. This should accept a logits - tensor as its first positional argument and additional keyword-only - arguments; see ``get_log_probs_and_entropy``/``get_values`` in - ``megatron_utils.loss`` for examples. It will be partially applied - so that the callable returned from the internal forward step only - requires the logits tensor. - args (Namespace): Runtime arguments. - model (Sequence[DDP]): Sequence of DDP-wrapped model chunks. - data_iterator (Sequence[DataIterator]): Iterable(s) yielding batches for inference. - num_microbatches (Sequence[int]): Number of microbatches per rollout step. - store_prefix (str): Prefix to prepend to stored output keys. + f: Post-forward callback used to compute and package outputs to collect. + args: Runtime arguments. + model: Sequence of DDP-wrapped model chunks. + data_iterator: Iterable(s) yielding batches for inference. + num_microbatches: Number of microbatches per rollout step. + store_prefix: Prefix to prepend to stored output keys. Returns: - dict[str, list[torch.Tensor]]: Aggregated outputs keyed by ``store_prefix + key``. + Aggregated outputs keyed by ``store_prefix + key``. """ - # reset data iterator for iterator in data_iterator: iterator.reset() @@ -304,18 +319,17 @@ def train_one_step( one scheduler step when gradients are valid. Args: - args (Namespace): Runtime arguments. - rollout_id (int): Rollout identifier. - step_id (int): Step index within the current rollout. - data_iterator (Sequence[DataIterator]): Iterable(s) yielding training batches. - model (Sequence[DDP]): Sequence of DDP-wrapped model chunks. - optimizer (MegatronOptimizer): Optimizer instance. - opt_param_scheduler (OptimizerParamScheduler): LR/WD scheduler. - num_microbatches (int): Number of microbatches to process. + args: Runtime arguments. + rollout_id: Rollout identifier. + step_id: Step index within the current rollout. + data_iterator: Iterable(s) yielding training batches. + model: Sequence of DDP-wrapped model chunks. + optimizer: Optimizer instance. + opt_param_scheduler: LR/WD scheduler. + num_microbatches: Number of microbatches to process. Returns: - tuple[dict[str, float], float]: Reduced loss dictionary (last stage only) - and gradient norm for logging. + Reduced loss dictionary (last stage only) and gradient norm for logging. """ args = get_args() @@ -462,13 +476,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p return {}, grad_norm -def should_disable_forward_pre_hook(args: Namespace) -> bool: - """Block forward pre-hook for certain configurations.""" - return args.use_distributed_optimizer and args.overlap_param_gather - - def finalize_model_grads_with_empty_cache(*args, **kwargs): - # trigger empty cache when there are less than 10% free memory before the final reduce scatter. # TODO: this is an ad-hoc method and we should figure out why the oom happens in the first place. device = torch.cuda.current_device() free, total = torch.cuda.mem_get_info(device) @@ -678,16 +686,21 @@ def save( args = get_args() if should_disable_forward_pre_hook(args): disable_forward_pre_hook(model) - save_checkpoint( - iteration, - model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far=0, - checkpointing_context=None, - train_data_iterator=None, - preprocess_common_state_dict_fn=None, - ) + + if is_lora_model(model): + save_checkpoint_with_lora(iteration, model, optimizer, opt_param_scheduler) + else: + save_checkpoint( + iteration, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far=0, + checkpointing_context=None, + train_data_iterator=None, + preprocess_common_state_dict_fn=None, + ) + if should_disable_forward_pre_hook(args): enable_forward_pre_hook(model) @@ -695,7 +708,16 @@ def save( def save_hf_model(args, rollout_id: int, model: Sequence[DDP]) -> None: """Save Megatron model in HuggingFace format. + For LoRA models this saves both: + - A **merged** HF model (adapter weights folded into base) at ``{path}/`` + so it can be loaded directly with ``AutoModelForCausalLM.from_pretrained``. + - An **adapter-only** HF PEFT checkpoint at ``{path}/adapter/`` + so it can be loaded with ``PeftModel.from_pretrained``. + + This function is collective — all ranks must call it. + Args: + args: Runtime arguments. model (Sequence[DDP]): Sequence of DDP-wrapped model chunks. rollout_id (int): Rollout ID for path formatting. """ @@ -718,17 +740,29 @@ def save_hf_model(args, rollout_id: int, model: Sequence[DDP]) -> None: path.mkdir(parents=True, exist_ok=True) with patch_megatron_model(model): - bridge.save_hf_pretrained( - model, - path=path, - ) + # For LoRA models, merge_adapter_weights=True (default) merges + # adapter weights into base weights for a standalone HF model. + bridge.save_hf_pretrained(model, path=path) if should_log: - logger.info(f"Successfully saved HuggingFace model to {path}") + logger.info(f"Successfully saved merged HuggingFace model to {path}") except Exception as e: if should_log: logger.error(f"Failed to save HuggingFace format: {e}") + # Additionally save adapter-only checkpoint for LoRA models + if is_lora_model(model): + try: + adapter_path = Path(args.save_hf.format(rollout_id=rollout_id)) / "adapter" + if should_log: + logger.info(f"Saving LoRA adapter (HF PEFT format) to {adapter_path}") + save_lora_checkpoint(model, args, str(adapter_path)) + if should_log: + logger.info(f"Successfully saved LoRA adapter to {adapter_path}") + except Exception as e: + if should_log: + logger.error(f"Failed to save LoRA adapter: {e}") + def initialize_model_and_optimizer( args: Namespace, role: str = "actor" @@ -743,7 +777,6 @@ def initialize_model_and_optimizer( tuple[list[DDP], MegatronOptimizer, OptimizerParamScheduler, int]: DDP-wrapped model chunks, optimizer, scheduler, and iteration index. """ - if torch.version.hip: import megatron.core.dist_checkpointing.strategies.filesystem_async as filesystem_async_module diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_base.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_base.py index ef7d62e8a..93059e8e9 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_base.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_base.py @@ -3,7 +3,7 @@ class HfWeightIteratorBase(ABC): @staticmethod - def create(args, model, **kwargs): + def create(args, model, *, is_lora=False, **kwargs): from .hf_weight_iterator_bridge import HfWeightIteratorBridge from .hf_weight_iterator_direct import HfWeightIteratorDirect @@ -12,13 +12,14 @@ def create(args, model, **kwargs): "bridge": HfWeightIteratorBridge, }[args.megatron_to_hf_mode] - return c(args, model, **kwargs) + return c(args, model, is_lora=is_lora, **kwargs) - def __init__(self, args, model, model_name, quantization_config): + def __init__(self, args, model, model_name, quantization_config, *, is_lora=False): self.args = args self.model = model self.model_name = model_name self.quantization_config = quantization_config + self.is_lora = is_lora @abstractmethod def get_hf_weight_chunks(self, megatron_local_weights): diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py index 7e0a4817e..aa6d2b730 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py @@ -19,14 +19,25 @@ def __init__(self, *args, **kwargs): self._bridge = AutoBridge.from_hf_pretrained(self.args.hf_checkpoint, trust_remote_code=True) def get_hf_weight_chunks(self, megatron_local_weights): - # TODO support quantization (e.g. modify megatron-bridge to provide megatron param name) + # TODO: support quantization (e.g. modify megatron-bridge to provide megatron param name) renamed_megatron_local_weights = {strip_param_name_prefix(k): v for k, v in megatron_local_weights.items()} with megatron_bridge_utils.patch_megatron_model(self.model): - conversion_tasks = self._bridge.get_conversion_tasks(self.model) - conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) - - named_weights = self._bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks) + if self.is_lora: + named_weights = self._bridge.export_adapter_weights( + self.model, + cpu=False, + show_progress=False, + ) + else: + conversion_tasks = self._bridge.get_conversion_tasks(self.model) + conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) + named_weights = self._bridge.export_hf_weights( + self.model, + cpu=False, + conversion_tasks=conversion_tasks, + ) + # TODO: verify if postprocess_hf_param is needed for LoRA weights named_weights = ( ( hf_param_name, diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 117c3d418..caf6ae54f 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -31,6 +31,7 @@ def __init__( *, model_name: str, quantization_config: dict[str, int | str | list[str]] | None, + is_lora: bool = False, ) -> None: """ Initialize. Groups created in connect_rollout_engines. diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 28cc19c97..c21f53d76 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -1,3 +1,4 @@ +import logging from argparse import Namespace from collections.abc import Callable, Mapping, Sequence from typing import Any @@ -9,6 +10,7 @@ from ray import ObjectRef from ray.actor import ActorHandle +from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME, build_lora_sync_config, is_lora_weight_name from miles.utils.distributed_utils import get_gloo_group from ..sglang import FlattenedTensorBucket, MultiprocessingSerializer @@ -20,12 +22,14 @@ update_weights_from_distributed, ) +logger = logging.getLogger(__name__) + class UpdateWeightFromTensor: """ Update rollout engines from tensor dict: - load(dict→GPU) → broadcast PP/EP(GPU NCCL) → gather TP(GPU NCCL) → convert HF(GPU) → send. - Colocated: GPU→CPU serialize → gather_object(Gloo CPU, collects from rollout_num_gpus_per_engine ranks) → Ray IPC to engine. + load(dict->GPU) -> broadcast PP/EP(GPU NCCL) -> gather TP(GPU NCCL) -> convert HF(GPU) -> send. + Colocated: GPU->CPU serialize -> gather_object(Gloo CPU) -> Ray IPC to engine. Distributed: GPU NCCL broadcast to remote engines. """ @@ -37,6 +41,7 @@ def __init__( *, model_name: str, quantization_config: dict[str, int | str | list[str]] | None, + is_lora: bool = False, ) -> None: """ Compute param buckets, create IPC Gloo groups (rollout_num_gpus_per_engine ranks/group). @@ -47,12 +52,19 @@ def __init__( self.model_name = model_name self.quantization_config = quantization_config self.weight_version = 0 + self.is_lora = is_lora + self._lora_loaded = False self._hf_weight_iterator = HfWeightIteratorBase.create( - args=args, model=model, model_name=model_name, quantization_config=quantization_config + args=args, + model=model, + model_name=model_name, + quantization_config=quantization_config, + is_lora=self.is_lora, ) - # create the group within megatron. + self._lora_config = build_lora_sync_config(args) if self.is_lora else None + # Create IPC gather groups within megatron. for start_rank in range(0, dist.get_world_size(), self.args.rollout_num_gpus_per_engine): end_rank = start_rank + self.args.rollout_num_gpus_per_engine group_ranks = list(range(start_rank, end_rank)) @@ -90,7 +102,6 @@ def connect_rollout_engines( disconnect_rollout_engines_from_distributed( self.args, self._group_name, self._model_update_groups, self.distributed_rollout_engines ) - self._model_update_groups = connect_rollout_engines_from_distributed( self.args, self._group_name, self.distributed_rollout_engines ) @@ -147,23 +158,47 @@ def update_weights(self) -> None: def _send_hf_params(self, hf_named_tensors) -> tuple[list[ObjectRef], Any]: all_refs = [] + long_lived_tensors = [] - refs_colocated, long_lived_tensors = _send_to_colocated_engine( - hf_named_tensors, + # Separate LoRA weights from base weights + if self.is_lora: + weight_tensors = [(n, t) for n, t in hf_named_tensors if is_lora_weight_name(n)] + else: + weight_tensors = hf_named_tensors + + kwargs = dict( + hf_named_tensors=weight_tensors, ipc_engine=self._ipc_engine, ipc_gather_src=self._ipc_gather_src, ipc_gather_group=self._ipc_gather_group, - weight_version=self.weight_version, ) + if self.is_lora: + kwargs |= dict( + lora_config=self._lora_config, + lora_name=LORA_ADAPTER_NAME, + lora_loaded=self._lora_loaded, + ) + else: + kwargs |= dict( + weight_version=self.weight_version, + ) + + refs_colocated, long_lived_tensors = _send_to_colocated_engine(**kwargs) all_refs.extend(refs_colocated) + if self.is_lora: + self._lora_loaded = True + + if self.is_lora and self.use_distribute and self._is_distributed_src_rank: + raise NotImplementedError("LoRA weight sync is not yet supported for distributed (non-colocated) engines") + if self.use_distribute and self._is_distributed_src_rank: refs_distributed = update_weights_from_distributed( self._group_name, self._model_update_groups, self.weight_version, self.distributed_rollout_engines, - hf_named_tensors, + weight_tensors, ) if refs_distributed: all_refs.extend(refs_distributed) @@ -177,9 +212,12 @@ def _send_to_colocated_engine( ipc_engine, ipc_gather_src, ipc_gather_group, - weight_version, + weight_version=None, + lora_config: dict | None = None, + lora_name: str | None = None, + lora_loaded: bool = False, ) -> tuple[list[ObjectRef], Any]: - # TODO improve + is_lora = lora_config is not None long_live_tensors = [] if getattr(FlattenedTensorBucket, "supports_multi_dtypes", False): @@ -195,10 +233,9 @@ def _send_to_colocated_engine( serialized_tensors = [] for _dtype, named_tensors in converted_named_tensors_by_dtypes.items(): flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=named_tensors) - metadata = flattened_tensor_bucket.get_metadata() flattened_tensor_data = { "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), - "metadata": metadata, + "metadata": flattened_tensor_bucket.get_metadata(), } long_live_tensors.append(flattened_tensor_data) serialized_tensors.append(MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True)) @@ -215,14 +252,32 @@ def _send_to_colocated_engine( refs = [] if dist.get_rank() == ipc_gather_src: - # TODO: here we assume all ranks have the same number of dtypes, not sure if that is correct. - num_dtypes = len(serialized_named_tensors[0]) - for i in range(num_dtypes): - kwargs = { - "serialized_named_tensors": [tensors[i] for tensors in serialized_named_tensors], - "load_format": "flattened_bucket", - "weight_version": str(weight_version), - } - refs.append(ipc_engine.update_weights_from_tensor.remote(**kwargs)) + if is_lora: + if lora_loaded: + ray.get(ipc_engine.unload_lora_adapter.remote(lora_name=lora_name)) + + # (Yusheng) to-do-1: update lora weights from tensors should support multiple dtypes (bf16, fp8, fp16, fp32) + # currently, we only support 1 type. If there are multiple dtypes, we need to serialize the tensors for each dtype. + # Thus, we need to apply the same way as `ipc_engine.update_weights_from_tensor` in future + # (Yusheng) to-do-2: need to add ci test acc here - now it will pass but fail to update lora weights + + refs.append( + ipc_engine.load_lora_adapter_from_tensors.remote( + lora_name=lora_name, + config_dict=lora_config, + serialized_tensors=serialized_named_tensors[0][0], + load_format="flattened_bucket", + ) + ) + + else: + num_dtypes = len(serialized_named_tensors[0]) + for i in range(num_dtypes): + kwargs = { + "serialized_named_tensors": [tensors[i] for tensors in serialized_named_tensors], + "load_format": "flattened_bucket", + "weight_version": str(weight_version), + } + refs.append(ipc_engine.update_weights_from_tensor.remote(**kwargs)) return refs, long_live_tensors diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index 06d821831..2d5ea7345 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -13,6 +13,7 @@ from sglang.srt.utils import kill_process_tree from urllib3.exceptions import NewConnectionError +from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME, convert_target_modules_to_hf, is_lora_enabled from miles.ray.ray_actor import RayActor from miles.utils.http_utils import get_host_info @@ -272,6 +273,32 @@ def update_weights_from_tensor( payload, ) + def load_lora_adapter_from_tensors( + self, + lora_name: str, + serialized_tensors: str, + config_dict: dict, + load_format: str | None = None, + pinned: bool = False, + added_tokens_config: dict | None = None, + ): + """Load a LoRA adapter from serialized tensor data.""" + payload = { + "lora_name": lora_name, + "serialized_tensors": serialized_tensors, + "config_dict": config_dict, + "pinned": pinned, + } + if load_format is not None: + payload["load_format"] = load_format + if added_tokens_config is not None: + payload["added_tokens_config"] = added_tokens_config + + return self._make_request( + "load_lora_adapter_from_tensors", + payload, + ) + def flush_cache(self): """Flush the cache of the server.""" if self.node_rank != 0: @@ -336,9 +363,20 @@ def get_weight_version(self): return response.json()["weight_version"] response.raise_for_status() - def release_memory_occupation(self): + def unload_lora_adapter(self, lora_name: str): + """Unload LoRA adapter.""" + return self._make_request( + "unload_lora_adapter", + {"lora_name": lora_name}, + ) + + def release_memory_occupation(self, tags: list[str] = None): + """Release memory occupation. Available tags: weights, kv_cache.""" self.flush_cache() - return self._make_request("release_memory_occupation") + return self._make_request( + "release_memory_occupation", + {"tags": tags}, + ) def resume_memory_occupation(self, tags: list[str] = None): """ @@ -490,6 +528,7 @@ def _compute_server_args( "random_seed": args.seed + rank, # memory "enable_memory_saver": args.offload_rollout, + "enable_weights_cpu_backup": args.offload_rollout, # distributed "host": host, "port": port, @@ -527,6 +566,17 @@ def _compute_server_args( kwargs["dtype"] = "float16" external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS] + if is_lora_enabled(args): + kwargs["enable_lora"] = True + kwargs["max_loras_per_batch"] = 1 + kwargs["max_lora_rank"] = max(getattr(args, "lora_rank", 0), 1) + kwargs["lora_target_modules"] = convert_target_modules_to_hf(args.target_modules) + + if args.lora_adapter_path is not None: + kwargs["lora_paths"] = {LORA_ADAPTER_NAME: args.lora_adapter_path} + else: + logger.info("No pre-trained LoRA adapter_path provided, will use random initial weights") + unused_keys = set(kwargs.keys()) for attr in dataclasses.fields(ServerArgs): if worker_type == "decode" and attr.name == "enable_hierarchical_cache": diff --git a/miles/backends/training_utils/ci_utils.py b/miles/backends/training_utils/ci_utils.py index e4bd0c083..9124889be 100644 --- a/miles/backends/training_utils/ci_utils.py +++ b/miles/backends/training_utils/ci_utils.py @@ -14,6 +14,10 @@ def check_kl(args: Namespace, log_dict: dict[str, float], step_id: int, accumula if args.multi_latent_attention: # TODO: mla currently have non-zero kl, need further investigation assert log_dict["train/ppo_kl"] < 1e-8, f"{log_dict=}" + elif getattr(args, "lora_rank", 0) > 0: + # LoRA weight conversion (Megatron → HF for SGLang) introduces + # small floating-point differences, so use a relaxed threshold. + assert abs(log_dict["train/ppo_kl"]) < 1e-8 and abs(log_dict["train/pg_clipfrac"]) < 1e-10, f"{log_dict=}" else: assert abs(log_dict["train/ppo_kl"]) < 1e-10 and abs(log_dict["train/pg_clipfrac"]) < 1e-10, f"{log_dict=}" if accumulated_step_id == 0 and "train/kl_loss" in log_dict and not args.use_rollout_routing_replay: diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 58967c1ee..32728d47a 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -173,10 +173,14 @@ def save(self, rollout_id): def load(self, rollout_id=None): self.data_source.load(rollout_id) - def offload(self): + def offload(self, tags: list[str] | None = None): self.health_monitoring_pause() return ray.get( - [engine.release_memory_occupation.remote() for engine in self.rollout_engines if engine is not None] + [ + engine.release_memory_occupation.remote(tags=tags) + for engine in self.rollout_engines + if engine is not None + ] ) def onload(self, tags: list[str] | None = None): @@ -188,6 +192,14 @@ def onload(self, tags: list[str] | None = None): ] ) + def health_monitoring_pause(self): + if self.args.use_fault_tolerance and self._health_monitor is not None: + self._health_monitor.pause() + + def health_monitoring_resume(self): + if self.args.use_fault_tolerance and self._health_monitor is not None: + self._health_monitor.resume() + def onload_weights(self): self.onload(tags=[GPU_MEMORY_TYPE_WEIGHTS]) @@ -215,14 +227,6 @@ def clear_num_new_engines(self): # when fault tolerance is not enabled, we need to manually clear num_new_engines after update_weights self.num_new_engines = 0 - def health_monitoring_pause(self) -> None: - if self._health_monitor is not None: - self._health_monitor.pause() - - def health_monitoring_resume(self) -> None: - if self._health_monitor is not None: - self._health_monitor.resume() - def check_weights(self, action: str): return ray.get([engine.check_weights.remote(action=action) for engine in self.rollout_engines]) diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index e80c144a1..0fbb0a563 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -13,6 +13,7 @@ from packaging.version import parse from tqdm import tqdm +from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME, is_lora_enabled from miles.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter from miles.utils.async_utils import run @@ -136,6 +137,9 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A "return_logprob": True, } + if is_lora_enabled(args): + payload["lora_path"] = LORA_ADAPTER_NAME + if args.use_rollout_routing_replay: payload["return_routed_experts"] = True diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 4158b2e8c..3d80574f6 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -106,6 +106,18 @@ def add_cluster_arguments(parser): ), ) + parser.add_argument( + "--offload-rollout-level", + type=str, + nargs="+", + default=["kv_cache", "weight"], + help=( + "Specifies what to offload during rollout when offload-rollout is set. " + "Possible values: 'kv_cache', 'weight'. Default: both 'kv_cache' and 'weight'. " + "Example: --offload-rollout-level kv_cache weight" + ), + ) + reset_arg(parser, "--distributed-backend", type=str, default="nccl") reset_arg(parser, "--distributed-timeout-minutes", type=int, default=10) @@ -931,6 +943,60 @@ def add_algo_arguments(parser): ) return parser + def add_lora_arguments(parser): + """Add LoRA-related arguments for Megatron backend.""" + parser.add_argument( + "--lora-rank", + type=int, + default=0, + help="LoRA rank. Set to 0 to disable LoRA (default: 0)", + ) + parser.add_argument( + "--lora-alpha", + type=int, + default=16, + help="LoRA alpha for scaling (default: 16)", + ) + parser.add_argument( + "--lora-dropout", + type=float, + default=0.0, + help="LoRA dropout rate (default: 0.0)", + ) + parser.add_argument( + "--lora-type", + type=str, + default="lora", + choices=["lora", "canonical_lora"], + help="LoRA variant to use: 'lora' (standard) or 'canonical_lora' (split Q/K/V) (default: lora)", + ) + parser.add_argument( + "--target-modules", + type=str, + default=None, + help="Target modules for LoRA. Use 'all-linear' or comma-separated module names " + "(e.g., 'q_proj,k_proj,v_proj,o_proj' for HF naming or 'linear_qkv,linear_proj' for Megatron naming)", + ) + parser.add_argument( + "--exclude-modules", + type=str, + default=None, + help="Modules to exclude from LoRA (comma-separated)", + ) + parser.add_argument( + "--lora-adapter-path", + type=str, + default=None, + help="Path to load pre-trained LoRA adapter weights (default: None)", + ) + parser.add_argument( + "--lora-sync-from-tensor", + action="store_true", + default=False, + help="Sync LoRA weights via tensor instead of file (more efficient)", + ) + return parser + def add_router_arguments(parser): parser.add_argument( "--use-miles-router", @@ -1396,6 +1462,7 @@ def add_sglang_tp_size(): parser = add_data_arguments(parser) parser = add_eval_arguments(parser) parser = add_algo_arguments(parser) + parser = add_lora_arguments(parser) parser = add_wandb_arguments(parser) parser = add_tensorboard_arguments(parser) parser = add_router_arguments(parser) @@ -1570,6 +1637,27 @@ def miles_validate_args(args): if args.save_interval is not None: assert args.save is not None, "'--save' is required when save_interval is set." + # Parse LoRA target modules + if args.lora_rank > 0: + assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." + + if args.target_modules == "all-linear": + modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + elif "," in args.target_modules: + modules = [m.strip() for m in args.target_modules.split(",")] + else: + modules = [args.target_modules] + + if args.exclude_modules: + exclude_set = ( + set(m.strip() for m in args.exclude_modules.split(",")) + if "," in args.exclude_modules + else {args.exclude_modules} + ) + modules = [m for m in modules if m not in exclude_set] + + args.target_modules = modules + assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]: diff --git a/tests/e2e/lora/test_lora_qwen2.5_0.5B.py b/tests/e2e/lora/test_lora_qwen2.5_0.5B.py new file mode 100644 index 000000000..31ca5f19e --- /dev/null +++ b/tests/e2e/lora/test_lora_qwen2.5_0.5B.py @@ -0,0 +1,135 @@ +"""E2E test for LoRA training with Qwen2.5-0.5B on GSM8K. + +Uses the Megatron backend with bridge mode. Runs a short GRPO training loop +with LoRA enabled (rank=32, all-linear) to validate: + - LoRA model setup via Bridge + - LoRA weight sync to SGLang rollout engines + - LoRA checkpoint save (native + HF PEFT format) + - Training completes without errors + +Requires: 8 GPUs, Qwen2.5-0.5B-Instruct model, GSM8K dataset. +Triggered by label: run-ci-lora +""" + +import os + +import miles.utils.external_utils.command_utils as U + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) + +MODEL_NAME = "Qwen2.5-0.5B-Instruct" +MODEL_TYPE = "qwen2.5-0.5B" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command("hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " "--megatron-to-hf-mode bridge " + + lora_args = "--lora-rank 32 " "--lora-alpha 32 " "--lora-dropout 0.0 " '--target-modules "all-linear" ' + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 1.0 " + "--global-batch-size 32 " + ) + + eval_args = ( + f"{'--eval-interval 2 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 4096 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-5 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = "--rollout-num-gpus-per-engine 1 " "--sglang-mem-fraction-static 0.4 " + + ci_args = "--ci-test " + + save_args = "--save-interval 2 " "--save /root/checkpoints/lora-qwen2.5-0.5B-ci " + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--calculate-per-token-loss " + "--use-miles-router " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{lora_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{save_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/fast/backends/__init__.py b/tests/fast/backends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/backends/megatron_utils/__init__.py b/tests/fast/backends/megatron_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/backends/megatron_utils/test_lora_checkpoint_helpers.py b/tests/fast/backends/megatron_utils/test_lora_checkpoint_helpers.py new file mode 100644 index 000000000..2a4baab3b --- /dev/null +++ b/tests/fast/backends/megatron_utils/test_lora_checkpoint_helpers.py @@ -0,0 +1,97 @@ +"""Unit tests for LoRA-related helpers in miles.backends.megatron_utils.checkpoint. + +Covers pure path-detection functions and the LoRA branch routing in +save_checkpoint_with_lora / load_checkpoint — the latter using mocks to avoid +GPU / distributed requirements. +""" + +import sys +from argparse import Namespace +from unittest.mock import MagicMock, patch + +import pytest + +sys.path.insert(0, "/root/Megatron-LM") + +from miles.backends.megatron_utils.checkpoint import _is_megatron_checkpoint, save_checkpoint_with_lora + +# --------------------------------------------------------------------------- +# _is_megatron_checkpoint +# --------------------------------------------------------------------------- + + +class TestIsMegatronCheckpoint: + def test_has_latest_file(self, tmp_path): + (tmp_path / "latest_checkpointed_iteration.txt").write_text("100") + assert _is_megatron_checkpoint(tmp_path) is True + + def test_iter_dir_name(self, tmp_path): + iter_dir = tmp_path / "iter_0000100" + iter_dir.mkdir() + assert _is_megatron_checkpoint(iter_dir) is True + + def test_regular_dir(self, tmp_path): + assert _is_megatron_checkpoint(tmp_path) is False + + def test_hf_checkpoint_dir(self, tmp_path): + (tmp_path / "config.json").write_text("{}") + (tmp_path / "model.safetensors").write_text("") + assert _is_megatron_checkpoint(tmp_path) is False + + @pytest.mark.parametrize( + "name", + [ + "iter_0000001", + "iter_0000000", + "iter_9999999", + ], + ) + def test_valid_iter_patterns(self, tmp_path, name): + d = tmp_path / name + d.mkdir() + assert _is_megatron_checkpoint(d) is True + + @pytest.mark.parametrize( + "name", + [ + "iter_123", # too short + "iter_00000001", # too long + "iteration_0000001", + "checkpoint", + ], + ) + def test_invalid_iter_patterns(self, tmp_path, name): + d = tmp_path / name + d.mkdir() + assert _is_megatron_checkpoint(d) is False + + +# --------------------------------------------------------------------------- +# save_checkpoint_with_lora — branch routing +# --------------------------------------------------------------------------- + + +class TestSaveCheckpointWithLoRA: + @patch("miles.backends.megatron_utils.checkpoint.get_args") + @patch("miles.backends.megatron_utils.checkpoint.save_lora_checkpoint") + @patch("miles.backends.megatron_utils.checkpoint.is_lora_model", return_value=True) + def test_lora_model_saves_adapter(self, mock_is_lora, mock_save_lora, mock_get_args, tmp_path): + mock_get_args.return_value = Namespace(save=str(tmp_path)) + model = [MagicMock()] + + save_checkpoint_with_lora(42, model, MagicMock(), MagicMock()) + + mock_save_lora.assert_called_once() + call_args = mock_save_lora.call_args + assert "adapter" in call_args[1].get("save_dir", call_args[0][2] if len(call_args[0]) > 2 else "") + + @patch("miles.backends.megatron_utils.checkpoint.get_args") + @patch("miles.backends.megatron_utils.checkpoint.save_checkpoint") + @patch("miles.backends.megatron_utils.checkpoint.is_lora_model", return_value=False) + def test_non_lora_model_saves_regular(self, mock_is_lora, mock_save_ckpt, mock_get_args, tmp_path): + mock_get_args.return_value = Namespace(save=str(tmp_path)) + model = [MagicMock()] + + save_checkpoint_with_lora(42, model, MagicMock(), MagicMock()) + + mock_save_ckpt.assert_called_once() diff --git a/tests/fast/backends/megatron_utils/test_lora_hf_weight_iterator.py b/tests/fast/backends/megatron_utils/test_lora_hf_weight_iterator.py new file mode 100644 index 000000000..9319b4a9d --- /dev/null +++ b/tests/fast/backends/megatron_utils/test_lora_hf_weight_iterator.py @@ -0,0 +1,80 @@ +"""Unit tests for HfWeightIteratorBase factory routing with LoRA flag. + +Validates that the is_lora flag is correctly propagated through the factory +and that the right iterator subclass is selected based on megatron_to_hf_mode. +""" + +from argparse import Namespace +from unittest.mock import MagicMock, patch + +import pytest + +from miles.backends.megatron_utils.update_weight.hf_weight_iterator_base import HfWeightIteratorBase + + +class _ConcreteIterator(HfWeightIteratorBase): + def get_hf_weight_chunks(self, megatron_local_weights): + return [] + + +_BASE_MODULE = "miles.backends.megatron_utils.update_weight.hf_weight_iterator_base" + + +class TestHfWeightIteratorFactory: + def _make_args(self, mode="bridge"): + return Namespace( + megatron_to_hf_mode=mode, + hf_checkpoint="/fake/path", + update_weight_buffer_size=1, + ) + + @patch(f"{_BASE_MODULE}.HfWeightIteratorBase.__init__", return_value=None) + def test_bridge_mode_creates_bridge_iterator(self, mock_init): + """Factory should select HfWeightIteratorBridge for 'bridge' mode.""" + from miles.backends.megatron_utils.update_weight.hf_weight_iterator_bridge import HfWeightIteratorBridge + + with patch.object(HfWeightIteratorBridge, "__init__", return_value=None): + args = self._make_args("bridge") + iterator = HfWeightIteratorBase.create( + args=args, model=[MagicMock()], is_lora=True, model_name="qwen", quantization_config=None + ) + assert isinstance(iterator, HfWeightIteratorBridge) + + @patch(f"{_BASE_MODULE}.HfWeightIteratorBase.__init__", return_value=None) + def test_raw_mode_creates_direct_iterator(self, mock_init): + """Factory should select HfWeightIteratorDirect for 'raw' mode.""" + from miles.backends.megatron_utils.update_weight.hf_weight_iterator_direct import HfWeightIteratorDirect + + with patch.object(HfWeightIteratorDirect, "__init__", return_value=None): + args = self._make_args("raw") + iterator = HfWeightIteratorBase.create( + args=args, model=[MagicMock()], is_lora=False, model_name="qwen", quantization_config=None + ) + assert isinstance(iterator, HfWeightIteratorDirect) + + def test_invalid_mode_raises(self): + args = self._make_args("invalid_mode") + with pytest.raises(KeyError): + HfWeightIteratorBase.create( + args=args, model=[MagicMock()], is_lora=False, model_name="qwen", quantization_config=None + ) + + def test_is_lora_stored_on_instance(self): + """Verify is_lora attribute is stored on the base class.""" + instance = _ConcreteIterator( + args=self._make_args(), + model=[MagicMock()], + model_name="qwen", + quantization_config=None, + is_lora=True, + ) + assert instance.is_lora is True + + def test_is_lora_default_false(self): + instance = _ConcreteIterator( + args=self._make_args(), + model=[MagicMock()], + model_name="qwen", + quantization_config=None, + ) + assert instance.is_lora is False diff --git a/tests/fast/backends/megatron_utils/test_lora_model_branches.py b/tests/fast/backends/megatron_utils/test_lora_model_branches.py new file mode 100644 index 000000000..dc555bbdb --- /dev/null +++ b/tests/fast/backends/megatron_utils/test_lora_model_branches.py @@ -0,0 +1,203 @@ +"""Mock-based tests for LoRA branch logic in miles.backends.megatron_utils.model. + +Validates that setup_model_and_optimizer, save, and save_hf_model correctly +route to LoRA-specific code paths depending on configuration — without GPU. +""" + +from argparse import Namespace +from unittest.mock import MagicMock, patch + + +# --------------------------------------------------------------------------- +# _ensure_model_list +# --------------------------------------------------------------------------- + + +class TestEnsureModelList: + def test_list_passthrough(self): + from miles.backends.megatron_utils.model import _ensure_model_list + + models = [MagicMock(), MagicMock()] + assert _ensure_model_list(models) is models + + def test_non_list_wrapped(self): + from miles.backends.megatron_utils.model import _ensure_model_list + + model = MagicMock() + result = _ensure_model_list(model) + assert isinstance(result, list) + assert result[0] is model + + +# --------------------------------------------------------------------------- +# should_disable_forward_pre_hook +# --------------------------------------------------------------------------- + + +class TestShouldDisableForwardPreHook: + def test_both_true(self): + from miles.backends.megatron_utils.model import should_disable_forward_pre_hook + + args = Namespace(use_distributed_optimizer=True, overlap_param_gather=True) + assert should_disable_forward_pre_hook(args) is True + + def test_optimizer_false(self): + from miles.backends.megatron_utils.model import should_disable_forward_pre_hook + + args = Namespace(use_distributed_optimizer=False, overlap_param_gather=True) + assert should_disable_forward_pre_hook(args) is False + + def test_overlap_false(self): + from miles.backends.megatron_utils.model import should_disable_forward_pre_hook + + args = Namespace(use_distributed_optimizer=True, overlap_param_gather=False) + assert should_disable_forward_pre_hook(args) is False + + +# --------------------------------------------------------------------------- +# setup_model_and_optimizer — LoRA branch routing +# --------------------------------------------------------------------------- + + +_MODEL_MODULE = "miles.backends.megatron_utils.model" + + +class TestSetupModelAndOptimizerLoraBranch: + """Verify that LoRA-enabled actor + bridge mode routes to _setup_lora_model_via_bridge.""" + + def _make_args(self, lora_rank=32, role="actor", mode="bridge"): + return Namespace( + lora_rank=lora_rank, + lora_adapter_path=None, + megatron_to_hf_mode=mode, + moe_use_upcycling=False, + load="/some/path", + pretrained_checkpoint=None, + # optimizer fields + num_rollout=10, + rollout_batch_size=8, + n_samples_per_prompt=8, + global_batch_size=32, + lr_decay_iters=None, + lr_wsd_decay_iters=None, + lr_warmup_fraction=None, + lr_warmup_iters=0, + lr_warmup_init=0, + lr=1e-5, + min_lr=0, + lr_decay_style="constant", + start_weight_decay=0, + end_weight_decay=0, + weight_decay_incr_style="constant", + use_checkpoint_opt_param_scheduler=False, + override_opt_param_scheduler=False, + lr_wsd_decay_style="linear", + enable_gloo_process_groups=False, + ) + + @patch(f"{_MODEL_MODULE}.get_optimizer_param_scheduler") + @patch(f"{_MODEL_MODULE}.get_megatron_optimizer") + @patch(f"{_MODEL_MODULE}._setup_lora_model_via_bridge") + def test_lora_actor_bridge_routes_to_lora_setup(self, mock_lora_setup, mock_opt, mock_sched): + from miles.backends.megatron_utils.model import setup_model_and_optimizer + + mock_lora_setup.return_value = [MagicMock()] + mock_opt.return_value = MagicMock(param_groups=[]) + mock_sched.return_value = MagicMock() + + args = self._make_args(lora_rank=32, role="actor", mode="bridge") + model, _, _ = setup_model_and_optimizer(args, role="actor") + + mock_lora_setup.assert_called_once_with(args) + + @patch(f"{_MODEL_MODULE}.get_optimizer_param_scheduler") + @patch(f"{_MODEL_MODULE}.get_megatron_optimizer") + @patch(f"{_MODEL_MODULE}.get_model") + @patch(f"{_MODEL_MODULE}.get_model_provider_func") + @patch(f"{_MODEL_MODULE}._setup_lora_model_via_bridge") + def test_lora_critic_skips_lora_setup(self, mock_lora_setup, mock_provider, mock_get_model, mock_opt, mock_sched): + from miles.backends.megatron_utils.model import setup_model_and_optimizer + + mock_get_model.return_value = [MagicMock()] + mock_opt.return_value = MagicMock(param_groups=[]) + mock_sched.return_value = MagicMock() + + args = self._make_args(lora_rank=32, role="critic", mode="bridge") + setup_model_and_optimizer(args, role="critic") + + mock_lora_setup.assert_not_called() + mock_get_model.assert_called_once() + + @patch(f"{_MODEL_MODULE}.get_optimizer_param_scheduler") + @patch(f"{_MODEL_MODULE}.get_megatron_optimizer") + @patch(f"{_MODEL_MODULE}.get_model") + @patch(f"{_MODEL_MODULE}.get_model_provider_func") + @patch(f"{_MODEL_MODULE}._setup_lora_model_via_bridge") + def test_non_lora_skips_lora_setup(self, mock_lora_setup, mock_provider, mock_get_model, mock_opt, mock_sched): + from miles.backends.megatron_utils.model import setup_model_and_optimizer + + mock_get_model.return_value = [MagicMock()] + mock_opt.return_value = MagicMock(param_groups=[]) + mock_sched.return_value = MagicMock() + + args = self._make_args(lora_rank=0, role="actor", mode="bridge") + setup_model_and_optimizer(args, role="actor") + + mock_lora_setup.assert_not_called() + mock_get_model.assert_called_once() + + @patch(f"{_MODEL_MODULE}.get_optimizer_param_scheduler") + @patch(f"{_MODEL_MODULE}.get_megatron_optimizer") + @patch(f"{_MODEL_MODULE}.get_model") + @patch(f"{_MODEL_MODULE}._setup_lora_model_via_bridge") + def test_lora_raw_mode_skips_bridge(self, mock_lora_setup, mock_get_model, mock_opt, mock_sched): + from miles.backends.megatron_utils.model import setup_model_and_optimizer + + mock_get_model.return_value = [MagicMock()] + mock_opt.return_value = MagicMock(param_groups=[]) + mock_sched.return_value = MagicMock() + + args = self._make_args(lora_rank=32, role="actor", mode="raw") + setup_model_and_optimizer(args, role="actor") + + mock_lora_setup.assert_not_called() + mock_get_model.assert_called_once() + + +# --------------------------------------------------------------------------- +# save — LoRA vs regular branch +# --------------------------------------------------------------------------- + + +class TestSaveLoRaBranch: + @patch(f"{_MODEL_MODULE}.enable_forward_pre_hook") + @patch(f"{_MODEL_MODULE}.disable_forward_pre_hook") + @patch(f"{_MODEL_MODULE}.should_disable_forward_pre_hook", return_value=False) + @patch(f"{_MODEL_MODULE}.get_args") + @patch(f"{_MODEL_MODULE}.save_checkpoint_with_lora") + @patch(f"{_MODEL_MODULE}.is_lora_model", return_value=True) + def test_lora_model_calls_lora_save( + self, mock_is_lora, mock_save_lora, mock_get_args, mock_should, mock_disable, mock_enable + ): + from miles.backends.megatron_utils.model import save + + model = [MagicMock()] + save(42, model, MagicMock(), MagicMock()) + + mock_save_lora.assert_called_once() + + @patch(f"{_MODEL_MODULE}.enable_forward_pre_hook") + @patch(f"{_MODEL_MODULE}.disable_forward_pre_hook") + @patch(f"{_MODEL_MODULE}.should_disable_forward_pre_hook", return_value=False) + @patch(f"{_MODEL_MODULE}.get_args") + @patch(f"{_MODEL_MODULE}.save_checkpoint") + @patch(f"{_MODEL_MODULE}.is_lora_model", return_value=False) + def test_non_lora_model_calls_regular_save( + self, mock_is_lora, mock_save_ckpt, mock_get_args, mock_should, mock_disable, mock_enable + ): + from miles.backends.megatron_utils.model import save + + model = [MagicMock()] + save(42, model, MagicMock(), MagicMock()) + + mock_save_ckpt.assert_called_once() diff --git a/tests/fast/backends/megatron_utils/test_lora_update_weight.py b/tests/fast/backends/megatron_utils/test_lora_update_weight.py new file mode 100644 index 000000000..c4b3f328a --- /dev/null +++ b/tests/fast/backends/megatron_utils/test_lora_update_weight.py @@ -0,0 +1,120 @@ +"""Mock-based tests for LoRA weight-sync logic in update_weight_from_tensor.py. + +Validates that _send_hf_params correctly separates LoRA vs base weights +and that UpdateWeightFromTensor initialises _lora_config only when LoRA is active. +""" + +from argparse import Namespace +from unittest.mock import MagicMock, patch + +import torch + +from miles.backends.megatron_utils.lora_utils import is_lora_weight_name + +# --------------------------------------------------------------------------- +# LoRA / base weight separation (pure logic, no distributed deps) +# --------------------------------------------------------------------------- + + +class TestLoraWeightSeparation: + """Test the filtering logic that _send_hf_params relies on.""" + + SAMPLE_WEIGHTS = [ + ("model.layers.0.self_attn.q_proj.weight", torch.randn(4, 4)), + ("model.layers.0.self_attn.q_proj.lora_A.weight", torch.randn(4, 2)), + ("model.layers.0.self_attn.q_proj.lora_B.weight", torch.randn(2, 4)), + ("model.layers.0.mlp.gate_proj.weight", torch.randn(8, 4)), + ("model.layers.0.mlp.gate_proj.lora_A.weight", torch.randn(8, 2)), + ("model.layers.0.mlp.gate_proj.lora_B.weight", torch.randn(2, 8)), + ] + + def test_separation_when_lora(self): + base = [(n, t) for n, t in self.SAMPLE_WEIGHTS if not is_lora_weight_name(n)] + lora = [(n, t) for n, t in self.SAMPLE_WEIGHTS if is_lora_weight_name(n)] + assert len(base) == 2 + assert len(lora) == 4 + + def test_no_separation_when_not_lora(self): + base = self.SAMPLE_WEIGHTS + lora = [] + assert len(base) == 6 + assert len(lora) == 0 + + def test_lora_names_contain_lora_A_or_B(self): + lora = [(n, t) for n, t in self.SAMPLE_WEIGHTS if is_lora_weight_name(n)] + for name, _ in lora: + assert ".lora_A." in name or ".lora_B." in name + + def test_base_names_do_not_contain_lora(self): + base = [(n, t) for n, t in self.SAMPLE_WEIGHTS if not is_lora_weight_name(n)] + for name, _ in base: + assert ".lora_A." not in name + assert ".lora_B." not in name + + +# --------------------------------------------------------------------------- +# UpdateWeightFromTensor._lora_config initialisation +# --------------------------------------------------------------------------- + + +_UW_MODULE = "miles.backends.megatron_utils.update_weight.update_weight_from_tensor" + + +class TestUpdateWeightFromTensorLoraConfig: + """Verify _lora_config is set only when is_lora=True.""" + + def _make_args(self): + return Namespace( + lora_rank=32, + lora_alpha=32, + lora_dropout=0.0, + target_modules=["linear_qkv", "linear_proj"], + megatron_to_hf_mode="bridge", + rollout_num_gpus_per_engine=2, + hf_checkpoint="/fake/path", + update_weight_buffer_size=1, + ) + + @patch(f"{_UW_MODULE}.dist") + @patch(f"{_UW_MODULE}.HfWeightIteratorBase") + def test_lora_true_sets_config(self, mock_iter_base, mock_dist): + from miles.backends.megatron_utils.update_weight.update_weight_from_tensor import UpdateWeightFromTensor + + mock_dist.get_world_size.return_value = 2 + mock_dist.get_rank.return_value = 0 + mock_dist.new_group.return_value = MagicMock() + mock_iter_base.create.return_value = MagicMock() + + args = self._make_args() + updater = UpdateWeightFromTensor( + args=args, + model=[MagicMock()], + weights_getter=lambda: {}, + model_name="qwen", + quantization_config=None, + is_lora=True, + ) + assert updater._lora_config is not None + assert updater._lora_config["peft_type"] == "LORA" + assert updater._lora_config["r"] == 32 + + @patch(f"{_UW_MODULE}.dist") + @patch(f"{_UW_MODULE}.HfWeightIteratorBase") + def test_lora_false_no_config(self, mock_iter_base, mock_dist): + from miles.backends.megatron_utils.update_weight.update_weight_from_tensor import UpdateWeightFromTensor + + mock_dist.get_world_size.return_value = 2 + mock_dist.get_rank.return_value = 0 + mock_dist.new_group.return_value = MagicMock() + mock_iter_base.create.return_value = MagicMock() + + args = self._make_args() + updater = UpdateWeightFromTensor( + args=args, + model=[MagicMock()], + weights_getter=lambda: {}, + model_name="qwen", + quantization_config=None, + is_lora=False, + ) + assert updater._lora_config is None diff --git a/tests/fast/backends/megatron_utils/test_lora_utils.py b/tests/fast/backends/megatron_utils/test_lora_utils.py new file mode 100644 index 000000000..779ade00e --- /dev/null +++ b/tests/fast/backends/megatron_utils/test_lora_utils.py @@ -0,0 +1,335 @@ +"""Unit tests for miles.backends.megatron_utils.lora_utils. + +Tests cover module name conversion, LoRA detection helpers, parameter identification, +exclude-module parsing, and LoRA sync config building — all without GPU. +""" + +from argparse import Namespace +from unittest.mock import MagicMock + +import pytest + +from miles.backends.megatron_utils.lora_utils import ( + LORA_ADAPTER_NAME, + _get_lora_class_name, + _is_adapter_param_name, + build_lora_sync_config, + convert_target_modules_to_hf, + convert_target_modules_to_megatron, + is_lora_enabled, + is_lora_weight_name, + parse_exclude_modules, +) + + +# --------------------------------------------------------------------------- +# _get_lora_class_name +# --------------------------------------------------------------------------- + + +class TestGetLoraClassName: + def test_none_returns_canonical(self): + assert _get_lora_class_name(None) == "CanonicalLoRA" + + def test_type_returns_class_name(self): + class FakeLoRA: + pass + + assert _get_lora_class_name(FakeLoRA) == "FakeLoRA" + + def test_instance_returns_class_name(self): + class FakeLoRA: + pass + + assert _get_lora_class_name(FakeLoRA()) == "FakeLoRA" + + +# --------------------------------------------------------------------------- +# convert_target_modules_to_megatron +# --------------------------------------------------------------------------- + + +def _make_lora_type(name: str): + """Helper to create a mock lora_type whose class name matches *name*.""" + mock = MagicMock() + type(mock).__name__ = name + return mock + + +class TestConvertTargetModulesToMegatron: + # --- "all-linear" variants ------------------------------------------------ + + @pytest.mark.parametrize("shorthand", ["all", "all-linear", "all_linear"]) + def test_all_linear_string_canonical(self, shorthand): + result = convert_target_modules_to_megatron(shorthand, lora_type=None) + assert result == [ + "linear_q", + "linear_k", + "linear_v", + "linear_proj", + "linear_fc1_up", + "linear_fc1_gate", + "linear_fc2", + ] + + @pytest.mark.parametrize("shorthand", ["all", "all-linear", "all_linear"]) + def test_all_linear_string_standard(self, shorthand): + lora_type = _make_lora_type("LoRA") + result = convert_target_modules_to_megatron(shorthand, lora_type=lora_type) + assert result == ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + + @pytest.mark.parametrize("shorthand", ["all", "all-linear", "all_linear"]) + def test_all_linear_single_element_list(self, shorthand): + result = convert_target_modules_to_megatron([shorthand], lora_type=None) + assert len(result) == 7 # CanonicalLoRA has 7 modules + + # --- HF -> Megatron conversion (standard LoRA) ---------------------------- + + def test_hf_to_megatron_standard_dedup(self): + lora = _make_lora_type("LoRA") + result = convert_target_modules_to_megatron(["q_proj", "k_proj", "v_proj"], lora_type=lora) + assert result == ["linear_qkv"] + + def test_hf_to_megatron_standard_all_modules(self): + lora = _make_lora_type("LoRA") + modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + result = convert_target_modules_to_megatron(modules, lora_type=lora) + assert result == ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + + # --- HF -> Megatron conversion (CanonicalLoRA) ---------------------------- + + def test_hf_to_megatron_canonical_split(self): + result = convert_target_modules_to_megatron(["q_proj", "k_proj", "v_proj"], lora_type=None) + assert result == ["linear_q", "linear_k", "linear_v"] + + def test_hf_to_megatron_canonical_gate_up(self): + result = convert_target_modules_to_megatron(["gate_proj", "up_proj"], lora_type=None) + assert result == ["linear_fc1_gate", "linear_fc1_up"] + + # --- Already in Megatron format ------------------------------------------- + + def test_megatron_format_passthrough(self): + modules = ["linear_qkv", "linear_proj"] + result = convert_target_modules_to_megatron(modules, lora_type=None) + assert result == modules + + def test_megatron_format_passthrough_canonical(self): + modules = ["linear_q", "linear_k", "linear_v"] + result = convert_target_modules_to_megatron(modules, lora_type=None) + assert result == modules + + # --- Single string input -------------------------------------------------- + + def test_single_hf_string_input(self): + lora = _make_lora_type("LoRA") + result = convert_target_modules_to_megatron("o_proj", lora_type=lora) + assert result == ["linear_proj"] + + +# --------------------------------------------------------------------------- +# convert_target_modules_to_hf +# --------------------------------------------------------------------------- + + +class TestConvertTargetModulesToHf: + def test_standard_linear_qkv(self): + assert convert_target_modules_to_hf(["linear_qkv"]) == ["q_proj", "k_proj", "v_proj"] + + def test_standard_linear_proj(self): + assert convert_target_modules_to_hf(["linear_proj"]) == ["o_proj"] + + def test_standard_linear_fc1(self): + assert convert_target_modules_to_hf(["linear_fc1"]) == ["gate_proj", "up_proj"] + + def test_standard_linear_fc2(self): + assert convert_target_modules_to_hf(["linear_fc2"]) == ["down_proj"] + + def test_canonical_split_modules(self): + result = convert_target_modules_to_hf(["linear_q", "linear_k", "linear_v"]) + assert result == ["q_proj", "k_proj", "v_proj"] + + def test_canonical_fc1_gate_up(self): + result = convert_target_modules_to_hf(["linear_fc1_gate", "linear_fc1_up"]) + assert result == ["gate_proj", "up_proj"] + + def test_unknown_module_passthrough(self): + assert convert_target_modules_to_hf(["some_custom_module"]) == ["some_custom_module"] + + def test_roundtrip_canonical_all_linear(self): + megatron = convert_target_modules_to_megatron("all-linear", lora_type=None) + hf = convert_target_modules_to_hf(megatron) + assert set(hf) == {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"} + + def test_roundtrip_standard_all_linear(self): + lora = _make_lora_type("LoRA") + megatron = convert_target_modules_to_megatron("all-linear", lora_type=lora) + hf = convert_target_modules_to_hf(megatron) + assert set(hf) == {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"} + + +# --------------------------------------------------------------------------- +# is_lora_enabled +# --------------------------------------------------------------------------- + + +class TestIsLoraEnabled: + def test_enabled_by_rank(self): + args = Namespace(lora_rank=32, lora_adapter_path=None) + assert is_lora_enabled(args) is True + + def test_enabled_by_adapter_path(self): + args = Namespace(lora_rank=0, lora_adapter_path="/some/path") + assert is_lora_enabled(args) is True + + def test_enabled_by_both(self): + args = Namespace(lora_rank=16, lora_adapter_path="/some/path") + assert is_lora_enabled(args) is True + + def test_disabled(self): + args = Namespace(lora_rank=0, lora_adapter_path=None) + assert is_lora_enabled(args) is False + + def test_disabled_missing_attrs(self): + args = Namespace() + assert is_lora_enabled(args) is False + + +# --------------------------------------------------------------------------- +# is_lora_weight_name / _is_adapter_param_name +# --------------------------------------------------------------------------- + + +class TestIsLoraWeightName: + @pytest.mark.parametrize( + "name", + [ + "model.layers.0.self_attn.q_proj.lora_A.weight", + "model.layers.0.self_attn.q_proj.lora_B.weight", + "base_model.model.layers.5.mlp.gate_proj.lora_A.default.weight", + "base_model.model.layers.5.mlp.gate_proj.lora_B.default.weight", + ], + ) + def test_positive(self, name): + assert is_lora_weight_name(name) is True + + @pytest.mark.parametrize( + "name", + [ + "model.layers.0.self_attn.q_proj.weight", + "model.embed_tokens.weight", + "lm_head.weight", + "model.layers.0.mlp.gate_proj.weight", + ], + ) + def test_negative(self, name): + assert is_lora_weight_name(name) is False + + +class TestIsAdapterParamName: + @pytest.mark.parametrize( + "name", + [ + "module.decoder.layers.0.self_attention.linear_qkv.lora_A.weight", + "module.decoder.layers.0.self_attention.linear_qkv.adapter.linear_in.weight", + "module.decoder.layers.0.self_attention.linear_qkv.adapter.linear_out.weight", + ], + ) + def test_positive(self, name): + assert _is_adapter_param_name(name) is True + + @pytest.mark.parametrize( + "name", + [ + "module.decoder.layers.0.self_attention.linear_qkv.weight", + "module.decoder.layers.0.mlp.linear_fc1.weight", + "module.embedding.word_embeddings.weight", + ], + ) + def test_negative(self, name): + assert _is_adapter_param_name(name) is False + + +# --------------------------------------------------------------------------- +# parse_exclude_modules +# --------------------------------------------------------------------------- + + +class TestParseExcludeModules: + def test_none(self): + args = Namespace(exclude_modules=None) + assert parse_exclude_modules(args) == [] + + def test_single_module_string(self): + args = Namespace(exclude_modules="o_proj") + result = parse_exclude_modules(args, lora_type=_make_lora_type("LoRA")) + assert result == ["linear_proj"] + + def test_comma_separated(self): + args = Namespace(exclude_modules="o_proj, down_proj") + result = parse_exclude_modules(args, lora_type=_make_lora_type("LoRA")) + assert set(result) == {"linear_proj", "linear_fc2"} + + def test_list_input(self): + args = Namespace(exclude_modules=["o_proj", "down_proj"]) + result = parse_exclude_modules(args, lora_type=_make_lora_type("LoRA")) + assert set(result) == {"linear_proj", "linear_fc2"} + + def test_missing_attr(self): + args = Namespace() + assert parse_exclude_modules(args) == [] + + +# --------------------------------------------------------------------------- +# build_lora_sync_config +# --------------------------------------------------------------------------- + + +class TestBuildLoraSyncConfig: + def test_basic_config(self): + args = Namespace( + lora_rank=32, + lora_alpha=32, + lora_dropout=0.0, + target_modules=["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"], + ) + config = build_lora_sync_config(args) + assert config["peft_type"] == "LORA" + assert config["r"] == 32 + assert config["lora_alpha"] == 32 + assert config["lora_dropout"] == 0.0 + assert config["bias"] == "none" + assert config["task_type"] == "CAUSAL_LM" + assert set(config["target_modules"]) == { + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + } + + def test_no_target_modules_uses_default(self): + args = Namespace(lora_rank=16, lora_alpha=16, lora_dropout=0.0, target_modules=None) + config = build_lora_sync_config(args) + assert len(config["target_modules"]) == 7 + + def test_canonical_target_modules(self): + args = Namespace( + lora_rank=8, + lora_alpha=8, + lora_dropout=0.1, + target_modules=["linear_q", "linear_k"], + ) + config = build_lora_sync_config(args) + assert config["target_modules"] == ["q_proj", "k_proj"] + assert config["r"] == 8 + + +# --------------------------------------------------------------------------- +# LORA_ADAPTER_NAME constant +# --------------------------------------------------------------------------- + + +def test_lora_adapter_name_constant(): + assert LORA_ADAPTER_NAME == "miles_lora" diff --git a/tests/fast/utils/test_lora_arguments.py b/tests/fast/utils/test_lora_arguments.py new file mode 100644 index 000000000..f4e78839e --- /dev/null +++ b/tests/fast/utils/test_lora_arguments.py @@ -0,0 +1,119 @@ +"""Unit tests for LoRA-related argument parsing in miles.utils.arguments. + +Covers the target-module expansion and exclude-module filtering logic +inside miles_validate_args (lines 1634-1653 of arguments.py). +We isolate the LoRA parsing logic to avoid triggering unrelated validations. +""" + +from argparse import Namespace +from copy import deepcopy + +import pytest + + +def _apply_lora_arg_parsing(args: Namespace) -> Namespace: + """Extract and apply only the LoRA target-module parsing logic from + miles_validate_args, avoiding unrelated assertions.""" + args = deepcopy(args) + if args.lora_rank > 0: + assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." + + if args.target_modules == "all-linear": + modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + elif "," in args.target_modules: + modules = [m.strip() for m in args.target_modules.split(",")] + else: + modules = [args.target_modules] + + if args.exclude_modules: + exclude_set = ( + set(m.strip() for m in args.exclude_modules.split(",")) + if "," in args.exclude_modules + else {args.exclude_modules} + ) + modules = [m for m in modules if m not in exclude_set] + + args.target_modules = modules + return args + + +# --------------------------------------------------------------------------- +# Target modules expansion +# --------------------------------------------------------------------------- + + +class TestLoraTargetModuleParsing: + def test_all_linear_expands_to_seven_modules(self): + args = Namespace(lora_rank=32, target_modules="all-linear", exclude_modules=None) + result = _apply_lora_arg_parsing(args) + assert result.target_modules == ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + + def test_comma_separated_split(self): + args = Namespace(lora_rank=16, target_modules="q_proj, k_proj, v_proj", exclude_modules=None) + result = _apply_lora_arg_parsing(args) + assert result.target_modules == ["q_proj", "k_proj", "v_proj"] + + def test_comma_separated_no_spaces(self): + args = Namespace(lora_rank=16, target_modules="q_proj,k_proj", exclude_modules=None) + result = _apply_lora_arg_parsing(args) + assert result.target_modules == ["q_proj", "k_proj"] + + def test_single_module(self): + args = Namespace(lora_rank=8, target_modules="q_proj", exclude_modules=None) + result = _apply_lora_arg_parsing(args) + assert result.target_modules == ["q_proj"] + + def test_lora_rank_zero_skips_parsing(self): + args = Namespace(lora_rank=0, target_modules="all-linear", exclude_modules=None) + result = _apply_lora_arg_parsing(args) + assert result.target_modules == "all-linear" # unchanged + + def test_missing_target_modules_asserts(self): + args = Namespace(lora_rank=32, target_modules=None, exclude_modules=None) + with pytest.raises(AssertionError, match="--target-modules"): + _apply_lora_arg_parsing(args) + + +# --------------------------------------------------------------------------- +# Exclude modules filtering +# --------------------------------------------------------------------------- + + +class TestLoraExcludeModules: + def test_single_exclude(self): + args = Namespace(lora_rank=32, target_modules="all-linear", exclude_modules="o_proj") + result = _apply_lora_arg_parsing(args) + assert "o_proj" not in result.target_modules + assert len(result.target_modules) == 6 + + def test_multiple_exclude_comma_separated(self): + args = Namespace(lora_rank=32, target_modules="all-linear", exclude_modules="o_proj, down_proj") + result = _apply_lora_arg_parsing(args) + assert "o_proj" not in result.target_modules + assert "down_proj" not in result.target_modules + assert len(result.target_modules) == 5 + + def test_exclude_all_results_in_empty(self): + args = Namespace( + lora_rank=32, + target_modules="q_proj,k_proj", + exclude_modules="q_proj,k_proj", + ) + result = _apply_lora_arg_parsing(args) + assert result.target_modules == [] + + def test_exclude_nonexistent_module_no_effect(self): + args = Namespace(lora_rank=32, target_modules="q_proj,k_proj", exclude_modules="nonexistent") + result = _apply_lora_arg_parsing(args) + assert result.target_modules == ["q_proj", "k_proj"] + + def test_no_exclude_modules(self): + args = Namespace(lora_rank=32, target_modules="q_proj,k_proj", exclude_modules=None) + result = _apply_lora_arg_parsing(args) + assert result.target_modules == ["q_proj", "k_proj"] + + def test_empty_string_exclude(self): + """Empty string is truthy; should be treated as a single (non-matching) exclude.""" + args = Namespace(lora_rank=32, target_modules="q_proj,k_proj", exclude_modules="") + result = _apply_lora_arg_parsing(args) + assert result.target_modules == ["q_proj", "k_proj"] diff --git a/train.py b/train.py index 745dcbed6..bd22ba561 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,5 @@ import ray +from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from miles.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models from miles.utils.arguments import parse_args @@ -70,7 +71,12 @@ def save(rollout_id): rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id)) if args.offload_rollout: - ray.get(rollout_manager.offload.remote()) + offload_tags = [GPU_MEMORY_TYPE_CUDA_GRAPH] + if "kv_cache" in args.offload_rollout_level: + offload_tags.append(GPU_MEMORY_TYPE_KV_CACHE) + if "weight" in args.offload_rollout_level: + offload_tags.append(GPU_MEMORY_TYPE_WEIGHTS) + ray.get(rollout_manager.offload.remote(tags=offload_tags)) if args.use_critic: critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref)