@@ -4486,3 +4486,131 @@ def test_interval_censored_validate_sample(
44864486 censored_dist .log_prob (value )
44874487 else :
44884488 censored_dist .log_prob (value ) # Should not raise
4489+
4490+
4491+ @pytest .mark .parametrize (
4492+ argnames = "concentration1,concentration0,value" ,
4493+ argvalues = [
4494+ (1.0 , 8.0 , 0.0 ),
4495+ (8.0 , 1.0 , 1.0 ),
4496+ ],
4497+ ids = ["Beta(1,8) at x=0" , "Beta(8,1) at x=1" ],
4498+ )
4499+ def test_beta_logprob_edge_cases (concentration1 , concentration0 , value ):
4500+ """Test Beta distribution with concentration=1 gives finite log probability at boundary."""
4501+ beta_dist = dist .Beta (concentration1 , concentration0 )
4502+ log_prob = beta_dist .log_prob (value )
4503+
4504+ assert not jnp .isnan (log_prob ), (
4505+ f"Beta({ concentration1 } ,{ concentration0 } ).log_prob({ value } ) should not be NaN"
4506+ )
4507+ assert jnp .isfinite (log_prob ), (
4508+ f"Beta({ concentration1 } ,{ concentration0 } ).log_prob({ value } ) should be finite"
4509+ )
4510+
4511+
4512+ def test_beta_logprob_edge_case_consistency_small_values ():
4513+ """Test that edge case values are consistent with small deviation values."""
4514+ beta_dist = dist .Beta (1.0 , 8.0 )
4515+ beta_dist2 = dist .Beta (8.0 , 1.0 )
4516+
4517+ # At boundary
4518+ log_prob_at_zero = beta_dist .log_prob (0.0 )
4519+ log_prob_at_one = beta_dist2 .log_prob (1.0 )
4520+
4521+ # Very close to boundary
4522+ small_value = 1e-10
4523+ log_prob_small = beta_dist .log_prob (small_value )
4524+ log_prob_close_to_one = beta_dist2 .log_prob (1.0 - small_value )
4525+
4526+ # Edge case values should be close to small deviation values
4527+ assert jnp .abs (log_prob_at_zero - log_prob_small ) < 1e-5
4528+ assert jnp .abs (log_prob_at_one - log_prob_close_to_one ) < 1e-5
4529+
4530+
4531+ def test_beta_logprob_edge_case_non_boundary_values ():
4532+ """Test that Beta with concentration=1 still works for non-boundary values."""
4533+ beta_dist = dist .Beta (1.0 , 8.0 )
4534+ beta_dist2 = dist .Beta (8.0 , 1.0 )
4535+
4536+ assert jnp .isfinite (beta_dist .log_prob (0.5 ))
4537+ assert jnp .isfinite (beta_dist2 .log_prob (0.5 ))
4538+
4539+
4540+ def test_beta_logprob_boundary_non_edge_cases ():
4541+ """Test that non-edge cases (concentration > 1) still give -inf at boundaries."""
4542+ beta_dist3 = dist .Beta (2.0 , 8.0 )
4543+ beta_dist4 = dist .Beta (8.0 , 2.0 )
4544+
4545+ assert jnp .isneginf (beta_dist3 .log_prob (0.0 ))
4546+ assert jnp .isneginf (beta_dist4 .log_prob (1.0 ))
4547+
4548+
4549+ @pytest .mark .parametrize (
4550+ argnames = "concentration1,concentration0,value,grad_param,grad_value" ,
4551+ argvalues = [
4552+ (1.0 , 8.0 , 0.0 , "value" , 0.0 ),
4553+ (8.0 , 1.0 , 1.0 , "value" , 1.0 ),
4554+ (1.0 , 8.0 , 0.0 , "concentration1" , 1.0 ),
4555+ (1.0 , 8.0 , 0.0 , "concentration0" , 8.0 ),
4556+ (8.0 , 1.0 , 1.0 , "concentration1" , 8.0 ),
4557+ (8.0 , 1.0 , 1.0 , "concentration0" , 1.0 ),
4558+ ],
4559+ ids = [
4560+ "Beta(1,8) at x=0" ,
4561+ "Beta(8,1) at x=1" ,
4562+ "Beta(1,8) at concentration1=1" ,
4563+ "Beta(1,8) at concentration0=8" ,
4564+ "Beta(8,1) at concentration1=8" ,
4565+ "Beta(8,1) at concentration0=1" ,
4566+ ],
4567+ )
4568+ def test_beta_gradient_edge_cases_single_param (
4569+ concentration1 , concentration0 , value , grad_param , grad_value
4570+ ):
4571+ """Test that gradients w.r.t. individual parameters are finite at edge cases."""
4572+ if grad_param == "value" :
4573+
4574+ def log_prob_fn (x ):
4575+ return dist .Beta (concentration1 , concentration0 ).log_prob (x )
4576+
4577+ grad = jax .grad (log_prob_fn )(value )
4578+ elif grad_param == "concentration1" :
4579+
4580+ def log_prob_fn (c1 ):
4581+ return dist .Beta (c1 , concentration0 ).log_prob (value )
4582+
4583+ grad = jax .grad (log_prob_fn )(grad_value )
4584+ else : # concentration0
4585+
4586+ def log_prob_fn (c0 ):
4587+ return dist .Beta (concentration1 , c0 ).log_prob (value )
4588+
4589+ grad = jax .grad (log_prob_fn )(grad_value )
4590+
4591+ assert jnp .isfinite (grad ), (
4592+ f"Gradient w.r.t. { grad_param } for Beta({ concentration1 } ,{ concentration0 } ) "
4593+ f"at x={ value } should be finite"
4594+ )
4595+
4596+
4597+ @pytest .mark .parametrize (
4598+ argnames = "concentration1,concentration0,value" ,
4599+ argvalues = [
4600+ (1.0 , 8.0 , 0.0 ),
4601+ (8.0 , 1.0 , 1.0 ),
4602+ ],
4603+ ids = ["Beta(1,8) at x=0" , "Beta(8,1) at x=1" ],
4604+ )
4605+ def test_beta_gradient_edge_cases_all_params (concentration1 , concentration0 , value ):
4606+ """Test that all gradients are finite when computed simultaneously at edge cases."""
4607+
4608+ def log_prob_fn (params ):
4609+ c1 , c0 , v = params
4610+ return dist .Beta (c1 , c0 ).log_prob (v )
4611+
4612+ grads = jax .grad (log_prob_fn )(jnp .array ([concentration1 , concentration0 , value ]))
4613+ assert jnp .all (jnp .isfinite (grads )), (
4614+ f"All gradients for Beta({ concentration1 } ,{ concentration0 } ) at x={ value } "
4615+ f"should be finite"
4616+ )
0 commit comments