Skip to content

Commit cfe7310

Browse files
committed
feat(core): remove stale connections after restoring an actor
1 parent 093eef4 commit cfe7310

File tree

11 files changed

+120
-16
lines changed

11 files changed

+120
-16
lines changed

examples/counter/src/registry.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ const counter = actor({
77
onAuth: () => {
88
return true;
99
},
10+
onConnect: (c) => {},
1011
actions: {
1112
increment: (c, x: number) => {
1213
c.state.count += x;

packages/actor/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
"scripts": {
7676
"build": "tsup src/mod.ts src/client.ts src/log.ts src/errors.ts src/test.ts",
7777
"check-types": "tsc --noEmit",
78-
"test": "vitest run"
78+
"test": "vitest run --passWithNoTests"
7979
},
8080
"dependencies": {
8181
"@rivetkit/core": "workspace:*"

packages/core/src/actor/connection.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ export function generateConnToken(): string {
1616
return generateSecureToken(32);
1717
}
1818

19+
export function generatePing(): string {
20+
return crypto.randomUUID();
21+
}
22+
1923
export type ConnId = string;
2024

2125
export type AnyConn = Conn<any, any, any, any, any, any, any>;

packages/core/src/actor/instance.ts

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import * as cbor from "cbor-x";
22
import invariant from "invariant";
33
import onChange from "on-change";
4+
import { success } from "zod/v4";
45
import type { ActorKey } from "@/actor/mod";
56
import type * as wsToClient from "@/actor/protocol/message/to-client";
67
import type * as wsToServer from "@/actor/protocol/message/to-server";
@@ -9,10 +10,10 @@ import type { Logger } from "@/common/log";
910
import { isCborSerializable, stringifyError } from "@/common/utils";
1011
import type { UniversalWebSocket } from "@/common/websocket-interface";
1112
import { ActorInspector } from "@/inspector/actor";
12-
import type { Registry, RegistryConfig } from "@/mod";
13+
import type { Registry } from "@/mod";
1314
import type { ActionContext } from "./action";
1415
import type { ActorConfig } from "./config";
15-
import { Conn, type ConnId } from "./connection";
16+
import { Conn, type ConnId, generatePing } from "./connection";
1617
import { ActorContext } from "./context";
1718
import type { AnyDatabaseProvider, InferDatabaseClient } from "./database";
1819
import type { ActorDriver, ConnDriver, ConnDrivers } from "./driver";
@@ -157,6 +158,11 @@ export class ActorInstance<
157158
#ready = false;
158159

159160
#connections = new Map<ConnId, Conn<S, CP, CS, V, I, AD, DB>>();
161+
// This is used to track the last ping sent to the client, when restoring a connection
162+
#connectionsPingRequests = new Map<
163+
ConnId,
164+
{ ping: string; resolve: () => void }
165+
>();
160166
#subscriptionIndex = new Map<string, Set<Conn<S, CP, CS, V, I, AD, DB>>>();
161167

162168
#schedule!: Schedule;
@@ -591,6 +597,8 @@ export class ActorInstance<
591597
// Set initial state
592598
this.#setPersist(persistData);
593599

600+
const restorePromises = [];
601+
594602
// Load connections
595603
for (const connPersist of this.#persist.c) {
596604
// Create connections
@@ -601,13 +609,60 @@ export class ActorInstance<
601609
driver,
602610
this.#connStateEnabled,
603611
);
604-
this.#connections.set(conn.id, conn);
605612

606-
// Register event subscriptions
607-
for (const sub of connPersist.su) {
608-
this.#addSubscription(sub.n, conn, true);
609-
}
613+
// send ping, to ensure the connection is alive
614+
615+
const { promise, resolve } = Promise.withResolvers<void>();
616+
restorePromises.push(
617+
Promise.race([
618+
promise,
619+
new Promise<void>((_, reject) => {
620+
setTimeout(() => {
621+
reject();
622+
}, 2500);
623+
}),
624+
])
625+
.catch(() => {
626+
process.nextTick(() => {
627+
logger().debug("connection restore failed", {
628+
connId: conn.id,
629+
});
630+
this.#connections.delete(conn.id);
631+
conn.disconnect(
632+
"Connection restore failed, connection is not alive",
633+
);
634+
this.__removeConn(conn);
635+
});
636+
})
637+
.then(() => {
638+
logger().debug("connection restored", {
639+
connId: conn.id,
640+
});
641+
this.#connections.set(conn.id, conn);
642+
643+
// Register event subscriptions
644+
for (const sub of connPersist.su) {
645+
this.#addSubscription(sub.n, conn, true);
646+
}
647+
}),
648+
);
649+
650+
const ping = generatePing();
651+
this.#connectionsPingRequests.set(conn.id, { ping, resolve });
652+
conn._sendMessage(
653+
new CachedSerializer<wsToClient.ToClient>({
654+
b: {
655+
p: ping,
656+
},
657+
}),
658+
);
610659
}
660+
661+
const result = await Promise.allSettled(restorePromises);
662+
logger().info("connections restored", {
663+
success: result.filter((r) => r.status === "fulfilled").length,
664+
failed: result.filter((r) => r.status === "rejected").length,
665+
});
611666
} else {
612667
logger().info("actor creating");
613668

@@ -818,6 +873,8 @@ export class ActorInstance<
818873
this.#persist.c.push(persist);
819874
this.saveState({ immediate: true });
820875

876+
this.inspector.emitter.emit("connectionUpdated");
877+
821878
// Handle connection
822879
if (this.#config.onConnect) {
823880
try {
@@ -841,8 +898,6 @@ export class ActorInstance<
841898
}
842899
}
843900

844-
this.inspector.emitter.emit("connectionUpdated");
845-
846901
// Send init message
847902
conn._sendMessage(
848903
new CachedSerializer<wsToClient.ToClient>({
@@ -890,6 +945,14 @@ export class ActorInstance<
890945
});
891946
this.#removeSubscription(eventName, conn, false);
892947
},
948+
onPong: async (pong, conn) => {
949+
const pingRequest = this.#connectionsPingRequests.get(conn.id);
950+
if (pingRequest?.ping === pong) {
951+
// Resolve the ping request if it matches the last sent ping
952+
pingRequest.resolve();
953+
this.#connectionsPingRequests.delete(conn.id);
954+
}
955+
},
893956
});
894957
}
895958

packages/core/src/actor/protocol/message/mod.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ export interface ProcessMessageHandler<
9090
eventName: string,
9191
conn: Conn<S, CP, CS, V, I, AD, DB>,
9292
) => Promise<void>;
93+
onPong?: (pong: string, conn: Conn<S, CP, CS, V, I, AD, DB>) => Promise<void>;
9394
}
9495

9596
export async function processMessage<
@@ -183,6 +184,17 @@ export async function processMessage<
183184
eventName,
184185
subscribe,
185186
});
187+
} else if ("p" in message.b) {
188+
if (handler.onPong === undefined) {
189+
throw new errors.Unsupported("Ping");
190+
}
191+
192+
const { p: pong } = message.b;
193+
await handler.onPong(pong, conn);
194+
logger().debug("pong response received", {
195+
connId: conn.id,
196+
pong,
197+
});
186198
} else {
187199
assertUnreachable(message.b);
188200
}

packages/core/src/actor/protocol/message/to-client.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,19 @@ export const EventSchema = z.object({
3636
a: z.array(z.unknown()),
3737
});
3838

39+
export const PingSchema = z.object({
40+
// Unique identifier for the ping
41+
p: z.string(),
42+
});
43+
3944
export const ToClientSchema = z.object({
4045
// Body
4146
b: z.union([
4247
z.object({ i: InitSchema }),
4348
z.object({ e: ErrorSchema }),
4449
z.object({ ar: ActionResponseSchema }),
4550
z.object({ ev: EventSchema }),
51+
PingSchema,
4652
]),
4753
});
4854

packages/core/src/actor/protocol/message/to-server.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,17 @@ const SubscriptionRequestSchema = z.object({
1616
s: z.boolean(),
1717
});
1818

19+
export const PingResponseSchema = z.object({
20+
// Ping
21+
p: z.string(),
22+
});
23+
1924
export const ToServerSchema = z.object({
2025
// Body
2126
b: z.union([
2227
z.object({ ar: ActionRequestSchema }),
2328
z.object({ sr: SubscriptionRequestSchema }),
29+
PingResponseSchema,
2430
]),
2531
});
2632

packages/core/src/actor/router-endpoints.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,12 +400,13 @@ export async function handleSseConnect(
400400
.getGenericConnGlobalState(actorId)
401401
.sseStreams.delete(connId);
402402
}
403-
if (conn && actor !== undefined) {
404-
actor.__removeConn(conn);
405-
}
406403

407404
// Close the stream on error
408405
stream.close();
406+
} finally {
407+
if (conn) {
408+
actor?.__removeConn(conn);
409+
}
409410
}
410411
});
411412
}

packages/core/src/client/actor-conn.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,11 @@ enc
416416
argsCount: response.b.ev.a?.length,
417417
});
418418
this.#dispatchEvent(response.b.ev);
419+
} else if ("p" in response.b) {
420+
// Ping request
421+
const ping = response.b.p;
422+
logger().trace("received ping request", { ping });
423+
this.#sendMessage({ b: { p: ping } });
419424
} else {
420425
assertUnreachable(response.b);
421426
}

packages/core/src/driver-test-suite/tests/actor-conn.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { describe, expect, test } from "vitest";
1+
import { describe, expect, test, vi } from "vitest";
22
import type { DriverTestConfig } from "../mod";
33
import { setupDriverTest } from "../utils";
44

@@ -257,5 +257,8 @@ export function runActorConnTests(driverTestConfig: DriverTestConfig) {
257257
]);
258258
});
259259
});
260+
describe("Ping", () => {
261+
test.skip("should restore connections after server restart", async (c) => {});
262+
});
260263
});
261264
}

0 commit comments

Comments
 (0)