fix: robust context engine interface — config selection, plugin discovery, ABC completeness

Follow-up fixes for the context engine plugin slot (PR #5700):

- Enhance ContextEngine ABC: add threshold_percent, protect_first_n,
  protect_last_n as class attributes; complete update_model() default
  with threshold recalculation; clarify on_session_end() lifecycle docs
- Add ContextCompressor.update_model() override for model/provider/
  base_url/api_key updates
- Replace all direct compressor internal access in run_agent.py with
  ABC interface: switch_model(), fallback restore, context probing
  all use update_model() now; _context_probed guarded with getattr/
  hasattr for plugin engine compatibility
- Create plugins/context_engine/ directory with discovery module
  (mirrors plugins/memory/ pattern) — discover_context_engines(),
  load_context_engine()
- Add context.engine config key to DEFAULT_CONFIG (default: compressor)
- Config-driven engine selection in run_agent.__init__: checks config,
  then plugins/context_engine/<name>/, then general plugin system,
  falls back to built-in ContextCompressor
- Wire on_session_end() in shutdown_memory_provider() at real session
  boundaries (CLI exit, /reset, gateway expiry)
This commit is contained in:
Teknium 2026-04-08 04:16:58 -07:00 committed by Teknium
parent 5d8dd622bc
commit 3fe6938176
5 changed files with 388 additions and 64 deletions

View file

@ -73,6 +73,22 @@ class ContextCompressor(ContextEngine):
self._context_probe_persistable = False self._context_probe_persistable = False
self._previous_summary = None self._previous_summary = None
def update_model(
self,
model: str,
context_length: int,
base_url: str = "",
api_key: str = "",
provider: str = "",
) -> None:
"""Update model info after a model switch or fallback activation."""
self.model = model
self.base_url = base_url
self.api_key = api_key
self.provider = provider
self.context_length = context_length
self.threshold_tokens = int(context_length * self.threshold_percent)
def __init__( def __init__(
self, self,
model: str, model: str,

View file

@ -3,7 +3,11 @@
A context engine controls how conversation context is managed when A context engine controls how conversation context is managed when
approaching the model's token limit. The built-in ContextCompressor approaching the model's token limit. The built-in ContextCompressor
is the default implementation. Third-party engines (e.g. LCM) can is the default implementation. Third-party engines (e.g. LCM) can
replace it by registering via the plugin system. replace it via the plugin system or by being placed in the
``plugins/context_engine/<name>/`` directory.
Selection is config-driven: ``context.engine`` in config.yaml.
Default is ``"compressor"`` (the built-in). Only one engine is active.
The engine is responsible for: The engine is responsible for:
- Deciding when compaction should fire - Deciding when compaction should fire
@ -17,7 +21,8 @@ Lifecycle:
3. update_from_response() called after each API response with usage data 3. update_from_response() called after each API response with usage data
4. should_compress() checked after each turn 4. should_compress() checked after each turn
5. compress() called when should_compress() returns True 5. compress() called when should_compress() returns True
6. on_session_end() called when the conversation ends 6. on_session_end() called at real session boundaries (CLI exit, /reset,
gateway session expiry) NOT per-turn
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -45,6 +50,16 @@ class ContextEngine(ABC):
context_length: int = 0 context_length: int = 0
compression_count: int = 0 compression_count: int = 0
# -- Compaction parameters (read by run_agent.py for preflight) --------
#
# These control the preflight compression check. Subclasses may
# override via __init__ or property; defaults are sensible for most
# engines.
threshold_percent: float = 0.75
protect_first_n: int = 3
protect_last_n: int = 6
# -- Core interface ---------------------------------------------------- # -- Core interface ----------------------------------------------------
@abstractmethod @abstractmethod
@ -93,9 +108,10 @@ class ContextEngine(ABC):
""" """
def on_session_end(self, session_id: str, messages: List[Dict[str, Any]]) -> None: def on_session_end(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
"""Called when the conversation ends. """Called at real session boundaries (CLI exit, /reset, gateway expiry).
Use this to flush state, close DB connections, etc. Use this to flush state, close DB connections, etc.
NOT called per-turn only when the session truly ends.
""" """
def on_session_reset(self) -> None: def on_session_reset(self) -> None:
@ -158,9 +174,11 @@ class ContextEngine(ABC):
api_key: str = "", api_key: str = "",
provider: str = "", provider: str = "",
) -> None: ) -> None:
"""Called when the user switches models mid-session. """Called when the user switches models or on fallback activation.
Default updates context_length and threshold_tokens. Override if Default updates context_length and recalculates threshold_tokens
your engine needs to do more (e.g. recalculate DAG budgets). from threshold_percent. Override if your engine needs more
(e.g. recalculate DAG budgets, switch summary models).
""" """
self.context_length = context_length self.context_length = context_length
self.threshold_tokens = int(context_length * self.threshold_percent)

View file

@ -504,6 +504,16 @@ DEFAULT_CONFIG = {
"max_ms": 2500, "max_ms": 2500,
}, },
# Context engine -- controls how the context window is managed when
# approaching the model's token limit.
# "compressor" = built-in lossy summarization (default).
# Set to a plugin name to activate an alternative engine (e.g. "lcm"
# for Lossless Context Management). The engine must be installed as
# a plugin in plugins/context_engine/<name>/ or ~/.hermes/plugins/.
"context": {
"engine": "compressor",
},
# Persistent memory -- bounded curated memory injected into system prompt # Persistent memory -- bounded curated memory injected into system prompt
"memory": { "memory": {
"memory_enabled": True, "memory_enabled": True,
@ -1450,7 +1460,7 @@ _KNOWN_ROOT_KEYS = {
"_config_version", "model", "providers", "fallback_model", "_config_version", "model", "providers", "fallback_model",
"fallback_providers", "credential_pool_strategies", "toolsets", "fallback_providers", "credential_pool_strategies", "toolsets",
"agent", "terminal", "display", "compression", "delegation", "agent", "terminal", "display", "compression", "delegation",
"auxiliary", "custom_providers", "memory", "gateway", "auxiliary", "custom_providers", "context", "memory", "gateway",
} }
# Valid fields inside a custom_providers list entry # Valid fields inside a custom_providers list entry

View file

@ -0,0 +1,219 @@
"""Context engine plugin discovery.
Scans ``plugins/context_engine/<name>/`` directories for context engine
plugins. Each subdirectory must contain ``__init__.py`` with a class
implementing the ContextEngine ABC.
Context engines are separate from the general plugin system they live
in the repo and are always available without user installation. Only ONE
can be active at a time, selected via ``context.engine`` in config.yaml.
The default engine is ``"compressor"`` (the built-in ContextCompressor).
Usage:
from plugins.context_engine import discover_context_engines, load_context_engine
available = discover_context_engines() # [(name, desc, available), ...]
engine = load_context_engine("lcm") # ContextEngine instance
"""
from __future__ import annotations
import importlib
import importlib.util
import logging
import sys
from pathlib import Path
from typing import List, Optional, Tuple
logger = logging.getLogger(__name__)
_CONTEXT_ENGINE_PLUGINS_DIR = Path(__file__).parent
def discover_context_engines() -> List[Tuple[str, str, bool]]:
"""Scan plugins/context_engine/ for available engines.
Returns list of (name, description, is_available) tuples.
Does NOT import the engines just reads plugin.yaml for metadata
and does a lightweight availability check.
"""
results = []
if not _CONTEXT_ENGINE_PLUGINS_DIR.is_dir():
return results
for child in sorted(_CONTEXT_ENGINE_PLUGINS_DIR.iterdir()):
if not child.is_dir() or child.name.startswith(("_", ".")):
continue
init_file = child / "__init__.py"
if not init_file.exists():
continue
# Read description from plugin.yaml if available
desc = ""
yaml_file = child / "plugin.yaml"
if yaml_file.exists():
try:
import yaml
with open(yaml_file) as f:
meta = yaml.safe_load(f) or {}
desc = meta.get("description", "")
except Exception:
pass
# Quick availability check — try loading and calling is_available()
available = True
try:
engine = _load_engine_from_dir(child)
if engine is None:
available = False
elif hasattr(engine, "is_available"):
available = engine.is_available()
except Exception:
available = False
results.append((child.name, desc, available))
return results
def load_context_engine(name: str) -> Optional["ContextEngine"]:
"""Load and return a ContextEngine instance by name.
Returns None if the engine is not found or fails to load.
"""
engine_dir = _CONTEXT_ENGINE_PLUGINS_DIR / name
if not engine_dir.is_dir():
logger.debug("Context engine '%s' not found in %s", name, _CONTEXT_ENGINE_PLUGINS_DIR)
return None
try:
engine = _load_engine_from_dir(engine_dir)
if engine:
return engine
logger.warning("Context engine '%s' loaded but no engine instance found", name)
return None
except Exception as e:
logger.warning("Failed to load context engine '%s': %s", name, e)
return None
def _load_engine_from_dir(engine_dir: Path) -> Optional["ContextEngine"]:
"""Import an engine module and extract the ContextEngine instance.
The module must have either:
- A register(ctx) function (plugin-style) we simulate a ctx
- A top-level class that extends ContextEngine we instantiate it
"""
name = engine_dir.name
module_name = f"plugins.context_engine.{name}"
init_file = engine_dir / "__init__.py"
if not init_file.exists():
return None
# Check if already loaded
if module_name in sys.modules:
mod = sys.modules[module_name]
else:
# Handle relative imports within the plugin
# First ensure the parent packages are registered
for parent in ("plugins", "plugins.context_engine"):
if parent not in sys.modules:
parent_path = Path(__file__).parent
if parent == "plugins":
parent_path = parent_path.parent
parent_init = parent_path / "__init__.py"
if parent_init.exists():
spec = importlib.util.spec_from_file_location(
parent, str(parent_init),
submodule_search_locations=[str(parent_path)]
)
if spec:
parent_mod = importlib.util.module_from_spec(spec)
sys.modules[parent] = parent_mod
try:
spec.loader.exec_module(parent_mod)
except Exception:
pass
# Now load the engine module
spec = importlib.util.spec_from_file_location(
module_name, str(init_file),
submodule_search_locations=[str(engine_dir)]
)
if not spec:
return None
mod = importlib.util.module_from_spec(spec)
sys.modules[module_name] = mod
# Register submodules so relative imports work
for sub_file in engine_dir.glob("*.py"):
if sub_file.name == "__init__.py":
continue
sub_name = sub_file.stem
full_sub_name = f"{module_name}.{sub_name}"
if full_sub_name not in sys.modules:
sub_spec = importlib.util.spec_from_file_location(
full_sub_name, str(sub_file)
)
if sub_spec:
sub_mod = importlib.util.module_from_spec(sub_spec)
sys.modules[full_sub_name] = sub_mod
try:
sub_spec.loader.exec_module(sub_mod)
except Exception as e:
logger.debug("Failed to load submodule %s: %s", full_sub_name, e)
try:
spec.loader.exec_module(mod)
except Exception as e:
logger.debug("Failed to exec_module %s: %s", module_name, e)
sys.modules.pop(module_name, None)
return None
# Try register(ctx) pattern first (how plugins are written)
if hasattr(mod, "register"):
collector = _EngineCollector()
try:
mod.register(collector)
if collector.engine:
return collector.engine
except Exception as e:
logger.debug("register() failed for %s: %s", name, e)
# Fallback: find a ContextEngine subclass and instantiate it
from agent.context_engine import ContextEngine
for attr_name in dir(mod):
attr = getattr(mod, attr_name, None)
if (isinstance(attr, type) and issubclass(attr, ContextEngine)
and attr is not ContextEngine):
try:
return attr()
except Exception:
pass
return None
class _EngineCollector:
"""Fake plugin context that captures register_context_engine calls."""
def __init__(self):
self.engine = None
def register_context_engine(self, engine):
self.engine = engine
# No-op for other registration methods
def register_tool(self, *args, **kwargs):
pass
def register_hook(self, *args, **kwargs):
pass
def register_cli_command(self, *args, **kwargs):
pass
def register_memory_provider(self, *args, **kwargs):
pass

View file

@ -1268,18 +1268,54 @@ class AIAgent:
pass pass
break break
# Check if a plugin registered a custom context engine (e.g. LCM) # Select context engine: config-driven (like memory providers).
_plugin_engine = None # 1. Check config.yaml context.engine setting
# 2. Check plugins/context_engine/<name>/ directory (repo-shipped)
# 3. Check general plugin system (user-installed plugins)
# 4. Fall back to built-in ContextCompressor
_selected_engine = None
_engine_name = "compressor" # default
try: try:
from hermes_cli.plugins import get_plugin_context_engine _ctx_cfg = _agent_cfg.get("context", {}) if isinstance(_agent_cfg, dict) else {}
_plugin_engine = get_plugin_context_engine() _engine_name = _ctx_cfg.get("engine", "compressor") or "compressor"
except Exception: except Exception:
pass pass
if _plugin_engine is not None: if _engine_name != "compressor":
self.context_compressor = _plugin_engine # Try loading from plugins/context_engine/<name>/
try:
from plugins.context_engine import load_context_engine
_selected_engine = load_context_engine(_engine_name)
except Exception as _ce_load_err:
logger.debug("Context engine load from plugins/context_engine/: %s", _ce_load_err)
# Try general plugin system as fallback
if _selected_engine is None:
try:
from hermes_cli.plugins import get_plugin_context_engine
_candidate = get_plugin_context_engine()
if _candidate and _candidate.name == _engine_name:
_selected_engine = _candidate
except Exception:
pass
if _selected_engine is None:
logger.warning(
"Context engine '%s' not found — falling back to built-in compressor",
_engine_name,
)
else:
# Even with default config, check if a plugin registered one
try:
from hermes_cli.plugins import get_plugin_context_engine
_selected_engine = get_plugin_context_engine()
except Exception:
pass
if _selected_engine is not None:
self.context_compressor = _selected_engine
if not self.quiet_mode: if not self.quiet_mode:
logger.info("Using plugin context engine: %s", _plugin_engine.name) logger.info("Using context engine: %s", _selected_engine.name)
else: else:
self.context_compressor = ContextCompressor( self.context_compressor = ContextCompressor(
model=self.model, model=self.model,
@ -1385,11 +1421,13 @@ class AIAgent:
"api_key": getattr(self, "api_key", ""), "api_key": getattr(self, "api_key", ""),
"client_kwargs": dict(self._client_kwargs), "client_kwargs": dict(self._client_kwargs),
"use_prompt_caching": self._use_prompt_caching, "use_prompt_caching": self._use_prompt_caching,
# Compressor state that _try_activate_fallback() overwrites # Context engine state that _try_activate_fallback() overwrites.
"compressor_model": _cc.model, # Use getattr for model/base_url/api_key/provider since plugin
"compressor_base_url": _cc.base_url, # engines may not have these (they're ContextCompressor-specific).
"compressor_model": getattr(_cc, "model", self.model),
"compressor_base_url": getattr(_cc, "base_url", self.base_url),
"compressor_api_key": getattr(_cc, "api_key", ""), "compressor_api_key": getattr(_cc, "api_key", ""),
"compressor_provider": _cc.provider, "compressor_provider": getattr(_cc, "provider", self.provider),
"compressor_context_length": _cc.context_length, "compressor_context_length": _cc.context_length,
"compressor_threshold_tokens": _cc.threshold_tokens, "compressor_threshold_tokens": _cc.threshold_tokens,
} }
@ -1518,13 +1556,12 @@ class AIAgent:
provider=self.provider, provider=self.provider,
config_context_length=getattr(self, "_config_context_length", None), config_context_length=getattr(self, "_config_context_length", None),
) )
self.context_compressor.model = self.model self.context_compressor.update_model(
self.context_compressor.base_url = self.base_url model=self.model,
self.context_compressor.api_key = self.api_key context_length=new_context_length,
self.context_compressor.provider = self.provider base_url=self.base_url,
self.context_compressor.context_length = new_context_length api_key=getattr(self, "api_key", ""),
self.context_compressor.threshold_tokens = int( provider=self.provider,
new_context_length * self.context_compressor.threshold_percent
) )
# ── Invalidate cached system prompt so it rebuilds next turn ── # ── Invalidate cached system prompt so it rebuilds next turn ──
@ -1540,10 +1577,10 @@ class AIAgent:
"api_key": getattr(self, "api_key", ""), "api_key": getattr(self, "api_key", ""),
"client_kwargs": dict(self._client_kwargs), "client_kwargs": dict(self._client_kwargs),
"use_prompt_caching": self._use_prompt_caching, "use_prompt_caching": self._use_prompt_caching,
"compressor_model": _cc.model if _cc else self.model, "compressor_model": getattr(_cc, "model", self.model) if _cc else self.model,
"compressor_base_url": _cc.base_url if _cc else self.base_url, "compressor_base_url": getattr(_cc, "base_url", self.base_url) if _cc else self.base_url,
"compressor_api_key": getattr(_cc, "api_key", "") if _cc else "", "compressor_api_key": getattr(_cc, "api_key", "") if _cc else "",
"compressor_provider": _cc.provider if _cc else self.provider, "compressor_provider": getattr(_cc, "provider", self.provider) if _cc else self.provider,
"compressor_context_length": _cc.context_length if _cc else 0, "compressor_context_length": _cc.context_length if _cc else 0,
"compressor_threshold_tokens": _cc.threshold_tokens if _cc else 0, "compressor_threshold_tokens": _cc.threshold_tokens if _cc else 0,
} }
@ -2740,10 +2777,11 @@ class AIAgent:
} }
def shutdown_memory_provider(self, messages: list = None) -> None: def shutdown_memory_provider(self, messages: list = None) -> None:
"""Shut down the memory provider — call at actual session boundaries. """Shut down the memory provider and context engine — call at actual session boundaries.
This calls on_session_end() then shutdown_all() on the memory This calls on_session_end() then shutdown_all() on the memory
manager. NOT called per-turn only at CLI exit, /reset, gateway manager, and on_session_end() on the context engine.
NOT called per-turn only at CLI exit, /reset, gateway
session expiry, etc. session expiry, etc.
""" """
if self._memory_manager: if self._memory_manager:
@ -2755,6 +2793,15 @@ class AIAgent:
self._memory_manager.shutdown_all() self._memory_manager.shutdown_all()
except Exception: except Exception:
pass pass
# Notify context engine of session end (flush DAG, close DBs, etc.)
if hasattr(self, "context_compressor") and self.context_compressor:
try:
self.context_compressor.on_session_end(
self.session_id or "",
messages or [],
)
except Exception:
pass
def close(self) -> None: def close(self) -> None:
"""Release all resources held by this agent instance. """Release all resources held by this agent instance.
@ -5272,13 +5319,12 @@ class AIAgent:
self.model, base_url=self.base_url, self.model, base_url=self.base_url,
api_key=self.api_key, provider=self.provider, api_key=self.api_key, provider=self.provider,
) )
self.context_compressor.model = self.model self.context_compressor.update_model(
self.context_compressor.base_url = self.base_url model=self.model,
self.context_compressor.api_key = self.api_key context_length=fb_context_length,
self.context_compressor.provider = self.provider base_url=self.base_url,
self.context_compressor.context_length = fb_context_length api_key=getattr(self, "api_key", ""),
self.context_compressor.threshold_tokens = int( provider=self.provider,
fb_context_length * self.context_compressor.threshold_percent
) )
self._emit_status( self._emit_status(
@ -5338,14 +5384,15 @@ class AIAgent:
shared=True, shared=True,
) )
# ── Restore context compressor state ── # ── Restore context engine state ──
cc = self.context_compressor cc = self.context_compressor
cc.model = rt["compressor_model"] cc.update_model(
cc.base_url = rt["compressor_base_url"] model=rt["compressor_model"],
cc.api_key = rt["compressor_api_key"] context_length=rt["compressor_context_length"],
cc.provider = rt["compressor_provider"] base_url=rt["compressor_base_url"],
cc.context_length = rt["compressor_context_length"] api_key=rt["compressor_api_key"],
cc.threshold_tokens = rt["compressor_threshold_tokens"] provider=rt["compressor_provider"],
)
# ── Reset fallback chain for the new turn ── # ── Reset fallback chain for the new turn ──
self._fallback_activated = False self._fallback_activated = False
@ -8247,7 +8294,7 @@ class AIAgent:
# Cache discovered context length after successful call. # Cache discovered context length after successful call.
# Only persist limits confirmed by the provider (parsed # Only persist limits confirmed by the provider (parsed
# from the error message), not guessed probe tiers. # from the error message), not guessed probe tiers.
if self.context_compressor._context_probed: if getattr(self.context_compressor, "_context_probed", False):
ctx = self.context_compressor.context_length ctx = self.context_compressor.context_length
if getattr(self.context_compressor, "_context_probe_persistable", False): if getattr(self.context_compressor, "_context_probe_persistable", False):
save_context_length(self.model, self.base_url, ctx) save_context_length(self.model, self.base_url, ctx)
@ -8586,16 +8633,22 @@ class AIAgent:
compressor = self.context_compressor compressor = self.context_compressor
old_ctx = compressor.context_length old_ctx = compressor.context_length
if old_ctx > _reduced_ctx: if old_ctx > _reduced_ctx:
compressor.context_length = _reduced_ctx compressor.update_model(
compressor.threshold_tokens = int( model=self.model,
_reduced_ctx * compressor.threshold_percent context_length=_reduced_ctx,
base_url=self.base_url,
api_key=getattr(self, "api_key", ""),
provider=self.provider,
) )
compressor._context_probed = True # Context probing flags — only set on built-in
# Don't persist — this is a subscription-tier # compressor (plugin engines manage their own).
# limitation, not a model capability. If the user if hasattr(compressor, "_context_probed"):
# later enables extra usage the 1M limit should compressor._context_probed = True
# come back automatically. # Don't persist — this is a subscription-tier
compressor._context_probe_persistable = False # limitation, not a model capability. If the
# user later enables extra usage the 1M limit
# should come back automatically.
compressor._context_probe_persistable = False
self._vprint( self._vprint(
f"{self.log_prefix}⚠️ Anthropic long-context tier " f"{self.log_prefix}⚠️ Anthropic long-context tier "
f"requires extra usage — reducing context: " f"requires extra usage — reducing context: "
@ -8759,17 +8812,25 @@ class AIAgent:
new_ctx = get_next_probe_tier(old_ctx) new_ctx = get_next_probe_tier(old_ctx)
if new_ctx and new_ctx < old_ctx: if new_ctx and new_ctx < old_ctx:
compressor.context_length = new_ctx compressor.update_model(
compressor.threshold_tokens = int(new_ctx * compressor.threshold_percent) model=self.model,
compressor._context_probed = True context_length=new_ctx,
# Only persist limits parsed from the provider's base_url=self.base_url,
# error message (a real number). Guessed fallback api_key=getattr(self, "api_key", ""),
# tiers from get_next_probe_tier() should stay provider=self.provider,
# in-memory only — persisting them pollutes the
# cache with wrong values.
compressor._context_probe_persistable = bool(
parsed_limit and parsed_limit == new_ctx
) )
# Context probing flags — only set on built-in
# compressor (plugin engines manage their own).
if hasattr(compressor, "_context_probed"):
compressor._context_probed = True
# Only persist limits parsed from the provider's
# error message (a real number). Guessed fallback
# tiers from get_next_probe_tier() should stay
# in-memory only — persisting them pollutes the
# cache with wrong values.
compressor._context_probe_persistable = bool(
parsed_limit and parsed_limit == new_ctx
)
self._vprint(f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,}{new_ctx:,} tokens", force=True) self._vprint(f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,}{new_ctx:,} tokens", force=True)
else: else:
self._vprint(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...", force=True) self._vprint(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...", force=True)