diff --git a/plugins/image_gen/xai/__init__.py b/plugins/image_gen/xai/__init__.py new file mode 100644 index 000000000..39e83e5ac --- /dev/null +++ b/plugins/image_gen/xai/__init__.py @@ -0,0 +1,324 @@ +"""xAI image generation backend. + +Exposes xAI's ``grok-imagine-image`` model as an +:class:`ImageGenProvider` implementation. + +Features: +- Text-to-image generation +- Image editing with reference images +- Multiple aspect ratios (1:1, 16:9, 9:16, etc.) +- Multiple resolutions (1K, 2K) +- Base64 output saved to cache + +Selection precedence (first hit wins): +1. ``XAI_IMAGE_MODEL`` env var +2. ``image_gen.xai.model`` in ``config.yaml`` +3. :data:`DEFAULT_MODEL` +""" + +from __future__ import annotations + +import logging +import os +from typing import Any, Dict, List, Optional, Tuple + +import requests + +from agent.image_gen_provider import ( + DEFAULT_ASPECT_RATIO, + ImageGenProvider, + error_response, + resolve_aspect_ratio, + save_b64_image, + success_response, +) +from tools.xai_http import hermes_xai_user_agent + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Model catalog +# --------------------------------------------------------------------------- + +API_MODEL = "grok-imagine-image" + +_MODELS: Dict[str, Dict[str, Any]] = { + "grok-imagine-image": { + "display": "Grok Imagine Image", + "speed": "~5-10s", + "strengths": "Fast, high-quality, supports editing", + }, +} + +DEFAULT_MODEL = "grok-imagine-image" + +# xAI aspect ratios (more options than FAL/OpenAI) +_XAI_ASPECT_RATIOS = { + "landscape": "16:9", + "square": "1:1", + "portrait": "9:16", + "4:3": "4:3", + "3:4": "3:4", + "3:2": "3:2", + "2:3": "2:3", +} + +# xAI resolutions +_XAI_RESOLUTIONS = { + "1k": "1024", + "2k": "2048", +} + +DEFAULT_RESOLUTION = "1k" + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +def _load_xai_config() -> Dict[str, Any]: + """Read ``image_gen.xai`` from config.yaml.""" + try: + from hermes_cli.config import load_config + + cfg = load_config() + section = cfg.get("image_gen") if isinstance(cfg, dict) else None + xai_section = section.get("xai") if isinstance(section, dict) else None + return xai_section if isinstance(xai_section, dict) else {} + except Exception as exc: + logger.debug("Could not load image_gen.xai config: %s", exc) + return {} + + +def _resolve_model() -> Tuple[str, Dict[str, Any]]: + """Decide which model to use and return ``(model_id, meta)``.""" + env_override = os.environ.get("XAI_IMAGE_MODEL") + if env_override and env_override in _MODELS: + return env_override, _MODELS[env_override] + + cfg = _load_xai_config() + candidate = cfg.get("model") if isinstance(cfg.get("model"), str) else None + if candidate and candidate in _MODELS: + return candidate, _MODELS[candidate] + + return DEFAULT_MODEL, _MODELS[DEFAULT_MODEL] + + +def _resolve_resolution() -> str: + """Get configured resolution.""" + cfg = _load_xai_config() + res = cfg.get("resolution") if isinstance(cfg.get("resolution"), str) else None + if res and res in _XAI_RESOLUTIONS: + return res + return DEFAULT_RESOLUTION + + +# --------------------------------------------------------------------------- +# Provider +# --------------------------------------------------------------------------- + + +class XAIImageGenProvider(ImageGenProvider): + """xAI ``grok-imagine-image`` backend.""" + + @property + def name(self) -> str: + return "xai" + + @property + def display_name(self) -> str: + return "xAI (Grok)" + + def is_available(self) -> bool: + return bool(os.getenv("XAI_API_KEY")) + + def list_models(self) -> List[Dict[str, Any]]: + return [ + { + "id": model_id, + "display": meta.get("display", model_id), + "speed": meta.get("speed", ""), + "strengths": meta.get("strengths", ""), + } + for model_id, meta in _MODELS.items() + ] + + def get_setup_schema(self) -> Dict[str, Any]: + return { + "name": "xAI (Grok)", + "badge": "paid", + "tag": "Native xAI image generation via grok-imagine-image", + "env_vars": [ + { + "key": "XAI_API_KEY", + "prompt": "xAI API key", + "url": "https://console.x.ai/", + }, + ], + } + + def generate( + self, + prompt: str, + aspect_ratio: str = DEFAULT_ASPECT_RATIO, + **kwargs: Any, + ) -> Dict[str, Any]: + """Generate an image using xAI's grok-imagine-image.""" + api_key = os.getenv("XAI_API_KEY", "").strip() + if not api_key: + return error_response( + error="XAI_API_KEY not set. Get one at https://console.x.ai/", + error_type="missing_api_key", + provider="xai", + aspect_ratio=aspect_ratio, + ) + + model_id, meta = _resolve_model() + aspect = resolve_aspect_ratio(aspect_ratio) + xai_ar = _XAI_ASPECT_RATIOS.get(aspect, "1:1") + resolution = _resolve_resolution() + xai_res = _XAI_RESOLUTIONS.get(resolution, "1024") + + # Check for editing mode (reference images) + reference_images = kwargs.get("reference_images", []) + edit_image = kwargs.get("edit_image") + + payload: Dict[str, Any] = { + "model": API_MODEL, + "prompt": prompt, + "aspect_ratio": xai_ar, + "resolution": xai_res, + } + + # Add editing parameters if present + if reference_images: + payload["reference_images"] = reference_images[:5] + if edit_image: + payload["image_url"] = edit_image + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "User-Agent": hermes_xai_user_agent(), + } + + base_url = (os.getenv("XAI_BASE_URL") or "https://api.x.ai/v1").strip().rstrip("/") + + try: + response = requests.post( + f"{base_url}/images/generations", + headers=headers, + json=payload, + timeout=120, + ) + response.raise_for_status() + except requests.HTTPError as exc: + status = exc.response.status_code if exc.response else 0 + try: + err_msg = exc.response.json().get("error", {}).get("message", exc.response.text[:300]) + except Exception: + err_msg = exc.response.text[:300] if exc.response else str(exc) + logger.error("xAI image gen failed (%d): %s", status, err_msg) + return error_response( + error=f"xAI image generation failed ({status}): {err_msg}", + error_type="api_error", + provider="xai", + model=model_id, + prompt=prompt, + aspect_ratio=aspect, + ) + except requests.Timeout: + return error_response( + error="xAI image generation timed out (120s)", + error_type="timeout", + provider="xai", + model=model_id, + prompt=prompt, + aspect_ratio=aspect, + ) + except requests.ConnectionError as exc: + return error_response( + error=f"xAI connection error: {exc}", + error_type="connection_error", + provider="xai", + model=model_id, + prompt=prompt, + aspect_ratio=aspect, + ) + + try: + result = response.json() + except Exception as exc: + return error_response( + error=f"xAI returned invalid JSON: {exc}", + error_type="invalid_response", + provider="xai", + model=model_id, + prompt=prompt, + aspect_ratio=aspect, + ) + + # Parse response — xAI returns data[0].b64_json or data[0].url + data = result.get("data", []) + if not data: + return error_response( + error="xAI returned no image data", + error_type="empty_response", + provider="xai", + model=model_id, + prompt=prompt, + aspect_ratio=aspect, + ) + + first = data[0] + b64 = first.get("b64_json") + url = first.get("url") + + if b64: + try: + saved_path = save_b64_image(b64, prefix=f"xai_{model_id}") + except Exception as exc: + return error_response( + error=f"Could not save image to cache: {exc}", + error_type="io_error", + provider="xai", + model=model_id, + prompt=prompt, + aspect_ratio=aspect, + ) + image_ref = str(saved_path) + elif url: + image_ref = url + else: + return error_response( + error="xAI response contained neither b64_json nor URL", + error_type="empty_response", + provider="xai", + model=model_id, + prompt=prompt, + aspect_ratio=aspect, + ) + + extra: Dict[str, Any] = {} + if reference_images: + extra["reference_images"] = len(reference_images) + + return success_response( + image=image_ref, + model=model_id, + prompt=prompt, + aspect_ratio=aspect, + provider="xai", + extra=extra if extra else None, + ) + + +# --------------------------------------------------------------------------- +# Plugin registration +# --------------------------------------------------------------------------- + + +def register(ctx: Any) -> None: + """Register this provider with the image gen registry.""" + ctx.register_image_gen_provider(XAIImageGenProvider()) diff --git a/plugins/image_gen/xai/plugin.yaml b/plugins/image_gen/xai/plugin.yaml new file mode 100644 index 000000000..af735846a --- /dev/null +++ b/plugins/image_gen/xai/plugin.yaml @@ -0,0 +1,7 @@ +name: xai +version: 1.0.0 +description: "xAI image generation backend (grok-imagine-image). Supports text-to-image and editing." +author: Julien Talbot +kind: backend +requires_env: + - XAI_API_KEY diff --git a/tests/plugins/image_gen/test_xai_provider.py b/tests/plugins/image_gen/test_xai_provider.py new file mode 100644 index 000000000..b69e3e18d --- /dev/null +++ b/tests/plugins/image_gen/test_xai_provider.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +"""Tests for xAI image generation provider.""" + +from __future__ import annotations + +import json +import os +from unittest.mock import MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _fake_api_key(monkeypatch): + """Ensure XAI_API_KEY is set for all tests.""" + monkeypatch.setenv("XAI_API_KEY", "test-key-12345") + + +# --------------------------------------------------------------------------- +# Provider class tests +# --------------------------------------------------------------------------- + + +class TestXAIImageGenProvider: + def test_name(self): + from plugins.image_gen.xai import XAIImageGenProvider + + provider = XAIImageGenProvider() + assert provider.name == "xai" + + def test_display_name(self): + from plugins.image_gen.xai import XAIImageGenProvider + + provider = XAIImageGenProvider() + assert provider.display_name == "xAI (Grok)" + + def test_is_available_with_key(self, monkeypatch): + monkeypatch.setenv("XAI_API_KEY", "sk-xxx") + from plugins.image_gen.xai import XAIImageGenProvider + + provider = XAIImageGenProvider() + assert provider.is_available() is True + + def test_is_available_without_key(self, monkeypatch): + monkeypatch.delenv("XAI_API_KEY", raising=False) + from plugins.image_gen.xai import XAIImageGenProvider + + provider = XAIImageGenProvider() + assert provider.is_available() is False + + def test_list_models(self): + from plugins.image_gen.xai import XAIImageGenProvider + + provider = XAIImageGenProvider() + models = provider.list_models() + assert len(models) >= 1 + assert models[0]["id"] == "grok-imagine-image" + + def test_default_model(self): + from plugins.image_gen.xai import XAIImageGenProvider + + provider = XAIImageGenProvider() + assert provider.default_model() == "grok-imagine-image" + + def test_get_setup_schema(self): + from plugins.image_gen.xai import XAIImageGenProvider + + provider = XAIImageGenProvider() + schema = provider.get_setup_schema() + assert schema["name"] == "xAI (Grok)" + assert schema["badge"] == "paid" + assert len(schema["env_vars"]) == 1 + assert schema["env_vars"][0]["key"] == "XAI_API_KEY" + + +# --------------------------------------------------------------------------- +# Config tests +# --------------------------------------------------------------------------- + + +class TestConfig: + def test_default_model(self): + from plugins.image_gen.xai import _resolve_model + + model_id, meta = _resolve_model() + assert model_id == "grok-imagine-image" + + def test_default_resolution(self): + from plugins.image_gen.xai import _resolve_resolution + + assert _resolve_resolution() == "1k" + + def test_custom_model(self, monkeypatch): + monkeypatch.setenv("XAI_IMAGE_MODEL", "grok-imagine-image") + from plugins.image_gen.xai import _resolve_model + + model_id, _ = _resolve_model() + assert model_id == "grok-imagine-image" + + +# --------------------------------------------------------------------------- +# Generate tests +# --------------------------------------------------------------------------- + + +class TestGenerate: + def test_missing_api_key(self, monkeypatch): + monkeypatch.delenv("XAI_API_KEY", raising=False) + from plugins.image_gen.xai import XAIImageGenProvider + + provider = XAIImageGenProvider() + result = provider.generate(prompt="test") + assert result["success"] is False + assert "XAI_API_KEY" in result["error"] + + def test_successful_generation(self): + from plugins.image_gen.xai import XAIImageGenProvider + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.raise_for_status = MagicMock() + mock_resp.json.return_value = { + "data": [{"b64_json": "dGVzdC1pbWFnZS1kYXRh"}], # base64 "test-image-data" + } + + with patch("plugins.image_gen.xai.requests.post", return_value=mock_resp): + with patch("plugins.image_gen.xai.save_b64_image", return_value="/tmp/test.png"): + provider = XAIImageGenProvider() + result = provider.generate(prompt="A cat playing piano") + + assert result["success"] is True + assert result["image"] == "/tmp/test.png" + assert result["provider"] == "xai" + assert result["model"] == "grok-imagine-image" + + def test_successful_url_response(self): + from plugins.image_gen.xai import XAIImageGenProvider + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.raise_for_status = MagicMock() + mock_resp.json.return_value = { + "data": [{"url": "https://xai.image/result.png"}], + } + + with patch("plugins.image_gen.xai.requests.post", return_value=mock_resp): + provider = XAIImageGenProvider() + result = provider.generate(prompt="A cat playing piano") + + assert result["success"] is True + assert result["image"] == "https://xai.image/result.png" + + def test_api_error(self): + import requests as req_lib + from plugins.image_gen.xai import XAIImageGenProvider + + mock_resp = MagicMock() + mock_resp.status_code = 401 + mock_resp.text = "Unauthorized" + mock_resp.json.return_value = {"error": {"message": "Invalid API key"}} + mock_resp.raise_for_status.side_effect = req_lib.HTTPError(response=mock_resp) + + with patch("plugins.image_gen.xai.requests.post", return_value=mock_resp): + provider = XAIImageGenProvider() + result = provider.generate(prompt="test") + + assert result["success"] is False + assert result["error_type"] == "api_error" + + def test_timeout(self): + import requests as req_lib + + from plugins.image_gen.xai import XAIImageGenProvider + + with patch("plugins.image_gen.xai.requests.post", side_effect=req_lib.Timeout()): + provider = XAIImageGenProvider() + result = provider.generate(prompt="test") + + assert result["success"] is False + assert result["error_type"] == "timeout" + + def test_empty_response(self): + from plugins.image_gen.xai import XAIImageGenProvider + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.raise_for_status = MagicMock() + mock_resp.json.return_value = {"data": []} + + with patch("plugins.image_gen.xai.requests.post", return_value=mock_resp): + provider = XAIImageGenProvider() + result = provider.generate(prompt="test") + + assert result["success"] is False + assert result["error_type"] == "empty_response" + + def test_with_reference_images(self): + from plugins.image_gen.xai import XAIImageGenProvider + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.raise_for_status = MagicMock() + mock_resp.json.return_value = { + "data": [{"url": "https://xai.image/edited.png"}], + } + + with patch("plugins.image_gen.xai.requests.post", return_value=mock_resp) as mock_post: + provider = XAIImageGenProvider() + result = provider.generate( + prompt="Edit this image", + reference_images=["https://example.com/ref1.png", "https://example.com/ref2.png"], + ) + + assert result["success"] is True + # Check that reference_images was passed in payload + call_args = mock_post.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + assert "reference_images" in payload + assert len(payload["reference_images"]) == 2 + + def test_auth_header(self): + from plugins.image_gen.xai import XAIImageGenProvider + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.raise_for_status = MagicMock() + mock_resp.json.return_value = { + "data": [{"url": "https://xai.image/test.png"}], + } + + with patch("plugins.image_gen.xai.requests.post", return_value=mock_resp) as mock_post: + provider = XAIImageGenProvider() + provider.generate(prompt="test") + + call_args = mock_post.call_args + headers = call_args.kwargs.get("headers") or call_args[1].get("headers") + assert "Bearer test-key-12345" in headers["Authorization"] + assert "Hermes-Agent" in headers["User-Agent"] + + +# --------------------------------------------------------------------------- +# Registration test +# --------------------------------------------------------------------------- + + +class TestRegistration: + def test_register(self): + from plugins.image_gen.xai import XAIImageGenProvider, register + + mock_ctx = MagicMock() + register(mock_ctx) + mock_ctx.register_image_gen_provider.assert_called_once() + provider = mock_ctx.register_image_gen_provider.call_args[0][0] + assert isinstance(provider, XAIImageGenProvider) + assert provider.name == "xai"