diff --git a/flight-booking-app/app/api/chat/[id]/route.ts b/flight-booking-app/app/api/chat/[id]/route.ts new file mode 100644 index 0000000..3b6f2a5 --- /dev/null +++ b/flight-booking-app/app/api/chat/[id]/route.ts @@ -0,0 +1,26 @@ +import { chatMessageHook } from "@/workflows/chat/hooks/chat-message"; + +export async function POST( + req: Request, + { params }: { params: Promise<{ id: string }> } +) { + const { message } = await req.json(); + const { id: threadId } = await params; + + console.log("Resuming hook for thread:", threadId, "with message:", message); + + try { + await chatMessageHook.resume(`thread:${threadId}`, { message }); + return Response.json({ success: true }); + } catch (error) { + console.error("Error resuming hook for thread:", threadId, error); + return Response.json( + { + error: `Failed to resume hook: ${ + error instanceof Error ? error.message : "Unknown error" + }`, + }, + { status: 500 } + ); + } +} diff --git a/flight-booking-app/app/api/chat/[id]/stream/route.ts b/flight-booking-app/app/api/chat/[id]/stream/route.ts index 658b6a6..7de4d18 100644 --- a/flight-booking-app/app/api/chat/[id]/stream/route.ts +++ b/flight-booking-app/app/api/chat/[id]/stream/route.ts @@ -7,18 +7,20 @@ import { getRun } from "workflow/api"; //export const maxDuration = 5; export async function GET( - request: Request, - { params }: { params: Promise<{ id: string }> }, + request: Request, + { params }: { params: Promise<{ id: string }> } ) { - const { id } = await params; - const { searchParams } = new URL(request.url); - const startIndexParam = searchParams.get("startIndex"); - const startIndex = - startIndexParam !== null ? parseInt(startIndexParam, 10) : undefined; - const run = getRun(id); - const stream = run.getReadable({ startIndex }); + const { id } = await params; + const { searchParams } = new URL(request.url); - return createUIMessageStreamResponse({ - stream, - }); + const startIndexParam = searchParams.get("startIndex"); + const startIndex = + startIndexParam !== null ? parseInt(startIndexParam, 10) : undefined; + + const run = getRun(id); + const stream = run.getReadable({ startIndex }); + + return createUIMessageStreamResponse({ + stream, + }); } diff --git a/flight-booking-app/app/api/chat/route.ts b/flight-booking-app/app/api/chat/route.ts index 17e8430..7687cb5 100644 --- a/flight-booking-app/app/api/chat/route.ts +++ b/flight-booking-app/app/api/chat/route.ts @@ -8,17 +8,33 @@ import { chat } from "@/workflows/chat"; //export const maxDuration = 8; export async function POST(req: Request) { - const { messages }: { messages: UIMessage[] } = await req.json(); + const body = await req.json(); - const run = await start(chat, [messages]); - const workflowStream = run.readable; + // Extract threadId from body or generate one if not provided + const threadId: string = + body.threadId || + body.messages?.[0]?.metadata?.threadId || + crypto.randomUUID(); + const messages: UIMessage[] = body.messages || []; - return createUIMessageStreamResponse({ - stream: workflowStream, - headers: { - // The workflow run ID is stored into `localStorage` on the client side, - // which influences the `resume` flag in the `useChat` hook. - "x-workflow-run-id": run.runId, - }, - }); + console.log( + "Starting chat workflow for thread:", + threadId, + "with", + messages.length, + "messages" + ); + + const run = await start(chat, [threadId, messages]); + const workflowStream = run.readable; + + return createUIMessageStreamResponse({ + stream: workflowStream, + headers: { + // The workflow run ID is stored on the client side for reconnection + "x-workflow-run-id": run.runId, + // The thread ID is used for sending follow-up messages via hooks + "x-thread-id": threadId, + }, + }); } diff --git a/flight-booking-app/app/api/hooks/approval/route.ts b/flight-booking-app/app/api/hooks/approval/route.ts index 51e4e89..0ced6dc 100644 --- a/flight-booking-app/app/api/hooks/approval/route.ts +++ b/flight-booking-app/app/api/hooks/approval/route.ts @@ -1,4 +1,4 @@ -import { bookingApprovalHook } from '@/workflows/chat/hooks/approval'; +import { bookingApprovalHook } from "@/workflows/chat/hooks/approval"; export async function POST(request: Request) { const { toolCallId, approved, comment } = await request.json(); diff --git a/flight-booking-app/app/page.tsx b/flight-booking-app/app/page.tsx index 909704e..de56ab2 100644 --- a/flight-booking-app/app/page.tsx +++ b/flight-booking-app/app/page.tsx @@ -1,8 +1,7 @@ "use client"; -import { useChat } from "@ai-sdk/react"; import { WorkflowChatTransport } from "@workflow/ai"; -import { useEffect, useMemo, useRef } from "react"; +import { useEffect, useMemo, useRef, useCallback } from "react"; import { Conversation, ConversationContent, @@ -22,6 +21,7 @@ import { import ChatInput from "@/components/chat-input"; import type { MyUIMessage } from "@/schemas/chat"; import { BookingApproval } from "@/components/booking-approval"; +import { useMultiTurnChat } from "@/components/use-multi-turn-chat"; const SUGGESTIONS = [ "Find me flights from San Francisco to Los Angeles", @@ -36,83 +36,83 @@ const FULL_EXAMPLE_PROMPT = export default function ChatPage() { const textareaRef = useRef(null); + // Check for an active workflow run that we should reconnect to const activeWorkflowRunId = useMemo(() => { - if (typeof window === "undefined") return; + if (typeof window === "undefined") return undefined; return localStorage.getItem("active-workflow-run-id") ?? undefined; }, []); - const { stop, error, messages, sendMessage, status, setMessages } = - useChat({ - resume: !!activeWorkflowRunId, - onError(error) { - console.error("onError", error); - }, - onFinish(data) { - console.log("onFinish", data); - - // Update the chat history in `localStorage` to include the latest bot message - console.log("Saving chat history to localStorage", data.messages); - localStorage.setItem("chat-history", JSON.stringify(data.messages)); - - requestAnimationFrame(() => { - textareaRef.current?.focus(); - }); - }, - - transport: new WorkflowChatTransport({ - onChatSendMessage: (response, options) => { - console.log("onChatSendMessage", response, options); + const { + stop, + messages, + sendMessage, + status, + setMessages, + error, + setThreadId, + pendingMessage, + endSession, + } = useMultiTurnChat({ + // Resume existing session if we have an active workflow run + resume: !!activeWorkflowRunId, + onError(error) { + console.error("Chat error:", error); + }, + onFinish() { + // Focus input after response completes + requestAnimationFrame(() => { + textareaRef.current?.focus(); + }); + }, - // Update the chat history in `localStorage` to include the latest user message - localStorage.setItem( - "chat-history", - JSON.stringify(options.messages) - ); + transport: new WorkflowChatTransport({ + onChatSendMessage: (response) => { + // Capture the thread ID from the server for follow-up messages + const serverThreadId = response.headers.get("x-thread-id"); + if (serverThreadId) { + setThreadId(serverThreadId); + localStorage.setItem("active-thread-id", serverThreadId); + } - // We'll store the workflow run ID in `localStorage` to allow the client - // to resume the chat session after a page refresh or network interruption - const workflowRunId = response.headers.get("x-workflow-run-id"); - if (!workflowRunId) { - throw new Error( - 'Workflow run ID not found in "x-workflow-run-id" response header' - ); - } + // Store the workflow run ID for reconnection after page refresh + const workflowRunId = response.headers.get("x-workflow-run-id"); + if (workflowRunId) { localStorage.setItem("active-workflow-run-id", workflowRunId); - }, - onChatEnd: ({ chatId, chunkIndex }) => { - console.log("onChatEnd", chatId, chunkIndex); - - // Once the chat stream ends, we can remove the workflow run ID from `localStorage` - localStorage.removeItem("active-workflow-run-id"); - }, - // Configure reconnection to use the stored workflow run ID - prepareReconnectToStreamRequest: ({ id, api, ...rest }) => { - console.log("prepareReconnectToStreamRequest", id); - const workflowRunId = localStorage.getItem("active-workflow-run-id"); - if (!workflowRunId) { - throw new Error("No active workflow run ID found"); - } - // Use the workflow run ID instead of the chat ID for reconnection - return { - ...rest, - api: `/api/chat/${encodeURIComponent(workflowRunId)}/stream`, - }; - }, - // Optional: Configure error handling for reconnection attempts - maxConsecutiveErrors: 5, - }), - }); + } + }, + onChatEnd: () => { + // Workflow completed - clear the run ID + localStorage.removeItem("active-workflow-run-id"); + }, + // Configure reconnection to use the stored workflow run ID + prepareReconnectToStreamRequest: ({ api, ...rest }) => { + const workflowRunId = localStorage.getItem("active-workflow-run-id"); + if (!workflowRunId) { + throw new Error("No active workflow run ID found for reconnection"); + } + return { + ...rest, + api: `/api/chat/${encodeURIComponent(workflowRunId)}/stream`, + }; + }, + maxConsecutiveErrors: 5, + }), + }); - // Load chat history from `localStorage`. In a real-world application, - // this would likely be done on the server side and loaded from a database, - // but for the purposes of this demo, we'll load it from `localStorage`. - useEffect(() => { - const chatHistory = localStorage.getItem("chat-history"); - if (!chatHistory) return; - setMessages(JSON.parse(chatHistory) as MyUIMessage[]); - }, [setMessages]); + // Clear session and start fresh + const handleNewChat = useCallback(async () => { + try { + await stop(); + } catch { + // Ignore abort errors when stopping + } + await endSession(); + localStorage.removeItem("active-workflow-run-id"); + localStorage.removeItem("active-thread-id"); + setMessages([]); + }, [stop, endSession, setMessages]); - // Activate the input field + // Focus input on mount useEffect(() => { textareaRef.current?.focus(); }, []); @@ -134,7 +134,7 @@ export default function ChatPage() { )} - {messages.length === 0 && ( + {messages.length === 0 && !pendingMessage && (

@@ -178,6 +178,7 @@ export default function ChatPage() {

)} + {messages.map((message, index) => { @@ -200,19 +201,18 @@ export default function ChatPage() { // Render workflow data messages if (part.type === "data-workflow" && "data" in part) { - const data = part.data as any; + const data = part.data as Record; return (
- {data.message} + {String(data.message ?? JSON.stringify(data))}
); } // Render tool parts - // Type guard to check if this is a tool invocation part if ( part.type === "tool-searchFlights" || part.type === "tool-checkFlightStatus" || @@ -221,7 +221,6 @@ export default function ChatPage() { part.type === "tool-checkBaggageAllowance" || part.type === "tool-sleep" ) { - // Additional type guard to ensure we have the required properties if (!("toolCallId" in part) || !("state" in part)) { return null; } @@ -233,7 +232,9 @@ export default function ChatPage() { {part.input ? ( - + } + /> ) : null} ); } + if (part.type === "tool-bookingApproval") { return ( ); })} + + {/* Show pending follow-up message while waiting for stream */} + {pendingMessage && ( + + + {pendingMessage} + Sending... + + + )} + {/* Show loading indicator when message is sent but no assistant response yet */} {messages.length > 0 && messages[messages.length - 1].role === "user" && @@ -314,20 +328,21 @@ export default function ChatPage() { }); }} stop={stop} + onNewChat={handleNewChat} /> ); } // Helper function to render tool outputs with proper formatting -function renderToolOutput(part: any) { - const partOutput = part.output as any; +function renderToolOutput(part: Record) { + const partOutput = part.output; if (!partOutput) { return null; } - const parsedPartOutput = JSON.parse(partOutput); - const output = parsedPartOutput.output.value; - const parsedOutput = JSON.parse(output); + const parsedPartOutput = JSON.parse(String(partOutput)); + const output = parsedPartOutput.output?.value; + const parsedOutput = output ? JSON.parse(output) : null; switch (part.type) { case "tool-searchFlights": { @@ -335,19 +350,19 @@ function renderToolOutput(part: any) { return (

{parsedOutput?.message}

- {flights.map((flight: any) => ( + {flights.map((flight: Record) => (
- {flight.airline} - {flight.flightNumber} + {String(flight.airline)} - {String(flight.flightNumber)}
- {flight.from} → {flight.to} + {String(flight.from)} → {String(flight.to)}
- Departure: {new Date(flight.departure).toLocaleString()} + Departure: {new Date(String(flight.departure)).toLocaleString()}
Status:{" "} @@ -358,10 +373,10 @@ function renderToolOutput(part: any) { : "text-orange-600" } > - {flight.status} + {String(flight.status)}
-
${flight.price}
+
${String(flight.price)}
))}
@@ -369,33 +384,35 @@ function renderToolOutput(part: any) { } case "tool-checkFlightStatus": { - const status = parsedOutput; + const flightStatus = parsedOutput; return (
-
Flight {status.flightNumber}
+
Flight {flightStatus.flightNumber}
Status:{" "} - {status.status} + {flightStatus.status}
- {status.from} → {status.to} + {flightStatus.from} → {flightStatus.to}
-
Airline: {status.airline}
- Departure: {new Date(status.departure).toLocaleString()} + Airline: {flightStatus.airline}
- Arrival: {new Date(status.arrival).toLocaleString()} + Departure: {new Date(flightStatus.departure).toLocaleString()}
-
Gate: {status.gate}
+
+ Arrival: {new Date(flightStatus.arrival).toLocaleString()} +
+
Gate: {flightStatus.gate}
); } @@ -431,7 +448,6 @@ function renderToolOutput(part: any) { case "tool-bookFlight": { const booking = parsedOutput; - return (
✅ Booking Confirmed!
@@ -467,10 +483,11 @@ function renderToolOutput(part: any) { } case "tool-sleep": { + const input = part.input as { durationMs?: number }; return (

- Sleeping for {part.input.durationMs}ms... + Sleeping for {input?.durationMs}ms...

); diff --git a/flight-booking-app/components/booking-approval.tsx b/flight-booking-app/components/booking-approval.tsx index 68a1946..ef43bfa 100644 --- a/flight-booking-app/components/booking-approval.tsx +++ b/flight-booking-app/components/booking-approval.tsx @@ -1,5 +1,5 @@ -'use client'; -import { useState } from 'react'; +"use client"; +import { useState } from "react"; interface BookingApprovalProps { toolCallId: string; input?: { @@ -14,7 +14,7 @@ export function BookingApproval({ input, output, }: BookingApprovalProps) { - const [comment, setComment] = useState(''); + const [comment, setComment] = useState(""); const [isSubmitting, setIsSubmitting] = useState(false); const [error, setError] = useState(null); @@ -42,18 +42,21 @@ export function BookingApproval({ setIsSubmitting(true); setError(null); try { - const response = await fetch('/api/hooks/approval', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, + const response = await fetch("/api/hooks/approval", { + method: "POST", + headers: { "Content-Type": "application/json" }, body: JSON.stringify({ toolCallId, approved, comment }), }); - + if (!response.ok) { const errorData = await response.text(); - throw new Error(`API error: ${response.status} - ${errorData || response.statusText}`); + throw new Error( + `API error: ${response.status} - ${errorData || response.statusText}` + ); } } catch (err) { - const errorMessage = err instanceof Error ? err.message : 'Failed to submit approval'; + const errorMessage = + err instanceof Error ? err.message : "Failed to submit approval"; setError(errorMessage); setIsSubmitting(false); return; @@ -93,7 +96,7 @@ export function BookingApproval({ disabled={isSubmitting} className="px-4 py-2 bg-green-600 text-white rounded hover:bg-green-700 disabled:opacity-50" > - {isSubmitting ? 'Submitting...' : 'Approve'} + {isSubmitting ? "Submitting..." : "Approve"}
diff --git a/flight-booking-app/components/chat-input.tsx b/flight-booking-app/components/chat-input.tsx index db5c408..43a4a9d 100644 --- a/flight-booking-app/components/chat-input.tsx +++ b/flight-booking-app/components/chat-input.tsx @@ -18,6 +18,7 @@ export default function ChatInput({ setMessages, sendMessage, stop, + onNewChat, }: { status: ChatStatus; textareaRef: React.RefObject; @@ -26,6 +27,7 @@ export default function ChatInput({ message: PromptInputMessage & { metadata?: { createdAt: number } } ) => void; stop: () => void; + onNewChat?: () => void | Promise; }) { const [text, setText] = useState(''); @@ -36,6 +38,8 @@ export default function ChatInput({ const hasText = Boolean(message.text); if (!hasText) return; + // Always send the message - the hook will handle routing + // (either as new message or follow-up to existing thread) sendMessage({ text: message.text || '', metadata: { createdAt: Date.now() }, @@ -58,9 +62,14 @@ export default function ChatInput({ size="sm" onClick={async () => { await stop(); - localStorage.removeItem('active-workflow-run-id'); - localStorage.removeItem('chat-history'); - setMessages([]); + if (onNewChat) { + await onNewChat(); + } else { + // Fallback if onNewChat not provided + localStorage.removeItem('active-workflow-run-id'); + localStorage.removeItem('active-thread-id'); + setMessages([]); + } setText(''); }} > diff --git a/flight-booking-app/components/use-multi-turn-chat.ts b/flight-booking-app/components/use-multi-turn-chat.ts new file mode 100644 index 0000000..ef2b3b3 --- /dev/null +++ b/flight-booking-app/components/use-multi-turn-chat.ts @@ -0,0 +1,254 @@ +'use client'; + +// A hook for multi-turn chat sessions with workflow-based agents. +// +// In multi-turn mode: +// - The first message starts a new workflow and is handled by useChat +// - Follow-up messages are sent via hook to the running workflow +// - The stream provides assistant messages; user messages are tracked locally +// +// This approach avoids the complexity of emitting user messages to the stream, +// which would require step-level persistence in the workflow. + +import { useChat, type UseChatOptions } from '@ai-sdk/react'; +import type { UIMessage } from 'ai'; +import { useState, useCallback, useRef, useEffect, useMemo } from 'react'; + +// A follow-up message that was sent via hook +interface FollowUpMessage { + id: string; + content: string; + timestamp: number; + // Whether this message has been "acknowledged" by an assistant response + acknowledged: boolean; +} + +/** + * A hook for multi-turn chat sessions. + * + * This hook: + * - Starts a new workflow for the first message + * - Sends follow-up messages to the running workflow via hooks + * - Tracks follow-up user messages locally for display + * + * @example + * ```typescript + * const { messages, sendMessage, status } = useMultiTurnChat({ + * transport: new WorkflowChatTransport({ + * onChatSendMessage: (response) => { + * const threadId = response.headers.get('x-thread-id'); + * if (threadId) setThreadId(threadId); + * }, + * }), + * }); + * ``` + */ +export function useMultiTurnChat( + options: UseChatOptions = {} +) { + // Track the current thread ID for multi-turn conversations + const [threadId, setThreadId] = useState(null); + const threadIdRef = useRef(null); + + // Track follow-up messages that were sent via hook + const [followUpMessages, setFollowUpMessages] = useState([]); + + // Track pending message for UI feedback while waiting + const [pendingMessage, setPendingMessage] = useState(null); + + // Use the underlying useChat hook with all options passed through + const chatHelpers = useChat(options); + + const { sendMessage: originalSendMessage, messages: streamMessages, status } = chatHelpers; + + // Update ref when threadId changes + useEffect(() => { + threadIdRef.current = threadId; + }, [threadId]); + + // Mark follow-up messages as acknowledged when we get a new assistant message + // This is a heuristic: when assistant responds, we assume pending messages were received + useEffect(() => { + if (status === 'streaming' || status === 'ready') { + // Find the last assistant message + const lastAssistantIdx = streamMessages.findLastIndex(m => m.role === 'assistant'); + if (lastAssistantIdx >= 0) { + setFollowUpMessages(prev => + prev.map(msg => ({ ...msg, acknowledged: true })) + ); + setPendingMessage(null); + } + } + }, [streamMessages, status]); + + // Combine stream messages with follow-up user messages for display + const messages = useMemo(() => { + // If no follow-up messages, just return stream messages + if (followUpMessages.length === 0) { + return streamMessages; + } + + // Insert acknowledged follow-up messages before the corresponding assistant responses + // For now, just append acknowledged follow-ups after the last assistant message + const result: TUIMessage[] = []; + + for (const msg of streamMessages) { + result.push(msg); + } + + // Add acknowledged follow-up messages as synthetic user messages + for (const followUp of followUpMessages.filter(f => f.acknowledged)) { + // Create a synthetic user message + const syntheticMsg = { + id: followUp.id, + role: 'user' as const, + content: followUp.content, + parts: [{ type: 'text' as const, text: followUp.content }], + } as TUIMessage; + + // Insert before the last assistant message if possible + const lastAssistantIdx = result.findLastIndex(m => m.role === 'assistant'); + if (lastAssistantIdx >= 0) { + result.splice(lastAssistantIdx, 0, syntheticMsg); + } else { + result.push(syntheticMsg); + } + } + + return result; + }, [streamMessages, followUpMessages]); + + // Internal function to send follow-up message via hook endpoint + const sendFollowUpInternal = useCallback( + async (messageText: string) => { + const currentThreadId = threadIdRef.current; + if (!currentThreadId) { + throw new Error('No active thread'); + } + + // Set pending message for UI feedback + setPendingMessage(messageText); + + // Track this message locally + const followUpId = `followup-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`; + setFollowUpMessages(prev => [ + ...prev, + { id: followUpId, content: messageText, timestamp: Date.now(), acknowledged: false }, + ]); + + // Send message to resume workflow via hook endpoint + try { + const response = await fetch(`/api/chat/${currentThreadId}`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ message: messageText }), + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(`Failed to send follow-up message: ${errorText}`); + } + } catch (error) { + console.error('Error sending follow-up message:', error); + // Remove the failed message from tracking + setFollowUpMessages(prev => prev.filter(m => m.id !== followUpId)); + setPendingMessage(null); + throw error; + } + }, + [] + ); + + // Smart sendMessage - uses follow-up for subsequent messages in same thread + const sendMessage = useCallback( + async ( + message?: Parameters[0], + requestOptions?: Parameters[1] + ) => { + // Extract message text + const messageText = + typeof message === 'string' + ? message + : message && 'text' in message + ? message.text || '' + : ''; + + // If we already have a thread, send as follow-up to existing workflow + if (threadIdRef.current && messageText) { + console.log( + 'Sending follow-up to existing thread:', + threadIdRef.current + ); + return sendFollowUpInternal(messageText); + } + + // First message - start a new workflow + console.log( + 'Starting new chat (threadId will be set from server response)' + ); + + return originalSendMessage(message, requestOptions); + }, + [originalSendMessage, sendFollowUpInternal] + ); + + // Function to update threadId from server response + const updateThreadId = useCallback((newThreadId: string) => { + console.log('Setting threadId from server:', newThreadId); + threadIdRef.current = newThreadId; + setThreadId(newThreadId); + }, []); + + // End the current multi-turn session + const endSession = useCallback(async () => { + const currentThreadId = threadIdRef.current; + if (!currentThreadId) return; + + try { + await fetch(`/api/chat/${currentThreadId}`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ message: '/done' }), + }); + } catch (error) { + console.error('Error ending session:', error); + } finally { + threadIdRef.current = null; + setThreadId(null); + setFollowUpMessages([]); + } + }, []); + + // Reset thread when messages are explicitly cleared + const wrappedSetMessages = useCallback( + ( + newMessages: TUIMessage[] | ((messages: TUIMessage[]) => TUIMessage[]) + ) => { + // If messages are being cleared, also clear the thread + if (Array.isArray(newMessages) && newMessages.length === 0) { + console.log('Clearing chat state'); + threadIdRef.current = null; + setThreadId(null); + setFollowUpMessages([]); + setPendingMessage(null); + } + // Call the underlying setMessages directly + chatHelpers.setMessages(newMessages); + }, + [chatHelpers.setMessages] + ); + + return { + ...chatHelpers, + messages, + sendMessage, + setMessages: wrappedSetMessages, + // Multi-turn specific + threadId, + setThreadId: updateThreadId, + sendFollowUp: sendFollowUpInternal, + endSession, + isSessionActive: threadId !== null, + pendingMessage, // For showing "sending..." feedback + }; +} diff --git a/flight-booking-app/workflows/chat/hooks/chat-message.ts b/flight-booking-app/workflows/chat/hooks/chat-message.ts new file mode 100644 index 0000000..3582b52 --- /dev/null +++ b/flight-booking-app/workflows/chat/hooks/chat-message.ts @@ -0,0 +1,8 @@ +import { defineHook } from 'workflow'; +import { z } from 'zod'; + +export const chatMessageHook = defineHook({ + schema: z.object({ + message: z.string(), + }), +}); diff --git a/flight-booking-app/workflows/chat/index.ts b/flight-booking-app/workflows/chat/index.ts index 0d38ba9..bdf041d 100644 --- a/flight-booking-app/workflows/chat/index.ts +++ b/flight-booking-app/workflows/chat/index.ts @@ -1,28 +1,88 @@ -import { convertToModelMessages, UIMessageChunk, type UIMessage } from "ai"; import { DurableAgent } from "@workflow/ai/agent"; -import { FLIGHT_ASSISTANT_PROMPT, flightBookingTools } from "./steps/tools"; +import { + type UIMessageChunk, + type UIMessage, + type ModelMessage, + convertToModelMessages, +} from "ai"; + import { getWritable } from "workflow"; +import { FLIGHT_ASSISTANT_PROMPT, flightBookingTools } from "./steps/tools"; +import { chatMessageHook } from "./hooks/chat-message"; + /** - * The main chat workflow + * The main chat workflow with multi-turn support. + * + * This workflow: + * 1. Takes initial messages and streams assistant responses + * 2. Waits for follow-up messages via hook + * 3. Loops until "/done" is received + * + * Note: User messages are NOT emitted to the stream. The client is responsible + * for displaying user messages based on: + * - Initial messages passed to the workflow (shown from client state) + * - Follow-up messages (shown as pending until assistant responds) */ -export async function chat(messages: UIMessage[]) { +export async function chat(threadId: string, initialMessages: UIMessage[]) { "use workflow"; - console.log("Starting workflow"); + console.log("Starting workflow for thread:", threadId); const writable = getWritable(); + // Keep track of messages in ModelMessage format for the agent + let modelMessages: ModelMessage[] = await convertToModelMessages( + initialMessages + ); + const agent = new DurableAgent({ model: "bedrock/claude-haiku-4-5-20251001-v1", system: FLIGHT_ASSISTANT_PROMPT, tools: flightBookingTools, }); - await agent.stream({ - messages: convertToModelMessages(messages), - writable, - }); + // Create hook with thread-specific token for resumption + const hook = chatMessageHook.create({ token: `thread:${threadId}` }); + + while (true) { + // Process current messages and get assistant response + const { messages: resultMessages } = await agent.stream({ + messages: modelMessages, + writable, + preventClose: true, // Keep stream open for follow-ups + }); + + // Update model messages with the result + modelMessages = resultMessages; + + // Wait for next user message via hook + const { message } = await hook; + + // Check if session should end + if (message === "/done") { + console.log("Ending workflow session for thread:", threadId); + break; + } + + // Add user message to conversation in ModelMessage format + modelMessages.push({ + role: "user", + content: message, + }); + + console.log("Received follow-up message, continuing conversation..."); + } + + console.log( + "Finished workflow session with", + modelMessages.length, + "messages" + ); - console.log("Finished workflow"); + return { + threadId, + messageCount: modelMessages.length, + status: "completed" as const, + }; } diff --git a/flight-booking-app/workflows/chat/steps/tools.ts b/flight-booking-app/workflows/chat/steps/tools.ts index bc0a404..293a096 100644 --- a/flight-booking-app/workflows/chat/steps/tools.ts +++ b/flight-booking-app/workflows/chat/steps/tools.ts @@ -1,36 +1,36 @@ -import { FatalError, sleep } from 'workflow'; -import { z } from 'zod'; -import { bookingApprovalHook } from '../hooks/approval'; +import { FatalError, sleep } from "workflow"; +import { z } from "zod"; +import { bookingApprovalHook } from "../hooks/approval"; export const mockAirports: Record< string, { name: string; city: string; timezone: string } > = { SFO: { - name: 'San Francisco International Airport', - city: 'San Francisco', - timezone: 'PST', + name: "San Francisco International Airport", + city: "San Francisco", + timezone: "PST", }, LAX: { - name: 'Los Angeles International Airport', - city: 'Los Angeles', - timezone: 'PST', + name: "Los Angeles International Airport", + city: "Los Angeles", + timezone: "PST", }, JFK: { - name: 'John F. Kennedy International Airport', - city: 'New York', - timezone: 'EST', + name: "John F. Kennedy International Airport", + city: "New York", + timezone: "EST", }, - MIA: { name: 'Miami International Airport', city: 'Miami', timezone: 'EST' }, + MIA: { name: "Miami International Airport", city: "Miami", timezone: "EST" }, ATL: { - name: 'Hartsfield-Jackson Atlanta International Airport', - city: 'Atlanta', - timezone: 'EST', + name: "Hartsfield-Jackson Atlanta International Airport", + city: "Atlanta", + timezone: "EST", }, ORD: { name: "O'Hare International Airport", - city: 'Chicago', - timezone: 'CST', + city: "Chicago", + timezone: "CST", }, }; @@ -44,7 +44,7 @@ export async function searchFlights({ to: string; date: string; }) { - 'use step'; + "use step"; console.log(`Searching flights from ${from} to ${to} on ${date}`); @@ -53,13 +53,13 @@ export async function searchFlights({ // Generate 3 flights with different price points and statuses const airlines = [ - 'United Airlines', - 'American Airlines', - 'Delta Airlines', - 'Southwest Airlines', - 'JetBlue', + "United Airlines", + "American Airlines", + "Delta Airlines", + "Southwest Airlines", + "JetBlue", ]; - const statuses = ['On Time', 'Delayed', 'On Time']; + const statuses = ["On Time", "Delayed", "On Time"]; const priceMultipliers = [1, 1.5, 2.2]; // Budget, mid-range, premium // Base price calculation (could be based on distance, popularity, etc.) @@ -78,8 +78,10 @@ export async function searchFlights({ const arrivalTime = new Date(departureTime.getTime() + duration * 60000); // Generate flight number - const airlineCode = ['UA', 'AA', 'DL', 'WN', 'B6'][index % 5]; - const flightNumber = `${airlineCode}${Math.floor(Math.random() * 900) + 100}`; + const airlineCode = ["UA", "AA", "DL", "WN", "B6"][index % 5]; + const flightNumber = `${airlineCode}${ + Math.floor(Math.random() * 900) + 100 + }`; return { flightNumber, @@ -105,42 +107,42 @@ export async function checkFlightStatus({ }: { flightNumber: string; }) { - 'use step'; + "use step"; console.log(`Checking status for flight ${flightNumber}`); // 10% chance of error to demonstrate retry if (Math.random() < 0.1) { - throw new Error('Flight status service temporarily unavailable'); + throw new Error("Flight status service temporarily unavailable"); } // Generate random flight details const airlines = [ - 'United Airlines', - 'American Airlines', - 'Delta Airlines', - 'Southwest Airlines', - 'JetBlue', + "United Airlines", + "American Airlines", + "Delta Airlines", + "Southwest Airlines", + "JetBlue", ]; const airports = [ - 'LAX', - 'JFK', - 'ORD', - 'ATL', - 'DFW', - 'SFO', - 'MIA', - 'DEN', - 'BOS', - 'SEA', + "LAX", + "JFK", + "ORD", + "ATL", + "DFW", + "SFO", + "MIA", + "DEN", + "BOS", + "SEA", ]; const statuses = [ - 'On Time', - 'Delayed', - 'Boarding', - 'Departed', - 'In Flight', - 'Landed', + "On Time", + "Delayed", + "Boarding", + "Departed", + "In Flight", + "Landed", ]; // Random selections @@ -160,27 +162,31 @@ export async function checkFlightStatus({ // Determine gate based on status const status = statuses[Math.floor(Math.random() * statuses.length)]; - const gate = ['Boarding', 'Departed', 'In Flight', 'Landed'].includes(status) - ? `${['A', 'B', 'C', 'D'][Math.floor(Math.random() * 4)]}${Math.floor(Math.random() * 30) + 1}` + const gate = ["Boarding", "Departed", "In Flight", "Landed"].includes(status) + ? `${["A", "B", "C", "D"][Math.floor(Math.random() * 4)]}${ + Math.floor(Math.random() * 30) + 1 + }` : Math.random() < 0.7 - ? `${['A', 'B', 'C', 'D'][Math.floor(Math.random() * 4)]}${Math.floor(Math.random() * 30) + 1}` - : 'TBD'; + ? `${["A", "B", "C", "D"][Math.floor(Math.random() * 4)]}${ + Math.floor(Math.random() * 30) + 1 + }` + : "TBD"; // Add delay information if status is "Delayed" const delayMinutes = - status === 'Delayed' ? Math.floor(Math.random() * 120) + 15 : 0; + status === "Delayed" ? Math.floor(Math.random() * 120) + 15 : 0; const actualDepartureTime = - status === 'Delayed' + status === "Delayed" ? new Date(departureTime.getTime() + delayMinutes * 60 * 1000) : departureTime; const actualArrivalTime = - status === 'Delayed' + status === "Delayed" ? new Date(arrivalTime.getTime() + delayMinutes * 60 * 1000) : arrivalTime; return { flightNumber: flightNumber.toUpperCase(), - status: status + (status === 'Delayed' ? ` (${delayMinutes} minutes)` : ''), + status: status + (status === "Delayed" ? ` (${delayMinutes} minutes)` : ""), departure: departureTime.toISOString(), arrival: arrivalTime.toISOString(), actualDeparture: actualDepartureTime.toISOString(), @@ -195,7 +201,7 @@ export async function checkFlightStatus({ /** Get airport information */ export async function getAirportInfo({ airportCode }: { airportCode: string }) { - 'use step'; + "use step"; console.log(`Getting information for airport ${airportCode}`); @@ -204,7 +210,7 @@ export async function getAirportInfo({ airportCode }: { airportCode: string }) { if (!airport) { return { error: `Airport code ${airportCode} not found`, - suggestion: `Try one of these: ${Object.keys(mockAirports).join(', ')}`, + suggestion: `Try one of these: ${Object.keys(mockAirports).join(", ")}`, }; } @@ -226,7 +232,7 @@ export async function bookFlight({ passengerName: string; seatPreference?: string; }) { - 'use step'; + "use step"; console.log(`Booking flight ${flightNumber} for ${passengerName}`); @@ -236,17 +242,20 @@ export async function bookFlight({ // 5% chance of seat unavailable if (Math.random() < 0.05) { throw new FatalError( - 'Selected seat preference not available. Please try a different preference.' + "Selected seat preference not available. Please try a different preference." ); } - const confirmationNumber = `BK${Math.random().toString(36).substring(2, 8).toUpperCase()}`; + const confirmationNumber = `BK${Math.random() + .toString(36) + .substring(2, 8) + .toUpperCase()}`; const seatNumber = - seatPreference === 'window' + seatPreference === "window" ? `${Math.floor(Math.random() * 30) + 1}A` - : seatPreference === 'aisle' - ? `${Math.floor(Math.random() * 30) + 1}C` - : `${Math.floor(Math.random() * 30) + 1}B`; + : seatPreference === "aisle" + ? `${Math.floor(Math.random() * 30) + 1}C` + : `${Math.floor(Math.random() * 30) + 1}B`; return { success: true, @@ -254,7 +263,7 @@ export async function bookFlight({ passengerName, flightNumber, seatNumber, - message: 'Flight booked successfully! Check your email for confirmation.', + message: "Flight booked successfully! Check your email for confirmation.", }; } @@ -266,14 +275,14 @@ export async function checkBaggageAllowance({ airline: string; ticketClass: string; }) { - 'use step'; + "use step"; console.log(`Checking baggage allowance for ${airline} ${ticketClass} class`); const allowances = { - economy: { carryOn: 1, checked: 1, maxWeight: '50 lbs' }, - business: { carryOn: 2, checked: 2, maxWeight: '70 lbs' }, - first: { carryOn: 2, checked: 3, maxWeight: '70 lbs' }, + economy: { carryOn: 1, checked: 1, maxWeight: "50 lbs" }, + business: { carryOn: 2, checked: 2, maxWeight: "70 lbs" }, + first: { carryOn: 2, checked: 3, maxWeight: "70 lbs" }, }; const classKey = ticketClass.toLowerCase() as keyof typeof allowances; @@ -285,7 +294,7 @@ export async function checkBaggageAllowance({ carryOnBags: allowance.carryOn, checkedBags: allowance.checked, maxWeightPerBag: allowance.maxWeight, - oversizeFee: '$150 per bag', + oversizeFee: "$150 per bag", }; } @@ -309,73 +318,75 @@ async function executeBookingApproval( // Workflow pauses here until the hook is resolved const { approved, comment } = await hook; if (!approved) { - return `Booking rejected: ${comment || 'No reason provided'}`; + return `Booking rejected: ${comment || "No reason provided"}`; } - return `Booking approved for ${passengerName} on flight ${flightNumber} (Price: ${price})${comment ? ` - Note: ${comment}` : ''}`; + return `Booking approved for ${passengerName} on flight ${flightNumber} (Price: ${price})${ + comment ? ` - Note: ${comment}` : "" + }`; } // Tool definitions export const flightBookingTools = { searchFlights: { description: - 'Search for available flights between two cities on a specific date', + "Search for available flights between two cities on a specific date", inputSchema: z.object({ - from: z.string().describe('Departure city or airport code'), - to: z.string().describe('Arrival city or airport code'), - date: z.string().describe('Travel date in YYYY-MM-DD format'), + from: z.string().describe("Departure city or airport code"), + to: z.string().describe("Arrival city or airport code"), + date: z.string().describe("Travel date in YYYY-MM-DD format"), }), execute: searchFlights, }, checkFlightStatus: { - description: 'Check the current status of a specific flight', + description: "Check the current status of a specific flight", inputSchema: z.object({ - flightNumber: z.string().describe('Flight number (e.g., UA123)'), + flightNumber: z.string().describe("Flight number (e.g., UA123)"), }), execute: checkFlightStatus, }, getAirportInfo: { - description: 'Get information about a specific airport', + description: "Get information about a specific airport", inputSchema: z.object({ - airportCode: z.string().describe('3-letter airport code (e.g., LAX)'), + airportCode: z.string().describe("3-letter airport code (e.g., LAX)"), }), execute: getAirportInfo, }, bookFlight: { - description: 'Book a flight for a passenger', + description: "Book a flight for a passenger", inputSchema: z.object({ - flightNumber: z.string().describe('Flight number to book'), - passengerName: z.string().describe('Full name of the passenger'), + flightNumber: z.string().describe("Flight number to book"), + passengerName: z.string().describe("Full name of the passenger"), seatPreference: z .string() .optional() - .describe('Seat preference: window, aisle, or middle'), + .describe("Seat preference: window, aisle, or middle"), }), execute: bookFlight, }, checkBaggageAllowance: { description: - 'Check baggage allowance for a specific airline and ticket class', + "Check baggage allowance for a specific airline and ticket class", inputSchema: z.object({ - airline: z.string().describe('Name of the airline'), + airline: z.string().describe("Name of the airline"), ticketClass: z .string() - .describe('Ticket class: economy, business, or first'), + .describe("Ticket class: economy, business, or first"), }), execute: checkBaggageAllowance, }, sleep: { - description: 'Pause execution for a specified duration', + description: "Pause execution for a specified duration", inputSchema: z.object({ - durationMs: z.number().describe('Duration to sleep in milliseconds'), + durationMs: z.number().describe("Duration to sleep in milliseconds"), }), execute: executeSleep, }, bookingApproval: { - description: 'Request human approval before booking a flight', + description: "Request human approval before booking a flight", inputSchema: z.object({ - flightNumber: z.string().describe('Flight number to book'), - passengerName: z.string().describe('Name of the passenger'), - price: z.number().describe('Total price of the booking'), + flightNumber: z.string().describe("Flight number to book"), + passengerName: z.string().describe("Name of the passenger"), + price: z.number().describe("Total price of the booking"), }), execute: executeBookingApproval, },