From 479cde8bbb311248ff821c6dd4b2fbee9931f5a7 Mon Sep 17 00:00:00 2001 From: blaginin Date: Mon, 27 Mar 2023 08:36:53 +0100 Subject: [PATCH 01/16] :recycle: Add checks for input_path existence --- tiatoolbox/wsicore/wsireader.py | 40 ++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/tiatoolbox/wsicore/wsireader.py b/tiatoolbox/wsicore/wsireader.py index b5dabcada..cd76489d3 100644 --- a/tiatoolbox/wsicore/wsireader.py +++ b/tiatoolbox/wsicore/wsireader.py @@ -11,7 +11,7 @@ import warnings from datetime import datetime from numbers import Number -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Literal, Optional, Tuple, Union import numpy as np import openslide @@ -36,6 +36,7 @@ Bounds = Tuple[Number, Number, Number, Number] IntBounds = Tuple[int, int, int, int] Resolution = Union[Number, Tuple[Number, Number], np.ndarray] +Units = Literal["mpp", "power", "baseline", "level"] def is_dicom(path: pathlib.Path) -> bool: @@ -204,6 +205,7 @@ def open( # noqa: A003 raise TypeError( "Invalid input: Must be a WSIRead, numpy array, string or pathlib.Path" ) + if isinstance(input_img, np.ndarray): return VirtualWSIReader(input_img, mpp=mpp, power=power) @@ -212,10 +214,22 @@ def open( # noqa: A003 # Input is a string or pathlib.Path, normalise to pathlib.Path input_path = pathlib.Path(input_img) + if not os.path.exists(input_path): + raise ValueError("`input_img` path must exist") + WSIReader.verify_supported_wsi(input_path) + return WSIReader.get_reader_by_filepath( + input_path, mpp=mpp, power=power, **kwargs + ) + @staticmethod + def get_reader_by_filepath( + input_path: pathlib.Path, + mpp: Optional[Tuple[Number, Number]] = None, + power: Optional[Number] = None, + **kwargs, + ) -> WSIReader: # Handle special cases first (DICOM, Zarr/NGFF, OME-TIFF) - if is_dicom(input_path): return DICOMWSIReader(input_path, mpp=mpp, power=power) @@ -373,7 +387,7 @@ def _info(self) -> WSIMeta: raise NotImplementedError def _find_optimal_level_and_downsample( - self, resolution: Resolution, units: str, precision: int = 3 + self, resolution: Resolution, units: Units, precision: int = 3 ) -> Tuple[int, np.ndarray]: """Find the optimal level to read at for a desired resolution and units. @@ -436,7 +450,7 @@ def find_read_rect_params( location: IntPair, size: IntPair, resolution: Resolution, - units: str, + units: Units, precision: int = 3, ) -> Tuple[int, IntPair, IntPair, NumPair, IntPair]: """Find optimal parameters for reading a rect at a given resolution. @@ -513,7 +527,7 @@ def find_read_rect_params( ) def _find_read_params_at_resolution( - self, location: IntPair, size: IntPair, resolution: Resolution, units: str + self, location: IntPair, size: IntPair, resolution: Resolution, units: Units ) -> Tuple[int, NumPair, IntPair, IntPair, IntPair, IntPair]: """Works similarly to `_find_read_rect_params`. @@ -600,7 +614,7 @@ def _find_read_params_at_resolution( ) + output def _bounds_at_resolution_to_baseline( - self, bounds: Bounds, resolution: Resolution, units: str + self, bounds: Bounds, resolution: Resolution, units: Units ) -> Bounds: """Find corresponding bounds in baseline. @@ -628,7 +642,7 @@ def _bounds_at_resolution_to_baseline( return np.concatenate([tl_at_baseline, br_at_baseline]) # bounds at baseline def slide_dimensions( - self, resolution: Resolution, units: str, precisions: int = 3 + self, resolution: Resolution, units: Units, precisions: int = 3 ) -> IntPair: """Return the size of WSI at requested resolution. @@ -662,7 +676,7 @@ def slide_dimensions( return wsi_shape_at_resolution def _find_read_bounds_params( - self, bounds: Bounds, resolution: Resolution, units: str, precision: int = 3 + self, bounds: Bounds, resolution: Resolution, units: Units, precision: int = 3 ) -> Tuple[int, IntBounds, IntPair, IntPair, np.ndarray]: """Find optimal parameters for reading bounds at a given resolution. @@ -901,7 +915,7 @@ def _read_rect_at_resolution( location: NumPair, size: NumPair, resolution: Resolution = 0, - units: str = "level", + units: Units = "level", interpolation: str = "optimise", pad_mode: str = "constant", pad_constant_values: Union[Number, Iterable[NumPair]] = 0, @@ -933,7 +947,7 @@ def read_rect( location: IntPair, size: IntPair, resolution: Resolution = 0, - units: str = "level", + units: Units = "level", interpolation: str = "optimise", pad_mode: str = "constant", pad_constant_values: Union[Number, Iterable[NumPair]] = 0, @@ -1123,7 +1137,7 @@ def read_bounds( self, bounds: Bounds, resolution: Resolution = 0, - units: str = "level", + units: Units = "level", interpolation: str = "optimise", pad_mode: str = "constant", pad_constant_values: Union[Number, Iterable[NumPair]] = 0, @@ -1257,7 +1271,7 @@ def read_region(self, location: NumPair, level: int, size: IntPair) -> np.ndarra location=location, size=size, resolution=level, units="level" ) - def slide_thumbnail(self, resolution: Resolution = 1.25, units: str = "power"): + def slide_thumbnail(self, resolution: Resolution = 1.25, units: Units = "power"): """Read the whole slide image thumbnail (1.25x by default). For more information on resolution and units see @@ -1288,7 +1302,7 @@ def tissue_mask( self, method: str = "otsu", resolution: Resolution = 1.25, - units: str = "power", + units: Units = "power", **masker_kwargs, ) -> "VirtualWSIReader": """Create a tissue mask and wrap it in a VirtualWSIReader. From 87b43b850f99389c579ba0ac85710d60642b3107 Mon Sep 17 00:00:00 2001 From: blaginin Date: Mon, 27 Mar 2023 08:38:48 +0100 Subject: [PATCH 02/16] :recycle: Refactor PatchPredictor --- tiatoolbox/models/engine/patch_predictor.py | 195 +++++++++++--------- 1 file changed, 104 insertions(+), 91 deletions(-) diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 96ada6692..5c5af9cdb 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -5,7 +5,7 @@ import pathlib import warnings from collections import OrderedDict -from typing import Callable, Tuple, Union +from typing import Callable, List, Literal, Tuple, Union import numpy as np import torch @@ -16,7 +16,7 @@ from tiatoolbox.models.engine.semantic_segmentor import IOSegmentorConfig from tiatoolbox.utils import misc from tiatoolbox.utils.misc import save_as_json -from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader +from tiatoolbox.wsicore.wsireader import Units, WSIReader class IOPatchPredictorConfig(IOSegmentorConfig): @@ -158,8 +158,6 @@ class PatchPredictor: Whether to output logging information. Attributes: - img (:obj:`str` or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): - A HWC image or a path to WSI. mode (str): Type of input to process. Choose from either `patch`, `tile` or `wsi`. @@ -229,7 +227,6 @@ def __init__( ): super().__init__() - self.imgs = None self.mode = None if model is None and pretrained_model is None: @@ -251,7 +248,7 @@ def __init__( @staticmethod def merge_predictions( - img: Union[str, pathlib.Path, np.ndarray], + input_img: Union[str, pathlib.Path, np.ndarray, WSIReader], output: dict, resolution: float = None, units: str = None, @@ -267,8 +264,8 @@ def merge_predictions( predicted by the model. Args: - img (:obj:`str` or :obj:`pathlib.Path` or :class:`numpy.ndarray`): - A HWC image or a path to WSI. + input_img (:obj:`str` or :obj:`pathlib.Path` or :class:`numpy.ndarray`): + Image to be processed. This can be a WSI, tile or patch. output (dict): Output generated by the model. resolution (float): @@ -308,16 +305,7 @@ def merge_predictions( ... [0, 0, 1, 1]]) """ - reader = WSIReader.open(img) - if isinstance(reader, VirtualWSIReader): - warnings.warn( - "Image is not pyramidal hence read is forced to be " - "at `units='baseline'` and `resolution=1.0`.", - stacklevel=2, - ) - resolution = 1.0 - units = "baseline" - + reader = WSIReader.open(input_img) canvas_shape = reader.slide_dimensions(resolution=resolution, units=units) canvas_shape = canvas_shape[::-1] # XY to YX @@ -555,18 +543,18 @@ def _prepare_save_dir(save_dir, imgs): save_dir = pathlib.Path(save_dir) save_dir.mkdir(parents=True, exist_ok=False) - return save_dir + return save_dir or pathlib.Path(os.getcwd()) - def _predict_patch(self, imgs, labels, return_probabilities, return_labels, on_gpu): + def _predict_patch( + self, input_imgs, labels, return_probabilities, return_labels, on_gpu + ): """Process patch mode. Args: - imgs (list, ndarray): - List of inputs to process. when using `patch` mode, the - input must be either a list of images, a list of image - file paths or a numpy array of an image list. When using - `tile` or `wsi` mode, the input must be a list of file - paths. + input_imgs (list): + List of inputs to process. Must be either a list of + images, a list of image file paths, WSIReader objects, + or a numpy array of an image list. labels: List of labels. If using `tile` or `wsi` mode, then only a single label per image tile or whole-slide image is @@ -587,21 +575,21 @@ def _predict_patch(self, imgs, labels, return_probabilities, return_labels, on_g # if a labels is provided, then return with the prediction return_labels = bool(labels) - if labels and len(labels) != len(imgs): + if labels and len(labels) != len(input_imgs): raise ValueError( - f"len(labels) != len(imgs) : " f"{len(labels)} != {len(imgs)}" + f"len(labels) != len(imgs) : " f"{len(labels)} != {len(input_imgs)}" ) # don't return coordinates if patches are already extracted return_coordinates = False - dataset = PatchDataset(imgs, labels) + dataset = PatchDataset(input_imgs, labels) return self._predict_engine( dataset, return_probabilities, return_labels, return_coordinates, on_gpu ) def _predict_tile_wsi( self, - imgs, + input_imgs, masks, labels, mode, @@ -616,12 +604,10 @@ def _predict_tile_wsi( """Predict on Tile and WSIs. Args: - imgs (list, ndarray): - List of inputs to process. when using `patch` mode, the - input must be either a list of images, a list of image - file paths or a numpy array of an image list. When using - `tile` or `wsi` mode, the input must be a list of file - paths. + input_imgs (list): + List of inputs to process. Must be either a list of + images, a list of image file paths, WSIReader objects, + or a numpy array of an image list. masks (list): List of masks. Only utilised when processing image tiles and whole-slide images. Patches are only processed if @@ -678,22 +664,22 @@ def _predict_tile_wsi( # generate a list of output file paths if number of input images > 1 file_dict = OrderedDict() - if len(imgs) > 1: + if len(input_imgs) > 1: save_output = True - for idx, img_path in enumerate(imgs): - img_path = pathlib.Path(img_path) + for idx, input_img in enumerate(input_imgs): img_label = None if labels is None else labels[idx] img_mask = None if masks is None else masks[idx] dataset = WSIPatchDataset( - img_path, + input_img, mode=mode, - mask_path=img_mask, + mask=img_mask, patch_input_shape=ioconfig.patch_input_shape, stride_shape=ioconfig.stride_shape, resolution=ioconfig.input_resolutions[0]["resolution"], units=ioconfig.input_resolutions[0]["units"], + auto_get_mask=True, ) output_model = self._predict_engine( dataset, @@ -712,7 +698,7 @@ def _predict_tile_wsi( merged_prediction = None if merge_predictions: merged_prediction = self.merge_predictions( - img_path, + input_img, output_model, resolution=output_model["resolution"], units=output_model["units"], @@ -721,49 +707,79 @@ def _predict_tile_wsi( outputs.append(merged_prediction) if save_output: - # dynamic 0 padding - img_code = f"{idx:0{len(str(len(imgs)))}d}" - - save_info = {} - save_path = os.path.join(str(save_dir), img_code) - raw_save_path = f"{save_path}.raw.json" - save_info["raw"] = raw_save_path - save_as_json(output_model, raw_save_path) - if merge_predictions: - merged_file_path = f"{save_path}.merged.npy" - np.save(merged_file_path, merged_prediction) - save_info["merged"] = merged_file_path - file_dict[str(img_path)] = save_info + img_id, save_info = self._save_output( + output_model, + idx, + merged_prediction, + input_img, + input_imgs, + save_dir, + merge_predictions, + ) + file_dict[img_id] = save_info return file_dict if save_output else outputs + def _save_output( + self, + output_model, + idx, + merged_prediction, + input_img, + input_imgs, + save_dir, + merge_predictions, + ): + # dynamic 0 padding + img_code = f"{idx:0{len(str(len(input_imgs)))}d}" + + save_info = {} + save_path = os.path.join(str(save_dir), img_code) + raw_save_path = f"{save_path}.raw.json" + save_info["raw"] = raw_save_path + save_as_json(output_model, raw_save_path) + if merge_predictions: + merged_file_path = f"{save_path}.merged.npy" + np.save(merged_file_path, merged_prediction) + save_info["merged"] = merged_file_path + + img_id = None + if isinstance(input_img, WSIReader): + img_id = str(input_img.input_path) + + elif isinstance(input_img, (str, pathlib.Path)): + img_id = str(input_img) + + if img_id is None: + img_id = idx + + return save_info + def predict( self, - imgs, - masks=None, - labels=None, - mode="patch", - return_probabilities=False, - return_labels=False, - on_gpu=True, + input_imgs: List[Union[str, pathlib.Path, np.ndarray, WSIReader]], + masks: List[Union[str, pathlib.Path, np.ndarray, WSIReader]] = None, + labels: List = None, + mode: Literal["patch", "tile", "wsi"] = "patch", + return_probabilities: bool = False, + return_labels: bool = False, + on_gpu: bool = True, ioconfig: IOPatchPredictorConfig = None, patch_input_shape: Tuple[int, int] = None, stride_shape: Tuple[int, int] = None, - resolution=None, - units=None, - merge_predictions=False, - save_dir=None, - save_output=False, + resolution: float = None, + units: Units = None, + merge_predictions: bool = False, + save_dir: bool = None, + save_output: bool = False, ): """Make a prediction for a list of input data. Args: - imgs (list, ndarray): - List of inputs to process. when using `patch` mode, the - input must be either a list of images, a list of image - file paths or a numpy array of an image list. When using - `tile` or `wsi` mode, the input must be a list of file - paths. + input_imgs (list): + List of inputs to process. Must be either a list of + images, a list of image file paths, WSIReader objects, + or a numpy array of an image list. masks (list): List of masks. Only utilised when processing image tiles and whole-slide images. Patches are only processed if @@ -796,7 +812,7 @@ def predict( level 0, and must be positive. If not provided, `stride_shape=patch_input_shape`. resolution (float): - Resolution used for reading the image. Please see + Resolution used for reading the images. Please see :obj:`WSIReader` for details. units (str): Units of resolution used for reading the image. Choose @@ -842,36 +858,33 @@ def predict( ... {'raw': '1.raw.json', 'merged': '1.merged.npy'} """ + if mode not in ["patch", "wsi", "tile"]: raise ValueError( f"{mode} is not a valid mode. Use either `patch`, `tile` or `wsi`" ) - if mode == "patch": - return self._predict_patch( - imgs, labels, return_probabilities, return_labels, on_gpu - ) - if not isinstance(imgs, list): + if mode == "patch" and masks is not None: + raise ValueError("masks are not supported for `patch` mode. ") + + if not isinstance(input_imgs, list): raise ValueError( "Input to `tile` and `wsi` mode must be a list of file paths." ) - if mode == "wsi" and masks is not None and len(masks) != len(imgs): + if mode == "wsi" and masks is not None and len(masks) != len(input_imgs): raise ValueError( - f"len(masks) != len(imgs) : " f"{len(masks)} != {len(imgs)}" + f"len(masks) != len(imgs) : " f"{len(masks)} != {len(input_imgs)}" + ) + + if mode == "patch": + return self._predict_patch( + input_imgs, labels, return_probabilities, return_labels, on_gpu ) ioconfig = self._update_ioconfig( ioconfig, patch_input_shape, stride_shape, resolution, units ) - if mode == "tile": - warnings.warn( - "WSIPatchDataset only reads image tile at " - '`units="baseline"`. Resolutions will be converted ' - "to baseline value.", - stacklevel=2, - ) - ioconfig = ioconfig.to_baseline() fx_list = ioconfig.scale_to_highest( ioconfig.input_resolutions, ioconfig.input_resolutions[0]["units"] @@ -880,10 +893,10 @@ def predict( fx_list = sorted(fx_list, key=lambda x: x[0]) highest_input_resolution = fx_list[0][1] - save_dir = self._prepare_save_dir(save_dir, imgs) + save_dir = self._prepare_save_dir(save_dir, input_imgs) return self._predict_tile_wsi( - imgs, + input_imgs, masks, labels, mode, From 1daf19e1084fd20a17f2c8183e6d5feeb50d2c2e Mon Sep 17 00:00:00 2001 From: blaginin Date: Mon, 27 Mar 2023 08:57:23 +0100 Subject: [PATCH 03/16] :adhesive_bandage: Add support for numpy array and WSIReader in datasets --- tiatoolbox/cli/patch_predictor.py | 2 +- tiatoolbox/models/dataset/classification.py | 162 +++++++++++--------- tiatoolbox/models/dataset/dataset_abc.py | 19 ++- 3 files changed, 108 insertions(+), 75 deletions(-) diff --git a/tiatoolbox/cli/patch_predictor.py b/tiatoolbox/cli/patch_predictor.py index a0fd49eaf..97021e672 100644 --- a/tiatoolbox/cli/patch_predictor.py +++ b/tiatoolbox/cli/patch_predictor.py @@ -87,7 +87,7 @@ def patch_predictor( ) output = predictor.predict( - imgs=files_all, + input_imgs=files_all, masks=masks_all, mode=mode, return_probabilities=return_probabilities, diff --git a/tiatoolbox/models/dataset/classification.py b/tiatoolbox/models/dataset/classification.py index 9b3e804bb..b3fef2930 100644 --- a/tiatoolbox/models/dataset/classification.py +++ b/tiatoolbox/models/dataset/classification.py @@ -1,6 +1,6 @@ import os import pathlib -import warnings +from typing import Literal, Tuple, Union import cv2 import numpy as np @@ -10,8 +10,7 @@ from tiatoolbox.models.dataset import dataset_abc from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils.misc import imread -from tiatoolbox.wsicore.wsimeta import WSIMeta -from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader +from tiatoolbox.wsicore.wsireader import Units, VirtualWSIReader, WSIReader class _TorchPreprocCaller: @@ -115,9 +114,9 @@ def __getitem__(self, idx): data = { "image": patch, } + if self.labels is not None: data["label"] = self.labels[idx] - return data return data @@ -154,13 +153,13 @@ class WSIPatchDataset(dataset_abc.PatchDatasetABC): def __init__( self, - img_path, - mode="wsi", - mask_path=None, - patch_input_shape=None, - stride_shape=None, - resolution=None, - units=None, + input_img: Union[str, pathlib.Path, np.ndarray, WSIReader], + mode: Literal["wsi", "tile"] = "wsi", + mask: Union[str, pathlib.Path, np.ndarray, VirtualWSIReader] = None, + patch_input_shape: Union[Tuple[int, int], np.ndarray] = None, + stride_shape: Union[Tuple[int, int], np.ndarray] = None, + resolution: float = 1, + units: Units = "baseline", auto_get_mask=True, ): """Create a WSI-level patch dataset. @@ -170,10 +169,16 @@ def __init__( Can be either `wsi` or `tile` to denote the image to read is either a whole-slide image or a large image tile. - img_path (:obj:`str` or :obj:`pathlib.Path`): + input_img: (:obj:`str` or + :obj:`pathlib.Path` or + :obj:`ndarray` or + :obj:`WSIReader`): Valid to pyramidal whole-slide image or large tile to read. - mask_path (:obj:`str` or :obj:`pathlib.Path`): + mask (:obj:`str` or + :obj:`pathlib.Path` or + :obj:`ndarray` or + :obj:`VirtualWSIReader`): Valid mask image. patch_input_shape: A tuple (int, int) or ndarray of shape (2,). Expected @@ -187,10 +192,12 @@ def __init__( `units`. Expected to be positive and of (height, width). Note, this is not at level 0. resolution: - Check (:class:`.WSIReader`) for details. When - `mode='tile'`, value is fixed to be `resolution=1.0` and - `units='baseline'` units: check (:class:`.WSIReader`) for - details. + Check (:class:`.WSIReader`) for details. + If reading from an image without specified metadata, + use `resolution=1.0` and`units='baseline'` units: + check (:class:`.WSIReader`) for details. + units: + Check (:class:`.WSIReader`) for details. preproc_func: Preprocessing function used to transform the input data. @@ -212,11 +219,12 @@ def __init__( """ super().__init__() - # Is there a generic func for path test in toolbox? - if not os.path.isfile(img_path): - raise ValueError("`img_path` must be a valid file path.") if mode not in ["wsi", "tile"]: raise ValueError(f"`{mode}` is not supported.") + + if units not in ["baseline", "power", "mpp"]: + raise ValueError(f"`{units}` is not supported.") + patch_input_shape = np.array(patch_input_shape) stride_shape = np.array(stride_shape) @@ -233,38 +241,7 @@ def __init__( ): raise ValueError(f"Invalid `stride_shape` value {stride_shape}.") - img_path = pathlib.Path(img_path) - if mode == "wsi": - self.reader = WSIReader.open(img_path) - else: - warnings.warn( - "WSIPatchDataset only reads image tile at " - '`units="baseline"` and `resolution=1.0`.', - stacklevel=2, - ) - units = "baseline" - resolution = 1.0 - img = imread(img_path) - axes = "YXS"[: len(img.shape)] - # initialise metadata for VirtualWSIReader. - # here, we simulate a whole-slide image, but with a single level. - # ! should we expose this so that use can provide their metadata ? - metadata = WSIMeta( - mpp=np.array([1.0, 1.0]), - axes=axes, - objective_power=10, - slide_dimensions=np.array(img.shape[:2][::-1]), - level_downsamples=[1.0], - level_dimensions=[np.array(img.shape[:2][::-1])], - ) - # hack value such that read if mask is provided is through - # 'mpp' or 'power' as varying 'baseline' is locked atm - units = "mpp" - resolution = 1.0 - self.reader = VirtualWSIReader( - img, - info=metadata, - ) + self.reader = WSIReader.open(input_img) # may decouple into misc ? # the scaling factor will scale base level to requested read resolution/units @@ -278,17 +255,70 @@ def __init__( input_within_bound=False, ) + self._apply_mask( + mask=mask, + auto_get_mask=auto_get_mask, + resolution=resolution, + units=units, + mode=mode, + ) + + if len(self.inputs) == 0: + raise ValueError("No patch coordinates remain after filtering.") + + self.patch_input_shape = patch_input_shape + self.resolution = resolution + self.units = units + + # Perform check on the input + self._check_input_integrity(mode="wsi") + + def _apply_mask( + self, + mask: Union[str, pathlib.Path, np.ndarray, VirtualWSIReader], + auto_get_mask: bool = True, + resolution: float = 1, + units: Units = "baseline", + mode: Literal["wsi", "tile"] = "wsi", + ): + """Reads or generates a mask for the input image and + applies it to the dataset.""" + mask_reader = None - if mask_path is not None: - if not os.path.isfile(mask_path): - raise ValueError("`mask_path` must be a valid file path.") - mask = imread(mask_path) # assume to be gray - mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY) - mask = np.array(mask > 0, dtype=np.uint8) - - mask_reader = VirtualWSIReader(mask) + if mask is not None: + if not isinstance(mask, (str, pathlib.Path, np.ndarray, VirtualWSIReader)): + raise ValueError( + "`mask` must be file path, np.ndarray or VirtualWSIReader." + ) + + if isinstance(mask, VirtualWSIReader): + if mask.mode != "bool": + raise ValueError( + "`mask` must be binary, " + "i.e. VirtualWSIReader's mode has to be 'bool'" + ) + + mask_reader = mask + + elif isinstance(mask, np.ndarray): + if mask.dtype != np.bool: + raise ValueError( + "`mask` must be binary, i.e. `ndarray.dtype` has to be bool" + ) + + mask_reader = VirtualWSIReader(mask.astype(np.uint8)) + + else: # assume to be file path + if not os.path.isfile(mask): + raise ValueError("`mask` must be a valid file path.") + + mask = imread(mask) # assume to be gray + mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY) + mask = np.array(mask > 0, dtype=np.uint8) + mask_reader = VirtualWSIReader(mask) + mask_reader.info = self.reader.info - elif auto_get_mask and mode == "wsi" and mask_path is None: + elif auto_get_mask and mode == "wsi" and mask is None: # if no mask provided and `wsi` mode, generate basic tissue # mask on the fly mask_reader = self.reader.tissue_mask(resolution=1.25, units="power") @@ -304,16 +334,6 @@ def __init__( ) self.inputs = self.inputs[selected] - if len(self.inputs) == 0: - raise ValueError("No patch coordinates remain after filtering.") - - self.patch_input_shape = patch_input_shape - self.resolution = resolution - self.units = units - - # Perform check on the input - self._check_input_integrity(mode="wsi") - def __getitem__(self, idx): coords = self.inputs[idx] # Read image patch from the whole-slide image diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index 94751538d..771952561 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -1,6 +1,7 @@ import os import pathlib from abc import ABC, abstractmethod +from typing import Union import numpy as np import torch @@ -98,14 +99,26 @@ def _check_input_integrity(self, mode): raise ValueError("`inputs` should be a list of patch coordinates.") @staticmethod - def load_img(path): + def load_img(input_img: Union[str, pathlib.Path, np.ndarray]) -> np.ndarray: """Load an image from a provided path. Args: - path (str): Path to an image file. + input_img (str, pathlib.Path, np.ndarray): path to image or image data. + + Returns: + np.ndarray: image data. """ - path = pathlib.Path(path) + + if not isinstance(input_img, (str, pathlib.Path, np.ndarray)): + raise ValueError( + f"Cannot load image data from `{type(input_img)}` objects." + ) + + if isinstance(input_img, np.ndarray): + return input_img + + path = pathlib.Path(input_img) if path.suffix not in (".npy", ".jpg", ".jpeg", ".tif", ".tiff", ".png"): raise ValueError(f"Cannot load image data from `{path.suffix}` files.") From ac0a9e81d01c661c1a5df5df7a1f52dcbb13d792 Mon Sep 17 00:00:00 2001 From: blaginin Date: Mon, 27 Mar 2023 09:42:36 +0100 Subject: [PATCH 04/16] :white_check_mark: Update tests for np.ndarray and WSIReader support in PatchPredictor --- tests/conftest.py | 22 ++++ tests/models/test_patch_predictor.py | 146 ++++++++++++++++++++++++--- tests/test_patch_extraction.py | 8 +- tests/test_wsireader.py | 46 +++++---- 4 files changed, 183 insertions(+), 39 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1f1f0e6ae..fb6862720 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import pathlib import shutil from pathlib import Path +from tempfile import NamedTemporaryFile from typing import Callable import pytest @@ -59,6 +60,27 @@ def __remote_sample(key: str) -> pathlib.Path: return __remote_sample +@pytest.fixture(scope="session") +def blank_sample(tmp_path_factory: TempPathFactory): + """Factory fixture for creating blank sample files.""" + + class BlankSample: + """Sample file. Automatically deleted after use.""" + + def __init__(self, suffix: str): + self.suffix = suffix + + def __enter__(self) -> pathlib.Path: + folder = tmp_path_factory.mktemp("data") + self.file = NamedTemporaryFile(suffix=self.suffix, dir=folder, delete=True) + return pathlib.Path(self.file.name) + + def __exit__(self, exc_type, exc_value, traceback): + self.file.close() + + return BlankSample + + @pytest.fixture(scope="session") def sample_ndpi(remote_sample) -> pathlib.Path: """Sample pytest fixture for ndpi images. diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index 4639353d7..a03e9d7a0 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -25,7 +25,7 @@ ) from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.utils.misc import download_data, imread, imwrite -from tiatoolbox.wsicore.wsireader import WSIReader +from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader ON_GPU = toolbox_env.has_gpu() @@ -222,9 +222,9 @@ def test_wsi_patch_dataset(sample_wsi_dict, tmp_path): mini_wsi_jpg = pathlib.Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) mini_wsi_msk = pathlib.Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - def reuse_init(img_path=mini_wsi_svs, **kwargs): + def reuse_init(input_img=mini_wsi_svs, **kwargs): """Testing function.""" - return WSIPatchDataset(img_path=img_path, **kwargs) + return WSIPatchDataset(input_img=input_img, **kwargs) def reuse_init_wsi(**kwargs): """Testing function.""" @@ -249,9 +249,9 @@ def __getitem__(self, idx): Proto() # skipcq # invalid path input - with pytest.raises(ValueError, match=r".*`img_path` must be a valid file path.*"): + with pytest.raises(ValueError, match=r".*`input_img` path must exist.*"): WSIPatchDataset( - img_path="aaaa", + input_img="aaaa", mode="wsi", patch_input_shape=[512, 512], stride_shape=[256, 256], @@ -259,10 +259,23 @@ def __getitem__(self, idx): ) # invalid mask path input - with pytest.raises(ValueError, match=r".*`mask_path` must be a valid file path.*"): + with pytest.raises(ValueError, match=r".*`mask` must be a valid file path.*"): WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path="aaaa", + input_img=mini_wsi_svs, + mask="aaaa", + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + resolution=1.0, + units="mpp", + auto_get_mask=False, + ) + + # mask as not VirtualWSIReader + with pytest.raises(ValueError, match=r".*`mask` must be .* VirtualWSIReader.*"): + WSIPatchDataset( + input_img=mini_wsi_svs, + mask=WSIReader.open(mini_wsi_svs), mode="wsi", patch_input_shape=[512, 512], stride_shape=[256, 256], @@ -275,6 +288,10 @@ def __getitem__(self, idx): with pytest.raises(ValueError, match="`X` is not supported."): reuse_init(mode="X") + # invalid units + with pytest.raises(ValueError, match="`X` is not supported."): + reuse_init(units="X") + # invalid patch with pytest.raises(ValueError, match="Invalid `patch_input_shape` value None."): reuse_init() @@ -346,8 +363,8 @@ def __getitem__(self, idx): ) assert len(ds) > 0 ds = WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path=mini_wsi_msk, + input_img=mini_wsi_svs, + mask=mini_wsi_msk, mode="wsi", patch_input_shape=[512, 512], stride_shape=[256, 256], @@ -361,8 +378,8 @@ def __getitem__(self, idx): imwrite(negative_mask_path, negative_mask) with pytest.raises(ValueError, match="No patch coordinates remain after filtering"): ds = WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path=negative_mask_path, + input_img=mini_wsi_svs, + mask=negative_mask_path, mode="wsi", patch_input_shape=[512, 512], stride_shape=[256, 256], @@ -374,7 +391,7 @@ def __getitem__(self, idx): # * for tile reader = WSIReader.open(mini_wsi_jpg) tile_ds = WSIPatchDataset( - img_path=mini_wsi_jpg, + input_img=mini_wsi_jpg, mode="tile", patch_input_shape=patch_size, stride_shape=stride_size, @@ -395,6 +412,58 @@ def __getitem__(self, idx): assert roi1.shape[1] == roi2.shape[1] assert np.min(correlation) > 0.9, correlation + positive_mask = (negative_mask + 1).astype(bool) + # check mask as np array + with pytest.raises(ValueError, match=r".*`mask` must be binary.*"): + WSIPatchDataset( + input_img=mini_wsi_svs, + mask=np.array([[0, 0, 1]]), + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + resolution=1.0, + units="mpp", + ) + ds = WSIPatchDataset( + input_img=mini_wsi_svs, + mask=positive_mask, + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + resolution=1.0, + units="mpp", + ) + + assert len(ds) > 0 + + # check mask VirtualWSIReader + with pytest.raises(ValueError, match=r".*`mask` must be binary.*"): + WSIPatchDataset( + input_img=mini_wsi_svs, + mask=VirtualWSIReader(np.array([[0, 0, 5]])), + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + resolution=1.0, + units="mpp", + ) + + ds = WSIPatchDataset( + input_img=mini_wsi_svs, + mask=VirtualWSIReader(positive_mask, mode="bool"), + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + resolution=1.0, + units="mpp", + ) + + assert len(ds) > 0 + def test_patch_dataset_abc(): """Test for ABC methods.""" @@ -489,6 +558,12 @@ def test_predictor_crash(): predictor.predict([1, 2, 3], masks=[1, 2], mode="wsi") with pytest.raises(ValueError, match=r".*labels.*!=.*imgs.*"): predictor.predict([1, 2, 3], labels=[1, 2], mode="patch") + # mask on patch are not supported + with pytest.raises(ValueError, match=r".*masks are not supported .* `patch`.*"): + predictor.predict( + [np.array([1, 2, 3])], masks=[np.array([1, 2, 3])], mode="patch" + ) + # remove previously generated data _rm_dir("output") @@ -673,7 +748,7 @@ def test_patch_predictor_api(sample_patch1, sample_patch2, tmp_path): # test prediction predictor = PatchPredictor(model=model, batch_size=1, verbose=False) output = predictor.predict( - inputs, + input_imgs=inputs, return_probabilities=True, labels=[1, "a"], return_labels=True, @@ -799,6 +874,49 @@ def test_wsi_predictor_api(sample_wsi_dict, tmp_path): # remove previously generated data _rm_dir("output") + # check that predictor can take in WSIReader object + svs_objects = [WSIReader.open(i) for i in [mini_wsi_svs, mini_wsi_svs]] + output = predictor.predict( + svs_objects, + masks=[mini_wsi_msk, mini_wsi_msk], + mode="wsi", + **kwargs, + ) + assert str(mini_wsi_svs) in output + # remove previously generated data + _rm_dir(kwargs["save_dir"]) + + # check that predictor can take in ndarray object + img_objects = [ + WSIReader.open(i).slide_thumbnail(1, "baseline") + for i in [mini_wsi_svs, mini_wsi_svs] + ] + + with pytest.raises(ValueError, match=".*Cannot determine scale.*"): + predictor.predict( + img_objects, + masks=[mini_wsi_msk, mini_wsi_msk], + mode="wsi", + **kwargs, + ) + _rm_dir(kwargs["save_dir"]) + + _kwargs = copy.deepcopy(kwargs) + _kwargs["units"] = "baseline" + _kwargs["resolution"] = 1.0 + + output = predictor.predict( + img_objects, + masks=[mini_wsi_msk, mini_wsi_msk], + mode="wsi", + **_kwargs, + ) + + assert len(output) == 2 + assert 0 in output + assert 1 in output + _rm_dir(_kwargs["save_dir"]) + def test_wsi_predictor_merge_predictions(sample_wsi_dict): """Test normal run of wsi predictor with merge predictions option.""" diff --git a/tests/test_patch_extraction.py b/tests/test_patch_extraction.py index 1f00bcf92..0d4b3b299 100644 --- a/tests/test_patch_extraction.py +++ b/tests/test_patch_extraction.py @@ -97,10 +97,9 @@ def test_get_patch_extractor(source_image, patch_extr_csv): def test_points_patch_extractor_image_format( - sample_svs, sample_jp2, source_image, patch_extr_csv + sample_svs, sample_jp2, source_image, patch_extr_csv, blank_sample ): """Test PointsPatchExtractor returns the right object.""" - file_parent_dir = pathlib.Path(__file__).parent locations_list = pathlib.Path(patch_extr_csv) points = patchextraction.get_patch_extractor( @@ -130,10 +129,9 @@ def test_points_patch_extractor_image_format( assert isinstance(points.wsi, OmnyxJP2WSIReader) - false_image = pathlib.Path(file_parent_dir.joinpath("data/source_image.test")) - with pytest.raises(FileNotSupported): + with blank_sample(".test") as false_image_path, pytest.raises(FileNotSupported): _ = patchextraction.get_patch_extractor( - input_img=false_image, + input_img=false_image_path, locations_list=locations_list, method_name="point", patch_size=(200, 200), diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py index 0cb332f35..2b60788d7 100644 --- a/tests/test_wsireader.py +++ b/tests/test_wsireader.py @@ -1373,11 +1373,11 @@ def test_invalid_masker_method(sample_svs): def test_wsireader_open( - sample_svs, sample_ndpi, sample_jp2, sample_ome_tiff, source_image + sample_svs, sample_ndpi, sample_jp2, sample_ome_tiff, source_image, blank_sample ): """Test WSIReader.open() to return correct object.""" - with pytest.raises(FileNotSupported): - _ = WSIReader.open("./sample.csv") + with blank_sample(".csv") as path, pytest.raises(FileNotSupported): + _ = WSIReader.open(path) with pytest.raises(TypeError): _ = WSIReader.open([1, 2]) @@ -1643,13 +1643,15 @@ def test_command_line_read_bounds(sample_ndpi, tmp_path): def test_command_line_jp2_read_bounds(sample_jp2, tmp_path): """Test JP2 read_bounds.""" + input_img = pathlib.Path(sample_jp2) + runner = CliRunner() read_bounds_result = runner.invoke( cli.main, [ "read-bounds", "--img-input", - str(pathlib.Path(sample_jp2)), + str(input_img), "--resolution", "0", "--units", @@ -1660,7 +1662,9 @@ def test_command_line_jp2_read_bounds(sample_jp2, tmp_path): ) assert read_bounds_result.exit_code == 0 - assert pathlib.Path(tmp_path).joinpath("../im_region.jpg").is_file() + input_dir = pathlib.Path(input_img).parent.parent + output_path = os.path.join(input_dir, "im_region.jpg") + assert pathlib.Path(output_path).is_file() @pytest.mark.skipif( @@ -1688,23 +1692,25 @@ def test_command_line_jp2_read_bounds_show(sample_jp2, tmp_path): assert read_bounds_result.exit_code == 0 -def test_command_line_unsupported_file_read_bounds(sample_svs, tmp_path): +def test_command_line_unsupported_file_read_bounds(sample_svs, tmp_path, blank_sample): """Test unsupported file read bounds.""" runner = CliRunner() - read_bounds_result = runner.invoke( - cli.main, - [ - "read-bounds", - "--img-input", - str(pathlib.Path(sample_svs))[:-1], - "--resolution", - "0", - "--units", - "level", - "--mode", - "save", - ], - ) + + with blank_sample(".csv") as file: + read_bounds_result = runner.invoke( + cli.main, + [ + "read-bounds", + "--img-input", + str(file), + "--resolution", + "0", + "--units", + "level", + "--mode", + "save", + ], + ) assert read_bounds_result.output == "" assert read_bounds_result.exit_code == 1 From 7e9ac19a6dee06d5dfdc539a4bc9fb72d7bbbf80 Mon Sep 17 00:00:00 2001 From: blaginin Date: Mon, 27 Mar 2023 11:43:44 +0100 Subject: [PATCH 05/16] :recycle: Refactor code. --- tests/conftest.py | 1 + tests/test_wsireader.py | 10 +++++++--- tiatoolbox/models/dataset/classification.py | 1 - tiatoolbox/models/dataset/dataset_abc.py | 1 - tiatoolbox/models/engine/patch_predictor.py | 5 +++-- tiatoolbox/wsicore/wsireader.py | 5 ++++- 6 files changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index fb6862720..4d1a01f46 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -69,6 +69,7 @@ class BlankSample: def __init__(self, suffix: str): self.suffix = suffix + self.file = None # will be set in __enter__ def __enter__(self) -> pathlib.Path: folder = tmp_path_factory.mktemp("data") diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py index e6d71d1a9..f55e879c6 100644 --- a/tests/test_wsireader.py +++ b/tests/test_wsireader.py @@ -1980,7 +1980,9 @@ def test_store_reader_alpha(remote_sample): def test_store_reader_no_types(tmp_path, remote_sample): - """Test AnnotationStoreReader with no types.""" + """ + Test AnnotationStoreReader with no types. + """ SQLiteStore(tmp_path / "store.db") wsi_reader = WSIReader.open(remote_sample("svs-1-small")) reader = AnnotationStoreReader(tmp_path / "store.db", wsi_reader.info) @@ -1989,8 +1991,10 @@ def test_store_reader_no_types(tmp_path, remote_sample): def test_store_reader_info_from_base(tmp_path, remote_sample): - """Test that AnnotationStoreReader will correctly get metadata - from a provided base_wsi if the store has no wsi metadata.""" + """ + Test that AnnotationStoreReader will correctly get metadata + from a provided base_wsi if the store has no wsi metadata. + """ SQLiteStore(tmp_path / "store.db") wsi_reader = WSIReader.open(remote_sample("svs-1-small")) store_reader = AnnotationStoreReader(tmp_path / "store.db", base_wsi=wsi_reader) diff --git a/tiatoolbox/models/dataset/classification.py b/tiatoolbox/models/dataset/classification.py index b3fef2930..8522d17d1 100644 --- a/tiatoolbox/models/dataset/classification.py +++ b/tiatoolbox/models/dataset/classification.py @@ -283,7 +283,6 @@ def _apply_mask( ): """Reads or generates a mask for the input image and applies it to the dataset.""" - mask_reader = None if mask is not None: if not isinstance(mask, (str, pathlib.Path, np.ndarray, VirtualWSIReader)): diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index 771952561..22e3c9cc0 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -109,7 +109,6 @@ def load_img(input_img: Union[str, pathlib.Path, np.ndarray]) -> np.ndarray: np.ndarray: image data. """ - if not isinstance(input_img, (str, pathlib.Path, np.ndarray)): raise ValueError( f"Cannot load image data from `{type(input_img)}` objects." diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 5c5af9cdb..36eeefa20 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -720,8 +720,8 @@ def _predict_tile_wsi( return file_dict if save_output else outputs + @staticmethod def _save_output( - self, output_model, idx, merged_prediction, @@ -730,6 +730,7 @@ def _save_output( save_dir, merge_predictions, ): + """Save prediction to json and/or numpy file.""" # dynamic 0 padding img_code = f"{idx:0{len(str(len(input_imgs)))}d}" @@ -753,7 +754,7 @@ def _save_output( if img_id is None: img_id = idx - return save_info + return img_id, save_info def predict( self, diff --git a/tiatoolbox/wsicore/wsireader.py b/tiatoolbox/wsicore/wsireader.py index 1708eeeb3..d7fb72936 100644 --- a/tiatoolbox/wsicore/wsireader.py +++ b/tiatoolbox/wsicore/wsireader.py @@ -279,6 +279,10 @@ def get_reader_by_filepath( power: Optional[Number] = None, **kwargs, ) -> WSIReader: + """ + Returns an appropriate :class:`.WSIReader` object + based on the file extension. + """ # Handle special cases first (DICOM, Zarr/NGFF, OME-TIFF) if is_dicom(input_path): return DICOMWSIReader(input_path, mpp=mpp, power=power) @@ -1432,7 +1436,6 @@ def save_tiles( >>> slide_param = wsi.info """ - if verbose: logger.setLevel(logging.DEBUG) From d711ed158969054a89b473b98e677eab0b575e80 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 5 Apr 2023 09:12:52 +0100 Subject: [PATCH 06/16] =?UTF-8?q?=F0=9F=93=8C=20Pin=20Pandas=20Version=20t?= =?UTF-8?q?o=20`>=3D2.0.0`=20-=20Pin=20Pandas=20Version=20to=20`>=3D2.0.0`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 2 +- tiatoolbox/annotation/storage.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index fdab6d53e..f85adf366 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ matplotlib>=3.6.2 numpy>=1.23.5, <1.24 # v1.24 produces error on Windows opencv-python>=4.6.0 openslide-python>=1.2.0 -pandas>=1.5.2 +pandas>=2.0.0 pillow>=9.3.0 pydicom>=2.3.1 # Used by wsidicom pyyaml>=6.0 diff --git a/tiatoolbox/annotation/storage.py b/tiatoolbox/annotation/storage.py index df3e03fd4..2d1a1d8ae 100644 --- a/tiatoolbox/annotation/storage.py +++ b/tiatoolbox/annotation/storage.py @@ -2768,7 +2768,7 @@ def to_dataframe(self) -> pd.DataFrame: } for key, annotation in self.items() ) - df = df.append(pd.json_normalize(df_rows)) + df = pd.concat([df, pd.json_normalize(df_rows)]) return df.set_index("key") def features(self) -> Generator[Dict[str, Any], None, None]: From 3aec8e0a50e5c45c119406da4d1920e4ef369a98 Mon Sep 17 00:00:00 2001 From: blaginin Date: Thu, 6 Apr 2023 13:33:33 +0100 Subject: [PATCH 07/16] :adhesive_bandage: add ignore_resolutions for compatibility --- tiatoolbox/models/engine/patch_predictor.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 36eeefa20..0c9af2ca7 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -773,6 +773,7 @@ def predict( merge_predictions: bool = False, save_dir: bool = None, save_output: bool = False, + ignore_resolutions: bool = False, ): """Make a prediction for a list of input data. @@ -829,7 +830,11 @@ def predict( where the running script is invoked. save_output (bool): Whether to save output for a single file. default=False - + ignore_resolutions (bool): + Whether to ignore the resolution of the input images. + PatchPredictor won't rescale the input images and will + and use them in the original resolution. + Works with `mode='patch'` only. Returns: (:class:`numpy.ndarray`, dict): Model predictions of the input dataset. If multiple @@ -887,6 +892,16 @@ def predict( ioconfig, patch_input_shape, stride_shape, resolution, units ) + if mode == "tile" and ignore_resolutions: + warnings.warn( + "WSIPatchDataset only reads image tile at " + '`units="baseline"`. Resolutions will be converted ' + "to baseline value. " + "Set ignore_resolutions to False to change this behaviour.", + stacklevel=2, + ) + ioconfig = ioconfig.to_baseline() + fx_list = ioconfig.scale_to_highest( ioconfig.input_resolutions, ioconfig.input_resolutions[0]["units"] ) From 15790fa9342ce3b8d17eaaa556b4ea83163ea398 Mon Sep 17 00:00:00 2001 From: blaginin Date: Thu, 6 Apr 2023 13:34:35 +0100 Subject: [PATCH 08/16] :white_check_mark: Add tests for different types of input for `WSIPatchDataset` --- tests/models/test_patch_predictor.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index a03e9d7a0..e328040d5 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -214,6 +214,9 @@ def test_patch_dataset_crash(tmp_path): ): predefined_preproc_func("secret-dataset") + with pytest.raises(ValueError, match=r".*Cannot load image data from.*"): + _ = PatchDataset(-1) # not a file path or image + def test_wsi_patch_dataset(sample_wsi_dict, tmp_path): """A test for creation and bare output.""" @@ -451,7 +454,7 @@ def __getitem__(self, idx): units="mpp", ) - ds = WSIPatchDataset( + ds_from_fp = WSIPatchDataset( input_img=mini_wsi_svs, mask=VirtualWSIReader(positive_mask, mode="bool"), mode="wsi", @@ -462,7 +465,21 @@ def __getitem__(self, idx): units="mpp", ) - assert len(ds) > 0 + assert len(ds_from_fp) > 0 + + mini_wsi_svs_np = imread(mini_wsi_svs) + ds_from_np = WSIPatchDataset( + input_img=mini_wsi_svs_np, + mask=VirtualWSIReader(positive_mask, mode="bool"), + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + resolution=1.0, + units="baseline", + ) + + assert len(ds_from_np) > 0 def test_patch_dataset_abc(): From 50cc400fc0c527aaa3fbae6cea37d5c0a84baf09 Mon Sep 17 00:00:00 2001 From: blaginin Date: Thu, 6 Apr 2023 14:16:33 +0100 Subject: [PATCH 09/16] :recycle: removed the blank line --- tiatoolbox/models/engine/patch_predictor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 0c9af2ca7..ca539c5f9 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -864,7 +864,6 @@ def predict( ... {'raw': '1.raw.json', 'merged': '1.merged.npy'} """ - if mode not in ["patch", "wsi", "tile"]: raise ValueError( f"{mode} is not a valid mode. Use either `patch`, `tile` or `wsi`" From 95f8faa3b1d9aa81c3a8d736bed8e27759ba9ce5 Mon Sep 17 00:00:00 2001 From: blaginin Date: Thu, 6 Apr 2023 14:32:34 +0100 Subject: [PATCH 10/16] :white_check_mark: Fix tests for `PatchDataset` --- tests/models/test_patch_predictor.py | 3 --- tiatoolbox/models/dataset/dataset_abc.py | 5 ----- 2 files changed, 8 deletions(-) diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index e328040d5..ff33bdded 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -214,9 +214,6 @@ def test_patch_dataset_crash(tmp_path): ): predefined_preproc_func("secret-dataset") - with pytest.raises(ValueError, match=r".*Cannot load image data from.*"): - _ = PatchDataset(-1) # not a file path or image - def test_wsi_patch_dataset(sample_wsi_dict, tmp_path): """A test for creation and bare output.""" diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index 22e3c9cc0..c57fcbc69 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -109,11 +109,6 @@ def load_img(input_img: Union[str, pathlib.Path, np.ndarray]) -> np.ndarray: np.ndarray: image data. """ - if not isinstance(input_img, (str, pathlib.Path, np.ndarray)): - raise ValueError( - f"Cannot load image data from `{type(input_img)}` objects." - ) - if isinstance(input_img, np.ndarray): return input_img From bb063d739023b955286223a97cf7e2a59bdbd6a7 Mon Sep 17 00:00:00 2001 From: blaginin Date: Thu, 6 Apr 2023 15:32:17 +0100 Subject: [PATCH 11/16] :rewind: Undo _prepare_save_dir bugfix --- tiatoolbox/models/engine/patch_predictor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index ca539c5f9..aa51ea75b 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -543,7 +543,7 @@ def _prepare_save_dir(save_dir, imgs): save_dir = pathlib.Path(save_dir) save_dir.mkdir(parents=True, exist_ok=False) - return save_dir or pathlib.Path(os.getcwd()) + return save_dir def _predict_patch( self, input_imgs, labels, return_probabilities, return_labels, on_gpu From 2823b4bb0c5c802ef29a8e55b261a1136892b93d Mon Sep 17 00:00:00 2001 From: blaginin Date: Sat, 8 Apr 2023 21:57:27 +0100 Subject: [PATCH 12/16] :white_check_mark: Add tests for ignore_resolutions mode --- tests/models/test_patch_predictor.py | 9 +++++++++ tiatoolbox/models/dataset/dataset_abc.py | 3 --- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index ff33bdded..3dfd5d7b4 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -915,6 +915,15 @@ def test_wsi_predictor_api(sample_wsi_dict, tmp_path): ) _rm_dir(kwargs["save_dir"]) + _ = predictor.predict( + img_objects, + masks=[mini_wsi_msk, mini_wsi_msk], + mode="tile", + ignore_resolutions=True, + **kwargs, + ) + _rm_dir(kwargs["save_dir"]) + _kwargs = copy.deepcopy(kwargs) _kwargs["units"] = "baseline" _kwargs["resolution"] = 1.0 diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index c57fcbc69..f5f51b870 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -109,9 +109,6 @@ def load_img(input_img: Union[str, pathlib.Path, np.ndarray]) -> np.ndarray: np.ndarray: image data. """ - if isinstance(input_img, np.ndarray): - return input_img - path = pathlib.Path(input_img) if path.suffix not in (".npy", ".jpg", ".jpeg", ".tif", ".tiff", ".png"): From 5e9ae4fec8ae1f2ac274c0995ecdb4219295e706 Mon Sep 17 00:00:00 2001 From: blaginin Date: Tue, 2 May 2023 21:54:30 +0100 Subject: [PATCH 13/16] :recycle: refactor patch predictor and related methods --- requirements/requirements.txt | 2 +- tests/models/test_patch_predictor.py | 2 +- tiatoolbox/cli/patch_predictor.py | 2 +- tiatoolbox/models/dataset/classification.py | 4 ++-- tiatoolbox/models/engine/patch_predictor.py | 20 ++++++++-------- tiatoolbox/wsicore/wsireader.py | 26 ++++++++++----------- 6 files changed, 28 insertions(+), 28 deletions(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 77358c6bb..999b445b8 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -14,7 +14,7 @@ opencv-python>=4.6.0 openslide-python>=1.2.0 pandas>=2.0.0 pillow>=9.3.0 -pydicom>=2.3.1 # Used by wsidef test_store_reader_no_types(tmp_path, remote_sample):dicom +pydicom>=2.3.1 # Used by wsidicom pyyaml>=6.0 requests>=2.28.1 scikit-image>=0.20 diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index f71341270..e7b4693a5 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -764,7 +764,7 @@ def test_patch_predictor_api(sample_patch1, sample_patch2, tmp_path): # test prediction predictor = PatchPredictor(model=model, batch_size=1, verbose=False) output = predictor.predict( - input_imgs=inputs, + imgs=inputs, return_probabilities=True, labels=[1, "a"], return_labels=True, diff --git a/tiatoolbox/cli/patch_predictor.py b/tiatoolbox/cli/patch_predictor.py index 97021e672..a0fd49eaf 100644 --- a/tiatoolbox/cli/patch_predictor.py +++ b/tiatoolbox/cli/patch_predictor.py @@ -87,7 +87,7 @@ def patch_predictor( ) output = predictor.predict( - input_imgs=files_all, + imgs=files_all, masks=masks_all, mode=mode, return_probabilities=return_probabilities, diff --git a/tiatoolbox/models/dataset/classification.py b/tiatoolbox/models/dataset/classification.py index 265740e50..c888de782 100644 --- a/tiatoolbox/models/dataset/classification.py +++ b/tiatoolbox/models/dataset/classification.py @@ -10,7 +10,7 @@ from tiatoolbox.models.dataset import dataset_abc from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils.misc import imread -from tiatoolbox.wsicore.wsireader import Units, VirtualWSIReader, WSIReader +from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader class _TorchPreprocCaller: @@ -154,7 +154,7 @@ def __init__( patch_input_shape: Union[Tuple[int, int], np.ndarray] = None, stride_shape: Union[Tuple[int, int], np.ndarray] = None, resolution: float = 1, - units: Units = "baseline", + units: str = "baseline", auto_get_mask=True, min_mask_ratio=0, preproc_func=None, diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index aa51ea75b..3ff758c32 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -16,7 +16,7 @@ from tiatoolbox.models.engine.semantic_segmentor import IOSegmentorConfig from tiatoolbox.utils import misc from tiatoolbox.utils.misc import save_as_json -from tiatoolbox.wsicore.wsireader import Units, WSIReader +from tiatoolbox.wsicore.wsireader import WSIReader class IOPatchPredictorConfig(IOSegmentorConfig): @@ -758,7 +758,7 @@ def _save_output( def predict( self, - input_imgs: List[Union[str, pathlib.Path, np.ndarray, WSIReader]], + imgs: List[Union[str, pathlib.Path, np.ndarray, WSIReader]], masks: List[Union[str, pathlib.Path, np.ndarray, WSIReader]] = None, labels: List = None, mode: Literal["patch", "tile", "wsi"] = "patch", @@ -769,7 +769,7 @@ def predict( patch_input_shape: Tuple[int, int] = None, stride_shape: Tuple[int, int] = None, resolution: float = None, - units: Units = None, + units: str = None, merge_predictions: bool = False, save_dir: bool = None, save_output: bool = False, @@ -778,7 +778,7 @@ def predict( """Make a prediction for a list of input data. Args: - input_imgs (list): + imgs (list): List of inputs to process. Must be either a list of images, a list of image file paths, WSIReader objects, or a numpy array of an image list. @@ -872,19 +872,19 @@ def predict( if mode == "patch" and masks is not None: raise ValueError("masks are not supported for `patch` mode. ") - if not isinstance(input_imgs, list): + if not isinstance(imgs, list): raise ValueError( "Input to `tile` and `wsi` mode must be a list of file paths." ) - if mode == "wsi" and masks is not None and len(masks) != len(input_imgs): + if mode == "wsi" and masks is not None and len(masks) != len(imgs): raise ValueError( - f"len(masks) != len(imgs) : " f"{len(masks)} != {len(input_imgs)}" + f"len(masks) != len(imgs) : " f"{len(masks)} != {len(imgs)}" ) if mode == "patch": return self._predict_patch( - input_imgs, labels, return_probabilities, return_labels, on_gpu + imgs, labels, return_probabilities, return_labels, on_gpu ) ioconfig = self._update_ioconfig( @@ -908,10 +908,10 @@ def predict( fx_list = sorted(fx_list, key=lambda x: x[0]) highest_input_resolution = fx_list[0][1] - save_dir = self._prepare_save_dir(save_dir, input_imgs) + save_dir = self._prepare_save_dir(save_dir, imgs) return self._predict_tile_wsi( - input_imgs, + imgs, masks, labels, mode, diff --git a/tiatoolbox/wsicore/wsireader.py b/tiatoolbox/wsicore/wsireader.py index bf40de390..d365f94f0 100644 --- a/tiatoolbox/wsicore/wsireader.py +++ b/tiatoolbox/wsicore/wsireader.py @@ -10,7 +10,7 @@ import re from datetime import datetime from numbers import Number -from typing import Iterable, List, Literal, Optional, Tuple, Union +from typing import Iterable, List, Optional, Tuple, Union import numpy as np import openslide @@ -36,7 +36,6 @@ Bounds = Tuple[Number, Number, Number, Number] IntBounds = Tuple[int, int, int, int] Resolution = Union[Number, Tuple[Number, Number], np.ndarray] -Units = Literal["mpp", "power", "baseline", "level"] MIN_NGFF_VERSION = Version("0.4") MAX_NGFF_VERSION = Version("0.4") @@ -282,6 +281,7 @@ def get_reader_by_filepath( Returns an appropriate :class:`.WSIReader` object based on the file extension. """ + # Handle special cases first (DICOM, Zarr/NGFF, OME-TIFF) if is_dicom(input_path): return DICOMWSIReader(input_path, mpp=mpp, power=power) @@ -440,7 +440,7 @@ def _info(self) -> WSIMeta: raise NotImplementedError def _find_optimal_level_and_downsample( - self, resolution: Resolution, units: Units, precision: int = 3 + self, resolution: Resolution, units: str, precision: int = 3 ) -> Tuple[int, np.ndarray]: """Find the optimal level to read at for a desired resolution and units. @@ -502,7 +502,7 @@ def find_read_rect_params( location: IntPair, size: IntPair, resolution: Resolution, - units: Units, + units: str, precision: int = 3, ) -> Tuple[int, IntPair, IntPair, NumPair, IntPair]: """Find optimal parameters for reading a rect at a given resolution. @@ -579,7 +579,7 @@ def find_read_rect_params( ) def _find_read_params_at_resolution( - self, location: IntPair, size: IntPair, resolution: Resolution, units: Units + self, location: IntPair, size: IntPair, resolution: Resolution, units: str ) -> Tuple[int, NumPair, IntPair, IntPair, IntPair, IntPair]: """Works similarly to `_find_read_rect_params`. @@ -666,7 +666,7 @@ def _find_read_params_at_resolution( ) + output def _bounds_at_resolution_to_baseline( - self, bounds: Bounds, resolution: Resolution, units: Units + self, bounds: Bounds, resolution: Resolution, units: str ) -> Bounds: """Find corresponding bounds in baseline. @@ -694,7 +694,7 @@ def _bounds_at_resolution_to_baseline( return np.concatenate([tl_at_baseline, br_at_baseline]) # bounds at baseline def slide_dimensions( - self, resolution: Resolution, units: Units, precisions: int = 3 + self, resolution: Resolution, units: str, precisions: int = 3 ) -> IntPair: """Return the size of WSI at requested resolution. @@ -728,7 +728,7 @@ def slide_dimensions( return wsi_shape_at_resolution def _find_read_bounds_params( - self, bounds: Bounds, resolution: Resolution, units: Units, precision: int = 3 + self, bounds: Bounds, resolution: Resolution, units: str, precision: int = 3 ) -> Tuple[int, IntBounds, IntPair, IntPair, np.ndarray]: """Find optimal parameters for reading bounds at a given resolution. @@ -961,7 +961,7 @@ def _read_rect_at_resolution( location: NumPair, size: NumPair, resolution: Resolution = 0, - units: Units = "level", + units: str = "level", interpolation: str = "optimise", pad_mode: str = "constant", pad_constant_values: Union[Number, Iterable[NumPair]] = 0, @@ -993,7 +993,7 @@ def read_rect( location: IntPair, size: IntPair, resolution: Resolution = 0, - units: Units = "level", + units: str = "level", interpolation: str = "optimise", pad_mode: str = "constant", pad_constant_values: Union[Number, Iterable[NumPair]] = 0, @@ -1187,7 +1187,7 @@ def read_bounds( self, bounds: Bounds, resolution: Resolution = 0, - units: Units = "level", + units: str = "level", interpolation: str = "optimise", pad_mode: str = "constant", pad_constant_values: Union[Number, Iterable[NumPair]] = 0, @@ -1325,7 +1325,7 @@ def read_region(self, location: NumPair, level: int, size: IntPair) -> np.ndarra location=location, size=size, resolution=level, units="level" ) - def slide_thumbnail(self, resolution: Resolution = 1.25, units: Units = "power"): + def slide_thumbnail(self, resolution: Resolution = 1.25, units: str = "power"): """Read the whole slide image thumbnail (1.25x by default). For more information on resolution and units see @@ -1356,7 +1356,7 @@ def tissue_mask( self, method: str = "otsu", resolution: Resolution = 1.25, - units: Units = "power", + units: str = "power", **masker_kwargs, ) -> "VirtualWSIReader": """Create a tissue mask and wrap it in a VirtualWSIReader. From 5bcdc682a4210820ed0533def25c8fa8a2f8862a Mon Sep 17 00:00:00 2001 From: Dima Blaginin Date: Tue, 2 May 2023 22:59:09 +0100 Subject: [PATCH 14/16] :recycle: add a blank line to `get_reader_by_filepath` Co-authored-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/wsicore/wsireader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tiatoolbox/wsicore/wsireader.py b/tiatoolbox/wsicore/wsireader.py index d365f94f0..27e14dd5d 100644 --- a/tiatoolbox/wsicore/wsireader.py +++ b/tiatoolbox/wsicore/wsireader.py @@ -280,6 +280,7 @@ def get_reader_by_filepath( """ Returns an appropriate :class:`.WSIReader` object based on the file extension. + """ # Handle special cases first (DICOM, Zarr/NGFF, OME-TIFF) From c31d07840be26e529eea31d2f4c9ee77d94e50ab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 May 2023 21:59:39 +0000 Subject: [PATCH 15/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tiatoolbox/wsicore/wsireader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/wsicore/wsireader.py b/tiatoolbox/wsicore/wsireader.py index 27e14dd5d..74b725577 100644 --- a/tiatoolbox/wsicore/wsireader.py +++ b/tiatoolbox/wsicore/wsireader.py @@ -280,7 +280,7 @@ def get_reader_by_filepath( """ Returns an appropriate :class:`.WSIReader` object based on the file extension. - + """ # Handle special cases first (DICOM, Zarr/NGFF, OME-TIFF) From 11674c4fe93549bc317381124fc5fc1870ddf9ff Mon Sep 17 00:00:00 2001 From: blaginin Date: Mon, 8 May 2023 18:15:46 +0100 Subject: [PATCH 16/16] :bug: fix bool dtype --- tiatoolbox/models/dataset/classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/dataset/classification.py b/tiatoolbox/models/dataset/classification.py index c888de782..72a6c43ea 100644 --- a/tiatoolbox/models/dataset/classification.py +++ b/tiatoolbox/models/dataset/classification.py @@ -306,7 +306,7 @@ def _apply_mask( mask_reader = mask elif isinstance(mask, np.ndarray): - if mask.dtype != np.bool: + if mask.dtype != bool: raise ValueError( "`mask` must be binary, i.e. `ndarray.dtype` has to be bool" )