diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 1853ce7c1..bcf80d62a 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -151,7 +151,6 @@ async def initialize(self) -> types.InitializeResult: result = await self.send_request( types.ClientRequest( types.InitializeRequest( - method="initialize", params=types.InitializeRequestParams( protocolVersion=types.LATEST_PROTOCOL_VERSION, capabilities=types.ClientCapabilities( @@ -170,20 +169,14 @@ async def initialize(self) -> types.InitializeResult: if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}") - await self.send_notification( - types.ClientNotification(types.InitializedNotification(method="notifications/initialized")) - ) + await self.send_notification(types.ClientNotification(types.InitializedNotification())) return result async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" return await self.send_request( - types.ClientRequest( - types.PingRequest( - method="ping", - ) - ), + types.ClientRequest(types.PingRequest()), types.EmptyResult, ) @@ -198,7 +191,6 @@ async def send_progress_notification( await self.send_notification( types.ClientNotification( types.ProgressNotification( - method="notifications/progress", params=types.ProgressNotificationParams( progressToken=progress_token, progress=progress, @@ -214,7 +206,6 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul return await self.send_request( types.ClientRequest( types.SetLevelRequest( - method="logging/setLevel", params=types.SetLevelRequestParams(level=level), ) ), @@ -226,7 +217,6 @@ async def list_resources(self, cursor: str | None = None) -> types.ListResources return await self.send_request( types.ClientRequest( types.ListResourcesRequest( - method="resources/list", params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, ) ), @@ -238,7 +228,6 @@ async def list_resource_templates(self, cursor: str | None = None) -> types.List return await self.send_request( types.ClientRequest( types.ListResourceTemplatesRequest( - method="resources/templates/list", params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, ) ), @@ -250,7 +239,6 @@ async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: return await self.send_request( types.ClientRequest( types.ReadResourceRequest( - method="resources/read", params=types.ReadResourceRequestParams(uri=uri), ) ), @@ -262,7 +250,6 @@ async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: return await self.send_request( types.ClientRequest( types.SubscribeRequest( - method="resources/subscribe", params=types.SubscribeRequestParams(uri=uri), ) ), @@ -274,7 +261,6 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: return await self.send_request( types.ClientRequest( types.UnsubscribeRequest( - method="resources/unsubscribe", params=types.UnsubscribeRequestParams(uri=uri), ) ), @@ -293,7 +279,6 @@ async def call_tool( result = await self.send_request( types.ClientRequest( types.CallToolRequest( - method="tools/call", params=types.CallToolRequestParams( name=name, arguments=arguments, @@ -337,7 +322,6 @@ async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResu return await self.send_request( types.ClientRequest( types.ListPromptsRequest( - method="prompts/list", params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, ) ), @@ -349,7 +333,6 @@ async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) - return await self.send_request( types.ClientRequest( types.GetPromptRequest( - method="prompts/get", params=types.GetPromptRequestParams(name=name, arguments=arguments), ) ), @@ -370,7 +353,6 @@ async def complete( return await self.send_request( types.ClientRequest( types.CompleteRequest( - method="completion/complete", params=types.CompleteRequestParams( ref=ref, argument=types.CompletionArgument(**argument), @@ -386,7 +368,6 @@ async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult: result = await self.send_request( types.ClientRequest( types.ListToolsRequest( - method="tools/list", params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, ) ), @@ -402,13 +383,7 @@ async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult: async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" - await self.send_notification( - types.ClientNotification( - types.RootsListChangedNotification( - method="notifications/roots/list_changed", - ) - ) - ) + await self.send_notification(types.ClientNotification(types.RootsListChangedNotification())) async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: ctx = RequestContext[ClientSession, Any]( diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 5c696b136..48df1171d 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -186,7 +186,6 @@ async def send_log_message( await self.send_notification( types.ServerNotification( types.LoggingMessageNotification( - method="notifications/message", params=types.LoggingMessageNotificationParams( level=level, data=data, @@ -202,7 +201,6 @@ async def send_resource_updated(self, uri: AnyUrl) -> None: await self.send_notification( types.ServerNotification( types.ResourceUpdatedNotification( - method="notifications/resources/updated", params=types.ResourceUpdatedNotificationParams(uri=uri), ) ) @@ -225,7 +223,6 @@ async def create_message( return await self.send_request( request=types.ServerRequest( types.CreateMessageRequest( - method="sampling/createMessage", params=types.CreateMessageRequestParams( messages=messages, systemPrompt=system_prompt, @@ -247,11 +244,7 @@ async def create_message( async def list_roots(self) -> types.ListRootsResult: """Send a roots/list request.""" return await self.send_request( - types.ServerRequest( - types.ListRootsRequest( - method="roots/list", - ) - ), + types.ServerRequest(types.ListRootsRequest()), types.ListRootsResult, ) @@ -273,7 +266,6 @@ async def elicit( return await self.send_request( types.ServerRequest( types.ElicitRequest( - method="elicitation/create", params=types.ElicitRequestParams( message=message, requestedSchema=requestedSchema, @@ -287,11 +279,7 @@ async def elicit( async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" return await self.send_request( - types.ServerRequest( - types.PingRequest( - method="ping", - ) - ), + types.ServerRequest(types.PingRequest()), types.EmptyResult, ) @@ -307,7 +295,6 @@ async def send_progress_notification( await self.send_notification( types.ServerNotification( types.ProgressNotification( - method="notifications/progress", params=types.ProgressNotificationParams( progressToken=progress_token, progress=progress, @@ -321,33 +308,15 @@ async def send_progress_notification( async def send_resource_list_changed(self) -> None: """Send a resource list changed notification.""" - await self.send_notification( - types.ServerNotification( - types.ResourceListChangedNotification( - method="notifications/resources/list_changed", - ) - ) - ) + await self.send_notification(types.ServerNotification(types.ResourceListChangedNotification())) async def send_tool_list_changed(self) -> None: """Send a tool list changed notification.""" - await self.send_notification( - types.ServerNotification( - types.ToolListChangedNotification( - method="notifications/tools/list_changed", - ) - ) - ) + await self.send_notification(types.ServerNotification(types.ToolListChangedNotification())) async def send_prompt_list_changed(self) -> None: """Send a prompt list changed notification.""" - await self.send_notification( - types.ServerNotification( - types.PromptListChangedNotification( - method="notifications/prompts/list_changed", - ) - ) - ) + await self.send_notification(types.ServerNotification(types.PromptListChangedNotification())) async def _handle_incoming(self, req: ServerRequestResponder) -> None: await self._incoming_message_stream_writer.send(req) diff --git a/src/mcp/types.py b/src/mcp/types.py index 98fefa080..62feda87a 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -326,7 +326,7 @@ class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]]) to begin initialization. """ - method: Literal["initialize"] + method: Literal["initialize"] = "initialize" params: InitializeRequestParams @@ -347,7 +347,7 @@ class InitializedNotification(Notification[NotificationParams | None, Literal["n finished. """ - method: Literal["notifications/initialized"] + method: Literal["notifications/initialized"] = "notifications/initialized" params: NotificationParams | None = None @@ -357,7 +357,7 @@ class PingRequest(Request[RequestParams | None, Literal["ping"]]): still alive. """ - method: Literal["ping"] + method: Literal["ping"] = "ping" params: RequestParams | None = None @@ -390,14 +390,14 @@ class ProgressNotification(Notification[ProgressNotificationParams, Literal["not long-running request. """ - method: Literal["notifications/progress"] + method: Literal["notifications/progress"] = "notifications/progress" params: ProgressNotificationParams class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]): """Sent from the client to request a list of resources the server has.""" - method: Literal["resources/list"] + method: Literal["resources/list"] = "resources/list" class Annotations(BaseModel): @@ -464,7 +464,7 @@ class ListResourcesResult(PaginatedResult): class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]): """Sent from the client to request a list of resource templates the server has.""" - method: Literal["resources/templates/list"] + method: Literal["resources/templates/list"] = "resources/templates/list" class ListResourceTemplatesResult(PaginatedResult): @@ -487,7 +487,7 @@ class ReadResourceRequestParams(RequestParams): class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]): """Sent from the client to the server, to read a specific resource URI.""" - method: Literal["resources/read"] + method: Literal["resources/read"] = "resources/read" params: ReadResourceRequestParams @@ -537,7 +537,7 @@ class ResourceListChangedNotification( of resources it can read from has changed. """ - method: Literal["notifications/resources/list_changed"] + method: Literal["notifications/resources/list_changed"] = "notifications/resources/list_changed" params: NotificationParams | None = None @@ -558,7 +558,7 @@ class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscr whenever a particular resource changes. """ - method: Literal["resources/subscribe"] + method: Literal["resources/subscribe"] = "resources/subscribe" params: SubscribeRequestParams @@ -576,7 +576,7 @@ class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/un the server. """ - method: Literal["resources/unsubscribe"] + method: Literal["resources/unsubscribe"] = "resources/unsubscribe" params: UnsubscribeRequestParams @@ -599,14 +599,14 @@ class ResourceUpdatedNotification( changed and may need to be read again. """ - method: Literal["notifications/resources/updated"] + method: Literal["notifications/resources/updated"] = "notifications/resources/updated" params: ResourceUpdatedNotificationParams class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]): """Sent from the client to request a list of prompts and prompt templates.""" - method: Literal["prompts/list"] + method: Literal["prompts/list"] = "prompts/list" class PromptArgument(BaseModel): @@ -655,7 +655,7 @@ class GetPromptRequestParams(RequestParams): class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]): """Used by the client to get a prompt provided by the server.""" - method: Literal["prompts/get"] + method: Literal["prompts/get"] = "prompts/get" params: GetPromptRequestParams @@ -782,14 +782,14 @@ class PromptListChangedNotification( of prompts it offers has changed. """ - method: Literal["notifications/prompts/list_changed"] + method: Literal["notifications/prompts/list_changed"] = "notifications/prompts/list_changed" params: NotificationParams | None = None class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]): """Sent from the client to request a list of tools the server has.""" - method: Literal["tools/list"] + method: Literal["tools/list"] = "tools/list" class ToolAnnotations(BaseModel): @@ -879,7 +879,7 @@ class CallToolRequestParams(RequestParams): class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]): """Used by the client to invoke a tool provided by the server.""" - method: Literal["tools/call"] + method: Literal["tools/call"] = "tools/call" params: CallToolRequestParams @@ -898,7 +898,7 @@ class ToolListChangedNotification(Notification[NotificationParams | None, Litera of tools it offers has changed. """ - method: Literal["notifications/tools/list_changed"] + method: Literal["notifications/tools/list_changed"] = "notifications/tools/list_changed" params: NotificationParams | None = None @@ -916,7 +916,7 @@ class SetLevelRequestParams(RequestParams): class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]): """A request from the client to the server, to enable or adjust logging.""" - method: Literal["logging/setLevel"] + method: Literal["logging/setLevel"] = "logging/setLevel" params: SetLevelRequestParams @@ -938,7 +938,7 @@ class LoggingMessageNotificationParams(NotificationParams): class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]): """Notification of a log message passed from server to client.""" - method: Literal["notifications/message"] + method: Literal["notifications/message"] = "notifications/message" params: LoggingMessageNotificationParams @@ -1033,7 +1033,7 @@ class CreateMessageRequestParams(RequestParams): class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]): """A request from the server to sample an LLM via the client.""" - method: Literal["sampling/createMessage"] + method: Literal["sampling/createMessage"] = "sampling/createMessage" params: CreateMessageRequestParams @@ -1105,7 +1105,7 @@ class CompleteRequestParams(RequestParams): class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]): """A request from the client to the server, to ask for completion options.""" - method: Literal["completion/complete"] + method: Literal["completion/complete"] = "completion/complete" params: CompleteRequestParams @@ -1144,7 +1144,7 @@ class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]): structure or access specific locations that the client has permission to read from. """ - method: Literal["roots/list"] + method: Literal["roots/list"] = "roots/list" params: RequestParams | None = None @@ -1193,7 +1193,7 @@ class RootsListChangedNotification( using the ListRootsRequest. """ - method: Literal["notifications/roots/list_changed"] + method: Literal["notifications/roots/list_changed"] = "notifications/roots/list_changed" params: NotificationParams | None = None @@ -1213,7 +1213,7 @@ class CancelledNotification(Notification[CancelledNotificationParams, Literal["n previously-issued request. """ - method: Literal["notifications/cancelled"] + method: Literal["notifications/cancelled"] = "notifications/cancelled" params: CancelledNotificationParams @@ -1259,7 +1259,7 @@ class ElicitRequestParams(RequestParams): class ElicitRequest(Request[ElicitRequestParams, Literal["elicitation/create"]]): """A request from the server to elicit information from the client.""" - method: Literal["elicitation/create"] + method: Literal["elicitation/create"] = "elicitation/create" params: ElicitRequestParams diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py index 0752d649f..e0b481581 100644 --- a/tests/client/test_resource_cleanup.py +++ b/tests/client/test_resource_cleanup.py @@ -35,7 +35,7 @@ async def _send_response(self, request_id: RequestId, response: SendResultT | Er ) # Create a test request - request = ClientRequest(PingRequest(method="ping")) + request = ClientRequest(PingRequest()) # Patch the _write_stream.send method to raise an exception async def mock_send(*args: Any, **kwargs: Any): diff --git a/tests/issues/test_129_resource_templates.py b/tests/issues/test_129_resource_templates.py index 4bedb15d5..ec9264c47 100644 --- a/tests/issues/test_129_resource_templates.py +++ b/tests/issues/test_129_resource_templates.py @@ -24,7 +24,7 @@ def get_user_profile(user_id: str) -> str: # Note: list_resource_templates() returns a decorator that wraps the handler # The handler returns a ServerResult with a ListResourceTemplatesResult inside result = await mcp._mcp_server.request_handlers[types.ListResourceTemplatesRequest]( - types.ListResourceTemplatesRequest(method="resources/templates/list", params=None) + types.ListResourceTemplatesRequest(params=None) ) assert isinstance(result.root, types.ListResourceTemplatesResult) templates = result.root.resourceTemplates diff --git a/tests/issues/test_342_base64_encoding.py b/tests/issues/test_342_base64_encoding.py index 6a6e410c7..da5695997 100644 --- a/tests/issues/test_342_base64_encoding.py +++ b/tests/issues/test_342_base64_encoding.py @@ -54,7 +54,6 @@ async def read_resource(uri: AnyUrl) -> list[ReadResourceContents]: # Create a request request = ReadResourceRequest( - method="resources/read", params=ReadResourceRequestParams(uri=AnyUrl("test://resource")), ) diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index e7149826b..516642c4b 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -61,7 +61,6 @@ async def first_request(): await client.send_request( ClientRequest( CallToolRequest( - method="tools/call", params=CallToolRequestParams(name="test_tool", arguments={}), ) ), @@ -83,7 +82,6 @@ async def first_request(): await client.send_notification( ClientNotification( CancelledNotification( - method="notifications/cancelled", params=CancelledNotificationParams( requestId=first_request_id, reason="Testing server recovery", @@ -96,7 +94,6 @@ async def first_request(): result = await client.send_request( ClientRequest( CallToolRequest( - method="tools/call", params=CallToolRequestParams(name="test_tool", arguments={}), ) ), diff --git a/tests/server/test_read_resource.py b/tests/server/test_read_resource.py index 91f6ef8c8..d97477e10 100644 --- a/tests/server/test_read_resource.py +++ b/tests/server/test_read_resource.py @@ -35,7 +35,6 @@ async def read_resource(uri: AnyUrl) -> Iterable[ReadResourceContents]: # Create a request request = types.ReadResourceRequest( - method="resources/read", params=types.ReadResourceRequestParams(uri=FileUrl(temp_file.as_uri())), ) @@ -63,7 +62,6 @@ async def read_resource(uri: AnyUrl) -> Iterable[ReadResourceContents]: # Create a request request = types.ReadResourceRequest( - method="resources/read", params=types.ReadResourceRequestParams(uri=FileUrl(temp_file.as_uri())), ) @@ -95,7 +93,6 @@ async def read_resource(uri: AnyUrl) -> Iterable[ReadResourceContents]: # Create a request request = types.ReadResourceRequest( - method="resources/read", params=types.ReadResourceRequestParams(uri=FileUrl(temp_file.as_uri())), ) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index c2c023c71..320693786 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -88,7 +88,6 @@ async def make_request(client_session: ClientSession): await client_session.send_request( ClientRequest( types.CallToolRequest( - method="tools/call", params=types.CallToolRequestParams(name="slow_tool", arguments={}), ) ), @@ -113,7 +112,6 @@ async def make_request(client_session: ClientSession): await client_session.send_notification( ClientNotification( CancelledNotification( - method="notifications/cancelled", params=CancelledNotificationParams(requestId=request_id), ) ) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index ecbe6eb08..55800da33 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1138,7 +1138,6 @@ async def run_tool(): await session.send_request( types.ClientRequest( types.CallToolRequest( - method="tools/call", params=types.CallToolRequestParams( name="wait_for_lock_with_notification", arguments={} ), @@ -1180,7 +1179,6 @@ async def run_tool(): result = await session.send_request( types.ClientRequest( types.CallToolRequest( - method="tools/call", params=types.CallToolRequestParams(name="release_lock", arguments={}), ) ), @@ -1193,7 +1191,6 @@ async def run_tool(): result = await session.send_request( types.ClientRequest( types.CallToolRequest( - method="tools/call", params=types.CallToolRequestParams(name="wait_for_lock_with_notification", arguments={}), ) ), diff --git a/tests/test_types.py b/tests/test_types.py index d7f2ac831..415eba66a 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,6 +1,15 @@ import pytest -from mcp.types import LATEST_PROTOCOL_VERSION, ClientRequest, JSONRPCMessage, JSONRPCRequest +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + ClientCapabilities, + ClientRequest, + Implementation, + InitializeRequest, + InitializeRequestParams, + JSONRPCMessage, + JSONRPCRequest, +) @pytest.mark.anyio @@ -25,3 +34,25 @@ async def test_jsonrpc_request(): assert request.root.method == "initialize" assert request.root.params is not None assert request.root.params["protocolVersion"] == LATEST_PROTOCOL_VERSION + + +@pytest.mark.anyio +async def test_method_initialization(): + """ + Test that the method is automatically set on object creation. + Testing just for InitializeRequest to keep the test simple, but should be set for other types as well. + """ + initialize_request = InitializeRequest( + params=InitializeRequestParams( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + clientInfo=Implementation( + name="mcp", + version="0.1.0", + ), + ) + ) + + assert initialize_request.method == "initialize", "method should be set to 'initialize'" + assert initialize_request.params is not None + assert initialize_request.params.protocolVersion == LATEST_PROTOCOL_VERSION