diff --git a/packages/rivetkit/src/client/config.ts b/packages/rivetkit/src/client/config.ts index bc4c15b9d..7e9240438 100644 --- a/packages/rivetkit/src/client/config.ts +++ b/packages/rivetkit/src/client/config.ts @@ -14,6 +14,13 @@ export const ClientConfigSchema = z.object({ }) .default({}), + token: z + .string() + .nullable() + .default(() => getEnvUniversal("RIVET_TOKEN") ?? null), + + headers: z.record(z.string()).optional().default({}), + /** Endpoint to connect to the Rivet engine. Can be configured via RIVET_ENGINE env var. */ endpoint: z .string() diff --git a/packages/rivetkit/src/common/actor-router-consts.ts b/packages/rivetkit/src/common/actor-router-consts.ts index 294389f6d..0abd47204 100644 --- a/packages/rivetkit/src/common/actor-router-consts.ts +++ b/packages/rivetkit/src/common/actor-router-consts.ts @@ -18,6 +18,8 @@ export const HEADER_CONN_ID = "x-rivet-conn"; export const HEADER_CONN_TOKEN = "x-rivet-conn-token"; +export const HEADER_RIVET_TOKEN = "x-rivet-token"; + // MARK: Manager Gateway Headers export const HEADER_RIVET_TARGET = "x-rivet-target"; export const HEADER_RIVET_ACTOR = "x-rivet-actor"; @@ -29,6 +31,7 @@ export const WS_PROTOCOL_TARGET = "rivet_target."; export const WS_PROTOCOL_ACTOR = "rivet_actor."; export const WS_PROTOCOL_ENCODING = "rivet_encoding."; export const WS_PROTOCOL_CONN_PARAMS = "rivet_conn_params."; +export const WS_PROTOCOL_TOKEN = "rivet_token."; // MARK: WebSocket Inline Test Protocol Prefixes export const WS_PROTOCOL_TRANSPORT = "test_transport."; @@ -50,4 +53,5 @@ export const ALLOWED_PUBLIC_HEADERS = [ HEADER_CONN_TOKEN, HEADER_RIVET_TARGET, HEADER_RIVET_ACTOR, + HEADER_RIVET_TOKEN, ]; diff --git a/packages/rivetkit/src/remote-manager-driver/actor-http-client.ts b/packages/rivetkit/src/remote-manager-driver/actor-http-client.ts index 1e5e7eee1..a74d1300d 100644 --- a/packages/rivetkit/src/remote-manager-driver/actor-http-client.ts +++ b/packages/rivetkit/src/remote-manager-driver/actor-http-client.ts @@ -1,4 +1,9 @@ import type { ClientConfig } from "@/client/config"; +import { + HEADER_RIVET_ACTOR, + HEADER_RIVET_TARGET, + HEADER_RIVET_TOKEN, +} from "@/common/actor-router-consts"; import { combineUrlPath } from "@/utils"; import { getEndpoint } from "./api-utils"; @@ -14,7 +19,11 @@ export async function sendHttpRequestToActor( // Handle body properly based on method and presence let bodyToSend: ArrayBuffer | null = null; - const guardHeaders = buildGuardHeadersForHttp(actorRequest, actorId); + const guardHeaders = buildGuardHeadersForHttp( + runConfig, + actorRequest, + actorId, + ); if (actorRequest.method !== "GET" && actorRequest.method !== "HEAD") { if (actorRequest.bodyUsed) { @@ -53,6 +62,7 @@ function mutableResponse(fetchRes: Response): Response { } function buildGuardHeadersForHttp( + runConfig: ClientConfig, actorRequest: Request, actorId: string, ): Headers { @@ -61,9 +71,15 @@ function buildGuardHeadersForHttp( for (const [key, value] of actorRequest.headers.entries()) { headers.set(key, value); } + // Add extra headers from config + for (const [key, value] of Object.entries(runConfig.headers)) { + headers.set(key, value); + } // Add guard-specific headers - headers.set("x-rivet-target", "actor"); - headers.set("x-rivet-actor", actorId); - headers.set("x-rivet-port", "main"); + headers.set(HEADER_RIVET_TARGET, "actor"); + headers.set(HEADER_RIVET_ACTOR, actorId); + if (runConfig.token) { + headers.set(HEADER_RIVET_TOKEN, runConfig.token); + } return headers; } diff --git a/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts b/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts index 7b14c3b20..023db8930 100644 --- a/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts +++ b/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts @@ -7,6 +7,7 @@ import { WS_PROTOCOL_ENCODING, WS_PROTOCOL_STANDARD as WS_PROTOCOL_RIVETKIT, WS_PROTOCOL_TARGET, + WS_PROTOCOL_TOKEN, } from "@/common/actor-router-consts"; import { importWebSocket } from "@/common/websocket"; import type { Encoding, UniversalWebSocket } from "@/mod"; @@ -37,7 +38,7 @@ export async function openWebSocketToActor( // Create WebSocket connection const ws = new WebSocket( guardUrl, - buildWebSocketProtocols(actorId, encoding, params), + buildWebSocketProtocols(runConfig, actorId, encoding, params), ); // Set binary type to arraybuffer for proper encoding support @@ -49,6 +50,7 @@ export async function openWebSocketToActor( } export function buildWebSocketProtocols( + runConfig: ClientConfig, actorId: string, encoding: Encoding, params?: unknown, @@ -58,6 +60,9 @@ export function buildWebSocketProtocols( protocols.push(`${WS_PROTOCOL_TARGET}actor`); protocols.push(`${WS_PROTOCOL_ACTOR}${actorId}`); protocols.push(`${WS_PROTOCOL_ENCODING}${encoding}`); + if (runConfig.token) { + protocols.push(`${WS_PROTOCOL_TOKEN}${runConfig.token}`); + } if (params) { protocols.push( `${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`, diff --git a/packages/rivetkit/src/remote-manager-driver/api-utils.ts b/packages/rivetkit/src/remote-manager-driver/api-utils.ts index 243c2fad4..196627050 100644 --- a/packages/rivetkit/src/remote-manager-driver/api-utils.ts +++ b/packages/rivetkit/src/remote-manager-driver/api-utils.ts @@ -33,10 +33,19 @@ export async function apiCall( logger().debug({ msg: "making api call", method, url }); + const headers: Record = { + ...config.headers, + }; + + // Add Authorization header if token is provided + if (config.token) { + headers.Authorization = `Bearer ${config.token}`; + } + return await sendHttpRequest({ method, url, - headers: {}, + headers, body, encoding: "json", skipParseResponse: false, diff --git a/packages/rivetkit/src/remote-manager-driver/mod.ts b/packages/rivetkit/src/remote-manager-driver/mod.ts index cc0fd382f..ac63db79c 100644 --- a/packages/rivetkit/src/remote-manager-driver/mod.ts +++ b/packages/rivetkit/src/remote-manager-driver/mod.ts @@ -246,7 +246,12 @@ export class RemoteManagerDriver implements ManagerDriver { }); // Build protocols - const protocols = buildWebSocketProtocols(actorId, encoding, params); + const protocols = buildWebSocketProtocols( + this.#config, + actorId, + encoding, + params, + ); const args = await createWebSocketProxy(c, wsGuardUrl, protocols); return await upgradeWebSocket(() => args)(c, noopNext());