diff --git a/mcp/client.go b/mcp/client.go index 3b1741b3..1ed3b048 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -76,9 +76,9 @@ type ClientOptions struct { // bind implements the binder[*ClientSession] interface, so that Clients can // be connected using [connect]. -func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState) *ClientSession { +func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState, onClose func()) *ClientSession { assert(mcpConn != nil && conn != nil, "nil connection") - cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c} + cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c, onClose: onClose} if state != nil { cs.state = *state } @@ -130,7 +130,7 @@ func (c *Client) capabilities() *ClientCapabilities { // server, calls or notifications will return an error wrapping // [ErrConnectionClosed]. func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptions) (cs *ClientSession, err error) { - cs, err = connect(ctx, t, c, (*clientSessionState)(nil)) + cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil) if err != nil { return nil, err } @@ -173,6 +173,8 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio // Call [ClientSession.Close] to close the connection, or await server // termination with [ClientSession.Wait]. type ClientSession struct { + onClose func() + conn *jsonrpc2.Connection client *Client keepaliveCancel context.CancelFunc @@ -208,7 +210,13 @@ func (cs *ClientSession) Close() error { if cs.keepaliveCancel != nil { cs.keepaliveCancel() } - return cs.conn.Close() + err := cs.conn.Close() + + if cs.onClose != nil { + cs.onClose() + } + + return err } // Wait waits for the connection to be closed by the server. diff --git a/mcp/server.go b/mcp/server.go index c44dfeb6..49a3f56e 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -681,9 +681,9 @@ func (s *Server) Run(ctx context.Context, t Transport) error { // bind implements the binder[*ServerSession] interface, so that Servers can // be connected using [connect]. -func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState) *ServerSession { +func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState, onClose func()) *ServerSession { assert(mcpConn != nil && conn != nil, "nil connection") - ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s} + ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s, onClose: onClose} if state != nil { ss.state = *state } @@ -710,6 +710,8 @@ func (s *Server) disconnect(cc *ServerSession) { // ServerSessionOptions configures the server session. type ServerSessionOptions struct { State *ServerSessionState + + onClose func() } // Connect connects the MCP server over the given transport and starts handling @@ -722,10 +724,12 @@ type ServerSessionOptions struct { // If opts.State is non-nil, it is the initial state for the server. func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOptions) (*ServerSession, error) { var state *ServerSessionState + var onClose func() if opts != nil { state = opts.State + onClose = opts.onClose } - return connect(ctx, t, s, state) + return connect(ctx, t, s, state, onClose) } // TODO: (nit) move all ServerSession methods below the ServerSession declaration. @@ -792,6 +796,8 @@ func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] { // Call [ServerSession.Close] to close the connection, or await client // termination with [ServerSession.Wait]. type ServerSession struct { + onClose func() + server *Server conn *jsonrpc2.Connection mcpConn Connection @@ -1018,7 +1024,13 @@ func (ss *ServerSession) Close() error { // Close is idempotent and conn.Close() handles concurrent calls correctly ss.keepaliveCancel() } - return ss.conn.Close() + err := ss.conn.Close() + + if ss.onClose != nil { + ss.onClose() + } + + return err } // Wait waits for the connection to be closed by the client. diff --git a/mcp/streamable.go b/mcp/streamable.go index 99fbe422..f56b7084 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -40,6 +40,8 @@ type StreamableHTTPHandler struct { getServer func(*http.Request) *Server opts StreamableHTTPOptions + onTransportDeletion func(sessionID string) // for testing only + mu sync.Mutex // TODO: we should store the ServerSession along with the transport, because // we need to cancel keepalive requests when closing the transport. @@ -283,6 +285,19 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque connectOpts = &ServerSessionOptions{ State: state, } + } else { + // Cleanup is only required in stateful mode, as transportation is + // not stored in the map otherwise. + connectOpts = &ServerSessionOptions{ + onClose: func() { + h.mu.Lock() + delete(h.transports, transport.SessionID) + h.mu.Unlock() + if h.onTransportDeletion != nil { + h.onTransportDeletion(transport.SessionID) + } + }, + } } // Pass req.Context() here, to allow middleware to add context values. diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 603be473..3d2e1c02 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -8,6 +8,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "maps" @@ -332,6 +333,76 @@ func testClientReplay(t *testing.T, test clientReplayTest) { } } +func TestServerTransportCleanup(t *testing.T) { + server := NewServer(testImpl, &ServerOptions{KeepAlive: 10 * time.Millisecond}) + + nClient := 3 + + var mu sync.Mutex + var id int = -1 // session id starting from "0", "1", "2"... + chans := make(map[string]chan struct{}, nClient) + + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + GetSessionID: func() string { + mu.Lock() + defer mu.Unlock() + id++ + if id == nClient { + t.Errorf("creating more than %v session", nClient) + } + chans[fmt.Sprint(id)] = make(chan struct{}, 1) + return fmt.Sprint(id) + }, + }) + + handler.onTransportDeletion = func(sessionID string) { + chans[sessionID] <- struct{}{} + } + + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Spin up clients connect to the same server but refuse to ping request. + for range nClient { + client := NewClient(testImpl, nil) + pingMiddleware := func(next MethodHandler) MethodHandler { + return func( + ctx context.Context, + method string, + req Request, + ) (Result, error) { + if method == "ping" { + return &emptyResult{}, errors.New("ping error") + } + return next(ctx, method, req) + } + } + client.AddReceivingMiddleware(pingMiddleware) + clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer clientSession.Close() + } + + for _, ch := range chans { + select { + case <-ctx.Done(): + t.Errorf("did not capture transport deletion event from all session in 10 seconds") + case <-ch: // Received transport deletion signal of this session + } + } + + handler.mu.Lock() + if len(handler.transports) != 0 { + t.Errorf("want empty transports map, find %v entries from handler's transports map", len(handler.transports)) + } + handler.mu.Unlock() +} + // TestServerInitiatedSSE verifies that the persistent SSE connection remains // open and can receive server-initiated events. func TestServerInitiatedSSE(t *testing.T) { diff --git a/mcp/transport.go b/mcp/transport.go index 2bcd8d7d..fac640a6 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -122,7 +122,8 @@ func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) { } type binder[T handler, State any] interface { - bind(Connection, *jsonrpc2.Connection, State) T + // TODO(rfindley): the bind API has gotten too complicated. Simplify. + bind(Connection, *jsonrpc2.Connection, State, func()) T disconnect(T) } @@ -130,7 +131,7 @@ type handler interface { handle(ctx context.Context, req *jsonrpc.Request) (any, error) } -func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State) (H, error) { +func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State, onClose func()) (H, error) { var zero H mcpConn, err := t.Connect(ctx) if err != nil { @@ -143,7 +144,7 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, preempter canceller ) bind := func(conn *jsonrpc2.Connection) jsonrpc2.Handler { - h = b.bind(mcpConn, conn, s) + h = b.bind(mcpConn, conn, s, onClose) preempter.conn = conn return jsonrpc2.HandlerFunc(h.handle) }