mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-08 03:01:47 +00:00
fix(gateway): preserve thread routing from cached live session sources
This commit is contained in:
parent
5bf12eb44a
commit
176b93575a
4 changed files with 90 additions and 0 deletions
|
|
@ -1086,6 +1086,7 @@ class GatewayRunner:
|
|||
self._pending_native_image_paths_by_session: Dict[str, List[str]] = {}
|
||||
self._busy_ack_ts: Dict[str, float] = {} # last busy-ack timestamp per session (debounce)
|
||||
self._session_run_generation: Dict[str, int] = {}
|
||||
self._session_sources: Dict[str, "SessionSource"] = {}
|
||||
|
||||
# Cache AIAgent instances per session to preserve prompt caching.
|
||||
# Without this, a new AIAgent is created per message, rebuilding the
|
||||
|
|
@ -2451,6 +2452,9 @@ class GatewayRunner:
|
|||
e,
|
||||
)
|
||||
|
||||
if source is None:
|
||||
source = self._get_cached_session_source(session_key)
|
||||
|
||||
if source is not None:
|
||||
platform_str = source.platform.value
|
||||
chat_id = str(source.chat_id)
|
||||
|
|
@ -6006,6 +6010,26 @@ class GatewayRunner:
|
|||
return []
|
||||
return list(pending_native.pop(session_key, []) or [])
|
||||
|
||||
def _cache_session_source(self, session_key: str, source) -> None:
|
||||
if not session_key or source is None:
|
||||
return
|
||||
cached_sources = getattr(self, "_session_sources", None)
|
||||
if cached_sources is None:
|
||||
cached_sources = {}
|
||||
self._session_sources = cached_sources
|
||||
try:
|
||||
cached_sources[session_key] = dataclasses.replace(source)
|
||||
except Exception:
|
||||
logger.debug("Failed to cache live session source for %s", session_key, exc_info=True)
|
||||
|
||||
def _get_cached_session_source(self, session_key: str):
|
||||
if not session_key:
|
||||
return None
|
||||
cached_sources = getattr(self, "_session_sources", None)
|
||||
if not cached_sources:
|
||||
return None
|
||||
return cached_sources.get(session_key)
|
||||
|
||||
async def _handle_message_with_agent(self, event, source, _quick_key: str, run_generation: int):
|
||||
"""Inner handler that runs under the _running_agents sentinel guard."""
|
||||
_msg_start_time = time.time()
|
||||
|
|
@ -6020,6 +6044,7 @@ class GatewayRunner:
|
|||
# Get or create session
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
session_key = session_entry.session_key
|
||||
self._cache_session_source(session_key, source)
|
||||
if self._is_telegram_topic_lane(source):
|
||||
try:
|
||||
binding = self._session_db.get_telegram_topic_binding(
|
||||
|
|
@ -11894,6 +11919,10 @@ class GatewayRunner:
|
|||
exc,
|
||||
)
|
||||
|
||||
cached_source = self._get_cached_session_source(session_key)
|
||||
if cached_source is not None:
|
||||
return cached_source
|
||||
|
||||
_parsed = _parse_session_key(session_key)
|
||||
if _parsed:
|
||||
derived_platform = _parsed["platform"]
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ def make_restart_runner(
|
|||
runner._update_prompt_pending = {}
|
||||
runner._voice_mode = {}
|
||||
runner._session_model_overrides = {}
|
||||
runner._session_sources = {}
|
||||
runner._shutdown_all_gateway_honcho = lambda: None
|
||||
runner._update_runtime_status = MagicMock()
|
||||
runner._queue_or_replace_pending_event = GatewayRunner._queue_or_replace_pending_event.__get__(
|
||||
|
|
@ -115,6 +116,12 @@ def make_restart_runner(
|
|||
runner._notify_active_sessions_of_shutdown = (
|
||||
GatewayRunner._notify_active_sessions_of_shutdown.__get__(runner, GatewayRunner)
|
||||
)
|
||||
runner._cache_session_source = GatewayRunner._cache_session_source.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner._get_cached_session_source = GatewayRunner._get_cached_session_source.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner._launch_detached_restart_command = GatewayRunner._launch_detached_restart_command.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
|
|
|
|||
|
|
@ -304,6 +304,40 @@ def test_build_process_event_source_falls_back_to_session_key_chat_type(monkeypa
|
|||
assert source.user_name == "Emiliyan"
|
||||
|
||||
|
||||
def test_build_process_event_source_uses_cached_live_source_before_session_key_parse(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
from gateway.session import SessionSource
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
runner._cache_session_source(
|
||||
"agent:main:telegram:group:-100:42",
|
||||
SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-100",
|
||||
chat_type="group",
|
||||
thread_id="42",
|
||||
user_id="proc_owner",
|
||||
user_name="alice",
|
||||
),
|
||||
)
|
||||
|
||||
source = runner._build_process_event_source(
|
||||
{
|
||||
"session_id": "proc_watch",
|
||||
"session_key": "agent:main:telegram:group:-100:42",
|
||||
}
|
||||
)
|
||||
|
||||
assert source is not None
|
||||
assert source.platform == Platform.TELEGRAM
|
||||
assert source.chat_id == "-100"
|
||||
assert source.chat_type == "group"
|
||||
assert source.thread_id == "42"
|
||||
assert source.user_id == "proc_owner"
|
||||
assert source.user_name == "alice"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_watch_notification_ignores_foreground_event_source(monkeypatch, tmp_path):
|
||||
"""Negative test: watch notification must NOT route to the foreground thread."""
|
||||
|
|
|
|||
|
|
@ -603,3 +603,23 @@ async def test_send_restart_notification_logs_info_on_sendresult_success(
|
|||
f"got records: {[(r.levelname, r.getMessage()) for r in caplog.records]}"
|
||||
)
|
||||
assert not notify_path.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_notifications_use_cached_live_thread_source_when_origin_missing():
|
||||
runner, adapter = make_restart_runner()
|
||||
source = make_restart_source(chat_id="parent-42", chat_type="group", thread_id="topic-7")
|
||||
session_key = build_session_key(source)
|
||||
|
||||
runner._running_agents[session_key] = object()
|
||||
runner.session_store._entries[session_key] = MagicMock(origin=None)
|
||||
runner._cache_session_source(session_key, source)
|
||||
adapter.send = AsyncMock(return_value=SendResult(success=True, message_id="shutdown"))
|
||||
|
||||
await runner._notify_active_sessions_of_shutdown()
|
||||
|
||||
adapter.send.assert_awaited_once_with(
|
||||
"parent-42",
|
||||
"⚠️ Gateway shutting down — Your current task will be interrupted.",
|
||||
metadata={"thread_id": "topic-7"},
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue