diff --git a/tests/tools/test_search_budget_truncation.py b/tests/tools/test_search_budget_truncation.py new file mode 100644 index 00000000000..ee614285087 --- /dev/null +++ b/tests/tools/test_search_budget_truncation.py @@ -0,0 +1,133 @@ +from unittest.mock import MagicMock + +import pytest + +from tools.file_operations import ExecuteResult, ShellFileOperations, _search_stdout_and_limit + + +TIMEOUT = "[Command timed out after 60s]" + + +@pytest.fixture() +def ops(): + env = MagicMock(cwd="/tmp/test") + env.execute.return_value = {"output": "", "returncode": 0} + return ShellFileOperations(env) + + +def timeout_output(*lines: str) -> str: + return "\n".join([*lines, TIMEOUT]) + + +def path_exists_or(output: str, returncode: int = 124): + def execute(command, **kwargs): + if "test -e" in command: + return {"output": "exists", "returncode": 0} + return {"output": output, "returncode": returncode} + + return execute + + +def assert_timed_out(result): + assert result.error is None + assert result.truncated is True + assert result.limit_reason == "search_timeout" + assert result.to_dict()["limit_reason"] == "search_timeout" + + +def test_timeout_helper_strips_only_trailing_marker(): + assert _search_stdout_and_limit(ExecuteResult(timeout_output("a.py"), 124)) == ("a.py", "search_timeout") + assert _search_stdout_and_limit(ExecuteResult("a.py\nnot a marker", 0)) == ("a.py\nnot a marker", None) + + +@pytest.mark.parametrize( + ("target", "output_mode", "raw", "expected"), + [ + ("files", "content", timeout_output("src/a.py", "src/b.py"), ["src/a.py", "src/b.py"]), + ("content", "files_only", timeout_output("src/a.py", "src/b.py"), ["src/a.py", "src/b.py"]), + ("content", "content", timeout_output("src/a.py:10:foo", "src/b.py:20:foo"), ["src/a.py", "src/b.py"]), + ], +) +def test_rg_timeout_returns_partial_results_without_marker(ops, monkeypatch, target, output_mode, raw, expected): + ops.env.execute.side_effect = path_exists_or(raw) + monkeypatch.setattr(ops, "_has_command", lambda cmd: cmd == "rg") + + result = ops.search("foo", path="/big", target=target, output_mode=output_mode) + + assert_timed_out(result) + if target == "content" and output_mode == "content": + assert [match.path for match in result.matches] == expected + assert all("timed out" not in match.content for match in result.matches) + else: + assert result.files == expected + assert all("timed out" not in path for path in result.files) + + +def test_rg_count_timeout_returns_partial_counts(ops, monkeypatch): + ops.env.execute.side_effect = path_exists_or(timeout_output("src/a.py:3", "src/b.py:5")) + monkeypatch.setattr(ops, "_has_command", lambda cmd: cmd == "rg") + + result = ops.search("foo", path="/big", target="content", output_mode="count") + + assert_timed_out(result) + assert result.counts == {"src/a.py": 3, "src/b.py": 5} + + +def test_rg_file_timeout_does_not_retry_unsorted(ops, monkeypatch): + calls = 0 + + def execute(command, **kwargs): + nonlocal calls + if "test -e" in command: + return {"output": "exists", "returncode": 0} + calls += 1 + return {"output": timeout_output(), "returncode": 124} + + ops.env.execute.side_effect = execute + monkeypatch.setattr(ops, "_has_command", lambda cmd: cmd == "rg") + + result = ops.search("*.py", path="/big", target="files") + + assert calls == 1 + assert_timed_out(result) + assert result.files == [] + + +def test_grep_timeout_returns_partial_match(ops, monkeypatch): + ops.env.execute.side_effect = path_exists_or(timeout_output("src/a.py:10:foo")) + monkeypatch.setattr(ops, "_has_command", lambda cmd: cmd == "grep") + + result = ops.search("foo", path="/big", target="content") + + assert_timed_out(result) + assert [match.path for match in result.matches] == ["src/a.py"] + + +def test_find_timeout_returns_partial_files_and_does_not_retry(ops, monkeypatch): + calls = 0 + + def execute(command, **kwargs): + nonlocal calls + if "test -e" in command: + return {"output": "exists", "returncode": 0} + calls += 1 + return {"output": timeout_output("1700000000.0 /big/a.py"), "returncode": 124} + + ops.env.execute.side_effect = execute + monkeypatch.setattr(ops, "_has_command", lambda cmd: cmd == "find") + + result = ops.search("*.py", path="/big", target="files") + + assert calls == 1 + assert_timed_out(result) + assert result.files == ["/big/a.py"] + + +def test_real_rg_error_still_hard_fails(ops, monkeypatch): + ops.env.execute.side_effect = path_exists_or("rg: regex parse error:", returncode=2) + monkeypatch.setattr(ops, "_has_command", lambda cmd: cmd == "rg") + + result = ops.search("[", path="/big", target="content") + + assert result.error == "Search failed: rg: regex parse error:" + assert result.limit_reason is None diff --git a/tools/file_operations.py b/tools/file_operations.py index 9815673eec1..1d523d70312 100644 --- a/tools/file_operations.py +++ b/tools/file_operations.py @@ -241,10 +241,11 @@ class SearchResult: counts: Dict[str, int] = field(default_factory=dict) total_count: int = 0 truncated: bool = False + limit_reason: Optional[str] = None error: Optional[str] = None def to_dict(self) -> dict: - result = {"total_count": self.total_count} + result: dict[str, object] = {"total_count": self.total_count} if self.matches: result["matches"] = [ {"path": m.path, "line": m.line_number, "content": m.content} @@ -256,6 +257,8 @@ class SearchResult: result["counts"] = self.counts if self.truncated: result["truncated"] = True + if self.limit_reason: + result["limit_reason"] = self.limit_reason if self.error: result["error"] = self.error return result @@ -285,6 +288,16 @@ class ExecuteResult: exit_code: int = 0 +_SEARCH_TIMEOUT_MARKER_RE = re.compile(r"\n?\[Command timed out after \d+s\]\s*$") + + +def _search_stdout_and_limit(result: ExecuteResult) -> tuple[str, Optional[str]]: + """Return stdout cleaned for parsing and a limit reason for search timeouts.""" + if result.exit_code == 124: + return _SEARCH_TIMEOUT_MARKER_RE.sub("", result.stdout), "search_timeout" + return result.stdout, None + + def _split_tool_diagnostics(output: str) -> tuple[str, str]: """Separate rg/grep diagnostic lines from real match output. @@ -1967,15 +1980,17 @@ class ShellFileOperations(FileOperations): f"-printf '%T@ %p\\n' 2>/dev/null | sort -rn{pagination_expr}" result = self._exec(cmd, timeout=60) + stdout, limit_reason = _search_stdout_and_limit(result) - if not result.stdout.strip(): + if not stdout.strip() and not limit_reason: # Try without -printf (BSD find compatibility -- macOS) cmd_simple = f"find {self._escape_shell_arg(path)}{hidden_filter_expr} -type f -name {self._escape_shell_arg(search_pattern)} " \ f"2>/dev/null | sort -rn{pagination_expr}" result = self._exec(cmd_simple, timeout=60) + stdout, limit_reason = _search_stdout_and_limit(result) files = [] - for line in result.stdout.strip().split('\n'): + for line in stdout.strip().split('\n'): if not line: continue parts = line.split(' ', 1) @@ -2003,7 +2018,9 @@ class ShellFileOperations(FileOperations): return SearchResult( files=files, - total_count=len(files) + total_count=len(files), + truncated=bool(limit_reason), + limit_reason=limit_reason, ) def _search_files_rg(self, pattern: str, path: str, limit: int, offset: int) -> SearchResult: @@ -2029,9 +2046,10 @@ class ShellFileOperations(FileOperations): f"| head -n {fetch_limit}" ) result = self._exec(cmd_sorted, timeout=60) - all_files = [f for f in result.stdout.strip().split('\n') if f] + stdout, limit_reason = _search_stdout_and_limit(result) + all_files = [f for f in stdout.strip().split('\n') if f] - if not all_files: + if not all_files and not limit_reason: # --sortr may have failed on older rg; retry without it. cmd_plain = ( f"rg --files -g {self._escape_shell_arg(glob_pattern)} " @@ -2039,14 +2057,16 @@ class ShellFileOperations(FileOperations): f"| head -n {fetch_limit}" ) result = self._exec(cmd_plain, timeout=60) - all_files = [f for f in result.stdout.strip().split('\n') if f] + stdout, limit_reason = _search_stdout_and_limit(result) + all_files = [f for f in stdout.strip().split('\n') if f] page = all_files[offset:offset + limit] return SearchResult( files=page, total_count=len(all_files), - truncated=len(all_files) >= fetch_limit, + truncated=len(all_files) >= fetch_limit or bool(limit_reason), + limit_reason=limit_reason, ) def _search_content(self, pattern: str, path: str, file_glob: Optional[str], @@ -2102,12 +2122,13 @@ class ShellFileOperations(FileOperations): # introduce false errors on a successful-but-truncated search. cmd = "set -o pipefail; " + " ".join(cmd_parts) result = self._exec(cmd, timeout=60) + stdout, limit_reason = _search_stdout_and_limit(result) # _exec merges stderr into stdout (stderr=subprocess.STDOUT), so rg's # diagnostic lines ("rg: : ", "rg: regex parse error:") # are interleaved with match output. Split them out: diagnostics must # not be parsed as matches, and on a hard error they ARE the message. - diagnostics, payload = _split_tool_diagnostics(result.stdout) + diagnostics, payload = _split_tool_diagnostics(stdout) # rg exit codes: 0=matches found, 1=no matches, 2=error. rg returns 2 # even on partial errors (e.g. one unreadable file in a tree that @@ -2124,7 +2145,12 @@ class ShellFileOperations(FileOperations): all_files = [f for f in stdout.strip().split('\n') if f] total = len(all_files) page = all_files[offset:offset + limit] - return SearchResult(files=page, total_count=total) + return SearchResult( + files=page, + total_count=total, + truncated=bool(limit_reason), + limit_reason=limit_reason, + ) elif output_mode == "count": counts = {} @@ -2136,7 +2162,12 @@ class ShellFileOperations(FileOperations): counts[parts[0]] = int(parts[1]) except ValueError: pass - return SearchResult(counts=counts, total_count=sum(counts.values())) + return SearchResult( + counts=counts, + total_count=sum(counts.values()), + truncated=bool(limit_reason), + limit_reason=limit_reason, + ) else: # Parse content matches and context lines. @@ -2177,7 +2208,8 @@ class ShellFileOperations(FileOperations): return SearchResult( matches=page, total_count=total, - truncated=total > offset + limit + truncated=total > offset + limit or bool(limit_reason), + limit_reason=limit_reason, ) def _search_with_grep(self, pattern: str, path: str, file_glob: Optional[str], @@ -2218,12 +2250,13 @@ class ShellFileOperations(FileOperations): # pipefail does not turn truncated results into false errors. cmd = "set -o pipefail; " + " ".join(cmd_parts) result = self._exec(cmd, timeout=60) + stdout, limit_reason = _search_stdout_and_limit(result) # _exec merges stderr into stdout, so grep's diagnostic lines # ("grep: : ") are interleaved with matches. Split them # out so they're never parsed as matches and so a hard error has a # clean message. - diagnostics, payload = _split_tool_diagnostics(result.stdout) + diagnostics, payload = _split_tool_diagnostics(stdout) # grep exit codes: 0=matches found, 1=no matches, 2=error. grep # returns 2 on partial errors (e.g. an unreadable file) even when @@ -2238,7 +2271,12 @@ class ShellFileOperations(FileOperations): all_files = [f for f in stdout.strip().split('\n') if f] total = len(all_files) page = all_files[offset:offset + limit] - return SearchResult(files=page, total_count=total) + return SearchResult( + files=page, + total_count=total, + truncated=bool(limit_reason), + limit_reason=limit_reason, + ) elif output_mode == "count": counts = {} @@ -2250,7 +2288,12 @@ class ShellFileOperations(FileOperations): counts[parts[0]] = int(parts[1]) except ValueError: pass - return SearchResult(counts=counts, total_count=sum(counts.values())) + return SearchResult( + counts=counts, + total_count=sum(counts.values()), + truncated=bool(limit_reason), + limit_reason=limit_reason, + ) else: # grep match lines: "file:lineno:content" (colon) @@ -2288,5 +2331,6 @@ class ShellFileOperations(FileOperations): return SearchResult( matches=page, total_count=total, - truncated=total > offset + limit + truncated=total > offset + limit or bool(limit_reason), + limit_reason=limit_reason, )