From 9ad97a762f57220030e46727a141475793c80573 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Tue, 23 Sep 2025 21:03:39 -0700 Subject: [PATCH] chore(core): migrate gateway to use websocket protocols --- .../cloudflare-workers/src/manager-driver.ts | 43 +- packages/cloudflare-workers/src/websocket.ts | 8 +- .../driver-test-suite/raw-websocket.ts | 10 - packages/rivetkit/src/actor/instance.ts | 2 - packages/rivetkit/src/actor/persisted.ts | 1 - .../rivetkit/src/actor/router-endpoints.ts | 12 - packages/rivetkit/src/actor/router.ts | 40 +- .../src/common/actor-router-consts.ts | 31 +- packages/rivetkit/src/driver-helpers/mod.ts | 13 +- .../test-inline-client-driver.ts | 36 +- .../driver-test-suite/tests/raw-websocket.ts | 35 -- .../src/drivers/engine/actor-driver.ts | 5 - .../src/drivers/file-system/manager.ts | 4 - packages/rivetkit/src/manager/driver.ts | 1 - packages/rivetkit/src/manager/gateway.ts | 397 ++++++++++++++++++ packages/rivetkit/src/manager/router.ts | 386 +++-------------- .../actor-websocket-client.ts | 37 +- .../rivetkit/src/remote-manager-driver/mod.ts | 14 +- .../src/remote-manager-driver/ws-proxy.ts | 11 +- 19 files changed, 565 insertions(+), 521 deletions(-) create mode 100644 packages/rivetkit/src/manager/gateway.ts diff --git a/packages/cloudflare-workers/src/manager-driver.ts b/packages/cloudflare-workers/src/manager-driver.ts index 95d9175a3..1edc785ab 100644 --- a/packages/cloudflare-workers/src/manager-driver.ts +++ b/packages/cloudflare-workers/src/manager-driver.ts @@ -6,11 +6,13 @@ import { type GetForIdInput, type GetOrCreateWithKeyInput, type GetWithKeyInput, - HEADER_AUTH_DATA, - HEADER_CONN_PARAMS, - HEADER_ENCODING, type ManagerDisplayInformation, type ManagerDriver, + WS_PROTOCOL_ACTOR, + WS_PROTOCOL_CONN_PARAMS, + WS_PROTOCOL_ENCODING, + WS_PROTOCOL_STANDARD, + WS_PROTOCOL_TARGET, } from "rivetkit/driver-helpers"; import { ActorAlreadyExists, InternalError } from "rivetkit/errors"; import { getCloudflareAmbientEnv } from "./handler"; @@ -81,16 +83,22 @@ export class CloudflareActorsManagerDriver implements ManagerDriver { const id = env.ACTOR_DO.idFromString(actorId); const stub = env.ACTOR_DO.get(id); + const protocols: string[] = []; + protocols.push(WS_PROTOCOL_STANDARD); + protocols.push(`${WS_PROTOCOL_TARGET}actor`); + protocols.push(`${WS_PROTOCOL_ACTOR}${actorId}`); + protocols.push(`${WS_PROTOCOL_ENCODING}${encoding}`); + if (params) { + protocols.push( + `${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`, + ); + } + const headers: Record = { Upgrade: "websocket", Connection: "Upgrade", - [HEADER_ENCODING]: encoding, + "sec-websocket-protocol": protocols.join(", "), }; - if (params) { - headers[HEADER_CONN_PARAMS] = JSON.stringify(params); - } - // HACK: See packages/drivers/cloudflare-workers/src/websocket.ts - headers["sec-websocket-protocol"] = "rivetkit"; // Use the path parameter to determine the URL const normalizedPath = path.startsWith("/") ? path : `/${path}`; @@ -152,7 +160,6 @@ export class CloudflareActorsManagerDriver implements ManagerDriver { actorId: string, encoding: Encoding, params: unknown, - authData: unknown, ): Promise { logger().debug({ msg: "forwarding websocket to durable object", @@ -188,14 +195,18 @@ export class CloudflareActorsManagerDriver implements ManagerDriver { } } - // Add RivetKit headers - actorRequest.headers.set(HEADER_ENCODING, encoding); + // Build protocols for WebSocket connection + const protocols: string[] = []; + protocols.push(WS_PROTOCOL_STANDARD); + protocols.push(`${WS_PROTOCOL_TARGET}actor`); + protocols.push(`${WS_PROTOCOL_ACTOR}${actorId}`); + protocols.push(`${WS_PROTOCOL_ENCODING}${encoding}`); if (params) { - actorRequest.headers.set(HEADER_CONN_PARAMS, JSON.stringify(params)); - } - if (authData) { - actorRequest.headers.set(HEADER_AUTH_DATA, JSON.stringify(authData)); + protocols.push( + `${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`, + ); } + actorRequest.headers.set("sec-websocket-protocol", protocols.join(", ")); const id = c.env.ACTOR_DO.idFromString(actorId); const stub = c.env.ACTOR_DO.get(id); diff --git a/packages/cloudflare-workers/src/websocket.ts b/packages/cloudflare-workers/src/websocket.ts index 930bf1fee..02d8b22e9 100644 --- a/packages/cloudflare-workers/src/websocket.ts +++ b/packages/cloudflare-workers/src/websocket.ts @@ -4,6 +4,7 @@ import type { UpgradeWebSocket, WSEvents, WSReadyState } from "hono/ws"; import { defineWebSocketHelper, WSContext } from "hono/ws"; +import { WS_PROTOCOL_STANDARD } from "rivetkit/driver-helpers"; // Based on https://github.com/honojs/hono/issues/1153#issuecomment-1767321332 export const upgradeWebSocket: UpgradeWebSocket< @@ -62,8 +63,11 @@ export const upgradeWebSocket: UpgradeWebSocket< // Set Sec-WebSocket-Protocol if does not exist const protocols = c.req.header("Sec-WebSocket-Protocol"); - if (typeof protocols === "string" && protocols.includes("rivetkit")) { - headers["Sec-WebSocket-Protocol"] = "rivetkit"; + if ( + typeof protocols === "string" && + protocols.includes(WS_PROTOCOL_STANDARD) + ) { + headers["Sec-WebSocket-Protocol"] = WS_PROTOCOL_STANDARD; } return new Response(null, { diff --git a/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts b/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts index 7194c35a8..53e31f5c3 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts @@ -48,16 +48,6 @@ export const rawWebSocketActor = actor({ messageCount: ctx.state.messageCount, }), ); - } else if (parsed.type === "getAuthData") { - // Auth data is not directly available in raw WebSocket handler - // Send a message indicating this limitation - websocket.send( - JSON.stringify({ - type: "authData", - authData: null, - message: "Auth data not available in raw WebSocket handler", - }), - ); } else if (parsed.type === "getRequestInfo") { // Send back the request URL info websocket.send( diff --git a/packages/rivetkit/src/actor/instance.ts b/packages/rivetkit/src/actor/instance.ts index 06b4d7a4e..8908131c5 100644 --- a/packages/rivetkit/src/actor/instance.ts +++ b/packages/rivetkit/src/actor/instance.ts @@ -918,7 +918,6 @@ export class ActorInstance { state: CS, driverId: ConnectionDriver, driverState: unknown, - authData: unknown, ): Promise> { this.#assertReady(); @@ -935,7 +934,6 @@ export class ActorInstance { connDriverState: driverState, params: params, state: state, - authData: authData, lastSeen: Date.now(), subscriptions: [], }; diff --git a/packages/rivetkit/src/actor/persisted.ts b/packages/rivetkit/src/actor/persisted.ts index b3287b41f..f7cb61d0d 100644 --- a/packages/rivetkit/src/actor/persisted.ts +++ b/packages/rivetkit/src/actor/persisted.ts @@ -17,7 +17,6 @@ export interface PersistedConn { connDriverState: unknown; params: CP; state: CS; - authData?: unknown; subscriptions: PersistedSubscription[]; lastSeen: number; } diff --git a/packages/rivetkit/src/actor/router-endpoints.ts b/packages/rivetkit/src/actor/router-endpoints.ts index 050abbc24..a97e47966 100644 --- a/packages/rivetkit/src/actor/router-endpoints.ts +++ b/packages/rivetkit/src/actor/router-endpoints.ts @@ -53,7 +53,6 @@ export interface ConnectWebSocketOpts { encoding: Encoding; actorId: string; params: unknown; - authData: unknown; } export interface ConnectWebSocketOutput { @@ -67,7 +66,6 @@ export interface ConnectSseOpts { encoding: Encoding; params: unknown; actorId: string; - authData: unknown; } export interface ConnectSseOutput { @@ -81,7 +79,6 @@ export interface ActionOpts { actionName: string; actionArgs: unknown[]; actorId: string; - authData: unknown; } export interface ActionOutput { @@ -99,14 +96,12 @@ export interface ConnsMessageOpts { export interface FetchOpts { request: Request; actorId: string; - authData: unknown; } export interface WebSocketOpts { request: Request; websocket: UniversalWebSocket; actorId: string; - authData: unknown; } /** @@ -119,7 +114,6 @@ export async function handleWebSocketConnect( actorId: string, encoding: Encoding, parameters: unknown, - authData: unknown, ): Promise { const exposeInternalError = req ? getRequestExposeInternalError(req) : false; @@ -189,7 +183,6 @@ export async function handleWebSocketConnect( connState, CONNECTION_DRIVER_WEBSOCKET, { encoding } satisfies GenericWebSocketDriverState, - authData, ); // Unblock other handlers @@ -339,7 +332,6 @@ export async function handleSseConnect( _runConfig: RunConfig, actorDriver: ActorDriver, actorId: string, - authData: unknown, ) { c.header("Content-Encoding", "Identity"); @@ -376,7 +368,6 @@ export async function handleSseConnect( connState, CONNECTION_DRIVER_SSE, { encoding } satisfies GenericSseDriverState, - authData, ); // Wait for close @@ -459,7 +450,6 @@ export async function handleAction( actorDriver: ActorDriver, actionName: string, actorId: string, - authData: unknown, ) { const encoding = getRequestEncoding(c.req); const parameters = getRequestConnParams(c.req); @@ -491,7 +481,6 @@ export async function handleAction( connState, CONNECTION_DRIVER_HTTP, {} satisfies GenericHttpDriverState, - authData, ); // Call action @@ -562,7 +551,6 @@ export async function handleRawWebSocketHandler( path: string, actorDriver: ActorDriver, actorId: string, - authData: unknown, ): Promise { const actor = await actorDriver.loadActor(actorId); diff --git a/packages/rivetkit/src/actor/router.ts b/packages/rivetkit/src/actor/router.ts index 903fbda9f..2f4104530 100644 --- a/packages/rivetkit/src/actor/router.ts +++ b/packages/rivetkit/src/actor/router.ts @@ -17,7 +17,6 @@ import { handleWebSocketConnect, } from "@/actor/router-endpoints"; import { - HEADER_AUTH_DATA, HEADER_CONN_ID, HEADER_CONN_PARAMS, HEADER_CONN_TOKEN, @@ -84,13 +83,11 @@ export function createActorRouter( return upgradeWebSocket(async (c) => { const encodingRaw = c.req.header(HEADER_ENCODING); const connParamsRaw = c.req.header(HEADER_CONN_PARAMS); - const authDataRaw = c.req.header(HEADER_AUTH_DATA); const encoding = EncodingSchema.parse(encodingRaw); const connParams = connParamsRaw ? JSON.parse(connParamsRaw) : undefined; - const authData = authDataRaw ? JSON.parse(authDataRaw) : undefined; return await handleWebSocketConnect( c.req.raw, @@ -99,7 +96,6 @@ export function createActorRouter( c.env.actorId, encoding, connParams, - authData, ); })(c, noopNext()); } else { @@ -111,32 +107,13 @@ export function createActorRouter( }); router.get("/connect/sse", async (c) => { - const authDataRaw = c.req.header(HEADER_AUTH_DATA); - let authData: unknown; - if (authDataRaw) { - authData = JSON.parse(authDataRaw); - } - - return handleSseConnect(c, runConfig, actorDriver, c.env.actorId, authData); + return handleSseConnect(c, runConfig, actorDriver, c.env.actorId); }); router.post("/action/:action", async (c) => { const actionName = c.req.param("action"); - const authDataRaw = c.req.header(HEADER_AUTH_DATA); - let authData: unknown; - if (authDataRaw) { - authData = JSON.parse(authDataRaw); - } - - return handleAction( - c, - runConfig, - actorDriver, - actionName, - c.env.actorId, - authData, - ); + return handleAction(c, runConfig, actorDriver, actionName, c.env.actorId); }); router.post("/connections/message", async (c) => { @@ -157,12 +134,6 @@ export function createActorRouter( // Raw HTTP endpoints - /http/* router.all("/raw/http/*", async (c) => { - const authDataRaw = c.req.header(HEADER_AUTH_DATA); - let authData: unknown; - if (authDataRaw) { - authData = JSON.parse(authDataRaw); - } - const actor = await actorDriver.loadActor(c.env.actorId); // TODO: This is not a clean way of doing this since `/http/` might exist mid-path @@ -186,9 +157,7 @@ export function createActorRouter( }); // Call the actor's onFetch handler - it will throw appropriate errors - const response = await actor.handleFetch(correctedRequest, { - auth: authData, - }); + const response = await actor.handleFetch(correctedRequest, {}); // This should never happen now since handleFetch throws errors if (!response) { @@ -205,13 +174,11 @@ export function createActorRouter( return upgradeWebSocket(async (c) => { const encodingRaw = c.req.header(HEADER_ENCODING); const connParamsRaw = c.req.header(HEADER_CONN_PARAMS); - const authDataRaw = c.req.header(HEADER_AUTH_DATA); const encoding = EncodingSchema.parse(encodingRaw); const connParams = connParamsRaw ? JSON.parse(connParamsRaw) : undefined; - const authData = authDataRaw ? JSON.parse(authDataRaw) : undefined; const url = new URL(c.req.url); const pathWithQuery = c.req.path + url.search; @@ -229,7 +196,6 @@ export function createActorRouter( pathWithQuery, actorDriver, c.env.actorId, - authData, ); })(c, noopNext()); } else { diff --git a/packages/rivetkit/src/common/actor-router-consts.ts b/packages/rivetkit/src/common/actor-router-consts.ts index b40dd1282..294389f6d 100644 --- a/packages/rivetkit/src/common/actor-router-consts.ts +++ b/packages/rivetkit/src/common/actor-router-consts.ts @@ -5,21 +5,34 @@ export const PATH_CONNECT_WEBSOCKET = "/connect/websocket"; export const PATH_RAW_WEBSOCKET_PREFIX = "/raw/websocket/"; // MARK: Headers -export const HEADER_ACTOR_QUERY = "X-RivetKit-Query"; +export const HEADER_ACTOR_QUERY = "x-rivet-query"; -export const HEADER_ENCODING = "X-RivetKit-Encoding"; +export const HEADER_ENCODING = "x-rivet-encoding"; // IMPORTANT: Params must be in headers or in an E2EE part of the request (i.e. NOT the URL or query string) in order to ensure that tokens can be securely passed in params. -export const HEADER_CONN_PARAMS = "X-RivetKit-Conn-Params"; +export const HEADER_CONN_PARAMS = "x-rivet-conn-params"; -// Internal header -export const HEADER_AUTH_DATA = "X-RivetKit-Auth-Data"; +export const HEADER_ACTOR_ID = "x-rivet-actor"; -export const HEADER_ACTOR_ID = "X-RivetKit-Actor"; +export const HEADER_CONN_ID = "x-rivet-conn"; -export const HEADER_CONN_ID = "X-RivetKit-Conn"; +export const HEADER_CONN_TOKEN = "x-rivet-conn-token"; -export const HEADER_CONN_TOKEN = "X-RivetKit-Conn-Token"; +// MARK: Manager Gateway Headers +export const HEADER_RIVET_TARGET = "x-rivet-target"; +export const HEADER_RIVET_ACTOR = "x-rivet-actor"; + +// MARK: WebSocket Protocol Prefixes +/** Some servers (such as node-ws & Cloudflare) require explicitly match a certain WebSocket protocol. This gives us a static protocol to match against. */ +export const WS_PROTOCOL_STANDARD = "rivet"; +export const WS_PROTOCOL_TARGET = "rivet_target."; +export const WS_PROTOCOL_ACTOR = "rivet_actor."; +export const WS_PROTOCOL_ENCODING = "rivet_encoding."; +export const WS_PROTOCOL_CONN_PARAMS = "rivet_conn_params."; + +// MARK: WebSocket Inline Test Protocol Prefixes +export const WS_PROTOCOL_TRANSPORT = "test_transport."; +export const WS_PROTOCOL_PATH = "test_path."; /** * Headers that publics can send from public clients. @@ -35,4 +48,6 @@ export const ALLOWED_PUBLIC_HEADERS = [ HEADER_ACTOR_ID, HEADER_CONN_ID, HEADER_CONN_TOKEN, + HEADER_RIVET_TARGET, + HEADER_RIVET_ACTOR, ]; diff --git a/packages/rivetkit/src/driver-helpers/mod.ts b/packages/rivetkit/src/driver-helpers/mod.ts index 3b6e49826..c1ac9f757 100644 --- a/packages/rivetkit/src/driver-helpers/mod.ts +++ b/packages/rivetkit/src/driver-helpers/mod.ts @@ -1,13 +1,24 @@ export type { ActorDriver } from "@/actor/driver"; export type { ActorInstance, AnyActorInstance } from "@/actor/instance"; export { + ALLOWED_PUBLIC_HEADERS, HEADER_ACTOR_ID, HEADER_ACTOR_QUERY, - HEADER_AUTH_DATA, HEADER_CONN_ID, HEADER_CONN_PARAMS, HEADER_CONN_TOKEN, HEADER_ENCODING, + HEADER_RIVET_ACTOR, + HEADER_RIVET_TARGET, + PATH_CONNECT_WEBSOCKET, + PATH_RAW_WEBSOCKET_PREFIX, + WS_PROTOCOL_ACTOR, + WS_PROTOCOL_CONN_PARAMS, + WS_PROTOCOL_ENCODING, + WS_PROTOCOL_PATH, + WS_PROTOCOL_STANDARD, + WS_PROTOCOL_TARGET, + WS_PROTOCOL_TRANSPORT, } from "@/common/actor-router-consts"; export type { ActorOutput, diff --git a/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts b/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts index 247bf0515..e32030efa 100644 --- a/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts +++ b/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts @@ -10,6 +10,12 @@ import { HEADER_ACTOR_QUERY, HEADER_CONN_PARAMS, HEADER_ENCODING, + WS_PROTOCOL_ACTOR, + WS_PROTOCOL_CONN_PARAMS, + WS_PROTOCOL_ENCODING, + WS_PROTOCOL_PATH, + WS_PROTOCOL_TARGET, + WS_PROTOCOL_TRANSPORT, } from "@/common/actor-router-consts"; import type { UniversalEventSource } from "@/common/eventsource-interface"; import type { DeconstructedError } from "@/common/utils"; @@ -161,16 +167,9 @@ export function createTestInlineClientDriver( const normalizedPath = path.startsWith("/") ? path.slice(1) : path; // Create WebSocket connection to the test endpoint - // Use a placeholder path and pass the actual path as a query param to avoid mixing user query params with internal ones const wsUrl = new URL( `${endpoint}/.test/inline-driver/connect-websocket/ws`, ); - wsUrl.searchParams.set("path", normalizedPath); - wsUrl.searchParams.set("actorId", actorId); - if (params !== undefined) - wsUrl.searchParams.set("params", JSON.stringify(params)); - wsUrl.searchParams.set("encodingKind", encoding); - wsUrl.searchParams.set("transport", transport); logger().debug({ msg: "creating websocket connection via test inline driver", @@ -179,16 +178,28 @@ export function createTestInlineClientDriver( // Convert http/https to ws/wss const wsProtocol = wsUrl.protocol === "https:" ? "wss:" : "ws:"; - const finalWsUrl = `${wsProtocol}//${wsUrl.host}${wsUrl.pathname}${wsUrl.search}`; + const finalWsUrl = `${wsProtocol}//${wsUrl.host}${wsUrl.pathname}`; logger().debug({ msg: "connecting to websocket", url: finalWsUrl }); + // Build protocols for the connection + const protocols: string[] = []; + protocols.push(`${WS_PROTOCOL_TARGET}actor`); + protocols.push(`${WS_PROTOCOL_ACTOR}${actorId}`); + protocols.push(`${WS_PROTOCOL_ENCODING}${encoding}`); + protocols.push(`${WS_PROTOCOL_TRANSPORT}${transport}`); + protocols.push( + `${WS_PROTOCOL_PATH}${encodeURIComponent(normalizedPath)}`, + ); + if (params !== undefined) { + protocols.push( + `${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`, + ); + } + // Create and return the WebSocket // Node & browser WebSocket types are incompatible - const ws = new WebSocket(finalWsUrl, [ - // HACK: See packages/drivers/cloudflare-workers/src/websocket.ts - "rivetkit", - ]) as any; + const ws = new WebSocket(finalWsUrl, protocols) as any; return ws; }, @@ -205,7 +216,6 @@ export function createTestInlineClientDriver( _actorId: string, _encoding: Encoding, _params: unknown, - _authData: unknown, ): Promise { throw "UNIMPLEMENTED"; // const upgradeWebSocket = this.#runConfig.getUpgradeWebSocket?.(); diff --git a/packages/rivetkit/src/driver-test-suite/tests/raw-websocket.ts b/packages/rivetkit/src/driver-test-suite/tests/raw-websocket.ts index c4196a305..a0ef07afc 100644 --- a/packages/rivetkit/src/driver-test-suite/tests/raw-websocket.ts +++ b/packages/rivetkit/src/driver-test-suite/tests/raw-websocket.ts @@ -277,41 +277,6 @@ export function runRawWebSocketTests(driverTestConfig: DriverTestConfig) { ws.close(); }); - test("should pass connection parameters through subprotocols", async (c) => { - const { client } = await setupDriverTest(c, driverTestConfig); - - // Create actor with connection parameters - const testParams = { userId: "test123", role: "admin" }; - const actor = client.rawWebSocketActor.getOrCreate(["params"], { - params: testParams, - }); - - const ws = await actor.websocket(); - - await new Promise((resolve) => { - ws.addEventListener("open", () => resolve(), { once: true }); - }); - - // Send a request to echo the auth data (which should include conn params from auth) - ws.send(JSON.stringify({ type: "getAuthData" })); - - const response = await new Promise((resolve, reject) => { - ws.addEventListener("message", (event: any) => { - const data = JSON.parse(event.data as string); - if (data.type === "authData") { - resolve(data); - } - }); - ws.addEventListener("close", reject); - }); - - // For now, just verify we get a response - // The actual connection params handling needs to be implemented - expect(response).toBeDefined(); - - ws.close(); - }); - test("should handle connection close properly", async (c) => { const { client } = await setupDriverTest(c, driverTestConfig); const actor = client.rawWebSocketActor.getOrCreate(["close-test"]); diff --git a/packages/rivetkit/src/drivers/engine/actor-driver.ts b/packages/rivetkit/src/drivers/engine/actor-driver.ts index b1672523e..44c6b0950 100644 --- a/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -29,7 +29,6 @@ import type { UniversalWebSocket } from "@/common/websocket-interface"; import { type ActorDriver, type AnyActorInstance, - HEADER_AUTH_DATA, HEADER_CONN_PARAMS, HEADER_ENCODING, type ManagerDriver, @@ -297,11 +296,9 @@ export class EngineActorDriver implements ActorDriver { // Parse headers const encodingRaw = request.headers.get(HEADER_ENCODING); const connParamsRaw = request.headers.get(HEADER_CONN_PARAMS); - const authDataRaw = request.headers.get(HEADER_AUTH_DATA); const encoding = EncodingSchema.parse(encodingRaw); const connParams = connParamsRaw ? JSON.parse(connParamsRaw) : undefined; - const authData = authDataRaw ? JSON.parse(authDataRaw) : undefined; // Fetch WS handler // @@ -315,7 +312,6 @@ export class EngineActorDriver implements ActorDriver { actorId, encoding, connParams, - authData, ); } else if (url.pathname.startsWith(PATH_RAW_WEBSOCKET_PREFIX)) { wsHandlerPromise = handleRawWebSocketHandler( @@ -323,7 +319,6 @@ export class EngineActorDriver implements ActorDriver { url.pathname + url.search, this, actorId, - authData, ); } else { throw new Error(`Unreachable path: ${url.pathname}`); diff --git a/packages/rivetkit/src/drivers/file-system/manager.ts b/packages/rivetkit/src/drivers/file-system/manager.ts index 820078018..47e34d023 100644 --- a/packages/rivetkit/src/drivers/file-system/manager.ts +++ b/packages/rivetkit/src/drivers/file-system/manager.ts @@ -154,7 +154,6 @@ export class FileSystemManagerDriver implements ManagerDriver { actorId, encoding, params, - undefined, ); return new InlineWebSocketAdapter2(wsHandler); } else if ( @@ -168,7 +167,6 @@ export class FileSystemManagerDriver implements ManagerDriver { path, this.#actorDriver, actorId, - undefined, ); return new InlineWebSocketAdapter2(wsHandler); } else { @@ -208,7 +206,6 @@ export class FileSystemManagerDriver implements ManagerDriver { actorId, encoding, connParams, - undefined, ); return upgradeWebSocket(() => wsHandler)(c, noopNext()); } else if ( @@ -222,7 +219,6 @@ export class FileSystemManagerDriver implements ManagerDriver { path, this.#actorDriver, actorId, - undefined, ); return upgradeWebSocket(() => wsHandler)(c, noopNext()); } else { diff --git a/packages/rivetkit/src/manager/driver.ts b/packages/rivetkit/src/manager/driver.ts index 79df91ac1..bd8916c35 100644 --- a/packages/rivetkit/src/manager/driver.ts +++ b/packages/rivetkit/src/manager/driver.ts @@ -33,7 +33,6 @@ export interface ManagerDriver { actorId: string, encoding: Encoding, params: unknown, - authData: unknown, ): Promise; displayInformation(): ManagerDisplayInformation; diff --git a/packages/rivetkit/src/manager/gateway.ts b/packages/rivetkit/src/manager/gateway.ts new file mode 100644 index 000000000..bfa73c070 --- /dev/null +++ b/packages/rivetkit/src/manager/gateway.ts @@ -0,0 +1,397 @@ +import type { Context as HonoContext, Next } from "hono"; +import type { WSContext } from "hono/ws"; +import { MissingActorHeader, WebSocketsNotEnabled } from "@/actor/errors"; +import type { Encoding, Transport } from "@/client/mod"; +import { + HEADER_RIVET_ACTOR, + HEADER_RIVET_TARGET, + WS_PROTOCOL_ACTOR, + WS_PROTOCOL_CONN_PARAMS, + WS_PROTOCOL_ENCODING, + WS_PROTOCOL_TARGET, +} from "@/common/actor-router-consts"; +import { deconstructError, noopNext } from "@/common/utils"; +import type { UniversalWebSocket, UpgradeWebSocketArgs } from "@/mod"; +import type { RunConfig } from "@/registry/run-config"; +import { promiseWithResolvers, stringifyError } from "@/utils"; +import type { ManagerDriver } from "./driver"; +import { logger } from "./log"; + +/** + * Provides an endpoint to connect to individual actors. + * + * Routes requests based on the Upgrade header: + * - WebSocket requests: Uses sec-websocket-protocol for routing (target.actor, actor.{id}) + * - HTTP requests: Uses x-rivet-target and x-rivet-actor headers for routing + */ +export async function actorGateway( + runConfig: RunConfig, + managerDriver: ManagerDriver, + c: HonoContext, + next: Next, +) { + // Skip test routes - let them be handled by their specific handlers + if (c.req.path.startsWith("/.test/")) { + return next(); + } + + // Check if this is a WebSocket upgrade request + if (c.req.header("upgrade") === "websocket") { + return await handleWebSocketGateway(runConfig, managerDriver, c); + } + + // Handle regular HTTP requests + return await handleHttpGateway(managerDriver, c, next); +} + +/** + * Handle WebSocket requests using sec-websocket-protocol for routing + */ +async function handleWebSocketGateway( + runConfig: RunConfig, + managerDriver: ManagerDriver, + c: HonoContext, +) { + const upgradeWebSocket = runConfig.getUpgradeWebSocket?.(); + if (!upgradeWebSocket) { + throw new WebSocketsNotEnabled(); + } + + // Parse configuration from Sec-WebSocket-Protocol header + const protocols = c.req.header("sec-websocket-protocol"); + let target: string | undefined; + let actorId: string | undefined; + let encodingRaw: string | undefined; + let connParamsRaw: string | undefined; + + if (protocols) { + const protocolList = protocols.split(",").map((p) => p.trim()); + for (const protocol of protocolList) { + if (protocol.startsWith(WS_PROTOCOL_TARGET)) { + target = protocol.substring(WS_PROTOCOL_TARGET.length); + } else if (protocol.startsWith(WS_PROTOCOL_ACTOR)) { + actorId = protocol.substring(WS_PROTOCOL_ACTOR.length); + } else if (protocol.startsWith(WS_PROTOCOL_ENCODING)) { + encodingRaw = protocol.substring(WS_PROTOCOL_ENCODING.length); + } else if (protocol.startsWith(WS_PROTOCOL_CONN_PARAMS)) { + connParamsRaw = decodeURIComponent( + protocol.substring(WS_PROTOCOL_CONN_PARAMS.length), + ); + } + } + } + + if (target !== "actor") { + return c.text("WebSocket upgrade requires target.actor protocol", 400); + } + + if (!actorId) { + throw new MissingActorHeader(); + } + + logger().debug({ + msg: "proxying websocket to actor", + actorId, + path: c.req.path, + encoding: encodingRaw, + }); + + const encoding = encodingRaw || "json"; + const connParams = connParamsRaw ? JSON.parse(connParamsRaw) : undefined; + + // Include query string if present + const pathWithQuery = c.req.url.includes("?") + ? c.req.path + c.req.url.substring(c.req.url.indexOf("?")) + : c.req.path; + + return await managerDriver.proxyWebSocket( + c, + pathWithQuery, + actorId, + encoding as any, // Will be validated by driver + connParams, + ); +} + +/** + * Handle HTTP requests using x-rivet headers for routing + */ +async function handleHttpGateway( + managerDriver: ManagerDriver, + c: HonoContext, + next: Next, +) { + const target = c.req.header(HEADER_RIVET_TARGET); + const actorId = c.req.header(HEADER_RIVET_ACTOR); + + if (target !== "actor") { + return next(); + } + + if (!actorId) { + throw new MissingActorHeader(); + } + + logger().debug({ + msg: "proxying request to actor", + actorId, + path: c.req.path, + method: c.req.method, + }); + + // Preserve all headers except the routing headers + const proxyHeaders = new Headers(c.req.raw.headers); + proxyHeaders.delete(HEADER_RIVET_TARGET); + proxyHeaders.delete(HEADER_RIVET_ACTOR); + + // Build the proxy request with the actor URL format + const url = new URL(c.req.url); + const proxyUrl = new URL(`http://actor${url.pathname}${url.search}`); + + const proxyRequest = new Request(proxyUrl, { + method: c.req.raw.method, + headers: proxyHeaders, + body: c.req.raw.body, + signal: c.req.raw.signal, + }); + + return await managerDriver.proxyRequest(c, proxyRequest, actorId); +} + +/** + * Creates a WebSocket proxy for test endpoints that forwards messages between server and client WebSockets + */ +export async function createTestWebSocketProxy( + clientWsPromise: Promise, +): Promise { + // Store a reference to the resolved WebSocket + let clientWs: UniversalWebSocket | null = null; + const { + promise: serverWsPromise, + resolve: serverWsResolve, + reject: serverWsReject, + } = promiseWithResolvers(); + try { + // Resolve the client WebSocket promise + logger().debug({ msg: "awaiting client websocket promise" }); + const ws = await clientWsPromise; + clientWs = ws; + logger().debug({ + msg: "client websocket promise resolved", + constructor: ws?.constructor.name, + }); + + // Wait for ws to open + await new Promise((resolve, reject) => { + const onOpen = () => { + logger().debug({ msg: "test websocket connection to actor opened" }); + resolve(); + }; + const onError = (error: any) => { + logger().error({ msg: "test websocket connection failed", error }); + reject( + new Error(`Failed to open WebSocket: ${error.message || error}`), + ); + serverWsReject(); + }; + + ws.addEventListener("open", onOpen); + + ws.addEventListener("error", onError); + + ws.addEventListener("message", async (clientEvt: MessageEvent) => { + const serverWs = await serverWsPromise; + + logger().debug({ + msg: `test websocket connection message from client`, + dataType: typeof clientEvt.data, + isBlob: clientEvt.data instanceof Blob, + isArrayBuffer: clientEvt.data instanceof ArrayBuffer, + dataConstructor: clientEvt.data?.constructor?.name, + dataStr: + typeof clientEvt.data === "string" + ? clientEvt.data.substring(0, 100) + : undefined, + }); + + if (serverWs.readyState === 1) { + // OPEN + // Handle Blob data + if (clientEvt.data instanceof Blob) { + clientEvt.data + .arrayBuffer() + .then((buffer) => { + logger().debug({ + msg: "converted client blob to arraybuffer, sending to server", + bufferSize: buffer.byteLength, + }); + serverWs.send(buffer as any); + }) + .catch((error) => { + logger().error({ + msg: "failed to convert blob to arraybuffer", + error, + }); + }); + } else { + logger().debug({ + msg: "sending client data directly to server", + dataType: typeof clientEvt.data, + dataLength: + typeof clientEvt.data === "string" + ? clientEvt.data.length + : undefined, + }); + serverWs.send(clientEvt.data as any); + } + } + }); + + ws.addEventListener("close", async (clientEvt: any) => { + const serverWs = await serverWsPromise; + + logger().debug({ + msg: `test websocket connection closed`, + }); + + if (serverWs.readyState !== 3) { + // Not CLOSED + serverWs.close(clientEvt.code, clientEvt.reason); + } + }); + + ws.addEventListener("error", async () => { + const serverWs = await serverWsPromise; + + logger().debug({ + msg: `test websocket connection error`, + }); + + if (serverWs.readyState !== 3) { + // Not CLOSED + serverWs.close(1011, "Error in client websocket"); + } + }); + }); + } catch (error) { + logger().error({ + msg: `failed to establish client websocket connection`, + error, + }); + return { + onOpen: (_evt, serverWs) => { + serverWs.close(1011, "Failed to establish connection"); + }, + onMessage: () => {}, + onError: () => {}, + onClose: () => {}, + }; + } + + // Create WebSocket proxy handlers to relay messages between client and server + return { + onOpen: (_evt: any, serverWs: WSContext) => { + logger().debug({ + msg: `test websocket connection from client opened`, + }); + + // Check WebSocket type + logger().debug({ + msg: "clientWs info", + constructor: clientWs.constructor.name, + hasAddEventListener: typeof clientWs.addEventListener === "function", + readyState: clientWs.readyState, + }); + + serverWsResolve(serverWs); + }, + onMessage: (evt: { data: any }) => { + logger().debug({ + msg: "received message from server", + dataType: typeof evt.data, + isBlob: evt.data instanceof Blob, + isArrayBuffer: evt.data instanceof ArrayBuffer, + dataConstructor: evt.data?.constructor?.name, + dataStr: + typeof evt.data === "string" ? evt.data.substring(0, 100) : undefined, + }); + + // Forward messages from server websocket to client websocket + if (clientWs.readyState === 1) { + // OPEN + // Handle Blob data + if (evt.data instanceof Blob) { + evt.data + .arrayBuffer() + .then((buffer) => { + logger().debug({ + msg: "converted blob to arraybuffer, sending", + bufferSize: buffer.byteLength, + }); + clientWs.send(buffer); + }) + .catch((error) => { + logger().error({ + msg: "failed to convert blob to arraybuffer", + error, + }); + }); + } else { + logger().debug({ + msg: "sending data directly", + dataType: typeof evt.data, + dataLength: + typeof evt.data === "string" ? evt.data.length : undefined, + }); + clientWs.send(evt.data); + } + } + }, + onClose: ( + event: { + wasClean: boolean; + code: number; + reason: string; + }, + serverWs: WSContext, + ) => { + logger().debug({ + msg: `server websocket closed`, + wasClean: event.wasClean, + code: event.code, + reason: event.reason, + }); + + // HACK: Close socket in order to fix bug with Cloudflare leaving WS in closing state + // https://github.com/cloudflare/workerd/issues/2569 + serverWs.close(1000, "hack_force_close"); + + // Close the client websocket when the server websocket closes + if ( + clientWs && + clientWs.readyState !== clientWs.CLOSED && + clientWs.readyState !== clientWs.CLOSING + ) { + // Don't pass code/message since this may affect how close events are triggered + clientWs.close(1000, event.reason); + } + }, + onError: (error: unknown) => { + logger().error({ + msg: `error in server websocket`, + error, + }); + + // Close the client websocket on error + if ( + clientWs && + clientWs.readyState !== clientWs.CLOSED && + clientWs.readyState !== clientWs.CLOSING + ) { + clientWs.close(1011, "Error in server websocket"); + } + + serverWsReject(); + }, + }; +} diff --git a/packages/rivetkit/src/manager/router.ts b/packages/rivetkit/src/manager/router.ts index fd6ed94a4..01f44abea 100644 --- a/packages/rivetkit/src/manager/router.ts +++ b/packages/rivetkit/src/manager/router.ts @@ -1,19 +1,20 @@ import { createRoute, OpenAPIHono } from "@hono/zod-openapi"; import * as cbor from "cbor-x"; -import { Hono } from "hono"; +import { Hono, Context as HonoContext, Next } from "hono"; import { cors as corsMiddleware } from "hono/cors"; import { createMiddleware } from "hono/factory"; -import type { WSContext } from "hono/ws"; import invariant from "invariant"; import { z } from "zod"; -import { - ActorNotFound, - MissingActorHeader, - Unsupported, - WebSocketsNotEnabled, -} from "@/actor/errors"; +import { ActorNotFound, Unsupported } from "@/actor/errors"; import { serializeActorKey } from "@/actor/keys"; import type { Encoding, Transport } from "@/client/mod"; +import { + WS_PROTOCOL_ACTOR, + WS_PROTOCOL_CONN_PARAMS, + WS_PROTOCOL_ENCODING, + WS_PROTOCOL_PATH, + WS_PROTOCOL_TRANSPORT, +} from "@/common/actor-router-consts"; import { handleRouteError, handleRouteNotFound, @@ -41,11 +42,11 @@ import { type Actor as ApiActor, } from "@/manager-api/actors"; import { RivetIdSchema } from "@/manager-api/common"; -import type { UniversalWebSocket, UpgradeWebSocketArgs } from "@/mod"; import type { RegistryConfig } from "@/registry/config"; import type { RunConfig } from "@/registry/run-config"; -import { promiseWithResolvers, stringifyError } from "@/utils"; +import { stringifyError } from "@/utils"; import type { ActorOutput, ManagerDriver } from "./driver"; +import { actorGateway, createTestWebSocketProxy } from "./gateway"; import { logger } from "./log"; function buildOpenApiResponses(schema: T) { @@ -82,80 +83,8 @@ export function createManagerRouter( ? corsMiddleware(runConfig.cors) : createMiddleware((_c, next) => next()); - // Actor proxy middleware - intercept requests with x-rivet-target=actor - router.use("*", cors, async (c, next) => { - const target = c.req.header("x-rivet-target"); - const actorId = c.req.header("x-rivet-actor"); - - if (target === "actor") { - if (!actorId) { - throw new MissingActorHeader(); - } - - logger().debug({ - msg: "proxying request to actor", - actorId, - path: c.req.path, - method: c.req.method, - }); - - // Handle WebSocket upgrade - if (c.req.header("upgrade") === "websocket") { - const upgradeWebSocket = runConfig.getUpgradeWebSocket?.(); - if (!upgradeWebSocket) { - throw new WebSocketsNotEnabled(); - } - - // For WebSocket, use the driver's proxyWebSocket method - // Extract any additional headers that might be needed - const encoding = - c.req.header("X-RivetKit-Encoding") || - c.req.header("x-rivet-encoding") || - "json"; - const connParams = - c.req.header("X-RivetKit-Conn-Params") || - c.req.header("x-rivet-conn-params"); - const authData = - c.req.header("X-RivetKit-Auth-Data") || - c.req.header("x-rivet-auth-data"); - - // Include query string if present - const pathWithQuery = c.req.url.includes("?") - ? c.req.path + c.req.url.substring(c.req.url.indexOf("?")) - : c.req.path; - - return await managerDriver.proxyWebSocket( - c, - pathWithQuery, - actorId, - encoding as any, // Will be validated by driver - connParams ? JSON.parse(connParams) : undefined, - authData ? JSON.parse(authData) : undefined, - ); - } - - // Handle regular HTTP requests - // Preserve all headers except the routing headers - const proxyHeaders = new Headers(c.req.raw.headers); - proxyHeaders.delete("x-rivet-target"); - proxyHeaders.delete("x-rivet-actor"); - - // Build the proxy request with the actor URL format - const url = new URL(c.req.url); - const proxyUrl = new URL(`http://actor${url.pathname}${url.search}`); - - const proxyRequest = new Request(proxyUrl, { - method: c.req.raw.method, - headers: proxyHeaders, - body: c.req.raw.body, - signal: c.req.raw.signal, - }); - - return await managerDriver.proxyRequest(c, proxyRequest, actorId); - } - - return next(); - }); + // Actor gateway + router.use("*", cors, actorGateway.bind(undefined, runConfig, managerDriver)); // GET / router.get("/", cors, (c) => { @@ -389,27 +318,45 @@ export function createManagerRouter( invariant(upgradeWebSocket, "websockets not supported on this platform"); return upgradeWebSocket(async (c: any) => { - const { - path, - actorId, - params: paramsRaw, - encodingKind, - transport, - } = c.req.query() as { - path: string; - actorId: string; - params?: string; - encodingKind: Encoding; - transport: Transport; - }; - const params = - paramsRaw !== undefined ? JSON.parse(paramsRaw) : undefined; + // Extract information from sec-websocket-protocol header + const protocolHeader = c.req.header("sec-websocket-protocol") || ""; + const protocols = protocolHeader.split(/,\s*/); + + // Parse protocols to extract connection info + let actorId = ""; + let encoding: Encoding = "bare"; + let transport: Transport = "websocket"; + let path = ""; + let params: unknown; + + for (const protocol of protocols) { + if (protocol.startsWith(WS_PROTOCOL_ACTOR)) { + actorId = protocol.substring(WS_PROTOCOL_ACTOR.length); + } else if (protocol.startsWith(WS_PROTOCOL_ENCODING)) { + encoding = protocol.substring( + WS_PROTOCOL_ENCODING.length, + ) as Encoding; + } else if (protocol.startsWith(WS_PROTOCOL_TRANSPORT)) { + transport = protocol.substring( + WS_PROTOCOL_TRANSPORT.length, + ) as Transport; + } else if (protocol.startsWith(WS_PROTOCOL_PATH)) { + path = decodeURIComponent( + protocol.substring(WS_PROTOCOL_PATH.length), + ); + } else if (protocol.startsWith(WS_PROTOCOL_CONN_PARAMS)) { + const paramsRaw = decodeURIComponent( + protocol.substring(WS_PROTOCOL_CONN_PARAMS.length), + ); + params = JSON.parse(paramsRaw); + } + } logger().debug({ msg: "received test inline driver websocket", actorId, params, - encodingKind, + encodingKind: encoding, transport, path: path, }); @@ -418,7 +365,7 @@ export function createManagerRouter( const clientWsPromise = managerDriver.openWebSocket( path, actorId, - encodingKind, + encoding, params, ); @@ -512,243 +459,6 @@ export function createManagerRouter( return { router: router as Hono, openapi: router }; } -/** - * Creates a WebSocket proxy for test endpoints that forwards messages between server and client WebSockets - */ -async function createTestWebSocketProxy( - clientWsPromise: Promise, -): Promise { - // Store a reference to the resolved WebSocket - let clientWs: UniversalWebSocket | null = null; - const { - promise: serverWsPromise, - resolve: serverWsResolve, - reject: serverWsReject, - } = promiseWithResolvers(); - try { - // Resolve the client WebSocket promise - logger().debug({ msg: "awaiting client websocket promise" }); - const ws = await clientWsPromise; - clientWs = ws; - logger().debug({ - msg: "client websocket promise resolved", - constructor: ws?.constructor.name, - }); - - // Wait for ws to open - await new Promise((resolve, reject) => { - const onOpen = () => { - logger().debug({ msg: "test websocket connection to actor opened" }); - resolve(); - }; - const onError = (error: any) => { - logger().error({ msg: "test websocket connection failed", error }); - reject( - new Error(`Failed to open WebSocket: ${error.message || error}`), - ); - serverWsReject(); - }; - - ws.addEventListener("open", onOpen); - - ws.addEventListener("error", onError); - - ws.addEventListener("message", async (clientEvt: MessageEvent) => { - const serverWs = await serverWsPromise; - - logger().debug({ - msg: `test websocket connection message from client`, - dataType: typeof clientEvt.data, - isBlob: clientEvt.data instanceof Blob, - isArrayBuffer: clientEvt.data instanceof ArrayBuffer, - dataConstructor: clientEvt.data?.constructor?.name, - dataStr: - typeof clientEvt.data === "string" - ? clientEvt.data.substring(0, 100) - : undefined, - }); - - if (serverWs.readyState === 1) { - // OPEN - // Handle Blob data - if (clientEvt.data instanceof Blob) { - clientEvt.data - .arrayBuffer() - .then((buffer) => { - logger().debug({ - msg: "converted client blob to arraybuffer, sending to server", - bufferSize: buffer.byteLength, - }); - serverWs.send(buffer as any); - }) - .catch((error) => { - logger().error({ - msg: "failed to convert blob to arraybuffer", - error, - }); - }); - } else { - logger().debug({ - msg: "sending client data directly to server", - dataType: typeof clientEvt.data, - dataLength: - typeof clientEvt.data === "string" - ? clientEvt.data.length - : undefined, - }); - serverWs.send(clientEvt.data as any); - } - } - }); - - ws.addEventListener("close", async (clientEvt: any) => { - const serverWs = await serverWsPromise; - - logger().debug({ - msg: `test websocket connection closed`, - }); - - if (serverWs.readyState !== 3) { - // Not CLOSED - serverWs.close(clientEvt.code, clientEvt.reason); - } - }); - - ws.addEventListener("error", async () => { - const serverWs = await serverWsPromise; - - logger().debug({ - msg: `test websocket connection error`, - }); - - if (serverWs.readyState !== 3) { - // Not CLOSED - serverWs.close(1011, "Error in client websocket"); - } - }); - }); - } catch (error) { - logger().error({ - msg: `failed to establish client websocket connection`, - error, - }); - return { - onOpen: (_evt, serverWs) => { - serverWs.close(1011, "Failed to establish connection"); - }, - onMessage: () => {}, - onError: () => {}, - onClose: () => {}, - }; - } - - // Create WebSocket proxy handlers to relay messages between client and server - return { - onOpen: (_evt: any, serverWs: WSContext) => { - logger().debug({ - msg: `test websocket connection from client opened`, - }); - - // Check WebSocket type - logger().debug({ - msg: "clientWs info", - constructor: clientWs.constructor.name, - hasAddEventListener: typeof clientWs.addEventListener === "function", - readyState: clientWs.readyState, - }); - - serverWsResolve(serverWs); - }, - onMessage: (evt: { data: any }) => { - logger().debug({ - msg: "received message from server", - dataType: typeof evt.data, - isBlob: evt.data instanceof Blob, - isArrayBuffer: evt.data instanceof ArrayBuffer, - dataConstructor: evt.data?.constructor?.name, - dataStr: - typeof evt.data === "string" ? evt.data.substring(0, 100) : undefined, - }); - - // Forward messages from server websocket to client websocket - if (clientWs.readyState === 1) { - // OPEN - // Handle Blob data - if (evt.data instanceof Blob) { - evt.data - .arrayBuffer() - .then((buffer) => { - logger().debug({ - msg: "converted blob to arraybuffer, sending", - bufferSize: buffer.byteLength, - }); - clientWs.send(buffer); - }) - .catch((error) => { - logger().error({ - msg: "failed to convert blob to arraybuffer", - error, - }); - }); - } else { - logger().debug({ - msg: "sending data directly", - dataType: typeof evt.data, - dataLength: - typeof evt.data === "string" ? evt.data.length : undefined, - }); - clientWs.send(evt.data); - } - } - }, - onClose: ( - event: { - wasClean: boolean; - code: number; - reason: string; - }, - serverWs: WSContext, - ) => { - logger().debug({ - msg: `server websocket closed`, - wasClean: event.wasClean, - code: event.code, - reason: event.reason, - }); - - // HACK: Close socket in order to fix bug with Cloudflare leaving WS in closing state - // https://github.com/cloudflare/workerd/issues/2569 - serverWs.close(1000, "hack_force_close"); - - // Close the client websocket when the server websocket closes - if ( - clientWs && - clientWs.readyState !== clientWs.CLOSED && - clientWs.readyState !== clientWs.CLOSING - ) { - // Don't pass code/message since this may affect how close events are triggered - clientWs.close(1000, event.reason); - } - }, - onError: (error: unknown) => { - logger().error({ - msg: `error in server websocket`, - error, - }); - - // Close the client websocket on error - if ( - clientWs && - clientWs.readyState !== clientWs.CLOSED && - clientWs.readyState !== clientWs.CLOSING - ) { - clientWs.close(1011, "Error in server websocket"); - } - - serverWsReject(); - }, - }; -} function createApiActor(actor: ActorOutput): ApiActor { return { diff --git a/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts b/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts index ebaa344ff..7b14c3b20 100644 --- a/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts +++ b/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts @@ -1,8 +1,12 @@ import type { ClientConfig } from "@/client/config"; import { - HEADER_AUTH_DATA, HEADER_CONN_PARAMS, HEADER_ENCODING, + WS_PROTOCOL_ACTOR, + WS_PROTOCOL_CONN_PARAMS, + WS_PROTOCOL_ENCODING, + WS_PROTOCOL_STANDARD as WS_PROTOCOL_RIVETKIT, + WS_PROTOCOL_TARGET, } from "@/common/actor-router-consts"; import { importWebSocket } from "@/common/websocket"; import type { Encoding, UniversalWebSocket } from "@/mod"; @@ -31,9 +35,10 @@ export async function openWebSocketToActor( }); // Create WebSocket connection - const ws = new WebSocket(guardUrl, { - headers: buildGuardHeadersForWebSocket(actorId, encoding, params), - }); + const ws = new WebSocket( + guardUrl, + buildWebSocketProtocols(actorId, encoding, params), + ); // Set binary type to arraybuffer for proper encoding support ws.binaryType = "arraybuffer"; @@ -43,22 +48,20 @@ export async function openWebSocketToActor( return ws as UniversalWebSocket; } -export function buildGuardHeadersForWebSocket( +export function buildWebSocketProtocols( actorId: string, encoding: Encoding, params?: unknown, - authData?: unknown, -): Record { - const headers: Record = {}; - headers["x-rivet-target"] = "actor"; - headers["x-rivet-actor"] = actorId; - headers["x-rivet-port"] = "main"; - headers[HEADER_ENCODING] = encoding; +): string[] { + const protocols: string[] = []; + protocols.push(WS_PROTOCOL_RIVETKIT); + protocols.push(`${WS_PROTOCOL_TARGET}actor`); + protocols.push(`${WS_PROTOCOL_ACTOR}${actorId}`); + protocols.push(`${WS_PROTOCOL_ENCODING}${encoding}`); if (params) { - headers[HEADER_CONN_PARAMS] = JSON.stringify(params); - } - if (authData) { - headers[HEADER_AUTH_DATA] = JSON.stringify(authData); + protocols.push( + `${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`, + ); } - return headers; + return protocols; } diff --git a/packages/rivetkit/src/remote-manager-driver/mod.ts b/packages/rivetkit/src/remote-manager-driver/mod.ts index 68cf0692a..397c9bf00 100644 --- a/packages/rivetkit/src/remote-manager-driver/mod.ts +++ b/packages/rivetkit/src/remote-manager-driver/mod.ts @@ -17,7 +17,7 @@ import type { Encoding, UniversalWebSocket } from "@/mod"; import { combineUrlPath } from "@/utils"; import { sendHttpRequestToActor } from "./actor-http-client"; import { - buildGuardHeadersForWebSocket, + buildWebSocketProtocols, openWebSocketToActor, } from "./actor-websocket-client"; import { @@ -227,7 +227,6 @@ export class RemoteManagerDriver implements ManagerDriver { actorId: string, encoding: Encoding, params: unknown, - authData: unknown, ): Promise { const upgradeWebSocket = this.#config.getUpgradeWebSocket?.(); invariant(upgradeWebSocket, "missing getUpgradeWebSocket"); @@ -243,14 +242,9 @@ export class RemoteManagerDriver implements ManagerDriver { guardUrl, }); - // Build headers - const headers = buildGuardHeadersForWebSocket( - actorId, - encoding, - params, - authData, - ); - const args = await createWebSocketProxy(c, wsGuardUrl, headers); + // Build protocols + const protocols = buildWebSocketProtocols(actorId, encoding, params); + const args = await createWebSocketProxy(c, wsGuardUrl, protocols); return await upgradeWebSocket(() => args)(c, noopNext()); } diff --git a/packages/rivetkit/src/remote-manager-driver/ws-proxy.ts b/packages/rivetkit/src/remote-manager-driver/ws-proxy.ts index c631e1d5a..dce82cd74 100644 --- a/packages/rivetkit/src/remote-manager-driver/ws-proxy.ts +++ b/packages/rivetkit/src/remote-manager-driver/ws-proxy.ts @@ -11,17 +11,10 @@ import { logger } from "./log"; export async function createWebSocketProxy( c: HonoContext, targetUrl: string, - headers: Record, + protocols: string[], ): Promise { const WebSocket = await importWebSocket(); - // HACK: Sanitize WebSocket-specific headers. If we don't do this, some WebSocket implementations (i.e. native WebSocket in Node.js) will fail to connect. - for (const [k, v] of c.req.raw.headers.entries()) { - if (!k.startsWith("sec-") && k !== "connection" && k !== "upgrade") { - headers[k] = v; - } - } - // WebSocket state interface WsState { targetWs?: WebSocket; @@ -43,7 +36,7 @@ export async function createWebSocketProxy( } // Create WebSocket - const targetWs = new WebSocket(targetUrl, { headers }); + const targetWs = new WebSocket(targetUrl, protocols); state.targetWs = targetWs; // Setup connection promise