12
12
from keras .models import Sequential , Model
13
13
from keras .layers import Reshape , Activation , Conv2D , Input , MaxPooling2D , BatchNormalization , Flatten , Dense , Lambda , ConvLSTM2D
14
14
from keras .layers .advanced_activations import LeakyReLU
15
- from keras .callbacks import EarlyStopping , ModelCheckpoint , TensorBoard
15
+ from keras .callbacks import EarlyStopping , ModelCheckpoint , TensorBoard , ReduceLROnPlateau
16
16
from keras .optimizers import SGD , Adam , RMSprop
17
17
from keras .layers .wrappers import TimeDistributed
18
18
from keras .layers .merge import concatenate
@@ -80,8 +80,8 @@ class MultiObjDetTracker:
80
80
]
81
81
82
82
LABELS = LABELS_MOT17
83
- IMAGE_H , IMAGE_W = 608 , 608 # 416
84
- GRID_H , GRID_W = 19 , 19 # 13
83
+ IMAGE_H , IMAGE_W = 416 , 416 # 416
84
+ GRID_H , GRID_W = 13 , 13 # 13
85
85
BOX = 5
86
86
CLASS = len (LABELS )
87
87
CLASS_WEIGHTS = np .ones (CLASS , dtype = 'float32' )
@@ -101,6 +101,10 @@ class MultiObjDetTracker:
101
101
SEQUENCE_LENGTH = 4
102
102
MAX_BOX_PER_IMAGE = 50
103
103
104
+ LOAD_MODEL = True
105
+ INITIAL_EPOCH = 0
106
+ SAVED_MODEL_PATH = 'models/MultiObjDetTracker-CHKPNT-03-0.55.hdf5'
107
+
104
108
# train_image_folder = 'data/ImageNet-ObjectDetection/ILSVRC2015Train/Data/VID/train/'
105
109
# train_annot_folder = 'data/ImageNet-ObjectDetection/ILSVRC2015Train/Annotations/VID/train/'
106
110
# valid_image_folder = 'data/ImageNet-ObjectDetection/ILSVRC2015Train/Data/VID/val/'
@@ -117,14 +121,16 @@ class MultiObjDetTracker:
117
121
118
122
def __init__ (self , argv = {}):
119
123
argv ['LABELS' ] = self .LABELS
120
- argv ['BATCH_SIZE' ] = self .BATCH_SIZE
124
+ argv ['BATCH_SIZE' ] = self .BATCH_SIZE * self . SEQUENCE_LENGTH
121
125
argv ['IMAGE_H' ] = self .IMAGE_H
122
126
argv ['IMAGE_W' ] = self .IMAGE_W
123
127
argv ['GRID_H' ] = self .GRID_H
124
128
argv ['GRID_W' ] = self .GRID_W
125
129
126
130
self .detector = KerasYOLO (argv )
127
131
self .load_model ()
132
+ if self .LOAD_MODEL :
133
+ self .load_weights ()
128
134
129
135
def loss_fxn (self , y_true , y_pred , tboxes , message = '' ):
130
136
return self .detector .loss_fxn (y_true , y_pred , tboxes , message = message )
@@ -167,7 +173,7 @@ def load_model(self):
167
173
output_det = TimeDistributed (Reshape ((self .GRID_H , self .GRID_W , self .BOX , 4 + 1 + self .CLASS )), name = 'detection' )(x_bbox )
168
174
169
175
z = concatenate ([x_bbox , x_vis ])
170
- z_vis = ConvLSTM2D (1024 , (3 ,3 ), strides = (1 ,1 ), padding = 'same' , return_sequences = True , name = 'tconv_lstm' )(z )
176
+ z_vis = ConvLSTM2D (512 , (3 ,3 ), strides = (1 ,1 ), padding = 'same' , return_sequences = True , name = 'tconv_lstm' )(z )
171
177
172
178
# z = TimeDistributed(Conv2D(1024, (3,3), strides=(1,1), padding='same', use_bias=False, name='tconv_1'), name='timedist_tconv1')(z)
173
179
# z = TimeDistributed(BatchNormalization(name='tnorm_1'), name='timedist_tnorm')(z)
@@ -207,7 +213,7 @@ def load_data_generators(self, generator_config):
207
213
pickle .dump (valid_imgs , fp )
208
214
209
215
210
- train_batch = BatchSequenceGenerator1 (train_imgs , generator_config , norm = normalize , shuffle = True , augment = False )
216
+ train_batch = BatchSequenceGenerator1 (train_imgs , generator_config , norm = normalize , shuffle = True , augment = True )
211
217
valid_batch = BatchSequenceGenerator1 (valid_imgs , generator_config , norm = normalize , augment = False )
212
218
213
219
return train_batch , valid_batch
@@ -244,38 +250,47 @@ def train(self):
244
250
mode = 'min' ,
245
251
verbose = 1 )
246
252
247
- checkpoint = ModelCheckpoint ('weights/WEIGHTS_MultiObjDetTracker.h5 ' ,
253
+ checkpoint = ModelCheckpoint ('models/MultiObjDetTracker-CHKPNT-{epoch:02d}-{val_loss:.2f}.hdf5 ' ,
248
254
monitor = 'val_loss' ,
249
255
verbose = 1 ,
250
- save_best_only = True ,
256
+ save_best_only = False ,
251
257
# save_weights_only = True,
252
258
mode = 'min' ,
253
259
period = 1 )
254
260
261
+ reduce_lr = ReduceLROnPlateau (monitor = 'val_loss' ,
262
+ factor = 0.5 ,
263
+ patience = 2 ,
264
+ verbose = 1 ,
265
+ mode = 'auto' ,
266
+ min_lr = 1e-5 )
267
+
255
268
tb_counter = len ([log for log in os .listdir (os .path .expanduser ('./logs/' )) if 'MultiObjDetTracker_' in log ]) + 1
256
269
tensorboard = TensorBoard (log_dir = os .path .expanduser ('./logs/' ) + 'MultiObjDetTracker_' + str (tb_counter ),
257
270
histogram_freq = 0 ,
258
271
write_graph = True ,
259
272
write_images = False )
260
273
261
- optimizer = Adam (lr = 1e-5 , beta_1 = 0.9 , beta_2 = 0.999 , epsilon = 1e-08 , decay = 0.0 )
274
+ optimizer = Adam (lr = 1e-4 , beta_1 = 0.9 , beta_2 = 0.999 , epsilon = 1e-08 , decay = 0.0 )
262
275
#optimizer = SGD(lr=1e-4, decay=0.0005, momentum=0.9)
263
276
#optimizer = RMSprop(lr=1e-4, rho=0.9, epsilon=1e-08, decay=0.0)
264
277
265
- self .model .compile (loss = [self .custom_loss_ttrack , self .custom_loss_dtrack ], loss_weights = [1.5 , 1.0 ], optimizer = optimizer )
278
+ self .model .compile (loss = [self .custom_loss_ttrack , self .custom_loss_dtrack ], loss_weights = [0.7 , 0.3 ], optimizer = optimizer )
266
279
self .model .fit_generator (
267
280
generator = train_batch ,
268
281
steps_per_epoch = len (train_batch ),
269
282
epochs = 100 ,
270
283
verbose = 1 ,
271
284
validation_data = valid_batch ,
272
285
validation_steps = len (valid_batch ),
273
- callbacks = [early_stop , checkpoint , tensorboard ],
274
- max_queue_size = 3 )
286
+ callbacks = [early_stop , checkpoint , tensorboard , reduce_lr ],
287
+ max_queue_size = 3 ,
288
+ initial_epoch = self .INITIAL_EPOCH )
275
289
276
290
277
- def load_weights (self , weight_path ):
278
- self .model .load_weights (weight_path )
291
+ def load_weights (self ):
292
+ self .model .load_weights (self .SAVED_MODEL_PATH )
293
+ self .INITIAL_EPOCH = int (self .SAVED_MODEL_PATH .split ('-' )[2 ])
279
294
280
295
def predict (self , input_paths , output_paths ):
281
296
assert len (input_paths )== self .SEQUENCE_LENGTH
0 commit comments