feat: ensure feature parity once again

This commit is contained in:
Brooklyn Nicholson 2026-04-11 14:02:36 -05:00
parent bf6af95ff5
commit e2ea8934d4
6 changed files with 922 additions and 112 deletions

View file

@ -229,6 +229,122 @@ def _resolve_model() -> str:
return "anthropic/claude-sonnet-4"
def _write_config_key(key_path: str, value):
cfg = _load_cfg()
current = cfg
keys = key_path.split(".")
for key in keys[:-1]:
if key not in current or not isinstance(current.get(key), dict):
current[key] = {}
current = current[key]
current[keys[-1]] = value
_save_cfg(cfg)
def _load_reasoning_config() -> dict | None:
from hermes_constants import parse_reasoning_effort
effort = str(_load_cfg().get("agent", {}).get("reasoning_effort", "") or "").strip()
return parse_reasoning_effort(effort)
def _load_service_tier() -> str | None:
raw = str(_load_cfg().get("agent", {}).get("service_tier", "") or "").strip().lower()
if not raw or raw in {"normal", "default", "standard", "off", "none"}:
return None
if raw in {"fast", "priority", "on"}:
return "priority"
return None
def _load_show_reasoning() -> bool:
return bool(_load_cfg().get("display", {}).get("show_reasoning", False))
def _load_tool_progress_mode() -> str:
raw = _load_cfg().get("display", {}).get("tool_progress", "all")
if raw is False:
return "off"
if raw is True:
return "all"
mode = str(raw or "all").strip().lower()
return mode if mode in {"off", "new", "all", "verbose"} else "all"
def _session_show_reasoning(sid: str) -> bool:
return bool(_sessions.get(sid, {}).get("show_reasoning", False))
def _session_tool_progress_mode(sid: str) -> str:
return str(_sessions.get(sid, {}).get("tool_progress_mode", "all") or "all")
def _tool_progress_enabled(sid: str) -> bool:
return _session_tool_progress_mode(sid) != "off"
def _restart_slash_worker(session: dict):
worker = session.get("slash_worker")
if worker:
try:
worker.close()
except Exception:
pass
try:
session["slash_worker"] = _SlashWorker(session["session_key"], getattr(session.get("agent"), "model", _resolve_model()))
except Exception:
session["slash_worker"] = None
def _apply_model_switch(sid: str, session: dict, raw_input: str) -> str:
agent = session.get("agent")
if not agent:
os.environ["HERMES_MODEL"] = raw_input
return raw_input
from hermes_cli.model_switch import switch_model
result = switch_model(
raw_input=raw_input,
current_provider=getattr(agent, "provider", "") or "",
current_model=getattr(agent, "model", "") or "",
current_base_url=getattr(agent, "base_url", "") or "",
current_api_key=getattr(agent, "api_key", "") or "",
)
if not result.success:
raise ValueError(result.error_message or "model switch failed")
agent.switch_model(
new_model=result.new_model,
new_provider=result.target_provider,
api_key=result.api_key,
base_url=result.base_url,
api_mode=result.api_mode,
)
os.environ["HERMES_MODEL"] = result.new_model
_restart_slash_worker(session)
_emit("session.info", sid, _session_info(agent))
return result.new_model
def _compress_session_history(session: dict) -> tuple[int, dict]:
from agent.model_metadata import estimate_messages_tokens_rough
agent = session["agent"]
history = list(session.get("history", []))
if len(history) < 4:
return 0, _get_usage(agent)
approx_tokens = estimate_messages_tokens_rough(history)
compressed, _ = agent._compress_context(
history,
getattr(agent, "_cached_system_prompt", "") or "",
approx_tokens=approx_tokens,
)
session["history"] = compressed
session["history_version"] = int(session.get("history_version", 0)) + 1
return len(history) - len(compressed), _get_usage(agent)
def _get_usage(agent) -> dict:
g = lambda k, fb=None: getattr(agent, k, 0) or (getattr(agent, fb, 0) if fb else 0)
usage = {
@ -320,14 +436,48 @@ def _tool_ctx(name: str, args: dict) -> str:
return ""
def _on_tool_start(sid: str, tool_call_id: str, name: str, args: dict):
session = _sessions.get(sid)
if session is not None:
try:
from agent.display import capture_local_edit_snapshot
snapshot = capture_local_edit_snapshot(name, args)
if snapshot is not None:
session.setdefault("edit_snapshots", {})[tool_call_id] = snapshot
except Exception:
pass
if _tool_progress_enabled(sid):
_emit("tool.start", sid, {"tool_id": tool_call_id, "name": name, "context": _tool_ctx(name, args)})
def _on_tool_complete(sid: str, tool_call_id: str, name: str, args: dict, result: str):
payload = {"tool_id": tool_call_id, "name": name}
session = _sessions.get(sid)
snapshot = None
if session is not None:
snapshot = session.setdefault("edit_snapshots", {}).pop(tool_call_id, None)
try:
from agent.display import render_edit_diff_with_delta
rendered: list[str] = []
if render_edit_diff_with_delta(name, result, function_args=args, snapshot=snapshot, print_fn=rendered.append):
payload["inline_diff"] = "\n".join(rendered)
except Exception:
pass
if _tool_progress_enabled(sid) or payload.get("inline_diff"):
_emit("tool.complete", sid, payload)
def _agent_cbs(sid: str) -> dict:
return dict(
tool_start_callback=lambda tc_id, name, args: _emit("tool.start", sid, {"tool_id": tc_id, "name": name, "context": _tool_ctx(name, args)}),
tool_complete_callback=lambda tc_id, name, args, result: _emit("tool.complete", sid, {"tool_id": tc_id, "name": name}),
tool_progress_callback=lambda name, preview, args: _emit("tool.progress", sid, {"name": name, "preview": preview}),
tool_gen_callback=lambda name: _emit("tool.generating", sid, {"name": name}),
tool_start_callback=lambda tc_id, name, args: _on_tool_start(sid, tc_id, name, args),
tool_complete_callback=lambda tc_id, name, args, result: _on_tool_complete(sid, tc_id, name, args, result),
tool_progress_callback=lambda name, preview, args: _tool_progress_enabled(sid)
and _emit("tool.progress", sid, {"name": name, "preview": preview}),
tool_gen_callback=lambda name: _tool_progress_enabled(sid) and _emit("tool.generating", sid, {"name": name}),
thinking_callback=lambda text: _emit("thinking.delta", sid, {"text": text}),
reasoning_callback=lambda text: _emit("reasoning.delta", sid, {"text": text}),
reasoning_callback=lambda text: _session_show_reasoning(sid) and _emit("reasoning.delta", sid, {"text": text}),
status_callback=lambda kind, text=None: _status_update(sid, str(kind), None if text is None else str(text)),
clarify_callback=lambda q, c: _block("clarify.request", sid, {"question": q, "choices": c}),
)
@ -357,7 +507,12 @@ def _make_agent(sid: str, key: str, session_id: str | None = None):
cfg = _load_cfg()
system_prompt = cfg.get("agent", {}).get("system_prompt", "") or ""
return AIAgent(
model=_resolve_model(), quiet_mode=True, platform="tui",
model=_resolve_model(),
quiet_mode=True,
verbose_logging=_load_tool_progress_mode() == "verbose",
reasoning_config=_load_reasoning_config(),
service_tier=_load_service_tier(),
platform="tui",
session_id=session_id or key, session_db=_get_db(),
ephemeral_system_prompt=system_prompt or None,
**_agent_cbs(sid),
@ -369,10 +524,16 @@ def _init_session(sid: str, key: str, agent, history: list, cols: int = 80):
"agent": agent,
"session_key": key,
"history": history,
"history_lock": threading.Lock(),
"history_version": 0,
"running": False,
"attached_images": [],
"image_counter": 0,
"cols": cols,
"slash_worker": None,
"show_reasoning": _load_show_reasoning(),
"tool_progress_mode": _load_tool_progress_mode(),
"edit_snapshots": {},
}
try:
_sessions[sid]["slash_worker"] = _SlashWorker(key, getattr(agent, "model", _resolve_model()))
@ -397,6 +558,17 @@ def _with_checkpoints(session, fn):
return fn(session["agent"]._checkpoint_mgr, os.getenv("TERMINAL_CWD", os.getcwd()))
def _resolve_checkpoint_hash(mgr, cwd: str, ref: str) -> str:
try:
checkpoints = mgr.list_checkpoints(cwd)
idx = int(ref) - 1
except ValueError:
return ref
if 0 <= idx < len(checkpoints):
return checkpoints[idx].get("hash", ref)
raise ValueError(f"Invalid checkpoint number. Use 1-{len(checkpoints)}.")
def _enrich_with_attached_images(user_text: str, image_paths: list[str]) -> str:
"""Pre-analyze attached images via vision and prepend descriptions to user text."""
import asyncio, json as _json
@ -561,11 +733,17 @@ def _(rid, params: dict) -> dict:
session, err = _sess(params, rid)
if err:
return err
history, removed = session.get("history", []), 0
while history and history[-1].get("role") in ("assistant", "tool"):
history.pop(); removed += 1
if history and history[-1].get("role") == "user":
history.pop(); removed += 1
removed = 0
with session["history_lock"]:
history = session.get("history", [])
while history and history[-1].get("role") in ("assistant", "tool"):
history.pop()
removed += 1
if history and history[-1].get("role") == "user":
history.pop()
removed += 1
if removed:
session["history_version"] = int(session.get("history_version", 0)) + 1
return _ok(rid, {"removed": removed})
@ -574,11 +752,11 @@ def _(rid, params: dict) -> dict:
session, err = _sess(params, rid)
if err:
return err
agent = session["agent"]
try:
if hasattr(agent, "compress_context"):
agent.compress_context()
return _ok(rid, {"status": "compressed", "usage": _get_usage(agent)})
with session["history_lock"]:
removed, usage = _compress_session_history(session)
_emit("session.info", params.get("session_id", ""), _session_info(session["agent"]))
return _ok(rid, {"status": "compressed", "removed": removed, "usage": usage})
except Exception as e:
return _err(rid, 5005, str(e))
@ -606,7 +784,8 @@ def _(rid, params: dict) -> dict:
return err
db = _get_db()
old_key = session["session_key"]
history = session.get("history", [])
with session["history_lock"]:
history = [dict(msg) for msg in session.get("history", [])]
if not history:
return _err(rid, 4008, "nothing to branch — send a message first")
new_key = _new_session_key()
@ -666,15 +845,47 @@ def _(rid, params: dict) -> dict:
session = _sessions.get(sid)
if not session:
return _err(rid, 4001, "session not found")
agent, history = session["agent"], session["history"]
with session["history_lock"]:
if session.get("running"):
return _err(rid, 4009, "session busy")
session["running"] = True
history = list(session["history"])
history_version = int(session.get("history_version", 0))
images = list(session.get("attached_images", []))
session["attached_images"] = []
agent = session["agent"]
_emit("message.start", sid)
def run():
approval_token = None
try:
from tools.approval import reset_current_session_key, set_current_session_key
approval_token = set_current_session_key(session["session_key"])
cols = session.get("cols", 80)
streamer = make_stream_renderer(cols)
images = session.pop("attached_images", [])
prompt = _enrich_with_attached_images(text, images) if images else text
prompt = text
if isinstance(prompt, str) and "@" in prompt:
from agent.context_references import preprocess_context_references
from agent.model_metadata import get_model_context_length
ctx_len = get_model_context_length(
getattr(agent, "model", "") or _resolve_model(),
base_url=getattr(agent, "base_url", "") or "",
api_key=getattr(agent, "api_key", "") or "",
)
ctx = preprocess_context_references(
prompt,
cwd=os.environ.get("TERMINAL_CWD", os.getcwd()),
allowed_root=os.environ.get("TERMINAL_CWD", os.getcwd()),
context_length=ctx_len,
)
if ctx.blocked:
_emit("error", sid, {"message": "\n".join(ctx.warnings) or "Context injection refused."})
return
prompt = ctx.message
prompt = _enrich_with_attached_images(prompt, images) if images else prompt
def _stream(delta):
payload = {"text": delta}
@ -689,7 +900,10 @@ def _(rid, params: dict) -> dict:
if isinstance(result, dict):
if isinstance(result.get("messages"), list):
session["history"] = result["messages"]
with session["history_lock"]:
if int(session.get("history_version", 0)) == history_version:
session["history"] = result["messages"]
session["history_version"] = history_version + 1
raw = result.get("final_response", "")
status = "interrupted" if result.get("interrupted") else "error" if result.get("error") else "complete"
else:
@ -703,6 +917,14 @@ def _(rid, params: dict) -> dict:
_emit("message.complete", sid, payload)
except Exception as e:
_emit("error", sid, {"message": str(e)})
finally:
try:
if approval_token is not None:
reset_current_session_key(approval_token)
except Exception:
pass
with session["history_lock"]:
session["running"] = False
threading.Thread(target=run, daemon=True).start()
return _ok(rid, {"status": "streaming"})
@ -733,6 +955,84 @@ def _(rid, params: dict) -> dict:
return _ok(rid, {"attached": True, "path": str(img_path), "count": len(session["attached_images"])})
@method("image.attach")
def _(rid, params: dict) -> dict:
session, err = _sess(params, rid)
if err:
return err
raw = str(params.get("path", "") or "").strip()
if not raw:
return _err(rid, 4015, "path required")
try:
from cli import _IMAGE_EXTENSIONS, _resolve_attachment_path, _split_path_input
path_token, remainder = _split_path_input(raw)
image_path = _resolve_attachment_path(path_token)
if image_path is None:
return _err(rid, 4016, f"image not found: {path_token}")
if image_path.suffix.lower() not in _IMAGE_EXTENSIONS:
return _err(rid, 4016, f"unsupported image: {image_path.name}")
session.setdefault("attached_images", []).append(str(image_path))
return _ok(
rid,
{
"attached": True,
"path": str(image_path),
"name": image_path.name,
"count": len(session["attached_images"]),
"remainder": remainder,
"text": remainder or f"[User attached image: {image_path.name}]",
},
)
except Exception as e:
return _err(rid, 5027, str(e))
@method("input.detect_drop")
def _(rid, params: dict) -> dict:
session, err = _sess(params, rid)
if err:
return err
try:
from cli import _detect_file_drop
raw = str(params.get("text", "") or "")
dropped = _detect_file_drop(raw)
if not dropped:
return _ok(rid, {"matched": False})
drop_path = dropped["path"]
remainder = dropped["remainder"]
if dropped["is_image"]:
session.setdefault("attached_images", []).append(str(drop_path))
text = remainder or f"[User attached image: {drop_path.name}]"
return _ok(
rid,
{
"matched": True,
"is_image": True,
"path": str(drop_path),
"name": drop_path.name,
"count": len(session["attached_images"]),
"text": text,
},
)
text = f"[User attached file: {drop_path}]" + (f"\n{remainder}" if remainder else "")
return _ok(
rid,
{
"matched": True,
"is_image": False,
"path": str(drop_path),
"name": drop_path.name,
"text": text,
},
)
except Exception as e:
return _err(rid, 5027, str(e))
@method("prompt.background")
def _(rid, params: dict) -> dict:
text, parent = params.get("text", ""), params.get("session_id", "")
@ -819,39 +1119,94 @@ def _(rid, params: dict) -> dict:
@method("config.set")
def _(rid, params: dict) -> dict:
key, value = params.get("key", ""), params.get("value", "")
session = _sessions.get(params.get("session_id", ""))
if key == "model":
os.environ["HERMES_MODEL"] = value
return _ok(rid, {"key": key, "value": value})
try:
if not value:
return _err(rid, 4002, "model value required")
if session:
value = _apply_model_switch(params.get("session_id", ""), session, value)
else:
os.environ["HERMES_MODEL"] = value
return _ok(rid, {"key": key, "value": value})
except Exception as e:
return _err(rid, 5001, str(e))
if key == "verbose":
cycle = ["off", "new", "all", "verbose"]
cur = session.get("tool_progress_mode", _load_tool_progress_mode()) if session else _load_tool_progress_mode()
if value and value != "cycle":
os.environ["HERMES_VERBOSE"] = value
return _ok(rid, {"key": key, "value": value})
cur = os.environ.get("HERMES_VERBOSE", "all")
try:
idx = cycle.index(cur)
except ValueError:
idx = 2
nv = cycle[(idx + 1) % len(cycle)]
os.environ["HERMES_VERBOSE"] = nv
nv = str(value).strip().lower()
if nv not in cycle:
return _err(rid, 4002, f"unknown verbose mode: {value}")
else:
try:
idx = cycle.index(cur)
except ValueError:
idx = 2
nv = cycle[(idx + 1) % len(cycle)]
_write_config_key("display.tool_progress", nv)
if session:
session["tool_progress_mode"] = nv
agent = session.get("agent")
if agent is not None:
agent.verbose_logging = nv == "verbose"
return _ok(rid, {"key": key, "value": nv})
if key == "yolo":
nv = "0" if os.environ.get("HERMES_YOLO", "0") == "1" else "1"
os.environ["HERMES_YOLO"] = nv
return _ok(rid, {"key": key, "value": nv})
try:
if session:
from tools.approval import (
disable_session_yolo,
enable_session_yolo,
is_session_yolo_enabled,
)
current = is_session_yolo_enabled(session["session_key"])
if current:
disable_session_yolo(session["session_key"])
nv = "0"
else:
enable_session_yolo(session["session_key"])
nv = "1"
else:
current = bool(os.environ.get("HERMES_YOLO_MODE"))
if current:
os.environ.pop("HERMES_YOLO_MODE", None)
nv = "0"
else:
os.environ["HERMES_YOLO_MODE"] = "1"
nv = "1"
return _ok(rid, {"key": key, "value": nv})
except Exception as e:
return _err(rid, 5001, str(e))
if key == "reasoning":
if value in ("show", "on"):
os.environ["HERMES_SHOW_REASONING"] = "1"
return _ok(rid, {"key": key, "value": "show"})
if value in ("hide", "off"):
os.environ.pop("HERMES_SHOW_REASONING", None)
return _ok(rid, {"key": key, "value": "hide"})
os.environ["HERMES_REASONING"] = value
return _ok(rid, {"key": key, "value": value})
try:
from hermes_constants import parse_reasoning_effort
arg = str(value or "").strip().lower()
if arg in ("show", "on"):
_write_config_key("display.show_reasoning", True)
if session:
session["show_reasoning"] = True
return _ok(rid, {"key": key, "value": "show"})
if arg in ("hide", "off"):
_write_config_key("display.show_reasoning", False)
if session:
session["show_reasoning"] = False
return _ok(rid, {"key": key, "value": "hide"})
parsed = parse_reasoning_effort(arg)
if parsed is None:
return _err(rid, 4002, f"unknown reasoning value: {value}")
_write_config_key("agent.reasoning_effort", arg)
if session and session.get("agent") is not None:
session["agent"].reasoning_config = parsed
return _ok(rid, {"key": key, "value": arg})
except Exception as e:
return _err(rid, 5001, str(e))
if key in ("prompt", "personality", "skin"):
try:
@ -900,6 +1255,12 @@ def _(rid, params: dict) -> dict:
return _ok(rid, {"prompt": _load_cfg().get("custom_prompt", "")})
if key == "skin":
return _ok(rid, {"value": _load_cfg().get("display", {}).get("skin", "default")})
if key == "mtime":
cfg_path = _hermes_home / "config.yaml"
try:
return _ok(rid, {"mtime": cfg_path.stat().st_mtime if cfg_path.exists() else 0})
except Exception:
return _ok(rid, {"mtime": 0})
return _err(rid, 4002, f"unknown config key: {key}")
@ -1235,30 +1596,23 @@ def _mirror_slash_side_effects(sid: str, session: dict, command: str):
try:
if name == "model" and arg and agent:
from hermes_cli.model_switch import switch_model
result = switch_model(
raw_input=arg,
current_provider=getattr(agent, "provider", "") or "",
current_model=getattr(agent, "model", "") or "",
current_base_url=getattr(agent, "base_url", "") or "",
current_api_key=getattr(agent, "api_key", "") or "",
)
if result.success:
agent.switch_model(
new_model=result.new_model,
new_provider=result.target_provider,
api_key=result.api_key,
base_url=result.base_url,
api_mode=result.api_mode,
)
_emit("session.info", sid, _session_info(agent))
_apply_model_switch(sid, session, arg)
elif name in ("personality", "prompt") and agent:
cfg = _load_cfg()
new_prompt = cfg.get("agent", {}).get("system_prompt", "") or ""
agent.ephemeral_system_prompt = new_prompt or None
agent._cached_system_prompt = None
elif name == "compress" and agent:
(getattr(agent, "compress_context", None) or getattr(agent, "context_compressor", agent).compress)()
with session["history_lock"]:
_compress_session_history(session)
_emit("session.info", sid, _session_info(agent))
elif name == "fast" and agent:
mode = arg.lower()
if mode in {"fast", "on"}:
agent.service_tier = "priority"
elif mode in {"normal", "off"}:
agent.service_tier = None
_emit("session.info", sid, _session_info(agent))
elif name == "reload-mcp" and agent and hasattr(agent, "reload_mcp_tools"):
agent.reload_mcp_tools()
elif name == "stop":
@ -1384,10 +1738,29 @@ def _(rid, params: dict) -> dict:
if err:
return err
target = params.get("hash", "")
file_path = params.get("file_path", "")
if not target:
return _err(rid, 4014, "hash required")
try:
return _ok(rid, _with_checkpoints(session, lambda mgr, cwd: mgr.restore(cwd, target)))
def go(mgr, cwd):
resolved = _resolve_checkpoint_hash(mgr, cwd, target)
result = mgr.restore(cwd, resolved, file_path=file_path or None)
if result.get("success") and not file_path:
removed = 0
with session["history_lock"]:
history = session.get("history", [])
while history and history[-1].get("role") in ("assistant", "tool"):
history.pop()
removed += 1
if history and history[-1].get("role") == "user":
history.pop()
removed += 1
if removed:
session["history_version"] = int(session.get("history_version", 0)) + 1
result["history_removed"] = removed
return result
return _ok(rid, _with_checkpoints(session, go))
except Exception as e:
return _err(rid, 5021, str(e))
@ -1401,7 +1774,7 @@ def _(rid, params: dict) -> dict:
if not target:
return _err(rid, 4014, "hash required")
try:
r = _with_checkpoints(session, lambda mgr, cwd: mgr.diff(cwd, target))
r = _with_checkpoints(session, lambda mgr, cwd: mgr.diff(cwd, _resolve_checkpoint_hash(mgr, cwd, target)))
raw = r.get("diff", "")[:4000]
payload = {"stat": r.get("stat", ""), "diff": raw}
rendered = render_diff(raw, session.get("cols", 80))