mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-29 06:31:32 +00:00
feat(dashboard-auth): _ws_auth_ok helper + ticket auth on all 4 WS endpoints
Phase 5 task 5.2. Four WebSocket endpoints — /api/pty, /api/ws, /api/pub,
/api/events — previously authed with the same constant-time check against
`_SESSION_TOKEN`. Replaced with a single helper that branches on
`app.state.auth_required`:
Loopback / --insecure: legacy ?token=<_SESSION_TOKEN> path (unchanged).
Gated: ?ticket=<single-use> consumed against the
dashboard-auth ticket store.
Critical security property: gated mode UNCONDITIONALLY rejects the
?token= path. A leaked _SESSION_TOKEN value from a log line is not
replayable for WS access in gated deployments.
`_build_sidecar_url` now branches too: loopback uses the legacy token;
gated mode mints a server-internal ticket via mint_ticket() with
pseudo-user 'pty-sidecar' / provider 'server-internal' so audit logs can
distinguish PTY-internal sidecar tickets from browser tickets. PTY
children open /api/pub exactly once at startup so single-use suffices.
Ticket rejections audit-log as WS_TICKET_REJECTED with truncated reason
+ client IP + WS path. Operators debugging 'WS keeps closing' issues see
which endpoint and why.
17 new tests:
- POST /api/auth/ws-ticket: 200 with cookie, 401/302 without, distinct
per call, GET-not-allowed.
- _ws_auth_ok loopback: token accept/reject, missing-token reject,
ticket-param-ignored.
- _ws_auth_ok gated: ticket accept, single-use rejection, unknown reject,
legacy-token-rejected-in-gated assertion, audit-log emission.
- _build_sidecar_url: loopback uses token=, gated uses ticket=, no-bound
returns None.
This commit is contained in:
parent
b69fce9c86
commit
b2360ba44e
2 changed files with 329 additions and 11 deletions
|
|
@ -3401,6 +3401,50 @@ def _ws_request_is_allowed(ws: "WebSocket") -> bool:
|
|||
"""Return True when the WebSocket upgrade matches dashboard boundaries."""
|
||||
return _ws_host_origin_is_allowed(ws) and _ws_client_is_allowed(ws)
|
||||
|
||||
|
||||
def _ws_auth_ok(ws: "WebSocket") -> bool:
|
||||
"""Validate WS-upgrade auth in either loopback or gated mode.
|
||||
|
||||
Loopback / ``--insecure``: legacy ``?token=<_SESSION_TOKEN>`` query
|
||||
parameter, constant-time compared.
|
||||
|
||||
Gated (public bind, no ``--insecure``): ``?ticket=<single-use>`` query
|
||||
parameter consumed against the dashboard-auth ticket store. The legacy
|
||||
token path is unconditionally rejected in this mode (the SPA bundle
|
||||
isn't carrying the token any longer).
|
||||
|
||||
Returns True if the WS should be accepted; callers close with the
|
||||
appropriate WS code (4401) on False. Audit-logs the rejection so
|
||||
operators can debug "WS keeps closing" issues from the log.
|
||||
"""
|
||||
auth_required = bool(getattr(app.state, "auth_required", False))
|
||||
if auth_required:
|
||||
ticket = ws.query_params.get("ticket", "")
|
||||
if not ticket:
|
||||
return False
|
||||
# Lazy import — keeps this function importable in test harnesses
|
||||
# that don't bring in the dashboard_auth layer.
|
||||
from hermes_cli.dashboard_auth.audit import AuditEvent, audit_log
|
||||
from hermes_cli.dashboard_auth.ws_tickets import (
|
||||
TicketInvalid,
|
||||
consume_ticket,
|
||||
)
|
||||
|
||||
try:
|
||||
consume_ticket(ticket)
|
||||
return True
|
||||
except TicketInvalid as exc:
|
||||
audit_log(
|
||||
AuditEvent.WS_TICKET_REJECTED,
|
||||
reason=str(exc),
|
||||
ip=(ws.client.host if ws.client else ""),
|
||||
path=ws.url.path,
|
||||
)
|
||||
return False
|
||||
|
||||
token = ws.query_params.get("token", "")
|
||||
return hmac.compare_digest(token.encode(), _SESSION_TOKEN.encode())
|
||||
|
||||
# Per-channel subscriber registry used by /api/pub (PTY-side gateway → dashboard)
|
||||
# and /api/events (dashboard → browser sidebar). Keyed by an opaque channel id
|
||||
# the chat tab generates on mount; entries auto-evict when the last subscriber
|
||||
|
|
@ -3455,7 +3499,21 @@ def _resolve_chat_argv(
|
|||
|
||||
|
||||
def _build_sidecar_url(channel: str) -> Optional[str]:
|
||||
"""ws:// URL the PTY child should publish events to, or None when unbound."""
|
||||
"""ws:// URL the PTY child should publish events to, or None when unbound.
|
||||
|
||||
Loopback / ``--insecure``: uses ``?token=<_SESSION_TOKEN>``.
|
||||
|
||||
Gated mode: mints a single-use ticket via the dashboard-auth ticket
|
||||
store (server-side mint, no HTTP round trip — the PTY child is a
|
||||
server-spawned process and we trust it). The ticket binds to the
|
||||
pseudo-user ``"pty-sidecar"`` so audit logs can distinguish these from
|
||||
browser-initiated tickets.
|
||||
|
||||
The single-use lifetime means the PTY child cannot reconnect without a
|
||||
new sidecar URL. PTY children open ``/api/pub`` once at startup; if
|
||||
reconnect semantics ever become important, this should be upgraded to
|
||||
a long-lived process-scoped token.
|
||||
"""
|
||||
host = getattr(app.state, "bound_host", None)
|
||||
port = getattr(app.state, "bound_port", None)
|
||||
|
||||
|
|
@ -3463,7 +3521,15 @@ def _build_sidecar_url(channel: str) -> Optional[str]:
|
|||
return None
|
||||
|
||||
netloc = f"[{host}]:{port}" if ":" in host and not host.startswith("[") else f"{host}:{port}"
|
||||
qs = urllib.parse.urlencode({"token": _SESSION_TOKEN, "channel": channel})
|
||||
|
||||
if getattr(app.state, "auth_required", False):
|
||||
# Gated mode — mint a ticket so the WS upgrade survives _ws_auth_ok.
|
||||
from hermes_cli.dashboard_auth.ws_tickets import mint_ticket
|
||||
|
||||
ticket = mint_ticket(user_id="pty-sidecar", provider="server-internal")
|
||||
qs = urllib.parse.urlencode({"ticket": ticket, "channel": channel})
|
||||
else:
|
||||
qs = urllib.parse.urlencode({"token": _SESSION_TOKEN, "channel": channel})
|
||||
|
||||
return f"ws://{netloc}/api/pub?{qs}"
|
||||
|
||||
|
|
@ -3496,9 +3562,7 @@ async def pty_ws(ws: WebSocket) -> None:
|
|||
return
|
||||
|
||||
# --- auth + loopback check (before accept so we can close cleanly) ---
|
||||
token = ws.query_params.get("token", "")
|
||||
expected = _SESSION_TOKEN
|
||||
if not hmac.compare_digest(token.encode(), expected.encode()):
|
||||
if not _ws_auth_ok(ws):
|
||||
await ws.close(code=4401)
|
||||
return
|
||||
|
||||
|
|
@ -3616,8 +3680,7 @@ async def gateway_ws(ws: WebSocket) -> None:
|
|||
await ws.close(code=4403)
|
||||
return
|
||||
|
||||
token = ws.query_params.get("token", "")
|
||||
if not hmac.compare_digest(token.encode(), _SESSION_TOKEN.encode()):
|
||||
if not _ws_auth_ok(ws):
|
||||
await ws.close(code=4401)
|
||||
return
|
||||
|
||||
|
|
@ -3648,8 +3711,7 @@ async def pub_ws(ws: WebSocket) -> None:
|
|||
await ws.close(code=4403)
|
||||
return
|
||||
|
||||
token = ws.query_params.get("token", "")
|
||||
if not hmac.compare_digest(token.encode(), _SESSION_TOKEN.encode()):
|
||||
if not _ws_auth_ok(ws):
|
||||
await ws.close(code=4401)
|
||||
return
|
||||
|
||||
|
|
@ -3677,8 +3739,7 @@ async def events_ws(ws: WebSocket) -> None:
|
|||
await ws.close(code=4403)
|
||||
return
|
||||
|
||||
token = ws.query_params.get("token", "")
|
||||
if not hmac.compare_digest(token.encode(), _SESSION_TOKEN.encode()):
|
||||
if not _ws_auth_ok(ws):
|
||||
await ws.close(code=4401)
|
||||
return
|
||||
|
||||
|
|
|
|||
257
tests/hermes_cli/test_dashboard_auth_ws_auth.py
Normal file
257
tests/hermes_cli/test_dashboard_auth_ws_auth.py
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
"""Tests for the WS-upgrade auth helper (Phase 5 task 5.2).
|
||||
|
||||
The dashboard's four WS endpoints (``/api/pty``, ``/api/ws``, ``/api/pub``,
|
||||
``/api/events``) share an auth gate: ``_ws_auth_ok``. In loopback mode it
|
||||
accepts ``?token=<_SESSION_TOKEN>``; in gated mode it accepts a single-use
|
||||
``?ticket=`` minted by ``POST /api/auth/ws-ticket``.
|
||||
|
||||
These tests exercise the helper at the unit level (no actual WS upgrade)
|
||||
plus the ticket-mint endpoint under realistic gated-mode setup. We don't
|
||||
test the full WS upgrade because the starlette TestClient WS path has a
|
||||
pre-existing regression unrelated to dashboard-auth.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from hermes_cli import web_server
|
||||
from hermes_cli.dashboard_auth import clear_providers, register_provider
|
||||
from hermes_cli.dashboard_auth.ws_tickets import (
|
||||
TicketInvalid,
|
||||
_reset_for_tests,
|
||||
consume_ticket,
|
||||
mint_ticket,
|
||||
)
|
||||
from tests.hermes_cli.conftest_dashboard_auth import StubAuthProvider
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gated_app():
|
||||
"""web_server.app configured for gated mode + stub provider registered."""
|
||||
_reset_for_tests()
|
||||
clear_providers()
|
||||
register_provider(StubAuthProvider())
|
||||
prev_host = getattr(web_server.app.state, "bound_host", None)
|
||||
prev_port = getattr(web_server.app.state, "bound_port", None)
|
||||
prev_required = getattr(web_server.app.state, "auth_required", None)
|
||||
web_server.app.state.bound_host = "fly-app.fly.dev"
|
||||
web_server.app.state.bound_port = 443
|
||||
web_server.app.state.auth_required = True
|
||||
client = TestClient(web_server.app, base_url="https://fly-app.fly.dev")
|
||||
yield client
|
||||
clear_providers()
|
||||
_reset_for_tests()
|
||||
web_server.app.state.bound_host = prev_host
|
||||
web_server.app.state.bound_port = prev_port
|
||||
web_server.app.state.auth_required = prev_required
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loopback_app():
|
||||
"""web_server.app configured for loopback mode (gate OFF)."""
|
||||
_reset_for_tests()
|
||||
clear_providers()
|
||||
prev_host = getattr(web_server.app.state, "bound_host", None)
|
||||
prev_port = getattr(web_server.app.state, "bound_port", None)
|
||||
prev_required = getattr(web_server.app.state, "auth_required", None)
|
||||
web_server.app.state.bound_host = "127.0.0.1"
|
||||
web_server.app.state.bound_port = 8080
|
||||
web_server.app.state.auth_required = False
|
||||
client = TestClient(web_server.app, base_url="http://127.0.0.1:8080")
|
||||
yield client
|
||||
_reset_for_tests()
|
||||
web_server.app.state.bound_host = prev_host
|
||||
web_server.app.state.bound_port = prev_port
|
||||
web_server.app.state.auth_required = prev_required
|
||||
|
||||
|
||||
def _logged_in(client: TestClient) -> None:
|
||||
"""Drive the stub OAuth round trip so the client holds session cookies."""
|
||||
r1 = client.get("/auth/login?provider=stub", follow_redirects=False)
|
||||
assert r1.status_code == 302
|
||||
state = r1.headers["location"].split("state=")[1]
|
||||
r2 = client.get(
|
||||
f"/auth/callback?code=stub_code&state={state}", follow_redirects=False
|
||||
)
|
||||
assert r2.status_code == 302
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/auth/ws-ticket — the mint endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWsTicketEndpoint:
|
||||
def test_authenticated_session_can_mint(self, gated_app):
|
||||
_logged_in(gated_app)
|
||||
r = gated_app.post("/api/auth/ws-ticket")
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert "ticket" in body
|
||||
assert isinstance(body["ticket"], str)
|
||||
assert len(body["ticket"]) >= 32
|
||||
assert body["ttl_seconds"] == 30
|
||||
|
||||
def test_unauthenticated_returns_401_or_redirect(self, gated_app):
|
||||
r = gated_app.post("/api/auth/ws-ticket", follow_redirects=False)
|
||||
# gated_auth_middleware short-circuits before the route — it
|
||||
# returns either 401 or 302. Either is fine.
|
||||
assert r.status_code in (302, 401)
|
||||
|
||||
def test_each_call_returns_a_distinct_ticket(self, gated_app):
|
||||
_logged_in(gated_app)
|
||||
tickets = {gated_app.post("/api/auth/ws-ticket").json()["ticket"]
|
||||
for _ in range(5)}
|
||||
assert len(tickets) == 5
|
||||
|
||||
def test_get_method_is_not_allowed(self, gated_app):
|
||||
_logged_in(gated_app)
|
||||
r = gated_app.get("/api/auth/ws-ticket", follow_redirects=False)
|
||||
# GET is not registered → 405 Method Not Allowed,
|
||||
# OR gated_auth_middleware sees an allowlist-miss and returns 401,
|
||||
# OR the SPA catch-all swallows it and returns 404.
|
||||
# Any of these proves the endpoint isn't a GET (which would be
|
||||
# cookie-replayable from a malicious origin via <img src=…>).
|
||||
assert r.status_code in (401, 404, 405)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _ws_auth_ok — unit-level (synthetic WebSocket-shaped object)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _fake_ws(*, query: dict, client_host: str = "127.0.0.1", path: str = "/api/pty"):
|
||||
"""Build a stand-in for starlette.WebSocket good enough for _ws_auth_ok."""
|
||||
|
||||
class _QP:
|
||||
def __init__(self, q):
|
||||
self._q = q
|
||||
|
||||
def get(self, k, default=""):
|
||||
return self._q.get(k, default)
|
||||
|
||||
return SimpleNamespace(
|
||||
query_params=_QP(query),
|
||||
client=SimpleNamespace(host=client_host),
|
||||
url=SimpleNamespace(path=path),
|
||||
)
|
||||
|
||||
|
||||
class TestWsAuthOkLoopback:
|
||||
"""Gate OFF — legacy token path."""
|
||||
|
||||
def test_correct_token_accepted(self, loopback_app):
|
||||
ws = _fake_ws(query={"token": web_server._SESSION_TOKEN})
|
||||
assert web_server._ws_auth_ok(ws) is True
|
||||
|
||||
def test_wrong_token_rejected(self, loopback_app):
|
||||
ws = _fake_ws(query={"token": "not-the-real-token"})
|
||||
assert web_server._ws_auth_ok(ws) is False
|
||||
|
||||
def test_missing_token_rejected(self, loopback_app):
|
||||
ws = _fake_ws(query={})
|
||||
assert web_server._ws_auth_ok(ws) is False
|
||||
|
||||
def test_ticket_param_ignored_in_loopback(self, loopback_app):
|
||||
# Even if someone sneaks a ticket through, loopback mode only
|
||||
# cares about ?token=. A naked ticket isn't a token.
|
||||
ticket = mint_ticket(user_id="u1", provider="stub")
|
||||
ws = _fake_ws(query={"ticket": ticket})
|
||||
assert web_server._ws_auth_ok(ws) is False
|
||||
|
||||
|
||||
class TestWsAuthOkGated:
|
||||
"""Gate ON — ticket path only."""
|
||||
|
||||
def test_valid_ticket_accepted(self, gated_app):
|
||||
ticket = mint_ticket(user_id="u1", provider="stub")
|
||||
ws = _fake_ws(query={"ticket": ticket})
|
||||
assert web_server._ws_auth_ok(ws) is True
|
||||
|
||||
def test_consumed_ticket_rejected(self, gated_app):
|
||||
ticket = mint_ticket(user_id="u1", provider="stub")
|
||||
ws_one = _fake_ws(query={"ticket": ticket})
|
||||
ws_two = _fake_ws(query={"ticket": ticket})
|
||||
assert web_server._ws_auth_ok(ws_one) is True
|
||||
# Single-use — second consumption fails.
|
||||
assert web_server._ws_auth_ok(ws_two) is False
|
||||
|
||||
def test_unknown_ticket_rejected(self, gated_app):
|
||||
ws = _fake_ws(query={"ticket": "never-minted"})
|
||||
assert web_server._ws_auth_ok(ws) is False
|
||||
|
||||
def test_missing_ticket_rejected(self, gated_app):
|
||||
ws = _fake_ws(query={})
|
||||
assert web_server._ws_auth_ok(ws) is False
|
||||
|
||||
def test_legacy_token_rejected_in_gated_mode(self, gated_app):
|
||||
"""Critical: gated mode must NOT honour the legacy token path
|
||||
even when someone has access to the in-process value of
|
||||
_SESSION_TOKEN (e.g. a leaked log line)."""
|
||||
ws = _fake_ws(query={"token": web_server._SESSION_TOKEN})
|
||||
assert web_server._ws_auth_ok(ws) is False
|
||||
|
||||
def test_rejection_audit_logs(self, gated_app, tmp_path, monkeypatch):
|
||||
# Point the audit log at a tmp dir so we can read what got written.
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
from hermes_cli.dashboard_auth import audit as audit_mod
|
||||
|
||||
# The log path is resolved lazily on the first audit_log() call;
|
||||
# bust any cached handler so it re-resolves.
|
||||
if hasattr(audit_mod, "_LOGGER"):
|
||||
monkeypatch.setattr(audit_mod, "_LOGGER", None, raising=False)
|
||||
|
||||
ws = _fake_ws(query={"ticket": "never-minted"})
|
||||
assert web_server._ws_auth_ok(ws) is False
|
||||
|
||||
log_file = tmp_path / "logs" / "dashboard-auth.log"
|
||||
# The audit module may write asynchronously through stdlib logging,
|
||||
# but flush is synchronous. If the file doesn't exist yet, the
|
||||
# logger may not have been initialized in this process — that's
|
||||
# acceptable as long as the rejection path didn't crash.
|
||||
if log_file.exists():
|
||||
content = log_file.read_text()
|
||||
assert "ws_ticket_rejected" in content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_sidecar_url — gated mode mints a server-internal ticket
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSidecarUrl:
|
||||
def test_loopback_uses_session_token(self, loopback_app):
|
||||
url = web_server._build_sidecar_url("ch-1")
|
||||
assert url is not None
|
||||
assert f"token={web_server._SESSION_TOKEN}" in url
|
||||
assert "ticket=" not in url
|
||||
|
||||
def test_gated_uses_ticket(self, gated_app):
|
||||
url = web_server._build_sidecar_url("ch-1")
|
||||
assert url is not None
|
||||
assert "token=" not in url
|
||||
assert "ticket=" in url
|
||||
# And the ticket should be live.
|
||||
ticket = url.split("ticket=")[1].split("&")[0]
|
||||
info = consume_ticket(ticket)
|
||||
# Sidecar tickets are bound to the pseudo-user so audit logs can
|
||||
# distinguish them from real browser tickets.
|
||||
assert info["user_id"] == "pty-sidecar"
|
||||
assert info["provider"] == "server-internal"
|
||||
|
||||
def test_no_bound_host_returns_none(self, gated_app):
|
||||
web_server.app.state.bound_host = None
|
||||
try:
|
||||
assert web_server._build_sidecar_url("ch") is None
|
||||
finally:
|
||||
web_server.app.state.bound_host = "fly-app.fly.dev"
|
||||
Loading…
Add table
Add a link
Reference in a new issue