diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index c32bb94b868..2a5e7a213fe 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -2372,6 +2372,7 @@ def _make_xai_callback_handler(expected_path: str) -> tuple[type[BaseHTTPRequest "error": None, "error_description": None, } + result_lock = threading.Lock() class _XAICallbackHandler(BaseHTTPRequestHandler): def _maybe_write_cors_headers(self) -> None: @@ -2398,16 +2399,27 @@ def _make_xai_callback_handler(expected_path: str) -> tuple[type[BaseHTTPRequest return params = parse_qs(parsed.query) - result["code"] = params.get("code", [None])[0] - result["state"] = params.get("state", [None])[0] - result["error"] = params.get("error", [None])[0] - result["error_description"] = params.get("error_description", [None])[0] + incoming = { + "code": params.get("code", [None])[0], + "state": params.get("state", [None])[0], + "error": params.get("error", [None])[0], + "error_description": params.get("error_description", [None])[0], + } + # ThreadingHTTPServer allows a fallback/manual callback to complete + # while a browser connection is stuck. Once we have a terminal + # OAuth result (code or error), keep the first one so a later + # concurrent/invalid callback cannot overwrite state before + # validation in _xai_oauth_loopback_login(). + if incoming["code"] or incoming["error"]: + with result_lock: + if not (result["code"] or result["error"]): + result.update(incoming) self.send_response(200) self._maybe_write_cors_headers() self.send_header("Content-Type", "text/html; charset=utf-8") self.end_headers() - if result["error"]: + if incoming["error"]: body = "