Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ grpo:
source_max: 1.0
target_min: 0.0
target_max: 1.0
seq_logprob_error_threshold: null

async_grpo:
enabled: false # Set to true to enable async training mode
Expand Down
1 change: 1 addition & 0 deletions examples/configs/vlm_grpo_3B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ grpo:
async_grpo:
enabled: false
max_trajectory_age_steps: 1
seq_logprob_error_threshold: null

loss_fn:
reference_policy_kl_penalty: 0.01
Expand Down
1 change: 1 addition & 0 deletions examples/configs/vlm_grpo_3B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ grpo:
async_grpo:
enabled: false
max_trajectory_age_steps: 1
seq_logprob_error_threshold: null
loss_fn:
reference_policy_kl_penalty: 0.01
# Can be set to k1, k2, k3
Expand Down
117 changes: 117 additions & 0 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ class GRPOConfig(TypedDict):
reward_scaling: RewardScalingConfig
# By default advantages are calculated on CPU. Setting this flag to true leverages GPU for their computation.
calculate_advantages_on_gpu: NotRequired[bool]
# Sequence-level logprob error masking for training stability. If set, mask sequences with mult_prob_error exceeding this threshold (same scale as token_mult_prob_error metric, e.g., 1.5)
seq_logprob_error_threshold: float | None
# Advantage estimator configuration (grpo or reinforce_plus_plus)
adv_estimator: NotRequired[AdvEstimatorConfig]

Expand Down Expand Up @@ -1161,6 +1163,89 @@ def _log_mixed_rewards_and_advantages_information(
metrics["advantages/mean"] = advantages.float().mean().item()


def compute_and_apply_seq_logprob_error_masking(
train_data: BatchedDataDict,
rewards: torch.Tensor,
seq_logprob_error_threshold: Optional[float],
) -> tuple[float, int, float]:
"""Compute sequence-level logprob error metrics and optionally mask high-error sequences.

This function computes the multiplicative probability error per sequence
(same calculation as token_mult_prob_error but aggregated per-sequence) and
optionally masks sequences that exceed the configured threshold.

Args:
train_data: Training data dict containing token_mask, sample_mask,
prev_logprobs, and generation_logprobs. If masking is applied,
sample_mask will be updated in-place.
rewards: Reward tensor for computing statistics on masked sequences.
seq_logprob_error_threshold: If set, mask sequences with mult_prob_error
exceeding this threshold. If None, only compute metrics.

Returns:
Tuple of (max_seq_mult_prob_error, num_masked_seqs, masked_correct_pct)
"""
# Compute sequence-level logprob error metrics (always)
token_mask = train_data["token_mask"][:, 1:]
sample_mask = train_data["sample_mask"]
prev_logprobs = train_data["prev_logprobs"][:, 1:]
generation_logprobs = train_data["generation_logprobs"][:, 1:]
lp_error = torch.abs(generation_logprobs - prev_logprobs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i thought MIS was an upper bound masking of train/gen? this looks two sided b/c of abs value.

should we expose two knobs so we can always compare the framework with another FW that has textbook MIS vs our variant?


# Use combined mask exactly as in loss function
mask = token_mask * sample_mask.unsqueeze(-1)

# Calculate sequence-level multiplicative prob error
# EXACT same calculation as token_mult_prob_error but per-sequence
seq_mult_prob_error = (torch.exp(lp_error * mask) * mask).sum(dim=-1) / mask.sum(
dim=-1
).clamp(min=1)
max_seq_mult_prob_error = (
seq_mult_prob_error.max().item() if seq_mult_prob_error.numel() > 0 else 0.0
)

# Apply sequence-level masking if configured
num_masked_seqs = 0
masked_correct_pct = 0.0

if seq_logprob_error_threshold is not None:
print(
f"▶ Applying sequence-level logprob error masking (threshold={seq_logprob_error_threshold})...",
flush=True,
)

original_sample_mask = sample_mask.clone()

# Create mask for sequences below threshold
seq_error_mask = (
seq_mult_prob_error <= seq_logprob_error_threshold
).float() * original_sample_mask

diff_mask = original_sample_mask - seq_error_mask
num_masked_seqs = int(diff_mask.sum().item())

if num_masked_seqs > 0:
diff_mask_bool = diff_mask.bool()
masked_correct_count = (rewards.view(-1)[diff_mask_bool] == 1).sum().item()
masked_correct_pct = masked_correct_count / num_masked_seqs

# Update sample_mask in train_data
train_data["sample_mask"] = seq_error_mask

print(
f" Masked {num_masked_seqs} sequences with mult_prob_error > {seq_logprob_error_threshold}",
flush=True,
)
if num_masked_seqs > 0:
print(
f" • {masked_correct_count}/{num_masked_seqs} masked sequences were correct (reward=1)"
f" → {masked_correct_pct:.2%}",
flush=True,
)

return max_seq_mult_prob_error, num_masked_seqs, masked_correct_pct


# ===============================================================================
# Training & Validation
# ===============================================================================
Expand Down Expand Up @@ -1608,6 +1693,17 @@ def grpo_train(
del logprob_data
del extra_multimodal_data

(
max_seq_mult_prob_error,
num_masked_seqs,
masked_correct_pct,
) = compute_and_apply_seq_logprob_error_masking(
train_data=train_data,
rewards=rewards,
seq_logprob_error_threshold=master_config["grpo"][
"seq_logprob_error_threshold"
],
)
# Compute advantages with adv_estimator using correct mask and logprobs
with timer.time("advantage_calculation"):
print("▶ Computing advantages...", flush=True)
Expand Down Expand Up @@ -1771,6 +1867,11 @@ def grpo_train(
metrics["generation_logger_metrics"] = generation_logger_metrics
total_valid_tokens += metrics["global_valid_toks"]

# Always log sequence-level error metrics (useful for deciding threshold)
metrics["max_seq_mult_prob_error"] = max_seq_mult_prob_error
metrics["num_masked_seqs_by_logprob_error"] = num_masked_seqs
metrics["masked_correct_pct"] = masked_correct_pct

## Checkpointing
consumed_samples += master_config["grpo"]["num_prompts_per_step"]
timeout.mark_iteration()
Expand Down Expand Up @@ -2601,6 +2702,17 @@ def async_grpo_train(
train_data["prev_logprobs"] = fprop_logprobs
train_data["reference_policy_logprobs"] = reference_logprobs

(
max_seq_mult_prob_error,
num_masked_seqs,
masked_correct_pct,
) = compute_and_apply_seq_logprob_error_masking(
train_data=train_data,
rewards=rewards,
seq_logprob_error_threshold=master_config["grpo"][
"seq_logprob_error_threshold"
],
)
# Compute advantages with adv_estimator using correct mask and logprobs
with timer.time("advantage_calculation"):
print("▶ Computing advantages...", flush=True)
Expand Down Expand Up @@ -2775,6 +2887,11 @@ def async_grpo_train(
metrics["generation_logger_metrics"] = generation_logger_metrics
total_valid_tokens += metrics["global_valid_toks"]

# Always log sequence-level error metrics (useful for deciding threshold)
metrics["max_seq_mult_prob_error"] = max_seq_mult_prob_error
metrics["num_masked_seqs_by_logprob_error"] = num_masked_seqs
metrics["masked_correct_pct"] = masked_correct_pct

# Checkpointing (same as sync version)
consumed_samples += master_config["grpo"]["num_prompts_per_step"]
timeout.mark_iteration()
Expand Down
Loading
Loading