Skip to content

【Hackathon 8th No.23】Improved Training of Wasserstein GANs 论文复现 #1152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
| 天气预报 | [DGMR 气象预报](./zh/examples/dgmr.md) | 数据驱动 | GAN | 监督学习 | [UK dataset](https://huggingface.co/datasets/openclimatefix/nimrod-uk-1km) | [Paper](https://arxiv.org/pdf/2104.00954.pdf) |
| 地震波形反演 | [VelocityGAN 地震波形反演](./zh/examples/velocity_gan.md) | 数据驱动 | VelocityGAN | 监督学习 | [OpenFWI](https://openfwi-lanl.github.io/docs/data.html#vel) | [Paper](https://arxiv.org/abs/1809.10262v6) |
| 交通预测 | [TGCN 交通流量预测](./zh/examples/tgcn.md) | 数据驱动 | GCN & CNN | 监督学习 | [PEMSD4 & PEMSD8](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tgcn/tgcn_data.zip) | - |
| 生成模型| [图像生成中的梯度惩罚应用](./zh/examples/wgan_gp.md)|数据驱动|WGAN GP|监督学习|[Data1](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)<br>[Data2](http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz)| [Paper](https://github.com/igul222/improved_wgan_training) |

<br>
<p align="center"><b>化学科学 (AI for Chemistry)</b></p>
Expand Down
350 changes: 350 additions & 0 deletions docs/zh/examples/wgan_gp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,350 @@
# WGANGP

!!! note

1. 运行之前将[Cifar10](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)下载,并更新wgangp_cifar10.yaml中的data_path
2. 运行之前将[MINST](http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz)下载,并更新wgangp_mnist.yaml中的data_path

=== "模型训练命令"
```sh
python wgangp_cifar10.py
```
```sh
python wgangp_mnist.py
```
```sh
python wgangp_toy.py
```

=== "模型评估命令"
```sh
python wgangp_cifar10.py mode=eval
```
```sh
python wgangp_mnist.py mode=eval
```
```sh
python wgangp_toy.py mode=eval
```


| 预训练模型 |
|:-----------------------------------|
| wgangp_cifar10_pretrained.pdparams |
| wgangp_mnist_pretrained.pdparams |
| wgangp_toy_pretrained.pdparams |
## 1. 背景简介
在数字图像处理和机器学习领域,生成对抗网络(GANs)因其卓越的图像生成能力而受到广泛关注。然而,传统的GAN架构在训练过程中可能会遇到不稳定的问题,尤其是在生成高分辨率或复杂场景的图像时。为了解决这些问题,研究人员提出了带有梯度惩罚的Wasserstein生成对抗网络(WGAN-GP),它不仅增强了训练过程的稳定性,还显著提升了生成图像的质量。

WGAN-GP通过改进损失函数来最小化真实数据分布与生成数据分布之间的差异,并引入梯度惩罚机制以确保训练过程中的平滑性和稳定性。这种优化方法克服了传统GAN中常见的模式崩溃问题,同时促进了更高效的训练和更逼真的图像生成。

## 2. 模型原理
WGAN-GP提出一种替代权重剪裁的方法:对评论者输入梯度的范数施加惩罚。在几乎无需超参数调整的情况下稳定训练多种GAN架构.

### 2.1 模型结构

WGAN-GP是一个条件对抗网络,包含了一个noise-to-image的生成器和一个CNN的判别器。下面显示了模型的整体结构。

```
noise===>generator===>fake_image==
==>discriminator===>Wasserstein Loss+Gradient Penalty
image==
```

- `Generator`是一种卷积神经网络。

- `Discriminator`是由卷积块组成的模型。输入图像,输出图像的真实性分数。

### 2.2 损失函数

判别器的损失函数采用了Wasserstein损失和梯度惩罚。其表达式为:

$$
L_d = \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}} D(\tilde{x}) - \underset{x \sim \mathbb{P}_r}{\mathbb{E}}D(x) + \lambda \underset{\hat{x} \sim \mathbb{P}_{\hat{x}}}{\mathbb{E}} \left[ \left( \| \nabla_{\hat{x}} D(\hat{x}) \|_2 - 1 \right)^2 \right]
$$

其中$\mathbb{P}_g$是生成器的分布,$\mathbb{P}_r$是真实数据的分布,$\mathbb{P}_{\hat{x}}$是来自$\mathbb{P}_g$和$\mathbb{P}_r$的混合插值样本。

生成器的损失函数是对抗性损失[$- \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}}D(\tilde{x})$]和内容损失(MAE、MSE)的组合。其表达式为:

$$
L_g = - \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}}D(\tilde{x})
$$
其中$\mathbb{P}_g$是生成器的分布

## 3. 模型构建

接下来开始讲解如何使用PaddleScience框架实现WGAN-GP。以下内容仅对关键步骤进行阐述,其余细节请参考 [API文档](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/api/arch/)。

### 3.1 数据集介绍

数据集采用了[Cifar10](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)数据集、[MNIST](http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz)和玩具数据集(swissroll/8gaussians/25gaussians)。

Cifar10数据集包含60000张32x32彩色图像,共分为10个类别,每个类别6000张图像。

Cifar10数据集有3个版本

| Version | Size | md5sum |
|:-----------------|:------------|:-------------------------------------|
| CIFAR-100 python | 161 MB | eb9058c3a382ffc7106e4002c42a8d85 |
| CIFAR-100 Matlab | 175 MB | 6a4bfa1dcd5c9453dda6bb54194911f4 |
| CIFAR-100 binary | 161 MB | 03b5dce01913d631647c71ecec9e9cb8 |

本实现使用的为CIFAR-100 python版本

MNIST数据集包含60000张28x28灰度图像,共分为10个类别,每个类别6000张图像。

玩具数据集

Swissroll:三维非线性流形数据集,呈现连续卷曲的螺旋结构,

8gaussians:二维合成数据集,包含八个对称分布的高斯簇,各簇中心均匀分布于圆周,

25gaussians:高密度高斯混合数据集,由25个规则排列的二维高斯分布构成,簇间距紧凑。

### 3.2 构建dataset API

由于Cifar10数据集由5个数据文件组成,由于数据集组织方式,我们无法直接使用PaddleScience内置的dataset API,所以先把所有数据读取出来,再使用```ppsci.data.dataset.array_dataset.NamedArrayDataset```。

下面给出Cifar10数据集读取的代码:
``` py linenums="167"
--8<--
examples/wgangp/functions.py:167:177
--8<--
```
其中`data_path`传入的是CIFAR-10的路径。

下面给出dataloader的配置代码:
``` py linenums="112"
--8<--
examples/wgangp/wgangp_cifar10.py:112:126
--8<--
```

由于MNIST数据集无法直接使用PaddleScience内置的dataset API,所以先把所有数据读取出来,再使用```ppsci.data.dataset.array_dataset.NamedArrayDataset```。

下面给出MNIST数据集读取的代码:
``` py linenums="368"
--8<--
examples/wgangp/functions.py:368:377
--8<--
```

下面给出dataloader的配置代码:
``` py linenums="101"
--8<--
examples/wgangp/wgangp_mnist.py:101:114
--8<--
```

由于玩具数据集无法直接使用PaddleScience内置的dataset API,所以先把所有数据生成出来,再使用```ppsci.data.dataset.array_dataset.NamedArrayDataset```。

下面给出玩具数据集的生成代码
``` py linenums="194"
--8<--
examples/wgangp/functions.py:194:236
--8<--
```

下面给出dataloader的配置代码:
``` py linenums="94"
--8<--
examples/wgangp/wgangp_toy.py:94:107
--8<--
```

### 3.3 模型构建

本案例的WGAN-GP没有被内置在PaddleScience中,需要额外实现,因此我们自定义了`WganGpCifar10Generator`和`WganGpCifar10Discriminator`、`WganGpMnistGenerator`和`WganGpMnistDiscriminator`、`WganGpToyGenerator`和`WganGpToyDiscriminator`。

模型的构建代码如下:

``` py
--8<--
examples/wgangp/wgangp_cifar10.py:96:98
examples/wgangp/wgangp_mnist.py:87:88
examples/wgangp/wgangp_toy.py:80:81
--8<--
```

参数配置如下:

``` yaml
--8<--
examples/wgangp/conf/wgangp_cifar10.yaml:29:43
examples/wgangp/conf/wgangp_mnist.yaml:29:38
examples/wgangp/conf/wgangp_toy.yaml:29:37
--8<--
```

### 3.4 自定义loss

WGAN-GP的损失函数较复杂,需要我们自定义实现。PaddleScience提供了用于自定loss函数的API——`ppsci.loss.FunctionalLoss`。方法为先定义loss函数,再将函数名作为参数传给 `FunctionalLoss`。需要注意,自定义loss函数的输入输出需要是字典的格式。

#### 3.4.1 Generator的loss

Cifar10_Generator的loss包含了对抗性损失和分类损失。这两项loss都有对应的权重,如果某一项 loss 的权重为 0,则表示训练中不添加该 loss 项。

``` py linenums="16"
--8<--
examples/wgangp/functions.py:16:44
--8<--
```

MNIST_Generator的loss只包含了对抗性损失。
``` py linenums="313"
--8<--
examples/wgangp/functions.py:313:328
--8<--
```

Toy_Generator的loss只包含了对抗性损失。
``` py linenums="238"
--8<--
examples/wgangp/functions.py:238:254
--8<--
```

#### 3.4.2 Discriminator的loss

Cifar10_Discriminator的loss包含了Wasserstein损失和梯度惩罚以及分类损失。其中,只有分类损失项有权重参数。
``` py linenums="46"
--8<--
examples/wgangp/functions.py:46:95
--8<--
```

MNIST_Discriminator的loss包含了Wasserstein损失和梯度惩罚。
``` py linenums="330"
--8<--
examples/wgangp/functions.py:330:366
--8<--
```

Toy_Discriminator的loss包含了Wasserstein损失和梯度惩罚。
``` py linenums="256"
--8<--
examples/wgangp/functions.py:256:292
--8<--
```

### 3.5 约束构建

所有案例均使用`ppsci.constraint.SupervisedConstraint`构建约束。

构建代码如下:

``` py
--8<--
examples/wgangp/wgangp_cifar10.py:129:145
examples/wgangp/wgangp_mnist.py:117:132
examples/wgangp/wgangp_toy.py:110:125
--8<--
```

### 3.6 优化器构建

WGANGP使用Adam优化器,可直接调用`ppsci.optimizer.Adam`构建,代码如下:

``` py
--8<--
examples/wgangp/wgangp_cifar10.py:148:162
examples/wgangp/wgangp_mnist.py:134:137
examples/wgangp/wgangp_toy.py:128:131
--8<--
```

### 3.7 Solver构建

将构建好的模型、约束、优化器和其它参数传递给 `ppsci.solver.Solver`。

``` py
--8<--
examples/wgangp/wgangp_cifar10.py:164:182
examples/wgangp/wgangp_mnist.py:139:157
examples/wgangp/wgangp_toy.py:134:151
--8<--
```

### 3.8 模型训练

``` py
--8<--
examples/wgangp/wgangp_cifar10.py:185:190
examples/wgangp/wgangp_mnist.py:160:165
examples/wgangp/wgangp_toy.py:154:159
--8<--
```

### 3.9 自定义metric

案例中只有针对Cifar10的案例有评估指标为Inception Score,MNIST和Toy案例没有评估指标。由于metric为空会报错所以自定义了一个无效metric
所以我们额外实现了两个metric

PaddleScience提供了用于自定metric函数的API——`ppsci.metric.FunctionalMetric`。方法为先定义metric函数,再将函数名作为参数传给 `FunctionalMetric`。需要注意,自定义metric函数的输入输出需要是字典的格式。

Inception Score的实现代码如下:
``` py linenums="97"
--8<--
examples/wgangp/functions.py:97:154
--8<--
```

invalid_metric的代码如下
``` py linenums="389"
--8<--
examples/wgangp/functions.py:389:391
--8<--
```

### 3.10 Validator构建

本案例使用`ppsci.validate.SupervisedValidator`构建评估器。

``` py
--8<--
examples/wgangp/wgangp_cifar10.py:53:70
examples/wgangp/wgangp_mnist.py:46:54
examples/wgangp/wgangp_toy.py:46:52
--8<--
```

### 3.11 模型评估

将模型、评估器和权重路径传递给`ppsci.solver.Solver`后,通过`solver.eval()`启动评估。

``` py
--8<--
examples/wgangp/wgangp_cifar10.py:65:74
examples/wgangp/wgangp_mnist.py:56:65
examples/wgangp/wgangp_toy.py:55:63
--8<--
```

### 3.12 可视化

评估完成后,我们以图片的形式对结果进行可视化,代码如下:

``` py
--8<--
examples/wgangp/wgangp_cifar10.py:76:92
examples/wgangp/wgangp_mnist.py:67:83
examples/wgangp/wgangp_toy.py:65:75
--8<--
```

## 4. 完整代码

``` py
--8<--
examples/wgangp/wgangp_cifar10.py
examples/wgangp/wgangp_mnist.py
examples/wgangp/wgangp_toy.py
--8<--
```

## 6. 参考文献

- [Improved Training of Wasserstein GANs 论文](https://arxiv.org/abs/1704.00028)

- [参考代码](https://github.com/igul222/improved_wgan_training)
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ nav:
- Pang-Weather: zh/examples/pangu_weather.md
- FengWu: zh/examples/fengwu.md
- FuXi: zh/examples/fuxi.md
- WGAN_GP: zh/examples/wgan_gp.md
- 化学科学(AI for Chemistry):
- Moflow: zh/examples/moflow.md
- IFM: zh/examples/ifm.md
Expand Down