fix(openviking): sanitize skill memory input

This commit is contained in:
Hao Zhe 2026-05-26 22:56:07 +08:00 committed by Teknium
parent e236bb87eb
commit e3adbb5ae9
3 changed files with 318 additions and 0 deletions

View file

@ -66,6 +66,61 @@ _MEMORY_WRITE_TARGET_SUBDIR_MAP = {
"memory": "patterns",
}
_SKILL_INVOCATION_PREFIX = "[IMPORTANT: The user has invoked the "
_SINGLE_SKILL_MARKER = "The full skill content is loaded below.]"
_SINGLE_SKILL_INSTRUCTION = (
"The user has provided the following instruction alongside the skill invocation: "
)
_BUNDLE_MARKER = " skill bundle,"
_BUNDLE_USER_INSTRUCTION = "\nUser instruction: "
_BUNDLE_FIRST_SKILL_BLOCK = "\n\n[Loaded as part of the "
_RUNTIME_NOTE = "\n\n[Runtime note:"
def _derive_openviking_user_text(content: Any) -> str:
"""Strip Hermes slash-skill scaffolding before sending content to OpenViking."""
if not isinstance(content, str):
return ""
if not content.startswith(_SKILL_INVOCATION_PREFIX):
return content
if _BUNDLE_MARKER in content:
return _extract_bundle_user_instruction(content)
if _SINGLE_SKILL_MARKER in content:
return _extract_single_skill_user_instruction(content)
return ""
def _extract_single_skill_user_instruction(message: str) -> str:
# Single-skill format appends the user instruction after the skill body, so
# the last occurrence is the user-provided one; the body may quote this text.
marker_idx = message.rfind(_SINGLE_SKILL_INSTRUCTION)
if marker_idx < 0:
return ""
instruction = message[marker_idx + len(_SINGLE_SKILL_INSTRUCTION) :]
runtime_idx = instruction.find(_RUNTIME_NOTE)
if runtime_idx >= 0:
instruction = instruction[:runtime_idx]
return instruction.strip()
def _extract_bundle_user_instruction(message: str) -> str:
# Bundle format puts the user instruction before the loaded skills, so the
# first occurrence is the user-provided one.
marker_idx = message.find(_BUNDLE_USER_INSTRUCTION)
if marker_idx < 0:
return ""
instruction = message[marker_idx + len(_BUNDLE_USER_INSTRUCTION) :]
first_skill_idx = instruction.find(_BUNDLE_FIRST_SKILL_BLOCK)
if first_skill_idx >= 0:
instruction = instruction[:first_skill_idx]
return instruction.strip()
# ---------------------------------------------------------------------------
# Process-level atexit safety net — ensures pending sessions are committed
@ -531,6 +586,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
"""Fire a background search to pre-load relevant context."""
query = _derive_openviking_user_text(query)
if not self._client or not query:
return
@ -570,6 +626,10 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not self._client:
return
user_content = _derive_openviking_user_text(user_content)
if not user_content:
return
self._turn_count += 1
def _sync():

View file

@ -2,9 +2,26 @@
import json
import plugins.memory.openviking as openviking_plugin
from plugins.memory.openviking import OpenVikingMemoryProvider
def _write_skill(skills_dir, name, body="Do the thing."):
skill_dir = skills_dir / name
skill_dir.mkdir(parents=True, exist_ok=True)
(skill_dir / "SKILL.md").write_text(
f"---\nname: {name}\ndescription: Description for {name}\n---\n\n# {name}\n\n{body}\n"
)
return skill_dir
def _write_bundle(bundles_dir, slug, skills):
bundles_dir.mkdir(parents=True, exist_ok=True)
lines = [f"name: {slug}", "skills:"]
lines.extend(f" - {skill}" for skill in skills)
(bundles_dir / f"{slug}.yaml").write_text("\n".join(lines) + "\n")
class FakeVikingClient:
def __init__(self, responses):
self.responses = responses
@ -17,6 +34,24 @@ class FakeVikingClient:
raise response
return response
def post(self, path, payload=None, **kwargs):
self.calls.append((path, payload or {}))
response = self.responses.get((path, tuple(sorted((payload or {}).items()))), {})
if isinstance(response, Exception):
raise response
return response
class RecordingVikingClient:
calls = []
def __init__(self, *args, **kwargs):
pass
def post(self, path, payload=None, **kwargs):
self.calls.append((path, payload or {}))
return {"result": {"memories": [], "resources": []}}
class TestOpenVikingSummaryUriNormalization:
def test_normalize_summary_uri_maps_pseudo_files_to_parent_directory(self):
@ -26,6 +61,196 @@ class TestOpenVikingSummaryUriNormalization:
assert OpenVikingMemoryProvider._normalize_summary_uri("viking://user/hermes/memories/profile.md") == "viking://user/hermes/memories/profile.md"
class TestOpenVikingSkillQuerySafety:
def test_derive_returns_empty_string_for_non_string_input(self):
assert openviking_plugin._derive_openviking_user_text(None) == ""
assert openviking_plugin._derive_openviking_user_text(123) == ""
assert openviking_plugin._derive_openviking_user_text([{"text": "hi"}]) == ""
def test_derive_passes_through_non_skill_content(self):
assert (
openviking_plugin._derive_openviking_user_text("regular user message")
== "regular user message"
)
def test_derive_returns_empty_for_skill_scaffolding_with_no_instruction(self):
skill_message = (
'[IMPORTANT: The user has invoked the "example" skill, indicating they want '
"you to follow its instructions. The full skill content is loaded below.]\n\n"
"# Example\n\n"
"Skill body only, no instruction."
)
assert openviking_plugin._derive_openviking_user_text(skill_message) == ""
def test_skill_markers_match_hermes_scaffolding(self, tmp_path, monkeypatch):
import agent.skill_bundles as skill_bundles
import agent.skill_commands as skill_commands
import tools.skills_tool as skills_tool
skills_dir = tmp_path / "skills"
bundles_dir = tmp_path / "skill-bundles"
_write_skill(skills_dir, "example")
_write_bundle(bundles_dir, "demo", ["example"])
monkeypatch.setattr(skills_tool, "SKILLS_DIR", skills_dir)
monkeypatch.setenv("HERMES_BUNDLES_DIR", str(bundles_dir))
monkeypatch.setattr(skill_commands, "_skill_commands", {})
monkeypatch.setattr(skill_commands, "_skill_commands_platform", None)
monkeypatch.setattr(skill_bundles, "_bundles_cache", {})
monkeypatch.setattr(skill_bundles, "_bundles_cache_mtime", None)
skill_commands.scan_skill_commands()
single = skill_commands.build_skill_invocation_message(
"/example",
user_instruction="hello",
runtime_note="runtime detail",
)
assert single is not None
assert openviking_plugin._SKILL_INVOCATION_PREFIX in single
assert openviking_plugin._SINGLE_SKILL_MARKER in single
assert openviking_plugin._SINGLE_SKILL_INSTRUCTION in single
assert openviking_plugin._RUNTIME_NOTE in single
skill_bundles.scan_bundles()
bundle_result = skill_bundles.build_bundle_invocation_message(
"/demo",
user_instruction="hello",
)
assert bundle_result is not None
bundle, _, _ = bundle_result
assert openviking_plugin._BUNDLE_MARKER in bundle
assert openviking_plugin._BUNDLE_USER_INSTRUCTION in bundle
assert openviking_plugin._BUNDLE_FIRST_SKILL_BLOCK in bundle
def test_queue_prefetch_searches_only_slash_skill_user_instruction(self, monkeypatch):
RecordingVikingClient.calls = []
monkeypatch.setattr(openviking_plugin, "_VikingClient", RecordingVikingClient)
provider = OpenVikingMemoryProvider()
provider._client = object()
provider._endpoint = "http://openviking.test"
provider._api_key = ""
provider._account = "default"
provider._user = "default"
provider._agent = "hermes"
skill_message = (
'[IMPORTANT: The user has invoked the "skill-creator" skill, indicating they want '
"you to follow its instructions. The full skill content is loaded below.]\n\n"
"# Skill Creator\n\n"
"Large skill body that must not be searched or embedded.\n\n"
"The user has provided the following instruction alongside the skill invocation: "
"make a skill for release triage"
)
provider.queue_prefetch(skill_message)
provider._prefetch_thread.join(timeout=5.0)
assert RecordingVikingClient.calls == [
(
"/api/v1/search/find",
{"query": "make a skill for release triage", "top_k": 5},
)
]
def test_queue_prefetch_searches_only_skill_bundle_user_instruction(self, monkeypatch):
RecordingVikingClient.calls = []
monkeypatch.setattr(openviking_plugin, "_VikingClient", RecordingVikingClient)
provider = OpenVikingMemoryProvider()
provider._client = object()
provider._endpoint = "http://openviking.test"
provider._api_key = ""
provider._account = "default"
provider._user = "default"
provider._agent = "hermes"
skill_message = (
'[IMPORTANT: The user has invoked the "backend-dev" skill bundle, '
"loading 2 skills together. Treat every skill below as active guidance for this turn.]\n\n"
"Bundle: backend-dev\n"
"Skills loaded: test-driven-development, code-review\n\n"
"User instruction: fix the failing retrieval test\n\n"
'[Loaded as part of the "backend-dev" skill bundle.]\n\n'
"Large bundled skill body that must not be searched or embedded."
)
provider.queue_prefetch(skill_message)
provider._prefetch_thread.join(timeout=5.0)
assert RecordingVikingClient.calls == [
(
"/api/v1/search/find",
{"query": "fix the failing retrieval test", "top_k": 5},
)
]
def test_queue_prefetch_skips_slash_skill_without_user_instruction(self, monkeypatch):
RecordingVikingClient.calls = []
monkeypatch.setattr(openviking_plugin, "_VikingClient", RecordingVikingClient)
provider = OpenVikingMemoryProvider()
provider._client = object()
skill_message = (
'[IMPORTANT: The user has invoked the "skill-creator" skill, indicating they want '
"you to follow its instructions. The full skill content is loaded below.]\n\n"
"# Skill Creator\n\n"
"Large skill body that must not be searched or embedded."
)
provider.queue_prefetch(skill_message)
assert provider._prefetch_thread is None
assert RecordingVikingClient.calls == []
def test_sync_turn_stores_only_slash_skill_user_instruction(self, monkeypatch):
RecordingVikingClient.calls = []
monkeypatch.setattr(openviking_plugin, "_VikingClient", RecordingVikingClient)
provider = OpenVikingMemoryProvider()
provider._client = object()
provider._endpoint = "http://openviking.test"
provider._api_key = ""
provider._account = "default"
provider._user = "default"
provider._agent = "hermes"
provider._session_id = "session-1"
skill_message = (
'[IMPORTANT: The user has invoked the "skill-creator" skill, indicating they want '
"you to follow its instructions. The full skill content is loaded below.]\n\n"
"# Skill Creator\n\n"
"Large skill body that must not be stored as user content.\n\n"
"The user has provided the following instruction alongside the skill invocation: "
"make a skill for release triage"
)
provider.sync_turn(skill_message, "Done.")
provider._sync_thread.join(timeout=5.0)
assert RecordingVikingClient.calls == [
(
"/api/v1/sessions/session-1/messages",
{"role": "user", "content": "make a skill for release triage"},
),
(
"/api/v1/sessions/session-1/messages",
{"role": "assistant", "content": "Done."},
),
]
def test_sync_turn_skips_slash_skill_without_user_instruction(self, monkeypatch):
RecordingVikingClient.calls = []
monkeypatch.setattr(openviking_plugin, "_VikingClient", RecordingVikingClient)
provider = OpenVikingMemoryProvider()
provider._client = object()
skill_message = (
'[IMPORTANT: The user has invoked the "skill-creator" skill, indicating they want '
"you to follow its instructions. The full skill content is loaded below.]\n\n"
"# Skill Creator\n\n"
"Large skill body that must not be stored as user content."
)
provider.sync_turn(skill_message, "Done.")
assert provider._sync_thread is None
assert RecordingVikingClient.calls == []
class TestOpenVikingRead:
def test_overview_read_normalizes_uri_and_unwraps_result(self):
provider = OpenVikingMemoryProvider()

View file

@ -130,6 +130,39 @@ class TestSyncExternalMemoryForTurn:
messages=messages,
)
def test_completed_skill_turn_keeps_original_message_for_memory_manager(self):
"""Provider-specific query shaping belongs inside the provider.
The MemoryManager fan-out contract stays raw so non-OpenViking
providers can decide for themselves whether slash-skill-expanded
content is useful.
"""
agent = _bare_agent()
skill_message = (
'[IMPORTANT: The user has invoked the "skill-creator" skill, indicating they want '
"you to follow its instructions. The full skill content is loaded below.]\n\n"
"# Skill Creator\n\n"
"Large skill body that must not be searched or embedded.\n\n"
"The user has provided the following instruction alongside the skill invocation: "
"make a skill for release triage"
)
agent._sync_external_memory_for_turn(
original_user_message=skill_message,
final_response="Done.",
interrupted=False,
)
agent._memory_manager.sync_all.assert_called_once_with(
skill_message,
"Done.",
session_id="test_session_001",
)
agent._memory_manager.queue_prefetch_all.assert_called_once_with(
skill_message,
session_id="test_session_001",
)
# --- Edge cases (pre-existing behaviour preserved) ------------------
def test_no_final_response_skips(self):