Skip to content

mcp/server.go: implement server-side logging throughout codebase #306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions mcp/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,27 @@ type LoggingHandler struct {
handler slog.Handler
}

// discardHandler is a slog.Handler that drops all logs.
type discardHandler struct{}

func (discardHandler) Enabled(context.Context, slog.Level) bool { return false }
func (discardHandler) Handle(context.Context, slog.Record) error { return nil }
func (discardHandler) WithAttrs([]slog.Attr) slog.Handler { return discardHandler{} }
func (discardHandler) WithGroup(string) slog.Handler { return discardHandler{} }

// ensureLogger returns l if non-nil, otherwise a discard logger.
func ensureLogger(l *slog.Logger) *slog.Logger {
if l != nil {
return l
}
return slog.New(discardHandler{})
}

// internalLogger is used for package-internal logging where we don't have a
// specific server/handler context. It defaults to a discard logger to avoid
// unsolicited output from library code.
var internalLogger = slog.New(discardHandler{})

// NewLoggingHandler creates a [LoggingHandler] that logs to the given [ServerSession] using a
// [slog.JSONHandler].
func NewLoggingHandler(ss *ServerSession, opts *LoggingHandlerOptions) *LoggingHandler {
Expand Down
36 changes: 33 additions & 3 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"encoding/json"
"fmt"
"iter"
"log/slog"
"maps"
"net/url"
"path/filepath"
Expand All @@ -32,8 +33,9 @@ const DefaultPageSize = 1000
// sessions by using [Server.Run].
type Server struct {
// fixed at creation
impl *Implementation
opts ServerOptions
impl *Implementation
opts ServerOptions
logger *slog.Logger

mu sync.Mutex
prompts *featureSet[*serverPrompt]
Expand All @@ -50,6 +52,8 @@ type Server struct {
type ServerOptions struct {
// Optional instructions for connected clients.
Instructions string
// Logger is used for server-side logging. If nil, slog.Default() is used.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nil should mean no logging.

Logger *slog.Logger
// If non-nil, called when "notifications/initialized" is received.
InitializedHandler func(context.Context, *ServerRequest[*InitializedParams])
// PageSize is the maximum number of items to return in a single page for
Expand Down Expand Up @@ -108,9 +112,11 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server {
if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil {
panic("UnsubscribeHandler requires SubscribeHandler")
}
l := ensureLogger(opts.Logger)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inline this into L119

return &Server{
impl: impl,
opts: *opts,
logger: l,
prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }),
tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }),
resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }),
Expand Down Expand Up @@ -462,6 +468,7 @@ func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNot
sessions := slices.Collect(maps.Keys(subscribedSessions))
s.mu.Unlock()
notifySessions(sessions, notificationResourceUpdated, params)
s.logger.Info("resource updated notification sent", "uri", params.URI, "subscriber_count", len(sessions))
return nil
}

Expand All @@ -479,6 +486,7 @@ func (s *Server) subscribe(ctx context.Context, req *ServerRequest[*SubscribePar
s.resourceSubscriptions[req.Params.URI] = make(map[*ServerSession]bool)
}
s.resourceSubscriptions[req.Params.URI][req.Session] = true
s.logger.Info("resource subscribed", "uri", req.Params.URI, "session_id", req.Session.ID())

return &emptyResult{}, nil
}
Expand All @@ -500,6 +508,7 @@ func (s *Server) unsubscribe(ctx context.Context, req *ServerRequest[*Unsubscrib
delete(s.resourceSubscriptions, req.Params.URI)
}
}
s.logger.Info("resource unsubscribed", "uri", req.Params.URI, "session_id", req.Session.ID())

return &emptyResult{}, nil
}
Expand All @@ -518,8 +527,10 @@ func (s *Server) unsubscribe(ctx context.Context, req *ServerRequest[*Unsubscrib
// It need not be called on servers that are used for multiple concurrent connections,
// as with [StreamableHTTPHandler].
func (s *Server) Run(ctx context.Context, t Transport) error {
s.logger.Info("server run start")
ss, err := s.Connect(ctx, t, nil)
if err != nil {
s.logger.Error("server connect failed", "error", err)
return err
}

Expand All @@ -531,8 +542,14 @@ func (s *Server) Run(ctx context.Context, t Transport) error {
select {
case <-ctx.Done():
ss.Close()
s.logger.Info("server run cancelled", "error", ctx.Err())
return ctx.Err()
case err := <-ssClosed:
if err != nil {
s.logger.Error("server session ended with error", "error", err)
} else {
s.logger.Info("server session ended")
}
return err
}
}
Expand All @@ -548,6 +565,7 @@ func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *Serv
s.mu.Lock()
s.sessions = append(s.sessions, ss)
s.mu.Unlock()
s.logger.Info("server session connected", "session_id", ss.ID())
return ss
}

Expand All @@ -563,6 +581,7 @@ func (s *Server) disconnect(cc *ServerSession) {
for _, subscribedSessions := range s.resourceSubscriptions {
delete(subscribedSessions, cc)
}
s.logger.Info("server session disconnected", "session_id", cc.ID())
}

// ServerSessionOptions configures the server session.
Expand All @@ -583,7 +602,13 @@ func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOp
if opts != nil {
state = opts.State
}
return connect(ctx, t, s, state)
s.logger.Info("server connecting")
ss, err := connect(ctx, t, s, state)
if err != nil {
s.logger.Error("server connect error", "error", err)
return nil, err
}
return ss, nil
}

// TODO: (nit) move all ServerSession methods below the ServerSession declaration.
Expand All @@ -606,14 +631,17 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar
})

if !wasInit {
ss.server.logger.Warn("initialized before initialize")
return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize)
}
if wasInitd {
ss.server.logger.Warn("duplicate initialized notification")
return nil, fmt.Errorf("duplicate %q received", notificationInitialized)
}
if h := ss.server.opts.InitializedHandler; h != nil {
h(ctx, serverRequestFor(ss, params))
}
ss.server.logger.Info("session initialized")
return nil, nil
}

Expand Down Expand Up @@ -798,6 +826,7 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any,
case methodInitialize, methodPing, notificationInitialized:
default:
if !initialized {
ss.server.logger.Warn("method invalid during initialization", "method", req.Method)
return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method)
}
}
Expand Down Expand Up @@ -842,6 +871,7 @@ func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*e
ss.updateState(func(state *ServerSessionState) {
state.LogLevel = params.Level
})
ss.server.logger.Info("client log level set", "level", params.Level)
return &emptyResult{}, nil
}

Expand Down
41 changes: 33 additions & 8 deletions mcp/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"sync"
Expand Down Expand Up @@ -47,6 +48,7 @@ type SSEHandler struct {

mu sync.Mutex
sessions map[string]*SSEServerTransport
logger *slog.Logger
}

// NewSSEHandler returns a new [SSEHandler] that creates and manages MCP
Expand All @@ -68,9 +70,12 @@ func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler {
return &SSEHandler{
getServer: getServer,
sessions: make(map[string]*SSEServerTransport),
logger: internalLogger,
}
}

func (h *SSEHandler) ensureLogger() { h.logger = ensureLogger(h.logger) }

// A SSEServerTransport is a logical SSE session created through a hanging GET
// request.
//
Expand Down Expand Up @@ -100,6 +105,10 @@ type SSEServerTransport struct {
// Response is the hanging response body to the incoming GET request.
Response http.ResponseWriter

// logger is used for per-POST diagnostics and transport-level logs.
// If nil, logging is disabled.
logger *slog.Logger

// incoming is the queue of incoming messages.
// It is never closed, and by convention, incoming is non-nil if and only if
// the transport is connected.
Expand All @@ -114,6 +123,8 @@ type SSEServerTransport struct {
done chan struct{} // closed when the connection is closed
}

func (t *SSEServerTransport) ensureLogger() { t.logger = ensureLogger(t.logger) }

// NewSSEServerTransport creates a new SSE transport for the given messages
// endpoint, and hanging GET response.
//
Expand All @@ -124,11 +135,13 @@ func NewSSEServerTransport(endpoint string, w http.ResponseWriter) *SSEServerTra
return &SSEServerTransport{
Endpoint: endpoint,
Response: w,
logger: ensureLogger(nil),
}
}

// ServeHTTP handles POST requests to the transport endpoint.
func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) {
t.ensureLogger()
if t.incoming == nil {
http.Error(w, "session not connected", http.StatusInternalServerError)
return
Expand All @@ -137,6 +150,7 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request)
// Read and parse the message.
data, err := io.ReadAll(req.Body)
if err != nil {
t.logger.Error("sse: failed to read body", "error", err)
http.Error(w, "failed to read body", http.StatusBadRequest)
return
}
Expand All @@ -145,11 +159,13 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request)
// useful
msg, err := jsonrpc2.DecodeMessage(data)
if err != nil {
t.logger.Error("sse: failed to parse body", "error", err)
http.Error(w, "failed to parse body", http.StatusBadRequest)
return
}
if req, ok := msg.(*jsonrpc.Request); ok {
if _, err := checkRequest(req, serverMethodInfos); err != nil {
t.logger.Warn("sse: request validation failed", "error", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
Expand All @@ -158,6 +174,7 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request)
case t.incoming <- msg:
w.WriteHeader(http.StatusAccepted)
case <-t.done:
t.logger.Info("sse: session closed while posting message")
http.Error(w, "session closed", http.StatusBadRequest)
}
}
Expand All @@ -181,6 +198,7 @@ func (t *SSEServerTransport) Connect(context.Context) (Connection, error) {
}

func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
h.ensureLogger()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this only occurs in one place, inline it.

sessionID := req.URL.Query().Get("sessionid")

// TODO: consider checking Content-Type here. For now, we are lax.
Expand Down Expand Up @@ -221,11 +239,24 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
sessionID = randText()
endpoint, err := req.URL.Parse("?sessionid=" + sessionID)
if err != nil {
h.logger.Error("sse: failed to create endpoint", "error", err)
http.Error(w, "internal error: failed to create endpoint", http.StatusInternalServerError)
return
}

transport := &SSEServerTransport{Endpoint: endpoint.RequestURI(), Response: w}
// Determine the server instance and pick a logger for the transport.
server := h.getServer(req)
if server == nil {
// The getServer argument to NewSSEHandler returned nil.
http.Error(w, "no server available", http.StatusBadRequest)
return
}
// Prefer the server's logger if available; otherwise use the handler's.
lg := server.logger
if lg == nil {
lg = h.logger
}
transport := &SSEServerTransport{Endpoint: endpoint.RequestURI(), Response: w, logger: ensureLogger(lg)}

// The session is terminated when the request exits.
h.mu.Lock()
Expand All @@ -236,15 +267,9 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
delete(h.sessions, sessionID)
h.mu.Unlock()
}()

server := h.getServer(req)
if server == nil {
// The getServer argument to NewSSEHandler returned nil.
http.Error(w, "no server available", http.StatusBadRequest)
return
}
ss, err := server.Connect(req.Context(), transport, nil)
if err != nil {
h.logger.Error("sse: server connect failed", "error", err)
http.Error(w, "connection failed", http.StatusInternalServerError)
return
}
Expand Down
5 changes: 3 additions & 2 deletions mcp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"errors"
"fmt"
"io"
"log"
"net"
"os"
"sync"
Expand Down Expand Up @@ -157,7 +156,9 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H,
OnDone: func() {
b.disconnect(h)
},
OnInternalError: func(err error) { log.Printf("jsonrpc2 error: %v", err) },
OnInternalError: func(err error) {
internalLogger.Error("jsonrpc2 internal error", "error", err)
},
})
assert(preempter.conn != nil, "unbound preempter")
return h, nil
Expand Down