diff --git a/agent/credential_pool.py b/agent/credential_pool.py index 2d5accd41..6fbb553d2 100644 --- a/agent/credential_pool.py +++ b/agent/credential_pool.py @@ -1078,9 +1078,10 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup # env vars (COPILOT_GITHUB_TOKEN / GH_TOKEN). They don't live in # the auth store or credential pool, so we resolve them here. try: - from hermes_cli.copilot_auth import resolve_copilot_token + from hermes_cli.copilot_auth import resolve_copilot_token, get_copilot_api_token token, source = resolve_copilot_token() if token: + api_token = get_copilot_api_token(token) source_name = "gh_cli" if "gh" in source.lower() else f"env:{source}" if not _is_suppressed(provider, source_name): active_sources.add(source_name) @@ -1092,7 +1093,7 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup { "source": source_name, "auth_type": AUTH_TYPE_API_KEY, - "access_token": token, + "access_token": api_token, "base_url": pconfig.inference_base_url if pconfig else "", "label": source, }, diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index ff0028875..207dc719f 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -426,10 +426,10 @@ def _resolve_api_key_provider_secret( if provider_id == "copilot": # Use the dedicated copilot auth module for proper token validation try: - from hermes_cli.copilot_auth import resolve_copilot_token + from hermes_cli.copilot_auth import resolve_copilot_token, get_copilot_api_token token, source = resolve_copilot_token() if token: - return token, source + return get_copilot_api_token(token), source except ValueError as exc: logger.warning("Copilot token validation failed: %s", exc) except Exception: diff --git a/hermes_cli/copilot_auth.py b/hermes_cli/copilot_auth.py index 24859da1a..348e4efe8 100644 --- a/hermes_cli/copilot_auth.py +++ b/hermes_cli/copilot_auth.py @@ -275,6 +275,99 @@ def copilot_device_code_login( return None +# ─── Copilot Token Exchange ──────────────────────────────────────────────── + +# Module-level cache for exchanged Copilot API tokens. +# Maps raw_token_fingerprint -> (api_token, expires_at_epoch). +_jwt_cache: dict[str, tuple[str, float]] = {} +_JWT_REFRESH_MARGIN_SECONDS = 120 # refresh 2 min before expiry + +# Token exchange endpoint and headers (matching VS Code / Copilot CLI) +_TOKEN_EXCHANGE_URL = "https://api.github.com/copilot_internal/v2/token" +_EDITOR_VERSION = "vscode/1.104.1" +_EXCHANGE_USER_AGENT = "GitHubCopilotChat/0.26.7" + + +def _token_fingerprint(raw_token: str) -> str: + """Short fingerprint of a raw token for cache keying (avoids storing full token).""" + import hashlib + return hashlib.sha256(raw_token.encode()).hexdigest()[:16] + + +def exchange_copilot_token(raw_token: str, *, timeout: float = 10.0) -> tuple[str, float]: + """Exchange a raw GitHub token for a short-lived Copilot API token. + + Calls ``GET https://api.github.com/copilot_internal/v2/token`` with + the raw GitHub token and returns ``(api_token, expires_at)``. + + The returned token is a semicolon-separated string (not a standard JWT) + used as ``Authorization: Bearer `` for Copilot API requests. + + Results are cached in-process and reused until close to expiry. + Raises ``ValueError`` on failure. + """ + import urllib.request + + fp = _token_fingerprint(raw_token) + + # Check cache first + cached = _jwt_cache.get(fp) + if cached: + api_token, expires_at = cached + if time.time() < expires_at - _JWT_REFRESH_MARGIN_SECONDS: + return api_token, expires_at + + req = urllib.request.Request( + _TOKEN_EXCHANGE_URL, + method="GET", + headers={ + "Authorization": f"token {raw_token}", + "User-Agent": _EXCHANGE_USER_AGENT, + "Accept": "application/json", + "Editor-Version": _EDITOR_VERSION, + }, + ) + + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + data = json.loads(resp.read().decode()) + except Exception as exc: + raise ValueError(f"Copilot token exchange failed: {exc}") from exc + + api_token = data.get("token", "") + expires_at = data.get("expires_at", 0) + if not api_token: + raise ValueError("Copilot token exchange returned empty token") + + # Convert expires_at to float if needed + expires_at = float(expires_at) if expires_at else time.time() + 1800 + + _jwt_cache[fp] = (api_token, expires_at) + logger.debug( + "Copilot token exchanged, expires_at=%s", + expires_at, + ) + return api_token, expires_at + + +def get_copilot_api_token(raw_token: str) -> str: + """Exchange a raw GitHub token for a Copilot API token, with fallback. + + Convenience wrapper: returns the exchanged token on success, or the + raw token unchanged if the exchange fails (e.g. network error, unsupported + account type). This preserves existing behaviour for accounts that don't + need exchange while enabling access to internal-only models for those that do. + """ + if not raw_token: + return raw_token + try: + api_token, _ = exchange_copilot_token(raw_token) + return api_token + except Exception as exc: + logger.debug("Copilot token exchange failed, using raw token: %s", exc) + return raw_token + + # ─── Copilot API Headers ─────────────────────────────────────────────────── def copilot_request_headers( diff --git a/tests/hermes_cli/test_copilot_token_exchange.py b/tests/hermes_cli/test_copilot_token_exchange.py new file mode 100644 index 000000000..9c6a219ab --- /dev/null +++ b/tests/hermes_cli/test_copilot_token_exchange.py @@ -0,0 +1,159 @@ +"""Tests for Copilot token exchange (raw GitHub token → Copilot API token).""" + +from __future__ import annotations + +import json +import time +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture(autouse=True) +def _clear_jwt_cache(): + """Reset the module-level JWT cache before each test.""" + import hermes_cli.copilot_auth as mod + mod._jwt_cache.clear() + yield + mod._jwt_cache.clear() + + +class TestExchangeCopilotToken: + """Tests for exchange_copilot_token().""" + + def _mock_urlopen(self, token="tid=abc;exp=123;sku=copilot_individual", expires_at=None): + """Create a mock urlopen context manager returning a token response.""" + if expires_at is None: + expires_at = time.time() + 1800 + resp_data = json.dumps({"token": token, "expires_at": expires_at}).encode() + mock_resp = MagicMock() + mock_resp.read.return_value = resp_data + mock_resp.__enter__ = MagicMock(return_value=mock_resp) + mock_resp.__exit__ = MagicMock(return_value=False) + return mock_resp + + @patch("urllib.request.urlopen") + def test_exchanges_token_successfully(self, mock_urlopen): + from hermes_cli.copilot_auth import exchange_copilot_token + + mock_urlopen.return_value = self._mock_urlopen(token="tid=abc;exp=999") + api_token, expires_at = exchange_copilot_token("gho_test123") + + assert api_token == "tid=abc;exp=999" + assert isinstance(expires_at, float) + + # Verify request was made with correct headers + call_args = mock_urlopen.call_args + req = call_args[0][0] + assert req.get_header("Authorization") == "token gho_test123" + assert "GitHubCopilotChat" in req.get_header("User-agent") + + @patch("urllib.request.urlopen") + def test_caches_result(self, mock_urlopen): + from hermes_cli.copilot_auth import exchange_copilot_token + + future = time.time() + 1800 + mock_urlopen.return_value = self._mock_urlopen(expires_at=future) + + exchange_copilot_token("gho_test123") + exchange_copilot_token("gho_test123") + + assert mock_urlopen.call_count == 1 + + @patch("urllib.request.urlopen") + def test_refreshes_expired_cache(self, mock_urlopen): + from hermes_cli.copilot_auth import exchange_copilot_token, _jwt_cache, _token_fingerprint + + # Seed cache with expired entry + fp = _token_fingerprint("gho_test123") + _jwt_cache[fp] = ("old_token", time.time() - 10) + + mock_urlopen.return_value = self._mock_urlopen( + token="new_token", expires_at=time.time() + 1800 + ) + api_token, _ = exchange_copilot_token("gho_test123") + + assert api_token == "new_token" + assert mock_urlopen.call_count == 1 + + @patch("urllib.request.urlopen") + def test_raises_on_empty_token(self, mock_urlopen): + from hermes_cli.copilot_auth import exchange_copilot_token + + resp_data = json.dumps({"token": "", "expires_at": 0}).encode() + mock_resp = MagicMock() + mock_resp.read.return_value = resp_data + mock_resp.__enter__ = MagicMock(return_value=mock_resp) + mock_resp.__exit__ = MagicMock(return_value=False) + mock_urlopen.return_value = mock_resp + + with pytest.raises(ValueError, match="empty token"): + exchange_copilot_token("gho_test123") + + @patch("urllib.request.urlopen", side_effect=Exception("network error")) + def test_raises_on_network_error(self, mock_urlopen): + from hermes_cli.copilot_auth import exchange_copilot_token + + with pytest.raises(ValueError, match="network error"): + exchange_copilot_token("gho_test123") + + +class TestGetCopilotApiToken: + """Tests for get_copilot_api_token() — the fallback wrapper.""" + + @patch("hermes_cli.copilot_auth.exchange_copilot_token") + def test_returns_exchanged_token(self, mock_exchange): + from hermes_cli.copilot_auth import get_copilot_api_token + + mock_exchange.return_value = ("exchanged_jwt", time.time() + 1800) + assert get_copilot_api_token("gho_raw") == "exchanged_jwt" + + @patch("hermes_cli.copilot_auth.exchange_copilot_token", side_effect=ValueError("fail")) + def test_falls_back_to_raw_token(self, mock_exchange): + from hermes_cli.copilot_auth import get_copilot_api_token + + assert get_copilot_api_token("gho_raw") == "gho_raw" + + def test_empty_token_passthrough(self): + from hermes_cli.copilot_auth import get_copilot_api_token + + assert get_copilot_api_token("") == "" + + +class TestTokenFingerprint: + """Tests for _token_fingerprint().""" + + def test_consistent(self): + from hermes_cli.copilot_auth import _token_fingerprint + + fp1 = _token_fingerprint("gho_abc123") + fp2 = _token_fingerprint("gho_abc123") + assert fp1 == fp2 + + def test_different_tokens_different_fingerprints(self): + from hermes_cli.copilot_auth import _token_fingerprint + + fp1 = _token_fingerprint("gho_abc123") + fp2 = _token_fingerprint("gho_xyz789") + assert fp1 != fp2 + + def test_length(self): + from hermes_cli.copilot_auth import _token_fingerprint + + assert len(_token_fingerprint("gho_test")) == 16 + + +class TestCallerIntegration: + """Test that callers correctly use token exchange.""" + + @patch("hermes_cli.copilot_auth.resolve_copilot_token", return_value=("gho_raw", "GH_TOKEN")) + @patch("hermes_cli.copilot_auth.get_copilot_api_token", return_value="exchanged_jwt") + def test_auth_resolve_uses_exchange(self, mock_exchange, mock_resolve): + from hermes_cli.auth import _resolve_api_key_provider_secret + + # Create a minimal pconfig mock + pconfig = MagicMock() + token, source = _resolve_api_key_provider_secret("copilot", pconfig) + assert token == "exchanged_jwt" + assert source == "GH_TOKEN" + mock_exchange.assert_called_once_with("gho_raw")