From 79c678224d7fa14ddebd7a387970a232b9545af4 Mon Sep 17 00:00:00 2001 From: Ilya Kostrikov Date: Thu, 29 Dec 2016 16:39:56 -0500 Subject: [PATCH] Major refactoring: add classes, remove pretty tensor --- gan.py | 82 +++++++++++++++++++++++++++++++++++++++++++++++++ generator.py | 31 +++++++++++++++++++ main.py | 63 +++++++++++++++++++++++++++++++++++++ utils.py | 56 +++++++++++++++++++++++++++++++++ vae.py | 87 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 319 insertions(+) create mode 100644 gan.py create mode 100644 generator.py create mode 100644 main.py create mode 100644 utils.py create mode 100644 vae.py diff --git a/gan.py b/gan.py new file mode 100644 index 0000000..2ac25f2 --- /dev/null +++ b/gan.py @@ -0,0 +1,82 @@ +'''TensorFlow implementation of http://arxiv.org/pdf/1511.06434.pdf''' + +from __future__ import absolute_import, division, print_function + +import math + +import numpy as np +from tensorflow.contrib import layers +from tensorflow.contrib import losses +from tensorflow.contrib.framework import arg_scope +import tensorflow as tf + +from utils import discriminator, decoder +from generator import Generator + +def concat_elu(inputs): + return tf.nn.elu(tf.concat(3, [-inputs, inputs])) + +class GAN(Generator): + + def __init__(self, hidden_size, batch_size, learning_rate): + self.input_tensor = tf.placeholder(tf.float32, [None, 28 * 28]) + + with arg_scope([layers.conv2d, layers.conv2d_transpose], + activation_fn=concat_elu, + normalizer_fn=layers.batch_norm, + normalizer_params={'scale': True}): + with tf.variable_scope("model"): + D1 = discriminator(self.input_tensor) # positive examples + D_params_num = len(tf.trainable_variables()) + G = decoder(tf.random_normal([batch_size, hidden_size])) + self.sampled_tensor = G + + with tf.variable_scope("model", reuse=True): + D2 = discriminator(G) # generated examples + + D_loss = self.__get_discrinator_loss(D1, D2) + G_loss = self.__get_generator_loss(D2) + + params = tf.trainable_variables() + D_params = params[:D_params_num] + G_params = params[D_params_num:] + # train_discrimator = optimizer.minimize(loss=D_loss, var_list=D_params) + # train_generator = optimizer.minimize(loss=G_loss, var_list=G_params) + global_step = tf.contrib.framework.get_or_create_global_step() + self.train_discrimator = layers.optimize_loss( + D_loss, global_step, learning_rate / 10, 'Adam', variables=D_params, update_ops=[]) + self.train_generator = layers.optimize_loss( + G_loss, global_step, learning_rate, 'Adam', variables=G_params, update_ops=[]) + + self.sess = tf.Session() + self.sess.run(tf.global_variables_initializer()) + + def __get_discrinator_loss(self, D1, D2): + '''Loss for the discriminator network + + Args: + D1: logits computed with a discriminator networks from real images + D2: logits computed with a discriminator networks from generated images + + Returns: + Cross entropy loss, positive samples have implicit labels 1, negative 0s + ''' + return (losses.sigmoid_cross_entropy(D1, tf.ones(tf.shape(D1))) + + losses.sigmoid_cross_entropy(D2, tf.zeros(tf.shape(D1)))) + + def __get_generator_loss(self, D2): + '''Loss for the genetor. Maximize probability of generating images that + discrimator cannot differentiate. + + Returns: + see the paper + ''' + return losses.sigmoid_cross_entropy(D2, tf.ones(tf.shape(D2))) + + def update_params(self, inputs): + d_loss_value = self.sess.run(self.train_discrimator, { + self.input_tensor: inputs}) + + g_loss_value = self.sess.run(self.train_generator) + + return g_loss_value diff --git a/generator.py b/generator.py new file mode 100644 index 0000000..dffc8b9 --- /dev/null +++ b/generator.py @@ -0,0 +1,31 @@ +import os +from scipy.misc import imsave + +class Generator(object): + + def update_params(self, input_tensor): + '''Update parameters of the network + + Args: + input_tensor: a batch of flattened images + + Returns: + Current loss value + ''' + raise NotImplementedError() + + def generate_and_save_images(self, num_samples, directory): + '''Generates the images using the model and saves them in the directory + + Args: + num_samples: number of samples to generate + directory: a directory to save the images + ''' + imgs = self.sess.run(self.sampled_tensor) + for k in range(imgs.shape[0]): + imgs_folder = os.path.join(directory, 'imgs') + if not os.path.exists(imgs_folder): + os.makedirs(imgs_folder) + + imsave(os.path.join(imgs_folder, '%d.png') % k, + imgs[k].reshape(28, 28)) diff --git a/main.py b/main.py new file mode 100644 index 0000000..5f0460a --- /dev/null +++ b/main.py @@ -0,0 +1,63 @@ +'''TensorFlow implementation of http://arxiv.org/pdf/1312.6114v10.pdf''' + +from __future__ import absolute_import, division, print_function + +import math +import os + +import numpy as np +import scipy.misc +import tensorflow as tf +from tensorflow.contrib import layers +from tensorflow.contrib import losses +from tensorflow.contrib.framework import arg_scope +from scipy.misc import imsave +from tensorflow.examples.tutorials.mnist import input_data + +from deconv import deconv2d +from progressbar import ETA, Bar, Percentage, ProgressBar + +from vae import VAE +from gan import GAN + +flags = tf.flags +logging = tf.logging + +flags.DEFINE_integer("batch_size", 128, "batch size") +flags.DEFINE_integer("updates_per_epoch", 1000, "number of updates per epoch") +flags.DEFINE_integer("max_epoch", 100, "max epoch") +flags.DEFINE_float("learning_rate", 1e-2, "learning rate") +flags.DEFINE_string("working_directory", "", "") +flags.DEFINE_integer("hidden_size", 128, "size of the hidden VAE unit") +flags.DEFINE_string("model", "gan", "gan or vae") + +FLAGS = flags.FLAGS + +if __name__ == "__main__": + data_directory = os.path.join(FLAGS.working_directory, "MNIST") + if not os.path.exists(data_directory): + os.makedirs(data_directory) + mnist = input_data.read_data_sets(data_directory, one_hot=True) + + assert FLAGS.model in ['vae', 'gan'] + if FLAGS.model == 'vae': + model = VAE(FLAGS.hidden_size, FLAGS.batch_size, FLAGS.learning_rate) + elif FLAGS.model == 'gan': + model = GAN(FLAGS.hidden_size, FLAGS.batch_size, FLAGS.learning_rate) + + for epoch in range(FLAGS.max_epoch): + training_loss = 0.0 + + pbar = ProgressBar() + for i in pbar(range(FLAGS.updates_per_epoch)): + images, _ = mnist.train.next_batch(FLAGS.batch_size) + loss_value = model.update_params(images) + training_loss += loss_value + + training_loss = training_loss / \ + (FLAGS.updates_per_epoch * FLAGS.batch_size) + + print("Loss %f" % training_loss) + + model.generate_and_save_images( + FLAGS.batch_size, FLAGS.working_directory) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..b9d8330 --- /dev/null +++ b/utils.py @@ -0,0 +1,56 @@ +import tensorflow as tf +from tensorflow.contrib import layers + + +def encoder(input_tensor, output_size): + '''Create encoder network. + + Args: + input_tensor: a batch of flattened images [batch_size, 28*28] + + Returns: + A tensor that expresses the encoder network + ''' + net = tf.reshape(input_tensor, [-1, 28, 28, 1]) + net = layers.conv2d(net, 32, 5, stride=2) + net = layers.conv2d(net, 64, 5, stride=2) + net = layers.conv2d(net, 128, 5, stride=2, padding='VALID') + net = layers.dropout(net, keep_prob=0.9) + net = layers.flatten(net) + return layers.fully_connected(net, output_size, activation_fn=None) + + +def discriminator(input_tensor): + '''Create a network that discriminates between images from a dataset and + generated ones. + + Args: + input: a batch of real images [batch, height, width, channels] + Returns: + A tensor that represents the network + ''' + + return encoder(input_tensor, 1) + + +def decoder(input_tensor): + '''Create decoder network. + + If input tensor is provided then decodes it, otherwise samples from + a sampled vector. + Args: + input_tensor: a batch of vectors to decode + + Returns: + A tensor that expresses the decoder network + ''' + + net = tf.expand_dims(input_tensor, 1) + net = tf.expand_dims(net, 1) + net = layers.conv2d_transpose(net, 128, 3, padding='VALID') + net = layers.conv2d_transpose(net, 64, 5, padding='VALID') + net = layers.conv2d_transpose(net, 32, 5, stride=2) + net = layers.conv2d_transpose( + net, 1, 5, stride=2, activation_fn=tf.nn.sigmoid) + net = layers.flatten(net) + return net diff --git a/vae.py b/vae.py new file mode 100644 index 0000000..207e710 --- /dev/null +++ b/vae.py @@ -0,0 +1,87 @@ +'''TensorFlow implementation of http://arxiv.org/pdf/1312.6114v10.pdf''' + +from __future__ import absolute_import, division, print_function + +import math + +import numpy as np +import tensorflow as tf +from tensorflow.contrib import layers +from tensorflow.contrib import losses +from tensorflow.contrib.framework import arg_scope + +from utils import encoder, decoder +from generator import Generator + + +class VAE(Generator): + + def __init__(self, hidden_size, batch_size, learning_rate): + self.input_tensor = tf.placeholder( + tf.float32, [None, 28 * 28]) + + with arg_scope([layers.conv2d, layers.conv2d_transpose], + activation_fn=tf.nn.elu, + normalizer_fn=layers.batch_norm, + normalizer_params={'scale': True}): + with tf.variable_scope("model") as scope: + encoded = encoder(self.input_tensor, hidden_size * 2) + + mean = encoded[:, :hidden_size] + stddev = tf.sqrt(tf.exp(encoded[:, hidden_size:])) + + epsilon = tf.random_normal([tf.shape(mean)[0], hidden_size]) + input_sample = mean + epsilon * stddev + + output_tensor = decoder(input_sample) + + with tf.variable_scope("model", reuse=True) as scope: + self.sampled_tensor = decoder(tf.random_normal( + [batch_size, hidden_size])) + + vae_loss = self.__get_vae_cost(mean, stddev) + rec_loss = self.__get_reconstruction_cost( + output_tensor, self.input_tensor) + + loss = vae_loss + rec_loss + self.train = layers.optimize_loss(loss, tf.contrib.framework.get_or_create_global_step( + ), learning_rate=learning_rate, optimizer='Adam', update_ops=[]) + + self.sess = tf.Session() + self.sess.run(tf.global_variables_initializer()) + + def __get_vae_cost(self, mean, stddev, epsilon=1e-8): + '''VAE loss + See the paper + + Args: + mean: + stddev: + epsilon: + ''' + return tf.reduce_sum(0.5 * (tf.square(mean) + tf.square(stddev) - + 2.0 * tf.log(stddev + epsilon) - 1.0)) + + def __get_reconstruction_cost(self, output_tensor, target_tensor, epsilon=1e-8): + '''Reconstruction loss + + Cross entropy reconstruction loss + + Args: + output_tensor: tensor produces by decoder + target_tensor: the target tensor that we want to reconstruct + epsilon: + ''' + return tf.reduce_sum(-target_tensor * tf.log(output_tensor + epsilon) - + (1.0 - target_tensor) * tf.log(1.0 - output_tensor + epsilon)) + + def update_params(self, input_tensor): + '''Update parameters of the network + + Args: + input_tensor: a batch of flattened images [batch_size, 28*28] + + Returns: + Current loss value + ''' + return self.sess.run(self.train, {self.input_tensor: input_tensor})