-
Notifications
You must be signed in to change notification settings - Fork 138
mcp/server.go: implement server-side logging throughout codebase #306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ import ( | |
"encoding/json" | ||
"fmt" | ||
"iter" | ||
"log/slog" | ||
"maps" | ||
"net/url" | ||
"path/filepath" | ||
|
@@ -32,8 +33,9 @@ const DefaultPageSize = 1000 | |
// sessions by using [Server.Run]. | ||
type Server struct { | ||
// fixed at creation | ||
impl *Implementation | ||
opts ServerOptions | ||
impl *Implementation | ||
opts ServerOptions | ||
logger *slog.Logger | ||
|
||
mu sync.Mutex | ||
prompts *featureSet[*serverPrompt] | ||
|
@@ -50,6 +52,8 @@ type Server struct { | |
type ServerOptions struct { | ||
// Optional instructions for connected clients. | ||
Instructions string | ||
// Logger is used for server-side logging. If nil, slog.Default() is used. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nil should mean no logging. |
||
Logger *slog.Logger | ||
// If non-nil, called when "notifications/initialized" is received. | ||
InitializedHandler func(context.Context, *ServerRequest[*InitializedParams]) | ||
// PageSize is the maximum number of items to return in a single page for | ||
|
@@ -108,9 +112,11 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server { | |
if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil { | ||
panic("UnsubscribeHandler requires SubscribeHandler") | ||
} | ||
l := ensureLogger(opts.Logger) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inline this into L119 |
||
return &Server{ | ||
impl: impl, | ||
opts: *opts, | ||
logger: l, | ||
prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), | ||
tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }), | ||
resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }), | ||
|
@@ -462,6 +468,7 @@ func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNot | |
sessions := slices.Collect(maps.Keys(subscribedSessions)) | ||
s.mu.Unlock() | ||
notifySessions(sessions, notificationResourceUpdated, params) | ||
s.logger.Info("resource updated notification sent", "uri", params.URI, "subscriber_count", len(sessions)) | ||
return nil | ||
} | ||
|
||
|
@@ -479,6 +486,7 @@ func (s *Server) subscribe(ctx context.Context, req *ServerRequest[*SubscribePar | |
s.resourceSubscriptions[req.Params.URI] = make(map[*ServerSession]bool) | ||
} | ||
s.resourceSubscriptions[req.Params.URI][req.Session] = true | ||
s.logger.Info("resource subscribed", "uri", req.Params.URI, "session_id", req.Session.ID()) | ||
|
||
return &emptyResult{}, nil | ||
} | ||
|
@@ -500,6 +508,7 @@ func (s *Server) unsubscribe(ctx context.Context, req *ServerRequest[*Unsubscrib | |
delete(s.resourceSubscriptions, req.Params.URI) | ||
} | ||
} | ||
s.logger.Info("resource unsubscribed", "uri", req.Params.URI, "session_id", req.Session.ID()) | ||
|
||
return &emptyResult{}, nil | ||
} | ||
|
@@ -518,8 +527,10 @@ func (s *Server) unsubscribe(ctx context.Context, req *ServerRequest[*Unsubscrib | |
// It need not be called on servers that are used for multiple concurrent connections, | ||
// as with [StreamableHTTPHandler]. | ||
func (s *Server) Run(ctx context.Context, t Transport) error { | ||
s.logger.Info("server run start") | ||
ss, err := s.Connect(ctx, t, nil) | ||
if err != nil { | ||
s.logger.Error("server connect failed", "error", err) | ||
return err | ||
} | ||
|
||
|
@@ -531,8 +542,14 @@ func (s *Server) Run(ctx context.Context, t Transport) error { | |
select { | ||
case <-ctx.Done(): | ||
ss.Close() | ||
s.logger.Info("server run cancelled", "error", ctx.Err()) | ||
return ctx.Err() | ||
case err := <-ssClosed: | ||
if err != nil { | ||
s.logger.Error("server session ended with error", "error", err) | ||
} else { | ||
s.logger.Info("server session ended") | ||
} | ||
return err | ||
} | ||
} | ||
|
@@ -548,6 +565,7 @@ func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *Serv | |
s.mu.Lock() | ||
s.sessions = append(s.sessions, ss) | ||
s.mu.Unlock() | ||
s.logger.Info("server session connected", "session_id", ss.ID()) | ||
return ss | ||
} | ||
|
||
|
@@ -563,6 +581,7 @@ func (s *Server) disconnect(cc *ServerSession) { | |
for _, subscribedSessions := range s.resourceSubscriptions { | ||
delete(subscribedSessions, cc) | ||
} | ||
s.logger.Info("server session disconnected", "session_id", cc.ID()) | ||
} | ||
|
||
// ServerSessionOptions configures the server session. | ||
|
@@ -583,7 +602,13 @@ func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOp | |
if opts != nil { | ||
state = opts.State | ||
} | ||
return connect(ctx, t, s, state) | ||
s.logger.Info("server connecting") | ||
ss, err := connect(ctx, t, s, state) | ||
if err != nil { | ||
s.logger.Error("server connect error", "error", err) | ||
return nil, err | ||
} | ||
return ss, nil | ||
} | ||
|
||
// TODO: (nit) move all ServerSession methods below the ServerSession declaration. | ||
|
@@ -606,14 +631,17 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar | |
}) | ||
|
||
if !wasInit { | ||
ss.server.logger.Warn("initialized before initialize") | ||
return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize) | ||
} | ||
if wasInitd { | ||
ss.server.logger.Warn("duplicate initialized notification") | ||
return nil, fmt.Errorf("duplicate %q received", notificationInitialized) | ||
} | ||
if h := ss.server.opts.InitializedHandler; h != nil { | ||
h(ctx, serverRequestFor(ss, params)) | ||
} | ||
ss.server.logger.Info("session initialized") | ||
return nil, nil | ||
} | ||
|
||
|
@@ -798,6 +826,7 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, | |
case methodInitialize, methodPing, notificationInitialized: | ||
default: | ||
if !initialized { | ||
ss.server.logger.Warn("method invalid during initialization", "method", req.Method) | ||
return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) | ||
} | ||
} | ||
|
@@ -842,6 +871,7 @@ func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*e | |
ss.updateState(func(state *ServerSessionState) { | ||
state.LogLevel = params.Level | ||
}) | ||
ss.server.logger.Info("client log level set", "level", params.Level) | ||
return &emptyResult{}, nil | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ import ( | |
"context" | ||
"fmt" | ||
"io" | ||
"log/slog" | ||
"net/http" | ||
"net/url" | ||
"sync" | ||
|
@@ -47,6 +48,7 @@ type SSEHandler struct { | |
|
||
mu sync.Mutex | ||
sessions map[string]*SSEServerTransport | ||
logger *slog.Logger | ||
} | ||
|
||
// NewSSEHandler returns a new [SSEHandler] that creates and manages MCP | ||
|
@@ -68,9 +70,12 @@ func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler { | |
return &SSEHandler{ | ||
getServer: getServer, | ||
sessions: make(map[string]*SSEServerTransport), | ||
logger: internalLogger, | ||
} | ||
} | ||
|
||
func (h *SSEHandler) ensureLogger() { h.logger = ensureLogger(h.logger) } | ||
|
||
// A SSEServerTransport is a logical SSE session created through a hanging GET | ||
// request. | ||
// | ||
|
@@ -100,6 +105,10 @@ type SSEServerTransport struct { | |
// Response is the hanging response body to the incoming GET request. | ||
Response http.ResponseWriter | ||
|
||
// logger is used for per-POST diagnostics and transport-level logs. | ||
// If nil, logging is disabled. | ||
logger *slog.Logger | ||
|
||
// incoming is the queue of incoming messages. | ||
// It is never closed, and by convention, incoming is non-nil if and only if | ||
// the transport is connected. | ||
|
@@ -114,6 +123,8 @@ type SSEServerTransport struct { | |
done chan struct{} // closed when the connection is closed | ||
} | ||
|
||
func (t *SSEServerTransport) ensureLogger() { t.logger = ensureLogger(t.logger) } | ||
|
||
// NewSSEServerTransport creates a new SSE transport for the given messages | ||
// endpoint, and hanging GET response. | ||
// | ||
|
@@ -124,11 +135,13 @@ func NewSSEServerTransport(endpoint string, w http.ResponseWriter) *SSEServerTra | |
return &SSEServerTransport{ | ||
Endpoint: endpoint, | ||
Response: w, | ||
logger: ensureLogger(nil), | ||
} | ||
} | ||
|
||
// ServeHTTP handles POST requests to the transport endpoint. | ||
func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { | ||
t.ensureLogger() | ||
if t.incoming == nil { | ||
http.Error(w, "session not connected", http.StatusInternalServerError) | ||
return | ||
|
@@ -137,6 +150,7 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) | |
// Read and parse the message. | ||
data, err := io.ReadAll(req.Body) | ||
if err != nil { | ||
t.logger.Error("sse: failed to read body", "error", err) | ||
http.Error(w, "failed to read body", http.StatusBadRequest) | ||
return | ||
} | ||
|
@@ -145,11 +159,13 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) | |
// useful | ||
msg, err := jsonrpc2.DecodeMessage(data) | ||
if err != nil { | ||
t.logger.Error("sse: failed to parse body", "error", err) | ||
http.Error(w, "failed to parse body", http.StatusBadRequest) | ||
return | ||
} | ||
if req, ok := msg.(*jsonrpc.Request); ok { | ||
if _, err := checkRequest(req, serverMethodInfos); err != nil { | ||
t.logger.Warn("sse: request validation failed", "error", err) | ||
http.Error(w, err.Error(), http.StatusBadRequest) | ||
return | ||
} | ||
|
@@ -158,6 +174,7 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) | |
case t.incoming <- msg: | ||
w.WriteHeader(http.StatusAccepted) | ||
case <-t.done: | ||
t.logger.Info("sse: session closed while posting message") | ||
http.Error(w, "session closed", http.StatusBadRequest) | ||
} | ||
} | ||
|
@@ -181,6 +198,7 @@ func (t *SSEServerTransport) Connect(context.Context) (Connection, error) { | |
} | ||
|
||
func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { | ||
h.ensureLogger() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this only occurs in one place, inline it. |
||
sessionID := req.URL.Query().Get("sessionid") | ||
|
||
// TODO: consider checking Content-Type here. For now, we are lax. | ||
|
@@ -221,11 +239,24 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { | |
sessionID = randText() | ||
endpoint, err := req.URL.Parse("?sessionid=" + sessionID) | ||
if err != nil { | ||
h.logger.Error("sse: failed to create endpoint", "error", err) | ||
http.Error(w, "internal error: failed to create endpoint", http.StatusInternalServerError) | ||
return | ||
} | ||
|
||
transport := &SSEServerTransport{Endpoint: endpoint.RequestURI(), Response: w} | ||
// Determine the server instance and pick a logger for the transport. | ||
server := h.getServer(req) | ||
if server == nil { | ||
// The getServer argument to NewSSEHandler returned nil. | ||
http.Error(w, "no server available", http.StatusBadRequest) | ||
return | ||
} | ||
// Prefer the server's logger if available; otherwise use the handler's. | ||
lg := server.logger | ||
if lg == nil { | ||
lg = h.logger | ||
} | ||
transport := &SSEServerTransport{Endpoint: endpoint.RequestURI(), Response: w, logger: ensureLogger(lg)} | ||
|
||
// The session is terminated when the request exits. | ||
h.mu.Lock() | ||
|
@@ -236,15 +267,9 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { | |
delete(h.sessions, sessionID) | ||
h.mu.Unlock() | ||
}() | ||
|
||
server := h.getServer(req) | ||
if server == nil { | ||
// The getServer argument to NewSSEHandler returned nil. | ||
http.Error(w, "no server available", http.StatusBadRequest) | ||
return | ||
} | ||
ss, err := server.Connect(req.Context(), transport, nil) | ||
if err != nil { | ||
h.logger.Error("sse: server connect failed", "error", err) | ||
http.Error(w, "connection failed", http.StatusInternalServerError) | ||
return | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.