Skip to content

Commit 4bc6576

Browse files
Upstream ReversibleEmbedding from KerasHub. (#21753)
* Upstream `ReversibleEmbedding` from KerasHub. * Fix tests. * Exclude test for openvino. * Fix int4 quantization.
1 parent 2b8ddf3 commit 4bc6576

File tree

6 files changed

+537
-0
lines changed

6 files changed

+537
-0
lines changed

keras/api/_tf_keras/keras/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@
7373
from keras.src.layers.core.input_layer import InputLayer as InputLayer
7474
from keras.src.layers.core.lambda_layer import Lambda as Lambda
7575
from keras.src.layers.core.masking import Masking as Masking
76+
from keras.src.layers.core.reversible_embedding import (
77+
ReversibleEmbedding as ReversibleEmbedding,
78+
)
7679
from keras.src.layers.core.wrapper import Wrapper as Wrapper
7780
from keras.src.layers.input_spec import InputSpec as InputSpec
7881
from keras.src.layers.layer import Layer as Layer

keras/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@
7373
from keras.src.layers.core.input_layer import InputLayer as InputLayer
7474
from keras.src.layers.core.lambda_layer import Lambda as Lambda
7575
from keras.src.layers.core.masking import Masking as Masking
76+
from keras.src.layers.core.reversible_embedding import (
77+
ReversibleEmbedding as ReversibleEmbedding,
78+
)
7679
from keras.src.layers.core.wrapper import Wrapper as Wrapper
7780
from keras.src.layers.input_spec import InputSpec as InputSpec
7881
from keras.src.layers.layer import Layer as Layer

keras/src/backend/openvino/excluded_tests.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ keras/src/layers/convolutional/separable_conv_test.py
88
keras/src/layers/core/dense_test.py
99
keras/src/layers/core/einsum_dense_test.py
1010
keras/src/layers/core/embedding_test.py
11+
keras/src/layers/core/reversible_embedding_test.py
1112
keras/src/layers/normalization/spectral_normalization_test.py
1213
keras/src/layers/normalization/unit_normalization_test.py
1314
keras/src/layers/pooling/average_pooling_test.py

keras/src/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from keras.src.layers.core.input_layer import InputLayer
3030
from keras.src.layers.core.lambda_layer import Lambda
3131
from keras.src.layers.core.masking import Masking
32+
from keras.src.layers.core.reversible_embedding import ReversibleEmbedding
3233
from keras.src.layers.core.wrapper import Wrapper
3334
from keras.src.layers.input_spec import InputSpec
3435
from keras.src.layers.layer import Layer
Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
import copy
2+
3+
from keras.src import dtype_policies
4+
from keras.src import layers
5+
from keras.src import ops
6+
from keras.src import quantizers
7+
from keras.src.api_export import keras_export
8+
from keras.src.backend import KerasTensor
9+
10+
11+
@keras_export("keras.layers.ReversibleEmbedding")
12+
class ReversibleEmbedding(layers.Embedding):
13+
"""An embedding layer which can project backwards to the input dim.
14+
15+
This layer is an extension of `keras.layers.Embedding` for language models.
16+
This layer can be called "in reverse" with `reverse=True`, in which case the
17+
layer will linearly project from `output_dim` back to `input_dim`.
18+
19+
By default, the reverse projection will use the transpose of the
20+
`embeddings` weights to project to `input_dim` (weights are "tied"). If
21+
`tie_weights=False`, the model will use a separate, trainable variable for
22+
reverse projection.
23+
24+
This layer has no bias terms.
25+
26+
Args:
27+
input_dim: Integer. Size of the vocabulary,
28+
i.e. maximum integer index + 1.
29+
output_dim: Integer. Dimension of the dense embedding.
30+
tie_weights: Boolean, whether or not the matrix for embedding and
31+
the matrix for the `reverse` projection should share the same
32+
weights.
33+
embeddings_initializer: Initializer for the `embeddings`
34+
matrix (see `keras.initializers`).
35+
embeddings_regularizer: Regularizer function applied to
36+
the `embeddings` matrix (see `keras.regularizers`).
37+
embeddings_constraint: Constraint function applied to
38+
the `embeddings` matrix (see `keras.constraints`).
39+
mask_zero: Boolean, whether or not the input value 0 is a special
40+
"padding" value that should be masked out.
41+
reverse_dtype: The dtype for the reverse projection computation.
42+
Defaults to the `compute_dtype` of the layer.
43+
logit_soft_cap: If `logit_soft_cap` is set and `reverse=True`, the
44+
output logits will be scaled by
45+
`tanh(logits / logit_soft_cap) * logit_soft_cap`. This narrows the
46+
range of output logits and can improve training.
47+
**kwargs: other keyword arguments passed to `keras.layers.Embedding`,
48+
including `name`, `trainable`, `dtype` etc.
49+
50+
Call arguments:
51+
inputs: The tensor inputs to the layer.
52+
reverse: Boolean. If `True` the layer will perform a linear projection
53+
from `output_dim` to `input_dim`, instead of a normal embedding
54+
call. Default to `False`.
55+
56+
Example:
57+
```python
58+
batch_size = 16
59+
vocab_size = 100
60+
hidden_dim = 32
61+
seq_length = 50
62+
63+
# Generate random inputs.
64+
token_ids = np.random.randint(vocab_size, size=(batch_size, seq_length))
65+
66+
embedding = keras.layers.ReversibleEmbedding(vocab_size, hidden_dim)
67+
# Embed tokens to shape `(batch_size, seq_length, hidden_dim)`.
68+
hidden_states = embedding(token_ids)
69+
# Project hidden states to shape `(batch_size, seq_length, vocab_size)`.
70+
logits = embedding(hidden_states, reverse=True)
71+
```
72+
73+
References:
74+
- [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
75+
- [Press and Wolf, 2016](https://arxiv.org/abs/1608.05859)
76+
"""
77+
78+
def __init__(
79+
self,
80+
input_dim,
81+
output_dim,
82+
tie_weights=True,
83+
embeddings_initializer="uniform",
84+
embeddings_regularizer=None,
85+
embeddings_constraint=None,
86+
mask_zero=False,
87+
reverse_dtype=None,
88+
logit_soft_cap=None,
89+
**kwargs,
90+
):
91+
super().__init__(
92+
input_dim,
93+
output_dim,
94+
embeddings_initializer=embeddings_initializer,
95+
embeddings_regularizer=embeddings_regularizer,
96+
embeddings_constraint=embeddings_constraint,
97+
mask_zero=mask_zero,
98+
**kwargs,
99+
)
100+
self.tie_weights = tie_weights
101+
self.reverse_dtype = reverse_dtype
102+
self.logit_soft_cap = logit_soft_cap
103+
104+
def build(self, inputs_shape=None):
105+
super().build(inputs_shape)
106+
if not self.tie_weights and self.quantization_mode not in (
107+
"int8",
108+
"int4",
109+
):
110+
self.reverse_embeddings = self.add_weight(
111+
shape=(self.output_dim, self.input_dim),
112+
initializer=self.embeddings_initializer,
113+
name="reverse_embeddings",
114+
trainable=True,
115+
)
116+
117+
def call(self, inputs, reverse=False):
118+
if not reverse:
119+
return super().call(inputs)
120+
else:
121+
if self.tie_weights:
122+
kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
123+
else:
124+
kernel = self.reverse_embeddings
125+
if self.reverse_dtype is not None:
126+
inputs = ops.cast(inputs, self.reverse_dtype)
127+
kernel = ops.cast(kernel, self.reverse_dtype)
128+
logits = ops.matmul(inputs, kernel)
129+
# Optionally soft-cap logits.
130+
if self.logit_soft_cap is not None:
131+
soft_cap = self.logit_soft_cap
132+
logits = ops.multiply(
133+
ops.tanh(ops.divide(logits, soft_cap)), soft_cap
134+
)
135+
return logits
136+
137+
def compute_output_shape(self, input_shape, reverse=False):
138+
output_shape = list(input_shape)
139+
if reverse:
140+
output_shape[-1] = self.input_dim
141+
else:
142+
output_shape += [self.output_dim]
143+
return output_shape
144+
145+
def compute_output_spec(self, inputs, reverse=False):
146+
output_shape = list(inputs.shape)
147+
if reverse:
148+
output_shape[-1] = self.input_dim
149+
else:
150+
output_shape += [self.output_dim]
151+
return KerasTensor(output_shape, dtype=self.compute_dtype)
152+
153+
def get_config(self):
154+
config = super().get_config()
155+
config.update(
156+
{
157+
"tie_weights": self.tie_weights,
158+
"reverse_dtype": self.reverse_dtype,
159+
"logit_soft_cap": self.logit_soft_cap,
160+
}
161+
)
162+
return config
163+
164+
@property
165+
def variable_serialization_spec(self):
166+
# Avoid modifying the parent's spec.
167+
_spec = copy.deepcopy(super().variable_serialization_spec)
168+
if not self.tie_weights:
169+
for mode, variable_spec in _spec.items():
170+
variable_spec.append("reverse_embeddings")
171+
if mode in ("int4", "int8"):
172+
variable_spec.append("reverse_embeddings_scale")
173+
return _spec
174+
175+
def quantized_build(self, embeddings_shape, mode):
176+
if mode == "int8":
177+
self._int8_build(embeddings_shape)
178+
elif mode == "int4":
179+
self._int4_build(embeddings_shape)
180+
else:
181+
raise self._quantization_mode_error(mode)
182+
self._is_quantized = True
183+
184+
def _int8_build(self, embeddings_shape):
185+
if embeddings_shape is None:
186+
embeddings_shape = (self.input_dim, self.output_dim)
187+
super()._int8_build(embeddings_shape=embeddings_shape)
188+
self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
189+
if not self.tie_weights:
190+
self.reverse_embeddings = self.add_weight(
191+
name="reverse_embeddings",
192+
shape=(self.output_dim, self.input_dim),
193+
initializer="zeros",
194+
dtype="int8",
195+
trainable=False,
196+
)
197+
self.reverse_embeddings_scale = self.add_weight(
198+
name="reverse_embeddings_scale",
199+
shape=(self.input_dim,),
200+
initializer="ones",
201+
trainable=False,
202+
)
203+
204+
def _int4_build(self, embeddings_shape):
205+
if embeddings_shape is None:
206+
embeddings_shape = (self.input_dim, self.output_dim)
207+
super()._int4_build(embeddings_shape=embeddings_shape)
208+
self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
209+
if not self.tie_weights:
210+
packed_rows = (self.output_dim + 1) // 2 # ceil for odd dims
211+
self.reverse_embeddings = self.add_weight(
212+
name="reverse_embeddings",
213+
shape=(packed_rows, self.input_dim),
214+
initializer="zeros",
215+
dtype="int8",
216+
trainable=False,
217+
)
218+
self.reverse_embeddings_scale = self.add_weight(
219+
name="reverse_embeddings_scale",
220+
shape=(self.input_dim,),
221+
initializer="ones",
222+
trainable=False,
223+
)
224+
225+
def _int8_call(self, inputs, reverse=False):
226+
if not reverse:
227+
return super()._int8_call(inputs)
228+
else:
229+
if self.tie_weights:
230+
kernel = ops.transpose(self._embeddings)
231+
scale = ops.transpose(self.embeddings_scale)
232+
else:
233+
kernel = self.reverse_embeddings
234+
scale = self.reverse_embeddings_scale
235+
inputs, inputs_scale = self.inputs_quantizer(inputs)
236+
logits = ops.matmul(inputs, kernel)
237+
# De-scale outputs
238+
logits = ops.cast(logits, self.compute_dtype)
239+
logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
240+
# Optionally soft-cap logits.
241+
if self.logit_soft_cap is not None:
242+
soft_cap = self.logit_soft_cap
243+
logits = ops.multiply(
244+
ops.tanh(ops.divide(logits, soft_cap)), soft_cap
245+
)
246+
return logits
247+
248+
def _int4_call(self, inputs, reverse=False):
249+
if not reverse:
250+
return super()._int4_call(inputs)
251+
else:
252+
if self.tie_weights:
253+
embeddings = ops.transpose(self._embeddings)
254+
scale = ops.transpose(self.embeddings_scale)
255+
else:
256+
embeddings = self.reverse_embeddings
257+
scale = self.reverse_embeddings_scale
258+
unpacked_embeddings = quantizers.unpack_int4(
259+
embeddings, self.output_dim, axis=0
260+
)
261+
inputs, inputs_scale = self.inputs_quantizer(inputs)
262+
logits = ops.matmul(inputs, unpacked_embeddings)
263+
# De-scale outputs
264+
logits = ops.cast(logits, self.compute_dtype)
265+
logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
266+
# Optionally soft-cap logits.
267+
if self.logit_soft_cap is not None:
268+
soft_cap = self.logit_soft_cap
269+
logits = ops.multiply(
270+
ops.tanh(ops.divide(logits, soft_cap)), soft_cap
271+
)
272+
return logits
273+
274+
def quantize(self, mode, type_check=True, config=None):
275+
del config
276+
if type_check and type(self) is not ReversibleEmbedding:
277+
raise self._not_implemented_error(self.quantize)
278+
279+
embeddings_shape = (self.input_dim, self.output_dim)
280+
if mode == "int8":
281+
# Quantize `self._embeddings` to int8 and compute corresponding
282+
# scale.
283+
embeddings_value, embeddings_scale = quantizers.abs_max_quantize(
284+
self._embeddings, axis=-1, to_numpy=True
285+
)
286+
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
287+
del self._embeddings
288+
if not self.tie_weights:
289+
reverse_embeddings_value, reverse_embeddings_scale = (
290+
quantizers.abs_max_quantize(
291+
self.reverse_embeddings, axis=0, to_numpy=True
292+
)
293+
)
294+
reverse_embeddings_scale = ops.squeeze(
295+
reverse_embeddings_scale, axis=0
296+
)
297+
del self.reverse_embeddings
298+
self.quantized_build(embeddings_shape, mode)
299+
self._embeddings.assign(embeddings_value)
300+
self.embeddings_scale.assign(embeddings_scale)
301+
if not self.tie_weights:
302+
self.reverse_embeddings.assign(reverse_embeddings_value)
303+
self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
304+
elif mode == "int4":
305+
# Quantize to int4 values (stored in int8 dtype, range [-8, 7]).
306+
embeddings_value, embeddings_scale = quantizers.abs_max_quantize(
307+
self._embeddings,
308+
axis=-1,
309+
value_range=(-8, 7),
310+
dtype="int8",
311+
to_numpy=True,
312+
)
313+
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
314+
# 2. Pack two int4 values into a single int8 byte.
315+
packed_embeddings_value, _, _ = quantizers.pack_int4(
316+
embeddings_value, axis=-1
317+
)
318+
del self._embeddings
319+
if not self.tie_weights:
320+
reverse_embeddings_value, reverse_embeddings_scale = (
321+
quantizers.abs_max_quantize(
322+
self.reverse_embeddings,
323+
axis=0,
324+
value_range=(-8, 7),
325+
dtype="int8",
326+
to_numpy=True,
327+
)
328+
)
329+
reverse_embeddings_scale = ops.squeeze(
330+
reverse_embeddings_scale, axis=0
331+
)
332+
# Pack two int4 values into a single int8 byte.
333+
packed_reverse_embeddings_value, _, _ = quantizers.pack_int4(
334+
reverse_embeddings_value, axis=0
335+
)
336+
del self.reverse_embeddings
337+
self.quantized_build(embeddings_shape, mode)
338+
self._embeddings.assign(packed_embeddings_value)
339+
self.embeddings_scale.assign(embeddings_scale)
340+
if not self.tie_weights:
341+
self.reverse_embeddings.assign(packed_reverse_embeddings_value)
342+
self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
343+
else:
344+
raise self._quantization_mode_error(mode)
345+
346+
# Set new dtype policy.
347+
if self.dtype_policy.quantization_mode is None:
348+
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
349+
self.dtype_policy = policy

0 commit comments

Comments
 (0)