Skip to content

Commit b542c9a

Browse files
✨ Define SemanticSegmentor with the New EngineABC (#866)
## Summary of Changes ### Major Additions - **Dask Integration:** - Added `dask` as a dependency and integrated Dask arrays and lazy computation throughout the engine and patch predictor code. - Added Dask-based merging, chunking, and memory-aware processing for large images and WSIs. - **Zarr Output Support:** - Added support for saving model predictions and intermediate results directly to Zarr format. - New CLI options and internal logic for Zarr output, including memory thresholding and chunked writes. - **SemanticSegmentor Engine:** - Added a new `SemanticSegmentor` engine with Dask/Zarr support and new test coverage (`test_semantic_segmentor.py`). - Added CLI entrypoint for `semantic_segmentor` and removed the old `semantic_segment` CLI. - **Enhanced CLI and Config:** - Added CLI options for memory threshold, unified worker options, and improved mask handling. - Updated YAML configs and sample data for new models and test images. - **Utilities and Validation:** - Added utility functions for minimal dtype casting, patch/stride validation, and improved error handling (e.g., `DimensionMismatchError`). - Improved annotation store conversion for Dask arrays and Zarr-backed outputs. - **Changes to `kwarg`** - Add `memory-threshold` - Unified `num-loader-workers` and `num-postproc-workers` into `num-workers` - Removed `cache_mode` as cache mode is automatically handled. --- ### Major Removals/Refactors - **Removed Old CLI and Redundant Code:** - Deleted the old `semantic_segment.py` CLI and replaced it with `semantic_segmentor.py`. - Removed legacy cache mode and patch prediction Zarr store tests. - **Refactored Model and Dataset APIs:** - Unified and simplified model inference APIs to always return arrays (not dicts) for batch outputs. - Refactored dataset classes to enforce patch shape validation and remove legacy “mode” logic. - **Test Cleanup:** - Removed or updated tests that relied on old APIs or cache mode. - Refactored test assertions for new output types and Dask array handling. - **API Consistency:** - Standardized function and argument names across engines, CLI, and utility modules. - Updated docstrings and type hints for clarity and consistency. --- ### Notable File Changes - **New:** - `tiatoolbox/cli/semantic_segmentor.py` - `tests/engines/test_semantic_segmentor.py` - **Removed:** - `tiatoolbox/cli/semantic_segment.py` - Old cache mode and patch Zarr store tests - **Heavily Modified:** - `engine_abc.py`, `patch_predictor.py`, `semantic_segmentor.py` - CLI modules and test suites - Dataset and utility modules for Dask/Zarr compatibility --- ### Impact - Enables scalable, parallel, and memory-efficient inference and output saving for large images. - Simplifies downstream analysis by supporting Zarr as a native output format. - Lays the groundwork for further Dask-based optimizations in TIAToolbox. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 31b7995 commit b542c9a

29 files changed

+3318
-2742
lines changed

requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ aiohttp>=3.8.1
44
albumentations>=1.3.0
55
bokeh>=3.1.1, <3.6.0
66
Click>=8.1.3, <8.2.0
7+
dask>=2025.10.0
78
defusedxml>=0.7.1
89
filelock>=3.9.0
910
flask>=2.2.2

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ def sample_wsi_dict(remote_sample: Callable) -> dict:
534534
"wsi4_4k_4k_svs",
535535
"wsi3_20k_20k_pred",
536536
"wsi4_4k_4k_pred",
537+
"wsi4_1k_1k_svs",
537538
]
538539
return {name: remote_sample(name) for name in file_names}
539540

tests/engines/test_engine_abc.py

Lines changed: 49 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
import logging
77
import shutil
88
from pathlib import Path
9-
from typing import TYPE_CHECKING, NoReturn
9+
from typing import NoReturn
1010

1111
import numpy as np
1212
import pytest
13+
import torch
1314
import torchvision.models as torch_models
1415
from typing_extensions import Unpack
1516

@@ -26,8 +27,7 @@
2627
)
2728
from tiatoolbox.models.engine.io_config import ModelIOConfigABC
2829

29-
if TYPE_CHECKING:
30-
import torch.nn
30+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
3131

3232

3333
class TestEngineABC(EngineABC):
@@ -69,6 +69,8 @@ def post_process_wsi(
6969
"""Post process WSI output."""
7070
return super().post_process_wsi(
7171
raw_predictions=raw_predictions,
72+
prediction_shape=(self.batch_size, 1),
73+
prediction_dtype=int,
7274
**kwargs,
7375
)
7476

@@ -79,7 +81,7 @@ def infer_wsi(
7981
**kwargs: dict,
8082
) -> dict | np.ndarray:
8183
"""Test infer_wsi."""
82-
return super().infer_wsi(
84+
return super().infer_wsi( # skipcq: PYL-E1121
8385
dataloader,
8486
save_path,
8587
**kwargs,
@@ -169,26 +171,26 @@ def test_ioconfig() -> NoReturn:
169171

170172

171173
def test_prepare_engines_save_dir(
172-
tmp_path: pytest.TempPathFactory,
174+
track_tmp_path: pytest.TempPathFactory,
173175
caplog: pytest.LogCaptureFixture,
174176
) -> NoReturn:
175177
"""Test prepare save directory for engines."""
176178
out_dir = prepare_engines_save_dir(
177-
save_dir=tmp_path / "patch_output",
179+
save_dir=track_tmp_path / "patch_output",
178180
patch_mode=True,
179181
overwrite=False,
180182
)
181183

182-
assert out_dir == tmp_path / "patch_output"
184+
assert out_dir == track_tmp_path / "patch_output"
183185
assert out_dir.exists()
184186

185187
out_dir = prepare_engines_save_dir(
186-
save_dir=tmp_path / "patch_output",
188+
save_dir=track_tmp_path / "patch_output",
187189
patch_mode=True,
188190
overwrite=True,
189191
)
190192

191-
assert out_dir == tmp_path / "patch_output"
193+
assert out_dir == track_tmp_path / "patch_output"
192194
assert out_dir.exists()
193195

194196
out_dir = prepare_engines_save_dir(
@@ -209,43 +211,43 @@ def test_prepare_engines_save_dir(
209211
)
210212

211213
out_dir = prepare_engines_save_dir(
212-
save_dir=tmp_path / "wsi_single_output",
214+
save_dir=track_tmp_path / "wsi_single_output",
213215
patch_mode=False,
214216
overwrite=False,
215217
)
216218

217-
assert out_dir == tmp_path / "wsi_single_output"
219+
assert out_dir == track_tmp_path / "wsi_single_output"
218220
assert out_dir.exists()
219221
assert r"When providing multiple whole-slide images / tiles" not in caplog.text
220222

221223
out_dir = prepare_engines_save_dir(
222-
save_dir=tmp_path / "wsi_multiple_output",
224+
save_dir=track_tmp_path / "wsi_multiple_output",
223225
patch_mode=False,
224226
overwrite=False,
225227
)
226228

227-
assert out_dir == tmp_path / "wsi_multiple_output"
229+
assert out_dir == track_tmp_path / "wsi_multiple_output"
228230
assert out_dir.exists()
229231
assert r"When providing multiple whole slide images" in caplog.text
230232

231233
# test for file overwrite with Path.mkdirs() method
232234
out_path = prepare_engines_save_dir(
233-
save_dir=tmp_path / "patch_output" / "output.zarr",
235+
save_dir=track_tmp_path / "patch_output" / "output.zarr",
234236
patch_mode=True,
235237
overwrite=True,
236238
)
237239
assert out_path.exists()
238240

239241
out_path = prepare_engines_save_dir(
240-
save_dir=tmp_path / "patch_output" / "output.zarr",
242+
save_dir=track_tmp_path / "patch_output" / "output.zarr",
241243
patch_mode=True,
242244
overwrite=True,
243245
)
244246
assert out_path.exists()
245247

246248
with pytest.raises(FileExistsError):
247249
out_path = prepare_engines_save_dir(
248-
save_dir=tmp_path / "patch_output" / "output.zarr",
250+
save_dir=track_tmp_path / "patch_output" / "output.zarr",
249251
patch_mode=True,
250252
overwrite=False,
251253
)
@@ -362,16 +364,16 @@ def test_engine_run_with_verbose() -> NoReturn:
362364
out = eng.run(
363365
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
364366
labels=list(range(10)),
365-
on_gpu=False,
367+
device=device,
366368
)
367369

368370
assert "probabilities" in out
369371
assert "labels" in out
370372

371373

372-
def test_patch_pred_zarr_store(tmp_path: pytest.TempPathFactory) -> NoReturn:
374+
def test_patch_pred_zarr_store(track_tmp_path: pytest.TempPathFactory) -> NoReturn:
373375
"""Test the engine run and patch pred store."""
374-
save_dir = tmp_path / "patch_output"
376+
save_dir = track_tmp_path / "patch_output"
375377

376378
eng = TestEngineABC(model="alexnet-kather100k")
377379
out = eng.run(
@@ -457,37 +459,6 @@ def test_patch_pred_zarr_store(tmp_path: pytest.TempPathFactory) -> NoReturn:
457459
)
458460

459461

460-
def test_cache_mode_patches(tmp_path: pytest.TempPathFactory) -> NoReturn:
461-
"""Test the caching mode."""
462-
save_dir = tmp_path / "patch_output"
463-
464-
eng = TestEngineABC(model="alexnet-kather100k")
465-
out = eng.run(
466-
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
467-
on_gpu=False,
468-
save_dir=save_dir,
469-
overwrite=True,
470-
cache_mode=True,
471-
)
472-
assert out.exists(), "Zarr output file does not exist"
473-
474-
output_file_name = "output2.zarr"
475-
cache_size = 4
476-
out = eng.run(
477-
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
478-
on_gpu=False,
479-
save_dir=save_dir,
480-
overwrite=True,
481-
cache_mode=True,
482-
cache_size=4,
483-
batch_size=8,
484-
output_file=output_file_name,
485-
)
486-
assert out.stem == output_file_name.split(".")[0]
487-
assert eng.batch_size == cache_size
488-
assert out.exists(), "Zarr output file does not exist"
489-
490-
491462
def test_get_dataloader(sample_svs: Path) -> None:
492463
"""Test the get_dataloader function."""
493464
eng = TestEngineABC(model="alexnet-kather100k")
@@ -514,82 +485,84 @@ def test_get_dataloader(sample_svs: Path) -> None:
514485
assert isinstance(dataloader.dataset, WSIPatchDataset)
515486

516487

517-
def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
488+
def test_io_config_delegation(
489+
track_tmp_path: Path, caplog: pytest.LogCaptureFixture
490+
) -> None:
518491
"""Test for delegating args to io config."""
519492
# test not providing config / full input info for not pretrained models
520493
model = CNNModel("resnet50")
521494
eng = TestEngineABC(model=model)
522495

523496
kwargs = {
524-
"patch_input_shape": [512, 512],
497+
"patch_input_shape": [224, 224],
525498
"input_resolutions": [{"units": "mpp", "resolution": 1.75}],
526499
}
527500
with caplog.at_level(logging.WARNING):
528501
eng.run(
529502
np.zeros((10, 224, 224, 3)),
530503
patch_mode=True,
531-
save_dir=tmp_path / "dump",
504+
save_dir=track_tmp_path / "dump",
532505
patch_input_shape=kwargs["patch_input_shape"],
533506
input_resolutions=kwargs["input_resolutions"],
534507
)
535508
assert "provide a valid ModelIOConfigABC" in caplog.text
536-
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
509+
shutil.rmtree(track_tmp_path / "dump", ignore_errors=True)
537510

538511
# test providing config / full input info for non pretrained models
539512
ioconfig = ModelIOConfigABC(
540-
patch_input_shape=(512, 512),
513+
patch_input_shape=(224, 224),
541514
stride_shape=(256, 256),
542515
input_resolutions=[{"resolution": 1.35, "units": "mpp"}],
543516
)
544517
eng.run(
545518
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
546519
patch_mode=True,
547-
save_dir=f"{tmp_path}/dump",
520+
save_dir=f"{track_tmp_path}/dump",
548521
ioconfig=ioconfig,
549522
)
550-
assert eng._ioconfig.patch_input_shape == (512, 512)
523+
assert eng._ioconfig.patch_input_shape == (224, 224)
551524
assert eng._ioconfig.stride_shape == (256, 256)
552525
assert eng._ioconfig.input_resolutions == [{"resolution": 1.35, "units": "mpp"}]
553-
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
526+
shutil.rmtree(track_tmp_path / "dump", ignore_errors=True)
554527

555528
eng.run(
556529
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
557530
patch_mode=True,
558-
save_dir=f"{tmp_path}/dump",
531+
save_dir=f"{track_tmp_path}/dump",
559532
**kwargs,
560533
)
561-
assert eng._ioconfig.patch_input_shape == [512, 512]
562-
assert eng._ioconfig.stride_shape == [512, 512]
534+
assert eng._ioconfig.patch_input_shape == [224, 224]
535+
assert eng._ioconfig.stride_shape == [224, 224]
563536
assert eng._ioconfig.input_resolutions == [{"resolution": 1.75, "units": "mpp"}]
564-
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
537+
shutil.rmtree(track_tmp_path / "dump", ignore_errors=True)
565538

566539
# test overwriting pretrained ioconfig
567540
eng = TestEngineABC(model="alexnet-kather100k")
568541
eng.run(
569-
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
542+
images=np.zeros((10, 300, 300, 3), dtype=np.uint8),
570543
patch_input_shape=(300, 300),
571544
stride_shape=(300, 300),
572545
input_resolutions=[{"units": "baseline", "resolution": 1.99}],
573546
patch_mode=True,
574-
save_dir=f"{tmp_path}/dump",
547+
save_dir=f"{track_tmp_path}/dump",
575548
)
576549
assert eng._ioconfig.patch_input_shape == (300, 300)
577550
assert eng._ioconfig.stride_shape == (300, 300)
578551
assert eng._ioconfig.input_resolutions[0]["resolution"] == 1.99
579552
assert eng._ioconfig.input_resolutions[0]["units"] == "baseline"
580-
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
553+
shutil.rmtree(track_tmp_path / "dump", ignore_errors=True)
581554

582555
eng.run(
583-
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
556+
images=np.zeros((10, 300, 300, 3), dtype=np.uint8),
584557
patch_input_shape=(300, 300),
585558
stride_shape=(300, 300),
586559
input_resolutions=None,
587560
patch_mode=True,
588-
save_dir=f"{tmp_path}/dump",
561+
save_dir=f"{track_tmp_path}/dump",
589562
)
590563
assert eng._ioconfig.patch_input_shape == (300, 300)
591564
assert eng._ioconfig.stride_shape == (300, 300)
592-
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
565+
shutil.rmtree(track_tmp_path / "dump", ignore_errors=True)
593566

594567
eng.ioconfig = None
595568
_ioconfig = eng._update_ioconfig(
@@ -618,3 +591,11 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
618591
stride_shape=(1, 1),
619592
input_resolutions=_kwargs["input_resolutions"],
620593
)
594+
595+
596+
def test_save_predictions_incorrect_output_type() -> None:
597+
"""Engine should raise TypeError if incorrect output type is requested."""
598+
eng = TestEngineABC(model="alexnet-kather100k")
599+
600+
with pytest.raises(TypeError, match=r".*Unsupported output type.* "):
601+
eng.save_predictions({"predictions": np.zeros((20, 9))}, output_type="random")

0 commit comments

Comments
 (0)