Merge pull request #44738 from kshitijk4poor/salvage/memory-sync-multimodal-content

fix(memory): flatten multimodal content before provider sync
This commit is contained in:
kshitij 2026-06-12 00:40:31 -07:00 committed by GitHub
commit 046f444ddc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 151 additions and 11 deletions

View file

@ -127,14 +127,21 @@ def _chat_content_to_responses_parts(content: Any, *, role: str = "user") -> Lis
return converted
def _summarize_user_message_for_log(content: Any) -> str:
"""Return a short text summary of a user message for logging/trajectory.
def _summarize_user_message_for_log(content: Any, *, sep: str = " ") -> str:
"""Flatten message content to a plain-text summary.
Multimodal messages arrive as a list of ``{type:"text"|"image_url", ...}``
parts from the API server. Logging, spinner previews, and trajectory
files all want a plain string this helper extracts the first chunk of
text and notes any attached images. Returns an empty string for empty
lists and ``str(content)`` for unexpected scalar types.
parts from the API server. Several consumers want a plain string:
- Logging, spinner previews, and trajectory files (the default ``sep=" "``).
- External memory providers, which feed the text to regexes
(``sanitize_context``) and text APIs a raw list crashes the sync with
``expected string or bytes-like object, got 'list'`` (use ``sep="\\n"``).
Text parts are joined with ``sep``; images become a ``[N image(s)]`` marker
so the turn isn't recorded as if the attachment never existed. Returns an
empty string for empty lists and ``str(content)`` for unexpected scalar
types.
"""
if content is None:
return ""
@ -157,7 +164,7 @@ def _summarize_user_message_for_log(content: Any) -> str:
text_bits.append(text)
elif ptype in {"image_url", "input_image"}:
image_count += 1
summary = " ".join(text_bits).strip()
summary = sep.join(text_bits).strip()
if image_count:
note = f"[{image_count} image{'s' if image_count != 1 else ''}]"
summary = f"{note} {summary}" if summary else note

View file

@ -169,7 +169,7 @@ from agent.codex_responses_adapter import (
_derive_responses_function_call_id as _codex_derive_responses_function_call_id,
_deterministic_call_id as _codex_deterministic_call_id,
_split_responses_tool_id as _codex_split_responses_tool_id,
_summarize_user_message_for_log, # noqa: F401 # re-exported for tests
_summarize_user_message_for_log, # also used by _sync_external_memory_for_turn (memory boundary)
)
from agent.tool_guardrails import (
ToolGuardrailDecision,
@ -2990,17 +2990,24 @@ class AIAgent:
return
if not (self._memory_manager and final_response and original_user_message):
return
# Multimodal turns carry content as a list of typed parts; providers
# expect plain strings, so flatten to text first (newline-joined for
# memory, vs the default space-join used for log/trajectory previews).
user_text = _summarize_user_message_for_log(original_user_message, sep="\n")
response_text = _summarize_user_message_for_log(final_response, sep="\n")
if not (user_text and response_text):
return
try:
sync_kwargs = {"session_id": self.session_id or ""}
if messages is not None:
sync_kwargs["messages"] = messages
self._memory_manager.sync_all(
original_user_message,
final_response,
user_text,
response_text,
**sync_kwargs,
)
self._memory_manager.queue_prefetch_all(
original_user_message,
user_text,
session_id=self.session_id or "",
)
except Exception:

View file

@ -979,6 +979,81 @@ class TestMemoryContextFencing:
assert combined.index("weather") < fence_start
class TestFlattenMessageContent:
"""Multimodal message content (list of typed parts) must flatten to a
plain string before reaching providers a raw list crashes their regex
sanitization with ``expected string or bytes-like object, got 'list'``.
The memory boundary reuses ``_summarize_user_message_for_log`` (the same
helper logging/trajectory use) with ``sep="\\n"`` instead of a forked copy.
"""
def test_string_passthrough(self):
from agent.codex_responses_adapter import _summarize_user_message_for_log
assert _summarize_user_message_for_log("hello", sep="\n") == "hello"
def test_none_is_empty(self):
from agent.codex_responses_adapter import _summarize_user_message_for_log
assert _summarize_user_message_for_log(None, sep="\n") == ""
def test_text_parts_joined_with_sep(self):
from agent.codex_responses_adapter import _summarize_user_message_for_log
content = [
{"type": "text", "text": "first"},
{"type": "text", "text": "second"},
]
assert _summarize_user_message_for_log(content, sep="\n") == "first\nsecond"
def test_default_sep_is_space(self):
"""Logging/trajectory callers (the default) keep the space-join."""
from agent.codex_responses_adapter import _summarize_user_message_for_log
content = [
{"type": "text", "text": "first"},
{"type": "text", "text": "second"},
]
assert _summarize_user_message_for_log(content) == "first second"
def test_image_part_becomes_marker(self):
from agent.codex_responses_adapter import _summarize_user_message_for_log
content = [
{"type": "text", "text": "look at this"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,xyz"}},
]
assert _summarize_user_message_for_log(content, sep="\n") == "[1 image] look at this"
def test_image_only_message(self):
from agent.codex_responses_adapter import _summarize_user_message_for_log
content = [
{"type": "image_url", "image_url": {"url": "data:..."}},
{"type": "image_url", "image_url": {"url": "data:..."}},
]
assert _summarize_user_message_for_log(content, sep="\n") == "[2 images]"
def test_unknown_parts_skipped(self):
from agent.codex_responses_adapter import _summarize_user_message_for_log
content = [{"type": "audio", "data": "..."}, {"type": "text", "text": "ok"}, 42]
assert _summarize_user_message_for_log(content, sep="\n") == "ok"
def test_bare_strings_in_list(self):
from agent.codex_responses_adapter import _summarize_user_message_for_log
assert _summarize_user_message_for_log(["plain", "strings"], sep="\n") == "plain\nstrings"
def test_scalar_fallback(self):
from agent.codex_responses_adapter import _summarize_user_message_for_log
assert _summarize_user_message_for_log(42, sep="\n") == "42"
def test_flattened_output_is_regex_safe(self):
"""The original failure: sanitize_context(list) raised TypeError."""
from agent.codex_responses_adapter import _summarize_user_message_for_log
from agent.memory_manager import sanitize_context
content = [
{"type": "text", "text": "fix this bug"},
{"type": "image_url", "image_url": {"url": "data:..."}},
]
# Must not raise.
assert sanitize_context(_summarize_user_message_for_log(content, sep="\n"))
# ---------------------------------------------------------------------------
# AIAgent.commit_memory_session — routes to MemoryManager.on_session_end
# ---------------------------------------------------------------------------

View file

@ -207,6 +207,57 @@ class TestSyncExternalMemoryForTurn:
# sync_all still happened before the prefetch blew up.
agent._memory_manager.sync_all.assert_called_once()
# --- Multimodal content flattening ----------------------------------
def test_multimodal_user_message_is_flattened(self):
"""A turn with an attached image carries the user message as a
list of typed parts. Providers feed the content to regexes
(sanitize_context), so a raw list raised ``expected string or
bytes-like object, got 'list'`` and the turn silently never
synced. The boundary must flatten to text first."""
agent = _bare_agent()
agent._sync_external_memory_for_turn(
original_user_message=[
{"type": "text", "text": "what is in this screenshot?"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
],
final_response="A terminal window showing a stack trace.",
interrupted=False,
)
agent._memory_manager.sync_all.assert_called_once_with(
"[1 image] what is in this screenshot?",
"A terminal window showing a stack trace.",
session_id="test_session_001",
)
agent._memory_manager.queue_prefetch_all.assert_called_once_with(
"[1 image] what is in this screenshot?",
session_id="test_session_001",
)
def test_multimodal_response_is_flattened(self):
agent = _bare_agent()
agent._sync_external_memory_for_turn(
original_user_message="describe it",
final_response=[{"type": "text", "text": "a cat"}],
interrupted=False,
)
agent._memory_manager.sync_all.assert_called_once_with(
"describe it", "a cat",
session_id="test_session_001",
)
def test_multimodal_with_no_text_at_all_skips(self):
"""Unknown-typed parts flatten to an empty string — don't sync a
turn with no recoverable text."""
agent = _bare_agent()
agent._sync_external_memory_for_turn(
original_user_message=[{"type": "audio", "data": "..."}],
final_response="noted",
interrupted=False,
)
agent._memory_manager.sync_all.assert_not_called()
agent._memory_manager.queue_prefetch_all.assert_not_called()
# --- The specific matrix the reporter asked about ------------------
@pytest.mark.parametrize("interrupted,final,user,expect_sync", [