Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

[transformers] Prompt masking #2192

Closed
wants to merge 13 commits into from
20 changes: 16 additions & 4 deletions src/sparseml/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class TextGenerationDataset(RegistryMixin):
"""

PROMPT_KEY = "prompt"
MASK_KEY = "mask"

def __init__(
self,
Expand Down Expand Up @@ -125,6 +126,7 @@ def tokenize_fn(data):
padding=self.padding,
max_length=self.max_seq_length,
truncation=True,
return_offsets_mapping=True,
)

# store unpadded prompt so we can mask out correct number of elements
Expand Down Expand Up @@ -156,16 +158,29 @@ def group_text_fn(data):
def label_fn(data):
# if the dataset uses prompts, mask them out so they don't contribute
# to the loss calculation
labels = data["input_ids"].copy()
if "offset_mapping" in data:
offset_mapping = data["offset_mapping"]
# get the character level mask
mask = data.get("mask")
if mask is not None:
for i, (start, end) in enumerate(offset_mapping):
# if any char is to be filtered
if "0" in mask[start:end]:
labels[i] = LABELS_MASK_VALUE

prompt_len = 0
if self.PROMPT_KEY in data:
prompt_len = len(data[self.PROMPT_KEY])
data["labels"] = data["input_ids"].copy()

data["labels"] = labels
data["labels"][:prompt_len] = [LABELS_MASK_VALUE] * prompt_len

# mask out padding in the labels as well
padding = len(data["attention_mask"]) - sum(data["attention_mask"])
if padding > 0:
data["labels"][-padding:] = [LABELS_MASK_VALUE] * padding

return data

if raw_dataset is None:
Expand Down Expand Up @@ -209,8 +224,6 @@ def label_fn(data):
load_from_cache_file=not self.data_args.overwrite_cache,
desc="Adding labels",
)
print(dataset.column_names)

return dataset

def map(
Expand All @@ -229,5 +242,4 @@ def map(
kwargs.pop("num_proc", None)
kwargs.pop("load_from_cache_file", None)
kwargs.pop("desc", None)

return dataset.map(**kwargs)
3 changes: 2 additions & 1 deletion src/sparseml/transformers/finetune/data/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def get_raw_dataset(self, *_ignore, **__ignore) -> Union[DatasetDict, Dataset]:
num_proc=self.data_args.preprocessing_num_workers,
desc="Removing unneeded columns",
)

return raw_dataset

def get_remove_columns_from_dataset(
Expand All @@ -108,5 +107,7 @@ def get_remove_columns_from_dataset(
remove_columns.remove(self.text_column)
if self.PROMPT_KEY in remove_columns:
remove_columns.remove(self.PROMPT_KEY)
if self.MASK_KEY in remove_columns:
remove_columns.remove(self.MASK_KEY)

return list(remove_columns)
51 changes: 51 additions & 0 deletions src/sparseml/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"ALL_TASK_NAMES",
"create_fake_dataloader",
"POSSIBLE_TOKENIZER_FILES",
"generate_mask",
"download_repo_from_huggingface_hub",
"download_model_directory",
]
Expand Down Expand Up @@ -544,6 +545,56 @@ def fetch_recipe_path(target: str):
return recipe_path


def generate_mask(string: str, response: str, prompt: str = "") -> str:
"""
Generate a mask based on provided prompt and response strings to obscure
characters in the input string. Prompt will be masked and string in response
will be kept represented by 0 - remove and 1 - keep.
By default, non-reponse wrapped strings will be matched with 0

Args:
:param string: The input string to be masked.
:param prompt: The prompt string to identify characters to obscure.
:param response: The response string to identify characters to keep visible.

Returns:
str: A string representing the mask where '1' indicates visible
characters and '0' indicates obscured characters.

"""

mask = ["1"] * len(string)
is_prompt = False if string.startswith(response) else True
counter = 0
for i, char in enumerate(string):
if is_prompt:
mask[i] = "0"

if counter > 0:
if not is_prompt and len(prompt) > 1 and char == prompt[counter]:
counter += 1
elif is_prompt and char == response[counter]:
counter += 1
else:
counter = 0

if len(prompt) > 0 and counter == len(prompt) and not is_prompt:
mask[i - counter + 1 : i + 1] = ["0"] * counter

counter = 0
is_prompt = True

if counter == len(response) and is_prompt:
mask[i - counter + 1 : i + 1] = ["1"] * counter

counter = 0
is_prompt = False

if prompt.startswith(char) or response.startswith(char):
counter = 1
return "".join(mask)


def download_repo_from_huggingface_hub(repo_id, **kwargs):
"""
Download relevant model files from the Hugging Face Hub
Expand Down
4 changes: 4 additions & 0 deletions src/sparseml/transformers/utils/preprocessing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import Dict

from sparseml.transformers.utils.helpers import generate_mask
from sparsezoo.utils.registry import RegistryMixin


Expand All @@ -26,4 +27,7 @@ def custom_evolved_codealpaca_dataset(data: Dict):
PROMPT_DICT = """[Instruction]:\n{instruction}\n\n[Response]:"""
data["prompt"] = PROMPT_DICT.format_map(data)
data["text"] = data["prompt"] + data["output"]
data["mask"] = generate_mask(
data["text"], prompt="[Instruction]", censor="[Response]"
)
return data
33 changes: 33 additions & 0 deletions tests/sparseml/transformers/finetune/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
oneshot,
train,
)
from sparseml.transformers.utils.helpers import generate_mask


def test_oneshot_and_finetune(tmp_path: Path):
Expand Down Expand Up @@ -322,3 +323,35 @@ def test_oneshot_with_modifier_object(tmp_path: Path):
splits=splits,
oneshot_device=device,
)


def test_finetune_wout_recipe_with_mask(tmp_path: Path):
recipe_str = None
model = "Xenova/llama2.c-stories15M"
device = "cuda:0"
if not torch.cuda.is_available():
device = "cpu"
dataset = "open_platypus"
concatenate_data = False
output_dir = tmp_path
max_steps = 50
splits = "train"

def preprocessing_func(example):
example["text"] = "[foo]" + example["text"] + "[bar] mask this"
example["mask"] = generate_mask(
example["text"], response="[bar]", prompt="[foo]"
)
return example

train(
model=model,
dataset=dataset,
output_dir=output_dir,
recipe=recipe_str,
max_steps=max_steps,
concatenate_data=concatenate_data,
splits=splits,
oneshot_device=device,
preprocessing_func=preprocessing_func,
)
53 changes: 53 additions & 0 deletions tests/sparseml/transformers/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from accelerate import init_empty_weights
from sparseml.transformers.utils.helpers import (
create_fake_dataloader,
generate_mask,
infer_recipe_from_model_path,
is_transformer_model,
resolve_recipe_file,
Expand Down Expand Up @@ -166,3 +167,55 @@ def test_save_zoo_directory(tmp_path, stub):
assert zoo_model.validate(minimal_validation=True, validate_onnxruntime=False)
shutil.rmtree(path_to_training_outputs)
shutil.rmtree(save_dir)


@pytest.mark.parametrize(
"string, response, prompt, expected_mask",
[
(
("[foo]hello\n\n" "[bar]world"),
"[bar]",
"[foo]",
("000000000000" "1111111111"),
),
(
(
"[Instruction]python is\n\n" # 24
"[Response]great\n\n" # 17
"[Instruction]What about Java" # 28
"[Response]Meh" # 13
),
"[Response]",
"[Instruction]",
(
"000000000000000000000000" # 24
"11111111111111111" # 17
"0000000000000000000000000000" # 28
"1111111111111" # 13
),
),
(
("[foo]hello\n\n" "[bar]world"),
"[bar]",
None,
("000000000000" "1111111111"),
),
(
("hello\n\n" "[bar]world"),
"[bar]",
None,
("0000000" "1111111111"),
),
(
("[bar]world" "[foo]hello\n\n" "[bar]world"),
"[bar]",
"[foo]",
("1111111111" "000000000000" "1111111111"),
),
],
)
def test_generate_mask(string, response, prompt, expected_mask):
if prompt is not None:
assert generate_mask(string, response, prompt) == expected_mask
else:
assert generate_mask(string, response) == expected_mask
Loading