fix(agent): preserve quoted @file references with spaces

This commit is contained in:
WAXLYY 2026-04-10 02:03:28 +03:00 committed by Teknium
parent 2c99b4e79b
commit c6e1add6f1
2 changed files with 78 additions and 7 deletions

View file

@ -13,8 +13,9 @@ from typing import Awaitable, Callable
from agent.model_metadata import estimate_tokens_rough
_QUOTED_REFERENCE_VALUE = r'(?:`[^`\n]+`|"[^"\n]+"|\'[^\'\n]+\')'
REFERENCE_PATTERN = re.compile(
r"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>\S+))"
rf"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>{_QUOTED_REFERENCE_VALUE}(?::\d+(?:-\d+)?)?|\S+))"
)
TRAILING_PUNCTUATION = ",.;!?"
_SENSITIVE_HOME_DIRS = (".ssh", ".aws", ".gnupg", ".kube", ".docker", ".azure", ".config/gh")
@ -81,14 +82,10 @@ def parse_context_references(message: str) -> list[ContextReference]:
value = _strip_trailing_punctuation(match.group("value") or "")
line_start = None
line_end = None
target = value
target = _strip_reference_wrappers(value)
if kind == "file":
range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value)
if range_match:
target = range_match.group("path")
line_start = int(range_match.group("start"))
line_end = int(range_match.group("end") or range_match.group("start"))
target, line_start, line_end = _parse_file_reference_value(value)
refs.append(
ContextReference(
@ -375,6 +372,38 @@ def _strip_trailing_punctuation(value: str) -> str:
return stripped
def _strip_reference_wrappers(value: str) -> str:
if len(value) >= 2 and value[0] == value[-1] and value[0] in "`\"'":
return value[1:-1]
return value
def _parse_file_reference_value(value: str) -> tuple[str, int | None, int | None]:
quoted_match = re.match(
r'^(?P<quote>`|"|\')(?P<path>.+?)(?P=quote)(?::(?P<start>\d+)(?:-(?P<end>\d+))?)?$',
value,
)
if quoted_match:
line_start = quoted_match.group("start")
line_end = quoted_match.group("end")
return (
quoted_match.group("path"),
int(line_start) if line_start is not None else None,
int(line_end or line_start) if line_start is not None else None,
)
range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value)
if range_match:
line_start = int(range_match.group("start"))
return (
range_match.group("path"),
line_start,
int(range_match.group("end") or range_match.group("start")),
)
return _strip_reference_wrappers(value), None, None
def _remove_reference_tokens(message: str, refs: list[ContextReference]) -> str:
pieces: list[str] = []
cursor = 0

View file

@ -83,6 +83,24 @@ def test_parse_references_strips_trailing_punctuation():
assert refs[1].target == "https://example.com/docs"
def test_parse_quoted_references_with_spaces_and_preserve_unquoted_ranges():
from agent.context_references import parse_context_references
refs = parse_context_references(
'review @file:"C:\\Users\\Simba\\My Project\\main.py":7-9 '
'and @folder:"docs and specs" plus @file:src/main.py:1-2'
)
assert [ref.kind for ref in refs] == ["file", "folder", "file"]
assert refs[0].target == r"C:\Users\Simba\My Project\main.py"
assert refs[0].line_start == 7
assert refs[0].line_end == 9
assert refs[1].target == "docs and specs"
assert refs[2].target == "src/main.py"
assert refs[2].line_start == 1
assert refs[2].line_end == 2
def test_expand_file_range_and_folder_listing(sample_repo: Path):
from agent.context_references import preprocess_context_references
@ -106,6 +124,30 @@ def test_expand_file_range_and_folder_listing(sample_repo: Path):
assert not result.warnings
def test_expand_quoted_file_reference_with_spaces(tmp_path: Path):
from agent.context_references import preprocess_context_references
workspace = tmp_path / "repo"
folder = workspace / "docs and specs"
folder.mkdir(parents=True)
file_path = folder / "release notes.txt"
file_path.write_text("line 1\nline 2\nline 3\n", encoding="utf-8")
result = preprocess_context_references(
'Review @file:"docs and specs/release notes.txt":2-3',
cwd=workspace,
context_length=100_000,
)
assert result.expanded
assert result.message.startswith("Review")
assert "line 1" not in result.message
assert "line 2" in result.message
assert "line 3" in result.message
assert "release notes.txt" in result.message
assert not result.warnings
def test_expand_git_diff_staged_and_log(sample_repo: Path):
from agent.context_references import preprocess_context_references