mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
chore: add tests
This commit is contained in:
parent
f035796381
commit
cc5ca0fe42
1 changed files with 365 additions and 0 deletions
365
tests/test_serve.py
Normal file
365
tests/test_serve.py
Normal file
|
|
@ -0,0 +1,365 @@
|
|||
"""Tests for the serve layer (serve.py) and event_queue integration.
|
||||
|
||||
Covers:
|
||||
- _emit_event: queue attached, no queue, queue full
|
||||
- extra_tags merging in _build_api_kwargs for Nous API
|
||||
- FastAPI /health endpoint
|
||||
- FastAPI /v1/agent/stream SSE endpoint (mocked AIAgent)
|
||||
|
||||
Run with: python -m pytest tests/test_serve.py -v
|
||||
"""
|
||||
|
||||
import json
|
||||
import queue
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tool_defs(*names: str) -> list:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": n,
|
||||
"description": f"{n} tool",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
for n in names
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def agent_no_queue():
|
||||
"""AIAgent without an event_queue (CLI/gateway mode)."""
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
a.client = MagicMock()
|
||||
return a
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def agent_with_queue():
|
||||
"""AIAgent with an event_queue attached (serve mode)."""
|
||||
eq = queue.Queue(maxsize=128)
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
event_queue=eq,
|
||||
)
|
||||
a.client = MagicMock()
|
||||
return a, eq
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def nous_agent():
|
||||
"""AIAgent pointing at a Nous inference URL with extra_tags."""
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
a = AIAgent(
|
||||
base_url="https://stg-inference-api.nousresearch.com/v1",
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
extra_tags=["user=test-user", "tier=paid"],
|
||||
)
|
||||
a.client = MagicMock()
|
||||
return a
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Group 1: _emit_event
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestEmitEvent:
|
||||
def test_no_queue_is_noop(self, agent_no_queue):
|
||||
"""_emit_event should silently do nothing when no queue is attached."""
|
||||
agent_no_queue._emit_event({"type": "text", "text": "hello"})
|
||||
|
||||
def test_event_pushed_to_queue(self, agent_with_queue):
|
||||
agent, eq = agent_with_queue
|
||||
event = {"type": "text", "text": "hello"}
|
||||
agent._emit_event(event)
|
||||
assert not eq.empty()
|
||||
assert eq.get_nowait() == event
|
||||
|
||||
def test_multiple_events_ordered(self, agent_with_queue):
|
||||
agent, eq = agent_with_queue
|
||||
events = [
|
||||
{"type": "tool-call", "name": "terminal", "status": "calling"},
|
||||
{"type": "tool-result", "name": "terminal", "status": "complete"},
|
||||
{"type": "text", "text": "done"},
|
||||
{"type": "done"},
|
||||
]
|
||||
for e in events:
|
||||
agent._emit_event(e)
|
||||
received = []
|
||||
while not eq.empty():
|
||||
received.append(eq.get_nowait())
|
||||
assert received == events
|
||||
|
||||
def test_full_queue_does_not_raise(self):
|
||||
"""When the queue is full, _emit_event should silently drop the event."""
|
||||
eq = queue.Queue(maxsize=1)
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
event_queue=eq,
|
||||
)
|
||||
eq.put({"type": "filler"})
|
||||
assert eq.full()
|
||||
a._emit_event({"type": "text", "text": "overflow"})
|
||||
assert eq.qsize() == 1
|
||||
assert eq.get_nowait()["type"] == "filler"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Group 2: extra_tags in _build_api_kwargs
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestExtraTags:
|
||||
def test_no_tags_on_openrouter(self, agent_no_queue):
|
||||
"""OpenRouter requests should NOT include Nous product tags."""
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent_no_queue._build_api_kwargs(messages)
|
||||
extra = kwargs.get("extra_body", {})
|
||||
assert "tags" not in extra
|
||||
|
||||
def test_default_product_tag_on_nous(self, nous_agent):
|
||||
"""Nous API requests should always include product=hermes-agent."""
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = nous_agent._build_api_kwargs(messages)
|
||||
tags = kwargs["extra_body"]["tags"]
|
||||
assert "product=hermes-agent" in tags
|
||||
|
||||
def test_extra_tags_merged(self, nous_agent):
|
||||
"""Caller-supplied tags should appear alongside the product tag."""
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = nous_agent._build_api_kwargs(messages)
|
||||
tags = kwargs["extra_body"]["tags"]
|
||||
assert "user=test-user" in tags
|
||||
assert "tier=paid" in tags
|
||||
assert "product=hermes-agent" in tags
|
||||
|
||||
def test_extra_tags_empty_by_default(self, agent_no_queue):
|
||||
"""Agent without extra_tags should have an empty list."""
|
||||
assert agent_no_queue._extra_tags == []
|
||||
|
||||
def test_extra_tags_does_not_mutate_original(self, nous_agent):
|
||||
"""Calling _build_api_kwargs should not grow _extra_tags each time."""
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
nous_agent._build_api_kwargs(messages)
|
||||
nous_agent._build_api_kwargs(messages)
|
||||
assert nous_agent._extra_tags.count("product=hermes-agent") == 0
|
||||
assert len(nous_agent._extra_tags) == 2
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Group 3: FastAPI endpoints (serve.py)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fastapi_app():
|
||||
"""Import the FastAPI app from serve.py."""
|
||||
from serve import app
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestHealthEndpoint:
|
||||
async def test_health_returns_ok(self, fastapi_app):
|
||||
transport = ASGITransport(app=fastapi_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAgentStreamEndpoint:
|
||||
async def test_stream_returns_sse_events(self, fastapi_app):
|
||||
"""Mock AIAgent to emit known events and verify SSE output."""
|
||||
mock_result = {
|
||||
"final_response": "Hello!",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
"completed": True,
|
||||
}
|
||||
|
||||
def fake_run_conversation(user_message, conversation_history=None):
|
||||
agent_instance = fake_init.agent_ref
|
||||
if agent_instance and agent_instance.event_queue:
|
||||
eq = agent_instance.event_queue
|
||||
eq.put({"type": "tool-call", "name": "terminal", "args": "echo hi", "status": "calling"})
|
||||
eq.put({"type": "tool-result", "name": "terminal", "output": "hi", "status": "complete", "duration": 0.1})
|
||||
eq.put({"type": "text", "text": "Hello!"})
|
||||
eq.put({"type": "done"})
|
||||
return mock_result
|
||||
|
||||
class fake_init:
|
||||
agent_ref = None
|
||||
|
||||
original_init = AIAgent.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
self.run_conversation = fake_run_conversation
|
||||
fake_init.agent_ref = self
|
||||
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
patch.object(AIAgent, "__init__", patched_init),
|
||||
):
|
||||
transport = ASGITransport(app=fastapi_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.post(
|
||||
"/v1/agent/stream",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Say hello"}],
|
||||
"model": "test/model",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert "text/event-stream" in resp.headers["content-type"]
|
||||
|
||||
lines = resp.text.strip().split("\n")
|
||||
events = []
|
||||
for line in lines:
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
|
||||
types = [e["type"] for e in events]
|
||||
assert "tool-call" in types
|
||||
assert "tool-result" in types
|
||||
assert "text" in types
|
||||
assert types[-1] == "done"
|
||||
|
||||
text_event = next(e for e in events if e["type"] == "text")
|
||||
assert text_event["text"] == "Hello!"
|
||||
|
||||
tool_call = next(e for e in events if e["type"] == "tool-call")
|
||||
assert tool_call["name"] == "terminal"
|
||||
|
||||
async def test_stream_error_propagated(self, fastapi_app):
|
||||
"""When AIAgent raises, an error event should be streamed."""
|
||||
original_init = AIAgent.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
def exploding_run(user_message, conversation_history=None):
|
||||
raise RuntimeError("kaboom")
|
||||
|
||||
self.run_conversation = exploding_run
|
||||
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
patch.object(AIAgent, "__init__", patched_init),
|
||||
):
|
||||
transport = ASGITransport(app=fastapi_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.post(
|
||||
"/v1/agent/stream",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "fail"}],
|
||||
"model": "test/model",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
events = []
|
||||
for line in resp.text.strip().split("\n"):
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
|
||||
error_events = [e for e in events if e["type"] == "error"]
|
||||
assert len(error_events) >= 1
|
||||
assert "kaboom" in error_events[0]["error"]
|
||||
assert events[-1]["type"] == "done"
|
||||
|
||||
async def test_stream_passes_base_url_and_tags(self, fastapi_app):
|
||||
"""Verify base_url, api_key, and tags from the request body reach AIAgent."""
|
||||
captured = {}
|
||||
original_init = AIAgent.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
captured["base_url"] = kwargs.get("base_url")
|
||||
captured["api_key"] = kwargs.get("api_key")
|
||||
captured["extra_tags"] = kwargs.get("extra_tags")
|
||||
original_init(self, *args, **kwargs)
|
||||
self.run_conversation = lambda **kw: (
|
||||
self.event_queue.put({"type": "text", "text": "ok"}) if self.event_queue else None,
|
||||
self.event_queue.put({"type": "done"}) if self.event_queue else None,
|
||||
{"final_response": "ok", "messages": [], "api_calls": 1, "completed": True},
|
||||
)[-1]
|
||||
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
patch.object(AIAgent, "__init__", patched_init),
|
||||
):
|
||||
transport = ASGITransport(app=fastapi_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
await client.post(
|
||||
"/v1/agent/stream",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"model": "test/model",
|
||||
"base_url": "https://my-api.example.com/v1",
|
||||
"api_key": "sk-test-key",
|
||||
"tags": ["user=alice", "tier=free"],
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
assert captured["base_url"] == "https://my-api.example.com/v1"
|
||||
assert captured["api_key"] == "sk-test-key"
|
||||
assert captured["extra_tags"] == ["user=alice", "tier=free"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue