diff --git a/gateway/run.py b/gateway/run.py index d0dd00f5d3..e7bfb62576 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1114,6 +1114,9 @@ class GatewayRunner: # let the adapter-level batching/queueing logic absorb them. _quick_key = build_session_key(source) if _quick_key in self._running_agents: + if event.get_command() == "status": + return await self._handle_status_command(event) + if event.message_type == MessageType.PHOTO: logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20]) adapter = self.adapters.get(source.platform) @@ -1822,6 +1825,8 @@ class GatewayRunner: # Update session with actual prompt token count and model from the agent self.session_store.update_session( session_entry.session_key, + input_tokens=agent_result.get("input_tokens", 0), + output_tokens=agent_result.get("output_tokens", 0), last_prompt_tokens=agent_result.get("last_prompt_tokens", 0), model=agent_result.get("model"), ) @@ -4171,11 +4176,15 @@ class GatewayRunner: # Return final response, or a message if something went wrong final_response = result.get("final_response") - # Extract last actual prompt token count from the agent's compressor + # Extract actual token counts from the agent instance used for this run _last_prompt_toks = 0 + _input_toks = 0 + _output_toks = 0 _agent = agent_holder[0] if _agent and hasattr(_agent, "context_compressor"): _last_prompt_toks = getattr(_agent.context_compressor, "last_prompt_tokens", 0) + _input_toks = getattr(_agent, "session_prompt_tokens", 0) + _output_toks = getattr(_agent, "session_completion_tokens", 0) _resolved_model = getattr(_agent, "model", None) if _agent else None if not final_response: @@ -4187,6 +4196,8 @@ class GatewayRunner: "tools": tools_holder[0] or [], "history_offset": len(agent_history), "last_prompt_tokens": _last_prompt_toks, + "input_tokens": _input_toks, + "output_tokens": _output_toks, "model": _resolved_model, } @@ -4250,6 +4261,8 @@ class GatewayRunner: "tools": tools_holder[0] or [], "history_offset": len(agent_history), "last_prompt_tokens": _last_prompt_toks, + "input_tokens": _input_toks, + "output_tokens": _output_toks, "model": _resolved_model, "session_id": effective_session_id, } diff --git a/tests/gateway/test_status_command.py b/tests/gateway/test_status_command.py new file mode 100644 index 0000000000..1c22543f75 --- /dev/null +++ b/tests/gateway/test_status_command.py @@ -0,0 +1,133 @@ +"""Tests for gateway /status behavior and token persistence.""" + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from gateway.config import GatewayConfig, Platform, PlatformConfig +from gateway.platforms.base import MessageEvent +from gateway.session import SessionEntry, SessionSource, build_session_key + + +def _make_source() -> SessionSource: + return SessionSource( + platform=Platform.TELEGRAM, + user_id="u1", + chat_id="c1", + user_name="tester", + chat_type="dm", + ) + + +def _make_event(text: str) -> MessageEvent: + return MessageEvent( + text=text, + source=_make_source(), + message_id="m1", + ) + + +def _make_runner(session_entry: SessionEntry): + from gateway.run import GatewayRunner + + runner = object.__new__(GatewayRunner) + runner.config = GatewayConfig( + platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")} + ) + adapter = MagicMock() + adapter.send = AsyncMock() + runner.adapters = {Platform.TELEGRAM: adapter} + runner._voice_mode = {} + runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False) + runner.session_store = MagicMock() + runner.session_store.get_or_create_session.return_value = session_entry + runner.session_store.load_transcript.return_value = [] + runner.session_store.has_any_sessions.return_value = True + runner.session_store.append_to_transcript = MagicMock() + runner.session_store.rewrite_transcript = MagicMock() + runner.session_store.update_session = MagicMock() + runner._running_agents = {} + runner._pending_messages = {} + runner._pending_approvals = {} + runner._session_db = None + runner._reasoning_config = None + runner._provider_routing = {} + runner._fallback_model = None + runner._show_reasoning = False + runner._is_user_authorized = lambda _source: True + runner._set_session_env = lambda _context: None + runner._should_send_voice_reply = lambda *_args, **_kwargs: False + runner._send_voice_reply = AsyncMock() + runner._capture_gateway_honcho_if_configured = lambda *args, **kwargs: None + runner._emit_gateway_run_progress = AsyncMock() + return runner + + +@pytest.mark.asyncio +async def test_status_command_reports_running_agent_without_interrupt(monkeypatch): + session_entry = SessionEntry( + session_key=build_session_key(_make_source()), + session_id="sess-1", + created_at=datetime.now(), + updated_at=datetime.now(), + platform=Platform.TELEGRAM, + chat_type="dm", + total_tokens=321, + ) + runner = _make_runner(session_entry) + running_agent = MagicMock() + runner._running_agents[build_session_key(_make_source())] = running_agent + + result = await runner._handle_message(_make_event("/status")) + + assert "**Tokens:** 321" in result + assert "**Agent Running:** Yes ⚡" in result + running_agent.interrupt.assert_not_called() + assert runner._pending_messages == {} + + +@pytest.mark.asyncio +async def test_handle_message_persists_agent_token_counts(monkeypatch): + import gateway.run as gateway_run + + session_entry = SessionEntry( + session_key=build_session_key(_make_source()), + session_id="sess-1", + created_at=datetime.now(), + updated_at=datetime.now(), + platform=Platform.TELEGRAM, + chat_type="dm", + ) + runner = _make_runner(session_entry) + runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}] + runner._run_agent = AsyncMock( + return_value={ + "final_response": "ok", + "messages": [], + "tools": [], + "history_offset": 0, + "last_prompt_tokens": 80, + "input_tokens": 120, + "output_tokens": 45, + "model": "openai/test-model", + } + ) + + monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"}) + monkeypatch.setattr( + "agent.model_metadata.get_model_context_length", + lambda *_args, **_kwargs: 100000, + ) + + result = await runner._handle_message(_make_event("hello")) + + assert result == "ok" + runner.session_store.update_session.assert_called_once_with( + session_entry.session_key, + input_tokens=120, + output_tokens=45, + last_prompt_tokens=80, + model="openai/test-model", + )