diff --git a/agent/context_compressor.py b/agent/context_compressor.py index 24d7120a9b..8f5325092b 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -66,6 +66,13 @@ class ContextCompressor(ContextEngine): def name(self) -> str: return "compressor" + def on_session_reset(self) -> None: + """Reset all per-session state for /new or /reset.""" + super().on_session_reset() + self._context_probed = False + self._context_probe_persistable = False + self._previous_summary = None + def __init__( self, model: str, diff --git a/hermes_cli/plugins.py b/hermes_cli/plugins.py index 7323bbd011..94ec20836d 100644 --- a/hermes_cli/plugins.py +++ b/hermes_cli/plugins.py @@ -201,8 +201,7 @@ class PluginContext: The *setup_fn* receives an argparse subparser and should add any arguments/sub-subparsers. If *handler_fn* is provided it is set - as the default dispatch function via ``set_defaults(func=...)``. - """ + as the default dispatch function via ``set_defaults(func=...)``.""" self._manager._cli_commands[name] = { "name": name, "help": help, @@ -213,6 +212,38 @@ class PluginContext: } logger.debug("Plugin %s registered CLI command: %s", self.manifest.name, name) + # -- context engine registration ----------------------------------------- + + def register_context_engine(self, engine) -> None: + """Register a context engine to replace the built-in ContextCompressor. + + Only one context engine plugin is allowed. If a second plugin tries + to register one, it is rejected with a warning. + + The engine must be an instance of ``agent.context_engine.ContextEngine``. + """ + if self._manager._context_engine is not None: + logger.warning( + "Plugin '%s' tried to register a context engine, but one is " + "already registered. Only one context engine plugin is allowed.", + self.manifest.name, + ) + return + # Defer the import to avoid circular deps at module level + from agent.context_engine import ContextEngine + if not isinstance(engine, ContextEngine): + logger.warning( + "Plugin '%s' tried to register a context engine that does not " + "inherit from ContextEngine. Ignoring.", + self.manifest.name, + ) + return + self._manager._context_engine = engine + logger.info( + "Plugin '%s' registered context engine: %s", + self.manifest.name, engine.name, + ) + # -- hook registration -------------------------------------------------- def register_hook(self, hook_name: str, callback: Callable) -> None: @@ -245,6 +276,7 @@ class PluginManager: self._hooks: Dict[str, List[Callable]] = {} self._plugin_tool_names: Set[str] = set() self._cli_commands: Dict[str, dict] = {} + self._context_engine = None # Set by a plugin via register_context_engine() self._discovered: bool = False self._cli_ref = None # Set by CLI after plugin discovery @@ -566,6 +598,11 @@ def get_plugin_cli_commands() -> Dict[str, dict]: return dict(get_plugin_manager()._cli_commands) +def get_plugin_context_engine(): + """Return the plugin-registered context engine, or None.""" + return get_plugin_manager()._context_engine + + def get_plugin_toolsets() -> List[tuple]: """Return plugin toolsets as ``(key, label, description)`` tuples. diff --git a/run_agent.py b/run_agent.py index 7ac077d784..2af911af0d 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1268,19 +1268,32 @@ class AIAgent: pass break - self.context_compressor = ContextCompressor( - model=self.model, - threshold_percent=compression_threshold, - protect_first_n=3, - protect_last_n=compression_protect_last, - summary_target_ratio=compression_target_ratio, - summary_model_override=compression_summary_model, - quiet_mode=self.quiet_mode, - base_url=self.base_url, - api_key=getattr(self, "api_key", ""), - config_context_length=_config_context_length, - provider=self.provider, - ) + # Check if a plugin registered a custom context engine (e.g. LCM) + _plugin_engine = None + try: + from hermes_cli.plugins import get_plugin_context_engine + _plugin_engine = get_plugin_context_engine() + except Exception: + pass + + if _plugin_engine is not None: + self.context_compressor = _plugin_engine + if not self.quiet_mode: + logger.info("Using plugin context engine: %s", _plugin_engine.name) + else: + self.context_compressor = ContextCompressor( + model=self.model, + threshold_percent=compression_threshold, + protect_first_n=3, + protect_last_n=compression_protect_last, + summary_target_ratio=compression_target_ratio, + summary_model_override=compression_summary_model, + quiet_mode=self.quiet_mode, + base_url=self.base_url, + api_key=getattr(self, "api_key", ""), + config_context_length=_config_context_length, + provider=self.provider, + ) self.compression_enabled = compression_enabled self._subdirectory_hints = SubdirectoryHintTracker( working_dir=os.getenv("TERMINAL_CWD") or None, @@ -1397,15 +1410,9 @@ class AIAgent: # Turn counter (added after reset_session_state was first written — #2635) self._user_turn_count = 0 - # Context compressor internal counters (if present) + # Context engine reset (works for both built-in compressor and plugins) if hasattr(self, "context_compressor") and self.context_compressor: - self.context_compressor.last_prompt_tokens = 0 - self.context_compressor.last_completion_tokens = 0 - self.context_compressor.compression_count = 0 - self.context_compressor._context_probed = False - self.context_compressor._context_probe_persistable = False - # Iterative summary from previous session must not bleed into new one (#2635) - self.context_compressor._previous_summary = None + self.context_compressor.on_session_reset() def switch_model(self, new_model, new_provider, api_key='', base_url='', api_mode=''): """Switch the model/provider in-place for a live agent. diff --git a/tests/agent/test_context_engine.py b/tests/agent/test_context_engine.py new file mode 100644 index 0000000000..a06285dc2a --- /dev/null +++ b/tests/agent/test_context_engine.py @@ -0,0 +1,250 @@ +"""Tests for the ContextEngine ABC and plugin slot.""" + +import json +import pytest +from typing import Any, Dict, List + +from agent.context_engine import ContextEngine +from agent.context_compressor import ContextCompressor + + +# --------------------------------------------------------------------------- +# A minimal concrete engine for testing the ABC +# --------------------------------------------------------------------------- + +class StubEngine(ContextEngine): + """Minimal engine that satisfies the ABC without doing real work.""" + + def __init__(self, context_length=200000, threshold_pct=0.50): + self.context_length = context_length + self.threshold_tokens = int(context_length * threshold_pct) + self._compress_called = False + self._tools_called = [] + + @property + def name(self) -> str: + return "stub" + + def update_from_response(self, usage: Dict[str, Any]) -> None: + self.last_prompt_tokens = usage.get("prompt_tokens", 0) + self.last_completion_tokens = usage.get("completion_tokens", 0) + self.last_total_tokens = usage.get("total_tokens", 0) + + def should_compress(self, prompt_tokens: int = None) -> bool: + tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens + return tokens >= self.threshold_tokens + + def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]: + self._compress_called = True + self.compression_count += 1 + # Trivial: just return as-is + return messages + + def get_tool_schemas(self) -> List[Dict[str, Any]]: + return [ + { + "name": "stub_search", + "description": "Search the stub engine", + "parameters": {"type": "object", "properties": {}}, + } + ] + + def handle_tool_call(self, name: str, args: Dict[str, Any]) -> str: + self._tools_called.append(name) + return json.dumps({"ok": True, "tool": name}) + + +# --------------------------------------------------------------------------- +# ABC contract tests +# --------------------------------------------------------------------------- + +class TestContextEngineABC: + """Verify the ABC enforces the required interface.""" + + def test_cannot_instantiate_abc_directly(self): + with pytest.raises(TypeError): + ContextEngine() + + def test_missing_methods_raises(self): + """A subclass missing required methods cannot be instantiated.""" + class Incomplete(ContextEngine): + @property + def name(self): + return "incomplete" + with pytest.raises(TypeError): + Incomplete() + + def test_stub_engine_satisfies_abc(self): + engine = StubEngine() + assert isinstance(engine, ContextEngine) + assert engine.name == "stub" + + def test_compressor_is_context_engine(self): + c = ContextCompressor(model="test", quiet_mode=True, config_context_length=200000) + assert isinstance(c, ContextEngine) + assert c.name == "compressor" + + +# --------------------------------------------------------------------------- +# Default method behavior +# --------------------------------------------------------------------------- + +class TestDefaults: + """Verify ABC default implementations work correctly.""" + + def test_default_tool_schemas_empty(self): + engine = StubEngine() + # StubEngine overrides this, so test the base via super + assert ContextEngine.get_tool_schemas(engine) == [] + + def test_default_handle_tool_call_returns_error(self): + engine = StubEngine() + result = ContextEngine.handle_tool_call(engine, "unknown", {}) + data = json.loads(result) + assert "error" in data + + def test_default_get_status(self): + engine = StubEngine() + engine.last_prompt_tokens = 50000 + status = engine.get_status() + assert status["last_prompt_tokens"] == 50000 + assert status["context_length"] == 200000 + assert status["threshold_tokens"] == 100000 + assert 0 < status["usage_percent"] <= 100 + + def test_on_session_reset(self): + engine = StubEngine() + engine.last_prompt_tokens = 999 + engine.compression_count = 3 + engine.on_session_reset() + assert engine.last_prompt_tokens == 0 + assert engine.compression_count == 0 + + def test_should_compress_preflight_default_false(self): + engine = StubEngine() + assert engine.should_compress_preflight([]) is False + + +# --------------------------------------------------------------------------- +# StubEngine behavior +# --------------------------------------------------------------------------- + +class TestStubEngine: + + def test_should_compress(self): + engine = StubEngine(context_length=100000, threshold_pct=0.50) + assert not engine.should_compress(40000) + assert engine.should_compress(50000) + assert engine.should_compress(60000) + + def test_compress_tracks_count(self): + engine = StubEngine() + msgs = [{"role": "user", "content": "hello"}] + result = engine.compress(msgs) + assert result == msgs + assert engine._compress_called + assert engine.compression_count == 1 + + def test_tool_schemas(self): + engine = StubEngine() + schemas = engine.get_tool_schemas() + assert len(schemas) == 1 + assert schemas[0]["name"] == "stub_search" + + def test_handle_tool_call(self): + engine = StubEngine() + result = engine.handle_tool_call("stub_search", {}) + assert json.loads(result)["ok"] is True + assert "stub_search" in engine._tools_called + + def test_update_from_response(self): + engine = StubEngine() + engine.update_from_response({"prompt_tokens": 1000, "completion_tokens": 200, "total_tokens": 1200}) + assert engine.last_prompt_tokens == 1000 + assert engine.last_completion_tokens == 200 + + +# --------------------------------------------------------------------------- +# ContextCompressor session reset via ABC +# --------------------------------------------------------------------------- + +class TestCompressorSessionReset: + """Verify ContextCompressor.on_session_reset() clears all state.""" + + def test_reset_clears_state(self): + c = ContextCompressor(model="test", quiet_mode=True, config_context_length=200000) + c.last_prompt_tokens = 50000 + c.compression_count = 3 + c._previous_summary = "some old summary" + c._context_probed = True + c._context_probe_persistable = True + + c.on_session_reset() + + assert c.last_prompt_tokens == 0 + assert c.last_completion_tokens == 0 + assert c.last_total_tokens == 0 + assert c.compression_count == 0 + assert c._context_probed is False + assert c._context_probe_persistable is False + assert c._previous_summary is None + + +# --------------------------------------------------------------------------- +# Plugin slot (PluginManager integration) +# --------------------------------------------------------------------------- + +class TestPluginContextEngineSlot: + """Test register_context_engine on PluginContext.""" + + def test_register_engine(self): + from hermes_cli.plugins import PluginManager, PluginContext, PluginManifest + mgr = PluginManager() + manifest = PluginManifest(name="test-lcm") + ctx = PluginContext(manifest, mgr) + + engine = StubEngine() + ctx.register_context_engine(engine) + + assert mgr._context_engine is engine + assert mgr._context_engine.name == "stub" + + def test_reject_second_engine(self): + from hermes_cli.plugins import PluginManager, PluginContext, PluginManifest + mgr = PluginManager() + manifest = PluginManifest(name="test-lcm") + ctx = PluginContext(manifest, mgr) + + engine1 = StubEngine() + engine2 = StubEngine() + ctx.register_context_engine(engine1) + ctx.register_context_engine(engine2) # should be rejected + + assert mgr._context_engine is engine1 + + def test_reject_non_engine(self): + from hermes_cli.plugins import PluginManager, PluginContext, PluginManifest + mgr = PluginManager() + manifest = PluginManifest(name="test-bad") + ctx = PluginContext(manifest, mgr) + + ctx.register_context_engine("not an engine") + assert mgr._context_engine is None + + def test_get_plugin_context_engine(self): + from hermes_cli.plugins import PluginManager, PluginContext, PluginManifest, get_plugin_context_engine, _plugin_manager + import hermes_cli.plugins as plugins_mod + + # Inject a test manager + old_mgr = plugins_mod._plugin_manager + try: + mgr = PluginManager() + plugins_mod._plugin_manager = mgr + + assert get_plugin_context_engine() is None + + engine = StubEngine() + mgr._context_engine = engine + assert get_plugin_context_engine() is engine + finally: + plugins_mod._plugin_manager = old_mgr