mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(aux): add session_search extra_body and concurrency controls
Adds auxiliary.<task>.extra_body config passthrough so reasoning-heavy OpenAI-compatible providers can receive provider-specific request fields (e.g. enable_thinking: false on GLM) on auxiliary calls, and bounds session_search summary fan-out with auxiliary.session_search.max_concurrency (default 3, clamped 1-5) to avoid 429 bursts on small providers. - agent/auxiliary_client.py: extract _get_auxiliary_task_config helper, add _get_task_extra_body, merge config+explicit extra_body with explicit winning - hermes_cli/config.py: extra_body defaults on all aux tasks + session_search.max_concurrency; _config_version 19 -> 20 - tools/session_search_tool.py: semaphore around _summarize_all gather - tests: coverage in test_auxiliary_client, test_session_search, test_aux_config - docs: user-guide/configuration.md + fallback-providers.md Co-authored-by: Teknium <teknium@nousresearch.com>
This commit is contained in:
parent
904f20d622
commit
6ab78401c9
8 changed files with 207 additions and 26 deletions
|
|
@ -2313,7 +2313,6 @@ def _resolve_task_provider_model(
|
|||
to "custom" and the task uses that direct endpoint. api_mode is one of
|
||||
"chat_completions", "codex_responses", or None (auto-detect).
|
||||
"""
|
||||
config = {}
|
||||
cfg_provider = None
|
||||
cfg_model = None
|
||||
cfg_base_url = None
|
||||
|
|
@ -2321,16 +2320,7 @@ def _resolve_task_provider_model(
|
|||
cfg_api_mode = None
|
||||
|
||||
if task:
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
config = load_config()
|
||||
except ImportError:
|
||||
config = {}
|
||||
|
||||
aux = config.get("auxiliary", {}) if isinstance(config, dict) else {}
|
||||
task_config = aux.get(task, {}) if isinstance(aux, dict) else {}
|
||||
if not isinstance(task_config, dict):
|
||||
task_config = {}
|
||||
task_config = _get_auxiliary_task_config(task)
|
||||
cfg_provider = str(task_config.get("provider", "")).strip() or None
|
||||
cfg_model = str(task_config.get("model", "")).strip() or None
|
||||
cfg_base_url = str(task_config.get("base_url", "")).strip() or None
|
||||
|
|
@ -2360,17 +2350,25 @@ def _resolve_task_provider_model(
|
|||
_DEFAULT_AUX_TIMEOUT = 30.0
|
||||
|
||||
|
||||
def _get_task_timeout(task: str, default: float = _DEFAULT_AUX_TIMEOUT) -> float:
|
||||
"""Read timeout from auxiliary.{task}.timeout in config, falling back to *default*."""
|
||||
def _get_auxiliary_task_config(task: str) -> Dict[str, Any]:
|
||||
"""Return the config dict for auxiliary.<task>, or {} when unavailable."""
|
||||
if not task:
|
||||
return default
|
||||
return {}
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
config = load_config()
|
||||
except ImportError:
|
||||
return default
|
||||
return {}
|
||||
aux = config.get("auxiliary", {}) if isinstance(config, dict) else {}
|
||||
task_config = aux.get(task, {}) if isinstance(aux, dict) else {}
|
||||
return task_config if isinstance(task_config, dict) else {}
|
||||
|
||||
|
||||
def _get_task_timeout(task: str, default: float = _DEFAULT_AUX_TIMEOUT) -> float:
|
||||
"""Read timeout from auxiliary.{task}.timeout in config, falling back to *default*."""
|
||||
if not task:
|
||||
return default
|
||||
task_config = _get_auxiliary_task_config(task)
|
||||
raw = task_config.get("timeout")
|
||||
if raw is not None:
|
||||
try:
|
||||
|
|
@ -2380,6 +2378,15 @@ def _get_task_timeout(task: str, default: float = _DEFAULT_AUX_TIMEOUT) -> float
|
|||
return default
|
||||
|
||||
|
||||
def _get_task_extra_body(task: str) -> Dict[str, Any]:
|
||||
"""Read auxiliary.<task>.extra_body and return a shallow copy when valid."""
|
||||
task_config = _get_auxiliary_task_config(task)
|
||||
raw = task_config.get("extra_body")
|
||||
if isinstance(raw, dict):
|
||||
return dict(raw)
|
||||
return {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anthropic-compatible endpoint detection + image block conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -2580,6 +2587,8 @@ def call_llm(
|
|||
"""
|
||||
resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model(
|
||||
task, provider, model, base_url, api_key)
|
||||
effective_extra_body = _get_task_extra_body(task)
|
||||
effective_extra_body.update(extra_body or {})
|
||||
|
||||
if task == "vision":
|
||||
effective_provider, client, final_model = resolve_vision_provider_client(
|
||||
|
|
@ -2654,7 +2663,7 @@ def call_llm(
|
|||
kwargs = _build_call_kwargs(
|
||||
resolved_provider, final_model, messages,
|
||||
temperature=temperature, max_tokens=max_tokens,
|
||||
tools=tools, timeout=effective_timeout, extra_body=extra_body,
|
||||
tools=tools, timeout=effective_timeout, extra_body=effective_extra_body,
|
||||
base_url=_base_info or resolved_base_url)
|
||||
|
||||
# Convert image blocks for Anthropic-compatible endpoints (e.g. MiniMax)
|
||||
|
|
@ -2709,7 +2718,7 @@ def call_llm(
|
|||
fb_label, fb_model, messages,
|
||||
temperature=temperature, max_tokens=max_tokens,
|
||||
tools=tools, timeout=effective_timeout,
|
||||
extra_body=extra_body,
|
||||
extra_body=effective_extra_body,
|
||||
base_url=str(getattr(fb_client, "base_url", "") or ""))
|
||||
return _validate_llm_response(
|
||||
fb_client.chat.completions.create(**fb_kwargs), task)
|
||||
|
|
@ -2792,6 +2801,8 @@ async def async_call_llm(
|
|||
"""
|
||||
resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model(
|
||||
task, provider, model, base_url, api_key)
|
||||
effective_extra_body = _get_task_extra_body(task)
|
||||
effective_extra_body.update(extra_body or {})
|
||||
|
||||
if task == "vision":
|
||||
effective_provider, client, final_model = resolve_vision_provider_client(
|
||||
|
|
@ -2852,7 +2863,7 @@ async def async_call_llm(
|
|||
kwargs = _build_call_kwargs(
|
||||
resolved_provider, final_model, messages,
|
||||
temperature=temperature, max_tokens=max_tokens,
|
||||
tools=tools, timeout=effective_timeout, extra_body=extra_body,
|
||||
tools=tools, timeout=effective_timeout, extra_body=effective_extra_body,
|
||||
base_url=_client_base or resolved_base_url)
|
||||
|
||||
# Convert image blocks for Anthropic-compatible endpoints (e.g. MiniMax)
|
||||
|
|
@ -2891,7 +2902,7 @@ async def async_call_llm(
|
|||
fb_label, fb_model, messages,
|
||||
temperature=temperature, max_tokens=max_tokens,
|
||||
tools=tools, timeout=effective_timeout,
|
||||
extra_body=extra_body,
|
||||
extra_body=effective_extra_body,
|
||||
base_url=str(getattr(fb_client, "base_url", "") or ""))
|
||||
# Convert sync fallback client to async
|
||||
async_fb, async_fb_model = _to_async_client(fb_client, fb_model or "")
|
||||
|
|
|
|||
|
|
@ -487,6 +487,7 @@ DEFAULT_CONFIG = {
|
|||
"base_url": "", # direct OpenAI-compatible endpoint (takes precedence over provider)
|
||||
"api_key": "", # API key for base_url (falls back to OPENAI_API_KEY)
|
||||
"timeout": 120, # seconds — LLM API call timeout; vision payloads need generous timeout
|
||||
"extra_body": {}, # OpenAI-compatible provider-specific request fields
|
||||
"download_timeout": 30, # seconds — image HTTP download timeout; increase for slow connections
|
||||
},
|
||||
"web_extract": {
|
||||
|
|
@ -495,6 +496,7 @@ DEFAULT_CONFIG = {
|
|||
"base_url": "",
|
||||
"api_key": "",
|
||||
"timeout": 360, # seconds (6min) — per-attempt LLM summarization timeout; increase for slow local models
|
||||
"extra_body": {},
|
||||
},
|
||||
"compression": {
|
||||
"provider": "auto",
|
||||
|
|
@ -502,6 +504,7 @@ DEFAULT_CONFIG = {
|
|||
"base_url": "",
|
||||
"api_key": "",
|
||||
"timeout": 120, # seconds — compression summarises large contexts; increase for local models
|
||||
"extra_body": {},
|
||||
},
|
||||
"session_search": {
|
||||
"provider": "auto",
|
||||
|
|
@ -509,6 +512,8 @@ DEFAULT_CONFIG = {
|
|||
"base_url": "",
|
||||
"api_key": "",
|
||||
"timeout": 30,
|
||||
"extra_body": {},
|
||||
"max_concurrency": 3, # Clamp parallel summaries to avoid request-burst 429s on small providers
|
||||
},
|
||||
"skills_hub": {
|
||||
"provider": "auto",
|
||||
|
|
@ -516,6 +521,7 @@ DEFAULT_CONFIG = {
|
|||
"base_url": "",
|
||||
"api_key": "",
|
||||
"timeout": 30,
|
||||
"extra_body": {},
|
||||
},
|
||||
"approval": {
|
||||
"provider": "auto",
|
||||
|
|
@ -523,6 +529,7 @@ DEFAULT_CONFIG = {
|
|||
"base_url": "",
|
||||
"api_key": "",
|
||||
"timeout": 30,
|
||||
"extra_body": {},
|
||||
},
|
||||
"mcp": {
|
||||
"provider": "auto",
|
||||
|
|
@ -530,6 +537,7 @@ DEFAULT_CONFIG = {
|
|||
"base_url": "",
|
||||
"api_key": "",
|
||||
"timeout": 30,
|
||||
"extra_body": {},
|
||||
},
|
||||
"flush_memories": {
|
||||
"provider": "auto",
|
||||
|
|
@ -537,6 +545,7 @@ DEFAULT_CONFIG = {
|
|||
"base_url": "",
|
||||
"api_key": "",
|
||||
"timeout": 30,
|
||||
"extra_body": {},
|
||||
},
|
||||
"title_generation": {
|
||||
"provider": "auto",
|
||||
|
|
@ -544,6 +553,7 @@ DEFAULT_CONFIG = {
|
|||
"base_url": "",
|
||||
"api_key": "",
|
||||
"timeout": 30,
|
||||
"extra_body": {},
|
||||
},
|
||||
},
|
||||
|
||||
|
|
@ -812,7 +822,7 @@ DEFAULT_CONFIG = {
|
|||
},
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 19,
|
||||
"_config_version": 20,
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -946,6 +946,70 @@ class TestStaleBaseUrlWarning:
|
|||
"Expected a warning about stale OPENAI_BASE_URL"
|
||||
assert mod._stale_base_url_warned is True
|
||||
|
||||
|
||||
class TestAuxiliaryTaskExtraBody:
|
||||
def test_sync_call_merges_task_extra_body_from_config(self):
|
||||
client = MagicMock()
|
||||
client.base_url = "https://api.example.com/v1"
|
||||
response = MagicMock()
|
||||
client.chat.completions.create.return_value = response
|
||||
|
||||
config = {
|
||||
"auxiliary": {
|
||||
"session_search": {
|
||||
"extra_body": {
|
||||
"enable_thinking": False,
|
||||
"reasoning": {"effort": "none"},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
with patch("hermes_cli.config.load_config", return_value=config), patch(
|
||||
"agent.auxiliary_client._get_cached_client",
|
||||
return_value=(client, "glm-4.5-air"),
|
||||
):
|
||||
result = call_llm(
|
||||
task="session_search",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
extra_body={"metadata": {"source": "test"}},
|
||||
)
|
||||
|
||||
assert result is response
|
||||
kwargs = client.chat.completions.create.call_args.kwargs
|
||||
assert kwargs["extra_body"]["enable_thinking"] is False
|
||||
assert kwargs["extra_body"]["reasoning"] == {"effort": "none"}
|
||||
assert kwargs["extra_body"]["metadata"] == {"source": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_call_explicit_extra_body_overrides_task_config(self):
|
||||
client = MagicMock()
|
||||
client.base_url = "https://api.example.com/v1"
|
||||
response = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(return_value=response)
|
||||
|
||||
config = {
|
||||
"auxiliary": {
|
||||
"session_search": {
|
||||
"extra_body": {"enable_thinking": False}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
with patch("hermes_cli.config.load_config", return_value=config), patch(
|
||||
"agent.auxiliary_client._get_cached_client",
|
||||
return_value=(client, "glm-4.5-air"),
|
||||
):
|
||||
result = await async_call_llm(
|
||||
task="session_search",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
extra_body={"enable_thinking": True},
|
||||
)
|
||||
|
||||
assert result is response
|
||||
kwargs = client.chat.completions.create.call_args.kwargs
|
||||
assert kwargs["extra_body"]["enable_thinking"] is True
|
||||
|
||||
def test_no_warning_when_provider_is_custom(self, monkeypatch, caplog):
|
||||
"""No warning when the provider is 'custom' — OPENAI_BASE_URL is expected."""
|
||||
import agent.auxiliary_client as mod
|
||||
|
|
|
|||
|
|
@ -39,6 +39,15 @@ def test_title_generation_present_in_default_config():
|
|||
assert tg["provider"] == "auto"
|
||||
assert tg["model"] == ""
|
||||
assert tg["timeout"] > 0
|
||||
assert tg["extra_body"] == {}
|
||||
|
||||
|
||||
def test_session_search_defaults_include_extra_body_and_concurrency():
|
||||
ss = DEFAULT_CONFIG["auxiliary"]["session_search"]
|
||||
assert ss["provider"] == "auto"
|
||||
assert ss["model"] == ""
|
||||
assert ss["extra_body"] == {}
|
||||
assert ss["max_concurrency"] == 3
|
||||
|
||||
|
||||
def test_aux_tasks_keys_all_exist_in_default_config():
|
||||
|
|
|
|||
|
|
@ -459,7 +459,7 @@ class TestCustomProviderCompatibility:
|
|||
migrate_config(interactive=False, quiet=True)
|
||||
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
|
||||
assert raw["_config_version"] == 19
|
||||
assert raw["_config_version"] == 20
|
||||
assert raw["providers"]["openai-direct"] == {
|
||||
"api": "https://api.openai.com/v1",
|
||||
"api_key": "test-key",
|
||||
|
|
@ -606,7 +606,7 @@ class TestInterimAssistantMessageConfig:
|
|||
migrate_config(interactive=False, quiet=True)
|
||||
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
|
||||
assert raw["_config_version"] == 19
|
||||
assert raw["_config_version"] == 20
|
||||
assert raw["display"]["tool_progress"] == "off"
|
||||
assert raw["display"]["interim_assistant_messages"] is True
|
||||
|
||||
|
|
@ -626,6 +626,6 @@ class TestDiscordChannelPromptsConfig:
|
|||
migrate_config(interactive=False, quiet=True)
|
||||
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
|
||||
assert raw["_config_version"] == 19
|
||||
assert raw["_config_version"] == 20
|
||||
assert raw["discord"]["auto_thread"] is True
|
||||
assert raw["discord"]["channel_prompts"] == {}
|
||||
|
|
|
|||
|
|
@ -64,4 +64,4 @@ class TestCamofoxConfigDefaults:
|
|||
|
||||
# The current schema version is tracked globally; unrelated default
|
||||
# options may bump it after browser defaults are added.
|
||||
assert DEFAULT_CONFIG["_config_version"] == 19
|
||||
assert DEFAULT_CONFIG["_config_version"] == 20
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Tests for tools/session_search_tool.py — helper functions and search dispatcher."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import pytest
|
||||
|
|
@ -8,6 +9,7 @@ from tools.session_search_tool import (
|
|||
_format_timestamp,
|
||||
_format_conversation,
|
||||
_truncate_around_matches,
|
||||
_get_session_search_max_concurrency,
|
||||
_HIDDEN_SESSION_SOURCES,
|
||||
MAX_SESSION_CHARS,
|
||||
SESSION_SEARCH_SCHEMA,
|
||||
|
|
@ -181,6 +183,63 @@ class TestTruncateAroundMatches:
|
|||
assert result.lower().count("alpha beta") == 2
|
||||
|
||||
|
||||
class TestSessionSearchConcurrency:
|
||||
def test_defaults_to_three(self):
|
||||
assert _get_session_search_max_concurrency() == 3
|
||||
|
||||
def test_reads_and_clamps_configured_value(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.load_config",
|
||||
lambda: {"auxiliary": {"session_search": {"max_concurrency": 9}}},
|
||||
)
|
||||
assert _get_session_search_max_concurrency() == 5
|
||||
|
||||
def test_session_search_respects_configured_concurrency_limit(self, monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
from tools.session_search_tool import session_search
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.load_config",
|
||||
lambda: {"auxiliary": {"session_search": {"max_concurrency": 1}}},
|
||||
)
|
||||
|
||||
max_seen = {"value": 0}
|
||||
active = {"value": 0}
|
||||
|
||||
async def fake_summarize(_text, _query, _meta):
|
||||
active["value"] += 1
|
||||
max_seen["value"] = max(max_seen["value"], active["value"])
|
||||
await asyncio.sleep(0.01)
|
||||
active["value"] -= 1
|
||||
return "summary"
|
||||
|
||||
monkeypatch.setattr("tools.session_search_tool._summarize_session", fake_summarize)
|
||||
monkeypatch.setattr("model_tools._run_async", lambda coro: asyncio.run(coro))
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.search_messages.return_value = [
|
||||
{"session_id": "s1", "source": "cli", "session_started": 1709500000, "model": "test"},
|
||||
{"session_id": "s2", "source": "cli", "session_started": 1709500001, "model": "test"},
|
||||
{"session_id": "s3", "source": "cli", "session_started": 1709500002, "model": "test"},
|
||||
]
|
||||
mock_db.get_session.side_effect = lambda sid: {
|
||||
"id": sid,
|
||||
"parent_session_id": None,
|
||||
"source": "cli",
|
||||
"started_at": 1709500000,
|
||||
}
|
||||
mock_db.get_messages_as_conversation.side_effect = lambda sid: [
|
||||
{"role": "user", "content": f"message from {sid}"},
|
||||
{"role": "assistant", "content": "response"},
|
||||
]
|
||||
|
||||
result = json.loads(session_search(query="message", db=mock_db, limit=3))
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["count"] == 3
|
||||
assert max_seen["value"] == 1
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# session_search (dispatcher)
|
||||
# =========================================================================
|
||||
|
|
|
|||
|
|
@ -27,6 +27,27 @@ MAX_SESSION_CHARS = 100_000
|
|||
MAX_SUMMARY_TOKENS = 10000
|
||||
|
||||
|
||||
def _get_session_search_max_concurrency(default: int = 3) -> int:
|
||||
"""Read auxiliary.session_search.max_concurrency with sane bounds."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
config = load_config()
|
||||
except ImportError:
|
||||
return default
|
||||
aux = config.get("auxiliary", {}) if isinstance(config, dict) else {}
|
||||
task_config = aux.get("session_search", {}) if isinstance(aux, dict) else {}
|
||||
if not isinstance(task_config, dict):
|
||||
return default
|
||||
raw = task_config.get("max_concurrency")
|
||||
if raw is None:
|
||||
return default
|
||||
try:
|
||||
value = int(raw)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return max(1, min(value, 5))
|
||||
|
||||
|
||||
def _format_timestamp(ts: Union[int, float, str, None]) -> str:
|
||||
"""Convert a Unix timestamp (float/int) or ISO string to a human-readable date.
|
||||
|
||||
|
|
@ -423,9 +444,16 @@ def session_search(
|
|||
|
||||
# Summarize all sessions in parallel
|
||||
async def _summarize_all() -> List[Union[str, Exception]]:
|
||||
"""Summarize all sessions in parallel."""
|
||||
"""Summarize all sessions with bounded concurrency."""
|
||||
max_concurrency = min(_get_session_search_max_concurrency(), max(1, len(tasks)))
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def _bounded_summary(text: str, meta: Dict[str, Any]) -> Optional[str]:
|
||||
async with semaphore:
|
||||
return await _summarize_session(text, query, meta)
|
||||
|
||||
coros = [
|
||||
_summarize_session(text, query, meta)
|
||||
_bounded_summary(text, meta)
|
||||
for _, _, text, meta in tasks
|
||||
]
|
||||
return await asyncio.gather(*coros, return_exceptions=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue