mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(copilot): exchange raw GitHub token for Copilot API JWT
Raw GitHub tokens (gho_/github_pat_/ghu_) are now exchanged for short-lived Copilot API tokens via /copilot_internal/v2/token before being used as Bearer credentials. This is required to access internal-only models (e.g. claude-opus-4.6-1m with 1M context). Implementation: - exchange_copilot_token(): calls the token exchange endpoint with in-process caching (dict keyed by SHA-256 fingerprint), refreshed 2 minutes before expiry. No disk persistence — gateway is long-running so in-memory cache is sufficient. - get_copilot_api_token(): convenience wrapper with graceful fallback — returns exchanged token on success, raw token on failure. - Both callers (hermes_cli/auth.py and agent/credential_pool.py) now pipe the raw token through get_copilot_api_token() before use. 12 new tests covering exchange, caching, expiry, error handling, fingerprinting, and caller integration. All 185 existing copilot/auth tests pass. Part 2 of #7731.
This commit is contained in:
parent
2cab8129d1
commit
d7ad07d6fe
4 changed files with 257 additions and 4 deletions
|
|
@ -1078,9 +1078,10 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
|
|||
# env vars (COPILOT_GITHUB_TOKEN / GH_TOKEN). They don't live in
|
||||
# the auth store or credential pool, so we resolve them here.
|
||||
try:
|
||||
from hermes_cli.copilot_auth import resolve_copilot_token
|
||||
from hermes_cli.copilot_auth import resolve_copilot_token, get_copilot_api_token
|
||||
token, source = resolve_copilot_token()
|
||||
if token:
|
||||
api_token = get_copilot_api_token(token)
|
||||
source_name = "gh_cli" if "gh" in source.lower() else f"env:{source}"
|
||||
if not _is_suppressed(provider, source_name):
|
||||
active_sources.add(source_name)
|
||||
|
|
@ -1092,7 +1093,7 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
|
|||
{
|
||||
"source": source_name,
|
||||
"auth_type": AUTH_TYPE_API_KEY,
|
||||
"access_token": token,
|
||||
"access_token": api_token,
|
||||
"base_url": pconfig.inference_base_url if pconfig else "",
|
||||
"label": source,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -426,10 +426,10 @@ def _resolve_api_key_provider_secret(
|
|||
if provider_id == "copilot":
|
||||
# Use the dedicated copilot auth module for proper token validation
|
||||
try:
|
||||
from hermes_cli.copilot_auth import resolve_copilot_token
|
||||
from hermes_cli.copilot_auth import resolve_copilot_token, get_copilot_api_token
|
||||
token, source = resolve_copilot_token()
|
||||
if token:
|
||||
return token, source
|
||||
return get_copilot_api_token(token), source
|
||||
except ValueError as exc:
|
||||
logger.warning("Copilot token validation failed: %s", exc)
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -275,6 +275,99 @@ def copilot_device_code_login(
|
|||
return None
|
||||
|
||||
|
||||
# ─── Copilot Token Exchange ────────────────────────────────────────────────
|
||||
|
||||
# Module-level cache for exchanged Copilot API tokens.
|
||||
# Maps raw_token_fingerprint -> (api_token, expires_at_epoch).
|
||||
_jwt_cache: dict[str, tuple[str, float]] = {}
|
||||
_JWT_REFRESH_MARGIN_SECONDS = 120 # refresh 2 min before expiry
|
||||
|
||||
# Token exchange endpoint and headers (matching VS Code / Copilot CLI)
|
||||
_TOKEN_EXCHANGE_URL = "https://api.github.com/copilot_internal/v2/token"
|
||||
_EDITOR_VERSION = "vscode/1.104.1"
|
||||
_EXCHANGE_USER_AGENT = "GitHubCopilotChat/0.26.7"
|
||||
|
||||
|
||||
def _token_fingerprint(raw_token: str) -> str:
|
||||
"""Short fingerprint of a raw token for cache keying (avoids storing full token)."""
|
||||
import hashlib
|
||||
return hashlib.sha256(raw_token.encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
def exchange_copilot_token(raw_token: str, *, timeout: float = 10.0) -> tuple[str, float]:
|
||||
"""Exchange a raw GitHub token for a short-lived Copilot API token.
|
||||
|
||||
Calls ``GET https://api.github.com/copilot_internal/v2/token`` with
|
||||
the raw GitHub token and returns ``(api_token, expires_at)``.
|
||||
|
||||
The returned token is a semicolon-separated string (not a standard JWT)
|
||||
used as ``Authorization: Bearer <token>`` for Copilot API requests.
|
||||
|
||||
Results are cached in-process and reused until close to expiry.
|
||||
Raises ``ValueError`` on failure.
|
||||
"""
|
||||
import urllib.request
|
||||
|
||||
fp = _token_fingerprint(raw_token)
|
||||
|
||||
# Check cache first
|
||||
cached = _jwt_cache.get(fp)
|
||||
if cached:
|
||||
api_token, expires_at = cached
|
||||
if time.time() < expires_at - _JWT_REFRESH_MARGIN_SECONDS:
|
||||
return api_token, expires_at
|
||||
|
||||
req = urllib.request.Request(
|
||||
_TOKEN_EXCHANGE_URL,
|
||||
method="GET",
|
||||
headers={
|
||||
"Authorization": f"token {raw_token}",
|
||||
"User-Agent": _EXCHANGE_USER_AGENT,
|
||||
"Accept": "application/json",
|
||||
"Editor-Version": _EDITOR_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
data = json.loads(resp.read().decode())
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Copilot token exchange failed: {exc}") from exc
|
||||
|
||||
api_token = data.get("token", "")
|
||||
expires_at = data.get("expires_at", 0)
|
||||
if not api_token:
|
||||
raise ValueError("Copilot token exchange returned empty token")
|
||||
|
||||
# Convert expires_at to float if needed
|
||||
expires_at = float(expires_at) if expires_at else time.time() + 1800
|
||||
|
||||
_jwt_cache[fp] = (api_token, expires_at)
|
||||
logger.debug(
|
||||
"Copilot token exchanged, expires_at=%s",
|
||||
expires_at,
|
||||
)
|
||||
return api_token, expires_at
|
||||
|
||||
|
||||
def get_copilot_api_token(raw_token: str) -> str:
|
||||
"""Exchange a raw GitHub token for a Copilot API token, with fallback.
|
||||
|
||||
Convenience wrapper: returns the exchanged token on success, or the
|
||||
raw token unchanged if the exchange fails (e.g. network error, unsupported
|
||||
account type). This preserves existing behaviour for accounts that don't
|
||||
need exchange while enabling access to internal-only models for those that do.
|
||||
"""
|
||||
if not raw_token:
|
||||
return raw_token
|
||||
try:
|
||||
api_token, _ = exchange_copilot_token(raw_token)
|
||||
return api_token
|
||||
except Exception as exc:
|
||||
logger.debug("Copilot token exchange failed, using raw token: %s", exc)
|
||||
return raw_token
|
||||
|
||||
|
||||
# ─── Copilot API Headers ───────────────────────────────────────────────────
|
||||
|
||||
def copilot_request_headers(
|
||||
|
|
|
|||
159
tests/hermes_cli/test_copilot_token_exchange.py
Normal file
159
tests/hermes_cli/test_copilot_token_exchange.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
"""Tests for Copilot token exchange (raw GitHub token → Copilot API token)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_jwt_cache():
|
||||
"""Reset the module-level JWT cache before each test."""
|
||||
import hermes_cli.copilot_auth as mod
|
||||
mod._jwt_cache.clear()
|
||||
yield
|
||||
mod._jwt_cache.clear()
|
||||
|
||||
|
||||
class TestExchangeCopilotToken:
|
||||
"""Tests for exchange_copilot_token()."""
|
||||
|
||||
def _mock_urlopen(self, token="tid=abc;exp=123;sku=copilot_individual", expires_at=None):
|
||||
"""Create a mock urlopen context manager returning a token response."""
|
||||
if expires_at is None:
|
||||
expires_at = time.time() + 1800
|
||||
resp_data = json.dumps({"token": token, "expires_at": expires_at}).encode()
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = resp_data
|
||||
mock_resp.__enter__ = MagicMock(return_value=mock_resp)
|
||||
mock_resp.__exit__ = MagicMock(return_value=False)
|
||||
return mock_resp
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_exchanges_token_successfully(self, mock_urlopen):
|
||||
from hermes_cli.copilot_auth import exchange_copilot_token
|
||||
|
||||
mock_urlopen.return_value = self._mock_urlopen(token="tid=abc;exp=999")
|
||||
api_token, expires_at = exchange_copilot_token("gho_test123")
|
||||
|
||||
assert api_token == "tid=abc;exp=999"
|
||||
assert isinstance(expires_at, float)
|
||||
|
||||
# Verify request was made with correct headers
|
||||
call_args = mock_urlopen.call_args
|
||||
req = call_args[0][0]
|
||||
assert req.get_header("Authorization") == "token gho_test123"
|
||||
assert "GitHubCopilotChat" in req.get_header("User-agent")
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_caches_result(self, mock_urlopen):
|
||||
from hermes_cli.copilot_auth import exchange_copilot_token
|
||||
|
||||
future = time.time() + 1800
|
||||
mock_urlopen.return_value = self._mock_urlopen(expires_at=future)
|
||||
|
||||
exchange_copilot_token("gho_test123")
|
||||
exchange_copilot_token("gho_test123")
|
||||
|
||||
assert mock_urlopen.call_count == 1
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_refreshes_expired_cache(self, mock_urlopen):
|
||||
from hermes_cli.copilot_auth import exchange_copilot_token, _jwt_cache, _token_fingerprint
|
||||
|
||||
# Seed cache with expired entry
|
||||
fp = _token_fingerprint("gho_test123")
|
||||
_jwt_cache[fp] = ("old_token", time.time() - 10)
|
||||
|
||||
mock_urlopen.return_value = self._mock_urlopen(
|
||||
token="new_token", expires_at=time.time() + 1800
|
||||
)
|
||||
api_token, _ = exchange_copilot_token("gho_test123")
|
||||
|
||||
assert api_token == "new_token"
|
||||
assert mock_urlopen.call_count == 1
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_raises_on_empty_token(self, mock_urlopen):
|
||||
from hermes_cli.copilot_auth import exchange_copilot_token
|
||||
|
||||
resp_data = json.dumps({"token": "", "expires_at": 0}).encode()
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = resp_data
|
||||
mock_resp.__enter__ = MagicMock(return_value=mock_resp)
|
||||
mock_resp.__exit__ = MagicMock(return_value=False)
|
||||
mock_urlopen.return_value = mock_resp
|
||||
|
||||
with pytest.raises(ValueError, match="empty token"):
|
||||
exchange_copilot_token("gho_test123")
|
||||
|
||||
@patch("urllib.request.urlopen", side_effect=Exception("network error"))
|
||||
def test_raises_on_network_error(self, mock_urlopen):
|
||||
from hermes_cli.copilot_auth import exchange_copilot_token
|
||||
|
||||
with pytest.raises(ValueError, match="network error"):
|
||||
exchange_copilot_token("gho_test123")
|
||||
|
||||
|
||||
class TestGetCopilotApiToken:
|
||||
"""Tests for get_copilot_api_token() — the fallback wrapper."""
|
||||
|
||||
@patch("hermes_cli.copilot_auth.exchange_copilot_token")
|
||||
def test_returns_exchanged_token(self, mock_exchange):
|
||||
from hermes_cli.copilot_auth import get_copilot_api_token
|
||||
|
||||
mock_exchange.return_value = ("exchanged_jwt", time.time() + 1800)
|
||||
assert get_copilot_api_token("gho_raw") == "exchanged_jwt"
|
||||
|
||||
@patch("hermes_cli.copilot_auth.exchange_copilot_token", side_effect=ValueError("fail"))
|
||||
def test_falls_back_to_raw_token(self, mock_exchange):
|
||||
from hermes_cli.copilot_auth import get_copilot_api_token
|
||||
|
||||
assert get_copilot_api_token("gho_raw") == "gho_raw"
|
||||
|
||||
def test_empty_token_passthrough(self):
|
||||
from hermes_cli.copilot_auth import get_copilot_api_token
|
||||
|
||||
assert get_copilot_api_token("") == ""
|
||||
|
||||
|
||||
class TestTokenFingerprint:
|
||||
"""Tests for _token_fingerprint()."""
|
||||
|
||||
def test_consistent(self):
|
||||
from hermes_cli.copilot_auth import _token_fingerprint
|
||||
|
||||
fp1 = _token_fingerprint("gho_abc123")
|
||||
fp2 = _token_fingerprint("gho_abc123")
|
||||
assert fp1 == fp2
|
||||
|
||||
def test_different_tokens_different_fingerprints(self):
|
||||
from hermes_cli.copilot_auth import _token_fingerprint
|
||||
|
||||
fp1 = _token_fingerprint("gho_abc123")
|
||||
fp2 = _token_fingerprint("gho_xyz789")
|
||||
assert fp1 != fp2
|
||||
|
||||
def test_length(self):
|
||||
from hermes_cli.copilot_auth import _token_fingerprint
|
||||
|
||||
assert len(_token_fingerprint("gho_test")) == 16
|
||||
|
||||
|
||||
class TestCallerIntegration:
|
||||
"""Test that callers correctly use token exchange."""
|
||||
|
||||
@patch("hermes_cli.copilot_auth.resolve_copilot_token", return_value=("gho_raw", "GH_TOKEN"))
|
||||
@patch("hermes_cli.copilot_auth.get_copilot_api_token", return_value="exchanged_jwt")
|
||||
def test_auth_resolve_uses_exchange(self, mock_exchange, mock_resolve):
|
||||
from hermes_cli.auth import _resolve_api_key_provider_secret
|
||||
|
||||
# Create a minimal pconfig mock
|
||||
pconfig = MagicMock()
|
||||
token, source = _resolve_api_key_provider_secret("copilot", pconfig)
|
||||
assert token == "exchanged_jwt"
|
||||
assert source == "GH_TOKEN"
|
||||
mock_exchange.assert_called_once_with("gho_raw")
|
||||
Loading…
Add table
Add a link
Reference in a new issue