diff --git a/agent/context_references.py b/agent/context_references.py index 1b8ac9481a..7ecb90c497 100644 --- a/agent/context_references.py +++ b/agent/context_references.py @@ -13,8 +13,9 @@ from typing import Awaitable, Callable from agent.model_metadata import estimate_tokens_rough +_QUOTED_REFERENCE_VALUE = r'(?:`[^`\n]+`|"[^"\n]+"|\'[^\'\n]+\')' REFERENCE_PATTERN = re.compile( - r"(?diff|staged)\b|(?Pfile|folder|git|url):(?P\S+))" + rf"(?diff|staged)\b|(?Pfile|folder|git|url):(?P{_QUOTED_REFERENCE_VALUE}(?::\d+(?:-\d+)?)?|\S+))" ) TRAILING_PUNCTUATION = ",.;!?" _SENSITIVE_HOME_DIRS = (".ssh", ".aws", ".gnupg", ".kube", ".docker", ".azure", ".config/gh") @@ -81,14 +82,10 @@ def parse_context_references(message: str) -> list[ContextReference]: value = _strip_trailing_punctuation(match.group("value") or "") line_start = None line_end = None - target = value + target = _strip_reference_wrappers(value) if kind == "file": - range_match = re.match(r"^(?P.+?):(?P\d+)(?:-(?P\d+))?$", value) - if range_match: - target = range_match.group("path") - line_start = int(range_match.group("start")) - line_end = int(range_match.group("end") or range_match.group("start")) + target, line_start, line_end = _parse_file_reference_value(value) refs.append( ContextReference( @@ -375,6 +372,38 @@ def _strip_trailing_punctuation(value: str) -> str: return stripped +def _strip_reference_wrappers(value: str) -> str: + if len(value) >= 2 and value[0] == value[-1] and value[0] in "`\"'": + return value[1:-1] + return value + + +def _parse_file_reference_value(value: str) -> tuple[str, int | None, int | None]: + quoted_match = re.match( + r'^(?P`|"|\')(?P.+?)(?P=quote)(?::(?P\d+)(?:-(?P\d+))?)?$', + value, + ) + if quoted_match: + line_start = quoted_match.group("start") + line_end = quoted_match.group("end") + return ( + quoted_match.group("path"), + int(line_start) if line_start is not None else None, + int(line_end or line_start) if line_start is not None else None, + ) + + range_match = re.match(r"^(?P.+?):(?P\d+)(?:-(?P\d+))?$", value) + if range_match: + line_start = int(range_match.group("start")) + return ( + range_match.group("path"), + line_start, + int(range_match.group("end") or range_match.group("start")), + ) + + return _strip_reference_wrappers(value), None, None + + def _remove_reference_tokens(message: str, refs: list[ContextReference]) -> str: pieces: list[str] = [] cursor = 0 diff --git a/tests/agent/test_context_references.py b/tests/agent/test_context_references.py index 92712c4d20..ea5579c568 100644 --- a/tests/agent/test_context_references.py +++ b/tests/agent/test_context_references.py @@ -83,6 +83,24 @@ def test_parse_references_strips_trailing_punctuation(): assert refs[1].target == "https://example.com/docs" +def test_parse_quoted_references_with_spaces_and_preserve_unquoted_ranges(): + from agent.context_references import parse_context_references + + refs = parse_context_references( + 'review @file:"C:\\Users\\Simba\\My Project\\main.py":7-9 ' + 'and @folder:"docs and specs" plus @file:src/main.py:1-2' + ) + + assert [ref.kind for ref in refs] == ["file", "folder", "file"] + assert refs[0].target == r"C:\Users\Simba\My Project\main.py" + assert refs[0].line_start == 7 + assert refs[0].line_end == 9 + assert refs[1].target == "docs and specs" + assert refs[2].target == "src/main.py" + assert refs[2].line_start == 1 + assert refs[2].line_end == 2 + + def test_expand_file_range_and_folder_listing(sample_repo: Path): from agent.context_references import preprocess_context_references @@ -106,6 +124,30 @@ def test_expand_file_range_and_folder_listing(sample_repo: Path): assert not result.warnings +def test_expand_quoted_file_reference_with_spaces(tmp_path: Path): + from agent.context_references import preprocess_context_references + + workspace = tmp_path / "repo" + folder = workspace / "docs and specs" + folder.mkdir(parents=True) + file_path = folder / "release notes.txt" + file_path.write_text("line 1\nline 2\nline 3\n", encoding="utf-8") + + result = preprocess_context_references( + 'Review @file:"docs and specs/release notes.txt":2-3', + cwd=workspace, + context_length=100_000, + ) + + assert result.expanded + assert result.message.startswith("Review") + assert "line 1" not in result.message + assert "line 2" in result.message + assert "line 3" in result.message + assert "release notes.txt" in result.message + assert not result.warnings + + def test_expand_git_diff_staged_and_log(sample_repo: Path): from agent.context_references import preprocess_context_references