77from jupyterlab_chat .models import Message
88from langchain .agents import create_agent
99from langchain .agents .middleware import AgentMiddleware
10- from langchain .agents .middleware .file_search import FilesystemFileSearchMiddleware
11- from langchain .agents .middleware .shell_tool import ShellToolMiddleware
1210from langchain .messages import ToolMessage
1311from langchain .tools .tool_node import ToolCallRequest
14- from langchain_core .messages import ToolMessage
1512from langgraph .checkpoint .sqlite .aio import AsyncSqliteSaver
1613from langgraph .types import Command
1714
2623
2724MEMORY_STORE_PATH = os .path .join (jupyter_data_dir (), "jupyter_ai" , "memory.sqlite" )
2825
29- JUPYTERNAUT_AVATAR_PATH = str (os .path .abspath (
30- os .path .join (os .path .dirname (__file__ ), "../static" , "jupyternaut.svg" )
31- ))
26+ JUPYTERNAUT_AVATAR_PATH = str (
27+ os .path .abspath (
28+ os .path .join (os .path .dirname (__file__ ), "../static" , "jupyternaut.svg" )
29+ )
30+ )
3231
3332
3433def format_tool_args_compact (args_dict , threshold = 25 ):
@@ -50,7 +49,7 @@ def format_tool_args_compact(args_dict, threshold=25):
5049
5150 for key , value in args_dict .items ():
5251 value_str = str (value )
53- lines = value_str .split (' \n ' )
52+ lines = value_str .split (" \n " )
5453
5554 if len (lines ) <= threshold :
5655 if len (lines ) == 1 and len (value_str ) > 80 :
@@ -60,16 +59,18 @@ def format_tool_args_compact(args_dict, threshold=25):
6059 else :
6160 # Add indentation for multi-line values
6261 if len (lines ) > 1 :
63- indented_value = ' \n ' .join (['' ] + lines )
62+ indented_value = " \n " .join (["" ] + lines )
6463 formatted_pairs .append (f" { key } :{ indented_value } " )
6564 else :
6665 formatted_pairs .append (f" { key } : { value_str } " )
6766 else :
6867 # Truncate and add summary
6968 truncated_lines = lines [:threshold ]
7069 remaining_lines = len (lines ) - threshold
71- indented_value = '\n ' .join (['' ] + truncated_lines )
72- formatted_pairs .append (f" { key } :{ indented_value } \n [+{ remaining_lines } more lines]" )
70+ indented_value = "\n " .join (["" ] + truncated_lines )
71+ formatted_pairs .append (
72+ f" { key } :{ indented_value } \n [+{ remaining_lines } more lines]"
73+ )
7374
7475 return "{\n " + ",\n " .join (formatted_pairs ) + "\n }"
7576
@@ -84,7 +85,7 @@ async def awrap_tool_call(
8485 request : ToolCallRequest ,
8586 handler : Callable [[ToolCallRequest ], ToolMessage | Command ],
8687 ) -> ToolMessage | Command :
87- args = format_tool_args_compact (request .tool_call [' args' ])
88+ args = format_tool_args_compact (request .tool_call [" args" ])
8889 self .log .info (f"{ request .tool_call ['name' ]} ({ args } )" )
8990
9091 try :
@@ -115,6 +116,10 @@ def defaults(self):
115116 system_prompt = "..." ,
116117 )
117118
119+ @property
120+ def yroom_manager (self ):
121+ return self .parent .serverapp .web_app .settings ["yroom_manager" ]
122+
118123 async def get_memory_store (self ):
119124 if not hasattr (self , "_memory_store" ):
120125 conn = await aiosqlite .connect (MEMORY_STORE_PATH , check_same_thread = False )
@@ -130,7 +135,7 @@ def get_tools(self):
130135 async def get_agent (self , model_id : str , model_args , system_prompt : str ):
131136 model = ChatLiteLLM (** model_args , model = model_id , streaming = True )
132137 memory_store = await self .get_memory_store ()
133-
138+
134139 return create_agent (
135140 model ,
136141 system_prompt = system_prompt ,
@@ -158,18 +163,20 @@ async def process_message(self, message: Message) -> None:
158163 model_id = model_id , model_args = model_args , system_prompt = system_prompt
159164 )
160165
166+ context = {
167+ "thread_id" : self .ychat .get_id (),
168+ "username" : message .sender
169+ }
170+
161171 async def create_aiter ():
162172 async for token , metadata in agent .astream (
163173 {"messages" : [{"role" : "user" , "content" : message .body }]},
164- {"configurable" : { "thread_id" : self . ychat . get_id ()} },
174+ {"configurable" : context },
165175 stream_mode = "messages" ,
166176 ):
167177 node = metadata ["langgraph_node" ]
168178 content_blocks = token .content_blocks
169- if (
170- node == "model"
171- and content_blocks
172- ):
179+ if node == "model" and content_blocks :
173180 if token .text :
174181 yield token .text
175182
@@ -195,6 +202,6 @@ def get_system_prompt(
195202 return JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE .render (** system_msg_args )
196203
197204 def shutdown (self ):
198- if hasattr (self ,"_memory_store" ):
205+ if hasattr (self , "_memory_store" ):
199206 self .parent .event_loop .create_task (self ._memory_store .conn .close ())
200207 super ().shutdown ()
0 commit comments