11"""Agent service classes for the AI chatbots"""
22
3+ import asyncio
34import json
45import logging
56from abc import ABC , abstractmethod
3233)
3334from openai import BadRequestError
3435from posthog .ai .langchain import CallbackHandler
35- from typing_extensions import TypedDict
3636
3737from ai_chatbots import tools
3838from ai_chatbots .api import (
4444)
4545from ai_chatbots .posthog import TokenTrackingCallbackHandler
4646from ai_chatbots .prompts import SYSTEM_PROMPT_MAPPING
47- from ai_chatbots .utils import get_django_cache , request_with_token
47+ from ai_chatbots .utils import (
48+ async_request_with_token ,
49+ get_django_cache ,
50+ )
4851
4952log = logging .getLogger (__name__ )
5053
@@ -225,7 +228,7 @@ async def get_completion(
225228 self ,
226229 message : str ,
227230 * ,
228- extra_state : Optional [TypedDict ] = None ,
231+ extra_state : Optional [dict [ str , Any ] ] = None ,
229232 debug : bool = settings .AI_DEBUG ,
230233 ) -> AsyncGenerator [str , None ]:
231234 """
@@ -498,21 +501,18 @@ def __init__( # noqa: PLR0913
498501 self .problem_set_title = problem_set_title
499502
500503 if not self .edx_module_id :
501- self .problem_set = get_canvas_problem_set (
502- self .run_readable_id , self .problem_set_title
503- )
504-
505- self .problem = ""
506504 self .variant = "canvas"
507-
505+ self . problem = ""
508506 else :
509- self .problem , self .problem_set = get_problem_from_edx_block (
510- edx_module_id , block_siblings
511- )
512507 self .variant = "edx"
508+ self .problem = None
509+
510+ self .problem_set = None
511+ self .problem_data_loaded = False
513512
514513 async def get_tool_metadata (self ) -> str :
515514 """Return the metadata for the tool"""
515+ await self .load_problem_data ()
516516 return {
517517 "edx_module_id" : self .edx_module_id ,
518518 "block_siblings" : self .block_siblings ,
@@ -529,11 +529,13 @@ async def get_completion(
529529 self ,
530530 message : str ,
531531 * ,
532- extra_state : Optional [TypedDict ] = None , # noqa: ARG002
532+ extra_state : Optional [dict [ str , Any ] ] = None , # noqa: ARG002
533533 debug : bool = settings .AI_DEBUG ,
534534 ) -> AsyncGenerator [str , None ]:
535535 """Call message_tutor with the user query and return the response"""
536536
537+ await self .load_problem_data ()
538+
537539 history = await self .get_latest_history ()
538540 message_id = str (uuid4 ())
539541 if history :
@@ -555,7 +557,12 @@ async def get_completion(
555557 full_response = ""
556558 new_history = []
557559 try :
558- generator , new_intent_history , new_assessment_history = message_tutor (
560+ (
561+ generator ,
562+ new_intent_history ,
563+ new_assessment_history ,
564+ ) = await asyncio .to_thread (
565+ message_tutor ,
559566 self .problem ,
560567 self .problem_set ,
561568 self .llm ,
@@ -603,8 +610,36 @@ async def get_completion(
603610 yield '<!-- {"error":{"message":"An error occurred, please try again"}} -->'
604611 log .exception ("Error running AI agent" )
605612
613+ async def load_problem_data (self ) -> None :
614+ """Fetch problem content if it has not already been loaded."""
615+ if self .problem_data_loaded :
616+ return
617+
618+ if self .variant == "canvas" :
619+ if not self .run_readable_id or not self .problem_set_title :
620+ msg = "Canvas tutor requires run_readable_id and problem_set_title"
621+ raise ValueError (msg )
622+ self .problem_set = await get_canvas_problem_set (
623+ self .run_readable_id ,
624+ self .problem_set_title ,
625+ )
626+ else :
627+ if not self .edx_module_id or not self .block_siblings :
628+ msg = "Edx tutor requires edx_module_id and block_siblings"
629+ raise ValueError (msg )
630+ problem , problem_set = await get_problem_from_edx_block (
631+ self .edx_module_id ,
632+ self .block_siblings ,
633+ )
634+ self .problem = problem
635+ self .problem_set = problem_set
636+
637+ self .problem_data_loaded = True
638+
606639
607- def get_problem_from_edx_block (edx_module_id : str , block_siblings : list [str ]):
640+ async def get_problem_from_edx_block (
641+ edx_module_id : str , block_siblings : list [str ]
642+ ) -> tuple [str , str ]:
608643 """
609644 Make an call to the learn contentfiles api to get the problem xml and problem
610645 set xml using the block id
@@ -622,7 +657,7 @@ def get_problem_from_edx_block(edx_module_id: str, block_siblings: list[str]):
622657 api_url = settings .AI_MIT_CONTENTFILE_URL
623658 params = {"edx_module_id" : block_siblings }
624659
625- response = request_with_token (api_url , params , timeout = 10 )
660+ response = await async_request_with_token (api_url , params , timeout = 10 )
626661
627662 response = response .json ()
628663
@@ -635,7 +670,9 @@ def get_problem_from_edx_block(edx_module_id: str, block_siblings: list[str]):
635670 return problem , problem_set
636671
637672
638- def get_canvas_problem_set (run_readable_id : str , problem_set_title : str ) -> str :
673+ async def get_canvas_problem_set (
674+ run_readable_id : str , problem_set_title : str
675+ ) -> dict [str , list [dict [str , str ]]]:
639676 """
640677 Make an call to the learn tutor probalem api to get the problem set and solution
641678 using run_readable_id and problem_set_title
@@ -650,7 +687,7 @@ def get_canvas_problem_set(run_readable_id: str, problem_set_title: str) -> str:
650687
651688 api_url = f"{ settings .PROBLEM_SET_URL } { run_readable_id } /{ problem_set_title } /"
652689
653- response = request_with_token (api_url , {}, timeout = 10 )
690+ response = await async_request_with_token (api_url , {}, timeout = 10 )
654691 response = response .json ()
655692
656693 return {
0 commit comments