diff --git a/docs/guides/index.md b/docs/guides/index.md index a85ef6294..7005e268a 100644 --- a/docs/guides/index.md +++ b/docs/guides/index.md @@ -68,6 +68,14 @@ Whether you're interested in understanding the system architecture, exploring su [:octicons-arrow-right-24: Over-Saturation Guide](over_saturation_stopping.md) +- :material-wrench:{ .lg .middle } Tool Calling + + ______________________________________________________________________ + + Benchmark multi-turn tool calling workloads with pre-anticipated tool call turns, synthetic data, and dataset-driven tool definitions. + + [:octicons-arrow-right-24: Tool Calling Guide](tool_calling.md) + - :material-image-multiple:{ .lg .middle } Multimodal Benchmarking ______________________________________________________________________ diff --git a/docs/guides/multiturn.md b/docs/guides/multiturn.md index c32f5d420..35dd629ff 100644 --- a/docs/guides/multiturn.md +++ b/docs/guides/multiturn.md @@ -173,6 +173,10 @@ When enabled, GuideLLM sends only the current turn's input and references the pr - If the server does not support response storage, requests on turn 2+ will fail with an error (typically a 404). - This option is only valid with `/v1/responses`. Using it with other request formats raises an error at startup. +## Tool Calling + +Multi-turn tool calling is supported as part of multi-turn benchmarks. See the dedicated [Tool Calling Guide](tool_calling.md) for full documentation on server setup, tool definitions, tool choice configuration, and edge cases. + ## The TurnPivot Preprocessor GuideLLM supports passing multiple `--data` options, each pointing to a separate dataset. If there are matches for the same column type across multiple datasets, they are treated as separate batches. Normally this is useful for layering columns from different datasets within the same request. For example adding a text column from one dataset to another with images or combining multiple normally-distributed synthetic datasets into a multimodal distribution. We can use the **TurnPivot** preprocessor to transpose turn columns and dataset batches. diff --git a/docs/guides/tool_calling.md b/docs/guides/tool_calling.md new file mode 100644 index 000000000..a84769140 --- /dev/null +++ b/docs/guides/tool_calling.md @@ -0,0 +1,168 @@ +# Tool Calling + +GuideLLM supports benchmarking multi-turn tool calling workloads. Tool call turns are **pre-anticipated**: the data pipeline decides upfront which turns expect a tool call and which expect plain text. GuideLLM does not dynamically create follow-up turns at runtime. Instead, the full conversation structure is planned during data generation, and the worker executes each turn in order, with each tool call being scheduled like any other turn by the profile. + +When a tool-call turn completes, GuideLLM appends a tool result to the conversation history and proceeds to the next pre-planned turn. The tool result content comes from one of three sources (in priority order): the dataset's tool response column, synthetic data configured via `tool_response_tokens`, or a short placeholder (`{"status": "ok"}`). All turns where a tool call is not anticipated have `tool_choice` overridden to `"none"` for predictability. + +## Mocked client-side tool calls + +GuideLLM currently supports mocked client-side tool calls. This means that the inference server runs the model and may return real `tool_calls`, but GuideLLM **does not execute** those functions against live APIs or other runtimes. The benchmark worker acts as a **mock client**: after each tool-call turn it injects the next `role: "tool"` message into client-side chat history for the following request. This allows measuring LLM throughput with tool-call handling, not external tool latency or side effects. + +## Server Setup + +Tool calling requires server-side support. For vLLM, enable auto tool choice and a parser matching your model: + +```bash +vllm serve Qwen/Qwen3-0.6B \ + --enable-auto-tool-choice \ + --tool-call-parser hermes +``` + +Common parsers: `hermes` (Qwen/Hermes), `llama3_json` (Llama 3.x), `mistral` (Mistral). Without these flags, vLLM will reject tool call output with grammar errors. + +## Providing Tool Definitions + +Tool definitions are always provided through the data pipeline rather than as a global CLI flag. There are three ways to supply them: + +**1. Synthetic data** -- set `tool_call_turns` (and optionally `tools`) in the data configuration: + +```bash +guidellm benchmark run \ + --target "http://localhost:8000" \ + --model "Qwen/Qwen3-0.6B" \ + --request-format /v1/chat/completions \ + --data '{"prompt_tokens": 200, "output_tokens": 100, "turns": 3, "tool_call_turns": 2}' \ + --max-requests 30 \ + --profile constant \ + --rate 1 +``` + +To specify non-contiguous tool-call turns, pass a list of 0-based turn indices: + +```bash +guidellm benchmark run \ + --target "http://localhost:8000" \ + --model "Qwen/Qwen3-0.6B" \ + --request-format /v1/chat/completions \ + --data '{"prompt_tokens": 200, "output_tokens": 100, "turns": 4, "tool_call_turns": [0, 2]}' \ + --max-requests 30 \ + --profile constant \ + --rate 1 +``` + +Synthetic data configuration fields for tool calling: + +| Field | Type | Default | Description | +| ---------------------------- | ------------------ | ------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `tool_call_turns` | `int \| list[int]` | `0` | Which turns include tool definitions and expect tool-call responses. An int N means "the first N turns"; a list of ints specifies explicit 0-based turn indices (e.g. `[0, 2]`). Normalized to a sorted list internally. When `0` or `[]`, no tool calling. | +| `tools` | `list` | `None` | Tool definitions in OpenAI format. When `None`, a built-in placeholder tool is used. Custom definitions can be provided inline: `"tools": [{"type": "function", ...}]`. | +| `tool_response_tokens` | `int` | `None` | Average number of tokens for synthetic tool call responses. When `None`, a short placeholder (`{"status": "ok"}`) is used. | +| `tool_response_tokens_stdev` | `int` | `None` | Standard deviation for tool response token count. | +| `tool_response_tokens_min` | `int` | `None` | Minimum number of tokens for tool response. | +| `tool_response_tokens_max` | `int` | `None` | Maximum number of tokens for tool response. | + +Note: The token count is for the content of a field of the mock tool call response. The JSON structure adds ~5 tokens to the mock tool call response. + +**Configuring tool response content** -- by default, tool results use a short placeholder (`{"status": "ok"}`). This default can be changed via the `GUIDELLM__DEFAULT_SYNTHETIC_TOOL_RESPONSE` environment variable. For more realistic benchmarks, set `tool_response_tokens` to generate variable-length JSON responses: + +```bash +guidellm benchmark run \ + --target "http://localhost:8000" \ + --model "Qwen/Qwen3-0.6B" \ + --request-format /v1/chat/completions \ + --data '{"prompt_tokens": 200, "output_tokens": 100, "turns": 3, "tool_call_turns": 2, "tool_response_tokens": 50}' \ + --max-requests 30 \ + --profile constant \ + --rate 1 +``` + +The `tool_response_tokens_stdev`, `tool_response_tokens_min`, and `tool_response_tokens_max` fields work identically to the corresponding `prompt_tokens_*` / `output_tokens_*` variance parameters. + +**2. Datasets with a tools column** -- datasets that already contain tool definitions (e.g. `madroid/glaive-function-calling-openai`) work directly. The column mapper auto-detects columns named `tools`, `functions`, or `tool_definitions`: + +```bash +guidellm benchmark run \ + --target "http://localhost:8000" \ + --data "madroid/glaive-function-calling-openai" \ + --data-column-mapper '{"text_column": "messages", "tools_column": "tools"}' \ + --data-preprocessors "tool_calling_message_extractor,encode_media" \ + --max-requests 50 \ + --profile constant \ + --rate 1 +``` + +The `tool_calling_message_extractor` preprocessor must be explicitly enabled via `--data-preprocessors` (it is not included by default). It parses each row's `messages` array and extracts prompts, system messages, and tool results into the appropriate columns. If the dataset has no tool result messages, the placeholder (`{"status": "ok"}`) is used as a fallback. + +## Tool Choice and Missing Tool Call Behavior + +Two backend settings control how tool-call turns are handled at runtime. Both are configured via `--backend-kwargs`: + +| Setting | Values | Default | Description | +| ---------------------------- | ---------------------------------------------- | ------------ | ------------------------------------------------------------------------------------------------------------------------------- | +| `extras.body.tool_choice` | `required`, `auto`, `none` | `required` | Sent as the `tool_choice` API parameter on turns that expect a tool call. On non-tool turns, it is always overridden to `none`. | +| `tool_call_missing_behavior` | `ignore_continue`, `ignore_stop`, `error_stop` | `error_stop` | What the backend does when a tool call was expected but the model produced plain text instead. | + +**Setting `tool_choice` via `--backend-kwargs`:** + +```bash +guidellm benchmark run \ + --target "http://localhost:8000" \ + --data '{"prompt_tokens": 200, "output_tokens": 100, "turns": 3, "tool_call_turns": 2}' \ + --backend-kwargs '{"extras": {"body": {"tool_choice": "auto"}}}' \ + --max-requests 30 +``` + +**Setting `tool_call_missing_behavior` via `--backend-kwargs`:** + +```bash +guidellm benchmark run \ + --target "http://localhost:8000" \ + --data '{"prompt_tokens": 200, "output_tokens": 100, "turns": 3, "tool_call_turns": 2}' \ + --backend-kwargs '{"tool_call_missing_behavior": "ignore_continue", "extras": {"body": {"tool_choice": "auto"}}}' \ + --max-requests 30 +``` + +**`tool_choice` implications:** + +- `required` (default) -- the model **must** produce a tool call. This gives the most predictable benchmarks and the fewest errors, since the server constrains the output to valid tool call JSON. Use this when you don't want to rely on the model choosing to use tools. However, it may slow down the server due to forcing the server to choose low-probability options. +- `auto` -- the model decides whether to call a tool. Useful for testing how often a model chooses to invoke tools, but increases the chance of missing tool calls (see `tool_call_missing_behavior`). +- `none` -- tools are present in the request but the model cannot call them. This is primarily set automatically on the final (plain-text) turn; setting it globally disables tool calling entirely. + +Note that `required` vs `auto` can also result in different model behavior. For example, the Qwen models only show their pre-tool-call thinking with `auto`. + +**`tool_call_missing_behavior` implications:** + +This setting only matters when `tool_choice` is `auto` (or `required` and the server doesn't enforce it): + +- `error_stop` (default) -- the current turn is marked as **errored** and all remaining turns are cancelled. Surfaces problems immediately. Best for validating that the model and server are correctly configured. +- `ignore_stop` -- the current turn is marked as **cancelled** (incomplete) and all remaining turns are cancelled. The model's response is preserved in the output but the turn's status reflects that the expected tool call was not produced. Use this when a missing tool call means the conversation can't continue meaningfully but shouldn't be treated as an error. +- `ignore_continue` -- the current turn is treated as **completed** and the conversation continues to the next turn. Each future tool-call turn is evaluated independently. Use this when you want to measure how many tool calls actually happen under `auto` mode without aborting the conversation. + +### Recommended scenarios + +| Tool Choice | Missing Behavior | Current Turn Status | Description | +| ----------- | ----------------- | ------------------- | ------------------------------------------------------------------------------------------------------------ | +| `required` | `error_stop` | errored | (default) Good for consistent and predictable behavior with synthetic data. May slow down the server. | +| `auto` | `ignore_continue` | completed | Good for testing `auto` behavior without the model choosing to not use a tool call causing errors. | +| `auto` | `ignore_stop` | cancelled | Good for testing `auto` behavior but ends the conversation early once the model creates a non-tool response. | + +## Output Token Limits on Tool-Call Turns + +When `output_tokens` is configured (either via synthetic data or a dataset column), GuideLLM normally sets `ignore_eos=True` and clears stop sequences to force the model to generate exactly N tokens. On tool-call turns, these settings are **automatically removed** because they are incompatible with vLLM's constrained decoding grammar: + +- **`ignore_eos`** conflicts with the grammar's terminal state. Constrained decoding guides token selection via a finite-state machine that marks EOS as the only valid token once the JSON is complete. `ignore_eos` masks out EOS, creating an impossible state with no valid tokens — causing server errors or runaway generation. +- **`stop=None`** removes stop sequences that the tool-call parser may rely on internally (e.g. `<|eot_id|>` for Llama models). +- **`max_tokens` / `max_completion_tokens`** would truncate mid-JSON, producing invalid tool call arguments that corrupt the conversation history on follow-up turns. + +As a result, tool-call turns generate output whose length is determined by the model and tool schema rather than the configured `output_tokens`. The model stops as soon as it produces valid JSON for the function name and arguments. This typically results in shorter output (20–80 tokens) compared to the configured target. + +Plain-text turns (where `tool_choice="none"`) are unaffected and continue to respect `output_tokens` / `ignore_eos` as normal, even when tool definitions are present in the request. + +## Edge Cases + +- **Single-turn tool calling** (`turns=1, tool_call_turns=1` or `tool_call_turns=[0]`) is supported. The conversation has one turn that expects a tool call and no plain-text response. +- **All-tool conversations** (e.g. `tool_call_turns=3` with `turns=3`, or `tool_call_turns=[0,1,2]`) are supported. Every turn is a tool-call turn and the model never produces a final plain-text response. The `output` field in `benchmarks.json` will be `None` for every request; use the `tool_calls` field to inspect model output. +- **Non-contiguous tool turns** (e.g. `tool_call_turns=[0, 2]` with `turns=4`) are supported. Only the specified turns expect tool calls; other turns produce plain text. +- **Tool definitions on non-tool turns** are still sent in the request (they're part of the data), but `tool_choice` is forced to `none` so the model produces text. This matches real-world agentic patterns where the tools remain available but the model is instructed to respond in natural language. +- **Mixed datasets** where only some rows have a `tools_column` work correctly. Rows without tools are treated as plain text conversations; rows with tools follow the tool-call flow. +- **Rate-limited profiles** (e.g. `--profile constant --rate 1`) pace follow-up tool turns through the same scheduler as any other request. The follow-up turn is requeued and waits for the next available scheduling slot, so the effective delay between turns is determined by the profile, not by the tool calling logic. diff --git a/src/guidellm/backends/openai/http.py b/src/guidellm/backends/openai/http.py index 527bb53aa..9536e9402 100644 --- a/src/guidellm/backends/openai/http.py +++ b/src/guidellm/backends/openai/http.py @@ -13,7 +13,7 @@ import asyncio import time from collections.abc import AsyncIterator -from typing import Any +from typing import Any, Literal import httpx from pydantic import Field, field_validator @@ -82,6 +82,17 @@ class OpenAIHttpBackendArgs(BackendArgs): "multi-turn requests. Only supported with /v1/responses." ), ) + tool_call_missing_behavior: Literal[ + "ignore_continue", "ignore_stop", "error_stop" + ] = Field( + default="error_stop", + description=( + "What happens when a tool call is expected but the model does not " + "produce one. Options: ignore_continue (continue to next turn), " + "ignore_stop (cancel remaining turns), error_stop (error and " + "cancel remaining turns)." + ), + ) @field_validator("request_format") @classmethod @@ -175,6 +186,7 @@ def __init__( max_tokens: int | None = None, max_completion_tokens: int | None = None, server_history: bool = False, + tool_call_missing_behavior: str = "error_stop", ): """ Initialize OpenAI HTTP backend with server configuration. @@ -191,6 +203,9 @@ def __init__( :param validate_backend: Backend validation configuration :param server_history: Use server-side conversation history (previous_response_id) for multi-turn. Only with /v1/responses. + :param tool_call_missing_behavior: What happens when a tool call is + expected but missing. One of: ignore_continue, ignore_stop, + error_stop. """ super().__init__(type_="openai_http") @@ -240,6 +255,7 @@ def __init__( else extras ) self.max_tokens: int | None = max_tokens or max_completion_tokens + self.tool_call_missing_behavior: str = tool_call_missing_behavior # Runtime state self._in_process = False @@ -401,6 +417,7 @@ async def resolve( # type: ignore[override, misc] extras=self.extras, max_tokens=self.max_tokens, server_history=self.server_history, + turn_index=request_info.turn_index, ) request_url = f"{self.target}/{request_path}" @@ -431,10 +448,11 @@ async def resolve( # type: ignore[override, misc] request_info.timings.request_end = time.time() response.raise_for_status() data = response.json() - yield ( - request_handler.compile_non_streaming(request, arguments, data), - request_info, + gen_response = request_handler.compile_non_streaming( + request, arguments, data ) + yield gen_response, request_info + self._check_tool_call_expectations(request, gen_response) return try: @@ -480,7 +498,9 @@ async def resolve( # type: ignore[override, misc] request_info.timings.token_iterations += iterations request_info.timings.request_end = time.time() - yield request_handler.compile_streaming(request, arguments), request_info + gen_response = request_handler.compile_streaming(request, arguments) + yield gen_response, request_info + self._check_tool_call_expectations(request, gen_response) except asyncio.CancelledError as err: # Yield current result to store iterative results before propagating yield request_handler.compile_streaming(request, arguments), request_info @@ -522,6 +542,37 @@ def _build_headers( return headers or None + def _check_tool_call_expectations( + self, + request: GenerationRequest, + response: GenerationResponse, + ) -> None: + """Validate that a tool-call turn actually produced tool calls. + + Called after the final yield in ``resolve`` so the response is + delivered to the worker before any exception propagates. When the + request expected a tool call but the model didn't produce one, + raises an exception according to ``tool_call_missing_behavior``: + + * ``ignore_continue`` -- no-op; the conversation proceeds normally. + * ``ignore_stop`` -- raises :class:`asyncio.CancelledError` so the + worker cancels remaining turns. + * ``error_stop`` -- raises :class:`ValueError` so the worker marks + the current turn as errored and cancels remaining turns. + + :param request: The generation request that was resolved. + :param response: The compiled response from the model. + """ + if not request.expects_tool_call or response.tool_calls: + return + + if self.tool_call_missing_behavior == "ignore_continue": + pass + elif self.tool_call_missing_behavior == "ignore_stop": + raise asyncio.CancelledError("Expected tool call but model produced none") + elif self.tool_call_missing_behavior == "error_stop": + raise ValueError("Expected tool call but model produced none") + def _resolve_validate_kwargs( self, validate_backend: bool | str | dict[str, Any] ) -> dict[str, Any] | None: diff --git a/src/guidellm/backends/openai/request_handlers.py b/src/guidellm/backends/openai/request_handlers.py index d833bdff5..04a9cdc88 100644 --- a/src/guidellm/backends/openai/request_handlers.py +++ b/src/guidellm/backends/openai/request_handlers.py @@ -18,6 +18,8 @@ from guidellm.scheduler import HistoryT from guidellm.schemas import GenerationRequest, GenerationResponse, UsageMetrics from guidellm.schemas.request import GenerationRequestArguments +from guidellm.schemas.tool_call import StreamingToolCall, StreamingToolCallFunction +from guidellm.settings import settings from guidellm.utils.imports import json from guidellm.utils.registry import RegistryMixin @@ -29,6 +31,8 @@ "OpenAIRequestHandlerFactory", "PoolingRequestHandler", "ResponsesRequestHandler", + "StreamingToolCall", + "StreamingToolCallFunction", "TextCompletionsRequestHandler", ] @@ -407,7 +411,12 @@ class ChatCompletionsRequestHandler(TextCompletionsRequestHandler): def __init__(self): super().__init__() + # Indices seen so far, used to count distinct tool calls self.streaming_tool_call_indices: set[int] = set() + # Full tool call payloads accumulated across streaming deltas, + # keyed by the delta ``index`` field. Needed for multi-turn tool + # calling so the response carries the id/name/arguments of each call. + self.streaming_tool_calls: dict[int, StreamingToolCall] = {} def _format_prompts( self, column_data: list[dict[str, Any]], column_type: str @@ -451,7 +460,95 @@ def _format_prompts( return formatted_data - def format( # noqa: C901 + @staticmethod + def _build_tool_response_messages( + tool_calls: list[StreamingToolCall], + tool_response_columns: list[Any], + ) -> list[dict[str, Any]]: + """Build synthetic ``role: "tool"`` messages for each tool call. + + Uses per-request tool response content from ``tool_response_columns`` + when available, falling back to + :attr:`settings.default_synthetic_tool_response`. + + :param tool_calls: The tool call objects from the prior assistant response. + :param tool_response_columns: Per-tool-call response content from the + dataset, which may be ``str`` or ``bytes`` (orjson). + :return: List of tool-role message dicts ready to append to messages. + """ + messages: list[dict[str, Any]] = [] + for idx, tc in enumerate(tool_calls): + raw_content = ( + tool_response_columns[idx] + if idx < len(tool_response_columns) + else settings.default_synthetic_tool_response + ) + # orjson.dumps returns bytes; ensure content is a string. + content = ( + raw_content.decode("utf-8") + if isinstance(raw_content, bytes) + else raw_content + ) + messages.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": content, + } + ) + return messages + + @staticmethod + def _apply_tool_call_overrides( + body: dict[str, Any], + data: GenerationRequest, + ) -> None: + """Inject tool definitions and constrain the request body for tool calling. + + Handles three concerns: + + 1. Deserializes and injects tool definitions from dataset columns. + 2. Sets ``tool_choice`` to ``"required"`` or ``"none"`` depending on + whether the current turn expects a tool call. + 3. Removes body keys that are incompatible with tool calling + (``ignore_eos``, ``stop``, and token-limit keys on tool-call turns). + + :param body: The mutable request body dict being built. + :param data: The current generation request. + """ + tools_column = data.columns.get("tools_column", []) + if tools_column: + tools_value = tools_column[0] + # JSON-serialized tool definitions (e.g. from synthetic data + # generators that store tools as strings for HuggingFace + # Features compatibility). orjson produces bytes; stdlib + # json produces str. + if isinstance(tools_value, str | bytes): + tools_value = json.loads(tools_value) + if isinstance(tools_value, list): + body["tools"] = tools_value + body.setdefault("tool_choice", "required") + + if "tools" not in body: + return + + # Override tool_choice to "none" on turns that don't expect tool calls, + # so the model produces a plain text response instead. + if not data.expects_tool_call: + body["tool_choice"] = "none" + + # Tool calling requires the model to stop naturally after producing + # valid JSON; ignore_eos would force generation past that point and + # break the server's constrained decoding grammar. + # max_completion_tokens would truncate output mid-JSON and corrupt + # the arguments sent in conversation history on follow-up turns. + if data.expects_tool_call: + body.pop("ignore_eos", None) + body.pop("stop", None) + body.pop("max_completion_tokens", None) + body.pop("max_tokens", None) + + def format( # noqa: C901, PLR0912, PLR0915 self, data: GenerationRequest, response: GenerationResponse | None = None, @@ -461,7 +558,9 @@ def format( # noqa: C901 """ Format the chat completion generation request into the appropriate structure. - :param request: The generation request to format + :param data: The generation request to format + :param response: Optional prior response for multi-turn history + :param history: Prior (request, response) pairs in the conversation :param **kwargs: Additional keyword arguments for request formatting :return: The formatted request arguments """ @@ -522,18 +621,36 @@ def format( # noqa: C901 self._format_prompts(data.columns.get(col, []), col) for col in ("text_column", "image_column", "video_column", "audio_column") ] - if prompts: - # Interleave prompt types + user_content = list(roundrobin(*prompts)) + if user_content: + arguments.body["messages"].append({"role": "user", "content": user_content}) + + # Append the prior assistant response to the message history. + # For tool call responses, include the assistant's tool_calls and + # synthetic tool result messages so the model sees the full + # multi-turn exchange. For plain text responses, just add content. + if response and response.tool_calls: arguments.body["messages"].append( - {"role": "user", "content": list(roundrobin(*prompts))} + { + "role": "assistant", + "content": response.text, + "tool_calls": [tc.model_dump() for tc in response.tool_calls], + } ) - - # Add the response to the current prompt if available - if response and response.text: + tool_response_columns = data.columns.get("tool_response_column", []) + arguments.body["messages"].extend( + self._build_tool_response_messages( + response.tool_calls, tool_response_columns + ) + ) + elif response and response.text: arguments.body["messages"].append( {"role": "assistant", "content": response.text} ) + # Inject tool definitions and apply tool-call-specific overrides. + self._apply_tool_call_overrides(arguments.body, data) + return arguments def compile_non_streaming( @@ -549,6 +666,7 @@ def compile_non_streaming( structure specific to chat completion endpoints. :param request: Original generation request + :param arguments: The request arguments that were sent :param response: Complete API response containing choices and usage data :return: Standardized GenerationResponse with extracted content and metrics """ @@ -560,7 +678,10 @@ def compile_non_streaming( if text is None and not raw_tool_calls: text = "" # Edge case: null content and no tools input_metrics, output_metrics = self.extract_metrics(usage, text) + + tool_calls: list[StreamingToolCall] | None = None if raw_tool_calls: + tool_calls = [StreamingToolCall.model_validate(tc) for tc in raw_tool_calls] output_metrics.tool_call_count = len(raw_tool_calls) if text is None: # tool-only turn output_metrics.tool_call_tokens = output_metrics.text_tokens @@ -572,6 +693,7 @@ def compile_non_streaming( request_args=arguments.model_dump_json(), response_id=response.get("id"), # use vLLM ID if available text=text, + tool_calls=tool_calls, input_metrics=input_metrics, output_metrics=output_metrics, ) @@ -602,11 +724,10 @@ def add_streaming_line(self, line: str) -> int | None: self.streaming_texts.append(content) updated = True - # tool_calls is an optional field for when the server is requesting a tool + # Accumulate streamed tool_calls deltas. Each tool call may be split + # across multiple chunks; we reassemble by ``index``. for tc_delta in delta.get("tool_calls", []): - # Keep track of the index to properly count tool usage, since a tool call - # can be split into multiple chunks when streaming. - self.streaming_tool_call_indices.add(tc_delta["index"]) + self._accumulate_tool_call_delta(tc_delta) updated = True if usage: @@ -614,6 +735,35 @@ def add_streaming_line(self, line: str) -> int | None: return 1 if updated else 0 + def _accumulate_tool_call_delta(self, tc_delta: dict[str, Any]) -> None: + """Merge a single streaming tool_call delta into accumulated state. + + Each tool call is split across multiple SSE chunks. This method + creates or updates the :class:`StreamingToolCall` entry keyed by the + delta's ``index`` field. + + :param tc_delta: A single element from the ``tool_calls`` array in a + streaming chat completion delta. + """ + idx = tc_delta["index"] + self.streaming_tool_call_indices.add(idx) + + if idx not in self.streaming_tool_calls: + self.streaming_tool_calls[idx] = StreamingToolCall( + id=tc_delta.get("id", ""), + type=tc_delta.get("type", "function"), + ) + + tc = self.streaming_tool_calls[idx] + fn_delta = tc_delta.get("function", {}) + + if fn_id := tc_delta.get("id"): + tc.id = fn_id + if fn_name := fn_delta.get("name"): + tc.function.name += fn_name + if fn_args := fn_delta.get("arguments"): + tc.function.arguments += fn_args + def compile_streaming( self, request: GenerationRequest, arguments: GenerationRequestArguments ) -> GenerationResponse: @@ -635,11 +785,18 @@ def compile_streaming( else: # mixed content + tool call turn output_metrics.mixed_content_tool_tokens = output_metrics.text_tokens + tool_calls: list[StreamingToolCall] | None = None + if self.streaming_tool_calls: + tool_calls = [ + self.streaming_tool_calls[i] for i in sorted(self.streaming_tool_calls) + ] + return GenerationResponse( request_id=request.request_id, request_args=arguments.model_dump_json(), response_id=self.streaming_response_id, # use vLLM ID if available text=text, + tool_calls=tool_calls, input_metrics=input_metrics, output_metrics=output_metrics, ) diff --git a/src/guidellm/benchmark/schemas/generative/accumulator.py b/src/guidellm/benchmark/schemas/generative/accumulator.py index cd82685c2..d45466c3b 100644 --- a/src/guidellm/benchmark/schemas/generative/accumulator.py +++ b/src/guidellm/benchmark/schemas/generative/accumulator.py @@ -693,6 +693,7 @@ def clear_stats_data(self, stats: GenerativeRequestStats | int): stats.request_args = None if self.clear_nonsampled_outputs: stats.output = None + stats.tool_calls = None @classmethod def compile_stats( diff --git a/src/guidellm/data/deserializers/synthetic.py b/src/guidellm/data/deserializers/synthetic.py index 3a1d0d255..7d6b7bcc6 100644 --- a/src/guidellm/data/deserializers/synthetic.py +++ b/src/guidellm/data/deserializers/synthetic.py @@ -18,6 +18,8 @@ DatasetDeserializerFactory, ) from guidellm.data.schemas import SyntheticTextDatasetConfig +from guidellm.settings import settings +from guidellm.utils.imports import json from guidellm.utils.random import IntegerRangeSampler __all__ = [ @@ -25,6 +27,23 @@ "SyntheticTextDatasetDeserializer", ] +# Placeholder tool definition used when the user doesn't supply their own +# tools but configures tool_call_turns with at least one turn. +DEFAULT_SYNTHETIC_TOOLS: list[dict[str, Any]] = [ + { + "type": "function", + "function": { + "name": "get_data", + "description": "Retrieve data from the system", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string", "description": "The query"}}, + "required": ["query"], + }, + }, + } +] + class _SyntheticTextExamplesIterable(_BaseExamplesIterable): """Custom examples iterable for synthetic text generation.""" @@ -75,6 +94,25 @@ def __iter__(self) -> Iterator[tuple[int, dict[str, Any]]]: prefix_iter = self._create_prefix_iter(faker, rand) samples_count = 0 + # Resolve tool definitions for tool-call turns + tool_call_turns_set = set(self.config.tool_call_turns) + tools_defs: list[dict[str, Any]] | None = None + if tool_call_turns_set: + tools_defs = self.config.tools or DEFAULT_SYNTHETIC_TOOLS + + # Optional sampler for variable-length tool responses + tool_response_sampler: Iterator[int] | None = None + if self.config.tool_response_tokens is not None: + tool_response_sampler = iter( + IntegerRangeSampler( + average=self.config.tool_response_tokens, + variance=self.config.tool_response_tokens_stdev, + min_value=self.config.tool_response_tokens_min, + max_value=self.config.tool_response_tokens_max, + random_seed=iter_random_seed + 2, + ) + ) + while True: prompt_tokens_count = next(prompt_tokens_sampler) output_tokens_count = ( @@ -93,6 +131,19 @@ def __iter__(self) -> Iterator[tuple[int, dict[str, Any]]]: row[f"prompt_tokens_count_{turn}"] = prompt_tokens_count if output_tokens_count is not None: row[f"output_tokens_count_{turn}"] = output_tokens_count + + if tools_defs is not None and turn in tool_call_turns_set: + row[f"tools_{turn}"] = json.dumps(tools_defs) + + if tool_response_sampler is not None: + tr_tokens = next(tool_response_sampler) + body = self._create_prompt(tr_tokens, faker) + row[f"tool_response_{turn}"] = json.dumps({"result": body}) + else: + row[f"tool_response_{turn}"] = ( + settings.default_synthetic_tool_response + ) + samples_count += 1 yield samples_count, row @@ -103,12 +154,19 @@ def is_typed(self) -> bool: @property def features(self) -> Features: - features = {"prefix": Value("string")} + features: dict[str, Any] = {"prefix": Value("string")} for i in range(self.config.turns): features[f"prompt_{i}"] = Value("string") features[f"prompt_tokens_count_{i}"] = Value("int32") if self.config.output_tokens is not None: features[f"output_tokens_count_{i}"] = Value("int32") + + if i in set(self.config.tool_call_turns): + # Tools column is a JSON-serialised list; store as string + # to keep the HuggingFace Features schema simple. + features[f"tools_{i}"] = Value("large_string") + features[f"tool_response_{i}"] = Value("large_string") + return Features(features) @property diff --git a/src/guidellm/data/finalizers.py b/src/guidellm/data/finalizers.py index 5fda041ec..ef67c4f0b 100644 --- a/src/guidellm/data/finalizers.py +++ b/src/guidellm/data/finalizers.py @@ -112,8 +112,14 @@ def finalize_turn( # noqa: C901 PLR0912 input_metrics.audio_bytes or 0 ) + audio_bytes + # A turn expects a tool call if it has tool definitions. + # Which turns carry tools_column is controlled by the data pipeline + # (synthetic generator or dataset columns). + expects_tool_call = bool(columns.get("tools_column")) + return GenerationRequest( columns=columns, + expects_tool_call=expects_tool_call, input_metrics=input_metrics, output_metrics=output_metrics, ) diff --git a/src/guidellm/data/preprocessors/__init__.py b/src/guidellm/data/preprocessors/__init__.py index 1a2eecb71..1c1d64366 100644 --- a/src/guidellm/data/preprocessors/__init__.py +++ b/src/guidellm/data/preprocessors/__init__.py @@ -5,6 +5,7 @@ DatasetPreprocessor, PreprocessorRegistry, ) +from .tool_calling import ToolCallingMessageExtractor from .turn_pivot import TurnPivot __all__ = [ @@ -13,5 +14,6 @@ "GenerativeColumnMapper", "MediaEncoder", "PreprocessorRegistry", + "ToolCallingMessageExtractor", "TurnPivot", ] diff --git a/src/guidellm/data/preprocessors/mappers.py b/src/guidellm/data/preprocessors/mappers.py index 721ab6272..b57234de6 100644 --- a/src/guidellm/data/preprocessors/mappers.py +++ b/src/guidellm/data/preprocessors/mappers.py @@ -68,6 +68,16 @@ class GenerativeColumnMapper(DataDependentPreprocessor): "wav", "mp3", ], + "tools_column": [ + "tools", + "functions", + "tool_definitions", + ], + "tool_response_column": [ + "tool_response", + "tool_result", + "tool_output", + ], } column_name_pattern: str = ( r"^(?P(?P({name})(es|s)?)([-_](?P\d+))?)$" @@ -168,8 +178,10 @@ def datasets_mappings( dataset_columns_str, ) - # Re-enumerate to ensure we don't have a gap in turns - for turn, (_, column_name) in enumerate(sorted(turn_columns)): + # Preserve the original turn index from the column name so + # that sparse columns (e.g. tools_0, tools_3) stay aligned + # with the turns they belong to. + for turn, column_name in sorted(turn_columns): column_type = cast("GenerativeDatasetColumnType", column_type) mappings[(column_type, turn)].append((index, column_name)) diff --git a/src/guidellm/data/preprocessors/tool_calling.py b/src/guidellm/data/preprocessors/tool_calling.py new file mode 100644 index 000000000..d6ce390a3 --- /dev/null +++ b/src/guidellm/data/preprocessors/tool_calling.py @@ -0,0 +1,101 @@ +"""Preprocessor for extracting prompts from tool calling datasets. + +Handles HuggingFace datasets where prompts are stored as OpenAI-format +``messages`` arrays rather than plain text columns. +""" + +from __future__ import annotations + +from typing import Any + +from guidellm.data.preprocessors.preprocessor import ( + DatasetPreprocessor, + PreprocessorRegistry, +) + +__all__ = ["ToolCallingMessageExtractor"] + + +@PreprocessorRegistry.register("tool_calling_message_extractor") +class ToolCallingMessageExtractor(DatasetPreprocessor): + """Extract user prompts, system prompts, and tool responses from messages. + + Many tool calling datasets (e.g. ``madroid/glaive-function-calling-openai``) + store conversations as a ``messages`` column containing an array of + ``{"role": ..., "content": ...}`` dicts. This preprocessor replaces the + ``text_column`` value with the extracted user content, populates + ``prefix_column`` with the system prompt when present, and populates + ``tool_response_column`` with ``role: "tool"`` response content. + + Usage:: + + guidellm benchmark run \\ + --data madroid/glaive-function-calling-openai \\ + --data-column-mapper \\ + '{"text_column": "messages", "tools_column": "tools"}' \\ + --data-preprocessors tool_calling_message_extractor,encode_media + """ + + def __init__(self, **_: Any) -> None: + pass + + def __call__( # noqa: C901 + self, items: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + for item in items: + text_values = item.get("text_column") + if not text_values or not isinstance(text_values, list): + continue + + new_texts: list[str] = [] + prefixes: list[str] = [] + tool_responses: list[str] = [] + + for value in text_values: + if isinstance(value, list): + user_parts, system_parts, tool_parts = _extract_from_messages(value) + if user_parts: + new_texts.append(" ".join(user_parts)) + if system_parts: + prefixes.append(" ".join(system_parts)) + tool_responses.extend(tool_parts) + elif isinstance(value, str): + new_texts.append(value) + + if new_texts: + item["text_column"] = new_texts + if prefixes: + item.setdefault("prefix_column", []).extend(prefixes) + if tool_responses: + item.setdefault("tool_response_column", []).extend(tool_responses) + + return items + + +def _extract_from_messages( + messages: list[dict[str, Any]], +) -> tuple[list[str], list[str], list[str]]: + """Pull user, system, and tool response content from an OpenAI messages array. + + :return: Tuple of (user_parts, system_parts, tool_response_parts). + """ + user_parts: list[str] = [] + system_parts: list[str] = [] + tool_response_parts: list[str] = [] + + for msg in messages: + if not isinstance(msg, dict): + continue + role = msg.get("role", "") + content = msg.get("content", "") + if not content or not isinstance(content, str): + continue + + if role == "user": + user_parts.append(content) + elif role == "system": + system_parts.append(content) + elif role == "tool": + tool_response_parts.append(content) + + return user_parts, system_parts, tool_response_parts diff --git a/src/guidellm/data/schemas.py b/src/guidellm/data/schemas.py index 02330c91a..dc8d30e63 100644 --- a/src/guidellm/data/schemas.py +++ b/src/guidellm/data/schemas.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Literal +from typing import Any, Literal -from pydantic import ConfigDict, Field, model_validator +from pydantic import ConfigDict, Field, field_validator, model_validator from guidellm.schemas import StandardBaseModel @@ -23,6 +23,8 @@ "image_column", "video_column", "audio_column", + "tools_column", + "tool_response_column", ] @@ -150,6 +152,40 @@ class SyntheticTextDatasetConfig(DataConfig): gt=0, default=1, ) + tool_call_turns: list[int] = Field( + description="Which turns should include tool definitions and expect " + "tool-call responses. An int N means 'the first N turns'; a list " + "of ints specifies explicit 0-based turn indices (e.g. [0, 2]). " + "Normalized to a sorted list after validation. " + "When 0 or [] (default), no tool calling is configured.", + default_factory=list, + ) + tools: list[dict[str, Any]] | None = Field( + description="Tool definitions in OpenAI format. When tool_call_turns is " + "non-empty and this is None, a static placeholder tool definition is used.", + default=None, + ) + tool_response_tokens: int | None = Field( + description="Average number of tokens for synthetic tool call responses. " + "When None (default), a short placeholder response is used.", + gt=0, + default=None, + ) + tool_response_tokens_stdev: int | None = Field( + description="Standard deviation for tool response token count.", + gt=0, + default=None, + ) + tool_response_tokens_min: int | None = Field( + description="Minimum number of tokens for tool response.", + gt=0, + default=None, + ) + tool_response_tokens_max: int | None = Field( + description="Maximum number of tokens for tool response.", + gt=0, + default=None, + ) model_config = ConfigDict( extra="allow", @@ -160,6 +196,28 @@ class SyntheticTextDatasetConfig(DataConfig): default=None, ) + @field_validator("tool_call_turns", mode="before") + @classmethod + def _coerce_tool_call_turns(cls, v: int | list[int]) -> list[int]: + """Convert an int N to [0, ..., N-1]; pass lists through sorted.""" + if isinstance(v, int): + if v < 0: + raise ValueError("tool_call_turns int must be >= 0") + return list(range(v)) + if len(v) != len(set(v)): + raise ValueError("tool_call_turns list must not contain duplicates") + return sorted(v) + + @model_validator(mode="after") + def _validate_tool_call_turn_indices(self) -> SyntheticTextDatasetConfig: + """Ensure all tool_call_turns indices are within [0, turns).""" + for idx in self.tool_call_turns: + if idx < 0 or idx >= self.turns: + raise ValueError( + f"tool_call_turns index {idx} out of range [0, {self.turns})" + ) + return self + @model_validator(mode="after") def check_prefix_options(self) -> SyntheticTextDatasetConfig: if self.__pydantic_extra__ is not None: diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index a7e191c45..eec6db62e 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -50,6 +50,7 @@ __all__ = ["WorkerProcess"] + ProcessRequestT = TypeAliasType( "ProcessRequestT", tuple[ @@ -368,7 +369,7 @@ async def _cancel_requests_loop(self): request_info.timings.resolve_end = time.time() self._send_update("cancelled", None, request, request_info) - async def _process_next_request( # noqa: C901 + async def _process_next_request( # noqa: C901, PLR0912, PLR0915 self, target_start: float ) -> ProcessRequestT[RequestT, ResponseT]: """ @@ -377,6 +378,10 @@ async def _process_next_request( # noqa: C901 Retrieves request from messaging queue, applies timing strategy, processes through backend, and publishes status updates throughout the lifecycle. + After resolve completes the worker inspects ``request_info.error`` + to decide whether to continue with remaining conversation turns, cancel + them, or mark the current turn as errored. + :param target_start: Unix timestamp when request should begin processing """ conversation: ConversationT[RequestT] = [] @@ -409,13 +414,23 @@ async def _process_next_request( # noqa: C901 response = resp - # Complete the request request_info.timings.resolve_end = time.time() - self._send_update("completed", response, request, request_info) - # Record Turn + if request_info.error: + self._send_update("errored", response, request, request_info) + else: + self._send_update("completed", response, request, request_info) + history.append((request, response)) + # Cancel remaining conversation turns when the backend signals an error. + if request_info.error: + for skip_req, skip_info in conversation: + skip_info.error = f"Cancelled: {request_info.error}" + skip_info.timings.resolve_end = time.time() + self._send_update("cancelled", None, skip_req, skip_info) + conversation.clear() + response = request = request_info = None except asyncio.CancelledError: premature_exit = True @@ -434,6 +449,12 @@ async def _process_next_request( # noqa: C901 logger.opt(exception=True).debug( f"Backend exception for request {request_info.request_id}" ) + # Cancel remaining conversation turns on backend error + for skip_req, skip_info in conversation: + skip_info.error = f"Cancelled: {request_info.error}" + skip_info.timings.resolve_end = time.time() + self._send_update("cancelled", None, skip_req, skip_info) + conversation.clear() finally: if request_info is not None: self.strategy.request_completed(request_info) diff --git a/src/guidellm/schemas/__init__.py b/src/guidellm/schemas/__init__.py index 4c78446fe..f49220e3c 100644 --- a/src/guidellm/schemas/__init__.py +++ b/src/guidellm/schemas/__init__.py @@ -36,6 +36,7 @@ Percentiles, StatusDistributionSummary, ) +from .tool_call import StreamingToolCall, StreamingToolCallFunction __all__ = [ "BaseModelT", @@ -57,6 +58,8 @@ "StandardBaseModel", "StatusBreakdown", "StatusDistributionSummary", + "StreamingToolCall", + "StreamingToolCallFunction", "SuccessfulT", "TotalT", "UsageMetrics", diff --git a/src/guidellm/schemas/request.py b/src/guidellm/schemas/request.py index 91120409d..f9e438295 100644 --- a/src/guidellm/schemas/request.py +++ b/src/guidellm/schemas/request.py @@ -238,6 +238,12 @@ class GenerationRequest(StandardBaseModel): "where keys are column names and values are lists of column entries." ), ) + expects_tool_call: bool = Field( + default=False, + description="Whether this turn is expected to produce a tool call response. " + "Set by the data pipeline for pre-planned tool-call turns. " + "Derived from the presence of tools_column in the turn data.", + ) input_metrics: UsageMetrics = Field( default_factory=UsageMetrics, description="Input statistics including counts, sizes, and durations.", diff --git a/src/guidellm/schemas/request_stats.py b/src/guidellm/schemas/request_stats.py index e94e31c38..1b1379550 100644 --- a/src/guidellm/schemas/request_stats.py +++ b/src/guidellm/schemas/request_stats.py @@ -18,6 +18,7 @@ from guidellm.schemas.base import StandardBaseDict from guidellm.schemas.info import RequestInfo from guidellm.schemas.request import UsageMetrics +from guidellm.schemas.tool_call import StreamingToolCall __all__ = ["GenerativeRequestStats"] @@ -53,6 +54,10 @@ class GenerativeRequestStats(StandardBaseDict): output: str | None = Field( default=None, description="Generated text output from the request" ) + tool_calls: list[StreamingToolCall] | None = Field( + default=None, + description="Raw tool call payloads from the model response in OpenAI format", + ) info: RequestInfo = Field(description="Request metadata and timing information") input_metrics: UsageMetrics = Field( description="Token usage statistics for the input prompt" diff --git a/src/guidellm/schemas/response.py b/src/guidellm/schemas/response.py index 202fdab19..c52210d2a 100644 --- a/src/guidellm/schemas/response.py +++ b/src/guidellm/schemas/response.py @@ -15,8 +15,13 @@ from guidellm.schemas.info import RequestInfo from guidellm.schemas.request import GenerationRequest, UsageMetrics from guidellm.schemas.request_stats import GenerativeRequestStats +from guidellm.schemas.tool_call import StreamingToolCall, StreamingToolCallFunction -__all__ = ["GenerationResponse"] +__all__ = [ + "GenerationResponse", + "StreamingToolCall", + "StreamingToolCallFunction", +] class GenerationResponse(StandardBaseModel): @@ -24,8 +29,8 @@ class GenerationResponse(StandardBaseModel): Response model for backend generation operations. Captures the output and metrics from a generation request, providing structured - data for text output, token usage statistics, and compilation of detailed - request statistics for analysis and monitoring purposes. + data for text output, tool call payloads, token usage statistics, and compilation + of detailed request statistics for analysis and monitoring purposes. Example: :: @@ -52,6 +57,13 @@ class GenerationResponse(StandardBaseModel): default=None, description="The generated response text.", ) + tool_calls: list[StreamingToolCall] | None = Field( + default=None, + description=( + "Raw tool call payloads from the model response, each containing " + "id, type, and function (name + arguments) in OpenAI format." + ), + ) input_metrics: UsageMetrics = Field( default_factory=UsageMetrics, description="Token usage statistics from the input prompt.", @@ -115,6 +127,7 @@ def compile_stats( response_id=self.response_id, request_args=self.request_args, output=self.text, + tool_calls=self.tool_calls, info=info, input_metrics=UsageMetrics(**input_metrics_dict), output_metrics=UsageMetrics(**output_metrics_dict), diff --git a/src/guidellm/schemas/tool_call.py b/src/guidellm/schemas/tool_call.py new file mode 100644 index 000000000..331878056 --- /dev/null +++ b/src/guidellm/schemas/tool_call.py @@ -0,0 +1,33 @@ +""" +Tool call data models for streaming and non-streaming responses. + +Provides Pydantic models for representing tool calls returned by OpenAI-compatible +APIs. Used by both the response and request statistics schemas to carry tool call +payloads through the benchmarking pipeline. +""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + +__all__ = [ + "StreamingToolCall", + "StreamingToolCallFunction", +] + + +class StreamingToolCallFunction(BaseModel): + """Accumulated function name and arguments for a single streamed tool call.""" + + name: str = "" + arguments: str = "" + + +class StreamingToolCall(BaseModel): + """A single tool call reassembled from streaming deltas.""" + + id: str = "" + type: str = "function" + function: StreamingToolCallFunction = Field( + default_factory=StreamingToolCallFunction + ) diff --git a/src/guidellm/settings.py b/src/guidellm/settings.py index f16ac5b6c..ed8ff713d 100644 --- a/src/guidellm/settings.py +++ b/src/guidellm/settings.py @@ -107,6 +107,7 @@ class Settings(BaseSettings): # Data settings dataset: DatasetSettings = DatasetSettings() + default_synthetic_tool_response: str = '{"status": "ok"}' # Report settings report_generation: ReportGenerationSettings = ReportGenerationSettings() diff --git a/tests/unit/backends/openai/test_request_handlers.py b/tests/unit/backends/openai/test_request_handlers.py index d03122065..cdad1cd0e 100644 --- a/tests/unit/backends/openai/test_request_handlers.py +++ b/tests/unit/backends/openai/test_request_handlers.py @@ -679,11 +679,9 @@ def test_format_messages_prefix(self, valid_instances): result = instance.format(data) - assert len(result.body["messages"]) == 2 + assert len(result.body["messages"]) == 1 assert result.body["messages"][0]["role"] == "system" assert result.body["messages"][0]["content"] == "You are a helpful assistant." - assert result.body["messages"][1]["role"] == "user" - assert result.body["messages"][1]["content"] == [] @pytest.mark.sanity def test_format_messages_image(self, valid_instances): diff --git a/tests/unit/schemas/test_response.py b/tests/unit/schemas/test_response.py index 853f92520..08b4cdd3a 100644 --- a/tests/unit/schemas/test_response.py +++ b/tests/unit/schemas/test_response.py @@ -14,7 +14,10 @@ GenerationResponse, GenerativeRequestStats, RequestInfo, + RequestTimings, StandardBaseModel, + StreamingToolCall, + StreamingToolCallFunction, UsageMetrics, ) @@ -205,6 +208,53 @@ def test_compile_stats_failed_request(self): assert stats.request_id == "test-123" assert stats.info.status == "errored" + @pytest.mark.smoke + def test_compile_stats_persists_tool_calls(self): + """ + compile_stats carries tool_calls from the response to the stats. + + ## WRITTEN BY AI ## + """ + tool_calls = [ + StreamingToolCall( + id="call_1", + function=StreamingToolCallFunction( + name="get_weather", arguments='{"city": "NYC"}' + ), + ) + ] + response = GenerationResponse( + request_id="test-tc", + request_args="test_args", + text=None, + tool_calls=tool_calls, + ) + + request = GenerationRequest(request_id="test-tc") + info = RequestInfo(request_id="test-tc", status="completed") + + stats = response.compile_stats(request, info) + assert stats.tool_calls == tool_calls + + @pytest.mark.smoke + def test_compile_stats_tool_calls_none_when_absent(self): + """ + compile_stats leaves tool_calls as None for plain-text responses. + + ## WRITTEN BY AI ## + """ + response = GenerationResponse( + request_id="test-plain", + request_args="test_args", + text="Hello world", + ) + + request = GenerationRequest(request_id="test-plain") + info = RequestInfo(request_id="test-plain", status="completed") + + stats = response.compile_stats(request, info) + assert stats.tool_calls is None + @pytest.mark.sanity def test_compile_stats_mismatched_request_id(self): """Test compile_stats with mismatched request IDs.""" @@ -263,3 +313,60 @@ def test_marshalling( reconstructed.output_metrics.model_dump() == instance.output_metrics.model_dump() ) + + +class TestGenerativeRequestStatsToolCalls: + """ + Tests for tool_calls field on GenerativeRequestStats. + + ## WRITTEN BY AI ## + """ + + @pytest.mark.smoke + def test_tool_calls_round_trips_through_serialization(self): + """ + tool_calls survives model_dump / model_validate. + + ## WRITTEN BY AI ## + """ + tool_calls = [ + StreamingToolCall( + id="call_1", + function=StreamingToolCallFunction(name="fn", arguments="{}"), + ) + ] + tool_calls_dicts = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "fn", "arguments": "{}"}, + } + ] + timings = RequestTimings(resolve_start=1.0, resolve_end=2.0) + stats = GenerativeRequestStats( + request_id="rt-1", + info=RequestInfo(request_id="rt-1", status="completed", timings=timings), + input_metrics=UsageMetrics(), + output_metrics=UsageMetrics(), + tool_calls=tool_calls, + ) + data = stats.model_dump() + assert data["tool_calls"] == tool_calls_dicts + + restored = GenerativeRequestStats.model_validate(data) + assert restored.tool_calls == tool_calls + + @pytest.mark.smoke + def test_tool_calls_defaults_to_none(self): + """ + tool_calls is None when not provided. + + ## WRITTEN BY AI ## + """ + stats = GenerativeRequestStats( + request_id="rt-2", + info=RequestInfo(request_id="rt-2", status="completed"), + input_metrics=UsageMetrics(), + output_metrics=UsageMetrics(), + ) + assert stats.tool_calls is None diff --git a/tests/unit/test_pre_anticipated_tool_calls.py b/tests/unit/test_pre_anticipated_tool_calls.py new file mode 100644 index 000000000..ffa296d8e --- /dev/null +++ b/tests/unit/test_pre_anticipated_tool_calls.py @@ -0,0 +1,1523 @@ +""" +Tests for pre-anticipated tool call design: tool-call turns are pre-planned at +data generation time rather than dynamically created in the worker loop. + +Covers config validation, synthetic data generation, finalizer flag setting, +request handler tool_choice overrides, and worker missing-tool-call behavior. + +## WRITTEN BY AI ## +""" + +from __future__ import annotations + +import asyncio +import time +from multiprocessing import Barrier, Event +from typing import Any +from unittest.mock import MagicMock + +import pytest +import pytest_asyncio +from pydantic import ValidationError + +from guidellm.backends.openai.request_handlers import ( + ChatCompletionsRequestHandler, +) +from guidellm.data.finalizers import GenerativeRequestFinalizer +from guidellm.data.schemas import SyntheticTextDatasetConfig +from guidellm.scheduler import SynchronousStrategy, WorkerProcess +from guidellm.schemas import GenerationRequest, RequestInfo, UsageMetrics +from guidellm.schemas.tool_call import StreamingToolCall, StreamingToolCallFunction +from guidellm.utils.imports import json +from guidellm.utils.messaging import InterProcessMessagingQueue +from tests.unit.testing_utils import async_timeout + +# --------------------------------------------------------------------------- +# SyntheticTextDatasetConfig validation +# --------------------------------------------------------------------------- + + +class TestSyntheticTextDatasetConfigToolCallFields: + """Validate tool_call_turns and tools fields on SyntheticTextDatasetConfig. + + ## WRITTEN BY AI ## + """ + + @pytest.mark.smoke + def test_defaults_no_tool_calling(self): + """Default config has no tool calling enabled. + + ## WRITTEN BY AI ## + """ + config = SyntheticTextDatasetConfig(prompt_tokens=50, output_tokens=50) + assert config.tool_call_turns == [] + assert config.tools is None + + @pytest.mark.smoke + def test_tool_call_turns_less_than_turns(self): + """tool_call_turns int is normalized to a list of indices. + + ## WRITTEN BY AI ## + """ + config = SyntheticTextDatasetConfig( + prompt_tokens=50, output_tokens=50, turns=3, tool_call_turns=2 + ) + assert config.tool_call_turns == [0, 1] + + @pytest.mark.sanity + def test_tool_call_turns_equal_to_turns_accepted(self): + """ + tool_call_turns == turns is valid (all turns are tool-call turns, + no final plain-text response). + + ## WRITTEN BY AI ## + """ + config = SyntheticTextDatasetConfig( + prompt_tokens=50, output_tokens=50, turns=3, tool_call_turns=3 + ) + assert config.tool_call_turns == [0, 1, 2] + + @pytest.mark.sanity + def test_custom_tools_accepted(self): + """Custom tools with valid tool_call_turns are accepted. + + ## WRITTEN BY AI ## + """ + custom_tools = [{"type": "function", "function": {"name": "my_func"}}] + config = SyntheticTextDatasetConfig( + prompt_tokens=50, + output_tokens=50, + turns=3, + tool_call_turns=1, + tools=custom_tools, + ) + assert config.tools == custom_tools + assert config.tool_call_turns == [0] + + @pytest.mark.smoke + def test_list_tool_call_turns_accepted(self): + """Explicit list of turn indices is accepted and sorted. + + ## WRITTEN BY AI ## + """ + config = SyntheticTextDatasetConfig( + prompt_tokens=50, output_tokens=50, turns=4, tool_call_turns=[2, 0] + ) + assert config.tool_call_turns == [0, 2] + + @pytest.mark.sanity + def test_list_tool_call_turns_validation_out_of_range(self): + """List indices must be within [0, turns). + + ## WRITTEN BY AI ## + """ + with pytest.raises(ValidationError, match="out of range"): + SyntheticTextDatasetConfig( + prompt_tokens=50, output_tokens=50, turns=3, tool_call_turns=[0, 3] + ) + + @pytest.mark.sanity + def test_list_tool_call_turns_validation_duplicates(self): + """Duplicate indices in the list are rejected. + + ## WRITTEN BY AI ## + """ + with pytest.raises(ValidationError, match="duplicates"): + SyntheticTextDatasetConfig( + prompt_tokens=50, output_tokens=50, turns=3, tool_call_turns=[0, 0] + ) + + @pytest.mark.sanity + def test_int_tool_call_turns_exceeds_turns_rejected(self): + """An int greater than turns is rejected. + + ## WRITTEN BY AI ## + """ + with pytest.raises(ValidationError, match="out of range"): + SyntheticTextDatasetConfig( + prompt_tokens=50, output_tokens=50, turns=2, tool_call_turns=3 + ) + + +# --------------------------------------------------------------------------- +# Synthetic data generation +# --------------------------------------------------------------------------- + + +class TestSyntheticDataToolColumns: + """Verify synthetic data emits tools_{turn} columns for tool_call_turns. + + ## WRITTEN BY AI ## + """ + + @pytest.fixture + def processor(self): + """Minimal mock processor for token encoding/decoding. + + ## WRITTEN BY AI ## + """ + proc = MagicMock() + proc.encode.return_value = list(range(100)) + proc.decode.return_value = "mock text" + return proc + + @pytest.mark.smoke + def test_no_tools_columns_when_tool_call_turns_zero(self, processor): + """With tool_call_turns=0, no tools columns are emitted. + + ## WRITTEN BY AI ## + """ + from guidellm.data.deserializers.synthetic import ( + _SyntheticTextExamplesIterable, + ) + + config = SyntheticTextDatasetConfig(prompt_tokens=10, output_tokens=10, turns=3) + iterable = _SyntheticTextExamplesIterable(config, processor, random_seed=42) + _, row = next(iter(iterable)) + + assert "tools_0" not in row + assert "tools_1" not in row + assert "tools_2" not in row + + @pytest.mark.smoke + def test_tools_columns_emitted_for_tool_call_turns(self, processor): + """With tool_call_turns=2 and turns=3, tools_0 and tools_1 are emitted. + + ## WRITTEN BY AI ## + """ + from guidellm.data.deserializers.synthetic import ( + DEFAULT_SYNTHETIC_TOOLS, + _SyntheticTextExamplesIterable, + ) + + config = SyntheticTextDatasetConfig( + prompt_tokens=10, output_tokens=10, turns=3, tool_call_turns=2 + ) + iterable = _SyntheticTextExamplesIterable(config, processor, random_seed=42) + _, row = next(iter(iterable)) + + assert "tools_0" in row + assert "tools_1" in row + assert "tools_2" not in row + + # Values are JSON-serialized lists + tools_0 = json.loads(row["tools_0"]) + assert tools_0 == DEFAULT_SYNTHETIC_TOOLS + + @pytest.mark.smoke + def test_non_contiguous_tool_call_turns_list(self, processor): + """With tool_call_turns=[0, 2] and turns=4, only turns 0 and 2 get tools. + + ## WRITTEN BY AI ## + """ + from guidellm.data.deserializers.synthetic import ( + DEFAULT_SYNTHETIC_TOOLS, + _SyntheticTextExamplesIterable, + ) + + config = SyntheticTextDatasetConfig( + prompt_tokens=10, output_tokens=10, turns=4, tool_call_turns=[0, 2] + ) + iterable = _SyntheticTextExamplesIterable(config, processor, random_seed=42) + _, row = next(iter(iterable)) + + assert "tools_0" in row + assert "tools_1" not in row + assert "tools_2" in row + assert "tools_3" not in row + + tools_0 = json.loads(row["tools_0"]) + assert tools_0 == DEFAULT_SYNTHETIC_TOOLS + + @pytest.mark.sanity + def test_custom_tools_used_in_synthetic_data(self, processor): + """User-provided tools are used instead of the default placeholder. + + ## WRITTEN BY AI ## + """ + from guidellm.data.deserializers.synthetic import ( + _SyntheticTextExamplesIterable, + ) + + custom_tools = [{"type": "function", "function": {"name": "custom_fn"}}] + config = SyntheticTextDatasetConfig( + prompt_tokens=10, + output_tokens=10, + turns=2, + tool_call_turns=1, + tools=custom_tools, + ) + iterable = _SyntheticTextExamplesIterable(config, processor, random_seed=42) + _, row = next(iter(iterable)) + + tools_0 = json.loads(row["tools_0"]) + assert tools_0 == custom_tools + + @pytest.mark.sanity + def test_features_include_tools_columns(self, processor): + """Features property includes tools_{i} entries for tool_call_turns. + + ## WRITTEN BY AI ## + """ + from guidellm.data.deserializers.synthetic import ( + _SyntheticTextExamplesIterable, + ) + + config = SyntheticTextDatasetConfig( + prompt_tokens=10, output_tokens=10, turns=3, tool_call_turns=2 + ) + iterable = _SyntheticTextExamplesIterable(config, processor, random_seed=42) + features = iterable.features + + assert "tools_0" in features + assert "tools_1" in features + assert "tools_2" not in features + + @pytest.mark.sanity + def test_features_non_contiguous_tool_call_turns(self, processor): + """Features property includes tools_{i} only for listed turn indices. + + ## WRITTEN BY AI ## + """ + from guidellm.data.deserializers.synthetic import ( + _SyntheticTextExamplesIterable, + ) + + config = SyntheticTextDatasetConfig( + prompt_tokens=10, output_tokens=10, turns=4, tool_call_turns=[1, 3] + ) + iterable = _SyntheticTextExamplesIterable(config, processor, random_seed=42) + features = iterable.features + + assert "tools_0" not in features + assert "tools_1" in features + assert "tools_2" not in features + assert "tools_3" in features + + +# --------------------------------------------------------------------------- +# Finalizer: expects_tool_call flag +# --------------------------------------------------------------------------- + + +class TestFinalizerExpectsToolCall: + """Verify GenerativeRequestFinalizer sets expects_tool_call correctly. + + ## WRITTEN BY AI ## + """ + + @pytest.fixture + def finalizer(self): + """ + ## WRITTEN BY AI ## + """ + return GenerativeRequestFinalizer() + + @pytest.mark.smoke + def test_expects_tool_call_matches_tools_column_presence(self, finalizer): + """expects_tool_call is True only on turns that have tools_column. + + ## WRITTEN BY AI ## + """ + items = [ + {"text_column": ["hello"], "tools_column": ['[{"type": "function"}]']}, + {"text_column": ["world"]}, + ] + results = finalizer(items) + + assert results[0].expects_tool_call is True + assert results[1].expects_tool_call is False + + @pytest.mark.smoke + def test_all_turns_with_tools_all_expect_tool_call(self, finalizer): + """When every turn has tools_column, every turn expects a tool call. + + ## WRITTEN BY AI ## + """ + items = [ + {"text_column": ["hello"], "tools_column": ['[{"type": "function"}]']}, + {"text_column": ["world"], "tools_column": ['[{"type": "function"}]']}, + ] + results = finalizer(items) + + assert results[0].expects_tool_call is True + assert results[1].expects_tool_call is True + + @pytest.mark.sanity + def test_expects_tool_call_false_without_tools(self, finalizer): + """Turns without tools_column have expects_tool_call=False. + + ## WRITTEN BY AI ## + """ + items = [ + {"text_column": ["hello"]}, + {"text_column": ["world"]}, + ] + results = finalizer(items) + + assert results[0].expects_tool_call is False + assert results[1].expects_tool_call is False + + @pytest.mark.sanity + def test_single_turn_with_tools_expects_tool_call(self, finalizer): + """A single-turn conversation with tools has expects_tool_call=True. + + ## WRITTEN BY AI ## + """ + items = [ + {"text_column": ["hello"], "tools_column": ['[{"type": "function"}]']}, + ] + results = finalizer(items) + assert results[0].expects_tool_call is True + + +# --------------------------------------------------------------------------- +# Request handler: tool_choice override +# --------------------------------------------------------------------------- + + +class TestChatCompletionsToolChoiceOverride: + """Verify tool_choice is overridden to 'none' on non-tool-call turns. + + ## WRITTEN BY AI ## + """ + + @pytest.fixture + def handler(self): + """ + ## WRITTEN BY AI ## + """ + return ChatCompletionsRequestHandler() + + @pytest.mark.smoke + def test_tool_choice_none_when_expects_false(self, handler): + """When expects_tool_call=False and tools come from dataset, tool_choice='none'. + + ## WRITTEN BY AI ## + """ + tools = [{"type": "function", "function": {"name": "fn"}}] + data = GenerationRequest( + columns={ + "text_column": ["test"], + "tools_column": [json.dumps(tools)], + }, + expects_tool_call=False, + ) + extras = {"body": {"tool_choice": "required"}} + result = handler.format(data, extras=extras) + + assert result.body["tool_choice"] == "none" + + @pytest.mark.smoke + def test_tool_choice_preserved_when_expects_true(self, handler): + """When expects_tool_call=True, the configured tool_choice is kept. + + ## WRITTEN BY AI ## + """ + tools = [{"type": "function", "function": {"name": "fn"}}] + data = GenerationRequest( + columns={ + "text_column": ["test"], + "tools_column": [json.dumps(tools)], + }, + expects_tool_call=True, + ) + extras = {"body": {"tool_choice": "required"}} + result = handler.format(data, extras=extras) + + assert result.body["tool_choice"] == "required" + + @pytest.mark.sanity + def test_auto_tool_choice_preserved_when_expects_true(self, handler): + """When expects_tool_call=True with auto mode, tool_choice stays 'auto'. + + ## WRITTEN BY AI ## + """ + tools = [{"type": "function", "function": {"name": "fn"}}] + data = GenerationRequest( + columns={ + "text_column": ["test"], + "tools_column": [json.dumps(tools)], + }, + expects_tool_call=True, + ) + extras = {"body": {"tool_choice": "auto"}} + result = handler.format(data, extras=extras) + + assert result.body["tool_choice"] == "auto" + + @pytest.mark.sanity + def test_no_override_without_tools(self, handler): + """Without tools in body, no tool_choice override happens. + + ## WRITTEN BY AI ## + """ + data = GenerationRequest( + columns={"text_column": ["test"]}, + expects_tool_call=False, + ) + result = handler.format(data) + + assert "tool_choice" not in result.body + + @pytest.mark.sanity + def test_per_request_tools_deserialized_from_json(self, handler): + """JSON-serialized tools from synthetic data are deserialized. + + ## WRITTEN BY AI ## + """ + tools = [{"type": "function", "function": {"name": "get_data"}}] + data = GenerationRequest( + columns={ + "text_column": ["test"], + "tools_column": [json.dumps(tools)], + }, + expects_tool_call=True, + ) + result = handler.format(data) + + assert result.body["tools"] == tools + + @pytest.mark.smoke + def test_max_completion_tokens_stripped_on_tool_call_turn(self, handler): + """On tool-call turns, max_completion_tokens is removed so the model + can finish producing valid tool call JSON without truncation. + + ## WRITTEN BY AI ## + """ + tools = [{"type": "function", "function": {"name": "fn"}}] + data = GenerationRequest( + columns={ + "text_column": ["test"], + "tools_column": [json.dumps(tools)], + }, + expects_tool_call=True, + output_metrics=UsageMetrics(text_tokens=100), + ) + result = handler.format(data) + + assert "max_completion_tokens" not in result.body + assert "max_tokens" not in result.body + + @pytest.mark.smoke + def test_max_completion_tokens_kept_on_plain_text_turn(self, handler): + """On the final plain-text turn, max_completion_tokens is preserved + to control output length. + + ## WRITTEN BY AI ## + """ + tools = [{"type": "function", "function": {"name": "fn"}}] + data = GenerationRequest( + columns={ + "text_column": ["test"], + "tools_column": [json.dumps(tools)], + }, + expects_tool_call=False, + output_metrics=UsageMetrics(text_tokens=100), + ) + result = handler.format(data) + + assert result.body["max_completion_tokens"] == 100 + + +# --------------------------------------------------------------------------- +# Worker: missing tool call handling +# --------------------------------------------------------------------------- + + +class _MockToolBackend: + """Mock backend that raises exceptions for missing tool calls. + + Mimics what OpenAIHTTPBackend._check_tool_call_expectations does: + raises CancelledError or ValueError during resolve based on + tool_call_missing_behavior, so the worker reacts via its exception + handlers. + """ + + def __init__( + self, + has_tool_calls: bool = True, + tool_call_missing_behavior: str = "error_stop", + ): + self.has_tool_calls = has_tool_calls + self.tool_call_missing_behavior = tool_call_missing_behavior + self.process_startup_called = False + self.validate_called = False + self.process_shutdown_called = False + + @property + def processes_limit(self): + return None + + @property + def requests_limit(self): + return None + + @property + def info(self): + return {"type": "mock_tool"} + + async def process_startup(self): + self.process_startup_called = True + + async def validate(self): + self.validate_called = True + + async def process_shutdown(self): + self.process_shutdown_called = True + + async def resolve(self, request, request_info, history=None): + response = MagicMock() + response.tool_calls = ( + [{"id": "call_1", "type": "function", "function": {"name": "fn"}}] + if self.has_tool_calls + else None + ) + + yield response, request_info + + # Replicate what the real backend does in + # _check_tool_call_expectations: raise exceptions when a tool + # call was expected but not produced. + if request.expects_tool_call and not self.has_tool_calls: + if self.tool_call_missing_behavior == "ignore_stop": + raise asyncio.CancelledError( + "Expected tool call but model produced none" + ) + elif self.tool_call_missing_behavior == "error_stop": + raise ValueError("Expected tool call but model produced none") + + +def _make_conversation( + num_turns: int, tool_call_turns: int +) -> list[tuple[Any, RequestInfo]]: + """Build a pre-planned conversation list for testing. + + ## WRITTEN BY AI ## + """ + conv = [] + for i in range(num_turns): + req = GenerationRequest( + columns={"text_column": [f"turn_{i}"]}, + expects_tool_call=(i < tool_call_turns), + ) + info = RequestInfo( + request_id=req.request_id, + conversation_id="conv_1", + turn_index=i, + status="queued", + ) + conv.append((req, info)) + return conv + + +class TestWorkerMissingToolCallBehavior: + """Test worker handling of missing tool calls for all 3 behaviors. + + ## WRITTEN BY AI ## + """ + + @pytest_asyncio.fixture + async def make_worker(self): + """Factory fixture that creates a worker with the given backend. + + ## WRITTEN BY AI ## + """ + workers = [] + + async def _factory(backend): + messaging = InterProcessMessagingQueue( + serialization="dict", + encoding=None, + max_buffer_receive_size=10, + poll_interval=0.01, + ) + await messaging.start(pydantic_models=[]) + + worker = WorkerProcess( + worker_index=0, + messaging=messaging.create_worker_copy(0), + backend=backend, + strategy=SynchronousStrategy(), + async_limit=1, + fut_scheduling_time_limit=10.0, + startup_barrier=Barrier(1), + requests_generated_event=Event(), + constraint_reached_event=Event(), + shutdown_event=Event(), + error_event=Event(), + ) + workers.append((worker, messaging)) + return worker, messaging + + yield _factory + + for _, msg in workers: + await msg.stop() + + @async_timeout(5.0) + @pytest.mark.asyncio + @pytest.mark.smoke + async def test_ignore_continue_keeps_remaining_turns(self, make_worker): + """ignore_continue: conversation continues to next turn normally. + + ## WRITTEN BY AI ## + """ + backend = _MockToolBackend( + has_tool_calls=False, + tool_call_missing_behavior="ignore_continue", + ) + worker, messaging = await make_worker(backend) + + conv = _make_conversation(num_turns=3, tool_call_turns=2) + await messaging.put(conv) + + await worker._processing_startup() + + history, remaining_conv, info = await worker._process_next_request( + target_start=time.time() + ) + + # ignore_continue: remaining turns are preserved + assert len(remaining_conv) == 2 + assert remaining_conv[0][0].expects_tool_call is True + assert remaining_conv[1][0].expects_tool_call is False + + @async_timeout(5.0) + @pytest.mark.asyncio + @pytest.mark.smoke + async def test_ignore_stop_cancels_all_turns(self, make_worker): + """ignore_stop: current turn and all remaining turns are cancelled. + + The backend raises CancelledError which the worker catches and + marks the current request as cancelled, then cancels remaining turns. + + ## WRITTEN BY AI ## + """ + backend = _MockToolBackend( + has_tool_calls=False, + tool_call_missing_behavior="ignore_stop", + ) + worker, messaging = await make_worker(backend) + + conv = _make_conversation(num_turns=3, tool_call_turns=2) + await messaging.put(conv) + + await worker._processing_startup() + + # Collect status updates + updates = [] + original_send = worker._send_update + + def capture_send(status, response, request, request_info): + updates.append((status, request_info.request_id)) + original_send(status, response, request, request_info) + + worker._send_update = capture_send + + with pytest.raises(asyncio.CancelledError): + await worker._process_next_request(target_start=time.time()) + + # Current turn cancelled + remaining 2 turns cancelled + cancelled_updates = [u for u in updates if u[0] == "cancelled"] + assert len(cancelled_updates) == 3 + + @async_timeout(5.0) + @pytest.mark.asyncio + @pytest.mark.smoke + async def test_error_stop_errors_and_cancels(self, make_worker): + """error_stop: current turn errored via ValueError, remaining cancelled. + + The backend raises ValueError which the worker catches via its + generic exception handler, setting request_info.error and sending + an "errored" status update. + + ## WRITTEN BY AI ## + """ + backend = _MockToolBackend( + has_tool_calls=False, + tool_call_missing_behavior="error_stop", + ) + worker, messaging = await make_worker(backend) + + conv = _make_conversation(num_turns=3, tool_call_turns=2) + await messaging.put(conv) + + await worker._processing_startup() + + updates = [] + original_send = worker._send_update + + def capture_send(status, response, request, request_info): + updates.append((status, request_info.request_id)) + original_send(status, response, request, request_info) + + worker._send_update = capture_send + + history, remaining_conv, info = await worker._process_next_request( + target_start=time.time() + ) + + # error_stop: conversation should be empty (cancelled in finally block) + assert len(remaining_conv) == 0 + + # Should have errored the current turn + errored_updates = [u for u in updates if u[0] == "errored"] + assert len(errored_updates) == 1 + + @async_timeout(5.0) + @pytest.mark.asyncio + @pytest.mark.sanity + async def test_tool_call_present_continues_normally(self, make_worker): + """When tool call IS present, conversation continues normally. + + ## WRITTEN BY AI ## + """ + backend = _MockToolBackend( + has_tool_calls=True, + tool_call_missing_behavior="error_stop", + ) + worker, messaging = await make_worker(backend) + + conv = _make_conversation(num_turns=3, tool_call_turns=2) + await messaging.put(conv) + + await worker._processing_startup() + + history, remaining_conv, info = await worker._process_next_request( + target_start=time.time() + ) + + # Tool call present: conversation continues with remaining turns + assert len(remaining_conv) == 2 + + @async_timeout(5.0) + @pytest.mark.asyncio + @pytest.mark.sanity + async def test_non_tool_turn_ignores_behavior(self, make_worker): + """Non-tool turns don't trigger missing-tool-call logic at all. + + ## WRITTEN BY AI ## + """ + backend = _MockToolBackend( + has_tool_calls=False, + tool_call_missing_behavior="error_stop", + ) + worker, messaging = await make_worker(backend) + + # Conversation with no tool call turns + conv = _make_conversation(num_turns=2, tool_call_turns=0) + await messaging.put(conv) + + await worker._processing_startup() + + updates = [] + original_send = worker._send_update + + def capture_send(status, response, request, request_info): + updates.append((status, request_info.request_id)) + original_send(status, response, request, request_info) + + worker._send_update = capture_send + + history, remaining_conv, info = await worker._process_next_request( + target_start=time.time() + ) + + # No error or cancellation -- normal flow + assert len(remaining_conv) == 1 + errored = [u for u in updates if u[0] == "errored"] + cancelled = [u for u in updates if u[0] == "cancelled"] + assert len(errored) == 0 + assert len(cancelled) == 0 + + @async_timeout(5.0) + @pytest.mark.asyncio + @pytest.mark.sanity + async def test_ignore_continue_per_turn_independence(self, make_worker): + """ignore_continue: each tool turn succeeds or fails independently. + + Process two turns sequentially. First turn misses tool call but + conversation continues. Second tool turn gets its own chance. + + ## WRITTEN BY AI ## + """ + backend = _MockToolBackend( + has_tool_calls=False, + tool_call_missing_behavior="ignore_continue", + ) + worker, messaging = await make_worker(backend) + + conv = _make_conversation(num_turns=3, tool_call_turns=2) + await messaging.put(conv) + + await worker._processing_startup() + + # Process first turn (tool turn, no tool call produced) + history, remaining_conv, _ = await worker._process_next_request( + target_start=time.time() + ) + + # Still have 2 remaining turns + assert len(remaining_conv) == 2 + + # The next turn is also a tool turn + assert remaining_conv[0][0].expects_tool_call is True + + +# --------------------------------------------------------------------------- +# Backend: tool_call_missing_behavior validation +# --------------------------------------------------------------------------- + + +class TestOpenAIBackendToolCallMissingBehavior: + """Validate tool_call_missing_behavior field on the backend. + + ## WRITTEN BY AI ## + """ + + @pytest.mark.smoke + def test_default_is_error_stop(self): + """Default tool_call_missing_behavior is error_stop. + + ## WRITTEN BY AI ## + """ + from guidellm.backends.openai.http import OpenAIHTTPBackend + + backend = OpenAIHTTPBackend(target="http://localhost:8000") + assert backend.tool_call_missing_behavior == "error_stop" + + @pytest.mark.sanity + def test_valid_behaviors_accepted(self): + """All valid tool_call_missing_behavior values are accepted. + + ## WRITTEN BY AI ## + """ + from guidellm.backends.openai.http import OpenAIHTTPBackend + + for behavior in ("ignore_continue", "ignore_stop", "error_stop"): + backend = OpenAIHTTPBackend( + target="http://localhost:8000", + tool_call_missing_behavior=behavior, + ) + assert backend.tool_call_missing_behavior == behavior + + @pytest.mark.sanity + def test_invalid_behavior_rejected(self): + """Invalid tool_call_missing_behavior is rejected by the Literal type. + + ## WRITTEN BY AI ## + """ + from guidellm.backends.openai.http import OpenAIHttpBackendArgs + + with pytest.raises(ValidationError): + OpenAIHttpBackendArgs( + target="http://localhost:8000", + tool_call_missing_behavior="invalid_mode", + ) + + +# --------------------------------------------------------------------------- +# Backend: check tool call expectations +# --------------------------------------------------------------------------- + + +class TestCheckToolCallExpectations: + """Verify _check_tool_call_expectations raises the right exceptions. + + ## WRITTEN BY AI ## + """ + + def _make_backend(self, behavior: str): + """ + ## WRITTEN BY AI ## + """ + from guidellm.backends.openai.http import OpenAIHTTPBackend + + return OpenAIHTTPBackend( + target="http://localhost:8000", + tool_call_missing_behavior=behavior, + ) + + def _make_request(self, expects_tool_call: bool) -> GenerationRequest: + """ + ## WRITTEN BY AI ## + """ + return GenerationRequest( + columns={"text_column": ["test"]}, + expects_tool_call=expects_tool_call, + ) + + def _make_response(self, has_tool_calls: bool): + """ + ## WRITTEN BY AI ## + """ + resp = MagicMock() + resp.tool_calls = ( + [ + StreamingToolCall( + id="call_1", + function=StreamingToolCallFunction(name="fn"), + ) + ] + if has_tool_calls + else None + ) + return resp + + @pytest.mark.smoke + def test_no_op_when_tool_call_present(self): + """No exception when the model produced a tool call. + + ## WRITTEN BY AI ## + """ + backend = self._make_backend("error_stop") + req = self._make_request(expects_tool_call=True) + resp = self._make_response(has_tool_calls=True) + + backend._check_tool_call_expectations(req, resp) + + @pytest.mark.smoke + def test_no_op_when_not_expecting_tool_call(self): + """No exception when the turn doesn't expect a tool call. + + ## WRITTEN BY AI ## + """ + backend = self._make_backend("error_stop") + req = self._make_request(expects_tool_call=False) + resp = self._make_response(has_tool_calls=False) + + backend._check_tool_call_expectations(req, resp) + + @pytest.mark.smoke + def test_ignore_continue_raises_nothing(self): + """ignore_continue: no exception even when tool call is missing. + + ## WRITTEN BY AI ## + """ + backend = self._make_backend("ignore_continue") + req = self._make_request(expects_tool_call=True) + resp = self._make_response(has_tool_calls=False) + + backend._check_tool_call_expectations(req, resp) + + @pytest.mark.smoke + def test_ignore_stop_raises_cancelled_error(self): + """ignore_stop: raises CancelledError when tool call is missing. + + ## WRITTEN BY AI ## + """ + backend = self._make_backend("ignore_stop") + req = self._make_request(expects_tool_call=True) + resp = self._make_response(has_tool_calls=False) + + with pytest.raises(asyncio.CancelledError, match="tool call"): + backend._check_tool_call_expectations(req, resp) + + @pytest.mark.smoke + def test_error_stop_raises_value_error(self): + """error_stop: raises ValueError when tool call is missing. + + ## WRITTEN BY AI ## + """ + backend = self._make_backend("error_stop") + req = self._make_request(expects_tool_call=True) + resp = self._make_response(has_tool_calls=False) + + with pytest.raises(ValueError, match="tool call"): + backend._check_tool_call_expectations(req, resp) + + +# --------------------------------------------------------------------------- +# SyntheticTextDatasetConfig: tool_response_tokens validation +# --------------------------------------------------------------------------- + + +class TestSyntheticTextDatasetConfigToolResponseFields: + """Validate tool_response_tokens fields on SyntheticTextDatasetConfig. + + ## WRITTEN BY AI ## + """ + + @pytest.mark.smoke + def test_tool_response_tokens_defaults_to_none(self): + """Default config has no tool_response_tokens. + + ## WRITTEN BY AI ## + """ + config = SyntheticTextDatasetConfig(prompt_tokens=50, output_tokens=50) + assert config.tool_response_tokens is None + assert config.tool_response_tokens_stdev is None + assert config.tool_response_tokens_min is None + assert config.tool_response_tokens_max is None + + @pytest.mark.smoke + def test_tool_response_tokens_accepted_with_tool_call_turns(self): + """tool_response_tokens is valid when tool_call_turns > 0. + + ## WRITTEN BY AI ## + """ + config = SyntheticTextDatasetConfig( + prompt_tokens=50, + output_tokens=50, + turns=3, + tool_call_turns=2, + tool_response_tokens=50, + ) + assert config.tool_response_tokens == 50 + + @pytest.mark.sanity + def test_tool_response_tokens_variance_fields(self): + """All variance fields are accepted together. + + ## WRITTEN BY AI ## + """ + config = SyntheticTextDatasetConfig( + prompt_tokens=50, + output_tokens=50, + turns=3, + tool_call_turns=2, + tool_response_tokens=100, + tool_response_tokens_stdev=20, + tool_response_tokens_min=50, + tool_response_tokens_max=150, + ) + assert config.tool_response_tokens == 100 + assert config.tool_response_tokens_stdev == 20 + assert config.tool_response_tokens_min == 50 + assert config.tool_response_tokens_max == 150 + + +# --------------------------------------------------------------------------- +# Synthetic data: tool_response_{i} columns +# --------------------------------------------------------------------------- + + +class TestSyntheticDataToolResponseColumns: + """Verify synthetic data emits tool_response_{turn} columns. + + ## WRITTEN BY AI ## + """ + + @pytest.fixture + def processor(self): + """Minimal mock processor for token encoding/decoding. + + ## WRITTEN BY AI ## + """ + proc = MagicMock() + proc.encode.return_value = list(range(100)) + proc.decode.return_value = "mock text" + return proc + + @pytest.mark.smoke + def test_default_tool_response_columns_emitted(self, processor): + """When tool_response_tokens is None, placeholder responses are used. + + ## WRITTEN BY AI ## + """ + from guidellm.data.deserializers.synthetic import ( + _SyntheticTextExamplesIterable, + ) + from guidellm.settings import settings + + config = SyntheticTextDatasetConfig( + prompt_tokens=10, output_tokens=10, turns=3, tool_call_turns=2 + ) + iterable = _SyntheticTextExamplesIterable(config, processor, random_seed=42) + _, row = next(iter(iterable)) + + assert row["tool_response_0"] == settings.default_synthetic_tool_response + assert row["tool_response_1"] == settings.default_synthetic_tool_response + assert "tool_response_2" not in row + + @pytest.mark.smoke + def test_variable_length_tool_response_columns(self, processor): + """When tool_response_tokens is set, generated JSON responses are used. + + ## WRITTEN BY AI ## + """ + from guidellm.data.deserializers.synthetic import ( + _SyntheticTextExamplesIterable, + ) + + config = SyntheticTextDatasetConfig( + prompt_tokens=10, + output_tokens=10, + turns=3, + tool_call_turns=2, + tool_response_tokens=30, + ) + iterable = _SyntheticTextExamplesIterable(config, processor, random_seed=42) + _, row = next(iter(iterable)) + + # Should be valid JSON with a "result" key + parsed_0 = json.loads(row["tool_response_0"]) + parsed_1 = json.loads(row["tool_response_1"]) + assert "result" in parsed_0 + assert "result" in parsed_1 + assert "tool_response_2" not in row + + @pytest.mark.sanity + def test_features_include_tool_response_columns(self, processor): + """Features property includes tool_response_{i} for tool_call_turns. + + ## WRITTEN BY AI ## + """ + from guidellm.data.deserializers.synthetic import ( + _SyntheticTextExamplesIterable, + ) + + config = SyntheticTextDatasetConfig( + prompt_tokens=10, output_tokens=10, turns=3, tool_call_turns=2 + ) + iterable = _SyntheticTextExamplesIterable(config, processor, random_seed=42) + features = iterable.features + + assert "tool_response_0" in features + assert "tool_response_1" in features + assert "tool_response_2" not in features + + +# --------------------------------------------------------------------------- +# ToolCallingMessageExtractor: tool response extraction +# --------------------------------------------------------------------------- + + +class TestToolCallingMessageExtractorToolResponses: + """Verify the extractor populates tool_response_column from messages. + + ## WRITTEN BY AI ## + """ + + @pytest.mark.smoke + def test_extracts_tool_role_content(self): + """Messages with role=tool have their content extracted. + + ## WRITTEN BY AI ## + """ + from guidellm.data.preprocessors.tool_calling import ( + ToolCallingMessageExtractor, + ) + + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Call the tool."}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + {"id": "call_1", "function": {"name": "fn", "arguments": "{}"}} + ], + }, + { + "role": "tool", + "content": '{"status": "success", "data": [1, 2]}', + "tool_call_id": "call_1", + }, + {"role": "user", "content": "Thanks!"}, + ] + + items = [{"text_column": [messages]}] + extractor = ToolCallingMessageExtractor() + result = extractor(items) + + assert "tool_response_column" in result[0] + assert result[0]["tool_response_column"] == [ + '{"status": "success", "data": [1, 2]}' + ] + + @pytest.mark.sanity + def test_no_tool_responses_when_absent(self): + """When no role=tool messages exist, tool_response_column is not set. + + ## WRITTEN BY AI ## + """ + from guidellm.data.preprocessors.tool_calling import ( + ToolCallingMessageExtractor, + ) + + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + items = [{"text_column": [messages]}] + extractor = ToolCallingMessageExtractor() + result = extractor(items) + + assert "tool_response_column" not in result[0] + + @pytest.mark.sanity + def test_multiple_tool_responses_extracted(self): + """Multiple role=tool messages are all extracted in order. + + ## WRITTEN BY AI ## + """ + from guidellm.data.preprocessors.tool_calling import ( + ToolCallingMessageExtractor, + ) + + messages = [ + {"role": "user", "content": "Do two things."}, + {"role": "tool", "content": '{"first": true}', "tool_call_id": "c1"}, + {"role": "tool", "content": '{"second": true}', "tool_call_id": "c2"}, + ] + + items = [{"text_column": [messages]}] + extractor = ToolCallingMessageExtractor() + result = extractor(items) + + assert result[0]["tool_response_column"] == [ + '{"first": true}', + '{"second": true}', + ] + + +# --------------------------------------------------------------------------- +# Request handler: tool_response_column usage +# --------------------------------------------------------------------------- + + +class TestChatCompletionsToolResponseColumn: + """Verify request handler uses tool_response_column instead of hardcoded default. + + ## WRITTEN BY AI ## + """ + + @pytest.fixture + def handler(self): + """ + ## WRITTEN BY AI ## + """ + return ChatCompletionsRequestHandler() + + @pytest.mark.smoke + def test_uses_tool_response_from_column(self, handler): + """Tool response content from tool_response_column is used in history. + + ## WRITTEN BY AI ## + """ + from guidellm.schemas import GenerationResponse + + tools = [{"type": "function", "function": {"name": "fn"}}] + prior_request = GenerationRequest( + columns={ + "text_column": ["call the tool"], + "tools_column": [json.dumps(tools)], + "tool_response_column": ['{"result": "custom data"}'], + }, + expects_tool_call=True, + ) + prior_response = MagicMock(spec=GenerationResponse) + prior_response.tool_calls = [ + StreamingToolCall( + id="call_1", + function=StreamingToolCallFunction(name="fn"), + ) + ] + prior_response.text = None + + current_request = GenerationRequest( + columns={"text_column": ["now respond"]}, + expects_tool_call=False, + ) + + result = handler.format( + current_request, + history=[(prior_request, prior_response)], + ) + + # Find the tool role message in the history + tool_messages = [m for m in result.body["messages"] if m.get("role") == "tool"] + assert len(tool_messages) == 1 + assert tool_messages[0]["content"] == '{"result": "custom data"}' + + @pytest.mark.sanity + def test_falls_back_to_default_without_column(self, handler): + """Without tool_response_column, the default placeholder is used. + + ## WRITTEN BY AI ## + """ + from guidellm.schemas import GenerationResponse + from guidellm.settings import settings + + tools = [{"type": "function", "function": {"name": "fn"}}] + prior_request = GenerationRequest( + columns={ + "text_column": ["call the tool"], + "tools_column": [json.dumps(tools)], + }, + expects_tool_call=True, + ) + prior_response = MagicMock(spec=GenerationResponse) + prior_response.tool_calls = [ + StreamingToolCall( + id="call_1", + function=StreamingToolCallFunction(name="fn"), + ) + ] + prior_response.text = None + + current_request = GenerationRequest( + columns={"text_column": ["now respond"]}, + expects_tool_call=False, + ) + + result = handler.format( + current_request, + history=[(prior_request, prior_response)], + ) + + tool_messages = [m for m in result.body["messages"] if m.get("role") == "tool"] + assert len(tool_messages) == 1 + assert tool_messages[0]["content"] == settings.default_synthetic_tool_response + + @pytest.mark.sanity + def test_bytes_tool_response_decoded(self, handler): + """Tool response content stored as bytes (from orjson) is decoded to str. + + ## WRITTEN BY AI ## + """ + from guidellm.schemas import GenerationResponse + + tools = [{"type": "function", "function": {"name": "fn"}}] + prior_request = GenerationRequest( + columns={ + "text_column": ["call the tool"], + "tools_column": [json.dumps(tools)], + "tool_response_column": [b'{"result": "bytes data"}'], + }, + expects_tool_call=True, + ) + prior_response = MagicMock(spec=GenerationResponse) + prior_response.tool_calls = [ + StreamingToolCall( + id="call_1", + function=StreamingToolCallFunction(name="fn"), + ) + ] + prior_response.text = None + + current_request = GenerationRequest( + columns={"text_column": ["now respond"]}, + expects_tool_call=False, + ) + + result = handler.format( + current_request, + history=[(prior_request, prior_response)], + ) + + tool_messages = [m for m in result.body["messages"] if m.get("role") == "tool"] + assert len(tool_messages) == 1 + assert tool_messages[0]["content"] == '{"result": "bytes data"}' + assert isinstance(tool_messages[0]["content"], str) + + +# --------------------------------------------------------------------------- +# JSONL pipeline integration: multi-turn tool call columns +# --------------------------------------------------------------------------- + + +def _run_row_through_pipeline(row: dict[str, Any]) -> list[GenerationRequest]: + """Push a single dataset row through the column mapper and finalizer. + + ## WRITTEN BY AI ## + """ + from datasets import Dataset + + from guidellm.data.preprocessors.mappers import GenerativeColumnMapper + + dataset = Dataset.from_dict({k: [v] for k, v in row.items()}) + + mapper = GenerativeColumnMapper() + mapper.setup_data([dataset], [{}]) + + finalizer = GenerativeRequestFinalizer() + mapped_turns = mapper([{"dataset": row}]) + return finalizer(mapped_turns) + + +class TestJsonlMultiTurnToolCallPipeline: + """Integration tests: JSONL with turn-indexed tool columns through the + full mapper-to-finalizer pipeline. + + ## WRITTEN BY AI ## + """ + + @pytest.mark.smoke + def test_consecutive_tool_turns(self): + """3-turn JSONL: turns 0-1 are tool calls, turn 2 is plain text. + + ## WRITTEN BY AI ## + """ + row = { + "prompt_0": "Call the weather tool", + "output_tokens_count_0": 50, + "tools_0": '[{"type": "function", "function": {"name": "get_weather"}}]', + "tool_response_0": '{"temp": 72}', + "prompt_1": "Now call the stock tool", + "output_tokens_count_1": 50, + "tools_1": '[{"type": "function", "function": {"name": "get_stock"}}]', + "tool_response_1": '{"price": 150}', + "prompt_2": "Summarize everything", + "output_tokens_count_2": 100, + } + + requests = _run_row_through_pipeline(row) + + assert len(requests) == 3 + + assert requests[0].expects_tool_call is True + assert "tools_column" in requests[0].columns + assert requests[0].columns["tool_response_column"] == ['{"temp": 72}'] + + assert requests[1].expects_tool_call is True + assert "tools_column" in requests[1].columns + assert requests[1].columns["tool_response_column"] == ['{"price": 150}'] + + assert requests[2].expects_tool_call is False + assert "tools_column" not in requests[2].columns + assert "tool_response_column" not in requests[2].columns + + @pytest.mark.smoke + def test_interleaved_tool_turns(self): + """4-turn JSONL: tool calls on turns 0 and 3, plain text on 1 and 2. + + ## WRITTEN BY AI ## + """ + row = { + "prompt_0": "Look up the weather", + "output_tokens_count_0": 50, + "tools_0": '[{"type": "function", "function": {"name": "get_weather"}}]', + "tool_response_0": '{"temp": 72}', + "prompt_1": "Tell me about it", + "output_tokens_count_1": 60, + "prompt_2": "Any other thoughts?", + "output_tokens_count_2": 60, + "prompt_3": "Now check stocks", + "output_tokens_count_3": 50, + "tools_3": '[{"type": "function", "function": {"name": "get_stock"}}]', + "tool_response_3": '{"price": 150}', + } + + requests = _run_row_through_pipeline(row) + + assert len(requests) == 4 + + assert requests[0].expects_tool_call is True + assert "tools_column" in requests[0].columns + assert "tool_response_column" in requests[0].columns + + assert requests[1].expects_tool_call is False + assert "tools_column" not in requests[1].columns + assert "tool_response_column" not in requests[1].columns + + assert requests[2].expects_tool_call is False + assert "tools_column" not in requests[2].columns + assert "tool_response_column" not in requests[2].columns + + assert requests[3].expects_tool_call is True + assert "tools_column" in requests[3].columns + assert "tool_response_column" in requests[3].columns