diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index ee4ff239198..1599eda9e6d 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -3510,35 +3510,46 @@ class APIServerAdapter(BasePlatformAdapter): loop = asyncio.get_running_loop() def _run(): - agent = self._create_agent( - ephemeral_system_prompt=ephemeral_system_prompt, - session_id=session_id, - stream_delta_callback=stream_delta_callback, - tool_progress_callback=tool_progress_callback, - tool_start_callback=tool_start_callback, - tool_complete_callback=tool_complete_callback, - gateway_session_key=gateway_session_key, + from gateway.session_context import clear_session_vars, set_session_vars + + tokens = set_session_vars( + platform="api_server", + chat_id=session_id or "", + session_key=gateway_session_key or session_id or "", + session_id=session_id or "", ) - if agent_ref is not None: - agent_ref[0] = agent - effective_task_id = session_id or str(uuid.uuid4()) - result = agent.run_conversation( - user_message=user_message, - conversation_history=conversation_history, - task_id=effective_task_id, - ) - usage = { - "input_tokens": getattr(agent, "session_prompt_tokens", 0) or 0, - "output_tokens": getattr(agent, "session_completion_tokens", 0) or 0, - "total_tokens": getattr(agent, "session_total_tokens", 0) or 0, - } - # Include the effective session ID in the result so callers - # (e.g. X-Hermes-Session-Id header) can track compression- - # triggered session rotations. (#16938) - _eff_sid = getattr(agent, "session_id", session_id) - if isinstance(_eff_sid, str) and _eff_sid: - result["session_id"] = _eff_sid - return result, usage + try: + agent = self._create_agent( + ephemeral_system_prompt=ephemeral_system_prompt, + session_id=session_id, + stream_delta_callback=stream_delta_callback, + tool_progress_callback=tool_progress_callback, + tool_start_callback=tool_start_callback, + tool_complete_callback=tool_complete_callback, + gateway_session_key=gateway_session_key, + ) + if agent_ref is not None: + agent_ref[0] = agent + effective_task_id = session_id or str(uuid.uuid4()) + result = agent.run_conversation( + user_message=user_message, + conversation_history=conversation_history, + task_id=effective_task_id, + ) + usage = { + "input_tokens": getattr(agent, "session_prompt_tokens", 0) or 0, + "output_tokens": getattr(agent, "session_completion_tokens", 0) or 0, + "total_tokens": getattr(agent, "session_total_tokens", 0) or 0, + } + # Include the effective session ID in the result so callers + # (e.g. X-Hermes-Session-Id header) can track compression- + # triggered session rotations. (#16938) + _eff_sid = getattr(agent, "session_id", session_id) + if isinstance(_eff_sid, str) and _eff_sid: + result["session_id"] = _eff_sid + return result, usage + finally: + clear_session_vars(tokens) return await loop.run_in_executor(None, _run) diff --git a/gateway/session_context.py b/gateway/session_context.py index 8dfc84cac80..c8c5cf438c7 100644 --- a/gateway/session_context.py +++ b/gateway/session_context.py @@ -106,6 +106,7 @@ def set_session_vars( user_id: str = "", user_name: str = "", session_key: str = "", + session_id: str = "", message_id: str = "", cwd: str = "", ) -> list: @@ -127,6 +128,7 @@ def set_session_vars( _SESSION_USER_ID.set(user_id), _SESSION_USER_NAME.set(user_name), _SESSION_KEY.set(session_key), + _SESSION_ID.set(session_id), _SESSION_MESSAGE_ID.set(message_id), ] try: @@ -157,6 +159,7 @@ def clear_session_vars(tokens: list) -> None: _SESSION_USER_ID, _SESSION_USER_NAME, _SESSION_KEY, + _SESSION_ID, _SESSION_MESSAGE_ID, ): var.set("") diff --git a/tests/gateway/test_session_api.py b/tests/gateway/test_session_api.py index 28d15e9a554..5d943e97348 100644 --- a/tests/gateway/test_session_api.py +++ b/tests/gateway/test_session_api.py @@ -75,6 +75,54 @@ async def test_capabilities_advertises_session_control_surface(adapter): } +@pytest.mark.asyncio +async def test_run_agent_binds_api_session_context_for_tool_env(adapter, monkeypatch): + """API-server request sessions should reach tools and terminal subprocess env.""" + monkeypatch.setenv("HERMES_SESSION_ID", "stale-session") + observed = {} + + class FakeAgent: + session_prompt_tokens = 0 + session_completion_tokens = 0 + session_total_tokens = 0 + + def __init__(self, session_id: str): + self.session_id = session_id + + def run_conversation(self, user_message, conversation_history, task_id): + from gateway.session_context import get_session_env + from tools.environments.local import _make_run_env + + observed["task_id"] = task_id + observed["context_session_id"] = get_session_env("HERMES_SESSION_ID") + observed["context_platform"] = get_session_env("HERMES_SESSION_PLATFORM") + observed["context_session_key"] = get_session_env("HERMES_SESSION_KEY") + observed["child_session_id"] = _make_run_env({}).get("HERMES_SESSION_ID") + return {"final_response": "ok"} + + def fake_create_agent(**kwargs): + return FakeAgent(kwargs["session_id"]) + + monkeypatch.setattr(adapter, "_create_agent", fake_create_agent) + + result, usage = await adapter._run_agent( + user_message="hello", + conversation_history=[], + session_id="request-session", + gateway_session_key="request-key", + ) + + assert result["session_id"] == "request-session" + assert usage == {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + assert observed == { + "task_id": "request-session", + "context_session_id": "request-session", + "context_platform": "api_server", + "context_session_key": "request-key", + "child_session_id": "request-session", + } + + @pytest.mark.asyncio async def test_session_crud_and_message_history(adapter, session_db): app = _create_session_app(adapter) diff --git a/tests/gateway/test_session_env.py b/tests/gateway/test_session_env.py index 2b6c983a769..1da1e2a3b81 100644 --- a/tests/gateway/test_session_env.py +++ b/tests/gateway/test_session_env.py @@ -190,6 +190,17 @@ def test_session_key_falls_back_to_os_environ(monkeypatch): assert get_session_env("HERMES_SESSION_KEY") == "" +def test_session_id_set_via_contextvars(monkeypatch): + """set_session_vars should set HERMES_SESSION_ID via contextvars.""" + monkeypatch.setenv("HERMES_SESSION_ID", "stale-env-session") + + tokens = set_session_vars(session_id="ctx-session-456") + assert get_session_env("HERMES_SESSION_ID") == "ctx-session-456" + + clear_session_vars(tokens) + assert get_session_env("HERMES_SESSION_ID") == "" + + def test_set_session_env_includes_session_key(): """_set_session_env should propagate session_key from SessionContext.""" runner = object.__new__(GatewayRunner)