|
19 | 19 | import java.util.List;
|
20 | 20 | import java.util.Map;
|
21 | 21 | import java.util.concurrent.ConcurrentHashMap;
|
| 22 | +import java.util.stream.Stream; |
22 | 23 |
|
| 24 | +import io.modelcontextprotocol.client.McpClient.AsyncSpec; |
23 | 25 | import org.slf4j.Logger;
|
24 | 26 | import org.slf4j.LoggerFactory;
|
25 | 27 | import org.springaicommunity.mcp.method.changed.prompt.AsyncPromptListChangedSpecification;
|
|
29 | 31 | import org.springaicommunity.mcp.method.logging.AsyncLoggingSpecification;
|
30 | 32 | import org.springaicommunity.mcp.method.progress.AsyncProgressSpecification;
|
31 | 33 | import org.springaicommunity.mcp.method.sampling.AsyncSamplingSpecification;
|
| 34 | + |
32 | 35 | import org.springframework.ai.mcp.customizer.McpAsyncClientCustomizer;
|
33 | 36 | import org.springframework.util.CollectionUtils;
|
34 | 37 |
|
35 |
| -import io.modelcontextprotocol.client.McpClient.AsyncSpec; |
36 |
| - |
37 | 38 | /**
|
38 | 39 | * @author Christian Tzolov
|
39 | 40 | */
|
@@ -80,85 +81,99 @@ public McpAsyncAnnotationCustomizer(List<AsyncSamplingSpecification> asyncSampli
|
80 | 81 | @Override
|
81 | 82 | public void customize(String name, AsyncSpec clientSpec) {
|
82 | 83 |
|
83 |
| - if (!CollectionUtils.isEmpty(asyncElicitationSpecifications)) { |
| 84 | + if (!CollectionUtils.isEmpty(this.asyncElicitationSpecifications)) { |
84 | 85 | this.asyncElicitationSpecifications.forEach(elicitationSpec -> {
|
85 |
| - if (elicitationSpec.clientId().equalsIgnoreCase(name)) { |
| 86 | + Stream.of(elicitationSpec.clients()).forEach(clientId -> { |
| 87 | + if (clientId.equalsIgnoreCase(name)) { |
86 | 88 |
|
87 |
| - // Check if client already has an elicitation spec |
88 |
| - if (clientElicitationSpecs.containsKey(name)) { |
89 |
| - throw new IllegalArgumentException("Client '" + name |
90 |
| - + "' already has an elicitationSpec registered. Only one elicitationSpec is allowed per client."); |
91 |
| - } |
| 89 | + // Check if client already has an elicitation spec |
| 90 | + if (this.clientElicitationSpecs.containsKey(name)) { |
| 91 | + throw new IllegalArgumentException("Client '" + name |
| 92 | + + "' already has an elicitationSpec registered. Only one elicitationSpec is allowed per client."); |
| 93 | + } |
92 | 94 |
|
93 |
| - clientElicitationSpecs.put(name, Boolean.TRUE); |
94 |
| - clientSpec.elicitation(elicitationSpec.elicitationHandler()); |
| 95 | + this.clientElicitationSpecs.put(name, Boolean.TRUE); |
| 96 | + clientSpec.elicitation(elicitationSpec.elicitationHandler()); |
95 | 97 |
|
96 |
| - logger.info("Registered elicitationSpec for client '{}'.", name); |
| 98 | + logger.info("Registered elicitationSpec for client '{}'.", name); |
97 | 99 |
|
98 |
| - } |
| 100 | + } |
| 101 | + }); |
99 | 102 | });
|
100 | 103 | }
|
101 | 104 |
|
102 |
| - if (!CollectionUtils.isEmpty(asyncSamplingSpecifications)) { |
| 105 | + if (!CollectionUtils.isEmpty(this.asyncSamplingSpecifications)) { |
103 | 106 | this.asyncSamplingSpecifications.forEach(samplingSpec -> {
|
104 |
| - if (samplingSpec.clientId().equalsIgnoreCase(name)) { |
| 107 | + Stream.of(samplingSpec.clients()).forEach(clientId -> { |
| 108 | + if (clientId.equalsIgnoreCase(name)) { |
105 | 109 |
|
106 |
| - // Check if client already has a sampling spec |
107 |
| - if (clientSamplingSpecs.containsKey(name)) { |
108 |
| - throw new IllegalArgumentException("Client '" + name |
109 |
| - + "' already has a samplingSpec registered. Only one samplingSpec is allowed per client."); |
110 |
| - } |
111 |
| - clientSamplingSpecs.put(name, Boolean.TRUE); |
| 110 | + // Check if client already has a sampling spec |
| 111 | + if (this.clientSamplingSpecs.containsKey(name)) { |
| 112 | + throw new IllegalArgumentException("Client '" + name |
| 113 | + + "' already has a samplingSpec registered. Only one samplingSpec is allowed per client."); |
| 114 | + } |
| 115 | + this.clientSamplingSpecs.put(name, Boolean.TRUE); |
112 | 116 |
|
113 |
| - clientSpec.sampling(samplingSpec.samplingHandler()); |
| 117 | + clientSpec.sampling(samplingSpec.samplingHandler()); |
114 | 118 |
|
115 |
| - logger.info("Registered samplingSpec for client '{}'.", name); |
116 |
| - } |
| 119 | + logger.info("Registered samplingSpec for client '{}'.", name); |
| 120 | + } |
| 121 | + }); |
117 | 122 | });
|
118 | 123 | }
|
119 | 124 |
|
120 |
| - if (!CollectionUtils.isEmpty(asyncLoggingSpecifications)) { |
| 125 | + if (!CollectionUtils.isEmpty(this.asyncLoggingSpecifications)) { |
121 | 126 | this.asyncLoggingSpecifications.forEach(loggingSpec -> {
|
122 |
| - if (loggingSpec.clientId().equalsIgnoreCase(name)) { |
123 |
| - clientSpec.loggingConsumer(loggingSpec.loggingHandler()); |
124 |
| - logger.info("Registered loggingSpec for client '{}'.", name); |
125 |
| - } |
| 127 | + Stream.of(loggingSpec.clients()).forEach(clientId -> { |
| 128 | + if (clientId.equalsIgnoreCase(name)) { |
| 129 | + clientSpec.loggingConsumer(loggingSpec.loggingHandler()); |
| 130 | + logger.info("Registered loggingSpec for client '{}'.", name); |
| 131 | + } |
| 132 | + }); |
126 | 133 | });
|
127 | 134 | }
|
128 | 135 |
|
129 |
| - if (!CollectionUtils.isEmpty(asyncProgressSpecifications)) { |
| 136 | + if (!CollectionUtils.isEmpty(this.asyncProgressSpecifications)) { |
130 | 137 | this.asyncProgressSpecifications.forEach(progressSpec -> {
|
131 |
| - if (progressSpec.clientId().equalsIgnoreCase(name)) { |
132 |
| - clientSpec.progressConsumer(progressSpec.progressHandler()); |
133 |
| - logger.info("Registered progressSpec for client '{}'.", name); |
134 |
| - } |
| 138 | + Stream.of(progressSpec.clients()).forEach(clientId -> { |
| 139 | + if (clientId.equalsIgnoreCase(name)) { |
| 140 | + clientSpec.progressConsumer(progressSpec.progressHandler()); |
| 141 | + logger.info("Registered progressSpec for client '{}'.", name); |
| 142 | + } |
| 143 | + }); |
135 | 144 | });
|
136 | 145 | }
|
137 | 146 |
|
138 |
| - if (!CollectionUtils.isEmpty(asyncToolListChangedSpecifications)) { |
| 147 | + if (!CollectionUtils.isEmpty(this.asyncToolListChangedSpecifications)) { |
139 | 148 | this.asyncToolListChangedSpecifications.forEach(toolListChangedSpec -> {
|
140 |
| - if (toolListChangedSpec.clientId().equalsIgnoreCase(name)) { |
141 |
| - clientSpec.toolsChangeConsumer(toolListChangedSpec.toolListChangeHandler()); |
142 |
| - logger.info("Registered toolListChangedSpec for client '{}'.", name); |
143 |
| - } |
| 149 | + Stream.of(toolListChangedSpec.clients()).forEach(clientId -> { |
| 150 | + if (clientId.equalsIgnoreCase(name)) { |
| 151 | + clientSpec.toolsChangeConsumer(toolListChangedSpec.toolListChangeHandler()); |
| 152 | + logger.info("Registered toolListChangedSpec for client '{}'.", name); |
| 153 | + } |
| 154 | + }); |
144 | 155 | });
|
145 | 156 | }
|
146 | 157 |
|
147 |
| - if (!CollectionUtils.isEmpty(asyncResourceListChangedSpecifications)) { |
| 158 | + if (!CollectionUtils.isEmpty(this.asyncResourceListChangedSpecifications)) { |
148 | 159 | this.asyncResourceListChangedSpecifications.forEach(resourceListChangedSpec -> {
|
149 |
| - if (resourceListChangedSpec.clientId().equalsIgnoreCase(name)) { |
150 |
| - clientSpec.resourcesChangeConsumer(resourceListChangedSpec.resourceListChangeHandler()); |
151 |
| - logger.info("Registered resourceListChangedSpec for client '{}'.", name); |
152 |
| - } |
| 160 | + Stream.of(resourceListChangedSpec.clients()).forEach(clientId -> { |
| 161 | + if (clientId.equalsIgnoreCase(name)) { |
| 162 | + clientSpec.resourcesChangeConsumer(resourceListChangedSpec.resourceListChangeHandler()); |
| 163 | + logger.info("Registered resourceListChangedSpec for client '{}'.", name); |
| 164 | + } |
| 165 | + }); |
153 | 166 | });
|
154 | 167 | }
|
155 | 168 |
|
156 |
| - if (!CollectionUtils.isEmpty(asyncPromptListChangedSpecifications)) { |
| 169 | + if (!CollectionUtils.isEmpty(this.asyncPromptListChangedSpecifications)) { |
157 | 170 | this.asyncPromptListChangedSpecifications.forEach(promptListChangedSpec -> {
|
158 |
| - if (promptListChangedSpec.clientId().equalsIgnoreCase(name)) { |
159 |
| - clientSpec.promptsChangeConsumer(promptListChangedSpec.promptListChangeHandler()); |
160 |
| - logger.info("Registered promptListChangedSpec for client '{}'.", name); |
161 |
| - } |
| 171 | + Stream.of(promptListChangedSpec.clients()).forEach(clientId -> { |
| 172 | + if (clientId.equalsIgnoreCase(name)) { |
| 173 | + clientSpec.promptsChangeConsumer(promptListChangedSpec.promptListChangeHandler()); |
| 174 | + logger.info("Registered promptListChangedSpec for client '{}'.", name); |
| 175 | + } |
| 176 | + }); |
162 | 177 | });
|
163 | 178 | }
|
164 | 179 | }
|
|
0 commit comments