Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
package io.modelcontextprotocol.common;

import java.util.Map;
import java.util.Optional;

import io.modelcontextprotocol.spec.HttpHeaders;
import io.modelcontextprotocol.util.Assert;

/**
Expand All @@ -28,6 +30,21 @@ public Object get(String key) {
return this.metadata.get(key);
}

@Override
public Optional<String> lastEventId() {
return Optional.ofNullable(metadata.get(HttpHeaders.LAST_EVENT_ID)).map(Object::toString);
}

@Override
public Optional<String> sessionId() {
return Optional.ofNullable(metadata.get(HttpHeaders.MCP_SESSION_ID)).map(Object::toString);
}

@Override
public Optional<String> protocolVersion() {
return Optional.ofNullable(metadata.get(HttpHeaders.PROTOCOL_VERSION)).map(Object::toString);
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@

package io.modelcontextprotocol.common;

import io.modelcontextprotocol.spec.ProtocolVersions;

import java.security.Principal;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

/**
* Context associated with the transport layer. It allows to add transport-level metadata
Expand Down Expand Up @@ -36,11 +42,58 @@ static McpTransportContext create(Map<String, Object> metadata) {
return new DefaultMcpTransportContext(metadata);
}

/**
* Returns a Map with entries for MCP transport concepts such as Protocol version,
* session ID and Last Event ID.
* @param headers Function typically backed by an HTTP Request Headers implementation.
* @return Map with entries for MCP transport concepts such as Protocol version,
* session ID and Last Event ID.
*/
static Map<String, Object> createMetadata(Function<String, String> headers) {
Map<String, Object> metadata = new HashMap<>(3);
metadata.put(io.modelcontextprotocol.spec.HttpHeaders.PROTOCOL_VERSION,
Optional.ofNullable(headers.apply(io.modelcontextprotocol.spec.HttpHeaders.PROTOCOL_VERSION))
.orElse(ProtocolVersions.MCP_2025_03_26));
Optional.ofNullable(headers.apply(io.modelcontextprotocol.spec.HttpHeaders.MCP_SESSION_ID))
.ifPresent(v -> metadata.put(io.modelcontextprotocol.spec.HttpHeaders.MCP_SESSION_ID, v));
Optional.ofNullable(headers.apply(io.modelcontextprotocol.spec.HttpHeaders.LAST_EVENT_ID))
.ifPresent(v -> metadata.put(io.modelcontextprotocol.spec.HttpHeaders.LAST_EVENT_ID, v));
return metadata;
}

/**
* Extract a value from the context.
* @param key the key under the data is expected
* @return the associated value or {@code null} if missing.
*/
Object get(String key);

/**
* @return The MCP Protocl Version
*/
default Optional<String> protocolVersion() {
return Optional.empty();
}

/**
* @return The Session ID
*/
default Optional<String> sessionId() {
return Optional.empty();
}

/**
* @return The Last Event ID
*/
default Optional<String> lastEventId() {
return Optional.empty();
}

/**
* @return The Principal. it may represent the authenticated user.
*/
default Optional<Principal> principal() {
return Optional.empty();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright 2025-2025 the original author or authors.
*/
package io.modelcontextprotocol.server.servlet;

import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import jakarta.servlet.http.HttpServletRequest;

/**
* {@link McpTransportContextExtractor} implementation for {@link HttpServletRequest}.
*/
public class HttpServletRequestMcpTransportContextExtractor
implements McpTransportContextExtractor<HttpServletRequest> {

@Override
public McpTransportContext extract(HttpServletRequest request) {
return McpTransportContext.create(McpTransportContext.createMetadata(request::getHeader));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
/*
* Copyright 2025-2025 the original author or authors.
*/
/**
* Classes related with servlet support.
*/
package io.modelcontextprotocol.server.servlet;
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.modelcontextprotocol.json.TypeRef;
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.server.servlet.HttpServletRequestMcpTransportContextExtractor;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
Expand Down Expand Up @@ -503,8 +504,7 @@ public static class Builder {

private String sseEndpoint = DEFAULT_SSE_ENDPOINT;

private McpTransportContextExtractor<HttpServletRequest> contextExtractor = (
serverRequest) -> McpTransportContext.EMPTY;
private McpTransportContextExtractor<HttpServletRequest> contextExtractor;

private Duration keepAliveInterval;

Expand Down Expand Up @@ -594,7 +594,8 @@ public HttpServletSseServerTransportProvider build() {
}
return new HttpServletSseServerTransportProvider(
jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, baseUrl, messageEndpoint, sseEndpoint,
keepAliveInterval, contextExtractor);
keepAliveInterval,
contextExtractor == null ? new HttpServletRequestMcpTransportContextExtractor() : contextExtractor);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.io.IOException;
import java.io.PrintWriter;

import io.modelcontextprotocol.server.servlet.HttpServletRequestMcpTransportContextExtractor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -240,8 +241,7 @@ public static class Builder {

private String mcpEndpoint = "/mcp";

private McpTransportContextExtractor<HttpServletRequest> contextExtractor = (
serverRequest) -> McpTransportContext.EMPTY;
private McpTransportContextExtractor<HttpServletRequest> contextExtractor;

private Builder() {
// used by a static method
Expand Down Expand Up @@ -297,7 +297,8 @@ public Builder contextExtractor(McpTransportContextExtractor<HttpServletRequest>
public HttpServletStatelessServerTransport build() {
Assert.notNull(mcpEndpoint, "Message endpoint must be set");
return new HttpServletStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper,
mcpEndpoint, contextExtractor);
mcpEndpoint,
contextExtractor == null ? new HttpServletRequestMcpTransportContextExtractor() : contextExtractor);
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
/*
* Copyright 2024-2024 the original author or authors.
*/

package io.modelcontextprotocol.server.transport;

import java.io.BufferedReader;
Expand All @@ -13,6 +12,7 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;

import io.modelcontextprotocol.server.servlet.HttpServletRequestMcpTransportContextExtractor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -769,8 +769,7 @@ public static class Builder {

private boolean disallowDelete = false;

private McpTransportContextExtractor<HttpServletRequest> contextExtractor = (
serverRequest) -> McpTransportContext.EMPTY;
private McpTransportContextExtractor<HttpServletRequest> contextExtractor;

private Duration keepAliveInterval;

Expand Down Expand Up @@ -843,7 +842,8 @@ public HttpServletStreamableServerTransportProvider build() {
Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set");
return new HttpServletStreamableServerTransportProvider(
jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, disallowDelete,
contextExtractor, keepAliveInterval);
contextExtractor == null ? new HttpServletRequestMcpTransportContextExtractor() : contextExtractor,
keepAliveInterval);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package io.modelcontextprotocol.common;

import java.util.HashMap;
import java.util.Map;
import java.util.function.BiFunction;

Expand All @@ -18,6 +19,7 @@
import io.modelcontextprotocol.server.McpServerFeatures;
import io.modelcontextprotocol.server.McpStatelessServerFeatures;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.server.servlet.HttpServletRequestMcpTransportContextExtractor;
import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider;
import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport;
import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider;
Expand Down Expand Up @@ -91,10 +93,20 @@ public class AsyncServerMcpTransportContextIntegrationTests {
return Mono.just(builder);
};

private final McpTransportContextExtractor<HttpServletRequest> serverContextExtractor = (HttpServletRequest r) -> {
var headerValue = r.getHeader(HEADER_NAME);
return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue))
: McpTransportContext.EMPTY;
private final McpTransportContextExtractor<HttpServletRequest> serverContextExtractor = new HttpServletRequestMcpTransportContextExtractor() {
@Override
public McpTransportContext extract(HttpServletRequest request) {
return McpTransportContext.create(metadata(request));
}

private Map<String, Object> metadata(HttpServletRequest r) {
Map<String, Object> m = new HashMap<>(McpTransportContext.createMetadata(r::getHeader));
var headerValue = r.getHeader(HEADER_NAME);
if (headerValue != null) {
m.put("server-side-header-value", headerValue);
}
return m;
}
};

private final HttpServletStatelessServerTransport statelessServerTransport = HttpServletStatelessServerTransport
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2024-2025 the original author or authors.
*/

package io.modelcontextprotocol.common;

import io.modelcontextprotocol.spec.HttpHeaders;
import org.junit.jupiter.api.Test;

import java.util.Collections;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;

class DefaultMcpTransportContextTest {

@Test
void protocolVersionNotPresent() {
var ctx = new DefaultMcpTransportContext(Collections.emptyMap());
assertFalse(ctx.protocolVersion().isPresent());
}

@Test
void sessionIdNotPresent() {
var ctx = new DefaultMcpTransportContext(Collections.emptyMap());
assertFalse(ctx.sessionId().isPresent());
}

@Test
void lastEventIdNotPresent() {
var ctx = new DefaultMcpTransportContext(Collections.emptyMap());
assertFalse(ctx.lastEventId().isPresent());
}

@Test
void protocolVersion_returnsProvidedValue() {
var ctx = new DefaultMcpTransportContext(Map.of(HttpHeaders.PROTOCOL_VERSION, "2025-01-01",
HttpHeaders.MCP_SESSION_ID, "session-123", HttpHeaders.LAST_EVENT_ID, "evt-456"));
assertEquals("2025-01-01", ctx.protocolVersion().orElseThrow());
}

@Test
void sessionId_returnsProvidedValue() {
var ctx = new DefaultMcpTransportContext(Map.of(HttpHeaders.PROTOCOL_VERSION, "2025-01-01",
HttpHeaders.MCP_SESSION_ID, "session-abc", HttpHeaders.LAST_EVENT_ID, "evt-456"));
assertEquals("session-abc", ctx.sessionId().orElseThrow());
}

@Test
void lastEventId_returnsProvidedValue() {
var ctx = new DefaultMcpTransportContext(Map.of(HttpHeaders.PROTOCOL_VERSION, "2025-01-01",
HttpHeaders.MCP_SESSION_ID, "session-abc", HttpHeaders.LAST_EVENT_ID, "evt-999"));
assertEquals("evt-999", ctx.lastEventId().orElseThrow());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
import io.modelcontextprotocol.server.McpStatelessServerFeatures;
import io.modelcontextprotocol.server.McpSyncServerExchange;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.server.servlet.HttpServletRequestMcpTransportContextExtractor;
import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider;
import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport;
import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider;
import io.modelcontextprotocol.server.transport.TomcatTestUtil;
import io.modelcontextprotocol.spec.McpSchema;
import jakarta.servlet.Servlet;
import jakarta.servlet.http.HttpServletRequest;

import java.util.HashMap;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Supplier;
Expand Down Expand Up @@ -71,10 +74,20 @@ public class SyncServerMcpTransportContextIntegrationTests {
}
};

private final McpTransportContextExtractor<HttpServletRequest> serverContextExtractor = (HttpServletRequest r) -> {
var headerValue = r.getHeader(HEADER_NAME);
return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue))
: McpTransportContext.EMPTY;
private final McpTransportContextExtractor<HttpServletRequest> serverContextExtractor = new HttpServletRequestMcpTransportContextExtractor() {
@Override
public McpTransportContext extract(HttpServletRequest request) {
return McpTransportContext.create(metadata(request));
}

private Map<String, Object> metadata(HttpServletRequest r) {
Map<String, Object> m = new HashMap<>(McpTransportContext.createMetadata(r::getHeader));
var headerValue = r.getHeader(HEADER_NAME);
if (headerValue != null) {
m.put("server-side-header-value", headerValue);
}
return m;
}
};

private final BiFunction<McpTransportContext, McpSchema.CallToolRequest, McpSchema.CallToolResult> statelessHandler = (
Expand Down
Loading