Skip to content

Add GrassiaIIGeometric Distribution #528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
269dd75
dist and rv init commit
ColtAllen Mar 29, 2025
b264161
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Apr 11, 2025
d734c68
docstrings
ColtAllen Apr 15, 2025
71bd632
Merge branch 'grassia2geo-dist' of https://github.com/ColtAllen/pymc-…
ColtAllen Apr 15, 2025
48e93f3
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Apr 15, 2025
93c4a60
unit tests
ColtAllen Apr 20, 2025
d2e72b5
alpha min value
ColtAllen Apr 20, 2025
8685005
revert alpha lim
ColtAllen Apr 21, 2025
026f182
small lam value tests
ColtAllen Apr 22, 2025
d12dd0b
ruff formatting
ColtAllen Apr 22, 2025
bcd9cac
TODOs
ColtAllen Apr 22, 2025
78be107
WIP add covar support to RV
ColtAllen Apr 22, 2025
f3ae359
Merge branch 'main' into grassia2geo-dist
ColtAllen Jun 20, 2025
8a30459
WIP time indexing
ColtAllen Jun 20, 2025
7c7afc8
WIP time indexing
ColtAllen Jun 20, 2025
fa9c1ec
Merge branch 'grassia2geo-dist' of https://github.com/ColtAllen/pymc-…
ColtAllen Jun 20, 2025
b957333
WIP symbolic indexing
ColtAllen Jun 20, 2025
d0c1d98
delete test_simple.py
ColtAllen Jun 20, 2025
264c55e
fix symbolic indexing errors
ColtAllen Jul 11, 2025
05e7c55
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Jul 11, 2025
0fa3390
clean up cursor code
ColtAllen Jul 11, 2025
5baa6f7
warn for ndims deprecation
ColtAllen Jul 11, 2025
a715ec7
clean up comments and final TODO
ColtAllen Jul 11, 2025
f3c0f29
remove ndims deprecation and extraneous code
ColtAllen Jul 11, 2025
a232e4c
revert changes to irrelevant test
ColtAllen Jul 12, 2025
ffc059f
remove time_covariate_vector default args
ColtAllen Jul 12, 2025
1d41eb7
revert remaining changes in irrelevant tests
ColtAllen Jul 12, 2025
47ad523
remove test_sampling_consistency
ColtAllen Jul 12, 2025
5b77263
checkpoint commit for log_cdf and test frameworks
ColtAllen Jul 12, 2025
eb7222f
checkpoint commit for log_cdf and test frameworks
ColtAllen Jul 12, 2025
b34e3d8
make C_t external function, code cleanup
ColtAllen Jul 12, 2025
9803321
rng_fn cleanup
ColtAllen Jul 13, 2025
5ff6853
WIP test frameworks
ColtAllen Jul 13, 2025
63a0b10
inverse cdf
ColtAllen Jul 15, 2025
932a046
covariate pos constraint and WIP RV
ColtAllen Jul 15, 2025
b78a5c4
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Jul 28, 2025
ab45a9c
WIP rng_fn testing
ColtAllen Jul 28, 2025
0d1dcea
WIP time covars required param
ColtAllen Jul 29, 2025
434e5a5
C_t for RV time covar support
ColtAllen Aug 10, 2025
c66c8a6
time_covar optional param
ColtAllen Aug 10, 2025
fb96220
restore GPT5 code
ColtAllen Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pymc_extras/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
BetaNegativeBinomial,
GeneralizedPoisson,
Skellam,
GrassiaIIGeometric,
)
from pymc_extras.distributions.histogram_utils import histogram_approximation
from pymc_extras.distributions.multivariate import R2D2M2CP
Expand All @@ -38,5 +39,6 @@
"R2D2M2CP",
"Skellam",
"histogram_approximation",
"GrassiaIIGeometric",
"PartialOrder",
]
223 changes: 223 additions & 0 deletions pymc_extras/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pymc as pm

from pymc.distributions.dist_math import betaln, check_parameters, factln, logpow
from pymc.distributions.distribution import Discrete
from pymc.distributions.shape_utils import rv_size_is_none
from pytensor import tensor as pt
from pytensor.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -399,3 +400,225 @@ def dist(cls, mu1, mu2, **kwargs):
class_name="Skellam",
**kwargs,
)


class GrassiaIIGeometricRV(RandomVariable):
name = "g2g"
signature = "(),(),(t)->()"

dtype = "int64"
_print_name = ("GrassiaIIGeometric", "\\operatorname{GrassiaIIGeometric}")

@classmethod
def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
# Aggregate time covariates for each sample before broadcasting
time_cov = np.asarray(time_covariate_vector)
if np.ndim(time_cov) == 0:
exp_time_covar = np.asarray(1.0)
else:
# Collapse all time/feature axes to a scalar multiplier for RNG
exp_time_covar = np.asarray(np.exp(time_cov).sum())

# Determine output size
if size is None:
size = np.broadcast_shapes(r.shape, alpha.shape, exp_time_covar.shape)

# Broadcast parameters to output size
r = np.broadcast_to(r, size)
alpha = np.broadcast_to(alpha, size)
exp_time_covar = np.broadcast_to(exp_time_covar, size)

lam = rng.gamma(shape=r, scale=1 / alpha, size=size)

lam_covar = lam * exp_time_covar

p = 1 - np.exp(-lam_covar)
# TODO: This is a hack to ensure valid probability in (0, 1]
# We should find a better way to do this.
# Ensure valid probability in (0, 1]
tiny = np.finfo(p.dtype).tiny
p = np.clip(p, tiny, 1.0)
samples = rng.geometric(p)
# samples = np.ceil(np.log(1 - rng.uniform(size=size)) / (-lam_covar))

return samples


g2g = GrassiaIIGeometricRV()


# TODO: Add covariate expressions to docstrings.
class GrassiaIIGeometric(Discrete):
r"""Grassia(II)-Geometric distribution.

This distribution is a flexible alternative to the Geometric distribution for the number of trials until a
discrete event, and can be extended to support both static and time-varying covariates.

Hardie and Fader describe this distribution with the following PMF and survival functions in [1]_:

.. math::
\mathbb{P}T=t|r,\alpha,\beta;Z(t)) = (\frac{\alpha}{\alpha+C(t-1)})^{r} - (\frac{\alpha}{\alpha+C(t)})^{r} \\
\begin{align}
\mathbb{S}(t|r,\alpha,\beta;Z(t)) = (\frac{\alpha}{\alpha+C(t)})^{r} \\
\end{align}

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
import arviz as az
plt.style.use('arviz-darkgrid')
t = np.arange(1, 11)
alpha_vals = [1., 1., 2., 2.]
r_vals = [.1, .25, .5, 1.]
for alpha, r in zip(alpha_vals, r_vals):
pmf = (alpha/(alpha + t - 1))**r - (alpha/(alpha+t))**r
plt.plot(t, pmf, '-o', label=r'$\alpha$ = {}, $r$ = {}'.format(alpha, r))
plt.xlabel('t', fontsize=12)
plt.ylabel('p(t)', fontsize=12)
plt.legend(loc=1)
plt.show()

======== ===============================================
Support :math:`t \in \mathbb{N}_{>0}`
======== ===============================================

Parameters
----------
r : tensor_like of float
Shape parameter (r > 0).
alpha : tensor_like of float
Scale parameter (alpha > 0).
time_covariate_vector : tensor_like of float
Vector containing dot products of time-varying covariates and coefficients.

References
----------
.. [1] Fader, Peter & G. S. Hardie, Bruce (2020).
"Incorporating Time-Varying Covariates in a Simple Mixture Model for Discrete-Time Duration Data."
https://www.brucehardie.com/notes/037/time-varying_covariates_in_BG.pdf
"""

rv_op = g2g

@classmethod
def dist(cls, r, alpha, time_covariate_vector=None, *args, **kwargs):
r = pt.as_tensor_variable(r)
alpha = pt.as_tensor_variable(alpha)

if time_covariate_vector is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like your logp doesn't handle ndim > 1 right? In that case raise NotImplementedError if value.ndim > 1 ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would hierarchical models still be supported if this were the case?

time_covariate_vector = pt.constant(0.0)
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)
# Normalize covariate to be 1D over time
if time_covariate_vector.ndim == 0:
time_covariate_vector = pt.reshape(time_covariate_vector, (1,))
elif time_covariate_vector.ndim > 1:
feature_axes = tuple(range(time_covariate_vector.ndim - 1))
time_covariate_vector = pt.sum(time_covariate_vector, axis=feature_axes)

return super().dist([r, alpha, time_covariate_vector], *args, **kwargs)

def logp(value, r, alpha, time_covariate_vector):
v = pt.as_tensor_variable(value)
ct_prev = C_t(v - 1, time_covariate_vector)
ct_curr = C_t(v, time_covariate_vector)
logS_prev = r * (pt.log(alpha) - pt.log(alpha + ct_prev))
logS_curr = r * (pt.log(alpha) - pt.log(alpha + ct_curr))
# Compute log(exp(logS_prev) - exp(logS_curr)) stably
max_logS = pt.maximum(logS_prev, logS_curr)
diff = pt.exp(logS_prev - max_logS) - pt.exp(logS_curr - max_logS)
logp = max_logS + pt.log(diff)

# Handle invalid / out-of-domain values
logp = pt.switch(value < 1, -np.inf, logp)

return check_parameters(
logp,
r > 0,
alpha > 0,
msg="r > 0, alpha > 0",
)

def logcdf(value, r, alpha, time_covariate_vector):
# Log CDF: log(1 - (alpha / (alpha + C(t)))**r)
t = pt.as_tensor_variable(value)
ct = C_t(t, time_covariate_vector)
logS = r * (pt.log(alpha) - pt.log(alpha + ct))
# Numerically stable log(1 - exp(logS))
logcdf = pt.switch(
pt.lt(logS, np.log(0.5)),
pt.log1p(-pt.exp(logS)),
pt.log(-pt.expm1(logS)),
)

return check_parameters(
logcdf,
r > 0,
alpha > 0,
msg="r > 0, alpha > 0",
)

def support_point(rv, size, r, alpha, time_covariate_vector):
"""Calculate a reasonable starting point for sampling.

For the GrassiaIIGeometric distribution, we use a point estimate based on
the expected value of the mixing distribution. Since the mixing distribution
is Gamma(r, 1/alpha), its mean is r/alpha. We then transform this through
the geometric link function and round to ensure an integer value.

When time_covariate_vector is provided, it affects the expected value through
the exponential link function: exp(time_covariate_vector).
"""
base_lambda = r / alpha

# Approximate expected value of geometric distribution
mean = pt.switch(
base_lambda < 0.1,
1.0 / base_lambda, # Approximation for small lambda
1.0 / (1.0 - pt.exp(-base_lambda)), # Full expression for larger lambda
)

# Apply time covariates if provided: multiply by exp(sum over axis=0)
# This yields a scalar for 1D covariates and a time-length vector for 2D (features x time)
tcv = pt.as_tensor_variable(time_covariate_vector)
if tcv.ndim != 0:
mean = mean * pt.exp(tcv.sum(axis=0))

# Round up to nearest integer and ensure >= 1
mean = pt.maximum(pt.ceil(mean), 1.0)

# Handle size parameter
if not rv_size_is_none(size):
mean = pt.full(size, mean)

return mean


def C_t(t: pt.TensorVariable, time_covariate_vector: pt.TensorVariable) -> pt.TensorVariable:
"""Utility for processing time-varying covariates in GrassiaIIGeometric distribution."""
# If unspecified (scalar), simply return t
if time_covariate_vector.ndim == 0:
return t

# Sum exp(covariates) across feature axes, keep last axis as time
if time_covariate_vector.ndim == 1:
per_time_sum = pt.exp(time_covariate_vector)
else:
# If axis=0 is time and axis>0 are features, sum over features (axis>0)
per_time_sum = pt.sum(pt.exp(time_covariate_vector), axis=0)

# Build cumulative sum up to each t without advanced indexing
time_length = pt.shape(per_time_sum)[0]
# Ensure t is at least 1D int64 for broadcasting
t_vec = pt.cast(t, "int64")
t_vec = pt.shape_padleft(t_vec) if t_vec.ndim == 0 else t_vec
# Create time indices [0, 1, ..., T-1]
time_idx = pt.arange(time_length, dtype="int64")
# Mask where time index < t (exclusive upper bound)
mask = pt.lt(time_idx, pt.shape_padright(t_vec, 1))
# Sum per-time contributions over time axis
base_sum = pt.sum(pt.shape_padleft(per_time_sum) * mask, axis=-1)
# If original t was scalar, return scalar (saturate at last time step)
return pt.squeeze(base_sum)
Loading