@@ -22,10 +22,17 @@ def __init__(
22
22
workers = 10 ,
23
23
batch_size = 128 ,
24
24
seed = 1 ,
25
- increment = 10
25
+ increment = 10 ,
26
+ validation_split = 0.
26
27
):
27
28
datasets = _get_datasets (dataset_name )
28
- self ._setup_data (datasets , random_order = random_order , seed = seed , increment = increment )
29
+ self ._setup_data (
30
+ datasets ,
31
+ random_order = random_order ,
32
+ seed = seed ,
33
+ increment = increment ,
34
+ validation_split = validation_split
35
+ )
29
36
self .train_transforms = datasets [0 ].train_transforms # FIXME handle multiple datasets
30
37
self .common_transforms = datasets [0 ].common_transforms
31
38
@@ -48,6 +55,9 @@ def new_task(self, memory=None):
48
55
x_train , y_train = self ._select (
49
56
self .data_train , self .targets_train , low_range = min_class , high_range = max_class
50
57
)
58
+ x_val , y_val = self ._select (
59
+ self .data_val , self .targets_val , low_range = min_class , high_range = max_class
60
+ )
51
61
x_test , y_test = self ._select (self .data_test , self .targets_test , high_range = max_class )
52
62
53
63
if memory is not None :
@@ -57,6 +67,7 @@ def new_task(self, memory=None):
57
67
y_train = np .concatenate ((y_train , targets_memory ))
58
68
59
69
train_loader = self ._get_loader (x_train , y_train , mode = "train" )
70
+ val_loader = self ._get_loader (x_val , y_val , mode = "train" ) if len (x_val ) > 0 else None
60
71
test_loader = self ._get_loader (x_test , y_test , mode = "test" )
61
72
62
73
task_info = {
@@ -71,13 +82,40 @@ def new_task(self, memory=None):
71
82
72
83
self ._current_task += 1
73
84
74
- return task_info , train_loader , test_loader
85
+ return task_info , train_loader , val_loader , test_loader
86
+
87
+ def get_custom_loader (self , class_indexes , mode = "test" , data_source = "train" ):
88
+ """Returns a custom loader.
89
+
90
+ :param class_indexes: A list of class indexes that we want.
91
+ :param mode: Various mode for the transformations applied on it.
92
+ :param data_source: Whether to fetch from the train, val, or test set.
93
+ :return: The raw data and a loader.
94
+ """
95
+ if not isinstance (class_indexes , list ): # TODO: deprecated, should always give a list
96
+ class_indexes = [class_indexes ]
97
+
98
+ if data_source == "train" :
99
+ x , y = self .data_train , self .targets_train
100
+ elif data_source == "val" :
101
+ x , y = self .data_val , self .targets_val
102
+ elif data_source == "test" :
103
+ x , y = self .data_test , self .targets_test
104
+ else :
105
+ raise ValueError ("Unknown data source <{}>." .format (data_source ))
75
106
76
- def get_class_loader (self , class_idx , mode = "test" ):
77
- x , y = self ._select (
78
- self .data_train , self .targets_train , low_range = class_idx , high_range = class_idx + 1
79
- )
80
- return x , self ._get_loader (x , y , shuffle = False , mode = mode )
107
+ data , targets = [], []
108
+ for class_index in class_indexes :
109
+ class_data , class_targets = self ._select (
110
+ x , y , low_range = class_index , high_range = class_index + 1
111
+ )
112
+ data .append (class_data )
113
+ targets .append (class_targets )
114
+
115
+ data = np .concatenate (data )
116
+ targets = np .concatenate (targets )
117
+
118
+ return data , self ._get_loader (data , targets , shuffle = False , mode = mode )
81
119
82
120
def _select (self , x , y , low_range = 0 , high_range = 0 ):
83
121
idxes = np .where (np .logical_and (y >= low_range , y < high_range ))[0 ]
@@ -102,10 +140,11 @@ def _get_loader(self, x, y, shuffle=True, mode="train"):
102
140
num_workers = self ._workers
103
141
)
104
142
105
- def _setup_data (self , datasets , random_order = False , seed = 1 , increment = 10 ):
143
+ def _setup_data (self , datasets , random_order = False , seed = 1 , increment = 10 , validation_split = 0. ):
106
144
# FIXME: handles online loading of images
107
145
self .data_train , self .targets_train = [], []
108
146
self .data_test , self .targets_test = [], []
147
+ self .data_val , self .targets_val = [], []
109
148
self .increments = []
110
149
self .class_order = []
111
150
@@ -115,6 +154,9 @@ def _setup_data(self, datasets, random_order=False, seed=1, increment=10):
115
154
test_dataset = dataset .base_dataset ("data" , train = False , download = True )
116
155
117
156
x_train , y_train = train_dataset .data , np .array (train_dataset .targets )
157
+ x_val , y_val , x_train , y_train = self ._split_per_class (
158
+ x_train , y_train , validation_split
159
+ )
118
160
x_test , y_test = test_dataset .data , np .array (test_dataset .targets )
119
161
120
162
order = [i for i in range (len (np .unique (y_train )))]
@@ -126,10 +168,12 @@ def _setup_data(self, datasets, random_order=False, seed=1, increment=10):
126
168
127
169
self .class_order .append (order )
128
170
129
- y_train = np .array (list (map (lambda x : order .index (x ), y_train )))
130
- y_test = np .array (list (map (lambda x : order .index (x ), y_test )))
171
+ y_train = self ._map_new_class_index (y_train , order )
172
+ y_val = self ._map_new_class_index (y_val , order )
173
+ y_test = self ._map_new_class_index (y_test , order )
131
174
132
175
y_train += current_class_idx
176
+ y_val += current_class_idx
133
177
y_test += current_class_idx
134
178
135
179
current_class_idx += len (order )
@@ -140,14 +184,53 @@ def _setup_data(self, datasets, random_order=False, seed=1, increment=10):
140
184
141
185
self .data_train .append (x_train )
142
186
self .targets_train .append (y_train )
187
+ self .data_val .append (x_val )
188
+ self .targets_val .append (y_val )
143
189
self .data_test .append (x_test )
144
190
self .targets_test .append (y_test )
145
191
146
192
self .data_train = np .concatenate (self .data_train )
147
193
self .targets_train = np .concatenate (self .targets_train )
194
+ self .data_val = np .concatenate (self .data_val )
195
+ self .targets_val = np .concatenate (self .targets_val )
148
196
self .data_test = np .concatenate (self .data_test )
149
197
self .targets_test = np .concatenate (self .targets_test )
150
198
199
+ @staticmethod
200
+ def _map_new_class_index (y , order ):
201
+ """Transforms targets for new class order."""
202
+ return np .array (list (map (lambda x : order .index (x ), y )))
203
+
204
+ @staticmethod
205
+ def _split_per_class (x , y , validation_split = 0. ):
206
+ """Splits train data for a subset of validation data.
207
+
208
+ Split is done so that each class has a much data.
209
+ """
210
+ shuffled_indexes = np .random .permutation (x .shape [0 ])
211
+ x = x [shuffled_indexes ]
212
+ y = y [shuffled_indexes ]
213
+
214
+ x_val , y_val = [], []
215
+ x_train , y_train = [], []
216
+
217
+ for class_id in np .unique (y ):
218
+ class_indexes = np .where (y == class_id )[0 ]
219
+ nb_val_elts = int (class_indexes .shape [0 ] * validation_split )
220
+
221
+ val_indexes = class_indexes [:nb_val_elts ]
222
+ train_indexes = class_indexes [nb_val_elts :]
223
+
224
+ x_val .append (x [val_indexes ])
225
+ y_val .append (y [val_indexes ])
226
+ x_train .append (x [train_indexes ])
227
+ y_train .append (y [train_indexes ])
228
+
229
+ x_val , y_val = np .concatenate (x_val ), np .concatenate (y_val )
230
+ x_train , y_train = np .concatenate (x_train ), np .concatenate (y_train )
231
+
232
+ return x_val , y_val , x_train , y_train
233
+
151
234
152
235
class DummyDataset (torch .utils .data .Dataset ):
153
236
0 commit comments