mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
147 lines
4.8 KiB
Python
147 lines
4.8 KiB
Python
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from gateway.platforms.base import MessageEvent
|
|
from gateway.restart import GATEWAY_SERVICE_RESTART_EXIT_CODE
|
|
from gateway.session import build_session_key
|
|
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
|
_runner, adapter = make_restart_runner()
|
|
release = asyncio.Event()
|
|
|
|
async def block_forever(_event):
|
|
await release.wait()
|
|
return None
|
|
|
|
adapter.set_message_handler(block_forever)
|
|
event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
|
|
|
|
await adapter.handle_message(event)
|
|
await asyncio.sleep(0)
|
|
|
|
session_key = build_session_key(event.source)
|
|
assert session_key in adapter._active_sessions
|
|
assert adapter._background_tasks
|
|
|
|
await adapter.cancel_background_tasks()
|
|
|
|
assert adapter._background_tasks == set()
|
|
assert adapter._active_sessions == {}
|
|
assert adapter._pending_messages == {}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
|
|
runner, adapter = make_restart_runner()
|
|
runner._pending_messages = {"session": "pending text"}
|
|
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
|
|
runner._restart_drain_timeout = 0.0
|
|
|
|
release = asyncio.Event()
|
|
|
|
async def block_forever(_event):
|
|
await release.wait()
|
|
return None
|
|
|
|
adapter.set_message_handler(block_forever)
|
|
event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
|
|
await adapter.handle_message(event)
|
|
await asyncio.sleep(0)
|
|
|
|
disconnect_mock = AsyncMock()
|
|
adapter.disconnect = disconnect_mock
|
|
|
|
session_key = build_session_key(event.source)
|
|
running_agent = MagicMock()
|
|
runner._running_agents = {session_key: running_agent}
|
|
|
|
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
|
await runner.stop()
|
|
|
|
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
|
|
disconnect_mock.assert_awaited_once()
|
|
assert runner.adapters == {}
|
|
assert runner._running_agents == {}
|
|
assert runner._pending_messages == {}
|
|
assert runner._pending_approvals == {}
|
|
assert runner._shutdown_event.is_set() is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_gateway_stop_drains_running_agents_before_disconnect():
|
|
runner, adapter = make_restart_runner()
|
|
disconnect_mock = AsyncMock()
|
|
adapter.disconnect = disconnect_mock
|
|
|
|
running_agent = MagicMock()
|
|
runner._running_agents = {"session": running_agent}
|
|
|
|
async def finish_agent():
|
|
await asyncio.sleep(0.05)
|
|
runner._running_agents.clear()
|
|
|
|
asyncio.create_task(finish_agent())
|
|
|
|
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
|
await runner.stop()
|
|
|
|
running_agent.interrupt.assert_not_called()
|
|
disconnect_mock.assert_awaited_once()
|
|
assert runner._shutdown_event.is_set() is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_gateway_stop_interrupts_after_drain_timeout():
|
|
runner, adapter = make_restart_runner()
|
|
runner._restart_drain_timeout = 0.05
|
|
|
|
disconnect_mock = AsyncMock()
|
|
adapter.disconnect = disconnect_mock
|
|
|
|
running_agent = MagicMock()
|
|
runner._running_agents = {"session": running_agent}
|
|
|
|
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
|
await runner.stop()
|
|
|
|
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
|
|
disconnect_mock.assert_awaited_once()
|
|
assert runner._shutdown_event.is_set() is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_gateway_stop_service_restart_sets_named_exit_code():
|
|
runner, adapter = make_restart_runner()
|
|
adapter.disconnect = AsyncMock()
|
|
|
|
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
|
await runner.stop(restart=True, service_restart=True)
|
|
|
|
assert runner._exit_code == GATEWAY_SERVICE_RESTART_EXIT_CODE
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_drain_active_agents_throttles_status_updates():
|
|
runner, _adapter = make_restart_runner()
|
|
runner._update_runtime_status = MagicMock()
|
|
|
|
runner._running_agents = {"a": MagicMock(), "b": MagicMock()}
|
|
|
|
async def finish_agents():
|
|
await asyncio.sleep(0.12)
|
|
runner._running_agents.pop("a")
|
|
await asyncio.sleep(0.12)
|
|
runner._running_agents.clear()
|
|
|
|
task = asyncio.create_task(finish_agents())
|
|
await runner._drain_active_agents(1.0)
|
|
await task
|
|
|
|
# Start, one count-change update, and final update. Allow one extra update
|
|
# if the loop observes the zero-agent state before exiting.
|
|
assert 3 <= runner._update_runtime_status.call_count <= 4
|