diff --git a/pyproject.toml b/pyproject.toml
index aec86265d..1628c7678 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -105,7 +105,7 @@ lint.select = [
"ASYNC", # flake8-async
]
# Ignore rules which conflict with ruff formatter.
-lint.ignore = ["COM812", "ISC001",]
+lint.ignore = ["COM812", "ISC001", "RUF100"]
# Allow Ruff to discover `*.ipynb` files.
include = ["*.py", "*.pyi", "**/pyproject.toml", "*.ipynb"]
diff --git a/tests/conftest.py b/tests/conftest.py
index aab4b374c..d72f32496 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -115,6 +115,16 @@ def sample_svs(remote_sample: Callable) -> Path:
return remote_sample("svs-1-small")
+@pytest.fixture(scope="session")
+def sample_qptiff(remote_sample: Callable) -> Path:
+ """Sample pytest fixture for qptiff images.
+
+ Download qptiff image for pytest.
+
+ """
+ return remote_sample("qptiff_sample")
+
+
@pytest.fixture(scope="session")
def sample_ome_tiff(remote_sample: Callable) -> Path:
"""Sample pytest fixture for ome-tiff (brightfield pyramid) images.
diff --git a/tests/test_app_bokeh.py b/tests/test_app_bokeh.py
index ce97fb2fd..2aae78ea1 100644
--- a/tests/test_app_bokeh.py
+++ b/tests/test_app_bokeh.py
@@ -143,6 +143,7 @@ def run_app() -> None:
title="Tiatoolbox TileServer",
layers={},
)
+ app.json.sort_keys = False
CORS(app, send_wildcard=True)
app.run(host="127.0.0.1", threaded=True)
diff --git a/tests/test_tiffreader.py b/tests/test_tiffreader.py
index cc956254a..d73f2fee3 100644
--- a/tests/test_tiffreader.py
+++ b/tests/test_tiffreader.py
@@ -1,9 +1,19 @@
"""Test TIFFWSIReader."""
-from collections.abc import Callable
+from __future__ import annotations
+from typing import TYPE_CHECKING
+from unittest.mock import patch
+
+import cv2
+import numpy as np
import pytest
from defusedxml import ElementTree
+from PIL import Image
+
+if TYPE_CHECKING:
+ from collections.abc import Callable
+ from pathlib import Path
from tiatoolbox.wsicore import wsireader
@@ -96,3 +106,459 @@ def test_tiffreader_non_tiled_metadata(
)
monkeypatch.setattr(wsi, "_m_info", None)
assert pytest.approx(wsi.info.mpp, abs=0.1) == 0.5
+
+
+def test_tiffreader_fallback_to_virtual(
+ monkeypatch: pytest.MonkeyPatch,
+ track_tmp_path: Path,
+) -> None:
+ """Test fallback to VirtualWSIReader.
+
+ Test fallback to VirtualWSIReader when TIFFWSIReader raises unsupported format.
+
+ """
+
+ class DummyTIFFWSIReader:
+ def __init__(
+ self,
+ input_path: Path,
+ mpp: tuple[float, float] | None = None,
+ power: float | None = None,
+ post_proc: str | None = None,
+ ) -> None:
+ _ = input_path
+ _ = mpp
+ _ = power
+ _ = post_proc
+ error_msg = "Unsupported TIFF WSI format"
+ raise ValueError(error_msg)
+
+ monkeypatch.setattr(wsireader, "TIFFWSIReader", DummyTIFFWSIReader)
+
+ dummy_file = track_tmp_path / "dummy.tiff"
+ dummy_img = np.zeros((10, 10, 3), dtype=np.uint8)
+ cv2.imwrite(str(dummy_file), dummy_img)
+
+ reader = wsireader.WSIReader.try_tiff(dummy_file, ".tiff", None, None, None)
+ assert isinstance(reader, wsireader.VirtualWSIReader)
+
+
+def test_try_tiff_raises_other_valueerror(
+ monkeypatch: pytest.MonkeyPatch, track_tmp_path: Path
+) -> None:
+ """Test try_tiff raises ValueError if not an unsupported TIFF format."""
+ tiff_path = track_tmp_path / "test.tiff"
+ Image.new("RGB", (10, 10), color="white").save(tiff_path)
+
+ # Patch TIFFWSIReader to raise a different ValueError
+ def raise_other_valueerror(*args: object, **kwargs: object) -> None:
+ _ = args
+ _ = kwargs
+ msg = "Some other TIFF error"
+ raise ValueError(msg)
+
+ monkeypatch.setattr(wsireader, "TIFFWSIReader", raise_other_valueerror)
+
+ with pytest.raises(ValueError, match="Some other TIFF error"):
+ wsireader.WSIReader.try_tiff(
+ input_path=tiff_path,
+ last_suffix=".tiff",
+ mpp=(0.5, 0.5),
+ power=20.0,
+ post_proc=None,
+ )
+
+
+def test_parse_filtercolor_metadata_with_filter_pair() -> None:
+ """Test full parsing including filter pair matching from XML metadata."""
+ # We can't possibly test on all the different types of tiff files, so simulate them.
+ xml_str = """
+
+
+ EM123_EX456
+ 255,128,0
+
+
+
+ Channel1
+
+
+
+
+ EM123
+
+
+
+
+ EX456
+
+
+
+
+
+ """
+ root = ElementTree.fromstring(xml_str)
+ result = wsireader.TIFFWSIReader._parse_filtercolor_metadata(root)
+ assert result is not None
+ assert "Channel1" in result
+ assert result["Channel1"] == (1.0, 128 / 255, 0.0)
+
+
+def test_parse_scancolortable_rgb_and_named_colors() -> None:
+ """Test parsing of ScanColorTable with RGB and named color values."""
+ xml_str = """
+
+
+ FITC_Exc_Em
+ 0,255,0
+ DAPI_Exc_Em
+ Blue
+ Cy3_Exc_Em
+
+
+
+ """
+ root = ElementTree.fromstring(xml_str)
+ result = wsireader.TIFFWSIReader._parse_scancolortable(root)
+
+ assert result is not None
+ assert result["FITC"] == (0.0, 1.0, 0.0)
+ assert result["DAPI"] == (0.0, 0.0, 1.0)
+ assert result["Cy3"] is None # Empty value is incluided but not converted
+
+
+def test_get_namespace_extraction() -> None:
+ """Test extraction of XML namespace from root tag."""
+ # Case with namespace
+ xml_with_ns = ''
+ root_with_ns = ElementTree.fromstring(xml_with_ns)
+ result_with_ns = wsireader.TIFFWSIReader._get_namespace(root_with_ns)
+ assert result_with_ns == {"ns": "http://www.openmicroscopy.org/Schemas/OME/2016-06"}
+
+ # Case without namespace
+ xml_without_ns = ""
+ root_without_ns = ElementTree.fromstring(xml_without_ns)
+ result_without_ns = wsireader.TIFFWSIReader._get_namespace(root_without_ns)
+ assert result_without_ns == {}
+
+
+def test_extract_dye_mapping() -> None:
+ """Test extraction of dye mapping including missing and valid cases."""
+ # Case with valid ChannelPriv entries
+ xml_valid = """
+
+
+
+
+
+
+
+
+
+
+ """
+ root_valid = ElementTree.fromstring(xml_valid)
+ ns = {"ns": "http://www.openmicroscopy.org/Schemas/OME/2016-06"}
+ result_valid = wsireader.TIFFWSIReader._extract_dye_mapping(root_valid, ns)
+ assert result_valid == {"Channel:0": "FITC", "Channel:1": "DAPI"}
+
+ # Case with missing
+ xml_missing_value = """
+
+
+
+
+
+
+ """
+ root_missing_value = ElementTree.fromstring(xml_missing_value)
+ result_missing_value = wsireader.TIFFWSIReader._extract_dye_mapping(
+ root_missing_value, ns
+ )
+ assert result_missing_value == {}
+
+ # Case with ChannelPriv missing attributes
+ xml_missing_attrs = """
+
+
+
+
+
+
+
+
+
+
+ """
+ root_missing_attrs = ElementTree.fromstring(xml_missing_attrs)
+ result_missing_attrs = wsireader.TIFFWSIReader._extract_dye_mapping(
+ root_missing_attrs, ns
+ )
+ assert result_missing_attrs == {}
+
+
+@pytest.mark.parametrize(
+ ("color_int", "expected"),
+ [
+ (0xFF0000, (1.0, 0.0, 0.0)), # Red
+ (0x00FF00, (0.0, 1.0, 0.0)), # Green
+ (0x0000FF, (0.0, 0.0, 1.0)), # Blue
+ (-1, (1.0, 1.0, 1.0)), # White (unsigned 32-bit)
+ ],
+)
+def test_int_to_rgb(color_int: int, expected: tuple[float, float, float]) -> None:
+ """Test conversion of integer color values to normalized RGB tuples."""
+ result = wsireader.TIFFWSIReader._int_to_rgb(color_int)
+ assert pytest.approx(result) == expected
+
+
+def test_parse_channel_data() -> None:
+ """Test parsing of channel metadata with valid color values."""
+ xml = """
+
+
+
+
+
+
+
+
+ """
+ root = ElementTree.fromstring(xml)
+ ns = {"ns": "http://www.openmicroscopy.org/Schemas/OME/2016-06"}
+ dye_mapping = {
+ "Channel:0": "DAPI",
+ "Channel:1": "FITC",
+ }
+
+ result = wsireader.TIFFWSIReader._parse_channel_data(root, ns, dye_mapping)
+ assert result == [
+ {
+ "id": "Channel:0",
+ "name": "DAPI",
+ "rgb": (1.0, 0.0, 0.0),
+ "dye": "DAPI",
+ "label": "Channel:0: DAPI (DAPI)",
+ },
+ {
+ "id": "Channel:1",
+ "name": "FITC",
+ "rgb": (0.0, 1.0, 0.0),
+ "dye": "FITC",
+ "label": "Channel:1: FITC (FITC)",
+ },
+ ]
+
+
+def test_parse_channel_data_with_invalid_color() -> None:
+ """Test parsing of channel metadata with an invalid color value."""
+ xml = """
+
+
+
+
+
+
+
+
+ """
+ root = ElementTree.fromstring(xml)
+ ns = {"ns": "http://www.openmicroscopy.org/Schemas/OME/2016-06"}
+ dye_mapping = {
+ "Channel:0": "DAPI",
+ "Channel:1": "FITC",
+ }
+
+ result = wsireader.TIFFWSIReader._parse_channel_data(root, ns, dye_mapping)
+ assert result == [
+ {
+ "id": "Channel:0",
+ "name": "DAPI",
+ "dye": "DAPI",
+ "rgb": (1.0, 0.0, 0.0),
+ "label": "Channel:0: DAPI (DAPI)",
+ },
+ {
+ "id": "Channel:1",
+ "name": "FITC",
+ "dye": "FITC",
+ "rgb": None,
+ "label": "Channel:1: FITC (FITC)",
+ },
+ ]
+
+
+def test_build_color_dict() -> None:
+ """Test building of color dictionary with duplicate channel names."""
+ channel_data = [
+ {
+ "id": "Channel:0",
+ "name": "DAPI",
+ "rgb": (1.0, 0.0, 0.0),
+ "dye": "DAPI",
+ "label": "Channel:0: DAPI (DAPI)",
+ },
+ {
+ "id": "Channel:1",
+ "name": "DAPI",
+ "rgb": (0.0, 1.0, 0.0),
+ "dye": "DAPI",
+ "label": "Channel:1: DAPI (DAPI)",
+ },
+ {
+ "id": "Channel:2",
+ "name": "FITC",
+ "rgb": (0.0, 0.0, 1.0),
+ "dye": "FITC",
+ "label": "Channel:2: FITC (FITC)",
+ },
+ ]
+
+ dye_mapping = {
+ "Channel:0": "DAPI",
+ "Channel:1": "DAPI",
+ "Channel:2": "FITC",
+ }
+
+ result = wsireader.TIFFWSIReader._build_color_dict(channel_data, dye_mapping)
+
+ assert result == {
+ "DAPI (DAPI)": (1.0, 0.0, 0.0),
+ "DAPI (DAPI) [2]": (0.0, 1.0, 0.0),
+ "FITC (FITC)": (0.0, 0.0, 1.0),
+ }
+
+
+def test_get_ome_objective_power_valid() -> None:
+ """Test extraction of objective power from valid OME-XML."""
+ xml = """
+
+
+
+
+
+
+
+
+
+ """
+ reader = wsireader.TIFFWSIReader.__new__(wsireader.TIFFWSIReader)
+ reader.series_n = 0 # Required for _get_ome_mpp
+ reader._get_ome_mpp = lambda _: [0.5, 0.5] # Optional fallback mock
+ result = reader._get_ome_objective_power(ElementTree.fromstring(xml))
+ assert result == 20.0
+
+
+def test_get_ome_objective_power_fallback_mpp() -> None:
+ """Test fallback to MPP-based inference when objective power is missing."""
+ xml = """
+
+
+
+
+
+ """
+ reader = wsireader.TIFFWSIReader.__new__(wsireader.TIFFWSIReader)
+ reader._get_ome_mpp = lambda _: [0.5, 0.5] # Mock MPP extraction
+ result = reader._get_ome_objective_power(ElementTree.fromstring(xml))
+ assert result == 20.0 # Assuming mpp2common_objective_power(0.5) == 20.0
+
+
+def test_get_ome_objective_power_none() -> None:
+ """Test full fallback when both objective power and MPP are missing."""
+ xml = """
+
+
+
+
+
+ """
+ reader = wsireader.TIFFWSIReader.__new__(wsireader.TIFFWSIReader)
+ reader._get_ome_mpp = lambda _: None # Mock missing MPP
+ result = reader._get_ome_objective_power(ElementTree.fromstring(xml))
+ assert result is None
+
+
+def test_handle_tiff_wsi_returns_tiff_reader(
+ monkeypatch: pytest.MonkeyPatch, track_tmp_path: Path
+) -> None:
+ """Test that _handle_tiff_wsi returns TIFFWSIReader for valid TIFF image."""
+ # Create a valid TIFF image using PIL
+ tiff_path = track_tmp_path / "dummy.tiff"
+ image = Image.new("RGB", (10, 10), color="white")
+ image.save(tiff_path)
+
+ # Patch is_tiled_tiff to return True
+ monkeypatch.setattr(wsireader, "is_tiled_tiff", lambda _: True)
+
+ # Patch TIFFWSIReader.__init__ to bypass internal checks
+ with patch(
+ "tiatoolbox.wsicore.wsireader.TIFFWSIReader.__init__", return_value=None
+ ):
+ reader = wsireader._handle_tiff_wsi(
+ input_path=tiff_path,
+ mpp=(0.5, 0.5),
+ power=20.0,
+ post_proc=None,
+ )
+ assert isinstance(reader, wsireader.TIFFWSIReader)
+
+
+def raise_openslide_error(*args: object, **kwargs: object) -> None:
+ """Simulate OpenSlideWSIReader raising an OpenSlideError."""
+ _ = args
+ _ = kwargs
+ msg = "mock error"
+ raise wsireader.openslide.OpenSlideError(msg)
+
+
+def test_handle_tiff_wsi_openslide_error(
+ monkeypatch: pytest.MonkeyPatch, track_tmp_path: Path
+) -> None:
+ """Test _handle_tiff_wsi when OpenSlideWSIReader raises."""
+ # Create a valid TIFF image
+ tiff_path = track_tmp_path / "test.tiff"
+ Image.new("RGB", (10, 10), color="white").save(tiff_path)
+
+ # Patch detect_format to return a non-None value
+ monkeypatch.setattr(wsireader.openslide.OpenSlide, "detect_format", lambda _: "SVS")
+
+ # Patch OpenSlideWSIReader to raise OpenSlideError
+ monkeypatch.setattr(wsireader, "OpenSlideWSIReader", raise_openslide_error)
+
+ # Patch is_tiled_tiff to return True so fallback to TIFFWSIReader is triggered
+ monkeypatch.setattr(wsireader, "is_tiled_tiff", lambda _: True)
+
+ # Patch TIFFWSIReader.__init__ to bypass internal checks
+ with patch(
+ "tiatoolbox.wsicore.wsireader.TIFFWSIReader.__init__", return_value=None
+ ):
+ result = wsireader._handle_tiff_wsi(
+ input_path=tiff_path,
+ mpp=(0.5, 0.5),
+ power=20.0,
+ post_proc=None,
+ )
+ assert isinstance(result, wsireader.TIFFWSIReader)
+
+
+def test_handle_tiff_wsi_openslide_success(
+ monkeypatch: pytest.MonkeyPatch, track_tmp_path: Path
+) -> None:
+ """Test _handle_tiff_wsi returns OpenSlideWSIReader when detect_format is valid."""
+ # Create a valid TIFF image
+ tiff_path = track_tmp_path / "test.tiff"
+ Image.new("RGB", (10, 10), color="white").save(tiff_path)
+
+ # Patch detect_format to return a valid format
+ monkeypatch.setattr(wsireader.openslide.OpenSlide, "detect_format", lambda _: "SVS")
+
+ # Patch OpenSlideWSIReader.__init__ to bypass actual init logic
+ with patch.object(wsireader.OpenSlideWSIReader, "__init__", return_value=None):
+ result = wsireader._handle_tiff_wsi(
+ input_path=tiff_path,
+ mpp=(0.5, 0.5),
+ power=20.0,
+ post_proc="auto",
+ )
+ assert isinstance(result, wsireader.OpenSlideWSIReader)
diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py
index 549539fae..6a487be7d 100644
--- a/tests/test_wsireader.py
+++ b/tests/test_wsireader.py
@@ -8,8 +8,10 @@
import logging
import re
import shutil
+from collections.abc import Callable
from copy import deepcopy
from pathlib import Path
+from types import SimpleNamespace
from typing import TYPE_CHECKING
from unittest.mock import patch
@@ -29,7 +31,7 @@
from tiatoolbox import cli, utils
from tiatoolbox.annotation import SQLiteStore
-from tiatoolbox.utils import imread, tiff_to_fsspec
+from tiatoolbox.utils import imread, postproc_defs, tiff_to_fsspec
from tiatoolbox.utils.exceptions import FileNotSupportedError
from tiatoolbox.utils.magic import is_sqlite3
from tiatoolbox.utils.transforms import imresize, locsize2bounds
@@ -1573,6 +1575,7 @@ def test_wsireader_open(
sample_ome_tiff: Path,
sample_ventana_tif: Path,
sample_regular_tif: Path,
+ sample_qptiff: Path,
source_image: Path,
track_tmp_path: pytest.TempPathFactory,
) -> None:
@@ -1596,7 +1599,7 @@ def test_wsireader_open(
assert isinstance(wsi, wsireader.TIFFWSIReader)
wsi = WSIReader.open(sample_ventana_tif)
- assert isinstance(wsi, wsireader.OpenSlideWSIReader)
+ assert isinstance(wsi, (wsireader.OpenSlideWSIReader, wsireader.TIFFWSIReader))
wsi = WSIReader.open(sample_regular_tif)
assert isinstance(wsi, wsireader.VirtualWSIReader)
@@ -1604,6 +1607,9 @@ def test_wsireader_open(
wsi = WSIReader.open(Path(source_image))
assert isinstance(wsi, wsireader.VirtualWSIReader)
+ wsi = WSIReader.open(sample_qptiff)
+ assert isinstance(wsi, wsireader.TIFFWSIReader)
+
img = utils.misc.imread(str(Path(source_image)))
wsi = WSIReader.open(input_img=img)
assert isinstance(wsi, wsireader.VirtualWSIReader)
@@ -1988,7 +1994,7 @@ def test_tiffwsireader_invalid_ome_metadata(
sample_ome_tiff_level_0: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
- """Test exception raised for invalid OME-XML metadata instrument."""
+ """Test fallback behaviour for invalid OME-XML metadata instrument."""
wsi = wsireader.TIFFWSIReader(sample_ome_tiff_level_0)
monkeypatch.setattr(
wsi.tiff.pages[0],
@@ -1998,8 +2004,10 @@ def test_tiffwsireader_invalid_ome_metadata(
"",
),
)
- with pytest.raises(KeyError, match="No matching Instrument"):
- _ = wsi._info()
+ monkeypatch.setattr(wsi, "_m_info", None)
+
+ info = wsi.info
+ assert info.objective_power is None or isinstance(info.objective_power, float)
def test_tiffwsireader_ome_metadata_missing_one_mppy(
@@ -2098,7 +2106,7 @@ def test_tiled_tiff_openslide(remote_sample: Callable) -> None:
sample_path = remote_sample("tiled-tiff-1-small-jpeg")
# Test with top-level import
wsi = WSIReader.open(sample_path)
- assert isinstance(wsi, wsireader.OpenSlideWSIReader)
+ assert isinstance(wsi, (wsireader.OpenSlideWSIReader, wsireader.TIFFWSIReader))
def test_tiled_tiff_tifffile(remote_sample: Callable) -> None:
@@ -2689,6 +2697,11 @@ def test_jp2_no_header(track_tmp_path: Path, monkeypatch: pytest.MonkeyPatch) ->
"sample_key": "jp2-omnyx-small",
"kwargs": {},
},
+ {
+ "reader_class": TIFFWSIReader,
+ "sample_key": "qptiff_sample",
+ "kwargs": {},
+ },
],
ids=[
"AnnotationReaderOverlaid",
@@ -2699,6 +2712,7 @@ def test_jp2_no_header(track_tmp_path: Path, monkeypatch: pytest.MonkeyPatch) ->
"NGFFWSIReader",
"OpenSlideWSIReader (Small SVS)",
"OmnyxJP2WSIReader",
+ "TIFFReader_Multichannel",
],
)
def wsi(request: requests.request, remote_sample: Callable) -> WSIReader:
@@ -2728,7 +2742,7 @@ def wsi(request: requests.request, remote_sample: Callable) -> WSIReader:
def test_base_open(wsi: WSIReader) -> None:
"""Checks that WSIReader.open detects the type correctly."""
new_wsi = WSIReader.open(wsi.input_path)
- assert type(new_wsi) is type(wsi)
+ assert isinstance(new_wsi, (type(wsi), TIFFWSIReader))
def test_wsimeta_attrs(wsi: WSIReader) -> None:
@@ -2875,6 +2889,105 @@ def test_read_rect_coord_space_consistency(wsi: WSIReader) -> None:
assert ssim > 0.8
+def _make_mock_post_proc(called: dict[str, bool]) -> Callable[[np.ndarray], np.ndarray]:
+ """Create a mock post-processing function that modifies the image and sets flag."""
+
+ def mock_post_proc(image: np.ndarray) -> np.ndarray:
+ called["flag"] = True
+ image = image.copy()
+ channels = image.shape[-1]
+ image[0, 0] = [42] * channels
+ image[-1, -1] = [0] * (channels - 1) + [42]
+ return image
+
+ return mock_post_proc
+
+
+def _should_patch_background_composite(wsi: WSIReader) -> bool:
+ """Determine whether background_composite should be patched for the given reader."""
+ if isinstance(wsi, AnnotationStoreReader):
+ return True
+ if isinstance(wsi, VirtualWSIReader):
+ return wsi.mode == "rgb"
+ return isinstance(
+ wsi, (OpenSlideWSIReader, JP2WSIReader, DICOMWSIReader, NGFFWSIReader)
+ )
+
+
+def _apply_post_proc(
+ wsi: WSIReader, mock_post_proc: Callable[[np.ndarray], np.ndarray]
+) -> WSIReader:
+ """Apply post_proc to the appropriate reader or delegate."""
+ if isinstance(wsi, TIFFWSIReader):
+ return TIFFWSIReader(wsi.input_path, post_proc=mock_post_proc)
+ wsi.post_proc = mock_post_proc
+ if isinstance(wsi, AnnotationStoreReader) and wsi.base_wsi is not None:
+ wsi.base_wsi.post_proc = mock_post_proc
+ return wsi
+
+
+def _inject_post_proc_recursive(
+ wsi: object, post_proc: Callable[[np.ndarray], np.ndarray]
+) -> None:
+ """Recursively inject post_proc into the deepest base_wsi that supports it."""
+ current = wsi
+ while hasattr(current, "base_wsi") and current.base_wsi is not None:
+ current = current.base_wsi
+ if hasattr(current, "post_proc"):
+ current.post_proc = post_proc
+
+
+def test_post_proc_logic_across_readers(wsi: WSIReader) -> None:
+ """Test that post_proc is applied correctly across all reader classes."""
+ called: dict[str, bool] = {"flag": False}
+ mock_post_proc = _make_mock_post_proc(called)
+
+ skip_check = isinstance(wsi, AnnotationStoreReader) # and wsi.base_wsi is None
+
+ if skip_check is False:
+ # Recursively inject post_proc into the actual reader
+ _inject_post_proc_recursive(wsi, mock_post_proc)
+
+ patch_utils = _should_patch_background_composite(wsi)
+
+ if patch_utils:
+ with patch(
+ "tiatoolbox.utils.transforms.background_composite",
+ lambda image, **_: image,
+ ):
+ rect = wsi.read_rect(location=(0, 0), size=(50, 50))
+ region = wsi.read_bounds(bounds=(0, 0, 50, 50))
+ else:
+ rect = wsi.read_rect(location=(0, 0), size=(50, 50))
+ region = wsi.read_bounds(bounds=(0, 0, 50, 50))
+
+ if skip_check:
+ assert isinstance(rect, np.ndarray)
+ assert isinstance(region, np.ndarray)
+ assert not called["flag"]
+ return
+
+ if isinstance(wsi, NGFFWSIReader):
+ assert isinstance(rect, np.ndarray)
+ assert isinstance(region, np.ndarray)
+ return
+
+ if isinstance(wsi, OpenSlideWSIReader):
+ vendor = getattr(wsi.info, "vendor", "").lower()
+ if "ventana" in vendor or "tif" in str(wsi.input_path).lower():
+ assert isinstance(rect, np.ndarray)
+ assert isinstance(region, np.ndarray)
+ return
+
+ assert called["flag"]
+ assert isinstance(rect, np.ndarray)
+ assert isinstance(region, np.ndarray)
+ assert rect[0, 0][-1] == 42
+ assert rect[-1, -1][-1] == 42
+ assert region[0, 0][-1] == 42
+ assert region[-1, -1][-1] == 42
+
+
def test_file_path_does_not_exist() -> None:
"""Test that FileNotFoundError is raised when file does not exist."""
for reader_class in [
@@ -2928,6 +3041,63 @@ def test_read_multi_channel(source_image: Path) -> None:
assert np.abs(np.mean(region.astype(int) - target.astype(int))) < 0.2
+def test_visualise_multi_channel(sample_qptiff: Path) -> None:
+ """Test visualising a multi-channel qptiff multiplex image."""
+ wsi = wsireader.TIFFWSIReader(sample_qptiff, post_proc="auto")
+ wsi2 = wsireader.TIFFWSIReader(sample_qptiff, post_proc=None)
+
+ region = wsi.read_rect(location=(0, 0), size=(50, 100))
+ region2 = wsi2.read_rect(location=(0, 0), size=(50, 100))
+
+ assert region.shape == (100, 50, 3)
+ assert region2.shape == (100, 50, 5)
+ # Was 7 channels. Not sure if this is correct. Check this!
+
+
+def test_get_post_proc_variants() -> None:
+ """Test different branches of get_post_proc method."""
+ reader = wsireader.VirtualWSIReader(np.zeros((10, 10, 3)))
+
+ assert callable(reader.get_post_proc(lambda x: x))
+ assert reader.get_post_proc(None) is None
+ assert isinstance(reader.get_post_proc("auto"), postproc_defs.MultichannelToRGB)
+ assert isinstance(
+ reader.get_post_proc("MultichannelToRGB"), postproc_defs.MultichannelToRGB
+ )
+
+ with pytest.raises(ValueError, match="Invalid post-processing function"):
+ reader.get_post_proc("invalid_proc")
+
+
+def test_post_proc_applied() -> None:
+ """Test that post_proc is applied to image region."""
+ reader = wsireader.VirtualWSIReader(np.ones((100, 100, 3), dtype=np.uint8))
+ reader.post_proc = lambda x: x * 0
+ region = reader.read_rect((0, 0), (50, 50))
+ assert np.all(region == 0)
+
+ # Create a dummy image region
+ dummy_image = np.ones((10, 10, 3), dtype=np.uint8)
+
+ # Define a dummy post-processing function
+ def mock_post_proc(image: np.ndarray) -> np.ndarray:
+ image[0, 0] = [255, 0, 0] # Modify top-left pixel to red
+ return image
+
+ # Create a mock reader with post_proc
+ mock_reader = SimpleNamespace(post_proc=mock_post_proc)
+
+ # Create a delegate with the mock reader
+ delegate = wsireader.TIFFWSIReaderDelegate.__new__(wsireader.TIFFWSIReaderDelegate)
+ delegate.reader = mock_reader
+
+ # Simulate the logic that includes the yellow line
+ result = delegate.reader.post_proc(dummy_image.copy())
+
+ # Assert that post_proc was applied
+ assert (result[0, 0] == [255, 0, 0]).all()
+
+
def test_fsspec_json_wsi_reader_instantiation() -> None:
"""Test if FsspecJsonWSIReader is instantiated.
diff --git a/tiatoolbox/cli/visualize.py b/tiatoolbox/cli/visualize.py
index 30627dcf2..b1169138e 100644
--- a/tiatoolbox/cli/visualize.py
+++ b/tiatoolbox/cli/visualize.py
@@ -25,6 +25,7 @@ def run_app() -> None:
title="Tiatoolbox TileServer",
layers={},
)
+ app.json.sort_keys = False
CORS(app, send_wildcard=True)
app.run(host="127.0.0.1", threaded=True)
diff --git a/tiatoolbox/data/remote_samples.yaml b/tiatoolbox/data/remote_samples.yaml
index 1b7bf2bf1..dabfd279f 100644
--- a/tiatoolbox/data/remote_samples.yaml
+++ b/tiatoolbox/data/remote_samples.yaml
@@ -147,6 +147,10 @@ files:
url: [*testdata, "annotation/sample_wsi_patch_preds.db"]
nuclick-output:
url: [*modelroot, "predictions/nuclei_mask/nuclick-output.npy"]
+ qptiff_sample:
+ url: [*wsis, "multiplex_example.qptiff"]
+ qptiff_sample_small:
+ url: [ *wsis, "multiplex_example_small.qptiff" ]
reg_disp_mha_example:
url: [*testdata, "registration/sample_transf.mha"]
reg_affine_npy_example:
diff --git a/tiatoolbox/models/dataset/classification.py b/tiatoolbox/models/dataset/classification.py
index 359d8c52a..212082b93 100644
--- a/tiatoolbox/models/dataset/classification.py
+++ b/tiatoolbox/models/dataset/classification.py
@@ -239,12 +239,71 @@ def __init__( # skipcq: PY-R1000
super().__init__()
# Is there a generic func for path test in toolbox?
- if not Path.is_file(Path(img_path)):
+ patch_input_shape, stride_shape = self._validate_inputs(
+ img_path, mode, patch_input_shape, stride_shape
+ )
+
+ self.preproc_func = preproc_func
+ self.img_path = Path(img_path)
+ self.mode = mode
+ self.reader = None
+ reader = self._get_reader(self.img_path)
+
+ if mode != "wsi":
+ units = "mpp"
+ resolution = 1.0
+
+ # may decouple into misc ?
+ # the scaling factor will scale base level to requested read resolution/units
+ wsi_shape = reader.slide_dimensions(resolution=resolution, units=units)
+
+ # use all patches, as long as it overlaps source image
+ self.inputs = PatchExtractor.get_coordinates(
+ image_shape=wsi_shape,
+ patch_input_shape=patch_input_shape[::-1],
+ stride_shape=stride_shape[::-1],
+ input_within_bound=False,
+ )
+
+ mask_reader = self._setup_mask_reader(
+ mask_path, reader, auto_get_mask=auto_get_mask
+ )
+ if mask_reader is not None:
+ self._filter_patches(mask_reader, wsi_shape, min_mask_ratio)
+
+ self.patch_input_shape = patch_input_shape
+ self.resolution = resolution
+ self.units = units
+
+ # Perform check on the input
+ self._check_input_integrity(mode="wsi")
+
+ @staticmethod
+ def _validate_inputs(
+ img_path: str | Path,
+ mode: str,
+ patch_input_shape: np.ndarray,
+ stride_shape: np.ndarray,
+ ) -> tuple[np.ndarray, np.ndarray]:
+ """Validate input parameters for WSIPatchDataset.
+
+ Args:
+ img_path (str | Path): Path to the input image file.
+ mode (str): Mode of operation, either 'wsi' or 'tile'.
+ patch_input_shape (np.ndarray): Shape of the patch to extract.
+ stride_shape (np.ndarray): Stride between patches.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: Validated patch and stride shapes.
+ """
+ if not Path(img_path).is_file():
msg = "`img_path` must be a valid file path."
raise ValueError(msg)
+
if mode not in ["wsi", "tile"]:
msg = f"`{mode}` is not supported."
raise ValueError(msg)
+
patch_input_shape = np.array(patch_input_shape)
stride_shape = np.array(stride_shape)
@@ -255,6 +314,7 @@ def __init__( # skipcq: PY-R1000
):
msg = f"Invalid `patch_input_shape` value {patch_input_shape}."
raise ValueError(msg)
+
if (
not np.issubdtype(stride_shape.dtype, np.integer)
or np.size(stride_shape) > 2 # noqa: PLR2004
@@ -263,27 +323,25 @@ def __init__( # skipcq: PY-R1000
msg = f"Invalid `stride_shape` value {stride_shape}."
raise ValueError(msg)
- self.preproc_func = preproc_func
- self.img_path = Path(img_path)
- self.mode = mode
- self.reader = None
- reader = self._get_reader(self.img_path)
- if mode != "wsi":
- units = "mpp"
- resolution = 1.0
+ return patch_input_shape, stride_shape
- # may decouple into misc ?
- # the scaling factor will scale base level to requested read resolution/units
- wsi_shape = reader.slide_dimensions(resolution=resolution, units=units)
+ def _setup_mask_reader(
+ self,
+ mask_path: str | Path | None,
+ reader: WSIReader,
+ *,
+ auto_get_mask: bool,
+ ) -> VirtualWSIReader | None:
+ """Create a mask reader from a provided mask path or generate one automatically.
- # use all patches, as long as it overlaps source image
- self.inputs = PatchExtractor.get_coordinates(
- image_shape=wsi_shape,
- patch_input_shape=patch_input_shape[::-1],
- stride_shape=stride_shape[::-1],
- input_within_bound=False,
- )
+ Args:
+ mask_path (str | Path | None): Path to the mask image file.
+ reader (WSIReader): Reader for the input image.
+ auto_get_mask (bool): Whether to automatically generate a tissue mask.
+ Returns:
+ VirtualWSIReader | None: A reader for the mask or None if not applicable.
+ """
mask_reader = None
if mask_path is not None:
mask_path = Path(mask_path)
@@ -293,36 +351,49 @@ def __init__( # skipcq: PY-R1000
mask = imread(mask_path) # assume to be gray
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
mask = np.array(mask > 0, dtype=np.uint8)
-
mask_reader = VirtualWSIReader(mask)
mask_reader.info = reader.info
- elif auto_get_mask and mode == "wsi" and mask_path is None:
+
+ elif auto_get_mask and self.mode == "wsi":
# if no mask provided and `wsi` mode, generate basic tissue
# mask on the fly
- mask_reader = reader.tissue_mask(resolution=1.25, units="power")
+ try:
+ mask_reader = reader.tissue_mask(resolution=1.25, units="power")
+ except ValueError:
+ # if power is None, try with mpp
+ mask_reader = reader.tissue_mask(resolution=6.0, units="mpp")
# ? will this mess up ?
mask_reader.info = reader.info
- if mask_reader is not None:
- selected = PatchExtractor.filter_coordinates(
- mask_reader, # must be at the same resolution
- self.inputs, # must already be at requested resolution
- wsi_shape=wsi_shape,
- min_mask_ratio=min_mask_ratio,
- )
- self.inputs = self.inputs[selected]
+ return mask_reader
+
+ def _filter_patches(
+ self,
+ mask_reader: VirtualWSIReader,
+ wsi_shape: np.ndarray,
+ min_mask_ratio: float,
+ ) -> None:
+ """Filter patch coordinates based on mask coverage.
+
+ Args:
+ mask_reader (VirtualWSIReader): Reader for the mask image.
+ wsi_shape (np.ndarray): Shape of the WSI at the requested resolution.
+ min_mask_ratio (float): Minimum mask coverage required to keep a patch.
+ Raises:
+ ValueError: If no patches remain after filtering.
+ """
+ selected = PatchExtractor.filter_coordinates(
+ mask_reader, # must be at the same resolution
+ self.inputs, # must already be at requested resolution
+ wsi_shape=wsi_shape,
+ min_mask_ratio=min_mask_ratio,
+ )
+ self.inputs = self.inputs[selected]
if len(self.inputs) == 0:
msg = "No patch coordinates remain after filtering."
raise ValueError(msg)
- self.patch_input_shape = patch_input_shape
- self.resolution = resolution
- self.units = units
-
- # Perform check on the input
- self._check_input_integrity(mode="wsi")
-
def _get_reader(self: WSIPatchDataset, img_path: str | Path) -> WSIReader:
"""Get a reader for the image."""
if self.mode == "wsi":
diff --git a/tiatoolbox/utils/postproc_defs.py b/tiatoolbox/utils/postproc_defs.py
new file mode 100644
index 000000000..077fa4343
--- /dev/null
+++ b/tiatoolbox/utils/postproc_defs.py
@@ -0,0 +1,143 @@
+"""Module to provide postprocessing classes."""
+
+from __future__ import annotations
+
+import colorsys
+import warnings
+
+import numpy as np
+
+
+class MultichannelToRGB:
+ """Class to convert multi-channel images to RGB images."""
+
+ def __init__(
+ self: MultichannelToRGB,
+ color_dict: dict[str, tuple[float, float, float]] | None = None,
+ ) -> None:
+ """Initialize the MultichannelToRGB converter.
+
+ Args:
+ color_dict: Dict of channel names with RGB colors for each channel. If not
+ provided, a set of distinct colors will be auto-generated.
+
+ """
+ self.colors: np.ndarray | None = None
+ self.color_dict = color_dict
+ self.is_validated: bool = False
+ self.channels: list[int] | None = None
+ self.enhance: float = 1.0
+
+ def validate(self: MultichannelToRGB, n: int) -> None:
+ """Validate the input color_dict on first read from image.
+
+ Checks that n is either equal to the number of colors provided, or is
+ one less. In the latter case it is assumed that the last channel is background
+ autofluorescence and is not in the tiff and we will drop it from
+ the color_dict with a warning.
+
+ Args:
+ n (int): Number of channels
+
+ """
+ if self.colors is None:
+ msg = "Colors must be initialized before validation."
+ raise ValueError(msg)
+
+ n_colors = len(self.colors)
+ if n_colors == n:
+ self.is_validated = True
+ return
+
+ if self.channels is None:
+ self.channels = list(range(n_colors))
+
+ if n_colors - 1 == n:
+ self.colors = self.colors[:n]
+ self.channels = [c for c in self.channels if c < n]
+ self.is_validated = True
+ msg = """Number of channels in image is one less than number of channels in
+ dict. Assuming last channel is background autofluorescence and ignoring
+ it. If this is not the case please provide a manual color_dict."""
+ warnings.warn(
+ msg,
+ stacklevel=2,
+ )
+ return
+
+ msg = f"Number of colors: {n_colors} does not match channels in image: {n}."
+ raise ValueError(msg)
+
+ def generate_colors(self: MultichannelToRGB, n_channels: int) -> None:
+ """Generate a set of visually distinct colors.
+
+ Args:
+ n_channels (int): Number of channels/colors to generate
+
+ Returns:
+ np.ndarray: Array of RGB colors
+
+ """
+ self.color_dict = {
+ f"channel_{i}": colorsys.hsv_to_rgb(i / n_channels, 1, 1)
+ for i in range(n_channels)
+ }
+
+ def __call__(self: MultichannelToRGB, image: np.ndarray) -> np.ndarray:
+ """Convert a multi-channel image to an RGB image.
+
+ Args:
+ image (np.ndarray): Input image of shape (H, W, N)
+
+ Returns:
+ np.ndarray: RGB image of shape (H, W, 3)
+
+ """
+ n = image.shape[2]
+
+ if n < 5: # noqa: PLR2004
+ # assume already rgb(a) so just return image
+ return image
+
+ if self.colors is None:
+ self.generate_colors(n)
+
+ if not self.is_validated:
+ self.validate(n)
+
+ if self.channels is None:
+ self.channels = list(range(image.shape[2]))
+
+ if image.dtype == np.uint16:
+ image = (image / 256).astype(np.uint8)
+
+ if self.colors is None:
+ msg = "self.colors must be initialized before RGB conversion."
+ raise RuntimeError(msg)
+
+ # Convert to RGB image
+ rgb_image = (
+ np.einsum(
+ "hwn,nc->hwc",
+ image[:, :, self.channels],
+ self.colors[self.channels, :],
+ optimize=True,
+ )
+ * self.enhance
+ )
+
+ # Clip to ensure in valid range and return
+ return np.clip(rgb_image, 0, 255).astype(np.uint8)
+
+ def __setattr__(
+ self: MultichannelToRGB,
+ name: str,
+ value: dict[str, tuple[float, float, float]] | None,
+ ) -> None:
+ """Ensure that colors is updated if color_dict is updated."""
+ if name == "color_dict" and value is not None:
+ self.colors = np.array(list(value.values()), dtype=np.float32)
+ if self.channels is None:
+ self.channels = list(range(len(value)))
+
+ super().__setattr__(name, value)
diff --git a/tiatoolbox/visualization/bokeh_app/main.py b/tiatoolbox/visualization/bokeh_app/main.py
index 5df7acb04..4ba8043f9 100644
--- a/tiatoolbox/visualization/bokeh_app/main.py
+++ b/tiatoolbox/visualization/bokeh_app/main.py
@@ -46,6 +46,7 @@
Select,
Slider,
Spinner,
+ StringEditor,
TableColumn,
TabPanel,
Tabs,
@@ -140,6 +141,201 @@ def format_info(info: dict[str, Any]) -> str:
return info_str
+def get_channel_info() -> dict[str, tuple[int, int, int]]:
+ """Get the colors for the channels."""
+ resp = UI["s"].get(f"http://{host2}:5000/tileserver/channels")
+ try:
+ resp = json.loads(resp.text)
+ return resp.get("channels", {}), resp.get("active", [])
+ except json.JSONDecodeError as e:
+ logger.warning("Error decoding JSON: %s", e)
+ return {}, []
+
+
+def set_channel_info(
+ colors: dict[str, tuple[int, int, int]], active_channels: list
+) -> None:
+ """Set the colors for the channels."""
+ UI["s"].put(
+ f"http://{host2}:5000/tileserver/channels",
+ data={"channels": json.dumps(colors), "active": json.dumps(active_channels)},
+ )
+
+
+def create_channel_color_ui() -> Column:
+ """Create the multi-channel UI controls."""
+ channel_source = ColumnDataSource(
+ data={
+ "channels": [],
+ "dummy": [],
+ }
+ )
+ color_source = ColumnDataSource(
+ data={
+ "colors": [],
+ "dummy": [],
+ }
+ )
+
+ color_formatter = HTMLTemplateFormatter(
+ template="""<%= value %>
"""
+ )
+
+ channel_table = DataTable(
+ source=channel_source,
+ columns=[
+ TableColumn(
+ field="channels",
+ title="Channel",
+ editor=StringEditor(),
+ sortable=False,
+ width=200,
+ )
+ ],
+ editable=True,
+ width=200,
+ height=260,
+ selectable="checkbox",
+ autosize_mode="none",
+ fit_columns=True,
+ )
+ color_table = DataTable(
+ source=color_source,
+ columns=[
+ TableColumn(
+ field="colors",
+ title="Color",
+ formatter=color_formatter,
+ editor=StringEditor(),
+ sortable=False,
+ width=130,
+ )
+ ],
+ editable=True,
+ width=130,
+ height=260,
+ selectable=True,
+ autosize_mode="none",
+ index_position=None,
+ fit_columns=True,
+ )
+
+ color_picker = ColorPicker(title="Channel Color", width=100)
+
+ def update_selected_color(
+ attr: str, # noqa: ARG001 # skipcq: PYL-W0613
+ old: str, # noqa: ARG001 # skipcq: PYL-W0613
+ new: str,
+ ) -> None:
+ """Update the selected color in multichannel ui."""
+ selected = color_source.selected.indices
+ if selected:
+ color_source.patch({"colors": [(selected[0], new)]})
+
+ color_picker.on_change("color", update_selected_color)
+
+ apply_button = Button(
+ label="Apply Changes", button_type="success", margin=(20, 5, 5, 5)
+ )
+
+ def apply_changes() -> None:
+ """Apply the changes to the image."""
+ colors = dict(
+ zip(
+ channel_source.data["channels"],
+ color_source.data["colors"],
+ strict=False,
+ )
+ )
+ active_channels = channel_source.selected.indices
+
+ set_channel_info({ch: hex2rgb(colors[ch]) for ch in colors}, active_channels)
+ change_tiles("slide")
+
+ apply_button.on_click(apply_changes)
+
+ def update_color_picker(
+ attr: str, # noqa: ARG001 # skipcq: PYL-W0613
+ old: str, # noqa: ARG001 # skipcq: PYL-W0613
+ new: str,
+ ) -> None:
+ """Update the color picker when a new channel is selected."""
+ if new:
+ selected_color = color_source.data["colors"][new[0]]
+ color_picker.color = selected_color
+ else:
+ color_picker.color = None
+
+ color_source.selected.on_change("indices", update_color_picker)
+
+ enhance_slider = Slider(
+ start=0.1,
+ end=10,
+ value=1,
+ step=0.1,
+ title="Enhance",
+ width=200,
+ )
+
+ def enhance_cb(
+ attr: str, # noqa: ARG001 # skipcq: PYL-W0613
+ old: str, # noqa: ARG001 # skipcq: PYL-W0613
+ new: str,
+ ) -> None:
+ """Enhance slider callback."""
+ UI["s"].put(
+ f"http://{host2}:5000/tileserver/enhance",
+ data={"val": json.dumps(new)},
+ )
+ UI["vstate"].update_state = 1
+ UI["vstate"].to_update.update(["slide"])
+
+ enhance_slider.on_change("value", enhance_cb)
+
+ instructions = Div(
+ text="""
+ Instructions:
+
+ - Double-click on the 'Active' column to toggle channel visibility
+ - Click on a row to select it for color editing
+ - Use 'Select All' or 'Deselect All' for quick selection
+ - Enable 'Solo Mode' and select a channel to view it alone
+ - Use the color picker to change the color of the selected channel
+ - Click 'Apply Changes' to update the image
+
+ """
+ )
+
+ return column(
+ instructions,
+ column(
+ row(channel_table, color_table),
+ row(color_picker, apply_button),
+ enhance_slider,
+ ),
+ )
+
+
+def populate_table() -> None:
+ """Populate the channel color table."""
+ # Access the ColumnDataSource from the UI dictionary
+ tables = UI["channel_select"].children[1].children[0].children
+ colors, active_channels = get_channel_info()
+
+ if colors is not None:
+ if active_channels:
+ tables[0].source.selected.indices = active_channels
+ tables[0].source.data = {
+ "channels": list(colors.keys()),
+ "dummy": list(colors.keys()),
+ }
+ tables[1].source.data = {
+ "colors": [rgb2hex(color) for color in colors.values()],
+ "dummy": list(colors.keys()),
+ }
+
+
def get_view_bounds(
dims: tuple[float, float],
plot_size: tuple[float, float],
@@ -734,12 +930,13 @@ def populate_slide_list(slide_folder: Path, search_txt: str | None = None) -> No
len_slidepath = len(slide_folder.parts)
for ext in [
"*.svs",
- "*ndpi",
+ "*.ndpi",
"*.tiff",
"*.mrxs",
"*.jpg",
"*.png",
"*.tif",
+ "*.qptiff",
"*.dcm",
]:
file_list.extend(list(Path(slide_folder).glob(str(Path("*") / ext))))
@@ -759,14 +956,22 @@ def populate_slide_list(slide_folder: Path, search_txt: str | None = None) -> No
UI["slide_select"].options = file_list
-def filter_input_cb(attr: str, old: str, new: str) -> None: # noqa: ARG001
+def filter_input_cb(
+ attr: str, # noqa: ARG001 # skipcq: PYL-W0613
+ old: str, # noqa: ARG001 # skipcq: PYL-W0613
+ new: str, # noqa: ARG001 # skipcq: PYL-W0613
+) -> None:
"""Change predicate to be used to filter annotations."""
build_predicate()
UI["vstate"].update_state = 1
UI["vstate"].to_update.update(["overlay"])
-def cprop_input_cb(attr: str, old: str, new: list[str]) -> None: # noqa: ARG001
+def cprop_input_cb(
+ attr: str, # noqa: ARG001 # skipcq: PYL-W0613
+ old: str, # noqa: ARG001 # skipcq: PYL-W0613
+ new: list[str],
+) -> None:
"""Change property to color by."""
if len(new) == 0:
return
@@ -884,6 +1089,7 @@ def slide_select_cb(attr: str, old: str, new: str) -> None: # noqa: ARG001
fname = make_safe_name(str(slide_path))
UI["s"].put(f"http://{host2}:5000/tileserver/slide", data={"slide_path": fname})
change_tiles("slide")
+ populate_table()
# Load the overlay and graph automatically if set in config
if doc_config["auto_load"]:
@@ -1663,12 +1869,14 @@ def gather_ui_elements( # noqa: PLR0915
"pt_size_spinner",
"edge_size_spinner",
"res_switch",
+ "channel_select",
],
[
opt_buttons,
pt_size_spinner,
edge_size_spinner,
res_switch,
+ create_channel_color_ui(),
],
strict=False,
),
@@ -2109,12 +2317,13 @@ def setup_doc(self: DocConfig, base_doc: Document) -> tuple[Row, Tabs]:
slide_list = []
for ext in [
"*.svs",
- "*ndpi",
+ "*.ndpi",
"*.tiff",
"*.tif",
"*.mrxs",
"*.png",
"*.jpg",
+ "*.qptiff",
"*.dcm",
]:
slide_list.extend(list(doc_config["slide_folder"].glob(ext)))
diff --git a/tiatoolbox/visualization/tileserver.py b/tiatoolbox/visualization/tileserver.py
index 236868f17..1ff6a0c6d 100644
--- a/tiatoolbox/visualization/tileserver.py
+++ b/tiatoolbox/visualization/tileserver.py
@@ -24,6 +24,7 @@
from tiatoolbox.annotation import AnnotationStore, SQLiteStore
from tiatoolbox.tools.pyramid import AnnotationTileGenerator, ZoomifyGenerator
from tiatoolbox.utils.misc import add_from_dat, store_from_dat
+from tiatoolbox.utils.postproc_defs import MultichannelToRGB
from tiatoolbox.utils.visualization import AnnotationRenderer, colourise_image
from tiatoolbox.wsicore.wsireader import (
OpenSlideWSIReader,
@@ -170,6 +171,9 @@ def __init__( # noqa: PLR0915
)
self.route("/tileserver/tap_query//")(self.tap_query)
self.route("/tileserver/prop_range", methods=["PUT"])(self.prop_range)
+ self.route("/tileserver/channels", methods=["GET"])(self.get_channels)
+ self.route("/tileserver/channels", methods=["PUT"])(self.set_channels)
+ self.route("/tileserver/enhance", methods=["PUT"])(self.set_enhance)
self.route("/tileserver/shutdown", methods=["POST"])(self.shutdown)
self.route("/tileserver/sessions", methods=["GET"])(self.sessions)
self.route("/tileserver/healthcheck", methods=["GET"])(self.healthcheck)
@@ -814,6 +818,41 @@ def prop_range(self: TileServer) -> str:
self.renderers[session_id].score_fn = lambda x: (x - minv) / (maxv - minv)
return "done"
+ def get_channels(self: TileServer) -> Response:
+ """Get the channels of the slide."""
+ session_id = self._get_session_id()
+ if isinstance(self.layers[session_id]["slide"].post_proc, MultichannelToRGB):
+ if not self.layers[session_id]["slide"].post_proc.is_validated:
+ _ = self.layers[session_id]["slide"].slide_thumbnail(
+ resolution=8.0, units="mpp"
+ )
+ return jsonify(
+ {
+ "channels": self.layers[session_id]["slide"].post_proc.color_dict,
+ "active": self.layers[session_id]["slide"].post_proc.channels,
+ },
+ )
+ return jsonify({"channels": {}, "active": []})
+
+ def set_channels(self: TileServer) -> str:
+ """Set the channels of the slide."""
+ session_id = self._get_session_id()
+ if isinstance(self.layers[session_id]["slide"].post_proc, MultichannelToRGB):
+ channels = json.loads(request.form["channels"])
+ active = json.loads(request.form["active"])
+ self.layers[session_id]["slide"].post_proc.color_dict = channels
+ self.layers[session_id]["slide"].post_proc.channels = active
+ self.layers[session_id]["slide"].post_proc.is_validated = False
+ return "done"
+
+ def set_enhance(self: TileServer) -> str:
+ """Set the enhance factor of the slide."""
+ session_id = self._get_session_id()
+ enhance = json.loads(request.form["val"])
+ if isinstance(self.layers[session_id]["slide"].post_proc, MultichannelToRGB):
+ self.layers[session_id]["slide"].post_proc.enhance = enhance
+ return "done"
+
def sessions(self: TileServer) -> Response:
"""Retrieve a mapping of session keys to their corresponding slide file paths.
diff --git a/tiatoolbox/wsicore/wsireader.py b/tiatoolbox/wsicore/wsireader.py
index 0791bc4fa..c2f212df6 100644
--- a/tiatoolbox/wsicore/wsireader.py
+++ b/tiatoolbox/wsicore/wsireader.py
@@ -8,6 +8,7 @@
import math
import os
import re
+from collections import defaultdict
from datetime import datetime
from numbers import Number
from pathlib import Path
@@ -15,6 +16,7 @@
import cv2
import fsspec
+import matplotlib.colors as mcolors
import numpy as np
import openslide
import pandas as pd
@@ -31,6 +33,7 @@
from tiatoolbox import logger, utils
from tiatoolbox.annotation import AnnotationStore, SQLiteStore
+from tiatoolbox.utils import postproc_defs
from tiatoolbox.utils.env_detection import pixman_warning
from tiatoolbox.utils.exceptions import FileNotSupportedError
from tiatoolbox.utils.magic import is_sqlite3
@@ -272,7 +275,10 @@ def np_virtual_wsi(
def _handle_tiff_wsi(
- input_path: Path, mpp: tuple[Number, Number] | None, power: Number | None
+ input_path: Path,
+ mpp: tuple[Number, Number] | None,
+ power: Number | None,
+ post_proc: str | callable | None,
) -> TIFFWSIReader | OpenSlideWSIReader | None:
"""Handle TIFF WSI cases.
@@ -285,6 +291,8 @@ def _handle_tiff_wsi(
power (:obj:`float` or :obj:`None`, optional):
The objective power of the WSI. If not provided, the power
is approximated from the MPP.
+ post_proc (str | callable | None):
+ Post-processing function to apply to the image.
Returns:
OpenSlideWSIReader | TIFFWSIReader | None:
@@ -294,11 +302,13 @@ def _handle_tiff_wsi(
"""
if openslide.OpenSlide.detect_format(input_path) is not None:
try:
- return OpenSlideWSIReader(input_path, mpp=mpp, power=power)
+ return OpenSlideWSIReader(
+ input_path, mpp=mpp, power=power, post_proc=post_proc
+ )
except openslide.OpenSlideError:
pass
if is_tiled_tiff(input_path):
- return TIFFWSIReader(input_path, mpp=mpp, power=power)
+ return TIFFWSIReader(input_path, mpp=mpp, power=power, post_proc=post_proc)
return None
@@ -322,6 +332,10 @@ class WSIReader:
power (:obj:`float` or :obj:`None`, optional):
The objective power of the WSI. If not provided, the power
is approximated from the MPP.
+ post_proc (str | callable | None):
+ Post-processing function to apply to the image. If None,
+ no post-processing is applied. If 'auto', the post-processing
+ function is automatically selected based on the reader type.
"""
@@ -330,6 +344,7 @@ def open( # noqa: PLR0911
input_img: str | Path | np.ndarray | WSIReader,
mpp: tuple[Number, Number] | None = None,
power: Number | None = None,
+ post_proc: str | callable | None = "auto",
**kwargs: dict,
) -> WSIReader:
"""Return an appropriate :class:`.WSIReader` object.
@@ -348,6 +363,10 @@ def open( # noqa: PLR0911
(x, y) tuple of the MPP in the units of the input image.
power (float):
Objective power of the input image.
+ post_proc (str | callable | None):
+ Post-processing function to apply to the image. If None,
+ no post-processing is applied. If 'auto', the post-processing
+ function is automatically selected based on the reader type.
kwargs (dict):
Key-word arguments.
@@ -360,14 +379,12 @@ def open( # noqa: PLR0911
>>> wsi = WSIReader.open(input_img="./sample.svs")
"""
- # Validate inputs
- if not isinstance(input_img, (WSIReader, np.ndarray, str, Path)):
- msg = "Invalid input: Must be a WSIRead, numpy array, string or Path"
- raise TypeError(
- msg,
- )
+ WSIReader._validate_input(input_img)
+
if isinstance(input_img, np.ndarray):
- return VirtualWSIReader(input_img, mpp=mpp, power=power)
+ return VirtualWSIReader(
+ input_img, mpp=mpp, power=power, post_proc=post_proc
+ )
if isinstance(input_img, WSIReader):
return input_img
@@ -377,45 +394,29 @@ def open( # noqa: PLR0911
WSIReader.verify_supported_wsi(input_path)
# Handle special cases first (DICOM, Zarr/NGFF, OME-TIFF)
- if is_dicom(input_path):
- return DICOMWSIReader(input_path, mpp=mpp, power=power)
-
- _, _, suffixes = utils.misc.split_path_name_ext(input_path)
- last_suffix = suffixes[-1]
-
- if FsspecJsonWSIReader.is_valid_zarr_fsspec(input_img):
- return FsspecJsonWSIReader(input_img, mpp=mpp, power=power)
-
- if last_suffix == ".db":
- return AnnotationStoreReader(input_path, **kwargs)
-
- if last_suffix in (".zarr",):
- if not is_ngff(input_path):
- msg = f"File {input_path} does not appear to be a v0.4 NGFF zarr."
- raise FileNotSupportedError(
- msg,
- )
- return NGFFWSIReader(input_path, mpp=mpp, power=power)
-
- if suffixes[-2:] in ([".ome", ".tiff"],) or suffixes[-2:] in (
- [".ome", ".tif"],
- ):
- return TIFFWSIReader(input_path, mpp=mpp, power=power)
+ special_reader = WSIReader._handle_special_cases(
+ input_path, input_img, mpp, power, post_proc, **kwargs
+ )
+ if special_reader is not None:
+ return special_reader
- if last_suffix in (".tif", ".tiff"):
- tiff_wsi = _handle_tiff_wsi(input_path, mpp=mpp, power=power)
- if tiff_wsi is not None:
- return tiff_wsi
+ # Try openslide last
+ return OpenSlideWSIReader(input_path, mpp=mpp, power=power, post_proc=post_proc)
- virtual_wsi = _handle_virtual_wsi(
- last_suffix=last_suffix, input_path=input_path, mpp=mpp, power=power
- )
+ @staticmethod
+ def _validate_input(input_img: str | Path | np.ndarray) -> None:
+ """Validate the input image type.
- if virtual_wsi is not None:
- return virtual_wsi
+ Args:
+ input_img (str | Path | np.ndarray): The input image, which
+ must be a path, string, numpy array, or WSIReader.
- # Try openslide last
- return OpenSlideWSIReader(input_path, mpp=mpp, power=power)
+ Raises:
+ TypeError: If the input is not one of the accepted types.
+ """
+ if not isinstance(input_img, (WSIReader, np.ndarray, str, Path)):
+ msg = "Invalid input: Must be a WSIReader, numpy array, string or Path"
+ raise TypeError(msg)
@staticmethod
def verify_supported_wsi(input_path: Path) -> None:
@@ -448,6 +449,7 @@ def verify_supported_wsi(input_path: Path) -> None:
".jpeg",
".zarr",
".db",
+ ".qptiff",
".json",
]:
msg = f"File {input_path} is not a supported file format."
@@ -455,11 +457,153 @@ def verify_supported_wsi(input_path: Path) -> None:
msg,
)
+ @staticmethod
+ def _handle_special_cases(
+ input_path: Path,
+ input_img: str | Path | np.ndarray,
+ mpp: tuple[Number, Number] | None = None,
+ power: Number | None = None,
+ post_proc: str | callable | None = "auto",
+ **kwargs: dict,
+ ) -> WSIReader | None:
+ """Handle special cases for selecting the appropriate WSIReader.
+
+ Args:
+ input_path (Path): Path to the input image file.
+ input_img (str | Path | np.ndarray): The input image or path.
+ mpp (tuple[Number, Number] | None, optional): Microns per pixel resolution.
+ power (Number | None, optional): Objective power.
+ post_proc (str | callable | None, optional): Post-processing method
+ or identifier.
+ **kwargs (dict): Additional keyword arguments for specific reader types.
+
+ Returns:
+ WSIReader | None: An appropriate WSIReader instance if a match is found,
+ otherwise None.
+
+ Raises:
+ FileNotSupportedError: If the file format is not supported for NGFF Zarr.
+
+ """
+ _, _, suffixes = utils.misc.split_path_name_ext(input_path)
+ last_suffix = suffixes[-1]
+
+ reader = (
+ WSIReader.try_dicom(input_path, mpp, power, post_proc)
+ or WSIReader.try_fsspec(input_img, mpp, power)
+ or WSIReader.try_annotation_store(
+ input_path, last_suffix, post_proc, kwargs
+ )
+ or WSIReader.try_ngff(input_path, last_suffix, mpp, power)
+ or WSIReader.try_ome_tiff(
+ input_path, suffixes, last_suffix, mpp, power, post_proc
+ )
+ or WSIReader.try_tiff(input_path, last_suffix, mpp, power, post_proc)
+ )
+
+ if reader is None:
+ reader = _handle_virtual_wsi(last_suffix, input_path, mpp, power)
+
+ return reader
+
+ @staticmethod
+ def try_dicom(
+ input_path: Path,
+ mpp: tuple[Number, Number] | None,
+ power: Number | None,
+ post_proc: str | callable | None,
+ ) -> DICOMWSIReader | None:
+ """Try to create a DICOMWSIReader if the input is a DICOM file."""
+ if is_dicom(input_path):
+ return DICOMWSIReader(input_path, mpp=mpp, power=power, post_proc=post_proc)
+ return None
+
+ @staticmethod
+ def try_fsspec(
+ input_img: str | Path | np.ndarray,
+ mpp: tuple[Number, Number] | None,
+ power: Number | None,
+ ) -> FsspecJsonWSIReader | None:
+ """Try to create a FsspecJsonWSIReader if the input is a valid Zarr fsspec."""
+ if FsspecJsonWSIReader.is_valid_zarr_fsspec(input_img):
+ return FsspecJsonWSIReader(input_img, mpp=mpp, power=power)
+ return None
+
+ @staticmethod
+ def try_annotation_store(
+ input_path: Path,
+ last_suffix: str,
+ post_proc: str | callable | None,
+ kwargs: dict,
+ ) -> AnnotationStoreReader | None:
+ """Try to create an AnnotationStoreReader if the file is a .db."""
+ if last_suffix == ".db":
+ kwargs["post_proc"] = post_proc
+ return AnnotationStoreReader(input_path, **kwargs)
+ return None
+
+ @staticmethod
+ def try_ngff(
+ input_path: Path,
+ last_suffix: str,
+ mpp: tuple[Number, Number] | None,
+ power: Number | None,
+ ) -> NGFFWSIReader | None:
+ """Try to create an NGFFWSIReader if the file is a valid NGFF Zarr."""
+ if last_suffix == ".zarr":
+ if not is_ngff(input_path):
+ msg = f"File {input_path} does not appear to be a v0.4 NGFF zarr."
+ raise FileNotSupportedError(msg)
+ return NGFFWSIReader(input_path, mpp=mpp, power=power)
+ return None
+
+ @staticmethod
+ def try_ome_tiff(
+ input_path: Path,
+ suffixes: list[str],
+ last_suffix: str,
+ mpp: tuple[Number, Number] | None,
+ power: Number | None,
+ post_proc: str | callable | None,
+ ) -> TIFFWSIReader | None:
+ """Try to create a TIFFWSIReader for OME-TIFF or QPTIFF formats."""
+ if (
+ suffixes[-2:] in ([".ome", ".tiff"], [".ome", ".tif"])
+ or last_suffix == ".qptiff"
+ ):
+ return TIFFWSIReader(input_path, mpp=mpp, power=power, post_proc=post_proc)
+ return None
+
+ @staticmethod
+ def try_tiff(
+ input_path: Path,
+ last_suffix: str,
+ mpp: tuple[Number, Number] | None,
+ power: Number | None,
+ post_proc: str | callable | None,
+ ) -> TIFFWSIReader | None:
+ """Try to create a TIFFWSIReader.
+
+ Try to create a TIFFWSIReader for standard TIFF formats,
+ or fallback to virtual WSI.
+ """
+ if last_suffix in (".tif", ".tiff"):
+ try:
+ return TIFFWSIReader(
+ input_path, mpp=mpp, power=power, post_proc=post_proc
+ )
+ except ValueError as e:
+ if "Unsupported TIFF WSI format" in str(e):
+ return _handle_virtual_wsi(last_suffix, input_path, mpp, power)
+ raise
+ return None
+
def __init__(
self: WSIReader,
input_img: str | Path | np.ndarray | AnnotationStore,
mpp: tuple[Number, Number] | None = None,
power: Number | None = None,
+ post_proc: callable | None = None,
) -> None:
"""Initialize :class:`WSIReader`."""
if isinstance(input_img, (np.ndarray, AnnotationStore)):
@@ -484,6 +628,7 @@ def __init__(
msg = "`power` must be a number."
raise TypeError(msg)
self._manual_power = power
+ self.post_proc = self.get_post_proc(post_proc)
@property
def info(self: WSIReader) -> WSIMeta:
@@ -515,6 +660,35 @@ def info(self: WSIReader, meta: WSIMeta) -> None:
"""
self._m_info = meta
+ def get_post_proc(self: WSIReader, post_proc: str | callable | None) -> callable:
+ """Get the post-processing function.
+
+ Args:
+ post_proc (str | callable | None):
+ Post-processing function to apply to the image. If auto,
+ will use no post_proc unless reader is TIFF or Virtual Reader,
+ in which case it will use MultichannelToRGB.
+
+ Returns:
+ callable:
+ Post-processing function.
+
+ """
+ if callable(post_proc):
+ return post_proc
+ if post_proc is None:
+ return None
+ if post_proc == "auto":
+ # if its TIFFWSIReader or VirtualWSIReader, return fn to
+ # allow multichannel, else return None
+ if isinstance(self, (TIFFWSIReader, VirtualWSIReader)):
+ return postproc_defs.MultichannelToRGB()
+ return None
+ if isinstance(post_proc, str) and hasattr(postproc_defs, post_proc):
+ return getattr(postproc_defs, post_proc)()
+ msg = f"Invalid post-processing function: {post_proc}"
+ raise ValueError(msg)
+
def _info(self: WSIReader) -> WSIMeta:
"""WSI metadata internal getter used to update info property.
@@ -1744,9 +1918,10 @@ def __init__(
input_img: str | Path | np.ndarray,
mpp: tuple[Number, Number] | None = None,
power: Number | None = None,
+ post_proc: str | callable | None = "auto",
) -> None:
"""Initialize :class:`OpenSlideWSIReader`."""
- super().__init__(input_img=input_img, mpp=mpp, power=power)
+ super().__init__(input_img=input_img, mpp=mpp, power=power, post_proc=post_proc)
self.openslide_wsi = openslide.OpenSlide(filename=str(self.input_path))
def read_rect(
@@ -1989,6 +2164,8 @@ def read_rect(
interpolation=interpolation,
)
+ if self.post_proc is not None:
+ im_region = self.post_proc(im_region)
return utils.transforms.background_composite(image=im_region, alpha=False)
def read_bounds(
@@ -2174,6 +2351,8 @@ class docstrings for more information.
interpolation=interpolation,
)
+ if self.post_proc is not None:
+ im_region = self.post_proc(im_region)
return utils.transforms.background_composite(image=im_region, alpha=False)
@staticmethod
@@ -2282,9 +2461,10 @@ def __init__(
input_img: str | Path | np.ndarray,
mpp: tuple[Number, Number] | None = None,
power: Number | None = None,
+ post_proc: str | callable | None = "auto",
) -> None:
"""Initialize :class:`OmnyxJP2WSIReader`."""
- super().__init__(input_img=input_img, mpp=mpp, power=power)
+ super().__init__(input_img=input_img, mpp=mpp, power=power, post_proc=post_proc)
import glymur # noqa: PLC0415
glymur.set_option("lib.num_threads", os.cpu_count() or 1)
@@ -2528,6 +2708,8 @@ def read_rect(
interpolation=interpolation,
)
+ if self.post_proc is not None:
+ im_region = self.post_proc(im_region)
return utils.transforms.background_composite(image=im_region, alpha=False)
def read_bounds(
@@ -2702,6 +2884,8 @@ class docstrings for more information.
interpolation=interpolation,
)
+ if self.post_proc is not None:
+ im_region = self.post_proc(im_region)
return utils.transforms.background_composite(image=im_region, alpha=False)
@staticmethod
@@ -2899,6 +3083,8 @@ class VirtualWSIReader(WSIReader):
"bool" mode supports binary masks,
interpolation in this case will be "nearest" instead of "bicubic".
"feature" mode allows multichannel features.
+ post_proc (str, callable):
+ Post-processing function to apply to the output image.
"""
@@ -2909,12 +3095,14 @@ def __init__(
power: Number | None = None,
info: WSIMeta | None = None,
mode: str = "rgb",
+ post_proc: str | callable | None = "auto",
) -> None:
"""Initialize :class:`VirtualWSIReader`."""
super().__init__(
input_img=input_img,
mpp=mpp,
power=power,
+ post_proc=post_proc,
)
if mode.lower() not in ["rgb", "bool", "feature"]:
msg = "Invalid mode."
@@ -3236,6 +3424,8 @@ def read_rect(
)
if self.mode == "rgb":
+ if self.post_proc is not None:
+ im_region = self.post_proc(im_region)
return utils.transforms.background_composite(image=im_region, alpha=False)
return im_region
@@ -3413,6 +3603,8 @@ class docstrings for more information.
)
if self.mode == "rgb":
+ if self.post_proc is not None:
+ im_region = self.post_proc(im_region)
return utils.transforms.background_composite(image=im_region, alpha=False)
return im_region
@@ -3476,12 +3668,13 @@ def __init__(
mpp: tuple[Number, Number] | None = None,
power: Number | None = None,
series: str = "auto",
- cache_size: int = 2**28,
+ cache_size: int = 2**28, # noqa: ARG002
+ post_proc: str | callable | None = "auto",
) -> None:
"""Initialize :class:`TIFFWSIReader`."""
- super().__init__(input_img=input_img, mpp=mpp, power=power)
+ super().__init__(input_img=input_img, mpp=mpp, power=power, post_proc=post_proc)
self.tiff = tifffile.TiffFile(self.input_path)
- self._axes = self.tiff.pages[0].axes
+ self._axes = self.tiff.series[0].axes
# Flag which is True if the image is a simple single page tile TIFF
is_single_page_tiled = all(
[
@@ -3514,7 +3707,8 @@ def __init__(
def page_area(page: tifffile.TiffPage) -> float:
"""Calculate the area of a page."""
return np.prod(
- TIFFWSIReaderDelegate.canonical_shape(self._axes, page.shape)[:2]
+ TIFFWSIReaderDelegate.canonical_shape(self._axes, page.shape)[:2],
+ dtype=float,
)
series_areas = [page_area(s.pages[0]) for s in all_series] # skipcq
@@ -3525,8 +3719,8 @@ def page_area(page: tifffile.TiffPage) -> float:
series=self.series_n,
aszarr=True,
)
- self._zarr_lru_cache = zarr.LRUStoreCache(self._zarr_store, max_size=cache_size)
- self._zarr_group = zarr.open(self._zarr_lru_cache)
+ # remove LRU cache for now as seems to cause issues on windows
+ self._zarr_group = zarr.open(self._zarr_store)
if not isinstance(self._zarr_group, zarr.hierarchy.Group): # pragma: no cover
group = zarr.hierarchy.group()
group[0] = self._zarr_group
@@ -3542,12 +3736,301 @@ def page_area(page: tifffile.TiffPage) -> float:
key=lambda x: -np.prod(
TIFFWSIReaderDelegate.canonical_shape(
self._axes, x[1].array.shape[:2]
- )
+ ),
+ dtype=float,
),
)
)
+ # maybe get colors if they exist in metadata
+ self._get_colors_from_meta()
+
self.tiff_reader_delegate = TIFFWSIReaderDelegate(self, self.level_arrays)
+ def _get_colors_from_meta(self: TIFFWSIReader) -> None:
+ """Get colors from metadata if they exist."""
+ if not isinstance(self.post_proc, postproc_defs.MultichannelToRGB):
+ return
+
+ try:
+ xml = self.info.raw["Description"]
+ root = ElementTree.fromstring(xml)
+ except ElementTree.ParseError:
+ return
+
+ # Try multiple formats
+ for parser in (
+ TIFFWSIReader._parse_scancolortable,
+ TIFFWSIReader._parse_filtercolor_metadata,
+ TIFFWSIReader._parse_ome_metadata_mapping,
+ ):
+ color_dict = parser(root)
+ if color_dict:
+ self.post_proc.color_dict = color_dict
+ return
+
+ @staticmethod
+ def _parse_scancolortable(
+ root: ElementTree,
+ ) -> dict[str, tuple[float, float, float]] | None:
+ """Parse ScanColorTable metadata from XML and convert color values to RGB.
+
+ Args:
+ root (ElementTree): The root of the parsed XML tree.
+
+ Returns:
+ dict[str, tuple[float, float, float]] | None: A mapping of channel
+ names to RGB tuples, or None if not found.
+ """
+ color_info = root.find(".//ScanColorTable")
+ if color_info is None:
+ return None
+
+ color_dict = {
+ k.text.split("_")[0]: v.text
+ for k, v in zip(
+ color_info.iterfind("ScanColorTable-k"),
+ color_info.iterfind("ScanColorTable-v"),
+ strict=False,
+ )
+ }
+ # values will be either a string of 3 ints e.g 155, 128, 0, or
+ # a color name e.g Lime. Convert them all to RGB tuples.
+ for key, value in color_dict.items():
+ if value is None:
+ continue
+ if "," in value:
+ color_dict[key] = tuple(int(x) / 255 for x in value.split(","))
+ else:
+ color_dict[key] = mcolors.to_rgb(value)
+
+ return color_dict
+
+ @staticmethod
+ def _parse_filtercolor_metadata(
+ root: ElementTree,
+ ) -> dict[str, tuple[float, float, float]] | None:
+ """Parse FilterColors metadata from XML and convert color values to RGB.
+
+ Args:
+ root (ElementTree): The root of the parsed XML tree.
+
+ Returns:
+ dict[str, tuple[float, float, float]] | None: A mapping of channel
+ names to RGB tuples, or None if not found.
+ """
+ # try alternate metadata format
+ # Build a map from filter pair string -> color label or RGB string
+ # from the section
+ filter_colors = {}
+ filter_colors_section = root.find(".//FilterColors")
+ if filter_colors_section is None:
+ return None
+
+ keys = filter_colors_section.findall(".//FilterColors-k")
+ vals = filter_colors_section.findall(".//FilterColors-v")
+ for k, v in zip(keys, vals, strict=False):
+ filter_colors[k.text] = v.text
+
+ # Helper function to convert color strings like "Lime" or
+ # "255, 128, 0" into (R,G,B)
+ def color_string_to_rgb(s: str) -> tuple[float, float, float]:
+ """Convert a color string (e.g., 'Lime' or '255, 128, 0') to an RGB tuple.
+
+ Args:
+ s (str): The color string.
+
+ Returns:
+ tuple[float, float, float]: RGB values normalized to [0, 1].
+ """
+ if "," in s:
+ return tuple(int(x.strip()) / 255 for x in s.split(","))
+ return mcolors.to_rgb(s)
+
+ # 2) For each , find the channel's name and figure out
+ # which filter pair it uses, then match that to a color.
+ channel_dict = {}
+ for scan_band in root.findall(".//ScanBands-i"):
+ # Inside a there is a with a tag
+ bands_i = scan_band.find(".//Bands-i")
+ if bands_i is not None:
+ band_name_element = bands_i.find("Name")
+ if band_name_element is not None:
+ channel_name = band_name_element.text.strip()
+
+ # Grab the filter pair manufacturer info
+ filter_pair = scan_band.find(".//FilterPair")
+ if filter_pair is not None:
+ emission_part = filter_pair.find(
+ ".//EmissionFilter/FixedFilter/PartNumber"
+ )
+ excitation_part = filter_pair.find(
+ ".//ExcitationFilter/FixedFilter/PartNumber"
+ )
+ if emission_part is not None and excitation_part is not None:
+ matching_rgb = (1.0, 1.0, 1.0) # default white
+ for fc_key, fc_val in filter_colors.items():
+ # if both part numbers appear in the FilterColors-k
+ # string, assume it's the match
+ if (
+ emission_part.text in fc_key
+ and excitation_part.text in fc_key
+ ):
+ matching_rgb = color_string_to_rgb(fc_val)
+ break
+
+ channel_dict[channel_name] = matching_rgb
+
+ return channel_dict if channel_dict else None
+
+ @staticmethod
+ def _get_namespace(root: ElementTree) -> dict:
+ """Extract the XML namespace from the root element.
+
+ Args:
+ root (ElementTree): Root of the parsed XML tree.
+
+ Returns:
+ dict: Dictionary containing the namespace prefix and URI.
+ """
+ if root.tag.startswith("{"):
+ ns_uri = root.tag.split("}")[0].strip("{")
+ return {"ns": ns_uri}
+
+ return {}
+
+ @staticmethod
+ def _extract_dye_mapping(root: ElementTree, ns: dict) -> dict:
+ """Extract dye mapping from OME-XML annotations.
+
+ Args:
+ root (ElementTree): Root of the parsed XML tree.
+ ns (dict): XML namespace dictionary.
+
+ Returns:
+ dict: Mapping of channel IDs to dye names.
+ """
+ dye_mapping = {}
+ for annotation in root.findall(
+ ".//ns:StructuredAnnotations/ns:XMLAnnotation", ns
+ ):
+ value_elem = annotation.find("ns:Value", ns)
+ if value_elem is not None:
+ for chan_priv in value_elem.findall(".//ns:ChannelPriv", ns):
+ chan_id = chan_priv.attrib.get("ID")
+ dye = chan_priv.attrib.get("FluorescenceChannel")
+ if chan_id and dye:
+ dye_mapping[chan_id] = dye
+ return dye_mapping
+
+ @staticmethod
+ def _int_to_rgb(color_int: int) -> tuple[float, float, float]:
+ """Convert an integer color value to an RGB tuple.
+
+ Args:
+ color_int (int): Integer representation of a color.
+
+ Returns:
+ tuple[float, float, float]: RGB values normalized to [0, 1].
+ """
+ if color_int < 0:
+ color_int += 1 << 32
+ r = (color_int >> 16) & 0xFF
+ g = (color_int >> 8) & 0xFF
+ b = color_int & 0xFF
+
+ return (r / 255, g / 255, b / 255)
+
+ @staticmethod
+ def _parse_channel_data(
+ root: ElementTree, ns: dict, dye_mapping: dict
+ ) -> list[dict]:
+ """Parse channel metadata from OME-XML.
+
+ Extract RGB color and dye information for each channel defined in the metadata.
+
+ Args:
+ root (ElementTree): Root of the parsed XML tree.
+ ns (dict): XML namespace dictionary.
+ dye_mapping (dict): Mapping of channel IDs to dye names.
+
+ Returns:
+ list[dict]: List of dictionaries containing channel metadata.
+ """
+ channel_data = []
+ for pixels in root.findall(".//ns:Pixels", ns):
+ for channel in pixels.findall("ns:Channel", ns):
+ chan_id = channel.attrib.get("ID")
+ name = channel.attrib.get("Name")
+ color = channel.attrib.get("Color")
+ if chan_id and name and color:
+ try:
+ color_int = int(color)
+ rgb = TIFFWSIReader._int_to_rgb(color_int)
+ except ValueError:
+ rgb = None
+ dye = dye_mapping.get(chan_id, "Unknown")
+ label = f"{chan_id}: {name} ({dye})"
+ channel_data.append(
+ {
+ "id": chan_id,
+ "name": name,
+ "dye": dye,
+ "rgb": rgb,
+ "label": label,
+ }
+ )
+ return channel_data
+
+ @staticmethod
+ def _build_color_dict(
+ channel_data: list[dict], dye_mapping: dict
+ ) -> dict[str, tuple[float, float, float]]:
+ """Build a dictionary mapping channel names to RGB color tuples.
+
+ Args:
+ channel_data (list[dict]): List of channel metadata dictionaries.
+ dye_mapping (dict): Mapping of channel IDs to dye names.
+
+ Returns:
+ dict[str, tuple[float, float, float]]: Dictionary mapping channel labels to
+ RGB values.
+ """
+ color_dict = {}
+ key_counts = defaultdict(int)
+ for c_data in channel_data:
+ chan_id = c_data["id"]
+ name = c_data["name"]
+ dye = dye_mapping.get(chan_id)
+ rgb = c_data["rgb"]
+ base_key = f"{name} ({dye})" if dye else name
+ count = key_counts[base_key]
+ key = base_key if count == 0 else f"{base_key} [{count + 1}]"
+ color_dict[key] = rgb
+ key_counts[base_key] += 1
+
+ return color_dict
+
+ @staticmethod
+ def _parse_ome_metadata_mapping(
+ root: ElementTree,
+ ) -> dict[str, tuple[float, float, float]] | None:
+ """Parse OME metadata from the given XML root element.
+
+ Args:
+ root (ElementTree): The root of the parsed XML tree.
+
+ Returns:
+ dict[str, tuple[float, float, float]] | None: A mapping
+ of channel names to RGB tuples, or None if not found.
+ """
+ # 3) Try OME/Lunaphore format e.g. for COMET
+ ns = TIFFWSIReader._get_namespace(root)
+ dye_mapping = TIFFWSIReader._extract_dye_mapping(root, ns)
+ channel_data = TIFFWSIReader._parse_channel_data(root, ns, dye_mapping)
+ color_dict = TIFFWSIReader._build_color_dict(channel_data, dye_mapping)
+
+ return color_dict if color_dict else None
+
def _get_ome_xml(self: TIFFWSIReader) -> ElementTree.Element:
"""Parse OME-XML from the description of the first IFD (page).
@@ -3602,32 +4085,49 @@ def _get_ome_objective_power(
"""
xml = xml or self._get_ome_xml()
namespaces = {"ome": "http://www.openmicroscopy.org/Schemas/OME/2016-06"}
- xml_series = xml.findall("ome:Image", namespaces)[self.series_n]
- instrument_ref = xml_series.find("ome:InstrumentRef", namespaces)
- if instrument_ref is None:
- return None
-
- objective_settings = xml_series.find("ome:ObjectiveSettings", namespaces)
- instrument_ref_id = instrument_ref.attrib["ID"]
- objective_settings_id = objective_settings.attrib["ID"]
- instruments = {
- instrument.attrib["ID"]: instrument
- for instrument in xml.findall("ome:Instrument", namespaces)
- }
- objectives = {
- (instrument_id, objective.attrib["ID"]): objective
- for instrument_id, instrument in instruments.items()
- for objective in instrument.findall("ome:Objective", namespaces)
- }
try:
- objective = objectives[(instrument_ref_id, objective_settings_id)]
- return float(objective.attrib.get("NominalMagnification"))
- except KeyError as e:
- msg = "No matching Instrument for image InstrumentRef in OME-XML."
- raise KeyError(
- msg,
- ) from e
+ xml_series = xml.findall("ome:Image", namespaces)[self.series_n]
+ instrument_ref = xml_series.find("ome:InstrumentRef", namespaces)
+ objective_settings = xml_series.find("ome:ObjectiveSettings", namespaces)
+ if objective_settings is None:
+ # try alternative tag
+ objective_settings = xml_series.find("ome:Objective", namespaces)
+
+ instrument_ref_id = instrument_ref.attrib.get("ID")
+ objective_settings_id = (
+ objective_settings.attrib.get("ID")
+ if objective_settings is not None
+ else "Objective:0"
+ )
+
+ instruments = {
+ instrument.attrib.get("ID"): instrument
+ for instrument in xml.findall("ome:Instrument", namespaces)
+ }
+ objectives = {
+ (instrument_id, objective.attrib.get("ID")): objective
+ for instrument_id, instrument in instruments.items()
+ for objective in instrument.findall("ome:Objective", namespaces)
+ }
+
+ objective = objectives.get((instrument_ref_id, objective_settings_id))
+ if objective is not None:
+ return float(objective.attrib.get("NominalMagnification"))
+
+ except (IndexError, AttributeError, ValueError, TypeError, KeyError) as e:
+ logger.warning("OME objective power extraction failed: %s", e)
+
+ # Fallback: try to infer from MPP
+ mpp = self._get_ome_mpp(xml)
+ if mpp is not None:
+ try:
+ return utils.misc.mpp2common_objective_power(float(np.mean(mpp)))
+ except (TypeError, ValueError) as e:
+ logger.warning("Failed to infer objective power from MPP: %s", e)
+
+ logger.warning("Objective power could not be determined from OME-XML.")
+ return None
def _get_ome_mpp(
self: TIFFWSIReader,
@@ -4097,9 +4597,9 @@ def canonical_shape(axes: str, shape: tuple[int, int]) -> tuple[int, int]:
Returns:
tuple[int, int]: Shape in YXS order.
"""
- if axes == "YXS":
+ if axes in ("YXS", "YXC"):
return shape
- if axes == "SYX":
+ if axes in ("SYX", "CYX"):
return np.roll(shape, -1)
msg = f"Unsupported axes `{axes}`."
raise ValueError(msg)
@@ -4306,7 +4806,9 @@ def read_rect(
pad_mode=pad_mode,
pad_constant_values=pad_constant_values,
)
- return utils.transforms.background_composite(im_region, alpha=False)
+ if self.reader.post_proc is not None:
+ im_region = self.reader.post_proc(im_region)
+ return im_region
# Find parameters for optimal read
(
@@ -4339,7 +4841,9 @@ def read_rect(
interpolation=interpolation,
)
- return utils.transforms.background_composite(image=im_region, alpha=False)
+ if self.reader.post_proc is not None:
+ im_region = self.reader.post_proc(im_region)
+ return im_region
def read_bounds(
self: TIFFWSIReaderDelegate,
@@ -4511,6 +5015,8 @@ class docstrings for more information.
output_size=size_at_requested,
)
+ if self.reader.post_proc is not None:
+ return self.reader.post_proc(im_region)
return im_region
@staticmethod
@@ -4559,11 +5065,12 @@ def __init__(
input_img: str | Path | np.ndarray,
mpp: tuple[Number, Number] | None = None,
power: Number | None = None,
+ post_proc: str | callable | None = "auto",
) -> None:
"""Initialize :class:`DICOMWSIReader`."""
from wsidicom import WsiDicom # noqa: PLC0415
- super().__init__(input_img, mpp, power)
+ super().__init__(input_img, mpp, power, post_proc)
self.wsi = WsiDicom.open(input_img)
def _info(self: DICOMWSIReader) -> WSIMeta:
@@ -4867,6 +5374,8 @@ def read_rect(
interpolation=interpolation,
)
+ if self.post_proc is not None:
+ im_region = self.post_proc(im_region)
return utils.transforms.background_composite(image=im_region, alpha=False)
def read_bounds(
@@ -5061,6 +5570,8 @@ class docstrings for more information.
interpolation=interpolation,
)
+ if self.post_proc is not None:
+ return self.post_proc(im_region)
return utils.transforms.background_composite(image=im_region, alpha=False)
@@ -5384,6 +5895,8 @@ def read_rect(
pad_mode=pad_mode,
pad_constant_values=pad_constant_values,
)
+ if self.post_proc is not None:
+ return self.post_proc(im_region)
return utils.transforms.background_composite(image=im_region, alpha=False)
# Find parameters for optimal read
@@ -5418,6 +5931,8 @@ def read_rect(
interpolation=interpolation,
)
+ if self.post_proc is not None:
+ im_region = self.post_proc(im_region)
return utils.transforms.background_composite(image=im_region, alpha=False)
def read_bounds(
@@ -5955,6 +6470,8 @@ def read_rect(
coord_space=coord_space,
**kwargs,
)
+ if self.post_proc is not None:
+ base_region = self.post_proc(base_region)
base_region = Image.fromarray(
utils.transforms.background_composite(base_region, alpha=True),
)
@@ -6148,6 +6665,8 @@ class docstrings for more information.
coord_space=coord_space,
**kwargs,
)
+ if self.post_proc is not None:
+ base_region = self.post_proc(base_region)
base_region = Image.fromarray(
utils.transforms.background_composite(base_region, alpha=True),
)