diff --git a/gateway/session.py b/gateway/session.py index 2d5376b071..5aefb6c012 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -785,7 +785,7 @@ class SessionStore: if self._db and db_session_id: try: - self._db.update_token_counts( + self._db.set_token_counts( db_session_id, input_tokens=input_tokens, output_tokens=output_tokens, diff --git a/hermes_state.py b/hermes_state.py index cf03951c79..b39c9c1f71 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -415,6 +415,72 @@ class SessionDB: ) self._conn.commit() + def set_token_counts( + self, + session_id: str, + input_tokens: int = 0, + output_tokens: int = 0, + model: str = None, + cache_read_tokens: int = 0, + cache_write_tokens: int = 0, + reasoning_tokens: int = 0, + estimated_cost_usd: Optional[float] = None, + actual_cost_usd: Optional[float] = None, + cost_status: Optional[str] = None, + cost_source: Optional[str] = None, + pricing_version: Optional[str] = None, + billing_provider: Optional[str] = None, + billing_base_url: Optional[str] = None, + billing_mode: Optional[str] = None, + ) -> None: + """Set token counters to absolute values (not increment). + + Use this when the caller provides cumulative totals from a completed + conversation run (e.g. the gateway, where the cached agent's + session_prompt_tokens already reflects the running total). + """ + with self._lock: + self._conn.execute( + """UPDATE sessions SET + input_tokens = ?, + output_tokens = ?, + cache_read_tokens = ?, + cache_write_tokens = ?, + reasoning_tokens = ?, + estimated_cost_usd = ?, + actual_cost_usd = CASE + WHEN ? IS NULL THEN actual_cost_usd + ELSE ? + END, + cost_status = COALESCE(?, cost_status), + cost_source = COALESCE(?, cost_source), + pricing_version = COALESCE(?, pricing_version), + billing_provider = COALESCE(billing_provider, ?), + billing_base_url = COALESCE(billing_base_url, ?), + billing_mode = COALESCE(billing_mode, ?), + model = COALESCE(model, ?) + WHERE id = ?""", + ( + input_tokens, + output_tokens, + cache_read_tokens, + cache_write_tokens, + reasoning_tokens, + estimated_cost_usd, + actual_cost_usd, + actual_cost_usd, + cost_status, + cost_source, + pricing_version, + billing_provider, + billing_base_url, + billing_mode, + model, + session_id, + ), + ) + self._conn.commit() + def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: """Get a session by ID.""" with self._lock: diff --git a/tests/gateway/test_session.py b/tests/gateway/test_session.py index 226e50593f..82281acc2e 100644 --- a/tests/gateway/test_session.py +++ b/tests/gateway/test_session.py @@ -846,7 +846,7 @@ class TestLastPromptTokens: store.update_session("k1", model="openai/gpt-5.4") - store._db.update_token_counts.assert_called_once_with( + store._db.set_token_counts.assert_called_once_with( "s1", input_tokens=0, output_tokens=0,