diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java index 893c1910d70..7d2f33fbf1c 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java @@ -23,23 +23,9 @@ import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.spec.McpSchema; -import org.springaicommunity.mcp.method.changed.prompt.AsyncPromptListChangedSpecification; -import org.springaicommunity.mcp.method.changed.prompt.SyncPromptListChangedSpecification; -import org.springaicommunity.mcp.method.changed.resource.AsyncResourceListChangedSpecification; -import org.springaicommunity.mcp.method.changed.resource.SyncResourceListChangedSpecification; -import org.springaicommunity.mcp.method.changed.tool.AsyncToolListChangedSpecification; -import org.springaicommunity.mcp.method.changed.tool.SyncToolListChangedSpecification; -import org.springaicommunity.mcp.method.elicitation.AsyncElicitationSpecification; -import org.springaicommunity.mcp.method.elicitation.SyncElicitationSpecification; -import org.springaicommunity.mcp.method.logging.AsyncLoggingSpecification; -import org.springaicommunity.mcp.method.logging.SyncLoggingSpecification; -import org.springaicommunity.mcp.method.progress.AsyncProgressSpecification; -import org.springaicommunity.mcp.method.progress.SyncProgressSpecification; -import org.springaicommunity.mcp.method.sampling.AsyncSamplingSpecification; -import org.springaicommunity.mcp.method.sampling.SyncSamplingSpecification; - -import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpAsyncAnnotationCustomizer; -import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpSyncAnnotationCustomizer; + +import org.springframework.ai.mcp.annotation.spring.ClientMcpAsyncHandlersRegistry; +import org.springframework.ai.mcp.annotation.spring.ClientMcpSyncHandlersRegistry; import org.springframework.ai.mcp.client.common.autoconfigure.configurer.McpAsyncClientConfigurer; import org.springframework.ai.mcp.client.common.autoconfigure.configurer.McpSyncClientConfigurer; import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties; @@ -161,7 +147,8 @@ private String connectedClientName(String clientName, String serverConnectionNam matchIfMissing = true) public List mcpSyncClients(McpSyncClientConfigurer mcpSyncClientConfigurer, McpClientCommonProperties commonProperties, - ObjectProvider> transportsProvider) { + ObjectProvider> transportsProvider, + ClientMcpSyncHandlersRegistry clientMcpSyncHandlersRegistry) { List mcpSyncClients = new ArrayList<>(); @@ -176,7 +163,22 @@ public List mcpSyncClients(McpSyncClientConfigurer mcpSyncClientC McpClient.SyncSpec spec = McpClient.sync(namedTransport.transport()) .clientInfo(clientInfo) - .requestTimeout(commonProperties.getRequestTimeout()); + .requestTimeout(commonProperties.getRequestTimeout()) + .sampling(samplingRequest -> clientMcpSyncHandlersRegistry.handleSampling(namedTransport.name(), + samplingRequest)) + .elicitation(elicitationRequest -> clientMcpSyncHandlersRegistry + .handleElicitation(namedTransport.name(), elicitationRequest)) + .loggingConsumer(loggingMessageNotification -> clientMcpSyncHandlersRegistry + .handleLogging(namedTransport.name(), loggingMessageNotification)) + .progressConsumer(progressNotification -> clientMcpSyncHandlersRegistry + .handleProgress(namedTransport.name(), progressNotification)) + .toolsChangeConsumer(newTools -> clientMcpSyncHandlersRegistry + .handleToolListChanged(namedTransport.name(), newTools)) + .promptsChangeConsumer(newPrompts -> clientMcpSyncHandlersRegistry + .handlePromptListChanged(namedTransport.name(), newPrompts)) + .resourcesChangeConsumer(newResources -> clientMcpSyncHandlersRegistry + .handleResourceListChanged(namedTransport.name(), newResources)) + .capabilities(clientMcpSyncHandlersRegistry.getCapabilities(namedTransport.name())); spec = mcpSyncClientConfigurer.configure(namedTransport.name(), spec); @@ -222,27 +224,14 @@ McpSyncClientConfigurer mcpSyncClientConfigurer(ObjectProvider loggingSpecs, - List samplingSpecs, List elicitationSpecs, - List progressSpecs, - List syncToolListChangedSpecifications, - List syncResourceListChangedSpecifications, - List syncPromptListChangedSpecifications) { - return new McpSyncAnnotationCustomizer(samplingSpecs, loggingSpecs, elicitationSpecs, progressSpecs, - syncToolListChangedSpecifications, syncResourceListChangedSpecifications, - syncPromptListChangedSpecifications); - } - // Async client configuration @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") public List mcpAsyncClients(McpAsyncClientConfigurer mcpAsyncClientConfigurer, McpClientCommonProperties commonProperties, - ObjectProvider> transportsProvider) { + ObjectProvider> transportsProvider, + ClientMcpAsyncHandlersRegistry clientMcpAsyncHandlersRegistry) { List mcpAsyncClients = new ArrayList<>(); @@ -257,7 +246,22 @@ public List mcpAsyncClients(McpAsyncClientConfigurer mcpAsyncCli McpClient.AsyncSpec spec = McpClient.async(namedTransport.transport()) .clientInfo(clientInfo) - .requestTimeout(commonProperties.getRequestTimeout()); + .requestTimeout(commonProperties.getRequestTimeout()) + .sampling(samplingRequest -> clientMcpAsyncHandlersRegistry.handleSampling(namedTransport.name(), + samplingRequest)) + .elicitation(elicitationRequest -> clientMcpAsyncHandlersRegistry + .handleElicitation(namedTransport.name(), elicitationRequest)) + .loggingConsumer(loggingMessageNotification -> clientMcpAsyncHandlersRegistry + .handleLogging(namedTransport.name(), loggingMessageNotification)) + .progressConsumer(progressNotification -> clientMcpAsyncHandlersRegistry + .handleProgress(namedTransport.name(), progressNotification)) + .toolsChangeConsumer(newTools -> clientMcpAsyncHandlersRegistry + .handleToolListChanged(namedTransport.name(), newTools)) + .promptsChangeConsumer(newPrompts -> clientMcpAsyncHandlersRegistry + .handlePromptListChanged(namedTransport.name(), newPrompts)) + .resourcesChangeConsumer(newResources -> clientMcpAsyncHandlersRegistry + .handleResourceListChanged(namedTransport.name(), newResources)) + .capabilities(clientMcpAsyncHandlersRegistry.getCapabilities(namedTransport.name())); spec = mcpAsyncClientConfigurer.configure(namedTransport.name(), spec); @@ -287,18 +291,6 @@ McpAsyncClientConfigurer mcpAsyncClientConfigurer(ObjectProvider loggingSpecs, - List samplingSpecs, List elicitationSpecs, - List progressSpecs, - List toolListChangedSpecs, - List resourceListChangedSpecs, - List promptListChangedSpecs) { - return new McpAsyncAnnotationCustomizer(samplingSpecs, loggingSpecs, elicitationSpecs, progressSpecs, - toolListChangedSpecs, resourceListChangedSpecs, promptListChangedSpecs); - } - /** * Record class that implements {@link AutoCloseable} to ensure proper cleanup of MCP * clients. diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpAsyncAnnotationCustomizer.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpAsyncAnnotationCustomizer.java deleted file mode 100644 index 292942a2d63..00000000000 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpAsyncAnnotationCustomizer.java +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.client.common.autoconfigure.annotations; - -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.stream.Stream; - -import io.modelcontextprotocol.client.McpClient.AsyncSpec; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springaicommunity.mcp.method.changed.prompt.AsyncPromptListChangedSpecification; -import org.springaicommunity.mcp.method.changed.resource.AsyncResourceListChangedSpecification; -import org.springaicommunity.mcp.method.changed.tool.AsyncToolListChangedSpecification; -import org.springaicommunity.mcp.method.elicitation.AsyncElicitationSpecification; -import org.springaicommunity.mcp.method.logging.AsyncLoggingSpecification; -import org.springaicommunity.mcp.method.progress.AsyncProgressSpecification; -import org.springaicommunity.mcp.method.sampling.AsyncSamplingSpecification; - -import org.springframework.ai.mcp.customizer.McpAsyncClientCustomizer; -import org.springframework.util.CollectionUtils; - -/** - * @author Christian Tzolov - */ -public class McpAsyncAnnotationCustomizer implements McpAsyncClientCustomizer { - - private static final Logger logger = LoggerFactory.getLogger(McpAsyncAnnotationCustomizer.class); - - private final List asyncSamplingSpecifications; - - private final List asyncLoggingSpecifications; - - private final List asyncElicitationSpecifications; - - private final List asyncProgressSpecifications; - - private final List asyncToolListChangedSpecifications; - - private final List asyncResourceListChangedSpecifications; - - private final List asyncPromptListChangedSpecifications; - - // Tracking registered specifications per client - private final Map clientElicitationSpecs = new ConcurrentHashMap<>(); - - private final Map clientSamplingSpecs = new ConcurrentHashMap<>(); - - public McpAsyncAnnotationCustomizer(List asyncSamplingSpecifications, - List asyncLoggingSpecifications, - List asyncElicitationSpecifications, - List asyncProgressSpecifications, - List asyncToolListChangedSpecifications, - List asyncResourceListChangedSpecifications, - List asyncPromptListChangedSpecifications) { - - this.asyncSamplingSpecifications = asyncSamplingSpecifications; - this.asyncLoggingSpecifications = asyncLoggingSpecifications; - this.asyncElicitationSpecifications = asyncElicitationSpecifications; - this.asyncProgressSpecifications = asyncProgressSpecifications; - this.asyncToolListChangedSpecifications = asyncToolListChangedSpecifications; - this.asyncResourceListChangedSpecifications = asyncResourceListChangedSpecifications; - this.asyncPromptListChangedSpecifications = asyncPromptListChangedSpecifications; - } - - @Override - public void customize(String name, AsyncSpec clientSpec) { - - if (!CollectionUtils.isEmpty(this.asyncElicitationSpecifications)) { - this.asyncElicitationSpecifications.forEach(elicitationSpec -> { - Stream.of(elicitationSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - - // Check if client already has an elicitation spec - if (this.clientElicitationSpecs.containsKey(name)) { - throw new IllegalArgumentException("Client '" + name - + "' already has an elicitationSpec registered. Only one elicitationSpec is allowed per client."); - } - - this.clientElicitationSpecs.put(name, Boolean.TRUE); - clientSpec.elicitation(elicitationSpec.elicitationHandler()); - - logger.info("Registered elicitationSpec for client '{}'.", name); - - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.asyncSamplingSpecifications)) { - this.asyncSamplingSpecifications.forEach(samplingSpec -> { - Stream.of(samplingSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - - // Check if client already has a sampling spec - if (this.clientSamplingSpecs.containsKey(name)) { - throw new IllegalArgumentException("Client '" + name - + "' already has a samplingSpec registered. Only one samplingSpec is allowed per client."); - } - this.clientSamplingSpecs.put(name, Boolean.TRUE); - - clientSpec.sampling(samplingSpec.samplingHandler()); - - logger.info("Registered samplingSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.asyncLoggingSpecifications)) { - this.asyncLoggingSpecifications.forEach(loggingSpec -> { - Stream.of(loggingSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.loggingConsumer(loggingSpec.loggingHandler()); - logger.info("Registered loggingSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.asyncProgressSpecifications)) { - this.asyncProgressSpecifications.forEach(progressSpec -> { - Stream.of(progressSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.progressConsumer(progressSpec.progressHandler()); - logger.info("Registered progressSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.asyncToolListChangedSpecifications)) { - this.asyncToolListChangedSpecifications.forEach(toolListChangedSpec -> { - Stream.of(toolListChangedSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.toolsChangeConsumer(toolListChangedSpec.toolListChangeHandler()); - logger.info("Registered toolListChangedSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.asyncResourceListChangedSpecifications)) { - this.asyncResourceListChangedSpecifications.forEach(resourceListChangedSpec -> { - Stream.of(resourceListChangedSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.resourcesChangeConsumer(resourceListChangedSpec.resourceListChangeHandler()); - logger.info("Registered resourceListChangedSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.asyncPromptListChangedSpecifications)) { - this.asyncPromptListChangedSpecifications.forEach(promptListChangedSpec -> { - Stream.of(promptListChangedSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.promptsChangeConsumer(promptListChangedSpec.promptListChangeHandler()); - logger.info("Registered promptListChangedSpec for client '{}'.", name); - } - }); - }); - } - } - -} diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientAnnotationScannerAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientAnnotationScannerAutoConfiguration.java index 8ce05bcbe07..449c2a9da4b 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientAnnotationScannerAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientAnnotationScannerAutoConfiguration.java @@ -27,9 +27,11 @@ import org.springaicommunity.mcp.annotation.McpSampling; import org.springaicommunity.mcp.annotation.McpToolListChanged; +import org.springframework.ai.mcp.annotation.spring.ClientMcpAsyncHandlersRegistry; +import org.springframework.ai.mcp.annotation.spring.ClientMcpSyncHandlersRegistry; import org.springframework.ai.mcp.annotation.spring.scan.AbstractAnnotatedMethodBeanFactoryInitializationAotProcessor; -import org.springframework.ai.mcp.annotation.spring.scan.AbstractAnnotatedMethodBeanPostProcessor; import org.springframework.ai.mcp.annotation.spring.scan.AbstractMcpAnnotatedBeans; +import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties; import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHintsRegistrar; @@ -60,15 +62,17 @@ public class McpClientAnnotationScannerAutoConfiguration { @Bean @ConditionalOnMissingBean - public ClientMcpAnnotatedBeans clientAnnotatedBeans() { - return new ClientMcpAnnotatedBeans(); + @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", + matchIfMissing = true) + public ClientMcpSyncHandlersRegistry clientMcpSyncHandlersRegistry() { + return new ClientMcpSyncHandlersRegistry(); } @Bean @ConditionalOnMissingBean - public static ClientAnnotatedMethodBeanPostProcessor clientAnnotatedMethodBeanPostProcessor( - ClientMcpAnnotatedBeans clientMcpAnnotatedBeans, McpClientAnnotationScannerProperties properties) { - return new ClientAnnotatedMethodBeanPostProcessor(clientMcpAnnotatedBeans, CLIENT_MCP_ANNOTATIONS); + @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") + public ClientMcpAsyncHandlersRegistry clientMcpAsyncHandlersRegistry() { + return new ClientMcpAsyncHandlersRegistry(); } @Bean @@ -90,15 +94,6 @@ public ClientAnnotatedBeanFactoryInitializationAotProcessor( } - public static class ClientAnnotatedMethodBeanPostProcessor extends AbstractAnnotatedMethodBeanPostProcessor { - - public ClientAnnotatedMethodBeanPostProcessor(ClientMcpAnnotatedBeans clientMcpAnnotatedBeans, - Set> targetAnnotations) { - super(clientMcpAnnotatedBeans, targetAnnotations); - } - - } - static class AnnotationHints implements RuntimeHintsRegistrar { @Override diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientSpecificationFactoryAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientSpecificationFactoryAutoConfiguration.java deleted file mode 100644 index 620028f0e63..00000000000 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientSpecificationFactoryAutoConfiguration.java +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.client.common.autoconfigure.annotations; - -import java.util.List; - -import org.springaicommunity.mcp.annotation.McpElicitation; -import org.springaicommunity.mcp.annotation.McpLogging; -import org.springaicommunity.mcp.annotation.McpProgress; -import org.springaicommunity.mcp.annotation.McpPromptListChanged; -import org.springaicommunity.mcp.annotation.McpResourceListChanged; -import org.springaicommunity.mcp.annotation.McpSampling; -import org.springaicommunity.mcp.annotation.McpToolListChanged; -import org.springaicommunity.mcp.method.changed.prompt.AsyncPromptListChangedSpecification; -import org.springaicommunity.mcp.method.changed.prompt.SyncPromptListChangedSpecification; -import org.springaicommunity.mcp.method.changed.resource.AsyncResourceListChangedSpecification; -import org.springaicommunity.mcp.method.changed.resource.SyncResourceListChangedSpecification; -import org.springaicommunity.mcp.method.changed.tool.AsyncToolListChangedSpecification; -import org.springaicommunity.mcp.method.changed.tool.SyncToolListChangedSpecification; -import org.springaicommunity.mcp.method.elicitation.AsyncElicitationSpecification; -import org.springaicommunity.mcp.method.elicitation.SyncElicitationSpecification; -import org.springaicommunity.mcp.method.logging.AsyncLoggingSpecification; -import org.springaicommunity.mcp.method.logging.SyncLoggingSpecification; -import org.springaicommunity.mcp.method.progress.AsyncProgressSpecification; -import org.springaicommunity.mcp.method.progress.SyncProgressSpecification; -import org.springaicommunity.mcp.method.sampling.AsyncSamplingSpecification; -import org.springaicommunity.mcp.method.sampling.SyncSamplingSpecification; - -import org.springframework.ai.mcp.annotation.spring.AsyncMcpAnnotationProviders; -import org.springframework.ai.mcp.annotation.spring.SyncMcpAnnotationProviders; -import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration.ClientMcpAnnotatedBeans; -import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties; -import org.springframework.boot.autoconfigure.AutoConfiguration; -import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; -import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; - -/** - * @author Christian Tzolov - * @author Fu Jian - */ -@AutoConfiguration(after = McpClientAnnotationScannerAutoConfiguration.class) -@ConditionalOnClass(McpLogging.class) -@ConditionalOnProperty(prefix = McpClientAnnotationScannerProperties.CONFIG_PREFIX, name = "enabled", - havingValue = "true", matchIfMissing = true) -public class McpClientSpecificationFactoryAutoConfiguration { - - @Configuration(proxyBeanMethods = false) - @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", - matchIfMissing = true) - static class SyncClientSpecificationConfiguration { - - @Bean - List loggingSpecs(ClientMcpAnnotatedBeans beansWithMcpMethodAnnotations) { - return SyncMcpAnnotationProviders - .loggingSpecifications(beansWithMcpMethodAnnotations.getBeansByAnnotation(McpLogging.class)); - } - - @Bean - List samplingSpecs(ClientMcpAnnotatedBeans beansWithMcpMethodAnnotations) { - return SyncMcpAnnotationProviders - .samplingSpecifications(beansWithMcpMethodAnnotations.getBeansByAnnotation(McpSampling.class)); - } - - @Bean - List elicitationSpecs(ClientMcpAnnotatedBeans beansWithMcpMethodAnnotations) { - return SyncMcpAnnotationProviders - .elicitationSpecifications(beansWithMcpMethodAnnotations.getBeansByAnnotation(McpElicitation.class)); - } - - @Bean - List progressSpecs(ClientMcpAnnotatedBeans beansWithMcpMethodAnnotations) { - return SyncMcpAnnotationProviders - .progressSpecifications(beansWithMcpMethodAnnotations.getBeansByAnnotation(McpProgress.class)); - } - - @Bean - List syncToolListChangedSpecs( - ClientMcpAnnotatedBeans beansWithMcpMethodAnnotations) { - return SyncMcpAnnotationProviders.toolListChangedSpecifications( - beansWithMcpMethodAnnotations.getBeansByAnnotation(McpToolListChanged.class)); - } - - @Bean - List syncResourceListChangedSpecs( - ClientMcpAnnotatedBeans beansWithMcpMethodAnnotations) { - return SyncMcpAnnotationProviders.resourceListChangedSpecifications( - beansWithMcpMethodAnnotations.getBeansByAnnotation(McpResourceListChanged.class)); - } - - @Bean - List syncPromptListChangedSpecs( - ClientMcpAnnotatedBeans beansWithMcpMethodAnnotations) { - return SyncMcpAnnotationProviders.promptListChangedSpecifications( - beansWithMcpMethodAnnotations.getBeansByAnnotation(McpPromptListChanged.class)); - } - - } - - @Configuration(proxyBeanMethods = false) - @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") - static class AsyncClientSpecificationConfiguration { - - @Bean - List loggingSpecs(ClientMcpAnnotatedBeans beanRegistry) { - return AsyncMcpAnnotationProviders.loggingSpecifications(beanRegistry.getAllAnnotatedBeans()); - } - - @Bean - List samplingSpecs(ClientMcpAnnotatedBeans beanRegistry) { - return AsyncMcpAnnotationProviders.samplingSpecifications(beanRegistry.getAllAnnotatedBeans()); - } - - @Bean - List elicitationSpecs(ClientMcpAnnotatedBeans beanRegistry) { - return AsyncMcpAnnotationProviders.elicitationSpecifications(beanRegistry.getAllAnnotatedBeans()); - } - - @Bean - List progressSpecs(ClientMcpAnnotatedBeans beanRegistry) { - return AsyncMcpAnnotationProviders.progressSpecifications(beanRegistry.getAllAnnotatedBeans()); - } - - @Bean - List asyncToolListChangedSpecs(ClientMcpAnnotatedBeans beanRegistry) { - return AsyncMcpAnnotationProviders.toolListChangedSpecifications(beanRegistry.getAllAnnotatedBeans()); - } - - @Bean - List asyncResourceListChangedSpecs( - ClientMcpAnnotatedBeans beanRegistry) { - return AsyncMcpAnnotationProviders.resourceListChangedSpecifications(beanRegistry.getAllAnnotatedBeans()); - } - - @Bean - List asyncPromptListChangedSpecs(ClientMcpAnnotatedBeans beanRegistry) { - return AsyncMcpAnnotationProviders.promptListChangedSpecifications(beanRegistry.getAllAnnotatedBeans()); - } - - } - -} diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpSyncAnnotationCustomizer.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpSyncAnnotationCustomizer.java deleted file mode 100644 index 69d19bfe1c0..00000000000 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpSyncAnnotationCustomizer.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.client.common.autoconfigure.annotations; - -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.stream.Stream; - -import io.modelcontextprotocol.client.McpClient.SyncSpec; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springaicommunity.mcp.method.changed.prompt.SyncPromptListChangedSpecification; -import org.springaicommunity.mcp.method.changed.resource.SyncResourceListChangedSpecification; -import org.springaicommunity.mcp.method.changed.tool.SyncToolListChangedSpecification; -import org.springaicommunity.mcp.method.elicitation.SyncElicitationSpecification; -import org.springaicommunity.mcp.method.logging.SyncLoggingSpecification; -import org.springaicommunity.mcp.method.progress.SyncProgressSpecification; -import org.springaicommunity.mcp.method.sampling.SyncSamplingSpecification; - -import org.springframework.ai.mcp.customizer.McpSyncClientCustomizer; -import org.springframework.util.CollectionUtils; - -/** - * @author Christian Tzolov - */ -public class McpSyncAnnotationCustomizer implements McpSyncClientCustomizer { - - private static final Logger logger = LoggerFactory.getLogger(McpSyncAnnotationCustomizer.class); - - private final List syncSamplingSpecifications; - - private final List syncLoggingSpecifications; - - private final List syncElicitationSpecifications; - - private final List syncProgressSpecifications; - - private final List syncToolListChangedSpecifications; - - private final List syncResourceListChangedSpecifications; - - private final List syncPromptListChangedSpecifications; - - // Tracking registered specifications per client - private final Map clientElicitationSpecs = new ConcurrentHashMap<>(); - - private final Map clientSamplingSpecs = new ConcurrentHashMap<>(); - - public McpSyncAnnotationCustomizer(List syncSamplingSpecifications, - List syncLoggingSpecifications, - List syncElicitationSpecifications, - List syncProgressSpecifications, - List syncToolListChangedSpecifications, - List syncResourceListChangedSpecifications, - List syncPromptListChangedSpecifications) { - - this.syncSamplingSpecifications = syncSamplingSpecifications; - this.syncLoggingSpecifications = syncLoggingSpecifications; - this.syncElicitationSpecifications = syncElicitationSpecifications; - this.syncProgressSpecifications = syncProgressSpecifications; - this.syncToolListChangedSpecifications = syncToolListChangedSpecifications; - this.syncResourceListChangedSpecifications = syncResourceListChangedSpecifications; - this.syncPromptListChangedSpecifications = syncPromptListChangedSpecifications; - } - - @Override - public void customize(String name, SyncSpec clientSpec) { - - if (!CollectionUtils.isEmpty(this.syncElicitationSpecifications)) { - this.syncElicitationSpecifications.forEach(elicitationSpec -> { - Stream.of(elicitationSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - // Check if client already has an elicitation spec - if (this.clientElicitationSpecs.containsKey(name)) { - throw new IllegalArgumentException("Client '" + name - + "' already has an elicitationSpec registered. Only one elicitationSpec is allowed per client."); - } - - this.clientElicitationSpecs.put(name, Boolean.TRUE); - clientSpec.elicitation(elicitationSpec.elicitationHandler()); - - logger.info("Registered elicitationSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.syncSamplingSpecifications)) { - this.syncSamplingSpecifications.forEach(samplingSpec -> { - Stream.of(samplingSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - - // Check if client already has a sampling spec - if (this.clientSamplingSpecs.containsKey(name)) { - throw new IllegalArgumentException("Client '" + name - + "' already has a samplingSpec registered. Only one samplingSpec is allowed per client."); - } - this.clientSamplingSpecs.put(name, Boolean.TRUE); - - clientSpec.sampling(samplingSpec.samplingHandler()); - - logger.info("Registered samplingSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.syncLoggingSpecifications)) { - this.syncLoggingSpecifications.forEach(loggingSpec -> { - Stream.of(loggingSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.loggingConsumer(loggingSpec.loggingHandler()); - logger.info("Registered loggingSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.syncProgressSpecifications)) { - this.syncProgressSpecifications.forEach(progressSpec -> { - Stream.of(progressSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.progressConsumer(progressSpec.progressHandler()); - logger.info("Registered progressSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.syncToolListChangedSpecifications)) { - this.syncToolListChangedSpecifications.forEach(toolListChangedSpec -> { - Stream.of(toolListChangedSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.toolsChangeConsumer(toolListChangedSpec.toolListChangeHandler()); - logger.info("Registered toolListChangedSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.syncResourceListChangedSpecifications)) { - this.syncResourceListChangedSpecifications.forEach(resourceListChangedSpec -> { - Stream.of(resourceListChangedSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.resourcesChangeConsumer(resourceListChangedSpec.resourceListChangeHandler()); - logger.info("Registered resourceListChangedSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.syncPromptListChangedSpecifications)) { - this.syncPromptListChangedSpecifications.forEach(promptListChangedSpec -> { - Stream.of(promptListChangedSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.promptsChangeConsumer(promptListChangedSpec.promptListChangeHandler()); - logger.info("Registered promptListChangedSpec for client '{}'.", name); - } - }); - }); - } - } - -} diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index 120dd1beab9..38cd4021d5c 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -17,5 +17,4 @@ org.springframework.ai.mcp.client.common.autoconfigure.StdioTransportAutoConfiguration org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration -org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfigurationIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfigurationIT.java index 1d1fbb92ae4..b4a72354db7 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfigurationIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfigurationIT.java @@ -29,6 +29,7 @@ import org.mockito.Mockito; import reactor.core.publisher.Mono; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.configurer.McpSyncClientConfigurer; import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties; import org.springframework.ai.mcp.customizer.McpSyncClientCustomizer; @@ -84,8 +85,9 @@ */ public class McpClientAutoConfigurationIT { - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withConfiguration( - AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class)); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, + McpClientAutoConfiguration.class, McpClientAnnotationScannerAutoConfiguration.class)); /** * Tests the default MCP client auto-configuration. diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientListChangedAnnotationsScanningIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientListChangedAnnotationsScanningIT.java index d00e3cc6b35..d406da99dd3 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientListChangedAnnotationsScanningIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientListChangedAnnotationsScanningIT.java @@ -16,15 +16,21 @@ package org.springframework.ai.mcp.client.common.autoconfigure.annotations; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; import io.modelcontextprotocol.spec.McpSchema; +import org.junit.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.springaicommunity.mcp.annotation.McpPromptListChanged; import org.springaicommunity.mcp.annotation.McpResourceListChanged; import org.springaicommunity.mcp.annotation.McpToolListChanged; +import reactor.core.publisher.Mono; +import org.springframework.ai.mcp.annotation.spring.ClientMcpAsyncHandlersRegistry; +import org.springframework.ai.mcp.annotation.spring.ClientMcpSyncHandlersRegistry; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; @@ -45,28 +51,67 @@ public class McpClientListChangedAnnotationsScanningIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(McpClientAnnotationScannerAutoConfiguration.class, - McpClientSpecificationFactoryAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(McpClientAnnotationScannerAutoConfiguration.class)); - @ParameterizedTest - @ValueSource(strings = { "SYNC", "ASYNC" }) - void shouldScanAllThreeListChangedAnnotations(String clientType) { - String prefix = clientType.toLowerCase(); + @Test + public void shouldScanAllThreeListChangedAnnotationsSync() { + this.contextRunner.withUserConfiguration(AllListChangedConfiguration.class) + .withPropertyValues("spring.ai.mcp.client.type=SYNC") + .run(context -> { + // Verify all three annotations were scanned + var registry = context.getBean(ClientMcpSyncHandlersRegistry.class); + var handlers = context.getBean(TestListChangedHandlers.class); + assertThat(registry).isNotNull(); + + List updatedTools = List.of(McpSchema.Tool.builder().name("tool-1").build(), + McpSchema.Tool.builder().name("tool-2").build()); + List updatedPrompts = List.of( + new McpSchema.Prompt("prompt-1", "a test prompt", Collections.emptyList()), + new McpSchema.Prompt("prompt-2", "another test prompt", Collections.emptyList())); + List updatedResources = List.of( + McpSchema.Resource.builder().name("resource-1").uri("file:///resource/1").build(), + McpSchema.Resource.builder().name("resource-2").uri("file:///resource/2").build()); + + registry.handleToolListChanged("test-client", updatedTools); + registry.handleResourceListChanged("test-client", updatedResources); + registry.handlePromptListChanged("test-client", updatedPrompts); + + assertThat(handlers.getCalls()).hasSize(3) + .containsExactlyInAnyOrder( + new TestListChangedHandlers.Call("resource-list-changed", updatedResources), + new TestListChangedHandlers.Call("prompt-list-changed", updatedPrompts), + new TestListChangedHandlers.Call("tool-list-changed", updatedTools)); + }); + } + @Test + public void shouldScanAllThreeListChangedAnnotationsAsync() { this.contextRunner.withUserConfiguration(AllListChangedConfiguration.class) - .withPropertyValues("spring.ai.mcp.client.type=" + clientType) + .withPropertyValues("spring.ai.mcp.client.type=ASYNC") .run(context -> { // Verify all three annotations were scanned - McpClientAnnotationScannerAutoConfiguration.ClientMcpAnnotatedBeans annotatedBeans = context - .getBean(McpClientAnnotationScannerAutoConfiguration.ClientMcpAnnotatedBeans.class); - assertThat(annotatedBeans.getBeansByAnnotation(McpToolListChanged.class)).hasSize(1); - assertThat(annotatedBeans.getBeansByAnnotation(McpResourceListChanged.class)).hasSize(1); - assertThat(annotatedBeans.getBeansByAnnotation(McpPromptListChanged.class)).hasSize(1); - - // Verify all three specification beans were created - assertThat(context).hasBean(prefix + "ToolListChangedSpecs"); - assertThat(context).hasBean(prefix + "ResourceListChangedSpecs"); - assertThat(context).hasBean(prefix + "PromptListChangedSpecs"); + var registry = context.getBean(ClientMcpAsyncHandlersRegistry.class); + var handlers = context.getBean(TestListChangedHandlers.class); + assertThat(registry).isNotNull(); + + List updatedTools = List.of(McpSchema.Tool.builder().name("tool-1").build(), + McpSchema.Tool.builder().name("tool-2").build()); + List updatedPrompts = List.of( + new McpSchema.Prompt("prompt-1", "a test prompt", Collections.emptyList()), + new McpSchema.Prompt("prompt-2", "another test prompt", Collections.emptyList())); + List updatedResources = List.of( + McpSchema.Resource.builder().name("resource-1").uri("file:///resource/1").build(), + McpSchema.Resource.builder().name("resource-2").uri("file:///resource/2").build()); + + registry.handleToolListChanged("test-client", updatedTools).block(); + registry.handleResourceListChanged("test-client", updatedResources).block(); + registry.handlePromptListChanged("test-client", updatedPrompts).block(); + + assertThat(handlers.getCalls()).hasSize(3) + .containsExactlyInAnyOrder( + new TestListChangedHandlers.Call("resource-list-changed", updatedResources), + new TestListChangedHandlers.Call("prompt-list-changed", updatedPrompts), + new TestListChangedHandlers.Call("tool-list-changed", updatedTools)); }); } @@ -80,10 +125,8 @@ void shouldNotScanAnnotationsWhenScannerDisabled(String clientType) { "spring.ai.mcp.client.annotation-scanner.enabled=false") .run(context -> { // Verify scanner beans were not created - assertThat(context).doesNotHaveBean(McpClientAnnotationScannerAutoConfiguration.class); - assertThat(context).doesNotHaveBean(prefix + "ToolListChangedSpecs"); - assertThat(context).doesNotHaveBean(prefix + "ResourceListChangedSpecs"); - assertThat(context).doesNotHaveBean(prefix + "PromptListChangedSpecs"); + assertThat(context).doesNotHaveBean(ClientMcpSyncHandlersRegistry.class); + assertThat(context).doesNotHaveBean(ClientMcpAsyncHandlersRegistry.class); }); } @@ -99,19 +142,47 @@ TestListChangedHandlers testHandlers() { static class TestListChangedHandlers { + private final List calls = new ArrayList<>(); + + public List getCalls() { + return this.calls; + } + @McpToolListChanged(clients = "test-client") public void onToolListChanged(List updatedTools) { - // Test handler for tool list changes + this.calls.add(new Call("tool-list-changed", updatedTools)); } @McpResourceListChanged(clients = "test-client") public void onResourceListChanged(List updatedResources) { - // Test handler for resource list changes + this.calls.add(new Call("resource-list-changed", updatedResources)); } @McpPromptListChanged(clients = "test-client") public void onPromptListChanged(List updatedPrompts) { - // Test handler for prompt list changes + this.calls.add(new Call("prompt-list-changed", updatedPrompts)); + } + + @McpToolListChanged(clients = "test-client") + public Mono onToolListChangedReactive(List updatedTools) { + this.calls.add(new Call("tool-list-changed", updatedTools)); + return Mono.empty(); + } + + @McpResourceListChanged(clients = "test-client") + public Mono onResourceListChangedReactive(List updatedResources) { + this.calls.add(new Call("resource-list-changed", updatedResources)); + return Mono.empty(); + } + + @McpPromptListChanged(clients = "test-client") + public Mono onPromptListChangedReactive(List updatedPrompts) { + this.calls.add(new Call("prompt-list-changed", updatedPrompts)); + return Mono.empty(); + } + + // Record calls made to this object + record Call(String name, Object callRequest) { } } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpSyncAnnotationCustomizerTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpSyncAnnotationCustomizerTests.java deleted file mode 100644 index 2e6f2f39b53..00000000000 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpSyncAnnotationCustomizerTests.java +++ /dev/null @@ -1,366 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.client.common.autoconfigure.annotations; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import io.modelcontextprotocol.client.McpClient.SyncSpec; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.springaicommunity.mcp.method.changed.prompt.SyncPromptListChangedSpecification; -import org.springaicommunity.mcp.method.changed.resource.SyncResourceListChangedSpecification; -import org.springaicommunity.mcp.method.changed.tool.SyncToolListChangedSpecification; -import org.springaicommunity.mcp.method.elicitation.SyncElicitationSpecification; -import org.springaicommunity.mcp.method.logging.SyncLoggingSpecification; -import org.springaicommunity.mcp.method.progress.SyncProgressSpecification; -import org.springaicommunity.mcp.method.sampling.SyncSamplingSpecification; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; - -@ExtendWith(MockitoExtension.class) -class McpSyncAnnotationCustomizerTests { - - @Mock - private SyncSpec syncSpec; - - private List samplingSpecs; - - private List loggingSpecs; - - private List elicitationSpecs; - - private List progressSpecs; - - private List toolListChangedSpecs; - - private List resourceListChangedSpecs; - - private List promptListChangedSpecs; - - @BeforeEach - void setUp() { - this.samplingSpecs = new ArrayList<>(); - this.loggingSpecs = new ArrayList<>(); - this.elicitationSpecs = new ArrayList<>(); - this.progressSpecs = new ArrayList<>(); - this.toolListChangedSpecs = new ArrayList<>(); - this.resourceListChangedSpecs = new ArrayList<>(); - this.promptListChangedSpecs = new ArrayList<>(); - } - - @Test - void constructorShouldInitializeAllFields() { - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - assertThat(customizer).isNotNull(); - } - - @Test - void constructorShouldAcceptNullLists() { - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(null, null, null, null, null, null, - null); - - assertThat(customizer).isNotNull(); - } - - @Test - void customizeShouldNotRegisterAnythingWhenAllListsAreEmpty() { - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - customizer.customize("test-client", this.syncSpec); - - verifyNoInteractions(this.syncSpec); - } - - @Test - void customizeShouldNotRegisterElicitationSpecForNonMatchingClient() { - SyncElicitationSpecification elicitationSpec = mock(SyncElicitationSpecification.class); - when(elicitationSpec.clients()).thenReturn(new String[] { "other-client" }); - this.elicitationSpecs.add(elicitationSpec); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - customizer.customize("test-client", this.syncSpec); - - verifyNoInteractions(this.syncSpec); - } - - @Test - void customizeShouldThrowExceptionWhenDuplicateElicitationSpecRegistered() { - SyncElicitationSpecification elicitationSpec1 = mock(SyncElicitationSpecification.class); - SyncElicitationSpecification elicitationSpec2 = mock(SyncElicitationSpecification.class); - - when(elicitationSpec1.clients()).thenReturn(new String[] { "test-client" }); - when(elicitationSpec1.elicitationHandler()).thenReturn(request -> null); - when(elicitationSpec2.clients()).thenReturn(new String[] { "test-client" }); - // No need to stub elicitationSpec2.elicitationHandler() as exception is thrown - // before it's accessed - - this.elicitationSpecs.addAll(Arrays.asList(elicitationSpec1, elicitationSpec2)); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - assertThatThrownBy(() -> customizer.customize("test-client", this.syncSpec)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Client 'test-client' already has an elicitationSpec registered"); - } - - @Test - void customizeShouldThrowExceptionWhenDuplicateSamplingSpecRegistered() { - SyncSamplingSpecification samplingSpec1 = mock(SyncSamplingSpecification.class); - SyncSamplingSpecification samplingSpec2 = mock(SyncSamplingSpecification.class); - - when(samplingSpec1.clients()).thenReturn(new String[] { "test-client" }); - when(samplingSpec1.samplingHandler()).thenReturn(request -> null); - when(samplingSpec2.clients()).thenReturn(new String[] { "test-client" }); - // No need to stub samplingSpec2.samplingHandler() as exception is thrown before - // it's accessed - - this.samplingSpecs.addAll(Arrays.asList(samplingSpec1, samplingSpec2)); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - assertThatThrownBy(() -> customizer.customize("test-client", this.syncSpec)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Client 'test-client' already has a samplingSpec registered"); - } - - @Test - void customizeShouldSkipSpecificationsWithNonMatchingClientIds() { - // Setup specs with different client IDs - SyncLoggingSpecification loggingSpec = mock(SyncLoggingSpecification.class); - SyncProgressSpecification progressSpec = mock(SyncProgressSpecification.class); - SyncElicitationSpecification elicitationSpec = mock(SyncElicitationSpecification.class); - - when(loggingSpec.clients()).thenReturn(new String[] { "other-client" }); - when(progressSpec.clients()).thenReturn(new String[] { "another-client" }); - when(elicitationSpec.clients()).thenReturn(new String[] { "different-client" }); - - this.loggingSpecs.add(loggingSpec); - this.progressSpecs.add(progressSpec); - this.elicitationSpecs.add(elicitationSpec); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - customizer.customize("target-client", this.syncSpec); - - // None of the specifications should be registered since client IDs don't match - verifyNoInteractions(this.syncSpec); - } - - @Test - void customizeShouldAllowElicitationSpecForDifferentClients() { - SyncElicitationSpecification elicitationSpec1 = mock(SyncElicitationSpecification.class); - SyncElicitationSpecification elicitationSpec2 = mock(SyncElicitationSpecification.class); - - when(elicitationSpec1.clients()).thenReturn(new String[] { "client1" }); - when(elicitationSpec1.elicitationHandler()).thenReturn(request -> null); - when(elicitationSpec2.clients()).thenReturn(new String[] { "client2" }); - when(elicitationSpec2.elicitationHandler()).thenReturn(request -> null); - - this.elicitationSpecs.addAll(Arrays.asList(elicitationSpec1, elicitationSpec2)); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - // Should not throw exception since they are for different clients - SyncSpec syncSpec1 = mock(SyncSpec.class); - customizer.customize("client1", syncSpec1); - - SyncSpec syncSpec2 = mock(SyncSpec.class); - customizer.customize("client2", syncSpec2); - - // No exception should be thrown, indicating successful registration for different - // clients - } - - @Test - void customizeShouldAllowSamplingSpecForDifferentClients() { - SyncSamplingSpecification samplingSpec1 = mock(SyncSamplingSpecification.class); - SyncSamplingSpecification samplingSpec2 = mock(SyncSamplingSpecification.class); - - when(samplingSpec1.clients()).thenReturn(new String[] { "client1" }); - when(samplingSpec1.samplingHandler()).thenReturn(request -> null); - when(samplingSpec2.clients()).thenReturn(new String[] { "client2" }); - when(samplingSpec2.samplingHandler()).thenReturn(request -> null); - - this.samplingSpecs.addAll(Arrays.asList(samplingSpec1, samplingSpec2)); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - // Should not throw exception since they are for different clients - SyncSpec syncSpec1 = mock(SyncSpec.class); - customizer.customize("client1", syncSpec1); - - SyncSpec syncSpec2 = mock(SyncSpec.class); - customizer.customize("client2", syncSpec2); - - // No exception should be thrown, indicating successful registration for different - // clients - } - - @Test - void customizeShouldPreventMultipleElicitationCallsForSameClient() { - SyncElicitationSpecification elicitationSpec = mock(SyncElicitationSpecification.class); - when(elicitationSpec.clients()).thenReturn(new String[] { "test-client" }); - when(elicitationSpec.elicitationHandler()).thenReturn(request -> null); - this.elicitationSpecs.add(elicitationSpec); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - // First call should succeed - customizer.customize("test-client", this.syncSpec); - - // Second call should throw exception - SyncSpec syncSpec2 = mock(SyncSpec.class); - assertThatThrownBy(() -> customizer.customize("test-client", syncSpec2)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Client 'test-client' already has an elicitationSpec registered"); - } - - @Test - void customizeShouldPreventMultipleSamplingCallsForSameClient() { - SyncSamplingSpecification samplingSpec = mock(SyncSamplingSpecification.class); - when(samplingSpec.clients()).thenReturn(new String[] { "test-client" }); - when(samplingSpec.samplingHandler()).thenReturn(request -> null); - this.samplingSpecs.add(samplingSpec); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - // First call should succeed - customizer.customize("test-client", this.syncSpec); - - // Second call should throw exception - SyncSpec syncSpec2 = mock(SyncSpec.class); - assertThatThrownBy(() -> customizer.customize("test-client", syncSpec2)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Client 'test-client' already has a samplingSpec registered"); - } - - @Test - void customizeShouldPerformCaseInsensitiveClientIdMatching() { - SyncElicitationSpecification elicitationSpec = mock(SyncElicitationSpecification.class); - when(elicitationSpec.clients()).thenReturn(new String[] { "TEST-CLIENT" }); - when(elicitationSpec.elicitationHandler()).thenReturn(request -> null); - this.elicitationSpecs.add(elicitationSpec); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - // Should register elicitation spec when client ID matches case-insensitively - customizer.customize("test-client", this.syncSpec); - - // Verify that a subsequent call for the same client (case-insensitive) throws - // exception - SyncSpec syncSpec2 = mock(SyncSpec.class); - assertThatThrownBy(() -> customizer.customize("test-client", syncSpec2)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Client 'test-client' already has an elicitationSpec registered"); - } - - @Test - void customizeShouldHandleEmptyClientName() { - SyncLoggingSpecification loggingSpec = mock(SyncLoggingSpecification.class); - when(loggingSpec.clients()).thenReturn(new String[] { "" }); - when(loggingSpec.loggingHandler()).thenReturn(message -> { - }); - this.loggingSpecs.add(loggingSpec); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - // Should not throw exception when customizing for empty client name - customizer.customize("", this.syncSpec); - - } - - @Test - void customizeShouldAllowMultipleLoggingSpecsForSameClient() { - SyncLoggingSpecification loggingSpec1 = mock(SyncLoggingSpecification.class); - SyncLoggingSpecification loggingSpec2 = mock(SyncLoggingSpecification.class); - - when(loggingSpec1.clients()).thenReturn(new String[] { "test-client" }); - when(loggingSpec1.loggingHandler()).thenReturn(message -> { - }); - when(loggingSpec2.clients()).thenReturn(new String[] { "test-client" }); - when(loggingSpec2.loggingHandler()).thenReturn(message -> { - }); - - this.loggingSpecs.addAll(Arrays.asList(loggingSpec1, loggingSpec2)); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - // Should not throw exception for multiple logging specs for same client - customizer.customize("test-client", this.syncSpec); - - } - - @Test - void customizeShouldAllowMultipleProgressSpecsForSameClient() { - SyncProgressSpecification progressSpec1 = mock(SyncProgressSpecification.class); - SyncProgressSpecification progressSpec2 = mock(SyncProgressSpecification.class); - - when(progressSpec1.clients()).thenReturn(new String[] { "test-client" }); - when(progressSpec1.progressHandler()).thenReturn(notification -> { - }); - when(progressSpec2.clients()).thenReturn(new String[] { "test-client" }); - when(progressSpec2.progressHandler()).thenReturn(notification -> { - }); - - this.progressSpecs.addAll(Arrays.asList(progressSpec1, progressSpec2)); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - // Should not throw exception for multiple progress specs for same client - customizer.customize("test-client", this.syncSpec); - } - -} diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/SseHttpClientTransportAutoConfigurationIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/SseHttpClientTransportAutoConfigurationIT.java index 7dae305197e..8d3fbf94e5e 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/SseHttpClientTransportAutoConfigurationIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/SseHttpClientTransportAutoConfigurationIT.java @@ -33,6 +33,7 @@ import reactor.core.publisher.Mono; import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.ai.mcp.client.httpclient.autoconfigure.SseHttpClientTransportAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.context.annotation.UserConfigurations; @@ -56,8 +57,8 @@ public class SseHttpClientTransportAutoConfigurationIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.mcp.client.initialized=false", "spring.ai.mcp.client.sse.connections.server1.url=" + host) - .withConfiguration( - AutoConfigurations.of(McpClientAutoConfiguration.class, SseHttpClientTransportAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(McpClientAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class, SseHttpClientTransportAutoConfiguration.class)); static String host = "http://localhost:3001"; diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java index 230605fdc46..0b52cc49ecb 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java @@ -33,6 +33,7 @@ import reactor.core.publisher.Mono; import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.ai.mcp.client.httpclient.autoconfigure.StreamableHttpHttpClientTransportAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.context.annotation.UserConfigurations; @@ -58,6 +59,7 @@ public class StreamableHttpHttpClientTransportAutoConfigurationIT { .withPropertyValues("spring.ai.mcp.client.initialized=false", "spring.ai.mcp.client.streamable-http.connections.server1.url=" + host) .withConfiguration(AutoConfigurations.of(McpClientAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class, StreamableHttpHttpClientTransportAutoConfiguration.class)); static String host = "http://localhost:3001"; diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java index 674f2663a5b..b9db603a0bc 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java @@ -31,7 +31,6 @@ import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; -import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration; import org.springframework.ai.model.chat.client.autoconfigure.ChatClientAutoConfiguration; import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration; import org.springframework.ai.tool.ToolCallback; @@ -74,7 +73,6 @@ void mcpClientSupportsSampling() { McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class, McpClientAnnotationScannerAutoConfiguration.class, - McpClientSpecificationFactoryAutoConfiguration.class, // Tool callbacks ToolCallingAutoConfiguration.class, // Chat client for sampling @@ -122,26 +120,7 @@ void toolCallbacksRegistered() { assertThat(resolver.resolve("customToolCallbackProvider")).isNotNull(); // MCP toolcallback providers are never added to the resolver - - // Bean graph setup - var injectedProviders = (List) ctx.getBean( - "org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration.toolcallbackprovider.mcp-excluded"); - // Beans exposed as non-MCP - var toolCallbackProvider = (ToolCallbackProvider) ctx.getBean("toolCallbackProvider"); - var customToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("customToolCallbackProvider"); - // This is injected in the resolver bean, because it's exposed as a - // ToolCallbackProvider, but it's not added to the resolver - var genericMcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("genericMcpToolCallbackProvider"); - - // beans exposed as MCP - var mcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("mcpToolCallbackProvider"); - var customMcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("customMcpToolCallbackProvider"); - - assertThat(injectedProviders) - .containsExactlyInAnyOrder(toolCallbackProvider, customToolCallbackProvider, - genericMcpToolCallbackProvider) - .doesNotContain(mcpToolCallbackProvider, customMcpToolCallbackProvider); - + // Otherwise, they would throw. }); } @@ -194,29 +173,27 @@ ToolCallbackProvider toolCallbackProvider() { return tcp; } - // This bean depends on the resolver, to ensure there are no cyclic dependencies @Bean - SyncMcpToolCallbackProvider mcpToolCallbackProvider(ToolCallbackResolver resolver) { + CustomToolCallbackProvider customToolCallbackProvider() { + return new CustomToolCallbackProvider("customToolCallbackProvider"); + } + + // Ignored by the resolver + @Bean + SyncMcpToolCallbackProvider mcpToolCallbackProvider() { var tcp = mock(SyncMcpToolCallbackProvider.class); when(tcp.getToolCallbacks()) .thenThrow(new RuntimeException("mcpToolCallbackProvider#getToolCallbacks should not be called")); return tcp; } + // Ignored by the resolver @Bean - CustomToolCallbackProvider customToolCallbackProvider() { - return new CustomToolCallbackProvider("customToolCallbackProvider"); - } - - // This bean depends on the resolver, to ensure there are no cyclic dependencies - @Bean - CustomMcpToolCallbackProvider customMcpToolCallbackProvider(ToolCallbackResolver resolver) { + CustomMcpToolCallbackProvider customMcpToolCallbackProvider() { return new CustomMcpToolCallbackProvider(); } - // This will be added to the resolver, because the visible type of the bean - // is ToolCallbackProvider ; we would need to actually instantiate the bean - // to find out that it is MCP-related + // Ignored by the resolver @Bean ToolCallbackProvider genericMcpToolCallbackProvider() { return new CustomMcpToolCallbackProvider(); diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationIT.java index 744b494cf4d..15fa7c3cb33 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationIT.java @@ -30,6 +30,7 @@ import org.testcontainers.containers.wait.strategy.Wait; import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -43,8 +44,8 @@ public class SseWebFluxTransportAutoConfigurationIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.mcp.client.initialized=false", "spring.ai.mcp.client.sse.connections.server1.url=" + host) - .withConfiguration( - AutoConfigurations.of(McpClientAutoConfiguration.class, SseWebFluxTransportAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(McpClientAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class, SseWebFluxTransportAutoConfiguration.class)); static String host = "http://localhost:3001"; diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java index 257df12a97b..83da4876bd1 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java @@ -30,6 +30,7 @@ import org.testcontainers.containers.wait.strategy.Wait; import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -45,6 +46,7 @@ public class StreamableHttpHttpClientTransportAutoConfigurationIT { .withPropertyValues("spring.ai.mcp.client.initialized=false", "spring.ai.mcp.client.streamable-http.connections.server1.url=" + host) .withConfiguration(AutoConfigurations.of(McpClientAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class)); static String host = "http://localhost:3001"; diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/SseWebClientWebFluxServerIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/SseWebClientWebFluxServerIT.java index a833402c1d1..57961417cf8 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/SseWebClientWebFluxServerIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/SseWebClientWebFluxServerIT.java @@ -62,6 +62,7 @@ import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.ai.mcp.client.webflux.autoconfigure.SseWebFluxTransportAutoConfiguration; import org.springframework.ai.mcp.customizer.McpSyncClientCustomizer; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; @@ -93,9 +94,9 @@ public class SseWebClientWebFluxServerIT { AutoConfigurations.of(McpServerAutoConfiguration.class, McpServerObjectMapperAutoConfiguration.class, ToolCallbackConverterAutoConfiguration.class, McpServerSseWebFluxAutoConfiguration.class)); - private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, - McpClientAutoConfiguration.class, SseWebFluxTransportAutoConfiguration.class)); + private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner().withConfiguration( + AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class, SseWebFluxTransportAutoConfiguration.class)); @Test void clientServerCapabilities() { @@ -518,6 +519,8 @@ McpSyncClientCustomizer clientCustomizer(TestContext testContext) { assertThat(progressNotification.total()).isEqualTo(1.0); // assertThat(progressNotification.message()).isEqualTo("processing"); }); + + mcpClientSpec.capabilities(McpSchema.ClientCapabilities.builder().elicitation().sampling().build()); }; } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StatelessWebClientWebFluxServerIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StatelessWebClientWebFluxServerIT.java index a4eb89181cd..f584a3ad91b 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StatelessWebClientWebFluxServerIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StatelessWebClientWebFluxServerIT.java @@ -49,6 +49,7 @@ import org.springframework.ai.mcp.McpToolUtils; import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; import org.springframework.ai.mcp.customizer.McpSyncClientCustomizer; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerObjectMapperAutoConfiguration; @@ -84,7 +85,8 @@ public class StatelessWebClientWebFluxServerIT { private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, - McpClientAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class)); + McpClientAutoConfiguration.class, McpClientAnnotationScannerAutoConfiguration.class, + StreamableHttpWebFluxTransportAutoConfiguration.class)); @Test void clientServerCapabilities() { diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotations2IT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotations2IT.java index 35f1d67937e..a7a56ba9e5a 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotations2IT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotations2IT.java @@ -72,7 +72,6 @@ import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; -import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration; import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerObjectMapperAutoConfiguration; @@ -109,8 +108,7 @@ public class StreamableMcpAnnotations2IT { private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class, - McpClientAnnotationScannerAutoConfiguration.class, - McpClientSpecificationFactoryAutoConfiguration.class)); + McpClientAnnotationScannerAutoConfiguration.class)); @Test void clientServerCapabilities() { diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsIT.java index 4a0da8b3ac7..cd34bac36c6 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsIT.java @@ -73,7 +73,6 @@ import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; -import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration; import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerObjectMapperAutoConfiguration; @@ -110,8 +109,7 @@ public class StreamableMcpAnnotationsIT { private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class, - McpClientAnnotationScannerAutoConfiguration.class, - McpClientSpecificationFactoryAutoConfiguration.class)); + McpClientAnnotationScannerAutoConfiguration.class)); @Test void clientServerCapabilities() { diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsManualIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsManualIT.java index 1f6c2490267..9d83b94e6df 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsManualIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsManualIT.java @@ -75,7 +75,6 @@ import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; -import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration; import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerObjectMapperAutoConfiguration; @@ -117,7 +116,7 @@ public class StreamableMcpAnnotationsManualIT { .withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class, // MCP Annotations - McpClientAnnotationScannerAutoConfiguration.class, McpClientSpecificationFactoryAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class, // Anthropic ChatClient Builder AnthropicChatAutoConfiguration.class, ChatClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class, ToolCallingAutoConfiguration.class)); diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java index 9403b2e0bf4..5b224a154b7 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java @@ -36,14 +36,11 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springaicommunity.mcp.annotation.McpElicitation; import org.springaicommunity.mcp.annotation.McpLogging; import org.springaicommunity.mcp.annotation.McpProgress; -import org.springaicommunity.mcp.annotation.McpSampling; import org.springaicommunity.mcp.annotation.McpTool; import org.springaicommunity.mcp.annotation.McpToolParam; import org.springaicommunity.mcp.context.McpSyncRequestContext; -import org.springaicommunity.mcp.context.StructuredElicitResult; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; @@ -51,8 +48,9 @@ import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; -import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration; import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; +import org.springframework.ai.mcp.server.autoconfigure.capabilities.McpHandlerConfiguration; +import org.springframework.ai.mcp.server.autoconfigure.capabilities.McpHandlerService; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerObjectMapperAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.ToolCallbackConverterAutoConfiguration; @@ -71,6 +69,7 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ComponentScan; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; import org.springframework.test.util.TestSocketUtils; @@ -98,8 +97,8 @@ public class StreamableMcpAnnotationsWithLLMIT { .withPropertyValues("spring.ai.anthropic.apiKey=" + System.getenv("ANTHROPIC_API_KEY")) .withConfiguration(anthropicAutoConfig(McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class, - McpClientAnnotationScannerAutoConfiguration.class, McpClientSpecificationFactoryAutoConfiguration.class, - AnthropicChatAutoConfiguration.class, ChatClientAutoConfiguration.class)); + McpClientAnnotationScannerAutoConfiguration.class, AnthropicChatAutoConfiguration.class, + ChatClientAutoConfiguration.class)); private static AutoConfigurations anthropicAutoConfig(Class... additional) { Class[] dependencies = { SpringAiRetryAutoConfiguration.class, ToolCallingAutoConfiguration.class, @@ -221,9 +220,6 @@ private static void stopHttpServer(DisposableServer server) { } } - record ElicitInput(String message) { - } - public static class TestMcpServerConfiguration { @Bean @@ -245,7 +241,8 @@ public String weather(McpSyncRequestContext ctx, @McpToolParam String cityName) ctx.ping(); // call client ping // call elicitation - var elicitationResult = ctx.elicit(e -> e.message("Test message"), ElicitInput.class); + var elicitationResult = ctx.elicit(e -> e.message("Test message"), + McpHandlerConfiguration.ElicitInput.class); ctx.progress(p -> p.progress(0.50).total(1.0).message("elicitation completed")); @@ -286,18 +283,16 @@ public static class TestContext { } + // We also include scanned beans, because those are registered differently. + @ComponentScan(basePackageClasses = McpHandlerService.class) public static class TestMcpClientHandlers { private static final Logger logger = LoggerFactory.getLogger(TestMcpClientHandlers.class); - private final ChatClient client; - private TestMcpClientConfiguration.TestContext testContext; - public TestMcpClientHandlers(TestMcpClientConfiguration.TestContext testContext, - ChatClient.Builder clientBuilder) { + public TestMcpClientHandlers(TestMcpClientConfiguration.TestContext testContext) { this.testContext = testContext; - this.client = clientBuilder.build(); } @McpProgress(clients = "server1") @@ -314,28 +309,6 @@ public void loggingHandler(McpSchema.LoggingMessageNotification loggingMessage) logger.info("MCP LOGGING: [{}] {}", loggingMessage.level(), loggingMessage.data()); } - @McpSampling(clients = "server1") - public McpSchema.CreateMessageResult samplingHandler(McpSchema.CreateMessageRequest llmRequest) { - logger.info("MCP SAMPLING: {}", llmRequest); - - String userPrompt = ((McpSchema.TextContent) llmRequest.messages().get(0).content()).text(); - String modelHint = llmRequest.modelPreferences().hints().get(0).name(); - // In a real use-case, we would use the chat client to call the LLM again - logger.info("MCP SAMPLING: simulating using chat client {}", this.client); - - return McpSchema.CreateMessageResult.builder() - .content(new McpSchema.TextContent("Response " + userPrompt + " with model hint " + modelHint)) - .build(); - } - - @McpElicitation(clients = "server1") - public StructuredElicitResult elicitationHandler(McpSchema.ElicitRequest request) { - logger.info("MCP ELICITATION: {}", request); - StreamableMcpAnnotationsWithLLMIT.ElicitInput elicitData = new StreamableMcpAnnotationsWithLLMIT.ElicitInput( - request.message()); - return StructuredElicitResult.builder().structuredContent(elicitData).build(); - } - } } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableWebClientWebFluxServerIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableWebClientWebFluxServerIT.java index 8ee43cf8d07..96108790929 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableWebClientWebFluxServerIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableWebClientWebFluxServerIT.java @@ -63,6 +63,7 @@ import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; import org.springframework.ai.mcp.customizer.McpSyncClientCustomizer; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; @@ -99,7 +100,8 @@ public class StreamableWebClientWebFluxServerIT { private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, - McpClientAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class)); + McpClientAutoConfiguration.class, McpClientAnnotationScannerAutoConfiguration.class, + StreamableHttpWebFluxTransportAutoConfiguration.class)); @Test void clientServerCapabilities() { @@ -521,6 +523,7 @@ McpSyncClientCustomizer clientCustomizer(TestContext testContext) { testContext.progressNotifications.add(progressNotification); testContext.progressLatch.countDown(); }); + mcpClientSpec.capabilities(McpSchema.ClientCapabilities.builder().sampling().elicitation().build()); }; } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerConfiguration.java new file mode 100644 index 00000000000..1b85779c9a3 --- /dev/null +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerConfiguration.java @@ -0,0 +1,77 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.server.autoconfigure.capabilities; + +import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.context.StructuredElicitResult; + +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Scope; +import org.springframework.web.context.annotation.RequestScope; + +@Configuration +public class McpHandlerConfiguration { + + private static final Logger logger = LoggerFactory.getLogger(McpHandlerConfiguration.class); + + @Bean + ElicitationHandler elicitationHandler() { + return new ElicitationHandler(); + } + + // Ensure that we don't blow up on non-singleton beans + @Bean + @Scope(scopeName = ConfigurableBeanFactory.SCOPE_PROTOTYPE) + Foo foo() { + return new Foo(); + } + + // Ensure that we don't blow up on non-singleton beans + @Bean + @RequestScope + Bar bar(Foo foo) { + return new Bar(); + } + + record ElicitationHandler() { + + @McpElicitation(clients = "server1") + public StructuredElicitResult elicitationHandler(McpSchema.ElicitRequest request) { + logger.info("MCP ELICITATION: {}", request); + ElicitInput elicitData = new ElicitInput(request.message()); + return StructuredElicitResult.builder().structuredContent(elicitData).build(); + } + + } + + public record ElicitInput(String message) { + } + + public static class Foo { + + } + + public static class Bar { + + } + +} diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerService.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerService.java new file mode 100644 index 00000000000..d9e872a60b2 --- /dev/null +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerService.java @@ -0,0 +1,52 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.server.autoconfigure.capabilities; + +import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpSampling; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.stereotype.Service; + +@Service +public class McpHandlerService { + + private static final Logger logger = LoggerFactory.getLogger(McpHandlerService.class); + + private final ChatClient client; + + public McpHandlerService(ChatClient.Builder chatClientBuilder) { + this.client = chatClientBuilder.build(); + } + + @McpSampling(clients = "server1") + public McpSchema.CreateMessageResult samplingHandler(McpSchema.CreateMessageRequest llmRequest) { + logger.info("MCP SAMPLING: {}", llmRequest); + + String userPrompt = ((McpSchema.TextContent) llmRequest.messages().get(0).content()).text(); + String modelHint = llmRequest.modelPreferences().hints().get(0).name(); + // In a real use-case, we would use the chat client to call the LLM again + logger.info("MCP SAMPLING: simulating using chat client {}", this.client); + + return McpSchema.CreateMessageResult.builder() + .content(new McpSchema.TextContent("Response " + userPrompt + " with model hint " + modelHint)) + .build(); + } + +} diff --git a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java index e5d6699cd69..bcdad6b2bf5 100644 --- a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java +++ b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java @@ -17,7 +17,6 @@ package org.springframework.ai.model.tool.autoconfigure; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import io.micrometer.observation.ObservationRegistry; @@ -36,14 +35,7 @@ import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; import org.springframework.ai.tool.resolution.StaticToolCallbackResolver; import org.springframework.ai.tool.resolution.ToolCallbackResolver; -import org.springframework.beans.BeansException; import org.springframework.beans.factory.ObjectProvider; -import org.springframework.beans.factory.annotation.Qualifier; -import org.springframework.beans.factory.config.BeanDefinition; -import org.springframework.beans.factory.support.BeanDefinitionBuilder; -import org.springframework.beans.factory.support.BeanDefinitionRegistry; -import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; -import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -65,26 +57,20 @@ @AutoConfiguration @ConditionalOnClass(ChatModel.class) @EnableConfigurationProperties(ToolCallingProperties.class) -public class ToolCallingAutoConfiguration implements BeanDefinitionRegistryPostProcessor { +public class ToolCallingAutoConfiguration { private static final Logger logger = LoggerFactory.getLogger(ToolCallingAutoConfiguration.class); - // Marker qualifier to exclude MCP-related ToolCallbackProviders - private static final String EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER = "org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration.toolcallbackprovider.mcp-excluded"; - /** * The default {@link ToolCallbackResolver} resolves tools by name for methods, * functions, and {@link ToolCallbackProvider} beans. *

- * MCP providers should not be injected to avoid cyclic dependencies. If some MCP - * providers are injected, we filter them out to avoid eagerly calling - * #getToolCallbacks. + * MCP providers are excluded, to avoid initializing them early with #listTools(). */ @Bean @ConditionalOnMissingBean ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext, - List toolCallbacks, - @Qualifier(EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER) List tcbProviders) { + List toolCallbacks, List tcbProviders) { List allFunctionAndToolCallbacks = new ArrayList<>(toolCallbacks); tcbProviders.stream() .filter(pr -> !isMcpToolCallbackProvider(ResolvableType.forInstance(pr))) @@ -100,41 +86,6 @@ ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationC return new DelegatingToolCallbackResolver(List.of(staticToolCallbackResolver, springBeanToolCallbackResolver)); } - /** - * Wrap {@link ToolCallbackProvider} beans that are not MCP-related into a named bean, - * which will be picked up by the - * {@link ToolCallingAutoConfiguration#toolCallbackResolver}. - *

- * MCP providers must be excluded, because they may depend on a {@code ChatClient} to - * do sampling. The chat client, in turn, depends on a {@link ToolCallbackResolver}. - * To do the detection, we depend on the exposed bean type. If a bean uses a factory - * method which returns a {@link ToolCallbackProvider}, which is an MCP provider under - * the hood, it will be included in the list. - */ - @Override - public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { - if (!(registry instanceof DefaultListableBeanFactory beanFactory)) { - return; - } - - var excludeMcpToolCallbackProviderBeanDefinition = BeanDefinitionBuilder - .genericBeanDefinition(List.class, () -> { - var providerNames = beanFactory.getBeanNamesForType(ToolCallbackProvider.class); - return Arrays.stream(providerNames) - .filter(name -> !isMcpToolCallbackProvider(beanFactory.getBeanDefinition(name).getResolvableType())) - .map(beanFactory::getBean) - .filter(ToolCallbackProvider.class::isInstance) - .map(ToolCallbackProvider.class::cast) - .toList(); - }) - .setScope(BeanDefinition.SCOPE_SINGLETON) - .setLazyInit(true) - .getBeanDefinition(); - - registry.registerBeanDefinition(EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER, - excludeMcpToolCallbackProviderBeanDefinition); - } - private static boolean isMcpToolCallbackProvider(ResolvableType type) { if (type.getType().getTypeName().equals("org.springframework.ai.mcp.SyncMcpToolCallbackProvider") || type.getType().getTypeName().equals("org.springframework.ai.mcp.AsyncMcpToolCallbackProvider")) { diff --git a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/AbstractClientMcpHandlerRegistry.java b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/AbstractClientMcpHandlerRegistry.java new file mode 100644 index 00000000000..c1d5c84017e --- /dev/null +++ b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/AbstractClientMcpHandlerRegistry.java @@ -0,0 +1,156 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.annotation.spring; + +import java.lang.annotation.Annotation; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import io.modelcontextprotocol.spec.McpSchema; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.annotation.McpLogging; +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.annotation.McpPromptListChanged; +import org.springaicommunity.mcp.annotation.McpResourceListChanged; +import org.springaicommunity.mcp.annotation.McpSampling; +import org.springaicommunity.mcp.annotation.McpToolListChanged; + +import org.springframework.aop.framework.autoproxy.AutoProxyUtils; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.config.BeanFactoryPostProcessor; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.util.ReflectionUtils; + +/** + * Base class for sync and async ClientMcpHandlerRegistries. Not intended for public use. + * + * @author Daniel Garnier-Moiroux + * @see ClientMcpAsyncHandlersRegistry + * @see ClientMcpSyncHandlersRegistry + */ +abstract class AbstractClientMcpHandlerRegistry implements BeanFactoryPostProcessor { + + protected Map capabilitiesPerClient = new HashMap<>(); + + protected ConfigurableListableBeanFactory beanFactory; + + protected final Set allAnnotatedBeans = new HashSet<>(); + + static final Class[] CLIENT_MCP_ANNOTATIONS = new Class[] { McpSampling.class, + McpElicitation.class, McpLogging.class, McpProgress.class, McpToolListChanged.class, + McpPromptListChanged.class, McpResourceListChanged.class }; + + static final McpSchema.ClientCapabilities EMPTY_CAPABILITIES = new McpSchema.ClientCapabilities(null, null, null, + null); + + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { + this.beanFactory = beanFactory; + Map> elicitationClientToAnnotatedBeans = new HashMap<>(); + Map> samplingClientToAnnotatedBeans = new HashMap<>(); + for (var beanName : beanFactory.getBeanDefinitionNames()) { + if (!beanFactory.getBeanDefinition(beanName).isSingleton()) { + // Only process singleton beans, not scoped beans + continue; + } + var foundAnnotations = this.scan(AutoProxyUtils.determineTargetClass(beanFactory, beanName)); + if (!foundAnnotations.isEmpty()) { + this.allAnnotatedBeans.add(beanName); + } + for (var foundAnnotation : foundAnnotations) { + if (foundAnnotation instanceof McpSampling sampling) { + for (var client : sampling.clients()) { + samplingClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList<>()).add(beanName); + } + } + else if (foundAnnotation instanceof McpElicitation elicitation) { + for (var client : elicitation.clients()) { + elicitationClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList<>()).add(beanName); + } + } + } + } + + for (var elicitationEntry : elicitationClientToAnnotatedBeans.entrySet()) { + if (elicitationEntry.getValue().size() > 1) { + throw new IllegalArgumentException( + "Found 2 elicitation handlers for client [%s], found in bean with names %s. Only one @McpElicitation handler is allowed per client" + .formatted(elicitationEntry.getKey(), new LinkedHashSet<>(elicitationEntry.getValue()))); + } + } + for (var samplingEntry : samplingClientToAnnotatedBeans.entrySet()) { + if (samplingEntry.getValue().size() > 1) { + throw new IllegalArgumentException( + "Found 2 sampling handlers for client [%s], found in bean with names %s. Only one @McpSampling handler is allowed per client" + .formatted(samplingEntry.getKey(), new LinkedHashSet<>(samplingEntry.getValue()))); + } + } + + Map capsPerClient = new HashMap<>(); + for (var samplingClient : samplingClientToAnnotatedBeans.keySet()) { + capsPerClient.computeIfAbsent(samplingClient, ignored -> McpSchema.ClientCapabilities.builder()).sampling(); + } + for (var elicitationClient : elicitationClientToAnnotatedBeans.keySet()) { + capsPerClient.computeIfAbsent(elicitationClient, ignored -> McpSchema.ClientCapabilities.builder()) + .elicitation(); + } + + this.capabilitiesPerClient = capsPerClient.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().build())); + } + + protected List scan(Class beanClass) { + List foundAnnotations = new ArrayList<>(); + + // Scan all methods in the bean class + ReflectionUtils.doWithMethods(beanClass, method -> { + for (var annotationType : CLIENT_MCP_ANNOTATIONS) { + Annotation annotation = AnnotationUtils.findAnnotation(method, annotationType); + if (annotation != null) { + foundAnnotations.add(annotation); + } + } + }); + return foundAnnotations; + } + + protected Map, Set> getBeansByAnnotationType() { + // Use a set in case multiple handlers are registered in the same bean + Map, Set> beansByAnnotation = new HashMap<>(); + for (var annotation : CLIENT_MCP_ANNOTATIONS) { + beansByAnnotation.put(annotation, new HashSet<>()); + } + + for (var beanName : this.allAnnotatedBeans) { + var bean = this.beanFactory.getBean(beanName); + var annotations = this.scan(bean.getClass()); + for (var annotation : annotations) { + beansByAnnotation.computeIfAbsent(annotation.annotationType(), k -> new HashSet<>()).add(bean); + } + } + return beansByAnnotation; + } + +} diff --git a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java new file mode 100644 index 00000000000..536f2eec7c1 --- /dev/null +++ b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java @@ -0,0 +1,267 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.annotation.spring; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.annotation.McpLogging; +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.annotation.McpPromptListChanged; +import org.springaicommunity.mcp.annotation.McpResourceListChanged; +import org.springaicommunity.mcp.annotation.McpSampling; +import org.springaicommunity.mcp.annotation.McpToolListChanged; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.beans.factory.SmartInitializingSingleton; + +/** + * Registry of methods annotated with MCP Client annotations (sampling, logging, etc.). + * All beans in the application context are scanned to find these methods automatically. + * They are then exposed by the registry by client name. + *

+ * The scanning happens in two phases: + *

+ * First, once bean definitions are available, all bean types are scanned for the presence + * of MCP annotations. In particular, this is used to prepare the result + * {@link #getCapabilities(String)}, which is then used by MCP client auto-configurations + * to configure the client capabilities without needing to instantiate the beans. + *

+ * Second, after all singleton beans have been instantiated, all annotated beans are + * scanned again, MCP handlers are created to match the annotations, and stored by client. + * + * @see McpSampling + * @see McpElicitation + * @see McpLogging + * @see McpProgress + * @see McpToolListChanged + * @see McpPromptListChanged + * @see McpResourceListChanged + * @author Daniel Garnier-Moiroux + * @since 1.1.0 + */ +public class ClientMcpAsyncHandlersRegistry extends AbstractClientMcpHandlerRegistry + implements SmartInitializingSingleton { + + private static final Logger logger = LoggerFactory.getLogger(ClientMcpAsyncHandlersRegistry.class); + + private final Map>> samplingHandlers = new HashMap<>(); + + private final Map>> elicitationHandlers = new HashMap<>(); + + private final Map>>> loggingHandlers = new HashMap<>(); + + private final Map>>> progressHandlers = new HashMap<>(); + + private final Map, Mono>>> toolListChangedHandlers = new HashMap<>(); + + private final Map, Mono>>> promptListChangedHandlers = new HashMap<>(); + + private final Map, Mono>>> resourceListChangedHandlers = new HashMap<>(); + + /** + * Obtain the MCP capabilities declared for a given MCP client. Capabilities are + * registered with the {@link McpSampling} and {@link McpElicitation} annotations. + */ + public McpSchema.ClientCapabilities getCapabilities(String clientName) { + return this.capabilitiesPerClient.getOrDefault(clientName, EMPTY_CAPABILITIES); + } + + /** + * Invoke the sampling handler for a given MCP client. + * + * @see McpSampling + */ + public Mono handleSampling(String name, + McpSchema.CreateMessageRequest samplingRequest) { + logger.debug("Handling sampling request for client {}", name); + var handler = this.samplingHandlers.get(name); + if (handler != null) { + return handler.apply(samplingRequest); + } + return Mono.error(new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + "Sampling not supported", Map.of("reason", "Client does not have sampling capability")))); + } + + /** + * Invoke the elicitation handler for a given MCP client. + * + * @see McpElicitation + */ + public Mono handleElicitation(String name, McpSchema.ElicitRequest elicitationRequest) { + logger.debug("Handling elicitation request for client {}", name); + var handler = this.elicitationHandlers.get(name); + if (handler != null) { + return handler.apply(elicitationRequest); + } + return Mono.error(new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + "Elicitation not supported", Map.of("reason", "Client does not have elicitation capability")))); + } + + /** + * Invoke all elicitation handlers for a given MCP client, sequentially. + * + * @see McpLogging + */ + public Mono handleLogging(String name, McpSchema.LoggingMessageNotification loggingMessageNotification) { + logger.debug("Handling logging notification for client {}", name); + var consumers = this.loggingHandlers.get(name); + if (consumers == null) { + return Mono.empty(); + } + return Flux.fromIterable(consumers).flatMap(c -> c.apply(loggingMessageNotification)).then(); + } + + /** + * Invoke all progress handlers for a given MCP client, sequentially. + * + * @see McpProgress + */ + public Mono handleProgress(String name, McpSchema.ProgressNotification progressNotification) { + logger.debug("Handling progress notification for client {}", name); + var consumers = this.progressHandlers.get(name); + if (consumers == null) { + return Mono.empty(); + } + return Flux.fromIterable(consumers).flatMap(c -> c.apply(progressNotification)).then(); + } + + /** + * Invoke all tool list changed handlers for a given MCP client, sequentially. + * + * @see McpToolListChanged + */ + public Mono handleToolListChanged(String name, List updatedTools) { + logger.debug("Handling tool list changed notification for client {}", name); + var consumers = this.toolListChangedHandlers.get(name); + if (consumers == null) { + return Mono.empty(); + } + return Flux.fromIterable(consumers).flatMap(c -> c.apply(updatedTools)).then(); + } + + /** + * Invoke all prompt list changed handlers for a given MCP client, sequentially. + * + * @see McpPromptListChanged + */ + public Mono handlePromptListChanged(String name, List updatedPrompts) { + logger.debug("Handling prompt list changed notification for client {}", name); + var consumers = this.promptListChangedHandlers.get(name); + if (consumers == null) { + return Mono.empty(); + } + return Flux.fromIterable(consumers).flatMap(c -> c.apply(updatedPrompts)).then(); + } + + /** + * Invoke all resource list changed handlers for a given MCP client, sequentially. + * + * @see McpResourceListChanged + */ + public Mono handleResourceListChanged(String name, List updatedResources) { + logger.debug("Handling resource list changed notification for client {}", name); + var consumers = this.resourceListChangedHandlers.get(name); + if (consumers == null) { + return Mono.empty(); + } + return Flux.fromIterable(consumers).flatMap(c -> c.apply(updatedResources)).then(); + } + + @Override + public void afterSingletonsInstantiated() { + var beansByAnnotation = this.getBeansByAnnotationType(); + + var samplingSpecs = AsyncMcpAnnotationProviders + .samplingSpecifications(new ArrayList<>(beansByAnnotation.get(McpSampling.class))); + for (var samplingSpec : samplingSpecs) { + for (var client : samplingSpec.clients()) { + logger.debug("Registering sampling handler for {}", client); + this.samplingHandlers.put(client, samplingSpec.samplingHandler()); + } + } + + var elicitationSpecs = AsyncMcpAnnotationProviders + .elicitationSpecifications(new ArrayList<>(beansByAnnotation.get(McpElicitation.class))); + for (var elicitationSpec : elicitationSpecs) { + for (var client : elicitationSpec.clients()) { + logger.debug("Registering elicitation handler for {}", client); + this.elicitationHandlers.put(client, elicitationSpec.elicitationHandler()); + } + } + + var loggingSpecs = AsyncMcpAnnotationProviders + .loggingSpecifications(new ArrayList<>(beansByAnnotation.get(McpLogging.class))); + for (var loggingSpec : loggingSpecs) { + for (var client : loggingSpec.clients()) { + logger.debug("Registering logging handler for {}", client); + this.loggingHandlers.computeIfAbsent(client, k -> new ArrayList<>()).add(loggingSpec.loggingHandler()); + } + } + + var progressSpecs = AsyncMcpAnnotationProviders + .progressSpecifications(new ArrayList<>(beansByAnnotation.get(McpProgress.class))); + for (var progressSpec : progressSpecs) { + for (var client : progressSpec.clients()) { + logger.debug("Registering progress handler for {}", client); + this.progressHandlers.computeIfAbsent(client, k -> new ArrayList<>()) + .add(progressSpec.progressHandler()); + } + } + + var toolsListChangedSpecs = AsyncMcpAnnotationProviders + .toolListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpToolListChanged.class))); + for (var toolsListChangedSpec : toolsListChangedSpecs) { + for (var client : toolsListChangedSpec.clients()) { + logger.debug("Registering tool list changed handler for {}", client); + this.toolListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) + .add(toolsListChangedSpec.toolListChangeHandler()); + } + } + + var promptListChangedSpecs = AsyncMcpAnnotationProviders + .promptListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpPromptListChanged.class))); + for (var promptListChangedSpec : promptListChangedSpecs) { + for (var client : promptListChangedSpec.clients()) { + logger.debug("Registering prompt list changed handler for {}", client); + this.promptListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) + .add(promptListChangedSpec.promptListChangeHandler()); + } + } + + var resourceListChangedSpecs = AsyncMcpAnnotationProviders + .resourceListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpResourceListChanged.class))); + for (var resourceListChangedSpec : resourceListChangedSpecs) { + for (var client : resourceListChangedSpec.clients()) { + logger.debug("Registering resource list changed handler for {}", client); + this.resourceListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) + .add(resourceListChangedSpec.resourceListChangeHandler()); + } + } + + } + +} diff --git a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java new file mode 100644 index 00000000000..36a4d63fa14 --- /dev/null +++ b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java @@ -0,0 +1,282 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.annotation.spring; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.annotation.McpLogging; +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.annotation.McpPromptListChanged; +import org.springaicommunity.mcp.annotation.McpResourceListChanged; +import org.springaicommunity.mcp.annotation.McpSampling; +import org.springaicommunity.mcp.annotation.McpToolListChanged; + +import org.springframework.beans.factory.SmartInitializingSingleton; + +/** + * Registry of methods annotated with MCP Client annotations (sampling, logging, etc.). + * All beans in the application context are scanned to find these methods automatically. + * They are then exposed by the registry by client name. + *

+ * The scanning happens in two phases: + *

+ * First, once bean definitions are available, all bean types are scanned for the presence + * of MCP annotations. In particular, this is used to prepare the result + * {@link #getCapabilities(String)}, which is then used by MCP client auto-configurations + * to configure the client capabilities without needing to instantiate the beans. + *

+ * Second, after all singleton beans have been instantiated, all annotated beans are + * scanned again, MCP handlers are created to match the annotations, and stored by client. + * + * @see McpSampling + * @see McpElicitation + * @see McpLogging + * @see McpProgress + * @see McpToolListChanged + * @see McpPromptListChanged + * @see McpResourceListChanged + * @author Daniel Garnier-Moiroux + * @since 1.1.0 + */ +public class ClientMcpSyncHandlersRegistry extends AbstractClientMcpHandlerRegistry + implements SmartInitializingSingleton { + + private static final Logger logger = LoggerFactory.getLogger(ClientMcpSyncHandlersRegistry.class); + + private final Map> samplingHandlers = new HashMap<>(); + + private final Map> elicitationHandlers = new HashMap<>(); + + private final Map>> loggingHandlers = new HashMap<>(); + + private final Map>> progressHandlers = new HashMap<>(); + + private final Map>>> toolListChangedHandlers = new HashMap<>(); + + private final Map>>> promptListChangedHandlers = new HashMap<>(); + + private final Map>>> resourceListChangedHandlers = new HashMap<>(); + + /** + * Obtain the MCP capabilities declared for a given MCP client. Capabilities are + * registered with the {@link McpSampling} and {@link McpElicitation} annotations. + */ + public McpSchema.ClientCapabilities getCapabilities(String clientName) { + return this.capabilitiesPerClient.getOrDefault(clientName, EMPTY_CAPABILITIES); + } + + /** + * Invoke the sampling handler for a given MCP client. + * + * @see McpSampling + */ + public McpSchema.CreateMessageResult handleSampling(String name, McpSchema.CreateMessageRequest samplingRequest) { + logger.debug("Handling sampling request for client {}", name); + + var handler = this.samplingHandlers.get(name); + if (handler != null) { + return handler.apply(samplingRequest); + } + throw new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + "Sampling not supported", Map.of("reason", "Client does not have sampling capability"))); + } + + /** + * Invoke the elicitation handler for a given MCP client. + * + * @see McpElicitation + */ + public McpSchema.ElicitResult handleElicitation(String name, McpSchema.ElicitRequest elicitationRequest) { + logger.debug("Handling elicitation request for client {}", name); + + var handler = this.elicitationHandlers.get(name); + if (handler != null) { + return handler.apply(elicitationRequest); + } + throw new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + "Elicitation not supported", Map.of("reason", "Client does not have elicitation capability"))); + } + + /** + * Invoke all elicitation handlers for a given MCP client, sequentially. + * + * @see McpLogging + */ + public void handleLogging(String name, McpSchema.LoggingMessageNotification loggingMessageNotification) { + logger.debug("Handling logging notification for client {}", name); + + var consumers = this.loggingHandlers.get(name); + if (consumers == null) { + return; + } + for (var consumer : consumers) { + consumer.accept(loggingMessageNotification); + } + } + + /** + * Invoke all progress handlers for a given MCP client, sequentially. + * + * @see McpProgress + */ + public void handleProgress(String name, McpSchema.ProgressNotification progressNotification) { + logger.debug("Handling progress notification for client {}", name); + + var consumers = this.progressHandlers.get(name); + if (consumers == null) { + return; + } + for (var consumer : consumers) { + consumer.accept(progressNotification); + } + } + + /** + * Invoke all tool list changed handlers for a given MCP client, sequentially. + * + * @see McpToolListChanged + */ + public void handleToolListChanged(String name, List updatedTools) { + logger.debug("Handling tool list changed notification for client {}", name); + + var consumers = this.toolListChangedHandlers.get(name); + if (consumers == null) { + return; + } + for (var consumer : consumers) { + consumer.accept(updatedTools); + } + } + + /** + * Invoke all prompt list changed handlers for a given MCP client, sequentially. + * + * @see McpPromptListChanged + */ + public void handlePromptListChanged(String name, List updatedPrompts) { + logger.debug("Handling prompt list changed notification for client {}", name); + + var consumers = this.promptListChangedHandlers.get(name); + if (consumers == null) { + return; + } + for (var consumer : consumers) { + consumer.accept(updatedPrompts); + } + } + + /** + * Invoke all resource list changed handlers for a given MCP client, sequentially. + * + * @see McpResourceListChanged + */ + public void handleResourceListChanged(String name, List updatedResources) { + logger.debug("Handling resource list changed notification for client {}", name); + + var consumers = this.resourceListChangedHandlers.get(name); + if (consumers == null) { + return; + } + for (var consumer : consumers) { + consumer.accept(updatedResources); + } + } + + @Override + public void afterSingletonsInstantiated() { + var beansByAnnotation = this.getBeansByAnnotationType(); + + var samplingSpecs = SyncMcpAnnotationProviders + .samplingSpecifications(new ArrayList<>(beansByAnnotation.get(McpSampling.class))); + for (var samplingSpec : samplingSpecs) { + for (var client : samplingSpec.clients()) { + logger.debug("Registering sampling handler for {}", client); + this.samplingHandlers.put(client, samplingSpec.samplingHandler()); + } + } + + var elicitationSpecs = SyncMcpAnnotationProviders + .elicitationSpecifications(new ArrayList<>(beansByAnnotation.get(McpElicitation.class))); + for (var elicitationSpec : elicitationSpecs) { + for (var client : elicitationSpec.clients()) { + logger.debug("Registering elicitation handler for {}", client); + this.elicitationHandlers.put(client, elicitationSpec.elicitationHandler()); + } + } + + var loggingSpecs = SyncMcpAnnotationProviders + .loggingSpecifications(new ArrayList<>(beansByAnnotation.get(McpLogging.class))); + for (var loggingSpec : loggingSpecs) { + for (var client : loggingSpec.clients()) { + logger.debug("Registering logging handler for {}", client); + this.loggingHandlers.computeIfAbsent(client, k -> new ArrayList<>()).add(loggingSpec.loggingHandler()); + } + } + + var progressSpecs = SyncMcpAnnotationProviders + .progressSpecifications(new ArrayList<>(beansByAnnotation.get(McpProgress.class))); + for (var progressSpec : progressSpecs) { + for (var client : progressSpec.clients()) { + logger.debug("Registering progress handler for {}", client); + this.progressHandlers.computeIfAbsent(client, k -> new ArrayList<>()) + .add(progressSpec.progressHandler()); + } + } + + var toolsListChangedSpecs = SyncMcpAnnotationProviders + .toolListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpToolListChanged.class))); + for (var toolsListChangedSpec : toolsListChangedSpecs) { + for (var client : toolsListChangedSpec.clients()) { + logger.debug("Registering tool list changed handler for {}", client); + this.toolListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) + .add(toolsListChangedSpec.toolListChangeHandler()); + } + } + + var promptListChangedSpecs = SyncMcpAnnotationProviders + .promptListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpPromptListChanged.class))); + for (var promptListChangedSpec : promptListChangedSpecs) { + for (var client : promptListChangedSpec.clients()) { + logger.debug("Registering prompt list changed handler for {}", client); + this.promptListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) + .add(promptListChangedSpec.promptListChangeHandler()); + } + } + + var resourceListChangedSpecs = SyncMcpAnnotationProviders + .resourceListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpResourceListChanged.class))); + for (var resourceListChangedSpec : resourceListChangedSpecs) { + for (var client : resourceListChangedSpec.clients()) { + logger.debug("Registering resource list changed handler for {}", client); + this.resourceListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) + .add(resourceListChangedSpec.resourceListChangeHandler()); + } + } + + } + +} diff --git a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java new file mode 100644 index 00000000000..6e7300bdd6f --- /dev/null +++ b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java @@ -0,0 +1,527 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.annotation.spring; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.annotation.McpLogging; +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.annotation.McpPromptListChanged; +import org.springaicommunity.mcp.annotation.McpResourceListChanged; +import org.springaicommunity.mcp.annotation.McpSampling; +import org.springaicommunity.mcp.annotation.McpToolListChanged; +import reactor.core.publisher.Mono; + +import org.springframework.aop.framework.autoproxy.AutoProxyUtils; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.InstanceOfAssertFactories.type; + +class ClientMcpAsyncHandlersRegistryTests { + + @Test + void getCapabilitiesPerClient() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(ClientCapabilitiesConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + + assertThat(registry.getCapabilities("client-1").elicitation()).isNotNull(); + assertThat(registry.getCapabilities("client-2").elicitation()).isNotNull(); + assertThat(registry.getCapabilities("client-3").elicitation()).isNotNull(); + + assertThat(registry.getCapabilities("client-1").sampling()).isNotNull(); + assertThat(registry.getCapabilities("client-2").sampling()).isNull(); + assertThat(registry.getCapabilities("client-3").sampling()).isNull(); + + assertThat(registry.getCapabilities("client-1").roots()).isNull(); + assertThat(registry.getCapabilities("client-2").roots()).isNull(); + assertThat(registry.getCapabilities("client-3").roots()).isNull(); + + assertThat(registry.getCapabilities("client-1").experimental()).isNull(); + assertThat(registry.getCapabilities("client-2").experimental()).isNull(); + assertThat(registry.getCapabilities("client-3").experimental()).isNull(); + + assertThat(registry.getCapabilities("client-unknown").sampling()).isNull(); + assertThat(registry.getCapabilities("client-unknown").elicitation()).isNull(); + assertThat(registry.getCapabilities("client-unknown").roots()).isNull(); + } + + @Test + void twoHandlersElicitation() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("firstConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleElicitationHandlerConfiguration.First.class) + .getBeanDefinition()); + beanFactory.registerBeanDefinition("secondConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleElicitationHandlerConfiguration.Second.class) + .getBeanDefinition()); + assertThatThrownBy(() -> registry.postProcessBeanFactory(beanFactory)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Found 2 elicitation handlers for client [client-1], found in bean with names [firstConfig, secondConfig]. Only one @McpElicitation handler is allowed per client"); + } + + @Test + void twoHandlersSameBeanElicitation() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("elicitationConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleElicitationHandlerConfiguration.TwoHandlers.class) + .getBeanDefinition()); + assertThatThrownBy(() -> registry.postProcessBeanFactory(beanFactory)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Found 2 elicitation handlers for client [client-1], found in bean with names [elicitationConfig]. Only one @McpElicitation handler is allowed per client"); + } + + @Test + void twoHandlersSampling() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("firstConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleSamplingHandlerConfiguration.First.class) + .getBeanDefinition()); + beanFactory.registerBeanDefinition("secondConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleSamplingHandlerConfiguration.Second.class) + .getBeanDefinition()); + assertThatThrownBy(() -> registry.postProcessBeanFactory(beanFactory)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Found 2 sampling handlers for client [client-1], found in bean with names [firstConfig, secondConfig]. Only one @McpSampling handler is allowed per client"); + } + + @Test + void twoHandlersSameBeanSampling() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("samplingConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleSamplingHandlerConfiguration.TwoHandlers.class) + .getBeanDefinition()); + assertThatThrownBy(() -> registry.postProcessBeanFactory(beanFactory)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Found 2 sampling handlers for client [client-1], found in bean with names [samplingConfig]. Only one @McpSampling handler is allowed per client"); + } + + @Test + void elicitation() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + + var request = McpSchema.ElicitRequest.builder().message("Elicit request").progressToken("token-12345").build(); + var response = registry.handleElicitation("client-1", request).block(); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1).containsEntry("message", "Elicit request"); + assertThat(response.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + } + + @Test + void missingElicitationHandler() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder + .genericBeanDefinition(ClientMcpAsyncHandlersRegistryTests.HandlersConfiguration.class) + .getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + + var request = McpSchema.ElicitRequest.builder().message("Elicit request").progressToken("token-12345").build(); + assertThatThrownBy(() -> registry.handleElicitation("client-unknown", request).block()) + .hasMessage("Elicitation not supported") + .asInstanceOf(type(McpError.class)) + .extracting(McpError::getJsonRpcError) + .satisfies(error -> assertThat(error.data()) + .isEqualTo(Map.of("reason", "Client does not have elicitation capability"))) + .satisfies(error -> assertThat(error.code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND)); + } + + @Test + void sampling() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + + var request = McpSchema.CreateMessageRequest.builder() + .messages(List + .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Tell a joke")))) + .build(); + var response = registry.handleSampling("client-1", request).block(); + + assertThat(response.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(response.model()).isEqualTo("testgpt-42.5"); + McpSchema.TextContent content = (McpSchema.TextContent) response.content(); + assertThat(content.text()).isEqualTo("Tell a joke"); + } + + @Test + void missingSamplingHandler() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder + .genericBeanDefinition(ClientMcpAsyncHandlersRegistryTests.HandlersConfiguration.class) + .getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + + var request = McpSchema.CreateMessageRequest.builder() + .messages(List + .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Tell a joke")))) + .build(); + assertThatThrownBy(() -> registry.handleSampling("client-unknown", request).block()) + .hasMessage("Sampling not supported") + .asInstanceOf(type(McpError.class)) + .extracting(McpError::getJsonRpcError) + .satisfies(error -> assertThat(error.data()) + .isEqualTo(Map.of("reason", "Client does not have sampling capability"))) + .satisfies(error -> assertThat(error.code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND)); + } + + @Test + void logging() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + var logRequest = McpSchema.LoggingMessageNotification.builder() + .data("Hello world") + .logger("log-me") + .level(McpSchema.LoggingLevel.INFO) + .build(); + + registry.handleLogging("client-1", logRequest).block(); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handleLoggingMessage", logRequest), + new HandlersConfiguration.Call("handleLoggingMessageAgain", logRequest)); + } + + @Test + void progress() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + var progressRequest = new McpSchema.ProgressNotification("progress-12345", 13.37, 100., "progressing ..."); + + registry.handleProgress("client-1", progressRequest).block(); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handleProgress", progressRequest), + new HandlersConfiguration.Call("handleProgressAgain", progressRequest)); + } + + @Test + void toolListChanged() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + List updatedTools = List.of(McpSchema.Tool.builder().name("tool-1").build(), + McpSchema.Tool.builder().name("tool-2").build()); + + registry.handleToolListChanged("client-1", updatedTools).block(); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handleToolListChanged", updatedTools), + new HandlersConfiguration.Call("handleToolListChangedAgain", updatedTools)); + } + + @Test + void promptListChanged() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + List updatedTools = List.of( + new McpSchema.Prompt("prompt-1", "a test prompt", Collections.emptyList()), + new McpSchema.Prompt("prompt-2", "another test prompt", Collections.emptyList())); + + registry.handlePromptListChanged("client-1", updatedTools).block(); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handlePromptListChanged", updatedTools), + new HandlersConfiguration.Call("handlePromptListChangedAgain", updatedTools)); + } + + @Test + void resourceListChanged() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + List updatedResources = List.of( + McpSchema.Resource.builder().name("resource-1").uri("file:///resource/1").build(), + McpSchema.Resource.builder().name("resource-2").uri("file:///resource/2").build()); + + registry.handleResourceListChanged("client-1", updatedResources).block(); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handleResourceListChanged", updatedResources), + new HandlersConfiguration.Call("handleResourceListChangedAgain", updatedResources)); + } + + @Test + void supportsNonResolvableTypes() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder + .genericBeanDefinition( + ClientMcpSyncHandlersRegistryTests.ClientCapabilitiesConfiguration.class.getName()) + .getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + + assertThat(registry.getCapabilities("client-1").elicitation()).isNotNull(); + } + + @Test + void supportsProxiedClass() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + var beanDefinition = BeanDefinitionBuilder.genericBeanDefinition(Object.class).getBeanDefinition(); + beanDefinition.setAttribute(AutoProxyUtils.ORIGINAL_TARGET_CLASS_ATTRIBUTE, + ClientMcpSyncHandlersRegistryTests.ClientCapabilitiesConfiguration.class); + beanFactory.registerBeanDefinition("myConfig", beanDefinition); + + registry.postProcessBeanFactory(beanFactory); + + assertThat(registry.getCapabilities("client-1").elicitation()).isNotNull(); + } + + static class ClientCapabilitiesConfiguration { + + @McpElicitation(clients = { "client-1", "client-2" }) + public Mono elicitationHandler1(McpSchema.ElicitRequest request) { + return Mono.empty(); + } + + @McpElicitation(clients = { "client-3" }) + public Mono elicitationHandler2(McpSchema.ElicitRequest request) { + return Mono.empty(); + } + + @McpSampling(clients = { "client-1" }) + public Mono samplingHandler(McpSchema.CreateMessageRequest request) { + return Mono.empty(); + } + + } + + static class DoubleElicitationHandlerConfiguration { + + static class First { + + @McpElicitation(clients = { "client-1" }) + public Mono elicitationHandler1(McpSchema.ElicitRequest request) { + return Mono.empty(); + } + + } + + static class Second { + + @McpElicitation(clients = { "client-1" }) + public Mono elicitationHandler2(McpSchema.ElicitRequest request) { + return Mono.empty(); + } + + } + + static class TwoHandlers { + + @McpElicitation(clients = { "client-1" }) + public Mono elicitationHandler1(McpSchema.ElicitRequest request) { + return Mono.empty(); + } + + @McpElicitation(clients = { "client-1" }) + public Mono elicitationHandler2(McpSchema.ElicitRequest request) { + return Mono.empty(); + } + + } + + } + + static class DoubleSamplingHandlerConfiguration { + + static class First { + + @McpSampling(clients = { "client-1" }) + public Mono samplingHandler1(McpSchema.CreateMessageRequest request) { + return Mono.empty(); + } + + } + + static class Second { + + @McpSampling(clients = { "client-1" }) + public Mono samplingHandler2(McpSchema.CreateMessageRequest request) { + return Mono.empty(); + } + + } + + static class TwoHandlers { + + @McpSampling(clients = { "client-1" }) + public Mono samplingHandler1(McpSchema.CreateMessageRequest request) { + return Mono.empty(); + } + + @McpSampling(clients = { "client-1" }) + public Mono samplingHandler2(McpSchema.CreateMessageRequest request) { + return Mono.empty(); + } + + } + + } + + static class HandlersConfiguration { + + private final List calls = new ArrayList<>(); + + HandlersConfiguration() { + } + + List getCalls() { + return Collections.unmodifiableList(this.calls); + } + + @McpElicitation(clients = { "client-1" }) + Mono elicitationHandler(McpSchema.ElicitRequest request) { + return Mono.just(McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.ACCEPT) + .content(Map.of("message", request.message())) + .build()); + } + + @McpSampling(clients = { "client-1" }) + Mono samplingHandler(McpSchema.CreateMessageRequest request) { + return Mono.just(McpSchema.CreateMessageResult.builder() + .message(((McpSchema.TextContent) request.messages().get(0).content()).text()) + .model("testgpt-42.5") + .build()); + } + + @McpLogging(clients = { "client-1" }) + Mono handleLoggingMessage(McpSchema.LoggingMessageNotification notification) { + this.calls.add(new Call("handleLoggingMessage", notification)); + return Mono.empty(); + } + + @McpLogging(clients = { "client-1" }) + Mono handleLoggingMessageAgain(McpSchema.LoggingMessageNotification notification) { + this.calls.add(new Call("handleLoggingMessageAgain", notification)); + return Mono.empty(); + } + + @McpProgress(clients = { "client-1" }) + Mono handleProgress(McpSchema.ProgressNotification notification) { + this.calls.add(new Call("handleProgress", notification)); + return Mono.empty(); + } + + @McpProgress(clients = { "client-1" }) + Mono handleProgressAgain(McpSchema.ProgressNotification notification) { + this.calls.add(new Call("handleProgressAgain", notification)); + return Mono.empty(); + } + + @McpToolListChanged(clients = { "client-1" }) + Mono handleToolListChanged(List updatedTools) { + this.calls.add(new Call("handleToolListChanged", updatedTools)); + return Mono.empty(); + } + + @McpToolListChanged(clients = { "client-1" }) + Mono handleToolListChangedAgain(List updatedTools) { + this.calls.add(new Call("handleToolListChangedAgain", updatedTools)); + return Mono.empty(); + } + + @McpPromptListChanged(clients = { "client-1" }) + Mono handlePromptListChanged(List updatedPrompts) { + this.calls.add(new Call("handlePromptListChanged", updatedPrompts)); + return Mono.empty(); + } + + @McpPromptListChanged(clients = { "client-1" }) + Mono handlePromptListChangedAgain(List updatedPrompts) { + this.calls.add(new Call("handlePromptListChangedAgain", updatedPrompts)); + return Mono.empty(); + } + + @McpResourceListChanged(clients = { "client-1" }) + Mono handleResourceListChanged(List updatedResources) { + this.calls.add(new Call("handleResourceListChanged", updatedResources)); + return Mono.empty(); + } + + @McpResourceListChanged(clients = { "client-1" }) + Mono handleResourceListChangedAgain(List updatedResources) { + this.calls.add(new Call("handleResourceListChangedAgain", updatedResources)); + return Mono.empty(); + } + + // Record calls made to this object + record Call(String name, Object callRequest) { + } + + } + +} diff --git a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java new file mode 100644 index 00000000000..9b75acf8aa6 --- /dev/null +++ b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java @@ -0,0 +1,509 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.annotation.spring; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.annotation.McpLogging; +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.annotation.McpPromptListChanged; +import org.springaicommunity.mcp.annotation.McpResourceListChanged; +import org.springaicommunity.mcp.annotation.McpSampling; +import org.springaicommunity.mcp.annotation.McpToolListChanged; + +import org.springframework.aop.framework.autoproxy.AutoProxyUtils; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.InstanceOfAssertFactories.type; + +class ClientMcpSyncHandlersRegistryTests { + + @Test + void getCapabilitiesPerClient() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(ClientCapabilitiesConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + + assertThat(registry.getCapabilities("client-1").elicitation()).isNotNull(); + assertThat(registry.getCapabilities("client-2").elicitation()).isNotNull(); + assertThat(registry.getCapabilities("client-3").elicitation()).isNotNull(); + + assertThat(registry.getCapabilities("client-1").sampling()).isNotNull(); + assertThat(registry.getCapabilities("client-2").sampling()).isNull(); + assertThat(registry.getCapabilities("client-3").sampling()).isNull(); + + assertThat(registry.getCapabilities("client-1").roots()).isNull(); + assertThat(registry.getCapabilities("client-2").roots()).isNull(); + assertThat(registry.getCapabilities("client-3").roots()).isNull(); + + assertThat(registry.getCapabilities("client-1").experimental()).isNull(); + assertThat(registry.getCapabilities("client-2").experimental()).isNull(); + assertThat(registry.getCapabilities("client-3").experimental()).isNull(); + + assertThat(registry.getCapabilities("client-unknown").sampling()).isNull(); + assertThat(registry.getCapabilities("client-unknown").elicitation()).isNull(); + assertThat(registry.getCapabilities("client-unknown").roots()).isNull(); + } + + @Test + void twoHandlersElicitation() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("firstConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleElicitationHandlerConfiguration.First.class) + .getBeanDefinition()); + beanFactory.registerBeanDefinition("secondConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleElicitationHandlerConfiguration.Second.class) + .getBeanDefinition()); + assertThatThrownBy(() -> registry.postProcessBeanFactory(beanFactory)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Found 2 elicitation handlers for client [client-1], found in bean with names [firstConfig, secondConfig]. Only one @McpElicitation handler is allowed per client"); + } + + @Test + void twoHandlersSameBeanElicitation() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("elicitationConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleElicitationHandlerConfiguration.TwoHandlers.class) + .getBeanDefinition()); + assertThatThrownBy(() -> registry.postProcessBeanFactory(beanFactory)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Found 2 elicitation handlers for client [client-1], found in bean with names [elicitationConfig]. Only one @McpElicitation handler is allowed per client"); + } + + @Test + void twoHandlersSampling() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("firstConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleSamplingHandlerConfiguration.First.class) + .getBeanDefinition()); + beanFactory.registerBeanDefinition("secondConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleSamplingHandlerConfiguration.Second.class) + .getBeanDefinition()); + assertThatThrownBy(() -> registry.postProcessBeanFactory(beanFactory)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Found 2 sampling handlers for client [client-1], found in bean with names [firstConfig, secondConfig]. Only one @McpSampling handler is allowed per client"); + } + + @Test + void twoHandlersSameBeanSampling() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("samplingConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleSamplingHandlerConfiguration.TwoHandlers.class) + .getBeanDefinition()); + assertThatThrownBy(() -> registry.postProcessBeanFactory(beanFactory)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Found 2 sampling handlers for client [client-1], found in bean with names [samplingConfig]. Only one @McpSampling handler is allowed per client"); + } + + @Test + void elicitation() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + + var request = McpSchema.ElicitRequest.builder().message("Elicit request").progressToken("token-12345").build(); + var response = registry.handleElicitation("client-1", request); + + assertThat(response.content()).hasSize(1).containsEntry("message", "Elicit request"); + assertThat(response.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + } + + @Test + void missingElicitationHandler() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + + var request = McpSchema.ElicitRequest.builder().message("Elicit request").progressToken("token-12345").build(); + assertThatThrownBy(() -> registry.handleElicitation("client-unknown", request)) + .hasMessage("Elicitation not supported") + .asInstanceOf(type(McpError.class)) + .extracting(McpError::getJsonRpcError) + .satisfies(error -> assertThat(error.data()) + .isEqualTo(Map.of("reason", "Client does not have elicitation capability"))) + .satisfies(error -> assertThat(error.code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND)); + } + + @Test + void sampling() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + + var request = McpSchema.CreateMessageRequest.builder() + .messages(List + .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Tell a joke")))) + .build(); + var response = registry.handleSampling("client-1", request); + + assertThat(response.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(response.model()).isEqualTo("testgpt-42.5"); + McpSchema.TextContent content = (McpSchema.TextContent) response.content(); + assertThat(content.text()).isEqualTo("Tell a joke"); + } + + @Test + void missingSamplingHandler() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + + var request = McpSchema.CreateMessageRequest.builder() + .messages(List + .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Tell a joke")))) + .build(); + assertThatThrownBy(() -> registry.handleSampling("client-unknown", request)) + .hasMessage("Sampling not supported") + .asInstanceOf(type(McpError.class)) + .extracting(McpError::getJsonRpcError) + .satisfies(error -> assertThat(error.data()) + .isEqualTo(Map.of("reason", "Client does not have sampling capability"))) + .satisfies(error -> assertThat(error.code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND)); + } + + @Test + void logging() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + var logRequest = McpSchema.LoggingMessageNotification.builder() + .data("Hello world") + .logger("log-me") + .level(McpSchema.LoggingLevel.INFO) + .build(); + + registry.handleLogging("client-1", logRequest); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handleLoggingMessage", logRequest), + new HandlersConfiguration.Call("handleLoggingMessageAgain", logRequest)); + } + + @Test + void progress() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + var progressRequest = new McpSchema.ProgressNotification("progress-12345", 13.37, 100., "progressing ..."); + + registry.handleProgress("client-1", progressRequest); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handleProgress", progressRequest), + new HandlersConfiguration.Call("handleProgressAgain", progressRequest)); + } + + @Test + void toolListChanged() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + List updatedTools = List.of(McpSchema.Tool.builder().name("tool-1").build(), + McpSchema.Tool.builder().name("tool-2").build()); + + registry.handleToolListChanged("client-1", updatedTools); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handleToolListChanged", updatedTools), + new HandlersConfiguration.Call("handleToolListChangedAgain", updatedTools)); + } + + @Test + void promptListChanged() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + List updatedPrompts = List.of( + new McpSchema.Prompt("prompt-1", "a test prompt", Collections.emptyList()), + new McpSchema.Prompt("prompt-2", "another test prompt", Collections.emptyList())); + + registry.handlePromptListChanged("client-1", updatedPrompts); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handlePromptListChanged", updatedPrompts), + new HandlersConfiguration.Call("handlePromptListChangedAgain", updatedPrompts)); + } + + @Test + void resourceListChanged() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + List updatedResources = List.of( + McpSchema.Resource.builder().name("resource-1").uri("file:///resource/1").build(), + McpSchema.Resource.builder().name("resource-2").uri("file:///resource/2").build()); + + registry.handleResourceListChanged("client-1", updatedResources); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handleResourceListChanged", updatedResources), + new HandlersConfiguration.Call("handleResourceListChangedAgain", updatedResources)); + } + + @Test + void supportsNonResolvableTypes() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(ClientCapabilitiesConfiguration.class.getName()) + .getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + + assertThat(registry.getCapabilities("client-1").elicitation()).isNotNull(); + } + + @Test + void supportsProxiedClass() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + var beanDefinition = BeanDefinitionBuilder.genericBeanDefinition(Object.class).getBeanDefinition(); + beanDefinition.setAttribute(AutoProxyUtils.ORIGINAL_TARGET_CLASS_ATTRIBUTE, + ClientCapabilitiesConfiguration.class); + beanFactory.registerBeanDefinition("myConfig", beanDefinition); + + registry.postProcessBeanFactory(beanFactory); + + assertThat(registry.getCapabilities("client-1").elicitation()).isNotNull(); + } + + static class ClientCapabilitiesConfiguration { + + @McpElicitation(clients = { "client-1", "client-2" }) + public McpSchema.ElicitResult elicitationHandler1(McpSchema.ElicitRequest request) { + return null; + } + + @McpElicitation(clients = { "client-3" }) + public McpSchema.ElicitResult elicitationHandler2(McpSchema.ElicitRequest request) { + return null; + } + + @McpSampling(clients = { "client-1" }) + public McpSchema.CreateMessageResult samplingHandler(McpSchema.CreateMessageRequest request) { + return null; + } + + } + + static class DoubleElicitationHandlerConfiguration { + + static class First { + + @McpElicitation(clients = { "client-1" }) + public McpSchema.ElicitResult elicitationHandler1(McpSchema.ElicitRequest request) { + return null; + } + + } + + static class Second { + + @McpElicitation(clients = { "client-1" }) + public McpSchema.ElicitResult elicitationHandler2(McpSchema.ElicitRequest request) { + return null; + } + + } + + static class TwoHandlers { + + @McpElicitation(clients = { "client-1" }) + public McpSchema.ElicitResult elicitationHandler1(McpSchema.ElicitRequest request) { + return null; + } + + @McpElicitation(clients = { "client-1" }) + public McpSchema.ElicitResult elicitationHandler2(McpSchema.ElicitRequest request) { + return null; + } + + } + + } + + static class DoubleSamplingHandlerConfiguration { + + static class First { + + @McpSampling(clients = { "client-1" }) + public McpSchema.CreateMessageResult samplingHandler1(McpSchema.CreateMessageRequest request) { + return null; + } + + } + + static class Second { + + @McpSampling(clients = { "client-1" }) + public McpSchema.CreateMessageResult samplingHandler2(McpSchema.CreateMessageRequest request) { + return null; + } + + } + + static class TwoHandlers { + + @McpSampling(clients = { "client-1" }) + public McpSchema.CreateMessageResult samplingHandler1(McpSchema.CreateMessageRequest request) { + return null; + } + + @McpSampling(clients = { "client-1" }) + public McpSchema.CreateMessageResult samplingHandler2(McpSchema.CreateMessageRequest request) { + return null; + } + + } + + } + + static class HandlersConfiguration { + + private final List calls = new ArrayList<>(); + + HandlersConfiguration() { + } + + List getCalls() { + return Collections.unmodifiableList(this.calls); + } + + @McpElicitation(clients = { "client-1" }) + McpSchema.ElicitResult elicitationHandler(McpSchema.ElicitRequest request) { + return McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.ACCEPT) + .content(Map.of("message", request.message())) + .build(); + } + + @McpSampling(clients = { "client-1" }) + McpSchema.CreateMessageResult samplingHandler(McpSchema.CreateMessageRequest request) { + return McpSchema.CreateMessageResult.builder() + .message(((McpSchema.TextContent) request.messages().get(0).content()).text()) + .model("testgpt-42.5") + .build(); + } + + @McpLogging(clients = { "client-1" }) + void handleLoggingMessage(McpSchema.LoggingMessageNotification notification) { + this.calls.add(new Call("handleLoggingMessage", notification)); + } + + @McpLogging(clients = { "client-1" }) + void handleLoggingMessageAgain(McpSchema.LoggingMessageNotification notification) { + this.calls.add(new Call("handleLoggingMessageAgain", notification)); + } + + @McpProgress(clients = { "client-1" }) + void handleProgress(McpSchema.ProgressNotification notification) { + this.calls.add(new Call("handleProgress", notification)); + } + + @McpProgress(clients = { "client-1" }) + void handleProgressAgain(McpSchema.ProgressNotification notification) { + this.calls.add(new Call("handleProgressAgain", notification)); + } + + @McpToolListChanged(clients = { "client-1" }) + void handleToolListChanged(List updatedTools) { + this.calls.add(new Call("handleToolListChanged", updatedTools)); + } + + @McpToolListChanged(clients = { "client-1" }) + void handleToolListChangedAgain(List updatedTools) { + this.calls.add(new Call("handleToolListChangedAgain", updatedTools)); + } + + @McpPromptListChanged(clients = { "client-1" }) + void handlePromptListChanged(List updatedPrompts) { + this.calls.add(new Call("handlePromptListChanged", updatedPrompts)); + } + + @McpPromptListChanged(clients = { "client-1" }) + void handlePromptListChangedAgain(List updatedPrompts) { + this.calls.add(new Call("handlePromptListChangedAgain", updatedPrompts)); + } + + @McpResourceListChanged(clients = { "client-1" }) + void handleResourceListChanged(List updatedResources) { + this.calls.add(new Call("handleResourceListChanged", updatedResources)); + } + + @McpResourceListChanged(clients = { "client-1" }) + void handleResourceListChangedAgain(List updatedResources) { + this.calls.add(new Call("handleResourceListChangedAgain", updatedResources)); + } + + // Record calls made to this object + record Call(String name, Object callRequest) { + } + + } + +}