fix: parse Codex image generation SSE directly

This commit is contained in:
Will Falcon 2026-05-27 09:35:01 +08:00 committed by Teknium
parent 16e86ce6a7
commit bba50977bc
2 changed files with 167 additions and 166 deletions

View file

@ -10,7 +10,6 @@ from __future__ import annotations
import importlib
from pathlib import Path
from types import SimpleNamespace
import pytest
@ -33,24 +32,6 @@ def _b64_png() -> str:
return base64.b64encode(bytes.fromhex(_PNG_HEX)).decode()
class _FakeStream:
def __init__(self, events, final_response):
self._events = list(events)
self._final = final_response
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def __iter__(self):
return iter(self._events)
def get_final_response(self):
return self._final
@pytest.fixture(autouse=True)
def _tmp_hermes_home(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
@ -127,22 +108,7 @@ class TestGenerate:
def test_generate_uses_codex_stream_path(self, provider, monkeypatch, tmp_path):
monkeypatch.setattr(codex_plugin, "_read_codex_access_token", lambda: "codex-token")
output_item = SimpleNamespace(
type="image_generation_call",
status="generating",
id="ig_test",
result=_b64_png(),
)
done_event = SimpleNamespace(type="response.output_item.done", item=output_item)
final_response = SimpleNamespace(output=[], status="completed", output_text="")
fake_client = SimpleNamespace(
responses=SimpleNamespace(
stream=lambda **kwargs: _FakeStream([done_event], final_response)
)
)
monkeypatch.setattr(codex_plugin, "_build_codex_client", lambda: fake_client)
monkeypatch.setattr(codex_plugin, "_collect_image_b64", lambda *a, **kw: _b64_png())
result = provider.generate("a cat", aspect_ratio="landscape")
@ -163,20 +129,15 @@ class TestGenerate:
captured = {}
def _stream(**kwargs):
captured.update(kwargs)
output_item = SimpleNamespace(
type="image_generation_call",
status="generating",
id="ig_test",
result=_b64_png(),
)
done_event = SimpleNamespace(type="response.output_item.done", item=output_item)
final_response = SimpleNamespace(output=[], status="completed", output_text="")
return _FakeStream([done_event], final_response)
def _collect(token, *, prompt, size, quality):
captured.update(codex_plugin._build_responses_payload(
prompt=prompt,
size=size,
quality=quality,
))
return _b64_png()
fake_client = SimpleNamespace(responses=SimpleNamespace(stream=_stream))
monkeypatch.setattr(codex_plugin, "_build_codex_client", lambda: fake_client)
monkeypatch.setattr(codex_plugin, "_collect_image_b64", _collect)
result = provider.generate("a cat", aspect_ratio="portrait")
assert result["success"] is True
@ -199,83 +160,59 @@ class TestGenerate:
assert tool["background"] == "opaque"
assert tool["partial_images"] == 1
def test_partial_image_event_used_when_done_missing(self, provider, monkeypatch):
"""If the stream never emits output_item.done, fall back to the
partial_image event so users at least get the latest preview frame."""
monkeypatch.setattr(codex_plugin, "_read_codex_access_token", lambda: "codex-token")
def test_partial_image_event_used_when_done_missing(self):
"""If output_item.done is missing, partial_image_b64 is accepted."""
payload = {
"type": "response.image_generation_call.partial_image",
"partial_image_b64": _b64_png(),
}
assert codex_plugin._extract_image_b64(payload) == _b64_png()
partial_event = SimpleNamespace(
type="response.image_generation_call.partial_image",
partial_image_b64=_b64_png(),
)
final_response = SimpleNamespace(output=[], status="completed", output_text="")
def test_sse_parser_handles_event_and_data_lines(self):
class _Response:
def iter_lines(self):
return iter([
"event: response.output_item.done",
'data: {"item": {"type": "image_generation_call", "result": "abc"}}',
"",
])
fake_client = SimpleNamespace(
responses=SimpleNamespace(
stream=lambda **kwargs: _FakeStream([partial_event], final_response)
)
)
monkeypatch.setattr(codex_plugin, "_build_codex_client", lambda: fake_client)
events = list(codex_plugin._iter_sse_json(_Response()))
assert events == [{
"type": "response.output_item.done",
"item": {"type": "image_generation_call", "result": "abc"},
}]
result = provider.generate("a cat")
assert result["success"] is True
assert Path(result["image"]).exists()
def test_final_response_sweep_recovers_image(self, provider, monkeypatch):
"""If no image_generation_call event arrives mid-stream, the
post-stream final-response sweep should still find the image."""
monkeypatch.setattr(codex_plugin, "_read_codex_access_token", lambda: "codex-token")
final_item = SimpleNamespace(
type="image_generation_call",
status="completed",
id="ig_final",
result=_b64_png(),
)
final_response = SimpleNamespace(output=[final_item], status="completed", output_text="")
fake_client = SimpleNamespace(
responses=SimpleNamespace(
stream=lambda **kwargs: _FakeStream([], final_response)
)
)
monkeypatch.setattr(codex_plugin, "_build_codex_client", lambda: fake_client)
result = provider.generate("a cat")
assert result["success"] is True
assert Path(result["image"]).exists()
def test_final_response_sweep_recovers_image(self):
"""Completed response output is found by recursive payload scanning."""
payload = {
"type": "response.completed",
"response": {
"output": [{
"type": "image_generation_call",
"status": "completed",
"id": "ig_final",
"result": _b64_png(),
}],
},
}
assert codex_plugin._extract_image_b64(payload) == _b64_png()
def test_empty_response_returns_error(self, provider, monkeypatch):
monkeypatch.setattr(codex_plugin, "_read_codex_access_token", lambda: "codex-token")
final_response = SimpleNamespace(output=[], status="completed", output_text="")
fake_client = SimpleNamespace(
responses=SimpleNamespace(
stream=lambda **kwargs: _FakeStream([], final_response)
)
)
monkeypatch.setattr(codex_plugin, "_build_codex_client", lambda: fake_client)
monkeypatch.setattr(codex_plugin, "_collect_image_b64", lambda *a, **kw: None)
result = provider.generate("a cat")
assert result["success"] is False
assert result["error_type"] == "empty_response"
def test_client_init_failure_returns_auth_error(self, provider, monkeypatch):
monkeypatch.setattr(codex_plugin, "_read_codex_access_token", lambda: "codex-token")
monkeypatch.setattr(codex_plugin, "_build_codex_client", lambda: None)
result = provider.generate("a cat")
assert result["success"] is False
assert result["error_type"] == "auth_required"
def test_stream_exception_returns_api_error(self, provider, monkeypatch):
monkeypatch.setattr(codex_plugin, "_read_codex_access_token", lambda: "codex-token")
def _boom(**kwargs):
def _boom(*args, **kwargs):
raise RuntimeError("cloudflare 403")
fake_client = SimpleNamespace(responses=SimpleNamespace(stream=_boom))
monkeypatch.setattr(codex_plugin, "_build_codex_client", lambda: fake_client)
monkeypatch.setattr(codex_plugin, "_collect_image_b64", _boom)
result = provider.generate("a cat")
assert result["success"] is False