diff --git a/gateway/run.py b/gateway/run.py index 670ec4c86..f9bf9a38b 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -482,6 +482,23 @@ def _resolve_hermes_bin() -> Optional[list[str]]: return None +def _parse_session_key(session_key: str) -> "dict | None": + """Parse a session key into its component parts. + + Session keys follow the format ``agent:main:{platform}:{chat_type}:{chat_id}``. + Returns a dict with ``platform``, ``chat_type``, and ``chat_id`` keys, + or None if the key doesn't match the expected format. + """ + parts = session_key.split(":") + if len(parts) >= 5 and parts[0] == "agent" and parts[1] == "main": + return { + "platform": parts[2], + "chat_type": parts[3], + "chat_id": parts[4], + } + return None + + def _format_gateway_process_notification(evt: dict) -> "str | None": """Format a watch pattern event from completion_queue into a [SYSTEM:] message.""" evt_type = evt.get("type", "completion") @@ -1489,12 +1506,11 @@ class GatewayRunner: notified: set = set() for session_key in active: # Parse platform + chat_id from the session key. - # Format: agent:main:{platform}:{chat_type}:{chat_id}[:{extra}...] - parts = session_key.split(":") - if len(parts) < 5: + _parsed = _parse_session_key(session_key) + if not _parsed: continue - platform_str = parts[2] - chat_id = parts[4] + platform_str = _parsed["platform"] + chat_id = _parsed["chat_id"] # Deduplicate: one notification per chat, even if multiple # sessions (different users/threads) share the same chat. @@ -7479,11 +7495,11 @@ class GatewayRunner: exc, ) - parts = session_key.split(":") - if len(parts) >= 5 and parts[0] == "agent" and parts[1] == "main": - derived_platform = parts[2] - derived_chat_type = parts[3] - derived_chat_id = parts[4] + _parsed = _parse_session_key(session_key) + if _parsed: + derived_platform = _parsed["platform"] + derived_chat_type = _parsed["chat_type"] + derived_chat_id = _parsed["chat_id"] platform_name = str(evt.get("platform") or derived_platform or "").strip().lower() chat_type = str(evt.get("chat_type") or derived_chat_type or "").strip().lower() diff --git a/tests/gateway/test_background_process_notifications.py b/tests/gateway/test_background_process_notifications.py index 90e9e063a..68eb5e304 100644 --- a/tests/gateway/test_background_process_notifications.py +++ b/tests/gateway/test_background_process_notifications.py @@ -14,7 +14,7 @@ from unittest.mock import AsyncMock, patch import pytest from gateway.config import GatewayConfig, Platform -from gateway.run import GatewayRunner +from gateway.run import GatewayRunner, _parse_session_key # --------------------------------------------------------------------------- @@ -302,3 +302,97 @@ def test_build_process_event_source_falls_back_to_session_key_chat_type(monkeypa assert source.thread_id == "42" assert source.user_id == "123" assert source.user_name == "Emiliyan" + + +@pytest.mark.asyncio +async def test_inject_watch_notification_ignores_foreground_event_source(monkeypatch, tmp_path): + """Negative test: watch notification must NOT route to the foreground thread.""" + from gateway.session import SessionSource + + runner = _build_runner(monkeypatch, tmp_path, "all") + adapter = runner.adapters[Platform.TELEGRAM] + + # Session store has the process's original thread (thread 42) + runner.session_store._entries["agent:main:telegram:group:-100:42"] = SimpleNamespace( + origin=SessionSource( + platform=Platform.TELEGRAM, + chat_id="-100", + chat_type="group", + thread_id="42", + user_id="proc_owner", + user_name="alice", + ) + ) + + # The evt dict carries the correct session_key — NOT a foreground event + evt = { + "session_id": "proc_cross_thread", + "session_key": "agent:main:telegram:group:-100:42", + } + + await runner._inject_watch_notification("[SYSTEM: watch match]", evt) + + adapter.handle_message.assert_awaited_once() + synth_event = adapter.handle_message.await_args.args[0] + # Must route to thread 42 (process origin), NOT some other thread + assert synth_event.source.thread_id == "42" + assert synth_event.source.user_id == "proc_owner" + + +def test_build_process_event_source_returns_none_for_empty_evt(monkeypatch, tmp_path): + """Missing session_key and no platform metadata → None (drop notification).""" + runner = _build_runner(monkeypatch, tmp_path, "all") + + source = runner._build_process_event_source({"session_id": "proc_orphan"}) + assert source is None + + +def test_build_process_event_source_returns_none_for_invalid_platform(monkeypatch, tmp_path): + """Invalid platform string → None.""" + runner = _build_runner(monkeypatch, tmp_path, "all") + + evt = { + "session_id": "proc_bad", + "platform": "not_a_real_platform", + "chat_type": "dm", + "chat_id": "123", + } + source = runner._build_process_event_source(evt) + assert source is None + + +def test_build_process_event_source_returns_none_for_short_session_key(monkeypatch, tmp_path): + """Session key with <5 parts doesn't parse, falls through to empty metadata → None.""" + runner = _build_runner(monkeypatch, tmp_path, "all") + + evt = { + "session_id": "proc_short", + "session_key": "agent:main:telegram", # Too few parts + } + source = runner._build_process_event_source(evt) + assert source is None + + +# --------------------------------------------------------------------------- +# _parse_session_key helper +# --------------------------------------------------------------------------- + +def test_parse_session_key_valid(): + result = _parse_session_key("agent:main:telegram:group:-100") + assert result == {"platform": "telegram", "chat_type": "group", "chat_id": "-100"} + + +def test_parse_session_key_with_extra_parts(): + """Extra trailing parts (thread_id etc.) are ignored — only first 5 matter.""" + result = _parse_session_key("agent:main:discord:group:chan123:thread456") + assert result == {"platform": "discord", "chat_type": "group", "chat_id": "chan123"} + + +def test_parse_session_key_too_short(): + assert _parse_session_key("agent:main:telegram") is None + assert _parse_session_key("") is None + + +def test_parse_session_key_wrong_prefix(): + assert _parse_session_key("cron:main:telegram:dm:123") is None + assert _parse_session_key("agent:cron:telegram:dm:123") is None diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index 65f84e146..55f4c10a8 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -1384,14 +1384,10 @@ def terminal_tool( if pty_disabled_reason: result_data["pty_note"] = pty_disabled_reason - # Mark for agent notification on completion - if notify_on_complete and background: - proc_session.notify_on_complete = True - result_data["notify_on_complete"] = True - - # In gateway mode, auto-register a fast watcher so the - # gateway can detect completion and trigger a new agent - # turn. CLI mode uses the completion_queue directly. + # Populate routing metadata on the session so that + # watch-pattern and completion notifications can be + # routed back to the correct chat/thread. + if background and (notify_on_complete or watch_patterns): from gateway.session_context import get_session_env as _gse _gw_platform = _gse("HERMES_SESSION_PLATFORM", "") if _gw_platform: @@ -1404,16 +1400,26 @@ def terminal_tool( proc_session.watcher_user_id = _gw_user_id proc_session.watcher_user_name = _gw_user_name proc_session.watcher_thread_id = _gw_thread_id + + # Mark for agent notification on completion + if notify_on_complete and background: + proc_session.notify_on_complete = True + result_data["notify_on_complete"] = True + + # In gateway mode, auto-register a fast watcher so the + # gateway can detect completion and trigger a new agent + # turn. CLI mode uses the completion_queue directly. + if proc_session.watcher_platform: proc_session.watcher_interval = 5 process_registry.pending_watchers.append({ "session_id": proc_session.id, "check_interval": 5, "session_key": session_key, - "platform": _gw_platform, - "chat_id": _gw_chat_id, - "user_id": _gw_user_id, - "user_name": _gw_user_name, - "thread_id": _gw_thread_id, + "platform": proc_session.watcher_platform, + "chat_id": proc_session.watcher_chat_id, + "user_id": proc_session.watcher_user_id, + "user_name": proc_session.watcher_user_name, + "thread_id": proc_session.watcher_thread_id, "notify_on_complete": True, })