Skip to content

server: support distributed sessions #232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mcp/conformance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func runServerTest(t *testing.T, test *conformanceTest) {
// Connect the server, and connect the client stream,
// but don't connect an actual client.
cTransport, sTransport := NewInMemoryTransports()
ss, err := s.Connect(ctx, sTransport)
ss, err := s.Connect(ctx, sTransport, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion mcp/example_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func Example_loggingMiddleware() {
ctx := context.Background()

// Connect server and client
serverSession, _ := server.Connect(ctx, serverTransport)
serverSession, _ := server.Connect(ctx, serverTransport, nil)
defer serverSession.Close()

clientSession, _ := client.Connect(ctx, clientTransport)
Expand Down
12 changes: 5 additions & 7 deletions mcp/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,12 @@ func NewLoggingHandler(ss *ServerSession, opts *LoggingHandlerOptions) *LoggingH
return lh
}

// Enabled implements [slog.Handler.Enabled] by comparing level to the [ServerSession]'s level.
// Enabled implements [slog.Handler.Enabled].
func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool {
// This is also checked in ServerSession.LoggingMessage, so checking it here
// is just an optimization that skips building the JSON.
h.ss.mu.Lock()
mcpLevel := h.ss.logLevel
h.ss.mu.Unlock()
return level >= mcpLevelToSlog(mcpLevel)
// This is already checked in ServerSession.LoggingMessage. Checking it here
// would be an optimization that skips building the JSON, but it would
// end up loading the SessionState twice, so don't do it.
return true
}

// WithAttrs implements [slog.Handler.WithAttrs].
Expand Down
14 changes: 7 additions & 7 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func TestEndToEnd(t *testing.T) {
s.AddResource(resource2, readHandler)

// Connect the server.
ss, err := s.Connect(ctx, st)
ss, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -549,7 +549,7 @@ func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *Clien
if config != nil {
config(s)
}
ss, err := s.Connect(ctx, st)
ss, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -598,7 +598,7 @@ func TestBatching(t *testing.T) {
ct, st := NewInMemoryTransports()

s := NewServer(testImpl, nil)
_, err := s.Connect(ctx, st)
_, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -668,7 +668,7 @@ func TestMiddleware(t *testing.T) {
ct, st := NewInMemoryTransports()

s := NewServer(testImpl, nil)
ss, err := s.Connect(ctx, st)
ss, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -777,7 +777,7 @@ func TestNoJSONNull(t *testing.T) {
ct = NewLoggingTransport(ct, &logbuf)

s := NewServer(testImpl, nil)
ss, err := s.Connect(ctx, st)
ss, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -845,7 +845,7 @@ func TestKeepAlive(t *testing.T) {
s := NewServer(testImpl, serverOpts)
AddTool(s, greetTool(), sayHi)

ss, err := s.Connect(ctx, st)
ss, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -889,7 +889,7 @@ func TestKeepAliveFailure(t *testing.T) {
// Server without keepalive (to test one-sided keepalive)
s := NewServer(testImpl, nil)
AddTool(s, greetTool(), sayHi)
ss, err := s.Connect(ctx, st)
ss, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
154 changes: 110 additions & 44 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"encoding/base64"
"encoding/gob"
"encoding/json"
"errors"
"fmt"
"iter"
"maps"
Expand Down Expand Up @@ -514,7 +515,8 @@ func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *Uns
// If no tools have been added, the server will not have the tool capability.
// The same goes for other features like prompts and resources.
func (s *Server) Run(ctx context.Context, t Transport) error {
ss, err := s.Connect(ctx, t)
// TODO: provide a way to pass ServerSessionOptions?
ss, err := s.Connect(ctx, t, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -557,34 +559,71 @@ func (s *Server) disconnect(cc *ServerSession) {
}
}

type SessionOptions struct {
// SessionID is the session's unique ID.
SessionID string
// SessionState is the current state of the session. The default is the initial
// state.
SessionState *SessionState
// SessionStore stores SessionStates. By default it is a MemorySessionStore.
SessionStore SessionStore
}

// Connect connects the MCP server over the given transport and starts handling
// messages.
// It returns a [ServerSession] for interacting with a [Client].
//
// It returns a connection object that may be used to terminate the connection
// (with [Connection.Close]), or await client termination (with
// [Connection.Wait]).
func (s *Server) Connect(ctx context.Context, t Transport) (*ServerSession, error) {
return connect(ctx, t, s)
// [SessionOptions.SessionStore] should be nil only for single-session transports,
// like [StdioTransport]. Multi-session transports, like [StreamableServerTransport],
// must provide a [SessionStore].
func (s *Server) Connect(ctx context.Context, t Transport, opts *SessionOptions) (*ServerSession, error) {
if opts != nil && opts.SessionState == nil && opts.SessionStore != nil {
return nil, errors.New("ServerSessionOptions has store but no state")
}
ss, err := connect(ctx, t, s)
if err != nil {
return nil, err
}
var state *SessionState
ss.mu.Lock()
if opts != nil {
ss.sessionID = opts.SessionID
ss.store = opts.SessionStore
state = opts.SessionState
}
if ss.store == nil {
ss.store = NewMemorySessionStore()
}
ss.mu.Unlock()
if state == nil {
state = &SessionState{}
}
// TODO(jba): This store is redundant with the one in StreamableHTTPHandler.ServeHTTP; dedup.
if err := ss.storeState(ctx, state); err != nil {
return nil, err
}
return ss, nil
}

func (ss *ServerSession) initialized(ctx context.Context, params *InitializedParams) (Result, error) {
if ss.server.opts.KeepAlive > 0 {
ss.startKeepalive(ss.server.opts.KeepAlive)
}
ss.mu.Lock()
hasParams := ss.initializeParams != nil
wasInitialized := ss._initialized
if hasParams {
ss._initialized = true
// TODO(jba): optimistic locking
state, err := ss.loadState(ctx)
if err != nil {
return nil, err
}
ss.mu.Unlock()

if !hasParams {
if state.InitializeParams == nil {
return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize)
}
if wasInitialized {
if state.Initialized {
return nil, fmt.Errorf("duplicate %q received", notificationInitialized)
}
state.Initialized = true
if err := ss.storeState(ctx, state); err != nil {
return nil, err
}
return callNotificationHandler(ctx, ss.server.opts.InitializedHandler, ss, params)
}

Expand All @@ -611,25 +650,32 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot
// Call [ServerSession.Close] to close the connection, or await client
// termination with [ServerSession.Wait].
type ServerSession struct {
server *Server
conn *jsonrpc2.Connection
mcpConn Connection
mu sync.Mutex
logLevel LoggingLevel
initializeParams *InitializeParams
_initialized bool
keepaliveCancel context.CancelFunc
server *Server
conn *jsonrpc2.Connection
mu sync.Mutex
logLevel LoggingLevel
keepaliveCancel context.CancelFunc
sessionID string
store SessionStore
}

func (ss *ServerSession) setConn(c Connection) {
ss.mcpConn = c
}

func (ss *ServerSession) ID() string {
if ss.mcpConn == nil {
return ""
}
return ss.mcpConn.SessionID()
return ss.sessionID
}

func (ss *ServerSession) loadState(ctx context.Context) (*SessionState, error) {
ss.mu.Lock()
defer ss.mu.Unlock()
return ss.store.Load(ctx, ss.sessionID)
}

func (ss *ServerSession) storeState(ctx context.Context, state *SessionState) error {
ss.mu.Lock()
defer ss.mu.Unlock()
return ss.store.Store(ctx, ss.sessionID, state)
}

// Ping pings the client.
Expand All @@ -652,16 +698,19 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag
// The message is not sent if the client has not called SetLevel, or if its level
// is below that of the last SetLevel.
func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) error {
ss.mu.Lock()
logLevel := ss.logLevel
ss.mu.Unlock()
if logLevel == "" {
// TODO: Loading the state on every log message can be expensive. Consider caching it briefly, perhaps for the
// duration of a request.
state, err := ss.loadState(ctx)
if err != nil {
return err
}
if state.LogLevel == "" {
// The spec is unclear, but seems to imply that no log messages are sent until the client
// sets the level.
// TODO(jba): read other SDKs, possibly file an issue.
return nil
}
if compareLevels(params.Level, logLevel) < 0 {
if compareLevels(params.Level, state.LogLevel) < 0 {
return nil
}
return handleNotify(ctx, ss, notificationLoggingMessage, params)
Expand Down Expand Up @@ -742,33 +791,45 @@ func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn }

// handle invokes the method described by the given JSON RPC request.
func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) {
ss.mu.Lock()
initialized := ss._initialized
ss.mu.Unlock()
state, err := ss.loadState(ctx)
if err != nil {
return nil, err
}
// From the spec:
// "The client SHOULD NOT send requests other than pings before the server
// has responded to the initialize request."
switch req.Method {
case methodInitialize, methodPing, notificationInitialized:
default:
if !initialized {
if !state.Initialized {
return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method)
}
}
// For the streamable transport, we need the request ID to correlate
// server->client calls and notifications to the incoming request from which
// they originated. See [idContextKey] for details.
ctx = context.WithValue(ctx, idContextKey{}, req.ID)
// TODO(jba): pass the state down so that it doesn't get loaded again.
return handleReceive(ctx, ss, req)
}

func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParams) (*InitializeResult, error) {
if params == nil {
return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams)
}
ss.mu.Lock()
ss.initializeParams = params
ss.mu.Unlock()

// TODO(jba): optimistic locking
state, err := ss.loadState(ctx)
if err != nil {
return nil, err
}
if state.InitializeParams != nil {
return nil, fmt.Errorf("session %s already initialized", ss.sessionID)
}
state.InitializeParams = params
if err := ss.storeState(ctx, state); err != nil {
return nil, fmt.Errorf("storing session state: %w", err)
}

// If we support the client's version, reply with it. Otherwise, reply with our
// latest version.
Expand All @@ -791,10 +852,15 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error
return &emptyResult{}, nil
}

func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*emptyResult, error) {
ss.mu.Lock()
defer ss.mu.Unlock()
ss.logLevel = params.Level
func (ss *ServerSession) setLevel(ctx context.Context, params *SetLevelParams) (*emptyResult, error) {
state, err := ss.loadState(ctx)
if err != nil {
return nil, err
}
state.LogLevel = params.Level
if err := ss.storeState(ctx, state); err != nil {
return nil, err
}
return &emptyResult{}, nil
}

Expand Down
4 changes: 2 additions & 2 deletions mcp/server_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func ExampleServer() {
server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v0.0.1"}, nil)
mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi)

serverSession, err := server.Connect(ctx, serverTransport)
serverSession, err := server.Connect(ctx, serverTransport, nil)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -62,7 +62,7 @@ func createSessions(ctx context.Context) (*mcp.ClientSession, *mcp.ServerSession
server := mcp.NewServer(testImpl, nil)
client := mcp.NewClient(testImpl, nil)
serverTransport, clientTransport := mcp.NewInMemoryTransports()
serverSession, err := server.Connect(ctx, serverTransport)
serverSession, err := server.Connect(ctx, serverTransport, nil)
if err != nil {
log.Fatal(err)
}
Expand Down
Loading