|
1 |
| -# Copyright (c) 2021 PPViT Authors. All Rights Reserved. |
| 1 | +# Copyright (c) 2021 PPViT Authors. All Rights Reserved. |
2 | 2 | #
|
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License");
|
4 | 4 | # you may not use this file except in compliance with the License.
|
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +"""Generate images using trained models""" |
15 | 16 | import argparse
|
16 |
| -import numpy |
17 |
| -import paddle |
| 17 | +import os |
18 | 18 | from PIL import Image
|
| 19 | +import paddle |
19 | 20 | from generator import Generator
|
20 |
| -from config import * |
| 21 | +from config import get_config |
| 22 | +from config import update_config |
21 | 23 |
|
22 | 24 |
|
23 | 25 | def main():
|
24 |
| - # get default config |
25 |
| - parser = argparse.ArgumentParser('') |
| 26 | + """ generate sample images using pretrained model |
| 27 | + The following args are required: |
| 28 | + -cfg: str, path of yaml model config file |
| 29 | + -pretrained: str, path of the pretrained model (ends with .pdparams) |
| 30 | + -num_out_images: int, the num of output images to be saved in file |
| 31 | + -out_folder: str, output folder path. |
| 32 | + """ |
| 33 | + paddle.set_device('gpu') |
| 34 | + # get config |
| 35 | + parser = argparse.ArgumentParser('Generate samples images') |
26 | 36 | parser.add_argument('-cfg', type=str, default='./configs/styleformer_cifar10.yaml')
|
27 |
| - parser.add_argument('-dataset', type=str, default="cifar10") |
| 37 | + parser.add_argument('-pretrained', type=str, default='./lsun.pdparams') |
| 38 | + parser.add_argument('-num_out_images', type=int, default=16) |
| 39 | + parser.add_argument('-out_folder', type=str, default='./out_images_lsun') |
| 40 | + |
| 41 | + parser.add_argument('-dataset', type=str, default=None) |
28 | 42 | parser.add_argument('-batch_size', type=int, default=None)
|
29 | 43 | parser.add_argument('-image_size', type=int, default=None)
|
30 | 44 | parser.add_argument('-ngpus', type=int, default=None)
|
31 |
| - parser.add_argument('-data_path', type=str, default='/dataset/cifar10/') |
| 45 | + parser.add_argument('-data_path', type=str, default=None) |
32 | 46 | parser.add_argument('-eval', action="store_true")
|
33 |
| - parser.add_argument('-pretrained', type=str, default=None) |
| 47 | + |
34 | 48 | args = parser.parse_args()
|
35 | 49 | config = get_config()
|
36 | 50 | config = update_config(config, args)
|
37 |
| - |
38 |
| - paddle.set_device('cpu') |
| 51 | + # get model |
| 52 | + print(f'----- Creating model...') |
39 | 53 | paddle_model = Generator(config)
|
40 | 54 | paddle_model.eval()
|
41 |
| - |
42 |
| - pre=paddle.load(r'./cifar10.pdparams') |
43 |
| - paddle_model.load_dict(pre) |
44 |
| - |
45 |
| - x = paddle.randn([32, 512]) |
46 |
| - x_paddle = paddle.to_tensor(x) |
47 |
| - out_paddle = paddle_model(x_paddle, c=paddle.randint(0, 10, [32])) |
48 |
| - |
49 |
| - gen_imgs=paddle.multiply(out_paddle,paddle.to_tensor(127.5)) |
50 |
| - gen_imgs=paddle.clip(paddle.add(gen_imgs,paddle.to_tensor(127.5)).transpose((0,2,3,1)), |
51 |
| - min=0.0,max=255.0).astype('uint8').cpu().numpy() |
52 |
| - |
53 |
| - for i in range(len(gen_imgs)): |
54 |
| - im = Image.fromarray(gen_imgs[i], 'RGB') |
55 |
| - im.save("./image/"+str(i)+".png") |
| 55 | + # load model weights |
| 56 | + print(f'----- Loading model form {config.MODEL.PRETRAINED}...') |
| 57 | + model_state_dict = paddle.load(config.MODEL.PRETRAINED) |
| 58 | + paddle_model.load_dict(model_state_dict) |
| 59 | + # get random input tensor |
| 60 | + x_paddle = paddle.randn([args.num_out_images, paddle_model.z_dim]) |
| 61 | + # inference |
| 62 | + print(f'----- Inferencing...') |
| 63 | + out_paddle = paddle_model( |
| 64 | + z=x_paddle, c=paddle.randint(0, config.MODEL.NUM_CLASSES, [args.num_out_images])) |
| 65 | + # post processing to obtain image |
| 66 | + print('----- Postprocessing') |
| 67 | + gen_imgs = (out_paddle * 127.5 + 128).clip(0, 255) |
| 68 | + gen_imgs = gen_imgs.transpose((0, 2, 3, 1)).astype('uint8') |
| 69 | + gen_imgs = gen_imgs.cpu().numpy() |
| 70 | + # save images to file |
| 71 | + os.makedirs(args.out_folder, exist_ok=True) |
| 72 | + print(f'----- Saving images to {args.out_folder}') |
| 73 | + for i, gen_img in enumerate(gen_imgs): |
| 74 | + img = Image.fromarray(gen_img, 'RGB') |
| 75 | + out_path = os.path.join(args.out_folder, str(i) + '.png') |
| 76 | + img.save(out_path) |
56 | 77 |
|
57 | 78 |
|
58 | 79 | if __name__ == "__main__":
|
|
0 commit comments