Skip to content

Commit 9c9dcf9

Browse files
authored
Merge pull request #42 from mitodl/mb/syllabus_bot
Syllabus Chatbot backend
2 parents 96a8d41 + 2568c88 commit 9c9dcf9

File tree

13 files changed

+994
-145
lines changed

13 files changed

+994
-145
lines changed

ai_chatbots/api.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
"""AI-specific functions for ai_agents."""
22

3+
import json
4+
import logging
5+
36
from django.conf import settings
7+
from langchain_core.messages import ToolMessage
48
from langgraph.checkpoint.memory import MemorySaver
59
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
610
from psycopg_pool import AsyncConnectionPool
11+
from typing_extensions import TypedDict
712

813
from main.utils import Singleton
914

15+
log = logging.getLogger(__name__)
16+
1017

1118
class ChatMemory(metaclass=Singleton):
1219
"""
@@ -54,3 +61,39 @@ def get_postgres_saver() -> AsyncPostgresSaver:
5461
kwargs=connection_kwargs,
5562
)
5663
return AsyncPostgresSaver(pool)
64+
65+
66+
def get_search_tool_metadata(thread_id: str, latest_state: TypedDict) -> str:
67+
"""
68+
Return the metadata for a bot search tool.
69+
"""
70+
tool_messages = (
71+
[]
72+
if not latest_state
73+
else [
74+
t
75+
for t in latest_state.values.get("messages", [])
76+
if t and t.__class__ == ToolMessage
77+
]
78+
)
79+
if tool_messages:
80+
msg_content = tool_messages[-1].content
81+
try:
82+
content = json.loads(msg_content or {})
83+
metadata = {
84+
"metadata": {
85+
"search_parameters": content.get("metadata", {}).get(
86+
"parameters", []
87+
),
88+
"search_results": content.get("results", []),
89+
"thread_id": thread_id,
90+
}
91+
}
92+
return json.dumps(metadata)
93+
except json.JSONDecodeError:
94+
log.exception("Error parsing tool metadata, not valid JSON")
95+
return json.dumps(
96+
{"error": "Error parsing tool metadata", "content": msg_content}
97+
)
98+
else:
99+
return "{}"

ai_chatbots/chatbots.py

Lines changed: 111 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,29 @@
44
import logging
55
from abc import ABC, abstractmethod
66
from collections.abc import AsyncGenerator
7-
from typing import Optional
7+
from operator import add
8+
from typing import Annotated, Optional
89
from uuid import uuid4
910

1011
import posthog
1112
from django.conf import settings
1213
from django.utils.module_loading import import_string
1314
from langchain_core.language_models.chat_models import BaseChatModel
14-
from langchain_core.messages import SystemMessage, ToolMessage
15+
from langchain_core.messages import HumanMessage, SystemMessage
1516
from langchain_core.messages.ai import AIMessageChunk
1617
from langchain_core.tools.base import BaseTool
1718
from langgraph.constants import END
1819
from langgraph.graph import MessagesState, StateGraph
1920
from langgraph.graph.graph import CompiledGraph
20-
from langgraph.prebuilt import ToolNode
21+
from langgraph.prebuilt import ToolNode, create_react_agent
22+
from langgraph.prebuilt.chat_agent_executor import AgentState
2123
from openai import BadRequestError
24+
from typing_extensions import TypedDict
2225

2326
from ai_chatbots import tools
24-
from ai_chatbots.api import ChatMemory
27+
from ai_chatbots.api import ChatMemory, get_search_tool_metadata
2528
from ai_chatbots.constants import LLMClassEnum
29+
from ai_chatbots.tools import search_content_files
2630

2731
log = logging.getLogger(__name__)
2832

@@ -39,6 +43,7 @@ class BaseChatbot(ABC):
3943
INSTRUCTIONS = "Provide instructions for the LLM"
4044

4145
# For LiteLLM tracking purposes
46+
TASK_NAME = "BASE_TASK"
4247
JOB_ID = "BASECHAT_JOB"
4348

4449
def __init__( # noqa: PLR0913
@@ -170,12 +175,19 @@ def continue_on_tool_call(state: MessagesState) -> str:
170175
# compile and return the agent graph
171176
return agent_graph.compile(checkpointer=self.memory)
172177

173-
@abstractmethod
174-
async def get_comment_metadata(self) -> str:
175-
"""Yield markdown comments to send hidden metdata in the response"""
178+
async def get_latest_history(self) -> dict:
179+
"""Get the most recent state history"""
180+
async for state in self.agent.aget_state_history(self.config):
181+
if state:
182+
return state
183+
return None
176184

177185
async def get_completion(
178-
self, message: str, *, debug: bool = settings.AI_DEBUG
186+
self,
187+
message: str,
188+
*,
189+
extra_state: Optional[TypedDict] = None,
190+
debug: bool = settings.AI_DEBUG,
179191
) -> AsyncGenerator[str, None]:
180192
"""
181193
Send the user message to the agent and yield the response as
@@ -188,8 +200,12 @@ async def get_completion(
188200
error = "Create agent before running"
189201
raise ValueError(error)
190202
try:
203+
state = {
204+
"messages": [HumanMessage(message)],
205+
**(extra_state or {}),
206+
}
191207
response_generator = self.agent.astream(
192-
{"messages": [{"role": "user", "content": message}]},
208+
state,
193209
self.config,
194210
stream_mode="messages",
195211
)
@@ -217,7 +233,7 @@ async def get_completion(
217233
yield '<!-- {"error":{"message":"An error occurred, please try again"}} -->'
218234
log.exception("Error running AI agent")
219235
if debug:
220-
yield f"\n\n<!-- {await self.get_comment_metadata()} -->\n\n"
236+
yield f"\n\n<!-- {await self.get_tool_metadata()} -->\n\n"
221237
if settings.POSTHOG_PROJECT_API_KEY:
222238
hog_client = posthog.Posthog(
223239
settings.POSTHOG_PROJECT_API_KEY, host=settings.POSTHOG_API_HOST
@@ -228,11 +244,18 @@ async def get_completion(
228244
properties={
229245
"question": message,
230246
"answer": full_response,
231-
"metadata": await self.get_comment_metadata(),
247+
"metadata": await self.get_tool_metadata(),
232248
"user": self.user_id,
233249
},
234250
)
235251

252+
@abstractmethod
253+
async def get_tool_metadata(self) -> str:
254+
"""
255+
Yield markdown comments to send hidden metadata in the response
256+
"""
257+
raise NotImplementedError
258+
236259

237260
class ResourceRecommendationBot(BaseChatbot):
238261
"""
@@ -241,6 +264,7 @@ class ResourceRecommendationBot(BaseChatbot):
241264
"""
242265

243266
TASK_NAME = "RECOMMENDATION_TASK"
267+
JOB_ID = "RECOMMENDATION_JOB"
244268

245269
INSTRUCTIONS = """You are an assistant helping users find courses from a catalog
246270
of learning resources. Users can ask about specific topics, levels, or recommendations
@@ -372,37 +396,84 @@ def create_tools(self) -> list[BaseTool]:
372396
"""Create tools required by the agent"""
373397
return [tools.search_courses]
374398

375-
async def get_latest_history(self) -> dict:
376-
async for state in self.agent.aget_state_history(self.config):
377-
if state:
378-
return state
379-
return None
399+
async def get_tool_metadata(self) -> str:
400+
"""Return the metadata for the search tool"""
401+
thread_id = self.config["configurable"]["thread_id"]
402+
latest_state = await self.get_latest_history()
403+
return get_search_tool_metadata(thread_id, latest_state)
404+
380405

381-
async def get_comment_metadata(self) -> str:
406+
class SyllabusAgentState(AgentState):
407+
"""
408+
State for the syllabus bot. Passes course_id and
409+
collection_name to the associated tool function.
410+
"""
411+
412+
course_id: Annotated[list[str], add]
413+
collection_name: Annotated[list[str], add]
414+
415+
416+
class SyllabusBot(BaseChatbot):
417+
"""Service class for the AI syllabus agent"""
418+
419+
TASK_NAME = "SYLLABUS_TASK"
420+
JOB_ID = "SYLLABUS_JOB"
421+
422+
INSTRUCTIONS = """You are an assistant helping users answer questions related
423+
to a syllabus.
424+
425+
Your job:
426+
1. Use the available function to gather relevant information about the user's question.
427+
2. Provide a clear, user-friendly summary of the information retrieved by the tool to
428+
answer the user's question.
429+
430+
Always run the tool to answer questions, and answer only based on the tool
431+
output. Do not include the course id in the query parameter.
432+
VERY IMPORTANT: NEVER USE ANY INFORMATION OUTSIDE OF THE TOOL OUTPUT TO
433+
ANSWER QUESTIONS. If no results are returned, say you could not find any relevant
434+
information.
435+
"""
436+
437+
def __init__( # noqa: PLR0913
438+
self,
439+
user_id: str,
440+
*,
441+
name: str = "MIT Open Learning Syllabus Chatbot",
442+
model: Optional[str] = None,
443+
temperature: Optional[float] = None,
444+
instructions: Optional[str] = None,
445+
thread_id: Optional[str] = None,
446+
):
447+
super().__init__(
448+
user_id,
449+
name=name,
450+
model=model or settings.AI_MODEL,
451+
temperature=temperature,
452+
instructions=instructions,
453+
thread_id=thread_id,
454+
)
455+
self.agent = self.create_agent_graph()
456+
457+
def create_tools(self):
458+
"""Create tools required by the agent"""
459+
return [search_content_files]
460+
461+
def create_agent_graph(self) -> CompiledGraph:
382462
"""
383-
Yield markdown comments to send hidden metadata in the response
463+
Generate a standard react agent graph for the syllabus agent.
464+
Use the custom SyllabusAgentState to pass course_id and collection_name
465+
to the associated tool function.
384466
"""
467+
return create_react_agent(
468+
self.llm,
469+
tools=self.tools,
470+
checkpointer=self.memory,
471+
state_schema=SyllabusAgentState,
472+
state_modifier=self.instructions,
473+
)
474+
475+
async def get_tool_metadata(self) -> str:
476+
"""Return the metadata for the search tool"""
385477
thread_id = self.config["configurable"]["thread_id"]
386-
metadata = {"thread_id": thread_id}
387478
latest_state = await self.get_latest_history()
388-
tool_messages = (
389-
[]
390-
if not latest_state
391-
else [
392-
t
393-
for t in latest_state.values.get("messages", [])
394-
if t and t.__class__ == ToolMessage
395-
]
396-
)
397-
if tool_messages:
398-
content = json.loads(tool_messages[-1].content or {})
399-
metadata = {
400-
"metadata": {
401-
"search_parameters": content.get("metadata", {}).get(
402-
"parameters", []
403-
),
404-
"search_results": content.get("results", []),
405-
"thread_id": thread_id,
406-
}
407-
}
408-
return json.dumps(metadata)
479+
return get_search_tool_metadata(thread_id, latest_state)

0 commit comments

Comments
 (0)