mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-08 03:01:47 +00:00
guard kanban worker lifecycle by run id
This commit is contained in:
parent
f0d278412f
commit
56b4795115
5 changed files with 243 additions and 36 deletions
|
|
@ -943,7 +943,12 @@ def _cmd_init(args: argparse.Namespace) -> int:
|
|||
|
||||
def _cmd_heartbeat(args: argparse.Namespace) -> int:
|
||||
with kb.connect() as conn:
|
||||
ok = kb.heartbeat_worker(conn, args.task_id, note=getattr(args, "note", None))
|
||||
ok = kb.heartbeat_worker(
|
||||
conn,
|
||||
args.task_id,
|
||||
note=getattr(args, "note", None),
|
||||
expected_run_id=_worker_run_id_for(args.task_id),
|
||||
)
|
||||
if not ok:
|
||||
print(f"cannot heartbeat {args.task_id} (not running?)", file=sys.stderr)
|
||||
return 1
|
||||
|
|
@ -1406,6 +1411,18 @@ def _cmd_comment(args: argparse.Namespace) -> int:
|
|||
return 0
|
||||
|
||||
|
||||
def _worker_run_id_for(task_id: str) -> Optional[int]:
|
||||
if os.environ.get("HERMES_KANBAN_TASK") != task_id:
|
||||
return None
|
||||
raw = os.environ.get("HERMES_KANBAN_RUN_ID")
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
return int(raw)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _cmd_complete(args: argparse.Namespace) -> int:
|
||||
"""Mark one or more tasks done. Supports a single id or a list."""
|
||||
ids = list(args.task_ids or [])
|
||||
|
|
@ -1442,6 +1459,7 @@ def _cmd_complete(args: argparse.Namespace) -> int:
|
|||
result=args.result,
|
||||
summary=summary,
|
||||
metadata=metadata,
|
||||
expected_run_id=_worker_run_id_for(tid),
|
||||
):
|
||||
failed.append(tid)
|
||||
print(f"cannot complete {tid} (unknown id or terminal state)", file=sys.stderr)
|
||||
|
|
@ -1487,7 +1505,12 @@ def _cmd_block(args: argparse.Namespace) -> int:
|
|||
for tid in ids:
|
||||
if reason:
|
||||
kb.add_comment(conn, tid, author, f"BLOCKED: {reason}")
|
||||
if not kb.block_task(conn, tid, reason=reason):
|
||||
if not kb.block_task(
|
||||
conn,
|
||||
tid,
|
||||
reason=reason,
|
||||
expected_run_id=_worker_run_id_for(tid),
|
||||
):
|
||||
failed.append(tid)
|
||||
print(f"cannot block {tid}", file=sys.stderr)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -2098,6 +2098,7 @@ def complete_task(
|
|||
summary: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
created_cards: Optional[Iterable[str]] = None,
|
||||
expected_run_id: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Transition ``running|ready -> done`` and record ``result``.
|
||||
|
||||
|
|
@ -2157,20 +2158,37 @@ def complete_task(
|
|||
verified_cards = []
|
||||
|
||||
with write_txn(conn):
|
||||
cur = conn.execute(
|
||||
"""
|
||||
UPDATE tasks
|
||||
SET status = 'done',
|
||||
result = ?,
|
||||
completed_at = ?,
|
||||
claim_lock = NULL,
|
||||
claim_expires= NULL,
|
||||
worker_pid = NULL
|
||||
WHERE id = ?
|
||||
AND status IN ('running', 'ready', 'blocked')
|
||||
""",
|
||||
(result, now, task_id),
|
||||
)
|
||||
if expected_run_id is None:
|
||||
cur = conn.execute(
|
||||
"""
|
||||
UPDATE tasks
|
||||
SET status = 'done',
|
||||
result = ?,
|
||||
completed_at = ?,
|
||||
claim_lock = NULL,
|
||||
claim_expires= NULL,
|
||||
worker_pid = NULL
|
||||
WHERE id = ?
|
||||
AND status IN ('running', 'ready', 'blocked')
|
||||
""",
|
||||
(result, now, task_id),
|
||||
)
|
||||
else:
|
||||
cur = conn.execute(
|
||||
"""
|
||||
UPDATE tasks
|
||||
SET status = 'done',
|
||||
result = ?,
|
||||
completed_at = ?,
|
||||
claim_lock = NULL,
|
||||
claim_expires= NULL,
|
||||
worker_pid = NULL
|
||||
WHERE id = ?
|
||||
AND status IN ('running', 'ready', 'blocked')
|
||||
AND current_run_id = ?
|
||||
""",
|
||||
(result, now, task_id, int(expected_run_id)),
|
||||
)
|
||||
if cur.rowcount != 1:
|
||||
return False
|
||||
run_id = _end_run(
|
||||
|
|
@ -2310,21 +2328,37 @@ def block_task(
|
|||
task_id: str,
|
||||
*,
|
||||
reason: Optional[str] = None,
|
||||
expected_run_id: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Transition ``running -> blocked``."""
|
||||
with write_txn(conn):
|
||||
cur = conn.execute(
|
||||
"""
|
||||
UPDATE tasks
|
||||
SET status = 'blocked',
|
||||
claim_lock = NULL,
|
||||
claim_expires= NULL,
|
||||
worker_pid = NULL
|
||||
WHERE id = ?
|
||||
AND status IN ('running', 'ready')
|
||||
""",
|
||||
(task_id,),
|
||||
)
|
||||
if expected_run_id is None:
|
||||
cur = conn.execute(
|
||||
"""
|
||||
UPDATE tasks
|
||||
SET status = 'blocked',
|
||||
claim_lock = NULL,
|
||||
claim_expires= NULL,
|
||||
worker_pid = NULL
|
||||
WHERE id = ?
|
||||
AND status IN ('running', 'ready')
|
||||
""",
|
||||
(task_id,),
|
||||
)
|
||||
else:
|
||||
cur = conn.execute(
|
||||
"""
|
||||
UPDATE tasks
|
||||
SET status = 'blocked',
|
||||
claim_lock = NULL,
|
||||
claim_expires= NULL,
|
||||
worker_pid = NULL
|
||||
WHERE id = ?
|
||||
AND status IN ('running', 'ready')
|
||||
AND current_run_id = ?
|
||||
""",
|
||||
(task_id, int(expected_run_id)),
|
||||
)
|
||||
if cur.rowcount != 1:
|
||||
return False
|
||||
run_id = _end_run(
|
||||
|
|
@ -2596,6 +2630,7 @@ def heartbeat_worker(
|
|||
task_id: str,
|
||||
*,
|
||||
note: Optional[str] = None,
|
||||
expected_run_id: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Record a ``heartbeat`` event + touch ``last_heartbeat_at``.
|
||||
|
||||
|
|
@ -2609,14 +2644,25 @@ def heartbeat_worker(
|
|||
"""
|
||||
now = int(time.time())
|
||||
with write_txn(conn):
|
||||
cur = conn.execute(
|
||||
"UPDATE tasks SET last_heartbeat_at = ? "
|
||||
"WHERE id = ? AND status = 'running'",
|
||||
(now, task_id),
|
||||
)
|
||||
if expected_run_id is None:
|
||||
cur = conn.execute(
|
||||
"UPDATE tasks SET last_heartbeat_at = ? "
|
||||
"WHERE id = ? AND status = 'running'",
|
||||
(now, task_id),
|
||||
)
|
||||
else:
|
||||
cur = conn.execute(
|
||||
"UPDATE tasks SET last_heartbeat_at = ? "
|
||||
"WHERE id = ? AND status = 'running' AND current_run_id = ?",
|
||||
(now, task_id, int(expected_run_id)),
|
||||
)
|
||||
if cur.rowcount != 1:
|
||||
return False
|
||||
run_id = _current_run_id(conn, task_id)
|
||||
run_id = (
|
||||
int(expected_run_id)
|
||||
if expected_run_id is not None
|
||||
else _current_run_id(conn, task_id)
|
||||
)
|
||||
if run_id is not None:
|
||||
conn.execute(
|
||||
"UPDATE task_runs SET last_heartbeat_at = ? WHERE id = ?",
|
||||
|
|
@ -3219,6 +3265,10 @@ def _default_spawn(
|
|||
env["HERMES_TENANT"] = task.tenant
|
||||
env["HERMES_KANBAN_TASK"] = task.id
|
||||
env["HERMES_KANBAN_WORKSPACE"] = workspace
|
||||
if task.current_run_id is not None:
|
||||
env["HERMES_KANBAN_RUN_ID"] = str(task.current_run_id)
|
||||
if task.claim_lock:
|
||||
env["HERMES_KANBAN_CLAIM_LOCK"] = task.claim_lock
|
||||
# Pin the shared board + workspaces root the dispatcher resolved, so
|
||||
# that even when the worker activates a profile (`hermes -p <name>`
|
||||
# rewrites HERMES_HOME), its kanban paths still match the
|
||||
|
|
|
|||
|
|
@ -1186,6 +1186,79 @@ def test_multiple_attempts_preserved_as_runs(kanban_home):
|
|||
conn.close()
|
||||
|
||||
|
||||
def test_stale_run_cannot_complete_new_attempt(kanban_home, monkeypatch):
|
||||
"""A worker from an earlier attempt cannot close a later retry."""
|
||||
import hermes_cli.kanban_db as _kb
|
||||
|
||||
conn = kb.connect()
|
||||
try:
|
||||
tid = kb.create_task(conn, title="retry guarded", assignee="worker")
|
||||
|
||||
kb.claim_task(conn, tid)
|
||||
run1 = kb.latest_run(conn, tid)
|
||||
kb._set_worker_pid(conn, tid, 98765)
|
||||
monkeypatch.setattr(_kb, "_pid_alive", lambda pid: False)
|
||||
assert kb.detect_crashed_workers(conn) == [tid]
|
||||
|
||||
kb.claim_task(conn, tid)
|
||||
run2 = kb.latest_run(conn, tid)
|
||||
assert run2.id != run1.id
|
||||
|
||||
assert not kb.complete_task(
|
||||
conn,
|
||||
tid,
|
||||
summary="late stale completion",
|
||||
expected_run_id=run1.id,
|
||||
)
|
||||
task = kb.get_task(conn, tid)
|
||||
assert task.status == "running"
|
||||
assert task.current_run_id == run2.id
|
||||
|
||||
assert kb.complete_task(
|
||||
conn,
|
||||
tid,
|
||||
summary="current completion",
|
||||
expected_run_id=run2.id,
|
||||
)
|
||||
runs = kb.list_runs(conn, tid)
|
||||
assert [r.outcome for r in runs] == ["crashed", "completed"]
|
||||
assert runs[-1].summary == "current completion"
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_stale_run_cannot_block_or_heartbeat_new_attempt(kanban_home, monkeypatch):
|
||||
"""Stale retry attempts cannot mutate the active run lifecycle."""
|
||||
import hermes_cli.kanban_db as _kb
|
||||
|
||||
conn = kb.connect()
|
||||
try:
|
||||
tid = kb.create_task(conn, title="retry heartbeat guarded", assignee="worker")
|
||||
|
||||
kb.claim_task(conn, tid)
|
||||
run1 = kb.latest_run(conn, tid)
|
||||
kb._set_worker_pid(conn, tid, 98765)
|
||||
monkeypatch.setattr(_kb, "_pid_alive", lambda pid: False)
|
||||
assert kb.detect_crashed_workers(conn) == [tid]
|
||||
|
||||
kb.claim_task(conn, tid)
|
||||
run2 = kb.latest_run(conn, tid)
|
||||
assert run2.id != run1.id
|
||||
|
||||
assert not kb.heartbeat_worker(conn, tid, note="late", expected_run_id=run1.id)
|
||||
assert not kb.block_task(conn, tid, reason="late block", expected_run_id=run1.id)
|
||||
task = kb.get_task(conn, tid)
|
||||
assert task.status == "running"
|
||||
assert task.current_run_id == run2.id
|
||||
assert task.last_heartbeat_at is None
|
||||
|
||||
assert kb.heartbeat_worker(conn, tid, note="current", expected_run_id=run2.id)
|
||||
assert kb.block_task(conn, tid, reason="current block", expected_run_id=run2.id)
|
||||
assert kb.get_task(conn, tid).status == "blocked"
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_run_on_block_with_reason(kanban_home):
|
||||
conn = kb.connect()
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -611,6 +611,44 @@ def test_worker_complete_own_task_still_works(worker_env):
|
|||
assert d.get("ok") is True and d.get("task_id") == worker_env
|
||||
|
||||
|
||||
def test_worker_complete_rejects_stale_run_id(worker_env, monkeypatch):
|
||||
"""A retried worker cannot complete the task using an old run token."""
|
||||
from hermes_cli import kanban_db as kb
|
||||
import hermes_cli.kanban_db as _kb
|
||||
|
||||
conn = kb.connect()
|
||||
try:
|
||||
run1 = kb.latest_run(conn, worker_env)
|
||||
kb._set_worker_pid(conn, worker_env, 98765)
|
||||
monkeypatch.setattr(_kb, "_pid_alive", lambda pid: False)
|
||||
assert kb.detect_crashed_workers(conn) == [worker_env]
|
||||
|
||||
kb.claim_task(conn, worker_env)
|
||||
run2 = kb.latest_run(conn, worker_env)
|
||||
assert run2.id != run1.id
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
from tools import kanban_tools as kt
|
||||
monkeypatch.setenv("HERMES_KANBAN_RUN_ID", str(run1.id))
|
||||
out = kt._handle_complete({"summary": "late stale completion"})
|
||||
d = json.loads(out)
|
||||
assert d.get("ok") is not True
|
||||
|
||||
conn = kb.connect()
|
||||
try:
|
||||
task = kb.get_task(conn, worker_env)
|
||||
assert task.status == "running"
|
||||
assert task.current_run_id == run2.id
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
monkeypatch.setenv("HERMES_KANBAN_RUN_ID", str(run2.id))
|
||||
out = kt._handle_complete({"summary": "current completion"})
|
||||
d = json.loads(out)
|
||||
assert d.get("ok") is True
|
||||
|
||||
|
||||
def test_orchestrator_complete_any_task_allowed(monkeypatch, tmp_path):
|
||||
"""Orchestrator profiles (no HERMES_KANBAN_TASK) can still complete
|
||||
any task via explicit task_id. The check only applies to workers."""
|
||||
|
|
|
|||
|
|
@ -79,6 +79,19 @@ def _default_task_id(arg: Optional[str]) -> Optional[str]:
|
|||
return env_tid or None
|
||||
|
||||
|
||||
def _worker_run_id(task_id: str) -> Optional[int]:
|
||||
"""Return this worker's dispatcher run id when it is scoped to task_id."""
|
||||
if os.environ.get("HERMES_KANBAN_TASK") != task_id:
|
||||
return None
|
||||
raw = os.environ.get("HERMES_KANBAN_RUN_ID")
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
return int(raw)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _enforce_worker_task_ownership(tid: str) -> Optional[str]:
|
||||
"""Reject worker-driven destructive calls on foreign task IDs.
|
||||
|
||||
|
|
@ -240,6 +253,7 @@ def _handle_complete(args: dict, **kw) -> str:
|
|||
conn, tid,
|
||||
result=result, summary=summary, metadata=metadata,
|
||||
created_cards=created_cards,
|
||||
expected_run_id=_worker_run_id(tid),
|
||||
)
|
||||
except kb.HallucinatedCardsError as hall_err:
|
||||
# Structured rejection — surface the phantom ids so the
|
||||
|
|
@ -281,7 +295,11 @@ def _handle_block(args: dict, **kw) -> str:
|
|||
try:
|
||||
kb, conn = _connect()
|
||||
try:
|
||||
ok = kb.block_task(conn, tid, reason=reason)
|
||||
ok = kb.block_task(
|
||||
conn, tid,
|
||||
reason=reason,
|
||||
expected_run_id=_worker_run_id(tid),
|
||||
)
|
||||
if not ok:
|
||||
return tool_error(
|
||||
f"could not block {tid} (unknown id or not in "
|
||||
|
|
@ -310,7 +328,12 @@ def _handle_heartbeat(args: dict, **kw) -> str:
|
|||
try:
|
||||
kb, conn = _connect()
|
||||
try:
|
||||
ok = kb.heartbeat_worker(conn, tid, note=note)
|
||||
ok = kb.heartbeat_worker(
|
||||
conn,
|
||||
tid,
|
||||
note=note,
|
||||
expected_run_id=_worker_run_id(tid),
|
||||
)
|
||||
if not ok:
|
||||
return tool_error(
|
||||
f"could not heartbeat {tid} (unknown id or not running)"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue