Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func main() {

// Connect to a server over stdin/stdout
transport := mcp.NewCommandTransport(exec.Command("myserver"))
session, err := client.Connect(ctx, transport)
session, err := client.Connect(ctx, transport, nil)
if err != nil {
log.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion examples/client/listfeatures/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func main() {
ctx := context.Background()
cmd := exec.Command(args[0], args[1:]...)
client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil)
cs, err := client.Connect(ctx, mcp.NewCommandTransport(cmd))
cs, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil)
if err != nil {
log.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/readme/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func main() {

// Connect to a server over stdin/stdout
transport := mcp.NewCommandTransport(exec.Command("myserver"))
session, err := client.Connect(ctx, transport)
session, err := client.Connect(ctx, transport, nil)
if err != nil {
log.Fatal(err)
}
Expand Down
44 changes: 26 additions & 18 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ type ClientOptions struct {

// bind implements the binder[*ClientSession] interface, so that Clients can
// be connected using [connect].
func (c *Client) bind(conn *jsonrpc2.Connection) *ClientSession {
cs := &ClientSession{
conn: conn,
client: c,
func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState) *ClientSession {
assert(mcpConn != nil && conn != nil, "nil connection")
cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c}
if state != nil {
cs.state = *state
}
c.mu.Lock()
defer c.mu.Unlock()
Expand All @@ -101,15 +102,19 @@ func (e unsupportedProtocolVersionError) Error() string {
return fmt.Sprintf("unsupported protocol version: %q", e.version)
}

// ClientSessionOptions is reserved for future use.
type ClientSessionOptions struct {
}

// Connect begins an MCP session by connecting to a server over the given
// transport, and initializing the session.
//
// Typically, it is the responsibility of the client to close the connection
// when it is no longer needed. However, if the connection is closed by the
// server, calls or notifications will return an error wrapping
// [ErrConnectionClosed].
func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, err error) {
cs, err = connect(ctx, t, c)
func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptions) (cs *ClientSession, err error) {
cs, err = connect(ctx, t, c, (*clientSessionState)(nil))
if err != nil {
return nil, err
}
Expand All @@ -133,9 +138,9 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e
if !slices.Contains(supportedProtocolVersions, res.ProtocolVersion) {
return nil, unsupportedProtocolVersionError{res.ProtocolVersion}
}
cs.initializeResult = res
cs.state.InitializeResult = res
if hc, ok := cs.mcpConn.(clientConnection); ok {
hc.initialized(res)
hc.sessionUpdated(cs.state)
}
if err := handleNotify(ctx, cs, notificationInitialized, &InitializedParams{}); err != nil {
_ = cs.Close()
Expand All @@ -156,22 +161,25 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e
// Call [ClientSession.Close] to close the connection, or await server
// termination with [ClientSession.Wait].
type ClientSession struct {
conn *jsonrpc2.Connection
client *Client
initializeResult *InitializeResult
keepaliveCancel context.CancelFunc
mcpConn Connection
conn *jsonrpc2.Connection
client *Client
keepaliveCancel context.CancelFunc
mcpConn Connection

// No mutex is (currently) required to guard the session state, because it is
// only set synchronously during Client.Connect.
state clientSessionState
}

func (cs *ClientSession) setConn(c Connection) {
cs.mcpConn = c
type clientSessionState struct {
InitializeResult *InitializeResult
}

func (cs *ClientSession) ID() string {
if cs.mcpConn == nil {
return ""
if c, ok := cs.mcpConn.(hasSessionID); ok {
return c.SessionID()
}
return cs.mcpConn.SessionID()
return ""
}

// Close performs a graceful close of the connection, preventing new requests
Expand Down
10 changes: 5 additions & 5 deletions mcp/client_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestList(t *testing.T) {
}
})
t.Run("iterator", func(t *testing.T) {
testIterator(ctx, t, clientSession.Tools(ctx, nil), wantTools)
testIterator(t, clientSession.Tools(ctx, nil), wantTools)
})
})

Expand All @@ -60,7 +60,7 @@ func TestList(t *testing.T) {
}
})
t.Run("iterator", func(t *testing.T) {
testIterator(ctx, t, clientSession.Resources(ctx, nil), wantResources)
testIterator(t, clientSession.Resources(ctx, nil), wantResources)
})
})

Expand All @@ -81,7 +81,7 @@ func TestList(t *testing.T) {
}
})
t.Run("ResourceTemplatesIterator", func(t *testing.T) {
testIterator(ctx, t, clientSession.ResourceTemplates(ctx, nil), wantResourceTemplates)
testIterator(t, clientSession.ResourceTemplates(ctx, nil), wantResourceTemplates)
})
})

Expand All @@ -102,12 +102,12 @@ func TestList(t *testing.T) {
}
})
t.Run("iterator", func(t *testing.T) {
testIterator(ctx, t, clientSession.Prompts(ctx, nil), wantPrompts)
testIterator(t, clientSession.Prompts(ctx, nil), wantPrompts)
})
})
}

func testIterator[T any](ctx context.Context, t *testing.T, seq iter.Seq2[*T, error], want []*T) {
func testIterator[T any](t *testing.T, seq iter.Seq2[*T, error], want []*T) {
t.Helper()
var got []*T
for x, err := range seq {
Expand Down
6 changes: 3 additions & 3 deletions mcp/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func TestServerRunContextCancel(t *testing.T) {

// send a ping to the server to ensure it's running
client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil)
session, err := client.Connect(ctx, clientTransport)
session, err := client.Connect(ctx, clientTransport, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -116,7 +116,7 @@ func TestServerInterrupt(t *testing.T) {
cmd := createServerCommand(t, "default")

client := mcp.NewClient(testImpl, nil)
_, err := client.Connect(ctx, mcp.NewCommandTransport(cmd))
_, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -189,7 +189,7 @@ func TestCmdTransport(t *testing.T) {
cmd := createServerCommand(t, "default")

client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil)
session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd))
session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil)
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion mcp/conformance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func runServerTest(t *testing.T, test *conformanceTest) {
// Connect the server, and connect the client stream,
// but don't connect an actual client.
cTransport, sTransport := NewInMemoryTransports()
ss, err := s.Connect(ctx, sTransport)
ss, err := s.Connect(ctx, sTransport, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
4 changes: 2 additions & 2 deletions mcp/example_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ func Example_loggingMiddleware() {
ctx := context.Background()

// Connect server and client
serverSession, _ := server.Connect(ctx, serverTransport)
serverSession, _ := server.Connect(ctx, serverTransport, nil)
defer serverSession.Close()

clientSession, _ := client.Connect(ctx, clientTransport)
clientSession, _ := client.Connect(ctx, clientTransport, nil)
defer clientSession.Close()

// Call the tool to demonstrate logging
Expand Down
2 changes: 1 addition & 1 deletion mcp/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool {
// This is also checked in ServerSession.LoggingMessage, so checking it here
// is just an optimization that skips building the JSON.
h.ss.mu.Lock()
mcpLevel := h.ss.logLevel
mcpLevel := h.ss.state.LogLevel
h.ss.mu.Unlock()
return level >= mcpLevelToSlog(mcpLevel)
}
Expand Down
28 changes: 14 additions & 14 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func TestEndToEnd(t *testing.T) {
s.AddResource(resource2, readHandler)

// Connect the server.
ss, err := s.Connect(ctx, st)
ss, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -148,7 +148,7 @@ func TestEndToEnd(t *testing.T) {
c.AddRoots(&Root{URI: "file://" + rootAbs})

// Connect the client.
cs, err := c.Connect(ctx, ct)
cs, err := c.Connect(ctx, ct, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -549,13 +549,13 @@ func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *Clien
if config != nil {
config(s)
}
ss, err := s.Connect(ctx, st)
ss, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}

c := NewClient(testImpl, nil)
cs, err := c.Connect(ctx, ct)
cs, err := c.Connect(ctx, ct, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -598,7 +598,7 @@ func TestBatching(t *testing.T) {
ct, st := NewInMemoryTransports()

s := NewServer(testImpl, nil)
_, err := s.Connect(ctx, st)
_, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -608,7 +608,7 @@ func TestBatching(t *testing.T) {
// 'initialize' to block. Therefore, we can only test with a size of 1.
// Since batching is being removed, we can probably just delete this.
const batchSize = 1
cs, err := c.Connect(ctx, ct)
cs, err := c.Connect(ctx, ct, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -668,7 +668,7 @@ func TestMiddleware(t *testing.T) {
ct, st := NewInMemoryTransports()

s := NewServer(testImpl, nil)
ss, err := s.Connect(ctx, st)
ss, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -695,7 +695,7 @@ func TestMiddleware(t *testing.T) {
c.AddSendingMiddleware(traceCalls[*ClientSession](&cbuf, "S1"), traceCalls[*ClientSession](&cbuf, "S2"))
c.AddReceivingMiddleware(traceCalls[*ClientSession](&cbuf, "R1"), traceCalls[*ClientSession](&cbuf, "R2"))

cs, err := c.Connect(ctx, ct)
cs, err := c.Connect(ctx, ct, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -777,13 +777,13 @@ func TestNoJSONNull(t *testing.T) {
ct = NewLoggingTransport(ct, &logbuf)

s := NewServer(testImpl, nil)
ss, err := s.Connect(ctx, st)
ss, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}

c := NewClient(testImpl, nil)
cs, err := c.Connect(ctx, ct)
cs, err := c.Connect(ctx, ct, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -845,7 +845,7 @@ func TestKeepAlive(t *testing.T) {
s := NewServer(testImpl, serverOpts)
AddTool(s, greetTool(), sayHi)

ss, err := s.Connect(ctx, st)
ss, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -855,7 +855,7 @@ func TestKeepAlive(t *testing.T) {
KeepAlive: 100 * time.Millisecond,
}
c := NewClient(testImpl, clientOpts)
cs, err := c.Connect(ctx, ct)
cs, err := c.Connect(ctx, ct, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -889,7 +889,7 @@ func TestKeepAliveFailure(t *testing.T) {
// Server without keepalive (to test one-sided keepalive)
s := NewServer(testImpl, nil)
AddTool(s, greetTool(), sayHi)
ss, err := s.Connect(ctx, st)
ss, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -899,7 +899,7 @@ func TestKeepAliveFailure(t *testing.T) {
KeepAlive: 50 * time.Millisecond,
}
c := NewClient(testImpl, clientOpts)
cs, err := c.Connect(ctx, ct)
cs, err := c.Connect(ctx, ct, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
Loading
Loading