mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
feat(mcp): add sampling support — server-initiated LLM requests (#753)
Add MCP sampling/createMessage capability via SamplingHandler class. Text-only sampling + tool use in sampling with governance (rate limits, model whitelist, token caps, tool loop limits). Per-server audit metrics. Based on concept from PR #366 by eren-karakus0. Restructured as class-based design with bug fixes and tests using real MCP SDK types. 50 new tests, 2600 total passing.
This commit is contained in:
parent
1f0944de21
commit
654e16187e
5 changed files with 1307 additions and 4 deletions
|
|
@ -555,6 +555,21 @@ toolsets:
|
|||
# args: ["-y", "@modelcontextprotocol/server-github"]
|
||||
# env:
|
||||
# GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..."
|
||||
#
|
||||
# Sampling (server-initiated LLM requests) — enabled by default.
|
||||
# Per-server config under the 'sampling' key:
|
||||
# analysis:
|
||||
# command: npx
|
||||
# args: ["-y", "analysis-server"]
|
||||
# sampling:
|
||||
# enabled: true # default: true
|
||||
# model: "gemini-3-flash" # override model (optional)
|
||||
# max_tokens_cap: 4096 # max tokens per request
|
||||
# timeout: 30 # LLM call timeout (seconds)
|
||||
# max_rpm: 10 # max requests per minute
|
||||
# allowed_models: [] # model whitelist (empty = all)
|
||||
# max_tool_rounds: 5 # tool loop limit (0 = disable)
|
||||
# log_level: "info" # audit verbosity
|
||||
|
||||
# =============================================================================
|
||||
# Voice Transcription (Speech-to-Text)
|
||||
|
|
|
|||
|
|
@ -321,6 +321,32 @@ mcp_servers:
|
|||
|
||||
All tools from all servers are registered and available simultaneously. Each server's tools are prefixed with its name to avoid collisions.
|
||||
|
||||
## Sampling (Server-Initiated LLM Requests)
|
||||
|
||||
Hermes supports MCP's `sampling/createMessage` capability — MCP servers can request LLM completions through the agent during tool execution. This enables agent-in-the-loop workflows (data analysis, content generation, decision-making).
|
||||
|
||||
Sampling is **enabled by default**. Configure per server:
|
||||
|
||||
```yaml
|
||||
mcp_servers:
|
||||
my_server:
|
||||
command: "npx"
|
||||
args: ["-y", "my-mcp-server"]
|
||||
sampling:
|
||||
enabled: true # default: true
|
||||
model: "gemini-3-flash" # model override (optional)
|
||||
max_tokens_cap: 4096 # max tokens per request
|
||||
timeout: 30 # LLM call timeout (seconds)
|
||||
max_rpm: 10 # max requests per minute
|
||||
allowed_models: [] # model whitelist (empty = all)
|
||||
max_tool_rounds: 5 # tool loop limit (0 = disable)
|
||||
log_level: "info" # audit verbosity
|
||||
```
|
||||
|
||||
Servers can also include `tools` in sampling requests for multi-turn tool-augmented workflows. The `max_tool_rounds` config prevents infinite tool loops. Per-server audit metrics (requests, errors, tokens, tool use count) are tracked via `get_mcp_status()`.
|
||||
|
||||
Disable sampling for untrusted servers with `sampling: { enabled: false }`.
|
||||
|
||||
## Notes
|
||||
|
||||
- MCP tools are called synchronously from the agent's perspective but run asynchronously on a dedicated background event loop
|
||||
|
|
|
|||
|
|
@ -1489,3 +1489,781 @@ class TestUtilityToolRegistration:
|
|||
assert entry.check_fn() is False
|
||||
|
||||
_servers.pop("chk", None)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# SamplingHandler tests
|
||||
# ===========================================================================
|
||||
|
||||
import math
|
||||
import time
|
||||
|
||||
from mcp.types import (
|
||||
CreateMessageResult,
|
||||
CreateMessageResultWithTools,
|
||||
ErrorData,
|
||||
SamplingCapability,
|
||||
SamplingToolsCapability,
|
||||
TextContent,
|
||||
ToolUseContent,
|
||||
)
|
||||
|
||||
from tools.mcp_tool import SamplingHandler, _safe_numeric
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for sampling tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_sampling_params(
|
||||
messages=None,
|
||||
max_tokens=100,
|
||||
system_prompt=None,
|
||||
model_preferences=None,
|
||||
temperature=None,
|
||||
stop_sequences=None,
|
||||
tools=None,
|
||||
tool_choice=None,
|
||||
):
|
||||
"""Create a fake CreateMessageRequestParams using SimpleNamespace.
|
||||
|
||||
Each message must have a ``content_as_list`` attribute that mirrors
|
||||
the SDK helper so that ``_convert_messages`` works correctly.
|
||||
"""
|
||||
if messages is None:
|
||||
content = SimpleNamespace(text="Hello")
|
||||
msg = SimpleNamespace(role="user", content=content, content_as_list=[content])
|
||||
messages = [msg]
|
||||
|
||||
params = SimpleNamespace(
|
||||
messages=messages,
|
||||
maxTokens=max_tokens,
|
||||
modelPreferences=model_preferences,
|
||||
temperature=temperature,
|
||||
stopSequences=stop_sequences,
|
||||
tools=tools,
|
||||
toolChoice=tool_choice,
|
||||
)
|
||||
if system_prompt is not None:
|
||||
params.systemPrompt = system_prompt
|
||||
return params
|
||||
|
||||
|
||||
def _make_llm_response(
|
||||
content="LLM response",
|
||||
model="test-model",
|
||||
finish_reason="stop",
|
||||
tool_calls=None,
|
||||
):
|
||||
"""Create a fake OpenAI chat completion response (text)."""
|
||||
message = SimpleNamespace(content=content, tool_calls=tool_calls)
|
||||
choice = SimpleNamespace(
|
||||
finish_reason=finish_reason,
|
||||
message=message,
|
||||
)
|
||||
usage = SimpleNamespace(total_tokens=42)
|
||||
return SimpleNamespace(choices=[choice], model=model, usage=usage)
|
||||
|
||||
|
||||
def _make_llm_tool_response(tool_calls_data=None, model="test-model"):
|
||||
"""Create a fake response with tool_calls.
|
||||
|
||||
``tool_calls_data``: list of (id, name, arguments_json) tuples.
|
||||
"""
|
||||
if tool_calls_data is None:
|
||||
tool_calls_data = [("call_1", "get_weather", '{"city": "London"}')]
|
||||
|
||||
tc_list = [
|
||||
SimpleNamespace(
|
||||
id=tc_id,
|
||||
function=SimpleNamespace(name=name, arguments=args),
|
||||
)
|
||||
for tc_id, name, args in tool_calls_data
|
||||
]
|
||||
return _make_llm_response(
|
||||
content=None,
|
||||
model=model,
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=tc_list,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. _safe_numeric helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSafeNumeric:
|
||||
def test_int_passthrough(self):
|
||||
assert _safe_numeric(10, 5, int) == 10
|
||||
|
||||
def test_string_coercion(self):
|
||||
assert _safe_numeric("20", 5, int) == 20
|
||||
|
||||
def test_none_returns_default(self):
|
||||
assert _safe_numeric(None, 7, int) == 7
|
||||
|
||||
def test_inf_returns_default(self):
|
||||
assert _safe_numeric(float("inf"), 3.0, float) == 3.0
|
||||
|
||||
def test_nan_returns_default(self):
|
||||
assert _safe_numeric(float("nan"), 4.0, float) == 4.0
|
||||
|
||||
def test_below_minimum_clamps(self):
|
||||
assert _safe_numeric(-5, 10, int, minimum=1) == 1
|
||||
|
||||
def test_minimum_zero_allowed(self):
|
||||
assert _safe_numeric(0, 10, int, minimum=0) == 0
|
||||
|
||||
def test_non_numeric_string_returns_default(self):
|
||||
assert _safe_numeric("abc", 42, int) == 42
|
||||
|
||||
def test_float_coercion(self):
|
||||
assert _safe_numeric("3.5", 1.0, float) == 3.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. SamplingHandler initialization and config parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSamplingHandlerInit:
|
||||
def test_defaults(self):
|
||||
h = SamplingHandler("srv", {})
|
||||
assert h.server_name == "srv"
|
||||
assert h.max_rpm == 10
|
||||
assert h.timeout == 30
|
||||
assert h.max_tokens_cap == 4096
|
||||
assert h.max_tool_rounds == 5
|
||||
assert h.model_override is None
|
||||
assert h.allowed_models == []
|
||||
assert h.metrics == {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
|
||||
|
||||
def test_custom_config(self):
|
||||
cfg = {
|
||||
"max_rpm": 20,
|
||||
"timeout": 60,
|
||||
"max_tokens_cap": 2048,
|
||||
"max_tool_rounds": 3,
|
||||
"model": "gpt-4o",
|
||||
"allowed_models": ["gpt-4o", "gpt-3.5-turbo"],
|
||||
"log_level": "debug",
|
||||
}
|
||||
h = SamplingHandler("custom", cfg)
|
||||
assert h.max_rpm == 20
|
||||
assert h.timeout == 60.0
|
||||
assert h.max_tokens_cap == 2048
|
||||
assert h.max_tool_rounds == 3
|
||||
assert h.model_override == "gpt-4o"
|
||||
assert h.allowed_models == ["gpt-4o", "gpt-3.5-turbo"]
|
||||
|
||||
def test_string_numeric_config_values(self):
|
||||
"""YAML sometimes delivers numeric values as strings."""
|
||||
cfg = {"max_rpm": "15", "timeout": "45.5", "max_tokens_cap": "1024"}
|
||||
h = SamplingHandler("s", cfg)
|
||||
assert h.max_rpm == 15
|
||||
assert h.timeout == 45.5
|
||||
assert h.max_tokens_cap == 1024
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Rate limiting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRateLimit:
|
||||
def setup_method(self):
|
||||
self.handler = SamplingHandler("rl", {"max_rpm": 3})
|
||||
|
||||
def test_allows_under_limit(self):
|
||||
assert self.handler._check_rate_limit() is True
|
||||
assert self.handler._check_rate_limit() is True
|
||||
assert self.handler._check_rate_limit() is True
|
||||
|
||||
def test_rejects_over_limit(self):
|
||||
for _ in range(3):
|
||||
self.handler._check_rate_limit()
|
||||
assert self.handler._check_rate_limit() is False
|
||||
|
||||
def test_window_expiry(self):
|
||||
"""Old timestamps should be purged from the sliding window."""
|
||||
for _ in range(3):
|
||||
self.handler._check_rate_limit()
|
||||
# Simulate timestamps from 61 seconds ago
|
||||
self.handler._rate_timestamps[:] = [time.time() - 61] * 3
|
||||
assert self.handler._check_rate_limit() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Model resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResolveModel:
|
||||
def setup_method(self):
|
||||
self.handler = SamplingHandler("mr", {})
|
||||
|
||||
def test_no_preference_no_override(self):
|
||||
assert self.handler._resolve_model(None) is None
|
||||
|
||||
def test_config_override_wins(self):
|
||||
self.handler.model_override = "override-model"
|
||||
prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")])
|
||||
assert self.handler._resolve_model(prefs) == "override-model"
|
||||
|
||||
def test_hint_used_when_no_override(self):
|
||||
prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")])
|
||||
assert self.handler._resolve_model(prefs) == "hint-model"
|
||||
|
||||
def test_empty_hints(self):
|
||||
prefs = SimpleNamespace(hints=[])
|
||||
assert self.handler._resolve_model(prefs) is None
|
||||
|
||||
def test_hint_without_name(self):
|
||||
prefs = SimpleNamespace(hints=[SimpleNamespace(name=None)])
|
||||
assert self.handler._resolve_model(prefs) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Message conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestConvertMessages:
|
||||
def setup_method(self):
|
||||
self.handler = SamplingHandler("mc", {})
|
||||
|
||||
def test_single_text_message(self):
|
||||
content = SimpleNamespace(text="Hello world")
|
||||
msg = SimpleNamespace(role="user", content=content, content_as_list=[content])
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
assert result[0] == {"role": "user", "content": "Hello world"}
|
||||
|
||||
def test_image_message(self):
|
||||
text_block = SimpleNamespace(text="Look at this")
|
||||
img_block = SimpleNamespace(data="abc123", mimeType="image/png")
|
||||
msg = SimpleNamespace(
|
||||
role="user",
|
||||
content=[text_block, img_block],
|
||||
content_as_list=[text_block, img_block],
|
||||
)
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
parts = result[0]["content"]
|
||||
assert len(parts) == 2
|
||||
assert parts[0] == {"type": "text", "text": "Look at this"}
|
||||
assert parts[1]["type"] == "image_url"
|
||||
assert "data:image/png;base64,abc123" in parts[1]["image_url"]["url"]
|
||||
|
||||
def test_tool_result_message(self):
|
||||
inner = SimpleNamespace(text="42 degrees")
|
||||
tr_block = SimpleNamespace(toolUseId="call_1", content=[inner])
|
||||
msg = SimpleNamespace(
|
||||
role="user",
|
||||
content=[tr_block],
|
||||
content_as_list=[tr_block],
|
||||
)
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "tool"
|
||||
assert result[0]["tool_call_id"] == "call_1"
|
||||
assert result[0]["content"] == "42 degrees"
|
||||
|
||||
def test_tool_use_message(self):
|
||||
tu_block = SimpleNamespace(
|
||||
id="call_2", name="get_weather", input={"city": "London"}
|
||||
)
|
||||
msg = SimpleNamespace(
|
||||
role="assistant",
|
||||
content=[tu_block],
|
||||
content_as_list=[tu_block],
|
||||
)
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert len(result[0]["tool_calls"]) == 1
|
||||
assert result[0]["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
assert json.loads(result[0]["tool_calls"][0]["function"]["arguments"]) == {"city": "London"}
|
||||
|
||||
def test_mixed_text_and_tool_use(self):
|
||||
"""Assistant message with both text and tool_calls."""
|
||||
text_block = SimpleNamespace(text="Let me check the weather")
|
||||
tu_block = SimpleNamespace(
|
||||
id="call_3", name="get_weather", input={"city": "Paris"}
|
||||
)
|
||||
msg = SimpleNamespace(
|
||||
role="assistant",
|
||||
content=[text_block, tu_block],
|
||||
content_as_list=[text_block, tu_block],
|
||||
)
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
assert result[0]["content"] == "Let me check the weather"
|
||||
assert len(result[0]["tool_calls"]) == 1
|
||||
|
||||
def test_fallback_without_content_as_list(self):
|
||||
"""When content_as_list is absent, falls back to content."""
|
||||
content = SimpleNamespace(text="Fallback text")
|
||||
msg = SimpleNamespace(role="user", content=content)
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
assert result[0]["content"] == "Fallback text"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. Text-only sampling callback (full flow)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSamplingCallbackText:
|
||||
def setup_method(self):
|
||||
self.handler = SamplingHandler("txt", {})
|
||||
|
||||
def test_text_response(self):
|
||||
"""Full flow: text response returns CreateMessageResult."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response(
|
||||
content="Hello from LLM"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
params = _make_sampling_params()
|
||||
result = asyncio.run(self.handler(None, params))
|
||||
|
||||
assert isinstance(result, CreateMessageResult)
|
||||
assert isinstance(result.content, TextContent)
|
||||
assert result.content.text == "Hello from LLM"
|
||||
assert result.model == "test-model"
|
||||
assert result.role == "assistant"
|
||||
assert result.stopReason == "endTurn"
|
||||
|
||||
def test_system_prompt_prepended(self):
|
||||
"""System prompt is inserted as the first message."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
params = _make_sampling_params(system_prompt="Be helpful")
|
||||
asyncio.run(self.handler(None, params))
|
||||
|
||||
call_args = fake_client.chat.completions.create.call_args
|
||||
messages = call_args.kwargs["messages"]
|
||||
assert messages[0] == {"role": "system", "content": "Be helpful"}
|
||||
|
||||
def test_length_stop_reason(self):
|
||||
"""finish_reason='length' maps to stopReason='maxTokens'."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response(
|
||||
finish_reason="length"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
params = _make_sampling_params()
|
||||
result = asyncio.run(self.handler(None, params))
|
||||
|
||||
assert isinstance(result, CreateMessageResult)
|
||||
assert result.stopReason == "maxTokens"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. Tool use sampling callback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSamplingCallbackToolUse:
|
||||
def setup_method(self):
|
||||
self.handler = SamplingHandler("tu", {})
|
||||
|
||||
def test_tool_use_response(self):
|
||||
"""LLM tool_calls response returns CreateMessageResultWithTools."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
params = _make_sampling_params()
|
||||
result = asyncio.run(self.handler(None, params))
|
||||
|
||||
assert isinstance(result, CreateMessageResultWithTools)
|
||||
assert result.stopReason == "toolUse"
|
||||
assert result.model == "test-model"
|
||||
assert len(result.content) == 1
|
||||
tc = result.content[0]
|
||||
assert isinstance(tc, ToolUseContent)
|
||||
assert tc.name == "get_weather"
|
||||
assert tc.id == "call_1"
|
||||
assert tc.input == {"city": "London"}
|
||||
|
||||
def test_multiple_tool_calls(self):
|
||||
"""Multiple tool_calls in a single response."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response(
|
||||
tool_calls_data=[
|
||||
("call_a", "func_a", '{"x": 1}'),
|
||||
("call_b", "func_b", '{"y": 2}'),
|
||||
]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
result = asyncio.run(self.handler(None, _make_sampling_params()))
|
||||
|
||||
assert isinstance(result, CreateMessageResultWithTools)
|
||||
assert len(result.content) == 2
|
||||
assert result.content[0].name == "func_a"
|
||||
assert result.content[1].name == "func_b"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Tool loop governance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToolLoopGovernance:
|
||||
def test_max_tool_rounds_enforcement(self):
|
||||
"""After max_tool_rounds consecutive tool responses, an error is returned."""
|
||||
handler = SamplingHandler("tl", {"max_tool_rounds": 2})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
params = _make_sampling_params()
|
||||
# Round 1, 2: allowed
|
||||
r1 = asyncio.run(handler(None, params))
|
||||
assert isinstance(r1, CreateMessageResultWithTools)
|
||||
r2 = asyncio.run(handler(None, params))
|
||||
assert isinstance(r2, CreateMessageResultWithTools)
|
||||
# Round 3: exceeds limit
|
||||
r3 = asyncio.run(handler(None, params))
|
||||
assert isinstance(r3, ErrorData)
|
||||
assert "Tool loop limit exceeded" in r3.message
|
||||
|
||||
def test_text_response_resets_counter(self):
|
||||
"""A text response resets the tool loop counter."""
|
||||
handler = SamplingHandler("tl2", {"max_tool_rounds": 1})
|
||||
fake_client = MagicMock()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
# Tool response (round 1 of 1 allowed)
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
r1 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r1, CreateMessageResultWithTools)
|
||||
|
||||
# Text response resets counter
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
r2 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r2, CreateMessageResult)
|
||||
|
||||
# Tool response again (should succeed since counter was reset)
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
r3 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r3, CreateMessageResultWithTools)
|
||||
|
||||
def test_max_tool_rounds_zero_disables(self):
|
||||
"""max_tool_rounds=0 means tool loops are disabled entirely."""
|
||||
handler = SamplingHandler("tl3", {"max_tool_rounds": 0})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, ErrorData)
|
||||
assert "Tool loops disabled" in result.message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 9. Error paths: rate limit, timeout, no provider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSamplingErrors:
|
||||
def test_rate_limit_error(self):
|
||||
handler = SamplingHandler("rle", {"max_rpm": 1})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
# First call succeeds
|
||||
r1 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r1, CreateMessageResult)
|
||||
# Second call is rate limited
|
||||
r2 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r2, ErrorData)
|
||||
assert "rate limit" in r2.message.lower()
|
||||
assert handler.metrics["errors"] == 1
|
||||
|
||||
def test_timeout_error(self):
|
||||
handler = SamplingHandler("to", {"timeout": 0.05})
|
||||
fake_client = MagicMock()
|
||||
|
||||
def slow_call(**kwargs):
|
||||
import threading
|
||||
# Use an event to ensure the thread truly blocks long enough
|
||||
evt = threading.Event()
|
||||
evt.wait(5) # blocks for up to 5 seconds (cancelled by timeout)
|
||||
return _make_llm_response()
|
||||
|
||||
fake_client.chat.completions.create.side_effect = slow_call
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, ErrorData)
|
||||
assert "timed out" in result.message.lower()
|
||||
assert handler.metrics["errors"] == 1
|
||||
|
||||
def test_no_provider_error(self):
|
||||
handler = SamplingHandler("np", {})
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(None, None),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, ErrorData)
|
||||
assert "No LLM provider" in result.message
|
||||
assert handler.metrics["errors"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 10. Model whitelist
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestModelWhitelist:
|
||||
def test_allowed_model_passes(self):
|
||||
handler = SamplingHandler("wl", {"allowed_models": ["gpt-4o", "test-model"]})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "test-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, CreateMessageResult)
|
||||
|
||||
def test_disallowed_model_rejected(self):
|
||||
handler = SamplingHandler("wl2", {"allowed_models": ["gpt-4o"]})
|
||||
fake_client = MagicMock()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "gpt-3.5-turbo"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, ErrorData)
|
||||
assert "not allowed" in result.message
|
||||
assert handler.metrics["errors"] == 1
|
||||
|
||||
def test_empty_whitelist_allows_all(self):
|
||||
handler = SamplingHandler("wl3", {"allowed_models": []})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "any-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, CreateMessageResult)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 11. Malformed tool_call arguments
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMalformedToolCallArgs:
|
||||
def test_invalid_json_wrapped_as_raw(self):
|
||||
"""Malformed JSON arguments get wrapped in {"_raw": ...}."""
|
||||
handler = SamplingHandler("mf", {})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response(
|
||||
tool_calls_data=[("call_x", "some_tool", "not valid json {{{")]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
assert isinstance(result, CreateMessageResultWithTools)
|
||||
tc = result.content[0]
|
||||
assert isinstance(tc, ToolUseContent)
|
||||
assert tc.input == {"_raw": "not valid json {{{"}
|
||||
|
||||
def test_dict_args_pass_through(self):
|
||||
"""When arguments are already a dict, they pass through directly."""
|
||||
handler = SamplingHandler("mf2", {})
|
||||
|
||||
# Build a tool call where arguments is already a dict
|
||||
tc_obj = SimpleNamespace(
|
||||
id="call_d",
|
||||
function=SimpleNamespace(name="do_stuff", arguments={"key": "val"}),
|
||||
)
|
||||
message = SimpleNamespace(content=None, tool_calls=[tc_obj])
|
||||
choice = SimpleNamespace(finish_reason="tool_calls", message=message)
|
||||
usage = SimpleNamespace(total_tokens=10)
|
||||
response = SimpleNamespace(choices=[choice], model="m", usage=usage)
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = response
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
assert isinstance(result, CreateMessageResultWithTools)
|
||||
assert result.content[0].input == {"key": "val"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 12. Metrics tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMetricsTracking:
|
||||
def test_request_and_token_metrics(self):
|
||||
handler = SamplingHandler("met", {})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
assert handler.metrics["requests"] == 1
|
||||
assert handler.metrics["tokens_used"] == 42
|
||||
assert handler.metrics["errors"] == 0
|
||||
|
||||
def test_tool_use_count_metric(self):
|
||||
handler = SamplingHandler("met2", {})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
assert handler.metrics["tool_use_count"] == 1
|
||||
assert handler.metrics["requests"] == 1
|
||||
|
||||
def test_error_metric_incremented(self):
|
||||
handler = SamplingHandler("met3", {})
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(None, None),
|
||||
):
|
||||
asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
assert handler.metrics["errors"] == 1
|
||||
assert handler.metrics["requests"] == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 13. session_kwargs()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSessionKwargs:
|
||||
def test_returns_correct_keys(self):
|
||||
handler = SamplingHandler("sk", {})
|
||||
kwargs = handler.session_kwargs()
|
||||
assert "sampling_callback" in kwargs
|
||||
assert "sampling_capabilities" in kwargs
|
||||
assert kwargs["sampling_callback"] is handler
|
||||
|
||||
def test_sampling_capabilities_type(self):
|
||||
handler = SamplingHandler("sk2", {})
|
||||
kwargs = handler.session_kwargs()
|
||||
cap = kwargs["sampling_capabilities"]
|
||||
assert isinstance(cap, SamplingCapability)
|
||||
assert isinstance(cap.tools, SamplingToolsCapability)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 14. MCPServerTask integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMCPServerTaskSamplingIntegration:
|
||||
def test_sampling_handler_created_when_enabled(self):
|
||||
"""MCPServerTask.run() creates a SamplingHandler when sampling is enabled."""
|
||||
from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES
|
||||
|
||||
server = MCPServerTask("int_test")
|
||||
config = {
|
||||
"command": "fake",
|
||||
"sampling": {"enabled": True, "max_rpm": 5},
|
||||
}
|
||||
# We only need to test the setup logic, not the actual connection.
|
||||
# Calling run() would attempt a real connection, so we test the
|
||||
# sampling setup portion directly.
|
||||
server._config = config
|
||||
sampling_config = config.get("sampling", {})
|
||||
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
|
||||
server._sampling = SamplingHandler(server.name, sampling_config)
|
||||
else:
|
||||
server._sampling = None
|
||||
|
||||
assert server._sampling is not None
|
||||
assert isinstance(server._sampling, SamplingHandler)
|
||||
assert server._sampling.server_name == "int_test"
|
||||
assert server._sampling.max_rpm == 5
|
||||
|
||||
def test_sampling_handler_none_when_disabled(self):
|
||||
"""MCPServerTask._sampling is None when sampling is disabled."""
|
||||
from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES
|
||||
|
||||
server = MCPServerTask("int_test2")
|
||||
config = {
|
||||
"command": "fake",
|
||||
"sampling": {"enabled": False},
|
||||
}
|
||||
server._config = config
|
||||
sampling_config = config.get("sampling", {})
|
||||
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
|
||||
server._sampling = SamplingHandler(server.name, sampling_config)
|
||||
else:
|
||||
server._sampling = None
|
||||
|
||||
assert server._sampling is None
|
||||
|
||||
def test_session_kwargs_used_in_stdio(self):
|
||||
"""When sampling is set, session_kwargs() are passed to ClientSession."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
server = MCPServerTask("sk_test")
|
||||
server._sampling = SamplingHandler("sk_test", {"max_rpm": 7})
|
||||
kwargs = server._sampling.session_kwargs()
|
||||
assert "sampling_callback" in kwargs
|
||||
assert "sampling_capabilities" in kwargs
|
||||
|
|
|
|||
|
|
@ -29,6 +29,18 @@ Example config::
|
|||
headers:
|
||||
Authorization: "Bearer sk-..."
|
||||
timeout: 180
|
||||
analysis:
|
||||
command: "npx"
|
||||
args: ["-y", "analysis-server"]
|
||||
sampling: # server-initiated LLM requests
|
||||
enabled: true # default: true
|
||||
model: "gemini-3-flash" # override model (optional)
|
||||
max_tokens_cap: 4096 # max tokens per request
|
||||
timeout: 30 # LLM call timeout (seconds)
|
||||
max_rpm: 10 # max requests per minute
|
||||
allowed_models: [] # model whitelist (empty = all)
|
||||
max_tool_rounds: 5 # tool loop limit (0 = disable)
|
||||
log_level: "info" # audit verbosity
|
||||
|
||||
Features:
|
||||
- Stdio transport (command + args) and HTTP/StreamableHTTP transport (url)
|
||||
|
|
@ -37,6 +49,8 @@ Features:
|
|||
- Credential stripping in error messages returned to the LLM
|
||||
- Configurable per-server timeouts for tool calls and connections
|
||||
- Thread-safe architecture with dedicated background event loop
|
||||
- Sampling support: MCP servers can request LLM completions via
|
||||
sampling/createMessage (text and tool-use responses)
|
||||
|
||||
Architecture:
|
||||
A dedicated background event loop (_mcp_loop) runs in a daemon thread.
|
||||
|
|
@ -58,9 +72,11 @@ Thread safety:
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -71,6 +87,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
_MCP_AVAILABLE = False
|
||||
_MCP_HTTP_AVAILABLE = False
|
||||
_MCP_SAMPLING_TYPES = False
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
|
@ -80,6 +97,20 @@ try:
|
|||
_MCP_HTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
_MCP_HTTP_AVAILABLE = False
|
||||
# Sampling types -- separated so older SDK versions don't break MCP support
|
||||
try:
|
||||
from mcp.types import (
|
||||
CreateMessageResult,
|
||||
CreateMessageResultWithTools,
|
||||
ErrorData,
|
||||
SamplingCapability,
|
||||
SamplingToolsCapability,
|
||||
TextContent,
|
||||
ToolUseContent,
|
||||
)
|
||||
_MCP_SAMPLING_TYPES = True
|
||||
except ImportError:
|
||||
logger.debug("MCP sampling types not available -- sampling disabled")
|
||||
except ImportError:
|
||||
logger.debug("mcp package not installed -- MCP tool support disabled")
|
||||
|
||||
|
|
@ -145,6 +176,386 @@ def _sanitize_error(text: str) -> str:
|
|||
return _CREDENTIAL_PATTERN.sub("[REDACTED]", text)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _safe_numeric(value, default, coerce=int, minimum=1):
|
||||
"""Coerce a config value to a numeric type, returning *default* on failure.
|
||||
|
||||
Handles string values from YAML (e.g. ``"10"`` instead of ``10``),
|
||||
non-finite floats, and values below *minimum*.
|
||||
"""
|
||||
try:
|
||||
result = coerce(value)
|
||||
if isinstance(result, float) and not math.isfinite(result):
|
||||
return default
|
||||
return max(result, minimum)
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
return default
|
||||
|
||||
|
||||
class SamplingHandler:
|
||||
"""Handles sampling/createMessage requests for a single MCP server.
|
||||
|
||||
Each MCPServerTask that has sampling enabled creates one SamplingHandler.
|
||||
The handler is callable and passed directly to ``ClientSession`` as
|
||||
the ``sampling_callback``. All state (rate-limit timestamps, metrics,
|
||||
tool-loop counters) lives on the instance -- no module-level globals.
|
||||
|
||||
The callback is async and runs on the MCP background event loop. The
|
||||
sync LLM call is offloaded to a thread via ``asyncio.to_thread()`` so
|
||||
it doesn't block the event loop.
|
||||
"""
|
||||
|
||||
_STOP_REASON_MAP = {"stop": "endTurn", "length": "maxTokens", "tool_calls": "toolUse"}
|
||||
|
||||
def __init__(self, server_name: str, config: dict):
|
||||
self.server_name = server_name
|
||||
self.max_rpm = _safe_numeric(config.get("max_rpm", 10), 10, int)
|
||||
self.timeout = _safe_numeric(config.get("timeout", 30), 30, float)
|
||||
self.max_tokens_cap = _safe_numeric(config.get("max_tokens_cap", 4096), 4096, int)
|
||||
self.max_tool_rounds = _safe_numeric(
|
||||
config.get("max_tool_rounds", 5), 5, int, minimum=0,
|
||||
)
|
||||
self.model_override = config.get("model")
|
||||
self.allowed_models = config.get("allowed_models", [])
|
||||
|
||||
_log_levels = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING}
|
||||
self.audit_level = _log_levels.get(
|
||||
str(config.get("log_level", "info")).lower(), logging.INFO,
|
||||
)
|
||||
|
||||
# Per-instance state
|
||||
self._rate_timestamps: List[float] = []
|
||||
self._tool_loop_count = 0
|
||||
self.metrics = {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
|
||||
|
||||
# -- Rate limiting -------------------------------------------------------
|
||||
|
||||
def _check_rate_limit(self) -> bool:
|
||||
"""Sliding-window rate limiter. Returns True if request is allowed."""
|
||||
now = time.time()
|
||||
window = now - 60
|
||||
self._rate_timestamps[:] = [t for t in self._rate_timestamps if t > window]
|
||||
if len(self._rate_timestamps) >= self.max_rpm:
|
||||
return False
|
||||
self._rate_timestamps.append(now)
|
||||
return True
|
||||
|
||||
# -- Model resolution ----------------------------------------------------
|
||||
|
||||
def _resolve_model(self, preferences) -> Optional[str]:
|
||||
"""Config override > server hint > None (use default)."""
|
||||
if self.model_override:
|
||||
return self.model_override
|
||||
if preferences and hasattr(preferences, "hints") and preferences.hints:
|
||||
for hint in preferences.hints:
|
||||
if hasattr(hint, "name") and hint.name:
|
||||
return hint.name
|
||||
return None
|
||||
|
||||
# -- Message conversion --------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_result_text(block) -> str:
|
||||
"""Extract text from a ToolResultContent block."""
|
||||
if not hasattr(block, "content") or block.content is None:
|
||||
return ""
|
||||
items = block.content if isinstance(block.content, list) else [block.content]
|
||||
return "\n".join(item.text for item in items if hasattr(item, "text"))
|
||||
|
||||
def _convert_messages(self, params) -> List[dict]:
|
||||
"""Convert MCP SamplingMessages to OpenAI format.
|
||||
|
||||
Uses ``msg.content_as_list`` (SDK helper) so single-block and
|
||||
list-of-blocks are handled uniformly. Dispatches per block type
|
||||
with ``isinstance`` on real SDK types when available, falling back
|
||||
to duck-typing via ``hasattr`` for compatibility.
|
||||
"""
|
||||
messages: List[dict] = []
|
||||
for msg in params.messages:
|
||||
blocks = msg.content_as_list if hasattr(msg, "content_as_list") else (
|
||||
msg.content if isinstance(msg.content, list) else [msg.content]
|
||||
)
|
||||
|
||||
# Separate blocks by kind
|
||||
tool_results = [b for b in blocks if hasattr(b, "toolUseId")]
|
||||
tool_uses = [b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")]
|
||||
content_blocks = [b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))]
|
||||
|
||||
# Emit tool result messages (role: tool)
|
||||
for tr in tool_results:
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.toolUseId,
|
||||
"content": self._extract_tool_result_text(tr),
|
||||
})
|
||||
|
||||
# Emit assistant tool_calls message
|
||||
if tool_uses:
|
||||
tc_list = []
|
||||
for tu in tool_uses:
|
||||
tc_list.append({
|
||||
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tu.name,
|
||||
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
|
||||
},
|
||||
})
|
||||
msg_dict: dict = {"role": msg.role, "tool_calls": tc_list}
|
||||
# Include any accompanying text
|
||||
text_parts = [b.text for b in content_blocks if hasattr(b, "text")]
|
||||
if text_parts:
|
||||
msg_dict["content"] = "\n".join(text_parts)
|
||||
messages.append(msg_dict)
|
||||
elif content_blocks:
|
||||
# Pure text/image content
|
||||
if len(content_blocks) == 1 and hasattr(content_blocks[0], "text"):
|
||||
messages.append({"role": msg.role, "content": content_blocks[0].text})
|
||||
else:
|
||||
parts = []
|
||||
for block in content_blocks:
|
||||
if hasattr(block, "text"):
|
||||
parts.append({"type": "text", "text": block.text})
|
||||
elif hasattr(block, "data") and hasattr(block, "mimeType"):
|
||||
parts.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
|
||||
})
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported sampling content block type: %s (skipped)",
|
||||
type(block).__name__,
|
||||
)
|
||||
if parts:
|
||||
messages.append({"role": msg.role, "content": parts})
|
||||
|
||||
return messages
|
||||
|
||||
# -- Error helper --------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _error(message: str, code: int = -1):
|
||||
"""Return ErrorData (MCP spec) or raise as fallback."""
|
||||
if _MCP_SAMPLING_TYPES:
|
||||
return ErrorData(code=code, message=message)
|
||||
raise Exception(message)
|
||||
|
||||
# -- Response building ---------------------------------------------------
|
||||
|
||||
def _build_tool_use_result(self, choice, response):
|
||||
"""Build a CreateMessageResultWithTools from an LLM tool_calls response."""
|
||||
self.metrics["tool_use_count"] += 1
|
||||
|
||||
# Tool loop governance
|
||||
if self.max_tool_rounds == 0:
|
||||
self._tool_loop_count = 0
|
||||
return self._error(
|
||||
f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)"
|
||||
)
|
||||
|
||||
self._tool_loop_count += 1
|
||||
if self._tool_loop_count > self.max_tool_rounds:
|
||||
self._tool_loop_count = 0
|
||||
return self._error(
|
||||
f"Tool loop limit exceeded for server '{self.server_name}' "
|
||||
f"(max {self.max_tool_rounds} rounds)"
|
||||
)
|
||||
|
||||
content_blocks = []
|
||||
for tc in choice.message.tool_calls:
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
parsed = json.loads(args)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.warning(
|
||||
"MCP server '%s': malformed tool_calls arguments "
|
||||
"from LLM (wrapping as raw): %.100s",
|
||||
self.server_name, args,
|
||||
)
|
||||
parsed = {"_raw": args}
|
||||
else:
|
||||
parsed = args if isinstance(args, dict) else {"_raw": str(args)}
|
||||
|
||||
content_blocks.append(ToolUseContent(
|
||||
type="tool_use",
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
input=parsed,
|
||||
))
|
||||
|
||||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling response: model=%s, tokens=%s, tool_calls=%d",
|
||||
self.server_name, response.model,
|
||||
getattr(getattr(response, "usage", None), "total_tokens", "?"),
|
||||
len(content_blocks),
|
||||
)
|
||||
|
||||
return CreateMessageResultWithTools(
|
||||
role="assistant",
|
||||
content=content_blocks,
|
||||
model=response.model,
|
||||
stopReason="toolUse",
|
||||
)
|
||||
|
||||
def _build_text_result(self, choice, response):
|
||||
"""Build a CreateMessageResult from a normal text response."""
|
||||
self._tool_loop_count = 0 # reset on text response
|
||||
response_text = choice.message.content or ""
|
||||
|
||||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling response: model=%s, tokens=%s",
|
||||
self.server_name, response.model,
|
||||
getattr(getattr(response, "usage", None), "total_tokens", "?"),
|
||||
)
|
||||
|
||||
return CreateMessageResult(
|
||||
role="assistant",
|
||||
content=TextContent(type="text", text=_sanitize_error(response_text)),
|
||||
model=response.model,
|
||||
stopReason=self._STOP_REASON_MAP.get(choice.finish_reason, "endTurn"),
|
||||
)
|
||||
|
||||
# -- Session kwargs helper -----------------------------------------------
|
||||
|
||||
def session_kwargs(self) -> dict:
|
||||
"""Return kwargs to pass to ClientSession for sampling support."""
|
||||
return {
|
||||
"sampling_callback": self,
|
||||
"sampling_capabilities": SamplingCapability(
|
||||
tools=SamplingToolsCapability(),
|
||||
),
|
||||
}
|
||||
|
||||
# -- Main callback -------------------------------------------------------
|
||||
|
||||
async def __call__(self, context, params):
|
||||
"""Sampling callback invoked by the MCP SDK.
|
||||
|
||||
Conforms to ``SamplingFnT`` protocol. Returns
|
||||
``CreateMessageResult``, ``CreateMessageResultWithTools``, or
|
||||
``ErrorData``.
|
||||
"""
|
||||
# Rate limit
|
||||
if not self._check_rate_limit():
|
||||
logger.warning(
|
||||
"MCP server '%s' sampling rate limit exceeded (%d/min)",
|
||||
self.server_name, self.max_rpm,
|
||||
)
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Sampling rate limit exceeded for server '{self.server_name}' "
|
||||
f"({self.max_rpm} requests/minute)"
|
||||
)
|
||||
|
||||
# Resolve model
|
||||
model = self._resolve_model(getattr(params, "modelPreferences", None))
|
||||
|
||||
# Get auxiliary LLM client
|
||||
from agent.auxiliary_client import get_text_auxiliary_client
|
||||
client, default_model = get_text_auxiliary_client()
|
||||
if client is None:
|
||||
self.metrics["errors"] += 1
|
||||
return self._error("No LLM provider available for sampling")
|
||||
|
||||
resolved_model = model or default_model
|
||||
|
||||
# Model whitelist check
|
||||
if self.allowed_models and resolved_model not in self.allowed_models:
|
||||
logger.warning(
|
||||
"MCP server '%s' requested model '%s' not in allowed_models",
|
||||
self.server_name, resolved_model,
|
||||
)
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Model '{resolved_model}' not allowed for server "
|
||||
f"'{self.server_name}'. Allowed: {', '.join(self.allowed_models)}"
|
||||
)
|
||||
|
||||
# Convert messages
|
||||
messages = self._convert_messages(params)
|
||||
if hasattr(params, "systemPrompt") and params.systemPrompt:
|
||||
messages.insert(0, {"role": "system", "content": params.systemPrompt})
|
||||
|
||||
# Build LLM call kwargs
|
||||
max_tokens = min(params.maxTokens, self.max_tokens_cap)
|
||||
call_kwargs: dict = {
|
||||
"model": resolved_model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
if hasattr(params, "temperature") and params.temperature is not None:
|
||||
call_kwargs["temperature"] = params.temperature
|
||||
if stop := getattr(params, "stopSequences", None):
|
||||
call_kwargs["stop"] = stop
|
||||
|
||||
# Forward server-provided tools
|
||||
server_tools = getattr(params, "tools", None)
|
||||
if server_tools:
|
||||
call_kwargs["tools"] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": getattr(t, "name", ""),
|
||||
"description": getattr(t, "description", "") or "",
|
||||
"parameters": getattr(t, "inputSchema", {}) or {},
|
||||
},
|
||||
}
|
||||
for t in server_tools
|
||||
]
|
||||
if tool_choice := getattr(params, "toolChoice", None):
|
||||
mode = getattr(tool_choice, "mode", "auto")
|
||||
call_kwargs["tool_choice"] = {"auto": "auto", "required": "required", "none": "none"}.get(mode, "auto")
|
||||
|
||||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling request: model=%s, max_tokens=%d, messages=%d",
|
||||
self.server_name, resolved_model, max_tokens, len(messages),
|
||||
)
|
||||
|
||||
# Offload sync LLM call to thread (non-blocking)
|
||||
def _sync_call():
|
||||
return client.chat.completions.create(**call_kwargs)
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
asyncio.to_thread(_sync_call), timeout=self.timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Sampling LLM call timed out after {self.timeout}s "
|
||||
f"for server '{self.server_name}'"
|
||||
)
|
||||
except Exception as exc:
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Sampling LLM call failed: {_sanitize_error(str(exc))}"
|
||||
)
|
||||
|
||||
# Track metrics
|
||||
choice = response.choices[0]
|
||||
self.metrics["requests"] += 1
|
||||
total_tokens = getattr(getattr(response, "usage", None), "total_tokens", 0)
|
||||
if isinstance(total_tokens, int):
|
||||
self.metrics["tokens_used"] += total_tokens
|
||||
|
||||
# Dispatch based on response type
|
||||
if (
|
||||
choice.finish_reason == "tool_calls"
|
||||
and hasattr(choice.message, "tool_calls")
|
||||
and choice.message.tool_calls
|
||||
):
|
||||
return self._build_tool_use_result(choice, response)
|
||||
|
||||
return self._build_text_result(choice, response)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Server task -- each MCP server lives in one long-lived asyncio Task
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -162,6 +573,7 @@ class MCPServerTask:
|
|||
__slots__ = (
|
||||
"name", "session", "tool_timeout",
|
||||
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
|
||||
"_sampling",
|
||||
)
|
||||
|
||||
def __init__(self, name: str):
|
||||
|
|
@ -174,6 +586,7 @@ class MCPServerTask:
|
|||
self._tools: list = []
|
||||
self._error: Optional[Exception] = None
|
||||
self._config: dict = {}
|
||||
self._sampling: Optional[SamplingHandler] = None
|
||||
|
||||
def _is_http(self) -> bool:
|
||||
"""Check if this server uses HTTP transport."""
|
||||
|
|
@ -197,8 +610,9 @@ class MCPServerTask:
|
|||
env=safe_env if safe_env else None,
|
||||
)
|
||||
|
||||
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
||||
async with stdio_client(server_params) as (read_stream, write_stream):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
await self._discover_tools()
|
||||
|
|
@ -218,12 +632,13 @@ class MCPServerTask:
|
|||
headers = config.get("headers")
|
||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
|
||||
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
||||
async with streamablehttp_client(
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=float(connect_timeout),
|
||||
) as (read_stream, write_stream, _get_session_id):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
await self._discover_tools()
|
||||
|
|
@ -250,6 +665,13 @@ class MCPServerTask:
|
|||
self._config = config
|
||||
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
|
||||
|
||||
# Set up sampling handler if enabled and SDK types are available
|
||||
sampling_config = config.get("sampling", {})
|
||||
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
|
||||
self._sampling = SamplingHandler(self.name, sampling_config)
|
||||
else:
|
||||
self._sampling = None
|
||||
|
||||
# Validate: warn if both url and command are present
|
||||
if "url" in config and "command" in config:
|
||||
logger.warning(
|
||||
|
|
@ -975,12 +1397,15 @@ def get_mcp_status() -> List[dict]:
|
|||
transport = "http" if "url" in cfg else "stdio"
|
||||
server = active_servers.get(name)
|
||||
if server and server.session is not None:
|
||||
result.append({
|
||||
entry = {
|
||||
"name": name,
|
||||
"transport": transport,
|
||||
"tools": len(server._tools),
|
||||
"connected": True,
|
||||
})
|
||||
}
|
||||
if server._sampling:
|
||||
entry["sampling"] = dict(server._sampling.metrics)
|
||||
result.append(entry)
|
||||
else:
|
||||
result.append({
|
||||
"name": name,
|
||||
|
|
|
|||
|
|
@ -271,3 +271,62 @@ You can reload MCP servers without restarting Hermes:
|
|||
|
||||
- In the CLI: the agent reconnects automatically
|
||||
- In messaging: send `/reload-mcp`
|
||||
|
||||
## Sampling (Server-Initiated LLM Requests)
|
||||
|
||||
MCP's `sampling/createMessage` capability allows MCP servers to request LLM completions through the Hermes agent. This enables agent-in-the-loop workflows where servers can leverage the LLM during tool execution — for example, a database server asking the LLM to interpret query results, or a code analysis server requesting the LLM to review findings.
|
||||
|
||||
### How It Works
|
||||
|
||||
When an MCP server sends a `sampling/createMessage` request:
|
||||
|
||||
1. The sampling callback validates against rate limits and model whitelist
|
||||
2. Resolves which model to use (config override > server hint > default)
|
||||
3. Converts MCP messages to OpenAI-compatible format
|
||||
4. Offloads the LLM call to a thread via `asyncio.to_thread()` (non-blocking)
|
||||
5. Returns the response (text or tool use) back to the server
|
||||
|
||||
### Configuration
|
||||
|
||||
Sampling is **enabled by default** for all MCP servers. No extra setup needed — if you have an auxiliary LLM client configured, sampling works automatically.
|
||||
|
||||
```yaml
|
||||
mcp_servers:
|
||||
analysis_server:
|
||||
command: "npx"
|
||||
args: ["-y", "my-analysis-server"]
|
||||
sampling:
|
||||
enabled: true # default: true
|
||||
model: "gemini-3-flash" # override model (optional)
|
||||
max_tokens_cap: 4096 # max tokens per request (default: 4096)
|
||||
timeout: 30 # LLM call timeout in seconds (default: 30)
|
||||
max_rpm: 10 # max requests per minute (default: 10)
|
||||
allowed_models: [] # model whitelist (empty = allow all)
|
||||
max_tool_rounds: 5 # max consecutive tool use rounds (0 = disable)
|
||||
log_level: "info" # audit verbosity: debug, info, warning
|
||||
```
|
||||
|
||||
### Tool Use in Sampling
|
||||
|
||||
Servers can include `tools` and `toolChoice` in sampling requests, enabling multi-turn tool-augmented workflows within a single sampling session. The callback forwards tool definitions to the LLM, handles tool use responses with proper `ToolUseContent` types, and enforces `max_tool_rounds` to prevent infinite loops.
|
||||
|
||||
### Security
|
||||
|
||||
- **Rate limiting**: Per-server sliding window (default: 10 req/min)
|
||||
- **Token cap**: Servers can't request more than `max_tokens_cap` (default: 4096)
|
||||
- **Model whitelist**: `allowed_models` restricts which models a server can use
|
||||
- **Tool loop limit**: `max_tool_rounds` caps consecutive tool use rounds
|
||||
- **Credential stripping**: LLM responses are sanitized before returning to the server
|
||||
- **Non-blocking**: LLM calls run in a separate thread via `asyncio.to_thread()`
|
||||
- **Typed errors**: All failures return structured `ErrorData` per MCP spec
|
||||
|
||||
To disable sampling for untrusted servers:
|
||||
|
||||
```yaml
|
||||
mcp_servers:
|
||||
untrusted:
|
||||
command: "npx"
|
||||
args: ["-y", "untrusted-server"]
|
||||
sampling:
|
||||
enabled: false
|
||||
```
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue