diff --git a/tests/gateway/relay/test_relay_registration.py b/tests/gateway/relay/test_relay_registration.py index e40e99548c5..810c4521578 100644 --- a/tests/gateway/relay/test_relay_registration.py +++ b/tests/gateway/relay/test_relay_registration.py @@ -1,7 +1,10 @@ -"""RelayAdapter registration via the platform registry (relay Phase 1, Task 1.3). +"""RelayAdapter registration via the platform registry. -Verifies the relay platform is registered ONLY behind the flag (dark-launch), -constructed through the same registry path as plugin adapters. +The relay platform is registered when a connector relay URL is configured +(``GATEWAY_RELAY_URL`` env or ``gateway.relay_url`` in config.yaml) — the same +config-driven shape as ``gateway.proxy_url``, not a separate feature flag. With +no URL configured, registration is a no-op so direct/single-tenant deployments +are unaffected. ``force=True`` registers a transport-less adapter for tests. """ from __future__ import annotations @@ -10,46 +13,54 @@ import pytest from gateway.config import PlatformConfig from gateway.platform_registry import platform_registry -from gateway.relay import register_relay_adapter, relay_enabled +from gateway.relay import register_relay_adapter, relay_url from gateway.relay.adapter import RelayAdapter @pytest.fixture(autouse=True) def _clean_registry(monkeypatch): - """Ensure each test starts/ends with no 'relay' entry and a clean env.""" - monkeypatch.delenv("HERMES_GATEWAY_RELAY", raising=False) + """Each test starts/ends with no 'relay' entry and a clean relay env.""" + monkeypatch.delenv("GATEWAY_RELAY_URL", raising=False) + monkeypatch.delenv("GATEWAY_RELAY_PLATFORM", raising=False) + monkeypatch.delenv("GATEWAY_RELAY_BOT_ID", raising=False) platform_registry.unregister("relay") yield platform_registry.unregister("relay") -def test_off_by_default(): - assert relay_enabled() is False +def test_off_when_no_url_configured(monkeypatch): + # No GATEWAY_RELAY_URL and (assuming) no gateway.relay_url in config. + monkeypatch.setattr("gateway.relay.relay_url", lambda: None) assert register_relay_adapter() is False assert platform_registry.is_registered("relay") is False -def test_enabled_by_env_flag(monkeypatch): - monkeypatch.setenv("HERMES_GATEWAY_RELAY", "1") - assert relay_enabled() is True +def test_registers_when_url_configured(monkeypatch): + monkeypatch.setenv("GATEWAY_RELAY_URL", "wss://connector.example/relay") + assert relay_url() == "wss://connector.example/relay" assert register_relay_adapter() is True assert platform_registry.is_registered("relay") is True -def test_force_registers_without_flag(): +def test_explicit_url_arg_registers(): + assert register_relay_adapter(url="wss://connector.example/relay") is True + assert platform_registry.is_registered("relay") is True + + +def test_force_registers_without_url(): assert register_relay_adapter(force=True) is True assert platform_registry.is_registered("relay") is True +def test_trailing_slash_stripped(monkeypatch): + monkeypatch.setenv("GATEWAY_RELAY_URL", "wss://connector.example/relay/") + assert relay_url() == "wss://connector.example/relay" + + def test_create_adapter_yields_relay_adapter(): + # force=True builds a transport-less adapter (no live dial in unit tests). register_relay_adapter(force=True) adapter = platform_registry.create_adapter("relay", PlatformConfig()) assert isinstance(adapter, RelayAdapter) # Placeholder descriptor until handshake negotiates the real one. assert adapter.descriptor.platform == "relay" - - -@pytest.mark.parametrize("val,expected", [("0", False), ("", False), ("true", True), ("ON", True), ("yes", True)]) -def test_flag_parsing(monkeypatch, val, expected): - monkeypatch.setenv("HERMES_GATEWAY_RELAY", val) - assert relay_enabled() is expected diff --git a/tests/gateway/relay/test_ws_transport.py b/tests/gateway/relay/test_ws_transport.py new file mode 100644 index 00000000000..dcb3f6c714f --- /dev/null +++ b/tests/gateway/relay/test_ws_transport.py @@ -0,0 +1,179 @@ +"""WebSocketRelayTransport against a real in-process WebSocket server. + +Exercises the production transport over an actual ``websockets`` server (no +mock socket): handshake (hello -> descriptor), inbound frame -> handler, +outbound request/response correlation, and follow_up routing. Proves the wire +framing (newline-delimited JSON) and the request/response future plumbing work +end to end on a live socket. + +Skipped cleanly if the optional ``websockets`` dependency is absent. +""" + +from __future__ import annotations + +import asyncio +import json + +import pytest +import pytest_asyncio + +from gateway.relay.ws_transport import WebSocketRelayTransport, WEBSOCKETS_AVAILABLE + +pytestmark = pytest.mark.skipif(not WEBSOCKETS_AVAILABLE, reason="websockets not installed") + +if WEBSOCKETS_AVAILABLE: + import websockets + + +DESCRIPTOR = { + "contract_version": 1, + "platform": "discord", + "label": "Discord", + "max_message_length": 2000, + "supports_draft_streaming": False, + "supports_edit": True, + "supports_threads": True, + "markdown_dialect": "discord", + "len_unit": "chars", +} + + +class _StubConnectorServer: + """Minimal connector: answers hello with a descriptor, echoes outbound.""" + + def __init__(self): + self.received: list[dict] = [] + self._server = None + self.url = "" + # Push channel: tests set this to a frame dict to deliver inbound. + self._to_push: list[dict] = [] + + async def start(self): + self._server = await websockets.serve(self._handle, "127.0.0.1", 0) + sock = next(iter(self._server.sockets)) + port = sock.getsockname()[1] + self.url = f"ws://127.0.0.1:{port}" + + async def stop(self): + if self._server is not None: + self._server.close() + await self._server.wait_closed() + + async def _handle(self, ws): + async for raw in ws: + for line in str(raw).split("\n"): + if not line.strip(): + continue + frame = json.loads(line) + self.received.append(frame) + await self._on_frame(ws, frame) + + async def _on_frame(self, ws, frame): + ftype = frame.get("type") + if ftype == "hello": + await ws.send(json.dumps({"type": "descriptor", "descriptor": DESCRIPTOR}) + "\n") + # Deliver any queued inbound frames right after handshake. + for f in self._to_push: + await ws.send(json.dumps(f) + "\n") + elif ftype == "outbound": + action = frame.get("action", {}) + # Echo a successful result correlated by requestId. + result = {"success": True, "message_id": f"srv-{action.get('op')}"} + await ws.send( + json.dumps({"type": "outbound_result", "requestId": frame["requestId"], "result": result}) + + "\n" + ) + + +@pytest_asyncio.fixture +async def server(): + srv = _StubConnectorServer() + await srv.start() + yield srv + await srv.stop() + + +@pytest.mark.asyncio +async def test_handshake_negotiates_descriptor(server): + t = WebSocketRelayTransport(server.url, "discord", "appShared") + await t.connect() + try: + desc = await t.handshake() + assert desc.platform == "discord" + assert desc.max_message_length == 2000 + # The hello carried the platform + botId. + hello = next(f for f in server.received if f["type"] == "hello") + assert hello["platform"] == "discord" + assert hello["botId"] == "appShared" + finally: + await t.disconnect() + + +@pytest.mark.asyncio +async def test_inbound_frame_reaches_handler(server): + server._to_push = [ + { + "type": "inbound", + "event": { + "text": "hello from connector", + "message_type": "text", + "source": {"platform": "discord", "chat_id": "chan1", "chat_type": "group", "guild_id": "guildA"}, + }, + "bufferId": "buf-1", + } + ] + received = [] + t = WebSocketRelayTransport(server.url, "discord", "appShared") + t.set_inbound_handler(lambda ev: received.append(ev) or asyncio.sleep(0)) + await t.connect() + try: + await t.handshake() + # Give the reader a tick to deliver the pushed inbound frame. + await asyncio.sleep(0.05) + assert len(received) == 1 + assert received[0].text == "hello from connector" + assert received[0].source.guild_id == "guildA" + finally: + await t.disconnect() + + +@pytest.mark.asyncio +async def test_outbound_round_trips_with_correlation(server): + t = WebSocketRelayTransport(server.url, "discord", "appShared") + await t.connect() + try: + await t.handshake() + result = await t.send_outbound({"op": "send", "chat_id": "chan1", "content": "hi"}) + assert result["success"] is True + assert result["message_id"] == "srv-send" + finally: + await t.disconnect() + + +@pytest.mark.asyncio +async def test_follow_up_round_trips(server): + t = WebSocketRelayTransport(server.url, "discord", "appShared") + await t.connect() + try: + await t.handshake() + result = await t.send_follow_up( + {"op": "follow_up", "session_key": "s1", "kind": "discord.interaction_token", "content": "fu"} + ) + assert result["success"] is True + assert result["message_id"] == "srv-follow_up" + # The follow_up rode an outbound frame the connector saw. + outbound = [f for f in server.received if f["type"] == "outbound"] + assert any(f["action"]["op"] == "follow_up" for f in outbound) + finally: + await t.disconnect() + + +@pytest.mark.asyncio +async def test_disconnect_fails_pending_waiters_cleanly(server): + t = WebSocketRelayTransport(server.url, "discord", "appShared", outbound_timeout_s=5) + await t.connect() + await t.handshake() + await t.disconnect() + # After disconnect, an outbound returns a structured failure rather than hanging. + result = await t.send_outbound({"op": "send", "chat_id": "c", "content": "x"}) + assert result["success"] is False