From 4776a34711a5bb03db85ff4e2b10e03a49ad6520 Mon Sep 17 00:00:00 2001 From: sgaofen <135070653+sgaofen@users.noreply.github.com> Date: Thu, 23 Apr 2026 00:59:28 -0700 Subject: [PATCH] fix(gateway): persist runtime session metadata --- gateway/run.py | 17 +++++++++-- gateway/session.py | 18 ++++++++++++ run_agent.py | 1 + tests/gateway/test_session.py | 43 +++++++++++++++++++++++----- tests/gateway/test_status_command.py | 5 ++++ 5 files changed, 75 insertions(+), 9 deletions(-) diff --git a/gateway/run.py b/gateway/run.py index a024649cbd..e712cdc981 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -4751,6 +4751,9 @@ class GatewayRunner: self.session_store.update_session( session_entry.session_key, last_prompt_tokens=agent_result.get("last_prompt_tokens", 0), + model=agent_result.get("model"), + model_provider=agent_result.get("provider"), + context_tokens=agent_result.get("context_length"), ) # Auto voice reply: send TTS audio before the text response @@ -10114,6 +10117,12 @@ class GatewayRunner: _input_toks = getattr(_agent, "session_prompt_tokens", 0) _output_toks = getattr(_agent, "session_completion_tokens", 0) _resolved_model = getattr(_agent, "model", None) if _agent else None + _resolved_provider = getattr(_agent, "provider", None) if _agent else None + _context_length = ( + getattr(getattr(_agent, "context_compressor", None), "context_length", 0) + if _agent + else 0 + ) if not final_response: error_msg = f"⚠️ {result['error']}" if result.get("error") else "" @@ -10128,7 +10137,9 @@ class GatewayRunner: "last_prompt_tokens": _last_prompt_toks, "input_tokens": _input_toks, "output_tokens": _output_toks, - "model": _resolved_model, + "model": _resolved_model or result.get("model"), + "provider": _resolved_provider or result.get("provider"), + "context_length": _context_length or result.get("context_length", 0), } # Scan tool results for MEDIA: tags that need to be delivered @@ -10217,7 +10228,9 @@ class GatewayRunner: "last_prompt_tokens": _last_prompt_toks, "input_tokens": _input_toks, "output_tokens": _output_toks, - "model": _resolved_model, + "model": _resolved_model or result.get("model"), + "provider": _resolved_provider or result.get("provider"), + "context_length": _context_length or result.get("context_length", 0), "session_id": effective_session_id, "response_previewed": result.get("response_previewed", False), } diff --git a/gateway/session.py b/gateway/session.py index db90d31217..aaec9daf67 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -360,6 +360,9 @@ class SessionEntry: # Last API-reported prompt tokens (for accurate compression pre-check) last_prompt_tokens: int = 0 + model: Optional[str] = None + model_provider: Optional[str] = None + context_tokens: int = 0 # Set when a session was created because the previous one expired; # consumed once by the message handler to inject a notice into context @@ -405,6 +408,9 @@ class SessionEntry: "cache_write_tokens": self.cache_write_tokens, "total_tokens": self.total_tokens, "last_prompt_tokens": self.last_prompt_tokens, + "model": self.model, + "model_provider": self.model_provider, + "context_tokens": self.context_tokens, "estimated_cost_usd": self.estimated_cost_usd, "cost_status": self.cost_status, "memory_flushed": self.memory_flushed, @@ -457,6 +463,9 @@ class SessionEntry: cache_write_tokens=data.get("cache_write_tokens", 0), total_tokens=data.get("total_tokens", 0), last_prompt_tokens=data.get("last_prompt_tokens", 0), + model=data.get("model"), + model_provider=data.get("model_provider"), + context_tokens=data.get("context_tokens", 0), estimated_cost_usd=data.get("estimated_cost_usd", 0.0), cost_status=data.get("cost_status", "unknown"), memory_flushed=data.get("memory_flushed", False), @@ -840,6 +849,9 @@ class SessionStore: self, session_key: str, last_prompt_tokens: int = None, + model: Optional[str] = None, + model_provider: Optional[str] = None, + context_tokens: int = None, ) -> None: """Update lightweight session metadata after an interaction.""" with self._lock: @@ -850,6 +862,12 @@ class SessionStore: entry.updated_at = _now() if last_prompt_tokens is not None: entry.last_prompt_tokens = last_prompt_tokens + if model is not None: + entry.model = model + if model_provider is not None: + entry.model_provider = model_provider + if context_tokens is not None: + entry.context_tokens = context_tokens self._save() def suspend_session(self, session_key: str) -> bool: diff --git a/run_agent.py b/run_agent.py index eaafac5b43..e44ad7d928 100644 --- a/run_agent.py +++ b/run_agent.py @@ -11865,6 +11865,7 @@ class AIAgent: "estimated_cost_usd": self.session_estimated_cost_usd, "cost_status": self.session_cost_status, "cost_source": self.session_cost_source, + "context_length": getattr(self.context_compressor, "context_length", 0) or 0, } # If a /steer landed after the final assistant turn (no more tool # batches to drain into), hand it back to the caller so it can be diff --git a/tests/gateway/test_session.py b/tests/gateway/test_session.py index 539b12a5e1..7deda3e637 100644 --- a/tests/gateway/test_session.py +++ b/tests/gateway/test_session.py @@ -944,7 +944,7 @@ class TestLastPromptTokens: assert entry.last_prompt_tokens == 0 def test_session_entry_roundtrip(self): - """last_prompt_tokens should survive serialization/deserialization.""" + """Runtime session metadata should survive serialization/deserialization.""" from gateway.session import SessionEntry from datetime import datetime entry = SessionEntry( @@ -953,14 +953,23 @@ class TestLastPromptTokens: created_at=datetime.now(), updated_at=datetime.now(), last_prompt_tokens=42000, + model="gpt-5.4", + model_provider="openai", + context_tokens=200000, ) d = entry.to_dict() assert d["last_prompt_tokens"] == 42000 + assert d["model"] == "gpt-5.4" + assert d["model_provider"] == "openai" + assert d["context_tokens"] == 200000 restored = SessionEntry.from_dict(d) assert restored.last_prompt_tokens == 42000 + assert restored.model == "gpt-5.4" + assert restored.model_provider == "openai" + assert restored.context_tokens == 200000 def test_session_entry_from_old_data(self): - """Old session data without last_prompt_tokens should default to 0.""" + """Old session data without runtime metadata should keep safe defaults.""" from gateway.session import SessionEntry data = { "session_key": "test", @@ -974,9 +983,12 @@ class TestLastPromptTokens: } entry = SessionEntry.from_dict(data) assert entry.last_prompt_tokens == 0 + assert entry.model is None + assert entry.model_provider is None + assert entry.context_tokens == 0 def test_update_session_sets_last_prompt_tokens(self, tmp_path): - """update_session should store the actual prompt token count.""" + """update_session should store the actual prompt token count and runtime metadata.""" config = GatewayConfig() with patch("gateway.session.SessionStore._ensure_loaded"): store = SessionStore(sessions_dir=tmp_path, config=config) @@ -994,11 +1006,20 @@ class TestLastPromptTokens: ) store._entries = {"k1": entry} - store.update_session("k1", last_prompt_tokens=85000) + store.update_session( + "k1", + last_prompt_tokens=85000, + model="claude-sonnet-4-6", + model_provider="anthropic", + context_tokens=400000, + ) assert entry.last_prompt_tokens == 85000 + assert entry.model == "claude-sonnet-4-6" + assert entry.model_provider == "anthropic" + assert entry.context_tokens == 400000 def test_update_session_none_does_not_change(self, tmp_path): - """update_session with default (None) should not change last_prompt_tokens.""" + """update_session with default (None) should not change stored metadata.""" config = GatewayConfig() with patch("gateway.session.SessionStore._ensure_loaded"): store = SessionStore(sessions_dir=tmp_path, config=config) @@ -1014,14 +1035,20 @@ class TestLastPromptTokens: created_at=datetime.now(), updated_at=datetime.now(), last_prompt_tokens=50000, + model="gpt-4o", + model_provider="openrouter", + context_tokens=128000, ) store._entries = {"k1": entry} store.update_session("k1") # No last_prompt_tokens arg assert entry.last_prompt_tokens == 50000 # unchanged + assert entry.model == "gpt-4o" + assert entry.model_provider == "openrouter" + assert entry.context_tokens == 128000 def test_update_session_zero_resets(self, tmp_path): - """update_session with last_prompt_tokens=0 should reset the field.""" + """update_session accepts zero values for numeric runtime fields.""" config = GatewayConfig() with patch("gateway.session.SessionStore._ensure_loaded"): store = SessionStore(sessions_dir=tmp_path, config=config) @@ -1037,11 +1064,13 @@ class TestLastPromptTokens: created_at=datetime.now(), updated_at=datetime.now(), last_prompt_tokens=85000, + context_tokens=200000, ) store._entries = {"k1": entry} - store.update_session("k1", last_prompt_tokens=0) + store.update_session("k1", last_prompt_tokens=0, context_tokens=0) assert entry.last_prompt_tokens == 0 + assert entry.context_tokens == 0 class TestRewriteTranscriptPreservesReasoning: """rewrite_transcript must not drop reasoning fields from SQLite.""" diff --git a/tests/gateway/test_status_command.py b/tests/gateway/test_status_command.py index 50e1c52cc2..2ff5585675 100644 --- a/tests/gateway/test_status_command.py +++ b/tests/gateway/test_status_command.py @@ -206,6 +206,8 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch): "input_tokens": 120, "output_tokens": 45, "model": "openai/test-model", + "provider": "openai", + "context_length": 100000, } ) @@ -221,6 +223,9 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch): runner.session_store.update_session.assert_called_once_with( session_entry.session_key, last_prompt_tokens=80, + model="openai/test-model", + model_provider="openai", + context_tokens=100000, )