Skip to content

Commit fb69971

Browse files
authored
mcp: cleanup transport after keepalive ping fails
An onClose function is passed to the ServerSession and ClientSession to help cleanup resources from the caller. The onClose function will be executed as part of the ServerSession and ClientSession closure. Fixes #258
1 parent 9754a2a commit fb69971

File tree

5 files changed

+118
-11
lines changed

5 files changed

+118
-11
lines changed

mcp/client.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ type ClientOptions struct {
7676

7777
// bind implements the binder[*ClientSession] interface, so that Clients can
7878
// be connected using [connect].
79-
func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState) *ClientSession {
79+
func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState, onClose func()) *ClientSession {
8080
assert(mcpConn != nil && conn != nil, "nil connection")
81-
cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c}
81+
cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c, onClose: onClose}
8282
if state != nil {
8383
cs.state = *state
8484
}
@@ -130,7 +130,7 @@ func (c *Client) capabilities() *ClientCapabilities {
130130
// server, calls or notifications will return an error wrapping
131131
// [ErrConnectionClosed].
132132
func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptions) (cs *ClientSession, err error) {
133-
cs, err = connect(ctx, t, c, (*clientSessionState)(nil))
133+
cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil)
134134
if err != nil {
135135
return nil, err
136136
}
@@ -173,6 +173,8 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio
173173
// Call [ClientSession.Close] to close the connection, or await server
174174
// termination with [ClientSession.Wait].
175175
type ClientSession struct {
176+
onClose func()
177+
176178
conn *jsonrpc2.Connection
177179
client *Client
178180
keepaliveCancel context.CancelFunc
@@ -208,7 +210,13 @@ func (cs *ClientSession) Close() error {
208210
if cs.keepaliveCancel != nil {
209211
cs.keepaliveCancel()
210212
}
211-
return cs.conn.Close()
213+
err := cs.conn.Close()
214+
215+
if cs.onClose != nil {
216+
cs.onClose()
217+
}
218+
219+
return err
212220
}
213221

214222
// Wait waits for the connection to be closed by the server.

mcp/server.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -698,9 +698,9 @@ func (s *Server) Run(ctx context.Context, t Transport) error {
698698

699699
// bind implements the binder[*ServerSession] interface, so that Servers can
700700
// be connected using [connect].
701-
func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState) *ServerSession {
701+
func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState, onClose func()) *ServerSession {
702702
assert(mcpConn != nil && conn != nil, "nil connection")
703-
ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s}
703+
ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s, onClose: onClose}
704704
if state != nil {
705705
ss.state = *state
706706
}
@@ -727,6 +727,8 @@ func (s *Server) disconnect(cc *ServerSession) {
727727
// ServerSessionOptions configures the server session.
728728
type ServerSessionOptions struct {
729729
State *ServerSessionState
730+
731+
onClose func()
730732
}
731733

732734
// Connect connects the MCP server over the given transport and starts handling
@@ -739,10 +741,12 @@ type ServerSessionOptions struct {
739741
// If opts.State is non-nil, it is the initial state for the server.
740742
func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOptions) (*ServerSession, error) {
741743
var state *ServerSessionState
744+
var onClose func()
742745
if opts != nil {
743746
state = opts.State
747+
onClose = opts.onClose
744748
}
745-
return connect(ctx, t, s, state)
749+
return connect(ctx, t, s, state, onClose)
746750
}
747751

748752
// TODO: (nit) move all ServerSession methods below the ServerSession declaration.
@@ -809,6 +813,8 @@ func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] {
809813
// Call [ServerSession.Close] to close the connection, or await client
810814
// termination with [ServerSession.Wait].
811815
type ServerSession struct {
816+
onClose func()
817+
812818
server *Server
813819
conn *jsonrpc2.Connection
814820
mcpConn Connection
@@ -1043,7 +1049,13 @@ func (ss *ServerSession) Close() error {
10431049
// Close is idempotent and conn.Close() handles concurrent calls correctly
10441050
ss.keepaliveCancel()
10451051
}
1046-
return ss.conn.Close()
1052+
err := ss.conn.Close()
1053+
1054+
if ss.onClose != nil {
1055+
ss.onClose()
1056+
}
1057+
1058+
return err
10471059
}
10481060

10491061
// Wait waits for the connection to be closed by the client.

mcp/streamable.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ type StreamableHTTPHandler struct {
4040
getServer func(*http.Request) *Server
4141
opts StreamableHTTPOptions
4242

43+
onTransportDeletion func(sessionID string) // for testing only
44+
4345
mu sync.Mutex
4446
// TODO: we should store the ServerSession along with the transport, because
4547
// we need to cancel keepalive requests when closing the transport.
@@ -283,6 +285,19 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
283285
connectOpts = &ServerSessionOptions{
284286
State: state,
285287
}
288+
} else {
289+
// Cleanup is only required in stateful mode, as transportation is
290+
// not stored in the map otherwise.
291+
connectOpts = &ServerSessionOptions{
292+
onClose: func() {
293+
h.mu.Lock()
294+
delete(h.transports, transport.SessionID)
295+
h.mu.Unlock()
296+
if h.onTransportDeletion != nil {
297+
h.onTransportDeletion(transport.SessionID)
298+
}
299+
},
300+
}
286301
}
287302

288303
// Pass req.Context() here, to allow middleware to add context values.

mcp/streamable_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"bytes"
99
"context"
1010
"encoding/json"
11+
"errors"
1112
"fmt"
1213
"io"
1314
"maps"
@@ -332,6 +333,76 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
332333
}
333334
}
334335

336+
func TestServerTransportCleanup(t *testing.T) {
337+
server := NewServer(testImpl, &ServerOptions{KeepAlive: 10 * time.Millisecond})
338+
339+
nClient := 3
340+
341+
var mu sync.Mutex
342+
var id int = -1 // session id starting from "0", "1", "2"...
343+
chans := make(map[string]chan struct{}, nClient)
344+
345+
handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
346+
GetSessionID: func() string {
347+
mu.Lock()
348+
defer mu.Unlock()
349+
id++
350+
if id == nClient {
351+
t.Errorf("creating more than %v session", nClient)
352+
}
353+
chans[fmt.Sprint(id)] = make(chan struct{}, 1)
354+
return fmt.Sprint(id)
355+
},
356+
})
357+
358+
handler.onTransportDeletion = func(sessionID string) {
359+
chans[sessionID] <- struct{}{}
360+
}
361+
362+
httpServer := httptest.NewServer(handler)
363+
defer httpServer.Close()
364+
365+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
366+
defer cancel()
367+
368+
// Spin up clients connect to the same server but refuse to ping request.
369+
for range nClient {
370+
client := NewClient(testImpl, nil)
371+
pingMiddleware := func(next MethodHandler) MethodHandler {
372+
return func(
373+
ctx context.Context,
374+
method string,
375+
req Request,
376+
) (Result, error) {
377+
if method == "ping" {
378+
return &emptyResult{}, errors.New("ping error")
379+
}
380+
return next(ctx, method, req)
381+
}
382+
}
383+
client.AddReceivingMiddleware(pingMiddleware)
384+
clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil)
385+
if err != nil {
386+
t.Fatalf("client.Connect() failed: %v", err)
387+
}
388+
defer clientSession.Close()
389+
}
390+
391+
for _, ch := range chans {
392+
select {
393+
case <-ctx.Done():
394+
t.Errorf("did not capture transport deletion event from all session in 10 seconds")
395+
case <-ch: // Received transport deletion signal of this session
396+
}
397+
}
398+
399+
handler.mu.Lock()
400+
if len(handler.transports) != 0 {
401+
t.Errorf("want empty transports map, find %v entries from handler's transports map", len(handler.transports))
402+
}
403+
handler.mu.Unlock()
404+
}
405+
335406
// TestServerInitiatedSSE verifies that the persistent SSE connection remains
336407
// open and can receive server-initiated events.
337408
func TestServerInitiatedSSE(t *testing.T) {

mcp/transport.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,16 @@ func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) {
122122
}
123123

124124
type binder[T handler, State any] interface {
125-
bind(Connection, *jsonrpc2.Connection, State) T
125+
// TODO(rfindley): the bind API has gotten too complicated. Simplify.
126+
bind(Connection, *jsonrpc2.Connection, State, func()) T
126127
disconnect(T)
127128
}
128129

129130
type handler interface {
130131
handle(ctx context.Context, req *jsonrpc.Request) (any, error)
131132
}
132133

133-
func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State) (H, error) {
134+
func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State, onClose func()) (H, error) {
134135
var zero H
135136
mcpConn, err := t.Connect(ctx)
136137
if err != nil {
@@ -143,7 +144,7 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H,
143144
preempter canceller
144145
)
145146
bind := func(conn *jsonrpc2.Connection) jsonrpc2.Handler {
146-
h = b.bind(mcpConn, conn, s)
147+
h = b.bind(mcpConn, conn, s, onClose)
147148
preempter.conn = conn
148149
return jsonrpc2.HandlerFunc(h.handle)
149150
}

0 commit comments

Comments
 (0)