diff --git a/tests/test_trajectory_compressor.py b/tests/test_trajectory_compressor.py index 74d63002923..8fcbfc38cfe 100644 --- a/tests/test_trajectory_compressor.py +++ b/tests/test_trajectory_compressor.py @@ -507,3 +507,124 @@ class TestGenerateSummary: summary = await tc._generate_summary_async("Turn content", metrics) assert summary == "[CONTEXT SUMMARY]:" + + +# --------------------------------------------------------------------------- +# TrajectoryCompressor — compression boundary must not split tool pairs +# --------------------------------------------------------------------------- + + +def _gpt_with_tool_call(label, tokens): + """A 'gpt' turn carrying a marker, padded to ~`tokens` tokens.""" + body = f"\n{{\"name\": \"{label}\"}}\n" + pad = max(0, tokens * 4 - len(body)) + return {"from": "gpt", "value": body + "x" * pad} + + +def _tool_response(label, tokens): + """A 'tool' turn carrying a marker, padded to ~`tokens` tokens.""" + body = f"\n{{\"name\": \"{label}\"}}\n" + pad = max(0, tokens * 4 - len(body)) + return {"from": "tool", "value": body + "x" * pad} + + +def _count_marker(trajectory, marker): + return sum(turn["value"].count(marker) for turn in trajectory) + + +def _paired_trajectory(): + """A 10-turn trajectory of gpt/tool pairs with one oversized middle gpt turn. + + Layout (index): system, human, gpt#0, tool#0, gpt#1(big), tool#1, gpt#2, + tool#2, gpt(final), human. With ``protect_last_n_turns=2`` the compressible + region is [4, 8) and the oversized gpt#1 at index 4 is large enough that the + token-accumulation boundary stops at index 5 — i.e. between gpt#1's + and tool#1's . + """ + return [ + {"from": "system", "value": "You are an agent. " * 4}, + {"from": "human", "value": "Please do the task. " * 4}, + _gpt_with_tool_call("a", 12), + _tool_response("a", 12), + _gpt_with_tool_call("b", 400), # oversized — forces a mid-pair boundary + _tool_response("b", 12), + _gpt_with_tool_call("c", 12), + _tool_response("c", 12), + {"from": "gpt", "value": "\n\nAll done."}, + {"from": "human", "value": "Thanks!"}, + ] + + +def _target_that_splits_after_index_4(tc, trajectory): + """Pick a target so token accumulation breaks right after index 4 (a gpt).""" + turn_tokens = tc.count_turn_tokens(trajectory) + total = sum(turn_tokens) + # threshold == turn_tokens[4] makes the loop break at compress_until = 5, + # which lands on the tool turn paired with gpt#1. + return total - turn_tokens[4] + tc.config.summary_target_tokens + + +class TestCompressionToolPairIntegrity: + def _config(self): + config = CompressionConfig() + config.protect_last_n_turns = 2 + config.summary_target_tokens = 4 + return config + + def test_sync_compression_does_not_orphan_tool_markers(self): + tc = _make_compressor(self._config()) + tc._generate_summary = MagicMock( + return_value="[CONTEXT SUMMARY]: middle turns summarized." + ) + trajectory = _paired_trajectory() + tc.config.target_max_tokens = _target_that_splits_after_index_4(tc, trajectory) + + compressed, metrics = tc.compress_trajectory(trajectory) + + assert metrics.was_compressed + # Every must keep its matching . + assert _count_marker(compressed, "") == _count_marker( + compressed, "" + ) + # A kept 'tool' turn must always immediately follow its 'gpt' turn — + # never the inserted summary (a 'human' turn) or another 'tool' turn. + for i, turn in enumerate(compressed): + if turn.get("from") == "tool": + assert i > 0 and compressed[i - 1].get("from") == "gpt" + + @pytest.mark.asyncio + async def test_async_compression_does_not_orphan_tool_markers(self): + tc = _make_compressor(self._config()) + tc._generate_summary_async = AsyncMock( + return_value="[CONTEXT SUMMARY]: middle turns summarized." + ) + trajectory = _paired_trajectory() + tc.config.target_max_tokens = _target_that_splits_after_index_4(tc, trajectory) + + compressed, metrics = await tc.compress_trajectory_async(trajectory) + + assert metrics.was_compressed + assert _count_marker(compressed, "") == _count_marker( + compressed, "" + ) + for i, turn in enumerate(compressed): + if turn.get("from") == "tool": + assert i > 0 and compressed[i - 1].get("from") == "gpt" + + def test_snap_boundary_skips_tool_turn_forward(self): + tc = _make_compressor() + trajectory = _paired_trajectory() + # Index 5 is a 'tool' turn; the boundary should move forward to 6. + assert tc._snap_boundary(trajectory, 5, 4, 8) == 6 + # Index 4 is a 'gpt' turn and already clean. + assert tc._snap_boundary(trajectory, 4, 4, 8) == 4 + + def test_snap_boundary_falls_back_to_backward(self): + tc = _make_compressor() + # Protected tail begins on a 'tool' turn at max_idx: no clean boundary + # ahead, so the boundary must retreat onto the preceding 'gpt' turn. + trajectory = [ + {"from": "gpt", "value": "a"}, + {"from": "tool", "value": "a"}, + ] + assert tc._snap_boundary(trajectory, 1, 0, 1) == 0 diff --git a/trajectory_compressor.py b/trajectory_compressor.py index 7ef396daa8b..9dc3826a854 100644 --- a/trajectory_compressor.py +++ b/trajectory_compressor.py @@ -524,9 +524,48 @@ class TrajectoryCompressor: compressible_start = max(head_protected) + 1 if head_protected else 0 compressible_end = min(tail_protected) if tail_protected else n - + return protected, compressible_start, compressible_end - + + @staticmethod + def _is_boundary_clean(trajectory: List[Dict[str, str]], idx: int) -> bool: + """Return True if a region boundary at ``idx`` does not split a turn pair. + + In the from/value trajectory format a ``tool`` turn (carrying + ```` markers) is always emitted immediately after the + ``gpt`` turn whose ```` it answers. A compression boundary + that lands *on* a ``tool`` turn therefore cuts between a tool call and + its response. A boundary is only clean when it sits at the very end of + the trajectory or on a non-``tool`` turn. + """ + return idx >= len(trajectory) or trajectory[idx].get("from") != "tool" + + @classmethod + def _snap_boundary( + cls, + trajectory: List[Dict[str, str]], + idx: int, + min_idx: int, + max_idx: int, + ) -> int: + """Move a compression boundary onto the nearest clean turn boundary. + + Moving forward is preferred so that an orphaned ``tool`` turn is folded + into the region that already holds its ``gpt`` turn; if no clean + boundary exists ahead (for example the protected tail itself begins on a + ``tool`` turn) the boundary is moved backward instead. The result is + clamped to ``[min_idx, max_idx]``. + """ + forward = idx + while forward < max_idx and not cls._is_boundary_clean(trajectory, forward): + forward += 1 + if cls._is_boundary_clean(trajectory, forward): + return forward + backward = idx + while backward > min_idx and not cls._is_boundary_clean(trajectory, backward): + backward -= 1 + return backward + def _extract_turn_content_for_summary(self, trajectory: List[Dict[str, str]], start: int, end: int) -> str: """ Extract content from turns to be summarized. @@ -746,7 +785,11 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" # Find protected regions protected, compress_start, compress_end = self._find_protected_indices(trajectory) - + + # Snap the head boundary so the compressible region never *starts* on an + # orphaned whose lives in the protected head. + compress_start = self._snap_boundary(trajectory, compress_start, compress_start, compress_end) + # Check if there's anything to compress if compress_start >= compress_end: # Nothing to compress, return as-is @@ -780,17 +823,29 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" if accumulated_tokens < target_tokens_to_compress and compress_until < compress_end: compress_until = compress_end accumulated_tokens = sum(turn_tokens[compress_start:compress_end]) - + + # Snap the tail boundary so we never cut between a and its + # : the summary replaces [compress_start, compress_until) + # and the remainder is kept verbatim, so a boundary on a tool turn would + # leave an orphaned marker and corrupt the training trajectory. + compress_until = self._snap_boundary(trajectory, compress_until, compress_start, compress_end) + if compress_until <= compress_start: + # Snapping collapsed the region; nothing can be safely compressed. + metrics.compressed_tokens = total_tokens + metrics.compressed_turns = len(trajectory) + metrics.still_over_limit = total_tokens > self.config.target_max_tokens + return trajectory, metrics + # Record compression region metrics.turns_compressed_start_idx = compress_start metrics.turns_compressed_end_idx = compress_until metrics.turns_in_compressed_region = compress_until - compress_start - + # Extract content for summary content_to_summarize = self._extract_turn_content_for_summary( trajectory, compress_start, compress_until ) - + # Generate summary summary = self._generate_summary(content_to_summarize, metrics) @@ -853,7 +908,11 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" # Find protected regions protected, compress_start, compress_end = self._find_protected_indices(trajectory) - + + # Snap the head boundary so the compressible region never *starts* on an + # orphaned whose lives in the protected head. + compress_start = self._snap_boundary(trajectory, compress_start, compress_start, compress_end) + # Check if there's anything to compress if compress_start >= compress_end: metrics.compressed_tokens = total_tokens @@ -879,17 +938,29 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" if accumulated_tokens < target_tokens_to_compress and compress_until < compress_end: compress_until = compress_end accumulated_tokens = sum(turn_tokens[compress_start:compress_end]) - + + # Snap the tail boundary so we never cut between a and its + # : the summary replaces [compress_start, compress_until) + # and the remainder is kept verbatim, so a boundary on a tool turn would + # leave an orphaned marker and corrupt the training trajectory. + compress_until = self._snap_boundary(trajectory, compress_until, compress_start, compress_end) + if compress_until <= compress_start: + # Snapping collapsed the region; nothing can be safely compressed. + metrics.compressed_tokens = total_tokens + metrics.compressed_turns = len(trajectory) + metrics.still_over_limit = total_tokens > self.config.target_max_tokens + return trajectory, metrics + # Record compression region metrics.turns_compressed_start_idx = compress_start metrics.turns_compressed_end_idx = compress_until metrics.turns_in_compressed_region = compress_until - compress_start - + # Extract content for summary content_to_summarize = self._extract_turn_content_for_summary( trajectory, compress_start, compress_until ) - + # Generate summary (ASYNC) summary = await self._generate_summary_async(content_to_summarize, metrics)