Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions packages/rivetkit/src/client/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ export const ClientConfigSchema = z.object({
})
.default({}),

token: z
.string()
.nullable()
.default(() => getEnvUniversal("RIVET_TOKEN") ?? null),

headers: z.record(z.string()).optional().default({}),

/** Endpoint to connect to the Rivet engine. Can be configured via RIVET_ENGINE env var. */
endpoint: z
.string()
Expand Down
4 changes: 4 additions & 0 deletions packages/rivetkit/src/common/actor-router-consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ export const HEADER_CONN_ID = "x-rivet-conn";

export const HEADER_CONN_TOKEN = "x-rivet-conn-token";

export const HEADER_RIVET_TOKEN = "x-rivet-token";

// MARK: Manager Gateway Headers
export const HEADER_RIVET_TARGET = "x-rivet-target";
export const HEADER_RIVET_ACTOR = "x-rivet-actor";
Expand All @@ -29,6 +31,7 @@ export const WS_PROTOCOL_TARGET = "rivet_target.";
export const WS_PROTOCOL_ACTOR = "rivet_actor.";
export const WS_PROTOCOL_ENCODING = "rivet_encoding.";
export const WS_PROTOCOL_CONN_PARAMS = "rivet_conn_params.";
export const WS_PROTOCOL_TOKEN = "rivet_token.";

// MARK: WebSocket Inline Test Protocol Prefixes
export const WS_PROTOCOL_TRANSPORT = "test_transport.";
Expand All @@ -50,4 +53,5 @@ export const ALLOWED_PUBLIC_HEADERS = [
HEADER_CONN_TOKEN,
HEADER_RIVET_TARGET,
HEADER_RIVET_ACTOR,
HEADER_RIVET_TOKEN,
];
24 changes: 20 additions & 4 deletions packages/rivetkit/src/remote-manager-driver/actor-http-client.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import type { ClientConfig } from "@/client/config";
import {
HEADER_RIVET_ACTOR,
HEADER_RIVET_TARGET,
HEADER_RIVET_TOKEN,
} from "@/common/actor-router-consts";
import { combineUrlPath } from "@/utils";
import { getEndpoint } from "./api-utils";

Expand All @@ -14,7 +19,11 @@ export async function sendHttpRequestToActor(

// Handle body properly based on method and presence
let bodyToSend: ArrayBuffer | null = null;
const guardHeaders = buildGuardHeadersForHttp(actorRequest, actorId);
const guardHeaders = buildGuardHeadersForHttp(
runConfig,
actorRequest,
actorId,
);

if (actorRequest.method !== "GET" && actorRequest.method !== "HEAD") {
if (actorRequest.bodyUsed) {
Expand Down Expand Up @@ -53,6 +62,7 @@ function mutableResponse(fetchRes: Response): Response {
}

function buildGuardHeadersForHttp(
runConfig: ClientConfig,
actorRequest: Request,
actorId: string,
): Headers {
Expand All @@ -61,9 +71,15 @@ function buildGuardHeadersForHttp(
for (const [key, value] of actorRequest.headers.entries()) {
headers.set(key, value);
}
// Add extra headers from config
for (const [key, value] of Object.entries(runConfig.headers)) {
headers.set(key, value);
}
// Add guard-specific headers
headers.set("x-rivet-target", "actor");
headers.set("x-rivet-actor", actorId);
headers.set("x-rivet-port", "main");
headers.set(HEADER_RIVET_TARGET, "actor");
headers.set(HEADER_RIVET_ACTOR, actorId);
if (runConfig.token) {
headers.set(HEADER_RIVET_TOKEN, runConfig.token);
}
return headers;
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
WS_PROTOCOL_ENCODING,
WS_PROTOCOL_STANDARD as WS_PROTOCOL_RIVETKIT,
WS_PROTOCOL_TARGET,
WS_PROTOCOL_TOKEN,
} from "@/common/actor-router-consts";
import { importWebSocket } from "@/common/websocket";
import type { Encoding, UniversalWebSocket } from "@/mod";
Expand Down Expand Up @@ -37,7 +38,7 @@ export async function openWebSocketToActor(
// Create WebSocket connection
const ws = new WebSocket(
guardUrl,
buildWebSocketProtocols(actorId, encoding, params),
buildWebSocketProtocols(runConfig, actorId, encoding, params),
);

// Set binary type to arraybuffer for proper encoding support
Expand All @@ -49,6 +50,7 @@ export async function openWebSocketToActor(
}

export function buildWebSocketProtocols(
runConfig: ClientConfig,
actorId: string,
encoding: Encoding,
params?: unknown,
Expand All @@ -58,6 +60,9 @@ export function buildWebSocketProtocols(
protocols.push(`${WS_PROTOCOL_TARGET}actor`);
protocols.push(`${WS_PROTOCOL_ACTOR}${actorId}`);
protocols.push(`${WS_PROTOCOL_ENCODING}${encoding}`);
if (runConfig.token) {
protocols.push(`${WS_PROTOCOL_TOKEN}${runConfig.token}`);
}
if (params) {
protocols.push(
`${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`,
Expand Down
11 changes: 10 additions & 1 deletion packages/rivetkit/src/remote-manager-driver/api-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,19 @@ export async function apiCall<TInput = unknown, TOutput = unknown>(

logger().debug({ msg: "making api call", method, url });

const headers: Record<string, string> = {
...config.headers,
};

// Add Authorization header if token is provided
if (config.token) {
headers.Authorization = `Bearer ${config.token}`;
}

return await sendHttpRequest<TInput, TOutput>({
method,
url,
headers: {},
headers,
body,
encoding: "json",
skipParseResponse: false,
Expand Down
7 changes: 6 additions & 1 deletion packages/rivetkit/src/remote-manager-driver/mod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,12 @@ export class RemoteManagerDriver implements ManagerDriver {
});

// Build protocols
const protocols = buildWebSocketProtocols(actorId, encoding, params);
const protocols = buildWebSocketProtocols(
this.#config,
actorId,
encoding,
params,
);
const args = await createWebSocketProxy(c, wsGuardUrl, protocols);

return await upgradeWebSocket(() => args)(c, noopNext());
Expand Down
Loading