Skip to content

mcp: cleanup transport after keepalive ping fails #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ type ClientOptions struct {

// bind implements the binder[*ClientSession] interface, so that Clients can
// be connected using [connect].
func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState) *ClientSession {
func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState, onClose func()) *ClientSession {
assert(mcpConn != nil && conn != nil, "nil connection")
cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c}
cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c, onClose: onClose}
if state != nil {
cs.state = *state
}
Expand Down Expand Up @@ -130,7 +130,7 @@ func (c *Client) capabilities() *ClientCapabilities {
// server, calls or notifications will return an error wrapping
// [ErrConnectionClosed].
func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptions) (cs *ClientSession, err error) {
cs, err = connect(ctx, t, c, (*clientSessionState)(nil))
cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -173,6 +173,8 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio
// Call [ClientSession.Close] to close the connection, or await server
// termination with [ClientSession.Wait].
type ClientSession struct {
onClose func()

conn *jsonrpc2.Connection
client *Client
keepaliveCancel context.CancelFunc
Expand Down Expand Up @@ -208,7 +210,13 @@ func (cs *ClientSession) Close() error {
if cs.keepaliveCancel != nil {
cs.keepaliveCancel()
}
return cs.conn.Close()
err := cs.conn.Close()

if cs.onClose != nil {
cs.onClose()
}

return err
}

// Wait waits for the connection to be closed by the server.
Expand Down
20 changes: 16 additions & 4 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -681,9 +681,9 @@ func (s *Server) Run(ctx context.Context, t Transport) error {

// bind implements the binder[*ServerSession] interface, so that Servers can
// be connected using [connect].
func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState) *ServerSession {
func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState, onClose func()) *ServerSession {
assert(mcpConn != nil && conn != nil, "nil connection")
ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s}
ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s, onClose: onClose}
if state != nil {
ss.state = *state
}
Expand All @@ -710,6 +710,8 @@ func (s *Server) disconnect(cc *ServerSession) {
// ServerSessionOptions configures the server session.
type ServerSessionOptions struct {
State *ServerSessionState

onClose func()
}

// Connect connects the MCP server over the given transport and starts handling
Expand All @@ -722,10 +724,12 @@ type ServerSessionOptions struct {
// If opts.State is non-nil, it is the initial state for the server.
func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOptions) (*ServerSession, error) {
var state *ServerSessionState
var onClose func()
if opts != nil {
state = opts.State
onClose = opts.onClose
}
return connect(ctx, t, s, state)
return connect(ctx, t, s, state, onClose)
}

// TODO: (nit) move all ServerSession methods below the ServerSession declaration.
Expand Down Expand Up @@ -792,6 +796,8 @@ func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] {
// Call [ServerSession.Close] to close the connection, or await client
// termination with [ServerSession.Wait].
type ServerSession struct {
onClose func()

server *Server
conn *jsonrpc2.Connection
mcpConn Connection
Expand Down Expand Up @@ -1018,7 +1024,13 @@ func (ss *ServerSession) Close() error {
// Close is idempotent and conn.Close() handles concurrent calls correctly
ss.keepaliveCancel()
}
return ss.conn.Close()
err := ss.conn.Close()

if ss.onClose != nil {
ss.onClose()
}

return err
}

// Wait waits for the connection to be closed by the client.
Expand Down
15 changes: 15 additions & 0 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ type StreamableHTTPHandler struct {
getServer func(*http.Request) *Server
opts StreamableHTTPOptions

onTransportDeletion func(sessionID string) // for testing only
Copy link
Collaborator Author

@h9jiang h9jiang Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer not to add this field only for testing purpose. But the test need some signal from the production indicating transportations have changed.

Let me know if you have any better solution to this.


mu sync.Mutex
// TODO: we should store the ServerSession along with the transport, because
// we need to cancel keepalive requests when closing the transport.
Expand Down Expand Up @@ -283,6 +285,19 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
connectOpts = &ServerSessionOptions{
State: state,
}
} else {
// Cleanup is only required in stateful mode, as transportation is
// not stored in the map otherwise.
connectOpts = &ServerSessionOptions{
onClose: func() {
h.mu.Lock()
delete(h.transports, transport.SessionID)
h.mu.Unlock()
if h.onTransportDeletion != nil {
h.onTransportDeletion(transport.SessionID)
}
},
}
}

// Pass req.Context() here, to allow middleware to add context values.
Expand Down
71 changes: 71 additions & 0 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"maps"
Expand Down Expand Up @@ -332,6 +333,76 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
}
}

func TestServerTransportCleanup(t *testing.T) {
server := NewServer(testImpl, &ServerOptions{KeepAlive: 10 * time.Millisecond})

nClient := 3

var mu sync.Mutex
var id int = -1 // session id starting from "0", "1", "2"...
chans := make(map[string]chan struct{}, nClient)

handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
GetSessionID: func() string {
mu.Lock()
defer mu.Unlock()
id++
if id == nClient {
t.Errorf("creating more than %v session", nClient)
}
chans[fmt.Sprint(id)] = make(chan struct{}, 1)
return fmt.Sprint(id)
},
})

handler.onTransportDeletion = func(sessionID string) {
chans[sessionID] <- struct{}{}
}

httpServer := httptest.NewServer(handler)
defer httpServer.Close()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

// Spin up clients connect to the same server but refuse to ping request.
for range nClient {
client := NewClient(testImpl, nil)
pingMiddleware := func(next MethodHandler) MethodHandler {
return func(
ctx context.Context,
method string,
req Request,
) (Result, error) {
if method == "ping" {
return &emptyResult{}, errors.New("ping error")
}
return next(ctx, method, req)
}
}
client.AddReceivingMiddleware(pingMiddleware)
clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
defer clientSession.Close()
}

for _, ch := range chans {
select {
case <-ctx.Done():
t.Errorf("did not capture transport deletion event from all session in 10 seconds")
case <-ch: // Received transport deletion signal of this session
}
}

handler.mu.Lock()
if len(handler.transports) != 0 {
t.Errorf("want empty transports map, find %v entries from handler's transports map", len(handler.transports))
}
handler.mu.Unlock()
}

// TestServerInitiatedSSE verifies that the persistent SSE connection remains
// open and can receive server-initiated events.
func TestServerInitiatedSSE(t *testing.T) {
Expand Down
7 changes: 4 additions & 3 deletions mcp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,16 @@ func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) {
}

type binder[T handler, State any] interface {
bind(Connection, *jsonrpc2.Connection, State) T
// TODO(rfindley): the bind API has gotten too complicated. Simplify.
bind(Connection, *jsonrpc2.Connection, State, func()) T
disconnect(T)
}

type handler interface {
handle(ctx context.Context, req *jsonrpc.Request) (any, error)
}

func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State) (H, error) {
func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State, onClose func()) (H, error) {
var zero H
mcpConn, err := t.Connect(ctx)
if err != nil {
Expand All @@ -143,7 +144,7 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H,
preempter canceller
)
bind := func(conn *jsonrpc2.Connection) jsonrpc2.Handler {
h = b.bind(mcpConn, conn, s)
h = b.bind(mcpConn, conn, s, onClose)
preempter.conn = conn
return jsonrpc2.HandlerFunc(h.handle)
}
Expand Down