fix(aux): add session_search extra_body and concurrency controls

Adds auxiliary.<task>.extra_body config passthrough so reasoning-heavy
OpenAI-compatible providers can receive provider-specific request fields
(e.g. enable_thinking: false on GLM) on auxiliary calls, and bounds
session_search summary fan-out with auxiliary.session_search.max_concurrency
(default 3, clamped 1-5) to avoid 429 bursts on small providers.

- agent/auxiliary_client.py: extract _get_auxiliary_task_config helper,
  add _get_task_extra_body, merge config+explicit extra_body with explicit winning
- hermes_cli/config.py: extra_body defaults on all aux tasks +
  session_search.max_concurrency; _config_version 19 -> 20
- tools/session_search_tool.py: semaphore around _summarize_all gather
- tests: coverage in test_auxiliary_client, test_session_search, test_aux_config
- docs: user-guide/configuration.md + fallback-providers.md

Co-authored-by: Teknium <teknium@nousresearch.com>
This commit is contained in:
helix4u 2026-04-20 00:44:32 -07:00 committed by Teknium
parent 904f20d622
commit 6ab78401c9
8 changed files with 207 additions and 26 deletions

View file

@ -2313,7 +2313,6 @@ def _resolve_task_provider_model(
to "custom" and the task uses that direct endpoint. api_mode is one of to "custom" and the task uses that direct endpoint. api_mode is one of
"chat_completions", "codex_responses", or None (auto-detect). "chat_completions", "codex_responses", or None (auto-detect).
""" """
config = {}
cfg_provider = None cfg_provider = None
cfg_model = None cfg_model = None
cfg_base_url = None cfg_base_url = None
@ -2321,16 +2320,7 @@ def _resolve_task_provider_model(
cfg_api_mode = None cfg_api_mode = None
if task: if task:
try: task_config = _get_auxiliary_task_config(task)
from hermes_cli.config import load_config
config = load_config()
except ImportError:
config = {}
aux = config.get("auxiliary", {}) if isinstance(config, dict) else {}
task_config = aux.get(task, {}) if isinstance(aux, dict) else {}
if not isinstance(task_config, dict):
task_config = {}
cfg_provider = str(task_config.get("provider", "")).strip() or None cfg_provider = str(task_config.get("provider", "")).strip() or None
cfg_model = str(task_config.get("model", "")).strip() or None cfg_model = str(task_config.get("model", "")).strip() or None
cfg_base_url = str(task_config.get("base_url", "")).strip() or None cfg_base_url = str(task_config.get("base_url", "")).strip() or None
@ -2360,17 +2350,25 @@ def _resolve_task_provider_model(
_DEFAULT_AUX_TIMEOUT = 30.0 _DEFAULT_AUX_TIMEOUT = 30.0
def _get_task_timeout(task: str, default: float = _DEFAULT_AUX_TIMEOUT) -> float: def _get_auxiliary_task_config(task: str) -> Dict[str, Any]:
"""Read timeout from auxiliary.{task}.timeout in config, falling back to *default*.""" """Return the config dict for auxiliary.<task>, or {} when unavailable."""
if not task: if not task:
return default return {}
try: try:
from hermes_cli.config import load_config from hermes_cli.config import load_config
config = load_config() config = load_config()
except ImportError: except ImportError:
return default return {}
aux = config.get("auxiliary", {}) if isinstance(config, dict) else {} aux = config.get("auxiliary", {}) if isinstance(config, dict) else {}
task_config = aux.get(task, {}) if isinstance(aux, dict) else {} task_config = aux.get(task, {}) if isinstance(aux, dict) else {}
return task_config if isinstance(task_config, dict) else {}
def _get_task_timeout(task: str, default: float = _DEFAULT_AUX_TIMEOUT) -> float:
"""Read timeout from auxiliary.{task}.timeout in config, falling back to *default*."""
if not task:
return default
task_config = _get_auxiliary_task_config(task)
raw = task_config.get("timeout") raw = task_config.get("timeout")
if raw is not None: if raw is not None:
try: try:
@ -2380,6 +2378,15 @@ def _get_task_timeout(task: str, default: float = _DEFAULT_AUX_TIMEOUT) -> float
return default return default
def _get_task_extra_body(task: str) -> Dict[str, Any]:
"""Read auxiliary.<task>.extra_body and return a shallow copy when valid."""
task_config = _get_auxiliary_task_config(task)
raw = task_config.get("extra_body")
if isinstance(raw, dict):
return dict(raw)
return {}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Anthropic-compatible endpoint detection + image block conversion # Anthropic-compatible endpoint detection + image block conversion
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -2580,6 +2587,8 @@ def call_llm(
""" """
resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model( resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model(
task, provider, model, base_url, api_key) task, provider, model, base_url, api_key)
effective_extra_body = _get_task_extra_body(task)
effective_extra_body.update(extra_body or {})
if task == "vision": if task == "vision":
effective_provider, client, final_model = resolve_vision_provider_client( effective_provider, client, final_model = resolve_vision_provider_client(
@ -2654,7 +2663,7 @@ def call_llm(
kwargs = _build_call_kwargs( kwargs = _build_call_kwargs(
resolved_provider, final_model, messages, resolved_provider, final_model, messages,
temperature=temperature, max_tokens=max_tokens, temperature=temperature, max_tokens=max_tokens,
tools=tools, timeout=effective_timeout, extra_body=extra_body, tools=tools, timeout=effective_timeout, extra_body=effective_extra_body,
base_url=_base_info or resolved_base_url) base_url=_base_info or resolved_base_url)
# Convert image blocks for Anthropic-compatible endpoints (e.g. MiniMax) # Convert image blocks for Anthropic-compatible endpoints (e.g. MiniMax)
@ -2709,7 +2718,7 @@ def call_llm(
fb_label, fb_model, messages, fb_label, fb_model, messages,
temperature=temperature, max_tokens=max_tokens, temperature=temperature, max_tokens=max_tokens,
tools=tools, timeout=effective_timeout, tools=tools, timeout=effective_timeout,
extra_body=extra_body, extra_body=effective_extra_body,
base_url=str(getattr(fb_client, "base_url", "") or "")) base_url=str(getattr(fb_client, "base_url", "") or ""))
return _validate_llm_response( return _validate_llm_response(
fb_client.chat.completions.create(**fb_kwargs), task) fb_client.chat.completions.create(**fb_kwargs), task)
@ -2792,6 +2801,8 @@ async def async_call_llm(
""" """
resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model( resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model(
task, provider, model, base_url, api_key) task, provider, model, base_url, api_key)
effective_extra_body = _get_task_extra_body(task)
effective_extra_body.update(extra_body or {})
if task == "vision": if task == "vision":
effective_provider, client, final_model = resolve_vision_provider_client( effective_provider, client, final_model = resolve_vision_provider_client(
@ -2852,7 +2863,7 @@ async def async_call_llm(
kwargs = _build_call_kwargs( kwargs = _build_call_kwargs(
resolved_provider, final_model, messages, resolved_provider, final_model, messages,
temperature=temperature, max_tokens=max_tokens, temperature=temperature, max_tokens=max_tokens,
tools=tools, timeout=effective_timeout, extra_body=extra_body, tools=tools, timeout=effective_timeout, extra_body=effective_extra_body,
base_url=_client_base or resolved_base_url) base_url=_client_base or resolved_base_url)
# Convert image blocks for Anthropic-compatible endpoints (e.g. MiniMax) # Convert image blocks for Anthropic-compatible endpoints (e.g. MiniMax)
@ -2891,7 +2902,7 @@ async def async_call_llm(
fb_label, fb_model, messages, fb_label, fb_model, messages,
temperature=temperature, max_tokens=max_tokens, temperature=temperature, max_tokens=max_tokens,
tools=tools, timeout=effective_timeout, tools=tools, timeout=effective_timeout,
extra_body=extra_body, extra_body=effective_extra_body,
base_url=str(getattr(fb_client, "base_url", "") or "")) base_url=str(getattr(fb_client, "base_url", "") or ""))
# Convert sync fallback client to async # Convert sync fallback client to async
async_fb, async_fb_model = _to_async_client(fb_client, fb_model or "") async_fb, async_fb_model = _to_async_client(fb_client, fb_model or "")

View file

@ -487,6 +487,7 @@ DEFAULT_CONFIG = {
"base_url": "", # direct OpenAI-compatible endpoint (takes precedence over provider) "base_url": "", # direct OpenAI-compatible endpoint (takes precedence over provider)
"api_key": "", # API key for base_url (falls back to OPENAI_API_KEY) "api_key": "", # API key for base_url (falls back to OPENAI_API_KEY)
"timeout": 120, # seconds — LLM API call timeout; vision payloads need generous timeout "timeout": 120, # seconds — LLM API call timeout; vision payloads need generous timeout
"extra_body": {}, # OpenAI-compatible provider-specific request fields
"download_timeout": 30, # seconds — image HTTP download timeout; increase for slow connections "download_timeout": 30, # seconds — image HTTP download timeout; increase for slow connections
}, },
"web_extract": { "web_extract": {
@ -495,6 +496,7 @@ DEFAULT_CONFIG = {
"base_url": "", "base_url": "",
"api_key": "", "api_key": "",
"timeout": 360, # seconds (6min) — per-attempt LLM summarization timeout; increase for slow local models "timeout": 360, # seconds (6min) — per-attempt LLM summarization timeout; increase for slow local models
"extra_body": {},
}, },
"compression": { "compression": {
"provider": "auto", "provider": "auto",
@ -502,6 +504,7 @@ DEFAULT_CONFIG = {
"base_url": "", "base_url": "",
"api_key": "", "api_key": "",
"timeout": 120, # seconds — compression summarises large contexts; increase for local models "timeout": 120, # seconds — compression summarises large contexts; increase for local models
"extra_body": {},
}, },
"session_search": { "session_search": {
"provider": "auto", "provider": "auto",
@ -509,6 +512,8 @@ DEFAULT_CONFIG = {
"base_url": "", "base_url": "",
"api_key": "", "api_key": "",
"timeout": 30, "timeout": 30,
"extra_body": {},
"max_concurrency": 3, # Clamp parallel summaries to avoid request-burst 429s on small providers
}, },
"skills_hub": { "skills_hub": {
"provider": "auto", "provider": "auto",
@ -516,6 +521,7 @@ DEFAULT_CONFIG = {
"base_url": "", "base_url": "",
"api_key": "", "api_key": "",
"timeout": 30, "timeout": 30,
"extra_body": {},
}, },
"approval": { "approval": {
"provider": "auto", "provider": "auto",
@ -523,6 +529,7 @@ DEFAULT_CONFIG = {
"base_url": "", "base_url": "",
"api_key": "", "api_key": "",
"timeout": 30, "timeout": 30,
"extra_body": {},
}, },
"mcp": { "mcp": {
"provider": "auto", "provider": "auto",
@ -530,6 +537,7 @@ DEFAULT_CONFIG = {
"base_url": "", "base_url": "",
"api_key": "", "api_key": "",
"timeout": 30, "timeout": 30,
"extra_body": {},
}, },
"flush_memories": { "flush_memories": {
"provider": "auto", "provider": "auto",
@ -537,6 +545,7 @@ DEFAULT_CONFIG = {
"base_url": "", "base_url": "",
"api_key": "", "api_key": "",
"timeout": 30, "timeout": 30,
"extra_body": {},
}, },
"title_generation": { "title_generation": {
"provider": "auto", "provider": "auto",
@ -544,6 +553,7 @@ DEFAULT_CONFIG = {
"base_url": "", "base_url": "",
"api_key": "", "api_key": "",
"timeout": 30, "timeout": 30,
"extra_body": {},
}, },
}, },
@ -812,7 +822,7 @@ DEFAULT_CONFIG = {
}, },
# Config schema version - bump this when adding new required fields # Config schema version - bump this when adding new required fields
"_config_version": 19, "_config_version": 20,
} }
# ============================================================================= # =============================================================================

View file

@ -946,6 +946,70 @@ class TestStaleBaseUrlWarning:
"Expected a warning about stale OPENAI_BASE_URL" "Expected a warning about stale OPENAI_BASE_URL"
assert mod._stale_base_url_warned is True assert mod._stale_base_url_warned is True
class TestAuxiliaryTaskExtraBody:
def test_sync_call_merges_task_extra_body_from_config(self):
client = MagicMock()
client.base_url = "https://api.example.com/v1"
response = MagicMock()
client.chat.completions.create.return_value = response
config = {
"auxiliary": {
"session_search": {
"extra_body": {
"enable_thinking": False,
"reasoning": {"effort": "none"},
}
}
}
}
with patch("hermes_cli.config.load_config", return_value=config), patch(
"agent.auxiliary_client._get_cached_client",
return_value=(client, "glm-4.5-air"),
):
result = call_llm(
task="session_search",
messages=[{"role": "user", "content": "hello"}],
extra_body={"metadata": {"source": "test"}},
)
assert result is response
kwargs = client.chat.completions.create.call_args.kwargs
assert kwargs["extra_body"]["enable_thinking"] is False
assert kwargs["extra_body"]["reasoning"] == {"effort": "none"}
assert kwargs["extra_body"]["metadata"] == {"source": "test"}
@pytest.mark.asyncio
async def test_async_call_explicit_extra_body_overrides_task_config(self):
client = MagicMock()
client.base_url = "https://api.example.com/v1"
response = MagicMock()
client.chat.completions.create = AsyncMock(return_value=response)
config = {
"auxiliary": {
"session_search": {
"extra_body": {"enable_thinking": False}
}
}
}
with patch("hermes_cli.config.load_config", return_value=config), patch(
"agent.auxiliary_client._get_cached_client",
return_value=(client, "glm-4.5-air"),
):
result = await async_call_llm(
task="session_search",
messages=[{"role": "user", "content": "hello"}],
extra_body={"enable_thinking": True},
)
assert result is response
kwargs = client.chat.completions.create.call_args.kwargs
assert kwargs["extra_body"]["enable_thinking"] is True
def test_no_warning_when_provider_is_custom(self, monkeypatch, caplog): def test_no_warning_when_provider_is_custom(self, monkeypatch, caplog):
"""No warning when the provider is 'custom' — OPENAI_BASE_URL is expected.""" """No warning when the provider is 'custom' — OPENAI_BASE_URL is expected."""
import agent.auxiliary_client as mod import agent.auxiliary_client as mod

View file

@ -39,6 +39,15 @@ def test_title_generation_present_in_default_config():
assert tg["provider"] == "auto" assert tg["provider"] == "auto"
assert tg["model"] == "" assert tg["model"] == ""
assert tg["timeout"] > 0 assert tg["timeout"] > 0
assert tg["extra_body"] == {}
def test_session_search_defaults_include_extra_body_and_concurrency():
ss = DEFAULT_CONFIG["auxiliary"]["session_search"]
assert ss["provider"] == "auto"
assert ss["model"] == ""
assert ss["extra_body"] == {}
assert ss["max_concurrency"] == 3
def test_aux_tasks_keys_all_exist_in_default_config(): def test_aux_tasks_keys_all_exist_in_default_config():

View file

@ -459,7 +459,7 @@ class TestCustomProviderCompatibility:
migrate_config(interactive=False, quiet=True) migrate_config(interactive=False, quiet=True)
raw = yaml.safe_load(config_path.read_text(encoding="utf-8")) raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
assert raw["_config_version"] == 19 assert raw["_config_version"] == 20
assert raw["providers"]["openai-direct"] == { assert raw["providers"]["openai-direct"] == {
"api": "https://api.openai.com/v1", "api": "https://api.openai.com/v1",
"api_key": "test-key", "api_key": "test-key",
@ -606,7 +606,7 @@ class TestInterimAssistantMessageConfig:
migrate_config(interactive=False, quiet=True) migrate_config(interactive=False, quiet=True)
raw = yaml.safe_load(config_path.read_text(encoding="utf-8")) raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
assert raw["_config_version"] == 19 assert raw["_config_version"] == 20
assert raw["display"]["tool_progress"] == "off" assert raw["display"]["tool_progress"] == "off"
assert raw["display"]["interim_assistant_messages"] is True assert raw["display"]["interim_assistant_messages"] is True
@ -626,6 +626,6 @@ class TestDiscordChannelPromptsConfig:
migrate_config(interactive=False, quiet=True) migrate_config(interactive=False, quiet=True)
raw = yaml.safe_load(config_path.read_text(encoding="utf-8")) raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
assert raw["_config_version"] == 19 assert raw["_config_version"] == 20
assert raw["discord"]["auto_thread"] is True assert raw["discord"]["auto_thread"] is True
assert raw["discord"]["channel_prompts"] == {} assert raw["discord"]["channel_prompts"] == {}

View file

@ -64,4 +64,4 @@ class TestCamofoxConfigDefaults:
# The current schema version is tracked globally; unrelated default # The current schema version is tracked globally; unrelated default
# options may bump it after browser defaults are added. # options may bump it after browser defaults are added.
assert DEFAULT_CONFIG["_config_version"] == 19 assert DEFAULT_CONFIG["_config_version"] == 20

View file

@ -1,5 +1,6 @@
"""Tests for tools/session_search_tool.py — helper functions and search dispatcher.""" """Tests for tools/session_search_tool.py — helper functions and search dispatcher."""
import asyncio
import json import json
import time import time
import pytest import pytest
@ -8,6 +9,7 @@ from tools.session_search_tool import (
_format_timestamp, _format_timestamp,
_format_conversation, _format_conversation,
_truncate_around_matches, _truncate_around_matches,
_get_session_search_max_concurrency,
_HIDDEN_SESSION_SOURCES, _HIDDEN_SESSION_SOURCES,
MAX_SESSION_CHARS, MAX_SESSION_CHARS,
SESSION_SEARCH_SCHEMA, SESSION_SEARCH_SCHEMA,
@ -181,6 +183,63 @@ class TestTruncateAroundMatches:
assert result.lower().count("alpha beta") == 2 assert result.lower().count("alpha beta") == 2
class TestSessionSearchConcurrency:
def test_defaults_to_three(self):
assert _get_session_search_max_concurrency() == 3
def test_reads_and_clamps_configured_value(self, monkeypatch):
monkeypatch.setattr(
"hermes_cli.config.load_config",
lambda: {"auxiliary": {"session_search": {"max_concurrency": 9}}},
)
assert _get_session_search_max_concurrency() == 5
def test_session_search_respects_configured_concurrency_limit(self, monkeypatch):
from unittest.mock import MagicMock
from tools.session_search_tool import session_search
monkeypatch.setattr(
"hermes_cli.config.load_config",
lambda: {"auxiliary": {"session_search": {"max_concurrency": 1}}},
)
max_seen = {"value": 0}
active = {"value": 0}
async def fake_summarize(_text, _query, _meta):
active["value"] += 1
max_seen["value"] = max(max_seen["value"], active["value"])
await asyncio.sleep(0.01)
active["value"] -= 1
return "summary"
monkeypatch.setattr("tools.session_search_tool._summarize_session", fake_summarize)
monkeypatch.setattr("model_tools._run_async", lambda coro: asyncio.run(coro))
mock_db = MagicMock()
mock_db.search_messages.return_value = [
{"session_id": "s1", "source": "cli", "session_started": 1709500000, "model": "test"},
{"session_id": "s2", "source": "cli", "session_started": 1709500001, "model": "test"},
{"session_id": "s3", "source": "cli", "session_started": 1709500002, "model": "test"},
]
mock_db.get_session.side_effect = lambda sid: {
"id": sid,
"parent_session_id": None,
"source": "cli",
"started_at": 1709500000,
}
mock_db.get_messages_as_conversation.side_effect = lambda sid: [
{"role": "user", "content": f"message from {sid}"},
{"role": "assistant", "content": "response"},
]
result = json.loads(session_search(query="message", db=mock_db, limit=3))
assert result["success"] is True
assert result["count"] == 3
assert max_seen["value"] == 1
# ========================================================================= # =========================================================================
# session_search (dispatcher) # session_search (dispatcher)
# ========================================================================= # =========================================================================

View file

@ -27,6 +27,27 @@ MAX_SESSION_CHARS = 100_000
MAX_SUMMARY_TOKENS = 10000 MAX_SUMMARY_TOKENS = 10000
def _get_session_search_max_concurrency(default: int = 3) -> int:
"""Read auxiliary.session_search.max_concurrency with sane bounds."""
try:
from hermes_cli.config import load_config
config = load_config()
except ImportError:
return default
aux = config.get("auxiliary", {}) if isinstance(config, dict) else {}
task_config = aux.get("session_search", {}) if isinstance(aux, dict) else {}
if not isinstance(task_config, dict):
return default
raw = task_config.get("max_concurrency")
if raw is None:
return default
try:
value = int(raw)
except (TypeError, ValueError):
return default
return max(1, min(value, 5))
def _format_timestamp(ts: Union[int, float, str, None]) -> str: def _format_timestamp(ts: Union[int, float, str, None]) -> str:
"""Convert a Unix timestamp (float/int) or ISO string to a human-readable date. """Convert a Unix timestamp (float/int) or ISO string to a human-readable date.
@ -423,9 +444,16 @@ def session_search(
# Summarize all sessions in parallel # Summarize all sessions in parallel
async def _summarize_all() -> List[Union[str, Exception]]: async def _summarize_all() -> List[Union[str, Exception]]:
"""Summarize all sessions in parallel.""" """Summarize all sessions with bounded concurrency."""
max_concurrency = min(_get_session_search_max_concurrency(), max(1, len(tasks)))
semaphore = asyncio.Semaphore(max_concurrency)
async def _bounded_summary(text: str, meta: Dict[str, Any]) -> Optional[str]:
async with semaphore:
return await _summarize_session(text, query, meta)
coros = [ coros = [
_summarize_session(text, query, meta) _bounded_summary(text, meta)
for _, _, text, meta in tasks for _, _, text, meta in tasks
] ]
return await asyncio.gather(*coros, return_exceptions=True) return await asyncio.gather(*coros, return_exceptions=True)