diff --git a/run_agent.py b/run_agent.py index 64daad4c8..3d957b390 100644 --- a/run_agent.py +++ b/run_agent.py @@ -4318,6 +4318,7 @@ class AIAgent: try: with active_client.responses.stream(**api_kwargs) as stream: for event in stream: + self._touch_activity("receiving stream response") if self._interrupt_requested: break event_type = getattr(event, "type", "") @@ -4442,6 +4443,7 @@ class AIAgent: collected_text_deltas: list = [] try: for event in stream_or_response: + self._touch_activity("receiving stream response") event_type = getattr(event, "type", None) if not event_type and isinstance(event, dict): event_type = event.get("type") @@ -5074,12 +5076,9 @@ class AIAgent: role = "assistant" reasoning_parts: list = [] usage_obj = None - _first_chunk_seen = False for chunk in stream: last_chunk_time["t"] = time.time() - if not _first_chunk_seen: - _first_chunk_seen = True - self._touch_activity("receiving stream response") + self._touch_activity("receiving stream response") if self._interrupt_requested: break @@ -5255,6 +5254,7 @@ class AIAgent: # actively arriving (the chat_completions path # already does this at the top of its chunk loop). last_chunk_time["t"] = time.time() + self._touch_activity("receiving stream response") if self._interrupt_requested: break diff --git a/tests/run_agent/test_streaming.py b/tests/run_agent/test_streaming.py index 1943b0611..97dcffc67 100644 --- a/tests/run_agent/test_streaming.py +++ b/tests/run_agent/test_streaming.py @@ -291,6 +291,38 @@ class TestStreamingCallbacks: assert len(first_delta_calls) == 1 + @patch("run_agent.AIAgent._create_request_openai_client") + @patch("run_agent.AIAgent._close_request_openai_client") + def test_chat_stream_refreshes_activity_on_every_chunk(self, mock_close, mock_create): + """Each streamed chat chunk should refresh the activity timestamp.""" + from run_agent import AIAgent + + chunks = [ + _make_stream_chunk(content="a"), + _make_stream_chunk(content="b"), + _make_stream_chunk(finish_reason="stop"), + ] + + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = iter(chunks) + mock_create.return_value = mock_client + + agent = AIAgent( + model="test/model", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + agent.api_mode = "chat_completions" + agent._interrupt_requested = False + + touch_calls = [] + agent._touch_activity = lambda desc: touch_calls.append(desc) + + agent._interruptible_streaming_api_call({}) + + assert touch_calls.count("receiving stream response") == len(chunks) + @patch("run_agent.AIAgent._create_request_openai_client") @patch("run_agent.AIAgent._close_request_openai_client") def test_tool_only_does_not_fire_callback(self, mock_close, mock_create): @@ -693,6 +725,55 @@ class TestCodexStreamCallbacks: response = agent._run_codex_stream({}, client=mock_client) assert "Hello from Codex!" in deltas + def test_codex_stream_refreshes_activity_on_every_event(self): + from run_agent import AIAgent + + agent = AIAgent( + model="test/model", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + agent.api_mode = "codex_responses" + agent._interrupt_requested = False + + touch_calls = [] + agent._touch_activity = lambda desc: touch_calls.append(desc) + + mock_event_text_1 = SimpleNamespace( + type="response.output_text.delta", + delta="Hello", + ) + mock_event_text_2 = SimpleNamespace( + type="response.output_text.delta", + delta=" world", + ) + mock_event_done = SimpleNamespace( + type="response.completed", + delta="", + ) + + mock_stream = MagicMock() + mock_stream.__enter__ = MagicMock(return_value=mock_stream) + mock_stream.__exit__ = MagicMock(return_value=False) + mock_stream.__iter__ = MagicMock( + return_value=iter([mock_event_text_1, mock_event_text_2, mock_event_done]) + ) + mock_stream.get_final_response.return_value = SimpleNamespace( + output=[SimpleNamespace( + type="message", + content=[SimpleNamespace(type="output_text", text="Hello world")], + )], + status="completed", + ) + + mock_client = MagicMock() + mock_client.responses.stream.return_value = mock_stream + + agent._run_codex_stream({}, client=mock_client) + + assert touch_calls.count("receiving stream response") == 3 + def test_codex_remote_protocol_error_falls_back_to_create_stream(self): from run_agent import AIAgent import httpx @@ -724,3 +805,102 @@ class TestCodexStreamCallbacks: assert response is fallback_response mock_fallback.assert_called_once_with({}, client=mock_client) + + def test_codex_create_stream_fallback_refreshes_activity_on_every_event(self): + from run_agent import AIAgent + + agent = AIAgent( + model="test/model", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + agent.api_mode = "codex_responses" + + touch_calls = [] + agent._touch_activity = lambda desc: touch_calls.append(desc) + + events = [ + SimpleNamespace(type="response.output_text.delta", delta="Hello"), + SimpleNamespace(type="response.output_item.done", item=SimpleNamespace(type="message")), + SimpleNamespace( + type="response.completed", + response=SimpleNamespace( + output=[SimpleNamespace( + type="message", + content=[SimpleNamespace(type="output_text", text="Hello")], + )] + ), + ), + ] + + class _FakeCreateStream: + def __iter__(self_inner): + return iter(events) + + def close(self_inner): + return None + + mock_stream = _FakeCreateStream() + + mock_client = MagicMock() + mock_client.responses.create.return_value = mock_stream + + agent._run_codex_create_stream_fallback( + {"model": "test/model", "instructions": "hi", "input": []}, + client=mock_client, + ) + + assert touch_calls.count("receiving stream response") == len(events) + + +class TestAnthropicStreamCallbacks: + """Verify Anthropic streaming refreshes activity on every event.""" + + def test_anthropic_stream_refreshes_activity_on_every_event(self): + from run_agent import AIAgent + + agent = AIAgent( + model="test/model", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + agent.api_mode = "anthropic_messages" + agent._interrupt_requested = False + + touch_calls = [] + agent._touch_activity = lambda desc: touch_calls.append(desc) + + events = [ + SimpleNamespace( + type="content_block_delta", + delta=SimpleNamespace(type="text_delta", text="Hello"), + ), + SimpleNamespace( + type="content_block_delta", + delta=SimpleNamespace(type="thinking_delta", thinking="thinking"), + ), + SimpleNamespace( + type="content_block_start", + content_block=SimpleNamespace(type="tool_use", name="terminal"), + ), + ] + + final_message = SimpleNamespace( + content=[], + stop_reason="end_turn", + ) + + mock_stream = MagicMock() + mock_stream.__enter__ = MagicMock(return_value=mock_stream) + mock_stream.__exit__ = MagicMock(return_value=False) + mock_stream.__iter__ = MagicMock(return_value=iter(events)) + mock_stream.get_final_message.return_value = final_message + + agent._anthropic_client = MagicMock() + agent._anthropic_client.messages.stream.return_value = mock_stream + + agent._interruptible_streaming_api_call({}) + + assert touch_calls.count("receiving stream response") == len(events)