feat: sort tool search results by score and add corresponding unit test

This commit is contained in:
Disaster-Terminator 2026-04-09 22:10:27 +08:00 committed by Teknium
parent b867171291
commit 9bdfcd1b93
2 changed files with 70 additions and 3 deletions

View file

@ -509,19 +509,24 @@ class OpenVikingMemoryProvider(MemoryProvider):
result = resp.get("result", {}) result = resp.get("result", {})
# Format results for the model — keep it concise # Format results for the model — keep it concise
formatted = [] scored_entries = []
for ctx_type in ("memories", "resources", "skills"): for ctx_type in ("memories", "resources", "skills"):
items = result.get(ctx_type, []) items = result.get(ctx_type, [])
for item in items: for item in items:
raw_score = item.get("score")
sort_score = raw_score if raw_score is not None else 0.0
entry = { entry = {
"uri": item.get("uri", ""), "uri": item.get("uri", ""),
"type": ctx_type.rstrip("s"), "type": ctx_type.rstrip("s"),
"score": round(item.get("score", 0), 3), "score": round(raw_score, 3) if raw_score is not None else 0.0,
"abstract": item.get("abstract", ""), "abstract": item.get("abstract", ""),
} }
if item.get("relations"): if item.get("relations"):
entry["related"] = [r.get("uri") for r in item["relations"][:3]] entry["related"] = [r.get("uri") for r in item["relations"][:3]]
formatted.append(entry) scored_entries.append((sort_score, entry))
scored_entries.sort(key=lambda x: x[0], reverse=True)
formatted = [entry for _, entry in scored_entries]
return json.dumps({ return json.dumps({
"results": formatted, "results": formatted,

View file

@ -0,0 +1,62 @@
import json
from unittest.mock import MagicMock
from plugins.memory.openviking import OpenVikingMemoryProvider
def test_tool_search_sorts_by_raw_score_across_buckets():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._client.post.return_value = {
"result": {
"memories": [
{"uri": "viking://memories/1", "score": 0.9003, "abstract": "memory result"},
],
"resources": [
{"uri": "viking://resources/1", "score": 0.9004, "abstract": "resource result"},
],
"skills": [
{"uri": "viking://skills/1", "score": 0.8999, "abstract": "skill result"},
],
"total": 3,
}
}
result = json.loads(provider._tool_search({"query": "ranking"}))
assert [entry["uri"] for entry in result["results"]] == [
"viking://resources/1",
"viking://memories/1",
"viking://skills/1",
]
assert [entry["score"] for entry in result["results"]] == [0.9, 0.9, 0.9]
assert result["total"] == 3
def test_tool_search_sorts_missing_raw_score_after_negative_scores():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._client.post.return_value = {
"result": {
"memories": [
{"uri": "viking://memories/missing", "abstract": "missing score"},
],
"resources": [
{"uri": "viking://resources/negative", "score": -0.25, "abstract": "negative score"},
],
"skills": [
{"uri": "viking://skills/positive", "score": 0.1, "abstract": "positive score"},
],
"total": 3,
}
}
result = json.loads(provider._tool_search({"query": "ranking"}))
assert [entry["uri"] for entry in result["results"]] == [
"viking://skills/positive",
"viking://memories/missing",
"viking://resources/negative",
]
assert [entry["score"] for entry in result["results"]] == [0.1, 0.0, -0.25]
assert result["total"] == 3