-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathtrain.py
More file actions
59 lines (48 loc) · 2.59 KB
/
train.py
File metadata and controls
59 lines (48 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import tensorflow as tf
from tensorflow.keras.callbacks import (TensorBoard, LearningRateScheduler,
EarlyStopping, ModelCheckpoint)
from tensorflow.keras.optimizers import Adam
import numpy as np
from scipy.misc import imread, imsave, imresize
from model import SRRAM
from dataset import Dataset
import utils
from pathlib import Path
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='./data/General-100')
parser.add_argument('--save_dir', type=str, default='./saved')
parser.add_argument('--img_size', type=int, default=96)
parser.add_argument('--extension', type=str, default='bmp')
parser.add_argument('--scale_factor', type=int, default=2)
parser.add_argument('--epochs', type=int, default=25)
parser.add_argument('--patience', type=int, default=1000)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--beta1', type=float, default=0.9)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--decay_step', type=int, default=2e+5)
parser.add_argument('--decay_rate', type=int, default=2)
flags = parser.parse_args()
dataset = Dataset(flags)
srram = SRRAM(scale_factor=flags.scale_factor)
save_path = utils.build_save_path(flags)
cbs = [TensorBoard(log_dir=save_path, histogram_freq=1, write_graph=True),
LearningRateScheduler(lambda epoch: utils.lr_decay(epoch, init_value=flags.lr,
decay_step=flags.decay_step,
decay_rate=flags.decay_rate)),
EarlyStopping(monitor='val_loss', patience=flags.patience, verbose=0, mode='auto'),
ModelCheckpoint(save_path + '/model.h5', save_best_only=True)]
srram.model.compile(optimizer=Adam(lr=flags.lr, epsilon=1e-8), loss='mae')
srram.model.fit(dataset.train_set, epochs=flags.epochs, steps_per_epoch=dataset.train_steps_per_epoch,
validation_data=dataset.val_set, validation_steps=dataset.val_steps_per_epoch,
callbacks=cbs)
#srram.model.save(Path(save_path) / 'model.h5')
sample_dir = Path(save_path) / 'sample'
sample_dir.mkdir(exist_ok=True)
for filename in (Path(flags.data_dir) / 'test').glob('*.bmp'):
img = imread(filename)
img = imresize(img, (img.shape[0] // flags.scale_factor, img.shape[1] // flags.scale_factor), interp='bicubic')
out = np.squeeze(srram.model.predict(img[None, :, :, :]), axis=0)
out = np.clip(out, 0, 255).astype('uint8')
imsave(Path(sample_dir) / '{}.bmp'.format(filename.stem), out)