diff --git a/plugins/image_gen/xai/__init__.py b/plugins/image_gen/xai/__init__.py index b1ec4368ef..93fd10ce39 100644 --- a/plugins/image_gen/xai/__init__.py +++ b/plugins/image_gen/xai/__init__.py @@ -203,11 +203,12 @@ class XAIImageGenProvider(ImageGenProvider): ) response.raise_for_status() except requests.HTTPError as exc: - status = exc.response.status_code if exc.response else 0 + response = exc.response + status = response.status_code if response is not None else 0 try: - err_msg = exc.response.json().get("error", {}).get("message", exc.response.text[:300]) + err_msg = response.json().get("error", {}).get("message", response.text[:300]) except Exception: - err_msg = exc.response.text[:300] if exc.response else str(exc) + err_msg = response.text[:300] if response is not None else str(exc) logger.error("xAI image gen failed (%d): %s", status, err_msg) return error_response( error=f"xAI image generation failed ({status}): {err_msg}", diff --git a/tests/plugins/image_gen/test_xai_provider.py b/tests/plugins/image_gen/test_xai_provider.py index ab1bf88345..0da46d43ec 100644 --- a/tests/plugins/image_gen/test_xai_provider.py +++ b/tests/plugins/image_gen/test_xai_provider.py @@ -172,6 +172,27 @@ class TestGenerate: assert result["success"] is False assert result["error_type"] == "api_error" + def test_api_error_preserves_real_response_status(self): + import requests as req_lib + from plugins.image_gen.xai import XAIImageGenProvider + + response = req_lib.Response() + response.status_code = 401 + response._content = json.dumps({"error": {"message": "Invalid API key"}}).encode() + response.headers["Content-Type"] = "application/json" + + response.raise_for_status = MagicMock( + side_effect=req_lib.HTTPError(response=response) + ) + + with patch("plugins.image_gen.xai.requests.post", return_value=response): + provider = XAIImageGenProvider() + result = provider.generate(prompt="test") + + assert result["success"] is False + assert result["error_type"] == "api_error" + assert "xAI image generation failed (401): Invalid API key" in result["error"] + def test_timeout(self): import requests as req_lib