|
6 | 6 | from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard
|
7 | 7 | from keras.datasets import mnist
|
8 | 8 | from keras.layers import Dense, Dropout, Flatten, Input
|
9 |
| -from keras.layers import Convolution2D, MaxPooling2D |
| 9 | +from keras.layers import Conv2D, MaxPooling2D |
10 | 10 | from keras.models import Model
|
11 | 11 |
|
12 | 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
13 | 13 | from utils import angle_error, RotNetDataGenerator, binarize_images
|
14 |
| - |
15 |
| - |
16 | 14 | # we don't need the labels indicating the digit value, so we only load the images
|
17 | 15 | (X_train, _), (X_test, _) = mnist.load_data()
|
18 | 16 |
|
|
38 | 36 |
|
39 | 37 | # model definition
|
40 | 38 | input = Input(shape=(img_rows, img_cols, img_channels))
|
41 |
| -x = Convolution2D(nb_filters, kernel_size[0], kernel_size[1], |
42 |
| - activation='relu')(input) |
43 |
| -x = Convolution2D(nb_filters, kernel_size[0], kernel_size[1], |
44 |
| - activation='relu')(x) |
| 39 | +x = Conv2D(nb_filters, kernel_size, activation='relu')(input) |
| 40 | +x = Conv2D(nb_filters, kernel_size, activation='relu')(x) |
45 | 41 | x = MaxPooling2D(pool_size=(2, 2))(x)
|
46 | 42 | x = Dropout(0.25)(x)
|
47 | 43 | x = Flatten()(x)
|
48 | 44 | x = Dense(128, activation='relu')(x)
|
49 | 45 | x = Dropout(0.25)(x)
|
50 | 46 | x = Dense(nb_classes, activation='softmax')(x)
|
51 | 47 |
|
52 |
| -model = Model(input=input, output=x) |
| 48 | +model = Model(inputs=input, outputs=x) |
53 | 49 |
|
54 | 50 | model.summary()
|
55 | 51 |
|
|
82 | 78 | preprocess_func=binarize_images,
|
83 | 79 | shuffle=True
|
84 | 80 | ),
|
85 |
| - samples_per_epoch=nb_train_samples, |
86 |
| - nb_epoch=nb_epoch, |
| 81 | + steps_per_epoch=nb_train_samples / batch_size, |
| 82 | + epochs=nb_epoch, |
87 | 83 | validation_data=RotNetDataGenerator(
|
88 | 84 | X_test,
|
89 | 85 | batch_size=batch_size,
|
90 | 86 | preprocess_func=binarize_images
|
91 | 87 | ),
|
92 |
| - nb_val_samples=nb_test_samples, |
| 88 | + validation_steps=nb_test_samples / batch_size, |
93 | 89 | verbose=1,
|
94 | 90 | callbacks=[checkpointer, early_stopping, tensorboard]
|
95 | 91 | )
|
0 commit comments