diff --git a/tests/tools/test_transcription_tools.py b/tests/tools/test_transcription_tools.py index effd4e1a67..2a24dc2410 100644 --- a/tests/tools/test_transcription_tools.py +++ b/tests/tools/test_transcription_tools.py @@ -995,3 +995,259 @@ class TestTranscribeAudioMistralDispatch: transcribe_audio(sample_ogg, model="voxtral-mini-2602") assert mock_mistral.call_args[0][1] == "voxtral-mini-2602" + + +# ============================================================================ +# _transcribe_xai +# ============================================================================ + + +@pytest.fixture +def mock_xai_http_module(): + """Inject a fake tools.xai_http module for testing.""" + fake_module = MagicMock() + fake_module.hermes_xai_user_agent = MagicMock(return_value="hermes-xai/test") + with patch.dict("sys.modules", {"tools.xai_http": fake_module}): + yield fake_module + + +class TestTranscribeXAI: + def test_no_key(self, monkeypatch): + monkeypatch.delenv("XAI_API_KEY", raising=False) + from tools.transcription_tools import _transcribe_xai + result = _transcribe_xai("/tmp/test.ogg", "grok-stt") + assert result["success"] is False + assert "XAI_API_KEY" in result["error"] + + def test_successful_transcription(self, monkeypatch, sample_ogg, mock_xai_http_module): + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "text": "bonjour le monde", + "language": "fr", + "duration": 3.2, + } + + with patch("tools.transcription_tools._load_stt_config", return_value={}), \ + patch("requests.post", return_value=mock_response): + from tools.transcription_tools import _transcribe_xai + result = _transcribe_xai(sample_ogg, "grok-stt") + + assert result["success"] is True + assert result["transcript"] == "bonjour le monde" + assert result["provider"] == "xai" + + def test_whitespace_stripped(self, monkeypatch, sample_ogg, mock_xai_http_module): + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"text": " hello world \n"} + + with patch("tools.transcription_tools._load_stt_config", return_value={}), \ + patch("requests.post", return_value=mock_response): + from tools.transcription_tools import _transcribe_xai + result = _transcribe_xai(sample_ogg, "grok-stt") + + assert result["transcript"] == "hello world" + + def test_api_error_returns_failure(self, monkeypatch, sample_ogg, mock_xai_http_module): + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.json.return_value = {"error": {"message": "Invalid audio format"}} + mock_response.text = '{"error": {"message": "Invalid audio format"}}' + + with patch("tools.transcription_tools._load_stt_config", return_value={}), \ + patch("requests.post", return_value=mock_response): + from tools.transcription_tools import _transcribe_xai + result = _transcribe_xai(sample_ogg, "grok-stt") + + assert result["success"] is False + assert "HTTP 400" in result["error"] + assert "Invalid audio format" in result["error"] + + def test_empty_transcript_returns_failure(self, monkeypatch, sample_ogg, mock_xai_http_module): + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"text": " "} + + with patch("tools.transcription_tools._load_stt_config", return_value={}), \ + patch("requests.post", return_value=mock_response): + from tools.transcription_tools import _transcribe_xai + result = _transcribe_xai(sample_ogg, "grok-stt") + + assert result["success"] is False + assert "empty transcript" in result["error"] + + def test_permission_error(self, monkeypatch, sample_ogg, mock_xai_http_module): + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + + with patch("tools.transcription_tools._load_stt_config", return_value={}), \ + patch("builtins.open", side_effect=PermissionError("denied")): + from tools.transcription_tools import _transcribe_xai + result = _transcribe_xai(sample_ogg, "grok-stt") + + assert result["success"] is False + assert "Permission denied" in result["error"] + + def test_network_error_returns_failure(self, monkeypatch, sample_ogg, mock_xai_http_module): + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + + with patch("tools.transcription_tools._load_stt_config", return_value={}), \ + patch("requests.post", side_effect=ConnectionError("timeout")): + from tools.transcription_tools import _transcribe_xai + result = _transcribe_xai(sample_ogg, "grok-stt") + + assert result["success"] is False + assert "timeout" in result["error"] + + def test_sends_language_and_format(self, monkeypatch, sample_ogg, mock_xai_http_module): + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"text": "test", "language": "fr", "duration": 1.0} + + with patch("tools.transcription_tools._load_stt_config", return_value={}), \ + patch("requests.post", return_value=mock_response) as mock_post: + from tools.transcription_tools import _transcribe_xai + _transcribe_xai(sample_ogg, "grok-stt") + + call_kwargs = mock_post.call_args + data = call_kwargs.kwargs.get("data", call_kwargs[1].get("data", {})) + assert data.get("language") == "fr" + assert data.get("format") == "true" + + def test_custom_base_url(self, monkeypatch, sample_ogg, mock_xai_http_module): + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + monkeypatch.setenv("XAI_STT_BASE_URL", "https://custom.x.ai/v1") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"text": "test", "language": "en", "duration": 1.0} + + with patch("tools.transcription_tools._load_stt_config", return_value={}), \ + patch("requests.post", return_value=mock_response) as mock_post: + from tools.transcription_tools import _transcribe_xai + _transcribe_xai(sample_ogg, "grok-stt") + + call_args = mock_post.call_args + url = call_args[0][0] if call_args[0] else call_args.kwargs.get("url", "") + assert "custom.x.ai" in url + + def test_diarize_sent_when_configured(self, monkeypatch, sample_ogg, mock_xai_http_module): + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"text": "test", "language": "fr", "duration": 1.0} + + config = {"xai": {"diarize": True}} + with patch("tools.transcription_tools._load_stt_config", return_value=config), \ + patch("requests.post", return_value=mock_response) as mock_post: + from tools.transcription_tools import _transcribe_xai + _transcribe_xai(sample_ogg, "grok-stt") + + data = mock_post.call_args.kwargs.get("data", mock_post.call_args[1].get("data", {})) + assert data.get("diarize") == "true" + + +# ============================================================================ +# _get_provider — xAI +# ============================================================================ + +class TestGetProviderXAI: + """xAI-specific provider selection tests.""" + + def test_xai_when_key_set(self, monkeypatch): + monkeypatch.setenv("XAI_API_KEY", "xai-test") + from tools.transcription_tools import _get_provider + assert _get_provider({"provider": "xai"}) == "xai" + + def test_xai_explicit_no_key_returns_none(self, monkeypatch): + """Explicit xai with no key returns none — no cross-provider fallback.""" + monkeypatch.delenv("XAI_API_KEY", raising=False) + from tools.transcription_tools import _get_provider + assert _get_provider({"provider": "xai"}) == "none" + + def test_auto_detect_xai_after_mistral(self, monkeypatch): + """Auto-detect: xai is tried after mistral when all above are unavailable.""" + monkeypatch.delenv("GROQ_API_KEY", raising=False) + monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False) + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("MISTRAL_API_KEY", raising=False) + monkeypatch.setenv("XAI_API_KEY", "xai-test") + with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._has_local_command", return_value=False), \ + patch("tools.transcription_tools._HAS_OPENAI", False), \ + patch("tools.transcription_tools._HAS_MISTRAL", False): + from tools.transcription_tools import _get_provider + assert _get_provider({}) == "xai" + + def test_auto_detect_mistral_preferred_over_xai(self, monkeypatch): + """Auto-detect: mistral is preferred over xai.""" + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + monkeypatch.setenv("XAI_API_KEY", "xai-test") + monkeypatch.delenv("GROQ_API_KEY", raising=False) + monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False) + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._has_local_command", return_value=False), \ + patch("tools.transcription_tools._HAS_OPENAI", False), \ + patch("tools.transcription_tools._HAS_MISTRAL", True): + from tools.transcription_tools import _get_provider + assert _get_provider({}) == "mistral" + + def test_auto_detect_no_key_returns_none(self, monkeypatch): + """Auto-detect: xai skipped when no key is set.""" + monkeypatch.delenv("XAI_API_KEY", raising=False) + with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._has_local_command", return_value=False), \ + patch("tools.transcription_tools._HAS_OPENAI", False), \ + patch("tools.transcription_tools._HAS_MISTRAL", False): + from tools.transcription_tools import _get_provider + assert _get_provider({}) == "none" + + +# ============================================================================ +# transcribe_audio — xAI dispatch +# ============================================================================ + +class TestTranscribeAudioXAIDispatch: + def test_dispatches_to_xai(self, sample_ogg): + with patch("tools.transcription_tools._load_stt_config", return_value={"provider": "xai"}), \ + patch("tools.transcription_tools._get_provider", return_value="xai"), \ + patch("tools.transcription_tools._transcribe_xai", + return_value={"success": True, "transcript": "hi", "provider": "xai"}) as mock_xai: + from tools.transcription_tools import transcribe_audio + result = transcribe_audio(sample_ogg) + + assert result["success"] is True + assert result["provider"] == "xai" + mock_xai.assert_called_once() + + def test_model_default_is_grok_stt(self, sample_ogg): + with patch("tools.transcription_tools._load_stt_config", return_value={"provider": "xai"}), \ + patch("tools.transcription_tools._get_provider", return_value="xai"), \ + patch("tools.transcription_tools._transcribe_xai", + return_value={"success": True, "transcript": "hi"}) as mock_xai: + from tools.transcription_tools import transcribe_audio + transcribe_audio(sample_ogg, model=None) + + assert mock_xai.call_args[0][1] == "grok-stt" + + def test_model_override_passed_to_xai(self, sample_ogg): + with patch("tools.transcription_tools._load_stt_config", return_value={}), \ + patch("tools.transcription_tools._get_provider", return_value="xai"), \ + patch("tools.transcription_tools._transcribe_xai", + return_value={"success": True, "transcript": "hi"}) as mock_xai: + from tools.transcription_tools import transcribe_audio + transcribe_audio(sample_ogg, model="custom-stt") + + assert mock_xai.call_args[0][1] == "custom-stt"