From 52f6ff42e60b96717322fa17c67c53ee850a38ef Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 4 Aug 2025 12:23:03 -0400 Subject: [PATCH 1/6] server: support distributed sessions - SessionState holds the state of a ServerSession. - SessionStore allows arbitrary storage for SessionState. - ServerSession.InitSession provides a session ID, store and state to a ServerSession. --- mcp/logging.go | 2 +- mcp/server.go | 41 ++++++++++++++++------ mcp/session.go | 74 +++++++++++++++++++++++++++++++++++++++ mcp/session_test.go | 48 ++++++++++++++++++++++++++ mcp/streamable.go | 84 +++++++++++++++++++++++++++++---------------- 5 files changed, 208 insertions(+), 41 deletions(-) create mode 100644 mcp/session.go create mode 100644 mcp/session_test.go diff --git a/mcp/logging.go b/mcp/logging.go index 4880e179..4d33097a 100644 --- a/mcp/logging.go +++ b/mcp/logging.go @@ -117,7 +117,7 @@ func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool { // This is also checked in ServerSession.LoggingMessage, so checking it here // is just an optimization that skips building the JSON. h.ss.mu.Lock() - mcpLevel := h.ss.logLevel + mcpLevel := h.ss.state.LogLevel h.ss.mu.Unlock() return level >= mcpLevelToSlog(mcpLevel) } diff --git a/mcp/server.go b/mcp/server.go index e69a872e..bfaa9c37 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -536,7 +536,7 @@ func (s *Server) Run(ctx context.Context, t Transport) error { // bind implements the binder[*ServerSession] interface, so that Servers can // be connected using [connect]. func (s *Server) bind(conn *jsonrpc2.Connection) *ServerSession { - ss := &ServerSession{conn: conn, server: s} + ss := &ServerSession{conn: conn, server: s, state: &SessionState{}} s.mu.Lock() s.sessions = append(s.sessions, ss) s.mu.Unlock() @@ -619,17 +619,26 @@ type ServerSession struct { initializeParams *InitializeParams _initialized bool keepaliveCancel context.CancelFunc + sessionID string + state *SessionState + store SessionStore } func (ss *ServerSession) setConn(c Connection) { - ss.mcpConn = c } -func (ss *ServerSession) ID() string { - if ss.mcpConn == nil { - return "" - } - return ss.mcpConn.SessionID() +func (ss *ServerSession) ID() string { return ss.sessionID } + +// InitSession initializes the session with a session ID, state, and store. +// If called, it must be called immediately after the session is connected. +// If never called, the session will begin with a zero SessionState and no session +// ID or store. +func (ss *ServerSession) InitSession(sessionID string, state *SessionState, store SessionStore) { + ss.mu.Lock() + defer ss.mu.Unlock() + ss.sessionID = sessionID + ss.state = state + ss.store = store } // Ping pings the client. @@ -653,7 +662,7 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag // is below that of the last SetLevel. func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) error { ss.mu.Lock() - logLevel := ss.logLevel + logLevel := ss.state.LogLevel ss.mu.Unlock() if logLevel == "" { // The spec is unclear, but seems to imply that no log messages are sent until the client @@ -767,8 +776,13 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) } ss.mu.Lock() - ss.initializeParams = params + ss.state.InitializeParams = params ss.mu.Unlock() + if ss.store != nil { + if err := ss.store.Store(ctx, ss.sessionID, ss.state); err != nil { + return nil, fmt.Errorf("storing session state: %w", err) + } + } // If we support the client's version, reply with it. Otherwise, reply with our // latest version. @@ -791,10 +805,15 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error return &emptyResult{}, nil } -func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*emptyResult, error) { +func (ss *ServerSession) setLevel(ctx context.Context, params *SetLevelParams) (*emptyResult, error) { ss.mu.Lock() defer ss.mu.Unlock() - ss.logLevel = params.Level + ss.state.LogLevel = params.Level + if ss.store != nil { + if err := ss.store.Store(ctx, ss.sessionID, ss.state); err != nil { + return nil, err + } + } return &emptyResult{}, nil } diff --git a/mcp/session.go b/mcp/session.go new file mode 100644 index 00000000..d4a944c1 --- /dev/null +++ b/mcp/session.go @@ -0,0 +1,74 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "io/fs" + "sync" +) + +// SessionState is the state of a session. +type SessionState struct { + // InitializeParams are the parameters from the initialize request. + InitializeParams *InitializeParams `json:"initializeParams"` + + // LogLevel is the logging level for the session. + LogLevel LoggingLevel `json:"logLevel"` + + // TODO: resource subscriptions +} + +// SessionStore is an interface for storing and retrieving session state. +type SessionStore interface { + // Load retrieves the session state for the given session ID. + // If there is none, it returns nil, fs.ErrNotExist. + Load(ctx context.Context, sessionID string) (*SessionState, error) + // Store saves the session state for the given session ID. + Store(ctx context.Context, sessionID string, state *SessionState) error + // Delete removes the session state for the given session ID. + Delete(ctx context.Context, sessionID string) error +} + +// MemorySessionStore is an in-memory implementation of SessionStore. +// It is safe for concurrent use. +type MemorySessionStore struct { + mu sync.Mutex + store map[string]*SessionState +} + +// NewMemorySessionStore creates a new MemorySessionStore. +func NewMemorySessionStore() *MemorySessionStore { + return &MemorySessionStore{ + store: make(map[string]*SessionState), + } +} + +// Load retrieves the session state for the given session ID. +func (s *MemorySessionStore) Load(ctx context.Context, sessionID string) (*SessionState, error) { + s.mu.Lock() + defer s.mu.Unlock() + state, ok := s.store[sessionID] + if !ok { + return nil, fs.ErrNotExist + } + return state, nil +} + +// Store saves the session state for the given session ID. +func (s *MemorySessionStore) Store(ctx context.Context, sessionID string, state *SessionState) error { + s.mu.Lock() + defer s.mu.Unlock() + s.store[sessionID] = state + return nil +} + +// Delete removes the session state for the given session ID. +func (s *MemorySessionStore) Delete(ctx context.Context, sessionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.store, sessionID) + return nil +} diff --git a/mcp/session_test.go b/mcp/session_test.go new file mode 100644 index 00000000..a4d2e966 --- /dev/null +++ b/mcp/session_test.go @@ -0,0 +1,48 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "io/fs" + "testing" +) + +func TestMemorySessionStore(t *testing.T) { + ctx := context.Background() + store := NewMemorySessionStore() + + sessionID := "test-session" + state := &SessionState{LogLevel: "debug"} + + // Test Store and Load + if err := store.Store(ctx, sessionID, state); err != nil { + t.Fatalf("Store() error = %v", err) + } + + loadedState, err := store.Load(ctx, sessionID) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if loadedState == nil { + t.Fatal("Load() returned nil state") + } + if loadedState.LogLevel != state.LogLevel { + t.Errorf("Load() LogLevel = %v, want %v", loadedState.LogLevel, state.LogLevel) + } + + // Test Delete + if err := store.Delete(ctx, sessionID); err != nil { + t.Fatalf("Delete() error = %v", err) + } + + deletedState, err := store.Load(ctx, sessionID) + if err != fs.ErrNotExist { + t.Fatalf("Load() after Delete(): got %v, want fs.ErrNotExist", err) + } + if deletedState != nil { + t.Error("Load() after Delete() returned non-nil state") + } +} diff --git a/mcp/streamable.go b/mcp/streamable.go index 108de5d2..278b3ea0 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "io/fs" "iter" "math" "math/rand/v2" @@ -35,11 +36,11 @@ const ( // // [MCP spec]: https://modelcontextprotocol.io/2025/03/26/streamable-http-transport.html type StreamableHTTPHandler struct { - getServer func(*http.Request) *Server opts StreamableHTTPOptions + getServer func(*http.Request) *Server - sessionsMu sync.Mutex - sessions map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) + transportMu sync.Mutex + transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) } // StreamableHTTPOptions is a placeholder options struct for future @@ -51,6 +52,10 @@ type StreamableHTTPOptions struct { // transportOptions sets the streamable server transport options to use when // establishing a new session. transportOptions *StreamableServerTransportOptions + + // SessionStore is the store for persistent sessions. + // If nil, sessions will be stored in memory. + SessionStore SessionStore } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -60,12 +65,15 @@ type StreamableHTTPOptions struct { // If getServer returns nil, a 400 Bad Request will be served. func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *StreamableHTTPOptions) *StreamableHTTPHandler { h := &StreamableHTTPHandler{ - getServer: getServer, - sessions: make(map[string]*StreamableServerTransport), + getServer: getServer, + transports: make(map[string]*StreamableServerTransport), } if opts != nil { h.opts = *opts } + if h.opts.SessionStore == nil { + h.opts.SessionStore = NewMemorySessionStore() + } return h } @@ -77,12 +85,12 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea // Should we allow passing in a session store? That would allow the handler to // be stateless. func (h *StreamableHTTPHandler) closeAll() { - h.sessionsMu.Lock() - defer h.sessionsMu.Unlock() - for _, s := range h.sessions { - s.Close() + h.transportMu.Lock() + defer h.transportMu.Unlock() + for _, t := range h.transports { + t.Close() } - h.sessions = nil + h.transports = nil } func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -110,14 +118,11 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } var session *StreamableServerTransport - if id := req.Header.Get(sessionIDHeader); id != "" { - h.sessionsMu.Lock() - session, _ = h.sessions[id] - h.sessionsMu.Unlock() - if session == nil { - http.Error(w, "session not found", http.StatusNotFound) - return - } + sessionID := req.Header.Get(sessionIDHeader) + if sessionID != "" { + h.transportMu.Lock() + session, _ = h.transports[sessionID] + h.transportMu.Unlock() } // TODO(rfindley): simplify the locking so that each request has only one @@ -128,9 +133,9 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) return } - h.sessionsMu.Lock() - delete(h.sessions, session.sessionID) - h.sessionsMu.Unlock() + h.transportMu.Lock() + delete(h.transports, session.sessionID) + h.transportMu.Unlock() session.Close() w.WriteHeader(http.StatusNoContent) return @@ -145,24 +150,45 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } if session == nil { - s := NewStreamableServerTransport(randText(), h.opts.transportOptions) + var state *SessionState + var err error + if sessionID != "" { + // The session might be in the store. + state, err = h.opts.SessionStore.Load(req.Context(), sessionID) + if errors.Is(err, fs.ErrNotExist) { + http.Error(w, fmt.Sprintf("no session with ID %s", sessionID), http.StatusNotFound) + return + } else if err != nil { + http.Error(w, fmt.Sprintf("SessionStore.Load(%q): %v", sessionID, err), http.StatusInternalServerError) + return + } + session = NewStreamableServerTransport(sessionID, nil) + } else { + state = &SessionState{} + sessionID = randText() + if err := h.opts.SessionStore.Store(req.Context(), sessionID, state); err != nil { + http.Error(w, fmt.Sprintf("SessionStore.Store, new session: %v", err), http.StatusInternalServerError) + return + } + session = NewStreamableServerTransport(sessionID, nil) + } server := h.getServer(req) if server == nil { - // The getServer argument to NewStreamableHTTPHandler returned nil. http.Error(w, "no server available", http.StatusBadRequest) return } // Pass req.Context() here, to allow middleware to add context values. // The context is detached in the jsonrpc2 library when handling the // long-running stream. - if _, err := server.Connect(req.Context(), s); err != nil { - http.Error(w, "failed connection", http.StatusInternalServerError) + ss, err := server.Connect(req.Context(), session) + if err != nil { + http.Error(w, fmt.Sprintf("failed connection: %v", err), http.StatusInternalServerError) return } - h.sessionsMu.Lock() - h.sessions[s.sessionID] = s - h.sessionsMu.Unlock() - session = s + ss.InitSession(sessionID, state, h.opts.SessionStore) + h.transportMu.Lock() + h.transports[session.sessionID] = session + h.transportMu.Unlock() } session.ServeHTTP(w, req) From d9bf37ef74116c0393ded034dc7fd3e03644e719 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 4 Aug 2025 16:02:09 -0400 Subject: [PATCH 2/6] add Connect arg (#232) --- mcp/example_middleware_test.go | 2 +- mcp/logging.go | 2 +- mcp/mcp_test.go | 14 +++--- mcp/server.go | 78 +++++++++++++++++++--------------- mcp/server_example_test.go | 4 +- mcp/sse.go | 2 +- mcp/streamable.go | 7 ++- 7 files changed, 60 insertions(+), 49 deletions(-) diff --git a/mcp/example_middleware_test.go b/mcp/example_middleware_test.go index 597b9dcd..05b07d8a 100644 --- a/mcp/example_middleware_test.go +++ b/mcp/example_middleware_test.go @@ -114,7 +114,7 @@ func Example_loggingMiddleware() { ctx := context.Background() // Connect server and client - serverSession, _ := server.Connect(ctx, serverTransport) + serverSession, _ := server.Connect(ctx, serverTransport, nil) defer serverSession.Close() clientSession, _ := client.Connect(ctx, clientTransport) diff --git a/mcp/logging.go b/mcp/logging.go index 4d33097a..0c031f19 100644 --- a/mcp/logging.go +++ b/mcp/logging.go @@ -117,7 +117,7 @@ func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool { // This is also checked in ServerSession.LoggingMessage, so checking it here // is just an optimization that skips building the JSON. h.ss.mu.Lock() - mcpLevel := h.ss.state.LogLevel + mcpLevel := h.ss.opts.SessionState.LogLevel h.ss.mu.Unlock() return level >= mcpLevelToSlog(mcpLevel) } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 48e95de2..abc5b398 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -104,7 +104,7 @@ func TestEndToEnd(t *testing.T) { s.AddResource(resource2, readHandler) // Connect the server. - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -549,7 +549,7 @@ func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *Clien if config != nil { config(s) } - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -598,7 +598,7 @@ func TestBatching(t *testing.T) { ct, st := NewInMemoryTransports() s := NewServer(testImpl, nil) - _, err := s.Connect(ctx, st) + _, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -668,7 +668,7 @@ func TestMiddleware(t *testing.T) { ct, st := NewInMemoryTransports() s := NewServer(testImpl, nil) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -777,7 +777,7 @@ func TestNoJSONNull(t *testing.T) { ct = NewLoggingTransport(ct, &logbuf) s := NewServer(testImpl, nil) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -845,7 +845,7 @@ func TestKeepAlive(t *testing.T) { s := NewServer(testImpl, serverOpts) AddTool(s, greetTool(), sayHi) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -889,7 +889,7 @@ func TestKeepAliveFailure(t *testing.T) { // Server without keepalive (to test one-sided keepalive) s := NewServer(testImpl, nil) AddTool(s, greetTool(), sayHi) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } diff --git a/mcp/server.go b/mcp/server.go index bfaa9c37..4d153f4f 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -10,6 +10,7 @@ import ( "encoding/base64" "encoding/gob" "encoding/json" + "errors" "fmt" "iter" "maps" @@ -514,7 +515,8 @@ func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *Uns // If no tools have been added, the server will not have the tool capability. // The same goes for other features like prompts and resources. func (s *Server) Run(ctx context.Context, t Transport) error { - ss, err := s.Connect(ctx, t) + // TODO: provide a way to pass ServerSessionOptions? + ss, err := s.Connect(ctx, t, nil) if err != nil { return err } @@ -536,7 +538,7 @@ func (s *Server) Run(ctx context.Context, t Transport) error { // bind implements the binder[*ServerSession] interface, so that Servers can // be connected using [connect]. func (s *Server) bind(conn *jsonrpc2.Connection) *ServerSession { - ss := &ServerSession{conn: conn, server: s, state: &SessionState{}} + ss := &ServerSession{conn: conn, server: s} s.mu.Lock() s.sessions = append(s.sessions, ss) s.mu.Unlock() @@ -557,14 +559,33 @@ func (s *Server) disconnect(cc *ServerSession) { } } +type ServerSessionOptions struct { + SessionID string + SessionState *SessionState + SessionStore SessionStore +} + // Connect connects the MCP server over the given transport and starts handling // messages. // // It returns a connection object that may be used to terminate the connection // (with [Connection.Close]), or await client termination (with // [Connection.Wait]). -func (s *Server) Connect(ctx context.Context, t Transport) (*ServerSession, error) { - return connect(ctx, t, s) +func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOptions) (*ServerSession, error) { + if opts != nil && opts.SessionState == nil && opts.SessionStore != nil { + return nil, errors.New("ServerSessionOptions has store but no state") + } + ss, err := connect(ctx, t, s) + if err != nil { + return nil, err + } + if opts != nil { + ss.opts = *opts + } + if ss.opts.SessionState == nil { + ss.opts.SessionState = &SessionState{} + } + return ss, nil } func (ss *ServerSession) initialized(ctx context.Context, params *InitializedParams) (Result, error) { @@ -572,7 +593,7 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar ss.startKeepalive(ss.server.opts.KeepAlive) } ss.mu.Lock() - hasParams := ss.initializeParams != nil + hasParams := ss.opts.SessionState.InitializeParams != nil wasInitialized := ss._initialized if hasParams { ss._initialized = true @@ -611,34 +632,21 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot // Call [ServerSession.Close] to close the connection, or await client // termination with [ServerSession.Wait]. type ServerSession struct { - server *Server - conn *jsonrpc2.Connection - mcpConn Connection - mu sync.Mutex - logLevel LoggingLevel - initializeParams *InitializeParams - _initialized bool - keepaliveCancel context.CancelFunc - sessionID string - state *SessionState - store SessionStore + server *Server + conn *jsonrpc2.Connection + opts ServerSessionOptions + mu sync.Mutex + logLevel LoggingLevel + _initialized bool + keepaliveCancel context.CancelFunc + sessionID string } func (ss *ServerSession) setConn(c Connection) { } -func (ss *ServerSession) ID() string { return ss.sessionID } - -// InitSession initializes the session with a session ID, state, and store. -// If called, it must be called immediately after the session is connected. -// If never called, the session will begin with a zero SessionState and no session -// ID or store. -func (ss *ServerSession) InitSession(sessionID string, state *SessionState, store SessionStore) { - ss.mu.Lock() - defer ss.mu.Unlock() - ss.sessionID = sessionID - ss.state = state - ss.store = store +func (ss *ServerSession) ID() string { + return ss.opts.SessionID } // Ping pings the client. @@ -662,7 +670,7 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag // is below that of the last SetLevel. func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) error { ss.mu.Lock() - logLevel := ss.state.LogLevel + logLevel := ss.opts.SessionState.LogLevel ss.mu.Unlock() if logLevel == "" { // The spec is unclear, but seems to imply that no log messages are sent until the client @@ -776,10 +784,10 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) } ss.mu.Lock() - ss.state.InitializeParams = params + ss.opts.SessionState.InitializeParams = params ss.mu.Unlock() - if ss.store != nil { - if err := ss.store.Store(ctx, ss.sessionID, ss.state); err != nil { + if store := ss.opts.SessionStore; store != nil { + if err := store.Store(ctx, ss.opts.SessionID, ss.opts.SessionState); err != nil { return nil, fmt.Errorf("storing session state: %w", err) } } @@ -808,9 +816,9 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error func (ss *ServerSession) setLevel(ctx context.Context, params *SetLevelParams) (*emptyResult, error) { ss.mu.Lock() defer ss.mu.Unlock() - ss.state.LogLevel = params.Level - if ss.store != nil { - if err := ss.store.Store(ctx, ss.sessionID, ss.state); err != nil { + ss.opts.SessionState.LogLevel = params.Level + if store := ss.opts.SessionStore; store != nil { + if err := store.Store(ctx, ss.opts.SessionID, ss.opts.SessionState); err != nil { return nil, err } } diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go index 3ab7a2a4..33bde4b9 100644 --- a/mcp/server_example_test.go +++ b/mcp/server_example_test.go @@ -31,7 +31,7 @@ func ExampleServer() { server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v0.0.1"}, nil) mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) - serverSession, err := server.Connect(ctx, serverTransport) + serverSession, err := server.Connect(ctx, serverTransport, nil) if err != nil { log.Fatal(err) } @@ -62,7 +62,7 @@ func createSessions(ctx context.Context) (*mcp.ClientSession, *mcp.ServerSession server := mcp.NewServer(testImpl, nil) client := mcp.NewClient(testImpl, nil) serverTransport, clientTransport := mcp.NewInMemoryTransports() - serverSession, err := server.Connect(ctx, serverTransport) + serverSession, err := server.Connect(ctx, serverTransport, nil) if err != nil { log.Fatal(err) } diff --git a/mcp/sse.go b/mcp/sse.go index bdc4770b..f74a3fb6 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -221,7 +221,7 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { http.Error(w, "no server available", http.StatusBadRequest) return } - ss, err := server.Connect(req.Context(), transport) + ss, err := server.Connect(req.Context(), transport, nil) if err != nil { http.Error(w, "connection failed", http.StatusInternalServerError) return diff --git a/mcp/streamable.go b/mcp/streamable.go index 278b3ea0..7c58d3bc 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -180,12 +180,15 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // Pass req.Context() here, to allow middleware to add context values. // The context is detached in the jsonrpc2 library when handling the // long-running stream. - ss, err := server.Connect(req.Context(), session) + _, err = server.Connect(req.Context(), session, &ServerSessionOptions{ + SessionID: sessionID, + SessionState: state, + SessionStore: h.opts.SessionStore, + }) if err != nil { http.Error(w, fmt.Sprintf("failed connection: %v", err), http.StatusInternalServerError) return } - ss.InitSession(sessionID, state, h.opts.SessionStore) h.transportMu.Lock() h.transports[session.sessionID] = session h.transportMu.Unlock() From 5b44363b18cbc059c20adea34fcdd255d7f84589 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sat, 9 Aug 2025 13:40:49 -0400 Subject: [PATCH 3/6] fix race (#232) --- mcp/conformance_test.go | 2 +- mcp/server.go | 6 +++--- mcp/streamable.go | 25 ++++++++++++++----------- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/mcp/conformance_test.go b/mcp/conformance_test.go index 883d8a89..8e6ea1be 100644 --- a/mcp/conformance_test.go +++ b/mcp/conformance_test.go @@ -135,7 +135,7 @@ func runServerTest(t *testing.T, test *conformanceTest) { // Connect the server, and connect the client stream, // but don't connect an actual client. cTransport, sTransport := NewInMemoryTransports() - ss, err := s.Connect(ctx, sTransport) + ss, err := s.Connect(ctx, sTransport, nil) if err != nil { t.Fatal(err) } diff --git a/mcp/server.go b/mcp/server.go index 4d153f4f..6cbae0de 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -559,7 +559,7 @@ func (s *Server) disconnect(cc *ServerSession) { } } -type ServerSessionOptions struct { +type SessionOptions struct { SessionID string SessionState *SessionState SessionStore SessionStore @@ -571,7 +571,7 @@ type ServerSessionOptions struct { // It returns a connection object that may be used to terminate the connection // (with [Connection.Close]), or await client termination (with // [Connection.Wait]). -func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOptions) (*ServerSession, error) { +func (s *Server) Connect(ctx context.Context, t Transport, opts *SessionOptions) (*ServerSession, error) { if opts != nil && opts.SessionState == nil && opts.SessionStore != nil { return nil, errors.New("ServerSessionOptions has store but no state") } @@ -634,7 +634,7 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot type ServerSession struct { server *Server conn *jsonrpc2.Connection - opts ServerSessionOptions + opts SessionOptions mu sync.Mutex logLevel LoggingLevel _initialized bool diff --git a/mcp/streamable.go b/mcp/streamable.go index 7c58d3bc..60205ece 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -117,26 +117,26 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } - var session *StreamableServerTransport + var transport *StreamableServerTransport sessionID := req.Header.Get(sessionIDHeader) if sessionID != "" { h.transportMu.Lock() - session, _ = h.transports[sessionID] + transport, _ = h.transports[sessionID] h.transportMu.Unlock() } // TODO(rfindley): simplify the locking so that each request has only one // critical section. if req.Method == http.MethodDelete { - if session == nil { + if transport == nil { // => Mcp-Session-Id was not set; else we'd have returned NotFound above. http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) return } h.transportMu.Lock() - delete(h.transports, session.sessionID) + delete(h.transports, transport.sessionID) h.transportMu.Unlock() - session.Close() + transport.Close() w.WriteHeader(http.StatusNoContent) return } @@ -149,7 +149,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } - if session == nil { + if transport == nil { var state *SessionState var err error if sessionID != "" { @@ -162,7 +162,6 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, fmt.Sprintf("SessionStore.Load(%q): %v", sessionID, err), http.StatusInternalServerError) return } - session = NewStreamableServerTransport(sessionID, nil) } else { state = &SessionState{} sessionID = randText() @@ -170,8 +169,8 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, fmt.Sprintf("SessionStore.Store, new session: %v", err), http.StatusInternalServerError) return } - session = NewStreamableServerTransport(sessionID, nil) } + transport = NewStreamableServerTransport(sessionID, nil) server := h.getServer(req) if server == nil { http.Error(w, "no server available", http.StatusBadRequest) @@ -180,7 +179,8 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // Pass req.Context() here, to allow middleware to add context values. // The context is detached in the jsonrpc2 library when handling the // long-running stream. - _, err = server.Connect(req.Context(), session, &ServerSessionOptions{ + // TODO: rename SessionOptions to ConnectOptions? + _, err = server.Connect(req.Context(), transport, &SessionOptions{ SessionID: sessionID, SessionState: state, SessionStore: h.opts.SessionStore, @@ -190,11 +190,14 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } h.transportMu.Lock() - h.transports[session.sessionID] = session + // Check in case another request with the same stored session ID got here first. + if _, ok := h.transports[transport.sessionID]; !ok { + h.transports[transport.sessionID] = transport + } h.transportMu.Unlock() } - session.ServeHTTP(w, req) + transport.ServeHTTP(w, req) } type StreamableServerTransportOptions struct { From 5ebbb94d19609cdd2b0f152c0dadde021df9e8c6 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 11 Aug 2025 07:03:09 -0400 Subject: [PATCH 4/6] load/save state for each method; add test --- mcp/logging.go | 12 ++--- mcp/server.go | 116 +++++++++++++++++++++++++++-------------- mcp/session.go | 13 +++-- mcp/session_test.go | 6 +-- mcp/streamable.go | 6 ++- mcp/streamable_test.go | 92 ++++++++++++++++++++++++++++++-- 6 files changed, 188 insertions(+), 57 deletions(-) diff --git a/mcp/logging.go b/mcp/logging.go index 0c031f19..e09e000e 100644 --- a/mcp/logging.go +++ b/mcp/logging.go @@ -112,14 +112,12 @@ func NewLoggingHandler(ss *ServerSession, opts *LoggingHandlerOptions) *LoggingH return lh } -// Enabled implements [slog.Handler.Enabled] by comparing level to the [ServerSession]'s level. +// Enabled implements [slog.Handler.Enabled]. func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool { - // This is also checked in ServerSession.LoggingMessage, so checking it here - // is just an optimization that skips building the JSON. - h.ss.mu.Lock() - mcpLevel := h.ss.opts.SessionState.LogLevel - h.ss.mu.Unlock() - return level >= mcpLevelToSlog(mcpLevel) + // This is already checked in ServerSession.LoggingMessage. Checking it here + // would be an optimization that skips building the JSON, but it would + // end up loading the SessionState twice, so don't do it. + return true } // WithAttrs implements [slog.Handler.WithAttrs]. diff --git a/mcp/server.go b/mcp/server.go index 6cbae0de..584c1020 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -560,17 +560,22 @@ func (s *Server) disconnect(cc *ServerSession) { } type SessionOptions struct { - SessionID string + // SessionID is the session's unique ID. + SessionID string + // SessionState is the current state of the session. The default is the initial + // state. SessionState *SessionState + // SessionStore stores SessionStates. By default it is a MemorySessionStore. SessionStore SessionStore } // Connect connects the MCP server over the given transport and starts handling // messages. +// It returns a [ServerSession] for interacting with a [Client]. // -// It returns a connection object that may be used to terminate the connection -// (with [Connection.Close]), or await client termination (with -// [Connection.Wait]). +// [SessionOptions.SessionStore] should be nil only for single-session transports, +// like [StdioTransport]. Multi-session transports, like [StreamableServerTransport], +// must provide a [SessionStore]. func (s *Server) Connect(ctx context.Context, t Transport, opts *SessionOptions) (*ServerSession, error) { if opts != nil && opts.SessionState == nil && opts.SessionStore != nil { return nil, errors.New("ServerSessionOptions has store but no state") @@ -579,11 +584,21 @@ func (s *Server) Connect(ctx context.Context, t Transport, opts *SessionOptions) if err != nil { return nil, err } + var state *SessionState if opts != nil { - ss.opts = *opts + ss.sessionID = opts.SessionID + ss.store = opts.SessionStore + state = opts.SessionState } - if ss.opts.SessionState == nil { - ss.opts.SessionState = &SessionState{} + if ss.store == nil { + ss.store = NewMemorySessionStore() + } + if state == nil { + state = &SessionState{} + } + // TODO(jba): This store is redundant with the one in StreamableHTTPHandler.ServeHTTP; dedup. + if err := ss.storeState(ctx, state); err != nil { + return nil, err } return ss, nil } @@ -592,20 +607,21 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar if ss.server.opts.KeepAlive > 0 { ss.startKeepalive(ss.server.opts.KeepAlive) } - ss.mu.Lock() - hasParams := ss.opts.SessionState.InitializeParams != nil - wasInitialized := ss._initialized - if hasParams { - ss._initialized = true + // TODO(jba): optimistic locking + state, err := ss.loadState(ctx) + if err != nil { + return nil, err } - ss.mu.Unlock() - - if !hasParams { + if state.InitializeParams == nil { return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize) } - if wasInitialized { + if state.Initialized { return nil, fmt.Errorf("duplicate %q received", notificationInitialized) } + state.Initialized = true + if err := ss.storeState(ctx, state); err != nil { + return nil, err + } return callNotificationHandler(ctx, ss.server.opts.InitializedHandler, ss, params) } @@ -634,19 +650,26 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot type ServerSession struct { server *Server conn *jsonrpc2.Connection - opts SessionOptions mu sync.Mutex logLevel LoggingLevel - _initialized bool keepaliveCancel context.CancelFunc sessionID string + store SessionStore } func (ss *ServerSession) setConn(c Connection) { } func (ss *ServerSession) ID() string { - return ss.opts.SessionID + return ss.sessionID +} + +func (ss *ServerSession) loadState(ctx context.Context) (*SessionState, error) { + return ss.store.Load(ctx, ss.sessionID) +} + +func (ss *ServerSession) storeState(ctx context.Context, state *SessionState) error { + return ss.store.Store(ctx, ss.sessionID, state) } // Ping pings the client. @@ -669,16 +692,19 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag // The message is not sent if the client has not called SetLevel, or if its level // is below that of the last SetLevel. func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) error { - ss.mu.Lock() - logLevel := ss.opts.SessionState.LogLevel - ss.mu.Unlock() - if logLevel == "" { + // TODO: Loading the state on every log message can be expensive. Consider caching it briefly, perhaps for the + // duration of a request. + state, err := ss.loadState(ctx) + if err != nil { + return err + } + if state.LogLevel == "" { // The spec is unclear, but seems to imply that no log messages are sent until the client // sets the level. // TODO(jba): read other SDKs, possibly file an issue. return nil } - if compareLevels(params.Level, logLevel) < 0 { + if compareLevels(params.Level, state.LogLevel) < 0 { return nil } return handleNotify(ctx, ss, notificationLoggingMessage, params) @@ -759,16 +785,17 @@ func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn } // handle invokes the method described by the given JSON RPC request. func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { - ss.mu.Lock() - initialized := ss._initialized - ss.mu.Unlock() + state, err := ss.loadState(ctx) + if err != nil { + return nil, err + } // From the spec: // "The client SHOULD NOT send requests other than pings before the server // has responded to the initialize request." switch req.Method { case methodInitialize, methodPing, notificationInitialized: default: - if !initialized { + if !state.Initialized { return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) } } @@ -776,6 +803,7 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, // server->client calls and notifications to the incoming request from which // they originated. See [idContextKey] for details. ctx = context.WithValue(ctx, idContextKey{}, req.ID) + // TODO(jba): pass the state down so that it doesn't get loaded again. return handleReceive(ctx, ss, req) } @@ -783,13 +811,18 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam if params == nil { return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) } - ss.mu.Lock() - ss.opts.SessionState.InitializeParams = params - ss.mu.Unlock() - if store := ss.opts.SessionStore; store != nil { - if err := store.Store(ctx, ss.opts.SessionID, ss.opts.SessionState); err != nil { - return nil, fmt.Errorf("storing session state: %w", err) - } + + // TODO(jba): optimistic locking + state, err := ss.loadState(ctx) + if err != nil { + return nil, err + } + if state.InitializeParams != nil { + return nil, fmt.Errorf("session %s already initialized", ss.sessionID) + } + state.InitializeParams = params + if err := ss.storeState(ctx, state); err != nil { + return nil, fmt.Errorf("storing session state: %w", err) } // If we support the client's version, reply with it. Otherwise, reply with our @@ -816,11 +849,14 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error func (ss *ServerSession) setLevel(ctx context.Context, params *SetLevelParams) (*emptyResult, error) { ss.mu.Lock() defer ss.mu.Unlock() - ss.opts.SessionState.LogLevel = params.Level - if store := ss.opts.SessionStore; store != nil { - if err := store.Store(ctx, ss.opts.SessionID, ss.opts.SessionState); err != nil { - return nil, err - } + + state, err := ss.loadState(ctx) + if err != nil { + return nil, err + } + state.LogLevel = params.Level + if err := ss.storeState(ctx, state); err != nil { + return nil, err } return &emptyResult{}, nil } diff --git a/mcp/session.go b/mcp/session.go index d4a944c1..255a3619 100644 --- a/mcp/session.go +++ b/mcp/session.go @@ -6,7 +6,8 @@ package mcp import ( "context" - "io/fs" + "errors" + "fmt" "sync" ) @@ -14,6 +15,9 @@ import ( type SessionState struct { // InitializeParams are the parameters from the initialize request. InitializeParams *InitializeParams `json:"initializeParams"` + // Initialized reports whether the session received an "initialized" notification + // from the client. + Initialized bool // LogLevel is the logging level for the session. LogLevel LoggingLevel `json:"logLevel"` @@ -21,10 +25,13 @@ type SessionState struct { // TODO: resource subscriptions } +// ErrNotSession indicates that a session is not in a SessionStore. +var ErrNoSession = errors.New("no such session") + // SessionStore is an interface for storing and retrieving session state. type SessionStore interface { // Load retrieves the session state for the given session ID. - // If there is none, it returns nil, fs.ErrNotExist. + // If there is none, it returns nil and an error wrapping ErrNoSession. Load(ctx context.Context, sessionID string) (*SessionState, error) // Store saves the session state for the given session ID. Store(ctx context.Context, sessionID string, state *SessionState) error @@ -52,7 +59,7 @@ func (s *MemorySessionStore) Load(ctx context.Context, sessionID string) (*Sessi defer s.mu.Unlock() state, ok := s.store[sessionID] if !ok { - return nil, fs.ErrNotExist + return nil, fmt.Errorf("session ID %q: %w", sessionID, ErrNoSession) } return state, nil } diff --git a/mcp/session_test.go b/mcp/session_test.go index a4d2e966..6c8f2fe1 100644 --- a/mcp/session_test.go +++ b/mcp/session_test.go @@ -6,7 +6,7 @@ package mcp import ( "context" - "io/fs" + "errors" "testing" ) @@ -39,8 +39,8 @@ func TestMemorySessionStore(t *testing.T) { } deletedState, err := store.Load(ctx, sessionID) - if err != fs.ErrNotExist { - t.Fatalf("Load() after Delete(): got %v, want fs.ErrNotExist", err) + if !errors.Is(err, ErrNoSession) { + t.Fatalf("Load() after Delete(): got %v, want ErrNoSession", err) } if deletedState != nil { t.Error("Load() after Delete() returned non-nil state") diff --git a/mcp/streamable.go b/mcp/streamable.go index 60205ece..e5e51ac1 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -163,8 +163,9 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } } else { + // New session: store an empty state. state = &SessionState{} - sessionID = randText() + sessionID = newSessionID() if err := h.opts.SessionStore.Store(req.Context(), sessionID, state); err != nil { http.Error(w, fmt.Sprintf("SessionStore.Store, new session: %v", err), http.StatusInternalServerError) return @@ -1233,3 +1234,6 @@ func calculateReconnectDelay(opts *StreamableReconnectOptions, attempt int) time return backoffDuration + jitter } + +// For testing. +var newSessionID = randText diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 24368b00..1f4a0a67 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -16,6 +16,7 @@ import ( "net/http/httptest" "net/http/httputil" "net/url" + "slices" "strings" "sync" "sync/atomic" @@ -235,9 +236,10 @@ func TestServerInitiatedSSE(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - client := NewClient(testImpl, &ClientOptions{ToolListChangedHandler: func(ctx context.Context, cc *ClientSession, params *ToolListChangedParams) { - notifications <- "toolListChanged" - }, + client := NewClient(testImpl, &ClientOptions{ + ToolListChangedHandler: func(ctx context.Context, cc *ClientSession, params *ToolListChangedParams) { + notifications <- "toolListChanged" + }, }) clientSession, err := client.Connect(ctx, NewStreamableClientTransport(httpServer.URL, nil)) if err != nil { @@ -821,3 +823,87 @@ func TestEventID(t *testing.T) { }) } } + +func TestDistributedSessionStore(t *testing.T) { + // To simulate a distributed server with a shared durable SessionStore, we use two distinct + // HTTP servers in memory with a shared MemorySessionStore, and two sessions with the same ID. + + defer func(f func() string) { + newSessionID = f + }(newSessionID) + newSessionID = func() string { return "test-session" } + + ctx := context.Background() + + // Start a server with a single tool. + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "tool"}, func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[any]) (*CallToolResultFor[int], error) { + ss.Log(ctx, &LoggingMessageParams{ + Level: "info", + Logger: "tool", + }) + return &CallToolResultFor[int]{StructuredContent: 3}, nil + }) + // indexes are: SetLevel, CallTool. + for bits := range 1 << 2 { + t.Run(fmt.Sprintf("%04b", bits), func(t *testing.T) { + indexes := bitsToSlice(bits, 4) + opts := &StreamableHTTPOptions{SessionStore: NewMemorySessionStore()} + var urls []string + for range 2 { + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, opts) + + defer handler.closeAll() + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + urls = append(urls, httpServer.URL) + } + + // The log handler will only be called in all cases if the SetLevel change is properly stored in the state. + logCalled := make(chan struct{}) + logHandler := func(_ context.Context, _ *ClientSession, params *LoggingMessageParams) { + close(logCalled) + } + + // Connect clients to each HTTP server. + var clientSessions []*ClientSession + for i := range 2 { + client := NewClient(testImpl, &ClientOptions{LoggingMessageHandler: logHandler}) + // TODO: split initialization handshake between servers. This will send both init messages to the same one. + cs, err := client.Connect(ctx, NewStreamableClientTransport(urls[i], nil)) + if err != nil { + t.Fatal(err) + } + clientSessions = append(clientSessions, cs) + } + + clientSessions[indexes[0]].SetLevel(ctx, &SetLevelParams{Level: "info"}) + + res, err := clientSessions[indexes[1]].CallTool(ctx, &CallToolParams{Name: "tool"}) + if err != nil { + t.Fatal(err) + } + // The logging notification might arrive after CallTool returns. + select { + case <-logCalled: + case <-time.After(time.Second): + t.Error("log not called") + } + if g, w := res.StructuredContent, 3.0; g != w { + t.Errorf("result: got %v %[1]T, want %v %[2]T", g, w) + } + }) + } +} + +// bitsToSlice splits the low-order n bits of bits into a slice of individual bit values. +// For example, 0101 => []int{0, 1, 0, 1}. +func bitsToSlice(bits, n int) []int { + var ints []int + for range n { + ints = append(ints, bits&1) + bits >>= 1 + } + slices.Reverse(ints) + return ints +} From 748de2b3c467e314230393a01203d41bca911b77 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 11 Aug 2025 08:11:15 -0400 Subject: [PATCH 5/6] fix data race --- mcp/server.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mcp/server.go b/mcp/server.go index 584c1020..d6f8103c 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -585,6 +585,7 @@ func (s *Server) Connect(ctx context.Context, t Transport, opts *SessionOptions) return nil, err } var state *SessionState + ss.mu.Lock() if opts != nil { ss.sessionID = opts.SessionID ss.store = opts.SessionStore @@ -593,6 +594,7 @@ func (s *Server) Connect(ctx context.Context, t Transport, opts *SessionOptions) if ss.store == nil { ss.store = NewMemorySessionStore() } + ss.mu.Unlock() if state == nil { state = &SessionState{} } @@ -665,10 +667,14 @@ func (ss *ServerSession) ID() string { } func (ss *ServerSession) loadState(ctx context.Context) (*SessionState, error) { + ss.mu.Lock() + defer ss.mu.Unlock() return ss.store.Load(ctx, ss.sessionID) } func (ss *ServerSession) storeState(ctx context.Context, state *SessionState) error { + ss.mu.Lock() + defer ss.mu.Unlock() return ss.store.Store(ctx, ss.sessionID, state) } From 0b72210248bcc4de1281024decbe4b4f4b225aae Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 11 Aug 2025 08:30:05 -0400 Subject: [PATCH 6/6] fix deadlock --- mcp/server.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index d6f8103c..68d8b3f9 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -853,9 +853,6 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error } func (ss *ServerSession) setLevel(ctx context.Context, params *SetLevelParams) (*emptyResult, error) { - ss.mu.Lock() - defer ss.mu.Unlock() - state, err := ss.loadState(ctx) if err != nil { return nil, err