diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 4ff6988f1..6118bb9bb 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -508,6 +508,8 @@ def _log_rollout_data(self, rollout_id: int, rollout_data, packed_batches): if isinstance(unpacked_batch[metric_key], torch.Tensor): loss_masks_tensor = unpacked_batch["loss_masks"].to(device=torch.cuda.current_device()) metric_tensor = unpacked_batch[metric_key].to(device=torch.cuda.current_device()) + if metric_tensor.shape[0] == 0: + continue val += (metric_tensor * loss_masks_tensor).sum() / loss_masks_tensor.sum().clamp_min(1) else: val += unpacked_batch[metric_key] @@ -578,27 +580,42 @@ def _train_core(self, rollout_id: int, rollout_data) -> None: self.ref_model.load_state_dict(actor_state) self.ref_model.cpu() - def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum): - # Prepare model inputs - model_args = self._get_model_inputs_args(packed_batch) - logits = self.model(**model_args).logits.squeeze(0).float() + def _compute_sft_loss(self, unpacked_batches, logits): + loss_masks = [batch["loss_masks"].to(device=logits.device) for batch in unpacked_batches] + response_lengths = [batch["response_lengths"] for batch in unpacked_batches] + log_probs = torch.cat([batch["cur_log_probs"] for batch in unpacked_batches], dim=0) + loss = -sum_of_sample_mean(log_probs, response_lengths, loss_masks) + # make sure the gradient could backprop correctly. + if log_probs.numel() == 0: + loss += 0 * logits.sum() - # Compute log probs and entropy (unified for both CP and non-CP modes) - log_probs, entropy_result = get_logprob_and_entropy_with_cp( - logits=logits, - target_tokens=packed_batch["tokens"], - cp_rank=self.cp_rank, - cp_size=self.cp_size, - cp_group=self.cp_group, - model_input_ids=model_args["input_ids"], - allow_compile=not self.args.true_on_policy_mode, - temperature=self.args.rollout_temperature, - ) - packed_batch["cur_log_probs"] = log_probs - packed_batch["entropy"] = entropy_result + kl_loss = 0 + if self.args.use_kl_loss: + old_log_prob_key = "rollout_log_probs" if self.args.use_rollout_logprobs else "log_probs" + missing_old_log_probs = [ + idx + for idx, batch in enumerate(unpacked_batches) + if old_log_prob_key not in batch or not isinstance(batch[old_log_prob_key], torch.Tensor) + ] + if missing_old_log_probs: + raise KeyError( + f"{old_log_prob_key} must be provided as torch.Tensor for all microbatches when " + f"use_rollout_logprobs is set to {self.args.use_rollout_logprobs}. Missing in batches: {missing_old_log_probs}" + ) + old_log_probs = torch.cat([batch[old_log_prob_key] for batch in unpacked_batches], dim=0) + old_log_probs = old_log_probs.to(device=log_probs.device) + kl_loss = self._compute_kl_loss(unpacked_batches, log_probs, old_log_probs, response_lengths, loss_masks) + loss += kl_loss - unpacked_batches = unpack_sequences(packed_batch) + reported = { + "loss": loss.detach(), + } + if self.args.use_kl_loss: + reported["kl_loss"] = kl_loss.detach() + return loss, reported + + def _compute_policy_loss(self, unpacked_batches, logits): old_log_prob_key = "rollout_log_probs" if self.args.use_rollout_logprobs else "log_probs" missing_old_log_probs = [ idx @@ -611,11 +628,11 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum): f"use_rollout_logprobs is set to {self.args.use_rollout_logprobs}. Missing in batches: {missing_old_log_probs}" ) old_log_probs = torch.cat([batch[old_log_prob_key] for batch in unpacked_batches], dim=0) + loss_masks = [batch["loss_masks"].to(device=logits.device) for batch in unpacked_batches] + + response_lengths = [batch["response_lengths"] for batch in unpacked_batches] log_probs = torch.cat([batch["cur_log_probs"] for batch in unpacked_batches], dim=0) advantages = torch.cat([batch["advantages"] for batch in unpacked_batches], dim=0) - loss_masks = [batch["loss_masks"].to(device=log_probs.device) for batch in unpacked_batches] - response_lengths = [batch["response_lengths"] for batch in unpacked_batches] - advantages = advantages.to(device=log_probs.device) old_log_probs = old_log_probs.to(device=log_probs.device) ppo_kl = old_log_probs - log_probs @@ -668,7 +685,7 @@ def _has_rollout_log_probs(batch) -> bool: tis_clipfrac = tis_clip != tis pg_loss = pg_loss * tis_clip - + assert not self.args.calculate_per_token_loss, "calculate_per_token_loss not yet implemented" pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks) pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks) @@ -687,20 +704,10 @@ def _has_rollout_log_probs(batch) -> bool: loss = pg_loss - self.args.entropy_coef * entropy_loss + kl_loss = 0 if self.args.use_kl_loss: - ref_log_probs = torch.cat([batch["ref_log_probs"] for batch in unpacked_batches], dim=0) - importance_ratio = None - if self.args.use_unbiased_kl: - importance_ratio = torch.exp(log_probs - old_log_probs) - kl = compute_approx_kl( - log_probs, - ref_log_probs, - kl_loss_type=self.args.kl_loss_type, - importance_ratio=importance_ratio, - ) - kl_loss = sum_of_sample_mean(kl, response_lengths, loss_masks) - - loss = loss + self.args.kl_loss_coef * kl_loss + kl_loss = self._compute_kl_loss(unpacked_batches, log_probs, old_log_probs, response_lengths, loss_masks) + loss += kl_loss reported = { "loss": loss.detach(), @@ -713,9 +720,6 @@ def _has_rollout_log_probs(batch) -> bool: if train_rollout_logprob_abs_diff is not None: reported["train_rollout_logprob_abs_diff"] = train_rollout_logprob_abs_diff - if self.args.use_kl_loss: - reported["kl_loss"] = kl_loss.detach() - if self.args.use_opsm: reported["opsm_clipfrac"] = opsm_clipfrac @@ -724,6 +728,64 @@ def _has_rollout_log_probs(batch) -> bool: reported["ois"] = sum_of_sample_mean(ois, response_lengths, loss_masks).detach() reported["tis_clipfrac"] = sum_of_sample_mean(tis_clipfrac.float(), response_lengths, loss_masks).detach() + if self.args.use_kl_loss: + reported["kl_loss"] = kl_loss.detach() + return loss, reported + + def _compute_kl_loss(self, unpacked_batches, log_probs, old_log_probs, response_lengths, loss_masks): + ref_log_probs = torch.cat([batch["ref_log_probs"] for batch in unpacked_batches], dim=0) + importance_ratio = None + if self.args.use_unbiased_kl: + importance_ratio = torch.exp(log_probs - old_log_probs) + kl = compute_approx_kl( + log_probs, + ref_log_probs, + kl_loss_type=self.args.kl_loss_type, + importance_ratio=importance_ratio, + ) + kl_loss = sum_of_sample_mean(kl, response_lengths, loss_masks) + + kl_term = self.args.kl_loss_coef * kl_loss + return kl_term + + def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum): + # Prepare model inputs + model_args = self._get_model_inputs_args(packed_batch) + logits = self.model(**model_args).logits.squeeze(0).float() + + # Compute log probs and entropy (unified for both CP and non-CP modes) + log_probs, entropy_result = get_logprob_and_entropy_with_cp( + logits=logits, + target_tokens=packed_batch["tokens"], + cp_rank=self.cp_rank, + cp_size=self.cp_size, + cp_group=self.cp_group, + model_input_ids=model_args["input_ids"], + allow_compile=not self.args.true_on_policy_mode, + temperature=self.args.rollout_temperature, + ) + packed_batch["cur_log_probs"] = log_probs + packed_batch["entropy"] = entropy_result + + unpacked_batches = unpack_sequences(packed_batch) + + old_log_prob_key = "rollout_log_probs" if self.args.use_rollout_logprobs else "log_probs" + missing_old_log_probs = [ + idx + for idx, batch in enumerate(unpacked_batches) + if old_log_prob_key not in batch or not isinstance(batch[old_log_prob_key], torch.Tensor) + ] + if missing_old_log_probs: + raise KeyError( + f"{old_log_prob_key} must be provided as torch.Tensor for all microbatches when " + f"use_rollout_logprobs is set to {self.args.use_rollout_logprobs}. Missing in batches: {missing_old_log_probs}" + ) + + if self.args.loss_type == "sft_loss": + loss, reported = self._compute_sft_loss(unpacked_batches, logits) + else: + loss, reported = self._compute_policy_loss(unpacked_batches, logits) + # Scale loss for gradient accumulation loss = loss * self.dp_size / self.args.global_batch_size loss.backward() @@ -760,7 +822,7 @@ def _has_rollout_log_probs(batch) -> bool: # Log learning rate per parameter group; use scheduler's last computed LRs lr_values = self.lr_scheduler.get_last_lr() for gid, _group in enumerate(self.optimizer.param_groups): - log_dict[f"train/lr-pg_{gid}"] = lr_values[gid] + log_dict[f"train/lr_{gid}"] = lr_values[gid] kl_info = "" if self.args.use_kl_loss and "kl_loss" in aggregated: diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index b300ad167..13ea337c0 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -160,7 +160,7 @@ def _get_rollout_data(self, rollout_id): data = sum(data, []) if self.args.disable_rollout_trim_samples: - logger.info(f"Collectd {len(data)} samples from rollout to train") + logger.info(f"Collected {len(data)} samples from rollout to train") elif len(data) % self.args.global_batch_size != 0: trim_len = (len(data) // self.args.global_batch_size) * self.args.global_batch_size origin_data_length = len(data) @@ -332,7 +332,7 @@ def _split_train_data_by_dp(self, data, dp_size): def init_rollout_engines(args, pg, all_rollout_engines): if args.debug_train_only: - return 0, None + return 0 num_gpu_per_engine = min(args.rollout_num_gpus_per_engine, args.num_gpus_per_node) num_engines = args.rollout_num_gpus // num_gpu_per_engine @@ -412,7 +412,7 @@ def init_rollout_engines(args, pg, all_rollout_engines): num_new_engines = len(rollout_engines) if num_new_engines == 0: - return num_new_engines, None + return num_new_engines if args.rollout_external: addr_and_ports = _allocate_rollout_engine_addr_and_ports_external(args=args, rollout_engines=rollout_engines) diff --git a/scripts/run-qwen3-0.6B-torch-sft-crowd-code.sh b/scripts/run-qwen3-0.6B-torch-sft-crowd-code.sh new file mode 100644 index 000000000..689883bb6 --- /dev/null +++ b/scripts/run-qwen3-0.6B-torch-sft-crowd-code.sh @@ -0,0 +1,168 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +# set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONUNBUFFERED=1 +export CUDA_VISIBLE_DEVICES=0,1 + +NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + + +# --- 1. DYNAMIC HOST IP DETECTION (CRITICAL FOR SLURM) --- +# Don't hardcode IP. Get the actual IP of the current node. +export HEAD_NODE_IP=$(hostname -I | awk '{print $1}') +echo "Detected Head Node IP: ${HEAD_NODE_IP}" + +# --- 2. PROXY CONFIGURATION --- +# Ensure local traffic doesn't go through a corporate proxy +export no_proxy="${HEAD_NODE_IP},localhost,127.0.0.1,0.0.0.0" +export NO_PROXY="${HEAD_NODE_IP},localhost,127.0.0.1,0.0.0.0" + +# --- 3. DEBUGGING & STABILITY ENV VARS --- +# Force NCCL/Distributed into a robust mode to prevent initialization hangs +# export NCCL_P2P_DISABLE=1 +# export NCCL_IB_DISABLE=1 +export NCCL_DEBUG=INFO +export TORCH_DISTRIBUTED_DEBUG=INFO + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" + +RUN_ID=${RUN_ID:-"run_$(date +%Y%m%d_%H%M%S)"} +LOAD_SAVE_PATH="/fast/project/HFMI_SynergyUnit/tab_model/huggingface/shared_data/${RUN_ID}/checkpoints" + +CKPT_ARGS=( + --hf-checkpoint /fast/project/HFMI_SynergyUnit/tab_model/huggingface/Qwen3-0.6B + --load /fast/project/HFMI_SynergyUnit/tab_model/huggingface/Qwen3-0.6B + --ref-load /fast/project/HFMI_SynergyUnit/tab_model/huggingface/Qwen3-0.6B +) + +SFT_ARGS=( + --rollout-function-path miles.rollout.sft_rollout.generate_rollout + --prompt-data /fast/project/HFMI_SynergyUnit/tab_model/huggingface/nemo_hf_part_jsonl_4k_tokens.parquet + --input-key messages + --apply-chat-template + --rollout-shuffle + --num-epoch 3 + --rollout-batch-size 16 + --global-batch-size 16 + + --loss-type sft_loss + --calculate-per-token-loss + --disable-compute-advantages-and-returns + --num-rollout 2000 + --debug-train-only +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-5 + --lr-decay-style WSD + --lr-wsd-decay-style linear + --lr-warmup-iters 100 + --lr-decay-iters 2000 + --lr-wsd-decay-iters 500 + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-project crowd-pilot-miles + --wandb-team instant-uv + --wandb-group qwen3-0.6b-sft-torch +) + +TRAIN_BACKEND_ARGS=( + --train-backend fsdp + --update-weight-buffer-size 536870912 + --gradient-checkpointing + --attn-implementation flash_attention_3 + --train-env-vars '{"PYTORCH_CUDA_ALLOC_CONF":"expandable_segments:True"}' + --actor-num-gpus-per-node 2 +) + +PERF_ARGS=( + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +MISC_ARGS=( + --actor-num-nodes 1 + --actor-num-gpus-per-node 2 + --colocate + --rollout-max-context-len 8192 + --rollout-max-prompt-len 8000 + --rollout-max-response-len 8192 + --use-fault-tolerance + --dump-details /fast/project/HFMI_SynergyUnit/tab_model/huggingface/shared_data/qwen3-600M-fsdp-1116-noref/dump_details +) + +# launch the master node of ray in container - 2 GPUs for training +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +python3 -m ray.scripts.scripts start --head \ + --node-ip-address=${HEAD_NODE_IP} \ + --num-gpus 2 \ + --num-cpus 4 \ + --memory=214748364800 \ + --disable-usage-stats \ + --dashboard-host=0.0.0.0 \ + --dashboard-port=8265 \ + --port=6379 + +echo "Ray started. Waiting for Dashboard to be ready..." + +# --- 4. WAIT FOR DASHBOARD (FIX FOR 504 ERROR) --- +# Loop until the dashboard port accepts connections +for i in {1..30}; do + if curl -s "http://${HEAD_NODE_IP}:8265" > /dev/null; then + echo "Dashboard is up!" + break + fi + echo "Waiting for Ray Dashboard..." + sleep 2 +done +# Add a small safety buffer +sleep 5 + +# Build runtime env +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/fast/project/HFMI_SynergyUnit/mihir/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"PYTORCH_CUDA_ALLOC_CONF\": \"expandable_segments:True\" + } +}" + +python3 -m ray.scripts.scripts job submit --address="http://${HEAD_NODE_IP}:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + ${CKPT_ARGS[@]} \ + ${SFT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${TRAIN_BACKEND_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${MISC_ARGS[@]} + + + diff --git a/scripts/run-qwen3-0.6B-torch-sft.sh b/scripts/run-qwen3-0.6B-torch-sft.sh new file mode 100644 index 000000000..c7afed1ba --- /dev/null +++ b/scripts/run-qwen3-0.6B-torch-sft.sh @@ -0,0 +1,165 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +# set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONUNBUFFERED=1 +export CUDA_VISIBLE_DEVICES=0,1 + +NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + + +# --- 1. DYNAMIC HOST IP DETECTION (CRITICAL FOR SLURM) --- +# Don't hardcode IP. Get the actual IP of the current node. +export HEAD_NODE_IP=$(hostname -I | awk '{print $1}') +echo "Detected Head Node IP: ${HEAD_NODE_IP}" + +# --- 2. PROXY CONFIGURATION --- +# Ensure local traffic doesn't go through a corporate proxy +export no_proxy="${HEAD_NODE_IP},localhost,127.0.0.1,0.0.0.0" +export NO_PROXY="${HEAD_NODE_IP},localhost,127.0.0.1,0.0.0.0" + +# --- 3. DEBUGGING & STABILITY ENV VARS --- +# Force NCCL/Distributed into a robust mode to prevent initialization hangs +# export NCCL_P2P_DISABLE=1 +# export NCCL_IB_DISABLE=1 +export NCCL_DEBUG=INFO +export TORCH_DISTRIBUTED_DEBUG=INFO + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" + +RUN_ID=${RUN_ID:-"run_$(date +%Y%m%d_%H%M%S)"} +LOAD_SAVE_PATH="/fast/project/HFMI_SynergyUnit/tab_model/huggingface/shared_data/${RUN_ID}/checkpoints" + +CKPT_ARGS=( + --hf-checkpoint /fast/project/HFMI_SynergyUnit/tab_model/huggingface/Qwen3-0.6B + --load /fast/project/HFMI_SynergyUnit/tab_model/huggingface/Qwen3-0.6B + --ref-load /fast/project/HFMI_SynergyUnit/tab_model/huggingface/Qwen3-0.6B +) + +SFT_ARGS=( + --rollout-function-path miles.rollout.sft_rollout.generate_rollout + --prompt-data /fast/project/HFMI_SynergyUnit/tab_model/huggingface/openhermes2_5.parquet + --input-key messages + --apply-chat-template + --rollout-shuffle + --num-epoch 3 + --rollout-batch-size 128 + --global-batch-size 128 + + --loss-type sft_loss + --calculate-per-token-loss + --disable-compute-advantages-and-returns + --debug-train-only +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-5 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-project crowd-pilot-miles + --wandb-team instant-uv + --wandb-group qwen3-0.6b-sft-torch +) + +SGLANG_ARGS=( + +) + +TRAIN_BACKEND_ARGS=( + --train-backend fsdp + --update-weight-buffer-size 536870912 + --gradient-checkpointing + --attn-implementation flash_attention_3 + --train-env-vars '{"PYTORCH_CUDA_ALLOC_CONF":"expandable_segments:True"}' + --actor-num-gpus-per-node 2 +) + +PERF_ARGS=( + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +MISC_ARGS=( + --actor-num-nodes 1 + --actor-num-gpus-per-node 2 + --rollout-batch-size 128 + --colocate + --use-fault-tolerance + --dump-details /fast/project/HFMI_SynergyUnit/tab_model/huggingface/shared_data/qwen3-600M-fsdp-1116-noref/dump_details +) + +# launch the master node of ray in container - 2 GPUs for training +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +python3 -m ray.scripts.scripts start --head \ + --node-ip-address=${HEAD_NODE_IP} \ + --num-gpus 2 \ + --num-cpus 4 \ + --memory=214748364800 \ + --disable-usage-stats \ + --dashboard-host=0.0.0.0 \ + --dashboard-port=8265 \ + --port=6379 + +echo "Ray started. Waiting for Dashboard to be ready..." + +# --- 4. WAIT FOR DASHBOARD (FIX FOR 504 ERROR) --- +# Loop until the dashboard port accepts connections +for i in {1..30}; do + if curl -s "http://${HEAD_NODE_IP}:8265" > /dev/null; then + echo "Dashboard is up!" + break + fi + echo "Waiting for Ray Dashboard..." + sleep 2 +done +# Add a small safety buffer +sleep 5 + +# Build runtime env +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/fast/project/HFMI_SynergyUnit/mihir/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"PYTORCH_CUDA_ALLOC_CONF\": \"expandable_segments:True\" + } +}" + +python3 -m ray.scripts.scripts job submit --address="http://${HEAD_NODE_IP}:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + ${CKPT_ARGS[@]} \ + ${SFT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${TRAIN_BACKEND_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${MISC_ARGS[@]} + + +