diff --git a/src/ModelContextProtocol/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/McpServerBuilderExtensions.cs index d925b24f..2d6314ba 100644 --- a/src/ModelContextProtocol/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/McpServerBuilderExtensions.cs @@ -53,6 +53,53 @@ public static partial class McpServerBuilderExtensions return builder; } + /// Adds instances to the service collection backing . + /// The tool type. + /// The builder instance. + /// The target instance from which the tools should be sourced. + /// The serializer options governing tool parameter marshalling. + /// The builder provided in . + /// is . + /// + /// + /// This method discovers all methods (public and non-public) on the specified + /// type, where the methods are attributed as , and adds an + /// instance for each, using as the associated instance for instance methods. + /// + /// + /// However, if is itself an of , + /// this method will register those tools directly without scanning for methods on . + /// + /// + public static IMcpServerBuilder WithTools<[DynamicallyAccessedMembers( + DynamicallyAccessedMemberTypes.PublicMethods | + DynamicallyAccessedMemberTypes.NonPublicMethods)] TToolType>( + this IMcpServerBuilder builder, + TToolType target, + JsonSerializerOptions? serializerOptions = null) + { + Throw.IfNull(builder); + Throw.IfNull(target); + + if (target is IEnumerable tools) + { + return builder.WithTools(tools); + } + + foreach (var toolMethod in typeof(TToolType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) + { + if (toolMethod.GetCustomAttribute() is not null) + { + builder.Services.AddSingleton(services => McpServerTool.Create( + toolMethod, + toolMethod.IsStatic ? null : target, + new() { Services = services, SerializerOptions = serializerOptions })); + } + } + + return builder; + } + /// Adds instances to the service collection backing . /// The builder instance. /// The instances to add to the server. @@ -137,7 +184,7 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, IEnume /// /// /// Note that this method performs reflection at runtime and may not work in Native AOT scenarios. For - /// Native AOT compatibility, consider using the generic method instead. + /// Native AOT compatibility, consider using the generic method instead. /// /// [RequiresUnreferencedCode(WithToolsRequiresUnreferencedCodeMessage)] @@ -193,6 +240,50 @@ where t.GetCustomAttribute() is not null return builder; } + /// Adds instances to the service collection backing . + /// The prompt type. + /// The builder instance. + /// The target instance from which the prompts should be sourced. + /// The serializer options governing prompt parameter marshalling. + /// The builder provided in . + /// is . + /// + /// + /// This method discovers all methods (public and non-public) on the specified + /// type, where the methods are attributed as , and adds an + /// instance for each, using as the associated instance for instance methods. + /// + /// + /// However, if is itself an of , + /// this method will register those prompts directly without scanning for methods on . + /// + /// + public static IMcpServerBuilder WithPrompts<[DynamicallyAccessedMembers( + DynamicallyAccessedMemberTypes.PublicMethods | + DynamicallyAccessedMemberTypes.NonPublicMethods)] TPromptType>( + this IMcpServerBuilder builder, + TPromptType target, + JsonSerializerOptions? serializerOptions = null) + { + Throw.IfNull(builder); + Throw.IfNull(target); + + if (target is IEnumerable prompts) + { + return builder.WithPrompts(prompts); + } + + foreach (var promptMethod in typeof(TPromptType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) + { + if (promptMethod.GetCustomAttribute() is not null) + { + builder.Services.AddSingleton(services => McpServerPrompt.Create(promptMethod, target, new() { Services = services, SerializerOptions = serializerOptions })); + } + } + + return builder; + } + /// Adds instances to the service collection backing . /// The builder instance. /// The instances to add to the server. @@ -277,7 +368,7 @@ public static IMcpServerBuilder WithPrompts(this IMcpServerBuilder builder, IEnu /// /// /// Note that this method performs reflection at runtime and may not work in Native AOT scenarios. For - /// Native AOT compatibility, consider using the generic method instead. + /// Native AOT compatibility, consider using the generic method instead. /// /// [RequiresUnreferencedCode(WithPromptsRequiresUnreferencedCodeMessage)] @@ -311,7 +402,8 @@ where t.GetCustomAttribute() is not null /// instance for each. For instance members, an instance will be constructed for each invocation of the resource. /// public static IMcpServerBuilder WithResources<[DynamicallyAccessedMembers( - DynamicallyAccessedMemberTypes.PublicMethods | DynamicallyAccessedMemberTypes.NonPublicMethods | + DynamicallyAccessedMemberTypes.PublicMethods | + DynamicallyAccessedMemberTypes.NonPublicMethods | DynamicallyAccessedMemberTypes.PublicConstructors)] TResourceType>( this IMcpServerBuilder builder) { @@ -330,6 +422,48 @@ where t.GetCustomAttribute() is not null return builder; } + /// Adds instances to the service collection backing . + /// The resource type. + /// The builder instance. + /// The target instance from which the prompts should be sourced. + /// The builder provided in . + /// is . + /// + /// + /// This method discovers all methods (public and non-public) on the specified + /// type, where the methods are attributed as , and adds an + /// instance for each, using as the associated instance for instance methods. + /// + /// + /// However, if is itself an of , + /// this method will register those resources directly without scanning for methods on . + /// + /// + public static IMcpServerBuilder WithResources<[DynamicallyAccessedMembers( + DynamicallyAccessedMemberTypes.PublicMethods | + DynamicallyAccessedMemberTypes.NonPublicMethods)] TResourceType>( + this IMcpServerBuilder builder, + TResourceType target) + { + Throw.IfNull(builder); + Throw.IfNull(target); + + if (target is IEnumerable resources) + { + return builder.WithResources(resources); + } + + foreach (var resourceTemplateMethod in typeof(TResourceType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) + { + if (resourceTemplateMethod.GetCustomAttribute() is not null) + { + builder.Services.AddSingleton(services => McpServerResource.Create(resourceTemplateMethod, target, new() { Services = services })); + } + } + + return builder; + } + /// Adds instances to the service collection backing . /// The builder instance. /// The instances to add to the server. @@ -412,7 +546,7 @@ public static IMcpServerBuilder WithResources(this IMcpServerBuilder builder, IE /// /// /// Note that this method performs reflection at runtime and may not work in Native AOT scenarios. For - /// Native AOT compatibility, consider using the generic method instead. + /// Native AOT compatibility, consider using the generic method instead. /// /// [RequiresUnreferencedCode(WithResourcesRequiresUnreferencedCodeMessage)] diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index 3fa2ec78..d697b979 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -4,7 +4,10 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using Moq; +using System.Collections; using System.ComponentModel; +using System.Text.Json; using System.Text.Json.Serialization; using System.Threading.Channels; @@ -217,13 +220,63 @@ public void WithPrompts_InvalidArgs_Throws() Assert.Throws("prompts", () => builder.WithPrompts((IEnumerable)null!)); Assert.Throws("promptTypes", () => builder.WithPrompts((IEnumerable)null!)); + Assert.Throws("target", () => builder.WithPrompts(target: null!)); IMcpServerBuilder nullBuilder = null!; Assert.Throws("builder", () => nullBuilder.WithPrompts()); + Assert.Throws("builder", () => nullBuilder.WithPrompts(new object())); Assert.Throws("builder", () => nullBuilder.WithPrompts(Array.Empty())); Assert.Throws("builder", () => nullBuilder.WithPromptsFromAssembly()); } + [Fact] + public async Task WithPrompts_TargetInstance_UsesTarget() + { + ServiceCollection sc = new(); + + var target = new SimplePrompts(new ObjectWithId() { Id = "42" }); + sc.AddMcpServer().WithPrompts(target); + + McpServerPrompt prompt = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolPrompt.Name == "returns_string"); + var result = await prompt.GetAsync(new RequestContext(new Mock().Object) + { + Params = new GetPromptRequestParams + { + Name = "returns_string", + Arguments = new Dictionary + { + ["message"] = JsonSerializer.SerializeToElement("hello", AIJsonUtilities.DefaultOptions), + } + } + }, TestContext.Current.CancellationToken); + + Assert.Equal(target.ReturnsString("hello"), (result.Messages[0].Content as TextContentBlock)?.Text); + } + + [Fact] + public async Task WithPrompts_TargetInstance_UsesEnumerableImplementation() + { + ServiceCollection sc = new(); + + sc.AddMcpServer().WithPrompts(new MyPromptProvider()); + + var prompts = sc.BuildServiceProvider().GetServices().ToArray(); + Assert.Equal(2, prompts.Length); + Assert.Contains(prompts, t => t.ProtocolPrompt.Name == "Returns42"); + Assert.Contains(prompts, t => t.ProtocolPrompt.Name == "Returns43"); + } + + private sealed class MyPromptProvider : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return McpServerPrompt.Create(() => "42", new() { Name = "Returns42" }); + yield return McpServerPrompt.Create(() => "43", new() { Name = "Returns43" }); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + [Fact] public void Empty_Enumerables_Is_Allowed() { diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs index ed930b17..e6b177f5 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs @@ -4,8 +4,12 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using Moq; +using System.Collections; using System.ComponentModel; +using System.Text.Json; using System.Threading.Channels; +using static ModelContextProtocol.Tests.Configuration.McpServerBuilderExtensionsPromptsTests; namespace ModelContextProtocol.Tests.Configuration; @@ -243,13 +247,59 @@ public void WithResources_InvalidArgs_Throws() Assert.Throws("resourceTemplates", () => builder.WithResources((IEnumerable)null!)); Assert.Throws("resourceTemplateTypes", () => builder.WithResources((IEnumerable)null!)); + Assert.Throws("target", () => builder.WithResources(target: null!)); IMcpServerBuilder nullBuilder = null!; Assert.Throws("builder", () => nullBuilder.WithResources()); + Assert.Throws("builder", () => nullBuilder.WithResources(new object())); Assert.Throws("builder", () => nullBuilder.WithResources(Array.Empty())); Assert.Throws("builder", () => nullBuilder.WithResourcesFromAssembly()); } + [Fact] + public async Task WithResources_TargetInstance_UsesTarget() + { + ServiceCollection sc = new(); + + var target = new ResourceWithId(new ObjectWithId() { Id = "42" }); + sc.AddMcpServer().WithResources(target); + + McpServerResource resource = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolResource?.Name == "returns_string"); + var result = await resource.ReadAsync(new RequestContext(new Mock().Object) + { + Params = new() + { + Uri = "returns://string" + } + }, TestContext.Current.CancellationToken); + + Assert.Equal(target.ReturnsString(), (result?.Contents[0] as TextResourceContents)?.Text); + } + + [Fact] + public async Task WithResources_TargetInstance_UsesEnumerableImplementation() + { + ServiceCollection sc = new(); + + sc.AddMcpServer().WithResources(new MyResourceProvider()); + + var resources = sc.BuildServiceProvider().GetServices().ToArray(); + Assert.Equal(2, resources.Length); + Assert.Contains(resources, t => t.ProtocolResource?.Name == "Returns42"); + Assert.Contains(resources, t => t.ProtocolResource?.Name == "Returns43"); + } + + private sealed class MyResourceProvider : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return McpServerResource.Create(() => "42", new() { Name = "Returns42" }); + yield return McpServerResource.Create(() => "43", new() { Name = "Returns43" }); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + [Fact] public void Empty_Enumerables_Is_Allowed() { @@ -307,4 +357,11 @@ public sealed class MoreResources [McpServerResource, Description("Another neat direct resource")] public static string AnotherNeatDirectResource() => "This is a neat resource"; } + + [McpServerResourceType] + public sealed class ResourceWithId(ObjectWithId id) + { + [McpServerResource(UriTemplate = "returns://string")] + public string ReturnsString() => $"Id: {id.Id}"; + } } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 35f833d5..dbea036d 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -5,6 +5,8 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using Moq; +using System.Collections; using System.Collections.Concurrent; using System.ComponentModel; using System.IO.Pipelines; @@ -403,9 +405,11 @@ public void WithTools_InvalidArgs_Throws() Assert.Throws("tools", () => builder.WithTools((IEnumerable)null!)); Assert.Throws("toolTypes", () => builder.WithTools((IEnumerable)null!)); + Assert.Throws("target", () => builder.WithTools(target: null!)); IMcpServerBuilder nullBuilder = null!; Assert.Throws("builder", () => nullBuilder.WithTools()); + Assert.Throws("builder", () => nullBuilder.WithTools(new object())); Assert.Throws("builder", () => nullBuilder.WithTools(Array.Empty())); Assert.Throws("builder", () => nullBuilder.WithToolsFromAssembly()); } @@ -503,6 +507,44 @@ public void WithToolsFromAssembly_Parameters_Satisfiable_From_DI(ServiceLifetime } } + [Fact] + public async Task WithTools_TargetInstance_UsesTarget() + { + ServiceCollection sc = new(); + + var target = new EchoTool(new ObjectWithId()); + sc.AddMcpServer().WithTools(target, BuilderToolsJsonContext.Default.Options); + + McpServerTool tool = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolTool.Name == "get_ctor_parameter"); + var result = await tool.InvokeAsync(new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); + + Assert.Equal(target.GetCtorParameter(), (result.Content[0] as TextContentBlock)?.Text); + } + + [Fact] + public async Task WithTools_TargetInstance_UsesEnumerableImplementation() + { + ServiceCollection sc = new(); + + sc.AddMcpServer().WithTools(new MyToolProvider()); + + var tools = sc.BuildServiceProvider().GetServices().ToArray(); + Assert.Equal(2, tools.Length); + Assert.Contains(tools, t => t.ProtocolTool.Name == "Returns42"); + Assert.Contains(tools, t => t.ProtocolTool.Name == "Returns43"); + } + + private sealed class MyToolProvider : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return McpServerTool.Create(() => "42", new() { Name = "Returns42" }); + yield return McpServerTool.Create(() => "43", new() { Name = "Returns43" }); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + [Fact] public async Task Recognizes_Parameter_Types() {