Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.slf4j.LoggerFactory;
import reactor.core.Disposable;
import reactor.core.Disposables;
import reactor.core.Exceptions;
import reactor.core.publisher.Mono;

import java.util.Optional;
Expand Down Expand Up @@ -77,8 +78,36 @@ public void close() {

@Override
public Mono<Void> closeGracefully() {
return Mono.from(this.onClose.apply(this.sessionId.get()))
.then(Mono.fromRunnable(this.openConnections::dispose));
return Mono.defer(() -> {
final String sessionId = this.sessionId.get();

final AtomicReference<Throwable> primary = new AtomicReference<>(null);

// Subscribe to onClose publisher and capture any error
return Mono.from(this.onClose.apply(sessionId)).onErrorResume(err -> {
primary.set(err);
return Mono.empty();
})
// Always dispose openConnections
.then(Mono.defer(() -> {
try {
this.openConnections.dispose();
}
catch (Throwable disposeEx) {
if (primary.get() != null) {
primary.get().addSuppressed(disposeEx);
}
else {
primary.set(disposeEx);
}
}

// Re-emit the original error (with suppressed dispose error),
// complete
Throwable throwable = primary.get();
return (throwable == null) ? Mono.empty() : Mono.error(Exceptions.propagate(throwable));
}));
});
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package io.modelcontextprotocol.spec;

import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.reactivestreams.Publisher;
import org.springframework.util.ReflectionUtils;
import reactor.core.Disposable;
import reactor.core.publisher.Mono;

import java.lang.reflect.Field;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/**
* Tests for {@link DefaultMcpTransportSession}.
*
* @author Phani Pemmaraju
*/
class DefaultMcpTransportSessionTests {

/** Minimal Disposable to flag that dispose() was called. */
static final class FlagDisposable implements Disposable {

final AtomicBoolean disposed = new AtomicBoolean(false);

@Override
public void dispose() {
disposed.set(true);
}

@Override
public boolean isDisposed() {
return disposed.get();
}

}

@Test
void closeGracefully_disposes_when_onClose_throws() {
@SuppressWarnings("unchecked")
Function<String, Publisher<Void>> onClose = Mockito.mock(Function.class);
Mockito.when(onClose.apply(Mockito.any())).thenReturn(Mono.error(new RuntimeException("runtime-exception")));

// construct session with required ctor
var session = new DefaultMcpTransportSession(onClose);

// seed session id
setField(session, "sessionId", new AtomicReference<>("sessionId-123"));

// get the existing final composite and add a child flag-disposable
Disposable.Composite composite = (Disposable.Composite) getField(session, "openConnections");
FlagDisposable flag = new FlagDisposable();
composite.add(flag);

// act + assert: original onClose error is propagated
assertThatThrownBy(() -> session.closeGracefully().block()).isInstanceOf(RuntimeException.class)
.hasMessageContaining("runtime-exception");

// and the child disposable was disposed => proves composite.dispose() executed
assertThat(flag.isDisposed()).isTrue();
}

@Test
void closeGracefully_propagates_onClose_error_and_disposes_children() {
// onClose fails again
@SuppressWarnings("unchecked")
Function<String, Publisher<Void>> onClose = Mockito.mock(Function.class);
Mockito.when(onClose.apply(Mockito.any())).thenReturn(Mono.error(new RuntimeException("runtime-exception")));

var session = new DefaultMcpTransportSession(onClose);
setField(session, "sessionId", new AtomicReference<>("sessionId-xyz"));

Disposable.Composite composite = (Disposable.Composite) getField(session, "openConnections");
FlagDisposable a = new FlagDisposable();
FlagDisposable b = new FlagDisposable();
composite.add(a);
composite.add(b);

Throwable thrown = Assertions.catchThrowable(() -> session.closeGracefully().block());

// primary error is from onClose
assertThat(thrown).isInstanceOf(RuntimeException.class).hasMessageContaining("runtime-exception");

// both children disposed
assertThat(a.isDisposed()).isTrue();
assertThat(b.isDisposed()).isTrue();
}

private static void setField(Object target, String fieldName, Object value) {
Field f = ReflectionUtils.findField(target.getClass(), fieldName);
if (f == null)
throw new IllegalArgumentException("No such field: " + fieldName);
ReflectionUtils.makeAccessible(f);
ReflectionUtils.setField(f, target, value);
}

private static Object getField(Object target, String fieldName) {
Field f = ReflectionUtils.findField(target.getClass(), fieldName);
if (f == null)
throw new IllegalArgumentException("No such field: " + fieldName);
ReflectionUtils.makeAccessible(f);
return ReflectionUtils.getField(f, target);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package io.modelcontextprotocol;

import java.time.Duration;
import java.util.Map;
import java.util.stream.Stream;

import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.*;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.server.RouterFunctions;
import org.springframework.web.reactive.function.server.ServerRequest;

import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.server.McpServer;
import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
import reactor.core.publisher.Hooks;
import reactor.netty.DisposableServer;
import reactor.netty.http.server.HttpServer;

@Timeout(15)
public class WebFluxSseCloseGracefullyIntegrationTests extends AbstractMcpClientServerIntegrationTests {

private int port;

private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse";

private static final String DEFAULT_MESSAGE_ENDPOINT = "/mcp/message";

private DisposableServer httpServer;

private WebFluxSseServerTransportProvider mcpServerTransportProvider;

static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext
.create(Map.of("important", "value"));

static Stream<Arguments> clientsForTesting() {
return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux"));
}

@Override
protected void prepareClients(int port, String mcpEndpoint) {
clientBuilders
.put("httpclient",
McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + port)
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
.build()).requestTimeout(Duration.ofSeconds(10)));

clientBuilders.put("webflux", McpClient
.sync(WebFluxSseClientTransport.builder(org.springframework.web.reactive.function.client.WebClient.builder()
.baseUrl("http://localhost:" + port)).sseEndpoint(CUSTOM_SSE_ENDPOINT).build())
.requestTimeout(Duration.ofSeconds(10)));
}

@Override
protected AsyncSpecification<?> prepareAsyncServerBuilder() {
return McpServer.async(mcpServerTransportProvider);
}

@Override
protected SingleSessionSyncSpecification prepareSyncServerBuilder() {
return McpServer.sync(mcpServerTransportProvider);
}

@BeforeEach
void before() {
// Build the transport provider with BOTH endpoints (message required)
this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider.Builder()
.messageEndpoint(DEFAULT_MESSAGE_ENDPOINT)
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
.contextExtractor(TEST_CONTEXT_EXTRACTOR)
.build();

// Wire session factory
prepareSyncServerBuilder().build();

// Bind on ephemeral port and discover the actual port
var httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction());
var adapter = new ReactorHttpHandlerAdapter(httpHandler);
this.httpServer = HttpServer.create().port(0).handle(adapter).bindNow();
this.port = httpServer.port();

// Build clients using the discovered port
prepareClients(this.port, null);

// keep your onErrorDropped suppression if you need it for noisy Reactor paths
Hooks.onErrorDropped(e -> {
});
}

@AfterEach
void after() {
if (httpServer != null)
httpServer.disposeNow();
Hooks.resetOnErrorDropped();
}

@ParameterizedTest(name = "closeGracefully after outage: {0}")
@MethodSource("clientsForTesting")
@DisplayName("closeGracefully() signals failure after server outage (WebFlux/SSE, sync client)")
void closeGracefully_disposes_after_server_unavailable(String clientKey) {
var reactiveClient = io.modelcontextprotocol.client.McpClient
.async(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + this.port))
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
.build())
.requestTimeout(Duration.ofSeconds(10))
.build();

reactiveClient.initialize().block(Duration.ofSeconds(5));

httpServer.disposeNow();

Assertions.assertThatCode(() -> reactiveClient.closeGracefully().block(Duration.ofSeconds(5)))
.doesNotThrowAnyException();

Assertions.assertThatThrownBy(() -> reactiveClient.initialize().block(Duration.ofSeconds(3)))
.isInstanceOf(Exception.class);

}

}