Skip to content

Support server to client notifications from the stateless transport #472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,13 @@ public interface McpTransportContext {
*/
McpTransportContext copy();

/**
* Sends a notification from the server to the client.
* @param method notification method name
* @param params any parameters or {@code null}
*/
default void sendNotification(String method, Object params) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative to this might be a separate McpNotifier instance that would be accessed by calling McpTransportContext.get() with a lib-defined key. wdyt?

throw new UnsupportedOperationException("Not supported in this implementation of MCP transport context");
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright 2024-2025 the original author or authors.
*/

package io.modelcontextprotocol.server;

import java.util.function.BiConsumer;

public class StatelessMcpTransportContext implements McpTransportContext {

private final McpTransportContext delegate;

private final BiConsumer<String, Object> notificationHandler;

/**
* Create an empty instance.
*/
public StatelessMcpTransportContext(BiConsumer<String, Object> notificationHandler) {
this(new DefaultMcpTransportContext(), notificationHandler);
}

private StatelessMcpTransportContext(McpTransportContext delegate, BiConsumer<String, Object> notificationHandler) {
this.delegate = delegate;
this.notificationHandler = notificationHandler;
}

@Override
public Object get(String key) {
return this.delegate.get(key);
}

@Override
public void put(String key, Object value) {
this.delegate.put(key, value);
}

public McpTransportContext copy() {
return new StatelessMcpTransportContext(delegate.copy(), notificationHandler);
}

@Override
public void sendNotification(String method, Object params) {
notificationHandler.accept(method, params);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,11 @@

package io.modelcontextprotocol.server.transport;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.fasterxml.jackson.databind.ObjectMapper;

import io.modelcontextprotocol.server.DefaultMcpTransportContext;
import io.modelcontextprotocol.server.McpStatelessServerHandler;
import io.modelcontextprotocol.server.McpTransportContext;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.server.StatelessMcpTransportContext;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpStatelessServerTransport;
Expand All @@ -26,8 +18,17 @@
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;

/**
* Implementation of an HttpServlet based {@link McpStatelessServerTransport}.
*
Expand Down Expand Up @@ -123,11 +124,16 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
return;
}

McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
AtomicInteger nextId = new AtomicInteger(0);
AtomicBoolean upgradedToSse = new AtomicBoolean(false);
BiConsumer<String, Object> notificationHandler = buildNotificationHandler(response, upgradedToSse, nextId);
McpTransportContext transportContext = this.contextExtractor.extract(request,
new StatelessMcpTransportContext(notificationHandler));

String accept = request.getHeader(ACCEPT);
if (accept == null || !(accept.contains(APPLICATION_JSON) && accept.contains(TEXT_EVENT_STREAM))) {
this.responseError(response, HttpServletResponse.SC_BAD_REQUEST,
this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, null, upgradedToSse.get(),
nextId.getAndIncrement(),
new McpError("Both application/json and text/event-stream required in Accept header"));
return;
}
Expand All @@ -149,18 +155,24 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
.block();

response.setContentType(APPLICATION_JSON);
response.setCharacterEncoding(UTF_8);
response.setStatus(HttpServletResponse.SC_OK);

String jsonResponseText = objectMapper.writeValueAsString(jsonrpcResponse);
PrintWriter writer = response.getWriter();
writer.write(jsonResponseText);
writer.flush();
if (upgradedToSse.get()) {
sendEvent(response.getWriter(), jsonResponseText, nextId.getAndIncrement());
}
else {
response.setContentType(APPLICATION_JSON);
response.setCharacterEncoding(UTF_8);
response.setStatus(HttpServletResponse.SC_OK);

PrintWriter writer = response.getWriter();
writer.write(jsonResponseText);
writer.flush();
}
}
catch (Exception e) {
logger.error("Failed to handle request: {}", e.getMessage());
this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, jsonrpcRequest.id(),
upgradedToSse.get(), nextId.getAndIncrement(),
new McpError("Failed to handle request: " + e.getMessage()));
}
}
Expand All @@ -173,41 +185,53 @@ else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) {
}
catch (Exception e) {
logger.error("Failed to handle notification: {}", e.getMessage());
this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, null,
upgradedToSse.get(), nextId.getAndIncrement(),
new McpError("Failed to handle notification: " + e.getMessage()));
}
}
else {
this.responseError(response, HttpServletResponse.SC_BAD_REQUEST,
new McpError("The server accepts either requests or notifications"));
this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, null, upgradedToSse.get(),
nextId.getAndIncrement(), new McpError("The server accepts either requests or notifications"));
}
}
catch (IllegalArgumentException | IOException e) {
logger.error("Failed to deserialize message: {}", e.getMessage());
this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError("Invalid message format"));
this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, null, upgradedToSse.get(),
nextId.getAndIncrement(), new McpError("Invalid message format"));
}
catch (Exception e) {
logger.error("Unexpected error handling message: {}", e.getMessage());
this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
new McpError("Unexpected error: " + e.getMessage()));
this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, null, upgradedToSse.get(),
nextId.getAndIncrement(), new McpError("Unexpected error: " + e.getMessage()));
}
}

/**
* Sends an error response to the client.
* @param response The HTTP servlet response
* @param httpCode The HTTP status code
* @param upgradedToSse true if the response is upgraded to SSE, false otherwise
* @param eventIdIfNeeded if upgradedToSse, the event ID to use, otherwise ignored
* @param mcpError The MCP error to send
* @throws IOException If an I/O error occurs
*/
private void responseError(HttpServletResponse response, int httpCode, McpError mcpError) throws IOException {
response.setContentType(APPLICATION_JSON);
response.setCharacterEncoding(UTF_8);
response.setStatus(httpCode);
String jsonError = objectMapper.writeValueAsString(mcpError);
PrintWriter writer = response.getWriter();
writer.write(jsonError);
writer.flush();
private void responseError(HttpServletResponse response, int httpCode, Object requestId, boolean upgradedToSse,
int eventIdIfNeeded, McpError mcpError) throws IOException {
if (upgradedToSse) {
String jsonError = objectMapper.writeValueAsString(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION,
requestId, null, mcpError.getJsonRpcError()));
sendEvent(response.getWriter(), jsonError, eventIdIfNeeded);
}
else {
response.setContentType(APPLICATION_JSON);
response.setCharacterEncoding(UTF_8);
response.setStatus(httpCode);
PrintWriter writer = response.getWriter();
String jsonError = objectMapper.writeValueAsString(mcpError);
writer.write(jsonError);
writer.flush();
}
}

/**
Expand Down Expand Up @@ -303,4 +327,43 @@ public HttpServletStatelessServerTransport build() {

}

private BiConsumer<String, Object> buildNotificationHandler(HttpServletResponse response,
AtomicBoolean upgradedToSse, AtomicInteger nextId) {
AtomicBoolean responseInitialized = new AtomicBoolean(false);

return (notificationMethod, params) -> {
if (responseInitialized.compareAndSet(false, true)) {
response.setContentType(TEXT_EVENT_STREAM);
response.setCharacterEncoding(UTF_8);
response.setStatus(HttpServletResponse.SC_OK);
}

upgradedToSse.set(true);

McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION,
notificationMethod, params);
try {
sendEvent(response.getWriter(), objectMapper.writeValueAsString(notification),
nextId.getAndIncrement());
}
catch (IOException e) {
logger.error("Failed to handle notification: {}", e.getMessage());
throw new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
e.getMessage(), null));
}
};
}

private void sendEvent(PrintWriter writer, String data, int id) throws IOException {
// tested with MCP inspector. Event must consist of these two fields and only
// these two fields
writer.write("id: " + id + "\n");
writer.write("data: " + data + "\n\n");
writer.flush();

if (writer.checkError()) {
throw new IOException("Client disconnected");
}
}

}
Loading