Skip to content

Commit 443c419

Browse files
committed
add cifar and lincls
1 parent b9b1e97 commit 443c419

File tree

9 files changed

+250
-90
lines changed

9 files changed

+250
-90
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
.ipynb_checkpoints/
2+
__pycache__/
3+
result/
4+
*.ipynb
5+
*.sh

augment.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,29 @@ def _augment_simsiam(self, x, shape, coord=[[[0., 0., 1., 1.]]]):
1818
x = self._resize(x)
1919
x = self._random_color_jitter(x, p=.8)
2020
x = self._random_grayscale(x, p=.2)
21-
x = self._random_gaussian_blur(x, p=.5)
21+
if self.args.dataset == 'imagenet':
22+
x = self._random_gaussian_blur(x, p=.5)
2223
x = self._random_hflip(x)
2324
x = self._standardize(x)
2425
return x
2526

2627
def _augment_lincls(self, x, shape, coord=[[[0., 0., 1., 1.]]]):
27-
x = self._crop(x, shape, coord)
28-
x = self._resize(x)
28+
x = tf.saturate_cast(x, tf.uint8)
29+
if self.args.dataset == 'imagenet':
30+
if self.mode == 'train':
31+
x = self._crop(x, shape, coord)
32+
else:
33+
x = self._centercrop(x, shape)
34+
35+
x = self._resize(x)
36+
37+
if self.mode == 'train':
38+
x = self._random_color_jitter(x, p=.8)
39+
x = self._random_grayscale(x, p=.2)
40+
if self.args.dataset == 'imagenet':
41+
x = self._random_gaussian_blur(x, p=.5)
42+
x = self._random_hflip(x)
43+
2944
x = self._standardize(x)
3045
return x
3146

@@ -46,7 +61,25 @@ def _crop(self, x, shape, coord=[[[0., 0., 1., 1.]]]):
4661

4762
offset_height, offset_width, _ = tf.unstack(bbox_begin)
4863
target_height, target_width, _ = tf.unstack(bbox_size)
49-
x = tf.slice(x, [offset_height, offset_width, 0], [target_height, target_width, 3])
64+
x = tf.slice(x, [offset_height, offset_width, 0], [target_height, target_width, -1])
65+
return x
66+
67+
def _centercrop(self, x, shape):
68+
if tf.less(shape[0], self.args.img_size):
69+
offset_height = 0
70+
target_height = shape[0]
71+
else:
72+
offset_height = tf.maximum(0, shape[0]-self.args.img_size) // 2
73+
target_height = self.args.img_size
74+
75+
if tf.less(shape[1], self.args.img_size):
76+
offset_width = 0
77+
target_width = shape[1]
78+
else:
79+
offset_width = tf.maximum(0, shape[1]-self.args.img_size) // 2
80+
target_width = self.args.img_size
81+
82+
x = tf.slice(x, [offset_height, offset_width, 0], [target_height, target_width, -1])
5083
return x
5184

5285
def _resize(self, x):

callback.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import six
33
import yaml
4+
import tqdm
45
import numpy as np
56
import pandas as pd
67
import tensorflow as tf
@@ -9,6 +10,8 @@
910
from tensorflow.keras.callbacks import CSVLogger
1011
from tensorflow.keras.callbacks import TensorBoard
1112
from tensorflow.keras.experimental import CosineDecay
13+
from sklearn.neighbors import KNeighborsClassifier
14+
from sklearn.metrics import accuracy_score
1215

1316
from common import create_stamp
1417

@@ -90,7 +93,7 @@ def create_callbacks(args, logger, initial_epoch):
9093
f'history - {args.history} | '
9194
f'tensorboard - {args.tensorboard}')
9295

93-
callbacks = []
96+
callbacks = []
9497
if args.checkpoint:
9598
if args.task == 'pretext':
9699
callbacks.append(ModelCheckpoint(
@@ -108,13 +111,13 @@ def create_callbacks(args, logger, initial_epoch):
108111
save_best_only=True))
109112
else:
110113
callbacks.append(ModelCheckpoint(
111-
filepath='{args.result_path}/{args.task}/{args.stamp}/checkpoint/latest',
114+
filepath=f'{args.result_path}/{args.task}/{args.stamp}/checkpoint/latest',
112115
monitor='val_acc1',
113116
mode='max',
114117
verbose=1,
115118
save_weights_only=True))
116119
callbacks.append(ModelCheckpoint(
117-
filepath='{args.result_path}/{args.task}/{args.stamp}/checkpoint/best',
120+
filepath=f'{args.result_path}/{args.task}/{args.stamp}/checkpoint/best',
118121
monitor='val_acc1',
119122
mode='max',
120123
verbose=1,

common.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def get_arguments():
2121
parser = argparse.ArgumentParser()
2222
parser.add_argument("--task", type=str, default='pretext',
2323
choices=['pretext', 'lincls'])
24+
parser.add_argument("--dataset", type=str, default='imagenet')
2425
parser.add_argument("--freeze", action='store_true')
2526
parser.add_argument("--backbone", type=str, default='resnet50')
2627
parser.add_argument("--batch_size", type=int, default=256)
@@ -38,6 +39,7 @@ def get_arguments():
3839
parser.add_argument("--steps", type=int, default=0)
3940
parser.add_argument("--epochs", type=int, default=200)
4041

42+
parser.add_argument("--evaluate", action='store_true')
4143
parser.add_argument("--checkpoint", action='store_true')
4244
parser.add_argument("--history", action='store_true')
4345
parser.add_argument("--tensorboard", action='store_true')
@@ -107,7 +109,7 @@ def create_stamp():
107109

108110

109111
def search_same(args):
110-
search_ignore = ['checkpoint', 'history', 'tensorboard',
112+
search_ignore = ['evaluate', 'checkpoint', 'history', 'tensorboard',
111113
'tb_interval', 'snapshot', 'summary',
112114
'src_path', 'data_path', 'result_path',
113115
'resume', 'stamp', 'gpus', 'ignore_search']
@@ -127,6 +129,9 @@ def search_same(args):
127129
for k, v in vars(args).items():
128130
if k in search_ignore:
129131
continue
132+
133+
if k == 'dataset' and k not in desc:
134+
desc[k] = 'imagenet'
130135

131136
if v != desc[k]:
132137
# if stamp == '210120_Wed_05_19_52':
@@ -157,7 +162,7 @@ def search_same(args):
157162

158163
if len(ckpt_list) > 0:
159164
args.snapshot = f'{args.result_path}/{args.task}/{args.stamp}/checkpoint/{ckpt_list[-1]}'
160-
initial_epoch = int(ckpt_list[-1].split('_')[0])
165+
initial_epoch = int(df['epoch'].iloc[-1]) + 1
161166
else:
162167
print('{} Training already finished!!!'.format(stamp))
163168
return args, -1

dataloader.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,32 @@
1010
AUTO = tf.data.experimental.AUTOTUNE
1111

1212

13-
def set_dataset(task, data_path):
14-
trainset = pd.read_csv(
15-
os.path.join(
16-
data_path, 'imagenet_trainset.csv'
17-
)).values.tolist()
18-
trainset = [[os.path.join(data_path, t[0]), t[1]] for t in trainset]
19-
20-
if task == 'lincls':
13+
def set_dataset(task, dataset, data_path):
14+
if dataset == 'imagenet':
15+
trainset = pd.read_csv(
16+
os.path.join(
17+
data_path, 'imagenet_trainset.csv'
18+
)).values.tolist()
19+
trainset = [[os.path.join(data_path, t[0]), t[1]] for t in trainset]
20+
2121
valset = pd.read_csv(
2222
os.path.join(
2323
data_path, 'imagenet_valset.csv'
2424
)).values.tolist()
2525
valset = [[os.path.join(data_path, t[0]), t[1]] for t in valset]
26-
return np.array(trainset, dtype='object'), np.array(valset, dtype='object')
26+
27+
elif dataset == 'cifar10':
28+
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
29+
trainset = [[i, l] for i, l in zip(x_train, y_train.flatten())]
30+
valset = [[i, l] for i, l in zip(x_test, y_test.flatten())]
2731

28-
return np.array(trainset, dtype='object')
32+
return np.array(trainset, dtype='object'), np.array(valset, dtype='object')
2933

3034

3135
class DataLoader:
32-
def __init__(self, args, mode, datalist, batch_size, num_workers=1, shuffle=True):
36+
def __init__(self, args, task, mode, datalist, batch_size, num_workers=1, shuffle=True):
3337
self.args = args
38+
self.task = task
3439
self.mode = mode
3540
self.datalist = datalist
3641
self.batch_size = batch_size
@@ -50,7 +55,7 @@ def fetch_dataset(self, path, y=None):
5055
return tf.data.Dataset.from_tensors(x)
5156

5257
def augmentation(self, img, shape):
53-
if self.args.task == 'pretext':
58+
if self.task == 'pretext':
5459
img_list = []
5560
for _ in range(2): # query, key
5661
aug_img = tf.identity(img)
@@ -61,20 +66,25 @@ def augmentation(self, img, shape):
6166
return self.augset._augment_lincls(img, shape)
6267

6368
def dataset_parser(self, value, label=None):
64-
shape = tf.image.extract_jpeg_shape(value)
65-
img = tf.io.decode_jpeg(value, channels=3)
69+
if self.args.dataset == 'imagenet':
70+
shape = tf.image.extract_jpeg_shape(value)
71+
img = tf.io.decode_jpeg(value, channels=3)
72+
elif self.args.dataset == 'cifar10':
73+
shape = (32, 32, 3)
74+
img = value
75+
6676
if label is None:
67-
# moco
77+
# pretext
6878
return self.augmentation(img, shape)
6979
else:
7080
# lincls
7181
inputs = self.augmentation(img, shape)
72-
labels = tf.one_hot(label, self.args.classes)
73-
return (inputs, labels)
82+
# labels = tf.one_hot(label, self.args.classes)
83+
return (inputs, label)
7484

7585
def _dataloader(self):
7686
self.imglist = self.datalist[:,0].tolist()
77-
if self.args.task == 'pretext':
87+
if self.task == 'pretext':
7888
dataset = tf.data.Dataset.from_tensor_slices(self.imglist)
7989
else:
8090
self.labellist = self.datalist[:,1].tolist()
@@ -84,7 +94,9 @@ def _dataloader(self):
8494
if self.shuffle:
8595
dataset = dataset.shuffle(len(self.datalist))
8696

87-
dataset = dataset.interleave(self.fetch_dataset, num_parallel_calls=AUTO)
97+
if self.args.dataset == 'imagenet':
98+
dataset = dataset.interleave(self.fetch_dataset, num_parallel_calls=AUTO)
99+
88100
dataset = dataset.map(self.dataset_parser, num_parallel_calls=AUTO)
89101
dataset = dataset.batch(self.batch_size)
90102
dataset = dataset.prefetch(AUTO)

layer.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,46 @@
11
import tensorflow as tf
2+
from tensorflow.keras.layers import Conv2D
3+
from tensorflow.keras.layers import Dense
4+
from tensorflow.keras.layers import BatchNormalization
5+
from tensorflow.keras.layers.experimental import SyncBatchNormalization
26
from tensorflow.keras.initializers import Constant
37

48

5-
class Conv2D(tf.keras.layers.Conv2D):
6-
def build(self, input_shape):
7-
k = 1 / input_shape[-1]
8-
self.kernel_initializer = Constant(tf.random.uniform([], -tf.sqrt(k), tf.sqrt(k)))
9-
super(Conv2D, self).build(input_shape)
9+
BatchNorm_DICT = {
10+
"bn": BatchNormalization,
11+
"syncbn": SyncBatchNormalization}
1012

1113

12-
class Dense(tf.keras.layers.Dense):
13-
def build(self, input_shape):
14-
k = 1 / input_shape[-1]
15-
self.kernel_initializer = Constant(tf.random.uniform([], -tf.sqrt(k), tf.sqrt(k)))
16-
super(Dense, self).build(input_shape)
14+
def _conv2d(**custom_kwargs):
15+
def _func(*args, **kwargs):
16+
kwargs.update(**custom_kwargs)
17+
return Conv2D(*args, **kwargs)
18+
return _func
19+
20+
21+
def _batchnorm(norm='bn', **custom_kwargs):
22+
def _func(*args, **kwargs):
23+
kwargs.update(**custom_kwargs)
24+
return BatchNorm_DICT[norm](*args, **kwargs)
25+
return _func
26+
27+
28+
def _dense(**custom_kwargs):
29+
def _func(*args, **kwargs):
30+
kwargs.update(**custom_kwargs)
31+
return Dense(*args, **kwargs)
32+
return _func
33+
34+
35+
# class Conv2D(tf.keras.layers.Conv2D):
36+
# def build(self, input_shape):
37+
# k = 1 / input_shape[-1]
38+
# self.kernel_initializer = Constant(tf.random.uniform([], -tf.sqrt(k), tf.sqrt(k)))
39+
# super(Conv2D, self).build(input_shape)
40+
41+
42+
# class Dense(tf.keras.layers.Dense):
43+
# def build(self, input_shape):
44+
# k = 1 / input_shape[-1]
45+
# self.kernel_initializer = Constant(tf.random.uniform([], -tf.sqrt(k), tf.sqrt(k)))
46+
# super(Dense, self).build(input_shape)

main.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@ def train_pretext(args, logger, initial_epoch, strategy, num_workers):
2121
##########################
2222
# Dataset
2323
##########################
24-
trainset = set_dataset(args.task, args.data_path)
24+
trainset, valset = set_dataset(args.task, args.dataset, args.data_path)
2525
steps_per_epoch = args.steps or len(trainset) // args.batch_size
2626

2727
logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
2828
logger.info("========== TRAINSET ==========")
2929
logger.info(f" --> {len(trainset)}")
3030
logger.info(f" --> {steps_per_epoch}")
3131

32+
logger.info("=========== VALSET ===========")
33+
logger.info(f" --> {len(valset)}")
34+
3235

3336
##########################
3437
# Model & Generator
@@ -41,8 +44,9 @@ def train_pretext(args, logger, initial_epoch, strategy, num_workers):
4144
optimizer=tf.keras.optimizers.SGD(lr_scheduler, momentum=.9),
4245
loss=tf.keras.losses.cosine_similarity,
4346
run_eagerly=False)
44-
45-
train_generator = DataLoader(args, 'train', trainset, args.batch_size, num_workers).dataloader
47+
48+
train_generator = DataLoader(args, args.task, 'train', trainset, args.batch_size, num_workers).dataloader
49+
4650

4751
##########################
4852
# Train
@@ -63,11 +67,11 @@ def train_pretext(args, logger, initial_epoch, strategy, num_workers):
6367

6468

6569
def train_lincls(args, logger, initial_epoch, strategy, num_workers):
66-
assert args.snapshot is not None, 'pretrained weight is needed!'
70+
# assert args.snapshot is not None, 'pretrained weight is needed!'
6771
##########################
6872
# Dataset
6973
##########################
70-
trainset, valset = set_dataset(args.task, args.data_path)
74+
trainset, valset = set_dataset(args.task, args.dataset, args.data_path)
7175
steps_per_epoch = args.steps or len(trainset) // args.batch_size
7276
validation_steps = len(valset) // args.batch_size
7377

@@ -84,19 +88,20 @@ def train_lincls(args, logger, initial_epoch, strategy, num_workers):
8488
##########################
8589
# Model & Generator
8690
##########################
91+
train_generator = DataLoader(args, args.task, 'train', trainset, args.batch_size, num_workers).dataloader
92+
val_generator = DataLoader(args, args.task, 'val', valset, args.batch_size, num_workers).dataloader
93+
8794
with strategy.scope():
8895
backbone = SimSiam(args, logger)
8996
model = set_lincls(args, backbone.encoder)
9097

9198
lr_scheduler = OptionalLearningRateSchedule(args, steps_per_epoch, initial_epoch)
9299
model.compile(
93100
optimizer=tf.keras.optimizers.SGD(lr_scheduler, momentum=.9),
94-
metrics=[tf.keras.metrics.TopKCategoricalAccuracy(1, 'acc1', dtype=tf.float32),
95-
tf.keras.metrics.TopKCategoricalAccuracy(5, 'acc5', dtype=tf.float32)],
96-
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, name='loss'))
97-
98-
train_generator = DataLoader(args, 'train', trainset, args.batch_size, num_workers).dataloader
99-
val_generator = DataLoader(args, 'val', valset, args.batch_size, num_workers).dataloader
101+
metrics=[tf.keras.metrics.SparseTopKCategoricalAccuracy(1, 'acc1', dtype=tf.float32),
102+
tf.keras.metrics.SparseTopKCategoricalAccuracy(5, 'acc5', dtype=tf.float32)],
103+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, name='loss'),
104+
run_eagerly=False)
100105

101106

102107
##########################
@@ -123,9 +128,15 @@ def main():
123128
set_seed()
124129
args = get_arguments()
125130
if args.task == 'pretext':
126-
args.lr = 0.5 * float(args.batch_size / 256)
131+
if args.dataset == 'imagenet':
132+
args.lr = 0.5 * float(args.batch_size / 256)
133+
elif args.dataset == 'cifar10':
134+
args.lr = 0.03 * float(args.batch_size / 256)
127135
else:
128-
args.lr = 30. * float(args.batch_size / 256)
136+
if args.dataset == 'imagenet' and args.freeze:
137+
args.lr = 30. * float(args.batch_size / 256)
138+
else:# args.dataset == 'cifar10':
139+
args.lr = 1.8 * float(args.batch_size / 256)
129140

130141
args, initial_epoch = search_same(args)
131142
if initial_epoch == -1:

0 commit comments

Comments
 (0)