Skip to content

Commit b04c30b

Browse files
134130tzolov
authored andcommitted
feat: Add progress notification support for MCP operations (#407)
Implement progress tracking for long-running operations with: - New ProgressNotification schema and client/server support - Progress consumer handlers in sync/async client builders - Server exchange methods for sending progress updates - Comprehensive integration tests - Backwards compatibility maintained - Add additional tests Signed-off-by: Christian Tzolov <[email protected]>
1 parent 87cdaf8 commit b04c30b

File tree

13 files changed

+535
-8
lines changed

13 files changed

+535
-8
lines changed

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,113 @@ void testLoggingNotification(String clientType) throws InterruptedException {
10161016
mcpServer.close();
10171017
}
10181018

1019+
// ---------------------------------------
1020+
// Progress Tests
1021+
// ---------------------------------------
1022+
@ParameterizedTest(name = "{0} : {displayName} ")
1023+
@ValueSource(strings = { "httpclient", "webflux" })
1024+
void testProgressNotification(String clientType) throws InterruptedException {
1025+
int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress
1026+
// token
1027+
CountDownLatch latch = new CountDownLatch(expectedNotificationsCount);
1028+
// Create a list to store received logging notifications
1029+
List<McpSchema.ProgressNotification> receivedNotifications = new CopyOnWriteArrayList<>();
1030+
1031+
var clientBuilder = clientBuilders.get(clientType);
1032+
1033+
// Create server with a tool that sends logging notifications
1034+
McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
1035+
.tool(McpSchema.Tool.builder()
1036+
.name("progress-test")
1037+
.description("Test progress notifications")
1038+
.inputSchema(emptyJsonSchema)
1039+
.build())
1040+
.callHandler((exchange, request) -> {
1041+
1042+
// Create and send notifications
1043+
var progressToken = (String) request.meta().get("progressToken");
1044+
1045+
return exchange
1046+
.progressNotification(
1047+
new McpSchema.ProgressNotification(progressToken, 0.0, 1.0, "Processing started"))
1048+
.then(exchange.progressNotification(
1049+
new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Processing data")))
1050+
.then(// Send a progress notification with another progress value
1051+
// should
1052+
exchange.progressNotification(new McpSchema.ProgressNotification("another-progress-token",
1053+
0.0, 1.0, "Another processing started")))
1054+
.then(exchange.progressNotification(
1055+
new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed")))
1056+
.thenReturn(new CallToolResult(("Progress test completed"), false));
1057+
})
1058+
.build();
1059+
1060+
var mcpServer = McpServer.async(mcpServerTransportProvider)
1061+
.serverInfo("test-server", "1.0.0")
1062+
.capabilities(ServerCapabilities.builder().tools(true).build())
1063+
.tools(tool)
1064+
.build();
1065+
1066+
try (
1067+
// Create client with progress notification handler
1068+
var mcpClient = clientBuilder.progressConsumer(notification -> {
1069+
receivedNotifications.add(notification);
1070+
latch.countDown();
1071+
}).build()) {
1072+
1073+
// Initialize client
1074+
InitializeResult initResult = mcpClient.initialize();
1075+
assertThat(initResult).isNotNull();
1076+
1077+
// Call the tool that sends progress notifications
1078+
McpSchema.CallToolRequest callToolRequest = McpSchema.CallToolRequest.builder()
1079+
.name("progress-test")
1080+
.meta(Map.of("progressToken", "test-progress-token"))
1081+
.build();
1082+
CallToolResult result = mcpClient.callTool(callToolRequest);
1083+
assertThat(result).isNotNull();
1084+
assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class);
1085+
assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed");
1086+
1087+
assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue();
1088+
1089+
// Should have received 3 notifications
1090+
assertThat(receivedNotifications).hasSize(expectedNotificationsCount);
1091+
1092+
Map<String, McpSchema.ProgressNotification> notificationMap = receivedNotifications.stream()
1093+
.collect(Collectors.toMap(n -> n.message(), n -> n));
1094+
1095+
// First notification should be 0.0/1.0 progress
1096+
assertThat(notificationMap.get("Processing started").progressToken()).isEqualTo("test-progress-token");
1097+
assertThat(notificationMap.get("Processing started").progress()).isEqualTo(0.0);
1098+
assertThat(notificationMap.get("Processing started").total()).isEqualTo(1.0);
1099+
assertThat(notificationMap.get("Processing started").message()).isEqualTo("Processing started");
1100+
1101+
// Second notification should be 0.5/1.0 progress
1102+
assertThat(notificationMap.get("Processing data").progressToken()).isEqualTo("test-progress-token");
1103+
assertThat(notificationMap.get("Processing data").progress()).isEqualTo(0.5);
1104+
assertThat(notificationMap.get("Processing data").total()).isEqualTo(1.0);
1105+
assertThat(notificationMap.get("Processing data").message()).isEqualTo("Processing data");
1106+
1107+
// Third notification should be another progress token with 0.0/1.0 progress
1108+
assertThat(notificationMap.get("Another processing started").progressToken())
1109+
.isEqualTo("another-progress-token");
1110+
assertThat(notificationMap.get("Another processing started").progress()).isEqualTo(0.0);
1111+
assertThat(notificationMap.get("Another processing started").total()).isEqualTo(1.0);
1112+
assertThat(notificationMap.get("Another processing started").message())
1113+
.isEqualTo("Another processing started");
1114+
1115+
// Fourth notification should be 1.0/1.0 progress
1116+
assertThat(notificationMap.get("Processing completed").progressToken()).isEqualTo("test-progress-token");
1117+
assertThat(notificationMap.get("Processing completed").progress()).isEqualTo(1.0);
1118+
assertThat(notificationMap.get("Processing completed").total()).isEqualTo(1.0);
1119+
assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed");
1120+
}
1121+
finally {
1122+
mcpServer.close();
1123+
}
1124+
}
1125+
10191126
// ---------------------------------------
10201127
// Completion Tests
10211128
// ---------------------------------------

mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212

1313
import java.time.Duration;
1414
import java.util.ArrayList;
15+
import java.util.List;
1516
import java.util.Map;
1617
import java.util.Objects;
18+
import java.util.concurrent.CopyOnWriteArrayList;
1719
import java.util.concurrent.atomic.AtomicBoolean;
1820
import java.util.concurrent.atomic.AtomicInteger;
1921
import java.util.concurrent.atomic.AtomicReference;
@@ -49,6 +51,7 @@
4951
import io.modelcontextprotocol.spec.McpTransport;
5052
import reactor.core.publisher.Flux;
5153
import reactor.core.publisher.Mono;
54+
import reactor.core.publisher.Sinks;
5255
import reactor.test.StepVerifier;
5356

5457
/**
@@ -420,7 +423,7 @@ void testListAllPromptsReturnsImmutableList() {
420423
.consumeNextWith(result -> {
421424
assertThat(result.prompts()).isNotNull();
422425
// Verify that the returned list is immutable
423-
assertThatThrownBy(() -> result.prompts().add(new Prompt("test", "Test", "test", null)))
426+
assertThatThrownBy(() -> result.prompts().add(new Prompt("test", "test", "test", null)))
424427
.isInstanceOf(UnsupportedOperationException.class);
425428
})
426429
.verifyComplete();
@@ -792,4 +795,39 @@ void testSampling() {
792795
});
793796
}
794797

798+
// ---------------------------------------
799+
// Progress Notification Tests
800+
// ---------------------------------------
801+
802+
@Test
803+
void testProgressConsumer() {
804+
Sinks.Many<McpSchema.ProgressNotification> sink = Sinks.many().unicast().onBackpressureBuffer();
805+
List<McpSchema.ProgressNotification> receivedNotifications = new CopyOnWriteArrayList<>();
806+
807+
withClient(createMcpTransport(), builder -> builder.progressConsumer(notification -> {
808+
receivedNotifications.add(notification);
809+
sink.tryEmitNext(notification);
810+
return Mono.empty();
811+
}), client -> {
812+
StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete();
813+
814+
// Call a tool that sends progress notifications
815+
CallToolRequest request = CallToolRequest.builder()
816+
.name("longRunningOperation")
817+
.arguments(Map.of("duration", 1, "steps", 2))
818+
.progressToken("test-token")
819+
.build();
820+
821+
StepVerifier.create(client.callTool(request)).consumeNextWith(result -> {
822+
assertThat(result).isNotNull();
823+
}).verifyComplete();
824+
825+
// Use StepVerifier to verify the progress notifications via the sink
826+
StepVerifier.create(sink.asFlux()).expectNextCount(2).thenCancel().verify(Duration.ofSeconds(3));
827+
828+
assertThat(receivedNotifications).hasSize(2);
829+
assertThat(receivedNotifications.get(0).progressToken()).isEqualTo("test-token");
830+
});
831+
}
832+
795833
}

mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
import java.time.Duration;
1414
import java.util.List;
1515
import java.util.Map;
16+
import java.util.concurrent.CopyOnWriteArrayList;
17+
import java.util.concurrent.CountDownLatch;
18+
import java.util.concurrent.TimeUnit;
1619
import java.util.concurrent.atomic.AtomicBoolean;
1720
import java.util.concurrent.atomic.AtomicInteger;
1821
import java.util.concurrent.atomic.AtomicReference;
@@ -648,4 +651,48 @@ void testSampling() {
648651
});
649652
}
650653

654+
// ---------------------------------------
655+
// Progress Notification Tests
656+
// ---------------------------------------
657+
658+
@Test
659+
void testProgressConsumer() {
660+
AtomicInteger progressNotificationCount = new AtomicInteger(0);
661+
List<McpSchema.ProgressNotification> receivedNotifications = new CopyOnWriteArrayList<>();
662+
CountDownLatch latch = new CountDownLatch(2);
663+
664+
withClient(createMcpTransport(), builder -> builder.progressConsumer(notification -> {
665+
System.out.println("Received progress notification: " + notification);
666+
receivedNotifications.add(notification);
667+
progressNotificationCount.incrementAndGet();
668+
latch.countDown();
669+
}), client -> {
670+
client.initialize();
671+
672+
// Call a tool that sends progress notifications
673+
CallToolRequest request = CallToolRequest.builder()
674+
.name("longRunningOperation")
675+
.arguments(Map.of("duration", 1, "steps", 2))
676+
.progressToken("test-token")
677+
.build();
678+
679+
CallToolResult result = client.callTool(request);
680+
681+
assertThat(result).isNotNull();
682+
683+
try {
684+
// Wait for progress notifications to be processed
685+
latch.await(3, TimeUnit.SECONDS);
686+
}
687+
catch (InterruptedException e) {
688+
e.printStackTrace();
689+
}
690+
691+
assertThat(progressNotificationCount.get()).isEqualTo(2);
692+
693+
assertThat(receivedNotifications).isNotEmpty();
694+
assertThat(receivedNotifications.get(0).progressToken()).isEqualTo("test-token");
695+
});
696+
}
697+
651698
}

mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ public class McpAsyncClient {
100100
public static final TypeReference<LoggingMessageNotification> LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeReference<>() {
101101
};
102102

103+
public static final TypeReference<McpSchema.ProgressNotification> PROGRESS_NOTIFICATION_TYPE_REF = new TypeReference<>() {
104+
};
105+
103106
/**
104107
* Client capabilities.
105108
*/
@@ -253,6 +256,16 @@ public class McpAsyncClient {
253256
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE,
254257
asyncLoggingNotificationHandler(loggingConsumersFinal));
255258

259+
// Utility Progress Notification
260+
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumersFinal = new ArrayList<>();
261+
progressConsumersFinal
262+
.add((notification) -> Mono.fromRunnable(() -> logger.debug("Progress: {}", notification)));
263+
if (!Utils.isEmpty(features.progressConsumers())) {
264+
progressConsumersFinal.addAll(features.progressConsumers());
265+
}
266+
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROGRESS,
267+
asyncProgressNotificationHandler(progressConsumersFinal));
268+
256269
this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo,
257270
List.of(McpSchema.LATEST_PROTOCOL_VERSION), initializationTimeout,
258271
ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers,
@@ -846,6 +859,28 @@ public Mono<Void> setLoggingLevel(LoggingLevel loggingLevel) {
846859
});
847860
}
848861

862+
/**
863+
* Create a notification handler for progress notifications from the server. This
864+
* handler automatically distributes progress notifications to all registered
865+
* consumers.
866+
* @param progressConsumers List of consumers that will be notified when a progress
867+
* message is received. Each consumer receives the progress notification.
868+
* @return A NotificationHandler that processes progress notifications by distributing
869+
* the message to all registered consumers
870+
*/
871+
private NotificationHandler asyncProgressNotificationHandler(
872+
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers) {
873+
874+
return params -> {
875+
McpSchema.ProgressNotification progressNotification = transport.unmarshalFrom(params,
876+
PROGRESS_NOTIFICATION_TYPE_REF);
877+
878+
return Flux.fromIterable(progressConsumers)
879+
.flatMap(consumer -> consumer.apply(progressNotification))
880+
.then();
881+
};
882+
}
883+
849884
/**
850885
* This method is package-private and used for test only. Should not be called by user
851886
* code.

0 commit comments

Comments
 (0)