Skip to content

Commit 14f9f38

Browse files
yfwpjin-nvidia
andauthored
feat: Mask sequences with high logprob error (#1838)
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com> Co-authored-by: Peter Jin <pjin@nvidia.com>
1 parent 0e0edcf commit 14f9f38

5 files changed

Lines changed: 517 additions & 43 deletions

File tree

examples/configs/grpo_math_1B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ grpo:
3737
source_max: 1.0
3838
target_min: 0.0
3939
target_max: 1.0
40+
seq_logprob_error_threshold: null
4041

4142
async_grpo:
4243
enabled: false # Set to true to enable async training mode

examples/configs/vlm_grpo_3B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ grpo:
3838
async_grpo:
3939
enabled: false
4040
max_trajectory_age_steps: 1
41+
seq_logprob_error_threshold: null
4142

4243
loss_fn:
4344
reference_policy_kl_penalty: 0.01

examples/configs/vlm_grpo_3B_megatron.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ grpo:
3636
async_grpo:
3737
enabled: false
3838
max_trajectory_age_steps: 1
39+
seq_logprob_error_threshold: null
3940
loss_fn:
4041
reference_policy_kl_penalty: 0.01
4142
# Can be set to k1, k2, k3

nemo_rl/algorithms/grpo.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,9 @@ class GRPOConfig(TypedDict):
162162
reward_scaling: RewardScalingConfig
163163
# By default advantages are calculated on CPU. Setting this flag to true leverages GPU for their computation.
164164
calculate_advantages_on_gpu: NotRequired[bool]
165+
# Sequence-level logprob error masking for training stability. If set, mask sequences with mult_prob_error exceeding this threshold (same scale as token_mult_prob_error metric, e.g., 1.5)
166+
# Note that this is slightly different than Masked Importance Sampling (MIS) because this uses the absolute value of the difference between the training and generation logprobs, whereas MIS just uses the difference between the training and generation logprobs.
167+
seq_logprob_error_threshold: float | None
165168
# Advantage estimator configuration (grpo or reinforce_plus_plus)
166169
adv_estimator: NotRequired[AdvEstimatorConfig]
167170

@@ -1161,6 +1164,89 @@ def _log_mixed_rewards_and_advantages_information(
11611164
metrics["advantages/mean"] = advantages.float().mean().item()
11621165

11631166

1167+
def compute_and_apply_seq_logprob_error_masking(
1168+
train_data: BatchedDataDict,
1169+
rewards: torch.Tensor,
1170+
seq_logprob_error_threshold: Optional[float],
1171+
) -> tuple[float, int, float]:
1172+
"""Compute sequence-level logprob error metrics and optionally mask high-error sequences.
1173+
1174+
This function computes the multiplicative probability error per sequence
1175+
(same calculation as token_mult_prob_error but aggregated per-sequence) and
1176+
optionally masks sequences that exceed the configured threshold.
1177+
1178+
Args:
1179+
train_data: Training data dict containing token_mask, sample_mask,
1180+
prev_logprobs, and generation_logprobs. If masking is applied,
1181+
sample_mask will be updated in-place.
1182+
rewards: Reward tensor for computing statistics on masked sequences.
1183+
seq_logprob_error_threshold: If set, mask sequences with mult_prob_error
1184+
exceeding this threshold. If None, only compute metrics.
1185+
1186+
Returns:
1187+
Tuple of (max_seq_mult_prob_error, num_masked_seqs, masked_correct_pct)
1188+
"""
1189+
# Compute sequence-level logprob error metrics (always)
1190+
token_mask = train_data["token_mask"][:, 1:]
1191+
sample_mask = train_data["sample_mask"]
1192+
prev_logprobs = train_data["prev_logprobs"][:, 1:]
1193+
generation_logprobs = train_data["generation_logprobs"][:, 1:]
1194+
lp_error = torch.abs(generation_logprobs - prev_logprobs)
1195+
1196+
# Use combined mask exactly as in loss function
1197+
mask = token_mask * sample_mask.unsqueeze(-1)
1198+
1199+
# Calculate sequence-level multiplicative prob error
1200+
# EXACT same calculation as token_mult_prob_error but per-sequence
1201+
seq_mult_prob_error = (torch.exp(lp_error * mask) * mask).sum(dim=-1) / mask.sum(
1202+
dim=-1
1203+
).clamp(min=1)
1204+
max_seq_mult_prob_error = (
1205+
seq_mult_prob_error.max().item() if seq_mult_prob_error.numel() > 0 else 0.0
1206+
)
1207+
1208+
# Apply sequence-level masking if configured
1209+
num_masked_seqs = 0
1210+
masked_correct_pct = 0.0
1211+
1212+
if seq_logprob_error_threshold is not None:
1213+
print(
1214+
f"▶ Applying sequence-level logprob error masking (threshold={seq_logprob_error_threshold})...",
1215+
flush=True,
1216+
)
1217+
1218+
original_sample_mask = sample_mask.clone()
1219+
1220+
# Create mask for sequences below threshold
1221+
seq_error_mask = (
1222+
seq_mult_prob_error <= seq_logprob_error_threshold
1223+
).float() * original_sample_mask
1224+
1225+
diff_mask = original_sample_mask - seq_error_mask
1226+
num_masked_seqs = int(diff_mask.sum().item())
1227+
1228+
if num_masked_seqs > 0:
1229+
diff_mask_bool = diff_mask.bool()
1230+
masked_correct_count = (rewards.view(-1)[diff_mask_bool] == 1).sum().item()
1231+
masked_correct_pct = masked_correct_count / num_masked_seqs
1232+
1233+
# Update sample_mask in train_data
1234+
train_data["sample_mask"] = seq_error_mask
1235+
1236+
print(
1237+
f" Masked {num_masked_seqs} sequences with mult_prob_error > {seq_logprob_error_threshold}",
1238+
flush=True,
1239+
)
1240+
if num_masked_seqs > 0:
1241+
print(
1242+
f" • {masked_correct_count}/{num_masked_seqs} masked sequences were correct (reward=1)"
1243+
f" → {masked_correct_pct:.2%}",
1244+
flush=True,
1245+
)
1246+
1247+
return max_seq_mult_prob_error, num_masked_seqs, masked_correct_pct
1248+
1249+
11641250
# ===============================================================================
11651251
# Training & Validation
11661252
# ===============================================================================
@@ -1608,6 +1694,17 @@ def grpo_train(
16081694
del logprob_data
16091695
del extra_multimodal_data
16101696

1697+
(
1698+
max_seq_mult_prob_error,
1699+
num_masked_seqs,
1700+
masked_correct_pct,
1701+
) = compute_and_apply_seq_logprob_error_masking(
1702+
train_data=train_data,
1703+
rewards=rewards,
1704+
seq_logprob_error_threshold=master_config["grpo"][
1705+
"seq_logprob_error_threshold"
1706+
],
1707+
)
16111708
# Compute advantages with adv_estimator using correct mask and logprobs
16121709
with timer.time("advantage_calculation"):
16131710
print("▶ Computing advantages...", flush=True)
@@ -1771,6 +1868,11 @@ def grpo_train(
17711868
metrics["generation_logger_metrics"] = generation_logger_metrics
17721869
total_valid_tokens += metrics["global_valid_toks"]
17731870

1871+
# Always log sequence-level error metrics (useful for deciding threshold)
1872+
metrics["max_seq_mult_prob_error"] = max_seq_mult_prob_error
1873+
metrics["num_masked_seqs_by_logprob_error"] = num_masked_seqs
1874+
metrics["masked_correct_pct"] = masked_correct_pct
1875+
17741876
## Checkpointing
17751877
consumed_samples += master_config["grpo"]["num_prompts_per_step"]
17761878
timeout.mark_iteration()
@@ -2601,6 +2703,17 @@ def async_grpo_train(
26012703
train_data["prev_logprobs"] = fprop_logprobs
26022704
train_data["reference_policy_logprobs"] = reference_logprobs
26032705

2706+
(
2707+
max_seq_mult_prob_error,
2708+
num_masked_seqs,
2709+
masked_correct_pct,
2710+
) = compute_and_apply_seq_logprob_error_masking(
2711+
train_data=train_data,
2712+
rewards=rewards,
2713+
seq_logprob_error_threshold=master_config["grpo"][
2714+
"seq_logprob_error_threshold"
2715+
],
2716+
)
26042717
# Compute advantages with adv_estimator using correct mask and logprobs
26052718
with timer.time("advantage_calculation"):
26062719
print("▶ Computing advantages...", flush=True)
@@ -2775,6 +2888,11 @@ def async_grpo_train(
27752888
metrics["generation_logger_metrics"] = generation_logger_metrics
27762889
total_valid_tokens += metrics["global_valid_toks"]
27772890

2891+
# Always log sequence-level error metrics (useful for deciding threshold)
2892+
metrics["max_seq_mult_prob_error"] = max_seq_mult_prob_error
2893+
metrics["num_masked_seqs_by_logprob_error"] = num_masked_seqs
2894+
metrics["masked_correct_pct"] = masked_correct_pct
2895+
27782896
# Checkpointing (same as sync version)
27792897
consumed_samples += master_config["grpo"]["num_prompts_per_step"]
27802898
timeout.mark_iteration()

0 commit comments

Comments
 (0)