mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(copilot): exchange raw GitHub token for Copilot API JWT
Raw GitHub tokens (gho_/github_pat_/ghu_) are now exchanged for short-lived Copilot API tokens via /copilot_internal/v2/token before being used as Bearer credentials. This is required to access internal-only models (e.g. claude-opus-4.6-1m with 1M context). Implementation: - exchange_copilot_token(): calls the token exchange endpoint with in-process caching (dict keyed by SHA-256 fingerprint), refreshed 2 minutes before expiry. No disk persistence — gateway is long-running so in-memory cache is sufficient. - get_copilot_api_token(): convenience wrapper with graceful fallback — returns exchanged token on success, raw token on failure. - Both callers (hermes_cli/auth.py and agent/credential_pool.py) now pipe the raw token through get_copilot_api_token() before use. 12 new tests covering exchange, caching, expiry, error handling, fingerprinting, and caller integration. All 185 existing copilot/auth tests pass. Part 2 of #7731.
This commit is contained in:
parent
2cab8129d1
commit
d7ad07d6fe
4 changed files with 257 additions and 4 deletions
|
|
@ -1078,9 +1078,10 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
|
||||||
# env vars (COPILOT_GITHUB_TOKEN / GH_TOKEN). They don't live in
|
# env vars (COPILOT_GITHUB_TOKEN / GH_TOKEN). They don't live in
|
||||||
# the auth store or credential pool, so we resolve them here.
|
# the auth store or credential pool, so we resolve them here.
|
||||||
try:
|
try:
|
||||||
from hermes_cli.copilot_auth import resolve_copilot_token
|
from hermes_cli.copilot_auth import resolve_copilot_token, get_copilot_api_token
|
||||||
token, source = resolve_copilot_token()
|
token, source = resolve_copilot_token()
|
||||||
if token:
|
if token:
|
||||||
|
api_token = get_copilot_api_token(token)
|
||||||
source_name = "gh_cli" if "gh" in source.lower() else f"env:{source}"
|
source_name = "gh_cli" if "gh" in source.lower() else f"env:{source}"
|
||||||
if not _is_suppressed(provider, source_name):
|
if not _is_suppressed(provider, source_name):
|
||||||
active_sources.add(source_name)
|
active_sources.add(source_name)
|
||||||
|
|
@ -1092,7 +1093,7 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
|
||||||
{
|
{
|
||||||
"source": source_name,
|
"source": source_name,
|
||||||
"auth_type": AUTH_TYPE_API_KEY,
|
"auth_type": AUTH_TYPE_API_KEY,
|
||||||
"access_token": token,
|
"access_token": api_token,
|
||||||
"base_url": pconfig.inference_base_url if pconfig else "",
|
"base_url": pconfig.inference_base_url if pconfig else "",
|
||||||
"label": source,
|
"label": source,
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -426,10 +426,10 @@ def _resolve_api_key_provider_secret(
|
||||||
if provider_id == "copilot":
|
if provider_id == "copilot":
|
||||||
# Use the dedicated copilot auth module for proper token validation
|
# Use the dedicated copilot auth module for proper token validation
|
||||||
try:
|
try:
|
||||||
from hermes_cli.copilot_auth import resolve_copilot_token
|
from hermes_cli.copilot_auth import resolve_copilot_token, get_copilot_api_token
|
||||||
token, source = resolve_copilot_token()
|
token, source = resolve_copilot_token()
|
||||||
if token:
|
if token:
|
||||||
return token, source
|
return get_copilot_api_token(token), source
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
logger.warning("Copilot token validation failed: %s", exc)
|
logger.warning("Copilot token validation failed: %s", exc)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
||||||
|
|
@ -275,6 +275,99 @@ def copilot_device_code_login(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Copilot Token Exchange ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Module-level cache for exchanged Copilot API tokens.
|
||||||
|
# Maps raw_token_fingerprint -> (api_token, expires_at_epoch).
|
||||||
|
_jwt_cache: dict[str, tuple[str, float]] = {}
|
||||||
|
_JWT_REFRESH_MARGIN_SECONDS = 120 # refresh 2 min before expiry
|
||||||
|
|
||||||
|
# Token exchange endpoint and headers (matching VS Code / Copilot CLI)
|
||||||
|
_TOKEN_EXCHANGE_URL = "https://api.github.com/copilot_internal/v2/token"
|
||||||
|
_EDITOR_VERSION = "vscode/1.104.1"
|
||||||
|
_EXCHANGE_USER_AGENT = "GitHubCopilotChat/0.26.7"
|
||||||
|
|
||||||
|
|
||||||
|
def _token_fingerprint(raw_token: str) -> str:
|
||||||
|
"""Short fingerprint of a raw token for cache keying (avoids storing full token)."""
|
||||||
|
import hashlib
|
||||||
|
return hashlib.sha256(raw_token.encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
|
def exchange_copilot_token(raw_token: str, *, timeout: float = 10.0) -> tuple[str, float]:
|
||||||
|
"""Exchange a raw GitHub token for a short-lived Copilot API token.
|
||||||
|
|
||||||
|
Calls ``GET https://api.github.com/copilot_internal/v2/token`` with
|
||||||
|
the raw GitHub token and returns ``(api_token, expires_at)``.
|
||||||
|
|
||||||
|
The returned token is a semicolon-separated string (not a standard JWT)
|
||||||
|
used as ``Authorization: Bearer <token>`` for Copilot API requests.
|
||||||
|
|
||||||
|
Results are cached in-process and reused until close to expiry.
|
||||||
|
Raises ``ValueError`` on failure.
|
||||||
|
"""
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
fp = _token_fingerprint(raw_token)
|
||||||
|
|
||||||
|
# Check cache first
|
||||||
|
cached = _jwt_cache.get(fp)
|
||||||
|
if cached:
|
||||||
|
api_token, expires_at = cached
|
||||||
|
if time.time() < expires_at - _JWT_REFRESH_MARGIN_SECONDS:
|
||||||
|
return api_token, expires_at
|
||||||
|
|
||||||
|
req = urllib.request.Request(
|
||||||
|
_TOKEN_EXCHANGE_URL,
|
||||||
|
method="GET",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"token {raw_token}",
|
||||||
|
"User-Agent": _EXCHANGE_USER_AGENT,
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Editor-Version": _EDITOR_VERSION,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||||
|
data = json.loads(resp.read().decode())
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValueError(f"Copilot token exchange failed: {exc}") from exc
|
||||||
|
|
||||||
|
api_token = data.get("token", "")
|
||||||
|
expires_at = data.get("expires_at", 0)
|
||||||
|
if not api_token:
|
||||||
|
raise ValueError("Copilot token exchange returned empty token")
|
||||||
|
|
||||||
|
# Convert expires_at to float if needed
|
||||||
|
expires_at = float(expires_at) if expires_at else time.time() + 1800
|
||||||
|
|
||||||
|
_jwt_cache[fp] = (api_token, expires_at)
|
||||||
|
logger.debug(
|
||||||
|
"Copilot token exchanged, expires_at=%s",
|
||||||
|
expires_at,
|
||||||
|
)
|
||||||
|
return api_token, expires_at
|
||||||
|
|
||||||
|
|
||||||
|
def get_copilot_api_token(raw_token: str) -> str:
|
||||||
|
"""Exchange a raw GitHub token for a Copilot API token, with fallback.
|
||||||
|
|
||||||
|
Convenience wrapper: returns the exchanged token on success, or the
|
||||||
|
raw token unchanged if the exchange fails (e.g. network error, unsupported
|
||||||
|
account type). This preserves existing behaviour for accounts that don't
|
||||||
|
need exchange while enabling access to internal-only models for those that do.
|
||||||
|
"""
|
||||||
|
if not raw_token:
|
||||||
|
return raw_token
|
||||||
|
try:
|
||||||
|
api_token, _ = exchange_copilot_token(raw_token)
|
||||||
|
return api_token
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Copilot token exchange failed, using raw token: %s", exc)
|
||||||
|
return raw_token
|
||||||
|
|
||||||
|
|
||||||
# ─── Copilot API Headers ───────────────────────────────────────────────────
|
# ─── Copilot API Headers ───────────────────────────────────────────────────
|
||||||
|
|
||||||
def copilot_request_headers(
|
def copilot_request_headers(
|
||||||
|
|
|
||||||
159
tests/hermes_cli/test_copilot_token_exchange.py
Normal file
159
tests/hermes_cli/test_copilot_token_exchange.py
Normal file
|
|
@ -0,0 +1,159 @@
|
||||||
|
"""Tests for Copilot token exchange (raw GitHub token → Copilot API token)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clear_jwt_cache():
|
||||||
|
"""Reset the module-level JWT cache before each test."""
|
||||||
|
import hermes_cli.copilot_auth as mod
|
||||||
|
mod._jwt_cache.clear()
|
||||||
|
yield
|
||||||
|
mod._jwt_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class TestExchangeCopilotToken:
|
||||||
|
"""Tests for exchange_copilot_token()."""
|
||||||
|
|
||||||
|
def _mock_urlopen(self, token="tid=abc;exp=123;sku=copilot_individual", expires_at=None):
|
||||||
|
"""Create a mock urlopen context manager returning a token response."""
|
||||||
|
if expires_at is None:
|
||||||
|
expires_at = time.time() + 1800
|
||||||
|
resp_data = json.dumps({"token": token, "expires_at": expires_at}).encode()
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.read.return_value = resp_data
|
||||||
|
mock_resp.__enter__ = MagicMock(return_value=mock_resp)
|
||||||
|
mock_resp.__exit__ = MagicMock(return_value=False)
|
||||||
|
return mock_resp
|
||||||
|
|
||||||
|
@patch("urllib.request.urlopen")
|
||||||
|
def test_exchanges_token_successfully(self, mock_urlopen):
|
||||||
|
from hermes_cli.copilot_auth import exchange_copilot_token
|
||||||
|
|
||||||
|
mock_urlopen.return_value = self._mock_urlopen(token="tid=abc;exp=999")
|
||||||
|
api_token, expires_at = exchange_copilot_token("gho_test123")
|
||||||
|
|
||||||
|
assert api_token == "tid=abc;exp=999"
|
||||||
|
assert isinstance(expires_at, float)
|
||||||
|
|
||||||
|
# Verify request was made with correct headers
|
||||||
|
call_args = mock_urlopen.call_args
|
||||||
|
req = call_args[0][0]
|
||||||
|
assert req.get_header("Authorization") == "token gho_test123"
|
||||||
|
assert "GitHubCopilotChat" in req.get_header("User-agent")
|
||||||
|
|
||||||
|
@patch("urllib.request.urlopen")
|
||||||
|
def test_caches_result(self, mock_urlopen):
|
||||||
|
from hermes_cli.copilot_auth import exchange_copilot_token
|
||||||
|
|
||||||
|
future = time.time() + 1800
|
||||||
|
mock_urlopen.return_value = self._mock_urlopen(expires_at=future)
|
||||||
|
|
||||||
|
exchange_copilot_token("gho_test123")
|
||||||
|
exchange_copilot_token("gho_test123")
|
||||||
|
|
||||||
|
assert mock_urlopen.call_count == 1
|
||||||
|
|
||||||
|
@patch("urllib.request.urlopen")
|
||||||
|
def test_refreshes_expired_cache(self, mock_urlopen):
|
||||||
|
from hermes_cli.copilot_auth import exchange_copilot_token, _jwt_cache, _token_fingerprint
|
||||||
|
|
||||||
|
# Seed cache with expired entry
|
||||||
|
fp = _token_fingerprint("gho_test123")
|
||||||
|
_jwt_cache[fp] = ("old_token", time.time() - 10)
|
||||||
|
|
||||||
|
mock_urlopen.return_value = self._mock_urlopen(
|
||||||
|
token="new_token", expires_at=time.time() + 1800
|
||||||
|
)
|
||||||
|
api_token, _ = exchange_copilot_token("gho_test123")
|
||||||
|
|
||||||
|
assert api_token == "new_token"
|
||||||
|
assert mock_urlopen.call_count == 1
|
||||||
|
|
||||||
|
@patch("urllib.request.urlopen")
|
||||||
|
def test_raises_on_empty_token(self, mock_urlopen):
|
||||||
|
from hermes_cli.copilot_auth import exchange_copilot_token
|
||||||
|
|
||||||
|
resp_data = json.dumps({"token": "", "expires_at": 0}).encode()
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.read.return_value = resp_data
|
||||||
|
mock_resp.__enter__ = MagicMock(return_value=mock_resp)
|
||||||
|
mock_resp.__exit__ = MagicMock(return_value=False)
|
||||||
|
mock_urlopen.return_value = mock_resp
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="empty token"):
|
||||||
|
exchange_copilot_token("gho_test123")
|
||||||
|
|
||||||
|
@patch("urllib.request.urlopen", side_effect=Exception("network error"))
|
||||||
|
def test_raises_on_network_error(self, mock_urlopen):
|
||||||
|
from hermes_cli.copilot_auth import exchange_copilot_token
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="network error"):
|
||||||
|
exchange_copilot_token("gho_test123")
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCopilotApiToken:
|
||||||
|
"""Tests for get_copilot_api_token() — the fallback wrapper."""
|
||||||
|
|
||||||
|
@patch("hermes_cli.copilot_auth.exchange_copilot_token")
|
||||||
|
def test_returns_exchanged_token(self, mock_exchange):
|
||||||
|
from hermes_cli.copilot_auth import get_copilot_api_token
|
||||||
|
|
||||||
|
mock_exchange.return_value = ("exchanged_jwt", time.time() + 1800)
|
||||||
|
assert get_copilot_api_token("gho_raw") == "exchanged_jwt"
|
||||||
|
|
||||||
|
@patch("hermes_cli.copilot_auth.exchange_copilot_token", side_effect=ValueError("fail"))
|
||||||
|
def test_falls_back_to_raw_token(self, mock_exchange):
|
||||||
|
from hermes_cli.copilot_auth import get_copilot_api_token
|
||||||
|
|
||||||
|
assert get_copilot_api_token("gho_raw") == "gho_raw"
|
||||||
|
|
||||||
|
def test_empty_token_passthrough(self):
|
||||||
|
from hermes_cli.copilot_auth import get_copilot_api_token
|
||||||
|
|
||||||
|
assert get_copilot_api_token("") == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenFingerprint:
|
||||||
|
"""Tests for _token_fingerprint()."""
|
||||||
|
|
||||||
|
def test_consistent(self):
|
||||||
|
from hermes_cli.copilot_auth import _token_fingerprint
|
||||||
|
|
||||||
|
fp1 = _token_fingerprint("gho_abc123")
|
||||||
|
fp2 = _token_fingerprint("gho_abc123")
|
||||||
|
assert fp1 == fp2
|
||||||
|
|
||||||
|
def test_different_tokens_different_fingerprints(self):
|
||||||
|
from hermes_cli.copilot_auth import _token_fingerprint
|
||||||
|
|
||||||
|
fp1 = _token_fingerprint("gho_abc123")
|
||||||
|
fp2 = _token_fingerprint("gho_xyz789")
|
||||||
|
assert fp1 != fp2
|
||||||
|
|
||||||
|
def test_length(self):
|
||||||
|
from hermes_cli.copilot_auth import _token_fingerprint
|
||||||
|
|
||||||
|
assert len(_token_fingerprint("gho_test")) == 16
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallerIntegration:
|
||||||
|
"""Test that callers correctly use token exchange."""
|
||||||
|
|
||||||
|
@patch("hermes_cli.copilot_auth.resolve_copilot_token", return_value=("gho_raw", "GH_TOKEN"))
|
||||||
|
@patch("hermes_cli.copilot_auth.get_copilot_api_token", return_value="exchanged_jwt")
|
||||||
|
def test_auth_resolve_uses_exchange(self, mock_exchange, mock_resolve):
|
||||||
|
from hermes_cli.auth import _resolve_api_key_provider_secret
|
||||||
|
|
||||||
|
# Create a minimal pconfig mock
|
||||||
|
pconfig = MagicMock()
|
||||||
|
token, source = _resolve_api_key_provider_secret("copilot", pconfig)
|
||||||
|
assert token == "exchanged_jwt"
|
||||||
|
assert source == "GH_TOKEN"
|
||||||
|
mock_exchange.assert_called_once_with("gho_raw")
|
||||||
Loading…
Add table
Add a link
Reference in a new issue