Skip to content

Commit

Permalink
Allow to customize the sampling method
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 731608909
  • Loading branch information
Conchylicultor authored and The gemma Authors committed Feb 27, 2025
1 parent 79671dc commit 6878297
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 14 deletions.
13 changes: 8 additions & 5 deletions colabs/sampling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,16 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vCmNJUbHBGvK"
"id": "uOQHd6eFd67c"
},
"outputs": [],
"source": []
"cell_type": "markdown",
"source": [
"By default, greedy decoding is used. You can pass a custom `sampling=` method as kwargs:\n",
"\n",
"* `gm.text.Greedy()`: (default) Greedy decoding\n",
"* `gm.text.RandomSampling()`: Simple random sampling"
]
}
],
"metadata": {
Expand Down
6 changes: 6 additions & 0 deletions gemma/gm/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@
from gemma.gm.text._tokenizer import Gemma2Tokenizer
from gemma.gm.text._tokenizer import Tokenizer
from gemma.gm.text._sampler import Sampler

# Sampling methods
# TODO(epot): Add `TopK`,...
from gemma.gm.text._sampling import SamplingMethod
from gemma.gm.text._sampling import Greedy
from gemma.gm.text._sampling import RandomSampling
25 changes: 23 additions & 2 deletions gemma/gm/text/_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
from collections.abc import Sequence
import dataclasses
import functools
import random as py_random
import typing

from gemma import params as params_lib
from gemma.gm.nn import _transformer
from gemma.gm.text import _sampler_impl
from gemma.gm.text import _sampling
from gemma.gm.text import _tokenizer
import numpy as np

Expand Down Expand Up @@ -51,12 +53,19 @@ class Sampler:
model: Gemma transformer model.
params: Model parameters.
tokenizer: Tokenizer.
cache_length: Size of the attention cache.
sampling: Sampling method to use. Default to greedy sampling.
seed: Seed to use for sampling.
"""

model: _transformer.Transformer
params: params_lib.Params
tokenizer: _tokenizer.Tokenizer
sampling: _sampling.SamplingMethod = dataclasses.field(
default_factory=_sampling.Greedy
)
seed: int = dataclasses.field(
default_factory=lambda: py_random.randint(0, 1000000000)
)

# TODO(epot): Add a `max_length` argument to the `sample()` method.
@typing.overload
Expand All @@ -65,6 +74,7 @@ def sample(
prompt: str,
*,
max_new_tokens: int = ...,
sampling: _sampling.SamplingMethod = ...,
) -> str:
...

Expand All @@ -74,6 +84,7 @@ def sample(
prompt: Sequence[str],
*,
max_new_tokens: int = ...,
sampling: _sampling.SamplingMethod = ...,
) -> list[str]:
...

Expand All @@ -82,6 +93,7 @@ def sample(
prompt,
*,
max_new_tokens=200,
sampling: _sampling.SamplingMethod | None = None,
):
"""Samples a string from the model.
Expand All @@ -90,10 +102,14 @@ def sample(
strings.
max_new_tokens: Maximum number of new tokens to generate. The transformer
will process `input_length + max_new_tokens`.
sampling: Sampling method to use. If given, will override the default
sampling method.
Returns:
The sampled output.
"""
sampling = sampling or self.sampling

if _is_str_array(prompt): # Supports batched input array
assert isinstance(prompt, np.ndarray)
prompt = prompt.tolist()
Expand All @@ -104,7 +120,11 @@ def sample(
else:
is_single_prompt = False

output = self._sampler(prompt, total_generation_steps=max_new_tokens)
output = self._sampler(
prompt,
sampling=sampling,
total_generation_steps=max_new_tokens,
)
output = output.text

if is_single_prompt:
Expand All @@ -123,6 +143,7 @@ def _sampler(self) -> _sampler_impl.Sampler:
# `max_length=` in `def sample()` ? No need to allocate extra memory
# when the sequence is smaller.
cache_length=1024,
seed=self.seed,
)


Expand Down
39 changes: 32 additions & 7 deletions gemma/gm/text/_sampler_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
from gemma import params as params_lib
from gemma import transformer as transformer_lib
from gemma.gm.nn import _transformer
from gemma.gm.text import _sampling
from gemma.gm.text import _tokenizer
import jax
import jax.numpy as jnp
from kauldron import random


@flax.struct.dataclass
@flax.struct.dataclass(kw_only=True)
class _SamplingState:
"""Internal sampling state."""

Expand Down Expand Up @@ -61,6 +63,8 @@ class _SamplingState:
# List of tokens that are forbidden to be generated.
forbidden_token_ids: Sequence[int] | None = None

rng: random.PRNGKey


@dataclasses.dataclass
class SamplerOutput:
Expand All @@ -86,6 +90,7 @@ def __init__(
tokenizer: _tokenizer.Tokenizer,
params: params_lib.Params,
cache_length: int = 1024,
seed: int,
):
"""Initializes a sampler for a Gemma model.
Expand All @@ -94,19 +99,27 @@ def __init__(
tokenizer: tokenizer of the given model.
params: weights of the model.
cache_length: Max length of the cache.
seed: Random seed for the sampler.
"""
self.transformer = transformer
self.tokenizer = tokenizer
self.params = params
self.cache_length = cache_length
self._compiled_sample_fn = jax.jit(self._sample_fn)
self._compiled_sample_fn = jax.jit(
self._sample_fn, static_argnames=('sampling',)
)
self._rng = random.PRNGKey(seed)

@property
def dtype(self) -> jnp.dtype:
return jax.tree_util.tree_leaves(self.params)[0].dtype

def _sample_step(
self, params, sampler_state: _SamplingState
self,
params,
sampler_state: _SamplingState,
*,
sampling: _sampling.SamplingMethod,
) -> _SamplingState:
"""Performs a single sampling step."""
batch_size = sampler_state.token_buffer.shape[0]
Expand All @@ -133,8 +146,10 @@ def _sample_step(
if sampler_state.forbidden_token_ids:
logits = logits.at[:, :, sampler_state.forbidden_token_ids].set(-jnp.inf)

next_token_candidate = jnp.argmax(logits, axis=-1) # [B, 1]
next_token_candidate = next_token_candidate[:, 0] # [B,]
# Logit is `B L V` with `L=1`
next_rng, curr_rng = sampler_state.rng.split()
next_token_candidate = sampling.get_next_tokens(logits, rng=curr_rng)
next_token_candidate = next_token_candidate[:, 0]

next_token_candidate = jnp.where(
decoding_step < sampler_state.num_input_tokens - 1,
Expand Down Expand Up @@ -170,6 +185,7 @@ def _sample_step(
done=done,
total_sampling_steps=sampler_state.total_sampling_steps,
forbidden_token_ids=sampler_state.forbidden_token_ids,
rng=next_rng,
)

def init_cache(self, bsz) -> dict[str, modules.LayerCache]:
Expand Down Expand Up @@ -230,6 +246,7 @@ def init_sample_state(
done=done,
total_sampling_steps=total_sampling_steps,
forbidden_token_ids=forbidden_token_ids,
rng=self._rng,
)

def tokenize(self, input_string: str) -> jax.Array:
Expand Down Expand Up @@ -263,11 +280,13 @@ def _sample_fn(
self,
params: params_lib.Params,
initial_sampling_state: _SamplingState,
*,
sampling: _sampling.SamplingMethod,
) -> _SamplingState:
"""Internal sampling function (to be jitted)."""

def sample_with_params(sampler_state: _SamplingState):
return self._sample_step(params, sampler_state)
return self._sample_step(params, sampler_state, sampling=sampling)

def cond_fn(sampler_state: _SamplingState):
return (
Expand All @@ -281,10 +300,12 @@ def cond_fn(sampler_state: _SamplingState):
def __call__(
self,
input_strings: Sequence[str],
*,
total_generation_steps: int,
echo: bool = False,
return_logits: bool = False,
forbidden_tokens: Sequence[str] | None = None,
sampling: _sampling.SamplingMethod,
) -> SamplerOutput:
"""Samples a completion of the input string.
Expand All @@ -296,6 +317,7 @@ def __call__(
return_logits: whether to return per-step logits used during generation.
forbidden_tokens: list of tokens that are forbidden to be generated. Each
token must map to a single token id in the vocab.
sampling: Sampling method to use.
Returns:
sampler_output: A SamplerOutput object containing the generated samples.
Expand All @@ -322,7 +344,7 @@ def __call__(
)

sampling_state = self._compiled_sample_fn(
self.params, initial_sampling_state
self.params, initial_sampling_state, sampling=sampling
)

masked_token_buffer = self.mask_tokens_after_eos_ids(
Expand All @@ -347,6 +369,9 @@ def __call__(

decoded_outputs = [self.tokenizer.decode(tokens) for tokens in out_tokens]

# Update the rng for the next call.
self._rng = sampling_state.rng

result = SamplerOutput(
text=decoded_outputs,
logits=out_logits,
Expand Down
59 changes: 59 additions & 0 deletions gemma/gm/text/_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sampling methods."""

import abc
import dataclasses

import jax
import jax.numpy as jnp
from kauldron.typing import Float, Int, PRNGKey, typechecked # pylint: disable=g-multiple-import,g-importing-member


class SamplingMethod(abc.ABC):
"""Base class for sampling methods."""

@abc.abstractmethod
def get_next_tokens(self, logits: Float['B V'], rng: PRNGKey) -> Int['B']:
"""Returns the next tokens to generate.
Args:
logits: Logits, as returned by the model (i.e. before softmax).
rng: A random key.
Returns:
The next tokens to generate.
"""
raise NotImplementedError()


@dataclasses.dataclass(frozen=True, kw_only=True)
class Greedy(SamplingMethod):
"""Greedy sampling."""

@typechecked
def get_next_tokens(self, logits: Float['*B V'], rng: PRNGKey) -> Int['*B']:
del rng
return jnp.argmax(logits, axis=-1)


@dataclasses.dataclass(frozen=True, kw_only=True)
class RandomSampling(SamplingMethod):
"""Simple random sampling."""
temperature: float = 1.0

@typechecked
def get_next_tokens(self, logits: Float['*B V'], rng: PRNGKey) -> Int['*B']:
return jax.random.categorical(rng, logits / self.temperature, axis=-1)

0 comments on commit 6878297

Please sign in to comment.