Skip to content

Commit c631641

Browse files
authored
mcp: check output schema type (#358)
Server.AddTool now checks that an output schema, if any, has type "object". Also add doc and a test.
1 parent d16ce9c commit c631641

File tree

2 files changed

+48
-12
lines changed

2 files changed

+48
-12
lines changed

mcp/server.go

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,17 @@ func (s *Server) RemovePrompts(names ...string) {
145145
// AddTool adds a [Tool] to the server, or replaces one with the same name.
146146
// The Tool argument must not be modified after this call.
147147
//
148-
// The tool's input schema must be non-nil. For a tool that takes no input,
149-
// or one where any input is valid, set [Tool.InputSchema] to the empty schema,
150-
// &jsonschema.Schema{}.
148+
// The tool's input schema must be non-nil and have the type "object". For a tool
149+
// that takes no input, or one where any input is valid, set [Tool.InputSchema] to
150+
// &jsonschema.Schema{Type: "object"}.
151+
//
152+
// If present, the output schema must also have type "object".
151153
//
152154
// When the handler is invoked as part of a CallTool request, req.Params.Arguments
153155
// will be a json.RawMessage. Unmarshaling the arguments and validating them against the
154156
// input schema are the handler author's responsibility.
155157
//
156-
// Most users will prefer the top-level function [AddTool].
158+
// Most users should use the top-level function [AddTool].
157159
func (s *Server) AddTool(t *Tool, h ToolHandler) {
158160
if t.InputSchema == nil {
159161
// This prevents the tool author from forgetting to write a schema where
@@ -165,6 +167,9 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) {
165167
if t.InputSchema.Type != "object" {
166168
panic(fmt.Errorf(`AddTool %q: input schema must have type "object"`, t.Name))
167169
}
170+
if t.OutputSchema != nil && t.OutputSchema.Type != "object" {
171+
panic(fmt.Errorf(`AddTool %q: output schema must have type "object"`, t.Name))
172+
}
168173
st := &serverTool{tool: t, handler: h}
169174
// Assume there was a change, since add replaces existing tools.
170175
// (It's possible a tool was replaced with an identical one, but not worth checking.)
@@ -176,9 +181,12 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) {
176181

177182
// ToolFor returns a shallow copy of t and a [ToolHandler] that wraps h.
178183
// If the tool's input schema is nil, it is set to the schema inferred from the In
179-
// type parameter, using [jsonschema.For].
184+
// type parameter, using [jsonschema.For]. The In type parameter must be a map
185+
// or a struct, so that its inferred JSON Schema has type "object".
186+
//
180187
// If the tool's output schema is nil and the Out type parameter is not the empty
181-
// interface, then the output schema is set to the schema inferred from Out.
188+
// interface, then the output schema is set to the schema inferred from Out, which
189+
// must be a map or a struct.
182190
//
183191
// Most users will call [AddTool]. Use [ToolFor] if you wish to modify the tool's
184192
// schemas or wrap the ToolHandler before calling [Server.AddTool].
@@ -305,12 +313,7 @@ func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved)
305313
}
306314

307315
// AddTool adds a tool and handler to the server.
308-
//
309-
// A shallow copy of the tool is made first.
310-
// If the tool's input schema is nil, the copy's input schema is set to the schema
311-
// inferred from the In type parameter, using [jsonschema.For].
312-
// If the tool's output schema is nil and the Out type parameter is not the empty
313-
// interface, then the copy's output schema is set to the schema inferred from Out.
316+
// It is a convenience for s.AddTool(ToolFor(t, h)).
314317
func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) {
315318
s.AddTool(ToolFor(t, h))
316319
}

mcp/server_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,3 +456,36 @@ func TestServerSessionkeepaliveCancelOverwritten(t *testing.T) {
456456
t.Fatal("expected ServerSession.keepaliveCancel to be nil after we manually niled it and re-initialized")
457457
}
458458
}
459+
460+
// panicks reports whether f() panics.
461+
func panics(f func()) (b bool) {
462+
defer func() {
463+
b = recover() != nil
464+
}()
465+
f()
466+
return false
467+
}
468+
469+
func TestAddTool(t *testing.T) {
470+
// AddTool should panic if In or Out are not JSON objects.
471+
s := NewServer(testImpl, nil)
472+
if !panics(func() {
473+
AddTool(s, &Tool{Name: "T1"}, func(context.Context, *CallToolRequest, string) (*CallToolResult, any, error) { return nil, nil, nil })
474+
}) {
475+
t.Error("bad In: expected panic")
476+
}
477+
if panics(func() {
478+
AddTool(s, &Tool{Name: "T2"}, func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) {
479+
return nil, nil, nil
480+
})
481+
}) {
482+
t.Error("good In: expected no panic")
483+
}
484+
if !panics(func() {
485+
AddTool(s, &Tool{Name: "T2"}, func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, int, error) {
486+
return nil, 0, nil
487+
})
488+
}) {
489+
t.Error("bad Out: expected panic")
490+
}
491+
}

0 commit comments

Comments
 (0)