|
1 | 1 | import os |
2 | | -from typing import Any, Callable |
| 2 | +from typing import Any |
3 | 3 |
|
4 | 4 | import aiosqlite |
5 | 5 | from jupyter_ai_persona_manager import BasePersona, PersonaDefaults |
6 | 6 | from jupyter_core.paths import jupyter_data_dir |
7 | 7 | from jupyterlab_chat.models import Message |
8 | 8 | from langchain.agents import create_agent |
9 | | -from langchain.agents.middleware import AgentMiddleware |
10 | | -from langchain.messages import ToolMessage |
11 | | -from langchain.tools.tool_node import ToolCallRequest |
12 | 9 | from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver |
13 | | -from langgraph.types import Command |
14 | 10 |
|
15 | 11 | from .chat_models import ChatLiteLLM |
16 | 12 | from .prompt_template import ( |
|
30 | 26 | ) |
31 | 27 |
|
32 | 28 |
|
33 | | -def format_tool_args_compact(args_dict, threshold=25): |
34 | | - """ |
35 | | - Create a more compact string representation of tool call args. |
36 | | - Each key-value pair is on its own line for better readability. |
37 | | -
|
38 | | - Args: |
39 | | - args_dict (dict): Dictionary of tool arguments |
40 | | - threshold (int): Maximum number of lines before truncation (default: 25) |
41 | | -
|
42 | | - Returns: |
43 | | - str: Formatted string representation of arguments |
44 | | - """ |
45 | | - if not args_dict: |
46 | | - return "{}" |
47 | | - |
48 | | - formatted_pairs = [] |
49 | | - |
50 | | - for key, value in args_dict.items(): |
51 | | - value_str = str(value) |
52 | | - lines = value_str.split("\n") |
53 | | - |
54 | | - if len(lines) <= threshold: |
55 | | - if len(lines) == 1 and len(value_str) > 80: |
56 | | - # Single long line - truncate |
57 | | - truncated = value_str[:77] + "..." |
58 | | - formatted_pairs.append(f" {key}: {truncated}") |
59 | | - else: |
60 | | - # Add indentation for multi-line values |
61 | | - if len(lines) > 1: |
62 | | - indented_value = "\n ".join([""] + lines) |
63 | | - formatted_pairs.append(f" {key}:{indented_value}") |
64 | | - else: |
65 | | - formatted_pairs.append(f" {key}: {value_str}") |
66 | | - else: |
67 | | - # Truncate and add summary |
68 | | - truncated_lines = lines[:threshold] |
69 | | - remaining_lines = len(lines) - threshold |
70 | | - indented_value = "\n ".join([""] + truncated_lines) |
71 | | - formatted_pairs.append( |
72 | | - f" {key}:{indented_value}\n [+{remaining_lines} more lines]" |
73 | | - ) |
74 | | - |
75 | | - return "{\n" + ",\n".join(formatted_pairs) + "\n}" |
76 | | - |
77 | | - |
78 | | -class ToolMonitoringMiddleware(AgentMiddleware): |
79 | | - def __init__(self, *, persona: BasePersona): |
80 | | - self.stream_message = persona.stream_message |
81 | | - self.log = persona.log |
82 | | - |
83 | | - async def awrap_tool_call( |
84 | | - self, |
85 | | - request: ToolCallRequest, |
86 | | - handler: Callable[[ToolCallRequest], ToolMessage | Command], |
87 | | - ) -> ToolMessage | Command: |
88 | | - args = format_tool_args_compact(request.tool_call["args"]) |
89 | | - self.log.info(f"{request.tool_call['name']}({args})") |
90 | | - |
91 | | - try: |
92 | | - result = await handler(request) |
93 | | - self.log.info(f"{request.tool_call['name']} Done!") |
94 | | - return result |
95 | | - except Exception as e: |
96 | | - self.log.info(f"{request.tool_call['name']} failed: {e}") |
97 | | - return ToolMessage( |
98 | | - tool_call_id=request.tool_call["id"], status="error", content=f"{e}" |
99 | | - ) |
100 | | - |
101 | | - |
102 | 29 | class JupyternautPersona(BasePersona): |
103 | 30 | """ |
104 | 31 | The Jupyternaut persona, the main persona provided by Jupyter AI. |
|
0 commit comments