From 0e17f2fba2368f87636ebecc2b796f37de3b0498 Mon Sep 17 00:00:00 2001 From: Hongxiang Jiang Date: Fri, 22 Aug 2025 14:23:31 -0400 Subject: [PATCH 1/2] mcp: cleanup transport after keepalive ping fails An onClose function is passed to the ServerSession and ClientSession to help cleanup resources from the caller. The onClose function will be executed as part of the ServerSession and ClientSession closure. Fixes #258 --- mcp/client.go | 16 ++++++--- mcp/server.go | 20 +++++++++--- mcp/streamable.go | 15 +++++++++ mcp/streamable_test.go | 73 ++++++++++++++++++++++++++++++++++++++++++ mcp/transport.go | 6 ++-- 5 files changed, 119 insertions(+), 11 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 3b1741b3..1ed3b048 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -76,9 +76,9 @@ type ClientOptions struct { // bind implements the binder[*ClientSession] interface, so that Clients can // be connected using [connect]. -func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState) *ClientSession { +func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState, onClose func()) *ClientSession { assert(mcpConn != nil && conn != nil, "nil connection") - cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c} + cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c, onClose: onClose} if state != nil { cs.state = *state } @@ -130,7 +130,7 @@ func (c *Client) capabilities() *ClientCapabilities { // server, calls or notifications will return an error wrapping // [ErrConnectionClosed]. func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptions) (cs *ClientSession, err error) { - cs, err = connect(ctx, t, c, (*clientSessionState)(nil)) + cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil) if err != nil { return nil, err } @@ -173,6 +173,8 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio // Call [ClientSession.Close] to close the connection, or await server // termination with [ClientSession.Wait]. type ClientSession struct { + onClose func() + conn *jsonrpc2.Connection client *Client keepaliveCancel context.CancelFunc @@ -208,7 +210,13 @@ func (cs *ClientSession) Close() error { if cs.keepaliveCancel != nil { cs.keepaliveCancel() } - return cs.conn.Close() + err := cs.conn.Close() + + if cs.onClose != nil { + cs.onClose() + } + + return err } // Wait waits for the connection to be closed by the server. diff --git a/mcp/server.go b/mcp/server.go index c44dfeb6..49a3f56e 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -681,9 +681,9 @@ func (s *Server) Run(ctx context.Context, t Transport) error { // bind implements the binder[*ServerSession] interface, so that Servers can // be connected using [connect]. -func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState) *ServerSession { +func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState, onClose func()) *ServerSession { assert(mcpConn != nil && conn != nil, "nil connection") - ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s} + ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s, onClose: onClose} if state != nil { ss.state = *state } @@ -710,6 +710,8 @@ func (s *Server) disconnect(cc *ServerSession) { // ServerSessionOptions configures the server session. type ServerSessionOptions struct { State *ServerSessionState + + onClose func() } // Connect connects the MCP server over the given transport and starts handling @@ -722,10 +724,12 @@ type ServerSessionOptions struct { // If opts.State is non-nil, it is the initial state for the server. func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOptions) (*ServerSession, error) { var state *ServerSessionState + var onClose func() if opts != nil { state = opts.State + onClose = opts.onClose } - return connect(ctx, t, s, state) + return connect(ctx, t, s, state, onClose) } // TODO: (nit) move all ServerSession methods below the ServerSession declaration. @@ -792,6 +796,8 @@ func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] { // Call [ServerSession.Close] to close the connection, or await client // termination with [ServerSession.Wait]. type ServerSession struct { + onClose func() + server *Server conn *jsonrpc2.Connection mcpConn Connection @@ -1018,7 +1024,13 @@ func (ss *ServerSession) Close() error { // Close is idempotent and conn.Close() handles concurrent calls correctly ss.keepaliveCancel() } - return ss.conn.Close() + err := ss.conn.Close() + + if ss.onClose != nil { + ss.onClose() + } + + return err } // Wait waits for the connection to be closed by the client. diff --git a/mcp/streamable.go b/mcp/streamable.go index 99fbe422..f56b7084 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -40,6 +40,8 @@ type StreamableHTTPHandler struct { getServer func(*http.Request) *Server opts StreamableHTTPOptions + onTransportDeletion func(sessionID string) // for testing only + mu sync.Mutex // TODO: we should store the ServerSession along with the transport, because // we need to cancel keepalive requests when closing the transport. @@ -283,6 +285,19 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque connectOpts = &ServerSessionOptions{ State: state, } + } else { + // Cleanup is only required in stateful mode, as transportation is + // not stored in the map otherwise. + connectOpts = &ServerSessionOptions{ + onClose: func() { + h.mu.Lock() + delete(h.transports, transport.SessionID) + h.mu.Unlock() + if h.onTransportDeletion != nil { + h.onTransportDeletion(transport.SessionID) + } + }, + } } // Pass req.Context() here, to allow middleware to add context values. diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 603be473..9d75c4ad 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -8,6 +8,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "maps" @@ -332,6 +333,78 @@ func testClientReplay(t *testing.T, test clientReplayTest) { } } +func TestServerTransportCleanup(t *testing.T) { + server := NewServer(testImpl, &ServerOptions{KeepAlive: 10 * time.Millisecond}) + + nClient := 3 + + var mu sync.Mutex + var id int = -1 // session id starting from "0", "1", "2"... + chans := make(map[string]chan struct{}, nClient) + + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + GetSessionID: func() string { + mu.Lock() + defer mu.Unlock() + id++ + if id == nClient { + t.Errorf("creating more than %v session", nClient) + } + chans[fmt.Sprint(id)] = make(chan struct{}, 1) + return fmt.Sprint(id) + }, + }) + + handler.onTransportDeletion = func(sessionID string) { + mu.Lock() + chans[sessionID] <- struct{}{} + mu.Unlock() + } + + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Spin up clients connect to the same server but refuse to ping request. + for range nClient { + client := NewClient(testImpl, nil) + pingMiddleware := func(next MethodHandler) MethodHandler { + return func( + ctx context.Context, + method string, + req Request, + ) (Result, error) { + if method == "ping" { + return &emptyResult{}, errors.New("ping error") + } + return next(ctx, method, req) + } + } + client.AddReceivingMiddleware(pingMiddleware) + clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer clientSession.Close() + } + + for _, ch := range chans { + select { + case <-ctx.Done(): + t.Errorf("did not capture transport deletion event from all session in 10 seconds") + case <-ch: // Received transport deletion signal of this session + } + } + + handler.mu.Lock() + if len(handler.transports) != 0 { + t.Errorf("want empty transports map, find %v entries from handler's transports map", len(handler.transports)) + } + handler.mu.Unlock() +} + // TestServerInitiatedSSE verifies that the persistent SSE connection remains // open and can receive server-initiated events. func TestServerInitiatedSSE(t *testing.T) { diff --git a/mcp/transport.go b/mcp/transport.go index 2bcd8d7d..36485508 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -122,7 +122,7 @@ func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) { } type binder[T handler, State any] interface { - bind(Connection, *jsonrpc2.Connection, State) T + bind(Connection, *jsonrpc2.Connection, State, func()) T disconnect(T) } @@ -130,7 +130,7 @@ type handler interface { handle(ctx context.Context, req *jsonrpc.Request) (any, error) } -func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State) (H, error) { +func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State, onClose func()) (H, error) { var zero H mcpConn, err := t.Connect(ctx) if err != nil { @@ -143,7 +143,7 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, preempter canceller ) bind := func(conn *jsonrpc2.Connection) jsonrpc2.Handler { - h = b.bind(mcpConn, conn, s) + h = b.bind(mcpConn, conn, s, onClose) preempter.conn = conn return jsonrpc2.HandlerFunc(h.handle) } From 61288e41f03a92da075f27e8611d9b9abfe24d73 Mon Sep 17 00:00:00 2001 From: Hongxiang Jiang Date: Fri, 22 Aug 2025 15:18:56 -0400 Subject: [PATCH 2/2] mcp: resolve comments - remove unnecessary lock for sending channel - add todo to refactor the bind method --- mcp/streamable_test.go | 2 -- mcp/transport.go | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 9d75c4ad..3d2e1c02 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -356,9 +356,7 @@ func TestServerTransportCleanup(t *testing.T) { }) handler.onTransportDeletion = func(sessionID string) { - mu.Lock() chans[sessionID] <- struct{}{} - mu.Unlock() } httpServer := httptest.NewServer(handler) diff --git a/mcp/transport.go b/mcp/transport.go index 36485508..fac640a6 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -122,6 +122,7 @@ func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) { } type binder[T handler, State any] interface { + // TODO(rfindley): the bind API has gotten too complicated. Simplify. bind(Connection, *jsonrpc2.Connection, State, func()) T disconnect(T) }