feat(mcp): add sampling support — server-initiated LLM requests (#753)

Add MCP sampling/createMessage capability via SamplingHandler class.

Text-only sampling + tool use in sampling with governance (rate limits,
model whitelist, token caps, tool loop limits). Per-server audit metrics.

Based on concept from PR #366 by eren-karakus0. Restructured as class-based
design with bug fixes and tests using real MCP SDK types.

50 new tests, 2600 total passing.
This commit is contained in:
Teknium 2026-03-09 03:37:38 -07:00 committed by GitHub
parent 1f0944de21
commit 654e16187e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 1307 additions and 4 deletions

View file

@ -555,6 +555,21 @@ toolsets:
# args: ["-y", "@modelcontextprotocol/server-github"]
# env:
# GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..."
#
# Sampling (server-initiated LLM requests) — enabled by default.
# Per-server config under the 'sampling' key:
# analysis:
# command: npx
# args: ["-y", "analysis-server"]
# sampling:
# enabled: true # default: true
# model: "gemini-3-flash" # override model (optional)
# max_tokens_cap: 4096 # max tokens per request
# timeout: 30 # LLM call timeout (seconds)
# max_rpm: 10 # max requests per minute
# allowed_models: [] # model whitelist (empty = all)
# max_tool_rounds: 5 # tool loop limit (0 = disable)
# log_level: "info" # audit verbosity
# =============================================================================
# Voice Transcription (Speech-to-Text)

View file

@ -321,6 +321,32 @@ mcp_servers:
All tools from all servers are registered and available simultaneously. Each server's tools are prefixed with its name to avoid collisions.
## Sampling (Server-Initiated LLM Requests)
Hermes supports MCP's `sampling/createMessage` capability — MCP servers can request LLM completions through the agent during tool execution. This enables agent-in-the-loop workflows (data analysis, content generation, decision-making).
Sampling is **enabled by default**. Configure per server:
```yaml
mcp_servers:
my_server:
command: "npx"
args: ["-y", "my-mcp-server"]
sampling:
enabled: true # default: true
model: "gemini-3-flash" # model override (optional)
max_tokens_cap: 4096 # max tokens per request
timeout: 30 # LLM call timeout (seconds)
max_rpm: 10 # max requests per minute
allowed_models: [] # model whitelist (empty = all)
max_tool_rounds: 5 # tool loop limit (0 = disable)
log_level: "info" # audit verbosity
```
Servers can also include `tools` in sampling requests for multi-turn tool-augmented workflows. The `max_tool_rounds` config prevents infinite tool loops. Per-server audit metrics (requests, errors, tokens, tool use count) are tracked via `get_mcp_status()`.
Disable sampling for untrusted servers with `sampling: { enabled: false }`.
## Notes
- MCP tools are called synchronously from the agent's perspective but run asynchronously on a dedicated background event loop

View file

@ -1489,3 +1489,781 @@ class TestUtilityToolRegistration:
assert entry.check_fn() is False
_servers.pop("chk", None)
# ===========================================================================
# SamplingHandler tests
# ===========================================================================
import math
import time
from mcp.types import (
CreateMessageResult,
CreateMessageResultWithTools,
ErrorData,
SamplingCapability,
SamplingToolsCapability,
TextContent,
ToolUseContent,
)
from tools.mcp_tool import SamplingHandler, _safe_numeric
# ---------------------------------------------------------------------------
# Helpers for sampling tests
# ---------------------------------------------------------------------------
def _make_sampling_params(
messages=None,
max_tokens=100,
system_prompt=None,
model_preferences=None,
temperature=None,
stop_sequences=None,
tools=None,
tool_choice=None,
):
"""Create a fake CreateMessageRequestParams using SimpleNamespace.
Each message must have a ``content_as_list`` attribute that mirrors
the SDK helper so that ``_convert_messages`` works correctly.
"""
if messages is None:
content = SimpleNamespace(text="Hello")
msg = SimpleNamespace(role="user", content=content, content_as_list=[content])
messages = [msg]
params = SimpleNamespace(
messages=messages,
maxTokens=max_tokens,
modelPreferences=model_preferences,
temperature=temperature,
stopSequences=stop_sequences,
tools=tools,
toolChoice=tool_choice,
)
if system_prompt is not None:
params.systemPrompt = system_prompt
return params
def _make_llm_response(
content="LLM response",
model="test-model",
finish_reason="stop",
tool_calls=None,
):
"""Create a fake OpenAI chat completion response (text)."""
message = SimpleNamespace(content=content, tool_calls=tool_calls)
choice = SimpleNamespace(
finish_reason=finish_reason,
message=message,
)
usage = SimpleNamespace(total_tokens=42)
return SimpleNamespace(choices=[choice], model=model, usage=usage)
def _make_llm_tool_response(tool_calls_data=None, model="test-model"):
"""Create a fake response with tool_calls.
``tool_calls_data``: list of (id, name, arguments_json) tuples.
"""
if tool_calls_data is None:
tool_calls_data = [("call_1", "get_weather", '{"city": "London"}')]
tc_list = [
SimpleNamespace(
id=tc_id,
function=SimpleNamespace(name=name, arguments=args),
)
for tc_id, name, args in tool_calls_data
]
return _make_llm_response(
content=None,
model=model,
finish_reason="tool_calls",
tool_calls=tc_list,
)
# ---------------------------------------------------------------------------
# 1. _safe_numeric helper
# ---------------------------------------------------------------------------
class TestSafeNumeric:
def test_int_passthrough(self):
assert _safe_numeric(10, 5, int) == 10
def test_string_coercion(self):
assert _safe_numeric("20", 5, int) == 20
def test_none_returns_default(self):
assert _safe_numeric(None, 7, int) == 7
def test_inf_returns_default(self):
assert _safe_numeric(float("inf"), 3.0, float) == 3.0
def test_nan_returns_default(self):
assert _safe_numeric(float("nan"), 4.0, float) == 4.0
def test_below_minimum_clamps(self):
assert _safe_numeric(-5, 10, int, minimum=1) == 1
def test_minimum_zero_allowed(self):
assert _safe_numeric(0, 10, int, minimum=0) == 0
def test_non_numeric_string_returns_default(self):
assert _safe_numeric("abc", 42, int) == 42
def test_float_coercion(self):
assert _safe_numeric("3.5", 1.0, float) == 3.5
# ---------------------------------------------------------------------------
# 2. SamplingHandler initialization and config parsing
# ---------------------------------------------------------------------------
class TestSamplingHandlerInit:
def test_defaults(self):
h = SamplingHandler("srv", {})
assert h.server_name == "srv"
assert h.max_rpm == 10
assert h.timeout == 30
assert h.max_tokens_cap == 4096
assert h.max_tool_rounds == 5
assert h.model_override is None
assert h.allowed_models == []
assert h.metrics == {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
def test_custom_config(self):
cfg = {
"max_rpm": 20,
"timeout": 60,
"max_tokens_cap": 2048,
"max_tool_rounds": 3,
"model": "gpt-4o",
"allowed_models": ["gpt-4o", "gpt-3.5-turbo"],
"log_level": "debug",
}
h = SamplingHandler("custom", cfg)
assert h.max_rpm == 20
assert h.timeout == 60.0
assert h.max_tokens_cap == 2048
assert h.max_tool_rounds == 3
assert h.model_override == "gpt-4o"
assert h.allowed_models == ["gpt-4o", "gpt-3.5-turbo"]
def test_string_numeric_config_values(self):
"""YAML sometimes delivers numeric values as strings."""
cfg = {"max_rpm": "15", "timeout": "45.5", "max_tokens_cap": "1024"}
h = SamplingHandler("s", cfg)
assert h.max_rpm == 15
assert h.timeout == 45.5
assert h.max_tokens_cap == 1024
# ---------------------------------------------------------------------------
# 3. Rate limiting
# ---------------------------------------------------------------------------
class TestRateLimit:
def setup_method(self):
self.handler = SamplingHandler("rl", {"max_rpm": 3})
def test_allows_under_limit(self):
assert self.handler._check_rate_limit() is True
assert self.handler._check_rate_limit() is True
assert self.handler._check_rate_limit() is True
def test_rejects_over_limit(self):
for _ in range(3):
self.handler._check_rate_limit()
assert self.handler._check_rate_limit() is False
def test_window_expiry(self):
"""Old timestamps should be purged from the sliding window."""
for _ in range(3):
self.handler._check_rate_limit()
# Simulate timestamps from 61 seconds ago
self.handler._rate_timestamps[:] = [time.time() - 61] * 3
assert self.handler._check_rate_limit() is True
# ---------------------------------------------------------------------------
# 4. Model resolution
# ---------------------------------------------------------------------------
class TestResolveModel:
def setup_method(self):
self.handler = SamplingHandler("mr", {})
def test_no_preference_no_override(self):
assert self.handler._resolve_model(None) is None
def test_config_override_wins(self):
self.handler.model_override = "override-model"
prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")])
assert self.handler._resolve_model(prefs) == "override-model"
def test_hint_used_when_no_override(self):
prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")])
assert self.handler._resolve_model(prefs) == "hint-model"
def test_empty_hints(self):
prefs = SimpleNamespace(hints=[])
assert self.handler._resolve_model(prefs) is None
def test_hint_without_name(self):
prefs = SimpleNamespace(hints=[SimpleNamespace(name=None)])
assert self.handler._resolve_model(prefs) is None
# ---------------------------------------------------------------------------
# 5. Message conversion
# ---------------------------------------------------------------------------
class TestConvertMessages:
def setup_method(self):
self.handler = SamplingHandler("mc", {})
def test_single_text_message(self):
content = SimpleNamespace(text="Hello world")
msg = SimpleNamespace(role="user", content=content, content_as_list=[content])
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
assert result[0] == {"role": "user", "content": "Hello world"}
def test_image_message(self):
text_block = SimpleNamespace(text="Look at this")
img_block = SimpleNamespace(data="abc123", mimeType="image/png")
msg = SimpleNamespace(
role="user",
content=[text_block, img_block],
content_as_list=[text_block, img_block],
)
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
parts = result[0]["content"]
assert len(parts) == 2
assert parts[0] == {"type": "text", "text": "Look at this"}
assert parts[1]["type"] == "image_url"
assert "data:image/png;base64,abc123" in parts[1]["image_url"]["url"]
def test_tool_result_message(self):
inner = SimpleNamespace(text="42 degrees")
tr_block = SimpleNamespace(toolUseId="call_1", content=[inner])
msg = SimpleNamespace(
role="user",
content=[tr_block],
content_as_list=[tr_block],
)
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
assert result[0]["role"] == "tool"
assert result[0]["tool_call_id"] == "call_1"
assert result[0]["content"] == "42 degrees"
def test_tool_use_message(self):
tu_block = SimpleNamespace(
id="call_2", name="get_weather", input={"city": "London"}
)
msg = SimpleNamespace(
role="assistant",
content=[tu_block],
content_as_list=[tu_block],
)
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
assert result[0]["role"] == "assistant"
assert len(result[0]["tool_calls"]) == 1
assert result[0]["tool_calls"][0]["function"]["name"] == "get_weather"
assert json.loads(result[0]["tool_calls"][0]["function"]["arguments"]) == {"city": "London"}
def test_mixed_text_and_tool_use(self):
"""Assistant message with both text and tool_calls."""
text_block = SimpleNamespace(text="Let me check the weather")
tu_block = SimpleNamespace(
id="call_3", name="get_weather", input={"city": "Paris"}
)
msg = SimpleNamespace(
role="assistant",
content=[text_block, tu_block],
content_as_list=[text_block, tu_block],
)
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
assert result[0]["content"] == "Let me check the weather"
assert len(result[0]["tool_calls"]) == 1
def test_fallback_without_content_as_list(self):
"""When content_as_list is absent, falls back to content."""
content = SimpleNamespace(text="Fallback text")
msg = SimpleNamespace(role="user", content=content)
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
assert result[0]["content"] == "Fallback text"
# ---------------------------------------------------------------------------
# 6. Text-only sampling callback (full flow)
# ---------------------------------------------------------------------------
class TestSamplingCallbackText:
def setup_method(self):
self.handler = SamplingHandler("txt", {})
def test_text_response(self):
"""Full flow: text response returns CreateMessageResult."""
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response(
content="Hello from LLM"
)
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
params = _make_sampling_params()
result = asyncio.run(self.handler(None, params))
assert isinstance(result, CreateMessageResult)
assert isinstance(result.content, TextContent)
assert result.content.text == "Hello from LLM"
assert result.model == "test-model"
assert result.role == "assistant"
assert result.stopReason == "endTurn"
def test_system_prompt_prepended(self):
"""System prompt is inserted as the first message."""
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
params = _make_sampling_params(system_prompt="Be helpful")
asyncio.run(self.handler(None, params))
call_args = fake_client.chat.completions.create.call_args
messages = call_args.kwargs["messages"]
assert messages[0] == {"role": "system", "content": "Be helpful"}
def test_length_stop_reason(self):
"""finish_reason='length' maps to stopReason='maxTokens'."""
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response(
finish_reason="length"
)
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
params = _make_sampling_params()
result = asyncio.run(self.handler(None, params))
assert isinstance(result, CreateMessageResult)
assert result.stopReason == "maxTokens"
# ---------------------------------------------------------------------------
# 7. Tool use sampling callback
# ---------------------------------------------------------------------------
class TestSamplingCallbackToolUse:
def setup_method(self):
self.handler = SamplingHandler("tu", {})
def test_tool_use_response(self):
"""LLM tool_calls response returns CreateMessageResultWithTools."""
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
params = _make_sampling_params()
result = asyncio.run(self.handler(None, params))
assert isinstance(result, CreateMessageResultWithTools)
assert result.stopReason == "toolUse"
assert result.model == "test-model"
assert len(result.content) == 1
tc = result.content[0]
assert isinstance(tc, ToolUseContent)
assert tc.name == "get_weather"
assert tc.id == "call_1"
assert tc.input == {"city": "London"}
def test_multiple_tool_calls(self):
"""Multiple tool_calls in a single response."""
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response(
tool_calls_data=[
("call_a", "func_a", '{"x": 1}'),
("call_b", "func_b", '{"y": 2}'),
]
)
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
result = asyncio.run(self.handler(None, _make_sampling_params()))
assert isinstance(result, CreateMessageResultWithTools)
assert len(result.content) == 2
assert result.content[0].name == "func_a"
assert result.content[1].name == "func_b"
# ---------------------------------------------------------------------------
# 8. Tool loop governance
# ---------------------------------------------------------------------------
class TestToolLoopGovernance:
def test_max_tool_rounds_enforcement(self):
"""After max_tool_rounds consecutive tool responses, an error is returned."""
handler = SamplingHandler("tl", {"max_tool_rounds": 2})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
params = _make_sampling_params()
# Round 1, 2: allowed
r1 = asyncio.run(handler(None, params))
assert isinstance(r1, CreateMessageResultWithTools)
r2 = asyncio.run(handler(None, params))
assert isinstance(r2, CreateMessageResultWithTools)
# Round 3: exceeds limit
r3 = asyncio.run(handler(None, params))
assert isinstance(r3, ErrorData)
assert "Tool loop limit exceeded" in r3.message
def test_text_response_resets_counter(self):
"""A text response resets the tool loop counter."""
handler = SamplingHandler("tl2", {"max_tool_rounds": 1})
fake_client = MagicMock()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
# Tool response (round 1 of 1 allowed)
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
r1 = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(r1, CreateMessageResultWithTools)
# Text response resets counter
fake_client.chat.completions.create.return_value = _make_llm_response()
r2 = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(r2, CreateMessageResult)
# Tool response again (should succeed since counter was reset)
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
r3 = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(r3, CreateMessageResultWithTools)
def test_max_tool_rounds_zero_disables(self):
"""max_tool_rounds=0 means tool loops are disabled entirely."""
handler = SamplingHandler("tl3", {"max_tool_rounds": 0})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, ErrorData)
assert "Tool loops disabled" in result.message
# ---------------------------------------------------------------------------
# 9. Error paths: rate limit, timeout, no provider
# ---------------------------------------------------------------------------
class TestSamplingErrors:
def test_rate_limit_error(self):
handler = SamplingHandler("rle", {"max_rpm": 1})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
# First call succeeds
r1 = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(r1, CreateMessageResult)
# Second call is rate limited
r2 = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(r2, ErrorData)
assert "rate limit" in r2.message.lower()
assert handler.metrics["errors"] == 1
def test_timeout_error(self):
handler = SamplingHandler("to", {"timeout": 0.05})
fake_client = MagicMock()
def slow_call(**kwargs):
import threading
# Use an event to ensure the thread truly blocks long enough
evt = threading.Event()
evt.wait(5) # blocks for up to 5 seconds (cancelled by timeout)
return _make_llm_response()
fake_client.chat.completions.create.side_effect = slow_call
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, ErrorData)
assert "timed out" in result.message.lower()
assert handler.metrics["errors"] == 1
def test_no_provider_error(self):
handler = SamplingHandler("np", {})
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(None, None),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, ErrorData)
assert "No LLM provider" in result.message
assert handler.metrics["errors"] == 1
# ---------------------------------------------------------------------------
# 10. Model whitelist
# ---------------------------------------------------------------------------
class TestModelWhitelist:
def test_allowed_model_passes(self):
handler = SamplingHandler("wl", {"allowed_models": ["gpt-4o", "test-model"]})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "test-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, CreateMessageResult)
def test_disallowed_model_rejected(self):
handler = SamplingHandler("wl2", {"allowed_models": ["gpt-4o"]})
fake_client = MagicMock()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "gpt-3.5-turbo"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, ErrorData)
assert "not allowed" in result.message
assert handler.metrics["errors"] == 1
def test_empty_whitelist_allows_all(self):
handler = SamplingHandler("wl3", {"allowed_models": []})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "any-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, CreateMessageResult)
# ---------------------------------------------------------------------------
# 11. Malformed tool_call arguments
# ---------------------------------------------------------------------------
class TestMalformedToolCallArgs:
def test_invalid_json_wrapped_as_raw(self):
"""Malformed JSON arguments get wrapped in {"_raw": ...}."""
handler = SamplingHandler("mf", {})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response(
tool_calls_data=[("call_x", "some_tool", "not valid json {{{")]
)
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, CreateMessageResultWithTools)
tc = result.content[0]
assert isinstance(tc, ToolUseContent)
assert tc.input == {"_raw": "not valid json {{{"}
def test_dict_args_pass_through(self):
"""When arguments are already a dict, they pass through directly."""
handler = SamplingHandler("mf2", {})
# Build a tool call where arguments is already a dict
tc_obj = SimpleNamespace(
id="call_d",
function=SimpleNamespace(name="do_stuff", arguments={"key": "val"}),
)
message = SimpleNamespace(content=None, tool_calls=[tc_obj])
choice = SimpleNamespace(finish_reason="tool_calls", message=message)
usage = SimpleNamespace(total_tokens=10)
response = SimpleNamespace(choices=[choice], model="m", usage=usage)
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = response
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, CreateMessageResultWithTools)
assert result.content[0].input == {"key": "val"}
# ---------------------------------------------------------------------------
# 12. Metrics tracking
# ---------------------------------------------------------------------------
class TestMetricsTracking:
def test_request_and_token_metrics(self):
handler = SamplingHandler("met", {})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
asyncio.run(handler(None, _make_sampling_params()))
assert handler.metrics["requests"] == 1
assert handler.metrics["tokens_used"] == 42
assert handler.metrics["errors"] == 0
def test_tool_use_count_metric(self):
handler = SamplingHandler("met2", {})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
asyncio.run(handler(None, _make_sampling_params()))
assert handler.metrics["tool_use_count"] == 1
assert handler.metrics["requests"] == 1
def test_error_metric_incremented(self):
handler = SamplingHandler("met3", {})
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(None, None),
):
asyncio.run(handler(None, _make_sampling_params()))
assert handler.metrics["errors"] == 1
assert handler.metrics["requests"] == 0
# ---------------------------------------------------------------------------
# 13. session_kwargs()
# ---------------------------------------------------------------------------
class TestSessionKwargs:
def test_returns_correct_keys(self):
handler = SamplingHandler("sk", {})
kwargs = handler.session_kwargs()
assert "sampling_callback" in kwargs
assert "sampling_capabilities" in kwargs
assert kwargs["sampling_callback"] is handler
def test_sampling_capabilities_type(self):
handler = SamplingHandler("sk2", {})
kwargs = handler.session_kwargs()
cap = kwargs["sampling_capabilities"]
assert isinstance(cap, SamplingCapability)
assert isinstance(cap.tools, SamplingToolsCapability)
# ---------------------------------------------------------------------------
# 14. MCPServerTask integration
# ---------------------------------------------------------------------------
class TestMCPServerTaskSamplingIntegration:
def test_sampling_handler_created_when_enabled(self):
"""MCPServerTask.run() creates a SamplingHandler when sampling is enabled."""
from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES
server = MCPServerTask("int_test")
config = {
"command": "fake",
"sampling": {"enabled": True, "max_rpm": 5},
}
# We only need to test the setup logic, not the actual connection.
# Calling run() would attempt a real connection, so we test the
# sampling setup portion directly.
server._config = config
sampling_config = config.get("sampling", {})
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
server._sampling = SamplingHandler(server.name, sampling_config)
else:
server._sampling = None
assert server._sampling is not None
assert isinstance(server._sampling, SamplingHandler)
assert server._sampling.server_name == "int_test"
assert server._sampling.max_rpm == 5
def test_sampling_handler_none_when_disabled(self):
"""MCPServerTask._sampling is None when sampling is disabled."""
from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES
server = MCPServerTask("int_test2")
config = {
"command": "fake",
"sampling": {"enabled": False},
}
server._config = config
sampling_config = config.get("sampling", {})
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
server._sampling = SamplingHandler(server.name, sampling_config)
else:
server._sampling = None
assert server._sampling is None
def test_session_kwargs_used_in_stdio(self):
"""When sampling is set, session_kwargs() are passed to ClientSession."""
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("sk_test")
server._sampling = SamplingHandler("sk_test", {"max_rpm": 7})
kwargs = server._sampling.session_kwargs()
assert "sampling_callback" in kwargs
assert "sampling_capabilities" in kwargs

View file

@ -29,6 +29,18 @@ Example config::
headers:
Authorization: "Bearer sk-..."
timeout: 180
analysis:
command: "npx"
args: ["-y", "analysis-server"]
sampling: # server-initiated LLM requests
enabled: true # default: true
model: "gemini-3-flash" # override model (optional)
max_tokens_cap: 4096 # max tokens per request
timeout: 30 # LLM call timeout (seconds)
max_rpm: 10 # max requests per minute
allowed_models: [] # model whitelist (empty = all)
max_tool_rounds: 5 # tool loop limit (0 = disable)
log_level: "info" # audit verbosity
Features:
- Stdio transport (command + args) and HTTP/StreamableHTTP transport (url)
@ -37,6 +49,8 @@ Features:
- Credential stripping in error messages returned to the LLM
- Configurable per-server timeouts for tool calls and connections
- Thread-safe architecture with dedicated background event loop
- Sampling support: MCP servers can request LLM completions via
sampling/createMessage (text and tool-use responses)
Architecture:
A dedicated background event loop (_mcp_loop) runs in a daemon thread.
@ -58,9 +72,11 @@ Thread safety:
import asyncio
import json
import logging
import math
import os
import re
import threading
import time
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
@ -71,6 +87,7 @@ logger = logging.getLogger(__name__)
_MCP_AVAILABLE = False
_MCP_HTTP_AVAILABLE = False
_MCP_SAMPLING_TYPES = False
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
@ -80,6 +97,20 @@ try:
_MCP_HTTP_AVAILABLE = True
except ImportError:
_MCP_HTTP_AVAILABLE = False
# Sampling types -- separated so older SDK versions don't break MCP support
try:
from mcp.types import (
CreateMessageResult,
CreateMessageResultWithTools,
ErrorData,
SamplingCapability,
SamplingToolsCapability,
TextContent,
ToolUseContent,
)
_MCP_SAMPLING_TYPES = True
except ImportError:
logger.debug("MCP sampling types not available -- sampling disabled")
except ImportError:
logger.debug("mcp package not installed -- MCP tool support disabled")
@ -145,6 +176,386 @@ def _sanitize_error(text: str) -> str:
return _CREDENTIAL_PATTERN.sub("[REDACTED]", text)
# ---------------------------------------------------------------------------
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
# ---------------------------------------------------------------------------
def _safe_numeric(value, default, coerce=int, minimum=1):
"""Coerce a config value to a numeric type, returning *default* on failure.
Handles string values from YAML (e.g. ``"10"`` instead of ``10``),
non-finite floats, and values below *minimum*.
"""
try:
result = coerce(value)
if isinstance(result, float) and not math.isfinite(result):
return default
return max(result, minimum)
except (TypeError, ValueError, OverflowError):
return default
class SamplingHandler:
"""Handles sampling/createMessage requests for a single MCP server.
Each MCPServerTask that has sampling enabled creates one SamplingHandler.
The handler is callable and passed directly to ``ClientSession`` as
the ``sampling_callback``. All state (rate-limit timestamps, metrics,
tool-loop counters) lives on the instance -- no module-level globals.
The callback is async and runs on the MCP background event loop. The
sync LLM call is offloaded to a thread via ``asyncio.to_thread()`` so
it doesn't block the event loop.
"""
_STOP_REASON_MAP = {"stop": "endTurn", "length": "maxTokens", "tool_calls": "toolUse"}
def __init__(self, server_name: str, config: dict):
self.server_name = server_name
self.max_rpm = _safe_numeric(config.get("max_rpm", 10), 10, int)
self.timeout = _safe_numeric(config.get("timeout", 30), 30, float)
self.max_tokens_cap = _safe_numeric(config.get("max_tokens_cap", 4096), 4096, int)
self.max_tool_rounds = _safe_numeric(
config.get("max_tool_rounds", 5), 5, int, minimum=0,
)
self.model_override = config.get("model")
self.allowed_models = config.get("allowed_models", [])
_log_levels = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING}
self.audit_level = _log_levels.get(
str(config.get("log_level", "info")).lower(), logging.INFO,
)
# Per-instance state
self._rate_timestamps: List[float] = []
self._tool_loop_count = 0
self.metrics = {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
# -- Rate limiting -------------------------------------------------------
def _check_rate_limit(self) -> bool:
"""Sliding-window rate limiter. Returns True if request is allowed."""
now = time.time()
window = now - 60
self._rate_timestamps[:] = [t for t in self._rate_timestamps if t > window]
if len(self._rate_timestamps) >= self.max_rpm:
return False
self._rate_timestamps.append(now)
return True
# -- Model resolution ----------------------------------------------------
def _resolve_model(self, preferences) -> Optional[str]:
"""Config override > server hint > None (use default)."""
if self.model_override:
return self.model_override
if preferences and hasattr(preferences, "hints") and preferences.hints:
for hint in preferences.hints:
if hasattr(hint, "name") and hint.name:
return hint.name
return None
# -- Message conversion --------------------------------------------------
@staticmethod
def _extract_tool_result_text(block) -> str:
"""Extract text from a ToolResultContent block."""
if not hasattr(block, "content") or block.content is None:
return ""
items = block.content if isinstance(block.content, list) else [block.content]
return "\n".join(item.text for item in items if hasattr(item, "text"))
def _convert_messages(self, params) -> List[dict]:
"""Convert MCP SamplingMessages to OpenAI format.
Uses ``msg.content_as_list`` (SDK helper) so single-block and
list-of-blocks are handled uniformly. Dispatches per block type
with ``isinstance`` on real SDK types when available, falling back
to duck-typing via ``hasattr`` for compatibility.
"""
messages: List[dict] = []
for msg in params.messages:
blocks = msg.content_as_list if hasattr(msg, "content_as_list") else (
msg.content if isinstance(msg.content, list) else [msg.content]
)
# Separate blocks by kind
tool_results = [b for b in blocks if hasattr(b, "toolUseId")]
tool_uses = [b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")]
content_blocks = [b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))]
# Emit tool result messages (role: tool)
for tr in tool_results:
messages.append({
"role": "tool",
"tool_call_id": tr.toolUseId,
"content": self._extract_tool_result_text(tr),
})
# Emit assistant tool_calls message
if tool_uses:
tc_list = []
for tu in tool_uses:
tc_list.append({
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
"type": "function",
"function": {
"name": tu.name,
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
},
})
msg_dict: dict = {"role": msg.role, "tool_calls": tc_list}
# Include any accompanying text
text_parts = [b.text for b in content_blocks if hasattr(b, "text")]
if text_parts:
msg_dict["content"] = "\n".join(text_parts)
messages.append(msg_dict)
elif content_blocks:
# Pure text/image content
if len(content_blocks) == 1 and hasattr(content_blocks[0], "text"):
messages.append({"role": msg.role, "content": content_blocks[0].text})
else:
parts = []
for block in content_blocks:
if hasattr(block, "text"):
parts.append({"type": "text", "text": block.text})
elif hasattr(block, "data") and hasattr(block, "mimeType"):
parts.append({
"type": "image_url",
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
})
else:
logger.warning(
"Unsupported sampling content block type: %s (skipped)",
type(block).__name__,
)
if parts:
messages.append({"role": msg.role, "content": parts})
return messages
# -- Error helper --------------------------------------------------------
@staticmethod
def _error(message: str, code: int = -1):
"""Return ErrorData (MCP spec) or raise as fallback."""
if _MCP_SAMPLING_TYPES:
return ErrorData(code=code, message=message)
raise Exception(message)
# -- Response building ---------------------------------------------------
def _build_tool_use_result(self, choice, response):
"""Build a CreateMessageResultWithTools from an LLM tool_calls response."""
self.metrics["tool_use_count"] += 1
# Tool loop governance
if self.max_tool_rounds == 0:
self._tool_loop_count = 0
return self._error(
f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)"
)
self._tool_loop_count += 1
if self._tool_loop_count > self.max_tool_rounds:
self._tool_loop_count = 0
return self._error(
f"Tool loop limit exceeded for server '{self.server_name}' "
f"(max {self.max_tool_rounds} rounds)"
)
content_blocks = []
for tc in choice.message.tool_calls:
args = tc.function.arguments
if isinstance(args, str):
try:
parsed = json.loads(args)
except (json.JSONDecodeError, ValueError):
logger.warning(
"MCP server '%s': malformed tool_calls arguments "
"from LLM (wrapping as raw): %.100s",
self.server_name, args,
)
parsed = {"_raw": args}
else:
parsed = args if isinstance(args, dict) else {"_raw": str(args)}
content_blocks.append(ToolUseContent(
type="tool_use",
id=tc.id,
name=tc.function.name,
input=parsed,
))
logger.log(
self.audit_level,
"MCP server '%s' sampling response: model=%s, tokens=%s, tool_calls=%d",
self.server_name, response.model,
getattr(getattr(response, "usage", None), "total_tokens", "?"),
len(content_blocks),
)
return CreateMessageResultWithTools(
role="assistant",
content=content_blocks,
model=response.model,
stopReason="toolUse",
)
def _build_text_result(self, choice, response):
"""Build a CreateMessageResult from a normal text response."""
self._tool_loop_count = 0 # reset on text response
response_text = choice.message.content or ""
logger.log(
self.audit_level,
"MCP server '%s' sampling response: model=%s, tokens=%s",
self.server_name, response.model,
getattr(getattr(response, "usage", None), "total_tokens", "?"),
)
return CreateMessageResult(
role="assistant",
content=TextContent(type="text", text=_sanitize_error(response_text)),
model=response.model,
stopReason=self._STOP_REASON_MAP.get(choice.finish_reason, "endTurn"),
)
# -- Session kwargs helper -----------------------------------------------
def session_kwargs(self) -> dict:
"""Return kwargs to pass to ClientSession for sampling support."""
return {
"sampling_callback": self,
"sampling_capabilities": SamplingCapability(
tools=SamplingToolsCapability(),
),
}
# -- Main callback -------------------------------------------------------
async def __call__(self, context, params):
"""Sampling callback invoked by the MCP SDK.
Conforms to ``SamplingFnT`` protocol. Returns
``CreateMessageResult``, ``CreateMessageResultWithTools``, or
``ErrorData``.
"""
# Rate limit
if not self._check_rate_limit():
logger.warning(
"MCP server '%s' sampling rate limit exceeded (%d/min)",
self.server_name, self.max_rpm,
)
self.metrics["errors"] += 1
return self._error(
f"Sampling rate limit exceeded for server '{self.server_name}' "
f"({self.max_rpm} requests/minute)"
)
# Resolve model
model = self._resolve_model(getattr(params, "modelPreferences", None))
# Get auxiliary LLM client
from agent.auxiliary_client import get_text_auxiliary_client
client, default_model = get_text_auxiliary_client()
if client is None:
self.metrics["errors"] += 1
return self._error("No LLM provider available for sampling")
resolved_model = model or default_model
# Model whitelist check
if self.allowed_models and resolved_model not in self.allowed_models:
logger.warning(
"MCP server '%s' requested model '%s' not in allowed_models",
self.server_name, resolved_model,
)
self.metrics["errors"] += 1
return self._error(
f"Model '{resolved_model}' not allowed for server "
f"'{self.server_name}'. Allowed: {', '.join(self.allowed_models)}"
)
# Convert messages
messages = self._convert_messages(params)
if hasattr(params, "systemPrompt") and params.systemPrompt:
messages.insert(0, {"role": "system", "content": params.systemPrompt})
# Build LLM call kwargs
max_tokens = min(params.maxTokens, self.max_tokens_cap)
call_kwargs: dict = {
"model": resolved_model,
"messages": messages,
"max_tokens": max_tokens,
}
if hasattr(params, "temperature") and params.temperature is not None:
call_kwargs["temperature"] = params.temperature
if stop := getattr(params, "stopSequences", None):
call_kwargs["stop"] = stop
# Forward server-provided tools
server_tools = getattr(params, "tools", None)
if server_tools:
call_kwargs["tools"] = [
{
"type": "function",
"function": {
"name": getattr(t, "name", ""),
"description": getattr(t, "description", "") or "",
"parameters": getattr(t, "inputSchema", {}) or {},
},
}
for t in server_tools
]
if tool_choice := getattr(params, "toolChoice", None):
mode = getattr(tool_choice, "mode", "auto")
call_kwargs["tool_choice"] = {"auto": "auto", "required": "required", "none": "none"}.get(mode, "auto")
logger.log(
self.audit_level,
"MCP server '%s' sampling request: model=%s, max_tokens=%d, messages=%d",
self.server_name, resolved_model, max_tokens, len(messages),
)
# Offload sync LLM call to thread (non-blocking)
def _sync_call():
return client.chat.completions.create(**call_kwargs)
try:
response = await asyncio.wait_for(
asyncio.to_thread(_sync_call), timeout=self.timeout,
)
except asyncio.TimeoutError:
self.metrics["errors"] += 1
return self._error(
f"Sampling LLM call timed out after {self.timeout}s "
f"for server '{self.server_name}'"
)
except Exception as exc:
self.metrics["errors"] += 1
return self._error(
f"Sampling LLM call failed: {_sanitize_error(str(exc))}"
)
# Track metrics
choice = response.choices[0]
self.metrics["requests"] += 1
total_tokens = getattr(getattr(response, "usage", None), "total_tokens", 0)
if isinstance(total_tokens, int):
self.metrics["tokens_used"] += total_tokens
# Dispatch based on response type
if (
choice.finish_reason == "tool_calls"
and hasattr(choice.message, "tool_calls")
and choice.message.tool_calls
):
return self._build_tool_use_result(choice, response)
return self._build_text_result(choice, response)
# ---------------------------------------------------------------------------
# Server task -- each MCP server lives in one long-lived asyncio Task
# ---------------------------------------------------------------------------
@ -162,6 +573,7 @@ class MCPServerTask:
__slots__ = (
"name", "session", "tool_timeout",
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
"_sampling",
)
def __init__(self, name: str):
@ -174,6 +586,7 @@ class MCPServerTask:
self._tools: list = []
self._error: Optional[Exception] = None
self._config: dict = {}
self._sampling: Optional[SamplingHandler] = None
def _is_http(self) -> bool:
"""Check if this server uses HTTP transport."""
@ -197,8 +610,9 @@ class MCPServerTask:
env=safe_env if safe_env else None,
)
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
async with stdio_client(server_params) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
await session.initialize()
self.session = session
await self._discover_tools()
@ -218,12 +632,13 @@ class MCPServerTask:
headers = config.get("headers")
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
async with streamablehttp_client(
url,
headers=headers,
timeout=float(connect_timeout),
) as (read_stream, write_stream, _get_session_id):
async with ClientSession(read_stream, write_stream) as session:
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
await session.initialize()
self.session = session
await self._discover_tools()
@ -250,6 +665,13 @@ class MCPServerTask:
self._config = config
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
# Set up sampling handler if enabled and SDK types are available
sampling_config = config.get("sampling", {})
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
self._sampling = SamplingHandler(self.name, sampling_config)
else:
self._sampling = None
# Validate: warn if both url and command are present
if "url" in config and "command" in config:
logger.warning(
@ -975,12 +1397,15 @@ def get_mcp_status() -> List[dict]:
transport = "http" if "url" in cfg else "stdio"
server = active_servers.get(name)
if server and server.session is not None:
result.append({
entry = {
"name": name,
"transport": transport,
"tools": len(server._tools),
"connected": True,
})
}
if server._sampling:
entry["sampling"] = dict(server._sampling.metrics)
result.append(entry)
else:
result.append({
"name": name,

View file

@ -271,3 +271,62 @@ You can reload MCP servers without restarting Hermes:
- In the CLI: the agent reconnects automatically
- In messaging: send `/reload-mcp`
## Sampling (Server-Initiated LLM Requests)
MCP's `sampling/createMessage` capability allows MCP servers to request LLM completions through the Hermes agent. This enables agent-in-the-loop workflows where servers can leverage the LLM during tool execution — for example, a database server asking the LLM to interpret query results, or a code analysis server requesting the LLM to review findings.
### How It Works
When an MCP server sends a `sampling/createMessage` request:
1. The sampling callback validates against rate limits and model whitelist
2. Resolves which model to use (config override > server hint > default)
3. Converts MCP messages to OpenAI-compatible format
4. Offloads the LLM call to a thread via `asyncio.to_thread()` (non-blocking)
5. Returns the response (text or tool use) back to the server
### Configuration
Sampling is **enabled by default** for all MCP servers. No extra setup needed — if you have an auxiliary LLM client configured, sampling works automatically.
```yaml
mcp_servers:
analysis_server:
command: "npx"
args: ["-y", "my-analysis-server"]
sampling:
enabled: true # default: true
model: "gemini-3-flash" # override model (optional)
max_tokens_cap: 4096 # max tokens per request (default: 4096)
timeout: 30 # LLM call timeout in seconds (default: 30)
max_rpm: 10 # max requests per minute (default: 10)
allowed_models: [] # model whitelist (empty = allow all)
max_tool_rounds: 5 # max consecutive tool use rounds (0 = disable)
log_level: "info" # audit verbosity: debug, info, warning
```
### Tool Use in Sampling
Servers can include `tools` and `toolChoice` in sampling requests, enabling multi-turn tool-augmented workflows within a single sampling session. The callback forwards tool definitions to the LLM, handles tool use responses with proper `ToolUseContent` types, and enforces `max_tool_rounds` to prevent infinite loops.
### Security
- **Rate limiting**: Per-server sliding window (default: 10 req/min)
- **Token cap**: Servers can't request more than `max_tokens_cap` (default: 4096)
- **Model whitelist**: `allowed_models` restricts which models a server can use
- **Tool loop limit**: `max_tool_rounds` caps consecutive tool use rounds
- **Credential stripping**: LLM responses are sanitized before returning to the server
- **Non-blocking**: LLM calls run in a separate thread via `asyncio.to_thread()`
- **Typed errors**: All failures return structured `ErrorData` per MCP spec
To disable sampling for untrusted servers:
```yaml
mcp_servers:
untrusted:
command: "npx"
args: ["-y", "untrusted-server"]
sampling:
enabled: false
```