From 6ec5c95c916c9db53cc16e309d52d3bd657b0d56 Mon Sep 17 00:00:00 2001 From: vp Date: Tue, 24 Feb 2026 23:02:38 +0300 Subject: [PATCH] fix: harden roots unsupported fallback and deny-all messaging --- main.py | 75 +++++++++++++++++++++++++++++--------- test_file_path_security.py | 24 +++++++++++- 2 files changed, 80 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index b531556..bcd89ce 100644 --- a/main.py +++ b/main.py @@ -162,6 +162,11 @@ MAX_FILE_BYTES: dict[str, int] = { "edit_chat_photo": 50 * 1024 * 1024, } ROOTS_UNSUPPORTED_ERROR_CODES = {-32601} +ROOTS_STATUS_READY = "ready" +ROOTS_STATUS_NOT_CONFIGURED = "not_configured" +ROOTS_STATUS_UNSUPPORTED_FALLBACK = "unsupported_fallback" +ROOTS_STATUS_CLIENT_DENY_ALL = "client_deny_all" +ROOTS_STATUS_ERROR = "error" # Error code prefix mapping for better error tracing @@ -463,29 +468,47 @@ def _ensure_size_within_limit(tool_name: str, candidate: Path) -> Optional[str]: async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]: - fallback_roots = list(SERVER_ALLOWED_ROOTS) - if ctx is None: - return fallback_roots + roots, _status = await _get_effective_allowed_roots_with_status(ctx) + return roots - try: - list_roots_result = await ctx.session.list_roots() - except McpError as error: + +def _is_roots_unsupported_error(error: Exception) -> bool: + if isinstance(error, McpError): error_code = getattr(getattr(error, "error", None), "code", None) error_message = ( getattr(getattr(error, "error", None), "message", None) or str(error) ).lower() - if error_code in ROOTS_UNSUPPORTED_ERROR_CODES or "method not found" in error_message: - # Fallback is allowed only when roots are unsupported in this MCP client session. - return fallback_roots + if error_code in ROOTS_UNSUPPORTED_ERROR_CODES: + return True + return "method not found" in error_message or "not implemented" in error_message + + if isinstance(error, NotImplementedError): + return True + if isinstance(error, AttributeError): + return "list_roots" in str(error) + return False + + +async def _get_effective_allowed_roots_with_status( + ctx: Optional[Context], +) -> tuple[List[Path], str]: + fallback_roots = list(SERVER_ALLOWED_ROOTS) + if ctx is None: + if fallback_roots: + return fallback_roots, ROOTS_STATUS_READY + return [], ROOTS_STATUS_NOT_CONFIGURED + + try: + list_roots_result = await ctx.session.list_roots() + except Exception as error: + if _is_roots_unsupported_error(error): + if fallback_roots: + return fallback_roots, ROOTS_STATUS_UNSUPPORTED_FALLBACK + return [], ROOTS_STATUS_NOT_CONFIGURED logger.error( "MCP roots request failed; disabling file-path tools for safety.", exc_info=True ) - return [] - except Exception: - logger.error( - "Unexpected MCP roots failure; disabling file-path tools for safety.", exc_info=True - ) - return [] + return [], ROOTS_STATUS_ERROR client_roots: List[Path] = [] for root in list_roots_result.roots: @@ -496,17 +519,33 @@ async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]: continue if client_roots: - return _dedupe_paths(client_roots) + return _dedupe_paths(client_roots), ROOTS_STATUS_READY # Roots API succeeded; an empty roots list is treated as explicit deny-all. - return [] + return [], ROOTS_STATUS_CLIENT_DENY_ALL async def _ensure_allowed_roots( ctx: Optional[Context], tool_name: str ) -> tuple[List[Path], Optional[str]]: - roots = await _get_effective_allowed_roots(ctx) + roots, status = await _get_effective_allowed_roots_with_status(ctx) if not roots: + if status == ROOTS_STATUS_CLIENT_DENY_ALL: + return ( + [], + ( + f"{tool_name} is disabled because the client provided an empty " + "MCP Roots list (deny-all)." + ), + ) + if status == ROOTS_STATUS_ERROR: + return ( + [], + ( + f"{tool_name} is disabled because MCP Roots could not be verified safely. " + "Check MCP client/server logs." + ), + ) return ( [], ( diff --git a/test_file_path_security.py b/test_file_path_security.py index e8e03a5..665c413 100644 --- a/test_file_path_security.py +++ b/test_file_path_security.py @@ -38,6 +38,15 @@ class _FailingContext: self.session = _FailingSession(error) +class _MissingRootsSession: + pass + + +class _MissingRootsContext: + def __init__(self): + self.session = _MissingRootsSession() + + @pytest.mark.asyncio async def test_readable_relative_path_resolves_inside_first_server_root(tmp_path, monkeypatch): root = (tmp_path / "root").resolve() @@ -139,7 +148,8 @@ async def test_empty_client_roots_disable_file_tools(tmp_path, monkeypatch): ) assert resolved is None assert error is not None - assert "disabled" in error + assert "empty MCP Roots list" in error + assert "deny-all" in error @pytest.mark.asyncio @@ -154,6 +164,18 @@ async def test_mcp_method_not_found_falls_back_to_server_allowlist(tmp_path, mon assert roots == [server_root] +@pytest.mark.asyncio +async def test_missing_list_roots_method_falls_back_to_server_allowlist(tmp_path, monkeypatch): + server_root = (tmp_path / "server_root").resolve() + server_root.mkdir(parents=True) + + monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root]) + ctx = _MissingRootsContext() + + roots = await main._get_effective_allowed_roots(ctx) + assert roots == [server_root] + + @pytest.mark.asyncio async def test_unexpected_roots_error_disables_file_path_tools(tmp_path, monkeypatch): server_root = (tmp_path / "server_root").resolve()