diff --git a/tests/gateway/test_retry_replacement.py b/tests/gateway/test_retry_replacement.py new file mode 100644 index 0000000000..e62979cc73 --- /dev/null +++ b/tests/gateway/test_retry_replacement.py @@ -0,0 +1,97 @@ +"""Regression tests for /retry replacement semantics.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from gateway.config import GatewayConfig +from gateway.platforms.base import MessageEvent, MessageType +from gateway.run import GatewayRunner +from gateway.session import SessionStore + + +@pytest.mark.asyncio +async def test_gateway_retry_replaces_last_user_turn_in_transcript(tmp_path): + config = GatewayConfig() + with patch("gateway.session.SessionStore._ensure_loaded"): + store = SessionStore(sessions_dir=tmp_path, config=config) + store._db = None + store._loaded = True + + session_id = "retry_session" + for msg in [ + {"role": "session_meta", "tools": []}, + {"role": "user", "content": "first question"}, + {"role": "assistant", "content": "first answer"}, + {"role": "user", "content": "retry me"}, + {"role": "assistant", "content": "old answer"}, + ]: + store.append_to_transcript(session_id, msg) + + gw = GatewayRunner.__new__(GatewayRunner) + gw.config = config + gw.session_store = store + + session_entry = MagicMock(session_id=session_id) + session_entry.last_prompt_tokens = 111 + gw.session_store.get_or_create_session = MagicMock(return_value=session_entry) + + async def fake_handle_message(event): + assert event.text == "retry me" + transcript_before = store.load_transcript(session_id) + assert [m.get("content") for m in transcript_before if m.get("role") == "user"] == [ + "first question" + ] + store.append_to_transcript(session_id, {"role": "user", "content": event.text}) + store.append_to_transcript(session_id, {"role": "assistant", "content": "new answer"}) + return "new answer" + + gw._handle_message = AsyncMock(side_effect=fake_handle_message) + + result = await gw._handle_retry_command( + MessageEvent(text="/retry", message_type=MessageType.TEXT, source=MagicMock()) + ) + + assert result == "new answer" + transcript_after = store.load_transcript(session_id) + assert [m.get("content") for m in transcript_after if m.get("role") == "user"] == [ + "first question", + "retry me", + ] + assert [m.get("content") for m in transcript_after if m.get("role") == "assistant"] == [ + "first answer", + "new answer", + ] + + +@pytest.mark.asyncio +async def test_gateway_retry_replays_original_text_not_retry_command(tmp_path): + config = MagicMock() + config.sessions_dir = tmp_path + config.max_context_messages = 20 + gw = GatewayRunner.__new__(GatewayRunner) + gw.config = config + gw.session_store = MagicMock() + + session_entry = MagicMock(session_id="test-session") + session_entry.last_prompt_tokens = 55 + gw.session_store.get_or_create_session.return_value = session_entry + gw.session_store.load_transcript.return_value = [ + {"role": "user", "content": "real message"}, + {"role": "assistant", "content": "answer"}, + ] + gw.session_store.rewrite_transcript = MagicMock() + + captured = {} + + async def fake_handle_message(event): + captured["text"] = event.text + return "ok" + + gw._handle_message = AsyncMock(side_effect=fake_handle_message) + + await gw._handle_retry_command( + MessageEvent(text="/retry", message_type=MessageType.TEXT, source=MagicMock()) + ) + + assert captured["text"] == "real message" diff --git a/tests/test_cli_retry.py b/tests/test_cli_retry.py new file mode 100644 index 0000000000..74e2512bfe --- /dev/null +++ b/tests/test_cli_retry.py @@ -0,0 +1,49 @@ +"""Regression tests for CLI /retry history replacement semantics.""" + +from tests.test_cli_init import _make_cli + + +def test_retry_last_truncates_history_before_requeueing_message(): + cli = _make_cli() + cli.conversation_history = [ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "one"}, + {"role": "user", "content": "retry me"}, + {"role": "assistant", "content": "old answer"}, + ] + + retry_msg = cli.retry_last() + + assert retry_msg == "retry me" + assert cli.conversation_history == [ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "one"}, + ] + + cli.conversation_history.append({"role": "user", "content": retry_msg}) + cli.conversation_history.append({"role": "assistant", "content": "new answer"}) + + assert [m["content"] for m in cli.conversation_history if m["role"] == "user"] == [ + "first", + "retry me", + ] + + +def test_process_command_retry_requeues_original_message_not_retry_command(): + cli = _make_cli() + queued = [] + + class _Queue: + def put(self, value): + queued.append(value) + + cli._pending_input = _Queue() + cli.conversation_history = [ + {"role": "user", "content": "retry me"}, + {"role": "assistant", "content": "old answer"}, + ] + + cli.process_command("/retry") + + assert queued == ["retry me"] + assert cli.conversation_history == []