diff --git a/tests/agent/test_auxiliary_client.py b/tests/agent/test_auxiliary_client.py new file mode 100644 index 0000000000..fd7893aae2 --- /dev/null +++ b/tests/agent/test_auxiliary_client.py @@ -0,0 +1,167 @@ +"""Tests for agent/auxiliary_client.py — API client resolution chain.""" + +import json +import os +from pathlib import Path +from unittest.mock import patch + +from agent.auxiliary_client import ( + _read_nous_auth, + _nous_api_key, + _nous_base_url, + auxiliary_max_tokens_param, + get_auxiliary_extra_body, + _AUTH_JSON_PATH, + _NOUS_DEFAULT_BASE_URL, + NOUS_EXTRA_BODY, +) + + +# --------------------------------------------------------------------------- +# _read_nous_auth +# --------------------------------------------------------------------------- + + +class TestReadNousAuth: + def test_missing_file(self, tmp_path): + with patch("agent.auxiliary_client._AUTH_JSON_PATH", tmp_path / "nope.json"): + assert _read_nous_auth() is None + + def test_wrong_active_provider(self, tmp_path): + auth_file = tmp_path / "auth.json" + auth_file.write_text(json.dumps({ + "active_provider": "openrouter", + "providers": {"nous": {"access_token": "tok"}} + })) + with patch("agent.auxiliary_client._AUTH_JSON_PATH", auth_file): + assert _read_nous_auth() is None + + def test_missing_tokens(self, tmp_path): + auth_file = tmp_path / "auth.json" + auth_file.write_text(json.dumps({ + "active_provider": "nous", + "providers": {"nous": {}} + })) + with patch("agent.auxiliary_client._AUTH_JSON_PATH", auth_file): + assert _read_nous_auth() is None + + def test_valid_access_token(self, tmp_path): + auth_file = tmp_path / "auth.json" + auth_file.write_text(json.dumps({ + "active_provider": "nous", + "providers": {"nous": {"access_token": "my-token"}} + })) + with patch("agent.auxiliary_client._AUTH_JSON_PATH", auth_file): + result = _read_nous_auth() + assert result is not None + assert result["access_token"] == "my-token" + + def test_valid_agent_key(self, tmp_path): + auth_file = tmp_path / "auth.json" + auth_file.write_text(json.dumps({ + "active_provider": "nous", + "providers": {"nous": {"agent_key": "agent-key-123"}} + })) + with patch("agent.auxiliary_client._AUTH_JSON_PATH", auth_file): + result = _read_nous_auth() + assert result is not None + + def test_corrupt_json(self, tmp_path): + auth_file = tmp_path / "auth.json" + auth_file.write_text("not json{{{") + with patch("agent.auxiliary_client._AUTH_JSON_PATH", auth_file): + assert _read_nous_auth() is None + + +# --------------------------------------------------------------------------- +# _nous_api_key +# --------------------------------------------------------------------------- + + +class TestNousApiKey: + def test_prefers_agent_key(self): + provider = {"agent_key": "agent-key", "access_token": "access-tok"} + assert _nous_api_key(provider) == "agent-key" + + def test_falls_back_to_access_token(self): + provider = {"access_token": "access-tok"} + assert _nous_api_key(provider) == "access-tok" + + def test_empty_provider(self): + assert _nous_api_key({}) == "" + + +# --------------------------------------------------------------------------- +# _nous_base_url +# --------------------------------------------------------------------------- + + +class TestNousBaseUrl: + def test_default(self): + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("NOUS_INFERENCE_BASE_URL", None) + assert _nous_base_url() == _NOUS_DEFAULT_BASE_URL + + def test_env_override(self): + with patch.dict(os.environ, {"NOUS_INFERENCE_BASE_URL": "https://custom.api/v1"}): + assert _nous_base_url() == "https://custom.api/v1" + + +# --------------------------------------------------------------------------- +# auxiliary_max_tokens_param +# --------------------------------------------------------------------------- + + +class TestAuxiliaryMaxTokensParam: + def test_openrouter_uses_max_tokens(self): + with patch.dict(os.environ, {"OPENROUTER_API_KEY": "key"}): + result = auxiliary_max_tokens_param(1000) + assert result == {"max_tokens": 1000} + + def test_direct_openai_uses_max_completion_tokens(self, tmp_path): + """Direct api.openai.com endpoint uses max_completion_tokens.""" + auth_file = tmp_path / "auth.json" + auth_file.write_text(json.dumps({"active_provider": "other"})) + + env = { + "OPENAI_BASE_URL": "https://api.openai.com/v1", + "OPENAI_API_KEY": "sk-test", + } + with patch.dict(os.environ, env, clear=False), \ + patch("agent.auxiliary_client._AUTH_JSON_PATH", auth_file): + os.environ.pop("OPENROUTER_API_KEY", None) + result = auxiliary_max_tokens_param(500) + assert result == {"max_completion_tokens": 500} + + def test_custom_non_openai_uses_max_tokens(self, tmp_path): + auth_file = tmp_path / "auth.json" + auth_file.write_text(json.dumps({"active_provider": "other"})) + + env = { + "OPENAI_BASE_URL": "https://my-custom-api.com/v1", + "OPENAI_API_KEY": "key", + } + with patch.dict(os.environ, env, clear=False), \ + patch("agent.auxiliary_client._AUTH_JSON_PATH", auth_file): + os.environ.pop("OPENROUTER_API_KEY", None) + result = auxiliary_max_tokens_param(500) + assert result == {"max_tokens": 500} + + +# --------------------------------------------------------------------------- +# get_auxiliary_extra_body +# --------------------------------------------------------------------------- + + +class TestGetAuxiliaryExtraBody: + def test_returns_nous_tags_when_nous(self): + with patch("agent.auxiliary_client.auxiliary_is_nous", True): + result = get_auxiliary_extra_body() + assert result == NOUS_EXTRA_BODY + # Should be a copy, not the original + assert result is not NOUS_EXTRA_BODY + + def test_returns_empty_when_not_nous(self): + with patch("agent.auxiliary_client.auxiliary_is_nous", False): + result = get_auxiliary_extra_body() + assert result == {} diff --git a/tests/gateway/test_pairing.py b/tests/gateway/test_pairing.py new file mode 100644 index 0000000000..e9e2e6f2e7 --- /dev/null +++ b/tests/gateway/test_pairing.py @@ -0,0 +1,349 @@ +"""Tests for gateway/pairing.py — DM pairing security system.""" + +import json +import os +import time +from pathlib import Path +from unittest.mock import patch + +from gateway.pairing import ( + PairingStore, + ALPHABET, + CODE_LENGTH, + CODE_TTL_SECONDS, + RATE_LIMIT_SECONDS, + MAX_PENDING_PER_PLATFORM, + MAX_FAILED_ATTEMPTS, + LOCKOUT_SECONDS, + _secure_write, +) + + +def _make_store(tmp_path): + """Create a PairingStore with PAIRING_DIR pointed to tmp_path.""" + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + return PairingStore() + + +# --------------------------------------------------------------------------- +# _secure_write +# --------------------------------------------------------------------------- + + +class TestSecureWrite: + def test_creates_parent_dirs(self, tmp_path): + target = tmp_path / "sub" / "dir" / "file.json" + _secure_write(target, '{"hello": "world"}') + assert target.exists() + assert json.loads(target.read_text()) == {"hello": "world"} + + def test_sets_file_permissions(self, tmp_path): + target = tmp_path / "secret.json" + _secure_write(target, "data") + mode = oct(target.stat().st_mode & 0o777) + assert mode == "0o600" + + +# --------------------------------------------------------------------------- +# Code generation +# --------------------------------------------------------------------------- + + +class TestCodeGeneration: + def test_code_format(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code = store.generate_code("telegram", "user1", "Alice") + assert code is not None + assert len(code) == CODE_LENGTH + assert all(c in ALPHABET for c in code) + + def test_code_uniqueness(self, tmp_path): + """Multiple codes for different users should be distinct.""" + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + codes = set() + for i in range(3): + code = store.generate_code("telegram", f"user{i}") + assert code is not None + codes.add(code) + assert len(codes) == 3 + + def test_stores_pending_entry(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code = store.generate_code("telegram", "user1", "Alice") + pending = store.list_pending("telegram") + assert len(pending) == 1 + assert pending[0]["code"] == code + assert pending[0]["user_id"] == "user1" + assert pending[0]["user_name"] == "Alice" + + +# --------------------------------------------------------------------------- +# Rate limiting +# --------------------------------------------------------------------------- + + +class TestRateLimiting: + def test_same_user_rate_limited(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code1 = store.generate_code("telegram", "user1") + code2 = store.generate_code("telegram", "user1") + assert code1 is not None + assert code2 is None # rate limited + + def test_different_users_not_rate_limited(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code1 = store.generate_code("telegram", "user1") + code2 = store.generate_code("telegram", "user2") + assert code1 is not None + assert code2 is not None + + def test_rate_limit_expires(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code1 = store.generate_code("telegram", "user1") + assert code1 is not None + + # Simulate rate limit expiry + limits = store._load_json(store._rate_limit_path()) + limits["telegram:user1"] = time.time() - RATE_LIMIT_SECONDS - 1 + store._save_json(store._rate_limit_path(), limits) + + code2 = store.generate_code("telegram", "user1") + assert code2 is not None + + +# --------------------------------------------------------------------------- +# Max pending limit +# --------------------------------------------------------------------------- + + +class TestMaxPending: + def test_max_pending_per_platform(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + codes = [] + for i in range(MAX_PENDING_PER_PLATFORM + 1): + code = store.generate_code("telegram", f"user{i}") + codes.append(code) + + # First MAX_PENDING_PER_PLATFORM should succeed + assert all(c is not None for c in codes[:MAX_PENDING_PER_PLATFORM]) + # Next one should be blocked + assert codes[MAX_PENDING_PER_PLATFORM] is None + + def test_different_platforms_independent(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + for i in range(MAX_PENDING_PER_PLATFORM): + store.generate_code("telegram", f"user{i}") + # Different platform should still work + code = store.generate_code("discord", "user0") + assert code is not None + + +# --------------------------------------------------------------------------- +# Approval flow +# --------------------------------------------------------------------------- + + +class TestApprovalFlow: + def test_approve_valid_code(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code = store.generate_code("telegram", "user1", "Alice") + result = store.approve_code("telegram", code) + + assert result is not None + assert result["user_id"] == "user1" + assert result["user_name"] == "Alice" + + def test_approved_user_is_approved(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code = store.generate_code("telegram", "user1", "Alice") + store.approve_code("telegram", code) + assert store.is_approved("telegram", "user1") is True + + def test_unapproved_user_not_approved(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + assert store.is_approved("telegram", "nonexistent") is False + + def test_approve_removes_from_pending(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code = store.generate_code("telegram", "user1") + store.approve_code("telegram", code) + pending = store.list_pending("telegram") + assert len(pending) == 0 + + def test_approve_case_insensitive(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code = store.generate_code("telegram", "user1", "Alice") + result = store.approve_code("telegram", code.lower()) + assert result is not None + + def test_approve_strips_whitespace(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code = store.generate_code("telegram", "user1", "Alice") + result = store.approve_code("telegram", f" {code} ") + assert result is not None + + def test_invalid_code_returns_none(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + result = store.approve_code("telegram", "INVALIDCODE") + assert result is None + + +# --------------------------------------------------------------------------- +# Lockout after failed attempts +# --------------------------------------------------------------------------- + + +class TestLockout: + def test_lockout_after_max_failures(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + # Generate a valid code so platform has data + store.generate_code("telegram", "user1") + + # Exhaust failed attempts + for _ in range(MAX_FAILED_ATTEMPTS): + store.approve_code("telegram", "WRONGCODE") + + # Platform should now be locked out — can't generate new codes + assert store._is_locked_out("telegram") is True + + def test_lockout_blocks_code_generation(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + for _ in range(MAX_FAILED_ATTEMPTS): + store.approve_code("telegram", "WRONG") + + code = store.generate_code("telegram", "newuser") + assert code is None + + def test_lockout_expires(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + for _ in range(MAX_FAILED_ATTEMPTS): + store.approve_code("telegram", "WRONG") + + # Simulate lockout expiry + limits = store._load_json(store._rate_limit_path()) + lockout_key = "_lockout:telegram" + limits[lockout_key] = time.time() - 1 # expired + store._save_json(store._rate_limit_path(), limits) + + assert store._is_locked_out("telegram") is False + + +# --------------------------------------------------------------------------- +# Code expiry +# --------------------------------------------------------------------------- + + +class TestCodeExpiry: + def test_expired_codes_cleaned_up(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code = store.generate_code("telegram", "user1") + + # Manually expire the code + pending = store._load_json(store._pending_path("telegram")) + pending[code]["created_at"] = time.time() - CODE_TTL_SECONDS - 1 + store._save_json(store._pending_path("telegram"), pending) + + # Cleanup happens on next operation + remaining = store.list_pending("telegram") + assert len(remaining) == 0 + + def test_expired_code_cannot_be_approved(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code = store.generate_code("telegram", "user1") + + # Expire it + pending = store._load_json(store._pending_path("telegram")) + pending[code]["created_at"] = time.time() - CODE_TTL_SECONDS - 1 + store._save_json(store._pending_path("telegram"), pending) + + result = store.approve_code("telegram", code) + assert result is None + + +# --------------------------------------------------------------------------- +# Revoke +# --------------------------------------------------------------------------- + + +class TestRevoke: + def test_revoke_approved_user(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code = store.generate_code("telegram", "user1", "Alice") + store.approve_code("telegram", code) + assert store.is_approved("telegram", "user1") is True + + revoked = store.revoke("telegram", "user1") + assert revoked is True + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + assert store.is_approved("telegram", "user1") is False + + def test_revoke_nonexistent_returns_false(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + assert store.revoke("telegram", "nobody") is False + + +# --------------------------------------------------------------------------- +# List & clear +# --------------------------------------------------------------------------- + + +class TestListAndClear: + def test_list_approved(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code = store.generate_code("telegram", "user1", "Alice") + store.approve_code("telegram", code) + approved = store.list_approved("telegram") + assert len(approved) == 1 + assert approved[0]["user_id"] == "user1" + assert approved[0]["platform"] == "telegram" + + def test_list_approved_all_platforms(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + c1 = store.generate_code("telegram", "user1") + store.approve_code("telegram", c1) + c2 = store.generate_code("discord", "user2") + store.approve_code("discord", c2) + approved = store.list_approved() + assert len(approved) == 2 + + def test_clear_pending(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + store.generate_code("telegram", "user1") + store.generate_code("telegram", "user2") + count = store.clear_pending("telegram") + remaining = store.list_pending("telegram") + assert count == 2 + assert len(remaining) == 0 + + def test_clear_pending_all_platforms(self, tmp_path): + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + store.generate_code("telegram", "user1") + store.generate_code("discord", "user2") + count = store.clear_pending() + assert count == 2 diff --git a/tests/honcho_integration/test_session.py b/tests/honcho_integration/test_session.py new file mode 100644 index 0000000000..356be3a407 --- /dev/null +++ b/tests/honcho_integration/test_session.py @@ -0,0 +1,189 @@ +"""Tests for honcho_integration/session.py — HonchoSession and helpers.""" + +from datetime import datetime +from unittest.mock import MagicMock + +from honcho_integration.session import ( + HonchoSession, + HonchoSessionManager, +) + + +# --------------------------------------------------------------------------- +# HonchoSession dataclass +# --------------------------------------------------------------------------- + + +class TestHonchoSession: + def _make_session(self): + return HonchoSession( + key="telegram:12345", + user_peer_id="user-telegram-12345", + assistant_peer_id="hermes-assistant", + honcho_session_id="telegram-12345", + ) + + def test_initial_state(self): + session = self._make_session() + assert session.key == "telegram:12345" + assert session.messages == [] + assert isinstance(session.created_at, datetime) + assert isinstance(session.updated_at, datetime) + + def test_add_message(self): + session = self._make_session() + session.add_message("user", "Hello!") + assert len(session.messages) == 1 + assert session.messages[0]["role"] == "user" + assert session.messages[0]["content"] == "Hello!" + assert "timestamp" in session.messages[0] + + def test_add_message_with_kwargs(self): + session = self._make_session() + session.add_message("assistant", "Hi!", source="gateway") + assert session.messages[0]["source"] == "gateway" + + def test_add_message_updates_timestamp(self): + session = self._make_session() + original = session.updated_at + session.add_message("user", "test") + assert session.updated_at >= original + + def test_get_history(self): + session = self._make_session() + session.add_message("user", "msg1") + session.add_message("assistant", "msg2") + history = session.get_history() + assert len(history) == 2 + assert history[0] == {"role": "user", "content": "msg1"} + assert history[1] == {"role": "assistant", "content": "msg2"} + + def test_get_history_strips_extra_fields(self): + session = self._make_session() + session.add_message("user", "hello", extra="metadata") + history = session.get_history() + assert "extra" not in history[0] + assert set(history[0].keys()) == {"role", "content"} + + def test_get_history_max_messages(self): + session = self._make_session() + for i in range(10): + session.add_message("user", f"msg{i}") + history = session.get_history(max_messages=3) + assert len(history) == 3 + assert history[0]["content"] == "msg7" + assert history[2]["content"] == "msg9" + + def test_get_history_max_messages_larger_than_total(self): + session = self._make_session() + session.add_message("user", "only one") + history = session.get_history(max_messages=100) + assert len(history) == 1 + + def test_clear(self): + session = self._make_session() + session.add_message("user", "msg1") + session.add_message("user", "msg2") + session.clear() + assert session.messages == [] + + def test_clear_updates_timestamp(self): + session = self._make_session() + session.add_message("user", "msg") + original = session.updated_at + session.clear() + assert session.updated_at >= original + + +# --------------------------------------------------------------------------- +# HonchoSessionManager._sanitize_id +# --------------------------------------------------------------------------- + + +class TestSanitizeId: + def test_clean_id_unchanged(self): + mgr = HonchoSessionManager() + assert mgr._sanitize_id("telegram-12345") == "telegram-12345" + + def test_colons_replaced(self): + mgr = HonchoSessionManager() + assert mgr._sanitize_id("telegram:12345") == "telegram-12345" + + def test_special_chars_replaced(self): + mgr = HonchoSessionManager() + result = mgr._sanitize_id("user@chat#room!") + assert "@" not in result + assert "#" not in result + assert "!" not in result + + def test_alphanumeric_preserved(self): + mgr = HonchoSessionManager() + assert mgr._sanitize_id("abc123_XYZ-789") == "abc123_XYZ-789" + + +# --------------------------------------------------------------------------- +# HonchoSessionManager._format_migration_transcript +# --------------------------------------------------------------------------- + + +class TestFormatMigrationTranscript: + def test_basic_transcript(self): + messages = [ + {"role": "user", "content": "Hello", "timestamp": "2026-01-01T00:00:00"}, + {"role": "assistant", "content": "Hi!", "timestamp": "2026-01-01T00:01:00"}, + ] + result = HonchoSessionManager._format_migration_transcript("telegram:123", messages) + assert isinstance(result, bytes) + text = result.decode("utf-8") + assert "" in text + assert "user: Hello" in text + assert "assistant: Hi!" in text + assert 'session_key="telegram:123"' in text + assert 'message_count="2"' in text + + def test_empty_messages(self): + result = HonchoSessionManager._format_migration_transcript("key", []) + text = result.decode("utf-8") + assert "" in text + assert "" in text + + def test_missing_fields_handled(self): + messages = [{"role": "user"}] # no content, no timestamp + result = HonchoSessionManager._format_migration_transcript("key", messages) + text = result.decode("utf-8") + assert "user: " in text # empty content + + +# --------------------------------------------------------------------------- +# HonchoSessionManager.delete / list_sessions +# --------------------------------------------------------------------------- + + +class TestManagerCacheOps: + def test_delete_cached_session(self): + mgr = HonchoSessionManager() + session = HonchoSession( + key="test", user_peer_id="u", assistant_peer_id="a", + honcho_session_id="s", + ) + mgr._cache["test"] = session + assert mgr.delete("test") is True + assert "test" not in mgr._cache + + def test_delete_nonexistent_returns_false(self): + mgr = HonchoSessionManager() + assert mgr.delete("nonexistent") is False + + def test_list_sessions(self): + mgr = HonchoSessionManager() + s1 = HonchoSession(key="k1", user_peer_id="u", assistant_peer_id="a", honcho_session_id="s1") + s2 = HonchoSession(key="k2", user_peer_id="u", assistant_peer_id="a", honcho_session_id="s2") + s1.add_message("user", "hi") + mgr._cache["k1"] = s1 + mgr._cache["k2"] = s2 + sessions = mgr.list_sessions() + assert len(sessions) == 2 + keys = {s["key"] for s in sessions} + assert keys == {"k1", "k2"} + s1_info = next(s for s in sessions if s["key"] == "k1") + assert s1_info["message_count"] == 1 diff --git a/tests/tools/test_skill_manager_tool.py b/tests/tools/test_skill_manager_tool.py new file mode 100644 index 0000000000..12779ad574 --- /dev/null +++ b/tests/tools/test_skill_manager_tool.py @@ -0,0 +1,366 @@ +"""Tests for tools/skill_manager_tool.py — skill creation, editing, and deletion.""" + +import json +from pathlib import Path +from unittest.mock import patch + +from tools.skill_manager_tool import ( + _validate_name, + _validate_frontmatter, + _validate_file_path, + _find_skill, + _resolve_skill_dir, + _create_skill, + _edit_skill, + _patch_skill, + _delete_skill, + _write_file, + _remove_file, + skill_manage, + VALID_NAME_RE, + ALLOWED_SUBDIRS, + MAX_NAME_LENGTH, +) + + +VALID_SKILL_CONTENT = """\ +--- +name: test-skill +description: A test skill for unit testing. +--- + +# Test Skill + +Step 1: Do the thing. +""" + +VALID_SKILL_CONTENT_2 = """\ +--- +name: test-skill +description: Updated description. +--- + +# Test Skill v2 + +Step 1: Do the new thing. +""" + + +# --------------------------------------------------------------------------- +# _validate_name +# --------------------------------------------------------------------------- + + +class TestValidateName: + def test_valid_names(self): + assert _validate_name("my-skill") is None + assert _validate_name("skill123") is None + assert _validate_name("my_skill.v2") is None + assert _validate_name("a") is None + + def test_empty_name(self): + assert _validate_name("") is not None + + def test_too_long(self): + assert _validate_name("a" * (MAX_NAME_LENGTH + 1)) is not None + + def test_uppercase_rejected(self): + assert _validate_name("MySkill") is not None + + def test_starts_with_hyphen_rejected(self): + assert _validate_name("-invalid") is not None + + def test_special_chars_rejected(self): + assert _validate_name("skill/name") is not None + assert _validate_name("skill name") is not None + assert _validate_name("skill@name") is not None + + +# --------------------------------------------------------------------------- +# _validate_frontmatter +# --------------------------------------------------------------------------- + + +class TestValidateFrontmatter: + def test_valid_content(self): + assert _validate_frontmatter(VALID_SKILL_CONTENT) is None + + def test_empty_content(self): + assert _validate_frontmatter("") is not None + assert _validate_frontmatter(" ") is not None + + def test_no_frontmatter(self): + err = _validate_frontmatter("# Just a heading\nSome content.\n") + assert err is not None + assert "frontmatter" in err.lower() + + def test_unclosed_frontmatter(self): + content = "---\nname: test\ndescription: desc\nBody content.\n" + assert _validate_frontmatter(content) is not None + + def test_missing_name_field(self): + content = "---\ndescription: desc\n---\n\nBody.\n" + assert _validate_frontmatter(content) is not None + + def test_missing_description_field(self): + content = "---\nname: test\n---\n\nBody.\n" + assert _validate_frontmatter(content) is not None + + def test_no_body_after_frontmatter(self): + content = "---\nname: test\ndescription: desc\n---\n" + assert _validate_frontmatter(content) is not None + + def test_invalid_yaml(self): + content = "---\n: invalid: yaml: {{{\n---\n\nBody.\n" + assert _validate_frontmatter(content) is not None + + +# --------------------------------------------------------------------------- +# _validate_file_path — path traversal prevention +# --------------------------------------------------------------------------- + + +class TestValidateFilePath: + def test_valid_paths(self): + assert _validate_file_path("references/api.md") is None + assert _validate_file_path("templates/config.yaml") is None + assert _validate_file_path("scripts/train.py") is None + assert _validate_file_path("assets/image.png") is None + + def test_empty_path(self): + assert _validate_file_path("") is not None + + def test_path_traversal_blocked(self): + err = _validate_file_path("references/../../../etc/passwd") + assert err is not None + assert "traversal" in err.lower() + + def test_disallowed_subdirectory(self): + err = _validate_file_path("secret/hidden.txt") + assert err is not None + + def test_directory_only_rejected(self): + err = _validate_file_path("references") + assert err is not None + + def test_root_level_file_rejected(self): + err = _validate_file_path("malicious.py") + assert err is not None + + +# --------------------------------------------------------------------------- +# CRUD operations +# --------------------------------------------------------------------------- + + +class TestCreateSkill: + def test_create_skill(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + result = _create_skill("my-skill", VALID_SKILL_CONTENT) + assert result["success"] is True + assert (tmp_path / "my-skill" / "SKILL.md").exists() + + def test_create_with_category(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + result = _create_skill("my-skill", VALID_SKILL_CONTENT, category="devops") + assert result["success"] is True + assert (tmp_path / "devops" / "my-skill" / "SKILL.md").exists() + assert result["category"] == "devops" + + def test_create_duplicate_blocked(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + result = _create_skill("my-skill", VALID_SKILL_CONTENT) + assert result["success"] is False + assert "already exists" in result["error"] + + def test_create_invalid_name(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + result = _create_skill("Invalid Name!", VALID_SKILL_CONTENT) + assert result["success"] is False + + def test_create_invalid_content(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + result = _create_skill("my-skill", "no frontmatter here") + assert result["success"] is False + + +class TestEditSkill: + def test_edit_existing_skill(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + result = _edit_skill("my-skill", VALID_SKILL_CONTENT_2) + assert result["success"] is True + content = (tmp_path / "my-skill" / "SKILL.md").read_text() + assert "Updated description" in content + + def test_edit_nonexistent_skill(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + result = _edit_skill("nonexistent", VALID_SKILL_CONTENT) + assert result["success"] is False + assert "not found" in result["error"] + + def test_edit_invalid_content_rejected(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + result = _edit_skill("my-skill", "no frontmatter") + assert result["success"] is False + # Original content should be preserved + content = (tmp_path / "my-skill" / "SKILL.md").read_text() + assert "A test skill" in content + + +class TestPatchSkill: + def test_patch_unique_match(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + result = _patch_skill("my-skill", "Do the thing.", "Do the new thing.") + assert result["success"] is True + content = (tmp_path / "my-skill" / "SKILL.md").read_text() + assert "Do the new thing." in content + + def test_patch_nonexistent_string(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + result = _patch_skill("my-skill", "this text does not exist", "replacement") + assert result["success"] is False + assert "not found" in result["error"] + + def test_patch_ambiguous_match_rejected(self, tmp_path): + content = """\ +--- +name: test-skill +description: A test skill. +--- + +# Test + +word word +""" + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + _create_skill("my-skill", content) + result = _patch_skill("my-skill", "word", "replaced") + assert result["success"] is False + assert "matched" in result["error"] + + def test_patch_replace_all(self, tmp_path): + content = """\ +--- +name: test-skill +description: A test skill. +--- + +# Test + +word word +""" + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + _create_skill("my-skill", content) + result = _patch_skill("my-skill", "word", "replaced", replace_all=True) + assert result["success"] is True + + def test_patch_supporting_file(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + _write_file("my-skill", "references/api.md", "old text here") + result = _patch_skill("my-skill", "old text", "new text", file_path="references/api.md") + assert result["success"] is True + + def test_patch_skill_not_found(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + result = _patch_skill("nonexistent", "old", "new") + assert result["success"] is False + + +class TestDeleteSkill: + def test_delete_existing(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + result = _delete_skill("my-skill") + assert result["success"] is True + assert not (tmp_path / "my-skill").exists() + + def test_delete_nonexistent(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + result = _delete_skill("nonexistent") + assert result["success"] is False + + def test_delete_cleans_empty_category_dir(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT, category="devops") + _delete_skill("my-skill") + assert not (tmp_path / "devops").exists() + + +# --------------------------------------------------------------------------- +# write_file / remove_file +# --------------------------------------------------------------------------- + + +class TestWriteFile: + def test_write_reference_file(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + result = _write_file("my-skill", "references/api.md", "# API\nEndpoint docs.") + assert result["success"] is True + assert (tmp_path / "my-skill" / "references" / "api.md").exists() + + def test_write_to_nonexistent_skill(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + result = _write_file("nonexistent", "references/doc.md", "content") + assert result["success"] is False + + def test_write_to_disallowed_path(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + result = _write_file("my-skill", "secret/evil.py", "malicious") + assert result["success"] is False + + +class TestRemoveFile: + def test_remove_existing_file(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + _write_file("my-skill", "references/api.md", "content") + result = _remove_file("my-skill", "references/api.md") + assert result["success"] is True + assert not (tmp_path / "my-skill" / "references" / "api.md").exists() + + def test_remove_nonexistent_file(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + result = _remove_file("my-skill", "references/nope.md") + assert result["success"] is False + + +# --------------------------------------------------------------------------- +# skill_manage dispatcher +# --------------------------------------------------------------------------- + + +class TestSkillManageDispatcher: + def test_unknown_action(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + raw = skill_manage(action="explode", name="test") + result = json.loads(raw) + assert result["success"] is False + assert "Unknown action" in result["error"] + + def test_create_without_content(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + raw = skill_manage(action="create", name="test") + result = json.loads(raw) + assert result["success"] is False + assert "content" in result["error"].lower() + + def test_patch_without_old_string(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + raw = skill_manage(action="patch", name="test") + result = json.loads(raw) + assert result["success"] is False + + def test_full_create_via_dispatcher(self, tmp_path): + with patch("tools.skill_manager_tool.SKILLS_DIR", tmp_path): + raw = skill_manage(action="create", name="test-skill", content=VALID_SKILL_CONTENT) + result = json.loads(raw) + assert result["success"] is True diff --git a/tests/tools/test_skills_tool.py b/tests/tools/test_skills_tool.py new file mode 100644 index 0000000000..07ea76b87c --- /dev/null +++ b/tests/tools/test_skills_tool.py @@ -0,0 +1,334 @@ +"""Tests for tools/skills_tool.py — skill discovery and viewing.""" + +import json +from pathlib import Path +from unittest.mock import patch + +from tools.skills_tool import ( + _parse_frontmatter, + _parse_tags, + _get_category_from_path, + _estimate_tokens, + _find_all_skills, + _load_category_description, + skills_list, + skills_categories, + skill_view, + SKILLS_DIR, + MAX_NAME_LENGTH, + MAX_DESCRIPTION_LENGTH, +) + + +def _make_skill(skills_dir, name, frontmatter_extra="", body="Step 1: Do the thing.", category=None): + """Helper to create a minimal skill directory.""" + if category: + skill_dir = skills_dir / category / name + else: + skill_dir = skills_dir / name + skill_dir.mkdir(parents=True, exist_ok=True) + content = f"""\ +--- +name: {name} +description: Description for {name}. +{frontmatter_extra}--- + +# {name} + +{body} +""" + (skill_dir / "SKILL.md").write_text(content) + return skill_dir + + +# --------------------------------------------------------------------------- +# _parse_frontmatter +# --------------------------------------------------------------------------- + + +class TestParseFrontmatter: + def test_valid_frontmatter(self): + content = "---\nname: test\ndescription: A test.\n---\n\n# Body\n" + fm, body = _parse_frontmatter(content) + assert fm["name"] == "test" + assert fm["description"] == "A test." + assert "# Body" in body + + def test_no_frontmatter(self): + content = "# Just a heading\nSome content.\n" + fm, body = _parse_frontmatter(content) + assert fm == {} + assert body == content + + def test_empty_frontmatter(self): + content = "---\n---\n\n# Body\n" + fm, body = _parse_frontmatter(content) + assert fm == {} + + def test_nested_yaml(self): + content = "---\nname: test\nmetadata:\n hermes:\n tags: [a, b]\n---\n\nBody.\n" + fm, body = _parse_frontmatter(content) + assert fm["metadata"]["hermes"]["tags"] == ["a", "b"] + + def test_malformed_yaml_fallback(self): + """Malformed YAML falls back to simple key:value parsing.""" + content = "---\nname: test\ndescription: desc\n: invalid\n---\n\nBody.\n" + fm, body = _parse_frontmatter(content) + # Should still parse what it can via fallback + assert "name" in fm + + +# --------------------------------------------------------------------------- +# _parse_tags +# --------------------------------------------------------------------------- + + +class TestParseTags: + def test_list_input(self): + assert _parse_tags(["a", "b", "c"]) == ["a", "b", "c"] + + def test_comma_separated_string(self): + assert _parse_tags("a, b, c") == ["a", "b", "c"] + + def test_bracket_wrapped_string(self): + assert _parse_tags("[a, b, c]") == ["a", "b", "c"] + + def test_empty_input(self): + assert _parse_tags("") == [] + assert _parse_tags(None) == [] + assert _parse_tags([]) == [] + + def test_strips_quotes(self): + result = _parse_tags('"tag1", \'tag2\'') + assert "tag1" in result + assert "tag2" in result + + def test_filters_empty_items(self): + assert _parse_tags([None, "", "valid"]) == ["valid"] + + +# --------------------------------------------------------------------------- +# _get_category_from_path +# --------------------------------------------------------------------------- + + +class TestGetCategoryFromPath: + def test_categorized_skill(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + skill_md = tmp_path / "mlops" / "axolotl" / "SKILL.md" + skill_md.parent.mkdir(parents=True) + skill_md.touch() + assert _get_category_from_path(skill_md) == "mlops" + + def test_uncategorized_skill(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + skill_md = tmp_path / "my-skill" / "SKILL.md" + skill_md.parent.mkdir(parents=True) + skill_md.touch() + assert _get_category_from_path(skill_md) is None + + def test_outside_skills_dir(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path / "skills"): + skill_md = tmp_path / "other" / "SKILL.md" + assert _get_category_from_path(skill_md) is None + + +# --------------------------------------------------------------------------- +# _estimate_tokens +# --------------------------------------------------------------------------- + + +class TestEstimateTokens: + def test_estimate(self): + assert _estimate_tokens("1234") == 1 + assert _estimate_tokens("12345678") == 2 + assert _estimate_tokens("") == 0 + + +# --------------------------------------------------------------------------- +# _find_all_skills +# --------------------------------------------------------------------------- + + +class TestFindAllSkills: + def test_finds_skills(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + _make_skill(tmp_path, "skill-a") + _make_skill(tmp_path, "skill-b") + skills = _find_all_skills() + assert len(skills) == 2 + names = {s["name"] for s in skills} + assert "skill-a" in names + assert "skill-b" in names + + def test_empty_directory(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + skills = _find_all_skills() + assert skills == [] + + def test_nonexistent_directory(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path / "nope"): + skills = _find_all_skills() + assert skills == [] + + def test_categorized_skills(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + _make_skill(tmp_path, "axolotl", category="mlops") + skills = _find_all_skills() + assert len(skills) == 1 + assert skills[0]["category"] == "mlops" + + def test_description_from_body_when_missing(self, tmp_path): + """If no description in frontmatter, first non-header line is used.""" + skill_dir = tmp_path / "no-desc" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\nname: no-desc\n---\n\n# Heading\n\nFirst paragraph.\n") + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + skills = _find_all_skills() + assert skills[0]["description"] == "First paragraph." + + def test_long_description_truncated(self, tmp_path): + long_desc = "x" * (MAX_DESCRIPTION_LENGTH + 100) + skill_dir = tmp_path / "long-desc" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text(f"---\nname: long\ndescription: {long_desc}\n---\n\nBody.\n") + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + skills = _find_all_skills() + assert len(skills[0]["description"]) <= MAX_DESCRIPTION_LENGTH + + def test_skips_git_directories(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + _make_skill(tmp_path, "real-skill") + git_dir = tmp_path / ".git" / "fake-skill" + git_dir.mkdir(parents=True) + (git_dir / "SKILL.md").write_text("---\nname: fake\ndescription: x\n---\n\nBody.\n") + skills = _find_all_skills() + assert len(skills) == 1 + assert skills[0]["name"] == "real-skill" + + +# --------------------------------------------------------------------------- +# skills_list +# --------------------------------------------------------------------------- + + +class TestSkillsList: + def test_empty_creates_directory(self, tmp_path): + skills_dir = tmp_path / "skills" + with patch("tools.skills_tool.SKILLS_DIR", skills_dir): + raw = skills_list() + result = json.loads(raw) + assert result["success"] is True + assert result["skills"] == [] + assert skills_dir.exists() + + def test_lists_skills(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + _make_skill(tmp_path, "alpha") + _make_skill(tmp_path, "beta") + raw = skills_list() + result = json.loads(raw) + assert result["count"] == 2 + + def test_category_filter(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + _make_skill(tmp_path, "skill-a", category="devops") + _make_skill(tmp_path, "skill-b", category="mlops") + raw = skills_list(category="devops") + result = json.loads(raw) + assert result["count"] == 1 + assert result["skills"][0]["name"] == "skill-a" + + +# --------------------------------------------------------------------------- +# skill_view +# --------------------------------------------------------------------------- + + +class TestSkillView: + def test_view_existing_skill(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + _make_skill(tmp_path, "my-skill") + raw = skill_view("my-skill") + result = json.loads(raw) + assert result["success"] is True + assert result["name"] == "my-skill" + assert "Step 1" in result["content"] + + def test_view_nonexistent_skill(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + _make_skill(tmp_path, "other-skill") + raw = skill_view("nonexistent") + result = json.loads(raw) + assert result["success"] is False + assert "not found" in result["error"].lower() + assert "available_skills" in result + + def test_view_reference_file(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + skill_dir = _make_skill(tmp_path, "my-skill") + refs_dir = skill_dir / "references" + refs_dir.mkdir() + (refs_dir / "api.md").write_text("# API Docs\nEndpoint info.") + raw = skill_view("my-skill", file_path="references/api.md") + result = json.loads(raw) + assert result["success"] is True + assert "Endpoint info" in result["content"] + + def test_view_nonexistent_file(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + _make_skill(tmp_path, "my-skill") + raw = skill_view("my-skill", file_path="references/nope.md") + result = json.loads(raw) + assert result["success"] is False + + def test_view_shows_linked_files(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + skill_dir = _make_skill(tmp_path, "my-skill") + refs_dir = skill_dir / "references" + refs_dir.mkdir() + (refs_dir / "guide.md").write_text("guide content") + raw = skill_view("my-skill") + result = json.loads(raw) + assert result["linked_files"] is not None + assert "references" in result["linked_files"] + + def test_view_tags_from_metadata(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + _make_skill(tmp_path, "tagged", frontmatter_extra="metadata:\n hermes:\n tags: [fine-tuning, llm]\n") + raw = skill_view("tagged") + result = json.loads(raw) + assert "fine-tuning" in result["tags"] + assert "llm" in result["tags"] + + def test_view_nonexistent_skills_dir(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path / "nope"): + raw = skill_view("anything") + result = json.loads(raw) + assert result["success"] is False + + +# --------------------------------------------------------------------------- +# skills_categories +# --------------------------------------------------------------------------- + + +class TestSkillsCategories: + def test_lists_categories(self, tmp_path): + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + _make_skill(tmp_path, "s1", category="devops") + _make_skill(tmp_path, "s2", category="mlops") + raw = skills_categories() + result = json.loads(raw) + assert result["success"] is True + names = {c["name"] for c in result["categories"]} + assert "devops" in names + assert "mlops" in names + + def test_empty_skills_dir(self, tmp_path): + skills_dir = tmp_path / "skills" + with patch("tools.skills_tool.SKILLS_DIR", skills_dir): + raw = skills_categories() + result = json.loads(raw) + assert result["success"] is True + assert result["categories"] == []