diff --git a/agent/anthropic_adapter.py b/agent/anthropic_adapter.py index af25b62b0c..84869bbda0 100644 --- a/agent/anthropic_adapter.py +++ b/agent/anthropic_adapter.py @@ -1117,6 +1117,49 @@ def _sanitize_tool_id(tool_id: str) -> str: return sanitized or "tool_0" +def _normalize_tool_input_schema(schema: Any) -> Dict[str, Any]: + """Normalize tool schemas before sending them to Anthropic. + + Anthropic's tool schema validator rejects nullable unions such as + ``anyOf: [{"type": "string"}, {"type": "null"}]`` that Pydantic/MCP + commonly emits for optional fields. Tool optionality is represented by + the parent ``required`` array, so collapse nullable unions to the non-null + branch while preserving metadata like description/default. + """ + if not schema: + return {"type": "object", "properties": {}} + + def _strip_nullable_union(node: Any) -> Any: + if isinstance(node, list): + return [_strip_nullable_union(item) for item in node] + if not isinstance(node, dict): + return node + + stripped = {k: _strip_nullable_union(v) for k, v in node.items()} + for key in ("anyOf", "oneOf"): + variants = stripped.get(key) + if not isinstance(variants, list): + continue + non_null = [ + item for item in variants + if not (isinstance(item, dict) and item.get("type") == "null") + ] + if len(non_null) == 1 and len(non_null) != len(variants): + replacement = dict(non_null[0]) if isinstance(non_null[0], dict) else {} + for meta_key in ("title", "description", "default", "examples"): + if meta_key in stripped and meta_key not in replacement: + replacement[meta_key] = stripped[meta_key] + return _strip_nullable_union(replacement) + return stripped + + normalized = _strip_nullable_union(schema) + if not isinstance(normalized, dict): + return {"type": "object", "properties": {}} + if normalized.get("type") == "object" and not isinstance(normalized.get("properties"), dict): + normalized = {**normalized, "properties": {}} + return normalized + + def convert_tools_to_anthropic(tools: List[Dict]) -> List[Dict]: """Convert OpenAI tool definitions to Anthropic format.""" if not tools: @@ -1127,7 +1170,9 @@ def convert_tools_to_anthropic(tools: List[Dict]) -> List[Dict]: result.append({ "name": fn.get("name", ""), "description": fn.get("description", ""), - "input_schema": fn.get("parameters", {"type": "object", "properties": {}}), + "input_schema": _normalize_tool_input_schema( + fn.get("parameters", {"type": "object", "properties": {}}) + ), }) return result diff --git a/tests/agent/test_anthropic_adapter.py b/tests/agent/test_anthropic_adapter.py index 32d24666b3..b78ae48590 100644 --- a/tests/agent/test_anthropic_adapter.py +++ b/tests/agent/test_anthropic_adapter.py @@ -544,6 +544,36 @@ class TestConvertTools: assert convert_tools_to_anthropic([]) == [] assert convert_tools_to_anthropic(None) == [] + def test_strips_nullable_union_from_input_schema(self): + tools = [ + { + "type": "function", + "function": { + "name": "run", + "description": "Run command", + "parameters": { + "type": "object", + "properties": { + "command": {"type": "string"}, + "timeout": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "default": None, + }, + }, + "required": ["command"], + }, + }, + } + ] + + result = convert_tools_to_anthropic(tools) + + assert result[0]["input_schema"]["properties"]["timeout"] == { + "type": "integer", + "default": None, + } + assert result[0]["input_schema"]["required"] == ["command"] + # --------------------------------------------------------------------------- # Message conversion diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 1604d4adb5..6e19a90121 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -266,6 +266,56 @@ class TestSchemaConversion: assert schema["properties"]["items"]["items"]["properties"] == {} + def test_optional_nullable_field_is_collapsed_to_non_null_schema(self): + """Anthropic rejects MCP/Pydantic anyOf-null optional parameter schemas.""" + from tools.mcp_tool import _normalize_mcp_input_schema + + schema = _normalize_mcp_input_schema({ + "type": "object", + "properties": { + "command": {"type": "string"}, + "workdir": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "description": "Optional working directory", + }, + }, + "required": ["command"], + }) + + assert schema["properties"]["workdir"] == { + "type": "string", + "default": None, + "description": "Optional working directory", + } + assert schema["required"] == ["command"] + + def test_nested_nullable_array_items_are_collapsed(self): + from tools.mcp_tool import _normalize_mcp_input_schema + + schema = _normalize_mcp_input_schema({ + "type": "object", + "properties": { + "filters": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "object", + "properties": {"field": {"type": "string"}}, + }, + {"type": "null"}, + ] + }, + } + }, + }) + + assert schema["properties"]["filters"]["items"] == { + "type": "object", + "properties": {"field": {"type": "string"}}, + } + def test_convert_mcp_schema_survives_missing_inputschema_attribute(self): """A Tool object without .inputSchema must not crash registration.""" import types @@ -1910,15 +1960,38 @@ class TestUtilityToolRegistration: import math import time -from mcp.types import ( - CreateMessageResult, - CreateMessageResultWithTools, - ErrorData, - SamplingCapability, - SamplingToolsCapability, - TextContent, - ToolUseContent, -) +class _CompatType: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +try: + from mcp.types import ( + CreateMessageResult, + ErrorData, + SamplingCapability, + TextContent, + ) +except ImportError: + CreateMessageResult = _CompatType + ErrorData = _CompatType + SamplingCapability = _CompatType + TextContent = _CompatType + +try: + from mcp.types import CreateMessageResultWithTools +except ImportError: + CreateMessageResultWithTools = _CompatType + +try: + from mcp.types import SamplingToolsCapability +except ImportError: + SamplingToolsCapability = _CompatType + +try: + from mcp.types import ToolUseContent +except ImportError: + ToolUseContent = _CompatType from tools.mcp_tool import SamplingHandler, _safe_numeric diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index e02219d7bc..2e056eb91b 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -868,6 +868,7 @@ class MCPServerTask: "_task", "_ready", "_shutdown_event", "_reconnect_event", "_tools", "_error", "_config", "_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock", + "_rpc_lock", ) def __init__(self, name: str): @@ -890,6 +891,12 @@ class MCPServerTask: self._registered_tool_names: list[str] = [] self._auth_type: str = "" self._refresh_lock = asyncio.Lock() + # MCP stdio sessions are a single JSON-RPC stream. Some servers emit + # list_changed notifications during startup; if the notification + # handler calls list_tools while a normal tool call is in flight, the + # stream can wedge and the user-visible tool call times out. Serialize + # client-initiated RPCs per server. + self._rpc_lock = asyncio.Lock() def _is_http(self) -> bool: """Check if this server uses HTTP transport.""" @@ -916,7 +923,16 @@ class MCPServerTask: "MCP server '%s': received tools/list_changed notification", self.name, ) - await self._refresh_tools() + # Some servers (notably mongodb-mcp-server) emit + # tools/list_changed immediately after initialize, + # while the client may already be executing another + # request. Refreshing synchronously inside the SDK + # notification handler can race with that request + # and wedge the stdio JSON-RPC stream, making all + # subsequent tool calls time out. Do the refresh in + # a separate task and let the handler return + # promptly. + asyncio.create_task(self._refresh_tools()) case PromptListChangedNotification(): logger.debug("MCP server '%s': prompts/list_changed (ignored)", self.name) case ResourceListChangedNotification(): @@ -942,12 +958,15 @@ class MCPServerTask: old_tool_names = set(self._registered_tool_names) # 1. Fetch current tool list from server - tools_result = await self.session.list_tools() + async with self._rpc_lock: + tools_result = await self.session.list_tools() new_mcp_tools = tools_result.tools if hasattr(tools_result, "tools") else [] - # 2. Deregister old tools from the central registry - for prefixed_name in self._registered_tool_names: - registry.deregister(prefixed_name) + # 2. Re-register with fresh tool list. Avoid deregistering first: + # live agent turns already have tool-call IDs pointing at the + # existing handler functions. Replacing entries in-place is enough + # for unchanged names and avoids transient "tool not connected" / + # stale-handler races during startup notifications. # 3. Re-register with fresh tool list self._tools = new_mcp_tools @@ -1204,7 +1223,8 @@ class MCPServerTask: """Discover tools from the connected session.""" if self.session is None: return - tools_result = await self.session.list_tools() + async with self._rpc_lock: + tools_result = await self.session.list_tools() self._tools = ( tools_result.tools if hasattr(tools_result, "tools") @@ -1954,7 +1974,8 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float): }, ensure_ascii=False) async def _call(): - result = await server.session.call_tool(tool_name, arguments=args) + async with server._rpc_lock: + result = await server.session.call_tool(tool_name, arguments=args) # MCP CallToolResult has .content (list of content blocks) and .isError if result.isError: error_text = "" @@ -2052,7 +2073,8 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float): }, ensure_ascii=False) async def _call(): - result = await server.session.list_resources() + async with server._rpc_lock: + result = await server.session.list_resources() resources = [] for r in (result.resources if hasattr(result, "resources") else []): entry = {} @@ -2115,7 +2137,8 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float): return tool_error("Missing required parameter 'uri'") async def _call(): - result = await server.session.read_resource(uri) + async with server._rpc_lock: + result = await server.session.read_resource(uri) # read_resource returns ReadResourceResult with .contents list parts: List[str] = [] contents = result.contents if hasattr(result, "contents") else [] @@ -2168,7 +2191,8 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float): }, ensure_ascii=False) async def _call(): - result = await server.session.list_prompts() + async with server._rpc_lock: + result = await server.session.list_prompts() prompts = [] for p in (result.prompts if hasattr(result, "prompts") else []): entry = {} @@ -2237,7 +2261,8 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float): arguments = args.get("arguments", {}) async def _call(): - result = await server.session.get_prompt(name, arguments=arguments) + async with server._rpc_lock: + result = await server.session.get_prompt(name, arguments=arguments) # GetPromptResult has .messages list messages = [] for msg in (result.messages if hasattr(result, "messages") else []): @@ -2321,6 +2346,11 @@ def _normalize_mcp_input_schema(schema: dict | None) -> dict: * ``required`` arrays are pruned to only names that exist in ``properties``; otherwise Google AI Studio / Gemini 400s with ``property is not defined``. See PR #4651. + * MCP/Pydantic optional fields commonly arrive as + ``anyOf: [{...}, {"type": "null"}], default: null``. Anthropic rejects + nullable branches in tool input schemas, so nullable unions are collapsed + to the non-null branch and optionality remains represented solely by the + parent object's ``required`` list. All repairs are provider-agnostic and ideally produce a schema valid on OpenAI, Anthropic, Gemini, and Moonshot in one pass. @@ -2342,6 +2372,30 @@ def _normalize_mcp_input_schema(schema: dict | None) -> dict: return [_rewrite_local_refs(item) for item in node] return node + def _strip_nullable_union(node): + """Collapse JSON Schema nullable unions to provider-safe non-null schemas.""" + if isinstance(node, list): + return [_strip_nullable_union(item) for item in node] + if not isinstance(node, dict): + return node + + stripped = {k: _strip_nullable_union(v) for k, v in node.items()} + for key in ("anyOf", "oneOf"): + variants = stripped.get(key) + if not isinstance(variants, list): + continue + non_null = [ + item for item in variants + if not (isinstance(item, dict) and item.get("type") == "null") + ] + if len(non_null) == 1 and len(non_null) != len(variants): + replacement = dict(non_null[0]) if isinstance(non_null[0], dict) else {} + for meta_key in ("title", "description", "default", "examples"): + if meta_key in stripped and meta_key not in replacement: + replacement[meta_key] = stripped[meta_key] + return _strip_nullable_union(replacement) + return stripped + def _repair_object_shape(node): """Recursively repair object-shaped nodes: fill type, prune required.""" if isinstance(node, list): @@ -2381,6 +2435,7 @@ def _normalize_mcp_input_schema(schema: dict | None) -> dict: return repaired normalized = _rewrite_local_refs(schema) + normalized = _strip_nullable_union(normalized) normalized = _repair_object_shape(normalized) # Ensure top-level is a well-formed object schema