Skip to content

Commit 59c49bf

Browse files
committed
mcp: support stateless streamable sessions
Support stateless streamable sessions by adding a GetSessionID function to StreamableHTTPOptions. If GetSessionID returns "", the session is stateless, and no validation is performed. This is implemented by providing the session a trivial initialization state. To implement this, some parts of modelcontextprotocol#232 (distributed sessions) are copied over, since they add an API for creating an already-initialized session. In total, the following new API is added: - StreamableHTTPOptions.GetSessionID - ServerSessionOptions (a new parameter to Server.Connect) - ServerSessionState - ClientSessionOptions (a new parameter to Client.Connect, for symmetry) For modelcontextprotocol#10
1 parent be1ddf5 commit 59c49bf

18 files changed

+281
-113
lines changed

examples/client/listfeatures/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func main() {
4141
ctx := context.Background()
4242
cmd := exec.Command(args[0], args[1:]...)
4343
client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil)
44-
cs, err := client.Connect(ctx, mcp.NewCommandTransport(cmd))
44+
cs, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil)
4545
if err != nil {
4646
log.Fatal(err)
4747
}

internal/readme/client/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func main() {
2121

2222
// Connect to a server over stdin/stdout
2323
transport := mcp.NewCommandTransport(exec.Command("myserver"))
24-
session, err := client.Connect(ctx, transport)
24+
session, err := client.Connect(ctx, transport, nil)
2525
if err != nil {
2626
log.Fatal(err)
2727
}

mcp/client.go

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,9 @@ type ClientOptions struct {
7171

7272
// bind implements the binder[*ClientSession] interface, so that Clients can
7373
// be connected using [connect].
74-
func (c *Client) bind(conn *jsonrpc2.Connection) *ClientSession {
75-
cs := &ClientSession{
76-
conn: conn,
77-
client: c,
78-
}
74+
func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection) *ClientSession {
75+
assert(mcpConn != nil && conn != nil, "nil connection")
76+
cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c}
7977
c.mu.Lock()
8078
defer c.mu.Unlock()
8179
c.sessions = append(c.sessions, cs)
@@ -101,14 +99,18 @@ func (e unsupportedProtocolVersionError) Error() string {
10199
return fmt.Sprintf("unsupported protocol version: %q", e.version)
102100
}
103101

102+
// ClientConnectOptions is reserved for future use.
103+
type ClientConnectOptions struct {
104+
}
105+
104106
// Connect begins an MCP session by connecting to a server over the given
105107
// transport, and initializing the session.
106108
//
107109
// Typically, it is the responsibility of the client to close the connection
108110
// when it is no longer needed. However, if the connection is closed by the
109111
// server, calls or notifications will return an error wrapping
110112
// [ErrConnectionClosed].
111-
func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, err error) {
113+
func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientConnectOptions) (cs *ClientSession, err error) {
112114
cs, err = connect(ctx, t, c)
113115
if err != nil {
114116
return nil, err
@@ -133,9 +135,9 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e
133135
if !slices.Contains(supportedProtocolVersions, res.ProtocolVersion) {
134136
return nil, unsupportedProtocolVersionError{res.ProtocolVersion}
135137
}
136-
cs.initializeResult = res
138+
cs.state.InitializeResult = res
137139
if hc, ok := cs.mcpConn.(clientConnection); ok {
138-
hc.initialized(res)
140+
hc.sessionUpdated(cs.state)
139141
}
140142
if err := handleNotify(ctx, cs, notificationInitialized, &InitializedParams{}); err != nil {
141143
_ = cs.Close()
@@ -156,22 +158,25 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e
156158
// Call [ClientSession.Close] to close the connection, or await server
157159
// termination with [ClientSession.Wait].
158160
type ClientSession struct {
159-
conn *jsonrpc2.Connection
160-
client *Client
161-
initializeResult *InitializeResult
162-
keepaliveCancel context.CancelFunc
163-
mcpConn Connection
161+
conn *jsonrpc2.Connection
162+
client *Client
163+
keepaliveCancel context.CancelFunc
164+
mcpConn Connection
165+
166+
// No mutex is (currently) required to guard the session state, because it is
167+
// only set synchronously during Client.Connect.
168+
state clientSessionState
164169
}
165170

166-
func (cs *ClientSession) setConn(c Connection) {
167-
cs.mcpConn = c
171+
type clientSessionState struct {
172+
InitializeResult *InitializeResult
168173
}
169174

170175
func (cs *ClientSession) ID() string {
171-
if cs.mcpConn == nil {
172-
return ""
176+
if c, ok := cs.mcpConn.(hasSessionID); ok {
177+
return c.SessionID()
173178
}
174-
return cs.mcpConn.SessionID()
179+
return ""
175180
}
176181

177182
// Close performs a graceful close of the connection, preventing new requests

mcp/client_list_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func TestList(t *testing.T) {
3838
}
3939
})
4040
t.Run("iterator", func(t *testing.T) {
41-
testIterator(ctx, t, clientSession.Tools(ctx, nil), wantTools)
41+
testIterator(t, clientSession.Tools(ctx, nil), wantTools)
4242
})
4343
})
4444

@@ -60,7 +60,7 @@ func TestList(t *testing.T) {
6060
}
6161
})
6262
t.Run("iterator", func(t *testing.T) {
63-
testIterator(ctx, t, clientSession.Resources(ctx, nil), wantResources)
63+
testIterator(t, clientSession.Resources(ctx, nil), wantResources)
6464
})
6565
})
6666

@@ -81,7 +81,7 @@ func TestList(t *testing.T) {
8181
}
8282
})
8383
t.Run("ResourceTemplatesIterator", func(t *testing.T) {
84-
testIterator(ctx, t, clientSession.ResourceTemplates(ctx, nil), wantResourceTemplates)
84+
testIterator(t, clientSession.ResourceTemplates(ctx, nil), wantResourceTemplates)
8585
})
8686
})
8787

@@ -102,12 +102,12 @@ func TestList(t *testing.T) {
102102
}
103103
})
104104
t.Run("iterator", func(t *testing.T) {
105-
testIterator(ctx, t, clientSession.Prompts(ctx, nil), wantPrompts)
105+
testIterator(t, clientSession.Prompts(ctx, nil), wantPrompts)
106106
})
107107
})
108108
}
109109

110-
func testIterator[T any](ctx context.Context, t *testing.T, seq iter.Seq2[*T, error], want []*T) {
110+
func testIterator[T any](t *testing.T, seq iter.Seq2[*T, error], want []*T) {
111111
t.Helper()
112112
var got []*T
113113
for x, err := range seq {

mcp/cmd_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func TestServerRunContextCancel(t *testing.T) {
8181

8282
// send a ping to the server to ensure it's running
8383
client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil)
84-
session, err := client.Connect(ctx, clientTransport)
84+
session, err := client.Connect(ctx, clientTransport, nil)
8585
if err != nil {
8686
t.Fatal(err)
8787
}
@@ -116,7 +116,7 @@ func TestServerInterrupt(t *testing.T) {
116116
cmd := createServerCommand(t, "default")
117117

118118
client := mcp.NewClient(testImpl, nil)
119-
_, err := client.Connect(ctx, mcp.NewCommandTransport(cmd))
119+
_, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil)
120120
if err != nil {
121121
t.Fatal(err)
122122
}
@@ -189,7 +189,7 @@ func TestCmdTransport(t *testing.T) {
189189
cmd := createServerCommand(t, "default")
190190

191191
client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil)
192-
session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd))
192+
session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil)
193193
if err != nil {
194194
t.Fatal(err)
195195
}

mcp/conformance_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ func runServerTest(t *testing.T, test *conformanceTest) {
135135
// Connect the server, and connect the client stream,
136136
// but don't connect an actual client.
137137
cTransport, sTransport := NewInMemoryTransports()
138-
ss, err := s.Connect(ctx, sTransport)
138+
ss, err := s.Connect(ctx, sTransport, nil)
139139
if err != nil {
140140
t.Fatal(err)
141141
}

mcp/example_middleware_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,10 @@ func Example_loggingMiddleware() {
114114
ctx := context.Background()
115115

116116
// Connect server and client
117-
serverSession, _ := server.Connect(ctx, serverTransport)
117+
serverSession, _ := server.Connect(ctx, serverTransport, nil)
118118
defer serverSession.Close()
119119

120-
clientSession, _ := client.Connect(ctx, clientTransport)
120+
clientSession, _ := client.Connect(ctx, clientTransport, nil)
121121
defer clientSession.Close()
122122

123123
// Call the tool to demonstrate logging

mcp/logging.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool {
117117
// This is also checked in ServerSession.LoggingMessage, so checking it here
118118
// is just an optimization that skips building the JSON.
119119
h.ss.mu.Lock()
120-
mcpLevel := h.ss.logLevel
120+
mcpLevel := h.ss.state.LogLevel
121121
h.ss.mu.Unlock()
122122
return level >= mcpLevelToSlog(mcpLevel)
123123
}

mcp/mcp_test.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ func TestEndToEnd(t *testing.T) {
104104
s.AddResource(resource2, readHandler)
105105

106106
// Connect the server.
107-
ss, err := s.Connect(ctx, st)
107+
ss, err := s.Connect(ctx, st, nil)
108108
if err != nil {
109109
t.Fatal(err)
110110
}
@@ -148,7 +148,7 @@ func TestEndToEnd(t *testing.T) {
148148
c.AddRoots(&Root{URI: "file://" + rootAbs})
149149

150150
// Connect the client.
151-
cs, err := c.Connect(ctx, ct)
151+
cs, err := c.Connect(ctx, ct, nil)
152152
if err != nil {
153153
t.Fatal(err)
154154
}
@@ -549,13 +549,13 @@ func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *Clien
549549
if config != nil {
550550
config(s)
551551
}
552-
ss, err := s.Connect(ctx, st)
552+
ss, err := s.Connect(ctx, st, nil)
553553
if err != nil {
554554
t.Fatal(err)
555555
}
556556

557557
c := NewClient(testImpl, nil)
558-
cs, err := c.Connect(ctx, ct)
558+
cs, err := c.Connect(ctx, ct, nil)
559559
if err != nil {
560560
t.Fatal(err)
561561
}
@@ -598,7 +598,7 @@ func TestBatching(t *testing.T) {
598598
ct, st := NewInMemoryTransports()
599599

600600
s := NewServer(testImpl, nil)
601-
_, err := s.Connect(ctx, st)
601+
_, err := s.Connect(ctx, st, nil)
602602
if err != nil {
603603
t.Fatal(err)
604604
}
@@ -608,7 +608,7 @@ func TestBatching(t *testing.T) {
608608
// 'initialize' to block. Therefore, we can only test with a size of 1.
609609
// Since batching is being removed, we can probably just delete this.
610610
const batchSize = 1
611-
cs, err := c.Connect(ctx, ct)
611+
cs, err := c.Connect(ctx, ct, nil)
612612
if err != nil {
613613
t.Fatal(err)
614614
}
@@ -668,7 +668,7 @@ func TestMiddleware(t *testing.T) {
668668
ct, st := NewInMemoryTransports()
669669

670670
s := NewServer(testImpl, nil)
671-
ss, err := s.Connect(ctx, st)
671+
ss, err := s.Connect(ctx, st, nil)
672672
if err != nil {
673673
t.Fatal(err)
674674
}
@@ -695,7 +695,7 @@ func TestMiddleware(t *testing.T) {
695695
c.AddSendingMiddleware(traceCalls[*ClientSession](&cbuf, "S1"), traceCalls[*ClientSession](&cbuf, "S2"))
696696
c.AddReceivingMiddleware(traceCalls[*ClientSession](&cbuf, "R1"), traceCalls[*ClientSession](&cbuf, "R2"))
697697

698-
cs, err := c.Connect(ctx, ct)
698+
cs, err := c.Connect(ctx, ct, nil)
699699
if err != nil {
700700
t.Fatal(err)
701701
}
@@ -777,13 +777,13 @@ func TestNoJSONNull(t *testing.T) {
777777
ct = NewLoggingTransport(ct, &logbuf)
778778

779779
s := NewServer(testImpl, nil)
780-
ss, err := s.Connect(ctx, st)
780+
ss, err := s.Connect(ctx, st, nil)
781781
if err != nil {
782782
t.Fatal(err)
783783
}
784784

785785
c := NewClient(testImpl, nil)
786-
cs, err := c.Connect(ctx, ct)
786+
cs, err := c.Connect(ctx, ct, nil)
787787
if err != nil {
788788
t.Fatal(err)
789789
}
@@ -845,7 +845,7 @@ func TestKeepAlive(t *testing.T) {
845845
s := NewServer(testImpl, serverOpts)
846846
AddTool(s, greetTool(), sayHi)
847847

848-
ss, err := s.Connect(ctx, st)
848+
ss, err := s.Connect(ctx, st, nil)
849849
if err != nil {
850850
t.Fatal(err)
851851
}
@@ -855,7 +855,7 @@ func TestKeepAlive(t *testing.T) {
855855
KeepAlive: 100 * time.Millisecond,
856856
}
857857
c := NewClient(testImpl, clientOpts)
858-
cs, err := c.Connect(ctx, ct)
858+
cs, err := c.Connect(ctx, ct, nil)
859859
if err != nil {
860860
t.Fatal(err)
861861
}
@@ -889,7 +889,7 @@ func TestKeepAliveFailure(t *testing.T) {
889889
// Server without keepalive (to test one-sided keepalive)
890890
s := NewServer(testImpl, nil)
891891
AddTool(s, greetTool(), sayHi)
892-
ss, err := s.Connect(ctx, st)
892+
ss, err := s.Connect(ctx, st, nil)
893893
if err != nil {
894894
t.Fatal(err)
895895
}
@@ -899,7 +899,7 @@ func TestKeepAliveFailure(t *testing.T) {
899899
KeepAlive: 50 * time.Millisecond,
900900
}
901901
c := NewClient(testImpl, clientOpts)
902-
cs, err := c.Connect(ctx, ct)
902+
cs, err := c.Connect(ctx, ct, nil)
903903
if err != nil {
904904
t.Fatal(err)
905905
}

0 commit comments

Comments
 (0)