Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 101 additions & 39 deletions miles/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,8 @@ def _log_rollout_data(self, rollout_id: int, rollout_data, packed_batches):
if isinstance(unpacked_batch[metric_key], torch.Tensor):
loss_masks_tensor = unpacked_batch["loss_masks"].to(device=torch.cuda.current_device())
metric_tensor = unpacked_batch[metric_key].to(device=torch.cuda.current_device())
if metric_tensor.shape[0] == 0:
continue
val += (metric_tensor * loss_masks_tensor).sum() / loss_masks_tensor.sum().clamp_min(1)
else:
val += unpacked_batch[metric_key]
Expand Down Expand Up @@ -578,27 +580,42 @@ def _train_core(self, rollout_id: int, rollout_data) -> None:
self.ref_model.load_state_dict(actor_state)
self.ref_model.cpu()

def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum):
# Prepare model inputs
model_args = self._get_model_inputs_args(packed_batch)
logits = self.model(**model_args).logits.squeeze(0).float()
def _compute_sft_loss(self, unpacked_batches, logits):
loss_masks = [batch["loss_masks"].to(device=logits.device) for batch in unpacked_batches]
response_lengths = [batch["response_lengths"] for batch in unpacked_batches]
log_probs = torch.cat([batch["cur_log_probs"] for batch in unpacked_batches], dim=0)
loss = -sum_of_sample_mean(log_probs, response_lengths, loss_masks)
# make sure the gradient could backprop correctly.
if log_probs.numel() == 0:
loss += 0 * logits.sum()

# Compute log probs and entropy (unified for both CP and non-CP modes)
log_probs, entropy_result = get_logprob_and_entropy_with_cp(
logits=logits,
target_tokens=packed_batch["tokens"],
cp_rank=self.cp_rank,
cp_size=self.cp_size,
cp_group=self.cp_group,
model_input_ids=model_args["input_ids"],
allow_compile=not self.args.true_on_policy_mode,
temperature=self.args.rollout_temperature,
)
packed_batch["cur_log_probs"] = log_probs
packed_batch["entropy"] = entropy_result
kl_loss = 0
if self.args.use_kl_loss:
old_log_prob_key = "rollout_log_probs" if self.args.use_rollout_logprobs else "log_probs"
missing_old_log_probs = [
idx
for idx, batch in enumerate(unpacked_batches)
if old_log_prob_key not in batch or not isinstance(batch[old_log_prob_key], torch.Tensor)
]
if missing_old_log_probs:
raise KeyError(
f"{old_log_prob_key} must be provided as torch.Tensor for all microbatches when "
f"use_rollout_logprobs is set to {self.args.use_rollout_logprobs}. Missing in batches: {missing_old_log_probs}"
)
old_log_probs = torch.cat([batch[old_log_prob_key] for batch in unpacked_batches], dim=0)
old_log_probs = old_log_probs.to(device=log_probs.device)
kl_loss = self._compute_kl_loss(unpacked_batches, log_probs, old_log_probs, response_lengths, loss_masks)
loss += kl_loss

unpacked_batches = unpack_sequences(packed_batch)
reported = {
"loss": loss.detach(),
}

if self.args.use_kl_loss:
reported["kl_loss"] = kl_loss.detach()
return loss, reported

def _compute_policy_loss(self, unpacked_batches, logits):
old_log_prob_key = "rollout_log_probs" if self.args.use_rollout_logprobs else "log_probs"
missing_old_log_probs = [
idx
Expand All @@ -611,11 +628,11 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum):
f"use_rollout_logprobs is set to {self.args.use_rollout_logprobs}. Missing in batches: {missing_old_log_probs}"
)
old_log_probs = torch.cat([batch[old_log_prob_key] for batch in unpacked_batches], dim=0)
loss_masks = [batch["loss_masks"].to(device=logits.device) for batch in unpacked_batches]

response_lengths = [batch["response_lengths"] for batch in unpacked_batches]
log_probs = torch.cat([batch["cur_log_probs"] for batch in unpacked_batches], dim=0)
advantages = torch.cat([batch["advantages"] for batch in unpacked_batches], dim=0)
loss_masks = [batch["loss_masks"].to(device=log_probs.device) for batch in unpacked_batches]
response_lengths = [batch["response_lengths"] for batch in unpacked_batches]

advantages = advantages.to(device=log_probs.device)
old_log_probs = old_log_probs.to(device=log_probs.device)
ppo_kl = old_log_probs - log_probs
Expand Down Expand Up @@ -668,7 +685,7 @@ def _has_rollout_log_probs(batch) -> bool:
tis_clipfrac = tis_clip != tis

pg_loss = pg_loss * tis_clip

assert not self.args.calculate_per_token_loss, "calculate_per_token_loss not yet implemented"
pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks)
Expand All @@ -687,20 +704,10 @@ def _has_rollout_log_probs(batch) -> bool:

loss = pg_loss - self.args.entropy_coef * entropy_loss

kl_loss = 0
if self.args.use_kl_loss:
ref_log_probs = torch.cat([batch["ref_log_probs"] for batch in unpacked_batches], dim=0)
importance_ratio = None
if self.args.use_unbiased_kl:
importance_ratio = torch.exp(log_probs - old_log_probs)
kl = compute_approx_kl(
log_probs,
ref_log_probs,
kl_loss_type=self.args.kl_loss_type,
importance_ratio=importance_ratio,
)
kl_loss = sum_of_sample_mean(kl, response_lengths, loss_masks)

loss = loss + self.args.kl_loss_coef * kl_loss
kl_loss = self._compute_kl_loss(unpacked_batches, log_probs, old_log_probs, response_lengths, loss_masks)
loss += kl_loss

reported = {
"loss": loss.detach(),
Expand All @@ -713,9 +720,6 @@ def _has_rollout_log_probs(batch) -> bool:
if train_rollout_logprob_abs_diff is not None:
reported["train_rollout_logprob_abs_diff"] = train_rollout_logprob_abs_diff

if self.args.use_kl_loss:
reported["kl_loss"] = kl_loss.detach()

if self.args.use_opsm:
reported["opsm_clipfrac"] = opsm_clipfrac

Expand All @@ -724,6 +728,64 @@ def _has_rollout_log_probs(batch) -> bool:
reported["ois"] = sum_of_sample_mean(ois, response_lengths, loss_masks).detach()
reported["tis_clipfrac"] = sum_of_sample_mean(tis_clipfrac.float(), response_lengths, loss_masks).detach()

if self.args.use_kl_loss:
reported["kl_loss"] = kl_loss.detach()
return loss, reported

def _compute_kl_loss(self, unpacked_batches, log_probs, old_log_probs, response_lengths, loss_masks):
ref_log_probs = torch.cat([batch["ref_log_probs"] for batch in unpacked_batches], dim=0)
importance_ratio = None
if self.args.use_unbiased_kl:
importance_ratio = torch.exp(log_probs - old_log_probs)
kl = compute_approx_kl(
log_probs,
ref_log_probs,
kl_loss_type=self.args.kl_loss_type,
importance_ratio=importance_ratio,
)
kl_loss = sum_of_sample_mean(kl, response_lengths, loss_masks)

kl_term = self.args.kl_loss_coef * kl_loss
return kl_term

def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum):
# Prepare model inputs
model_args = self._get_model_inputs_args(packed_batch)
logits = self.model(**model_args).logits.squeeze(0).float()

# Compute log probs and entropy (unified for both CP and non-CP modes)
log_probs, entropy_result = get_logprob_and_entropy_with_cp(
logits=logits,
target_tokens=packed_batch["tokens"],
cp_rank=self.cp_rank,
cp_size=self.cp_size,
cp_group=self.cp_group,
model_input_ids=model_args["input_ids"],
allow_compile=not self.args.true_on_policy_mode,
temperature=self.args.rollout_temperature,
)
packed_batch["cur_log_probs"] = log_probs
packed_batch["entropy"] = entropy_result

unpacked_batches = unpack_sequences(packed_batch)

old_log_prob_key = "rollout_log_probs" if self.args.use_rollout_logprobs else "log_probs"
missing_old_log_probs = [
idx
for idx, batch in enumerate(unpacked_batches)
if old_log_prob_key not in batch or not isinstance(batch[old_log_prob_key], torch.Tensor)
]
if missing_old_log_probs:
raise KeyError(
f"{old_log_prob_key} must be provided as torch.Tensor for all microbatches when "
f"use_rollout_logprobs is set to {self.args.use_rollout_logprobs}. Missing in batches: {missing_old_log_probs}"
)

if self.args.loss_type == "sft_loss":
loss, reported = self._compute_sft_loss(unpacked_batches, logits)
else:
loss, reported = self._compute_policy_loss(unpacked_batches, logits)

# Scale loss for gradient accumulation
loss = loss * self.dp_size / self.args.global_batch_size
loss.backward()
Expand Down Expand Up @@ -760,7 +822,7 @@ def _has_rollout_log_probs(batch) -> bool:
# Log learning rate per parameter group; use scheduler's last computed LRs
lr_values = self.lr_scheduler.get_last_lr()
for gid, _group in enumerate(self.optimizer.param_groups):
log_dict[f"train/lr-pg_{gid}"] = lr_values[gid]
log_dict[f"train/lr_{gid}"] = lr_values[gid]

kl_info = ""
if self.args.use_kl_loss and "kl_loss" in aggregated:
Expand Down
6 changes: 3 additions & 3 deletions miles/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _get_rollout_data(self, rollout_id):
data = sum(data, [])

if self.args.disable_rollout_trim_samples:
logger.info(f"Collectd {len(data)} samples from rollout to train")
logger.info(f"Collected {len(data)} samples from rollout to train")
elif len(data) % self.args.global_batch_size != 0:
trim_len = (len(data) // self.args.global_batch_size) * self.args.global_batch_size
origin_data_length = len(data)
Expand Down Expand Up @@ -332,7 +332,7 @@ def _split_train_data_by_dp(self, data, dp_size):

def init_rollout_engines(args, pg, all_rollout_engines):
if args.debug_train_only:
return 0, None
return 0

num_gpu_per_engine = min(args.rollout_num_gpus_per_engine, args.num_gpus_per_node)
num_engines = args.rollout_num_gpus // num_gpu_per_engine
Expand Down Expand Up @@ -412,7 +412,7 @@ def init_rollout_engines(args, pg, all_rollout_engines):
num_new_engines = len(rollout_engines)

if num_new_engines == 0:
return num_new_engines, None
return num_new_engines

if args.rollout_external:
addr_and_ports = _allocate_rollout_engine_addr_and_ports_external(args=args, rollout_engines=rollout_engines)
Expand Down
168 changes: 168 additions & 0 deletions scripts/run-qwen3-0.6B-torch-sft-crowd-code.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#!/bin/bash

# for rerun the task
pkill -9 sglang
sleep 3
ray stop --force
pkill -9 ray
pkill -9 python
sleep 3
pkill -9 ray
pkill -9 python

# set -ex

# will prevent ray from buffering stdout/stderr
export PYTHONUNBUFFERED=1
export CUDA_VISIBLE_DEVICES=0,1

NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l)
if [ "$NVLINK_COUNT" -gt 0 ]; then
HAS_NVLINK=1
else
HAS_NVLINK=0
fi
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"


# --- 1. DYNAMIC HOST IP DETECTION (CRITICAL FOR SLURM) ---
# Don't hardcode IP. Get the actual IP of the current node.
export HEAD_NODE_IP=$(hostname -I | awk '{print $1}')
echo "Detected Head Node IP: ${HEAD_NODE_IP}"

# --- 2. PROXY CONFIGURATION ---
# Ensure local traffic doesn't go through a corporate proxy
export no_proxy="${HEAD_NODE_IP},localhost,127.0.0.1,0.0.0.0"
export NO_PROXY="${HEAD_NODE_IP},localhost,127.0.0.1,0.0.0.0"

# --- 3. DEBUGGING & STABILITY ENV VARS ---
# Force NCCL/Distributed into a robust mode to prevent initialization hangs
# export NCCL_P2P_DISABLE=1
# export NCCL_IB_DISABLE=1
export NCCL_DEBUG=INFO
export TORCH_DISTRIBUTED_DEBUG=INFO

SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"

RUN_ID=${RUN_ID:-"run_$(date +%Y%m%d_%H%M%S)"}
LOAD_SAVE_PATH="/fast/project/HFMI_SynergyUnit/tab_model/huggingface/shared_data/${RUN_ID}/checkpoints"

CKPT_ARGS=(
--hf-checkpoint /fast/project/HFMI_SynergyUnit/tab_model/huggingface/Qwen3-0.6B
--load /fast/project/HFMI_SynergyUnit/tab_model/huggingface/Qwen3-0.6B
--ref-load /fast/project/HFMI_SynergyUnit/tab_model/huggingface/Qwen3-0.6B
)

SFT_ARGS=(
--rollout-function-path miles.rollout.sft_rollout.generate_rollout
--prompt-data /fast/project/HFMI_SynergyUnit/tab_model/huggingface/nemo_hf_part_jsonl_4k_tokens.parquet
--input-key messages
--apply-chat-template
--rollout-shuffle
--num-epoch 3
--rollout-batch-size 16
--global-batch-size 16

--loss-type sft_loss
--calculate-per-token-loss
--disable-compute-advantages-and-returns
--num-rollout 2000
--debug-train-only
)

OPTIMIZER_ARGS=(
--optimizer adam
--lr 1e-5
--lr-decay-style WSD
--lr-wsd-decay-style linear
--lr-warmup-iters 100
--lr-decay-iters 2000
--lr-wsd-decay-iters 500
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.98
)

WANDB_ARGS=(
--use-wandb
--wandb-project crowd-pilot-miles
--wandb-team instant-uv
--wandb-group qwen3-0.6b-sft-torch
)

TRAIN_BACKEND_ARGS=(
--train-backend fsdp
--update-weight-buffer-size 536870912
--gradient-checkpointing
--attn-implementation flash_attention_3
--train-env-vars '{"PYTORCH_CUDA_ALLOC_CONF":"expandable_segments:True"}'
--actor-num-gpus-per-node 2
)

PERF_ARGS=(
--use-dynamic-batch-size
--max-tokens-per-gpu 9216
)

MISC_ARGS=(
--actor-num-nodes 1
--actor-num-gpus-per-node 2
--colocate
--rollout-max-context-len 8192
--rollout-max-prompt-len 8000
--rollout-max-response-len 8192
--use-fault-tolerance
--dump-details /fast/project/HFMI_SynergyUnit/tab_model/huggingface/shared_data/qwen3-600M-fsdp-1116-noref/dump_details
)

# launch the master node of ray in container - 2 GPUs for training
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
python3 -m ray.scripts.scripts start --head \
--node-ip-address=${HEAD_NODE_IP} \
--num-gpus 2 \
--num-cpus 4 \
--memory=214748364800 \
--disable-usage-stats \
--dashboard-host=0.0.0.0 \
--dashboard-port=8265 \
--port=6379

echo "Ray started. Waiting for Dashboard to be ready..."

# --- 4. WAIT FOR DASHBOARD (FIX FOR 504 ERROR) ---
# Loop until the dashboard port accepts connections
for i in {1..30}; do
if curl -s "http://${HEAD_NODE_IP}:8265" > /dev/null; then
echo "Dashboard is up!"
break
fi
echo "Waiting for Ray Dashboard..."
sleep 2
done
# Add a small safety buffer
sleep 5

# Build runtime env
RUNTIME_ENV_JSON="{
\"env_vars\": {
\"PYTHONPATH\": \"/fast/project/HFMI_SynergyUnit/mihir/Megatron-LM/\",
\"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
\"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\",
\"PYTORCH_CUDA_ALLOC_CONF\": \"expandable_segments:True\"
}
}"

python3 -m ray.scripts.scripts job submit --address="http://${HEAD_NODE_IP}:8265" \
--runtime-env-json="${RUNTIME_ENV_JSON}" \
-- python3 train.py \
${CKPT_ARGS[@]} \
${SFT_ARGS[@]} \
${OPTIMIZER_ARGS[@]} \
${WANDB_ARGS[@]} \
${SGLANG_ARGS[@]} \
${TRAIN_BACKEND_ARGS[@]} \
${PERF_ARGS[@]} \
${MISC_ARGS[@]}



Loading