mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix: use ceiling division for token estimation, deduplicate inline formula
Switch estimate_tokens_rough(), estimate_messages_tokens_rough(), and estimate_request_tokens_rough() from floor division (len // 4) to ceiling division ((len + 3) // 4). Short texts (1-3 chars) previously estimated as 0 tokens, causing the compressor and pre-flight checks to systematically undercount when many short tool results are present. Also replaced the inline duplicate formula in run_conversation() (total_chars // 4) with a call to the shared estimate_messages_tokens_rough() function. Updated 4 tests that hardcoded floor-division expected values. Related: issue #6217, PR #6629
This commit is contained in:
parent
6d272ba477
commit
5c2ecdec49
3 changed files with 19 additions and 11 deletions
|
|
@ -1045,16 +1045,21 @@ def get_model_context_length(
|
|||
|
||||
|
||||
def estimate_tokens_rough(text: str) -> int:
|
||||
"""Rough token estimate (~4 chars/token) for pre-flight checks."""
|
||||
"""Rough token estimate (~4 chars/token) for pre-flight checks.
|
||||
|
||||
Uses ceiling division so short texts (1-3 chars) never estimate as
|
||||
0 tokens, which would cause the compressor and pre-flight checks to
|
||||
systematically undercount when many short tool results are present.
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
return len(text) // 4
|
||||
return (len(text) + 3) // 4
|
||||
|
||||
|
||||
def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int:
|
||||
"""Rough token estimate for a message list (pre-flight only)."""
|
||||
total_chars = sum(len(str(msg)) for msg in messages)
|
||||
return total_chars // 4
|
||||
return (total_chars + 3) // 4
|
||||
|
||||
|
||||
def estimate_request_tokens_rough(
|
||||
|
|
@ -1077,4 +1082,4 @@ def estimate_request_tokens_rough(
|
|||
total_chars += sum(len(str(msg)) for msg in messages)
|
||||
if tools:
|
||||
total_chars += len(str(tools))
|
||||
return total_chars // 4
|
||||
return (total_chars + 3) // 4
|
||||
|
|
|
|||
|
|
@ -8049,7 +8049,7 @@ class AIAgent:
|
|||
|
||||
# Calculate approximate request size for logging
|
||||
total_chars = sum(len(str(msg)) for msg in api_messages)
|
||||
approx_tokens = total_chars // 4 # Rough estimate: 4 chars per token
|
||||
approx_tokens = estimate_messages_tokens_rough(api_messages)
|
||||
|
||||
# Thinking spinner for quiet mode (animated during API call)
|
||||
thinking_spinner = None
|
||||
|
|
|
|||
|
|
@ -50,7 +50,8 @@ class TestEstimateTokensRough:
|
|||
assert estimate_tokens_rough("a" * 400) == 100
|
||||
|
||||
def test_short_text(self):
|
||||
assert estimate_tokens_rough("hello") == 1
|
||||
# "hello" = 5 chars → ceil(5/4) = 2
|
||||
assert estimate_tokens_rough("hello") == 2
|
||||
|
||||
def test_proportional(self):
|
||||
short = estimate_tokens_rough("hello world")
|
||||
|
|
@ -68,10 +69,11 @@ class TestEstimateMessagesTokensRough:
|
|||
assert estimate_messages_tokens_rough([]) == 0
|
||||
|
||||
def test_single_message_concrete_value(self):
|
||||
"""Verify against known str(msg) length."""
|
||||
"""Verify against known str(msg) length (ceiling division)."""
|
||||
msg = {"role": "user", "content": "a" * 400}
|
||||
result = estimate_messages_tokens_rough([msg])
|
||||
expected = len(str(msg)) // 4
|
||||
n = len(str(msg))
|
||||
expected = (n + 3) // 4
|
||||
assert result == expected
|
||||
|
||||
def test_multiple_messages_additive(self):
|
||||
|
|
@ -80,7 +82,8 @@ class TestEstimateMessagesTokensRough:
|
|||
{"role": "assistant", "content": "Hi there, how can I help?"},
|
||||
]
|
||||
result = estimate_messages_tokens_rough(msgs)
|
||||
expected = sum(len(str(m)) for m in msgs) // 4
|
||||
n = sum(len(str(m)) for m in msgs)
|
||||
expected = (n + 3) // 4
|
||||
assert result == expected
|
||||
|
||||
def test_tool_call_message(self):
|
||||
|
|
@ -89,7 +92,7 @@ class TestEstimateMessagesTokensRough:
|
|||
"tool_calls": [{"id": "1", "function": {"name": "terminal", "arguments": "{}"}}]}
|
||||
result = estimate_messages_tokens_rough([msg])
|
||||
assert result > 0
|
||||
assert result == len(str(msg)) // 4
|
||||
assert result == (len(str(msg)) + 3) // 4
|
||||
|
||||
def test_message_with_list_content(self):
|
||||
"""Vision messages with multimodal content arrays."""
|
||||
|
|
@ -98,7 +101,7 @@ class TestEstimateMessagesTokensRough:
|
|||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}}
|
||||
]}
|
||||
result = estimate_messages_tokens_rough([msg])
|
||||
assert result == len(str(msg)) // 4
|
||||
assert result == (len(str(msg)) + 3) // 4
|
||||
|
||||
|
||||
# =========================================================================
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue