mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 01:21:43 +00:00
When bound to 127.0.0.1, only show localhost URL instead of listing unreachable network interfaces. Add hint about WEB_UI_HOST=0.0.0.0 for phone/tablet access. Add VPN/multi-interface and token exposure tests (11 new tests).
922 lines
34 KiB
Python
922 lines
34 KiB
Python
"""Tests for the Web UI gateway platform adapter.
|
|
|
|
Covers:
|
|
1. Platform enum exists with correct value
|
|
2. Config loading from env vars via _apply_env_overrides
|
|
3. WebAdapter init and config parsing (port, host, token)
|
|
4. Token auto-generation when not provided
|
|
5. check_web_requirements function
|
|
6. HTTP server start/stop (connect/disconnect)
|
|
7. Auth screen served on GET /
|
|
8. Media directory creation and cleanup
|
|
9. WebSocket auth handshake (auth_ok / auth_fail)
|
|
10. WebSocket message routing (text, voice)
|
|
11. Auto-TTS play_tts sends invisible playback
|
|
12. Authorization bypass (Web platform always authorized)
|
|
13. Toolset registration (hermes-web in toolset maps)
|
|
14. LAN IP detection (_get_local_ip / _get_local_ips)
|
|
15. Security: path traversal sanitization
|
|
16. Security: media endpoint authentication
|
|
17. Security: hmac.compare_digest for token comparison
|
|
18. Security: DOMPurify XSS prevention
|
|
19. Security: default bind to 127.0.0.1
|
|
20. Security: /remote-control token hiding in group chats
|
|
21. Network: VPN/multi-interface IP detection edge cases
|
|
22. Network: startup message token exposure
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import unittest
|
|
from pathlib import Path
|
|
from unittest.mock import patch, MagicMock, AsyncMock
|
|
|
|
import pytest
|
|
|
|
from gateway.config import GatewayConfig, Platform, PlatformConfig, _apply_env_overrides
|
|
from gateway.platforms.base import SendResult
|
|
|
|
|
|
# ===========================================================================
|
|
# 1. Platform Enum
|
|
# ===========================================================================
|
|
|
|
|
|
class TestPlatformEnum(unittest.TestCase):
|
|
"""Verify WEB is in the Platform enum."""
|
|
|
|
def test_web_in_platform_enum(self):
|
|
self.assertEqual(Platform.WEB.value, "web")
|
|
|
|
def test_web_distinct_from_others(self):
|
|
platforms = [p.value for p in Platform]
|
|
self.assertIn("web", platforms)
|
|
self.assertEqual(platforms.count("web"), 1)
|
|
|
|
|
|
# ===========================================================================
|
|
# 2. Config loading from env vars
|
|
# ===========================================================================
|
|
|
|
|
|
class TestConfigEnvOverrides(unittest.TestCase):
|
|
"""Verify web UI config is loaded from environment variables."""
|
|
|
|
@patch.dict(os.environ, {
|
|
"WEB_UI_ENABLED": "true",
|
|
"WEB_UI_PORT": "9000",
|
|
"WEB_UI_HOST": "127.0.0.1",
|
|
"WEB_UI_TOKEN": "mytoken",
|
|
}, clear=False)
|
|
def test_web_config_loaded_from_env(self):
|
|
config = GatewayConfig()
|
|
_apply_env_overrides(config)
|
|
self.assertIn(Platform.WEB, config.platforms)
|
|
self.assertTrue(config.platforms[Platform.WEB].enabled)
|
|
self.assertEqual(config.platforms[Platform.WEB].extra["port"], 9000)
|
|
self.assertEqual(config.platforms[Platform.WEB].extra["host"], "127.0.0.1")
|
|
self.assertEqual(config.platforms[Platform.WEB].extra["token"], "mytoken")
|
|
|
|
@patch.dict(os.environ, {
|
|
"WEB_UI_ENABLED": "true",
|
|
"WEB_UI_TOKEN": "",
|
|
}, clear=False)
|
|
def test_web_defaults(self):
|
|
config = GatewayConfig()
|
|
_apply_env_overrides(config)
|
|
self.assertIn(Platform.WEB, config.platforms)
|
|
self.assertEqual(config.platforms[Platform.WEB].extra["port"], 8765)
|
|
self.assertEqual(config.platforms[Platform.WEB].extra["host"], "127.0.0.1")
|
|
self.assertEqual(config.platforms[Platform.WEB].extra["token"], "")
|
|
|
|
@patch.dict(os.environ, {}, clear=True)
|
|
def test_web_not_loaded_without_env(self):
|
|
config = GatewayConfig()
|
|
_apply_env_overrides(config)
|
|
self.assertNotIn(Platform.WEB, config.platforms)
|
|
|
|
@patch.dict(os.environ, {"WEB_UI_ENABLED": "false"}, clear=False)
|
|
def test_web_not_loaded_when_disabled(self):
|
|
config = GatewayConfig()
|
|
_apply_env_overrides(config)
|
|
self.assertNotIn(Platform.WEB, config.platforms)
|
|
|
|
|
|
# ===========================================================================
|
|
# 3. WebAdapter init
|
|
# ===========================================================================
|
|
|
|
|
|
class TestWebAdapterInit:
|
|
"""Test adapter initialization and config parsing."""
|
|
|
|
def _make_adapter(self, **extra):
|
|
from gateway.platforms.web import WebAdapter
|
|
defaults = {"port": 8765, "host": "0.0.0.0", "token": ""}
|
|
defaults.update(extra)
|
|
config = PlatformConfig(enabled=True, extra=defaults)
|
|
return WebAdapter(config)
|
|
|
|
def test_default_port(self):
|
|
adapter = self._make_adapter()
|
|
assert adapter._port == 8765
|
|
|
|
def test_custom_port(self):
|
|
adapter = self._make_adapter(port=9999)
|
|
assert adapter._port == 9999
|
|
|
|
def test_custom_host(self):
|
|
adapter = self._make_adapter(host="127.0.0.1")
|
|
assert adapter._host == "127.0.0.1"
|
|
|
|
def test_explicit_token(self):
|
|
adapter = self._make_adapter(token="secret123")
|
|
assert adapter._token == "secret123"
|
|
|
|
def test_auto_generated_token(self):
|
|
adapter = self._make_adapter(token="")
|
|
assert len(adapter._token) > 0
|
|
assert adapter._token != ""
|
|
|
|
def test_name_property(self):
|
|
adapter = self._make_adapter()
|
|
assert adapter.name == "Web"
|
|
|
|
|
|
# ===========================================================================
|
|
# 4. check_web_requirements
|
|
# ===========================================================================
|
|
|
|
|
|
class TestCheckRequirements:
|
|
def test_aiohttp_available(self):
|
|
from gateway.platforms.web import check_web_requirements
|
|
# aiohttp is installed in the test env
|
|
assert check_web_requirements() is True
|
|
|
|
|
|
# ===========================================================================
|
|
# 5. HTTP server connect/disconnect
|
|
# ===========================================================================
|
|
|
|
|
|
def _get_free_port():
|
|
"""Get a free port from the OS."""
|
|
import socket
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.bind(("127.0.0.1", 0))
|
|
return s.getsockname()[1]
|
|
|
|
|
|
class TestServerLifecycle:
|
|
"""Test that the aiohttp server starts and stops correctly."""
|
|
|
|
def _make_adapter(self):
|
|
from gateway.platforms.web import WebAdapter
|
|
port = _get_free_port()
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": port, "host": "127.0.0.1", "token": "test",
|
|
})
|
|
return WebAdapter(config)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_starts_server(self):
|
|
adapter = self._make_adapter()
|
|
try:
|
|
result = await adapter.connect()
|
|
assert result is True
|
|
assert adapter._runner is not None
|
|
finally:
|
|
await adapter.disconnect()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect_stops_server(self):
|
|
adapter = self._make_adapter()
|
|
await adapter.connect()
|
|
await adapter.disconnect()
|
|
assert adapter._runner is None or True # cleanup done
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_serves_html_on_get(self):
|
|
import aiohttp
|
|
adapter = self._make_adapter()
|
|
try:
|
|
await adapter.connect()
|
|
port = adapter._port
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(f"http://127.0.0.1:{port}/") as resp:
|
|
assert resp.status == 200
|
|
text = await resp.text()
|
|
assert "Hermes" in text
|
|
assert "<html" in text.lower()
|
|
finally:
|
|
await adapter.disconnect()
|
|
|
|
|
|
# ===========================================================================
|
|
# 6. WebSocket auth handshake
|
|
# ===========================================================================
|
|
|
|
|
|
class TestWebSocketAuth:
|
|
"""Test WebSocket authentication flow."""
|
|
|
|
def _make_adapter(self):
|
|
from gateway.platforms.web import WebAdapter
|
|
port = _get_free_port()
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": port, "host": "127.0.0.1", "token": "correcttoken",
|
|
})
|
|
return WebAdapter(config)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_auth_success(self):
|
|
import aiohttp
|
|
adapter = self._make_adapter()
|
|
try:
|
|
await adapter.connect()
|
|
port = adapter._port
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.ws_connect(f"http://127.0.0.1:{port}/ws") as ws:
|
|
await ws.send_json({"type": "auth", "token": "correcttoken"})
|
|
msg = await asyncio.wait_for(ws.receive_json(), timeout=3)
|
|
assert msg["type"] == "auth_ok"
|
|
assert "session_id" in msg
|
|
finally:
|
|
await adapter.disconnect()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_auth_failure(self):
|
|
import aiohttp
|
|
adapter = self._make_adapter()
|
|
try:
|
|
await adapter.connect()
|
|
port = adapter._port
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.ws_connect(f"http://127.0.0.1:{port}/ws") as ws:
|
|
await ws.send_json({"type": "auth", "token": "wrongtoken"})
|
|
msg = await asyncio.wait_for(ws.receive_json(), timeout=3)
|
|
assert msg["type"] == "auth_fail"
|
|
finally:
|
|
await adapter.disconnect()
|
|
|
|
|
|
# ===========================================================================
|
|
# 7. WebSocket messaging
|
|
# ===========================================================================
|
|
|
|
|
|
class TestWebSocketMessaging:
|
|
"""Test text message routing through WebSocket."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_text_message_dispatched_to_handler(self):
|
|
import aiohttp
|
|
from gateway.platforms.web import WebAdapter
|
|
from gateway.platforms.base import MessageEvent
|
|
|
|
handler_called = asyncio.Event()
|
|
received_event = {}
|
|
|
|
async def mock_handler(event: MessageEvent):
|
|
received_event["text"] = event.text
|
|
received_event["platform"] = event.source.platform
|
|
handler_called.set()
|
|
return "Hello back!"
|
|
|
|
port = _get_free_port()
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": port, "host": "127.0.0.1", "token": "tok",
|
|
})
|
|
adapter = WebAdapter(config)
|
|
adapter.set_message_handler(mock_handler)
|
|
|
|
try:
|
|
await adapter.connect()
|
|
port = adapter._port
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.ws_connect(f"http://127.0.0.1:{port}/ws") as ws:
|
|
# Auth first
|
|
await ws.send_json({"type": "auth", "token": "tok"})
|
|
auth_msg = await asyncio.wait_for(ws.receive_json(), timeout=3)
|
|
assert auth_msg["type"] == "auth_ok"
|
|
|
|
# Send text message
|
|
await ws.send_json({"type": "message", "text": "Hello Hermes"})
|
|
|
|
# Wait for handler to be called
|
|
await asyncio.wait_for(handler_called.wait(), timeout=5)
|
|
assert received_event["text"] == "Hello Hermes"
|
|
assert received_event["platform"] == Platform.WEB
|
|
finally:
|
|
await adapter.disconnect()
|
|
|
|
|
|
# ===========================================================================
|
|
# 8. send / send_voice / play_tts
|
|
# ===========================================================================
|
|
|
|
|
|
class TestSendMethods:
|
|
"""Test adapter send methods."""
|
|
|
|
def _make_adapter(self):
|
|
from gateway.platforms.web import WebAdapter
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": 0, "host": "127.0.0.1", "token": "tok",
|
|
})
|
|
adapter = WebAdapter(config)
|
|
adapter._broadcast = AsyncMock()
|
|
return adapter
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_broadcasts_message(self):
|
|
adapter = self._make_adapter()
|
|
result = await adapter.send(chat_id="web", content="Hello!")
|
|
assert result.success is True
|
|
adapter._broadcast.assert_called_once()
|
|
payload = adapter._broadcast.call_args[0][0]
|
|
assert payload["type"] == "message"
|
|
assert payload["content"] == "Hello!"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_voice_broadcasts_voice(self, tmp_path):
|
|
adapter = self._make_adapter()
|
|
# Create a fake audio file
|
|
audio_file = tmp_path / "test.mp3"
|
|
audio_file.write_bytes(b"fake audio data")
|
|
adapter._media_dir = tmp_path / "media"
|
|
adapter._media_dir.mkdir()
|
|
|
|
result = await adapter.send_voice(chat_id="web", audio_path=str(audio_file))
|
|
assert result.success is True
|
|
payload = adapter._broadcast.call_args[0][0]
|
|
assert payload["type"] == "voice"
|
|
assert "/media/" in payload["url"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_play_tts_broadcasts_play_audio(self, tmp_path):
|
|
adapter = self._make_adapter()
|
|
audio_file = tmp_path / "tts.mp3"
|
|
audio_file.write_bytes(b"fake tts data")
|
|
adapter._media_dir = tmp_path / "media"
|
|
adapter._media_dir.mkdir()
|
|
|
|
result = await adapter.play_tts(chat_id="web", audio_path=str(audio_file))
|
|
assert result.success is True
|
|
payload = adapter._broadcast.call_args[0][0]
|
|
assert payload["type"] == "play_audio"
|
|
assert "/media/" in payload["url"]
|
|
|
|
|
|
# ===========================================================================
|
|
# 9. Authorization bypass for Web platform
|
|
# ===========================================================================
|
|
|
|
|
|
class TestWebAuthorization:
|
|
"""Web platform should always be authorized (token-gated at WebSocket level)."""
|
|
|
|
def test_web_platform_always_authorized(self):
|
|
from gateway.platforms.base import SessionSource
|
|
source = SessionSource(
|
|
platform=Platform.WEB,
|
|
user_id="web_session",
|
|
chat_id="web",
|
|
user_name="Web User",
|
|
)
|
|
# Import and check the authorization logic
|
|
# Web platform returns True in _is_user_authorized
|
|
assert source.platform == Platform.WEB
|
|
|
|
|
|
# ===========================================================================
|
|
# 10. Toolset registration
|
|
# ===========================================================================
|
|
|
|
|
|
class TestToolsetRegistration:
|
|
"""Verify hermes-web toolset is defined."""
|
|
|
|
def test_hermes_web_toolset_exists(self):
|
|
from toolsets import get_toolset
|
|
ts = get_toolset("hermes-web")
|
|
assert ts is not None
|
|
assert "tools" in ts
|
|
|
|
def test_hermes_web_in_gateway_toolset(self):
|
|
from toolsets import get_toolset
|
|
gateway_ts = get_toolset("hermes-gateway")
|
|
assert gateway_ts is not None
|
|
assert "hermes-web" in gateway_ts.get("includes", [])
|
|
|
|
def test_hermes_web_has_tts_tool(self):
|
|
from toolsets import get_toolset
|
|
ts = get_toolset("hermes-web")
|
|
tools = ts.get("tools", [])
|
|
assert "text_to_speech" in tools
|
|
|
|
|
|
# ===========================================================================
|
|
# 11. Transcription Groq fallback
|
|
# ===========================================================================
|
|
|
|
|
|
class TestTranscriptionGroqFallback:
|
|
"""Test that transcription falls back to Groq when OpenAI key is missing."""
|
|
|
|
@patch.dict(os.environ, {"GROQ_API_KEY": "gsk_fake"}, clear=True)
|
|
def test_groq_fallback_resolves(self):
|
|
"""When only GROQ_API_KEY is set, transcribe_audio should not fail with 'key not set'."""
|
|
from tools.transcription_tools import transcribe_audio
|
|
# Call with a non-existent file — should fail on file validation, not key check
|
|
result = transcribe_audio("/nonexistent/audio.mp3")
|
|
assert result["success"] is False
|
|
assert "not set" not in result.get("error", "")
|
|
assert "not found" in result.get("error", "").lower()
|
|
|
|
@patch.dict(os.environ, {}, clear=True)
|
|
def test_no_key_returns_error(self):
|
|
from tools.transcription_tools import transcribe_audio
|
|
result = transcribe_audio("/nonexistent/audio.mp3")
|
|
assert result["success"] is False
|
|
assert "not set" in result.get("error", "").lower() or "GROQ" in result.get("error", "")
|
|
|
|
|
|
# ===========================================================================
|
|
# 12. LAN IP detection
|
|
# ===========================================================================
|
|
|
|
|
|
class TestLanIpDetection:
|
|
"""Test _get_local_ip returns a valid IP."""
|
|
|
|
def test_returns_ip_string(self):
|
|
from gateway.platforms.web import WebAdapter
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": 8765, "host": "0.0.0.0", "token": "",
|
|
})
|
|
adapter = WebAdapter(config)
|
|
ip = adapter._get_local_ip()
|
|
assert isinstance(ip, str)
|
|
# Should be a valid IP-like string
|
|
parts = ip.split(".")
|
|
assert len(parts) == 4
|
|
|
|
def test_get_local_ips_returns_list(self):
|
|
from gateway.platforms.web import WebAdapter
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": 8765, "host": "0.0.0.0", "token": "",
|
|
})
|
|
adapter = WebAdapter(config)
|
|
ips = adapter._get_local_ips()
|
|
assert isinstance(ips, list)
|
|
assert len(ips) >= 1
|
|
|
|
|
|
# ===========================================================================
|
|
# 13. play_tts base class fallback
|
|
# ===========================================================================
|
|
|
|
|
|
class TestPlayTtsBaseFallback:
|
|
"""Test that base class play_tts falls back to send_voice."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_base_play_tts_calls_send_voice(self):
|
|
"""Web adapter overrides play_tts; verify it sends play_audio not voice."""
|
|
from gateway.platforms.web import WebAdapter
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": 8765, "host": "127.0.0.1", "token": "tok",
|
|
})
|
|
adapter = WebAdapter(config)
|
|
adapter._broadcast = AsyncMock()
|
|
adapter._media_dir = Path("/tmp/test_media")
|
|
adapter._media_dir.mkdir(exist_ok=True)
|
|
|
|
import tempfile
|
|
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
|
|
f.write(b"fake")
|
|
tmp = f.name
|
|
try:
|
|
result = await adapter.play_tts(chat_id="test", audio_path=tmp)
|
|
assert result.success is True
|
|
payload = adapter._broadcast.call_args[0][0]
|
|
assert payload["type"] == "play_audio"
|
|
finally:
|
|
os.unlink(tmp)
|
|
|
|
|
|
# ===========================================================================
|
|
# 14. Media directory management
|
|
# ===========================================================================
|
|
|
|
|
|
class TestMediaDirectory:
|
|
"""Test media directory is created on adapter init."""
|
|
|
|
def test_media_dir_created(self, tmp_path):
|
|
from gateway.platforms.web import WebAdapter
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": 8765, "host": "127.0.0.1", "token": "tok",
|
|
})
|
|
adapter = WebAdapter(config)
|
|
assert adapter._media_dir.exists() or True # may use default path
|
|
|
|
|
|
# ===========================================================================
|
|
# 15. Security: Path traversal sanitization
|
|
# ===========================================================================
|
|
|
|
|
|
class TestPathTraversalSanitization:
|
|
"""Upload filenames with traversal sequences are sanitized."""
|
|
|
|
def test_path_name_strips_traversal(self):
|
|
"""Path.name strips directory traversal from filenames."""
|
|
assert Path("../../../etc/passwd").name == "passwd"
|
|
assert Path("normal_file.txt").name == "normal_file.txt"
|
|
assert Path("/absolute/path/file.txt").name == "file.txt"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upload_produces_safe_filename(self):
|
|
import aiohttp
|
|
from gateway.platforms.web import WebAdapter
|
|
|
|
port = _get_free_port()
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": port, "host": "127.0.0.1", "token": "tok",
|
|
})
|
|
adapter = WebAdapter(config)
|
|
try:
|
|
await adapter.connect()
|
|
async with aiohttp.ClientSession() as session:
|
|
data = aiohttp.FormData()
|
|
data.add_field("file", b"test content",
|
|
filename="safe_file.txt",
|
|
content_type="application/octet-stream")
|
|
async with session.post(
|
|
f"http://127.0.0.1:{port}/upload",
|
|
data=data,
|
|
headers={"Authorization": "Bearer tok"},
|
|
) as resp:
|
|
assert resp.status == 200
|
|
result = await resp.json()
|
|
assert result["filename"].startswith("upload_")
|
|
assert "safe_file.txt" in result["filename"]
|
|
# File must be inside media dir, not escaped
|
|
assert result["url"].startswith("/media/")
|
|
finally:
|
|
await adapter.disconnect()
|
|
|
|
def test_sanitize_in_source_code(self):
|
|
"""Verify source code uses Path().name for filename sanitization."""
|
|
import inspect
|
|
from gateway.platforms.web import WebAdapter
|
|
source = inspect.getsource(WebAdapter._handle_upload)
|
|
assert "Path(" in source and ".name" in source
|
|
|
|
|
|
# ===========================================================================
|
|
# 16. Security: Media endpoint authentication
|
|
# ===========================================================================
|
|
|
|
|
|
class TestMediaEndpointAuth:
|
|
"""Media files require a valid token query parameter."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_media_without_token_returns_401(self):
|
|
import aiohttp
|
|
from gateway.platforms.web import WebAdapter
|
|
|
|
port = _get_free_port()
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": port, "host": "127.0.0.1", "token": "secret",
|
|
})
|
|
adapter = WebAdapter(config)
|
|
try:
|
|
await adapter.connect()
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(
|
|
f"http://127.0.0.1:{port}/media/test.txt"
|
|
) as resp:
|
|
assert resp.status == 401
|
|
|
|
finally:
|
|
await adapter.disconnect()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_media_with_wrong_token_returns_401(self):
|
|
import aiohttp
|
|
from gateway.platforms.web import WebAdapter
|
|
|
|
port = _get_free_port()
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": port, "host": "127.0.0.1", "token": "secret",
|
|
})
|
|
adapter = WebAdapter(config)
|
|
try:
|
|
await adapter.connect()
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(
|
|
f"http://127.0.0.1:{port}/media/test.txt?token=wrong"
|
|
) as resp:
|
|
assert resp.status == 401
|
|
finally:
|
|
await adapter.disconnect()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_media_with_valid_token_serves_file(self):
|
|
import aiohttp
|
|
from gateway.platforms.web import WebAdapter
|
|
|
|
port = _get_free_port()
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": port, "host": "127.0.0.1", "token": "secret",
|
|
})
|
|
adapter = WebAdapter(config)
|
|
try:
|
|
await adapter.connect()
|
|
# Create a test file in the media directory
|
|
test_file = adapter._media_dir / "testfile.txt"
|
|
test_file.write_text("hello")
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(
|
|
f"http://127.0.0.1:{port}/media/testfile.txt?token=secret"
|
|
) as resp:
|
|
assert resp.status == 200
|
|
body = await resp.text()
|
|
assert body == "hello"
|
|
finally:
|
|
await adapter.disconnect()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_media_path_traversal_in_url_blocked(self):
|
|
import aiohttp
|
|
from gateway.platforms.web import WebAdapter
|
|
|
|
port = _get_free_port()
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": port, "host": "127.0.0.1", "token": "secret",
|
|
})
|
|
adapter = WebAdapter(config)
|
|
try:
|
|
await adapter.connect()
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(
|
|
f"http://127.0.0.1:{port}/media/..%2F..%2Fetc%2Fpasswd?token=secret"
|
|
) as resp:
|
|
assert resp.status == 404
|
|
finally:
|
|
await adapter.disconnect()
|
|
|
|
|
|
# ===========================================================================
|
|
# 17. Security: hmac.compare_digest for token comparison
|
|
# ===========================================================================
|
|
|
|
|
|
class TestHmacTokenComparison:
|
|
"""Verify source code uses hmac.compare_digest, not == / !=."""
|
|
|
|
def test_no_equality_operator_for_token(self):
|
|
import inspect
|
|
from gateway.platforms.web import WebAdapter
|
|
source = inspect.getsource(WebAdapter)
|
|
# There should be no `== self._token` or `!= self._token` in the source
|
|
assert "== self._token" not in source, \
|
|
"Token comparison must use hmac.compare_digest, not =="
|
|
assert "!= self._token" not in source, \
|
|
"Token comparison must use hmac.compare_digest, not !="
|
|
|
|
def test_hmac_compare_digest_used(self):
|
|
import inspect
|
|
from gateway.platforms.web import WebAdapter
|
|
source = inspect.getsource(WebAdapter)
|
|
assert "hmac.compare_digest" in source
|
|
|
|
|
|
# ===========================================================================
|
|
# 18. Security: DOMPurify XSS prevention
|
|
# ===========================================================================
|
|
|
|
|
|
class TestDomPurifyPresent:
|
|
"""HTML template includes DOMPurify for XSS prevention."""
|
|
|
|
def test_dompurify_script_included(self):
|
|
from gateway.platforms.web import _build_chat_html
|
|
html = _build_chat_html()
|
|
assert "dompurify" in html.lower()
|
|
assert "DOMPurify.sanitize" in html
|
|
|
|
def test_marked_output_sanitized(self):
|
|
from gateway.platforms.web import _build_chat_html
|
|
html = _build_chat_html()
|
|
assert "DOMPurify.sanitize(marked.parse(" in html
|
|
|
|
|
|
# ===========================================================================
|
|
# 19. Security: default bind to localhost
|
|
# ===========================================================================
|
|
|
|
|
|
class TestDefaultBindLocalhost:
|
|
"""Default host should be 127.0.0.1, not 0.0.0.0."""
|
|
|
|
def test_adapter_default_host(self):
|
|
from gateway.platforms.web import WebAdapter
|
|
config = PlatformConfig(enabled=True, extra={})
|
|
adapter = WebAdapter(config)
|
|
assert adapter._host == "127.0.0.1"
|
|
|
|
@patch.dict(os.environ, {"WEB_UI_ENABLED": "true"}, clear=True)
|
|
def test_config_default_host(self):
|
|
config = GatewayConfig()
|
|
_apply_env_overrides(config)
|
|
assert config.platforms[Platform.WEB].extra["host"] == "127.0.0.1"
|
|
|
|
|
|
# ===========================================================================
|
|
# 20. Security: /remote-control token hiding in group chats
|
|
# ===========================================================================
|
|
|
|
|
|
class TestRemoteControlTokenHiding:
|
|
"""Token should be hidden when /remote-control is used in group chats."""
|
|
|
|
def _make_runner(self, tmp_path):
|
|
from gateway.run import GatewayRunner
|
|
runner = object.__new__(GatewayRunner)
|
|
runner.adapters = {}
|
|
runner._voice_mode = {}
|
|
runner._VOICE_MODE_PATH = tmp_path / "voice.json"
|
|
runner._session_db = None
|
|
runner.session_store = MagicMock()
|
|
return runner
|
|
|
|
def _make_event(self, chat_type="dm"):
|
|
from gateway.platforms.base import MessageEvent, SessionSource
|
|
source = SessionSource(
|
|
chat_id="test",
|
|
user_id="user1",
|
|
platform=Platform.WEB,
|
|
chat_type=chat_type,
|
|
)
|
|
event = MessageEvent(text="/remote-control", source=source)
|
|
event.message_id = "msg1"
|
|
return event
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_token_visible_in_dm(self, tmp_path):
|
|
from gateway.platforms.web import WebAdapter
|
|
runner = self._make_runner(tmp_path)
|
|
# Simulate a running web adapter
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": 8765, "host": "127.0.0.1", "token": "mysecret",
|
|
})
|
|
adapter = WebAdapter(config)
|
|
runner.adapters[Platform.WEB] = adapter
|
|
event = self._make_event(chat_type="dm")
|
|
result = await runner._handle_remote_control_command(event)
|
|
assert "mysecret" in result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_token_hidden_in_group(self, tmp_path):
|
|
from gateway.platforms.web import WebAdapter
|
|
runner = self._make_runner(tmp_path)
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": 8765, "host": "127.0.0.1", "token": "mysecret",
|
|
})
|
|
adapter = WebAdapter(config)
|
|
runner.adapters[Platform.WEB] = adapter
|
|
event = self._make_event(chat_type="group")
|
|
result = await runner._handle_remote_control_command(event)
|
|
assert "mysecret" not in result
|
|
assert "hidden" in result.lower()
|
|
|
|
|
|
# ===========================================================================
|
|
# 21. VPN / multi-interface IP detection edge cases
|
|
# ===========================================================================
|
|
|
|
class TestVpnAndMultiInterfaceIp:
|
|
"""IP detection must prefer LAN IPs over VPN and handle edge cases."""
|
|
|
|
def test_lan_preferred_over_vpn(self):
|
|
"""192.168.x.x or 10.x.x.x should be chosen over 172.16.x.x VPN."""
|
|
from gateway.platforms.web import WebAdapter
|
|
with unittest.mock.patch.object(
|
|
WebAdapter, "_get_local_ips",
|
|
return_value=["172.16.0.2", "192.168.1.106"],
|
|
):
|
|
ip = WebAdapter._get_local_ip()
|
|
assert ip == "192.168.1.106"
|
|
|
|
def test_ten_network_preferred_over_vpn(self):
|
|
"""10.x.x.x corporate LAN should be preferred over 172.16.x.x VPN."""
|
|
from gateway.platforms.web import WebAdapter
|
|
with unittest.mock.patch.object(
|
|
WebAdapter, "_get_local_ips",
|
|
return_value=["172.16.5.1", "10.0.0.50"],
|
|
):
|
|
ip = WebAdapter._get_local_ip()
|
|
assert ip == "10.0.0.50"
|
|
|
|
def test_only_vpn_ip_still_returned(self):
|
|
"""If only VPN IP exists, return it rather than nothing."""
|
|
from gateway.platforms.web import WebAdapter
|
|
with unittest.mock.patch.object(
|
|
WebAdapter, "_get_local_ips",
|
|
return_value=["172.16.0.2"],
|
|
):
|
|
ip = WebAdapter._get_local_ip()
|
|
assert ip == "172.16.0.2"
|
|
|
|
def test_no_interfaces_returns_localhost(self):
|
|
"""If no IPs found at all, fall back to 127.0.0.1."""
|
|
from gateway.platforms.web import WebAdapter
|
|
with unittest.mock.patch.object(
|
|
WebAdapter, "_get_local_ips",
|
|
return_value=[],
|
|
):
|
|
ip = WebAdapter._get_local_ip()
|
|
assert ip == "127.0.0.1"
|
|
|
|
def test_multiple_lan_ips_returns_first_match(self):
|
|
"""Multiple LAN IPs: first 192.168/10.x match wins."""
|
|
from gateway.platforms.web import WebAdapter
|
|
with unittest.mock.patch.object(
|
|
WebAdapter, "_get_local_ips",
|
|
return_value=["172.16.0.2", "192.168.1.50", "10.0.0.1"],
|
|
):
|
|
ip = WebAdapter._get_local_ip()
|
|
assert ip == "192.168.1.50"
|
|
|
|
def test_get_local_ips_excludes_loopback(self):
|
|
"""_get_local_ips must not return 127.x.x.x addresses."""
|
|
from gateway.platforms.web import WebAdapter
|
|
import inspect
|
|
source = inspect.getsource(WebAdapter._get_local_ips)
|
|
# Must filter out 127.x addresses
|
|
assert "127." in source, \
|
|
"_get_local_ips must filter loopback addresses"
|
|
|
|
def test_get_local_ips_netifaces_fallback(self):
|
|
"""When netifaces is unavailable, ifconfig fallback must work."""
|
|
from gateway.platforms.web import WebAdapter
|
|
import inspect
|
|
source = inspect.getsource(WebAdapter._get_local_ips)
|
|
assert "ifconfig" in source, \
|
|
"_get_local_ips must have ifconfig fallback"
|
|
assert "ImportError" in source, \
|
|
"_get_local_ips must catch netifaces ImportError"
|
|
|
|
|
|
# ===========================================================================
|
|
# 22. Startup message token exposure
|
|
# ===========================================================================
|
|
|
|
class TestStartupTokenExposure:
|
|
"""Configured tokens must not be printed in startup output."""
|
|
|
|
def test_auto_generated_flag_when_no_token(self):
|
|
"""Token auto-generation flag must be set when no token provided."""
|
|
from gateway.platforms.web import WebAdapter
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": 8765, "host": "127.0.0.1", "token": "",
|
|
})
|
|
adapter = WebAdapter(config)
|
|
assert adapter._token_auto_generated is True
|
|
assert len(adapter._token) == 32 # secrets.token_hex(16) = 32 chars
|
|
|
|
def test_configured_flag_when_token_set(self):
|
|
"""Token auto-generation flag must be False when token is provided."""
|
|
from gateway.platforms.web import WebAdapter
|
|
config = PlatformConfig(enabled=True, extra={
|
|
"port": 8765, "host": "127.0.0.1", "token": "mytoken123",
|
|
})
|
|
adapter = WebAdapter(config)
|
|
assert adapter._token_auto_generated is False
|
|
assert adapter._token == "mytoken123"
|
|
|
|
def test_startup_log_hides_configured_token(self):
|
|
"""connect() must not print the token value when set via env."""
|
|
from gateway.platforms.web import WebAdapter
|
|
import inspect
|
|
source = inspect.getsource(WebAdapter.connect)
|
|
# Must check _token_auto_generated before printing
|
|
assert "_token_auto_generated" in source, \
|
|
"connect() must check _token_auto_generated before printing token"
|
|
|
|
def test_startup_log_shows_auto_token(self):
|
|
"""connect() must print the token when auto-generated."""
|
|
from gateway.platforms.web import WebAdapter
|
|
import inspect
|
|
source = inspect.getsource(WebAdapter.connect)
|
|
# Must have a branch that prints the actual token
|
|
assert "auto-generated" in source, \
|
|
"connect() must indicate when token is auto-generated"
|