Skip to content

Commit cca7443

Browse files
AbishekRajVGpre-commit-ci[bot]shaneahmed
authored
♻️ Update model_to() and load_torch_model() methods in ModelABC (#733)
- Adds `model.to(device)` and `model.load_model_from_file()` functionality. --------- Signed-off-by: Shan E Ahmed Raza <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Shan E Ahmed Raza <[email protected]>
1 parent 4a041ae commit cca7443

File tree

2 files changed

+139
-51
lines changed

2 files changed

+139
-51
lines changed

tests/models/test_abc.py

Lines changed: 85 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,73 @@
44
from typing import TYPE_CHECKING
55

66
import pytest
7+
import torch
8+
from torch import nn
79

810
from tiatoolbox import rcParam
9-
from tiatoolbox.models.architecture import get_pretrained_model
11+
from tiatoolbox.models.architecture import (
12+
fetch_pretrained_weights,
13+
get_pretrained_model,
14+
)
1015
from tiatoolbox.models.models_abc import ModelABC
1116
from tiatoolbox.utils import env_detection as toolbox_env
1217

1318
if TYPE_CHECKING:
1419
import numpy as np
1520

1621

22+
class ProtoRaisesTypeError(ModelABC):
23+
"""Intentionally created to check for TypeError."""
24+
25+
# skipcq
26+
def __init__(self: Proto) -> None:
27+
"""Initialize ProtoRaisesTypeError."""
28+
super().__init__()
29+
30+
@staticmethod
31+
# skipcq
32+
def infer_batch() -> None:
33+
"""Define infer batch."""
34+
# base class definition pass
35+
36+
37+
class ProtoNoPostProcess(ModelABC):
38+
"""Intentionally created to check No Post Processing."""
39+
40+
def forward(self: ProtoNoPostProcess) -> None:
41+
"""Define forward function."""
42+
43+
@staticmethod
44+
# skipcq
45+
def infer_batch() -> None:
46+
"""Define infer batch."""
47+
48+
49+
class Proto(ModelABC):
50+
"""Intentionally created to check error."""
51+
52+
def __init__(self: Proto) -> None:
53+
"""Initialize Proto."""
54+
super().__init__()
55+
self.dummy_param = nn.Parameter(torch.empty(0))
56+
57+
@staticmethod
58+
# skipcq
59+
def postproc(image: np.ndarray) -> np.ndarray:
60+
"""Define postproc function."""
61+
return image - 2
62+
63+
# skipcq
64+
def forward(self: Proto) -> None:
65+
"""Define forward function."""
66+
67+
@staticmethod
68+
# skipcq
69+
def infer_batch() -> None:
70+
"""Define infer batch."""
71+
pass # base class definition pass # noqa: PIE790
72+
73+
1774
@pytest.mark.skipif(
1875
toolbox_env.running_on_ci() or not toolbox_env.has_gpu(),
1976
reason="Local test on machine with GPU.",
@@ -25,67 +82,37 @@ def test_get_pretrained_model() -> None:
2582
get_pretrained_model(pretrained_name, overwrite=True)
2683

2784

85+
@pytest.mark.skipif(
86+
toolbox_env.running_on_ci() or not toolbox_env.has_gpu(),
87+
reason="Local test on CLI",
88+
)
89+
def test_model_to_cuda() -> None:
90+
"""This Test should pass locally if GPU is available."""
91+
# Test on GPU
92+
# no GPU on Travis so this will crash
93+
model = Proto() # skipcq
94+
assert model.dummy_param.device.type == "cpu"
95+
model = model.to(device="cuda")
96+
assert isinstance(model, nn.Module)
97+
assert model.dummy_param.device.type == "cuda"
98+
99+
28100
def test_model_abc() -> None:
29101
"""Test API in model ABC."""
30102
# test missing definition for abstract
31103
with pytest.raises(TypeError):
32104
# crash due to not defining forward, infer_batch, postproc
33105
ModelABC() # skipcq
34106

35-
# intentionally created to check error
36-
# skipcq
37-
class Proto(ModelABC):
38-
# skipcq
39-
def __init__(self: Proto) -> None:
40-
super().__init__()
41-
42-
@staticmethod
43-
# skipcq
44-
def infer_batch() -> None:
45-
pass # base class definition pass
46-
47107
# skipcq
48108
with pytest.raises(TypeError):
49109
# crash due to not defining forward and postproc
50-
Proto() # skipcq
110+
ProtoRaisesTypeError() # skipcq
51111

52-
# intentionally create to check inheritance
53-
# skipcq
54-
class Proto(ModelABC):
55-
# skipcq
56-
def forward(self: Proto) -> None:
57-
pass # base class definition pass
58-
59-
@staticmethod
60-
# skipcq
61-
def infer_batch() -> None:
62-
pass # base class definition pass
63-
64-
model = Proto()
112+
model = ProtoNoPostProcess()
65113
assert model.preproc(1) == 1, "Must be unchanged!"
66114
assert model.postproc(1) == 1, "Must be unchanged!"
67115

68-
# intentionally created to check error
69-
# skipcq
70-
class Proto(ModelABC):
71-
# skipcq
72-
def __init__(self: Proto) -> None:
73-
super().__init__()
74-
75-
@staticmethod
76-
# skipcq
77-
def postproc(image: np.ndarray) -> None:
78-
return image - 2
79-
80-
# skipcq
81-
def forward(self: Proto) -> None:
82-
pass # base class definition pass
83-
84-
@staticmethod
85-
# skipcq
86-
def infer_batch() -> None:
87-
pass # base class definition pass
88-
89116
model = Proto() # skipcq
90117
# test assign un-callable to preproc_func/postproc_func
91118
with pytest.raises(ValueError, match=r".*callable*"):
@@ -111,3 +138,13 @@ def infer_batch() -> None:
111138
# coverage setter check
112139
model.postproc_func = None # skipcq: PYL-W0201
113140
assert model.postproc_func(2) == 0
141+
142+
# Test on CPU
143+
model = model.to(device="cpu")
144+
assert isinstance(model, nn.Module)
145+
assert model.dummy_param.device.type == "cpu"
146+
147+
# Test load_weights_from_file() method
148+
weights_path = fetch_pretrained_weights("alexnet-kather100k")
149+
with pytest.raises(RuntimeError, match=r".*loading state_dict*"):
150+
_ = model.load_weights_from_file(weights_path)

tiatoolbox/models/models_abc.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
from abc import ABC, abstractmethod
55
from typing import TYPE_CHECKING, Any, Callable
66

7-
from torch import nn
7+
import torch
8+
from torch import device as torch_device
89

910
if TYPE_CHECKING: # pragma: no cover
11+
from pathlib import Path
12+
1013
import numpy as np
1114

1215

@@ -31,7 +34,7 @@ def output_resolutions(self: IOConfigABC) -> None:
3134
raise NotImplementedError
3235

3336

34-
class ModelABC(ABC, nn.Module):
37+
class ModelABC(ABC, torch.nn.Module):
3538
"""Abstract base class for models used in tiatoolbox."""
3639

3740
def __init__(self: ModelABC) -> None:
@@ -48,7 +51,12 @@ def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None:
4851

4952
@staticmethod
5053
@abstractmethod
51-
def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> None:
54+
def infer_batch(
55+
model: torch.nn.Module,
56+
batch_data: np.ndarray,
57+
*,
58+
on_gpu: bool,
59+
) -> None:
5260
"""Run inference on an input batch.
5361
5462
Contains logic for forward operation as well as I/O aggregation.
@@ -135,3 +143,46 @@ def postproc_func(self: ModelABC, func: Callable) -> None:
135143
self._postproc = self.postproc
136144
else:
137145
self._postproc = func
146+
147+
def to(self: ModelABC, device: str = "cpu") -> torch.nn.Module:
148+
"""Transfers model to cpu/gpu.
149+
150+
Args:
151+
model (torch.nn.Module):
152+
PyTorch defined model.
153+
device (str):
154+
Transfers model to the specified device. Default is "cpu".
155+
156+
Returns:
157+
torch.nn.Module:
158+
The model after being moved to cpu/gpu.
159+
160+
"""
161+
device = torch_device(device)
162+
model = super().to(device)
163+
164+
# If target device istorch.cuda and more
165+
# than one GPU is available, use DataParallel
166+
if device.type == "cuda" and torch.cuda.device_count() > 1:
167+
model = torch.nn.DataParallel(model) # pragma: no cover
168+
169+
return model
170+
171+
def load_weights_from_file(self: ModelABC, weights: str | Path) -> torch.nn.Module:
172+
"""Helper function to load a torch model.
173+
174+
Args:
175+
self (ModelABC):
176+
A torch model as :class:`ModelABC`.
177+
weights (str or Path):
178+
Path to pretrained weights.
179+
180+
Returns:
181+
torch.nn.Module:
182+
Torch model with pretrained weights loaded on CPU.
183+
184+
"""
185+
# ! assume to be saved in single GPU mode
186+
# always load on to the CPU
187+
saved_state_dict = torch.load(weights, map_location="cpu")
188+
return super().load_state_dict(saved_state_dict, strict=True)

0 commit comments

Comments
 (0)