-
Notifications
You must be signed in to change notification settings - Fork 166
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Major refactoring: add classes, remove pretty tensor
- Loading branch information
1 parent
44426b3
commit 79c6782
Showing
5 changed files
with
319 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
'''TensorFlow implementation of http://arxiv.org/pdf/1511.06434.pdf''' | ||
|
||
from __future__ import absolute_import, division, print_function | ||
|
||
import math | ||
|
||
import numpy as np | ||
from tensorflow.contrib import layers | ||
from tensorflow.contrib import losses | ||
from tensorflow.contrib.framework import arg_scope | ||
import tensorflow as tf | ||
|
||
from utils import discriminator, decoder | ||
from generator import Generator | ||
|
||
def concat_elu(inputs): | ||
return tf.nn.elu(tf.concat(3, [-inputs, inputs])) | ||
|
||
class GAN(Generator): | ||
|
||
def __init__(self, hidden_size, batch_size, learning_rate): | ||
self.input_tensor = tf.placeholder(tf.float32, [None, 28 * 28]) | ||
|
||
with arg_scope([layers.conv2d, layers.conv2d_transpose], | ||
activation_fn=concat_elu, | ||
normalizer_fn=layers.batch_norm, | ||
normalizer_params={'scale': True}): | ||
with tf.variable_scope("model"): | ||
D1 = discriminator(self.input_tensor) # positive examples | ||
D_params_num = len(tf.trainable_variables()) | ||
G = decoder(tf.random_normal([batch_size, hidden_size])) | ||
self.sampled_tensor = G | ||
|
||
with tf.variable_scope("model", reuse=True): | ||
D2 = discriminator(G) # generated examples | ||
|
||
D_loss = self.__get_discrinator_loss(D1, D2) | ||
G_loss = self.__get_generator_loss(D2) | ||
|
||
params = tf.trainable_variables() | ||
D_params = params[:D_params_num] | ||
G_params = params[D_params_num:] | ||
# train_discrimator = optimizer.minimize(loss=D_loss, var_list=D_params) | ||
# train_generator = optimizer.minimize(loss=G_loss, var_list=G_params) | ||
global_step = tf.contrib.framework.get_or_create_global_step() | ||
self.train_discrimator = layers.optimize_loss( | ||
D_loss, global_step, learning_rate / 10, 'Adam', variables=D_params, update_ops=[]) | ||
self.train_generator = layers.optimize_loss( | ||
G_loss, global_step, learning_rate, 'Adam', variables=G_params, update_ops=[]) | ||
|
||
self.sess = tf.Session() | ||
self.sess.run(tf.global_variables_initializer()) | ||
|
||
def __get_discrinator_loss(self, D1, D2): | ||
'''Loss for the discriminator network | ||
Args: | ||
D1: logits computed with a discriminator networks from real images | ||
D2: logits computed with a discriminator networks from generated images | ||
Returns: | ||
Cross entropy loss, positive samples have implicit labels 1, negative 0s | ||
''' | ||
return (losses.sigmoid_cross_entropy(D1, tf.ones(tf.shape(D1))) + | ||
losses.sigmoid_cross_entropy(D2, tf.zeros(tf.shape(D1)))) | ||
|
||
def __get_generator_loss(self, D2): | ||
'''Loss for the genetor. Maximize probability of generating images that | ||
discrimator cannot differentiate. | ||
Returns: | ||
see the paper | ||
''' | ||
return losses.sigmoid_cross_entropy(D2, tf.ones(tf.shape(D2))) | ||
|
||
def update_params(self, inputs): | ||
d_loss_value = self.sess.run(self.train_discrimator, { | ||
self.input_tensor: inputs}) | ||
|
||
g_loss_value = self.sess.run(self.train_generator) | ||
|
||
return g_loss_value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import os | ||
from scipy.misc import imsave | ||
|
||
class Generator(object): | ||
|
||
def update_params(self, input_tensor): | ||
'''Update parameters of the network | ||
Args: | ||
input_tensor: a batch of flattened images | ||
Returns: | ||
Current loss value | ||
''' | ||
raise NotImplementedError() | ||
|
||
def generate_and_save_images(self, num_samples, directory): | ||
'''Generates the images using the model and saves them in the directory | ||
Args: | ||
num_samples: number of samples to generate | ||
directory: a directory to save the images | ||
''' | ||
imgs = self.sess.run(self.sampled_tensor) | ||
for k in range(imgs.shape[0]): | ||
imgs_folder = os.path.join(directory, 'imgs') | ||
if not os.path.exists(imgs_folder): | ||
os.makedirs(imgs_folder) | ||
|
||
imsave(os.path.join(imgs_folder, '%d.png') % k, | ||
imgs[k].reshape(28, 28)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
'''TensorFlow implementation of http://arxiv.org/pdf/1312.6114v10.pdf''' | ||
|
||
from __future__ import absolute_import, division, print_function | ||
|
||
import math | ||
import os | ||
|
||
import numpy as np | ||
import scipy.misc | ||
import tensorflow as tf | ||
from tensorflow.contrib import layers | ||
from tensorflow.contrib import losses | ||
from tensorflow.contrib.framework import arg_scope | ||
from scipy.misc import imsave | ||
from tensorflow.examples.tutorials.mnist import input_data | ||
|
||
from deconv import deconv2d | ||
from progressbar import ETA, Bar, Percentage, ProgressBar | ||
|
||
from vae import VAE | ||
from gan import GAN | ||
|
||
flags = tf.flags | ||
logging = tf.logging | ||
|
||
flags.DEFINE_integer("batch_size", 128, "batch size") | ||
flags.DEFINE_integer("updates_per_epoch", 1000, "number of updates per epoch") | ||
flags.DEFINE_integer("max_epoch", 100, "max epoch") | ||
flags.DEFINE_float("learning_rate", 1e-2, "learning rate") | ||
flags.DEFINE_string("working_directory", "", "") | ||
flags.DEFINE_integer("hidden_size", 128, "size of the hidden VAE unit") | ||
flags.DEFINE_string("model", "gan", "gan or vae") | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
if __name__ == "__main__": | ||
data_directory = os.path.join(FLAGS.working_directory, "MNIST") | ||
if not os.path.exists(data_directory): | ||
os.makedirs(data_directory) | ||
mnist = input_data.read_data_sets(data_directory, one_hot=True) | ||
|
||
assert FLAGS.model in ['vae', 'gan'] | ||
if FLAGS.model == 'vae': | ||
model = VAE(FLAGS.hidden_size, FLAGS.batch_size, FLAGS.learning_rate) | ||
elif FLAGS.model == 'gan': | ||
model = GAN(FLAGS.hidden_size, FLAGS.batch_size, FLAGS.learning_rate) | ||
|
||
for epoch in range(FLAGS.max_epoch): | ||
training_loss = 0.0 | ||
|
||
pbar = ProgressBar() | ||
for i in pbar(range(FLAGS.updates_per_epoch)): | ||
images, _ = mnist.train.next_batch(FLAGS.batch_size) | ||
loss_value = model.update_params(images) | ||
training_loss += loss_value | ||
|
||
training_loss = training_loss / \ | ||
(FLAGS.updates_per_epoch * FLAGS.batch_size) | ||
|
||
print("Loss %f" % training_loss) | ||
|
||
model.generate_and_save_images( | ||
FLAGS.batch_size, FLAGS.working_directory) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import tensorflow as tf | ||
from tensorflow.contrib import layers | ||
|
||
|
||
def encoder(input_tensor, output_size): | ||
'''Create encoder network. | ||
Args: | ||
input_tensor: a batch of flattened images [batch_size, 28*28] | ||
Returns: | ||
A tensor that expresses the encoder network | ||
''' | ||
net = tf.reshape(input_tensor, [-1, 28, 28, 1]) | ||
net = layers.conv2d(net, 32, 5, stride=2) | ||
net = layers.conv2d(net, 64, 5, stride=2) | ||
net = layers.conv2d(net, 128, 5, stride=2, padding='VALID') | ||
net = layers.dropout(net, keep_prob=0.9) | ||
net = layers.flatten(net) | ||
return layers.fully_connected(net, output_size, activation_fn=None) | ||
|
||
|
||
def discriminator(input_tensor): | ||
'''Create a network that discriminates between images from a dataset and | ||
generated ones. | ||
Args: | ||
input: a batch of real images [batch, height, width, channels] | ||
Returns: | ||
A tensor that represents the network | ||
''' | ||
|
||
return encoder(input_tensor, 1) | ||
|
||
|
||
def decoder(input_tensor): | ||
'''Create decoder network. | ||
If input tensor is provided then decodes it, otherwise samples from | ||
a sampled vector. | ||
Args: | ||
input_tensor: a batch of vectors to decode | ||
Returns: | ||
A tensor that expresses the decoder network | ||
''' | ||
|
||
net = tf.expand_dims(input_tensor, 1) | ||
net = tf.expand_dims(net, 1) | ||
net = layers.conv2d_transpose(net, 128, 3, padding='VALID') | ||
net = layers.conv2d_transpose(net, 64, 5, padding='VALID') | ||
net = layers.conv2d_transpose(net, 32, 5, stride=2) | ||
net = layers.conv2d_transpose( | ||
net, 1, 5, stride=2, activation_fn=tf.nn.sigmoid) | ||
net = layers.flatten(net) | ||
return net |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
'''TensorFlow implementation of http://arxiv.org/pdf/1312.6114v10.pdf''' | ||
|
||
from __future__ import absolute_import, division, print_function | ||
|
||
import math | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow.contrib import layers | ||
from tensorflow.contrib import losses | ||
from tensorflow.contrib.framework import arg_scope | ||
|
||
from utils import encoder, decoder | ||
from generator import Generator | ||
|
||
|
||
class VAE(Generator): | ||
|
||
def __init__(self, hidden_size, batch_size, learning_rate): | ||
self.input_tensor = tf.placeholder( | ||
tf.float32, [None, 28 * 28]) | ||
|
||
with arg_scope([layers.conv2d, layers.conv2d_transpose], | ||
activation_fn=tf.nn.elu, | ||
normalizer_fn=layers.batch_norm, | ||
normalizer_params={'scale': True}): | ||
with tf.variable_scope("model") as scope: | ||
encoded = encoder(self.input_tensor, hidden_size * 2) | ||
|
||
mean = encoded[:, :hidden_size] | ||
stddev = tf.sqrt(tf.exp(encoded[:, hidden_size:])) | ||
|
||
epsilon = tf.random_normal([tf.shape(mean)[0], hidden_size]) | ||
input_sample = mean + epsilon * stddev | ||
|
||
output_tensor = decoder(input_sample) | ||
|
||
with tf.variable_scope("model", reuse=True) as scope: | ||
self.sampled_tensor = decoder(tf.random_normal( | ||
[batch_size, hidden_size])) | ||
|
||
vae_loss = self.__get_vae_cost(mean, stddev) | ||
rec_loss = self.__get_reconstruction_cost( | ||
output_tensor, self.input_tensor) | ||
|
||
loss = vae_loss + rec_loss | ||
self.train = layers.optimize_loss(loss, tf.contrib.framework.get_or_create_global_step( | ||
), learning_rate=learning_rate, optimizer='Adam', update_ops=[]) | ||
|
||
self.sess = tf.Session() | ||
self.sess.run(tf.global_variables_initializer()) | ||
|
||
def __get_vae_cost(self, mean, stddev, epsilon=1e-8): | ||
'''VAE loss | ||
See the paper | ||
Args: | ||
mean: | ||
stddev: | ||
epsilon: | ||
''' | ||
return tf.reduce_sum(0.5 * (tf.square(mean) + tf.square(stddev) - | ||
2.0 * tf.log(stddev + epsilon) - 1.0)) | ||
|
||
def __get_reconstruction_cost(self, output_tensor, target_tensor, epsilon=1e-8): | ||
'''Reconstruction loss | ||
Cross entropy reconstruction loss | ||
Args: | ||
output_tensor: tensor produces by decoder | ||
target_tensor: the target tensor that we want to reconstruct | ||
epsilon: | ||
''' | ||
return tf.reduce_sum(-target_tensor * tf.log(output_tensor + epsilon) - | ||
(1.0 - target_tensor) * tf.log(1.0 - output_tensor + epsilon)) | ||
|
||
def update_params(self, input_tensor): | ||
'''Update parameters of the network | ||
Args: | ||
input_tensor: a batch of flattened images [batch_size, 28*28] | ||
Returns: | ||
Current loss value | ||
''' | ||
return self.sess.run(self.train, {self.input_tensor: input_tensor}) |