fix: preserve reasoning_content on Kimi replay

This commit is contained in:
helix4u 2026-04-21 23:40:31 -06:00 committed by Teknium
parent 30ec12970b
commit a7d78d3bfd
7 changed files with 187 additions and 16 deletions

View file

@ -7216,6 +7216,7 @@ class GatewayRunner:
tool_calls=msg.get("tool_calls"), tool_calls=msg.get("tool_calls"),
tool_call_id=msg.get("tool_call_id"), tool_call_id=msg.get("tool_call_id"),
reasoning=msg.get("reasoning"), reasoning=msg.get("reasoning"),
reasoning_content=msg.get("reasoning_content"),
) )
except Exception: except Exception:
pass # Best-effort copy pass # Best-effort copy

View file

@ -1147,6 +1147,10 @@ class SessionStore:
tool_name=message.get("tool_name"), tool_name=message.get("tool_name"),
tool_calls=message.get("tool_calls"), tool_calls=message.get("tool_calls"),
tool_call_id=message.get("tool_call_id"), tool_call_id=message.get("tool_call_id"),
reasoning=message.get("reasoning") if message.get("role") == "assistant" else None,
reasoning_content=message.get("reasoning_content") if message.get("role") == "assistant" else None,
reasoning_details=message.get("reasoning_details") if message.get("role") == "assistant" else None,
codex_reasoning_items=message.get("codex_reasoning_items") if message.get("role") == "assistant" else None,
) )
except Exception as e: except Exception as e:
logger.debug("Session DB operation failed: %s", e) logger.debug("Session DB operation failed: %s", e)
@ -1176,6 +1180,7 @@ class SessionStore:
tool_calls=msg.get("tool_calls"), tool_calls=msg.get("tool_calls"),
tool_call_id=msg.get("tool_call_id"), tool_call_id=msg.get("tool_call_id"),
reasoning=msg.get("reasoning") if role == "assistant" else None, reasoning=msg.get("reasoning") if role == "assistant" else None,
reasoning_content=msg.get("reasoning_content") if role == "assistant" else None,
reasoning_details=msg.get("reasoning_details") if role == "assistant" else None, reasoning_details=msg.get("reasoning_details") if role == "assistant" else None,
codex_reasoning_items=msg.get("codex_reasoning_items") if role == "assistant" else None, codex_reasoning_items=msg.get("codex_reasoning_items") if role == "assistant" else None,
) )

View file

@ -31,7 +31,7 @@ T = TypeVar("T")
DEFAULT_DB_PATH = get_hermes_home() / "state.db" DEFAULT_DB_PATH = get_hermes_home() / "state.db"
SCHEMA_VERSION = 6 SCHEMA_VERSION = 7
SCHEMA_SQL = """ SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS schema_version ( CREATE TABLE IF NOT EXISTS schema_version (
@ -80,6 +80,7 @@ CREATE TABLE IF NOT EXISTS messages (
token_count INTEGER, token_count INTEGER,
finish_reason TEXT, finish_reason TEXT,
reasoning TEXT, reasoning TEXT,
reasoning_content TEXT,
reasoning_details TEXT, reasoning_details TEXT,
codex_reasoning_items TEXT codex_reasoning_items TEXT
); );
@ -329,6 +330,15 @@ class SessionDB:
except sqlite3.OperationalError: except sqlite3.OperationalError:
pass # Column already exists pass # Column already exists
cursor.execute("UPDATE schema_version SET version = 6") cursor.execute("UPDATE schema_version SET version = 6")
if current_version < 7:
# v7: preserve provider-native reasoning_content separately from
# normalized reasoning text. Kimi/Moonshot replay can require
# this field on assistant tool-call messages when thinking is on.
try:
cursor.execute('ALTER TABLE messages ADD COLUMN "reasoning_content" TEXT')
except sqlite3.OperationalError:
pass # Column already exists
cursor.execute("UPDATE schema_version SET version = 7")
# Unique title index — always ensure it exists (safe to run after migrations # Unique title index — always ensure it exists (safe to run after migrations
# since the title column is guaranteed to exist at this point) # since the title column is guaranteed to exist at this point)
@ -922,6 +932,7 @@ class SessionDB:
token_count: int = None, token_count: int = None,
finish_reason: str = None, finish_reason: str = None,
reasoning: str = None, reasoning: str = None,
reasoning_content: str = None,
reasoning_details: Any = None, reasoning_details: Any = None,
codex_reasoning_items: Any = None, codex_reasoning_items: Any = None,
) -> int: ) -> int:
@ -951,8 +962,8 @@ class SessionDB:
cursor = conn.execute( cursor = conn.execute(
"""INSERT INTO messages (session_id, role, content, tool_call_id, """INSERT INTO messages (session_id, role, content, tool_call_id,
tool_calls, tool_name, timestamp, token_count, finish_reason, tool_calls, tool_name, timestamp, token_count, finish_reason,
reasoning, reasoning_details, codex_reasoning_items) reasoning, reasoning_content, reasoning_details, codex_reasoning_items)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
( (
session_id, session_id,
role, role,
@ -964,6 +975,7 @@ class SessionDB:
token_count, token_count,
finish_reason, finish_reason,
reasoning, reasoning,
reasoning_content,
reasoning_details_json, reasoning_details_json,
codex_items_json, codex_items_json,
), ),
@ -1014,7 +1026,7 @@ class SessionDB:
with self._lock: with self._lock:
cursor = self._conn.execute( cursor = self._conn.execute(
"SELECT role, content, tool_call_id, tool_calls, tool_name, " "SELECT role, content, tool_call_id, tool_calls, tool_name, "
"reasoning, reasoning_details, codex_reasoning_items " "reasoning, reasoning_content, reasoning_details, codex_reasoning_items "
"FROM messages WHERE session_id = ? ORDER BY timestamp, id", "FROM messages WHERE session_id = ? ORDER BY timestamp, id",
(session_id,), (session_id,),
) )
@ -1038,6 +1050,8 @@ class SessionDB:
if row["role"] == "assistant": if row["role"] == "assistant":
if row["reasoning"]: if row["reasoning"]:
msg["reasoning"] = row["reasoning"] msg["reasoning"] = row["reasoning"]
if row["reasoning_content"] is not None:
msg["reasoning_content"] = row["reasoning_content"]
if row["reasoning_details"]: if row["reasoning_details"]:
try: try:
msg["reasoning_details"] = json.loads(row["reasoning_details"]) msg["reasoning_details"] = json.loads(row["reasoning_details"])

View file

@ -2966,6 +2966,7 @@ class AIAgent:
tool_call_id=msg.get("tool_call_id"), tool_call_id=msg.get("tool_call_id"),
finish_reason=msg.get("finish_reason"), finish_reason=msg.get("finish_reason"),
reasoning=msg.get("reasoning") if role == "assistant" else None, reasoning=msg.get("reasoning") if role == "assistant" else None,
reasoning_content=msg.get("reasoning_content") if role == "assistant" else None,
reasoning_details=msg.get("reasoning_details") if role == "assistant" else None, reasoning_details=msg.get("reasoning_details") if role == "assistant" else None,
codex_reasoning_items=msg.get("codex_reasoning_items") if role == "assistant" else None, codex_reasoning_items=msg.get("codex_reasoning_items") if role == "assistant" else None,
) )
@ -7003,6 +7004,11 @@ class AIAgent:
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
if hasattr(assistant_message, "reasoning_content"):
raw_reasoning_content = getattr(assistant_message, "reasoning_content", None)
if raw_reasoning_content is not None:
msg["reasoning_content"] = _sanitize_surrogates(raw_reasoning_content)
if hasattr(assistant_message, 'reasoning_details') and assistant_message.reasoning_details: if hasattr(assistant_message, 'reasoning_details') and assistant_message.reasoning_details:
# Pass reasoning_details back unmodified so providers (OpenRouter, # Pass reasoning_details back unmodified so providers (OpenRouter,
# Anthropic, OpenAI) can maintain reasoning continuity across turns. # Anthropic, OpenAI) can maintain reasoning continuity across turns.
@ -7077,6 +7083,30 @@ class AIAgent:
return msg return msg
def _copy_reasoning_content_for_api(self, source_msg: dict, api_msg: dict) -> None:
"""Copy provider-facing reasoning fields onto an API replay message."""
if source_msg.get("role") != "assistant":
return
explicit_reasoning = source_msg.get("reasoning_content")
if isinstance(explicit_reasoning, str):
api_msg["reasoning_content"] = explicit_reasoning
return
normalized_reasoning = source_msg.get("reasoning")
if isinstance(normalized_reasoning, str) and normalized_reasoning:
api_msg["reasoning_content"] = normalized_reasoning
return
kimi_requires_reasoning = (
self.provider in {"kimi-coding", "kimi-coding-cn"}
or base_url_host_matches(self.base_url, "api.kimi.com")
or base_url_host_matches(self.base_url, "moonshot.ai")
or base_url_host_matches(self.base_url, "moonshot.cn")
)
if kimi_requires_reasoning and source_msg.get("tool_calls"):
api_msg["reasoning_content"] = ""
@staticmethod @staticmethod
def _sanitize_tool_calls_for_strict_api(api_msg: dict) -> dict: def _sanitize_tool_calls_for_strict_api(api_msg: dict) -> dict:
"""Strip Codex Responses API fields from tool_calls for strict providers. """Strip Codex Responses API fields from tool_calls for strict providers.
@ -7160,10 +7190,7 @@ class AIAgent:
api_messages = [] api_messages = []
for msg in messages: for msg in messages:
api_msg = msg.copy() api_msg = msg.copy()
if msg.get("role") == "assistant": self._copy_reasoning_content_for_api(msg, api_msg)
reasoning = msg.get("reasoning")
if reasoning:
api_msg["reasoning_content"] = reasoning
api_msg.pop("reasoning", None) api_msg.pop("reasoning", None)
api_msg.pop("finish_reason", None) api_msg.pop("finish_reason", None)
api_msg.pop("_flush_sentinel", None) api_msg.pop("_flush_sentinel", None)
@ -8923,11 +8950,7 @@ class AIAgent:
# For ALL assistant messages, pass reasoning back to the API # For ALL assistant messages, pass reasoning back to the API
# This ensures multi-turn reasoning context is preserved # This ensures multi-turn reasoning context is preserved
if msg.get("role") == "assistant": self._copy_reasoning_content_for_api(msg, api_msg)
reasoning_text = msg.get("reasoning")
if reasoning_text:
# Add reasoning_content for API compatibility (Moonshot AI, Novita, OpenRouter)
api_msg["reasoning_content"] = reasoning_text
# Remove 'reasoning' field - it's for trajectory storage only # Remove 'reasoning' field - it's for trajectory storage only
# We've copied it to 'reasoning_content' for the API above # We've copied it to 'reasoning_content' for the API above

View file

@ -1059,6 +1059,7 @@ class TestRewriteTranscriptPreservesReasoning:
role="assistant", role="assistant",
content="The answer is 42.", content="The answer is 42.",
reasoning="I need to think step by step.", reasoning="I need to think step by step.",
reasoning_content="provider scratchpad",
reasoning_details=[{"type": "summary", "text": "step by step"}], reasoning_details=[{"type": "summary", "text": "step by step"}],
codex_reasoning_items=[{"id": "r1", "type": "reasoning"}], codex_reasoning_items=[{"id": "r1", "type": "reasoning"}],
) )
@ -1066,6 +1067,7 @@ class TestRewriteTranscriptPreservesReasoning:
# Verify all three were stored # Verify all three were stored
before = db.get_messages_as_conversation(session_id) before = db.get_messages_as_conversation(session_id)
assert before[0].get("reasoning") == "I need to think step by step." assert before[0].get("reasoning") == "I need to think step by step."
assert before[0].get("reasoning_content") == "provider scratchpad"
assert before[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}] assert before[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}]
assert before[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}] assert before[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}]
@ -1082,5 +1084,6 @@ class TestRewriteTranscriptPreservesReasoning:
# Load again — all three reasoning fields must survive # Load again — all three reasoning fields must survive
after = db.get_messages_as_conversation(session_id) after = db.get_messages_as_conversation(session_id)
assert after[0].get("reasoning") == "I need to think step by step." assert after[0].get("reasoning") == "I need to think step by step."
assert after[0].get("reasoning_content") == "provider scratchpad"
assert after[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}] assert after[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}]
assert after[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}] assert after[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}]

View file

@ -1216,6 +1216,15 @@ class TestBuildAssistantMessage:
result = agent._build_assistant_message(msg, "stop") result = agent._build_assistant_message(msg, "stop")
assert result["reasoning"] == "thinking" assert result["reasoning"] == "thinking"
def test_reasoning_content_preserved_separately(self, agent):
msg = _mock_assistant_msg(
content="answer",
reasoning="summary",
reasoning_content="provider scratchpad",
)
result = agent._build_assistant_message(msg, "stop")
assert result["reasoning_content"] == "provider scratchpad"
def test_with_tool_calls(self, agent): def test_with_tool_calls(self, agent):
tc = _mock_tool_call(name="web_search", arguments='{"q":"test"}', call_id="c1") tc = _mock_tool_call(name="web_search", arguments='{"q":"test"}', call_id="c1")
msg = _mock_assistant_msg(content="", tool_calls=[tc]) msg = _mock_assistant_msg(content="", tool_calls=[tc])
@ -4188,6 +4197,90 @@ class TestPersistUserMessageOverride:
assert first_db_write["content"] == "Hello there" assert first_db_write["content"] == "Hello there"
class TestReasoningReplayForStrictProviders:
"""Assistant replay must preserve provider-native reasoning fields."""
def _setup_agent(self, agent):
agent._cached_system_prompt = "You are helpful."
agent._use_prompt_caching = False
agent.tool_delay = 0
agent.compression_enabled = False
agent.save_trajectories = False
def test_kimi_tool_replay_includes_empty_reasoning_content(self, agent):
self._setup_agent(agent)
agent.base_url = "https://api.kimi.com/coding/v1"
agent._base_url_lower = agent.base_url.lower()
agent.provider = "kimi-coding"
prior_assistant = {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "c1",
"type": "function",
"function": {"name": "terminal", "arguments": "{\"command\":\"date\"}"},
}
],
}
tool_result = {"role": "tool", "tool_call_id": "c1", "content": "Tue Apr 21"}
final_resp = _mock_response(content="done", finish_reason="stop")
agent.client.chat.completions.create.return_value = final_resp
with (
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
):
result = agent.run_conversation(
"next step",
conversation_history=[prior_assistant, tool_result],
)
assert result["completed"] is True
sent_messages = agent.client.chat.completions.create.call_args.kwargs["messages"]
replayed_assistant = next(msg for msg in sent_messages if msg.get("role") == "assistant")
assert replayed_assistant["role"] == "assistant"
assert replayed_assistant["tool_calls"][0]["function"]["name"] == "terminal"
assert "reasoning_content" in replayed_assistant
assert replayed_assistant["reasoning_content"] == ""
def test_explicit_reasoning_content_beats_normalized_reasoning_on_replay(self, agent):
self._setup_agent(agent)
prior_assistant = {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "c1",
"type": "function",
"function": {"name": "web_search", "arguments": "{\"q\":\"test\"}"},
}
],
"reasoning": "summary reasoning",
"reasoning_content": "provider-native scratchpad",
}
tool_result = {"role": "tool", "tool_call_id": "c1", "content": "ok"}
final_resp = _mock_response(content="done", finish_reason="stop")
agent.client.chat.completions.create.return_value = final_resp
with (
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
):
result = agent.run_conversation(
"next step",
conversation_history=[prior_assistant, tool_result],
)
assert result["completed"] is True
sent_messages = agent.client.chat.completions.create.call_args.kwargs["messages"]
replayed_assistant = next(msg for msg in sent_messages if msg.get("role") == "assistant")
assert replayed_assistant["reasoning_content"] == "provider-native scratchpad"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Bugfix: _vprint force=True on error messages during TTS # Bugfix: _vprint force=True on error messages during TTS
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View file

@ -255,6 +255,38 @@ class TestMessageStorage:
assert msg["reasoning"] == "Thinking about what to say" assert msg["reasoning"] == "Thinking about what to say"
assert msg["reasoning_details"] == details assert msg["reasoning_details"] == details
def test_reasoning_content_persisted_and_restored(self, db):
"""reasoning_content must survive session replay as its own field."""
db.create_session(session_id="s1", source="cli")
db.append_message(
"s1",
role="assistant",
content="Hello",
reasoning="Short summary",
reasoning_content="Longer provider-native scratchpad",
)
conv = db.get_messages_as_conversation("s1")
assert len(conv) == 1
assert conv[0]["reasoning"] == "Short summary"
assert conv[0]["reasoning_content"] == "Longer provider-native scratchpad"
def test_reasoning_content_empty_string_restored_for_assistant(self, db):
"""Empty reasoning_content still needs to round-trip for strict replays."""
db.create_session(session_id="s1", source="cli")
db.append_message(
"s1",
role="assistant",
content="",
tool_calls=[{"id": "c1", "type": "function", "function": {"name": "date", "arguments": "{}"}}],
reasoning_content="",
)
conv = db.get_messages_as_conversation("s1")
assert len(conv) == 1
assert "reasoning_content" in conv[0]
assert conv[0]["reasoning_content"] == ""
def test_reasoning_not_set_for_non_assistant(self, db): def test_reasoning_not_set_for_non_assistant(self, db):
"""reasoning is never leaked onto user or tool messages.""" """reasoning is never leaked onto user or tool messages."""
db.create_session(session_id="s1", source="telegram") db.create_session(session_id="s1", source="telegram")
@ -1120,7 +1152,7 @@ class TestSchemaInit:
def test_schema_version(self, db): def test_schema_version(self, db):
cursor = db._conn.execute("SELECT version FROM schema_version") cursor = db._conn.execute("SELECT version FROM schema_version")
version = cursor.fetchone()[0] version = cursor.fetchone()[0]
assert version == 6 assert version == 7
def test_title_column_exists(self, db): def test_title_column_exists(self, db):
"""Verify the title column was created in the sessions table.""" """Verify the title column was created in the sessions table."""
@ -1176,12 +1208,12 @@ class TestSchemaInit:
conn.commit() conn.commit()
conn.close() conn.close()
# Open with SessionDB — should migrate to v6 # Open with SessionDB — should migrate to v7
migrated_db = SessionDB(db_path=db_path) migrated_db = SessionDB(db_path=db_path)
# Verify migration # Verify migration
cursor = migrated_db._conn.execute("SELECT version FROM schema_version") cursor = migrated_db._conn.execute("SELECT version FROM schema_version")
assert cursor.fetchone()[0] == 6 assert cursor.fetchone()[0] == 7
# Verify title column exists and is NULL for existing sessions # Verify title column exists and is NULL for existing sessions
session = migrated_db.get_session("existing") session = migrated_db.get_session("existing")