Skip to content

Commit 8d679f2

Browse files
committed
fix: validate content type and protocol version in REST routes
- Removing consumes from route annotations and adding explicit ContentTypeNotSupportedError validation in sendMessage/sendMessageStreaming - Moving VersionNotSupportedError HTTP status from 501 -> 400 - Adding unit and integration tests for both scenarios fix: HTTP+JSON error mapping is incorrect for ContentTypeNotSupportedError & VersionNotSupportedError Signed-off-by: Emmanuel Hugonnet <ehugonne@redhat.com>
1 parent 4bcc122 commit 8d679f2

File tree

5 files changed

+120
-7
lines changed

5 files changed

+120
-7
lines changed

reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import io.a2a.server.extensions.A2AExtensions;
3232
import io.a2a.server.util.async.Internal;
3333
import io.a2a.spec.A2AError;
34+
import io.a2a.spec.ContentTypeNotSupportedError;
3435
import io.a2a.spec.InternalError;
3536
import io.a2a.spec.InvalidParamsError;
3637
import io.a2a.spec.MethodNotFoundError;
@@ -165,8 +166,13 @@ public class A2AServerRoutes {
165166
* @param body the JSON request body
166167
* @param rc the Vert.x routing context
167168
*/
168-
@Route(regex = "^\\/(?<tenant>[^\\/]*\\/?)message:send$", order = 1, methods = {Route.HttpMethod.POST}, consumes = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING)
169+
@Route(regex = "^\\/(?<tenant>[^\\/]*\\/?)message:send$", order = 1, methods = {Route.HttpMethod.POST}, type = Route.HandlerType.BLOCKING)
169170
public void sendMessage(@Body String body, RoutingContext rc) {
171+
String contentType = rc.request().getHeader(CONTENT_TYPE);
172+
if (contentType == null || !contentType.contains(APPLICATION_JSON)) {
173+
sendResponse(rc, jsonRestHandler.createErrorResponse(new ContentTypeNotSupportedError(null, null, null)));
174+
return;
175+
}
170176
ServerCallContext context = createCallContext(rc, SEND_MESSAGE_METHOD);
171177
HTTPRestResponse response = null;
172178
try {
@@ -198,8 +204,13 @@ public void sendMessage(@Body String body, RoutingContext rc) {
198204
* @param body the JSON request body
199205
* @param rc the Vert.x routing context
200206
*/
201-
@Route(regex = "^\\/(?<tenant>[^\\/]*\\/?)message:stream$", order = 1, methods = {Route.HttpMethod.POST}, consumes = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING)
207+
@Route(regex = "^\\/(?<tenant>[^\\/]*\\/?)message:stream$", order = 1, methods = {Route.HttpMethod.POST}, type = Route.HandlerType.BLOCKING)
202208
public void sendMessageStreaming(@Body String body, RoutingContext rc) {
209+
String contentType = rc.request().getHeader(CONTENT_TYPE);
210+
if (contentType == null || !contentType.contains(APPLICATION_JSON)) {
211+
sendResponse(rc, jsonRestHandler.createErrorResponse(new ContentTypeNotSupportedError(null, null, null)));
212+
return;
213+
}
203214
ServerCallContext context = createCallContext(rc, SEND_STREAMING_MESSAGE_METHOD);
204215
HTTPRestStreamingResponse streamingResponse = null;
205216
HTTPRestResponse error = null;
@@ -339,6 +350,11 @@ public void getTask(RoutingContext rc) {
339350
*/
340351
@Route(regex = "^\\/(?<tenant>[^\\/]*\\/?)tasks\\/(?<taskId>[^/]+):cancel$", order = 1, methods = {Route.HttpMethod.POST}, type = Route.HandlerType.BLOCKING)
341352
public void cancelTask(@Body String body, RoutingContext rc) {
353+
String contentType = rc.request().getHeader(CONTENT_TYPE);
354+
if (contentType == null || !contentType.contains(APPLICATION_JSON)) {
355+
sendResponse(rc, jsonRestHandler.createErrorResponse(new ContentTypeNotSupportedError(null, null, null)));
356+
return;
357+
}
342358
String taskId = rc.pathParam("taskId");
343359
ServerCallContext context = createCallContext(rc, CANCEL_TASK_METHOD);
344360
HTTPRestResponse response = null;
@@ -443,8 +459,13 @@ public void subscribeToTask(RoutingContext rc) {
443459
* @param body the JSON request body with notification configuration
444460
* @param rc the Vert.x routing context (taskId extracted from path)
445461
*/
446-
@Route(regex = "^\\/(?<tenant>[^\\/]*\\/?)tasks\\/(?<taskId>[^/]+)\\/pushNotificationConfigs$", order = 1, methods = {Route.HttpMethod.POST}, consumes = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING)
462+
@Route(regex = "^\\/(?<tenant>[^\\/]*\\/?)tasks\\/(?<taskId>[^/]+)\\/pushNotificationConfigs$", order = 1, methods = {Route.HttpMethod.POST}, type = Route.HandlerType.BLOCKING)
447463
public void CreateTaskPushNotificationConfiguration(@Body String body, RoutingContext rc) {
464+
String contentType = rc.request().getHeader(CONTENT_TYPE);
465+
if (contentType == null || !contentType.contains(APPLICATION_JSON)) {
466+
sendResponse(rc, jsonRestHandler.createErrorResponse(new ContentTypeNotSupportedError(null, null, null)));
467+
return;
468+
}
448469
String taskId = rc.pathParam("taskId");
449470
ServerCallContext context = createCallContext(rc, SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD);
450471
HTTPRestResponse response = null;

reference/rest/src/test/java/io/a2a/server/rest/quarkus/A2AServerRoutesTest.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import static org.mockito.ArgumentMatchers.anyString;
1919
import static org.mockito.ArgumentMatchers.eq;
2020
import static org.mockito.Mockito.mock;
21+
import static org.mockito.Mockito.never;
2122
import static org.mockito.Mockito.verify;
2223
import static org.mockito.Mockito.when;
2324

@@ -26,6 +27,8 @@
2627
import jakarta.enterprise.inject.Instance;
2728

2829
import io.a2a.server.ServerCallContext;
30+
import io.a2a.spec.ContentTypeNotSupportedError;
31+
import io.a2a.spec.VersionNotSupportedError;
2932
import io.a2a.transport.rest.handler.RestHandler;
3033
import io.a2a.transport.rest.handler.RestHandler.HTTPRestResponse;
3134
import io.vertx.core.Future;
@@ -80,6 +83,7 @@ public void setUp() {
8083
when(mockRoutingContext.user()).thenReturn(null);
8184
when(mockRequest.headers()).thenReturn(mockHeaders);
8285
when(mockRequest.params()).thenReturn(mockParams);
86+
when(mockRequest.getHeader(any(CharSequence.class))).thenReturn("application/json");
8387
when(mockRoutingContext.body()).thenReturn(mockRequestBody);
8488
when(mockRequestBody.asString()).thenReturn("{}");
8589
when(mockResponse.setStatusCode(any(Integer.class))).thenReturn(mockResponse);
@@ -438,6 +442,61 @@ public void testDeleteTaskPushNotificationConfiguration_MethodNameSetInContext()
438442
assertEquals(DELETE_TASK_PUSH_NOTIFICATION_CONFIG_METHOD, capturedContext.getState().get(METHOD_NAME_KEY));
439443
}
440444

445+
@Test
446+
public void testSendMessage_UnsupportedContentType_ReturnsContentTypeNotSupportedError() {
447+
// Arrange
448+
HTTPRestResponse mockErrorResponse = mock(HTTPRestResponse.class);
449+
when(mockErrorResponse.getStatusCode()).thenReturn(415);
450+
when(mockErrorResponse.getContentType()).thenReturn("application/problem+json");
451+
when(mockErrorResponse.getBody()).thenReturn("{\"type\":\"https://a2a-protocol.org/errors/content-type-not-supported\"}");
452+
when(mockRestHandler.createErrorResponse(any(ContentTypeNotSupportedError.class))).thenReturn(mockErrorResponse);
453+
when(mockRequest.getHeader(any(CharSequence.class))).thenReturn("text/plain");
454+
455+
// Act
456+
routes.sendMessage("{}", mockRoutingContext);
457+
458+
// Assert: createErrorResponse called with ContentTypeNotSupportedError, sendMessage NOT called
459+
verify(mockRestHandler).createErrorResponse(any(ContentTypeNotSupportedError.class));
460+
verify(mockRestHandler, never()).sendMessage(any(ServerCallContext.class), anyString(), anyString());
461+
}
462+
463+
@Test
464+
public void testSendMessageStreaming_UnsupportedContentType_ReturnsContentTypeNotSupportedError() {
465+
// Arrange
466+
HTTPRestResponse mockErrorResponse = mock(HTTPRestResponse.class);
467+
when(mockErrorResponse.getStatusCode()).thenReturn(415);
468+
when(mockErrorResponse.getContentType()).thenReturn("application/problem+json");
469+
when(mockErrorResponse.getBody()).thenReturn("{\"type\":\"https://a2a-protocol.org/errors/content-type-not-supported\"}");
470+
when(mockRestHandler.createErrorResponse(any(ContentTypeNotSupportedError.class))).thenReturn(mockErrorResponse);
471+
when(mockRequest.getHeader(any(CharSequence.class))).thenReturn("text/plain");
472+
473+
// Act
474+
routes.sendMessageStreaming("{}", mockRoutingContext);
475+
476+
// Assert: createErrorResponse called with ContentTypeNotSupportedError, sendStreamingMessage NOT called
477+
verify(mockRestHandler).createErrorResponse(any(ContentTypeNotSupportedError.class));
478+
verify(mockRestHandler, never()).sendStreamingMessage(any(ServerCallContext.class), anyString(), anyString());
479+
}
480+
481+
@Test
482+
public void testSendMessage_UnsupportedProtocolVersion_ReturnsVersionNotSupportedError() {
483+
// Arrange: content type is OK, but RestHandler returns a VersionNotSupportedError response
484+
HTTPRestResponse mockErrorResponse = mock(HTTPRestResponse.class);
485+
when(mockErrorResponse.getStatusCode()).thenReturn(400);
486+
when(mockErrorResponse.getContentType()).thenReturn("application/problem+json");
487+
when(mockErrorResponse.getBody()).thenReturn("{\"type\":\"https://a2a-protocol.org/errors/version-not-supported\"}");
488+
when(mockRequest.getHeader(any(CharSequence.class))).thenReturn("application/json");
489+
when(mockRestHandler.sendMessage(any(ServerCallContext.class), anyString(), anyString()))
490+
.thenReturn(mockErrorResponse);
491+
492+
// Act
493+
routes.sendMessage("{}", mockRoutingContext);
494+
495+
// Assert: sendMessage was called and error response forwarded
496+
verify(mockRestHandler).sendMessage(any(ServerCallContext.class), anyString(), eq("{}"));
497+
verify(mockResponse).setStatusCode(400);
498+
}
499+
441500
/**
442501
* Helper method to set a field via reflection for testing purposes.
443502
*/

reference/rest/src/test/java/io/a2a/server/rest/quarkus/QuarkusA2ARestTest.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,39 @@ protected String getTransportUrl() {
3030
@Override
3131
protected abstract void configureTransport(ClientBuilder builder);
3232

33+
@Test
34+
public void testSendMessageWithUnsupportedContentType() throws Exception {
35+
HttpClient client = HttpClient.newBuilder()
36+
.version(HttpClient.Version.HTTP_2)
37+
.build();
38+
HttpRequest request = HttpRequest.newBuilder()
39+
.uri(URI.create("http://localhost:" + serverPort + "/message:send"))
40+
.POST(HttpRequest.BodyPublishers.ofString("test body"))
41+
.header("Content-Type", "text/plain")
42+
.build();
43+
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
44+
Assertions.assertEquals(415, response.statusCode());
45+
Assertions.assertTrue(response.body().contains("content-type-not-supported"),
46+
"Expected content-type-not-supported in response body: " + response.body());
47+
}
48+
49+
@Test
50+
public void testSendMessageWithUnsupportedProtocolVersion() throws Exception {
51+
HttpClient client = HttpClient.newBuilder()
52+
.version(HttpClient.Version.HTTP_2)
53+
.build();
54+
HttpRequest request = HttpRequest.newBuilder()
55+
.uri(URI.create("http://localhost:" + serverPort + "/message:send"))
56+
.POST(HttpRequest.BodyPublishers.ofString("{}"))
57+
.header("Content-Type", APPLICATION_JSON)
58+
.header("A2A-Version", "0.4.0")
59+
.build();
60+
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
61+
Assertions.assertEquals(400, response.statusCode());
62+
Assertions.assertTrue(response.body().contains("version-not-supported"),
63+
"Expected version-not-supported in response body: " + response.body());
64+
}
65+
3366
@Test
3467
public void testMethodNotFound() throws Exception {
3568
// Create the client

transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -770,8 +770,7 @@ private static int mapErrorToHttpStatus(A2AError error) {
770770
return 409;
771771
}
772772
if (error instanceof PushNotificationNotSupportedError
773-
|| error instanceof UnsupportedOperationError
774-
|| error instanceof VersionNotSupportedError) {
773+
|| error instanceof UnsupportedOperationError) {
775774
return 501;
776775
}
777776
if (error instanceof ContentTypeNotSupportedError) {
@@ -781,7 +780,8 @@ private static int mapErrorToHttpStatus(A2AError error) {
781780
return 502;
782781
}
783782
if (error instanceof ExtendedAgentCardNotConfiguredError
784-
|| error instanceof ExtensionSupportRequiredError) {
783+
|| error instanceof ExtensionSupportRequiredError
784+
|| error instanceof VersionNotSupportedError) {
785785
return 400;
786786
}
787787
if (error instanceof InternalError) {

transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,7 @@ public void testVersionNotSupportedErrorOnSendMessage() {
777777

778778
RestHandler.HTTPRestResponse response = handler.sendMessage(contextWithVersion, "", requestBody);
779779

780-
assertProblemDetail(response, 501,
780+
assertProblemDetail(response, 400,
781781
"https://a2a-protocol.org/errors/version-not-supported",
782782
"Protocol version '2.0' is not supported. Supported versions: [1.0]");
783783
}

0 commit comments

Comments
 (0)