Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cli/bin/postgres-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1622,10 +1622,10 @@ auth
const requestedPort = opts.port || 0; // 0 = OS assigns available port
const callbackServer = authServer.createCallbackServer(requestedPort, params.state, 120000); // 2 minute timeout

// Wait a bit for server to start and get port
await new Promise(resolve => setTimeout(resolve, 100));
const actualPort = callbackServer.getPort();
const redirectUri = `http://localhost:${actualPort}/callback`;
// Wait for server to start and get the actual port
const actualPort = await callbackServer.ready;
// Use 127.0.0.1 to match the server bind address (avoids IPv6 issues on some hosts)
const redirectUri = `http://127.0.0.1:${actualPort}/callback`;

console.log(`Callback server listening on port ${actualPort}`);

Expand Down
191 changes: 124 additions & 67 deletions cli/lib/auth-server.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import * as http from "http";

/**
* OAuth callback result
*/
Expand All @@ -12,6 +14,7 @@ export interface CallbackResult {
export interface CallbackServer {
server: { stop: () => void };
promise: Promise<CallbackResult>;
ready: Promise<number>; // Resolves with actual port when server is listening
getPort: () => number;
}

Expand All @@ -31,14 +34,17 @@ function escapeHtml(str: string | null): string {
}

/**
* Create and start callback server using Bun.serve
* Create and start callback server using Node.js http module
*
* @param port - Port to listen on (0 for random available port)
* @param expectedState - Expected state parameter for CSRF protection
* @param timeoutMs - Timeout in milliseconds
* @returns Server object with promise and getPort function
* @returns Server object with promise, ready promise, and getPort function
*
* @remarks
* The `ready` promise resolves with the actual port once the server is listening.
* Callers should await `ready` before using `getPort()` when using port 0.
*
* The server stops asynchronously ~100ms after the callback resolves/rejects.
* This delay ensures the HTTP response is fully sent before closing the connection.
* Callers should not attempt to reuse the same port immediately after the promise
Expand All @@ -53,53 +59,78 @@ export function createCallbackServer(
let actualPort = port;
let resolveCallback: (value: CallbackResult) => void;
let rejectCallback: (reason: Error) => void;
let serverInstance: ReturnType<typeof Bun.serve> | null = null;
let resolveReady: (port: number) => void;
let rejectReady: (reason: Error) => void;
let serverInstance: http.Server | null = null;

const promise = new Promise<CallbackResult>((resolve, reject) => {
resolveCallback = resolve;
rejectCallback = reject;
});

const ready = new Promise<number>((resolve, reject) => {
resolveReady = resolve;
rejectReady = reject;
});

let timeoutId: ReturnType<typeof setTimeout> | null = null;

const stopServer = () => {
// Clear timeout to prevent it firing after manual stop
if (timeoutId) {
clearTimeout(timeoutId);
timeoutId = null;
}
if (serverInstance) {
serverInstance.close();
serverInstance = null;
}
};

// Timeout handler
const timeout = setTimeout(() => {
timeoutId = setTimeout(() => {
if (!resolved) {
resolved = true;
if (serverInstance) {
serverInstance.stop();
}
timeoutId = null; // Already fired, clear reference
stopServer();
rejectCallback(new Error("Authentication timeout. Please try again."));
}
}, timeoutMs);

serverInstance = Bun.serve({
port: port,
hostname: "127.0.0.1",
fetch(req) {
if (resolved) {
return new Response("Already handled", { status: 200 });
}
serverInstance = http.createServer((req, res) => {
if (resolved) {
res.writeHead(200, { "Content-Type": "text/plain" });
res.end("Already handled");
return;
}

const url = new URL(req.url);
const url = new URL(req.url || "/", `http://127.0.0.1:${actualPort}`);

// Only handle /callback path
if (!url.pathname.startsWith("/callback")) {
return new Response("Not Found", { status: 404 });
}
// Only handle /callback path
if (!url.pathname.startsWith("/callback")) {
res.writeHead(404, { "Content-Type": "text/plain" });
res.end("Not Found");
return;
}

const code = url.searchParams.get("code");
const state = url.searchParams.get("state");
const error = url.searchParams.get("error");
const errorDescription = url.searchParams.get("error_description");
const code = url.searchParams.get("code");
const state = url.searchParams.get("state");
const error = url.searchParams.get("error");
const errorDescription = url.searchParams.get("error_description");

// Handle OAuth error
if (error) {
resolved = true;
clearTimeout(timeout);
// Handle OAuth error
if (error) {
resolved = true;
if (timeoutId) {
clearTimeout(timeoutId);
timeoutId = null;
}

setTimeout(() => serverInstance?.stop(), 100);
rejectCallback(new Error(`OAuth error: ${error}${errorDescription ? ` - ${errorDescription}` : ""}`));
setTimeout(() => stopServer(), 100);
rejectCallback(new Error(`OAuth error: ${error}${errorDescription ? ` - ${errorDescription}` : ""}`));

return new Response(`
res.writeHead(400, { "Content-Type": "text/html" });
res.end(`
<!DOCTYPE html>
<html>
<head>
Expand All @@ -120,12 +151,14 @@ export function createCallbackServer(
</div>
</body>
</html>
`, { status: 400, headers: { "Content-Type": "text/html" } });
}
`);
return;
}

// Validate required parameters
if (!code || !state) {
return new Response(`
// Validate required parameters
if (!code || !state) {
res.writeHead(400, { "Content-Type": "text/html" });
res.end(`
<!DOCTYPE html>
<html>
<head>
Expand All @@ -144,18 +177,23 @@ export function createCallbackServer(
</div>
</body>
</html>
`, { status: 400, headers: { "Content-Type": "text/html" } });
}
`);
return;
}

// Validate state (CSRF protection)
if (expectedState && state !== expectedState) {
resolved = true;
clearTimeout(timeout);
// Validate state (CSRF protection)
if (expectedState && state !== expectedState) {
resolved = true;
if (timeoutId) {
clearTimeout(timeoutId);
timeoutId = null;
}

setTimeout(() => serverInstance?.stop(), 100);
rejectCallback(new Error("State mismatch (possible CSRF attack)"));
setTimeout(() => stopServer(), 100);
rejectCallback(new Error("State mismatch (possible CSRF attack)"));

return new Response(`
res.writeHead(400, { "Content-Type": "text/html" });
res.end(`
<!DOCTYPE html>
<html>
<head>
Expand All @@ -174,19 +212,24 @@ export function createCallbackServer(
</div>
</body>
</html>
`, { status: 400, headers: { "Content-Type": "text/html" } });
}
`);
return;
}

// Success!
resolved = true;
clearTimeout(timeout);
// Success!
resolved = true;
if (timeoutId) {
clearTimeout(timeoutId);
timeoutId = null;
}

// Resolve first, then stop server asynchronously after response is sent.
// The 100ms delay ensures the HTTP response is fully written before closing.
resolveCallback({ code, state });
setTimeout(() => serverInstance?.stop(), 100);
// Resolve first, then stop server asynchronously after response is sent.
// The 100ms delay ensures the HTTP response is fully written before closing.
resolveCallback({ code, state });
setTimeout(() => stopServer(), 100);

return new Response(`
res.writeHead(200, { "Content-Type": "text/html" });
res.end(`
<!DOCTYPE html>
<html>
<head>
Expand All @@ -205,24 +248,38 @@ export function createCallbackServer(
</div>
</body>
</html>
`, { status: 200, headers: { "Content-Type": "text/html" } });
},
`);
});

actualPort = serverInstance.port;
// Handle server errors (e.g., EADDRINUSE)
serverInstance.on("error", (err: NodeJS.ErrnoException) => {
if (timeoutId) {
clearTimeout(timeoutId);
timeoutId = null;
}
if (err.code === "EADDRINUSE") {
rejectReady(new Error(`Port ${port} is already in use`));
} else {
rejectReady(new Error(`Server error: ${err.message}`));
}
if (!resolved) {
resolved = true;
rejectCallback(err);
}
});

serverInstance.listen(port, "127.0.0.1", () => {
const address = serverInstance?.address();
if (address && typeof address === "object") {
actualPort = address.port;
}
resolveReady(actualPort);
});
Comment thread
NikolayS marked this conversation as resolved.

return {
server: { stop: () => serverInstance?.stop() },
server: { stop: stopServer },
Comment thread
NikolayS marked this conversation as resolved.
promise,
ready,
getPort: () => actualPort,
};
}

/**
* Get the actual port the server is listening on
* @param server - Bun server instance
* @returns Port number
*/
export function getServerPort(server: ReturnType<typeof Bun.serve>): number {
return server.port;
}
72 changes: 38 additions & 34 deletions cli/test/init.integration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -326,41 +326,45 @@ describe.skipIf(skipTests)("integration: prepare-db", () => {
}
});

test("--verify returns 0 when ok and non-zero when missing", async () => {
pg = await createTempPostgres();

try {
// Prepare: run init
{
const r = runCliInit([pg.adminUri, "--password", "monpw", "--skip-optional-permissions"]);
expect(r.status).toBe(0);
}

// Verify should pass
{
const r = runCliInit([pg.adminUri, "--verify", "--skip-optional-permissions"]);
expect(r.status).toBe(0);
expect(r.stdout).toMatch(/prepare-db verify: OK/i);
}

// Break a required privilege and ensure verify fails
{
const c = new Client({ connectionString: pg.adminUri });
await c.connect();
await c.query("revoke select on pg_catalog.pg_index from public");
await c.query("revoke select on pg_catalog.pg_index from postgres_ai_mon");
await c.end();
}
{
const r = runCliInit([pg.adminUri, "--verify", "--skip-optional-permissions"]);
expect(r.status).not.toBe(0);
expect(r.stderr).toMatch(/prepare-db verify failed/i);
expect(r.stderr).toMatch(/pg_catalog\.pg_index/i);
test(
"--verify returns 0 when ok and non-zero when missing",
async () => {
pg = await createTempPostgres();

try {
// Prepare: run init
{
const r = runCliInit([pg.adminUri, "--password", "monpw", "--skip-optional-permissions"]);
expect(r.status).toBe(0);
}

// Verify should pass
{
const r = runCliInit([pg.adminUri, "--verify", "--skip-optional-permissions"]);
expect(r.status).toBe(0);
expect(r.stdout).toMatch(/prepare-db verify: OK/i);
}

// Break a required privilege and ensure verify fails
{
const c = new Client({ connectionString: pg.adminUri });
await c.connect();
await c.query("revoke select on pg_catalog.pg_index from public");
await c.query("revoke select on pg_catalog.pg_index from postgres_ai_mon");
await c.end();
}
{
const r = runCliInit([pg.adminUri, "--verify", "--skip-optional-permissions"]);
expect(r.status).not.toBe(0);
expect(r.stderr).toMatch(/prepare-db verify failed/i);
expect(r.stderr).toMatch(/pg_catalog\.pg_index/i);
}
} finally {
await pg.cleanup();
}
} finally {
await pg.cleanup();
}
});
},
{ timeout: 15000 }
);

test("--reset-password updates the monitoring role login password", async () => {
pg = await createTempPostgres();
Expand Down