diff --git a/setup_sse.sh b/setup_sse.sh new file mode 100755 index 0000000..8701d21 --- /dev/null +++ b/setup_sse.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# Add required dependencies +echo "Installing dependencies..." +npm install --save uuid express cors + +# Copy updated request handler +echo "Updating requestHandler.ts..." +mv src/requestHandler.updated.ts src/requestHandler.ts + +# Install TypeScript type definitions +echo "Installing TypeScript type definitions..." +npm install --save-dev @types/uuid @types/express + +echo "Setup completed successfully!" +echo "To run the server with SSE support, execute: npm start" \ No newline at end of file diff --git a/src/index.ts b/src/index.ts index e363db7..80d67b6 100644 --- a/src/index.ts +++ b/src/index.ts @@ -2,10 +2,32 @@ import { Server } from "@modelcontextprotocol/sdk/server/index.js"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; +import express from "express"; +import cors from "cors"; import { createToolDefinitions } from "./tools.js"; import { setupRequestHandlers } from "./requestHandler.js"; +import { setupSSEEndpoints } from "./sseHandler.js"; async function runServer() { + // Create Express app for HTTP server + const app = express(); + const PORT = process.env.PORT || 3000; + + // Enable CORS for all routes + app.use(cors()); + + // Parse JSON bodies + app.use(express.json()); + + // Setup SSE endpoints + setupSSEEndpoints(app); + + // Start HTTP server + const httpServer = app.listen(PORT, () => { + console.log(`HTTP server running on port ${PORT}`); + }); + + // Initialize MCP server const server = new Server( { name: "executeautomation/playwright-mcp-server", @@ -28,6 +50,21 @@ async function runServer() { // Create transport and connect const transport = new StdioServerTransport(); await server.connect(transport); + + // Handle graceful shutdown + const handleShutdown = async () => { + console.log("Shutting down servers..."); + + // Close HTTP server + httpServer.close(); + + // Exit process + process.exit(0); + }; + + // Register shutdown handlers + process.on("SIGINT", handleShutdown); + process.on("SIGTERM", handleShutdown); } runServer().catch((error) => { diff --git a/src/requestHandler.updated.ts b/src/requestHandler.updated.ts new file mode 100644 index 0000000..dc8ea19 --- /dev/null +++ b/src/requestHandler.updated.ts @@ -0,0 +1,70 @@ +import { Server } from "@modelcontextprotocol/sdk/server/index.js"; +import { + ListResourcesRequestSchema, + ReadResourceRequestSchema, + ListToolsRequestSchema, + CallToolRequestSchema, + Tool +} from "@modelcontextprotocol/sdk/types.js"; +import { handleToolCall, getConsoleLogs, getScreenshots } from "./toolHandler.js"; +import { withSSENotifications } from "./sseIntegration.js"; + +export function setupRequestHandlers(server: Server, tools: Tool[]) { + // List resources handler + server.setRequestHandler(ListResourcesRequestSchema, async () => ({ + resources: [ + { + uri: "console://logs", + mimeType: "text/plain", + name: "Browser console logs", + }, + ...Array.from(getScreenshots().keys()).map(name => ({ + uri: `screenshot://${name}`, + mimeType: "image/png", + name: `Screenshot: ${name}`, + })), + ], + })); + + // Read resource handler + server.setRequestHandler(ReadResourceRequestSchema, async (request) => { + const uri = request.params.uri.toString(); + + if (uri === "console://logs") { + return { + contents: [{ + uri, + mimeType: "text/plain", + text: getConsoleLogs().join("\n"), + }], + }; + } + + if (uri.startsWith("screenshot://")) { + const name = uri.split("://")[1]; + const screenshot = getScreenshots().get(name); + if (screenshot) { + return { + contents: [{ + uri, + mimeType: "image/png", + blob: screenshot, + }], + }; + } + } + + throw new Error(`Resource not found: ${uri}`); + }); + + // List tools handler + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: tools, + })); + + // Call tool handler with SSE notifications + const enhancedToolHandler = withSSENotifications(handleToolCall); + server.setRequestHandler(CallToolRequestSchema, async (request) => + enhancedToolHandler(request.params.name, request.params.arguments ?? {}, server) + ); +} \ No newline at end of file diff --git a/src/sse.ts b/src/sse.ts new file mode 100644 index 0000000..6cf2d47 --- /dev/null +++ b/src/sse.ts @@ -0,0 +1,131 @@ +import { v4 as uuidv4 } from 'uuid'; + +interface SSEClient { + id: string; + res: any; +} + +interface SSEEvent { + id?: string; + event?: string; + data: string; + retry?: number; +} + +class SSEManager { + private static instance: SSEManager; + private clients: Map; + + private constructor() { + this.clients = new Map(); + } + + public static getInstance(): SSEManager { + if (!SSEManager.instance) { + SSEManager.instance = new SSEManager(); + } + return SSEManager.instance; + } + + /** + * Register a new SSE client connection + * @param res - Express response object + * @returns Client ID + */ + public registerClient(res: any): string { + // Set headers for SSE + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + 'Connection': 'keep-alive', + 'X-Accel-Buffering': 'no' + }); + + // Send initial connection established message + res.write('retry: 10000\n\n'); + + // Generate unique ID for this client + const clientId = uuidv4(); + + // Store client connection + this.clients.set(clientId, { id: clientId, res }); + + // Send initial connection successful event + this.sendEventToClient(clientId, { + event: 'connected', + data: JSON.stringify({ clientId }) + }); + + console.log(`SSE client connected: ${clientId}`); + return clientId; + } + + /** + * Remove a client connection + * @param clientId - ID of client to remove + */ + public removeClient(clientId: string): void { + if (this.clients.has(clientId)) { + this.clients.delete(clientId); + console.log(`SSE client disconnected: ${clientId}`); + } + } + + /** + * Send an event to a specific client + * @param clientId - ID of client to send to + * @param event - Event to send + */ + public sendEventToClient(clientId: string, event: SSEEvent): boolean { + const client = this.clients.get(clientId); + if (!client) { + return false; + } + + try { + // Format SSE message + let message = ''; + if (event.id) message += `id: ${event.id}\n`; + if (event.event) message += `event: ${event.event}\n`; + message += `data: ${event.data}\n`; + if (event.retry) message += `retry: ${event.retry}\n`; + message += '\n'; + + // Send to client + client.res.write(message); + return true; + } catch (error) { + console.error(`Error sending event to client ${clientId}:`, error); + this.removeClient(clientId); + return false; + } + } + + /** + * Send an event to all connected clients + * @param event - Event to send + */ + public broadcast(event: SSEEvent): void { + this.clients.forEach((client) => { + this.sendEventToClient(client.id, event); + }); + } + + /** + * Get the number of connected clients + * @returns Number of connected clients + */ + public getClientCount(): number { + return this.clients.size; + } + + /** + * Get all connected client IDs + * @returns Array of client IDs + */ + public getClientIds(): string[] { + return Array.from(this.clients.keys()); + } +} + +export default SSEManager; \ No newline at end of file diff --git a/src/sseHandler.ts b/src/sseHandler.ts new file mode 100644 index 0000000..7d06990 --- /dev/null +++ b/src/sseHandler.ts @@ -0,0 +1,66 @@ +import { Express, Request, Response } from "express"; +import SSEManager from "./sse.js"; + +/** + * Setup SSE-related endpoints in Express + * @param app - Express application + */ +export function setupSSEEndpoints(app: Express): void { + const sseManager = SSEManager.getInstance(); + + // SSE connection endpoint + app.get('/sse', (req: Request, res: Response) => { + // Set headers for SSE connection + res.setHeader('Content-Type', 'text/event-stream'); + res.setHeader('Cache-Control', 'no-cache'); + res.setHeader('Connection', 'keep-alive'); + res.setHeader('Access-Control-Allow-Origin', '*'); + res.flushHeaders(); + + // Register client + const clientId = sseManager.registerClient(res); + + // Handle client disconnect + req.on('close', () => { + sseManager.removeClient(clientId); + }); + }); + + // Endpoint to send an event to a specific client + app.post('/sse/event/:clientId', (req: Request, res: Response) => { + const { clientId } = req.params; + const event = req.body; + + if (!event.data) { + return res.status(400).json({ error: 'Event data is required' }); + } + + const success = sseManager.sendEventToClient(clientId, event); + + if (success) { + res.status(200).json({ success: true }); + } else { + res.status(404).json({ error: 'Client not found' }); + } + }); + + // Endpoint to broadcast an event to all clients + app.post('/sse/broadcast', (req: Request, res: Response) => { + const event = req.body; + + if (!event.data) { + return res.status(400).json({ error: 'Event data is required' }); + } + + sseManager.broadcast(event); + res.status(200).json({ success: true, clientCount: sseManager.getClientCount() }); + }); + + // Endpoint to get connected client information + app.get('/sse/clients', (req: Request, res: Response) => { + res.status(200).json({ + count: sseManager.getClientCount(), + clients: sseManager.getClientIds() + }); + }); +} \ No newline at end of file diff --git a/src/sseIntegration.ts b/src/sseIntegration.ts new file mode 100644 index 0000000..3737d6d --- /dev/null +++ b/src/sseIntegration.ts @@ -0,0 +1,117 @@ +import type { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; +import SSEManager from './sse.js'; + +// Get SSE Manager instance +const sseManager = SSEManager.getInstance(); + +/** + * Wraps the original handleToolCall function to add SSE notifications + * + * @param originalHandler - Original tool handler function + * @returns Wrapped handler function with SSE integration + */ +export function withSSENotifications( + originalHandler: (name: string, args: any, server: any) => Promise +) { + return async (name: string, args: any, server: any): Promise => { + // Notify clients that a tool is being called + sseManager.broadcast({ + event: 'tool_call_start', + data: JSON.stringify({ + tool: name, + args: sanitizeArgs(args), + timestamp: new Date().toISOString() + }) + }); + + try { + // Call the original handler + const result = await originalHandler(name, args, server); + + // Broadcast tool result via SSE + sseManager.broadcast({ + event: 'tool_call_complete', + data: JSON.stringify({ + tool: name, + args: sanitizeArgs(args), + success: !result.isError, + timestamp: new Date().toISOString() + }) + }); + + return result; + } catch (error) { + // Notify about tool error via SSE + sseManager.broadcast({ + event: 'tool_call_error', + data: JSON.stringify({ + tool: name, + args: sanitizeArgs(args), + error: (error as Error).message, + timestamp: new Date().toISOString() + }) + }); + + // Re-throw the error to be handled by the original error handling + throw error; + } + }; +} + +/** + * Returns a safe copy of args for JSON serialization + * Removes any sensitive data or complex objects that can't be serialized + */ +function sanitizeArgs(args: any): any { + if (!args) return {}; + + // Create a shallow copy + const sanitized = { ...args }; + + // Remove any sensitive fields + if (sanitized.password) sanitized.password = '******'; + if (sanitized.token) sanitized.token = '******'; + + return sanitized; +} + +/** + * Send a notification about browser events through SSE + */ +export function notifyBrowserEvent(event: string, details: any): void { + sseManager.broadcast({ + event: `browser_${event}`, + data: JSON.stringify({ + event, + details, + timestamp: new Date().toISOString() + }) + }); +} + +/** + * Notify when a screenshot is taken + */ +export function notifyScreenshot(name: string, hasSelector: boolean): void { + sseManager.broadcast({ + event: 'screenshot_taken', + data: JSON.stringify({ + name, + type: hasSelector ? 'element' : 'page', + timestamp: new Date().toISOString() + }) + }); +} + +/** + * Notify when console logs are updated + */ +export function notifyConsoleLogUpdate(count: number): void { + sseManager.broadcast({ + event: 'console_logs_updated', + data: JSON.stringify({ + count, + timestamp: new Date().toISOString() + }) + }); +} \ No newline at end of file diff --git a/src/toolHandler.ts b/src/toolHandler.ts index dd8dc46..c71f1e9 100644 --- a/src/toolHandler.ts +++ b/src/toolHandler.ts @@ -4,6 +4,7 @@ import type { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; import { BROWSER_TOOLS, API_TOOLS } from './tools.js'; import type { ToolContext } from './tools/common/types.js'; import { ActionRecorder } from './tools/codegen/recorder.js'; +import SSEManager from './sse.js'; import { startCodegenSession, endCodegenSession,