mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-29 06:31:32 +00:00
feat(dashboard-auth): single-use WS tickets + POST /api/auth/ws-ticket
Phase 5 task 5.1. Browsers cannot set Authorization on a WebSocket
upgrade, so in gated mode the SPA needs an alternative way to bind the
upgrade to its authenticated session.
hermes_cli/dashboard_auth/ws_tickets.py — in-memory single-use ticket
store with 30s TTL. Thread-safe (threading.Lock), token_urlsafe(32)
values, ticket value truncated to 8 chars in error messages for log
hygiene. Module-level state with _reset_for_tests() helper.
hermes_cli/dashboard_auth/routes.py — adds POST /api/auth/ws-ticket.
Auth-required (the gate middleware already attaches Session to
request.state.session). Returns {ticket, ttl_seconds}; emits
WS_TICKET_MINTED audit event with user_id + provider + ip.
hermes_cli/dashboard_auth/audit.py — adds WS_TICKET_REJECTED enum
value for the consume-side rejection event (wired into the WS
endpoints in task 5.2).
11 new tests covering round-trip, single-use, TTL boundary, unknown
ticket rejection, secret-hygiene truncation in error messages, and
concurrent mint+consume from 20 threads.
This commit is contained in:
parent
848baeb0a8
commit
b69fce9c86
4 changed files with 285 additions and 0 deletions
|
|
@ -46,6 +46,7 @@ class AuditEvent(enum.Enum):
|
|||
REVOKE = "revoke"
|
||||
SESSION_VERIFY_FAILURE = "session_verify_failure"
|
||||
WS_TICKET_MINTED = "ws_ticket_minted"
|
||||
WS_TICKET_REJECTED = "ws_ticket_rejected"
|
||||
|
||||
|
||||
def _resolve_log_path() -> Path:
|
||||
|
|
|
|||
|
|
@ -302,3 +302,39 @@ async def api_auth_me(request: Request):
|
|||
"provider": sess.provider,
|
||||
"expires_at": sess.expires_at,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth-required: WS upgrade ticket (Phase 5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/api/auth/ws-ticket", name="auth_ws_ticket")
|
||||
async def api_auth_ws_ticket(request: Request):
|
||||
"""Mint a short-lived single-use ticket for the authenticated session.
|
||||
|
||||
Browsers cannot set ``Authorization`` on a WebSocket upgrade, so in
|
||||
gated mode the SPA POSTs this endpoint to get a ``?ticket=`` value to
|
||||
append to ``/api/pty``, ``/api/ws``, ``/api/pub``, or ``/api/events``.
|
||||
|
||||
The ticket has a 30-second TTL and is single-use. Calling this endpoint
|
||||
multiple times in quick succession (e.g. one ticket per WS) is the
|
||||
expected pattern.
|
||||
"""
|
||||
sess = getattr(request.state, "session", None)
|
||||
if sess is None:
|
||||
# Middleware should already have rejected, but check defensively.
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
# Import here so the routes module stays usable in test contexts that
|
||||
# don't load the ticket store.
|
||||
from hermes_cli.dashboard_auth.ws_tickets import TTL_SECONDS, mint_ticket
|
||||
|
||||
ticket = mint_ticket(user_id=sess.user_id, provider=sess.provider)
|
||||
audit_log(
|
||||
AuditEvent.WS_TICKET_MINTED,
|
||||
provider=sess.provider,
|
||||
user_id=sess.user_id,
|
||||
ip=_client_ip(request),
|
||||
)
|
||||
return {"ticket": ticket, "ttl_seconds": TTL_SECONDS}
|
||||
|
|
|
|||
87
hermes_cli/dashboard_auth/ws_tickets.py
Normal file
87
hermes_cli/dashboard_auth/ws_tickets.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""Short-lived single-use tickets for WS-upgrade auth in gated mode.
|
||||
|
||||
Browsers cannot set ``Authorization`` on a WebSocket upgrade. In loopback
|
||||
mode the legacy ``?token=<_SESSION_TOKEN>`` query param works because the
|
||||
token is injected into the SPA bundle. In gated mode there is no injected
|
||||
token — the SPA gets a fresh ticket via the authenticated REST endpoint
|
||||
``POST /api/auth/ws-ticket`` and passes that as ``?ticket=`` on the
|
||||
WS upgrade.
|
||||
|
||||
Tickets are single-use, TTL = 30 seconds. In-memory; the dashboard is a
|
||||
single process so no distributed coordination is needed. The module
|
||||
exposes a small functional API rather than a class so tests can patch
|
||||
``time.time`` cleanly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
#: Time-to-live for newly-minted tickets in seconds. 30 s is long enough
|
||||
#: that the SPA can call ``getWsTicket()`` and immediately open the WS,
|
||||
#: short enough that a leaked ticket is uninteresting.
|
||||
TTL_SECONDS = 30
|
||||
|
||||
_lock = threading.Lock()
|
||||
_tickets: Dict[str, Tuple[int, Dict[str, Any]]] = {} # ticket -> (expires_at, info)
|
||||
|
||||
|
||||
class TicketInvalid(Exception):
|
||||
"""Ticket missing, expired, or already consumed."""
|
||||
|
||||
|
||||
def mint_ticket(*, user_id: str, provider: str) -> str:
|
||||
"""Generate a one-shot ticket bound to this user identity.
|
||||
|
||||
The returned token is base64url, 43 bytes of entropy (32-byte random
|
||||
seed). Stash returns the ``info`` dict to the caller on consume so the
|
||||
WS handler can carry the identity forward into its session log.
|
||||
"""
|
||||
ticket = secrets.token_urlsafe(32)
|
||||
info = {
|
||||
"user_id": user_id,
|
||||
"provider": provider,
|
||||
"minted_at": int(time.time()),
|
||||
}
|
||||
with _lock:
|
||||
_tickets[ticket] = (int(time.time()) + TTL_SECONDS, info)
|
||||
_gc_expired_locked()
|
||||
return ticket
|
||||
|
||||
|
||||
def consume_ticket(ticket: str) -> Dict[str, Any]:
|
||||
"""Validate and consume. Raises :class:`TicketInvalid` on missing/expired/used.
|
||||
|
||||
Single-use semantics: a successful consume immediately removes the
|
||||
ticket from the store, so a second call with the same value raises
|
||||
``TicketInvalid("unknown ticket: …")``.
|
||||
"""
|
||||
now = int(time.time())
|
||||
with _lock:
|
||||
entry = _tickets.pop(ticket, None)
|
||||
if entry is None:
|
||||
# Truncate ticket value in the error so misuse never logs the
|
||||
# secret in full.
|
||||
truncated = (ticket[:8] + "…") if ticket else "<empty>"
|
||||
raise TicketInvalid(f"unknown ticket: {truncated}")
|
||||
expires_at, info = entry
|
||||
if expires_at < now:
|
||||
raise TicketInvalid("expired")
|
||||
return info
|
||||
|
||||
|
||||
def _gc_expired_locked() -> None:
|
||||
"""Drop expired tickets. Caller must hold ``_lock``."""
|
||||
now = int(time.time())
|
||||
expired = [t for t, (exp, _) in _tickets.items() if exp < now]
|
||||
for t in expired:
|
||||
_tickets.pop(t, None)
|
||||
|
||||
|
||||
def _reset_for_tests() -> None:
|
||||
"""Test-only: drop all tickets."""
|
||||
with _lock:
|
||||
_tickets.clear()
|
||||
161
tests/hermes_cli/test_dashboard_auth_ws_tickets.py
Normal file
161
tests/hermes_cli/test_dashboard_auth_ws_tickets.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
"""Tests for the WS-upgrade ticket store (Phase 5 task 5.1).
|
||||
|
||||
The store is process-local and threading-safe. Tests run with xdist so
|
||||
each worker has its own module instance — no cross-worker bleed — but we
|
||||
call ``_reset_for_tests`` between tests to keep things deterministic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.dashboard_auth import ws_tickets
|
||||
from hermes_cli.dashboard_auth.ws_tickets import (
|
||||
TTL_SECONDS,
|
||||
TicketInvalid,
|
||||
_reset_for_tests,
|
||||
consume_ticket,
|
||||
mint_ticket,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset():
|
||||
_reset_for_tests()
|
||||
yield
|
||||
_reset_for_tests()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMintAndConsume:
|
||||
def test_round_trip(self):
|
||||
ticket = mint_ticket(user_id="u1", provider="nous")
|
||||
info = consume_ticket(ticket)
|
||||
assert info["user_id"] == "u1"
|
||||
assert info["provider"] == "nous"
|
||||
assert "minted_at" in info
|
||||
|
||||
def test_ticket_has_minimum_length(self):
|
||||
# ``secrets.token_urlsafe(32)`` produces ~43 chars; enforce a floor
|
||||
# so a future refactor can't accidentally shrink the entropy.
|
||||
ticket = mint_ticket(user_id="u1", provider="nous")
|
||||
assert len(ticket) >= 32
|
||||
|
||||
def test_ticket_values_are_unique(self):
|
||||
seen = {mint_ticket(user_id="u1", provider="x") for _ in range(50)}
|
||||
assert len(seen) == 50
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single-use
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSingleUse:
|
||||
def test_second_consume_raises(self):
|
||||
ticket = mint_ticket(user_id="u1", provider="stub")
|
||||
consume_ticket(ticket)
|
||||
with pytest.raises(TicketInvalid, match="unknown"):
|
||||
consume_ticket(ticket)
|
||||
|
||||
def test_unknown_ticket_rejected(self):
|
||||
with pytest.raises(TicketInvalid, match="unknown"):
|
||||
consume_ticket("nope-never-minted")
|
||||
|
||||
def test_empty_ticket_rejected(self):
|
||||
with pytest.raises(TicketInvalid):
|
||||
consume_ticket("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TTL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTTL:
|
||||
def test_constant_is_30_seconds(self):
|
||||
# Pinned so a refactor that doubled the lifetime would surface here.
|
||||
assert TTL_SECONDS == 30
|
||||
|
||||
def test_expired_ticket_rejected(self, monkeypatch):
|
||||
# Mock time inside the ws_tickets module so mint and consume see
|
||||
# different clocks. We have to patch the symbol the module actually
|
||||
# binds; ``time`` is module-level there.
|
||||
clock = {"now": 1_000_000}
|
||||
|
||||
def fake_time():
|
||||
return clock["now"]
|
||||
|
||||
monkeypatch.setattr(ws_tickets.time, "time", fake_time)
|
||||
|
||||
ticket = mint_ticket(user_id="u1", provider="stub")
|
||||
clock["now"] += TTL_SECONDS + 1
|
||||
with pytest.raises(TicketInvalid, match="expired"):
|
||||
consume_ticket(ticket)
|
||||
|
||||
def test_at_exact_ttl_boundary_still_valid(self, monkeypatch):
|
||||
clock = {"now": 1_000_000}
|
||||
monkeypatch.setattr(ws_tickets.time, "time", lambda: clock["now"])
|
||||
|
||||
ticket = mint_ticket(user_id="u1", provider="stub")
|
||||
clock["now"] += TTL_SECONDS # exactly at boundary; expires_at == now
|
||||
# Implementation: ``expires_at < now`` (strict), so == passes.
|
||||
info = consume_ticket(ticket)
|
||||
assert info["user_id"] == "u1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Truncated value in error message (secret hygiene)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestErrorMessages:
|
||||
def test_unknown_ticket_error_truncates_value(self):
|
||||
long_value = "a" * 100
|
||||
with pytest.raises(TicketInvalid) as exc_info:
|
||||
consume_ticket(long_value)
|
||||
# Never log more than the first 8 chars of an opaque ticket.
|
||||
message = str(exc_info.value)
|
||||
assert long_value not in message
|
||||
assert long_value[:8] in message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread safety: mint + consume from many threads doesn't deadlock or
|
||||
# return duplicates.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConcurrency:
|
||||
def test_mint_and_consume_concurrent(self):
|
||||
results: list[dict] = []
|
||||
errors: list[Exception] = []
|
||||
lock = threading.Lock()
|
||||
|
||||
def worker(i: int):
|
||||
try:
|
||||
t = mint_ticket(user_id=f"u{i}", provider="stub")
|
||||
info = consume_ticket(t)
|
||||
with lock:
|
||||
results.append(info)
|
||||
except Exception as exc: # noqa: BLE001 — collect for assert
|
||||
with lock:
|
||||
errors.append(exc)
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(20)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=5.0)
|
||||
assert not t.is_alive(), "thread deadlocked"
|
||||
|
||||
assert errors == []
|
||||
assert len(results) == 20
|
||||
# Every consume returns a distinct user_id (no cross-thread bleed).
|
||||
assert {r["user_id"] for r in results} == {f"u{i}" for i in range(20)}
|
||||
Loading…
Add table
Add a link
Reference in a new issue