diff --git a/.changeset/typed-state-support.md b/.changeset/typed-state-support.md new file mode 100644 index 0000000..83bfa9c --- /dev/null +++ b/.changeset/typed-state-support.md @@ -0,0 +1,7 @@ +--- +"mcp-lite": major +--- + +Add typed state support to McpServer via generic TConfig parameter, enabling type-safe state access in middleware and handlers. + +**Breaking change:** The `ctx.env` property has been removed and replaced with `ctx.state`. Migrate by renaming all `ctx.env` references to `ctx.state`. diff --git a/packages/core/src/context.ts b/packages/core/src/context.ts index c22d71e..66250fc 100644 --- a/packages/core/src/context.ts +++ b/packages/core/src/context.ts @@ -62,24 +62,25 @@ export function getProgressToken( return undefined; } -export function createContext( +export function createContext< + TConfig extends { State: unknown } = { State: Record }, +>( message: JsonRpcMessage, requestId: JsonRpcId | undefined, options: CreateContextOptions = {}, -): MCPServerContext { +): MCPServerContext { // Prefer explicit option, otherwise derive from the request message const progressToken = options.progressToken !== undefined ? options.progressToken : getProgressToken(message); - const context: MCPServerContext = { + const context: MCPServerContext = { request: message, authInfo: options.authInfo, requestId, response: null, - env: {}, - state: {}, + state: {} as TConfig["State"], progressToken, validate: (validator: unknown, input: unknown): T => createValidationFunction(validator, input), diff --git a/packages/core/src/core.ts b/packages/core/src/core.ts index c0ae738..5698dc1 100644 --- a/packages/core/src/core.ts +++ b/packages/core/src/core.ts @@ -66,9 +66,11 @@ function isSupportedVersion(version: string): version is SupportedVersion { ); } -async function runMiddlewares( - middlewares: Middleware[], - ctx: MCPServerContext, +async function runMiddlewares< + TConfig extends { State: unknown } = { State: Record }, +>( + middlewares: Middleware[], + ctx: MCPServerContext, tail: () => Promise, ): Promise { const dispatch = async (i: number): Promise => { @@ -291,19 +293,21 @@ export interface McpServerOptions { * @see {@link ToolCallResult} For tool return value format * @see {@link MCPServerContext} For request context interface */ -export class McpServer { - private methods: Record = {}; +export class McpServer< + TConfig extends { State: unknown } = { State: Record }, +> { + private methods: Record> = {}; private initialized = false; private serverInfo: { name: string; version: string }; - private middlewares: Middleware[] = []; + private middlewares: Middleware[] = []; private capabilities: InitializeResult["capabilities"] = {}; - private onErrorHandler?: OnError; + private onErrorHandler?: OnError; private schemaAdapter?: SchemaAdapter; private logger: Logger; - private tools = new Map(); - private prompts = new Map(); - private resources = new Map(); + private tools = new Map>(); + private prompts = new Map>(); + private resources = new Map>(); private notificationSender?: ( sessionId: string | undefined, @@ -396,7 +400,7 @@ export class McpServer { * }); * ``` */ - use(middleware: Middleware): this { + use(middleware: Middleware): this { this.middlewares.push(middleware); return this; } @@ -424,7 +428,7 @@ export class McpServer { * }); * ``` */ - onError(handler: OnError): this { + onError(handler: OnError): this { this.onErrorHandler = handler; return this; } @@ -535,7 +539,7 @@ export class McpServer { outputSchema: SOutput; handler: ( args: InferOutput, - ctx: MCPServerContext, + ctx: MCPServerContext, ) => | Promise>> | ToolCallResult>; @@ -553,7 +557,7 @@ export class McpServer { outputSchema?: unknown; handler: ( args: InferOutput, - ctx: MCPServerContext, + ctx: MCPServerContext, ) => Promise | ToolCallResult; }, ): this; @@ -569,7 +573,7 @@ export class McpServer { outputSchema: S; handler: ( args: unknown, - ctx: MCPServerContext, + ctx: MCPServerContext, ) => | Promise>> | ToolCallResult>; @@ -587,7 +591,7 @@ export class McpServer { outputSchema?: unknown; handler: ( args: TArgs, - ctx: MCPServerContext, + ctx: MCPServerContext, ) => Promise> | ToolCallResult; }, ): this; @@ -603,7 +607,7 @@ export class McpServer { outputSchema?: unknown | StandardSchemaV1; handler: ( args: TArgs, - ctx: MCPServerContext, + ctx: MCPServerContext, ) => Promise | ToolCallResult; }, ): this { @@ -638,10 +642,10 @@ export class McpServer { metadata.outputSchema = outputSchemaResolved.resolvedSchema; } - const entry: ToolEntry = { + const entry: ToolEntry = { metadata, // TODO - We could avoid this cast if MethodHandler had a generic type for `params` that defaulted to unknown, but here we could pass TArgs - handler: def.handler as MethodHandler, + handler: def.handler as MethodHandler, validator, outputValidator: outputSchemaResolved.validator, }; @@ -714,7 +718,7 @@ export class McpServer { resource( template: string, meta: ResourceMeta, - handler: ResourceHandler, + handler: ResourceHandler, ): this; /** @@ -742,20 +746,21 @@ export class McpServer { template: string, meta: ResourceMeta, validators: ResourceVarValidators, - handler: ResourceHandler, + handler: ResourceHandler, ): this; resource( template: string, meta: ResourceMeta, - validatorsOrHandler: ResourceVarValidators | ResourceHandler, - handler?: ResourceHandler, + validatorsOrHandler: ResourceVarValidators | ResourceHandler, + handler?: ResourceHandler, ): this { if (!this.capabilities.resources) { this.capabilities.resources = { listChanged: true }; } - const actualHandler = handler || (validatorsOrHandler as ResourceHandler); + const actualHandler = + handler || (validatorsOrHandler as ResourceHandler); const validators = handler ? (validatorsOrHandler as ResourceVarValidators) : undefined; @@ -775,7 +780,7 @@ export class McpServer { ...meta, }; - const entry: ResourceEntry = { + const entry: ResourceEntry = { metadata, handler: actualHandler, validators, @@ -873,7 +878,7 @@ export class McpServer { _meta?: { [key: string]: unknown }; arguments?: unknown | StandardSchemaV1; inputSchema?: unknown | StandardSchemaV1; - handler: PromptHandler; + handler: PromptHandler; }, ): this { if (!this.capabilities.prompts) { @@ -917,9 +922,9 @@ export class McpServer { metadata._meta = def._meta; } - const entry: PromptEntry = { + const entry: PromptEntry = { metadata, - handler: def.handler as PromptHandler, + handler: def.handler as PromptHandler, validator, }; @@ -961,7 +966,7 @@ export class McpServer { * See examples/composing-servers for a full working example with multiple * child servers, middleware composition, and real-world patterns. */ - group(child: McpServer): this; + group(child: McpServer): this; /** * Mount a child server with namespaced tools and prompts. @@ -986,7 +991,7 @@ export class McpServer { * See examples/composing-servers for a full working example with multiple * child servers, middleware composition, and real-world patterns. */ - group(prefix: string, child: McpServer): this; + group(prefix: string, child: McpServer): this; /** * Mount a child server with flexible namespacing options. @@ -1012,23 +1017,26 @@ export class McpServer { * .group({ prefix: 'ai', suffix: 'v2' }, server); // 'ai/generateText_v2' * ``` */ - group(options: { prefix?: string; suffix?: string }, child: McpServer): this; + group( + options: { prefix?: string; suffix?: string }, + child: McpServer, + ): this; group( prefixOrOptionsOrChild: | string | { prefix?: string; suffix?: string } - | McpServer, - child?: McpServer, + | McpServer, + child?: McpServer, ): this { let prefix = ""; let suffix = ""; - let childServer: McpServer; + let childServer: McpServer; if (typeof prefixOrOptionsOrChild === "string") { // .group("prefix", child) prefix = prefixOrOptionsOrChild; - childServer = child as McpServer; + childServer = child as McpServer; } else if (prefixOrOptionsOrChild instanceof McpServer) { // .group(child) childServer = prefixOrOptionsOrChild; @@ -1036,7 +1044,7 @@ export class McpServer { // .group({ prefix?, suffix? }, child) prefix = prefixOrOptionsOrChild.prefix || ""; suffix = prefixOrOptionsOrChild.suffix || ""; - childServer = child as McpServer; + childServer = child as McpServer; } this.mountChild(prefix, suffix, childServer); @@ -1049,9 +1057,9 @@ export class McpServer { * @internal */ protected _exportRegistries(): { - tools: Array<{ name: string; entry: ToolEntry }>; - prompts: Array<{ name: string; entry: PromptEntry }>; - resources: Array<{ template: string; entry: ResourceEntry }>; + tools: Array<{ name: string; entry: ToolEntry }>; + prompts: Array<{ name: string; entry: PromptEntry }>; + resources: Array<{ template: string; entry: ResourceEntry }>; } { return { tools: Array.from(this.tools.entries()).map(([name, entry]) => ({ @@ -1073,7 +1081,7 @@ export class McpServer { * Used internally by .group() to compose middleware chains. * @internal */ - protected _exportMiddlewares(): Middleware[] { + protected _exportMiddlewares(): Middleware[] { return [...this.middlewares]; } @@ -1084,9 +1092,9 @@ export class McpServer { * @internal */ private wrapWithMiddlewares( - mws: Middleware[], - handler: MethodHandler, - ): MethodHandler { + mws: Middleware[], + handler: MethodHandler, + ): MethodHandler { return async (params, ctx) => { let result: unknown; let handlerCalled = false; @@ -1117,9 +1125,9 @@ export class McpServer { * @internal */ private wrapResourceHandler( - mws: Middleware[], - handler: ResourceHandler, - ): ResourceHandler { + mws: Middleware[], + handler: ResourceHandler, + ): ResourceHandler { return async (uri, vars, ctx) => { let result: ResourceReadResult | undefined; let handlerCalled = false; @@ -1159,7 +1167,11 @@ export class McpServer { * duplicates are silently skipped. * @internal */ - private mountChild(prefix: string, suffix: string, child: McpServer): void { + private mountChild( + prefix: string, + suffix: string, + child: McpServer, + ): void { /** * Adds prefix or suffix to a tool name before mounting */ @@ -1183,7 +1195,7 @@ export class McpServer { ? this.wrapWithMiddlewares(childMWs, entry.handler) : entry.handler; - const wrappedEntry: ToolEntry = { + const wrappedEntry: ToolEntry = { metadata: { ...entry.metadata, name: qualifiedName }, handler: wrappedHandler, validator: entry.validator, @@ -1207,11 +1219,11 @@ export class McpServer { childMWs.length > 0 ? (this.wrapWithMiddlewares( childMWs, - entry.handler as MethodHandler, - ) as PromptHandler) + entry.handler as MethodHandler, + ) as PromptHandler) : entry.handler; - const wrappedEntry: PromptEntry = { + const wrappedEntry: PromptEntry = { metadata: { ...entry.metadata, name: qualifiedName }, handler: wrappedHandler, validator: entry.validator, @@ -1234,7 +1246,7 @@ export class McpServer { ? this.wrapResourceHandler(childMWs, entry.handler) : entry.handler; - const wrappedEntry: ResourceEntry = { + const wrappedEntry: ResourceEntry = { ...entry, handler: wrappedHandler, }; @@ -1332,7 +1344,7 @@ export class McpServer { ) : undefined; - const ctx = createContext(message as JsonRpcMessage, requestId, { + const ctx = createContext(message as JsonRpcMessage, requestId, { sessionId, sessionProtocolVersion: contextOptions.sessionProtocolVersion, progressToken, @@ -1408,7 +1420,7 @@ export class McpServer { private async handleToolsList( _params: unknown, - _ctx: MCPServerContext, + _ctx: MCPServerContext, ): Promise { return { tools: Array.from(this.tools.values()).map((t) => t.metadata), @@ -1417,7 +1429,7 @@ export class McpServer { private async handleToolsCall( params: unknown, - ctx: MCPServerContext, + ctx: MCPServerContext, ): Promise { if (!isObject(params)) { throw new RpcError( @@ -1482,7 +1494,7 @@ export class McpServer { private async handlePromptsList( _params: unknown, - _ctx: MCPServerContext, + _ctx: MCPServerContext, ): Promise { return { prompts: Array.from(this.prompts.values()).map((p) => p.metadata), @@ -1491,7 +1503,7 @@ export class McpServer { private async handlePromptsGet( params: unknown, - ctx: MCPServerContext, + ctx: MCPServerContext, ): Promise { if (!isObject(params)) { throw new RpcError( @@ -1531,7 +1543,7 @@ export class McpServer { private async handleResourcesList( _params: unknown, - _ctx: MCPServerContext, + _ctx: MCPServerContext, ): Promise { const resources = Array.from(this.resources.values()) .filter((entry) => entry.type === "resource") @@ -1542,7 +1554,7 @@ export class McpServer { private async handleResourceTemplatesList( _params: unknown, - _ctx: MCPServerContext, + _ctx: MCPServerContext, ): Promise { const resourceTemplates = Array.from(this.resources.values()) .filter((entry) => entry.type === "resource_template") @@ -1553,7 +1565,7 @@ export class McpServer { private async handleResourcesRead( params: unknown, - ctx: MCPServerContext, + ctx: MCPServerContext, ): Promise { if (typeof params !== "object" || params === null) { throw new RpcError( @@ -1573,7 +1585,7 @@ export class McpServer { const uri = readParams.uri; - let matchedEntry: ResourceEntry | null = null; + let matchedEntry: ResourceEntry | null = null; let vars: Record = {}; const directEntry = this.resources.get(uri); @@ -1642,7 +1654,7 @@ export class McpServer { private async handleInitialize( params: unknown, - _ctx: MCPServerContext, + _ctx: MCPServerContext, ): Promise { if (!isInitializeParams(params)) { throw new RpcError( @@ -1683,42 +1695,42 @@ export class McpServer { private async handleNotificationCancelled( _params: unknown, - _ctx: MCPServerContext, + _ctx: MCPServerContext, ): Promise> { return {}; } private async handleNotificationInitialized( _params: unknown, - _ctx: MCPServerContext, + _ctx: MCPServerContext, ): Promise> { return {}; } private async handleNotificationProgress( _params: unknown, - _ctx: MCPServerContext, + _ctx: MCPServerContext, ): Promise> { return {}; } private async handleNotificationRootsListChanged( _params: unknown, - _ctx: MCPServerContext, + _ctx: MCPServerContext, ): Promise> { return {}; } private async handleLoggingSetLevel( _params: unknown, - _ctx: MCPServerContext, + _ctx: MCPServerContext, ): Promise> { return {}; } private async handleNotImplemented( _params: unknown, - ctx: MCPServerContext, + ctx: MCPServerContext, ): Promise { throw new RpcError(JSON_RPC_ERROR_CODES.INTERNAL_ERROR, "Not implemented", { method: ctx.request.method, diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts index 87aaa06..557dbea 100644 --- a/packages/core/src/types.ts +++ b/packages/core/src/types.ts @@ -52,9 +52,11 @@ export interface JsonRpcError { data?: unknown; } -export type OnError = ( +export type OnError< + TConfig extends { State: unknown } = { State: Record }, +> = ( err: unknown, - ctx: MCPServerContext, + ctx: MCPServerContext, ) => JsonRpcError | undefined | Promise; export interface InitializeParams { @@ -90,12 +92,13 @@ export interface ProgressUpdate { message?: string; } -export interface MCPServerContext { +export interface MCPServerContext< + TConfig extends { State: unknown } = { State: Record }, +> { request: JsonRpcMessage; requestId: JsonRpcId | undefined; response: JsonRpcRes | null; - env: Record; - state: Record; + state: TConfig["State"]; /** * Info on the authenticated user, if any */ @@ -123,14 +126,18 @@ export interface MCPClientFeatures { supports(feature: ClientCapabilities | string): boolean; } -export type Middleware = ( - ctx: MCPServerContext, +export type Middleware< + TConfig extends { State: unknown } = { State: Record }, +> = ( + ctx: MCPServerContext, next: () => Promise, ) => Promise | void; -export type MethodHandler = ( +export type MethodHandler< + TConfig extends { State: unknown } = { State: Record }, +> = ( params: unknown, - ctx: MCPServerContext, + ctx: MCPServerContext, ) => Promise | unknown; export function isJsonRpcNotification( @@ -308,14 +315,19 @@ export interface PromptMetadata { _meta?: { [key: string]: unknown }; } -export type PromptHandler = ( +export type PromptHandler< + TArgs = unknown, + TConfig extends { State: unknown } = { State: Record }, +> = ( args: TArgs, - ctx: MCPServerContext, + ctx: MCPServerContext, ) => Promise | PromptGetResult; -export interface PromptEntry { +export interface PromptEntry< + TConfig extends { State: unknown } = { State: Record }, +> { metadata: PromptMetadata; - handler: PromptHandler; + handler: PromptHandler; validator?: unknown; } @@ -338,16 +350,20 @@ export interface ResourceProvider { ) => unknown; } -export interface ToolEntry { +export interface ToolEntry< + TConfig extends { State: unknown } = { State: Record }, +> { metadata: Tool; - handler: MethodHandler; + handler: MethodHandler; validator?: unknown; outputValidator?: unknown; } -export interface ResourceEntry { +export interface ResourceEntry< + TConfig extends { State: unknown } = { State: Record }, +> { metadata: Resource | ResourceTemplate; - handler: ResourceHandler; + handler: ResourceHandler; validators?: ResourceVarValidators; matcher?: UriMatcher; type: "resource" | "resource_template"; @@ -513,10 +529,12 @@ export interface ResourceMeta { export type ResourceVarValidators = Record; -export type ResourceHandler = ( +export type ResourceHandler< + TConfig extends { State: unknown } = { State: Record }, +> = ( uri: URL, vars: ResourceVars, - ctx: MCPServerContext, + ctx: MCPServerContext, ) => Promise; export interface NotificationSenderOptions {