diff --git a/clickhouse_db_schema/misc.torchagent_feedback/schema.sql b/clickhouse_db_schema/misc.torchagent_feedback/schema.sql index a67f534672..122f328a60 100644 --- a/clickhouse_db_schema/misc.torchagent_feedback/schema.sql +++ b/clickhouse_db_schema/misc.torchagent_feedback/schema.sql @@ -1,10 +1,12 @@ -CREATE TABLE misc.torchagent_feedback ( - `user` String, - `session_id` String, - `feedback` Int8, - `time_inserted` DateTime64(0, 'UTC') -) -ENGINE = SharedMergeTree('/clickhouse/tables/{uuid}/{shard}', '{replica}') -PARTITION BY toYYYYMM(time_inserted) -ORDER BY (user, session_id, time_inserted) -SETTINGS index_granularity = 8192 +CREATE TABLE + misc.torchagent_feedback ( + `user` String, + `session_id` String, + `torchagent_feedback` String, + `feedback` Int8, + `time_inserted` DateTime64 (0, 'UTC') + ) ENGINE = SharedMergeTree ('/clickhouse/tables/{uuid}/{shard}', '{replica}') +PARTITION BY + toYYYYMM (time_inserted) +ORDER BY + (user, session_id, time_inserted) SETTINGS index_granularity = 8192 \ No newline at end of file diff --git a/torchci/components/TorchAgentPage.tsx b/torchci/components/TorchAgentPage.tsx index 8fa5ca0d8b..2a72c784f2 100644 --- a/torchci/components/TorchAgentPage.tsx +++ b/torchci/components/TorchAgentPage.tsx @@ -1,9 +1,17 @@ -import { Box, Button, Typography, useTheme } from "@mui/material"; +import MenuIcon from "@mui/icons-material/Menu"; +import { + Box, + Button, + IconButton, + Tooltip, + Typography, + useMediaQuery, + useTheme, +} from "@mui/material"; import { useSession } from "next-auth/react"; -import { useEffect, useRef, useState } from "react"; +import { useCallback, useEffect, useRef, useState } from "react"; import AISpinner from "./AISpinner"; import { ChatHistorySidebar } from "./TorchAgentPage/ChatHistorySidebar"; -import { FeedbackButtons } from "./TorchAgentPage/FeedbackButtons"; import { GrafanaEmbed } from "./TorchAgentPage/GrafanaEmbed"; import { HeaderSection } from "./TorchAgentPage/HeaderSection"; import { @@ -16,24 +24,26 @@ import { LoadingDisplay } from "./TorchAgentPage/LoadingDisplay"; import { processContentBlockDelta, processMessageLine, - processUserMessages, } from "./TorchAgentPage/messageProcessor"; import { QueryInputSection } from "./TorchAgentPage/QueryInputSection"; import { + ChatMain, + ChatMessages, ChunkMetadata, LoaderWrapper, + MessageBubble, QuerySection, ResponseText, - ResultsSection, TorchAgentPageContainer, } from "./TorchAgentPage/styles"; import { TodoList } from "./TorchAgentPage/TodoList"; import { ToolUse } from "./TorchAgentPage/ToolUse"; import { ParsedContent } from "./TorchAgentPage/types"; import { + extractGrafanaLinks, formatElapsedTime, formatTokenCount, - renderTextWithLinks, + renderMarkdownWithLinks, } from "./TorchAgentPage/utils"; import { WelcomeSection } from "./TorchAgentPage/WelcomeSection"; @@ -62,6 +72,11 @@ const hasAuthCookie = () => { export const TorchAgentPage = () => { const session = useSession(); const theme = useTheme(); + const isMobile = useMediaQuery(theme.breakpoints.down("lg")); // Below 1200px + + // Constants + const typingSpeed = 30; + const sidebarWidth = 300; const featureRequestUrl = "https://github.com/pytorch/test-infra/issues/new?title=" + @@ -82,7 +97,6 @@ export const TorchAgentPage = () => { {} ); const [allToolsExpanded, setAllToolsExpanded] = useState(false); - const [typingSpeed] = useState(10); const [thinkingMessageIndex, setThinkingMessageIndex] = useState(0); const [startTime, setStartTime] = useState(null); const [elapsedTime, setElapsedTime] = useState(0); @@ -92,21 +106,63 @@ export const TorchAgentPage = () => { const [error, setError] = useState(""); const [debugVisible, setDebugVisible] = useState(false); const [currentSessionId, setCurrentSessionId] = useState(null); - const [feedbackVisible, setFeedbackVisible] = useState(false); const [chatHistory, setChatHistory] = useState([]); const [selectedSession, setSelectedSession] = useState(null); const [isHistoryLoading, setIsHistoryLoading] = useState(false); const [isSessionLoading, setIsSessionLoading] = useState(false); const [drawerOpen, setDrawerOpen] = useState(true); + const [headerHeight, setHeaderHeight] = useState(80); // Default fallback + + const contentRef = useRef(null); + + // Auto-collapse sidebar on mobile screens + useEffect(() => { + if (isMobile && drawerOpen) { + setDrawerOpen(false); + } else if (!isMobile && !drawerOpen) { + setDrawerOpen(true); + } + // Don't auto-expand on desktop - preserve user preference + }, [isMobile]); + + // Measure header height dynamically by calculating offset from viewport top + useEffect(() => { + const measureHeaderHeight = () => { + if (contentRef.current) { + const rect = contentRef.current.getBoundingClientRect(); + const height = Math.max(rect.top, 80); // Ensure minimum of 80px + setHeaderHeight(height); + } + }; + + // Use requestAnimationFrame for proper timing after render + const rafId = requestAnimationFrame(measureHeaderHeight); + + // Re-measure on window resize + const handleResize = () => { + requestAnimationFrame(measureHeaderHeight); + }; + + window.addEventListener("resize", handleResize); + return () => { + cancelAnimationFrame(rafId); + window.removeEventListener("resize", handleResize); + }; + }, []); + + const toggleSidebar = () => { + setDrawerOpen(!drawerOpen); + }; const fetchControllerRef = useRef(null); + const chatContainerRef = useRef(null); const thinkingMessages = useThinkingMessages(); const displayedTokens = useAnimatedCounter(totalTokens); const calculateTotalTokens = useTokenCalculator(); const { showScrollButton, scrollToBottomAndEnable, resetAutoScroll } = - useAutoScroll(isLoading, parsedResponses); + useAutoScroll(isLoading, parsedResponses, chatContainerRef); const handleQueryChange = (event: React.ChangeEvent) => { setQuery(event.target.value); @@ -127,7 +183,7 @@ export const TorchAgentPage = () => { ); - const fetchChatHistory = async () => { + const fetchChatHistory = useCallback(async () => { if (!session.data?.user && !hasAuthCookie()) return; setIsHistoryLoading(true); @@ -144,7 +200,7 @@ export const TorchAgentPage = () => { } finally { setIsHistoryLoading(false); } - }; + }, [session.data?.user]); const loadChatSession = async (sessionId: string) => { // Cancel any active stream first @@ -160,6 +216,11 @@ export const TorchAgentPage = () => { setResponse(""); setError(""); + // Close mobile sidebar when loading a session + if (isMobile && drawerOpen) { + setDrawerOpen(false); + } + try { const response = await fetch( `/api/torchagent-get-chat-history?sessionId=${sessionId}` @@ -168,13 +229,6 @@ export const TorchAgentPage = () => { const sessionData = await response.json(); if (sessionData.messages && Array.isArray(sessionData.messages)) { - const userMessage = sessionData.messages.find( - (msg: any) => msg.type === "user_message" && msg.content - ); - if (userMessage) { - setQuery(userMessage.content); - } - setParsedResponses([]); let fullResponse = ""; @@ -185,17 +239,51 @@ export const TorchAgentPage = () => { }); setResponse(fullResponse); - processUserMessages(sessionData.messages, setParsedResponses); - - const lines = fullResponse.split("\n").filter((line) => line.trim()); - lines.forEach((line) => { - processMessageLine(line, setParsedResponses, false); + // Process all messages in chronological order + sessionData.messages.forEach((msg: any) => { + if (msg.type === "user_message" || msg.type === "user") { + // Process user message + const textContent = msg.content; + const grafanaLinks = extractGrafanaLinks(textContent); + + setParsedResponses((prev) => [ + ...prev, + { + type: "user_message", + content: textContent, + displayedContent: textContent, + isAnimating: false, + timestamp: Date.now(), + grafanaLinks: + grafanaLinks.length > 0 ? grafanaLinks : undefined, + }, + ]); + } else if (msg.content) { + // Process assistant message content line by line + const lines = msg.content + .split("\n") + .filter((line: string) => line.trim()); + lines.forEach((line: string) => { + processMessageLine( + line, + setParsedResponses, + false, + undefined, + (sessionId: string) => { + console.log( + "Setting session ID from loadChatSession:", + sessionId + ); + setCurrentSessionId(sessionId); + } + ); + }); + } }); } setSelectedSession(sessionId); setCurrentSessionId(sessionId); - setFeedbackVisible(sessionData.status === "completed"); } else { console.error("Failed to load chat session"); setError("Failed to load chat session"); @@ -222,41 +310,65 @@ export const TorchAgentPage = () => { setParsedResponses([]); setSelectedSession(null); setCurrentSessionId(null); - setFeedbackVisible(false); setError(""); setTotalTokens(0); setCompletedTokens(0); setElapsedTime(0); setCompletedTime(0); setIsSessionLoading(false); + + // Close mobile sidebar when starting new chat + if (isMobile && drawerOpen) { + setDrawerOpen(false); + } }; useEffect(() => { if (session.data?.user) { fetchChatHistory(); } - }, [session.data?.user]); + }, [session.data?.user, fetchChatHistory]); useEffect(() => { if (!session.data?.user) return; - const hasActiveChat = - isLoading || - currentSessionId !== null || - chatHistory.some( - (chat) => - chat.title === "New Chat..." || - (chat.status && chat.status === "in_progress") - ); + let timeoutId: NodeJS.Timeout; + + const scheduleNextUpdate = () => { + // Only refresh history if there's an actual active streaming chat or in_progress sessions + const hasActiveChat = + isLoading || // Currently streaming a response + // OR there's a current session that's in progress (user sent message, waiting for response) + (currentSessionId && + chatHistory.some( + (chat) => + chat.sessionId === currentSessionId && + chat.status === "in_progress" + )) || + // OR there are other sessions marked as in_progress or temporary "New Chat..." entries + chatHistory.some( + (chat) => + chat.title === "New Chat..." || + (chat.status && chat.status === "in_progress") + ); + + if (hasActiveChat) { + timeoutId = setTimeout(async () => { + await fetchChatHistory(); + scheduleNextUpdate(); // Schedule the next check + }, 10000); + } + }; - if (hasActiveChat) { - const interval = setInterval(() => { - fetchChatHistory(); - }, 10000); + // Start the first check + scheduleNextUpdate(); - return () => clearInterval(interval); - } - }, [session.data?.user, isLoading, currentSessionId, chatHistory]); + return () => { + if (timeoutId) { + clearTimeout(timeoutId); + } + }; + }, [session.data?.user, isLoading, currentSessionId, fetchChatHistory]); useEffect(() => { if (!isLoading) return; @@ -349,22 +461,33 @@ export const TorchAgentPage = () => { setResponse((prev) => prev + line + "\n"); const json = JSON.parse(line) as any; - if (json.status === "connecting" && json.userUuid) { - setCurrentSessionId(json.userUuid); - - const now = new Date(); - const timestamp = now.toISOString(); - const tempSession: ChatSession = { - sessionId: json.userUuid, - timestamp: timestamp, - date: timestamp.slice(0, 10), - filename: `${timestamp}_${json.userUuid}.json`, - key: `history/user/${timestamp}_${json.userUuid}.json`, - title: "New Chat...", - }; - - setChatHistory((prev) => [tempSession, ...prev]); - setSelectedSession(json.userUuid); + if (json.status === "connecting" && json.sessionId) { + setCurrentSessionId(json.sessionId); + + // Only add to chat history if this is a new session (not resuming) + if (!json.resumeSession) { + const now = new Date(); + const timestamp = now.toISOString(); + const tempSession: ChatSession = { + sessionId: json.sessionId, + timestamp: timestamp, + date: timestamp.slice(0, 10), + filename: `${timestamp}_${json.sessionId}.json`, + key: `history/user/${timestamp}_${json.sessionId}.json`, + title: "New Chat...", + }; + + setChatHistory((prev) => [tempSession, ...prev]); + setSelectedSession(json.sessionId); + } + // For resumed sessions, we keep the existing selectedSession and just update currentSessionId + return; + } + + // Handle system messages with session_id + if (json.type === "agent_mgmt" && json.sessionId) { + console.log("Received session_id from system message:", json.sessionId); + setCurrentSessionId(json.sessionId); return; } @@ -378,7 +501,19 @@ export const TorchAgentPage = () => { } if (json.type === "assistant" || json.type === "user") { - processMessageLine("", setParsedResponses, true, json); + processMessageLine( + "", + setParsedResponses, + true, + json, + (sessionId: string) => { + console.log( + "Setting session ID from processMessageLine:", + sessionId + ); + setCurrentSessionId(sessionId); + } + ); } else if (json.type === "content_block_delta") { processContentBlockDelta(json, setParsedResponses); } else if (json.error) { @@ -394,9 +529,6 @@ export const TorchAgentPage = () => { fetchControllerRef.current.abort(); fetchControllerRef.current = null; setIsLoading(false); - if (currentSessionId) { - setFeedbackVisible(true); - } } }; @@ -410,10 +542,25 @@ export const TorchAgentPage = () => { cancelRequest(); + // Add user query to parsed responses immediately and clear input + const userMessage = { + type: "user_message" as const, + content: query, + timestamp: Date.now(), + }; + setIsLoading(true); setResponse(""); - setParsedResponses([]); - setFeedbackVisible(false); + + // For new chats or when no session exists, start fresh + // For continued sessions, append to existing responses + if (!currentSessionId || !selectedSession) { + setParsedResponses([userMessage]); // Start fresh for new chats + } else { + setParsedResponses((prev) => [...prev, userMessage]); // Append for continued chats + } + + setQuery(""); // Clear the input immediately setError(""); setAllToolsExpanded(false); resetAutoScroll(); @@ -428,6 +575,18 @@ export const TorchAgentPage = () => { fetchControllerRef.current = new AbortController(); try { + const requestBody: any = { query: userMessage.content }; + + // Include sessionId if this is a continued session + if (currentSessionId) { + console.log("Continuing session with sessionId:", currentSessionId); + requestBody.sessionId = currentSessionId; + } else { + console.log("Starting new session"); + } + + console.log("Sending request body:", requestBody); + const response = await fetch("/api/torchagent-api", { method: "POST", headers: { @@ -436,7 +595,7 @@ export const TorchAgentPage = () => { Connection: "keep-alive", "X-Requested-With": "XMLHttpRequest", }, - body: JSON.stringify({ query }), + body: JSON.stringify(requestBody), signal: fetchControllerRef.current.signal, cache: "no-store", // @ts-ignore @@ -477,16 +636,11 @@ export const TorchAgentPage = () => { parseJsonLine(buffer.trim()); } - setTimeout(() => { - const finalTokens = calculateTotalTokens(parsedResponses); - setCompletedTime(elapsedTime); - setCompletedTokens(finalTokens); - setTotalTokens(finalTokens); - setIsLoading(false); - if (currentSessionId) { - setFeedbackVisible(true); - } - }, 500); + const finalTokens = calculateTotalTokens(parsedResponses); + setCompletedTime(elapsedTime); + setCompletedTokens(finalTokens); + setTotalTokens(finalTokens); + setIsLoading(false); break; } @@ -512,9 +666,6 @@ export const TorchAgentPage = () => { setError(`Error: ${err instanceof Error ? err.message : String(err)}`); } setIsLoading(false); - if (currentSessionId) { - setFeedbackVisible(true); - } } }; @@ -575,66 +726,57 @@ export const TorchAgentPage = () => { .filter((item) => item.type !== "todo_list") .map((item, index) => (
+ {" "} {item.type === "user_message" ? ( - - - User Query: - - - {renderTextWithLinks(item.content, false)} - - + <> + + {renderMarkdownWithLinks( + item.content, + false, + theme.palette.mode === "dark" + )} + {item.grafanaLinks && item.grafanaLinks.length > 0 && ( - + {item.grafanaLinks.map((link, i) => ( ))} - + )} - + ) : item.type === "text" ? ( <> - - {renderTextWithLinks( - (item.displayedContent !== undefined - ? item.displayedContent - : item.content - )?.trim() || "", - item.isAnimating + + + {renderMarkdownWithLinks( + (item.displayedContent !== undefined + ? item.displayedContent + : item.content + )?.trim() || "", + item.isAnimating, + theme.palette.mode === "dark" + )} + + {!item.isAnimating && ( + + {item.outputTokens + ? `${formatTokenCount(item.outputTokens)} tokens` + : ""} + )} - - - {!item.isAnimating && ( - - {/* For historical chats, we skip timing calculations since timestamps are strings */} - {item.outputTokens - ? `${formatTokenCount(item.outputTokens)} tokens` - : ""} - - )} - + {item.grafanaLinks && item.grafanaLinks.length > 0 && ( - + {item.grafanaLinks.map((link, i) => ( ))} - + )} - ) : item.type === "tool_use" && item.toolName ? ( + ) : item.type === "tool_use" && + item.toolName && + item.toolName !== "TodoWrite" && + item.toolName !== "TodoRead" ? ( {
))} - {isLoading - ? renderLoader() - : (completedTokens > 0 || feedbackVisible) && ( - - {completedTokens > 0 && ( - - Completed in {formatElapsedTime(completedTime)} • Total:{" "} - {formatTokenCount(completedTokens)} tokens - - )} - - - )} + {isLoading && renderLoader()} + + {completedTokens > 0 && !isLoading && ( + + + Completed in {formatElapsedTime(completedTime)} • Total:{" "} + {formatTokenCount(completedTokens)} tokens + + + )} ); }; - const sidebarWidth = 300; - return ( - {/* Sidebar */} + {/* Hamburger button for collapsed sidebar */} + {!drawerOpen && ( + + + + + + + + )} + - {/* Main Content */} - + {isSessionLoading ? ( - + ) : ( - + { bugReportUrl={bugReportUrl} /> + + {parsedResponses.length > 0 && ( + <> + + Results + {parsedResponses.length > 0 && + parsedResponses.some( + (item) => item.type === "tool_use" + ) && ( + + )} + + + {error && ( + + {error} + + )} + + {renderContent()} + + {debugVisible && ( + + + Debug: Raw Response + +
+                        {response || "(No data yet)"}
+                      
+
+ )} + + )} +
+ {/* Show welcome message for completely new chats */} {!selectedSession && ( { query={query} isLoading={isLoading} debugVisible={debugVisible} - isReadOnly={selectedSession !== currentSessionId} onQueryChange={handleQueryChange} onSubmit={handleSubmit} onToggleDebug={() => setDebugVisible(!debugVisible)} onCancel={cancelRequest} + currentSessionId={currentSessionId} /> )} - - - - Results - {parsedResponses.length > 0 && - parsedResponses.some((item) => item.type === "tool_use") && ( - - )} - - - {error && ( - - {error} - - )} - - {renderContent()} - - {debugVisible && ( - - - Debug: Raw Response - -
-                    {response || "(No data yet)"}
-                  
-
- )} -
)} -
+
); }; diff --git a/torchci/components/TorchAgentPage/ChatHistorySidebar.tsx b/torchci/components/TorchAgentPage/ChatHistorySidebar.tsx index c7e1ad9dc3..020652cf47 100644 --- a/torchci/components/TorchAgentPage/ChatHistorySidebar.tsx +++ b/torchci/components/TorchAgentPage/ChatHistorySidebar.tsx @@ -1,5 +1,6 @@ import AddIcon from "@mui/icons-material/Add"; import ChatIcon from "@mui/icons-material/Chat"; +import MenuIcon from "@mui/icons-material/Menu"; import { Box, CircularProgress, @@ -31,8 +32,11 @@ interface ChatHistorySidebarProps { chatHistory: ChatSession[]; selectedSession: string | null; isHistoryLoading: boolean; + isMobile: boolean; + headerHeight: number; onStartNewChat: () => void; onLoadChatSession: (sessionId: string) => void; + onToggleSidebar: () => void; } export const ChatHistorySidebar: React.FC = ({ @@ -41,22 +45,30 @@ export const ChatHistorySidebar: React.FC = ({ chatHistory, selectedSession, isHistoryLoading, + isMobile, + headerHeight, onStartNewChat, onLoadChatSession, + onToggleSidebar, }) => { return ( @@ -69,7 +81,18 @@ export const ChatHistorySidebar: React.FC = ({ mb: 2, }} > - Chat History + + + + + + + Chat History + diff --git a/torchci/components/TorchAgentPage/HeaderSection.tsx b/torchci/components/TorchAgentPage/HeaderSection.tsx index 8b0455d643..d3868ef5bf 100644 --- a/torchci/components/TorchAgentPage/HeaderSection.tsx +++ b/torchci/components/TorchAgentPage/HeaderSection.tsx @@ -33,34 +33,36 @@ export const HeaderSection: React.FC = ({ )} - - TorchAgent - + + + TorchAgent + - - - - - - - + + + + + + + + ); diff --git a/torchci/components/TorchAgentPage/LoadingDisplay.tsx b/torchci/components/TorchAgentPage/LoadingDisplay.tsx index d80ac5dd30..ec0f4f60cb 100644 --- a/torchci/components/TorchAgentPage/LoadingDisplay.tsx +++ b/torchci/components/TorchAgentPage/LoadingDisplay.tsx @@ -5,26 +5,39 @@ interface LoadingDisplayProps { message: string; size?: number; showFullScreen?: boolean; + drawerOpen?: boolean; + sidebarWidth?: number; } export const LoadingDisplay: React.FC = ({ message, size = 60, showFullScreen = false, + drawerOpen = false, + sidebarWidth = 300, }) => { + const baseStyles = { + display: "flex", + flexDirection: "column", + alignItems: "center", + justifyContent: "center", + } as const; + const containerSx = showFullScreen ? { - display: "flex", - flexDirection: "column", - alignItems: "center", - justifyContent: "center", + ...baseStyles, height: "100vh", + width: "100%", + maxWidth: "900px", + marginLeft: drawerOpen + ? `calc(50% + ${-sidebarWidth / 2}px)` + : "calc(50%)", + marginRight: "auto", + transform: "translateX(-50%)", + transition: "margin-left 0.3s ease, transform 0.3s ease", } : { - display: "flex", - flexDirection: "column", - alignItems: "center", - justifyContent: "center", + ...baseStyles, height: "300px", }; diff --git a/torchci/components/TorchAgentPage/QueryInputSection.tsx b/torchci/components/TorchAgentPage/QueryInputSection.tsx index 59308c2499..efde236627 100644 --- a/torchci/components/TorchAgentPage/QueryInputSection.tsx +++ b/torchci/components/TorchAgentPage/QueryInputSection.tsx @@ -1,51 +1,46 @@ import { Box, Button, TextField } from "@mui/material"; import React from "react"; +import { FeedbackButtons } from "./FeedbackButtons"; import { QuerySection } from "./styles"; interface QueryInputSectionProps { query: string; isLoading: boolean; debugVisible: boolean; - isReadOnly?: boolean; onQueryChange: (event: React.ChangeEvent) => void; onSubmit: (event: React.FormEvent) => void; onToggleDebug: () => void; onCancel: () => void; + currentSessionId: string | null; } export const QueryInputSection: React.FC = ({ query, isLoading, debugVisible, - isReadOnly = false, onQueryChange, onSubmit, onToggleDebug, onCancel, + currentSessionId, }) => { return ( { - if (!isReadOnly && (e.ctrlKey || e.metaKey) && e.key === "Enter") { + if ((e.ctrlKey || e.metaKey) && e.key === "Enter") { e.preventDefault(); if (!isLoading && query.trim()) { onSubmit(e); @@ -60,9 +55,19 @@ export const QueryInputSection: React.FC = ({ mt: 2, }} > - + + + + {isLoading && ( )} - {!isReadOnly && ( + { - )} + } diff --git a/torchci/components/TorchAgentPage/hooks.ts b/torchci/components/TorchAgentPage/hooks.ts index 8fdb79b025..9cff4e2946 100644 --- a/torchci/components/TorchAgentPage/hooks.ts +++ b/torchci/components/TorchAgentPage/hooks.ts @@ -84,7 +84,8 @@ export const useTokenCalculator = () => { export const useAutoScroll = ( isLoading: boolean, - parsedResponses: ParsedContent[] + parsedResponses: ParsedContent[], + containerRef: React.RefObject ) => { const [autoScrollEnabled, setAutoScrollEnabled] = useState(true); const [showScrollButton, setShowScrollButton] = useState(false); @@ -92,11 +93,13 @@ export const useAutoScroll = ( const scrollToBottomAndEnable = useCallback(() => { setAutoScrollEnabled(true); setShowScrollButton(false); - window.scrollTo({ - top: document.body.scrollHeight, - behavior: "smooth", - }); - }, []); + if (containerRef.current) { + containerRef.current.scrollTo({ + top: containerRef.current.scrollHeight, + behavior: "smooth", + }); + } + }, [containerRef]); const resetAutoScroll = useCallback(() => { setAutoScrollEnabled(true); @@ -105,9 +108,9 @@ export const useAutoScroll = ( useEffect(() => { const isAtBottom = () => { - const scrollPosition = window.innerHeight + window.scrollY; - const bottomOfPage = document.body.offsetHeight - 100; - return scrollPosition >= bottomOfPage; + if (!containerRef.current) return true; + const { scrollTop, scrollHeight, clientHeight } = containerRef.current; + return scrollHeight - scrollTop - clientHeight <= 100; }; const handleScroll = () => { @@ -124,9 +127,16 @@ export const useAutoScroll = ( } }; - window.addEventListener("scroll", handleScroll); - return () => window.removeEventListener("scroll", handleScroll); - }, [isLoading, showScrollButton, parsedResponses, autoScrollEnabled]); + const container = containerRef.current; + container?.addEventListener("scroll", handleScroll); + return () => container?.removeEventListener("scroll", handleScroll); + }, [ + isLoading, + showScrollButton, + parsedResponses, + autoScrollEnabled, + containerRef, + ]); useEffect(() => { if (!isLoading || !autoScrollEnabled || parsedResponses.length === 0) @@ -140,22 +150,24 @@ export const useAutoScroll = ( } const isAtBottom = () => { - const scrollPosition = window.innerHeight + window.scrollY; - const bottomOfPage = document.body.offsetHeight - 50; - return scrollPosition >= bottomOfPage; + if (!containerRef.current) return true; + const { scrollTop, scrollHeight, clientHeight } = containerRef.current; + return scrollHeight - scrollTop - clientHeight <= 50; }; if (!isAtBottom()) { const scrollToBottom = () => { - window.scrollTo({ - top: document.body.scrollHeight, - behavior: "smooth", - }); + if (containerRef.current) { + containerRef.current.scrollTo({ + top: containerRef.current.scrollHeight, + behavior: "smooth", + }); + } }; requestAnimationFrame(scrollToBottom); } - }, [parsedResponses, isLoading, autoScrollEnabled]); + }, [parsedResponses, isLoading, autoScrollEnabled, containerRef]); useEffect(() => { if (!isLoading && parsedResponses.length > 0 && autoScrollEnabled) { @@ -168,16 +180,24 @@ export const useAutoScroll = ( const finalScrollTimer = setTimeout(() => { if (autoScrollEnabled) { - window.scrollTo({ - top: document.body.scrollHeight, - behavior: "smooth", - }); + if (containerRef.current) { + containerRef.current.scrollTo({ + top: containerRef.current.scrollHeight, + behavior: "smooth", + }); + } } }, 200); return () => clearTimeout(finalScrollTimer); } - }, [isLoading, parsedResponses.length, autoScrollEnabled, parsedResponses]); + }, [ + isLoading, + parsedResponses.length, + autoScrollEnabled, + parsedResponses, + containerRef, + ]); return { autoScrollEnabled, diff --git a/torchci/components/TorchAgentPage/messageProcessor.ts b/torchci/components/TorchAgentPage/messageProcessor.ts index 00c8ff5277..c511441098 100644 --- a/torchci/components/TorchAgentPage/messageProcessor.ts +++ b/torchci/components/TorchAgentPage/messageProcessor.ts @@ -6,7 +6,8 @@ export const processMessageLine = ( line: string, setParsedResponses: React.Dispatch>, isStreaming: boolean = false, - json?: any + json?: any, + onSessionIdReceived?: (sessionId: string) => void ): void => { try { // If json is not provided, parse the line @@ -14,6 +15,15 @@ export const processMessageLine = ( json = JSON.parse(line); } + // Handle system messages with session_id + if (json.type === "system" && json.subtype === "init" && json.session_id) { + console.log("Received session_id from system message:", json.session_id); + if (onSessionIdReceived) { + onSessionIdReceived(json.session_id); + } + return; + } + // Handle assistant messages if (json.type === "assistant" && json.message?.content) { json.message.content.forEach((item: any) => { @@ -249,7 +259,8 @@ export const processUserMessages = ( setParsedResponses: React.Dispatch> ): void => { messages.forEach((msg: any) => { - if (msg.type === "user_message" && msg.content) { + // Handle both "user_message" and "user" type messages + if ((msg.type === "user_message" || msg.type === "user") && msg.content) { const textContent = msg.content; const grafanaLinks = extractGrafanaLinks(textContent); diff --git a/torchci/components/TorchAgentPage/styles.ts b/torchci/components/TorchAgentPage/styles.ts index 9114ce580b..a850845657 100644 --- a/torchci/components/TorchAgentPage/styles.ts +++ b/torchci/components/TorchAgentPage/styles.ts @@ -1,17 +1,35 @@ import { Box, Button, Paper, Typography } from "@mui/material"; import { styled } from "@mui/material/styles"; -export const TorchAgentPageContainer = styled("div")({ - fontFamily: "Roboto", - padding: "20px", - maxWidth: "1200px", - margin: "0 auto", +export const TorchAgentPageContainer = styled("div")<{ + drawerOpen?: boolean; + sidebarWidth?: number; +}>(({ drawerOpen = false, sidebarWidth = 300 }) => { + // When drawer is open, we want to center the content in the remaining space + // The sidebar takes up sidebarWidth, so we shift left by half of that to center + const leftOffset = drawerOpen ? -sidebarWidth / 2 : 0; + + return { + fontFamily: "Roboto", + padding: "20px", + width: "100%", + maxWidth: "900px", + marginTop: "0", + marginBottom: "0", + marginLeft: `calc(50% + ${leftOffset}px)`, + marginRight: "auto", + transform: "translateX(-50%)", + transition: "margin-left 0.3s ease, transform 0.3s ease", + }; }); -export const QuerySection = styled(Paper)({ +export const QuerySection = styled(Paper)(({ theme }) => ({ padding: "20px", - marginBottom: "20px", -}); + position: "sticky", + bottom: 0, + zIndex: 5, + borderTop: `1px solid ${theme.palette.divider}`, +})); export const ResultsSection = styled(Paper)(({ theme }) => ({ padding: "20px", @@ -21,14 +39,76 @@ export const ResultsSection = styled(Paper)(({ theme }) => ({ scrollBehavior: "smooth", })); +export const ChatMain = styled(Box)({ + flexGrow: 1, + display: "flex", + flexDirection: "column", + height: "100vh", +}); + +export const ChatMessages = styled(Box)(({ theme }) => ({ + flexGrow: 1, + overflowY: "auto", + padding: "20px", + backgroundColor: theme.palette.mode === "dark" ? "#1a1a1a" : "#f5f5f5", + display: "flex", + flexDirection: "column", +})); + +export const MessageBubble = styled(Box)<{ + from: "user" | "agent"; + fullWidth?: boolean; +}>(({ theme, from, fullWidth }) => ({ + maxWidth: fullWidth ? "100%" : "80%", + padding: "12px", + borderRadius: 12, + marginBottom: "10px", + alignSelf: from === "user" ? "flex-end" : "flex-start", + marginLeft: from === "user" ? "auto" : "0", + marginRight: from === "user" ? "0" : "auto", + backgroundColor: + from === "user" + ? "#059669" // Green color instead of red + : theme.palette.mode === "dark" + ? "#333" + : "#e0e0e0", + color: from === "user" ? "white" : theme.palette.text.primary, +})); + export const ResponseText = styled("div")(({ theme }) => ({ - whiteSpace: "pre-wrap", wordBreak: "break-word", fontFamily: "Roboto, 'Helvetica Neue', Arial, sans-serif", margin: 0, lineHeight: 1.5, - paddingTop: "1em", color: theme.palette.mode === "dark" ? "#e0e0e0" : "inherit", + // Reset styles for markdown content + "& > *:first-of-type": { + marginTop: 0, + }, + "& > *:last-child": { + marginBottom: 0, + }, + // Dark mode adjustments for code blocks + "& code": { + backgroundColor: + theme.palette.mode === "dark" + ? "rgba(255, 255, 255, 0.1)" + : "rgba(0, 0, 0, 0.1)", + color: theme.palette.mode === "dark" ? "#e0e0e0" : "inherit", + }, + "& pre": { + backgroundColor: + theme.palette.mode === "dark" + ? "rgba(255, 255, 255, 0.05)" + : "rgba(0, 0, 0, 0.05)", + }, + "& blockquote": { + borderLeftColor: theme.palette.mode === "dark" ? "#666" : "#ccc", + color: + theme.palette.mode === "dark" + ? "rgba(255, 255, 255, 0.7)" + : "rgba(0, 0, 0, 0.7)", + }, })); export const ToolUseBlock = styled(Paper)(({ theme }) => ({ @@ -107,7 +187,7 @@ export const ChunkMetadata = styled(Typography)(({ theme }) => ({ : "rgba(0, 0, 0, 0.5)", textAlign: "right", marginTop: "4px", - marginBottom: "16px", + marginBottom: "-5px", fontFamily: "Roboto, 'Helvetica Neue', Arial, sans-serif", })); @@ -161,7 +241,7 @@ export const ScrollToBottomButton = styled(Button)(({ theme }) => ({ height: "48px", minWidth: "48px", borderRadius: "50%", - backgroundColor: theme.palette.primary.main, + backgroundColor: "#059669", // Green color color: "white", display: "flex", justifyContent: "center", @@ -172,7 +252,7 @@ export const ScrollToBottomButton = styled(Button)(({ theme }) => ({ transition: "all 0.2s ease-in-out", padding: 0, "&:hover": { - backgroundColor: theme.palette.primary.dark, + backgroundColor: "#047857", // Darker green transform: "scale(1.1)", boxShadow: "0 6px 10px rgba(0, 0, 0, 0.4)", }, diff --git a/torchci/components/TorchAgentPage/types.ts b/torchci/components/TorchAgentPage/types.ts index 4b479d039c..2eddc9c06a 100644 --- a/torchci/components/TorchAgentPage/types.ts +++ b/torchci/components/TorchAgentPage/types.ts @@ -46,6 +46,8 @@ export interface MessageWrapper { is_error?: boolean; result?: string; session_id?: string; + sessionId?: string; + resumeSession?: boolean; usage?: { output_tokens: number; input_tokens?: number; diff --git a/torchci/components/TorchAgentPage/utils.tsx b/torchci/components/TorchAgentPage/utils.tsx index 9af2ecbd16..fb0f20290e 100644 --- a/torchci/components/TorchAgentPage/utils.tsx +++ b/torchci/components/TorchAgentPage/utils.tsx @@ -1,4 +1,5 @@ import React from "react"; +import ReactMarkdown from "react-markdown"; import { GrafanaLink } from "./types"; const GRAFANA_LINK_REGEX = @@ -18,72 +19,297 @@ export const extractGrafanaLinks = (text: string): GrafanaLink[] => { return links; }; -export const renderTextWithLinks = ( - text: string, - isAnimating?: boolean -): React.ReactNode => { - if (!text) return null; +const convertPlainUrlsToMarkdown = (text: string): string => { + // Simple and reliable URL regex: https?:// followed by any non-whitespace characters + const urlRegex = /https?:\/\/\S+/gi; - const result: React.ReactNode[] = []; - let lastIndex = 0; + // Create a copy to work with + let result = text; + let processedUrls = new Set(); + + // Find all URLs and process them let match; - let counter = 0; + const matches = []; - // Reset regex lastIndex to avoid issues with global regex - GRAFANA_LINK_REGEX.lastIndex = 0; + // Reset regex + urlRegex.lastIndex = 0; - while ((match = GRAFANA_LINK_REGEX.exec(text)) !== null) { - if (match.index > lastIndex) { - result.push(text.substring(lastIndex, match.index)); + while ((match = urlRegex.exec(text)) !== null) { + matches.push({ + url: match[0], + index: match.index, + length: match[0].length, + }); + } + + // Process matches in reverse order to avoid index shifting + for (let i = matches.length - 1; i >= 0; i--) { + const { url, index } = matches[i]; + + // Skip if we've already processed this exact URL + if (processedUrls.has(url)) continue; + processedUrls.add(url); + + // Get context around the URL + const beforeUrl = text.substring(Math.max(0, index - 2), index); + const afterUrl = text.substring(index + url.length, index + url.length + 1); + + // Check if it's already in markdown link format + if (beforeUrl.endsWith("](") || afterUrl.startsWith(")")) { + continue; } - result.push( - - {match[0]} - - ); + // Check if it's in an existing markdown link by looking for bracket patterns + const textBeforeUrl = text.substring(0, index); + const lastOpenBracket = textBeforeUrl.lastIndexOf("["); + const lastCloseBracket = textBeforeUrl.lastIndexOf("]("); - lastIndex = match.index + match[0].length; - } + // If we found '](' after the last '[' and close to this URL, skip it + if (lastCloseBracket > lastOpenBracket && index - lastCloseBracket < 20) { + continue; + } + + // Replace this specific occurrence in the result string + const beforeReplace = result.substring(0, index); + const afterReplace = result.substring(index + url.length); + result = beforeReplace + `[${url}](${url})` + afterReplace; - if (lastIndex < text.length) { - result.push(text.substring(lastIndex)); + // Adjust indices for remaining matches + const lengthDiff = `[${url}](${url})`.length - url.length; + for (let j = 0; j < i; j++) { + if (matches[j].index > index) { + matches[j].index += lengthDiff; + } + } } - if (text.length > 0 && result.length > 0 && isAnimating) { - const lastItem = result[result.length - 1]; + return result; +}; + +export const renderMarkdownWithLinks = ( + text: string, + isAnimating?: boolean, + isDarkMode?: boolean +): React.ReactNode => { + if (!text) return null; + + // Convert plain URLs to markdown links if they're not already markdown links + const processedText = convertPlainUrlsToMarkdown(text); - if (typeof lastItem === "string") { - result[result.length - 1] = ( - <> - {lastItem} - ( + - - - ); - } + > + {children} + + ), + // Custom code block styling + code: ({ children, className, ...props }) => { + const isBlock = className?.includes("language-"); + return ( + + {children} + + ); + }, + // Custom pre styling for code blocks + pre: ({ children, ...props }) => ( +
+            {children}
+          
+ ), + // Custom paragraph styling + p: ({ children, ...props }) => ( +

+ {children} +

+ ), + // Custom list styling + ul: ({ children, ...props }) => ( +
    + {children} +
+ ), + ol: ({ children, ...props }) => ( +
    + {children} +
+ ), + // Custom heading styling + h1: ({ children, ...props }) => ( +

+ {children} +

+ ), + h2: ({ children, ...props }) => ( +

+ {children} +

+ ), + h3: ({ children, ...props }) => ( +

+ {children} +

+ ), + // Custom blockquote styling + blockquote: ({ children, ...props }) => ( +
+ {children} +
+ ), + // Custom table styling + table: ({ children, ...props }) => ( + + {children} +
+ ), + th: ({ children, ...props }) => { + // Filter out react-markdown specific props that conflict with HTML props + const { node, ...htmlProps } = props as any; + return ( + + {children} + + ); + }, + td: ({ children, ...props }) => { + // Filter out react-markdown specific props that conflict with HTML props + const { node, ...htmlProps } = props as any; + return ( + + {children} + + ); + }, + }} + > + {processedText} + + ); + + // Add animation cursor if needed + if (isAnimating && text.length > 0) { + return ( +
+ {markdownElement} + + +
+ ); } - return result; + return markdownElement; }; export const formatElapsedTime = (seconds: number): string => { diff --git a/torchci/pages/api/torchagent-api/index.ts b/torchci/pages/api/torchagent-api/index.ts index cb15e61c2a..adb6c6b8ea 100644 --- a/torchci/pages/api/torchagent-api/index.ts +++ b/torchci/pages/api/torchagent-api/index.ts @@ -1,4 +1,3 @@ -import { randomUUID } from "crypto"; import { NextApiRequest, NextApiResponse } from "next"; import { getAuthorizedUsername } from "../../../lib/getAuthorizedUsername"; import { authOptions } from "../auth/[...nextauth]"; @@ -38,8 +37,8 @@ export default async function handler( return; } - // Get query from request body - const { query } = req.body; + // Get query and optional sessionId from request body + const { query, sessionId } = req.body; if (!query || typeof query !== "string") { console.log("Rejected: Invalid query parameter"); @@ -48,8 +47,17 @@ export default async function handler( .json({ error: "Query parameter is required and must be a string" }); } + if (sessionId && typeof sessionId !== "string") { + console.log("Rejected: Invalid sessionId parameter"); + return res + .status(400) + .json({ error: "SessionId parameter must be a string if provided" }); + } + console.log( - `Processing query (${query.length} chars) - forwarding to Lambda` + `Processing query (${query.length} chars) - ${ + sessionId ? `continuing session ${sessionId}` : "new session" + } - forwarding to Lambda` ); // CRITICAL STREAMING HEADERS @@ -76,14 +84,19 @@ export default async function handler( }; try { - // Generate a session ID for this user (could be made more sophisticated) - const userUuid = randomUUID(); - - console.log(`Calling Lambda with userUuid: ${userUuid}`); + const resumeSession = !!sessionId; + console.log( + `Calling Lambda with sessionId: ${sessionId} (${ + resumeSession ? "resumed session" : "new session" + })` + ); console.log("and token: ", AUTH_TOKEN); // Write initial message to start the stream - res.write(`{"status":"connecting","userUuid":"${userUuid}"}\n`); + // For continued sessions, this helps the frontend know the sessionId is being used + res.write( + `{"status":"connecting","sessionId":"${sessionId}","resumeSession":${resumeSession}}\n` + ); flushStream(res); @@ -96,7 +109,7 @@ export default async function handler( }, body: JSON.stringify({ query: query, - userUuid: userUuid, + sessionId: sessionId, username: username, }), });