Skip to content

Commit d078be9

Browse files
committed
fix lsun data bug
1 parent 02ecd09 commit d078be9

32 files changed

+1863
-473
lines changed

gan/celeba_dataset.py gan/Styleformer/celeba_dataset.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
"""
16-
CelebA Dataset related classes and methods
16+
CelebA Dataset related classes and methods
1717
Currently only support for GAN
1818
"""
1919

@@ -44,17 +44,17 @@ def __len__(self):
4444
return len(self.img_path_list)
4545

4646
def __getitem__(self, index):
47-
data = Image.open(self.img_path_list[index]).convert('RGB')
47+
img = Image.open(self.img_path_list[index]).convert('RGB')
4848
if self.transform is not None:
49-
data = self.transform(data)
49+
img = self.transform(img)
5050
label = 0
51-
return data, label
52-
53-
if __name__ == "__main__":
54-
dataset = CelebADataset(file_folder='./celeba/img_align_celeba')
55-
for idx, (data, label) in enumerate(dataset):
56-
print(idx)
57-
print(data.size)
58-
print('-----')
59-
if idx == 10:
60-
break
51+
return img, label
52+
53+
#if __name__ == "__main__":
54+
# dataset = CelebADataset(file_folder='./celeba/img_align_celeba')
55+
# for idx, (data, label) in enumerate(dataset):
56+
# print(idx)
57+
# print(data.size)
58+
# print('-----')
59+
# if idx == 10:
60+
# break

gan/Styleformer/config.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
Configuration for data, model archtecture, and training, etc.
1818
Config can be set by .yaml file or by argparser(limited usage)
1919
20-
2120
"""
2221

2322
import os
@@ -37,6 +36,8 @@
3736
_C.DATA.CHANNEL = 3 # input image channel
3837
_C.DATA.CROP_PCT = 1.0 # input image scale ratio, scale is applied before centercrop in eval mode
3938
_C.DATA.NUM_WORKERS = 2 # number of data loading threads
39+
_C.DATA.MAX_REAL_NUM = None # number of images used in the dataset (real images)
40+
_C.DATA.MAX_GEN_NUM = None # number of images used in the generator (fake images)
4041

4142
# model settings
4243
_C.MODEL = CN()
@@ -73,7 +74,6 @@
7374
_C.TRAIN.WARMUP_START_LR = 0.0
7475
_C.TRAIN.END_LR = 0.0
7576
_C.TRAIN.GRAD_CLIP = 1.0
76-
_C.TRAIN.ACCUM_ITER = 2
7777

7878
_C.TRAIN.LR_SCHEDULER = CN()
7979
_C.TRAIN.LR_SCHEDULER.NAME = 'warmupcosine'
@@ -161,7 +161,9 @@ def update_config(config, args):
161161
return config
162162

163163

164-
def get_config():
165-
"""Return a clone config"""
164+
def get_config(cfg_file=None):
165+
"""Return a clone of config or load from yaml file"""
166166
config = _C.clone()
167+
if cfg_file:
168+
_update_config_from_file(config, cfg_file)
167169
return config

gan/Styleformer/configs/styleformer_celeba.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
DATA:
22
IMAGE_SIZE: 64
3+
MAX_GEN_NUM: 50000
4+
MAX_REAL_NUM: None
35
MODEL:
46
TYPE: Styleformer
57
NAME: Styleformer_Linformer
8+
NUM_CLASSES: 10177
69
GEN:
710
RESOLUTION: 8
811
NUM_LAYERS: [1,2,1,1]

gan/Styleformer/configs/styleformer_cifar10.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
DATA:
22
IMAGE_SIZE: 32
3+
MAX_GEN_NUM: 50000
4+
MAX_REAL_NUM: None
35
MODEL:
46
TYPE: Styleformer
57
NAME: Styleformer_Large
8+
NUM_CLASSES: 10
69
GEN:
710
RESOLUTION: 8
811
NUM_LAYERS: [1,3,3]

gan/Styleformer/configs/styleformer_lsun.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
DATA:
22
IMAGE_SIZE: 128
3+
MAX_GEN_NUM: 50000
4+
MAX_REAL_NUM: None
35
MODEL:
46
TYPE: Styleformer
57
NAME: Styleformer_Linformer
8+
NUM_CLASSES: 1
69
GEN:
710
RESOLUTION: 8
811
NUM_LAYERS: [1,2,1,1,1]

gan/Styleformer/configs/styleformer_stl10.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
DATA:
22
IMAGE_SIZE: 48
3+
MAX_GEN_NUM: 50000
4+
MAX_REAL_NUM: None
35
MODEL:
46
TYPE: Styleformer
57
NAME: Styleformer_Medium
8+
NUM_CLASSES: 1 # unlabeled data, all class 0
69
GEN:
710
RESOLUTION: 12
811
NUM_LAYERS: [1,3,3]

gan/Styleformer/datasets.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919

2020
import os
2121
import math
22-
import sys
23-
import numpy as np
24-
from paddle.io import Dataset, DataLoader, DistributedBatchSampler
25-
from paddle.vision import transforms, datasets, image_load
26-
sys.path.append('../')
27-
from stl10_dataset import *
28-
from lsun_church_dataset import *
29-
from celeba_dataset import *
22+
from paddle.io import Dataset
23+
from paddle.io import DataLoader
24+
from paddle.io import DistributedBatchSampler
25+
from paddle.vision import transforms
26+
from paddle.vision import datasets
27+
from paddle.vision import image_load
28+
from stl10_dataset import STL10Dataset
29+
from lsun_church_dataset import LSUNchurchDataset
30+
from celeba_dataset import CelebADataset
3031

3132
class ImageNet2012Dataset(Dataset):
3233
"""Build ImageNet2012 dataset
@@ -140,7 +141,8 @@ def get_dataset(config, mode='train'):
140141
mode=mode,
141142
transform=get_train_transforms(config))
142143
else:
143-
mode = 'test'
144+
#mode = 'test'
145+
mode = 'unlabeled'
144146
dataset = STL10Dataset(file_folder=config.DATA.DATA_PATH,
145147
mode=mode,
146148
transform=get_val_transforms(config))
@@ -156,10 +158,10 @@ def get_dataset(config, mode='train'):
156158
elif config.DATA.DATASET == "celeba":
157159
if mode == 'train':
158160
dataset = CelebADataset(file_folder=config.DATA.DATA_PATH,
159-
transform=get_train_transforms(config))
161+
transform=get_train_transforms(config))
160162
else:
161163
dataset = CelebADataset(file_folder=config.DATA.DATA_PATH,
162-
transform=get_val_transforms(config))
164+
transform=get_val_transforms(config))
163165
elif config.DATA.DATASET == "imagenet2012":
164166
if mode == 'train':
165167
dataset = ImageNet2012Dataset(config.DATA.DATA_PATH,
@@ -171,7 +173,7 @@ def get_dataset(config, mode='train'):
171173
transform=get_val_transforms(config))
172174
else:
173175
raise NotImplementedError(
174-
"[{config.DATA.DATASET}] Only cifar10, cifar100, imagenet2012 are supported now")
176+
"Only support cifar10, cifar100, imagenet2012, celeba, stl10, lsun")
175177
return dataset
176178

177179

gan/Styleformer/discriminator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2021 PPViT Authors. All Rights Reserved.
1+
# Copyright (c) 2021 PPViT Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -167,4 +167,4 @@ def forward(self, input):
167167
out = out.reshape((batch, -1))
168168
out = self.final_linear(out)
169169

170-
return out
170+
return out

gan/Styleformer/generate.py

+47-26
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2021 PPViT Authors. All Rights Reserved.
1+
# Copyright (c) 2021 PPViT Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,47 +12,68 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
"""Generate images using trained models"""
1516
import argparse
16-
import numpy
17-
import paddle
17+
import os
1818
from PIL import Image
19+
import paddle
1920
from generator import Generator
20-
from config import *
21+
from config import get_config
22+
from config import update_config
2123

2224

2325
def main():
24-
# get default config
25-
parser = argparse.ArgumentParser('')
26+
""" generate sample images using pretrained model
27+
The following args are required:
28+
-cfg: str, path of yaml model config file
29+
-pretrained: str, path of the pretrained model (ends with .pdparams)
30+
-num_out_images: int, the num of output images to be saved in file
31+
-out_folder: str, output folder path.
32+
"""
33+
paddle.set_device('gpu')
34+
# get config
35+
parser = argparse.ArgumentParser('Generate samples images')
2636
parser.add_argument('-cfg', type=str, default='./configs/styleformer_cifar10.yaml')
27-
parser.add_argument('-dataset', type=str, default="cifar10")
37+
parser.add_argument('-pretrained', type=str, default='./lsun.pdparams')
38+
parser.add_argument('-num_out_images', type=int, default=16)
39+
parser.add_argument('-out_folder', type=str, default='./out_images_lsun')
40+
41+
parser.add_argument('-dataset', type=str, default=None)
2842
parser.add_argument('-batch_size', type=int, default=None)
2943
parser.add_argument('-image_size', type=int, default=None)
3044
parser.add_argument('-ngpus', type=int, default=None)
31-
parser.add_argument('-data_path', type=str, default='/dataset/cifar10/')
45+
parser.add_argument('-data_path', type=str, default=None)
3246
parser.add_argument('-eval', action="store_true")
33-
parser.add_argument('-pretrained', type=str, default=None)
47+
3448
args = parser.parse_args()
3549
config = get_config()
3650
config = update_config(config, args)
37-
38-
paddle.set_device('cpu')
51+
# get model
52+
print(f'----- Creating model...')
3953
paddle_model = Generator(config)
4054
paddle_model.eval()
41-
42-
pre=paddle.load(r'./cifar10.pdparams')
43-
paddle_model.load_dict(pre)
44-
45-
x = paddle.randn([32, 512])
46-
x_paddle = paddle.to_tensor(x)
47-
out_paddle = paddle_model(x_paddle, c=paddle.randint(0, 10, [32]))
48-
49-
gen_imgs=paddle.multiply(out_paddle,paddle.to_tensor(127.5))
50-
gen_imgs=paddle.clip(paddle.add(gen_imgs,paddle.to_tensor(127.5)).transpose((0,2,3,1)),
51-
min=0.0,max=255.0).astype('uint8').cpu().numpy()
52-
53-
for i in range(len(gen_imgs)):
54-
im = Image.fromarray(gen_imgs[i], 'RGB')
55-
im.save("./image/"+str(i)+".png")
55+
# load model weights
56+
print(f'----- Loading model form {config.MODEL.PRETRAINED}...')
57+
model_state_dict = paddle.load(config.MODEL.PRETRAINED)
58+
paddle_model.load_dict(model_state_dict)
59+
# get random input tensor
60+
x_paddle = paddle.randn([args.num_out_images, paddle_model.z_dim])
61+
# inference
62+
print(f'----- Inferencing...')
63+
out_paddle = paddle_model(
64+
z=x_paddle, c=paddle.randint(0, config.MODEL.NUM_CLASSES, [args.num_out_images]))
65+
# post processing to obtain image
66+
print('----- Postprocessing')
67+
gen_imgs = (out_paddle * 127.5 + 128).clip(0, 255)
68+
gen_imgs = gen_imgs.transpose((0, 2, 3, 1)).astype('uint8')
69+
gen_imgs = gen_imgs.cpu().numpy()
70+
# save images to file
71+
os.makedirs(args.out_folder, exist_ok=True)
72+
print(f'----- Saving images to {args.out_folder}')
73+
for i, gen_img in enumerate(gen_imgs):
74+
img = Image.fromarray(gen_img, 'RGB')
75+
out_path = os.path.join(args.out_folder, str(i) + '.png')
76+
img.save(out_path)
5677

5778

5879
if __name__ == "__main__":

0 commit comments

Comments
 (0)