fix(openviking): sync structured turns with tool parts

This commit is contained in:
Eurekaxun 2026-06-02 14:33:12 +08:00 committed by Hao Zhe
parent 3485bc7225
commit c7b7f92ec1
3 changed files with 639 additions and 21 deletions

View file

@ -70,6 +70,8 @@ _TIMEOUT = 30.0
_SESSION_DRAIN_TIMEOUT = 10.0
_DEFERRED_COMMIT_TIMEOUT = (_TIMEOUT * 2) + 5.0
_REMOTE_RESOURCE_PREFIXES = ("http://", "https://", "git@", "ssh://", "git://")
_SYNC_TRACE_ENV = "HERMES_OPENVIKING_SYNC_TRACE"
_OPENVIKING_RECALL_TOOL_NAMES = {"viking_search", "viking_read", "viking_browse"}
# Maps the viking_remember `category` enum to a viking:// subdirectory.
# Keep in sync with REMEMBER_SCHEMA.parameters.properties.category.enum.
@ -156,6 +158,18 @@ def _derive_openviking_user_text(content: Any) -> str:
return extract_user_instruction_from_skill_message(content) or ""
def _sync_trace_enabled() -> bool:
return os.environ.get(_SYNC_TRACE_ENV, "").strip().lower() in {"1", "true", "yes", "on"}
def _preview(value: Any, limit: int = 160) -> str:
text = "" if value is None else str(value)
text = text.replace("\n", "\\n")
if len(text) > limit:
return text[:limit] + "..."
return text
# ---------------------------------------------------------------------------
# Process-level atexit safety net — ensures pending sessions are committed
# even if shutdown_memory_provider is never called (e.g. gateway crash,
@ -2221,7 +2235,10 @@ class OpenVikingMemoryProvider(MemoryProvider):
def _commit_session(self, sid: str, turn_count: int, *, context: str) -> bool:
try:
self._client.post(f"/api/v1/sessions/{sid}/commit")
self._client.post(
f"/api/v1/sessions/{sid}/commit",
{"keep_recent_count": 0},
)
self._mark_session_committed(sid)
logger.info("OpenViking session %s committed %s (%d turns)", sid, context, turn_count)
return True
@ -2293,7 +2310,261 @@ class OpenVikingMemoryProvider(MemoryProvider):
with self._prefetch_lock:
self._prefetch_result = ""
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
@staticmethod
def _message_text(content: Any) -> str:
"""Extract text from OpenAI-style string/list content."""
if isinstance(content, str):
return content
if isinstance(content, list):
chunks = []
for block in content:
if isinstance(block, str):
chunks.append(block)
elif isinstance(block, dict):
if block.get("type") == "text" and isinstance(block.get("text"), str):
chunks.append(block["text"])
elif isinstance(block.get("content"), str):
chunks.append(block["content"])
return "\n".join(chunk for chunk in chunks if chunk)
if content is None:
return ""
return str(content)
@classmethod
def _message_matches_text(cls, message: Dict[str, Any], expected: Any) -> bool:
expected_text = cls._message_text(expected).strip()
if not expected_text:
return False
actual_text = cls._message_text(message.get("content")).strip()
return actual_text == expected_text
@classmethod
def _extract_current_turn_messages(
cls,
messages: Optional[List[Dict[str, Any]]],
user_content: str,
assistant_content: str,
) -> List[Dict[str, Any]]:
"""Slice the completed turn out of Hermes' full canonical transcript."""
if not messages:
return []
end_idx: Optional[int] = None
if cls._message_text(assistant_content).strip():
for idx in range(len(messages) - 1, -1, -1):
message = messages[idx]
if (
isinstance(message, dict)
and message.get("role") == "assistant"
and cls._message_matches_text(message, assistant_content)
):
end_idx = idx
break
if end_idx is None:
for idx in range(len(messages) - 1, -1, -1):
message = messages[idx]
if isinstance(message, dict) and message.get("role") == "assistant":
end_idx = idx
break
if end_idx is None:
end_idx = len(messages) - 1
start_idx: Optional[int] = None
if cls._message_text(user_content).strip():
for idx in range(end_idx, -1, -1):
message = messages[idx]
if (
isinstance(message, dict)
and message.get("role") == "user"
and cls._message_matches_text(message, user_content)
):
start_idx = idx
break
if start_idx is None:
for idx in range(end_idx, -1, -1):
message = messages[idx]
if isinstance(message, dict) and message.get("role") == "user":
start_idx = idx
break
if start_idx is None:
return []
return [message for message in messages[start_idx : end_idx + 1] if isinstance(message, dict)]
@staticmethod
def _tool_call_id(tool_call: Dict[str, Any]) -> str:
return str(tool_call.get("id") or tool_call.get("tool_call_id") or "")
@staticmethod
def _tool_call_name(tool_call: Dict[str, Any]) -> str:
function = tool_call.get("function")
if isinstance(function, dict):
return str(function.get("name") or "")
return str(tool_call.get("name") or "")
@staticmethod
def _is_openviking_recall_tool_name(tool_name: Any) -> bool:
return str(tool_name or "").strip().lower() in _OPENVIKING_RECALL_TOOL_NAMES
@staticmethod
def _tool_call_input(tool_call: Dict[str, Any]) -> Dict[str, Any]:
function = tool_call.get("function")
raw_args: Any = None
if isinstance(function, dict):
raw_args = function.get("arguments")
if raw_args is None:
raw_args = tool_call.get("args")
if raw_args is None:
return {}
if isinstance(raw_args, dict):
return raw_args
if isinstance(raw_args, str):
if not raw_args.strip():
return {}
try:
parsed = json.loads(raw_args)
except Exception:
return {"value": raw_args}
if isinstance(parsed, dict):
return parsed
return {"value": parsed}
return {"value": raw_args}
@classmethod
def _tool_result_status(cls, message: Dict[str, Any]) -> str:
raw_status = str(message.get("status") or message.get("tool_status") or "").lower()
if raw_status in {"error", "failed", "failure"}:
return "error"
if raw_status in {"completed", "complete", "success", "succeeded"}:
return "completed"
text = cls._message_text(message.get("content")).strip()
if text:
try:
parsed = json.loads(text)
except Exception:
parsed = None
if isinstance(parsed, dict):
status = str(parsed.get("status") or "").lower()
exit_code = parsed.get("exit_code")
if (
status in {"error", "failed", "failure"}
or parsed.get("success") is False
or bool(parsed.get("error"))
or (isinstance(exit_code, int) and exit_code != 0)
):
return "error"
return "completed"
@classmethod
def _messages_to_openviking_batch(
cls,
messages: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Convert Hermes canonical messages into OpenViking batch payloads."""
tool_calls_by_id: Dict[str, Dict[str, Any]] = {}
completed_tool_ids: set[str] = set()
skipped_tool_ids: set[str] = set()
for message in messages:
if not isinstance(message, dict):
continue
if message.get("role") == "tool":
tool_id = str(message.get("tool_call_id") or message.get("id") or "")
if tool_id:
completed_tool_ids.add(tool_id)
if cls._is_openviking_recall_tool_name(message.get("name")):
skipped_tool_ids.add(tool_id)
continue
if message.get("role") != "assistant":
continue
for tool_call in message.get("tool_calls") or []:
if not isinstance(tool_call, dict):
continue
tool_id = cls._tool_call_id(tool_call)
tool_name = cls._tool_call_name(tool_call)
if tool_id:
tool_calls_by_id[tool_id] = {
"tool_name": tool_name,
"tool_input": cls._tool_call_input(tool_call),
}
if cls._is_openviking_recall_tool_name(tool_name):
skipped_tool_ids.add(tool_id)
payload_messages: List[Dict[str, Any]] = []
pending_tool_parts: List[Dict[str, Any]] = []
def flush_tool_parts() -> None:
nonlocal pending_tool_parts
if pending_tool_parts:
payload_messages.append({"role": "user", "parts": pending_tool_parts})
pending_tool_parts = []
for message in messages:
if not isinstance(message, dict):
continue
role = str(message.get("role") or "")
if role in {"system", "developer"}:
continue
if role == "tool":
tool_id = str(message.get("tool_call_id") or message.get("id") or "")
prior_call = tool_calls_by_id.get(tool_id, {})
tool_name = str(message.get("name") or prior_call.get("tool_name") or "")
if tool_id in skipped_tool_ids or cls._is_openviking_recall_tool_name(tool_name):
continue
tool_part = {
"type": "tool",
"tool_id": tool_id,
"tool_name": tool_name,
"tool_input": prior_call.get("tool_input", {}),
"tool_output": cls._message_text(message.get("content")),
"tool_status": cls._tool_result_status(message),
}
pending_tool_parts.append(tool_part)
continue
if role not in {"user", "assistant"}:
continue
flush_tool_parts()
parts: List[Dict[str, Any]] = []
text = cls._message_text(message.get("content"))
if text:
parts.append({"type": "text", "text": text})
if role == "assistant":
for tool_call in message.get("tool_calls") or []:
if not isinstance(tool_call, dict):
continue
tool_id = cls._tool_call_id(tool_call)
tool_name = cls._tool_call_name(tool_call)
if tool_id in skipped_tool_ids or cls._is_openviking_recall_tool_name(tool_name):
continue
if tool_id in completed_tool_ids:
continue
parts.append({
"type": "tool",
"tool_id": tool_id,
"tool_name": tool_name,
"tool_input": cls._tool_call_input(tool_call),
"tool_status": "pending",
})
if parts:
payload_messages.append({"role": role, "parts": parts})
flush_tool_parts()
return payload_messages
def sync_turn(
self,
user_content: str,
assistant_content: str,
*,
session_id: str = "",
messages: Optional[List[Dict[str, Any]]] = None,
) -> None:
"""Record the conversation turn in OpenViking's session (non-blocking)."""
if not self._client:
return
@ -2302,6 +2573,37 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not user_content:
return
turn_messages = (
self._extract_current_turn_messages(messages, user_content, assistant_content)
if messages is not None
else []
)
if turn_messages:
turn_messages = [dict(message) for message in turn_messages]
for message in turn_messages:
if message.get("role") == "user":
message["content"] = user_content
break
batch_messages = self._messages_to_openviking_batch(turn_messages)
if _sync_trace_enabled():
logger.info(
"OpenViking sync_turn trace: session_arg=%r cached_session=%r "
"messages_param_supported=true messages_present=%s message_count=%s "
"turn_message_count=%d batch_message_count=%d user_len=%d assistant_len=%d "
"user_preview=%r assistant_preview=%r",
session_id,
self._session_id,
messages is not None,
len(messages) if messages is not None else None,
len(turn_messages),
len(batch_messages),
len(str(user_content or "")),
len(str(assistant_content or "")),
_preview(user_content),
_preview(assistant_content),
)
# Snapshot the sid and bump the turn counter atomically so a
# concurrent on_session_switch/on_session_end can't interleave its
# snapshot+reset between the read and the increment (lost turn) and so
@ -2313,24 +2615,39 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._turn_count += 1
def _sync():
try:
client = self._new_client()
def _post_turn(client: _VikingClient) -> None:
if batch_messages:
payload = {"messages": batch_messages}
if _sync_trace_enabled():
logger.info(
"OpenViking sync_turn trace: POST /api/v1/sessions/%s/messages/batch payload=%s",
sid,
json.dumps(payload, ensure_ascii=False),
)
try:
client.post(f"/api/v1/sessions/{sid}/messages/batch", payload)
return
except Exception as batch_error:
logger.warning(
"OpenViking structured sync failed; falling back to text sync: %s",
batch_error,
)
self._post_session_turn(
client,
sid,
user_content[:4000],
assistant_content[:4000],
self._message_text(assistant_content)[:4000],
)
try:
client = self._new_client()
_post_turn(client)
except Exception as e:
logger.debug("OpenViking sync_turn failed, reconnecting: %s", e)
try:
client = self._new_client()
self._post_session_turn(
client,
sid,
user_content[:4000],
assistant_content[:4000],
)
_post_turn(client)
except Exception as retry_error:
logger.warning("OpenViking sync_turn failed: %s", retry_error)