diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index b1d07e5d65..43a9338d78 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -456,6 +456,7 @@ class DiscordAdapter(BasePlatformAdapter): # show the standard typing gateway event for bots) self._typing_tasks: Dict[str, asyncio.Task] = {} self._bot_task: Optional[asyncio.Task] = None + self._post_connect_task: Optional[asyncio.Task] = None # Dedup cache: prevents duplicate bot responses when Discord # RESUME replays events after reconnects. self._dedup = MessageDeduplicator() @@ -545,15 +546,14 @@ class DiscordAdapter(BasePlatformAdapter): # Resolve any usernames in the allowed list to numeric IDs await adapter_self._resolve_allowed_usernames() - - # Sync slash commands with Discord - try: - synced = await adapter_self._client.tree.sync() - logger.info("[%s] Synced %d slash command(s)", adapter_self.name, len(synced)) - except Exception as e: # pragma: no cover - defensive logging - logger.warning("[%s] Slash command sync failed: %s", adapter_self.name, e, exc_info=True) adapter_self._ready_event.set() + if adapter_self._post_connect_task and not adapter_self._post_connect_task.done(): + adapter_self._post_connect_task.cancel() + adapter_self._post_connect_task = asyncio.create_task( + adapter_self._run_post_connect_initialization() + ) + @self._client.event async def on_message(message: DiscordMessage): # Dedup: Discord RESUME replays events after reconnects (#4777) @@ -686,14 +686,36 @@ class DiscordAdapter(BasePlatformAdapter): except Exception as e: # pragma: no cover - defensive logging logger.warning("[%s] Error during disconnect: %s", self.name, e, exc_info=True) + if self._post_connect_task and not self._post_connect_task.done(): + self._post_connect_task.cancel() + try: + await self._post_connect_task + except asyncio.CancelledError: + pass + self._running = False self._client = None self._ready_event.clear() + self._post_connect_task = None self._release_platform_lock() logger.info("[%s] Disconnected", self.name) + async def _run_post_connect_initialization(self) -> None: + """Finish non-critical startup work after Discord is connected.""" + if not self._client: + return + try: + synced = await asyncio.wait_for(self._client.tree.sync(), timeout=30) + logger.info("[%s] Synced %d slash command(s)", self.name, len(synced)) + except asyncio.TimeoutError: + logger.warning("[%s] Slash command sync timed out after 30s", self.name) + except asyncio.CancelledError: + raise + except Exception as e: # pragma: no cover - defensive logging + logger.warning("[%s] Slash command sync failed: %s", self.name, e, exc_info=True) + async def _add_reaction(self, message: Any, emoji: str) -> bool: """Add an emoji reaction to a Discord message.""" if not message or not hasattr(message, "add_reaction"): diff --git a/gateway/run.py b/gateway/run.py index bb566ca5da..560ccee4a0 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -916,6 +916,12 @@ class GatewayRunner: adapter.fatal_error_code or "unknown", adapter.fatal_error_message or "unknown error", ) + self._update_platform_runtime_status( + adapter.platform.value, + platform_state="retrying" if adapter.fatal_error_retryable else "fatal", + error_code=adapter.fatal_error_code, + error_message=adapter.fatal_error_message, + ) existing = self.adapters.get(adapter.platform) if existing is adapter: @@ -993,6 +999,25 @@ class GatewayRunner: ) except Exception: pass + + def _update_platform_runtime_status( + self, + platform: str, + *, + platform_state: Optional[str] = None, + error_code: Optional[str] = None, + error_message: Optional[str] = None, + ) -> None: + try: + from gateway.status import write_runtime_status + write_runtime_status( + platform=platform, + platform_state=platform_state, + error_code=error_code, + error_message=error_message, + ) + except Exception: + pass @staticmethod def _load_prefill_messages() -> List[Dict[str, Any]]: @@ -1498,16 +1523,34 @@ class GatewayRunner: # Try to connect logger.info("Connecting to %s...", platform.value) + self._update_platform_runtime_status( + platform.value, + platform_state="connecting", + error_code=None, + error_message=None, + ) try: success = await adapter.connect() if success: self.adapters[platform] = adapter self._sync_voice_mode_state_to_adapter(adapter) connected_count += 1 + self._update_platform_runtime_status( + platform.value, + platform_state="connected", + error_code=None, + error_message=None, + ) logger.info("✓ %s connected", platform.value) else: logger.warning("✗ %s failed to connect", platform.value) if adapter.has_fatal_error: + self._update_platform_runtime_status( + platform.value, + platform_state="retrying" if adapter.fatal_error_retryable else "fatal", + error_code=adapter.fatal_error_code, + error_message=adapter.fatal_error_message, + ) target = ( startup_retryable_errors if adapter.fatal_error_retryable @@ -1524,6 +1567,12 @@ class GatewayRunner: "next_retry": time.monotonic() + 30, } else: + self._update_platform_runtime_status( + platform.value, + platform_state="retrying", + error_code=None, + error_message="failed to connect", + ) startup_retryable_errors.append( f"{platform.value}: failed to connect" ) @@ -1535,6 +1584,12 @@ class GatewayRunner: } except Exception as e: logger.error("✗ %s error: %s", platform.value, e) + self._update_platform_runtime_status( + platform.value, + platform_state="retrying", + error_code=None, + error_message=str(e), + ) startup_retryable_errors.append(f"{platform.value}: {e}") # Unexpected exceptions are typically transient — queue for retry self._failed_platforms[platform] = { @@ -1813,6 +1868,12 @@ class GatewayRunner: self._sync_voice_mode_state_to_adapter(adapter) self.delivery_router.adapters = self.adapters del self._failed_platforms[platform] + self._update_platform_runtime_status( + platform.value, + platform_state="connected", + error_code=None, + error_message=None, + ) logger.info("✓ %s reconnected successfully", platform.value) # Rebuild channel directory with the new adapter @@ -1824,12 +1885,24 @@ class GatewayRunner: else: # Check if the failure is non-retryable if adapter.has_fatal_error and not adapter.fatal_error_retryable: + self._update_platform_runtime_status( + platform.value, + platform_state="fatal", + error_code=adapter.fatal_error_code, + error_message=adapter.fatal_error_message, + ) logger.warning( "Reconnect %s: non-retryable error (%s), removing from retry queue", platform.value, adapter.fatal_error_message, ) del self._failed_platforms[platform] else: + self._update_platform_runtime_status( + platform.value, + platform_state="retrying", + error_code=adapter.fatal_error_code, + error_message=adapter.fatal_error_message or "failed to reconnect", + ) backoff = min(30 * (2 ** (attempt - 1)), _BACKOFF_CAP) info["attempts"] = attempt info["next_retry"] = time.monotonic() + backoff @@ -1838,6 +1911,12 @@ class GatewayRunner: platform.value, backoff, ) except Exception as e: + self._update_platform_runtime_status( + platform.value, + platform_state="retrying", + error_code=None, + error_message=str(e), + ) backoff = min(30 * (2 ** (attempt - 1)), _BACKOFF_CAP) info["attempts"] = attempt info["next_retry"] = time.monotonic() + backoff diff --git a/gateway/status.py b/gateway/status.py index 5423461c2f..d7f357b363 100644 --- a/gateway/status.py +++ b/gateway/status.py @@ -26,6 +26,7 @@ _GATEWAY_KIND = "hermes-gateway" _RUNTIME_STATUS_FILE = "gateway_state.json" _LOCKS_DIRNAME = "gateway-locks" _IS_WINDOWS = sys.platform == "win32" +_UNSET = object() def _get_pid_path() -> Path: @@ -218,14 +219,14 @@ def write_pid_file() -> None: def write_runtime_status( *, - gateway_state: Optional[str] = None, - exit_reason: Optional[str] = None, - restart_requested: Optional[bool] = None, - active_agents: Optional[int] = None, - platform: Optional[str] = None, - platform_state: Optional[str] = None, - error_code: Optional[str] = None, - error_message: Optional[str] = None, + gateway_state: Any = _UNSET, + exit_reason: Any = _UNSET, + restart_requested: Any = _UNSET, + active_agents: Any = _UNSET, + platform: Any = _UNSET, + platform_state: Any = _UNSET, + error_code: Any = _UNSET, + error_message: Any = _UNSET, ) -> None: """Persist gateway runtime health information for diagnostics/status.""" path = _get_runtime_status_path() @@ -236,22 +237,22 @@ def write_runtime_status( payload["start_time"] = _get_process_start_time(os.getpid()) payload["updated_at"] = _utc_now_iso() - if gateway_state is not None: + if gateway_state is not _UNSET: payload["gateway_state"] = gateway_state - if exit_reason is not None: + if exit_reason is not _UNSET: payload["exit_reason"] = exit_reason - if restart_requested is not None: + if restart_requested is not _UNSET: payload["restart_requested"] = bool(restart_requested) - if active_agents is not None: + if active_agents is not _UNSET: payload["active_agents"] = max(0, int(active_agents)) - if platform is not None: + if platform is not _UNSET: platform_payload = payload["platforms"].get(platform, {}) - if platform_state is not None: + if platform_state is not _UNSET: platform_payload["state"] = platform_state - if error_code is not None: + if error_code is not _UNSET: platform_payload["error_code"] = error_code - if error_message is not None: + if error_message is not _UNSET: platform_payload["error_message"] = error_message platform_payload["updated_at"] = _utc_now_iso() payload["platforms"][platform] = platform_payload diff --git a/tests/gateway/test_discord_connect.py b/tests/gateway/test_discord_connect.py index 9f094dd0dd..04490f2462 100644 --- a/tests/gateway/test_discord_connect.py +++ b/tests/gateway/test_discord_connect.py @@ -74,6 +74,26 @@ class FakeBot: return None +class SlowSyncTree(FakeTree): + def __init__(self): + super().__init__() + self.started = asyncio.Event() + self.allow_finish = asyncio.Event() + + async def _slow_sync(): + self.started.set() + await self.allow_finish.wait() + return [] + + self.sync = AsyncMock(side_effect=_slow_sync) + + +class SlowSyncBot(FakeBot): + def __init__(self, *, intents, proxy=None): + super().__init__(intents=intents, proxy=proxy) + self.tree = SlowSyncTree() + + @pytest.mark.asyncio @pytest.mark.parametrize( ("allowed_users", "expected_members_intent"), @@ -138,3 +158,36 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch): assert ok is False assert released == [("discord-bot-token", "test-token")] assert adapter._platform_lock_identity is None + + +@pytest.mark.asyncio +async def test_connect_does_not_wait_for_slash_sync(monkeypatch): + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="test-token")) + + monkeypatch.setattr("gateway.status.acquire_scoped_lock", lambda scope, identity, metadata=None: (True, None)) + monkeypatch.setattr("gateway.status.release_scoped_lock", lambda scope, identity: None) + + intents = SimpleNamespace(message_content=False, dm_messages=False, guild_messages=False, members=False, voice_states=False) + monkeypatch.setattr(discord_platform.Intents, "default", lambda: intents) + + created = {} + + def fake_bot_factory(*, command_prefix, intents, proxy=None): + bot = SlowSyncBot(intents=intents, proxy=proxy) + created["bot"] = bot + return bot + + monkeypatch.setattr(discord_platform.commands, "Bot", fake_bot_factory) + monkeypatch.setattr(adapter, "_resolve_allowed_usernames", AsyncMock()) + + ok = await asyncio.wait_for(adapter.connect(), timeout=1.0) + + assert ok is True + assert adapter._ready_event.is_set() + + await asyncio.wait_for(created["bot"].tree.started.wait(), timeout=1.0) + assert created["bot"].tree.sync.await_count == 1 + + created["bot"].tree.allow_finish.set() + await asyncio.sleep(0) + await adapter.disconnect() diff --git a/tests/gateway/test_runner_startup_failures.py b/tests/gateway/test_runner_startup_failures.py index 1be67b71bb..787cb0adad 100644 --- a/tests/gateway/test_runner_startup_failures.py +++ b/tests/gateway/test_runner_startup_failures.py @@ -1,4 +1,5 @@ import pytest +from unittest.mock import AsyncMock from gateway.config import GatewayConfig, Platform, PlatformConfig from gateway.platforms.base import BasePlatformAdapter @@ -45,6 +46,23 @@ class _DisabledAdapter(BasePlatformAdapter): return {"id": chat_id} +class _SuccessfulAdapter(BasePlatformAdapter): + def __init__(self): + super().__init__(PlatformConfig(enabled=True, token="***"), Platform.DISCORD) + + async def connect(self) -> bool: + return True + + async def disconnect(self) -> None: + self._mark_disconnected() + + async def send(self, chat_id, content, reply_to=None, metadata=None): + raise NotImplementedError + + async def get_chat_info(self, chat_id): + return {"id": chat_id} + + @pytest.mark.asyncio async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch, tmp_path): monkeypatch.setenv("HERMES_HOME", str(tmp_path)) @@ -65,7 +83,7 @@ async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch, state = read_runtime_status() assert state["gateway_state"] == "startup_failed" assert "temporary DNS resolution failure" in state["exit_reason"] - assert state["platforms"]["telegram"]["state"] == "fatal" + assert state["platforms"]["telegram"]["state"] == "retrying" assert state["platforms"]["telegram"]["error_code"] == "telegram_connect_error" @@ -89,6 +107,31 @@ async def test_runner_allows_cron_only_mode_when_no_platforms_are_enabled(monkey assert state["gateway_state"] == "running" +@pytest.mark.asyncio +async def test_runner_records_connected_platform_state_on_success(monkeypatch, tmp_path): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig(enabled=True, token="***") + }, + sessions_dir=tmp_path / "sessions", + ) + runner = GatewayRunner(config) + + monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _SuccessfulAdapter()) + monkeypatch.setattr(runner.hooks, "discover_and_load", lambda: None) + monkeypatch.setattr(runner.hooks, "emit", AsyncMock()) + + ok = await runner.start() + + assert ok is True + state = read_runtime_status() + assert state["gateway_state"] == "running" + assert state["platforms"]["discord"]["state"] == "connected" + assert state["platforms"]["discord"]["error_code"] is None + assert state["platforms"]["discord"]["error_message"] is None + + @pytest.mark.asyncio async def test_start_gateway_replace_force_uses_terminate_pid(monkeypatch, tmp_path): monkeypatch.setenv("HERMES_HOME", str(tmp_path)) diff --git a/tests/gateway/test_status.py b/tests/gateway/test_status.py index 6792061f92..16d4bfc5e8 100644 --- a/tests/gateway/test_status.py +++ b/tests/gateway/test_status.py @@ -104,6 +104,34 @@ class TestGatewayRuntimeStatus: assert payload["platforms"]["telegram"]["error_code"] == "telegram_polling_conflict" assert payload["platforms"]["telegram"]["error_message"] == "another poller is active" + def test_write_runtime_status_explicit_none_clears_stale_fields(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + status.write_runtime_status( + gateway_state="startup_failed", + exit_reason="stale error", + platform="discord", + platform_state="fatal", + error_code="discord_timeout", + error_message="stale platform error", + ) + + status.write_runtime_status( + gateway_state="running", + exit_reason=None, + platform="discord", + platform_state="connected", + error_code=None, + error_message=None, + ) + + payload = status.read_runtime_status() + assert payload["gateway_state"] == "running" + assert payload["exit_reason"] is None + assert payload["platforms"]["discord"]["state"] == "connected" + assert payload["platforms"]["discord"]["error_code"] is None + assert payload["platforms"]["discord"]["error_message"] is None + class TestTerminatePid: def test_force_uses_taskkill_on_windows(self, monkeypatch):