@@ -29,7 +29,7 @@ def TransformerEncoder(mode='train', # pylint: disable=invalid-name
29
29
feature_depth = 512 ,
30
30
feedforward_depth = 2048 ,
31
31
num_heads = 8 ,
32
- dropout = 0.9 ):
32
+ dropout = 0.1 ):
33
33
"""Transformer Encoder Stack.
34
34
35
35
Args:
@@ -38,20 +38,22 @@ def TransformerEncoder(mode='train', # pylint: disable=invalid-name
38
38
feature_depth: int: depth of embedding
39
39
feedforward_depth: int: depth of feed-forward layer
40
40
num_heads: int: number of attention heads
41
- dropout: float: dropout rate - Stax follows TF's KEEP probability convention
41
+ dropout: float: dropout rate (how much to drop out; note that stax follows
42
+ Tensorflow's keep_rate convention, so we use 1 - dropout in calls below)
42
43
43
44
Returns:
44
45
A staxlayer for implementing a raw Transformer encoder stack. No embedding
45
46
or positional signals are added by this layer.
46
47
"""
48
+ keep_rate = 1.0 - dropout
47
49
# Multi-headed Attention and Feed-forward layers
48
50
multi_attention = stax .MultiHeadedAttention (
49
- feature_depth , num_heads = num_heads , dropout = dropout , mode = mode )
51
+ feature_depth , num_heads = num_heads , dropout = keep_rate , mode = mode )
50
52
51
53
feed_forward = stax .serial (
52
54
stax .Dense (feedforward_depth , W_init = stax .xavier_uniform ()),
53
55
stax .Relu ,
54
- stax .Dropout (dropout , mode = mode ),
56
+ stax .Dropout (keep_rate , mode = mode ),
55
57
stax .Dense (feature_depth , W_init = stax .xavier_uniform ())
56
58
)
57
59
@@ -74,11 +76,11 @@ def encoder(embedded_source, source_mask):
74
76
stax .Identity , # value
75
77
source_mask ), # attention mask
76
78
multi_attention ,
77
- stax .Dropout (dropout , mode = mode )),
79
+ stax .Dropout (keep_rate , mode = mode )),
78
80
# feed-forward
79
81
stax .residual (stax .LayerNorm (feature_depth ),
80
82
feed_forward ,
81
- stax .Dropout (dropout , mode = mode ))
83
+ stax .Dropout (keep_rate , mode = mode ))
82
84
)
83
85
return stax .serial (
84
86
embedded_source ,
@@ -95,8 +97,8 @@ def TransformerLM(vocab_size, # pylint: disable=invalid-name
95
97
feature_depth = 512 ,
96
98
feedforward_depth = 2048 ,
97
99
num_heads = 8 ,
98
- dropout = 0.9 ,
99
- max_len = 256 ):
100
+ dropout = 0.1 ,
101
+ max_len = 512 ):
100
102
"""Transformer language model (only uses the decoder part of Transformer).
101
103
102
104
Args:
@@ -106,20 +108,21 @@ def TransformerLM(vocab_size, # pylint: disable=invalid-name
106
108
feature_depth: int: depth of embedding
107
109
feedforward_depth: int: depth of feed-forward layer
108
110
num_heads: int: number of attention heads
109
- dropout: float: dropout rate - Stax follows TF's KEEP probability convention
111
+ dropout: float: dropout rate (how much to drop out)
110
112
max_len: int: maximum symbol length for positional encoding
111
113
112
114
Returns:
113
115
init and apply.
114
116
"""
117
+ keep_rate = 1.0 - dropout
115
118
# Multi-headed Attention and Feed-forward layers
116
119
multi_attention = stax .MultiHeadedAttention (
117
- feature_depth , num_heads = num_heads , dropout = dropout , mode = mode )
120
+ feature_depth , num_heads = num_heads , dropout = keep_rate , mode = mode )
118
121
119
122
feed_forward = stax .serial (
120
123
stax .Dense (feedforward_depth , W_init = stax .xavier_uniform ()),
121
124
stax .Relu ,
122
- stax .Dropout (dropout , mode = mode ),
125
+ stax .Dropout (keep_rate , mode = mode ),
123
126
stax .Dense (feature_depth , W_init = stax .xavier_uniform ())
124
127
)
125
128
@@ -132,18 +135,18 @@ def TransformerLM(vocab_size, # pylint: disable=invalid-name
132
135
stax .Identity , # value
133
136
stax .CausalMask (axis = - 2 )), # attention mask
134
137
multi_attention ,
135
- stax .Dropout (dropout , mode = mode )),
138
+ stax .Dropout (keep_rate , mode = mode )),
136
139
# feed-forward
137
140
stax .residual (stax .LayerNorm (feature_depth ),
138
141
feed_forward ,
139
- stax .Dropout (dropout , mode = mode ))
142
+ stax .Dropout (keep_rate , mode = mode ))
140
143
)
141
144
142
145
return stax .serial (
143
146
stax .ShiftRight (),
144
147
stax .Embedding (feature_depth , vocab_size ),
145
148
stax .PositionalEncoding (feature_depth , max_len = max_len ),
146
- stax .Dropout (dropout , mode = mode ),
149
+ stax .Dropout (keep_rate , mode = mode ),
147
150
stax .repeat (decoder_layer , num_layers ),
148
151
stax .LayerNorm (feature_depth ),
149
152
stax .Dense (vocab_size , W_init = stax .xavier_uniform ()),
@@ -158,7 +161,7 @@ def Transformer(source_vocab_size, # pylint: disable=invalid-name
158
161
feature_depth = 512 ,
159
162
feedforward_depth = 2048 ,
160
163
num_heads = 8 ,
161
- dropout = 0.9 ,
164
+ dropout = 0.1 ,
162
165
shared_embedding = True ,
163
166
max_len = 200 ,
164
167
return_evals = False ):
@@ -172,7 +175,7 @@ def Transformer(source_vocab_size, # pylint: disable=invalid-name
172
175
feature_depth: int: depth of embedding
173
176
feedforward_depth: int: depth of feed-forward layer
174
177
num_heads: int: number of attention heads
175
- dropout: float: dropout rate - Stax follows TF's KEEP probability convention
178
+ dropout: float: dropout rate (how much to drop out)
176
179
shared_embedding: bool: specify whether source/target embeddings are tied.
177
180
max_len: int: maximum symbol length for positional encoding
178
181
return_evals: bool: whether to generate decode-time evaluation functions
@@ -182,11 +185,11 @@ def Transformer(source_vocab_size, # pylint: disable=invalid-name
182
185
the 'evals' functions that itself returns a namedtuple containing evaluation
183
186
functions for the trained encoder, decoder, and generator substax.
184
187
"""
185
-
188
+ keep_rate = 1.0 - dropout
186
189
# Input embedding and positional encoding
187
190
inject_position = stax .serial (
188
191
stax .PositionalEncoding (feature_depth , max_len = max_len ),
189
- stax .Dropout (dropout , mode = mode )
192
+ stax .Dropout (keep_rate , mode = mode )
190
193
)
191
194
if shared_embedding :
192
195
assert source_vocab_size == target_vocab_size
@@ -202,12 +205,12 @@ def Transformer(source_vocab_size, # pylint: disable=invalid-name
202
205
203
206
# Multi-headed Attention and Feed-forward layers
204
207
multi_attention = stax .MultiHeadedAttention (
205
- feature_depth , num_heads = num_heads , dropout = dropout , mode = mode )
208
+ feature_depth , num_heads = num_heads , dropout = keep_rate , mode = mode )
206
209
207
210
feed_forward = stax .serial (
208
211
stax .Dense (feedforward_depth , W_init = stax .xavier_uniform ()),
209
212
stax .Relu ,
210
- stax .Dropout (dropout , mode = mode ),
213
+ stax .Dropout (keep_rate , mode = mode ),
211
214
stax .Dense (feature_depth , W_init = stax .xavier_uniform ())
212
215
)
213
216
@@ -231,11 +234,11 @@ def encoder(source, source_mask):
231
234
stax .Identity , # value
232
235
source_mask ), # attention mask
233
236
multi_attention ,
234
- stax .Dropout (dropout , mode = mode )),
237
+ stax .Dropout (keep_rate , mode = mode )),
235
238
# feed-forward
236
239
stax .residual (stax .LayerNorm (feature_depth ),
237
240
feed_forward ,
238
- stax .Dropout (dropout , mode = mode ))
241
+ stax .Dropout (keep_rate , mode = mode ))
239
242
)
240
243
return stax .serial (
241
244
source ,
@@ -266,19 +269,19 @@ def decoder(memory, target, target_mask, memory_mask):
266
269
stax .Identity , # value
267
270
target_mask ), # attention mask
268
271
multi_attention ,
269
- stax .Dropout (dropout , mode = mode )),
272
+ stax .Dropout (keep_rate , mode = mode )),
270
273
# target attends to encoded source
271
274
stax .residual (stax .LayerNorm (feature_depth ),
272
275
stax .multiplex (stax .Identity , # query
273
276
memory , # key
274
277
memory , # value
275
278
memory_mask ), # attention mask
276
279
multi_attention ,
277
- stax .Dropout (dropout , mode = mode )),
280
+ stax .Dropout (keep_rate , mode = mode )),
278
281
# feed-forward
279
282
stax .residual (stax .LayerNorm (feature_depth ),
280
283
feed_forward ,
281
- stax .Dropout (dropout , mode = mode ))
284
+ stax .Dropout (keep_rate , mode = mode ))
282
285
)
283
286
return stax .serial (
284
287
target ,
0 commit comments