diff --git a/src/spatialdata_io/readers/_utils/_image.py b/src/spatialdata_io/readers/_utils/_image.py new file mode 100644 index 00000000..4621d5fb --- /dev/null +++ b/src/spatialdata_io/readers/_utils/_image.py @@ -0,0 +1,126 @@ +from collections.abc import Callable +from typing import Any + +import dask.array as da +import numpy as np +from dask import delayed +from numpy.typing import NDArray + + +def _compute_chunk_sizes_positions(size: int, chunk: int, min_coord: int) -> tuple[NDArray[np.int_], NDArray[np.int_]]: + """Calculate chunk sizes and positions for a given dimension and chunk size""" + # All chunks have the same size except for the last one + positions = np.arange(min_coord, min_coord + size, chunk) + lengths = np.minimum(chunk, min_coord + size - positions) + + return positions, lengths + + +def _compute_chunks( + shape: tuple[int, int], + chunk_size: tuple[int, int], + min_coordinates: tuple[int, int] = (0, 0), +) -> NDArray[np.int_]: + """Create all chunk specs for a given image and chunk size. + + Creates specifications (x, y, width, height) with (x, y) being the upper left corner + of chunks of size chunk_size. Chunks at the edges correspond to the remainder of + chunk size and dimensions + + Parameters + ---------- + shape : tuple[int, int] + Size of the image in (width, height). + chunk_size : tuple[int, int] + Size of individual tiles in (width, height). + min_coordinates : tuple[int, int], optional + Minimum coordinates (x, y) in the image, defaults to (0, 0). + + Returns + ------- + np.ndarray + Array of shape (n_tiles_x, n_tiles_y, 4). Each entry defines a tile + as (x, y, width, height). + """ + x_positions, widths = _compute_chunk_sizes_positions(shape[1], chunk_size[1], min_coord=min_coordinates[1]) + y_positions, heights = _compute_chunk_sizes_positions(shape[0], chunk_size[0], min_coord=min_coordinates[0]) + + # Generate the tiles + tiles = np.array( + [ + [[x, y, w, h] for x, w in zip(x_positions, widths, strict=True)] + for y, h in zip(y_positions, heights, strict=True) + ], + dtype=int, + ) + return tiles + + +def _read_chunks( + func: Callable[..., NDArray[np.int_]], + slide: Any, + coords: NDArray[np.int_], + n_channel: int, + dtype: np.number, + **func_kwargs: Any, +) -> list[list[da.array]]: + """Abstract method to tile a large microscopy image. + + Parameters + ---------- + func + Function to retrieve a rectangular tile from the slide image. Must take the + arguments: + + - slide Full slide image + - x0: x (col) coordinate of upper left corner of chunk + - y0: y (row) coordinate of upper left corner of chunk + - width: Width of chunk + - height: Height of chunk + + and should return the chunk as numpy array of shape (c, y, x) + slide + Slide image in lazyly loaded format compatible with func + coords + Coordinates of the upper left corner of the image in format (n_row_x, n_row_y, 4) + where the last dimension defines the rectangular tile in format (x, y, width, height). + n_row_x represents the number of chunks in x dimension and n_row_y the number of chunks + in y dimension. + n_channel + Number of channels in array + dtype + Data type of image + func_kwargs + Additional keyword arguments passed to func + + Returns + ------- + list[list[da.array]] + List (length: n_row_x) of lists (length: n_row_y) of chunks. + Represents all chunks of the full image. + """ + func_kwargs = func_kwargs if func_kwargs else {} + + # Collect each delayed chunk as item in list of list + # Inner list becomes dim=-1 (cols/x) + # Outer list becomes dim=-2 (rows/y) + # see dask.array.block + chunks = [ + [ + da.from_delayed( + delayed(func)( + slide, + x0=coords[y, x, 0], + y0=coords[y, x, 1], + width=coords[y, x, 2], + height=coords[y, x, 3], + **func_kwargs, + ), + dtype=dtype, + shape=(n_channel, *coords[y, x, [3, 2]]), + ) + for x in range(coords.shape[1]) + ] + for y in range(coords.shape[0]) + ] + return chunks diff --git a/src/spatialdata_io/readers/generic.py b/src/spatialdata_io/readers/generic.py index 24ee7071..b632eba7 100644 --- a/src/spatialdata_io/readers/generic.py +++ b/src/spatialdata_io/readers/generic.py @@ -3,21 +3,34 @@ import warnings from collections.abc import Sequence from pathlib import Path +from typing import Protocol, TypeVar +import dask.array as da import numpy as np +import tifffile from dask_image.imread import imread from geopandas import GeoDataFrame +from numpy.typing import NDArray from spatialdata._docs import docstring_parameter from spatialdata.models import Image2DModel, ShapesModel from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM from spatialdata.transformations import Identity from xarray import DataArray +from ._utils._image import _compute_chunks, _read_chunks + VALID_IMAGE_TYPES = [".tif", ".tiff", ".png", ".jpg", ".jpeg"] VALID_SHAPE_TYPES = [".geojson"] +DEFAULT_CHUNKSIZE = (1000, 1000) __all__ = ["generic", "geojson", "image", "VALID_IMAGE_TYPES", "VALID_SHAPE_TYPES"] +T = TypeVar("T", bound=np.generic) # Restrict to NumPy scalar types + + +class DaskArray(Protocol[T]): + dtype: np.dtype[T] + @docstring_parameter( valid_image_types=", ".join(VALID_IMAGE_TYPES), @@ -68,11 +81,69 @@ def geojson(input: Path, coordinate_system: str) -> GeoDataFrame: return ShapesModel.parse(input, transformations={coordinate_system: Identity()}) +def _tiff_to_chunks(input: Path, axes_dim_mapping: dict[str, int]) -> list[list[DaskArray[np.int_]]]: + """Chunkwise reader for tiff files. + + Parameters + ---------- + input + Path to image + axes_dim_mapping + Mapping between dimension name (x, y, c) and index + + Returns + ------- + list[list[DaskArray]] + """ + # Lazy file reader + slide = tifffile.memmap(input) + + # Transpose to cyx order + slide = np.transpose(slide, (axes_dim_mapping["c"], axes_dim_mapping["y"], axes_dim_mapping["x"])) + + # Get dimensions in (x, y) + slide_dimensions = slide.shape[2], slide.shape[1] + + # Get number of channels (c) + n_channel = slide.shape[0] + + # Compute chunk coords + chunk_coords = _compute_chunks(slide_dimensions, chunk_size=DEFAULT_CHUNKSIZE, min_coordinates=(0, 0)) + + # Define reader func + def _reader_func(slide: NDArray[np.int_], x0: int, y0: int, width: int, height: int) -> NDArray[np.int_]: + return np.array(slide[:, y0 : y0 + height, x0 : x0 + width]) + + return _read_chunks(_reader_func, slide, coords=chunk_coords, n_channel=n_channel, dtype=slide.dtype) + + def image(input: Path, data_axes: Sequence[str], coordinate_system: str) -> DataArray: - """Reads an image file and returns a parsed Image2D spatial element""" - # this function is just a draft, the more general one will be available when - # https://github.com/scverse/spatialdata-io/pull/234 is merged - image = imread(input) - if len(image.shape) == len(data_axes) + 1 and image.shape[0] == 1: - image = np.squeeze(image, axis=0) + """Reads an image file and returns a parsed Image2DModel""" + # Map passed data axes to position of dimension + axes_dim_mapping = {axes: ndim for ndim, axes in enumerate(data_axes)} + + if input.suffix in [".tiff", ".tif"]: + try: + chunks = _tiff_to_chunks(input, axes_dim_mapping=axes_dim_mapping) + image = da.block(chunks, allow_unknown_chunksizes=True) + + # Edge case: Compressed images are not memory-mappable + except ValueError: + warnings.warn( + "Image data are not memory-mappable, potentially due to compression. Trying to load the image into memory at once", + stacklevel=2, + ) + image = imread(input) + if len(image.shape) == len(data_axes) + 1 and image.shape[0] == 1: + image = np.squeeze(image, axis=0) + image = image.transpose(axes_dim_mapping["c"], axes_dim_mapping["y"], axes_dim_mapping["x"]) + + elif input.suffix in [".png", ".jpg", ".jpeg"]: + image = imread(input) + if len(image.shape) == len(data_axes) + 1 and image.shape[0] == 1: + image = np.squeeze(image, axis=0) + + else: + raise NotImplementedError(f"File format {input.suffix} not implemented") + return Image2DModel.parse(image, dims=data_axes, transformations={coordinate_system: Identity()}) diff --git a/tests/readers/test_utils_image.py b/tests/readers/test_utils_image.py new file mode 100644 index 00000000..f1c58fa1 --- /dev/null +++ b/tests/readers/test_utils_image.py @@ -0,0 +1,65 @@ +import numpy as np +import pytest +from numpy.typing import NDArray + +from spatialdata_io.readers._utils._image import ( + _compute_chunk_sizes_positions, + _compute_chunks, +) + + +@pytest.mark.parametrize( + ("size", "chunk", "min_coordinate", "positions", "lengths"), + [ + (300, 100, 0, np.array([0, 100, 200]), np.array([100, 100, 100])), + (300, 200, 0, np.array([0, 200]), np.array([200, 100])), + (300, 100, -100, np.array([-100, 0, 100]), np.array([100, 100, 100])), + (300, 200, -100, np.array([-100, 100]), np.array([200, 100])), + ], +) +def test_compute_chunk_sizes_positions( + size: int, + chunk: int, + min_coordinate: int, + positions: NDArray[np.number], + lengths: NDArray[np.number], +) -> None: + computed_positions, computed_lengths = _compute_chunk_sizes_positions(size, chunk, min_coordinate) + assert (positions == computed_positions).all() + assert (lengths == computed_lengths).all() + + +@pytest.mark.parametrize( + ("dimensions", "chunk_size", "min_coordinates", "result"), + [ + # Regular grid 2x2 + ( + (2, 2), + (1, 1), + (0, 0), + np.array([[[0, 0, 1, 1], [1, 0, 1, 1]], [[0, 1, 1, 1], [1, 1, 1, 1]]]), + ), + # Different tile sizes + ( + (3, 3), + (2, 2), + (0, 0), + np.array([[[0, 0, 2, 2], [2, 0, 1, 2]], [[0, 2, 2, 1], [2, 2, 1, 1]]]), + ), + ( + (2, 2), + (1, 1), + (-1, 0), + np.array([[[0, -1, 1, 1], [1, -1, 1, 1]], [[0, 0, 1, 1], [1, 0, 1, 1]]]), + ), + ], +) +def test_compute_chunks( + dimensions: tuple[int, int], + chunk_size: tuple[int, int], + min_coordinates: tuple[int, int], + result: NDArray[np.number], +) -> None: + tiles = _compute_chunks(dimensions, chunk_size, min_coordinates) + + assert (tiles == result).all() diff --git a/tests/test_generic.py b/tests/test_generic.py index d4dda1a7..3251e8fd 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -9,9 +9,12 @@ from PIL import Image from spatialdata import SpatialData from spatialdata.datasets import blobs +from tifffile import imread as tiffread +from tifffile import imwrite as tiffwrite from spatialdata_io.__main__ import read_generic_wrapper from spatialdata_io.converters.generic_to_zarr import generic_to_zarr +from spatialdata_io.readers.generic import image @contextmanager @@ -33,6 +36,42 @@ def save_temp_files() -> Generator[tuple[Path, Path, Path], None, None]: yield jpg_path, geojson_path, Path(tmpdir) +@pytest.fixture( + scope="module", + params=[ + {"axes": ("c", "y", "x"), "compression": None}, + {"axes": ("x", "y", "c"), "compression": None}, + {"axes": ("c", "y", "x"), "compression": "lzw"}, + {"axes": ("x", "y", "c"), "compression": "lzw"}, + ], +) +def save_tiff_files( + request: pytest.FixtureRequest, +) -> Generator[tuple[Path, tuple[str], Path], None, None]: + with tempfile.TemporaryDirectory() as tmpdir: + axes = request.param["axes"] + compression = request.param["compression"] + + sdata = blobs() + # save the image as tiff + x = sdata["blobs_image"].transpose(*axes).data.compute() + img = np.clip(x * 255, 0, 255).astype(np.uint8) + + tiff_path = Path(tmpdir) / "blobs_image.tiff" + tiffwrite(tiff_path, img, compression=compression) + + yield tiff_path, axes, Path(tmpdir) + + +def test_read_tiff(save_tiff_files: tuple[Path, tuple[str], Path]) -> None: + tiff_path, axes, _ = save_tiff_files + img = image(tiff_path, data_axes=axes, coordinate_system="global") + + reference = tiffread(tiff_path) + + assert (img.compute() == reference).all() + + @pytest.mark.parametrize("cli", [True, False]) @pytest.mark.parametrize("element_name", [None, "test_element"]) def test_read_generic_image(runner: CliRunner, cli: bool, element_name: str | None) -> None: