fix(gateway): persist runtime session metadata

This commit is contained in:
sgaofen 2026-04-23 00:59:28 -07:00
parent d1ce358646
commit 4776a34711
5 changed files with 75 additions and 9 deletions

View file

@ -4751,6 +4751,9 @@ class GatewayRunner:
self.session_store.update_session(
session_entry.session_key,
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
model=agent_result.get("model"),
model_provider=agent_result.get("provider"),
context_tokens=agent_result.get("context_length"),
)
# Auto voice reply: send TTS audio before the text response
@ -10114,6 +10117,12 @@ class GatewayRunner:
_input_toks = getattr(_agent, "session_prompt_tokens", 0)
_output_toks = getattr(_agent, "session_completion_tokens", 0)
_resolved_model = getattr(_agent, "model", None) if _agent else None
_resolved_provider = getattr(_agent, "provider", None) if _agent else None
_context_length = (
getattr(getattr(_agent, "context_compressor", None), "context_length", 0)
if _agent
else 0
)
if not final_response:
error_msg = f"⚠️ {result['error']}" if result.get("error") else ""
@ -10128,7 +10137,9 @@ class GatewayRunner:
"last_prompt_tokens": _last_prompt_toks,
"input_tokens": _input_toks,
"output_tokens": _output_toks,
"model": _resolved_model,
"model": _resolved_model or result.get("model"),
"provider": _resolved_provider or result.get("provider"),
"context_length": _context_length or result.get("context_length", 0),
}
# Scan tool results for MEDIA:<path> tags that need to be delivered
@ -10217,7 +10228,9 @@ class GatewayRunner:
"last_prompt_tokens": _last_prompt_toks,
"input_tokens": _input_toks,
"output_tokens": _output_toks,
"model": _resolved_model,
"model": _resolved_model or result.get("model"),
"provider": _resolved_provider or result.get("provider"),
"context_length": _context_length or result.get("context_length", 0),
"session_id": effective_session_id,
"response_previewed": result.get("response_previewed", False),
}

View file

@ -360,6 +360,9 @@ class SessionEntry:
# Last API-reported prompt tokens (for accurate compression pre-check)
last_prompt_tokens: int = 0
model: Optional[str] = None
model_provider: Optional[str] = None
context_tokens: int = 0
# Set when a session was created because the previous one expired;
# consumed once by the message handler to inject a notice into context
@ -405,6 +408,9 @@ class SessionEntry:
"cache_write_tokens": self.cache_write_tokens,
"total_tokens": self.total_tokens,
"last_prompt_tokens": self.last_prompt_tokens,
"model": self.model,
"model_provider": self.model_provider,
"context_tokens": self.context_tokens,
"estimated_cost_usd": self.estimated_cost_usd,
"cost_status": self.cost_status,
"memory_flushed": self.memory_flushed,
@ -457,6 +463,9 @@ class SessionEntry:
cache_write_tokens=data.get("cache_write_tokens", 0),
total_tokens=data.get("total_tokens", 0),
last_prompt_tokens=data.get("last_prompt_tokens", 0),
model=data.get("model"),
model_provider=data.get("model_provider"),
context_tokens=data.get("context_tokens", 0),
estimated_cost_usd=data.get("estimated_cost_usd", 0.0),
cost_status=data.get("cost_status", "unknown"),
memory_flushed=data.get("memory_flushed", False),
@ -840,6 +849,9 @@ class SessionStore:
self,
session_key: str,
last_prompt_tokens: int = None,
model: Optional[str] = None,
model_provider: Optional[str] = None,
context_tokens: int = None,
) -> None:
"""Update lightweight session metadata after an interaction."""
with self._lock:
@ -850,6 +862,12 @@ class SessionStore:
entry.updated_at = _now()
if last_prompt_tokens is not None:
entry.last_prompt_tokens = last_prompt_tokens
if model is not None:
entry.model = model
if model_provider is not None:
entry.model_provider = model_provider
if context_tokens is not None:
entry.context_tokens = context_tokens
self._save()
def suspend_session(self, session_key: str) -> bool:

View file

@ -11865,6 +11865,7 @@ class AIAgent:
"estimated_cost_usd": self.session_estimated_cost_usd,
"cost_status": self.session_cost_status,
"cost_source": self.session_cost_source,
"context_length": getattr(self.context_compressor, "context_length", 0) or 0,
}
# If a /steer landed after the final assistant turn (no more tool
# batches to drain into), hand it back to the caller so it can be

View file

@ -944,7 +944,7 @@ class TestLastPromptTokens:
assert entry.last_prompt_tokens == 0
def test_session_entry_roundtrip(self):
"""last_prompt_tokens should survive serialization/deserialization."""
"""Runtime session metadata should survive serialization/deserialization."""
from gateway.session import SessionEntry
from datetime import datetime
entry = SessionEntry(
@ -953,14 +953,23 @@ class TestLastPromptTokens:
created_at=datetime.now(),
updated_at=datetime.now(),
last_prompt_tokens=42000,
model="gpt-5.4",
model_provider="openai",
context_tokens=200000,
)
d = entry.to_dict()
assert d["last_prompt_tokens"] == 42000
assert d["model"] == "gpt-5.4"
assert d["model_provider"] == "openai"
assert d["context_tokens"] == 200000
restored = SessionEntry.from_dict(d)
assert restored.last_prompt_tokens == 42000
assert restored.model == "gpt-5.4"
assert restored.model_provider == "openai"
assert restored.context_tokens == 200000
def test_session_entry_from_old_data(self):
"""Old session data without last_prompt_tokens should default to 0."""
"""Old session data without runtime metadata should keep safe defaults."""
from gateway.session import SessionEntry
data = {
"session_key": "test",
@ -974,9 +983,12 @@ class TestLastPromptTokens:
}
entry = SessionEntry.from_dict(data)
assert entry.last_prompt_tokens == 0
assert entry.model is None
assert entry.model_provider is None
assert entry.context_tokens == 0
def test_update_session_sets_last_prompt_tokens(self, tmp_path):
"""update_session should store the actual prompt token count."""
"""update_session should store the actual prompt token count and runtime metadata."""
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
@ -994,11 +1006,20 @@ class TestLastPromptTokens:
)
store._entries = {"k1": entry}
store.update_session("k1", last_prompt_tokens=85000)
store.update_session(
"k1",
last_prompt_tokens=85000,
model="claude-sonnet-4-6",
model_provider="anthropic",
context_tokens=400000,
)
assert entry.last_prompt_tokens == 85000
assert entry.model == "claude-sonnet-4-6"
assert entry.model_provider == "anthropic"
assert entry.context_tokens == 400000
def test_update_session_none_does_not_change(self, tmp_path):
"""update_session with default (None) should not change last_prompt_tokens."""
"""update_session with default (None) should not change stored metadata."""
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
@ -1014,14 +1035,20 @@ class TestLastPromptTokens:
created_at=datetime.now(),
updated_at=datetime.now(),
last_prompt_tokens=50000,
model="gpt-4o",
model_provider="openrouter",
context_tokens=128000,
)
store._entries = {"k1": entry}
store.update_session("k1") # No last_prompt_tokens arg
assert entry.last_prompt_tokens == 50000 # unchanged
assert entry.model == "gpt-4o"
assert entry.model_provider == "openrouter"
assert entry.context_tokens == 128000
def test_update_session_zero_resets(self, tmp_path):
"""update_session with last_prompt_tokens=0 should reset the field."""
"""update_session accepts zero values for numeric runtime fields."""
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
@ -1037,11 +1064,13 @@ class TestLastPromptTokens:
created_at=datetime.now(),
updated_at=datetime.now(),
last_prompt_tokens=85000,
context_tokens=200000,
)
store._entries = {"k1": entry}
store.update_session("k1", last_prompt_tokens=0)
store.update_session("k1", last_prompt_tokens=0, context_tokens=0)
assert entry.last_prompt_tokens == 0
assert entry.context_tokens == 0
class TestRewriteTranscriptPreservesReasoning:
"""rewrite_transcript must not drop reasoning fields from SQLite."""

View file

@ -206,6 +206,8 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
"input_tokens": 120,
"output_tokens": 45,
"model": "openai/test-model",
"provider": "openai",
"context_length": 100000,
}
)
@ -221,6 +223,9 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
runner.session_store.update_session.assert_called_once_with(
session_entry.session_key,
last_prompt_tokens=80,
model="openai/test-model",
model_provider="openai",
context_tokens=100000,
)