From 2ef8821f0087f070aef2bbb82e9cbb63d1c8ea21 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Wed, 20 Aug 2025 16:49:45 -0400 Subject: [PATCH] Remove `McpEndpoint` & `IMcpClient` --- README.md | 4 +- samples/InMemoryTransport/Program.cs | 2 +- .../AssemblyNameHelper.cs | 9 + .../Client/IClientTransport.cs | 4 +- .../Client/IMcpClient.cs | 47 --- .../Client/McpClient.cs | 236 --------------- .../Client/McpClientExtensions.cs | 48 ++-- .../Client/McpClientFactory.cs | 28 +- .../Client/McpClientOptions.cs | 2 +- .../Client/McpClientPrompt.cs | 4 +- .../Client/McpClientResource.cs | 6 +- .../Client/McpClientResourceTemplate.cs | 4 +- .../Client/McpClientSession.cs | 272 ++++++++++++++++++ .../Client/McpClientTool.cs | 10 +- src/ModelContextProtocol.Core/IMcpEndpoint.cs | 2 +- src/ModelContextProtocol.Core/McpEndpoint.cs | 144 ---------- .../McpEndpointExtensions.cs | 4 +- .../{McpSession.cs => McpSessionHandler.cs} | 50 +++- src/ModelContextProtocol.Core/README.md | 4 +- .../Server/DestinationBoundMcpServer.cs | 3 +- .../Server/McpServerExtensions.cs | 4 +- .../Server/McpServerFactory.cs | 2 +- .../{McpServer.cs => McpServerSession.cs} | 70 +++-- .../Server/StdioServerTransport.cs | 2 +- .../TokenProgress.cs | 6 +- .../HttpServerIntegrationTests.cs | 2 +- .../MapMcpTests.cs | 2 +- .../SseIntegrationTests.cs | 2 +- .../SseServerIntegrationTestFixture.cs | 2 +- .../StatelessServerTests.cs | 2 +- .../Client/McpClientExtensionsTests.cs | 22 +- .../Client/McpClientFactoryTests.cs | 4 +- .../Client/McpClientResourceTemplateTests.cs | 2 +- .../ClientIntegrationTestFixture.cs | 2 +- .../ClientServerTestBase.cs | 2 +- .../McpServerBuilderExtensionsPromptsTests.cs | 12 +- ...cpServerBuilderExtensionsResourcesTests.cs | 12 +- .../McpServerBuilderExtensionsToolsTests.cs | 32 +-- .../Configuration/McpServerScopedTests.cs | 2 +- .../DiagnosticTests.cs | 4 +- .../Protocol/ElicitationTests.cs | 2 +- .../Protocol/NotificationHandlerTests.cs | 10 +- 42 files changed, 497 insertions(+), 586 deletions(-) create mode 100644 src/ModelContextProtocol.Core/AssemblyNameHelper.cs delete mode 100644 src/ModelContextProtocol.Core/Client/IMcpClient.cs delete mode 100644 src/ModelContextProtocol.Core/Client/McpClient.cs create mode 100644 src/ModelContextProtocol.Core/Client/McpClientSession.cs delete mode 100644 src/ModelContextProtocol.Core/McpEndpoint.cs rename src/ModelContextProtocol.Core/{McpSession.cs => McpSessionHandler.cs} (95%) rename src/ModelContextProtocol.Core/Server/{McpServer.cs => McpServerSession.cs} (90%) diff --git a/README.md b/README.md index 163d57f8..5c9cfb42 100644 --- a/README.md +++ b/README.md @@ -37,8 +37,8 @@ dotnet add package ModelContextProtocol --prerelease ## Getting Started (Client) -To get started writing a client, the `McpClientFactory.CreateAsync` method is used to instantiate and connect an `IMcpClient` -to a server. Once you have an `IMcpClient`, you can interact with it, such as to enumerate all available tools and invoke tools. +To get started writing a client, the `McpClientFactory.CreateAsync` method is used to instantiate and connect an `McpClientSession` +to a server. Once you have an `McpClientSession`, you can interact with it, such as to enumerate all available tools and invoke tools. ```csharp var clientTransport = new StdioClientTransport(new StdioClientTransportOptions diff --git a/samples/InMemoryTransport/Program.cs b/samples/InMemoryTransport/Program.cs index 67e2d320..5aa00390 100644 --- a/samples/InMemoryTransport/Program.cs +++ b/samples/InMemoryTransport/Program.cs @@ -21,7 +21,7 @@ _ = server.RunAsync(); // Connect a client using a stream-based transport over the same in-memory pipe. -await using IMcpClient client = await McpClientFactory.CreateAsync( +await using McpClientSession client = await McpClientFactory.CreateAsync( new StreamClientTransport(clientToServerPipe.Writer.AsStream(), serverToClientPipe.Reader.AsStream())); // List all tools. diff --git a/src/ModelContextProtocol.Core/AssemblyNameHelper.cs b/src/ModelContextProtocol.Core/AssemblyNameHelper.cs new file mode 100644 index 00000000..292ed2f9 --- /dev/null +++ b/src/ModelContextProtocol.Core/AssemblyNameHelper.cs @@ -0,0 +1,9 @@ +using System.Reflection; + +namespace ModelContextProtocol; + +internal static class AssemblyNameHelper +{ + /// Cached naming information used for MCP session name/version when none is specified. + public static AssemblyName DefaultAssemblyName { get; } = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName(); +} diff --git a/src/ModelContextProtocol.Core/Client/IClientTransport.cs b/src/ModelContextProtocol.Core/Client/IClientTransport.cs index 52517895..9cffa09d 100644 --- a/src/ModelContextProtocol.Core/Client/IClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/IClientTransport.cs @@ -11,7 +11,7 @@ namespace ModelContextProtocol.Client; /// and servers, allowing different transport protocols to be used interchangeably. /// /// -/// When creating an , is typically used, and is +/// When creating an , is typically used, and is /// provided with the based on expected server configuration. /// /// @@ -35,7 +35,7 @@ public interface IClientTransport /// /// /// The lifetime of the returned instance is typically managed by the - /// that uses this transport. When the client is disposed, it will dispose + /// that uses this transport. When the client is disposed, it will dispose /// the transport session as well. /// /// diff --git a/src/ModelContextProtocol.Core/Client/IMcpClient.cs b/src/ModelContextProtocol.Core/Client/IMcpClient.cs deleted file mode 100644 index 68a92a2d..00000000 --- a/src/ModelContextProtocol.Core/Client/IMcpClient.cs +++ /dev/null @@ -1,47 +0,0 @@ -using ModelContextProtocol.Protocol; - -namespace ModelContextProtocol.Client; - -/// -/// Represents an instance of a Model Context Protocol (MCP) client that connects to and communicates with an MCP server. -/// -public interface IMcpClient : IMcpEndpoint -{ - /// - /// Gets the capabilities supported by the connected server. - /// - /// The client is not connected. - ServerCapabilities ServerCapabilities { get; } - - /// - /// Gets the implementation information of the connected server. - /// - /// - /// - /// This property provides identification details about the connected server, including its name and version. - /// It is populated during the initialization handshake and is available after a successful connection. - /// - /// - /// This information can be useful for logging, debugging, compatibility checks, and displaying server - /// information to users. - /// - /// - /// The client is not connected. - Implementation ServerInfo { get; } - - /// - /// Gets any instructions describing how to use the connected server and its features. - /// - /// - /// - /// This property contains instructions provided by the server during initialization that explain - /// how to effectively use its capabilities. These instructions can include details about available - /// tools, expected input formats, limitations, or any other helpful information. - /// - /// - /// This can be used by clients to improve an LLM's understanding of available tools, prompts, and resources. - /// It can be thought of like a "hint" to the model and may be added to a system prompt. - /// - /// - string? ServerInstructions { get; } -} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/McpClient.cs b/src/ModelContextProtocol.Core/Client/McpClient.cs deleted file mode 100644 index dd8c7fe0..00000000 --- a/src/ModelContextProtocol.Core/Client/McpClient.cs +++ /dev/null @@ -1,236 +0,0 @@ -using Microsoft.Extensions.Logging; -using ModelContextProtocol.Protocol; -using System.Text.Json; - -namespace ModelContextProtocol.Client; - -/// -internal sealed partial class McpClient : McpEndpoint, IMcpClient -{ - private static Implementation DefaultImplementation { get; } = new() - { - Name = DefaultAssemblyName.Name ?? nameof(McpClient), - Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0", - }; - - private readonly IClientTransport _clientTransport; - private readonly McpClientOptions _options; - - private ITransport? _sessionTransport; - private CancellationTokenSource? _connectCts; - - private ServerCapabilities? _serverCapabilities; - private Implementation? _serverInfo; - private string? _serverInstructions; - - /// - /// Initializes a new instance of the class. - /// - /// The transport to use for communication with the server. - /// Options for the client, defining protocol version and capabilities. - /// The logger factory. - public McpClient(IClientTransport clientTransport, McpClientOptions? options, ILoggerFactory? loggerFactory) - : base(loggerFactory) - { - options ??= new(); - - _clientTransport = clientTransport; - _options = options; - - EndpointName = clientTransport.Name; - - if (options.Capabilities is { } capabilities) - { - if (capabilities.NotificationHandlers is { } notificationHandlers) - { - NotificationHandlers.RegisterRange(notificationHandlers); - } - - if (capabilities.Sampling is { } samplingCapability) - { - if (samplingCapability.SamplingHandler is not { } samplingHandler) - { - throw new InvalidOperationException("Sampling capability was set but it did not provide a handler."); - } - - RequestHandlers.Set( - RequestMethods.SamplingCreateMessage, - (request, _, cancellationToken) => samplingHandler( - request, - request?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, - cancellationToken), - McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, - McpJsonUtilities.JsonContext.Default.CreateMessageResult); - } - - if (capabilities.Roots is { } rootsCapability) - { - if (rootsCapability.RootsHandler is not { } rootsHandler) - { - throw new InvalidOperationException("Roots capability was set but it did not provide a handler."); - } - - RequestHandlers.Set( - RequestMethods.RootsList, - (request, _, cancellationToken) => rootsHandler(request, cancellationToken), - McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, - McpJsonUtilities.JsonContext.Default.ListRootsResult); - } - - if (capabilities.Elicitation is { } elicitationCapability) - { - if (elicitationCapability.ElicitationHandler is not { } elicitationHandler) - { - throw new InvalidOperationException("Elicitation capability was set but it did not provide a handler."); - } - - RequestHandlers.Set( - RequestMethods.ElicitationCreate, - (request, _, cancellationToken) => elicitationHandler(request, cancellationToken), - McpJsonUtilities.JsonContext.Default.ElicitRequestParams, - McpJsonUtilities.JsonContext.Default.ElicitResult); - } - } - } - - /// - public string? SessionId - { - get - { - if (_sessionTransport is null) - { - throw new InvalidOperationException("Must have already initialized a session when invoking this property."); - } - - return _sessionTransport.SessionId; - } - } - - /// - public ServerCapabilities ServerCapabilities => _serverCapabilities ?? throw new InvalidOperationException("The client is not connected."); - - /// - public Implementation ServerInfo => _serverInfo ?? throw new InvalidOperationException("The client is not connected."); - - /// - public string? ServerInstructions => _serverInstructions; - - /// - public override string EndpointName { get; } - - /// - /// Asynchronously connects to an MCP server, establishes the transport connection, and completes the initialization handshake. - /// - public async Task ConnectAsync(CancellationToken cancellationToken = default) - { - _connectCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - cancellationToken = _connectCts.Token; - - try - { - // Connect transport - _sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); - InitializeSession(_sessionTransport); - // We don't want the ConnectAsync token to cancel the session after we've successfully connected. - // The base class handles cleaning up the session in DisposeAsync without our help. - StartSession(_sessionTransport, fullSessionCancellationToken: CancellationToken.None); - - // Perform initialization sequence - using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - initializationCts.CancelAfter(_options.InitializationTimeout); - - try - { - // Send initialize request - string requestProtocol = _options.ProtocolVersion ?? McpSession.LatestProtocolVersion; - var initializeResponse = await this.SendRequestAsync( - RequestMethods.Initialize, - new InitializeRequestParams - { - ProtocolVersion = requestProtocol, - Capabilities = _options.Capabilities ?? new ClientCapabilities(), - ClientInfo = _options.ClientInfo ?? DefaultImplementation, - }, - McpJsonUtilities.JsonContext.Default.InitializeRequestParams, - McpJsonUtilities.JsonContext.Default.InitializeResult, - cancellationToken: initializationCts.Token).ConfigureAwait(false); - - // Store server information - if (_logger.IsEnabled(LogLevel.Information)) - { - LogServerCapabilitiesReceived(EndpointName, - capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities), - serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation)); - } - - _serverCapabilities = initializeResponse.Capabilities; - _serverInfo = initializeResponse.ServerInfo; - _serverInstructions = initializeResponse.Instructions; - - // Validate protocol version - bool isResponseProtocolValid = - _options.ProtocolVersion is { } optionsProtocol ? optionsProtocol == initializeResponse.ProtocolVersion : - McpSession.SupportedProtocolVersions.Contains(initializeResponse.ProtocolVersion); - if (!isResponseProtocolValid) - { - LogServerProtocolVersionMismatch(EndpointName, requestProtocol, initializeResponse.ProtocolVersion); - throw new McpException($"Server protocol version mismatch. Expected {requestProtocol}, got {initializeResponse.ProtocolVersion}"); - } - - // Send initialized notification - await this.SendNotificationAsync( - NotificationMethods.InitializedNotification, - new InitializedNotificationParams(), - McpJsonUtilities.JsonContext.Default.InitializedNotificationParams, - cancellationToken: initializationCts.Token).ConfigureAwait(false); - - } - catch (OperationCanceledException oce) when (initializationCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested) - { - LogClientInitializationTimeout(EndpointName); - throw new TimeoutException("Initialization timed out", oce); - } - } - catch (Exception e) - { - LogClientInitializationError(EndpointName, e); - await DisposeAsync().ConfigureAwait(false); - throw; - } - } - - /// - public override async ValueTask DisposeUnsynchronizedAsync() - { - try - { - if (_connectCts is not null) - { - await _connectCts.CancelAsync().ConfigureAwait(false); - _connectCts.Dispose(); - } - - await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); - } - finally - { - if (_sessionTransport is not null) - { - await _sessionTransport.DisposeAsync().ConfigureAwait(false); - } - } - } - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client received server '{ServerInfo}' capabilities: '{Capabilities}'.")] - private partial void LogServerCapabilitiesReceived(string endpointName, string capabilities, string serverInfo); - - [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization error.")] - private partial void LogClientInitializationError(string endpointName, Exception exception); - - [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization timed out.")] - private partial void LogClientInitializationTimeout(string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client protocol version mismatch with server. Expected '{Expected}', received '{Received}'.")] - private partial void LogServerProtocolVersionMismatch(string endpointName, string expected, string received); -} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs b/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs index 60a9c3a6..4e082b07 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs @@ -8,7 +8,7 @@ namespace ModelContextProtocol.Client; /// -/// Provides extension methods for interacting with an . +/// Provides extension methods for interacting with an . /// /// /// @@ -38,7 +38,7 @@ public static class McpClientExtensions /// /// is . /// Thrown when the server cannot be reached or returns an error response. - public static Task PingAsync(this IMcpClient client, CancellationToken cancellationToken = default) + public static Task PingAsync(this McpClientSession client, CancellationToken cancellationToken = default) { Throw.IfNull(client); @@ -90,7 +90,7 @@ public static Task PingAsync(this IMcpClient client, CancellationToken cancellat /// /// is . public static async ValueTask> ListToolsAsync( - this IMcpClient client, + this McpClientSession client, JsonSerializerOptions? serializerOptions = null, CancellationToken cancellationToken = default) { @@ -156,7 +156,7 @@ public static async ValueTask> ListToolsAsync( /// /// is . public static async IAsyncEnumerable EnumerateToolsAsync( - this IMcpClient client, + this McpClientSession client, JsonSerializerOptions? serializerOptions = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { @@ -203,7 +203,7 @@ public static async IAsyncEnumerable EnumerateToolsAsync( /// /// is . public static async ValueTask> ListPromptsAsync( - this IMcpClient client, CancellationToken cancellationToken = default) + this McpClientSession client, CancellationToken cancellationToken = default) { Throw.IfNull(client); @@ -259,7 +259,7 @@ public static async ValueTask> ListPromptsAsync( /// /// is . public static async IAsyncEnumerable EnumeratePromptsAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) + this McpClientSession client, [EnumeratorCancellation] CancellationToken cancellationToken = default) { Throw.IfNull(client); @@ -309,7 +309,7 @@ public static async IAsyncEnumerable EnumeratePromptsAsync( /// Thrown when the prompt does not exist, when required arguments are missing, or when the server encounters an error processing the prompt. /// is . public static ValueTask GetPromptAsync( - this IMcpClient client, + this McpClientSession client, string name, IReadOnlyDictionary? arguments = null, JsonSerializerOptions? serializerOptions = null, @@ -347,7 +347,7 @@ public static ValueTask GetPromptAsync( /// /// is . public static async ValueTask> ListResourceTemplatesAsync( - this IMcpClient client, CancellationToken cancellationToken = default) + this McpClientSession client, CancellationToken cancellationToken = default) { Throw.IfNull(client); @@ -404,7 +404,7 @@ public static async ValueTask> ListResourceTemp /// /// is . public static async IAsyncEnumerable EnumerateResourceTemplatesAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) + this McpClientSession client, [EnumeratorCancellation] CancellationToken cancellationToken = default) { Throw.IfNull(client); @@ -458,7 +458,7 @@ public static async IAsyncEnumerable EnumerateResourc /// /// is . public static async ValueTask> ListResourcesAsync( - this IMcpClient client, CancellationToken cancellationToken = default) + this McpClientSession client, CancellationToken cancellationToken = default) { Throw.IfNull(client); @@ -515,7 +515,7 @@ public static async ValueTask> ListResourcesAsync( /// /// is . public static async IAsyncEnumerable EnumerateResourcesAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) + this McpClientSession client, [EnumeratorCancellation] CancellationToken cancellationToken = default) { Throw.IfNull(client); @@ -549,7 +549,7 @@ public static async IAsyncEnumerable EnumerateResourcesAsync( /// is . /// is empty or composed entirely of whitespace. public static ValueTask ReadResourceAsync( - this IMcpClient client, string uri, CancellationToken cancellationToken = default) + this McpClientSession client, string uri, CancellationToken cancellationToken = default) { Throw.IfNull(client); Throw.IfNullOrWhiteSpace(uri); @@ -571,7 +571,7 @@ public static ValueTask ReadResourceAsync( /// is . /// is . public static ValueTask ReadResourceAsync( - this IMcpClient client, Uri uri, CancellationToken cancellationToken = default) + this McpClientSession client, Uri uri, CancellationToken cancellationToken = default) { Throw.IfNull(client); Throw.IfNull(uri); @@ -590,7 +590,7 @@ public static ValueTask ReadResourceAsync( /// is . /// is empty or composed entirely of whitespace. public static ValueTask ReadResourceAsync( - this IMcpClient client, string uriTemplate, IReadOnlyDictionary arguments, CancellationToken cancellationToken = default) + this McpClientSession client, string uriTemplate, IReadOnlyDictionary arguments, CancellationToken cancellationToken = default) { Throw.IfNull(client); Throw.IfNullOrWhiteSpace(uriTemplate); @@ -633,7 +633,7 @@ public static ValueTask ReadResourceAsync( /// is . /// is empty or composed entirely of whitespace. /// The server returned an error response. - public static ValueTask CompleteAsync(this IMcpClient client, Reference reference, string argumentName, string argumentValue, CancellationToken cancellationToken = default) + public static ValueTask CompleteAsync(this McpClientSession client, Reference reference, string argumentName, string argumentValue, CancellationToken cancellationToken = default) { Throw.IfNull(client); Throw.IfNull(reference); @@ -676,7 +676,7 @@ public static ValueTask CompleteAsync(this IMcpClient client, Re /// is . /// is . /// is empty or composed entirely of whitespace. - public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) + public static Task SubscribeToResourceAsync(this McpClientSession client, string uri, CancellationToken cancellationToken = default) { Throw.IfNull(client); Throw.IfNullOrWhiteSpace(uri); @@ -713,7 +713,7 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, /// /// is . /// is . - public static Task SubscribeToResourceAsync(this IMcpClient client, Uri uri, CancellationToken cancellationToken = default) + public static Task SubscribeToResourceAsync(this McpClientSession client, Uri uri, CancellationToken cancellationToken = default) { Throw.IfNull(client); Throw.IfNull(uri); @@ -745,7 +745,7 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, Uri uri, Can /// is . /// is . /// is empty or composed entirely of whitespace. - public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) + public static Task UnsubscribeFromResourceAsync(this McpClientSession client, string uri, CancellationToken cancellationToken = default) { Throw.IfNull(client); Throw.IfNullOrWhiteSpace(uri); @@ -781,7 +781,7 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string u /// /// is . /// is . - public static Task UnsubscribeFromResourceAsync(this IMcpClient client, Uri uri, CancellationToken cancellationToken = default) + public static Task UnsubscribeFromResourceAsync(this McpClientSession client, Uri uri, CancellationToken cancellationToken = default) { Throw.IfNull(client); Throw.IfNull(uri); @@ -825,7 +825,7 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, Uri uri, /// /// public static ValueTask CallToolAsync( - this IMcpClient client, + this McpClientSession client, string toolName, IReadOnlyDictionary? arguments = null, IProgress? progress = null, @@ -854,7 +854,7 @@ public static ValueTask CallToolAsync( cancellationToken: cancellationToken); static async ValueTask SendRequestWithProgressAsync( - IMcpClient client, + McpClientSession client, string toolName, IReadOnlyDictionary? arguments, IProgress progress, @@ -1035,7 +1035,7 @@ internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chat /// /// /// is . - public static Task SetLoggingLevel(this IMcpClient client, LoggingLevel level, CancellationToken cancellationToken = default) + public static Task SetLoggingLevel(this McpClientSession client, LoggingLevel level, CancellationToken cancellationToken = default) { Throw.IfNull(client); @@ -1070,8 +1070,8 @@ public static Task SetLoggingLevel(this IMcpClient client, LoggingLevel level, C /// /// /// is . - public static Task SetLoggingLevel(this IMcpClient client, LogLevel level, CancellationToken cancellationToken = default) => - SetLoggingLevel(client, McpServer.ToLoggingLevel(level), cancellationToken); + public static Task SetLoggingLevel(this McpClientSession client, LogLevel level, CancellationToken cancellationToken = default) => + SetLoggingLevel(client, McpServerSession.ToLoggingLevel(level), cancellationToken); /// Convers a dictionary with values to a dictionary with values. private static Dictionary? ToArgumentsDictionary( diff --git a/src/ModelContextProtocol.Core/Client/McpClientFactory.cs b/src/ModelContextProtocol.Core/Client/McpClientFactory.cs index 30b3a947..57827a42 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientFactory.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientFactory.cs @@ -6,13 +6,13 @@ namespace ModelContextProtocol.Client; /// Provides factory methods for creating Model Context Protocol (MCP) clients. /// /// -/// This factory class is the primary way to instantiate instances +/// This factory class is the primary way to instantiate instances /// that connect to MCP servers. It handles the creation and connection /// of appropriate implementations through the supplied transport. /// -public static partial class McpClientFactory +public static class McpClientFactory { - /// Creates an , connecting it to the specified server. + /// Creates an , connecting it to the specified server. /// The transport instance used to communicate with the server. /// /// A client configuration object which specifies client capabilities and protocol version. @@ -20,10 +20,10 @@ public static partial class McpClientFactory /// /// A logger factory for creating loggers for clients. /// The to monitor for cancellation requests. The default is . - /// An that's connected to the specified server. + /// An that's connected to the specified server. /// is . /// is . - public static async Task CreateAsync( + public static async Task CreateAsync( IClientTransport clientTransport, McpClientOptions? clientOptions = null, ILoggerFactory? loggerFactory = null, @@ -31,24 +31,20 @@ public static async Task CreateAsync( { Throw.IfNull(clientTransport); - McpClient client = new(clientTransport, clientOptions, loggerFactory); + var transport = await clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); + var endpointName = clientTransport.Name; + + var clientSession = new McpClientSession(transport, endpointName, clientOptions, loggerFactory); try { - await client.ConnectAsync(cancellationToken).ConfigureAwait(false); - if (loggerFactory?.CreateLogger(typeof(McpClientFactory)) is ILogger logger) - { - logger.LogClientCreated(client.EndpointName); - } + await clientSession.ConnectAsync(cancellationToken).ConfigureAwait(false); } catch { - await client.DisposeAsync().ConfigureAwait(false); + await clientSession.DisposeAsync().ConfigureAwait(false); throw; } - return client; + return clientSession; } - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client created and connected.")] - private static partial void LogClientCreated(this ILogger logger, string endpointName); } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/McpClientOptions.cs b/src/ModelContextProtocol.Core/Client/McpClientOptions.cs index 76099d0d..93145a3a 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientOptions.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientOptions.cs @@ -3,7 +3,7 @@ namespace ModelContextProtocol.Client; /// -/// Provides configuration options for creating instances. +/// Provides configuration options for creating instances. /// /// /// These options are typically passed to when creating a client. diff --git a/src/ModelContextProtocol.Core/Client/McpClientPrompt.cs b/src/ModelContextProtocol.Core/Client/McpClientPrompt.cs index 43fc759a..4726b46e 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientPrompt.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientPrompt.cs @@ -20,9 +20,9 @@ namespace ModelContextProtocol.Client; /// public sealed class McpClientPrompt { - private readonly IMcpClient _client; + private readonly McpClientSession _client; - internal McpClientPrompt(IMcpClient client, Prompt prompt) + internal McpClientPrompt(McpClientSession client, Prompt prompt) { _client = client; ProtocolPrompt = prompt; diff --git a/src/ModelContextProtocol.Core/Client/McpClientResource.cs b/src/ModelContextProtocol.Core/Client/McpClientResource.cs index 06f8aff6..91ebf4d4 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientResource.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientResource.cs @@ -15,9 +15,9 @@ namespace ModelContextProtocol.Client; /// public sealed class McpClientResource { - private readonly IMcpClient _client; + private readonly McpClientSession _client; - internal McpClientResource(IMcpClient client, Resource resource) + internal McpClientResource(McpClientSession client, Resource resource) { _client = client; ProtocolResource = resource; @@ -58,7 +58,7 @@ internal McpClientResource(IMcpClient client, Resource resource) /// A containing the resource's result with content and messages. /// /// - /// This is a convenience method that internally calls . + /// This is a convenience method that internally calls . /// /// public ValueTask ReadAsync( diff --git a/src/ModelContextProtocol.Core/Client/McpClientResourceTemplate.cs b/src/ModelContextProtocol.Core/Client/McpClientResourceTemplate.cs index 4da1bd0c..678f7ccc 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientResourceTemplate.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientResourceTemplate.cs @@ -15,9 +15,9 @@ namespace ModelContextProtocol.Client; /// public sealed class McpClientResourceTemplate { - private readonly IMcpClient _client; + private readonly McpClientSession _client; - internal McpClientResourceTemplate(IMcpClient client, ResourceTemplate resourceTemplate) + internal McpClientResourceTemplate(McpClientSession client, ResourceTemplate resourceTemplate) { _client = client; ProtocolResourceTemplate = resourceTemplate; diff --git a/src/ModelContextProtocol.Core/Client/McpClientSession.cs b/src/ModelContextProtocol.Core/Client/McpClientSession.cs new file mode 100644 index 00000000..584b1c7e --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/McpClientSession.cs @@ -0,0 +1,272 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Protocol; +using System.Text.Json; + +namespace ModelContextProtocol.Client; + +/// +/// Represents an instance of a Model Context Protocol (MCP) client session that connects to and communicates with an MCP server. +/// +public sealed partial class McpClientSession : IMcpEndpoint +{ + private static Implementation DefaultImplementation { get; } = new() + { + Name = AssemblyNameHelper.DefaultAssemblyName.Name ?? nameof(McpClientSession), + Version = AssemblyNameHelper.DefaultAssemblyName.Version?.ToString() ?? "1.0.0", + }; + + private readonly ILogger _logger; + private readonly ITransport _transport; + private readonly string _endpointName; + private readonly McpClientOptions _options; + private readonly McpSessionHandler _sessionHandler; + + private CancellationTokenSource? _connectCts; + + private ServerCapabilities? _serverCapabilities; + private Implementation? _serverInfo; + private string? _serverInstructions; + + private int _isDisposed; + + /// + /// Initializes a new instance of the class. + /// + /// The transport to use for communication with the server. + /// The name of the endpoint for logging and debug purposes. + /// Options for the client, defining protocol version and capabilities. + /// The logger factory. + internal McpClientSession(ITransport transport, string endpointName, McpClientOptions? options, ILoggerFactory? loggerFactory) + { + options ??= new(); + + _transport = transport; + _endpointName = $"Client ({options.ClientInfo?.Name ?? DefaultImplementation.Name} {options.ClientInfo?.Version ?? DefaultImplementation.Version})"; + _options = options; + _logger = loggerFactory?.CreateLogger() ?? NullLogger.Instance; + + var notificationHandlers = new NotificationHandlers(); + var requestHandlers = new RequestHandlers(); + + if (options.Capabilities is { } capabilities) + { + RegisterHandlers(capabilities, notificationHandlers, requestHandlers); + } + + _sessionHandler = new McpSessionHandler(isServer: false, transport, endpointName, requestHandlers, notificationHandlers, _logger); + } + + private void RegisterHandlers(ClientCapabilities capabilities, NotificationHandlers notificationHandlers, RequestHandlers requestHandlers) + { + if (capabilities.NotificationHandlers is { } notificationHandlersFromCapabilities) + { + notificationHandlers.RegisterRange(notificationHandlersFromCapabilities); + } + + if (capabilities.Sampling is { } samplingCapability) + { + if (samplingCapability.SamplingHandler is not { } samplingHandler) + { + throw new InvalidOperationException("Sampling capability was set but it did not provide a handler."); + } + + requestHandlers.Set( + RequestMethods.SamplingCreateMessage, + (request, _, cancellationToken) => samplingHandler( + request, + request?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, + cancellationToken), + McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, + McpJsonUtilities.JsonContext.Default.CreateMessageResult); + } + + if (capabilities.Roots is { } rootsCapability) + { + if (rootsCapability.RootsHandler is not { } rootsHandler) + { + throw new InvalidOperationException("Roots capability was set but it did not provide a handler."); + } + + requestHandlers.Set( + RequestMethods.RootsList, + (request, _, cancellationToken) => rootsHandler(request, cancellationToken), + McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, + McpJsonUtilities.JsonContext.Default.ListRootsResult); + } + + if (capabilities.Elicitation is { } elicitationCapability) + { + if (elicitationCapability.ElicitationHandler is not { } elicitationHandler) + { + throw new InvalidOperationException("Elicitation capability was set but it did not provide a handler."); + } + + requestHandlers.Set( + RequestMethods.ElicitationCreate, + (request, _, cancellationToken) => elicitationHandler(request, cancellationToken), + McpJsonUtilities.JsonContext.Default.ElicitRequestParams, + McpJsonUtilities.JsonContext.Default.ElicitResult); + } + } + + /// + public string? SessionId => _transport.SessionId; + + /// + /// Gets the capabilities supported by the connected server. + /// + /// The client is not connected. + public ServerCapabilities ServerCapabilities => _serverCapabilities ?? throw new InvalidOperationException("The client is not connected."); + + /// + /// Gets the implementation information of the connected server. + /// + /// + /// + /// This property provides identification details about the connected server, including its name and version. + /// It is populated during the initialization handshake and is available after a successful connection. + /// + /// + /// This information can be useful for logging, debugging, compatibility checks, and displaying server + /// information to users. + /// + /// + /// The client is not connected. + public Implementation ServerInfo => _serverInfo ?? throw new InvalidOperationException("The client is not connected."); + + /// + /// Gets any instructions describing how to use the connected server and its features. + /// + /// + /// + /// This property contains instructions provided by the server during initialization that explain + /// how to effectively use its capabilities. These instructions can include details about available + /// tools, expected input formats, limitations, or any other helpful information. + /// + /// + /// This can be used by clients to improve an LLM's understanding of available tools, prompts, and resources. + /// It can be thought of like a "hint" to the model and may be added to a system prompt. + /// + /// + public string? ServerInstructions => _serverInstructions; + + /// + /// Asynchronously connects to an MCP server, establishes the transport connection, and completes the initialization handshake. + /// + public async Task ConnectAsync(CancellationToken cancellationToken = default) + { + _connectCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cancellationToken = _connectCts.Token; + + try + { + // We don't want the ConnectAsync token to cancel the message processing loop after we've successfully connected. + // The session handler handles cancelling the loop upon its disposal. + _ = _sessionHandler.ProcessMessagesAsync(CancellationToken.None); + + // Perform initialization sequence + using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + initializationCts.CancelAfter(_options.InitializationTimeout); + + try + { + // Send initialize request + string requestProtocol = _options.ProtocolVersion ?? McpSessionHandler.LatestProtocolVersion; + var initializeResponse = await this.SendRequestAsync( + RequestMethods.Initialize, + new InitializeRequestParams + { + ProtocolVersion = requestProtocol, + Capabilities = _options.Capabilities ?? new ClientCapabilities(), + ClientInfo = _options.ClientInfo ?? DefaultImplementation, + }, + McpJsonUtilities.JsonContext.Default.InitializeRequestParams, + McpJsonUtilities.JsonContext.Default.InitializeResult, + cancellationToken: initializationCts.Token).ConfigureAwait(false); + + // Store server information + if (_logger.IsEnabled(LogLevel.Information)) + { + LogServerCapabilitiesReceived(_endpointName, + capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities), + serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation)); + } + + _serverCapabilities = initializeResponse.Capabilities; + _serverInfo = initializeResponse.ServerInfo; + _serverInstructions = initializeResponse.Instructions; + + // Validate protocol version + bool isResponseProtocolValid = + _options.ProtocolVersion is { } optionsProtocol ? optionsProtocol == initializeResponse.ProtocolVersion : + McpSessionHandler.SupportedProtocolVersions.Contains(initializeResponse.ProtocolVersion); + if (!isResponseProtocolValid) + { + LogServerProtocolVersionMismatch(_endpointName, requestProtocol, initializeResponse.ProtocolVersion); + throw new McpException($"Server protocol version mismatch. Expected {requestProtocol}, got {initializeResponse.ProtocolVersion}"); + } + + // Send initialized notification + await this.SendNotificationAsync( + NotificationMethods.InitializedNotification, + new InitializedNotificationParams(), + McpJsonUtilities.JsonContext.Default.InitializedNotificationParams, + cancellationToken: initializationCts.Token).ConfigureAwait(false); + + } + catch (OperationCanceledException oce) when (initializationCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested) + { + LogClientInitializationTimeout(_endpointName); + throw new TimeoutException("Initialization timed out", oce); + } + } + catch (Exception e) + { + LogClientInitializationError(_endpointName, e); + await DisposeAsync().ConfigureAwait(false); + throw; + } + + LogClientConnected(_endpointName); + } + + /// + public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + => _sessionHandler.SendRequestAsync(request, cancellationToken); + + /// + public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + => _sessionHandler.SendMessageAsync(message, cancellationToken); + + /// + public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) + => _sessionHandler.RegisterNotificationHandler(method, handler); + + /// + public async ValueTask DisposeAsync() + { + if (Interlocked.CompareExchange(ref _isDisposed, 1, 0) != 0) + { + return; + } + + await _sessionHandler.DisposeAsync().ConfigureAwait(false); + await _transport.DisposeAsync().ConfigureAwait(false); + } + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client received server '{ServerInfo}' capabilities: '{Capabilities}'.")] + private partial void LogServerCapabilitiesReceived(string endpointName, string capabilities, string serverInfo); + + [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization error.")] + private partial void LogClientInitializationError(string endpointName, Exception exception); + + [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization timed out.")] + private partial void LogClientInitializationTimeout(string endpointName); + + [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client protocol version mismatch with server. Expected '{Expected}', received '{Received}'.")] + private partial void LogServerProtocolVersionMismatch(string endpointName, string expected, string received); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client created and connected.")] + private partial void LogClientConnected(string endpointName); +} diff --git a/src/ModelContextProtocol.Core/Client/McpClientTool.cs b/src/ModelContextProtocol.Core/Client/McpClientTool.cs index 1810e9c5..d0be8929 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientTool.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientTool.cs @@ -6,11 +6,11 @@ namespace ModelContextProtocol.Client; /// -/// Provides an that calls a tool via an . +/// Provides an that calls a tool via an . /// /// /// -/// The class encapsulates an along with a description of +/// The class encapsulates an along with a description of /// a tool available via that client, allowing it to be invoked as an . This enables integration /// with AI models that support function calling capabilities. /// @@ -20,7 +20,7 @@ namespace ModelContextProtocol.Client; /// /// /// Typically, you would get instances of this class by calling the -/// or extension methods on an instance. +/// or extension methods on an instance. /// /// public sealed class McpClientTool : AIFunction @@ -32,13 +32,13 @@ public sealed class McpClientTool : AIFunction ["Strict"] = false, // some MCP schemas may not meet "strict" requirements }); - private readonly IMcpClient _client; + private readonly McpClientSession _client; private readonly string _name; private readonly string _description; private readonly IProgress? _progress; internal McpClientTool( - IMcpClient client, + McpClientSession client, Tool tool, JsonSerializerOptions serializerOptions, string? name = null, diff --git a/src/ModelContextProtocol.Core/IMcpEndpoint.cs b/src/ModelContextProtocol.Core/IMcpEndpoint.cs index ea825e68..d48cc2d1 100644 --- a/src/ModelContextProtocol.Core/IMcpEndpoint.cs +++ b/src/ModelContextProtocol.Core/IMcpEndpoint.cs @@ -17,7 +17,7 @@ namespace ModelContextProtocol; /// /// /// -/// serves as the base interface for both and +/// serves as the base interface for both and /// interfaces, providing the common functionality needed for MCP protocol /// communication. Most applications will use these more specific interfaces rather than working with /// directly. diff --git a/src/ModelContextProtocol.Core/McpEndpoint.cs b/src/ModelContextProtocol.Core/McpEndpoint.cs deleted file mode 100644 index 0d0ccbb9..00000000 --- a/src/ModelContextProtocol.Core/McpEndpoint.cs +++ /dev/null @@ -1,144 +0,0 @@ -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; -using System.Diagnostics.CodeAnalysis; -using System.Reflection; - -namespace ModelContextProtocol; - -/// -/// Base class for an MCP JSON-RPC endpoint. This covers both MCP clients and servers. -/// It is not supported, nor necessary, to implement both client and server functionality in the same class. -/// If an application needs to act as both a client and a server, it should use separate objects for each. -/// This is especially true as a client represents a connection to one and only one server, and vice versa. -/// Any multi-client or multi-server functionality should be implemented at a higher level of abstraction. -/// -internal abstract partial class McpEndpoint : IAsyncDisposable -{ - /// Cached naming information used for name/version when none is specified. - internal static AssemblyName DefaultAssemblyName { get; } = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName(); - - private McpSession? _session; - private CancellationTokenSource? _sessionCts; - - private readonly SemaphoreSlim _disposeLock = new(1, 1); - private bool _disposed; - - protected readonly ILogger _logger; - - /// - /// Initializes a new instance of the class. - /// - /// The logger factory. - protected McpEndpoint(ILoggerFactory? loggerFactory = null) - { - _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; - } - - protected RequestHandlers RequestHandlers { get; } = []; - - protected NotificationHandlers NotificationHandlers { get; } = new(); - - public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) - => GetSessionOrThrow().SendRequestAsync(request, cancellationToken); - - public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) - => GetSessionOrThrow().SendMessageAsync(message, cancellationToken); - - public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => - GetSessionOrThrow().RegisterNotificationHandler(method, handler); - - /// - /// Gets the name of the endpoint for logging and debug purposes. - /// - public abstract string EndpointName { get; } - - /// - /// Task that processes incoming messages from the transport. - /// - protected Task? MessageProcessingTask { get; private set; } - - protected void InitializeSession(ITransport sessionTransport) - { - _session = new McpSession(this is IMcpServer, sessionTransport, EndpointName, RequestHandlers, NotificationHandlers, _logger); - } - - [MemberNotNull(nameof(MessageProcessingTask))] - protected void StartSession(ITransport sessionTransport, CancellationToken fullSessionCancellationToken) - { - _sessionCts = CancellationTokenSource.CreateLinkedTokenSource(fullSessionCancellationToken); - MessageProcessingTask = GetSessionOrThrow().ProcessMessagesAsync(_sessionCts.Token); - } - - protected void CancelSession() => _sessionCts?.Cancel(); - - public async ValueTask DisposeAsync() - { - using var _ = await _disposeLock.LockAsync().ConfigureAwait(false); - - if (_disposed) - { - return; - } - _disposed = true; - - await DisposeUnsynchronizedAsync().ConfigureAwait(false); - } - - /// - /// Cleans up the endpoint and releases resources. - /// - /// - public virtual async ValueTask DisposeUnsynchronizedAsync() - { - LogEndpointShuttingDown(EndpointName); - - try - { - if (_sessionCts is not null) - { - await _sessionCts.CancelAsync().ConfigureAwait(false); - } - - if (MessageProcessingTask is not null) - { - try - { - await MessageProcessingTask.ConfigureAwait(false); - } - catch (OperationCanceledException) - { - // Ignore cancellation - } - } - } - finally - { - _session?.Dispose(); - _sessionCts?.Dispose(); - } - - LogEndpointShutDown(EndpointName); - } - - protected McpSession GetSessionOrThrow() - { -#if NET - ObjectDisposedException.ThrowIf(_disposed, this); -#else - if (_disposed) - { - throw new ObjectDisposedException(GetType().Name); - } -#endif - - return _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages."); - } - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} shutting down.")] - private partial void LogEndpointShuttingDown(string endpointName); - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} shut down.")] - private partial void LogEndpointShutDown(string endpointName); -} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/McpEndpointExtensions.cs b/src/ModelContextProtocol.Core/McpEndpointExtensions.cs index 4e4abe5c..345c1053 100644 --- a/src/ModelContextProtocol.Core/McpEndpointExtensions.cs +++ b/src/ModelContextProtocol.Core/McpEndpointExtensions.cs @@ -16,8 +16,8 @@ namespace ModelContextProtocol; /// simplifying JSON-RPC communication by handling serialization and deserialization of parameters and results. /// /// -/// These extension methods are designed to be used with both client () and -/// server () implementations of the interface. +/// These extension methods are designed to be used with both client () and +/// server () implementations of the interface. /// /// public static class McpEndpointExtensions diff --git a/src/ModelContextProtocol.Core/McpSession.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs similarity index 95% rename from src/ModelContextProtocol.Core/McpSession.cs rename to src/ModelContextProtocol.Core/McpSessionHandler.cs index da954205..463e796a 100644 --- a/src/ModelContextProtocol.Core/McpSession.cs +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -17,7 +17,7 @@ namespace ModelContextProtocol; /// /// Class for managing an MCP JSON-RPC session. This covers both MCP clients and servers. /// -internal sealed partial class McpSession : IDisposable +internal sealed partial class McpSessionHandler : IAsyncDisposable { private static readonly Histogram s_clientSessionDuration = Diagnostics.CreateDurationHistogram( "mcp.client.session.duration", "Measures the duration of a client session.", longBuckets: true); @@ -61,8 +61,11 @@ internal sealed partial class McpSession : IDisposable private readonly string _sessionId = Guid.NewGuid().ToString("N"); private long _lastRequestId; + private CancellationTokenSource? _messageProcessingCts; + private Task? _messageProcessingTask; + /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// true if this is a server; false if it's a client. /// An MCP transport implementation. @@ -70,7 +73,7 @@ internal sealed partial class McpSession : IDisposable /// A collection of request handlers. /// A collection of notification handlers. /// The logger. - public McpSession( + public McpSessionHandler( bool isServer, ITransport transport, string endpointName, @@ -107,7 +110,21 @@ public McpSession( /// Starts processing messages from the transport. This method will block until the transport is disconnected. /// This is generally started in a background task or thread from the initialization logic of the derived class. /// - public async Task ProcessMessagesAsync(CancellationToken cancellationToken) + public Task ProcessMessagesAsync(CancellationToken cancellationToken) + { + if (_messageProcessingTask is not null) + { + throw new InvalidOperationException("The message processing loop has already started."); + } + + Debug.Assert(_messageProcessingCts is null); + + _messageProcessingCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _messageProcessingTask = ProcessMessagesCoreAsync(_messageProcessingCts.Token); + return _messageProcessingTask; + } + + private async Task ProcessMessagesCoreAsync(CancellationToken cancellationToken) { try { @@ -344,7 +361,7 @@ private CancellationTokenRegistration RegisterCancellation(CancellationToken can return cancellationToken.Register(static objState => { - var state = (Tuple)objState!; + var state = (Tuple)objState!; _ = state.Item1.SendMessageAsync(new JsonRpcNotification { Method = NotificationMethods.CancelledNotification, @@ -372,6 +389,8 @@ public IAsyncDisposable RegisterNotificationHandler(string method, FuncA task containing the server's response. public async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) { + Throw.IfNull(request); + cancellationToken.ThrowIfCancellationRequested(); Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; @@ -682,7 +701,7 @@ private static void FinalizeDiagnostics( } } - public void Dispose() + public async ValueTask DisposeAsync() { Histogram durationMetric = _isServer ? s_serverSessionDuration : s_clientSessionDuration; if (durationMetric.Enabled) @@ -695,13 +714,30 @@ public void Dispose() durationMetric.Record(GetElapsed(_sessionStartingTimestamp).TotalSeconds, tags); } - // Complete all pending requests with cancellation foreach (var entry in _pendingRequests) { entry.Value.TrySetCanceled(); } _pendingRequests.Clear(); + + if (_messageProcessingCts is not null) + { + await _messageProcessingCts.CancelAsync().ConfigureAwait(false); + } + + if (_messageProcessingTask is not null) + { + try + { + await _messageProcessingTask.ConfigureAwait(false); + } + catch (OperationCanceledException) + { + // Ignore cancellation + } + } + LogSessionDisposed(EndpointName, _sessionId, _transportKind); } diff --git a/src/ModelContextProtocol.Core/README.md b/src/ModelContextProtocol.Core/README.md index beb365c8..121c6184 100644 --- a/src/ModelContextProtocol.Core/README.md +++ b/src/ModelContextProtocol.Core/README.md @@ -27,8 +27,8 @@ dotnet add package ModelContextProtocol.Core --prerelease ## Getting Started (Client) -To get started writing a client, the `McpClientFactory.CreateAsync` method is used to instantiate and connect an `IMcpClient` -to a server. Once you have an `IMcpClient`, you can interact with it, such as to enumerate all available tools and invoke tools. +To get started writing a client, the `McpClientFactory.CreateAsync` method is used to instantiate and connect an `McpClientSession` +to a server. Once you have an `McpClientSession`, you can interact with it, such as to enumerate all available tools and invoke tools. ```csharp var clientTransport = new StdioClientTransport(new StdioClientTransportOptions diff --git a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs index d286d1ef..e1160468 100644 --- a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs +++ b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs @@ -3,9 +3,8 @@ namespace ModelContextProtocol.Server; -internal sealed class DestinationBoundMcpServer(McpServer server, ITransport? transport) : IMcpServer +internal sealed class DestinationBoundMcpServer(McpServerSession server, ITransport? transport) : IMcpServer { - public string EndpointName => server.EndpointName; public string? SessionId => transport?.SessionId ?? server.SessionId; public ClientCapabilities? ClientCapabilities => server.ClientCapabilities; public Implementation? ClientInfo => server.ClientInfo; diff --git a/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs b/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs index 277ed737..4f19adff 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs @@ -333,7 +333,7 @@ private sealed class ClientLogger(IMcpServer server, string categoryName) : ILog /// public bool IsEnabled(LogLevel logLevel) => server?.LoggingLevel is { } loggingLevel && - McpServer.ToLoggingLevel(logLevel) >= loggingLevel; + McpServerSession.ToLoggingLevel(logLevel) >= loggingLevel; /// public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) @@ -351,7 +351,7 @@ void Log(LogLevel logLevel, string message) { _ = server.SendNotificationAsync(NotificationMethods.LoggingMessageNotification, new LoggingMessageNotificationParams { - Level = McpServer.ToLoggingLevel(logLevel), + Level = McpServerSession.ToLoggingLevel(logLevel), Data = JsonSerializer.SerializeToElement(message, McpJsonUtilities.JsonContext.Default.String), Logger = categoryName, }); diff --git a/src/ModelContextProtocol.Core/Server/McpServerFactory.cs b/src/ModelContextProtocol.Core/Server/McpServerFactory.cs index 50d4188b..79e3d8c1 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerFactory.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerFactory.cs @@ -31,6 +31,6 @@ public static IMcpServer Create( Throw.IfNull(transport); Throw.IfNull(serverOptions); - return new McpServer(transport, serverOptions, loggerFactory, serviceProvider); + return new McpServerSession(transport, serverOptions, loggerFactory, serviceProvider); } } diff --git a/src/ModelContextProtocol.Core/Server/McpServer.cs b/src/ModelContextProtocol.Core/Server/McpServerSession.cs similarity index 90% rename from src/ModelContextProtocol.Core/Server/McpServer.cs rename to src/ModelContextProtocol.Core/Server/McpServerSession.cs index 6c5858f9..1614399c 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerSession.cs @@ -1,5 +1,6 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; using System.Runtime.CompilerServices; using System.Text.Json.Serialization.Metadata; @@ -7,22 +8,28 @@ namespace ModelContextProtocol.Server; /// -internal sealed class McpServer : McpEndpoint, IMcpServer +public sealed class McpServerSession : IMcpServer { internal static Implementation DefaultImplementation { get; } = new() { - Name = DefaultAssemblyName.Name ?? nameof(McpServer), - Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0", + Name = AssemblyNameHelper.DefaultAssemblyName.Name ?? nameof(McpServerSession), + Version = AssemblyNameHelper.DefaultAssemblyName.Version?.ToString() ?? "1.0.0", }; + private readonly ILogger _logger; private readonly ITransport _sessionTransport; private readonly bool _servicesScopePerRequest; private readonly List _disposables = []; + private readonly NotificationHandlers _notificationHandlers; + private readonly RequestHandlers _requestHandlers; + private readonly McpSessionHandler _sessionHandler; private readonly string _serverOnlyEndpointName; - private string? _endpointName; + private string _endpointName; private int _started; + private int _isDisposed; + /// Holds a boxed value for the server. /// /// Initialized to non-null the first time SetLevel is used. This is stored as a strong box @@ -31,7 +38,7 @@ internal sealed class McpServer : McpEndpoint, IMcpServer private StrongBox? _loggingLevel; /// - /// Creates a new instance of . + /// Creates a new instance of . /// /// Transport to use for the server representing an already-established session. /// Configuration options for this server, including capabilities. @@ -39,8 +46,7 @@ internal sealed class McpServer : McpEndpoint, IMcpServer /// Logger factory to use for logging /// Optional service provider to use for dependency injection /// The server was incorrectly configured. - public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? loggerFactory, IServiceProvider? serviceProvider) - : base(loggerFactory) + public McpServerSession(ITransport transport, McpServerOptions options, ILoggerFactory? loggerFactory, IServiceProvider? serviceProvider) { Throw.IfNull(transport); Throw.IfNull(options); @@ -51,11 +57,16 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? ServerOptions = options; Services = serviceProvider; _serverOnlyEndpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; + _endpointName = _serverOnlyEndpointName; _servicesScopePerRequest = options.ScopeRequests; + _logger = loggerFactory?.CreateLogger() ?? NullLogger.Instance; ClientInfo = options.KnownClientInfo; UpdateEndpointNameWithClientInfo(); + _notificationHandlers = new(); + _requestHandlers = []; + // Configure all request handlers based on the supplied options. ServerCapabilities = new(); ConfigureInitialize(options); @@ -70,7 +81,7 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? // Register any notification handlers that were provided. if (options.Capabilities?.NotificationHandlers is { } notificationHandlers) { - NotificationHandlers.RegisterRange(notificationHandlers); + _notificationHandlers.RegisterRange(notificationHandlers); } // Now that everything has been configured, subscribe to any necessary notifications. @@ -93,7 +104,7 @@ void Register(McpServerPrimitiveCollection? collection, } // And initialize the session. - InitializeSession(transport); + _sessionHandler = new McpSessionHandler(isServer: true, _sessionTransport, _endpointName!, _requestHandlers, _notificationHandlers, _logger); } /// @@ -114,9 +125,6 @@ void Register(McpServerPrimitiveCollection? collection, /// public IServiceProvider? Services { get; } - /// - public override string EndpointName => _endpointName ?? _serverOnlyEndpointName; - /// public LoggingLevel? LoggingLevel => _loggingLevel?.Value; @@ -130,8 +138,7 @@ public async Task RunAsync(CancellationToken cancellationToken = default) try { - StartSession(_sessionTransport, fullSessionCancellationToken: cancellationToken); - await MessageProcessingTask.ConfigureAwait(false); + await _sessionHandler.ProcessMessagesAsync(cancellationToken).ConfigureAwait(false); } finally { @@ -139,10 +146,29 @@ public async Task RunAsync(CancellationToken cancellationToken = default) } } - public override async ValueTask DisposeUnsynchronizedAsync() + + /// + public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + => _sessionHandler.SendRequestAsync(request, cancellationToken); + + /// + public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + => _sessionHandler.SendMessageAsync(message, cancellationToken); + + /// + public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) + => _sessionHandler.RegisterNotificationHandler(method, handler); + + /// + public async ValueTask DisposeAsync() { + if (Interlocked.CompareExchange(ref _isDisposed, 1, 0) != 0) + { + return; + } + _disposables.ForEach(d => d()); - await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); + await _sessionHandler.DisposeAsync().ConfigureAwait(false); } private void ConfigurePing() @@ -155,7 +181,7 @@ private void ConfigurePing() private void ConfigureInitialize(McpServerOptions options) { - RequestHandlers.Set(RequestMethods.Initialize, + _requestHandlers.Set(RequestMethods.Initialize, async (request, _, _) => { ClientCapabilities = request?.Capabilities ?? new(); @@ -163,7 +189,7 @@ private void ConfigureInitialize(McpServerOptions options) // Use the ClientInfo to update the session EndpointName for logging. UpdateEndpointNameWithClientInfo(); - GetSessionOrThrow().EndpointName = EndpointName; + _sessionHandler.EndpointName = _endpointName; // Negotiate a protocol version. If the server options provide one, use that. // Otherwise, try to use whatever the client requested as long as it's supported. @@ -171,9 +197,9 @@ private void ConfigureInitialize(McpServerOptions options) string? protocolVersion = options.ProtocolVersion; if (protocolVersion is null) { - protocolVersion = request?.ProtocolVersion is string clientProtocolVersion && McpSession.SupportedProtocolVersions.Contains(clientProtocolVersion) ? + protocolVersion = request?.ProtocolVersion is string clientProtocolVersion && McpSessionHandler.SupportedProtocolVersions.Contains(clientProtocolVersion) ? clientProtocolVersion : - McpSession.LatestProtocolVersion; + McpSessionHandler.LatestProtocolVersion; } return new InitializeResult @@ -496,7 +522,7 @@ private void ConfigureLogging(McpServerOptions options) ServerCapabilities.Logging = new(); ServerCapabilities.Logging.SetLoggingLevelHandler = setLoggingLevelHandler; - RequestHandlers.Set( + _requestHandlers.Set( RequestMethods.LoggingSetLevel, (request, destinationTransport, cancellationToken) => { @@ -566,7 +592,7 @@ private void SetHandler( JsonTypeInfo requestTypeInfo, JsonTypeInfo responseTypeInfo) { - RequestHandlers.Set(method, + _requestHandlers.Set(method, (request, destinationTransport, cancellationToken) => InvokeHandlerAsync(handler, request, destinationTransport, cancellationToken), requestTypeInfo, responseTypeInfo); diff --git a/src/ModelContextProtocol.Core/Server/StdioServerTransport.cs b/src/ModelContextProtocol.Core/Server/StdioServerTransport.cs index 556a3115..26641cf6 100644 --- a/src/ModelContextProtocol.Core/Server/StdioServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StdioServerTransport.cs @@ -37,7 +37,7 @@ private static string GetServerName(McpServerOptions serverOptions) { Throw.IfNull(serverOptions); - return serverOptions.ServerInfo?.Name ?? McpServer.DefaultImplementation.Name; + return serverOptions.ServerInfo?.Name ?? McpServerSession.DefaultImplementation.Name; } // Neither WindowsConsoleStream nor UnixConsoleStream respect CancellationTokens or cancel any I/O on Dispose. diff --git a/src/ModelContextProtocol.Core/TokenProgress.cs b/src/ModelContextProtocol.Core/TokenProgress.cs index f222fbf7..e0506011 100644 --- a/src/ModelContextProtocol.Core/TokenProgress.cs +++ b/src/ModelContextProtocol.Core/TokenProgress.cs @@ -4,13 +4,13 @@ namespace ModelContextProtocol; /// /// Provides an tied to a specific progress token and that will issue -/// progress notifications on the supplied endpoint. +/// progress notifications on the supplied session. /// -internal sealed class TokenProgress(IMcpEndpoint endpoint, ProgressToken progressToken) : IProgress +internal sealed class TokenProgress(IMcpEndpoint session, ProgressToken progressToken) : IProgress { /// public void Report(ProgressNotificationValue value) { - _ = endpoint.NotifyProgressAsync(progressToken, value, CancellationToken.None); + _ = session.NotifyProgressAsync(progressToken, value, CancellationToken.None); } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index 9b3c91b9..844fd734 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -23,7 +23,7 @@ public override void Dispose() protected abstract SseClientTransportOptions ClientTransportOptions { get; } - private Task GetClientAsync(McpClientOptions? options = null) + private Task GetClientAsync(McpClientOptions? options = null) { return _fixture.ConnectMcpClientAsync(options, LoggerFactory); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 4d0d7356..7ada1f9f 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -23,7 +23,7 @@ protected void ConfigureStateless(HttpServerTransportOptions options) options.Stateless = Stateless; } - protected async Task ConnectAsync( + protected async Task ConnectAsync( string? path = null, SseClientTransportOptions? transportOptions = null, McpClientOptions? clientOptions = null) diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index 8191f609..7a74eb31 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -21,7 +21,7 @@ public partial class SseIntegrationTests(ITestOutputHelper outputHelper) : Kestr Name = "In-memory SSE Client", }; - private Task ConnectMcpClientAsync(HttpClient? httpClient = null, SseClientTransportOptions? transportOptions = null) + private Task ConnectMcpClientAsync(HttpClient? httpClient = null, SseClientTransportOptions? transportOptions = null) => McpClientFactory.CreateAsync( new SseClientTransport(transportOptions ?? DefaultTransportOptions, httpClient ?? HttpClient, LoggerFactory), loggerFactory: LoggerFactory, diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs index 2aa675c8..8cdc7645 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs @@ -44,7 +44,7 @@ public SseServerIntegrationTestFixture() public HttpClient HttpClient { get; } - public Task ConnectMcpClientAsync(McpClientOptions? options, ILoggerFactory loggerFactory) + public Task ConnectMcpClientAsync(McpClientOptions? options, ILoggerFactory loggerFactory) { return McpClientFactory.CreateAsync( new SseClientTransport(DefaultTransportOptions, HttpClient, loggerFactory), diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs index b50a43ed..865aaf6e 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs @@ -58,7 +58,7 @@ private async Task StartAsync() HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); } - private Task ConnectMcpClientAsync(McpClientOptions? clientOptions = null) + private Task ConnectMcpClientAsync(McpClientOptions? clientOptions = null) => McpClientFactory.CreateAsync( new SseClientTransport(DefaultTransportOptions, HttpClient, LoggerFactory), clientOptions, LoggerFactory, TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index e3d7ce44..87622719 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -197,7 +197,7 @@ public async Task CreateSamplingHandler_ShouldHandleResourceMessages() [Fact] public async Task ListToolsAsync_AllToolsReturned() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(12, tools.Count); @@ -223,7 +223,7 @@ public async Task ListToolsAsync_AllToolsReturned() [Fact] public async Task EnumerateToolsAsync_AllToolsReturned() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); await foreach (var tool in client.EnumerateToolsAsync(cancellationToken: TestContext.Current.CancellationToken)) { @@ -242,7 +242,7 @@ public async Task EnumerateToolsAsync_AllToolsReturned() public async Task EnumerateToolsAsync_FlowsJsonSerializerOptions() { JsonSerializerOptions options = new(JsonSerializerOptions.Default); - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); bool hasTools = false; await foreach (var tool in client.EnumerateToolsAsync(options, TestContext.Current.CancellationToken)) @@ -263,7 +263,7 @@ public async Task EnumerateToolsAsync_FlowsJsonSerializerOptions() public async Task EnumerateToolsAsync_HonorsJsonSerializerOptions() { JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var tool = (await client.ListToolsAsync(emptyOptions, TestContext.Current.CancellationToken)).First(); await Assert.ThrowsAsync(async () => await tool.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken)); @@ -273,7 +273,7 @@ public async Task EnumerateToolsAsync_HonorsJsonSerializerOptions() public async Task SendRequestAsync_HonorsJsonSerializerOptions() { JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(async () => await client.SendRequestAsync("Method4", new() { Name = "tool" }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); } @@ -282,7 +282,7 @@ public async Task SendRequestAsync_HonorsJsonSerializerOptions() public async Task SendNotificationAsync_HonorsJsonSerializerOptions() { JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(() => client.SendNotificationAsync("Method4", new { Value = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); } @@ -291,7 +291,7 @@ public async Task SendNotificationAsync_HonorsJsonSerializerOptions() public async Task GetPromptsAsync_HonorsJsonSerializerOptions() { JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(async () => await client.GetPromptAsync("Prompt", new Dictionary { ["i"] = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); } @@ -300,7 +300,7 @@ public async Task GetPromptsAsync_HonorsJsonSerializerOptions() public async Task WithName_ChangesToolName() { JsonSerializerOptions options = new(JsonSerializerOptions.Default); - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).First(); var originalName = tool.Name; @@ -315,7 +315,7 @@ public async Task WithName_ChangesToolName() public async Task WithDescription_ChangesToolDescription() { JsonSerializerOptions options = new(JsonSerializerOptions.Default); - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).FirstOrDefault(); var originalDescription = tool?.Description; var redescribedTool = tool?.WithDescription("ToolWithNewDescription"); @@ -344,7 +344,7 @@ public async Task WithProgress_ProgressReported() return 42; }, new() { Name = "ProgressReporter" })); - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var tool = (await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken)).First(t => t.Name == "ProgressReporter"); @@ -372,7 +372,7 @@ private sealed class SynchronousProgress(Action callb [Fact] public async Task AsClientLoggerProvider_MessagesSentToClient() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); ILoggerProvider loggerProvider = Server.AsClientLoggerProvider(); Assert.Throws("categoryName", () => loggerProvider.CreateLogger(null!)); diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs index 7516a218..6c59d94a 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs @@ -85,9 +85,9 @@ public async Task CreateAsync_WithCapabilitiesOptions(Type transportType) }; var clientTransport = (IClientTransport)Activator.CreateInstance(transportType)!; - IMcpClient? client = null; + McpClientSession? client = null; - var actionTask = McpClientFactory.CreateAsync(clientTransport, clientOptions, new Mock().Object, CancellationToken.None); + var actionTask = McpClientFactory.CreateAsync(clientTransport, clientOptions, loggerFactory: null, CancellationToken.None); // Act if (clientTransport is FailureTransport) diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs index 48c3c370..a063e4c5 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs @@ -73,7 +73,7 @@ public static IEnumerable UriTemplate_InputsProduceExpectedOutputs_Mem public async Task UriTemplate_InputsProduceExpectedOutputs( IReadOnlyDictionary variables, string uriTemplate, object expected) { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var result = await client.ReadResourceAsync(uriTemplate, variables, TestContext.Current.CancellationToken); Assert.NotNull(result); diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs index ebc7171e..f36e2621 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs @@ -41,7 +41,7 @@ public void Initialize(ILoggerFactory loggerFactory) _loggerFactory = loggerFactory; } - public Task CreateClientAsync(string clientId, McpClientOptions? clientOptions = null) => + public Task CreateClientAsync(string clientId, McpClientOptions? clientOptions = null) => McpClientFactory.CreateAsync(new StdioClientTransport(clientId switch { "everything" => EverythingServerTransportOptions, diff --git a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs index ec1c8510..8d1b8590 100644 --- a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs +++ b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs @@ -62,7 +62,7 @@ public async ValueTask DisposeAsync() Dispose(); } - protected async Task CreateMcpClientForServer(McpClientOptions? clientOptions = null) + protected async Task CreateMcpClientForServer(McpClientOptions? clientOptions = null) { return await McpClientFactory.CreateAsync( new StreamClientTransport( diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index 3fa2ec78..cc2107e0 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -95,7 +95,7 @@ public void Adds_Prompts_To_Server() [Fact] public async Task Can_List_And_Call_Registered_Prompts() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); Assert.Equal(6, prompts.Count); @@ -124,7 +124,7 @@ public async Task Can_List_And_Call_Registered_Prompts() [Fact] public async Task Can_Be_Notified_Of_Prompt_Changes() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); Assert.Equal(6, prompts.Count); @@ -165,7 +165,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes() [Fact] public async Task TitleAttributeProperty_PropagatedToTitle() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(prompts); @@ -179,7 +179,7 @@ public async Task TitleAttributeProperty_PropagatedToTitle() [Fact] public async Task Throws_When_Prompt_Fails() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(async () => await client.GetPromptAsync( nameof(SimplePrompts.ThrowsException), @@ -189,7 +189,7 @@ await Assert.ThrowsAsync(async () => await client.GetPromptAsync( [Fact] public async Task Throws_Exception_On_Unknown_Prompt() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( "NotRegisteredPrompt", @@ -201,7 +201,7 @@ public async Task Throws_Exception_On_Unknown_Prompt() [Fact] public async Task Throws_Exception_Missing_Parameter() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( "returns_chat_messages", diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs index ed930b17..f2b90c8a 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs @@ -122,7 +122,7 @@ public void Adds_Resources_To_Server() [Fact] public async Task Can_List_And_Call_Registered_Resources() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); Assert.NotNull(client.ServerCapabilities.Resources); @@ -141,7 +141,7 @@ public async Task Can_List_And_Call_Registered_Resources() [Fact] public async Task Can_List_And_Call_Registered_ResourceTemplates() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var resources = await client.ListResourceTemplatesAsync(TestContext.Current.CancellationToken); Assert.Equal(3, resources.Count); @@ -158,7 +158,7 @@ public async Task Can_List_And_Call_Registered_ResourceTemplates() [Fact] public async Task Can_Be_Notified_Of_Resource_Changes() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var resources = await client.ListResourcesAsync(TestContext.Current.CancellationToken); Assert.Equal(5, resources.Count); @@ -199,7 +199,7 @@ public async Task Can_Be_Notified_Of_Resource_Changes() [Fact] public async Task TitleAttributeProperty_PropagatedToTitle() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var resources = await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(resources); @@ -217,7 +217,7 @@ public async Task TitleAttributeProperty_PropagatedToTitle() [Fact] public async Task Throws_When_Resource_Fails() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( $"resource://mcp/{nameof(SimpleResources.ThrowsException)}", @@ -227,7 +227,7 @@ await Assert.ThrowsAsync(async () => await client.ReadResourceAsyn [Fact] public async Task Throws_Exception_On_Unknown_Resource() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( "test:///NotRegisteredResource", diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 35f833d5..748ee50f 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -121,7 +121,7 @@ public void Adds_Tools_To_Server() [Fact] public async Task Can_List_Registered_Tools() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(16, tools.Count); @@ -185,7 +185,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T [Fact] public async Task Can_Be_Notified_Of_Tool_Changes() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(16, tools.Count); @@ -226,7 +226,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() [Fact] public async Task Can_Call_Registered_Tool() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "echo", @@ -245,7 +245,7 @@ public async Task Can_Call_Registered_Tool() [Fact] public async Task Can_Call_Registered_Tool_With_Array_Result() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "echo_array", @@ -268,7 +268,7 @@ public async Task Can_Call_Registered_Tool_With_Array_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Null_Result() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "return_null", @@ -282,7 +282,7 @@ public async Task Can_Call_Registered_Tool_With_Null_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Json_Result() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "return_json", @@ -299,7 +299,7 @@ public async Task Can_Call_Registered_Tool_With_Json_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Int_Result() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "return_integer", @@ -314,7 +314,7 @@ public async Task Can_Call_Registered_Tool_With_Int_Result() [Fact] public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "echo_complex", @@ -331,7 +331,7 @@ public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() [Fact] public async Task Can_Call_Registered_Tool_With_Instance_Method() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); string[][] parts = new string[2][]; for (int i = 0; i < 2; i++) @@ -360,7 +360,7 @@ public async Task Can_Call_Registered_Tool_With_Instance_Method() [Fact] public async Task Returns_IsError_Content_When_Tool_Fails() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "throw_exception", @@ -375,7 +375,7 @@ public async Task Returns_IsError_Content_When_Tool_Fails() [Fact] public async Task Throws_Exception_On_Unknown_Tool() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.CallToolAsync( "NotRegisteredTool", @@ -387,7 +387,7 @@ public async Task Throws_Exception_On_Unknown_Tool() [Fact] public async Task Returns_IsError_Missing_Parameter() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "echo", @@ -506,7 +506,7 @@ public void WithToolsFromAssembly_Parameters_Satisfiable_From_DI(ServiceLifetime [Fact] public async Task Recognizes_Parameter_Types() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); @@ -581,7 +581,7 @@ public void Create_ExtractsToolAnnotations_SomeSet() [Fact] public async Task TitleAttributeProperty_PropagatedToTitle() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); @@ -597,7 +597,7 @@ public async Task TitleAttributeProperty_PropagatedToTitle() [Fact] public async Task HandlesIProgressParameter() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); @@ -651,7 +651,7 @@ public async Task HandlesIProgressParameter() [Fact] public async Task CancellationNotificationsPropagateToToolTokens() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs index b940c1c7..4eec471e 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs @@ -22,7 +22,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer [Fact] public async Task InjectScopedServiceAsArgument() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(McpServerScopedTestsJsonContext.Default.Options, TestContext.Current.CancellationToken); var tool = tools.First(t => t.Name == "echo_complex"); diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs index 116c62a1..fe08438e 100644 --- a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -128,7 +128,7 @@ await RunConnected(async (client, server) => Assert.Equal("-32602", doesNotExistToolClient.Tags.Single(t => t.Key == "rpc.jsonrpc.error_code").Value); } - private static async Task RunConnected(Func action, List clientToServerLog) + private static async Task RunConnected(Func action, List clientToServerLog) { Pipe clientToServerPipe = new(), serverToClientPipe = new(); StreamServerTransport serverTransport = new(clientToServerPipe.Reader.AsStream(), serverToClientPipe.Writer.AsStream()); @@ -153,7 +153,7 @@ private static async Task RunConnected(Func action { serverTask = server.RunAsync(TestContext.Current.CancellationToken); - await using (IMcpClient client = await McpClientFactory.CreateAsync( + await using (McpClientSession client = await McpClientFactory.CreateAsync( clientTransport, cancellationToken: TestContext.Current.CancellationToken)) { diff --git a/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs index f4474391..2b735ed3 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs @@ -67,7 +67,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer [Fact] public async Task Can_Elicit_Information() { - await using IMcpClient client = await CreateMcpClientForServer(new McpClientOptions + await using McpClientSession client = await CreateMcpClientForServer(new McpClientOptions { Capabilities = new() { diff --git a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs index 0d18667e..d8fc4b94 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs @@ -13,7 +13,7 @@ public NotificationHandlerTests(ITestOutputHelper testOutputHelper) public async Task RegistrationsAreRemovedWhenDisposed() { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); const int Iterations = 10; @@ -40,7 +40,7 @@ public async Task RegistrationsAreRemovedWhenDisposed() public async Task MultipleRegistrationsResultInMultipleCallbacks() { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); const int RegistrationCount = 10; @@ -80,7 +80,7 @@ public async Task MultipleRegistrationsResultInMultipleCallbacks() public async Task MultipleHandlersRunEvenIfOneThrows() { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); const int RegistrationCount = 10; @@ -122,7 +122,7 @@ public async Task MultipleHandlersRunEvenIfOneThrows() public async Task DisposeAsyncDoesNotCompleteWhileNotificationHandlerRuns(int numberOfDisposals) { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var handlerRunning = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var releaseHandler = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); @@ -163,7 +163,7 @@ public async Task DisposeAsyncDoesNotCompleteWhileNotificationHandlerRuns(int nu public async Task DisposeAsyncCompletesImmediatelyWhenInvokedFromHandler(int numberOfDisposals) { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClientSession client = await CreateMcpClientForServer(); var handlerRunning = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var releaseHandler = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);