Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,16 +287,16 @@ func (c *Client) AddReceivingMiddleware(middleware ...Middleware[*ClientSession]

// clientMethodInfos maps from the RPC method name to serverMethodInfos.
var clientMethodInfos = map[string]methodInfo{
methodComplete: newMethodInfo(sessionMethod((*ClientSession).Complete), true),
methodPing: newMethodInfo(sessionMethod((*ClientSession).ping), true),
methodListRoots: newMethodInfo(clientMethod((*Client).listRoots), true),
methodCreateMessage: newMethodInfo(clientMethod((*Client).createMessage), true),
notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler), false),
notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler), false),
notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler), false),
notificationResourceUpdated: newMethodInfo(clientMethod((*Client).callResourceUpdatedHandler), false),
notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler), false),
notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler), false),
methodComplete: newMethodInfo(sessionMethod((*ClientSession).Complete), true, nil),
methodPing: newMethodInfo(sessionMethod((*ClientSession).ping), true, nil),
methodListRoots: newMethodInfo(clientMethod((*Client).listRoots), true, nil),
methodCreateMessage: newMethodInfo(clientMethod((*Client).createMessage), true, nil),
notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler), false, nil),
notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler), false, nil),
notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler), false, nil),
notificationResourceUpdated: newMethodInfo(clientMethod((*Client).callResourceUpdatedHandler), false, nil),
notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler), false, nil),
notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler), false, nil),
}

func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo {
Expand Down
49 changes: 32 additions & 17 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
"github.com/modelcontextprotocol/go-sdk/internal/util"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
"github.com/modelcontextprotocol/go-sdk/jsonschema"
)

const DefaultPageSize = 1000
Expand Down Expand Up @@ -682,24 +683,38 @@ func (s *Server) AddReceivingMiddleware(middleware ...Middleware[*ServerSession]
addMiddleware(&s.receivingMethodHandler_, middleware)
}

func must[T any](t T, err error) T {
if err != nil {
panic(err)
}
return t
}

// serverMethodInfos maps from the RPC method name to serverMethodInfos.
var serverMethodInfos = map[string]methodInfo{
methodComplete: newMethodInfo(serverMethod((*Server).complete), true),
methodInitialize: newMethodInfo(sessionMethod((*ServerSession).initialize), true),
methodPing: newMethodInfo(sessionMethod((*ServerSession).ping), true),
methodListPrompts: newMethodInfo(serverMethod((*Server).listPrompts), true),
methodGetPrompt: newMethodInfo(serverMethod((*Server).getPrompt), true),
methodListTools: newMethodInfo(serverMethod((*Server).listTools), true),
methodCallTool: newMethodInfo(serverMethod((*Server).callTool), true),
methodListResources: newMethodInfo(serverMethod((*Server).listResources), true),
methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates), true),
methodReadResource: newMethodInfo(serverMethod((*Server).readResource), true),
methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel), true),
methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe), true),
methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe), true),
notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler), false),
notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler), false),
notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler), false),
var serverMethodInfos map[string]methodInfo

func init() {
initializeSchema := must(jsonschema.For[*InitializeParams]())
initializeSchema.Required = []string{"capabilities", "clientInfo", "protocolVersion"}

serverMethodInfos = map[string]methodInfo{
methodComplete: newMethodInfo(serverMethod((*Server).complete), true, nil),
methodInitialize: newMethodInfo(sessionMethod((*ServerSession).initialize), true, initializeSchema),
methodPing: newMethodInfo(sessionMethod((*ServerSession).ping), true, nil),
methodListPrompts: newMethodInfo(serverMethod((*Server).listPrompts), true, nil),
methodGetPrompt: newMethodInfo(serverMethod((*Server).getPrompt), true, nil),
methodListTools: newMethodInfo(serverMethod((*Server).listTools), true, nil),
methodCallTool: newMethodInfo(serverMethod((*Server).callTool), true, nil),
methodListResources: newMethodInfo(serverMethod((*Server).listResources), true, nil),
methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates), true, nil),
methodReadResource: newMethodInfo(serverMethod((*Server).readResource), true, nil),
methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel), true, nil),
methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe), true, nil),
methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe), true, nil),
notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler), false, nil),
notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler), false, nil),
notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler), false, nil),
}
}

func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos }
Expand Down
20 changes: 18 additions & 2 deletions mcp/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
"github.com/modelcontextprotocol/go-sdk/jsonschema"
)

// latestProtocolVersion is the latest protocol version that this version of the SDK supports.
Expand Down Expand Up @@ -197,14 +198,29 @@ type paramsPtr[T any] interface {
//
// If isRequest is set, the method is treated as a request rather than a
// notification.
func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHandler[S, P, R], isRequest bool) methodInfo {
func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHandler[S, P, R], isRequest bool, paramSchema *jsonschema.Schema) methodInfo {
var resolved *jsonschema.Resolved
if paramSchema != nil {
var err error
resolved, err = paramSchema.Resolve(nil)
if err != nil {
panic(err)
}
}
return methodInfo{
isRequest: isRequest,
unmarshalParams: func(m json.RawMessage) (Params, error) {
var p P
if m != nil {
if err := json.Unmarshal(m, &p); err != nil {
return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, p, err)
return nil, fmt.Errorf("unmarshaling params into a %T: %w", p, err)
}
} else if resolved != nil {
return nil, fmt.Errorf(`missing required "params"`)
}
if resolved != nil {
if err := resolved.Validate(p); err != nil {
return nil, fmt.Errorf("invalid params: %v", err)
}
}
return orZero[Params](p), nil
Expand Down
15 changes: 14 additions & 1 deletion mcp/testdata/conformance/server/bad_requests.txtar
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ bad requests.
Fixed bugs:
- No id in 'initialize' should not panic (#197).
- No id in 'ping' should not panic (#194).
- No params in 'initialize' should not panic (#195).
- Notifications with IDs should not be treated like requests.

TODO:
- No params in 'initialize' should not panic (#195).

-- prompts --
code_review
Expand All @@ -22,6 +22,11 @@ code_review
"clientInfo": { "name": "ExampleClient", "version": "1.0.0" }
}
}
{
"jsonrpc": "2.0",
"id": 1,
"method": "initialize"
}
{
"jsonrpc": "2.0",
"id": 2,
Expand All @@ -36,6 +41,14 @@ code_review
{"jsonrpc":"2.0", "method":"ping"}

-- server --
{
"jsonrpc": "2.0",
"id": 1,
"error": {
"code": 0,
"message": "handleRequest \"initialize\": missing required \"params\""
}
}
{
"jsonrpc": "2.0",
"id": 2,
Expand Down
Loading