mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-07 02:51:50 +00:00
feat: add streaming token output and simplify CLI to plain stdout
This commit is contained in:
parent
c754135965
commit
4d6c90c6d0
7 changed files with 599 additions and 982 deletions
256
tests/test_streaming.py
Normal file
256
tests/test_streaming.py
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
"""Tests for streaming token output — accumulator shape, callback order, fallback."""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
import pytest
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tool_defs(*names):
|
||||
return [
|
||||
{"type": "function", "function": {"name": n, "description": f"{n}", "parameters": {"type": "object", "properties": {}}}}
|
||||
for n in names
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def agent():
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
cb = MagicMock()
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
stream_delta_callback=cb,
|
||||
)
|
||||
a.client = MagicMock()
|
||||
a._stream_cb = cb
|
||||
return a
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — fake streaming chunks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _chunk(content=None, tool_call_delta=None, finish_reason=None, usage=None, model=None):
|
||||
delta = SimpleNamespace(content=content, tool_calls=tool_call_delta)
|
||||
choice = SimpleNamespace(delta=delta, finish_reason=finish_reason)
|
||||
c = SimpleNamespace(choices=[choice])
|
||||
if usage is not None:
|
||||
c.usage = SimpleNamespace(**usage)
|
||||
if model:
|
||||
c.model = model
|
||||
return c
|
||||
|
||||
|
||||
def _usage_chunk(**kw):
|
||||
c = SimpleNamespace(choices=[], usage=SimpleNamespace(**kw))
|
||||
return c
|
||||
|
||||
|
||||
def _tc_delta(index, id=None, name=None, arguments=None, type=None):
|
||||
fn = SimpleNamespace(name=name, arguments=arguments)
|
||||
return SimpleNamespace(index=index, id=id, type=type, function=fn)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: accumulator shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStreamingAccumulator:
|
||||
def test_text_only_response(self, agent):
|
||||
"""Streaming text-only response produces correct synthetic shape."""
|
||||
chunks = [
|
||||
_chunk(content="Hello", model="test/m"),
|
||||
_chunk(content=" world"),
|
||||
_chunk(finish_reason="stop"),
|
||||
_usage_chunk(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
assert resp.choices[0].message.content == "Hello world"
|
||||
assert resp.choices[0].message.tool_calls is None
|
||||
assert resp.choices[0].finish_reason == "stop"
|
||||
assert resp.usage.prompt_tokens == 10
|
||||
assert resp.model == "test/m"
|
||||
|
||||
def test_tool_call_response(self, agent):
|
||||
"""Streaming tool-call response accumulates function name + arguments."""
|
||||
chunks = [
|
||||
_chunk(tool_call_delta=[_tc_delta(0, id="call_1", name="web_search", arguments='{"q', type="function")]),
|
||||
_chunk(tool_call_delta=[_tc_delta(0, arguments='uery": "hi"}')]),
|
||||
_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
tc = resp.choices[0].message.tool_calls
|
||||
assert tc is not None
|
||||
assert len(tc) == 1
|
||||
assert tc[0].id == "call_1"
|
||||
assert tc[0].function.name == "web_search"
|
||||
assert tc[0].function.arguments == '{"query": "hi"}'
|
||||
assert resp.choices[0].finish_reason == "tool_calls"
|
||||
|
||||
def test_mixed_content_and_tool_calls(self, agent):
|
||||
"""Content + tool calls in same stream are both accumulated."""
|
||||
chunks = [
|
||||
_chunk(content="Let me check."),
|
||||
_chunk(tool_call_delta=[_tc_delta(0, id="c1", name="web_search", arguments="{}", type="function")]),
|
||||
_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
assert resp.choices[0].message.content == "Let me check."
|
||||
assert len(resp.choices[0].message.tool_calls) == 1
|
||||
|
||||
|
||||
class TestStreamingCallbacks:
|
||||
def test_deltas_fire_in_order(self, agent):
|
||||
"""stream_delta_callback receives content deltas in order."""
|
||||
received = []
|
||||
agent.stream_delta_callback = lambda t: received.append(t)
|
||||
chunks = [_chunk(content="a"), _chunk(content="b"), _chunk(content="c"), _chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
assert received == ["a", "b", "c"]
|
||||
|
||||
def test_on_first_delta_fires_once(self, agent):
|
||||
first = MagicMock()
|
||||
chunks = [_chunk(content="x"), _chunk(content="y"), _chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
agent._interruptible_streaming_api_call({"model": "test"}, on_first_delta=first)
|
||||
|
||||
first.assert_called_once()
|
||||
|
||||
def test_tool_only_does_not_fire_callback(self, agent):
|
||||
"""Tool-call-only stream does not invoke stream_delta_callback."""
|
||||
received = []
|
||||
agent.stream_delta_callback = lambda t: received.append(t)
|
||||
chunks = [
|
||||
_chunk(tool_call_delta=[_tc_delta(0, id="c1", name="t", arguments="{}", type="function")]),
|
||||
_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
assert received == []
|
||||
|
||||
|
||||
class TestStreamingFallback:
|
||||
def test_stream_error_falls_back(self, agent):
|
||||
"""When streaming fails with 'not support', falls back to non-streaming."""
|
||||
agent.client.chat.completions.create.side_effect = [
|
||||
Exception("streaming not supported by this provider"),
|
||||
SimpleNamespace(
|
||||
choices=[SimpleNamespace(
|
||||
message=SimpleNamespace(content="ok", tool_calls=None, reasoning=None, reasoning_content=None, reasoning_details=None),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=None,
|
||||
model="test/m",
|
||||
),
|
||||
]
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
assert resp.choices[0].message.content == "ok"
|
||||
assert agent.client.chat.completions.create.call_count == 2
|
||||
|
||||
def test_non_stream_error_raises(self, agent):
|
||||
"""Non-stream-related errors propagate normally."""
|
||||
agent.client.chat.completions.create.side_effect = ValueError("bad request")
|
||||
|
||||
with pytest.raises(ValueError, match="bad request"):
|
||||
agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: base.py already_sent contract
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAlreadySentContract:
|
||||
def _make_adapter(self, send_side_effect=None):
|
||||
from gateway.platforms.base import BasePlatformAdapter, SendResult
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
class FakeAdapter(BasePlatformAdapter):
|
||||
async def connect(self): return True
|
||||
async def disconnect(self): pass
|
||||
async def get_chat_info(self, chat_id): return {"name": "test"}
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
if send_side_effect is not None:
|
||||
send_side_effect(content)
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
cfg = PlatformConfig(enabled=True)
|
||||
adapter = FakeAdapter(cfg, Platform.TELEGRAM)
|
||||
adapter._running = True
|
||||
return adapter
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_sent_skips_send(self):
|
||||
"""Handler returning already_sent=True prevents base from calling send()."""
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.config import Platform
|
||||
from gateway.session import SessionSource
|
||||
|
||||
sent = []
|
||||
adapter = self._make_adapter(send_side_effect=lambda c: sent.append(c))
|
||||
|
||||
async def handler(event):
|
||||
return {"content": "hello", "already_sent": True}
|
||||
adapter.set_message_handler(handler)
|
||||
|
||||
event = MessageEvent(
|
||||
text="hi",
|
||||
source=SessionSource(platform=Platform.TELEGRAM, chat_id="1", user_id="u1"),
|
||||
)
|
||||
await adapter._process_message_background(event, "s1")
|
||||
|
||||
assert sent == [], "send() should not be called when already_sent=True"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_string_response_sends_normally(self):
|
||||
"""Handler returning a plain string triggers send() as before."""
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.config import Platform
|
||||
from gateway.session import SessionSource
|
||||
|
||||
sent = []
|
||||
adapter = self._make_adapter(send_side_effect=lambda c: sent.append(c))
|
||||
|
||||
async def handler(event):
|
||||
return "hello"
|
||||
adapter.set_message_handler(handler)
|
||||
|
||||
event = MessageEvent(
|
||||
text="hi",
|
||||
source=SessionSource(platform=Platform.TELEGRAM, chat_id="1", user_id="u1"),
|
||||
)
|
||||
await adapter._process_message_background(event, "s1")
|
||||
|
||||
assert "hello" in sent
|
||||
Loading…
Add table
Add a link
Reference in a new issue