diff --git a/hermes_cli/dashboard_auth/audit.py b/hermes_cli/dashboard_auth/audit.py index d20cdb7def8..9e52ca75ebe 100644 --- a/hermes_cli/dashboard_auth/audit.py +++ b/hermes_cli/dashboard_auth/audit.py @@ -46,6 +46,7 @@ class AuditEvent(enum.Enum): REVOKE = "revoke" SESSION_VERIFY_FAILURE = "session_verify_failure" WS_TICKET_MINTED = "ws_ticket_minted" + WS_TICKET_REJECTED = "ws_ticket_rejected" def _resolve_log_path() -> Path: diff --git a/hermes_cli/dashboard_auth/routes.py b/hermes_cli/dashboard_auth/routes.py index b14ff60905f..b4acf581af0 100644 --- a/hermes_cli/dashboard_auth/routes.py +++ b/hermes_cli/dashboard_auth/routes.py @@ -302,3 +302,39 @@ async def api_auth_me(request: Request): "provider": sess.provider, "expires_at": sess.expires_at, } + + +# --------------------------------------------------------------------------- +# Auth-required: WS upgrade ticket (Phase 5) +# --------------------------------------------------------------------------- + + +@router.post("/api/auth/ws-ticket", name="auth_ws_ticket") +async def api_auth_ws_ticket(request: Request): + """Mint a short-lived single-use ticket for the authenticated session. + + Browsers cannot set ``Authorization`` on a WebSocket upgrade, so in + gated mode the SPA POSTs this endpoint to get a ``?ticket=`` value to + append to ``/api/pty``, ``/api/ws``, ``/api/pub``, or ``/api/events``. + + The ticket has a 30-second TTL and is single-use. Calling this endpoint + multiple times in quick succession (e.g. one ticket per WS) is the + expected pattern. + """ + sess = getattr(request.state, "session", None) + if sess is None: + # Middleware should already have rejected, but check defensively. + raise HTTPException(status_code=401, detail="Unauthorized") + + # Import here so the routes module stays usable in test contexts that + # don't load the ticket store. + from hermes_cli.dashboard_auth.ws_tickets import TTL_SECONDS, mint_ticket + + ticket = mint_ticket(user_id=sess.user_id, provider=sess.provider) + audit_log( + AuditEvent.WS_TICKET_MINTED, + provider=sess.provider, + user_id=sess.user_id, + ip=_client_ip(request), + ) + return {"ticket": ticket, "ttl_seconds": TTL_SECONDS} diff --git a/hermes_cli/dashboard_auth/ws_tickets.py b/hermes_cli/dashboard_auth/ws_tickets.py new file mode 100644 index 00000000000..6ebad217e46 --- /dev/null +++ b/hermes_cli/dashboard_auth/ws_tickets.py @@ -0,0 +1,87 @@ +"""Short-lived single-use tickets for WS-upgrade auth in gated mode. + +Browsers cannot set ``Authorization`` on a WebSocket upgrade. In loopback +mode the legacy ``?token=<_SESSION_TOKEN>`` query param works because the +token is injected into the SPA bundle. In gated mode there is no injected +token — the SPA gets a fresh ticket via the authenticated REST endpoint +``POST /api/auth/ws-ticket`` and passes that as ``?ticket=`` on the +WS upgrade. + +Tickets are single-use, TTL = 30 seconds. In-memory; the dashboard is a +single process so no distributed coordination is needed. The module +exposes a small functional API rather than a class so tests can patch +``time.time`` cleanly. +""" + +from __future__ import annotations + +import secrets +import threading +import time +from typing import Any, Dict, Tuple + +#: Time-to-live for newly-minted tickets in seconds. 30 s is long enough +#: that the SPA can call ``getWsTicket()`` and immediately open the WS, +#: short enough that a leaked ticket is uninteresting. +TTL_SECONDS = 30 + +_lock = threading.Lock() +_tickets: Dict[str, Tuple[int, Dict[str, Any]]] = {} # ticket -> (expires_at, info) + + +class TicketInvalid(Exception): + """Ticket missing, expired, or already consumed.""" + + +def mint_ticket(*, user_id: str, provider: str) -> str: + """Generate a one-shot ticket bound to this user identity. + + The returned token is base64url, 43 bytes of entropy (32-byte random + seed). Stash returns the ``info`` dict to the caller on consume so the + WS handler can carry the identity forward into its session log. + """ + ticket = secrets.token_urlsafe(32) + info = { + "user_id": user_id, + "provider": provider, + "minted_at": int(time.time()), + } + with _lock: + _tickets[ticket] = (int(time.time()) + TTL_SECONDS, info) + _gc_expired_locked() + return ticket + + +def consume_ticket(ticket: str) -> Dict[str, Any]: + """Validate and consume. Raises :class:`TicketInvalid` on missing/expired/used. + + Single-use semantics: a successful consume immediately removes the + ticket from the store, so a second call with the same value raises + ``TicketInvalid("unknown ticket: …")``. + """ + now = int(time.time()) + with _lock: + entry = _tickets.pop(ticket, None) + if entry is None: + # Truncate ticket value in the error so misuse never logs the + # secret in full. + truncated = (ticket[:8] + "…") if ticket else "" + raise TicketInvalid(f"unknown ticket: {truncated}") + expires_at, info = entry + if expires_at < now: + raise TicketInvalid("expired") + return info + + +def _gc_expired_locked() -> None: + """Drop expired tickets. Caller must hold ``_lock``.""" + now = int(time.time()) + expired = [t for t, (exp, _) in _tickets.items() if exp < now] + for t in expired: + _tickets.pop(t, None) + + +def _reset_for_tests() -> None: + """Test-only: drop all tickets.""" + with _lock: + _tickets.clear() diff --git a/tests/hermes_cli/test_dashboard_auth_ws_tickets.py b/tests/hermes_cli/test_dashboard_auth_ws_tickets.py new file mode 100644 index 00000000000..6eeefbed54a --- /dev/null +++ b/tests/hermes_cli/test_dashboard_auth_ws_tickets.py @@ -0,0 +1,161 @@ +"""Tests for the WS-upgrade ticket store (Phase 5 task 5.1). + +The store is process-local and threading-safe. Tests run with xdist so +each worker has its own module instance — no cross-worker bleed — but we +call ``_reset_for_tests`` between tests to keep things deterministic. +""" + +from __future__ import annotations + +import threading + +import pytest + +from hermes_cli.dashboard_auth import ws_tickets +from hermes_cli.dashboard_auth.ws_tickets import ( + TTL_SECONDS, + TicketInvalid, + _reset_for_tests, + consume_ticket, + mint_ticket, +) + + +@pytest.fixture(autouse=True) +def _reset(): + _reset_for_tests() + yield + _reset_for_tests() + + +# --------------------------------------------------------------------------- +# Happy path +# --------------------------------------------------------------------------- + + +class TestMintAndConsume: + def test_round_trip(self): + ticket = mint_ticket(user_id="u1", provider="nous") + info = consume_ticket(ticket) + assert info["user_id"] == "u1" + assert info["provider"] == "nous" + assert "minted_at" in info + + def test_ticket_has_minimum_length(self): + # ``secrets.token_urlsafe(32)`` produces ~43 chars; enforce a floor + # so a future refactor can't accidentally shrink the entropy. + ticket = mint_ticket(user_id="u1", provider="nous") + assert len(ticket) >= 32 + + def test_ticket_values_are_unique(self): + seen = {mint_ticket(user_id="u1", provider="x") for _ in range(50)} + assert len(seen) == 50 + + +# --------------------------------------------------------------------------- +# Single-use +# --------------------------------------------------------------------------- + + +class TestSingleUse: + def test_second_consume_raises(self): + ticket = mint_ticket(user_id="u1", provider="stub") + consume_ticket(ticket) + with pytest.raises(TicketInvalid, match="unknown"): + consume_ticket(ticket) + + def test_unknown_ticket_rejected(self): + with pytest.raises(TicketInvalid, match="unknown"): + consume_ticket("nope-never-minted") + + def test_empty_ticket_rejected(self): + with pytest.raises(TicketInvalid): + consume_ticket("") + + +# --------------------------------------------------------------------------- +# TTL +# --------------------------------------------------------------------------- + + +class TestTTL: + def test_constant_is_30_seconds(self): + # Pinned so a refactor that doubled the lifetime would surface here. + assert TTL_SECONDS == 30 + + def test_expired_ticket_rejected(self, monkeypatch): + # Mock time inside the ws_tickets module so mint and consume see + # different clocks. We have to patch the symbol the module actually + # binds; ``time`` is module-level there. + clock = {"now": 1_000_000} + + def fake_time(): + return clock["now"] + + monkeypatch.setattr(ws_tickets.time, "time", fake_time) + + ticket = mint_ticket(user_id="u1", provider="stub") + clock["now"] += TTL_SECONDS + 1 + with pytest.raises(TicketInvalid, match="expired"): + consume_ticket(ticket) + + def test_at_exact_ttl_boundary_still_valid(self, monkeypatch): + clock = {"now": 1_000_000} + monkeypatch.setattr(ws_tickets.time, "time", lambda: clock["now"]) + + ticket = mint_ticket(user_id="u1", provider="stub") + clock["now"] += TTL_SECONDS # exactly at boundary; expires_at == now + # Implementation: ``expires_at < now`` (strict), so == passes. + info = consume_ticket(ticket) + assert info["user_id"] == "u1" + + +# --------------------------------------------------------------------------- +# Truncated value in error message (secret hygiene) +# --------------------------------------------------------------------------- + + +class TestErrorMessages: + def test_unknown_ticket_error_truncates_value(self): + long_value = "a" * 100 + with pytest.raises(TicketInvalid) as exc_info: + consume_ticket(long_value) + # Never log more than the first 8 chars of an opaque ticket. + message = str(exc_info.value) + assert long_value not in message + assert long_value[:8] in message + + +# --------------------------------------------------------------------------- +# Thread safety: mint + consume from many threads doesn't deadlock or +# return duplicates. +# --------------------------------------------------------------------------- + + +class TestConcurrency: + def test_mint_and_consume_concurrent(self): + results: list[dict] = [] + errors: list[Exception] = [] + lock = threading.Lock() + + def worker(i: int): + try: + t = mint_ticket(user_id=f"u{i}", provider="stub") + info = consume_ticket(t) + with lock: + results.append(info) + except Exception as exc: # noqa: BLE001 — collect for assert + with lock: + errors.append(exc) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(20)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5.0) + assert not t.is_alive(), "thread deadlocked" + + assert errors == [] + assert len(results) == 20 + # Every consume returns a distinct user_id (no cross-thread bleed). + assert {r["user_id"] for r in results} == {f"u{i}" for i in range(20)}