Skip to content

Commit 17f8a39

Browse files
author
Kyu-Young
committed
init
0 parents  commit 17f8a39

File tree

4 files changed

+460
-0
lines changed

4 files changed

+460
-0
lines changed

input_fn.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Input pipeline using Dataset API."""
2+
3+
from __future__ import absolute_import
4+
from __future__ import division
5+
from __future__ import print_function
6+
from __future__ import unicode_literals
7+
8+
from absl import flags
9+
10+
import tensorflow as tf
11+
import tensorflow_datasets as tfds
12+
13+
FLAGS = flags.FLAGS
14+
15+
flags.DEFINE_integer('shuffle_buffer_size', 5000, 'Size of the shuffle buffer.')
16+
17+
18+
class InputDataset(object):
19+
"""Input pipeline for the IMDB dataset.
20+
21+
Attributes:
22+
tokenizer: Tokenizer used to encode and decode text.
23+
"""
24+
25+
def __init__(self, encoding, max_length=None):
26+
"""Creates an InputDataset instance.
27+
28+
Args:
29+
encoding: Type of encoding to use. Should be one of 'plain_text', 'bytes',
30+
'subwords8k', and 'subwords32k'.
31+
"""
32+
if encoding not in ('plain_text', 'bytes', 'subwords8k', 'subwords32k'):
33+
raise ValueError('Unsupported encoding type %s' % encoding)
34+
35+
loaded_imdb = tfds.load(
36+
'imdb_reviews/{}'.format(encoding), with_info=True, as_supervised=True)
37+
self._dataset, self._info = loaded_imdb
38+
self.tokenizer = self._info.features['text'].encoder
39+
self.max_length = max_length
40+
41+
def input_fn(self, mode, batch_size, bucket_boundaries=None, bow=False):
42+
"""Returns an instance of tf.data.Dataset.
43+
44+
Args:
45+
mode: One of 'train' or 'test'.
46+
batch_size: Size of a batch.
47+
bucket_boundaries: List of boundaries for bucketing.
48+
bow: True to process the input as a bag-of-words.
49+
"""
50+
if mode not in ('train', 'test'):
51+
raise ValueError('Unsupported mode type %s' % mode)
52+
dataset = self._dataset[mode]
53+
54+
# Transform into a bag-of-words input if applicable.
55+
def bag_of_words(tokens, label):
56+
indices = tf.expand_dims(tokens, axis=-1)
57+
updates = tf.ones([tf.shape(indices)[0]])
58+
shape = tf.constant([self.tokenizer.vocab_size], dtype=indices.dtype)
59+
scatter = tf.scatter_nd(indices, updates, shape)
60+
return scatter, label
61+
if bow:
62+
dataset = dataset.map(bag_of_words, num_parallel_calls=12)
63+
64+
# Shuffle the data.
65+
if self.max_length:
66+
dataset = dataset.filter(lambda f, l: tf.shape(f)[0] < self.max_length)
67+
dataset = dataset.shuffle(
68+
buffer_size=FLAGS.shuffle_buffer_size, reshuffle_each_iteration=True)
69+
70+
# Create batches of examples and pad.
71+
if mode == 'train' and bucket_boundaries:
72+
bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1)
73+
dataset = dataset.apply(
74+
tf.data.experimental.bucket_by_sequence_length(
75+
lambda feature, label: tf.shape(feature)[0],
76+
bucket_boundaries=bucket_boundaries,
77+
bucket_batch_sizes=bucket_batch_sizes,
78+
padded_shapes=dataset.output_shapes))
79+
else:
80+
output_shapes = dataset.output_shapes
81+
if self.max_length:
82+
output_shapes = (tf.TensorShape([tf.Dimension(sefl.max_length)]),
83+
tf.TensorShape([]))
84+
dataset = dataset.padded_batch(batch_size, output_shapes)
85+
86+
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
87+
88+
return dataset

knn.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""Sentiment analysis using KNN."""
2+
3+
from __future__ import absolute_import
4+
from __future__ import division
5+
from __future__ import print_function
6+
7+
import collections
8+
import heapq
9+
import random
10+
import time
11+
12+
from absl import app
13+
from absl import flags
14+
15+
FLAGS = flags.FLAGS
16+
17+
flags.DEFINE_enum('mode', 'knn', ['knn', 'analyze'], 'Execution mode.')
18+
19+
flags.DEFINE_string('train_data', None, 'Train file in LIBSVM format.')
20+
21+
flags.DEFINE_string('test_data', None, 'Test file in LIBSVM format.')
22+
23+
flags.DEFINE_integer('k_value', 0, 'Value of k.')
24+
25+
26+
def parse_libsvm_file(filename):
27+
"""Parses a file in LIBSVM format."""
28+
# Features and label.
29+
data_points = []
30+
with open(filename) as f:
31+
for line in f:
32+
line = line.split()
33+
assert len(line) > 1
34+
d = {'features': {}, 'norm': 0.0, 'label': int(line[0])}
35+
for bow in line[1:]:
36+
word_id, num_occur = bow.split(':')
37+
num_occur = float(num_occur)
38+
d['features'][word_id] = num_occur
39+
d['norm'] += num_occur ** 2
40+
data_points.append(d)
41+
return data_points
42+
43+
44+
def l2_dist(d1, d2):
45+
"""L2 distance between two sparse vectors represented as dicts."""
46+
if len(d1['features']) < len(d2['features']):
47+
return l2_dist(d2, d1)
48+
d1_norm, d2_norm = d1['norm'], d2['norm']
49+
return (d1_norm + d2_norm - 2 * sum(
50+
d1['features'].get(key, 0.0) * d2['features'].get(key, 0.0)
51+
for key in d2['features'].keys()))
52+
53+
54+
def find_knn(data_points, d, k):
55+
"""Finds k-nearest data points."""
56+
neighbors = []
57+
heapq.heapify(neighbors)
58+
for data_point in data_points:
59+
l2d = l2_dist(data_point, d)
60+
if len(neighbors) < k:
61+
heapq.heappush(neighbors, (-l2d, data_point))
62+
else:
63+
heapq.heappushpop(neighbors, (-l2d, data_point))
64+
return [item[1] for item in neighbors]
65+
66+
67+
def run_knn(train_data_points, test_data_points, k):
68+
"""Runs knn and report the overall error rate."""
69+
count = 0
70+
num_pos, num_neg = 0.0, 0.0
71+
num_pos_correct, num_neg_correct = 0.0, 0.0
72+
for test_data_point in test_data_points:
73+
count += 1
74+
if count % 1000 == 0:
75+
print('Processed {} examples.'.format(count))
76+
neighbors = find_knn(train_data_points, test_data_point, k)
77+
score = sum(neighbor['label'] for neighbor in neighbors)
78+
score /= float(len(neighbors))
79+
true_score = test_data_point['label']
80+
if true_score >= 7:
81+
num_pos += 1
82+
if score >= 7:
83+
num_pos_correct += 1
84+
if true_score <= 4:
85+
num_neg += 1
86+
if score <= 4:
87+
num_neg_correct += 1
88+
89+
pos_error_rate = 1.0 - num_pos_correct / (num_pos + 1e-8)
90+
neg_error_rate = 1.0 - num_neg_correct / (num_neg + 1e-8)
91+
tot_error_rate = (
92+
1.0 - (num_pos_correct + num_neg_correct) / (num_pos + num_neg + 1e-8))
93+
print('Pos error rate: {}'.format(round(pos_error_rate, 5)))
94+
print('Neg error rate: {}'.format(round(neg_error_rate, 5)))
95+
print('Tot error rate: {}'.format(round(tot_error_rate, 5)))
96+
97+
98+
def run_analysis(data_points):
99+
"""Analyzes input data."""
100+
num_unique_words_dict = collections.defaultdict(int)
101+
num_total_words_dict = collections.defaultdict(int)
102+
for d in data_points:
103+
num_unique_words_dict[len(d['features']) // 100] += 1
104+
num_words = sum(d['features'].values())
105+
num_total_words_dict[num_words // 100] += 1
106+
num_total = float(len(data_points))
107+
avg_unique_words = (
108+
sum(k * v for k, v in num_unique_words_dict.items()) / num_total)
109+
avg_total_words = (
110+
sum(k * v for k, v in num_total_words_dict.items()) / num_total)
111+
print('Dist of unique words count: {}'.format(num_unique_words_dict))
112+
print('Dist of total words count: {}'.format(num_total_words_dict))
113+
114+
115+
def main(unused_argv):
116+
print('Start parsing input data..')
117+
train_data_points = parse_libsvm_file(FLAGS.train_data)
118+
test_data_points = parse_libsvm_file(FLAGS.test_data)
119+
if FLAGS.mode == 'knn':
120+
random.shuffle(train_data_points)
121+
random.shuffle(test_data_points)
122+
print('Start running knn..')
123+
start = time.time()
124+
run_knn(train_data_points, test_data_points, FLAGS.k_value)
125+
end = time.time()
126+
print('Run time: {} secs'.format(round(end - start, 2)))
127+
elif FLAGS.mode == 'analyze':
128+
print('Analyze train data:')
129+
run_analysis(train_data_points)
130+
print('Analyze test data:')
131+
run_analysis(test_data_points)
132+
133+
134+
if __name__ == '__main__':
135+
app.run(main)

main.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""Main to run TensorFlow models."""
2+
3+
from __future__ import absolute_import
4+
from __future__ import division
5+
from __future__ import print_function
6+
7+
import os
8+
9+
from absl import flags
10+
11+
import tensorflow as tf
12+
13+
from . import input_fn
14+
from . import model
15+
from . import util
16+
17+
FLAGS = flags.FLAGS
18+
19+
flags.DEFINE_enum('mode', None, ['train', 'eval'], 'Execution mode.')
20+
21+
flags.DEFINE_string('logdir', '/tmp/sentiment-analysis', 'Model directory.')
22+
23+
flags.DEFINE_enum('model', 'rnn', ['mlp', 'rnn'], 'Type of model to use.')
24+
25+
flags.DEFINE_enum('optimizer', 'adam',
26+
['sgd', 'rmsprop', 'adam'],
27+
'Type of optimizer to use for training.')
28+
29+
flags.DEFINE_enum('encoding', 'subwords8k',
30+
['plain_text', 'bytes', 'subwords8k', 'subwords32k'],
31+
'Type of text encoding to use.')
32+
33+
flags.DEFINE_integer('num_epochs', 10, 'Number of epochs to run for training.')
34+
35+
flags.DEFINE_integer('num_layers', 1, 'Number of hidden layers.')
36+
37+
flags.DEFINE_list('num_units', [64], 'Number of hidden units.')
38+
39+
flags.DEFINE_enum('cell_type', 'lstm',
40+
['gru', 'lstm', 'bidi-gru', 'bidi-lstm'],
41+
'Type of RNN cell to use.')
42+
43+
flags.DEFINE_integer('embedding_size', 32, 'Size of the input embedding.')
44+
45+
flags.DEFINE_integer('batch_size', 16, 'Size of the batch.')
46+
47+
flags.DEFINE_bool('verbose', True, 'Verbosity.')
48+
49+
flags.DEFINE_integer('max_length', None, 'Maximum length input to train on.')
50+
51+
flags.DEFINE_bool('early_stop', False, 'True to early stop')
52+
53+
54+
def create_model(vocab_size):
55+
"""Creates a Keras model."""
56+
num_units = [int(num_unit) for num_unit in FLAGS.num_units]
57+
if FLAGS.model == 'rnn':
58+
new_model = model.rnn_model(FLAGS.num_layers, FLAGS.cell_type, num_units,
59+
vocab_size, FLAGS.embedding_size)
60+
else:
61+
new_model = model.mlp_model(FLAGS.num_layers, num_units, vocab_size)
62+
new_model.compile(optimizer=FLAGS.optimizer, loss='binary_crossentropy',
63+
metrics=['accuracy'])
64+
new_model.summary()
65+
return new_model
66+
67+
68+
def run_train():
69+
"""Trains a model."""
70+
# Set up input pipeline.
71+
input_dataset = input_fn.InputDataset(FLAGS.encoding)
72+
tokenizer = input_dataset.tokenizer
73+
74+
use_bow = (FLAGS.model == 'mlp')
75+
train_dataset = input_dataset.input_fn('train', FLAGS.batch_size, bow=use_bow)
76+
test_dataset = input_dataset.input_fn('test', 10, bow=use_bow)
77+
78+
new_model = create_model(tokenizer.vocab_size)
79+
latest_checkpoint = tf.train.latest_checkpoint(FLAGS.logdir)
80+
if latest_checkpoint:
81+
print("Reloading from {}".format(latest_checkpoint))
82+
new_model.load_weights(latest_checkpoint)
83+
84+
# Define callbacks to run during training.
85+
callbacks = []
86+
87+
checkpoint = util.CNSModelCheckpoint(os.path.join(FLAGS.logdir, FLAGS.model))
88+
callbacks.append(checkpoint)
89+
90+
tensorboard = tf.keras.callbacks.TensorBoard(
91+
log_dir=FLAGS.logdir, update_freq='batch')
92+
callbacks.append(tensorboard)
93+
94+
if FLAGS.early_stop:
95+
early_stop = tf.keras.callbacks.EarlyStopping(
96+
monitor='val_accuracy', min_delta=0.0001, patience=10)
97+
callbacks.append(early_stop)
98+
99+
# Start training.
100+
history = new_model.fit(train_dataset, epochs=FLAGS.num_epochs,
101+
callbacks=callbacks,
102+
validation_data=test_dataset,
103+
validation_steps=25,
104+
verbose=int(FLAGS.verbose))
105+
106+
# Write out the training history.
107+
dirname = os.path.dirname(FLAGS.logdir)
108+
if not tf.gfile.Exists(dirname):
109+
tf.gfile.MakeDirs(dirname)
110+
with tf.gfile.GFile(os.path.join(FLAGS.logdir, 'history.txt'), 'w') as f:
111+
f.write(str(history.history))
112+
113+
114+
def run_eval():
115+
"""Evaluates a model."""
116+
# Set up input pipeline.
117+
input_dataset = input_fn.InputDataset(FLAGS.encoding)
118+
tokenizer = input_dataset.tokenizer
119+
120+
use_bow = (FLAGS.model == 'mlp')
121+
dataset = input_dataset.input_fn('test', FLAGS.batch_size, bow=use_bow)
122+
123+
new_model = create_model(tokenizer.vocab_size)
124+
latest_checkpoint = tf.train.latest_checkpoint(FLAGS.logdir)
125+
if latest_checkpoint:
126+
print("Reloading from {}".format(latest_checkpoint))
127+
new_model.load_weights(latest_checkpoint)
128+
129+
ret = new_model.evaluate(dataset)
130+
print('Eval results: {}'.format(ret))
131+
132+
133+
def main(unused_argv):
134+
if FLAGS.mode == 'train':
135+
run_train()
136+
elif FLAGS.mode == 'eval':
137+
run_eval()
138+
139+
140+
if __name__ == '__main__':
141+
tf.app.run(main)

0 commit comments

Comments
 (0)