diff --git a/cli-config.yaml.example b/cli-config.yaml.example index ec7ccb620..fb1af78fc 100644 --- a/cli-config.yaml.example +++ b/cli-config.yaml.example @@ -555,6 +555,21 @@ toolsets: # args: ["-y", "@modelcontextprotocol/server-github"] # env: # GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..." +# +# Sampling (server-initiated LLM requests) — enabled by default. +# Per-server config under the 'sampling' key: +# analysis: +# command: npx +# args: ["-y", "analysis-server"] +# sampling: +# enabled: true # default: true +# model: "gemini-3-flash" # override model (optional) +# max_tokens_cap: 4096 # max tokens per request +# timeout: 30 # LLM call timeout (seconds) +# max_rpm: 10 # max requests per minute +# allowed_models: [] # model whitelist (empty = all) +# max_tool_rounds: 5 # tool loop limit (0 = disable) +# log_level: "info" # audit verbosity # ============================================================================= # Voice Transcription (Speech-to-Text) diff --git a/skills/mcp/native-mcp/SKILL.md b/skills/mcp/native-mcp/SKILL.md index 4362c6cf8..e56bf3fc1 100644 --- a/skills/mcp/native-mcp/SKILL.md +++ b/skills/mcp/native-mcp/SKILL.md @@ -321,6 +321,32 @@ mcp_servers: All tools from all servers are registered and available simultaneously. Each server's tools are prefixed with its name to avoid collisions. +## Sampling (Server-Initiated LLM Requests) + +Hermes supports MCP's `sampling/createMessage` capability — MCP servers can request LLM completions through the agent during tool execution. This enables agent-in-the-loop workflows (data analysis, content generation, decision-making). + +Sampling is **enabled by default**. Configure per server: + +```yaml +mcp_servers: + my_server: + command: "npx" + args: ["-y", "my-mcp-server"] + sampling: + enabled: true # default: true + model: "gemini-3-flash" # model override (optional) + max_tokens_cap: 4096 # max tokens per request + timeout: 30 # LLM call timeout (seconds) + max_rpm: 10 # max requests per minute + allowed_models: [] # model whitelist (empty = all) + max_tool_rounds: 5 # tool loop limit (0 = disable) + log_level: "info" # audit verbosity +``` + +Servers can also include `tools` in sampling requests for multi-turn tool-augmented workflows. The `max_tool_rounds` config prevents infinite tool loops. Per-server audit metrics (requests, errors, tokens, tool use count) are tracked via `get_mcp_status()`. + +Disable sampling for untrusted servers with `sampling: { enabled: false }`. + ## Notes - MCP tools are called synchronously from the agent's perspective but run asynchronously on a dedicated background event loop diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 7da383a95..1acbdfa12 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -1489,3 +1489,781 @@ class TestUtilityToolRegistration: assert entry.check_fn() is False _servers.pop("chk", None) + + +# =========================================================================== +# SamplingHandler tests +# =========================================================================== + +import math +import time + +from mcp.types import ( + CreateMessageResult, + CreateMessageResultWithTools, + ErrorData, + SamplingCapability, + SamplingToolsCapability, + TextContent, + ToolUseContent, +) + +from tools.mcp_tool import SamplingHandler, _safe_numeric + + +# --------------------------------------------------------------------------- +# Helpers for sampling tests +# --------------------------------------------------------------------------- + +def _make_sampling_params( + messages=None, + max_tokens=100, + system_prompt=None, + model_preferences=None, + temperature=None, + stop_sequences=None, + tools=None, + tool_choice=None, +): + """Create a fake CreateMessageRequestParams using SimpleNamespace. + + Each message must have a ``content_as_list`` attribute that mirrors + the SDK helper so that ``_convert_messages`` works correctly. + """ + if messages is None: + content = SimpleNamespace(text="Hello") + msg = SimpleNamespace(role="user", content=content, content_as_list=[content]) + messages = [msg] + + params = SimpleNamespace( + messages=messages, + maxTokens=max_tokens, + modelPreferences=model_preferences, + temperature=temperature, + stopSequences=stop_sequences, + tools=tools, + toolChoice=tool_choice, + ) + if system_prompt is not None: + params.systemPrompt = system_prompt + return params + + +def _make_llm_response( + content="LLM response", + model="test-model", + finish_reason="stop", + tool_calls=None, +): + """Create a fake OpenAI chat completion response (text).""" + message = SimpleNamespace(content=content, tool_calls=tool_calls) + choice = SimpleNamespace( + finish_reason=finish_reason, + message=message, + ) + usage = SimpleNamespace(total_tokens=42) + return SimpleNamespace(choices=[choice], model=model, usage=usage) + + +def _make_llm_tool_response(tool_calls_data=None, model="test-model"): + """Create a fake response with tool_calls. + + ``tool_calls_data``: list of (id, name, arguments_json) tuples. + """ + if tool_calls_data is None: + tool_calls_data = [("call_1", "get_weather", '{"city": "London"}')] + + tc_list = [ + SimpleNamespace( + id=tc_id, + function=SimpleNamespace(name=name, arguments=args), + ) + for tc_id, name, args in tool_calls_data + ] + return _make_llm_response( + content=None, + model=model, + finish_reason="tool_calls", + tool_calls=tc_list, + ) + + +# --------------------------------------------------------------------------- +# 1. _safe_numeric helper +# --------------------------------------------------------------------------- + +class TestSafeNumeric: + def test_int_passthrough(self): + assert _safe_numeric(10, 5, int) == 10 + + def test_string_coercion(self): + assert _safe_numeric("20", 5, int) == 20 + + def test_none_returns_default(self): + assert _safe_numeric(None, 7, int) == 7 + + def test_inf_returns_default(self): + assert _safe_numeric(float("inf"), 3.0, float) == 3.0 + + def test_nan_returns_default(self): + assert _safe_numeric(float("nan"), 4.0, float) == 4.0 + + def test_below_minimum_clamps(self): + assert _safe_numeric(-5, 10, int, minimum=1) == 1 + + def test_minimum_zero_allowed(self): + assert _safe_numeric(0, 10, int, minimum=0) == 0 + + def test_non_numeric_string_returns_default(self): + assert _safe_numeric("abc", 42, int) == 42 + + def test_float_coercion(self): + assert _safe_numeric("3.5", 1.0, float) == 3.5 + + +# --------------------------------------------------------------------------- +# 2. SamplingHandler initialization and config parsing +# --------------------------------------------------------------------------- + +class TestSamplingHandlerInit: + def test_defaults(self): + h = SamplingHandler("srv", {}) + assert h.server_name == "srv" + assert h.max_rpm == 10 + assert h.timeout == 30 + assert h.max_tokens_cap == 4096 + assert h.max_tool_rounds == 5 + assert h.model_override is None + assert h.allowed_models == [] + assert h.metrics == {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0} + + def test_custom_config(self): + cfg = { + "max_rpm": 20, + "timeout": 60, + "max_tokens_cap": 2048, + "max_tool_rounds": 3, + "model": "gpt-4o", + "allowed_models": ["gpt-4o", "gpt-3.5-turbo"], + "log_level": "debug", + } + h = SamplingHandler("custom", cfg) + assert h.max_rpm == 20 + assert h.timeout == 60.0 + assert h.max_tokens_cap == 2048 + assert h.max_tool_rounds == 3 + assert h.model_override == "gpt-4o" + assert h.allowed_models == ["gpt-4o", "gpt-3.5-turbo"] + + def test_string_numeric_config_values(self): + """YAML sometimes delivers numeric values as strings.""" + cfg = {"max_rpm": "15", "timeout": "45.5", "max_tokens_cap": "1024"} + h = SamplingHandler("s", cfg) + assert h.max_rpm == 15 + assert h.timeout == 45.5 + assert h.max_tokens_cap == 1024 + + +# --------------------------------------------------------------------------- +# 3. Rate limiting +# --------------------------------------------------------------------------- + +class TestRateLimit: + def setup_method(self): + self.handler = SamplingHandler("rl", {"max_rpm": 3}) + + def test_allows_under_limit(self): + assert self.handler._check_rate_limit() is True + assert self.handler._check_rate_limit() is True + assert self.handler._check_rate_limit() is True + + def test_rejects_over_limit(self): + for _ in range(3): + self.handler._check_rate_limit() + assert self.handler._check_rate_limit() is False + + def test_window_expiry(self): + """Old timestamps should be purged from the sliding window.""" + for _ in range(3): + self.handler._check_rate_limit() + # Simulate timestamps from 61 seconds ago + self.handler._rate_timestamps[:] = [time.time() - 61] * 3 + assert self.handler._check_rate_limit() is True + + +# --------------------------------------------------------------------------- +# 4. Model resolution +# --------------------------------------------------------------------------- + +class TestResolveModel: + def setup_method(self): + self.handler = SamplingHandler("mr", {}) + + def test_no_preference_no_override(self): + assert self.handler._resolve_model(None) is None + + def test_config_override_wins(self): + self.handler.model_override = "override-model" + prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")]) + assert self.handler._resolve_model(prefs) == "override-model" + + def test_hint_used_when_no_override(self): + prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")]) + assert self.handler._resolve_model(prefs) == "hint-model" + + def test_empty_hints(self): + prefs = SimpleNamespace(hints=[]) + assert self.handler._resolve_model(prefs) is None + + def test_hint_without_name(self): + prefs = SimpleNamespace(hints=[SimpleNamespace(name=None)]) + assert self.handler._resolve_model(prefs) is None + + +# --------------------------------------------------------------------------- +# 5. Message conversion +# --------------------------------------------------------------------------- + +class TestConvertMessages: + def setup_method(self): + self.handler = SamplingHandler("mc", {}) + + def test_single_text_message(self): + content = SimpleNamespace(text="Hello world") + msg = SimpleNamespace(role="user", content=content, content_as_list=[content]) + params = _make_sampling_params(messages=[msg]) + result = self.handler._convert_messages(params) + assert len(result) == 1 + assert result[0] == {"role": "user", "content": "Hello world"} + + def test_image_message(self): + text_block = SimpleNamespace(text="Look at this") + img_block = SimpleNamespace(data="abc123", mimeType="image/png") + msg = SimpleNamespace( + role="user", + content=[text_block, img_block], + content_as_list=[text_block, img_block], + ) + params = _make_sampling_params(messages=[msg]) + result = self.handler._convert_messages(params) + assert len(result) == 1 + parts = result[0]["content"] + assert len(parts) == 2 + assert parts[0] == {"type": "text", "text": "Look at this"} + assert parts[1]["type"] == "image_url" + assert "data:image/png;base64,abc123" in parts[1]["image_url"]["url"] + + def test_tool_result_message(self): + inner = SimpleNamespace(text="42 degrees") + tr_block = SimpleNamespace(toolUseId="call_1", content=[inner]) + msg = SimpleNamespace( + role="user", + content=[tr_block], + content_as_list=[tr_block], + ) + params = _make_sampling_params(messages=[msg]) + result = self.handler._convert_messages(params) + assert len(result) == 1 + assert result[0]["role"] == "tool" + assert result[0]["tool_call_id"] == "call_1" + assert result[0]["content"] == "42 degrees" + + def test_tool_use_message(self): + tu_block = SimpleNamespace( + id="call_2", name="get_weather", input={"city": "London"} + ) + msg = SimpleNamespace( + role="assistant", + content=[tu_block], + content_as_list=[tu_block], + ) + params = _make_sampling_params(messages=[msg]) + result = self.handler._convert_messages(params) + assert len(result) == 1 + assert result[0]["role"] == "assistant" + assert len(result[0]["tool_calls"]) == 1 + assert result[0]["tool_calls"][0]["function"]["name"] == "get_weather" + assert json.loads(result[0]["tool_calls"][0]["function"]["arguments"]) == {"city": "London"} + + def test_mixed_text_and_tool_use(self): + """Assistant message with both text and tool_calls.""" + text_block = SimpleNamespace(text="Let me check the weather") + tu_block = SimpleNamespace( + id="call_3", name="get_weather", input={"city": "Paris"} + ) + msg = SimpleNamespace( + role="assistant", + content=[text_block, tu_block], + content_as_list=[text_block, tu_block], + ) + params = _make_sampling_params(messages=[msg]) + result = self.handler._convert_messages(params) + assert len(result) == 1 + assert result[0]["content"] == "Let me check the weather" + assert len(result[0]["tool_calls"]) == 1 + + def test_fallback_without_content_as_list(self): + """When content_as_list is absent, falls back to content.""" + content = SimpleNamespace(text="Fallback text") + msg = SimpleNamespace(role="user", content=content) + params = _make_sampling_params(messages=[msg]) + result = self.handler._convert_messages(params) + assert len(result) == 1 + assert result[0]["content"] == "Fallback text" + + +# --------------------------------------------------------------------------- +# 6. Text-only sampling callback (full flow) +# --------------------------------------------------------------------------- + +class TestSamplingCallbackText: + def setup_method(self): + self.handler = SamplingHandler("txt", {}) + + def test_text_response(self): + """Full flow: text response returns CreateMessageResult.""" + fake_client = MagicMock() + fake_client.chat.completions.create.return_value = _make_llm_response( + content="Hello from LLM" + ) + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "default-model"), + ): + params = _make_sampling_params() + result = asyncio.run(self.handler(None, params)) + + assert isinstance(result, CreateMessageResult) + assert isinstance(result.content, TextContent) + assert result.content.text == "Hello from LLM" + assert result.model == "test-model" + assert result.role == "assistant" + assert result.stopReason == "endTurn" + + def test_system_prompt_prepended(self): + """System prompt is inserted as the first message.""" + fake_client = MagicMock() + fake_client.chat.completions.create.return_value = _make_llm_response() + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "default-model"), + ): + params = _make_sampling_params(system_prompt="Be helpful") + asyncio.run(self.handler(None, params)) + + call_args = fake_client.chat.completions.create.call_args + messages = call_args.kwargs["messages"] + assert messages[0] == {"role": "system", "content": "Be helpful"} + + def test_length_stop_reason(self): + """finish_reason='length' maps to stopReason='maxTokens'.""" + fake_client = MagicMock() + fake_client.chat.completions.create.return_value = _make_llm_response( + finish_reason="length" + ) + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "default-model"), + ): + params = _make_sampling_params() + result = asyncio.run(self.handler(None, params)) + + assert isinstance(result, CreateMessageResult) + assert result.stopReason == "maxTokens" + + +# --------------------------------------------------------------------------- +# 7. Tool use sampling callback +# --------------------------------------------------------------------------- + +class TestSamplingCallbackToolUse: + def setup_method(self): + self.handler = SamplingHandler("tu", {}) + + def test_tool_use_response(self): + """LLM tool_calls response returns CreateMessageResultWithTools.""" + fake_client = MagicMock() + fake_client.chat.completions.create.return_value = _make_llm_tool_response() + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "default-model"), + ): + params = _make_sampling_params() + result = asyncio.run(self.handler(None, params)) + + assert isinstance(result, CreateMessageResultWithTools) + assert result.stopReason == "toolUse" + assert result.model == "test-model" + assert len(result.content) == 1 + tc = result.content[0] + assert isinstance(tc, ToolUseContent) + assert tc.name == "get_weather" + assert tc.id == "call_1" + assert tc.input == {"city": "London"} + + def test_multiple_tool_calls(self): + """Multiple tool_calls in a single response.""" + fake_client = MagicMock() + fake_client.chat.completions.create.return_value = _make_llm_tool_response( + tool_calls_data=[ + ("call_a", "func_a", '{"x": 1}'), + ("call_b", "func_b", '{"y": 2}'), + ] + ) + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "default-model"), + ): + result = asyncio.run(self.handler(None, _make_sampling_params())) + + assert isinstance(result, CreateMessageResultWithTools) + assert len(result.content) == 2 + assert result.content[0].name == "func_a" + assert result.content[1].name == "func_b" + + +# --------------------------------------------------------------------------- +# 8. Tool loop governance +# --------------------------------------------------------------------------- + +class TestToolLoopGovernance: + def test_max_tool_rounds_enforcement(self): + """After max_tool_rounds consecutive tool responses, an error is returned.""" + handler = SamplingHandler("tl", {"max_tool_rounds": 2}) + fake_client = MagicMock() + fake_client.chat.completions.create.return_value = _make_llm_tool_response() + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "default-model"), + ): + params = _make_sampling_params() + # Round 1, 2: allowed + r1 = asyncio.run(handler(None, params)) + assert isinstance(r1, CreateMessageResultWithTools) + r2 = asyncio.run(handler(None, params)) + assert isinstance(r2, CreateMessageResultWithTools) + # Round 3: exceeds limit + r3 = asyncio.run(handler(None, params)) + assert isinstance(r3, ErrorData) + assert "Tool loop limit exceeded" in r3.message + + def test_text_response_resets_counter(self): + """A text response resets the tool loop counter.""" + handler = SamplingHandler("tl2", {"max_tool_rounds": 1}) + fake_client = MagicMock() + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "default-model"), + ): + # Tool response (round 1 of 1 allowed) + fake_client.chat.completions.create.return_value = _make_llm_tool_response() + r1 = asyncio.run(handler(None, _make_sampling_params())) + assert isinstance(r1, CreateMessageResultWithTools) + + # Text response resets counter + fake_client.chat.completions.create.return_value = _make_llm_response() + r2 = asyncio.run(handler(None, _make_sampling_params())) + assert isinstance(r2, CreateMessageResult) + + # Tool response again (should succeed since counter was reset) + fake_client.chat.completions.create.return_value = _make_llm_tool_response() + r3 = asyncio.run(handler(None, _make_sampling_params())) + assert isinstance(r3, CreateMessageResultWithTools) + + def test_max_tool_rounds_zero_disables(self): + """max_tool_rounds=0 means tool loops are disabled entirely.""" + handler = SamplingHandler("tl3", {"max_tool_rounds": 0}) + fake_client = MagicMock() + fake_client.chat.completions.create.return_value = _make_llm_tool_response() + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "default-model"), + ): + result = asyncio.run(handler(None, _make_sampling_params())) + assert isinstance(result, ErrorData) + assert "Tool loops disabled" in result.message + + +# --------------------------------------------------------------------------- +# 9. Error paths: rate limit, timeout, no provider +# --------------------------------------------------------------------------- + +class TestSamplingErrors: + def test_rate_limit_error(self): + handler = SamplingHandler("rle", {"max_rpm": 1}) + fake_client = MagicMock() + fake_client.chat.completions.create.return_value = _make_llm_response() + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "default-model"), + ): + # First call succeeds + r1 = asyncio.run(handler(None, _make_sampling_params())) + assert isinstance(r1, CreateMessageResult) + # Second call is rate limited + r2 = asyncio.run(handler(None, _make_sampling_params())) + assert isinstance(r2, ErrorData) + assert "rate limit" in r2.message.lower() + assert handler.metrics["errors"] == 1 + + def test_timeout_error(self): + handler = SamplingHandler("to", {"timeout": 0.05}) + fake_client = MagicMock() + + def slow_call(**kwargs): + import threading + # Use an event to ensure the thread truly blocks long enough + evt = threading.Event() + evt.wait(5) # blocks for up to 5 seconds (cancelled by timeout) + return _make_llm_response() + + fake_client.chat.completions.create.side_effect = slow_call + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "default-model"), + ): + result = asyncio.run(handler(None, _make_sampling_params())) + assert isinstance(result, ErrorData) + assert "timed out" in result.message.lower() + assert handler.metrics["errors"] == 1 + + def test_no_provider_error(self): + handler = SamplingHandler("np", {}) + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(None, None), + ): + result = asyncio.run(handler(None, _make_sampling_params())) + assert isinstance(result, ErrorData) + assert "No LLM provider" in result.message + assert handler.metrics["errors"] == 1 + + +# --------------------------------------------------------------------------- +# 10. Model whitelist +# --------------------------------------------------------------------------- + +class TestModelWhitelist: + def test_allowed_model_passes(self): + handler = SamplingHandler("wl", {"allowed_models": ["gpt-4o", "test-model"]}) + fake_client = MagicMock() + fake_client.chat.completions.create.return_value = _make_llm_response() + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "test-model"), + ): + result = asyncio.run(handler(None, _make_sampling_params())) + assert isinstance(result, CreateMessageResult) + + def test_disallowed_model_rejected(self): + handler = SamplingHandler("wl2", {"allowed_models": ["gpt-4o"]}) + fake_client = MagicMock() + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "gpt-3.5-turbo"), + ): + result = asyncio.run(handler(None, _make_sampling_params())) + assert isinstance(result, ErrorData) + assert "not allowed" in result.message + assert handler.metrics["errors"] == 1 + + def test_empty_whitelist_allows_all(self): + handler = SamplingHandler("wl3", {"allowed_models": []}) + fake_client = MagicMock() + fake_client.chat.completions.create.return_value = _make_llm_response() + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "any-model"), + ): + result = asyncio.run(handler(None, _make_sampling_params())) + assert isinstance(result, CreateMessageResult) + + +# --------------------------------------------------------------------------- +# 11. Malformed tool_call arguments +# --------------------------------------------------------------------------- + +class TestMalformedToolCallArgs: + def test_invalid_json_wrapped_as_raw(self): + """Malformed JSON arguments get wrapped in {"_raw": ...}.""" + handler = SamplingHandler("mf", {}) + fake_client = MagicMock() + fake_client.chat.completions.create.return_value = _make_llm_tool_response( + tool_calls_data=[("call_x", "some_tool", "not valid json {{{")] + ) + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "default-model"), + ): + result = asyncio.run(handler(None, _make_sampling_params())) + + assert isinstance(result, CreateMessageResultWithTools) + tc = result.content[0] + assert isinstance(tc, ToolUseContent) + assert tc.input == {"_raw": "not valid json {{{"} + + def test_dict_args_pass_through(self): + """When arguments are already a dict, they pass through directly.""" + handler = SamplingHandler("mf2", {}) + + # Build a tool call where arguments is already a dict + tc_obj = SimpleNamespace( + id="call_d", + function=SimpleNamespace(name="do_stuff", arguments={"key": "val"}), + ) + message = SimpleNamespace(content=None, tool_calls=[tc_obj]) + choice = SimpleNamespace(finish_reason="tool_calls", message=message) + usage = SimpleNamespace(total_tokens=10) + response = SimpleNamespace(choices=[choice], model="m", usage=usage) + + fake_client = MagicMock() + fake_client.chat.completions.create.return_value = response + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "default-model"), + ): + result = asyncio.run(handler(None, _make_sampling_params())) + + assert isinstance(result, CreateMessageResultWithTools) + assert result.content[0].input == {"key": "val"} + + +# --------------------------------------------------------------------------- +# 12. Metrics tracking +# --------------------------------------------------------------------------- + +class TestMetricsTracking: + def test_request_and_token_metrics(self): + handler = SamplingHandler("met", {}) + fake_client = MagicMock() + fake_client.chat.completions.create.return_value = _make_llm_response() + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "default-model"), + ): + asyncio.run(handler(None, _make_sampling_params())) + + assert handler.metrics["requests"] == 1 + assert handler.metrics["tokens_used"] == 42 + assert handler.metrics["errors"] == 0 + + def test_tool_use_count_metric(self): + handler = SamplingHandler("met2", {}) + fake_client = MagicMock() + fake_client.chat.completions.create.return_value = _make_llm_tool_response() + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(fake_client, "default-model"), + ): + asyncio.run(handler(None, _make_sampling_params())) + + assert handler.metrics["tool_use_count"] == 1 + assert handler.metrics["requests"] == 1 + + def test_error_metric_incremented(self): + handler = SamplingHandler("met3", {}) + + with patch( + "agent.auxiliary_client.get_text_auxiliary_client", + return_value=(None, None), + ): + asyncio.run(handler(None, _make_sampling_params())) + + assert handler.metrics["errors"] == 1 + assert handler.metrics["requests"] == 0 + + +# --------------------------------------------------------------------------- +# 13. session_kwargs() +# --------------------------------------------------------------------------- + +class TestSessionKwargs: + def test_returns_correct_keys(self): + handler = SamplingHandler("sk", {}) + kwargs = handler.session_kwargs() + assert "sampling_callback" in kwargs + assert "sampling_capabilities" in kwargs + assert kwargs["sampling_callback"] is handler + + def test_sampling_capabilities_type(self): + handler = SamplingHandler("sk2", {}) + kwargs = handler.session_kwargs() + cap = kwargs["sampling_capabilities"] + assert isinstance(cap, SamplingCapability) + assert isinstance(cap.tools, SamplingToolsCapability) + + +# --------------------------------------------------------------------------- +# 14. MCPServerTask integration +# --------------------------------------------------------------------------- + +class TestMCPServerTaskSamplingIntegration: + def test_sampling_handler_created_when_enabled(self): + """MCPServerTask.run() creates a SamplingHandler when sampling is enabled.""" + from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES + + server = MCPServerTask("int_test") + config = { + "command": "fake", + "sampling": {"enabled": True, "max_rpm": 5}, + } + # We only need to test the setup logic, not the actual connection. + # Calling run() would attempt a real connection, so we test the + # sampling setup portion directly. + server._config = config + sampling_config = config.get("sampling", {}) + if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES: + server._sampling = SamplingHandler(server.name, sampling_config) + else: + server._sampling = None + + assert server._sampling is not None + assert isinstance(server._sampling, SamplingHandler) + assert server._sampling.server_name == "int_test" + assert server._sampling.max_rpm == 5 + + def test_sampling_handler_none_when_disabled(self): + """MCPServerTask._sampling is None when sampling is disabled.""" + from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES + + server = MCPServerTask("int_test2") + config = { + "command": "fake", + "sampling": {"enabled": False}, + } + server._config = config + sampling_config = config.get("sampling", {}) + if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES: + server._sampling = SamplingHandler(server.name, sampling_config) + else: + server._sampling = None + + assert server._sampling is None + + def test_session_kwargs_used_in_stdio(self): + """When sampling is set, session_kwargs() are passed to ClientSession.""" + from tools.mcp_tool import MCPServerTask + + server = MCPServerTask("sk_test") + server._sampling = SamplingHandler("sk_test", {"max_rpm": 7}) + kwargs = server._sampling.session_kwargs() + assert "sampling_callback" in kwargs + assert "sampling_capabilities" in kwargs diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 55e1f7d59..deb87d483 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -29,6 +29,18 @@ Example config:: headers: Authorization: "Bearer sk-..." timeout: 180 + analysis: + command: "npx" + args: ["-y", "analysis-server"] + sampling: # server-initiated LLM requests + enabled: true # default: true + model: "gemini-3-flash" # override model (optional) + max_tokens_cap: 4096 # max tokens per request + timeout: 30 # LLM call timeout (seconds) + max_rpm: 10 # max requests per minute + allowed_models: [] # model whitelist (empty = all) + max_tool_rounds: 5 # tool loop limit (0 = disable) + log_level: "info" # audit verbosity Features: - Stdio transport (command + args) and HTTP/StreamableHTTP transport (url) @@ -37,6 +49,8 @@ Features: - Credential stripping in error messages returned to the LLM - Configurable per-server timeouts for tool calls and connections - Thread-safe architecture with dedicated background event loop + - Sampling support: MCP servers can request LLM completions via + sampling/createMessage (text and tool-use responses) Architecture: A dedicated background event loop (_mcp_loop) runs in a daemon thread. @@ -58,9 +72,11 @@ Thread safety: import asyncio import json import logging +import math import os import re import threading +import time from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) @@ -71,6 +87,7 @@ logger = logging.getLogger(__name__) _MCP_AVAILABLE = False _MCP_HTTP_AVAILABLE = False +_MCP_SAMPLING_TYPES = False try: from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client @@ -80,6 +97,20 @@ try: _MCP_HTTP_AVAILABLE = True except ImportError: _MCP_HTTP_AVAILABLE = False + # Sampling types -- separated so older SDK versions don't break MCP support + try: + from mcp.types import ( + CreateMessageResult, + CreateMessageResultWithTools, + ErrorData, + SamplingCapability, + SamplingToolsCapability, + TextContent, + ToolUseContent, + ) + _MCP_SAMPLING_TYPES = True + except ImportError: + logger.debug("MCP sampling types not available -- sampling disabled") except ImportError: logger.debug("mcp package not installed -- MCP tool support disabled") @@ -145,6 +176,386 @@ def _sanitize_error(text: str) -> str: return _CREDENTIAL_PATTERN.sub("[REDACTED]", text) +# --------------------------------------------------------------------------- +# Sampling -- server-initiated LLM requests (MCP sampling/createMessage) +# --------------------------------------------------------------------------- + +def _safe_numeric(value, default, coerce=int, minimum=1): + """Coerce a config value to a numeric type, returning *default* on failure. + + Handles string values from YAML (e.g. ``"10"`` instead of ``10``), + non-finite floats, and values below *minimum*. + """ + try: + result = coerce(value) + if isinstance(result, float) and not math.isfinite(result): + return default + return max(result, minimum) + except (TypeError, ValueError, OverflowError): + return default + + +class SamplingHandler: + """Handles sampling/createMessage requests for a single MCP server. + + Each MCPServerTask that has sampling enabled creates one SamplingHandler. + The handler is callable and passed directly to ``ClientSession`` as + the ``sampling_callback``. All state (rate-limit timestamps, metrics, + tool-loop counters) lives on the instance -- no module-level globals. + + The callback is async and runs on the MCP background event loop. The + sync LLM call is offloaded to a thread via ``asyncio.to_thread()`` so + it doesn't block the event loop. + """ + + _STOP_REASON_MAP = {"stop": "endTurn", "length": "maxTokens", "tool_calls": "toolUse"} + + def __init__(self, server_name: str, config: dict): + self.server_name = server_name + self.max_rpm = _safe_numeric(config.get("max_rpm", 10), 10, int) + self.timeout = _safe_numeric(config.get("timeout", 30), 30, float) + self.max_tokens_cap = _safe_numeric(config.get("max_tokens_cap", 4096), 4096, int) + self.max_tool_rounds = _safe_numeric( + config.get("max_tool_rounds", 5), 5, int, minimum=0, + ) + self.model_override = config.get("model") + self.allowed_models = config.get("allowed_models", []) + + _log_levels = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING} + self.audit_level = _log_levels.get( + str(config.get("log_level", "info")).lower(), logging.INFO, + ) + + # Per-instance state + self._rate_timestamps: List[float] = [] + self._tool_loop_count = 0 + self.metrics = {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0} + + # -- Rate limiting ------------------------------------------------------- + + def _check_rate_limit(self) -> bool: + """Sliding-window rate limiter. Returns True if request is allowed.""" + now = time.time() + window = now - 60 + self._rate_timestamps[:] = [t for t in self._rate_timestamps if t > window] + if len(self._rate_timestamps) >= self.max_rpm: + return False + self._rate_timestamps.append(now) + return True + + # -- Model resolution ---------------------------------------------------- + + def _resolve_model(self, preferences) -> Optional[str]: + """Config override > server hint > None (use default).""" + if self.model_override: + return self.model_override + if preferences and hasattr(preferences, "hints") and preferences.hints: + for hint in preferences.hints: + if hasattr(hint, "name") and hint.name: + return hint.name + return None + + # -- Message conversion -------------------------------------------------- + + @staticmethod + def _extract_tool_result_text(block) -> str: + """Extract text from a ToolResultContent block.""" + if not hasattr(block, "content") or block.content is None: + return "" + items = block.content if isinstance(block.content, list) else [block.content] + return "\n".join(item.text for item in items if hasattr(item, "text")) + + def _convert_messages(self, params) -> List[dict]: + """Convert MCP SamplingMessages to OpenAI format. + + Uses ``msg.content_as_list`` (SDK helper) so single-block and + list-of-blocks are handled uniformly. Dispatches per block type + with ``isinstance`` on real SDK types when available, falling back + to duck-typing via ``hasattr`` for compatibility. + """ + messages: List[dict] = [] + for msg in params.messages: + blocks = msg.content_as_list if hasattr(msg, "content_as_list") else ( + msg.content if isinstance(msg.content, list) else [msg.content] + ) + + # Separate blocks by kind + tool_results = [b for b in blocks if hasattr(b, "toolUseId")] + tool_uses = [b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")] + content_blocks = [b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))] + + # Emit tool result messages (role: tool) + for tr in tool_results: + messages.append({ + "role": "tool", + "tool_call_id": tr.toolUseId, + "content": self._extract_tool_result_text(tr), + }) + + # Emit assistant tool_calls message + if tool_uses: + tc_list = [] + for tu in tool_uses: + tc_list.append({ + "id": getattr(tu, "id", f"call_{len(tc_list)}"), + "type": "function", + "function": { + "name": tu.name, + "arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input), + }, + }) + msg_dict: dict = {"role": msg.role, "tool_calls": tc_list} + # Include any accompanying text + text_parts = [b.text for b in content_blocks if hasattr(b, "text")] + if text_parts: + msg_dict["content"] = "\n".join(text_parts) + messages.append(msg_dict) + elif content_blocks: + # Pure text/image content + if len(content_blocks) == 1 and hasattr(content_blocks[0], "text"): + messages.append({"role": msg.role, "content": content_blocks[0].text}) + else: + parts = [] + for block in content_blocks: + if hasattr(block, "text"): + parts.append({"type": "text", "text": block.text}) + elif hasattr(block, "data") and hasattr(block, "mimeType"): + parts.append({ + "type": "image_url", + "image_url": {"url": f"data:{block.mimeType};base64,{block.data}"}, + }) + else: + logger.warning( + "Unsupported sampling content block type: %s (skipped)", + type(block).__name__, + ) + if parts: + messages.append({"role": msg.role, "content": parts}) + + return messages + + # -- Error helper -------------------------------------------------------- + + @staticmethod + def _error(message: str, code: int = -1): + """Return ErrorData (MCP spec) or raise as fallback.""" + if _MCP_SAMPLING_TYPES: + return ErrorData(code=code, message=message) + raise Exception(message) + + # -- Response building --------------------------------------------------- + + def _build_tool_use_result(self, choice, response): + """Build a CreateMessageResultWithTools from an LLM tool_calls response.""" + self.metrics["tool_use_count"] += 1 + + # Tool loop governance + if self.max_tool_rounds == 0: + self._tool_loop_count = 0 + return self._error( + f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)" + ) + + self._tool_loop_count += 1 + if self._tool_loop_count > self.max_tool_rounds: + self._tool_loop_count = 0 + return self._error( + f"Tool loop limit exceeded for server '{self.server_name}' " + f"(max {self.max_tool_rounds} rounds)" + ) + + content_blocks = [] + for tc in choice.message.tool_calls: + args = tc.function.arguments + if isinstance(args, str): + try: + parsed = json.loads(args) + except (json.JSONDecodeError, ValueError): + logger.warning( + "MCP server '%s': malformed tool_calls arguments " + "from LLM (wrapping as raw): %.100s", + self.server_name, args, + ) + parsed = {"_raw": args} + else: + parsed = args if isinstance(args, dict) else {"_raw": str(args)} + + content_blocks.append(ToolUseContent( + type="tool_use", + id=tc.id, + name=tc.function.name, + input=parsed, + )) + + logger.log( + self.audit_level, + "MCP server '%s' sampling response: model=%s, tokens=%s, tool_calls=%d", + self.server_name, response.model, + getattr(getattr(response, "usage", None), "total_tokens", "?"), + len(content_blocks), + ) + + return CreateMessageResultWithTools( + role="assistant", + content=content_blocks, + model=response.model, + stopReason="toolUse", + ) + + def _build_text_result(self, choice, response): + """Build a CreateMessageResult from a normal text response.""" + self._tool_loop_count = 0 # reset on text response + response_text = choice.message.content or "" + + logger.log( + self.audit_level, + "MCP server '%s' sampling response: model=%s, tokens=%s", + self.server_name, response.model, + getattr(getattr(response, "usage", None), "total_tokens", "?"), + ) + + return CreateMessageResult( + role="assistant", + content=TextContent(type="text", text=_sanitize_error(response_text)), + model=response.model, + stopReason=self._STOP_REASON_MAP.get(choice.finish_reason, "endTurn"), + ) + + # -- Session kwargs helper ----------------------------------------------- + + def session_kwargs(self) -> dict: + """Return kwargs to pass to ClientSession for sampling support.""" + return { + "sampling_callback": self, + "sampling_capabilities": SamplingCapability( + tools=SamplingToolsCapability(), + ), + } + + # -- Main callback ------------------------------------------------------- + + async def __call__(self, context, params): + """Sampling callback invoked by the MCP SDK. + + Conforms to ``SamplingFnT`` protocol. Returns + ``CreateMessageResult``, ``CreateMessageResultWithTools``, or + ``ErrorData``. + """ + # Rate limit + if not self._check_rate_limit(): + logger.warning( + "MCP server '%s' sampling rate limit exceeded (%d/min)", + self.server_name, self.max_rpm, + ) + self.metrics["errors"] += 1 + return self._error( + f"Sampling rate limit exceeded for server '{self.server_name}' " + f"({self.max_rpm} requests/minute)" + ) + + # Resolve model + model = self._resolve_model(getattr(params, "modelPreferences", None)) + + # Get auxiliary LLM client + from agent.auxiliary_client import get_text_auxiliary_client + client, default_model = get_text_auxiliary_client() + if client is None: + self.metrics["errors"] += 1 + return self._error("No LLM provider available for sampling") + + resolved_model = model or default_model + + # Model whitelist check + if self.allowed_models and resolved_model not in self.allowed_models: + logger.warning( + "MCP server '%s' requested model '%s' not in allowed_models", + self.server_name, resolved_model, + ) + self.metrics["errors"] += 1 + return self._error( + f"Model '{resolved_model}' not allowed for server " + f"'{self.server_name}'. Allowed: {', '.join(self.allowed_models)}" + ) + + # Convert messages + messages = self._convert_messages(params) + if hasattr(params, "systemPrompt") and params.systemPrompt: + messages.insert(0, {"role": "system", "content": params.systemPrompt}) + + # Build LLM call kwargs + max_tokens = min(params.maxTokens, self.max_tokens_cap) + call_kwargs: dict = { + "model": resolved_model, + "messages": messages, + "max_tokens": max_tokens, + } + if hasattr(params, "temperature") and params.temperature is not None: + call_kwargs["temperature"] = params.temperature + if stop := getattr(params, "stopSequences", None): + call_kwargs["stop"] = stop + + # Forward server-provided tools + server_tools = getattr(params, "tools", None) + if server_tools: + call_kwargs["tools"] = [ + { + "type": "function", + "function": { + "name": getattr(t, "name", ""), + "description": getattr(t, "description", "") or "", + "parameters": getattr(t, "inputSchema", {}) or {}, + }, + } + for t in server_tools + ] + if tool_choice := getattr(params, "toolChoice", None): + mode = getattr(tool_choice, "mode", "auto") + call_kwargs["tool_choice"] = {"auto": "auto", "required": "required", "none": "none"}.get(mode, "auto") + + logger.log( + self.audit_level, + "MCP server '%s' sampling request: model=%s, max_tokens=%d, messages=%d", + self.server_name, resolved_model, max_tokens, len(messages), + ) + + # Offload sync LLM call to thread (non-blocking) + def _sync_call(): + return client.chat.completions.create(**call_kwargs) + + try: + response = await asyncio.wait_for( + asyncio.to_thread(_sync_call), timeout=self.timeout, + ) + except asyncio.TimeoutError: + self.metrics["errors"] += 1 + return self._error( + f"Sampling LLM call timed out after {self.timeout}s " + f"for server '{self.server_name}'" + ) + except Exception as exc: + self.metrics["errors"] += 1 + return self._error( + f"Sampling LLM call failed: {_sanitize_error(str(exc))}" + ) + + # Track metrics + choice = response.choices[0] + self.metrics["requests"] += 1 + total_tokens = getattr(getattr(response, "usage", None), "total_tokens", 0) + if isinstance(total_tokens, int): + self.metrics["tokens_used"] += total_tokens + + # Dispatch based on response type + if ( + choice.finish_reason == "tool_calls" + and hasattr(choice.message, "tool_calls") + and choice.message.tool_calls + ): + return self._build_tool_use_result(choice, response) + + return self._build_text_result(choice, response) + + # --------------------------------------------------------------------------- # Server task -- each MCP server lives in one long-lived asyncio Task # --------------------------------------------------------------------------- @@ -162,6 +573,7 @@ class MCPServerTask: __slots__ = ( "name", "session", "tool_timeout", "_task", "_ready", "_shutdown_event", "_tools", "_error", "_config", + "_sampling", ) def __init__(self, name: str): @@ -174,6 +586,7 @@ class MCPServerTask: self._tools: list = [] self._error: Optional[Exception] = None self._config: dict = {} + self._sampling: Optional[SamplingHandler] = None def _is_http(self) -> bool: """Check if this server uses HTTP transport.""" @@ -197,8 +610,9 @@ class MCPServerTask: env=safe_env if safe_env else None, ) + sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {} async with stdio_client(server_params) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: + async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session: await session.initialize() self.session = session await self._discover_tools() @@ -218,12 +632,13 @@ class MCPServerTask: headers = config.get("headers") connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) + sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {} async with streamablehttp_client( url, headers=headers, timeout=float(connect_timeout), ) as (read_stream, write_stream, _get_session_id): - async with ClientSession(read_stream, write_stream) as session: + async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session: await session.initialize() self.session = session await self._discover_tools() @@ -250,6 +665,13 @@ class MCPServerTask: self._config = config self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT) + # Set up sampling handler if enabled and SDK types are available + sampling_config = config.get("sampling", {}) + if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES: + self._sampling = SamplingHandler(self.name, sampling_config) + else: + self._sampling = None + # Validate: warn if both url and command are present if "url" in config and "command" in config: logger.warning( @@ -975,12 +1397,15 @@ def get_mcp_status() -> List[dict]: transport = "http" if "url" in cfg else "stdio" server = active_servers.get(name) if server and server.session is not None: - result.append({ + entry = { "name": name, "transport": transport, "tools": len(server._tools), "connected": True, - }) + } + if server._sampling: + entry["sampling"] = dict(server._sampling.metrics) + result.append(entry) else: result.append({ "name": name, diff --git a/website/docs/user-guide/features/mcp.md b/website/docs/user-guide/features/mcp.md index 0297b152d..9a29d4316 100644 --- a/website/docs/user-guide/features/mcp.md +++ b/website/docs/user-guide/features/mcp.md @@ -271,3 +271,62 @@ You can reload MCP servers without restarting Hermes: - In the CLI: the agent reconnects automatically - In messaging: send `/reload-mcp` + +## Sampling (Server-Initiated LLM Requests) + +MCP's `sampling/createMessage` capability allows MCP servers to request LLM completions through the Hermes agent. This enables agent-in-the-loop workflows where servers can leverage the LLM during tool execution — for example, a database server asking the LLM to interpret query results, or a code analysis server requesting the LLM to review findings. + +### How It Works + +When an MCP server sends a `sampling/createMessage` request: + +1. The sampling callback validates against rate limits and model whitelist +2. Resolves which model to use (config override > server hint > default) +3. Converts MCP messages to OpenAI-compatible format +4. Offloads the LLM call to a thread via `asyncio.to_thread()` (non-blocking) +5. Returns the response (text or tool use) back to the server + +### Configuration + +Sampling is **enabled by default** for all MCP servers. No extra setup needed — if you have an auxiliary LLM client configured, sampling works automatically. + +```yaml +mcp_servers: + analysis_server: + command: "npx" + args: ["-y", "my-analysis-server"] + sampling: + enabled: true # default: true + model: "gemini-3-flash" # override model (optional) + max_tokens_cap: 4096 # max tokens per request (default: 4096) + timeout: 30 # LLM call timeout in seconds (default: 30) + max_rpm: 10 # max requests per minute (default: 10) + allowed_models: [] # model whitelist (empty = allow all) + max_tool_rounds: 5 # max consecutive tool use rounds (0 = disable) + log_level: "info" # audit verbosity: debug, info, warning +``` + +### Tool Use in Sampling + +Servers can include `tools` and `toolChoice` in sampling requests, enabling multi-turn tool-augmented workflows within a single sampling session. The callback forwards tool definitions to the LLM, handles tool use responses with proper `ToolUseContent` types, and enforces `max_tool_rounds` to prevent infinite loops. + +### Security + +- **Rate limiting**: Per-server sliding window (default: 10 req/min) +- **Token cap**: Servers can't request more than `max_tokens_cap` (default: 4096) +- **Model whitelist**: `allowed_models` restricts which models a server can use +- **Tool loop limit**: `max_tool_rounds` caps consecutive tool use rounds +- **Credential stripping**: LLM responses are sanitized before returning to the server +- **Non-blocking**: LLM calls run in a separate thread via `asyncio.to_thread()` +- **Typed errors**: All failures return structured `ErrorData` per MCP spec + +To disable sampling for untrusted servers: + +```yaml +mcp_servers: + untrusted: + command: "npx" + args: ["-y", "untrusted-server"] + sampling: + enabled: false +```