Skip to content

Commit

Permalink
Simplify data pipeline for next-token prediction, and sampling
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 730797372
  • Loading branch information
Conchylicultor authored and The gemma Authors committed Feb 25, 2025
1 parent d057dca commit ee0d556
Show file tree
Hide file tree
Showing 13 changed files with 424 additions and 330 deletions.
41 changes: 9 additions & 32 deletions colabs/finetuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
"* Creating the model inputs (with `gm.data.Tokenize`))\n",
"* Adding padding (with `gm.data.Pad`) (required to batch inputs with different lengths)\n",
"\n",
"Note that in practice, you can combine multiple transforms into a higher level transform. See the `gm.data.AddContrastiveFields()` transform in the [DPO example](https://github.com/google-deepmind/gemma/tree/main/examples/dpo.py) for an example.\n",
"Note that in practice, you can combine multiple transforms into a higher level transform. See the `gm.data.ContrastiveTask()` transform in the [DPO example](https://github.com/google-deepmind/gemma/tree/main/examples/dpo.py) for an example.\n",
"\n",
"Here, we try [mtnt](https://www.tensorflow.org/datasets/catalog/mtnt), a small translation dataset. The dataset structure is `{'src': ..., 'dst': ...}`."
]
Expand Down Expand Up @@ -184,48 +184,25 @@
" shuffle=True,\n",
" batch_size=8,\n",
" transforms=[\n",
" # TFDS returns `bytes` rather than `str`, so need to decode them first\n",
" gm.data.DecodeBytes(key=['src', 'dst']),\n",
" # We format the input to add the special tokens\n",
" # See `\u003cstart_of_turn\u003e` section in\n",
" # https://github.com/google-deepmind/gemma/blob/main/docs/tokenizer.md\n",
" gm.data.FormatText(\n",
" key='src',\n",
" template=\"\"\"\\\n",
" \u003cstart_of_turn\u003euser\n",
" {text}\u003cend_of_turn\u003e\n",
" \u003cstart_of_turn\u003emodel\n",
" \"\"\",\n",
" ),\n",
" # Tokenize the inputs/outputs\n",
" gm.data.Tokenize(key='src', tokenizer=tokenizer, add_bos=True),\n",
" gm.data.Tokenize(key='dst', tokenizer=tokenizer, add_eos=True),\n",
" # Create the model inputs/targets/loss_mask.\n",
" gm.data.AddNextTokenPredictionFields(\n",
" gm.data.Seq2SeqTask(\n",
" # Select which field from the dataset to use.\n",
" # https://www.tensorflow.org/datasets/catalog/mtnt\n",
" in_prompt='src',\n",
" in_response='dst',\n",
" # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}\n",
" out_input='input',\n",
" out_target='target',\n",
" out_target_mask='loss_mask',\n",
" ),\n",
" # Only keep the fields we need.\n",
" kd.data.Elements(keep=[\"input\", \"target\", \"loss_mask\"]),\n",
" # Pad the sequences to support batching.\n",
" gm.data.Pad(\n",
" key=[\"input\", \"target\", \"loss_mask\"],\n",
" tokenizer=tokenizer,\n",
" # Padding parameters\n",
" max_length=200,\n",
" # In this dataset, ~1% of examples are longer than 200 tokens.\n",
" # TODO(epot): Compute statistics\n",
" truncate=True,\n",
" ),\n",
" # For shape compatibility with the loss\n",
" kd.data.Rearrange(\n",
" key=[\"target\", \"loss_mask\"], pattern=\"... -\u003e ... 1\"\n",
" ),\n",
" ],\n",
")\n",
"\n",
"(ex,) = ds.take(1)\n",
"ex = ds[0]\n",
"\n",
"treescope.show(ex)"
]
Expand All @@ -236,7 +213,7 @@
"id": "3ny2J07G2X7i"
},
"source": [
"We can decode an example from the batch to inspect the model input and check it is properly formatted:"
"We can decode an example from the batch to inspect the model input. We see that the `\u003cstart_of_turn\u003e` / `\u003cend_of_turn\u003e` where correctly added to follow Gemma dialog format."
]
},
{
Expand Down
37 changes: 7 additions & 30 deletions colabs/lora_finetuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -309,48 +309,25 @@
" shuffle=True,\n",
" batch_size=8,\n",
" transforms=[\n",
" # TFDS returns `bytes` rather than `str`, so need to decode them first\n",
" gm.data.DecodeBytes(key=['src', 'dst']),\n",
" # We format the input to add the special tokens\n",
" # See `\u003cstart_of_turn\u003e` section in\n",
" # https://github.com/google-deepmind/gemma/blob/main/docs/tokenizer.md\n",
" gm.data.FormatText(\n",
" key='src',\n",
" template=\"\"\"\\\n",
" \u003cstart_of_turn\u003euser\n",
" {text}\u003cend_of_turn\u003e\n",
" \u003cstart_of_turn\u003emodel\n",
" \"\"\",\n",
" ),\n",
" # Tokenize the inputs/outputs\n",
" gm.data.Tokenize(key='src', tokenizer=tokenizer, add_bos=True),\n",
" gm.data.Tokenize(key='dst', tokenizer=tokenizer, add_eos=True),\n",
" # Create the model inputs/targets/loss_mask.\n",
" gm.data.AddNextTokenPredictionFields(\n",
" gm.data.Seq2SeqTask(\n",
" # Select which field from the dataset to use.\n",
" # https://www.tensorflow.org/datasets/catalog/mtnt\n",
" in_prompt='src',\n",
" in_response='dst',\n",
" # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}\n",
" out_input='input',\n",
" out_target='target',\n",
" out_target_mask='loss_mask',\n",
" ),\n",
" # Only keep the fields we need.\n",
" kd.data.Elements(keep=[\"input\", \"target\", \"loss_mask\"]),\n",
" # Pad the sequences to support batching.\n",
" gm.data.Pad(\n",
" key=[\"input\", \"target\", \"loss_mask\"],\n",
" tokenizer=tokenizer,\n",
" # Padding parameters\n",
" max_length=200,\n",
" # In this dataset, ~1% of examples are longer than 200 tokens.\n",
" # TODO(epot): Compute statistics\n",
" truncate=True,\n",
" ),\n",
" # For shape compatibility with the loss\n",
" kd.data.Rearrange(\n",
" key=[\"target\", \"loss_mask\"], pattern=\"... -\u003e ... 1\"\n",
" ),\n",
" ],\n",
")\n",
"\n",
"(ex,) = ds.take(1)\n",
"ex = ds[0]\n",
"\n",
"treescope.show(ex)"
]
Expand Down
4 changes: 1 addition & 3 deletions examples/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _make_dataset(training: bool) -> kd.data.Pipeline:
batch_size=batch_size,
transforms=[
# Create the model inputs and loss mask.
gm.data.AddContrastiveFields(
gm.data.ContrastiveTask(
in_prompt="input",
in_chosen="chosen",
in_rejected="rejected",
Expand All @@ -111,7 +111,5 @@ def _make_dataset(training: bool) -> kd.data.Pipeline:
# TODO(epot): Run stats (how many examples are we dropping?)
truncate=True,
),
# Only keep the fields we need.
kd.data.Elements(keep=["tokens", "mask"]),
],
)
55 changes: 19 additions & 36 deletions examples/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,9 @@ def get_config():
# test set.
"sampling": gm.evals.SamplerEvaluator(
run=kd.evals.EveryNSteps(1000),
# Sampling parameters
tokenizer=gm.text.Gemma2Tokenizer(),
max_new_tokens=150,
# Which examples to use for sampling
# The prompt and response indicates the fields to use within each
# dataset example.
prompt="prompt",
response="response",
max_new_tokens=150, # Sampling parameters
num_examples=1, # Only predict a single example
ds=kd.data.py.Json(
shuffle=False,
num_epochs=1,
batch_size=None,
num_workers=0,
),
ds=_make_dataset(training=False, sampling=True),
),
},
)
Expand All @@ -128,41 +116,36 @@ def get_config():
def _make_dataset(
*,
training: bool,
batch_size: int,
max_length: int,
sampling: bool = False,
batch_size: int | None = None,
max_length: int | None = None,
):
tokenizer = gm.text.Gemma2Tokenizer()

split = "train" if training else "test"

return kd.data.py.Json(
return kd.data.py.Tfds(
name="mtnt/en-fr",
split="train" if training else "test",
shuffle=True if training else False,
num_epochs=None if training else 1,
batch_size=batch_size,
batch_size=None if sampling else batch_size,
num_workers=4,
transforms=[
gm.data.Tokenize(key="prompt", tokenizer=tokenizer, add_bos=True),
gm.data.Tokenize(key="response", tokenizer=tokenizer, add_eos=True),
# Create the model inputs/targets/loss_mask.
gm.data.AddNextTokenPredictionFields(
in_prompt="prompt",
in_response="response",
gm.data.Seq2SeqTask(
# Select which field from the dataset to use.
# https://www.tensorflow.org/datasets/catalog/mtnt
in_prompt="src",
in_response="dst",
# Output batch is {"input": ..., "target": ..., "loss_mask": ...}
out_input="input",
out_target="target",
out_target_mask="loss_mask",
),
# Only keep the fields we need.
kd.data.Elements(keep=["input", "target", "loss_mask"]),
# Pad the sequences to support batching.
gm.data.Pad(
key=["input", "target", "loss_mask"],
max_length=max_length,
tokenizer=tokenizer,
# Padding parameters
max_length=None if sampling else max_length,
# In this dataset, ~1% of examples are longer than 512 tokens.
truncate=True,
),
# For shape compatibility with the loss
kd.data.Rearrange(
key=["target", "loss_mask"], pattern="... -> ... 1"
sampling=sampling,
),
],
)
78 changes: 23 additions & 55 deletions examples/next_token_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,19 @@
```
<start_of_turn>user
You are a helpful assistant with access to the following functions. Use them if required -
{
"name": "search_recipes",
"description": "Search for recipes based on ingredients",
"parameters": {
...
}
}
{
"name": "get_movie_details",
"description": "Get details about a movie",
"parameters": {
...
}
}
I have some chicken, broccoli, and cheese. Can you find me a recipe?
Hello! I would love to visit France.<end_of_turn>
<start_of_turn>model
```
Output:
```
{"name": "search_recipes", "arguments": '{"ingredients": ["chicken", "broccoli", "cheese"]}'}<end_of_turn>
Bonjour ! J'adorerais visiter la France.<end_of_turn>
```
The `<start_of_turn>` and `<end_of_turn>` are special tokens used to specify
which of the user or model is speaking.
which of the user or model is speaking. Those are automatically added by the
`Seq2SeqTask` transform.
Train locally with:
Expand Down Expand Up @@ -116,21 +101,9 @@ def get_config():
# test set.
"sampling": gm.evals.SamplerEvaluator(
run=kd.evals.EveryNSteps(1000),
# Sampling parameters
tokenizer=gm.text.Gemma2Tokenizer(),
max_new_tokens=50,
# Which examples to use for sampling
# The prompt and response indicates the fields to use within each
# dataset example.
prompt="prompt",
response="response",
max_new_tokens=50, # Sampling parameters
num_examples=1, # Only predict a single example
ds=kd.data.py.Json(
shuffle=False,
num_epochs=1,
batch_size=None,
num_workers=0,
),
ds=_make_dataset(training=False, sampling=True),
),
},
)
Expand All @@ -139,41 +112,36 @@ def get_config():
def _make_dataset(
*,
training: bool,
batch_size: int,
max_length: int,
sampling: bool = False,
batch_size: int | None = None,
max_length: int | None = None,
):
tokenizer = gm.text.Gemma2Tokenizer()

split = "train" if training else "test"

return kd.data.py.Json(
return kd.data.py.Tfds(
name="mtnt/en-fr",
split="train" if training else "test",
shuffle=True if training else False,
num_epochs=None if training else 1,
batch_size=batch_size,
batch_size=None if sampling else batch_size,
num_workers=4,
transforms=[
gm.data.Tokenize(key="prompt", tokenizer=tokenizer, add_bos=True),
gm.data.Tokenize(key="response", tokenizer=tokenizer, add_eos=True),
# Create the model inputs/targets/loss_mask.
gm.data.AddNextTokenPredictionFields(
in_prompt="prompt",
in_response="response",
gm.data.Seq2SeqTask(
# Select which field from the dataset to use.
# https://www.tensorflow.org/datasets/catalog/mtnt
in_prompt="src",
in_response="dst",
# Output batch is {"input": ..., "target": ..., "loss_mask": ...}
out_input="input",
out_target="target",
out_target_mask="loss_mask",
),
# Only keep the fields we need.
kd.data.Elements(keep=["input", "target", "loss_mask"]),
# Pad the sequences to support batching.
gm.data.Pad(
key=["input", "target", "loss_mask"],
max_length=max_length,
tokenizer=tokenizer,
# Padding parameters
max_length=None if sampling else max_length,
# In this dataset, ~1% of examples are longer than 512 tokens.
truncate=True,
),
# For shape compatibility with the loss
kd.data.Rearrange(
key=["target", "loss_mask"], pattern="... -> ... 1"
sampling=sampling,
),
],
)
Loading

0 comments on commit ee0d556

Please sign in to comment.