Skip to content

Commit d91518a

Browse files
committed
🔀 Merge develop into dev-define-engine-abc
2 parents 77921c9 + 2e9802b commit d91518a

File tree

2 files changed

+135
-1
lines changed

2 files changed

+135
-1
lines changed

tests/test_utils.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from tests.test_annotation_stores import cell_polygon
2020
from tiatoolbox import utils
21+
from tiatoolbox.annotation.storage import SQLiteStore
2122
from tiatoolbox.models.architecture import fetch_pretrained_weights
2223
from tiatoolbox.utils import misc
2324
from tiatoolbox.utils.exceptions import FileNotSupportedError
@@ -734,6 +735,7 @@ def test_sub_pixel_read_incorrect_read_func_return() -> None:
734735
image = np.ones((10, 10))
735736

736737
def read_func(*args: tuple, **kwargs: dict) -> np.ndarray: # noqa: ARG001
738+
"""Dummy read function for tests."""
737739
return np.ones((5, 5))
738740

739741
with pytest.raises(ValueError, match="incorrect size"):
@@ -752,6 +754,7 @@ def test_sub_pixel_read_empty_read_func_return() -> None:
752754
image = np.ones((10, 10))
753755

754756
def read_func(*args: tuple, **kwargs: dict) -> np.ndarray: # noqa: ARG001
757+
"""Dummy read function for tests."""
755758
return np.ones((0, 0))
756759

757760
with pytest.raises(ValueError, match="is empty"):
@@ -1624,3 +1627,69 @@ def test_imwrite(tmp_path: Path) -> NoReturn:
16241627
tmp_path / "thisfolderdoesnotexist" / "test_imwrite.jpg",
16251628
img,
16261629
)
1630+
1631+
1632+
def test_patch_pred_store() -> None:
1633+
"""Test patch_pred_store."""
1634+
# Define a mock patch_output
1635+
patch_output = {
1636+
"predictions": [1, 0, 1],
1637+
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
1638+
"other": "other",
1639+
}
1640+
1641+
store = misc.patch_pred_store(patch_output, (1.0, 1.0))
1642+
1643+
# Check that its an SQLiteStore containing the expected annotations
1644+
assert isinstance(store, SQLiteStore)
1645+
assert len(store) == 3
1646+
for annotation in store.values():
1647+
assert annotation.geometry.area == 1
1648+
assert annotation.properties["type"] in [0, 1]
1649+
assert "other" not in annotation.properties
1650+
1651+
patch_output.pop("coordinates")
1652+
# check correct error is raised if coordinates are missing
1653+
with pytest.raises(ValueError, match="coordinates"):
1654+
misc.patch_pred_store(patch_output, (1.0, 1.0))
1655+
1656+
1657+
def test_patch_pred_store_cdict() -> None:
1658+
"""Test patch_pred_store with a class dict."""
1659+
# Define a mock patch_output
1660+
patch_output = {
1661+
"predictions": [1, 0, 1],
1662+
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
1663+
"probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]],
1664+
"labels": [1, 0, 1],
1665+
"other": "other",
1666+
}
1667+
class_dict = {0: "class0", 1: "class1"}
1668+
store = misc.patch_pred_store(patch_output, (1.0, 1.0), class_dict=class_dict)
1669+
1670+
# Check that its an SQLiteStore containing the expected annotations
1671+
assert isinstance(store, SQLiteStore)
1672+
assert len(store) == 3
1673+
for annotation in store.values():
1674+
assert annotation.geometry.area == 1
1675+
assert annotation.properties["label"] in ["class0", "class1"]
1676+
assert annotation.properties["type"] in ["class0", "class1"]
1677+
assert "other" not in annotation.properties
1678+
1679+
1680+
def test_patch_pred_store_sf() -> None:
1681+
"""Test patch_pred_store with scale factor."""
1682+
# Define a mock patch_output
1683+
patch_output = {
1684+
"predictions": [1, 0, 1],
1685+
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
1686+
"probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]],
1687+
"labels": [1, 0, 1],
1688+
}
1689+
store = misc.patch_pred_store(patch_output, (2.0, 2.0))
1690+
1691+
# Check that its an SQLiteStore containing the expected annotations
1692+
assert isinstance(store, SQLiteStore)
1693+
assert len(store) == 3
1694+
for annotation in store.values():
1695+
assert annotation.geometry.area == 4

tiatoolbox/utils/misc.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import yaml
1818
from filelock import FileLock
1919
from shapely.affinity import translate
20+
from shapely.geometry import Polygon
2021
from shapely.geometry import shape as feature2geometry
2122
from skimage import exposure
2223

@@ -859,7 +860,8 @@ def select_device(*, on_gpu: bool) -> str:
859860
"""Selects the appropriate device as requested.
860861
861862
Args:
862-
on_gpu (bool): Selects gpu if True.
863+
on_gpu (bool):
864+
Selects gpu if True.
863865
864866
Returns:
865867
str:
@@ -1174,3 +1176,66 @@ def add_from_dat(
11741176

11751177
logger.info("Added %d annotations.", len(anns))
11761178
store.append_many(anns)
1179+
1180+
1181+
def patch_pred_store(
1182+
patch_output: dict,
1183+
scale_factor: tuple[int, int],
1184+
class_dict: dict | None = None,
1185+
) -> AnnotationStore:
1186+
"""Create an SQLiteStore containing Annotations for each patch.
1187+
1188+
Args:
1189+
patch_output (dict): A dictionary of patch prediction information. Important
1190+
keys are "probabilities", "predictions", "coordinates", and "labels".
1191+
scale_factor (tuple[int, int]): The scale factor to use when loading the
1192+
annotations. All coordinates will be multiplied by this factor to allow
1193+
conversion of annotations saved at non-baseline resolution to baseline.
1194+
Should be model_mpp/slide_mpp.
1195+
class_dict (dict): Optional dictionary mapping class indices to class names.
1196+
1197+
Returns:
1198+
SQLiteStore: An SQLiteStore containing Annotations for each patch.
1199+
1200+
"""
1201+
if "coordinates" not in patch_output:
1202+
# we cant create annotations without coordinates
1203+
msg = "Patch output must contain coordinates."
1204+
raise ValueError(msg)
1205+
# get relevant keys
1206+
class_probs = patch_output.get("probabilities", [])
1207+
preds = patch_output.get("predictions", [])
1208+
patch_coords = np.array(patch_output.get("coordinates", []))
1209+
if not np.all(np.array(scale_factor) == 1):
1210+
patch_coords = patch_coords * (np.tile(scale_factor, 2)) # to baseline mpp
1211+
labels = patch_output.get("labels", [])
1212+
# get classes to consider
1213+
if len(class_probs) == 0:
1214+
classes_predicted = np.unique(preds).tolist()
1215+
else:
1216+
classes_predicted = range(len(class_probs[0]))
1217+
if class_dict is None:
1218+
# if no class dict create a default one
1219+
class_dict = {i: i for i in np.unique(preds + labels).tolist()}
1220+
annotations = []
1221+
# find what keys we need to save
1222+
keys = ["predictions"]
1223+
keys = keys + [key for key in ["probabilities", "labels"] if key in patch_output]
1224+
1225+
# put patch predictions into a store
1226+
annotations = []
1227+
for i, pred in enumerate(preds):
1228+
if "probabilities" in keys:
1229+
props = {
1230+
f"prob_{class_dict[j]}": class_probs[i][j] for j in classes_predicted
1231+
}
1232+
else:
1233+
props = {}
1234+
if "labels" in keys:
1235+
props["label"] = class_dict[labels[i]]
1236+
props["type"] = class_dict[pred]
1237+
annotations.append(Annotation(Polygon.from_bounds(*patch_coords[i]), props))
1238+
store = SQLiteStore()
1239+
keys = store.append_many(annotations, [str(i) for i in range(len(annotations))])
1240+
1241+
return store

0 commit comments

Comments
 (0)