From cb8861a088a7d4bce70c695c133c2610ccc87fbf Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 13 Aug 2025 13:20:16 -0400 Subject: [PATCH 1/2] Add WithXx overloads that take target instance --- .../McpServerBuilderExtensions.cs | 106 +++++++++++++++++- .../McpServerBuilderExtensionsPromptsTests.cs | 28 +++++ ...cpServerBuilderExtensionsResourcesTests.cs | 32 ++++++ .../McpServerBuilderExtensionsToolsTests.cs | 17 +++ 4 files changed, 179 insertions(+), 4 deletions(-) diff --git a/src/ModelContextProtocol/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/McpServerBuilderExtensions.cs index d925b24f..c845846e 100644 --- a/src/ModelContextProtocol/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/McpServerBuilderExtensions.cs @@ -53,6 +53,39 @@ public static partial class McpServerBuilderExtensions return builder; } + /// Adds instances to the service collection backing . + /// The tool type. + /// The builder instance. + /// The target instance from which the tools should be sourced. + /// The serializer options governing tool parameter marshalling. + /// The builder provided in . + /// is . + /// + /// This method discovers all instance methods (public and non-public) on the specified + /// type, where the methods are attributed as , and adds an + /// instance for each, using as the associated instance. + /// + public static IMcpServerBuilder WithTools<[DynamicallyAccessedMembers( + DynamicallyAccessedMemberTypes.PublicMethods | + DynamicallyAccessedMemberTypes.NonPublicMethods)] TToolType>( + this IMcpServerBuilder builder, + TToolType target, + JsonSerializerOptions? serializerOptions = null) + { + Throw.IfNull(builder); + Throw.IfNull(target); + + foreach (var toolMethod in typeof(TToolType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) + { + if (toolMethod.GetCustomAttribute() is not null) + { + builder.Services.AddSingleton(services => McpServerTool.Create(toolMethod, target, new() { Services = services, SerializerOptions = serializerOptions })); + } + } + + return builder; + } + /// Adds instances to the service collection backing . /// The builder instance. /// The instances to add to the server. @@ -137,7 +170,7 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, IEnume /// /// /// Note that this method performs reflection at runtime and may not work in Native AOT scenarios. For - /// Native AOT compatibility, consider using the generic method instead. + /// Native AOT compatibility, consider using the generic method instead. /// /// [RequiresUnreferencedCode(WithToolsRequiresUnreferencedCodeMessage)] @@ -193,6 +226,39 @@ where t.GetCustomAttribute() is not null return builder; } + /// Adds instances to the service collection backing . + /// The prompt type. + /// The builder instance. + /// The target instance from which the prompts should be sourced. + /// The serializer options governing prompt parameter marshalling. + /// The builder provided in . + /// is . + /// + /// This method discovers all instance methods (public and non-public) on the specified + /// type, where the methods are attributed as , and adds an + /// instance for each, using as the associated instance. + /// + public static IMcpServerBuilder WithPrompts<[DynamicallyAccessedMembers( + DynamicallyAccessedMemberTypes.PublicMethods | + DynamicallyAccessedMemberTypes.NonPublicMethods)] TPromptType>( + this IMcpServerBuilder builder, + TPromptType target, + JsonSerializerOptions? serializerOptions = null) + { + Throw.IfNull(builder); + Throw.IfNull(target); + + foreach (var promptMethod in typeof(TPromptType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) + { + if (promptMethod.GetCustomAttribute() is not null) + { + builder.Services.AddSingleton(services => McpServerPrompt.Create(promptMethod, target, new() { Services = services, SerializerOptions = serializerOptions })); + } + } + + return builder; + } + /// Adds instances to the service collection backing . /// The builder instance. /// The instances to add to the server. @@ -277,7 +343,7 @@ public static IMcpServerBuilder WithPrompts(this IMcpServerBuilder builder, IEnu /// /// /// Note that this method performs reflection at runtime and may not work in Native AOT scenarios. For - /// Native AOT compatibility, consider using the generic method instead. + /// Native AOT compatibility, consider using the generic method instead. /// /// [RequiresUnreferencedCode(WithPromptsRequiresUnreferencedCodeMessage)] @@ -311,7 +377,8 @@ where t.GetCustomAttribute() is not null /// instance for each. For instance members, an instance will be constructed for each invocation of the resource. /// public static IMcpServerBuilder WithResources<[DynamicallyAccessedMembers( - DynamicallyAccessedMemberTypes.PublicMethods | DynamicallyAccessedMemberTypes.NonPublicMethods | + DynamicallyAccessedMemberTypes.PublicMethods | + DynamicallyAccessedMemberTypes.NonPublicMethods | DynamicallyAccessedMemberTypes.PublicConstructors)] TResourceType>( this IMcpServerBuilder builder) { @@ -330,6 +397,37 @@ where t.GetCustomAttribute() is not null return builder; } + /// Adds instances to the service collection backing . + /// The resource type. + /// The builder instance. + /// The target instance from which the prompts should be sourced. + /// The builder provided in . + /// is . + /// + /// This method discovers all instance methods (public and non-public) on the specified + /// type, where the methods are attributed as , and adds an + /// instance for each, using as the associated instance. + /// + public static IMcpServerBuilder WithResources<[DynamicallyAccessedMembers( + DynamicallyAccessedMemberTypes.PublicMethods | + DynamicallyAccessedMemberTypes.NonPublicMethods)] TResourceType>( + this IMcpServerBuilder builder, + TResourceType target) + { + Throw.IfNull(builder); + Throw.IfNull(target); + + foreach (var resourceTemplateMethod in typeof(TResourceType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) + { + if (resourceTemplateMethod.GetCustomAttribute() is not null) + { + builder.Services.AddSingleton(services => McpServerResource.Create(resourceTemplateMethod, target, new() { Services = services })); + } + } + + return builder; + } + /// Adds instances to the service collection backing . /// The builder instance. /// The instances to add to the server. @@ -412,7 +510,7 @@ public static IMcpServerBuilder WithResources(this IMcpServerBuilder builder, IE /// /// /// Note that this method performs reflection at runtime and may not work in Native AOT scenarios. For - /// Native AOT compatibility, consider using the generic method instead. + /// Native AOT compatibility, consider using the generic method instead. /// /// [RequiresUnreferencedCode(WithResourcesRequiresUnreferencedCodeMessage)] diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index 3fa2ec78..0f3ea550 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -4,7 +4,9 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using Moq; using System.ComponentModel; +using System.Text.Json; using System.Text.Json.Serialization; using System.Threading.Channels; @@ -217,13 +219,39 @@ public void WithPrompts_InvalidArgs_Throws() Assert.Throws("prompts", () => builder.WithPrompts((IEnumerable)null!)); Assert.Throws("promptTypes", () => builder.WithPrompts((IEnumerable)null!)); + Assert.Throws("target", () => builder.WithPrompts(target: null!)); IMcpServerBuilder nullBuilder = null!; Assert.Throws("builder", () => nullBuilder.WithPrompts()); + Assert.Throws("builder", () => nullBuilder.WithPrompts(new object())); Assert.Throws("builder", () => nullBuilder.WithPrompts(Array.Empty())); Assert.Throws("builder", () => nullBuilder.WithPromptsFromAssembly()); } + [Fact] + public async Task WithPrompts_TargetInstance_UsesTarget() + { + ServiceCollection sc = new(); + + var target = new SimplePrompts(new ObjectWithId() { Id = "42" }); + sc.AddMcpServer().WithPrompts(target); + + McpServerPrompt prompt = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolPrompt.Name == "returns_string"); + var result = await prompt.GetAsync(new RequestContext(new Mock().Object) + { + Params = new GetPromptRequestParams + { + Name = "returns_string", + Arguments = new Dictionary + { + ["message"] = JsonSerializer.SerializeToElement("hello", AIJsonUtilities.DefaultOptions), + } + } + }, TestContext.Current.CancellationToken); + + Assert.Equal(target.ReturnsString("hello"), (result.Messages[0].Content as TextContentBlock)?.Text); + } + [Fact] public void Empty_Enumerables_Is_Allowed() { diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs index ed930b17..aa4a8607 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs @@ -4,8 +4,11 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using Moq; using System.ComponentModel; +using System.Text.Json; using System.Threading.Channels; +using static ModelContextProtocol.Tests.Configuration.McpServerBuilderExtensionsPromptsTests; namespace ModelContextProtocol.Tests.Configuration; @@ -243,13 +246,35 @@ public void WithResources_InvalidArgs_Throws() Assert.Throws("resourceTemplates", () => builder.WithResources((IEnumerable)null!)); Assert.Throws("resourceTemplateTypes", () => builder.WithResources((IEnumerable)null!)); + Assert.Throws("target", () => builder.WithResources(target: null!)); IMcpServerBuilder nullBuilder = null!; Assert.Throws("builder", () => nullBuilder.WithResources()); + Assert.Throws("builder", () => nullBuilder.WithResources(new object())); Assert.Throws("builder", () => nullBuilder.WithResources(Array.Empty())); Assert.Throws("builder", () => nullBuilder.WithResourcesFromAssembly()); } + [Fact] + public async Task WithResources_TargetInstance_UsesTarget() + { + ServiceCollection sc = new(); + + var target = new ResourceWithId(new ObjectWithId() { Id = "42" }); + sc.AddMcpServer().WithResources(target); + + McpServerResource resource = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolResource?.Name == "returns_string"); + var result = await resource.ReadAsync(new RequestContext(new Mock().Object) + { + Params = new() + { + Uri = "returns://string" + } + }, TestContext.Current.CancellationToken); + + Assert.Equal(target.ReturnsString(), (result?.Contents[0] as TextResourceContents)?.Text); + } + [Fact] public void Empty_Enumerables_Is_Allowed() { @@ -307,4 +332,11 @@ public sealed class MoreResources [McpServerResource, Description("Another neat direct resource")] public static string AnotherNeatDirectResource() => "This is a neat resource"; } + + [McpServerResourceType] + public sealed class ResourceWithId(ObjectWithId id) + { + [McpServerResource(UriTemplate = "returns://string")] + public string ReturnsString() => $"Id: {id.Id}"; + } } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 35f833d5..4e430fce 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -5,6 +5,7 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using Moq; using System.Collections.Concurrent; using System.ComponentModel; using System.IO.Pipelines; @@ -403,9 +404,11 @@ public void WithTools_InvalidArgs_Throws() Assert.Throws("tools", () => builder.WithTools((IEnumerable)null!)); Assert.Throws("toolTypes", () => builder.WithTools((IEnumerable)null!)); + Assert.Throws("target", () => builder.WithTools(target: null!)); IMcpServerBuilder nullBuilder = null!; Assert.Throws("builder", () => nullBuilder.WithTools()); + Assert.Throws("builder", () => nullBuilder.WithTools(new object())); Assert.Throws("builder", () => nullBuilder.WithTools(Array.Empty())); Assert.Throws("builder", () => nullBuilder.WithToolsFromAssembly()); } @@ -503,6 +506,20 @@ public void WithToolsFromAssembly_Parameters_Satisfiable_From_DI(ServiceLifetime } } + [Fact] + public async Task WithTools_TargetInstance_UsesTarget() + { + ServiceCollection sc = new(); + + var target = new EchoTool(new ObjectWithId()); + sc.AddMcpServer().WithTools(target, BuilderToolsJsonContext.Default.Options); + + McpServerTool tool = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolTool.Name == "get_ctor_parameter"); + var result = await tool.InvokeAsync(new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); + + Assert.Equal(target.GetCtorParameter(), (result.Content[0] as TextContentBlock)?.Text); + } + [Fact] public async Task Recognizes_Parameter_Types() { From 09a4a9905a1efa24565dae38da41fd3e666131e2 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 14 Aug 2025 18:22:28 -0400 Subject: [PATCH 2/2] Special-case enumerables --- .../McpServerBuilderExtensions.cs | 52 ++++++++++++++++--- .../McpServerBuilderExtensionsPromptsTests.cs | 25 +++++++++ ...cpServerBuilderExtensionsResourcesTests.cs | 25 +++++++++ .../McpServerBuilderExtensionsToolsTests.cs | 25 +++++++++ 4 files changed, 119 insertions(+), 8 deletions(-) diff --git a/src/ModelContextProtocol/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/McpServerBuilderExtensions.cs index c845846e..2d6314ba 100644 --- a/src/ModelContextProtocol/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/McpServerBuilderExtensions.cs @@ -61,9 +61,15 @@ public static partial class McpServerBuilderExtensions /// The builder provided in . /// is . /// - /// This method discovers all instance methods (public and non-public) on the specified + /// + /// This method discovers all methods (public and non-public) on the specified /// type, where the methods are attributed as , and adds an - /// instance for each, using as the associated instance. + /// instance for each, using as the associated instance for instance methods. + /// + /// + /// However, if is itself an of , + /// this method will register those tools directly without scanning for methods on . + /// /// public static IMcpServerBuilder WithTools<[DynamicallyAccessedMembers( DynamicallyAccessedMemberTypes.PublicMethods | @@ -75,11 +81,19 @@ public static partial class McpServerBuilderExtensions Throw.IfNull(builder); Throw.IfNull(target); + if (target is IEnumerable tools) + { + return builder.WithTools(tools); + } + foreach (var toolMethod in typeof(TToolType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) { if (toolMethod.GetCustomAttribute() is not null) { - builder.Services.AddSingleton(services => McpServerTool.Create(toolMethod, target, new() { Services = services, SerializerOptions = serializerOptions })); + builder.Services.AddSingleton(services => McpServerTool.Create( + toolMethod, + toolMethod.IsStatic ? null : target, + new() { Services = services, SerializerOptions = serializerOptions })); } } @@ -234,9 +248,15 @@ where t.GetCustomAttribute() is not null /// The builder provided in . /// is . /// - /// This method discovers all instance methods (public and non-public) on the specified + /// + /// This method discovers all methods (public and non-public) on the specified /// type, where the methods are attributed as , and adds an - /// instance for each, using as the associated instance. + /// instance for each, using as the associated instance for instance methods. + /// + /// + /// However, if is itself an of , + /// this method will register those prompts directly without scanning for methods on . + /// /// public static IMcpServerBuilder WithPrompts<[DynamicallyAccessedMembers( DynamicallyAccessedMemberTypes.PublicMethods | @@ -248,6 +268,11 @@ where t.GetCustomAttribute() is not null Throw.IfNull(builder); Throw.IfNull(target); + if (target is IEnumerable prompts) + { + return builder.WithPrompts(prompts); + } + foreach (var promptMethod in typeof(TPromptType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) { if (promptMethod.GetCustomAttribute() is not null) @@ -404,9 +429,15 @@ where t.GetCustomAttribute() is not null /// The builder provided in . /// is . /// - /// This method discovers all instance methods (public and non-public) on the specified - /// type, where the methods are attributed as , and adds an - /// instance for each, using as the associated instance. + /// + /// This method discovers all methods (public and non-public) on the specified + /// type, where the methods are attributed as , and adds an + /// instance for each, using as the associated instance for instance methods. + /// + /// + /// However, if is itself an of , + /// this method will register those resources directly without scanning for methods on . + /// /// public static IMcpServerBuilder WithResources<[DynamicallyAccessedMembers( DynamicallyAccessedMemberTypes.PublicMethods | @@ -417,6 +448,11 @@ where t.GetCustomAttribute() is not null Throw.IfNull(builder); Throw.IfNull(target); + if (target is IEnumerable resources) + { + return builder.WithResources(resources); + } + foreach (var resourceTemplateMethod in typeof(TResourceType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) { if (resourceTemplateMethod.GetCustomAttribute() is not null) diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index 0f3ea550..d697b979 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -5,6 +5,7 @@ using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using Moq; +using System.Collections; using System.ComponentModel; using System.Text.Json; using System.Text.Json.Serialization; @@ -252,6 +253,30 @@ public async Task WithPrompts_TargetInstance_UsesTarget() Assert.Equal(target.ReturnsString("hello"), (result.Messages[0].Content as TextContentBlock)?.Text); } + [Fact] + public async Task WithPrompts_TargetInstance_UsesEnumerableImplementation() + { + ServiceCollection sc = new(); + + sc.AddMcpServer().WithPrompts(new MyPromptProvider()); + + var prompts = sc.BuildServiceProvider().GetServices().ToArray(); + Assert.Equal(2, prompts.Length); + Assert.Contains(prompts, t => t.ProtocolPrompt.Name == "Returns42"); + Assert.Contains(prompts, t => t.ProtocolPrompt.Name == "Returns43"); + } + + private sealed class MyPromptProvider : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return McpServerPrompt.Create(() => "42", new() { Name = "Returns42" }); + yield return McpServerPrompt.Create(() => "43", new() { Name = "Returns43" }); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + [Fact] public void Empty_Enumerables_Is_Allowed() { diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs index aa4a8607..e6b177f5 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs @@ -5,6 +5,7 @@ using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using Moq; +using System.Collections; using System.ComponentModel; using System.Text.Json; using System.Threading.Channels; @@ -275,6 +276,30 @@ public async Task WithResources_TargetInstance_UsesTarget() Assert.Equal(target.ReturnsString(), (result?.Contents[0] as TextResourceContents)?.Text); } + [Fact] + public async Task WithResources_TargetInstance_UsesEnumerableImplementation() + { + ServiceCollection sc = new(); + + sc.AddMcpServer().WithResources(new MyResourceProvider()); + + var resources = sc.BuildServiceProvider().GetServices().ToArray(); + Assert.Equal(2, resources.Length); + Assert.Contains(resources, t => t.ProtocolResource?.Name == "Returns42"); + Assert.Contains(resources, t => t.ProtocolResource?.Name == "Returns43"); + } + + private sealed class MyResourceProvider : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return McpServerResource.Create(() => "42", new() { Name = "Returns42" }); + yield return McpServerResource.Create(() => "43", new() { Name = "Returns43" }); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + [Fact] public void Empty_Enumerables_Is_Allowed() { diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 4e430fce..dbea036d 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -6,6 +6,7 @@ using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using Moq; +using System.Collections; using System.Collections.Concurrent; using System.ComponentModel; using System.IO.Pipelines; @@ -520,6 +521,30 @@ public async Task WithTools_TargetInstance_UsesTarget() Assert.Equal(target.GetCtorParameter(), (result.Content[0] as TextContentBlock)?.Text); } + [Fact] + public async Task WithTools_TargetInstance_UsesEnumerableImplementation() + { + ServiceCollection sc = new(); + + sc.AddMcpServer().WithTools(new MyToolProvider()); + + var tools = sc.BuildServiceProvider().GetServices().ToArray(); + Assert.Equal(2, tools.Length); + Assert.Contains(tools, t => t.ProtocolTool.Name == "Returns42"); + Assert.Contains(tools, t => t.ProtocolTool.Name == "Returns43"); + } + + private sealed class MyToolProvider : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return McpServerTool.Create(() => "42", new() { Name = "Returns42" }); + yield return McpServerTool.Create(() => "43", new() { Name = "Returns43" }); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + [Fact] public async Task Recognizes_Parameter_Types() {