Skip to content
This repository was archived by the owner on Oct 22, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions packages/cloudflare-workers/src/manager-driver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import {
type GetForIdInput,
type GetOrCreateWithKeyInput,
type GetWithKeyInput,
HEADER_AUTH_DATA,
HEADER_CONN_PARAMS,
HEADER_ENCODING,
type ManagerDisplayInformation,
type ManagerDriver,
WS_PROTOCOL_ACTOR,
WS_PROTOCOL_CONN_PARAMS,
WS_PROTOCOL_ENCODING,
WS_PROTOCOL_STANDARD,
WS_PROTOCOL_TARGET,
} from "rivetkit/driver-helpers";
import { ActorAlreadyExists, InternalError } from "rivetkit/errors";
import { getCloudflareAmbientEnv } from "./handler";
Expand Down Expand Up @@ -81,16 +83,22 @@ export class CloudflareActorsManagerDriver implements ManagerDriver {
const id = env.ACTOR_DO.idFromString(actorId);
const stub = env.ACTOR_DO.get(id);

const protocols: string[] = [];
protocols.push(WS_PROTOCOL_STANDARD);
protocols.push(`${WS_PROTOCOL_TARGET}actor`);
protocols.push(`${WS_PROTOCOL_ACTOR}${actorId}`);
protocols.push(`${WS_PROTOCOL_ENCODING}${encoding}`);
if (params) {
protocols.push(
`${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`,
);
}

const headers: Record<string, string> = {
Upgrade: "websocket",
Connection: "Upgrade",
[HEADER_ENCODING]: encoding,
"sec-websocket-protocol": protocols.join(", "),
};
if (params) {
headers[HEADER_CONN_PARAMS] = JSON.stringify(params);
}
// HACK: See packages/drivers/cloudflare-workers/src/websocket.ts
headers["sec-websocket-protocol"] = "rivetkit";

// Use the path parameter to determine the URL
const normalizedPath = path.startsWith("/") ? path : `/${path}`;
Expand Down Expand Up @@ -152,7 +160,6 @@ export class CloudflareActorsManagerDriver implements ManagerDriver {
actorId: string,
encoding: Encoding,
params: unknown,
authData: unknown,
): Promise<Response> {
logger().debug({
msg: "forwarding websocket to durable object",
Expand Down Expand Up @@ -188,14 +195,18 @@ export class CloudflareActorsManagerDriver implements ManagerDriver {
}
}

// Add RivetKit headers
actorRequest.headers.set(HEADER_ENCODING, encoding);
// Build protocols for WebSocket connection
const protocols: string[] = [];
protocols.push(WS_PROTOCOL_STANDARD);
protocols.push(`${WS_PROTOCOL_TARGET}actor`);
protocols.push(`${WS_PROTOCOL_ACTOR}${actorId}`);
protocols.push(`${WS_PROTOCOL_ENCODING}${encoding}`);
if (params) {
actorRequest.headers.set(HEADER_CONN_PARAMS, JSON.stringify(params));
}
if (authData) {
actorRequest.headers.set(HEADER_AUTH_DATA, JSON.stringify(authData));
protocols.push(
`${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`,
);
}
actorRequest.headers.set("sec-websocket-protocol", protocols.join(", "));

const id = c.env.ACTOR_DO.idFromString(actorId);
const stub = c.env.ACTOR_DO.get(id);
Expand Down
8 changes: 6 additions & 2 deletions packages/cloudflare-workers/src/websocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import type { UpgradeWebSocket, WSEvents, WSReadyState } from "hono/ws";
import { defineWebSocketHelper, WSContext } from "hono/ws";
import { WS_PROTOCOL_STANDARD } from "rivetkit/driver-helpers";

// Based on https://github.com/honojs/hono/issues/1153#issuecomment-1767321332
export const upgradeWebSocket: UpgradeWebSocket<
Expand Down Expand Up @@ -62,8 +63,11 @@ export const upgradeWebSocket: UpgradeWebSocket<

// Set Sec-WebSocket-Protocol if does not exist
const protocols = c.req.header("Sec-WebSocket-Protocol");
if (typeof protocols === "string" && protocols.includes("rivetkit")) {
headers["Sec-WebSocket-Protocol"] = "rivetkit";
if (
typeof protocols === "string" &&
protocols.includes(WS_PROTOCOL_STANDARD)
) {
headers["Sec-WebSocket-Protocol"] = WS_PROTOCOL_STANDARD;
}

return new Response(null, {
Expand Down
10 changes: 0 additions & 10 deletions packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,6 @@ export const rawWebSocketActor = actor({
messageCount: ctx.state.messageCount,
}),
);
} else if (parsed.type === "getAuthData") {
// Auth data is not directly available in raw WebSocket handler
// Send a message indicating this limitation
websocket.send(
JSON.stringify({
type: "authData",
authData: null,
message: "Auth data not available in raw WebSocket handler",
}),
);
} else if (parsed.type === "getRequestInfo") {
// Send back the request URL info
websocket.send(
Expand Down
2 changes: 0 additions & 2 deletions packages/rivetkit/src/actor/instance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,6 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
state: CS,
driverId: ConnectionDriver,
driverState: unknown,
authData: unknown,
): Promise<Conn<S, CP, CS, V, I, DB>> {
this.#assertReady();

Expand All @@ -935,7 +934,6 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
connDriverState: driverState,
params: params,
state: state,
authData: authData,
lastSeen: Date.now(),
subscriptions: [],
};
Expand Down
1 change: 0 additions & 1 deletion packages/rivetkit/src/actor/persisted.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ export interface PersistedConn<CP, CS> {
connDriverState: unknown;
params: CP;
state: CS;
authData?: unknown;
subscriptions: PersistedSubscription[];
lastSeen: number;
}
Expand Down
12 changes: 0 additions & 12 deletions packages/rivetkit/src/actor/router-endpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ export interface ConnectWebSocketOpts {
encoding: Encoding;
actorId: string;
params: unknown;
authData: unknown;
}

export interface ConnectWebSocketOutput {
Expand All @@ -67,7 +66,6 @@ export interface ConnectSseOpts {
encoding: Encoding;
params: unknown;
actorId: string;
authData: unknown;
}

export interface ConnectSseOutput {
Expand All @@ -81,7 +79,6 @@ export interface ActionOpts {
actionName: string;
actionArgs: unknown[];
actorId: string;
authData: unknown;
}

export interface ActionOutput {
Expand All @@ -99,14 +96,12 @@ export interface ConnsMessageOpts {
export interface FetchOpts {
request: Request;
actorId: string;
authData: unknown;
}

export interface WebSocketOpts {
request: Request;
websocket: UniversalWebSocket;
actorId: string;
authData: unknown;
}

/**
Expand All @@ -119,7 +114,6 @@ export async function handleWebSocketConnect(
actorId: string,
encoding: Encoding,
parameters: unknown,
authData: unknown,
): Promise<UpgradeWebSocketArgs> {
const exposeInternalError = req ? getRequestExposeInternalError(req) : false;

Expand Down Expand Up @@ -189,7 +183,6 @@ export async function handleWebSocketConnect(
connState,
CONNECTION_DRIVER_WEBSOCKET,
{ encoding } satisfies GenericWebSocketDriverState,
authData,
);

// Unblock other handlers
Expand Down Expand Up @@ -339,7 +332,6 @@ export async function handleSseConnect(
_runConfig: RunConfig,
actorDriver: ActorDriver,
actorId: string,
authData: unknown,
) {
c.header("Content-Encoding", "Identity");

Expand Down Expand Up @@ -376,7 +368,6 @@ export async function handleSseConnect(
connState,
CONNECTION_DRIVER_SSE,
{ encoding } satisfies GenericSseDriverState,
authData,
);

// Wait for close
Expand Down Expand Up @@ -459,7 +450,6 @@ export async function handleAction(
actorDriver: ActorDriver,
actionName: string,
actorId: string,
authData: unknown,
) {
const encoding = getRequestEncoding(c.req);
const parameters = getRequestConnParams(c.req);
Expand Down Expand Up @@ -491,7 +481,6 @@ export async function handleAction(
connState,
CONNECTION_DRIVER_HTTP,
{} satisfies GenericHttpDriverState,
authData,
);

// Call action
Expand Down Expand Up @@ -562,7 +551,6 @@ export async function handleRawWebSocketHandler(
path: string,
actorDriver: ActorDriver,
actorId: string,
authData: unknown,
): Promise<UpgradeWebSocketArgs> {
const actor = await actorDriver.loadActor(actorId);

Expand Down
40 changes: 3 additions & 37 deletions packages/rivetkit/src/actor/router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import {
handleWebSocketConnect,
} from "@/actor/router-endpoints";
import {
HEADER_AUTH_DATA,
HEADER_CONN_ID,
HEADER_CONN_PARAMS,
HEADER_CONN_TOKEN,
Expand Down Expand Up @@ -84,13 +83,11 @@ export function createActorRouter(
return upgradeWebSocket(async (c) => {
const encodingRaw = c.req.header(HEADER_ENCODING);
const connParamsRaw = c.req.header(HEADER_CONN_PARAMS);
const authDataRaw = c.req.header(HEADER_AUTH_DATA);

const encoding = EncodingSchema.parse(encodingRaw);
const connParams = connParamsRaw
? JSON.parse(connParamsRaw)
: undefined;
const authData = authDataRaw ? JSON.parse(authDataRaw) : undefined;

return await handleWebSocketConnect(
c.req.raw,
Expand All @@ -99,7 +96,6 @@ export function createActorRouter(
c.env.actorId,
encoding,
connParams,
authData,
);
})(c, noopNext());
} else {
Expand All @@ -111,32 +107,13 @@ export function createActorRouter(
});

router.get("/connect/sse", async (c) => {
const authDataRaw = c.req.header(HEADER_AUTH_DATA);
let authData: unknown;
if (authDataRaw) {
authData = JSON.parse(authDataRaw);
}

return handleSseConnect(c, runConfig, actorDriver, c.env.actorId, authData);
return handleSseConnect(c, runConfig, actorDriver, c.env.actorId);
});

router.post("/action/:action", async (c) => {
const actionName = c.req.param("action");

const authDataRaw = c.req.header(HEADER_AUTH_DATA);
let authData: unknown;
if (authDataRaw) {
authData = JSON.parse(authDataRaw);
}

return handleAction(
c,
runConfig,
actorDriver,
actionName,
c.env.actorId,
authData,
);
return handleAction(c, runConfig, actorDriver, actionName, c.env.actorId);
});

router.post("/connections/message", async (c) => {
Expand All @@ -157,12 +134,6 @@ export function createActorRouter(

// Raw HTTP endpoints - /http/*
router.all("/raw/http/*", async (c) => {
const authDataRaw = c.req.header(HEADER_AUTH_DATA);
let authData: unknown;
if (authDataRaw) {
authData = JSON.parse(authDataRaw);
}

const actor = await actorDriver.loadActor(c.env.actorId);

// TODO: This is not a clean way of doing this since `/http/` might exist mid-path
Expand All @@ -186,9 +157,7 @@ export function createActorRouter(
});

// Call the actor's onFetch handler - it will throw appropriate errors
const response = await actor.handleFetch(correctedRequest, {
auth: authData,
});
const response = await actor.handleFetch(correctedRequest, {});

// This should never happen now since handleFetch throws errors
if (!response) {
Expand All @@ -205,13 +174,11 @@ export function createActorRouter(
return upgradeWebSocket(async (c) => {
const encodingRaw = c.req.header(HEADER_ENCODING);
const connParamsRaw = c.req.header(HEADER_CONN_PARAMS);
const authDataRaw = c.req.header(HEADER_AUTH_DATA);

const encoding = EncodingSchema.parse(encodingRaw);
const connParams = connParamsRaw
? JSON.parse(connParamsRaw)
: undefined;
const authData = authDataRaw ? JSON.parse(authDataRaw) : undefined;

const url = new URL(c.req.url);
const pathWithQuery = c.req.path + url.search;
Expand All @@ -229,7 +196,6 @@ export function createActorRouter(
pathWithQuery,
actorDriver,
c.env.actorId,
authData,
);
})(c, noopNext());
} else {
Expand Down
31 changes: 23 additions & 8 deletions packages/rivetkit/src/common/actor-router-consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,34 @@ export const PATH_CONNECT_WEBSOCKET = "/connect/websocket";
export const PATH_RAW_WEBSOCKET_PREFIX = "/raw/websocket/";

// MARK: Headers
export const HEADER_ACTOR_QUERY = "X-RivetKit-Query";
export const HEADER_ACTOR_QUERY = "x-rivet-query";

export const HEADER_ENCODING = "X-RivetKit-Encoding";
export const HEADER_ENCODING = "x-rivet-encoding";

// IMPORTANT: Params must be in headers or in an E2EE part of the request (i.e. NOT the URL or query string) in order to ensure that tokens can be securely passed in params.
export const HEADER_CONN_PARAMS = "X-RivetKit-Conn-Params";
export const HEADER_CONN_PARAMS = "x-rivet-conn-params";

// Internal header
export const HEADER_AUTH_DATA = "X-RivetKit-Auth-Data";
export const HEADER_ACTOR_ID = "x-rivet-actor";

export const HEADER_ACTOR_ID = "X-RivetKit-Actor";
export const HEADER_CONN_ID = "x-rivet-conn";

export const HEADER_CONN_ID = "X-RivetKit-Conn";
export const HEADER_CONN_TOKEN = "x-rivet-conn-token";

export const HEADER_CONN_TOKEN = "X-RivetKit-Conn-Token";
// MARK: Manager Gateway Headers
export const HEADER_RIVET_TARGET = "x-rivet-target";
export const HEADER_RIVET_ACTOR = "x-rivet-actor";

// MARK: WebSocket Protocol Prefixes
/** Some servers (such as node-ws & Cloudflare) require explicitly match a certain WebSocket protocol. This gives us a static protocol to match against. */
export const WS_PROTOCOL_STANDARD = "rivet";
export const WS_PROTOCOL_TARGET = "rivet_target.";
export const WS_PROTOCOL_ACTOR = "rivet_actor.";
export const WS_PROTOCOL_ENCODING = "rivet_encoding.";
export const WS_PROTOCOL_CONN_PARAMS = "rivet_conn_params.";

// MARK: WebSocket Inline Test Protocol Prefixes
export const WS_PROTOCOL_TRANSPORT = "test_transport.";
export const WS_PROTOCOL_PATH = "test_path.";

/**
* Headers that publics can send from public clients.
Expand All @@ -35,4 +48,6 @@ export const ALLOWED_PUBLIC_HEADERS = [
HEADER_ACTOR_ID,
HEADER_CONN_ID,
HEADER_CONN_TOKEN,
HEADER_RIVET_TARGET,
HEADER_RIVET_ACTOR,
];
Loading
Loading