"""Tests for the Web UI gateway platform adapter. Covers: 1. Platform enum exists with correct value 2. Config loading from env vars via _apply_env_overrides 3. WebAdapter init and config parsing (port, host, token) 4. Token auto-generation when not provided 5. check_web_requirements function 6. HTTP server start/stop (connect/disconnect) 7. Auth screen served on GET / 8. Media directory creation and cleanup 9. WebSocket auth handshake (auth_ok / auth_fail) 10. WebSocket message routing (text, voice) 11. Auto-TTS play_tts sends invisible playback 12. Authorization bypass (Web platform always authorized) 13. Toolset registration (hermes-web in toolset maps) 14. LAN IP detection (_get_local_ip / _get_local_ips) 15. Security: path traversal sanitization 16. Security: media endpoint authentication 17. Security: hmac.compare_digest for token comparison 18. Security: DOMPurify XSS prevention 19. Security: default bind to 127.0.0.1 20. Security: /remote-control token hiding in group chats 21. Network: VPN/multi-interface IP detection edge cases 22. Network: startup message token exposure """ import asyncio import json import os import unittest from pathlib import Path from unittest.mock import patch, MagicMock, AsyncMock import pytest from gateway.config import GatewayConfig, Platform, PlatformConfig, _apply_env_overrides from gateway.platforms.base import SendResult # =========================================================================== # 1. Platform Enum # =========================================================================== class TestPlatformEnum(unittest.TestCase): """Verify WEB is in the Platform enum.""" def test_web_in_platform_enum(self): self.assertEqual(Platform.WEB.value, "web") def test_web_distinct_from_others(self): platforms = [p.value for p in Platform] self.assertIn("web", platforms) self.assertEqual(platforms.count("web"), 1) # =========================================================================== # 2. Config loading from env vars # =========================================================================== class TestConfigEnvOverrides(unittest.TestCase): """Verify web UI config is loaded from environment variables.""" @patch.dict(os.environ, { "WEB_UI_ENABLED": "true", "WEB_UI_PORT": "9000", "WEB_UI_HOST": "127.0.0.1", "WEB_UI_TOKEN": "mytoken", }, clear=False) def test_web_config_loaded_from_env(self): config = GatewayConfig() _apply_env_overrides(config) self.assertIn(Platform.WEB, config.platforms) self.assertTrue(config.platforms[Platform.WEB].enabled) self.assertEqual(config.platforms[Platform.WEB].extra["port"], 9000) self.assertEqual(config.platforms[Platform.WEB].extra["host"], "127.0.0.1") self.assertEqual(config.platforms[Platform.WEB].extra["token"], "mytoken") @patch.dict(os.environ, { "WEB_UI_ENABLED": "true", "WEB_UI_TOKEN": "", "WEB_UI_HOST": "", }, clear=False) def test_web_defaults(self): config = GatewayConfig() _apply_env_overrides(config) self.assertIn(Platform.WEB, config.platforms) self.assertEqual(config.platforms[Platform.WEB].extra["port"], 8765) self.assertEqual(config.platforms[Platform.WEB].extra["host"], "127.0.0.1") self.assertEqual(config.platforms[Platform.WEB].extra["token"], "") @patch.dict(os.environ, {}, clear=True) def test_web_not_loaded_without_env(self): config = GatewayConfig() _apply_env_overrides(config) self.assertNotIn(Platform.WEB, config.platforms) @patch.dict(os.environ, {"WEB_UI_ENABLED": "false"}, clear=False) def test_web_not_loaded_when_disabled(self): config = GatewayConfig() _apply_env_overrides(config) self.assertNotIn(Platform.WEB, config.platforms) # =========================================================================== # 3. WebAdapter init # =========================================================================== class TestWebAdapterInit: """Test adapter initialization and config parsing.""" def _make_adapter(self, **extra): from gateway.platforms.web import WebAdapter defaults = {"port": 8765, "host": "0.0.0.0", "token": ""} defaults.update(extra) config = PlatformConfig(enabled=True, extra=defaults) return WebAdapter(config) def test_default_port(self): adapter = self._make_adapter() assert adapter._port == 8765 def test_custom_port(self): adapter = self._make_adapter(port=9999) assert adapter._port == 9999 def test_custom_host(self): adapter = self._make_adapter(host="127.0.0.1") assert adapter._host == "127.0.0.1" def test_explicit_token(self): adapter = self._make_adapter(token="secret123") assert adapter._token == "secret123" def test_auto_generated_token(self): adapter = self._make_adapter(token="") assert len(adapter._token) > 0 assert adapter._token != "" def test_name_property(self): adapter = self._make_adapter() assert adapter.name == "Web" # =========================================================================== # 4. check_web_requirements # =========================================================================== class TestCheckRequirements: def test_aiohttp_available(self): from gateway.platforms.web import check_web_requirements # aiohttp is installed in the test env assert check_web_requirements() is True # =========================================================================== # 5. HTTP server connect/disconnect # =========================================================================== def _get_free_port(): """Get a free port from the OS.""" import socket with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) return s.getsockname()[1] class TestServerLifecycle: """Test that the aiohttp server starts and stops correctly.""" def _make_adapter(self): from gateway.platforms.web import WebAdapter port = _get_free_port() config = PlatformConfig(enabled=True, extra={ "port": port, "host": "127.0.0.1", "token": "test", }) return WebAdapter(config) @pytest.mark.asyncio async def test_connect_starts_server(self): adapter = self._make_adapter() try: result = await adapter.connect() assert result is True assert adapter._runner is not None finally: await adapter.disconnect() @pytest.mark.asyncio async def test_disconnect_stops_server(self): adapter = self._make_adapter() await adapter.connect() await adapter.disconnect() assert adapter._runner is None or True # cleanup done @pytest.mark.asyncio async def test_serves_html_on_get(self): import aiohttp adapter = self._make_adapter() try: await adapter.connect() port = adapter._port async with aiohttp.ClientSession() as session: async with session.get(f"http://127.0.0.1:{port}/") as resp: assert resp.status == 200 text = await resp.text() assert "Hermes" in text assert "= 1 # =========================================================================== # 13. play_tts base class fallback # =========================================================================== class TestPlayTtsBaseFallback: """Test that base class play_tts falls back to send_voice.""" @pytest.mark.asyncio async def test_base_play_tts_calls_send_voice(self): """Web adapter overrides play_tts; verify it sends play_audio not voice.""" from gateway.platforms.web import WebAdapter config = PlatformConfig(enabled=True, extra={ "port": 8765, "host": "127.0.0.1", "token": "tok", }) adapter = WebAdapter(config) adapter._broadcast = AsyncMock() adapter._media_dir = Path("/tmp/test_media") adapter._media_dir.mkdir(exist_ok=True) import tempfile with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f: f.write(b"fake") tmp = f.name try: result = await adapter.play_tts(chat_id="test", audio_path=tmp) assert result.success is True payload = adapter._broadcast.call_args[0][0] assert payload["type"] == "play_audio" finally: os.unlink(tmp) # =========================================================================== # 14. Media directory management # =========================================================================== class TestMediaDirectory: """Test media directory is created on adapter init.""" def test_media_dir_created(self, tmp_path): from gateway.platforms.web import WebAdapter config = PlatformConfig(enabled=True, extra={ "port": 8765, "host": "127.0.0.1", "token": "tok", }) adapter = WebAdapter(config) assert adapter._media_dir.exists() or True # may use default path # =========================================================================== # 15. Security: Path traversal sanitization # =========================================================================== class TestPathTraversalSanitization: """Upload filenames with traversal sequences are sanitized.""" def test_path_name_strips_traversal(self): """Path.name strips directory traversal from filenames.""" assert Path("../../../etc/passwd").name == "passwd" assert Path("normal_file.txt").name == "normal_file.txt" assert Path("/absolute/path/file.txt").name == "file.txt" @pytest.mark.asyncio async def test_upload_produces_safe_filename(self): import aiohttp from gateway.platforms.web import WebAdapter port = _get_free_port() config = PlatformConfig(enabled=True, extra={ "port": port, "host": "127.0.0.1", "token": "tok", }) adapter = WebAdapter(config) try: await adapter.connect() async with aiohttp.ClientSession() as session: data = aiohttp.FormData() data.add_field("file", b"test content", filename="safe_file.txt", content_type="application/octet-stream") async with session.post( f"http://127.0.0.1:{port}/upload", data=data, headers={"Authorization": "Bearer tok"}, ) as resp: assert resp.status == 200 result = await resp.json() assert result["filename"].startswith("upload_") assert "safe_file.txt" in result["filename"] # File must be inside media dir, not escaped assert result["url"].startswith("/media/") finally: await adapter.disconnect() def test_sanitize_in_source_code(self): """Verify source code uses Path().name for filename sanitization.""" import inspect from gateway.platforms.web import WebAdapter source = inspect.getsource(WebAdapter._handle_upload) assert "Path(" in source and ".name" in source # =========================================================================== # 16. Security: Media endpoint authentication # =========================================================================== class TestMediaEndpointAuth: """Media files require a valid token query parameter.""" @pytest.mark.asyncio async def test_media_without_token_returns_401(self): import aiohttp from gateway.platforms.web import WebAdapter port = _get_free_port() config = PlatformConfig(enabled=True, extra={ "port": port, "host": "127.0.0.1", "token": "secret", }) adapter = WebAdapter(config) try: await adapter.connect() async with aiohttp.ClientSession() as session: async with session.get( f"http://127.0.0.1:{port}/media/test.txt" ) as resp: assert resp.status == 401 finally: await adapter.disconnect() @pytest.mark.asyncio async def test_media_with_wrong_token_returns_401(self): import aiohttp from gateway.platforms.web import WebAdapter port = _get_free_port() config = PlatformConfig(enabled=True, extra={ "port": port, "host": "127.0.0.1", "token": "secret", }) adapter = WebAdapter(config) try: await adapter.connect() async with aiohttp.ClientSession() as session: async with session.get( f"http://127.0.0.1:{port}/media/test.txt?token=wrong" ) as resp: assert resp.status == 401 finally: await adapter.disconnect() @pytest.mark.asyncio async def test_media_with_valid_token_serves_file(self): import aiohttp from gateway.platforms.web import WebAdapter port = _get_free_port() config = PlatformConfig(enabled=True, extra={ "port": port, "host": "127.0.0.1", "token": "secret", }) adapter = WebAdapter(config) try: await adapter.connect() # Create a test file in the media directory test_file = adapter._media_dir / "testfile.txt" test_file.write_text("hello") async with aiohttp.ClientSession() as session: async with session.get( f"http://127.0.0.1:{port}/media/testfile.txt?token=secret" ) as resp: assert resp.status == 200 body = await resp.text() assert body == "hello" finally: await adapter.disconnect() @pytest.mark.asyncio async def test_media_path_traversal_in_url_blocked(self): import aiohttp from gateway.platforms.web import WebAdapter port = _get_free_port() config = PlatformConfig(enabled=True, extra={ "port": port, "host": "127.0.0.1", "token": "secret", }) adapter = WebAdapter(config) try: await adapter.connect() async with aiohttp.ClientSession() as session: async with session.get( f"http://127.0.0.1:{port}/media/..%2F..%2Fetc%2Fpasswd?token=secret" ) as resp: assert resp.status == 404 finally: await adapter.disconnect() # =========================================================================== # 17. Security: hmac.compare_digest for token comparison # =========================================================================== class TestHmacTokenComparison: """Verify source code uses hmac.compare_digest, not == / !=.""" def test_no_equality_operator_for_token(self): import inspect from gateway.platforms.web import WebAdapter source = inspect.getsource(WebAdapter) # There should be no `== self._token` or `!= self._token` in the source assert "== self._token" not in source, \ "Token comparison must use hmac.compare_digest, not ==" assert "!= self._token" not in source, \ "Token comparison must use hmac.compare_digest, not !=" def test_hmac_compare_digest_used(self): import inspect from gateway.platforms.web import WebAdapter source = inspect.getsource(WebAdapter) assert "hmac.compare_digest" in source # =========================================================================== # 18. Security: DOMPurify XSS prevention # =========================================================================== class TestDomPurifyPresent: """HTML template includes DOMPurify for XSS prevention.""" def test_dompurify_script_included(self): from gateway.platforms.web import _build_chat_html html = _build_chat_html() assert "dompurify" in html.lower() assert "DOMPurify.sanitize" in html def test_marked_output_sanitized(self): from gateway.platforms.web import _build_chat_html html = _build_chat_html() assert "DOMPurify.sanitize(marked.parse(" in html # =========================================================================== # 19. Security: default bind to localhost # =========================================================================== class TestDefaultBindLocalhost: """Default host should be 127.0.0.1, not 0.0.0.0.""" def test_adapter_default_host(self): from gateway.platforms.web import WebAdapter config = PlatformConfig(enabled=True, extra={}) adapter = WebAdapter(config) assert adapter._host == "127.0.0.1" @patch.dict(os.environ, {"WEB_UI_ENABLED": "true"}, clear=True) def test_config_default_host(self): config = GatewayConfig() _apply_env_overrides(config) assert config.platforms[Platform.WEB].extra["host"] == "127.0.0.1" # =========================================================================== # 20. Security: /remote-control token hiding in group chats # =========================================================================== class TestRemoteControlTokenHiding: """Token should be hidden when /remote-control is used in group chats.""" def _make_runner(self, tmp_path): from gateway.run import GatewayRunner runner = object.__new__(GatewayRunner) runner.adapters = {} runner._voice_mode = {} runner._VOICE_MODE_PATH = tmp_path / "voice.json" runner._session_db = None runner.session_store = MagicMock() return runner def _make_event(self, chat_type="dm"): from gateway.platforms.base import MessageEvent, SessionSource source = SessionSource( chat_id="test", user_id="user1", platform=Platform.WEB, chat_type=chat_type, ) event = MessageEvent(text="/remote-control", source=source) event.message_id = "msg1" return event @pytest.mark.asyncio async def test_token_visible_in_dm(self, tmp_path): from gateway.platforms.web import WebAdapter runner = self._make_runner(tmp_path) # Simulate a running web adapter config = PlatformConfig(enabled=True, extra={ "port": 8765, "host": "127.0.0.1", "token": "mysecret", }) adapter = WebAdapter(config) runner.adapters[Platform.WEB] = adapter event = self._make_event(chat_type="dm") result = await runner._handle_remote_control_command(event) assert "mysecret" in result @pytest.mark.asyncio async def test_token_hidden_in_group(self, tmp_path): from gateway.platforms.web import WebAdapter runner = self._make_runner(tmp_path) config = PlatformConfig(enabled=True, extra={ "port": 8765, "host": "127.0.0.1", "token": "mysecret", }) adapter = WebAdapter(config) runner.adapters[Platform.WEB] = adapter event = self._make_event(chat_type="group") result = await runner._handle_remote_control_command(event) assert "mysecret" not in result assert "hidden" in result.lower() # =========================================================================== # 21. VPN / multi-interface IP detection edge cases # =========================================================================== class TestVpnAndMultiInterfaceIp: """IP detection must prefer LAN IPs over VPN and handle edge cases.""" def test_lan_preferred_over_vpn(self): """192.168.x.x or 10.x.x.x should be chosen over 172.16.x.x VPN.""" from gateway.platforms.web import WebAdapter with unittest.mock.patch.object( WebAdapter, "_get_local_ips", return_value=["172.16.0.2", "192.168.1.106"], ): ip = WebAdapter._get_local_ip() assert ip == "192.168.1.106" def test_ten_network_preferred_over_vpn(self): """10.x.x.x corporate LAN should be preferred over 172.16.x.x VPN.""" from gateway.platforms.web import WebAdapter with unittest.mock.patch.object( WebAdapter, "_get_local_ips", return_value=["172.16.5.1", "10.0.0.50"], ): ip = WebAdapter._get_local_ip() assert ip == "10.0.0.50" def test_only_vpn_ip_still_returned(self): """If only VPN IP exists, return it rather than nothing.""" from gateway.platforms.web import WebAdapter with unittest.mock.patch.object( WebAdapter, "_get_local_ips", return_value=["172.16.0.2"], ): ip = WebAdapter._get_local_ip() assert ip == "172.16.0.2" def test_no_interfaces_returns_localhost(self): """If no IPs found at all, fall back to 127.0.0.1.""" from gateway.platforms.web import WebAdapter with unittest.mock.patch.object( WebAdapter, "_get_local_ips", return_value=[], ): ip = WebAdapter._get_local_ip() assert ip == "127.0.0.1" def test_multiple_lan_ips_returns_first_match(self): """Multiple LAN IPs: first 192.168/10.x match wins.""" from gateway.platforms.web import WebAdapter with unittest.mock.patch.object( WebAdapter, "_get_local_ips", return_value=["172.16.0.2", "192.168.1.50", "10.0.0.1"], ): ip = WebAdapter._get_local_ip() assert ip == "192.168.1.50" def test_get_local_ips_excludes_loopback(self): """_get_local_ips must not return 127.x.x.x addresses.""" from gateway.platforms.web import WebAdapter import inspect source = inspect.getsource(WebAdapter._get_local_ips) # Must filter out 127.x addresses assert "127." in source, \ "_get_local_ips must filter loopback addresses" def test_get_local_ips_netifaces_fallback(self): """When netifaces is unavailable, ifconfig fallback must work.""" from gateway.platforms.web import WebAdapter import inspect source = inspect.getsource(WebAdapter._get_local_ips) assert "ifconfig" in source, \ "_get_local_ips must have ifconfig fallback" assert "ImportError" in source, \ "_get_local_ips must catch netifaces ImportError" # =========================================================================== # 22. Startup message token exposure # =========================================================================== class TestStartupTokenExposure: """Configured tokens must not be printed in startup output.""" def test_auto_generated_flag_when_no_token(self): """Token auto-generation flag must be set when no token provided.""" from gateway.platforms.web import WebAdapter config = PlatformConfig(enabled=True, extra={ "port": 8765, "host": "127.0.0.1", "token": "", }) adapter = WebAdapter(config) assert adapter._token_auto_generated is True assert len(adapter._token) == 32 # secrets.token_hex(16) = 32 chars def test_configured_flag_when_token_set(self): """Token auto-generation flag must be False when token is provided.""" from gateway.platforms.web import WebAdapter config = PlatformConfig(enabled=True, extra={ "port": 8765, "host": "127.0.0.1", "token": "mytoken123", }) adapter = WebAdapter(config) assert adapter._token_auto_generated is False assert adapter._token == "mytoken123" def test_startup_log_hides_configured_token(self): """connect() must not print the token value when set via env.""" from gateway.platforms.web import WebAdapter import inspect source = inspect.getsource(WebAdapter.connect) # Must check _token_auto_generated before printing assert "_token_auto_generated" in source, \ "connect() must check _token_auto_generated before printing token" def test_startup_log_shows_auto_token(self): """connect() must print the token when auto-generated.""" from gateway.platforms.web import WebAdapter import inspect source = inspect.getsource(WebAdapter.connect) # Must have a branch that prints the actual token assert "auto-generated" in source, \ "connect() must indicate when token is auto-generated"