Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/counter/src/registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ const counter = actor({
onAuth: () => {
return true;
},
onConnect: (c) => {},
actions: {
increment: (c, x: number) => {
c.state.count += x;
Expand Down
2 changes: 1 addition & 1 deletion packages/actor/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"scripts": {
"build": "tsup src/mod.ts src/client.ts src/log.ts src/errors.ts src/test.ts",
"check-types": "tsc --noEmit",
"test": "vitest run"
"test": "vitest run --passWithNoTests"
},
"dependencies": {
"@rivetkit/core": "workspace:*"
Expand Down
4 changes: 4 additions & 0 deletions packages/core/src/actor/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ export function generateConnToken(): string {
return generateSecureToken(32);
}

export function generatePing(): string {
return crypto.randomUUID();
}

export type ConnId = string;

export type AnyConn = Conn<any, any, any, any, any, any, any>;
Expand Down
80 changes: 71 additions & 9 deletions packages/core/src/actor/instance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import type { Logger } from "@/common/log";
import { isCborSerializable, stringifyError } from "@/common/utils";
import type { UniversalWebSocket } from "@/common/websocket-interface";
import { ActorInspector } from "@/inspector/actor";
import type { Registry, RegistryConfig } from "@/mod";
import type { Registry } from "@/mod";
import type { ActionContext } from "./action";
import type { ActorConfig } from "./config";
import { Conn, type ConnId } from "./connection";
import { Conn, type ConnId, generatePing } from "./connection";
import { ActorContext } from "./context";
import type { AnyDatabaseProvider, InferDatabaseClient } from "./database";
import type { ActorDriver, ConnDriver, ConnDrivers } from "./driver";
Expand Down Expand Up @@ -157,6 +157,11 @@ export class ActorInstance<
#ready = false;

#connections = new Map<ConnId, Conn<S, CP, CS, V, I, AD, DB>>();
// This is used to track the last ping sent to the client, when restoring a connection
#connectionsPingRequests = new Map<
ConnId,
{ ping: string; resolve: () => void }
>();
#subscriptionIndex = new Map<string, Set<Conn<S, CP, CS, V, I, AD, DB>>>();

#schedule!: Schedule;
Expand Down Expand Up @@ -591,6 +596,8 @@ export class ActorInstance<
// Set initial state
this.#setPersist(persistData);

const restorePromises = [];

// Load connections
for (const connPersist of this.#persist.c) {
// Create connections
Expand All @@ -601,13 +608,60 @@ export class ActorInstance<
driver,
this.#connStateEnabled,
);
this.#connections.set(conn.id, conn);

// Register event subscriptions
for (const sub of connPersist.su) {
this.#addSubscription(sub.n, conn, true);
}
// send ping, to ensure the connection is alive

const { promise, resolve } = Promise.withResolvers<void>();
restorePromises.push(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tihs promise chain is illegible. rewrite cleaner

Promise.race([
promise,
new Promise<void>((_, reject) => {
setTimeout(() => {
reject();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

() => { reject() }

to

reject

}, 2500);
}),
])
.catch(() => {
process.nextTick(() => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this portable? i think we need to use setTimeout(x, 0) instead

need comment on why this is running in the next tick

logger().debug("connection restore failed", {
connId: conn.id,
});
this.#connections.delete(conn.id);
conn.disconnect(
"Connection restore failed, connection is not alive",
);
this.__removeConn(conn);
});
})
.then(() => {
logger().debug("connection restored", {
connId: conn.id,
});
this.#connections.set(conn.id, conn);

// Register event subscriptions
for (const sub of connPersist.su) {
this.#addSubscription(sub.n, conn, true);
}
}),
);

const ping = generatePing();
this.#connectionsPingRequests.set(conn.id, { ping, resolve });
conn._sendMessage(
new CachedSerializer<wsToClient.ToClient>({
b: {
p: ping,
},
}),
);
}

const result = await Promise.allSettled(restorePromises);
logger().info("connections restored", {
success: result.filter((r) => r.status === "fulfilled").length,
failed: result.filter((r) => r.status === "rejected").length,
});
} else {
logger().info("actor creating");

Expand Down Expand Up @@ -818,6 +872,8 @@ export class ActorInstance<
this.#persist.c.push(persist);
this.saveState({ immediate: true });

this.inspector.emitter.emit("connectionUpdated");

// Handle connection
if (this.#config.onConnect) {
try {
Expand All @@ -841,8 +897,6 @@ export class ActorInstance<
}
}

this.inspector.emitter.emit("connectionUpdated");

// Send init message
conn._sendMessage(
new CachedSerializer<wsToClient.ToClient>({
Expand Down Expand Up @@ -890,6 +944,14 @@ export class ActorInstance<
});
this.#removeSubscription(eventName, conn, false);
},
onPong: async (pong, conn) => {
const pingRequest = this.#connectionsPingRequests.get(conn.id);
if (pingRequest?.ping === pong) {
// Resolve the ping request if it matches the last sent ping
pingRequest.resolve();
this.#connectionsPingRequests.delete(conn.id);
}
},
});
}

Expand Down
12 changes: 12 additions & 0 deletions packages/core/src/actor/protocol/message/mod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ export interface ProcessMessageHandler<
eventName: string,
conn: Conn<S, CP, CS, V, I, AD, DB>,
) => Promise<void>;
onPong?: (pong: string, conn: Conn<S, CP, CS, V, I, AD, DB>) => Promise<void>;
}

export async function processMessage<
Expand Down Expand Up @@ -183,6 +184,17 @@ export async function processMessage<
eventName,
subscribe,
});
} else if ("p" in message.b) {
if (handler.onPong === undefined) {
throw new errors.Unsupported("Ping");
}

const { p: pong } = message.b;
await handler.onPong(pong, conn);
logger().debug("pong response received", {
connId: conn.id,
pong,
});
} else {
assertUnreachable(message.b);
}
Expand Down
6 changes: 6 additions & 0 deletions packages/core/src/actor/protocol/message/to-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@ export const EventSchema = z.object({
a: z.array(z.unknown()),
});

export const PingSchema = z.object({
// Unique identifier for the ping
p: z.string(),
});

export const ToClientSchema = z.object({
// Body
b: z.union([
z.object({ i: InitSchema }),
z.object({ e: ErrorSchema }),
z.object({ ar: ActionResponseSchema }),
z.object({ ev: EventSchema }),
PingSchema,
]),
});

Expand Down
6 changes: 6 additions & 0 deletions packages/core/src/actor/protocol/message/to-server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@ const SubscriptionRequestSchema = z.object({
s: z.boolean(),
});

export const PingResponseSchema = z.object({
// Ping
p: z.string(),
});

export const ToServerSchema = z.object({
// Body
b: z.union([
z.object({ ar: ActionRequestSchema }),
z.object({ sr: SubscriptionRequestSchema }),
PingResponseSchema,
]),
});

Expand Down
7 changes: 4 additions & 3 deletions packages/core/src/actor/router-endpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,13 @@ export async function handleSseConnect(
.getGenericConnGlobalState(actorId)
.sseStreams.delete(connId);
}
if (conn && actor !== undefined) {
actor.__removeConn(conn);
}

// Close the stream on error
stream.close();
} finally {
if (conn) {
actor?.__removeConn(conn);
}
}
});
}
Expand Down
5 changes: 5 additions & 0 deletions packages/core/src/client/actor-conn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,11 @@ enc
argsCount: response.b.ev.a?.length,
});
this.#dispatchEvent(response.b.ev);
} else if ("p" in response.b) {
// Ping request
const ping = response.b.p;
logger().trace("received ping request", { ping });
this.#sendMessage({ b: { p: ping } });
} else {
assertUnreachable(response.b);
}
Expand Down
5 changes: 4 additions & 1 deletion packages/core/src/driver-test-suite/tests/actor-conn.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { describe, expect, test } from "vitest";
import { describe, expect, test, vi } from "vitest";
import type { DriverTestConfig } from "../mod";
import { setupDriverTest } from "../utils";

Expand Down Expand Up @@ -257,5 +257,8 @@ export function runActorConnTests(driverTestConfig: DriverTestConfig) {
]);
});
});
describe("Ping", () => {
test.skip("should restore connections after server restart", async (c) => {});
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test for connection restoration is currently skipped with test.skip(). To ensure the new ping/pong functionality is properly validated, consider implementing this test to verify that connections are correctly restored after server restart, as outlined in the PR description. This would provide confidence that the connection restoration mechanism works as expected in real-world scenarios.

Suggested change
test.skip("should restore connections after server restart", async (c) => {});
test("should restore connections after server restart", async (c) => {
const driver = c.driver();
const actor = await driver.createActor();
// Create a connection and verify it works
await actor.call("test", {});
// Simulate server restart
await driver.simulateServerRestart();
// Verify connection is restored and works after restart
await actor.call("test", {});
// Clean up
await actor.delete();
});

Spotted by Diamond

Is this helpful? React 👍 or 👎 to let us know.

});
});
}
7 changes: 5 additions & 2 deletions packages/core/src/driver-test-suite/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@ export async function setupDriverTest(
): Promise<{
client: Client<typeof registry>;
endpoint: string;
cleanup?: () => Promise<void>;
}> {
if (!driverTestConfig.useRealTimers) {
vi.useFakeTimers();
}

// Build drivers
const projectPath = resolve(__dirname, "../../fixtures/driver-test-suite");
const { endpoint, cleanup } = await driverTestConfig.start(projectPath);
c.onTestFinished(cleanup);
const { endpoint, cleanup: driverCleanup } =
await driverTestConfig.start(projectPath);
c.onTestFinished(driverCleanup);

let client: Client<typeof registry>;
if (driverTestConfig.clientType === "http") {
Expand All @@ -49,6 +51,7 @@ export async function setupDriverTest(
return {
client,
endpoint,
cleanup: driverCleanup,
};
}

Expand Down
Loading