-
-
Notifications
You must be signed in to change notification settings - Fork 69
Open
Labels
bugSomething isn't workingSomething isn't workingenhancementsNew feature or requestNew feature or requesthelp wantedExtra attention is neededExtra attention is needed
Description
Support I have a dataset of shape (T, N)
with integer observations over time, and I want to model one transition matrix for each of the N
time series, as in:
import pymc as pm
import pymc_extras as pmx
T = 100
N = 30
n_states = 3
coords = {'series': np.arange(N), 'time': np.arange(T), 'state':np.arange(n_states), 'next_state':np.arange(n_states)}
with pm.Model(coords=coords) as m:
init_dist = pm.Categorical.dist(logit_p=np.ones(3,), shape=(N,))
P = pm.Dirichlet('P', a=np.ones((N, n_states, n_states)), dims=['series', 'state', 'next_state'])
obs = pmx.DiscreteMarkovChain('obs', P=P, init_dist=init_dist, dims=['series', 'time'])
This currently fails with a scan error, but should be allowed.
The scan error
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[25], line 14
12 init_dist = pm.Categorical.dist(logit_p=np.ones(n_states), shape=(N,))
13 P = pm.Dirichlet('P', a=np.ones((N, n_states, n_states)), dims=['series', 'state', 'next_state'])
---> 14 obs = pmx.DiscreteMarkovChain('obs', P=P, init_dist=init_dist, dims=['series', 'time'])
File ~/.conda/envs/rx-bonds/lib/python3.11/site-packages/pymc_extras/distributions/timeseries.py:132, in DiscreteMarkovChain.__new__(cls, steps, n_lags, *args, **kwargs)
123 def __new__(cls, *args, steps=None, n_lags=1, **kwargs):
124 steps = get_support_shape_1d(
125 support_shape=steps,
126 shape=None,
(...) 129 support_shape_offset=n_lags,
130 )
--> 132 return super().__new__(cls, *args, steps=steps, n_lags=n_lags, **kwargs)
File ~/.conda/envs/rx-bonds/lib/python3.11/site-packages/pymc/distributions/distribution.py:505, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, default_transform, *args, **kwargs)
502 elif observed is not None:
503 kwargs["shape"] = tuple(observed.shape)
--> 505 rv_out = cls.dist(*args, **kwargs)
507 rv_out = model.register_rv(
508 rv_out,
509 name,
(...) 515 initval=initval,
516 )
518 # add in pretty-printing support
File ~/.conda/envs/rx-bonds/lib/python3.11/site-packages/pymc_extras/distributions/timeseries.py:178, in DiscreteMarkovChain.dist(cls, P, logit_P, steps, init_dist, n_lags, **kwargs)
175 k = P.shape[-1]
176 init_dist = pm.Categorical.dist(p=pt.full((k,), 1 / k))
--> 178 return super().dist([P, steps, init_dist], n_lags=n_lags, **kwargs)
File ~/.conda/envs/rx-bonds/lib/python3.11/site-packages/pymc/distributions/distribution.py:571, in Distribution.dist(cls, dist_params, shape, **kwargs)
568 ndim_supp = getattr(cls.rv_op, "ndim_supp", getattr(cls.rv_type, "ndim_supp", None))
569 if ndim_supp is None:
570 # Initialize Ops and check the ndim_supp that is now required to exist
--> 571 ndim_supp = cls.rv_op(*dist_params, **kwargs).owner.op.ndim_supp
573 create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp)
574 rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
File ~/.conda/envs/rx-bonds/lib/python3.11/site-packages/pymc_extras/distributions/timeseries.py:202, in DiscreteMarkovChain.rv_op(cls, P, steps, init_dist, n_lags, size)
199 next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng).owner.outputs
200 return next_state, {old_rng: next_rng}
--> 202 markov_chain, state_updates = pytensor.scan(
203 transition,
204 non_sequences=[P_, state_rng],
205 outputs_info=_make_outputs_info(n_lags, init_dist_),
206 n_steps=steps_,
207 strict=True,
208 )
210 (state_next_rng,) = tuple(state_updates.values())
212 discrete_mc_ = pt.moveaxis(pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1)
File ~/.conda/envs/rx-bonds/lib/python3.11/site-packages/pytensor/scan/basic.py:1165, in scan(fn, sequences, outputs_info, non_sequences, n_steps, truncate_gradient, go_backwards, mode, name, profile, allow_gc, strict, return_list)
1163 pass
1164 scan_inputs += [arg]
-> 1165 scan_outs = local_op(*scan_inputs)
1166 if not isinstance(scan_outs, list | tuple):
1167 scan_outs = [scan_outs]
File ~/.conda/envs/rx-bonds/lib/python3.11/site-packages/pytensor/graph/op.py:293, in Op.__call__(self, name, return_list, *inputs, **kwargs)
249 def __call__(
250 self, *inputs: Any, name=None, return_list=False, **kwargs
251 ) -> Variable | list[Variable]:
252 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
253
254 This method is just a wrapper around :meth:`Op.make_node`.
(...) 291
292 """
--> 293 node = self.make_node(*inputs, **kwargs)
294 if name is not None:
295 if len(node.outputs) == 1:
File ~/.conda/envs/rx-bonds/lib/python3.11/site-packages/pytensor/scan/op.py:1116, in Scan.make_node(self, *inputs)
1106 raise ValueError(
1107 err_msg2
1108 % (
(...) 1113 )
1114 )
1115 if inner_sitsot_out.ndim != outer_sitsot.ndim - 1:
-> 1116 raise ValueError(
1117 err_msg3
1118 % (
1119 str(outer_sitsot),
1120 argoffset + idx,
1121 outer_sitsot.type.ndim,
1122 inner_sitsot_out.type.ndim,
1123 )
1124 )
1126 argoffset += len(self.outer_sitsot(inputs))
1127 # Check that the shared variable and their update rule have the same
1128 # dtype. Maybe even same type ?!
ValueError: When compiling the inner function of scan (the function called by scan in each of its iterations) the following error has been encountered: The initial state (`outputs_info` in scan nomenclature) of variable SetSubtensor{:stop}.0 (argument number 0) has 2 dimension(s), while the corresponding variable in the result of the inner function of scan (`fn`) has 2 dimension(s) (it should be one less than the initial state). For example, if the inner function of scan returns a vector of size d and scan uses the values of the previous time-step, then the initial state in scan should be a matrix of shape (1, d). The first dimension of this matrix corresponds to the number of previous time-steps that scan uses in each of its iterations. In order to solve this issue if the two variable currently have the same dimensionality, you can increase the dimensionality of the variable in the initial state of scan by using dimshuffle or shape_padleft.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingenhancementsNew feature or requestNew feature or requesthelp wantedExtra attention is neededExtra attention is needed