Skip to content

Improve testing for StreamableHttpHandler and IdleTrackingBackgroundService #345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
<PackageVersion Include="Microsoft.Extensions.Logging" Version="9.0.4" />
<PackageVersion Include="Microsoft.Extensions.Logging.Console" Version="9.0.4" />
<PackageVersion Include="Microsoft.Extensions.Options" Version="9.0.4" />
<PackageVersion Include="Microsoft.Extensions.TimeProvider.Testing" Version="9.4.0" />
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="17.12.0" />
<PackageVersion Include="Moq" Version="4.20.72" />
<PackageVersion Include="OpenTelemetry" Version="1.11.2" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public static class HttpMcpServerBuilderExtensions
public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder, Action<HttpServerTransportOptions>? configureOptions = null)
{
ArgumentNullException.ThrowIfNull(builder);

builder.Services.TryAddSingleton<StreamableHttpHandler>();
builder.Services.TryAddSingleton<SseHandler>();
builder.Services.AddHostedService<IdleTrackingBackgroundService>();
Expand Down
15 changes: 10 additions & 5 deletions src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,18 @@ public async ValueTask DisposeAsync()
}
finally
{
if (Server is not null)
try
{
await Server.DisposeAsync();
if (Server is not null)
{
await Server.DisposeAsync();
}
}
finally
{
await Transport.DisposeAsync();
_disposeCts.Dispose();
}

await Transport.DisposeAsync();
_disposeCts.Dispose();
}
}

Expand Down
12 changes: 10 additions & 2 deletions src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,17 @@ public class HttpServerTransportOptions
/// Represents the duration of time the server will wait between any active requests before timing out an
/// MCP session. This is checked in background every 5 seconds. A client trying to resume a session will
/// receive a 404 status code and should restart their session. A client can keep their session open by
/// keeping a GET request open. The default value is set to 2 minutes.
/// keeping a GET request open. The default value is set to 2 hours.
/// </summary>
public TimeSpan IdleTimeout { get; set; } = TimeSpan.FromMinutes(2);
public TimeSpan IdleTimeout { get; set; } = TimeSpan.FromHours(2);

/// <summary>
/// The maximum number of idle sessions to track. This is used to limit the number of sessions that can be idle at once.
/// Past this limit, the server will log a critical error and terminate the oldest idle sessions even if they have not reached
/// their <see cref="IdleTimeout"/> until the idle session count is below this limit. Clients that keep their session open by
/// keeping a GET request open will not count towards this limit. The default value is set to 10,000 sessions.
/// </summary>
public int MaxIdleSessionCount { get; set; } = 10_000;

/// <summary>
/// Used for testing the <see cref="IdleTimeout"/>.
Expand Down
107 changes: 81 additions & 26 deletions src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,39 @@ namespace ModelContextProtocol.AspNetCore;
internal sealed partial class IdleTrackingBackgroundService(
StreamableHttpHandler handler,
IOptions<HttpServerTransportOptions> options,
IHostApplicationLifetime appLifetime,
ILogger<IdleTrackingBackgroundService> logger) : BackgroundService
{
// The compiler will complain about the parameter being unused otherwise despite the source generator.
private ILogger _logger = logger;

// We can make this configurable once we properly harden the MCP server. In the meantime, anyone running
// this should be taking a cattle not pets approach to their servers and be able to launch more processes
// to handle more than 10,000 idle sessions at a time.
private const int MaxIdleSessionCount = 10_000;

protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
var timeProvider = options.Value.TimeProvider;
using var timer = new PeriodicTimer(TimeSpan.FromSeconds(5), timeProvider);
// Still run loop given infinite IdleTimeout to enforce the MaxIdleSessionCount and assist graceful shutdown.
if (options.Value.IdleTimeout != Timeout.InfiniteTimeSpan)
{
ArgumentOutOfRangeException.ThrowIfLessThan(options.Value.IdleTimeout, TimeSpan.Zero);
}
ArgumentOutOfRangeException.ThrowIfLessThan(options.Value.MaxIdleSessionCount, 0);

try
{
var timeProvider = options.Value.TimeProvider;
using var timer = new PeriodicTimer(TimeSpan.FromSeconds(5), timeProvider);

var idleTimeoutTicks = options.Value.IdleTimeout.Ticks;
var maxIdleSessionCount = options.Value.MaxIdleSessionCount;

var idleSessions = new SortedSet<(string SessionId, long Timestamp)>(SessionTimestampComparer.Instance);

while (!stoppingToken.IsCancellationRequested && await timer.WaitForNextTickAsync(stoppingToken))
{
var idleActivityCutoff = timeProvider.GetTimestamp() - options.Value.IdleTimeout.Ticks;
var idleActivityCutoff = idleTimeoutTicks switch
{
< 0 => long.MinValue,
var ticks => timeProvider.GetTimestamp() - ticks,
};

var idleCount = 0;
foreach (var (_, session) in handler.Sessions)
{
if (session.IsActive || session.SessionClosed.IsCancellationRequested)
Expand All @@ -38,34 +49,40 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken)
continue;
}

idleCount++;
if (idleCount == MaxIdleSessionCount)
{
// Emit critical log at most once every 5 seconds the idle count it exceeded,
//since the IdleTimeout will no longer be respected.
LogMaxSessionIdleCountExceeded();
}
else if (idleCount < MaxIdleSessionCount && session.LastActivityTicks > idleActivityCutoff)
if (session.LastActivityTicks < idleActivityCutoff)
{
RemoveAndCloseSession(session.Id);
continue;
}

if (handler.Sessions.TryRemove(session.Id, out var removedSession))
idleSessions.Add((session.Id, session.LastActivityTicks));

// Emit critical log at most once every 5 seconds the idle count it exceeded,
// since the IdleTimeout will no longer be respected.
if (idleSessions.Count == maxIdleSessionCount + 1)
{
LogSessionIdle(removedSession.Id);
LogMaxSessionIdleCountExceeded(maxIdleSessionCount);
}
}

// Don't slow down the idle tracking loop. DisposeSessionAsync logs. We only await during graceful shutdown.
_ = DisposeSessionAsync(removedSession);
if (idleSessions.Count > maxIdleSessionCount)
{
var sessionsToPrune = idleSessions.ToArray()[..^maxIdleSessionCount];
foreach (var (id, _) in sessionsToPrune)
{
RemoveAndCloseSession(id);
}
}

idleSessions.Clear();
}
}
catch (OperationCanceledException) when (stoppingToken.IsCancellationRequested)
{
}
finally
{
if (stoppingToken.IsCancellationRequested)
try
{
List<Task> disposeSessionTasks = [];

Expand All @@ -79,7 +96,29 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken)

await Task.WhenAll(disposeSessionTasks);
}
finally
{
if (!stoppingToken.IsCancellationRequested)
{
// Something went terribly wrong. A very unexpected exception must be bubbling up, but let's ensure we also stop the application,
// so that it hopefully gets looked at and restarted. This shouldn't really be reachable.
appLifetime.StopApplication();
IdleTrackingBackgroundServiceStoppedUnexpectedly();
}
}
}
}

private void RemoveAndCloseSession(string sessionId)
{
if (!handler.Sessions.TryRemove(sessionId, out var session))
{
return;
}

LogSessionIdle(session.Id);
// Don't slow down the idle tracking loop. DisposeSessionAsync logs. We only await during graceful shutdown.
_ = DisposeSessionAsync(session);
}

private async Task DisposeSessionAsync(HttpMcpSession<StreamableHttpServerTransport> session)
Expand All @@ -94,12 +133,28 @@ private async Task DisposeSessionAsync(HttpMcpSession<StreamableHttpServerTransp
}
}

private sealed class SessionTimestampComparer : IComparer<(string SessionId, long Timestamp)>
{
public static SessionTimestampComparer Instance { get; } = new();

public int Compare((string SessionId, long Timestamp) x, (string SessionId, long Timestamp) y) =>
x.Timestamp.CompareTo(y.Timestamp) switch
{
// Use a SessionId comparison as tiebreaker to ensure uniqueness in the SortedSet.
0 => string.CompareOrdinal(x.SessionId, y.SessionId),
var timestampComparison => timestampComparison,
};
}

[LoggerMessage(Level = LogLevel.Information, Message = "Closing idle session {sessionId}.")]
private partial void LogSessionIdle(string sessionId);

[LoggerMessage(Level = LogLevel.Critical, Message = "Exceeded static maximum of 10,000 idle connections. Now clearing all inactive connections regardless of timeout.")]
private partial void LogMaxSessionIdleCountExceeded();

[LoggerMessage(Level = LogLevel.Error, Message = "Error disposing the IMcpServer for session {sessionId}.")]
[LoggerMessage(Level = LogLevel.Error, Message = "Error disposing session {sessionId}.")]
private partial void LogSessionDisposeError(string sessionId, Exception ex);

[LoggerMessage(Level = LogLevel.Critical, Message = "Exceeded maximum of {maxIdleSessionCount} idle sessions. Now closing sessions active more recently than configured IdleTimeout.")]
private partial void LogMaxSessionIdleCountExceeded(int maxIdleSessionCount);

[LoggerMessage(Level = LogLevel.Critical, Message = "The IdleTrackingBackgroundService has stopped unexpectedly.")]
private partial void IdleTrackingBackgroundServiceStoppedUnexpectedly();
}
40 changes: 19 additions & 21 deletions src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Net.Http.Headers;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Server;
Expand All @@ -23,18 +24,19 @@ internal sealed class StreamableHttpHandler(
IServiceProvider applicationServices)
{
private static JsonTypeInfo<JsonRpcError> s_errorTypeInfo = GetRequiredJsonTypeInfo<JsonRpcError>();
private static MediaTypeHeaderValue ApplicationJsonMediaType = new("application/json");
private static MediaTypeHeaderValue TextEventStreamMediaType = new("text/event-stream");

public ConcurrentDictionary<string, HttpMcpSession<StreamableHttpServerTransport>> Sessions { get; } = new(StringComparer.Ordinal);

public async Task HandlePostRequestAsync(HttpContext context)
{
// The Streamable HTTP spec mandates the client MUST accept both application/json and text/event-stream.
// ASP.NET Core Minimal APIs mostly ry to stay out of the business of response content negotiation, so
// we have to do this manually. The spec doesn't mandate that servers MUST reject these requests, but it's
// probably good to at least start out trying to be strict.
var acceptHeader = context.Request.Headers.Accept.ToString();
if (!acceptHeader.Contains("application/json", StringComparison.Ordinal) ||
!acceptHeader.Contains("text/event-stream", StringComparison.Ordinal))
// ASP.NET Core Minimal APIs mostly try to stay out of the business of response content negotiation,
// so we have to do this manually. The spec doesn't mandate that servers MUST reject these requests,
// but it's probably good to at least start out trying to be strict.
var acceptHeaders = context.Request.GetTypedHeaders().Accept;
if (!acceptHeaders.Contains(ApplicationJsonMediaType) || !acceptHeaders.Contains(TextEventStreamMediaType))
{
await WriteJsonRpcErrorAsync(context,
"Not Acceptable: Client must accept both application/json and text/event-stream",
Expand All @@ -49,9 +51,8 @@ await WriteJsonRpcErrorAsync(context,
}

using var _ = session.AcquireReference();
using var cts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, session.SessionClosed);
InitializeSseResponse(context);
var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), cts.Token);
var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), context.RequestAborted);
if (!wroteResponse)
{
// We wound up writing nothing, so there should be no Content-Type response header.
Expand All @@ -62,8 +63,8 @@ await WriteJsonRpcErrorAsync(context,

public async Task HandleGetRequestAsync(HttpContext context)
{
var acceptHeader = context.Request.Headers.Accept.ToString();
if (!acceptHeader.Contains("application/json", StringComparison.Ordinal))
var acceptHeaders = context.Request.GetTypedHeaders().Accept;
if (!acceptHeaders.Contains(TextEventStreamMediaType))
{
await WriteJsonRpcErrorAsync(context,
"Not Acceptable: Client must accept text/event-stream",
Expand Down Expand Up @@ -105,12 +106,6 @@ public async Task HandleDeleteRequestAsync(HttpContext context)
}
}

private void InitializeSessionResponse(HttpContext context, HttpMcpSession<StreamableHttpServerTransport> session)
{
context.Response.Headers["mcp-session-id"] = session.Id;
context.Features.Set(session.Server);
}

private async ValueTask<HttpMcpSession<StreamableHttpServerTransport>?> GetSessionAsync(HttpContext context, string sessionId)
{
if (Sessions.TryGetValue(sessionId, out var existingSession))
Expand All @@ -123,7 +118,8 @@ await WriteJsonRpcErrorAsync(context,
return null;
}

InitializeSessionResponse(context, existingSession);
context.Response.Headers["mcp-session-id"] = existingSession.Id;
context.Features.Set(existingSession.Server);
return existingSession;
}

Expand All @@ -138,11 +134,10 @@ await WriteJsonRpcErrorAsync(context,
private async ValueTask<HttpMcpSession<StreamableHttpServerTransport>?> GetOrCreateSessionAsync(HttpContext context)
{
var sessionId = context.Request.Headers["mcp-session-id"].ToString();
HttpMcpSession<StreamableHttpServerTransport>? session;

if (string.IsNullOrEmpty(sessionId))
{
session = await CreateSessionAsync(context);
var session = await CreateSessionAsync(context);

if (!Sessions.TryAdd(session.Id, session))
{
Expand All @@ -159,6 +154,9 @@ await WriteJsonRpcErrorAsync(context,

private async ValueTask<HttpMcpSession<StreamableHttpServerTransport>> CreateSessionAsync(HttpContext context)
{
var sessionId = MakeNewSessionId();
context.Response.Headers["mcp-session-id"] = sessionId;

var mcpServerOptions = mcpServerOptionsSnapshot.Value;
if (httpMcpServerOptions.Value.ConfigureSessionOptions is { } configureSessionOptions)
{
Expand All @@ -169,16 +167,16 @@ private async ValueTask<HttpMcpSession<StreamableHttpServerTransport>> CreateSes
var transport = new StreamableHttpServerTransport();
// Use application instead of request services, because the session will likely outlive the first initialization request.
var server = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, applicationServices);
context.Features.Set(server);

var session = new HttpMcpSession<StreamableHttpServerTransport>(MakeNewSessionId(), transport, context.User, httpMcpServerOptions.Value.TimeProvider)
var session = new HttpMcpSession<StreamableHttpServerTransport>(sessionId, transport, context.User, httpMcpServerOptions.Value.TimeProvider)
{
Server = server,
};

var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? RunSessionAsync;
session.ServerRunTask = runSessionAsync(context, server, session.SessionClosed);

InitializeSessionResponse(context, session);
return session;
}

Expand Down
2 changes: 1 addition & 1 deletion src/ModelContextProtocol/Protocol/Transport/SseWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can
{
Throw.IfNull(message);

using var _ = await _disposeLock.LockAsync().ConfigureAwait(false);
using var _ = await _disposeLock.LockAsync(cancellationToken).ConfigureAwait(false);

if (_disposed)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken c
throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session.");
}

using var getCts = CancellationTokenSource.CreateLinkedTokenSource(_disposeCts.Token, cancellationToken);
await _sseWriter.WriteAllAsync(sseResponseStream, getCts.Token).ConfigureAwait(false);
// We do not need to reference _disposeCts like in HandlePostRequest, because the session ending completes the _sseWriter gracefully.
await _sseWriter.WriteAllAsync(sseResponseStream, cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand Down
32 changes: 32 additions & 0 deletions tests/Common/Utils/MockLoggerProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using Microsoft.Extensions.Logging;
using System.Collections.Concurrent;

namespace ModelContextProtocol.Tests.Utils;

public class MockLoggerProvider() : ILoggerProvider
{
public ConcurrentQueue<(string Category, LogLevel LogLevel, string Message, Exception? Exception)> LogMessages { get; } = [];

public ILogger CreateLogger(string categoryName)
{
return new MockLogger(this, categoryName);
}

public void Dispose()
{
}

private class MockLogger(MockLoggerProvider mockProvider, string category) : ILogger
{
public void Log<TState>(
LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func<TState, Exception?, string> formatter)
{
mockProvider.LogMessages.Enqueue((category, logLevel, formatter(state, exception), exception));
}

public bool IsEnabled(LogLevel logLevel) => true;

// The MockLoggerProvider is a convenient NoopDisposable
public IDisposable BeginScope<TState>(TState state) where TState : notnull => mockProvider;
}
}
Loading
Loading