fix(codex): record app-server token usage in session accounting

This commit is contained in:
JP Lew 2026-06-09 12:52:52 +05:30 committed by Teknium
parent 85852b71d8
commit cb4cc08b0a
4 changed files with 266 additions and 1 deletions

View file

@ -25,6 +25,154 @@ from typing import Any, Dict, List
logger = logging.getLogger(__name__)
def _coerce_usage_int(value: Any) -> int:
if isinstance(value, bool):
return 0
if isinstance(value, int):
return max(value, 0)
if isinstance(value, float):
return max(int(value), 0)
if isinstance(value, str):
try:
return max(int(value), 0)
except ValueError:
return 0
return 0
def _record_codex_app_server_usage(agent, turn) -> dict[str, Any]:
"""Translate Codex app-server token usage into Hermes accounting.
Codex app-server reports usage via thread/tokenUsage/updated as:
inputTokens, cachedInputTokens, outputTokens, reasoningOutputTokens,
totalTokens.
Hermes' canonical prompt bucket includes uncached input + cached input.
The Codex app-server protocol does not currently expose cache-write tokens,
so that bucket remains zero on this runtime.
Even when Codex omits usage for a turn, Hermes should still count that turn
as one API call for session/status accounting.
"""
agent.session_api_calls += 1
usage = getattr(turn, "token_usage_last", None)
if not isinstance(usage, dict) or not usage:
if agent._session_db and agent.session_id:
try:
if not agent._session_db_created:
agent._ensure_db_session()
agent._session_db.update_token_counts(
agent.session_id,
model=agent.model,
api_call_count=1,
)
except Exception as exc:
logger.debug(
"Codex app-server api-call persistence failed (session=%s): %s",
agent.session_id, exc,
)
return {}
from agent.usage_pricing import CanonicalUsage, estimate_usage_cost
input_tokens = _coerce_usage_int(usage.get("inputTokens"))
cache_read_tokens = _coerce_usage_int(usage.get("cachedInputTokens"))
output_tokens = _coerce_usage_int(usage.get("outputTokens"))
reasoning_tokens = _coerce_usage_int(usage.get("reasoningOutputTokens"))
reported_total = _coerce_usage_int(usage.get("totalTokens"))
canonical_usage = CanonicalUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=0,
reasoning_tokens=reasoning_tokens,
raw_usage=usage,
)
prompt_tokens = canonical_usage.prompt_tokens
completion_tokens = canonical_usage.output_tokens
total_tokens = reported_total or canonical_usage.total_tokens
usage_dict = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"input_tokens": canonical_usage.input_tokens,
"output_tokens": canonical_usage.output_tokens,
"cache_read_tokens": canonical_usage.cache_read_tokens,
"cache_write_tokens": canonical_usage.cache_write_tokens,
"reasoning_tokens": canonical_usage.reasoning_tokens,
}
compressor = getattr(agent, "context_compressor", None)
if compressor is not None:
try:
compressor.update_from_response(usage_dict)
context_window = getattr(turn, "model_context_window", None)
if isinstance(context_window, int) and context_window > 0:
compressor.context_length = context_window
except Exception:
logger.debug("codex app-server usage update failed", exc_info=True)
agent.session_prompt_tokens += prompt_tokens
agent.session_completion_tokens += completion_tokens
agent.session_total_tokens += total_tokens
agent.session_input_tokens += canonical_usage.input_tokens
agent.session_output_tokens += canonical_usage.output_tokens
agent.session_cache_read_tokens += canonical_usage.cache_read_tokens
agent.session_cache_write_tokens += canonical_usage.cache_write_tokens
agent.session_reasoning_tokens += canonical_usage.reasoning_tokens
cost_result = estimate_usage_cost(
agent.model,
canonical_usage,
provider=agent.provider,
base_url=agent.base_url,
api_key=getattr(agent, "api_key", ""),
)
if cost_result.amount_usd is not None:
agent.session_estimated_cost_usd += float(cost_result.amount_usd)
agent.session_cost_status = cost_result.status
agent.session_cost_source = cost_result.source
if agent._session_db and agent.session_id:
try:
if not agent._session_db_created:
agent._ensure_db_session()
agent._session_db.update_token_counts(
agent.session_id,
input_tokens=canonical_usage.input_tokens,
output_tokens=canonical_usage.output_tokens,
cache_read_tokens=canonical_usage.cache_read_tokens,
cache_write_tokens=canonical_usage.cache_write_tokens,
reasoning_tokens=canonical_usage.reasoning_tokens,
estimated_cost_usd=float(cost_result.amount_usd)
if cost_result.amount_usd is not None else None,
cost_status=cost_result.status,
cost_source=cost_result.source,
billing_provider=agent.provider,
billing_base_url=agent.base_url,
billing_mode="subscription_included"
if cost_result.status == "included" else None,
model=agent.model,
api_call_count=1,
)
except Exception as exc:
logger.debug(
"Codex app-server token persistence failed (session=%s, tokens=%d): %s",
agent.session_id, total_tokens, exc,
)
return {
**usage_dict,
"last_prompt_tokens": prompt_tokens,
"estimated_cost_usd": float(cost_result.amount_usd)
if cost_result.amount_usd is not None else None,
"cost_status": cost_result.status,
"cost_source": cost_result.source,
}
def run_codex_app_server_turn(
agent,
*,
@ -120,6 +268,8 @@ def run_codex_app_server_turn(
agent._iters_since_skill = (
getattr(agent, "_iters_since_skill", 0) + turn.tool_iterations
)
usage_result = _record_codex_app_server_usage(agent, turn)
api_calls = 1
# Now check the skill nudge AFTER iters were incremented — same
# pattern the chat_completions path uses (line ~15432).
@ -164,12 +314,13 @@ def run_codex_app_server_turn(
return {
"final_response": turn.final_text,
"messages": messages,
"api_calls": 1, # one app-server "turn" maps to one logical API call
"api_calls": api_calls,
"completed": not turn.interrupted and turn.error is None,
"partial": turn.interrupted or turn.error is not None,
"error": turn.error,
"codex_thread_id": turn.thread_id,
"codex_turn_id": turn.turn_id,
**usage_result,
}

View file

@ -72,6 +72,9 @@ class TurnResult:
error: Optional[str] = None # Set if turn ended in a non-recoverable error
turn_id: Optional[str] = None
thread_id: Optional[str] = None
token_usage_last: Optional[dict[str, Any]] = None
token_usage_total: Optional[dict[str, Any]] = None
model_context_window: Optional[int] = None
# Hint to the caller that the underlying codex subprocess is likely
# wedged (turn-level timeout fired, post-tool watchdog tripped, or
# token-refresh failure killed the child). The caller should retire
@ -501,6 +504,7 @@ class CodexAppServerSession:
pending = self._client.take_notification(timeout=0)
if pending is None:
break
_apply_token_usage_notification(result, pending)
self._track_pending_file_change(pending)
proj = projector.project(pending)
if proj.messages:
@ -536,6 +540,8 @@ class CodexAppServerSession:
except Exception: # pragma: no cover - display callback
logger.debug("on_event callback raised", exc_info=True)
_apply_token_usage_notification(result, note)
# Track in-progress fileChange items so the approval bridge
# can surface a real change summary when codex requests
# approval (the approval params themselves don't carry the
@ -802,6 +808,30 @@ class CodexAppServerSession:
return cached
def _apply_token_usage_notification(result: TurnResult, note: dict) -> None:
"""Capture Codex app-server token usage updates for caller accounting.
Codex does not put token usage on turn/completed. It emits a separate
thread/tokenUsage/updated notification containing cumulative totals and
the latest turn breakdown.
"""
if not isinstance(note, dict) or note.get("method") != "thread/tokenUsage/updated":
return
params = note.get("params") or {}
token_usage = params.get("tokenUsage") or {}
if not isinstance(token_usage, dict):
return
last = token_usage.get("last")
total = token_usage.get("total")
if isinstance(last, dict):
result.token_usage_last = dict(last)
if isinstance(total, dict):
result.token_usage_total = dict(total)
window = token_usage.get("modelContextWindow")
if isinstance(window, int) and window > 0:
result.model_context_window = window
def _approval_choice_to_codex_decision(choice: str) -> str:
"""Map Hermes approval choices onto codex's CommandExecutionApprovalDecision
/ FileChangeApprovalDecision wire values.

View file

@ -196,6 +196,40 @@ class TestRunTurn:
# turn_id propagated for downstream session-DB linkage
assert r.turn_id == "turn-fake-001"
def test_token_usage_notification_is_captured(self):
client = FakeClient()
client.queue_notification(
"thread/tokenUsage/updated",
threadId="thread-fake-001",
turnId="turn-fake-001",
tokenUsage={
"last": {
"totalTokens": 130,
"inputTokens": 80,
"cachedInputTokens": 20,
"outputTokens": 25,
"reasoningOutputTokens": 5,
},
"total": {
"totalTokens": 500,
"inputTokens": 300,
"cachedInputTokens": 75,
"outputTokens": 100,
"reasoningOutputTokens": 25,
},
"modelContextWindow": 200000,
},
)
client.queue_notification(
"turn/completed",
threadId="t",
turn={"id": "tu1", "status": "completed", "error": None},
)
r = make_session(client).run_turn("hi", turn_timeout=2.0)
assert r.token_usage_last["totalTokens"] == 130
assert r.token_usage_total["totalTokens"] == 500
assert r.model_context_window == 200000
def test_rich_content_turn_is_collapsed_to_text_payload(self):
client = FakeClient()
client.queue_notification(

View file

@ -84,6 +84,56 @@ class TestRunConversationCodexPath:
assert result["codex_thread_id"] == "thread-stub-1"
assert result["codex_turn_id"] == "turn-stub-1"
def test_codex_app_server_token_usage_updates_session_accounting(self, monkeypatch):
def fake_run_turn(self, user_input: str, **kwargs):
return TurnResult(
final_text="done",
projected_messages=[{"role": "assistant", "content": "done"}],
turn_id="turn-usage-1",
thread_id="thread-usage-1",
token_usage_last={
"totalTokens": 130,
"inputTokens": 80,
"cachedInputTokens": 20,
"outputTokens": 25,
"reasoningOutputTokens": 5,
},
model_context_window=200000,
)
monkeypatch.setattr(CodexAppServerSession, "run_turn", fake_run_turn)
monkeypatch.setattr(
CodexAppServerSession, "ensure_started", lambda self: "thread-usage-1"
)
agent = _make_codex_agent()
with patch.object(agent, "_spawn_background_review", return_value=None):
result = agent.run_conversation("hello")
assert result["api_calls"] == 1
assert result["prompt_tokens"] == 100
assert result["completion_tokens"] == 25
assert result["total_tokens"] == 130
assert result["input_tokens"] == 80
assert result["output_tokens"] == 25
assert result["cache_read_tokens"] == 20
assert result["cache_write_tokens"] == 0
assert result["reasoning_tokens"] == 5
assert result["last_prompt_tokens"] == 100
assert agent.session_api_calls == 1
assert agent.session_prompt_tokens == 100
assert agent.session_completion_tokens == 25
assert agent.session_total_tokens == 130
assert agent.session_input_tokens == 80
assert agent.session_output_tokens == 25
assert agent.session_cache_read_tokens == 20
assert agent.session_cache_write_tokens == 0
assert agent.session_reasoning_tokens == 5
assert agent.context_compressor.last_prompt_tokens == 100
assert agent.context_compressor.last_completion_tokens == 25
assert agent.context_compressor.last_total_tokens == 130
assert agent.context_compressor.context_length == 200000
def test_projected_messages_are_spliced(self, fake_session):
agent = _make_codex_agent()
with patch.object(agent, "_spawn_background_review", return_value=None):