fix(memory): harden OpenViking local path uploads

This commit is contained in:
Hao Zhe 2026-05-04 18:11:08 +08:00 committed by Teknium
parent 187951ec6b
commit 2b6345cee3
2 changed files with 149 additions and 14 deletions

View file

@ -35,6 +35,8 @@ import uuid
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from urllib.request import url2pathname
from agent.memory_provider import MemoryProvider from agent.memory_provider import MemoryProvider
from tools.registry import tool_error from tools.registry import tool_error
@ -43,6 +45,7 @@ logger = logging.getLogger(__name__)
_DEFAULT_ENDPOINT = "http://127.0.0.1:1933" _DEFAULT_ENDPOINT = "http://127.0.0.1:1933"
_TIMEOUT = 30.0 _TIMEOUT = 30.0
_REMOTE_RESOURCE_PREFIXES = ("http://", "https://", "git@", "ssh://", "git://")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -333,6 +336,40 @@ def _zip_directory(dir_path: Path) -> Path:
return zip_path return zip_path
def _is_windows_absolute_path(value: str) -> bool:
return (
len(value) >= 3
and value[0].isalpha()
and value[1] == ":"
and value[2] in ("/", "\\")
)
def _is_remote_resource_source(value: str) -> bool:
return value.startswith(_REMOTE_RESOURCE_PREFIXES)
def _is_local_path_reference(value: str) -> bool:
if not value or "\n" in value or "\r" in value:
return False
if _is_remote_resource_source(value):
return False
if _is_windows_absolute_path(value):
return True
return (
value.startswith(("/", "./", "../", "~/", ".\\", "..\\", "~\\"))
or "/" in value
or "\\" in value
)
def _path_from_file_uri(uri: str) -> Path | str:
parsed = urlparse(uri)
if parsed.netloc not in ("", "localhost"):
return f"Unsupported non-local file URI: {uri}"
return Path(url2pathname(parsed.path)).expanduser()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# MemoryProvider implementation # MemoryProvider implementation
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -837,23 +874,39 @@ class OpenVikingMemoryProvider(MemoryProvider):
if key in args and args[key] not in (None, ""): if key in args and args[key] not in (None, ""):
payload[key] = args[key] payload[key] = args[key]
source_path = Path(url).expanduser() parsed_url = urlparse(url)
cleanup_path: Optional[Path] = None if _is_remote_resource_source(url):
if source_path.exists(): source_path = None
if source_path.is_dir(): elif parsed_url.scheme == "file":
payload["source_name"] = source_path.name source_path = _path_from_file_uri(url)
cleanup_path = _zip_directory(source_path) if isinstance(source_path, str):
upload_path = cleanup_path return tool_error(source_path)
elif source_path.is_file(): elif parsed_url.scheme and not _is_windows_absolute_path(url):
payload["source_name"] = source_path.name source_path = None
upload_path = source_path
else:
return tool_error(f"Unsupported local resource path: {url}")
payload["temp_file_id"] = self._client.upload_temp_file(upload_path)
else: else:
payload["path"] = url source_path = Path(url).expanduser()
cleanup_path: Optional[Path] = None
try: try:
if source_path is not None:
if source_path.exists():
if source_path.is_dir():
payload["source_name"] = source_path.name
cleanup_path = _zip_directory(source_path)
upload_path = cleanup_path
elif source_path.is_file():
payload["source_name"] = source_path.name
upload_path = source_path
else:
return tool_error(f"Unsupported local resource path: {url}")
payload["temp_file_id"] = self._client.upload_temp_file(upload_path)
elif _is_local_path_reference(url):
return tool_error(f"Local resource path does not exist: {url}")
else:
payload["path"] = url
else:
payload["path"] = url
resp = self._client.post("/api/v1/resources", payload) resp = self._client.post("/api/v1/resources", payload)
result = resp.get("result", {}) result = resp.get("result", {})
finally: finally:

View file

@ -93,6 +93,32 @@ def test_tool_add_resource_uploads_existing_local_file(tmp_path):
assert result["root_uri"] == "viking://resources/sample" assert result["root_uri"] == "viking://resources/sample"
def test_tool_add_resource_uploads_file_uri(tmp_path):
sample = tmp_path / "sample.md"
sample.write_text("# Local resource\n", encoding="utf-8")
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._client.upload_temp_file.return_value = "upload_sample.md"
provider._client.post.return_value = {
"status": "ok",
"result": {"root_uri": "viking://resources/sample"},
}
result = json.loads(provider._tool_add_resource({
"url": sample.as_uri(),
"reason": "file uri test",
}))
provider._client.upload_temp_file.assert_called_once_with(sample)
provider._client.post.assert_called_once_with("/api/v1/resources", {
"reason": "file uri test",
"source_name": "sample.md",
"temp_file_id": "upload_sample.md",
})
assert result["status"] == "added"
assert result["root_uri"] == "viking://resources/sample"
def test_tool_add_resource_uploads_existing_local_directory_and_cleans_zip(tmp_path): def test_tool_add_resource_uploads_existing_local_directory_and_cleans_zip(tmp_path):
docs = tmp_path / "docs" docs = tmp_path / "docs"
docs.mkdir() docs.mkdir()
@ -149,6 +175,40 @@ def test_tool_add_resource_cleans_local_directory_zip_when_add_fails(tmp_path):
assert not uploaded_paths[0].exists() assert not uploaded_paths[0].exists()
def test_tool_add_resource_cleans_local_directory_zip_when_upload_fails(tmp_path):
docs = tmp_path / "docs"
docs.mkdir()
(docs / "guide.md").write_text("# Guide\n", encoding="utf-8")
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
uploaded_paths = []
def fail_upload(path):
uploaded_paths.append(path)
raise RuntimeError("upload failed")
provider._client.upload_temp_file.side_effect = fail_upload
with pytest.raises(RuntimeError, match="upload failed"):
provider._tool_add_resource({"url": str(docs)})
assert uploaded_paths
assert not uploaded_paths[0].exists()
provider._client.post.assert_not_called()
def test_tool_add_resource_rejects_missing_local_path(tmp_path):
missing = tmp_path / "missing.md"
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
result = json.loads(provider._tool_add_resource({"url": str(missing)}))
assert result["error"] == f"Local resource path does not exist: {missing}"
provider._client.upload_temp_file.assert_not_called()
provider._client.post.assert_not_called()
def test_tool_add_resource_sends_remote_url_as_path(): def test_tool_add_resource_sends_remote_url_as_path():
provider = OpenVikingMemoryProvider() provider = OpenVikingMemoryProvider()
provider._client = MagicMock() provider._client = MagicMock()
@ -165,6 +225,28 @@ def test_tool_add_resource_sends_remote_url_as_path():
}) })
@pytest.mark.parametrize("url", [
"git@github.com:org/repo.git",
"git@ssh.dev.azure.com:v3/org/project/repo",
"ssh://git@github.com/org/repo.git",
"git://github.com/org/repo.git",
])
def test_tool_add_resource_sends_git_remote_sources_as_path(url):
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._client.post.return_value = {
"status": "ok",
"result": {"root_uri": "viking://resources/repo"},
}
provider._tool_add_resource({"url": url})
provider._client.upload_temp_file.assert_not_called()
provider._client.post.assert_called_once_with("/api/v1/resources", {
"path": url,
})
def test_viking_client_upload_temp_file_uses_multipart_identity_headers(tmp_path, monkeypatch): def test_viking_client_upload_temp_file_uses_multipart_identity_headers(tmp_path, monkeypatch):
sample = tmp_path / "sample.md" sample = tmp_path / "sample.md"
sample.write_text("# Local resource\n", encoding="utf-8") sample.write_text("# Local resource\n", encoding="utf-8")