-
Notifications
You must be signed in to change notification settings - Fork 138
mcp/streamable: use event store to fix unbounded memory issues #335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -348,7 +348,7 @@ type stream struct { | |
// that there are messages available to write into the HTTP response. | ||
// In addition, the presence of a channel guarantees that at most one HTTP response | ||
// can receive messages for a logical stream. After claiming the stream, incoming | ||
// requests should read from outgoing, to ensure that no new messages are missed. | ||
// requests should read from the event store, to ensure that no new messages are missed. | ||
// | ||
// To simplify locking, signal is an atomic. We need an atomic.Pointer, because | ||
// you can't set an atomic.Value to nil. | ||
|
@@ -360,22 +360,23 @@ type stream struct { | |
// The following mutable fields are protected by the mutex of the containing | ||
// StreamableServerTransport. | ||
|
||
// outgoing is the list of outgoing messages, enqueued by server methods that | ||
// write notifications and responses, and dequeued by streamResponse. | ||
outgoing [][]byte | ||
|
||
// streamRequests is the set of unanswered incoming RPCs for the stream. | ||
// | ||
// Requests persist until their response data has been added to outgoing. | ||
// Requests persist until their response data has been added to the event store. | ||
requests map[jsonrpc.ID]struct{} | ||
|
||
// lastWriteIndex tracks the index of the last message written to the event store for this stream. | ||
lastWriteIndex atomic.Int64 | ||
} | ||
|
||
func newStream(id StreamID, jsonResponse bool) *stream { | ||
return &stream{ | ||
s := &stream{ | ||
id: id, | ||
jsonResponse: jsonResponse, | ||
requests: make(map[jsonrpc.ID]struct{}), | ||
} | ||
s.lastWriteIndex.Store(-1) | ||
return s | ||
} | ||
|
||
func signalChanPtr() *chan struct{} { | ||
|
@@ -559,8 +560,8 @@ func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter | |
|
||
var msgs []json.RawMessage | ||
ctx := req.Context() | ||
for msg, ok := range c.messages(ctx, stream, false) { | ||
if !ok { | ||
for msg, err := range c.messages(ctx, stream, false, -1) { | ||
if err != nil { | ||
if ctx.Err() != nil { | ||
w.WriteHeader(http.StatusNoContent) | ||
return | ||
|
@@ -623,44 +624,20 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, | |
} | ||
} | ||
|
||
if lastIndex >= 0 { | ||
// Resume. | ||
for data, err := range c.eventStore.After(req.Context(), c.SessionID(), stream.id, lastIndex) { | ||
if err != nil { | ||
// TODO: reevaluate these status codes. | ||
// Maybe distinguish between storage errors, which are 500s, and missing | ||
// session or stream ID--can these arise from bad input? | ||
status := http.StatusInternalServerError | ||
if errors.Is(err, ErrEventsPurged) { | ||
status = http.StatusInsufficientStorage | ||
} | ||
errorf(status, "failed to read events: %v", err) | ||
return | ||
} | ||
// The iterator yields events beginning just after lastIndex, or it would have | ||
// yielded an error. | ||
if !write(data) { | ||
return | ||
} | ||
} | ||
} | ||
|
||
// Repeatedly collect pending outgoing events and send them. | ||
ctx := req.Context() | ||
for msg, ok := range c.messages(ctx, stream, persistent) { | ||
if !ok { | ||
for msg, err := range c.messages(ctx, stream, persistent, lastIndex) { | ||
if err != nil { | ||
if ctx.Err() != nil && writes == 0 { | ||
// This probably doesn't matter, but respond with NoContent if the client disconnected. | ||
w.WriteHeader(http.StatusNoContent) | ||
} else if errors.Is(err, ErrEventsPurged) { | ||
errorf(http.StatusInsufficientStorage, "failed to read events: %v", err) | ||
} else { | ||
errorf(http.StatusGone, "stream terminated") | ||
} | ||
return | ||
} | ||
if err := c.eventStore.Append(req.Context(), c.SessionID(), stream.id, msg); err != nil { | ||
errorf(http.StatusInternalServerError, "storing event: %v", err.Error()) | ||
return | ||
} | ||
if !write(msg) { | ||
return | ||
} | ||
|
@@ -675,41 +652,48 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, | |
// If the stream did not terminate normally, it is either because ctx was | ||
// cancelled, or the connection is closed: check the ctx.Err() to differentiate | ||
// these cases. | ||
func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool) iter.Seq2[json.RawMessage, bool] { | ||
return func(yield func(json.RawMessage, bool) bool) { | ||
func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool, lastIndex int) iter.Seq2[json.RawMessage, error] { | ||
return func(yield func(json.RawMessage, error) bool) { | ||
for { | ||
c.mu.Lock() | ||
outgoing := stream.outgoing | ||
stream.outgoing = nil | ||
nOutstanding := len(stream.requests) | ||
c.mu.Unlock() | ||
|
||
for _, data := range outgoing { | ||
if !yield(data, true) { | ||
for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, lastIndex) { | ||
if err != nil { | ||
// Wait for session initialization before yielding. | ||
if errors.Is(err, ErrUnknownSession) || errors.Is(err, ErrUnknownStream) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of doing it this way, I would avoid calling After at all if there is no session or stream. If there is a session and stream and After returns one of these errors, I think it is a real error and should be yielded. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure there is an easy way to do that because the session and stream may exist but it may not exist in the event store yet. Append could happen before or after the After call which is why we need After to report the error to us. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should only call After when the client sends Last-Event-ID. If they send it too early, the server should return an error. I don't understand the state where After is called before Append. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We can call After with an index of -1 to start writing from the beginning of the stream which allows us to simplify the logic even if last-event-id is not sent.
After is called when respondSSE is called which is disjoint from when Append is called in Write. These events can happen in any order which is why we case on ErrUnknownSession and ErrUnknownStream to skip to the logic below which waits for a stream signal. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why don't we just call Append with nil data when we create the stream? Here's the problem: I'd like the eventstore to be able to completely clean up the stream or session at will, and so when we get an unknown session or stream, we should fail this connection because it will never be recoverable. |
||
break | ||
} | ||
yield(nil, err) | ||
return | ||
} | ||
if !yield(data, nil) { | ||
return | ||
} | ||
lastIndex++ | ||
} | ||
c.mu.Lock() | ||
nOutstanding := len(stream.requests) | ||
c.mu.Unlock() | ||
|
||
// If all requests have been handled and replied to, we should terminate this connection. | ||
// "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream." | ||
// §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server | ||
// We only want to terminate POSTs, and GETs that are replaying. The general-purpose GET | ||
// (stream ID 0) will never have requests, and should remain open indefinitely. | ||
if nOutstanding == 0 && !persistent { | ||
if nOutstanding == 0 && !persistent && lastIndex >= int(stream.lastWriteIndex.Load()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I originally thought that this commit reintroduced the bug that was fixed in findleyr@f4a9396. However, I see that it probably doesn't, because of this atomic check. I think it would be simpler to just move the check for nOutstanding above the After loop above. Then you don't need lastWriteIndex. WDYT? I prefer to avoid atomics when there's already a synchronization mechanism (mu), because it's hard to reason about the relationship between the atomics and critical sections. |
||
return | ||
} | ||
|
||
select { | ||
case <-*stream.signal.Load(): // there are new outgoing messages | ||
// return to top of loop | ||
case <-c.done: // session is closed | ||
yield(nil, false) | ||
yield(nil, errors.New("session is closed")) | ||
return | ||
case <-ctx.Done(): | ||
yield(nil, false) | ||
yield(nil, ctx.Err()) | ||
return | ||
} | ||
} | ||
|
||
} | ||
} | ||
|
||
|
@@ -812,9 +796,10 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e | |
stream = c.streams[""] | ||
} | ||
|
||
// TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == "" | ||
// and the client never did a GET), then memory will grow without bound. Consider a mitigation. | ||
stream.outgoing = append(stream.outgoing, data) | ||
if err := c.eventStore.Append(ctx, c.SessionID(), stream.id, data); err != nil { | ||
return fmt.Errorf("error storing event: %w", err) | ||
} | ||
stream.lastWriteIndex.Add(1) | ||
if isResponse { | ||
// Once we've put the reply on the queue, it's no longer outstanding. | ||
delete(stream.requests, forRequest) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: s/return/wrap: we don't return this error value exactly, but one that wraps it.