Skip to content

Commit b067261

Browse files
authored
Prune idle sessions before starting new ones (#701)
1 parent 70960e1 commit b067261

File tree

16 files changed

+472
-281
lines changed

16 files changed

+472
-281
lines changed

samples/ProtectedMcpServer/Tools/HttpClientExt.cs

Lines changed: 0 additions & 13 deletions
This file was deleted.

samples/ProtectedMcpServer/Tools/WeatherTools.cs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ public async Task<string> GetAlerts(
2121
[Description("The US state to get alerts for. Use the 2 letter abbreviation for the state (e.g. NY).")] string state)
2222
{
2323
var client = _httpClientFactory.CreateClient("WeatherApi");
24-
using var jsonDocument = await client.ReadJsonDocumentAsync($"/alerts/active/area/{state}");
25-
var jsonElement = jsonDocument.RootElement;
26-
var alerts = jsonElement.GetProperty("features").EnumerateArray();
24+
using var jsonDocument = await client.GetFromJsonAsync<JsonDocument>($"/alerts/active/area/{state}")
25+
?? throw new McpException("No JSON returned from alerts endpoint");
26+
27+
var alerts = jsonDocument.RootElement.GetProperty("features").EnumerateArray();
2728

2829
if (!alerts.Any())
2930
{
@@ -50,12 +51,14 @@ public async Task<string> GetForecast(
5051
{
5152
var client = _httpClientFactory.CreateClient("WeatherApi");
5253
var pointUrl = string.Create(CultureInfo.InvariantCulture, $"/points/{latitude},{longitude}");
53-
using var jsonDocument = await client.ReadJsonDocumentAsync(pointUrl);
54-
var forecastUrl = jsonDocument.RootElement.GetProperty("properties").GetProperty("forecast").GetString()
55-
?? throw new Exception($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}");
5654

57-
using var forecastDocument = await client.ReadJsonDocumentAsync(forecastUrl);
58-
var periods = forecastDocument.RootElement.GetProperty("properties").GetProperty("periods").EnumerateArray();
55+
using var locationDocument = await client.GetFromJsonAsync<JsonDocument>(pointUrl);
56+
var forecastUrl = locationDocument?.RootElement.GetProperty("properties").GetProperty("forecast").GetString()
57+
?? throw new McpException($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}");
58+
59+
using var forecastDocument = await client.GetFromJsonAsync<JsonDocument>(forecastUrl);
60+
var periods = forecastDocument?.RootElement.GetProperty("properties").GetProperty("periods").EnumerateArray()
61+
?? throw new McpException("No JSON returned from forecast endpoint");
5962

6063
return string.Join("\n---\n", periods.Select(period => $"""
6164
{period.GetProperty("name").GetString()}

samples/QuickstartWeatherServer/Tools/WeatherTools.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ public static async Task<string> GetForecast(
4343
[Description("Longitude of the location.")] double longitude)
4444
{
4545
var pointUrl = string.Create(CultureInfo.InvariantCulture, $"/points/{latitude},{longitude}");
46-
using var jsonDocument = await client.ReadJsonDocumentAsync(pointUrl);
47-
var forecastUrl = jsonDocument.RootElement.GetProperty("properties").GetProperty("forecast").GetString()
48-
?? throw new Exception($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}");
46+
using var locationDocument = await client.ReadJsonDocumentAsync(pointUrl);
47+
var forecastUrl = locationDocument.RootElement.GetProperty("properties").GetProperty("forecast").GetString()
48+
?? throw new McpException($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}");
4949

5050
using var forecastDocument = await client.ReadJsonDocumentAsync(forecastUrl);
5151
var periods = forecastDocument.RootElement.GetProperty("properties").GetProperty("periods").EnumerateArray();

src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder
2323
{
2424
ArgumentNullException.ThrowIfNull(builder);
2525

26+
builder.Services.TryAddSingleton<StatefulSessionManager>();
2627
builder.Services.TryAddSingleton<StreamableHttpHandler>();
2728
builder.Services.TryAddSingleton<SseHandler>();
2829
builder.Services.AddHostedService<IdleTrackingBackgroundService>();

src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs

Lines changed: 0 additions & 85 deletions
This file was deleted.

src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ public class HttpServerTransportOptions
6666
/// Past this limit, the server will log a critical error and terminate the oldest idle sessions even if they have not reached
6767
/// their <see cref="IdleTimeout"/> until the idle session count is below this limit. Clients that keep their session open by
6868
/// keeping a GET request open will not count towards this limit.
69-
/// Defaults to 100,000 sessions.
69+
/// Defaults to 10,000 sessions.
7070
/// </remarks>
71-
public int MaxIdleSessionCount { get; set; } = 100_000;
71+
public int MaxIdleSessionCount { get; set; } = 10_000;
7272

7373
/// <summary>
7474
/// Used for testing the <see cref="IdleTimeout"/>.
Lines changed: 5 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
1-
using System.Runtime.InteropServices;
2-
using Microsoft.Extensions.Hosting;
1+
using Microsoft.Extensions.Hosting;
32
using Microsoft.Extensions.Logging;
43
using Microsoft.Extensions.Options;
5-
using ModelContextProtocol.Server;
64

75
namespace ModelContextProtocol.AspNetCore;
86

97
internal sealed partial class IdleTrackingBackgroundService(
10-
StreamableHttpHandler handler,
8+
StatefulSessionManager sessions,
119
IOptions<HttpServerTransportOptions> options,
1210
IHostApplicationLifetime appLifetime,
1311
ILogger<IdleTrackingBackgroundService> logger) : BackgroundService
1412
{
15-
// The compiler will complain about the parameter being unused otherwise despite the source generator.
13+
// Workaround for https://github.com/dotnet/runtime/issues/91121. This is fixed in .NET 9 and later.
1614
private readonly ILogger _logger = logger;
1715

1816
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
@@ -30,65 +28,9 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken)
3028
var timeProvider = options.Value.TimeProvider;
3129
using var timer = new PeriodicTimer(TimeSpan.FromSeconds(5), timeProvider);
3230

33-
var idleTimeoutTicks = options.Value.IdleTimeout.Ticks;
34-
var maxIdleSessionCount = options.Value.MaxIdleSessionCount;
35-
36-
// Create two lists that will be reused between runs.
37-
// This assumes that the number of idle sessions is not breached frequently.
38-
// If the idle sessions often breach the maximum, a priority queue could be considered.
39-
var idleSessionsTimestamps = new List<long>();
40-
var idleSessionSessionIds = new List<string>();
41-
4231
while (!stoppingToken.IsCancellationRequested && await timer.WaitForNextTickAsync(stoppingToken))
4332
{
44-
var idleActivityCutoff = idleTimeoutTicks switch
45-
{
46-
< 0 => long.MinValue,
47-
var ticks => timeProvider.GetTimestamp() - ticks,
48-
};
49-
50-
foreach (var (_, session) in handler.Sessions)
51-
{
52-
if (session.IsActive || session.SessionClosed.IsCancellationRequested)
53-
{
54-
// There's a request currently active or the session is already being closed.
55-
continue;
56-
}
57-
58-
if (session.LastActivityTicks < idleActivityCutoff)
59-
{
60-
RemoveAndCloseSession(session.Id);
61-
continue;
62-
}
63-
64-
// Add the timestamp and the session
65-
idleSessionsTimestamps.Add(session.LastActivityTicks);
66-
idleSessionSessionIds.Add(session.Id);
67-
68-
// Emit critical log at most once every 5 seconds the idle count it exceeded,
69-
// since the IdleTimeout will no longer be respected.
70-
if (idleSessionsTimestamps.Count == maxIdleSessionCount + 1)
71-
{
72-
LogMaxSessionIdleCountExceeded(maxIdleSessionCount);
73-
}
74-
}
75-
76-
if (idleSessionsTimestamps.Count > maxIdleSessionCount)
77-
{
78-
var timestamps = CollectionsMarshal.AsSpan(idleSessionsTimestamps);
79-
80-
// Sort only if the maximum is breached and sort solely by the timestamp. Sort both collections.
81-
timestamps.Sort(CollectionsMarshal.AsSpan(idleSessionSessionIds));
82-
83-
var sessionsToPrune = CollectionsMarshal.AsSpan(idleSessionSessionIds)[..^maxIdleSessionCount];
84-
foreach (var id in sessionsToPrune)
85-
{
86-
RemoveAndCloseSession(id);
87-
}
88-
}
89-
90-
idleSessionsTimestamps.Clear();
91-
idleSessionSessionIds.Clear();
33+
await sessions.PruneIdleSessionsAsync(stoppingToken);
9234
}
9335
}
9436
catch (OperationCanceledException) when (stoppingToken.IsCancellationRequested)
@@ -98,17 +40,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken)
9840
{
9941
try
10042
{
101-
List<Task> disposeSessionTasks = [];
102-
103-
foreach (var (sessionKey, _) in handler.Sessions)
104-
{
105-
if (handler.Sessions.TryRemove(sessionKey, out var session))
106-
{
107-
disposeSessionTasks.Add(DisposeSessionAsync(session));
108-
}
109-
}
110-
111-
await Task.WhenAll(disposeSessionTasks);
43+
await sessions.DisposeAllSessionsAsync();
11244
}
11345
finally
11446
{
@@ -123,39 +55,6 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken)
12355
}
12456
}
12557

126-
private void RemoveAndCloseSession(string sessionId)
127-
{
128-
if (!handler.Sessions.TryRemove(sessionId, out var session))
129-
{
130-
return;
131-
}
132-
133-
LogSessionIdle(session.Id);
134-
// Don't slow down the idle tracking loop. DisposeSessionAsync logs. We only await during graceful shutdown.
135-
_ = DisposeSessionAsync(session);
136-
}
137-
138-
private async Task DisposeSessionAsync(HttpMcpSession<StreamableHttpServerTransport> session)
139-
{
140-
try
141-
{
142-
await session.DisposeAsync();
143-
}
144-
catch (Exception ex)
145-
{
146-
LogSessionDisposeError(session.Id, ex);
147-
}
148-
}
149-
150-
[LoggerMessage(Level = LogLevel.Information, Message = "Closing idle session {sessionId}.")]
151-
private partial void LogSessionIdle(string sessionId);
152-
153-
[LoggerMessage(Level = LogLevel.Error, Message = "Error disposing session {sessionId}.")]
154-
private partial void LogSessionDisposeError(string sessionId, Exception ex);
155-
156-
[LoggerMessage(Level = LogLevel.Critical, Message = "Exceeded maximum of {maxIdleSessionCount} idle sessions. Now closing sessions active more recently than configured IdleTimeout.")]
157-
private partial void LogMaxSessionIdleCountExceeded(int maxIdleSessionCount);
158-
15958
[LoggerMessage(Level = LogLevel.Critical, Message = "The IdleTrackingBackgroundService has stopped unexpectedly.")]
16059
private partial void IdleTrackingBackgroundServiceStoppedUnexpectedly();
16160
}

src/ModelContextProtocol.AspNetCore/SseHandler.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ internal sealed class SseHandler(
1616
IHostApplicationLifetime hostApplicationLifetime,
1717
ILoggerFactory loggerFactory)
1818
{
19-
private readonly ConcurrentDictionary<string, HttpMcpSession<SseResponseStreamTransport>> _sessions = new(StringComparer.Ordinal);
19+
private readonly ConcurrentDictionary<string, SseSession> _sessions = new(StringComparer.Ordinal);
2020

2121
public async Task HandleSseRequestAsync(HttpContext context)
2222
{
@@ -34,9 +34,9 @@ public async Task HandleSseRequestAsync(HttpContext context)
3434
await using var transport = new SseResponseStreamTransport(context.Response.Body, $"{endpointPattern}message?sessionId={sessionId}", sessionId);
3535

3636
var userIdClaim = StreamableHttpHandler.GetUserIdClaim(context.User);
37-
await using var httpMcpSession = new HttpMcpSession<SseResponseStreamTransport>(sessionId, transport, userIdClaim, httpMcpServerOptions.Value.TimeProvider);
37+
var sseSession = new SseSession(transport, userIdClaim);
3838

39-
if (!_sessions.TryAdd(sessionId, httpMcpSession))
39+
if (!_sessions.TryAdd(sessionId, sseSession))
4040
{
4141
throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created.");
4242
}
@@ -55,12 +55,10 @@ public async Task HandleSseRequestAsync(HttpContext context)
5555
try
5656
{
5757
await using var mcpServer = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, context.RequestServices);
58-
httpMcpSession.Server = mcpServer;
5958
context.Features.Set(mcpServer);
6059

6160
var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? StreamableHttpHandler.RunSessionAsync;
62-
httpMcpSession.ServerRunTask = runSessionAsync(context, mcpServer, cancellationToken);
63-
await httpMcpSession.ServerRunTask;
61+
await runSessionAsync(context, mcpServer, cancellationToken);
6462
}
6563
finally
6664
{
@@ -87,13 +85,13 @@ public async Task HandleMessageRequestAsync(HttpContext context)
8785
return;
8886
}
8987

90-
if (!_sessions.TryGetValue(sessionId.ToString(), out var httpMcpSession))
88+
if (!_sessions.TryGetValue(sessionId.ToString(), out var sseSession))
9189
{
9290
await Results.BadRequest($"Session ID not found.").ExecuteAsync(context);
9391
return;
9492
}
9593

96-
if (!httpMcpSession.HasSameUserId(context.User))
94+
if (sseSession.UserId != StreamableHttpHandler.GetUserIdClaim(context.User))
9795
{
9896
await Results.Forbid().ExecuteAsync(context);
9997
return;
@@ -106,8 +104,10 @@ public async Task HandleMessageRequestAsync(HttpContext context)
106104
return;
107105
}
108106

109-
await httpMcpSession.Transport.OnMessageReceivedAsync(message, context.RequestAborted);
107+
await sseSession.Transport.OnMessageReceivedAsync(message, context.RequestAborted);
110108
context.Response.StatusCode = StatusCodes.Status202Accepted;
111109
await context.Response.WriteAsync("Accepted");
112110
}
111+
112+
private record SseSession(SseResponseStreamTransport Transport, UserIdClaim? UserId);
113113
}

0 commit comments

Comments
 (0)