Skip to content

Commit 60296a9

Browse files
committed
LSTM example
1 parent 4b63d6a commit 60296a9

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
__author__ = 'maxim'
5+
6+
import tensorflow as tf
7+
8+
import hyperengine as hype
9+
10+
11+
def rnn_model(params):
12+
x = tf.placeholder(tf.float32, [None, 28, 28], name='input')
13+
y = tf.placeholder(tf.int32, [None], name='label')
14+
15+
lstm_cells = [tf.nn.rnn_cell.BasicLSTMCell(num_units=layer) for layer in params.lstm.layers]
16+
multi_cell = tf.nn.rnn_cell.MultiRNNCell(lstm_cells)
17+
outputs, states = tf.nn.dynamic_rnn(multi_cell, x, dtype=tf.float32)
18+
19+
# Here, `states` holds the final states of 3 layers, `states[-1]` is the state of the last layer.
20+
# Hence the name: `h` state of the top layer (short-term state)
21+
top_layer_h_state = states[-1][1]
22+
logits = tf.layers.dense(top_layer_h_state, 10)
23+
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
24+
loss = tf.reduce_mean(xentropy, name='loss')
25+
optimizer = tf.train.AdamOptimizer(learning_rate=params.learning_rate)
26+
optimizer.minimize(loss, name='minimize')
27+
correct = tf.nn.in_top_k(logits, y, 1)
28+
tf.reduce_mean(tf.cast(correct, tf.float32), name='accuracy')
29+
30+
from tensorflow.examples.tutorials.mnist import input_data
31+
tf_data_sets = input_data.read_data_sets('temp-mnist/data', one_hot=False)
32+
convert = lambda data_set: hype.DataSet(data_set.images.reshape((-1, 28, 28)), data_set.labels)
33+
data = hype.Data(train=convert(tf_data_sets.train),
34+
validation=convert(tf_data_sets.validation),
35+
test=convert(tf_data_sets.test))
36+
37+
def solver_generator(params):
38+
rnn_model(params)
39+
40+
solver_params = {
41+
'batch_size': 1000,
42+
'eval_batch_size': 2500,
43+
'epochs': 10,
44+
'evaluate_test': True,
45+
'eval_flexible': False,
46+
'save_dir': 'temp-mnist/model-zoo/example-3-1-{date}-{random_id}',
47+
'save_accuracy_limit': 0.99,
48+
}
49+
solver = hype.TensorflowSolver(data=data, hyper_params=params, **solver_params)
50+
return solver
51+
52+
53+
hyper_params_spec = hype.spec.new(
54+
learning_rate = 10**hype.spec.uniform(-2, -3),
55+
lstm = hype.spec.new(
56+
layers = [hype.spec.choice([128, 160, 256]),
57+
hype.spec.choice([128, 160, 256]),
58+
hype.spec.choice([128, 160, 256])]
59+
)
60+
)
61+
strategy_params = {
62+
'io_load_dir': 'temp-mnist/example-3-1',
63+
'io_save_dir': 'temp-mnist/example-3-1',
64+
}
65+
66+
tuner = hype.HyperTuner(hyper_params_spec, solver_generator, **strategy_params)
67+
tuner.tune()

hyperengine/examples/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,7 @@ All examples are self-contained to ease understanding.
2828
exploring different variations of the all-conv-nets, which achieve state-of-the-art accuracy
2929
with few parameters and computational costs.
3030
See ["Striving for Simplicity: The All Convolutional Net"](https://arxiv.org/abs/1412.6806) paper for details.
31+
32+
#### 3. Recurrent Neural Networks
33+
- [**LSTM to classify MNIST digits**](3_1_lstm_mnist.py):
34+
recurrent neural networks can process images too. Let's see if it can get to 99% with right hyper-parameters.

0 commit comments

Comments
 (0)