fix: follow-up improvements for watch notification routing (#9537)

- Populate watcher_* routing fields for watch-only processes (not just
  notify_on_complete), so watch-pattern events carry direct metadata
  instead of relying solely on session_key parsing fallback
- Extract _parse_session_key() helper to dedupe session key parsing
  at two call sites in gateway/run.py
- Add negative test proving cross-thread leakage doesn't happen
- Add edge-case tests for _build_process_event_source returning None
  (empty evt, invalid platform, short session_key)
- Add unit tests for _parse_session_key helper
This commit is contained in:
kshitijk4poor 2026-04-15 22:52:30 +05:30
parent f871ec5a69
commit e47abf5c69
3 changed files with 140 additions and 24 deletions

View file

@ -482,6 +482,23 @@ def _resolve_hermes_bin() -> Optional[list[str]]:
return None
def _parse_session_key(session_key: str) -> "dict | None":
"""Parse a session key into its component parts.
Session keys follow the format ``agent:main:{platform}:{chat_type}:{chat_id}``.
Returns a dict with ``platform``, ``chat_type``, and ``chat_id`` keys,
or None if the key doesn't match the expected format.
"""
parts = session_key.split(":")
if len(parts) >= 5 and parts[0] == "agent" and parts[1] == "main":
return {
"platform": parts[2],
"chat_type": parts[3],
"chat_id": parts[4],
}
return None
def _format_gateway_process_notification(evt: dict) -> "str | None":
"""Format a watch pattern event from completion_queue into a [SYSTEM:] message."""
evt_type = evt.get("type", "completion")
@ -1489,12 +1506,11 @@ class GatewayRunner:
notified: set = set()
for session_key in active:
# Parse platform + chat_id from the session key.
# Format: agent:main:{platform}:{chat_type}:{chat_id}[:{extra}...]
parts = session_key.split(":")
if len(parts) < 5:
_parsed = _parse_session_key(session_key)
if not _parsed:
continue
platform_str = parts[2]
chat_id = parts[4]
platform_str = _parsed["platform"]
chat_id = _parsed["chat_id"]
# Deduplicate: one notification per chat, even if multiple
# sessions (different users/threads) share the same chat.
@ -7479,11 +7495,11 @@ class GatewayRunner:
exc,
)
parts = session_key.split(":")
if len(parts) >= 5 and parts[0] == "agent" and parts[1] == "main":
derived_platform = parts[2]
derived_chat_type = parts[3]
derived_chat_id = parts[4]
_parsed = _parse_session_key(session_key)
if _parsed:
derived_platform = _parsed["platform"]
derived_chat_type = _parsed["chat_type"]
derived_chat_id = _parsed["chat_id"]
platform_name = str(evt.get("platform") or derived_platform or "").strip().lower()
chat_type = str(evt.get("chat_type") or derived_chat_type or "").strip().lower()

View file

@ -14,7 +14,7 @@ from unittest.mock import AsyncMock, patch
import pytest
from gateway.config import GatewayConfig, Platform
from gateway.run import GatewayRunner
from gateway.run import GatewayRunner, _parse_session_key
# ---------------------------------------------------------------------------
@ -302,3 +302,97 @@ def test_build_process_event_source_falls_back_to_session_key_chat_type(monkeypa
assert source.thread_id == "42"
assert source.user_id == "123"
assert source.user_name == "Emiliyan"
@pytest.mark.asyncio
async def test_inject_watch_notification_ignores_foreground_event_source(monkeypatch, tmp_path):
"""Negative test: watch notification must NOT route to the foreground thread."""
from gateway.session import SessionSource
runner = _build_runner(monkeypatch, tmp_path, "all")
adapter = runner.adapters[Platform.TELEGRAM]
# Session store has the process's original thread (thread 42)
runner.session_store._entries["agent:main:telegram:group:-100:42"] = SimpleNamespace(
origin=SessionSource(
platform=Platform.TELEGRAM,
chat_id="-100",
chat_type="group",
thread_id="42",
user_id="proc_owner",
user_name="alice",
)
)
# The evt dict carries the correct session_key — NOT a foreground event
evt = {
"session_id": "proc_cross_thread",
"session_key": "agent:main:telegram:group:-100:42",
}
await runner._inject_watch_notification("[SYSTEM: watch match]", evt)
adapter.handle_message.assert_awaited_once()
synth_event = adapter.handle_message.await_args.args[0]
# Must route to thread 42 (process origin), NOT some other thread
assert synth_event.source.thread_id == "42"
assert synth_event.source.user_id == "proc_owner"
def test_build_process_event_source_returns_none_for_empty_evt(monkeypatch, tmp_path):
"""Missing session_key and no platform metadata → None (drop notification)."""
runner = _build_runner(monkeypatch, tmp_path, "all")
source = runner._build_process_event_source({"session_id": "proc_orphan"})
assert source is None
def test_build_process_event_source_returns_none_for_invalid_platform(monkeypatch, tmp_path):
"""Invalid platform string → None."""
runner = _build_runner(monkeypatch, tmp_path, "all")
evt = {
"session_id": "proc_bad",
"platform": "not_a_real_platform",
"chat_type": "dm",
"chat_id": "123",
}
source = runner._build_process_event_source(evt)
assert source is None
def test_build_process_event_source_returns_none_for_short_session_key(monkeypatch, tmp_path):
"""Session key with <5 parts doesn't parse, falls through to empty metadata → None."""
runner = _build_runner(monkeypatch, tmp_path, "all")
evt = {
"session_id": "proc_short",
"session_key": "agent:main:telegram", # Too few parts
}
source = runner._build_process_event_source(evt)
assert source is None
# ---------------------------------------------------------------------------
# _parse_session_key helper
# ---------------------------------------------------------------------------
def test_parse_session_key_valid():
result = _parse_session_key("agent:main:telegram:group:-100")
assert result == {"platform": "telegram", "chat_type": "group", "chat_id": "-100"}
def test_parse_session_key_with_extra_parts():
"""Extra trailing parts (thread_id etc.) are ignored — only first 5 matter."""
result = _parse_session_key("agent:main:discord:group:chan123:thread456")
assert result == {"platform": "discord", "chat_type": "group", "chat_id": "chan123"}
def test_parse_session_key_too_short():
assert _parse_session_key("agent:main:telegram") is None
assert _parse_session_key("") is None
def test_parse_session_key_wrong_prefix():
assert _parse_session_key("cron:main:telegram:dm:123") is None
assert _parse_session_key("agent:cron:telegram:dm:123") is None

View file

@ -1384,14 +1384,10 @@ def terminal_tool(
if pty_disabled_reason:
result_data["pty_note"] = pty_disabled_reason
# Mark for agent notification on completion
if notify_on_complete and background:
proc_session.notify_on_complete = True
result_data["notify_on_complete"] = True
# In gateway mode, auto-register a fast watcher so the
# gateway can detect completion and trigger a new agent
# turn. CLI mode uses the completion_queue directly.
# Populate routing metadata on the session so that
# watch-pattern and completion notifications can be
# routed back to the correct chat/thread.
if background and (notify_on_complete or watch_patterns):
from gateway.session_context import get_session_env as _gse
_gw_platform = _gse("HERMES_SESSION_PLATFORM", "")
if _gw_platform:
@ -1404,16 +1400,26 @@ def terminal_tool(
proc_session.watcher_user_id = _gw_user_id
proc_session.watcher_user_name = _gw_user_name
proc_session.watcher_thread_id = _gw_thread_id
# Mark for agent notification on completion
if notify_on_complete and background:
proc_session.notify_on_complete = True
result_data["notify_on_complete"] = True
# In gateway mode, auto-register a fast watcher so the
# gateway can detect completion and trigger a new agent
# turn. CLI mode uses the completion_queue directly.
if proc_session.watcher_platform:
proc_session.watcher_interval = 5
process_registry.pending_watchers.append({
"session_id": proc_session.id,
"check_interval": 5,
"session_key": session_key,
"platform": _gw_platform,
"chat_id": _gw_chat_id,
"user_id": _gw_user_id,
"user_name": _gw_user_name,
"thread_id": _gw_thread_id,
"platform": proc_session.watcher_platform,
"chat_id": proc_session.watcher_chat_id,
"user_id": proc_session.watcher_user_id,
"user_name": proc_session.watcher_user_name,
"thread_id": proc_session.watcher_thread_id,
"notify_on_complete": True,
})