Skip to content

Commit 3aed387

Browse files
authored
Fix Beta with concentration1=1 gives nan log_prob at value=0 (#2089)
* fix * add custom gradient * simplify tests * another approach * clean up merge * simplyfy with double where trick * use dirichlet * simplify comments * simplify return
1 parent 47335b8 commit 3aed387

File tree

2 files changed

+152
-1
lines changed

2 files changed

+152
-1
lines changed

numpyro/distributions/continuous.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,30 @@ def sample(
216216

217217
@validate_sample
218218
def log_prob(self, value: ArrayLike) -> ArrayLike:
219-
return self._dirichlet.log_prob(jnp.stack([value, 1.0 - value], -1))
219+
# Use double-where trick to avoid NaN gradients at boundary conditions
220+
# Reference: https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
221+
is_boundary = (value == 0.0) | (value == 1.0)
222+
223+
# Mask boundary values (0 or 1) to safe value (0.5) for gradient computation
224+
safe_value = jnp.where(is_boundary, 0.5, value)
225+
safe_complement = jnp.where(is_boundary, 0.5, 1.0 - value)
226+
227+
# Compute log_prob with safe values (gradients flow through this path)
228+
safe_dirichlet_value = jnp.stack([safe_value, safe_complement], axis=-1)
229+
safe_log_prob = self._dirichlet.log_prob(safe_dirichlet_value)
230+
231+
# At boundaries, compute correct forward value using xlogy (handles 0*log(0)=0)
232+
# Use stop_gradient so gradients come only from safe_log_prob
233+
correct_value = (
234+
xlogy(self.concentration1 - 1.0, value)
235+
+ xlogy(self.concentration0 - 1.0, 1.0 - value)
236+
- betaln(self.concentration1, self.concentration0)
237+
)
238+
239+
# Apply correction at boundaries, return safe value elsewhere
240+
return jnp.where(
241+
is_boundary, jax.lax.stop_gradient(correct_value), safe_log_prob
242+
)
220243

221244
@property
222245
def mean(self) -> ArrayLike:

test/test_distributions.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4486,3 +4486,131 @@ def test_interval_censored_validate_sample(
44864486
censored_dist.log_prob(value)
44874487
else:
44884488
censored_dist.log_prob(value) # Should not raise
4489+
4490+
4491+
@pytest.mark.parametrize(
4492+
argnames="concentration1,concentration0,value",
4493+
argvalues=[
4494+
(1.0, 8.0, 0.0),
4495+
(8.0, 1.0, 1.0),
4496+
],
4497+
ids=["Beta(1,8) at x=0", "Beta(8,1) at x=1"],
4498+
)
4499+
def test_beta_logprob_edge_cases(concentration1, concentration0, value):
4500+
"""Test Beta distribution with concentration=1 gives finite log probability at boundary."""
4501+
beta_dist = dist.Beta(concentration1, concentration0)
4502+
log_prob = beta_dist.log_prob(value)
4503+
4504+
assert not jnp.isnan(log_prob), (
4505+
f"Beta({concentration1},{concentration0}).log_prob({value}) should not be NaN"
4506+
)
4507+
assert jnp.isfinite(log_prob), (
4508+
f"Beta({concentration1},{concentration0}).log_prob({value}) should be finite"
4509+
)
4510+
4511+
4512+
def test_beta_logprob_edge_case_consistency_small_values():
4513+
"""Test that edge case values are consistent with small deviation values."""
4514+
beta_dist = dist.Beta(1.0, 8.0)
4515+
beta_dist2 = dist.Beta(8.0, 1.0)
4516+
4517+
# At boundary
4518+
log_prob_at_zero = beta_dist.log_prob(0.0)
4519+
log_prob_at_one = beta_dist2.log_prob(1.0)
4520+
4521+
# Very close to boundary
4522+
small_value = 1e-10
4523+
log_prob_small = beta_dist.log_prob(small_value)
4524+
log_prob_close_to_one = beta_dist2.log_prob(1.0 - small_value)
4525+
4526+
# Edge case values should be close to small deviation values
4527+
assert jnp.abs(log_prob_at_zero - log_prob_small) < 1e-5
4528+
assert jnp.abs(log_prob_at_one - log_prob_close_to_one) < 1e-5
4529+
4530+
4531+
def test_beta_logprob_edge_case_non_boundary_values():
4532+
"""Test that Beta with concentration=1 still works for non-boundary values."""
4533+
beta_dist = dist.Beta(1.0, 8.0)
4534+
beta_dist2 = dist.Beta(8.0, 1.0)
4535+
4536+
assert jnp.isfinite(beta_dist.log_prob(0.5))
4537+
assert jnp.isfinite(beta_dist2.log_prob(0.5))
4538+
4539+
4540+
def test_beta_logprob_boundary_non_edge_cases():
4541+
"""Test that non-edge cases (concentration > 1) still give -inf at boundaries."""
4542+
beta_dist3 = dist.Beta(2.0, 8.0)
4543+
beta_dist4 = dist.Beta(8.0, 2.0)
4544+
4545+
assert jnp.isneginf(beta_dist3.log_prob(0.0))
4546+
assert jnp.isneginf(beta_dist4.log_prob(1.0))
4547+
4548+
4549+
@pytest.mark.parametrize(
4550+
argnames="concentration1,concentration0,value,grad_param,grad_value",
4551+
argvalues=[
4552+
(1.0, 8.0, 0.0, "value", 0.0),
4553+
(8.0, 1.0, 1.0, "value", 1.0),
4554+
(1.0, 8.0, 0.0, "concentration1", 1.0),
4555+
(1.0, 8.0, 0.0, "concentration0", 8.0),
4556+
(8.0, 1.0, 1.0, "concentration1", 8.0),
4557+
(8.0, 1.0, 1.0, "concentration0", 1.0),
4558+
],
4559+
ids=[
4560+
"Beta(1,8) at x=0",
4561+
"Beta(8,1) at x=1",
4562+
"Beta(1,8) at concentration1=1",
4563+
"Beta(1,8) at concentration0=8",
4564+
"Beta(8,1) at concentration1=8",
4565+
"Beta(8,1) at concentration0=1",
4566+
],
4567+
)
4568+
def test_beta_gradient_edge_cases_single_param(
4569+
concentration1, concentration0, value, grad_param, grad_value
4570+
):
4571+
"""Test that gradients w.r.t. individual parameters are finite at edge cases."""
4572+
if grad_param == "value":
4573+
4574+
def log_prob_fn(x):
4575+
return dist.Beta(concentration1, concentration0).log_prob(x)
4576+
4577+
grad = jax.grad(log_prob_fn)(value)
4578+
elif grad_param == "concentration1":
4579+
4580+
def log_prob_fn(c1):
4581+
return dist.Beta(c1, concentration0).log_prob(value)
4582+
4583+
grad = jax.grad(log_prob_fn)(grad_value)
4584+
else: # concentration0
4585+
4586+
def log_prob_fn(c0):
4587+
return dist.Beta(concentration1, c0).log_prob(value)
4588+
4589+
grad = jax.grad(log_prob_fn)(grad_value)
4590+
4591+
assert jnp.isfinite(grad), (
4592+
f"Gradient w.r.t. {grad_param} for Beta({concentration1},{concentration0}) "
4593+
f"at x={value} should be finite"
4594+
)
4595+
4596+
4597+
@pytest.mark.parametrize(
4598+
argnames="concentration1,concentration0,value",
4599+
argvalues=[
4600+
(1.0, 8.0, 0.0),
4601+
(8.0, 1.0, 1.0),
4602+
],
4603+
ids=["Beta(1,8) at x=0", "Beta(8,1) at x=1"],
4604+
)
4605+
def test_beta_gradient_edge_cases_all_params(concentration1, concentration0, value):
4606+
"""Test that all gradients are finite when computed simultaneously at edge cases."""
4607+
4608+
def log_prob_fn(params):
4609+
c1, c0, v = params
4610+
return dist.Beta(c1, c0).log_prob(v)
4611+
4612+
grads = jax.grad(log_prob_fn)(jnp.array([concentration1, concentration0, value]))
4613+
assert jnp.all(jnp.isfinite(grads)), (
4614+
f"All gradients for Beta({concentration1},{concentration0}) at x={value} "
4615+
f"should be finite"
4616+
)

0 commit comments

Comments
 (0)