mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(discord): decouple readiness from slash sync
This commit is contained in:
parent
fa7cd44b92
commit
cfbfc4c3f1
6 changed files with 250 additions and 24 deletions
|
|
@ -456,6 +456,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
# show the standard typing gateway event for bots)
|
||||
self._typing_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._bot_task: Optional[asyncio.Task] = None
|
||||
self._post_connect_task: Optional[asyncio.Task] = None
|
||||
# Dedup cache: prevents duplicate bot responses when Discord
|
||||
# RESUME replays events after reconnects.
|
||||
self._dedup = MessageDeduplicator()
|
||||
|
|
@ -545,15 +546,14 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
# Resolve any usernames in the allowed list to numeric IDs
|
||||
await adapter_self._resolve_allowed_usernames()
|
||||
|
||||
# Sync slash commands with Discord
|
||||
try:
|
||||
synced = await adapter_self._client.tree.sync()
|
||||
logger.info("[%s] Synced %d slash command(s)", adapter_self.name, len(synced))
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.warning("[%s] Slash command sync failed: %s", adapter_self.name, e, exc_info=True)
|
||||
adapter_self._ready_event.set()
|
||||
|
||||
if adapter_self._post_connect_task and not adapter_self._post_connect_task.done():
|
||||
adapter_self._post_connect_task.cancel()
|
||||
adapter_self._post_connect_task = asyncio.create_task(
|
||||
adapter_self._run_post_connect_initialization()
|
||||
)
|
||||
|
||||
@self._client.event
|
||||
async def on_message(message: DiscordMessage):
|
||||
# Dedup: Discord RESUME replays events after reconnects (#4777)
|
||||
|
|
@ -686,14 +686,36 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.warning("[%s] Error during disconnect: %s", self.name, e, exc_info=True)
|
||||
|
||||
if self._post_connect_task and not self._post_connect_task.done():
|
||||
self._post_connect_task.cancel()
|
||||
try:
|
||||
await self._post_connect_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._running = False
|
||||
self._client = None
|
||||
self._ready_event.clear()
|
||||
self._post_connect_task = None
|
||||
|
||||
self._release_platform_lock()
|
||||
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
async def _run_post_connect_initialization(self) -> None:
|
||||
"""Finish non-critical startup work after Discord is connected."""
|
||||
if not self._client:
|
||||
return
|
||||
try:
|
||||
synced = await asyncio.wait_for(self._client.tree.sync(), timeout=30)
|
||||
logger.info("[%s] Synced %d slash command(s)", self.name, len(synced))
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("[%s] Slash command sync timed out after 30s", self.name)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.warning("[%s] Slash command sync failed: %s", self.name, e, exc_info=True)
|
||||
|
||||
async def _add_reaction(self, message: Any, emoji: str) -> bool:
|
||||
"""Add an emoji reaction to a Discord message."""
|
||||
if not message or not hasattr(message, "add_reaction"):
|
||||
|
|
|
|||
|
|
@ -916,6 +916,12 @@ class GatewayRunner:
|
|||
adapter.fatal_error_code or "unknown",
|
||||
adapter.fatal_error_message or "unknown error",
|
||||
)
|
||||
self._update_platform_runtime_status(
|
||||
adapter.platform.value,
|
||||
platform_state="retrying" if adapter.fatal_error_retryable else "fatal",
|
||||
error_code=adapter.fatal_error_code,
|
||||
error_message=adapter.fatal_error_message,
|
||||
)
|
||||
|
||||
existing = self.adapters.get(adapter.platform)
|
||||
if existing is adapter:
|
||||
|
|
@ -993,6 +999,25 @@ class GatewayRunner:
|
|||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _update_platform_runtime_status(
|
||||
self,
|
||||
platform: str,
|
||||
*,
|
||||
platform_state: Optional[str] = None,
|
||||
error_code: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
) -> None:
|
||||
try:
|
||||
from gateway.status import write_runtime_status
|
||||
write_runtime_status(
|
||||
platform=platform,
|
||||
platform_state=platform_state,
|
||||
error_code=error_code,
|
||||
error_message=error_message,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _load_prefill_messages() -> List[Dict[str, Any]]:
|
||||
|
|
@ -1498,16 +1523,34 @@ class GatewayRunner:
|
|||
|
||||
# Try to connect
|
||||
logger.info("Connecting to %s...", platform.value)
|
||||
self._update_platform_runtime_status(
|
||||
platform.value,
|
||||
platform_state="connecting",
|
||||
error_code=None,
|
||||
error_message=None,
|
||||
)
|
||||
try:
|
||||
success = await adapter.connect()
|
||||
if success:
|
||||
self.adapters[platform] = adapter
|
||||
self._sync_voice_mode_state_to_adapter(adapter)
|
||||
connected_count += 1
|
||||
self._update_platform_runtime_status(
|
||||
platform.value,
|
||||
platform_state="connected",
|
||||
error_code=None,
|
||||
error_message=None,
|
||||
)
|
||||
logger.info("✓ %s connected", platform.value)
|
||||
else:
|
||||
logger.warning("✗ %s failed to connect", platform.value)
|
||||
if adapter.has_fatal_error:
|
||||
self._update_platform_runtime_status(
|
||||
platform.value,
|
||||
platform_state="retrying" if adapter.fatal_error_retryable else "fatal",
|
||||
error_code=adapter.fatal_error_code,
|
||||
error_message=adapter.fatal_error_message,
|
||||
)
|
||||
target = (
|
||||
startup_retryable_errors
|
||||
if adapter.fatal_error_retryable
|
||||
|
|
@ -1524,6 +1567,12 @@ class GatewayRunner:
|
|||
"next_retry": time.monotonic() + 30,
|
||||
}
|
||||
else:
|
||||
self._update_platform_runtime_status(
|
||||
platform.value,
|
||||
platform_state="retrying",
|
||||
error_code=None,
|
||||
error_message="failed to connect",
|
||||
)
|
||||
startup_retryable_errors.append(
|
||||
f"{platform.value}: failed to connect"
|
||||
)
|
||||
|
|
@ -1535,6 +1584,12 @@ class GatewayRunner:
|
|||
}
|
||||
except Exception as e:
|
||||
logger.error("✗ %s error: %s", platform.value, e)
|
||||
self._update_platform_runtime_status(
|
||||
platform.value,
|
||||
platform_state="retrying",
|
||||
error_code=None,
|
||||
error_message=str(e),
|
||||
)
|
||||
startup_retryable_errors.append(f"{platform.value}: {e}")
|
||||
# Unexpected exceptions are typically transient — queue for retry
|
||||
self._failed_platforms[platform] = {
|
||||
|
|
@ -1813,6 +1868,12 @@ class GatewayRunner:
|
|||
self._sync_voice_mode_state_to_adapter(adapter)
|
||||
self.delivery_router.adapters = self.adapters
|
||||
del self._failed_platforms[platform]
|
||||
self._update_platform_runtime_status(
|
||||
platform.value,
|
||||
platform_state="connected",
|
||||
error_code=None,
|
||||
error_message=None,
|
||||
)
|
||||
logger.info("✓ %s reconnected successfully", platform.value)
|
||||
|
||||
# Rebuild channel directory with the new adapter
|
||||
|
|
@ -1824,12 +1885,24 @@ class GatewayRunner:
|
|||
else:
|
||||
# Check if the failure is non-retryable
|
||||
if adapter.has_fatal_error and not adapter.fatal_error_retryable:
|
||||
self._update_platform_runtime_status(
|
||||
platform.value,
|
||||
platform_state="fatal",
|
||||
error_code=adapter.fatal_error_code,
|
||||
error_message=adapter.fatal_error_message,
|
||||
)
|
||||
logger.warning(
|
||||
"Reconnect %s: non-retryable error (%s), removing from retry queue",
|
||||
platform.value, adapter.fatal_error_message,
|
||||
)
|
||||
del self._failed_platforms[platform]
|
||||
else:
|
||||
self._update_platform_runtime_status(
|
||||
platform.value,
|
||||
platform_state="retrying",
|
||||
error_code=adapter.fatal_error_code,
|
||||
error_message=adapter.fatal_error_message or "failed to reconnect",
|
||||
)
|
||||
backoff = min(30 * (2 ** (attempt - 1)), _BACKOFF_CAP)
|
||||
info["attempts"] = attempt
|
||||
info["next_retry"] = time.monotonic() + backoff
|
||||
|
|
@ -1838,6 +1911,12 @@ class GatewayRunner:
|
|||
platform.value, backoff,
|
||||
)
|
||||
except Exception as e:
|
||||
self._update_platform_runtime_status(
|
||||
platform.value,
|
||||
platform_state="retrying",
|
||||
error_code=None,
|
||||
error_message=str(e),
|
||||
)
|
||||
backoff = min(30 * (2 ** (attempt - 1)), _BACKOFF_CAP)
|
||||
info["attempts"] = attempt
|
||||
info["next_retry"] = time.monotonic() + backoff
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ _GATEWAY_KIND = "hermes-gateway"
|
|||
_RUNTIME_STATUS_FILE = "gateway_state.json"
|
||||
_LOCKS_DIRNAME = "gateway-locks"
|
||||
_IS_WINDOWS = sys.platform == "win32"
|
||||
_UNSET = object()
|
||||
|
||||
|
||||
def _get_pid_path() -> Path:
|
||||
|
|
@ -218,14 +219,14 @@ def write_pid_file() -> None:
|
|||
|
||||
def write_runtime_status(
|
||||
*,
|
||||
gateway_state: Optional[str] = None,
|
||||
exit_reason: Optional[str] = None,
|
||||
restart_requested: Optional[bool] = None,
|
||||
active_agents: Optional[int] = None,
|
||||
platform: Optional[str] = None,
|
||||
platform_state: Optional[str] = None,
|
||||
error_code: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
gateway_state: Any = _UNSET,
|
||||
exit_reason: Any = _UNSET,
|
||||
restart_requested: Any = _UNSET,
|
||||
active_agents: Any = _UNSET,
|
||||
platform: Any = _UNSET,
|
||||
platform_state: Any = _UNSET,
|
||||
error_code: Any = _UNSET,
|
||||
error_message: Any = _UNSET,
|
||||
) -> None:
|
||||
"""Persist gateway runtime health information for diagnostics/status."""
|
||||
path = _get_runtime_status_path()
|
||||
|
|
@ -236,22 +237,22 @@ def write_runtime_status(
|
|||
payload["start_time"] = _get_process_start_time(os.getpid())
|
||||
payload["updated_at"] = _utc_now_iso()
|
||||
|
||||
if gateway_state is not None:
|
||||
if gateway_state is not _UNSET:
|
||||
payload["gateway_state"] = gateway_state
|
||||
if exit_reason is not None:
|
||||
if exit_reason is not _UNSET:
|
||||
payload["exit_reason"] = exit_reason
|
||||
if restart_requested is not None:
|
||||
if restart_requested is not _UNSET:
|
||||
payload["restart_requested"] = bool(restart_requested)
|
||||
if active_agents is not None:
|
||||
if active_agents is not _UNSET:
|
||||
payload["active_agents"] = max(0, int(active_agents))
|
||||
|
||||
if platform is not None:
|
||||
if platform is not _UNSET:
|
||||
platform_payload = payload["platforms"].get(platform, {})
|
||||
if platform_state is not None:
|
||||
if platform_state is not _UNSET:
|
||||
platform_payload["state"] = platform_state
|
||||
if error_code is not None:
|
||||
if error_code is not _UNSET:
|
||||
platform_payload["error_code"] = error_code
|
||||
if error_message is not None:
|
||||
if error_message is not _UNSET:
|
||||
platform_payload["error_message"] = error_message
|
||||
platform_payload["updated_at"] = _utc_now_iso()
|
||||
payload["platforms"][platform] = platform_payload
|
||||
|
|
|
|||
|
|
@ -74,6 +74,26 @@ class FakeBot:
|
|||
return None
|
||||
|
||||
|
||||
class SlowSyncTree(FakeTree):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.started = asyncio.Event()
|
||||
self.allow_finish = asyncio.Event()
|
||||
|
||||
async def _slow_sync():
|
||||
self.started.set()
|
||||
await self.allow_finish.wait()
|
||||
return []
|
||||
|
||||
self.sync = AsyncMock(side_effect=_slow_sync)
|
||||
|
||||
|
||||
class SlowSyncBot(FakeBot):
|
||||
def __init__(self, *, intents, proxy=None):
|
||||
super().__init__(intents=intents, proxy=proxy)
|
||||
self.tree = SlowSyncTree()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("allowed_users", "expected_members_intent"),
|
||||
|
|
@ -138,3 +158,36 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch):
|
|||
assert ok is False
|
||||
assert released == [("discord-bot-token", "test-token")]
|
||||
assert adapter._platform_lock_identity is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_does_not_wait_for_slash_sync(monkeypatch):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="test-token"))
|
||||
|
||||
monkeypatch.setattr("gateway.status.acquire_scoped_lock", lambda scope, identity, metadata=None: (True, None))
|
||||
monkeypatch.setattr("gateway.status.release_scoped_lock", lambda scope, identity: None)
|
||||
|
||||
intents = SimpleNamespace(message_content=False, dm_messages=False, guild_messages=False, members=False, voice_states=False)
|
||||
monkeypatch.setattr(discord_platform.Intents, "default", lambda: intents)
|
||||
|
||||
created = {}
|
||||
|
||||
def fake_bot_factory(*, command_prefix, intents, proxy=None):
|
||||
bot = SlowSyncBot(intents=intents, proxy=proxy)
|
||||
created["bot"] = bot
|
||||
return bot
|
||||
|
||||
monkeypatch.setattr(discord_platform.commands, "Bot", fake_bot_factory)
|
||||
monkeypatch.setattr(adapter, "_resolve_allowed_usernames", AsyncMock())
|
||||
|
||||
ok = await asyncio.wait_for(adapter.connect(), timeout=1.0)
|
||||
|
||||
assert ok is True
|
||||
assert adapter._ready_event.is_set()
|
||||
|
||||
await asyncio.wait_for(created["bot"].tree.started.wait(), timeout=1.0)
|
||||
assert created["bot"].tree.sync.await_count == 1
|
||||
|
||||
created["bot"].tree.allow_finish.set()
|
||||
await asyncio.sleep(0)
|
||||
await adapter.disconnect()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
|
|
@ -45,6 +46,23 @@ class _DisabledAdapter(BasePlatformAdapter):
|
|||
return {"id": chat_id}
|
||||
|
||||
|
||||
class _SuccessfulAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.DISCORD)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._mark_disconnected()
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
|
@ -65,7 +83,7 @@ async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch,
|
|||
state = read_runtime_status()
|
||||
assert state["gateway_state"] == "startup_failed"
|
||||
assert "temporary DNS resolution failure" in state["exit_reason"]
|
||||
assert state["platforms"]["telegram"]["state"] == "fatal"
|
||||
assert state["platforms"]["telegram"]["state"] == "retrying"
|
||||
assert state["platforms"]["telegram"]["error_code"] == "telegram_connect_error"
|
||||
|
||||
|
||||
|
|
@ -89,6 +107,31 @@ async def test_runner_allows_cron_only_mode_when_no_platforms_are_enabled(monkey
|
|||
assert state["gateway_state"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_records_connected_platform_state_on_success(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(enabled=True, token="***")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _SuccessfulAdapter())
|
||||
monkeypatch.setattr(runner.hooks, "discover_and_load", lambda: None)
|
||||
monkeypatch.setattr(runner.hooks, "emit", AsyncMock())
|
||||
|
||||
ok = await runner.start()
|
||||
|
||||
assert ok is True
|
||||
state = read_runtime_status()
|
||||
assert state["gateway_state"] == "running"
|
||||
assert state["platforms"]["discord"]["state"] == "connected"
|
||||
assert state["platforms"]["discord"]["error_code"] is None
|
||||
assert state["platforms"]["discord"]["error_message"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_gateway_replace_force_uses_terminate_pid(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
|
|
|||
|
|
@ -104,6 +104,34 @@ class TestGatewayRuntimeStatus:
|
|||
assert payload["platforms"]["telegram"]["error_code"] == "telegram_polling_conflict"
|
||||
assert payload["platforms"]["telegram"]["error_message"] == "another poller is active"
|
||||
|
||||
def test_write_runtime_status_explicit_none_clears_stale_fields(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
status.write_runtime_status(
|
||||
gateway_state="startup_failed",
|
||||
exit_reason="stale error",
|
||||
platform="discord",
|
||||
platform_state="fatal",
|
||||
error_code="discord_timeout",
|
||||
error_message="stale platform error",
|
||||
)
|
||||
|
||||
status.write_runtime_status(
|
||||
gateway_state="running",
|
||||
exit_reason=None,
|
||||
platform="discord",
|
||||
platform_state="connected",
|
||||
error_code=None,
|
||||
error_message=None,
|
||||
)
|
||||
|
||||
payload = status.read_runtime_status()
|
||||
assert payload["gateway_state"] == "running"
|
||||
assert payload["exit_reason"] is None
|
||||
assert payload["platforms"]["discord"]["state"] == "connected"
|
||||
assert payload["platforms"]["discord"]["error_code"] is None
|
||||
assert payload["platforms"]["discord"]["error_message"] is None
|
||||
|
||||
|
||||
class TestTerminatePid:
|
||||
def test_force_uses_taskkill_on_windows(self, monkeypatch):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue