-
Notifications
You must be signed in to change notification settings - Fork 599
feat(dspy): add dspy a2a example #416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ae633d5
eb2236e
43291b2
fe37c4e
9d4c24b
b066adb
59bf2e5
015bc84
c3a2d04
60847bf
f6c6c07
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| BRAINTRUST_API_KEY=sk-... | ||
| MEM0_API_KEY=m0-... | ||
| OPENAI_API_KEY=sk-proj-... |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
|
|
||
| # A2A DSPy Agent with Memory | ||
|
|
||
| This sample demonstrates an Agent-to-Agent (A2A) server built using [DSPy](https://github.com/stanfordnlp/dspy), a framework for programming with language models. The agent features conversational memory powered by [Mem0](https://mem0.ai/) and observability through [Braintrust](https://www.braintrust.dev/). | ||
|
|
||
| ## Core Functionality | ||
|
|
||
| * **DSPy Integration:** Uses DSPy's `ChainOfThought` module for structured reasoning and response generation. | ||
| * **Memory Management:** Leverages Mem0 to store and retrieve user interactions, enabling context-aware conversations across sessions. | ||
| * **Observability:** Integrated with Braintrust for tracing agent execution, LLM calls, and memory operations. | ||
| * **A2A Protocol:** Fully compliant A2A server that supports both stateful conversations and task completion. | ||
|
|
||
| ## Files | ||
|
|
||
| * `__main__.py`: The main entry point that configures and starts the A2A server with CORS support. | ||
| * `executor.py`: The `DspyAgentExecutor` class that implements the A2A `AgentExecutor` interface, handling task execution and memory operations. | ||
| * `agents/dspy_example.py`: DSPy agent definition using a custom `AgentSignature` with Chain-of-Thought reasoning. | ||
| * `memory/base.py`: Abstract base class for memory implementations. | ||
| * `memory/mem0.py`: Mem0 memory implementation for storing and retrieving conversation context. | ||
| * `logger.py`: Logging configuration for the agent. | ||
| * `test_client.py`: Test client to interact with the agent. | ||
|
|
||
| ## Prerequisites | ||
|
|
||
| * Python 3.13 | ||
| * OpenAI API Key | ||
| * Mem0 API Key | ||
| * Braintrust API Key (optional, for observability) | ||
|
|
||
| ## Setup | ||
|
|
||
| 1. **Set Environment Variables:** | ||
|
|
||
| Create a `.env` file or export the following environment variables: | ||
|
|
||
| ```bash | ||
| export OPENAI_API_KEY="your-openai-api-key" | ||
| export MEM0_API_KEY="your-mem0-api-key" | ||
| export BRAINTRUST_API_KEY="your-braintrust-api-key" # Optional | ||
| ``` | ||
|
|
||
| Replace the placeholder values with your actual API keys. | ||
|
|
||
| 2. **Install Dependencies:** | ||
|
|
||
| The project uses `uv` for dependency management. Dependencies are defined in `pyproject.toml`. | ||
|
|
||
| ## Running the Application | ||
|
|
||
| 1. **Start the A2A Server:** | ||
|
|
||
| ```bash | ||
| uv run . | ||
| ``` | ||
|
|
||
| By default, the server will start on `http://localhost:10020`. You can customize the host and port: | ||
|
|
||
| ```bash | ||
| uv run . --host 0.0.0.0 --port 8080 | ||
| ``` | ||
|
|
||
| 2. **Interact with the Agent:** | ||
|
|
||
| You can use the included test client: | ||
|
|
||
| ```bash | ||
| uv run test_client.py | ||
| ``` | ||
|
|
||
| Or use the CLI host from the samples: | ||
|
|
||
| ```bash | ||
| cd samples/python/hosts/cli | ||
| uv run . --agent http://localhost:10020 | ||
| ``` | ||
|
|
||
| ## How It Works | ||
|
|
||
| 1. **User Input:** The agent receives a question or message through the A2A protocol. | ||
| 2. **Memory Retrieval:** The agent queries Mem0 for relevant past interactions using the user's context ID. | ||
| 3. **DSPy Processing:** The question and retrieved context are passed to the DSPy `ChainOfThought` module. | ||
| 4. **Response Generation:** DSPy generates a response using GPT-4o-mini, determining if the task is complete or requires more input. | ||
| 5. **Memory Storage:** The interaction (user input and agent response) is saved to Mem0 for future context. | ||
| 6. **Task Completion:** The agent either completes the task with an artifact or requests additional input. | ||
|
|
||
| ## Agent Capabilities | ||
|
|
||
| The agent exposes a single skill: | ||
|
|
||
| * **Skill ID:** `dspy_agent` | ||
| * **Name:** DSPy Agent | ||
| * **Description:** A simple DSPy agent that can answer questions and remember user interactions. | ||
| * **Tags:** DSPy, Memory, Mem0 | ||
| * **Example Queries:** | ||
| - "What is the capital of France?" | ||
| - "What did I ask you about earlier?" | ||
| - "Remember that I prefer morning meetings." | ||
|
|
||
| ## Memory Features | ||
|
|
||
| The agent uses Mem0 to provide: | ||
|
|
||
| * **User-specific Memory:** Each user (identified by `context_id`) has their own memory space. | ||
| * **Semantic Retrieval:** Memories are retrieved based on semantic similarity to the current query. | ||
| * **Persistent Context:** Conversations are remembered across sessions, enabling continuity. | ||
|
|
||
| ## Observability | ||
|
|
||
| With Braintrust integration, you can: | ||
|
|
||
| * Track each agent execution with detailed spans | ||
| * Monitor LLM calls and their inputs/outputs | ||
| * View memory retrieval and storage operations | ||
| * Analyze performance and debug issues | ||
|
|
||
| Visit the Braintrust dashboard to view traces after interacting with the agent. | ||
|
|
||
| ## Disclaimer | ||
|
|
||
| Important: The sample code provided is for demonstration purposes and illustrates the mechanics of the Agent-to-Agent (A2A) protocol. When building production applications, it is critical to treat any agent operating outside of your direct control as a potentially untrusted entity. | ||
|
|
||
| All data received from an external agent—including but not limited to its AgentCard, messages, artifacts, and task statuses—should be handled as untrusted input. | ||
| For example, a malicious agent could provide an AgentCard containing crafted data in its fields (e.g., description, name, skills.description). | ||
| If this data is used without sanitization to construct prompts for a Large Language Model (LLM), it could expose your application to prompt injection attacks. | ||
| Failure to properly validate and sanitize this data before use can introduce security vulnerabilities into your application. | ||
|
|
||
| Developers are responsible for implementing appropriate security measures, such as input validation and secure handling of credentials to protect their systems and users. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| import os | ||
|
|
||
| import dspy | ||
|
|
||
| from braintrust.wrappers.dspy import BraintrustDSpyCallback | ||
| from dotenv import load_dotenv | ||
|
|
||
|
|
||
| load_dotenv() | ||
|
|
||
| lm = dspy.LM(model='gpt-4o-mini', api_key=os.getenv('OPENAI_API_KEY')) | ||
| dspy.configure(lm=lm, callbacks=[BraintrustDSpyCallback()]) | ||
|
|
||
|
|
||
| class AgentSignature(dspy.Signature): | ||
| """You are a helpful assistant that can answer any question.""" | ||
|
|
||
| question: str = dspy.InputField(description='The question to answer') | ||
| ctx: list[dict] = dspy.InputField( | ||
| description='The context to use for the question' | ||
| ) | ||
| answer: str = dspy.OutputField(description='The answer to the question') | ||
| completed_task: bool = dspy.OutputField( | ||
| description='Whether the task is complete or need more input' | ||
| ) | ||
|
|
||
|
|
||
| agent = dspy.ChainOfThought(signature=AgentSignature) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| from a2a.server.agent_execution import AgentExecutor, RequestContext | ||
| from a2a.server.events import EventQueue | ||
| from a2a.server.tasks import TaskUpdater | ||
| from a2a.types import ( | ||
| InternalError, | ||
| InvalidParamsError, | ||
| Task, | ||
| TaskState, | ||
| TextPart, | ||
| UnsupportedOperationError, | ||
| ) | ||
| from a2a.utils import ( | ||
| new_agent_text_message, | ||
| ) | ||
| from a2a.utils.errors import ServerError | ||
| from braintrust import current_span, traced | ||
|
|
||
| from agents.dspy_example import agent | ||
| from logger import logger | ||
| from memory.mem0 import Mem0Memory | ||
|
|
||
|
|
||
| class DspyAgentExecutor(AgentExecutor): | ||
| """Memory-aware DSPy AgentExecutor with per-user context.""" | ||
|
|
||
| def __init__(self) -> None: | ||
| self.agent = agent | ||
| self.memory = Mem0Memory() | ||
|
|
||
| @traced | ||
| async def execute( | ||
| self, | ||
| context: RequestContext, | ||
| event_queue: EventQueue, | ||
| ) -> None: | ||
| """Execute the task.""" | ||
| with logger.start_span(): | ||
| error = self._validate_request(context) | ||
| if error: | ||
| raise ServerError(error=InvalidParamsError()) | ||
|
|
||
| updater = TaskUpdater( | ||
| event_queue, context.task_id, context.context_id | ||
| ) | ||
| if not context.current_task: | ||
| await updater.submit() | ||
|
|
||
| await updater.start_work() | ||
|
|
||
| query = context.get_user_input() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The user input from You should process the input to extract the text content. For example: user_input = context.get_user_input()
query_text = (
' '.join(p.text for p in user_input if isinstance(p, TextPart))
if isinstance(user_input, list)
else str(user_input)
)
# Then use query_text in calls to memory.retrieve, agent, etc. |
||
| try: | ||
| ctx = await self.memory.retrieve( | ||
| query=query, user_id=context.context_id | ||
| ) | ||
| result = self.agent(question=str(query), ctx=ctx) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| current_span().log(input=query, output=result.answer) | ||
| await self.memory.save( | ||
| user_id=context.context_id, | ||
| user_input=query, | ||
| assistant_response=result.answer, | ||
| ) | ||
| except Exception as e: | ||
| current_span().log(error=e) | ||
| raise ServerError(error=InternalError()) from e | ||
| if result.completed_task: | ||
| await updater.add_artifact( | ||
| [TextPart(text=result.answer)], | ||
| name='answer', | ||
| ) | ||
| await updater.complete() | ||
| else: | ||
| await updater.update_status( | ||
| TaskState.input_required, | ||
| message=new_agent_text_message( | ||
| result.answer, context.context_id, context.task_id | ||
| ), | ||
| ) | ||
|
|
||
| async def cancel( | ||
| self, request: RequestContext, event_queue: EventQueue | ||
| ) -> Task | None: | ||
| """Cancel the task.""" | ||
| raise ServerError(error=UnsupportedOperationError()) | ||
|
|
||
| def _validate_request(self, context: RequestContext) -> bool: | ||
| return False | ||
|
Comment on lines
+85
to
+86
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Comment on lines
+85
to
+86
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The def _validate_request(self, context: RequestContext) -> bool:
"""Validate the incoming request. For this example, we assume all requests are valid.
Returns:
bool: True if there is a validation error, False otherwise.
"""
# TODO: Implement request validation for production use.
return False |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| import os | ||
| import re | ||
|
|
||
| from typing import Any | ||
|
|
||
| from braintrust import init_logger, set_masking_function | ||
| from braintrust.wrappers.litellm import patch_litellm | ||
| from dotenv import load_dotenv | ||
|
|
||
|
|
||
| patch_litellm() | ||
|
|
||
| load_dotenv() | ||
|
|
||
|
|
||
| def mask_sensitive_data(data: Any) -> Any: | ||
| """Mask sensitive data.""" | ||
| if isinstance(data, str): | ||
| return re.sub( | ||
| r'\b(api[_-]?key|password|token)[\s:=]+\S+', | ||
| r'\1: [REDACTED]', | ||
| data, | ||
| flags=re.IGNORECASE, | ||
| ) | ||
|
|
||
| if isinstance(data, dict): | ||
| masked = {} | ||
| for key, value in data.items(): | ||
| if re.match( | ||
| r'^(api[_-]?key|password|secret|token|auth|credential)$', | ||
| key, | ||
| re.IGNORECASE, | ||
| ): | ||
| masked[key] = '[REDACTED]' | ||
| else: | ||
| masked[key] = mask_sensitive_data(value) | ||
| return masked | ||
|
|
||
| if isinstance(data, list): | ||
| return [mask_sensitive_data(item) for item in data] | ||
|
|
||
| return data | ||
|
|
||
|
|
||
| set_masking_function(mask_sensitive_data) | ||
|
|
||
| logger = init_logger( | ||
| project='My Project', api_key=os.getenv('BRAINTRUST_API_KEY') | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| from abc import ABC, abstractmethod | ||
| from typing import Any | ||
|
|
||
|
|
||
| class Memory(ABC): | ||
| """Base class for memory.""" | ||
|
|
||
| @abstractmethod | ||
| async def save( | ||
| self, user_id: str, user_input: str, assistant_response: str | ||
| ) -> Any: | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| async def retrieve(self, query: str, user_id: str) -> list[dict]: | ||
| pass | ||
|
Comment on lines
+8
to
+16
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The @abstractmethod
async def save(self, user_id: str, user_input: str, assistant_response: str) -> None:
pass
@abstractmethod
async def retrieve(self, query: str, user_id: str) -> List[Dict]:
pass |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import os | ||
| import traceback | ||
|
|
||
| from braintrust import current_span, traced | ||
| from dotenv import load_dotenv | ||
| from mem0 import AsyncMemoryClient | ||
|
|
||
| from memory.base import Memory | ||
|
|
||
|
|
||
| load_dotenv() | ||
|
|
||
| mem0 = AsyncMemoryClient(api_key=os.getenv('MEM0_API_KEY')) | ||
|
|
||
|
|
||
| class Mem0Memory(Memory): | ||
| """Mem0 memory implementation.""" | ||
|
|
||
| @traced | ||
| async def retrieve(self, query: str, user_id: str) -> list[dict]: | ||
| """Retrieve relevant context from Mem0.""" | ||
| try: | ||
| memories = await mem0.search(query=query, user_id=user_id) | ||
| serialized_memories = ' '.join([mem['memory'] for mem in memories]) | ||
| context = [ | ||
| { | ||
| 'role': 'system', | ||
| 'content': f'Relevant information: {serialized_memories}', | ||
| }, | ||
| {'role': 'user', 'content': query}, | ||
| ] | ||
| current_span().log( | ||
| metadata={ | ||
| 'memory_retrieved': context, | ||
| 'query': query, | ||
| 'user_id': user_id, | ||
| } | ||
| ) | ||
| except Exception as e: # noqa: BLE001 | ||
| current_span().log( | ||
| metadata={'error': e, 'traceback': traceback.format_exc()} | ||
| ) | ||
| return [{'role': 'user', 'content': query}] | ||
| else: | ||
| return context | ||
|
|
||
| @traced | ||
| async def save( | ||
| self, user_id: str, user_input: str, assistant_response: str | ||
| ) -> None: | ||
| """Save the interaction to Mem0.""" | ||
| try: | ||
| interaction = [ | ||
| {'role': 'user', 'content': user_input}, | ||
| {'role': 'assistant', 'content': assistant_response}, | ||
| ] | ||
| result = await mem0.add(interaction, user_id=user_id) | ||
| current_span().log( | ||
| metadata={'memory_saved': result, 'user_id': user_id} | ||
| ) | ||
| except Exception as e: # noqa: BLE001 | ||
| current_span().log( | ||
| metadata={'error': e, 'traceback': traceback.format_exc()} | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation refers to
__main__.pyas the main entry point, but the actual file is namedrun.py. This should be corrected to avoid confusion for users trying to understand the project structure.