diff --git a/ts/packages/agents/browser/src/agent/actionHandler.mts b/ts/packages/agents/browser/src/agent/actionHandler.mts index fdce4aa82..9b02da80d 100644 --- a/ts/packages/agents/browser/src/agent/actionHandler.mts +++ b/ts/packages/agents/browser/src/agent/actionHandler.mts @@ -83,7 +83,7 @@ async function updateBrowserContext( return; } - const webSocket = await createWebSocket(); + const webSocket = await createWebSocket("browser", "dispatcher"); if (webSocket) { context.agentContext.webSocket = webSocket; context.agentContext.browserConnector = new BrowserConnector(context); diff --git a/ts/packages/agents/browser/src/extension/serviceWorker.ts b/ts/packages/agents/browser/src/extension/serviceWorker.ts index 4f1db532b..0de119015 100644 --- a/ts/packages/agents/browser/src/extension/serviceWorker.ts +++ b/ts/packages/agents/browser/src/extension/serviceWorker.ts @@ -39,7 +39,7 @@ export async function createWebSocket() { let socketEndpoint = configValues["WEBSOCKET_HOST"] ?? "ws://localhost:8080/"; - socketEndpoint += "?clientId=" + chrome.runtime.id; + socketEndpoint += `?channel=browser&role=client&clientId=${chrome.runtime.id}`; return new Promise((resolve, reject) => { const webSocket = new WebSocket(socketEndpoint); console.log("Connected to: " + socketEndpoint); @@ -133,11 +133,13 @@ async function ensureWebsocketConnected() { } }; - webSocket.onclose = (event: object) => { + webSocket.onclose = (event: any) => { console.log("websocket connection closed"); webSocket = undefined; showBadgeError(); - reconnectWebSocket(); + if (event.reason !== "duplicate") { + reconnectWebSocket(); + } }; resolve(webSocket); diff --git a/ts/packages/agents/code/src/codeActionHandler.ts b/ts/packages/agents/code/src/codeActionHandler.ts index f02515a95..f05f41eff 100644 --- a/ts/packages/agents/code/src/codeActionHandler.ts +++ b/ts/packages/agents/code/src/codeActionHandler.ts @@ -52,7 +52,7 @@ async function updateCodeContext( return; } - const webSocket = await createWebSocket(); + const webSocket = await createWebSocket("code", "dispatcher"); if (webSocket) { agentContext.webSocket = webSocket; agentContext.pendingCall = new Map(); diff --git a/ts/packages/coda/src/webSocket.ts b/ts/packages/coda/src/webSocket.ts index 321392079..a3f4fb9d0 100644 --- a/ts/packages/coda/src/webSocket.ts +++ b/ts/packages/coda/src/webSocket.ts @@ -11,9 +11,19 @@ export type WebSocketMessage = { body: any; }; -export async function createWebSocket() { +export async function createWebSocket( + channel: string, + role: string, + clientId?: string, +) { return new Promise((resolve, reject) => { - const webSocket = new WebSocket("ws://localhost:8080/"); + let endpoint = "ws://localhost:8080"; + endpoint += `?channel=${channel}&role=${role}`; + if (clientId) { + endpoint += `clientId=${clientId}`; + } + + const webSocket = new WebSocket(endpoint); webSocket.onopen = (event: object) => { console.log("websocket open"); diff --git a/ts/packages/coda/src/wsConnect.ts b/ts/packages/coda/src/wsConnect.ts index 843fa33a3..173bef92e 100644 --- a/ts/packages/coda/src/wsConnect.ts +++ b/ts/packages/coda/src/wsConnect.ts @@ -20,7 +20,7 @@ async function ensureWebsocketConnected() { return; } - webSocket = await createWebSocket(); + webSocket = await createWebSocket("code", "client"); if (!webSocket) { return; } diff --git a/ts/packages/commonUtils/src/webSockets.ts b/ts/packages/commonUtils/src/webSockets.ts index 4cebce079..089523f4f 100644 --- a/ts/packages/commonUtils/src/webSockets.ts +++ b/ts/packages/commonUtils/src/webSockets.ts @@ -17,7 +17,11 @@ export type WebSocketMessage = { body: any; }; -export async function createWebSocket() { +export async function createWebSocket( + channel: string, + role: string, + clientId?: string, +) { return new Promise((resolve, reject) => { let endpoint = "ws://localhost:8080"; if (process.env["WEBSOCKET_HOST"]) { @@ -32,6 +36,11 @@ export async function createWebSocket() { } } + endpoint += `?channel=${channel}&role=${role}`; + if (clientId) { + endpoint += `clientId=${clientId}`; + } + const webSocket = new WebSocket(endpoint); webSocket.onopen = (event: object) => { diff --git a/ts/packages/dispatcher/src/context/system/handlers/serviceHost/service.ts b/ts/packages/dispatcher/src/context/system/handlers/serviceHost/service.ts index d7e84e94f..d62bf8815 100644 --- a/ts/packages/dispatcher/src/context/system/handlers/serviceHost/service.ts +++ b/ts/packages/dispatcher/src/context/system/handlers/serviceHost/service.ts @@ -6,11 +6,26 @@ import { WebSocketMessage } from "common-utils"; import registerDebug from "debug"; import { IncomingMessage } from "node:http"; +interface Client { + id: string | null; + role: string; + socket: WebSocket; + channelName: string; +} + +interface Channel { + name: string; + clients: Set; +} + const debug = registerDebug("typeagent:serviceHost"); const hostEndpoint = process.env["WEBSOCKET_HOST"] ?? "ws://localhost:8080"; const url = new URL(hostEndpoint); +// Channels organized by agentType +const channels: Map = new Map(); + try { const wss = new WebSocketServer({ port: parseInt(url.port), @@ -32,29 +47,73 @@ try { wss.on("connection", (ws: WebSocket, req: IncomingMessage) => { debug("New client connected"); - if (req.url) { - const params = new URLSearchParams(req.url.split("?")[1]); - const clientId = params.get("clientId"); - if (clientId) { - for (var client of wss.clients) { - if ((client as any).clientId) { - wss.clients.delete(client); - } - } + const params = new URLSearchParams(req.url?.split("?")[1]); + const clientId = params.get("clientId"); + const channelName = params.get("channel"); + const role = params.get("role"); - (ws as any).clientId = clientId; + if (!channelName || !role) { + ws.send(JSON.stringify({ error: "Missing agentName or role" })); + ws.close(); + return; + } + + // Ensure the channel exists + if (!channels.has(channelName)) { + channels.set(channelName, { + name: channelName, + clients: new Set(), + }); + } + + const channel = channels.get(channelName)!; + const client: Client = { + id: clientId, + role: role, + socket: ws, + channelName: channelName, + }; + + if (clientId) { + for (var socket of wss.clients) { + if ((socket as any).clientId == clientId && socket !== ws) { + debug( + "Closing duplicate socket instance for id " + clientId, + ); + socket.close(1013, "duplicate"); + } } + + (ws as any).clientId = clientId; } - debug(`Connection count: ${wss.clients.size}`); + channel.clients.add(client); + debug(`Client ${clientId} joined channel ${channelName}.`); ws.on("message", (message: string) => { try { const data = JSON.parse(message) as WebSocketMessage; if (data.messageType != "keepAlive") { - // broadcast to all connected clients - // TO DO: add routing to send messages to specific clients - wss.clients.forEach((client) => client.send(message)); + let foundAtLeastOneTarget = false; + + // Broadcast message to all clients in the same channel that have a different role + channel.clients.forEach((client) => { + if ( + client.role !== role && + client.socket.readyState === WebSocket.OPEN + ) { + client.socket.send(message); + foundAtLeastOneTarget = true; + } + }); + + if (!foundAtLeastOneTarget) { + const errorMessage = + data.source === channelName + ? `The ${channelName} agent is not connected. The message cannot be processed.` + : `No ${channelName} clients are listening for messaages on this channel`; + ws.send(JSON.stringify({ error: errorMessage })); + } } } catch { debug("WebSocket message not parsed."); @@ -62,7 +121,14 @@ try { }); ws.on("close", () => { - debug("Client disconnected"); + debug(`Client ${clientId} disconnected.`); + channel.clients.delete(client); + + // Cleanup empty channels + if (channel.clients.size === 0) { + channels.delete(channelName); + debug(`Channel ${channelName} deleted.`); + } }); }); diff --git a/ts/packages/shell/src/main/browserIpc.ts b/ts/packages/shell/src/main/browserIpc.ts index 5904dd03c..0d9310db5 100644 --- a/ts/packages/shell/src/main/browserIpc.ts +++ b/ts/packages/shell/src/main/browserIpc.ts @@ -40,7 +40,11 @@ export class BrowserAgentIpc { } catch {} } - this.webSocket = await createWebSocket(); + this.webSocket = await createWebSocket( + "browser", + "client", + "inlineBrowser", + ); if (!this.webSocket) { resolve(undefined); return;