diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index 4300f5da57..e0c9cf8460 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -20,6 +20,7 @@ Requires: """ import asyncio +import hashlib import hmac import json import logging @@ -283,6 +284,24 @@ def _make_request_fingerprint(body: Dict[str, Any], keys: List[str]) -> str: return sha256(repr(subset).encode("utf-8")).hexdigest() +def _derive_chat_session_id( + system_prompt: Optional[str], + first_user_message: str, +) -> str: + """Derive a stable session ID from the conversation's first user message. + + OpenAI-compatible frontends (Open WebUI, LibreChat, etc.) send the full + conversation history with every request. The system prompt and first user + message are constant across all turns of the same conversation, so hashing + them produces a deterministic session ID that lets the API server reuse + the same Hermes session (and therefore the same Docker container sandbox + directory) across turns. + """ + seed = f"{system_prompt or ''}\n{first_user_message}" + digest = hashlib.sha256(seed.encode("utf-8")).hexdigest()[:16] + return f"api-{digest}" + + class APIServerAdapter(BasePlatformAdapter): """ OpenAI-compatible HTTP API server adapter. @@ -590,7 +609,16 @@ class APIServerAdapter(BasePlatformAdapter): logger.warning("Failed to load session history for %s: %s", session_id, e) history = [] else: - session_id = str(uuid.uuid4()) + # Derive a stable session ID from the conversation fingerprint so + # that consecutive messages from the same Open WebUI (or similar) + # conversation map to the same Hermes session. The first user + # message + system prompt are constant across all turns. + first_user = "" + for cm in conversation_messages: + if cm.get("role") == "user": + first_user = cm.get("content", "") + break + session_id = _derive_chat_session_id(system_prompt, first_user) # history already set from request body above completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" @@ -1366,6 +1394,7 @@ class APIServerAdapter(BasePlatformAdapter): result = agent.run_conversation( user_message=user_message, conversation_history=conversation_history, + task_id="default", ) usage = { "input_tokens": getattr(agent, "session_prompt_tokens", 0) or 0, @@ -1532,6 +1561,7 @@ class APIServerAdapter(BasePlatformAdapter): r = agent.run_conversation( user_message=user_message, conversation_history=conversation_history, + task_id="default", ) u = { "input_tokens": getattr(agent, "session_prompt_tokens", 0) or 0, diff --git a/tests/gateway/test_api_server.py b/tests/gateway/test_api_server.py index 8085a0a6f3..a1117f5ca3 100644 --- a/tests/gateway/test_api_server.py +++ b/tests/gateway/test_api_server.py @@ -26,6 +26,7 @@ from gateway.platforms.api_server import ( APIServerAdapter, ResponseStore, _CORS_HEADERS, + _derive_chat_session_id, check_api_server_requirements, cors_middleware, security_headers_middleware, @@ -658,6 +659,98 @@ class TestChatCompletionsEndpoint: data = await resp.json() assert "Provider failed" in data["error"]["message"] + @pytest.mark.asyncio + async def test_stable_session_id_across_turns(self, adapter): + """Same conversation (same first user message) produces the same session_id.""" + mock_result = {"final_response": "ok", "messages": [], "api_calls": 1} + + app = _create_app(adapter) + session_ids = [] + async with TestClient(TestServer(app)) as cli: + # Turn 1: single user message + with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run: + mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}) + await cli.post( + "/v1/chat/completions", + json={ + "model": "hermes-agent", + "messages": [{"role": "user", "content": "Hello"}], + }, + ) + session_ids.append(mock_run.call_args.kwargs["session_id"]) + + # Turn 2: same first message, conversation grew + with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run: + mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}) + await cli.post( + "/v1/chat/completions", + json={ + "model": "hermes-agent", + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ], + }, + ) + session_ids.append(mock_run.call_args.kwargs["session_id"]) + + assert session_ids[0] == session_ids[1], "Session ID should be stable across turns" + assert session_ids[0].startswith("api-"), "Derived session IDs should have api- prefix" + + @pytest.mark.asyncio + async def test_different_conversations_get_different_session_ids(self, adapter): + """Different first messages produce different session_ids.""" + mock_result = {"final_response": "ok", "messages": [], "api_calls": 1} + + app = _create_app(adapter) + session_ids = [] + async with TestClient(TestServer(app)) as cli: + for first_msg in ["Hello", "Goodbye"]: + with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run: + mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}) + await cli.post( + "/v1/chat/completions", + json={ + "model": "hermes-agent", + "messages": [{"role": "user", "content": first_msg}], + }, + ) + session_ids.append(mock_run.call_args.kwargs["session_id"]) + + assert session_ids[0] != session_ids[1] + + +# --------------------------------------------------------------------------- +# _derive_chat_session_id unit tests +# --------------------------------------------------------------------------- + + +class TestDeriveChatSessionId: + def test_deterministic(self): + """Same inputs always produce the same session ID.""" + a = _derive_chat_session_id("sys", "hello") + b = _derive_chat_session_id("sys", "hello") + assert a == b + + def test_prefix(self): + assert _derive_chat_session_id(None, "hi").startswith("api-") + + def test_different_system_prompt(self): + a = _derive_chat_session_id("You are a pirate.", "Hello") + b = _derive_chat_session_id("You are a robot.", "Hello") + assert a != b + + def test_different_first_message(self): + a = _derive_chat_session_id(None, "Hello") + b = _derive_chat_session_id(None, "Goodbye") + assert a != b + + def test_none_system_prompt(self): + """None system prompt doesn't crash.""" + sid = _derive_chat_session_id(None, "test") + assert isinstance(sid, str) and len(sid) > 4 + # --------------------------------------------------------------------------- # /v1/responses endpoint