diff --git a/gateway/run.py b/gateway/run.py index 576b84151e..bea75af013 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -2804,20 +2804,12 @@ class GatewayRunner: skip_db=agent_persisted, ) - # Update session with actual prompt token count and model from the agent + # Token counts and model are now persisted by the agent directly. + # Keep only last_prompt_tokens here for context-window tracking and + # compression decisions. self.session_store.update_session( session_entry.session_key, - input_tokens=agent_result.get("input_tokens", 0), - output_tokens=agent_result.get("output_tokens", 0), - cache_read_tokens=agent_result.get("cache_read_tokens", 0), - cache_write_tokens=agent_result.get("cache_write_tokens", 0), last_prompt_tokens=agent_result.get("last_prompt_tokens", 0), - model=agent_result.get("model"), - estimated_cost_usd=agent_result.get("estimated_cost_usd"), - cost_status=agent_result.get("cost_status"), - cost_source=agent_result.get("cost_source"), - provider=agent_result.get("provider"), - base_url=agent_result.get("base_url"), ) # Auto voice reply: send TTS audio before the text response diff --git a/gateway/session.py b/gateway/session.py index bcbac7193b..c3b913ef81 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -778,66 +778,18 @@ class SessionStore: def update_session( self, session_key: str, - input_tokens: int = 0, - output_tokens: int = 0, - cache_read_tokens: int = 0, - cache_write_tokens: int = 0, last_prompt_tokens: int = None, - model: str = None, - estimated_cost_usd: Optional[float] = None, - cost_status: Optional[str] = None, - cost_source: Optional[str] = None, - provider: Optional[str] = None, - base_url: Optional[str] = None, ) -> None: - """Update a session's metadata after an interaction.""" - db_session_id = None - + """Update lightweight session metadata after an interaction.""" with self._lock: self._ensure_loaded_locked() if session_key in self._entries: entry = self._entries[session_key] entry.updated_at = _now() - # Direct assignment — the gateway receives cumulative totals - # from the cached agent, not per-call deltas. - entry.input_tokens = input_tokens - entry.output_tokens = output_tokens - entry.cache_read_tokens = cache_read_tokens - entry.cache_write_tokens = cache_write_tokens if last_prompt_tokens is not None: entry.last_prompt_tokens = last_prompt_tokens - if estimated_cost_usd is not None: - entry.estimated_cost_usd = estimated_cost_usd - if cost_status: - entry.cost_status = cost_status - entry.total_tokens = ( - entry.input_tokens - + entry.output_tokens - + entry.cache_read_tokens - + entry.cache_write_tokens - ) self._save() - db_session_id = entry.session_id - - if self._db and db_session_id: - try: - self._db.set_token_counts( - db_session_id, - input_tokens=input_tokens, - output_tokens=output_tokens, - cache_read_tokens=cache_read_tokens, - cache_write_tokens=cache_write_tokens, - estimated_cost_usd=estimated_cost_usd, - cost_status=cost_status, - cost_source=cost_source, - billing_provider=provider, - billing_base_url=base_url, - model=model, - absolute=True, - ) - except Exception as e: - logger.debug("Session DB operation failed: %s", e) def reset_session(self, session_key: str) -> Optional[SessionEntry]: """Force reset a session, creating a new session ID.""" diff --git a/run_agent.py b/run_agent.py index ed0de1fd83..8822821015 100644 --- a/run_agent.py +++ b/run_agent.py @@ -7221,11 +7221,13 @@ class AIAgent: self.session_cost_source = cost_result.source # Persist token counts to session DB for /insights. - # Gateway sessions persist via session_store.update_session() - # after run_conversation returns, so only persist here for - # CLI (and other non-gateway) platforms to avoid double-counting. - if (self._session_db and self.session_id - and getattr(self, 'platform', None) == 'cli'): + # Do this for every platform with a session_id so non-CLI + # sessions (gateway, cron, delegated runs) cannot lose + # token/accounting data if a higher-level persistence path + # is skipped or fails. Gateway/session-store writes use + # absolute totals, so they safely overwrite these per-call + # deltas instead of double-counting them. + if self._session_db and self.session_id: try: self._session_db.update_token_counts( self.session_id, diff --git a/tests/gateway/test_session.py b/tests/gateway/test_session.py index 82281acc2e..77d4993ee3 100644 --- a/tests/gateway/test_session.py +++ b/tests/gateway/test_session.py @@ -825,43 +825,6 @@ class TestLastPromptTokens: store.update_session("k1", last_prompt_tokens=0) assert entry.last_prompt_tokens == 0 - def test_update_session_passes_model_to_db(self, tmp_path): - """Gateway session updates should forward the resolved model to SQLite.""" - config = GatewayConfig() - with patch("gateway.session.SessionStore._ensure_loaded"): - store = SessionStore(sessions_dir=tmp_path, config=config) - store._loaded = True - store._save = MagicMock() - store._db = MagicMock() - - from gateway.session import SessionEntry - from datetime import datetime - entry = SessionEntry( - session_key="k1", - session_id="s1", - created_at=datetime.now(), - updated_at=datetime.now(), - ) - store._entries = {"k1": entry} - - store.update_session("k1", model="openai/gpt-5.4") - - store._db.set_token_counts.assert_called_once_with( - "s1", - input_tokens=0, - output_tokens=0, - cache_read_tokens=0, - cache_write_tokens=0, - estimated_cost_usd=None, - cost_status=None, - cost_source=None, - billing_provider=None, - billing_base_url=None, - model="openai/gpt-5.4", - absolute=True, - ) - - class TestRewriteTranscriptPreservesReasoning: """rewrite_transcript must not drop reasoning fields from SQLite.""" diff --git a/tests/gateway/test_status_command.py b/tests/gateway/test_status_command.py index 1378ff1cb9..328b795c63 100644 --- a/tests/gateway/test_status_command.py +++ b/tests/gateway/test_status_command.py @@ -126,15 +126,5 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch): assert result == "ok" runner.session_store.update_session.assert_called_once_with( session_entry.session_key, - input_tokens=120, - output_tokens=45, - cache_read_tokens=0, - cache_write_tokens=0, last_prompt_tokens=80, - model="openai/test-model", - estimated_cost_usd=None, - cost_status=None, - cost_source=None, - provider=None, - base_url=None, ) diff --git a/tests/test_token_persistence_non_cli.py b/tests/test_token_persistence_non_cli.py new file mode 100644 index 0000000000..d25cf07ab8 --- /dev/null +++ b/tests/test_token_persistence_non_cli.py @@ -0,0 +1,62 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from run_agent import AIAgent + + +def _mock_response(*, usage: dict, content: str = "done"): + msg = SimpleNamespace(content=content, tool_calls=None) + choice = SimpleNamespace(message=msg, finish_reason="stop") + return SimpleNamespace( + choices=[choice], + model="test/model", + usage=SimpleNamespace(**usage), + ) + + +def _make_agent(session_db, *, platform: str): + with ( + patch("run_agent.get_tool_definitions", return_value=[]), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + ): + agent = AIAgent( + api_key="test-key", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + session_db=session_db, + session_id=f"{platform}-session", + platform=platform, + ) + agent.client = MagicMock() + agent.client.chat.completions.create.return_value = _mock_response( + usage={ + "prompt_tokens": 11, + "completion_tokens": 7, + "total_tokens": 18, + } + ) + return agent + + +def test_run_conversation_persists_tokens_for_telegram_sessions(): + session_db = MagicMock() + agent = _make_agent(session_db, platform="telegram") + + result = agent.run_conversation("hello") + + assert result["final_response"] == "done" + session_db.update_token_counts.assert_called_once() + assert session_db.update_token_counts.call_args.args[0] == "telegram-session" + + +def test_run_conversation_persists_tokens_for_cron_sessions(): + session_db = MagicMock() + agent = _make_agent(session_db, platform="cron") + + result = agent.run_conversation("hello") + + assert result["final_response"] == "done" + session_db.update_token_counts.assert_called_once() + assert session_db.update_token_counts.call_args.args[0] == "cron-session"