Skip to content

Add experimental dims module with objects that follow dim-based semantics (like xarray without coordinates) #7820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jun 17, 2025

This builds on top of PyTensor xtensor module, to introduce distributions and model objects that follow xarray-like semantics (without coordinates). Example model:

import numpy as np
import pymc as pm
import pymc.dims as pmd

# Very realistic looking data!
observed_response_np = np.ones((5, 20), dtype=int)
coords = coords = {
    "participant": range(5),
    "trial": range(20),
    "item": range(3),
}
with pm.Model(coords=coords) as dmodel:
    observed_response = pmd.Data(
        "observed_response", observed_response_np, dims=("participant", "trial")
    )

    # Participant constant preferences for each item
    participant_preference = pmd.ZeroSumNormal(
        "participant_preference", 
        core_dims="item", 
        dims=("participant", "item"),
    )

    # Shared time effects across all participants
    time_effects = pmd.Normal("time_effects", dims=("item", "trial"))

    trial_preference = pmd.Deterministic(
        "trial_pereference",
        participant_preference + time_effects,
        dims=(...,),  # No need to specify, PyMC knows them
    )

    response = pmd.Categorical(
        "response",
        p=pmd.math.softmax(trial_preference, dim="item"),
        core_dims="item",
        observed=observed_response,
        dims=(...,), # No need to specify, PyMC knows them
    )

Equivalently, with the traditional API:

with pm.Model(coords=coords) as model:
    observed_response = pm.Data(
        "observed_response", observed_response_np, dims=("participant", "trial")
    )

    # Participant constant preferences for each item
    participant_preference = pm.ZeroSumNormal(
        "participant_preference", 
        n_zerosum_axes=1,
       dims=("participant", "item"),
    )

    # Shared time effects across all participants
    time_effects = pm.Normal("time_effects", dims=("trial", "item"))

    trial_preference = pm.Deterministic(
        "trial_preference",
        participant_preference[:, None, :] + time_effects[None, :, :],
        dims=("participant", "trial", "item"),
    )

    response = pm.Categorical(
        "response",
        p=pm.math.softmax(trial_preference, axis=-1),
        observed=observed_response,
        dims=("participant", "trial"),
    )

More details in the new core notebook

Points of contention

People have expressed dissatisfaction with some aspects of the current approach. I'll try to summarize it here as fair as possible, and feel free to correct me in the comments

Dims being strings

Current design: dims are just strings that label the axis of a tensor.

There are 3 concerns with this simplicity:

  1. Typos can introduce bugs. Say you have time and item it may be easy to mistype between the two. I don't think this is inherent to strings and the user could always assign the string item or time to something less confusing like product_dim = "item", or just use other labels.
  2. Dims can be accidentally repeated if one is not careful. See the contortions one may need to go through to avoid this in the current implementation of MvNormal:
chol = cls._as_xtensor(chol)
# chol @ chol.T in xarray semantics requires a rename
safe_name = "_"
if "_" in chol.type.dims:
    safe_name *= max(map(len, chol.type.dims)) + 1
cov = chol.dot(chol.rename({d0: safe_name}), dim=d1).rename({safe_name: d1})

Obviously we can have safe_name = get_discardable_new_dim_name(chol.type.dims), but the point still stands that it can be awkward to do abstract work on dims.

  1. Dims don't provide enough structure / error-checking in them. Having x[1:] come out with the same dims as x is error-prone because you may then accidentally add them back and get an error only at runtime. I agree partially with this concern.

Note: It's not that dims must be strings (although they do now). Dims could be made to be any hashable / pickleable object, not just strings without much work in PyTensor. When we talk of richer dims, we mean something like: pymc-devs/pytensor#407

Counterpoints:

  1. I think the shape problem is mostly mitigated by using static shapes in PyTensor. XTensorVariables still have the type.shape attribute and if you try to add x + x[1:], if x had a static shape it would raise immediately. However we don't use static shapes in PyMC by default, all dims start as being mutable and can later be made constant with helpers like freeze_data_and_dims, which the jax / nutpie samplers do by default, but not the pymc ones.

  2. It's much easier to write the actual model. You can say concat([x, y, x], dim="time"), whereas with specialized dims this may not be possible or actively forbidden. Alternative APIs may look like some variant of concat([x, y, x], dim=(x.dim[0], y.dim[1], x.dim[0]), which imo defeats the main purpose of having dims (not having to worry about dimension position in the underlying memory).

  3. Alternatively, richer structured dims introduced by PyTensor automatically (say when you do x[::-1]) could be made somehow comparable / matcheable to the original dims the user provided. So an operator like concat may be able to treat the dims of x[::-1].dims as identical to the original x.dims for the purpose of deciding which axis to align. If this can be achieved, it should require no user-facing changes in the current code and so the existing code could be merged as is.

No coordinates

As you can see from the comment below by @twiecki (and it's indeed the first thing everyone asks), people wish to be able to do coordinate based operations, mainly selection like x.sel(time="yesterday"). That's not possible in the current implementation (a useful error message is raised if you try to use that method).

Also it creates a hard problems for us. This is discussed in the new notebook, but in summary, the current PyMC API uses coords to actually define the model AND propagates these to xarray objects after sampling, whenever the user tells us that variable has the same dims. With the new dims model we always know the dims of every variable, so we should always propagate?

Not so easy. The original coordinates are not trivially valid for intermediate operations. They may have the wrong length or order. So what options do we have?

  1. Don't propagate dims by default, because of the 1% of the time it can create problems. I think this is fundamentally wrong, because if a user then goes and tries to do this, they get orthogonal broadcast:
# Note this is regular PyMC code
with pm.Model(coords={"time": range(2)}) as m:
  x = pm.Normal("x", dims=("time",))
  x_reversed = pm.Deterministic("x_reversed", x[::-1])

  prior = pm.sample_prior_predictive(draws=1).prior.squeeze(("chain", "draw"))
  assert (prior["x"] + prior["x_reversed"]).shape == (2, 2)
  1. Don't propagate coords by default. Then the behavior is just like the one in the PyMC model. This is obviously infuriating for users, why do you ask for coords if you don't use them besides getting their size? Totally agree.

Also note that xarray Datasets always need coords, even if individual DataArrays do not. We have to decide what to do with coords one way or another :)

This PR takes the most dumb pragmatic approach. Keep coords unless you detect a shape inconsistency. In that case, refuse to propagate them with a warning. Options for users are to use custom coords for intermediate operations (in which case they also need a custom dim, since PyMC doesn't have a concept of coords per variable (#7852) or to use the old pm.Deterministic without dims (instead of pm.dims.Deterministic) and get the behavior mentioned above (no custom coords nor dims).

If users care about coords they have to mend them manually after the fact. This is discussed at the end of the new notebook (still have to add example of mending).

So why don't we have coords in the graph?

I don't think this can be realistically done in a way that is both performant, simple and doesn't restrict us to static shapes.

Just give up on static shapes? Some people really doubt that we need this flexibility in PyTensor (JAX does just fine without it right?). I strongly disagree. Just to give one example: It rules out any sort of symbolic iterative algorithms unless you pad everything to the worst case scenarios. Saying this example is not worthwhile is not saying we never need runtime shapes! It's just one example where it's needed, there are more.

Besides getting rid of runtime shapes doesn't make it trivially easy to implement and work with coords abstractly. A lot of PyMC magic is transforming user code to go from graph A -> graph B. Adding coordinates increases code complexity.

These are related concerns

Talking about concerns for the lack of coords circles back to the concerns about dims listed above. coords in a sense supercharge dims in xarray. Look at how the behavior of x[:-idx] + x[idx:] changes when you have arrays with coordinates vs without.

import xarray as xr

x = xr.DataArray([0, 1, 2, 3], dims=("day",))
y = xr.DataArray([0, 1, 2, 3], dims=("day",), coords={"day": ["Mon", "Tue", "Wed", "Thu"]})

print((x[1:] + x[:-1]).values)  # [1 3 5]
print((y[1:] + y[:-1]).values)  # [2 4]

print((x - x[::-1]).values)  # [-3 -1  1  3]
print((y - y[::-1]).values)  # [0 0 0 0]

If we know the coordinates of variables, it's no longer confusing that things may have different shapes. There are clear rules how to solve the discrepancy (which are controlled by global variables in xarray..., which may affect how much you like it.)

Reminder: The current approach works the same as xarrays without coords. Some proposals for richer dims without coords would still have to decide what x[idx:] + x[:-idx] should do. Options and my opinionated downsides are

  1. It fails at write time (worry: writing models becomes cumbersome)
  2. It warns at write time with a helpful suggestion of how to do it correctly (worry: mypy PTSD flashbacks)
  3. Dims broadcast orthogonally (worry: it's never what the user would expect)
  4. Work like it does now (in which case it need to block us merging the current code)

Final note

Perfection is the enemy of progress. We have had 5 solid years with deficient use of coords/dims. People seem to like it, and I don't think anyone is doing anything better out there. I believe this is a solid incremental step that brings dims to the model for real, while still handling coords deficiently.

I'm not against coords, I am against bad performance, code complexity and static-shape handicaps (in that order). Maybe these problems can be solved, but they won't be solved by me.

I'm also not against more structured dims at the PyTensor level. Happy to help out there.

I also don't think we should wait. The XTensor stuff, simple as it is, took a while. I opened a draft PR two summers ago, and there is still a lot to do. I would like to share it and start iterating with user feedback, instead of just theoretical concerns. I'm optimistic that this would both inform a more advanced approach and allow us to transition from a less "beta" stage than the one we are at now.

I'm also optimistic the community will accept the experimental / exploratory nature of the new approach, and that we may drop it for something else down the road, that we believe is even better.

day_of_conception = datetime.date(2025, 6, 17)
day_of_last_bug = datetime.date(2025, 6, 17)
today = datetime.date.today()
days_with_bugs = (day_of_last_bug - day_of_conception).days
Copy link
Member

@twiecki twiecki Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wtf 😆

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has two purposes: distract reviewers so they don't focus on the critical changes, and prove that OSS libraries can't be fun.

@ricardoV94 ricardoV94 force-pushed the model_with_dims branch 9 times, most recently from 1cfde5b to f571e5d Compare June 21, 2025 15:54
@twiecki
Copy link
Member

twiecki commented Jun 21, 2025

Can this index using labels? x["a"]

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jun 21, 2025

Can this index using labels? x["a"]

I don't know what x["a"] means :).

Is "a" a coordinate? x.loc["a"] would be the xarray syntax? You can't do that.

Like in xarray, you can do x.isel(dim=idxs) or x[{dim: idxs}].

You cannot do x.sel(dim=coords) or x.loc[coords]

The new PyTensor objects have dims but not coords. It's not trivial to encode coord based operations in our backends.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@ricardoV94 ricardoV94 force-pushed the model_with_dims branch 3 times, most recently from 3eb6738 to 67a0eda Compare June 30, 2025 07:20
@ricardoV94 ricardoV94 force-pushed the model_with_dims branch 2 times, most recently from 3234468 to d4017fb Compare June 30, 2025 12:54
@twiecki
Copy link
Member

twiecki commented Jun 30, 2025

We should make this the 6.0 release.

@ricardoV94
Copy link
Member Author

We should make this the 6.0 release.

I agree, but would perhaps wait until we beta-tested it to the point it no longer feels too experimental

@ricardoV94 ricardoV94 changed the title Model with dims Add experimental dims module with objects that follow dim-based semantics (like xarray) Jun 30, 2025
@ricardoV94 ricardoV94 force-pushed the model_with_dims branch 4 times, most recently from f1f7478 to 56c05f1 Compare June 30, 2025 16:17
Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding one extra comment about myst syntax but outside reviewnb because it always messes up rendering of backticks.

I will try to play around and build/port some models one of these days, and look at the code itself while I do that

@ricardoV94 ricardoV94 force-pushed the model_with_dims branch 3 times, most recently from c2b3c7f to 619b1ad Compare July 8, 2025 18:34
@ricardoV94 ricardoV94 changed the title Add experimental dims module with objects that follow dim-based semantics (like xarray) Add experimental dims module with objects that follow dim-based semantics (like xarray without coordinates) Jul 11, 2025
@ricardoV94 ricardoV94 force-pushed the model_with_dims branch 2 times, most recently from b1b624c to 05e3033 Compare July 11, 2025 11:42
@ricardoV94 ricardoV94 marked this pull request as ready for review July 11, 2025 11:43
@ricardoV94 ricardoV94 force-pushed the model_with_dims branch 2 times, most recently from 2fb0a52 to fc635ab Compare July 11, 2025 13:51
Copy link

codecov bot commented Jul 11, 2025

Codecov Report

Attention: Patch coverage is 92.44713% with 50 lines in your changes missing coverage. Please review.

Project coverage is 92.87%. Comparing base (ae43026) to head (341ffca).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
pymc/dims/distributions/core.py 90.29% 13 Missing ⚠️
pymc/dims/model.py 66.66% 10 Missing ⚠️
pymc/dims/distributions/vector.py 89.55% 7 Missing ⚠️
pymc/dims/distributions/scalar.py 94.87% 6 Missing ⚠️
pymc/testing.py 78.26% 5 Missing ⚠️
pymc/data.py 75.00% 3 Missing ⚠️
pymc/logprob/utils.py 57.14% 3 Missing ⚠️
pymc/pytensorf.py 95.45% 2 Missing ⚠️
pymc/distributions/continuous.py 75.00% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7820      +/-   ##
==========================================
- Coverage   92.92%   92.87%   -0.05%     
==========================================
  Files         107      115       +8     
  Lines       18299    18816     +517     
==========================================
+ Hits        17004    17476     +472     
- Misses       1295     1340      +45     
Files with missing lines Coverage Δ
pymc/backends/arviz.py 96.01% <100.00%> (+0.25%) ⬆️
pymc/dims/__init__.py 100.00% <100.00%> (ø)
pymc/dims/distributions/__init__.py 100.00% <100.00%> (ø)
pymc/dims/distributions/transforms.py 100.00% <100.00%> (ø)
pymc/dims/math.py 100.00% <100.00%> (ø)
pymc/distributions/distribution.py 93.83% <100.00%> (+0.24%) ⬆️
pymc/distributions/multivariate.py 93.91% <100.00%> (+<0.01%) ⬆️
pymc/distributions/shape_utils.py 91.87% <100.00%> (+0.52%) ⬆️
pymc/initial_point.py 99.19% <100.00%> (+0.19%) ⬆️
pymc/logprob/basic.py 100.00% <100.00%> (ø)
... and 16 more

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 force-pushed the model_with_dims branch 2 times, most recently from 33db15f to 4ab90c7 Compare July 11, 2025 17:29
@williambdean
Copy link
Contributor

@williambdean
Copy link
Contributor

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jul 14, 2025

Thanks @williambdean it should now be listed. I also tried to add some stuff in the API but it's likely broken. Help appreciated @OriolAbril (last commit)

@williambdean
Copy link
Contributor

The perfect PR doesnt exi...


This submodule contains functions for defining distributions and operations that use explicit dimensions.

The module is presented in :doc:`dims_module`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be a ref type cross-reference: :ref:`dims_module`

Comment on lines +10 to +13
.. currentmodule:: pymc.dims

.. autosummary::
:toctree: generated/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be a toctree directive only as (for now) it is listing other doc pages, not yet pymc objects

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you propose the changes with github comment thing? I'm not sure if I should take autosummary or something else?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to and didn't work, now it does allow me:

Suggested change
.. currentmodule:: pymc.dims
.. autosummary::
:toctree: generated/
.. toctree::

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current preview of this page: https://pymcio--7820.org.readthedocs.build/projects/docs/en/7820/api/dims.html. I had to enter the url manually because it is not part of the toctree, it should be added to https://github.com/pymc-devs/pymc/blob/main/docs/source/api.rst?plain=1#L7-L25 (adding it here will also make the page have a sidebar like the other api pages)

Dims
====

This submodule contains functions for defining distributions and operations that use explicit dimensions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably add a warning about the module here too

Comment on lines +42 to +47
class Flat(DimDistribution):
xrv_op = pxr.as_xrv(flat)

@classmethod
def dist(cls, **kwargs):
return super().dist([], **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These currently end up having an empty docstring: https://pymcio--7820.org.readthedocs.build/projects/docs/en/7820/api/dims/generated/pymc.dims.Flat.html. Is it possible to dynamically copy the docstring from the "regular" distribution?

For vector distributions or transforms we already have specific docstrings (and I think we want that there) but scalar ones can probably share the docstring after a search/replace of tensor_like -> xtensor_like

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants