Skip to content

Commit 1657a51

Browse files
authored
Merge branch 'modelcontextprotocol:main' into setting-default-method
2 parents 558aa5a + 9a8592e commit 1657a51

File tree

7 files changed

+197
-16
lines changed

7 files changed

+197
-16
lines changed

src/mcp/client/auth.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,6 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
546546
logger.exception("OAuth flow error")
547547
raise
548548

549-
# Retry with new tokens
550-
self._add_auth_header(request)
551-
yield request
549+
# Retry with new tokens
550+
self._add_auth_header(request)
551+
yield request

src/mcp/client/stdio/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"HOMEPATH",
3333
"LOCALAPPDATA",
3434
"PATH",
35+
"PATHEXT",
3536
"PROCESSOR_ARCHITECTURE",
3637
"SYSTEMDRIVE",
3738
"SYSTEMROOT",

src/mcp/server/fastmcp/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from importlib.metadata import version
44

55
from .server import Context, FastMCP
6-
from .utilities.types import Image
6+
from .utilities.types import Audio, Image
77

88
__version__ = version("mcp")
9-
__all__ = ["FastMCP", "Context", "Image"]
9+
__all__ = ["FastMCP", "Context", "Image", "Audio"]

src/mcp/server/fastmcp/utilities/func_metadata.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from mcp.server.fastmcp.exceptions import InvalidSignature
2323
from mcp.server.fastmcp.utilities.logging import get_logger
24-
from mcp.server.fastmcp.utilities.types import Image
24+
from mcp.server.fastmcp.utilities.types import Audio, Image
2525
from mcp.types import ContentBlock, TextContent
2626

2727
logger = get_logger(__name__)
@@ -506,6 +506,9 @@ def _convert_to_content(
506506
if isinstance(result, Image):
507507
return [result.to_image_content()]
508508

509+
if isinstance(result, Audio):
510+
return [result.to_audio_content()]
511+
509512
if isinstance(result, list | tuple):
510513
return list(
511514
chain.from_iterable(

src/mcp/server/fastmcp/utilities/types.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import base64
44
from pathlib import Path
55

6-
from mcp.types import ImageContent
6+
from mcp.types import AudioContent, ImageContent
77

88

99
class Image:
@@ -52,3 +52,50 @@ def to_image_content(self) -> ImageContent:
5252
raise ValueError("No image data available")
5353

5454
return ImageContent(type="image", data=data, mimeType=self._mime_type)
55+
56+
57+
class Audio:
58+
"""Helper class for returning audio from tools."""
59+
60+
def __init__(
61+
self,
62+
path: str | Path | None = None,
63+
data: bytes | None = None,
64+
format: str | None = None,
65+
):
66+
if not bool(path) ^ bool(data):
67+
raise ValueError("Either path or data can be provided")
68+
69+
self.path = Path(path) if path else None
70+
self.data = data
71+
self._format = format
72+
self._mime_type = self._get_mime_type()
73+
74+
def _get_mime_type(self) -> str:
75+
"""Get MIME type from format or guess from file extension."""
76+
if self._format:
77+
return f"audio/{self._format.lower()}"
78+
79+
if self.path:
80+
suffix = self.path.suffix.lower()
81+
return {
82+
".wav": "audio/wav",
83+
".mp3": "audio/mpeg",
84+
".ogg": "audio/ogg",
85+
".flac": "audio/flac",
86+
".aac": "audio/aac",
87+
".m4a": "audio/mp4",
88+
}.get(suffix, "application/octet-stream")
89+
return "audio/wav" # default for raw binary data
90+
91+
def to_audio_content(self) -> AudioContent:
92+
"""Convert to MCP AudioContent."""
93+
if self.path:
94+
with open(self.path, "rb") as f:
95+
data = base64.b64encode(f.read()).decode()
96+
elif self.data is not None:
97+
data = base64.b64encode(self.data).decode()
98+
else:
99+
raise ValueError("No audio data available")
100+
101+
return AudioContent(type="audio", data=data, mimeType=self._mime_type)

tests/client/test_auth.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,19 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl
361361
),
362362
request=token_request,
363363
)
364-
token_request = await auth_flow.asend(token_response)
364+
365+
# After OAuth flow completes, the original request is retried with auth header
366+
final_request = await auth_flow.asend(token_response)
367+
assert final_request.headers["Authorization"] == "Bearer new_access_token"
368+
assert final_request.method == "GET"
369+
assert str(final_request.url) == "https://api.example.com/v1/mcp"
370+
371+
# Send final success response to properly close the generator
372+
final_response = httpx.Response(200, request=final_request)
373+
try:
374+
await auth_flow.asend(final_response)
375+
except StopAsyncIteration:
376+
pass # Expected - generator should complete
365377

366378
@pytest.mark.anyio
367379
async def test_handle_metadata_response_success(self, oauth_provider: OAuthClientProvider):
@@ -694,11 +706,61 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide
694706
assert final_request.method == "GET"
695707
assert str(final_request.url) == "https://api.example.com/mcp"
696708

709+
# Send final success response to properly close the generator
710+
final_response = httpx.Response(200, request=final_request)
711+
try:
712+
await auth_flow.asend(final_response)
713+
except StopAsyncIteration:
714+
pass # Expected - generator should complete
715+
697716
# Verify tokens were stored
698717
assert oauth_provider.context.current_tokens is not None
699718
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
700719
assert oauth_provider.context.token_expiry_time is not None
701720

721+
@pytest.mark.anyio
722+
async def test_auth_flow_no_unnecessary_retry_after_oauth(
723+
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken
724+
):
725+
"""Test that requests are not retried unnecessarily - the core bug that caused 2x performance degradation."""
726+
# Pre-store valid tokens so no OAuth flow is needed
727+
await mock_storage.set_tokens(valid_tokens)
728+
oauth_provider.context.current_tokens = valid_tokens
729+
oauth_provider.context.token_expiry_time = time.time() + 1800
730+
oauth_provider._initialized = True
731+
732+
test_request = httpx.Request("GET", "https://api.example.com/mcp")
733+
auth_flow = oauth_provider.async_auth_flow(test_request)
734+
735+
# Count how many times the request is yielded
736+
request_yields = 0
737+
738+
# First request - should have auth header already
739+
request = await auth_flow.__anext__()
740+
request_yields += 1
741+
assert request.headers["Authorization"] == "Bearer test_access_token"
742+
743+
# Send a successful 200 response
744+
response = httpx.Response(200, request=request)
745+
746+
# In the buggy version, this would yield the request AGAIN unconditionally
747+
# In the fixed version, this should end the generator
748+
try:
749+
await auth_flow.asend(response) # extra request
750+
request_yields += 1
751+
# If we reach here, the bug is present
752+
pytest.fail(
753+
f"Unnecessary retry detected! Request was yielded {request_yields} times. "
754+
f"This indicates the retry logic bug that caused 2x performance degradation. "
755+
f"The request should only be yielded once for successful responses."
756+
)
757+
except StopAsyncIteration:
758+
# This is the expected behavior - no unnecessary retry
759+
pass
760+
761+
# Verify exactly one request was yielded (no double-sending)
762+
assert request_yields == 1, f"Expected 1 request yield, got {request_yields}"
763+
702764

703765
@pytest.mark.parametrize(
704766
(

tests/server/fastmcp/test_server.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from mcp.server.fastmcp import Context, FastMCP
1111
from mcp.server.fastmcp.prompts.base import Message, UserMessage
1212
from mcp.server.fastmcp.resources import FileResource, FunctionResource
13-
from mcp.server.fastmcp.utilities.types import Image
13+
from mcp.server.fastmcp.utilities.types import Audio, Image
1414
from mcp.server.session import ServerSession
1515
from mcp.shared.exceptions import McpError
1616
from mcp.shared.memory import (
@@ -195,6 +195,10 @@ def image_tool_fn(path: str) -> Image:
195195
return Image(path)
196196

197197

198+
def audio_tool_fn(path: str) -> Audio:
199+
return Audio(path)
200+
201+
198202
def mixed_content_tool_fn() -> list[ContentBlock]:
199203
return [
200204
TextContent(type="text", text="Hello"),
@@ -300,6 +304,60 @@ async def test_tool_image_helper(self, tmp_path: Path):
300304
# Check structured content - Image return type should NOT have structured output
301305
assert result.structuredContent is None
302306

307+
@pytest.mark.anyio
308+
async def test_tool_audio_helper(self, tmp_path: Path):
309+
# Create a test audio
310+
audio_path = tmp_path / "test.wav"
311+
audio_path.write_bytes(b"fake wav data")
312+
313+
mcp = FastMCP()
314+
mcp.add_tool(audio_tool_fn)
315+
async with client_session(mcp._mcp_server) as client:
316+
result = await client.call_tool("audio_tool_fn", {"path": str(audio_path)})
317+
assert len(result.content) == 1
318+
content = result.content[0]
319+
assert isinstance(content, AudioContent)
320+
assert content.type == "audio"
321+
assert content.mimeType == "audio/wav"
322+
# Verify base64 encoding
323+
decoded = base64.b64decode(content.data)
324+
assert decoded == b"fake wav data"
325+
# Check structured content - Image return type should NOT have structured output
326+
assert result.structuredContent is None
327+
328+
@pytest.mark.parametrize(
329+
"filename,expected_mime_type",
330+
[
331+
("test.wav", "audio/wav"),
332+
("test.mp3", "audio/mpeg"),
333+
("test.ogg", "audio/ogg"),
334+
("test.flac", "audio/flac"),
335+
("test.aac", "audio/aac"),
336+
("test.m4a", "audio/mp4"),
337+
("test.unknown", "application/octet-stream"), # Unknown extension fallback
338+
],
339+
)
340+
@pytest.mark.anyio
341+
async def test_tool_audio_suffix_detection(self, tmp_path: Path, filename: str, expected_mime_type: str):
342+
"""Test that Audio helper correctly detects MIME types from file suffixes"""
343+
mcp = FastMCP()
344+
mcp.add_tool(audio_tool_fn)
345+
346+
# Create a test audio file with the specific extension
347+
audio_path = tmp_path / filename
348+
audio_path.write_bytes(b"fake audio data")
349+
350+
async with client_session(mcp._mcp_server) as client:
351+
result = await client.call_tool("audio_tool_fn", {"path": str(audio_path)})
352+
assert len(result.content) == 1
353+
content = result.content[0]
354+
assert isinstance(content, AudioContent)
355+
assert content.type == "audio"
356+
assert content.mimeType == expected_mime_type
357+
# Verify base64 encoding
358+
decoded = base64.b64decode(content.data)
359+
assert decoded == b"fake audio data"
360+
303361
@pytest.mark.anyio
304362
async def test_tool_mixed_content(self):
305363
mcp = FastMCP()
@@ -332,19 +390,24 @@ async def test_tool_mixed_content(self):
332390
assert structured_result[i][key] == value
333391

334392
@pytest.mark.anyio
335-
async def test_tool_mixed_list_with_image(self, tmp_path: Path):
393+
async def test_tool_mixed_list_with_audio_and_image(self, tmp_path: Path):
336394
"""Test that lists containing Image objects and other types are handled
337395
correctly"""
338396
# Create a test image
339397
image_path = tmp_path / "test.png"
340398
image_path.write_bytes(b"test image data")
341399

400+
# Create a test audio
401+
audio_path = tmp_path / "test.wav"
402+
audio_path.write_bytes(b"test audio data")
403+
342404
# TODO(Marcelo): It seems if we add the proper type hint, it generates an invalid JSON schema.
343405
# We need to fix this.
344406
def mixed_list_fn() -> list: # type: ignore
345407
return [ # type: ignore
346408
"text message",
347409
Image(image_path),
410+
Audio(audio_path),
348411
{"key": "value"},
349412
TextContent(type="text", text="direct content"),
350413
]
@@ -353,7 +416,7 @@ def mixed_list_fn() -> list: # type: ignore
353416
mcp.add_tool(mixed_list_fn) # type: ignore
354417
async with client_session(mcp._mcp_server) as client:
355418
result = await client.call_tool("mixed_list_fn", {})
356-
assert len(result.content) == 4
419+
assert len(result.content) == 5
357420
# Check text conversion
358421
content1 = result.content[0]
359422
assert isinstance(content1, TextContent)
@@ -363,14 +426,19 @@ def mixed_list_fn() -> list: # type: ignore
363426
assert isinstance(content2, ImageContent)
364427
assert content2.mimeType == "image/png"
365428
assert base64.b64decode(content2.data) == b"test image data"
366-
# Check dict conversion
429+
# Check audio conversion
367430
content3 = result.content[2]
368-
assert isinstance(content3, TextContent)
369-
assert '"key": "value"' in content3.text
370-
# Check direct TextContent
431+
assert isinstance(content3, AudioContent)
432+
assert content3.mimeType == "audio/wav"
433+
assert base64.b64decode(content3.data) == b"test audio data"
434+
# Check dict conversion
371435
content4 = result.content[3]
372436
assert isinstance(content4, TextContent)
373-
assert content4.text == "direct content"
437+
assert '"key": "value"' in content4.text
438+
# Check direct TextContent
439+
content5 = result.content[4]
440+
assert isinstance(content5, TextContent)
441+
assert content5.text == "direct content"
374442
# Check structured content - untyped list with Image objects should NOT have structured output
375443
assert result.structuredContent is None
376444

0 commit comments

Comments
 (0)