diff --git a/hermes_cli/web_server.py b/hermes_cli/web_server.py index 6693f9c7cbf..e2211c3c307 100644 --- a/hermes_cli/web_server.py +++ b/hermes_cli/web_server.py @@ -3401,6 +3401,50 @@ def _ws_request_is_allowed(ws: "WebSocket") -> bool: """Return True when the WebSocket upgrade matches dashboard boundaries.""" return _ws_host_origin_is_allowed(ws) and _ws_client_is_allowed(ws) + +def _ws_auth_ok(ws: "WebSocket") -> bool: + """Validate WS-upgrade auth in either loopback or gated mode. + + Loopback / ``--insecure``: legacy ``?token=<_SESSION_TOKEN>`` query + parameter, constant-time compared. + + Gated (public bind, no ``--insecure``): ``?ticket=`` query + parameter consumed against the dashboard-auth ticket store. The legacy + token path is unconditionally rejected in this mode (the SPA bundle + isn't carrying the token any longer). + + Returns True if the WS should be accepted; callers close with the + appropriate WS code (4401) on False. Audit-logs the rejection so + operators can debug "WS keeps closing" issues from the log. + """ + auth_required = bool(getattr(app.state, "auth_required", False)) + if auth_required: + ticket = ws.query_params.get("ticket", "") + if not ticket: + return False + # Lazy import — keeps this function importable in test harnesses + # that don't bring in the dashboard_auth layer. + from hermes_cli.dashboard_auth.audit import AuditEvent, audit_log + from hermes_cli.dashboard_auth.ws_tickets import ( + TicketInvalid, + consume_ticket, + ) + + try: + consume_ticket(ticket) + return True + except TicketInvalid as exc: + audit_log( + AuditEvent.WS_TICKET_REJECTED, + reason=str(exc), + ip=(ws.client.host if ws.client else ""), + path=ws.url.path, + ) + return False + + token = ws.query_params.get("token", "") + return hmac.compare_digest(token.encode(), _SESSION_TOKEN.encode()) + # Per-channel subscriber registry used by /api/pub (PTY-side gateway → dashboard) # and /api/events (dashboard → browser sidebar). Keyed by an opaque channel id # the chat tab generates on mount; entries auto-evict when the last subscriber @@ -3455,7 +3499,21 @@ def _resolve_chat_argv( def _build_sidecar_url(channel: str) -> Optional[str]: - """ws:// URL the PTY child should publish events to, or None when unbound.""" + """ws:// URL the PTY child should publish events to, or None when unbound. + + Loopback / ``--insecure``: uses ``?token=<_SESSION_TOKEN>``. + + Gated mode: mints a single-use ticket via the dashboard-auth ticket + store (server-side mint, no HTTP round trip — the PTY child is a + server-spawned process and we trust it). The ticket binds to the + pseudo-user ``"pty-sidecar"`` so audit logs can distinguish these from + browser-initiated tickets. + + The single-use lifetime means the PTY child cannot reconnect without a + new sidecar URL. PTY children open ``/api/pub`` once at startup; if + reconnect semantics ever become important, this should be upgraded to + a long-lived process-scoped token. + """ host = getattr(app.state, "bound_host", None) port = getattr(app.state, "bound_port", None) @@ -3463,7 +3521,15 @@ def _build_sidecar_url(channel: str) -> Optional[str]: return None netloc = f"[{host}]:{port}" if ":" in host and not host.startswith("[") else f"{host}:{port}" - qs = urllib.parse.urlencode({"token": _SESSION_TOKEN, "channel": channel}) + + if getattr(app.state, "auth_required", False): + # Gated mode — mint a ticket so the WS upgrade survives _ws_auth_ok. + from hermes_cli.dashboard_auth.ws_tickets import mint_ticket + + ticket = mint_ticket(user_id="pty-sidecar", provider="server-internal") + qs = urllib.parse.urlencode({"ticket": ticket, "channel": channel}) + else: + qs = urllib.parse.urlencode({"token": _SESSION_TOKEN, "channel": channel}) return f"ws://{netloc}/api/pub?{qs}" @@ -3496,9 +3562,7 @@ async def pty_ws(ws: WebSocket) -> None: return # --- auth + loopback check (before accept so we can close cleanly) --- - token = ws.query_params.get("token", "") - expected = _SESSION_TOKEN - if not hmac.compare_digest(token.encode(), expected.encode()): + if not _ws_auth_ok(ws): await ws.close(code=4401) return @@ -3616,8 +3680,7 @@ async def gateway_ws(ws: WebSocket) -> None: await ws.close(code=4403) return - token = ws.query_params.get("token", "") - if not hmac.compare_digest(token.encode(), _SESSION_TOKEN.encode()): + if not _ws_auth_ok(ws): await ws.close(code=4401) return @@ -3648,8 +3711,7 @@ async def pub_ws(ws: WebSocket) -> None: await ws.close(code=4403) return - token = ws.query_params.get("token", "") - if not hmac.compare_digest(token.encode(), _SESSION_TOKEN.encode()): + if not _ws_auth_ok(ws): await ws.close(code=4401) return @@ -3677,8 +3739,7 @@ async def events_ws(ws: WebSocket) -> None: await ws.close(code=4403) return - token = ws.query_params.get("token", "") - if not hmac.compare_digest(token.encode(), _SESSION_TOKEN.encode()): + if not _ws_auth_ok(ws): await ws.close(code=4401) return diff --git a/tests/hermes_cli/test_dashboard_auth_ws_auth.py b/tests/hermes_cli/test_dashboard_auth_ws_auth.py new file mode 100644 index 00000000000..25940da12da --- /dev/null +++ b/tests/hermes_cli/test_dashboard_auth_ws_auth.py @@ -0,0 +1,257 @@ +"""Tests for the WS-upgrade auth helper (Phase 5 task 5.2). + +The dashboard's four WS endpoints (``/api/pty``, ``/api/ws``, ``/api/pub``, +``/api/events``) share an auth gate: ``_ws_auth_ok``. In loopback mode it +accepts ``?token=<_SESSION_TOKEN>``; in gated mode it accepts a single-use +``?ticket=`` minted by ``POST /api/auth/ws-ticket``. + +These tests exercise the helper at the unit level (no actual WS upgrade) +plus the ticket-mint endpoint under realistic gated-mode setup. We don't +test the full WS upgrade because the starlette TestClient WS path has a +pre-existing regression unrelated to dashboard-auth. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + +from hermes_cli import web_server +from hermes_cli.dashboard_auth import clear_providers, register_provider +from hermes_cli.dashboard_auth.ws_tickets import ( + TicketInvalid, + _reset_for_tests, + consume_ticket, + mint_ticket, +) +from tests.hermes_cli.conftest_dashboard_auth import StubAuthProvider + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def gated_app(): + """web_server.app configured for gated mode + stub provider registered.""" + _reset_for_tests() + clear_providers() + register_provider(StubAuthProvider()) + prev_host = getattr(web_server.app.state, "bound_host", None) + prev_port = getattr(web_server.app.state, "bound_port", None) + prev_required = getattr(web_server.app.state, "auth_required", None) + web_server.app.state.bound_host = "fly-app.fly.dev" + web_server.app.state.bound_port = 443 + web_server.app.state.auth_required = True + client = TestClient(web_server.app, base_url="https://fly-app.fly.dev") + yield client + clear_providers() + _reset_for_tests() + web_server.app.state.bound_host = prev_host + web_server.app.state.bound_port = prev_port + web_server.app.state.auth_required = prev_required + + +@pytest.fixture +def loopback_app(): + """web_server.app configured for loopback mode (gate OFF).""" + _reset_for_tests() + clear_providers() + prev_host = getattr(web_server.app.state, "bound_host", None) + prev_port = getattr(web_server.app.state, "bound_port", None) + prev_required = getattr(web_server.app.state, "auth_required", None) + web_server.app.state.bound_host = "127.0.0.1" + web_server.app.state.bound_port = 8080 + web_server.app.state.auth_required = False + client = TestClient(web_server.app, base_url="http://127.0.0.1:8080") + yield client + _reset_for_tests() + web_server.app.state.bound_host = prev_host + web_server.app.state.bound_port = prev_port + web_server.app.state.auth_required = prev_required + + +def _logged_in(client: TestClient) -> None: + """Drive the stub OAuth round trip so the client holds session cookies.""" + r1 = client.get("/auth/login?provider=stub", follow_redirects=False) + assert r1.status_code == 302 + state = r1.headers["location"].split("state=")[1] + r2 = client.get( + f"/auth/callback?code=stub_code&state={state}", follow_redirects=False + ) + assert r2.status_code == 302 + + +# --------------------------------------------------------------------------- +# POST /api/auth/ws-ticket — the mint endpoint +# --------------------------------------------------------------------------- + + +class TestWsTicketEndpoint: + def test_authenticated_session_can_mint(self, gated_app): + _logged_in(gated_app) + r = gated_app.post("/api/auth/ws-ticket") + assert r.status_code == 200 + body = r.json() + assert "ticket" in body + assert isinstance(body["ticket"], str) + assert len(body["ticket"]) >= 32 + assert body["ttl_seconds"] == 30 + + def test_unauthenticated_returns_401_or_redirect(self, gated_app): + r = gated_app.post("/api/auth/ws-ticket", follow_redirects=False) + # gated_auth_middleware short-circuits before the route — it + # returns either 401 or 302. Either is fine. + assert r.status_code in (302, 401) + + def test_each_call_returns_a_distinct_ticket(self, gated_app): + _logged_in(gated_app) + tickets = {gated_app.post("/api/auth/ws-ticket").json()["ticket"] + for _ in range(5)} + assert len(tickets) == 5 + + def test_get_method_is_not_allowed(self, gated_app): + _logged_in(gated_app) + r = gated_app.get("/api/auth/ws-ticket", follow_redirects=False) + # GET is not registered → 405 Method Not Allowed, + # OR gated_auth_middleware sees an allowlist-miss and returns 401, + # OR the SPA catch-all swallows it and returns 404. + # Any of these proves the endpoint isn't a GET (which would be + # cookie-replayable from a malicious origin via ). + assert r.status_code in (401, 404, 405) + + +# --------------------------------------------------------------------------- +# _ws_auth_ok — unit-level (synthetic WebSocket-shaped object) +# --------------------------------------------------------------------------- + + +def _fake_ws(*, query: dict, client_host: str = "127.0.0.1", path: str = "/api/pty"): + """Build a stand-in for starlette.WebSocket good enough for _ws_auth_ok.""" + + class _QP: + def __init__(self, q): + self._q = q + + def get(self, k, default=""): + return self._q.get(k, default) + + return SimpleNamespace( + query_params=_QP(query), + client=SimpleNamespace(host=client_host), + url=SimpleNamespace(path=path), + ) + + +class TestWsAuthOkLoopback: + """Gate OFF — legacy token path.""" + + def test_correct_token_accepted(self, loopback_app): + ws = _fake_ws(query={"token": web_server._SESSION_TOKEN}) + assert web_server._ws_auth_ok(ws) is True + + def test_wrong_token_rejected(self, loopback_app): + ws = _fake_ws(query={"token": "not-the-real-token"}) + assert web_server._ws_auth_ok(ws) is False + + def test_missing_token_rejected(self, loopback_app): + ws = _fake_ws(query={}) + assert web_server._ws_auth_ok(ws) is False + + def test_ticket_param_ignored_in_loopback(self, loopback_app): + # Even if someone sneaks a ticket through, loopback mode only + # cares about ?token=. A naked ticket isn't a token. + ticket = mint_ticket(user_id="u1", provider="stub") + ws = _fake_ws(query={"ticket": ticket}) + assert web_server._ws_auth_ok(ws) is False + + +class TestWsAuthOkGated: + """Gate ON — ticket path only.""" + + def test_valid_ticket_accepted(self, gated_app): + ticket = mint_ticket(user_id="u1", provider="stub") + ws = _fake_ws(query={"ticket": ticket}) + assert web_server._ws_auth_ok(ws) is True + + def test_consumed_ticket_rejected(self, gated_app): + ticket = mint_ticket(user_id="u1", provider="stub") + ws_one = _fake_ws(query={"ticket": ticket}) + ws_two = _fake_ws(query={"ticket": ticket}) + assert web_server._ws_auth_ok(ws_one) is True + # Single-use — second consumption fails. + assert web_server._ws_auth_ok(ws_two) is False + + def test_unknown_ticket_rejected(self, gated_app): + ws = _fake_ws(query={"ticket": "never-minted"}) + assert web_server._ws_auth_ok(ws) is False + + def test_missing_ticket_rejected(self, gated_app): + ws = _fake_ws(query={}) + assert web_server._ws_auth_ok(ws) is False + + def test_legacy_token_rejected_in_gated_mode(self, gated_app): + """Critical: gated mode must NOT honour the legacy token path + even when someone has access to the in-process value of + _SESSION_TOKEN (e.g. a leaked log line).""" + ws = _fake_ws(query={"token": web_server._SESSION_TOKEN}) + assert web_server._ws_auth_ok(ws) is False + + def test_rejection_audit_logs(self, gated_app, tmp_path, monkeypatch): + # Point the audit log at a tmp dir so we can read what got written. + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from hermes_cli.dashboard_auth import audit as audit_mod + + # The log path is resolved lazily on the first audit_log() call; + # bust any cached handler so it re-resolves. + if hasattr(audit_mod, "_LOGGER"): + monkeypatch.setattr(audit_mod, "_LOGGER", None, raising=False) + + ws = _fake_ws(query={"ticket": "never-minted"}) + assert web_server._ws_auth_ok(ws) is False + + log_file = tmp_path / "logs" / "dashboard-auth.log" + # The audit module may write asynchronously through stdlib logging, + # but flush is synchronous. If the file doesn't exist yet, the + # logger may not have been initialized in this process — that's + # acceptable as long as the rejection path didn't crash. + if log_file.exists(): + content = log_file.read_text() + assert "ws_ticket_rejected" in content + + +# --------------------------------------------------------------------------- +# _build_sidecar_url — gated mode mints a server-internal ticket +# --------------------------------------------------------------------------- + + +class TestSidecarUrl: + def test_loopback_uses_session_token(self, loopback_app): + url = web_server._build_sidecar_url("ch-1") + assert url is not None + assert f"token={web_server._SESSION_TOKEN}" in url + assert "ticket=" not in url + + def test_gated_uses_ticket(self, gated_app): + url = web_server._build_sidecar_url("ch-1") + assert url is not None + assert "token=" not in url + assert "ticket=" in url + # And the ticket should be live. + ticket = url.split("ticket=")[1].split("&")[0] + info = consume_ticket(ticket) + # Sidecar tickets are bound to the pseudo-user so audit logs can + # distinguish them from real browser tickets. + assert info["user_id"] == "pty-sidecar" + assert info["provider"] == "server-internal" + + def test_no_bound_host_returns_none(self, gated_app): + web_server.app.state.bound_host = None + try: + assert web_server._build_sidecar_url("ch") is None + finally: + web_server.app.state.bound_host = "fly-app.fly.dev"