Skip to content

Fix mean, var and std for XTensorVariables #1533

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytensor/xtensor/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def bool_reduce(x, dim: REDUCE_DIM = None, *, binary_op):
def _infer_reduced_size(original_var, reduced_var):
reduced_dims = reduced_var.dims
return variadic_mul(
*[size for dim, size in original_var.sizes if dim not in reduced_dims]
*[size for dim, size in original_var.sizes.items() if dim not in reduced_dims]
)


Expand All @@ -96,7 +96,7 @@ def var(x, dim: REDUCE_DIM, *, ddof: int = 0):
x = as_xtensor(x)
x_mean = mean(x, dim)
n = _infer_reduced_size(x, x_mean)
return square(x - x_mean) / (n - ddof)
return square(x - x_mean).sum(dim) / (n - ddof)


def std(x, dim: REDUCE_DIM, *, ddof: int = 0):
Expand Down
177 changes: 152 additions & 25 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,21 +366,25 @@ def __trunc__(self):
# https://docs.xarray.dev/en/latest/api.html#id1
@property
def values(self) -> TensorVariable:
"""Convert to a TensorVariable with the same data."""
return typing.cast(TensorVariable, px.basic.tensor_from_xtensor(self))

# Can't provide property data because that's already taken by Constants!
# data = values

@property
def coords(self):
"""Not implemented."""
raise NotImplementedError("coords not implemented for XTensorVariable")

@property
def dims(self) -> tuple[str, ...]:
"""The names of the dimensions of the variable."""
return self.type.dims

@property
def sizes(self) -> dict[str, TensorVariable]:
"""The sizes of the dimensions of the variable."""
return dict(zip(self.dims, self.shape))

@property
Expand All @@ -392,18 +396,22 @@ def as_numpy(self):
# https://docs.xarray.dev/en/latest/api.html#ndarray-attributes
@property
def ndim(self) -> int:
"""The number of dimensions of the variable."""
return self.type.ndim

@property
def shape(self) -> tuple[TensorVariable, ...]:
"""The shape of the variable."""
return tuple(px.basic.tensor_from_xtensor(self).shape) # type: ignore

@property
def size(self) -> TensorVariable:
"""The total number of elements in the variable."""
return typing.cast(TensorVariable, variadic_mul(*self.shape))

@property
def dtype(self):
def dtype(self) -> str:
"""The data type of the variable."""
return self.type.dtype

@property
Expand All @@ -414,6 +422,7 @@ def broadcastable(self):
# DataArray contents
# https://docs.xarray.dev/en/latest/api.html#dataarray-contents
def rename(self, new_name_or_name_dict=None, **names):
"""Rename the variable or its dimension(s)."""
if isinstance(new_name_or_name_dict, str):
new_name = new_name_or_name_dict
name_dict = None
Expand All @@ -425,31 +434,41 @@ def rename(self, new_name_or_name_dict=None, **names):
return new_out

def copy(self, name: str | None = None):
"""Create a copy of the variable.

This is just an identity operation, as XTensorVariables are immutable.
"""
out = px.math.identity(self)
out.name = name
return out

def astype(self, dtype):
"""Convert the variable to a different data type."""
return px.math.cast(self, dtype)

def item(self):
"""Not implemented."""
raise NotImplementedError("item not implemented for XTensorVariable")

# Indexing
# https://docs.xarray.dev/en/latest/api.html#id2
def __setitem__(self, idx, value):
"""Not implemented. Use `x[idx].set(value)` or `x[idx].inc(value)` instead."""
raise TypeError(
"XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
)

@property
def loc(self):
"""Not implemented."""
raise NotImplementedError("loc not implemented for XTensorVariable")

def sel(self, *args, **kwargs):
"""Not implemented."""
raise NotImplementedError("sel not implemented for XTensorVariable")

def __getitem__(self, idx):
"""Index the variable positionally."""
if isinstance(idx, dict):
return self.isel(idx)

Expand All @@ -465,6 +484,7 @@ def isel(
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
**indexers_kwargs,
):
"""Index the variable along the specified dimension(s)."""
if indexers_kwargs:
if indexers is not None:
raise ValueError(
Expand Down Expand Up @@ -505,6 +525,48 @@ def isel(
return px.indexing.index(self, *indices)

def set(self, value):
"""Return a copy of the variable indexed by self with the indexed values set to y.

The original variable is not modified.

Raises
------
Value:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Value:
ValueError

If self is not the result of an index operation

Examples
--------

.. test-code::

import pytensor.xtensor as ptx

x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x[:, idx].set(1)
out.eval()

.. test-output::

array([[1, 0],
[0, 1]])


.. test-code::

import pytensor.xtensor as ptx

x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x.isel({"b": idx}).set(-1)
out.eval()

.. test-output::

array([[-1, 0],
[0, -1]])

"""
if not (
self.owner is not None and isinstance(self.owner.op, px.indexing.Index)
):
Expand All @@ -516,6 +578,48 @@ def set(self, value):
return px.indexing.index_assignment(x, value, *idxs)

def inc(self, value):
"""Return a copy of the variable indexed by self with the indexed values incremented by value.

The original variable is not modified.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indentation got messed up somehow


Raises
------
Value:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Value:
ValueError

If self is not the result of an index operation

Examples
--------

.. test-code::
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.. test-code::
.. testcode::

None of these blocks are being shown in the rendered version: https://pytensor--1533.org.readthedocs.build/en/1533/library/xtensor/type.html#pytensor.xtensor.type.XTensorVariable.inc

I think it is because the directive name is without hyphen: https://www.sphinx-doc.org/en/master/usage/extensions/doctest.html#directive-testcode


import pytensor.xtensor as ptx

x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x[:, idx].inc(1)
out.eval()

.. test-output::

array([[2, 1],
[1, 2]])


.. test-code::

import pytensor.xtensor as ptx

x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x.isel({"b": idx}).inc(-1)
out.eval()

.. test-output::

array([[0, 1],
[1, 0]])

"""
if not (
self.owner is not None and isinstance(self.owner.op, px.indexing.Index)
):
Expand Down Expand Up @@ -579,7 +683,7 @@ def squeeze(
drop=None,
axis: int | Sequence[int] | None = None,
):
"""Remove dimensions of size 1 from an XTensorVariable.
"""Remove dimensions of size 1.

Parameters
----------
Expand All @@ -606,7 +710,7 @@ def expand_dims(
axis: int | Sequence[int] | None = None,
**dim_kwargs,
):
"""Add one or more new dimensions to the tensor.
"""Add one or more new dimensions to the variable.

Parameters
----------
Expand All @@ -616,14 +720,10 @@ def expand_dims(
- int: the new size
- sequence: coordinates (length determines size)
Comment on lines 720 to 721
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- int: the new size
- sequence: coordinates (length determines size)
- int: the new size
- sequence: coordinates (length determines size)

no extra indentation and empty line between previous line and list should render as list instead of current quoted list render

create_index_for_new_dim : bool, default: True
Currently ignored. Reserved for future coordinate support.
In xarray, when True (default), creates a coordinate index for the new dimension
with values from 0 to size-1. When False, no coordinate index is created.
Ignored by PyTensor
axis : int | Sequence[int] | None, default: None
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
By default (None), new dimensions are inserted at the beginning (axis=0).
Symbolic axis is not supported yet.
Negative values count from the end.
**dim_kwargs : int | Sequence
Alternative to `dim` dict. Only used if `dim` is None.

Expand All @@ -643,65 +743,75 @@ def expand_dims(
# ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7
def clip(self, min, max):
"""Clip the values of the variable to a specified range."""
return px.math.clip(self, min, max)

def conj(self):
"""Return the complex conjugate of the variable."""
return px.math.conj(self)

@property
def imag(self):
"""Return the imaginary part of the variable."""
return px.math.imag(self)

@property
def real(self):
"""Return the real part of the variable."""
return px.math.real(self)

@property
def T(self):
"""Return the full transpose of the tensor.
"""Return the full transpose of the variable.

This is equivalent to calling transpose() with no arguments.

Returns
-------
XTensorVariable
Fully transposed tensor.
"""
return self.transpose()

# Aggregation
# https://docs.xarray.dev/en/latest/api.html#id6
def all(self, dim=None):
"""Reduce the variable by applying `all` along some dimension(s)."""
return px.reduction.all(self, dim)

def any(self, dim=None):
"""Reduce the variable by applying `any` along some dimension(s)."""
return px.reduction.any(self, dim)

def max(self, dim=None):
"""Compute the maximum along the given dimension(s)."""
return px.reduction.max(self, dim)

def min(self, dim=None):
"""Compute the minimum along the given dimension(s)."""
return px.reduction.min(self, dim)

def mean(self, dim=None):
"""Compute the mean along the given dimension(s)."""
return px.reduction.mean(self, dim)

def prod(self, dim=None):
"""Compute the product along the given dimension(s)."""
return px.reduction.prod(self, dim)

def sum(self, dim=None):
"""Compute the sum along the given dimension(s)."""
return px.reduction.sum(self, dim)

def std(self, dim=None):
return px.reduction.std(self, dim)
def std(self, dim=None, ddof=0):
"""Compute the standard deviation along the given dimension(s)."""
return px.reduction.std(self, dim, ddof=ddof)

def var(self, dim=None):
return px.reduction.var(self, dim)
def var(self, dim=None, ddof=0):
"""Compute the variance along the given dimension(s)."""
return px.reduction.var(self, dim, ddof=ddof)

def cumsum(self, dim=None):
"""Compute the cumulative sum along the given dimension(s)."""
return px.reduction.cumsum(self, dim)

def cumprod(self, dim=None):
"""Compute the cumulative product along the given dimension(s)."""
return px.reduction.cumprod(self, dim)

def diff(self, dim, n=1):
Expand All @@ -720,7 +830,7 @@ def transpose(
*dim: str | EllipsisType,
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
):
"""Transpose dimensions of the tensor.
"""Transpose the dimensions of the variable.

Parameters
----------
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the description of missing dims, if you want the bullet point list to be rendered as a list instead of as inline text (see current render https://pytensor--1533.org.readthedocs.build/en/1533/library/xtensor/type.html#pytensor.xtensor.type.XTensorVariable.transpose) there should be an empty line between previous line and start of list:

How to handle dimensions that don't exist in the tensor:

- "raise": Raise an error if any dimensions don't exist

Expand All @@ -747,21 +857,38 @@ def transpose(
return px.shape.transpose(self, *dim, missing_dims=missing_dims)

def stack(self, dim, **dims):
"""Stack existing dimensions into a single new dimension."""
return px.shape.stack(self, dim, **dims)

def unstack(self, dim, **dims):
"""Unstack a dimension into multiple dimensions of a given size.

Because XTensorVariables don't have coords, this operation requires the sizes of each unstacked dimension to be specified.
Also, unstacked dims will follow a C-style order, regardless of the order of the original dimensions.

.. test-code::

import pytensor.xtensor as ptx

x = ptx.as_xtensor([[1, 2], [3, 4]], dims=("a", "b"))
stacked_cumsum = x.stack({"c": ["a", "b"]}).cumsum("c")
unstacked_cumsum = stacked_x.unstack({"c": x.sizes})
unstacked_cumsum.eval()
Comment on lines +871 to +876
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be indented


.. test-output::

array([[ 1, 3],
[ 6, 10]])

"""
return px.shape.unstack(self, dim, **dims)

def dot(self, other, dim=None):
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
"""Generalized dot product with another XTensorVariable."""
return px.math.dot(self, other, dim=dim)

def broadcast(self, *others, exclude=None):
"""Broadcast this tensor against other XTensorVariables."""
return px.shape.broadcast(self, *others, exclude=exclude)

def broadcast_like(self, other, exclude=None):
"""Broadcast this tensor against another XTensorVariable."""
"""Broadcast against another XTensorVariable."""
_, self_bcast = px.shape.broadcast(other, self, exclude=exclude)
return self_bcast

Expand Down
Loading
Loading