Skip to content

Commit

Permalink
internal changes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 731308975
  • Loading branch information
The gemma Authors committed Feb 26, 2025
1 parent 6b347f2 commit 79671dc
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 55 deletions.
7 changes: 5 additions & 2 deletions examples/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
)
_STRING_TO_SAMPLE = flags.DEFINE_string(
"string_to_sample",
"Where is Paris ?",
"Where is Paris?",
help="Input string to sample.",
)

Expand All @@ -71,7 +71,9 @@ def _load_and_sample(
) -> None:
"""Loads and samples a string from a checkpoint."""
print(f"Loading the parameters from {path_checkpoint}")
parameters = params_lib.load_and_format_params(path_checkpoint)
parameters = params_lib.load_and_format_params(
path_checkpoint,
)
print("Parameters loaded.")
# Create a sampler with the right param shapes.
vocab = spm.SentencePieceProcessor()
Expand All @@ -85,6 +87,7 @@ def _load_and_sample(
transformer=transformer,
vocab=vocab,
params=parameters["transformer"],
cache_length=_CACHE_SIZE,
)
sampled_str = sampler(
input_strings=[input_string],
Expand Down
94 changes: 49 additions & 45 deletions gemma/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,7 @@

import sentencepiece as spm


def _compute_attention_masks(
time_step: jax.Array, seq_len: int, input_mask: jax.Array
) -> jax.Array:
"""Computes causal attention mask."""
bsz = input_mask.shape[0]
batch_time_step = jnp.full((bsz, 1), time_step, dtype=jnp.uint32)
causal_mask = jnp.less_equal(
jnp.expand_dims(jnp.arange(seq_len), 0), batch_time_step
)
max_seq_len = min(input_mask.shape[-1], seq_len)
input_mask = jax.lax.dynamic_slice(
input_mask,
(0, jnp.maximum(time_step - seq_len + 1, 0)),
(bsz, max_seq_len),
)
input_mask = (
jnp.ones((bsz, seq_len), dtype=jnp.bool_)
.at[:, :max_seq_len]
.set(input_mask)
)

causal_mask = jnp.logical_and(causal_mask, input_mask)
attention_mask = causal_mask[:, jnp.newaxis, :].astype(jnp.bool_)

return attention_mask
_compute_attention_masks = transformer_lib.compute_attention_masks


@chex.dataclass
Expand Down Expand Up @@ -129,7 +104,8 @@ def __init__(
if cache_length is None:
warnings.warn(
'TransformerConfig.max_cache_length is deprecated and will be'
' REMOVED!!! Instead, set the `cache_length` in the `Sampler` class.',
' REMOVED!!! Instead, set the `cache_length` in the `Sampler`'
' class.',
DeprecationWarning,
stacklevel=2,
)
Expand All @@ -146,17 +122,16 @@ def _sample_step(
) -> _SamplingState:
"""Performs a single sampling step."""
batch_size = sampler_state.token_buffer.shape[0]
input_mask = sampler_state.token_buffer != self.vocab.pad_id()
decoding_step = jnp.asarray(sampler_state.decoding_step, dtype=jnp.int32)
last_token = sampler_state.token_buffer[:, decoding_step]
input_mask = sampler_state.token_buffer != self.vocab.pad_id()
attention_mask = _compute_attention_masks(
attention_mask = transformer_lib.compute_attention_masks(
decoding_step, self.cache_length, input_mask
)
step_positions = jnp.expand_dims(
sampler_state.positions[:, decoding_step], -1
)
last_token = last_token.reshape((batch_size, 1))

logits, cache = self.transformer.apply(
{'params': params},
last_token,
Expand Down Expand Up @@ -225,12 +200,7 @@ def init_sample_state(
buffer_size = total_sampling_steps + 1

token_buffer = jnp.full(
(
bsz,
buffer_size,
),
self.vocab.pad_id(),
dtype=jnp.int32,
(bsz, buffer_size), self.vocab.pad_id(), dtype=jnp.int32
)
input_mask = jnp.ones_like(token_buffer, dtype=jnp.bool_)
for i, (input_ids, num_tokens) in enumerate(
Expand All @@ -240,25 +210,61 @@ def init_sample_state(
input_mask = input_mask.at[i, :num_tokens].set(
input_ids != self.vocab.pad_id()
)

positions = transformer_lib.build_positions_from_mask(input_mask)

done = jnp.zeros((bsz,), dtype=jnp.bool_)
num_input_tokens = jnp.array(num_input_tokens, dtype=jnp.int32)

input_mask = token_buffer != self.vocab.pad_id()
decoding_step = num_input_tokens[0]
logits, cache = self.transformer.apply(
{'params': self.params},
jax.lax.dynamic_slice(token_buffer, (0, 0), (bsz, decoding_step)),
positions[:, :decoding_step],
self.init_cache(bsz),
transformer_lib.compute_sequence_attention_mask(
time_step=decoding_step,
seq_len=self.cache_length,
input_mask=input_mask,
),
)
if include_logits:
logits_buffer = jnp.zeros(
(bsz, buffer_size, self.transformer.config.num_embed),
dtype=jnp.float32,
logits_buffer = jnp.concatenate(
[
jnp.zeros(
(bsz, 1, self.transformer.config.num_embed),
dtype=jnp.float32,
),
logits,
jnp.zeros(
(
bsz,
buffer_size - decoding_step - 1,
self.transformer.config.num_embed,
),
dtype=jnp.float32,
),
],
axis=1,
)
else:
logits_buffer = None

# We decoded one token here:
if forbidden_token_ids:
logits = logits.at[:, :, forbidden_token_ids].set(-jnp.inf)
next_token_candidate = jnp.argmax(logits, axis=-1)
next_token_candidate = next_token_candidate[:, -1]
token_buffer = token_buffer.at[:, decoding_step].set(next_token_candidate)

return _SamplingState(
decoding_step=0,
num_input_tokens=jnp.array(num_input_tokens, dtype=jnp.int32),
decoding_step=decoding_step,
num_input_tokens=num_input_tokens,
token_buffer=token_buffer,
positions=positions,
logits_buffer=logits_buffer,
cache=self.init_cache(bsz),
cache=cache,
done=done,
total_sampling_steps=total_sampling_steps,
forbidden_token_ids=forbidden_token_ids,
Expand All @@ -284,7 +290,7 @@ def mask_tokens_after_eos_ids(self, token_buffer):
mask = jnp.less_equal(
jnp.arange(token_buffer.shape[-1]), eos_indices[:, None]
)
masked_token_buffer = token_buffer * mask + self.vocab.pad_id()*(1 - mask)
masked_token_buffer = token_buffer * mask + self.vocab.pad_id() * (1 - mask)

return masked_token_buffer

Expand Down Expand Up @@ -373,9 +379,7 @@ def __call__(
out_logits.append(
logits_buffer[start_idx:total_sampling_steps].tolist()
)

decoded_outputs = [self.vocab.DecodeIds(tokens) for tokens in out_tokens]

result = SamplerOutput(
text=decoded_outputs,
logits=out_logits,
Expand Down
20 changes: 12 additions & 8 deletions gemma/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def test_samples(self):
params=params['params'],
)

result = sampler(['input string', 'hello world'], total_generation_steps=10)
result = sampler(
['input string', 'hello world'],
total_generation_steps=10,
)
self.assertIsNotNone(result)

def test_forbidden_tokens(self):
Expand Down Expand Up @@ -220,11 +223,13 @@ def test_forward_equivalence(self):
)

output_transformer = sampler(
[raw_input], total_generation_steps=10, echo=True
[raw_input],
total_generation_steps=10,
echo=True,
)
out_logits = np.array(output_transformer.logits)[0, 1 : n_input_tokens + 1]

np.testing.assert_almost_equal(output_forward, out_logits)
np.testing.assert_almost_equal(output_forward, out_logits, decimal=2)

def test_sampler_init_sample_state(self):
vocab = MockVocab()
Expand Down Expand Up @@ -327,8 +332,8 @@ def test_compute_attention_mask(self):
time_step, seq_len, input_mask
)
expected_attn_mask = jnp.array(
[[0, 0, 1, 1, 1, 0, 0, 0],
[0, 0, 1, 0, 1, 0, 0, 0]], dtype=jnp.bool_)
[[0, 0, 1, 1, 1, 0, 0, 0], [0, 0, 1, 0, 1, 0, 0, 0]], dtype=jnp.bool_
)

self.assertTrue((attn_mask.squeeze(1) == expected_attn_mask).all())

Expand All @@ -339,10 +344,9 @@ def test_compute_attention_mask(self):
attn_mask = sampler_lib._compute_attention_masks(
time_step, seq_len, input_mask
)
print(attn_mask)
expected_attn_mask = jnp.array(
[[0, 1, 1, 1],
[0, 1, 0, 1]], dtype=jnp.bool_)
[[0, 1, 1, 1], [0, 1, 0, 1]], dtype=jnp.bool_
)

self.assertTrue((attn_mask.squeeze(1) == expected_attn_mask).all())

Expand Down
63 changes: 63 additions & 0 deletions gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,69 @@ def __call__(
return logits, cache # pytype: disable=bad-return-type


def compute_sequence_attention_mask(
time_step: jax.Array,
*,
seq_len: int,
input_mask: jax.Array,
bi_directional_mask: jax.Array | None = None,
) -> jax.Array:
"""Computes sequence attention mask."""
bsz = input_mask.shape[0]
attention_mask = jnp.tile(
jnp.expand_dims(jnp.tri(N=int(time_step), M=int(seq_len)), axis=0),
(bsz, 1, 1),
)
if bi_directional_mask is not None:
bi_directional_mask = jnp.expand_dims(
jnp.concatenate([
bi_directional_mask[0],
jnp.zeros((seq_len - len(bi_directional_mask))),
]),
axis=0,
)
bi_directional_mask = jnp.tile(
jnp.expand_dims(
jnp.outer(bi_directional_mask, bi_directional_mask)[
:time_step, :seq_len
],
axis=0,
),
(bsz, 1, 1),
).astype(jnp.bool_)
attention_mask = jnp.logical_or(attention_mask, bi_directional_mask).astype(
jnp.bool_
)
return attention_mask


def compute_attention_masks(
time_step: jax.Array, seq_len: int, input_mask: jax.Array
) -> jax.Array:
"""Computes causal attention mask."""
bsz = input_mask.shape[0]
batch_time_step = jnp.full((bsz, 1), time_step, dtype=jnp.uint32)
causal_mask = jnp.less_equal(
jnp.expand_dims(jnp.arange(seq_len), 0), batch_time_step
)
max_seq_len = min(input_mask.shape[-1], seq_len)
input_mask = jax.lax.dynamic_slice(
input_mask,
(0, jnp.maximum(time_step - seq_len + 1, 0)),
(bsz, max_seq_len),
)
input_mask = (
jnp.ones((bsz, seq_len), dtype=jnp.bool_)
.at[:, :max_seq_len]
.set(input_mask)
)

causal_mask = jnp.logical_and(causal_mask, input_mask)
attention_mask = causal_mask[:, jnp.newaxis, :].astype(jnp.bool_)

return attention_mask


def make_causal_attn_mask(
input_mask: jax.Array, # Shape [B, L]
) -> jax.Array:
Expand Down

0 comments on commit 79671dc

Please sign in to comment.