Skip to content

mcp: add and implement SessionStorage interface #215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/sse/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@ func main() {
default:
return nil
}
})
}, nil)
log.Fatal(http.ListenAndServe(addr, handler))
}
73 changes: 73 additions & 0 deletions mcp/session_store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.

package mcp

import (
"iter"
"sync"
)

// ServerSessionStore is a store of [Transport] sessions.
//
// The store must be thread-safe.
type ServerSessionStore[T Transport] interface {
Get(id string) (T, error)
Set(id string, session T) error
Delete(id string) error
Reset() error
All() (iter.Seq[T], error)
}

// MemoryServerSessionStore is a simple in-memory implementation of
// [ServerSessionStore].
type MemoryServerSessionStore[T Transport] struct {
mu sync.Mutex
sessions map[string]T
}

func NewMemoryServerSessionStore[T Transport]() *MemoryServerSessionStore[T] {
return &MemoryServerSessionStore[T]{
sessions: make(map[string]T),
}
}

func (s *MemoryServerSessionStore[T]) Get(id string) (T, error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.sessions[id], nil
}

func (s *MemoryServerSessionStore[T]) Set(id string, session T) error {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[id] = session
return nil
}

func (s *MemoryServerSessionStore[T]) Delete(id string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, id)
return nil
}

func (s *MemoryServerSessionStore[T]) All() (iter.Seq[T], error) {
return func(yield func(T) bool) {
s.mu.Lock()
defer s.mu.Unlock()
for _, session := range s.sessions {
if !yield(session) {
return
}
}
}, nil
}

func (s *MemoryServerSessionStore[T]) Reset() error {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions = make(map[string]T)
return nil
}
40 changes: 40 additions & 0 deletions mcp/session_store_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.

package mcp

import (
"net/http/httptest"
"testing"
)

func TestMemorySessionStorePersistence(t *testing.T) {
store := NewMemoryServerSessionStore[*SSEServerTransport]()

sessionID := "session-1"
rr := httptest.NewRecorder()
expectedSession := NewSSEServerTransport("endpoint-1", rr)
store.Set(sessionID, expectedSession)

actualSession, err := store.Get(sessionID)
if err != nil {
t.Error("unexpected session Get error", err)
}
if actualSession != expectedSession {
t.Errorf("wanted %s to be %v but got %v", sessionID, expectedSession, actualSession)
}

err = store.Delete(sessionID)
if err != nil {
t.Error("unexpected session Delete error", err)
}

actualSession, err = store.Get(sessionID)
if err != nil {
t.Error("unexpected session Get error", err)
}
if actualSession != nil {
t.Errorf("wanted %s to be nil but got %v", sessionID, actualSession)
}
}
42 changes: 28 additions & 14 deletions mcp/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,14 @@ import (
type SSEHandler struct {
getServer func(request *http.Request) *Server
onConnection func(*ServerSession) // for testing; must not block
sessionStore ServerSessionStore[*SSEServerTransport]
}

mu sync.Mutex
sessions map[string]*SSEServerTransport
// SSEOptions is a placeholder options struct for future
// configuration of the SSEHandler.
type SSEOptions struct {
// TODO: support configurable session ID generation (?)
SessionStore ServerSessionStore[*SSEServerTransport]
}

// NewSSEHandler returns a new [SSEHandler] that creates and manages MCP
Expand All @@ -64,10 +69,17 @@ type SSEHandler struct {
// will return a 400 Bad Request.
//
// TODO(rfindley): add options.
func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler {
func NewSSEHandler(getServer func(request *http.Request) *Server, opts *SSEOptions) *SSEHandler {
var sessionStore ServerSessionStore[*SSEServerTransport]
if opts != nil {
sessionStore = opts.SessionStore
}
if sessionStore == nil {
sessionStore = NewMemoryServerSessionStore[*SSEServerTransport]()
}
return &SSEHandler{
getServer: getServer,
sessions: make(map[string]*SSEServerTransport),
getServer: getServer,
sessionStore: sessionStore,
}
}

Expand Down Expand Up @@ -164,9 +176,11 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
http.Error(w, "sessionid must be provided", http.StatusBadRequest)
return
}
h.mu.Lock()
session := h.sessions[sessionID]
h.mu.Unlock()
session, err := h.sessionStore.Get(sessionID)
if err != nil {
http.Error(w, "failed to get session", http.StatusInternalServerError)
return
}
if session == nil {
http.Error(w, "session not found", http.StatusNotFound)
return
Expand Down Expand Up @@ -200,13 +214,13 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
transport := NewSSEServerTransport(endpoint.RequestURI(), w)

// The session is terminated when the request exits.
h.mu.Lock()
h.sessions[sessionID] = transport
h.mu.Unlock()
err = h.sessionStore.Set(sessionID, transport)
if err != nil {
http.Error(w, "internal error: failed to set session", http.StatusInternalServerError)
return
}
defer func() {
h.mu.Lock()
delete(h.sessions, sessionID)
h.mu.Unlock()
h.sessionStore.Delete(sessionID)
}()

server := h.getServer(req)
Expand Down
2 changes: 1 addition & 1 deletion mcp/sse_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func ExampleSSEHandler() {
server := mcp.NewServer(&mcp.Implementation{Name: "adder", Version: "v0.0.1"}, nil)
mcp.AddTool(server, &mcp.Tool{Name: "add", Description: "add two numbers"}, Add)

handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server { return server })
handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server { return server }, nil)
httpServer := httptest.NewServer(handler)
defer httpServer.Close()

Expand Down
2 changes: 1 addition & 1 deletion mcp/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestSSEServer(t *testing.T) {
server := NewServer(testImpl, nil)
AddTool(server, &Tool{Name: "greet"}, sayHi)

sseHandler := NewSSEHandler(func(*http.Request) *Server { return server })
sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }, nil)

conns := make(chan *ServerSession, 1)
sseHandler.onConnection = func(cc *ServerSession) {
Expand Down
54 changes: 35 additions & 19 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ const (
type StreamableHTTPHandler struct {
getServer func(*http.Request) *Server

sessionsMu sync.Mutex
sessions map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header)
sessionStore ServerSessionStore[*StreamableServerTransport]
}

// StreamableHTTPOptions is a placeholder options struct for future
// configuration of the StreamableHTTP handler.
type StreamableHTTPOptions struct {
// TODO: support configurable session ID generation (?)
// TODO: support session retention (?)
SessionStore ServerSessionStore[*StreamableServerTransport]
}

// NewStreamableHTTPHandler returns a new [StreamableHTTPHandler].
Expand All @@ -52,9 +52,16 @@ type StreamableHTTPOptions struct {
// sessions. It is OK for getServer to return the same server multiple times.
// If getServer returns nil, a 400 Bad Request will be served.
func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *StreamableHTTPOptions) *StreamableHTTPHandler {
var sessionStore ServerSessionStore[*StreamableServerTransport]
if opts != nil {
sessionStore = opts.SessionStore
}
if sessionStore == nil {
sessionStore = NewMemoryServerSessionStore[*StreamableServerTransport]()
}
return &StreamableHTTPHandler{
getServer: getServer,
sessions: make(map[string]*StreamableServerTransport),
getServer: getServer,
sessionStore: sessionStore,
}
}

Expand All @@ -65,13 +72,15 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea
//
// Should we allow passing in a session store? That would allow the handler to
// be stateless.
func (h *StreamableHTTPHandler) closeAll() {
h.sessionsMu.Lock()
defer h.sessionsMu.Unlock()
for _, s := range h.sessions {
s.Close()
func (h *StreamableHTTPHandler) closeAll() error {
sessions, err := h.sessionStore.All()
if err != nil {
return err
}
for session := range sessions {
session.Close()
}
h.sessions = nil
return h.sessionStore.Reset()
}

func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
Expand Down Expand Up @@ -100,9 +109,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque

var session *StreamableServerTransport
if id := req.Header.Get(sessionIDHeader); id != "" {
h.sessionsMu.Lock()
session, _ = h.sessions[id]
h.sessionsMu.Unlock()
var err error
session, err = h.sessionStore.Get(id)
if err != nil {
http.Error(w, "failed to get session", http.StatusInternalServerError)
return
}
if session == nil {
http.Error(w, "session not found", http.StatusNotFound)
return
Expand All @@ -117,9 +129,11 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest)
return
}
h.sessionsMu.Lock()
delete(h.sessions, session.sessionID)
h.sessionsMu.Unlock()
err := h.sessionStore.Delete(session.sessionID)
if err != nil {
http.Error(w, "failed to delete session", http.StatusInternalServerError)
return
}
session.Close()
w.WriteHeader(http.StatusNoContent)
return
Expand Down Expand Up @@ -148,9 +162,11 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
http.Error(w, "failed connection", http.StatusInternalServerError)
return
}
h.sessionsMu.Lock()
h.sessions[s.sessionID] = s
h.sessionsMu.Unlock()
err := h.sessionStore.Set(s.sessionID, s)
if err != nil {
http.Error(w, "failed to save session", http.StatusInternalServerError)
return
}
session = s
}

Expand Down