Skip to content

Commit 4f49c0c

Browse files
authored
[FIX] Fix missing base estimators when calling load (#45)
* fix save and load * remove randomness in unit tests
1 parent 1ede2a9 commit 4f49c0c

File tree

3 files changed

+44
-13
lines changed

3 files changed

+44
-13
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Changelog
44
[Ver 0.1.*]
55
-----------
66

7+
* |Fix| Fix missing base estimators when calling :meth:`load()` for all ensembles | `@xuyxu <https://github.com/xuyxu>`__
78
* |MajorFeature| Add methods on model deserialization :meth:`load()` for all ensembles | `@mttgdd <https://github.com/mttgdd>`__
89

910
[Beta]

torchensemble/tests/test_all_models.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch.utils.data import TensorDataset, DataLoader
66

77
import torchensemble
8+
from torchensemble.utils import io
89
from torchensemble.utils.logging import set_logger
910

1011

@@ -24,6 +25,10 @@
2425
torchensemble.AdversarialTrainingRegressor]
2526

2627

28+
# Remove randomness
29+
np.random.seed(0)
30+
torch.manual_seed(0)
31+
2732
set_logger("pytest_all_models")
2833

2934

@@ -66,12 +71,10 @@ def forward(self, X):
6671

6772
# Testing data
6873
X_test = torch.Tensor(np.array(([0.5, 0.5],
69-
[0.6, 0.6],
70-
[0.7, 0.7],
71-
[0.8, 0.8])))
74+
[0.6, 0.6])))
7275

73-
y_test_clf = torch.LongTensor(np.array(([1, 1, 0, 0])))
74-
y_test_reg = torch.FloatTensor(np.array(([0.5, 0.6, 0.7, 0.8])))
76+
y_test_clf = torch.LongTensor(np.array(([1, 0])))
77+
y_test_reg = torch.FloatTensor(np.array(([0.5, 0.6])))
7578
y_test_reg = y_test_reg.view(-1, 1)
7679

7780

@@ -94,9 +97,9 @@ def test_clf(clf):
9497

9598
# Prepare data
9699
train = TensorDataset(X_train, y_train_clf)
97-
train_loader = DataLoader(train, batch_size=2)
100+
train_loader = DataLoader(train, batch_size=2, shuffle=False)
98101
test = TensorDataset(X_test, y_test_clf)
99-
test_loader = DataLoader(test, batch_size=2)
102+
test_loader = DataLoader(test, batch_size=2, shuffle=False)
100103

101104
# Snapshot ensemble needs more epochs
102105
if isinstance(model, torchensemble.SnapshotEnsembleClassifier):
@@ -109,7 +112,15 @@ def test_clf(clf):
109112
save_model=True)
110113

111114
# Test
112-
model.predict(test_loader)
115+
prev_acc = model.predict(test_loader)
116+
117+
# Reload
118+
new_model = clf(estimator=MLP_clf, n_estimators=n_estimators, cuda=False)
119+
io.load(new_model)
120+
121+
post_acc = new_model.predict(test_loader)
122+
123+
assert prev_acc == post_acc # ensure the same performance
113124

114125

115126
@pytest.mark.parametrize("reg", all_reg)
@@ -131,9 +142,9 @@ def test_reg(reg):
131142

132143
# Prepare data
133144
train = TensorDataset(X_train, y_train_reg)
134-
train_loader = DataLoader(train, batch_size=2)
145+
train_loader = DataLoader(train, batch_size=2, shuffle=False)
135146
test = TensorDataset(X_test, y_test_reg)
136-
test_loader = DataLoader(test, batch_size=2)
147+
test_loader = DataLoader(test, batch_size=2, shuffle=False)
137148

138149
# Snapshot ensemble needs more epochs
139150
if isinstance(model, torchensemble.SnapshotEnsembleRegressor):
@@ -146,7 +157,15 @@ def test_reg(reg):
146157
save_model=True)
147158

148159
# Test
149-
model.predict(test_loader)
160+
prev_mse = model.predict(test_loader)
161+
162+
# Reload
163+
new_model = reg(estimator=MLP_reg, n_estimators=n_estimators, cuda=False)
164+
io.load(new_model)
165+
166+
post_mse = new_model.predict(test_loader)
167+
168+
assert prev_mse == post_mse # ensure the same performance
150169

151170

152171
@pytest.mark.parametrize("method", all_clf + all_reg)

torchensemble/utils/io.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ def save(model, save_dir, logger):
1414
filename = "{}_{}_{}_ckpt.pth".format(type(model).__name__,
1515
model.base_estimator_.__name__,
1616
model.n_estimators)
17-
state = {"model": model.state_dict()}
17+
18+
# The real number of base estimators in some ensembles is not same as
19+
# `n_estimators`.
20+
state = {"n_estimators": len(model.estimators_),
21+
"model": model.state_dict()}
1822
save_dir = os.path.join(save_dir, filename)
1923

2024
logger.info("Saving the model to `{}`".format(save_dir))
@@ -39,4 +43,11 @@ def load(model, save_dir="./", logger=None):
3943
if logger:
4044
logger.info("Loading the model from `{}`".format(save_dir))
4145

42-
model.load_state_dict(torch.load(save_dir)["model"])
46+
state = torch.load(save_dir)
47+
n_estimators = state["n_estimators"]
48+
model_params = state["model"]
49+
50+
# Pre-allocate and load all base estimators
51+
for _ in range(n_estimators):
52+
model.estimators_.append(model._make_estimator())
53+
model.load_state_dict(model_params)

0 commit comments

Comments
 (0)