mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-05 02:31:47 +00:00
Merge remote-tracking branch 'origin/main' into hermes/hermes-1f7bfa9e
# Conflicts: # cron/scheduler.py # tools/send_message_tool.py
This commit is contained in:
commit
e7fc6450fc
99 changed files with 9609 additions and 1075 deletions
|
|
@ -74,6 +74,26 @@ class FakeBot:
|
|||
return None
|
||||
|
||||
|
||||
class SlowSyncTree(FakeTree):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.started = asyncio.Event()
|
||||
self.allow_finish = asyncio.Event()
|
||||
|
||||
async def _slow_sync():
|
||||
self.started.set()
|
||||
await self.allow_finish.wait()
|
||||
return []
|
||||
|
||||
self.sync = AsyncMock(side_effect=_slow_sync)
|
||||
|
||||
|
||||
class SlowSyncBot(FakeBot):
|
||||
def __init__(self, *, intents, proxy=None):
|
||||
super().__init__(intents=intents, proxy=proxy)
|
||||
self.tree = SlowSyncTree()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("allowed_users", "expected_members_intent"),
|
||||
|
|
@ -138,3 +158,36 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch):
|
|||
assert ok is False
|
||||
assert released == [("discord-bot-token", "test-token")]
|
||||
assert adapter._platform_lock_identity is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_does_not_wait_for_slash_sync(monkeypatch):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="test-token"))
|
||||
|
||||
monkeypatch.setattr("gateway.status.acquire_scoped_lock", lambda scope, identity, metadata=None: (True, None))
|
||||
monkeypatch.setattr("gateway.status.release_scoped_lock", lambda scope, identity: None)
|
||||
|
||||
intents = SimpleNamespace(message_content=False, dm_messages=False, guild_messages=False, members=False, voice_states=False)
|
||||
monkeypatch.setattr(discord_platform.Intents, "default", lambda: intents)
|
||||
|
||||
created = {}
|
||||
|
||||
def fake_bot_factory(*, command_prefix, intents, proxy=None):
|
||||
bot = SlowSyncBot(intents=intents, proxy=proxy)
|
||||
created["bot"] = bot
|
||||
return bot
|
||||
|
||||
monkeypatch.setattr(discord_platform.commands, "Bot", fake_bot_factory)
|
||||
monkeypatch.setattr(adapter, "_resolve_allowed_usernames", AsyncMock())
|
||||
|
||||
ok = await asyncio.wait_for(adapter.connect(), timeout=1.0)
|
||||
|
||||
assert ok is True
|
||||
assert adapter._ready_event.is_set()
|
||||
|
||||
await asyncio.wait_for(created["bot"].tree.started.wait(), timeout=1.0)
|
||||
assert created["bot"].tree.sync.await_count == 1
|
||||
|
||||
created["bot"].tree.allow_finish.set()
|
||||
await asyncio.sleep(0)
|
||||
await adapter.disconnect()
|
||||
|
|
|
|||
355
tests/gateway/test_display_config.py
Normal file
355
tests/gateway/test_display_config.py
Normal file
|
|
@ -0,0 +1,355 @@
|
|||
"""Tests for gateway.display_config — per-platform display/verbosity resolver."""
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Resolver: resolution order
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResolveDisplaySetting:
|
||||
"""resolve_display_setting() resolves with correct priority."""
|
||||
|
||||
def test_explicit_platform_override_wins(self):
|
||||
"""display.platforms.<plat>.<key> takes top priority."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "all",
|
||||
"platforms": {
|
||||
"telegram": {"tool_progress": "verbose"},
|
||||
},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "verbose"
|
||||
|
||||
def test_global_setting_when_no_platform_override(self):
|
||||
"""Falls back to display.<key> when no platform override exists."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "new",
|
||||
"platforms": {},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "new"
|
||||
|
||||
def test_platform_default_when_no_user_config(self):
|
||||
"""Falls back to built-in platform default."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
# Empty config — should get built-in defaults
|
||||
config = {}
|
||||
# Telegram defaults to tier_high → "all"
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "all"
|
||||
# Email defaults to tier_minimal → "off"
|
||||
assert resolve_display_setting(config, "email", "tool_progress") == "off"
|
||||
|
||||
def test_global_default_for_unknown_platform(self):
|
||||
"""Unknown platforms get the global defaults."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {}
|
||||
# Unknown platform, no config → global default "all"
|
||||
assert resolve_display_setting(config, "unknown_platform", "tool_progress") == "all"
|
||||
|
||||
def test_fallback_parameter_used_last(self):
|
||||
"""Explicit fallback is used when nothing else matches."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {}
|
||||
# "nonexistent_key" isn't in any defaults
|
||||
result = resolve_display_setting(config, "telegram", "nonexistent_key", "my_fallback")
|
||||
assert result == "my_fallback"
|
||||
|
||||
def test_platform_override_only_affects_that_platform(self):
|
||||
"""Other platforms are unaffected by a specific platform override."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "all",
|
||||
"platforms": {
|
||||
"slack": {"tool_progress": "off"},
|
||||
},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "slack", "tool_progress") == "off"
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "all"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backward compatibility: tool_progress_overrides
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBackwardCompat:
|
||||
"""Legacy tool_progress_overrides is still respected as a fallback."""
|
||||
|
||||
def test_legacy_overrides_read(self):
|
||||
"""tool_progress_overrides is read when no platforms entry exists."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "all",
|
||||
"tool_progress_overrides": {
|
||||
"signal": "off",
|
||||
"telegram": "verbose",
|
||||
},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "signal", "tool_progress") == "off"
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "verbose"
|
||||
|
||||
def test_new_platforms_takes_precedence_over_legacy(self):
|
||||
"""display.platforms beats tool_progress_overrides."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "all",
|
||||
"tool_progress_overrides": {"telegram": "verbose"},
|
||||
"platforms": {"telegram": {"tool_progress": "new"}},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "new"
|
||||
|
||||
def test_legacy_overrides_only_for_tool_progress(self):
|
||||
"""Legacy overrides don't affect other settings."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress_overrides": {"telegram": "verbose"},
|
||||
}
|
||||
}
|
||||
# show_reasoning should NOT read from tool_progress_overrides
|
||||
assert resolve_display_setting(config, "telegram", "show_reasoning") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# YAML normalisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestYAMLNormalisation:
|
||||
"""YAML 1.1 quirks (bare off → False, on → True) are handled."""
|
||||
|
||||
def test_tool_progress_false_normalised_to_off(self):
|
||||
"""YAML's bare `off` parses as False — normalised to 'off' string."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"tool_progress": False}}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "off"
|
||||
|
||||
def test_tool_progress_true_normalised_to_all(self):
|
||||
"""YAML's bare `on` parses as True — normalised to 'all'."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"tool_progress": True}}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "all"
|
||||
|
||||
def test_show_reasoning_string_true(self):
|
||||
"""String 'true' is normalised to bool True."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"platforms": {"telegram": {"show_reasoning": "true"}}}}
|
||||
assert resolve_display_setting(config, "telegram", "show_reasoning") is True
|
||||
|
||||
def test_tool_preview_length_string(self):
|
||||
"""String numbers are normalised to int."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"platforms": {"slack": {"tool_preview_length": "80"}}}}
|
||||
assert resolve_display_setting(config, "slack", "tool_preview_length") == 80
|
||||
|
||||
def test_platform_override_false_tool_progress(self):
|
||||
"""Per-platform bare off → normalised."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"platforms": {"slack": {"tool_progress": False}}}}
|
||||
assert resolve_display_setting(config, "slack", "tool_progress") == "off"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Built-in platform defaults (tier system)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPlatformDefaults:
|
||||
"""Built-in defaults reflect platform capability tiers."""
|
||||
|
||||
def test_high_tier_platforms(self):
|
||||
"""Telegram and Discord default to 'all' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("telegram", "discord"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "all", plat
|
||||
|
||||
def test_medium_tier_platforms(self):
|
||||
"""Slack, Mattermost, Matrix default to 'new' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("slack", "mattermost", "matrix", "feishu"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "new", plat
|
||||
|
||||
def test_low_tier_platforms(self):
|
||||
"""Signal, WhatsApp, etc. default to 'off' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("signal", "whatsapp", "bluebubbles", "weixin", "wecom", "dingtalk"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "off", plat
|
||||
|
||||
def test_minimal_tier_platforms(self):
|
||||
"""Email, SMS, webhook default to 'off' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("email", "sms", "webhook", "homeassistant"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "off", plat
|
||||
|
||||
def test_low_tier_streaming_defaults_to_false(self):
|
||||
"""Low-tier platforms default streaming to False."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
assert resolve_display_setting({}, "signal", "streaming") is False
|
||||
assert resolve_display_setting({}, "email", "streaming") is False
|
||||
|
||||
def test_high_tier_streaming_defaults_to_none(self):
|
||||
"""High-tier platforms default streaming to None (follow global)."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
assert resolve_display_setting({}, "telegram", "streaming") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_effective_display / get_platform_defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHelpers:
|
||||
"""Helper functions return correct composite results."""
|
||||
|
||||
def test_get_effective_display_merges_correctly(self):
|
||||
from gateway.display_config import get_effective_display
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "new",
|
||||
"show_reasoning": True,
|
||||
"platforms": {
|
||||
"telegram": {"tool_progress": "verbose"},
|
||||
},
|
||||
}
|
||||
}
|
||||
eff = get_effective_display(config, "telegram")
|
||||
assert eff["tool_progress"] == "verbose" # platform override
|
||||
assert eff["show_reasoning"] is True # global
|
||||
assert "tool_preview_length" in eff # default filled in
|
||||
|
||||
def test_get_platform_defaults_returns_dict(self):
|
||||
from gateway.display_config import get_platform_defaults
|
||||
|
||||
defaults = get_platform_defaults("telegram")
|
||||
assert "tool_progress" in defaults
|
||||
assert "show_reasoning" in defaults
|
||||
# Returns a new dict (not the shared tier dict)
|
||||
defaults["tool_progress"] = "changed"
|
||||
assert get_platform_defaults("telegram")["tool_progress"] != "changed"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config migration: tool_progress_overrides → display.platforms
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestConfigMigration:
|
||||
"""Version 16 migration moves tool_progress_overrides into display.platforms."""
|
||||
|
||||
def test_migration_creates_platforms_entries(self, tmp_path, monkeypatch):
|
||||
"""Old overrides are migrated into display.platforms.<plat>.tool_progress."""
|
||||
import yaml
|
||||
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config = {
|
||||
"_config_version": 15,
|
||||
"display": {
|
||||
"tool_progress_overrides": {
|
||||
"signal": "off",
|
||||
"telegram": "all",
|
||||
},
|
||||
},
|
||||
}
|
||||
config_path.write_text(yaml.dump(config))
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
# Re-import to pick up the new HERMES_HOME
|
||||
import importlib
|
||||
import hermes_cli.config as cfg_mod
|
||||
importlib.reload(cfg_mod)
|
||||
|
||||
result = cfg_mod.migrate_config(interactive=False, quiet=True)
|
||||
# Re-read config
|
||||
updated = yaml.safe_load(config_path.read_text())
|
||||
platforms = updated.get("display", {}).get("platforms", {})
|
||||
assert platforms.get("signal", {}).get("tool_progress") == "off"
|
||||
assert platforms.get("telegram", {}).get("tool_progress") == "all"
|
||||
|
||||
def test_migration_preserves_existing_platforms_entries(self, tmp_path, monkeypatch):
|
||||
"""Existing display.platforms entries are NOT overwritten by migration."""
|
||||
import yaml
|
||||
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config = {
|
||||
"_config_version": 15,
|
||||
"display": {
|
||||
"tool_progress_overrides": {"telegram": "off"},
|
||||
"platforms": {"telegram": {"tool_progress": "verbose"}},
|
||||
},
|
||||
}
|
||||
config_path.write_text(yaml.dump(config))
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
import importlib
|
||||
import hermes_cli.config as cfg_mod
|
||||
importlib.reload(cfg_mod)
|
||||
|
||||
cfg_mod.migrate_config(interactive=False, quiet=True)
|
||||
updated = yaml.safe_load(config_path.read_text())
|
||||
# Existing "verbose" should NOT be overwritten by legacy "off"
|
||||
assert updated["display"]["platforms"]["telegram"]["tool_progress"] == "verbose"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming per-platform (None = follow global)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestStreamingPerPlatform:
|
||||
"""Streaming per-platform override semantics."""
|
||||
|
||||
def test_none_means_follow_global(self):
|
||||
"""When streaming is None, the caller should use global config."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {}
|
||||
# Telegram has no streaming override in defaults → None
|
||||
result = resolve_display_setting(config, "telegram", "streaming")
|
||||
assert result is None # caller should check global StreamingConfig
|
||||
|
||||
def test_explicit_false_disables(self):
|
||||
"""Explicit False disables streaming for that platform."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"platforms": {"telegram": {"streaming": False}},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "telegram", "streaming") is False
|
||||
|
||||
def test_explicit_true_enables(self):
|
||||
"""Explicit True enables streaming for that platform."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"platforms": {"email": {"streaming": True}},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "email", "streaming") is True
|
||||
|
|
@ -157,12 +157,44 @@ def _make_fake_mautrix():
|
|||
mautrix_crypto_store = types.ModuleType("mautrix.crypto.store")
|
||||
|
||||
class MemoryCryptoStore:
|
||||
def __init__(self, account_id="", pickle_key=""):
|
||||
def __init__(self, account_id="", pickle_key=""): # noqa: S301
|
||||
self.account_id = account_id
|
||||
self.pickle_key = pickle_key
|
||||
|
||||
mautrix_crypto_store.MemoryCryptoStore = MemoryCryptoStore
|
||||
|
||||
# --- mautrix.crypto.store.asyncpg ---
|
||||
mautrix_crypto_store_asyncpg = types.ModuleType("mautrix.crypto.store.asyncpg")
|
||||
|
||||
class PgCryptoStore:
|
||||
upgrade_table = MagicMock()
|
||||
|
||||
def __init__(self, account_id="", pickle_key="", db=None): # noqa: S301
|
||||
self.account_id = account_id
|
||||
self.pickle_key = pickle_key
|
||||
self.db = db
|
||||
|
||||
async def open(self):
|
||||
pass
|
||||
|
||||
mautrix_crypto_store_asyncpg.PgCryptoStore = PgCryptoStore
|
||||
|
||||
# --- mautrix.util ---
|
||||
mautrix_util = types.ModuleType("mautrix.util")
|
||||
|
||||
# --- mautrix.util.async_db ---
|
||||
mautrix_util_async_db = types.ModuleType("mautrix.util.async_db")
|
||||
|
||||
class Database:
|
||||
@classmethod
|
||||
def create(cls, url, upgrade_table=None):
|
||||
db = MagicMock()
|
||||
db.start = AsyncMock()
|
||||
db.stop = AsyncMock()
|
||||
return db
|
||||
|
||||
mautrix_util_async_db.Database = Database
|
||||
|
||||
return {
|
||||
"mautrix": mautrix,
|
||||
"mautrix.api": mautrix_api,
|
||||
|
|
@ -171,6 +203,9 @@ def _make_fake_mautrix():
|
|||
"mautrix.client.state_store": mautrix_client_state_store,
|
||||
"mautrix.crypto": mautrix_crypto,
|
||||
"mautrix.crypto.store": mautrix_crypto_store,
|
||||
"mautrix.crypto.store.asyncpg": mautrix_crypto_store_asyncpg,
|
||||
"mautrix.util": mautrix_util,
|
||||
"mautrix.util.async_db": mautrix_util_async_db,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -740,6 +775,12 @@ class TestMatrixAccessTokenAuth:
|
|||
mock_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "DEV123"))
|
||||
mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}})
|
||||
mock_client.add_event_handler = MagicMock()
|
||||
mock_client.handle_sync = MagicMock(return_value=[])
|
||||
mock_client.query_keys = AsyncMock(return_value={
|
||||
"device_keys": {"@bot:example.org": {"DEV123": {
|
||||
"keys": {"ed25519:DEV123": "fake_ed25519_key"},
|
||||
}}},
|
||||
})
|
||||
mock_client.api = MagicMock()
|
||||
mock_client.api.token = "syt_test_access_token"
|
||||
mock_client.api.session = MagicMock()
|
||||
|
|
@ -751,6 +792,8 @@ class TestMatrixAccessTokenAuth:
|
|||
mock_olm.share_keys = AsyncMock()
|
||||
mock_olm.share_keys_min_trust = None
|
||||
mock_olm.send_keys_min_trust = None
|
||||
mock_olm.account = MagicMock()
|
||||
mock_olm.account.identity_keys = {"ed25519": "fake_ed25519_key"}
|
||||
|
||||
# Patch Client constructor to return our mock
|
||||
fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client)
|
||||
|
|
@ -924,6 +967,12 @@ class TestMatrixDeviceId:
|
|||
mock_client.whoami = AsyncMock(return_value=MagicMock(user_id="@bot:example.org", device_id="WHOAMI_DEV"))
|
||||
mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}})
|
||||
mock_client.add_event_handler = MagicMock()
|
||||
mock_client.handle_sync = MagicMock(return_value=[])
|
||||
mock_client.query_keys = AsyncMock(return_value={
|
||||
"device_keys": {"@bot:example.org": {"MY_STABLE_DEVICE": {
|
||||
"keys": {"ed25519:MY_STABLE_DEVICE": "fake_ed25519_key"},
|
||||
}}},
|
||||
})
|
||||
mock_client.api = MagicMock()
|
||||
mock_client.api.token = "syt_test_access_token"
|
||||
mock_client.api.session = MagicMock()
|
||||
|
|
@ -934,6 +983,8 @@ class TestMatrixDeviceId:
|
|||
mock_olm.share_keys = AsyncMock()
|
||||
mock_olm.share_keys_min_trust = None
|
||||
mock_olm.send_keys_min_trust = None
|
||||
mock_olm.account = MagicMock()
|
||||
mock_olm.account.identity_keys = {"ed25519": "fake_ed25519_key"}
|
||||
|
||||
fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client)
|
||||
fake_mautrix_mods["mautrix.crypto"].OlmMachine = MagicMock(return_value=mock_olm)
|
||||
|
|
@ -1030,8 +1081,8 @@ class TestMatrixDeviceIdConfig:
|
|||
|
||||
class TestMatrixSyncLoop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_loop_shares_keys_when_encryption_enabled(self):
|
||||
"""_sync_loop should call crypto.share_keys() after each sync."""
|
||||
async def test_sync_loop_dispatches_events_and_stores_token(self):
|
||||
"""_sync_loop should call handle_sync() and persist next_batch."""
|
||||
adapter = _make_adapter()
|
||||
adapter._encryption = True
|
||||
adapter._closing = False
|
||||
|
|
@ -1046,7 +1097,6 @@ class TestMatrixSyncLoop:
|
|||
return {"rooms": {"join": {"!room:example.org": {}}}, "next_batch": "s1234"}
|
||||
|
||||
mock_crypto = MagicMock()
|
||||
mock_crypto.share_keys = AsyncMock()
|
||||
|
||||
mock_sync_store = MagicMock()
|
||||
mock_sync_store.get_next_batch = AsyncMock(return_value=None)
|
||||
|
|
@ -1062,7 +1112,6 @@ class TestMatrixSyncLoop:
|
|||
await adapter._sync_loop()
|
||||
|
||||
fake_client.sync.assert_awaited_once()
|
||||
mock_crypto.share_keys.assert_awaited_once()
|
||||
fake_client.handle_sync.assert_called_once()
|
||||
mock_sync_store.put_next_batch.assert_awaited_once_with("s1234")
|
||||
|
||||
|
|
@ -1248,6 +1297,12 @@ class TestMatrixEncryptedEventHandler:
|
|||
mock_client.whoami = AsyncMock(return_value=MagicMock(user_id="@bot:example.org", device_id="DEV123"))
|
||||
mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}})
|
||||
mock_client.add_event_handler = MagicMock()
|
||||
mock_client.handle_sync = MagicMock(return_value=[])
|
||||
mock_client.query_keys = AsyncMock(return_value={
|
||||
"device_keys": {"@bot:example.org": {"DEV123": {
|
||||
"keys": {"ed25519:DEV123": "fake_ed25519_key"},
|
||||
}}},
|
||||
})
|
||||
mock_client.api = MagicMock()
|
||||
mock_client.api.token = "syt_test_token"
|
||||
mock_client.api.session = MagicMock()
|
||||
|
|
@ -1258,6 +1313,8 @@ class TestMatrixEncryptedEventHandler:
|
|||
mock_olm.share_keys = AsyncMock()
|
||||
mock_olm.share_keys_min_trust = None
|
||||
mock_olm.send_keys_min_trust = None
|
||||
mock_olm.account = MagicMock()
|
||||
mock_olm.account.identity_keys = {"ed25519": "fake_ed25519_key"}
|
||||
|
||||
fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client)
|
||||
fake_mautrix_mods["mautrix.crypto"].OlmMachine = MagicMock(return_value=mock_olm)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from gateway.run import _dequeue_pending_event
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
|
|
@ -79,6 +80,26 @@ class TestQueueMessageStorage:
|
|||
# Should be consumed (cleared)
|
||||
assert adapter.get_pending_message(session_key) is None
|
||||
|
||||
def test_dequeue_pending_event_preserves_voice_media_metadata(self):
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:voice"
|
||||
event = MessageEvent(
|
||||
text="",
|
||||
message_type=MessageType.VOICE,
|
||||
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
|
||||
message_id="voice-q1",
|
||||
media_urls=["/tmp/voice.ogg"],
|
||||
media_types=["audio/ogg"],
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
retrieved = _dequeue_pending_event(adapter, session_key)
|
||||
|
||||
assert retrieved is event
|
||||
assert retrieved.media_urls == ["/tmp/voice.ogg"]
|
||||
assert retrieved.media_types == ["audio/ogg"]
|
||||
assert adapter.get_pending_message(session_key) is None
|
||||
|
||||
def test_queue_does_not_set_interrupt_event(self):
|
||||
"""The whole point of /queue — no interrupt signal."""
|
||||
adapter = _StubAdapter()
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from types import SimpleNamespace
|
|||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, SendResult
|
||||
from gateway.config import Platform, PlatformConfig, StreamingConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
|
|
@ -104,6 +104,11 @@ def _make_runner(adapter):
|
|||
runner._session_db = None
|
||||
runner._running_agents = {}
|
||||
runner.hooks = SimpleNamespace(loaded_hooks=False)
|
||||
runner.config = SimpleNamespace(
|
||||
thread_sessions_per_user=False,
|
||||
group_sessions_per_user=False,
|
||||
stt_enabled=False,
|
||||
)
|
||||
return runner
|
||||
|
||||
|
||||
|
|
@ -118,6 +123,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
|
|||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = FakeAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
import tools.terminal_tool # noqa: F401 - register terminal emoji for this fake-agent test
|
||||
|
||||
adapter = ProgressCaptureAdapter()
|
||||
runner = _make_runner(adapter)
|
||||
|
|
@ -144,7 +150,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
|
|||
assert adapter.sent == [
|
||||
{
|
||||
"chat_id": "-1001",
|
||||
"content": '⚙️ terminal: "pwd"',
|
||||
"content": '💻 terminal: "pwd"',
|
||||
"reply_to": None,
|
||||
"metadata": {"thread_id": "17585"},
|
||||
}
|
||||
|
|
@ -334,3 +340,238 @@ def test_all_mode_no_truncation_when_preview_fits(monkeypatch, tmp_path):
|
|||
content = adapter.sent[0]["content"]
|
||||
# With a 200-char cap, the 165-char command should NOT be truncated
|
||||
assert "..." not in content, f"Preview was truncated when it shouldn't be: {content}"
|
||||
|
||||
|
||||
class CommentaryAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||
self.stream_delta_callback = kwargs.get("stream_delta_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
if self.interim_assistant_callback:
|
||||
self.interim_assistant_callback("I'll inspect the repo first.", already_streamed=False)
|
||||
time.sleep(0.1)
|
||||
if self.stream_delta_callback:
|
||||
self.stream_delta_callback("done")
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
class PreviewedResponseAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
if self.interim_assistant_callback:
|
||||
self.interim_assistant_callback("You're welcome.", already_streamed=False)
|
||||
return {
|
||||
"final_response": "You're welcome.",
|
||||
"response_previewed": True,
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
class QueuedCommentaryAgent:
|
||||
calls = 0
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
type(self).calls += 1
|
||||
if type(self).calls == 1 and self.interim_assistant_callback:
|
||||
self.interim_assistant_callback("I'll inspect the repo first.", already_streamed=False)
|
||||
return {
|
||||
"final_response": f"final response {type(self).calls}",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
async def _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
agent_cls,
|
||||
*,
|
||||
session_id,
|
||||
pending_text=None,
|
||||
config_data=None,
|
||||
):
|
||||
if config_data:
|
||||
import yaml
|
||||
|
||||
(tmp_path / "config.yaml").write_text(yaml.dump(config_data), encoding="utf-8")
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = agent_cls
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
adapter = ProgressCaptureAdapter()
|
||||
runner = _make_runner(adapter)
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
if config_data and "streaming" in config_data:
|
||||
runner.config.streaming = StreamingConfig.from_dict(config_data["streaming"])
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
)
|
||||
session_key = "agent:main:telegram:group:-1001:17585"
|
||||
if pending_text is not None:
|
||||
adapter._pending_messages[session_key] = MessageEvent(
|
||||
text=pending_text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id="queued-1",
|
||||
)
|
||||
|
||||
result = await runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id=session_id,
|
||||
session_key=session_key,
|
||||
)
|
||||
return adapter, result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_surfaces_real_interim_commentary(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary",
|
||||
config_data={"display": {"interim_assistant_messages": True}},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is not True
|
||||
assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_surfaces_interim_commentary_by_default(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-default-on",
|
||||
)
|
||||
|
||||
assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_suppresses_interim_commentary_when_disabled(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-disabled",
|
||||
config_data={"display": {"interim_assistant_messages": False}},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is not True
|
||||
assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_tool_progress_does_not_control_interim_commentary(monkeypatch, tmp_path):
|
||||
"""tool_progress=all with interim_assistant_messages=false should not surface commentary."""
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-tool-progress",
|
||||
config_data={"display": {"tool_progress": "all", "interim_assistant_messages": False}},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is not True
|
||||
assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_streaming_does_not_enable_completed_interim_commentary(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
"""Streaming alone with interim_assistant_messages=false should not surface commentary."""
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-streaming",
|
||||
config_data={
|
||||
"display": {"tool_progress": "off", "interim_assistant_messages": False},
|
||||
"streaming": {"enabled": True},
|
||||
},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is True
|
||||
assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_interim_commentary_works_with_tool_progress_off(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-explicit-on",
|
||||
config_data={
|
||||
"display": {
|
||||
"tool_progress": "off",
|
||||
"interim_assistant_messages": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is not True
|
||||
assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_previewed_final_marks_already_sent(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
PreviewedResponseAgent,
|
||||
session_id="sess-previewed",
|
||||
config_data={"display": {"interim_assistant_messages": True}},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is True
|
||||
assert [call["content"] for call in adapter.sent] == ["You're welcome."]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_queued_message_does_not_treat_commentary_as_final(monkeypatch, tmp_path):
|
||||
QueuedCommentaryAgent.calls = 0
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
QueuedCommentaryAgent,
|
||||
session_id="sess-queued-commentary",
|
||||
pending_text="queued follow-up",
|
||||
config_data={"display": {"interim_assistant_messages": True}},
|
||||
)
|
||||
|
||||
sent_texts = [call["content"] for call in adapter.sent]
|
||||
assert result["final_response"] == "final response 2"
|
||||
assert "I'll inspect the repo first." in sent_texts
|
||||
assert "final response 1" in sent_texts
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
|
|
@ -45,6 +46,23 @@ class _DisabledAdapter(BasePlatformAdapter):
|
|||
return {"id": chat_id}
|
||||
|
||||
|
||||
class _SuccessfulAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.DISCORD)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._mark_disconnected()
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
|
@ -65,7 +83,7 @@ async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch,
|
|||
state = read_runtime_status()
|
||||
assert state["gateway_state"] == "startup_failed"
|
||||
assert "temporary DNS resolution failure" in state["exit_reason"]
|
||||
assert state["platforms"]["telegram"]["state"] == "fatal"
|
||||
assert state["platforms"]["telegram"]["state"] == "retrying"
|
||||
assert state["platforms"]["telegram"]["error_code"] == "telegram_connect_error"
|
||||
|
||||
|
||||
|
|
@ -89,6 +107,31 @@ async def test_runner_allows_cron_only_mode_when_no_platforms_are_enabled(monkey
|
|||
assert state["gateway_state"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_records_connected_platform_state_on_success(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(enabled=True, token="***")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _SuccessfulAdapter())
|
||||
monkeypatch.setattr(runner.hooks, "discover_and_load", lambda: None)
|
||||
monkeypatch.setattr(runner.hooks, "emit", AsyncMock())
|
||||
|
||||
ok = await runner.start()
|
||||
|
||||
assert ok is True
|
||||
state = read_runtime_status()
|
||||
assert state["gateway_state"] == "running"
|
||||
assert state["platforms"]["discord"]["state"] == "connected"
|
||||
assert state["platforms"]["discord"]["error_code"] is None
|
||||
assert state["platforms"]["discord"]["error_message"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_gateway_replace_force_uses_terminate_pid(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import os
|
||||
|
||||
from gateway.config import Platform
|
||||
|
|
@ -130,3 +131,99 @@ def test_set_session_env_handles_missing_optional_fields():
|
|||
assert get_session_env("HERMES_SESSION_THREAD_ID") == ""
|
||||
|
||||
runner._clear_session_env(tokens)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SESSION_KEY contextvars tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_session_key_set_via_contextvars(monkeypatch):
|
||||
"""set_session_vars should set HERMES_SESSION_KEY via contextvars."""
|
||||
monkeypatch.delenv("HERMES_SESSION_KEY", raising=False)
|
||||
|
||||
tokens = set_session_vars(
|
||||
platform="telegram",
|
||||
chat_id="-1001",
|
||||
session_key="tg:-1001:17585",
|
||||
)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "tg:-1001:17585"
|
||||
|
||||
clear_session_vars(tokens)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == ""
|
||||
|
||||
|
||||
def test_session_key_falls_back_to_os_environ(monkeypatch):
|
||||
"""get_session_env for SESSION_KEY should fall back to os.environ."""
|
||||
monkeypatch.setenv("HERMES_SESSION_KEY", "env-session-123")
|
||||
|
||||
# No contextvar set — should read from os.environ
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "env-session-123"
|
||||
|
||||
# Set contextvar — should prefer it
|
||||
tokens = set_session_vars(session_key="ctx-session-456")
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "ctx-session-456"
|
||||
|
||||
# Restore — should fall back to os.environ
|
||||
clear_session_vars(tokens)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "env-session-123"
|
||||
|
||||
|
||||
def test_set_session_env_includes_session_key():
|
||||
"""_set_session_env should propagate session_key from SessionContext."""
|
||||
runner = object.__new__(GatewayRunner)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_name="Group",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
)
|
||||
context = SessionContext(
|
||||
source=source,
|
||||
connected_platforms=[],
|
||||
home_channels={},
|
||||
session_key="tg:-1001:17585",
|
||||
)
|
||||
|
||||
tokens = runner._set_session_env(context)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "tg:-1001:17585"
|
||||
runner._clear_session_env(tokens)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == ""
|
||||
|
||||
|
||||
def test_session_key_no_race_condition_with_contextvars(monkeypatch):
|
||||
"""Prove contextvars isolates SESSION_KEY across concurrent async tasks.
|
||||
|
||||
Two tasks set different session keys. With contextvars each task
|
||||
reads back its own value. With os.environ the second task would
|
||||
overwrite the first (the old bug).
|
||||
"""
|
||||
monkeypatch.delenv("HERMES_SESSION_KEY", raising=False)
|
||||
|
||||
results = {}
|
||||
|
||||
async def handler(key: str, delay: float):
|
||||
tokens = set_session_vars(session_key=key)
|
||||
try:
|
||||
await asyncio.sleep(delay)
|
||||
read_back = get_session_env("HERMES_SESSION_KEY")
|
||||
results[key] = read_back
|
||||
finally:
|
||||
clear_session_vars(tokens)
|
||||
|
||||
async def run():
|
||||
task_a = asyncio.create_task(handler("session-A", 0.15))
|
||||
await asyncio.sleep(0.05)
|
||||
task_b = asyncio.create_task(handler("session-B", 0.05))
|
||||
await asyncio.gather(task_a, task_b)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
# Both tasks must read back their own session key
|
||||
assert results["session-A"] == "session-A", (
|
||||
f"Session A got '{results['session-A']}' instead of 'session-A' — race condition!"
|
||||
)
|
||||
assert results["session-B"] == "session-B", (
|
||||
f"Session B got '{results['session-B']}' instead of 'session-B' — race condition!"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -104,6 +104,34 @@ class TestGatewayRuntimeStatus:
|
|||
assert payload["platforms"]["telegram"]["error_code"] == "telegram_polling_conflict"
|
||||
assert payload["platforms"]["telegram"]["error_message"] == "another poller is active"
|
||||
|
||||
def test_write_runtime_status_explicit_none_clears_stale_fields(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
status.write_runtime_status(
|
||||
gateway_state="startup_failed",
|
||||
exit_reason="stale error",
|
||||
platform="discord",
|
||||
platform_state="fatal",
|
||||
error_code="discord_timeout",
|
||||
error_message="stale platform error",
|
||||
)
|
||||
|
||||
status.write_runtime_status(
|
||||
gateway_state="running",
|
||||
exit_reason=None,
|
||||
platform="discord",
|
||||
platform_state="connected",
|
||||
error_code=None,
|
||||
error_message=None,
|
||||
)
|
||||
|
||||
payload = status.read_runtime_status()
|
||||
assert payload["gateway_state"] == "running"
|
||||
assert payload["exit_reason"] is None
|
||||
assert payload["platforms"]["discord"]["state"] == "connected"
|
||||
assert payload["platforms"]["discord"]["error_code"] is None
|
||||
assert payload["platforms"]["discord"]["error_message"] is None
|
||||
|
||||
|
||||
class TestTerminatePid:
|
||||
def test_force_uses_taskkill_on_windows(self, monkeypatch):
|
||||
|
|
|
|||
|
|
@ -505,3 +505,81 @@ class TestSegmentBreakOnToolBoundary:
|
|||
assert len(sent_texts) == 3
|
||||
assert sent_texts[0].startswith(prefix)
|
||||
assert sum(len(t) for t in sent_texts[1:]) == len(tail)
|
||||
|
||||
|
||||
class TestInterimCommentaryMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_commentary_message_stays_separate_from_final_stream(self):
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock(side_effect=[
|
||||
SimpleNamespace(success=True, message_id="msg_1"),
|
||||
SimpleNamespace(success=True, message_id="msg_2"),
|
||||
])
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter,
|
||||
"chat_123",
|
||||
StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5),
|
||||
)
|
||||
|
||||
consumer.on_commentary("I'll inspect the repository first.")
|
||||
consumer.on_delta("Done.")
|
||||
consumer.finish()
|
||||
|
||||
await consumer.run()
|
||||
|
||||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
assert sent_texts == ["I'll inspect the repository first.", "Done."]
|
||||
assert consumer.final_response_sent is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_final_send_does_not_mark_final_response_sent(self):
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock(return_value=SimpleNamespace(success=False, message_id=None))
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter,
|
||||
"chat_123",
|
||||
StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5),
|
||||
)
|
||||
|
||||
consumer.on_delta("Done.")
|
||||
consumer.finish()
|
||||
|
||||
await consumer.run()
|
||||
|
||||
assert consumer.final_response_sent is False
|
||||
assert consumer.already_sent is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_without_message_id_marks_visible_and_sends_only_tail(self):
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock(side_effect=[
|
||||
SimpleNamespace(success=True, message_id=None),
|
||||
SimpleNamespace(success=True, message_id=None),
|
||||
])
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter,
|
||||
"chat_123",
|
||||
StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉"),
|
||||
)
|
||||
|
||||
consumer.on_delta("Hello")
|
||||
task = asyncio.create_task(consumer.run())
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(" world")
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.finish()
|
||||
await task
|
||||
|
||||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
assert sent_texts == ["Hello ▉", "world"]
|
||||
assert consumer.already_sent is True
|
||||
assert consumer.final_response_sent is True
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
import yaml
|
||||
|
||||
from gateway.config import GatewayConfig, load_gateway_config
|
||||
from gateway.config import GatewayConfig, Platform, load_gateway_config
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def test_gateway_config_stt_disabled_from_dict_nested():
|
||||
|
|
@ -69,3 +71,46 @@ async def test_enrich_message_with_transcription_avoids_bogus_no_provider_messag
|
|||
assert "No STT provider is configured" not in result
|
||||
assert "trouble transcribing" in result
|
||||
assert "caption" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_inbound_message_text_transcribes_queued_voice_event():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(stt_enabled=True)
|
||||
runner.adapters = {}
|
||||
runner._model = "test-model"
|
||||
runner._base_url = ""
|
||||
runner._has_setup_skill = lambda: False
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="",
|
||||
message_type=MessageType.VOICE,
|
||||
source=source,
|
||||
media_urls=["/tmp/queued-voice.ogg"],
|
||||
media_types=["audio/ogg"],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"tools.transcription_tools.transcribe_audio",
|
||||
return_value={
|
||||
"success": True,
|
||||
"transcript": "queued voice transcript",
|
||||
"provider": "local_command",
|
||||
},
|
||||
):
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=[],
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert "queued voice transcript" in result
|
||||
assert "voice message" in result.lower()
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ class TestVerboseCommand:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enabled_cycles_mode(self, tmp_path, monkeypatch):
|
||||
"""When enabled, /verbose cycles tool_progress mode."""
|
||||
"""When enabled, /verbose cycles tool_progress mode per-platform."""
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
|
|
@ -79,10 +79,11 @@ class TestVerboseCommand:
|
|||
|
||||
# all -> verbose
|
||||
assert "VERBOSE" in result
|
||||
assert "telegram" in result.lower() # per-platform feedback
|
||||
|
||||
# Verify config was saved
|
||||
# Verify config was saved to display.platforms.telegram
|
||||
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["display"]["tool_progress"] == "verbose"
|
||||
assert saved["display"]["platforms"]["telegram"]["tool_progress"] == "verbose"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cycles_through_all_modes(self, tmp_path, monkeypatch):
|
||||
|
|
@ -103,8 +104,9 @@ class TestVerboseCommand:
|
|||
for mode in expected:
|
||||
result = await runner._handle_verbose_command(_make_event())
|
||||
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["display"]["tool_progress"] == mode, \
|
||||
f"Expected {mode}, got {saved['display']['tool_progress']}"
|
||||
actual = saved["display"]["platforms"]["telegram"]["tool_progress"]
|
||||
assert actual == mode, \
|
||||
f"Expected {mode}, got {actual}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_defaults_to_all_when_no_tool_progress_set(self, tmp_path, monkeypatch):
|
||||
|
|
@ -122,10 +124,45 @@ class TestVerboseCommand:
|
|||
runner = _make_runner()
|
||||
result = await runner._handle_verbose_command(_make_event())
|
||||
|
||||
# default "all" -> verbose
|
||||
# Telegram default is "all" (high tier) → cycles to verbose
|
||||
assert "VERBOSE" in result
|
||||
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["display"]["tool_progress"] == "verbose"
|
||||
assert saved["display"]["platforms"]["telegram"]["tool_progress"] == "verbose"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_platform_isolation(self, tmp_path, monkeypatch):
|
||||
"""Cycling /verbose on Telegram doesn't change Slack's setting.
|
||||
|
||||
Without a global tool_progress, each platform uses its built-in
|
||||
default: Telegram = 'all' (high tier), Slack = 'new' (medium tier).
|
||||
"""
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
# No global tool_progress → built-in platform defaults apply
|
||||
config_path.write_text(
|
||||
"display:\n tool_progress_command: true\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
|
||||
runner = _make_runner()
|
||||
|
||||
# Cycle on Telegram
|
||||
await runner._handle_verbose_command(
|
||||
_make_event(platform=Platform.TELEGRAM)
|
||||
)
|
||||
# Cycle on Slack
|
||||
await runner._handle_verbose_command(
|
||||
_make_event(platform=Platform.SLACK)
|
||||
)
|
||||
|
||||
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
platforms = saved["display"]["platforms"]
|
||||
# Telegram: all -> verbose (high tier default = all)
|
||||
assert platforms["telegram"]["tool_progress"] == "verbose"
|
||||
# Slack: new -> all (medium tier default = new, cycle to all)
|
||||
assert platforms["slack"]["tool_progress"] == "all"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_config_file_returns_disabled(self, tmp_path, monkeypatch):
|
||||
|
|
|
|||
185
tests/gateway/test_wecom_callback.py
Normal file
185
tests/gateway/test_wecom_callback.py
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
"""Tests for the WeCom callback-mode adapter."""
|
||||
|
||||
import asyncio
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.wecom_callback import WecomCallbackAdapter
|
||||
from gateway.platforms.wecom_crypto import WXBizMsgCrypt
|
||||
|
||||
|
||||
def _app(name="test-app", corp_id="ww1234567890", agent_id="1000002"):
|
||||
return {
|
||||
"name": name,
|
||||
"corp_id": corp_id,
|
||||
"corp_secret": "test-secret",
|
||||
"agent_id": agent_id,
|
||||
"token": "test-callback-token",
|
||||
"encoding_aes_key": "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG",
|
||||
}
|
||||
|
||||
|
||||
def _config(apps=None):
|
||||
return PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"mode": "callback", "host": "127.0.0.1", "port": 0, "apps": apps or [_app()]},
|
||||
)
|
||||
|
||||
|
||||
class TestWecomCrypto:
|
||||
def test_roundtrip_encrypt_decrypt(self):
|
||||
app = _app()
|
||||
crypt = WXBizMsgCrypt(app["token"], app["encoding_aes_key"], app["corp_id"])
|
||||
encrypted_xml = crypt.encrypt(
|
||||
"<xml><Content>hello</Content></xml>", nonce="nonce123", timestamp="123456",
|
||||
)
|
||||
root = ET.fromstring(encrypted_xml)
|
||||
decrypted = crypt.decrypt(
|
||||
root.findtext("MsgSignature", default=""),
|
||||
root.findtext("TimeStamp", default=""),
|
||||
root.findtext("Nonce", default=""),
|
||||
root.findtext("Encrypt", default=""),
|
||||
)
|
||||
assert b"<Content>hello</Content>" in decrypted
|
||||
|
||||
def test_signature_mismatch_raises(self):
|
||||
app = _app()
|
||||
crypt = WXBizMsgCrypt(app["token"], app["encoding_aes_key"], app["corp_id"])
|
||||
encrypted_xml = crypt.encrypt("<xml/>", nonce="n", timestamp="1")
|
||||
root = ET.fromstring(encrypted_xml)
|
||||
from gateway.platforms.wecom_crypto import SignatureError
|
||||
with pytest.raises(SignatureError):
|
||||
crypt.decrypt("bad-sig", "1", "n", root.findtext("Encrypt", default=""))
|
||||
|
||||
|
||||
class TestWecomCallbackEventConstruction:
|
||||
def test_build_event_extracts_text_message(self):
|
||||
adapter = WecomCallbackAdapter(_config())
|
||||
xml_text = """
|
||||
<xml>
|
||||
<ToUserName>ww1234567890</ToUserName>
|
||||
<FromUserName>zhangsan</FromUserName>
|
||||
<CreateTime>1710000000</CreateTime>
|
||||
<MsgType>text</MsgType>
|
||||
<Content>\u4f60\u597d</Content>
|
||||
<MsgId>123456789</MsgId>
|
||||
</xml>
|
||||
"""
|
||||
event = adapter._build_event(_app(), xml_text)
|
||||
assert event is not None
|
||||
assert event.source is not None
|
||||
assert event.source.user_id == "zhangsan"
|
||||
assert event.source.chat_id == "ww1234567890:zhangsan"
|
||||
assert event.message_id == "123456789"
|
||||
assert event.text == "\u4f60\u597d"
|
||||
|
||||
def test_build_event_returns_none_for_subscribe(self):
|
||||
adapter = WecomCallbackAdapter(_config())
|
||||
xml_text = """
|
||||
<xml>
|
||||
<ToUserName>ww1234567890</ToUserName>
|
||||
<FromUserName>zhangsan</FromUserName>
|
||||
<CreateTime>1710000000</CreateTime>
|
||||
<MsgType>event</MsgType>
|
||||
<Event>subscribe</Event>
|
||||
</xml>
|
||||
"""
|
||||
event = adapter._build_event(_app(), xml_text)
|
||||
assert event is None
|
||||
|
||||
|
||||
class TestWecomCallbackRouting:
|
||||
def test_user_app_key_scopes_across_corps(self):
|
||||
adapter = WecomCallbackAdapter(_config())
|
||||
assert adapter._user_app_key("corpA", "alice") == "corpA:alice"
|
||||
assert adapter._user_app_key("corpB", "alice") == "corpB:alice"
|
||||
assert adapter._user_app_key("corpA", "alice") != adapter._user_app_key("corpB", "alice")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_selects_correct_app_for_scoped_chat_id(self):
|
||||
apps = [
|
||||
_app(name="corp-a", corp_id="corpA", agent_id="1001"),
|
||||
_app(name="corp-b", corp_id="corpB", agent_id="2002"),
|
||||
]
|
||||
adapter = WecomCallbackAdapter(_config(apps=apps))
|
||||
adapter._user_app_map["corpB:alice"] = "corp-b"
|
||||
adapter._access_tokens["corp-b"] = {"token": "tok-b", "expires_at": 9999999999}
|
||||
|
||||
calls = {}
|
||||
|
||||
class FakeResponse:
|
||||
def json(self):
|
||||
return {"errcode": 0, "msgid": "ok1"}
|
||||
|
||||
class FakeClient:
|
||||
async def post(self, url, json):
|
||||
calls["url"] = url
|
||||
calls["json"] = json
|
||||
return FakeResponse()
|
||||
|
||||
adapter._http_client = FakeClient()
|
||||
result = await adapter.send("corpB:alice", "hello")
|
||||
|
||||
assert result.success is True
|
||||
assert calls["json"]["touser"] == "alice"
|
||||
assert calls["json"]["agentid"] == 2002
|
||||
assert "tok-b" in calls["url"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_falls_back_from_bare_user_id_when_unique(self):
|
||||
apps = [_app(name="corp-a", corp_id="corpA", agent_id="1001")]
|
||||
adapter = WecomCallbackAdapter(_config(apps=apps))
|
||||
adapter._user_app_map["corpA:alice"] = "corp-a"
|
||||
adapter._access_tokens["corp-a"] = {"token": "tok-a", "expires_at": 9999999999}
|
||||
|
||||
calls = {}
|
||||
|
||||
class FakeResponse:
|
||||
def json(self):
|
||||
return {"errcode": 0, "msgid": "ok2"}
|
||||
|
||||
class FakeClient:
|
||||
async def post(self, url, json):
|
||||
calls["url"] = url
|
||||
calls["json"] = json
|
||||
return FakeResponse()
|
||||
|
||||
adapter._http_client = FakeClient()
|
||||
result = await adapter.send("alice", "hello")
|
||||
|
||||
assert result.success is True
|
||||
assert calls["json"]["agentid"] == 1001
|
||||
|
||||
|
||||
class TestWecomCallbackPollLoop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_loop_dispatches_handle_message(self, monkeypatch):
|
||||
adapter = WecomCallbackAdapter(_config())
|
||||
calls = []
|
||||
|
||||
async def fake_handle_message(event):
|
||||
calls.append(event.text)
|
||||
|
||||
monkeypatch.setattr(adapter, "handle_message", fake_handle_message)
|
||||
event = adapter._build_event(
|
||||
_app(),
|
||||
"""
|
||||
<xml>
|
||||
<ToUserName>ww1234567890</ToUserName>
|
||||
<FromUserName>lisi</FromUserName>
|
||||
<CreateTime>1710000000</CreateTime>
|
||||
<MsgType>text</MsgType>
|
||||
<Content>test</Content>
|
||||
<MsgId>m2</MsgId>
|
||||
</xml>
|
||||
""",
|
||||
)
|
||||
task = asyncio.create_task(adapter._poll_loop())
|
||||
await adapter._message_queue.put(event)
|
||||
await asyncio.sleep(0.05)
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
assert calls == ["test"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue