Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions docs/tutorial/tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import skimage.color
import skimage.data
import tifffile
import zarr
from loguru import logger

import stack_to_chunk
Expand Down Expand Up @@ -99,11 +100,47 @@


# %%
# The levels property can be inspected to show we've added the first level. Ekach level
# is downsampled by a factor of ``2**level``, so level 0 is downsampled by a factor of
# 1, which is just a copy of the original data (as expected).
# The levels property can be inspected to show we've added the first level (level 0):
print(group.levels)

# %%
# Next, we will add some downsampled levels to the group. This is done by calling
# ``add_downsample_level`` on the group object. Each level is linearly downsampled by a
# factor of :math:`2^{level}`, so level 0 is downsampled by a factor of 1 (original
# resolution), level 1 is downsampled by a factor of 2, level 2 by a factor of 4,
# and so on.

group.add_downsample_level(1)
group.add_downsample_level(2)

# %%
# We can see from the progress messages that the shape of each level is half the size of
# the previous level (rounded up to the nearest integer). The chunk sizes are maintained
# the same as in the original data (until the image size is less than the chunk size).

# %%
# We can inspect the levels property again to check that levels 0, 1, and 2 are present:
print(group.levels)

# %%
# Note that the downampled levels have to be added in order, so we can't add level 3
# before adding level 2 (because the previous level is needed to calculate the next).

# %%
# Let's plot the downsampled data to see what it looks like.
# We turn off interpolation (which ``imshow`` does by default) to make the pixelation
# at downsampled levels more clearly visible.

fig, ax = plt.subplots(1, 3, figsize=(12, 4))
for i, level in enumerate(group.levels):
data = zarr.open(temp_dir_path / "chunked.zarr" / str(level))
first_slice = data[:].transpose(2, 1, 0)[0]
ax[i].imshow(first_slice, cmap="gray", interpolation="none")
ax[i].set_title(f"Level {level}")
ax[i].axis("off")
fig.tight_layout()
fig.show()

# %%
# Cleanup
# -------
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ classifiers = [
dependencies = [
"dask",
"loguru",
"scikit-image",
"zarr",
]
description = "Convert stacks of images to chunked datasets"
Expand Down
1 change: 1 addition & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ per-file-ignores = {"docs/*" = [
"D103", # Missing docstring in public function
"INP001", # is part of an implicit namespace package
"S101",
"SLF001", # Private member accessed
]}
select = [
"ALL",
Expand Down
82 changes: 82 additions & 0 deletions src/stack_to_chunk/downsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
Utilities for downsampling images.

These are based on the ``ome_zarr.dask_utils.py`` module of the ome-zarr-py library,
originally contributed by by Andreas Eisenbarth @aeisenbarth.
See https://github.com/toloudis/ome-zarr-py/pull/1
"""

import numpy as np
import skimage.transform
from dask import array as da


def _rechunk_to_even(image: da.Array) -> da.Array:
"""
Rechunk the input image so that chunk sizes are even in each dimension.

This guarantees integer chunk sizes after downsampling by two.
"""
factors = np.array([0.5] * image.ndim)
even_chunksize = tuple(
np.maximum(1, np.round(np.array(image.chunksize) * factors) / factors).astype(
int
)
)
return image.rechunk(even_chunksize)


def _half_shape(input_shape: tuple[int, int, int]) -> tuple[int, int, int]:
"""
Calculate the output shape after downsampling by two in each dimension.

Rounds up to the nearest integer after division.
"""
return tuple(np.ceil(np.array(input_shape) / 2).astype(int))


def _resize_block(block: da.Array) -> da.Array:
"""
Resize a block by a factor of 2 in each dimension using linear interpolation.
"""
new_block_shape = _half_shape(block.shape)
return skimage.transform.resize(
block,
new_block_shape,
order=1,
anti_aliasing=False,
).astype(block.dtype)


def downsample_by_two(image: da.Array) -> da.Array:
"""
Downsample a dask array by two in each dimension.

Parameters
----------
image : da.Array
The input image.

Returns
-------
da.Array
The downsampled image, which has half the size of the input image in each
dimension.

"""
new_image_shape = _half_shape(image.shape)
new_image_slices = tuple(slice(0, d) for d in new_image_shape)

# Rechunk the image so that chunk sizes will be whole numbers after downsampling
image_rechunked = _rechunk_to_even(image)
new_image_chunksize = _half_shape(image_rechunked.chunksize)

new_image = da.map_blocks(
_resize_block,
image_rechunked,
chunks=new_image_chunksize,
dtype=image.dtype,
)[new_image_slices]

# restore the original chunking and type
return new_image.rechunk(image.chunksize).astype(image.dtype)
72 changes: 54 additions & 18 deletions src/stack_to_chunk/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from typing import Any, Literal

import dask.array as da
import numpy as np
import zarr
from dask.array.core import Array
Expand All @@ -14,6 +15,7 @@
from numcodecs.abc import Codec

from stack_to_chunk._array_helpers import _copy_slab
from stack_to_chunk.downsample import downsample_by_two
from stack_to_chunk.ome_ngff import SPATIAL_UNIT


Expand Down Expand Up @@ -99,7 +101,7 @@ def levels(self) -> list[int]:

def add_full_res_data(
self,
data: Array,
data: da.Array,
*,
chunk_size: int,
compressor: Literal["default"] | Codec,
Expand Down Expand Up @@ -178,18 +180,9 @@ def add_full_res_data(
p.join()

blosc.use_threads = blosc_use_threads
self._add_level_metadata(0)
logger.info("Finished full resolution copy to zarr.")

multiscales = self._group.attrs["multiscales"]
multiscales[0]["datasets"].append(
{
"path": "0",
"coordinateTransformations": [{"type": "scale", "scale": [1, 1, 1]}],
}
)

self._group.attrs["multiscales"] = multiscales

def add_downsample_level(self, level: int) -> None:
"""
Add a level of downsampling.
Expand Down Expand Up @@ -221,12 +214,55 @@ def add_downsample_level(self, level: int) -> None:
msg,
)

source_data = self._group[level_minus_one]
new_shape = np.ceil(np.array(source_data.shape) / 2)
logger.info(f"Downsampling level {level_minus_one} to level {level_str}...")
# Get the source data from the level below as a dask array
source_store = self._group[level_minus_one]
source_data = da.from_zarr(source_store, chunks=source_store.chunks)

self._group[level_str] = zarr.create(
new_shape,
chunks=source_data.chunks,
dtype=source_data.dtype,
compressor=source_data.compressor,
# Linearly downsample the data by a factor of 2 in each dimension
new_data = downsample_by_two(source_data)
logger.info(
f"Generated level {level_str} array with shape {new_data.shape} "
f"and chunk sizes {new_data.chunksize}, using linear interpolation."
)

# Create the new zarr store for the downsampled data
new_store = self._group.require_dataset(
level_str,
shape=new_data.shape,
chunks=source_store.chunks,
dtype=source_store.dtype,
compressor=source_store.compressor,
)
# Write the downsampled data to the new store
new_data.to_zarr(new_store, compute=True)
self._add_level_metadata(level)
logger.info(f"Saved level {level_str} to zarr.")

def _add_level_metadata(self, level: int = 0) -> None:
"""
Add the required multiscale metadata for the corresponding level.

Parameters
----------
level :
Level of downsampling. Level 0 corresponds to full resolution data.

"""
# we assume that the scale factor is always 2 in each dimension
scale_factors = [float(s * 2**level) for s in self._voxel_size]
new_dataset = {
"path": str(level),
"coordinateTransformations": [
{
"type": "scale",
"scale": scale_factors,
}
],
}

multiscales = self._group.attrs["multiscales"][0]
existing_dataset_paths = [d["path"] for d in multiscales["datasets"]]
if new_dataset["path"] not in existing_dataset_paths:
multiscales["datasets"].append(new_dataset)
self._group.attrs["multiscales"] = [multiscales]
69 changes: 69 additions & 0 deletions src/stack_to_chunk/tests/test_downsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Tests for the downsample.py module."""

import dask.array as da
import numpy as np
import pytest
from skimage.transform import resize

from stack_to_chunk.downsample import (
_half_shape,
_rechunk_to_even,
_resize_block,
downsample_by_two,
)


class TestDownsample:
"""Tests for the downsample.py module."""

shape_3d = tuple[int, int, int]
image_shape: shape_3d = (583, 245, 156)
image_chunksize: shape_3d = (64, 64, 64)
image_darray: da.Array = da.random.randint(
low=0, high=2**16, dtype=np.uint16, size=image_shape
)

@pytest.mark.parametrize(
"chunksize",
[
(64, 64, 64),
(64, 63, 63),
(63, 63, 63),
],
)
def test_rechunk_to_even(self, chunksize: shape_3d) -> None:
"""Test rechunking to even chunk sizes."""
chunked_array = self.image_darray.rechunk(chunksize)
even_chunksize = _rechunk_to_even(chunked_array).chunksize
assert even_chunksize == self.image_chunksize

def test_half_shape(self) -> None:
"""Test calculating the half shape of an input shape."""
assert _half_shape(self.image_shape) == (292, 123, 78)
assert _half_shape(self.image_chunksize) == (32, 32, 32)

def test_resize_block(self) -> None:
"""Test resizing a single block by a factor of 2 in each dimension."""
block = self.image_darray[:64, :64, :64].compute()
resized_block = _resize_block(block)
assert resized_block.shape == (32, 32, 32)
assert resized_block.dtype == block.dtype

def test_downsample_by_two(self) -> None:
"""Test downsampling a chunked image by a factor of 2 in each dimension."""
input_array = self.image_darray.rechunk(self.image_chunksize)
downsampled = downsample_by_two(input_array)
assert downsampled.chunksize == input_array.chunksize
assert downsampled.dtype == input_array.dtype
assert downsampled.ndim == input_array.ndim
assert downsampled.shape == _half_shape(input_array.shape)

# directly downsample image (without parallelization)
directly_downsampled = resize(
input_array,
output_shape=_half_shape(input_array.shape),
order=1,
anti_aliasing=False,
).astype(input_array.dtype)

np.testing.assert_equal(downsampled.compute(), directly_downsampled)
Loading