refactor(gateway): generalize topic recovery via adapter hook

Replace the runner-introspection trick in #32998 with an explicit
`set_topic_recovery_fn` setter on `BasePlatformAdapter`. The gateway
runner installs it once at adapter init; the adapter calls
`_apply_topic_recovery(event)` before any session keying.

Also apply the hook in `BasePlatformAdapter.handle_message` so the
running-agent guard and pending-message queue key off the recovered
thread_id too — not just the text-batch coalescence.

Net change vs #32998 alone: -2 files of indirection (no
`_message_handler.__self__` peek, no separate `_normalize_text_batch_source`),
+1 generic mechanism (other adapters can install their own hook later).
This commit is contained in:
teknium1 2026-05-28 20:49:53 -07:00 committed by Teknium
parent 5407d25599
commit 100536134c
4 changed files with 71 additions and 59 deletions

View file

@ -472,6 +472,7 @@ def is_host_excluded_by_no_proxy(hostname: str, no_proxy_value: str | None = Non
return False
import dataclasses
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
@ -1561,6 +1562,10 @@ class BasePlatformAdapter(ABC):
self.config = config
self.platform = platform
self._message_handler: Optional[MessageHandler] = None
# Optional hook (e.g. Telegram DM topic recovery) that rewrites
# ``event.source.thread_id`` before session keying. Returns the
# corrected thread_id or None to leave the source untouched.
self._topic_recovery_fn: Optional[Callable[[Any], Optional[str]]] = None
self._running = False
self._fatal_error_code: Optional[str] = None
self._fatal_error_message: Optional[str] = None
@ -1816,6 +1821,40 @@ class BasePlatformAdapter(ABC):
"""
self._message_handler = handler
def set_topic_recovery_fn(
self,
fn: Optional[Callable[[Any], Optional[str]]],
) -> None:
"""Install a thread_id-recovery hook (Telegram DM topic mode).
The hook is called with ``event.source`` before session keying;
a non-None return value replaces ``source.thread_id``. Pass
``None`` to clear the hook.
"""
# Guard against subclasses that initialize via ``object.__new__`` in
# tests and never run ``BasePlatformAdapter.__init__``.
self._topic_recovery_fn = fn # type: ignore[attr-defined]
def _apply_topic_recovery(self, event: MessageEvent) -> None:
"""Rewrite ``event.source.thread_id`` in place if the hook returns one."""
recover = getattr(self, "_topic_recovery_fn", None)
if recover is None:
return
source = getattr(event, "source", None)
if source is None:
return
try:
recovered = recover(source)
except Exception:
logger.debug("topic recovery hook failed", exc_info=True)
return
if recovered is None or str(recovered) == str(source.thread_id or ""):
return
try:
event.source = dataclasses.replace(source, thread_id=str(recovered))
except Exception:
logger.debug("topic recovery rewrite failed", exc_info=True)
def set_busy_session_handler(self, handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]]) -> None:
"""Set an optional handler for messages arriving during active sessions."""
self._busy_session_handler = handler
@ -3332,7 +3371,12 @@ class BasePlatformAdapter(ABC):
return
coerce_plaintext_gateway_command(event)
# Rewrite ``event.source.thread_id`` via the installed recovery hook
# (Telegram DM topic mode) so the session key, guard checks, and
# downstream delivery all agree on the same lane.
self._apply_topic_recovery(event)
session_key = build_session_key(
event.source,
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),

View file

@ -5027,38 +5027,20 @@ class TelegramAdapter(BasePlatformAdapter):
# ------------------------------------------------------------------
def _text_batch_key(self, event: MessageEvent) -> str:
"""Session-scoped key for text message batching."""
"""Session-scoped key for text message batching.
Applies the installed topic-recovery hook first so DM-topic batches
coalesce on (and dispatch to) the recovered lane rather than the
raw inbound ``message_thread_id`` Telegram may have attached.
"""
from gateway.session import build_session_key
source = self._normalize_text_batch_source(event)
self._apply_topic_recovery(event)
return build_session_key(
source,
event.source,
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
thread_sessions_per_user=self.config.extra.get("thread_sessions_per_user", False),
)
def _normalize_text_batch_source(self, event: MessageEvent):
"""Apply runner-side Telegram DM topic recovery before batching."""
source = getattr(event, "source", None)
if source is None:
return source
runner = getattr(getattr(self, "_message_handler", None), "__self__", None)
recover_fn = getattr(runner, "_recover_telegram_topic_thread_id", None)
if not callable(recover_fn):
return source
try:
recovered = recover_fn(source)
except Exception:
logger.debug("telegram text batch recovery failed", exc_info=True)
return source
if recovered is None or str(recovered) == str(source.thread_id or ""):
return source
normalized = dataclasses.replace(source, thread_id=str(recovered))
try:
event.source = normalized
except Exception:
pass
return normalized
def _enqueue_text_event(self, event: MessageEvent) -> None:
"""Buffer a text event and reset the flush timer.

View file

@ -4159,6 +4159,7 @@ class GatewayRunner:
adapter.set_fatal_error_handler(self._handle_adapter_fatal_error)
adapter.set_session_store(self.session_store)
adapter.set_busy_session_handler(self._handle_active_session_busy_message)
adapter.set_topic_recovery_fn(self._recover_telegram_topic_thread_id)
adapter._busy_text_mode = self._busy_text_mode
# Try to connect
@ -5864,6 +5865,7 @@ class GatewayRunner:
adapter.set_fatal_error_handler(self._handle_adapter_fatal_error)
adapter.set_session_store(self.session_store)
adapter.set_busy_session_handler(self._handle_active_session_busy_message)
adapter.set_topic_recovery_fn(self._recover_telegram_topic_thread_id)
adapter._busy_text_mode = self._busy_text_mode
success = await self._connect_adapter_with_timeout(adapter, platform)

View file

@ -126,16 +126,9 @@ class TestTextBatching:
async def test_dm_topic_batching_recovers_thread_before_keying(self):
"""DM-topic text batches should use the recovered topic lane."""
adapter = _make_adapter()
class _Runner:
def _recover_telegram_topic_thread_id(self, source):
return "222" if str(source.thread_id or "") == "1" else None
async def _handle_message(self, _event):
return None
runner = _Runner()
adapter._message_handler = runner._handle_message
adapter.set_topic_recovery_fn(
lambda source: "222" if str(source.thread_id or "") == "1" else None
)
event = MessageEvent(
text="hello from DM topic",
message_type=MessageType.TEXT,
@ -150,29 +143,20 @@ class TestTextBatching:
adapter._enqueue_text_event(event)
recovered_key = build_session_key(
SimpleNamespace(
platform=Platform.TELEGRAM,
chat_id="12345",
chat_type="dm",
thread_id="222",
),
group_sessions_per_user=True,
thread_sessions_per_user=False,
)
stale_key = build_session_key(
SimpleNamespace(
platform=Platform.TELEGRAM,
chat_id="12345",
chat_type="dm",
thread_id="1",
),
group_sessions_per_user=True,
thread_sessions_per_user=False,
)
def _key(thread_id: str) -> str:
return build_session_key(
SimpleNamespace(
platform=Platform.TELEGRAM,
chat_id="12345",
chat_type="dm",
thread_id=thread_id,
),
group_sessions_per_user=True,
thread_sessions_per_user=False,
)
assert recovered_key in adapter._pending_text_batches
assert stale_key not in adapter._pending_text_batches
assert _key("222") in adapter._pending_text_batches
assert _key("1") not in adapter._pending_text_batches
assert event.source.thread_id == "222"
await asyncio.sleep(0.2)