diff --git a/examples/counter/src/registry.ts b/examples/counter/src/registry.ts index 07ee4d22f..6d144151d 100644 --- a/examples/counter/src/registry.ts +++ b/examples/counter/src/registry.ts @@ -7,6 +7,7 @@ const counter = actor({ onAuth: () => { return true; }, + onConnect: (c) => {}, actions: { increment: (c, x: number) => { c.state.count += x; diff --git a/packages/actor/package.json b/packages/actor/package.json index d691727a4..57adb8fb6 100644 --- a/packages/actor/package.json +++ b/packages/actor/package.json @@ -75,7 +75,7 @@ "scripts": { "build": "tsup src/mod.ts src/client.ts src/log.ts src/errors.ts src/test.ts", "check-types": "tsc --noEmit", - "test": "vitest run" + "test": "vitest run --passWithNoTests" }, "dependencies": { "@rivetkit/core": "workspace:*" diff --git a/packages/core/src/actor/connection.ts b/packages/core/src/actor/connection.ts index b920ca1c5..d27b62e5a 100644 --- a/packages/core/src/actor/connection.ts +++ b/packages/core/src/actor/connection.ts @@ -16,6 +16,10 @@ export function generateConnToken(): string { return generateSecureToken(32); } +export function generatePing(): string { + return crypto.randomUUID(); +} + export type ConnId = string; export type AnyConn = Conn; diff --git a/packages/core/src/actor/instance.ts b/packages/core/src/actor/instance.ts index 9b8f1c5f0..30c9d63fe 100644 --- a/packages/core/src/actor/instance.ts +++ b/packages/core/src/actor/instance.ts @@ -9,10 +9,10 @@ import type { Logger } from "@/common/log"; import { isCborSerializable, stringifyError } from "@/common/utils"; import type { UniversalWebSocket } from "@/common/websocket-interface"; import { ActorInspector } from "@/inspector/actor"; -import type { Registry, RegistryConfig } from "@/mod"; +import type { Registry } from "@/mod"; import type { ActionContext } from "./action"; import type { ActorConfig } from "./config"; -import { Conn, type ConnId } from "./connection"; +import { Conn, type ConnId, generatePing } from "./connection"; import { ActorContext } from "./context"; import type { AnyDatabaseProvider, InferDatabaseClient } from "./database"; import type { ActorDriver, ConnDriver, ConnDrivers } from "./driver"; @@ -157,6 +157,11 @@ export class ActorInstance< #ready = false; #connections = new Map>(); + // This is used to track the last ping sent to the client, when restoring a connection + #connectionsPingRequests = new Map< + ConnId, + { ping: string; resolve: () => void } + >(); #subscriptionIndex = new Map>>(); #schedule!: Schedule; @@ -591,6 +596,8 @@ export class ActorInstance< // Set initial state this.#setPersist(persistData); + const restorePromises = []; + // Load connections for (const connPersist of this.#persist.c) { // Create connections @@ -601,13 +608,60 @@ export class ActorInstance< driver, this.#connStateEnabled, ); - this.#connections.set(conn.id, conn); - // Register event subscriptions - for (const sub of connPersist.su) { - this.#addSubscription(sub.n, conn, true); - } + // send ping, to ensure the connection is alive + + const { promise, resolve } = Promise.withResolvers(); + restorePromises.push( + Promise.race([ + promise, + new Promise((_, reject) => { + setTimeout(() => { + reject(); + }, 2500); + }), + ]) + .catch(() => { + process.nextTick(() => { + logger().debug("connection restore failed", { + connId: conn.id, + }); + this.#connections.delete(conn.id); + conn.disconnect( + "Connection restore failed, connection is not alive", + ); + this.__removeConn(conn); + }); + }) + .then(() => { + logger().debug("connection restored", { + connId: conn.id, + }); + this.#connections.set(conn.id, conn); + + // Register event subscriptions + for (const sub of connPersist.su) { + this.#addSubscription(sub.n, conn, true); + } + }), + ); + + const ping = generatePing(); + this.#connectionsPingRequests.set(conn.id, { ping, resolve }); + conn._sendMessage( + new CachedSerializer({ + b: { + p: ping, + }, + }), + ); } + + const result = await Promise.allSettled(restorePromises); + logger().info("connections restored", { + success: result.filter((r) => r.status === "fulfilled").length, + failed: result.filter((r) => r.status === "rejected").length, + }); } else { logger().info("actor creating"); @@ -818,6 +872,8 @@ export class ActorInstance< this.#persist.c.push(persist); this.saveState({ immediate: true }); + this.inspector.emitter.emit("connectionUpdated"); + // Handle connection if (this.#config.onConnect) { try { @@ -841,8 +897,6 @@ export class ActorInstance< } } - this.inspector.emitter.emit("connectionUpdated"); - // Send init message conn._sendMessage( new CachedSerializer({ @@ -890,6 +944,14 @@ export class ActorInstance< }); this.#removeSubscription(eventName, conn, false); }, + onPong: async (pong, conn) => { + const pingRequest = this.#connectionsPingRequests.get(conn.id); + if (pingRequest?.ping === pong) { + // Resolve the ping request if it matches the last sent ping + pingRequest.resolve(); + this.#connectionsPingRequests.delete(conn.id); + } + }, }); } diff --git a/packages/core/src/actor/protocol/message/mod.ts b/packages/core/src/actor/protocol/message/mod.ts index a01935bae..04a73054e 100644 --- a/packages/core/src/actor/protocol/message/mod.ts +++ b/packages/core/src/actor/protocol/message/mod.ts @@ -90,6 +90,7 @@ export interface ProcessMessageHandler< eventName: string, conn: Conn, ) => Promise; + onPong?: (pong: string, conn: Conn) => Promise; } export async function processMessage< @@ -183,6 +184,17 @@ export async function processMessage< eventName, subscribe, }); + } else if ("p" in message.b) { + if (handler.onPong === undefined) { + throw new errors.Unsupported("Ping"); + } + + const { p: pong } = message.b; + await handler.onPong(pong, conn); + logger().debug("pong response received", { + connId: conn.id, + pong, + }); } else { assertUnreachable(message.b); } diff --git a/packages/core/src/actor/protocol/message/to-client.ts b/packages/core/src/actor/protocol/message/to-client.ts index a0a5330a9..cdb5d6e36 100644 --- a/packages/core/src/actor/protocol/message/to-client.ts +++ b/packages/core/src/actor/protocol/message/to-client.ts @@ -36,6 +36,11 @@ export const EventSchema = z.object({ a: z.array(z.unknown()), }); +export const PingSchema = z.object({ + // Unique identifier for the ping + p: z.string(), +}); + export const ToClientSchema = z.object({ // Body b: z.union([ @@ -43,6 +48,7 @@ export const ToClientSchema = z.object({ z.object({ e: ErrorSchema }), z.object({ ar: ActionResponseSchema }), z.object({ ev: EventSchema }), + PingSchema, ]), }); diff --git a/packages/core/src/actor/protocol/message/to-server.ts b/packages/core/src/actor/protocol/message/to-server.ts index 257f2792b..be8045cee 100644 --- a/packages/core/src/actor/protocol/message/to-server.ts +++ b/packages/core/src/actor/protocol/message/to-server.ts @@ -16,11 +16,17 @@ const SubscriptionRequestSchema = z.object({ s: z.boolean(), }); +export const PingResponseSchema = z.object({ + // Ping + p: z.string(), +}); + export const ToServerSchema = z.object({ // Body b: z.union([ z.object({ ar: ActionRequestSchema }), z.object({ sr: SubscriptionRequestSchema }), + PingResponseSchema, ]), }); diff --git a/packages/core/src/actor/router-endpoints.ts b/packages/core/src/actor/router-endpoints.ts index 925e06a0f..bd13277fa 100644 --- a/packages/core/src/actor/router-endpoints.ts +++ b/packages/core/src/actor/router-endpoints.ts @@ -400,12 +400,13 @@ export async function handleSseConnect( .getGenericConnGlobalState(actorId) .sseStreams.delete(connId); } - if (conn && actor !== undefined) { - actor.__removeConn(conn); - } // Close the stream on error stream.close(); + } finally { + if (conn) { + actor?.__removeConn(conn); + } } }); } diff --git a/packages/core/src/client/actor-conn.ts b/packages/core/src/client/actor-conn.ts index 167cd2bcc..77fc92204 100644 --- a/packages/core/src/client/actor-conn.ts +++ b/packages/core/src/client/actor-conn.ts @@ -416,6 +416,11 @@ enc argsCount: response.b.ev.a?.length, }); this.#dispatchEvent(response.b.ev); + } else if ("p" in response.b) { + // Ping request + const ping = response.b.p; + logger().trace("received ping request", { ping }); + this.#sendMessage({ b: { p: ping } }); } else { assertUnreachable(response.b); } diff --git a/packages/core/src/driver-test-suite/tests/actor-conn.ts b/packages/core/src/driver-test-suite/tests/actor-conn.ts index 9d3bfbd62..5f11e087a 100644 --- a/packages/core/src/driver-test-suite/tests/actor-conn.ts +++ b/packages/core/src/driver-test-suite/tests/actor-conn.ts @@ -1,4 +1,4 @@ -import { describe, expect, test } from "vitest"; +import { describe, expect, test, vi } from "vitest"; import type { DriverTestConfig } from "../mod"; import { setupDriverTest } from "../utils"; @@ -257,5 +257,8 @@ export function runActorConnTests(driverTestConfig: DriverTestConfig) { ]); }); }); + describe("Ping", () => { + test.skip("should restore connections after server restart", async (c) => {}); + }); }); } diff --git a/packages/core/src/driver-test-suite/utils.ts b/packages/core/src/driver-test-suite/utils.ts index d00cd99fa..7c2fb9390 100644 --- a/packages/core/src/driver-test-suite/utils.ts +++ b/packages/core/src/driver-test-suite/utils.ts @@ -14,6 +14,7 @@ export async function setupDriverTest( ): Promise<{ client: Client; endpoint: string; + cleanup?: () => Promise; }> { if (!driverTestConfig.useRealTimers) { vi.useFakeTimers(); @@ -21,8 +22,9 @@ export async function setupDriverTest( // Build drivers const projectPath = resolve(__dirname, "../../fixtures/driver-test-suite"); - const { endpoint, cleanup } = await driverTestConfig.start(projectPath); - c.onTestFinished(cleanup); + const { endpoint, cleanup: driverCleanup } = + await driverTestConfig.start(projectPath); + c.onTestFinished(driverCleanup); let client: Client; if (driverTestConfig.clientType === "http") { @@ -49,6 +51,7 @@ export async function setupDriverTest( return { client, endpoint, + cleanup: driverCleanup, }; }