diff --git a/gateway/delivery.py b/gateway/delivery.py index 88701cd61d5..8d4e43fdcd2 100644 --- a/gateway/delivery.py +++ b/gateway/delivery.py @@ -44,6 +44,25 @@ def _looks_like_int(value: Optional[str]) -> bool: return False +def _send_result_failed(result: Any) -> bool: + if isinstance(result, dict): + return result.get("success") is False + return getattr(result, "success", True) is False + + +def _send_result_error(result: Any) -> Optional[str]: + if isinstance(result, dict): + error = result.get("error") + else: + error = getattr(result, "error", None) + return str(error) if error else None + + +def _is_thread_not_found_delivery_error(result: Any) -> bool: + error = _send_result_error(result) + return bool(error and "thread not found" in error.lower()) + + @dataclass class DeliveryTarget: """ @@ -268,6 +287,8 @@ class DeliveryRouter: ) send_metadata = dict(metadata or {}) + is_named_telegram_private_topic = False + named_telegram_private_topic_name: Optional[str] = None if target.thread_id: has_explicit_direct_topic = ( "direct_messages_topic_id" in send_metadata @@ -283,6 +304,7 @@ class DeliveryRouter: and not has_explicit_direct_topic ) if is_named_telegram_private_topic: + named_telegram_private_topic_name = target_thread_id ensure_dm_topic = getattr(adapter, "ensure_dm_topic", None) if ensure_dm_topic is None: raise RuntimeError( @@ -318,8 +340,37 @@ class DeliveryRouter: elif "thread_id" not in send_metadata and "message_thread_id" not in send_metadata and not has_explicit_direct_topic: send_metadata["thread_id"] = target_thread_id result = await adapter.send(target.chat_id, content, metadata=send_metadata or None) - if getattr(result, "success", True) is False: - raise RuntimeError(getattr(result, "error", None) or f"{target.platform.value} delivery failed") + if _send_result_failed(result): + if ( + is_named_telegram_private_topic + and named_telegram_private_topic_name + and _is_thread_not_found_delivery_error(result) + ): + ensure_dm_topic = getattr(adapter, "ensure_dm_topic", None) + if ensure_dm_topic is None: + raise RuntimeError( + "Telegram adapter cannot refresh named private DM topics" + ) + try: + refreshed_thread_id = await ensure_dm_topic( + target.chat_id, + named_telegram_private_topic_name, + force_create=True, + ) + except TypeError: + refreshed_thread_id = await ensure_dm_topic( + target.chat_id, + named_telegram_private_topic_name, + ) + if not refreshed_thread_id: + raise RuntimeError( + f"Failed to refresh Telegram private DM topic '{named_telegram_private_topic_name}'" + ) + send_metadata["thread_id"] = str(refreshed_thread_id) + send_metadata["telegram_dm_topic_created_for_send"] = True + result = await adapter.send(target.chat_id, content, metadata=send_metadata or None) + if _send_result_failed(result): + raise RuntimeError(_send_result_error(result) or f"{target.platform.value} delivery failed") return result diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 56fdfc7ca3d..e8baff74e1d 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -1192,7 +1192,7 @@ class TelegramAdapter(BasePlatformAdapter): thread_id = await self._create_dm_topic(chat_id_int, name=name) return str(thread_id) if thread_id else None - async def ensure_dm_topic(self, chat_id: str, topic_name: str) -> Optional[str]: + async def ensure_dm_topic(self, chat_id: str, topic_name: str, force_create: bool = False) -> Optional[str]: """Return a private DM topic thread id, creating and persisting it if needed.""" name = str(topic_name or "").strip() if not name: @@ -1204,7 +1204,7 @@ class TelegramAdapter(BasePlatformAdapter): cache_key = f"{chat_id_int}:{name}" cached = self._dm_topics.get(cache_key) - if cached: + if cached and not force_create: return str(cached) topic_conf: Optional[Dict[str, Any]] = None @@ -1219,7 +1219,7 @@ class TelegramAdapter(BasePlatformAdapter): break break - if topic_conf and topic_conf.get("thread_id"): + if topic_conf and topic_conf.get("thread_id") and not force_create: thread_id = int(topic_conf["thread_id"]) self._dm_topics[cache_key] = thread_id return str(thread_id) @@ -1242,7 +1242,7 @@ class TelegramAdapter(BasePlatformAdapter): topic_conf["thread_id"] = thread_id self._dm_topics[cache_key] = int(thread_id) - self._persist_dm_topic_thread_id(chat_id_int, name, int(thread_id)) + self._persist_dm_topic_thread_id(chat_id_int, name, int(thread_id), replace_existing=force_create) return str(thread_id) async def rename_dm_topic( @@ -1268,7 +1268,13 @@ class TelegramAdapter(BasePlatformAdapter): self.name, chat_id, thread_id, name, ) - def _persist_dm_topic_thread_id(self, chat_id: int, topic_name: str, thread_id: int) -> None: + def _persist_dm_topic_thread_id( + self, + chat_id: int, + topic_name: str, + thread_id: int, + replace_existing: bool = False, + ) -> None: """Save a newly created thread_id back into config.yaml so it persists across restarts.""" try: from hermes_constants import get_hermes_home @@ -1301,9 +1307,10 @@ class TelegramAdapter(BasePlatformAdapter): matching_chat_entry = chat_entry for t in chat_entry.setdefault("topics", []): if t.get("name") == topic_name: - if not t.get("thread_id"): - t["thread_id"] = thread_id - changed = True + if replace_existing or not t.get("thread_id"): + if t.get("thread_id") != thread_id: + t["thread_id"] = thread_id + changed = True break else: chat_entry.setdefault("topics", []).append( diff --git a/tests/gateway/test_delivery.py b/tests/gateway/test_delivery.py index 69a62fb4330..f94836e3159 100644 --- a/tests/gateway/test_delivery.py +++ b/tests/gateway/test_delivery.py @@ -134,11 +134,31 @@ class RecordingAdapter: self.calls.append({"chat_id": chat_id, "content": content, "metadata": metadata}) return {"success": True} - async def ensure_dm_topic(self, chat_id, topic_name): - self.ensure_dm_topic_calls.append({"chat_id": chat_id, "topic_name": topic_name}) + async def ensure_dm_topic(self, chat_id, topic_name, force_create=False): + self.ensure_dm_topic_calls.append( + {"chat_id": chat_id, "topic_name": topic_name, "force_create": force_create} + ) return "38049" +class StaleTopicAdapter: + def __init__(self): + self.calls = [] + self.ensure_dm_topic_calls = [] + + async def send(self, chat_id, content, metadata=None): + self.calls.append({"chat_id": chat_id, "content": content, "metadata": dict(metadata or {})}) + if len(self.calls) == 1: + return SendResult(success=False, error="Bad Request: message thread not found") + return SendResult(success=True, message_id="fresh-message") + + async def ensure_dm_topic(self, chat_id, topic_name, force_create=False): + self.ensure_dm_topic_calls.append( + {"chat_id": chat_id, "topic_name": topic_name, "force_create": force_create} + ) + return "38064" if force_create else "32343" + + @pytest.mark.asyncio async def test_explicit_telegram_private_thread_requires_reply_anchor(tmp_path, monkeypatch): monkeypatch.setattr("gateway.delivery.get_hermes_home", lambda: tmp_path) @@ -162,7 +182,7 @@ async def test_named_telegram_private_topic_is_created_before_delivery(tmp_path, await router._deliver_to_platform(target, "hello", metadata=None) assert adapter.ensure_dm_topic_calls == [ - {"chat_id": "722341991", "topic_name": "Hermes API Test"} + {"chat_id": "722341991", "topic_name": "Hermes API Test", "force_create": False} ] assert adapter.calls == [ { @@ -176,6 +196,24 @@ async def test_named_telegram_private_topic_is_created_before_delivery(tmp_path, ] +@pytest.mark.asyncio +async def test_named_telegram_private_topic_refreshes_stale_thread_id(tmp_path, monkeypatch): + monkeypatch.setattr("gateway.delivery.get_hermes_home", lambda: tmp_path) + adapter = StaleTopicAdapter() + router = DeliveryRouter(GatewayConfig(), adapters={Platform.TELEGRAM: adapter}) + target = DeliveryTarget.parse("telegram:722341991:Personal") + + result = await router._deliver_to_platform(target, "hello", metadata=None) + + assert getattr(result, "message_id", None) == "fresh-message" + assert adapter.ensure_dm_topic_calls == [ + {"chat_id": "722341991", "topic_name": "Personal", "force_create": False}, + {"chat_id": "722341991", "topic_name": "Personal", "force_create": True}, + ] + assert [call["metadata"]["thread_id"] for call in adapter.calls] == ["32343", "38064"] + assert all(call["metadata"]["telegram_dm_topic_created_for_send"] is True for call in adapter.calls) + + @pytest.mark.asyncio async def test_explicit_telegram_private_thread_uses_reply_fallback_with_anchor(tmp_path, monkeypatch): monkeypatch.setattr("gateway.delivery.get_hermes_home", lambda: tmp_path) diff --git a/tests/gateway/test_dm_topics.py b/tests/gateway/test_dm_topics.py index b50f80291ef..332375229c5 100644 --- a/tests/gateway/test_dm_topics.py +++ b/tests/gateway/test_dm_topics.py @@ -224,7 +224,33 @@ async def test_ensure_dm_topic_creates_on_demand_and_persists(): assert adapter._dm_topics_config == [ {"chat_id": 111, "topics": [{"name": "On Demand", "thread_id": 444}]} ] - adapter._persist_dm_topic_thread_id.assert_called_once_with(111, "On Demand", 444) + adapter._persist_dm_topic_thread_id.assert_called_once_with( + 111, "On Demand", 444, replace_existing=False + ) + + +@pytest.mark.asyncio +async def test_ensure_dm_topic_force_create_replaces_persisted_thread_id(): + """Refreshing a stale named topic should replace the cached persisted thread_id.""" + adapter = _make_adapter() + bot = AsyncMock() + bot.create_forum_topic.return_value = SimpleNamespace(message_thread_id=777) + adapter._bot = bot + adapter._persist_dm_topic_thread_id = MagicMock() + adapter._dm_topics = {"111:General": 500} + adapter._dm_topics_config = [ + {"chat_id": 111, "topics": [{"name": "General", "thread_id": 500}]} + ] + + result = await adapter.ensure_dm_topic("111", "General", force_create=True) + + assert result == "777" + bot.create_forum_topic.assert_called_once_with(chat_id=111, name="General") + assert adapter._dm_topics["111:General"] == 777 + assert adapter._dm_topics_config[0]["topics"][0]["thread_id"] == 777 + adapter._persist_dm_topic_thread_id.assert_called_once_with( + 111, "General", 777, replace_existing=True + ) # ── _persist_dm_topic_thread_id ── @@ -309,6 +335,45 @@ def test_persist_dm_topic_thread_id_skips_if_already_set(tmp_path): assert topics[0]["thread_id"] == 500 # unchanged +def test_persist_dm_topic_thread_id_replaces_existing_when_requested(tmp_path): + """Forced refresh should overwrite a stale persisted thread_id.""" + import yaml + + config_data = { + "platforms": { + "telegram": { + "extra": { + "dm_topics": [ + { + "chat_id": 111, + "topics": [ + {"name": "General", "icon_color": 123, "thread_id": 500}, + ], + } + ] + } + } + } + } + + config_file = tmp_path / ".hermes" / "config.yaml" + config_file.parent.mkdir(parents=True) + with open(config_file, "w") as f: + yaml.dump(config_data, f) + + adapter = _make_adapter() + + with patch.object(Path, "home", return_value=tmp_path), \ + patch.dict(os.environ, {"HERMES_HOME": str(tmp_path / ".hermes")}): + adapter._persist_dm_topic_thread_id(111, "General", 999, replace_existing=True) + + with open(config_file) as f: + result = yaml.safe_load(f) + + topics = result["platforms"]["telegram"]["extra"]["dm_topics"][0]["topics"] + assert topics[0]["thread_id"] == 999 + + # ── _get_dm_topic_info ── diff --git a/tests/gateway/test_telegram_thread_fallback.py b/tests/gateway/test_telegram_thread_fallback.py index 1900ff09124..86aa2eefeb5 100644 --- a/tests/gateway/test_telegram_thread_fallback.py +++ b/tests/gateway/test_telegram_thread_fallback.py @@ -624,6 +624,33 @@ async def test_send_created_private_topic_uses_message_thread_without_anchor(): assert "direct_messages_topic_id" not in call_log[0] +@pytest.mark.asyncio +async def test_created_private_topic_thread_not_found_fails_without_root_fallback(): + """Created private-topic sends must not retry into All Messages on stale thread IDs.""" + adapter = _make_adapter() + call_log = [] + + async def mock_send_message(**kwargs): + call_log.append(dict(kwargs)) + raise FakeBadRequest("Message thread not found") + + adapter._bot = SimpleNamespace(send_message=mock_send_message) + + result = await adapter.send( + chat_id="123", + content="created topic message", + metadata={ + "thread_id": "32343", + "telegram_dm_topic_created_for_send": True, + }, + ) + + assert result.success is False + assert "thread not found" in str(result.error).lower() + assert len(call_log) == 1 + assert call_log[0]["message_thread_id"] == 32343 + + @pytest.mark.asyncio async def test_send_uses_metadata_reply_fallback_for_streaming_dm_topics(): """Metadata-only sends still stay in Hermes-created Telegram DM topics."""