Skip to content

Commit 016ddfe

Browse files
committed
mcp: use RawMessage for CallToolParams.Arguments
WIP DO NOT REVIEW DO NOT SUBMIT
1 parent bf79d78 commit 016ddfe

File tree

15 files changed

+93
-98
lines changed

15 files changed

+93
-98
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ func main() {
8181
defer session.Close()
8282

8383
// Call a tool on the server.
84-
params := &mcp.CallToolParams{
85-
Name: "greet",
86-
Arguments: map[string]any{"name": "you"},
87-
}
88-
res, err := session.CallTool(ctx, params)
84+
res, err := session.CallTool(ctx, &mcp.CallToolParams{
85+
Name: "greet",
86+
}, map[string]any{
87+
"name": "you",
88+
})
8989
if err != nil {
9090
log.Fatalf("CallTool failed: %v", err)
9191
}

examples/server/memory/kb.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,10 +443,6 @@ func (k knowledgeBase) CreateEntities(ctx context.Context, req *mcp.CallToolRequ
443443
&mcp.TextContent{Text: "Entities created successfully"},
444444
}
445445

446-
res.StructuredContent = CreateEntitiesResult{
447-
Entities: entities,
448-
}
449-
450446
return &res, CreateEntitiesResult{Entities: entities}, nil
451447
}
452448

internal/readme/client/client.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ func main() {
2828
defer session.Close()
2929

3030
// Call a tool on the server.
31-
params := &mcp.CallToolParams{
32-
Name: "greet",
33-
Arguments: map[string]any{"name": "you"},
34-
}
35-
res, err := session.CallTool(ctx, params)
31+
res, err := session.CallTool(ctx, &mcp.CallToolParams{
32+
Name: "greet",
33+
}, map[string]any{
34+
"name": "you",
35+
})
3636
if err != nil {
3737
log.Fatalf("CallTool failed: %v", err)
3838
}

mcp/client.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package mcp
66

77
import (
88
"context"
9+
"encoding/json"
910
"fmt"
1011
"iter"
1112
"slices"
@@ -387,13 +388,22 @@ func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams)
387388

388389
// CallTool calls the tool with the given name and arguments.
389390
// The arguments can be any value that marshals into a JSON object.
390-
func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (*CallToolResult, error) {
391+
func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams, args any) (*CallToolResult, error) {
391392
if params == nil {
392393
params = new(CallToolParams)
393394
}
395+
if args != nil {
396+
assert(len(params.Arguments) == 0, "non-nil args with non-empty params.Arguments")
397+
398+
data, err := json.Marshal(args)
399+
if err != nil {
400+
return nil, fmt.Errorf("marshalling args: %v", err)
401+
}
402+
params.Arguments = json.RawMessage(data)
403+
}
394404
if params.Arguments == nil {
395405
// Avoid sending nil over the wire.
396-
params.Arguments = map[string]any{}
406+
params.Arguments = json.RawMessage(`{}`)
397407
}
398408
return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params)))
399409
}

mcp/cmd_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,9 @@ func TestCmdTransport(t *testing.T) {
203203
t.Fatal(err)
204204
}
205205
got, err := session.CallTool(ctx, &mcp.CallToolParams{
206-
Name: "greet",
207-
Arguments: map[string]any{"name": "user"},
206+
Name: "greet",
207+
}, map[string]any{
208+
"name": "user",
208209
})
209210
if err != nil {
210211
t.Fatal(err)

mcp/example_middleware_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,8 @@ func Example_loggingMiddleware() {
121121
// Call the tool to demonstrate logging
122122
result, _ := clientSession.CallTool(ctx, &mcp.CallToolParams{
123123
Name: "greet",
124-
Arguments: map[string]any{
125-
"name": "World",
126-
},
124+
}, map[string]any{
125+
"name": "World",
127126
})
128127

129128
fmt.Printf("Tool result: %s\n", result.Content[0].(*mcp.TextContent).Text)

mcp/mcp_test.go

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,9 @@ func TestEndToEnd(t *testing.T) {
213213
t.Run("tools", func(t *testing.T) {
214214
// ListTools is tested in client_list_test.go.
215215
gotHi, err := cs.CallTool(ctx, &CallToolParams{
216-
Name: "greet",
217-
Arguments: map[string]any{"name": "user"},
216+
Name: "greet",
217+
}, map[string]any{
218+
"name": "user",
218219
})
219220
if err != nil {
220221
t.Fatal(err)
@@ -228,10 +229,7 @@ func TestEndToEnd(t *testing.T) {
228229
t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff)
229230
}
230231

231-
gotFail, err := cs.CallTool(ctx, &CallToolParams{
232-
Name: "fail",
233-
Arguments: map[string]any{},
234-
})
232+
gotFail, err := cs.CallTool(ctx, &CallToolParams{Name: "fail"}, nil)
235233
// Counter-intuitively, when a tool fails, we don't expect an RPC error for
236234
// call tool: instead, the failure is embedded in the result.
237235
if err != nil {
@@ -605,16 +603,18 @@ func TestServerClosing(t *testing.T) {
605603
wg.Done()
606604
}()
607605
if _, err := cs.CallTool(ctx, &CallToolParams{
608-
Name: "greet",
609-
Arguments: map[string]any{"name": "user"},
606+
Name: "greet",
607+
}, map[string]any{
608+
"name": "user",
610609
}); err != nil {
611610
t.Fatalf("after connecting: %v", err)
612611
}
613612
ss.Close()
614613
wg.Wait()
615614
if _, err := cs.CallTool(ctx, &CallToolParams{
616-
Name: "greet",
617-
Arguments: map[string]any{"name": "user"},
615+
Name: "greet",
616+
}, map[string]any{
617+
"name": "user",
618618
}); !errors.Is(err, ErrConnectionClosed) {
619619
t.Errorf("after disconnection, got error %v, want EOF", err)
620620
}
@@ -679,7 +679,7 @@ func TestCancellation(t *testing.T) {
679679
defer cs.Close()
680680

681681
ctx, cancel := context.WithCancel(context.Background())
682-
go cs.CallTool(ctx, &CallToolParams{Name: "slow"})
682+
go cs.CallTool(ctx, &CallToolParams{Name: "slow"}, nil)
683683
<-start
684684
cancel()
685685
select {
@@ -892,8 +892,9 @@ func TestKeepAlive(t *testing.T) {
892892

893893
// Test that the connection is still alive by making a call
894894
result, err := cs.CallTool(ctx, &CallToolParams{
895-
Name: "greet",
896-
Arguments: map[string]any{"Name": "user"},
895+
Name: "greet",
896+
}, map[string]any{
897+
"Name": "user",
897898
})
898899
if err != nil {
899900
t.Fatalf("call failed after keepalive: %v", err)
@@ -942,8 +943,9 @@ func TestKeepAliveFailure(t *testing.T) {
942943
deadline := time.Now().Add(1 * time.Second)
943944
for time.Now().Before(deadline) {
944945
_, err = cs.CallTool(ctx, &CallToolParams{
945-
Name: "greet",
946-
Arguments: map[string]any{"Name": "user"},
946+
Name: "greet",
947+
}, map[string]any{
948+
"Name": "user",
947949
})
948950
if errors.Is(err, ErrConnectionClosed) {
949951
return // Test passed
@@ -1025,7 +1027,7 @@ func TestSynchronousNotifications(t *testing.T) {
10251027

10261028
t.Run("from client", func(t *testing.T) {
10271029
client.AddRoots(&Root{Name: "myroot", URI: "file://foo"})
1028-
res, err := cs.CallTool(context.Background(), &CallToolParams{Name: "tool"})
1030+
res, err := cs.CallTool(context.Background(), &CallToolParams{Name: "tool"}, nil)
10291031
if err != nil {
10301032
t.Fatalf("CallTool failed: %v", err)
10311033
}
@@ -1058,7 +1060,7 @@ func TestNoDistributedDeadlock(t *testing.T) {
10581060
// delegates synchronization to the user.
10591061
clientOpts := &ClientOptions{
10601062
CreateMessageHandler: func(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) {
1061-
req.Session.CallTool(ctx, &CallToolParams{Name: "tool2"})
1063+
req.Session.CallTool(ctx, &CallToolParams{Name: "tool2"}, nil)
10621064
return &CreateMessageResult{Content: &TextContent{}}, nil
10631065
},
10641066
}
@@ -1077,7 +1079,7 @@ func TestNoDistributedDeadlock(t *testing.T) {
10771079

10781080
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
10791081
defer cancel()
1080-
if _, err := cs.CallTool(ctx, &CallToolParams{Name: "tool1"}); err != nil {
1082+
if _, err := cs.CallTool(ctx, &CallToolParams{Name: "tool1"}, nil); err != nil {
10811083
// should not deadlock
10821084
t.Fatalf("CallTool failed: %v", err)
10831085
}
@@ -1155,11 +1157,11 @@ func TestPointerArgEquivalence(t *testing.T) {
11551157

11561158
// Then, check that we handle empty input equivalently.
11571159
for _, args := range []any{nil, struct{}{}} {
1158-
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: args})
1160+
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name}, args)
11591161
if err != nil {
11601162
t.Fatal(err)
11611163
}
1162-
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: args})
1164+
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name}, args)
11631165
if err != nil {
11641166
t.Fatal(err)
11651167
}
@@ -1171,11 +1173,11 @@ func TestPointerArgEquivalence(t *testing.T) {
11711173
// Then, check that we handle different types of output equivalently.
11721174
for _, in := range []string{"nil", "empty", "ok"} {
11731175
t.Run(in, func(t *testing.T) {
1174-
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: input{In: in}})
1176+
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name}, input{In: in})
11751177
if err != nil {
11761178
t.Fatal(err)
11771179
}
1178-
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: input{In: in}})
1180+
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name}, input{In: in})
11791181
if err != nil {
11801182
t.Fatal(err)
11811183
}

mcp/protocol.go

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,24 +44,8 @@ type CallToolParams struct {
4444
// This property is reserved by the protocol to allow clients and servers to
4545
// attach additional metadata to their responses.
4646
Meta `json:"_meta,omitempty"`
47-
Name string `json:"name"`
48-
Arguments any `json:"arguments,omitempty"`
49-
}
50-
51-
// When unmarshalling CallToolParams on the server side, we need to delay unmarshaling of the arguments.
52-
func (c *CallToolParams) UnmarshalJSON(data []byte) error {
53-
var raw struct {
54-
Meta `json:"_meta,omitempty"`
55-
Name string `json:"name"`
56-
RawArguments json.RawMessage `json:"arguments,omitempty"`
57-
}
58-
if err := json.Unmarshal(data, &raw); err != nil {
59-
return err
60-
}
61-
c.Meta = raw.Meta
62-
c.Name = raw.Name
63-
c.Arguments = raw.RawArguments
64-
return nil
47+
Name string `json:"name"`
48+
Arguments json.RawMessage `json:"arguments,omitempty"`
6549
}
6650

6751
// The server's response to a tool call.
@@ -74,7 +58,7 @@ type CallToolResult struct {
7458
Content []Content `json:"content"`
7559
// An optional JSON object that represents the structured result of the tool
7660
// call.
77-
StructuredContent any `json:"structuredContent,omitempty"`
61+
StructuredContent json.RawMessage `json:"structuredContent,omitempty"`
7862
// Whether the tool call ended in an error.
7963
//
8064
// If not set, this is assumed to be false (the call was successful).

mcp/protocol_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ func TestContentUnmarshal(t *testing.T) {
510510
Meta: Meta{"m": true},
511511
Content: content,
512512
IsError: true,
513-
StructuredContent: map[string]any{"s": "x"},
513+
StructuredContent: mustMarshal(map[string]any{"s": "x"}),
514514
}
515515
var got CallToolResult
516516
roundtrip(ctr, &got)
@@ -519,7 +519,7 @@ func TestContentUnmarshal(t *testing.T) {
519519
Meta: Meta{"m": true},
520520
Content: content,
521521
IsError: true,
522-
StructuredContent: 3.0,
522+
StructuredContent: mustMarshal(3.0),
523523
}
524524
var gotf CallToolResult
525525
roundtrip(ctrf, &gotf)

mcp/server.go

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,6 @@ func (s *Server) RemovePrompts(names ...string) {
149149
// or one where any input is valid, set [Tool.InputSchema] to the empty schema,
150150
// &jsonschema.Schema{}.
151151
//
152-
// When the handler is invoked as part of a CallTool request, req.Params.Arguments
153-
// will be a json.RawMessage. Unmarshaling the arguments and validating them against the
154-
// input schema are the handler author's responsibility.
155-
//
156152
// Most users will prefer the top-level function [AddTool].
157153
func (s *Server) AddTool(t *Tool, h ToolHandler) {
158154
if t.InputSchema == nil {
@@ -214,7 +210,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
214210

215211
th := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) {
216212
// Unmarshal and validate args.
217-
rawArgs := req.Params.Arguments.(json.RawMessage)
213+
rawArgs := req.Params.Arguments
218214
var in In
219215
if rawArgs != nil {
220216
if err := unmarshalSchema(rawArgs, inputResolved, &in); err != nil {
@@ -249,14 +245,23 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
249245
if res == nil {
250246
res = &CallToolResult{}
251247
}
252-
res.StructuredContent = out
248+
var toMarshal any = out
253249
if elemZero != nil {
254250
// Avoid typed nil, which will serialize as JSON null.
255251
// Instead, use the zero value of the non-zero
256252
var z Out
257253
if any(out) == any(z) { // zero is only non-nil if Out is a pointer type
258-
res.StructuredContent = elemZero
254+
toMarshal = elemZero
255+
}
256+
}
257+
if reflect.ValueOf(toMarshal).IsValid() {
258+
// TODO: we should probably also check that toMarshal is a valid JSON
259+
// object type--a (pointer to) map or struct.
260+
structuredOut, err := json.Marshal(toMarshal)
261+
if err != nil {
262+
return nil, fmt.Errorf("marshalling result: %v", err)
259263
}
264+
res.StructuredContent = structuredOut
260265
}
261266
return res, nil
262267
}

0 commit comments

Comments
 (0)