diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index da7d911119..11aebb9e84 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -37,6 +37,7 @@ grpo: source_max: 1.0 target_min: 0.0 target_max: 1.0 + seq_logprob_error_threshold: null async_grpo: enabled: false # Set to true to enable async training mode diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index 5d0c1aae2a..a6522d6c8e 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -38,6 +38,7 @@ grpo: async_grpo: enabled: false max_trajectory_age_steps: 1 + seq_logprob_error_threshold: null loss_fn: reference_policy_kl_penalty: 0.01 diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml index 6d9016503a..a38b6e15a8 100644 --- a/examples/configs/vlm_grpo_3B_megatron.yaml +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -36,6 +36,7 @@ grpo: async_grpo: enabled: false max_trajectory_age_steps: 1 + seq_logprob_error_threshold: null loss_fn: reference_policy_kl_penalty: 0.01 # Can be set to k1, k2, k3 diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index bc56d21326..d349500516 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -162,6 +162,9 @@ class GRPOConfig(TypedDict): reward_scaling: RewardScalingConfig # By default advantages are calculated on CPU. Setting this flag to true leverages GPU for their computation. calculate_advantages_on_gpu: NotRequired[bool] + # Sequence-level logprob error masking for training stability. If set, mask sequences with mult_prob_error exceeding this threshold (same scale as token_mult_prob_error metric, e.g., 1.5) + # Note that this is slightly different than Masked Importance Sampling (MIS) because this uses the absolute value of the difference between the training and generation logprobs, whereas MIS just uses the difference between the training and generation logprobs. + seq_logprob_error_threshold: float | None # Advantage estimator configuration (grpo or reinforce_plus_plus) adv_estimator: NotRequired[AdvEstimatorConfig] @@ -1161,6 +1164,89 @@ def _log_mixed_rewards_and_advantages_information( metrics["advantages/mean"] = advantages.float().mean().item() +def compute_and_apply_seq_logprob_error_masking( + train_data: BatchedDataDict, + rewards: torch.Tensor, + seq_logprob_error_threshold: Optional[float], +) -> tuple[float, int, float]: + """Compute sequence-level logprob error metrics and optionally mask high-error sequences. + + This function computes the multiplicative probability error per sequence + (same calculation as token_mult_prob_error but aggregated per-sequence) and + optionally masks sequences that exceed the configured threshold. + + Args: + train_data: Training data dict containing token_mask, sample_mask, + prev_logprobs, and generation_logprobs. If masking is applied, + sample_mask will be updated in-place. + rewards: Reward tensor for computing statistics on masked sequences. + seq_logprob_error_threshold: If set, mask sequences with mult_prob_error + exceeding this threshold. If None, only compute metrics. + + Returns: + Tuple of (max_seq_mult_prob_error, num_masked_seqs, masked_correct_pct) + """ + # Compute sequence-level logprob error metrics (always) + token_mask = train_data["token_mask"][:, 1:] + sample_mask = train_data["sample_mask"] + prev_logprobs = train_data["prev_logprobs"][:, 1:] + generation_logprobs = train_data["generation_logprobs"][:, 1:] + lp_error = torch.abs(generation_logprobs - prev_logprobs) + + # Use combined mask exactly as in loss function + mask = token_mask * sample_mask.unsqueeze(-1) + + # Calculate sequence-level multiplicative prob error + # EXACT same calculation as token_mult_prob_error but per-sequence + seq_mult_prob_error = (torch.exp(lp_error * mask) * mask).sum(dim=-1) / mask.sum( + dim=-1 + ).clamp(min=1) + max_seq_mult_prob_error = ( + seq_mult_prob_error.max().item() if seq_mult_prob_error.numel() > 0 else 0.0 + ) + + # Apply sequence-level masking if configured + num_masked_seqs = 0 + masked_correct_pct = 0.0 + + if seq_logprob_error_threshold is not None: + print( + f"▶ Applying sequence-level logprob error masking (threshold={seq_logprob_error_threshold})...", + flush=True, + ) + + original_sample_mask = sample_mask.clone() + + # Create mask for sequences below threshold + seq_error_mask = ( + seq_mult_prob_error <= seq_logprob_error_threshold + ).float() * original_sample_mask + + diff_mask = original_sample_mask - seq_error_mask + num_masked_seqs = int(diff_mask.sum().item()) + + if num_masked_seqs > 0: + diff_mask_bool = diff_mask.bool() + masked_correct_count = (rewards.view(-1)[diff_mask_bool] == 1).sum().item() + masked_correct_pct = masked_correct_count / num_masked_seqs + + # Update sample_mask in train_data + train_data["sample_mask"] = seq_error_mask + + print( + f" Masked {num_masked_seqs} sequences with mult_prob_error > {seq_logprob_error_threshold}", + flush=True, + ) + if num_masked_seqs > 0: + print( + f" • {masked_correct_count}/{num_masked_seqs} masked sequences were correct (reward=1)" + f" → {masked_correct_pct:.2%}", + flush=True, + ) + + return max_seq_mult_prob_error, num_masked_seqs, masked_correct_pct + + # =============================================================================== # Training & Validation # =============================================================================== @@ -1608,6 +1694,17 @@ def grpo_train( del logprob_data del extra_multimodal_data + ( + max_seq_mult_prob_error, + num_masked_seqs, + masked_correct_pct, + ) = compute_and_apply_seq_logprob_error_masking( + train_data=train_data, + rewards=rewards, + seq_logprob_error_threshold=master_config["grpo"][ + "seq_logprob_error_threshold" + ], + ) # Compute advantages with adv_estimator using correct mask and logprobs with timer.time("advantage_calculation"): print("▶ Computing advantages...", flush=True) @@ -1771,6 +1868,11 @@ def grpo_train( metrics["generation_logger_metrics"] = generation_logger_metrics total_valid_tokens += metrics["global_valid_toks"] + # Always log sequence-level error metrics (useful for deciding threshold) + metrics["max_seq_mult_prob_error"] = max_seq_mult_prob_error + metrics["num_masked_seqs_by_logprob_error"] = num_masked_seqs + metrics["masked_correct_pct"] = masked_correct_pct + ## Checkpointing consumed_samples += master_config["grpo"]["num_prompts_per_step"] timeout.mark_iteration() @@ -2601,6 +2703,17 @@ def async_grpo_train( train_data["prev_logprobs"] = fprop_logprobs train_data["reference_policy_logprobs"] = reference_logprobs + ( + max_seq_mult_prob_error, + num_masked_seqs, + masked_correct_pct, + ) = compute_and_apply_seq_logprob_error_masking( + train_data=train_data, + rewards=rewards, + seq_logprob_error_threshold=master_config["grpo"][ + "seq_logprob_error_threshold" + ], + ) # Compute advantages with adv_estimator using correct mask and logprobs with timer.time("advantage_calculation"): print("▶ Computing advantages...", flush=True) @@ -2775,6 +2888,11 @@ def async_grpo_train( metrics["generation_logger_metrics"] = generation_logger_metrics total_valid_tokens += metrics["global_valid_toks"] + # Always log sequence-level error metrics (useful for deciding threshold) + metrics["max_seq_mult_prob_error"] = max_seq_mult_prob_error + metrics["num_masked_seqs_by_logprob_error"] = num_masked_seqs + metrics["masked_correct_pct"] = masked_correct_pct + # Checkpointing (same as sync version) consumed_samples += master_config["grpo"]["num_prompts_per_step"] timeout.mark_iteration() diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 7cadcc40f8..2a6d91cd87 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -26,6 +26,7 @@ from nemo_rl.algorithms.grpo import ( _default_grpo_save_state, async_grpo_train, + compute_and_apply_seq_logprob_error_masking, dynamic_sampling, grpo_train, validate, @@ -245,6 +246,14 @@ def mock_ray_get(ref): patch("nemo_rl.algorithms.grpo.print_performance_metrics", return_value={}) ) + # Mock compute_and_apply_seq_logprob_error_masking to avoid needing real logprob data + stack.enter_context( + patch( + "nemo_rl.algorithms.grpo.compute_and_apply_seq_logprob_error_masking", + return_value=(0.0, 0, 0.0), + ) + ) + return stack @@ -1133,6 +1142,11 @@ def fake_batched_message_log_to_flat_message(*_args, **_kwargs): monkeypatch.setattr( grpo_mod, "maybe_gpu_profile_step", lambda *_args, **_kwargs: None ) + monkeypatch.setattr( + grpo_mod, + "compute_and_apply_seq_logprob_error_masking", + lambda *_args, **_kwargs: (0.0, 0, 0.0), + ) master_config = mock_grpo_components["master_config"] master_config["grpo"]["max_num_steps"] = 1 @@ -1296,6 +1310,7 @@ def val_iter(self): "enabled": False, "max_trajectory_age_steps": 1, }, + "seq_logprob_error_threshold": None, "adv_estimator": { "name": "grpo", "use_leave_one_out_baseline": False, @@ -1392,20 +1407,24 @@ def test_grpo_exit_on_max_steps(mock_grpo_components, train_func): "nemo_rl.algorithms.grpo.run_async_multi_turn_rollout", return_value=(mock_batch, mock_rollout_metrics), ): - train_func( - mock_grpo_components["policy"], - None, # policy_generation - mock_grpo_components["train_dataloader"], - mock_grpo_components["val_dataloader"], - mock_grpo_components["tokenizer"], - mock_grpo_components["loss_fn"], - mock_grpo_components["task_to_env"], - mock_grpo_components["val_task_to_env"], - mock_grpo_components["logger"], - mock_grpo_components["checkpointer"], - grpo_save_state, - mock_grpo_components["master_config"], - ) + with patch( + "nemo_rl.algorithms.grpo.compute_and_apply_seq_logprob_error_masking", + return_value=(0.0, 0, 0.0), + ): + train_func( + mock_grpo_components["policy"], + None, # policy_generation + mock_grpo_components["train_dataloader"], + mock_grpo_components["val_dataloader"], + mock_grpo_components["tokenizer"], + mock_grpo_components["loss_fn"], + mock_grpo_components["task_to_env"], + mock_grpo_components["val_task_to_env"], + mock_grpo_components["logger"], + mock_grpo_components["checkpointer"], + grpo_save_state, + mock_grpo_components["master_config"], + ) # Verify we trained for exactly 12 steps assert mock_grpo_components["policy"].train.call_count == 12 @@ -1440,21 +1459,25 @@ def test_grpo_exit_on_max_epochs(mock_grpo_components, train_func): ) as mock_async_rollout: mock_async_rollout.return_value = (mock_batch, mock_rollout_metrics) - # Run training - train_func( - mock_grpo_components["policy"], - None, # policy_generation - mock_grpo_components["train_dataloader"], - mock_grpo_components["val_dataloader"], - mock_grpo_components["tokenizer"], - mock_grpo_components["loss_fn"], - mock_grpo_components["task_to_env"], - mock_grpo_components["val_task_to_env"], - mock_grpo_components["logger"], - mock_grpo_components["checkpointer"], - grpo_save_state, - mock_grpo_components["master_config"], - ) + with patch( + "nemo_rl.algorithms.grpo.compute_and_apply_seq_logprob_error_masking", + return_value=(0.0, 0, 0.0), + ): + # Run training + train_func( + mock_grpo_components["policy"], + None, # policy_generation + mock_grpo_components["train_dataloader"], + mock_grpo_components["val_dataloader"], + mock_grpo_components["tokenizer"], + mock_grpo_components["loss_fn"], + mock_grpo_components["task_to_env"], + mock_grpo_components["val_task_to_env"], + mock_grpo_components["logger"], + mock_grpo_components["checkpointer"], + grpo_save_state, + mock_grpo_components["master_config"], + ) # Verify we trained for exactly two epochs (20 batches) assert mock_grpo_components["policy"].train.call_count == 20 @@ -1515,20 +1538,24 @@ def test_grpo_exit_on_timeout(mock_grpo_components, train_func, capsys): "nemo_rl.algorithms.grpo.run_async_multi_turn_rollout", return_value=(mock_batch, mock_rollout_metrics), ): - train_func( - mock_grpo_components["policy"], - None, # policy_generation - mock_grpo_components["train_dataloader"], - mock_grpo_components["val_dataloader"], - mock_grpo_components["tokenizer"], - mock_grpo_components["loss_fn"], - mock_grpo_components["task_to_env"], - mock_grpo_components["val_task_to_env"], - mock_grpo_components["logger"], - mock_grpo_components["checkpointer"], - grpo_save_state, - mock_grpo_components["master_config"], - ) + with patch( + "nemo_rl.algorithms.grpo.compute_and_apply_seq_logprob_error_masking", + return_value=(0.0, 0, 0.0), + ): + train_func( + mock_grpo_components["policy"], + None, # policy_generation + mock_grpo_components["train_dataloader"], + mock_grpo_components["val_dataloader"], + mock_grpo_components["tokenizer"], + mock_grpo_components["loss_fn"], + mock_grpo_components["task_to_env"], + mock_grpo_components["val_task_to_env"], + mock_grpo_components["logger"], + mock_grpo_components["checkpointer"], + grpo_save_state, + mock_grpo_components["master_config"], + ) # Verify training stopped at 8 steps (when check_save returned True) assert mock_grpo_components["policy"].train.call_count == 8 @@ -2008,3 +2035,329 @@ def test_validate_returns_empty_when_no_dataloader(self): assert val_metrics == {} assert timing == {} + + +# ============================================================================ +# Tests for compute_and_apply_seq_logprob_error_masking function +# ============================================================================ + + +class TestComputeAndApplySeqLogprobErrorMasking: + """Tests for the compute_and_apply_seq_logprob_error_masking function.""" + + def _create_train_data( + self, + batch_size: int, + seq_length: int, + prev_logprobs: torch.Tensor, + generation_logprobs: torch.Tensor, + token_mask: torch.Tensor = None, + sample_mask: torch.Tensor = None, + ) -> BatchedDataDict: + """Helper to create mock train_data for testing.""" + if token_mask is None: + token_mask = torch.ones(batch_size, seq_length) + if sample_mask is None: + sample_mask = torch.ones(batch_size) + + return BatchedDataDict( + { + "token_mask": token_mask, + "sample_mask": sample_mask, + "prev_logprobs": prev_logprobs, + "generation_logprobs": generation_logprobs, + } + ) + + def test_no_threshold_only_computes_metrics(self): + """Test that when threshold is None, only metrics are computed (no masking).""" + batch_size, seq_length = 4, 10 + + # Create logprobs with varying errors + prev_logprobs = torch.zeros(batch_size, seq_length) + generation_logprobs = torch.zeros(batch_size, seq_length) + # Add small errors to sequences + generation_logprobs[0, 1:5] = 0.1 # Small error + generation_logprobs[1, 1:5] = 0.5 # Medium error + generation_logprobs[2, 1:5] = 1.0 # Large error + generation_logprobs[3, 1:5] = 2.0 # Very large error + + train_data = self._create_train_data( + batch_size, seq_length, prev_logprobs, generation_logprobs + ) + rewards = torch.tensor([1.0, 0.0, 1.0, 0.0]) + original_sample_mask = train_data["sample_mask"].clone() + + max_error, num_masked, masked_pct = compute_and_apply_seq_logprob_error_masking( + train_data, rewards, seq_logprob_error_threshold=None + ) + + # Verify metrics are computed + assert max_error > 0.0, "Should compute max error" + assert num_masked == 0, "Should not mask any sequences when threshold is None" + assert masked_pct == 0.0, "Should have 0% masked" + # Verify sample_mask is unchanged + assert torch.equal(train_data["sample_mask"], original_sample_mask) + + def test_masking_with_threshold(self): + """Test that sequences exceeding threshold are masked.""" + batch_size, seq_length = 4, 10 + + # Create logprobs with specific errors + # Note: The metric is averaged over all tokens, so errors get diluted. + # Formula: seq_mult_prob_error = sum(exp(error) * mask) / sum(mask) + # With seq_length=10 and slicing [:, 1:], we have 9 tokens per sequence. + prev_logprobs = torch.zeros(batch_size, seq_length) + generation_logprobs = torch.zeros(batch_size, seq_length) + # Sequence 0: small error -> avg ≈ 1.047 (below threshold 1.2) + generation_logprobs[0, 1:5] = 0.1 + # Sequence 1: small error -> avg ≈ 1.047 (below threshold 1.2) + generation_logprobs[1, 1:5] = 0.1 + # Sequence 2: medium error -> avg ≈ 1.288 (above threshold 1.2) + # 4 tokens with exp(0.5)≈1.649, 5 tokens with exp(0)=1 -> (4*1.649+5)/9≈1.288 + generation_logprobs[2, 1:5] = 0.5 + # Sequence 3: large error -> avg ≈ 1.764 (above threshold 1.2) + # 4 tokens with exp(1.0)≈2.718, 5 tokens with exp(0)=1 -> (4*2.718+5)/9≈1.764 + generation_logprobs[3, 1:5] = 1.0 + + train_data = self._create_train_data( + batch_size, seq_length, prev_logprobs, generation_logprobs + ) + rewards = torch.tensor([1.0, 0.0, 1.0, 0.0]) + + # Use threshold 1.2 which should mask sequences 2 and 3 + _max_error, num_masked, masked_pct = ( + compute_and_apply_seq_logprob_error_masking( + train_data, rewards, seq_logprob_error_threshold=1.2 + ) + ) + + # Verify masking occurred + assert num_masked == 2, "Should mask 2 sequences (indices 2 and 3)" + # Sequence 2 had reward=1, sequence 3 had reward=0, so 50% correct + assert masked_pct == 0.5, "50% of masked sequences should be correct" + + # Verify sample_mask is updated correctly + expected_mask = torch.tensor([1.0, 1.0, 0.0, 0.0]) + assert torch.allclose(train_data["sample_mask"], expected_mask), ( + "Should mask sequences 2 and 3" + ) + + def test_no_sequences_masked_when_all_below_threshold(self): + """Test that no sequences are masked when all are below threshold.""" + batch_size, seq_length = 3, 8 + + # Create logprobs with small errors (all below threshold) + prev_logprobs = torch.zeros(batch_size, seq_length) + generation_logprobs = torch.zeros(batch_size, seq_length) + generation_logprobs[:, 1:4] = 0.05 # Very small error for all + + train_data = self._create_train_data( + batch_size, seq_length, prev_logprobs, generation_logprobs + ) + rewards = torch.tensor([1.0, 1.0, 1.0]) + original_sample_mask = train_data["sample_mask"].clone() + + _max_error, num_masked, masked_pct = ( + compute_and_apply_seq_logprob_error_masking( + train_data, rewards, seq_logprob_error_threshold=2.0 + ) + ) + + # Verify no masking occurred + assert num_masked == 0, "Should not mask any sequences" + assert masked_pct == 0.0 + # All sequences should remain in sample_mask + assert torch.equal(train_data["sample_mask"], original_sample_mask) + + def test_all_sequences_masked_when_all_above_threshold(self): + """Test that all sequences are masked when all exceed threshold.""" + batch_size, seq_length = 3, 8 + + # Create logprobs with large errors (all above threshold) + prev_logprobs = torch.zeros(batch_size, seq_length) + generation_logprobs = torch.zeros(batch_size, seq_length) + generation_logprobs[:, 1:4] = 1.0 # Large error for all (exp(1) ~ 2.7) + + train_data = self._create_train_data( + batch_size, seq_length, prev_logprobs, generation_logprobs + ) + rewards = torch.tensor([1.0, 0.0, 1.0]) # 2 correct, 1 incorrect + + _max_error, num_masked, masked_pct = ( + compute_and_apply_seq_logprob_error_masking( + train_data, rewards, seq_logprob_error_threshold=1.0 + ) + ) + + # Verify all sequences are masked + assert num_masked == 3, "Should mask all 3 sequences" + assert masked_pct == pytest.approx(2 / 3, rel=1e-5), ( + "2/3 of masked should be correct" + ) + # All sequences should be zeroed in sample_mask + assert torch.equal(train_data["sample_mask"], torch.zeros(batch_size)) + + def test_respects_existing_sample_mask(self): + """Test that masking respects already-masked sequences in sample_mask.""" + batch_size, seq_length = 4, 8 + + # Create logprobs with large errors + prev_logprobs = torch.zeros(batch_size, seq_length) + generation_logprobs = torch.zeros(batch_size, seq_length) + generation_logprobs[:, 1:4] = 1.0 # Large error for all + + # Pre-mask sequence 1 (it's already excluded) + sample_mask = torch.tensor([1.0, 0.0, 1.0, 1.0]) + + train_data = self._create_train_data( + batch_size, + seq_length, + prev_logprobs, + generation_logprobs, + sample_mask=sample_mask, + ) + rewards = torch.tensor([1.0, 1.0, 0.0, 1.0]) + + _max_error, num_masked, masked_pct = ( + compute_and_apply_seq_logprob_error_masking( + train_data, rewards, seq_logprob_error_threshold=1.0 + ) + ) + + # Only 3 sequences were originally unmasked, all should be masked now + assert num_masked == 3, "Should mask 3 sequences (indices 0, 2, 3)" + # Sequences 0 and 3 had reward=1, sequence 2 had reward=0 + assert masked_pct == pytest.approx(2 / 3, rel=1e-5), ( + "2/3 of newly masked should be correct" + ) + # All should be zeroed (including already-masked seq 1) + assert torch.equal(train_data["sample_mask"], torch.zeros(batch_size)) + + def test_masked_correct_pct_calculation(self): + """Test that masked_correct_pct is calculated correctly.""" + batch_size, seq_length = 5, 8 + + prev_logprobs = torch.zeros(batch_size, seq_length) + generation_logprobs = torch.zeros(batch_size, seq_length) + # Make sequences 2, 3, 4 have high error (will be masked) + generation_logprobs[2:5, 1:4] = 1.5 + + train_data = self._create_train_data( + batch_size, seq_length, prev_logprobs, generation_logprobs + ) + # Rewards: seq 2 correct, seq 3 incorrect, seq 4 correct + rewards = torch.tensor([0.0, 0.0, 1.0, 0.0, 1.0]) + + _max_error, num_masked, masked_pct = ( + compute_and_apply_seq_logprob_error_masking( + train_data, rewards, seq_logprob_error_threshold=1.5 + ) + ) + + assert num_masked == 3, "Should mask 3 sequences" + # 2 out of 3 masked sequences were correct (reward=1) + assert masked_pct == pytest.approx(2 / 3, rel=1e-5), ( + "2/3 of masked should be correct" + ) + + def test_token_mask_is_respected(self): + """Test that token_mask affects the error calculation correctly.""" + batch_size, seq_length = 2, 8 + + prev_logprobs = torch.zeros(batch_size, seq_length) + generation_logprobs = torch.zeros(batch_size, seq_length) + # Add large error to both sequences at positions 1:6 + generation_logprobs[:, 1:6] = 1.0 + + # But mask out tokens 3-5 for sequence 0 (reducing its effective error) + # After slicing [:, 1:], this affects positions 2-4 in the 7-token sequence + token_mask = torch.ones(batch_size, seq_length) + token_mask[0, 3:6] = 0.0 # Mask out high-error tokens for seq 0 + + # After slicing [:, 1:] and accounting for token_mask: + # Seq 0: 4 valid tokens (positions 0,1,5,6), 2 have error -> avg ≈ 1.859 + # Seq 1: 7 valid tokens, 5 have error -> avg ≈ 2.227 + # Use threshold 2.0 so seq 0 passes but seq 1 fails + + train_data = self._create_train_data( + batch_size, + seq_length, + prev_logprobs, + generation_logprobs, + token_mask=token_mask, + ) + rewards = torch.tensor([1.0, 0.0]) + + # Sequence 0 should have lower error due to masked tokens + # Sequence 1 should have higher error + _max_error, num_masked, masked_pct = ( + compute_and_apply_seq_logprob_error_masking( + train_data, rewards, seq_logprob_error_threshold=2.0 + ) + ) + + # Only sequence 1 should be masked (seq 0 has reduced error due to token_mask) + assert num_masked == 1, "Should mask only sequence 1" + assert masked_pct == 0.0, "Masked sequence had reward=0" + assert train_data["sample_mask"][0] == 1.0, "Sequence 0 should remain unmasked" + assert train_data["sample_mask"][1] == 0.0, "Sequence 1 should be masked" + + def test_empty_batch_returns_zero_metrics(self): + """Test handling of edge case with empty batch.""" + # Create empty train_data + train_data = BatchedDataDict( + { + "token_mask": torch.zeros(0, 8), + "sample_mask": torch.zeros(0), + "prev_logprobs": torch.zeros(0, 8), + "generation_logprobs": torch.zeros(0, 8), + } + ) + rewards = torch.zeros(0) + + max_error, num_masked, masked_pct = compute_and_apply_seq_logprob_error_masking( + train_data, rewards, seq_logprob_error_threshold=1.5 + ) + + assert max_error == 0.0, "Empty batch should have max_error=0" + assert num_masked == 0, "Empty batch should have no masked sequences" + assert masked_pct == 0.0, "Empty batch should have 0% masked" + + def test_threshold_boundary_values(self): + """Test behavior at exact threshold boundary.""" + batch_size, seq_length = 3, 8 + + # Create logprobs where error is exactly at threshold + prev_logprobs = torch.zeros(batch_size, seq_length) + generation_logprobs = torch.zeros(batch_size, seq_length) + + # Set up specific errors: sequence-level mult_prob_error will be approximately: + # exp(error * 1) * 1 (for 1 token with error) + # So if error=0.4, mult_prob_error ~ exp(0.4) ~ 1.49 + # If error=0.41, mult_prob_error ~ exp(0.41) ~ 1.51 + generation_logprobs[0, 1] = 0.4 # Below threshold 1.5 + generation_logprobs[1, 1] = 0.405 # Very close to threshold + generation_logprobs[2, 1] = 0.41 # Just above threshold 1.5 + + # Only consider position 1 as valid token + token_mask = torch.zeros(batch_size, seq_length) + token_mask[:, 1] = 1.0 + + train_data = self._create_train_data( + batch_size, + seq_length, + prev_logprobs, + generation_logprobs, + token_mask=token_mask, + ) + rewards = torch.tensor([1.0, 1.0, 1.0]) + + # Threshold of 1.5 should mask sequence 2 (exp(0.41) > 1.5) + max_error, num_masked, masked_pct = compute_and_apply_seq_logprob_error_masking( + train_data, rewards, seq_logprob_error_threshold=1.5 + ) + + # At least sequence 2 should be masked + assert num_masked >= 1, "At least one sequence should be masked" + assert train_data["sample_mask"][0] == 1.0, "Sequence 0 should be kept"