fix(api-server): bind request session context for tools

This commit is contained in:
helix4u 2026-06-08 17:52:03 -06:00 committed by Teknium
parent 52ae9d9f02
commit b23184cad4
4 changed files with 101 additions and 28 deletions

View file

@ -3510,35 +3510,46 @@ class APIServerAdapter(BasePlatformAdapter):
loop = asyncio.get_running_loop()
def _run():
agent = self._create_agent(
ephemeral_system_prompt=ephemeral_system_prompt,
session_id=session_id,
stream_delta_callback=stream_delta_callback,
tool_progress_callback=tool_progress_callback,
tool_start_callback=tool_start_callback,
tool_complete_callback=tool_complete_callback,
gateway_session_key=gateway_session_key,
from gateway.session_context import clear_session_vars, set_session_vars
tokens = set_session_vars(
platform="api_server",
chat_id=session_id or "",
session_key=gateway_session_key or session_id or "",
session_id=session_id or "",
)
if agent_ref is not None:
agent_ref[0] = agent
effective_task_id = session_id or str(uuid.uuid4())
result = agent.run_conversation(
user_message=user_message,
conversation_history=conversation_history,
task_id=effective_task_id,
)
usage = {
"input_tokens": getattr(agent, "session_prompt_tokens", 0) or 0,
"output_tokens": getattr(agent, "session_completion_tokens", 0) or 0,
"total_tokens": getattr(agent, "session_total_tokens", 0) or 0,
}
# Include the effective session ID in the result so callers
# (e.g. X-Hermes-Session-Id header) can track compression-
# triggered session rotations. (#16938)
_eff_sid = getattr(agent, "session_id", session_id)
if isinstance(_eff_sid, str) and _eff_sid:
result["session_id"] = _eff_sid
return result, usage
try:
agent = self._create_agent(
ephemeral_system_prompt=ephemeral_system_prompt,
session_id=session_id,
stream_delta_callback=stream_delta_callback,
tool_progress_callback=tool_progress_callback,
tool_start_callback=tool_start_callback,
tool_complete_callback=tool_complete_callback,
gateway_session_key=gateway_session_key,
)
if agent_ref is not None:
agent_ref[0] = agent
effective_task_id = session_id or str(uuid.uuid4())
result = agent.run_conversation(
user_message=user_message,
conversation_history=conversation_history,
task_id=effective_task_id,
)
usage = {
"input_tokens": getattr(agent, "session_prompt_tokens", 0) or 0,
"output_tokens": getattr(agent, "session_completion_tokens", 0) or 0,
"total_tokens": getattr(agent, "session_total_tokens", 0) or 0,
}
# Include the effective session ID in the result so callers
# (e.g. X-Hermes-Session-Id header) can track compression-
# triggered session rotations. (#16938)
_eff_sid = getattr(agent, "session_id", session_id)
if isinstance(_eff_sid, str) and _eff_sid:
result["session_id"] = _eff_sid
return result, usage
finally:
clear_session_vars(tokens)
return await loop.run_in_executor(None, _run)

View file

@ -106,6 +106,7 @@ def set_session_vars(
user_id: str = "",
user_name: str = "",
session_key: str = "",
session_id: str = "",
message_id: str = "",
cwd: str = "",
) -> list:
@ -127,6 +128,7 @@ def set_session_vars(
_SESSION_USER_ID.set(user_id),
_SESSION_USER_NAME.set(user_name),
_SESSION_KEY.set(session_key),
_SESSION_ID.set(session_id),
_SESSION_MESSAGE_ID.set(message_id),
]
try:
@ -157,6 +159,7 @@ def clear_session_vars(tokens: list) -> None:
_SESSION_USER_ID,
_SESSION_USER_NAME,
_SESSION_KEY,
_SESSION_ID,
_SESSION_MESSAGE_ID,
):
var.set("")

View file

@ -75,6 +75,54 @@ async def test_capabilities_advertises_session_control_surface(adapter):
}
@pytest.mark.asyncio
async def test_run_agent_binds_api_session_context_for_tool_env(adapter, monkeypatch):
"""API-server request sessions should reach tools and terminal subprocess env."""
monkeypatch.setenv("HERMES_SESSION_ID", "stale-session")
observed = {}
class FakeAgent:
session_prompt_tokens = 0
session_completion_tokens = 0
session_total_tokens = 0
def __init__(self, session_id: str):
self.session_id = session_id
def run_conversation(self, user_message, conversation_history, task_id):
from gateway.session_context import get_session_env
from tools.environments.local import _make_run_env
observed["task_id"] = task_id
observed["context_session_id"] = get_session_env("HERMES_SESSION_ID")
observed["context_platform"] = get_session_env("HERMES_SESSION_PLATFORM")
observed["context_session_key"] = get_session_env("HERMES_SESSION_KEY")
observed["child_session_id"] = _make_run_env({}).get("HERMES_SESSION_ID")
return {"final_response": "ok"}
def fake_create_agent(**kwargs):
return FakeAgent(kwargs["session_id"])
monkeypatch.setattr(adapter, "_create_agent", fake_create_agent)
result, usage = await adapter._run_agent(
user_message="hello",
conversation_history=[],
session_id="request-session",
gateway_session_key="request-key",
)
assert result["session_id"] == "request-session"
assert usage == {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
assert observed == {
"task_id": "request-session",
"context_session_id": "request-session",
"context_platform": "api_server",
"context_session_key": "request-key",
"child_session_id": "request-session",
}
@pytest.mark.asyncio
async def test_session_crud_and_message_history(adapter, session_db):
app = _create_session_app(adapter)

View file

@ -190,6 +190,17 @@ def test_session_key_falls_back_to_os_environ(monkeypatch):
assert get_session_env("HERMES_SESSION_KEY") == ""
def test_session_id_set_via_contextvars(monkeypatch):
"""set_session_vars should set HERMES_SESSION_ID via contextvars."""
monkeypatch.setenv("HERMES_SESSION_ID", "stale-env-session")
tokens = set_session_vars(session_id="ctx-session-456")
assert get_session_env("HERMES_SESSION_ID") == "ctx-session-456"
clear_session_vars(tokens)
assert get_session_env("HERMES_SESSION_ID") == ""
def test_set_session_env_includes_session_key():
"""_set_session_env should propagate session_key from SessionContext."""
runner = object.__new__(GatewayRunner)