diff --git a/python/packages/ag-ui/agent_framework_ag_ui/__init__.py b/python/packages/ag-ui/agent_framework_ag_ui/__init__.py index 143f2499a0..5c9646e5a3 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/__init__.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/__init__.py @@ -13,9 +13,10 @@ RecipeConfirmationStrategy, TaskPlannerConfirmationStrategy, ) -from ._endpoint import add_agent_framework_fastapi_endpoint +from ._endpoint import DEFAULT_TAGS, add_agent_framework_fastapi_endpoint from ._event_converters import AGUIEventConverter from ._http_service import AGUIHttpService +from ._types import AGUIRequest try: __version__ = importlib.metadata.version(__name__) @@ -24,11 +25,13 @@ __all__ = [ "AgentFrameworkAgent", + "AGUIRequest", "add_agent_framework_fastapi_endpoint", "AGUIChatClient", "AGUIEventConverter", "AGUIHttpService", "ConfirmationStrategy", + "DEFAULT_TAGS", "DefaultConfirmationStrategy", "TaskPlannerConfirmationStrategy", "RecipeConfirmationStrategy", diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py b/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py index d1baad5561..79b812b90f 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py @@ -3,17 +3,20 @@ """FastAPI endpoint creation for AG-UI agents.""" import logging -from typing import Any +from typing import Any, Sequence from ag_ui.encoder import EventEncoder from agent_framework import AgentProtocol -from fastapi import FastAPI, Request +from fastapi import FastAPI from fastapi.responses import StreamingResponse from ._agent import AgentFrameworkAgent +from ._types import AGUIRequest logger = logging.getLogger(__name__) +DEFAULT_TAGS: Sequence[str] = ["AG-UI"] + def add_agent_framework_fastapi_endpoint( app: FastAPI, @@ -22,17 +25,20 @@ def add_agent_framework_fastapi_endpoint( state_schema: dict[str, Any] | None = None, predict_state_config: dict[str, dict[str, str]] | None = None, allow_origins: list[str] | None = None, + tags: Sequence[str] | None = None, ) -> None: """Add an AG-UI endpoint to a FastAPI app. Args: - app: The FastAPI application - agent: The agent to expose (can be raw AgentProtocol or wrapped) - path: The endpoint path - state_schema: Optional state schema for shared state management + app: The FastAPI application. + agent: The agent to expose (can be raw AgentProtocol or wrapped). + path: The endpoint path. + state_schema: Optional state schema for shared state management. predict_state_config: Optional predictive state update configuration. Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}} - allow_origins: CORS origins (not yet implemented) + allow_origins: CORS origins (not yet implemented). + tags: Optional list of tags for OpenAPI documentation grouping. + Defaults to ["AG-UI"]. """ if isinstance(agent, AgentProtocol): wrapped_agent = AgentFrameworkAgent( @@ -43,15 +49,17 @@ def add_agent_framework_fastapi_endpoint( else: wrapped_agent = agent - @app.post(path) - async def agent_endpoint(request: Request): # type: ignore[misc] + endpoint_tags: list[str] = list(tags) if tags is not None else list(DEFAULT_TAGS) + + @app.post(path, tags=endpoint_tags) # type: ignore[arg-type] + async def agent_endpoint(request_body: AGUIRequest): # type: ignore[misc] """Handle AG-UI agent requests. Note: Function is accessed via FastAPI's decorator registration, despite appearing unused to static analysis. """ try: - input_data = await request.json() + input_data = request_body.model_dump(exclude_none=True) logger.debug( f"[{path}] Received request - Run ID: {input_data.get('run_id', 'no-run-id')}, " f"Thread ID: {input_data.get('thread_id', 'no-thread-id')}, " diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_types.py b/python/packages/ag-ui/agent_framework_ag_ui/_types.py index da7d80ea66..84d31cee3c 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_types.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_types.py @@ -4,6 +4,8 @@ from typing import Any, TypedDict +from pydantic import BaseModel, Field + class PredictStateConfig(TypedDict): """Configuration for predictive state updates.""" @@ -25,3 +27,36 @@ class AgentState(TypedDict): """Base state for AG-UI agents.""" messages: list[Any] | None + + +class AGUIRequest(BaseModel): + """AG-UI request body schema for FastAPI endpoint. + + This model defines the structure of incoming requests to AG-UI endpoints, + providing proper OpenAPI schema generation for Swagger UI documentation. + + Attributes: + messages: List of AG-UI format messages for the conversation. + run_id: Optional identifier for the current run. + thread_id: Optional identifier for the conversation thread. + state: Optional shared state dictionary for agentic generative UI. + """ + + messages: list[Any] = Field( + default_factory=list, + description="AG-UI format messages for the conversation", + ) + run_id: str | None = Field( + default=None, + description="Optional identifier for the current run", + ) + thread_id: str | None = Field( + default=None, + description="Optional identifier for the conversation thread", + ) + state: dict[str, Any] | None = Field( + default=None, + description="Optional shared state for agentic generative UI", + ) + + model_config = {"extra": "allow"} diff --git a/python/packages/ag-ui/tests/test_endpoint.py b/python/packages/ag-ui/tests/test_endpoint.py index 1ae364f818..a56e780732 100644 --- a/python/packages/ag-ui/tests/test_endpoint.py +++ b/python/packages/ag-ui/tests/test_endpoint.py @@ -11,7 +11,7 @@ from fastapi.testclient import TestClient from agent_framework_ag_ui._agent import AgentFrameworkAgent -from agent_framework_ag_ui._endpoint import add_agent_framework_fastapi_endpoint +from agent_framework_ag_ui._endpoint import DEFAULT_TAGS, add_agent_framework_fastapi_endpoint class MockChatClient: @@ -139,7 +139,11 @@ async def test_endpoint_event_streaming(): async def test_endpoint_error_handling(): - """Test endpoint error handling during request parsing.""" + """Test endpoint error handling during request parsing. + + With Pydantic model validation, FastAPI returns 422 Unprocessable Entity + for invalid request bodies, which is the correct HTTP semantics. + """ app = FastAPI() agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) @@ -147,14 +151,13 @@ async def test_endpoint_error_handling(): client = TestClient(app) - # Send invalid JSON to trigger parsing error before streaming + # Send invalid JSON to trigger validation error response = client.post("/failing", data="invalid json", headers={"content-type": "application/json"}) - # The exception handler catches it and returns JSON error - assert response.status_code == 200 + # FastAPI returns 422 for validation errors with Pydantic models + assert response.status_code == 422 content = json.loads(response.content) - assert "error" in content - assert content["error"] == "An internal error has occurred." + assert "detail" in content async def test_endpoint_multiple_paths(): @@ -240,3 +243,98 @@ async def test_endpoint_complex_input(): ) assert response.status_code == 200 + + +async def test_endpoint_default_tags(): + """Test that endpoint uses default tags when not specified.""" + app = FastAPI() + agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) + + add_agent_framework_fastapi_endpoint(app, agent, path="/default-tags") + + # Check OpenAPI schema for default tags + openapi_schema = app.openapi() + path_item = openapi_schema["paths"]["/default-tags"]["post"] + + assert "tags" in path_item + assert path_item["tags"] == DEFAULT_TAGS + + +async def test_endpoint_custom_tags(): + """Test that endpoint uses custom tags when specified.""" + app = FastAPI() + agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) + custom_tags = ["Custom", "Agent"] + + add_agent_framework_fastapi_endpoint(app, agent, path="/custom-tags", tags=custom_tags) + + # Check OpenAPI schema for custom tags + openapi_schema = app.openapi() + path_item = openapi_schema["paths"]["/custom-tags"]["post"] + + assert "tags" in path_item + assert path_item["tags"] == custom_tags + + +async def test_endpoint_openapi_schema_includes_request_body(): + """Test that endpoint OpenAPI schema includes request body definition.""" + app = FastAPI() + agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) + + add_agent_framework_fastapi_endpoint(app, agent, path="/schema-test") + + # Check OpenAPI schema for request body + openapi_schema = app.openapi() + path_item = openapi_schema["paths"]["/schema-test"]["post"] + + assert "requestBody" in path_item + assert "content" in path_item["requestBody"] + assert "application/json" in path_item["requestBody"]["content"] + + # Verify schema reference exists + json_content = path_item["requestBody"]["content"]["application/json"] + assert "schema" in json_content + + +async def test_endpoint_openapi_schema_has_agui_request_properties(): + """Test that OpenAPI schema includes AGUIRequest properties.""" + app = FastAPI() + agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient()) + + add_agent_framework_fastapi_endpoint(app, agent, path="/properties-test") + + openapi_schema = app.openapi() + + # Find the AGUIRequest schema in components + assert "components" in openapi_schema + assert "schemas" in openapi_schema["components"] + assert "AGUIRequest" in openapi_schema["components"]["schemas"] + + agui_schema = openapi_schema["components"]["schemas"]["AGUIRequest"] + assert "properties" in agui_schema + assert "messages" in agui_schema["properties"] + assert "run_id" in agui_schema["properties"] + assert "thread_id" in agui_schema["properties"] + assert "state" in agui_schema["properties"] + + +async def test_endpoint_multiple_agents_different_tags(): + """Test multiple agents with different tags.""" + app = FastAPI() + agent1 = ChatAgent(name="agent1", instructions="First agent", chat_client=MockChatClient()) + agent2 = ChatAgent(name="agent2", instructions="Second agent", chat_client=MockChatClient()) + + add_agent_framework_fastapi_endpoint(app, agent1, path="/agent1", tags=["Agent1"]) + add_agent_framework_fastapi_endpoint(app, agent2, path="/agent2", tags=["Agent2"]) + + openapi_schema = app.openapi() + + assert openapi_schema["paths"]["/agent1"]["post"]["tags"] == ["Agent1"] + assert openapi_schema["paths"]["/agent2"]["post"]["tags"] == ["Agent2"] + + +def test_default_tags_constant(): + """Test that DEFAULT_TAGS constant is correct.""" + assert DEFAULT_TAGS == ["AG-UI"] + assert isinstance(DEFAULT_TAGS, list) + assert len(DEFAULT_TAGS) == 1 diff --git a/python/packages/ag-ui/tests/test_types.py b/python/packages/ag-ui/tests/test_types.py index 3c61278d9e..5305354f68 100644 --- a/python/packages/ag-ui/tests/test_types.py +++ b/python/packages/ag-ui/tests/test_types.py @@ -2,7 +2,7 @@ """Tests for type definitions in _types.py.""" -from agent_framework_ag_ui._types import AgentState, PredictStateConfig, RunMetadata +from agent_framework_ag_ui._types import AgentState, AGUIRequest, PredictStateConfig, RunMetadata class TestPredictStateConfig: @@ -143,3 +143,109 @@ def test_agent_state_complex_messages(self) -> None: assert len(state["messages"]) == 2 assert "metadata" in state["messages"][0] assert "tool_calls" in state["messages"][1] + + +class TestAGUIRequest: + """Test AGUIRequest Pydantic model.""" + + def test_agui_request_creation_with_defaults(self) -> None: + """Test creating AGUIRequest with default values.""" + request = AGUIRequest() + + assert request.messages == [] + assert request.run_id is None + assert request.thread_id is None + assert request.state is None + + def test_agui_request_with_messages(self) -> None: + """Test AGUIRequest with messages.""" + request = AGUIRequest( + messages=[ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + ] + ) + + assert len(request.messages) == 2 + assert request.messages[0]["role"] == "user" + assert request.messages[1]["role"] == "assistant" + + def test_agui_request_with_all_fields(self) -> None: + """Test AGUIRequest with all fields populated.""" + request = AGUIRequest( + messages=[{"role": "user", "content": "Test"}], + run_id="run-123", + thread_id="thread-456", + state={"document": "Hello world"}, + ) + + assert len(request.messages) == 1 + assert request.run_id == "run-123" + assert request.thread_id == "thread-456" + assert request.state == {"document": "Hello world"} + + def test_agui_request_model_dump(self) -> None: + """Test AGUIRequest model_dump method.""" + request = AGUIRequest( + messages=[{"role": "user", "content": "Test"}], + run_id="run-123", + ) + + data = request.model_dump() + assert "messages" in data + assert "run_id" in data + assert "thread_id" in data + assert "state" in data + + def test_agui_request_model_dump_exclude_none(self) -> None: + """Test AGUIRequest model_dump with exclude_none.""" + request = AGUIRequest( + messages=[{"role": "user", "content": "Test"}], + run_id="run-123", + ) + + data = request.model_dump(exclude_none=True) + assert "messages" in data + assert "run_id" in data + assert "thread_id" not in data + assert "state" not in data + + def test_agui_request_allows_extra_fields(self) -> None: + """Test AGUIRequest allows extra fields.""" + request = AGUIRequest( + messages=[], + custom_field="custom_value", + ) + + data = request.model_dump() + assert data["custom_field"] == "custom_value" + + def test_agui_request_from_dict(self) -> None: + """Test creating AGUIRequest from dict.""" + data = { + "messages": [{"role": "user", "content": "Hello"}], + "run_id": "run-789", + "thread_id": "thread-012", + "state": {"key": "value"}, + } + + request = AGUIRequest(**data) + + assert len(request.messages) == 1 + assert request.run_id == "run-789" + assert request.thread_id == "thread-012" + assert request.state == {"key": "value"} + + def test_agui_request_json_schema(self) -> None: + """Test AGUIRequest generates proper JSON schema.""" + schema = AGUIRequest.model_json_schema() + + assert "properties" in schema + assert "messages" in schema["properties"] + assert "run_id" in schema["properties"] + assert "thread_id" in schema["properties"] + assert "state" in schema["properties"] + + # Verify descriptions are present + assert "description" in schema["properties"]["messages"] + assert "description" in schema["properties"]["run_id"]