Skip to content

Commit c6facb5

Browse files
committed
add Connect arg (#232)
1 parent 427f775 commit c6facb5

File tree

7 files changed

+55
-43
lines changed

7 files changed

+55
-43
lines changed

mcp/example_middleware_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ func Example_loggingMiddleware() {
114114
ctx := context.Background()
115115

116116
// Connect server and client
117-
serverSession, _ := server.Connect(ctx, serverTransport)
117+
serverSession, _ := server.Connect(ctx, serverTransport, nil)
118118
defer serverSession.Close()
119119

120120
clientSession, _ := client.Connect(ctx, clientTransport)

mcp/logging.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool {
117117
// This is also checked in ServerSession.LoggingMessage, so checking it here
118118
// is just an optimization that skips building the JSON.
119119
h.ss.mu.Lock()
120-
mcpLevel := h.ss.state.LogLevel
120+
mcpLevel := h.ss.opts.SessionState.LogLevel
121121
h.ss.mu.Unlock()
122122
return level >= mcpLevelToSlog(mcpLevel)
123123
}

mcp/mcp_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ func TestEndToEnd(t *testing.T) {
104104
s.AddResource(resource2, readHandler)
105105

106106
// Connect the server.
107-
ss, err := s.Connect(ctx, st)
107+
ss, err := s.Connect(ctx, st, nil)
108108
if err != nil {
109109
t.Fatal(err)
110110
}
@@ -549,7 +549,7 @@ func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *Clien
549549
if config != nil {
550550
config(s)
551551
}
552-
ss, err := s.Connect(ctx, st)
552+
ss, err := s.Connect(ctx, st, nil)
553553
if err != nil {
554554
t.Fatal(err)
555555
}
@@ -598,7 +598,7 @@ func TestBatching(t *testing.T) {
598598
ct, st := NewInMemoryTransports()
599599

600600
s := NewServer(testImpl, nil)
601-
_, err := s.Connect(ctx, st)
601+
_, err := s.Connect(ctx, st, nil)
602602
if err != nil {
603603
t.Fatal(err)
604604
}
@@ -668,7 +668,7 @@ func TestMiddleware(t *testing.T) {
668668
ct, st := NewInMemoryTransports()
669669

670670
s := NewServer(testImpl, nil)
671-
ss, err := s.Connect(ctx, st)
671+
ss, err := s.Connect(ctx, st, nil)
672672
if err != nil {
673673
t.Fatal(err)
674674
}
@@ -777,7 +777,7 @@ func TestNoJSONNull(t *testing.T) {
777777
ct = NewLoggingTransport(ct, &logbuf)
778778

779779
s := NewServer(testImpl, nil)
780-
ss, err := s.Connect(ctx, st)
780+
ss, err := s.Connect(ctx, st, nil)
781781
if err != nil {
782782
t.Fatal(err)
783783
}
@@ -845,7 +845,7 @@ func TestKeepAlive(t *testing.T) {
845845
s := NewServer(testImpl, serverOpts)
846846
AddTool(s, greetTool(), sayHi)
847847

848-
ss, err := s.Connect(ctx, st)
848+
ss, err := s.Connect(ctx, st, nil)
849849
if err != nil {
850850
t.Fatal(err)
851851
}
@@ -889,7 +889,7 @@ func TestKeepAliveFailure(t *testing.T) {
889889
// Server without keepalive (to test one-sided keepalive)
890890
s := NewServer(testImpl, nil)
891891
AddTool(s, greetTool(), sayHi)
892-
ss, err := s.Connect(ctx, st)
892+
ss, err := s.Connect(ctx, st, nil)
893893
if err != nil {
894894
t.Fatal(err)
895895
}

mcp/server.go

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"encoding/base64"
1111
"encoding/gob"
1212
"encoding/json"
13+
"errors"
1314
"fmt"
1415
"iter"
1516
"maps"
@@ -513,7 +514,8 @@ func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *Uns
513514
// If no tools have been added, the server will not have the tool capability.
514515
// The same goes for other features like prompts and resources.
515516
func (s *Server) Run(ctx context.Context, t Transport) error {
516-
ss, err := s.Connect(ctx, t)
517+
// TODO: provide a way to pass ServerSessionOptions?
518+
ss, err := s.Connect(ctx, t, nil)
517519
if err != nil {
518520
return err
519521
}
@@ -535,7 +537,7 @@ func (s *Server) Run(ctx context.Context, t Transport) error {
535537
// bind implements the binder[*ServerSession] interface, so that Servers can
536538
// be connected using [connect].
537539
func (s *Server) bind(conn *jsonrpc2.Connection) *ServerSession {
538-
ss := &ServerSession{conn: conn, server: s, state: &SessionState{}}
540+
ss := &ServerSession{conn: conn, server: s}
539541
s.mu.Lock()
540542
s.sessions = append(s.sessions, ss)
541543
s.mu.Unlock()
@@ -556,14 +558,33 @@ func (s *Server) disconnect(cc *ServerSession) {
556558
}
557559
}
558560

561+
type ServerSessionOptions struct {
562+
SessionID string
563+
SessionState *SessionState
564+
SessionStore SessionStore
565+
}
566+
559567
// Connect connects the MCP server over the given transport and starts handling
560568
// messages.
561569
//
562570
// It returns a connection object that may be used to terminate the connection
563571
// (with [Connection.Close]), or await client termination (with
564572
// [Connection.Wait]).
565-
func (s *Server) Connect(ctx context.Context, t Transport) (*ServerSession, error) {
566-
return connect(ctx, t, s)
573+
func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOptions) (*ServerSession, error) {
574+
if opts != nil && opts.SessionState == nil && opts.SessionStore != nil {
575+
return nil, errors.New("ServerSessionOptions has store but no state")
576+
}
577+
ss, err := connect(ctx, t, s)
578+
if err != nil {
579+
return nil, err
580+
}
581+
if opts != nil {
582+
ss.opts = *opts
583+
}
584+
if ss.opts.SessionState == nil {
585+
ss.opts.SessionState = &SessionState{}
586+
}
587+
return ss, nil
567588
}
568589

569590
func (s *Server) callInitializedHandler(ctx context.Context, ss *ServerSession, params *InitializedParams) (Result, error) {
@@ -596,32 +617,20 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot
596617
// Call [ServerSession.Close] to close the connection, or await client
597618
// termination with [ServerSession.Wait].
598619
type ServerSession struct {
599-
server *Server
600-
conn *jsonrpc2.Connection
620+
server *Server
621+
conn *jsonrpc2.Connection
622+
opts ServerSessionOptions
623+
601624
mu sync.Mutex
602625
initialized bool
603626
keepaliveCancel context.CancelFunc
604-
605-
sessionID string
606-
state *SessionState
607-
store SessionStore
608627
}
609628

610629
func (ss *ServerSession) setConn(c Connection) {
611630
}
612631

613-
func (ss *ServerSession) ID() string { return ss.sessionID }
614-
615-
// InitSession initializes the session with a session ID, state, and store.
616-
// If called, it must be called immediately after the session is connected.
617-
// If never called, the session will begin with a zero SessionState and no session
618-
// ID or store.
619-
func (ss *ServerSession) InitSession(sessionID string, state *SessionState, store SessionStore) {
620-
ss.mu.Lock()
621-
defer ss.mu.Unlock()
622-
ss.sessionID = sessionID
623-
ss.state = state
624-
ss.store = store
632+
func (ss *ServerSession) ID() string {
633+
return ss.opts.SessionID
625634
}
626635

627636
// Ping pings the client.
@@ -645,7 +654,7 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag
645654
// is below that of the last SetLevel.
646655
func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) error {
647656
ss.mu.Lock()
648-
logLevel := ss.state.LogLevel
657+
logLevel := ss.opts.SessionState.LogLevel
649658
ss.mu.Unlock()
650659
if logLevel == "" {
651660
// The spec is unclear, but seems to imply that no log messages are sent until the client
@@ -759,10 +768,10 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam
759768
return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams)
760769
}
761770
ss.mu.Lock()
762-
ss.state.InitializeParams = params
771+
ss.opts.SessionState.InitializeParams = params
763772
ss.mu.Unlock()
764-
if ss.store != nil {
765-
if err := ss.store.Store(ctx, ss.sessionID, ss.state); err != nil {
773+
if store := ss.opts.SessionStore; store != nil {
774+
if err := store.Store(ctx, ss.opts.SessionID, ss.opts.SessionState); err != nil {
766775
return nil, fmt.Errorf("storing session state: %w", err)
767776
}
768777
}
@@ -802,9 +811,9 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error
802811
func (ss *ServerSession) setLevel(ctx context.Context, params *SetLevelParams) (*emptyResult, error) {
803812
ss.mu.Lock()
804813
defer ss.mu.Unlock()
805-
ss.state.LogLevel = params.Level
806-
if ss.store != nil {
807-
if err := ss.store.Store(ctx, ss.sessionID, ss.state); err != nil {
814+
ss.opts.SessionState.LogLevel = params.Level
815+
if store := ss.opts.SessionStore; store != nil {
816+
if err := store.Store(ctx, ss.opts.SessionID, ss.opts.SessionState); err != nil {
808817
return nil, err
809818
}
810819
}

mcp/server_example_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func ExampleServer() {
3131
server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v0.0.1"}, nil)
3232
mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi)
3333

34-
serverSession, err := server.Connect(ctx, serverTransport)
34+
serverSession, err := server.Connect(ctx, serverTransport, nil)
3535
if err != nil {
3636
log.Fatal(err)
3737
}
@@ -62,7 +62,7 @@ func createSessions(ctx context.Context) (*mcp.ClientSession, *mcp.ServerSession
6262
server := mcp.NewServer(testImpl, nil)
6363
client := mcp.NewClient(testImpl, nil)
6464
serverTransport, clientTransport := mcp.NewInMemoryTransports()
65-
serverSession, err := server.Connect(ctx, serverTransport)
65+
serverSession, err := server.Connect(ctx, serverTransport, nil)
6666
if err != nil {
6767
log.Fatal(err)
6868
}

mcp/sse.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
221221
http.Error(w, "no server available", http.StatusBadRequest)
222222
return
223223
}
224-
ss, err := server.Connect(req.Context(), transport)
224+
ss, err := server.Connect(req.Context(), transport, nil)
225225
if err != nil {
226226
http.Error(w, "connection failed", http.StatusInternalServerError)
227227
return

mcp/streamable.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,15 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
171171
// Pass req.Context() here, to allow middleware to add context values.
172172
// The context is detached in the jsonrpc2 library when handling the
173173
// long-running stream.
174-
session, err := server.Connect(req.Context(), transport)
174+
_, err = server.Connect(req.Context(), transport, &ServerSessionOptions{
175+
SessionID: sessionID,
176+
SessionState: state,
177+
SessionStore: h.opts.SessionStore,
178+
})
175179
if err != nil {
176180
http.Error(w, fmt.Sprintf("failed connection: %v", err), http.StatusInternalServerError)
177181
return
178182
}
179-
session.InitSession(sessionID, state, h.opts.SessionStore)
180183
h.transportMu.Lock()
181184
h.transports[transport.sessionID] = transport
182185
h.transportMu.Unlock()

0 commit comments

Comments
 (0)