Skip to content

Add WithXx overloads that take target instance #706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 138 additions & 4 deletions src/ModelContextProtocol/McpServerBuilderExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,53 @@ public static partial class McpServerBuilderExtensions
return builder;
}

/// <summary>Adds <see cref="McpServerTool"/> instances to the service collection backing <paramref name="builder"/>.</summary>
/// <typeparam name="TToolType">The tool type.</typeparam>
/// <param name="builder">The builder instance.</param>
/// <param name="target">The target instance from which the tools should be sourced.</param>
/// <param name="serializerOptions">The serializer options governing tool parameter marshalling.</param>
/// <returns>The builder provided in <paramref name="builder"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
/// <remarks>
/// <para>
/// This method discovers all methods (public and non-public) on the specified <typeparamref name="TToolType"/>
/// type, where the methods are attributed as <see cref="McpServerToolAttribute"/>, and adds an <see cref="McpServerTool"/>
/// instance for each, using <paramref name="target"/> as the associated instance for instance methods.
/// </para>
/// <para>
/// However, if <typeparamref name="TToolType"/> is itself an <see cref="IEnumerable{T}"/> of <see cref="McpServerTool"/>,
/// this method will register those tools directly without scanning for methods on <typeparamref name="TToolType"/>.
/// </para>
/// </remarks>
public static IMcpServerBuilder WithTools<[DynamicallyAccessedMembers(
DynamicallyAccessedMemberTypes.PublicMethods |
DynamicallyAccessedMemberTypes.NonPublicMethods)] TToolType>(
this IMcpServerBuilder builder,
TToolType target,
JsonSerializerOptions? serializerOptions = null)
{
Throw.IfNull(builder);
Throw.IfNull(target);

if (target is IEnumerable<McpServerTool> tools)
{
return builder.WithTools(tools);
}

foreach (var toolMethod in typeof(TToolType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance))
{
if (toolMethod.GetCustomAttribute<McpServerToolAttribute>() is not null)
{
builder.Services.AddSingleton(services => McpServerTool.Create(
toolMethod,
toolMethod.IsStatic ? null : target,
new() { Services = services, SerializerOptions = serializerOptions }));
}
}

return builder;
}

/// <summary>Adds <see cref="McpServerTool"/> instances to the service collection backing <paramref name="builder"/>.</summary>
/// <param name="builder">The builder instance.</param>
/// <param name="tools">The <see cref="McpServerTool"/> instances to add to the server.</param>
Expand Down Expand Up @@ -137,7 +184,7 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, IEnume
/// </para>
/// <para>
/// Note that this method performs reflection at runtime and may not work in Native AOT scenarios. For
/// Native AOT compatibility, consider using the generic <see cref="WithTools{TToolType}"/> method instead.
/// Native AOT compatibility, consider using the generic <see cref="M:WithTools"/> method instead.
/// </para>
/// </remarks>
[RequiresUnreferencedCode(WithToolsRequiresUnreferencedCodeMessage)]
Expand Down Expand Up @@ -193,6 +240,50 @@ where t.GetCustomAttribute<McpServerToolTypeAttribute>() is not null
return builder;
}

/// <summary>Adds <see cref="McpServerPrompt"/> instances to the service collection backing <paramref name="builder"/>.</summary>
/// <typeparam name="TPromptType">The prompt type.</typeparam>
/// <param name="builder">The builder instance.</param>
/// <param name="target">The target instance from which the prompts should be sourced.</param>
/// <param name="serializerOptions">The serializer options governing prompt parameter marshalling.</param>
/// <returns>The builder provided in <paramref name="builder"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
/// <remarks>
/// <para>
/// This method discovers all methods (public and non-public) on the specified <typeparamref name="TPromptType"/>
/// type, where the methods are attributed as <see cref="McpServerPromptAttribute"/>, and adds an <see cref="McpServerPrompt"/>
/// instance for each, using <paramref name="target"/> as the associated instance for instance methods.
/// </para>
/// <para>
/// However, if <typeparamref name="TPromptType"/> is itself an <see cref="IEnumerable{T}"/> of <see cref="McpServerPrompt"/>,
/// this method will register those prompts directly without scanning for methods on <typeparamref name="TPromptType"/>.
/// </para>
/// </remarks>
public static IMcpServerBuilder WithPrompts<[DynamicallyAccessedMembers(
DynamicallyAccessedMemberTypes.PublicMethods |
DynamicallyAccessedMemberTypes.NonPublicMethods)] TPromptType>(
this IMcpServerBuilder builder,
TPromptType target,
JsonSerializerOptions? serializerOptions = null)
{
Throw.IfNull(builder);
Throw.IfNull(target);

if (target is IEnumerable<McpServerPrompt> prompts)
{
return builder.WithPrompts(prompts);
}

foreach (var promptMethod in typeof(TPromptType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance))
{
if (promptMethod.GetCustomAttribute<McpServerPromptAttribute>() is not null)
{
builder.Services.AddSingleton(services => McpServerPrompt.Create(promptMethod, target, new() { Services = services, SerializerOptions = serializerOptions }));
}
}

return builder;
}

/// <summary>Adds <see cref="McpServerPrompt"/> instances to the service collection backing <paramref name="builder"/>.</summary>
/// <param name="builder">The builder instance.</param>
/// <param name="prompts">The <see cref="McpServerPrompt"/> instances to add to the server.</param>
Expand Down Expand Up @@ -277,7 +368,7 @@ public static IMcpServerBuilder WithPrompts(this IMcpServerBuilder builder, IEnu
/// </para>
/// <para>
/// Note that this method performs reflection at runtime and may not work in Native AOT scenarios. For
/// Native AOT compatibility, consider using the generic <see cref="WithPrompts{TPromptType}"/> method instead.
/// Native AOT compatibility, consider using the generic <see cref="M:WithPrompts"/> method instead.
/// </para>
/// </remarks>
[RequiresUnreferencedCode(WithPromptsRequiresUnreferencedCodeMessage)]
Expand Down Expand Up @@ -311,7 +402,8 @@ where t.GetCustomAttribute<McpServerPromptTypeAttribute>() is not null
/// instance for each. For instance members, an instance will be constructed for each invocation of the resource.
/// </remarks>
public static IMcpServerBuilder WithResources<[DynamicallyAccessedMembers(
DynamicallyAccessedMemberTypes.PublicMethods | DynamicallyAccessedMemberTypes.NonPublicMethods |
DynamicallyAccessedMemberTypes.PublicMethods |
DynamicallyAccessedMemberTypes.NonPublicMethods |
DynamicallyAccessedMemberTypes.PublicConstructors)] TResourceType>(
this IMcpServerBuilder builder)
{
Expand All @@ -330,6 +422,48 @@ where t.GetCustomAttribute<McpServerPromptTypeAttribute>() is not null
return builder;
}

/// <summary>Adds <see cref="McpServerResource"/> instances to the service collection backing <paramref name="builder"/>.</summary>
/// <typeparam name="TResourceType">The resource type.</typeparam>
/// <param name="builder">The builder instance.</param>
/// <param name="target">The target instance from which the prompts should be sourced.</param>
/// <returns>The builder provided in <paramref name="builder"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
/// <remarks>
/// <para>
/// This method discovers all methods (public and non-public) on the specified <typeparamref name="TResourceType"/>
/// type, where the methods are attributed as <see cref="McpServerResourceAttribute"/>, and adds an <see cref="McpServerResource"/>
/// instance for each, using <paramref name="target"/> as the associated instance for instance methods.
/// </para>
/// <para>
/// However, if <typeparamref name="TResourceType"/> is itself an <see cref="IEnumerable{T}"/> of <see cref="McpServerResource"/>,
/// this method will register those resources directly without scanning for methods on <typeparamref name="TResourceType"/>.
/// </para>
/// </remarks>
public static IMcpServerBuilder WithResources<[DynamicallyAccessedMembers(
DynamicallyAccessedMemberTypes.PublicMethods |
DynamicallyAccessedMemberTypes.NonPublicMethods)] TResourceType>(
this IMcpServerBuilder builder,
TResourceType target)
{
Throw.IfNull(builder);
Throw.IfNull(target);

if (target is IEnumerable<McpServerResource> resources)
{
return builder.WithResources(resources);
}

foreach (var resourceTemplateMethod in typeof(TResourceType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance))
{
if (resourceTemplateMethod.GetCustomAttribute<McpServerResourceAttribute>() is not null)
{
builder.Services.AddSingleton(services => McpServerResource.Create(resourceTemplateMethod, target, new() { Services = services }));
}
}

return builder;
}

/// <summary>Adds <see cref="McpServerResource"/> instances to the service collection backing <paramref name="builder"/>.</summary>
/// <param name="builder">The builder instance.</param>
/// <param name="resourceTemplates">The <see cref="McpServerResource"/> instances to add to the server.</param>
Expand Down Expand Up @@ -412,7 +546,7 @@ public static IMcpServerBuilder WithResources(this IMcpServerBuilder builder, IE
/// </para>
/// <para>
/// Note that this method performs reflection at runtime and may not work in Native AOT scenarios. For
/// Native AOT compatibility, consider using the generic <see cref="WithResources{TResourceType}"/> method instead.
/// Native AOT compatibility, consider using the generic <see cref="M:WithResources"/> method instead.
/// </para>
/// </remarks>
[RequiresUnreferencedCode(WithResourcesRequiresUnreferencedCodeMessage)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using Moq;
using System.Collections;
using System.ComponentModel;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Channels;

Expand Down Expand Up @@ -217,13 +220,63 @@ public void WithPrompts_InvalidArgs_Throws()

Assert.Throws<ArgumentNullException>("prompts", () => builder.WithPrompts((IEnumerable<McpServerPrompt>)null!));
Assert.Throws<ArgumentNullException>("promptTypes", () => builder.WithPrompts((IEnumerable<Type>)null!));
Assert.Throws<ArgumentNullException>("target", () => builder.WithPrompts<object>(target: null!));

IMcpServerBuilder nullBuilder = null!;
Assert.Throws<ArgumentNullException>("builder", () => nullBuilder.WithPrompts<object>());
Assert.Throws<ArgumentNullException>("builder", () => nullBuilder.WithPrompts(new object()));
Assert.Throws<ArgumentNullException>("builder", () => nullBuilder.WithPrompts(Array.Empty<Type>()));
Assert.Throws<ArgumentNullException>("builder", () => nullBuilder.WithPromptsFromAssembly());
}

[Fact]
public async Task WithPrompts_TargetInstance_UsesTarget()
{
ServiceCollection sc = new();

var target = new SimplePrompts(new ObjectWithId() { Id = "42" });
sc.AddMcpServer().WithPrompts(target);

McpServerPrompt prompt = sc.BuildServiceProvider().GetServices<McpServerPrompt>().First(t => t.ProtocolPrompt.Name == "returns_string");
var result = await prompt.GetAsync(new RequestContext<GetPromptRequestParams>(new Mock<IMcpServer>().Object)
{
Params = new GetPromptRequestParams
{
Name = "returns_string",
Arguments = new Dictionary<string, JsonElement>
{
["message"] = JsonSerializer.SerializeToElement("hello", AIJsonUtilities.DefaultOptions),
}
}
}, TestContext.Current.CancellationToken);

Assert.Equal(target.ReturnsString("hello"), (result.Messages[0].Content as TextContentBlock)?.Text);
}

[Fact]
public async Task WithPrompts_TargetInstance_UsesEnumerableImplementation()
{
ServiceCollection sc = new();

sc.AddMcpServer().WithPrompts(new MyPromptProvider());

var prompts = sc.BuildServiceProvider().GetServices<McpServerPrompt>().ToArray();
Assert.Equal(2, prompts.Length);
Assert.Contains(prompts, t => t.ProtocolPrompt.Name == "Returns42");
Assert.Contains(prompts, t => t.ProtocolPrompt.Name == "Returns43");
}

private sealed class MyPromptProvider : IEnumerable<McpServerPrompt>
{
public IEnumerator<McpServerPrompt> GetEnumerator()
{
yield return McpServerPrompt.Create(() => "42", new() { Name = "Returns42" });
yield return McpServerPrompt.Create(() => "43", new() { Name = "Returns43" });
}

IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}

[Fact]
public void Empty_Enumerables_Is_Allowed()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using Moq;
using System.Collections;
using System.ComponentModel;
using System.Text.Json;
using System.Threading.Channels;
using static ModelContextProtocol.Tests.Configuration.McpServerBuilderExtensionsPromptsTests;

namespace ModelContextProtocol.Tests.Configuration;

Expand Down Expand Up @@ -243,13 +247,59 @@ public void WithResources_InvalidArgs_Throws()

Assert.Throws<ArgumentNullException>("resourceTemplates", () => builder.WithResources((IEnumerable<McpServerResource>)null!));
Assert.Throws<ArgumentNullException>("resourceTemplateTypes", () => builder.WithResources((IEnumerable<Type>)null!));
Assert.Throws<ArgumentNullException>("target", () => builder.WithResources<object>(target: null!));

IMcpServerBuilder nullBuilder = null!;
Assert.Throws<ArgumentNullException>("builder", () => nullBuilder.WithResources<object>());
Assert.Throws<ArgumentNullException>("builder", () => nullBuilder.WithResources(new object()));
Assert.Throws<ArgumentNullException>("builder", () => nullBuilder.WithResources(Array.Empty<Type>()));
Assert.Throws<ArgumentNullException>("builder", () => nullBuilder.WithResourcesFromAssembly());
}

[Fact]
public async Task WithResources_TargetInstance_UsesTarget()
{
ServiceCollection sc = new();

var target = new ResourceWithId(new ObjectWithId() { Id = "42" });
sc.AddMcpServer().WithResources(target);

McpServerResource resource = sc.BuildServiceProvider().GetServices<McpServerResource>().First(t => t.ProtocolResource?.Name == "returns_string");
var result = await resource.ReadAsync(new RequestContext<ReadResourceRequestParams>(new Mock<IMcpServer>().Object)
{
Params = new()
{
Uri = "returns://string"
}
}, TestContext.Current.CancellationToken);

Assert.Equal(target.ReturnsString(), (result?.Contents[0] as TextResourceContents)?.Text);
}

[Fact]
public async Task WithResources_TargetInstance_UsesEnumerableImplementation()
{
ServiceCollection sc = new();

sc.AddMcpServer().WithResources(new MyResourceProvider());

var resources = sc.BuildServiceProvider().GetServices<McpServerResource>().ToArray();
Assert.Equal(2, resources.Length);
Assert.Contains(resources, t => t.ProtocolResource?.Name == "Returns42");
Assert.Contains(resources, t => t.ProtocolResource?.Name == "Returns43");
}

private sealed class MyResourceProvider : IEnumerable<McpServerResource>
{
public IEnumerator<McpServerResource> GetEnumerator()
{
yield return McpServerResource.Create(() => "42", new() { Name = "Returns42" });
yield return McpServerResource.Create(() => "43", new() { Name = "Returns43" });
}

IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}

[Fact]
public void Empty_Enumerables_Is_Allowed()
{
Expand Down Expand Up @@ -307,4 +357,11 @@ public sealed class MoreResources
[McpServerResource, Description("Another neat direct resource")]
public static string AnotherNeatDirectResource() => "This is a neat resource";
}

[McpServerResourceType]
public sealed class ResourceWithId(ObjectWithId id)
{
[McpServerResource(UriTemplate = "returns://string")]
public string ReturnsString() => $"Id: {id.Id}";
}
}
Loading
Loading