Skip to content

WIP: Multiturn Benchmarking Support #211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: feat/unified_scheduler
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 52 additions & 25 deletions src/guidellm/dataset/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import random
from collections.abc import Iterable, Iterator
from pathlib import Path
from typing import Any, Literal, Optional, Union
from typing import Any, Optional, TypedDict, Union

import yaml
from datasets import (
Expand Down Expand Up @@ -63,6 +63,26 @@ class SyntheticDatasetConfig(BaseModel):
gt=0,
default=None,
)
turns: int = Field(
description="The number of turns in the conversation.",
gt=0,
default=1,
)
turns_stdev: Optional[int] = Field(
description="The standard deviation of the number of turns.",
gt=0,
default=None,
)
turns_min: Optional[int] = Field(
description="The minimum number of turns in the conversation.",
gt=0,
default=None,
)
turns_max: Optional[int] = Field(
description="The maximum number of turns in the conversation.",
gt=0,
default=None,
)
samples: int = Field(
description="The number of samples to generate for the dataset.",
gt=0,
Expand Down Expand Up @@ -118,14 +138,13 @@ def parse_config_file(data: Union[str, Path]) -> "SyntheticDatasetConfig":
return SyntheticDatasetConfig(**config_dict)


class SyntheticTextItemsGenerator(
Iterable[
dict[
Literal["prompt", "prompt_tokens_count", "output_tokens_count"],
Union[str, int],
]
]
):
class SyntheticDatasetRow(TypedDict):
prompt: list[str]
prompt_tokens_count: list[int]
output_tokens_count: list[int]


class SyntheticTextItemsGenerator(Iterable[SyntheticDatasetRow]):
def __init__(
self,
config: SyntheticDatasetConfig,
Expand All @@ -141,12 +160,7 @@ def __init__(

def __iter__(
self,
) -> Iterator[
dict[
Literal["prompt", "prompt_tokens_count", "output_tokens_count"],
Union[str, int],
]
]:
) -> Iterator[SyntheticDatasetRow]:
prompt_tokens_sampler = IntegerRangeSampler(
average=self.config.prompt_tokens,
variance=self.config.prompt_tokens_stdev,
Expand All @@ -161,20 +175,33 @@ def __iter__(
max_value=self.config.output_tokens_max,
random_seed=self.random_seed + 1, # ensure diff dist from prompts
)
turns_sampler = IntegerRangeSampler(
average=self.config.turns,
variance=self.config.turns_stdev,
min_value=self.config.turns_min,
max_value=self.config.turns_max,
random_seed=self.random_seed + 7, # ensure diff dist
)
# ensure diff distribution from output tokens
rand = random.Random(self.random_seed + 2) # noqa: S311

for _, prompt_tokens, output_tokens in zip(
range(self.config.samples),
prompt_tokens_sampler,
output_tokens_sampler,
):
start_index = rand.randint(0, len(self.text_creator.words))
yield {
"prompt": self._create_prompt(prompt_tokens, start_index),
"prompt_tokens_count": prompt_tokens,
"output_tokens_count": output_tokens,
for _, turns in zip(range(self.config.samples), turns_sampler):
row: SyntheticDatasetRow = {
"prompt": [],
"prompt_tokens_count": [],
"output_tokens_count": [],
}
for _, prompt_tokens, output_tokens in zip(
range(turns),
prompt_tokens_sampler,
output_tokens_sampler,
):
start_index = rand.randint(0, len(self.text_creator.words))
row["prompt"].append(self._create_prompt(prompt_tokens, start_index))
row["prompt_tokens_count"].append(prompt_tokens)
row["output_tokens_count"].append(output_tokens)

yield row

def _create_prompt(self, prompt_tokens: int, start_index: int) -> str:
if prompt_tokens <= 0:
Expand Down
47 changes: 47 additions & 0 deletions src/guidellm/preprocess/item.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from collections.abc import Sequence
from typing import Generic, Optional, TypeVar

from pydantic import Field

from guidellm.objects.pydantic import StandardBaseModel

PromptT = TypeVar("PromptT")


class Item(StandardBaseModel, Generic[PromptT]):
"""
Represents a single item in a dataset,
containing a prompt and its associated metadata.
"""

value: PromptT = Field(
description="The prompt text or data for the item.",
examples=[
"What is the capital of France?",
"Explain quantum computing in simple terms.",
],
)
prompt_tokens: Optional[int] = Field(
default=None, gt=0, description="Number of tokens in the prompt"
)
output_tokens: Optional[int] = Field(
default=None, gt=0, description="Number of tokens in the output"
)


class ItemList(Sequence[Item[PromptT]]):
"""
Represents a list of items, each containing a prompt and its metadata.
"""

shared_prefix: Optional[PromptT]

def __init__(self, *items: Item[PromptT], shared_prefix: Optional[PromptT] = None):
self.shared_prefix = shared_prefix
self._items = list(items)

def __getitem__(self, key):
return self._items[key]

def __len__(self) -> int:
return len(self._items)
25 changes: 12 additions & 13 deletions src/guidellm/request/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
from transformers import PreTrainedTokenizerBase # type: ignore[import]

from guidellm.config import settings
from guidellm.dataset import ColumnInputTypes, load_dataset
from guidellm.objects import StandardBaseModel
from guidellm.request.request import GenerationRequest
from guidellm.preprocess.item import Item, ItemList
from guidellm.request.session import GenerativeRequestSession

__all__ = [
Expand Down Expand Up @@ -113,7 +112,7 @@ def __iter__(self) -> Iterator[GenerativeRequestSession]:
scope_create_count += 1

for item in dataset_iter:
yield GenerativeRequestSession(self._create_request(item))
yield GenerativeRequestSession(self._create_items(item))

self._preserved_iter = None

Expand Down Expand Up @@ -261,7 +260,8 @@ def _get_dataset_iter(

return dataset_iter

def _create_request(self, item: dict[str, Any]) -> GenerationRequest:
def _create_items(self, item: dict[str, Any]) -> ItemList:
prompts = item[self.column_mappings["prompt_column"]]
prompt_tokens = (
item[self.column_mappings["prompt_tokens_count_column"]]
if "prompt_tokens_count_column" in self.column_mappings
Expand All @@ -273,13 +273,12 @@ def _create_request(self, item: dict[str, Any]) -> GenerationRequest:
else None
)

return GenerationRequest(
request_type=settings.preferred_route,
content=item[self.column_mappings["prompt_column"]],
stats=(
{"prompt_tokens": prompt_tokens} if prompt_tokens is not None else {}
),
constraints=(
{"output_tokens": output_tokens} if output_tokens is not None else {}
),
items = (
Item(value=prompt, output_tokens=out_t, prompt_tokens=in_t)
for prompt, in_t, out_t in zip(
prompts if isinstance(prompts, list) else [prompts],
prompt_tokens if isinstance(prompt_tokens, list) else [prompt_tokens],
output_tokens if isinstance(output_tokens, list) else [output_tokens],
)
)
return ItemList(*items)
67 changes: 53 additions & 14 deletions src/guidellm/request/session.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import itertools
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from typing import TYPE_CHECKING, Generic

if TYPE_CHECKING:
from collections.abc import Sequence

from guidellm.backend.response import ResponseSummary
from guidellm.config import settings
from guidellm.preprocess.item import Item, ItemList
from guidellm.request.request import GenerationRequest
from guidellm.scheduler.types import RequestT, ResponseT

__all__ = ["GenerativeRequestSession", "RequestSession"]

# TODO: Replace with specific types that implement needed features
RequestT = TypeVar("RequestT")
ResponseT = TypeVar("ResponseT")


class RequestSession(ABC, Generic[RequestT, ResponseT]):
@abstractmethod
Expand All @@ -29,24 +32,60 @@ def push_response(self, response: ResponseT) -> None: ...
def complete(self) -> bool: ...


# TODO: Implement multiturn support
class GenerativeRequestSession(RequestSession[GenerationRequest, ResponseSummary]):
def __init__(self, request: GenerationRequest) -> None:
self.request = request
self._complete = False
def __init__(self, items: ItemList) -> None:
if len(items) < 1:
raise ValueError("Prompts cannot be empty")

self.prompts: Sequence[Item] = items
self.responses: list[Item] = []

def __len__(self) -> int:
return 1
return len(self.prompts)

def get_next_request(self) -> GenerationRequest:
return self.request
completed_responses = len(self.responses)

# FIXME: Can only handle string requests
content = "".join(
itertools.chain.from_iterable(
(x.value, y.value)
for x, y in zip(self.prompts, self.responses + [Item(value="")])
)
)

prev_prompt_tokens = sum(
(x.prompt_tokens or 0) + (x.output_tokens or 0) for x in self.responses
)
prompt_tokens = (
self.prompts[completed_responses].prompt_tokens or 0
) + prev_prompt_tokens

output_tokens = self.prompts[completed_responses].output_tokens

return GenerationRequest(
request_type=settings.preferred_route,
content=content,
stats=({"prompt_tokens": prompt_tokens} if prompt_tokens else {}),
constraints=({"output_tokens": output_tokens} if output_tokens else {}),
)

def get_next_delay(self) -> float:
return 0.0

def push_response(self, response: ResponseSummary) -> None: # noqa: ARG002
self._complete = True
def push_response(self, response: ResponseSummary) -> None:
if len(self.responses) < len(self.prompts):
resp = Item(
value=response.value,
prompt_tokens=response.response_prompt_tokens
or response.request_prompt_tokens,
output_tokens=response.response_output_tokens
or response.request_output_tokens,
)
self.responses.append(resp)
else:
raise ValueError("Response list full")

@property
def complete(self) -> bool:
return self._complete
return len(self.responses) >= len(self.prompts)
Loading