Skip to content

Support batched parameter inputs to DiscreteMarkovChain #550

@jessegrabowski

Description

@jessegrabowski

Support I have a dataset of shape (T, N) with integer observations over time, and I want to model one transition matrix for each of the N time series, as in:

import pymc as pm
import pymc_extras as pmx

T = 100
N = 30
n_states = 3

coords = {'series': np.arange(N), 'time': np.arange(T), 'state':np.arange(n_states), 'next_state':np.arange(n_states)}

with pm.Model(coords=coords) as m:
    init_dist = pm.Categorical.dist(logit_p=np.ones(3,), shape=(N,))
    P = pm.Dirichlet('P', a=np.ones((N, n_states, n_states)), dims=['series', 'state', 'next_state'])
    obs = pmx.DiscreteMarkovChain('obs', P=P, init_dist=init_dist, dims=['series', 'time'])

This currently fails with a scan error, but should be allowed.

The scan error
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[25], line 14
     12 init_dist = pm.Categorical.dist(logit_p=np.ones(n_states), shape=(N,))
     13 P = pm.Dirichlet('P', a=np.ones((N, n_states, n_states)), dims=['series', 'state', 'next_state'])
---> 14 obs = pmx.DiscreteMarkovChain('obs', P=P, init_dist=init_dist, dims=['series', 'time'])

File ~/.conda/envs/rx-bonds/lib/python3.11/site-packages/pymc_extras/distributions/timeseries.py:132, in DiscreteMarkovChain.__new__(cls, steps, n_lags, *args, **kwargs)
    123 def __new__(cls, *args, steps=None, n_lags=1, **kwargs):
    124     steps = get_support_shape_1d(
    125         support_shape=steps,
    126         shape=None,
   (...)    129         support_shape_offset=n_lags,
    130     )
--> 132     return super().__new__(cls, *args, steps=steps, n_lags=n_lags, **kwargs)

File ~/.conda/envs/rx-bonds/lib/python3.11/site-packages/pymc/distributions/distribution.py:505, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, default_transform, *args, **kwargs)
    502     elif observed is not None:
    503         kwargs["shape"] = tuple(observed.shape)
--> 505 rv_out = cls.dist(*args, **kwargs)
    507 rv_out = model.register_rv(
    508     rv_out,
    509     name,
   (...)    515     initval=initval,
    516 )
    518 # add in pretty-printing support

File ~/.conda/envs/rx-bonds/lib/python3.11/site-packages/pymc_extras/distributions/timeseries.py:178, in DiscreteMarkovChain.dist(cls, P, logit_P, steps, init_dist, n_lags, **kwargs)
    175     k = P.shape[-1]
    176     init_dist = pm.Categorical.dist(p=pt.full((k,), 1 / k))
--> 178 return super().dist([P, steps, init_dist], n_lags=n_lags, **kwargs)

File ~/.conda/envs/rx-bonds/lib/python3.11/site-packages/pymc/distributions/distribution.py:571, in Distribution.dist(cls, dist_params, shape, **kwargs)
    568 ndim_supp = getattr(cls.rv_op, "ndim_supp", getattr(cls.rv_type, "ndim_supp", None))
    569 if ndim_supp is None:
    570     # Initialize Ops and check the ndim_supp that is now required to exist
--> 571     ndim_supp = cls.rv_op(*dist_params, **kwargs).owner.op.ndim_supp
    573 create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp)
    574 rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)

File ~/.conda/envs/rx-bonds/lib/python3.11/site-packages/pymc_extras/distributions/timeseries.py:202, in DiscreteMarkovChain.rv_op(cls, P, steps, init_dist, n_lags, size)
    199     next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng).owner.outputs
    200     return next_state, {old_rng: next_rng}
--> 202 markov_chain, state_updates = pytensor.scan(
    203     transition,
    204     non_sequences=[P_, state_rng],
    205     outputs_info=_make_outputs_info(n_lags, init_dist_),
    206     n_steps=steps_,
    207     strict=True,
    208 )
    210 (state_next_rng,) = tuple(state_updates.values())
    212 discrete_mc_ = pt.moveaxis(pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1)

File ~/.conda/envs/rx-bonds/lib/python3.11/site-packages/pytensor/scan/basic.py:1165, in scan(fn, sequences, outputs_info, non_sequences, n_steps, truncate_gradient, go_backwards, mode, name, profile, allow_gc, strict, return_list)
   1163         pass
   1164     scan_inputs += [arg]
-> 1165 scan_outs = local_op(*scan_inputs)
   1166 if not isinstance(scan_outs, list | tuple):
   1167     scan_outs = [scan_outs]

File ~/.conda/envs/rx-bonds/lib/python3.11/site-packages/pytensor/graph/op.py:293, in Op.__call__(self, name, return_list, *inputs, **kwargs)
    249 def __call__(
    250     self, *inputs: Any, name=None, return_list=False, **kwargs
    251 ) -> Variable | list[Variable]:
    252     r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    253 
    254     This method is just a wrapper around :meth:`Op.make_node`.
   (...)    291 
    292     """
--> 293     node = self.make_node(*inputs, **kwargs)
    294     if name is not None:
    295         if len(node.outputs) == 1:

File ~/.conda/envs/rx-bonds/lib/python3.11/site-packages/pytensor/scan/op.py:1116, in Scan.make_node(self, *inputs)
   1106         raise ValueError(
   1107             err_msg2
   1108             % (
   (...)   1113             )
   1114         )
   1115     if inner_sitsot_out.ndim != outer_sitsot.ndim - 1:
-> 1116         raise ValueError(
   1117             err_msg3
   1118             % (
   1119                 str(outer_sitsot),
   1120                 argoffset + idx,
   1121                 outer_sitsot.type.ndim,
   1122                 inner_sitsot_out.type.ndim,
   1123             )
   1124         )
   1126 argoffset += len(self.outer_sitsot(inputs))
   1127 # Check that the shared variable and their update rule have the same
   1128 # dtype. Maybe even same type ?!

ValueError: When compiling the inner function of scan (the function called by scan in each of its iterations) the following error has been encountered: The initial state (`outputs_info` in scan nomenclature) of variable SetSubtensor{:stop}.0 (argument number 0) has 2 dimension(s), while the corresponding variable in the result of the inner function of scan (`fn`) has 2 dimension(s) (it should be one less than the initial state). For example, if the inner function of scan returns a vector of size d and scan uses the values of the previous time-step, then the initial state in scan should be a matrix of shape (1, d). The first dimension of this matrix corresponds to the number of previous time-steps that scan uses in each of its iterations. In order to solve this issue if the two variable currently have the same dimensionality, you can increase the dimensionality of the variable in the initial state of scan by using dimshuffle or shape_padleft. 

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingenhancementsNew feature or requesthelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions