mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-09 08:21:50 +00:00
fix(api-server): bind request session context for tools
This commit is contained in:
parent
52ae9d9f02
commit
b23184cad4
4 changed files with 101 additions and 28 deletions
|
|
@ -3510,35 +3510,46 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
loop = asyncio.get_running_loop()
|
||||
|
||||
def _run():
|
||||
agent = self._create_agent(
|
||||
ephemeral_system_prompt=ephemeral_system_prompt,
|
||||
session_id=session_id,
|
||||
stream_delta_callback=stream_delta_callback,
|
||||
tool_progress_callback=tool_progress_callback,
|
||||
tool_start_callback=tool_start_callback,
|
||||
tool_complete_callback=tool_complete_callback,
|
||||
gateway_session_key=gateway_session_key,
|
||||
from gateway.session_context import clear_session_vars, set_session_vars
|
||||
|
||||
tokens = set_session_vars(
|
||||
platform="api_server",
|
||||
chat_id=session_id or "",
|
||||
session_key=gateway_session_key or session_id or "",
|
||||
session_id=session_id or "",
|
||||
)
|
||||
if agent_ref is not None:
|
||||
agent_ref[0] = agent
|
||||
effective_task_id = session_id or str(uuid.uuid4())
|
||||
result = agent.run_conversation(
|
||||
user_message=user_message,
|
||||
conversation_history=conversation_history,
|
||||
task_id=effective_task_id,
|
||||
)
|
||||
usage = {
|
||||
"input_tokens": getattr(agent, "session_prompt_tokens", 0) or 0,
|
||||
"output_tokens": getattr(agent, "session_completion_tokens", 0) or 0,
|
||||
"total_tokens": getattr(agent, "session_total_tokens", 0) or 0,
|
||||
}
|
||||
# Include the effective session ID in the result so callers
|
||||
# (e.g. X-Hermes-Session-Id header) can track compression-
|
||||
# triggered session rotations. (#16938)
|
||||
_eff_sid = getattr(agent, "session_id", session_id)
|
||||
if isinstance(_eff_sid, str) and _eff_sid:
|
||||
result["session_id"] = _eff_sid
|
||||
return result, usage
|
||||
try:
|
||||
agent = self._create_agent(
|
||||
ephemeral_system_prompt=ephemeral_system_prompt,
|
||||
session_id=session_id,
|
||||
stream_delta_callback=stream_delta_callback,
|
||||
tool_progress_callback=tool_progress_callback,
|
||||
tool_start_callback=tool_start_callback,
|
||||
tool_complete_callback=tool_complete_callback,
|
||||
gateway_session_key=gateway_session_key,
|
||||
)
|
||||
if agent_ref is not None:
|
||||
agent_ref[0] = agent
|
||||
effective_task_id = session_id or str(uuid.uuid4())
|
||||
result = agent.run_conversation(
|
||||
user_message=user_message,
|
||||
conversation_history=conversation_history,
|
||||
task_id=effective_task_id,
|
||||
)
|
||||
usage = {
|
||||
"input_tokens": getattr(agent, "session_prompt_tokens", 0) or 0,
|
||||
"output_tokens": getattr(agent, "session_completion_tokens", 0) or 0,
|
||||
"total_tokens": getattr(agent, "session_total_tokens", 0) or 0,
|
||||
}
|
||||
# Include the effective session ID in the result so callers
|
||||
# (e.g. X-Hermes-Session-Id header) can track compression-
|
||||
# triggered session rotations. (#16938)
|
||||
_eff_sid = getattr(agent, "session_id", session_id)
|
||||
if isinstance(_eff_sid, str) and _eff_sid:
|
||||
result["session_id"] = _eff_sid
|
||||
return result, usage
|
||||
finally:
|
||||
clear_session_vars(tokens)
|
||||
|
||||
return await loop.run_in_executor(None, _run)
|
||||
|
||||
|
|
|
|||
|
|
@ -106,6 +106,7 @@ def set_session_vars(
|
|||
user_id: str = "",
|
||||
user_name: str = "",
|
||||
session_key: str = "",
|
||||
session_id: str = "",
|
||||
message_id: str = "",
|
||||
cwd: str = "",
|
||||
) -> list:
|
||||
|
|
@ -127,6 +128,7 @@ def set_session_vars(
|
|||
_SESSION_USER_ID.set(user_id),
|
||||
_SESSION_USER_NAME.set(user_name),
|
||||
_SESSION_KEY.set(session_key),
|
||||
_SESSION_ID.set(session_id),
|
||||
_SESSION_MESSAGE_ID.set(message_id),
|
||||
]
|
||||
try:
|
||||
|
|
@ -157,6 +159,7 @@ def clear_session_vars(tokens: list) -> None:
|
|||
_SESSION_USER_ID,
|
||||
_SESSION_USER_NAME,
|
||||
_SESSION_KEY,
|
||||
_SESSION_ID,
|
||||
_SESSION_MESSAGE_ID,
|
||||
):
|
||||
var.set("")
|
||||
|
|
|
|||
|
|
@ -75,6 +75,54 @@ async def test_capabilities_advertises_session_control_surface(adapter):
|
|||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_binds_api_session_context_for_tool_env(adapter, monkeypatch):
|
||||
"""API-server request sessions should reach tools and terminal subprocess env."""
|
||||
monkeypatch.setenv("HERMES_SESSION_ID", "stale-session")
|
||||
observed = {}
|
||||
|
||||
class FakeAgent:
|
||||
session_prompt_tokens = 0
|
||||
session_completion_tokens = 0
|
||||
session_total_tokens = 0
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self.session_id = session_id
|
||||
|
||||
def run_conversation(self, user_message, conversation_history, task_id):
|
||||
from gateway.session_context import get_session_env
|
||||
from tools.environments.local import _make_run_env
|
||||
|
||||
observed["task_id"] = task_id
|
||||
observed["context_session_id"] = get_session_env("HERMES_SESSION_ID")
|
||||
observed["context_platform"] = get_session_env("HERMES_SESSION_PLATFORM")
|
||||
observed["context_session_key"] = get_session_env("HERMES_SESSION_KEY")
|
||||
observed["child_session_id"] = _make_run_env({}).get("HERMES_SESSION_ID")
|
||||
return {"final_response": "ok"}
|
||||
|
||||
def fake_create_agent(**kwargs):
|
||||
return FakeAgent(kwargs["session_id"])
|
||||
|
||||
monkeypatch.setattr(adapter, "_create_agent", fake_create_agent)
|
||||
|
||||
result, usage = await adapter._run_agent(
|
||||
user_message="hello",
|
||||
conversation_history=[],
|
||||
session_id="request-session",
|
||||
gateway_session_key="request-key",
|
||||
)
|
||||
|
||||
assert result["session_id"] == "request-session"
|
||||
assert usage == {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
assert observed == {
|
||||
"task_id": "request-session",
|
||||
"context_session_id": "request-session",
|
||||
"context_platform": "api_server",
|
||||
"context_session_key": "request-key",
|
||||
"child_session_id": "request-session",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_crud_and_message_history(adapter, session_db):
|
||||
app = _create_session_app(adapter)
|
||||
|
|
|
|||
|
|
@ -190,6 +190,17 @@ def test_session_key_falls_back_to_os_environ(monkeypatch):
|
|||
assert get_session_env("HERMES_SESSION_KEY") == ""
|
||||
|
||||
|
||||
def test_session_id_set_via_contextvars(monkeypatch):
|
||||
"""set_session_vars should set HERMES_SESSION_ID via contextvars."""
|
||||
monkeypatch.setenv("HERMES_SESSION_ID", "stale-env-session")
|
||||
|
||||
tokens = set_session_vars(session_id="ctx-session-456")
|
||||
assert get_session_env("HERMES_SESSION_ID") == "ctx-session-456"
|
||||
|
||||
clear_session_vars(tokens)
|
||||
assert get_session_env("HERMES_SESSION_ID") == ""
|
||||
|
||||
|
||||
def test_set_session_env_includes_session_key():
|
||||
"""_set_session_env should propagate session_key from SessionContext."""
|
||||
runner = object.__new__(GatewayRunner)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue