fix: deep review — prefix matching, tool_calls extraction, query perf, serialization

Issues found and fixed during deep code path review:

1. CRITICAL: Prefix matching returned wrong prices for dated model names
   - 'gpt-4o-mini-2024-07-18' matched gpt-4o ($2.50) instead of gpt-4o-mini ($0.15)
   - Same for o3-mini→o3 (9x), gpt-4.1-mini→gpt-4.1 (5x), gpt-4.1-nano→gpt-4.1 (20x)
   - Fix: use longest-match-wins strategy instead of first-match
   - Removed dangerous key.startswith(bare) reverse matching

2. CRITICAL: Top Tools section was empty for CLI sessions
   - run_agent.py doesn't set tool_name on tool response messages (pre-existing)
   - Insights now also extracts tool names from tool_calls JSON on assistant
     messages, which IS populated for all sessions
   - Uses max() merge strategy to avoid double-counting between sources

3. SELECT * replaced with explicit column list
   - Skips system_prompt and model_config blobs (can be thousands of chars)
   - Reduces memory and I/O for large session counts

4. Sets in overview dict converted to sorted lists
   - models_with_pricing / models_without_pricing were Python sets
   - Sets aren't JSON-serializable — would crash json.dumps()

5. Negative duration guard
   - end > start check prevents negative durations from clock drift

6. Model breakdown sort fallback
   - When all tokens are 0, now sorts by session count instead of arbitrary order

7. Removed unused timedelta import

Added 6 new tests: dated model pricing (4), tool_calls JSON extraction,
JSON serialization safety. Total: 69 tests.
This commit is contained in:
teknium1 2026-03-06 14:50:57 -08:00
parent 75f523f5c0
commit 585f8528b2
2 changed files with 169 additions and 19 deletions

View file

@ -16,9 +16,10 @@ Usage:
print(engine.format_terminal(report)) print(engine.format_terminal(report))
""" """
import json
import time import time
from collections import Counter, defaultdict from collections import Counter, defaultdict
from datetime import datetime, timedelta from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
# ========================================================================= # =========================================================================
@ -82,12 +83,18 @@ def _get_pricing(model_name: str) -> Dict[str, float]:
if bare in MODEL_PRICING: if bare in MODEL_PRICING:
return MODEL_PRICING[bare] return MODEL_PRICING[bare]
# Fuzzy prefix match # Fuzzy prefix match — prefer the LONGEST matching key to avoid
# e.g. "gpt-4o" matching before "gpt-4o-mini" for "gpt-4o-mini-2024-07-18"
best_match = None
best_len = 0
for key, price in MODEL_PRICING.items(): for key, price in MODEL_PRICING.items():
if bare.startswith(key) or key.startswith(bare): if bare.startswith(key) and len(key) > best_len:
return price best_match = price
best_len = len(key)
if best_match:
return best_match
# Keyword heuristics # Keyword heuristics (checked in most-specific-first order)
if "opus" in bare: if "opus" in bare:
return {"input": 15.00, "output": 75.00} return {"input": 15.00, "output": 75.00}
if "sonnet" in bare: if "sonnet" in bare:
@ -211,26 +218,39 @@ class InsightsEngine:
# Data gathering (SQL queries) # Data gathering (SQL queries)
# ========================================================================= # =========================================================================
# Columns we actually need (skip system_prompt, model_config blobs)
_SESSION_COLS = ("id, source, model, started_at, ended_at, "
"message_count, tool_call_count, input_tokens, output_tokens")
def _get_sessions(self, cutoff: float, source: str = None) -> List[Dict]: def _get_sessions(self, cutoff: float, source: str = None) -> List[Dict]:
"""Fetch sessions within the time window.""" """Fetch sessions within the time window."""
if source: if source:
cursor = self._conn.execute( cursor = self._conn.execute(
"""SELECT * FROM sessions f"""SELECT {self._SESSION_COLS} FROM sessions
WHERE started_at >= ? AND source = ? WHERE started_at >= ? AND source = ?
ORDER BY started_at DESC""", ORDER BY started_at DESC""",
(cutoff, source), (cutoff, source),
) )
else: else:
cursor = self._conn.execute( cursor = self._conn.execute(
"""SELECT * FROM sessions f"""SELECT {self._SESSION_COLS} FROM sessions
WHERE started_at >= ? WHERE started_at >= ?
ORDER BY started_at DESC""", ORDER BY started_at DESC""",
(cutoff,), (cutoff,),
) )
return [dict(row) for row in cursor.fetchall()] return [dict(row) for row in cursor.fetchall()]
def _get_tool_usage(self, cutoff: float, source: str = None) -> List[Dict]: def _get_tool_usage(self, cutoff: float, source: str = None) -> List[Dict]:
"""Get tool call counts from messages.""" """Get tool call counts from messages.
Uses two sources:
1. tool_name column on 'tool' role messages (set by gateway)
2. tool_calls JSON on 'assistant' role messages (covers CLI where
tool_name is not populated on tool responses)
"""
tool_counts = Counter()
# Source 1: explicit tool_name on tool response messages
if source: if source:
cursor = self._conn.execute( cursor = self._conn.execute(
"""SELECT m.tool_name, COUNT(*) as count """SELECT m.tool_name, COUNT(*) as count
@ -253,7 +273,64 @@ class InsightsEngine:
ORDER BY count DESC""", ORDER BY count DESC""",
(cutoff,), (cutoff,),
) )
return [dict(row) for row in cursor.fetchall()] for row in cursor.fetchall():
tool_counts[row["tool_name"]] += row["count"]
# Source 2: extract from tool_calls JSON on assistant messages
# (covers CLI sessions where tool_name is NULL on tool responses)
if source:
cursor2 = self._conn.execute(
"""SELECT m.tool_calls
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ? AND s.source = ?
AND m.role = 'assistant' AND m.tool_calls IS NOT NULL""",
(cutoff, source),
)
else:
cursor2 = self._conn.execute(
"""SELECT m.tool_calls
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ?
AND m.role = 'assistant' AND m.tool_calls IS NOT NULL""",
(cutoff,),
)
tool_calls_counts = Counter()
for row in cursor2.fetchall():
try:
calls = row["tool_calls"]
if isinstance(calls, str):
calls = json.loads(calls)
if isinstance(calls, list):
for call in calls:
func = call.get("function", {}) if isinstance(call, dict) else {}
name = func.get("name")
if name:
tool_calls_counts[name] += 1
except (json.JSONDecodeError, TypeError, AttributeError):
continue
# Merge: prefer tool_name source, supplement with tool_calls source
# for tools not already counted
if not tool_counts and tool_calls_counts:
# No tool_name data at all — use tool_calls exclusively
tool_counts = tool_calls_counts
elif tool_counts and tool_calls_counts:
# Both sources have data — use whichever has the higher count per tool
# (they may overlap, so take the max to avoid double-counting)
all_tools = set(tool_counts) | set(tool_calls_counts)
merged = Counter()
for tool in all_tools:
merged[tool] = max(tool_counts.get(tool, 0), tool_calls_counts.get(tool, 0))
tool_counts = merged
# Convert to the expected format
return [
{"tool_name": name, "count": count}
for name, count in tool_counts.most_common()
]
def _get_message_stats(self, cutoff: float, source: str = None) -> Dict: def _get_message_stats(self, cutoff: float, source: str = None) -> Dict:
"""Get aggregate message statistics.""" """Get aggregate message statistics."""
@ -314,12 +391,12 @@ class InsightsEngine:
else: else:
models_without_pricing.add(display) models_without_pricing.add(display)
# Session duration stats # Session duration stats (guard against negative durations from clock drift)
durations = [] durations = []
for s in sessions: for s in sessions:
start = s.get("started_at") start = s.get("started_at")
end = s.get("ended_at") end = s.get("ended_at")
if start and end: if start and end and end > start:
durations.append(end - start) durations.append(end - start)
total_hours = sum(durations) / 3600 if durations else 0 total_hours = sum(durations) / 3600 if durations else 0
@ -347,8 +424,8 @@ class InsightsEngine:
"tool_messages": message_stats.get("tool_messages") or 0, "tool_messages": message_stats.get("tool_messages") or 0,
"date_range_start": date_range_start, "date_range_start": date_range_start,
"date_range_end": date_range_end, "date_range_end": date_range_end,
"models_with_pricing": models_with_pricing, "models_with_pricing": sorted(models_with_pricing),
"models_without_pricing": models_without_pricing, "models_without_pricing": sorted(models_without_pricing),
} }
def _compute_model_breakdown(self, sessions: List[Dict]) -> List[Dict]: def _compute_model_breakdown(self, sessions: List[Dict]) -> List[Dict]:
@ -377,7 +454,8 @@ class InsightsEngine:
{"model": model, **data} {"model": model, **data}
for model, data in model_data.items() for model, data in model_data.items()
] ]
result.sort(key=lambda x: x["total_tokens"], reverse=True) # Sort by tokens first, fall back to session count when tokens are 0
result.sort(key=lambda x: (x["total_tokens"], x["sessions"]), reverse=True)
return result return result
def _compute_platform_breakdown(self, sessions: List[Dict]) -> List[Dict]: def _compute_platform_breakdown(self, sessions: List[Dict]) -> List[Dict]:

View file

@ -176,6 +176,26 @@ class TestPricing:
pricing = _get_pricing("gemini-3.0-ultra") pricing = _get_pricing("gemini-3.0-ultra")
assert pricing["input"] == 0.15 assert pricing["input"] == 0.15
def test_dated_model_gpt4o_mini(self):
"""gpt-4o-mini-2024-07-18 should match gpt-4o-mini, NOT gpt-4o."""
pricing = _get_pricing("gpt-4o-mini-2024-07-18")
assert pricing["input"] == 0.15 # gpt-4o-mini price, not gpt-4o's 2.50
def test_dated_model_o3_mini(self):
"""o3-mini-2025-01-31 should match o3-mini, NOT o3."""
pricing = _get_pricing("o3-mini-2025-01-31")
assert pricing["input"] == 1.10 # o3-mini price, not o3's 10.00
def test_dated_model_gpt41_mini(self):
"""gpt-4.1-mini-2025-04-14 should match gpt-4.1-mini, NOT gpt-4.1."""
pricing = _get_pricing("gpt-4.1-mini-2025-04-14")
assert pricing["input"] == 0.40 # gpt-4.1-mini, not gpt-4.1's 2.00
def test_dated_model_gpt41_nano(self):
"""gpt-4.1-nano-2025-04-14 should match gpt-4.1-nano, NOT gpt-4.1."""
pricing = _get_pricing("gpt-4.1-nano-2025-04-14")
assert pricing["input"] == 0.10 # gpt-4.1-nano, not gpt-4.1's 2.00
class TestHasKnownPricing: class TestHasKnownPricing:
def test_known_commercial_model(self): def test_known_commercial_model(self):
@ -585,6 +605,58 @@ class TestEdgeCases:
assert custom["cost"] == 0.0 assert custom["cost"] == 0.0
assert custom["has_pricing"] is False assert custom["has_pricing"] is False
def test_tool_usage_from_tool_calls_json(self, db):
"""Tool usage should be extracted from tool_calls JSON when tool_name is NULL."""
import json as _json
db.create_session(session_id="s1", source="cli", model="test")
# Assistant message with tool_calls (this is what CLI produces)
db.append_message("s1", role="assistant", content="Let me search",
tool_calls=[{"id": "call_1", "type": "function",
"function": {"name": "search_files", "arguments": "{}"}}])
# Tool response WITHOUT tool_name (this is the CLI bug)
db.append_message("s1", role="tool", content="found results",
tool_call_id="call_1")
db.append_message("s1", role="assistant", content="Now reading",
tool_calls=[{"id": "call_2", "type": "function",
"function": {"name": "read_file", "arguments": "{}"}}])
db.append_message("s1", role="tool", content="file content",
tool_call_id="call_2")
db.append_message("s1", role="assistant", content="And searching again",
tool_calls=[{"id": "call_3", "type": "function",
"function": {"name": "search_files", "arguments": "{}"}}])
db.append_message("s1", role="tool", content="more results",
tool_call_id="call_3")
db._conn.commit()
engine = InsightsEngine(db)
report = engine.generate(days=30)
tools = report["tools"]
# Should find tools from tool_calls JSON even though tool_name is NULL
tool_names = [t["tool"] for t in tools]
assert "search_files" in tool_names
assert "read_file" in tool_names
# search_files was called twice
sf = next(t for t in tools if t["tool"] == "search_files")
assert sf["count"] == 2
def test_overview_pricing_sets_are_lists(self, db):
"""models_with/without_pricing should be JSON-serializable lists."""
import json as _json
db.create_session(session_id="s1", source="cli", model="gpt-4o")
db.create_session(session_id="s2", source="cli", model="my-custom")
db._conn.commit()
engine = InsightsEngine(db)
report = engine.generate(days=30)
overview = report["overview"]
assert isinstance(overview["models_with_pricing"], list)
assert isinstance(overview["models_without_pricing"], list)
# Should be JSON-serializable
_json.dumps(report["overview"]) # would raise if sets present
def test_mixed_commercial_and_custom_models(self, db): def test_mixed_commercial_and_custom_models(self, db):
"""Mix of commercial and custom models: only commercial ones get costs.""" """Mix of commercial and custom models: only commercial ones get costs."""
db.create_session(session_id="s1", source="cli", model="gpt-4o") db.create_session(session_id="s1", source="cli", model="gpt-4o")
@ -599,7 +671,7 @@ class TestEdgeCases:
# Cost should only come from gpt-4o, not from the custom model # Cost should only come from gpt-4o, not from the custom model
overview = report["overview"] overview = report["overview"]
assert overview["estimated_cost"] > 0 assert overview["estimated_cost"] > 0
assert "gpt-4o" in overview["models_with_pricing"] assert "gpt-4o" in overview["models_with_pricing"] # list now, not set
assert "my-local-llama" in overview["models_without_pricing"] assert "my-local-llama" in overview["models_without_pricing"]
# Verify individual model entries # Verify individual model entries