Skip to content

Add GrassiaIIGeometric Distribution #528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 33 commits into
base: main
Choose a base branch
from

Conversation

ColtAllen
Copy link

@ColtAllen ColtAllen commented Jun 20, 2025

This PR closes #438. Still a fair bit of work left to do on this, but creating a draft PR for now to resolve a merge conflict.

@ColtAllen ColtAllen closed this Jun 20, 2025
@ColtAllen ColtAllen reopened this Jun 20, 2025
@ColtAllen ColtAllen marked this pull request as draft June 20, 2025 14:32
@ColtAllen ColtAllen changed the title resolve merge conflicts with main Add GrassiaIIGeometric Distribution Jun 20, 2025
@ColtAllen ColtAllen marked this pull request as ready for review July 11, 2025 14:06
@ColtAllen
Copy link
Author

ColtAllen commented Jul 11, 2025

Tests are passing, but when tested in a notebook with pymc.Censored, I get the following error:

Error Trace
TypeError                                 Traceback (most recent call last)
Cell In[40], line 30
     20     churn_raw = clv.distributions.GrassiaIIGeometric.dist(r=r, alpha=alpha)
     21     cens = pm.Censored(
     22         "censored",
     23         churn_raw,
   (...)
     27         dims=("customer_id",),
     28     )
---> 30     idata = pm.sample(idata_kwargs={"log_likelihood": True})

File ~/mamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/pymc/sampling/mcmc.py:789, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    786     msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
    787     _log.warning(msg)
--> 789 provided_steps, selected_steps = assign_step_methods(model, step, methods=pm.STEP_METHODS)
    790 exclusive_nuts = (
    791     # User provided an instantiated NUTS step, and nothing else is needed
    792     (not selected_steps and len(provided_steps) == 1 and isinstance(provided_steps[0], NUTS))
   (...)
    799     )
    800 )
    802 if nuts_sampler != "pymc":

File ~/mamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/pymc/sampling/mcmc.py:247, in assign_step_methods(model, step, methods)
    245 methods_list: list[type[BlockedStep]] = list(methods or pm.STEP_METHODS)
    246 selected_steps: dict[type[BlockedStep], list] = {}
--> 247 model_logp = model.logp()
    249 for var in model.value_vars:
    250     if var not in assigned_vars:
    251         # determine if a gradient can be computed

File ~/mamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/pymc/model/core.py:696, in Model.logp(self, vars, jacobian, sum)
    694 rv_logps: list[TensorVariable] = []
    695 if rvs:
--> 696     rv_logps = transformed_conditional_logp(
    697         rvs=rvs,
    698         rvs_to_values=self.rvs_to_values,
    699         rvs_to_transforms=self.rvs_to_transforms,
    700         jacobian=jacobian,
    701     )
    702     assert isinstance(rv_logps, list)
    704 # Replace random variables by their value variables in potential terms

File ~/mamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/pymc/logprob/basic.py:595, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
    592     transform_rewrite = TransformValuesRewrite(values_to_transforms)  # type: ignore[arg-type]
    594 kwargs.setdefault("warn_rvs", False)
--> 595 temp_logp_terms = conditional_logp(
    596     rvs_to_values,
    597     extra_rewrites=transform_rewrite,
    598     use_jacobian=jacobian,
    599     **kwargs,
    600 )
    602 # The function returns the logp for every single value term we provided to it.
    603 # This includes the extra values we plugged in above, so we filter those we
    604 # actually wanted in the same order they were given in.
    605 logp_terms = {}

File ~/mamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/pymc/logprob/basic.py:529, in conditional_logp(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)
    526 node_values = remapped_vars[: len(node_values)]
    527 node_inputs = remapped_vars[len(node_values) :]
--> 529 node_logprobs = _logprob(
    530     node.op,
    531     node_values,
    532     *node_inputs,
    533     **kwargs,
    534 )
    536 if not isinstance(node_logprobs, list | tuple):
    537     node_logprobs = [node_logprobs]

File ~/mamba/envs/pymc-marketing-dev/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/mamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/pymc/logprob/censoring.py:111, in clip_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs)
    108 base_rv_op = base_rv.owner.op
    109 base_rv_inputs = base_rv.owner.inputs
--> 111 logprob = _logprob(base_rv_op, (value,), *base_rv_inputs, **kwargs)
    112 logcdf = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs)
    114 if base_rv_op.name:

File ~/mamba/envs/pymc-marketing-dev/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/mamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/pymc/distributions/distribution.py:140, in DistributionMeta.__new__.<locals>.logp(op, values, *dist_params, **kwargs)
    138     dist_params = [dist_params[i] for i in params_idxs]
    139 [value] = values
--> 140 return class_logp(value, *dist_params)

File ~/Projects/pymc_extras/distributions/discrete.py:542, in GrassiaIIGeometric.logp(value, r, alpha, time_covariate_vector)
   539          covariate_value = time_covariate_vector[safe_idx]
   540          return t * pt.exp(covariate_value)
   541   logp = pt.log(
-> 542     pt.pow(alpha / (alpha + C_t(value - 1)), r)
   543      - pt.pow(alpha / (alpha + C_t(value)), r)
   544  )

File ~/Projects//pymc_extras/distributions/discrete.py:538, in GrassiaIIGeometric.logp.<locals>.C_t(t)
   536  max_idx = pt.shape(time_covariate_vector)[0] - 1
   537  safe_idx = pt.minimum(t_idx, max_idx)
-> 538  covariate_value = time_covariate_vector[safe_idx]
   539  return t * pt.exp(covariate_value)

File ~/mamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/tensor/variable.py:557, in _tensor_py_operators.__getitem__(self, args)
    554                 advanced = True
    556 if advanced:
--> 557     return pt.subtensor.advanced_subtensor(self, *args)
    558 else:
    559     if np.newaxis in args or NoneConst in args:
    560         # `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new
    561         # broadcastable dimension at this location".  Since PyTensor adds
   (...)
    564         # then uses recursion to apply any other indices and add any
    565         # remaining new axes.

File ~/mamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/graph/op.py:293, in Op.__call__(self, name, return_list, *inputs, **kwargs)
    249 def __call__(
    250     self, *inputs: Any, name=None, return_list=False, **kwargs
    251 ) -> Variable | list[Variable]:
    252     r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    253 
    254     This method is just a wrapper around :meth:`Op.make_node`.
   (...)
    291 
    292     """
--> 293     node = self.make_node(*inputs, **kwargs)
    294     if name is not None:
    295         if len(node.outputs) == 1:

File ~/mamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/tensor/subtensor.py:2817, in AdvancedSubtensor.make_node(self, x, *index)
   2815 def make_node(self, x, *index):
   2816     x = as_tensor_variable(x)
-> 2817     index = tuple(map(as_index_variable, index))
   2819     # We create a fake symbolic shape tuple and identify the broadcast
   2820     # dimensions from the shape result of this entire subtensor operation.
   2821     with config.change_flags(compute_test_value="off"):

File ~/mamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/tensor/subtensor.py:2774, in as_index_variable(idx)
   2772 idx = as_tensor_variable(idx)
   2773 if idx.type.dtype not in discrete_dtypes:
-> 2774     raise TypeError("index must be integers or a boolean mask")
   2775 if idx.type.dtype == "bool" and idx.type.ndim == 0:
   2776     raise NotImplementedError(
   2777         "Boolean scalar indexing not implemented. "
   2778         "Open an issue in https://github.com/pymc-devs/pytensor/issues if you need this behavior."
   2779     )

TypeError: index must be integers or a boolean mask

Copy link
Member

@fonnesbeck fonnesbeck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. Just a couple suggestions.

time_covariate_vector = pt.constant(0.0)
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)

def C_t(t):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be moved outside the function so that it can be reused by logp?

Copy link
Author

@ColtAllen ColtAllen Jul 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how kosher this is, but due to how instantiation is handled in logp and logcdf, I had to move C_t outside the distribution class altogether to get it to work.

samples = rng.geometric(p)

# Clip samples to reasonable bounds to prevent infinite values
max_sample = 10000 # Reasonable upper bound for discrete time-to-event data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this an argument?

@@ -208,3 +208,239 @@ def test_logp(self):
{"mu1": Rplus_small, "mu2": Rplus_small},
lambda value, mu1, mu2: scipy.stats.skellam.logpmf(value, mu1, mu2),
)


class TestGrassiaIIGeometric:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can a test for logcdf be added?


# Ensure p is in valid range for geometric distribution
min_p = max(1e-6, np.finfo(float).tiny) # Minimum probability to prevent infinite values
p = np.clip(p, min_p, 1.0)
Copy link
Member

@ricardoV94 ricardoV94 Jul 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both clips are suspicious. Can't we compute draws in a more stable way. You could use log1mexp to get p on log scale, is there a way to sample from a geometric with log_p instead of p?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to sample from a custom PMF rather than a Geometric?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need an inverse CDF or a custom sampling algorithm. There's no generic way of sampling from a PMF (other than running some MCMC algorithm)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can derive the inverse CDF (another task for the trusty whiteboard). Would adding this to rng_fn look something like rng.inverse_cdf?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have a stable inverse cdf (or even better inverse log_cdf), you could then take a uniform (or log uniform) draw and pass it through the icdf to get a draw. I'm not saying that's the best route, just that those clips are very indicative of a poor random implementation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After doing the derivations, neither the inverse CDF nor its log are stable unfortunately. I also ran into the same issue as below of having to average the covariate vector.

assert np.mean(draws) > 0
assert np.var(draws) > 0

def test_logp_basic(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to use the existing frameworks to test logp/logcdf and support point?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_logp? There is some inconsistency in how the other distributions are tested.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Sure but if it works for your distribution that's the best. If it doesn't would be good to know why

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think I'll be able to use check_logp or any others expecting a scipy equivalent. check_selfconsistency_discrete_logcdf and assert_support_point_is_expected could still be useful though.

@@ -509,7 +509,7 @@ def dist(cls, r, alpha, time_covariate_vector=None, *args, **kwargs):
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)
return super().dist([r, alpha, time_covariate_vector], *args, **kwargs)

def logp(value, r, alpha, time_covariate_vector=None):
def logp(value, r, alpha, time_covariate_vector):
if time_covariate_vector is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like your logp doesn't handle ndim > 1 right? In that case raise NotImplementedError if value.ndim > 1 ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would hierarchical models still be supported if this were the case?

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 12, 2025

@ColtAllen I think you can simplify the rng and get more stability. To take draws from a geometric with p you can do np.ceil(np.log(U) / np.log(1 - p))), which simplifies away your 1 - exp in p. I think this is equivalent:

u = rng.uniform(size=size)
samples = np.ceil(np.log(u) / -lam_covar)

Where lam_covar is defined as in your current impl

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 12, 2025

Or equivalently: samples = np.ceil(rng.exponential(size=size) / lam_covar), but I don't think it matters

@ColtAllen
Copy link
Author

Or equivalently: samples = np.ceil(rng.exponential(size=size) / lam_covar), but I don't think it matters

Thanks; I've had the best luck with this approach. A caveat is the p expression is per time period, so I had to settle for averaging the covariate vector to return a single sample per vector.

I'm encountering a cryptic error for test_random_edge_cases:

E       RuntimeWarning: invalid value encountered in cast
E       Apply node that caused the error: g2g_rv{"(),(),()->()"}(RNG(<Generator(PCG64) at 0x152884E40>), [1000], [0.1], [5.], [0.])
E       Toposort index: 0
E       Inputs types: [RandomGeneratorType, TensorType(int64, shape=(1,)), TensorType(float64, shape=(1,)), TensorType(float32, shape=(1,)), TensorType(float32, shape=(1,))]
E       Inputs shapes: ['No shapes', (1,), (1,), (1,), (1,)]
E       Inputs strides: ['No strides', (8,), (8,), (4,), (4,)]
E       Inputs values: [Generator(PCG64) at 0x152884E40, array([1000]), array([0.1]), array([5.], dtype=float32), array([0.], dtype=float32)]
E       Outputs clients: [[], [output[0](g2g_rv{"(),(),()->()"}.out)]]

# Calculate exp(time_covariate_vector) for all samples
exp_time_covar = np.exp(
time_covariate_vector
).mean() # must average over time for correct broadcasting
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this. Why do you take the mean? What is this distribution supposed to do with a time_covariate_vector in the theoretical sense?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if it's a vector your signature is wrong, it should be (),(),(a)->() or perhaps (),(),(a)->(a), it's no longer a univariate distributions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also you should try to make your RV/logp work with batch dims, so you should be doing stuff like mean(axis=-1), and indexing with [..., safe_idx], or explicitly raise NotImplementedError if you don't want to supprot batch dims. By default we assume the implementations work with batch parameters.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you take the mean? What is this distribution supposed to do with a time_covariate_vector in the theoretical sense?

For the PMF and survival function, the covariate vector is exponentiated and summed over all active time periods. However, the research note implies geometric samples are drawn for each time period. If batch parameters are supported by default perhaps there's nothing to worry about.

Also if it's a vector your signature is wrong, it should be (),(),(a)->() or perhaps (),(),(a)->(a), it's no longer a univariate distributions

I'll have to investigate this more; changing the signature breaks most of the tests. Also, doesn't pymc.Censored still only support univariate distributions?

Copy link
Member

@ricardoV94 ricardoV94 Jul 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the RV actually representing in plain english? What would be the meaning of each geometric per covariate? Would their sum (or some other aggregation) make sense?

Also, doesn't pymc.Censored still only support univariate distributions?

If it's (),(),(a)->() it's still univariate, it just has a vector parameter in the core case (like Categorical which takes a vector of p, and returns an integer).

Still we can relax the constraint of Censored, we just never did because there was no multivariate distribution with a cdf, so it didn't make sense to bother back then.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding tests breaking, yes the signature changes the meaning of some things, so we need to restructure it a bit. First thing we should clarify what is the RV supposed to do.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the RV actually representing in plain english?

Time period at which an event occurs.

What would be the meaning of each geometric per covariate? Would their sum (or some other aggregation) make sense?

p is typically fixed for Geometric draws, but the covariate vector allows it to vary over time, similar to an Accelerated Failure Time model. Summing would make the most sense, but will have to think about the formulation to ensure 0 < p <= 1.

Copy link
Member

@ricardoV94 ricardoV94 Jul 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if you have covariate vector you also observe multiple times, and in your logp you have a value vector that's as long as the covariate for the values (for a single subject)?

The thing that makes this a univariate distribution is that a subject has a constant lambda over all these events?

In that case it sounds like you have a "(),(),(a)->(a)" indeed? Just have to adjust because in that case a should always be atleast a vector (even if there's only one constant 0) and size doesn't include a, it's whats batched on top of that, like in the mvnormal size doesn't include the shape implied by the last dimensions of mu or cov

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in your logp you have a value vector that's as long as the covariate for the values (for a single subject)?

It's a value scalar: value == len(covariate_vector). Covariates are actually optional for this distribution.

a subject has a constant lambda over all these events?

Yes

In that case it sounds like you have a "(),(),(a)->(a)" indeed?

This might be true in the case of RV draws. Is a == len(convariate_vector)?

Copy link
Member

@ricardoV94 ricardoV94 Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't have a scalar and multivariate distribution with the same signature, so you could consider len(covar)==1 on the scalar case, and set it to zero like you already do.

Otherwise you can have two instances of the rv with two signatures, but not sure it's worth the trouble

yes a in the signature is covar. You can give it whatever name you want in the signature. Usually we try to keep short or one letter but doesn't really matter. a isn't great I guess

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add GrassiaIIGeometric Distribution
3 participants