diff --git a/agent/agent_runtime_helpers.py b/agent/agent_runtime_helpers.py index 4a267f95596..70f8fec736c 100644 --- a/agent/agent_runtime_helpers.py +++ b/agent/agent_runtime_helpers.py @@ -1050,6 +1050,11 @@ def restore_primary_runtime(agent) -> bool: agent._fallback_activated = False agent._fallback_index = 0 + # Undo the fallback's identity rewrite so the prompt is + # byte-identical to the stored copy again (prefix cache match). + from agent.chat_completion_helpers import rewrite_prompt_model_identity + rewrite_prompt_model_identity(agent, rt["model"], rt["provider"]) + logger.info( "Primary runtime restored for new turn: %s (%s)", agent.model, agent.provider, diff --git a/agent/chat_completion_helpers.py b/agent/chat_completion_helpers.py index c9272c76266..cee392caaba 100644 --- a/agent/chat_completion_helpers.py +++ b/agent/chat_completion_helpers.py @@ -1042,6 +1042,35 @@ def build_assistant_message(agent, assistant_message, finish_reason: str) -> dic +def rewrite_prompt_model_identity(agent, model: str, provider: str) -> None: + """Point the cached system prompt's ``Model:``/``Provider:`` lines at + the active runtime after a provider switch. + + The system prompt is session-stable and replayed verbatim for prefix-cache + warmth, but after a failover the new backend's cache is cold anyway — + while a stale identity line makes the agent misreport which model it is + when asked. Rewrite the lines in place WITHOUT persisting to the session + DB: the stored row keeps the primary's labels, so when the primary is + restored the prompt is byte-identical to the stored copy again and its + prefix cache still matches. + + Only the LAST occurrence of each line is touched — the identity lines + live in the volatile tail of the prompt, and earlier matches could be + user content (memory snapshots, context files). + """ + sp = getattr(agent, "_cached_system_prompt", None) + if not isinstance(sp, str) or not sp: + return + for label, value in (("Model", model), ("Provider", provider)): + if not value: + continue + matches = list(re.finditer(rf"(?m)^{label}: .*$", sp)) + if matches: + last = matches[-1] + sp = f"{sp[:last.start()]}{label}: {value}{sp[last.end():]}" + agent._cached_system_prompt = sp + + def try_activate_fallback(agent, reason: "FailoverReason | None" = None) -> bool: """Switch to the next fallback model/provider in the chain. @@ -1287,6 +1316,10 @@ def try_activate_fallback(agent, reason: "FailoverReason | None" = None) -> bool api_mode=agent.api_mode, ) + # Keep the prompt's self-identity in sync with the model actually + # answering, so "what model are you?" doesn't report the primary. + rewrite_prompt_model_identity(agent, fb_model, fb_provider) + agent._buffer_status( f"🔄 Primary model failed — switching to fallback: " f"{fb_model} via {fb_provider}" diff --git a/agent/conversation_loop.py b/agent/conversation_loop.py index 0ccc9649428..157762f1a1b 100644 --- a/agent/conversation_loop.py +++ b/agent/conversation_loop.py @@ -466,6 +466,32 @@ def _content_policy_blocked_result( } +def _sync_failover_system_message(agent, api_messages, active_system_prompt): + """Refresh the in-flight system message after a provider failover. + + ``try_activate_fallback`` rewrites the ``Model:``/``Provider:`` identity + lines on ``agent._cached_system_prompt`` (see + ``rewrite_prompt_model_identity``) so the agent reports the model that is + actually answering. But the current call block's ``api_messages`` were + built from the pre-failover prompt, and the retry loop rebuilds + ``api_kwargs`` from that list each iteration — without this sync the + whole turn (and every gateway turn, since fallback re-activates per + message while the primary is down) ships the stale identity. + + Mutates ``api_messages[0]`` in place and returns the prompt to use as + ``active_system_prompt`` for subsequent call-block rebuilds. + """ + sp = getattr(agent, "_cached_system_prompt", None) + if not isinstance(sp, str) or not sp: + return active_system_prompt + if api_messages and api_messages[0].get("role") == "system": + effective = sp + if agent.ephemeral_system_prompt: + effective = (effective + "\n\n" + agent.ephemeral_system_prompt).strip() + api_messages[0]["content"] = effective + return sp + + def run_conversation( agent, user_message: str, @@ -940,6 +966,8 @@ def run_conversation( ) agent._buffer_status(f"⏳ {_nous_msg}") if agent._try_activate_fallback(): + active_system_prompt = _sync_failover_system_message( + agent, api_messages, active_system_prompt) retry_count = 0 compression_attempts = 0 _retry.primary_recovery_attempted = False @@ -1265,6 +1293,8 @@ def run_conversation( if agent._fallback_index < len(agent._fallback_chain): agent._buffer_status("⚠️ Empty/malformed response — switching to fallback...") if agent._try_activate_fallback(): + active_system_prompt = _sync_failover_system_message( + agent, api_messages, active_system_prompt) retry_count = 0 compression_attempts = 0 _retry.primary_recovery_attempted = False @@ -1336,6 +1366,8 @@ def run_conversation( if agent._has_pending_fallback(): agent._buffer_status(f"⚠️ Max retries ({max_retries}) for invalid responses — trying fallback...") if agent._try_activate_fallback(): + active_system_prompt = _sync_failover_system_message( + agent, api_messages, active_system_prompt) retry_count = 0 compression_attempts = 0 _retry.primary_recovery_attempted = False @@ -1479,6 +1511,8 @@ def run_conversation( "⚠️ Model declined to respond (safety refusal) — trying fallback..." ) if agent._try_activate_fallback(): + active_system_prompt = _sync_failover_system_message( + agent, api_messages, active_system_prompt) retry_count = 0 compression_attempts = 0 _retry.primary_recovery_attempted = False @@ -2783,6 +2817,8 @@ def run_conversation( else: agent._buffer_status("⚠️ Rate limited — switching to fallback provider...") if agent._try_activate_fallback(reason=classified.reason): + active_system_prompt = _sync_failover_system_message( + agent, api_messages, active_system_prompt) retry_count = 0 compression_attempts = 0 _retry.primary_recovery_attempted = False @@ -3186,6 +3222,8 @@ def run_conversation( else: agent._buffer_status(f"⚠️ Non-retryable error (HTTP {status_code}) — trying fallback...") if agent._try_activate_fallback(): + active_system_prompt = _sync_failover_system_message( + agent, api_messages, active_system_prompt) retry_count = 0 compression_attempts = 0 _retry.primary_recovery_attempted = False @@ -3333,6 +3371,8 @@ def run_conversation( if agent._has_pending_fallback(): agent._buffer_status(f"⚠️ Max retries ({max_retries}) exhausted — trying fallback...") if agent._try_activate_fallback(): + active_system_prompt = _sync_failover_system_message( + agent, api_messages, active_system_prompt) retry_count = 0 compression_attempts = 0 _retry.primary_recovery_attempted = False @@ -4279,6 +4319,8 @@ def run_conversation( "switching to fallback provider..." ) if agent._try_activate_fallback(): + active_system_prompt = _sync_failover_system_message( + agent, api_messages, active_system_prompt) agent._empty_content_retries = 0 agent._buffer_status( f"↻ Switched to fallback: {agent.model} " diff --git a/tests/agent/test_failover_identity.py b/tests/agent/test_failover_identity.py new file mode 100644 index 00000000000..1937da6b643 --- /dev/null +++ b/tests/agent/test_failover_identity.py @@ -0,0 +1,104 @@ +"""Tests for system-prompt model-identity sync across provider failover. + +The system prompt is session-stable and embeds ``Model:``/``Provider:`` +identity lines. When ``try_activate_fallback`` swaps the runtime, the +prompt must be rewritten in place (and synced into the in-flight +``api_messages``) or the agent reports the primary model's name while a +fallback model is answering — e.g. a local gemma fallback claiming to be +gpt-5.4-mini after a Codex usage-limit 429. +""" + +from types import SimpleNamespace + +from agent.chat_completion_helpers import rewrite_prompt_model_identity +from agent.conversation_loop import _sync_failover_system_message + + +_PROMPT = ( + "You are a helpful assistant.\n" + "\n" + "Memory note at line start:\n" + "Model: decoy-from-memory\n" + "\n" + "Conversation started: Wednesday, June 10, 2026\n" + "Model: gpt-5.4-mini\n" + "Provider: openai-codex" +) + + +def _agent(prompt=_PROMPT, ephemeral=None): + return SimpleNamespace( + _cached_system_prompt=prompt, + ephemeral_system_prompt=ephemeral, + ) + + +class TestRewritePromptModelIdentity: + def test_swaps_identity_lines_to_fallback_runtime(self): + agent = _agent() + rewrite_prompt_model_identity(agent, "gemma4:e2b-mlx", "custom") + assert "Model: gemma4:e2b-mlx" in agent._cached_system_prompt + assert "Provider: custom" in agent._cached_system_prompt + assert "Model: gpt-5.4-mini" not in agent._cached_system_prompt + assert "Provider: openai-codex" not in agent._cached_system_prompt + + def test_only_last_occurrence_is_rewritten(self): + agent = _agent() + rewrite_prompt_model_identity(agent, "gemma4:e2b-mlx", "custom") + # Earlier matching lines may be user content (memory snapshots, + # context files) and must survive untouched. + assert "Model: decoy-from-memory" in agent._cached_system_prompt + + def test_round_trip_restores_byte_identical_prompt(self): + # restore_primary_runtime rewrites the lines back; the result must + # match the stored prompt byte-for-byte so the primary's prefix + # cache still hits after restoration. + agent = _agent() + rewrite_prompt_model_identity(agent, "gemma4:e2b-mlx", "custom") + rewrite_prompt_model_identity(agent, "gpt-5.4-mini", "openai-codex") + assert agent._cached_system_prompt == _PROMPT + + def test_noop_when_prompt_missing_or_empty(self): + for prompt in (None, ""): + agent = _agent(prompt=prompt) + rewrite_prompt_model_identity(agent, "m", "p") + assert agent._cached_system_prompt == prompt + + def test_empty_values_leave_lines_unchanged(self): + agent = _agent() + rewrite_prompt_model_identity(agent, "", "") + assert agent._cached_system_prompt == _PROMPT + + +class TestSyncFailoverSystemMessage: + def test_patches_in_flight_system_message(self): + agent = _agent() + rewrite_prompt_model_identity(agent, "gemma4:e2b-mlx", "custom") + api_messages = [ + {"role": "system", "content": _PROMPT}, + {"role": "user", "content": "what model are you?"}, + ] + result = _sync_failover_system_message(agent, api_messages, _PROMPT) + assert "Model: gemma4:e2b-mlx" in api_messages[0]["content"] + assert result == agent._cached_system_prompt + + def test_appends_ephemeral_system_prompt(self): + agent = _agent(ephemeral="Stay terse.") + api_messages = [{"role": "system", "content": _PROMPT}] + _sync_failover_system_message(agent, api_messages, _PROMPT) + assert api_messages[0]["content"].endswith("Stay terse.") + + def test_noop_without_cached_prompt(self): + agent = _agent(prompt=None) + api_messages = [{"role": "system", "content": "original"}] + result = _sync_failover_system_message(agent, api_messages, "active") + assert api_messages[0]["content"] == "original" + assert result == "active" + + def test_noop_when_first_message_is_not_system(self): + agent = _agent() + api_messages = [{"role": "user", "content": "hi"}] + result = _sync_failover_system_message(agent, api_messages, "active") + assert api_messages == [{"role": "user", "content": "hi"}] + # Still returns the cached prompt for subsequent call-block rebuilds. + assert result == agent._cached_system_prompt