Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

[transformers] Prompt masking #2192

Closed
wants to merge 13 commits into from
19 changes: 15 additions & 4 deletions src/sparseml/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def tokenize_fn(data):
padding=self.padding,
max_length=self.max_seq_length,
truncation=True,
return_offsets_mapping=True,
)

# store unpadded prompt so we can mask out correct number of elements
Expand Down Expand Up @@ -156,16 +157,29 @@ def group_text_fn(data):
def label_fn(data):
# if the dataset uses prompts, mask them out so they don't contribute
# to the loss calculation
labels = data["input_ids"].copy()
if "offset_mapping" in data:
offset_mapping = data["offset_mapping"]
# get the character level mask
mask = data.get("mask")
if mask is not None:
for i, (start, end) in enumerate(offset_mapping):
# if any char is to be filtered
if "0" in mask[start:end]:
labels[i] = LABELS_MASK_VALUE

prompt_len = 0
if self.PROMPT_KEY in data:
prompt_len = len(data[self.PROMPT_KEY])
data["labels"] = data["input_ids"].copy()

data["labels"] = labels
data["labels"][:prompt_len] = [LABELS_MASK_VALUE] * prompt_len

# mask out padding in the labels as well
padding = len(data["attention_mask"]) - sum(data["attention_mask"])
if padding > 0:
data["labels"][-padding:] = [LABELS_MASK_VALUE] * padding

return data

dataset = self.map(
Expand Down Expand Up @@ -206,8 +220,6 @@ def label_fn(data):
load_from_cache_file=not self.data_args.overwrite_cache,
desc="Adding labels",
)
print(dataset.column_names)

return dataset

def map(
Expand All @@ -226,5 +238,4 @@ def map(
kwargs.pop("num_proc", None)
kwargs.pop("load_from_cache_file", None)
kwargs.pop("desc", None)

return dataset.map(**kwargs)
3 changes: 2 additions & 1 deletion src/sparseml/transformers/finetune/data/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def get_raw_dataset(self, *_ignore, **__ignore) -> Union[DatasetDict, Dataset]:
num_proc=self.data_args.preprocessing_num_workers,
desc="Removing unneeded columns",
)

return raw_dataset

def get_remove_columns_from_dataset(
Expand All @@ -108,5 +107,7 @@ def get_remove_columns_from_dataset(
remove_columns.remove(self.text_column)
if self.PROMPT_KEY in remove_columns:
remove_columns.remove(self.PROMPT_KEY)
if "mask" in remove_columns:
remove_columns.remove("mask")

return list(remove_columns)
47 changes: 47 additions & 0 deletions src/sparseml/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"ALL_TASK_NAMES",
"create_fake_dataloader",
"POSSIBLE_TOKENIZER_FILES",
"generate_mask",
]


Expand Down Expand Up @@ -554,3 +555,49 @@ def fetch_recipe_path(target: str):
recipe_path = hf_hub_download(repo_id=target, filename=DEFAULT_RECIPE_NAME)

return recipe_path


def generate_mask(string: str, prompt: str, censor: str) -> str:
"""
Generate a mask based on provided prompt and censor strings to obscure
characters in the input string.

Args:
:param string: The input string to be masked.
:param prompt: The prompt string to identify characters to keep visible.
:param censor: The censor string to identify characters to obscure.

Returns:
str: A string representing the mask where '1' indicates visible
characters and '0' indicates obscured characters.

"""
mask = ["1"] * len(string)
is_prompt = True
counter = 0
for i, char in enumerate(string):
if not is_prompt:
mask[i] = "0"

if counter > 0:
if not is_prompt and char == prompt[counter]:
counter += 1
elif is_prompt and char == censor[counter]:
counter += 1
else:
counter = 0

if counter == len(prompt) and not is_prompt:
mask[i - counter + 1 : i + 1] = ["1"] * counter
counter = 0
is_prompt = True

if counter == len(censor) and is_prompt:
mask[i - counter + 1 : i + 1] = ["0"] * counter
counter = 0
is_prompt = False

if prompt.startswith(char) or censor.startswith(char):
counter = 1

return "".join(mask)
4 changes: 4 additions & 0 deletions src/sparseml/transformers/utils/preprocessing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import Dict

from sparseml.transformers.utils.helpers import generate_mask
from sparsezoo.utils.registry import RegistryMixin


Expand All @@ -26,4 +27,7 @@ def custom_evolved_codealpaca_dataset(data: Dict):
PROMPT_DICT = """[Instruction]:\n{instruction}\n\n[Response]:"""
data["prompt"] = PROMPT_DICT.format_map(data)
data["text"] = data["prompt"] + data["output"]
data["mask"] = generate_mask(
data["text"], prompt="[Instruction]", censor="[Response]"
)
return data
27 changes: 27 additions & 0 deletions tests/sparseml/transformers/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from accelerate import init_empty_weights
from sparseml.transformers.utils.helpers import (
create_fake_dataloader,
generate_mask,
infer_recipe_from_model_path,
is_transformer_model,
resolve_recipe_file,
Expand Down Expand Up @@ -166,3 +167,29 @@ def test_save_zoo_directory(tmp_path, stub):
assert zoo_model.validate(minimal_validation=True, validate_onnxruntime=False)
shutil.rmtree(path_to_training_outputs)
shutil.rmtree(save_dir)


@pytest.mark.parametrize(
"string, prompt, censor, expected_mask",
[
("[foo]hello\n\n[bar]world", "[foo]", "[bar]", "1111111111110000000000"),
(
(
"[Instruction]python is\n\n" # 24
"[Response]great\n\n" # 17
"[Instruction]What about Java" # 28
"[Response]Meh" # 13
),
"[Instruction]",
"[Response]",
(
"111111111111111111111111" # 24
"00000000000000000" # 17
"1111111111111111111111111111" # 28
"0000000000000" # 13
),
),
],
)
def test_generate_mask(string, prompt, censor, expected_mask):
assert generate_mask(string, prompt, censor) == expected_mask