mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-27 01:11:40 +00:00
refactor: unify gateway session title sync flow
This commit is contained in:
parent
59d45346ba
commit
3438de3623
8 changed files with 421 additions and 61 deletions
|
|
@ -56,6 +56,26 @@ def generate_title(user_message: str, assistant_response: str, timeout: float =
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def generate_title_if_missing(
|
||||||
|
session_db,
|
||||||
|
session_id: str,
|
||||||
|
user_message: str,
|
||||||
|
assistant_response: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Return a generated title only if the session does not already have one."""
|
||||||
|
if not session_db or not session_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
existing = session_db.get_session_title(session_id)
|
||||||
|
if existing:
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return generate_title(user_message, assistant_response)
|
||||||
|
|
||||||
|
|
||||||
def auto_title_session(
|
def auto_title_session(
|
||||||
session_db,
|
session_db,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
|
@ -70,18 +90,12 @@ def auto_title_session(
|
||||||
- session already has a title (user-set or previously auto-generated)
|
- session already has a title (user-set or previously auto-generated)
|
||||||
- title generation fails
|
- title generation fails
|
||||||
"""
|
"""
|
||||||
if not session_db or not session_id:
|
title = generate_title_if_missing(
|
||||||
return
|
session_db,
|
||||||
|
session_id,
|
||||||
# Check if title already exists (user may have set one via /title before first response)
|
user_message,
|
||||||
try:
|
assistant_response,
|
||||||
existing = session_db.get_session_title(session_id)
|
)
|
||||||
if existing:
|
|
||||||
return
|
|
||||||
except Exception:
|
|
||||||
return
|
|
||||||
|
|
||||||
title = generate_title(user_message, assistant_response)
|
|
||||||
if not title:
|
if not title:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -92,6 +106,12 @@ def auto_title_session(
|
||||||
logger.debug("Failed to set auto-generated title: %s", e)
|
logger.debug("Failed to set auto-generated title: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
def should_auto_title(conversation_history: list) -> bool:
|
||||||
|
"""Return whether this history is still early enough for auto-titling."""
|
||||||
|
user_msg_count = sum(1 for m in (conversation_history or []) if m.get("role") == "user")
|
||||||
|
return user_msg_count <= 2
|
||||||
|
|
||||||
|
|
||||||
def maybe_auto_title(
|
def maybe_auto_title(
|
||||||
session_db,
|
session_db,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
|
@ -108,12 +128,7 @@ def maybe_auto_title(
|
||||||
if not session_db or not session_id or not user_message or not assistant_response:
|
if not session_db or not session_id or not user_message or not assistant_response:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Count user messages in history to detect first exchange.
|
if not should_auto_title(conversation_history):
|
||||||
# conversation_history includes the exchange that just happened,
|
|
||||||
# so for a first exchange we expect exactly 1 user message
|
|
||||||
# (or 2 counting system). Be generous: generate on first 2 exchanges.
|
|
||||||
user_msg_count = sum(1 for m in (conversation_history or []) if m.get("role") == "user")
|
|
||||||
if user_msg_count > 2:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
|
|
|
||||||
|
|
@ -1522,13 +1522,19 @@ class BasePlatformAdapter(ABC):
|
||||||
|
|
||||||
``generation`` lets callers tie the callback to a specific gateway run
|
``generation`` lets callers tie the callback to a specific gateway run
|
||||||
generation so stale runs cannot clear callbacks owned by a fresher run.
|
generation so stale runs cannot clear callbacks owned by a fresher run.
|
||||||
|
Multiple callbacks may be registered for the same session; matching
|
||||||
|
callbacks run in registration order.
|
||||||
"""
|
"""
|
||||||
if not session_key or not callable(callback):
|
if not session_key or not callable(callback):
|
||||||
return
|
return
|
||||||
if generation is None:
|
entry: Any = callback if generation is None else (int(generation), callback)
|
||||||
self._post_delivery_callbacks[session_key] = callback
|
existing = self._post_delivery_callbacks.get(session_key)
|
||||||
|
if existing is None:
|
||||||
|
self._post_delivery_callbacks[session_key] = entry
|
||||||
|
elif isinstance(existing, list):
|
||||||
|
existing.append(entry)
|
||||||
else:
|
else:
|
||||||
self._post_delivery_callbacks[session_key] = (int(generation), callback)
|
self._post_delivery_callbacks[session_key] = [existing, entry]
|
||||||
|
|
||||||
def pop_post_delivery_callback(
|
def pop_post_delivery_callback(
|
||||||
self,
|
self,
|
||||||
|
|
@ -1536,22 +1542,46 @@ class BasePlatformAdapter(ABC):
|
||||||
*,
|
*,
|
||||||
generation: int | None = None,
|
generation: int | None = None,
|
||||||
) -> Callable | None:
|
) -> Callable | None:
|
||||||
"""Pop a deferred callback, optionally requiring generation ownership."""
|
"""Pop deferred callbacks, optionally requiring generation ownership."""
|
||||||
if not session_key:
|
if not session_key:
|
||||||
return None
|
return None
|
||||||
entry = self._post_delivery_callbacks.get(session_key)
|
entry = self._post_delivery_callbacks.get(session_key)
|
||||||
if entry is None:
|
if entry is None:
|
||||||
return None
|
return None
|
||||||
if isinstance(entry, tuple) and len(entry) == 2:
|
|
||||||
entry_generation, callback = entry
|
entries = entry if isinstance(entry, list) else [entry]
|
||||||
|
matched: list[Callable] = []
|
||||||
|
remaining: list[Any] = []
|
||||||
|
|
||||||
|
for item in entries:
|
||||||
|
if isinstance(item, tuple) and len(item) == 2:
|
||||||
|
entry_generation, callback = item
|
||||||
if generation is not None and int(entry_generation) != int(generation):
|
if generation is not None and int(entry_generation) != int(generation):
|
||||||
return None
|
remaining.append(item)
|
||||||
self._post_delivery_callbacks.pop(session_key, None)
|
continue
|
||||||
return callback if callable(callback) else None
|
if callable(callback):
|
||||||
|
matched.append(callback)
|
||||||
|
continue
|
||||||
|
|
||||||
if generation is not None:
|
if generation is not None:
|
||||||
return None
|
remaining.append(item)
|
||||||
|
continue
|
||||||
|
if callable(item):
|
||||||
|
matched.append(item)
|
||||||
|
|
||||||
|
if remaining:
|
||||||
|
self._post_delivery_callbacks[session_key] = remaining if len(remaining) > 1 else remaining[0]
|
||||||
|
else:
|
||||||
self._post_delivery_callbacks.pop(session_key, None)
|
self._post_delivery_callbacks.pop(session_key, None)
|
||||||
return entry if callable(entry) else None
|
|
||||||
|
if not matched:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _run_all() -> None:
|
||||||
|
for callback in matched:
|
||||||
|
callback()
|
||||||
|
|
||||||
|
return _run_all
|
||||||
|
|
||||||
# ── Processing lifecycle hooks ──────────────────────────────────────────
|
# ── Processing lifecycle hooks ──────────────────────────────────────────
|
||||||
# Subclasses override these to react to message processing events
|
# Subclasses override these to react to message processing events
|
||||||
|
|
|
||||||
195
gateway/run.py
195
gateway/run.py
|
|
@ -1028,6 +1028,169 @@ class GatewayRunner:
|
||||||
thread_sessions_per_user=getattr(config, "thread_sessions_per_user", False),
|
thread_sessions_per_user=getattr(config, "thread_sessions_per_user", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _track_background_task(self, task: asyncio.Task) -> asyncio.Task:
|
||||||
|
"""Keep a background task alive until completion."""
|
||||||
|
background_tasks = getattr(self, "_background_tasks", None)
|
||||||
|
if isinstance(background_tasks, set):
|
||||||
|
background_tasks.add(task)
|
||||||
|
task.add_done_callback(background_tasks.discard)
|
||||||
|
return task
|
||||||
|
|
||||||
|
async def _sync_session_title_to_source(self, source: SessionSource, title: str) -> bool:
|
||||||
|
"""Best-effort platform thread/topic title sync for a session title."""
|
||||||
|
if not source or not source.thread_id or not source.platform:
|
||||||
|
return False
|
||||||
|
adapter = self.adapters.get(source.platform)
|
||||||
|
if not adapter:
|
||||||
|
return False
|
||||||
|
updater = getattr(adapter, "update_thread_title", None)
|
||||||
|
if not callable(updater):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
return bool(await updater(source.chat_id, source.thread_id, title))
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Session-title sync failed for %s thread %s: %s", source.platform, source.thread_id, e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _schedule_session_title_sync_after_delivery(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
session_key: Optional[str],
|
||||||
|
source: SessionSource,
|
||||||
|
title: str,
|
||||||
|
generation: int | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Schedule platform title sync after the current response is delivered."""
|
||||||
|
if not session_key or not source or not source.thread_id or not source.platform:
|
||||||
|
return False
|
||||||
|
adapter = self.adapters.get(source.platform)
|
||||||
|
if not adapter:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _after_delivery() -> None:
|
||||||
|
self._track_background_task(
|
||||||
|
asyncio.create_task(self._sync_session_title_to_source(source, title))
|
||||||
|
)
|
||||||
|
|
||||||
|
register = getattr(adapter, "register_post_delivery_callback", None)
|
||||||
|
if callable(register):
|
||||||
|
register(session_key, _after_delivery, generation=generation)
|
||||||
|
return True
|
||||||
|
|
||||||
|
post_callbacks = getattr(adapter, "_post_delivery_callbacks", None)
|
||||||
|
if isinstance(post_callbacks, dict):
|
||||||
|
post_callbacks[session_key] = _after_delivery
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _apply_session_title(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
session_id: str,
|
||||||
|
source: SessionSource,
|
||||||
|
title: str,
|
||||||
|
session_key: Optional[str] = None,
|
||||||
|
defer_thread_sync: bool = False,
|
||||||
|
sync_thread_now: bool = False,
|
||||||
|
only_if_missing: bool = False,
|
||||||
|
generation: int | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Persist a session title and optionally propagate it to the source thread."""
|
||||||
|
if not self._session_db:
|
||||||
|
return False
|
||||||
|
if only_if_missing:
|
||||||
|
changed = self._session_db.set_session_title_if_missing(session_id, title)
|
||||||
|
else:
|
||||||
|
changed = self._session_db.set_session_title(session_id, title)
|
||||||
|
if not changed:
|
||||||
|
return False
|
||||||
|
if sync_thread_now:
|
||||||
|
await self._sync_session_title_to_source(source, title)
|
||||||
|
elif defer_thread_sync:
|
||||||
|
self._schedule_session_title_sync_after_delivery(
|
||||||
|
session_key=session_key,
|
||||||
|
source=source,
|
||||||
|
title=title,
|
||||||
|
generation=generation,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _auto_title_gateway_session(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
session_id: str,
|
||||||
|
session_key: Optional[str],
|
||||||
|
source: SessionSource,
|
||||||
|
user_message: str,
|
||||||
|
assistant_response: str,
|
||||||
|
) -> None:
|
||||||
|
"""Generate a missing title for a gateway session and sync it natively."""
|
||||||
|
if not self._session_db:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
from agent.title_generator import generate_title_if_missing
|
||||||
|
|
||||||
|
title = await asyncio.to_thread(
|
||||||
|
generate_title_if_missing,
|
||||||
|
self._session_db,
|
||||||
|
session_id,
|
||||||
|
user_message,
|
||||||
|
assistant_response,
|
||||||
|
)
|
||||||
|
if not title:
|
||||||
|
return
|
||||||
|
await self._apply_session_title(
|
||||||
|
session_id=session_id,
|
||||||
|
session_key=session_key,
|
||||||
|
source=source,
|
||||||
|
title=title,
|
||||||
|
sync_thread_now=True,
|
||||||
|
only_if_missing=True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Gateway auto-title failed for %s: %s", session_id, e)
|
||||||
|
|
||||||
|
def _maybe_schedule_gateway_auto_title(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
session_id: str,
|
||||||
|
session_key: Optional[str],
|
||||||
|
source: SessionSource,
|
||||||
|
user_message: str,
|
||||||
|
assistant_response: str,
|
||||||
|
conversation_history: list,
|
||||||
|
generation: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Run gateway auto-title after delivery when this is an early exchange."""
|
||||||
|
if not self._session_db or not session_id or not user_message or not assistant_response:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
from agent.title_generator import should_auto_title
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
if not should_auto_title(conversation_history):
|
||||||
|
return
|
||||||
|
|
||||||
|
def _launch() -> None:
|
||||||
|
self._track_background_task(
|
||||||
|
asyncio.create_task(
|
||||||
|
self._auto_title_gateway_session(
|
||||||
|
session_id=session_id,
|
||||||
|
session_key=session_key,
|
||||||
|
source=source,
|
||||||
|
user_message=user_message,
|
||||||
|
assistant_response=assistant_response,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter = self.adapters.get(source.platform) if source and source.platform else None
|
||||||
|
register = getattr(adapter, "register_post_delivery_callback", None) if adapter else None
|
||||||
|
if session_key and callable(register):
|
||||||
|
register(session_key, _launch, generation=generation)
|
||||||
|
else:
|
||||||
|
_launch()
|
||||||
|
|
||||||
def _resolve_session_agent_runtime(
|
def _resolve_session_agent_runtime(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|
@ -7063,17 +7226,14 @@ class GatewayRunner:
|
||||||
return "⚠️ Title is empty after cleanup. Please use printable characters."
|
return "⚠️ Title is empty after cleanup. Please use printable characters."
|
||||||
# Set the title
|
# Set the title
|
||||||
try:
|
try:
|
||||||
if self._session_db.set_session_title(session_id, sanitized):
|
if await self._apply_session_title(
|
||||||
response = f"✏️ Session title set: **{sanitized}**"
|
session_id=session_id,
|
||||||
if source.platform == Platform.TELEGRAM and source.thread_id:
|
session_key=session_entry.session_key,
|
||||||
adapter = self.adapters.get(source.platform)
|
source=source,
|
||||||
if adapter and await adapter.update_thread_title(
|
title=sanitized,
|
||||||
source.chat_id,
|
defer_thread_sync=True,
|
||||||
source.thread_id,
|
|
||||||
sanitized,
|
|
||||||
):
|
):
|
||||||
response += "\n🧵 Telegram topic renamed too."
|
return f"✏️ Session title set: **{sanitized}**"
|
||||||
return response
|
|
||||||
else:
|
else:
|
||||||
return "Session not found in database."
|
return "Session not found in database."
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|
@ -10092,14 +10252,15 @@ class GatewayRunner:
|
||||||
# Auto-generate session title after first exchange (non-blocking)
|
# Auto-generate session title after first exchange (non-blocking)
|
||||||
if final_response and self._session_db:
|
if final_response and self._session_db:
|
||||||
try:
|
try:
|
||||||
from agent.title_generator import maybe_auto_title
|
|
||||||
all_msgs = result_holder[0].get("messages", []) if result_holder[0] else []
|
all_msgs = result_holder[0].get("messages", []) if result_holder[0] else []
|
||||||
maybe_auto_title(
|
self._maybe_schedule_gateway_auto_title(
|
||||||
self._session_db,
|
session_id=effective_session_id,
|
||||||
effective_session_id,
|
session_key=session_key,
|
||||||
message,
|
source=source,
|
||||||
final_response,
|
user_message=message,
|
||||||
all_msgs,
|
assistant_response=final_response,
|
||||||
|
conversation_history=all_msgs,
|
||||||
|
generation=run_generation,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -641,6 +641,31 @@ class SessionDB:
|
||||||
rowcount = self._execute_write(_do)
|
rowcount = self._execute_write(_do)
|
||||||
return rowcount > 0
|
return rowcount > 0
|
||||||
|
|
||||||
|
def set_session_title_if_missing(self, session_id: str, title: str) -> bool:
|
||||||
|
"""Atomically set a session title only when the current title is NULL."""
|
||||||
|
title = self.sanitize_title(title)
|
||||||
|
if not title:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _do(conn):
|
||||||
|
cursor = conn.execute(
|
||||||
|
"SELECT id FROM sessions WHERE title = ? AND id != ?",
|
||||||
|
(title, session_id),
|
||||||
|
)
|
||||||
|
conflict = cursor.fetchone()
|
||||||
|
if conflict:
|
||||||
|
raise ValueError(
|
||||||
|
f"Title '{title}' is already in use by session {conflict['id']}"
|
||||||
|
)
|
||||||
|
cursor = conn.execute(
|
||||||
|
"UPDATE sessions SET title = ? WHERE id = ? AND title IS NULL",
|
||||||
|
(title, session_id),
|
||||||
|
)
|
||||||
|
return cursor.rowcount
|
||||||
|
|
||||||
|
rowcount = self._execute_write(_do)
|
||||||
|
return rowcount > 0
|
||||||
|
|
||||||
def get_session_title(self, session_id: str) -> Optional[str]:
|
def get_session_title(self, session_id: str) -> Optional[str]:
|
||||||
"""Get the title for a session, or None."""
|
"""Get the title for a session, or None."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import pytest
|
||||||
|
|
||||||
from agent.title_generator import (
|
from agent.title_generator import (
|
||||||
generate_title,
|
generate_title,
|
||||||
|
generate_title_if_missing,
|
||||||
auto_title_session,
|
auto_title_session,
|
||||||
maybe_auto_title,
|
maybe_auto_title,
|
||||||
)
|
)
|
||||||
|
|
@ -89,19 +90,26 @@ class TestAutoTitleSession:
|
||||||
def test_skips_if_no_session_db(self):
|
def test_skips_if_no_session_db(self):
|
||||||
auto_title_session(None, "sess-1", "hi", "hello") # should not crash
|
auto_title_session(None, "sess-1", "hi", "hello") # should not crash
|
||||||
|
|
||||||
def test_skips_if_title_exists(self):
|
def test_generate_title_if_missing_skips_if_title_exists(self):
|
||||||
db = MagicMock()
|
db = MagicMock()
|
||||||
db.get_session_title.return_value = "Existing Title"
|
db.get_session_title.return_value = "Existing Title"
|
||||||
|
|
||||||
with patch("agent.title_generator.generate_title") as gen:
|
with patch("agent.title_generator.generate_title") as gen:
|
||||||
auto_title_session(db, "sess-1", "hi", "hello")
|
assert generate_title_if_missing(db, "sess-1", "hi", "hello") is None
|
||||||
gen.assert_not_called()
|
gen.assert_not_called()
|
||||||
|
|
||||||
|
def test_generate_title_if_missing_returns_generated_title(self):
|
||||||
|
db = MagicMock()
|
||||||
|
db.get_session_title.return_value = None
|
||||||
|
|
||||||
|
with patch("agent.title_generator.generate_title", return_value="New Title"):
|
||||||
|
assert generate_title_if_missing(db, "sess-1", "hi", "hello") == "New Title"
|
||||||
|
|
||||||
def test_generates_and_sets_title(self):
|
def test_generates_and_sets_title(self):
|
||||||
db = MagicMock()
|
db = MagicMock()
|
||||||
db.get_session_title.return_value = None
|
db.get_session_title.return_value = None
|
||||||
|
|
||||||
with patch("agent.title_generator.generate_title", return_value="New Title"):
|
with patch("agent.title_generator.generate_title_if_missing", return_value="New Title"):
|
||||||
auto_title_session(db, "sess-1", "hi", "hello")
|
auto_title_session(db, "sess-1", "hi", "hello")
|
||||||
db.set_session_title.assert_called_once_with("sess-1", "New Title")
|
db.set_session_title.assert_called_once_with("sess-1", "New Title")
|
||||||
|
|
||||||
|
|
@ -109,7 +117,7 @@ class TestAutoTitleSession:
|
||||||
db = MagicMock()
|
db = MagicMock()
|
||||||
db.get_session_title.return_value = None
|
db.get_session_title.return_value = None
|
||||||
|
|
||||||
with patch("agent.title_generator.generate_title", return_value=None):
|
with patch("agent.title_generator.generate_title_if_missing", return_value=None):
|
||||||
auto_title_session(db, "sess-1", "hi", "hello")
|
auto_title_session(db, "sess-1", "hi", "hello")
|
||||||
db.set_session_title.assert_not_called()
|
db.set_session_title.assert_not_called()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -782,6 +782,22 @@ async def test_base_processing_releases_post_delivery_callback_after_main_send()
|
||||||
assert released == [True]
|
assert released == [True]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_post_delivery_callbacks_compose_for_same_session():
|
||||||
|
"""Multiple post-delivery callbacks for a session should all run."""
|
||||||
|
adapter = ProgressCaptureAdapter()
|
||||||
|
fired = []
|
||||||
|
|
||||||
|
adapter.register_post_delivery_callback("sess", lambda: fired.append("first"), generation=3)
|
||||||
|
adapter.register_post_delivery_callback("sess", lambda: fired.append("second"), generation=3)
|
||||||
|
|
||||||
|
callback = adapter.pop_post_delivery_callback("sess", generation=3)
|
||||||
|
assert callable(callback)
|
||||||
|
callback()
|
||||||
|
|
||||||
|
assert fired == ["first", "second"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_agent_drops_tool_progress_after_generation_invalidation(monkeypatch, tmp_path):
|
async def test_run_agent_drops_tool_progress_after_generation_invalidation(monkeypatch, tmp_path):
|
||||||
import yaml
|
import yaml
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
"""Tests for /title gateway slash command.
|
"""Tests for gateway session-title flows.
|
||||||
|
|
||||||
Tests the _handle_title_command handler (set/show session titles)
|
Tests the /title handler plus native gateway session-title propagation
|
||||||
across all gateway messenger platforms.
|
for manual and auto-generated titles.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -33,6 +34,7 @@ def _make_runner(session_db=None):
|
||||||
runner.adapters = {}
|
runner.adapters = {}
|
||||||
runner._voice_mode = {}
|
runner._voice_mode = {}
|
||||||
runner._session_db = session_db
|
runner._session_db = session_db
|
||||||
|
runner._background_tasks = set()
|
||||||
|
|
||||||
# Mock session_store that returns a session entry with a known session_id
|
# Mock session_store that returns a session entry with a known session_id
|
||||||
mock_session_entry = MagicMock()
|
mock_session_entry = MagicMock()
|
||||||
|
|
@ -72,7 +74,7 @@ class TestHandleTitleCommand:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_set_title_renames_telegram_topic_when_in_thread(self, tmp_path):
|
async def test_set_title_renames_telegram_topic_when_in_thread(self, tmp_path):
|
||||||
"""Telegram /title should also rename the active topic thread when possible."""
|
"""Telegram /title should schedule thread-title sync via the native callback path."""
|
||||||
from hermes_state import SessionDB
|
from hermes_state import SessionDB
|
||||||
db = SessionDB(db_path=tmp_path / "state.db")
|
db = SessionDB(db_path=tmp_path / "state.db")
|
||||||
db.create_session("test_session_123", "telegram")
|
db.create_session("test_session_123", "telegram")
|
||||||
|
|
@ -85,14 +87,18 @@ class TestHandleTitleCommand:
|
||||||
event = _make_event(text="/title Indicative Topic", thread_id="470094")
|
event = _make_event(text="/title Indicative Topic", thread_id="470094")
|
||||||
result = await runner._handle_title_command(event)
|
result = await runner._handle_title_command(event)
|
||||||
|
|
||||||
|
adapter.register_post_delivery_callback.assert_called_once()
|
||||||
|
callback = adapter.register_post_delivery_callback.call_args.args[1]
|
||||||
|
callback()
|
||||||
|
await asyncio.sleep(0)
|
||||||
adapter.update_thread_title.assert_awaited_once_with("67890", "470094", "Indicative Topic")
|
adapter.update_thread_title.assert_awaited_once_with("67890", "470094", "Indicative Topic")
|
||||||
assert "Telegram topic renamed too" in result
|
assert "Telegram topic renamed too" not in result
|
||||||
assert db.get_session_title("test_session_123") == "Indicative Topic"
|
assert db.get_session_title("test_session_123") == "Indicative Topic"
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_set_title_renames_telegram_general_topic_when_thread_is_one(self, tmp_path):
|
async def test_set_title_renames_telegram_general_topic_when_thread_is_one(self, tmp_path):
|
||||||
"""Telegram General topic thread_id=1 should still trigger a rename attempt."""
|
"""Telegram General topic thread_id=1 should also use the deferred sync path."""
|
||||||
from hermes_state import SessionDB
|
from hermes_state import SessionDB
|
||||||
db = SessionDB(db_path=tmp_path / "state.db")
|
db = SessionDB(db_path=tmp_path / "state.db")
|
||||||
db.create_session("test_session_123", "telegram")
|
db.create_session("test_session_123", "telegram")
|
||||||
|
|
@ -105,8 +111,12 @@ class TestHandleTitleCommand:
|
||||||
event = _make_event(text="/title Lobby", thread_id="1")
|
event = _make_event(text="/title Lobby", thread_id="1")
|
||||||
result = await runner._handle_title_command(event)
|
result = await runner._handle_title_command(event)
|
||||||
|
|
||||||
|
adapter.register_post_delivery_callback.assert_called_once()
|
||||||
|
callback = adapter.register_post_delivery_callback.call_args.args[1]
|
||||||
|
callback()
|
||||||
|
await asyncio.sleep(0)
|
||||||
adapter.update_thread_title.assert_awaited_once_with("67890", "1", "Lobby")
|
adapter.update_thread_title.assert_awaited_once_with("67890", "1", "Lobby")
|
||||||
assert "Telegram topic renamed too" in result
|
assert "Telegram topic renamed too" not in result
|
||||||
assert db.get_session_title("test_session_123") == "Lobby"
|
assert db.get_session_title("test_session_123") == "Lobby"
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
@ -125,11 +135,91 @@ class TestHandleTitleCommand:
|
||||||
event = _make_event(text="/title Plain Chat Title")
|
event = _make_event(text="/title Plain Chat Title")
|
||||||
result = await runner._handle_title_command(event)
|
result = await runner._handle_title_command(event)
|
||||||
|
|
||||||
adapter.update_thread_title.assert_not_called()
|
adapter.register_post_delivery_callback.assert_not_called()
|
||||||
assert "Telegram topic renamed too" not in result
|
|
||||||
assert db.get_session_title("test_session_123") == "Plain Chat Title"
|
assert db.get_session_title("test_session_123") == "Plain Chat Title"
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestGatewayAutoTitleSync:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_title_flow_uses_same_session_title_path(self, tmp_path):
|
||||||
|
"""Gateway auto-title should persist title and sync Telegram thread title."""
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
db = SessionDB(db_path=tmp_path / "state.db")
|
||||||
|
db.create_session("test_session_123", "telegram")
|
||||||
|
|
||||||
|
runner = _make_runner(session_db=db)
|
||||||
|
adapter = MagicMock()
|
||||||
|
adapter.update_thread_title = AsyncMock(return_value=True)
|
||||||
|
runner.adapters[Platform.TELEGRAM] = adapter
|
||||||
|
source = _make_event(thread_id="470094").source
|
||||||
|
|
||||||
|
with patch("agent.title_generator.generate_title_if_missing", return_value="Auto Topic"):
|
||||||
|
await runner._auto_title_gateway_session(
|
||||||
|
session_id="test_session_123",
|
||||||
|
session_key="telegram:12345:67890",
|
||||||
|
source=source,
|
||||||
|
user_message="hello",
|
||||||
|
assistant_response="hi there",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert db.get_session_title("test_session_123") == "Auto Topic"
|
||||||
|
adapter.update_thread_title.assert_awaited_once_with("67890", "470094", "Auto Topic")
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_title_skips_platform_sync_when_no_thread(self, tmp_path):
|
||||||
|
"""Gateway auto-title without a thread should remain DB-only."""
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
db = SessionDB(db_path=tmp_path / "state.db")
|
||||||
|
db.create_session("test_session_123", "telegram")
|
||||||
|
|
||||||
|
runner = _make_runner(session_db=db)
|
||||||
|
adapter = MagicMock()
|
||||||
|
adapter.update_thread_title = AsyncMock(return_value=True)
|
||||||
|
runner.adapters[Platform.TELEGRAM] = adapter
|
||||||
|
source = _make_event().source
|
||||||
|
|
||||||
|
with patch("agent.title_generator.generate_title_if_missing", return_value="Auto Session"):
|
||||||
|
await runner._auto_title_gateway_session(
|
||||||
|
session_id="test_session_123",
|
||||||
|
session_key="telegram:12345:67890",
|
||||||
|
source=source,
|
||||||
|
user_message="hello",
|
||||||
|
assistant_response="hi there",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert db.get_session_title("test_session_123") == "Auto Session"
|
||||||
|
adapter.update_thread_title.assert_not_called()
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_title_skips_overwriting_existing_manual_title(self, tmp_path):
|
||||||
|
"""Gateway auto-title should not clobber a title set while generation was in flight."""
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
db = SessionDB(db_path=tmp_path / "state.db")
|
||||||
|
db.create_session("test_session_123", "telegram")
|
||||||
|
db.set_session_title("test_session_123", "Manual Title")
|
||||||
|
|
||||||
|
runner = _make_runner(session_db=db)
|
||||||
|
adapter = MagicMock()
|
||||||
|
adapter.update_thread_title = AsyncMock(return_value=True)
|
||||||
|
runner.adapters[Platform.TELEGRAM] = adapter
|
||||||
|
source = _make_event(thread_id="470094").source
|
||||||
|
|
||||||
|
with patch("agent.title_generator.generate_title_if_missing", return_value="Auto Topic"):
|
||||||
|
await runner._auto_title_gateway_session(
|
||||||
|
session_id="test_session_123",
|
||||||
|
session_key="telegram:12345:67890",
|
||||||
|
source=source,
|
||||||
|
user_message="hello",
|
||||||
|
assistant_response="hi there",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert db.get_session_title("test_session_123") == "Manual Title"
|
||||||
|
adapter.update_thread_title.assert_not_called()
|
||||||
|
db.close()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_show_title_when_set(self, tmp_path):
|
async def test_show_title_when_set(self, tmp_path):
|
||||||
"""Showing title when one is set returns the title."""
|
"""Showing title when one is set returns the title."""
|
||||||
|
|
|
||||||
|
|
@ -945,6 +945,7 @@ class TestSessionTitle:
|
||||||
|
|
||||||
def test_set_title_nonexistent_session(self, db):
|
def test_set_title_nonexistent_session(self, db):
|
||||||
assert db.set_session_title("nonexistent", "Title") is False
|
assert db.set_session_title("nonexistent", "Title") is False
|
||||||
|
assert db.set_session_title_if_missing("nonexistent", "Title") is False
|
||||||
|
|
||||||
def test_title_initially_none(self, db):
|
def test_title_initially_none(self, db):
|
||||||
db.create_session(session_id="s1", source="cli")
|
db.create_session(session_id="s1", source="cli")
|
||||||
|
|
@ -959,6 +960,20 @@ class TestSessionTitle:
|
||||||
session = db.get_session("s1")
|
session = db.get_session("s1")
|
||||||
assert session["title"] == "Updated Title"
|
assert session["title"] == "Updated Title"
|
||||||
|
|
||||||
|
def test_set_title_if_missing_only_sets_once(self, db):
|
||||||
|
db.create_session(session_id="s1", source="cli")
|
||||||
|
assert db.set_session_title_if_missing("s1", "Initial Title") is True
|
||||||
|
assert db.set_session_title_if_missing("s1", "Ignored Title") is False
|
||||||
|
assert db.get_session("s1")["title"] == "Initial Title"
|
||||||
|
|
||||||
|
def test_set_title_if_missing_respects_uniqueness(self, db):
|
||||||
|
db.create_session(session_id="s1", source="cli")
|
||||||
|
db.create_session(session_id="s2", source="cli")
|
||||||
|
db.set_session_title("s1", "Taken")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="already in use"):
|
||||||
|
db.set_session_title_if_missing("s2", "Taken")
|
||||||
|
|
||||||
def test_title_in_search_sessions(self, db):
|
def test_title_in_search_sessions(self, db):
|
||||||
db.create_session(session_id="s1", source="cli")
|
db.create_session(session_id="s1", source="cli")
|
||||||
db.set_session_title("s1", "Debugging Auth")
|
db.set_session_title("s1", "Debugging Auth")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue