|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | + |
| 4 | +__author__ = 'maxim' |
| 5 | + |
| 6 | +import tensorflow as tf |
| 7 | + |
| 8 | +import hyperengine as hype |
| 9 | + |
| 10 | + |
| 11 | +def rnn_model(params): |
| 12 | + x = tf.placeholder(tf.float32, [None, 28, 28], name='input') |
| 13 | + y = tf.placeholder(tf.int32, [None], name='label') |
| 14 | + |
| 15 | + lstm_cells = [tf.nn.rnn_cell.BasicLSTMCell(num_units=layer) for layer in params.lstm.layers] |
| 16 | + multi_cell = tf.nn.rnn_cell.MultiRNNCell(lstm_cells) |
| 17 | + outputs, states = tf.nn.dynamic_rnn(multi_cell, x, dtype=tf.float32) |
| 18 | + |
| 19 | + # Here, `states` holds the final states of 3 layers, `states[-1]` is the state of the last layer. |
| 20 | + # Hence the name: `h` state of the top layer (short-term state) |
| 21 | + top_layer_h_state = states[-1][1] |
| 22 | + logits = tf.layers.dense(top_layer_h_state, 10) |
| 23 | + xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits) |
| 24 | + loss = tf.reduce_mean(xentropy, name='loss') |
| 25 | + optimizer = tf.train.AdamOptimizer(learning_rate=params.learning_rate) |
| 26 | + optimizer.minimize(loss, name='minimize') |
| 27 | + correct = tf.nn.in_top_k(logits, y, 1) |
| 28 | + tf.reduce_mean(tf.cast(correct, tf.float32), name='accuracy') |
| 29 | + |
| 30 | +from tensorflow.examples.tutorials.mnist import input_data |
| 31 | +tf_data_sets = input_data.read_data_sets('temp-mnist/data', one_hot=False) |
| 32 | +convert = lambda data_set: hype.DataSet(data_set.images.reshape((-1, 28, 28)), data_set.labels) |
| 33 | +data = hype.Data(train=convert(tf_data_sets.train), |
| 34 | + validation=convert(tf_data_sets.validation), |
| 35 | + test=convert(tf_data_sets.test)) |
| 36 | + |
| 37 | +def solver_generator(params): |
| 38 | + rnn_model(params) |
| 39 | + |
| 40 | + solver_params = { |
| 41 | + 'batch_size': 1000, |
| 42 | + 'eval_batch_size': 2500, |
| 43 | + 'epochs': 10, |
| 44 | + 'evaluate_test': True, |
| 45 | + 'eval_flexible': False, |
| 46 | + 'save_dir': 'temp-mnist/model-zoo/example-3-1-{date}-{random_id}', |
| 47 | + 'save_accuracy_limit': 0.99, |
| 48 | + } |
| 49 | + solver = hype.TensorflowSolver(data=data, hyper_params=params, **solver_params) |
| 50 | + return solver |
| 51 | + |
| 52 | + |
| 53 | +hyper_params_spec = hype.spec.new( |
| 54 | + learning_rate = 10**hype.spec.uniform(-2, -3), |
| 55 | + lstm = hype.spec.new( |
| 56 | + layers = [hype.spec.choice([128, 160, 256]), |
| 57 | + hype.spec.choice([128, 160, 256]), |
| 58 | + hype.spec.choice([128, 160, 256])] |
| 59 | + ) |
| 60 | +) |
| 61 | +strategy_params = { |
| 62 | + 'io_load_dir': 'temp-mnist/example-3-1', |
| 63 | + 'io_save_dir': 'temp-mnist/example-3-1', |
| 64 | +} |
| 65 | + |
| 66 | +tuner = hype.HyperTuner(hyper_params_spec, solver_generator, **strategy_params) |
| 67 | +tuner.tune() |
0 commit comments