diff --git a/gateway/run.py b/gateway/run.py index 2f15361c6..00156f126 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -667,6 +667,7 @@ class GatewayRunner: def _flush_memories_for_session( self, old_session_id: str, + session_key: Optional[str] = None, ): """Prompt the agent to save memories/skills before context is lost. @@ -685,15 +686,12 @@ class GatewayRunner: return from run_agent import AIAgent - runtime_kwargs = _resolve_runtime_agent_kwargs() + model, runtime_kwargs = self._resolve_session_agent_runtime( + session_key=session_key, + ) if not runtime_kwargs.get("api_key"): return - # Resolve model from config — AIAgent's default is OpenRouter- - # formatted ("anthropic/claude-opus-4.6") which fails when the - # active provider is openai-codex. - model = _resolve_gateway_model() - tmp_agent = AIAgent( **runtime_kwargs, model=model, @@ -773,6 +771,7 @@ class GatewayRunner: async def _async_flush_memories( self, old_session_id: str, + session_key: Optional[str] = None, ): """Run the sync memory flush in a thread pool so it won't block the event loop.""" loop = asyncio.get_event_loop() @@ -780,6 +779,7 @@ class GatewayRunner: None, self._flush_memories_for_session, old_session_id, + session_key, ) @property @@ -814,6 +814,46 @@ class GatewayRunner: thread_sessions_per_user=getattr(config, "thread_sessions_per_user", False), ) + def _resolve_session_agent_runtime( + self, + *, + source: Optional[SessionSource] = None, + session_key: Optional[str] = None, + user_config: Optional[dict] = None, + ) -> tuple[str, dict]: + """Resolve model/runtime for a session, honoring session-scoped /model overrides. + + If the session override already contains a complete provider bundle + (provider/api_key/base_url/api_mode), prefer it directly instead of + resolving fresh global runtime state first. + """ + resolved_session_key = session_key + if not resolved_session_key and source is not None: + try: + resolved_session_key = self._session_key_for_source(source) + except Exception: + resolved_session_key = None + + model = _resolve_gateway_model(user_config) + override = self._session_model_overrides.get(resolved_session_key) if resolved_session_key else None + if override: + override_model = override.get("model", model) + override_runtime = { + "provider": override.get("provider"), + "api_key": override.get("api_key"), + "base_url": override.get("base_url"), + "api_mode": override.get("api_mode"), + } + if override_runtime.get("api_key"): + return override_model, override_runtime + + runtime_kwargs = _resolve_runtime_agent_kwargs() + if override and resolved_session_key: + model, runtime_kwargs = self._apply_session_model_override( + resolved_session_key, model, runtime_kwargs + ) + return model, runtime_kwargs + def _resolve_turn_agent_config(self, user_message: str, model: str, runtime_kwargs: dict) -> dict: from agent.smart_model_routing import resolve_turn_route from hermes_cli.models import resolve_fast_mode_overrides @@ -1598,7 +1638,7 @@ class GatewayRunner: for key, entry in _expired_entries: try: - await self._async_flush_memories(entry.session_id) + await self._async_flush_memories(entry.session_id, key) # Shut down memory provider and close tool resources # on the cached agent. Idle agents live in # _agent_cache (not _running_agents), so look there. @@ -2867,6 +2907,7 @@ class GatewayRunner: _hyg_provider = None _hyg_base_url = None _hyg_api_key = None + _hyg_data = {} try: _hyg_cfg_path = _hermes_home / "config.yaml" if _hyg_cfg_path.exists(): @@ -2901,15 +2942,17 @@ class GatewayRunner: _comp_cfg.get("enabled", True) ).lower() in ("true", "1", "yes") - # Resolve provider/base_url from runtime if not in config - if not _hyg_provider or not _hyg_base_url: - try: - _hyg_runtime = _resolve_runtime_agent_kwargs() - _hyg_provider = _hyg_provider or _hyg_runtime.get("provider") - _hyg_base_url = _hyg_base_url or _hyg_runtime.get("base_url") - _hyg_api_key = _hyg_runtime.get("api_key") - except Exception: - pass + try: + _hyg_model, _hyg_runtime = self._resolve_session_agent_runtime( + source=source, + session_key=session_key, + user_config=_hyg_data if isinstance(_hyg_data, dict) else None, + ) + _hyg_provider = _hyg_runtime.get("provider") or _hyg_provider + _hyg_base_url = _hyg_runtime.get("base_url") or _hyg_base_url + _hyg_api_key = _hyg_runtime.get("api_key") or _hyg_api_key + except Exception: + pass # Check custom_providers per-model context_length # (same fallback as run_agent.py lines 1171-1189). @@ -2996,7 +3039,11 @@ class GatewayRunner: try: from run_agent import AIAgent - _hyg_runtime = _resolve_runtime_agent_kwargs() + _hyg_model, _hyg_runtime = self._resolve_session_agent_runtime( + source=source, + session_key=session_key, + user_config=_hyg_data if isinstance(_hyg_data, dict) else None, + ) if _hyg_runtime.get("api_key"): _hyg_msgs = [ {"role": m.get("role"), "content": m.get("content")} @@ -3652,7 +3699,7 @@ class GatewayRunner: old_entry = self.session_store._entries.get(session_key) if old_entry: _flush_task = asyncio.create_task( - self._async_flush_memories(old_entry.session_id) + self._async_flush_memories(old_entry.session_id, session_key) ) self._background_tasks.add(_flush_task) _flush_task.add_done_callback(self._background_tasks.discard) @@ -4973,7 +5020,11 @@ class GatewayRunner: _thread_metadata = {"thread_id": source.thread_id} if source.thread_id else None try: - runtime_kwargs = _resolve_runtime_agent_kwargs() + user_config = _load_gateway_config() + model, runtime_kwargs = self._resolve_session_agent_runtime( + source=source, + user_config=user_config, + ) if not runtime_kwargs.get("api_key"): await adapter.send( source.chat_id, @@ -4982,8 +5033,6 @@ class GatewayRunner: ) return - user_config = _load_gateway_config() - model = _resolve_gateway_model(user_config) platform_key = _platform_config_key(source.platform) from hermes_cli.tools_config import _get_platform_tools @@ -5143,7 +5192,12 @@ class GatewayRunner: _thread_meta = {"thread_id": source.thread_id} if source.thread_id else None try: - runtime_kwargs = _resolve_runtime_agent_kwargs() + user_config = _load_gateway_config() + model, runtime_kwargs = self._resolve_session_agent_runtime( + source=source, + session_key=session_key, + user_config=user_config, + ) if not runtime_kwargs.get("api_key"): await adapter.send( source.chat_id, @@ -5152,8 +5206,6 @@ class GatewayRunner: ) return - user_config = _load_gateway_config() - model = _resolve_gateway_model(user_config) platform_key = _platform_config_key(source.platform) reasoning_config = self._load_reasoning_config() self._service_tier = self._load_service_tier() @@ -5490,13 +5542,14 @@ class GatewayRunner: from agent.manual_compression_feedback import summarize_manual_compression from agent.model_metadata import estimate_messages_tokens_rough - runtime_kwargs = _resolve_runtime_agent_kwargs() + session_key = self._session_key_for_source(source) + model, runtime_kwargs = self._resolve_session_agent_runtime( + source=source, + session_key=session_key, + ) if not runtime_kwargs.get("api_key"): return "No provider configured -- cannot compress." - # Resolve model from config (same reason as memory flush above). - model = _resolve_gateway_model() - msgs = [ {"role": m.get("role"), "content": m.get("content")} for m in history @@ -5656,7 +5709,7 @@ class GatewayRunner: # Flush memories for current session before switching try: _flush_task = asyncio.create_task( - self._async_flush_memories(current_entry.session_id) + self._async_flush_memories(current_entry.session_id, session_key) ) self._background_tasks.add(_flush_task) _flush_task.add_done_callback(self._background_tasks.discard) @@ -7227,10 +7280,12 @@ class GatewayRunner: except Exception: pass - model = _resolve_gateway_model(user_config) - try: - runtime_kwargs = _resolve_runtime_agent_kwargs() + model, runtime_kwargs = self._resolve_session_agent_runtime( + source=source, + session_key=session_key, + user_config=user_config, + ) except Exception as exc: return { "final_response": f"⚠️ Provider authentication failed: {exc}", @@ -7239,11 +7294,6 @@ class GatewayRunner: "tools": [], } - # /model overrides take precedence over config.yaml defaults. - model, runtime_kwargs = self._apply_session_model_override( - session_key, model, runtime_kwargs - ) - pr = self._provider_routing reasoning_config = self._load_reasoning_config() self._reasoning_config = reasoning_config diff --git a/tests/gateway/test_resume_command.py b/tests/gateway/test_resume_command.py index dc788f74f..4c82f4894 100644 --- a/tests/gateway/test_resume_command.py +++ b/tests/gateway/test_resume_command.py @@ -221,5 +221,6 @@ class TestHandleResumeCommand: runner._async_flush_memories.assert_called_once_with( "current_session_001", + "agent:main:telegram:dm:67890", ) db.close() diff --git a/tests/gateway/test_session_model_override_routing.py b/tests/gateway/test_session_model_override_routing.py new file mode 100644 index 000000000..340d01fdc --- /dev/null +++ b/tests/gateway/test_session_model_override_routing.py @@ -0,0 +1,160 @@ +"""Regression tests for session-scoped model/provider overrides in gateway agents. + +These cover the bug where `/model ...` stored a session override, but fresh +agent constructions still resolved model/provider from global config/runtime. +That let helper agents (and cache-miss main agents) route GPT-5.4 to the wrong +provider, e.g. Nous instead of OpenAI Codex. +""" + +import asyncio +import sys +import threading +import types +from unittest.mock import AsyncMock, MagicMock + +import pytest + +import gateway.run as gateway_run +from gateway.config import Platform +from gateway.session import SessionSource + + +class _CapturingAgent: + """Fake agent that records init kwargs for assertions.""" + + last_init = None + + def __init__(self, *args, **kwargs): + type(self).last_init = dict(kwargs) + self.tools = [] + + def run_conversation(self, user_message: str, conversation_history=None, task_id=None): + return { + "final_response": "ok", + "messages": [], + "api_calls": 1, + } + + +def _make_runner(): + runner = object.__new__(gateway_run.GatewayRunner) + runner.adapters = {} + runner.session_store = None + runner.config = None + runner._voice_mode = {} + runner._ephemeral_system_prompt = "" + runner._prefill_messages = [] + runner._reasoning_config = None + runner._show_reasoning = False + runner._provider_routing = {} + runner._fallback_model = None + runner._service_tier = None + runner._running_agents = {} + runner._running_agents_ts = {} + runner._background_tasks = set() + runner._session_db = None + runner._session_model_overrides = {} + runner._pending_model_notes = {} + runner._pending_approvals = {} + runner._agent_cache = {} + runner._agent_cache_lock = threading.Lock() + runner._get_or_create_gateway_honcho = lambda session_key: (None, None) + runner.hooks = MagicMock() + runner.hooks.emit = AsyncMock() + runner.hooks.loaded_hooks = [] + return runner + + +def _codex_override(): + return { + "model": "gpt-5.4", + "provider": "openai-codex", + "api_key": "***", + "base_url": "https://chatgpt.com/backend-api/codex", + "api_mode": "codex_responses", + } + + +def _explode_runtime_resolution(): + raise AssertionError( + "global runtime resolution should not run when a complete session override exists" + ) + + +def test_run_agent_prefers_session_override_over_global_runtime(monkeypatch): + monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {}) + monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None) + monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", _explode_runtime_resolution) + + fake_run_agent = types.ModuleType("run_agent") + fake_run_agent.AIAgent = _CapturingAgent + monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) + + _CapturingAgent.last_init = None + runner = _make_runner() + + source = SessionSource( + platform=Platform.LOCAL, + chat_id="cli", + chat_name="CLI", + chat_type="dm", + user_id="user-1", + ) + session_key = "agent:main:local:dm" + runner._session_model_overrides[session_key] = _codex_override() + + result = asyncio.run( + runner._run_agent( + message="ping", + context_prompt="", + history=[], + source=source, + session_id="session-1", + session_key=session_key, + ) + ) + + assert result["final_response"] == "ok" + assert _CapturingAgent.last_init is not None + assert _CapturingAgent.last_init["model"] == "gpt-5.4" + assert _CapturingAgent.last_init["provider"] == "openai-codex" + assert _CapturingAgent.last_init["api_mode"] == "codex_responses" + assert _CapturingAgent.last_init["base_url"] == "https://chatgpt.com/backend-api/codex" + assert _CapturingAgent.last_init["api_key"] == "***" + + +@pytest.mark.asyncio +async def test_background_task_prefers_session_override_over_global_runtime(monkeypatch): + monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {}) + monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", _explode_runtime_resolution) + + fake_run_agent = types.ModuleType("run_agent") + fake_run_agent.AIAgent = _CapturingAgent + monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) + + _CapturingAgent.last_init = None + runner = _make_runner() + + adapter = AsyncMock() + adapter.send = AsyncMock() + adapter.extract_media = MagicMock(return_value=([], "ok")) + adapter.extract_images = MagicMock(return_value=([], "ok")) + runner.adapters[Platform.TELEGRAM] = adapter + + source = SessionSource( + platform=Platform.TELEGRAM, + user_id="12345", + chat_id="67890", + user_name="testuser", + ) + session_key = runner._session_key_for_source(source) + runner._session_model_overrides[session_key] = _codex_override() + + await runner._run_background_task("say hello", source, "bg_test") + + assert _CapturingAgent.last_init is not None + assert _CapturingAgent.last_init["model"] == "gpt-5.4" + assert _CapturingAgent.last_init["provider"] == "openai-codex" + assert _CapturingAgent.last_init["api_mode"] == "codex_responses" + assert _CapturingAgent.last_init["base_url"] == "https://chatgpt.com/backend-api/codex" + assert _CapturingAgent.last_init["api_key"] == "***"