From 8a70a535f29b094e697194e08dca68863534921c Mon Sep 17 00:00:00 2001 From: "Hoenicke, Florian (DE - Berlin)" Date: Tue, 7 Aug 2018 12:18:54 +0200 Subject: [PATCH 1/2] python 2 to 3 --- .gitignore | 4 ++++ input_data.py | 8 ++++---- main.py | 4 ++-- utils.py | 2 +- 4 files changed, 11 insertions(+), 7 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..294075e --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.idea/ +__pycache__/ +training/ +MNIST_data/ diff --git a/input_data.py b/input_data.py index c3195ee..c5c75de 100644 --- a/input_data.py +++ b/input_data.py @@ -12,9 +12,9 @@ def maybe_download(filename, work_directory): os.mkdir(work_directory) filepath = os.path.join(work_directory, filename) if not os.path.exists(filepath): - filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath) + filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) statinfo = os.stat(filepath) - print 'Succesfully downloaded', filename, statinfo.st_size, 'bytes.' + print ('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') return filepath @@ -25,7 +25,7 @@ def _read32(bytestream): def extract_images(filename): """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" - print 'Extracting', filename + print ('Extracting', filename) with gzip.open(filename) as bytestream: magic = _read32(bytestream) if magic != 2051: @@ -52,7 +52,7 @@ def dense_to_one_hot(labels_dense, num_classes=10): def extract_labels(filename, one_hot=False): """Extract the labels into a 1D uint8 numpy array [index].""" - print 'Extracting', filename + print ('Extracting', filename) with gzip.open(filename) as bytestream: magic = _read32(bytestream) if magic != 2049: diff --git a/main.py b/main.py index 1c2678b..3bb3471 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,6 @@ import tensorflow as tf import numpy as np -import input_data +from tensorflow.examples.tutorials.mnist import input_data import matplotlib.pyplot as plt import os from scipy.misc import imsave as ims @@ -69,7 +69,7 @@ def train(self): _, gen_loss, lat_loss = sess.run((self.optimizer, self.generation_loss, self.latent_loss), feed_dict={self.images: batch}) # dumb hack to print cost every epoch if idx % (self.n_samples - 3) == 0: - print "epoch %d: genloss %f latloss %f" % (epoch, np.mean(gen_loss), np.mean(lat_loss)) + print ("epoch {}: genloss {} latloss {}".format(epoch, np.mean(gen_loss), np.mean(lat_loss))) saver.save(sess, os.getcwd()+"/training/train",global_step=epoch) generated_test = sess.run(self.generated_images, feed_dict={self.images: visualization}) generated_test = generated_test.reshape(self.batchsize,28,28) diff --git a/utils.py b/utils.py index e5ce3db..eeacd3c 100644 --- a/utils.py +++ b/utils.py @@ -7,7 +7,7 @@ def merge(images, size): for idx, image in enumerate(images): i = idx % size[1] - j = idx / size[1] + j = int(idx / size[1]) img[j*h:j*h+h, i*w:i*w+w] = image return img From cdcea1551e511e143bef6f764b20e58b1b0dda3e Mon Sep 17 00:00:00 2001 From: "Hoenicke, Florian (DE - Berlin)" Date: Tue, 7 Aug 2018 12:46:40 +0200 Subject: [PATCH 2/2] remove hack --- main.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 3bb3471..904edf2 100644 --- a/main.py +++ b/main.py @@ -64,16 +64,18 @@ def train(self): with tf.Session() as sess: sess.run(tf.initialize_all_variables()) for epoch in range(10): + gen_loss = None + lat_loss = None + for idx in range(int(self.n_samples / self.batchsize)): batch = self.mnist.train.next_batch(self.batchsize)[0] _, gen_loss, lat_loss = sess.run((self.optimizer, self.generation_loss, self.latent_loss), feed_dict={self.images: batch}) - # dumb hack to print cost every epoch - if idx % (self.n_samples - 3) == 0: - print ("epoch {}: genloss {} latloss {}".format(epoch, np.mean(gen_loss), np.mean(lat_loss))) - saver.save(sess, os.getcwd()+"/training/train",global_step=epoch) - generated_test = sess.run(self.generated_images, feed_dict={self.images: visualization}) - generated_test = generated_test.reshape(self.batchsize,28,28) - ims("results/"+str(epoch)+".jpg",merge(generated_test[:64],[8,8])) + + print ("epoch {}: genloss {} latloss {}".format(epoch, np.mean(gen_loss), np.mean(lat_loss))) + saver.save(sess, os.getcwd()+"/training/train",global_step=epoch) + generated_test = sess.run(self.generated_images, feed_dict={self.images: visualization}) + generated_test = generated_test.reshape(self.batchsize,28,28) + ims("results/"+str(epoch)+".jpg",merge(generated_test[:64],[8,8])) model = LatentAttention()