mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-09 08:21:50 +00:00
fix(web-server): move event channel state from module globals to app.state (#37683)
Module-level asyncio.Lock() binds to whatever event loop was active at
import time. When the same web_server module is reused across multiple
TestClient instances (or across uvicorn reloads), the old lock still
references a defunct loop, causing 'attached to a different loop' errors
and flaky subscriber-registration races in CI.
Replace the module-level _event_channels dict + _event_lock with:
- _lifespan() async context manager that creates both on the running
event loop during FastAPI startup (guaranteed correct loop binding)
- _get_event_state() lazy accessor that initialises on app.state when
TestClient is used without a `with` block (preserves backward compat)
All call sites (_broadcast_event, /api/pub, /api/events) now receive the
app reference and read state via _get_event_state(app) instead of the
module globals. The test polling loop is updated to check
app.state.event_channels rather than the removed module attribute.
This commit is contained in:
parent
a429a2a0bf
commit
cbc82511ea
2 changed files with 52 additions and 13 deletions
|
|
@ -9,6 +9,8 @@ Usage:
|
|||
python -m hermes_cli.main web --port 8080
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
|
|
@ -84,7 +86,43 @@ except ImportError:
|
|||
WEB_DIST = Path(os.environ["HERMES_WEB_DIST"]) if "HERMES_WEB_DIST" in os.environ else Path(__file__).parent / "web_dist"
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="Hermes Agent", version=__version__)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-channel subscriber registry used by /api/pub (PTY-side gateway → dashboard)
|
||||
# and /api/events (dashboard → browser sidebar). Keyed by an opaque channel id
|
||||
# the chat tab generates on mount; entries auto-evict when the last subscriber
|
||||
# drops AND the publisher has disconnected.
|
||||
#
|
||||
# State lives on app.state (not module-level globals) so that asyncio.Lock is
|
||||
# created on the running event loop during lifespan startup. A module-level
|
||||
# asyncio.Lock() binds to whatever loop was active at import time, which breaks
|
||||
# when the same module is used across TestClient instances or uvicorn reloads.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@asynccontextmanager
|
||||
async def _lifespan(app: "FastAPI"):
|
||||
app.state.event_channels = {} # dict[str, set]
|
||||
app.state.event_lock = asyncio.Lock()
|
||||
yield
|
||||
|
||||
|
||||
def _get_event_state(app: "FastAPI"):
|
||||
"""Return (event_channels, event_lock) from app.state.
|
||||
|
||||
Lazily initialises the state if the lifespan hasn't run (e.g. when
|
||||
TestClient is constructed without a ``with`` block). The lifespan
|
||||
path is preferred because it guarantees the Lock is created on the
|
||||
correct event loop, but the lazy path lets existing non-``with``
|
||||
TestClient usages keep working.
|
||||
"""
|
||||
try:
|
||||
return app.state.event_channels, app.state.event_lock
|
||||
except AttributeError:
|
||||
app.state.event_channels = {}
|
||||
app.state.event_lock = asyncio.Lock()
|
||||
return app.state.event_channels, app.state.event_lock
|
||||
|
||||
|
||||
app = FastAPI(title="Hermes Agent", version=__version__, lifespan=_lifespan)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session token for protecting sensitive endpoints (reveal).
|
||||
|
|
@ -6342,8 +6380,7 @@ def _ws_auth_ok(ws: "WebSocket") -> bool:
|
|||
# and /api/events (dashboard → browser sidebar). Keyed by an opaque channel id
|
||||
# the chat tab generates on mount; entries auto-evict when the last subscriber
|
||||
# drops AND the publisher has disconnected.
|
||||
_event_channels: dict[str, set] = {}
|
||||
_event_lock = asyncio.Lock()
|
||||
# (State is initialised in _lifespan on app startup — see above.)
|
||||
|
||||
|
||||
def _resolve_chat_argv(
|
||||
|
|
@ -6452,10 +6489,11 @@ def _build_sidecar_url(channel: str) -> Optional[str]:
|
|||
return f"ws://{netloc}/api/pub?{qs}"
|
||||
|
||||
|
||||
async def _broadcast_event(channel: str, payload: str) -> None:
|
||||
async def _broadcast_event(app: Any, channel: str, payload: str) -> None:
|
||||
"""Fan out one publisher frame to every subscriber on `channel`."""
|
||||
async with _event_lock:
|
||||
subs = list(_event_channels.get(channel, ()))
|
||||
event_channels, event_lock = _get_event_state(app)
|
||||
async with event_lock:
|
||||
subs = list(event_channels.get(channel, ()))
|
||||
|
||||
for sub in subs:
|
||||
try:
|
||||
|
|
@ -6646,7 +6684,7 @@ async def pub_ws(ws: WebSocket) -> None:
|
|||
|
||||
try:
|
||||
while True:
|
||||
await _broadcast_event(channel, await ws.receive_text())
|
||||
await _broadcast_event(ws.app, channel, await ws.receive_text())
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
|
|
@ -6672,8 +6710,9 @@ async def events_ws(ws: WebSocket) -> None:
|
|||
|
||||
await ws.accept()
|
||||
|
||||
async with _event_lock:
|
||||
_event_channels.setdefault(channel, set()).add(ws)
|
||||
event_channels, event_lock = _get_event_state(ws.app)
|
||||
async with event_lock:
|
||||
event_channels.setdefault(channel, set()).add(ws)
|
||||
|
||||
try:
|
||||
while True:
|
||||
|
|
@ -6684,14 +6723,14 @@ async def events_ws(ws: WebSocket) -> None:
|
|||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
async with _event_lock:
|
||||
subs = _event_channels.get(channel)
|
||||
async with event_lock:
|
||||
subs = event_channels.get(channel)
|
||||
|
||||
if subs is not None:
|
||||
subs.discard(ws)
|
||||
|
||||
if not subs:
|
||||
_event_channels.pop(channel, None)
|
||||
event_channels.pop(channel, None)
|
||||
|
||||
|
||||
def _normalise_prefix(raw: Optional[str]) -> str:
|
||||
|
|
|
|||
|
|
@ -3415,7 +3415,7 @@ class TestPtyWebSocket:
|
|||
# subscriber registration and the message is dropped.
|
||||
deadline = time.monotonic() + 5.0
|
||||
while time.monotonic() < deadline:
|
||||
if ws_mod._event_channels.get("broadcast-test"):
|
||||
if ws_mod.app.state.event_channels.get("broadcast-test"):
|
||||
break
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue