Skip to content
This repository was archived by the owner on Jan 10, 2023. It is now read-only.

Commit 25c4945

Browse files
committed
Major update adding new architectures and TPU support.
1 parent 51029dc commit 25c4945

File tree

126 files changed

+10399
-10677
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

126 files changed

+10399
-10677
lines changed

README.md

Lines changed: 113 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,113 @@
1-
## Compare GAN code.
2-
3-
This is the code that was used in "Are GANs Created Equal? A Large-Scale Study"
4-
paper (https://arxiv.org/abs/1711.10337) and in "The GAN Landscape: Losses,
5-
Architectures, Regularization, and Normalization"
6-
(https://arxiv.org/abs/1807.04720).
7-
8-
If you want to see the version used only in the first paper - please see the
9-
*v1* branch of this repository.
10-
11-
## Pre-trained models
12-
13-
The pre-trained models are available on TensorFlow Hub. Please see
14-
[this colab](https://colab.research.google.com/github/google/compare_gan/blob/master/compare_gan/src/tfhub_models.ipynb)
15-
for an example how to use them.
16-
17-
### Best hyperparameters
18-
19-
This repository also contains the values for the best hyperparameters for
20-
different combinations of models, regularizations and penalties. You can see
21-
them in `generate_tasks_lib.py` file and train using
22-
`--experiment=best_models_sndcgan`
23-
24-
### Installation:
25-
26-
To install, run:
27-
28-
```shell
29-
python -m pip install -e . --user
30-
```
31-
32-
After installing, make sure to run
33-
34-
```shell
35-
compare_gan_prepare_datasets.sh
36-
```
37-
38-
It will download all the necessary datasets and frozen TF graphs. By default it
39-
will store them in `/tmp/datasets`.
40-
41-
WARNING: by default this script only downloads and installs small datasets - it
42-
doesn't download celebaHQ or lsun bedrooms.
43-
44-
* **Lsun bedrooms dataset**: If you want to install lsun-bedrooms you need to
45-
run t2t-datagen yourself (this dataset will take couple hours to download
46-
and unpack).
47-
48-
* **CelebaHQ dataset**: currently it is not available in tensor2tensor. Please
49-
use the
50-
[ProgressiveGAN github](https://github.com/tkarras/progressive_growing_of_gans)
51-
for instructions on how to prepare it.
52-
53-
### Running
54-
55-
compare_gan has two binaries:
56-
57-
* `generate_tasks` - that creates a list of files with parameters to execute
58-
* `run_one_task` - that executes a given task, both training and evaluation,
59-
and stores results in the CSV file.
60-
61-
```shell
62-
# Create tasks for experiment "test" in directory /tmp/results. See "src/generate_tasks_lib.py" to see other possible experiments.
63-
compare_gan_generate_tasks --workdir=/tmp/results --experiment=test
64-
65-
# Run task 0 (training and eval)
66-
compare_gan_run_one_task --workdir=/tmp/results --task_num=0 --dataset_root=/tmp/datasets
67-
68-
# Run task 1 (training and eval)
69-
compare_gan_run_one_task --workdir=/tmp/results --task_num=1 --dataset_root=/tmp/datasets
70-
```
71-
72-
Results (all computed metrics) will be stored in
73-
`/tmp/results/TASK_NUM/scores.csv`.
1+
# Compare GAN
2+
3+
This repository offers TensorFlow implementations for many components related to
4+
**Generative Adversarial Networks**:
5+
6+
* losses (such non-saturating GAN, least-squares GAN, and WGAN),
7+
* penalties (such as the gradient penalty),
8+
* normalization techniques (such as spectral normalization, batch
9+
normalization, and layer normalization),
10+
* neural architectures (BigGAN, ResNet, DCGAN), and
11+
* evaluation metrics (FID score, Inception Score, precision-recall, and KID
12+
score).
13+
14+
The code is **configurable via [Gin](https://github.com/google/gin-config)** and
15+
runs on **GPU/TPU/CPUs**. Several research papers make use of this repository,
16+
including:
17+
18+
1. [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337)
19+
[<font color="green">[Code]</font>](https://github.com/google/compare_gan/tree/v1)
20+
\
21+
Mario Lucic*, Karol Kurach*, Marcin Michalski, Sylvain Gelly, Olivier
22+
Bousquet **[NeurIPS 2018]**
23+
24+
2. [The GAN Landscape: Losses, Architectures, Regularization, and Normalization](https://arxiv.org/abs/1807.04720)
25+
[<font color="green">[Code]</font>](https://github.com/google/compare_gan/tree/v2)
26+
\
27+
Karol Kurach*, Mario Lucic*, Xiaohua Zhai, Marcin Michalski, Sylvain Gelly
28+
**[2018]**
29+
30+
3. [Assessing Generative Models via Precision and Recall](https://arxiv.org/abs/1806.00035)
31+
[<font color="green">[Code]</font>](https://github.com/google/compare_gan/blob/560697ee213f91048c6b4231ab79fcdd9bf20381/compare_gan/src/prd_score.py)
32+
\
33+
Mehdi S. M. Sajjadi, Olivier Bachem, Mario Lucic, Olivier Bousquet, Sylvain
34+
Gelly **[NeurIPS 2018]**
35+
36+
4. [GILBO: One Metric to Measure Them All](https://arxiv.org/abs/1802.04874)
37+
[<font color="green">[Code]</font>](https://github.com/google/compare_gan/blob/560697ee213f91048c6b4231ab79fcdd9bf20381/compare_gan/src/gilbo.py)
38+
\
39+
Alexander A. Alemi, Ian Fischer **[NeurIPS 2018]**
40+
41+
5. [A Case for Object Compositionality in Deep Generative Models of Images](https://arxiv.org/abs/1810.10340)
42+
[<font color="green">[Code]</font>](https://github.com/google/compare_gan/tree/v2_multigan)
43+
\
44+
Sjoerd van Steenkiste, Karol Kurach, Sylvain Gelly **[2018]**
45+
46+
6. [On Self Modulation for Generative Adversarial Networks](https://arxiv.org/abs/1810.01365)
47+
[<font color="green">[Code]</font>](https://github.com/google/compare_gan) \
48+
Ting Chen, Mario Lucic, Neil Houlsby, Sylvain Gelly **[ICLR 2019]**
49+
50+
7. [Self-Supervised Generative Adversarial Networks](https://arxiv.org/abs/1811.11212)
51+
[<font color="green">[Code]</font>](https://github.com/google/compare_gan) \
52+
Ting Chen, Xiaohua Zhai, Marvin Ritter, Mario Lucic, Neil Houlsby **[CVPR
53+
2019]**
54+
55+
56+
## Installation
57+
58+
You can easily install the library and all necessary dependencies by running:
59+
`pip install -e .` from the `compare_gan/` folder.
60+
61+
## Running experiments
62+
63+
Simply run the `main.py` passing a `--model_dir` (this is where checkpoints are
64+
stored) and a `--gin_config` (defines which model on which data set and other
65+
options). We provide several example configurations in the `example_configs/`
66+
folder, namely:
67+
68+
* **dcgan_celeba64**: DCGAN architecture with non-saturating loss on CelebA
69+
64x64px
70+
* **resnet_cifar10**: ResNet architecture with non-saturating loss and
71+
spectral normalization on CIFAR-10
72+
* **resnet_lsun-bedroom128**: ResNet architecture with WGAN loss and gradient
73+
penalty on LSUN-bedrooms 128x128px
74+
* **sndcgan_celebahq128**: SN-DCGAN architecture with non-saturating loss and
75+
spectral normalization on CelebA-HQ 128x128px
76+
* **biggan_imagenet128**: BigGAN architecture with hinge loss and spectral
77+
normalization on ImageNet 128x128px
78+
79+
### Training and evaluation
80+
81+
To see all available options please run `python main.py --help`. Main options:
82+
83+
* To **train** the model use `--schedule=train` (default). Training is resumed
84+
from the last saved checkpoint.
85+
* To **evaluate** all checkpoints use `--schedule=continuous_eval
86+
--eval_every_steps=0`. To evaluate only checkpoints where the step size is
87+
divisible by 5000, use `--schedule=continuous_eval --eval_every_steps=5000`.
88+
By default, 3 averaging runs are used to estimate the Inception Score and
89+
the FID score. Keep in mind that when running locally on a single GPU it may
90+
not be possible to run training and evaluation simultaneously due to memory
91+
constraints.
92+
* To **train and evaluate** the model use `--schedule=eval_after_train
93+
--eval_every_steps=0`.
94+
95+
### Training on Cloud TPUs
96+
97+
We recommend using the
98+
[ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu) to create
99+
a Cloud TPU and corresponding Compute Engine VM. We use v3-128 Cloud TPU v3 Pod
100+
for training models on ImageNet in 128x128 resolutions. You can use smaller
101+
slices if you reduce the batch size (`options.batch_size` in the Gin config) or
102+
model parameters. Keep in mind that the model quality might change. Before
103+
training make sure that the environment variable `TPU_NAME` is set. Running
104+
evaluation on TPUs is currently not supported. Use a VM with a single GPU
105+
instead.
106+
107+
### Datasets
108+
109+
Compare GAN uses [TensorFlow Datasets](https://www.tensorflow.org/datasets) and
110+
it will automatically download and prepare the data. For ImageNet you will need
111+
to download the archive yourself. For CelebAHq you need to download and prepare
112+
the images on your own. If you are using TPUs make sure to point the training
113+
script to your Google Storage Bucket (`--tfds_data_dir`).
File renamed without changes.
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# coding=utf-8
2+
# Copyright 2018 Google LLC & Hwalsuk Lee.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Defines interfaces for generator and discriminator networks."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import abc
23+
from compare_gan import utils
24+
import gin
25+
import six
26+
import tensorflow as tf
27+
28+
29+
@gin.configurable("G", blacklist=["name", "image_shape"])
30+
@six.add_metaclass(abc.ABCMeta)
31+
class AbstractGenerator(object):
32+
"""Interface for generator architectures."""
33+
34+
def __init__(self,
35+
name="generator",
36+
image_shape=None,
37+
batch_norm_fn=None,
38+
spectral_norm=False):
39+
"""Constructor for all generator architectures.
40+
41+
Args:
42+
name: Scope name of the generator.
43+
image_shape: Image shape to be generated, [height, width, colors].
44+
batch_norm_fn: Function for batch normalization or None.
45+
spectral_norm: If True use spectral normalization for all weights.
46+
"""
47+
self._name = name
48+
self._image_shape = image_shape
49+
self._batch_norm_fn = batch_norm_fn
50+
self._spectral_norm = spectral_norm
51+
52+
def __call__(self, z, y, is_training, reuse=tf.AUTO_REUSE):
53+
with tf.variable_scope(self._name, values=[z, y], reuse=reuse):
54+
outputs = self.apply(z=z, y=y, is_training=is_training)
55+
return outputs
56+
57+
def batch_norm(self, inputs, **kwargs):
58+
if self._batch_norm_fn is None:
59+
return inputs
60+
args = kwargs.copy()
61+
args["inputs"] = inputs
62+
if "use_sn" not in args:
63+
args["use_sn"] = self._spectral_norm
64+
return utils.call_with_accepted_args(self._batch_norm_fn, **args)
65+
66+
@abc.abstractmethod
67+
def apply(self, z, y, is_training):
68+
"""Apply the generator on a input.
69+
70+
Args:
71+
z: `Tensor` of shape [batch_size, z_dim] with latent code.
72+
y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
73+
labels.
74+
is_training: Boolean, whether the architecture should be constructed for
75+
training or inference.
76+
77+
Returns:
78+
Generated images of shape [batch_size] + self.image_shape.
79+
"""
80+
81+
82+
@gin.configurable("D", blacklist=["name"])
83+
@six.add_metaclass(abc.ABCMeta)
84+
class AbstractDiscriminator(object):
85+
"""Interface for discriminator architectures."""
86+
87+
def __init__(self,
88+
name="discriminator",
89+
batch_norm_fn=None,
90+
layer_norm=False,
91+
spectral_norm=False):
92+
self._name = name
93+
self._batch_norm_fn = batch_norm_fn
94+
self._layer_norm = layer_norm
95+
self._spectral_norm = spectral_norm
96+
97+
def __call__(self, x, y, is_training, reuse=tf.AUTO_REUSE):
98+
with tf.variable_scope(self._name, values=[x, y], reuse=reuse):
99+
outputs = self.apply(x=x, y=y, is_training=is_training)
100+
return outputs
101+
102+
def batch_norm(self, inputs, **kwargs):
103+
if self._batch_norm_fn is None:
104+
return inputs
105+
args = kwargs.copy()
106+
args["inputs"] = inputs
107+
if "use_sn" not in args:
108+
args["use_sn"] = self._spectral_norm
109+
return utils.call_with_accepted_args(self._batch_norm_fn, **args)
110+
111+
112+
@abc.abstractmethod
113+
def apply(self, x, y, is_training):
114+
"""Apply the discriminator on a input.
115+
116+
Args:
117+
x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images.
118+
y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
119+
labels.
120+
is_training: Boolean, whether the architecture should be constructed for
121+
training or inference.
122+
123+
Returns:
124+
Tuple of 3 Tensors, the final prediction of the discriminator, the logits
125+
before the final output activation function and logits form the second
126+
last layer.
127+
"""

0 commit comments

Comments
 (0)