diff --git a/plugins/memory/openviking/__init__.py b/plugins/memory/openviking/__init__.py index c1643f2ee8..4c2a4bf15f 100644 --- a/plugins/memory/openviking/__init__.py +++ b/plugins/memory/openviking/__init__.py @@ -35,6 +35,8 @@ import uuid import zipfile from pathlib import Path from typing import Any, Dict, List, Optional +from urllib.parse import urlparse +from urllib.request import url2pathname from agent.memory_provider import MemoryProvider from tools.registry import tool_error @@ -43,6 +45,7 @@ logger = logging.getLogger(__name__) _DEFAULT_ENDPOINT = "http://127.0.0.1:1933" _TIMEOUT = 30.0 +_REMOTE_RESOURCE_PREFIXES = ("http://", "https://", "git@", "ssh://", "git://") # --------------------------------------------------------------------------- @@ -333,6 +336,40 @@ def _zip_directory(dir_path: Path) -> Path: return zip_path +def _is_windows_absolute_path(value: str) -> bool: + return ( + len(value) >= 3 + and value[0].isalpha() + and value[1] == ":" + and value[2] in ("/", "\\") + ) + + +def _is_remote_resource_source(value: str) -> bool: + return value.startswith(_REMOTE_RESOURCE_PREFIXES) + + +def _is_local_path_reference(value: str) -> bool: + if not value or "\n" in value or "\r" in value: + return False + if _is_remote_resource_source(value): + return False + if _is_windows_absolute_path(value): + return True + return ( + value.startswith(("/", "./", "../", "~/", ".\\", "..\\", "~\\")) + or "/" in value + or "\\" in value + ) + + +def _path_from_file_uri(uri: str) -> Path | str: + parsed = urlparse(uri) + if parsed.netloc not in ("", "localhost"): + return f"Unsupported non-local file URI: {uri}" + return Path(url2pathname(parsed.path)).expanduser() + + # --------------------------------------------------------------------------- # MemoryProvider implementation # --------------------------------------------------------------------------- @@ -837,23 +874,39 @@ class OpenVikingMemoryProvider(MemoryProvider): if key in args and args[key] not in (None, ""): payload[key] = args[key] - source_path = Path(url).expanduser() - cleanup_path: Optional[Path] = None - if source_path.exists(): - if source_path.is_dir(): - payload["source_name"] = source_path.name - cleanup_path = _zip_directory(source_path) - upload_path = cleanup_path - elif source_path.is_file(): - payload["source_name"] = source_path.name - upload_path = source_path - else: - return tool_error(f"Unsupported local resource path: {url}") - payload["temp_file_id"] = self._client.upload_temp_file(upload_path) + parsed_url = urlparse(url) + if _is_remote_resource_source(url): + source_path = None + elif parsed_url.scheme == "file": + source_path = _path_from_file_uri(url) + if isinstance(source_path, str): + return tool_error(source_path) + elif parsed_url.scheme and not _is_windows_absolute_path(url): + source_path = None else: - payload["path"] = url + source_path = Path(url).expanduser() + cleanup_path: Optional[Path] = None try: + if source_path is not None: + if source_path.exists(): + if source_path.is_dir(): + payload["source_name"] = source_path.name + cleanup_path = _zip_directory(source_path) + upload_path = cleanup_path + elif source_path.is_file(): + payload["source_name"] = source_path.name + upload_path = source_path + else: + return tool_error(f"Unsupported local resource path: {url}") + payload["temp_file_id"] = self._client.upload_temp_file(upload_path) + elif _is_local_path_reference(url): + return tool_error(f"Local resource path does not exist: {url}") + else: + payload["path"] = url + else: + payload["path"] = url + resp = self._client.post("/api/v1/resources", payload) result = resp.get("result", {}) finally: diff --git a/tests/plugins/memory/test_openviking_provider.py b/tests/plugins/memory/test_openviking_provider.py index d5b115600f..56691ec7e2 100644 --- a/tests/plugins/memory/test_openviking_provider.py +++ b/tests/plugins/memory/test_openviking_provider.py @@ -93,6 +93,32 @@ def test_tool_add_resource_uploads_existing_local_file(tmp_path): assert result["root_uri"] == "viking://resources/sample" +def test_tool_add_resource_uploads_file_uri(tmp_path): + sample = tmp_path / "sample.md" + sample.write_text("# Local resource\n", encoding="utf-8") + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._client.upload_temp_file.return_value = "upload_sample.md" + provider._client.post.return_value = { + "status": "ok", + "result": {"root_uri": "viking://resources/sample"}, + } + + result = json.loads(provider._tool_add_resource({ + "url": sample.as_uri(), + "reason": "file uri test", + })) + + provider._client.upload_temp_file.assert_called_once_with(sample) + provider._client.post.assert_called_once_with("/api/v1/resources", { + "reason": "file uri test", + "source_name": "sample.md", + "temp_file_id": "upload_sample.md", + }) + assert result["status"] == "added" + assert result["root_uri"] == "viking://resources/sample" + + def test_tool_add_resource_uploads_existing_local_directory_and_cleans_zip(tmp_path): docs = tmp_path / "docs" docs.mkdir() @@ -149,6 +175,40 @@ def test_tool_add_resource_cleans_local_directory_zip_when_add_fails(tmp_path): assert not uploaded_paths[0].exists() +def test_tool_add_resource_cleans_local_directory_zip_when_upload_fails(tmp_path): + docs = tmp_path / "docs" + docs.mkdir() + (docs / "guide.md").write_text("# Guide\n", encoding="utf-8") + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + uploaded_paths = [] + + def fail_upload(path): + uploaded_paths.append(path) + raise RuntimeError("upload failed") + + provider._client.upload_temp_file.side_effect = fail_upload + + with pytest.raises(RuntimeError, match="upload failed"): + provider._tool_add_resource({"url": str(docs)}) + + assert uploaded_paths + assert not uploaded_paths[0].exists() + provider._client.post.assert_not_called() + + +def test_tool_add_resource_rejects_missing_local_path(tmp_path): + missing = tmp_path / "missing.md" + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + + result = json.loads(provider._tool_add_resource({"url": str(missing)})) + + assert result["error"] == f"Local resource path does not exist: {missing}" + provider._client.upload_temp_file.assert_not_called() + provider._client.post.assert_not_called() + + def test_tool_add_resource_sends_remote_url_as_path(): provider = OpenVikingMemoryProvider() provider._client = MagicMock() @@ -165,6 +225,28 @@ def test_tool_add_resource_sends_remote_url_as_path(): }) +@pytest.mark.parametrize("url", [ + "git@github.com:org/repo.git", + "git@ssh.dev.azure.com:v3/org/project/repo", + "ssh://git@github.com/org/repo.git", + "git://github.com/org/repo.git", +]) +def test_tool_add_resource_sends_git_remote_sources_as_path(url): + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._client.post.return_value = { + "status": "ok", + "result": {"root_uri": "viking://resources/repo"}, + } + + provider._tool_add_resource({"url": url}) + + provider._client.upload_temp_file.assert_not_called() + provider._client.post.assert_called_once_with("/api/v1/resources", { + "path": url, + }) + + def test_viking_client_upload_temp_file_uses_multipart_identity_headers(tmp_path, monkeypatch): sample = tmp_path / "sample.md" sample.write_text("# Local resource\n", encoding="utf-8")