44import logging
55from abc import ABC , abstractmethod
66from collections .abc import AsyncGenerator
7- from typing import Optional
7+ from operator import add
8+ from typing import Annotated , Optional
89from uuid import uuid4
910
1011import posthog
1112from django .conf import settings
1213from django .utils .module_loading import import_string
1314from langchain_core .language_models .chat_models import BaseChatModel
14- from langchain_core .messages import SystemMessage , ToolMessage
15+ from langchain_core .messages import HumanMessage , SystemMessage
1516from langchain_core .messages .ai import AIMessageChunk
1617from langchain_core .tools .base import BaseTool
1718from langgraph .constants import END
1819from langgraph .graph import MessagesState , StateGraph
1920from langgraph .graph .graph import CompiledGraph
20- from langgraph .prebuilt import ToolNode
21+ from langgraph .prebuilt import ToolNode , create_react_agent
22+ from langgraph .prebuilt .chat_agent_executor import AgentState
2123from openai import BadRequestError
24+ from typing_extensions import TypedDict
2225
2326from ai_chatbots import tools
24- from ai_chatbots .api import ChatMemory
27+ from ai_chatbots .api import ChatMemory , get_search_tool_metadata
2528from ai_chatbots .constants import LLMClassEnum
29+ from ai_chatbots .tools import search_content_files
2630
2731log = logging .getLogger (__name__ )
2832
@@ -39,6 +43,7 @@ class BaseChatbot(ABC):
3943 INSTRUCTIONS = "Provide instructions for the LLM"
4044
4145 # For LiteLLM tracking purposes
46+ TASK_NAME = "BASE_TASK"
4247 JOB_ID = "BASECHAT_JOB"
4348
4449 def __init__ ( # noqa: PLR0913
@@ -170,12 +175,19 @@ def continue_on_tool_call(state: MessagesState) -> str:
170175 # compile and return the agent graph
171176 return agent_graph .compile (checkpointer = self .memory )
172177
173- @abstractmethod
174- async def get_comment_metadata (self ) -> str :
175- """Yield markdown comments to send hidden metdata in the response"""
178+ async def get_latest_history (self ) -> dict :
179+ """Get the most recent state history"""
180+ async for state in self .agent .aget_state_history (self .config ):
181+ if state :
182+ return state
183+ return None
176184
177185 async def get_completion (
178- self , message : str , * , debug : bool = settings .AI_DEBUG
186+ self ,
187+ message : str ,
188+ * ,
189+ extra_state : Optional [TypedDict ] = None ,
190+ debug : bool = settings .AI_DEBUG ,
179191 ) -> AsyncGenerator [str , None ]:
180192 """
181193 Send the user message to the agent and yield the response as
@@ -188,8 +200,12 @@ async def get_completion(
188200 error = "Create agent before running"
189201 raise ValueError (error )
190202 try :
203+ state = {
204+ "messages" : [HumanMessage (message )],
205+ ** (extra_state or {}),
206+ }
191207 response_generator = self .agent .astream (
192- { "messages" : [{ "role" : "user" , "content" : message }]} ,
208+ state ,
193209 self .config ,
194210 stream_mode = "messages" ,
195211 )
@@ -217,7 +233,7 @@ async def get_completion(
217233 yield '<!-- {"error":{"message":"An error occurred, please try again"}} -->'
218234 log .exception ("Error running AI agent" )
219235 if debug :
220- yield f"\n \n <!-- { await self .get_comment_metadata ()} -->\n \n "
236+ yield f"\n \n <!-- { await self .get_tool_metadata ()} -->\n \n "
221237 if settings .POSTHOG_PROJECT_API_KEY :
222238 hog_client = posthog .Posthog (
223239 settings .POSTHOG_PROJECT_API_KEY , host = settings .POSTHOG_API_HOST
@@ -228,11 +244,18 @@ async def get_completion(
228244 properties = {
229245 "question" : message ,
230246 "answer" : full_response ,
231- "metadata" : await self .get_comment_metadata (),
247+ "metadata" : await self .get_tool_metadata (),
232248 "user" : self .user_id ,
233249 },
234250 )
235251
252+ @abstractmethod
253+ async def get_tool_metadata (self ) -> str :
254+ """
255+ Yield markdown comments to send hidden metadata in the response
256+ """
257+ raise NotImplementedError
258+
236259
237260class ResourceRecommendationBot (BaseChatbot ):
238261 """
@@ -241,6 +264,7 @@ class ResourceRecommendationBot(BaseChatbot):
241264 """
242265
243266 TASK_NAME = "RECOMMENDATION_TASK"
267+ JOB_ID = "RECOMMENDATION_JOB"
244268
245269 INSTRUCTIONS = """You are an assistant helping users find courses from a catalog
246270of learning resources. Users can ask about specific topics, levels, or recommendations
@@ -372,37 +396,84 @@ def create_tools(self) -> list[BaseTool]:
372396 """Create tools required by the agent"""
373397 return [tools .search_courses ]
374398
375- async def get_latest_history (self ) -> dict :
376- async for state in self .agent .aget_state_history (self .config ):
377- if state :
378- return state
379- return None
399+ async def get_tool_metadata (self ) -> str :
400+ """Return the metadata for the search tool"""
401+ thread_id = self .config ["configurable" ]["thread_id" ]
402+ latest_state = await self .get_latest_history ()
403+ return get_search_tool_metadata (thread_id , latest_state )
404+
380405
381- async def get_comment_metadata (self ) -> str :
406+ class SyllabusAgentState (AgentState ):
407+ """
408+ State for the syllabus bot. Passes course_id and
409+ collection_name to the associated tool function.
410+ """
411+
412+ course_id : Annotated [list [str ], add ]
413+ collection_name : Annotated [list [str ], add ]
414+
415+
416+ class SyllabusBot (BaseChatbot ):
417+ """Service class for the AI syllabus agent"""
418+
419+ TASK_NAME = "SYLLABUS_TASK"
420+ JOB_ID = "SYLLABUS_JOB"
421+
422+ INSTRUCTIONS = """You are an assistant helping users answer questions related
423+ to a syllabus.
424+
425+ Your job:
426+ 1. Use the available function to gather relevant information about the user's question.
427+ 2. Provide a clear, user-friendly summary of the information retrieved by the tool to
428+ answer the user's question.
429+
430+ Always run the tool to answer questions, and answer only based on the tool
431+ output. Do not include the course id in the query parameter.
432+ VERY IMPORTANT: NEVER USE ANY INFORMATION OUTSIDE OF THE TOOL OUTPUT TO
433+ ANSWER QUESTIONS. If no results are returned, say you could not find any relevant
434+ information.
435+ """
436+
437+ def __init__ ( # noqa: PLR0913
438+ self ,
439+ user_id : str ,
440+ * ,
441+ name : str = "MIT Open Learning Syllabus Chatbot" ,
442+ model : Optional [str ] = None ,
443+ temperature : Optional [float ] = None ,
444+ instructions : Optional [str ] = None ,
445+ thread_id : Optional [str ] = None ,
446+ ):
447+ super ().__init__ (
448+ user_id ,
449+ name = name ,
450+ model = model or settings .AI_MODEL ,
451+ temperature = temperature ,
452+ instructions = instructions ,
453+ thread_id = thread_id ,
454+ )
455+ self .agent = self .create_agent_graph ()
456+
457+ def create_tools (self ):
458+ """Create tools required by the agent"""
459+ return [search_content_files ]
460+
461+ def create_agent_graph (self ) -> CompiledGraph :
382462 """
383- Yield markdown comments to send hidden metadata in the response
463+ Generate a standard react agent graph for the syllabus agent.
464+ Use the custom SyllabusAgentState to pass course_id and collection_name
465+ to the associated tool function.
384466 """
467+ return create_react_agent (
468+ self .llm ,
469+ tools = self .tools ,
470+ checkpointer = self .memory ,
471+ state_schema = SyllabusAgentState ,
472+ state_modifier = self .instructions ,
473+ )
474+
475+ async def get_tool_metadata (self ) -> str :
476+ """Return the metadata for the search tool"""
385477 thread_id = self .config ["configurable" ]["thread_id" ]
386- metadata = {"thread_id" : thread_id }
387478 latest_state = await self .get_latest_history ()
388- tool_messages = (
389- []
390- if not latest_state
391- else [
392- t
393- for t in latest_state .values .get ("messages" , [])
394- if t and t .__class__ == ToolMessage
395- ]
396- )
397- if tool_messages :
398- content = json .loads (tool_messages [- 1 ].content or {})
399- metadata = {
400- "metadata" : {
401- "search_parameters" : content .get ("metadata" , {}).get (
402- "parameters" , []
403- ),
404- "search_results" : content .get ("results" , []),
405- "thread_id" : thread_id ,
406- }
407- }
408- return json .dumps (metadata )
479+ return get_search_tool_metadata (thread_id , latest_state )
0 commit comments