diff --git a/cron/scheduler.py b/cron/scheduler.py index a8464cce6..d5967d6ab 100644 --- a/cron/scheduler.py +++ b/cron/scheduler.py @@ -45,7 +45,7 @@ _LOCK_FILE = _LOCK_DIR / ".tick.lock" def _resolve_origin(job: dict) -> Optional[dict]: - """Extract origin info from a job, returning {platform, chat_id, chat_name} or None.""" + """Extract origin info from a job, preserving any extra routing metadata.""" origin = job.get("origin") if not origin: return None @@ -69,6 +69,8 @@ def _deliver_result(job: dict, content: str) -> None: if deliver == "local": return + thread_id = None + # Resolve target platform + chat_id if deliver == "origin": if not origin: @@ -76,6 +78,7 @@ def _deliver_result(job: dict, content: str) -> None: return platform_name = origin["platform"] chat_id = origin["chat_id"] + thread_id = origin.get("thread_id") elif ":" in deliver: platform_name, chat_id = deliver.split(":", 1) else: @@ -83,6 +86,7 @@ def _deliver_result(job: dict, content: str) -> None: platform_name = deliver if origin and origin.get("platform") == platform_name: chat_id = origin["chat_id"] + thread_id = origin.get("thread_id") else: # Fall back to home channel chat_id = os.getenv(f"{platform_name.upper()}_HOME_CHANNEL", "") @@ -118,13 +122,13 @@ def _deliver_result(job: dict, content: str) -> None: # Run the async send in a fresh event loop (safe from any thread) try: - result = asyncio.run(_send_to_platform(platform, pconfig, chat_id, content)) + result = asyncio.run(_send_to_platform(platform, pconfig, chat_id, content, thread_id=thread_id)) except RuntimeError: # asyncio.run() fails if there's already a running loop in this thread; # spin up a new thread to avoid that. import concurrent.futures with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: - future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, content)) + future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, content, thread_id=thread_id)) result = future.result(timeout=30) except Exception as e: logger.error("Job '%s': delivery to %s:%s failed: %s", job["id"], platform_name, chat_id, e) @@ -137,7 +141,7 @@ def _deliver_result(job: dict, content: str) -> None: # Mirror the delivered content into the target's gateway session try: from gateway.mirror import mirror_to_session - mirror_to_session(platform_name, chat_id, content, source_label="cron") + mirror_to_session(platform_name, chat_id, content, source_label="cron", thread_id=thread_id) except Exception as e: logger.warning("Job '%s': mirror_to_session failed: %s", job["id"], e) diff --git a/gateway/channel_directory.py b/gateway/channel_directory.py index 31406a7de..858859fd2 100644 --- a/gateway/channel_directory.py +++ b/gateway/channel_directory.py @@ -17,6 +17,26 @@ logger = logging.getLogger(__name__) DIRECTORY_PATH = Path.home() / ".hermes" / "channel_directory.json" +def _session_entry_id(origin: Dict[str, Any]) -> Optional[str]: + chat_id = origin.get("chat_id") + if not chat_id: + return None + thread_id = origin.get("thread_id") + if thread_id: + return f"{chat_id}:{thread_id}" + return str(chat_id) + + +def _session_entry_name(origin: Dict[str, Any]) -> str: + base_name = origin.get("chat_name") or origin.get("user_name") or str(origin.get("chat_id")) + thread_id = origin.get("thread_id") + if not thread_id: + return base_name + + topic_label = origin.get("chat_topic") or f"topic {thread_id}" + return f"{base_name} / {topic_label}" + + # --------------------------------------------------------------------------- # Build / refresh # --------------------------------------------------------------------------- @@ -123,14 +143,15 @@ def _build_from_sessions(platform_name: str) -> List[Dict[str, str]]: origin = session.get("origin") or {} if origin.get("platform") != platform_name: continue - chat_id = origin.get("chat_id") - if not chat_id or chat_id in seen_ids: + entry_id = _session_entry_id(origin) + if not entry_id or entry_id in seen_ids: continue - seen_ids.add(chat_id) + seen_ids.add(entry_id) entries.append({ - "id": str(chat_id), - "name": origin.get("chat_name") or origin.get("user_name") or str(chat_id), + "id": entry_id, + "name": _session_entry_name(origin), "type": session.get("chat_type", "dm"), + "thread_id": origin.get("thread_id"), }) except Exception as e: logger.debug("Channel directory: failed to read sessions for %s: %s", platform_name, e) diff --git a/gateway/delivery.py b/gateway/delivery.py index 0093c1fb0..5bcd58f4c 100644 --- a/gateway/delivery.py +++ b/gateway/delivery.py @@ -37,6 +37,7 @@ class DeliveryTarget: """ platform: Platform chat_id: Optional[str] = None # None means use home channel + thread_id: Optional[str] = None is_origin: bool = False is_explicit: bool = False # True if chat_id was explicitly specified @@ -58,6 +59,7 @@ class DeliveryTarget: return cls( platform=origin.platform, chat_id=origin.chat_id, + thread_id=origin.thread_id, is_origin=True, ) else: @@ -150,7 +152,7 @@ class DeliveryRouter: continue # Deduplicate - key = (target.platform, target.chat_id) + key = (target.platform, target.chat_id, target.thread_id) if key not in seen_platforms: seen_platforms.add(key) targets.append(target) @@ -285,7 +287,10 @@ class DeliveryRouter: + f"\n\n... [truncated, full output saved to {saved_path}]" ) - return await adapter.send(target.chat_id, content, metadata=metadata) + send_metadata = dict(metadata or {}) + if target.thread_id and "thread_id" not in send_metadata: + send_metadata["thread_id"] = target.thread_id + return await adapter.send(target.chat_id, content, metadata=send_metadata or None) def parse_deliver_spec( diff --git a/gateway/mirror.py b/gateway/mirror.py index 1fbd55d51..f54e6e1a3 100644 --- a/gateway/mirror.py +++ b/gateway/mirror.py @@ -26,6 +26,7 @@ def mirror_to_session( chat_id: str, message_text: str, source_label: str = "cli", + thread_id: Optional[str] = None, ) -> bool: """ Append a delivery-mirror message to the target session's transcript. @@ -37,9 +38,9 @@ def mirror_to_session( All errors are caught -- this is never fatal. """ try: - session_id = _find_session_id(platform, str(chat_id)) + session_id = _find_session_id(platform, str(chat_id), thread_id=thread_id) if not session_id: - logger.debug("Mirror: no session found for %s:%s", platform, chat_id) + logger.debug("Mirror: no session found for %s:%s:%s", platform, chat_id, thread_id) return False mirror_msg = { @@ -57,11 +58,11 @@ def mirror_to_session( return True except Exception as e: - logger.debug("Mirror failed for %s:%s: %s", platform, chat_id, e) + logger.debug("Mirror failed for %s:%s:%s: %s", platform, chat_id, thread_id, e) return False -def _find_session_id(platform: str, chat_id: str) -> Optional[str]: +def _find_session_id(platform: str, chat_id: str, thread_id: Optional[str] = None) -> Optional[str]: """ Find the active session_id for a platform + chat_id pair. @@ -91,6 +92,9 @@ def _find_session_id(platform: str, chat_id: str) -> Optional[str]: origin_chat_id = str(origin.get("chat_id", "")) if origin_chat_id == str(chat_id): + origin_thread_id = origin.get("thread_id") + if thread_id is not None and str(origin_thread_id or "") != str(thread_id): + continue updated = entry.get("updated_at", "") if updated > best_updated: best_updated = updated diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 1e7436188..f4ab43ea4 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -24,7 +24,7 @@ from pathlib import Path as _Path sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) from gateway.config import Platform, PlatformConfig -from gateway.session import SessionSource +from gateway.session import SessionSource, build_session_key # --------------------------------------------------------------------------- @@ -646,7 +646,7 @@ class BasePlatformAdapter(ABC): if not self._message_handler: return - session_key = event.source.chat_id + session_key = build_session_key(event.source) # Check if there's already an active handler for this session if session_key in self._active_sessions: diff --git a/gateway/run.py b/gateway/run.py index be89833ac..d2c39e889 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -875,7 +875,6 @@ class GatewayRunner: if command in quick_commands: qcmd = quick_commands[command] if qcmd.get("type") == "exec": - import asyncio exec_cmd = qcmd.get("command", "") if exec_cmd: try: @@ -1067,12 +1066,14 @@ class GatewayRunner: ) _hyg_adapter = self.adapters.get(source.platform) + _hyg_meta = {"thread_id": source.thread_id} if source.thread_id else None if _hyg_adapter: try: await _hyg_adapter.send( source.chat_id, f"🗜️ Session is large ({_msg_count} messages, " - f"~{_approx_tokens:,} tokens). Auto-compressing..." + f"~{_approx_tokens:,} tokens). Auto-compressing...", + metadata=_hyg_meta, ) except Exception: pass @@ -1132,7 +1133,8 @@ class GatewayRunner: f"🗜️ Compressed: {_msg_count} → " f"{_new_count} messages, " f"~{_approx_tokens:,} → " - f"~{_new_tokens:,} tokens" + f"~{_new_tokens:,} tokens", + metadata=_hyg_meta, ) except Exception: pass @@ -1152,7 +1154,8 @@ class GatewayRunner: "after compression " f"(~{_new_tokens:,} tokens). " "Consider using /reset to start " - "fresh if you experience issues." + "fresh if you experience issues.", + metadata=_hyg_meta, ) except Exception: pass @@ -1164,6 +1167,7 @@ class GatewayRunner: # Compression failed and session is dangerously large if _approx_tokens >= _warn_token_threshold: _hyg_adapter = self.adapters.get(source.platform) + _hyg_meta = {"thread_id": source.thread_id} if source.thread_id else None if _hyg_adapter: try: await _hyg_adapter.send( @@ -1173,7 +1177,8 @@ class GatewayRunner: f"~{_approx_tokens:,} tokens) and " "auto-compression failed. Consider " "using /compress or /reset to avoid " - "issues." + "issues.", + metadata=_hyg_meta, ) except Exception: pass @@ -2765,7 +2770,7 @@ class GatewayRunner: # Restore typing indicator await asyncio.sleep(0.3) - await adapter.send_typing(source.chat_id) + await adapter.send_typing(source.chat_id, metadata=_progress_metadata) except queue.Empty: await asyncio.sleep(0.3) diff --git a/gateway/session.py b/gateway/session.py index b1cdefa5b..17ca8e4d5 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -306,6 +306,8 @@ def build_session_key(source: SessionSource) -> str: if platform == "whatsapp" and source.chat_id: return f"agent:main:{platform}:dm:{source.chat_id}" return f"agent:main:{platform}:dm" + if source.thread_id: + return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}:{source.thread_id}" return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}" diff --git a/tests/gateway/test_base_topic_sessions.py b/tests/gateway/test_base_topic_sessions.py new file mode 100644 index 000000000..e3ca7ae72 --- /dev/null +++ b/tests/gateway/test_base_topic_sessions.py @@ -0,0 +1,135 @@ +"""Tests for BasePlatformAdapter topic-aware session handling.""" + +import asyncio +from types import SimpleNamespace + +import pytest + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult +from gateway.session import SessionSource, build_session_key + + +class DummyTelegramAdapter(BasePlatformAdapter): + def __init__(self): + super().__init__(PlatformConfig(enabled=True, token="fake-token"), Platform.TELEGRAM) + self.sent = [] + self.typing = [] + + async def connect(self) -> bool: + return True + + async def disconnect(self) -> None: + return None + + async def send(self, chat_id, content, reply_to=None, metadata=None) -> SendResult: + self.sent.append( + { + "chat_id": chat_id, + "content": content, + "reply_to": reply_to, + "metadata": metadata, + } + ) + return SendResult(success=True, message_id="1") + + async def send_typing(self, chat_id: str, metadata=None) -> None: + self.typing.append({"chat_id": chat_id, "metadata": metadata}) + return None + + async def get_chat_info(self, chat_id: str): + return {"id": chat_id} + + +def _make_event(chat_id: str, thread_id: str, message_id: str = "1") -> MessageEvent: + return MessageEvent( + text="hello", + source=SessionSource( + platform=Platform.TELEGRAM, + chat_id=chat_id, + chat_type="group", + thread_id=thread_id, + ), + message_id=message_id, + ) + + +class TestBasePlatformTopicSessions: + @pytest.mark.asyncio + async def test_handle_message_does_not_interrupt_different_topic(self, monkeypatch): + adapter = DummyTelegramAdapter() + adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None)) + + active_event = _make_event("-1001", "10") + adapter._active_sessions[build_session_key(active_event.source)] = asyncio.Event() + + scheduled = [] + + def fake_create_task(coro): + scheduled.append(coro) + coro.close() + return SimpleNamespace() + + monkeypatch.setattr(asyncio, "create_task", fake_create_task) + + await adapter.handle_message(_make_event("-1001", "11")) + + assert len(scheduled) == 1 + assert adapter._pending_messages == {} + + @pytest.mark.asyncio + async def test_handle_message_interrupts_same_topic(self, monkeypatch): + adapter = DummyTelegramAdapter() + adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None)) + + active_event = _make_event("-1001", "10") + adapter._active_sessions[build_session_key(active_event.source)] = asyncio.Event() + + scheduled = [] + + def fake_create_task(coro): + scheduled.append(coro) + coro.close() + return SimpleNamespace() + + monkeypatch.setattr(asyncio, "create_task", fake_create_task) + + pending_event = _make_event("-1001", "10", message_id="2") + await adapter.handle_message(pending_event) + + assert scheduled == [] + assert adapter.get_pending_message(build_session_key(pending_event.source)) == pending_event + + @pytest.mark.asyncio + async def test_process_message_background_replies_in_same_topic(self): + adapter = DummyTelegramAdapter() + typing_calls = [] + + async def handler(_event): + await asyncio.sleep(0) + return "ack" + + async def hold_typing(_chat_id, interval=2.0, metadata=None): + typing_calls.append({"chat_id": _chat_id, "metadata": metadata}) + await asyncio.Event().wait() + + adapter.set_message_handler(handler) + adapter._keep_typing = hold_typing + + event = _make_event("-1001", "17585") + await adapter._process_message_background(event, build_session_key(event.source)) + + assert adapter.sent == [ + { + "chat_id": "-1001", + "content": "ack", + "reply_to": "1", + "metadata": {"thread_id": "17585"}, + } + ] + assert typing_calls == [ + { + "chat_id": "-1001", + "metadata": {"thread_id": "17585"}, + } + ] diff --git a/tests/gateway/test_channel_directory.py b/tests/gateway/test_channel_directory.py index d7562977d..9ff8ac979 100644 --- a/tests/gateway/test_channel_directory.py +++ b/tests/gateway/test_channel_directory.py @@ -111,6 +111,13 @@ class TestResolveChannelName: with self._setup(tmp_path, platforms): assert resolve_channel_name("telegram", "nonexistent") is None + def test_topic_name_resolves_to_composite_id(self, tmp_path): + platforms = { + "telegram": [{"id": "-1001:17585", "name": "Coaching Chat / topic 17585", "type": "group"}] + } + with self._setup(tmp_path, platforms): + assert resolve_channel_name("telegram", "Coaching Chat / topic 17585") == "-1001:17585" + class TestBuildFromSessions: def _write_sessions(self, tmp_path, sessions_data): @@ -169,6 +176,42 @@ class TestBuildFromSessions: assert len(entries) == 1 + def test_keeps_distinct_topics_with_same_chat_id(self, tmp_path): + self._write_sessions(tmp_path, { + "group_root": { + "origin": {"platform": "telegram", "chat_id": "-1001", "chat_name": "Coaching Chat"}, + "chat_type": "group", + }, + "topic_a": { + "origin": { + "platform": "telegram", + "chat_id": "-1001", + "chat_name": "Coaching Chat", + "thread_id": "17585", + }, + "chat_type": "group", + }, + "topic_b": { + "origin": { + "platform": "telegram", + "chat_id": "-1001", + "chat_name": "Coaching Chat", + "thread_id": "17587", + }, + "chat_type": "group", + }, + }) + + with patch.object(Path, "home", return_value=tmp_path): + entries = _build_from_sessions("telegram") + + ids = {entry["id"] for entry in entries} + names = {entry["name"] for entry in entries} + assert ids == {"-1001", "-1001:17585", "-1001:17587"} + assert "Coaching Chat" in names + assert "Coaching Chat / topic 17585" in names + assert "Coaching Chat / topic 17587" in names + class TestFormatDirectoryForDisplay: def test_empty_directory(self, tmp_path): @@ -181,6 +224,7 @@ class TestFormatDirectoryForDisplay: "telegram": [ {"id": "123", "name": "Alice", "type": "dm"}, {"id": "456", "name": "Dev Group", "type": "group"}, + {"id": "-1001:17585", "name": "Coaching Chat / topic 17585", "type": "group"}, ] }) with patch("gateway.channel_directory.DIRECTORY_PATH", cache_file): @@ -189,6 +233,7 @@ class TestFormatDirectoryForDisplay: assert "Telegram:" in result assert "telegram:Alice" in result assert "telegram:Dev Group" in result + assert "telegram:Coaching Chat / topic 17585" in result def test_discord_grouped_by_guild(self, tmp_path): cache_file = _write_directory(tmp_path, { diff --git a/tests/gateway/test_delivery.py b/tests/gateway/test_delivery.py index 124dfee72..42eba781e 100644 --- a/tests/gateway/test_delivery.py +++ b/tests/gateway/test_delivery.py @@ -24,10 +24,11 @@ class TestParseTargetPlatformChat: assert target.chat_id is None def test_origin_with_source(self): - origin = SessionSource(platform=Platform.TELEGRAM, chat_id="789") + origin = SessionSource(platform=Platform.TELEGRAM, chat_id="789", thread_id="42") target = DeliveryTarget.parse("origin", origin=origin) assert target.platform == Platform.TELEGRAM assert target.chat_id == "789" + assert target.thread_id == "42" assert target.is_origin is True def test_origin_without_source(self): @@ -64,7 +65,7 @@ class TestParseDeliverSpec: class TestTargetToStringRoundtrip: def test_origin_roundtrip(self): - origin = SessionSource(platform=Platform.TELEGRAM, chat_id="111") + origin = SessionSource(platform=Platform.TELEGRAM, chat_id="111", thread_id="42") target = DeliveryTarget.parse("origin", origin=origin) assert target.to_string() == "origin" diff --git a/tests/gateway/test_mirror.py b/tests/gateway/test_mirror.py index 928f4eac2..427e720cd 100644 --- a/tests/gateway/test_mirror.py +++ b/tests/gateway/test_mirror.py @@ -57,6 +57,26 @@ class TestFindSessionId: assert result == "sess_new" + def test_thread_id_disambiguates_same_chat(self, tmp_path): + sessions_dir, index_file = _setup_sessions(tmp_path, { + "topic_a": { + "session_id": "sess_topic_a", + "origin": {"platform": "telegram", "chat_id": "-1001", "thread_id": "10"}, + "updated_at": "2026-01-01T00:00:00", + }, + "topic_b": { + "session_id": "sess_topic_b", + "origin": {"platform": "telegram", "chat_id": "-1001", "thread_id": "11"}, + "updated_at": "2026-02-01T00:00:00", + }, + }) + + with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \ + patch.object(mirror_mod, "_SESSIONS_INDEX", index_file): + result = _find_session_id("telegram", "-1001", thread_id="10") + + assert result == "sess_topic_a" + def test_no_match_returns_none(self, tmp_path): sessions_dir, index_file = _setup_sessions(tmp_path, { "sess": { @@ -146,6 +166,29 @@ class TestMirrorToSession: assert msg["mirror"] is True assert msg["mirror_source"] == "cli" + def test_successful_mirror_uses_thread_id(self, tmp_path): + sessions_dir, index_file = _setup_sessions(tmp_path, { + "topic_a": { + "session_id": "sess_topic_a", + "origin": {"platform": "telegram", "chat_id": "-1001", "thread_id": "10"}, + "updated_at": "2026-01-01T00:00:00", + }, + "topic_b": { + "session_id": "sess_topic_b", + "origin": {"platform": "telegram", "chat_id": "-1001", "thread_id": "11"}, + "updated_at": "2026-02-01T00:00:00", + }, + }) + + with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \ + patch.object(mirror_mod, "_SESSIONS_INDEX", index_file), \ + patch("gateway.mirror._append_to_sqlite"): + result = mirror_to_session("telegram", "-1001", "Hello topic!", source_label="cron", thread_id="10") + + assert result is True + assert (sessions_dir / "sess_topic_a.jsonl").exists() + assert not (sessions_dir / "sess_topic_b.jsonl").exists() + def test_no_matching_session(self, tmp_path): sessions_dir, index_file = _setup_sessions(tmp_path, {}) diff --git a/tests/gateway/test_run_progress_topics.py b/tests/gateway/test_run_progress_topics.py new file mode 100644 index 000000000..20ae712a2 --- /dev/null +++ b/tests/gateway/test_run_progress_topics.py @@ -0,0 +1,134 @@ +"""Tests for topic-aware gateway progress updates.""" + +import importlib +import sys +import time +import types +from types import SimpleNamespace + +import pytest + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import BasePlatformAdapter, SendResult +from gateway.session import SessionSource + + +class ProgressCaptureAdapter(BasePlatformAdapter): + def __init__(self): + super().__init__(PlatformConfig(enabled=True, token="fake-token"), Platform.TELEGRAM) + self.sent = [] + self.edits = [] + self.typing = [] + + async def connect(self) -> bool: + return True + + async def disconnect(self) -> None: + return None + + async def send(self, chat_id, content, reply_to=None, metadata=None) -> SendResult: + self.sent.append( + { + "chat_id": chat_id, + "content": content, + "reply_to": reply_to, + "metadata": metadata, + } + ) + return SendResult(success=True, message_id="progress-1") + + async def edit_message(self, chat_id, message_id, content) -> SendResult: + self.edits.append( + { + "chat_id": chat_id, + "message_id": message_id, + "content": content, + } + ) + return SendResult(success=True, message_id=message_id) + + async def send_typing(self, chat_id, metadata=None) -> None: + self.typing.append({"chat_id": chat_id, "metadata": metadata}) + + async def get_chat_info(self, chat_id: str): + return {"id": chat_id} + + +class FakeAgent: + def __init__(self, **kwargs): + self.tool_progress_callback = kwargs["tool_progress_callback"] + self.tools = [] + + def run_conversation(self, message, conversation_history=None, task_id=None): + self.tool_progress_callback("terminal", "pwd") + time.sleep(0.35) + self.tool_progress_callback("browser_navigate", "https://example.com") + time.sleep(0.35) + return { + "final_response": "done", + "messages": [], + "api_calls": 1, + } + + +def _make_runner(adapter): + gateway_run = importlib.import_module("gateway.run") + GatewayRunner = gateway_run.GatewayRunner + + runner = object.__new__(GatewayRunner) + runner.adapters = {Platform.TELEGRAM: adapter} + runner._prefill_messages = [] + runner._ephemeral_system_prompt = "" + runner._reasoning_config = None + runner._provider_routing = {} + runner._fallback_model = None + runner._session_db = None + runner._running_agents = {} + runner.hooks = SimpleNamespace(loaded_hooks=False) + return runner + + +@pytest.mark.asyncio +async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_path): + monkeypatch.setenv("HERMES_TOOL_PROGRESS_MODE", "all") + + fake_dotenv = types.ModuleType("dotenv") + fake_dotenv.load_dotenv = lambda *args, **kwargs: None + monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv) + + fake_run_agent = types.ModuleType("run_agent") + fake_run_agent.AIAgent = FakeAgent + monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) + + adapter = ProgressCaptureAdapter() + runner = _make_runner(adapter) + gateway_run = importlib.import_module("gateway.run") + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "fake"}) + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="-1001", + chat_type="group", + thread_id="17585", + ) + + result = await runner._run_agent( + message="hello", + context_prompt="", + history=[], + source=source, + session_id="sess-1", + session_key="agent:main:telegram:group:-1001:17585", + ) + + assert result["final_response"] == "done" + assert adapter.sent == [ + { + "chat_id": "-1001", + "content": '💻 terminal: "pwd"', + "reply_to": None, + "metadata": {"thread_id": "17585"}, + } + ] + assert adapter.edits + assert all(call["metadata"] == {"thread_id": "17585"} for call in adapter.typing) diff --git a/tests/gateway/test_session.py b/tests/gateway/test_session.py index 7a7f4b878..e25a0a9c7 100644 --- a/tests/gateway/test_session.py +++ b/tests/gateway/test_session.py @@ -368,6 +368,17 @@ class TestWhatsAppDMSessionKeyConsistency: key = build_session_key(source) assert key == "agent:main:discord:group:guild-123" + def test_group_thread_includes_thread_id(self): + """Forum-style threads need a distinct session key within one group.""" + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="-1002285219667", + chat_type="group", + thread_id="17585", + ) + key = build_session_key(source) + assert key == "agent:main:telegram:group:-1002285219667:17585" + class TestSessionStoreEntriesAttribute: """Regression: /reset must access _entries, not _sessions.""" diff --git a/tests/gateway/test_session_hygiene.py b/tests/gateway/test_session_hygiene.py index 9ac7b8029..d627c2056 100644 --- a/tests/gateway/test_session_hygiene.py +++ b/tests/gateway/test_session_hygiene.py @@ -8,9 +8,19 @@ The hygiene system uses the SAME compression config as the agent: so CLI and messaging platforms behave identically. """ -import pytest +import importlib +import sys +import types +from datetime import datetime +from types import SimpleNamespace from unittest.mock import patch, MagicMock, AsyncMock + +import pytest + from agent.model_metadata import estimate_messages_tokens_rough +from gateway.config import GatewayConfig, Platform, PlatformConfig +from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult +from gateway.session import SessionEntry, SessionSource # --------------------------------------------------------------------------- @@ -41,6 +51,32 @@ def _make_large_history_tokens(target_tokens: int) -> list: return _make_history(n_msgs, content_size=content_size) +class HygieneCaptureAdapter(BasePlatformAdapter): + def __init__(self): + super().__init__(PlatformConfig(enabled=True, token="fake-token"), Platform.TELEGRAM) + self.sent = [] + + async def connect(self) -> bool: + return True + + async def disconnect(self) -> None: + return None + + async def send(self, chat_id, content, reply_to=None, metadata=None) -> SendResult: + self.sent.append( + { + "chat_id": chat_id, + "content": content, + "reply_to": reply_to, + "metadata": metadata, + } + ) + return SendResult(success=True, message_id="hygiene-1") + + async def get_chat_info(self, chat_id: str): + return {"id": chat_id} + + # --------------------------------------------------------------------------- # Detection threshold tests (model-aware, unified with compression config) # --------------------------------------------------------------------------- @@ -202,3 +238,90 @@ class TestTokenEstimation: # Should be well above the 170K threshold for a 200k model threshold = int(200_000 * 0.85) assert tokens > threshold + + +@pytest.mark.asyncio +async def test_session_hygiene_messages_stay_in_originating_topic(monkeypatch, tmp_path): + fake_dotenv = types.ModuleType("dotenv") + fake_dotenv.load_dotenv = lambda *args, **kwargs: None + monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv) + + class FakeCompressAgent: + def __init__(self, **kwargs): + self.model = kwargs.get("model") + + def _compress_context(self, messages, *_args, **_kwargs): + return ([{"role": "assistant", "content": "compressed"}], None) + + fake_run_agent = types.ModuleType("run_agent") + fake_run_agent.AIAgent = FakeCompressAgent + monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) + + gateway_run = importlib.import_module("gateway.run") + GatewayRunner = gateway_run.GatewayRunner + + adapter = HygieneCaptureAdapter() + runner = object.__new__(GatewayRunner) + runner.config = GatewayConfig( + platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake-token")} + ) + runner.adapters = {Platform.TELEGRAM: adapter} + runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False) + runner.session_store = MagicMock() + runner.session_store.get_or_create_session.return_value = SessionEntry( + session_key="agent:main:telegram:group:-1001:17585", + session_id="sess-1", + created_at=datetime.now(), + updated_at=datetime.now(), + platform=Platform.TELEGRAM, + chat_type="group", + ) + runner.session_store.load_transcript.return_value = _make_history(6, content_size=400) + runner.session_store.has_any_sessions.return_value = True + runner.session_store.rewrite_transcript = MagicMock() + runner.session_store.append_to_transcript = MagicMock() + runner._running_agents = {} + runner._pending_messages = {} + runner._pending_approvals = {} + runner._session_db = None + runner._is_user_authorized = lambda _source: True + runner._set_session_env = lambda _context: None + runner._run_agent = AsyncMock( + return_value={ + "final_response": "ok", + "messages": [], + "tools": [], + "history_offset": 0, + "last_prompt_tokens": 0, + } + ) + + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "fake"}) + monkeypatch.setattr( + "agent.model_metadata.get_model_context_length", + lambda *_args, **_kwargs: 100, + ) + monkeypatch.setenv("TELEGRAM_HOME_CHANNEL", "795544298") + + event = MessageEvent( + text="hello", + source=SessionSource( + platform=Platform.TELEGRAM, + chat_id="-1001", + chat_type="group", + thread_id="17585", + ), + message_id="1", + ) + + result = await runner._handle_message(event) + + assert result == "ok" + assert len(adapter.sent) == 2 + assert adapter.sent[0]["chat_id"] == "-1001" + assert "Session is large" in adapter.sent[0]["content"] + assert adapter.sent[0]["metadata"] == {"thread_id": "17585"} + assert adapter.sent[1]["chat_id"] == "-1001" + assert "Compressed:" in adapter.sent[1]["content"] + assert adapter.sent[1]["metadata"] == {"thread_id": "17585"} diff --git a/tests/tools/test_send_message_tool.py b/tests/tools/test_send_message_tool.py new file mode 100644 index 000000000..fc037bc84 --- /dev/null +++ b/tests/tools/test_send_message_tool.py @@ -0,0 +1,67 @@ +"""Tests for tools/send_message_tool.py.""" + +import asyncio +import json +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +from gateway.config import Platform +from tools.send_message_tool import send_message_tool + + +def _run_async_immediately(coro): + return asyncio.run(coro) + + +def _make_config(): + telegram_cfg = SimpleNamespace(enabled=True, token="fake-token", extra={}) + return SimpleNamespace( + platforms={Platform.TELEGRAM: telegram_cfg}, + get_home_channel=lambda _platform: None, + ), telegram_cfg + + +class TestSendMessageTool: + def test_sends_to_explicit_telegram_topic_target(self): + config, telegram_cfg = _make_config() + + with patch("gateway.config.load_gateway_config", return_value=config), \ + patch("tools.interrupt.is_interrupted", return_value=False), \ + patch("model_tools._run_async", side_effect=_run_async_immediately), \ + patch("tools.send_message_tool._send_to_platform", new=AsyncMock(return_value={"success": True})) as send_mock, \ + patch("gateway.mirror.mirror_to_session", return_value=True) as mirror_mock: + result = json.loads( + send_message_tool( + { + "action": "send", + "target": "telegram:-1001:17585", + "message": "hello", + } + ) + ) + + assert result["success"] is True + send_mock.assert_awaited_once_with(Platform.TELEGRAM, telegram_cfg, "-1001", "hello", thread_id="17585") + mirror_mock.assert_called_once_with("telegram", "-1001", "hello", source_label="cli", thread_id="17585") + + def test_resolved_telegram_topic_name_preserves_thread_id(self): + config, telegram_cfg = _make_config() + + with patch("gateway.config.load_gateway_config", return_value=config), \ + patch("tools.interrupt.is_interrupted", return_value=False), \ + patch("gateway.channel_directory.resolve_channel_name", return_value="-1001:17585"), \ + patch("model_tools._run_async", side_effect=_run_async_immediately), \ + patch("tools.send_message_tool._send_to_platform", new=AsyncMock(return_value={"success": True})) as send_mock, \ + patch("gateway.mirror.mirror_to_session", return_value=True): + result = json.loads( + send_message_tool( + { + "action": "send", + "target": "telegram:Coaching Chat / topic 17585", + "message": "hello", + } + ) + ) + + assert result["success"] is True + send_mock.assert_awaited_once_with(Platform.TELEGRAM, telegram_cfg, "-1001", "hello", thread_id="17585") diff --git a/tools/send_message_tool.py b/tools/send_message_tool.py index 8f5dbb61c..f0b1dd27a 100644 --- a/tools/send_message_tool.py +++ b/tools/send_message_tool.py @@ -8,10 +8,13 @@ human-friendly channel names to IDs. Works in both CLI and gateway contexts. import json import logging import os +import re import time logger = logging.getLogger(__name__) +_TELEGRAM_TOPIC_TARGET_RE = re.compile(r"^\s*(-?\d+)(?::(\d+))?\s*$") + SEND_MESSAGE_SCHEMA = { "name": "send_message", @@ -33,7 +36,7 @@ SEND_MESSAGE_SCHEMA = { }, "target": { "type": "string", - "description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', or 'platform:chat_id'. Examples: 'telegram', 'discord:#bot-home', 'slack:#engineering', 'signal:+15551234567'" + "description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', 'platform:chat_id', or Telegram topic 'telegram:chat_id:thread_id'. Examples: 'telegram', 'telegram:-1001234567890:17585', 'discord:#bot-home', 'slack:#engineering', 'signal:+15551234567'" }, "message": { "type": "string", @@ -73,23 +76,30 @@ def _handle_send(args): parts = target.split(":", 1) platform_name = parts[0].strip().lower() - chat_id = parts[1].strip() if len(parts) > 1 else None + target_ref = parts[1].strip() if len(parts) > 1 else None + chat_id = None + thread_id = None + + if target_ref: + chat_id, thread_id, is_explicit = _parse_target_ref(platform_name, target_ref) + else: + is_explicit = False # Resolve human-friendly channel names to numeric IDs - if chat_id and not chat_id.lstrip("-").isdigit(): + if target_ref and not is_explicit: try: from gateway.channel_directory import resolve_channel_name - resolved = resolve_channel_name(platform_name, chat_id) + resolved = resolve_channel_name(platform_name, target_ref) if resolved: - chat_id = resolved + chat_id, thread_id, _ = _parse_target_ref(platform_name, resolved) else: return json.dumps({ - "error": f"Could not resolve '{chat_id}' on {platform_name}. " + "error": f"Could not resolve '{target_ref}' on {platform_name}. " f"Use send_message(action='list') to see available targets." }) except Exception: return json.dumps({ - "error": f"Could not resolve '{chat_id}' on {platform_name}. " + "error": f"Could not resolve '{target_ref}' on {platform_name}. " f"Try using a numeric channel ID instead." }) @@ -134,7 +144,7 @@ def _handle_send(args): try: from model_tools import _run_async - result = _run_async(_send_to_platform(platform, pconfig, chat_id, message)) + result = _run_async(_send_to_platform(platform, pconfig, chat_id, message, thread_id=thread_id)) if used_home_channel and isinstance(result, dict) and result.get("success"): result["note"] = f"Sent to {platform_name} home channel (chat_id: {chat_id})" @@ -143,7 +153,7 @@ def _handle_send(args): try: from gateway.mirror import mirror_to_session source_label = os.getenv("HERMES_SESSION_PLATFORM", "cli") - if mirror_to_session(platform_name, chat_id, message, source_label=source_label): + if mirror_to_session(platform_name, chat_id, message, source_label=source_label, thread_id=thread_id): result["mirrored"] = True except Exception: pass @@ -153,11 +163,22 @@ def _handle_send(args): return json.dumps({"error": f"Send failed: {e}"}) -async def _send_to_platform(platform, pconfig, chat_id, message): +def _parse_target_ref(platform_name: str, target_ref: str): + """Parse a tool target into chat_id/thread_id and whether it is explicit.""" + if platform_name == "telegram": + match = _TELEGRAM_TOPIC_TARGET_RE.fullmatch(target_ref) + if match: + return match.group(1), match.group(2), True + if target_ref.lstrip("-").isdigit(): + return target_ref, None, True + return None, None, False + + +async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None): """Route a message to the appropriate platform sender.""" from gateway.config import Platform if platform == Platform.TELEGRAM: - return await _send_telegram(pconfig.token, chat_id, message) + return await _send_telegram(pconfig.token, chat_id, message, thread_id=thread_id) elif platform == Platform.DISCORD: return await _send_discord(pconfig.token, chat_id, message) elif platform == Platform.SLACK: @@ -167,12 +188,15 @@ async def _send_to_platform(platform, pconfig, chat_id, message): return {"error": f"Direct sending not yet implemented for {platform.value}"} -async def _send_telegram(token, chat_id, message): +async def _send_telegram(token, chat_id, message, thread_id=None): """Send via Telegram Bot API (one-shot, no polling needed).""" try: from telegram import Bot bot = Bot(token=token) - msg = await bot.send_message(chat_id=int(chat_id), text=message) + send_kwargs = {"chat_id": int(chat_id), "text": message} + if thread_id is not None: + send_kwargs["message_thread_id"] = int(thread_id) + msg = await bot.send_message(**send_kwargs) return {"success": True, "platform": "telegram", "chat_id": chat_id, "message_id": str(msg.message_id)} except ImportError: return {"error": "python-telegram-bot not installed. Run: pip install python-telegram-bot"}