Skip to content

Commit 96cdc2c

Browse files
3coinsdlqqq
authored andcommitted
Updated tools, prompt to nudge agent to work with chat and active notebook
1 parent 3286283 commit 96cdc2c

File tree

3 files changed

+363
-40
lines changed

3 files changed

+363
-40
lines changed

jupyter_ai_jupyternaut/jupyternaut/jupyternaut.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,8 @@
77
from jupyterlab_chat.models import Message
88
from langchain.agents import create_agent
99
from langchain.agents.middleware import AgentMiddleware
10-
from langchain.agents.middleware.file_search import FilesystemFileSearchMiddleware
11-
from langchain.agents.middleware.shell_tool import ShellToolMiddleware
1210
from langchain.messages import ToolMessage
1311
from langchain.tools.tool_node import ToolCallRequest
14-
from langchain_core.messages import ToolMessage
1512
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
1613
from langgraph.types import Command
1714

@@ -26,9 +23,11 @@
2623

2724
MEMORY_STORE_PATH = os.path.join(jupyter_data_dir(), "jupyter_ai", "memory.sqlite")
2825

29-
JUPYTERNAUT_AVATAR_PATH = str(os.path.abspath(
30-
os.path.join(os.path.dirname(__file__), "../static", "jupyternaut.svg")
31-
))
26+
JUPYTERNAUT_AVATAR_PATH = str(
27+
os.path.abspath(
28+
os.path.join(os.path.dirname(__file__), "../static", "jupyternaut.svg")
29+
)
30+
)
3231

3332

3433
def format_tool_args_compact(args_dict, threshold=25):
@@ -50,7 +49,7 @@ def format_tool_args_compact(args_dict, threshold=25):
5049

5150
for key, value in args_dict.items():
5251
value_str = str(value)
53-
lines = value_str.split('\n')
52+
lines = value_str.split("\n")
5453

5554
if len(lines) <= threshold:
5655
if len(lines) == 1 and len(value_str) > 80:
@@ -60,16 +59,18 @@ def format_tool_args_compact(args_dict, threshold=25):
6059
else:
6160
# Add indentation for multi-line values
6261
if len(lines) > 1:
63-
indented_value = '\n '.join([''] + lines)
62+
indented_value = "\n ".join([""] + lines)
6463
formatted_pairs.append(f" {key}:{indented_value}")
6564
else:
6665
formatted_pairs.append(f" {key}: {value_str}")
6766
else:
6867
# Truncate and add summary
6968
truncated_lines = lines[:threshold]
7069
remaining_lines = len(lines) - threshold
71-
indented_value = '\n '.join([''] + truncated_lines)
72-
formatted_pairs.append(f" {key}:{indented_value}\n [+{remaining_lines} more lines]")
70+
indented_value = "\n ".join([""] + truncated_lines)
71+
formatted_pairs.append(
72+
f" {key}:{indented_value}\n [+{remaining_lines} more lines]"
73+
)
7374

7475
return "{\n" + ",\n".join(formatted_pairs) + "\n}"
7576

@@ -84,7 +85,7 @@ async def awrap_tool_call(
8485
request: ToolCallRequest,
8586
handler: Callable[[ToolCallRequest], ToolMessage | Command],
8687
) -> ToolMessage | Command:
87-
args = format_tool_args_compact(request.tool_call['args'])
88+
args = format_tool_args_compact(request.tool_call["args"])
8889
self.log.info(f"{request.tool_call['name']}({args})")
8990

9091
try:
@@ -115,6 +116,10 @@ def defaults(self):
115116
system_prompt="...",
116117
)
117118

119+
@property
120+
def yroom_manager(self):
121+
return self.parent.serverapp.web_app.settings["yroom_manager"]
122+
118123
async def get_memory_store(self):
119124
if not hasattr(self, "_memory_store"):
120125
conn = await aiosqlite.connect(MEMORY_STORE_PATH, check_same_thread=False)
@@ -130,7 +135,7 @@ def get_tools(self):
130135
async def get_agent(self, model_id: str, model_args, system_prompt: str):
131136
model = ChatLiteLLM(**model_args, model=model_id, streaming=True)
132137
memory_store = await self.get_memory_store()
133-
138+
134139
return create_agent(
135140
model,
136141
system_prompt=system_prompt,
@@ -158,18 +163,20 @@ async def process_message(self, message: Message) -> None:
158163
model_id=model_id, model_args=model_args, system_prompt=system_prompt
159164
)
160165

166+
context = {
167+
"thread_id": self.ychat.get_id(),
168+
"username": message.sender
169+
}
170+
161171
async def create_aiter():
162172
async for token, metadata in agent.astream(
163173
{"messages": [{"role": "user", "content": message.body}]},
164-
{"configurable": {"thread_id": self.ychat.get_id()}},
174+
{"configurable": context},
165175
stream_mode="messages",
166176
):
167177
node = metadata["langgraph_node"]
168178
content_blocks = token.content_blocks
169-
if (
170-
node == "model"
171-
and content_blocks
172-
):
179+
if node == "model" and content_blocks:
173180
if token.text:
174181
yield token.text
175182

@@ -195,6 +202,6 @@ def get_system_prompt(
195202
return JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args)
196203

197204
def shutdown(self):
198-
if hasattr(self,"_memory_store"):
205+
if hasattr(self, "_memory_store"):
199206
self.parent.event_loop.create_task(self._memory_store.conn.close())
200207
super().shutdown()

0 commit comments

Comments
 (0)