chore: add tests

This commit is contained in:
Austin Pickett 2026-03-05 19:01:24 -05:00
parent f035796381
commit cc5ca0fe42

365
tests/test_serve.py Normal file
View file

@ -0,0 +1,365 @@
"""Tests for the serve layer (serve.py) and event_queue integration.
Covers:
- _emit_event: queue attached, no queue, queue full
- extra_tags merging in _build_api_kwargs for Nous API
- FastAPI /health endpoint
- FastAPI /v1/agent/stream SSE endpoint (mocked AIAgent)
Run with: python -m pytest tests/test_serve.py -v
"""
import json
import queue
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from run_agent import AIAgent
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _make_tool_defs(*names: str) -> list:
return [
{
"type": "function",
"function": {
"name": n,
"description": f"{n} tool",
"parameters": {"type": "object", "properties": {}},
},
}
for n in names
]
@pytest.fixture()
def agent_no_queue():
"""AIAgent without an event_queue (CLI/gateway mode)."""
with (
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
):
a = AIAgent(
api_key="test-key-1234567890",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
a.client = MagicMock()
return a
@pytest.fixture()
def agent_with_queue():
"""AIAgent with an event_queue attached (serve mode)."""
eq = queue.Queue(maxsize=128)
with (
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
):
a = AIAgent(
api_key="test-key-1234567890",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
event_queue=eq,
)
a.client = MagicMock()
return a, eq
@pytest.fixture()
def nous_agent():
"""AIAgent pointing at a Nous inference URL with extra_tags."""
with (
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
):
a = AIAgent(
base_url="https://stg-inference-api.nousresearch.com/v1",
api_key="test-key-1234567890",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
extra_tags=["user=test-user", "tier=paid"],
)
a.client = MagicMock()
return a
# ===========================================================================
# Group 1: _emit_event
# ===========================================================================
class TestEmitEvent:
def test_no_queue_is_noop(self, agent_no_queue):
"""_emit_event should silently do nothing when no queue is attached."""
agent_no_queue._emit_event({"type": "text", "text": "hello"})
def test_event_pushed_to_queue(self, agent_with_queue):
agent, eq = agent_with_queue
event = {"type": "text", "text": "hello"}
agent._emit_event(event)
assert not eq.empty()
assert eq.get_nowait() == event
def test_multiple_events_ordered(self, agent_with_queue):
agent, eq = agent_with_queue
events = [
{"type": "tool-call", "name": "terminal", "status": "calling"},
{"type": "tool-result", "name": "terminal", "status": "complete"},
{"type": "text", "text": "done"},
{"type": "done"},
]
for e in events:
agent._emit_event(e)
received = []
while not eq.empty():
received.append(eq.get_nowait())
assert received == events
def test_full_queue_does_not_raise(self):
"""When the queue is full, _emit_event should silently drop the event."""
eq = queue.Queue(maxsize=1)
with (
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
):
a = AIAgent(
api_key="test-key-1234567890",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
event_queue=eq,
)
eq.put({"type": "filler"})
assert eq.full()
a._emit_event({"type": "text", "text": "overflow"})
assert eq.qsize() == 1
assert eq.get_nowait()["type"] == "filler"
# ===========================================================================
# Group 2: extra_tags in _build_api_kwargs
# ===========================================================================
class TestExtraTags:
def test_no_tags_on_openrouter(self, agent_no_queue):
"""OpenRouter requests should NOT include Nous product tags."""
messages = [{"role": "user", "content": "hi"}]
kwargs = agent_no_queue._build_api_kwargs(messages)
extra = kwargs.get("extra_body", {})
assert "tags" not in extra
def test_default_product_tag_on_nous(self, nous_agent):
"""Nous API requests should always include product=hermes-agent."""
messages = [{"role": "user", "content": "hi"}]
kwargs = nous_agent._build_api_kwargs(messages)
tags = kwargs["extra_body"]["tags"]
assert "product=hermes-agent" in tags
def test_extra_tags_merged(self, nous_agent):
"""Caller-supplied tags should appear alongside the product tag."""
messages = [{"role": "user", "content": "hi"}]
kwargs = nous_agent._build_api_kwargs(messages)
tags = kwargs["extra_body"]["tags"]
assert "user=test-user" in tags
assert "tier=paid" in tags
assert "product=hermes-agent" in tags
def test_extra_tags_empty_by_default(self, agent_no_queue):
"""Agent without extra_tags should have an empty list."""
assert agent_no_queue._extra_tags == []
def test_extra_tags_does_not_mutate_original(self, nous_agent):
"""Calling _build_api_kwargs should not grow _extra_tags each time."""
messages = [{"role": "user", "content": "hi"}]
nous_agent._build_api_kwargs(messages)
nous_agent._build_api_kwargs(messages)
assert nous_agent._extra_tags.count("product=hermes-agent") == 0
assert len(nous_agent._extra_tags) == 2
# ===========================================================================
# Group 3: FastAPI endpoints (serve.py)
# ===========================================================================
@pytest.fixture()
def fastapi_app():
"""Import the FastAPI app from serve.py."""
from serve import app
return app
@pytest.mark.asyncio
class TestHealthEndpoint:
async def test_health_returns_ok(self, fastapi_app):
transport = ASGITransport(app=fastapi_app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.get("/health")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "ok"
@pytest.mark.asyncio
class TestAgentStreamEndpoint:
async def test_stream_returns_sse_events(self, fastapi_app):
"""Mock AIAgent to emit known events and verify SSE output."""
mock_result = {
"final_response": "Hello!",
"messages": [],
"api_calls": 1,
"completed": True,
}
def fake_run_conversation(user_message, conversation_history=None):
agent_instance = fake_init.agent_ref
if agent_instance and agent_instance.event_queue:
eq = agent_instance.event_queue
eq.put({"type": "tool-call", "name": "terminal", "args": "echo hi", "status": "calling"})
eq.put({"type": "tool-result", "name": "terminal", "output": "hi", "status": "complete", "duration": 0.1})
eq.put({"type": "text", "text": "Hello!"})
eq.put({"type": "done"})
return mock_result
class fake_init:
agent_ref = None
original_init = AIAgent.__init__
def patched_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
self.run_conversation = fake_run_conversation
fake_init.agent_ref = self
with (
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
patch.object(AIAgent, "__init__", patched_init),
):
transport = ASGITransport(app=fastapi_app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/v1/agent/stream",
json={
"messages": [{"role": "user", "content": "Say hello"}],
"model": "test/model",
},
timeout=30,
)
assert resp.status_code == 200
assert "text/event-stream" in resp.headers["content-type"]
lines = resp.text.strip().split("\n")
events = []
for line in lines:
if line.startswith("data: "):
events.append(json.loads(line[6:]))
types = [e["type"] for e in events]
assert "tool-call" in types
assert "tool-result" in types
assert "text" in types
assert types[-1] == "done"
text_event = next(e for e in events if e["type"] == "text")
assert text_event["text"] == "Hello!"
tool_call = next(e for e in events if e["type"] == "tool-call")
assert tool_call["name"] == "terminal"
async def test_stream_error_propagated(self, fastapi_app):
"""When AIAgent raises, an error event should be streamed."""
original_init = AIAgent.__init__
def patched_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
def exploding_run(user_message, conversation_history=None):
raise RuntimeError("kaboom")
self.run_conversation = exploding_run
with (
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
patch.object(AIAgent, "__init__", patched_init),
):
transport = ASGITransport(app=fastapi_app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/v1/agent/stream",
json={
"messages": [{"role": "user", "content": "fail"}],
"model": "test/model",
},
timeout=30,
)
assert resp.status_code == 200
events = []
for line in resp.text.strip().split("\n"):
if line.startswith("data: "):
events.append(json.loads(line[6:]))
error_events = [e for e in events if e["type"] == "error"]
assert len(error_events) >= 1
assert "kaboom" in error_events[0]["error"]
assert events[-1]["type"] == "done"
async def test_stream_passes_base_url_and_tags(self, fastapi_app):
"""Verify base_url, api_key, and tags from the request body reach AIAgent."""
captured = {}
original_init = AIAgent.__init__
def patched_init(self, *args, **kwargs):
captured["base_url"] = kwargs.get("base_url")
captured["api_key"] = kwargs.get("api_key")
captured["extra_tags"] = kwargs.get("extra_tags")
original_init(self, *args, **kwargs)
self.run_conversation = lambda **kw: (
self.event_queue.put({"type": "text", "text": "ok"}) if self.event_queue else None,
self.event_queue.put({"type": "done"}) if self.event_queue else None,
{"final_response": "ok", "messages": [], "api_calls": 1, "completed": True},
)[-1]
with (
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
patch.object(AIAgent, "__init__", patched_init),
):
transport = ASGITransport(app=fastapi_app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
await client.post(
"/v1/agent/stream",
json={
"messages": [{"role": "user", "content": "hi"}],
"model": "test/model",
"base_url": "https://my-api.example.com/v1",
"api_key": "sk-test-key",
"tags": ["user=alice", "tier=free"],
},
timeout=30,
)
assert captured["base_url"] == "https://my-api.example.com/v1"
assert captured["api_key"] == "sk-test-key"
assert captured["extra_tags"] == ["user=alice", "tier=free"]