diff --git a/cli.py b/cli.py index 48af2c69f..624139076 100644 --- a/cli.py +++ b/cli.py @@ -19,6 +19,7 @@ import shutil import sys import json import re +import concurrent.futures import base64 import atexit import tempfile @@ -65,6 +66,7 @@ from agent.usage_pricing import ( format_duration_compact, format_token_count_compact, ) +from agent.account_usage import fetch_account_usage, render_account_usage_lines from hermes_cli.banner import _format_context_length, format_banner_version_label _COMMAND_SPINNER_FRAMES = ("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏") @@ -7018,6 +7020,27 @@ class HermesCLI: if cost_result.status == "unknown": print(f" Note: Pricing unknown for {agent.model}") + # Account limits -- fetched off-thread with a hard timeout so slow + # provider APIs don't hang the prompt. + provider = getattr(agent, "provider", None) or getattr(self, "provider", None) + base_url = getattr(agent, "base_url", None) or getattr(self, "base_url", None) + api_key = getattr(agent, "api_key", None) or getattr(self, "api_key", None) + account_snapshot = None + if provider: + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as _pool: + try: + account_snapshot = _pool.submit( + fetch_account_usage, provider, + base_url=base_url, api_key=api_key, + ).result(timeout=10.0) + except (concurrent.futures.TimeoutError, Exception): + account_snapshot = None + account_lines = [f" {line}" for line in render_account_usage_lines(account_snapshot)] + if account_lines: + print() + for line in account_lines: + print(line) + if self.verbose: logging.getLogger().setLevel(logging.DEBUG) for noisy in ('openai', 'openai._base_client', 'httpx', 'httpcore', 'asyncio', 'hpack', 'grpc', 'modal'): diff --git a/gateway/run.py b/gateway/run.py index 0343790b0..c19303e61 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -30,6 +30,8 @@ from pathlib import Path from datetime import datetime from typing import Dict, Optional, Any, List +from agent.account_usage import fetch_account_usage, render_account_usage_lines + # --- Agent cache tuning --------------------------------------------------- # Bounds the per-session AIAgent cache to prevent unbounded growth in # long-lived gateways (each AIAgent holds LLM clients, tool schemas, @@ -7262,6 +7264,38 @@ class GatewayRunner: if cached: agent = cached[0] + # Resolve provider/base_url/api_key for the account-usage fetch. + # Prefer the live agent; fall back to persisted billing data on the + # SessionDB row so `/usage` still returns account info between turns + # when no agent is resident. + provider = getattr(agent, "provider", None) if agent and agent is not _AGENT_PENDING_SENTINEL else None + base_url = getattr(agent, "base_url", None) if agent and agent is not _AGENT_PENDING_SENTINEL else None + api_key = getattr(agent, "api_key", None) if agent and agent is not _AGENT_PENDING_SENTINEL else None + if not provider and getattr(self, "_session_db", None) is not None: + try: + _entry_for_billing = self.session_store.get_or_create_session(source) + persisted = self._session_db.get_session(_entry_for_billing.session_id) or {} + except Exception: + persisted = {} + provider = provider or persisted.get("billing_provider") + base_url = base_url or persisted.get("billing_base_url") + + # Fetch account usage off the event loop so slow provider APIs don't + # block the gateway. Failures are non-fatal -- account_lines stays []. + account_lines: list[str] = [] + if provider: + try: + account_snapshot = await asyncio.to_thread( + fetch_account_usage, + provider, + base_url=base_url, + api_key=api_key, + ) + except Exception: + account_snapshot = None + if account_snapshot: + account_lines = render_account_usage_lines(account_snapshot, markdown=True) + if agent and hasattr(agent, "session_total_tokens") and agent.session_api_calls > 0: lines = [] @@ -7319,6 +7353,10 @@ class GatewayRunner: if ctx.compression_count: lines.append(f"Compressions: {ctx.compression_count}") + if account_lines: + lines.append("") + lines.extend(account_lines) + return "\n".join(lines) # No agent at all -- check session history for a rough count @@ -7328,12 +7366,18 @@ class GatewayRunner: from agent.model_metadata import estimate_messages_tokens_rough msgs = [m for m in history if m.get("role") in ("user", "assistant") and m.get("content")] approx = estimate_messages_tokens_rough(msgs) - return ( - f"📊 **Session Info**\n" - f"Messages: {len(msgs)}\n" - f"Estimated context: ~{approx:,} tokens\n" - f"_(Detailed usage available after the first agent response)_" - ) + lines = [ + "📊 **Session Info**", + f"Messages: {len(msgs)}", + f"Estimated context: ~{approx:,} tokens", + "_(Detailed usage available after the first agent response)_", + ] + if account_lines: + lines.append("") + lines.extend(account_lines) + return "\n".join(lines) + if account_lines: + return "\n".join(account_lines) return "No usage data available for this session." async def _handle_insights_command(self, event: MessageEvent) -> str: diff --git a/tests/gateway/test_usage_command.py b/tests/gateway/test_usage_command.py index 291581089..feced75b2 100644 --- a/tests/gateway/test_usage_command.py +++ b/tests/gateway/test_usage_command.py @@ -175,3 +175,79 @@ class TestUsageCachedAgent: result = await runner._handle_usage_command(event) assert "Cost: included" in result + + +class TestUsageAccountSection: + """Account-limits section appended to /usage output (PR #2486).""" + + @pytest.mark.asyncio + async def test_usage_command_includes_account_section(self, monkeypatch): + agent = _make_mock_agent(provider="openai-codex") + agent.base_url = "https://chatgpt.com/backend-api/codex" + agent.api_key = "unused" + runner = _make_runner(SK, cached_agent=agent) + event = MagicMock() + + monkeypatch.setattr( + "gateway.run.fetch_account_usage", + lambda provider, base_url=None, api_key=None: object(), + ) + monkeypatch.setattr( + "gateway.run.render_account_usage_lines", + lambda snapshot, markdown=False: [ + "📈 **Account limits**", + "Provider: openai-codex (Pro)", + "Session: 85% remaining (15% used)", + ], + ) + with patch("agent.rate_limit_tracker.format_rate_limit_compact", return_value="RPM: 50/60"), \ + patch("agent.usage_pricing.estimate_usage_cost") as mock_cost: + mock_cost.return_value = MagicMock(amount_usd=None, status="included") + result = await runner._handle_usage_command(event) + + assert "📊 **Session Token Usage**" in result + assert "📈 **Account limits**" in result + assert "Provider: openai-codex (Pro)" in result + + @pytest.mark.asyncio + async def test_usage_command_uses_persisted_provider_when_agent_not_running(self, monkeypatch): + runner = _make_runner(SK) + runner._session_db = MagicMock() + runner._session_db.get_session.return_value = { + "billing_provider": "openai-codex", + "billing_base_url": "https://chatgpt.com/backend-api/codex", + } + session_entry = MagicMock() + session_entry.session_id = "sess-1" + runner.session_store.get_or_create_session.return_value = session_entry + runner.session_store.load_transcript.return_value = [ + {"role": "user", "content": "earlier"}, + ] + + calls = {} + + async def _fake_to_thread(fn, *args, **kwargs): + calls["args"] = args + calls["kwargs"] = kwargs + return fn(*args, **kwargs) + + monkeypatch.setattr("gateway.run.asyncio.to_thread", _fake_to_thread) + monkeypatch.setattr( + "gateway.run.fetch_account_usage", + lambda provider, base_url=None, api_key=None: object(), + ) + monkeypatch.setattr( + "gateway.run.render_account_usage_lines", + lambda snapshot, markdown=False: [ + "📈 **Account limits**", + "Provider: openai-codex (Pro)", + ], + ) + + event = MagicMock() + result = await runner._handle_usage_command(event) + + assert calls["args"] == ("openai-codex",) + assert calls["kwargs"]["base_url"] == "https://chatgpt.com/backend-api/codex" + assert "📊 **Session Info**" in result + assert "📈 **Account limits**" in result