mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-23 10:42:00 +00:00
feat(cli): add native Antigravity OAuth provider
This commit is contained in:
parent
29176ffecf
commit
8baa4e9976
25 changed files with 2371 additions and 18 deletions
|
|
@ -1394,6 +1394,21 @@ def create_openai_client(agent, client_kwargs: dict, *, reason: str, shared: boo
|
|||
agent._client_log_context(),
|
||||
)
|
||||
return client
|
||||
if agent.provider == "google-antigravity" or str(client_kwargs.get("base_url", "")).startswith("antigravity-pa://"):
|
||||
from agent.antigravity_cloudcode_adapter import AntigravityCloudCodeClient
|
||||
|
||||
safe_kwargs = {
|
||||
k: v for k, v in client_kwargs.items()
|
||||
if k in {"api_key", "base_url", "default_headers", "project_id", "timeout"}
|
||||
}
|
||||
client = AntigravityCloudCodeClient(**safe_kwargs)
|
||||
_ra().logger.info(
|
||||
"Antigravity Code Assist client created (%s, shared=%s) %s",
|
||||
reason,
|
||||
shared,
|
||||
agent._client_log_context(),
|
||||
)
|
||||
return client
|
||||
if agent.provider == "gemini":
|
||||
from agent.gemini_native_adapter import GeminiNativeClient, is_native_gemini_base_url
|
||||
|
||||
|
|
|
|||
164
agent/antigravity_cloudcode_adapter.py
Normal file
164
agent/antigravity_cloudcode_adapter.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
"""OpenAI-compatible facade for Antigravity native OAuth inference."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from agent import antigravity_oauth
|
||||
from agent.antigravity_code_assist import (
|
||||
ANTIGRAVITY_CODE_ASSIST_ENDPOINT,
|
||||
CodeAssistError,
|
||||
ProjectContext,
|
||||
build_headers,
|
||||
resolve_project_context,
|
||||
)
|
||||
from agent.gemini_cloudcode_adapter import (
|
||||
GeminiCloudCodeClient,
|
||||
_GeminiStreamChunk,
|
||||
_gemini_http_error,
|
||||
_iter_sse_events,
|
||||
_translate_gemini_response,
|
||||
_translate_stream_event,
|
||||
build_gemini_request,
|
||||
wrap_code_assist_request,
|
||||
)
|
||||
|
||||
MARKER_BASE_URL = "antigravity-pa://google"
|
||||
|
||||
|
||||
class AntigravityCloudCodeClient(GeminiCloudCodeClient):
|
||||
"""Minimal OpenAI-SDK-compatible facade over Antigravity Code Assist."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
project_id: str = "",
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or "antigravity-oauth",
|
||||
base_url=base_url or MARKER_BASE_URL,
|
||||
default_headers=default_headers,
|
||||
project_id=project_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _ensure_project_context(self, access_token: str, model: str) -> ProjectContext:
|
||||
if self._project_context is not None:
|
||||
return self._project_context # type: ignore[return-value]
|
||||
|
||||
env_project = antigravity_oauth.resolve_project_id_from_env()
|
||||
creds = antigravity_oauth.load_credentials()
|
||||
stored_project = creds.project_id if creds else ""
|
||||
if stored_project:
|
||||
self._project_context = ProjectContext(
|
||||
project_id=stored_project,
|
||||
managed_project_id=creds.managed_project_id if creds else "",
|
||||
source="stored",
|
||||
)
|
||||
return self._project_context
|
||||
|
||||
ctx = resolve_project_context(
|
||||
access_token,
|
||||
configured_project_id=self._configured_project_id,
|
||||
env_project_id=env_project,
|
||||
)
|
||||
if ctx.project_id or ctx.managed_project_id:
|
||||
antigravity_oauth.update_project_ids(
|
||||
project_id=ctx.project_id,
|
||||
managed_project_id=ctx.managed_project_id,
|
||||
)
|
||||
self._project_context = ctx
|
||||
return ctx
|
||||
|
||||
def _create_chat_completion(
|
||||
self,
|
||||
*,
|
||||
model: str = "gemini-3-flash-agent",
|
||||
messages: Optional[List[Dict[str, Any]]] = None,
|
||||
stream: bool = False,
|
||||
tools: Any = None,
|
||||
tool_choice: Any = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
stop: Any = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
timeout: Any = None,
|
||||
**_: Any,
|
||||
) -> Any:
|
||||
access_token = antigravity_oauth.get_valid_access_token()
|
||||
ctx = self._ensure_project_context(access_token, model)
|
||||
|
||||
thinking_config = None
|
||||
if isinstance(extra_body, dict):
|
||||
thinking_config = extra_body.get("thinking_config") or extra_body.get("thinkingConfig")
|
||||
|
||||
inner = build_gemini_request(
|
||||
messages=messages or [],
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
stop=stop,
|
||||
thinking_config=thinking_config,
|
||||
)
|
||||
wrapped = wrap_code_assist_request(
|
||||
project_id=ctx.project_id,
|
||||
model=model,
|
||||
inner_request=inner,
|
||||
)
|
||||
|
||||
headers = build_headers(access_token)
|
||||
headers.update(self._default_headers)
|
||||
|
||||
if stream:
|
||||
return self._stream_completion(model=model, wrapped=wrapped, headers=headers)
|
||||
|
||||
url = f"{ANTIGRAVITY_CODE_ASSIST_ENDPOINT}/v1internal:generateContent"
|
||||
response = self._http.post(url, json=wrapped, headers=headers)
|
||||
if response.status_code != 200:
|
||||
raise _gemini_http_error(response)
|
||||
try:
|
||||
payload = response.json()
|
||||
except ValueError as exc:
|
||||
raise CodeAssistError(
|
||||
f"Invalid JSON from Antigravity Code Assist: {exc}",
|
||||
code="antigravity_code_assist_invalid_json",
|
||||
) from exc
|
||||
return _translate_gemini_response(payload, model=model)
|
||||
|
||||
def _stream_completion(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
wrapped: Dict[str, Any],
|
||||
headers: Dict[str, str],
|
||||
) -> Iterator[_GeminiStreamChunk]:
|
||||
url = f"{ANTIGRAVITY_CODE_ASSIST_ENDPOINT}/v1internal:streamGenerateContent?alt=sse"
|
||||
stream_headers = dict(headers)
|
||||
stream_headers["Accept"] = "text/event-stream"
|
||||
|
||||
def _generator() -> Iterator[_GeminiStreamChunk]:
|
||||
try:
|
||||
with self._http.stream("POST", url, json=wrapped, headers=stream_headers) as response:
|
||||
if response.status_code != 200:
|
||||
response.read()
|
||||
raise _gemini_http_error(response)
|
||||
tool_call_counter: List[int] = [0]
|
||||
for event in _iter_sse_events(response):
|
||||
for chunk in _translate_stream_event(event, model, tool_call_counter):
|
||||
yield chunk
|
||||
except httpx.HTTPError as exc:
|
||||
raise CodeAssistError(
|
||||
f"Antigravity streaming request failed: {exc}",
|
||||
code="antigravity_code_assist_stream_error",
|
||||
) from exc
|
||||
|
||||
return _generator()
|
||||
276
agent/antigravity_code_assist.py
Normal file
276
agent/antigravity_code_assist.py
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
"""Antigravity Code Assist control-plane helpers.
|
||||
|
||||
The new Antigravity CLI uses the same v1internal Code Assist family as
|
||||
gemini-cli, but with Antigravity OAuth scopes, metadata and model catalog. This
|
||||
module keeps that provider-specific surface separate from
|
||||
``agent.google_code_assist``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from agent.google_code_assist import CodeAssistError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ANTIGRAVITY_CODE_ASSIST_ENDPOINT = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
ANTIGRAVITY_MODEL_ENDPOINTS = [
|
||||
ANTIGRAVITY_CODE_ASSIST_ENDPOINT,
|
||||
"https://cloudcode-pa.googleapis.com",
|
||||
"https://autopush-cloudcode-pa.sandbox.googleapis.com",
|
||||
]
|
||||
|
||||
ANTIGRAVITY_CLIENT_METADATA = {
|
||||
"ideType": "ANTIGRAVITY",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
}
|
||||
ANTIGRAVITY_USER_AGENT = "antigravity/1.0.0 windows/amd64"
|
||||
ANTIGRAVITY_X_GOOG_API_CLIENT = "google-cloud-sdk vscode_cloudshelleditor/0.1"
|
||||
|
||||
DEFAULT_AGENT_MODEL_IDS = [
|
||||
"gemini-3-flash-agent",
|
||||
"gemini-3.5-flash-low",
|
||||
"gemini-pro-agent",
|
||||
"gemini-3.1-pro-low",
|
||||
"claude-sonnet-4-6",
|
||||
"claude-opus-4-6-thinking",
|
||||
"gpt-oss-120b-medium",
|
||||
]
|
||||
|
||||
DEPRECATED_MODEL_REPLACEMENTS = {
|
||||
"gemini-3.1-pro-high": "gemini-pro-agent",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AntigravityProjectInfo:
|
||||
project_id: str = ""
|
||||
raw: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProjectContext:
|
||||
project_id: str = ""
|
||||
managed_project_id: str = ""
|
||||
tier_id: str = ""
|
||||
source: str = ""
|
||||
|
||||
|
||||
def _client_metadata() -> Dict[str, str]:
|
||||
return dict(ANTIGRAVITY_CLIENT_METADATA)
|
||||
|
||||
|
||||
def build_headers(access_token: str, *, accept: str = "application/json") -> Dict[str, str]:
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": accept,
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"User-Agent": ANTIGRAVITY_USER_AGENT,
|
||||
"X-Goog-Api-Client": ANTIGRAVITY_X_GOOG_API_CLIENT,
|
||||
"Client-Metadata": json.dumps(_client_metadata(), separators=(",", ":")),
|
||||
"x-activity-request-id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
|
||||
def _post_json(
|
||||
url: str,
|
||||
body: Dict[str, Any],
|
||||
access_token: str,
|
||||
*,
|
||||
timeout: float = 30.0,
|
||||
) -> Dict[str, Any]:
|
||||
data = json.dumps(body).encode("utf-8")
|
||||
request = urllib.request.Request(
|
||||
url,
|
||||
data=data,
|
||||
method="POST",
|
||||
headers=build_headers(access_token),
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=timeout) as response:
|
||||
raw = response.read().decode("utf-8", errors="replace")
|
||||
return json.loads(raw) if raw else {}
|
||||
except urllib.error.HTTPError as exc:
|
||||
detail = ""
|
||||
try:
|
||||
detail = exc.read().decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
pass
|
||||
raise CodeAssistError(
|
||||
f"Antigravity Code Assist HTTP {exc.code}: {detail or exc.reason}",
|
||||
code=f"antigravity_code_assist_http_{exc.code}",
|
||||
) from exc
|
||||
except urllib.error.URLError as exc:
|
||||
raise CodeAssistError(
|
||||
f"Antigravity Code Assist request failed: {exc}",
|
||||
code="antigravity_code_assist_network_error",
|
||||
) from exc
|
||||
|
||||
|
||||
def load_code_assist(
|
||||
access_token: str,
|
||||
*,
|
||||
project_id: str = "",
|
||||
endpoint: str = ANTIGRAVITY_CODE_ASSIST_ENDPOINT,
|
||||
) -> AntigravityProjectInfo:
|
||||
metadata = _client_metadata()
|
||||
if project_id:
|
||||
metadata["duetProject"] = project_id
|
||||
body: Dict[str, Any] = {"metadata": metadata}
|
||||
if project_id:
|
||||
body["cloudaicompanionProject"] = project_id
|
||||
resp = _post_json(f"{endpoint}/v1internal:loadCodeAssist", body, access_token)
|
||||
project = (
|
||||
str(resp.get("cloudaicompanionProject") or "").strip()
|
||||
or str(resp.get("project") or "").strip()
|
||||
)
|
||||
return AntigravityProjectInfo(project_id=project, raw=resp)
|
||||
|
||||
|
||||
def resolve_project_context(
|
||||
access_token: str,
|
||||
*,
|
||||
configured_project_id: str = "",
|
||||
env_project_id: str = "",
|
||||
) -> ProjectContext:
|
||||
if configured_project_id:
|
||||
return ProjectContext(project_id=configured_project_id, source="config")
|
||||
if env_project_id:
|
||||
return ProjectContext(project_id=env_project_id, source="env")
|
||||
info = load_code_assist(access_token)
|
||||
return ProjectContext(
|
||||
project_id=info.project_id,
|
||||
managed_project_id=info.project_id,
|
||||
source="discovered" if info.project_id else "unknown",
|
||||
)
|
||||
|
||||
|
||||
def fetch_available_models(
|
||||
access_token: str,
|
||||
*,
|
||||
project_id: str = "",
|
||||
endpoint: str = ANTIGRAVITY_CODE_ASSIST_ENDPOINT,
|
||||
) -> Dict[str, Any]:
|
||||
body: Dict[str, Any] = {}
|
||||
if project_id:
|
||||
body["project"] = project_id
|
||||
return _post_json(f"{endpoint}/v1internal:fetchAvailableModels", body, access_token)
|
||||
|
||||
|
||||
def fetch_available_models_with_fallbacks(
|
||||
access_token: str,
|
||||
*,
|
||||
project_id: str = "",
|
||||
endpoints: Optional[Iterable[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
last_err: Optional[Exception] = None
|
||||
for endpoint in endpoints or ANTIGRAVITY_MODEL_ENDPOINTS:
|
||||
try:
|
||||
return fetch_available_models(
|
||||
access_token,
|
||||
project_id=project_id,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
except Exception as exc:
|
||||
last_err = exc
|
||||
logger.debug("Antigravity fetchAvailableModels failed on %s: %s", endpoint, exc)
|
||||
if last_err:
|
||||
raise last_err
|
||||
return {}
|
||||
|
||||
|
||||
def _model_id_from_value(value: Any) -> str:
|
||||
if isinstance(value, str):
|
||||
return value.strip()
|
||||
if isinstance(value, dict):
|
||||
for key in ("modelId", "model_id", "id", "name"):
|
||||
candidate = str(value.get(key) or "").strip()
|
||||
if candidate:
|
||||
return candidate
|
||||
return ""
|
||||
|
||||
|
||||
def _ids_from_sort(sort: Dict[str, Any]) -> List[str]:
|
||||
ids: List[str] = []
|
||||
for key in ("modelIds", "model_ids", "models", "modelSorts"):
|
||||
value = sort.get(key)
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
mid = _model_id_from_value(item)
|
||||
if mid:
|
||||
ids.append(mid)
|
||||
elif isinstance(value, dict):
|
||||
mid = _model_id_from_value(value)
|
||||
if mid:
|
||||
ids.append(mid)
|
||||
return ids
|
||||
|
||||
|
||||
def _is_recommended_sort(sort: Dict[str, Any]) -> bool:
|
||||
label = " ".join(
|
||||
str(sort.get(key) or "")
|
||||
for key in ("name", "displayName", "title", "category", "group")
|
||||
).lower()
|
||||
return "recommended" in label
|
||||
|
||||
|
||||
def _raw_model_ids(payload: Dict[str, Any]) -> List[str]:
|
||||
ids: List[str] = []
|
||||
models = payload.get("models")
|
||||
if isinstance(models, list):
|
||||
for item in models:
|
||||
mid = _model_id_from_value(item)
|
||||
if mid:
|
||||
ids.append(mid)
|
||||
return ids
|
||||
|
||||
|
||||
def filter_agent_model_ids(ids: Iterable[str]) -> List[str]:
|
||||
seen: set[str] = set()
|
||||
filtered: List[str] = []
|
||||
raw = [str(mid).strip() for mid in ids if str(mid).strip()]
|
||||
replacements = set(DEPRECATED_MODEL_REPLACEMENTS.values())
|
||||
for mid in raw:
|
||||
if mid in seen:
|
||||
continue
|
||||
if mid.startswith(("chat_", "tab_")):
|
||||
continue
|
||||
if mid in DEPRECATED_MODEL_REPLACEMENTS and DEPRECATED_MODEL_REPLACEMENTS[mid] in raw:
|
||||
continue
|
||||
if mid in replacements and mid in seen:
|
||||
continue
|
||||
seen.add(mid)
|
||||
filtered.append(mid)
|
||||
return filtered
|
||||
|
||||
|
||||
def parse_agent_model_ids(payload: Dict[str, Any]) -> List[str]:
|
||||
"""Return the user-facing Antigravity agent model list in display order."""
|
||||
sorts = payload.get("agentModelSorts")
|
||||
ordered: List[str] = []
|
||||
if isinstance(sorts, list):
|
||||
recommended = [s for s in sorts if isinstance(s, dict) and _is_recommended_sort(s)]
|
||||
rest = [s for s in sorts if isinstance(s, dict) and not _is_recommended_sort(s)]
|
||||
for sort in recommended + rest:
|
||||
ordered.extend(_ids_from_sort(sort))
|
||||
|
||||
if not ordered:
|
||||
default_id = str(payload.get("defaultAgentModelId") or "").strip()
|
||||
if default_id:
|
||||
ordered.append(default_id)
|
||||
for mid in DEFAULT_AGENT_MODEL_IDS:
|
||||
ordered.append(mid)
|
||||
ordered.extend(_raw_model_ids(payload))
|
||||
|
||||
filtered = filter_agent_model_ids(ordered)
|
||||
if filtered:
|
||||
return filtered
|
||||
return list(DEFAULT_AGENT_MODEL_IDS)
|
||||
872
agent/antigravity_oauth.py
Normal file
872
agent/antigravity_oauth.py
Normal file
|
|
@ -0,0 +1,872 @@
|
|||
"""Google OAuth PKCE flow for the Antigravity (google-antigravity) provider.
|
||||
|
||||
Tokens are stored separately from the existing ``google-gemini-cli`` provider so
|
||||
development and production credentials do not accidentally bleed across:
|
||||
|
||||
~/.hermes/auth/antigravity_oauth.json
|
||||
|
||||
The on-disk schema matches ``agent.google_oauth`` so the runtime resolver can
|
||||
share the same refresh/project-id packing convention.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import hashlib
|
||||
import http.server
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import shutil
|
||||
import stat
|
||||
import threading
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
import webbrowser
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from utils import atomic_replace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ENV_CLIENT_ID = "HERMES_ANTIGRAVITY_CLIENT_ID"
|
||||
ENV_CLIENT_SECRET = "HERMES_ANTIGRAVITY_CLIENT_SECRET"
|
||||
ENV_CLI_PATH = "HERMES_ANTIGRAVITY_CLI_PATH"
|
||||
|
||||
_CLIENT_ID_PATTERN = re.compile(
|
||||
r"([0-9]{8,}-[a-z0-9]{20,}\.apps\.googleusercontent\.com)"
|
||||
)
|
||||
_CLIENT_SECRET_PATTERN = re.compile(r"(GOCSPX-[A-Za-z0-9_-]{20,80})")
|
||||
_DISCOVERY_MAX_FILE_BYTES = 25 * 1024 * 1024
|
||||
_DISCOVERY_MAX_AGY_BINARY_BYTES = 220 * 1024 * 1024
|
||||
_DISCOVERY_MAX_FILES = 600
|
||||
_DISCOVERY_EXTENSIONS = {
|
||||
"",
|
||||
".cjs",
|
||||
".exe",
|
||||
".js",
|
||||
".json",
|
||||
".mjs",
|
||||
".node",
|
||||
".ts",
|
||||
}
|
||||
_DISCOVERY_SKIP_DIRS = {
|
||||
".system_generated",
|
||||
"brain",
|
||||
"conversations",
|
||||
"log",
|
||||
"logs",
|
||||
"scratch",
|
||||
}
|
||||
|
||||
AUTH_ENDPOINT = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token"
|
||||
USERINFO_ENDPOINT = "https://www.googleapis.com/oauth2/v1/userinfo"
|
||||
|
||||
OAUTH_SCOPES = (
|
||||
"https://www.googleapis.com/auth/cloud-platform "
|
||||
"https://www.googleapis.com/auth/userinfo.email "
|
||||
"https://www.googleapis.com/auth/userinfo.profile "
|
||||
"https://www.googleapis.com/auth/cclog "
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs"
|
||||
)
|
||||
|
||||
DEFAULT_REDIRECT_PORT = 51121
|
||||
REDIRECT_HOST = "localhost"
|
||||
CALLBACK_PATH = "/oauth-callback"
|
||||
REFRESH_SKEW_SECONDS = 60
|
||||
TOKEN_REQUEST_TIMEOUT_SECONDS = 20.0
|
||||
CALLBACK_WAIT_SECONDS = 300
|
||||
LOCK_TIMEOUT_SECONDS = 30.0
|
||||
|
||||
|
||||
class AntigravityOAuthError(RuntimeError):
|
||||
def __init__(self, message: str, *, code: str = "antigravity_oauth_error") -> None:
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
|
||||
|
||||
def _credentials_path() -> Path:
|
||||
return get_hermes_home() / "auth" / "antigravity_oauth.json"
|
||||
|
||||
|
||||
def _lock_path() -> Path:
|
||||
return _credentials_path().with_suffix(".json.lock")
|
||||
|
||||
|
||||
_lock_state = threading.local()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _credentials_lock(timeout_seconds: float = LOCK_TIMEOUT_SECONDS):
|
||||
depth = getattr(_lock_state, "depth", 0)
|
||||
if depth > 0:
|
||||
_lock_state.depth = depth + 1
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_lock_state.depth -= 1
|
||||
return
|
||||
|
||||
lock_file_path = _lock_path()
|
||||
lock_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd = os.open(str(lock_file_path), os.O_CREAT | os.O_RDWR, 0o600)
|
||||
acquired = False
|
||||
try:
|
||||
try:
|
||||
import fcntl
|
||||
except ImportError:
|
||||
fcntl = None
|
||||
|
||||
if fcntl is not None:
|
||||
deadline = time.monotonic() + max(0.0, float(timeout_seconds))
|
||||
while True:
|
||||
try:
|
||||
fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
acquired = True
|
||||
break
|
||||
except BlockingIOError:
|
||||
if time.monotonic() >= deadline:
|
||||
raise TimeoutError(
|
||||
f"Timed out acquiring Antigravity OAuth credentials lock at {lock_file_path}."
|
||||
)
|
||||
time.sleep(0.05)
|
||||
else:
|
||||
try:
|
||||
import msvcrt # type: ignore[import-not-found]
|
||||
|
||||
deadline = time.monotonic() + max(0.0, float(timeout_seconds))
|
||||
while True:
|
||||
try:
|
||||
msvcrt.locking(fd, msvcrt.LK_NBLCK, 1)
|
||||
acquired = True
|
||||
break
|
||||
except OSError:
|
||||
if time.monotonic() >= deadline:
|
||||
raise TimeoutError(
|
||||
f"Timed out acquiring Antigravity OAuth credentials lock at {lock_file_path}."
|
||||
)
|
||||
time.sleep(0.05)
|
||||
except ImportError:
|
||||
acquired = True
|
||||
|
||||
_lock_state.depth = 1
|
||||
yield
|
||||
finally:
|
||||
try:
|
||||
if acquired:
|
||||
try:
|
||||
import fcntl
|
||||
|
||||
fcntl.flock(fd, fcntl.LOCK_UN)
|
||||
except ImportError:
|
||||
try:
|
||||
import msvcrt # type: ignore[import-not-found]
|
||||
|
||||
try:
|
||||
msvcrt.locking(fd, msvcrt.LK_UNLCK, 1)
|
||||
except OSError:
|
||||
pass
|
||||
except ImportError:
|
||||
pass
|
||||
finally:
|
||||
os.close(fd)
|
||||
_lock_state.depth = 0
|
||||
|
||||
|
||||
_discovered_creds_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def _secret_candidates(raw: str) -> list[str]:
|
||||
candidates: list[str] = []
|
||||
for length in (35, 34, 36, 33, 37, 38, 39, 40, 41, 42):
|
||||
if len(raw) >= length:
|
||||
candidates.append(raw[:length])
|
||||
candidates.append(raw)
|
||||
return list(dict.fromkeys(candidates))
|
||||
|
||||
|
||||
def _candidate_discovery_roots() -> list[Path]:
|
||||
roots: list[Path] = []
|
||||
|
||||
explicit = (os.getenv(ENV_CLI_PATH) or "").strip()
|
||||
if explicit:
|
||||
roots.append(Path(explicit))
|
||||
|
||||
for command in ("agy", "agy.exe", "antigravity", "antigravity.exe"):
|
||||
found = shutil.which(command)
|
||||
if found:
|
||||
roots.append(Path(found))
|
||||
|
||||
for env_key in ("LOCALAPPDATA", "APPDATA", "ProgramFiles", "ProgramFiles(x86)"):
|
||||
base = os.getenv(env_key)
|
||||
if not base:
|
||||
continue
|
||||
base_path = Path(base)
|
||||
roots.extend((
|
||||
base_path / "agy",
|
||||
base_path / "agy" / "bin" / "agy.exe",
|
||||
base_path / "Programs" / "Antigravity",
|
||||
base_path / "Programs" / "Antigravity CLI",
|
||||
base_path / "Google" / "Antigravity",
|
||||
base_path / "Google" / "Antigravity CLI",
|
||||
))
|
||||
|
||||
home = Path.home()
|
||||
for root in (
|
||||
home / ".gemini" / "antigravity-cli",
|
||||
home / ".antigravitycli",
|
||||
home / ".antigravity",
|
||||
):
|
||||
roots.append(root)
|
||||
|
||||
unique: list[Path] = []
|
||||
seen: set[str] = set()
|
||||
for root in roots:
|
||||
try:
|
||||
key = str(root.expanduser().resolve())
|
||||
except OSError:
|
||||
key = str(root.expanduser())
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique.append(root)
|
||||
return unique
|
||||
|
||||
|
||||
def _iter_discovery_files() -> list[Path]:
|
||||
files: list[Path] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
def add(path: Path) -> None:
|
||||
if len(files) >= _DISCOVERY_MAX_FILES:
|
||||
return
|
||||
if path.suffix.lower() not in _DISCOVERY_EXTENSIONS:
|
||||
return
|
||||
try:
|
||||
stat_info = path.stat()
|
||||
max_bytes = (
|
||||
_DISCOVERY_MAX_AGY_BINARY_BYTES
|
||||
if path.name.lower() in {"agy", "agy.exe", "antigravity", "antigravity.exe"}
|
||||
else _DISCOVERY_MAX_FILE_BYTES
|
||||
)
|
||||
if not path.is_file() or stat_info.st_size > max_bytes:
|
||||
return
|
||||
key = str(path.resolve())
|
||||
except OSError:
|
||||
return
|
||||
if key in seen:
|
||||
return
|
||||
seen.add(key)
|
||||
files.append(path)
|
||||
|
||||
for root in _candidate_discovery_roots():
|
||||
if len(files) >= _DISCOVERY_MAX_FILES:
|
||||
break
|
||||
try:
|
||||
if root.is_file():
|
||||
add(root)
|
||||
continue
|
||||
if not root.is_dir():
|
||||
continue
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
for dirpath, dirnames, filenames in os.walk(root):
|
||||
dirnames[:] = [
|
||||
d for d in dirnames
|
||||
if d not in _DISCOVERY_SKIP_DIRS and not d.startswith(".git")
|
||||
]
|
||||
for filename in filenames:
|
||||
add(Path(dirpath) / filename)
|
||||
if len(files) >= _DISCOVERY_MAX_FILES:
|
||||
break
|
||||
if len(files) >= _DISCOVERY_MAX_FILES:
|
||||
break
|
||||
return files
|
||||
|
||||
|
||||
def _extract_client_credential_candidates_from_text(content: str) -> list[Tuple[str, str]]:
|
||||
client_ids = list(dict.fromkeys(match.group(1) for match in _CLIENT_ID_PATTERN.finditer(content)))
|
||||
secrets: list[str] = []
|
||||
for match in _CLIENT_SECRET_PATTERN.finditer(content):
|
||||
secrets.extend(_secret_candidates(match.group(1)))
|
||||
secrets = list(dict.fromkeys(secrets))
|
||||
return [(client_id, secret) for client_id in client_ids for secret in secrets]
|
||||
|
||||
|
||||
def _discover_client_credentials() -> Tuple[str, str]:
|
||||
if _discovered_creds_cache.get("resolved"):
|
||||
return (
|
||||
_discovered_creds_cache.get("client_id", ""),
|
||||
_discovered_creds_cache.get("client_secret", ""),
|
||||
)
|
||||
|
||||
for path in _iter_discovery_files():
|
||||
try:
|
||||
content = path.read_bytes().decode("utf-8", errors="ignore")
|
||||
except OSError:
|
||||
continue
|
||||
candidates = _extract_client_credential_candidates_from_text(content)
|
||||
if candidates:
|
||||
client_id, client_secret = candidates[0]
|
||||
_discovered_creds_cache.update({
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"candidates": candidates,
|
||||
"resolved": "1",
|
||||
})
|
||||
logger.info("Discovered Antigravity OAuth client credentials from %s", path)
|
||||
return client_id, client_secret
|
||||
|
||||
_discovered_creds_cache["resolved"] = "1"
|
||||
return "", ""
|
||||
|
||||
|
||||
def _get_client_id() -> str:
|
||||
env_val = (os.getenv(ENV_CLIENT_ID) or "").strip()
|
||||
if env_val:
|
||||
return env_val
|
||||
discovered, _ = _discover_client_credentials()
|
||||
return discovered
|
||||
|
||||
|
||||
def _get_client_secret() -> str:
|
||||
env_val = (os.getenv(ENV_CLIENT_SECRET) or "").strip()
|
||||
if env_val:
|
||||
return env_val
|
||||
_, discovered = _discover_client_credentials()
|
||||
return discovered
|
||||
|
||||
|
||||
def _iter_client_credential_candidates() -> list[Tuple[str, str]]:
|
||||
env_id = (os.getenv(ENV_CLIENT_ID) or "").strip()
|
||||
env_secret = (os.getenv(ENV_CLIENT_SECRET) or "").strip()
|
||||
if env_id and env_secret:
|
||||
return [(env_id, env_secret)]
|
||||
|
||||
_discover_client_credentials()
|
||||
cached = _discovered_creds_cache.get("candidates")
|
||||
if isinstance(cached, list):
|
||||
return [
|
||||
(str(client_id), str(client_secret))
|
||||
for client_id, client_secret in cached
|
||||
if client_id and client_secret
|
||||
]
|
||||
client_id = str(_discovered_creds_cache.get("client_id") or "")
|
||||
client_secret = str(_discovered_creds_cache.get("client_secret") or "")
|
||||
return [(client_id, client_secret)] if client_id and client_secret else []
|
||||
|
||||
|
||||
def _require_client_id() -> str:
|
||||
client_id = _get_client_id()
|
||||
if not client_id:
|
||||
raise AntigravityOAuthError(
|
||||
"Antigravity OAuth client ID is not available. Install Antigravity CLI "
|
||||
"so Hermes can discover its desktop OAuth client, set "
|
||||
f"{ENV_CLI_PATH} to the agy executable, or set {ENV_CLIENT_ID} and "
|
||||
f"{ENV_CLIENT_SECRET} in ~/.hermes/.env.",
|
||||
code="antigravity_oauth_client_id_missing",
|
||||
)
|
||||
return client_id
|
||||
|
||||
|
||||
def _require_client_secret() -> str:
|
||||
client_secret = _get_client_secret()
|
||||
if not client_secret:
|
||||
raise AntigravityOAuthError(
|
||||
"Antigravity OAuth client secret is not available. Install Antigravity CLI "
|
||||
"so Hermes can discover its desktop OAuth client, set "
|
||||
f"{ENV_CLI_PATH} to the agy executable, or set {ENV_CLIENT_ID} and "
|
||||
f"{ENV_CLIENT_SECRET} in ~/.hermes/.env.",
|
||||
code="antigravity_oauth_client_secret_missing",
|
||||
)
|
||||
return client_secret
|
||||
|
||||
|
||||
def _require_client_credentials() -> Tuple[str, str]:
|
||||
candidates = _iter_client_credential_candidates()
|
||||
if not candidates:
|
||||
_require_client_id()
|
||||
_require_client_secret()
|
||||
return candidates[0]
|
||||
|
||||
|
||||
def _generate_pkce_pair() -> Tuple[str, str]:
|
||||
verifier = secrets.token_urlsafe(64)
|
||||
digest = hashlib.sha256(verifier.encode("ascii")).digest()
|
||||
challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
||||
return verifier, challenge
|
||||
|
||||
|
||||
@dataclass
|
||||
class RefreshParts:
|
||||
refresh_token: str
|
||||
project_id: str = ""
|
||||
managed_project_id: str = ""
|
||||
|
||||
@classmethod
|
||||
def parse(cls, packed: str) -> "RefreshParts":
|
||||
if not packed:
|
||||
return cls(refresh_token="")
|
||||
parts = packed.split("|", 2)
|
||||
return cls(
|
||||
refresh_token=parts[0],
|
||||
project_id=parts[1] if len(parts) > 1 else "",
|
||||
managed_project_id=parts[2] if len(parts) > 2 else "",
|
||||
)
|
||||
|
||||
def format(self) -> str:
|
||||
if not self.refresh_token:
|
||||
return ""
|
||||
if not self.project_id and not self.managed_project_id:
|
||||
return self.refresh_token
|
||||
return f"{self.refresh_token}|{self.project_id}|{self.managed_project_id}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AntigravityCredentials:
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_ms: int
|
||||
email: str = ""
|
||||
project_id: str = ""
|
||||
managed_project_id: str = ""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"refresh": RefreshParts(
|
||||
refresh_token=self.refresh_token,
|
||||
project_id=self.project_id,
|
||||
managed_project_id=self.managed_project_id,
|
||||
).format(),
|
||||
"access": self.access_token,
|
||||
"expires": int(self.expires_ms),
|
||||
"email": self.email,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AntigravityCredentials":
|
||||
parts = RefreshParts.parse(str(data.get("refresh", "") or ""))
|
||||
return cls(
|
||||
access_token=str(data.get("access", "") or ""),
|
||||
refresh_token=parts.refresh_token,
|
||||
expires_ms=int(data.get("expires", 0) or 0),
|
||||
email=str(data.get("email", "") or ""),
|
||||
project_id=parts.project_id,
|
||||
managed_project_id=parts.managed_project_id,
|
||||
)
|
||||
|
||||
def access_token_expired(self, skew_seconds: int = REFRESH_SKEW_SECONDS) -> bool:
|
||||
if not self.access_token or not self.expires_ms:
|
||||
return True
|
||||
return (time.time() + max(0, skew_seconds)) * 1000 >= self.expires_ms
|
||||
|
||||
|
||||
def load_credentials() -> Optional[AntigravityCredentials]:
|
||||
path = _credentials_path()
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
with _credentials_lock():
|
||||
raw = path.read_text(encoding="utf-8")
|
||||
data = json.loads(raw)
|
||||
except (json.JSONDecodeError, OSError, IOError) as exc:
|
||||
logger.warning("Failed to read Antigravity OAuth credentials at %s: %s", path, exc)
|
||||
return None
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
creds = AntigravityCredentials.from_dict(data)
|
||||
if not creds.access_token:
|
||||
return None
|
||||
return creds
|
||||
|
||||
|
||||
def save_credentials(creds: AntigravityCredentials) -> Path:
|
||||
path = _credentials_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
os.chmod(path.parent, 0o700)
|
||||
except OSError:
|
||||
pass
|
||||
payload = json.dumps(creds.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
with _credentials_lock():
|
||||
tmp_path = path.with_suffix(f".tmp.{os.getpid()}.{secrets.token_hex(4)}")
|
||||
try:
|
||||
fd = os.open(
|
||||
str(tmp_path),
|
||||
os.O_WRONLY | os.O_CREAT | os.O_EXCL,
|
||||
stat.S_IRUSR | stat.S_IWUSR,
|
||||
)
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
||||
fh.write(payload)
|
||||
fh.flush()
|
||||
os.fsync(fh.fileno())
|
||||
atomic_replace(tmp_path, path)
|
||||
finally:
|
||||
try:
|
||||
if tmp_path.exists():
|
||||
tmp_path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
return path
|
||||
|
||||
|
||||
def clear_credentials() -> None:
|
||||
path = _credentials_path()
|
||||
with _credentials_lock():
|
||||
try:
|
||||
path.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError as exc:
|
||||
logger.warning("Failed to remove Antigravity OAuth credentials at %s: %s", path, exc)
|
||||
|
||||
|
||||
def _post_form(url: str, data: Dict[str, str], timeout: float) -> Dict[str, Any]:
|
||||
body = urllib.parse.urlencode(data).encode("ascii")
|
||||
request = urllib.request.Request(
|
||||
url,
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=timeout) as response:
|
||||
raw = response.read().decode("utf-8", errors="replace")
|
||||
return json.loads(raw)
|
||||
except urllib.error.HTTPError as exc:
|
||||
detail = ""
|
||||
try:
|
||||
detail = exc.read().decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
pass
|
||||
code = "antigravity_oauth_token_http_error"
|
||||
if "invalid_grant" in detail.lower():
|
||||
code = "antigravity_oauth_invalid_grant"
|
||||
elif "invalid_client" in detail.lower():
|
||||
code = "antigravity_oauth_invalid_client"
|
||||
raise AntigravityOAuthError(
|
||||
f"Antigravity OAuth token endpoint returned HTTP {exc.code}: {detail or exc.reason}",
|
||||
code=code,
|
||||
) from exc
|
||||
except urllib.error.URLError as exc:
|
||||
raise AntigravityOAuthError(
|
||||
f"Antigravity OAuth token request failed: {exc}",
|
||||
code="antigravity_oauth_token_network_error",
|
||||
) from exc
|
||||
|
||||
|
||||
def exchange_code(
|
||||
code: str,
|
||||
verifier: str,
|
||||
redirect_uri: str,
|
||||
*,
|
||||
timeout: float = TOKEN_REQUEST_TIMEOUT_SECONDS,
|
||||
) -> Dict[str, Any]:
|
||||
last_error: Optional[AntigravityOAuthError] = None
|
||||
candidates = _iter_client_credential_candidates()
|
||||
if not candidates:
|
||||
candidates = [_require_client_credentials()]
|
||||
for client_id, client_secret in candidates:
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"code_verifier": verifier,
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
try:
|
||||
return _post_form(TOKEN_ENDPOINT, data, timeout)
|
||||
except AntigravityOAuthError as exc:
|
||||
last_error = exc
|
||||
if exc.code != "antigravity_oauth_invalid_client":
|
||||
raise
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise AntigravityOAuthError(
|
||||
"Antigravity OAuth client credentials are unavailable.",
|
||||
code="antigravity_oauth_client_missing",
|
||||
)
|
||||
|
||||
|
||||
def refresh_access_token(
|
||||
refresh_token: str,
|
||||
*,
|
||||
timeout: float = TOKEN_REQUEST_TIMEOUT_SECONDS,
|
||||
) -> Dict[str, Any]:
|
||||
if not refresh_token:
|
||||
raise AntigravityOAuthError(
|
||||
"Cannot refresh: refresh_token is empty. Re-run OAuth login.",
|
||||
code="antigravity_oauth_refresh_token_missing",
|
||||
)
|
||||
last_error: Optional[AntigravityOAuthError] = None
|
||||
candidates = _iter_client_credential_candidates()
|
||||
if not candidates:
|
||||
candidates = [_require_client_credentials()]
|
||||
for client_id, client_secret in candidates:
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
}
|
||||
try:
|
||||
return _post_form(TOKEN_ENDPOINT, data, timeout)
|
||||
except AntigravityOAuthError as exc:
|
||||
last_error = exc
|
||||
if exc.code not in {
|
||||
"antigravity_oauth_invalid_client",
|
||||
"antigravity_oauth_invalid_grant",
|
||||
}:
|
||||
raise
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise AntigravityOAuthError(
|
||||
"Antigravity OAuth client credentials are unavailable.",
|
||||
code="antigravity_oauth_client_missing",
|
||||
)
|
||||
|
||||
|
||||
def _fetch_user_email(access_token: str, timeout: float = TOKEN_REQUEST_TIMEOUT_SECONDS) -> str:
|
||||
try:
|
||||
request = urllib.request.Request(
|
||||
USERINFO_ENDPOINT + "?alt=json",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
with urllib.request.urlopen(request, timeout=timeout) as response:
|
||||
raw = response.read().decode("utf-8", errors="replace")
|
||||
data = json.loads(raw)
|
||||
return str(data.get("email", "") or "")
|
||||
except Exception as exc:
|
||||
logger.debug("Antigravity userinfo fetch failed (non-fatal): %s", exc)
|
||||
return ""
|
||||
|
||||
|
||||
_refresh_inflight: Dict[str, threading.Event] = {}
|
||||
_refresh_inflight_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_valid_access_token(*, force_refresh: bool = False) -> str:
|
||||
creds = load_credentials()
|
||||
if creds is None:
|
||||
raise AntigravityOAuthError(
|
||||
"No Antigravity OAuth credentials found. Run `hermes login --provider google-antigravity` first.",
|
||||
code="antigravity_oauth_not_logged_in",
|
||||
)
|
||||
if not force_refresh and not creds.access_token_expired():
|
||||
return creds.access_token
|
||||
|
||||
rt = creds.refresh_token
|
||||
with _refresh_inflight_lock:
|
||||
event = _refresh_inflight.get(rt)
|
||||
if event is None:
|
||||
event = threading.Event()
|
||||
_refresh_inflight[rt] = event
|
||||
owner = True
|
||||
else:
|
||||
owner = False
|
||||
|
||||
if not owner:
|
||||
event.wait(timeout=LOCK_TIMEOUT_SECONDS)
|
||||
fresh = load_credentials()
|
||||
if fresh is not None and not fresh.access_token_expired():
|
||||
return fresh.access_token
|
||||
|
||||
try:
|
||||
try:
|
||||
resp = refresh_access_token(rt)
|
||||
except AntigravityOAuthError as exc:
|
||||
if exc.code == "antigravity_oauth_invalid_grant":
|
||||
clear_credentials()
|
||||
raise
|
||||
new_access = str(resp.get("access_token", "") or "").strip()
|
||||
if not new_access:
|
||||
raise AntigravityOAuthError(
|
||||
"Refresh response did not include an access_token.",
|
||||
code="antigravity_oauth_refresh_empty",
|
||||
)
|
||||
creds.access_token = new_access
|
||||
creds.refresh_token = str(resp.get("refresh_token", "") or "").strip() or creds.refresh_token
|
||||
expires_in = int(resp.get("expires_in", 0) or 0)
|
||||
creds.expires_ms = int((time.time() + max(60, expires_in)) * 1000)
|
||||
save_credentials(creds)
|
||||
return creds.access_token
|
||||
finally:
|
||||
if owner:
|
||||
with _refresh_inflight_lock:
|
||||
_refresh_inflight.pop(rt, None)
|
||||
event.set()
|
||||
|
||||
|
||||
def update_project_ids(project_id: str = "", managed_project_id: str = "") -> None:
|
||||
creds = load_credentials()
|
||||
if creds is None:
|
||||
return
|
||||
if project_id:
|
||||
creds.project_id = project_id
|
||||
if managed_project_id:
|
||||
creds.managed_project_id = managed_project_id
|
||||
save_credentials(creds)
|
||||
|
||||
|
||||
class _OAuthCallbackHandler(http.server.BaseHTTPRequestHandler):
|
||||
expected_state: str = ""
|
||||
captured_code: Optional[str] = None
|
||||
captured_error: Optional[str] = None
|
||||
ready: Optional[threading.Event] = None
|
||||
|
||||
def log_message(self, format: str, *args: Any) -> None: # noqa: A002, N802
|
||||
logger.debug("Antigravity OAuth callback: " + format, *args)
|
||||
|
||||
def do_GET(self) -> None: # noqa: N802
|
||||
parsed = urllib.parse.urlparse(self.path)
|
||||
if parsed.path != CALLBACK_PATH:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
return
|
||||
|
||||
params = urllib.parse.parse_qs(parsed.query)
|
||||
state = (params.get("state") or [""])[0]
|
||||
error = (params.get("error") or [""])[0]
|
||||
code = (params.get("code") or [""])[0]
|
||||
|
||||
handler_cls = type(self)
|
||||
if state != self.expected_state:
|
||||
handler_cls.captured_error = "OAuth state mismatch."
|
||||
elif error:
|
||||
handler_cls.captured_error = error
|
||||
elif not code:
|
||||
handler_cls.captured_error = "OAuth callback did not include a code."
|
||||
else:
|
||||
handler_cls.captured_code = code
|
||||
|
||||
ok = not handler_cls.captured_error
|
||||
self.send_response(200 if ok else 400)
|
||||
self.send_header("Content-Type", "text/html; charset=utf-8")
|
||||
self.end_headers()
|
||||
msg = "Antigravity OAuth complete. You can return to Hermes." if ok else handler_cls.captured_error
|
||||
self.wfile.write(f"<html><body><p>{msg}</p></body></html>".encode("utf-8"))
|
||||
if handler_cls.ready is not None:
|
||||
handler_cls.ready.set()
|
||||
|
||||
|
||||
class _ReusableHTTPServer(http.server.HTTPServer):
|
||||
allow_reuse_address = True
|
||||
|
||||
|
||||
def resolve_project_id_from_env() -> str:
|
||||
for key in ("HERMES_ANTIGRAVITY_PROJECT_ID", "GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT_ID"):
|
||||
value = (os.getenv(key) or "").strip()
|
||||
if value:
|
||||
return value
|
||||
return ""
|
||||
|
||||
|
||||
def start_oauth_flow(
|
||||
*,
|
||||
force_relogin: bool = False,
|
||||
open_browser: bool = True,
|
||||
port: int = DEFAULT_REDIRECT_PORT,
|
||||
project_id: str = "",
|
||||
) -> AntigravityCredentials:
|
||||
if not force_relogin:
|
||||
existing = load_credentials()
|
||||
if existing and not existing.access_token_expired():
|
||||
return existing
|
||||
|
||||
verifier, challenge = _generate_pkce_pair()
|
||||
state = secrets.token_urlsafe(24)
|
||||
client_id, _ = _require_client_credentials()
|
||||
|
||||
ready = threading.Event()
|
||||
handler_cls = type("AntigravityOAuthCallbackHandler", (_OAuthCallbackHandler,), {})
|
||||
handler_cls.expected_state = state
|
||||
handler_cls.captured_code = None
|
||||
handler_cls.captured_error = None
|
||||
handler_cls.ready = ready
|
||||
|
||||
try:
|
||||
server = _ReusableHTTPServer((REDIRECT_HOST, int(port)), handler_cls)
|
||||
except OSError:
|
||||
server = _ReusableHTTPServer((REDIRECT_HOST, 0), handler_cls)
|
||||
actual_port = int(server.server_address[1])
|
||||
redirect_uri = f"http://{REDIRECT_HOST}:{actual_port}{CALLBACK_PATH}"
|
||||
|
||||
thread = threading.Thread(target=server.serve_forever, daemon=True)
|
||||
thread.start()
|
||||
try:
|
||||
params = {
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": OAUTH_SCOPES,
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
"state": state,
|
||||
"code_challenge": challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
auth_url = AUTH_ENDPOINT + "?" + urllib.parse.urlencode(params)
|
||||
print("Open this URL to authorize Antigravity OAuth:")
|
||||
print(auth_url)
|
||||
if open_browser:
|
||||
webbrowser.open(auth_url)
|
||||
if not ready.wait(timeout=CALLBACK_WAIT_SECONDS):
|
||||
raise AntigravityOAuthError(
|
||||
"Timed out waiting for Antigravity OAuth callback.",
|
||||
code="antigravity_oauth_callback_timeout",
|
||||
)
|
||||
if handler_cls.captured_error:
|
||||
raise AntigravityOAuthError(
|
||||
handler_cls.captured_error,
|
||||
code="antigravity_oauth_callback_error",
|
||||
)
|
||||
code = handler_cls.captured_code or ""
|
||||
token = exchange_code(code, verifier, redirect_uri)
|
||||
finally:
|
||||
server.shutdown()
|
||||
server.server_close()
|
||||
|
||||
access_token = str(token.get("access_token", "") or "").strip()
|
||||
refresh_token = str(token.get("refresh_token", "") or "").strip()
|
||||
if not access_token or not refresh_token:
|
||||
raise AntigravityOAuthError(
|
||||
"Antigravity OAuth response did not include both access_token and refresh_token.",
|
||||
code="antigravity_oauth_missing_token",
|
||||
)
|
||||
expires_in = int(token.get("expires_in", 0) or 0)
|
||||
creds = AntigravityCredentials(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_ms=int((time.time() + max(60, expires_in)) * 1000),
|
||||
email=_fetch_user_email(access_token),
|
||||
project_id=project_id,
|
||||
)
|
||||
save_credentials(creds)
|
||||
return creds
|
||||
|
||||
|
||||
def run_antigravity_oauth_login_pure() -> Dict[str, Any]:
|
||||
creds = start_oauth_flow(
|
||||
force_relogin=True,
|
||||
project_id=resolve_project_id_from_env(),
|
||||
)
|
||||
return {
|
||||
"access_token": creds.access_token,
|
||||
"refresh_token": creds.refresh_token,
|
||||
"expires_at_ms": creds.expires_ms,
|
||||
"email": creds.email,
|
||||
"project_id": creds.project_id,
|
||||
}
|
||||
|
|
@ -93,11 +93,14 @@ def _translate_tool_call_to_gemini(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
|||
args = {"_raw": args_raw}
|
||||
if not isinstance(args, dict):
|
||||
args = {"_value": args}
|
||||
function_call = {
|
||||
"name": fn.get("name") or "",
|
||||
"args": args,
|
||||
}
|
||||
if tool_call.get("id"):
|
||||
function_call["id"] = str(tool_call["id"])
|
||||
return {
|
||||
"functionCall": {
|
||||
"name": fn.get("name") or "",
|
||||
"args": args,
|
||||
},
|
||||
"functionCall": function_call,
|
||||
# Sentinel signature — matches opencode-gemini-auth's approach.
|
||||
# Without this, Code Assist rejects function calls that originated
|
||||
# outside its own chain.
|
||||
|
|
@ -122,12 +125,13 @@ def _translate_tool_result_to_gemini(message: Dict[str, Any]) -> Dict[str, Any]:
|
|||
except json.JSONDecodeError:
|
||||
parsed = None
|
||||
response = parsed if isinstance(parsed, dict) else {"output": content}
|
||||
return {
|
||||
"functionResponse": {
|
||||
"name": name,
|
||||
"response": response,
|
||||
},
|
||||
function_response = {
|
||||
"name": name,
|
||||
"response": response,
|
||||
}
|
||||
if message.get("tool_call_id"):
|
||||
function_response["id"] = str(message["tool_call_id"])
|
||||
return {"functionResponse": function_response}
|
||||
|
||||
|
||||
def _build_gemini_contents(
|
||||
|
|
@ -358,8 +362,9 @@ def _translate_gemini_response(
|
|||
args_str = json.dumps(fc.get("args") or {}, ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
args_str = "{}"
|
||||
call_id = str(fc.get("id") or "").strip() or f"call_{uuid.uuid4().hex[:12]}"
|
||||
tool_calls.append(SimpleNamespace(
|
||||
id=f"call_{uuid.uuid4().hex[:12]}",
|
||||
id=call_id,
|
||||
type="function",
|
||||
index=i,
|
||||
function=SimpleNamespace(name=str(fc["name"]), arguments=args_str),
|
||||
|
|
@ -554,6 +559,7 @@ def _translate_stream_event(
|
|||
model=model,
|
||||
tool_call_delta={
|
||||
"index": idx,
|
||||
"id": str(fc.get("id") or "").strip(),
|
||||
"name": name,
|
||||
"arguments": args_str,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -437,7 +437,7 @@ class ChatCompletionsTransport(ProviderTransport):
|
|||
extra_body["extra_body"] = openai_compat_extra
|
||||
elif raw_thinking_config:
|
||||
extra_body["thinking_config"] = raw_thinking_config
|
||||
elif provider_name == "google-gemini-cli":
|
||||
elif provider_name in {"google-gemini-cli", "google-antigravity"}:
|
||||
thinking_config = _build_gemini_thinking_config(model, reasoning_config)
|
||||
if thinking_config:
|
||||
extra_body["thinking_config"] = thinking_config
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue