diff --git a/fastapi_mcp/openapi/convert.py b/fastapi_mcp/openapi/convert.py index 22e5c5e..9462372 100644 --- a/fastapi_mcp/openapi/convert.py +++ b/fastapi_mcp/openapi/convert.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import mcp.types as types @@ -9,6 +9,9 @@ generate_example_from_schema, resolve_schema_references, get_single_param_type_from_schema, + detect_form_encoded_content_type, + detect_multipart_content_type, + extract_form_field_names, ) logger = logging.getLogger(__name__) @@ -52,12 +55,40 @@ def convert_openapi_to_mcp_tools( logger.warning(f"Skipping operation with no operationId: {operation}") continue + # Detect content type and form fields from request body + request_body = operation.get("requestBody", {}) + try: + content_type, form_fields = _detect_content_type_and_form_fields( + request_body + ) + if content_type: + logger.info( + "Content type detection successful for operation %s: %s with %d form fields", + operation_id, + content_type, + len(form_fields), + ) + else: + logger.debug( + "No specific content type detected for operation %s, will use default JSON behavior", + operation_id, + ) + except Exception as e: + logger.error( + "Content type detection failed for operation %s: %s. Using default JSON behavior", + operation_id, + str(e), + ) + content_type, form_fields = None, [] + # Save operation details for later HTTP calls operation_map[operation_id] = { "path": path, "method": method, "parameters": operation.get("parameters", []), - "request_body": operation.get("requestBody", {}), + "request_body": request_body, + "content_type": content_type, + "form_fields": form_fields, } summary = operation.get("summary", "") @@ -85,7 +116,9 @@ def convert_openapi_to_mcp_tools( responses_to_include = responses if not describe_all_responses and success_response: # If we're not describing all responses, only include the success response - success_code = next((code for code in success_codes if str(code) in responses), None) + success_code = next( + (code for code in success_codes if str(code) in responses), None + ) if success_code: responses_to_include = {str(success_code): success_response} @@ -100,7 +133,9 @@ def convert_openapi_to_mcp_tools( # Add schema information if available if "content" in response_data: - for content_type, content_data in response_data["content"].items(): + for content_type, content_data in response_data[ + "content" + ].items(): if "schema" in content_data: schema = content_data["schema"] response_info += f"\nContent-Type: {content_type}" @@ -113,7 +148,9 @@ def convert_openapi_to_mcp_tools( # Check if content has examples if "examples" in content_data: - for example_key, example_data in content_data["examples"].items(): + for example_data in content_data[ + "examples" + ].values(): if "value" in example_data: example_response = example_data["value"] break @@ -123,33 +160,56 @@ def convert_openapi_to_mcp_tools( # If we have an example response, add it to the docs if example_response: - response_info += "\n\n**Example Response:**\n```json\n" - response_info += json.dumps(example_response, indent=2) + response_info += ( + "\n\n**Example Response:**\n```json\n" + ) + response_info += json.dumps( + example_response, indent=2 + ) response_info += "\n```" # Otherwise generate an example from the schema else: - generated_example = generate_example_from_schema(display_schema) + generated_example = generate_example_from_schema( + display_schema + ) if generated_example: - response_info += "\n\n**Example Response:**\n```json\n" - response_info += json.dumps(generated_example, indent=2) + response_info += ( + "\n\n**Example Response:**\n```json\n" + ) + response_info += json.dumps( + generated_example, indent=2 + ) response_info += "\n```" # Only include full schema information if requested if describe_full_response_schema: # Format schema information based on its type - if display_schema.get("type") == "array" and "items" in display_schema: + if ( + display_schema.get("type") == "array" + and "items" in display_schema + ): items_schema = display_schema["items"] response_info += "\n\n**Output Schema:** Array of items with the following structure:\n```json\n" - response_info += json.dumps(items_schema, indent=2) + response_info += json.dumps( + items_schema, indent=2 + ) response_info += "\n```" elif "properties" in display_schema: - response_info += "\n\n**Output Schema:**\n```json\n" - response_info += json.dumps(display_schema, indent=2) + response_info += ( + "\n\n**Output Schema:**\n```json\n" + ) + response_info += json.dumps( + display_schema, indent=2 + ) response_info += "\n```" else: - response_info += "\n\n**Output Schema:**\n```json\n" - response_info += json.dumps(display_schema, indent=2) + response_info += ( + "\n\n**Output Schema:**\n```json\n" + ) + response_info += json.dumps( + display_schema, indent=2 + ) response_info += "\n```" tool_description += response_info @@ -200,7 +260,9 @@ def convert_openapi_to_mcp_tools( for param_name, param in path_params: param_schema = param.get("schema", {}) param_desc = param.get("description", "") - param_required = param.get("required", True) # Path params are usually required + param_required = param.get( + "required", True + ) # Path params are usually required properties[param_name] = param_schema.copy() properties[param_name]["title"] = param_name @@ -225,7 +287,9 @@ def convert_openapi_to_mcp_tools( properties[param_name]["description"] = param_desc if "type" not in properties[param_name]: - properties[param_name]["type"] = get_single_param_type_from_schema(param_schema) + properties[param_name]["type"] = get_single_param_type_from_schema( + param_schema + ) if "default" in param_schema: properties[param_name]["default"] = param_schema["default"] @@ -245,7 +309,9 @@ def convert_openapi_to_mcp_tools( properties[param_name]["description"] = param_desc if "type" not in properties[param_name]: - properties[param_name]["type"] = get_single_param_type_from_schema(param_schema) + properties[param_name]["type"] = get_single_param_type_from_schema( + param_schema + ) if "default" in param_schema: properties[param_name]["default"] = param_schema["default"] @@ -254,14 +320,114 @@ def convert_openapi_to_mcp_tools( required_props.append(param_name) # Create a proper input schema for the tool - input_schema = {"type": "object", "properties": properties, "title": f"{operation_id}Arguments"} + input_schema = { + "type": "object", + "properties": properties, + "title": f"{operation_id}Arguments", + } if required_props: input_schema["required"] = required_props # Create the MCP tool definition - tool = types.Tool(name=operation_id, description=tool_description, inputSchema=input_schema) + tool = types.Tool( + name=operation_id, + description=tool_description, + inputSchema=input_schema, + ) tools.append(tool) return tools, operation_map + + +def _detect_content_type_and_form_fields( + request_body: Optional[Dict[str, Any]] +) -> Tuple[Optional[str], List[str]]: + """ + Detect the content type and form fields from a request body schema. + + Args: + request_body: The requestBody section from OpenAPI operation + + Returns: + A tuple of (content_type, form_fields) where: + - content_type is the detected content type or None + - form_fields is a list of form field names or empty list + """ + if not request_body or "content" not in request_body: + logger.debug("No request body or content found, using default JSON behavior") + return None, [] + + content = request_body["content"] + available_content_types = list(content.keys()) + logger.debug("Available content types for analysis: %s", available_content_types) + + # Priority order: form-encoded > multipart > JSON + detected_content_type = None + + # Check for form-encoded first (highest priority) + if detect_form_encoded_content_type(request_body): + detected_content_type = "application/x-www-form-urlencoded" + logger.debug("Detected form-encoded content type (priority 1)") + # Check for multipart second + elif detect_multipart_content_type(request_body): + detected_content_type = "multipart/form-data" + logger.debug("Detected multipart content type (priority 2)") + # Check for JSON as fallback + elif "application/json" in content: + detected_content_type = "application/json" + logger.debug("Detected JSON content type (fallback)") + + # If no supported content type found, log and return None + if not detected_content_type: + logger.warning( + "No supported content type found in %s, falling back to default JSON behavior", + available_content_types, + ) + return None, [] + + # Extract form fields for form-based content types + form_fields = [] + if detected_content_type in [ + "application/x-www-form-urlencoded", + "multipart/form-data", + ]: + try: + content_schema = content[detected_content_type].get("schema", {}) + if not content_schema: + logger.warning( + "No schema found for %s content type, cannot extract form fields", + detected_content_type, + ) + return None, [] + + form_fields = extract_form_field_names(content_schema) + logger.debug( + "Successfully extracted %d form fields for %s: %s", + len(form_fields), + detected_content_type, + form_fields, + ) + + if not form_fields: + logger.warning( + "No form fields found in schema for %s, falling back to JSON behavior", + detected_content_type, + ) + return None, [] + + except Exception as e: + logger.error( + "Failed to extract form fields from schema for %s: %s. Falling back to JSON behavior", + detected_content_type, + str(e), + ) + return None, [] + + logger.info( + "Content type detection complete: %s with %d form fields", + detected_content_type, + len(form_fields), + ) + return detected_content_type, form_fields diff --git a/fastapi_mcp/openapi/utils.py b/fastapi_mcp/openapi/utils.py index 1821d57..eb55f14 100644 --- a/fastapi_mcp/openapi/utils.py +++ b/fastapi_mcp/openapi/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, List def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str: @@ -162,3 +162,56 @@ def generate_example_from_schema(schema: Dict[str, Any]) -> Any: # Default case return None + + +def detect_form_encoded_content_type(request_body: Dict[str, Any]) -> bool: + """ + Detect if a request body uses application/x-www-form-urlencoded content type. + + Args: + request_body: The requestBody section from OpenAPI operation + + Returns: + True if form-encoded content type is detected, False otherwise + """ + if not request_body or "content" not in request_body: + return False + + content = request_body["content"] + return "application/x-www-form-urlencoded" in content + + +def detect_multipart_content_type(request_body: Dict[str, Any]) -> bool: + """ + Detect if a request body uses multipart/form-data content type. + + Args: + request_body: The requestBody section from OpenAPI operation + + Returns: + True if multipart content type is detected, False otherwise + """ + if not request_body or "content" not in request_body: + return False + + content = request_body["content"] + return "multipart/form-data" in content + + +def extract_form_field_names(schema: Dict[str, Any]) -> List[str]: + """ + Extract form field names from schema properties. + + Args: + schema: The schema object containing properties + + Returns: + List of form field names, or empty list if no properties found + """ + if not schema or not isinstance(schema, dict): + return [] + + if "properties" not in schema: + return [] + + return list(schema["properties"].keys()) diff --git a/fastapi_mcp/server.py b/fastapi_mcp/server.py index bb75106..c5f24fc 100644 --- a/fastapi_mcp/server.py +++ b/fastapi_mcp/server.py @@ -537,11 +537,21 @@ async def _execute_api_tool( if name.lower() in self._forward_headers: headers[name] = value - body = arguments if arguments else None + # Separate form fields from other parameters based on operation metadata + content_type = operation.get("content_type") + form_fields = operation.get("form_fields", []) + + body: Optional[Any] + if content_type in ["application/x-www-form-urlencoded", "multipart/form-data"] and form_fields: + # Only include form fields in the body for form-encoded requests + body = {k: v for k, v in arguments.items() if k in form_fields} + else: + # For JSON or other content types, include all remaining arguments + body = arguments or None try: logger.debug(f"Making {method.upper()} request to {path}") - response = await self._request(client, method, path, query, headers, body) + response = await self._request(client, method, path, query, headers, body, content_type) # TODO: Better typing for the AsyncClientProtocol. It should return a ResponseProtocol that has a json() method that returns a dict/list/etc. try: @@ -577,20 +587,40 @@ async def _request( query: Dict[str, Any], headers: Dict[str, str], body: Optional[Any], + content_type: Optional[str] = None, ) -> Any: if method.lower() == "get": return await client.get(path, params=query, headers=headers) elif method.lower() == "post": - return await client.post(path, params=query, headers=headers, json=body) + return await self._request_with_body(client, "post", path, query, headers, body, content_type) elif method.lower() == "put": - return await client.put(path, params=query, headers=headers, json=body) + return await self._request_with_body(client, "put", path, query, headers, body, content_type) elif method.lower() == "delete": return await client.delete(path, params=query, headers=headers) elif method.lower() == "patch": - return await client.patch(path, params=query, headers=headers, json=body) + return await self._request_with_body(client, "patch", path, query, headers, body, content_type) else: raise ValueError(f"Unsupported HTTP method: {method}") + async def _request_with_body( + self, + client: httpx.AsyncClient, + method: str, + path: str, + query: Dict[str, Any], + headers: Dict[str, str], + body: Optional[Any], + content_type: Optional[str] = None, + ) -> Any: + """Handle requests with body content, using appropriate encoding based on content type.""" + if content_type == "application/x-www-form-urlencoded": + return await client.request(method, path, params=query, headers=headers, data=body) + elif content_type == "multipart/form-data": + return await client.request(method, path, params=query, headers=headers, files=body) + else: + # Default to JSON for backward compatibility + return await client.request(method, path, params=query, headers=headers, json=body) + def _filter_tools(self, tools: List[types.Tool], openapi_schema: Dict[str, Any]) -> List[types.Tool]: """ Filter tools based on operation IDs and tags. diff --git a/tests/test_backward_compatibility.py b/tests/test_backward_compatibility.py new file mode 100644 index 0000000..2895627 --- /dev/null +++ b/tests/test_backward_compatibility.py @@ -0,0 +1,486 @@ +""" +Backward compatibility tests for form parameter support. + +These tests ensure that existing JSON endpoints continue to work without changes +and that mixed parameter scenarios work correctly. +""" + +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from fastapi import FastAPI + +from fastapi_mcp import FastApiMCP +from mcp.types import TextContent + + +@pytest.mark.asyncio +async def test_json_endpoints_unchanged(simple_fastapi_app: FastAPI): + """Test that existing JSON endpoints continue to work without changes.""" + mcp = FastApiMCP(simple_fastapi_app) + + # Mock the HTTP client response for JSON endpoint + mock_response = MagicMock() + mock_response.json.return_value = {"id": 1, "name": "Test Item", "price": 10.0} + mock_response.status_code = 200 + mock_response.text = '{"id": 1, "name": "Test Item", "price": 10.0}' + + # Mock the HTTP client + mock_client = AsyncMock() + mock_client.request.return_value = mock_response + + # Test JSON POST request (create_item endpoint) + tool_name = "create_item" + arguments = { + "item": {"id": 1, "name": "Test Item", "price": 10.0, "tags": ["tag1"], "description": "Test description"} + } + + # Execute the tool + with patch.object(mcp, "_http_client", mock_client): + result = await mcp._execute_api_tool( + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map + ) + + # Verify the result + assert len(result) == 1 + assert isinstance(result[0], TextContent) + + # Verify the HTTP client was called with JSON (backward compatibility) + mock_client.request.assert_called_once_with("post", "/items/", params={}, headers={}, json=arguments) + + +@pytest.mark.asyncio +async def test_get_endpoints_unchanged(simple_fastapi_app: FastAPI): + """Test that GET endpoints continue to work unchanged.""" + mcp = FastApiMCP(simple_fastapi_app) + + # Mock the HTTP client response + mock_response = MagicMock() + mock_response.json.return_value = {"id": 1, "name": "Test Item"} + mock_response.status_code = 200 + mock_response.text = '{"id": 1, "name": "Test Item"}' + + # Mock the HTTP client + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + # Test GET request + tool_name = "get_item" + arguments = {"item_id": 1} + + # Execute the tool + with patch.object(mcp, "_http_client", mock_client): + result = await mcp._execute_api_tool( + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map + ) + + # Verify the result + assert len(result) == 1 + assert isinstance(result[0], TextContent) + + # Verify the HTTP client was called correctly (unchanged) + mock_client.get.assert_called_once_with("/items/1", params={}, headers={}) + + +@pytest.mark.asyncio +async def test_mixed_parameters_query_path_json(simple_fastapi_app: FastAPI): + """Test mixed parameter scenarios with query, path, and JSON body parameters.""" + mcp = FastApiMCP(simple_fastapi_app) + + # Mock the HTTP client response + mock_response = MagicMock() + mock_response.json.return_value = {"id": 1, "name": "Updated Item"} + mock_response.status_code = 200 + mock_response.text = '{"id": 1, "name": "Updated Item"}' + + # Mock the HTTP client + mock_client = AsyncMock() + mock_client.request.return_value = mock_response + + # Test PUT request with path parameter and JSON body + tool_name = "update_item" + arguments = { + "item_id": 1, # Path parameter + "item": { # JSON body + "id": 1, + "name": "Updated Item", + "price": 15.0, + "tags": ["updated"], + "description": "Updated description", + }, + } + + # Execute the tool + with patch.object(mcp, "_http_client", mock_client): + result = await mcp._execute_api_tool( + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map + ) + + # Verify the result + assert len(result) == 1 + assert isinstance(result[0], TextContent) + + # Verify the HTTP client was called correctly + # Path parameter should be in URL, body should be JSON + expected_body = {"item": arguments["item"]} + mock_client.request.assert_called_once_with("put", "/items/1", params={}, headers={}, json=expected_body) + + +@pytest.mark.asyncio +async def test_mixed_parameters_query_path_only(simple_fastapi_app: FastAPI): + """Test mixed parameter scenarios with only query and path parameters.""" + mcp = FastApiMCP(simple_fastapi_app) + + # Mock the HTTP client response + mock_response = MagicMock() + mock_response.json.return_value = [{"id": 1, "name": "Item 1"}] + mock_response.status_code = 200 + mock_response.text = '[{"id": 1, "name": "Item 1"}]' + + # Mock the HTTP client + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + # Test GET request with query parameters + tool_name = "list_items" + arguments = { + "skip": 0, # Query parameter + "limit": 10, # Query parameter + "sort_by": "name", # Query parameter + } + + # Execute the tool + with patch.object(mcp, "_http_client", mock_client): + result = await mcp._execute_api_tool( + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map + ) + + # Verify the result + assert len(result) == 1 + assert isinstance(result[0], TextContent) + + # Verify the HTTP client was called correctly + # All parameters should be in query params + mock_client.get.assert_called_once_with("/items/", params={"skip": 0, "limit": 10, "sort_by": "name"}, headers={}) + + +@pytest.mark.asyncio +async def test_error_handling_unchanged(simple_fastapi_app: FastAPI): + """Test that error handling improvements don't break existing error responses.""" + mcp = FastApiMCP(simple_fastapi_app) + + # Mock the HTTP client response with error + mock_response = MagicMock() + mock_response.json.return_value = {"detail": "Item not found"} + mock_response.status_code = 404 + mock_response.text = '{"detail": "Item not found"}' + + # Mock the HTTP client + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + # Test error response + tool_name = "get_item" + arguments = {"item_id": 999} # Non-existent item + + # Execute the tool and expect an exception + with patch.object(mcp, "_http_client", mock_client): + with pytest.raises(Exception) as exc_info: + await mcp._execute_api_tool( + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map + ) + + # Verify the error message format is unchanged + error_message = str(exc_info.value) + assert "Error calling get_item" in error_message + assert "Status code: 404" in error_message + assert "Item not found" in error_message + + +@pytest.mark.asyncio +async def test_operation_map_structure_unchanged(simple_fastapi_app: FastAPI): + """Test that operation map structure includes new fields but maintains backward compatibility.""" + mcp = FastApiMCP(simple_fastapi_app) + + # Check that operation map has expected structure + assert "create_item" in mcp.operation_map + operation = mcp.operation_map["create_item"] + + # Verify existing fields are still present + assert "path" in operation + assert "method" in operation + assert "parameters" in operation + assert "request_body" in operation + + # Verify new fields are added (but may be None for JSON endpoints) + assert "content_type" in operation + assert "form_fields" in operation + + # For JSON endpoints, these should be None or empty + # (since simple_fastapi_app uses JSON, not form parameters) + assert operation["content_type"] is None or operation["content_type"] == "application/json" + assert operation["form_fields"] == [] + + +@pytest.mark.asyncio +async def test_no_content_type_fallback_to_json(simple_fastapi_app: FastAPI): + """Test that endpoints with no specific content type fall back to JSON behavior.""" + mcp = FastApiMCP(simple_fastapi_app) + + # Mock the HTTP client response + mock_response = MagicMock() + mock_response.json.return_value = {"id": 1, "name": "Test Item"} + mock_response.status_code = 200 + mock_response.text = '{"id": 1, "name": "Test Item"}' + + # Mock the HTTP client + mock_client = AsyncMock() + mock_client.request.return_value = mock_response + + # Test POST request without specific content type + tool_name = "create_item" + arguments = {"item": {"id": 1, "name": "Test Item", "price": 10.0}} + + # Execute the tool + with patch.object(mcp, "_http_client", mock_client): + result = await mcp._execute_api_tool( + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map + ) + + # Verify the result + assert len(result) == 1 + assert isinstance(result[0], TextContent) + + # Verify the HTTP client was called with JSON (fallback behavior) + mock_client.request.assert_called_once_with("post", "/items/", params={}, headers={}, json=arguments) + + +@pytest.mark.asyncio +async def test_delete_endpoints_unchanged(simple_fastapi_app: FastAPI): + """Test that DELETE endpoints continue to work unchanged.""" + mcp = FastApiMCP(simple_fastapi_app) + + # Mock the HTTP client response + mock_response = MagicMock() + mock_response.json.return_value = None + mock_response.status_code = 204 + mock_response.text = "" + + # Mock the HTTP client + mock_client = AsyncMock() + mock_client.delete.return_value = mock_response + + # Test DELETE request + tool_name = "delete_item" + arguments = {"item_id": 1} + + # Execute the tool + with patch.object(mcp, "_http_client", mock_client): + result = await mcp._execute_api_tool( + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map + ) + + # Verify the result + assert len(result) == 1 + assert isinstance(result[0], TextContent) + + # Verify the HTTP client was called correctly (unchanged) + mock_client.delete.assert_called_once_with("/items/1", params={}, headers={}) + + +@pytest.mark.asyncio +async def test_response_formatting_unchanged(simple_fastapi_app: FastAPI): + """Test that response formatting remains unchanged for JSON responses.""" + mcp = FastApiMCP(simple_fastapi_app) + + # Mock the HTTP client response with complex JSON + complex_response = { + "id": 1, + "name": "Test Item", + "price": 10.0, + "tags": ["tag1", "tag2"], + "metadata": {"created_at": "2023-01-01T00:00:00Z", "updated_at": "2023-01-02T00:00:00Z"}, + } + mock_response = MagicMock() + mock_response.json.return_value = complex_response + mock_response.status_code = 200 + mock_response.text = '{"id": 1, "name": "Test Item", "price": 10.0, "tags": ["tag1", "tag2"], "metadata": {"created_at": "2023-01-01T00:00:00Z", "updated_at": "2023-01-02T00:00:00Z"}}' + + # Mock the HTTP client + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + # Test GET request + tool_name = "get_item" + arguments = {"item_id": 1} + + # Execute the tool + with patch.object(mcp, "_http_client", mock_client): + result = await mcp._execute_api_tool( + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map + ) + + # Verify the result + assert len(result) == 1 + assert isinstance(result[0], TextContent) + + # Verify the response is properly formatted JSON (unchanged behavior) + response_text = result[0].text + assert "id" in response_text + assert "name" in response_text + assert "metadata" in response_text + # Should be formatted with indentation + assert " " in response_text # Indentation indicates JSON formatting + + +@pytest.mark.asyncio +async def test_header_forwarding_unchanged(simple_fastapi_app: FastAPI): + """Test that header forwarding behavior remains unchanged.""" + mcp = FastApiMCP(simple_fastapi_app) + + # Mock the HTTP client response + mock_response = MagicMock() + mock_response.json.return_value = {"id": 1, "name": "Test Item"} + mock_response.status_code = 200 + mock_response.text = '{"id": 1, "name": "Test Item"}' + + # Mock the HTTP client + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + # Create HTTP request info with authorization header + from fastapi_mcp.types import HTTPRequestInfo + + http_request_info = HTTPRequestInfo( + method="GET", + path="/test", + headers={"authorization": "Bearer test-token", "x-custom": "custom-value"}, + cookies={}, + query_params={}, + body=None, + ) + + # Test GET request with header forwarding + tool_name = "get_item" + arguments = {"item_id": 1} + + # Execute the tool + with patch.object(mcp, "_http_client", mock_client): + result = await mcp._execute_api_tool( + client=mock_client, + tool_name=tool_name, + arguments=arguments, + operation_map=mcp.operation_map, + http_request_info=http_request_info, + ) + + # Verify the result + assert len(result) == 1 + assert isinstance(result[0], TextContent) + + # Verify the HTTP client was called with forwarded authorization header + # (unchanged behavior - only authorization header should be forwarded by default) + mock_client.get.assert_called_once_with("/items/1", params={}, headers={"authorization": "Bearer test-token"}) + + +@pytest.mark.asyncio +async def test_mixed_form_and_json_endpoints(): + """Test that form parameter endpoints can coexist with JSON endpoints.""" + from fastapi import FastAPI, Form + from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools + from fastapi.openapi.utils import get_openapi + + # Create a test app with both JSON and form endpoints + app = FastAPI(title="Mixed Test App", description="Test app with both JSON and form endpoints") + + @app.post("/json-endpoint", operation_id="json_endpoint") + async def json_endpoint(data: dict): + return {"received": data} + + @app.post("/form-endpoint", operation_id="form_endpoint") + async def form_endpoint(name: str = Form(...), age: int = Form(...)): + return {"name": name, "age": age} + + # Get OpenAPI schema and convert to MCP tools + openapi_schema = get_openapi( + title=app.title, + version=app.version, + openapi_version=app.openapi_version, + description=app.description, + routes=app.routes, + ) + + tools, operation_map = convert_openapi_to_mcp_tools(openapi_schema) + + # Verify both endpoints are present + tool_names = [tool.name for tool in tools] + assert "json_endpoint" in tool_names + assert "form_endpoint" in tool_names + + # Verify operation map has correct content types + json_op = operation_map["json_endpoint"] + form_op = operation_map["form_endpoint"] + + # JSON endpoint should have JSON content type or None (fallback) + assert json_op["content_type"] in [None, "application/json"] + assert json_op["form_fields"] == [] + + # Form endpoint should have form content type and form fields + assert form_op["content_type"] == "application/x-www-form-urlencoded" + assert set(form_op["form_fields"]) == {"name", "age"} + + +@pytest.mark.asyncio +async def test_content_type_priority_with_multiple_types(): + """Test that content type priority works correctly when multiple types are available.""" + from fastapi_mcp.openapi.convert import _detect_content_type_and_form_fields + + # Mock request body with multiple content types + request_body = { + "content": { + "application/json": {"schema": {"type": "object", "properties": {"data": {"type": "string"}}}}, + "application/x-www-form-urlencoded": { + "schema": {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}} + }, + } + } + + # Should prioritize form-encoded over JSON + content_type, form_fields = _detect_content_type_and_form_fields(request_body) + assert content_type == "application/x-www-form-urlencoded" + assert set(form_fields) == {"name", "age"} + + +@pytest.mark.asyncio +async def test_fallback_behavior_on_detection_failure(): + """Test that the system falls back to JSON behavior when content type detection fails.""" + from fastapi_mcp.openapi.convert import _detect_content_type_and_form_fields + + # Mock request body with unsupported content type + request_body = {"content": {"application/xml": {"schema": {"type": "string"}}}} + + # Should fall back to None (JSON behavior) + content_type, form_fields = _detect_content_type_and_form_fields(request_body) + assert content_type is None + assert form_fields == [] + + +@pytest.mark.asyncio +async def test_empty_request_body_handling(): + """Test that empty or missing request bodies are handled correctly.""" + from fastapi_mcp.openapi.convert import _detect_content_type_and_form_fields + + # Test with empty request body + content_type, form_fields = _detect_content_type_and_form_fields({}) + assert content_type is None + assert form_fields == [] + + # Test with None request body + content_type, form_fields = _detect_content_type_and_form_fields(None) + assert content_type is None + assert form_fields == [] + + # Test with request body without content + request_body = {"description": "Test endpoint"} + content_type, form_fields = _detect_content_type_and_form_fields(request_body) + assert content_type is None + assert form_fields == [] diff --git a/tests/test_mcp_execute_api_tool.py b/tests/test_mcp_execute_api_tool.py index cc05d34..85c792d 100644 --- a/tests/test_mcp_execute_api_tool.py +++ b/tests/test_mcp_execute_api_tool.py @@ -10,183 +10,151 @@ async def test_execute_api_tool_success(simple_fastapi_app: FastAPI): """Test successful execution of an API tool.""" mcp = FastApiMCP(simple_fastapi_app) - + # Mock the HTTP client response mock_response = MagicMock() mock_response.json.return_value = {"id": 1, "name": "Test Item"} mock_response.status_code = 200 mock_response.text = '{"id": 1, "name": "Test Item"}' - + # Mock the HTTP client mock_client = AsyncMock() mock_client.get.return_value = mock_response - + # Test parameters tool_name = "get_item" arguments = {"item_id": 1} - + # Execute the tool - with patch.object(mcp, '_http_client', mock_client): + with patch.object(mcp, "_http_client", mock_client): result = await mcp._execute_api_tool( - client=mock_client, - tool_name=tool_name, - arguments=arguments, - operation_map=mcp.operation_map + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map ) - + # Verify the result assert len(result) == 1 assert isinstance(result[0], TextContent) assert result[0].text == '{\n "id": 1,\n "name": "Test Item"\n}' - + # Verify the HTTP client was called correctly - mock_client.get.assert_called_once_with( - "/items/1", - params={}, - headers={} - ) + mock_client.get.assert_called_once_with("/items/1", params={}, headers={}) @pytest.mark.asyncio async def test_execute_api_tool_with_query_params(simple_fastapi_app: FastAPI): """Test execution of an API tool with query parameters.""" mcp = FastApiMCP(simple_fastapi_app) - + # Mock the HTTP client response mock_response = MagicMock() mock_response.json.return_value = [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}] mock_response.status_code = 200 mock_response.text = '[{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}]' - + # Mock the HTTP client mock_client = AsyncMock() mock_client.get.return_value = mock_response - + # Test parameters tool_name = "list_items" arguments = {"skip": 0, "limit": 2} - + # Execute the tool - with patch.object(mcp, '_http_client', mock_client): + with patch.object(mcp, "_http_client", mock_client): result = await mcp._execute_api_tool( - client=mock_client, - tool_name=tool_name, - arguments=arguments, - operation_map=mcp.operation_map + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map ) - + # Verify the result assert len(result) == 1 assert isinstance(result[0], TextContent) - + # Verify the HTTP client was called with query parameters - mock_client.get.assert_called_once_with( - "/items/", - params={"skip": 0, "limit": 2}, - headers={} - ) + mock_client.get.assert_called_once_with("/items/", params={"skip": 0, "limit": 2}, headers={}) @pytest.mark.asyncio async def test_execute_api_tool_with_body(simple_fastapi_app: FastAPI): """Test execution of an API tool with request body.""" mcp = FastApiMCP(simple_fastapi_app) - + # Mock the HTTP client response mock_response = MagicMock() mock_response.json.return_value = {"id": 1, "name": "New Item"} mock_response.status_code = 200 mock_response.text = '{"id": 1, "name": "New Item"}' - + # Mock the HTTP client mock_client = AsyncMock() - mock_client.post.return_value = mock_response - + # Mock the request method instead of post since _request_with_body uses client.request() + mock_client.request.return_value = mock_response + # Test parameters tool_name = "create_item" arguments = { - "item": { - "id": 1, - "name": "New Item", - "price": 10.0, - "tags": ["tag1"], - "description": "New item description" - } + "item": {"id": 1, "name": "New Item", "price": 10.0, "tags": ["tag1"], "description": "New item description"} } - + # Execute the tool - with patch.object(mcp, '_http_client', mock_client): + with patch.object(mcp, "_http_client", mock_client): result = await mcp._execute_api_tool( - client=mock_client, - tool_name=tool_name, - arguments=arguments, - operation_map=mcp.operation_map + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map ) - + # Verify the result assert len(result) == 1 assert isinstance(result[0], TextContent) - + # Verify the HTTP client was called with the request body - mock_client.post.assert_called_once_with( - "/items/", - params={}, - headers={}, - json=arguments - ) + mock_client.request.assert_called_once_with("post", "/items/", params={}, headers={}, json=arguments) @pytest.mark.asyncio async def test_execute_api_tool_with_non_ascii_chars(simple_fastapi_app: FastAPI): """Test execution of an API tool with non-ASCII characters.""" mcp = FastApiMCP(simple_fastapi_app) - + # Test data with both ASCII and non-ASCII characters test_data = { "id": 1, "name": "你好 World", # Chinese characters + ASCII "price": 10.0, "tags": ["tag1", "标签2"], # Chinese characters in tags - "description": "这是一个测试描述" # All Chinese characters + "description": "这是一个测试描述", # All Chinese characters } - + # Mock the HTTP client response mock_response = MagicMock() mock_response.json.return_value = test_data mock_response.status_code = 200 - mock_response.text = '{"id": 1, "name": "你好 World", "price": 10.0, "tags": ["tag1", "标签2"], "description": "这是一个测试描述"}' - + mock_response.text = ( + '{"id": 1, "name": "你好 World", "price": 10.0, "tags": ["tag1", "标签2"], "description": "这是一个测试描述"}' + ) + # Mock the HTTP client mock_client = AsyncMock() mock_client.get.return_value = mock_response - + # Test parameters tool_name = "get_item" arguments = {"item_id": 1} - + # Execute the tool - with patch.object(mcp, '_http_client', mock_client): + with patch.object(mcp, "_http_client", mock_client): result = await mcp._execute_api_tool( - client=mock_client, - tool_name=tool_name, - arguments=arguments, - operation_map=mcp.operation_map + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map ) - + # Verify the result assert len(result) == 1 assert isinstance(result[0], TextContent) - + # Verify that the response contains both ASCII and non-ASCII characters response_text = result[0].text assert "你好" in response_text # Chinese characters preserved assert "World" in response_text # ASCII characters preserved assert "标签2" in response_text # Chinese characters in tags preserved assert "这是一个测试描述" in response_text # All Chinese description preserved - + # Verify the HTTP client was called correctly - mock_client.get.assert_called_once_with( - "/items/1", - params={}, - headers={} - ) + mock_client.get.assert_called_once_with("/items/1", params={}, headers={}) diff --git a/tests/test_openapi_conversion.py b/tests/test_openapi_conversion.py index aefe643..5df9ee9 100644 --- a/tests/test_openapi_conversion.py +++ b/tests/test_openapi_conversion.py @@ -377,6 +377,129 @@ def test_body_params_descriptions_and_defaults(complex_fastapi_app: FastAPI): assert item_props["quantity"]["default"] == 1 +def test_content_type_detection(): + """Test the content type detection functionality for form parameters.""" + from fastapi_mcp.openapi.convert import _detect_content_type_and_form_fields + + # Test form-encoded content type detection + form_encoded_request_body = { + "content": { + "application/x-www-form-urlencoded": { + "schema": { + "type": "object", + "properties": {"username": {"type": "string"}, "password": {"type": "string"}}, + } + } + } + } + + content_type, form_fields = _detect_content_type_and_form_fields(form_encoded_request_body) + assert content_type == "application/x-www-form-urlencoded" + assert set(form_fields) == {"username", "password"} + + # Test multipart content type detection + multipart_request_body = { + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "properties": {"file": {"type": "string", "format": "binary"}, "description": {"type": "string"}}, + } + } + } + } + + content_type, form_fields = _detect_content_type_and_form_fields(multipart_request_body) + assert content_type == "multipart/form-data" + assert set(form_fields) == {"file", "description"} + + # Test JSON content type detection (fallback) + json_request_body = { + "content": { + "application/json": { + "schema": {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}} + } + } + } + + content_type, form_fields = _detect_content_type_and_form_fields(json_request_body) + assert content_type == "application/json" + assert form_fields == [] # No form fields for JSON + + # Test priority logic - form-encoded should win over JSON + mixed_request_body = { + "content": { + "application/json": {"schema": {"type": "object", "properties": {"data": {"type": "string"}}}}, + "application/x-www-form-urlencoded": { + "schema": {"type": "object", "properties": {"form_field": {"type": "string"}}} + }, + } + } + + content_type, form_fields = _detect_content_type_and_form_fields(mixed_request_body) + assert content_type == "application/x-www-form-urlencoded" + assert form_fields == ["form_field"] + + # Test multipart should win over JSON but lose to form-encoded + mixed_request_body_2 = { + "content": { + "application/json": {"schema": {"type": "object", "properties": {"data": {"type": "string"}}}}, + "multipart/form-data": { + "schema": {"type": "object", "properties": {"upload": {"type": "string", "format": "binary"}}} + }, + } + } + + content_type, form_fields = _detect_content_type_and_form_fields(mixed_request_body_2) + assert content_type == "multipart/form-data" + assert form_fields == ["upload"] + + # Test empty request body + content_type, form_fields = _detect_content_type_and_form_fields({}) + assert content_type is None + assert form_fields == [] + + # Test request body without content + content_type, form_fields = _detect_content_type_and_form_fields({"required": True}) + assert content_type is None + assert form_fields == [] + + # Test unsupported content type + unsupported_request_body = {"content": {"application/xml": {"schema": {"type": "string"}}}} + + content_type, form_fields = _detect_content_type_and_form_fields(unsupported_request_body) + assert content_type is None + assert form_fields == [] + + +def test_operation_map_includes_content_type_info(complex_fastapi_app: FastAPI): + """Test that the operation map includes content type and form fields information.""" + openapi_schema = get_openapi( + title=complex_fastapi_app.title, + version=complex_fastapi_app.version, + openapi_version=complex_fastapi_app.openapi_version, + description=complex_fastapi_app.description, + routes=complex_fastapi_app.routes, + ) + + tools, operation_map = convert_openapi_to_mcp_tools(openapi_schema) + + # Check that all operations have the new fields + for operation_id, operation_info in operation_map.items(): + assert "content_type" in operation_info + assert "form_fields" in operation_info + + # For the complex app, create_order should have JSON content type + if operation_id == "create_order": + assert operation_info["content_type"] == "application/json" + assert operation_info["form_fields"] == [] + + # GET operations should have no content type + if operation_info["method"] == "get": + assert operation_info["content_type"] is None + assert operation_info["form_fields"] == [] + + def test_body_params_edge_cases(complex_fastapi_app: FastAPI): """ Test handling of edge cases for body parameters, such as: