Skip to content

【Hackathon 8th No.23】Improved Training of Wasserstein GANs 论文复现 #1146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 66 commits into
base: develop
Choose a base branch
from

Conversation

XvLingWYY
Copy link

PR types

New Features

PR changes

Others

Describe

Add wgangp

Copy link

paddle-bot bot commented Apr 27, 2025

Thanks for your contribution!

@XvLingWYY
Copy link
Author

@robinbg

@XvLingWYY XvLingWYY changed the title 【Hackathon 8th No.23】RFC:Improved Training of Wasserstein GANs 论文复现 【Hackathon 8th No.23】Improved Training of Wasserstein GANs 论文复现 Apr 27, 2025
@luotao1
Copy link
Collaborator

luotao1 commented May 12, 2025

@lijialin03
Copy link
Contributor

对于相同目的的反复修改,在git push之前在本地修改的时候,可以使用git commit --amend,而不是git commit -m "新的信息",这样可以把修改改在上一个commit上,然后再push就会更新commit而不是创建新的

@XvLingWYY
Copy link
Author

通过网盘分享的文件:models.zip
链接: https://pan.baidu.com/s/1ekP99D2ylox3yzqxQkJkyw?pwd=wgan 提取码: wgan
这是更改后的模型和可视化内容,在之前的基础上增加了MNIST和CIFAR10的可视化图片。
由于paddle和tensorflow中inception v3模型不一样,以及模型转化始终报错导致使用我的实现的S分数评估模型模型只有5.2左右的分数,所以在CIFAR10中附加了模型评估代码使用的是原作者仓库中的模型。使用原作者仓库中的模型IS分数可以达到要求。
对于mnist当前参数训练后再评估,mse loss约0.47,在论文中作者并未对MNIST数据集上的实验进行mse loss评估,从可视化的图片中可以看见模型生成的图片与真实图片一样都是手写数字并且都可以识别出内容并非不同。

@lijialin03
Copy link
Contributor

通过网盘分享的文件:models.zip 链接: https://pan.baidu.com/s/1ekP99D2ylox3yzqxQkJkyw?pwd=wgan 提取码: wgan 这是更改后的模型和可视化内容,在之前的基础上增加了MNIST和CIFAR10的可视化图片。 由于paddle和tensorflow中inception v3模型不一样,以及模型转化始终报错导致使用我的实现的S分数评估模型模型只有5.2左右的分数,所以在CIFAR10中附加了模型评估代码使用的是原作者仓库中的模型。使用原作者仓库中的模型IS分数可以达到要求。 对于mnist当前参数训练后再评估,mse loss约0.47,在论文中作者并未对MNIST数据集上的实验进行mse loss评估,从可视化的图片中可以看见模型生成的图片与真实图片一样都是手写数字并且都可以识别出内容并非不同。

网盘里的包下载打开后,mnist这个目录中的图片是:
434f5fd1be3f008a0bee835be11e9e4f
imagex和image_real_x的标号是一一对应的吗,如果是的话这看起来不是都不相同吗……?

@XvLingWYY
Copy link
Author

imagex和image_real_x的图片两个并没有任何关系,只是恰好出现在一个batch_size里

@XvLingWYY
Copy link
Author

当向不了解生成对抗网络(GAN)的人解释为什么在每个batch size中真实图片和生成的图片并不一样,不能使用均方误差(MSE)损失评估时,可以这样来解释:

想象你正在教一个孩子画画。这个孩子刚开始学习如何画猫。每次你给他看一张真实的猫的照片,并让他尝试画出这张照片中的猫。但这里有个问题:孩子每次画出来的猫都不一样,而且与任何他之前看到的真实猫的照片也不完全相同。

如果我们用一种非常简单的方式来评价孩子的作品,比如说“看看你的画和这张照片里的猫有多接近”,我们实际上是在测量两幅图像之间的像素差异——这就像MSE损失所做的。但是,在这种情况下,这种方法并不是很有帮助,因为目标不是让孩子精确地复制每一张照片,而是要让他们能够画出看起来像猫的图像,即使这些图像没有一张与特定的照片一模一样。

在GANs中,生成器的任务是创建看起来逼真的图像,而不是精确复制训练集中的图像。因此,使用MSE这样的度量标准来比较生成的图像和真实图像并不合适,因为它会惩罚那些虽然看起来很真实但与具体某张训练图像不同的生成结果。相反,GANs使用了一种更智能的方法,即通过判别器网络来判断一张图像是来自真实数据分布还是由生成器创造的,从而引导生成器产生更加逼真的图像。

总结来说,关键点在于GAN的目标是让生成的图像具有足够的真实感,而不是让它们与某个特定的真实图像尽可能相似。这就需要采用不同于MSE的评价方法,比如对抗性损失,来指导模型的学习过程。

@lijialin03
Copy link
Contributor

当向不了解生成对抗网络(GAN)的人解释为什么在每个batch size中真实图片和生成的图片并不一样,不能使用均方误差(MSE)损失评估时,可以这样来解释:

想象你正在教一个孩子画画。这个孩子刚开始学习如何画猫。每次你给他看一张真实的猫的照片,并让他尝试画出这张照片中的猫。但这里有个问题:孩子每次画出来的猫都不一样,而且与任何他之前看到的真实猫的照片也不完全相同。

如果我们用一种非常简单的方式来评价孩子的作品,比如说“看看你的画和这张照片里的猫有多接近”,我们实际上是在测量两幅图像之间的像素差异——这就像MSE损失所做的。但是,在这种情况下,这种方法并不是很有帮助,因为目标不是让孩子精确地复制每一张照片,而是要让他们能够画出看起来像猫的图像,即使这些图像没有一张与特定的照片一模一样。

在GANs中,生成器的任务是创建看起来逼真的图像,而不是精确复制训练集中的图像。因此,使用MSE这样的度量标准来比较生成的图像和真实图像并不合适,因为它会惩罚那些虽然看起来很真实但与具体某张训练图像不同的生成结果。相反,GANs使用了一种更智能的方法,即通过判别器网络来判断一张图像是来自真实数据分布还是由生成器创造的,从而引导生成器产生更加逼真的图像。

总结来说,关键点在于GAN的目标是让生成的图像具有足够的真实感,而不是让它们与某个特定的真实图像尽可能相似。这就需要采用不同于MSE的评价方法,比如对抗性损失,来指导模型的学习过程。

好的,明白了,谢谢

@XvLingWYY
Copy link
Author

老师距离上次提交已经过了4个工作日了,还有需要修改的地方吗?

Copy link
Contributor

@lijialin03 lijialin03 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

运行结果上没什么问题了,但是还有下面这些点可能容易让人疑惑,还可以再优化一下吗?
另外改完之后点resolved就行,不用再逐条回复了,谢谢

Comment on lines 82 to 86
for i in range(
cfg["EVAL"]["batch_size"]
if cfg["EVAL"]["batch_size"] < cfg.VIS.num
else cfg.VIS.num
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个应该是因为batch_size变为1之后影响了BatchNorm,测试了一下batch_size最少要在8左右基本对结果没有影响。另外还有个问题是为什么batch_size变为1之后,eval过程的loss值没有变化(且都很大),我看代码loss里是判断的当前图片和真实图片的判别器结果的差异,所以说不论batch_size变不变,判别器都认为是假图片吗

@XvLingWYY
Copy link
Author

”这个应该是因为batch_size变为1之后影响了BatchNorm,测试了一下batch_size最少要在8左右基本对结果没有影响。“
generator_model.eval()
添加这一行后可以有效解决图片质量下降的问题
“另外还有个问题是为什么batch_size变为1之后,eval过程的loss值没有变化(且都很大),我看代码loss里是判断的当前图片和真实图片的判别器结果的差异,所以说不论batch_size变不变,判别器都认为是假图片吗”
我把Cifar10GenFuncs中的fake_data替换为real_data后发现输入真实图片后结果相同loss值没有变化(且都很大)所以模型只是分不清,并且图片的质量主要是靠IS分数进行评估。

Copy link
Contributor

@lijialin03 lijialin03 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants