Skip to content

Commit 43ad1eb

Browse files
authored
mcp: treat pointers equivalently to non-pointers when deriving schema
As reported in #199 and #200, the fact that we return a possibly "null" schema for pointer types breaks various clients, which expect schemas to be of type "object". This is an unfortunate footgun. For now, assume that the user wants us to treat pointers equivalently to non-pointers. If we want to change this behavior in the future, we can do so behind an option. + a test Also fix the handling of nil results in the case where the output schema is non-nil: we must provide structured content in this case. (This was causing the test to fail). Fixes #199 Fixes #200
1 parent 5b1f328 commit 43ad1eb

File tree

2 files changed

+161
-23
lines changed

2 files changed

+161
-23
lines changed

mcp/mcp_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,3 +1084,104 @@ func TestNoDistributedDeadlock(t *testing.T) {
10841084
}
10851085

10861086
var testImpl = &Implementation{Name: "test", Version: "v1.0.0"}
1087+
1088+
// This test checks that when we use pointer types for tools, we get the same
1089+
// schema as when using the non-pointer types. It is too much of a footgun for
1090+
// there to be a difference (see #199 and #200).
1091+
//
1092+
// If anyone asks, we can add an option that controls how pointers are treated.
1093+
func TestPointerArgEquivalence(t *testing.T) {
1094+
type input struct {
1095+
In string
1096+
}
1097+
type output struct {
1098+
Out string
1099+
}
1100+
cs, _ := basicConnection(t, func(s *Server) {
1101+
// Add two equivalent tools, one of which operates in the 'pointer' realm,
1102+
// the other of which does not.
1103+
//
1104+
// We handle a few different types of results, to assert they behave the
1105+
// same in all cases.
1106+
AddTool(s, &Tool{Name: "pointer"}, func(_ context.Context, req *ServerRequest[*CallToolParams], in *input) (*CallToolResult, *output, error) {
1107+
switch in.In {
1108+
case "":
1109+
return nil, nil, fmt.Errorf("must provide input")
1110+
case "nil":
1111+
return nil, nil, nil
1112+
case "empty":
1113+
return &CallToolResult{}, nil, nil
1114+
case "ok":
1115+
return &CallToolResult{}, &output{Out: "foo"}, nil
1116+
default:
1117+
panic("unreachable")
1118+
}
1119+
})
1120+
AddTool(s, &Tool{Name: "nonpointer"}, func(_ context.Context, req *ServerRequest[*CallToolParams], in input) (*CallToolResult, output, error) {
1121+
switch in.In {
1122+
case "":
1123+
return nil, output{}, fmt.Errorf("must provide input")
1124+
case "nil":
1125+
return nil, output{}, nil
1126+
case "empty":
1127+
return &CallToolResult{}, output{}, nil
1128+
case "ok":
1129+
return &CallToolResult{}, output{Out: "foo"}, nil
1130+
default:
1131+
panic("unreachable")
1132+
}
1133+
})
1134+
})
1135+
defer cs.Close()
1136+
1137+
ctx := context.Background()
1138+
tools, err := cs.ListTools(ctx, nil)
1139+
if err != nil {
1140+
t.Fatal(err)
1141+
}
1142+
if got, want := len(tools.Tools), 2; got != want {
1143+
t.Fatalf("got %d tools, want %d", got, want)
1144+
}
1145+
t0 := tools.Tools[0]
1146+
t1 := tools.Tools[1]
1147+
1148+
// First, check that the tool schemas don't differ.
1149+
if diff := cmp.Diff(t0.InputSchema, t1.InputSchema); diff != "" {
1150+
t.Errorf("input schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff)
1151+
}
1152+
if diff := cmp.Diff(t0.OutputSchema, t1.OutputSchema); diff != "" {
1153+
t.Errorf("output schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff)
1154+
}
1155+
1156+
// Then, check that we handle empty input equivalently.
1157+
for _, args := range []any{nil, struct{}{}} {
1158+
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: args})
1159+
if err != nil {
1160+
t.Fatal(err)
1161+
}
1162+
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: args})
1163+
if err != nil {
1164+
t.Fatal(err)
1165+
}
1166+
if diff := cmp.Diff(r0, r1); diff != "" {
1167+
t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff)
1168+
}
1169+
}
1170+
1171+
// Then, check that we handle different types of output equivalently.
1172+
for _, in := range []string{"nil", "empty", "ok"} {
1173+
t.Run(in, func(t *testing.T) {
1174+
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: input{In: in}})
1175+
if err != nil {
1176+
t.Fatal(err)
1177+
}
1178+
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: input{In: in}})
1179+
if err != nil {
1180+
t.Fatal(err)
1181+
}
1182+
if diff := cmp.Diff(r0, r1); diff != "" {
1183+
t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff)
1184+
}
1185+
})
1186+
}
1187+
}

mcp/server.go

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -189,31 +189,26 @@ func toolFor[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandle
189189

190190
// TODO(v0.3.0): test
191191
func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) {
192-
var err error
193192
tt := *t
194-
tt.InputSchema = t.InputSchema
195-
if tt.InputSchema == nil {
196-
tt.InputSchema, err = jsonschema.For[In](nil)
193+
var inputResolved *jsonschema.Resolved
194+
if _, err := setSchema[In](&tt.InputSchema, &inputResolved); err != nil {
195+
return nil, nil, fmt.Errorf("input schema: %w", err)
196+
}
197+
198+
// Handling for zero values:
199+
//
200+
// If Out is a pointer type and we've derived the output schema from its
201+
// element type, use the zero value of its element type in place of a typed
202+
// nil.
203+
var (
204+
elemZero any // only non-nil if Out is a pointer type
205+
outputResolved *jsonschema.Resolved
206+
)
207+
if reflect.TypeFor[Out]() != reflect.TypeFor[any]() {
208+
var err error
209+
elemZero, err = setSchema[Out](&t.OutputSchema, &outputResolved)
197210
if err != nil {
198-
return nil, nil, fmt.Errorf("input schema: %w", err)
199-
}
200-
}
201-
inputResolved, err := tt.InputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
202-
if err != nil {
203-
return nil, nil, fmt.Errorf("resolving input schema: %w", err)
204-
}
205-
206-
if tt.OutputSchema == nil && reflect.TypeFor[Out]() != reflect.TypeFor[any]() {
207-
tt.OutputSchema, err = jsonschema.For[Out](nil)
208-
}
209-
if err != nil {
210-
return nil, nil, fmt.Errorf("output schema: %w", err)
211-
}
212-
var outputResolved *jsonschema.Resolved
213-
if tt.OutputSchema != nil {
214-
outputResolved, err = tt.OutputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
215-
if err != nil {
216-
return nil, nil, fmt.Errorf("resolving output schema: %w", err)
211+
return nil, nil, fmt.Errorf("output schema: %v", err)
217212
}
218213
}
219214

@@ -255,12 +250,54 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
255250
res = &CallToolResult{}
256251
}
257252
res.StructuredContent = out
253+
if elemZero != nil {
254+
// Avoid typed nil, which will serialize as JSON null.
255+
// Instead, use the zero value of the non-zero
256+
var z Out
257+
if any(out) == any(z) { // zero is only non-nil if Out is a pointer type
258+
res.StructuredContent = elemZero
259+
}
260+
}
261+
if tt.OutputSchema != nil && elemZero != nil {
262+
res.StructuredContent = elemZero
263+
}
258264
return res, nil
259265
}
260266

261267
return &tt, th, nil
262268
}
263269

270+
// setSchema sets the schema and resolved schema corresponding to the type T.
271+
//
272+
// If sfield is nil, the schema is derived from T.
273+
//
274+
// Pointers are treated equivalently to non-pointers when deriving the schema.
275+
// If an indirection occurred to derive the schema, a non-nil zero value is
276+
// returned to be used in place of the typed nil zero value.
277+
//
278+
// Note that if sfield already holds a schema, zero will be nil even if T is a
279+
// pointer: if the user provided the schema, they may have intentionally
280+
// derived it from the pointer type, and handling of zero values is up to them.
281+
//
282+
// TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we
283+
// should have a jsonschema.Zero(schema) helper?
284+
func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) (zero any, err error) {
285+
rt := reflect.TypeFor[T]()
286+
if *sfield == nil {
287+
if rt.Kind() == reflect.Pointer {
288+
rt = rt.Elem()
289+
zero = reflect.Zero(rt).Interface()
290+
}
291+
// TODO: we should be able to pass nil opts here.
292+
*sfield, err = jsonschema.ForType(rt, &jsonschema.ForOptions{})
293+
}
294+
if err != nil {
295+
return zero, err
296+
}
297+
*rfield, err = (*sfield).Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
298+
return zero, err
299+
}
300+
264301
// AddTool adds a tool and handler to the server.
265302
//
266303
// A shallow copy of the tool is made first.

0 commit comments

Comments
 (0)