@@ -65,11 +65,12 @@ def n_outputs(self):
65
65
"""Specifies how many data tensors this layer promises as output."""
66
66
return self ._n_sections
67
67
68
- def call (self , inputs , params = (), ** kwargs ):
68
+ def call (self , inputs , params = (), state = (), ** kwargs ):
69
69
rngs = _pop_rng_and_split (kwargs , len (inputs ))
70
- result = [self ._layer (x , params = params , rng = r , ** kwargs )
71
- for x , r in zip (inputs , rngs )]
72
- return tuple (result )
70
+ results = [self ._layer (x , params = params , state = state , rng = r , ** kwargs )
71
+ for x , r in zip (inputs , rngs )]
72
+ result_outputs , result_states = zip (* results )
73
+ return tuple (result_outputs ), tuple (result_states )
73
74
74
75
def new_parameters (self , input_shape , input_dtype , rng ):
75
76
first_shape = input_shape [0 ]
@@ -122,12 +123,13 @@ def __init__(self, n_sections=2, axis=-1):
122
123
self ._n_sections = n_sections
123
124
self ._axis = axis
124
125
125
- def call (self , inputs , params = (), ** kwargs ):
126
+ def call (self , inputs , params = (), state = (), ** kwargs ):
126
127
del params , kwargs
127
- return tuple (backend .numpy .split (inputs , self ._n_sections , self ._axis ))
128
+ res = tuple (backend .numpy .split (inputs , self ._n_sections , self ._axis ))
129
+ return res , state
128
130
129
131
def new_parameters (self , input_shapes , input_dtype , rng ):
130
- return ()
132
+ return (), ()
131
133
132
134
def n_inputs (self ):
133
135
"""Specifies how many data tensors this layer expects as input."""
@@ -167,17 +169,17 @@ def n_outputs(self):
167
169
return self ._n_sections
168
170
169
171
def new_parameters (self , input_shape , input_dtype , rng ):
170
- return ()
172
+ return (), ()
171
173
172
- def call (self , inputs , params = (), ** kwargs ):
174
+ def call (self , inputs , params = (), state = (), ** kwargs ):
173
175
del params , kwargs
174
176
x1 , x2 = inputs
175
177
176
178
x1_split = backend .numpy .split (x1 , self ._n_sections , self ._axis )
177
179
x2_split = backend .numpy .split (x2 , self ._n_sections , self ._axis )
178
180
179
181
res = [backend .numpy .concatenate (ys , - 1 ) for ys in zip (x1_split , x2_split )]
180
- return tuple (res )
182
+ return tuple (res ), state
181
183
182
184
def reverse (self , output , params = (), ** kwargs ):
183
185
del params , kwargs
@@ -288,7 +290,7 @@ def __init__(self, n_heads=1, d_head=64,
288
290
# The lack of a bias term here is consistent with the tensor2tensor
289
291
# implementation, and shouldn't have an effect on modeling quality.
290
292
291
- def call (self , x , params , ** kwargs ):
293
+ def call (self , x , params , state , ** kwargs ):
292
294
del kwargs
293
295
seqlen = x .shape [1 ]
294
296
res = np .dot (x , params )
@@ -300,13 +302,13 @@ def call(self, x, params, **kwargs):
300
302
# n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
301
303
res = np .reshape (res , (- 1 , seqlen , self ._d_head ))
302
304
303
- return res
305
+ return res , state
304
306
305
307
def new_parameters (self , input_shape , input_dtype , rng ):
306
308
del input_dtype
307
309
w = self ._kernel_initializer (
308
310
(input_shape [- 1 ], self ._n_heads * self ._d_head ), rng )
309
- return w
311
+ return w , ()
310
312
311
313
312
314
class ComputeAttentionOutput (tl .Layer ):
@@ -321,7 +323,7 @@ def __init__(self, n_heads=1, d_model=1024,
321
323
# The lack of a bias term here is consistent with the tensor2tensor
322
324
# implementation, and shouldn't have an effect on modeling quality.
323
325
324
- def call (self , x , params , ** kwargs ):
326
+ def call (self , x , params , state , ** kwargs ):
325
327
del kwargs
326
328
seqlen = x .shape [1 ]
327
329
d_head = x .shape [2 ]
@@ -330,13 +332,13 @@ def call(self, x, params, **kwargs):
330
332
x = np .transpose (x , (0 , 2 , 1 , 3 )) # -> n_batch, seqlen, n_heads, d_head
331
333
x = np .reshape (x , (- 1 , seqlen , self ._n_heads * d_head ))
332
334
333
- return np .dot (x , params )
335
+ return np .dot (x , params ), state
334
336
335
337
def new_parameters (self , input_shape , input_dtype , rng ):
336
338
del input_dtype
337
339
w = self ._kernel_initializer (
338
340
(input_shape [- 1 ] * self ._n_heads , self ._d_model ), rng )
339
- return w
341
+ return w , ()
340
342
341
343
342
344
class ApplyAttentionWrapper (tl .Parallel ):
@@ -374,14 +376,14 @@ def __init__(self, dropout, mode):
374
376
self ._dropout = dropout
375
377
self ._mode = mode
376
378
377
- def call (self , inputs , params = (), rng = None , ** kwargs ):
379
+ def call (self , inputs , params = (), state = (), rng = None , ** kwargs ):
378
380
del params
379
381
q , k , v = inputs
380
382
mask_size = q .shape [- 2 ]
381
383
mask = np .tril (np .ones ((1 , mask_size , mask_size ), dtype = onp .bool_ ), k = 0 )
382
384
res = tl .DotProductAttention (
383
385
q , k , v , mask , dropout = self ._dropout , mode = self ._mode , rng = rng )
384
- return res
386
+ return res , state
385
387
386
388
def forward_and_vjp (self , inputs , ct , params = (), ** kwargs ):
387
389
# Simultaneous forward pass and backprop through the attention mechanism.
@@ -391,7 +393,7 @@ def do_call(x):
391
393
return output , vjpfun (ct )[0 ]
392
394
393
395
def new_parameters (self , input_shapes , input_dtype , rng ):
394
- return ()
396
+ return (), ()
395
397
396
398
def n_inputs (self ):
397
399
return 3
@@ -413,9 +415,9 @@ def __init__(self, loop_stride, dropout, mode):
413
415
else :
414
416
self .dropout = None
415
417
416
- def call (self , inputs , params = (), ** kwargs ):
418
+ def call (self , inputs , params = (), state = (), ** kwargs ):
417
419
output , _ = self .forward_and_vjp (inputs , None , params = params , ** kwargs )
418
- return output
420
+ return output , state
419
421
420
422
def forward_and_vjp (self , inputs , ct , params = (), rng = None , ** kwargs ):
421
423
# This is the core of the memory-efficient attention implementation, where
@@ -547,9 +549,9 @@ def __init__(self, dropout, mode, n_bins=64):
547
549
super (DummyHashedAttention , self ).__init__ (dropout , mode )
548
550
self .n_bins = n_bins
549
551
550
- def call (self , inputs , params = (), ** kwargs ):
552
+ def call (self , inputs , params = (), state = (), ** kwargs ):
551
553
output , _ = self .forward_and_vjp (inputs , None , params = params , ** kwargs )
552
- return output
554
+ return output , state
553
555
554
556
def forward_and_vjp (self , inputs , ct , params = (), ** kwargs ):
555
557
del params , kwargs
0 commit comments