fix(agent): route compression aux through live session runtime

This commit is contained in:
Harish Kukreja 2026-04-12 00:10:19 -04:00 committed by Teknium
parent c52f6348b6
commit b1f13a8c5f
6 changed files with 216 additions and 11 deletions

View file

@ -1021,6 +1021,23 @@ _AUTO_PROVIDER_LABELS = {
_AGGREGATOR_PROVIDERS = frozenset({"openrouter", "nous"})
_MAIN_RUNTIME_FIELDS = ("provider", "model", "base_url", "api_key", "api_mode")
def _normalize_main_runtime(main_runtime: Optional[Dict[str, Any]]) -> Dict[str, str]:
"""Return a sanitized copy of a live main-runtime override."""
if not isinstance(main_runtime, dict):
return {}
normalized: Dict[str, str] = {}
for field in _MAIN_RUNTIME_FIELDS:
value = main_runtime.get(field)
if isinstance(value, str) and value.strip():
normalized[field] = value.strip()
provider = normalized.get("provider")
if provider:
normalized["provider"] = provider.lower()
return normalized
def _get_provider_chain() -> List[tuple]:
"""Return the ordered provider detection chain.
@ -1130,7 +1147,7 @@ def _try_payment_fallback(
return None, None, ""
def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
def _resolve_auto(main_runtime: Optional[Dict[str, Any]] = None) -> Tuple[Optional[OpenAI], Optional[str]]:
"""Full auto-detection chain.
Priority:
@ -1142,6 +1159,12 @@ def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
"""
global auxiliary_is_nous, _stale_base_url_warned
auxiliary_is_nous = False # Reset — _try_nous() will set True if it wins
runtime = _normalize_main_runtime(main_runtime)
runtime_provider = runtime.get("provider", "")
runtime_model = runtime.get("model", "")
runtime_base_url = runtime.get("base_url", "")
runtime_api_key = runtime.get("api_key", "")
runtime_api_mode = runtime.get("api_mode", "")
# ── Warn once if OPENAI_BASE_URL is set but config.yaml uses a named
# provider (not 'custom'). This catches the common "env poisoning"
@ -1149,7 +1172,7 @@ def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
# old OPENAI_BASE_URL lingers in ~/.hermes/.env. ──
if not _stale_base_url_warned:
_env_base = os.getenv("OPENAI_BASE_URL", "").strip()
_cfg_provider = _read_main_provider()
_cfg_provider = runtime_provider or _read_main_provider()
if (_env_base and _cfg_provider
and _cfg_provider != "custom"
and not _cfg_provider.startswith("custom:")):
@ -1163,12 +1186,25 @@ def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
_stale_base_url_warned = True
# ── Step 1: non-aggregator main provider → use main model directly ──
main_provider = _read_main_provider()
main_model = _read_main_model()
main_provider = runtime_provider or _read_main_provider()
main_model = runtime_model or _read_main_model()
if (main_provider and main_model
and main_provider not in _AGGREGATOR_PROVIDERS
and main_provider not in ("auto", "")):
client, resolved = resolve_provider_client(main_provider, main_model)
resolved_provider = main_provider
explicit_base_url = None
explicit_api_key = None
if runtime_base_url and (main_provider == "custom" or main_provider.startswith("custom:")):
resolved_provider = "custom"
explicit_base_url = runtime_base_url
explicit_api_key = runtime_api_key or None
client, resolved = resolve_provider_client(
resolved_provider,
main_model,
explicit_base_url=explicit_base_url,
explicit_api_key=explicit_api_key,
api_mode=runtime_api_mode or None,
)
if client is not None:
logger.info("Auxiliary auto-detect: using main provider %s (%s)",
main_provider, resolved or main_model)
@ -1249,6 +1285,7 @@ def resolve_provider_client(
explicit_base_url: str = None,
explicit_api_key: str = None,
api_mode: str = None,
main_runtime: Optional[Dict[str, Any]] = None,
) -> Tuple[Optional[Any], Optional[str]]:
"""Central router: given a provider name and optional model, return a
configured client with the correct auth, base URL, and API format.
@ -1319,7 +1356,7 @@ def resolve_provider_client(
# ── Auto: try all providers in priority order ────────────────────
if provider == "auto":
client, resolved = _resolve_auto()
client, resolved = _resolve_auto(main_runtime=main_runtime)
if client is None:
return None, None
# When auto-detection lands on a non-OpenRouter provider (e.g. a
@ -1543,7 +1580,11 @@ def resolve_provider_client(
# ── Public API ──────────────────────────────────────────────────────────────
def get_text_auxiliary_client(task: str = "") -> Tuple[Optional[OpenAI], Optional[str]]:
def get_text_auxiliary_client(
task: str = "",
*,
main_runtime: Optional[Dict[str, Any]] = None,
) -> Tuple[Optional[OpenAI], Optional[str]]:
"""Return (client, default_model_slug) for text-only auxiliary tasks.
Args:
@ -1560,10 +1601,11 @@ def get_text_auxiliary_client(task: str = "") -> Tuple[Optional[OpenAI], Optiona
explicit_base_url=base_url,
explicit_api_key=api_key,
api_mode=api_mode,
main_runtime=main_runtime,
)
def get_async_text_auxiliary_client(task: str = ""):
def get_async_text_auxiliary_client(task: str = "", *, main_runtime: Optional[Dict[str, Any]] = None):
"""Return (async_client, model_slug) for async consumers.
For standard providers returns (AsyncOpenAI, model). For Codex returns
@ -1578,6 +1620,7 @@ def get_async_text_auxiliary_client(task: str = ""):
explicit_base_url=base_url,
explicit_api_key=api_key,
api_mode=api_mode,
main_runtime=main_runtime,
)
@ -1892,6 +1935,7 @@ def _get_cached_client(
base_url: str = None,
api_key: str = None,
api_mode: str = None,
main_runtime: Optional[Dict[str, Any]] = None,
) -> Tuple[Optional[Any], Optional[str]]:
"""Get or create a cached client for the given provider.
@ -1915,7 +1959,9 @@ def _get_cached_client(
loop_id = id(current_loop)
except RuntimeError:
pass
cache_key = (provider, async_mode, base_url or "", api_key or "", api_mode or "", loop_id)
runtime = _normalize_main_runtime(main_runtime)
runtime_key = tuple(runtime.get(field, "") for field in _MAIN_RUNTIME_FIELDS) if provider == "auto" else ()
cache_key = (provider, async_mode, base_url or "", api_key or "", api_mode or "", loop_id, runtime_key)
with _client_cache_lock:
if cache_key in _client_cache:
cached_client, cached_default, cached_loop = _client_cache[cache_key]
@ -1940,6 +1986,7 @@ def _get_cached_client(
explicit_base_url=base_url,
explicit_api_key=api_key,
api_mode=api_mode,
main_runtime=runtime,
)
if client is not None:
# For async clients, remember which loop they were created on so we
@ -2149,6 +2196,7 @@ def call_llm(
model: str = None,
base_url: str = None,
api_key: str = None,
main_runtime: Optional[Dict[str, Any]] = None,
messages: list,
temperature: float = None,
max_tokens: int = None,
@ -2214,6 +2262,7 @@ def call_llm(
base_url=resolved_base_url,
api_key=resolved_api_key,
api_mode=resolved_api_mode,
main_runtime=main_runtime,
)
if client is None:
# When the user explicitly chose a non-OpenRouter provider but no
@ -2234,7 +2283,7 @@ def call_llm(
if not resolved_base_url:
logger.info("Auxiliary %s: provider %s unavailable, trying auto-detection chain",
task or "call", resolved_provider)
client, final_model = _get_cached_client("auto")
client, final_model = _get_cached_client("auto", main_runtime=main_runtime)
if client is None:
raise RuntimeError(
f"No LLM provider configured for task={task} provider={resolved_provider}. "

View file

@ -86,12 +86,14 @@ class ContextCompressor(ContextEngine):
base_url: str = "",
api_key: str = "",
provider: str = "",
api_mode: str = "",
) -> None:
"""Update model info after a model switch or fallback activation."""
self.model = model
self.base_url = base_url
self.api_key = api_key
self.provider = provider
self.api_mode = api_mode
self.context_length = context_length
self.threshold_tokens = max(
int(context_length * self.threshold_percent),
@ -111,11 +113,13 @@ class ContextCompressor(ContextEngine):
api_key: str = "",
config_context_length: int | None = None,
provider: str = "",
api_mode: str = "",
):
self.model = model
self.base_url = base_url
self.api_key = api_key
self.provider = provider
self.api_mode = api_mode
self.threshold_percent = threshold_percent
self.protect_first_n = protect_first_n
self.protect_last_n = protect_last_n
@ -438,6 +442,13 @@ The user has requested that this compaction PRIORITISE preserving all informatio
try:
call_kwargs = {
"task": "compression",
"main_runtime": {
"model": self.model,
"provider": self.provider,
"base_url": self.base_url,
"api_key": self.api_key,
"api_mode": self.api_mode,
},
"messages": [{"role": "user", "content": prompt}],
"max_tokens": summary_budget * 2,
# timeout resolved from auxiliary.compression.timeout config by call_llm

View file

@ -1307,6 +1307,7 @@ class AIAgent:
api_key=getattr(self, "api_key", ""),
config_context_length=_config_context_length,
provider=self.provider,
api_mode=self.api_mode,
)
self.compression_enabled = compression_enabled
@ -1563,6 +1564,7 @@ class AIAgent:
base_url=self.base_url,
api_key=getattr(self, "api_key", ""),
provider=self.provider,
api_mode=self.api_mode,
)
# ── Invalidate cached system prompt so it rebuilds next turn ──
@ -1696,6 +1698,16 @@ class AIAgent:
except Exception:
logger.debug("status_callback error in _emit_status", exc_info=True)
def _current_main_runtime(self) -> Dict[str, str]:
"""Return the live main runtime for session-scoped auxiliary routing."""
return {
"model": getattr(self, "model", "") or "",
"provider": getattr(self, "provider", "") or "",
"base_url": getattr(self, "base_url", "") or "",
"api_key": getattr(self, "api_key", "") or "",
"api_mode": getattr(self, "api_mode", "") or "",
}
def _check_compression_model_feasibility(self) -> None:
"""Warn at session start if the auxiliary compression model's context
window is smaller than the main model's compression threshold.
@ -1716,7 +1728,10 @@ class AIAgent:
from agent.auxiliary_client import get_text_auxiliary_client
from agent.model_metadata import get_model_context_length
client, aux_model = get_text_auxiliary_client("compression")
client, aux_model = get_text_auxiliary_client(
"compression",
main_runtime=self._current_main_runtime(),
)
if client is None or not aux_model:
msg = (
"⚠ No auxiliary LLM provider configured — context "

View file

@ -971,6 +971,74 @@ class TestTaskSpecificOverrides:
client, model = get_text_auxiliary_client("compression")
assert model == "google/gemini-3-flash-preview" # auto → OpenRouter
def test_resolve_auto_prefers_live_main_runtime_over_persisted_config(self, monkeypatch, tmp_path):
"""Session-only live model switches should override persisted config for auto routing."""
hermes_home = tmp_path / "hermes"
hermes_home.mkdir(parents=True, exist_ok=True)
(hermes_home / "config.yaml").write_text(
"""model:
default: glm-5.1
provider: opencode-go
compression:
summary_provider: auto
"""
)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
calls = []
def _fake_resolve(provider, model=None, *args, **kwargs):
calls.append((provider, model, kwargs))
return MagicMock(), model or "resolved-model"
with patch("agent.auxiliary_client.resolve_provider_client", side_effect=_fake_resolve):
client, model = _resolve_auto(
main_runtime={
"provider": "openai-codex",
"model": "gpt-5.4",
"api_mode": "codex_responses",
}
)
assert client is not None
assert model == "gpt-5.4"
assert calls[0][0] == "openai-codex"
assert calls[0][1] == "gpt-5.4"
assert calls[0][2]["api_mode"] == "codex_responses"
def test_explicit_compression_pin_still_wins_over_live_main_runtime(self, monkeypatch, tmp_path):
"""Task-level compression config should beat a live session override."""
hermes_home = tmp_path / "hermes"
hermes_home.mkdir(parents=True, exist_ok=True)
(hermes_home / "config.yaml").write_text(
"""auxiliary:
compression:
provider: openrouter
model: google/gemini-3-flash-preview
model:
default: glm-5.1
provider: opencode-go
"""
)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(MagicMock(), "google/gemini-3-flash-preview")) as mock_resolve:
client, model = get_text_auxiliary_client(
"compression",
main_runtime={
"provider": "openai-codex",
"model": "gpt-5.4",
},
)
assert client is not None
assert model == "google/gemini-3-flash-preview"
assert mock_resolve.call_args.args[0] == "openrouter"
assert mock_resolve.call_args.kwargs["main_runtime"] == {
"provider": "openai-codex",
"model": "gpt-5.4",
}
def test_compression_summary_base_url_from_config(self, monkeypatch, tmp_path):
"""compression.summary_base_url should produce a custom-endpoint client."""
hermes_home = tmp_path / "hermes"

View file

@ -191,6 +191,37 @@ class TestNonStringContent:
kwargs = mock_call.call_args.kwargs
assert "temperature" not in kwargs
def test_summary_call_passes_live_main_runtime(self):
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "ok"
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
c = ContextCompressor(
model="gpt-5.4",
provider="openai-codex",
base_url="https://chatgpt.com/backend-api/codex",
api_key="codex-token",
api_mode="codex_responses",
quiet_mode=True,
)
messages = [
{"role": "user", "content": "do something"},
{"role": "assistant", "content": "ok"},
]
with patch("agent.context_compressor.call_llm", return_value=mock_response) as mock_call:
c._generate_summary(messages)
assert mock_call.call_args.kwargs["main_runtime"] == {
"model": "gpt-5.4",
"provider": "openai-codex",
"base_url": "https://chatgpt.com/backend-api/codex",
"api_key": "codex-token",
"api_mode": "codex_responses",
}
class TestSummaryFailureCooldown:
def test_summary_failure_enters_cooldown_and_skips_retry(self):

View file

@ -26,6 +26,7 @@ def _make_agent(
agent.provider = "openrouter"
agent.base_url = "https://openrouter.ai/api/v1"
agent.api_key = "sk-test"
agent.api_mode = "chat_completions"
agent.quiet_mode = True
agent.log_prefix = ""
agent.compression_enabled = compression_enabled
@ -99,6 +100,36 @@ def test_no_warning_when_aux_context_sufficient(mock_get_client, mock_ctx_len):
assert agent._compression_warning is None
def test_feasibility_check_passes_live_main_runtime():
"""Compression feasibility should probe using the live session runtime."""
agent = _make_agent(main_context=200_000, threshold_percent=0.50)
agent.model = "gpt-5.4"
agent.provider = "openai-codex"
agent.base_url = "https://chatgpt.com/backend-api/codex"
agent.api_key = "codex-token"
agent.api_mode = "codex_responses"
mock_client = MagicMock()
mock_client.base_url = "https://chatgpt.com/backend-api/codex"
mock_client.api_key = "codex-token"
with patch("agent.auxiliary_client.get_text_auxiliary_client", return_value=(mock_client, "gpt-5.4")) as mock_get_client, \
patch("agent.model_metadata.get_model_context_length", return_value=200_000):
agent._emit_status = lambda msg: None
agent._check_compression_model_feasibility()
mock_get_client.assert_called_once_with(
"compression",
main_runtime={
"model": "gpt-5.4",
"provider": "openai-codex",
"base_url": "https://chatgpt.com/backend-api/codex",
"api_key": "codex-token",
"api_mode": "codex_responses",
},
)
@patch("agent.auxiliary_client.get_text_auxiliary_client")
def test_warns_when_no_auxiliary_provider(mock_get_client):
"""Warning emitted when no auxiliary provider is configured."""