diff --git a/tests/tools/test_mcp_oauth.py b/tests/tools/test_mcp_oauth.py index e12149a45d3..53e0abfd615 100644 --- a/tests/tools/test_mcp_oauth.py +++ b/tests/tools/test_mcp_oauth.py @@ -23,6 +23,7 @@ from tools.mcp_oauth import ( _wait_for_callback, _make_callback_handler, _redirect_handler, + _paste_callback_reader, ) @@ -621,3 +622,135 @@ def test_build_oauth_auth_preserves_server_url_path(): assert captured["server_url"] == "https://mcp.notion.com/mcp" + +class TestPasteCallbackReader: + """_paste_callback_reader parses redirect URLs / query strings from stdin.""" + + def _empty_result(self): + return {"auth_code": None, "state": None, "error": None} + + def test_parses_full_local_redirect_url(self, monkeypatch): + result = self._empty_result() + monkeypatch.setattr( + "sys.stdin", + MagicMock(readline=lambda: "http://127.0.0.1:37949/callback?code=abc&state=xyz\n"), + ) + _paste_callback_reader(result) + assert result["auth_code"] == "abc" + assert result["state"] == "xyz" + assert result["error"] is None + + def test_parses_remote_provider_url(self, monkeypatch): + """User pastes the URL their browser ended up on, including a real host.""" + result = self._empty_result() + url = "https://mcp.linear.app/callback?code=deadbeef&state=eyJ0ZXN0Ijoi" + monkeypatch.setattr("sys.stdin", MagicMock(readline=lambda: url + "\n")) + _paste_callback_reader(result) + assert result["auth_code"] == "deadbeef" + assert result["state"] == "eyJ0ZXN0Ijoi" + + def test_parses_bare_query_string(self, monkeypatch): + result = self._empty_result() + monkeypatch.setattr( + "sys.stdin", + MagicMock(readline=lambda: "code=token123&state=st1\n"), + ) + _paste_callback_reader(result) + assert result["auth_code"] == "token123" + assert result["state"] == "st1" + + def test_parses_leading_question_mark(self, monkeypatch): + result = self._empty_result() + monkeypatch.setattr( + "sys.stdin", + MagicMock(readline=lambda: "?code=tok&state=stA\n"), + ) + _paste_callback_reader(result) + assert result["auth_code"] == "tok" + assert result["state"] == "stA" + + def test_captures_error_param(self, monkeypatch): + result = self._empty_result() + monkeypatch.setattr( + "sys.stdin", + MagicMock(readline=lambda: "https://example/cb?error=access_denied\n"), + ) + _paste_callback_reader(result) + assert result["auth_code"] is None + assert result["error"] == "access_denied" + + def test_empty_input_noop(self, monkeypatch): + result = self._empty_result() + monkeypatch.setattr("sys.stdin", MagicMock(readline=lambda: "")) + _paste_callback_reader(result) + assert result["auth_code"] is None + assert result["error"] is None + + def test_garbage_input_noop(self, monkeypatch, capsys): + result = self._empty_result() + monkeypatch.setattr( + "sys.stdin", MagicMock(readline=lambda: "not a url at all\n") + ) + _paste_callback_reader(result) + assert result["auth_code"] is None + assert result["error"] is None + err = capsys.readouterr().err + assert "did not contain" in err or "Could not parse" in err + + def test_skips_when_http_listener_already_won(self, monkeypatch): + """If HTTP listener filled the result first, paste must not overwrite.""" + result = {"auth_code": "from_http", "state": "http_state", "error": None} + monkeypatch.setattr( + "sys.stdin", + MagicMock(readline=lambda: "code=from_paste&state=paste_state\n"), + ) + _paste_callback_reader(result) + assert result["auth_code"] == "from_http" + assert result["state"] == "http_state" + + def test_swallows_stdin_errors(self, monkeypatch): + """OSError / interrupt on readline must not propagate.""" + result = self._empty_result() + def raise_oserror(): + raise OSError("stdin closed") + monkeypatch.setattr("sys.stdin", MagicMock(readline=raise_oserror)) + _paste_callback_reader(result) # must not raise + assert result["auth_code"] is None + + +class TestWaitForCallbackPasteIntegration: + """_wait_for_callback offers the paste prompt only when interactive.""" + + def test_paste_prompt_shown_on_tty(self, monkeypatch, capsys): + import tools.mcp_oauth as mod + mod._oauth_port = _find_free_port() + monkeypatch.setattr(mod, "_is_interactive", lambda: True) + # Make stdin readline block forever so HTTP listener path drives the test; + # we just want to verify the prompt was printed and the thread spawned. + def block_forever(): + import threading + threading.Event().wait() + monkeypatch.setattr("sys.stdin", MagicMock(readline=block_forever)) + + async def instant_sleep(_): + pass + with patch.object(mod.asyncio, "sleep", instant_sleep): + with pytest.raises(OAuthNonInteractiveError): + asyncio.run(_wait_for_callback()) + err = capsys.readouterr().err + assert "paste the redirect URL" in err + + def test_paste_prompt_NOT_shown_when_noninteractive(self, monkeypatch, capsys): + """Preserves existing invariant: no input() / paste prompt in headless runs.""" + import tools.mcp_oauth as mod + mod._oauth_port = _find_free_port() + monkeypatch.setattr(mod, "_is_interactive", lambda: False) + + async def instant_sleep(_): + pass + with patch.object(mod.asyncio, "sleep", instant_sleep): + with patch("builtins.input", side_effect=AssertionError("input() must not be called")): + with pytest.raises(OAuthNonInteractiveError): + asyncio.run(_wait_for_callback()) + err = capsys.readouterr().err + assert "paste the redirect URL" not in err diff --git a/tools/mcp_oauth.py b/tools/mcp_oauth.py index 53b4615000f..c79d999cd93 100644 --- a/tools/mcp_oauth.py +++ b/tools/mcp_oauth.py @@ -403,17 +403,25 @@ async def _redirect_handler(authorization_url: str) -> None: # On a remote SSH session the OAuth provider redirects to # http://127.0.0.1:/callback, which reaches the callback server on # the *remote* machine — not the user's local machine where the browser - # opened. Print a port-forward hint so the user knows to tunnel first. + # opened. Two ways out: paste the redirect URL back (default fallback, + # offered by _wait_for_callback on interactive TTYs), or set up an SSH + # port forward so the redirect tunnels through. if _oauth_port and (os.getenv("SSH_CLIENT") or os.getenv("SSH_TTY")): print( - f" Remote session detected. The OAuth provider will redirect your browser to\n" + f" Remote session detected. After you authorize, the provider redirects to\n" f" http://127.0.0.1:{_oauth_port}/callback\n" - f" which the callback listener on THIS machine is waiting on. If your browser\n" - f" is on a different machine, forward the port first in a separate terminal:\n" + f" which only the listener on THIS machine can receive. Two options:\n" f"\n" - f" ssh -N -L {_oauth_port}:127.0.0.1:{_oauth_port} @\n" + f" 1. Easiest — when your browser shows a connection error after\n" + f" authorizing, copy the full URL from the address bar and paste\n" + f" it at the prompt below. The pasted ``code=...&state=...`` is\n" + f" enough to complete the flow.\n" f"\n" - f" Then open the URL above. See: https://hermes-agent.nousresearch.com/docs/guides/oauth-over-ssh\n", + f" 2. Or forward the port first in a separate terminal:\n" + f" ssh -N -L {_oauth_port}:127.0.0.1:{_oauth_port} @\n" + f" then open the URL above and let it redirect normally.\n" + f"\n" + f" See: https://hermes-agent.nousresearch.com/docs/guides/oauth-over-ssh\n", file=sys.stderr, ) @@ -437,6 +445,12 @@ async def _wait_for_callback() -> tuple[str, str | None]: before this is ever called. Polls for the result without blocking the event loop. + On an interactive TTY, races the HTTP listener against a stdin paste + fallback so users without an SSH tunnel can copy the redirect URL (or + just the ``code=...&state=...`` query string) from a browser on another + machine and paste it back. The HTTP listener wins when the redirect + reaches it first; the paste fallback wins when it doesn't. + Raises: OAuthNonInteractiveError: If the callback times out (no user present to complete the browser auth). @@ -468,6 +482,23 @@ async def _wait_for_callback() -> tuple[str, str | None]: server_thread = threading.Thread(target=server.handle_request, daemon=True) server_thread.start() + # Optional paste-fallback thread: only on interactive TTYs. Reads one + # line from stdin and writes the parsed code/state into the shared + # result dict. The HTTP listener and this thread race for the result; + # whichever fills it first wins. + paste_thread: threading.Thread | None = None + if _is_interactive(): + print( + "\n Or paste the redirect URL here (or the ``?code=...&state=...`` " + "portion) and press Enter:", + file=sys.stderr, + flush=True, + ) + paste_thread = threading.Thread( + target=_paste_callback_reader, args=(result,), daemon=True + ) + paste_thread.start() + timeout = 300.0 poll_interval = 0.5 elapsed = 0.0 @@ -491,6 +522,70 @@ async def _wait_for_callback() -> tuple[str, str | None]: return result["auth_code"], result["state"] +def _paste_callback_reader(result: dict) -> None: + """Read one line from stdin, parse it as an OAuth redirect, write to result. + + Accepts any of: + - Full redirect URL: ``http://127.0.0.1:37949/callback?code=...&state=...`` + - The provider's own callback URL: ``https://mcp.example.com/callback?code=...&state=...`` + - Just the query string: ``?code=...&state=...`` or ``code=...&state=...`` + + Failures to parse, EOF, or interrupts are swallowed — this is best-effort + fallback alongside the HTTP listener, which remains the primary path. + """ + try: + line = sys.stdin.readline() + except (KeyboardInterrupt, OSError, ValueError): + return + if not line: + return # EOF + line = line.strip() + if not line: + return + + # Skip if HTTP listener already won. + if result.get("auth_code") is not None or result.get("error") is not None: + return + + # Strip a leading "?" if user pasted just a query string. + query = line + if "?" in line: + # Either a full URL or "?code=...". Take everything after the first "?". + query = line.split("?", 1)[1] + if query.startswith("?"): + query = query[1:] + + try: + params = parse_qs(query) + except (ValueError, TypeError): + print( + " Could not parse pasted input as an OAuth redirect — ignoring.", + file=sys.stderr, + ) + return + + code = params.get("code", [None])[0] + state = params.get("state", [None])[0] + error = params.get("error", [None])[0] + + if not code and not error: + print( + " Pasted input did not contain ``code=`` or ``error=`` — ignoring.", + file=sys.stderr, + ) + return + + # One more race-check before writing. + if result.get("auth_code") is not None or result.get("error") is not None: + return + + result["auth_code"] = code + result["state"] = state + result["error"] = error + if code: + print(" Got authorization code from paste — completing flow.", file=sys.stderr) + + # --------------------------------------------------------------------------- # Public API # ---------------------------------------------------------------------------