diff --git a/.gitignore b/.gitignore index 694735b68..9fcc0d6d3 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,6 @@ out .DS_Store dist/ + +# claude +.claude/ \ No newline at end of file diff --git a/src/client/index.ts b/src/client/index.ts index 3e8d8ec80..9fd788a68 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -172,6 +172,11 @@ export class Client< this._instructions = result.instructions; + // Handle session assignment from server + if (result.sessionId) { + this.createSession(result.sessionId, result.sessionTimeout); + } + await this.notification({ method: "notifications/initialized", }); diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 12714ea44..1c4768d13 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -1,4 +1,5 @@ import { Transport, FetchLike } from "../shared/transport.js"; +import { SessionState } from "../shared/protocol.js"; import { isInitializedNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js"; import { EventSourceParserStream } from "eventsource-parser/stream"; @@ -129,6 +130,7 @@ export class StreamableHTTPClientTransport implements Transport { private _authProvider?: OAuthClientProvider; private _fetch?: FetchLike; private _sessionId?: string; + private _sessionState?: SessionState; // For protocol-level session support private _reconnectionOptions: StreamableHTTPReconnectionOptions; private _protocolVersion?: string; @@ -504,7 +506,12 @@ export class StreamableHTTPClientTransport implements Transport { } get sessionId(): string | undefined { - return this._sessionId; + // Prefer protocol-level session state, fallback to legacy _sessionId + return this._sessionState?.sessionId || this._sessionId; + } + + setSessionState(sessionState: SessionState): void { + this._sessionState = sessionState; } /** diff --git a/src/inMemory.ts b/src/inMemory.ts index 5dd6e81e0..97596f94f 100644 --- a/src/inMemory.ts +++ b/src/inMemory.ts @@ -1,6 +1,7 @@ import { Transport } from "./shared/transport.js"; import { JSONRPCMessage, RequestId } from "./types.js"; import { AuthInfo } from "./server/auth/types.js"; +import { SessionState } from "./shared/protocol.js"; interface QueuedMessage { message: JSONRPCMessage; @@ -17,7 +18,23 @@ export class InMemoryTransport implements Transport { onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void; - sessionId?: string; + + private _sessionState?: SessionState; + + get sessionId(): string | undefined { + return this._sessionState?.sessionId; + } + + getLegacySessionOptions(): undefined { + // InMemoryTransport has no legacy session configuration + return undefined; + } + + setSessionState(sessionState: SessionState): void { + // Store session state for sessionId getter + // InMemoryTransport doesn't use session state for other purposes + this._sessionState = sessionState; + } /** * Creates a pair of linked in-memory transports that can communicate with each other. One should be passed to a Client and one to a Server. diff --git a/src/integration-tests/taskResumability.test.ts b/src/integration-tests/taskResumability.test.ts index efd2611f8..fe4c0b667 100644 --- a/src/integration-tests/taskResumability.test.ts +++ b/src/integration-tests/taskResumability.test.ts @@ -186,7 +186,7 @@ describe('Transport resumability', () => { name: 'run-notifications', arguments: { count: 3, - interval: 10 + interval: 50 // Increased interval for more reliable timing } } }, CallToolResultSchema, { @@ -194,8 +194,10 @@ describe('Transport resumability', () => { onresumptiontoken: onLastEventIdUpdate }); - // Wait for some notifications to arrive (not all) - shorter wait time - await new Promise(resolve => setTimeout(resolve, 20)); + // Wait for some notifications to arrive (not all) + // With 50ms interval, first notification should arrive immediately, + // second at 50ms. We wait 75ms to ensure we get at least 1-2 notifications + await new Promise(resolve => setTimeout(resolve, 75)); // Verify we received some notifications and lastEventId was updated expect(notifications.length).toBeGreaterThan(0); @@ -219,7 +221,7 @@ describe('Transport resumability', () => { // Add a short delay to ensure clean disconnect before reconnecting - await new Promise(resolve => setTimeout(resolve, 10)); + await new Promise(resolve => setTimeout(resolve, 50)); // Wait for the rejection to be handled await catchPromise; diff --git a/src/server/index.ts b/src/server/index.ts index 10ae2fadc..fa7729ba7 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -3,7 +3,10 @@ import { Protocol, ProtocolOptions, RequestOptions, + SessionOptions, + SessionState, } from "../shared/protocol.js"; +import { Transport } from "../shared/transport.js"; import { ClientCapabilities, CreateMessageRequest, @@ -32,6 +35,8 @@ import { ServerRequest, ServerResult, SUPPORTED_PROTOCOL_VERSIONS, + SessionTerminateRequestSchema, + SessionTerminateRequest, } from "../types.js"; import Ajv from "ajv"; @@ -85,12 +90,21 @@ export class Server< private _clientVersion?: Implementation; private _capabilities: ServerCapabilities; private _instructions?: string; + private _sessionOptions?: SessionOptions; /** * Callback for when initialization has fully completed (i.e., the client has sent an `initialized` notification). */ oninitialized?: () => void; + /** + * Returns the connected transport instance. + * Used for session-to-server routing in examples. + */ + getTransport() { + return this.transport; + } + /** * Initializes this server with the given name and version information. */ @@ -98,18 +112,66 @@ export class Server< private _serverInfo: Implementation, options?: ServerOptions, ) { - super(options); + // Extract session options before passing to super + const { sessions, ...protocolOptions } = options ?? {}; + super(protocolOptions); + this._sessionOptions = sessions; this._capabilities = options?.capabilities ?? {}; this._instructions = options?.instructions; this.setRequestHandler(InitializeRequestSchema, (request) => this._oninitialize(request), ); + this.setRequestHandler(SessionTerminateRequestSchema, (request) => + this._onSessionTerminate(request), + ); this.setNotificationHandler(InitializedNotificationSchema, () => this.oninitialized?.(), ); } + /** + * Handles initialization request synchronously for HTTP transport backward compatibility. + * This bypasses the Protocol's async request handling to allow immediate error detection. + * @internal + */ + async handleInitializeSync(request: InitializeRequest): Promise { + // Call the internal initialization handler directly + const result = await this._oninitialize(request); + return result; + } + + /** + * Connect to a transport, handling legacy session options from the transport. + */ + async connect(transport: Transport): Promise { + // Handle legacy session options delegation from transport + const legacySessionOptions = transport.getLegacySessionOptions?.(); + if (legacySessionOptions) { + if (this._sessionOptions) { + // Both server session options and transport legacy session options provided. Using server options. + } else { + this._sessionOptions = legacySessionOptions; + } + } + + // Register synchronous initialization handler if transport supports it + if (transport.setInitializeHandler) { + transport.setInitializeHandler((request: InitializeRequest) => + this.handleInitializeSync(request) + ); + } + + // Register synchronous termination handler if transport supports it + if (transport.setTerminateHandler) { + transport.setTerminateHandler((sessionId?: string) => + this.terminateSession(sessionId) + ); + } + + await super.connect(transport); + } + /** * Registers new capabilities. This can only be called before connecting to a transport. * @@ -269,12 +331,76 @@ export class Server< ? requestedVersion : LATEST_PROTOCOL_VERSION; - return { + const result: InitializeResult = { protocolVersion, capabilities: this.getCapabilities(), serverInfo: this._serverInfo, ...(this._instructions && { instructions: this._instructions }), }; + + // Generate session if supported + if (this._sessionOptions?.sessionIdGenerator) { + const sessionId = this._sessionOptions.sessionIdGenerator(); + result.sessionId = sessionId; + result.sessionTimeout = this._sessionOptions.sessionTimeout; + + await this.initializeSession(sessionId, this._sessionOptions.sessionTimeout); + } + + return result; + } + + private async initializeSession(sessionId: string, timeout?: number): Promise { + // Create the session + this.createSession(sessionId, timeout); + + // Try to call the initialization callback, but if it fails, + // store the error in session state and rethrow + try { + await this._sessionOptions?.onsessioninitialized?.(sessionId); + } catch (error) { + // Store the error in session state for the transport to check + const sessionState = this.getSessionState(); + if (sessionState) { + sessionState.callbackError = error instanceof Error ? error : new Error(String(error)); + } + throw error; + } + } + + protected async terminateSession(sessionId?: string): Promise { + // Get the current session ID before termination + const currentSessionId = this.getSessionState()?.sessionId; + + // Call parent's terminateSession to clear the session state + await super.terminateSession(sessionId); + + // Now call the callback if we had a session + if (currentSessionId) { + try { + await this._sessionOptions?.onsessionclosed?.(currentSessionId); + } catch (error) { + // Re-create minimal session state just to store the error for transport to check + const sessionState: SessionState = { + sessionId: currentSessionId, + createdAt: Date.now(), + lastActivity: Date.now(), + callbackError: error instanceof Error ? error : new Error(String(error)) + }; + // Notify transport of the error state + this.transport?.setSessionState?.(sessionState); + throw error; + } + } + } + + private async _onSessionTerminate( + request: SessionTerminateRequest + ): Promise { + // Use the same termination logic as the protocol method + // sessionId comes directly from the protocol request + await this.terminateSession(request.sessionId); + return {}; } /** diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 10e550df4..9c5d08d54 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -14,7 +14,8 @@ import { LoggingMessageNotificationSchema, Notification, TextContent, - ElicitRequestSchema + ElicitRequestSchema, + InitializeResultSchema } from "../types.js"; import { ResourceTemplate } from "./mcp.js"; import { completable } from "./completable.js"; @@ -1342,6 +1343,10 @@ describe("tool()", () => { const mcpServer = new McpServer({ name: "test server", version: "1.0", + }, { + sessions: { + sessionIdGenerator: () => "test-session-123" + } }); const client = new Client({ @@ -1349,7 +1354,7 @@ describe("tool()", () => { version: "1.0", }); - let receivedSessionId: string | undefined; + let receivedSessionId: string | number | undefined; mcpServer.tool("test-tool", async (extra) => { receivedSessionId = extra.sessionId; return { @@ -1363,20 +1368,32 @@ describe("tool()", () => { }); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - // Set a test sessionId on the server transport - serverTransport.sessionId = "test-session-123"; await Promise.all([ client.connect(clientTransport), mcpServer.server.connect(serverTransport), ]); + // Initialize to create session + await client.request( + { + method: "initialize", + params: { + protocolVersion: "2025-06-18", + capabilities: {}, + clientInfo: { name: "test client", version: "1.0" } + } + }, + InitializeResultSchema + ); + await client.request( { method: "tools/call", params: { name: "test-tool", }, + sessionId: "test-session-123", // Protocol-level session approach }, CallToolResultSchema, ); diff --git a/src/server/mcp.ts b/src/server/mcp.ts index 791facef1..f352d1464 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -85,6 +85,14 @@ export class McpServer { await this.server.close(); } + /** + * Returns the connected transport instance. + * Used for session-to-server routing in examples. + */ + getTransport() { + return this.server.getTransport(); + } + private _toolHandlersInitialized = false; private setToolRequestHandlers() { diff --git a/src/server/server-session.test.ts b/src/server/server-session.test.ts new file mode 100644 index 000000000..7d2c473c8 --- /dev/null +++ b/src/server/server-session.test.ts @@ -0,0 +1,117 @@ +import { describe, it, expect, jest, beforeEach } from '@jest/globals'; +import { Server } from './index.js'; +import { JSONRPCMessage, MessageExtraInfo } from '../types.js'; +import { Transport } from '../shared/transport.js'; + +// Mock transport for testing +class MockTransport implements Transport { + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; + + sentMessages: JSONRPCMessage[] = []; + + async start(): Promise {} + async close(): Promise {} + + async send(message: JSONRPCMessage): Promise { + this.sentMessages.push(message); + } +} + +describe('Server Session Integration', () => { + let server: Server; + let transport: MockTransport; + + beforeEach(() => { + transport = new MockTransport(); + }); + + describe('Session Configuration', () => { + it('should accept session options through constructor', async () => { + const mockCallback = jest.fn() as jest.MockedFunction<(sessionId: string | number) => void>; + + server = new Server( + { name: 'test-server', version: '1.0.0' }, + { + sessions: { + sessionIdGenerator: () => 'test-session-123', + sessionTimeout: 3600, + onsessioninitialized: mockCallback, + onsessionclosed: mockCallback + } + } + ); + + await server.connect(transport); + + // Verify server was created successfully with session options + expect(server).toBeDefined(); + expect(server.getTransport()).toBe(transport); + }); + + it('should work without session options', async () => { + server = new Server( + { name: 'test-server', version: '1.0.0' } + ); + + await server.connect(transport); + + // Should work fine without session configuration + expect(server).toBeDefined(); + expect(server.getTransport()).toBe(transport); + }); + }); + + describe('Transport Access', () => { + it('should expose transport via getTransport method', async () => { + server = new Server( + { name: 'test-server', version: '1.0.0' } + ); + await server.connect(transport); + + expect(server.getTransport()).toBe(transport); + }); + + it('should return undefined when not connected', () => { + server = new Server( + { name: 'test-server', version: '1.0.0' } + ); + + expect(server.getTransport()).toBeUndefined(); + }); + }); + + describe('Session Handler Registration', () => { + it('should register session terminate handler when created', async () => { + server = new Server( + { name: 'test-server', version: '1.0.0' }, + { + sessions: { + sessionIdGenerator: () => 'test-session' + } + } + ); + await server.connect(transport); + + // Test that session/terminate handler exists by sending a terminate message + // and verifying we don't get "method not found" error + const terminateMessage = { + jsonrpc: '2.0' as const, + id: 1, + method: 'session/terminate', + sessionId: 'test-session' + }; + + transport.onmessage!(terminateMessage); + + // Check if a "method not found" error was sent + const methodNotFoundError = transport.sentMessages.find(msg => + 'error' in msg && msg.error.code === -32601 + ); + + // Handler should exist, so no "method not found" error + expect(methodNotFoundError).toBeUndefined(); + }); + }); +}); \ No newline at end of file diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index 3a0a5c066..e6848097f 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -262,6 +262,23 @@ describe("StreamableHTTPServerTransport", () => { expect(response.headers.get("mcp-session-id")).toBeDefined(); }); + it("should create transport without options (backward compatibility)", async () => { + // Test that StreamableHTTPServerTransport can be created without any options + const minimalTransport = new StreamableHTTPServerTransport(); + expect(minimalTransport).toBeDefined(); + + // Test that it can connect to a server + const minimalMcpServer = new McpServer( + { name: "minimal-server", version: "1.0.0" }, + { capabilities: {} } + ); + + await expect(minimalMcpServer.connect(minimalTransport)).resolves.not.toThrow(); + + // Clean up + await minimalTransport.close(); + }); + it("should reject second initialization request", async () => { // First initialize const sessionId = await initializeServer(); @@ -289,6 +306,7 @@ describe("StreamableHTTPServerTransport", () => { params: { clientInfo: { name: "test-client-2", version: "1.0" }, protocolVersion: "2025-03-26", + capabilities: {}, }, id: "init-2", } diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index 3bf84e430..cd8c562a0 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -1,6 +1,7 @@ import { IncomingMessage, ServerResponse } from "node:http"; import { Transport } from "../shared/transport.js"; -import { MessageExtraInfo, RequestInfo, isInitializeRequest, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema, RequestId, SUPPORTED_PROTOCOL_VERSIONS, DEFAULT_NEGOTIATED_PROTOCOL_VERSION } from "../types.js"; +import { SessionState, SessionOptions } from "../shared/protocol.js"; +import { MessageExtraInfo, RequestInfo, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema, RequestId, SUPPORTED_PROTOCOL_VERSIONS, DEFAULT_NEGOTIATED_PROTOCOL_VERSION, isInitializeRequest, InitializeRequest, InitializeResult } from "../types.js"; import getRawBody from "raw-body"; import contentType from "content-type"; import { randomUUID } from "node:crypto"; @@ -128,36 +129,82 @@ export interface StreamableHTTPServerTransportOptions { * - No session validation is performed */ export class StreamableHTTPServerTransport implements Transport { - // when sessionId is not set (undefined), it means the transport is in stateless mode - private sessionIdGenerator: (() => string) | undefined; private _started: boolean = false; private _streamMapping: Map = new Map(); private _requestToStreamMapping: Map = new Map(); private _requestResponseMap: Map = new Map(); - private _initialized: boolean = false; private _enableJsonResponse: boolean = false; private _standaloneSseStreamId: string = '_GET_stream'; private _eventStore?: EventStore; - private _onsessioninitialized?: (sessionId: string) => void | Promise; - private _onsessionclosed?: (sessionId: string) => void | Promise; private _allowedHosts?: string[]; private _allowedOrigins?: string[]; private _enableDnsRebindingProtection: boolean; - - sessionId?: string; + private _sessionState?: SessionState; // Reference to server's session state + private _legacySessionCallbacks?: SessionOptions; // Legacy callbacks for backward compatibility + private _initializeHandler?: (request: InitializeRequest) => Promise; // Special handler for synchronous initialization + private _terminateHandler?: (sessionId?: string) => Promise; // Special handler for synchronous termination + private _pendingInitResponse?: JSONRPCMessage; // Pending initialization response to send via SSE onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; - constructor(options: StreamableHTTPServerTransportOptions) { - this.sessionIdGenerator = options.sessionIdGenerator; - this._enableJsonResponse = options.enableJsonResponse ?? false; - this._eventStore = options.eventStore; - this._onsessioninitialized = options.onsessioninitialized; - this._onsessionclosed = options.onsessionclosed; - this._allowedHosts = options.allowedHosts; - this._allowedOrigins = options.allowedOrigins; - this._enableDnsRebindingProtection = options.enableDnsRebindingProtection ?? false; + /** + * Sets the session state reference for HTTP header handling. + * Called by the server when session is created. + */ + setSessionState(sessionState: SessionState): void { + this._sessionState = sessionState; + } + + /** + * Sets a special handler for initialization requests that bypasses async protocol handling. + * This allows the transport to get immediate error feedback for HTTP status codes. + * @internal + */ + setInitializeHandler(handler: (request: InitializeRequest) => Promise): void { + this._initializeHandler = handler; + } + + /** + * Sets a handler for synchronous session termination processing. + * This allows the transport to get immediate error feedback for HTTP status codes. + * @internal + */ + setTerminateHandler(handler: (sessionId?: string) => Promise): void { + this._terminateHandler = handler; + } + + /** + * Gets the current sessionId for HTTP headers. + * Returns undefined if no session is active. + */ + get sessionId(): string | undefined { + const sessionId = this._sessionState?.sessionId; + return sessionId; + } + + /** + * Gets legacy session options for delegation to server. + * Used for backward compatibility when server connects. + */ + getLegacySessionOptions(): SessionOptions | undefined { + return this._legacySessionCallbacks; + } + + constructor(options?: StreamableHTTPServerTransportOptions) { + // Store legacy session callbacks for delegation to server + this._legacySessionCallbacks = options ? { + sessionIdGenerator: options.sessionIdGenerator, + onsessioninitialized: options.onsessioninitialized, + onsessionclosed: options.onsessionclosed + } : undefined; + + // Transport options + this._enableJsonResponse = options?.enableJsonResponse ?? false; + this._eventStore = options?.eventStore; + this._allowedHosts = options?.allowedHosts; + this._allowedOrigins = options?.allowedOrigins; + this._enableDnsRebindingProtection = options?.enableDnsRebindingProtection ?? false; } /** @@ -248,12 +295,7 @@ export class StreamableHTTPServerTransport implements Transport { return; } - // If an Mcp-Session-Id is returned by the server during initialization, - // clients using the Streamable HTTP transport MUST include it - // in the Mcp-Session-Id header on all of their subsequent HTTP requests. - if (!this.validateSession(req, res)) { - return; - } + // Session validation now handled by server through protocol layer if (!this.validateProtocolVersion(req, res)) { return; } @@ -374,6 +416,11 @@ export class StreamableHTTPServerTransport implements Transport { */ private async handlePostRequest(req: IncomingMessage & { auth?: AuthInfo }, res: ServerResponse, parsedBody?: unknown): Promise { try { + // Validate protocol version first + if (!this.validateProtocolVersion(req, res)) { + return; + } + // Validate the Accept header const acceptHeader = req.headers.accept; // The client MUST include an Accept header, listing both application/json and text/event-stream as supported content types. @@ -426,58 +473,173 @@ export class StreamableHTTPServerTransport implements Transport { messages = [JSONRPCMessageSchema.parse(rawMessage)]; } - // Check if this is an initialization request - // https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/ - const isInitializationRequest = messages.some(isInitializeRequest); - if (isInitializationRequest) { - // If it's a server with session management and the session ID is already set we should reject the request - // to avoid re-initialization. - if (this._initialized && this.sessionId !== undefined) { - res.writeHead(400).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32600, - message: "Invalid Request: Server already initialized" - }, - id: null - })); - return; - } - if (messages.length > 1) { - res.writeHead(400).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32600, - message: "Invalid Request: Only one initialization request is allowed" - }, - id: null - })); - return; - } - this.sessionId = this.sessionIdGenerator?.(); - this._initialized = true; - - // If we have a session ID and an onsessioninitialized handler, call it immediately - // This is needed in cases where the server needs to keep track of multiple sessions - if (this.sessionId && this._onsessioninitialized) { - await Promise.resolve(this._onsessioninitialized(this.sessionId)); + // Inject sessionId from HTTP headers into protocol messages (for backward compatibility) + const headerSessionId = req.headers["mcp-session-id"]; + if (headerSessionId && !Array.isArray(headerSessionId)) { + // Check for sessionId mismatches first + for (const message of messages) { + if ('sessionId' in message && message.sessionId !== undefined) { + if (message.sessionId !== headerSessionId) { + // SessionId mismatch between header and protocol message + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Bad Request: SessionId mismatch between header and protocol message" + }, + id: null + })); + return; // Fail entire request + } + } } + + // No mismatches, proceed with injection + messages = messages.map(message => { + // Inject header sessionId if message doesn't have one + if (!('sessionId' in message) || message.sessionId === undefined) { + return { ...message, sessionId: headerSessionId }; + } + return message; // Keep existing sessionId + }); + } + // Count initialization requests for validation + const initRequests = messages.filter(isInitializeRequest); + + // Check for multiple initialization requests in batch + if (initRequests.length > 1) { + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32600, + message: "Only one initialization request is allowed per batch" + }, + id: null + })); + return; } - if (!isInitializationRequest) { - // If an Mcp-Session-Id is returned by the server during initialization, - // clients using the Streamable HTTP transport MUST include it - // in the Mcp-Session-Id header on all of their subsequent HTTP requests. - if (!this.validateSession(req, res)) { - return; + + // Process initialization messages first to create session state before SSE headers + const processedInitMessages = new Set(); + for (const message of messages) { + if (isInitializeRequest(message)) { + // Use synchronous initialization handler if available for immediate error detection + if (this._initializeHandler && isJSONRPCRequest(message)) { + try { + // Check if already initialized + if (this._sessionState) { + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32600, + message: "Server already initialized" + }, + id: message.id + })); + return; + } + + // Both type guards ensure message is InitializeRequest with id + const result = await this._initializeHandler(message); + // Create the response message and mark it as processed + const response = { + jsonrpc: "2.0" as const, + id: message.id, + result + }; + processedInitMessages.add(JSON.stringify(message)); + // Store the response to send later via SSE + this._pendingInitResponse = response; + } catch (error) { + // Initialization failed - return HTTP error immediately + const errorMessage = error instanceof Error ? error.message : String(error); + res.writeHead(400, { "Content-Type": "text/plain" }); + res.end(`Session initialization failed: ${errorMessage}`); + return; + } + } else { + // Fallback to async processing via onmessage + await Promise.resolve(this.onmessage?.(message, { authInfo, requestInfo })); + processedInitMessages.add(JSON.stringify(message)); + + // Check if session initialization failed (callback threw) + if (this._sessionState?.callbackError) { + res.writeHead(400, { "Content-Type": "text/plain" }); + res.end(`Session initialization failed: ${this._sessionState.callbackError.message}`); + return; + } + } } - // Mcp-Protocol-Version header is required for all requests after initialization. - if (!this.validateProtocolVersion(req, res)) { - return; + } + // Session should now be created and available for HTTP headers + + // Validate session for non-initialization requests (backward compatibility for HTTP transport) + // This provides appropriate HTTP status codes before starting SSE stream + const sessionsEnabled = this._legacySessionCallbacks?.sessionIdGenerator !== undefined; + if (sessionsEnabled) { + // Sessions are enabled, validate for non-initialization requests + // Skip messages that have already been processed as initialization + for (const message of messages) { + const messageStr = JSON.stringify(message); + if (isJSONRPCRequest(message) && !isInitializeRequest(message) && !processedInitMessages.has(messageStr)) { + const messageSessionId = message.sessionId; + + // Check if session ID is missing when required + if (!messageSessionId) { + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Bad Request: Session ID required" + }, + id: null + })); + return; + } + + // Check if server is not initialized yet + if (!this._sessionState) { + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Bad Request: Server not initialized" + }, + id: null + })); + return; + } + + // Check if we have an active session and validate the ID + if (messageSessionId !== this._sessionState.sessionId) { + res.writeHead(404).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32001, + message: "Session not found" + }, + id: null + })); + return; + } + + // If no session exists yet but sessionId was provided, it's invalid + if (!this._sessionState) { + res.writeHead(404).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32001, + message: "Session not found" + }, + id: null + })); + return; + } + } } } - // check if it contains requests const hasRequests = messages.some(isJSONRPCRequest); @@ -489,7 +651,7 @@ export class StreamableHTTPServerTransport implements Transport { for (const message of messages) { this.onmessage?.(message, { authInfo, requestInfo }); } - } else if (hasRequests) { + } else { // The default behavior is to use SSE streaming // but in some cases server will return JSON responses const streamId = randomUUID(); @@ -504,7 +666,6 @@ export class StreamableHTTPServerTransport implements Transport { if (this.sessionId !== undefined) { headers["mcp-session-id"] = this.sessionId; } - res.writeHead(200, headers); } // Store the response for this request to send messages back through this connection @@ -520,8 +681,18 @@ export class StreamableHTTPServerTransport implements Transport { this._streamMapping.delete(streamId); }); - // handle each message + // Send pending initialization response if we have one + if (this._pendingInitResponse) { + await this.send(this._pendingInitResponse); + this._pendingInitResponse = undefined; + } + + // handle each message (skip already processed initialization messages) for (const message of messages) { + const messageStr = JSON.stringify(message); + if (processedInitMessages.has(messageStr)) { + continue; + } this.onmessage?.(message, { authInfo, requestInfo }); } // The server SHOULD NOT close the SSE stream before sending all JSON-RPC responses @@ -543,83 +714,86 @@ export class StreamableHTTPServerTransport implements Transport { } /** - * Handles DELETE requests to terminate sessions + * Handles DELETE requests to terminate sessions + * + * Note: backward compatibility. Handler delegates via a SessionTerminateRequest message to the server */ private async handleDeleteRequest(req: IncomingMessage, res: ServerResponse): Promise { - if (!this.validateSession(req, res)) { - return; - } if (!this.validateProtocolVersion(req, res)) { return; } - await Promise.resolve(this._onsessionclosed?.(this.sessionId!)); - await this.close(); - res.writeHead(200).end(); - } - - /** - * Validates session ID for non-initialization requests - * Returns true if the session is valid, false otherwise - */ - private validateSession(req: IncomingMessage, res: ServerResponse): boolean { - if (this.sessionIdGenerator === undefined) { - // If the sessionIdGenerator ID is not set, the session management is disabled - // and we don't need to validate the session ID - return true; - } - if (!this._initialized) { - // If the server has not been initialized yet, reject all requests + + // Extract sessionId from header and convert to session/terminate protocol message + const headerSessionId = req.headers["mcp-session-id"]; + if (!headerSessionId || Array.isArray(headerSessionId)) { res.writeHead(400).end(JSON.stringify({ jsonrpc: "2.0", error: { code: -32000, - message: "Bad Request: Server not initialized" + message: "Bad Request: Mcp-Session-Id header required for session termination" }, id: null })); - return false; + return; } - const sessionId = req.headers["mcp-session-id"]; - - if (!sessionId) { - // Non-initialization requests without a session ID should return 400 Bad Request - res.writeHead(400).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Bad Request: Mcp-Session-Id header is required" - }, - id: null - })); - return false; - } else if (Array.isArray(sessionId)) { - res.writeHead(400).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Bad Request: Mcp-Session-Id header must be a single value" - }, - id: null - })); - return false; + // Validate session exists before attempting termination (HTTP transport backward compatibility) + if (this._sessionState) { + if (headerSessionId !== this._sessionState.sessionId) { + res.writeHead(404).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32001, + message: "Session not found" + }, + id: null + })); + return; + } } - else if (sessionId !== this.sessionId) { - // Reject requests with invalid session ID with 404 Not Found - res.writeHead(404).end(JSON.stringify({ + + // Use synchronous termination handler if available for immediate error detection + if (this._terminateHandler) { + try { + await this._terminateHandler(headerSessionId); + // Success + res.writeHead(200).end(); + } catch (error) { + // Termination failed - return HTTP error immediately + const errorMessage = error instanceof Error ? error.message : String(error); + res.writeHead(500, { "Content-Type": "text/plain" }); + res.end(`Session termination failed: ${errorMessage}`); + return; + } + } else { + // Fallback to async processing via onmessage + // Create session/terminate protocol message + const terminateMessage: JSONRPCMessage = { jsonrpc: "2.0", - error: { - code: -32001, - message: "Session not found" - }, - id: null + id: Date.now(), // Simple ID for internal message + method: "session/terminate", + sessionId: headerSessionId + }; + + // Send to server for processing (server handles validation and termination) + await Promise.resolve(this.onmessage?.(terminateMessage, { + requestInfo: { headers: req.headers } })); - return false; + + // Check if termination failed (onsessionclosed threw) + if (this._sessionState?.callbackError) { + res.writeHead(500, { "Content-Type": "text/plain" }); + res.end(`Session termination failed: ${this._sessionState.callbackError.message}`); + return; + } + + // Success + res.writeHead(200).end(); } - - return true; } + // Session validation now handled entirely by server through protocol layer + private validateProtocolVersion(req: IncomingMessage, res: ServerResponse): boolean { let protocolVersion = req.headers["mcp-protocol-version"] ?? DEFAULT_NEGOTIATED_PROTOCOL_VERSION; if (Array.isArray(protocolVersion)) { @@ -662,6 +836,7 @@ export class StreamableHTTPServerTransport implements Transport { // Check if this message should be sent on the standalone SSE stream (no request ID) // Ignore notifications from tools (which have relatedRequestId set) // Those will be sent via dedicated response SSE streams + if (requestId === undefined) { // For standalone SSE streams, we can only send requests and notifications if (isJSONRPCResponse(message) || isJSONRPCError(message)) { @@ -723,8 +898,9 @@ export class StreamableHTTPServerTransport implements Transport { const headers: Record = { 'Content-Type': 'application/json', }; - if (this.sessionId !== undefined) { - headers['mcp-session-id'] = this.sessionId; + const sessionId = this.sessionId; + if (sessionId !== undefined) { + headers['mcp-session-id'] = sessionId; } const responses = relatedIds diff --git a/src/shared/protocol-session.test.ts b/src/shared/protocol-session.test.ts new file mode 100644 index 000000000..602d0e366 --- /dev/null +++ b/src/shared/protocol-session.test.ts @@ -0,0 +1,182 @@ +import { describe, it, expect, beforeEach } from '@jest/globals'; +import { Protocol, SessionState } from './protocol.js'; +import { ErrorCode, JSONRPCRequest, JSONRPCMessage, Request, Notification, Result, MessageExtraInfo } from '../types.js'; +import { Transport } from './transport.js'; + +// Mock transport for testing +class MockTransport implements Transport { + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; + + sentMessages: JSONRPCMessage[] = []; + + async start(): Promise {} + async close(): Promise {} + + async send(message: JSONRPCMessage): Promise { + this.sentMessages.push(message); + } +} + +// Test implementation of Protocol +class TestProtocol extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + + // Expose protected methods for testing + public testValidateSessionId(sessionId?: string) { + return this.validateSessionId(sessionId); + } + + public testCreateSession(sessionId: string, timeout?: number) { + return this.createSession(sessionId, timeout); + } + + public testTerminateSession(sessionId?: string) { + return this.terminateSession(sessionId); + } + + public testUpdateSessionActivity() { + return this.updateSessionActivity(); + } + + public testIsSessionExpired() { + return this.isSessionExpired(); + } + + public getSessionState(): SessionState | undefined { + return (this as unknown as { _sessionState?: SessionState })._sessionState; + } +} + +describe('Protocol Session Management', () => { + let protocol: TestProtocol; + let transport: MockTransport; + + beforeEach(() => { + transport = new MockTransport(); + }); + + describe('Session Validation', () => { + it('should allow sessionless operation when no session options', async () => { + protocol = new TestProtocol(); + await protocol.connect(transport); + + // Should validate successfully with no session + expect(protocol.testValidateSessionId(undefined)).toBe(true); + expect(protocol.testValidateSessionId('some-session')).toBe(false); + }); + + it('should validate session correctly when session exists', async () => { + protocol = new TestProtocol(); + await protocol.connect(transport); + + // Create a session + protocol.testCreateSession('test-session-123'); + + // Valid session should pass + expect(protocol.testValidateSessionId('test-session-123')).toBe(true); + + // Invalid session should fail + expect(protocol.testValidateSessionId('wrong-session')).toBe(false); + + // No session when one exists should fail + expect(protocol.testValidateSessionId(undefined)).toBe(false); + }); + + it('should validate sessionless correctly when no active session', async () => { + protocol = new TestProtocol(); + await protocol.connect(transport); + + // No active session, no message session = valid + expect(protocol.testValidateSessionId(undefined)).toBe(true); + + // No active session, message has session = invalid + expect(protocol.testValidateSessionId('some-session')).toBe(false); + }); + }); + + describe('Session Lifecycle', () => { + it('should create session with correct state', async () => { + protocol = new TestProtocol(); + await protocol.connect(transport); + + protocol.testCreateSession('test-session-123', 60); + + const sessionState = protocol.getSessionState(); + expect(sessionState).toBeDefined(); + expect(sessionState!.sessionId).toBe('test-session-123'); + expect(sessionState!.timeout).toBe(60); + expect(sessionState!.createdAt).toBeCloseTo(Date.now(), -2); + expect(sessionState!.lastActivity).toBeCloseTo(Date.now(), -2); + }); + + it('should terminate session correctly', async () => { + protocol = new TestProtocol(); + await protocol.connect(transport); + + protocol.testCreateSession('test-session-123'); + expect(protocol.getSessionState()).toBeDefined(); + + await protocol.testTerminateSession('test-session-123'); + + expect(protocol.getSessionState()).toBeUndefined(); + }); + + it('should reject termination with wrong sessionId', async () => { + protocol = new TestProtocol(); + await protocol.connect(transport); + + protocol.testCreateSession('test-session-123'); + + await expect(protocol.testTerminateSession('wrong-session')) + .rejects.toThrow('Internal error'); + + // Session should still exist + expect(protocol.getSessionState()).toBeDefined(); + }); + }); + + describe('Message Handling with Sessions', () => { + beforeEach(async () => { + protocol = new TestProtocol(); + await protocol.connect(transport); + protocol.testCreateSession('test-session'); + }); + + it('should reject messages with invalid sessionId', () => { + const invalidMessage: JSONRPCRequest = { + jsonrpc: '2.0', + id: 1, + method: 'test', + sessionId: 'wrong-session' + }; + + // Simulate message handling + transport.onmessage!(invalidMessage); + + // Should send error response + expect(transport.sentMessages).toHaveLength(1); + const errorMessage = transport.sentMessages[0] as JSONRPCMessage & { error: { code: number } }; + expect(errorMessage.error.code).toBe(ErrorCode.InvalidSession); + }); + + it('should reject sessionless messages when session exists', () => { + const sessionlessMessage: JSONRPCRequest = { + jsonrpc: '2.0', + id: 1, + method: 'test' + // No sessionId + }; + + transport.onmessage!(sessionlessMessage); + + // Should send error response + expect(transport.sentMessages).toHaveLength(1); + const errorMessage = transport.sentMessages[0] as JSONRPCMessage & { error: { code: number } }; + expect(errorMessage.error.code).toBe(ErrorCode.InvalidSession); + }); + }); +}); \ No newline at end of file diff --git a/src/shared/protocol-transport-handling.test.ts b/src/shared/protocol-transport-handling.test.ts index 3baa9b638..27e75f6cd 100644 --- a/src/shared/protocol-transport-handling.test.ts +++ b/src/shared/protocol-transport-handling.test.ts @@ -43,6 +43,112 @@ describe("Protocol transport handling bug", () => { transportB = new MockTransport("B"); }); + test("should handle initialize request correctly when transport switches mid-flight", async () => { + // Set up a handler for initialize that simulates processing time + let resolveHandler: (value: Result) => void; + const handlerPromise = new Promise((resolve) => { + resolveHandler = resolve; + }); + + const InitializeRequestSchema = z.object({ + method: z.literal("initialize"), + params: z.object({ + protocolVersion: z.string(), + capabilities: z.object({}), + clientInfo: z.object({ + name: z.string(), + version: z.string() + }) + }) + }); + + protocol.setRequestHandler( + InitializeRequestSchema, + async (request) => { + console.log(`Processing initialize from ${request.params.clientInfo.name}`); + return handlerPromise; + } + ); + + // Client A connects and sends initialize request + await protocol.connect(transportA); + + const initFromA = { + jsonrpc: "2.0" as const, + method: "initialize", + params: { + protocolVersion: "2025-06-18", + capabilities: {}, + clientInfo: { + name: "clientA", + version: "1.0" + } + }, + id: 1 + }; + + // Simulate client A sending initialize request + transportA.onmessage?.(initFromA); + + // While A's initialize is being processed, client B connects + // This overwrites the transport reference in the protocol + await protocol.connect(transportB); + + const initFromB = { + jsonrpc: "2.0" as const, + method: "initialize", + params: { + protocolVersion: "2025-06-18", + capabilities: {}, + clientInfo: { + name: "clientB", + version: "1.0" + } + }, + id: 2 + }; + + // Client B sends its own initialize request + transportB.onmessage?.(initFromB); + + // Now complete A's initialize request with session info + resolveHandler!({ + protocolVersion: "2025-06-18", + capabilities: {}, + serverInfo: { name: "test-server", version: "1.0" }, + sessionId: "session-for-A" + } as Result); + + // Wait for async operations to complete + await new Promise(resolve => setTimeout(resolve, 10)); + + // Check where the responses went + console.log("Transport A received:", transportA.sentMessages); + console.log("Transport B received:", transportB.sentMessages); + + // Transport A should receive response for its initialize request + expect(transportA.sentMessages.length).toBe(1); + expect(transportA.sentMessages[0]).toMatchObject({ + jsonrpc: "2.0", + id: 1, + result: { + protocolVersion: "2025-06-18", + sessionId: "session-for-A" + } + }); + + // Transport B should receive its own response (when handler completes) + expect(transportB.sentMessages.length).toBe(1); + expect(transportB.sentMessages[0]).toMatchObject({ + jsonrpc: "2.0", + id: 2, + result: { + protocolVersion: "2025-06-18", + sessionId: "session-for-A" // Same handler result in this test + } + }); + }); + test("should send response to the correct transport when multiple clients are connected", async () => { // Set up a request handler that simulates processing time let resolveHandler: (value: Result) => void; @@ -121,6 +227,140 @@ describe("Protocol transport handling bug", () => { }); }); + test("should prevent re-initialization when transport switches after successful init", async () => { + // Server-side protocol with session support + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + + // Expose session methods for testing + public testGetSessionState() { + return this.getSessionState(); + } + + public testCreateSession(sessionId: string) { + return this.createSession(sessionId); + } + })(); + + const InitializeRequestSchema = z.object({ + method: z.literal("initialize"), + params: z.object({ + protocolVersion: z.string(), + capabilities: z.object({}), + clientInfo: z.object({ + name: z.string(), + version: z.string() + }) + }) + }); + + let initializeCount = 0; + serverProtocol.setRequestHandler( + InitializeRequestSchema, + async (request) => { + initializeCount++; + console.log(`Initialize handler called, count=${initializeCount}, client=${request.params.clientInfo.name}`); + // Simulate session creation on server side + const sessionId = `session-${initializeCount}`; + serverProtocol.testCreateSession(sessionId); + + return { + protocolVersion: "2025-06-18", + capabilities: {}, + serverInfo: { name: "test-server", version: "1.0" }, + sessionId + } as Result; + } + ); + + // First client connects and initializes + await serverProtocol.connect(transportA); + + const initFromA = { + jsonrpc: "2.0" as const, + method: "initialize", + params: { + protocolVersion: "2025-06-18", + capabilities: {}, + clientInfo: { + name: "clientA", + version: "1.0" + } + }, + id: 1 + }; + + transportA.onmessage?.(initFromA); + + // Wait for initialization to complete + await new Promise(resolve => setTimeout(resolve, 10)); + + // Verify session was created for transport A + expect(serverProtocol.testGetSessionState()).toBeDefined(); + expect(serverProtocol.testGetSessionState()?.sessionId).toBe("session-1"); + + // Now client B connects (transport switches) + await serverProtocol.connect(transportB); + + // Note: Session state is NOT automatically cleared when transport switches + // This could lead to session ID mismatches if the same protocol instance + // is reused with different transports + expect(serverProtocol.testGetSessionState()).toBeDefined(); + expect(serverProtocol.testGetSessionState()?.sessionId).toBe("session-1"); + + const initFromB = { + jsonrpc: "2.0" as const, + method: "initialize", + params: { + protocolVersion: "2025-06-18", + capabilities: {}, + clientInfo: { + name: "clientB", + version: "1.0" + } + }, + id: 2 + }; + + transportB.onmessage?.(initFromB); + + // Wait for second initialization attempt + await new Promise(resolve => setTimeout(resolve, 10)); + + // The session state should remain from the first initialization + // The protocol doesn't allow re-initialization once a session exists + expect(serverProtocol.testGetSessionState()).toBeDefined(); + expect(serverProtocol.testGetSessionState()?.sessionId).toBe("session-1"); + + // Verify transport A got success response + expect(transportA.sentMessages.length).toBe(1); + expect(transportA.sentMessages[0]).toMatchObject({ + jsonrpc: "2.0", + id: 1, + result: { + sessionId: "session-1" + } + }); + + // Transport B's initialize request is rejected because it lacks a valid session ID + // The server has an active session from transport A, so requests without + // the correct session ID are rejected + expect(transportB.sentMessages.length).toBe(1); + expect(transportB.sentMessages[0]).toMatchObject({ + jsonrpc: "2.0", + id: 2, + error: expect.objectContaining({ + code: -32003, // Invalid session error code + message: "Invalid or expired session" + }) + }); + + // Verify the handler was only called once + expect(initializeCount).toBe(1); + }); + test("demonstrates the timing issue with multiple rapid connections", async () => { const delays: number[] = []; const results: { transport: string; response: JSONRPCMessage[] }[] = []; diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 7df190ba1..53ca2084d 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -11,6 +11,7 @@ import { JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + JSONRPCMessage, McpError, Notification, PingRequestSchema, @@ -24,6 +25,7 @@ import { RequestMeta, MessageExtraInfo, RequestInfo, + SessionId, } from "../types.js"; import { Transport, TransportSendOptions } from "./transport.js"; import { AuthInfo } from "../server/auth/types.js"; @@ -33,6 +35,27 @@ import { AuthInfo } from "../server/auth/types.js"; */ export type ProgressCallback = (progress: Progress) => void; +/** + * Session state for protocol-level session management. + */ +export interface SessionState { + sessionId: SessionId; + createdAt: number; + lastActivity: number; + timeout?: number; // seconds + callbackError?: Error; // Stores error if session callbacks fail +} + +/** + * Session configuration options. + */ +export interface SessionOptions { + sessionIdGenerator?: () => SessionId; + sessionTimeout?: number; // seconds + onsessioninitialized?: (sessionId: SessionId) => void | Promise; + onsessionclosed?: (sessionId: SessionId) => void | Promise; +} + /** * Additional initialization options. */ @@ -52,6 +75,10 @@ export type ProtocolOptions = { * e.g., ['notifications/tools/list_changed'] */ debouncedNotificationMethods?: string[]; + /** + * Session configuration options. + */ + sessions?: SessionOptions; }; /** @@ -121,9 +148,9 @@ export type RequestHandlerExtra { private _transport?: Transport; private _requestMessageId = 0; + private _sessionState?: SessionState; private _requestHandlers: Map< string, ( @@ -290,6 +318,73 @@ export abstract class Protocol< } } + // Session management methods + protected validateSessionId(messageSessionId?: SessionId): boolean { + if (!messageSessionId && !this._sessionState) return true; // Both sessionless + if (!messageSessionId || !this._sessionState) return false; // Mismatch + return messageSessionId === this._sessionState.sessionId; + } + + protected createSession(sessionId: SessionId, timeout?: number): void { + this._sessionState = { + sessionId, + createdAt: Date.now(), + lastActivity: Date.now(), + timeout + }; + // Don't reset counter when creating session - only reset on reconnect/terminate + + // Notify transport of session state for HTTP header handling + this._transport?.setSessionState?.(this._sessionState); + } + + + protected updateSessionActivity(): void { + if (this._sessionState) { + this._sessionState.lastActivity = Date.now(); + } + } + + protected isSessionExpired(): boolean { + if (!this._sessionState?.timeout) return false; + const now = Date.now(); + const expiry = this._sessionState.lastActivity + (this._sessionState.timeout * 1000); + return now > expiry; + } + + protected async terminateSession(sessionId?: SessionId): Promise { + // Validate sessionId - mismatch is internal error since sessionId should be validated on incoming message + if (sessionId && sessionId !== this._sessionState?.sessionId) { + throw new Error(`Internal error: terminateSession called with sessionId ${sessionId} but current session is ${this._sessionState?.sessionId}`); + } + + // Terminate session (same cleanup as protocol handler) + if (this._sessionState) { + this._sessionState = undefined; + this._requestMessageId = 0; // Reset counter + } + } + + + protected getSessionState() { + return this._sessionState; + } + + private sendInvalidSessionError(message: JSONRPCMessage): void { + if ('id' in message && message.id !== undefined) { + const errorResponse: JSONRPCError = { + jsonrpc: "2.0", + id: message.id, + error: { + code: ErrorCode.InvalidSession, + message: "Invalid or expired session", + data: { sessionId: 'sessionId' in message ? message.sessionId : null } + } + }; + this._transport?.send(errorResponse).catch(err => this._onerror(err)); + } + } + /** * Attaches to the given transport, starts it, and starts listening for messages. * @@ -309,9 +404,32 @@ export abstract class Protocol< this._onerror(error); }; + const _onmessage = this._transport?.onmessage; this._transport.onmessage = (message, extra) => { _onmessage?.(message, extra); + + // Only validate session for incoming requests (server-side only) + // Don't validate responses or notifications as they are outgoing from server + if (this._sessionState && isJSONRPCRequest(message)) { + // Check for session expiry BEFORE updating activity + if (this.isSessionExpired()) { + this.sendInvalidSessionError(message); + return; + } + + const messageSessionId = 'sessionId' in message ? message.sessionId : undefined; + if (!this.validateSessionId(messageSessionId)) { + // Send invalid session error + this.sendInvalidSessionError(message); + return; + } + // Only update activity if message has valid sessionId + if (messageSessionId) { + this.updateSessionActivity(); + } + } + if (isJSONRPCResponse(message) || isJSONRPCError(message)) { this._onresponse(message); } else if (isJSONRPCRequest(message)) { @@ -370,8 +488,9 @@ export abstract class Protocol< const handler = this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler; - // Capture the current transport at request time to ensure responses go to the correct client + // Capture the current transport and session state at request time to ensure responses go to the correct client const capturedTransport = this._transport; + const capturedSessionState = this._sessionState ? { ...this._sessionState } : undefined; if (handler === undefined) { capturedTransport @@ -396,7 +515,7 @@ export abstract class Protocol< const fullExtra: RequestHandlerExtra = { signal: abortController.signal, - sessionId: capturedTransport?.sessionId, + sessionId: capturedSessionState?.sessionId, _meta: request.params?._meta, sendNotification: (notification) => @@ -417,11 +536,16 @@ export abstract class Protocol< return; } - return capturedTransport?.send({ - result, - jsonrpc: "2.0", + const resultWithSession = { + ...result, + ...(this._sessionState && { sessionId: this._sessionState.sessionId }), + }; + const responseMessage = { + result: resultWithSession, + jsonrpc: "2.0" as const, id: request.id, - }); + }; + return capturedTransport?.send(responseMessage); }, (error) => { if (abortController.signal.aborted) { @@ -562,8 +686,14 @@ export abstract class Protocol< options?.signal?.throwIfAborted(); const messageId = this._requestMessageId++; - const jsonrpcRequest: JSONRPCRequest = { + // Add sessionId to request if not already present and we have session state + const requestWithSession = { ...request, + ...(this._sessionState && !request.sessionId && { sessionId: this._sessionState.sessionId }), + }; + + const jsonrpcRequest: JSONRPCRequest = { + ...requestWithSession, jsonrpc: "2.0", id: messageId, }; @@ -674,8 +804,14 @@ export abstract class Protocol< return; } - const jsonrpcNotification: JSONRPCNotification = { + // Add sessionId to notification if not already present and we have session state + const notificationWithSession = { ...notification, + ...(this._sessionState && !notification.sessionId && { sessionId: this._sessionState.sessionId }), + }; + + const jsonrpcNotification: JSONRPCNotification = { + ...notificationWithSession, jsonrpc: "2.0", }; // Send the notification, but don't await it here to avoid blocking. @@ -687,8 +823,14 @@ export abstract class Protocol< return; } - const jsonrpcNotification: JSONRPCNotification = { + // Add sessionId to notification if not already present and we have session state + const notificationWithSession = { ...notification, + ...(this._sessionState && !notification.sessionId && { sessionId: this._sessionState.sessionId }), + }; + + const jsonrpcNotification: JSONRPCNotification = { + ...notificationWithSession, jsonrpc: "2.0", }; diff --git a/src/shared/transport.ts b/src/shared/transport.ts index 386b6bae5..53f1402ec 100644 --- a/src/shared/transport.ts +++ b/src/shared/transport.ts @@ -1,4 +1,5 @@ -import { JSONRPCMessage, MessageExtraInfo, RequestId } from "../types.js"; +import { JSONRPCMessage, MessageExtraInfo, RequestId, InitializeRequest, InitializeResult } from "../types.js"; +import { SessionOptions, SessionState } from "./protocol.js"; export type FetchLike = (url: string | URL, init?: RequestInit) => Promise; @@ -74,12 +75,38 @@ export interface Transport { onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; /** - * The session ID generated for this connection. + * The session ID for this connection (read-only). + * Available for backward compatibility only - returns the current session state's sessionId. + * Session management should be done through server session options, not transport properties. */ - sessionId?: string; + readonly sessionId?: string; + + /** + * Gets legacy session configuration for backward compatibility. + * Used by server to delegate transport-level session configuration. + */ + getLegacySessionOptions?: () => SessionOptions | undefined; + + /** + * Sets the session state reference for HTTP header handling. + * Used by server to notify transport of session creation. + */ + setSessionState?: (sessionState: SessionState) => void; /** * Sets the protocol version used for the connection (called when the initialize response is received). */ setProtocolVersion?: (version: string) => void; + + /** + * Sets a handler for synchronous initialization processing. + * Used by HTTP transport to handle initialization before sending response headers. + */ + setInitializeHandler?: (handler: (request: InitializeRequest) => Promise) => void; + + /** + * Sets a handler for synchronous session termination processing. + * Used by HTTP transport to handle termination before sending response headers. + */ + setTerminateHandler?: (handler: (sessionId?: string) => Promise) => void; } diff --git a/src/types.ts b/src/types.ts index 323e37389..6884f165f 100644 --- a/src/types.ts +++ b/src/types.ts @@ -23,6 +23,11 @@ export const ProgressTokenSchema = z.union([z.string(), z.number().int()]); */ export const CursorSchema = z.string(); +/** + * A unique identifier for a session. + */ +export const SessionIdSchema = z.string(); + const RequestMetaSchema = z .object({ /** @@ -41,6 +46,7 @@ const BaseRequestParamsSchema = z export const RequestSchema = z.object({ method: z.string(), params: z.optional(BaseRequestParamsSchema), + sessionId: z.optional(SessionIdSchema), }); const BaseNotificationParamsSchema = z @@ -56,6 +62,7 @@ const BaseNotificationParamsSchema = z export const NotificationSchema = z.object({ method: z.string(), params: z.optional(BaseNotificationParamsSchema), + sessionId: z.optional(SessionIdSchema), }); export const ResultSchema = z @@ -65,6 +72,7 @@ export const ResultSchema = z * for notes on _meta usage. */ _meta: z.optional(z.object({}).passthrough()), + sessionId: z.optional(SessionIdSchema), }) .passthrough(); @@ -123,6 +131,9 @@ export enum ErrorCode { // SDK error codes ConnectionClosed = -32000, RequestTimeout = -32001, + + // MCP-specific error codes + InvalidSession = -32003, // Standard JSON-RPC error codes ParseError = -32700, @@ -359,6 +370,14 @@ export const InitializeResultSchema = ResultSchema.extend({ * This can be used by clients to improve the LLM's understanding of available tools, resources, etc. It can be thought of like a "hint" to the model. For example, this information MAY be added to the system prompt. */ instructions: z.optional(z.string()), + /** + * Optional session identifier assigned by the server. + */ + sessionId: z.optional(SessionIdSchema), + /** + * Optional session timeout hint in seconds. + */ + sessionTimeout: z.optional(z.number().int().positive()), }); /** @@ -1352,6 +1371,15 @@ export const CompleteResultSchema = ResultSchema.extend({ .passthrough(), }); +/* Sessions */ +/** + * Request to terminate a session. + */ +export const SessionTerminateRequestSchema = RequestSchema.extend({ + method: z.literal("session/terminate"), + // No params - sessionId in request envelope +}); + /* Roots */ /** * Represents a root directory or file that the server can operate on. @@ -1400,6 +1428,7 @@ export const RootsListChangedNotificationSchema = NotificationSchema.extend({ export const ClientRequestSchema = z.union([ PingRequestSchema, InitializeRequestSchema, + SessionTerminateRequestSchema, CompleteRequestSchema, SetLevelRequestSchema, GetPromptRequestSchema, @@ -1522,6 +1551,7 @@ export type RequestMeta = Infer; export type Notification = Infer; export type Result = Infer; export type RequestId = Infer; +export type SessionId = Infer; export type JSONRPCRequest = Infer; export type JSONRPCNotification = Infer; export type JSONRPCResponse = Infer; @@ -1628,6 +1658,9 @@ export type PromptReference = Infer; export type CompleteRequest = Infer; export type CompleteResult = Infer; +/* Sessions */ +export type SessionTerminateRequest = Infer; + /* Roots */ export type Root = Infer; export type ListRootsRequest = Infer;