Skip to content

Commit e416cd1

Browse files
committed
feat: fix save model
1 parent ab4db46 commit e416cd1

12 files changed

+9
-5
lines changed

.DS_Store

0 Bytes
Binary file not shown.

apps/.DS_Store

6 KB
Binary file not shown.

apps/GAN/.DS_Store

0 Bytes
Binary file not shown.

apps/GAN/person_face_gen.zip

26.6 KB
Binary file not shown.

apps/GAN/person_face_gen/.DS_Store

6 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

apps/GAN/person_face_gen/const.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# 图片尺寸
66
image_size = 64
77
# 数据load多线程处理数
8-
workers = 2
8+
workers = 5
99
# 生成器特征数
1010
ngf = 64
1111
# 判别器特征数
@@ -16,10 +16,10 @@
1616
nc=3
1717

1818
# Number of GPUs available. Use 0 for CPU mode.
19-
ngpu = 0
19+
ngpu = 1
2020

2121
# Number of training epochs
22-
num_epochs = 5
22+
num_epochs = 1
2323

2424
# Learning rate for optimizers
2525
lr = 0.0002

apps/GAN/person_face_gen/gan.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from device import device
1515
from plot import plot_original_dataset,loss_plot,image_show
1616
from const import ngpu, nz, ngf, nc, ndf, batch_size, workers, image_size
17+
from generator import Generator
18+
from discriminator import Discriminator
1719

1820
if __name__ == '__main__':
1921
#加载数据
@@ -30,8 +32,9 @@
3032
#绘制生成图像
3133
image_show(img_list)
3234
#保存模型
33-
netD.save_model('./models')
34-
netG.save_model('./models')
35+
Generator.save_model(netG,'./models')
36+
Discriminator.save(netD,'./models')
37+
3538
#保存图片
3639
os.makedirs('./images', exist_ok=True)
3740
vutils.save_image(img_list[-1], './images/fake_images.png', normalize=True)

apps/GAN/person_face_gen/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def get_generator(nz, ngf, nc):
3535
# to ``mean=0``, ``stdev=0.02``.
3636
netG.apply(weights_init)
3737

38+
3839
return netG
3940

4041

0 commit comments

Comments
 (0)