Skip to content

Commit b198721

Browse files
committed
mcp: cleanup transport after keepalive ping fails
An onClose function is passed to the ServerSession and ClientSession to help cleanup resources from the caller. The onClose function will be executed as part of the ServerSession and ClientSession closure. Fixes #258
1 parent 42f419f commit b198721

File tree

5 files changed

+117
-11
lines changed

5 files changed

+117
-11
lines changed

mcp/client.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ type ClientOptions struct {
7676

7777
// bind implements the binder[*ClientSession] interface, so that Clients can
7878
// be connected using [connect].
79-
func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState) *ClientSession {
79+
func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState, onClose func()) *ClientSession {
8080
assert(mcpConn != nil && conn != nil, "nil connection")
81-
cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c}
81+
cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c, onClose: onClose}
8282
if state != nil {
8383
cs.state = *state
8484
}
@@ -130,7 +130,7 @@ func (c *Client) capabilities() *ClientCapabilities {
130130
// server, calls or notifications will return an error wrapping
131131
// [ErrConnectionClosed].
132132
func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptions) (cs *ClientSession, err error) {
133-
cs, err = connect(ctx, t, c, (*clientSessionState)(nil))
133+
cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil)
134134
if err != nil {
135135
return nil, err
136136
}
@@ -173,6 +173,8 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio
173173
// Call [ClientSession.Close] to close the connection, or await server
174174
// termination with [ClientSession.Wait].
175175
type ClientSession struct {
176+
onClose func()
177+
176178
conn *jsonrpc2.Connection
177179
client *Client
178180
keepaliveCancel context.CancelFunc
@@ -208,7 +210,13 @@ func (cs *ClientSession) Close() error {
208210
if cs.keepaliveCancel != nil {
209211
cs.keepaliveCancel()
210212
}
211-
return cs.conn.Close()
213+
err := cs.conn.Close()
214+
215+
if cs.onClose != nil {
216+
cs.onClose()
217+
}
218+
219+
return err
212220
}
213221

214222
// Wait waits for the connection to be closed by the server.

mcp/server.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -681,9 +681,9 @@ func (s *Server) Run(ctx context.Context, t Transport) error {
681681

682682
// bind implements the binder[*ServerSession] interface, so that Servers can
683683
// be connected using [connect].
684-
func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState) *ServerSession {
684+
func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState, onClose func()) *ServerSession {
685685
assert(mcpConn != nil && conn != nil, "nil connection")
686-
ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s}
686+
ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s, onClose: onClose}
687687
if state != nil {
688688
ss.state = *state
689689
}
@@ -710,6 +710,8 @@ func (s *Server) disconnect(cc *ServerSession) {
710710
// ServerSessionOptions configures the server session.
711711
type ServerSessionOptions struct {
712712
State *ServerSessionState
713+
714+
onClose func()
713715
}
714716

715717
// Connect connects the MCP server over the given transport and starts handling
@@ -722,10 +724,12 @@ type ServerSessionOptions struct {
722724
// If opts.State is non-nil, it is the initial state for the server.
723725
func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOptions) (*ServerSession, error) {
724726
var state *ServerSessionState
727+
var onClose func()
725728
if opts != nil {
726729
state = opts.State
730+
onClose = opts.onClose
727731
}
728-
return connect(ctx, t, s, state)
732+
return connect(ctx, t, s, state, onClose)
729733
}
730734

731735
// TODO: (nit) move all ServerSession methods below the ServerSession declaration.
@@ -792,6 +796,8 @@ func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] {
792796
// Call [ServerSession.Close] to close the connection, or await client
793797
// termination with [ServerSession.Wait].
794798
type ServerSession struct {
799+
onClose func()
800+
795801
server *Server
796802
conn *jsonrpc2.Connection
797803
mcpConn Connection
@@ -1018,7 +1024,13 @@ func (ss *ServerSession) Close() error {
10181024
// Close is idempotent and conn.Close() handles concurrent calls correctly
10191025
ss.keepaliveCancel()
10201026
}
1021-
return ss.conn.Close()
1027+
err := ss.conn.Close()
1028+
1029+
if ss.onClose != nil {
1030+
ss.onClose()
1031+
}
1032+
1033+
return err
10221034
}
10231035

10241036
// Wait waits for the connection to be closed by the client.

mcp/streamable.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ type StreamableHTTPHandler struct {
4040
getServer func(*http.Request) *Server
4141
opts StreamableHTTPOptions
4242

43+
onTransportDeletion func(sessionID string) // for testing only
44+
4345
mu sync.Mutex
4446
// TODO: we should store the ServerSession along with the transport, because
4547
// we need to cancel keepalive requests when closing the transport.
@@ -283,6 +285,17 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
283285
connectOpts = &ServerSessionOptions{
284286
State: state,
285287
}
288+
} else {
289+
connectOpts = &ServerSessionOptions{
290+
onClose: func() {
291+
h.mu.Lock()
292+
delete(h.transports, transport.SessionID)
293+
h.mu.Unlock()
294+
if h.onTransportDeletion != nil {
295+
h.onTransportDeletion(transport.SessionID)
296+
}
297+
},
298+
}
286299
}
287300

288301
// Pass req.Context() here, to allow middleware to add context values.

mcp/streamable_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"bytes"
99
"context"
1010
"encoding/json"
11+
"errors"
1112
"fmt"
1213
"io"
1314
"maps"
@@ -332,6 +333,78 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
332333
}
333334
}
334335

336+
func TestServerTransportCleanup(t *testing.T) {
337+
server := NewServer(testImpl, &ServerOptions{KeepAlive: 10 * time.Millisecond})
338+
339+
nClient := 3
340+
341+
var mu sync.Mutex
342+
var id int = -1 // session id starting from "0", "1", "2"...
343+
chans := make(map[string]chan struct{}, nClient)
344+
345+
handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
346+
GetSessionID: func() string {
347+
mu.Lock()
348+
defer mu.Unlock()
349+
id++
350+
if id == nClient {
351+
t.Errorf("creating more than %v session", nClient)
352+
}
353+
chans[fmt.Sprint(id)] = make(chan struct{}, 1)
354+
return fmt.Sprint(id)
355+
},
356+
})
357+
358+
handler.onTransportDeletion = func(sessionID string) {
359+
mu.Lock()
360+
chans[sessionID] <- struct{}{}
361+
mu.Unlock()
362+
}
363+
364+
httpServer := httptest.NewServer(handler)
365+
defer httpServer.Close()
366+
367+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
368+
defer cancel()
369+
370+
// Spin up clients connect to the same server but refuse to ping request.
371+
for range nClient {
372+
client := NewClient(testImpl, nil)
373+
pingMiddleware := func(next MethodHandler) MethodHandler {
374+
return func(
375+
ctx context.Context,
376+
method string,
377+
req Request,
378+
) (Result, error) {
379+
if method == "ping" {
380+
return &emptyResult{}, errors.New("ping error")
381+
}
382+
return next(ctx, method, req)
383+
}
384+
}
385+
client.AddReceivingMiddleware(pingMiddleware)
386+
clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil)
387+
if err != nil {
388+
t.Fatalf("client.Connect() failed: %v", err)
389+
}
390+
defer clientSession.Close()
391+
}
392+
393+
for _, ch := range chans {
394+
select {
395+
case <-ctx.Done():
396+
t.Errorf("did not capture transport deletion event from all session in 10 seconds")
397+
case <-ch: // Received transport deletion signal of this session
398+
}
399+
}
400+
401+
handler.mu.Lock()
402+
if len(handler.transports) != 0 {
403+
t.Errorf("want empty transports map, find %v entries from handler's transports map", len(handler.transports))
404+
}
405+
handler.mu.Unlock()
406+
}
407+
335408
// TestServerInitiatedSSE verifies that the persistent SSE connection remains
336409
// open and can receive server-initiated events.
337410
func TestServerInitiatedSSE(t *testing.T) {

mcp/transport.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,15 @@ func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) {
122122
}
123123

124124
type binder[T handler, State any] interface {
125-
bind(Connection, *jsonrpc2.Connection, State) T
125+
bind(Connection, *jsonrpc2.Connection, State, func()) T
126126
disconnect(T)
127127
}
128128

129129
type handler interface {
130130
handle(ctx context.Context, req *jsonrpc.Request) (any, error)
131131
}
132132

133-
func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State) (H, error) {
133+
func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State, onClose func()) (H, error) {
134134
var zero H
135135
mcpConn, err := t.Connect(ctx)
136136
if err != nil {
@@ -143,7 +143,7 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H,
143143
preempter canceller
144144
)
145145
bind := func(conn *jsonrpc2.Connection) jsonrpc2.Handler {
146-
h = b.bind(mcpConn, conn, s)
146+
h = b.bind(mcpConn, conn, s, onClose)
147147
preempter.conn = conn
148148
return jsonrpc2.HandlerFunc(h.handle)
149149
}

0 commit comments

Comments
 (0)