diff --git a/gateway/run.py b/gateway/run.py index b75b0e1f0b..662e089413 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -5274,27 +5274,76 @@ class GatewayRunner: ) async def _handle_usage_command(self, event: MessageEvent) -> str: - """Handle /usage command -- show token usage for the session's last agent run.""" + """Handle /usage command -- show token usage for the current session. + + Checks both _running_agents (mid-turn) and _agent_cache (between turns) + so that rate limits, cost estimates, and detailed token breakdowns are + available whenever the user asks, not only while the agent is running. + """ source = event.source session_key = self._session_key_for_source(source) + # Try running agent first (mid-turn), then cached agent (between turns) agent = self._running_agents.get(session_key) + if not agent or agent is _AGENT_PENDING_SENTINEL: + _cache_lock = getattr(self, "_agent_cache_lock", None) + _cache = getattr(self, "_agent_cache", None) + if _cache_lock and _cache is not None: + with _cache_lock: + cached = _cache.get(session_key) + if cached: + agent = cached[0] + if agent and hasattr(agent, "session_total_tokens") and agent.session_api_calls > 0: lines = [] - # Rate limits first (when available from provider headers) + # Rate limits (when available from provider headers) rl_state = agent.get_rate_limit_state() if rl_state and rl_state.has_data: from agent.rate_limit_tracker import format_rate_limit_compact lines.append(f"⏱️ **Rate Limits:** {format_rate_limit_compact(rl_state)}") lines.append("") - # Session token usage + # Session token usage — detailed breakdown matching CLI + input_tokens = getattr(agent, "session_input_tokens", 0) or 0 + output_tokens = getattr(agent, "session_output_tokens", 0) or 0 + cache_read = getattr(agent, "session_cache_read_tokens", 0) or 0 + cache_write = getattr(agent, "session_cache_write_tokens", 0) or 0 + lines.append("📊 **Session Token Usage**") - lines.append(f"Prompt (input): {agent.session_prompt_tokens:,}") - lines.append(f"Completion (output): {agent.session_completion_tokens:,}") + lines.append(f"Model: `{agent.model}`") + lines.append(f"Input tokens: {input_tokens:,}") + if cache_read: + lines.append(f"Cache read tokens: {cache_read:,}") + if cache_write: + lines.append(f"Cache write tokens: {cache_write:,}") + lines.append(f"Output tokens: {output_tokens:,}") lines.append(f"Total: {agent.session_total_tokens:,}") lines.append(f"API calls: {agent.session_api_calls}") + + # Cost estimation + try: + from agent.usage_pricing import CanonicalUsage, estimate_usage_cost + cost_result = estimate_usage_cost( + agent.model, + CanonicalUsage( + input_tokens=input_tokens, + output_tokens=output_tokens, + cache_read_tokens=cache_read, + cache_write_tokens=cache_write, + ), + provider=getattr(agent, "provider", None), + base_url=getattr(agent, "base_url", None), + ) + if cost_result.amount_usd is not None: + prefix = "~" if cost_result.status == "estimated" else "" + lines.append(f"Cost: {prefix}${float(cost_result.amount_usd):.4f}") + elif cost_result.status == "included": + lines.append("Cost: included") + except Exception: + pass + + # Context window and compressions ctx = agent.context_compressor if ctx.last_prompt_tokens: pct = min(100, ctx.last_prompt_tokens / ctx.context_length * 100) if ctx.context_length else 0 @@ -5304,7 +5353,7 @@ class GatewayRunner: return "\n".join(lines) - # No running agent -- check session history for a rough count + # No agent at all -- check session history for a rough count session_entry = self.session_store.get_or_create_session(source) history = self.session_store.load_transcript(session_entry.session_id) if history: @@ -5315,7 +5364,7 @@ class GatewayRunner: f"📊 **Session Info**\n" f"Messages: {len(msgs)}\n" f"Estimated context: ~{approx:,} tokens\n" - f"_(Detailed usage available during active conversations)_" + f"_(Detailed usage available after the first agent response)_" ) return "No usage data available for this session." diff --git a/tests/gateway/test_usage_command.py b/tests/gateway/test_usage_command.py new file mode 100644 index 0000000000..2915810891 --- /dev/null +++ b/tests/gateway/test_usage_command.py @@ -0,0 +1,177 @@ +"""Tests for gateway /usage command — agent cache lookup and output fields.""" + +import asyncio +import threading +from unittest.mock import MagicMock, patch + +import pytest + + +def _make_mock_agent(**overrides): + """Create a mock AIAgent with realistic session counters.""" + agent = MagicMock() + defaults = { + "model": "anthropic/claude-sonnet-4.6", + "provider": "openrouter", + "base_url": None, + "session_total_tokens": 50_000, + "session_api_calls": 5, + "session_prompt_tokens": 40_000, + "session_completion_tokens": 10_000, + "session_input_tokens": 35_000, + "session_output_tokens": 10_000, + "session_cache_read_tokens": 5_000, + "session_cache_write_tokens": 2_000, + } + defaults.update(overrides) + for k, v in defaults.items(): + setattr(agent, k, v) + + # Rate limit state + rl = MagicMock() + rl.has_data = True + agent.get_rate_limit_state.return_value = rl + + # Context compressor + ctx = MagicMock() + ctx.last_prompt_tokens = 30_000 + ctx.context_length = 200_000 + ctx.compression_count = 1 + agent.context_compressor = ctx + + return agent + + +def _make_runner(session_key, agent=None, cached_agent=None): + """Build a bare GatewayRunner with just the fields _handle_usage_command needs.""" + from gateway.run import GatewayRunner, _AGENT_PENDING_SENTINEL + + runner = object.__new__(GatewayRunner) + runner._running_agents = {} + runner._running_agents_ts = {} + runner._agent_cache = {} + runner._agent_cache_lock = threading.Lock() + runner.session_store = MagicMock() + + if agent is not None: + runner._running_agents[session_key] = agent + + if cached_agent is not None: + runner._agent_cache[session_key] = (cached_agent, "sig") + + # Wire helper + runner._session_key_for_source = MagicMock(return_value=session_key) + + return runner + + +SK = "agent:main:telegram:private:12345" + + +class TestUsageCachedAgent: + """The main fix: /usage should find agents in _agent_cache between turns.""" + + @pytest.mark.asyncio + async def test_cached_agent_shows_detailed_usage(self): + agent = _make_mock_agent() + runner = _make_runner(SK, cached_agent=agent) + event = MagicMock() + + with patch("agent.rate_limit_tracker.format_rate_limit_compact", return_value="RPM: 50/60"), \ + patch("agent.usage_pricing.estimate_usage_cost") as mock_cost: + mock_cost.return_value = MagicMock(amount_usd=0.1234, status="estimated") + result = await runner._handle_usage_command(event) + + assert "claude-sonnet-4.6" in result + assert "35,000" in result # input tokens + assert "10,000" in result # output tokens + assert "5,000" in result # cache read + assert "2,000" in result # cache write + assert "50,000" in result # total + assert "$0.1234" in result + assert "30,000" in result # context + assert "Compressions: 1" in result + + @pytest.mark.asyncio + async def test_running_agent_preferred_over_cache(self): + """When agent is in both dicts, the running one wins.""" + running = _make_mock_agent(session_api_calls=10, session_total_tokens=80_000) + cached = _make_mock_agent(session_api_calls=5, session_total_tokens=50_000) + runner = _make_runner(SK, agent=running, cached_agent=cached) + event = MagicMock() + + with patch("agent.rate_limit_tracker.format_rate_limit_compact", return_value="RPM: 50/60"), \ + patch("agent.usage_pricing.estimate_usage_cost") as mock_cost: + mock_cost.return_value = MagicMock(amount_usd=None, status="unknown") + result = await runner._handle_usage_command(event) + + assert "80,000" in result # running agent's total + assert "API calls: 10" in result + + @pytest.mark.asyncio + async def test_sentinel_skipped_uses_cache(self): + """PENDING sentinel in _running_agents should fall through to cache.""" + from gateway.run import _AGENT_PENDING_SENTINEL + + cached = _make_mock_agent() + runner = _make_runner(SK, cached_agent=cached) + runner._running_agents[SK] = _AGENT_PENDING_SENTINEL + event = MagicMock() + + with patch("agent.rate_limit_tracker.format_rate_limit_compact", return_value="RPM: 50/60"), \ + patch("agent.usage_pricing.estimate_usage_cost") as mock_cost: + mock_cost.return_value = MagicMock(amount_usd=None, status="unknown") + result = await runner._handle_usage_command(event) + + assert "claude-sonnet-4.6" in result + assert "Session Token Usage" in result + + @pytest.mark.asyncio + async def test_no_agent_anywhere_falls_to_history(self): + """No running or cached agent → rough estimate from transcript.""" + runner = _make_runner(SK) + event = MagicMock() + + session_entry = MagicMock() + session_entry.session_id = "sess123" + runner.session_store.get_or_create_session.return_value = session_entry + runner.session_store.load_transcript.return_value = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi there"}, + ] + + with patch("agent.model_metadata.estimate_messages_tokens_rough", return_value=500): + result = await runner._handle_usage_command(event) + + assert "Session Info" in result + assert "Messages: 2" in result + assert "~500" in result + + @pytest.mark.asyncio + async def test_cache_read_write_hidden_when_zero(self): + """Cache token lines should be omitted when zero.""" + agent = _make_mock_agent(session_cache_read_tokens=0, session_cache_write_tokens=0) + runner = _make_runner(SK, cached_agent=agent) + event = MagicMock() + + with patch("agent.rate_limit_tracker.format_rate_limit_compact", return_value="RPM: 50/60"), \ + patch("agent.usage_pricing.estimate_usage_cost") as mock_cost: + mock_cost.return_value = MagicMock(amount_usd=None, status="unknown") + result = await runner._handle_usage_command(event) + + assert "Cache read" not in result + assert "Cache write" not in result + + @pytest.mark.asyncio + async def test_cost_included_status(self): + """Subscription-included providers show 'included' instead of dollar amount.""" + agent = _make_mock_agent(provider="openai-codex") + runner = _make_runner(SK, cached_agent=agent) + event = MagicMock() + + with patch("agent.rate_limit_tracker.format_rate_limit_compact", return_value="RPM: 50/60"), \ + patch("agent.usage_pricing.estimate_usage_cost") as mock_cost: + mock_cost.return_value = MagicMock(amount_usd=None, status="included") + result = await runner._handle_usage_command(event) + + assert "Cost: included" in result