diff --git a/.vscode/launch.json b/.vscode/launch.json index f8eaa53f6..8eec7d6e5 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -18,8 +18,8 @@ "request": "launch", "name": "Launch Program", "skipFiles": ["/**"], - "program": "${workspaceFolder}/dist/index.js", - "args": ["--transport", "http", "--loggers", "stderr", "mcp"], + "runtimeExecutable": "npm", + "runtimeArgs": ["start"], "preLaunchTask": "tsc: build - tsconfig.build.json", "outFiles": ["${workspaceFolder}/dist/**/*.js"] } diff --git a/README.md b/README.md index 78169a00a..6a91e158f 100644 --- a/README.md +++ b/README.md @@ -302,20 +302,22 @@ The MongoDB MCP Server can be configured using multiple methods, with the follow ### Configuration Options -| CLI Option | Environment Variable | Default | Description | -| ------------------ | --------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `apiClientId` | `MDB_MCP_API_CLIENT_ID` | | Atlas API client ID for authentication. Required for running Atlas tools. | -| `apiClientSecret` | `MDB_MCP_API_CLIENT_SECRET` | | Atlas API client secret for authentication. Required for running Atlas tools. | -| `connectionString` | `MDB_MCP_CONNECTION_STRING` | | MongoDB connection string for direct database connections. Optional, if not set, you'll need to call the `connect` tool before interacting with MongoDB data. | -| `loggers` | `MDB_MCP_LOGGERS` | disk,mcp | Comma separated values, possible values are `mcp`, `disk` and `stderr`. See [Logger Options](#logger-options) for details. | -| `logPath` | `MDB_MCP_LOG_PATH` | see note\* | Folder to store logs. | -| `disabledTools` | `MDB_MCP_DISABLED_TOOLS` | | An array of tool names, operation types, and/or categories of tools that will be disabled. | -| `readOnly` | `MDB_MCP_READ_ONLY` | false | When set to true, only allows read, connect, and metadata operation types, disabling create/update/delete operations. | -| `indexCheck` | `MDB_MCP_INDEX_CHECK` | false | When set to true, enforces that query operations must use an index, rejecting queries that perform a collection scan. | -| `telemetry` | `MDB_MCP_TELEMETRY` | enabled | When set to disabled, disables telemetry collection. | -| `transport` | `MDB_MCP_TRANSPORT` | stdio | Either 'stdio' or 'http'. | -| `httpPort` | `MDB_MCP_HTTP_PORT` | 3000 | Port number. | -| `httpHost` | `MDB_MCP_HTTP_HOST` | 127.0.0.1 | Host to bind the http server. | +| CLI Option | Environment Variable | Default | Description | +| ----------------------- | --------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `apiClientId` | `MDB_MCP_API_CLIENT_ID` | | Atlas API client ID for authentication. Required for running Atlas tools. | +| `apiClientSecret` | `MDB_MCP_API_CLIENT_SECRET` | | Atlas API client secret for authentication. Required for running Atlas tools. | +| `connectionString` | `MDB_MCP_CONNECTION_STRING` | | MongoDB connection string for direct database connections. Optional, if not set, you'll need to call the `connect` tool before interacting with MongoDB data. | +| `loggers` | `MDB_MCP_LOGGERS` | disk,mcp | Comma separated values, possible values are `mcp`, `disk` and `stderr`. See [Logger Options](#logger-options) for details. | +| `logPath` | `MDB_MCP_LOG_PATH` | see note\* | Folder to store logs. | +| `disabledTools` | `MDB_MCP_DISABLED_TOOLS` | | An array of tool names, operation types, and/or categories of tools that will be disabled. | +| `readOnly` | `MDB_MCP_READ_ONLY` | false | When set to true, only allows read, connect, and metadata operation types, disabling create/update/delete operations. | +| `indexCheck` | `MDB_MCP_INDEX_CHECK` | false | When set to true, enforces that query operations must use an index, rejecting queries that perform a collection scan. | +| `telemetry` | `MDB_MCP_TELEMETRY` | enabled | When set to disabled, disables telemetry collection. | +| `transport` | `MDB_MCP_TRANSPORT` | stdio | Either 'stdio' or 'http'. | +| `httpPort` | `MDB_MCP_HTTP_PORT` | 3000 | Port number. | +| `httpHost` | `MDB_MCP_HTTP_HOST` | 127.0.0.1 | Host to bind the http server. | +| `idleTimeoutMs` | `MDB_MCP_IDLE_TIMEOUT_MS` | 600000 | Idle timeout for a client to disconnect (only applies to http transport). | +| `notificationTimeoutMs` | `MDB_MCP_NOTIFICATION_TIMEOUT_MS` | 540000 | Notification timeout for a client to be aware of diconnect (only applies to http transport). | #### Logger Options diff --git a/package.json b/package.json index cafa5e9b4..b58c1191e 100644 --- a/package.json +++ b/package.json @@ -16,7 +16,7 @@ }, "type": "module", "scripts": { - "start": "node dist/index.js --transport http", + "start": "node dist/index.js --transport http --loggers stderr mcp", "prepare": "npm run build", "build:clean": "rm -rf dist", "build:compile": "tsc --project tsconfig.build.json", diff --git a/src/common/config.ts b/src/common/config.ts index 98c13cfcb..3406a440a 100644 --- a/src/common/config.ts +++ b/src/common/config.ts @@ -28,6 +28,8 @@ export interface UserConfig { httpPort: number; httpHost: string; loggers: Array<"stderr" | "disk" | "mcp">; + idleTimeoutMs: number; + notificationTimeoutMs: number; } const defaults: UserConfig = { @@ -47,6 +49,8 @@ const defaults: UserConfig = { httpPort: 3000, httpHost: "127.0.0.1", loggers: ["disk", "mcp"], + idleTimeoutMs: 600000, // 10 minutes + notificationTimeoutMs: 540000, // 9 minutes }; export const config = { diff --git a/src/common/logger.ts b/src/common/logger.ts index 259d173ea..7ed1a9acf 100644 --- a/src/common/logger.ts +++ b/src/common/logger.ts @@ -43,8 +43,9 @@ export const LogId = { streamableHttpTransportStarted: mongoLogId(1_006_001), streamableHttpTransportSessionCloseFailure: mongoLogId(1_006_002), - streamableHttpTransportRequestFailure: mongoLogId(1_006_003), - streamableHttpTransportCloseFailure: mongoLogId(1_006_004), + streamableHttpTransportSessionCloseNotification: mongoLogId(1_006_003), + streamableHttpTransportRequestFailure: mongoLogId(1_006_004), + streamableHttpTransportCloseFailure: mongoLogId(1_006_005), } as const; export abstract class LoggerBase { diff --git a/src/common/sessionStore.ts b/src/common/sessionStore.ts index 9159f6335..9ad9d9bba 100644 --- a/src/common/sessionStore.ts +++ b/src/common/sessionStore.ts @@ -1,31 +1,92 @@ import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; -import logger, { LogId } from "./logger.js"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import logger, { LogId, McpLogger } from "./logger.js"; +import { TimeoutManager } from "./timeoutManager.js"; export class SessionStore { - private sessions: { [sessionId: string]: StreamableHTTPServerTransport } = {}; + private sessions: { + [sessionId: string]: { + mcpServer: McpServer; + transport: StreamableHTTPServerTransport; + abortTimeout: TimeoutManager; + notificationTimeout: TimeoutManager; + }; + } = {}; + + constructor( + private readonly idleTimeoutMS: number, + private readonly notificationTimeoutMS: number + ) { + if (idleTimeoutMS <= 0) { + throw new Error("idleTimeoutMS must be greater than 0"); + } + if (notificationTimeoutMS <= 0) { + throw new Error("notificationTimeoutMS must be greater than 0"); + } + if (idleTimeoutMS <= notificationTimeoutMS) { + throw new Error("idleTimeoutMS must be greater than notificationTimeoutMS"); + } + } getSession(sessionId: string): StreamableHTTPServerTransport | undefined { - return this.sessions[sessionId]; + this.resetTimeout(sessionId); + return this.sessions[sessionId]?.transport; + } + + private resetTimeout(sessionId: string): void { + const session = this.sessions[sessionId]; + if (!session) { + return; + } + + session.abortTimeout.reset(); + + session.notificationTimeout.reset(); + } + + private sendNotification(sessionId: string): void { + const session = this.sessions[sessionId]; + if (!session) { + return; + } + const logger = new McpLogger(session.mcpServer); + logger.info( + LogId.streamableHttpTransportSessionCloseNotification, + "sessionStore", + "Session is about to be closed due to inactivity" + ); } - setSession(sessionId: string, transport: StreamableHTTPServerTransport): void { + setSession(sessionId: string, transport: StreamableHTTPServerTransport, mcpServer: McpServer): void { if (this.sessions[sessionId]) { throw new Error(`Session ${sessionId} already exists`); } - this.sessions[sessionId] = transport; + const abortTimeout = new TimeoutManager(async () => { + const logger = new McpLogger(mcpServer); + logger.info( + LogId.streamableHttpTransportSessionCloseNotification, + "sessionStore", + "Session closed due to inactivity" + ); + + await this.closeSession(sessionId); + }, this.idleTimeoutMS); + const notificationTimeout = new TimeoutManager( + () => this.sendNotification(sessionId), + this.notificationTimeoutMS + ); + this.sessions[sessionId] = { mcpServer, transport, abortTimeout, notificationTimeout }; } async closeSession(sessionId: string, closeTransport: boolean = true): Promise { if (!this.sessions[sessionId]) { throw new Error(`Session ${sessionId} not found`); } + this.sessions[sessionId].abortTimeout.clear(); + this.sessions[sessionId].notificationTimeout.clear(); if (closeTransport) { - const transport = this.sessions[sessionId]; - if (!transport) { - throw new Error(`Session ${sessionId} not found`); - } try { - await transport.close(); + await this.sessions[sessionId].transport.close(); } catch (error) { logger.error( LogId.streamableHttpTransportSessionCloseFailure, @@ -38,11 +99,6 @@ export class SessionStore { } async closeAllSessions(): Promise { - await Promise.all( - Object.values(this.sessions) - .filter((transport) => transport !== undefined) - .map((transport) => transport.close()) - ); - this.sessions = {}; + await Promise.all(Object.keys(this.sessions).map((sessionId) => this.closeSession(sessionId))); } } diff --git a/src/common/timeoutManager.ts b/src/common/timeoutManager.ts new file mode 100644 index 000000000..03161dfc0 --- /dev/null +++ b/src/common/timeoutManager.ts @@ -0,0 +1,63 @@ +/** + * A class that manages timeouts for a callback function. + * It is used to ensure that a callback function is called after a certain amount of time. + * If the callback function is not called after the timeout, it will be called with an error. + */ +export class TimeoutManager { + private timeoutId?: NodeJS.Timeout; + + /** + * A callback function that is called when the timeout is reached. + */ + public onerror?: (error: unknown) => void; + + /** + * Creates a new TimeoutManager. + * @param callback - A callback function that is called when the timeout is reached. + * @param timeoutMS - The timeout in milliseconds. + */ + constructor( + private readonly callback: () => Promise | void, + private readonly timeoutMS: number + ) { + if (timeoutMS <= 0) { + throw new Error("timeoutMS must be greater than 0"); + } + this.reset(); + } + + /** + * Clears the timeout. + */ + clear() { + if (this.timeoutId) { + clearTimeout(this.timeoutId); + this.timeoutId = undefined; + } + } + + /** + * Runs the callback function. + */ + private async runCallback() { + if (this.callback) { + try { + await this.callback(); + } catch (error: unknown) { + this.onerror?.(error); + } + } + } + + /** + * Resets the timeout. + */ + reset() { + this.clear(); + this.timeoutId = setTimeout(() => { + void this.runCallback().finally(() => { + this.timeoutId = undefined; + }); + }, this.timeoutMS); + } +} diff --git a/src/index.ts b/src/index.ts index 4f81c0cd1..fca9b83fa 100644 --- a/src/index.ts +++ b/src/index.ts @@ -6,7 +6,7 @@ import { StdioRunner } from "./transports/stdio.js"; import { StreamableHttpRunner } from "./transports/streamableHttp.js"; async function main() { - const transportRunner = config.transport === "stdio" ? new StdioRunner() : new StreamableHttpRunner(); + const transportRunner = config.transport === "stdio" ? new StdioRunner(config) : new StreamableHttpRunner(config); const shutdown = () => { logger.info(LogId.serverCloseRequested, "server", `Server close requested`); diff --git a/src/transports/base.ts b/src/transports/base.ts index 442db18a4..cc58f7502 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -1,4 +1,4 @@ -import { config } from "../common/config.js"; +import { UserConfig } from "../common/config.js"; import { packageInfo } from "../common/packageInfo.js"; import { Server } from "../server.js"; import { Session } from "../common/session.js"; @@ -6,14 +6,14 @@ import { Telemetry } from "../telemetry/telemetry.js"; import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; export abstract class TransportRunnerBase { - protected setupServer(): Server { + protected setupServer(userConfig: UserConfig): Server { const session = new Session({ - apiBaseUrl: config.apiBaseUrl, - apiClientId: config.apiClientId, - apiClientSecret: config.apiClientSecret, + apiBaseUrl: userConfig.apiBaseUrl, + apiClientId: userConfig.apiClientId, + apiClientSecret: userConfig.apiClientSecret, }); - const telemetry = Telemetry.create(session, config); + const telemetry = Telemetry.create(session, userConfig); const mcpServer = new McpServer({ name: packageInfo.mcpServerName, @@ -24,7 +24,7 @@ export abstract class TransportRunnerBase { mcpServer, session, telemetry, - userConfig: config, + userConfig, }); } diff --git a/src/transports/stdio.ts b/src/transports/stdio.ts index 9f18627ce..870ec73cc 100644 --- a/src/transports/stdio.ts +++ b/src/transports/stdio.ts @@ -4,6 +4,7 @@ import { TransportRunnerBase } from "./base.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js"; import { EJSON } from "bson"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; +import { UserConfig } from "../common/config.js"; // This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk // but it uses EJSON.parse instead of JSON.parse to handle BSON types @@ -52,9 +53,13 @@ export function createStdioTransport(): StdioServerTransport { export class StdioRunner extends TransportRunnerBase { private server: Server | undefined; + constructor(private userConfig: UserConfig) { + super(); + } + async start() { try { - this.server = this.setupServer(); + this.server = this.setupServer(this.userConfig); const transport = createStdioTransport(); diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index fbe01a559..282cd7bc2 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -3,7 +3,7 @@ import http from "http"; import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; import { TransportRunnerBase } from "./base.js"; -import { config } from "../common/config.js"; +import { UserConfig } from "../common/config.js"; import logger, { LogId } from "../common/logger.js"; import { randomUUID } from "crypto"; import { SessionStore } from "../common/sessionStore.js"; @@ -38,7 +38,12 @@ function promiseHandler( export class StreamableHttpRunner extends TransportRunnerBase { private httpServer: http.Server | undefined; - private sessionStore: SessionStore = new SessionStore(); + private sessionStore: SessionStore; + + constructor(private userConfig: UserConfig) { + super(); + this.sessionStore = new SessionStore(this.userConfig.idleTimeoutMs, this.userConfig.notificationTimeoutMs); + } async start() { const app = express(); @@ -101,11 +106,11 @@ export class StreamableHttpRunner extends TransportRunnerBase { return; } - const server = this.setupServer(); + const server = this.setupServer(this.userConfig); const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID().toString(), onsessioninitialized: (sessionId) => { - this.sessionStore.setSession(sessionId, transport); + this.sessionStore.setSession(sessionId, transport, server.mcpServer); }, onsessionclosed: async (sessionId) => { try { @@ -140,7 +145,7 @@ export class StreamableHttpRunner extends TransportRunnerBase { app.delete("/mcp", promiseHandler(handleRequest)); this.httpServer = await new Promise((resolve, reject) => { - const result = app.listen(config.httpPort, config.httpHost, (err?: Error) => { + const result = app.listen(this.userConfig.httpPort, this.userConfig.httpHost, (err?: Error) => { if (err) { reject(err); return; @@ -152,7 +157,7 @@ export class StreamableHttpRunner extends TransportRunnerBase { logger.info( LogId.streamableHttpTransportStarted, "streamableHttpTransport", - `Server started on http://${config.httpHost}:${config.httpPort}` + `Server started on http://${this.userConfig.httpHost}:${this.userConfig.httpPort}` ); } diff --git a/tests/integration/transports/streamableHttp.test.ts b/tests/integration/transports/streamableHttp.test.ts index c295705ee..d5b6e0be7 100644 --- a/tests/integration/transports/streamableHttp.test.ts +++ b/tests/integration/transports/streamableHttp.test.ts @@ -14,7 +14,7 @@ describe("StreamableHttpRunner", () => { oldLoggers = config.loggers; config.telemetry = "disabled"; config.loggers = ["stderr"]; - runner = new StreamableHttpRunner(); + runner = new StreamableHttpRunner(config); await runner.start(); }); diff --git a/tests/unit/common/timeoutManager.test.ts b/tests/unit/common/timeoutManager.test.ts new file mode 100644 index 000000000..a0cc5b307 --- /dev/null +++ b/tests/unit/common/timeoutManager.test.ts @@ -0,0 +1,79 @@ +import { afterAll, beforeAll, describe, expect, it, vi } from "vitest"; +import { TimeoutManager } from "../../../src/common/timeoutManager.js"; + +describe("TimeoutManager", () => { + beforeAll(() => { + vi.useFakeTimers(); + }); + + afterAll(() => { + vi.useRealTimers(); + }); + + it("calls the timeout callback", () => { + const callback = vi.fn(); + + new TimeoutManager(callback, 1000); + + vi.advanceTimersByTime(1000); + expect(callback).toHaveBeenCalled(); + }); + + it("does not call the timeout callback if the timeout is cleared", () => { + const callback = vi.fn(); + + const timeoutManager = new TimeoutManager(callback, 1000); + + vi.advanceTimersByTime(500); + timeoutManager.clear(); + vi.advanceTimersByTime(500); + + expect(callback).not.toHaveBeenCalled(); + }); + + it("does not call the timeout callback if the timeout is reset", () => { + const callback = vi.fn(); + + const timeoutManager = new TimeoutManager(callback, 1000); + + vi.advanceTimersByTime(500); + timeoutManager.reset(); + vi.advanceTimersByTime(500); + expect(callback).not.toHaveBeenCalled(); + }); + + it("calls the onerror callback", () => { + const onerrorCallback = vi.fn(); + + const tm = new TimeoutManager(() => { + throw new Error("test"); + }, 1000); + tm.onerror = onerrorCallback; + + vi.advanceTimersByTime(1000); + expect(onerrorCallback).toHaveBeenCalled(); + }); + + describe("if timeout is reset", () => { + it("does not call the timeout callback within the timeout period", () => { + const callback = vi.fn(); + + const timeoutManager = new TimeoutManager(callback, 1000); + + vi.advanceTimersByTime(500); + timeoutManager.reset(); + vi.advanceTimersByTime(500); + expect(callback).not.toHaveBeenCalled(); + }); + it("calls the timeout callback after the timeout period", () => { + const callback = vi.fn(); + + const timeoutManager = new TimeoutManager(callback, 1000); + + vi.advanceTimersByTime(500); + timeoutManager.reset(); + vi.advanceTimersByTime(1000); + expect(callback).toHaveBeenCalled(); + }); + }); +});