Skip to content

Feature: Non-Blocking call_tool and request state externalisation #1209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
4165200
add methods to enable call tool requests to be started and joined at …
davemssavage Jul 12, 2025
04ff73a
refactor args for clearer meaning, use error vs returning none on tim…
davemssavage Jul 12, 2025
288ebe3
add resume logic to request/join call_tool functions
davemssavage Jul 14, 2025
40028da
Remove None as valid return type from join_call_tool, fix typo ImMemo…
davemssavage Jul 14, 2025
161da46
send resume on init rather than part of join, refactor resume to be g…
davemssavage Jul 14, 2025
aa2cbec
fix import error
davemssavage Jul 16, 2025
7329cba
Refactor code to send resume as part of join call rather than it, thi…
davemssavage Jul 26, 2025
e4c25b7
simplify token capture using events rather than streams, add test for…
davemssavage Jul 27, 2025
79f3c4e
avoid exceptions during join call tool on timeout as this is expected…
davemssavage Jul 28, 2025
6c47890
Merge branch 'main' into feature/call-futures
davemssavage Jul 28, 2025
f262bb6
uv ruff fixes
davemssavage Jul 28, 2025
c5eab90
add assert for pyright checks
davemssavage Jul 28, 2025
79eb3c9
update test description
davemssavage Jul 28, 2025
a8ffd71
pass related request id on progress to allow this to trigger event id…
davemssavage Aug 16, 2025
cca5e34
ruff format fixes
davemssavage Aug 16, 2025
5750c32
Merge branch 'main' into feature/call-futures
davemssavage Aug 16, 2025
f1a973a
use move on after rather than fail_after to simplify code for the sam…
davemssavage Aug 18, 2025
746d3b8
add timeout to request_call_tool to enable clients to unblock if serv…
davemssavage Aug 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 89 additions & 2 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import mcp.types as types
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder, RequestStateManager
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS

DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
Expand Down Expand Up @@ -118,13 +118,15 @@ def __init__(
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
request_state_manager: RequestStateManager[types.ClientRequest, types.ClientResult] | None = None,
) -> None:
super().__init__(
read_stream,
write_stream,
types.ServerRequest,
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
request_state_manager=request_state_manager,
)
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
Expand All @@ -133,6 +135,7 @@ def __init__(
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
self._resumable = False

async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
Expand Down Expand Up @@ -170,6 +173,8 @@ async def initialize(self) -> types.InitializeResult:
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")

self._resumable = result.capabilities.resume and result.capabilities.resume.resumable

await self.send_notification(
types.ClientNotification(types.InitializedNotification(method="notifications/initialized"))
)
Expand Down Expand Up @@ -281,6 +286,88 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
types.EmptyResult,
)

async def request_call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
progress_callback: ProgressFnT | None = None,
timeout: float | None = None,
cancel_if_not_resumable: bool = False,
) -> types.RequestId | None:
if self._resumable:
captured_token = None
captured = anyio.Event()

async def capture_token(token: str):
nonlocal captured_token
captured_token = token
captured.set()

metadata = ClientMessageMetadata(on_resumption_token_update=capture_token)

request_id = await self.start_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name=name,
arguments=arguments,
),
)
),
progress_callback=progress_callback,
metadata=metadata,
)

try:
with anyio.fail_after(timeout):
while captured_token is None:
await captured.wait()

await self._request_state_manager.update_resume_token(request_id, captured_token)

return request_id
except TimeoutError:
if cancel_if_not_resumable:
with anyio.CancelScope(shield=True):
with anyio.move_on_after(timeout):
await self.cancel_call_tool(request_id=request_id)
return None
else:
return await self.start_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name=name,
arguments=arguments,
),
)
),
progress_callback=progress_callback,
)

async def join_call_tool(
self,
request_id: types.RequestId,
progress_callback: ProgressFnT | None = None,
request_read_timeout_seconds: timedelta | None = None,
done_on_timeout: bool = True,
) -> types.CallToolResult | None:
return await self.join_request(
request_id,
types.CallToolResult,
request_read_timeout_seconds=request_read_timeout_seconds,
progress_callback=progress_callback,
done_on_timeout=done_on_timeout,
)

async def cancel_call_tool(
self,
request_id: types.RequestId,
) -> bool:
return await self.cancel_request(request_id)

async def call_tool(
self,
name: str,
Expand Down
30 changes: 21 additions & 9 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
JSONRPCRequest,
JSONRPCResponse,
RequestId,
ResumeCapability,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -136,18 +137,26 @@ def _maybe_extract_session_id_from_response(
def _maybe_extract_protocol_version_from_message(
self,
message: JSONRPCMessage,
) -> None:
) -> JSONRPCMessage:
"""Extract protocol version from initialization response message."""
if isinstance(message.root, JSONRPCResponse) and message.root.result:
try:
# Parse the result as InitializeResult for type safety
init_result = InitializeResult.model_validate(message.root.result)
self.protocol_version = str(init_result.protocolVersion)
logger.info(f"Negotiated protocol version: {self.protocol_version}")
if init_result.capabilities.resume is None:
# resumeablity is predicated on the server and the transport
# this assumes that if the server hasn't explicitly configured
# that streamable http transports are resumeable
init_result.capabilities.resume = ResumeCapability(resumable=True)
message.root.result = init_result.model_dump()
except Exception as exc:
logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}")
logger.warning(f"Raw result: {message.root.result}")

return message

async def _handle_sse_event(
self,
sse: ServerSentEvent,
Expand All @@ -164,7 +173,7 @@ async def _handle_sse_event(

# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)
message = self._maybe_extract_protocol_version_from_message(message)

# If this is a response and we have original_request_id, replace it
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
Expand Down Expand Up @@ -304,7 +313,7 @@ async def _handle_json_response(

# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)
message = self._maybe_extract_protocol_version_from_message(message)

session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
Expand Down Expand Up @@ -335,7 +344,10 @@ async def _handle_sse_response(
break
except Exception as e:
logger.exception("Error reading SSE stream:")
await ctx.read_stream_writer.send(e)
try:
await ctx.read_stream_writer.send(e)
except anyio.ClosedResourceError:
pass

async def _handle_unexpected_content_type(
self,
Expand Down Expand Up @@ -473,8 +485,8 @@ async def streamablehttp_client(
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)

async with anyio.create_task_group() as tg:
try:
try:
async with anyio.create_task_group() as tg:
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")

async with httpx_client_factory(
Expand Down Expand Up @@ -506,6 +518,6 @@ def start_get_stream() -> None:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
1 change: 1 addition & 0 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,7 @@ async def report_progress(self, progress: float, total: float | None = None, mes
progress=progress,
total=total,
message=message,
related_request_id=self.request_context.request_id,
)

async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]:
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ async def send_progress_notification(
progress: float,
total: float | None = None,
message: str | None = None,
related_request_id: str | None = None,
related_request_id: types.RequestId | None = None,
) -> None:
"""Send a progress notification."""
await self.send_notification(
Expand Down
1 change: 0 additions & 1 deletion src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,6 @@ async def send_event(event_message: EventMessage) -> None:
async with msg_reader:
async for event_message in msg_reader:
event_data = self._create_event_data(event_message)

await sse_stream_writer.send(event_data)
except Exception:
logger.exception("Error in replay sender")
Expand Down
Loading
Loading