5
5
from torch .utils .data import TensorDataset , DataLoader
6
6
7
7
import torchensemble
8
+ from torchensemble .utils import io
8
9
from torchensemble .utils .logging import set_logger
9
10
10
11
24
25
torchensemble .AdversarialTrainingRegressor ]
25
26
26
27
28
+ # Remove randomness
29
+ np .random .seed (0 )
30
+ torch .manual_seed (0 )
31
+
27
32
set_logger ("pytest_all_models" )
28
33
29
34
@@ -66,12 +71,10 @@ def forward(self, X):
66
71
67
72
# Testing data
68
73
X_test = torch .Tensor (np .array (([0.5 , 0.5 ],
69
- [0.6 , 0.6 ],
70
- [0.7 , 0.7 ],
71
- [0.8 , 0.8 ])))
74
+ [0.6 , 0.6 ])))
72
75
73
- y_test_clf = torch .LongTensor (np .array (([1 , 1 , 0 , 0 ])))
74
- y_test_reg = torch .FloatTensor (np .array (([0.5 , 0.6 , 0.7 , 0.8 ])))
76
+ y_test_clf = torch .LongTensor (np .array (([1 , 0 ])))
77
+ y_test_reg = torch .FloatTensor (np .array (([0.5 , 0.6 ])))
75
78
y_test_reg = y_test_reg .view (- 1 , 1 )
76
79
77
80
@@ -94,9 +97,9 @@ def test_clf(clf):
94
97
95
98
# Prepare data
96
99
train = TensorDataset (X_train , y_train_clf )
97
- train_loader = DataLoader (train , batch_size = 2 )
100
+ train_loader = DataLoader (train , batch_size = 2 , shuffle = False )
98
101
test = TensorDataset (X_test , y_test_clf )
99
- test_loader = DataLoader (test , batch_size = 2 )
102
+ test_loader = DataLoader (test , batch_size = 2 , shuffle = False )
100
103
101
104
# Snapshot ensemble needs more epochs
102
105
if isinstance (model , torchensemble .SnapshotEnsembleClassifier ):
@@ -109,7 +112,15 @@ def test_clf(clf):
109
112
save_model = True )
110
113
111
114
# Test
112
- model .predict (test_loader )
115
+ prev_acc = model .predict (test_loader )
116
+
117
+ # Reload
118
+ new_model = clf (estimator = MLP_clf , n_estimators = n_estimators , cuda = False )
119
+ io .load (new_model )
120
+
121
+ post_acc = new_model .predict (test_loader )
122
+
123
+ assert prev_acc == post_acc # ensure the same performance
113
124
114
125
115
126
@pytest .mark .parametrize ("reg" , all_reg )
@@ -131,9 +142,9 @@ def test_reg(reg):
131
142
132
143
# Prepare data
133
144
train = TensorDataset (X_train , y_train_reg )
134
- train_loader = DataLoader (train , batch_size = 2 )
145
+ train_loader = DataLoader (train , batch_size = 2 , shuffle = False )
135
146
test = TensorDataset (X_test , y_test_reg )
136
- test_loader = DataLoader (test , batch_size = 2 )
147
+ test_loader = DataLoader (test , batch_size = 2 , shuffle = False )
137
148
138
149
# Snapshot ensemble needs more epochs
139
150
if isinstance (model , torchensemble .SnapshotEnsembleRegressor ):
@@ -146,7 +157,15 @@ def test_reg(reg):
146
157
save_model = True )
147
158
148
159
# Test
149
- model .predict (test_loader )
160
+ prev_mse = model .predict (test_loader )
161
+
162
+ # Reload
163
+ new_model = reg (estimator = MLP_reg , n_estimators = n_estimators , cuda = False )
164
+ io .load (new_model )
165
+
166
+ post_mse = new_model .predict (test_loader )
167
+
168
+ assert prev_mse == post_mse # ensure the same performance
150
169
151
170
152
171
@pytest .mark .parametrize ("method" , all_clf + all_reg )
0 commit comments