Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions ARCHITECTURE.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,4 @@
- Client & inline client
- ManagerDriver
- ActorDriver
- GenericConnGlobalState & other generic drivers: tracks actual connections separately from the actual conn state
- TODO: Can we remove the "generic" prefix?

19 changes: 1 addition & 18 deletions packages/cloudflare-workers/src/actor-driver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@ import type {
RegistryConfig,
RunConfig,
} from "rivetkit";
import {
createGenericConnDrivers,
GenericConnGlobalState,
lookupInRegistry,
} from "rivetkit";
import { lookupInRegistry } from "rivetkit";
import type { Client } from "rivetkit/client";
import type {
ActorDriver,
Expand Down Expand Up @@ -52,7 +48,6 @@ class ActorHandler {
actor?: AnyActorInstance;
actorPromise?: ReturnType<typeof promiseWithResolvers<void>> =
promiseWithResolvers();
genericConnGlobalState = new GenericConnGlobalState();
}

export class CloudflareActorsActorDriver implements ActorDriver {
Expand Down Expand Up @@ -116,11 +111,7 @@ export class CloudflareActorsActorDriver implements ActorDriver {
handler.actor = definition.instantiate();

// Start actor
const connDrivers = createGenericConnDrivers(
handler.genericConnGlobalState,
);
await handler.actor.start(
connDrivers,
this,
this.#inlineClient,
actorId,
Expand All @@ -136,14 +127,6 @@ export class CloudflareActorsActorDriver implements ActorDriver {
return handler.actor;
}

getGenericConnGlobalState(actorId: string): GenericConnGlobalState {
const handler = this.#actors.get(actorId);
if (!handler) {
throw new Error(`Actor ${actorId} not loaded`);
}
return handler.genericConnGlobalState;
}

getContext(actorId: string): DriverContext {
const state = this.#globalState.getDOState(actorId);
return { state: state.ctx };
Expand Down
49 changes: 0 additions & 49 deletions packages/rivetkit/fixtures/driver-test-suite/conn-liveness.ts

This file was deleted.

3 changes: 0 additions & 3 deletions packages/rivetkit/fixtures/driver-test-suite/registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import {
syncActionActor,
} from "./action-types";
import { onStateChangeActor } from "./actor-onstatechange";
import { connLivenessActor } from "./conn-liveness";
import { counterWithParams } from "./conn-params";
import { connStateActor } from "./conn-state";
// Import actors from individual files
Expand Down Expand Up @@ -82,8 +81,6 @@ export const registry = setup({
counterWithParams,
// From conn-state.ts
connStateActor,
// From actor-conn.ts
connLivenessActor,
// From metadata.ts
metadataActor,
// From vars.ts
Expand Down
4 changes: 0 additions & 4 deletions packages/rivetkit/schemas/actor-persist/v1.bare
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ type PersistedConnection struct {
id: str
# Connection token
token: str
# Connection driver type
driver: str
# Connection driver state
driverState: data
# Connection parameters
parameters: data
# Connection state
Expand Down
2 changes: 1 addition & 1 deletion packages/rivetkit/src/actor/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type { ActorKey } from "@/actor/mod";
import type { Client } from "@/client/client";
import type { Logger } from "@/common/log";
import type { Registry } from "@/registry/mod";
import type { Conn, ConnId } from "./connection";
import type { Conn, ConnId } from "./conn";
import type { ActorContext } from "./context";
import type { AnyDatabaseProvider, InferDatabaseClient } from "./database";
import type { SaveStateOptions } from "./instance";
Expand Down
2 changes: 1 addition & 1 deletion packages/rivetkit/src/actor/config.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { z } from "zod";
import type { UniversalWebSocket } from "@/common/websocket-interface";
import type { ActionContext } from "./action";
import type { Conn } from "./connection";
import type { Conn } from "./conn";
import type { ActorContext } from "./context";
import type { AnyDatabaseProvider } from "./database";

Expand Down
205 changes: 205 additions & 0 deletions packages/rivetkit/src/actor/conn-drivers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import type { SSEStreamingApi } from "hono/streaming";
import type { WSContext } from "hono/ws";
import type { WebSocket } from "ws";
import type { AnyConn } from "@/actor/conn";
import type { AnyActorInstance } from "@/actor/instance";
import type { CachedSerializer, Encoding } from "@/actor/protocol/serde";
import { encodeDataToString } from "@/actor/protocol/serde";
import type * as protocol from "@/schemas/client-protocol/mod";
import { assertUnreachable, type promiseWithResolvers } from "@/utils";

export enum ConnDriverKind {
WEBSOCKET = 0,
SSE = 1,
HTTP = 2,
}

export enum ConnReadyState {
UNKNOWN = -1,
CONNECTING = 0,
OPEN = 1,
CLOSING = 2,
CLOSED = 3,
}

export interface ConnDriverWebSocketState {
encoding: Encoding;
websocket: WSContext;
closePromise: ReturnType<typeof promiseWithResolvers<void>>;
}

export interface ConnDriverSseState {
encoding: Encoding;
stream: SSEStreamingApi;
}

export type ConnDriverHttpState = Record<never, never>;

export type ConnDriverState =
| { [ConnDriverKind.WEBSOCKET]: ConnDriverWebSocketState }
| { [ConnDriverKind.SSE]: ConnDriverSseState }
| { [ConnDriverKind.HTTP]: ConnDriverHttpState };

export interface ConnDriver<State> {
sendMessage?(
actor: AnyActorInstance,
conn: AnyConn,
state: State,
message: CachedSerializer<protocol.ToClient>,
): void;

/**
* This returns a promise since we commonly disconnect at the end of a program, and not waiting will cause the socket to not close cleanly.
*/
disconnect(
actor: AnyActorInstance,
conn: AnyConn,
state: State,
reason?: string,
): Promise<void>;

/**
* Returns the ready state of the connection.
* This is used to determine if the connection is ready to send messages, or if the connection is stale.
*/
getConnectionReadyState(
actor: AnyActorInstance,
conn: AnyConn,
state: State,
): ConnReadyState | undefined;
}

// MARK: WebSocket
const WEBSOCKET_DRIVER: ConnDriver<ConnDriverWebSocketState> = {
sendMessage: (
actor: AnyActorInstance,
_conn: AnyConn,
state: ConnDriverWebSocketState,
message: CachedSerializer<protocol.ToClient>,
) => {
const serialized = message.serialize(state.encoding);

actor.rLog.debug({
msg: "sending websocket message",
encoding: state.encoding,
dataType: typeof serialized,
isUint8Array: serialized instanceof Uint8Array,
isArrayBuffer: serialized instanceof ArrayBuffer,
dataLength: (serialized as any).byteLength || (serialized as any).length,
});

// Convert Uint8Array to ArrayBuffer for proper transmission
if (serialized instanceof Uint8Array) {
const buffer = serialized.buffer.slice(
serialized.byteOffset,
serialized.byteOffset + serialized.byteLength,
);
// Handle SharedArrayBuffer case
if (buffer instanceof SharedArrayBuffer) {
const arrayBuffer = new ArrayBuffer(buffer.byteLength);
new Uint8Array(arrayBuffer).set(new Uint8Array(buffer));
actor.rLog.debug({
msg: "converted SharedArrayBuffer to ArrayBuffer",
byteLength: arrayBuffer.byteLength,
});
state.websocket.send(arrayBuffer);
} else {
actor.rLog.debug({
msg: "sending ArrayBuffer",
byteLength: buffer.byteLength,
});
state.websocket.send(buffer);
}
} else {
actor.rLog.debug({
msg: "sending string data",
length: (serialized as string).length,
});
state.websocket.send(serialized);
}
},

disconnect: async (
_actor: AnyActorInstance,
_conn: AnyConn,
state: ConnDriverWebSocketState,
reason?: string,
) => {
// Close socket
state.websocket.close(1000, reason);

// Create promise to wait for socket to close gracefully
await state.closePromise.promise;
},

getConnectionReadyState: (
_actor: AnyActorInstance,
_conn: AnyConn,
state: ConnDriverWebSocketState,
): ConnReadyState | undefined => {
return state.websocket.readyState;
},
};

// MARK: SSE
const SSE_DRIVER: ConnDriver<ConnDriverSseState> = {
sendMessage: (
_actor: AnyActorInstance,
_conn: AnyConn,
state: ConnDriverSseState,
message: CachedSerializer<protocol.ToClient>,
) => {
state.stream.writeSSE({
data: encodeDataToString(message.serialize(state.encoding)),
});
},

disconnect: async (
_actor: AnyActorInstance,
_conn: AnyConn,
state: ConnDriverSseState,
_reason?: string,
) => {
state.stream.close();
},

getConnectionReadyState: (
_actor: AnyActorInstance,
_conn: AnyConn,
state: ConnDriverSseState,
): ConnReadyState | undefined => {
if (state.stream.aborted || state.stream.closed) {
return ConnReadyState.CLOSED;
}

return ConnReadyState.OPEN;
},
};

// MARK: HTTP
const HTTP_DRIVER: ConnDriver<ConnDriverHttpState> = {
getConnectionReadyState(_actor, _conn) {
// TODO: This might not be the correct logic
return ConnReadyState.OPEN;
},
disconnect: async () => {
// Noop
// TODO: Abort the request
},
};

/** List of all connection drivers. */
export const CONN_DRIVERS: Record<ConnDriverKind, ConnDriver<unknown>> = {
[ConnDriverKind.WEBSOCKET]: WEBSOCKET_DRIVER,
[ConnDriverKind.SSE]: SSE_DRIVER,
[ConnDriverKind.HTTP]: HTTP_DRIVER,
};

export function getConnDriverFromState(
state: ConnDriverState,
): ConnDriver<unknown> {
if (ConnDriverKind.WEBSOCKET in state) return WEBSOCKET_DRIVER;
else if (ConnDriverKind.SSE in state) return SSE_DRIVER;
else if (ConnDriverKind.HTTP in state) return SSE_DRIVER;
else assertUnreachable(state);
}
Loading
Loading