From 3438de362300a4adc7754c6e2428a0555324e197 Mon Sep 17 00:00:00 2001 From: aaron Date: Thu, 23 Apr 2026 09:09:33 +0000 Subject: [PATCH] refactor: unify gateway session title sync flow --- agent/title_generator.py | 51 ++++-- gateway/platforms/base.py | 54 ++++-- gateway/run.py | 197 ++++++++++++++++++++-- hermes_state.py | 25 +++ tests/agent/test_title_generator.py | 16 +- tests/gateway/test_run_progress_topics.py | 16 ++ tests/gateway/test_title_command.py | 108 +++++++++++- tests/test_hermes_state.py | 15 ++ 8 files changed, 421 insertions(+), 61 deletions(-) diff --git a/agent/title_generator.py b/agent/title_generator.py index d6ed9200a2..22d4f35c0a 100644 --- a/agent/title_generator.py +++ b/agent/title_generator.py @@ -56,6 +56,26 @@ def generate_title(user_message: str, assistant_response: str, timeout: float = return None +def generate_title_if_missing( + session_db, + session_id: str, + user_message: str, + assistant_response: str, +) -> Optional[str]: + """Return a generated title only if the session does not already have one.""" + if not session_db or not session_id: + return None + + try: + existing = session_db.get_session_title(session_id) + if existing: + return None + except Exception: + return None + + return generate_title(user_message, assistant_response) + + def auto_title_session( session_db, session_id: str, @@ -70,18 +90,12 @@ def auto_title_session( - session already has a title (user-set or previously auto-generated) - title generation fails """ - if not session_db or not session_id: - return - - # Check if title already exists (user may have set one via /title before first response) - try: - existing = session_db.get_session_title(session_id) - if existing: - return - except Exception: - return - - title = generate_title(user_message, assistant_response) + title = generate_title_if_missing( + session_db, + session_id, + user_message, + assistant_response, + ) if not title: return @@ -92,6 +106,12 @@ def auto_title_session( logger.debug("Failed to set auto-generated title: %s", e) +def should_auto_title(conversation_history: list) -> bool: + """Return whether this history is still early enough for auto-titling.""" + user_msg_count = sum(1 for m in (conversation_history or []) if m.get("role") == "user") + return user_msg_count <= 2 + + def maybe_auto_title( session_db, session_id: str, @@ -108,12 +128,7 @@ def maybe_auto_title( if not session_db or not session_id or not user_message or not assistant_response: return - # Count user messages in history to detect first exchange. - # conversation_history includes the exchange that just happened, - # so for a first exchange we expect exactly 1 user message - # (or 2 counting system). Be generous: generate on first 2 exchanges. - user_msg_count = sum(1 for m in (conversation_history or []) if m.get("role") == "user") - if user_msg_count > 2: + if not should_auto_title(conversation_history): return thread = threading.Thread( diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 7e32382728..b82c7baefb 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -1522,13 +1522,19 @@ class BasePlatformAdapter(ABC): ``generation`` lets callers tie the callback to a specific gateway run generation so stale runs cannot clear callbacks owned by a fresher run. + Multiple callbacks may be registered for the same session; matching + callbacks run in registration order. """ if not session_key or not callable(callback): return - if generation is None: - self._post_delivery_callbacks[session_key] = callback + entry: Any = callback if generation is None else (int(generation), callback) + existing = self._post_delivery_callbacks.get(session_key) + if existing is None: + self._post_delivery_callbacks[session_key] = entry + elif isinstance(existing, list): + existing.append(entry) else: - self._post_delivery_callbacks[session_key] = (int(generation), callback) + self._post_delivery_callbacks[session_key] = [existing, entry] def pop_post_delivery_callback( self, @@ -1536,22 +1542,46 @@ class BasePlatformAdapter(ABC): *, generation: int | None = None, ) -> Callable | None: - """Pop a deferred callback, optionally requiring generation ownership.""" + """Pop deferred callbacks, optionally requiring generation ownership.""" if not session_key: return None entry = self._post_delivery_callbacks.get(session_key) if entry is None: return None - if isinstance(entry, tuple) and len(entry) == 2: - entry_generation, callback = entry - if generation is not None and int(entry_generation) != int(generation): - return None + + entries = entry if isinstance(entry, list) else [entry] + matched: list[Callable] = [] + remaining: list[Any] = [] + + for item in entries: + if isinstance(item, tuple) and len(item) == 2: + entry_generation, callback = item + if generation is not None and int(entry_generation) != int(generation): + remaining.append(item) + continue + if callable(callback): + matched.append(callback) + continue + + if generation is not None: + remaining.append(item) + continue + if callable(item): + matched.append(item) + + if remaining: + self._post_delivery_callbacks[session_key] = remaining if len(remaining) > 1 else remaining[0] + else: self._post_delivery_callbacks.pop(session_key, None) - return callback if callable(callback) else None - if generation is not None: + + if not matched: return None - self._post_delivery_callbacks.pop(session_key, None) - return entry if callable(entry) else None + + def _run_all() -> None: + for callback in matched: + callback() + + return _run_all # ── Processing lifecycle hooks ────────────────────────────────────────── # Subclasses override these to react to message processing events diff --git a/gateway/run.py b/gateway/run.py index 0debaae27d..981d900e18 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1028,6 +1028,169 @@ class GatewayRunner: thread_sessions_per_user=getattr(config, "thread_sessions_per_user", False), ) + def _track_background_task(self, task: asyncio.Task) -> asyncio.Task: + """Keep a background task alive until completion.""" + background_tasks = getattr(self, "_background_tasks", None) + if isinstance(background_tasks, set): + background_tasks.add(task) + task.add_done_callback(background_tasks.discard) + return task + + async def _sync_session_title_to_source(self, source: SessionSource, title: str) -> bool: + """Best-effort platform thread/topic title sync for a session title.""" + if not source or not source.thread_id or not source.platform: + return False + adapter = self.adapters.get(source.platform) + if not adapter: + return False + updater = getattr(adapter, "update_thread_title", None) + if not callable(updater): + return False + try: + return bool(await updater(source.chat_id, source.thread_id, title)) + except Exception as e: + logger.debug("Session-title sync failed for %s thread %s: %s", source.platform, source.thread_id, e) + return False + + def _schedule_session_title_sync_after_delivery( + self, + *, + session_key: Optional[str], + source: SessionSource, + title: str, + generation: int | None = None, + ) -> bool: + """Schedule platform title sync after the current response is delivered.""" + if not session_key or not source or not source.thread_id or not source.platform: + return False + adapter = self.adapters.get(source.platform) + if not adapter: + return False + + def _after_delivery() -> None: + self._track_background_task( + asyncio.create_task(self._sync_session_title_to_source(source, title)) + ) + + register = getattr(adapter, "register_post_delivery_callback", None) + if callable(register): + register(session_key, _after_delivery, generation=generation) + return True + + post_callbacks = getattr(adapter, "_post_delivery_callbacks", None) + if isinstance(post_callbacks, dict): + post_callbacks[session_key] = _after_delivery + return True + return False + + async def _apply_session_title( + self, + *, + session_id: str, + source: SessionSource, + title: str, + session_key: Optional[str] = None, + defer_thread_sync: bool = False, + sync_thread_now: bool = False, + only_if_missing: bool = False, + generation: int | None = None, + ) -> bool: + """Persist a session title and optionally propagate it to the source thread.""" + if not self._session_db: + return False + if only_if_missing: + changed = self._session_db.set_session_title_if_missing(session_id, title) + else: + changed = self._session_db.set_session_title(session_id, title) + if not changed: + return False + if sync_thread_now: + await self._sync_session_title_to_source(source, title) + elif defer_thread_sync: + self._schedule_session_title_sync_after_delivery( + session_key=session_key, + source=source, + title=title, + generation=generation, + ) + return True + + async def _auto_title_gateway_session( + self, + *, + session_id: str, + session_key: Optional[str], + source: SessionSource, + user_message: str, + assistant_response: str, + ) -> None: + """Generate a missing title for a gateway session and sync it natively.""" + if not self._session_db: + return + try: + from agent.title_generator import generate_title_if_missing + + title = await asyncio.to_thread( + generate_title_if_missing, + self._session_db, + session_id, + user_message, + assistant_response, + ) + if not title: + return + await self._apply_session_title( + session_id=session_id, + session_key=session_key, + source=source, + title=title, + sync_thread_now=True, + only_if_missing=True, + ) + except Exception as e: + logger.debug("Gateway auto-title failed for %s: %s", session_id, e) + + def _maybe_schedule_gateway_auto_title( + self, + *, + session_id: str, + session_key: Optional[str], + source: SessionSource, + user_message: str, + assistant_response: str, + conversation_history: list, + generation: int | None = None, + ) -> None: + """Run gateway auto-title after delivery when this is an early exchange.""" + if not self._session_db or not session_id or not user_message or not assistant_response: + return + try: + from agent.title_generator import should_auto_title + except Exception: + return + if not should_auto_title(conversation_history): + return + + def _launch() -> None: + self._track_background_task( + asyncio.create_task( + self._auto_title_gateway_session( + session_id=session_id, + session_key=session_key, + source=source, + user_message=user_message, + assistant_response=assistant_response, + ) + ) + ) + + adapter = self.adapters.get(source.platform) if source and source.platform else None + register = getattr(adapter, "register_post_delivery_callback", None) if adapter else None + if session_key and callable(register): + register(session_key, _launch, generation=generation) + else: + _launch() + def _resolve_session_agent_runtime( self, *, @@ -7063,17 +7226,14 @@ class GatewayRunner: return "⚠️ Title is empty after cleanup. Please use printable characters." # Set the title try: - if self._session_db.set_session_title(session_id, sanitized): - response = f"✏️ Session title set: **{sanitized}**" - if source.platform == Platform.TELEGRAM and source.thread_id: - adapter = self.adapters.get(source.platform) - if adapter and await adapter.update_thread_title( - source.chat_id, - source.thread_id, - sanitized, - ): - response += "\n🧵 Telegram topic renamed too." - return response + if await self._apply_session_title( + session_id=session_id, + session_key=session_entry.session_key, + source=source, + title=sanitized, + defer_thread_sync=True, + ): + return f"✏️ Session title set: **{sanitized}**" else: return "Session not found in database." except ValueError as e: @@ -10092,14 +10252,15 @@ class GatewayRunner: # Auto-generate session title after first exchange (non-blocking) if final_response and self._session_db: try: - from agent.title_generator import maybe_auto_title all_msgs = result_holder[0].get("messages", []) if result_holder[0] else [] - maybe_auto_title( - self._session_db, - effective_session_id, - message, - final_response, - all_msgs, + self._maybe_schedule_gateway_auto_title( + session_id=effective_session_id, + session_key=session_key, + source=source, + user_message=message, + assistant_response=final_response, + conversation_history=all_msgs, + generation=run_generation, ) except Exception: pass diff --git a/hermes_state.py b/hermes_state.py index 2d8a0fd4af..2a3084ba35 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -641,6 +641,31 @@ class SessionDB: rowcount = self._execute_write(_do) return rowcount > 0 + def set_session_title_if_missing(self, session_id: str, title: str) -> bool: + """Atomically set a session title only when the current title is NULL.""" + title = self.sanitize_title(title) + if not title: + return False + + def _do(conn): + cursor = conn.execute( + "SELECT id FROM sessions WHERE title = ? AND id != ?", + (title, session_id), + ) + conflict = cursor.fetchone() + if conflict: + raise ValueError( + f"Title '{title}' is already in use by session {conflict['id']}" + ) + cursor = conn.execute( + "UPDATE sessions SET title = ? WHERE id = ? AND title IS NULL", + (title, session_id), + ) + return cursor.rowcount + + rowcount = self._execute_write(_do) + return rowcount > 0 + def get_session_title(self, session_id: str) -> Optional[str]: """Get the title for a session, or None.""" with self._lock: diff --git a/tests/agent/test_title_generator.py b/tests/agent/test_title_generator.py index 98fb8fb213..38900ba7eb 100644 --- a/tests/agent/test_title_generator.py +++ b/tests/agent/test_title_generator.py @@ -7,6 +7,7 @@ import pytest from agent.title_generator import ( generate_title, + generate_title_if_missing, auto_title_session, maybe_auto_title, ) @@ -89,19 +90,26 @@ class TestAutoTitleSession: def test_skips_if_no_session_db(self): auto_title_session(None, "sess-1", "hi", "hello") # should not crash - def test_skips_if_title_exists(self): + def test_generate_title_if_missing_skips_if_title_exists(self): db = MagicMock() db.get_session_title.return_value = "Existing Title" with patch("agent.title_generator.generate_title") as gen: - auto_title_session(db, "sess-1", "hi", "hello") + assert generate_title_if_missing(db, "sess-1", "hi", "hello") is None gen.assert_not_called() + def test_generate_title_if_missing_returns_generated_title(self): + db = MagicMock() + db.get_session_title.return_value = None + + with patch("agent.title_generator.generate_title", return_value="New Title"): + assert generate_title_if_missing(db, "sess-1", "hi", "hello") == "New Title" + def test_generates_and_sets_title(self): db = MagicMock() db.get_session_title.return_value = None - with patch("agent.title_generator.generate_title", return_value="New Title"): + with patch("agent.title_generator.generate_title_if_missing", return_value="New Title"): auto_title_session(db, "sess-1", "hi", "hello") db.set_session_title.assert_called_once_with("sess-1", "New Title") @@ -109,7 +117,7 @@ class TestAutoTitleSession: db = MagicMock() db.get_session_title.return_value = None - with patch("agent.title_generator.generate_title", return_value=None): + with patch("agent.title_generator.generate_title_if_missing", return_value=None): auto_title_session(db, "sess-1", "hi", "hello") db.set_session_title.assert_not_called() diff --git a/tests/gateway/test_run_progress_topics.py b/tests/gateway/test_run_progress_topics.py index 59e9fa0408..24eb6e1d3b 100644 --- a/tests/gateway/test_run_progress_topics.py +++ b/tests/gateway/test_run_progress_topics.py @@ -782,6 +782,22 @@ async def test_base_processing_releases_post_delivery_callback_after_main_send() assert released == [True] +@pytest.mark.asyncio +async def test_post_delivery_callbacks_compose_for_same_session(): + """Multiple post-delivery callbacks for a session should all run.""" + adapter = ProgressCaptureAdapter() + fired = [] + + adapter.register_post_delivery_callback("sess", lambda: fired.append("first"), generation=3) + adapter.register_post_delivery_callback("sess", lambda: fired.append("second"), generation=3) + + callback = adapter.pop_post_delivery_callback("sess", generation=3) + assert callable(callback) + callback() + + assert fired == ["first", "second"] + + @pytest.mark.asyncio async def test_run_agent_drops_tool_progress_after_generation_invalidation(monkeypatch, tmp_path): import yaml diff --git a/tests/gateway/test_title_command.py b/tests/gateway/test_title_command.py index 0107df75c8..37d82a1024 100644 --- a/tests/gateway/test_title_command.py +++ b/tests/gateway/test_title_command.py @@ -1,9 +1,10 @@ -"""Tests for /title gateway slash command. +"""Tests for gateway session-title flows. -Tests the _handle_title_command handler (set/show session titles) -across all gateway messenger platforms. +Tests the /title handler plus native gateway session-title propagation +for manual and auto-generated titles. """ +import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -33,6 +34,7 @@ def _make_runner(session_db=None): runner.adapters = {} runner._voice_mode = {} runner._session_db = session_db + runner._background_tasks = set() # Mock session_store that returns a session entry with a known session_id mock_session_entry = MagicMock() @@ -72,7 +74,7 @@ class TestHandleTitleCommand: @pytest.mark.asyncio async def test_set_title_renames_telegram_topic_when_in_thread(self, tmp_path): - """Telegram /title should also rename the active topic thread when possible.""" + """Telegram /title should schedule thread-title sync via the native callback path.""" from hermes_state import SessionDB db = SessionDB(db_path=tmp_path / "state.db") db.create_session("test_session_123", "telegram") @@ -85,14 +87,18 @@ class TestHandleTitleCommand: event = _make_event(text="/title Indicative Topic", thread_id="470094") result = await runner._handle_title_command(event) + adapter.register_post_delivery_callback.assert_called_once() + callback = adapter.register_post_delivery_callback.call_args.args[1] + callback() + await asyncio.sleep(0) adapter.update_thread_title.assert_awaited_once_with("67890", "470094", "Indicative Topic") - assert "Telegram topic renamed too" in result + assert "Telegram topic renamed too" not in result assert db.get_session_title("test_session_123") == "Indicative Topic" db.close() @pytest.mark.asyncio async def test_set_title_renames_telegram_general_topic_when_thread_is_one(self, tmp_path): - """Telegram General topic thread_id=1 should still trigger a rename attempt.""" + """Telegram General topic thread_id=1 should also use the deferred sync path.""" from hermes_state import SessionDB db = SessionDB(db_path=tmp_path / "state.db") db.create_session("test_session_123", "telegram") @@ -105,8 +111,12 @@ class TestHandleTitleCommand: event = _make_event(text="/title Lobby", thread_id="1") result = await runner._handle_title_command(event) + adapter.register_post_delivery_callback.assert_called_once() + callback = adapter.register_post_delivery_callback.call_args.args[1] + callback() + await asyncio.sleep(0) adapter.update_thread_title.assert_awaited_once_with("67890", "1", "Lobby") - assert "Telegram topic renamed too" in result + assert "Telegram topic renamed too" not in result assert db.get_session_title("test_session_123") == "Lobby" db.close() @@ -125,11 +135,91 @@ class TestHandleTitleCommand: event = _make_event(text="/title Plain Chat Title") result = await runner._handle_title_command(event) - adapter.update_thread_title.assert_not_called() - assert "Telegram topic renamed too" not in result + adapter.register_post_delivery_callback.assert_not_called() assert db.get_session_title("test_session_123") == "Plain Chat Title" db.close() + +class TestGatewayAutoTitleSync: + @pytest.mark.asyncio + async def test_auto_title_flow_uses_same_session_title_path(self, tmp_path): + """Gateway auto-title should persist title and sync Telegram thread title.""" + from hermes_state import SessionDB + db = SessionDB(db_path=tmp_path / "state.db") + db.create_session("test_session_123", "telegram") + + runner = _make_runner(session_db=db) + adapter = MagicMock() + adapter.update_thread_title = AsyncMock(return_value=True) + runner.adapters[Platform.TELEGRAM] = adapter + source = _make_event(thread_id="470094").source + + with patch("agent.title_generator.generate_title_if_missing", return_value="Auto Topic"): + await runner._auto_title_gateway_session( + session_id="test_session_123", + session_key="telegram:12345:67890", + source=source, + user_message="hello", + assistant_response="hi there", + ) + + assert db.get_session_title("test_session_123") == "Auto Topic" + adapter.update_thread_title.assert_awaited_once_with("67890", "470094", "Auto Topic") + db.close() + + @pytest.mark.asyncio + async def test_auto_title_skips_platform_sync_when_no_thread(self, tmp_path): + """Gateway auto-title without a thread should remain DB-only.""" + from hermes_state import SessionDB + db = SessionDB(db_path=tmp_path / "state.db") + db.create_session("test_session_123", "telegram") + + runner = _make_runner(session_db=db) + adapter = MagicMock() + adapter.update_thread_title = AsyncMock(return_value=True) + runner.adapters[Platform.TELEGRAM] = adapter + source = _make_event().source + + with patch("agent.title_generator.generate_title_if_missing", return_value="Auto Session"): + await runner._auto_title_gateway_session( + session_id="test_session_123", + session_key="telegram:12345:67890", + source=source, + user_message="hello", + assistant_response="hi there", + ) + + assert db.get_session_title("test_session_123") == "Auto Session" + adapter.update_thread_title.assert_not_called() + db.close() + + @pytest.mark.asyncio + async def test_auto_title_skips_overwriting_existing_manual_title(self, tmp_path): + """Gateway auto-title should not clobber a title set while generation was in flight.""" + from hermes_state import SessionDB + db = SessionDB(db_path=tmp_path / "state.db") + db.create_session("test_session_123", "telegram") + db.set_session_title("test_session_123", "Manual Title") + + runner = _make_runner(session_db=db) + adapter = MagicMock() + adapter.update_thread_title = AsyncMock(return_value=True) + runner.adapters[Platform.TELEGRAM] = adapter + source = _make_event(thread_id="470094").source + + with patch("agent.title_generator.generate_title_if_missing", return_value="Auto Topic"): + await runner._auto_title_gateway_session( + session_id="test_session_123", + session_key="telegram:12345:67890", + source=source, + user_message="hello", + assistant_response="hi there", + ) + + assert db.get_session_title("test_session_123") == "Manual Title" + adapter.update_thread_title.assert_not_called() + db.close() + @pytest.mark.asyncio async def test_show_title_when_set(self, tmp_path): """Showing title when one is set returns the title.""" diff --git a/tests/test_hermes_state.py b/tests/test_hermes_state.py index dfb2445c55..508db86d16 100644 --- a/tests/test_hermes_state.py +++ b/tests/test_hermes_state.py @@ -945,6 +945,7 @@ class TestSessionTitle: def test_set_title_nonexistent_session(self, db): assert db.set_session_title("nonexistent", "Title") is False + assert db.set_session_title_if_missing("nonexistent", "Title") is False def test_title_initially_none(self, db): db.create_session(session_id="s1", source="cli") @@ -959,6 +960,20 @@ class TestSessionTitle: session = db.get_session("s1") assert session["title"] == "Updated Title" + def test_set_title_if_missing_only_sets_once(self, db): + db.create_session(session_id="s1", source="cli") + assert db.set_session_title_if_missing("s1", "Initial Title") is True + assert db.set_session_title_if_missing("s1", "Ignored Title") is False + assert db.get_session("s1")["title"] == "Initial Title" + + def test_set_title_if_missing_respects_uniqueness(self, db): + db.create_session(session_id="s1", source="cli") + db.create_session(session_id="s2", source="cli") + db.set_session_title("s1", "Taken") + + with pytest.raises(ValueError, match="already in use"): + db.set_session_title_if_missing("s2", "Taken") + def test_title_in_search_sessions(self, db): db.create_session(session_id="s1", source="cli") db.set_session_title("s1", "Debugging Auth")