mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-31 06:51:29 +00:00
refactor(gateway): generalize topic recovery via adapter hook
Replace the runner-introspection trick in #32998 with an explicit `set_topic_recovery_fn` setter on `BasePlatformAdapter`. The gateway runner installs it once at adapter init; the adapter calls `_apply_topic_recovery(event)` before any session keying. Also apply the hook in `BasePlatformAdapter.handle_message` so the running-agent guard and pending-message queue key off the recovered thread_id too — not just the text-batch coalescence. Net change vs #32998 alone: -2 files of indirection (no `_message_handler.__self__` peek, no separate `_normalize_text_batch_source`), +1 generic mechanism (other adapters can install their own hook later).
This commit is contained in:
parent
5407d25599
commit
100536134c
4 changed files with 71 additions and 59 deletions
|
|
@ -472,6 +472,7 @@ def is_host_excluded_by_no_proxy(hostname: str, no_proxy_value: str | None = Non
|
|||
return False
|
||||
|
||||
|
||||
import dataclasses
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
|
@ -1561,6 +1562,10 @@ class BasePlatformAdapter(ABC):
|
|||
self.config = config
|
||||
self.platform = platform
|
||||
self._message_handler: Optional[MessageHandler] = None
|
||||
# Optional hook (e.g. Telegram DM topic recovery) that rewrites
|
||||
# ``event.source.thread_id`` before session keying. Returns the
|
||||
# corrected thread_id or None to leave the source untouched.
|
||||
self._topic_recovery_fn: Optional[Callable[[Any], Optional[str]]] = None
|
||||
self._running = False
|
||||
self._fatal_error_code: Optional[str] = None
|
||||
self._fatal_error_message: Optional[str] = None
|
||||
|
|
@ -1816,6 +1821,40 @@ class BasePlatformAdapter(ABC):
|
|||
"""
|
||||
self._message_handler = handler
|
||||
|
||||
def set_topic_recovery_fn(
|
||||
self,
|
||||
fn: Optional[Callable[[Any], Optional[str]]],
|
||||
) -> None:
|
||||
"""Install a thread_id-recovery hook (Telegram DM topic mode).
|
||||
|
||||
The hook is called with ``event.source`` before session keying;
|
||||
a non-None return value replaces ``source.thread_id``. Pass
|
||||
``None`` to clear the hook.
|
||||
"""
|
||||
# Guard against subclasses that initialize via ``object.__new__`` in
|
||||
# tests and never run ``BasePlatformAdapter.__init__``.
|
||||
self._topic_recovery_fn = fn # type: ignore[attr-defined]
|
||||
|
||||
def _apply_topic_recovery(self, event: MessageEvent) -> None:
|
||||
"""Rewrite ``event.source.thread_id`` in place if the hook returns one."""
|
||||
recover = getattr(self, "_topic_recovery_fn", None)
|
||||
if recover is None:
|
||||
return
|
||||
source = getattr(event, "source", None)
|
||||
if source is None:
|
||||
return
|
||||
try:
|
||||
recovered = recover(source)
|
||||
except Exception:
|
||||
logger.debug("topic recovery hook failed", exc_info=True)
|
||||
return
|
||||
if recovered is None or str(recovered) == str(source.thread_id or ""):
|
||||
return
|
||||
try:
|
||||
event.source = dataclasses.replace(source, thread_id=str(recovered))
|
||||
except Exception:
|
||||
logger.debug("topic recovery rewrite failed", exc_info=True)
|
||||
|
||||
def set_busy_session_handler(self, handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]]) -> None:
|
||||
"""Set an optional handler for messages arriving during active sessions."""
|
||||
self._busy_session_handler = handler
|
||||
|
|
@ -3332,7 +3371,12 @@ class BasePlatformAdapter(ABC):
|
|||
return
|
||||
|
||||
coerce_plaintext_gateway_command(event)
|
||||
|
||||
|
||||
# Rewrite ``event.source.thread_id`` via the installed recovery hook
|
||||
# (Telegram DM topic mode) so the session key, guard checks, and
|
||||
# downstream delivery all agree on the same lane.
|
||||
self._apply_topic_recovery(event)
|
||||
|
||||
session_key = build_session_key(
|
||||
event.source,
|
||||
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
|
||||
|
|
|
|||
|
|
@ -5027,38 +5027,20 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
# ------------------------------------------------------------------
|
||||
|
||||
def _text_batch_key(self, event: MessageEvent) -> str:
|
||||
"""Session-scoped key for text message batching."""
|
||||
"""Session-scoped key for text message batching.
|
||||
|
||||
Applies the installed topic-recovery hook first so DM-topic batches
|
||||
coalesce on (and dispatch to) the recovered lane rather than the
|
||||
raw inbound ``message_thread_id`` Telegram may have attached.
|
||||
"""
|
||||
from gateway.session import build_session_key
|
||||
source = self._normalize_text_batch_source(event)
|
||||
self._apply_topic_recovery(event)
|
||||
return build_session_key(
|
||||
source,
|
||||
event.source,
|
||||
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
|
||||
thread_sessions_per_user=self.config.extra.get("thread_sessions_per_user", False),
|
||||
)
|
||||
|
||||
def _normalize_text_batch_source(self, event: MessageEvent):
|
||||
"""Apply runner-side Telegram DM topic recovery before batching."""
|
||||
source = getattr(event, "source", None)
|
||||
if source is None:
|
||||
return source
|
||||
runner = getattr(getattr(self, "_message_handler", None), "__self__", None)
|
||||
recover_fn = getattr(runner, "_recover_telegram_topic_thread_id", None)
|
||||
if not callable(recover_fn):
|
||||
return source
|
||||
try:
|
||||
recovered = recover_fn(source)
|
||||
except Exception:
|
||||
logger.debug("telegram text batch recovery failed", exc_info=True)
|
||||
return source
|
||||
if recovered is None or str(recovered) == str(source.thread_id or ""):
|
||||
return source
|
||||
normalized = dataclasses.replace(source, thread_id=str(recovered))
|
||||
try:
|
||||
event.source = normalized
|
||||
except Exception:
|
||||
pass
|
||||
return normalized
|
||||
|
||||
def _enqueue_text_event(self, event: MessageEvent) -> None:
|
||||
"""Buffer a text event and reset the flush timer.
|
||||
|
||||
|
|
|
|||
|
|
@ -4159,6 +4159,7 @@ class GatewayRunner:
|
|||
adapter.set_fatal_error_handler(self._handle_adapter_fatal_error)
|
||||
adapter.set_session_store(self.session_store)
|
||||
adapter.set_busy_session_handler(self._handle_active_session_busy_message)
|
||||
adapter.set_topic_recovery_fn(self._recover_telegram_topic_thread_id)
|
||||
adapter._busy_text_mode = self._busy_text_mode
|
||||
|
||||
# Try to connect
|
||||
|
|
@ -5864,6 +5865,7 @@ class GatewayRunner:
|
|||
adapter.set_fatal_error_handler(self._handle_adapter_fatal_error)
|
||||
adapter.set_session_store(self.session_store)
|
||||
adapter.set_busy_session_handler(self._handle_active_session_busy_message)
|
||||
adapter.set_topic_recovery_fn(self._recover_telegram_topic_thread_id)
|
||||
adapter._busy_text_mode = self._busy_text_mode
|
||||
|
||||
success = await self._connect_adapter_with_timeout(adapter, platform)
|
||||
|
|
|
|||
|
|
@ -126,16 +126,9 @@ class TestTextBatching:
|
|||
async def test_dm_topic_batching_recovers_thread_before_keying(self):
|
||||
"""DM-topic text batches should use the recovered topic lane."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
class _Runner:
|
||||
def _recover_telegram_topic_thread_id(self, source):
|
||||
return "222" if str(source.thread_id or "") == "1" else None
|
||||
|
||||
async def _handle_message(self, _event):
|
||||
return None
|
||||
|
||||
runner = _Runner()
|
||||
adapter._message_handler = runner._handle_message
|
||||
adapter.set_topic_recovery_fn(
|
||||
lambda source: "222" if str(source.thread_id or "") == "1" else None
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="hello from DM topic",
|
||||
message_type=MessageType.TEXT,
|
||||
|
|
@ -150,29 +143,20 @@ class TestTextBatching:
|
|||
|
||||
adapter._enqueue_text_event(event)
|
||||
|
||||
recovered_key = build_session_key(
|
||||
SimpleNamespace(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="12345",
|
||||
chat_type="dm",
|
||||
thread_id="222",
|
||||
),
|
||||
group_sessions_per_user=True,
|
||||
thread_sessions_per_user=False,
|
||||
)
|
||||
stale_key = build_session_key(
|
||||
SimpleNamespace(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="12345",
|
||||
chat_type="dm",
|
||||
thread_id="1",
|
||||
),
|
||||
group_sessions_per_user=True,
|
||||
thread_sessions_per_user=False,
|
||||
)
|
||||
def _key(thread_id: str) -> str:
|
||||
return build_session_key(
|
||||
SimpleNamespace(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="12345",
|
||||
chat_type="dm",
|
||||
thread_id=thread_id,
|
||||
),
|
||||
group_sessions_per_user=True,
|
||||
thread_sessions_per_user=False,
|
||||
)
|
||||
|
||||
assert recovered_key in adapter._pending_text_batches
|
||||
assert stale_key not in adapter._pending_text_batches
|
||||
assert _key("222") in adapter._pending_text_batches
|
||||
assert _key("1") not in adapter._pending_text_batches
|
||||
assert event.source.thread_id == "222"
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue