Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 0d840ee

Browse files
Lukasz KaiserCopybara-Service
Lukasz Kaiser
authored and
Copybara-Service
committed
Make TransformerLM train reasonably well in trax. Adding loss and metric masking and dropout refactor in Transformer.
PiperOrigin-RevId: 239692595
1 parent eedd6d7 commit 0d840ee

File tree

5 files changed

+60
-39
lines changed

5 files changed

+60
-39
lines changed

tensor2tensor/trax/configs/resnet50_imagenet_8gb.gin

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,5 @@ train.eval_steps = 20
4242
train.inputs = @trax.inputs.inputs
4343
train.model = @trax.models.Resnet50
4444
train.optimizer = @trax.optimizers.momentum
45-
train.train_steps = 500000
45+
train.train_steps = 1000000
4646
train.lr_schedule = @learning_rate.EvalAdjustingSchedule

tensor2tensor/trax/configs/transformer_lm1b_8gb.gin

+11-7
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,32 @@ import tensor2tensor.trax.trax
55

66
# Parameters for batch_fun:
77
# ==============================================================================
8-
batch_fun.batch_size = 32
9-
batch_fun.eval_batch_size = 32
8+
batch_fun.batch_size = 128
9+
batch_fun.eval_batch_size = 128
1010

1111
# Parameters for inputs:
1212
# ==============================================================================
1313
inputs.data_dir = None
1414
inputs.dataset_name = 't2t_languagemodel_lm1b32k'
1515

16+
# Parameters for mask:
17+
# ==============================================================================
18+
mask.mask_id = 0
19+
1620
# Parameters for MultifactorSchedule:
1721
# ==============================================================================
18-
MultifactorSchedule.constant = 0.05
22+
MultifactorSchedule.constant = 0.1
1923
MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay'
2024
MultifactorSchedule.warmup_steps = 8000
2125

2226
# Parameters for preprocess_fun:
2327
# ==============================================================================
24-
preprocess_fun.max_target_length = 256
28+
preprocess_fun.max_target_length = 512
2529

2630
# Parameters for train:
2731
# ==============================================================================
2832
train.eval_frequency = 1000
29-
train.eval_steps = 1
33+
train.eval_steps = 5
3034
train.inputs = @trax.inputs.inputs
3135
train.model = @trax.models.TransformerLM
3236
train.run_debug_step = False
@@ -38,10 +42,10 @@ train_and_eval_batches.input_name = 'targets'
3842

3943
# Parameters for TransformerLM:
4044
# ==============================================================================
41-
TransformerLM.dropout = 0.1
45+
TransformerLM.dropout = 0.2
4246
TransformerLM.feature_depth = 512
4347
TransformerLM.feedforward_depth = 2048
44-
TransformerLM.max_len = 256
48+
TransformerLM.max_len = 512
4549
TransformerLM.mode = 'train'
4650
TransformerLM.num_heads = 8
4751
TransformerLM.num_layers = 6

tensor2tensor/trax/inputs.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,12 @@ def batch_fun(dataset, training, shapes, target_names,
191191
if variable_target_shapes:
192192
bucket_boundaries = [bucket_length // 4, bucket_length // 2,
193193
bucket_length, bucket_length * 2,
194-
bucket_length * 4, bucket_length * 8]
194+
bucket_length * 4, bucket_length * 8,
195+
bucket_length * 16]
195196
bucket_batch_sizes = [cur_batch_size * 4, cur_batch_size * 2,
196197
cur_batch_size, cur_batch_size // 2,
197-
cur_batch_size // 4, cur_batch_size // 8, 1]
198+
cur_batch_size // 4, cur_batch_size // 8,
199+
max(1, cur_batch_size // 16), 1]
198200
buckets = (bucket_boundaries, bucket_batch_sizes)
199201

200202
if buckets:

tensor2tensor/trax/models/transformer.py

+28-25
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def TransformerEncoder(mode='train', # pylint: disable=invalid-name
2929
feature_depth=512,
3030
feedforward_depth=2048,
3131
num_heads=8,
32-
dropout=0.9):
32+
dropout=0.1):
3333
"""Transformer Encoder Stack.
3434
3535
Args:
@@ -38,20 +38,22 @@ def TransformerEncoder(mode='train', # pylint: disable=invalid-name
3838
feature_depth: int: depth of embedding
3939
feedforward_depth: int: depth of feed-forward layer
4040
num_heads: int: number of attention heads
41-
dropout: float: dropout rate - Stax follows TF's KEEP probability convention
41+
dropout: float: dropout rate (how much to drop out; note that stax follows
42+
Tensorflow's keep_rate convention, so we use 1 - dropout in calls below)
4243
4344
Returns:
4445
A staxlayer for implementing a raw Transformer encoder stack. No embedding
4546
or positional signals are added by this layer.
4647
"""
48+
keep_rate = 1.0 - dropout
4749
# Multi-headed Attention and Feed-forward layers
4850
multi_attention = stax.MultiHeadedAttention(
49-
feature_depth, num_heads=num_heads, dropout=dropout, mode=mode)
51+
feature_depth, num_heads=num_heads, dropout=keep_rate, mode=mode)
5052

5153
feed_forward = stax.serial(
5254
stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()),
5355
stax.Relu,
54-
stax.Dropout(dropout, mode=mode),
56+
stax.Dropout(keep_rate, mode=mode),
5557
stax.Dense(feature_depth, W_init=stax.xavier_uniform())
5658
)
5759

@@ -74,11 +76,11 @@ def encoder(embedded_source, source_mask):
7476
stax.Identity, # value
7577
source_mask), # attention mask
7678
multi_attention,
77-
stax.Dropout(dropout, mode=mode)),
79+
stax.Dropout(keep_rate, mode=mode)),
7880
# feed-forward
7981
stax.residual(stax.LayerNorm(feature_depth),
8082
feed_forward,
81-
stax.Dropout(dropout, mode=mode))
83+
stax.Dropout(keep_rate, mode=mode))
8284
)
8385
return stax.serial(
8486
embedded_source,
@@ -95,8 +97,8 @@ def TransformerLM(vocab_size, # pylint: disable=invalid-name
9597
feature_depth=512,
9698
feedforward_depth=2048,
9799
num_heads=8,
98-
dropout=0.9,
99-
max_len=256):
100+
dropout=0.1,
101+
max_len=512):
100102
"""Transformer language model (only uses the decoder part of Transformer).
101103
102104
Args:
@@ -106,20 +108,21 @@ def TransformerLM(vocab_size, # pylint: disable=invalid-name
106108
feature_depth: int: depth of embedding
107109
feedforward_depth: int: depth of feed-forward layer
108110
num_heads: int: number of attention heads
109-
dropout: float: dropout rate - Stax follows TF's KEEP probability convention
111+
dropout: float: dropout rate (how much to drop out)
110112
max_len: int: maximum symbol length for positional encoding
111113
112114
Returns:
113115
init and apply.
114116
"""
117+
keep_rate = 1.0 - dropout
115118
# Multi-headed Attention and Feed-forward layers
116119
multi_attention = stax.MultiHeadedAttention(
117-
feature_depth, num_heads=num_heads, dropout=dropout, mode=mode)
120+
feature_depth, num_heads=num_heads, dropout=keep_rate, mode=mode)
118121

119122
feed_forward = stax.serial(
120123
stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()),
121124
stax.Relu,
122-
stax.Dropout(dropout, mode=mode),
125+
stax.Dropout(keep_rate, mode=mode),
123126
stax.Dense(feature_depth, W_init=stax.xavier_uniform())
124127
)
125128

@@ -132,18 +135,18 @@ def TransformerLM(vocab_size, # pylint: disable=invalid-name
132135
stax.Identity, # value
133136
stax.CausalMask(axis=-2)), # attention mask
134137
multi_attention,
135-
stax.Dropout(dropout, mode=mode)),
138+
stax.Dropout(keep_rate, mode=mode)),
136139
# feed-forward
137140
stax.residual(stax.LayerNorm(feature_depth),
138141
feed_forward,
139-
stax.Dropout(dropout, mode=mode))
142+
stax.Dropout(keep_rate, mode=mode))
140143
)
141144

142145
return stax.serial(
143146
stax.ShiftRight(),
144147
stax.Embedding(feature_depth, vocab_size),
145148
stax.PositionalEncoding(feature_depth, max_len=max_len),
146-
stax.Dropout(dropout, mode=mode),
149+
stax.Dropout(keep_rate, mode=mode),
147150
stax.repeat(decoder_layer, num_layers),
148151
stax.LayerNorm(feature_depth),
149152
stax.Dense(vocab_size, W_init=stax.xavier_uniform()),
@@ -158,7 +161,7 @@ def Transformer(source_vocab_size, # pylint: disable=invalid-name
158161
feature_depth=512,
159162
feedforward_depth=2048,
160163
num_heads=8,
161-
dropout=0.9,
164+
dropout=0.1,
162165
shared_embedding=True,
163166
max_len=200,
164167
return_evals=False):
@@ -172,7 +175,7 @@ def Transformer(source_vocab_size, # pylint: disable=invalid-name
172175
feature_depth: int: depth of embedding
173176
feedforward_depth: int: depth of feed-forward layer
174177
num_heads: int: number of attention heads
175-
dropout: float: dropout rate - Stax follows TF's KEEP probability convention
178+
dropout: float: dropout rate (how much to drop out)
176179
shared_embedding: bool: specify whether source/target embeddings are tied.
177180
max_len: int: maximum symbol length for positional encoding
178181
return_evals: bool: whether to generate decode-time evaluation functions
@@ -182,11 +185,11 @@ def Transformer(source_vocab_size, # pylint: disable=invalid-name
182185
the 'evals' functions that itself returns a namedtuple containing evaluation
183186
functions for the trained encoder, decoder, and generator substax.
184187
"""
185-
188+
keep_rate = 1.0 - dropout
186189
# Input embedding and positional encoding
187190
inject_position = stax.serial(
188191
stax.PositionalEncoding(feature_depth, max_len=max_len),
189-
stax.Dropout(dropout, mode=mode)
192+
stax.Dropout(keep_rate, mode=mode)
190193
)
191194
if shared_embedding:
192195
assert source_vocab_size == target_vocab_size
@@ -202,12 +205,12 @@ def Transformer(source_vocab_size, # pylint: disable=invalid-name
202205

203206
# Multi-headed Attention and Feed-forward layers
204207
multi_attention = stax.MultiHeadedAttention(
205-
feature_depth, num_heads=num_heads, dropout=dropout, mode=mode)
208+
feature_depth, num_heads=num_heads, dropout=keep_rate, mode=mode)
206209

207210
feed_forward = stax.serial(
208211
stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()),
209212
stax.Relu,
210-
stax.Dropout(dropout, mode=mode),
213+
stax.Dropout(keep_rate, mode=mode),
211214
stax.Dense(feature_depth, W_init=stax.xavier_uniform())
212215
)
213216

@@ -231,11 +234,11 @@ def encoder(source, source_mask):
231234
stax.Identity, # value
232235
source_mask), # attention mask
233236
multi_attention,
234-
stax.Dropout(dropout, mode=mode)),
237+
stax.Dropout(keep_rate, mode=mode)),
235238
# feed-forward
236239
stax.residual(stax.LayerNorm(feature_depth),
237240
feed_forward,
238-
stax.Dropout(dropout, mode=mode))
241+
stax.Dropout(keep_rate, mode=mode))
239242
)
240243
return stax.serial(
241244
source,
@@ -266,19 +269,19 @@ def decoder(memory, target, target_mask, memory_mask):
266269
stax.Identity, # value
267270
target_mask), # attention mask
268271
multi_attention,
269-
stax.Dropout(dropout, mode=mode)),
272+
stax.Dropout(keep_rate, mode=mode)),
270273
# target attends to encoded source
271274
stax.residual(stax.LayerNorm(feature_depth),
272275
stax.multiplex(stax.Identity, # query
273276
memory, # key
274277
memory, # value
275278
memory_mask), # attention mask
276279
multi_attention,
277-
stax.Dropout(dropout, mode=mode)),
280+
stax.Dropout(keep_rate, mode=mode)),
278281
# feed-forward
279282
stax.residual(stax.LayerNorm(feature_depth),
280283
feed_forward,
281-
stax.Dropout(dropout, mode=mode))
284+
stax.Dropout(keep_rate, mode=mode))
282285
)
283286
return stax.serial(
284287
target,

tensor2tensor/trax/trax.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,38 @@
4646
from tensorflow.io import gfile
4747

4848

49+
@gin.configurable(blacklist=["inputs", "targets"])
50+
def masked_mean(inputs, targets, mask_id=None):
51+
"""Mean of the inputs but counting only those where targets != mask_id."""
52+
x = inputs.astype(np.float32)
53+
if mask_id is None:
54+
return np.mean(x)
55+
unmask = 1.0 - np.equal(targets, mask_id).astype(np.float32)
56+
return np.sum(x * unmask) / np.sum(unmask)
57+
58+
4959
def accuracy(batch, model_predictions):
5060
"""Calculate accuracy."""
5161
_, targets = batch
5262
predicted_class = np.argmax(model_predictions, axis=-1)
53-
return np.mean(predicted_class == targets)
63+
correct = np.equal(predicted_class, targets)
64+
return masked_mean(correct, targets)
5465

5566

5667
def neg_log_perplexity(batch, model_predictions):
5768
"""Calculate negative log perplexity."""
5869
_, targets = batch
5970
hot_targets = stax.one_hot(targets, model_predictions.shape[-1])
60-
return np.mean(np.sum(model_predictions * hot_targets, axis=-1))
71+
xent = np.sum(model_predictions * hot_targets, axis=-1)
72+
return masked_mean(xent, targets)
6173

6274

6375
def loss(params, batch, model_predict):
6476
"""Calculate loss."""
6577
inputs, targets = batch
6678
preds = model_predict(params, inputs)
67-
return - np.mean(np.sum(preds * stax.one_hot(targets, preds.shape[-1]),
68-
axis=-1))
79+
xent = np.sum(preds * stax.one_hot(targets, preds.shape[-1]), axis=-1)
80+
return - masked_mean(xent, targets)
6981

7082

7183
def log(s, stdout=True):

0 commit comments

Comments
 (0)