From adc6f945b5795723411a15653b01c492b1d89a27 Mon Sep 17 00:00:00 2001 From: Kamdyn Shaeffer Date: Thu, 14 Aug 2025 17:59:15 -0400 Subject: [PATCH] mcp/server.go: implemented server-side logging throughout codebase --- mcp/logging.go | 21 +++++++++++++++++++++ mcp/server.go | 36 +++++++++++++++++++++++++++++++++--- mcp/sse.go | 41 +++++++++++++++++++++++++++++++++-------- mcp/transport.go | 5 +++-- 4 files changed, 90 insertions(+), 13 deletions(-) diff --git a/mcp/logging.go b/mcp/logging.go index 4d33097a..593c7190 100644 --- a/mcp/logging.go +++ b/mcp/logging.go @@ -87,6 +87,27 @@ type LoggingHandler struct { handler slog.Handler } +// discardHandler is a slog.Handler that drops all logs. +type discardHandler struct{} + +func (discardHandler) Enabled(context.Context, slog.Level) bool { return false } +func (discardHandler) Handle(context.Context, slog.Record) error { return nil } +func (discardHandler) WithAttrs([]slog.Attr) slog.Handler { return discardHandler{} } +func (discardHandler) WithGroup(string) slog.Handler { return discardHandler{} } + +// ensureLogger returns l if non-nil, otherwise a discard logger. +func ensureLogger(l *slog.Logger) *slog.Logger { + if l != nil { + return l + } + return slog.New(discardHandler{}) +} + +// internalLogger is used for package-internal logging where we don't have a +// specific server/handler context. It defaults to a discard logger to avoid +// unsolicited output from library code. +var internalLogger = slog.New(discardHandler{}) + // NewLoggingHandler creates a [LoggingHandler] that logs to the given [ServerSession] using a // [slog.JSONHandler]. func NewLoggingHandler(ss *ServerSession, opts *LoggingHandlerOptions) *LoggingHandler { diff --git a/mcp/server.go b/mcp/server.go index e39372dc..c45522e4 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -12,6 +12,7 @@ import ( "encoding/json" "fmt" "iter" + "log/slog" "maps" "net/url" "path/filepath" @@ -32,8 +33,9 @@ const DefaultPageSize = 1000 // sessions by using [Server.Run]. type Server struct { // fixed at creation - impl *Implementation - opts ServerOptions + impl *Implementation + opts ServerOptions + logger *slog.Logger mu sync.Mutex prompts *featureSet[*serverPrompt] @@ -50,6 +52,8 @@ type Server struct { type ServerOptions struct { // Optional instructions for connected clients. Instructions string + // Logger is used for server-side logging. If nil, slog.Default() is used. + Logger *slog.Logger // If non-nil, called when "notifications/initialized" is received. InitializedHandler func(context.Context, *ServerRequest[*InitializedParams]) // PageSize is the maximum number of items to return in a single page for @@ -108,9 +112,11 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server { if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil { panic("UnsubscribeHandler requires SubscribeHandler") } + l := ensureLogger(opts.Logger) return &Server{ impl: impl, opts: *opts, + logger: l, prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }), resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }), @@ -462,6 +468,7 @@ func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNot sessions := slices.Collect(maps.Keys(subscribedSessions)) s.mu.Unlock() notifySessions(sessions, notificationResourceUpdated, params) + s.logger.Info("resource updated notification sent", "uri", params.URI, "subscriber_count", len(sessions)) return nil } @@ -479,6 +486,7 @@ func (s *Server) subscribe(ctx context.Context, req *ServerRequest[*SubscribePar s.resourceSubscriptions[req.Params.URI] = make(map[*ServerSession]bool) } s.resourceSubscriptions[req.Params.URI][req.Session] = true + s.logger.Info("resource subscribed", "uri", req.Params.URI, "session_id", req.Session.ID()) return &emptyResult{}, nil } @@ -500,6 +508,7 @@ func (s *Server) unsubscribe(ctx context.Context, req *ServerRequest[*Unsubscrib delete(s.resourceSubscriptions, req.Params.URI) } } + s.logger.Info("resource unsubscribed", "uri", req.Params.URI, "session_id", req.Session.ID()) return &emptyResult{}, nil } @@ -518,8 +527,10 @@ func (s *Server) unsubscribe(ctx context.Context, req *ServerRequest[*Unsubscrib // It need not be called on servers that are used for multiple concurrent connections, // as with [StreamableHTTPHandler]. func (s *Server) Run(ctx context.Context, t Transport) error { + s.logger.Info("server run start") ss, err := s.Connect(ctx, t, nil) if err != nil { + s.logger.Error("server connect failed", "error", err) return err } @@ -531,8 +542,14 @@ func (s *Server) Run(ctx context.Context, t Transport) error { select { case <-ctx.Done(): ss.Close() + s.logger.Info("server run cancelled", "error", ctx.Err()) return ctx.Err() case err := <-ssClosed: + if err != nil { + s.logger.Error("server session ended with error", "error", err) + } else { + s.logger.Info("server session ended") + } return err } } @@ -548,6 +565,7 @@ func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *Serv s.mu.Lock() s.sessions = append(s.sessions, ss) s.mu.Unlock() + s.logger.Info("server session connected", "session_id", ss.ID()) return ss } @@ -563,6 +581,7 @@ func (s *Server) disconnect(cc *ServerSession) { for _, subscribedSessions := range s.resourceSubscriptions { delete(subscribedSessions, cc) } + s.logger.Info("server session disconnected", "session_id", cc.ID()) } // ServerSessionOptions configures the server session. @@ -583,7 +602,13 @@ func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOp if opts != nil { state = opts.State } - return connect(ctx, t, s, state) + s.logger.Info("server connecting") + ss, err := connect(ctx, t, s, state) + if err != nil { + s.logger.Error("server connect error", "error", err) + return nil, err + } + return ss, nil } // TODO: (nit) move all ServerSession methods below the ServerSession declaration. @@ -606,14 +631,17 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar }) if !wasInit { + ss.server.logger.Warn("initialized before initialize") return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize) } if wasInitd { + ss.server.logger.Warn("duplicate initialized notification") return nil, fmt.Errorf("duplicate %q received", notificationInitialized) } if h := ss.server.opts.InitializedHandler; h != nil { h(ctx, serverRequestFor(ss, params)) } + ss.server.logger.Info("session initialized") return nil, nil } @@ -798,6 +826,7 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, case methodInitialize, methodPing, notificationInitialized: default: if !initialized { + ss.server.logger.Warn("method invalid during initialization", "method", req.Method) return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) } } @@ -842,6 +871,7 @@ func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*e ss.updateState(func(state *ServerSessionState) { state.LogLevel = params.Level }) + ss.server.logger.Info("client log level set", "level", params.Level) return &emptyResult{}, nil } diff --git a/mcp/sse.go b/mcp/sse.go index b7f0d4e2..78b0de5e 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -9,6 +9,7 @@ import ( "context" "fmt" "io" + "log/slog" "net/http" "net/url" "sync" @@ -47,6 +48,7 @@ type SSEHandler struct { mu sync.Mutex sessions map[string]*SSEServerTransport + logger *slog.Logger } // NewSSEHandler returns a new [SSEHandler] that creates and manages MCP @@ -68,9 +70,12 @@ func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler { return &SSEHandler{ getServer: getServer, sessions: make(map[string]*SSEServerTransport), + logger: internalLogger, } } +func (h *SSEHandler) ensureLogger() { h.logger = ensureLogger(h.logger) } + // A SSEServerTransport is a logical SSE session created through a hanging GET // request. // @@ -100,6 +105,10 @@ type SSEServerTransport struct { // Response is the hanging response body to the incoming GET request. Response http.ResponseWriter + // logger is used for per-POST diagnostics and transport-level logs. + // If nil, logging is disabled. + logger *slog.Logger + // incoming is the queue of incoming messages. // It is never closed, and by convention, incoming is non-nil if and only if // the transport is connected. @@ -114,6 +123,8 @@ type SSEServerTransport struct { done chan struct{} // closed when the connection is closed } +func (t *SSEServerTransport) ensureLogger() { t.logger = ensureLogger(t.logger) } + // NewSSEServerTransport creates a new SSE transport for the given messages // endpoint, and hanging GET response. // @@ -124,11 +135,13 @@ func NewSSEServerTransport(endpoint string, w http.ResponseWriter) *SSEServerTra return &SSEServerTransport{ Endpoint: endpoint, Response: w, + logger: ensureLogger(nil), } } // ServeHTTP handles POST requests to the transport endpoint. func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { + t.ensureLogger() if t.incoming == nil { http.Error(w, "session not connected", http.StatusInternalServerError) return @@ -137,6 +150,7 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) // Read and parse the message. data, err := io.ReadAll(req.Body) if err != nil { + t.logger.Error("sse: failed to read body", "error", err) http.Error(w, "failed to read body", http.StatusBadRequest) return } @@ -145,11 +159,13 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) // useful msg, err := jsonrpc2.DecodeMessage(data) if err != nil { + t.logger.Error("sse: failed to parse body", "error", err) http.Error(w, "failed to parse body", http.StatusBadRequest) return } if req, ok := msg.(*jsonrpc.Request); ok { if _, err := checkRequest(req, serverMethodInfos); err != nil { + t.logger.Warn("sse: request validation failed", "error", err) http.Error(w, err.Error(), http.StatusBadRequest) return } @@ -158,6 +174,7 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) case t.incoming <- msg: w.WriteHeader(http.StatusAccepted) case <-t.done: + t.logger.Info("sse: session closed while posting message") http.Error(w, "session closed", http.StatusBadRequest) } } @@ -181,6 +198,7 @@ func (t *SSEServerTransport) Connect(context.Context) (Connection, error) { } func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + h.ensureLogger() sessionID := req.URL.Query().Get("sessionid") // TODO: consider checking Content-Type here. For now, we are lax. @@ -221,11 +239,24 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { sessionID = randText() endpoint, err := req.URL.Parse("?sessionid=" + sessionID) if err != nil { + h.logger.Error("sse: failed to create endpoint", "error", err) http.Error(w, "internal error: failed to create endpoint", http.StatusInternalServerError) return } - transport := &SSEServerTransport{Endpoint: endpoint.RequestURI(), Response: w} + // Determine the server instance and pick a logger for the transport. + server := h.getServer(req) + if server == nil { + // The getServer argument to NewSSEHandler returned nil. + http.Error(w, "no server available", http.StatusBadRequest) + return + } + // Prefer the server's logger if available; otherwise use the handler's. + lg := server.logger + if lg == nil { + lg = h.logger + } + transport := &SSEServerTransport{Endpoint: endpoint.RequestURI(), Response: w, logger: ensureLogger(lg)} // The session is terminated when the request exits. h.mu.Lock() @@ -236,15 +267,9 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { delete(h.sessions, sessionID) h.mu.Unlock() }() - - server := h.getServer(req) - if server == nil { - // The getServer argument to NewSSEHandler returned nil. - http.Error(w, "no server available", http.StatusBadRequest) - return - } ss, err := server.Connect(req.Context(), transport, nil) if err != nil { + h.logger.Error("sse: server connect failed", "error", err) http.Error(w, "connection failed", http.StatusInternalServerError) return } diff --git a/mcp/transport.go b/mcp/transport.go index 76b79986..62a23f06 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "log" "net" "os" "sync" @@ -157,7 +156,9 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, OnDone: func() { b.disconnect(h) }, - OnInternalError: func(err error) { log.Printf("jsonrpc2 error: %v", err) }, + OnInternalError: func(err error) { + internalLogger.Error("jsonrpc2 internal error", "error", err) + }, }) assert(preempter.conn != nil, "unbound preempter") return h, nil