Skip to content

Commit bf79d78

Browse files
authored
mcp: remove references to ClientRequest[T] for concrete T (#343)
See related PR about ServerRequest[T].
1 parent d8e18b3 commit bf79d78

File tree

6 files changed

+60
-47
lines changed

6 files changed

+60
-47
lines changed

mcp/client.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,14 @@ func NewClient(impl *Implementation, opts *ClientOptions) *Client {
5555
type ClientOptions struct {
5656
// Handler for sampling.
5757
// Called when a server calls CreateMessage.
58-
CreateMessageHandler func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error)
58+
CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error)
5959
// Handlers for notifications from the server.
60-
ToolListChangedHandler func(context.Context, *ClientRequest[*ToolListChangedParams])
61-
PromptListChangedHandler func(context.Context, *ClientRequest[*PromptListChangedParams])
62-
ResourceListChangedHandler func(context.Context, *ClientRequest[*ResourceListChangedParams])
63-
ResourceUpdatedHandler func(context.Context, *ClientRequest[*ResourceUpdatedNotificationParams])
64-
LoggingMessageHandler func(context.Context, *ClientRequest[*LoggingMessageParams])
65-
ProgressNotificationHandler func(context.Context, *ClientRequest[*ProgressNotificationParams])
60+
ToolListChangedHandler func(context.Context, *ToolListChangedRequest)
61+
PromptListChangedHandler func(context.Context, *PromptListChangedRequest)
62+
ResourceListChangedHandler func(context.Context, *ResourceListChangedRequest)
63+
ResourceUpdatedHandler func(context.Context, *ResourceUpdatedNotificationRequest)
64+
LoggingMessageHandler func(context.Context, *LoggingMessageRequest)
65+
ProgressNotificationHandler func(context.Context, *ProgressNotificationClientRequest)
6666
// If non-zero, defines an interval for regular "ping" requests.
6767
// If the peer fails to respond to pings originating from the keepalive check,
6868
// the session is automatically closed.
@@ -132,7 +132,7 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio
132132
ClientInfo: c.impl,
133133
Capabilities: c.capabilities(),
134134
}
135-
req := &ClientRequest[*InitializeParams]{Session: cs, Params: params}
135+
req := &InitializeRequest{Session: cs, Params: params}
136136
res, err := handleSend[*InitializeResult](ctx, methodInitialize, req)
137137
if err != nil {
138138
_ = cs.Close()
@@ -145,7 +145,7 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio
145145
if hc, ok := cs.mcpConn.(clientConnection); ok {
146146
hc.sessionUpdated(cs.state)
147147
}
148-
req2 := &ClientRequest[*InitializedParams]{Session: cs, Params: &InitializedParams{}}
148+
req2 := &InitializedClientRequest{Session: cs, Params: &InitializedParams{}}
149149
if err := handleNotify(ctx, notificationInitialized, req2); err != nil {
150150
_ = cs.Close()
151151
return nil, err
@@ -248,7 +248,7 @@ func changeAndNotify[P Params](c *Client, notification string, params P, change
248248
notifySessions(sessions, notification, params)
249249
}
250250

251-
func (c *Client) listRoots(_ context.Context, req *ClientRequest[*ListRootsParams]) (*ListRootsResult, error) {
251+
func (c *Client) listRoots(_ context.Context, req *ListRootsRequest) (*ListRootsResult, error) {
252252
c.mu.Lock()
253253
defer c.mu.Unlock()
254254
roots := slices.Collect(c.roots.all())
@@ -260,7 +260,7 @@ func (c *Client) listRoots(_ context.Context, req *ClientRequest[*ListRootsParam
260260
}, nil
261261
}
262262

263-
func (c *Client) createMessage(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
263+
func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) {
264264
if c.opts.CreateMessageHandler == nil {
265265
// TODO: wrap or annotate this error? Pick a standard code?
266266
return nil, jsonrpc2.NewError(CodeUnsupportedMethod, "client does not support CreateMessage")
@@ -436,35 +436,35 @@ func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribePar
436436
return err
437437
}
438438

439-
func (c *Client) callToolChangedHandler(ctx context.Context, req *ClientRequest[*ToolListChangedParams]) (Result, error) {
439+
func (c *Client) callToolChangedHandler(ctx context.Context, req *ToolListChangedRequest) (Result, error) {
440440
if h := c.opts.ToolListChangedHandler; h != nil {
441441
h(ctx, req)
442442
}
443443
return nil, nil
444444
}
445445

446-
func (c *Client) callPromptChangedHandler(ctx context.Context, req *ClientRequest[*PromptListChangedParams]) (Result, error) {
446+
func (c *Client) callPromptChangedHandler(ctx context.Context, req *PromptListChangedRequest) (Result, error) {
447447
if h := c.opts.PromptListChangedHandler; h != nil {
448448
h(ctx, req)
449449
}
450450
return nil, nil
451451
}
452452

453-
func (c *Client) callResourceChangedHandler(ctx context.Context, req *ClientRequest[*ResourceListChangedParams]) (Result, error) {
453+
func (c *Client) callResourceChangedHandler(ctx context.Context, req *ResourceListChangedRequest) (Result, error) {
454454
if h := c.opts.ResourceListChangedHandler; h != nil {
455455
h(ctx, req)
456456
}
457457
return nil, nil
458458
}
459459

460-
func (c *Client) callResourceUpdatedHandler(ctx context.Context, req *ClientRequest[*ResourceUpdatedNotificationParams]) (Result, error) {
460+
func (c *Client) callResourceUpdatedHandler(ctx context.Context, req *ResourceUpdatedNotificationRequest) (Result, error) {
461461
if h := c.opts.ResourceUpdatedHandler; h != nil {
462462
h(ctx, req)
463463
}
464464
return nil, nil
465465
}
466466

467-
func (c *Client) callLoggingHandler(ctx context.Context, req *ClientRequest[*LoggingMessageParams]) (Result, error) {
467+
func (c *Client) callLoggingHandler(ctx context.Context, req *LoggingMessageRequest) (Result, error) {
468468
if h := c.opts.LoggingMessageHandler; h != nil {
469469
h(ctx, req)
470470
}

mcp/client_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ func TestClientCapabilities(t *testing.T) {
211211
name: "With sampling",
212212
configureClient: func(s *Client) {},
213213
clientOpts: ClientOptions{
214-
CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
214+
CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) {
215215
return nil, nil
216216
},
217217
},

mcp/mcp_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@ func TestEndToEnd(t *testing.T) {
7474
}
7575

7676
sopts := &ServerOptions{
77-
InitializedHandler: func(context.Context, *InitializedRequest) {
77+
InitializedHandler: func(context.Context, *InitializedServerRequest) {
7878
notificationChans["initialized"] <- 0
7979
},
8080
RootsListChangedHandler: func(context.Context, *RootsListChangedRequest) {
8181
notificationChans["roots"] <- 0
8282
},
83-
ProgressNotificationHandler: func(context.Context, *ProgressNotificationRequest) {
83+
ProgressNotificationHandler: func(context.Context, *ProgressNotificationServerRequest) {
8484
notificationChans["progress_server"] <- 0
8585
},
8686
SubscribeHandler: func(context.Context, *SubscribeRequest) error {
@@ -129,25 +129,25 @@ func TestEndToEnd(t *testing.T) {
129129

130130
loggingMessages := make(chan *LoggingMessageParams, 100) // big enough for all logging
131131
opts := &ClientOptions{
132-
CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
132+
CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) {
133133
return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil
134134
},
135-
ToolListChangedHandler: func(context.Context, *ClientRequest[*ToolListChangedParams]) {
135+
ToolListChangedHandler: func(context.Context, *ToolListChangedRequest) {
136136
notificationChans["tools"] <- 0
137137
},
138-
PromptListChangedHandler: func(context.Context, *ClientRequest[*PromptListChangedParams]) {
138+
PromptListChangedHandler: func(context.Context, *PromptListChangedRequest) {
139139
notificationChans["prompts"] <- 0
140140
},
141-
ResourceListChangedHandler: func(context.Context, *ClientRequest[*ResourceListChangedParams]) {
141+
ResourceListChangedHandler: func(context.Context, *ResourceListChangedRequest) {
142142
notificationChans["resources"] <- 0
143143
},
144-
LoggingMessageHandler: func(_ context.Context, req *ClientRequest[*LoggingMessageParams]) {
144+
LoggingMessageHandler: func(_ context.Context, req *LoggingMessageRequest) {
145145
loggingMessages <- req.Params
146146
},
147-
ProgressNotificationHandler: func(context.Context, *ClientRequest[*ProgressNotificationParams]) {
147+
ProgressNotificationHandler: func(context.Context, *ProgressNotificationClientRequest) {
148148
notificationChans["progress_client"] <- 0
149149
},
150-
ResourceUpdatedHandler: func(context.Context, *ClientRequest[*ResourceUpdatedNotificationParams]) {
150+
ResourceUpdatedHandler: func(context.Context, *ResourceUpdatedNotificationRequest) {
151151
notificationChans["resource_updated"] <- 0
152152
},
153153
}
@@ -992,10 +992,10 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) {
992992
func TestSynchronousNotifications(t *testing.T) {
993993
var toolsChanged atomic.Bool
994994
clientOpts := &ClientOptions{
995-
ToolListChangedHandler: func(ctx context.Context, req *ClientRequest[*ToolListChangedParams]) {
995+
ToolListChangedHandler: func(ctx context.Context, req *ToolListChangedRequest) {
996996
toolsChanged.Store(true)
997997
},
998-
CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
998+
CreateMessageHandler: func(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) {
999999
if !toolsChanged.Load() {
10001000
return nil, fmt.Errorf("didn't get a tools changed notification")
10011001
}
@@ -1057,7 +1057,7 @@ func TestNoDistributedDeadlock(t *testing.T) {
10571057
// possible, and in any case making tool calls asynchronous by default
10581058
// delegates synchronization to the user.
10591059
clientOpts := &ClientOptions{
1060-
CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
1060+
CreateMessageHandler: func(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) {
10611061
req.Session.CallTool(ctx, &CallToolParams{Name: "tool2"})
10621062
return &CreateMessageResult{Content: &TextContent{}}, nil
10631063
},

mcp/requests.go

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,30 @@ package mcp
88

99
// TODO: expand the aliases
1010
type (
11-
CallToolRequest = ServerRequest[*CallToolParams]
12-
CompleteRequest = ServerRequest[*CompleteParams]
13-
GetPromptRequest = ServerRequest[*GetPromptParams]
14-
InitializedRequest = ServerRequest[*InitializedParams]
15-
ListPromptsRequest = ServerRequest[*ListPromptsParams]
16-
ListResourcesRequest = ServerRequest[*ListResourcesParams]
17-
ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams]
18-
ListToolsRequest = ServerRequest[*ListToolsParams]
19-
ProgressNotificationRequest = ServerRequest[*ProgressNotificationParams]
20-
ReadResourceRequest = ServerRequest[*ReadResourceParams]
21-
RootsListChangedRequest = ServerRequest[*RootsListChangedParams]
22-
SubscribeRequest = ServerRequest[*SubscribeParams]
23-
UnsubscribeRequest = ServerRequest[*UnsubscribeParams]
11+
CallToolRequest = ServerRequest[*CallToolParams]
12+
CompleteRequest = ServerRequest[*CompleteParams]
13+
GetPromptRequest = ServerRequest[*GetPromptParams]
14+
InitializedServerRequest = ServerRequest[*InitializedParams]
15+
ListPromptsRequest = ServerRequest[*ListPromptsParams]
16+
ListResourcesRequest = ServerRequest[*ListResourcesParams]
17+
ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams]
18+
ListToolsRequest = ServerRequest[*ListToolsParams]
19+
ProgressNotificationServerRequest = ServerRequest[*ProgressNotificationParams]
20+
ReadResourceRequest = ServerRequest[*ReadResourceParams]
21+
RootsListChangedRequest = ServerRequest[*RootsListChangedParams]
22+
SubscribeRequest = ServerRequest[*SubscribeParams]
23+
UnsubscribeRequest = ServerRequest[*UnsubscribeParams]
24+
)
25+
26+
type (
27+
CreateMessageRequest = ClientRequest[*CreateMessageParams]
28+
InitializedClientRequest = ClientRequest[*InitializedParams]
29+
InitializeRequest = ClientRequest[*InitializeParams]
30+
ListRootsRequest = ClientRequest[*ListRootsParams]
31+
LoggingMessageRequest = ClientRequest[*LoggingMessageParams]
32+
ProgressNotificationClientRequest = ClientRequest[*ProgressNotificationParams]
33+
PromptListChangedRequest = ClientRequest[*PromptListChangedParams]
34+
ResourceListChangedRequest = ClientRequest[*ResourceListChangedParams]
35+
ResourceUpdatedNotificationRequest = ClientRequest[*ResourceUpdatedNotificationParams]
36+
ToolListChangedRequest = ClientRequest[*ToolListChangedParams]
2437
)

mcp/server.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@ type ServerOptions struct {
5454
// Optional instructions for connected clients.
5555
Instructions string
5656
// If non-nil, called when "notifications/initialized" is received.
57-
InitializedHandler func(context.Context, *InitializedRequest)
57+
InitializedHandler func(context.Context, *InitializedServerRequest)
5858
// PageSize is the maximum number of items to return in a single page for
5959
// list methods (e.g. ListTools).
6060
PageSize int
6161
// If non-nil, called when "notifications/roots/list_changed" is received.
6262
RootsListChangedHandler func(context.Context, *RootsListChangedRequest)
6363
// If non-nil, called when "notifications/progress" is received.
64-
ProgressNotificationHandler func(context.Context, *ProgressNotificationRequest)
64+
ProgressNotificationHandler func(context.Context, *ProgressNotificationServerRequest)
6565
// If non-nil, called when "completion/complete" is received.
6666
CompletionHandler func(context.Context, *CompleteRequest) (*CompleteResult, error)
6767
// If non-zero, defines an interval for regular "ping" requests.

mcp/streamable_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func TestStreamableTransports(t *testing.T) {
121121
HTTPClient: httpClient,
122122
}
123123
client := NewClient(testImpl, &ClientOptions{
124-
CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
124+
CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) {
125125
return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil
126126
},
127127
})
@@ -255,7 +255,7 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
255255
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
256256
defer cancel()
257257
client := NewClient(testImpl, &ClientOptions{
258-
ProgressNotificationHandler: func(ctx context.Context, req *ClientRequest[*ProgressNotificationParams]) {
258+
ProgressNotificationHandler: func(ctx context.Context, req *ProgressNotificationClientRequest) {
259259
notifications <- req.Params.Message
260260
},
261261
})
@@ -344,7 +344,7 @@ func TestServerInitiatedSSE(t *testing.T) {
344344
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
345345
defer cancel()
346346
client := NewClient(testImpl, &ClientOptions{
347-
ToolListChangedHandler: func(context.Context, *ClientRequest[*ToolListChangedParams]) {
347+
ToolListChangedHandler: func(context.Context, *ToolListChangedRequest) {
348348
notifications <- "toolListChanged"
349349
},
350350
})

0 commit comments

Comments
 (0)