Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 5386fb1

Browse files
Lukasz Kaisercopybara-github
Lukasz Kaiser
authored andcommitted
Small corrections.
PiperOrigin-RevId: 264694817
1 parent 67ca605 commit 5386fb1

File tree

3 files changed

+28
-26
lines changed

3 files changed

+28
-26
lines changed

tensor2tensor/trax/layers/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -258,12 +258,12 @@ def __call__(self, x, params=(), state=(), **kwargs):
258258
# JAX.
259259

260260
assert state is (), ( # pylint: disable=literal-comparison
261-
'Custom gradients do not allow non-trivial start state.')
261+
'Custom gradients require trivial start state. Got %s' % str(state))
262262

263263
def check_end_state(output_state):
264264
output, state = output_state
265265
assert state is (), ( # pylint: disable=literal-comparison
266-
'Custom gradients do not allow non-trivial end state.')
266+
'Custom gradients require trivial end state. Got %s' % str(state))
267267
return output
268268

269269
# See this link for how custom transformations are defined in JAX:

tensor2tensor/trax/layers/combinators.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def call(self, xs, params=(), state=(), **kwargs):
181181
raise ValueError('number of params ({}) not equal to number of layers '
182182
'({})'.format(len(params), n_layers))
183183
if n_layers != 1 and len(state) != n_layers:
184-
raise ValueError('number of params ({}) not equal to number of layers '
184+
raise ValueError('length of state ({}) not equal to number of layers '
185185
'({})'.format(len(state), n_layers))
186186
for layer, p, s, rng in zip(self._sublayers, params, state, rngs):
187187
is_stack_just_one_item = (_count_items(stack) == 1)

tensor2tensor/trax/models/research/transformer_revnet.py

+25-23
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,12 @@ def n_outputs(self):
6565
"""Specifies how many data tensors this layer promises as output."""
6666
return self._n_sections
6767

68-
def call(self, inputs, params=(), **kwargs):
68+
def call(self, inputs, params=(), state=(), **kwargs):
6969
rngs = _pop_rng_and_split(kwargs, len(inputs))
70-
result = [self._layer(x, params=params, rng=r, **kwargs)
71-
for x, r in zip(inputs, rngs)]
72-
return tuple(result)
70+
results = [self._layer(x, params=params, state=state, rng=r, **kwargs)
71+
for x, r in zip(inputs, rngs)]
72+
result_outputs, result_states = zip(*results)
73+
return tuple(result_outputs), tuple(result_states)
7374

7475
def new_parameters(self, input_shape, input_dtype, rng):
7576
first_shape = input_shape[0]
@@ -122,12 +123,13 @@ def __init__(self, n_sections=2, axis=-1):
122123
self._n_sections = n_sections
123124
self._axis = axis
124125

125-
def call(self, inputs, params=(), **kwargs):
126+
def call(self, inputs, params=(), state=(), **kwargs):
126127
del params, kwargs
127-
return tuple(backend.numpy.split(inputs, self._n_sections, self._axis))
128+
res = tuple(backend.numpy.split(inputs, self._n_sections, self._axis))
129+
return res, state
128130

129131
def new_parameters(self, input_shapes, input_dtype, rng):
130-
return ()
132+
return (), ()
131133

132134
def n_inputs(self):
133135
"""Specifies how many data tensors this layer expects as input."""
@@ -167,17 +169,17 @@ def n_outputs(self):
167169
return self._n_sections
168170

169171
def new_parameters(self, input_shape, input_dtype, rng):
170-
return ()
172+
return (), ()
171173

172-
def call(self, inputs, params=(), **kwargs):
174+
def call(self, inputs, params=(), state=(), **kwargs):
173175
del params, kwargs
174176
x1, x2 = inputs
175177

176178
x1_split = backend.numpy.split(x1, self._n_sections, self._axis)
177179
x2_split = backend.numpy.split(x2, self._n_sections, self._axis)
178180

179181
res = [backend.numpy.concatenate(ys, -1) for ys in zip(x1_split, x2_split)]
180-
return tuple(res)
182+
return tuple(res), state
181183

182184
def reverse(self, output, params=(), **kwargs):
183185
del params, kwargs
@@ -288,7 +290,7 @@ def __init__(self, n_heads=1, d_head=64,
288290
# The lack of a bias term here is consistent with the tensor2tensor
289291
# implementation, and shouldn't have an effect on modeling quality.
290292

291-
def call(self, x, params, **kwargs):
293+
def call(self, x, params, state, **kwargs):
292294
del kwargs
293295
seqlen = x.shape[1]
294296
res = np.dot(x, params)
@@ -300,13 +302,13 @@ def call(self, x, params, **kwargs):
300302
# n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
301303
res = np.reshape(res, (-1, seqlen, self._d_head))
302304

303-
return res
305+
return res, state
304306

305307
def new_parameters(self, input_shape, input_dtype, rng):
306308
del input_dtype
307309
w = self._kernel_initializer(
308310
(input_shape[-1], self._n_heads * self._d_head), rng)
309-
return w
311+
return w, ()
310312

311313

312314
class ComputeAttentionOutput(tl.Layer):
@@ -321,7 +323,7 @@ def __init__(self, n_heads=1, d_model=1024,
321323
# The lack of a bias term here is consistent with the tensor2tensor
322324
# implementation, and shouldn't have an effect on modeling quality.
323325

324-
def call(self, x, params, **kwargs):
326+
def call(self, x, params, state, **kwargs):
325327
del kwargs
326328
seqlen = x.shape[1]
327329
d_head = x.shape[2]
@@ -330,13 +332,13 @@ def call(self, x, params, **kwargs):
330332
x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head
331333
x = np.reshape(x, (-1, seqlen, self._n_heads * d_head))
332334

333-
return np.dot(x, params)
335+
return np.dot(x, params), state
334336

335337
def new_parameters(self, input_shape, input_dtype, rng):
336338
del input_dtype
337339
w = self._kernel_initializer(
338340
(input_shape[-1] * self._n_heads, self._d_model), rng)
339-
return w
341+
return w, ()
340342

341343

342344
class ApplyAttentionWrapper(tl.Parallel):
@@ -374,14 +376,14 @@ def __init__(self, dropout, mode):
374376
self._dropout = dropout
375377
self._mode = mode
376378

377-
def call(self, inputs, params=(), rng=None, **kwargs):
379+
def call(self, inputs, params=(), state=(), rng=None, **kwargs):
378380
del params
379381
q, k, v = inputs
380382
mask_size = q.shape[-2]
381383
mask = np.tril(np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0)
382384
res = tl.DotProductAttention(
383385
q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng)
384-
return res
386+
return res, state
385387

386388
def forward_and_vjp(self, inputs, ct, params=(), **kwargs):
387389
# Simultaneous forward pass and backprop through the attention mechanism.
@@ -391,7 +393,7 @@ def do_call(x):
391393
return output, vjpfun(ct)[0]
392394

393395
def new_parameters(self, input_shapes, input_dtype, rng):
394-
return ()
396+
return (), ()
395397

396398
def n_inputs(self):
397399
return 3
@@ -413,9 +415,9 @@ def __init__(self, loop_stride, dropout, mode):
413415
else:
414416
self.dropout = None
415417

416-
def call(self, inputs, params=(), **kwargs):
418+
def call(self, inputs, params=(), state=(), **kwargs):
417419
output, _ = self.forward_and_vjp(inputs, None, params=params, **kwargs)
418-
return output
420+
return output, state
419421

420422
def forward_and_vjp(self, inputs, ct, params=(), rng=None, **kwargs):
421423
# This is the core of the memory-efficient attention implementation, where
@@ -547,9 +549,9 @@ def __init__(self, dropout, mode, n_bins=64):
547549
super(DummyHashedAttention, self).__init__(dropout, mode)
548550
self.n_bins = n_bins
549551

550-
def call(self, inputs, params=(), **kwargs):
552+
def call(self, inputs, params=(), state=(), **kwargs):
551553
output, _ = self.forward_and_vjp(inputs, None, params=params, **kwargs)
552-
return output
554+
return output, state
553555

554556
def forward_and_vjp(self, inputs, ct, params=(), **kwargs):
555557
del params, kwargs

0 commit comments

Comments
 (0)