mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-30 06:41:51 +00:00
240 lines
10 KiB
Python
240 lines
10 KiB
Python
"""Focused tests for API server session-control endpoints."""
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
from aiohttp import web
|
|
from aiohttp.test_utils import TestClient, TestServer
|
|
|
|
from gateway.config import PlatformConfig
|
|
from gateway.platforms.api_server import APIServerAdapter
|
|
from hermes_state import SessionDB
|
|
|
|
|
|
@pytest.fixture
|
|
def session_db(tmp_path):
|
|
db = SessionDB(tmp_path / "state.db")
|
|
try:
|
|
yield db
|
|
finally:
|
|
close = getattr(db, "close", None)
|
|
if callable(close):
|
|
close()
|
|
|
|
|
|
@pytest.fixture
|
|
def adapter(session_db):
|
|
adapter = APIServerAdapter(PlatformConfig(enabled=True))
|
|
adapter._session_db = session_db
|
|
return adapter
|
|
|
|
|
|
@pytest.fixture
|
|
def auth_adapter(session_db):
|
|
adapter = APIServerAdapter(PlatformConfig(enabled=True, extra={"key": "sk-test"}))
|
|
adapter._session_db = session_db
|
|
return adapter
|
|
|
|
|
|
def _create_session_app(adapter: APIServerAdapter) -> web.Application:
|
|
app = web.Application()
|
|
app.router.add_get("/v1/capabilities", adapter._handle_capabilities)
|
|
app.router.add_get("/api/sessions", adapter._handle_list_sessions)
|
|
app.router.add_post("/api/sessions", adapter._handle_create_session)
|
|
app.router.add_get("/api/sessions/{session_id}", adapter._handle_get_session)
|
|
app.router.add_patch("/api/sessions/{session_id}", adapter._handle_patch_session)
|
|
app.router.add_delete("/api/sessions/{session_id}", adapter._handle_delete_session)
|
|
app.router.add_get("/api/sessions/{session_id}/messages", adapter._handle_session_messages)
|
|
app.router.add_post("/api/sessions/{session_id}/fork", adapter._handle_fork_session)
|
|
app.router.add_post("/api/sessions/{session_id}/chat", adapter._handle_session_chat)
|
|
app.router.add_post("/api/sessions/{session_id}/chat/stream", adapter._handle_session_chat_stream)
|
|
return app
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_capabilities_advertises_session_control_surface(adapter):
|
|
app = _create_session_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.get("/v1/capabilities")
|
|
assert resp.status == 200
|
|
data = await resp.json()
|
|
|
|
features = data["features"]
|
|
assert features["session_resources"] is True
|
|
assert features["session_chat"] is True
|
|
assert features["session_chat_streaming"] is True
|
|
assert features["session_fork"] is True
|
|
assert features["admin_config_rw"] is False
|
|
assert features["memory_write_api"] is False
|
|
assert features["skills_api"] is False
|
|
assert features["realtime_voice"] is False
|
|
assert data["endpoints"]["sessions"] == {"method": "GET", "path": "/api/sessions"}
|
|
assert data["endpoints"]["session_chat_stream"] == {
|
|
"method": "POST",
|
|
"path": "/api/sessions/{session_id}/chat/stream",
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_crud_and_message_history(adapter, session_db):
|
|
app = _create_session_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
create_resp = await cli.post("/api/sessions", json={"title": "Mobile chat", "model": "test-model"})
|
|
assert create_resp.status == 201
|
|
created = await create_resp.json()
|
|
session_id = created["session"]["id"]
|
|
assert created["object"] == "hermes.session"
|
|
assert created["session"]["title"] == "Mobile chat"
|
|
|
|
session_db.append_message(session_id, "user", "hello from phone")
|
|
session_db.append_message(session_id, "assistant", "hello from hermes")
|
|
|
|
list_resp = await cli.get("/api/sessions?limit=10&offset=0")
|
|
assert list_resp.status == 200
|
|
listed = await list_resp.json()
|
|
assert listed["object"] == "list"
|
|
assert [s["id"] for s in listed["data"]] == [session_id]
|
|
assert listed["data"][0]["message_count"] == 2
|
|
|
|
get_resp = await cli.get(f"/api/sessions/{session_id}")
|
|
assert get_resp.status == 200
|
|
got = await get_resp.json()
|
|
assert got["session"]["id"] == session_id
|
|
assert got["session"]["message_count"] == 2
|
|
|
|
messages_resp = await cli.get(f"/api/sessions/{session_id}/messages")
|
|
assert messages_resp.status == 200
|
|
messages = await messages_resp.json()
|
|
assert messages["object"] == "list"
|
|
assert [m["role"] for m in messages["data"]] == ["user", "assistant"]
|
|
assert messages["data"][0]["content"] == "hello from phone"
|
|
|
|
patch_resp = await cli.patch(f"/api/sessions/{session_id}", json={"title": "Renamed"})
|
|
assert patch_resp.status == 200
|
|
patched = await patch_resp.json()
|
|
assert patched["session"]["title"] == "Renamed"
|
|
|
|
delete_resp = await cli.delete(f"/api/sessions/{session_id}")
|
|
assert delete_resp.status == 200
|
|
deleted = await delete_resp.json()
|
|
assert deleted == {"object": "hermes.session.deleted", "id": session_id, "deleted": True}
|
|
assert session_db.get_session(session_id) is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_fork_uses_current_sessiondb_branch_primitives(adapter, session_db):
|
|
source_id = session_db.create_session("source-session", "api_server", model="test-model")
|
|
session_db.set_session_title(source_id, "Original")
|
|
session_db.append_message(source_id, "user", "first path")
|
|
session_db.append_message(source_id, "assistant", "answer")
|
|
|
|
app = _create_session_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post(f"/api/sessions/{source_id}/fork", json={"title": "Alternative"})
|
|
assert resp.status == 201
|
|
payload = await resp.json()
|
|
|
|
fork = payload["session"]
|
|
assert payload["object"] == "hermes.session"
|
|
assert fork["id"] != source_id
|
|
assert fork["parent_session_id"] == source_id
|
|
assert fork["title"] == "Alternative"
|
|
assert [m["content"] for m in session_db.get_messages(fork["id"])] == ["first path", "answer"]
|
|
assert session_db.get_session(source_id)["end_reason"] == "branched"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_chat_loads_history_and_preserves_session_headers(auth_adapter, session_db):
|
|
session_id = session_db.create_session("chat-session", "api_server")
|
|
session_db.set_session_title(session_id, "Chat")
|
|
session_db.append_message(session_id, "user", "earlier")
|
|
session_db.append_message(session_id, "assistant", "prior answer")
|
|
|
|
mock_run = AsyncMock(return_value=({"final_response": "fresh answer", "session_id": session_id}, {"total_tokens": 3}))
|
|
app = _create_session_app(auth_adapter)
|
|
with patch.object(auth_adapter, "_run_agent", mock_run):
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post(
|
|
f"/api/sessions/{session_id}/chat",
|
|
json={"message": "next", "system_message": "stay focused"},
|
|
headers={"Authorization": "Bearer sk-test", "X-Hermes-Session-Key": "client-42"},
|
|
)
|
|
assert resp.status == 200
|
|
payload = await resp.json()
|
|
|
|
assert resp.headers["X-Hermes-Session-Id"] == session_id
|
|
assert resp.headers["X-Hermes-Session-Key"] == "client-42"
|
|
assert payload["object"] == "hermes.session.chat.completion"
|
|
assert payload["session_id"] == session_id
|
|
assert payload["message"]["role"] == "assistant"
|
|
assert payload["message"]["content"] == "fresh answer"
|
|
mock_run.assert_awaited_once()
|
|
_, kwargs = mock_run.call_args
|
|
assert kwargs["session_id"] == session_id
|
|
assert kwargs["gateway_session_key"] == "client-42"
|
|
assert kwargs["ephemeral_system_prompt"] == "stay focused"
|
|
assert kwargs["conversation_history"] == [
|
|
{"role": "user", "content": "earlier"},
|
|
{"role": "assistant", "content": "prior answer"},
|
|
]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_chat_stream_emits_lifecycle_events_and_keepalive_safe_shape(adapter, session_db):
|
|
session_id = session_db.create_session("stream-session", "api_server")
|
|
session_db.set_session_title(session_id, "Stream")
|
|
|
|
async def fake_run(**kwargs):
|
|
kwargs["stream_delta_callback"]("Hello")
|
|
kwargs["stream_delta_callback"](" world")
|
|
kwargs["tool_progress_callback"]("reasoning.available", tool_name="_thinking", preview="thinking")
|
|
return {"final_response": "Hello world", "session_id": session_id}, {"total_tokens": 2}
|
|
|
|
app = _create_session_app(adapter)
|
|
with patch.object(adapter, "_run_agent", side_effect=fake_run):
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post(f"/api/sessions/{session_id}/chat/stream", json={"message": "stream please"})
|
|
assert resp.status == 200
|
|
assert resp.headers["Content-Type"].startswith("text/event-stream")
|
|
body = await resp.text()
|
|
|
|
assert "event: run.started" in body
|
|
assert "event: message.started" in body
|
|
assert "event: assistant.delta" in body
|
|
assert "Hello world" in body
|
|
assert "event: tool.progress" in body
|
|
assert "event: assistant.completed" in body
|
|
assert "event: run.completed" in body
|
|
assert "event: done" in body
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_endpoints_require_auth_when_key_configured(auth_adapter):
|
|
app = _create_session_app(auth_adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.get("/api/sessions")
|
|
assert resp.status == 401
|
|
body = await resp.json()
|
|
assert body["error"]["code"] == "invalid_api_key"
|
|
|
|
ok = await cli.get("/api/sessions", headers={"Authorization": "Bearer sk-test"})
|
|
assert ok.status == 200
|
|
data = await ok.json()
|
|
assert data["object"] == "list"
|
|
assert data["data"] == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_header_rejected_without_api_key(adapter, session_db):
|
|
session_id = session_db.create_session("unsafe-session", "api_server")
|
|
app = _create_session_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post(
|
|
f"/api/sessions/{session_id}/chat",
|
|
json={"message": "hello"},
|
|
headers={"X-Hermes-Session-Key": "client-42"},
|
|
)
|
|
assert resp.status == 403
|
|
data = await resp.json()
|
|
assert "X-Hermes-Session-Key requires API key" in data["error"]["message"]
|