refactor: unify gateway session title sync flow

This commit is contained in:
aaron 2026-04-23 09:09:33 +00:00
parent 59d45346ba
commit 3438de3623
8 changed files with 421 additions and 61 deletions

View file

@ -1028,6 +1028,169 @@ class GatewayRunner:
thread_sessions_per_user=getattr(config, "thread_sessions_per_user", False),
)
def _track_background_task(self, task: asyncio.Task) -> asyncio.Task:
"""Keep a background task alive until completion."""
background_tasks = getattr(self, "_background_tasks", None)
if isinstance(background_tasks, set):
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
return task
async def _sync_session_title_to_source(self, source: SessionSource, title: str) -> bool:
"""Best-effort platform thread/topic title sync for a session title."""
if not source or not source.thread_id or not source.platform:
return False
adapter = self.adapters.get(source.platform)
if not adapter:
return False
updater = getattr(adapter, "update_thread_title", None)
if not callable(updater):
return False
try:
return bool(await updater(source.chat_id, source.thread_id, title))
except Exception as e:
logger.debug("Session-title sync failed for %s thread %s: %s", source.platform, source.thread_id, e)
return False
def _schedule_session_title_sync_after_delivery(
self,
*,
session_key: Optional[str],
source: SessionSource,
title: str,
generation: int | None = None,
) -> bool:
"""Schedule platform title sync after the current response is delivered."""
if not session_key or not source or not source.thread_id or not source.platform:
return False
adapter = self.adapters.get(source.platform)
if not adapter:
return False
def _after_delivery() -> None:
self._track_background_task(
asyncio.create_task(self._sync_session_title_to_source(source, title))
)
register = getattr(adapter, "register_post_delivery_callback", None)
if callable(register):
register(session_key, _after_delivery, generation=generation)
return True
post_callbacks = getattr(adapter, "_post_delivery_callbacks", None)
if isinstance(post_callbacks, dict):
post_callbacks[session_key] = _after_delivery
return True
return False
async def _apply_session_title(
self,
*,
session_id: str,
source: SessionSource,
title: str,
session_key: Optional[str] = None,
defer_thread_sync: bool = False,
sync_thread_now: bool = False,
only_if_missing: bool = False,
generation: int | None = None,
) -> bool:
"""Persist a session title and optionally propagate it to the source thread."""
if not self._session_db:
return False
if only_if_missing:
changed = self._session_db.set_session_title_if_missing(session_id, title)
else:
changed = self._session_db.set_session_title(session_id, title)
if not changed:
return False
if sync_thread_now:
await self._sync_session_title_to_source(source, title)
elif defer_thread_sync:
self._schedule_session_title_sync_after_delivery(
session_key=session_key,
source=source,
title=title,
generation=generation,
)
return True
async def _auto_title_gateway_session(
self,
*,
session_id: str,
session_key: Optional[str],
source: SessionSource,
user_message: str,
assistant_response: str,
) -> None:
"""Generate a missing title for a gateway session and sync it natively."""
if not self._session_db:
return
try:
from agent.title_generator import generate_title_if_missing
title = await asyncio.to_thread(
generate_title_if_missing,
self._session_db,
session_id,
user_message,
assistant_response,
)
if not title:
return
await self._apply_session_title(
session_id=session_id,
session_key=session_key,
source=source,
title=title,
sync_thread_now=True,
only_if_missing=True,
)
except Exception as e:
logger.debug("Gateway auto-title failed for %s: %s", session_id, e)
def _maybe_schedule_gateway_auto_title(
self,
*,
session_id: str,
session_key: Optional[str],
source: SessionSource,
user_message: str,
assistant_response: str,
conversation_history: list,
generation: int | None = None,
) -> None:
"""Run gateway auto-title after delivery when this is an early exchange."""
if not self._session_db or not session_id or not user_message or not assistant_response:
return
try:
from agent.title_generator import should_auto_title
except Exception:
return
if not should_auto_title(conversation_history):
return
def _launch() -> None:
self._track_background_task(
asyncio.create_task(
self._auto_title_gateway_session(
session_id=session_id,
session_key=session_key,
source=source,
user_message=user_message,
assistant_response=assistant_response,
)
)
)
adapter = self.adapters.get(source.platform) if source and source.platform else None
register = getattr(adapter, "register_post_delivery_callback", None) if adapter else None
if session_key and callable(register):
register(session_key, _launch, generation=generation)
else:
_launch()
def _resolve_session_agent_runtime(
self,
*,
@ -7063,17 +7226,14 @@ class GatewayRunner:
return "⚠️ Title is empty after cleanup. Please use printable characters."
# Set the title
try:
if self._session_db.set_session_title(session_id, sanitized):
response = f"✏️ Session title set: **{sanitized}**"
if source.platform == Platform.TELEGRAM and source.thread_id:
adapter = self.adapters.get(source.platform)
if adapter and await adapter.update_thread_title(
source.chat_id,
source.thread_id,
sanitized,
):
response += "\n🧵 Telegram topic renamed too."
return response
if await self._apply_session_title(
session_id=session_id,
session_key=session_entry.session_key,
source=source,
title=sanitized,
defer_thread_sync=True,
):
return f"✏️ Session title set: **{sanitized}**"
else:
return "Session not found in database."
except ValueError as e:
@ -10092,14 +10252,15 @@ class GatewayRunner:
# Auto-generate session title after first exchange (non-blocking)
if final_response and self._session_db:
try:
from agent.title_generator import maybe_auto_title
all_msgs = result_holder[0].get("messages", []) if result_holder[0] else []
maybe_auto_title(
self._session_db,
effective_session_id,
message,
final_response,
all_msgs,
self._maybe_schedule_gateway_auto_title(
session_id=effective_session_id,
session_key=session_key,
source=source,
user_message=message,
assistant_response=final_response,
conversation_history=all_msgs,
generation=run_generation,
)
except Exception:
pass