fix(openviking): sync structured turns with tool parts

This commit is contained in:
Eurekaxun 2026-06-02 14:33:12 +08:00 committed by Hao Zhe
parent 3485bc7225
commit c7b7f92ec1
3 changed files with 639 additions and 21 deletions

View file

@ -70,6 +70,8 @@ _TIMEOUT = 30.0
_SESSION_DRAIN_TIMEOUT = 10.0
_DEFERRED_COMMIT_TIMEOUT = (_TIMEOUT * 2) + 5.0
_REMOTE_RESOURCE_PREFIXES = ("http://", "https://", "git@", "ssh://", "git://")
_SYNC_TRACE_ENV = "HERMES_OPENVIKING_SYNC_TRACE"
_OPENVIKING_RECALL_TOOL_NAMES = {"viking_search", "viking_read", "viking_browse"}
# Maps the viking_remember `category` enum to a viking:// subdirectory.
# Keep in sync with REMEMBER_SCHEMA.parameters.properties.category.enum.
@ -156,6 +158,18 @@ def _derive_openviking_user_text(content: Any) -> str:
return extract_user_instruction_from_skill_message(content) or ""
def _sync_trace_enabled() -> bool:
return os.environ.get(_SYNC_TRACE_ENV, "").strip().lower() in {"1", "true", "yes", "on"}
def _preview(value: Any, limit: int = 160) -> str:
text = "" if value is None else str(value)
text = text.replace("\n", "\\n")
if len(text) > limit:
return text[:limit] + "..."
return text
# ---------------------------------------------------------------------------
# Process-level atexit safety net — ensures pending sessions are committed
# even if shutdown_memory_provider is never called (e.g. gateway crash,
@ -2221,7 +2235,10 @@ class OpenVikingMemoryProvider(MemoryProvider):
def _commit_session(self, sid: str, turn_count: int, *, context: str) -> bool:
try:
self._client.post(f"/api/v1/sessions/{sid}/commit")
self._client.post(
f"/api/v1/sessions/{sid}/commit",
{"keep_recent_count": 0},
)
self._mark_session_committed(sid)
logger.info("OpenViking session %s committed %s (%d turns)", sid, context, turn_count)
return True
@ -2293,7 +2310,261 @@ class OpenVikingMemoryProvider(MemoryProvider):
with self._prefetch_lock:
self._prefetch_result = ""
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
@staticmethod
def _message_text(content: Any) -> str:
"""Extract text from OpenAI-style string/list content."""
if isinstance(content, str):
return content
if isinstance(content, list):
chunks = []
for block in content:
if isinstance(block, str):
chunks.append(block)
elif isinstance(block, dict):
if block.get("type") == "text" and isinstance(block.get("text"), str):
chunks.append(block["text"])
elif isinstance(block.get("content"), str):
chunks.append(block["content"])
return "\n".join(chunk for chunk in chunks if chunk)
if content is None:
return ""
return str(content)
@classmethod
def _message_matches_text(cls, message: Dict[str, Any], expected: Any) -> bool:
expected_text = cls._message_text(expected).strip()
if not expected_text:
return False
actual_text = cls._message_text(message.get("content")).strip()
return actual_text == expected_text
@classmethod
def _extract_current_turn_messages(
cls,
messages: Optional[List[Dict[str, Any]]],
user_content: str,
assistant_content: str,
) -> List[Dict[str, Any]]:
"""Slice the completed turn out of Hermes' full canonical transcript."""
if not messages:
return []
end_idx: Optional[int] = None
if cls._message_text(assistant_content).strip():
for idx in range(len(messages) - 1, -1, -1):
message = messages[idx]
if (
isinstance(message, dict)
and message.get("role") == "assistant"
and cls._message_matches_text(message, assistant_content)
):
end_idx = idx
break
if end_idx is None:
for idx in range(len(messages) - 1, -1, -1):
message = messages[idx]
if isinstance(message, dict) and message.get("role") == "assistant":
end_idx = idx
break
if end_idx is None:
end_idx = len(messages) - 1
start_idx: Optional[int] = None
if cls._message_text(user_content).strip():
for idx in range(end_idx, -1, -1):
message = messages[idx]
if (
isinstance(message, dict)
and message.get("role") == "user"
and cls._message_matches_text(message, user_content)
):
start_idx = idx
break
if start_idx is None:
for idx in range(end_idx, -1, -1):
message = messages[idx]
if isinstance(message, dict) and message.get("role") == "user":
start_idx = idx
break
if start_idx is None:
return []
return [message for message in messages[start_idx : end_idx + 1] if isinstance(message, dict)]
@staticmethod
def _tool_call_id(tool_call: Dict[str, Any]) -> str:
return str(tool_call.get("id") or tool_call.get("tool_call_id") or "")
@staticmethod
def _tool_call_name(tool_call: Dict[str, Any]) -> str:
function = tool_call.get("function")
if isinstance(function, dict):
return str(function.get("name") or "")
return str(tool_call.get("name") or "")
@staticmethod
def _is_openviking_recall_tool_name(tool_name: Any) -> bool:
return str(tool_name or "").strip().lower() in _OPENVIKING_RECALL_TOOL_NAMES
@staticmethod
def _tool_call_input(tool_call: Dict[str, Any]) -> Dict[str, Any]:
function = tool_call.get("function")
raw_args: Any = None
if isinstance(function, dict):
raw_args = function.get("arguments")
if raw_args is None:
raw_args = tool_call.get("args")
if raw_args is None:
return {}
if isinstance(raw_args, dict):
return raw_args
if isinstance(raw_args, str):
if not raw_args.strip():
return {}
try:
parsed = json.loads(raw_args)
except Exception:
return {"value": raw_args}
if isinstance(parsed, dict):
return parsed
return {"value": parsed}
return {"value": raw_args}
@classmethod
def _tool_result_status(cls, message: Dict[str, Any]) -> str:
raw_status = str(message.get("status") or message.get("tool_status") or "").lower()
if raw_status in {"error", "failed", "failure"}:
return "error"
if raw_status in {"completed", "complete", "success", "succeeded"}:
return "completed"
text = cls._message_text(message.get("content")).strip()
if text:
try:
parsed = json.loads(text)
except Exception:
parsed = None
if isinstance(parsed, dict):
status = str(parsed.get("status") or "").lower()
exit_code = parsed.get("exit_code")
if (
status in {"error", "failed", "failure"}
or parsed.get("success") is False
or bool(parsed.get("error"))
or (isinstance(exit_code, int) and exit_code != 0)
):
return "error"
return "completed"
@classmethod
def _messages_to_openviking_batch(
cls,
messages: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Convert Hermes canonical messages into OpenViking batch payloads."""
tool_calls_by_id: Dict[str, Dict[str, Any]] = {}
completed_tool_ids: set[str] = set()
skipped_tool_ids: set[str] = set()
for message in messages:
if not isinstance(message, dict):
continue
if message.get("role") == "tool":
tool_id = str(message.get("tool_call_id") or message.get("id") or "")
if tool_id:
completed_tool_ids.add(tool_id)
if cls._is_openviking_recall_tool_name(message.get("name")):
skipped_tool_ids.add(tool_id)
continue
if message.get("role") != "assistant":
continue
for tool_call in message.get("tool_calls") or []:
if not isinstance(tool_call, dict):
continue
tool_id = cls._tool_call_id(tool_call)
tool_name = cls._tool_call_name(tool_call)
if tool_id:
tool_calls_by_id[tool_id] = {
"tool_name": tool_name,
"tool_input": cls._tool_call_input(tool_call),
}
if cls._is_openviking_recall_tool_name(tool_name):
skipped_tool_ids.add(tool_id)
payload_messages: List[Dict[str, Any]] = []
pending_tool_parts: List[Dict[str, Any]] = []
def flush_tool_parts() -> None:
nonlocal pending_tool_parts
if pending_tool_parts:
payload_messages.append({"role": "user", "parts": pending_tool_parts})
pending_tool_parts = []
for message in messages:
if not isinstance(message, dict):
continue
role = str(message.get("role") or "")
if role in {"system", "developer"}:
continue
if role == "tool":
tool_id = str(message.get("tool_call_id") or message.get("id") or "")
prior_call = tool_calls_by_id.get(tool_id, {})
tool_name = str(message.get("name") or prior_call.get("tool_name") or "")
if tool_id in skipped_tool_ids or cls._is_openviking_recall_tool_name(tool_name):
continue
tool_part = {
"type": "tool",
"tool_id": tool_id,
"tool_name": tool_name,
"tool_input": prior_call.get("tool_input", {}),
"tool_output": cls._message_text(message.get("content")),
"tool_status": cls._tool_result_status(message),
}
pending_tool_parts.append(tool_part)
continue
if role not in {"user", "assistant"}:
continue
flush_tool_parts()
parts: List[Dict[str, Any]] = []
text = cls._message_text(message.get("content"))
if text:
parts.append({"type": "text", "text": text})
if role == "assistant":
for tool_call in message.get("tool_calls") or []:
if not isinstance(tool_call, dict):
continue
tool_id = cls._tool_call_id(tool_call)
tool_name = cls._tool_call_name(tool_call)
if tool_id in skipped_tool_ids or cls._is_openviking_recall_tool_name(tool_name):
continue
if tool_id in completed_tool_ids:
continue
parts.append({
"type": "tool",
"tool_id": tool_id,
"tool_name": tool_name,
"tool_input": cls._tool_call_input(tool_call),
"tool_status": "pending",
})
if parts:
payload_messages.append({"role": role, "parts": parts})
flush_tool_parts()
return payload_messages
def sync_turn(
self,
user_content: str,
assistant_content: str,
*,
session_id: str = "",
messages: Optional[List[Dict[str, Any]]] = None,
) -> None:
"""Record the conversation turn in OpenViking's session (non-blocking)."""
if not self._client:
return
@ -2302,6 +2573,37 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not user_content:
return
turn_messages = (
self._extract_current_turn_messages(messages, user_content, assistant_content)
if messages is not None
else []
)
if turn_messages:
turn_messages = [dict(message) for message in turn_messages]
for message in turn_messages:
if message.get("role") == "user":
message["content"] = user_content
break
batch_messages = self._messages_to_openviking_batch(turn_messages)
if _sync_trace_enabled():
logger.info(
"OpenViking sync_turn trace: session_arg=%r cached_session=%r "
"messages_param_supported=true messages_present=%s message_count=%s "
"turn_message_count=%d batch_message_count=%d user_len=%d assistant_len=%d "
"user_preview=%r assistant_preview=%r",
session_id,
self._session_id,
messages is not None,
len(messages) if messages is not None else None,
len(turn_messages),
len(batch_messages),
len(str(user_content or "")),
len(str(assistant_content or "")),
_preview(user_content),
_preview(assistant_content),
)
# Snapshot the sid and bump the turn counter atomically so a
# concurrent on_session_switch/on_session_end can't interleave its
# snapshot+reset between the read and the increment (lost turn) and so
@ -2313,24 +2615,39 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._turn_count += 1
def _sync():
try:
client = self._new_client()
def _post_turn(client: _VikingClient) -> None:
if batch_messages:
payload = {"messages": batch_messages}
if _sync_trace_enabled():
logger.info(
"OpenViking sync_turn trace: POST /api/v1/sessions/%s/messages/batch payload=%s",
sid,
json.dumps(payload, ensure_ascii=False),
)
try:
client.post(f"/api/v1/sessions/{sid}/messages/batch", payload)
return
except Exception as batch_error:
logger.warning(
"OpenViking structured sync failed; falling back to text sync: %s",
batch_error,
)
self._post_session_turn(
client,
sid,
user_content[:4000],
assistant_content[:4000],
self._message_text(assistant_content)[:4000],
)
try:
client = self._new_client()
_post_turn(client)
except Exception as e:
logger.debug("OpenViking sync_turn failed, reconnecting: %s", e)
try:
client = self._new_client()
self._post_session_turn(
client,
sid,
user_content[:4000],
assistant_content[:4000],
)
_post_turn(client)
except Exception as retry_error:
logger.warning("OpenViking sync_turn failed: %s", retry_error)

View file

@ -265,6 +265,280 @@ class TestOpenVikingSkillQuerySafety:
assert RecordingVikingClient.calls == []
class TestOpenVikingTurnConversion:
def test_extract_current_turn_anchors_on_latest_matching_user_and_assistant(self):
messages = [
{"role": "user", "content": "Please inspect the repository for assemble hooks."},
{"role": "assistant", "content": "Earlier answer."},
{"role": "user", "content": "Please inspect the repository for assemble hooks."},
{
"role": "assistant",
"content": "I will search the codebase.",
"tool_calls": [
{
"id": "call_rg_1",
"type": "function",
"function": {
"name": "shell_command",
"arguments": json.dumps({"command": "rg assemble"}),
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_rg_1",
"name": "shell_command",
"content": "agent/context_engine.py: no preassemble hook",
},
{"role": "assistant", "content": "The current main does not expose assemble."},
]
turn = OpenVikingMemoryProvider._extract_current_turn_messages(
messages,
"Please inspect the repository for assemble hooks.",
"The current main does not expose assemble.",
)
assert turn == messages[2:]
def test_messages_to_openviking_batch_coalesces_tool_results(self):
turn = [
{"role": "user", "content": "Please inspect the repository for assemble hooks."},
{
"role": "assistant",
"content": "I will search the codebase.",
"tool_calls": [
{
"id": "call_rg_1",
"type": "function",
"function": {
"name": "shell_command",
"arguments": json.dumps({"command": "rg assemble"}),
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_rg_1",
"name": "shell_command",
"content": "agent/context_engine.py: no preassemble hook",
},
{"role": "assistant", "content": "The current main does not expose assemble."},
]
batch = OpenVikingMemoryProvider._messages_to_openviking_batch(turn)
assert [message["role"] for message in batch] == ["user", "assistant", "user", "assistant"]
assert batch[0]["parts"] == [
{"type": "text", "text": "Please inspect the repository for assemble hooks."}
]
assert batch[1]["parts"] == [
{"type": "text", "text": "I will search the codebase."}
]
assert batch[2]["parts"] == [
{
"type": "tool",
"tool_id": "call_rg_1",
"tool_name": "shell_command",
"tool_input": {"command": "rg assemble"},
"tool_output": "agent/context_engine.py: no preassemble hook",
"tool_status": "completed",
}
]
assert batch[3]["parts"] == [
{"type": "text", "text": "The current main does not expose assemble."}
]
def test_messages_to_openviking_batch_marks_json_tool_error_results(self):
turn = [
{"role": "user", "content": "Check the file."},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_read_1",
"type": "function",
"function": {
"name": "read_file",
"arguments": json.dumps({"path": "missing.md"}),
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_read_1",
"name": "read_file",
"content": json.dumps({"error": "File not found", "exit_code": 1}),
},
]
batch = OpenVikingMemoryProvider._messages_to_openviking_batch(turn)
assert batch[1]["parts"] == [
{
"type": "tool",
"tool_id": "call_read_1",
"tool_name": "read_file",
"tool_input": {"path": "missing.md"},
"tool_output": json.dumps({"error": "File not found", "exit_code": 1}),
"tool_status": "error",
}
]
def test_messages_to_openviking_batch_keeps_pending_tool_call_without_result(self):
turn = [
{"role": "user", "content": "Start a long running check."},
{
"role": "assistant",
"content": "Starting it now.",
"tool_calls": [
{
"id": "call_long_1",
"type": "function",
"function": {
"name": "long_check",
"arguments": json.dumps({"target": "repo"}),
},
}
],
},
]
batch = OpenVikingMemoryProvider._messages_to_openviking_batch(turn)
assert batch[1]["parts"] == [
{"type": "text", "text": "Starting it now."},
{
"type": "tool",
"tool_id": "call_long_1",
"tool_name": "long_check",
"tool_input": {"target": "repo"},
"tool_status": "pending",
},
]
def test_messages_to_openviking_batch_coalesces_adjacent_tool_results(self):
turn = [
{"role": "user", "content": "Run both tools."},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_a",
"type": "function",
"function": {
"name": "first_tool",
"arguments": json.dumps({"x": 1}),
},
},
{
"id": "call_b",
"type": "function",
"function": {
"name": "second_tool",
"arguments": json.dumps({"y": 2}),
},
},
],
},
{"role": "tool", "tool_call_id": "call_a", "name": "first_tool", "content": "a"},
{"role": "tool", "tool_call_id": "call_b", "name": "second_tool", "content": "b"},
{"role": "assistant", "content": "Done."},
]
batch = OpenVikingMemoryProvider._messages_to_openviking_batch(turn)
assert [message["role"] for message in batch] == ["user", "user", "assistant"]
assert batch[1]["parts"] == [
{
"type": "tool",
"tool_id": "call_a",
"tool_name": "first_tool",
"tool_input": {"x": 1},
"tool_output": "a",
"tool_status": "completed",
},
{
"type": "tool",
"tool_id": "call_b",
"tool_name": "second_tool",
"tool_input": {"y": 2},
"tool_output": "b",
"tool_status": "completed",
},
]
def test_messages_to_openviking_batch_skips_openviking_recall_tool_results(self):
for recall_tool_name in ("viking_search", "viking_read", "viking_browse"):
turn = [
{"role": "user", "content": "What did we decide about context assembly?"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_recall_1",
"type": "function",
"function": {
"name": recall_tool_name,
"arguments": json.dumps({"query": "context assembly decision"}),
},
},
{
"id": "call_shell_1",
"type": "function",
"function": {
"name": "shell_command",
"arguments": json.dumps({"command": "rg preassemble"}),
},
},
],
},
{
"role": "tool",
"tool_call_id": "call_recall_1",
"name": recall_tool_name,
"content": json.dumps({
"results": [
{
"uri": "viking://user/hermes/memories/context",
"abstract": "Old OpenViking memory content",
}
]
}),
},
{
"role": "tool",
"tool_call_id": "call_shell_1",
"name": "shell_command",
"content": "plugins/memory/openviking/__init__.py",
},
{"role": "assistant", "content": "We decided to keep sync_turn scoped to ingestion."},
]
batch = OpenVikingMemoryProvider._messages_to_openviking_batch(turn)
assert [message["role"] for message in batch] == ["user", "user", "assistant"]
assert batch[1]["parts"] == [
{
"type": "tool",
"tool_id": "call_shell_1",
"tool_name": "shell_command",
"tool_input": {"command": "rg preassemble"},
"tool_output": "plugins/memory/openviking/__init__.py",
"tool_status": "completed",
}
]
batch_text = json.dumps(batch)
assert recall_tool_name not in batch_text
assert "Old OpenViking memory content" not in batch_text
class TestOpenVikingRead:
def test_overview_read_normalizes_uri_and_unwraps_result(self):
provider = OpenVikingMemoryProvider()

View file

@ -1975,7 +1975,10 @@ def test_on_session_switch_commits_old_session_and_rotates_id():
provider.on_session_switch("new-sid", parent_session_id="old-sid")
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
provider._client.post.assert_called_once_with(
"/api/v1/sessions/old-sid/commit",
{"keep_recent_count": 0},
)
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
@ -1998,7 +2001,10 @@ def test_on_session_switch_commits_pending_tokens_without_turn_count():
provider.on_session_switch("new-sid")
provider._client.get.assert_called_once_with("/api/v1/sessions/old-sid")
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
provider._client.post.assert_called_once_with(
"/api/v1/sessions/old-sid/commit",
{"keep_recent_count": 0},
)
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
@ -2051,7 +2057,10 @@ def test_on_session_switch_waits_for_inflight_sync_thread():
provider.on_session_switch("new-sid")
assert join_calls, "expected on_session_switch to join the in-flight sync thread"
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
provider._client.post.assert_called_once_with(
"/api/v1/sessions/old-sid/commit",
{"keep_recent_count": 0},
)
def test_on_session_switch_noop_on_empty_new_id():
@ -2206,7 +2215,10 @@ def test_on_session_end_marks_session_clean_after_successful_commit():
provider.on_session_end([])
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
provider._client.post.assert_called_once_with(
"/api/v1/sessions/old-sid/commit",
{"keep_recent_count": 0},
)
assert provider._turn_count == 0
@ -2228,7 +2240,10 @@ def test_on_session_end_commits_pending_tokens_without_turn_count():
provider.on_session_end([])
provider._client.get.assert_called_once_with("/api/v1/sessions/old-sid")
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
provider._client.post.assert_called_once_with(
"/api/v1/sessions/old-sid/commit",
{"keep_recent_count": 0},
)
def test_end_then_switch_does_not_double_commit():
@ -2241,7 +2256,10 @@ def test_end_then_switch_does_not_double_commit():
provider.on_session_switch("new-sid", parent_session_id="old-sid")
# Exactly one commit call, on the OLD session, fired by on_session_end.
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
provider._client.post.assert_called_once_with(
"/api/v1/sessions/old-sid/commit",
{"keep_recent_count": 0},
)
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
@ -2253,7 +2271,10 @@ def test_end_then_switch_with_pending_tokens_does_not_double_commit():
provider.on_session_end([])
provider.on_session_switch("new-sid", parent_session_id="old-sid")
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
provider._client.post.assert_called_once_with(
"/api/v1/sessions/old-sid/commit",
{"keep_recent_count": 0},
)
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
@ -2400,7 +2421,10 @@ def test_on_session_switch_does_not_block_caller_on_slow_drain():
# Let the finalizer finish so it doesn't leak past the test.
release_drain.set()
assert provider._drain_finalizers(timeout=5.0)
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
provider._client.post.assert_called_once_with(
"/api/v1/sessions/old-sid/commit",
{"keep_recent_count": 0},
)
def test_on_session_switch_defers_old_commit_to_finalizer_thread():
@ -2415,7 +2439,7 @@ def test_on_session_switch_defers_old_commit_to_finalizer_thread():
committed = threading.Event()
drain_timeouts = []
def fake_post(path):
def fake_post(path, payload=None):
committed.set()
return {}
@ -2433,7 +2457,10 @@ def test_on_session_switch_defers_old_commit_to_finalizer_thread():
assert provider._turn_count == 0
# The old-session commit lands on the finalizer thread, not inline.
assert committed.wait(timeout=5.0), "old session was not finalized off-thread"
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
provider._client.post.assert_called_once_with(
"/api/v1/sessions/old-sid/commit",
{"keep_recent_count": 0},
)
# The finalizer drains with the deferred (longer) budget, not inline 10s.
assert drain_timeouts == [_DEFERRED_COMMIT_TIMEOUT]