diff --git a/torchrec/metrics/ne.py b/torchrec/metrics/ne.py index 41f14a92e..784f3a372 100644 --- a/torchrec/metrics/ne.py +++ b/torchrec/metrics/ne.py @@ -55,16 +55,14 @@ def compute_ne( eta: float, allow_missing_label_with_zero_weight: bool = False, ) -> torch.Tensor: - if allow_missing_label_with_zero_weight and not weighted_num_samples.all(): - # If nan were to occur, return a dummy value instead of nan if - # allow_missing_label_with_zero_weight is True - return torch.tensor([eta]) - - # Goes into this block if all elements in weighted_num_samples > 0 - weighted_num_samples = weighted_num_samples.double().clamp(min=eta) - mean_label = pos_labels / weighted_num_samples + clamped_weighted_num_samples = weighted_num_samples.double().clamp(min=eta) + mean_label = pos_labels / clamped_weighted_num_samples ce_norm = _compute_cross_entropy_norm(mean_label, pos_labels, neg_labels, eta) - return ce_sum / ce_norm + ne = ce_sum / ce_norm + if allow_missing_label_with_zero_weight and not weighted_num_samples.all(): + # If inf were to occur, return a dummy value instead. + return torch.where(weighted_num_samples > 0, ne, eta) + return ne def compute_logloss( diff --git a/torchrec/metrics/tests/test_ne.py b/torchrec/metrics/tests/test_ne.py index e2dcfc254..badaa9aac 100644 --- a/torchrec/metrics/tests/test_ne.py +++ b/torchrec/metrics/tests/test_ne.py @@ -194,6 +194,20 @@ def test_ne_zero_weights(self) -> None: zero_weights=True, ) + def test_ne_allow_missing_label_with_zero_weight(self) -> None: + eta = 1e-12 + ne = compute_ne( + ce_sum=torch.rand(3), + weighted_num_samples=torch.tensor([3, 0, 2]), + pos_labels=torch.tensor([1, 0, 2]), + neg_labels=torch.tensor([2, 0, 0]), + eta=eta, + allow_missing_label_with_zero_weight=True, + ) + self.assertTrue(torch.all(~ne.isinf())) + self.assertTrue(torch.all(~ne.isnan())) + self.assertTrue(torch.equal(ne.eq(eta), torch.tensor([False, True, False]))) + _logloss_metric_test_helper: Callable[..., None] = partial( metric_test_helper, include_logloss=True )