Skip to content

Commit b5afddb

Browse files
author
Ryan Sepassi
committed
Reproduce reported virtual adversarial text results
1 parent fc7342b commit b5afddb

17 files changed

+208
-116
lines changed

adversarial_text/BUILD

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
licenses(["notice"]) # Apache 2.0
2+
13
# Binaries
24
# ==============================================================================
35
py_binary(
46
name = "evaluate",
57
srcs = ["evaluate.py"],
68
deps = [
79
":graphs",
10+
# google3 file dep,
11+
# tensorflow dep,
812
],
913
)
1014

@@ -14,6 +18,8 @@ py_binary(
1418
deps = [
1519
":graphs",
1620
":train_utils",
21+
# google3 file dep,
22+
# tensorflow dep,
1723
],
1824
)
1925

@@ -25,6 +31,8 @@ py_binary(
2531
deps = [
2632
":graphs",
2733
":train_utils",
34+
# google3 file dep,
35+
# tensorflow dep,
2836
],
2937
)
3038

@@ -37,30 +45,42 @@ py_library(
3745
":adversarial_losses",
3846
":inputs",
3947
":layers",
48+
# tensorflow dep,
4049
],
4150
)
4251

4352
py_library(
4453
name = "adversarial_losses",
4554
srcs = ["adversarial_losses.py"],
55+
deps = [
56+
# tensorflow dep,
57+
],
4658
)
4759

4860
py_library(
4961
name = "inputs",
5062
srcs = ["inputs.py"],
5163
deps = [
64+
# tensorflow dep,
5265
"//adversarial_text/data:data_utils",
5366
],
5467
)
5568

5669
py_library(
5770
name = "layers",
5871
srcs = ["layers.py"],
72+
deps = [
73+
# tensorflow dep,
74+
],
5975
)
6076

6177
py_library(
6278
name = "train_utils",
6379
srcs = ["train_utils.py"],
80+
deps = [
81+
# numpy dep,
82+
# tensorflow dep,
83+
],
6484
)
6585

6686
# Tests
@@ -71,6 +91,7 @@ py_test(
7191
srcs = ["graphs_test.py"],
7292
deps = [
7393
":graphs",
94+
# tensorflow dep,
7495
"//adversarial_text/data:data_utils",
7596
],
7697
)

adversarial_text/README.md

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ $ bazel run :pretrain -- \
5656
--embedding_dims=256 \
5757
--rnn_cell_size=1024 \
5858
--num_candidate_samples=1024 \
59-
--optimizer=adam \
6059
--batch_size=256 \
6160
--learning_rate=0.001 \
6261
--learning_rate_decay_factor=0.9999 \
@@ -87,7 +86,6 @@ $ bazel run :train_classifier -- \
8786
--rnn_cell_size=1024 \
8887
--cl_num_layers=1 \
8988
--cl_hidden_size=30 \
90-
--optimizer=adam \
9189
--batch_size=64 \
9290
--learning_rate=0.0005 \
9391
--learning_rate_decay_factor=0.9998 \
@@ -96,7 +94,8 @@ $ bazel run :train_classifier -- \
9694
--num_timesteps=400 \
9795
--keep_prob_emb=0.5 \
9896
--normalize_embeddings \
99-
--adv_training_method=vat
97+
--adv_training_method=vat \
98+
--perturb_norm_length=5.0
10099
```
101100

102101
### Evaluate on test data
@@ -136,21 +135,21 @@ adversarial training losses). The training loop itself is defined in
136135
### Command-Line Flags
137136

138137
Flags related to distributed training and the training loop itself are defined
139-
in `train_utils.py`.
138+
in [`train_utils.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/train_utils.py).
140139

141-
Flags related to model hyperparameters are defined in `graphs.py`.
140+
Flags related to model hyperparameters are defined in [`graphs.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/graphs.py).
142141

143-
Flags related to adversarial training are defined in `adversarial_losses.py`.
142+
Flags related to adversarial training are defined in [`adversarial_losses.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/adversarial_losses.py).
144143

145144
Flags particular to each job are defined in the main binary files.
146145

147146
### Data Generation
148147

149-
* Vocabulary generation: `gen_vocab.py`
150-
* Data generation: `gen_data.py`
148+
* Vocabulary generation: [`gen_vocab.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/data/gen_vocab.py)
149+
* Data generation: [`gen_data.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/data/gen_data.py)
151150

152-
Command-line flags defined in `document_generators.py` control which dataset is
153-
processed and how.
151+
Command-line flags defined in [`document_generators.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/data/document_generators.py)
152+
control which dataset is processed and how.
154153

155154
## Contact for Issues
156155

adversarial_text/adversarial_losses.py

Lines changed: 27 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017 Google, Inc. All Rights Reserved.
1+
# Copyright 2017 Google Inc. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,25 +12,27 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
1615
"""Adversarial losses for text models."""
1716
from __future__ import absolute_import
1817
from __future__ import division
1918
from __future__ import print_function
2019

20+
# Dependency imports
21+
2122
import tensorflow as tf
2223

2324
flags = tf.app.flags
2425
FLAGS = flags.FLAGS
2526

2627
# Adversarial and virtual adversarial training parameters.
27-
flags.DEFINE_float('perturb_norm_length', 0.1,
28+
flags.DEFINE_float('perturb_norm_length', 5.0,
2829
'Norm length of adversarial perturbation to be '
29-
'optimized with validation')
30+
'optimized with validation. '
31+
'5.0 is optimal on IMDB with virtual adversarial training. ')
3032

3133
# Virtual adversarial training parameters
3234
flags.DEFINE_integer('num_power_iteration', 1, 'The number of power iteration')
33-
flags.DEFINE_float('small_constant_for_finite_diff', 1e-3,
35+
flags.DEFINE_float('small_constant_for_finite_diff', 1e-1,
3436
'Small constant for finite difference method')
3537

3638
# Parameters for building the graph
@@ -83,19 +85,22 @@ def virtual_adversarial_loss(logits, embedded, inputs,
8385
"""
8486
# Stop gradient of logits. See https://arxiv.org/abs/1507.00677 for details.
8587
logits = tf.stop_gradient(logits)
88+
8689
# Only care about the KL divergence on the final timestep.
87-
weights = _end_of_seq_mask(inputs.labels)
90+
weights = inputs.eos_weights
91+
assert weights is not None
8892

8993
# Initialize perturbation with random noise.
9094
# shape(embedded) = (batch_size, num_timesteps, embedding_dim)
91-
d = _mask_by_length(tf.random_normal(shape=tf.shape(embedded)), inputs.length)
95+
d = tf.random_normal(shape=tf.shape(embedded))
9296

9397
# Perform finite difference method and power iteration.
9498
# See Eq.(8) in the paper http://arxiv.org/pdf/1507.00677.pdf,
9599
# Adding small noise to input and taking gradient with respect to the noise
96100
# corresponds to 1 power iteration.
97101
for _ in xrange(FLAGS.num_power_iteration):
98-
d = _scale_l2(d, FLAGS.small_constant_for_finite_diff)
102+
d = _scale_l2(
103+
_mask_by_length(d, inputs.length), FLAGS.small_constant_for_finite_diff)
99104
d_logits = logits_from_embedding_fn(embedded + d)
100105
kl = _kl_divergence_with_logits(logits, d_logits, weights)
101106
d, = tf.gradients(
@@ -104,8 +109,7 @@ def virtual_adversarial_loss(logits, embedded, inputs,
104109
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
105110
d = tf.stop_gradient(d)
106111

107-
perturb = _scale_l2(
108-
_mask_by_length(d, inputs.length), FLAGS.perturb_norm_length)
112+
perturb = _scale_l2(d, FLAGS.perturb_norm_length)
109113
vadv_logits = logits_from_embedding_fn(embedded + perturb)
110114
return _kl_divergence_with_logits(logits, vadv_logits, weights)
111115

@@ -136,7 +140,8 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
136140
"""Virtual adversarial loss for bidirectional models."""
137141
logits = tf.stop_gradient(logits)
138142
f_inputs, _ = inputs
139-
weights = _end_of_seq_mask(f_inputs.labels)
143+
weights = f_inputs.eos_weights
144+
assert weights is not None
140145

141146
perturbs = [
142147
_mask_by_length(tf.random_normal(shape=tf.shape(emb)), f_inputs.length)
@@ -155,10 +160,7 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
155160
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
156161
perturbs = [tf.stop_gradient(d) for d in perturbs]
157162

158-
perturbs = [
159-
_scale_l2(_mask_by_length(d, f_inputs.length), FLAGS.perturb_norm_length)
160-
for d in perturbs
161-
]
163+
perturbs = [_scale_l2(d, FLAGS.perturb_norm_length) for d in perturbs]
162164
vadv_logits = logits_from_embedding_fn(
163165
[emb + d for (emb, d) in zip(embedded, perturbs)])
164166
return _kl_divergence_with_logits(logits, vadv_logits, weights)
@@ -167,40 +169,26 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
167169
def _mask_by_length(t, length):
168170
"""Mask t, 3-D [batch, time, dim], by length, 1-D [batch,]."""
169171
maxlen = t.get_shape().as_list()[1]
170-
mask = tf.sequence_mask(length, maxlen=maxlen)
172+
173+
# Subtract 1 from length to prevent the perturbation from going on 'eos'
174+
mask = tf.sequence_mask(length - 1, maxlen=maxlen)
171175
mask = tf.expand_dims(tf.cast(mask, tf.float32), -1)
172176
# shape(mask) = (batch, num_timesteps, 1)
173177
return t * mask
174178

175179

176180
def _scale_l2(x, norm_length):
177181
# shape(x) = (batch, num_timesteps, d)
178-
179182
# Divide x by max(abs(x)) for a numerically stable L2 norm.
180183
# 2norm(x) = a * 2norm(x/a)
181184
# Scale over the full sequence, dims (1, 2)
182185
alpha = tf.reduce_max(tf.abs(x), (1, 2), keep_dims=True) + 1e-12
183-
l2_norm = alpha * tf.sqrt(tf.reduce_sum(tf.pow(x / alpha, 2), (1, 2),
184-
keep_dims=True) + 1e-6)
186+
l2_norm = alpha * tf.sqrt(
187+
tf.reduce_sum(tf.pow(x / alpha, 2), (1, 2), keep_dims=True) + 1e-6)
185188
x_unit = x / l2_norm
186189
return norm_length * x_unit
187190

188191

189-
def _end_of_seq_mask(tokens):
190-
"""Generate a mask for the EOS token (1.0 on EOS, 0.0 otherwise).
191-
192-
Args:
193-
tokens: 1-D integer tensor [num_timesteps*batch_size]. Each element is an
194-
id from the vocab.
195-
196-
Returns:
197-
Float tensor same shape as tokens, whose values are 1.0 on the end of
198-
sequence and 0.0 on the others.
199-
"""
200-
eos_id = FLAGS.vocab_size - 1
201-
return tf.cast(tf.equal(tokens, eos_id), tf.float32)
202-
203-
204192
def _kl_divergence_with_logits(q_logits, p_logits, weights):
205193
"""Returns weighted KL divergence between distributions q and p.
206194
@@ -218,21 +206,19 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
218206
# For logistic regression
219207
if FLAGS.num_classes == 2:
220208
q = tf.nn.sigmoid(q_logits)
221-
p = tf.nn.sigmoid(p_logits)
222209
kl = (-tf.nn.sigmoid_cross_entropy_with_logits(logits=q_logits, labels=q) +
223210
tf.nn.sigmoid_cross_entropy_with_logits(logits=p_logits, labels=q))
211+
kl = tf.squeeze(kl)
224212

225213
# For softmax regression
226214
else:
227-
q = tf.nn.softmax(q_logits)
228-
p = tf.nn.softmax(p_logits)
229-
kl = tf.reduce_sum(q * (tf.log(q) - tf.log(p)), 1)
215+
kl = tf.reduce_sum(
216+
q * (tf.nn.log_softmax(q_logits) - tf.nn.log_softmax(p_logits)), 1)
230217

231218
num_labels = tf.reduce_sum(weights)
232219
num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels)
233220

234-
kl.get_shape().assert_has_rank(2)
221+
kl.get_shape().assert_has_rank(1)
235222
weights.get_shape().assert_has_rank(1)
236-
loss = tf.identity(tf.reduce_sum(tf.expand_dims(weights, -1) * kl) /
237-
num_labels, name='kl')
223+
loss = tf.identity(tf.reduce_sum(weights * kl) / num_labels, name='kl')
238224
return loss

adversarial_text/data/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
licenses(["notice"]) # Apache 2.0
2+
13
package(
24
default_visibility = [
35
"//adversarial_text:__subpackages__",
@@ -10,6 +12,7 @@ py_binary(
1012
deps = [
1113
":data_utils",
1214
":document_generators",
15+
# tensorflow dep,
1316
],
1417
)
1518

@@ -19,23 +22,31 @@ py_binary(
1922
deps = [
2023
":data_utils",
2124
":document_generators",
25+
# tensorflow dep,
2226
],
2327
)
2428

2529
py_library(
2630
name = "document_generators",
2731
srcs = ["document_generators.py"],
32+
deps = [
33+
# tensorflow dep,
34+
],
2835
)
2936

3037
py_library(
3138
name = "data_utils",
3239
srcs = ["data_utils.py"],
40+
deps = [
41+
# tensorflow dep,
42+
],
3343
)
3444

3545
py_test(
3646
name = "data_utils_test",
3747
srcs = ["data_utils_test.py"],
3848
deps = [
3949
":data_utils",
50+
# tensorflow dep,
4051
],
4152
)

adversarial_text/data/data_utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017 Google, Inc. All Rights Reserved.
1+
# Copyright 2017 Google Inc. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,13 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
1615
"""Utilities for generating/preprocessing data for adversarial text models."""
1716

1817
import operator
1918
import os
2019
import random
2120
import re
21+
22+
# Dependency imports
23+
2224
import tensorflow as tf
2325

2426
EOS_TOKEN = '</s>'
@@ -215,13 +217,17 @@ def build_lm_sequence(seq):
215217
216218
Returns:
217219
SequenceWrapper with `seq` tokens copied over to output sequence tokens and
218-
labels (offset by 1, i.e. predict next token) with weights set to 1.0.
220+
labels (offset by 1, i.e. predict next token) with weights set to 1.0,
221+
except for <eos> token.
219222
"""
220223
lm_seq = SequenceWrapper()
221-
for i, timestep in enumerate(seq[:-1]):
222-
lm_seq.add_timestep().set_token(timestep.token).set_label(
223-
seq[i + 1].token).set_weight(1.0)
224-
224+
for i, timestep in enumerate(seq):
225+
if i == len(seq) - 1:
226+
lm_seq.add_timestep().set_token(timestep.token).set_label(
227+
seq[i].token).set_weight(0.0)
228+
else:
229+
lm_seq.add_timestep().set_token(timestep.token).set_label(
230+
seq[i + 1].token).set_weight(1.0)
225231
return lm_seq
226232

227233

0 commit comments

Comments
 (0)