From 464b51d455fe2caab2691b3331e31d1adf94d733 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Fri, 22 May 2026 16:00:59 -0700 Subject: [PATCH] Support media in session chat API --- gateway/platforms/api_server.py | 26 ++++++++++---- tests/gateway/test_session_api.py | 60 +++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index 1a132fdec16..b39e83f43a4 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -319,6 +319,20 @@ def _multimodal_validation_error(exc: ValueError, *, param: str) -> "web.Respons ) +def _session_chat_user_message(body: Dict[str, Any], *, param: str = "message") -> tuple[Any, Optional["web.Response"]]: + """Parse and normalize session chat ``message`` / ``input`` like chat completions.""" + user_message = body.get("message") or body.get("input") + if not _content_has_visible_payload(user_message): + return None, web.json_response( + _openai_error("Missing 'message' field", code="missing_message"), + status=400, + ) + try: + return _normalize_multimodal_content(user_message), None + except ValueError as exc: + return None, _multimodal_validation_error(exc, param=param) + + def check_api_server_requirements() -> bool: """Check if API server dependencies are available.""" return AIOHTTP_AVAILABLE @@ -1483,9 +1497,9 @@ class APIServerAdapter(BasePlatformAdapter): body, err = await self._read_json_body(request) if err: return err - user_message = body.get("message") or body.get("input") - if not _content_has_visible_payload(user_message): - return web.json_response(_openai_error("Missing 'message' field", code="missing_message"), status=400) + user_message, err = _session_chat_user_message(body) + if err is not None: + return err system_prompt = body.get("system_message") or body.get("instructions") if system_prompt is not None and not isinstance(system_prompt, str): return web.json_response(_openai_error("system_message must be a string", code="invalid_system_message"), status=400) @@ -1527,9 +1541,9 @@ class APIServerAdapter(BasePlatformAdapter): body, err = await self._read_json_body(request) if err: return err - user_message = body.get("message") or body.get("input") - if not _content_has_visible_payload(user_message): - return web.json_response(_openai_error("Missing 'message' field", code="missing_message"), status=400) + user_message, err = _session_chat_user_message(body) + if err is not None: + return err system_prompt = body.get("system_message") or body.get("instructions") if system_prompt is not None and not isinstance(system_prompt, str): return web.json_response(_openai_error("system_message must be a string", code="invalid_system_message"), status=400) diff --git a/tests/gateway/test_session_api.py b/tests/gateway/test_session_api.py index 7c99bc8c24a..afc80108317 100644 --- a/tests/gateway/test_session_api.py +++ b/tests/gateway/test_session_api.py @@ -180,6 +180,66 @@ async def test_session_chat_loads_history_and_preserves_session_headers(auth_ada ] +@pytest.mark.asyncio +async def test_session_chat_accepts_multimodal_message(auth_adapter, session_db): + session_id = session_db.create_session("image-session", "api_server") + image_payload = [ + {"type": "input_text", "text": "What's in this image?"}, + {"type": "input_image", "image_url": "data:image/png;base64,AAAA"}, + ] + expected_user_message = [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}}, + ] + + mock_run = AsyncMock(return_value=({"final_response": "A cat.", "session_id": session_id}, {"total_tokens": 4})) + app = _create_session_app(auth_adapter) + with patch.object(auth_adapter, "_run_agent", mock_run): + async with TestClient(TestServer(app)) as cli: + resp = await cli.post( + f"/api/sessions/{session_id}/chat", + json={"message": image_payload}, + headers={"Authorization": "Bearer sk-test"}, + ) + assert resp.status == 200, await resp.text() + + _, kwargs = mock_run.call_args + assert kwargs["user_message"] == expected_user_message + + +@pytest.mark.asyncio +async def test_session_chat_stream_accepts_multimodal_message(adapter, session_db): + session_id = session_db.create_session("image-stream-session", "api_server") + image_payload = [ + {"type": "input_text", "text": "What's in this image?"}, + {"type": "input_image", "image_url": "data:image/png;base64,AAAA"}, + ] + expected_user_message = [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}}, + ] + captured_kwargs = {} + + async def fake_run(**kwargs): + captured_kwargs.update(kwargs) + kwargs["stream_delta_callback"]("A cat.") + return {"final_response": "A cat.", "session_id": session_id}, {"total_tokens": 4} + + app = _create_session_app(adapter) + with patch.object(adapter, "_run_agent", side_effect=fake_run): + async with TestClient(TestServer(app)) as cli: + resp = await cli.post( + f"/api/sessions/{session_id}/chat/stream", + json={"message": image_payload}, + ) + assert resp.status == 200, await resp.text() + assert resp.headers["Content-Type"].startswith("text/event-stream") + body = await resp.text() + + assert "event: assistant.completed" in body + assert captured_kwargs["user_message"] == expected_user_message + + @pytest.mark.asyncio async def test_session_chat_stream_emits_lifecycle_events_and_keepalive_safe_shape(adapter, session_db): session_id = session_db.create_session("stream-session", "api_server")