From 3f83ef26c8550ac2719b9d6beacd101dfadfa607 Mon Sep 17 00:00:00 2001 From: Jordan Zimmerman Date: Fri, 8 Aug 2025 08:37:15 +0100 Subject: [PATCH] Support server to client notifications from the stateless transport The MCP spec allows stateless servers to send notifications to the client during a request. The response needs to be upgraded to SSE and the notifications are send in a stream until the final result is sent. This commit adds a `sendNotification` method to the transport context allowing each transport implementation to implement it or not. In this commit, HttpServletStatelessServerTransport implements the method and when the caller first sends a notification, the response is changed to `TEXT_EVENT_STREAM` and events are then streamed until the final result. This change will allow future features such as logging, list changes, etc. should we ever decide to support sessions in some manner. Even if we don't support sessions, sending progress notifications is a useful feature by itself. --- .../server/McpTransportContext.java | 9 ++ .../server/StatelessMcpTransportContext.java | 46 ++++++ .../HttpServletStatelessServerTransport.java | 129 +++++++++++---- .../HttpServletStatelessIntegrationTests.java | 149 +++++++++++++++++- 4 files changed, 299 insertions(+), 34 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/StatelessMcpTransportContext.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java index 1cd540f72..21f751d89 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java @@ -47,4 +47,13 @@ public interface McpTransportContext { */ McpTransportContext copy(); + /** + * Sends a notification from the server to the client. + * @param method notification method name + * @param params any parameters or {@code null} + */ + default void sendNotification(String method, Object params) { + throw new UnsupportedOperationException("Not supported in this implementation of MCP transport context"); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/StatelessMcpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/server/StatelessMcpTransportContext.java new file mode 100644 index 000000000..b2b0a6cb8 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/StatelessMcpTransportContext.java @@ -0,0 +1,46 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.function.BiConsumer; + +public class StatelessMcpTransportContext implements McpTransportContext { + + private final McpTransportContext delegate; + + private final BiConsumer notificationHandler; + + /** + * Create an empty instance. + */ + public StatelessMcpTransportContext(BiConsumer notificationHandler) { + this(new DefaultMcpTransportContext(), notificationHandler); + } + + private StatelessMcpTransportContext(McpTransportContext delegate, BiConsumer notificationHandler) { + this.delegate = delegate; + this.notificationHandler = notificationHandler; + } + + @Override + public Object get(String key) { + return this.delegate.get(key); + } + + @Override + public void put(String key, Object value) { + this.delegate.put(key, value); + } + + public McpTransportContext copy() { + return new StatelessMcpTransportContext(delegate.copy(), notificationHandler); + } + + @Override + public void sendNotification(String method, Object params) { + notificationHandler.accept(method, params); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java index 25b003564..041471965 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java @@ -4,19 +4,11 @@ package io.modelcontextprotocol.server.transport; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.PrintWriter; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.server.DefaultMcpTransportContext; import io.modelcontextprotocol.server.McpStatelessServerHandler; import io.modelcontextprotocol.server.McpTransportContext; import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.StatelessMcpTransportContext; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpStatelessServerTransport; @@ -26,8 +18,17 @@ import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; + /** * Implementation of an HttpServlet based {@link McpStatelessServerTransport}. * @@ -123,11 +124,16 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + AtomicInteger nextId = new AtomicInteger(0); + AtomicBoolean upgradedToSse = new AtomicBoolean(false); + BiConsumer notificationHandler = buildNotificationHandler(response, upgradedToSse, nextId); + McpTransportContext transportContext = this.contextExtractor.extract(request, + new StatelessMcpTransportContext(notificationHandler)); String accept = request.getHeader(ACCEPT); if (accept == null || !(accept.contains(APPLICATION_JSON) && accept.contains(TEXT_EVENT_STREAM))) { - this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, null, upgradedToSse.get(), + nextId.getAndIncrement(), new McpError("Both application/json and text/event-stream required in Accept header")); return; } @@ -149,18 +155,24 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) .block(); - response.setContentType(APPLICATION_JSON); - response.setCharacterEncoding(UTF_8); - response.setStatus(HttpServletResponse.SC_OK); - String jsonResponseText = objectMapper.writeValueAsString(jsonrpcResponse); - PrintWriter writer = response.getWriter(); - writer.write(jsonResponseText); - writer.flush(); + if (upgradedToSse.get()) { + sendEvent(response.getWriter(), jsonResponseText, nextId.getAndIncrement()); + } + else { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_OK); + + PrintWriter writer = response.getWriter(); + writer.write(jsonResponseText); + writer.flush(); + } } catch (Exception e) { logger.error("Failed to handle request: {}", e.getMessage()); - this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, jsonrpcRequest.id(), + upgradedToSse.get(), nextId.getAndIncrement(), new McpError("Failed to handle request: " + e.getMessage())); } } @@ -173,23 +185,25 @@ else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { } catch (Exception e) { logger.error("Failed to handle notification: {}", e.getMessage()); - this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, null, + upgradedToSse.get(), nextId.getAndIncrement(), new McpError("Failed to handle notification: " + e.getMessage())); } } else { - this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, - new McpError("The server accepts either requests or notifications")); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, null, upgradedToSse.get(), + nextId.getAndIncrement(), new McpError("The server accepts either requests or notifications")); } } catch (IllegalArgumentException | IOException e) { logger.error("Failed to deserialize message: {}", e.getMessage()); - this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError("Invalid message format")); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, null, upgradedToSse.get(), + nextId.getAndIncrement(), new McpError("Invalid message format")); } catch (Exception e) { logger.error("Unexpected error handling message: {}", e.getMessage()); - this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, - new McpError("Unexpected error: " + e.getMessage())); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, null, upgradedToSse.get(), + nextId.getAndIncrement(), new McpError("Unexpected error: " + e.getMessage())); } } @@ -197,17 +211,27 @@ else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { * Sends an error response to the client. * @param response The HTTP servlet response * @param httpCode The HTTP status code + * @param upgradedToSse true if the response is upgraded to SSE, false otherwise + * @param eventIdIfNeeded if upgradedToSse, the event ID to use, otherwise ignored * @param mcpError The MCP error to send * @throws IOException If an I/O error occurs */ - private void responseError(HttpServletResponse response, int httpCode, McpError mcpError) throws IOException { - response.setContentType(APPLICATION_JSON); - response.setCharacterEncoding(UTF_8); - response.setStatus(httpCode); - String jsonError = objectMapper.writeValueAsString(mcpError); - PrintWriter writer = response.getWriter(); - writer.write(jsonError); - writer.flush(); + private void responseError(HttpServletResponse response, int httpCode, Object requestId, boolean upgradedToSse, + int eventIdIfNeeded, McpError mcpError) throws IOException { + if (upgradedToSse) { + String jsonError = objectMapper.writeValueAsString(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + requestId, null, mcpError.getJsonRpcError())); + sendEvent(response.getWriter(), jsonError, eventIdIfNeeded); + } + else { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(httpCode); + PrintWriter writer = response.getWriter(); + String jsonError = objectMapper.writeValueAsString(mcpError); + writer.write(jsonError); + writer.flush(); + } } /** @@ -303,4 +327,43 @@ public HttpServletStatelessServerTransport build() { } + private BiConsumer buildNotificationHandler(HttpServletResponse response, + AtomicBoolean upgradedToSse, AtomicInteger nextId) { + AtomicBoolean responseInitialized = new AtomicBoolean(false); + + return (notificationMethod, params) -> { + if (responseInitialized.compareAndSet(false, true)) { + response.setContentType(TEXT_EVENT_STREAM); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_OK); + } + + upgradedToSse.set(true); + + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + notificationMethod, params); + try { + sendEvent(response.getWriter(), objectMapper.writeValueAsString(notification), + nextId.getAndIncrement()); + } + catch (IOException e) { + logger.error("Failed to handle notification: {}", e.getMessage()); + throw new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + e.getMessage(), null)); + } + }; + } + + private void sendEvent(PrintWriter writer, String data, int id) throws IOException { + // tested with MCP inspector. Event must consist of these two fields and only + // these two fields + writer.write("id: " + id + "\n"); + writer.write("data: " + data + "\n\n"); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Client disconnected"); + } + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java index 4c3f22d76..c43caa356 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java @@ -35,12 +35,19 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.web.client.RestClient; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; import java.time.Duration; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; +import java.util.stream.Stream; import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.APPLICATION_JSON; import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.TEXT_EVENT_STREAM; @@ -61,10 +68,13 @@ class HttpServletStatelessIntegrationTests { private Tomcat tomcat; + private ObjectMapper objectMapper; + @BeforeEach public void before() { + objectMapper = new ObjectMapper(); this.mcpStatelessServerTransport = HttpServletStatelessServerTransport.builder() - .objectMapper(new ObjectMapper()) + .objectMapper(objectMapper) .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .build(); @@ -219,6 +229,143 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { mcpServer.close(); } + @Test + void testNotifications() throws Exception { + + Tool tool = Tool.builder().name("test").build(); + Tool exceptionTool = Tool.builder().name("exception").build(); + + final int PROGRESS_QTY = 1000; + final String progressMessage = "We're working on it..."; + + var progressToken = UUID.randomUUID().toString(); + var callResponse = new CallToolResult(List.of(), null, null, Map.of("progressToken", progressToken)); + + McpStatelessServerFeatures.SyncToolSpecification toolSpecification = new McpStatelessServerFeatures.SyncToolSpecification( + tool, (transportContext, request) -> { + // Simulate sending progress notifications - send enough to ensure + // that cunked transfer encoding is used + for (int i = 0; i < PROGRESS_QTY; i++) { + transportContext.sendNotification(McpSchema.METHOD_NOTIFICATION_PROGRESS, + new McpSchema.ProgressNotification(progressToken, i, (double) PROGRESS_QTY, + progressMessage)); + } + return callResponse; + }); + + McpStatelessServerFeatures.SyncToolSpecification exceptionToolSpecification = new McpStatelessServerFeatures.SyncToolSpecification( + exceptionTool, (transportContext, request) -> { + // send 1 progress so that the response gets upgraded + transportContext.sendNotification(McpSchema.METHOD_NOTIFICATION_PROGRESS, + new McpSchema.ProgressNotification(progressToken, 1, 5.0, progressMessage)); + throw new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INVALID_PARAMS, + "bad tool", Map.of())); + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(toolSpecification, exceptionToolSpecification) + .build(); + + HttpClient client = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build(); + HttpRequest request = HttpRequest.newBuilder() + .method("POST", + HttpRequest.BodyPublishers.ofString( + objectMapper.writeValueAsString(new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + "tools/call", "1", new McpSchema.CallToolRequest("test", Map.of()))))) + .header("Content-Type", APPLICATION_JSON) + .header("Accept", APPLICATION_JSON + "," + TEXT_EVENT_STREAM) + .uri(URI.create("http://localhost:" + PORT + CUSTOM_MESSAGE_ENDPOINT)) + .build(); + + HttpResponse> response = client.send(request, HttpResponse.BodyHandlers.ofLines()); + assertThat(response.headers().firstValue("Transfer-Encoding")).contains("chunked"); + + List responseBody = response.body().toList(); + + assertThat(responseBody).hasSize((PROGRESS_QTY + 1) * 3); // 3 lines per progress + // notification + 4 + // for + // the call result + + Iterator iterator = responseBody.iterator(); + for (int i = 0; i < PROGRESS_QTY; ++i) { + String idLine = iterator.next(); + String dataLine = iterator.next(); + String blankLine = iterator.next(); + + McpSchema.ProgressNotification expectedNotification = new McpSchema.ProgressNotification(progressToken, i, + (double) PROGRESS_QTY, progressMessage); + McpSchema.JSONRPCNotification expectedJsonRpcNotification = new McpSchema.JSONRPCNotification( + McpSchema.JSONRPC_VERSION, McpSchema.METHOD_NOTIFICATION_PROGRESS, expectedNotification); + + assertThat(idLine).isEqualTo("id: " + i); + assertThat(dataLine).isEqualTo("data: " + objectMapper.writeValueAsString(expectedJsonRpcNotification)); + assertThat(blankLine).isBlank(); + } + + String idLine = iterator.next(); + String dataLine = iterator.next(); + String blankLine = iterator.next(); + + assertThat(idLine).isEqualTo("id: " + PROGRESS_QTY); + assertThat(dataLine).isEqualTo("data: " + objectMapper + .writeValueAsString(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, "1", callResponse, null))); + assertThat(blankLine).isBlank(); + + assertThat(iterator.hasNext()).isFalse(); + + // next, test the exception tool + + request = HttpRequest.newBuilder() + .method("POST", + HttpRequest.BodyPublishers.ofString( + objectMapper.writeValueAsString(new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + "tools/call", "1", new McpSchema.CallToolRequest("exception", Map.of()))))) + .header("Content-Type", APPLICATION_JSON) + .header("Accept", APPLICATION_JSON + "," + TEXT_EVENT_STREAM) + .uri(URI.create("http://localhost:" + PORT + CUSTOM_MESSAGE_ENDPOINT)) + .build(); + + response = client.send(request, HttpResponse.BodyHandlers.ofLines()); + assertThat(response.headers().firstValue("Transfer-Encoding")).contains("chunked"); + + responseBody = response.body().toList(); + + assertThat(responseBody).hasSize(6); // 1 progress notification + the error + // response + + iterator = responseBody.iterator(); + + idLine = iterator.next(); + dataLine = iterator.next(); + blankLine = iterator.next(); + + McpSchema.ProgressNotification expectedNotification = new McpSchema.ProgressNotification(progressToken, 1, 5.0, + progressMessage); + McpSchema.JSONRPCNotification expectedJsonRpcNotification = new McpSchema.JSONRPCNotification( + McpSchema.JSONRPC_VERSION, McpSchema.METHOD_NOTIFICATION_PROGRESS, expectedNotification); + + assertThat(idLine).isEqualTo("id: 0"); + assertThat(dataLine).isEqualTo("data: " + objectMapper.writeValueAsString(expectedJsonRpcNotification)); + assertThat(blankLine).isBlank(); + + idLine = iterator.next(); + dataLine = iterator.next(); + blankLine = iterator.next(); + + assertThat(iterator.hasNext()).isFalse(); + + assertThat(idLine).isEqualTo("id: 1"); + assertThat(dataLine).isEqualTo( + "data: " + objectMapper.writeValueAsString(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, "1", + null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INVALID_PARAMS, + "bad tool", Map.of())))); + assertThat(blankLine).isBlank(); + + mcpServer.close(); + } + // --------------------------------------- // Tool Structured Output Schema Tests // ---------------------------------------