@@ -162,6 +162,9 @@ class GRPOConfig(TypedDict):
162162 reward_scaling : RewardScalingConfig
163163 # By default advantages are calculated on CPU. Setting this flag to true leverages GPU for their computation.
164164 calculate_advantages_on_gpu : NotRequired [bool ]
165+ # Sequence-level logprob error masking for training stability. If set, mask sequences with mult_prob_error exceeding this threshold (same scale as token_mult_prob_error metric, e.g., 1.5)
166+ # Note that this is slightly different than Masked Importance Sampling (MIS) because this uses the absolute value of the difference between the training and generation logprobs, whereas MIS just uses the difference between the training and generation logprobs.
167+ seq_logprob_error_threshold : float | None
165168 # Advantage estimator configuration (grpo or reinforce_plus_plus)
166169 adv_estimator : NotRequired [AdvEstimatorConfig ]
167170
@@ -1161,6 +1164,89 @@ def _log_mixed_rewards_and_advantages_information(
11611164 metrics ["advantages/mean" ] = advantages .float ().mean ().item ()
11621165
11631166
1167+ def compute_and_apply_seq_logprob_error_masking (
1168+ train_data : BatchedDataDict ,
1169+ rewards : torch .Tensor ,
1170+ seq_logprob_error_threshold : Optional [float ],
1171+ ) -> tuple [float , int , float ]:
1172+ """Compute sequence-level logprob error metrics and optionally mask high-error sequences.
1173+
1174+ This function computes the multiplicative probability error per sequence
1175+ (same calculation as token_mult_prob_error but aggregated per-sequence) and
1176+ optionally masks sequences that exceed the configured threshold.
1177+
1178+ Args:
1179+ train_data: Training data dict containing token_mask, sample_mask,
1180+ prev_logprobs, and generation_logprobs. If masking is applied,
1181+ sample_mask will be updated in-place.
1182+ rewards: Reward tensor for computing statistics on masked sequences.
1183+ seq_logprob_error_threshold: If set, mask sequences with mult_prob_error
1184+ exceeding this threshold. If None, only compute metrics.
1185+
1186+ Returns:
1187+ Tuple of (max_seq_mult_prob_error, num_masked_seqs, masked_correct_pct)
1188+ """
1189+ # Compute sequence-level logprob error metrics (always)
1190+ token_mask = train_data ["token_mask" ][:, 1 :]
1191+ sample_mask = train_data ["sample_mask" ]
1192+ prev_logprobs = train_data ["prev_logprobs" ][:, 1 :]
1193+ generation_logprobs = train_data ["generation_logprobs" ][:, 1 :]
1194+ lp_error = torch .abs (generation_logprobs - prev_logprobs )
1195+
1196+ # Use combined mask exactly as in loss function
1197+ mask = token_mask * sample_mask .unsqueeze (- 1 )
1198+
1199+ # Calculate sequence-level multiplicative prob error
1200+ # EXACT same calculation as token_mult_prob_error but per-sequence
1201+ seq_mult_prob_error = (torch .exp (lp_error * mask ) * mask ).sum (dim = - 1 ) / mask .sum (
1202+ dim = - 1
1203+ ).clamp (min = 1 )
1204+ max_seq_mult_prob_error = (
1205+ seq_mult_prob_error .max ().item () if seq_mult_prob_error .numel () > 0 else 0.0
1206+ )
1207+
1208+ # Apply sequence-level masking if configured
1209+ num_masked_seqs = 0
1210+ masked_correct_pct = 0.0
1211+
1212+ if seq_logprob_error_threshold is not None :
1213+ print (
1214+ f"▶ Applying sequence-level logprob error masking (threshold={ seq_logprob_error_threshold } )..." ,
1215+ flush = True ,
1216+ )
1217+
1218+ original_sample_mask = sample_mask .clone ()
1219+
1220+ # Create mask for sequences below threshold
1221+ seq_error_mask = (
1222+ seq_mult_prob_error <= seq_logprob_error_threshold
1223+ ).float () * original_sample_mask
1224+
1225+ diff_mask = original_sample_mask - seq_error_mask
1226+ num_masked_seqs = int (diff_mask .sum ().item ())
1227+
1228+ if num_masked_seqs > 0 :
1229+ diff_mask_bool = diff_mask .bool ()
1230+ masked_correct_count = (rewards .view (- 1 )[diff_mask_bool ] == 1 ).sum ().item ()
1231+ masked_correct_pct = masked_correct_count / num_masked_seqs
1232+
1233+ # Update sample_mask in train_data
1234+ train_data ["sample_mask" ] = seq_error_mask
1235+
1236+ print (
1237+ f" Masked { num_masked_seqs } sequences with mult_prob_error > { seq_logprob_error_threshold } " ,
1238+ flush = True ,
1239+ )
1240+ if num_masked_seqs > 0 :
1241+ print (
1242+ f" • { masked_correct_count } /{ num_masked_seqs } masked sequences were correct (reward=1)"
1243+ f" → { masked_correct_pct :.2%} " ,
1244+ flush = True ,
1245+ )
1246+
1247+ return max_seq_mult_prob_error , num_masked_seqs , masked_correct_pct
1248+
1249+
11641250# ===============================================================================
11651251# Training & Validation
11661252# ===============================================================================
@@ -1608,6 +1694,17 @@ def grpo_train(
16081694 del logprob_data
16091695 del extra_multimodal_data
16101696
1697+ (
1698+ max_seq_mult_prob_error ,
1699+ num_masked_seqs ,
1700+ masked_correct_pct ,
1701+ ) = compute_and_apply_seq_logprob_error_masking (
1702+ train_data = train_data ,
1703+ rewards = rewards ,
1704+ seq_logprob_error_threshold = master_config ["grpo" ][
1705+ "seq_logprob_error_threshold"
1706+ ],
1707+ )
16111708 # Compute advantages with adv_estimator using correct mask and logprobs
16121709 with timer .time ("advantage_calculation" ):
16131710 print ("▶ Computing advantages..." , flush = True )
@@ -1771,6 +1868,11 @@ def grpo_train(
17711868 metrics ["generation_logger_metrics" ] = generation_logger_metrics
17721869 total_valid_tokens += metrics ["global_valid_toks" ]
17731870
1871+ # Always log sequence-level error metrics (useful for deciding threshold)
1872+ metrics ["max_seq_mult_prob_error" ] = max_seq_mult_prob_error
1873+ metrics ["num_masked_seqs_by_logprob_error" ] = num_masked_seqs
1874+ metrics ["masked_correct_pct" ] = masked_correct_pct
1875+
17741876 ## Checkpointing
17751877 consumed_samples += master_config ["grpo" ]["num_prompts_per_step" ]
17761878 timeout .mark_iteration ()
@@ -2601,6 +2703,17 @@ def async_grpo_train(
26012703 train_data ["prev_logprobs" ] = fprop_logprobs
26022704 train_data ["reference_policy_logprobs" ] = reference_logprobs
26032705
2706+ (
2707+ max_seq_mult_prob_error ,
2708+ num_masked_seqs ,
2709+ masked_correct_pct ,
2710+ ) = compute_and_apply_seq_logprob_error_masking (
2711+ train_data = train_data ,
2712+ rewards = rewards ,
2713+ seq_logprob_error_threshold = master_config ["grpo" ][
2714+ "seq_logprob_error_threshold"
2715+ ],
2716+ )
26042717 # Compute advantages with adv_estimator using correct mask and logprobs
26052718 with timer .time ("advantage_calculation" ):
26062719 print ("▶ Computing advantages..." , flush = True )
@@ -2775,6 +2888,11 @@ def async_grpo_train(
27752888 metrics ["generation_logger_metrics" ] = generation_logger_metrics
27762889 total_valid_tokens += metrics ["global_valid_toks" ]
27772890
2891+ # Always log sequence-level error metrics (useful for deciding threshold)
2892+ metrics ["max_seq_mult_prob_error" ] = max_seq_mult_prob_error
2893+ metrics ["num_masked_seqs_by_logprob_error" ] = num_masked_seqs
2894+ metrics ["masked_correct_pct" ] = masked_correct_pct
2895+
27782896 # Checkpointing (same as sync version)
27792897 consumed_samples += master_config ["grpo" ]["num_prompts_per_step" ]
27802898 timeout .mark_iteration ()
0 commit comments