Skip to content

Commit e86f440

Browse files
committed
Improving engine_abc based on review comments
1 parent 1f85c50 commit e86f440

File tree

2 files changed

+30
-16
lines changed

2 files changed

+30
-16
lines changed

tests/engines/test_engine_abc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def test_engine_abc_incorrect_model_type() -> NoReturn:
6060
TypeError,
6161
match="Input model must be a string or 'torch.nn.Module'.",
6262
):
63-
# Can't instantiate abstract class with abstract methods
6463
TestEngineABC(model=1)
6564

6665

tiatoolbox/models/engine/engine_abc.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class EngineABC(ABC):
112112
Whether to output logging information.
113113
114114
Attributes:
115-
images (str or :obj:`pathlib.Path` or :obj:`numpy.ndarray`):
115+
images (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`):
116116
A NHWC image or a path to WSI.
117117
patch_mode (str):
118118
Whether to treat input image as a patch or WSI.
@@ -144,7 +144,7 @@ class EngineABC(ABC):
144144
from either `level`, `power` or `mpp`. Please see
145145
:obj:`WSIReader` for details.
146146
patch_input_shape (tuple):
147-
Size of patches input to the model. Patches are at
147+
Shape of patches input to the model as tupled of HW. Patches are at
148148
requested read resolution, not with respect to level 0,
149149
and must be positive.
150150
stride_shape (tuple):
@@ -176,11 +176,6 @@ class EngineABC(ABC):
176176
>>> engine = EngineABC(model="resnet18-kather100k")
177177
>>> output = engine.run(data, patch_mode=True)
178178
179-
>>> # list of 2 image patch files as input
180-
>>> data = ['path/img.png', 'path/img.png']
181-
>>> engine = EngineABC(model="resnet18-kather100k")
182-
>>> output = engine.run(data, patch_mode=False)
183-
184179
>>> # list of 2 image files as input
185180
>>> image = ['path/image1.png', 'path/image2.png']
186181
>>> engine = EngineABC(model="resnet18-kather100k")
@@ -343,23 +338,43 @@ def infer_patches(
343338

344339
return raw_predictions
345340

346-
def post_process_patches(
341+
def setup_patch_dataset(
347342
self: EngineABC,
348343
raw_predictions: dict,
349344
output_type: str,
350345
save_dir: Path | None = None,
351346
**kwargs: dict,
352347
) -> Path | AnnotationStore:
353-
"""Post-process image patches."""
354-
"""Stores as an Annotation Store or Zarr (default) and returns the Path"""
348+
"""Post-process image patches.
355349
356-
if not save_dir and self.patch_mode and output_type != "AnnotationStore":
357-
return raw_predictions
350+
Args:
351+
raw_predictions (dict):
352+
A dictionary of patch prediction information.
353+
save_dir (Path):
354+
Optional Output Path to directory to save the patch dataset output to a
355+
`.zarr` or `.db` file, provided patch_mode is True. if the patch_mode is
356+
False then save_dir is required.
357+
output_type (str):
358+
The desired output type for resulting patch dataset.
359+
**kwargs (dict):
360+
Keyword Args to update setup_patch_dataset() method attributes.
361+
362+
Returns: (dict, Path, :class:`SQLiteStore`):
363+
if the output_type is "AnnotationStore", the function returns the patch
364+
predictor output as an SQLiteStore containing Annotations for each or the
365+
Path to a `.db` file depending on whether a save_dir Path is provided.
366+
Otherwise, the function defaults to returning patch predictor output, either
367+
as a dict or the Path to a `.zarr` file depending on whether a save_dir Path
368+
is provided.
358369
359-
if not save_dir:
360-
msg = "`save_dir` not specified."
370+
"""
371+
if not save_dir and not self.patch_mode:
372+
msg = "`save_dir` must be specified when patch_mode is False."
361373
raise OSError(msg)
362374

375+
if not save_dir and output_type != "AnnotationStore":
376+
return raw_predictions
377+
363378
output_file = (
364379
kwargs["output_file"] and kwargs.pop("output_file")
365380
if "output_file" in kwargs
@@ -618,7 +633,7 @@ def run(
618633
raw_predictions = self.infer_patches(
619634
data_loader=data_loader,
620635
)
621-
return self.post_process_patches(
636+
return self.setup_patch_dataset(
622637
raw_predictions=raw_predictions,
623638
output_type=output_type,
624639
save_dir=save_dir,

0 commit comments

Comments
 (0)