Skip to content

Feature/message_filter #687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,6 @@ cython_debug/
# PyPI configuration file
.pypirc
.aider*

# VSCode Local history
.history
2 changes: 1 addition & 1 deletion src/agents/models/openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ async def _fetch_response(
stream_options=self._non_null_or_not_given(stream_options),
store=self._non_null_or_not_given(store),
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) },
extra_headers={**HEADERS, **(model_settings.extra_headers or {})},
extra_query=model_settings.extra_query,
extra_body=model_settings.extra_body,
metadata=self._non_null_or_not_given(model_settings.metadata),
Expand Down
94 changes: 94 additions & 0 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import asyncio
import copy
from collections.abc import Iterable
from dataclasses import dataclass, field
from inspect import iscoroutinefunction
from typing import Any, cast

from openai.types.responses import ResponseCompletedEvent
Expand Down Expand Up @@ -70,6 +72,19 @@ class RunConfig:
agent. See the documentation in `Handoff.input_filter` for more details.
"""

run_step_input_filter: (
callable[str | list[TResponseInputItem], str | list[TResponseInputItem]] | None
) = None
"""A global input filter to apply between agent steps. If set, the input to the agent will be
passed through this function before being sent to the model. This is useful for modifying the
input to the model, for example, to manage the context window size."""

run_step_input_filter_raise_error: bool = False
"""What to do if the input filter raises an exception. If False (the default), we'll continue
with the original input. If True, we'll raise the exception. This is useful for debugging, but
generally you want to set this to False so that the agent can continue running even if
the input filter fails."""

input_guardrails: list[InputGuardrail[Any]] | None = None
"""A list of input guardrails to run on the initial run input."""

Expand Down Expand Up @@ -214,6 +229,12 @@ async def run(
f"Running agent {current_agent.name} (turn {current_turn})",
)

original_input = await cls._run_step_input_filter(
original_input=original_input,
run_config=run_config,
span=current_span,
)

if current_turn == 1:
input_guardrail_results, turn_result = await asyncio.gather(
cls._run_input_guardrails(
Expand Down Expand Up @@ -546,6 +567,10 @@ async def _run_streamed_impl(
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
break

streamed_result.input = await cls._run_step_input_filter(
original_input=streamed_result.input, run_config=run_config, span=current_span
)

if current_turn == 1:
# Run the input guardrails in the background and put the results on the queue
streamed_result._input_guardrails_task = asyncio.create_task(
Expand Down Expand Up @@ -966,3 +991,72 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
return agent.model

return run_config.model_provider.get_model(agent.model)

@classmethod
async def _run_step_input_filter(
cls,
original_input: str | list[TResponseInputItem],
run_config: RunConfig,
span: Span[AgentSpanData],
) -> str | list[TResponseInputItem]:
filter = run_config.run_step_input_filter
_raise = run_config.run_step_input_filter_raise_error

def is_acceptable_response(response: object) -> bool:
return isinstance(response, str) or (
isinstance(response, Iterable)
and all(
"type" in item
for item in response # minimal check for ResponseInputItem
)
)

if not filter:
return original_input

if not callable(filter):
_error_tracing.attach_error_to_span(
span,
SpanError(
message="Input step filter is not callable",
data={"input_step_filter": filter},
),
)
if _raise:
raise ModelBehaviorError("Input step filter is not callable")
return original_input
try:
if iscoroutinefunction(filter):
input_filter_response = await filter(original_input)
else:
input_filter_response = filter(original_input)
except Exception as e:
_error_tracing.attach_error_to_span(
span,
SpanError(
message="Input step filter raised an exception",
data={
"input_step_filter": filter,
"exception": str(e),
},
),
)
if _raise:
raise ModelBehaviorError("Input step filter raised an exception") from e
return original_input

if not is_acceptable_response(input_filter_response):
_error_tracing.attach_error_to_span(
span,
SpanError(
message=(
"Input step filter did not return a string or list of ResponseInputItems"
),
data={"input_step_filter": filter, "response": input_filter_response},
),
)
if _raise:
raise ModelBehaviorError("Input step filter did not return a string or list")
return original_input

return input_filter_response
2 changes: 2 additions & 0 deletions src/agents/voice/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
TTSVoice = Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"]
"""Exportable type for the TTSModelSettings voice enum"""


@dataclass
class TTSModelSettings:
"""Settings for a TTS model."""

voice: TTSVoice | None = None
"""
The voice to use for the TTS model. If not provided, the default voice for the respective model
Expand Down
9 changes: 4 additions & 5 deletions tests/test_extra_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ class DummyResponses:
async def create(self, **kwargs):
nonlocal called_kwargs
called_kwargs = kwargs

class DummyResponse:
id = "dummy"
output = []
usage = type(
"Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
)()

return DummyResponse()

class DummyClient:
def __init__(self):
self.responses = DummyResponses()



model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
extra_headers = {"X-Test-Header": "test-value"}
await model.get_response(
system_instructions=None,
Expand All @@ -47,7 +47,6 @@ def __init__(self):
assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value"



@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_extra_headers_passed_to_openai_client():
Expand Down Expand Up @@ -76,7 +75,7 @@ def __init__(self):
self.chat = type("_Chat", (), {"completions": DummyCompletions()})()
self.base_url = "https://api.openai.com"

model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
extra_headers = {"X-Test-Header": "test-value"}
await model.get_response(
system_instructions=None,
Expand Down
Loading