fix(search): keep partial results on search timeout (#36142)

Treat search command budget timeouts as soft truncation so partial results survive, while real search failures still return structured errors.
This commit is contained in:
Teknium 2026-06-13 14:35:21 -07:00 committed by GitHub
parent 069bfd6545
commit 1fa761f8de
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 193 additions and 16 deletions

View file

@ -0,0 +1,133 @@
from unittest.mock import MagicMock
import pytest
from tools.file_operations import ExecuteResult, ShellFileOperations, _search_stdout_and_limit
TIMEOUT = "[Command timed out after 60s]"
@pytest.fixture()
def ops():
env = MagicMock(cwd="/tmp/test")
env.execute.return_value = {"output": "", "returncode": 0}
return ShellFileOperations(env)
def timeout_output(*lines: str) -> str:
return "\n".join([*lines, TIMEOUT])
def path_exists_or(output: str, returncode: int = 124):
def execute(command, **kwargs):
if "test -e" in command:
return {"output": "exists", "returncode": 0}
return {"output": output, "returncode": returncode}
return execute
def assert_timed_out(result):
assert result.error is None
assert result.truncated is True
assert result.limit_reason == "search_timeout"
assert result.to_dict()["limit_reason"] == "search_timeout"
def test_timeout_helper_strips_only_trailing_marker():
assert _search_stdout_and_limit(ExecuteResult(timeout_output("a.py"), 124)) == ("a.py", "search_timeout")
assert _search_stdout_and_limit(ExecuteResult("a.py\nnot a marker", 0)) == ("a.py\nnot a marker", None)
@pytest.mark.parametrize(
("target", "output_mode", "raw", "expected"),
[
("files", "content", timeout_output("src/a.py", "src/b.py"), ["src/a.py", "src/b.py"]),
("content", "files_only", timeout_output("src/a.py", "src/b.py"), ["src/a.py", "src/b.py"]),
("content", "content", timeout_output("src/a.py:10:foo", "src/b.py:20:foo"), ["src/a.py", "src/b.py"]),
],
)
def test_rg_timeout_returns_partial_results_without_marker(ops, monkeypatch, target, output_mode, raw, expected):
ops.env.execute.side_effect = path_exists_or(raw)
monkeypatch.setattr(ops, "_has_command", lambda cmd: cmd == "rg")
result = ops.search("foo", path="/big", target=target, output_mode=output_mode)
assert_timed_out(result)
if target == "content" and output_mode == "content":
assert [match.path for match in result.matches] == expected
assert all("timed out" not in match.content for match in result.matches)
else:
assert result.files == expected
assert all("timed out" not in path for path in result.files)
def test_rg_count_timeout_returns_partial_counts(ops, monkeypatch):
ops.env.execute.side_effect = path_exists_or(timeout_output("src/a.py:3", "src/b.py:5"))
monkeypatch.setattr(ops, "_has_command", lambda cmd: cmd == "rg")
result = ops.search("foo", path="/big", target="content", output_mode="count")
assert_timed_out(result)
assert result.counts == {"src/a.py": 3, "src/b.py": 5}
def test_rg_file_timeout_does_not_retry_unsorted(ops, monkeypatch):
calls = 0
def execute(command, **kwargs):
nonlocal calls
if "test -e" in command:
return {"output": "exists", "returncode": 0}
calls += 1
return {"output": timeout_output(), "returncode": 124}
ops.env.execute.side_effect = execute
monkeypatch.setattr(ops, "_has_command", lambda cmd: cmd == "rg")
result = ops.search("*.py", path="/big", target="files")
assert calls == 1
assert_timed_out(result)
assert result.files == []
def test_grep_timeout_returns_partial_match(ops, monkeypatch):
ops.env.execute.side_effect = path_exists_or(timeout_output("src/a.py:10:foo"))
monkeypatch.setattr(ops, "_has_command", lambda cmd: cmd == "grep")
result = ops.search("foo", path="/big", target="content")
assert_timed_out(result)
assert [match.path for match in result.matches] == ["src/a.py"]
def test_find_timeout_returns_partial_files_and_does_not_retry(ops, monkeypatch):
calls = 0
def execute(command, **kwargs):
nonlocal calls
if "test -e" in command:
return {"output": "exists", "returncode": 0}
calls += 1
return {"output": timeout_output("1700000000.0 /big/a.py"), "returncode": 124}
ops.env.execute.side_effect = execute
monkeypatch.setattr(ops, "_has_command", lambda cmd: cmd == "find")
result = ops.search("*.py", path="/big", target="files")
assert calls == 1
assert_timed_out(result)
assert result.files == ["/big/a.py"]
def test_real_rg_error_still_hard_fails(ops, monkeypatch):
ops.env.execute.side_effect = path_exists_or("rg: regex parse error:", returncode=2)
monkeypatch.setattr(ops, "_has_command", lambda cmd: cmd == "rg")
result = ops.search("[", path="/big", target="content")
assert result.error == "Search failed: rg: regex parse error:"
assert result.limit_reason is None

View file

@ -241,10 +241,11 @@ class SearchResult:
counts: Dict[str, int] = field(default_factory=dict)
total_count: int = 0
truncated: bool = False
limit_reason: Optional[str] = None
error: Optional[str] = None
def to_dict(self) -> dict:
result = {"total_count": self.total_count}
result: dict[str, object] = {"total_count": self.total_count}
if self.matches:
result["matches"] = [
{"path": m.path, "line": m.line_number, "content": m.content}
@ -256,6 +257,8 @@ class SearchResult:
result["counts"] = self.counts
if self.truncated:
result["truncated"] = True
if self.limit_reason:
result["limit_reason"] = self.limit_reason
if self.error:
result["error"] = self.error
return result
@ -285,6 +288,16 @@ class ExecuteResult:
exit_code: int = 0
_SEARCH_TIMEOUT_MARKER_RE = re.compile(r"\n?\[Command timed out after \d+s\]\s*$")
def _search_stdout_and_limit(result: ExecuteResult) -> tuple[str, Optional[str]]:
"""Return stdout cleaned for parsing and a limit reason for search timeouts."""
if result.exit_code == 124:
return _SEARCH_TIMEOUT_MARKER_RE.sub("", result.stdout), "search_timeout"
return result.stdout, None
def _split_tool_diagnostics(output: str) -> tuple[str, str]:
"""Separate rg/grep diagnostic lines from real match output.
@ -1967,15 +1980,17 @@ class ShellFileOperations(FileOperations):
f"-printf '%T@ %p\\n' 2>/dev/null | sort -rn{pagination_expr}"
result = self._exec(cmd, timeout=60)
stdout, limit_reason = _search_stdout_and_limit(result)
if not result.stdout.strip():
if not stdout.strip() and not limit_reason:
# Try without -printf (BSD find compatibility -- macOS)
cmd_simple = f"find {self._escape_shell_arg(path)}{hidden_filter_expr} -type f -name {self._escape_shell_arg(search_pattern)} " \
f"2>/dev/null | sort -rn{pagination_expr}"
result = self._exec(cmd_simple, timeout=60)
stdout, limit_reason = _search_stdout_and_limit(result)
files = []
for line in result.stdout.strip().split('\n'):
for line in stdout.strip().split('\n'):
if not line:
continue
parts = line.split(' ', 1)
@ -2003,7 +2018,9 @@ class ShellFileOperations(FileOperations):
return SearchResult(
files=files,
total_count=len(files)
total_count=len(files),
truncated=bool(limit_reason),
limit_reason=limit_reason,
)
def _search_files_rg(self, pattern: str, path: str, limit: int, offset: int) -> SearchResult:
@ -2029,9 +2046,10 @@ class ShellFileOperations(FileOperations):
f"| head -n {fetch_limit}"
)
result = self._exec(cmd_sorted, timeout=60)
all_files = [f for f in result.stdout.strip().split('\n') if f]
stdout, limit_reason = _search_stdout_and_limit(result)
all_files = [f for f in stdout.strip().split('\n') if f]
if not all_files:
if not all_files and not limit_reason:
# --sortr may have failed on older rg; retry without it.
cmd_plain = (
f"rg --files -g {self._escape_shell_arg(glob_pattern)} "
@ -2039,14 +2057,16 @@ class ShellFileOperations(FileOperations):
f"| head -n {fetch_limit}"
)
result = self._exec(cmd_plain, timeout=60)
all_files = [f for f in result.stdout.strip().split('\n') if f]
stdout, limit_reason = _search_stdout_and_limit(result)
all_files = [f for f in stdout.strip().split('\n') if f]
page = all_files[offset:offset + limit]
return SearchResult(
files=page,
total_count=len(all_files),
truncated=len(all_files) >= fetch_limit,
truncated=len(all_files) >= fetch_limit or bool(limit_reason),
limit_reason=limit_reason,
)
def _search_content(self, pattern: str, path: str, file_glob: Optional[str],
@ -2102,12 +2122,13 @@ class ShellFileOperations(FileOperations):
# introduce false errors on a successful-but-truncated search.
cmd = "set -o pipefail; " + " ".join(cmd_parts)
result = self._exec(cmd, timeout=60)
stdout, limit_reason = _search_stdout_and_limit(result)
# _exec merges stderr into stdout (stderr=subprocess.STDOUT), so rg's
# diagnostic lines ("rg: <file>: <error>", "rg: regex parse error:")
# are interleaved with match output. Split them out: diagnostics must
# not be parsed as matches, and on a hard error they ARE the message.
diagnostics, payload = _split_tool_diagnostics(result.stdout)
diagnostics, payload = _split_tool_diagnostics(stdout)
# rg exit codes: 0=matches found, 1=no matches, 2=error. rg returns 2
# even on partial errors (e.g. one unreadable file in a tree that
@ -2124,7 +2145,12 @@ class ShellFileOperations(FileOperations):
all_files = [f for f in stdout.strip().split('\n') if f]
total = len(all_files)
page = all_files[offset:offset + limit]
return SearchResult(files=page, total_count=total)
return SearchResult(
files=page,
total_count=total,
truncated=bool(limit_reason),
limit_reason=limit_reason,
)
elif output_mode == "count":
counts = {}
@ -2136,7 +2162,12 @@ class ShellFileOperations(FileOperations):
counts[parts[0]] = int(parts[1])
except ValueError:
pass
return SearchResult(counts=counts, total_count=sum(counts.values()))
return SearchResult(
counts=counts,
total_count=sum(counts.values()),
truncated=bool(limit_reason),
limit_reason=limit_reason,
)
else:
# Parse content matches and context lines.
@ -2177,7 +2208,8 @@ class ShellFileOperations(FileOperations):
return SearchResult(
matches=page,
total_count=total,
truncated=total > offset + limit
truncated=total > offset + limit or bool(limit_reason),
limit_reason=limit_reason,
)
def _search_with_grep(self, pattern: str, path: str, file_glob: Optional[str],
@ -2218,12 +2250,13 @@ class ShellFileOperations(FileOperations):
# pipefail does not turn truncated results into false errors.
cmd = "set -o pipefail; " + " ".join(cmd_parts)
result = self._exec(cmd, timeout=60)
stdout, limit_reason = _search_stdout_and_limit(result)
# _exec merges stderr into stdout, so grep's diagnostic lines
# ("grep: <file>: <error>") are interleaved with matches. Split them
# out so they're never parsed as matches and so a hard error has a
# clean message.
diagnostics, payload = _split_tool_diagnostics(result.stdout)
diagnostics, payload = _split_tool_diagnostics(stdout)
# grep exit codes: 0=matches found, 1=no matches, 2=error. grep
# returns 2 on partial errors (e.g. an unreadable file) even when
@ -2238,7 +2271,12 @@ class ShellFileOperations(FileOperations):
all_files = [f for f in stdout.strip().split('\n') if f]
total = len(all_files)
page = all_files[offset:offset + limit]
return SearchResult(files=page, total_count=total)
return SearchResult(
files=page,
total_count=total,
truncated=bool(limit_reason),
limit_reason=limit_reason,
)
elif output_mode == "count":
counts = {}
@ -2250,7 +2288,12 @@ class ShellFileOperations(FileOperations):
counts[parts[0]] = int(parts[1])
except ValueError:
pass
return SearchResult(counts=counts, total_count=sum(counts.values()))
return SearchResult(
counts=counts,
total_count=sum(counts.values()),
truncated=bool(limit_reason),
limit_reason=limit_reason,
)
else:
# grep match lines: "file:lineno:content" (colon)
@ -2288,5 +2331,6 @@ class ShellFileOperations(FileOperations):
return SearchResult(
matches=page,
total_count=total,
truncated=total > offset + limit
truncated=total > offset + limit or bool(limit_reason),
limit_reason=limit_reason,
)