Skip to content

Commit ea60845

Browse files
author
Arthur Douillard
committed
[common] Add proper val split for incremental loader.
1 parent b2a814c commit ea60845

File tree

4 files changed

+123
-41
lines changed

4 files changed

+123
-41
lines changed

inclearn/lib/data.py

+94-11
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,17 @@ def __init__(
2222
workers=10,
2323
batch_size=128,
2424
seed=1,
25-
increment=10
25+
increment=10,
26+
validation_split=0.
2627
):
2728
datasets = _get_datasets(dataset_name)
28-
self._setup_data(datasets, random_order=random_order, seed=seed, increment=increment)
29+
self._setup_data(
30+
datasets,
31+
random_order=random_order,
32+
seed=seed,
33+
increment=increment,
34+
validation_split=validation_split
35+
)
2936
self.train_transforms = datasets[0].train_transforms # FIXME handle multiple datasets
3037
self.common_transforms = datasets[0].common_transforms
3138

@@ -48,6 +55,9 @@ def new_task(self, memory=None):
4855
x_train, y_train = self._select(
4956
self.data_train, self.targets_train, low_range=min_class, high_range=max_class
5057
)
58+
x_val, y_val = self._select(
59+
self.data_val, self.targets_val, low_range=min_class, high_range=max_class
60+
)
5161
x_test, y_test = self._select(self.data_test, self.targets_test, high_range=max_class)
5262

5363
if memory is not None:
@@ -57,6 +67,7 @@ def new_task(self, memory=None):
5767
y_train = np.concatenate((y_train, targets_memory))
5868

5969
train_loader = self._get_loader(x_train, y_train, mode="train")
70+
val_loader = self._get_loader(x_val, y_val, mode="train") if len(x_val) > 0 else None
6071
test_loader = self._get_loader(x_test, y_test, mode="test")
6172

6273
task_info = {
@@ -71,13 +82,40 @@ def new_task(self, memory=None):
7182

7283
self._current_task += 1
7384

74-
return task_info, train_loader, test_loader
85+
return task_info, train_loader, val_loader, test_loader
86+
87+
def get_custom_loader(self, class_indexes, mode="test", data_source="train"):
88+
"""Returns a custom loader.
89+
90+
:param class_indexes: A list of class indexes that we want.
91+
:param mode: Various mode for the transformations applied on it.
92+
:param data_source: Whether to fetch from the train, val, or test set.
93+
:return: The raw data and a loader.
94+
"""
95+
if not isinstance(class_indexes, list): # TODO: deprecated, should always give a list
96+
class_indexes = [class_indexes]
97+
98+
if data_source == "train":
99+
x, y = self.data_train, self.targets_train
100+
elif data_source == "val":
101+
x, y = self.data_val, self.targets_val
102+
elif data_source == "test":
103+
x, y = self.data_test, self.targets_test
104+
else:
105+
raise ValueError("Unknown data source <{}>.".format(data_source))
75106

76-
def get_class_loader(self, class_idx, mode="test"):
77-
x, y = self._select(
78-
self.data_train, self.targets_train, low_range=class_idx, high_range=class_idx + 1
79-
)
80-
return x, self._get_loader(x, y, shuffle=False, mode=mode)
107+
data, targets = [], []
108+
for class_index in class_indexes:
109+
class_data, class_targets = self._select(
110+
x, y, low_range=class_index, high_range=class_index + 1
111+
)
112+
data.append(class_data)
113+
targets.append(class_targets)
114+
115+
data = np.concatenate(data)
116+
targets = np.concatenate(targets)
117+
118+
return data, self._get_loader(data, targets, shuffle=False, mode=mode)
81119

82120
def _select(self, x, y, low_range=0, high_range=0):
83121
idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0]
@@ -102,10 +140,11 @@ def _get_loader(self, x, y, shuffle=True, mode="train"):
102140
num_workers=self._workers
103141
)
104142

105-
def _setup_data(self, datasets, random_order=False, seed=1, increment=10):
143+
def _setup_data(self, datasets, random_order=False, seed=1, increment=10, validation_split=0.):
106144
# FIXME: handles online loading of images
107145
self.data_train, self.targets_train = [], []
108146
self.data_test, self.targets_test = [], []
147+
self.data_val, self.targets_val = [], []
109148
self.increments = []
110149
self.class_order = []
111150

@@ -115,6 +154,9 @@ def _setup_data(self, datasets, random_order=False, seed=1, increment=10):
115154
test_dataset = dataset.base_dataset("data", train=False, download=True)
116155

117156
x_train, y_train = train_dataset.data, np.array(train_dataset.targets)
157+
x_val, y_val, x_train, y_train = self._split_per_class(
158+
x_train, y_train, validation_split
159+
)
118160
x_test, y_test = test_dataset.data, np.array(test_dataset.targets)
119161

120162
order = [i for i in range(len(np.unique(y_train)))]
@@ -126,10 +168,12 @@ def _setup_data(self, datasets, random_order=False, seed=1, increment=10):
126168

127169
self.class_order.append(order)
128170

129-
y_train = np.array(list(map(lambda x: order.index(x), y_train)))
130-
y_test = np.array(list(map(lambda x: order.index(x), y_test)))
171+
y_train = self._map_new_class_index(y_train, order)
172+
y_val = self._map_new_class_index(y_val, order)
173+
y_test = self._map_new_class_index(y_test, order)
131174

132175
y_train += current_class_idx
176+
y_val += current_class_idx
133177
y_test += current_class_idx
134178

135179
current_class_idx += len(order)
@@ -140,14 +184,53 @@ def _setup_data(self, datasets, random_order=False, seed=1, increment=10):
140184

141185
self.data_train.append(x_train)
142186
self.targets_train.append(y_train)
187+
self.data_val.append(x_val)
188+
self.targets_val.append(y_val)
143189
self.data_test.append(x_test)
144190
self.targets_test.append(y_test)
145191

146192
self.data_train = np.concatenate(self.data_train)
147193
self.targets_train = np.concatenate(self.targets_train)
194+
self.data_val = np.concatenate(self.data_val)
195+
self.targets_val = np.concatenate(self.targets_val)
148196
self.data_test = np.concatenate(self.data_test)
149197
self.targets_test = np.concatenate(self.targets_test)
150198

199+
@staticmethod
200+
def _map_new_class_index(y, order):
201+
"""Transforms targets for new class order."""
202+
return np.array(list(map(lambda x: order.index(x), y)))
203+
204+
@staticmethod
205+
def _split_per_class(x, y, validation_split=0.):
206+
"""Splits train data for a subset of validation data.
207+
208+
Split is done so that each class has a much data.
209+
"""
210+
shuffled_indexes = np.random.permutation(x.shape[0])
211+
x = x[shuffled_indexes]
212+
y = y[shuffled_indexes]
213+
214+
x_val, y_val = [], []
215+
x_train, y_train = [], []
216+
217+
for class_id in np.unique(y):
218+
class_indexes = np.where(y == class_id)[0]
219+
nb_val_elts = int(class_indexes.shape[0] * validation_split)
220+
221+
val_indexes = class_indexes[:nb_val_elts]
222+
train_indexes = class_indexes[nb_val_elts:]
223+
224+
x_val.append(x[val_indexes])
225+
y_val.append(y[val_indexes])
226+
x_train.append(x[train_indexes])
227+
y_train.append(y[train_indexes])
228+
229+
x_val, y_val = np.concatenate(x_val), np.concatenate(y_val)
230+
x_train, y_train = np.concatenate(x_train), np.concatenate(y_train)
231+
232+
return x_val, y_val, x_train, y_train
233+
151234

152235
class DummyDataset(torch.utils.data.Dataset):
153236

inclearn/lib/factory.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def get_data(args):
5151
random_order=args["random_classes"],
5252
shuffle=True,
5353
batch_size=args["batch_size"],
54-
workers=args["workers"]
54+
workers=args["workers"],
55+
validation_split=args["validation"]
5556
)
5657

5758

inclearn/models/icarl.py

+24-26
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, args):
4646
self._clf_loss = F.binary_cross_entropy_with_logits
4747
self._distil_loss = F.binary_cross_entropy_with_logits
4848

49-
self._herding_matrix = np.zeros((100, 500)) # FIXME: nb classes
49+
self._herding_matrix = []
5050

5151
def eval(self):
5252
self._network.eval()
@@ -75,21 +75,15 @@ def _train_task(self, train_loader, val_loader):
7575
print("nb ", len(train_loader.dataset))
7676

7777
for epoch in range(self._n_epochs):
78-
_loss = 0.
78+
_loss, val_loss = 0., 0.
7979

8080
self._scheduler.step()
8181

8282
prog_bar = tqdm(train_loader)
83-
c = 0
84-
for inputs, targets in prog_bar:
85-
c += 1
83+
for i, (inputs, targets) in enumerate(prog_bar, start=1):
8684
self._optimizer.zero_grad()
8785

88-
inputs, targets = inputs.to(self._device), targets.to(self._device)
89-
targets = utils.to_onehot(targets, self._n_classes).to(self._device)
90-
logits = self._network(inputs)
91-
92-
loss = self._compute_loss(inputs, logits, targets)
86+
loss = self._forward_loss(inputs, targets)
9387

9488
if not utils._check_loss(loss):
9589
import pdb
@@ -100,14 +94,26 @@ def _train_task(self, train_loader, val_loader):
10094

10195
_loss += loss.item()
10296

97+
if val_loader is not None and i == len(train_loader):
98+
for inputs, targets in val_loader:
99+
val_loss += self._forward_loss(inputs, targets).item()
100+
103101
prog_bar.set_description(
104-
"Task {}/{}, Epoch {}/{} => Clf loss: {}".format(
102+
"Task {}/{}, Epoch {}/{} => Clf loss: {}, Val loss: {}".format(
105103
self._task + 1, self._n_tasks,
106104
epoch + 1, self._n_epochs,
107-
round(_loss / c, 3)
105+
round(_loss / i, 3),
106+
round(val_loss, 3)
108107
)
109108
)
110109

110+
def _forward_loss(self, inputs, targets):
111+
inputs, targets = inputs.to(self._device), targets.to(self._device)
112+
targets = utils.to_onehot(targets, self._n_classes).to(self._device)
113+
logits = self._network(inputs)
114+
115+
return self._compute_loss(inputs, logits, targets)
116+
111117
def _after_task(self, inc_dataset):
112118
self.build_examplars(inc_dataset)
113119

@@ -182,23 +188,24 @@ def _memory_per_class(self):
182188
# -----------------
183189

184190
def build_examplars(self, inc_dataset):
191+
print("Building & updating memory.")
192+
185193
self._data_memory, self._targets_memory = [], []
186194
self._class_means = np.zeros((100, self._network.features_dim))
187195

188196
for class_idx in range(self._n_classes):
189-
inputs, loader = inc_dataset.get_class_loader(class_idx, mode="test")
197+
inputs, loader = inc_dataset.get_custom_loader(class_idx, mode="test")
190198
features, targets = extract_features(
191199
self._network, loader
192200
)
193201
features_flipped, _ = extract_features(
194-
self._network, inc_dataset.get_class_loader(class_idx, mode="flip")[1]
202+
self._network, inc_dataset.get_custom_loader(class_idx, mode="flip")[1]
195203
)
196204

197205
if class_idx >= self._n_classes - self._task_size:
198-
print("Finding examplars for", class_idx)
199-
self._herding_matrix[class_idx, :] = select_examplars(
206+
self._herding_matrix.append(select_examplars(
200207
features, self._memory_per_class
201-
)
208+
))
202209

203210
examplar_mean, alph = compute_examplar_mean(
204211
features, features_flipped, self._herding_matrix[class_idx], self._memory_per_class
@@ -281,13 +288,4 @@ def compute_accuracy(model, loader, class_means):
281288
sqd = cdist(class_means, features, 'sqeuclidean')
282289
score_icarl = (-sqd).T
283290

284-
# Compute the accuracy over the batch
285-
stat_icarl = [
286-
ll in best
287-
for ll, best in zip(targets_.astype('int32'),
288-
np.argsort(score_icarl, axis=1)[:, -1:])
289-
]
290-
291-
print("stats ", np.average(stat_icarl))
292-
293291
return np.argsort(score_icarl, axis=1)[:, -1], targets_

inclearn/train.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _train(args):
3636
memory = None
3737

3838
for _ in range(inc_dataset.n_tasks):
39-
task_info, train_loader, test_loader = inc_dataset.new_task(memory)
39+
task_info, train_loader, val_loader, test_loader = inc_dataset.new_task(memory)
4040
if task_info["task"] == args["max_task"]:
4141
break
4242

@@ -50,10 +50,10 @@ def _train(args):
5050
)
5151

5252
model.eval()
53-
model.before_task(train_loader, None)
53+
model.before_task(train_loader, val_loader)
5454
print("Train on {}->{}.".format(task_info["min_class"], task_info["max_class"]))
5555
model.train()
56-
model.train_task(train_loader, None)
56+
model.train_task(train_loader, val_loader)
5757
model.eval()
5858
model.after_task(inc_dataset)
5959

0 commit comments

Comments
 (0)