Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
117 changes: 53 additions & 64 deletions samples/python/agents/headless_agent_auth/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import json
import logging
import os
import sys
Expand All @@ -16,15 +15,8 @@
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.server.tasks import InMemoryTaskStore
from a2a.types import (
AgentAuthentication,
AgentCapabilities,
AgentCard,
AgentSkill,
# ClientCredentialsOAuthFlow,
# OAuth2SecurityScheme,
# OAuthFlows,
)
from agent import HRAgent
from agent_executor import HRAgentExecutor
from api import hr_api
from oauth2_middleware import OAuth2Middleware
Expand All @@ -35,11 +27,14 @@


@click.command()
@click.option('--host', default='0.0.0.0')
@click.option('--host', default='0.0.0.0') # noqa: S104
@click.option('--port_agent', default=10050)
@click.option('--port_api', default=10051)
def main(host: str, port_agent: int, port_api: int):
async def run_all():
def main(host: str, port_agent: int, port_api: int) -> None:
"""Run the HR Agent and the HR API."""

async def run_all() -> None:
"""Run all components concurrently."""
await asyncio.gather(
start_agent(host, port_agent),
start_api(host, port_api),
Expand All @@ -48,56 +43,46 @@ async def run_all():
asyncio.run(run_all())


async def start_agent(host: str, port):
agent_card = AgentCard(
name='Staff0 HR Agent',
description='This agent handles external verification requests about Staff0 employees made by third parties.',
url=f'http://{host}:{port}/',
version='0.1.0',
default_input_modes=HRAgent.SUPPORTED_CONTENT_TYPES,
default_output_modes=HRAgent.SUPPORTED_CONTENT_TYPES,
capabilities=AgentCapabilities(streaming=True),
skills=[
AgentSkill(
id='is_active_employee',
name='Check Employment Status Tool',
description='Confirm whether a person is an active employee of the company.',
tags=['employment status'],
examples=[
'Does John Doe with email jdoe@staff0.com work at Staff0?'
],
)
async def start_agent(host: str, port: int) -> None:
"""Start the HR Agent server."""
# We define the configuration as a raw dictionary first.
# This avoids the "no attribute root" error by letting the AgentCard
# constructor handle the internal Pydantic mapping itself.
card_config = {
'name': 'Staff0 HR Agent',
'description': 'This agent handles external verification requests...',
'url': f'http://{host}:{port}/',
'version': '0.1.0',
'default_input_modes': ['application/json'],
'default_output_modes': ['application/json'],
'capabilities': {'streaming': True},
'skills': [
{
'id': 'is_active_employee',
'name': 'Check Employment Status Tool',
'description': 'Confirm whether a person is an active employee.',
'tags': ['employment status'],
'examples': ['Does John Doe work at Staff0?'],
}
],
authentication=AgentAuthentication(
schemes=['oauth2'],
credentials=json.dumps(
{
'tokenUrl': f'https://{os.getenv("HR_AUTH0_DOMAIN")}/oauth/token',
'scopes': {
'read:employee_status': 'Allows confirming whether a person is an active employee of the company.'
},
}
),
),
# security_schemes={
# 'oauth2_m2m_client': OAuth2SecurityScheme(
# description='',
# flows=OAuthFlows(
# authorization_code=ClientCredentialsOAuthFlow(
# token_url=f'https://{os.getenv("HR_AUTH0_DOMAIN")}/oauth/token',
# scopes={
# 'read:employee_status': 'Allows confirming whether a person is an active employee of the company.',
# },
# ),
# ),
# ),
# },
# security=[{
# 'oauth2_m2m_client': [
# 'read:employee_status',
# ],
# }],
)
'security_schemes': {
'oauth2_m2m': {
'type': 'oauth2',
'flows': {
'client_credentials': {
'token_url': f'https://{os.getenv("HR_AUTH0_DOMAIN")}/oauth/token',
'scopes': {'read:employee_status': 'Verify status'},
}
},
}
},
'security': [{'oauth2_m2m': ['read:employee_status']}],
}

# Now, pass the WHOLE dictionary into the constructor.
# The SDK will convert the nested dicts into the proper
# OAuth2SecurityScheme objects internally.
agent_card = AgentCard(**card_config)

request_handler = DefaultRequestHandler(
agent_executor=HRAgentExecutor(),
Expand All @@ -112,15 +97,19 @@ async def start_agent(host: str, port):
app.add_middleware(
OAuth2Middleware,
agent_card=agent_card,
public_paths=['/.well-known/agent.json'],
public_paths=[
'/.well-known/agent.json',
'/.well-known/agent-card.json',
],
)

logger.info(f'Starting HR Agent server on {host}:{port}')
logger.info('Starting HR Agent server on %s:%s', host, port)
await uvicorn.Server(uvicorn.Config(app=app, host=host, port=port)).serve()


async def start_api(host: str, port):
logger.info(f'Starting HR API server on {host}:{port}')
async def start_api(host: str, port: int) -> None:
"""Start the HR API server."""
logger.info('Starting HR API server on %s:%s', host, port)
await uvicorn.Server(
uvicorn.Config(app=hr_api, host=host, port=port)
).serve()
Expand Down
57 changes: 40 additions & 17 deletions samples/python/agents/headless_agent_auth/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from pydantic import BaseModel


HTTP_STATUS_OK = 200
HTTP_STATUS_NOT_FOUND = 404


auth0_ai = Auth0AI(
auth0={
'domain': os.getenv('HR_AUTH0_DOMAIN'),
Expand All @@ -32,7 +36,7 @@
scopes=['read:employee'],
user_id=lambda employee_id, **__: employee_id,
audience=os.getenv('HR_API_AUTH0_AUDIENCE'),
on_authorization_request='block', # TODO: this is just for demo purposes
on_authorization_request='block', # note: this is just for demo purposes
)


Expand Down Expand Up @@ -63,14 +67,14 @@ async def is_active_employee(employee_id: str) -> dict[str, Any]:
},
)

if response.status_code == 404:
if response.status_code == HTTP_STATUS_NOT_FOUND:
return {'active': False}
if response.status_code == 200:
if response.status_code == HTTP_STATUS_OK:
return {'active': True}
response.raise_for_status()
except httpx.HTTPError as e:
return {'error': f'HR API request failed: {e}'}
except Exception:
except Exception: # noqa: BLE001
return {'error': 'Unexpected response from HR API.'}


Expand All @@ -95,7 +99,7 @@ def get_employee_id_by_email(work_email: str) -> dict[str, Any] | None:
)[0]

return {'employee_id': user['user_id']} if user else None
except Exception:
except Exception: # noqa: BLE001
return {'error': 'Unexpected response from Auth0 Management API.'}


Expand All @@ -109,6 +113,8 @@ class ResponseFormat(BaseModel):


class HRAgent:
"""HR Agent that handles external verification requests."""

SUPPORTED_CONTENT_TYPES = ['text', 'text/plain']

SYSTEM_INSTRUCTION: str = (
Expand All @@ -125,7 +131,8 @@ class HRAgent:
'For any other tool error, set the status to "failed".'
)

def __init__(self):
def __init__(self) -> None:
"""Initialize the HR Agent."""
self.model = ChatGoogleGenerativeAI(model='gemini-2.0-flash')
self.tools = [
get_employee_id_by_email,
Expand All @@ -141,41 +148,57 @@ def __init__(self):
)

async def invoke(self, query: str, context_id: str) -> dict[str, Any]:
"""Invoke the agent with a query."""
config: RunnableConfig = {'configurable': {'thread_id': context_id}}
await self.graph.ainvoke({'messages': [('user', query)]}, config)
return self.get_agent_response(config)

async def stream(
self, query: str, context_id: str
) -> AsyncIterable[dict[str, Any]]:
inputs: dict[str, Any] = {'messages': [('user', query)]}
"""Stream the agent's response."""
inputs: dict[str, any] = {'messages': [('user', query)]}
config: RunnableConfig = {'configurable': {'thread_id': context_id}}

async for item in self.graph.astream(
inputs, config, stream_mode='values'
async for chunk in self.graph.astream(
inputs, config, stream_mode='updates'
):
message = item['messages'][-1] if 'messages' in item else None
node_name = next(iter(chunk.keys()))
data = chunk[node_name]

# 1. Handle intermediate tool steps
messages = data.get('messages', [])
message = messages[-1] if messages else None
if message:
if (
isinstance(message, AIMessage)
and message.tool_calls
and len(message.tool_calls) > 0
):
if isinstance(message, AIMessage) and message.tool_calls:
yield {
'is_task_complete': False,
'task_state': 'working',
'content': 'Looking up the employment status...',
'content': 'Looking up...',
}
elif isinstance(message, ToolMessage):
yield {
'is_task_complete': False,
'task_state': 'working',
'content': 'Processing the employment status...',
'content': 'Processing...',
}

# 2. ADD THIS: Handle the final structured response node
if node_name == 'generate_structured_response':
# The response is usually in a variable called 'structured_response'
resp = data.get('structured_response')
if resp:
yield {
'is_task_complete': resp.status == 'completed',
'task_state': resp.status,
'content': resp.message,
}

# 3. Fallback: Always try to get the final state if the loop finishes
yield self.get_agent_response(config)

def get_agent_response(self, config: RunnableConfig) -> dict[str, Any]:
"""Get the final agent response from the state."""
current_state = self.graph.get_state(config)
structured_response = current_state.values.get('structured_response')

Expand Down
14 changes: 12 additions & 2 deletions samples/python/agents/headless_agent_auth/agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
class HRAgentExecutor(AgentExecutor):
"""HR AgentExecutor Example."""

def __init__(self):
def __init__(self) -> None:
"""Initialize the HR Agent Executor."""
self.agent = HRAgent()

async def execute(
self,
context: RequestContext,
event_queue: EventQueue,
) -> None:
"""Execute the HR Agent."""
query = context.get_user_input()
task = context.current_task

Expand All @@ -47,7 +49,14 @@ async def execute(
)
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
status=TaskStatus(state=task_state),
status=TaskStatus(
state=task_state,
message=new_agent_text_message(
event['content'],
task.context_id,
task.id,
),
),
final=True,
context_id=task.context_id,
task_id=task.id,
Expand Down Expand Up @@ -78,4 +87,5 @@ async def execute(
async def cancel(
self, request: RequestContext, event_queue: EventQueue
) -> Task | None:
"""Cancel the HR Agent (not supported)."""
raise Exception('cancel not supported')
16 changes: 11 additions & 5 deletions samples/python/agents/headless_agent_auth/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

from typing import Annotated

from fastapi import Depends, FastAPI
from fastapi_plugin import Auth0FastAPI

Expand All @@ -12,12 +14,16 @@
app = FastAPI()


@app.get('/employees/{id}')
@app.get('/employees/{employee_id}')
def get_employee(
id: str, _claims: dict = Depends(auth0.require_auth(scopes='read:employee'))
):
# TODO: if needed, return more employee details
return {'employee_id': id}
employee_id: str,
_claims: Annotated[
dict, Depends(auth0.require_auth(scopes='read:employee'))
],
) -> dict:
"""Get employee details by ID."""
# Note: if needed, return more employee details here
return {'employee_id': employee_id}


hr_api = app
Loading