From 7163ec636a7290d1fa3544ead20c645bc8e80f6c Mon Sep 17 00:00:00 2001 From: He-Pin Date: Mon, 8 Sep 2025 21:47:05 +0800 Subject: [PATCH 1/2] chore: Extract DefaultMcpServerSessionFactory from McpAsyncServer. Signed-off-by: He-Pin --- .../server/McpAsyncServer.java | 47 +++++------------ .../spec/DefaultMcpServerSessionFactory.java | 50 +++++++++++++++++++ 2 files changed, 62 insertions(+), 35 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpServerSessionFactory.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 3c8057a72..f671b8527 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -4,47 +4,25 @@ package io.modelcontextprotocol.server; -import java.time.Duration; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.function.BiFunction; - -import io.modelcontextprotocol.spec.DefaultMcpStreamableServerSessionFactory; -import io.modelcontextprotocol.spec.McpServerTransportProviderBase; -import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.spec.JsonSchemaValidator; -import io.modelcontextprotocol.spec.McpClientSession; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse; -import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; -import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; -import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; -import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; -import io.modelcontextprotocol.spec.McpSchema.TextContent; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.*; +import io.modelcontextprotocol.spec.McpSchema.*; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import java.time.Duration; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.BiFunction; + /** * The Model Context Protocol (MCP) server implementation that provides asynchronous * communication using Project Reactor's Mono and Flux types. @@ -148,9 +126,8 @@ public class McpAsyncServer { Map notificationHandlers = prepareNotificationHandlers(features); this.protocolVersions = mcpTransportProvider.protocolVersions(); - - mcpTransportProvider.setSessionFactory(transport -> new McpServerSession(UUID.randomUUID().toString(), - requestTimeout, transport, this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); + mcpTransportProvider.setSessionFactory(new DefaultMcpServerSessionFactory(requestTimeout, + this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); } McpAsyncServer(McpStreamableServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpServerSessionFactory.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpServerSessionFactory.java new file mode 100644 index 000000000..2aae285b8 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpServerSessionFactory.java @@ -0,0 +1,50 @@ +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.server.McpInitRequestHandler; +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; + +import java.time.Duration; +import java.util.Map; + +/** + * The default implementation of {@link McpServerSession.Factory}. + * + * @author He-Pin + */ +public class DefaultMcpServerSessionFactory implements McpServerSession.Factory { + + Duration requestTimeout; + + McpInitRequestHandler initHandler; + + Map> requestHandlers; + + Map notificationHandlers; + + public DefaultMcpServerSessionFactory(final Duration requestTimeout, final McpInitRequestHandler initHandler, + final Map> requestHandlers, + final Map notificationHandlers) { + this.requestTimeout = requestTimeout; + this.initHandler = initHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + @Override + public McpServerSession create(final McpServerTransport sessionTransport) { + final String sessionId = generateSessionId(sessionTransport); + return new McpServerSession(sessionId, requestTimeout, sessionTransport, initHandler, requestHandlers, + notificationHandlers); + } + + /** + * Generate a unique session ID for the given transport. + * @param sessionTransport the transport + * @return unique session ID + */ + protected String generateSessionId(final McpServerTransport sessionTransport) { + return java.util.UUID.randomUUID().toString(); + } + +} From 453820138f6ad881497ebbffdeca3879b34e337d Mon Sep 17 00:00:00 2001 From: He-Pin Date: Mon, 8 Sep 2025 22:25:53 +0800 Subject: [PATCH 2/2] chore: Rewirte lambda in McpAsyncServer with local methods. Signed-off-by: He-Pin --- .../server/McpAsyncServer.java | 276 +++++++++--------- 1 file changed, 137 insertions(+), 139 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index f671b8527..9c9c727df 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -158,7 +158,7 @@ public class McpAsyncServer { private Map prepareNotificationHandlers(McpServerFeatures.Async features) { Map notificationHandlers = new HashMap<>(); - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, this::handleNotificationInitialized); List, Mono>> rootsChangeConsumers = features .rootsChangeConsumers(); @@ -173,45 +173,58 @@ private Map prepareNotificationHandlers(McpServe return notificationHandlers; } - private Map> prepareRequestHandlers() { - Map> requestHandlers = new HashMap<>(); + protected Mono handleNotificationInitialized(final McpAsyncServerExchange exchange, final Object params) { + return Mono.just(Map.of()).then(); + } + @SuppressWarnings("rawtypes") + private Map> prepareRequestHandlers() { + final Map> requestHandlers = new HashMap<>(); // Initialize request handlers for standard MCP methods // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); + requestHandlers.put(McpSchema.METHOD_PING, (McpRequestHandler) this::handlePing); // Add tools API handlers if the tool capability is enabled if (this.serverCapabilities.tools() != null) { - requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, (McpRequestHandler) this::handleListTools); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, (McpRequestHandler) this::handleToolCall); } // Add resources API handlers if provided if (this.serverCapabilities.resources() != null) { - requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, (McpRequestHandler) this::handleListResources); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, (McpRequestHandler) this::handleReadResources); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, + (McpRequestHandler) this::handleListResourceTemplates); } // Add prompts API handlers if provider exists if (this.serverCapabilities.prompts() != null) { - requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, (McpRequestHandler) this::handleListPrompts); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, (McpRequestHandler) this::handleGetPrompt); } // Add logging API handlers if the logging capability is enabled if (this.serverCapabilities.logging() != null) { - requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); + requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, (McpRequestHandler) this::handleSetLogger); } // Add completion API handlers if the completion capability is enabled if (this.serverCapabilities.completions() != null) { - requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); + requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, + (McpRequestHandler) this::handleCompletionComplete); } return requestHandlers; } + /** + * Handle the ping request. + */ + protected Mono> handlePing(McpAsyncServerExchange exchange, Object params) { + return Mono.just(Map.of()); + } + // --------------------------------------- // Lifecycle Management // --------------------------------------- @@ -472,31 +485,27 @@ public Mono notifyToolsListChanged() { return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); } - private McpRequestHandler toolsListRequestHandler() { - return (exchange, params) -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + protected Mono handleListTools(final McpAsyncServerExchange exchange, final Object params) { + List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); - return Mono.just(new McpSchema.ListToolsResult(tools, null)); - }; + return Mono.just(new McpSchema.ListToolsResult(tools, null)); } - private McpRequestHandler toolsCallRequestHandler() { - return (exchange, params) -> { - McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, - new TypeReference() { - }); + protected Mono handleToolCall(final McpAsyncServerExchange exchange, final Object params) { + McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, + new TypeReference() { + }); - Optional toolSpecification = this.tools.stream() - .filter(tr -> callToolRequest.name().equals(tr.tool().name())) - .findAny(); + Optional toolSpecification = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); - if (toolSpecification.isEmpty()) { - return Mono.error(new McpError(new JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INVALID_PARAMS, - "Unknown tool: invalid_tool_name", "Tool not found: " + callToolRequest.name()))); - } + if (toolSpecification.isEmpty()) { + return Mono.error(new McpError(new JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INVALID_PARAMS, + "Unknown tool: invalid_tool_name", "Tool not found: " + callToolRequest.name()))); + } - return toolSpecification.get().callHandler().apply(exchange, callToolRequest); - }; + return toolSpecification.get().callHandler().apply(exchange, callToolRequest); } // --------------------------------------- @@ -573,21 +582,19 @@ public Mono notifyResourcesUpdated(McpSchema.ResourcesUpdatedNotification resourcesUpdatedNotification); } - private McpRequestHandler resourcesListRequestHandler() { - return (exchange, params) -> { - var resourceList = this.resources.values() - .stream() - .map(McpServerFeatures.AsyncResourceSpecification::resource) - .filter(resource -> !resource.uri().contains("{")) - .toList(); - return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); - }; + protected Mono handleListResources(final McpAsyncServerExchange exchange, + final Object params) { + var resourceList = this.resources.values() + .stream() + .map(McpServerFeatures.AsyncResourceSpecification::resource) + .filter(resource -> !resource.uri().contains("{")) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); } - private McpRequestHandler resourceTemplateListRequestHandler() { - return (exchange, params) -> Mono - .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); - + protected Mono handleListResourceTemplates(final McpAsyncServerExchange exchange, + final Object params) { + return Mono.just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); } private List getResourceTemplates() { @@ -608,23 +615,21 @@ private List getResourceTemplates() { return list; } - private McpRequestHandler resourcesReadRequestHandler() { - return (exchange, params) -> { - McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, - new TypeReference() { - }); - var resourceUri = resourceRequest.uri(); + protected Mono handleReadResources(final McpAsyncServerExchange exchange, final Object params) { + McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + var resourceUri = resourceRequest.uri(); - McpServerFeatures.AsyncResourceSpecification specification = this.resources.values() - .stream() - .filter(resourceSpecification -> this.uriTemplateManagerFactory - .create(resourceSpecification.resource().uri()) - .matches(resourceUri)) - .findFirst() - .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); + McpServerFeatures.AsyncResourceSpecification specification = this.resources.values() + .stream() + .filter(resourceSpecification -> this.uriTemplateManagerFactory + .create(resourceSpecification.resource().uri()) + .matches(resourceUri)) + .findFirst() + .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); - return Mono.defer(() -> specification.readHandler().apply(exchange, resourceRequest)); - }; + return Mono.defer(() -> specification.readHandler().apply(exchange, resourceRequest)); } // --------------------------------------- @@ -701,36 +706,32 @@ public Mono notifyPromptsListChanged() { return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); } - private McpRequestHandler promptsListRequestHandler() { - return (exchange, params) -> { - // TODO: Implement pagination - // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, - // new TypeReference() { - // }); + protected Mono handleListPrompts(final McpAsyncServerExchange exchange, final Object params) { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + // new TypeReference() { + // }); - var promptList = this.prompts.values() - .stream() - .map(McpServerFeatures.AsyncPromptSpecification::prompt) - .toList(); + var promptList = this.prompts.values() + .stream() + .map(McpServerFeatures.AsyncPromptSpecification::prompt) + .toList(); - return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); - }; + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); } - private McpRequestHandler promptsGetRequestHandler() { - return (exchange, params) -> { - McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, - new TypeReference() { - }); + protected Mono handleGetPrompt(final McpAsyncServerExchange exchange, final Object params) { + McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, + new TypeReference() { + }); - // Implement prompt retrieval logic here - McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); - if (specification == null) { - return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); - } + // Implement prompt retrieval logic here + McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + if (specification == null) { + return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + } - return Mono.defer(() -> specification.promptHandler().apply(exchange, promptRequest)); - }; + return Mono.defer(() -> specification.promptHandler().apply(exchange, promptRequest)); } // --------------------------------------- @@ -763,79 +764,76 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN loggingMessageNotification); } - private McpRequestHandler setLoggerRequestHandler() { - return (exchange, params) -> { - return Mono.defer(() -> { + protected Mono> handleSetLogger(final McpAsyncServerExchange exchange, final Object params) { + return Mono.defer(() -> { - SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, - new TypeReference() { - }); + SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, + new TypeReference() { + }); - exchange.setMinLoggingLevel(newMinLoggingLevel.level()); + exchange.setMinLoggingLevel(newMinLoggingLevel.level()); - // FIXME: this field is deprecated and should be removed together - // with the broadcasting loggingNotification. - this.minLoggingLevel = newMinLoggingLevel.level(); + // FIXME: this field is deprecated and should be removed together + // with the broadcasting loggingNotification. + this.minLoggingLevel = newMinLoggingLevel.level(); - return Mono.just(Map.of()); - }); - }; + return Mono.just(Map.of()); + }); } - private McpRequestHandler completionCompleteRequestHandler() { - return (exchange, params) -> { - McpSchema.CompleteRequest request = parseCompletionParams(params); + protected Mono handleCompletionComplete(final McpAsyncServerExchange exchange, + final Object params) { + McpSchema.CompleteRequest request = parseCompletionParams(params); - if (request.ref() == null) { - return Mono.error(new McpError("ref must not be null")); - } + if (request.ref() == null) { + return Mono.error(new McpError("ref must not be null")); + } - if (request.ref().type() == null) { - return Mono.error(new McpError("type must not be null")); - } + if (request.ref().type() == null) { + return Mono.error(new McpError("type must not be null")); + } - String type = request.ref().type(); + String type = request.ref().type(); - String argumentName = request.argument().name(); + String argumentName = request.argument().name(); - // check if the referenced resource exists - if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { - McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); - if (promptSpec == null) { - return Mono.error(new McpError("Prompt not found: " + promptReference.name())); - } - if (!promptSpec.prompt() - .arguments() - .stream() - .filter(arg -> arg.name().equals(argumentName)) - .findFirst() - .isPresent()) { - - return Mono.error(new McpError("Argument not found: " + argumentName)); - } + // check if the referenced resource exists + if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { + McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); + if (promptSpec == null) { + return Mono.error(new McpError("Prompt not found: " + promptReference.name())); } + if (!promptSpec.prompt() + .arguments() + .stream() + .filter(arg -> arg.name().equals(argumentName)) + .findFirst() + .isPresent()) { - if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { - McpServerFeatures.AsyncResourceSpecification resourceSpec = this.resources.get(resourceReference.uri()); - if (resourceSpec == null) { - return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); - } - if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) - .getVariableNames() - .contains(argumentName)) { - return Mono.error(new McpError("Argument not found: " + argumentName)); - } + return Mono.error(new McpError("Argument not found: " + argumentName)); + } + } + if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + McpServerFeatures.AsyncResourceSpecification resourceSpec = this.resources.get(resourceReference.uri()); + if (resourceSpec == null) { + return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); + } + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + return Mono.error(new McpError("Argument not found: " + argumentName)); } - McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); + } - if (specification == null) { - return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); - } + McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); - return Mono.defer(() -> specification.completionHandler().apply(exchange, request)); - }; + if (specification == null) { + return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); + } + + return Mono.defer(() -> specification.completionHandler().apply(exchange, request)); } /**