-
Notifications
You must be signed in to change notification settings - Fork 137
Fix mean, var and std for XTensorVariables #1533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -366,21 +366,25 @@ def __trunc__(self): | |||||||||||
# https://docs.xarray.dev/en/latest/api.html#id1 | ||||||||||||
@property | ||||||||||||
def values(self) -> TensorVariable: | ||||||||||||
"""Convert to a TensorVariable with the same data.""" | ||||||||||||
return typing.cast(TensorVariable, px.basic.tensor_from_xtensor(self)) | ||||||||||||
|
||||||||||||
# Can't provide property data because that's already taken by Constants! | ||||||||||||
# data = values | ||||||||||||
|
||||||||||||
@property | ||||||||||||
def coords(self): | ||||||||||||
"""Not implemented.""" | ||||||||||||
raise NotImplementedError("coords not implemented for XTensorVariable") | ||||||||||||
|
||||||||||||
@property | ||||||||||||
def dims(self) -> tuple[str, ...]: | ||||||||||||
"""The names of the dimensions of the variable.""" | ||||||||||||
return self.type.dims | ||||||||||||
|
||||||||||||
@property | ||||||||||||
def sizes(self) -> dict[str, TensorVariable]: | ||||||||||||
"""The sizes of the dimensions of the variable.""" | ||||||||||||
return dict(zip(self.dims, self.shape)) | ||||||||||||
|
||||||||||||
@property | ||||||||||||
|
@@ -392,18 +396,22 @@ def as_numpy(self): | |||||||||||
# https://docs.xarray.dev/en/latest/api.html#ndarray-attributes | ||||||||||||
@property | ||||||||||||
def ndim(self) -> int: | ||||||||||||
"""The number of dimensions of the variable.""" | ||||||||||||
return self.type.ndim | ||||||||||||
|
||||||||||||
@property | ||||||||||||
def shape(self) -> tuple[TensorVariable, ...]: | ||||||||||||
"""The shape of the variable.""" | ||||||||||||
return tuple(px.basic.tensor_from_xtensor(self).shape) # type: ignore | ||||||||||||
|
||||||||||||
@property | ||||||||||||
def size(self) -> TensorVariable: | ||||||||||||
"""The total number of elements in the variable.""" | ||||||||||||
return typing.cast(TensorVariable, variadic_mul(*self.shape)) | ||||||||||||
|
||||||||||||
@property | ||||||||||||
def dtype(self): | ||||||||||||
def dtype(self) -> str: | ||||||||||||
"""The data type of the variable.""" | ||||||||||||
return self.type.dtype | ||||||||||||
|
||||||||||||
@property | ||||||||||||
|
@@ -414,6 +422,7 @@ def broadcastable(self): | |||||||||||
# DataArray contents | ||||||||||||
# https://docs.xarray.dev/en/latest/api.html#dataarray-contents | ||||||||||||
def rename(self, new_name_or_name_dict=None, **names): | ||||||||||||
"""Rename the variable or its dimension(s).""" | ||||||||||||
if isinstance(new_name_or_name_dict, str): | ||||||||||||
new_name = new_name_or_name_dict | ||||||||||||
name_dict = None | ||||||||||||
|
@@ -425,31 +434,41 @@ def rename(self, new_name_or_name_dict=None, **names): | |||||||||||
return new_out | ||||||||||||
|
||||||||||||
def copy(self, name: str | None = None): | ||||||||||||
"""Create a copy of the variable. | ||||||||||||
|
||||||||||||
This is just an identity operation, as XTensorVariables are immutable. | ||||||||||||
""" | ||||||||||||
out = px.math.identity(self) | ||||||||||||
out.name = name | ||||||||||||
return out | ||||||||||||
|
||||||||||||
def astype(self, dtype): | ||||||||||||
"""Convert the variable to a different data type.""" | ||||||||||||
return px.math.cast(self, dtype) | ||||||||||||
|
||||||||||||
def item(self): | ||||||||||||
"""Not implemented.""" | ||||||||||||
raise NotImplementedError("item not implemented for XTensorVariable") | ||||||||||||
|
||||||||||||
# Indexing | ||||||||||||
# https://docs.xarray.dev/en/latest/api.html#id2 | ||||||||||||
def __setitem__(self, idx, value): | ||||||||||||
"""Not implemented. Use `x[idx].set(value)` or `x[idx].inc(value)` instead.""" | ||||||||||||
raise TypeError( | ||||||||||||
"XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead." | ||||||||||||
) | ||||||||||||
|
||||||||||||
@property | ||||||||||||
def loc(self): | ||||||||||||
"""Not implemented.""" | ||||||||||||
raise NotImplementedError("loc not implemented for XTensorVariable") | ||||||||||||
|
||||||||||||
def sel(self, *args, **kwargs): | ||||||||||||
"""Not implemented.""" | ||||||||||||
raise NotImplementedError("sel not implemented for XTensorVariable") | ||||||||||||
|
||||||||||||
def __getitem__(self, idx): | ||||||||||||
"""Index the variable positionally.""" | ||||||||||||
if isinstance(idx, dict): | ||||||||||||
return self.isel(idx) | ||||||||||||
|
||||||||||||
|
@@ -465,6 +484,7 @@ def isel( | |||||||||||
missing_dims: Literal["raise", "warn", "ignore"] = "raise", | ||||||||||||
**indexers_kwargs, | ||||||||||||
): | ||||||||||||
"""Index the variable along the specified dimension(s).""" | ||||||||||||
if indexers_kwargs: | ||||||||||||
if indexers is not None: | ||||||||||||
raise ValueError( | ||||||||||||
|
@@ -505,6 +525,48 @@ def isel( | |||||||||||
return px.indexing.index(self, *indices) | ||||||||||||
|
||||||||||||
def set(self, value): | ||||||||||||
"""Return a copy of the variable indexed by self with the indexed values set to y. | ||||||||||||
|
||||||||||||
The original variable is not modified. | ||||||||||||
|
||||||||||||
Raises | ||||||||||||
------ | ||||||||||||
Value: | ||||||||||||
If self is not the result of an index operation | ||||||||||||
|
||||||||||||
Examples | ||||||||||||
-------- | ||||||||||||
|
||||||||||||
.. test-code:: | ||||||||||||
|
||||||||||||
import pytensor.xtensor as ptx | ||||||||||||
|
||||||||||||
x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b")) | ||||||||||||
idx = ptx.as_xtensor([0, 1], dims=("a",)) | ||||||||||||
out = x[:, idx].set(1) | ||||||||||||
out.eval() | ||||||||||||
|
||||||||||||
.. test-output:: | ||||||||||||
|
||||||||||||
array([[1, 0], | ||||||||||||
[0, 1]]) | ||||||||||||
|
||||||||||||
|
||||||||||||
.. test-code:: | ||||||||||||
|
||||||||||||
import pytensor.xtensor as ptx | ||||||||||||
|
||||||||||||
x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b")) | ||||||||||||
idx = ptx.as_xtensor([0, 1], dims=("a",)) | ||||||||||||
out = x.isel({"b": idx}).set(-1) | ||||||||||||
out.eval() | ||||||||||||
|
||||||||||||
.. test-output:: | ||||||||||||
|
||||||||||||
array([[-1, 0], | ||||||||||||
[0, -1]]) | ||||||||||||
|
||||||||||||
""" | ||||||||||||
if not ( | ||||||||||||
self.owner is not None and isinstance(self.owner.op, px.indexing.Index) | ||||||||||||
): | ||||||||||||
|
@@ -516,6 +578,48 @@ def set(self, value): | |||||||||||
return px.indexing.index_assignment(x, value, *idxs) | ||||||||||||
|
||||||||||||
def inc(self, value): | ||||||||||||
"""Return a copy of the variable indexed by self with the indexed values incremented by value. | ||||||||||||
|
||||||||||||
The original variable is not modified. | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The indentation got messed up somehow |
||||||||||||
|
||||||||||||
Raises | ||||||||||||
------ | ||||||||||||
Value: | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
If self is not the result of an index operation | ||||||||||||
|
||||||||||||
Examples | ||||||||||||
-------- | ||||||||||||
|
||||||||||||
.. test-code:: | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
None of these blocks are being shown in the rendered version: https://pytensor--1533.org.readthedocs.build/en/1533/library/xtensor/type.html#pytensor.xtensor.type.XTensorVariable.inc I think it is because the directive name is without hyphen: https://www.sphinx-doc.org/en/master/usage/extensions/doctest.html#directive-testcode |
||||||||||||
|
||||||||||||
import pytensor.xtensor as ptx | ||||||||||||
|
||||||||||||
x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b")) | ||||||||||||
idx = ptx.as_xtensor([0, 1], dims=("a",)) | ||||||||||||
out = x[:, idx].inc(1) | ||||||||||||
out.eval() | ||||||||||||
|
||||||||||||
.. test-output:: | ||||||||||||
|
||||||||||||
array([[2, 1], | ||||||||||||
[1, 2]]) | ||||||||||||
|
||||||||||||
|
||||||||||||
.. test-code:: | ||||||||||||
|
||||||||||||
import pytensor.xtensor as ptx | ||||||||||||
|
||||||||||||
x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b")) | ||||||||||||
idx = ptx.as_xtensor([0, 1], dims=("a",)) | ||||||||||||
out = x.isel({"b": idx}).inc(-1) | ||||||||||||
out.eval() | ||||||||||||
|
||||||||||||
.. test-output:: | ||||||||||||
|
||||||||||||
array([[0, 1], | ||||||||||||
[1, 0]]) | ||||||||||||
|
||||||||||||
""" | ||||||||||||
if not ( | ||||||||||||
self.owner is not None and isinstance(self.owner.op, px.indexing.Index) | ||||||||||||
): | ||||||||||||
|
@@ -579,7 +683,7 @@ def squeeze( | |||||||||||
drop=None, | ||||||||||||
axis: int | Sequence[int] | None = None, | ||||||||||||
): | ||||||||||||
"""Remove dimensions of size 1 from an XTensorVariable. | ||||||||||||
"""Remove dimensions of size 1. | ||||||||||||
|
||||||||||||
Parameters | ||||||||||||
---------- | ||||||||||||
|
@@ -606,7 +710,7 @@ def expand_dims( | |||||||||||
axis: int | Sequence[int] | None = None, | ||||||||||||
**dim_kwargs, | ||||||||||||
): | ||||||||||||
"""Add one or more new dimensions to the tensor. | ||||||||||||
"""Add one or more new dimensions to the variable. | ||||||||||||
|
||||||||||||
Parameters | ||||||||||||
---------- | ||||||||||||
|
@@ -616,14 +720,10 @@ def expand_dims( | |||||||||||
- int: the new size | ||||||||||||
- sequence: coordinates (length determines size) | ||||||||||||
Comment on lines
720
to
721
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
no extra indentation and empty line between previous line and list should render as list instead of current quoted list render |
||||||||||||
create_index_for_new_dim : bool, default: True | ||||||||||||
Currently ignored. Reserved for future coordinate support. | ||||||||||||
In xarray, when True (default), creates a coordinate index for the new dimension | ||||||||||||
with values from 0 to size-1. When False, no coordinate index is created. | ||||||||||||
Ignored by PyTensor | ||||||||||||
axis : int | Sequence[int] | None, default: None | ||||||||||||
Not implemented yet. In xarray, specifies where to insert the new dimension(s). | ||||||||||||
By default (None), new dimensions are inserted at the beginning (axis=0). | ||||||||||||
Symbolic axis is not supported yet. | ||||||||||||
Negative values count from the end. | ||||||||||||
**dim_kwargs : int | Sequence | ||||||||||||
Alternative to `dim` dict. Only used if `dim` is None. | ||||||||||||
|
||||||||||||
|
@@ -643,65 +743,75 @@ def expand_dims( | |||||||||||
# ndarray methods | ||||||||||||
# https://docs.xarray.dev/en/latest/api.html#id7 | ||||||||||||
def clip(self, min, max): | ||||||||||||
"""Clip the values of the variable to a specified range.""" | ||||||||||||
return px.math.clip(self, min, max) | ||||||||||||
|
||||||||||||
def conj(self): | ||||||||||||
"""Return the complex conjugate of the variable.""" | ||||||||||||
return px.math.conj(self) | ||||||||||||
|
||||||||||||
@property | ||||||||||||
def imag(self): | ||||||||||||
"""Return the imaginary part of the variable.""" | ||||||||||||
return px.math.imag(self) | ||||||||||||
|
||||||||||||
@property | ||||||||||||
def real(self): | ||||||||||||
"""Return the real part of the variable.""" | ||||||||||||
return px.math.real(self) | ||||||||||||
|
||||||||||||
@property | ||||||||||||
def T(self): | ||||||||||||
"""Return the full transpose of the tensor. | ||||||||||||
"""Return the full transpose of the variable. | ||||||||||||
|
||||||||||||
This is equivalent to calling transpose() with no arguments. | ||||||||||||
|
||||||||||||
Returns | ||||||||||||
------- | ||||||||||||
XTensorVariable | ||||||||||||
Fully transposed tensor. | ||||||||||||
""" | ||||||||||||
return self.transpose() | ||||||||||||
|
||||||||||||
# Aggregation | ||||||||||||
# https://docs.xarray.dev/en/latest/api.html#id6 | ||||||||||||
def all(self, dim=None): | ||||||||||||
"""Reduce the variable by applying `all` along some dimension(s).""" | ||||||||||||
return px.reduction.all(self, dim) | ||||||||||||
|
||||||||||||
def any(self, dim=None): | ||||||||||||
"""Reduce the variable by applying `any` along some dimension(s).""" | ||||||||||||
return px.reduction.any(self, dim) | ||||||||||||
|
||||||||||||
def max(self, dim=None): | ||||||||||||
"""Compute the maximum along the given dimension(s).""" | ||||||||||||
return px.reduction.max(self, dim) | ||||||||||||
|
||||||||||||
def min(self, dim=None): | ||||||||||||
"""Compute the minimum along the given dimension(s).""" | ||||||||||||
return px.reduction.min(self, dim) | ||||||||||||
|
||||||||||||
def mean(self, dim=None): | ||||||||||||
"""Compute the mean along the given dimension(s).""" | ||||||||||||
return px.reduction.mean(self, dim) | ||||||||||||
|
||||||||||||
def prod(self, dim=None): | ||||||||||||
"""Compute the product along the given dimension(s).""" | ||||||||||||
return px.reduction.prod(self, dim) | ||||||||||||
|
||||||||||||
def sum(self, dim=None): | ||||||||||||
"""Compute the sum along the given dimension(s).""" | ||||||||||||
return px.reduction.sum(self, dim) | ||||||||||||
|
||||||||||||
def std(self, dim=None): | ||||||||||||
return px.reduction.std(self, dim) | ||||||||||||
def std(self, dim=None, ddof=0): | ||||||||||||
"""Compute the standard deviation along the given dimension(s).""" | ||||||||||||
return px.reduction.std(self, dim, ddof=ddof) | ||||||||||||
|
||||||||||||
def var(self, dim=None): | ||||||||||||
return px.reduction.var(self, dim) | ||||||||||||
def var(self, dim=None, ddof=0): | ||||||||||||
"""Compute the variance along the given dimension(s).""" | ||||||||||||
return px.reduction.var(self, dim, ddof=ddof) | ||||||||||||
|
||||||||||||
def cumsum(self, dim=None): | ||||||||||||
"""Compute the cumulative sum along the given dimension(s).""" | ||||||||||||
return px.reduction.cumsum(self, dim) | ||||||||||||
|
||||||||||||
def cumprod(self, dim=None): | ||||||||||||
"""Compute the cumulative product along the given dimension(s).""" | ||||||||||||
return px.reduction.cumprod(self, dim) | ||||||||||||
|
||||||||||||
def diff(self, dim, n=1): | ||||||||||||
|
@@ -720,7 +830,7 @@ def transpose( | |||||||||||
*dim: str | EllipsisType, | ||||||||||||
missing_dims: Literal["raise", "warn", "ignore"] = "raise", | ||||||||||||
): | ||||||||||||
"""Transpose dimensions of the tensor. | ||||||||||||
"""Transpose the dimensions of the variable. | ||||||||||||
|
||||||||||||
Parameters | ||||||||||||
---------- | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the description of missing dims, if you want the bullet point list to be rendered as a list instead of as inline text (see current render https://pytensor--1533.org.readthedocs.build/en/1533/library/xtensor/type.html#pytensor.xtensor.type.XTensorVariable.transpose) there should be an empty line between previous line and start of list:
|
||||||||||||
|
@@ -747,21 +857,38 @@ def transpose( | |||||||||||
return px.shape.transpose(self, *dim, missing_dims=missing_dims) | ||||||||||||
|
||||||||||||
def stack(self, dim, **dims): | ||||||||||||
"""Stack existing dimensions into a single new dimension.""" | ||||||||||||
return px.shape.stack(self, dim, **dims) | ||||||||||||
|
||||||||||||
def unstack(self, dim, **dims): | ||||||||||||
"""Unstack a dimension into multiple dimensions of a given size. | ||||||||||||
|
||||||||||||
Because XTensorVariables don't have coords, this operation requires the sizes of each unstacked dimension to be specified. | ||||||||||||
Also, unstacked dims will follow a C-style order, regardless of the order of the original dimensions. | ||||||||||||
|
||||||||||||
.. test-code:: | ||||||||||||
|
||||||||||||
import pytensor.xtensor as ptx | ||||||||||||
|
||||||||||||
x = ptx.as_xtensor([[1, 2], [3, 4]], dims=("a", "b")) | ||||||||||||
stacked_cumsum = x.stack({"c": ["a", "b"]}).cumsum("c") | ||||||||||||
unstacked_cumsum = stacked_x.unstack({"c": x.sizes}) | ||||||||||||
unstacked_cumsum.eval() | ||||||||||||
Comment on lines
+871
to
+876
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be indented |
||||||||||||
|
||||||||||||
.. test-output:: | ||||||||||||
|
||||||||||||
array([[ 1, 3], | ||||||||||||
[ 6, 10]]) | ||||||||||||
|
||||||||||||
""" | ||||||||||||
return px.shape.unstack(self, dim, **dims) | ||||||||||||
|
||||||||||||
def dot(self, other, dim=None): | ||||||||||||
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims.""" | ||||||||||||
"""Generalized dot product with another XTensorVariable.""" | ||||||||||||
return px.math.dot(self, other, dim=dim) | ||||||||||||
|
||||||||||||
def broadcast(self, *others, exclude=None): | ||||||||||||
"""Broadcast this tensor against other XTensorVariables.""" | ||||||||||||
return px.shape.broadcast(self, *others, exclude=exclude) | ||||||||||||
|
||||||||||||
def broadcast_like(self, other, exclude=None): | ||||||||||||
"""Broadcast this tensor against another XTensorVariable.""" | ||||||||||||
"""Broadcast against another XTensorVariable.""" | ||||||||||||
_, self_bcast = px.shape.broadcast(other, self, exclude=exclude) | ||||||||||||
return self_bcast | ||||||||||||
|
||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.