2020 from multiprocessing .managers import Namespace
2121
2222 from tiatoolbox .models .engine .io_config import IOSegmentorConfig
23+ from tiatoolbox .typing import IntPair , Resolution , Units
2324
2425
2526class PatchDatasetABC (ABC , torch .utils .data .Dataset ):
2627 """Define abstract base class for patch dataset."""
2728
2829 def __init__ (
29- self ,
30+ self : PatchDatasetABC ,
3031 ) -> None :
3132 """Initialize :class:`PatchDatasetABC`."""
3233 super ().__init__ ()
@@ -36,7 +37,7 @@ def __init__(
3637 self .labels = []
3738
3839 @staticmethod
39- def _check_shape_integrity (shapes ) :
40+ def _check_shape_integrity (shapes : list | np . ndarray ) -> None :
4041 """Checks the integrity of input shapes.
4142
4243 Args:
@@ -56,7 +57,7 @@ def _check_shape_integrity(shapes):
5657 msg = "Images must have the same dimensions."
5758 raise ValueError (msg )
5859
59- def _check_input_integrity (self , mode ) :
60+ def _check_input_integrity (self : PatchDatasetABC , mode : str ) -> None :
6061 """Check that variables received during init are valid.
6162
6263 These checks include:
@@ -113,11 +114,15 @@ def _check_input_integrity(self, mode):
113114 raise ValueError (msg )
114115
115116 @staticmethod
116- def load_img (path ) :
117+ def load_img (path : str | Path ) -> np . ndarray :
117118 """Load an image from a provided path.
118119
119120 Args:
120- path (str): Path to an image file.
121+ path (str or Path): Path to an image file.
122+
123+ Returns:
124+ :class:`numpy.ndarray`:
125+ Image as a numpy array.
121126
122127 """
123128 path = Path (path )
@@ -129,12 +134,12 @@ def load_img(path):
129134 return imread (path , as_uint8 = False )
130135
131136 @staticmethod
132- def preproc (image ) :
137+ def preproc (image : np . ndarray ) -> np . ndarray :
133138 """Define the pre-processing of this class of loader."""
134139 return image
135140
136141 @property
137- def preproc_func (self ) :
142+ def preproc_func (self : PatchDatasetABC ) -> Callable :
138143 """Return the current pre-processing function of this instance.
139144
140145 The returned function is expected to behave as follows:
@@ -144,7 +149,7 @@ def preproc_func(self):
144149 return self ._preproc
145150
146151 @preproc_func .setter
147- def preproc_func (self , func ) :
152+ def preproc_func (self : PatchDatasetABC , func : Callable ) -> None :
148153 """Set the pre-processing function for this instance.
149154
150155 If `func=None`, the method will default to `self.preproc`.
@@ -162,12 +167,12 @@ def preproc_func(self, func):
162167 msg = f"{ func } is not callable!"
163168 raise ValueError (msg )
164169
165- def __len__ (self ) -> int :
170+ def __len__ (self : PatchDatasetABC ) -> int :
166171 """Return the length of the instance attributes."""
167172 return len (self .inputs )
168173
169174 @abstractmethod
170- def __getitem__ (self , idx ) :
175+ def __getitem__ (self : PatchDatasetABC , idx : int ) -> None :
171176 """Get an item from the dataset."""
172177 ... # pragma: no cover
173178
@@ -213,12 +218,12 @@ class WSIStreamDataset(torch_data.Dataset):
213218 """
214219
215220 def __init__ (
216- self ,
221+ self : WSIStreamDataset ,
217222 ioconfig : IOSegmentorConfig ,
218223 wsi_paths : list [str | Path ],
219224 mp_shared_space : Namespace ,
220225 preproc : Callable [[np .ndarray ], np .ndarray ] | None = None ,
221- mode = "wsi" ,
226+ mode : str = "wsi" ,
222227 ) -> None :
223228 """Initialize :class:`WSIStreamDataset`."""
224229 super ().__init__ ()
@@ -240,7 +245,7 @@ def __init__(
240245 self .wsi_idx = None # to be received externally via thread communication
241246 self .reader = None
242247
243- def _get_reader (self , img_path ) :
248+ def _get_reader (self : WSIStreamDataset , img_path : str | Path ) -> WSIReader :
244249 """Get appropriate reader for input path."""
245250 img_path = Path (img_path )
246251 if self .mode == "wsi" :
@@ -261,12 +266,12 @@ def _get_reader(self, img_path):
261266 info = metadata ,
262267 )
263268
264- def __len__ (self ) -> int :
269+ def __len__ (self : WSIStreamDataset ) -> int :
265270 """Return the length of the instance attributes."""
266271 return len (self .mp_shared_space .patch_inputs )
267272
268273 @staticmethod
269- def collate_fn (batch ) :
274+ def collate_fn (batch : list | np . ndarray ) -> torch . Tensor :
270275 """Prototype to handle reading exception.
271276
272277 This will exclude any sample with `None` from the batch. As
@@ -278,7 +283,7 @@ def collate_fn(batch):
278283 batch = [v for v in batch if v is not None ]
279284 return torch .utils .data .dataloader .default_collate (batch )
280285
281- def __getitem__ (self , idx : int ):
286+ def __getitem__ (self : WSIStreamDataset , idx : int ) -> tuple :
282287 """Get an item from the dataset."""
283288 # ! no need to lock as we do not modify source value in shared space
284289 if self .wsi_idx != self .mp_shared_space .wsi_idx :
@@ -341,18 +346,18 @@ class WSIPatchDataset(PatchDatasetABC):
341346 """
342347
343348 def __init__ ( # noqa: PLR0913, PLR0915
344- self ,
345- img_path ,
346- mode = "wsi" ,
347- mask_path = None ,
348- patch_input_shape = None ,
349- stride_shape = None ,
350- resolution = None ,
351- units = None ,
352- min_mask_ratio = 0 ,
353- preproc_func = None ,
349+ self : WSIPatchDataset ,
350+ img_path : str | Path ,
351+ mode : str = "wsi" ,
352+ mask_path : str | Path | None = None ,
353+ patch_input_shape : IntPair = None ,
354+ stride_shape : IntPair = None ,
355+ resolution : Resolution = None ,
356+ units : Units = None ,
357+ min_mask_ratio : float = 0 ,
358+ preproc_func : Callable | None = None ,
354359 * ,
355- auto_get_mask = True ,
360+ auto_get_mask : bool = True ,
356361 ) -> None :
357362 """Create a WSI-level patch dataset.
358363
@@ -377,20 +382,20 @@ def __init__( # noqa: PLR0913, PLR0915
377382 stride shape to read at requested `resolution` and
378383 `units`. Expected to be positive and of (height, width).
379384 Note, this is not at level 0.
380- resolution:
385+ resolution (Resolution) :
381386 Check (:class:`.WSIReader`) for details. When
382387 `mode='tile'`, value is fixed to be `resolution=1.0` and
383388 `units='baseline'` units: check (:class:`.WSIReader`) for
384389 details.
385- units:
390+ units (Units) :
386391 Units in which `resolution` is defined.
387- auto_get_mask:
392+ auto_get_mask (bool) :
388393 If `True`, then automatically get simple threshold mask using
389394 WSIReader.tissue_mask() function.
390- min_mask_ratio:
395+ min_mask_ratio (float) :
391396 Only patches with positive area percentage above this value are
392397 included. Defaults to 0.
393- preproc_func:
398+ preproc_func (Callable) :
394399 Preprocessing function used to transform the input data. If
395400 supplied, the function will be called on each patch before
396401 returning it.
@@ -521,7 +526,7 @@ def __init__( # noqa: PLR0913, PLR0915
521526 # Perform check on the input
522527 self ._check_input_integrity (mode = "wsi" )
523528
524- def __getitem__ (self , idx ) :
529+ def __getitem__ (self : WSIPatchDataset , idx : int ) -> dict :
525530 """Get an item from the dataset."""
526531 coords = self .inputs [idx ]
527532 # Read image patch from the whole-slide image
@@ -546,11 +551,11 @@ class PatchDataset(PatchDatasetABC):
546551 `torch.utils.data.Dataset` class.
547552
548553 Attributes:
549- inputs:
554+ inputs (list or np.ndarray) :
550555 Either a list of patches, where each patch is a ndarray or a
551556 list of valid path with its extension be (".jpg", ".jpeg",
552557 ".tif", ".tiff", ".png") pointing to an image.
553- labels:
558+ labels (list) :
554559 List of labels for sample at the same index in `inputs`.
555560 Default is `None`.
556561
0 commit comments