mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-30 06:41:51 +00:00
fix(acp): refresh session info after auto-title
This commit is contained in:
parent
eda1c97a1e
commit
741a349458
2 changed files with 81 additions and 0 deletions
|
|
@ -46,6 +46,7 @@ from acp.schema import (
|
|||
ResourceContentBlock,
|
||||
SessionCapabilities,
|
||||
SessionForkCapabilities,
|
||||
SessionInfoUpdate,
|
||||
SessionListCapabilities,
|
||||
SessionMode,
|
||||
SessionModeState,
|
||||
|
|
@ -707,6 +708,35 @@ class HermesACPAgent(acp.Agent):
|
|||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _send_session_info_update(self, session_id: str) -> None:
|
||||
"""Send ACP native session metadata after Hermes changes it."""
|
||||
if not self._conn:
|
||||
return
|
||||
try:
|
||||
row = self.session_manager._get_db().get_session(session_id)
|
||||
except Exception:
|
||||
logger.debug("Could not read ACP session info for %s", session_id, exc_info=True)
|
||||
return
|
||||
if not row:
|
||||
return
|
||||
|
||||
title = row.get("title")
|
||||
updated_at = row.get("updated_at")
|
||||
if updated_at is not None and not isinstance(updated_at, str):
|
||||
updated_at = str(updated_at)
|
||||
update = SessionInfoUpdate(
|
||||
session_update="session_info_update",
|
||||
title=title if isinstance(title, str) and title.strip() else None,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
try:
|
||||
await self._conn.session_update(
|
||||
session_id=session_id,
|
||||
update=update,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Could not send ACP session info update for %s", session_id, exc_info=True)
|
||||
|
||||
def _schedule_usage_update(self, state: SessionState) -> None:
|
||||
"""Schedule native context indicator refresh after ACP responses."""
|
||||
if not self._conn:
|
||||
|
|
@ -1471,12 +1501,20 @@ class HermesACPAgent(acp.Agent):
|
|||
try:
|
||||
from agent.title_generator import maybe_auto_title
|
||||
|
||||
def _notify_title_update(_title: str) -> None:
|
||||
if conn:
|
||||
loop.call_soon_threadsafe(
|
||||
asyncio.create_task,
|
||||
self._send_session_info_update(session_id),
|
||||
)
|
||||
|
||||
maybe_auto_title(
|
||||
self.session_manager._get_db(),
|
||||
session_id,
|
||||
user_text,
|
||||
final_response,
|
||||
state.history,
|
||||
title_callback=_notify_title_update,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to auto-title ACP session %s", session_id, exc_info=True)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from acp.schema import (
|
|||
SetSessionModelResponse,
|
||||
SetSessionModeResponse,
|
||||
SessionInfo,
|
||||
SessionInfoUpdate,
|
||||
TextContentBlock,
|
||||
ToolCallProgress,
|
||||
ToolCallStart,
|
||||
|
|
@ -1140,6 +1141,48 @@ class TestPrompt:
|
|||
assert mock_title.call_args.args[1] == new_resp.session_id
|
||||
assert mock_title.call_args.args[2] == "fix the broken ACP history"
|
||||
assert mock_title.call_args.args[3] == "Here is the fix."
|
||||
assert callable(mock_title.call_args.kwargs["title_callback"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_sends_session_info_update_after_auto_title(self, agent):
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
resp = await agent.new_session(cwd="/tmp")
|
||||
state = agent.session_manager.get_session(resp.session_id)
|
||||
state.agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "Done.",
|
||||
"messages": [
|
||||
{"role": "user", "content": "fix zed titles"},
|
||||
{"role": "assistant", "content": "Done."},
|
||||
],
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
})
|
||||
|
||||
def fake_auto_title(db, session_id, user_text, final_response, history, **kwargs):
|
||||
db.set_session_title(session_id, "Fix Zed titles")
|
||||
kwargs["title_callback"]("Fix Zed titles")
|
||||
|
||||
with patch("agent.title_generator.maybe_auto_title", side_effect=fake_auto_title):
|
||||
mock_conn.session_update.reset_mock()
|
||||
await agent.prompt(
|
||||
session_id=resp.session_id,
|
||||
prompt=[TextContentBlock(type="text", text="fix zed titles")],
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
updates = [
|
||||
call.kwargs.get("update") or call.args[1]
|
||||
for call in mock_conn.session_update.await_args_list
|
||||
]
|
||||
info_updates = [u for u in updates if isinstance(u, SessionInfoUpdate)]
|
||||
assert len(info_updates) == 1
|
||||
assert info_updates[0].session_update == "session_info_update"
|
||||
assert info_updates[0].title == "Fix Zed titles"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_populates_usage_from_top_level_run_conversation_fields(self, agent):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue