diff --git a/docs/datasets.md b/docs/datasets.md index 86aa3af2..0022eb31 100644 --- a/docs/datasets.md +++ b/docs/datasets.md @@ -69,6 +69,7 @@ guidellm benchmark \ - `output_tokens_stdev`: Standard deviation for output tokens. If not supplied and min/max are not specified, no deviation is applied. If not supplied and min/max are specified, a uniform distribution is used. - `output_tokens_min`: Minimum number of tokens in outputs. If unset and `output_tokens_stdev` is set, the minimum is 1. - `output_tokens_max`: Maximum number of tokens in outputs. If unset and `output_tokens_stdev` is set, the maximum is 5 times the standard deviation. +- `prefix_tokens`: Number of tokens to share as a prefix across all prompts. Is additive to the prompt tokens distribution so each request is `prefix_tokens + prompt_tokens_sample()`. If unset, defaults to 0. - `samples`: Number of samples to generate (default: 1000). More samples will increase the time taken to generate the dataset before benchmarking, but will also decrease the likelihood of caching requests. - `source`: Source text for generation (default: `data:prideandprejudice.txt.gz`). This can be any text file, URL containing a text file, or a compressed text file. The text is used to sample from at a word and punctuation granularity and then combined into a single string of the desired lengths. diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py index 9868ab52..94dd3aa6 100644 --- a/src/guidellm/dataset/synthetic.py +++ b/src/guidellm/dataset/synthetic.py @@ -25,6 +25,11 @@ class SyntheticDatasetConfig(BaseModel): + prefix_tokens: int = Field( + description="The number of shared prefix tokens to prepend to each prompt.", + ge=0, + default=0, + ) prompt_tokens: int = Field( description="The average number of text tokens generated for prompts.", gt=0, @@ -164,21 +169,28 @@ def __iter__( # ensure diff distribution from output tokens rand = random.Random(self.random_seed + 2) # noqa: S311 + prefix_index = rand.randint(0, len(self.text_creator.words)) + prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index) + for _, prompt_tokens, output_tokens in zip( range(self.config.samples), prompt_tokens_sampler, output_tokens_sampler, ): start_index = rand.randint(0, len(self.text_creator.words)) + prompt_text = self.processor.decode( + prefix_tokens + self._create_prompt(prompt_tokens, start_index), + skip_special_tokens=True, + ) yield { - "prompt": self._create_prompt(prompt_tokens, start_index), - "prompt_tokens_count": prompt_tokens, + "prompt": prompt_text, + "prompt_tokens_count": self.config.prefix_tokens + prompt_tokens, "output_tokens_count": output_tokens, } - def _create_prompt(self, prompt_tokens: int, start_index: int) -> str: + def _create_prompt(self, prompt_tokens: int, start_index: int) -> list[int]: if prompt_tokens <= 0: - return "" + return [] left = start_index right = start_index + 4 * prompt_tokens @@ -186,16 +198,17 @@ def _create_prompt(self, prompt_tokens: int, start_index: int) -> str: while left < right: mid = (left + right) // 2 test_prompt = self.text_creator.create_text(start_index, mid - start_index) - test_tokens = len(self.processor.tokenize(test_prompt)) + test_tokens = self.processor.encode(test_prompt) - if test_tokens == prompt_tokens: - return test_prompt - elif test_tokens < prompt_tokens: + if len(test_tokens) == prompt_tokens: + return test_tokens + elif len(test_tokens) < prompt_tokens: left = mid + 1 else: right = mid - return self.text_creator.create_text(start_index, left - start_index) + final_text = self.text_creator.create_text(start_index, left - start_index) + return self.processor.encode(final_text) class SyntheticDatasetCreator(DatasetCreator):