diff --git a/tests/tools/test_code_execution.py b/tests/tools/test_code_execution.py index 7c39c9e1c4..94157651ca 100644 --- a/tests/tools/test_code_execution.py +++ b/tests/tools/test_code_execution.py @@ -393,5 +393,56 @@ class TestStubSchemaDrift(unittest.TestCase): self.assertIn("mode", src) +class TestHeadTailTruncation(unittest.TestCase): + """Tests for head+tail truncation of large stdout in execute_code.""" + + def _run(self, code): + with patch("model_tools.handle_function_call", side_effect=_mock_handle_function_call): + result = execute_code( + code=code, + task_id="test-task", + enabled_tools=list(SANDBOX_ALLOWED_TOOLS), + ) + return json.loads(result) + + def test_short_output_not_truncated(self): + """Output under MAX_STDOUT_BYTES should not be truncated.""" + result = self._run('print("small output")') + self.assertEqual(result["status"], "success") + self.assertIn("small output", result["output"]) + self.assertNotIn("TRUNCATED", result["output"]) + + def test_large_output_preserves_head_and_tail(self): + """Output exceeding MAX_STDOUT_BYTES keeps both head and tail.""" + code = ''' +# Print HEAD marker, then filler, then TAIL marker +print("HEAD_MARKER_START") +for i in range(15000): + print(f"filler_line_{i:06d}_padding_to_fill_buffer") +print("TAIL_MARKER_END") +''' + result = self._run(code) + self.assertEqual(result["status"], "success") + output = result["output"] + # Head should be preserved + self.assertIn("HEAD_MARKER_START", output) + # Tail should be preserved (this is the key improvement) + self.assertIn("TAIL_MARKER_END", output) + # Truncation notice should be present + self.assertIn("TRUNCATED", output) + + def test_truncation_notice_format(self): + """Truncation notice includes character counts.""" + code = ''' +for i in range(15000): + print(f"padding_line_{i:06d}_xxxxxxxxxxxxxxxxxxxxxxxxxx") +''' + result = self._run(code) + output = result["output"] + if "TRUNCATED" in output: + self.assertIn("chars omitted", output) + self.assertIn("total", output) + + if __name__ == "__main__": unittest.main() diff --git a/tools/code_execution_tool.py b/tools/code_execution_tool.py index 7ea8fa8e40..b16b1d870d 100644 --- a/tools/code_execution_tool.py +++ b/tools/code_execution_tool.py @@ -457,11 +457,17 @@ def execute_code( # --- Poll loop: watch for exit, timeout, and interrupt --- deadline = time.monotonic() + timeout - stdout_chunks: list = [] stderr_chunks: list = [] - # Background readers to avoid pipe buffer deadlocks + # Background readers to avoid pipe buffer deadlocks. + # For stdout we use a head+tail strategy: keep the first HEAD_BYTES + # and a rolling window of the last TAIL_BYTES so the final print() + # output is never lost. Stderr keeps head-only (errors appear early). + _STDOUT_HEAD_BYTES = int(MAX_STDOUT_BYTES * 0.4) # 40% head + _STDOUT_TAIL_BYTES = MAX_STDOUT_BYTES - _STDOUT_HEAD_BYTES # 60% tail + def _drain(pipe, chunks, max_bytes): + """Simple head-only drain (used for stderr).""" total = 0 try: while True: @@ -475,8 +481,48 @@ def execute_code( except (ValueError, OSError): pass + stdout_total_bytes = [0] # mutable ref for total bytes seen + + def _drain_head_tail(pipe, head_chunks, tail_chunks, head_bytes, tail_bytes, total_ref): + """Drain stdout keeping both head and tail data.""" + head_collected = 0 + from collections import deque + tail_buf = deque() + tail_collected = 0 + try: + while True: + data = pipe.read(4096) + if not data: + break + total_ref[0] += len(data) + # Fill head buffer first + if head_collected < head_bytes: + keep = min(len(data), head_bytes - head_collected) + head_chunks.append(data[:keep]) + head_collected += keep + data = data[keep:] # remaining goes to tail + if not data: + continue + # Everything past head goes into rolling tail buffer + tail_buf.append(data) + tail_collected += len(data) + # Evict old tail data to stay within tail_bytes budget + while tail_collected > tail_bytes and tail_buf: + oldest = tail_buf.popleft() + tail_collected -= len(oldest) + except (ValueError, OSError): + pass + # Transfer final tail to output list + tail_chunks.extend(tail_buf) + + stdout_head_chunks: list = [] + stdout_tail_chunks: list = [] + stdout_reader = threading.Thread( - target=_drain, args=(proc.stdout, stdout_chunks, MAX_STDOUT_BYTES), daemon=True + target=_drain_head_tail, + args=(proc.stdout, stdout_head_chunks, stdout_tail_chunks, + _STDOUT_HEAD_BYTES, _STDOUT_TAIL_BYTES, stdout_total_bytes), + daemon=True ) stderr_reader = threading.Thread( target=_drain, args=(proc.stderr, stderr_chunks, MAX_STDERR_BYTES), daemon=True @@ -500,12 +546,21 @@ def execute_code( stdout_reader.join(timeout=3) stderr_reader.join(timeout=3) - stdout_text = b"".join(stdout_chunks).decode("utf-8", errors="replace") + stdout_head = b"".join(stdout_head_chunks).decode("utf-8", errors="replace") + stdout_tail = b"".join(stdout_tail_chunks).decode("utf-8", errors="replace") stderr_text = b"".join(stderr_chunks).decode("utf-8", errors="replace") - # Truncation notice - if len(stdout_text) >= MAX_STDOUT_BYTES: - stdout_text = stdout_text[:MAX_STDOUT_BYTES] + "\n[output truncated at 50KB]" + # Assemble stdout with head+tail truncation + total_stdout = stdout_total_bytes[0] + if total_stdout > MAX_STDOUT_BYTES and stdout_tail: + omitted = total_stdout - len(stdout_head) - len(stdout_tail) + truncated_notice = ( + f"\n\n... [OUTPUT TRUNCATED - {omitted:,} chars omitted " + f"out of {total_stdout:,} total] ...\n\n" + ) + stdout_text = stdout_head + truncated_notice + stdout_tail + else: + stdout_text = stdout_head + stdout_tail exit_code = proc.returncode if proc.returncode is not None else -1 duration = round(time.monotonic() - exec_start, 2)