Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
using System.Net.Http.Headers;

namespace ModelContextProtocol.Authentication;

Expand All @@ -20,7 +19,7 @@ internal override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage r
{
if (request.Headers.Authorization == null)
{
await AddAuthorizationHeaderAsync(request, _currentScheme, cancellationToken).ConfigureAwait(false);
await credentialProvider.AddAuthorizationHeaderAsync(request, _currentScheme, cancellationToken).ConfigureAwait(false);
}

var response = await base.SendAsync(request, message, cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -78,7 +77,7 @@ private async Task<HttpResponseMessage> HandleUnauthorizedResponseAsync(
}
}

await AddAuthorizationHeaderAsync(retryRequest, _currentScheme, cancellationToken).ConfigureAwait(false);
await credentialProvider.AddAuthorizationHeaderAsync(retryRequest, _currentScheme, cancellationToken).ConfigureAwait(false);
return await base.SendAsync(retryRequest, originalJsonRpcMessage, cancellationToken).ConfigureAwait(false);
}

Expand All @@ -96,23 +95,4 @@ private static HashSet<string> ExtractServerSupportedSchemes(HttpResponseMessage

return serverSchemes;
}

/// <summary>
/// Adds an authorization header to the request.
/// </summary>
private async Task AddAuthorizationHeaderAsync(HttpRequestMessage request, string scheme, CancellationToken cancellationToken)
{
if (request.RequestUri is null)
{
return;
}

var token = await credentialProvider.GetCredentialAsync(scheme, request.RequestUri, cancellationToken).ConfigureAwait(false);
if (string.IsNullOrEmpty(token))
{
return;
}

request.Headers.Authorization = new AuthenticationHeaderValue(scheme, token);
}
}
41 changes: 41 additions & 0 deletions src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,19 @@ public sealed class ClientOAuthOptions
/// </remarks>
public AuthorizationRedirectDelegate? AuthorizationRedirectDelegate { get; set; }

/// <summary>
/// Gets or sets the delegate used for handling the dynamic client registration response.
/// </summary>
/// <remarks>
/// <para>
/// This delegate is responsible for processing the response from the dynamic client registration endpoint.
/// </para>
/// <para>
/// The implementation should save the client credentials securely for future use.
/// </para>
/// </remarks>
public DynamicClientRegistrationDelegate? DynamicClientRegistrationDelegate { get; set; }

/// <summary>
/// Gets or sets the authorization server selector function.
/// </summary>
Expand Down Expand Up @@ -85,6 +98,34 @@ public sealed class ClientOAuthOptions
/// </remarks>
public Uri? ClientUri { get; set; }

/// <summary>
/// Gets or sets the client type to use during dynamic client registration.
/// </summary>
/// <remarks>
/// <para>
/// This indicates whether the client is confidential (requires a client secret) or public (does not require a client secret).
/// Only used when a <see cref="ClientId"/> is not specified.
/// </para>
/// <para>
/// When not specified, the client type will default to <see cref="OAuthClientType.Confidential"/>.
/// </para>
/// </remarks>
public OAuthClientType? ClientType { get; set; }

/// <summary>
/// Gets or sets the initial access token to use during dynamic client registration.
/// </summary>
/// <remarks>
/// <para>
/// This token is used to authenticate the client during the registration process.
/// Only used when a <see cref="ClientId"/> is not specified.
/// </para>
/// <para>
/// This is required if the authorization server does not allow anonymous client registration.
/// </para>
/// </remarks>
public string? InitialAccessToken { get; set; }

/// <summary>
/// Gets or sets additional parameters to include in the query string of the OAuth authorization request
/// providing extra information or fulfilling specific requirements of the OAuth provider.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Microsoft.Extensions.Logging.Abstractions;
using System.Collections.Specialized;
using System.Diagnostics.CodeAnalysis;
using System.Net.Http.Headers;
using System.Security.Cryptography;
using System.Text;
using System.Text.Json;
Expand All @@ -27,10 +28,12 @@ internal sealed partial class ClientOAuthProvider
private readonly IDictionary<string, string> _additionalAuthorizationParameters;
private readonly Func<IReadOnlyList<Uri>, Uri?> _authServerSelector;
private readonly AuthorizationRedirectDelegate _authorizationRedirectDelegate;
private readonly DynamicClientRegistrationDelegate? _dynamicClientRegistrationDelegate;

// _clientName and _client URI is used for dynamic client registration (RFC 7591)
// _clientName, _clientUri, and _clientType is used for dynamic client registration (RFC 7591)
private readonly string? _clientName;
private readonly Uri? _clientUri;
private readonly OAuthClientType _clientType;

private readonly HttpClient _httpClient;
private readonly ILogger _logger;
Expand Down Expand Up @@ -69,6 +72,7 @@ public ClientOAuthProvider(
_redirectUri = options.RedirectUri ?? throw new ArgumentException("ClientOAuthOptions.RedirectUri must configured.");
_clientName = options.ClientName;
_clientUri = options.ClientUri;
_clientType = options.ClientType ?? OAuthClientType.Confidential;
_scopes = options.Scopes?.ToArray();
_additionalAuthorizationParameters = options.AdditionalAuthorizationParameters;

Expand All @@ -77,6 +81,20 @@ public ClientOAuthProvider(

// Set up authorization URL handler (use default if not provided)
_authorizationRedirectDelegate = options.AuthorizationRedirectDelegate ?? DefaultAuthorizationUrlHandler;

// Set up dynamic client registration delegate
_dynamicClientRegistrationDelegate = options.DynamicClientRegistrationDelegate;

if (options.InitialAccessToken is not null)
{
_token = new()
{
AccessToken = options.InitialAccessToken,
ExpiresIn = 900,
TokenType = BearerScheme,
ObtainedAt = DateTimeOffset.UtcNow,
};
}
}

/// <summary>
Expand Down Expand Up @@ -175,6 +193,25 @@ public async Task HandleUnauthorizedResponseAsync(
await PerformOAuthAuthorizationAsync(response, cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// Adds an authorization header to the request.
/// </summary>
internal async Task AddAuthorizationHeaderAsync(HttpRequestMessage request, string scheme, CancellationToken cancellationToken)
{
if (request.RequestUri is null)
{
return;
}

var token = await GetCredentialAsync(scheme, request.RequestUri, cancellationToken).ConfigureAwait(false);
if (string.IsNullOrEmpty(token))
{
return;
}

request.Headers.Authorization = new AuthenticationHeaderValue(scheme, token);
}

/// <summary>
/// Performs OAuth authorization by selecting an appropriate authorization server and completing the OAuth flow.
/// </summary>
Expand Down Expand Up @@ -442,7 +479,7 @@ private async Task PerformDynamicClientRegistrationAsync(
RedirectUris = [_redirectUri.ToString()],
GrantTypes = ["authorization_code", "refresh_token"],
ResponseTypes = ["code"],
TokenEndpointAuthMethod = "client_secret_post",
TokenEndpointAuthMethod = _clientType == OAuthClientType.Confidential ? "client_secret_post" : "none",
ClientName = _clientName,
ClientUri = _clientUri?.ToString(),
Scope = _scopes is not null ? string.Join(" ", _scopes) : null
Expand All @@ -456,6 +493,11 @@ private async Task PerformDynamicClientRegistrationAsync(
Content = requestContent
};

if (_token is not null)
{
await AddAuthorizationHeaderAsync(request, _token.TokenType, cancellationToken).ConfigureAwait(false);
}

using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false);

if (!httpResponse.IsSuccessStatusCode)
Expand Down Expand Up @@ -483,6 +525,11 @@ private async Task PerformDynamicClientRegistrationAsync(
}

LogDynamicClientRegistrationSuccessful(_clientId!);

if (_dynamicClientRegistrationDelegate is not null)
{
await _dynamicClientRegistrationDelegate(registrationResponse, cancellationToken).ConfigureAwait(false);
}
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

namespace ModelContextProtocol.Authentication;

/// <summary>
/// Represents a method that handles the dynamic client registration response.
/// </summary>
/// <param name="response">The dynamic client registration response containing the client credentials.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A task that represents the asynchronous operation.</returns>
/// <remarks>
/// The implementation should save the client credentials securely for future use.
/// </remarks>
public delegate Task DynamicClientRegistrationDelegate(DynamicClientRegistrationResponse response, CancellationToken cancellationToken);
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace ModelContextProtocol.Authentication;
/// <summary>
/// Represents a client registration response for OAuth 2.0 Dynamic Client Registration (RFC 7591).
/// </summary>
internal sealed class DynamicClientRegistrationResponse
public sealed class DynamicClientRegistrationResponse
{
/// <summary>
/// Gets or sets the client identifier.
Expand Down
17 changes: 17 additions & 0 deletions src/ModelContextProtocol.Core/Authentication/OAuthClientType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
namespace ModelContextProtocol.Authentication;

/// <summary>
/// Represents the type of OAuth client.
/// </summary>
public enum OAuthClientType
{
/// <summary>
/// A confidential client, typically a server-side application that can securely store credentials.
/// </summary>
Confidential,

/// <summary>
/// A public client, typically a client-side application that cannot securely store credentials.
/// </summary>
Public,
}
Loading