Skip to content

fix:*-like creation routines take kwargs #2992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changes/2992.fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fix a bug preventing ``ones_like``, ``full_like``, ``empty_like``, ``zeros_like`` and ``open_like`` functions from accepting
an explicit specification of array attributes like shape, dtype, chunks etc. The functions ``full_like``,
``empty_like``, and ``open_like`` now also more consistently infer a ``fill_value`` parameter from the provided array.
51 changes: 32 additions & 19 deletions src/zarr/api/asynchronous.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
import asyncio
import dataclasses
import warnings
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, cast

import numpy as np
import numpy.typing as npt
@@ -38,13 +38,15 @@
create_hierarchy,
)
from zarr.core.metadata import ArrayMetadataDict, ArrayV2Metadata, ArrayV3Metadata
from zarr.core.metadata.v2 import _default_compressor, _default_filters
from zarr.core.metadata.v2 import CompressorLikev2, _default_compressor, _default_filters
from zarr.errors import NodeTypeValidationError
from zarr.storage._common import make_store_path

if TYPE_CHECKING:
from collections.abc import Iterable

import numcodecs

from zarr.abc.codec import Codec
from zarr.core.buffer import NDArrayLikeOrScalar
from zarr.core.chunk_key_encodings import ChunkKeyEncoding
@@ -116,10 +118,20 @@
return shape, chunks


def _like_args(a: ArrayLike, kwargs: dict[str, Any]) -> dict[str, Any]:
class _LikeArgs(TypedDict):
shape: NotRequired[ChunkCoords]
chunks: NotRequired[ChunkCoords]
dtype: NotRequired[np.dtype[np.generic]]
order: NotRequired[Literal["C", "F"]]
filters: NotRequired[tuple[numcodecs.abc.Codec, ...] | None]
compressor: NotRequired[CompressorLikev2]
codecs: NotRequired[tuple[Codec, ...]]


def _like_args(a: ArrayLike) -> _LikeArgs:
"""Set default values for shape and chunks if they are not present in the array-like object"""

new = kwargs.copy()
new: _LikeArgs = {}

Check warning on line 134 in src/zarr/api/asynchronous.py

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L134

Added line #L134 was not covered by tests

shape, chunks = _get_shape_chunks(a)
if shape is not None:
@@ -1078,7 +1090,7 @@
shape: ChunkCoords, **kwargs: Any
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
"""Create an empty array with the specified shape. The contents will be filled with the
array's fill value or zeros if no fill value is provided.
specified fill value or zeros if no fill value is provided.

Parameters
----------
@@ -1093,8 +1105,7 @@
retrieve data from an empty Zarr array, any values may be returned,
and these are not guaranteed to be stable from one access to the next.
"""

return await create(shape=shape, fill_value=None, **kwargs)
return await create(shape=shape, **kwargs)

Check warning on line 1108 in src/zarr/api/asynchronous.py

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L1108

Added line #L1108 was not covered by tests


async def empty_like(
@@ -1121,8 +1132,10 @@
retrieve data from an empty Zarr array, any values may be returned,
and these are not guaranteed to be stable from one access to the next.
"""
like_kwargs = _like_args(a, kwargs)
return await empty(**like_kwargs)
like_kwargs = _like_args(a) | kwargs
if isinstance(a, (AsyncArray | Array)):
like_kwargs.setdefault("fill_value", a.metadata.fill_value)
return await empty(**like_kwargs) # type: ignore[arg-type]

Check warning on line 1138 in src/zarr/api/asynchronous.py

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L1135-L1138

Added lines #L1135 - L1138 were not covered by tests


# TODO: add type annotations for fill_value and kwargs
@@ -1167,10 +1180,10 @@
Array
The new array.
"""
like_kwargs = _like_args(a, kwargs)
if isinstance(a, AsyncArray):
like_kwargs = _like_args(a) | kwargs
if isinstance(a, (AsyncArray | Array)):

Check warning on line 1184 in src/zarr/api/asynchronous.py

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L1183-L1184

Added lines #L1183 - L1184 were not covered by tests
like_kwargs.setdefault("fill_value", a.metadata.fill_value)
return await full(**like_kwargs)
return await full(**like_kwargs) # type: ignore[arg-type]

Check warning on line 1186 in src/zarr/api/asynchronous.py

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L1186

Added line #L1186 was not covered by tests


async def ones(
@@ -1211,8 +1224,8 @@
Array
The new array.
"""
like_kwargs = _like_args(a, kwargs)
return await ones(**like_kwargs)
like_kwargs = _like_args(a) | kwargs
return await ones(**like_kwargs) # type: ignore[arg-type]

Check warning on line 1228 in src/zarr/api/asynchronous.py

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L1227-L1228

Added lines #L1227 - L1228 were not covered by tests


async def open_array(
@@ -1292,10 +1305,10 @@
AsyncArray
The opened array.
"""
like_kwargs = _like_args(a, kwargs)
like_kwargs = _like_args(a) | kwargs

Check warning on line 1308 in src/zarr/api/asynchronous.py

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L1308

Added line #L1308 was not covered by tests
if isinstance(a, (AsyncArray | Array)):
kwargs.setdefault("fill_value", a.metadata.fill_value)
return await open_array(path=path, **like_kwargs)
like_kwargs.setdefault("fill_value", a.metadata.fill_value)
return await open_array(path=path, **like_kwargs) # type: ignore[arg-type]

Check warning on line 1311 in src/zarr/api/asynchronous.py

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L1310-L1311

Added lines #L1310 - L1311 were not covered by tests


async def zeros(
@@ -1336,5 +1349,5 @@
Array
The new array.
"""
like_kwargs = _like_args(a, kwargs)
return await zeros(**like_kwargs)
like_kwargs = _like_args(a) | kwargs
return await zeros(**like_kwargs) # type: ignore[arg-type]

Check warning on line 1353 in src/zarr/api/asynchronous.py

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L1352-L1353

Added lines #L1352 - L1353 were not covered by tests
97 changes: 96 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import zarr.codecs
import zarr.storage
@@ -77,6 +77,101 @@ def test_create(memory_store: Store) -> None:
z = create(shape=(400, 100), chunks=(16, 16.5), store=store, overwrite=True) # type: ignore [arg-type]


@pytest.mark.parametrize(
"func",
[
zarr.api.asynchronous.zeros_like,
zarr.api.asynchronous.ones_like,
zarr.api.asynchronous.empty_like,
zarr.api.asynchronous.full_like,
zarr.api.asynchronous.open_like,
],
)
@pytest.mark.parametrize("out_shape", ["keep", (10, 10)])
@pytest.mark.parametrize("out_chunks", ["keep", (10, 10)])
@pytest.mark.parametrize("out_dtype", ["keep", "int8"])
@pytest.mark.parametrize("out_fill", ["keep", 4])
async def test_array_like_creation(
zarr_format: ZarrFormat,
func: Callable[[Any], Any],
out_shape: Literal["keep"] | tuple[int, ...],
out_chunks: Literal["keep"] | tuple[int, ...],
out_dtype: str,
out_fill: Literal["keep"] | int,
) -> None:
"""
Test zeros_like, ones_like, empty_like, full_like, ensuring that we can override the
shape, chunks, dtype and fill_value of the array-like object provided to these functions with
appropriate keyword arguments
"""
ref_fill = 100
ref_arr = zarr.create_array(
store={},
shape=(11, 12),
dtype="uint8",
chunks=(11, 12),
zarr_format=zarr_format,
fill_value=ref_fill,
)
kwargs: dict[str, object] = {}
if func is zarr.api.asynchronous.full_like:
if out_fill == "keep":
expect_fill = ref_fill
else:
expect_fill = out_fill
kwargs["fill_value"] = expect_fill
elif func is zarr.api.asynchronous.zeros_like:
expect_fill = 0
elif func is zarr.api.asynchronous.ones_like:
expect_fill = 1
elif func is zarr.api.asynchronous.empty_like:
if out_fill == "keep":
expect_fill = ref_fill
else:
kwargs["fill_value"] = out_fill
expect_fill = out_fill
elif func is zarr.api.asynchronous.open_like: # type: ignore[comparison-overlap]
if out_fill == "keep":
expect_fill = ref_fill
else:
kwargs["fill_value"] = out_fill
expect_fill = out_fill
kwargs["mode"] = "w"
else:
raise AssertionError
if out_shape != "keep":
kwargs["shape"] = out_shape
expect_shape = out_shape
else:
expect_shape = ref_arr.shape
if out_chunks != "keep":
kwargs["chunks"] = out_chunks
expect_chunks = out_chunks
else:
expect_chunks = ref_arr.chunks
if out_dtype != "keep":
kwargs["dtype"] = out_dtype
expect_dtype = out_dtype
else:
expect_dtype = ref_arr.dtype # type: ignore[assignment]

new_arr = await func(ref_arr, path="foo", **kwargs) # type: ignore[call-arg]
assert new_arr.shape == expect_shape
assert new_arr.chunks == expect_chunks
assert new_arr.dtype == expect_dtype
assert np.all(Array(new_arr)[:] == expect_fill)


async def test_invalid_full_like() -> None:
"""
Test that a fill value that is incompatible with the proposed dtype is rejected
"""
ref_arr = zarr.ones(store={}, shape=(11, 12), dtype="uint8", chunks=(11, 12))
fill = 4
with pytest.raises(ValueError, match=f"fill value {fill} is not valid for dtype DataType.bool"):
await zarr.api.asynchronous.full_like(ref_arr, path="foo", fill_value=fill, dtype="bool")


# TODO: parametrize over everything this function takes
@pytest.mark.parametrize("store", ["memory"], indirect=True)
def test_create_array(store: Store, zarr_format: ZarrFormat) -> None:
62 changes: 61 additions & 1 deletion tests/test_group.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
import re
import time
import warnings
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, get_args

import numpy as np
import pytest
@@ -668,6 +668,66 @@ def test_group_create_array(
assert np.array_equal(array[:], data)


LikeMethodName = Literal["zeros_like", "ones_like", "empty_like", "full_like"]


@pytest.mark.parametrize("method_name", get_args(LikeMethodName))
@pytest.mark.parametrize("out_shape", ["keep", (10, 10)])
@pytest.mark.parametrize("out_chunks", ["keep", (10, 10)])
@pytest.mark.parametrize("out_dtype", ["keep", "int8"])
def test_group_array_like_creation(
zarr_format: ZarrFormat,
method_name: LikeMethodName,
out_shape: Literal["keep"] | tuple[int, ...],
out_chunks: Literal["keep"] | tuple[int, ...],
out_dtype: str,
) -> None:
"""
Test Group.{zeros_like, ones_like, empty_like, full_like}, ensuring that we can override the
shape, chunks, and dtype of the array-like object provided to these functions with
appropriate keyword arguments
"""
ref_arr = zarr.ones(store={}, shape=(11, 12), dtype="uint8", chunks=(11, 12))
group = Group.from_store({}, zarr_format=zarr_format)
kwargs = {}
if method_name == "full_like":
expect_fill = 4
kwargs["fill_value"] = expect_fill
meth = group.full_like
elif method_name == "zeros_like":
expect_fill = 0
meth = group.zeros_like
elif method_name == "ones_like":
expect_fill = 1
meth = group.ones_like
elif method_name == "empty_like":
expect_fill = ref_arr.fill_value
meth = group.empty_like
else:
raise AssertionError
if out_shape != "keep":
kwargs["shape"] = out_shape
expect_shape = out_shape
else:
expect_shape = ref_arr.shape
if out_chunks != "keep":
kwargs["chunks"] = out_chunks
expect_chunks = out_chunks
else:
expect_chunks = ref_arr.chunks
if out_dtype != "keep":
kwargs["dtype"] = out_dtype
expect_dtype = out_dtype
else:
expect_dtype = ref_arr.dtype

new_arr = meth(name="foo", data=ref_arr, **kwargs)
assert new_arr.shape == expect_shape
assert new_arr.chunks == expect_chunks
assert new_arr.dtype == expect_dtype
assert np.all(new_arr[:] == expect_fill)


def test_group_array_creation(
store: Store,
zarr_format: ZarrFormat,
Loading