Skip to content

mcp: use RawMessage for CallToolParams.Arguments #346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ func main() {
defer session.Close()

// Call a tool on the server.
params := &mcp.CallToolParams{
Name: "greet",
Arguments: map[string]any{"name": "you"},
}
res, err := session.CallTool(ctx, params)
res, err := session.CallTool(ctx, &mcp.CallToolParams{
Name: "greet",
}, map[string]any{
"name": "you",
})
if err != nil {
log.Fatalf("CallTool failed: %v", err)
}
Expand Down
4 changes: 0 additions & 4 deletions examples/server/memory/kb.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,6 @@ func (k knowledgeBase) CreateEntities(ctx context.Context, req *mcp.CallToolRequ
&mcp.TextContent{Text: "Entities created successfully"},
}

res.StructuredContent = CreateEntitiesResult{
Entities: entities,
}

return &res, CreateEntitiesResult{Entities: entities}, nil
}

Expand Down
10 changes: 5 additions & 5 deletions internal/readme/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ func main() {
defer session.Close()

// Call a tool on the server.
params := &mcp.CallToolParams{
Name: "greet",
Arguments: map[string]any{"name": "you"},
}
res, err := session.CallTool(ctx, params)
res, err := session.CallTool(ctx, &mcp.CallToolParams{
Name: "greet",
}, map[string]any{
"name": "you",
})
if err != nil {
log.Fatalf("CallTool failed: %v", err)
}
Expand Down
18 changes: 15 additions & 3 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package mcp

import (
"context"
"encoding/json"
"fmt"
"iter"
"slices"
Expand Down Expand Up @@ -386,14 +387,25 @@ func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams)
}

// CallTool calls the tool with the given name and arguments.
// The arguments can be any value that marshals into a JSON object.
func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (*CallToolResult, error) {
//
// If args is non-nil, it is marshalled into params.Arguments.
// CallToolPanics if args is non-nil and params.Arguments is non-empty.
func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams, args any) (*CallToolResult, error) {
if params == nil {
params = new(CallToolParams)
}
if args != nil {
assert(len(params.Arguments) == 0, "non-nil args with non-empty params.Arguments")

data, err := json.Marshal(args)
if err != nil {
return nil, fmt.Errorf("marshalling args: %v", err)
}
params.Arguments = json.RawMessage(data)
}
if params.Arguments == nil {
// Avoid sending nil over the wire.
params.Arguments = map[string]any{}
params.Arguments = json.RawMessage(`{}`)
}
return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params)))
}
Expand Down
5 changes: 3 additions & 2 deletions mcp/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,9 @@ func TestCmdTransport(t *testing.T) {
t.Fatal(err)
}
got, err := session.CallTool(ctx, &mcp.CallToolParams{
Name: "greet",
Arguments: map[string]any{"name": "user"},
Name: "greet",
}, map[string]any{
"name": "user",
})
if err != nil {
t.Fatal(err)
Expand Down
5 changes: 2 additions & 3 deletions mcp/example_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,8 @@ func Example_loggingMiddleware() {
// Call the tool to demonstrate logging
result, _ := clientSession.CallTool(ctx, &mcp.CallToolParams{
Name: "greet",
Arguments: map[string]any{
"name": "World",
},
}, map[string]any{
"name": "World",
})

fmt.Printf("Tool result: %s\n", result.Content[0].(*mcp.TextContent).Text)
Expand Down
46 changes: 24 additions & 22 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,9 @@ func TestEndToEnd(t *testing.T) {
t.Run("tools", func(t *testing.T) {
// ListTools is tested in client_list_test.go.
gotHi, err := cs.CallTool(ctx, &CallToolParams{
Name: "greet",
Arguments: map[string]any{"name": "user"},
Name: "greet",
}, map[string]any{
"name": "user",
})
if err != nil {
t.Fatal(err)
Expand All @@ -228,10 +229,7 @@ func TestEndToEnd(t *testing.T) {
t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff)
}

gotFail, err := cs.CallTool(ctx, &CallToolParams{
Name: "fail",
Arguments: map[string]any{},
})
gotFail, err := cs.CallTool(ctx, &CallToolParams{Name: "fail"}, nil)
// Counter-intuitively, when a tool fails, we don't expect an RPC error for
// call tool: instead, the failure is embedded in the result.
if err != nil {
Expand Down Expand Up @@ -605,16 +603,18 @@ func TestServerClosing(t *testing.T) {
wg.Done()
}()
if _, err := cs.CallTool(ctx, &CallToolParams{
Name: "greet",
Arguments: map[string]any{"name": "user"},
Name: "greet",
}, map[string]any{
"name": "user",
}); err != nil {
t.Fatalf("after connecting: %v", err)
}
ss.Close()
wg.Wait()
if _, err := cs.CallTool(ctx, &CallToolParams{
Name: "greet",
Arguments: map[string]any{"name": "user"},
Name: "greet",
}, map[string]any{
"name": "user",
}); !errors.Is(err, ErrConnectionClosed) {
t.Errorf("after disconnection, got error %v, want EOF", err)
}
Expand Down Expand Up @@ -679,7 +679,7 @@ func TestCancellation(t *testing.T) {
defer cs.Close()

ctx, cancel := context.WithCancel(context.Background())
go cs.CallTool(ctx, &CallToolParams{Name: "slow"})
go cs.CallTool(ctx, &CallToolParams{Name: "slow"}, nil)
<-start
cancel()
select {
Expand Down Expand Up @@ -892,8 +892,9 @@ func TestKeepAlive(t *testing.T) {

// Test that the connection is still alive by making a call
result, err := cs.CallTool(ctx, &CallToolParams{
Name: "greet",
Arguments: map[string]any{"Name": "user"},
Name: "greet",
}, map[string]any{
"Name": "user",
})
if err != nil {
t.Fatalf("call failed after keepalive: %v", err)
Expand Down Expand Up @@ -942,8 +943,9 @@ func TestKeepAliveFailure(t *testing.T) {
deadline := time.Now().Add(1 * time.Second)
for time.Now().Before(deadline) {
_, err = cs.CallTool(ctx, &CallToolParams{
Name: "greet",
Arguments: map[string]any{"Name": "user"},
Name: "greet",
}, map[string]any{
"Name": "user",
})
if errors.Is(err, ErrConnectionClosed) {
return // Test passed
Expand Down Expand Up @@ -1025,7 +1027,7 @@ func TestSynchronousNotifications(t *testing.T) {

t.Run("from client", func(t *testing.T) {
client.AddRoots(&Root{Name: "myroot", URI: "file://foo"})
res, err := cs.CallTool(context.Background(), &CallToolParams{Name: "tool"})
res, err := cs.CallTool(context.Background(), &CallToolParams{Name: "tool"}, nil)
if err != nil {
t.Fatalf("CallTool failed: %v", err)
}
Expand Down Expand Up @@ -1058,7 +1060,7 @@ func TestNoDistributedDeadlock(t *testing.T) {
// delegates synchronization to the user.
clientOpts := &ClientOptions{
CreateMessageHandler: func(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) {
req.Session.CallTool(ctx, &CallToolParams{Name: "tool2"})
req.Session.CallTool(ctx, &CallToolParams{Name: "tool2"}, nil)
return &CreateMessageResult{Content: &TextContent{}}, nil
},
}
Expand All @@ -1077,7 +1079,7 @@ func TestNoDistributedDeadlock(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := cs.CallTool(ctx, &CallToolParams{Name: "tool1"}); err != nil {
if _, err := cs.CallTool(ctx, &CallToolParams{Name: "tool1"}, nil); err != nil {
// should not deadlock
t.Fatalf("CallTool failed: %v", err)
}
Expand Down Expand Up @@ -1155,11 +1157,11 @@ func TestPointerArgEquivalence(t *testing.T) {

// Then, check that we handle empty input equivalently.
for _, args := range []any{nil, struct{}{}} {
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: args})
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name}, args)
if err != nil {
t.Fatal(err)
}
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: args})
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name}, args)
if err != nil {
t.Fatal(err)
}
Expand All @@ -1171,11 +1173,11 @@ func TestPointerArgEquivalence(t *testing.T) {
// Then, check that we handle different types of output equivalently.
for _, in := range []string{"nil", "empty", "ok"} {
t.Run(in, func(t *testing.T) {
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: input{In: in}})
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name}, input{In: in})
if err != nil {
t.Fatal(err)
}
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: input{In: in}})
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name}, input{In: in})
if err != nil {
t.Fatal(err)
}
Expand Down
22 changes: 3 additions & 19 deletions mcp/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,8 @@ type CallToolParams struct {
// This property is reserved by the protocol to allow clients and servers to
// attach additional metadata to their responses.
Meta `json:"_meta,omitempty"`
Name string `json:"name"`
Arguments any `json:"arguments,omitempty"`
}

// When unmarshalling CallToolParams on the server side, we need to delay unmarshaling of the arguments.
func (c *CallToolParams) UnmarshalJSON(data []byte) error {
var raw struct {
Meta `json:"_meta,omitempty"`
Name string `json:"name"`
RawArguments json.RawMessage `json:"arguments,omitempty"`
}
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
c.Meta = raw.Meta
c.Name = raw.Name
c.Arguments = raw.RawArguments
return nil
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments,omitempty"`
}

// The server's response to a tool call.
Expand All @@ -74,7 +58,7 @@ type CallToolResult struct {
Content []Content `json:"content"`
// An optional JSON object that represents the structured result of the tool
// call.
StructuredContent any `json:"structuredContent,omitempty"`
StructuredContent json.RawMessage `json:"structuredContent,omitempty"`
// Whether the tool call ended in an error.
//
// If not set, this is assumed to be false (the call was successful).
Expand Down
4 changes: 2 additions & 2 deletions mcp/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ func TestContentUnmarshal(t *testing.T) {
Meta: Meta{"m": true},
Content: content,
IsError: true,
StructuredContent: map[string]any{"s": "x"},
StructuredContent: mustMarshal(map[string]any{"s": "x"}),
}
var got CallToolResult
roundtrip(ctr, &got)
Expand All @@ -519,7 +519,7 @@ func TestContentUnmarshal(t *testing.T) {
Meta: Meta{"m": true},
Content: content,
IsError: true,
StructuredContent: 3.0,
StructuredContent: mustMarshal(3.0),
}
var gotf CallToolResult
roundtrip(ctrf, &gotf)
Expand Down
19 changes: 12 additions & 7 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,6 @@ func (s *Server) RemovePrompts(names ...string) {
// or one where any input is valid, set [Tool.InputSchema] to the empty schema,
// &jsonschema.Schema{}.
//
// When the handler is invoked as part of a CallTool request, req.Params.Arguments
// will be a json.RawMessage. Unmarshaling the arguments and validating them against the
// input schema are the handler author's responsibility.
//
// Most users will prefer the top-level function [AddTool].
func (s *Server) AddTool(t *Tool, h ToolHandler) {
if t.InputSchema == nil {
Expand Down Expand Up @@ -214,7 +210,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan

th := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) {
// Unmarshal and validate args.
rawArgs := req.Params.Arguments.(json.RawMessage)
rawArgs := req.Params.Arguments
var in In
if rawArgs != nil {
if err := unmarshalSchema(rawArgs, inputResolved, &in); err != nil {
Expand Down Expand Up @@ -249,14 +245,23 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
if res == nil {
res = &CallToolResult{}
}
res.StructuredContent = out
var toMarshal any = out
if elemZero != nil {
// Avoid typed nil, which will serialize as JSON null.
// Instead, use the zero value of the non-zero
var z Out
if any(out) == any(z) { // zero is only non-nil if Out is a pointer type
res.StructuredContent = elemZero
toMarshal = elemZero
}
}
if reflect.ValueOf(toMarshal).IsValid() {
// TODO: we should probably also check that toMarshal is a valid JSON
// object type--a (pointer to) map or struct.
structuredOut, err := json.Marshal(toMarshal)
if err != nil {
return nil, fmt.Errorf("marshalling result: %v", err)
}
res.StructuredContent = structuredOut
}
return res, nil
}
Expand Down
5 changes: 3 additions & 2 deletions mcp/server_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ func ExampleServer() {
}

res, err := clientSession.CallTool(ctx, &mcp.CallToolParams{
Name: "greet",
Arguments: map[string]any{"name": "user"},
Name: "greet",
}, map[string]any{
"name": "user",
})
if err != nil {
log.Fatal(err)
Expand Down
5 changes: 3 additions & 2 deletions mcp/sse_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ func ExampleSSEHandler() {
defer cs.Close()

res, err := cs.CallTool(ctx, &mcp.CallToolParams{
Name: "add",
Arguments: map[string]any{"x": 1, "y": 2},
Name: "add",
}, map[string]any{
"x": 1, "y": 2,
})
if err != nil {
log.Fatal(err)
Expand Down
5 changes: 3 additions & 2 deletions mcp/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ func TestSSEServer(t *testing.T) {
}
ss := <-serverSessions
gotHi, err := cs.CallTool(ctx, &CallToolParams{
Name: "greet",
Arguments: map[string]any{"Name": "user"},
Name: "greet",
}, map[string]any{
"Name": "user",
})
if err != nil {
t.Fatal(err)
Expand Down
Loading
Loading