Skip to content

Commit 77921c9

Browse files
committed
🐛 Fix tiatoolbox/models/dataset/dataset_abc.py for annotations
1 parent 02115bf commit 77921c9

File tree

1 file changed

+40
-35
lines changed

1 file changed

+40
-35
lines changed

tiatoolbox/models/dataset/dataset_abc.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@
2020
from multiprocessing.managers import Namespace
2121

2222
from tiatoolbox.models.engine.io_config import IOSegmentorConfig
23+
from tiatoolbox.typing import IntPair, Resolution, Units
2324

2425

2526
class PatchDatasetABC(ABC, torch.utils.data.Dataset):
2627
"""Define abstract base class for patch dataset."""
2728

2829
def __init__(
29-
self,
30+
self: PatchDatasetABC,
3031
) -> None:
3132
"""Initialize :class:`PatchDatasetABC`."""
3233
super().__init__()
@@ -36,7 +37,7 @@ def __init__(
3637
self.labels = []
3738

3839
@staticmethod
39-
def _check_shape_integrity(shapes):
40+
def _check_shape_integrity(shapes: list | np.ndarray) -> None:
4041
"""Checks the integrity of input shapes.
4142
4243
Args:
@@ -56,7 +57,7 @@ def _check_shape_integrity(shapes):
5657
msg = "Images must have the same dimensions."
5758
raise ValueError(msg)
5859

59-
def _check_input_integrity(self, mode):
60+
def _check_input_integrity(self: PatchDatasetABC, mode: str) -> None:
6061
"""Check that variables received during init are valid.
6162
6263
These checks include:
@@ -113,11 +114,15 @@ def _check_input_integrity(self, mode):
113114
raise ValueError(msg)
114115

115116
@staticmethod
116-
def load_img(path):
117+
def load_img(path: str | Path) -> np.ndarray:
117118
"""Load an image from a provided path.
118119
119120
Args:
120-
path (str): Path to an image file.
121+
path (str or Path): Path to an image file.
122+
123+
Returns:
124+
:class:`numpy.ndarray`:
125+
Image as a numpy array.
121126
122127
"""
123128
path = Path(path)
@@ -129,12 +134,12 @@ def load_img(path):
129134
return imread(path, as_uint8=False)
130135

131136
@staticmethod
132-
def preproc(image):
137+
def preproc(image: np.ndarray) -> np.ndarray:
133138
"""Define the pre-processing of this class of loader."""
134139
return image
135140

136141
@property
137-
def preproc_func(self):
142+
def preproc_func(self: PatchDatasetABC) -> Callable:
138143
"""Return the current pre-processing function of this instance.
139144
140145
The returned function is expected to behave as follows:
@@ -144,7 +149,7 @@ def preproc_func(self):
144149
return self._preproc
145150

146151
@preproc_func.setter
147-
def preproc_func(self, func):
152+
def preproc_func(self: PatchDatasetABC, func: Callable) -> None:
148153
"""Set the pre-processing function for this instance.
149154
150155
If `func=None`, the method will default to `self.preproc`.
@@ -162,12 +167,12 @@ def preproc_func(self, func):
162167
msg = f"{func} is not callable!"
163168
raise ValueError(msg)
164169

165-
def __len__(self) -> int:
170+
def __len__(self: PatchDatasetABC) -> int:
166171
"""Return the length of the instance attributes."""
167172
return len(self.inputs)
168173

169174
@abstractmethod
170-
def __getitem__(self, idx):
175+
def __getitem__(self: PatchDatasetABC, idx: int) -> None:
171176
"""Get an item from the dataset."""
172177
... # pragma: no cover
173178

@@ -213,12 +218,12 @@ class WSIStreamDataset(torch_data.Dataset):
213218
"""
214219

215220
def __init__(
216-
self,
221+
self: WSIStreamDataset,
217222
ioconfig: IOSegmentorConfig,
218223
wsi_paths: list[str | Path],
219224
mp_shared_space: Namespace,
220225
preproc: Callable[[np.ndarray], np.ndarray] | None = None,
221-
mode="wsi",
226+
mode: str = "wsi",
222227
) -> None:
223228
"""Initialize :class:`WSIStreamDataset`."""
224229
super().__init__()
@@ -240,7 +245,7 @@ def __init__(
240245
self.wsi_idx = None # to be received externally via thread communication
241246
self.reader = None
242247

243-
def _get_reader(self, img_path):
248+
def _get_reader(self: WSIStreamDataset, img_path: str | Path) -> WSIReader:
244249
"""Get appropriate reader for input path."""
245250
img_path = Path(img_path)
246251
if self.mode == "wsi":
@@ -261,12 +266,12 @@ def _get_reader(self, img_path):
261266
info=metadata,
262267
)
263268

264-
def __len__(self) -> int:
269+
def __len__(self: WSIStreamDataset) -> int:
265270
"""Return the length of the instance attributes."""
266271
return len(self.mp_shared_space.patch_inputs)
267272

268273
@staticmethod
269-
def collate_fn(batch):
274+
def collate_fn(batch: list | np.ndarray) -> torch.Tensor:
270275
"""Prototype to handle reading exception.
271276
272277
This will exclude any sample with `None` from the batch. As
@@ -278,7 +283,7 @@ def collate_fn(batch):
278283
batch = [v for v in batch if v is not None]
279284
return torch.utils.data.dataloader.default_collate(batch)
280285

281-
def __getitem__(self, idx: int):
286+
def __getitem__(self: WSIStreamDataset, idx: int) -> tuple:
282287
"""Get an item from the dataset."""
283288
# ! no need to lock as we do not modify source value in shared space
284289
if self.wsi_idx != self.mp_shared_space.wsi_idx:
@@ -341,18 +346,18 @@ class WSIPatchDataset(PatchDatasetABC):
341346
"""
342347

343348
def __init__( # noqa: PLR0913, PLR0915
344-
self,
345-
img_path,
346-
mode="wsi",
347-
mask_path=None,
348-
patch_input_shape=None,
349-
stride_shape=None,
350-
resolution=None,
351-
units=None,
352-
min_mask_ratio=0,
353-
preproc_func=None,
349+
self: WSIPatchDataset,
350+
img_path: str | Path,
351+
mode: str = "wsi",
352+
mask_path: str | Path | None = None,
353+
patch_input_shape: IntPair = None,
354+
stride_shape: IntPair = None,
355+
resolution: Resolution = None,
356+
units: Units = None,
357+
min_mask_ratio: float = 0,
358+
preproc_func: Callable | None = None,
354359
*,
355-
auto_get_mask=True,
360+
auto_get_mask: bool = True,
356361
) -> None:
357362
"""Create a WSI-level patch dataset.
358363
@@ -377,20 +382,20 @@ def __init__( # noqa: PLR0913, PLR0915
377382
stride shape to read at requested `resolution` and
378383
`units`. Expected to be positive and of (height, width).
379384
Note, this is not at level 0.
380-
resolution:
385+
resolution (Resolution):
381386
Check (:class:`.WSIReader`) for details. When
382387
`mode='tile'`, value is fixed to be `resolution=1.0` and
383388
`units='baseline'` units: check (:class:`.WSIReader`) for
384389
details.
385-
units:
390+
units (Units):
386391
Units in which `resolution` is defined.
387-
auto_get_mask:
392+
auto_get_mask (bool):
388393
If `True`, then automatically get simple threshold mask using
389394
WSIReader.tissue_mask() function.
390-
min_mask_ratio:
395+
min_mask_ratio (float):
391396
Only patches with positive area percentage above this value are
392397
included. Defaults to 0.
393-
preproc_func:
398+
preproc_func (Callable):
394399
Preprocessing function used to transform the input data. If
395400
supplied, the function will be called on each patch before
396401
returning it.
@@ -521,7 +526,7 @@ def __init__( # noqa: PLR0913, PLR0915
521526
# Perform check on the input
522527
self._check_input_integrity(mode="wsi")
523528

524-
def __getitem__(self, idx):
529+
def __getitem__(self: WSIPatchDataset, idx: int) -> dict:
525530
"""Get an item from the dataset."""
526531
coords = self.inputs[idx]
527532
# Read image patch from the whole-slide image
@@ -546,11 +551,11 @@ class PatchDataset(PatchDatasetABC):
546551
`torch.utils.data.Dataset` class.
547552
548553
Attributes:
549-
inputs:
554+
inputs (list or np.ndarray):
550555
Either a list of patches, where each patch is a ndarray or a
551556
list of valid path with its extension be (".jpg", ".jpeg",
552557
".tif", ".tiff", ".png") pointing to an image.
553-
labels:
558+
labels (list):
554559
List of labels for sample at the same index in `inputs`.
555560
Default is `None`.
556561

0 commit comments

Comments
 (0)