diff --git a/mcp/conformance_test.go b/mcp/conformance_test.go index 883d8a89..8e6ea1be 100644 --- a/mcp/conformance_test.go +++ b/mcp/conformance_test.go @@ -135,7 +135,7 @@ func runServerTest(t *testing.T, test *conformanceTest) { // Connect the server, and connect the client stream, // but don't connect an actual client. cTransport, sTransport := NewInMemoryTransports() - ss, err := s.Connect(ctx, sTransport) + ss, err := s.Connect(ctx, sTransport, nil) if err != nil { t.Fatal(err) } diff --git a/mcp/example_middleware_test.go b/mcp/example_middleware_test.go index 597b9dcd..05b07d8a 100644 --- a/mcp/example_middleware_test.go +++ b/mcp/example_middleware_test.go @@ -114,7 +114,7 @@ func Example_loggingMiddleware() { ctx := context.Background() // Connect server and client - serverSession, _ := server.Connect(ctx, serverTransport) + serverSession, _ := server.Connect(ctx, serverTransport, nil) defer serverSession.Close() clientSession, _ := client.Connect(ctx, clientTransport) diff --git a/mcp/logging.go b/mcp/logging.go index 4880e179..e09e000e 100644 --- a/mcp/logging.go +++ b/mcp/logging.go @@ -112,14 +112,12 @@ func NewLoggingHandler(ss *ServerSession, opts *LoggingHandlerOptions) *LoggingH return lh } -// Enabled implements [slog.Handler.Enabled] by comparing level to the [ServerSession]'s level. +// Enabled implements [slog.Handler.Enabled]. func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool { - // This is also checked in ServerSession.LoggingMessage, so checking it here - // is just an optimization that skips building the JSON. - h.ss.mu.Lock() - mcpLevel := h.ss.logLevel - h.ss.mu.Unlock() - return level >= mcpLevelToSlog(mcpLevel) + // This is already checked in ServerSession.LoggingMessage. Checking it here + // would be an optimization that skips building the JSON, but it would + // end up loading the SessionState twice, so don't do it. + return true } // WithAttrs implements [slog.Handler.WithAttrs]. diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 48e95de2..abc5b398 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -104,7 +104,7 @@ func TestEndToEnd(t *testing.T) { s.AddResource(resource2, readHandler) // Connect the server. - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -549,7 +549,7 @@ func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *Clien if config != nil { config(s) } - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -598,7 +598,7 @@ func TestBatching(t *testing.T) { ct, st := NewInMemoryTransports() s := NewServer(testImpl, nil) - _, err := s.Connect(ctx, st) + _, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -668,7 +668,7 @@ func TestMiddleware(t *testing.T) { ct, st := NewInMemoryTransports() s := NewServer(testImpl, nil) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -777,7 +777,7 @@ func TestNoJSONNull(t *testing.T) { ct = NewLoggingTransport(ct, &logbuf) s := NewServer(testImpl, nil) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -845,7 +845,7 @@ func TestKeepAlive(t *testing.T) { s := NewServer(testImpl, serverOpts) AddTool(s, greetTool(), sayHi) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -889,7 +889,7 @@ func TestKeepAliveFailure(t *testing.T) { // Server without keepalive (to test one-sided keepalive) s := NewServer(testImpl, nil) AddTool(s, greetTool(), sayHi) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } diff --git a/mcp/server.go b/mcp/server.go index e69a872e..68d8b3f9 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -10,6 +10,7 @@ import ( "encoding/base64" "encoding/gob" "encoding/json" + "errors" "fmt" "iter" "maps" @@ -514,7 +515,8 @@ func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *Uns // If no tools have been added, the server will not have the tool capability. // The same goes for other features like prompts and resources. func (s *Server) Run(ctx context.Context, t Transport) error { - ss, err := s.Connect(ctx, t) + // TODO: provide a way to pass ServerSessionOptions? + ss, err := s.Connect(ctx, t, nil) if err != nil { return err } @@ -557,34 +559,71 @@ func (s *Server) disconnect(cc *ServerSession) { } } +type SessionOptions struct { + // SessionID is the session's unique ID. + SessionID string + // SessionState is the current state of the session. The default is the initial + // state. + SessionState *SessionState + // SessionStore stores SessionStates. By default it is a MemorySessionStore. + SessionStore SessionStore +} + // Connect connects the MCP server over the given transport and starts handling // messages. +// It returns a [ServerSession] for interacting with a [Client]. // -// It returns a connection object that may be used to terminate the connection -// (with [Connection.Close]), or await client termination (with -// [Connection.Wait]). -func (s *Server) Connect(ctx context.Context, t Transport) (*ServerSession, error) { - return connect(ctx, t, s) +// [SessionOptions.SessionStore] should be nil only for single-session transports, +// like [StdioTransport]. Multi-session transports, like [StreamableServerTransport], +// must provide a [SessionStore]. +func (s *Server) Connect(ctx context.Context, t Transport, opts *SessionOptions) (*ServerSession, error) { + if opts != nil && opts.SessionState == nil && opts.SessionStore != nil { + return nil, errors.New("ServerSessionOptions has store but no state") + } + ss, err := connect(ctx, t, s) + if err != nil { + return nil, err + } + var state *SessionState + ss.mu.Lock() + if opts != nil { + ss.sessionID = opts.SessionID + ss.store = opts.SessionStore + state = opts.SessionState + } + if ss.store == nil { + ss.store = NewMemorySessionStore() + } + ss.mu.Unlock() + if state == nil { + state = &SessionState{} + } + // TODO(jba): This store is redundant with the one in StreamableHTTPHandler.ServeHTTP; dedup. + if err := ss.storeState(ctx, state); err != nil { + return nil, err + } + return ss, nil } func (ss *ServerSession) initialized(ctx context.Context, params *InitializedParams) (Result, error) { if ss.server.opts.KeepAlive > 0 { ss.startKeepalive(ss.server.opts.KeepAlive) } - ss.mu.Lock() - hasParams := ss.initializeParams != nil - wasInitialized := ss._initialized - if hasParams { - ss._initialized = true + // TODO(jba): optimistic locking + state, err := ss.loadState(ctx) + if err != nil { + return nil, err } - ss.mu.Unlock() - - if !hasParams { + if state.InitializeParams == nil { return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize) } - if wasInitialized { + if state.Initialized { return nil, fmt.Errorf("duplicate %q received", notificationInitialized) } + state.Initialized = true + if err := ss.storeState(ctx, state); err != nil { + return nil, err + } return callNotificationHandler(ctx, ss.server.opts.InitializedHandler, ss, params) } @@ -611,25 +650,32 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot // Call [ServerSession.Close] to close the connection, or await client // termination with [ServerSession.Wait]. type ServerSession struct { - server *Server - conn *jsonrpc2.Connection - mcpConn Connection - mu sync.Mutex - logLevel LoggingLevel - initializeParams *InitializeParams - _initialized bool - keepaliveCancel context.CancelFunc + server *Server + conn *jsonrpc2.Connection + mu sync.Mutex + logLevel LoggingLevel + keepaliveCancel context.CancelFunc + sessionID string + store SessionStore } func (ss *ServerSession) setConn(c Connection) { - ss.mcpConn = c } func (ss *ServerSession) ID() string { - if ss.mcpConn == nil { - return "" - } - return ss.mcpConn.SessionID() + return ss.sessionID +} + +func (ss *ServerSession) loadState(ctx context.Context) (*SessionState, error) { + ss.mu.Lock() + defer ss.mu.Unlock() + return ss.store.Load(ctx, ss.sessionID) +} + +func (ss *ServerSession) storeState(ctx context.Context, state *SessionState) error { + ss.mu.Lock() + defer ss.mu.Unlock() + return ss.store.Store(ctx, ss.sessionID, state) } // Ping pings the client. @@ -652,16 +698,19 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag // The message is not sent if the client has not called SetLevel, or if its level // is below that of the last SetLevel. func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) error { - ss.mu.Lock() - logLevel := ss.logLevel - ss.mu.Unlock() - if logLevel == "" { + // TODO: Loading the state on every log message can be expensive. Consider caching it briefly, perhaps for the + // duration of a request. + state, err := ss.loadState(ctx) + if err != nil { + return err + } + if state.LogLevel == "" { // The spec is unclear, but seems to imply that no log messages are sent until the client // sets the level. // TODO(jba): read other SDKs, possibly file an issue. return nil } - if compareLevels(params.Level, logLevel) < 0 { + if compareLevels(params.Level, state.LogLevel) < 0 { return nil } return handleNotify(ctx, ss, notificationLoggingMessage, params) @@ -742,16 +791,17 @@ func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn } // handle invokes the method described by the given JSON RPC request. func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { - ss.mu.Lock() - initialized := ss._initialized - ss.mu.Unlock() + state, err := ss.loadState(ctx) + if err != nil { + return nil, err + } // From the spec: // "The client SHOULD NOT send requests other than pings before the server // has responded to the initialize request." switch req.Method { case methodInitialize, methodPing, notificationInitialized: default: - if !initialized { + if !state.Initialized { return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) } } @@ -759,6 +809,7 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, // server->client calls and notifications to the incoming request from which // they originated. See [idContextKey] for details. ctx = context.WithValue(ctx, idContextKey{}, req.ID) + // TODO(jba): pass the state down so that it doesn't get loaded again. return handleReceive(ctx, ss, req) } @@ -766,9 +817,19 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam if params == nil { return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) } - ss.mu.Lock() - ss.initializeParams = params - ss.mu.Unlock() + + // TODO(jba): optimistic locking + state, err := ss.loadState(ctx) + if err != nil { + return nil, err + } + if state.InitializeParams != nil { + return nil, fmt.Errorf("session %s already initialized", ss.sessionID) + } + state.InitializeParams = params + if err := ss.storeState(ctx, state); err != nil { + return nil, fmt.Errorf("storing session state: %w", err) + } // If we support the client's version, reply with it. Otherwise, reply with our // latest version. @@ -791,10 +852,15 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error return &emptyResult{}, nil } -func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*emptyResult, error) { - ss.mu.Lock() - defer ss.mu.Unlock() - ss.logLevel = params.Level +func (ss *ServerSession) setLevel(ctx context.Context, params *SetLevelParams) (*emptyResult, error) { + state, err := ss.loadState(ctx) + if err != nil { + return nil, err + } + state.LogLevel = params.Level + if err := ss.storeState(ctx, state); err != nil { + return nil, err + } return &emptyResult{}, nil } diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go index 3ab7a2a4..33bde4b9 100644 --- a/mcp/server_example_test.go +++ b/mcp/server_example_test.go @@ -31,7 +31,7 @@ func ExampleServer() { server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v0.0.1"}, nil) mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) - serverSession, err := server.Connect(ctx, serverTransport) + serverSession, err := server.Connect(ctx, serverTransport, nil) if err != nil { log.Fatal(err) } @@ -62,7 +62,7 @@ func createSessions(ctx context.Context) (*mcp.ClientSession, *mcp.ServerSession server := mcp.NewServer(testImpl, nil) client := mcp.NewClient(testImpl, nil) serverTransport, clientTransport := mcp.NewInMemoryTransports() - serverSession, err := server.Connect(ctx, serverTransport) + serverSession, err := server.Connect(ctx, serverTransport, nil) if err != nil { log.Fatal(err) } diff --git a/mcp/session.go b/mcp/session.go new file mode 100644 index 00000000..255a3619 --- /dev/null +++ b/mcp/session.go @@ -0,0 +1,81 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "errors" + "fmt" + "sync" +) + +// SessionState is the state of a session. +type SessionState struct { + // InitializeParams are the parameters from the initialize request. + InitializeParams *InitializeParams `json:"initializeParams"` + // Initialized reports whether the session received an "initialized" notification + // from the client. + Initialized bool + + // LogLevel is the logging level for the session. + LogLevel LoggingLevel `json:"logLevel"` + + // TODO: resource subscriptions +} + +// ErrNotSession indicates that a session is not in a SessionStore. +var ErrNoSession = errors.New("no such session") + +// SessionStore is an interface for storing and retrieving session state. +type SessionStore interface { + // Load retrieves the session state for the given session ID. + // If there is none, it returns nil and an error wrapping ErrNoSession. + Load(ctx context.Context, sessionID string) (*SessionState, error) + // Store saves the session state for the given session ID. + Store(ctx context.Context, sessionID string, state *SessionState) error + // Delete removes the session state for the given session ID. + Delete(ctx context.Context, sessionID string) error +} + +// MemorySessionStore is an in-memory implementation of SessionStore. +// It is safe for concurrent use. +type MemorySessionStore struct { + mu sync.Mutex + store map[string]*SessionState +} + +// NewMemorySessionStore creates a new MemorySessionStore. +func NewMemorySessionStore() *MemorySessionStore { + return &MemorySessionStore{ + store: make(map[string]*SessionState), + } +} + +// Load retrieves the session state for the given session ID. +func (s *MemorySessionStore) Load(ctx context.Context, sessionID string) (*SessionState, error) { + s.mu.Lock() + defer s.mu.Unlock() + state, ok := s.store[sessionID] + if !ok { + return nil, fmt.Errorf("session ID %q: %w", sessionID, ErrNoSession) + } + return state, nil +} + +// Store saves the session state for the given session ID. +func (s *MemorySessionStore) Store(ctx context.Context, sessionID string, state *SessionState) error { + s.mu.Lock() + defer s.mu.Unlock() + s.store[sessionID] = state + return nil +} + +// Delete removes the session state for the given session ID. +func (s *MemorySessionStore) Delete(ctx context.Context, sessionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.store, sessionID) + return nil +} diff --git a/mcp/session_test.go b/mcp/session_test.go new file mode 100644 index 00000000..6c8f2fe1 --- /dev/null +++ b/mcp/session_test.go @@ -0,0 +1,48 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "errors" + "testing" +) + +func TestMemorySessionStore(t *testing.T) { + ctx := context.Background() + store := NewMemorySessionStore() + + sessionID := "test-session" + state := &SessionState{LogLevel: "debug"} + + // Test Store and Load + if err := store.Store(ctx, sessionID, state); err != nil { + t.Fatalf("Store() error = %v", err) + } + + loadedState, err := store.Load(ctx, sessionID) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if loadedState == nil { + t.Fatal("Load() returned nil state") + } + if loadedState.LogLevel != state.LogLevel { + t.Errorf("Load() LogLevel = %v, want %v", loadedState.LogLevel, state.LogLevel) + } + + // Test Delete + if err := store.Delete(ctx, sessionID); err != nil { + t.Fatalf("Delete() error = %v", err) + } + + deletedState, err := store.Load(ctx, sessionID) + if !errors.Is(err, ErrNoSession) { + t.Fatalf("Load() after Delete(): got %v, want ErrNoSession", err) + } + if deletedState != nil { + t.Error("Load() after Delete() returned non-nil state") + } +} diff --git a/mcp/sse.go b/mcp/sse.go index bdc4770b..f74a3fb6 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -221,7 +221,7 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { http.Error(w, "no server available", http.StatusBadRequest) return } - ss, err := server.Connect(req.Context(), transport) + ss, err := server.Connect(req.Context(), transport, nil) if err != nil { http.Error(w, "connection failed", http.StatusInternalServerError) return diff --git a/mcp/streamable.go b/mcp/streamable.go index 108de5d2..e5e51ac1 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "io/fs" "iter" "math" "math/rand/v2" @@ -35,11 +36,11 @@ const ( // // [MCP spec]: https://modelcontextprotocol.io/2025/03/26/streamable-http-transport.html type StreamableHTTPHandler struct { - getServer func(*http.Request) *Server opts StreamableHTTPOptions + getServer func(*http.Request) *Server - sessionsMu sync.Mutex - sessions map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) + transportMu sync.Mutex + transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) } // StreamableHTTPOptions is a placeholder options struct for future @@ -51,6 +52,10 @@ type StreamableHTTPOptions struct { // transportOptions sets the streamable server transport options to use when // establishing a new session. transportOptions *StreamableServerTransportOptions + + // SessionStore is the store for persistent sessions. + // If nil, sessions will be stored in memory. + SessionStore SessionStore } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -60,12 +65,15 @@ type StreamableHTTPOptions struct { // If getServer returns nil, a 400 Bad Request will be served. func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *StreamableHTTPOptions) *StreamableHTTPHandler { h := &StreamableHTTPHandler{ - getServer: getServer, - sessions: make(map[string]*StreamableServerTransport), + getServer: getServer, + transports: make(map[string]*StreamableServerTransport), } if opts != nil { h.opts = *opts } + if h.opts.SessionStore == nil { + h.opts.SessionStore = NewMemorySessionStore() + } return h } @@ -77,12 +85,12 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea // Should we allow passing in a session store? That would allow the handler to // be stateless. func (h *StreamableHTTPHandler) closeAll() { - h.sessionsMu.Lock() - defer h.sessionsMu.Unlock() - for _, s := range h.sessions { - s.Close() + h.transportMu.Lock() + defer h.transportMu.Unlock() + for _, t := range h.transports { + t.Close() } - h.sessions = nil + h.transports = nil } func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -109,29 +117,26 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } - var session *StreamableServerTransport - if id := req.Header.Get(sessionIDHeader); id != "" { - h.sessionsMu.Lock() - session, _ = h.sessions[id] - h.sessionsMu.Unlock() - if session == nil { - http.Error(w, "session not found", http.StatusNotFound) - return - } + var transport *StreamableServerTransport + sessionID := req.Header.Get(sessionIDHeader) + if sessionID != "" { + h.transportMu.Lock() + transport, _ = h.transports[sessionID] + h.transportMu.Unlock() } // TODO(rfindley): simplify the locking so that each request has only one // critical section. if req.Method == http.MethodDelete { - if session == nil { + if transport == nil { // => Mcp-Session-Id was not set; else we'd have returned NotFound above. http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) return } - h.sessionsMu.Lock() - delete(h.sessions, session.sessionID) - h.sessionsMu.Unlock() - session.Close() + h.transportMu.Lock() + delete(h.transports, transport.sessionID) + h.transportMu.Unlock() + transport.Close() w.WriteHeader(http.StatusNoContent) return } @@ -144,28 +149,56 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } - if session == nil { - s := NewStreamableServerTransport(randText(), h.opts.transportOptions) + if transport == nil { + var state *SessionState + var err error + if sessionID != "" { + // The session might be in the store. + state, err = h.opts.SessionStore.Load(req.Context(), sessionID) + if errors.Is(err, fs.ErrNotExist) { + http.Error(w, fmt.Sprintf("no session with ID %s", sessionID), http.StatusNotFound) + return + } else if err != nil { + http.Error(w, fmt.Sprintf("SessionStore.Load(%q): %v", sessionID, err), http.StatusInternalServerError) + return + } + } else { + // New session: store an empty state. + state = &SessionState{} + sessionID = newSessionID() + if err := h.opts.SessionStore.Store(req.Context(), sessionID, state); err != nil { + http.Error(w, fmt.Sprintf("SessionStore.Store, new session: %v", err), http.StatusInternalServerError) + return + } + } + transport = NewStreamableServerTransport(sessionID, nil) server := h.getServer(req) if server == nil { - // The getServer argument to NewStreamableHTTPHandler returned nil. http.Error(w, "no server available", http.StatusBadRequest) return } // Pass req.Context() here, to allow middleware to add context values. // The context is detached in the jsonrpc2 library when handling the // long-running stream. - if _, err := server.Connect(req.Context(), s); err != nil { - http.Error(w, "failed connection", http.StatusInternalServerError) + // TODO: rename SessionOptions to ConnectOptions? + _, err = server.Connect(req.Context(), transport, &SessionOptions{ + SessionID: sessionID, + SessionState: state, + SessionStore: h.opts.SessionStore, + }) + if err != nil { + http.Error(w, fmt.Sprintf("failed connection: %v", err), http.StatusInternalServerError) return } - h.sessionsMu.Lock() - h.sessions[s.sessionID] = s - h.sessionsMu.Unlock() - session = s + h.transportMu.Lock() + // Check in case another request with the same stored session ID got here first. + if _, ok := h.transports[transport.sessionID]; !ok { + h.transports[transport.sessionID] = transport + } + h.transportMu.Unlock() } - session.ServeHTTP(w, req) + transport.ServeHTTP(w, req) } type StreamableServerTransportOptions struct { @@ -1201,3 +1234,6 @@ func calculateReconnectDelay(opts *StreamableReconnectOptions, attempt int) time return backoffDuration + jitter } + +// For testing. +var newSessionID = randText diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 24368b00..1f4a0a67 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -16,6 +16,7 @@ import ( "net/http/httptest" "net/http/httputil" "net/url" + "slices" "strings" "sync" "sync/atomic" @@ -235,9 +236,10 @@ func TestServerInitiatedSSE(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - client := NewClient(testImpl, &ClientOptions{ToolListChangedHandler: func(ctx context.Context, cc *ClientSession, params *ToolListChangedParams) { - notifications <- "toolListChanged" - }, + client := NewClient(testImpl, &ClientOptions{ + ToolListChangedHandler: func(ctx context.Context, cc *ClientSession, params *ToolListChangedParams) { + notifications <- "toolListChanged" + }, }) clientSession, err := client.Connect(ctx, NewStreamableClientTransport(httpServer.URL, nil)) if err != nil { @@ -821,3 +823,87 @@ func TestEventID(t *testing.T) { }) } } + +func TestDistributedSessionStore(t *testing.T) { + // To simulate a distributed server with a shared durable SessionStore, we use two distinct + // HTTP servers in memory with a shared MemorySessionStore, and two sessions with the same ID. + + defer func(f func() string) { + newSessionID = f + }(newSessionID) + newSessionID = func() string { return "test-session" } + + ctx := context.Background() + + // Start a server with a single tool. + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "tool"}, func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[any]) (*CallToolResultFor[int], error) { + ss.Log(ctx, &LoggingMessageParams{ + Level: "info", + Logger: "tool", + }) + return &CallToolResultFor[int]{StructuredContent: 3}, nil + }) + // indexes are: SetLevel, CallTool. + for bits := range 1 << 2 { + t.Run(fmt.Sprintf("%04b", bits), func(t *testing.T) { + indexes := bitsToSlice(bits, 4) + opts := &StreamableHTTPOptions{SessionStore: NewMemorySessionStore()} + var urls []string + for range 2 { + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, opts) + + defer handler.closeAll() + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + urls = append(urls, httpServer.URL) + } + + // The log handler will only be called in all cases if the SetLevel change is properly stored in the state. + logCalled := make(chan struct{}) + logHandler := func(_ context.Context, _ *ClientSession, params *LoggingMessageParams) { + close(logCalled) + } + + // Connect clients to each HTTP server. + var clientSessions []*ClientSession + for i := range 2 { + client := NewClient(testImpl, &ClientOptions{LoggingMessageHandler: logHandler}) + // TODO: split initialization handshake between servers. This will send both init messages to the same one. + cs, err := client.Connect(ctx, NewStreamableClientTransport(urls[i], nil)) + if err != nil { + t.Fatal(err) + } + clientSessions = append(clientSessions, cs) + } + + clientSessions[indexes[0]].SetLevel(ctx, &SetLevelParams{Level: "info"}) + + res, err := clientSessions[indexes[1]].CallTool(ctx, &CallToolParams{Name: "tool"}) + if err != nil { + t.Fatal(err) + } + // The logging notification might arrive after CallTool returns. + select { + case <-logCalled: + case <-time.After(time.Second): + t.Error("log not called") + } + if g, w := res.StructuredContent, 3.0; g != w { + t.Errorf("result: got %v %[1]T, want %v %[2]T", g, w) + } + }) + } +} + +// bitsToSlice splits the low-order n bits of bits into a slice of individual bit values. +// For example, 0101 => []int{0, 1, 0, 1}. +func bitsToSlice(bits, n int) []int { + var ints []int + for range n { + ints = append(ints, bits&1) + bits >>= 1 + } + slices.Reverse(ints) + return ints +}