Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/packages/ag-ui/agent_framework_ag_ui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
RecipeConfirmationStrategy,
TaskPlannerConfirmationStrategy,
)
from ._endpoint import add_agent_framework_fastapi_endpoint
from ._endpoint import DEFAULT_TAGS, add_agent_framework_fastapi_endpoint
from ._event_converters import AGUIEventConverter
from ._http_service import AGUIHttpService
from ._types import AGUIRequest

try:
__version__ = importlib.metadata.version(__name__)
Expand All @@ -24,11 +25,13 @@

__all__ = [
"AgentFrameworkAgent",
"AGUIRequest",
"add_agent_framework_fastapi_endpoint",
"AGUIChatClient",
"AGUIEventConverter",
"AGUIHttpService",
"ConfirmationStrategy",
"DEFAULT_TAGS",
"DefaultConfirmationStrategy",
"TaskPlannerConfirmationStrategy",
"RecipeConfirmationStrategy",
Expand Down
28 changes: 18 additions & 10 deletions python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
"""FastAPI endpoint creation for AG-UI agents."""

import logging
from typing import Any
from typing import Any, Sequence

from ag_ui.encoder import EventEncoder
from agent_framework import AgentProtocol
from fastapi import FastAPI, Request
from fastapi import FastAPI
from fastapi.responses import StreamingResponse

from ._agent import AgentFrameworkAgent
from ._types import AGUIRequest

logger = logging.getLogger(__name__)

DEFAULT_TAGS: Sequence[str] = ["AG-UI"]


def add_agent_framework_fastapi_endpoint(
app: FastAPI,
Expand All @@ -22,17 +25,20 @@ def add_agent_framework_fastapi_endpoint(
state_schema: dict[str, Any] | None = None,
predict_state_config: dict[str, dict[str, str]] | None = None,
allow_origins: list[str] | None = None,
tags: Sequence[str] | None = None,
) -> None:
"""Add an AG-UI endpoint to a FastAPI app.

Args:
app: The FastAPI application
agent: The agent to expose (can be raw AgentProtocol or wrapped)
path: The endpoint path
state_schema: Optional state schema for shared state management
app: The FastAPI application.
agent: The agent to expose (can be raw AgentProtocol or wrapped).
path: The endpoint path.
state_schema: Optional state schema for shared state management.
predict_state_config: Optional predictive state update configuration.
Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}}
allow_origins: CORS origins (not yet implemented)
allow_origins: CORS origins (not yet implemented).
tags: Optional list of tags for OpenAPI documentation grouping.
Defaults to ["AG-UI"].
"""
if isinstance(agent, AgentProtocol):
wrapped_agent = AgentFrameworkAgent(
Expand All @@ -43,15 +49,17 @@ def add_agent_framework_fastapi_endpoint(
else:
wrapped_agent = agent

@app.post(path)
async def agent_endpoint(request: Request): # type: ignore[misc]
endpoint_tags: list[str] = list(tags) if tags is not None else list(DEFAULT_TAGS)

@app.post(path, tags=endpoint_tags) # type: ignore[arg-type]
async def agent_endpoint(request_body: AGUIRequest): # type: ignore[misc]
"""Handle AG-UI agent requests.

Note: Function is accessed via FastAPI's decorator registration,
despite appearing unused to static analysis.
"""
try:
input_data = await request.json()
input_data = request_body.model_dump(exclude_none=True)
logger.debug(
f"[{path}] Received request - Run ID: {input_data.get('run_id', 'no-run-id')}, "
f"Thread ID: {input_data.get('thread_id', 'no-thread-id')}, "
Expand Down
35 changes: 35 additions & 0 deletions python/packages/ag-ui/agent_framework_ag_ui/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from typing import Any, TypedDict

from pydantic import BaseModel, Field


class PredictStateConfig(TypedDict):
"""Configuration for predictive state updates."""
Expand All @@ -25,3 +27,36 @@ class AgentState(TypedDict):
"""Base state for AG-UI agents."""

messages: list[Any] | None


class AGUIRequest(BaseModel):
"""AG-UI request body schema for FastAPI endpoint.

This model defines the structure of incoming requests to AG-UI endpoints,
providing proper OpenAPI schema generation for Swagger UI documentation.

Attributes:
messages: List of AG-UI format messages for the conversation.
run_id: Optional identifier for the current run.
thread_id: Optional identifier for the conversation thread.
state: Optional shared state dictionary for agentic generative UI.
"""

messages: list[Any] = Field(
default_factory=list,
description="AG-UI format messages for the conversation",
)
run_id: str | None = Field(
default=None,
description="Optional identifier for the current run",
)
thread_id: str | None = Field(
default=None,
description="Optional identifier for the conversation thread",
)
state: dict[str, Any] | None = Field(
default=None,
description="Optional shared state for agentic generative UI",
)

model_config = {"extra": "allow"}
112 changes: 105 additions & 7 deletions python/packages/ag-ui/tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fastapi.testclient import TestClient

from agent_framework_ag_ui._agent import AgentFrameworkAgent
from agent_framework_ag_ui._endpoint import add_agent_framework_fastapi_endpoint
from agent_framework_ag_ui._endpoint import DEFAULT_TAGS, add_agent_framework_fastapi_endpoint


class MockChatClient:
Expand Down Expand Up @@ -139,22 +139,25 @@ async def test_endpoint_event_streaming():


async def test_endpoint_error_handling():
"""Test endpoint error handling during request parsing."""
"""Test endpoint error handling during request parsing.

With Pydantic model validation, FastAPI returns 422 Unprocessable Entity
for invalid request bodies, which is the correct HTTP semantics.
"""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())

add_agent_framework_fastapi_endpoint(app, agent, path="/failing")

client = TestClient(app)

# Send invalid JSON to trigger parsing error before streaming
# Send invalid JSON to trigger validation error
response = client.post("/failing", data="invalid json", headers={"content-type": "application/json"})

# The exception handler catches it and returns JSON error
assert response.status_code == 200
# FastAPI returns 422 for validation errors with Pydantic models
assert response.status_code == 422
content = json.loads(response.content)
assert "error" in content
assert content["error"] == "An internal error has occurred."
assert "detail" in content


async def test_endpoint_multiple_paths():
Expand Down Expand Up @@ -240,3 +243,98 @@ async def test_endpoint_complex_input():
)

assert response.status_code == 200


async def test_endpoint_default_tags():
"""Test that endpoint uses default tags when not specified."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())

add_agent_framework_fastapi_endpoint(app, agent, path="/default-tags")

# Check OpenAPI schema for default tags
openapi_schema = app.openapi()
path_item = openapi_schema["paths"]["/default-tags"]["post"]

assert "tags" in path_item
assert path_item["tags"] == DEFAULT_TAGS


async def test_endpoint_custom_tags():
"""Test that endpoint uses custom tags when specified."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
custom_tags = ["Custom", "Agent"]

add_agent_framework_fastapi_endpoint(app, agent, path="/custom-tags", tags=custom_tags)

# Check OpenAPI schema for custom tags
openapi_schema = app.openapi()
path_item = openapi_schema["paths"]["/custom-tags"]["post"]

assert "tags" in path_item
assert path_item["tags"] == custom_tags


async def test_endpoint_openapi_schema_includes_request_body():
"""Test that endpoint OpenAPI schema includes request body definition."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())

add_agent_framework_fastapi_endpoint(app, agent, path="/schema-test")

# Check OpenAPI schema for request body
openapi_schema = app.openapi()
path_item = openapi_schema["paths"]["/schema-test"]["post"]

assert "requestBody" in path_item
assert "content" in path_item["requestBody"]
assert "application/json" in path_item["requestBody"]["content"]

# Verify schema reference exists
json_content = path_item["requestBody"]["content"]["application/json"]
assert "schema" in json_content


async def test_endpoint_openapi_schema_has_agui_request_properties():
"""Test that OpenAPI schema includes AGUIRequest properties."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())

add_agent_framework_fastapi_endpoint(app, agent, path="/properties-test")

openapi_schema = app.openapi()

# Find the AGUIRequest schema in components
assert "components" in openapi_schema
assert "schemas" in openapi_schema["components"]
assert "AGUIRequest" in openapi_schema["components"]["schemas"]

agui_schema = openapi_schema["components"]["schemas"]["AGUIRequest"]
assert "properties" in agui_schema
assert "messages" in agui_schema["properties"]
assert "run_id" in agui_schema["properties"]
assert "thread_id" in agui_schema["properties"]
assert "state" in agui_schema["properties"]


async def test_endpoint_multiple_agents_different_tags():
"""Test multiple agents with different tags."""
app = FastAPI()
agent1 = ChatAgent(name="agent1", instructions="First agent", chat_client=MockChatClient())
agent2 = ChatAgent(name="agent2", instructions="Second agent", chat_client=MockChatClient())

add_agent_framework_fastapi_endpoint(app, agent1, path="/agent1", tags=["Agent1"])
add_agent_framework_fastapi_endpoint(app, agent2, path="/agent2", tags=["Agent2"])

openapi_schema = app.openapi()

assert openapi_schema["paths"]["/agent1"]["post"]["tags"] == ["Agent1"]
assert openapi_schema["paths"]["/agent2"]["post"]["tags"] == ["Agent2"]


def test_default_tags_constant():
"""Test that DEFAULT_TAGS constant is correct."""
assert DEFAULT_TAGS == ["AG-UI"]
assert isinstance(DEFAULT_TAGS, list)
assert len(DEFAULT_TAGS) == 1
Loading