Skip to content

Commit

Permalink
Major refactoring: add classes, remove pretty tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
ikostrikov2 committed Dec 29, 2016
1 parent 44426b3 commit 79c6782
Show file tree
Hide file tree
Showing 5 changed files with 319 additions and 0 deletions.
82 changes: 82 additions & 0 deletions gan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
'''TensorFlow implementation of http://arxiv.org/pdf/1511.06434.pdf'''

from __future__ import absolute_import, division, print_function

import math

import numpy as np
from tensorflow.contrib import layers
from tensorflow.contrib import losses
from tensorflow.contrib.framework import arg_scope
import tensorflow as tf

from utils import discriminator, decoder
from generator import Generator

def concat_elu(inputs):
return tf.nn.elu(tf.concat(3, [-inputs, inputs]))

class GAN(Generator):

def __init__(self, hidden_size, batch_size, learning_rate):
self.input_tensor = tf.placeholder(tf.float32, [None, 28 * 28])

with arg_scope([layers.conv2d, layers.conv2d_transpose],
activation_fn=concat_elu,
normalizer_fn=layers.batch_norm,
normalizer_params={'scale': True}):
with tf.variable_scope("model"):
D1 = discriminator(self.input_tensor) # positive examples
D_params_num = len(tf.trainable_variables())
G = decoder(tf.random_normal([batch_size, hidden_size]))
self.sampled_tensor = G

with tf.variable_scope("model", reuse=True):
D2 = discriminator(G) # generated examples

D_loss = self.__get_discrinator_loss(D1, D2)
G_loss = self.__get_generator_loss(D2)

params = tf.trainable_variables()
D_params = params[:D_params_num]
G_params = params[D_params_num:]
# train_discrimator = optimizer.minimize(loss=D_loss, var_list=D_params)
# train_generator = optimizer.minimize(loss=G_loss, var_list=G_params)
global_step = tf.contrib.framework.get_or_create_global_step()
self.train_discrimator = layers.optimize_loss(
D_loss, global_step, learning_rate / 10, 'Adam', variables=D_params, update_ops=[])
self.train_generator = layers.optimize_loss(
G_loss, global_step, learning_rate, 'Adam', variables=G_params, update_ops=[])

self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())

def __get_discrinator_loss(self, D1, D2):
'''Loss for the discriminator network
Args:
D1: logits computed with a discriminator networks from real images
D2: logits computed with a discriminator networks from generated images
Returns:
Cross entropy loss, positive samples have implicit labels 1, negative 0s
'''
return (losses.sigmoid_cross_entropy(D1, tf.ones(tf.shape(D1))) +
losses.sigmoid_cross_entropy(D2, tf.zeros(tf.shape(D1))))

def __get_generator_loss(self, D2):
'''Loss for the genetor. Maximize probability of generating images that
discrimator cannot differentiate.
Returns:
see the paper
'''
return losses.sigmoid_cross_entropy(D2, tf.ones(tf.shape(D2)))

def update_params(self, inputs):
d_loss_value = self.sess.run(self.train_discrimator, {
self.input_tensor: inputs})

g_loss_value = self.sess.run(self.train_generator)

return g_loss_value
31 changes: 31 additions & 0 deletions generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
from scipy.misc import imsave

class Generator(object):

def update_params(self, input_tensor):
'''Update parameters of the network
Args:
input_tensor: a batch of flattened images
Returns:
Current loss value
'''
raise NotImplementedError()

def generate_and_save_images(self, num_samples, directory):
'''Generates the images using the model and saves them in the directory
Args:
num_samples: number of samples to generate
directory: a directory to save the images
'''
imgs = self.sess.run(self.sampled_tensor)
for k in range(imgs.shape[0]):
imgs_folder = os.path.join(directory, 'imgs')
if not os.path.exists(imgs_folder):
os.makedirs(imgs_folder)

imsave(os.path.join(imgs_folder, '%d.png') % k,
imgs[k].reshape(28, 28))
63 changes: 63 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
'''TensorFlow implementation of http://arxiv.org/pdf/1312.6114v10.pdf'''

from __future__ import absolute_import, division, print_function

import math
import os

import numpy as np
import scipy.misc
import tensorflow as tf
from tensorflow.contrib import layers
from tensorflow.contrib import losses
from tensorflow.contrib.framework import arg_scope
from scipy.misc import imsave
from tensorflow.examples.tutorials.mnist import input_data

from deconv import deconv2d
from progressbar import ETA, Bar, Percentage, ProgressBar

from vae import VAE
from gan import GAN

flags = tf.flags
logging = tf.logging

flags.DEFINE_integer("batch_size", 128, "batch size")
flags.DEFINE_integer("updates_per_epoch", 1000, "number of updates per epoch")
flags.DEFINE_integer("max_epoch", 100, "max epoch")
flags.DEFINE_float("learning_rate", 1e-2, "learning rate")
flags.DEFINE_string("working_directory", "", "")
flags.DEFINE_integer("hidden_size", 128, "size of the hidden VAE unit")
flags.DEFINE_string("model", "gan", "gan or vae")

FLAGS = flags.FLAGS

if __name__ == "__main__":
data_directory = os.path.join(FLAGS.working_directory, "MNIST")
if not os.path.exists(data_directory):
os.makedirs(data_directory)
mnist = input_data.read_data_sets(data_directory, one_hot=True)

assert FLAGS.model in ['vae', 'gan']
if FLAGS.model == 'vae':
model = VAE(FLAGS.hidden_size, FLAGS.batch_size, FLAGS.learning_rate)
elif FLAGS.model == 'gan':
model = GAN(FLAGS.hidden_size, FLAGS.batch_size, FLAGS.learning_rate)

for epoch in range(FLAGS.max_epoch):
training_loss = 0.0

pbar = ProgressBar()
for i in pbar(range(FLAGS.updates_per_epoch)):
images, _ = mnist.train.next_batch(FLAGS.batch_size)
loss_value = model.update_params(images)
training_loss += loss_value

training_loss = training_loss / \
(FLAGS.updates_per_epoch * FLAGS.batch_size)

print("Loss %f" % training_loss)

model.generate_and_save_images(
FLAGS.batch_size, FLAGS.working_directory)
56 changes: 56 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import tensorflow as tf
from tensorflow.contrib import layers


def encoder(input_tensor, output_size):
'''Create encoder network.
Args:
input_tensor: a batch of flattened images [batch_size, 28*28]
Returns:
A tensor that expresses the encoder network
'''
net = tf.reshape(input_tensor, [-1, 28, 28, 1])
net = layers.conv2d(net, 32, 5, stride=2)
net = layers.conv2d(net, 64, 5, stride=2)
net = layers.conv2d(net, 128, 5, stride=2, padding='VALID')
net = layers.dropout(net, keep_prob=0.9)
net = layers.flatten(net)
return layers.fully_connected(net, output_size, activation_fn=None)


def discriminator(input_tensor):
'''Create a network that discriminates between images from a dataset and
generated ones.
Args:
input: a batch of real images [batch, height, width, channels]
Returns:
A tensor that represents the network
'''

return encoder(input_tensor, 1)


def decoder(input_tensor):
'''Create decoder network.
If input tensor is provided then decodes it, otherwise samples from
a sampled vector.
Args:
input_tensor: a batch of vectors to decode
Returns:
A tensor that expresses the decoder network
'''

net = tf.expand_dims(input_tensor, 1)
net = tf.expand_dims(net, 1)
net = layers.conv2d_transpose(net, 128, 3, padding='VALID')
net = layers.conv2d_transpose(net, 64, 5, padding='VALID')
net = layers.conv2d_transpose(net, 32, 5, stride=2)
net = layers.conv2d_transpose(
net, 1, 5, stride=2, activation_fn=tf.nn.sigmoid)
net = layers.flatten(net)
return net
87 changes: 87 additions & 0 deletions vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
'''TensorFlow implementation of http://arxiv.org/pdf/1312.6114v10.pdf'''

from __future__ import absolute_import, division, print_function

import math

import numpy as np
import tensorflow as tf
from tensorflow.contrib import layers
from tensorflow.contrib import losses
from tensorflow.contrib.framework import arg_scope

from utils import encoder, decoder
from generator import Generator


class VAE(Generator):

def __init__(self, hidden_size, batch_size, learning_rate):
self.input_tensor = tf.placeholder(
tf.float32, [None, 28 * 28])

with arg_scope([layers.conv2d, layers.conv2d_transpose],
activation_fn=tf.nn.elu,
normalizer_fn=layers.batch_norm,
normalizer_params={'scale': True}):
with tf.variable_scope("model") as scope:
encoded = encoder(self.input_tensor, hidden_size * 2)

mean = encoded[:, :hidden_size]
stddev = tf.sqrt(tf.exp(encoded[:, hidden_size:]))

epsilon = tf.random_normal([tf.shape(mean)[0], hidden_size])
input_sample = mean + epsilon * stddev

output_tensor = decoder(input_sample)

with tf.variable_scope("model", reuse=True) as scope:
self.sampled_tensor = decoder(tf.random_normal(
[batch_size, hidden_size]))

vae_loss = self.__get_vae_cost(mean, stddev)
rec_loss = self.__get_reconstruction_cost(
output_tensor, self.input_tensor)

loss = vae_loss + rec_loss
self.train = layers.optimize_loss(loss, tf.contrib.framework.get_or_create_global_step(
), learning_rate=learning_rate, optimizer='Adam', update_ops=[])

self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())

def __get_vae_cost(self, mean, stddev, epsilon=1e-8):
'''VAE loss
See the paper
Args:
mean:
stddev:
epsilon:
'''
return tf.reduce_sum(0.5 * (tf.square(mean) + tf.square(stddev) -
2.0 * tf.log(stddev + epsilon) - 1.0))

def __get_reconstruction_cost(self, output_tensor, target_tensor, epsilon=1e-8):
'''Reconstruction loss
Cross entropy reconstruction loss
Args:
output_tensor: tensor produces by decoder
target_tensor: the target tensor that we want to reconstruct
epsilon:
'''
return tf.reduce_sum(-target_tensor * tf.log(output_tensor + epsilon) -
(1.0 - target_tensor) * tf.log(1.0 - output_tensor + epsilon))

def update_params(self, input_tensor):
'''Update parameters of the network
Args:
input_tensor: a batch of flattened images [batch_size, 28*28]
Returns:
Current loss value
'''
return self.sess.run(self.train, {self.input_tensor: input_tensor})

0 comments on commit 79c6782

Please sign in to comment.