mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(aux): add session_search extra_body and concurrency controls
Adds auxiliary.<task>.extra_body config passthrough so reasoning-heavy OpenAI-compatible providers can receive provider-specific request fields (e.g. enable_thinking: false on GLM) on auxiliary calls, and bounds session_search summary fan-out with auxiliary.session_search.max_concurrency (default 3, clamped 1-5) to avoid 429 bursts on small providers. - agent/auxiliary_client.py: extract _get_auxiliary_task_config helper, add _get_task_extra_body, merge config+explicit extra_body with explicit winning - hermes_cli/config.py: extra_body defaults on all aux tasks + session_search.max_concurrency; _config_version 19 -> 20 - tools/session_search_tool.py: semaphore around _summarize_all gather - tests: coverage in test_auxiliary_client, test_session_search, test_aux_config - docs: user-guide/configuration.md + fallback-providers.md Co-authored-by: Teknium <teknium@nousresearch.com>
This commit is contained in:
parent
904f20d622
commit
6ab78401c9
8 changed files with 207 additions and 26 deletions
|
|
@ -2313,7 +2313,6 @@ def _resolve_task_provider_model(
|
||||||
to "custom" and the task uses that direct endpoint. api_mode is one of
|
to "custom" and the task uses that direct endpoint. api_mode is one of
|
||||||
"chat_completions", "codex_responses", or None (auto-detect).
|
"chat_completions", "codex_responses", or None (auto-detect).
|
||||||
"""
|
"""
|
||||||
config = {}
|
|
||||||
cfg_provider = None
|
cfg_provider = None
|
||||||
cfg_model = None
|
cfg_model = None
|
||||||
cfg_base_url = None
|
cfg_base_url = None
|
||||||
|
|
@ -2321,16 +2320,7 @@ def _resolve_task_provider_model(
|
||||||
cfg_api_mode = None
|
cfg_api_mode = None
|
||||||
|
|
||||||
if task:
|
if task:
|
||||||
try:
|
task_config = _get_auxiliary_task_config(task)
|
||||||
from hermes_cli.config import load_config
|
|
||||||
config = load_config()
|
|
||||||
except ImportError:
|
|
||||||
config = {}
|
|
||||||
|
|
||||||
aux = config.get("auxiliary", {}) if isinstance(config, dict) else {}
|
|
||||||
task_config = aux.get(task, {}) if isinstance(aux, dict) else {}
|
|
||||||
if not isinstance(task_config, dict):
|
|
||||||
task_config = {}
|
|
||||||
cfg_provider = str(task_config.get("provider", "")).strip() or None
|
cfg_provider = str(task_config.get("provider", "")).strip() or None
|
||||||
cfg_model = str(task_config.get("model", "")).strip() or None
|
cfg_model = str(task_config.get("model", "")).strip() or None
|
||||||
cfg_base_url = str(task_config.get("base_url", "")).strip() or None
|
cfg_base_url = str(task_config.get("base_url", "")).strip() or None
|
||||||
|
|
@ -2360,17 +2350,25 @@ def _resolve_task_provider_model(
|
||||||
_DEFAULT_AUX_TIMEOUT = 30.0
|
_DEFAULT_AUX_TIMEOUT = 30.0
|
||||||
|
|
||||||
|
|
||||||
def _get_task_timeout(task: str, default: float = _DEFAULT_AUX_TIMEOUT) -> float:
|
def _get_auxiliary_task_config(task: str) -> Dict[str, Any]:
|
||||||
"""Read timeout from auxiliary.{task}.timeout in config, falling back to *default*."""
|
"""Return the config dict for auxiliary.<task>, or {} when unavailable."""
|
||||||
if not task:
|
if not task:
|
||||||
return default
|
return {}
|
||||||
try:
|
try:
|
||||||
from hermes_cli.config import load_config
|
from hermes_cli.config import load_config
|
||||||
config = load_config()
|
config = load_config()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return default
|
return {}
|
||||||
aux = config.get("auxiliary", {}) if isinstance(config, dict) else {}
|
aux = config.get("auxiliary", {}) if isinstance(config, dict) else {}
|
||||||
task_config = aux.get(task, {}) if isinstance(aux, dict) else {}
|
task_config = aux.get(task, {}) if isinstance(aux, dict) else {}
|
||||||
|
return task_config if isinstance(task_config, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_task_timeout(task: str, default: float = _DEFAULT_AUX_TIMEOUT) -> float:
|
||||||
|
"""Read timeout from auxiliary.{task}.timeout in config, falling back to *default*."""
|
||||||
|
if not task:
|
||||||
|
return default
|
||||||
|
task_config = _get_auxiliary_task_config(task)
|
||||||
raw = task_config.get("timeout")
|
raw = task_config.get("timeout")
|
||||||
if raw is not None:
|
if raw is not None:
|
||||||
try:
|
try:
|
||||||
|
|
@ -2380,6 +2378,15 @@ def _get_task_timeout(task: str, default: float = _DEFAULT_AUX_TIMEOUT) -> float
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _get_task_extra_body(task: str) -> Dict[str, Any]:
|
||||||
|
"""Read auxiliary.<task>.extra_body and return a shallow copy when valid."""
|
||||||
|
task_config = _get_auxiliary_task_config(task)
|
||||||
|
raw = task_config.get("extra_body")
|
||||||
|
if isinstance(raw, dict):
|
||||||
|
return dict(raw)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Anthropic-compatible endpoint detection + image block conversion
|
# Anthropic-compatible endpoint detection + image block conversion
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -2580,6 +2587,8 @@ def call_llm(
|
||||||
"""
|
"""
|
||||||
resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model(
|
resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model(
|
||||||
task, provider, model, base_url, api_key)
|
task, provider, model, base_url, api_key)
|
||||||
|
effective_extra_body = _get_task_extra_body(task)
|
||||||
|
effective_extra_body.update(extra_body or {})
|
||||||
|
|
||||||
if task == "vision":
|
if task == "vision":
|
||||||
effective_provider, client, final_model = resolve_vision_provider_client(
|
effective_provider, client, final_model = resolve_vision_provider_client(
|
||||||
|
|
@ -2654,7 +2663,7 @@ def call_llm(
|
||||||
kwargs = _build_call_kwargs(
|
kwargs = _build_call_kwargs(
|
||||||
resolved_provider, final_model, messages,
|
resolved_provider, final_model, messages,
|
||||||
temperature=temperature, max_tokens=max_tokens,
|
temperature=temperature, max_tokens=max_tokens,
|
||||||
tools=tools, timeout=effective_timeout, extra_body=extra_body,
|
tools=tools, timeout=effective_timeout, extra_body=effective_extra_body,
|
||||||
base_url=_base_info or resolved_base_url)
|
base_url=_base_info or resolved_base_url)
|
||||||
|
|
||||||
# Convert image blocks for Anthropic-compatible endpoints (e.g. MiniMax)
|
# Convert image blocks for Anthropic-compatible endpoints (e.g. MiniMax)
|
||||||
|
|
@ -2709,7 +2718,7 @@ def call_llm(
|
||||||
fb_label, fb_model, messages,
|
fb_label, fb_model, messages,
|
||||||
temperature=temperature, max_tokens=max_tokens,
|
temperature=temperature, max_tokens=max_tokens,
|
||||||
tools=tools, timeout=effective_timeout,
|
tools=tools, timeout=effective_timeout,
|
||||||
extra_body=extra_body,
|
extra_body=effective_extra_body,
|
||||||
base_url=str(getattr(fb_client, "base_url", "") or ""))
|
base_url=str(getattr(fb_client, "base_url", "") or ""))
|
||||||
return _validate_llm_response(
|
return _validate_llm_response(
|
||||||
fb_client.chat.completions.create(**fb_kwargs), task)
|
fb_client.chat.completions.create(**fb_kwargs), task)
|
||||||
|
|
@ -2792,6 +2801,8 @@ async def async_call_llm(
|
||||||
"""
|
"""
|
||||||
resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model(
|
resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model(
|
||||||
task, provider, model, base_url, api_key)
|
task, provider, model, base_url, api_key)
|
||||||
|
effective_extra_body = _get_task_extra_body(task)
|
||||||
|
effective_extra_body.update(extra_body or {})
|
||||||
|
|
||||||
if task == "vision":
|
if task == "vision":
|
||||||
effective_provider, client, final_model = resolve_vision_provider_client(
|
effective_provider, client, final_model = resolve_vision_provider_client(
|
||||||
|
|
@ -2852,7 +2863,7 @@ async def async_call_llm(
|
||||||
kwargs = _build_call_kwargs(
|
kwargs = _build_call_kwargs(
|
||||||
resolved_provider, final_model, messages,
|
resolved_provider, final_model, messages,
|
||||||
temperature=temperature, max_tokens=max_tokens,
|
temperature=temperature, max_tokens=max_tokens,
|
||||||
tools=tools, timeout=effective_timeout, extra_body=extra_body,
|
tools=tools, timeout=effective_timeout, extra_body=effective_extra_body,
|
||||||
base_url=_client_base or resolved_base_url)
|
base_url=_client_base or resolved_base_url)
|
||||||
|
|
||||||
# Convert image blocks for Anthropic-compatible endpoints (e.g. MiniMax)
|
# Convert image blocks for Anthropic-compatible endpoints (e.g. MiniMax)
|
||||||
|
|
@ -2891,7 +2902,7 @@ async def async_call_llm(
|
||||||
fb_label, fb_model, messages,
|
fb_label, fb_model, messages,
|
||||||
temperature=temperature, max_tokens=max_tokens,
|
temperature=temperature, max_tokens=max_tokens,
|
||||||
tools=tools, timeout=effective_timeout,
|
tools=tools, timeout=effective_timeout,
|
||||||
extra_body=extra_body,
|
extra_body=effective_extra_body,
|
||||||
base_url=str(getattr(fb_client, "base_url", "") or ""))
|
base_url=str(getattr(fb_client, "base_url", "") or ""))
|
||||||
# Convert sync fallback client to async
|
# Convert sync fallback client to async
|
||||||
async_fb, async_fb_model = _to_async_client(fb_client, fb_model or "")
|
async_fb, async_fb_model = _to_async_client(fb_client, fb_model or "")
|
||||||
|
|
|
||||||
|
|
@ -487,6 +487,7 @@ DEFAULT_CONFIG = {
|
||||||
"base_url": "", # direct OpenAI-compatible endpoint (takes precedence over provider)
|
"base_url": "", # direct OpenAI-compatible endpoint (takes precedence over provider)
|
||||||
"api_key": "", # API key for base_url (falls back to OPENAI_API_KEY)
|
"api_key": "", # API key for base_url (falls back to OPENAI_API_KEY)
|
||||||
"timeout": 120, # seconds — LLM API call timeout; vision payloads need generous timeout
|
"timeout": 120, # seconds — LLM API call timeout; vision payloads need generous timeout
|
||||||
|
"extra_body": {}, # OpenAI-compatible provider-specific request fields
|
||||||
"download_timeout": 30, # seconds — image HTTP download timeout; increase for slow connections
|
"download_timeout": 30, # seconds — image HTTP download timeout; increase for slow connections
|
||||||
},
|
},
|
||||||
"web_extract": {
|
"web_extract": {
|
||||||
|
|
@ -495,6 +496,7 @@ DEFAULT_CONFIG = {
|
||||||
"base_url": "",
|
"base_url": "",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"timeout": 360, # seconds (6min) — per-attempt LLM summarization timeout; increase for slow local models
|
"timeout": 360, # seconds (6min) — per-attempt LLM summarization timeout; increase for slow local models
|
||||||
|
"extra_body": {},
|
||||||
},
|
},
|
||||||
"compression": {
|
"compression": {
|
||||||
"provider": "auto",
|
"provider": "auto",
|
||||||
|
|
@ -502,6 +504,7 @@ DEFAULT_CONFIG = {
|
||||||
"base_url": "",
|
"base_url": "",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"timeout": 120, # seconds — compression summarises large contexts; increase for local models
|
"timeout": 120, # seconds — compression summarises large contexts; increase for local models
|
||||||
|
"extra_body": {},
|
||||||
},
|
},
|
||||||
"session_search": {
|
"session_search": {
|
||||||
"provider": "auto",
|
"provider": "auto",
|
||||||
|
|
@ -509,6 +512,8 @@ DEFAULT_CONFIG = {
|
||||||
"base_url": "",
|
"base_url": "",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"timeout": 30,
|
"timeout": 30,
|
||||||
|
"extra_body": {},
|
||||||
|
"max_concurrency": 3, # Clamp parallel summaries to avoid request-burst 429s on small providers
|
||||||
},
|
},
|
||||||
"skills_hub": {
|
"skills_hub": {
|
||||||
"provider": "auto",
|
"provider": "auto",
|
||||||
|
|
@ -516,6 +521,7 @@ DEFAULT_CONFIG = {
|
||||||
"base_url": "",
|
"base_url": "",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"timeout": 30,
|
"timeout": 30,
|
||||||
|
"extra_body": {},
|
||||||
},
|
},
|
||||||
"approval": {
|
"approval": {
|
||||||
"provider": "auto",
|
"provider": "auto",
|
||||||
|
|
@ -523,6 +529,7 @@ DEFAULT_CONFIG = {
|
||||||
"base_url": "",
|
"base_url": "",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"timeout": 30,
|
"timeout": 30,
|
||||||
|
"extra_body": {},
|
||||||
},
|
},
|
||||||
"mcp": {
|
"mcp": {
|
||||||
"provider": "auto",
|
"provider": "auto",
|
||||||
|
|
@ -530,6 +537,7 @@ DEFAULT_CONFIG = {
|
||||||
"base_url": "",
|
"base_url": "",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"timeout": 30,
|
"timeout": 30,
|
||||||
|
"extra_body": {},
|
||||||
},
|
},
|
||||||
"flush_memories": {
|
"flush_memories": {
|
||||||
"provider": "auto",
|
"provider": "auto",
|
||||||
|
|
@ -537,6 +545,7 @@ DEFAULT_CONFIG = {
|
||||||
"base_url": "",
|
"base_url": "",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"timeout": 30,
|
"timeout": 30,
|
||||||
|
"extra_body": {},
|
||||||
},
|
},
|
||||||
"title_generation": {
|
"title_generation": {
|
||||||
"provider": "auto",
|
"provider": "auto",
|
||||||
|
|
@ -544,6 +553,7 @@ DEFAULT_CONFIG = {
|
||||||
"base_url": "",
|
"base_url": "",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"timeout": 30,
|
"timeout": 30,
|
||||||
|
"extra_body": {},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|
@ -812,7 +822,7 @@ DEFAULT_CONFIG = {
|
||||||
},
|
},
|
||||||
|
|
||||||
# Config schema version - bump this when adding new required fields
|
# Config schema version - bump this when adding new required fields
|
||||||
"_config_version": 19,
|
"_config_version": 20,
|
||||||
}
|
}
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
|
||||||
|
|
@ -946,6 +946,70 @@ class TestStaleBaseUrlWarning:
|
||||||
"Expected a warning about stale OPENAI_BASE_URL"
|
"Expected a warning about stale OPENAI_BASE_URL"
|
||||||
assert mod._stale_base_url_warned is True
|
assert mod._stale_base_url_warned is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuxiliaryTaskExtraBody:
|
||||||
|
def test_sync_call_merges_task_extra_body_from_config(self):
|
||||||
|
client = MagicMock()
|
||||||
|
client.base_url = "https://api.example.com/v1"
|
||||||
|
response = MagicMock()
|
||||||
|
client.chat.completions.create.return_value = response
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"auxiliary": {
|
||||||
|
"session_search": {
|
||||||
|
"extra_body": {
|
||||||
|
"enable_thinking": False,
|
||||||
|
"reasoning": {"effort": "none"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("hermes_cli.config.load_config", return_value=config), patch(
|
||||||
|
"agent.auxiliary_client._get_cached_client",
|
||||||
|
return_value=(client, "glm-4.5-air"),
|
||||||
|
):
|
||||||
|
result = call_llm(
|
||||||
|
task="session_search",
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
extra_body={"metadata": {"source": "test"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is response
|
||||||
|
kwargs = client.chat.completions.create.call_args.kwargs
|
||||||
|
assert kwargs["extra_body"]["enable_thinking"] is False
|
||||||
|
assert kwargs["extra_body"]["reasoning"] == {"effort": "none"}
|
||||||
|
assert kwargs["extra_body"]["metadata"] == {"source": "test"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_call_explicit_extra_body_overrides_task_config(self):
|
||||||
|
client = MagicMock()
|
||||||
|
client.base_url = "https://api.example.com/v1"
|
||||||
|
response = MagicMock()
|
||||||
|
client.chat.completions.create = AsyncMock(return_value=response)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"auxiliary": {
|
||||||
|
"session_search": {
|
||||||
|
"extra_body": {"enable_thinking": False}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("hermes_cli.config.load_config", return_value=config), patch(
|
||||||
|
"agent.auxiliary_client._get_cached_client",
|
||||||
|
return_value=(client, "glm-4.5-air"),
|
||||||
|
):
|
||||||
|
result = await async_call_llm(
|
||||||
|
task="session_search",
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
extra_body={"enable_thinking": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is response
|
||||||
|
kwargs = client.chat.completions.create.call_args.kwargs
|
||||||
|
assert kwargs["extra_body"]["enable_thinking"] is True
|
||||||
|
|
||||||
def test_no_warning_when_provider_is_custom(self, monkeypatch, caplog):
|
def test_no_warning_when_provider_is_custom(self, monkeypatch, caplog):
|
||||||
"""No warning when the provider is 'custom' — OPENAI_BASE_URL is expected."""
|
"""No warning when the provider is 'custom' — OPENAI_BASE_URL is expected."""
|
||||||
import agent.auxiliary_client as mod
|
import agent.auxiliary_client as mod
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,15 @@ def test_title_generation_present_in_default_config():
|
||||||
assert tg["provider"] == "auto"
|
assert tg["provider"] == "auto"
|
||||||
assert tg["model"] == ""
|
assert tg["model"] == ""
|
||||||
assert tg["timeout"] > 0
|
assert tg["timeout"] > 0
|
||||||
|
assert tg["extra_body"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_session_search_defaults_include_extra_body_and_concurrency():
|
||||||
|
ss = DEFAULT_CONFIG["auxiliary"]["session_search"]
|
||||||
|
assert ss["provider"] == "auto"
|
||||||
|
assert ss["model"] == ""
|
||||||
|
assert ss["extra_body"] == {}
|
||||||
|
assert ss["max_concurrency"] == 3
|
||||||
|
|
||||||
|
|
||||||
def test_aux_tasks_keys_all_exist_in_default_config():
|
def test_aux_tasks_keys_all_exist_in_default_config():
|
||||||
|
|
|
||||||
|
|
@ -459,7 +459,7 @@ class TestCustomProviderCompatibility:
|
||||||
migrate_config(interactive=False, quiet=True)
|
migrate_config(interactive=False, quiet=True)
|
||||||
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
assert raw["_config_version"] == 19
|
assert raw["_config_version"] == 20
|
||||||
assert raw["providers"]["openai-direct"] == {
|
assert raw["providers"]["openai-direct"] == {
|
||||||
"api": "https://api.openai.com/v1",
|
"api": "https://api.openai.com/v1",
|
||||||
"api_key": "test-key",
|
"api_key": "test-key",
|
||||||
|
|
@ -606,7 +606,7 @@ class TestInterimAssistantMessageConfig:
|
||||||
migrate_config(interactive=False, quiet=True)
|
migrate_config(interactive=False, quiet=True)
|
||||||
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
assert raw["_config_version"] == 19
|
assert raw["_config_version"] == 20
|
||||||
assert raw["display"]["tool_progress"] == "off"
|
assert raw["display"]["tool_progress"] == "off"
|
||||||
assert raw["display"]["interim_assistant_messages"] is True
|
assert raw["display"]["interim_assistant_messages"] is True
|
||||||
|
|
||||||
|
|
@ -626,6 +626,6 @@ class TestDiscordChannelPromptsConfig:
|
||||||
migrate_config(interactive=False, quiet=True)
|
migrate_config(interactive=False, quiet=True)
|
||||||
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
assert raw["_config_version"] == 19
|
assert raw["_config_version"] == 20
|
||||||
assert raw["discord"]["auto_thread"] is True
|
assert raw["discord"]["auto_thread"] is True
|
||||||
assert raw["discord"]["channel_prompts"] == {}
|
assert raw["discord"]["channel_prompts"] == {}
|
||||||
|
|
|
||||||
|
|
@ -64,4 +64,4 @@ class TestCamofoxConfigDefaults:
|
||||||
|
|
||||||
# The current schema version is tracked globally; unrelated default
|
# The current schema version is tracked globally; unrelated default
|
||||||
# options may bump it after browser defaults are added.
|
# options may bump it after browser defaults are added.
|
||||||
assert DEFAULT_CONFIG["_config_version"] == 19
|
assert DEFAULT_CONFIG["_config_version"] == 20
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Tests for tools/session_search_tool.py — helper functions and search dispatcher."""
|
"""Tests for tools/session_search_tool.py — helper functions and search dispatcher."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -8,6 +9,7 @@ from tools.session_search_tool import (
|
||||||
_format_timestamp,
|
_format_timestamp,
|
||||||
_format_conversation,
|
_format_conversation,
|
||||||
_truncate_around_matches,
|
_truncate_around_matches,
|
||||||
|
_get_session_search_max_concurrency,
|
||||||
_HIDDEN_SESSION_SOURCES,
|
_HIDDEN_SESSION_SOURCES,
|
||||||
MAX_SESSION_CHARS,
|
MAX_SESSION_CHARS,
|
||||||
SESSION_SEARCH_SCHEMA,
|
SESSION_SEARCH_SCHEMA,
|
||||||
|
|
@ -181,6 +183,63 @@ class TestTruncateAroundMatches:
|
||||||
assert result.lower().count("alpha beta") == 2
|
assert result.lower().count("alpha beta") == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionSearchConcurrency:
|
||||||
|
def test_defaults_to_three(self):
|
||||||
|
assert _get_session_search_max_concurrency() == 3
|
||||||
|
|
||||||
|
def test_reads_and_clamps_configured_value(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"hermes_cli.config.load_config",
|
||||||
|
lambda: {"auxiliary": {"session_search": {"max_concurrency": 9}}},
|
||||||
|
)
|
||||||
|
assert _get_session_search_max_concurrency() == 5
|
||||||
|
|
||||||
|
def test_session_search_respects_configured_concurrency_limit(self, monkeypatch):
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from tools.session_search_tool import session_search
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"hermes_cli.config.load_config",
|
||||||
|
lambda: {"auxiliary": {"session_search": {"max_concurrency": 1}}},
|
||||||
|
)
|
||||||
|
|
||||||
|
max_seen = {"value": 0}
|
||||||
|
active = {"value": 0}
|
||||||
|
|
||||||
|
async def fake_summarize(_text, _query, _meta):
|
||||||
|
active["value"] += 1
|
||||||
|
max_seen["value"] = max(max_seen["value"], active["value"])
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
active["value"] -= 1
|
||||||
|
return "summary"
|
||||||
|
|
||||||
|
monkeypatch.setattr("tools.session_search_tool._summarize_session", fake_summarize)
|
||||||
|
monkeypatch.setattr("model_tools._run_async", lambda coro: asyncio.run(coro))
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.search_messages.return_value = [
|
||||||
|
{"session_id": "s1", "source": "cli", "session_started": 1709500000, "model": "test"},
|
||||||
|
{"session_id": "s2", "source": "cli", "session_started": 1709500001, "model": "test"},
|
||||||
|
{"session_id": "s3", "source": "cli", "session_started": 1709500002, "model": "test"},
|
||||||
|
]
|
||||||
|
mock_db.get_session.side_effect = lambda sid: {
|
||||||
|
"id": sid,
|
||||||
|
"parent_session_id": None,
|
||||||
|
"source": "cli",
|
||||||
|
"started_at": 1709500000,
|
||||||
|
}
|
||||||
|
mock_db.get_messages_as_conversation.side_effect = lambda sid: [
|
||||||
|
{"role": "user", "content": f"message from {sid}"},
|
||||||
|
{"role": "assistant", "content": "response"},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = json.loads(session_search(query="message", db=mock_db, limit=3))
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["count"] == 3
|
||||||
|
assert max_seen["value"] == 1
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# session_search (dispatcher)
|
# session_search (dispatcher)
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,27 @@ MAX_SESSION_CHARS = 100_000
|
||||||
MAX_SUMMARY_TOKENS = 10000
|
MAX_SUMMARY_TOKENS = 10000
|
||||||
|
|
||||||
|
|
||||||
|
def _get_session_search_max_concurrency(default: int = 3) -> int:
|
||||||
|
"""Read auxiliary.session_search.max_concurrency with sane bounds."""
|
||||||
|
try:
|
||||||
|
from hermes_cli.config import load_config
|
||||||
|
config = load_config()
|
||||||
|
except ImportError:
|
||||||
|
return default
|
||||||
|
aux = config.get("auxiliary", {}) if isinstance(config, dict) else {}
|
||||||
|
task_config = aux.get("session_search", {}) if isinstance(aux, dict) else {}
|
||||||
|
if not isinstance(task_config, dict):
|
||||||
|
return default
|
||||||
|
raw = task_config.get("max_concurrency")
|
||||||
|
if raw is None:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
value = int(raw)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return default
|
||||||
|
return max(1, min(value, 5))
|
||||||
|
|
||||||
|
|
||||||
def _format_timestamp(ts: Union[int, float, str, None]) -> str:
|
def _format_timestamp(ts: Union[int, float, str, None]) -> str:
|
||||||
"""Convert a Unix timestamp (float/int) or ISO string to a human-readable date.
|
"""Convert a Unix timestamp (float/int) or ISO string to a human-readable date.
|
||||||
|
|
||||||
|
|
@ -423,9 +444,16 @@ def session_search(
|
||||||
|
|
||||||
# Summarize all sessions in parallel
|
# Summarize all sessions in parallel
|
||||||
async def _summarize_all() -> List[Union[str, Exception]]:
|
async def _summarize_all() -> List[Union[str, Exception]]:
|
||||||
"""Summarize all sessions in parallel."""
|
"""Summarize all sessions with bounded concurrency."""
|
||||||
|
max_concurrency = min(_get_session_search_max_concurrency(), max(1, len(tasks)))
|
||||||
|
semaphore = asyncio.Semaphore(max_concurrency)
|
||||||
|
|
||||||
|
async def _bounded_summary(text: str, meta: Dict[str, Any]) -> Optional[str]:
|
||||||
|
async with semaphore:
|
||||||
|
return await _summarize_session(text, query, meta)
|
||||||
|
|
||||||
coros = [
|
coros = [
|
||||||
_summarize_session(text, query, meta)
|
_bounded_summary(text, meta)
|
||||||
for _, _, text, meta in tasks
|
for _, _, text, meta in tasks
|
||||||
]
|
]
|
||||||
return await asyncio.gather(*coros, return_exceptions=True)
|
return await asyncio.gather(*coros, return_exceptions=True)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue