hermes-agent/tests/tools/test_rl_training_tool.py
alt-glitch 79a5f03f92 refactor(types): simplify pass on P1 batch
Follow-up to 15ac253b per /simplify review:

- gateway/platforms/discord.py:3638 - move self.resolved = True *after*
  the `if interaction.data is None: return` guard. Previously the view
  was marked resolved before the None-guard, so a None data payload
  silently rejected the user's next click.
- agent/display.py:732 - replace `if self.start_time is None: continue`
  with `assert self.start_time is not None`. start() sets start_time
  before the animate thread starts, so the None branch was dead; the
  `continue` form would have busy-looped (skipping the 0.12s sleep).
- tests/hermes_cli/test_config_shapes.py - drop __total__ dunder
  restatement test (it just echoes the class declaration); trim commit
  narration from module docstring.
- tests/agent/test_credential_pool.py, tests/tools/test_rl_training_tool.py -
  drop "added in commit ..." banners (narrates the change per CLAUDE.md).
2026-04-23 17:41:18 +05:30

167 lines
5.4 KiB
Python

"""Tests for rl_training_tool.py — file handle lifecycle and cleanup.
Verifies that _stop_training_run properly closes log file handles,
terminates processes, and handles edge cases on failure paths.
Inspired by PR #715 (0xbyt4).
"""
import dataclasses
import io
from unittest.mock import MagicMock
import pytest
from tools.rl_training_tool import RunState, _stop_training_run
def _make_run_state(**overrides) -> RunState:
"""Create a minimal RunState for testing."""
defaults = {
"run_id": "test-run-001",
"environment": "test_env",
"config": {},
}
defaults.update(overrides)
return RunState(**defaults)
class TestStopTrainingRunFileHandles:
"""Verify that _stop_training_run closes log file handles stored as attributes."""
def test_closes_all_log_file_handles(self):
state = _make_run_state()
files = {}
for attr in ("api_log_file", "trainer_log_file", "env_log_file"):
fh = MagicMock()
setattr(state, attr, fh)
files[attr] = fh
_stop_training_run(state)
for attr, fh in files.items():
fh.close.assert_called_once()
assert getattr(state, attr) is None
def test_clears_file_attrs_to_none(self):
state = _make_run_state()
state.api_log_file = MagicMock()
_stop_training_run(state)
assert state.api_log_file is None
def test_close_exception_does_not_propagate(self):
"""If a file handle .close() raises, it must not crash."""
state = _make_run_state()
bad_fh = MagicMock()
bad_fh.close.side_effect = OSError("already closed")
good_fh = MagicMock()
state.api_log_file = bad_fh
state.trainer_log_file = good_fh
_stop_training_run(state) # should not raise
bad_fh.close.assert_called_once()
good_fh.close.assert_called_once()
def test_handles_missing_file_attrs(self):
"""RunState without log file attrs should not crash."""
state = _make_run_state()
# No log file attrs set at all — getattr(..., None) should handle it
_stop_training_run(state) # should not raise
class TestStopTrainingRunProcesses:
"""Verify that _stop_training_run terminates processes correctly."""
def test_terminates_running_processes(self):
state = _make_run_state()
for attr in ("api_process", "trainer_process", "env_process"):
proc = MagicMock()
proc.poll.return_value = None # still running
setattr(state, attr, proc)
_stop_training_run(state)
for attr in ("api_process", "trainer_process", "env_process"):
getattr(state, attr).terminate.assert_called_once()
def test_does_not_terminate_exited_processes(self):
state = _make_run_state()
proc = MagicMock()
proc.poll.return_value = 0 # already exited
state.api_process = proc
_stop_training_run(state)
proc.terminate.assert_not_called()
def test_handles_none_processes(self):
state = _make_run_state()
# All process attrs are None by default
_stop_training_run(state) # should not raise
def test_handles_mixed_running_and_exited_processes(self):
state = _make_run_state()
# api still running
api = MagicMock()
api.poll.return_value = None
state.api_process = api
# trainer already exited
trainer = MagicMock()
trainer.poll.return_value = 0
state.trainer_process = trainer
# env is None
state.env_process = None
_stop_training_run(state)
api.terminate.assert_called_once()
trainer.terminate.assert_not_called()
class TestRunStateLogFileFields:
def test_log_file_fields_default_none(self):
"""All three log_file fields should default to None."""
state = _make_run_state()
assert state.api_log_file is None
assert state.trainer_log_file is None
assert state.env_log_file is None
def test_accepts_file_handle_for_api_log(self):
"""api_log_file should accept an open file-like object."""
api_log = io.StringIO()
state = _make_run_state(api_log_file=api_log)
assert state.api_log_file is api_log
def test_log_file_fields_present_in_dataclass(self):
"""All three field names must be declared on the RunState dataclass."""
field_names = {f.name for f in dataclasses.fields(RunState)}
assert "api_log_file" in field_names
assert "trainer_log_file" in field_names
assert "env_log_file" in field_names
class TestStopTrainingRunStatus:
"""Verify status transitions in _stop_training_run."""
def test_sets_status_to_stopped_when_running(self):
state = _make_run_state(status="running")
_stop_training_run(state)
assert state.status == "stopped"
def test_does_not_change_status_when_failed(self):
state = _make_run_state(status="failed")
_stop_training_run(state)
assert state.status == "failed"
def test_does_not_change_status_when_pending(self):
state = _make_run_state(status="pending")
_stop_training_run(state)
assert state.status == "pending"
def test_no_crash_with_no_processes_and_no_files(self):
state = _make_run_state()
_stop_training_run(state) # should not raise
assert state.status == "pending"