44from typing import TYPE_CHECKING
55
66import pytest
7+ import torch
8+ from torch import nn
79
810from tiatoolbox import rcParam
9- from tiatoolbox .models .architecture import get_pretrained_model
11+ from tiatoolbox .models .architecture import (
12+ fetch_pretrained_weights ,
13+ get_pretrained_model ,
14+ )
1015from tiatoolbox .models .models_abc import ModelABC
1116from tiatoolbox .utils import env_detection as toolbox_env
1217
1318if TYPE_CHECKING :
1419 import numpy as np
1520
1621
22+ class ProtoRaisesTypeError (ModelABC ):
23+ """Intentionally created to check for TypeError."""
24+
25+ # skipcq
26+ def __init__ (self : Proto ) -> None :
27+ """Initialize ProtoRaisesTypeError."""
28+ super ().__init__ ()
29+
30+ @staticmethod
31+ # skipcq
32+ def infer_batch () -> None :
33+ """Define infer batch."""
34+ # base class definition pass
35+
36+
37+ class ProtoNoPostProcess (ModelABC ):
38+ """Intentionally created to check No Post Processing."""
39+
40+ def forward (self : ProtoNoPostProcess ) -> None :
41+ """Define forward function."""
42+
43+ @staticmethod
44+ # skipcq
45+ def infer_batch () -> None :
46+ """Define infer batch."""
47+
48+
49+ class Proto (ModelABC ):
50+ """Intentionally created to check error."""
51+
52+ def __init__ (self : Proto ) -> None :
53+ """Initialize Proto."""
54+ super ().__init__ ()
55+ self .dummy_param = nn .Parameter (torch .empty (0 ))
56+
57+ @staticmethod
58+ # skipcq
59+ def postproc (image : np .ndarray ) -> np .ndarray :
60+ """Define postproc function."""
61+ return image - 2
62+
63+ # skipcq
64+ def forward (self : Proto ) -> None :
65+ """Define forward function."""
66+
67+ @staticmethod
68+ # skipcq
69+ def infer_batch () -> None :
70+ """Define infer batch."""
71+ pass # base class definition pass # noqa: PIE790
72+
73+
1774@pytest .mark .skipif (
1875 toolbox_env .running_on_ci () or not toolbox_env .has_gpu (),
1976 reason = "Local test on machine with GPU." ,
@@ -25,67 +82,37 @@ def test_get_pretrained_model() -> None:
2582 get_pretrained_model (pretrained_name , overwrite = True )
2683
2784
85+ @pytest .mark .skipif (
86+ toolbox_env .running_on_ci () or not toolbox_env .has_gpu (),
87+ reason = "Local test on CLI" ,
88+ )
89+ def test_model_to_cuda () -> None :
90+ """This Test should pass locally if GPU is available."""
91+ # Test on GPU
92+ # no GPU on Travis so this will crash
93+ model = Proto () # skipcq
94+ assert model .dummy_param .device .type == "cpu"
95+ model = model .to (device = "cuda" )
96+ assert isinstance (model , nn .Module )
97+ assert model .dummy_param .device .type == "cuda"
98+
99+
28100def test_model_abc () -> None :
29101 """Test API in model ABC."""
30102 # test missing definition for abstract
31103 with pytest .raises (TypeError ):
32104 # crash due to not defining forward, infer_batch, postproc
33105 ModelABC () # skipcq
34106
35- # intentionally created to check error
36- # skipcq
37- class Proto (ModelABC ):
38- # skipcq
39- def __init__ (self : Proto ) -> None :
40- super ().__init__ ()
41-
42- @staticmethod
43- # skipcq
44- def infer_batch () -> None :
45- pass # base class definition pass
46-
47107 # skipcq
48108 with pytest .raises (TypeError ):
49109 # crash due to not defining forward and postproc
50- Proto () # skipcq
110+ ProtoRaisesTypeError () # skipcq
51111
52- # intentionally create to check inheritance
53- # skipcq
54- class Proto (ModelABC ):
55- # skipcq
56- def forward (self : Proto ) -> None :
57- pass # base class definition pass
58-
59- @staticmethod
60- # skipcq
61- def infer_batch () -> None :
62- pass # base class definition pass
63-
64- model = Proto ()
112+ model = ProtoNoPostProcess ()
65113 assert model .preproc (1 ) == 1 , "Must be unchanged!"
66114 assert model .postproc (1 ) == 1 , "Must be unchanged!"
67115
68- # intentionally created to check error
69- # skipcq
70- class Proto (ModelABC ):
71- # skipcq
72- def __init__ (self : Proto ) -> None :
73- super ().__init__ ()
74-
75- @staticmethod
76- # skipcq
77- def postproc (image : np .ndarray ) -> None :
78- return image - 2
79-
80- # skipcq
81- def forward (self : Proto ) -> None :
82- pass # base class definition pass
83-
84- @staticmethod
85- # skipcq
86- def infer_batch () -> None :
87- pass # base class definition pass
88-
89116 model = Proto () # skipcq
90117 # test assign un-callable to preproc_func/postproc_func
91118 with pytest .raises (ValueError , match = r".*callable*" ):
@@ -111,3 +138,13 @@ def infer_batch() -> None:
111138 # coverage setter check
112139 model .postproc_func = None # skipcq: PYL-W0201
113140 assert model .postproc_func (2 ) == 0
141+
142+ # Test on CPU
143+ model = model .to (device = "cpu" )
144+ assert isinstance (model , nn .Module )
145+ assert model .dummy_param .device .type == "cpu"
146+
147+ # Test load_weights_from_file() method
148+ weights_path = fetch_pretrained_weights ("alexnet-kather100k" )
149+ with pytest .raises (RuntimeError , match = r".*loading state_dict*" ):
150+ _ = model .load_weights_from_file (weights_path )
0 commit comments