mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
feat: sort tool search results by score and add corresponding unit test
This commit is contained in:
parent
b867171291
commit
9bdfcd1b93
2 changed files with 70 additions and 3 deletions
|
|
@ -509,19 +509,24 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
||||||
result = resp.get("result", {})
|
result = resp.get("result", {})
|
||||||
|
|
||||||
# Format results for the model — keep it concise
|
# Format results for the model — keep it concise
|
||||||
formatted = []
|
scored_entries = []
|
||||||
for ctx_type in ("memories", "resources", "skills"):
|
for ctx_type in ("memories", "resources", "skills"):
|
||||||
items = result.get(ctx_type, [])
|
items = result.get(ctx_type, [])
|
||||||
for item in items:
|
for item in items:
|
||||||
|
raw_score = item.get("score")
|
||||||
|
sort_score = raw_score if raw_score is not None else 0.0
|
||||||
entry = {
|
entry = {
|
||||||
"uri": item.get("uri", ""),
|
"uri": item.get("uri", ""),
|
||||||
"type": ctx_type.rstrip("s"),
|
"type": ctx_type.rstrip("s"),
|
||||||
"score": round(item.get("score", 0), 3),
|
"score": round(raw_score, 3) if raw_score is not None else 0.0,
|
||||||
"abstract": item.get("abstract", ""),
|
"abstract": item.get("abstract", ""),
|
||||||
}
|
}
|
||||||
if item.get("relations"):
|
if item.get("relations"):
|
||||||
entry["related"] = [r.get("uri") for r in item["relations"][:3]]
|
entry["related"] = [r.get("uri") for r in item["relations"][:3]]
|
||||||
formatted.append(entry)
|
scored_entries.append((sort_score, entry))
|
||||||
|
|
||||||
|
scored_entries.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
formatted = [entry for _, entry in scored_entries]
|
||||||
|
|
||||||
return json.dumps({
|
return json.dumps({
|
||||||
"results": formatted,
|
"results": formatted,
|
||||||
|
|
|
||||||
62
tests/plugins/memory/test_openviking_provider.py
Normal file
62
tests/plugins/memory/test_openviking_provider.py
Normal file
|
|
@ -0,0 +1,62 @@
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from plugins.memory.openviking import OpenVikingMemoryProvider
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_search_sorts_by_raw_score_across_buckets():
|
||||||
|
provider = OpenVikingMemoryProvider()
|
||||||
|
provider._client = MagicMock()
|
||||||
|
provider._client.post.return_value = {
|
||||||
|
"result": {
|
||||||
|
"memories": [
|
||||||
|
{"uri": "viking://memories/1", "score": 0.9003, "abstract": "memory result"},
|
||||||
|
],
|
||||||
|
"resources": [
|
||||||
|
{"uri": "viking://resources/1", "score": 0.9004, "abstract": "resource result"},
|
||||||
|
],
|
||||||
|
"skills": [
|
||||||
|
{"uri": "viking://skills/1", "score": 0.8999, "abstract": "skill result"},
|
||||||
|
],
|
||||||
|
"total": 3,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = json.loads(provider._tool_search({"query": "ranking"}))
|
||||||
|
|
||||||
|
assert [entry["uri"] for entry in result["results"]] == [
|
||||||
|
"viking://resources/1",
|
||||||
|
"viking://memories/1",
|
||||||
|
"viking://skills/1",
|
||||||
|
]
|
||||||
|
assert [entry["score"] for entry in result["results"]] == [0.9, 0.9, 0.9]
|
||||||
|
assert result["total"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_search_sorts_missing_raw_score_after_negative_scores():
|
||||||
|
provider = OpenVikingMemoryProvider()
|
||||||
|
provider._client = MagicMock()
|
||||||
|
provider._client.post.return_value = {
|
||||||
|
"result": {
|
||||||
|
"memories": [
|
||||||
|
{"uri": "viking://memories/missing", "abstract": "missing score"},
|
||||||
|
],
|
||||||
|
"resources": [
|
||||||
|
{"uri": "viking://resources/negative", "score": -0.25, "abstract": "negative score"},
|
||||||
|
],
|
||||||
|
"skills": [
|
||||||
|
{"uri": "viking://skills/positive", "score": 0.1, "abstract": "positive score"},
|
||||||
|
],
|
||||||
|
"total": 3,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = json.loads(provider._tool_search({"query": "ranking"}))
|
||||||
|
|
||||||
|
assert [entry["uri"] for entry in result["results"]] == [
|
||||||
|
"viking://skills/positive",
|
||||||
|
"viking://memories/missing",
|
||||||
|
"viking://resources/negative",
|
||||||
|
]
|
||||||
|
assert [entry["score"] for entry in result["results"]] == [0.1, 0.0, -0.25]
|
||||||
|
assert result["total"] == 3
|
||||||
Loading…
Add table
Add a link
Reference in a new issue