Skip to content

Commit dae8853

Browse files
authored
mcp: support stateless streamable sessions (#277)
Support stateless streamable sessions by adding a GetSessionID function to StreamableHTTPOptions. If GetSessionID returns "", the session is stateless, and no validation is performed. This is implemented by providing the session a trivial initialization state. To implement this, some parts of #232 (distributed sessions) are copied over, since they add an API for creating an already-initialized session. In total, the following new API is added: - StreamableHTTPOptions.GetSessionID - ServerSessionOptions (a new parameter to Server.Connect) - ServerSessionState - ClientSessionOptions (a new parameter to Client.Connect, for symmetry) For #10
1 parent 119a583 commit dae8853

19 files changed

+305
-119
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func main() {
7474

7575
// Connect to a server over stdin/stdout
7676
transport := mcp.NewCommandTransport(exec.Command("myserver"))
77-
session, err := client.Connect(ctx, transport)
77+
session, err := client.Connect(ctx, transport, nil)
7878
if err != nil {
7979
log.Fatal(err)
8080
}

examples/client/listfeatures/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func main() {
4141
ctx := context.Background()
4242
cmd := exec.Command(args[0], args[1:]...)
4343
client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil)
44-
cs, err := client.Connect(ctx, mcp.NewCommandTransport(cmd))
44+
cs, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil)
4545
if err != nil {
4646
log.Fatal(err)
4747
}

internal/readme/client/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func main() {
2121

2222
// Connect to a server over stdin/stdout
2323
transport := mcp.NewCommandTransport(exec.Command("myserver"))
24-
session, err := client.Connect(ctx, transport)
24+
session, err := client.Connect(ctx, transport, nil)
2525
if err != nil {
2626
log.Fatal(err)
2727
}

mcp/client.go

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,11 @@ type ClientOptions struct {
7171

7272
// bind implements the binder[*ClientSession] interface, so that Clients can
7373
// be connected using [connect].
74-
func (c *Client) bind(conn *jsonrpc2.Connection) *ClientSession {
75-
cs := &ClientSession{
76-
conn: conn,
77-
client: c,
74+
func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState) *ClientSession {
75+
assert(mcpConn != nil && conn != nil, "nil connection")
76+
cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c}
77+
if state != nil {
78+
cs.state = *state
7879
}
7980
c.mu.Lock()
8081
defer c.mu.Unlock()
@@ -101,15 +102,19 @@ func (e unsupportedProtocolVersionError) Error() string {
101102
return fmt.Sprintf("unsupported protocol version: %q", e.version)
102103
}
103104

105+
// ClientSessionOptions is reserved for future use.
106+
type ClientSessionOptions struct {
107+
}
108+
104109
// Connect begins an MCP session by connecting to a server over the given
105110
// transport, and initializing the session.
106111
//
107112
// Typically, it is the responsibility of the client to close the connection
108113
// when it is no longer needed. However, if the connection is closed by the
109114
// server, calls or notifications will return an error wrapping
110115
// [ErrConnectionClosed].
111-
func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, err error) {
112-
cs, err = connect(ctx, t, c)
116+
func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptions) (cs *ClientSession, err error) {
117+
cs, err = connect(ctx, t, c, (*clientSessionState)(nil))
113118
if err != nil {
114119
return nil, err
115120
}
@@ -133,9 +138,9 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e
133138
if !slices.Contains(supportedProtocolVersions, res.ProtocolVersion) {
134139
return nil, unsupportedProtocolVersionError{res.ProtocolVersion}
135140
}
136-
cs.initializeResult = res
141+
cs.state.InitializeResult = res
137142
if hc, ok := cs.mcpConn.(clientConnection); ok {
138-
hc.initialized(res)
143+
hc.sessionUpdated(cs.state)
139144
}
140145
if err := handleNotify(ctx, cs, notificationInitialized, &InitializedParams{}); err != nil {
141146
_ = cs.Close()
@@ -156,22 +161,25 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e
156161
// Call [ClientSession.Close] to close the connection, or await server
157162
// termination with [ClientSession.Wait].
158163
type ClientSession struct {
159-
conn *jsonrpc2.Connection
160-
client *Client
161-
initializeResult *InitializeResult
162-
keepaliveCancel context.CancelFunc
163-
mcpConn Connection
164+
conn *jsonrpc2.Connection
165+
client *Client
166+
keepaliveCancel context.CancelFunc
167+
mcpConn Connection
168+
169+
// No mutex is (currently) required to guard the session state, because it is
170+
// only set synchronously during Client.Connect.
171+
state clientSessionState
164172
}
165173

166-
func (cs *ClientSession) setConn(c Connection) {
167-
cs.mcpConn = c
174+
type clientSessionState struct {
175+
InitializeResult *InitializeResult
168176
}
169177

170178
func (cs *ClientSession) ID() string {
171-
if cs.mcpConn == nil {
172-
return ""
179+
if c, ok := cs.mcpConn.(hasSessionID); ok {
180+
return c.SessionID()
173181
}
174-
return cs.mcpConn.SessionID()
182+
return ""
175183
}
176184

177185
// Close performs a graceful close of the connection, preventing new requests

mcp/client_list_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func TestList(t *testing.T) {
3838
}
3939
})
4040
t.Run("iterator", func(t *testing.T) {
41-
testIterator(ctx, t, clientSession.Tools(ctx, nil), wantTools)
41+
testIterator(t, clientSession.Tools(ctx, nil), wantTools)
4242
})
4343
})
4444

@@ -60,7 +60,7 @@ func TestList(t *testing.T) {
6060
}
6161
})
6262
t.Run("iterator", func(t *testing.T) {
63-
testIterator(ctx, t, clientSession.Resources(ctx, nil), wantResources)
63+
testIterator(t, clientSession.Resources(ctx, nil), wantResources)
6464
})
6565
})
6666

@@ -81,7 +81,7 @@ func TestList(t *testing.T) {
8181
}
8282
})
8383
t.Run("ResourceTemplatesIterator", func(t *testing.T) {
84-
testIterator(ctx, t, clientSession.ResourceTemplates(ctx, nil), wantResourceTemplates)
84+
testIterator(t, clientSession.ResourceTemplates(ctx, nil), wantResourceTemplates)
8585
})
8686
})
8787

@@ -102,12 +102,12 @@ func TestList(t *testing.T) {
102102
}
103103
})
104104
t.Run("iterator", func(t *testing.T) {
105-
testIterator(ctx, t, clientSession.Prompts(ctx, nil), wantPrompts)
105+
testIterator(t, clientSession.Prompts(ctx, nil), wantPrompts)
106106
})
107107
})
108108
}
109109

110-
func testIterator[T any](ctx context.Context, t *testing.T, seq iter.Seq2[*T, error], want []*T) {
110+
func testIterator[T any](t *testing.T, seq iter.Seq2[*T, error], want []*T) {
111111
t.Helper()
112112
var got []*T
113113
for x, err := range seq {

mcp/cmd_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func TestServerRunContextCancel(t *testing.T) {
8181

8282
// send a ping to the server to ensure it's running
8383
client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil)
84-
session, err := client.Connect(ctx, clientTransport)
84+
session, err := client.Connect(ctx, clientTransport, nil)
8585
if err != nil {
8686
t.Fatal(err)
8787
}
@@ -116,7 +116,7 @@ func TestServerInterrupt(t *testing.T) {
116116
cmd := createServerCommand(t, "default")
117117

118118
client := mcp.NewClient(testImpl, nil)
119-
_, err := client.Connect(ctx, mcp.NewCommandTransport(cmd))
119+
_, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil)
120120
if err != nil {
121121
t.Fatal(err)
122122
}
@@ -198,7 +198,7 @@ func TestCmdTransport(t *testing.T) {
198198
cmd := createServerCommand(t, "default")
199199

200200
client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil)
201-
session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd))
201+
session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil)
202202
if err != nil {
203203
t.Fatal(err)
204204
}

mcp/conformance_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ func runServerTest(t *testing.T, test *conformanceTest) {
135135
// Connect the server, and connect the client stream,
136136
// but don't connect an actual client.
137137
cTransport, sTransport := NewInMemoryTransports()
138-
ss, err := s.Connect(ctx, sTransport)
138+
ss, err := s.Connect(ctx, sTransport, nil)
139139
if err != nil {
140140
t.Fatal(err)
141141
}

mcp/example_middleware_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,10 @@ func Example_loggingMiddleware() {
114114
ctx := context.Background()
115115

116116
// Connect server and client
117-
serverSession, _ := server.Connect(ctx, serverTransport)
117+
serverSession, _ := server.Connect(ctx, serverTransport, nil)
118118
defer serverSession.Close()
119119

120-
clientSession, _ := client.Connect(ctx, clientTransport)
120+
clientSession, _ := client.Connect(ctx, clientTransport, nil)
121121
defer clientSession.Close()
122122

123123
// Call the tool to demonstrate logging

mcp/logging.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool {
117117
// This is also checked in ServerSession.LoggingMessage, so checking it here
118118
// is just an optimization that skips building the JSON.
119119
h.ss.mu.Lock()
120-
mcpLevel := h.ss.logLevel
120+
mcpLevel := h.ss.state.LogLevel
121121
h.ss.mu.Unlock()
122122
return level >= mcpLevelToSlog(mcpLevel)
123123
}

mcp/mcp_test.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ func TestEndToEnd(t *testing.T) {
104104
s.AddResource(resource2, readHandler)
105105

106106
// Connect the server.
107-
ss, err := s.Connect(ctx, st)
107+
ss, err := s.Connect(ctx, st, nil)
108108
if err != nil {
109109
t.Fatal(err)
110110
}
@@ -148,7 +148,7 @@ func TestEndToEnd(t *testing.T) {
148148
c.AddRoots(&Root{URI: "file://" + rootAbs})
149149

150150
// Connect the client.
151-
cs, err := c.Connect(ctx, ct)
151+
cs, err := c.Connect(ctx, ct, nil)
152152
if err != nil {
153153
t.Fatal(err)
154154
}
@@ -549,13 +549,13 @@ func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *Clien
549549
if config != nil {
550550
config(s)
551551
}
552-
ss, err := s.Connect(ctx, st)
552+
ss, err := s.Connect(ctx, st, nil)
553553
if err != nil {
554554
t.Fatal(err)
555555
}
556556

557557
c := NewClient(testImpl, nil)
558-
cs, err := c.Connect(ctx, ct)
558+
cs, err := c.Connect(ctx, ct, nil)
559559
if err != nil {
560560
t.Fatal(err)
561561
}
@@ -598,7 +598,7 @@ func TestBatching(t *testing.T) {
598598
ct, st := NewInMemoryTransports()
599599

600600
s := NewServer(testImpl, nil)
601-
_, err := s.Connect(ctx, st)
601+
_, err := s.Connect(ctx, st, nil)
602602
if err != nil {
603603
t.Fatal(err)
604604
}
@@ -608,7 +608,7 @@ func TestBatching(t *testing.T) {
608608
// 'initialize' to block. Therefore, we can only test with a size of 1.
609609
// Since batching is being removed, we can probably just delete this.
610610
const batchSize = 1
611-
cs, err := c.Connect(ctx, ct)
611+
cs, err := c.Connect(ctx, ct, nil)
612612
if err != nil {
613613
t.Fatal(err)
614614
}
@@ -668,7 +668,7 @@ func TestMiddleware(t *testing.T) {
668668
ct, st := NewInMemoryTransports()
669669

670670
s := NewServer(testImpl, nil)
671-
ss, err := s.Connect(ctx, st)
671+
ss, err := s.Connect(ctx, st, nil)
672672
if err != nil {
673673
t.Fatal(err)
674674
}
@@ -695,7 +695,7 @@ func TestMiddleware(t *testing.T) {
695695
c.AddSendingMiddleware(traceCalls[*ClientSession](&cbuf, "S1"), traceCalls[*ClientSession](&cbuf, "S2"))
696696
c.AddReceivingMiddleware(traceCalls[*ClientSession](&cbuf, "R1"), traceCalls[*ClientSession](&cbuf, "R2"))
697697

698-
cs, err := c.Connect(ctx, ct)
698+
cs, err := c.Connect(ctx, ct, nil)
699699
if err != nil {
700700
t.Fatal(err)
701701
}
@@ -777,13 +777,13 @@ func TestNoJSONNull(t *testing.T) {
777777
ct = NewLoggingTransport(ct, &logbuf)
778778

779779
s := NewServer(testImpl, nil)
780-
ss, err := s.Connect(ctx, st)
780+
ss, err := s.Connect(ctx, st, nil)
781781
if err != nil {
782782
t.Fatal(err)
783783
}
784784

785785
c := NewClient(testImpl, nil)
786-
cs, err := c.Connect(ctx, ct)
786+
cs, err := c.Connect(ctx, ct, nil)
787787
if err != nil {
788788
t.Fatal(err)
789789
}
@@ -845,7 +845,7 @@ func TestKeepAlive(t *testing.T) {
845845
s := NewServer(testImpl, serverOpts)
846846
AddTool(s, greetTool(), sayHi)
847847

848-
ss, err := s.Connect(ctx, st)
848+
ss, err := s.Connect(ctx, st, nil)
849849
if err != nil {
850850
t.Fatal(err)
851851
}
@@ -855,7 +855,7 @@ func TestKeepAlive(t *testing.T) {
855855
KeepAlive: 100 * time.Millisecond,
856856
}
857857
c := NewClient(testImpl, clientOpts)
858-
cs, err := c.Connect(ctx, ct)
858+
cs, err := c.Connect(ctx, ct, nil)
859859
if err != nil {
860860
t.Fatal(err)
861861
}
@@ -889,7 +889,7 @@ func TestKeepAliveFailure(t *testing.T) {
889889
// Server without keepalive (to test one-sided keepalive)
890890
s := NewServer(testImpl, nil)
891891
AddTool(s, greetTool(), sayHi)
892-
ss, err := s.Connect(ctx, st)
892+
ss, err := s.Connect(ctx, st, nil)
893893
if err != nil {
894894
t.Fatal(err)
895895
}
@@ -899,7 +899,7 @@ func TestKeepAliveFailure(t *testing.T) {
899899
KeepAlive: 50 * time.Millisecond,
900900
}
901901
c := NewClient(testImpl, clientOpts)
902-
cs, err := c.Connect(ctx, ct)
902+
cs, err := c.Connect(ctx, ct, nil)
903903
if err != nil {
904904
t.Fatal(err)
905905
}

0 commit comments

Comments
 (0)