diff --git a/mcp/event.go b/mcp/event.go index f4f4eee..fcc35ba 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -283,6 +283,12 @@ func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID // index is no longer available. var ErrEventsPurged = errors.New("data purged") +// ErrUnknownSession is the error that [EventStore.After] should return if the session ID is unknown. +var ErrUnknownSession = errors.New("unknown session ID") + +// ErrUnknownSession is the error that [EventStore.After] should return if the stream ID is unknown. +var ErrUnknownStream = errors.New("unknown stream ID") + // After implements [EventStore.After]. func (s *MemoryEventStore) After(_ context.Context, sessionID string, streamID StreamID, index int) iter.Seq2[[]byte, error] { // Return the data items to yield. @@ -292,11 +298,11 @@ func (s *MemoryEventStore) After(_ context.Context, sessionID string, streamID S defer s.mu.Unlock() streamMap, ok := s.store[sessionID] if !ok { - return nil, fmt.Errorf("MemoryEventStore.After: unknown session ID %q", sessionID) + return nil, fmt.Errorf("MemoryEventStore.After: session ID %v: %w", sessionID, ErrUnknownSession) } dl, ok := streamMap[streamID] if !ok { - return nil, fmt.Errorf("MemoryEventStore.After: unknown stream ID %v in session %q", streamID, sessionID) + return nil, fmt.Errorf("MemoryEventStore.After: stream ID %v in session %q: %w", streamID, sessionID, ErrUnknownStream) } start := index + 1 if dl.first > start { diff --git a/mcp/streamable.go b/mcp/streamable.go index 1ecf201..e1fc88a 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -348,7 +348,7 @@ type stream struct { // that there are messages available to write into the HTTP response. // In addition, the presence of a channel guarantees that at most one HTTP response // can receive messages for a logical stream. After claiming the stream, incoming - // requests should read from outgoing, to ensure that no new messages are missed. + // requests should read from the event store, to ensure that no new messages are missed. // // To simplify locking, signal is an atomic. We need an atomic.Pointer, because // you can't set an atomic.Value to nil. @@ -360,22 +360,23 @@ type stream struct { // The following mutable fields are protected by the mutex of the containing // StreamableServerTransport. - // outgoing is the list of outgoing messages, enqueued by server methods that - // write notifications and responses, and dequeued by streamResponse. - outgoing [][]byte - // streamRequests is the set of unanswered incoming RPCs for the stream. // - // Requests persist until their response data has been added to outgoing. + // Requests persist until their response data has been added to the event store. requests map[jsonrpc.ID]struct{} + + // lastWriteIndex tracks the index of the last message written to the event store for this stream. + lastWriteIndex atomic.Int64 } func newStream(id StreamID, jsonResponse bool) *stream { - return &stream{ + s := &stream{ id: id, jsonResponse: jsonResponse, requests: make(map[jsonrpc.ID]struct{}), } + s.lastWriteIndex.Store(-1) + return s } func signalChanPtr() *chan struct{} { @@ -559,8 +560,8 @@ func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter var msgs []json.RawMessage ctx := req.Context() - for msg, ok := range c.messages(ctx, stream, false) { - if !ok { + for msg, err := range c.messages(ctx, stream, false, -1) { + if err != nil { if ctx.Err() != nil { w.WriteHeader(http.StatusNoContent) return @@ -623,44 +624,20 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, } } - if lastIndex >= 0 { - // Resume. - for data, err := range c.eventStore.After(req.Context(), c.SessionID(), stream.id, lastIndex) { - if err != nil { - // TODO: reevaluate these status codes. - // Maybe distinguish between storage errors, which are 500s, and missing - // session or stream ID--can these arise from bad input? - status := http.StatusInternalServerError - if errors.Is(err, ErrEventsPurged) { - status = http.StatusInsufficientStorage - } - errorf(status, "failed to read events: %v", err) - return - } - // The iterator yields events beginning just after lastIndex, or it would have - // yielded an error. - if !write(data) { - return - } - } - } - // Repeatedly collect pending outgoing events and send them. ctx := req.Context() - for msg, ok := range c.messages(ctx, stream, persistent) { - if !ok { + for msg, err := range c.messages(ctx, stream, persistent, lastIndex) { + if err != nil { if ctx.Err() != nil && writes == 0 { // This probably doesn't matter, but respond with NoContent if the client disconnected. w.WriteHeader(http.StatusNoContent) + } else if errors.Is(err, ErrEventsPurged) { + errorf(http.StatusInsufficientStorage, "failed to read events: %v", err) } else { errorf(http.StatusGone, "stream terminated") } return } - if err := c.eventStore.Append(req.Context(), c.SessionID(), stream.id, msg); err != nil { - errorf(http.StatusInternalServerError, "storing event: %v", err.Error()) - return - } if !write(msg) { return } @@ -675,27 +652,33 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, // If the stream did not terminate normally, it is either because ctx was // cancelled, or the connection is closed: check the ctx.Err() to differentiate // these cases. -func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool) iter.Seq2[json.RawMessage, bool] { - return func(yield func(json.RawMessage, bool) bool) { +func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool, lastIndex int) iter.Seq2[json.RawMessage, error] { + return func(yield func(json.RawMessage, error) bool) { for { - c.mu.Lock() - outgoing := stream.outgoing - stream.outgoing = nil - nOutstanding := len(stream.requests) - c.mu.Unlock() - - for _, data := range outgoing { - if !yield(data, true) { + for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, lastIndex) { + if err != nil { + // Wait for session initialization before yielding. + if errors.Is(err, ErrUnknownSession) || errors.Is(err, ErrUnknownStream) { + break + } + yield(nil, err) return } + if !yield(data, nil) { + return + } + lastIndex++ } + c.mu.Lock() + nOutstanding := len(stream.requests) + c.mu.Unlock() // If all requests have been handled and replied to, we should terminate this connection. // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream." // ยง6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server // We only want to terminate POSTs, and GETs that are replaying. The general-purpose GET // (stream ID 0) will never have requests, and should remain open indefinitely. - if nOutstanding == 0 && !persistent { + if nOutstanding == 0 && !persistent && lastIndex >= int(stream.lastWriteIndex.Load()) { return } @@ -703,13 +686,14 @@ func (c *streamableServerConn) messages(ctx context.Context, stream *stream, per case <-*stream.signal.Load(): // there are new outgoing messages // return to top of loop case <-c.done: // session is closed - yield(nil, false) + yield(nil, errors.New("session is closed")) return case <-ctx.Done(): - yield(nil, false) + yield(nil, ctx.Err()) return } } + } } @@ -812,9 +796,10 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e stream = c.streams[""] } - // TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == "" - // and the client never did a GET), then memory will grow without bound. Consider a mitigation. - stream.outgoing = append(stream.outgoing, data) + if err := c.eventStore.Append(ctx, c.SessionID(), stream.id, data); err != nil { + return fmt.Errorf("error storing event: %w", err) + } + stream.lastWriteIndex.Add(1) if isResponse { // Once we've put the reply on the queue, it's no longer outstanding. delete(stream.requests, forRequest)