hermes-agent/agent/antigravity_code_assist.py
Teknium b7a912ea45 fix(antigravity): bake in public OAuth client + default project fallback
Salvage follow-up on top of @pmos69's #29474. The PR resolved the
Antigravity OAuth client purely by discovering it from an installed `agy`
binary or HERMES_ANTIGRAVITY_CLIENT_ID/SECRET env vars, so users without
agy installed hit a hard 'client ID not available' error.

Antigravity's desktop OAuth client is a public, non-confidential installed-app
client (PKCE provides the security), baked into every copy of the Antigravity
CLI — same posture as the gemini-cli credentials Hermes already ships in
google_oauth.py. Bake it in as the final fallback (env -> discovery -> public
default) and add the public default Code Assist project as the discovery
fallback, matching the reference Antigravity flow. Now consumers can
authenticate directly without agy installed.
2026-06-21 16:41:30 -07:00

286 lines
8.8 KiB
Python

"""Antigravity Code Assist control-plane helpers.
The new Antigravity CLI uses the same v1internal Code Assist family as
gemini-cli, but with Antigravity OAuth scopes, metadata and model catalog. This
module keeps that provider-specific surface separate from
``agent.google_code_assist``.
"""
from __future__ import annotations
import json
import logging
import urllib.error
import urllib.request
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional
from agent.google_code_assist import CodeAssistError
logger = logging.getLogger(__name__)
ANTIGRAVITY_CODE_ASSIST_ENDPOINT = "https://daily-cloudcode-pa.sandbox.googleapis.com"
ANTIGRAVITY_MODEL_ENDPOINTS = [
ANTIGRAVITY_CODE_ASSIST_ENDPOINT,
"https://cloudcode-pa.googleapis.com",
"https://autopush-cloudcode-pa.sandbox.googleapis.com",
]
ANTIGRAVITY_CLIENT_METADATA = {
"ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
}
ANTIGRAVITY_USER_AGENT = "antigravity/1.0.0 windows/amd64"
ANTIGRAVITY_X_GOOG_API_CLIENT = "google-cloud-sdk vscode_cloudshelleditor/0.1"
DEFAULT_AGENT_MODEL_IDS = [
"gemini-3-flash-agent",
"gemini-3.5-flash-low",
"gemini-pro-agent",
"gemini-3.1-pro-low",
"claude-sonnet-4-6",
"claude-opus-4-6-thinking",
"gpt-oss-120b-medium",
]
DEPRECATED_MODEL_REPLACEMENTS = {
"gemini-3.1-pro-high": "gemini-pro-agent",
}
@dataclass
class AntigravityProjectInfo:
project_id: str = ""
raw: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ProjectContext:
project_id: str = ""
managed_project_id: str = ""
tier_id: str = ""
source: str = ""
def _client_metadata() -> Dict[str, str]:
return dict(ANTIGRAVITY_CLIENT_METADATA)
def build_headers(access_token: str, *, accept: str = "application/json") -> Dict[str, str]:
return {
"Content-Type": "application/json",
"Accept": accept,
"Authorization": f"Bearer {access_token}",
"User-Agent": ANTIGRAVITY_USER_AGENT,
"X-Goog-Api-Client": ANTIGRAVITY_X_GOOG_API_CLIENT,
"Client-Metadata": json.dumps(_client_metadata(), separators=(",", ":")),
"x-activity-request-id": str(uuid.uuid4()),
}
def _post_json(
url: str,
body: Dict[str, Any],
access_token: str,
*,
timeout: float = 30.0,
) -> Dict[str, Any]:
data = json.dumps(body).encode("utf-8")
request = urllib.request.Request(
url,
data=data,
method="POST",
headers=build_headers(access_token),
)
try:
with urllib.request.urlopen(request, timeout=timeout) as response:
raw = response.read().decode("utf-8", errors="replace")
return json.loads(raw) if raw else {}
except urllib.error.HTTPError as exc:
detail = ""
try:
detail = exc.read().decode("utf-8", errors="replace")
except Exception:
pass
raise CodeAssistError(
f"Antigravity Code Assist HTTP {exc.code}: {detail or exc.reason}",
code=f"antigravity_code_assist_http_{exc.code}",
) from exc
except urllib.error.URLError as exc:
raise CodeAssistError(
f"Antigravity Code Assist request failed: {exc}",
code="antigravity_code_assist_network_error",
) from exc
def load_code_assist(
access_token: str,
*,
project_id: str = "",
endpoint: str = ANTIGRAVITY_CODE_ASSIST_ENDPOINT,
) -> AntigravityProjectInfo:
metadata = _client_metadata()
if project_id:
metadata["duetProject"] = project_id
body: Dict[str, Any] = {"metadata": metadata}
if project_id:
body["cloudaicompanionProject"] = project_id
resp = _post_json(f"{endpoint}/v1internal:loadCodeAssist", body, access_token)
project = (
str(resp.get("cloudaicompanionProject") or "").strip()
or str(resp.get("project") or "").strip()
)
return AntigravityProjectInfo(project_id=project, raw=resp)
def resolve_project_context(
access_token: str,
*,
configured_project_id: str = "",
env_project_id: str = "",
) -> ProjectContext:
if configured_project_id:
return ProjectContext(project_id=configured_project_id, source="config")
if env_project_id:
return ProjectContext(project_id=env_project_id, source="env")
info = load_code_assist(access_token)
if info.project_id:
return ProjectContext(
project_id=info.project_id,
managed_project_id=info.project_id,
source="discovered",
)
# Discovery returned no project (common on fresh consumer accounts that
# haven't been onboarded). Fall back to the public default project so the
# call chain still succeeds — mirrors the Antigravity CLI reference flow.
from agent.antigravity_oauth import DEFAULT_PROJECT_ID
return ProjectContext(
project_id=DEFAULT_PROJECT_ID,
managed_project_id=DEFAULT_PROJECT_ID,
source="default",
)
def fetch_available_models(
access_token: str,
*,
project_id: str = "",
endpoint: str = ANTIGRAVITY_CODE_ASSIST_ENDPOINT,
) -> Dict[str, Any]:
body: Dict[str, Any] = {}
if project_id:
body["project"] = project_id
return _post_json(f"{endpoint}/v1internal:fetchAvailableModels", body, access_token)
def fetch_available_models_with_fallbacks(
access_token: str,
*,
project_id: str = "",
endpoints: Optional[Iterable[str]] = None,
) -> Dict[str, Any]:
last_err: Optional[Exception] = None
for endpoint in endpoints or ANTIGRAVITY_MODEL_ENDPOINTS:
try:
return fetch_available_models(
access_token,
project_id=project_id,
endpoint=endpoint,
)
except Exception as exc:
last_err = exc
logger.debug("Antigravity fetchAvailableModels failed on %s: %s", endpoint, exc)
if last_err:
raise last_err
return {}
def _model_id_from_value(value: Any) -> str:
if isinstance(value, str):
return value.strip()
if isinstance(value, dict):
for key in ("modelId", "model_id", "id", "name"):
candidate = str(value.get(key) or "").strip()
if candidate:
return candidate
return ""
def _ids_from_sort(sort: Dict[str, Any]) -> List[str]:
ids: List[str] = []
for key in ("modelIds", "model_ids", "models", "modelSorts"):
value = sort.get(key)
if isinstance(value, list):
for item in value:
mid = _model_id_from_value(item)
if mid:
ids.append(mid)
elif isinstance(value, dict):
mid = _model_id_from_value(value)
if mid:
ids.append(mid)
return ids
def _is_recommended_sort(sort: Dict[str, Any]) -> bool:
label = " ".join(
str(sort.get(key) or "")
for key in ("name", "displayName", "title", "category", "group")
).lower()
return "recommended" in label
def _raw_model_ids(payload: Dict[str, Any]) -> List[str]:
ids: List[str] = []
models = payload.get("models")
if isinstance(models, list):
for item in models:
mid = _model_id_from_value(item)
if mid:
ids.append(mid)
return ids
def filter_agent_model_ids(ids: Iterable[str]) -> List[str]:
seen: set[str] = set()
filtered: List[str] = []
raw = [str(mid).strip() for mid in ids if str(mid).strip()]
replacements = set(DEPRECATED_MODEL_REPLACEMENTS.values())
for mid in raw:
if mid in seen:
continue
if mid.startswith(("chat_", "tab_")):
continue
if mid in DEPRECATED_MODEL_REPLACEMENTS and DEPRECATED_MODEL_REPLACEMENTS[mid] in raw:
continue
if mid in replacements and mid in seen:
continue
seen.add(mid)
filtered.append(mid)
return filtered
def parse_agent_model_ids(payload: Dict[str, Any]) -> List[str]:
"""Return the user-facing Antigravity agent model list in display order."""
sorts = payload.get("agentModelSorts")
ordered: List[str] = []
if isinstance(sorts, list):
recommended = [s for s in sorts if isinstance(s, dict) and _is_recommended_sort(s)]
rest = [s for s in sorts if isinstance(s, dict) and not _is_recommended_sort(s)]
for sort in recommended + rest:
ordered.extend(_ids_from_sort(sort))
if not ordered:
default_id = str(payload.get("defaultAgentModelId") or "").strip()
if default_id:
ordered.append(default_id)
for mid in DEFAULT_AGENT_MODEL_IDS:
ordered.append(mid)
ordered.extend(_raw_model_ids(payload))
filtered = filter_agent_model_ids(ordered)
if filtered:
return filtered
return list(DEFAULT_AGENT_MODEL_IDS)