Support media in session chat API

This commit is contained in:
Jonathan 2026-05-22 16:00:59 -07:00 committed by Teknium
parent f7527b0fdb
commit 464b51d455
2 changed files with 80 additions and 6 deletions

View file

@ -319,6 +319,20 @@ def _multimodal_validation_error(exc: ValueError, *, param: str) -> "web.Respons
)
def _session_chat_user_message(body: Dict[str, Any], *, param: str = "message") -> tuple[Any, Optional["web.Response"]]:
"""Parse and normalize session chat ``message`` / ``input`` like chat completions."""
user_message = body.get("message") or body.get("input")
if not _content_has_visible_payload(user_message):
return None, web.json_response(
_openai_error("Missing 'message' field", code="missing_message"),
status=400,
)
try:
return _normalize_multimodal_content(user_message), None
except ValueError as exc:
return None, _multimodal_validation_error(exc, param=param)
def check_api_server_requirements() -> bool:
"""Check if API server dependencies are available."""
return AIOHTTP_AVAILABLE
@ -1483,9 +1497,9 @@ class APIServerAdapter(BasePlatformAdapter):
body, err = await self._read_json_body(request)
if err:
return err
user_message = body.get("message") or body.get("input")
if not _content_has_visible_payload(user_message):
return web.json_response(_openai_error("Missing 'message' field", code="missing_message"), status=400)
user_message, err = _session_chat_user_message(body)
if err is not None:
return err
system_prompt = body.get("system_message") or body.get("instructions")
if system_prompt is not None and not isinstance(system_prompt, str):
return web.json_response(_openai_error("system_message must be a string", code="invalid_system_message"), status=400)
@ -1527,9 +1541,9 @@ class APIServerAdapter(BasePlatformAdapter):
body, err = await self._read_json_body(request)
if err:
return err
user_message = body.get("message") or body.get("input")
if not _content_has_visible_payload(user_message):
return web.json_response(_openai_error("Missing 'message' field", code="missing_message"), status=400)
user_message, err = _session_chat_user_message(body)
if err is not None:
return err
system_prompt = body.get("system_message") or body.get("instructions")
if system_prompt is not None and not isinstance(system_prompt, str):
return web.json_response(_openai_error("system_message must be a string", code="invalid_system_message"), status=400)

View file

@ -180,6 +180,66 @@ async def test_session_chat_loads_history_and_preserves_session_headers(auth_ada
]
@pytest.mark.asyncio
async def test_session_chat_accepts_multimodal_message(auth_adapter, session_db):
session_id = session_db.create_session("image-session", "api_server")
image_payload = [
{"type": "input_text", "text": "What's in this image?"},
{"type": "input_image", "image_url": "data:image/png;base64,AAAA"},
]
expected_user_message = [
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}},
]
mock_run = AsyncMock(return_value=({"final_response": "A cat.", "session_id": session_id}, {"total_tokens": 4}))
app = _create_session_app(auth_adapter)
with patch.object(auth_adapter, "_run_agent", mock_run):
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
f"/api/sessions/{session_id}/chat",
json={"message": image_payload},
headers={"Authorization": "Bearer sk-test"},
)
assert resp.status == 200, await resp.text()
_, kwargs = mock_run.call_args
assert kwargs["user_message"] == expected_user_message
@pytest.mark.asyncio
async def test_session_chat_stream_accepts_multimodal_message(adapter, session_db):
session_id = session_db.create_session("image-stream-session", "api_server")
image_payload = [
{"type": "input_text", "text": "What's in this image?"},
{"type": "input_image", "image_url": "data:image/png;base64,AAAA"},
]
expected_user_message = [
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}},
]
captured_kwargs = {}
async def fake_run(**kwargs):
captured_kwargs.update(kwargs)
kwargs["stream_delta_callback"]("A cat.")
return {"final_response": "A cat.", "session_id": session_id}, {"total_tokens": 4}
app = _create_session_app(adapter)
with patch.object(adapter, "_run_agent", side_effect=fake_run):
async with TestClient(TestServer(app)) as cli:
resp = await cli.post(
f"/api/sessions/{session_id}/chat/stream",
json={"message": image_payload},
)
assert resp.status == 200, await resp.text()
assert resp.headers["Content-Type"].startswith("text/event-stream")
body = await resp.text()
assert "event: assistant.completed" in body
assert captured_kwargs["user_message"] == expected_user_message
@pytest.mark.asyncio
async def test_session_chat_stream_emits_lifecycle_events_and_keepalive_safe_shape(adapter, session_db):
session_id = session_db.create_session("stream-session", "api_server")