fix: harden roots unsupported fallback and deny-all messaging

This commit is contained in:
vp 2026-02-24 23:02:38 +03:00
parent 172fedaf7a
commit 6ec5c95c91
2 changed files with 80 additions and 19 deletions

75
main.py
View file

@ -162,6 +162,11 @@ MAX_FILE_BYTES: dict[str, int] = {
"edit_chat_photo": 50 * 1024 * 1024, "edit_chat_photo": 50 * 1024 * 1024,
} }
ROOTS_UNSUPPORTED_ERROR_CODES = {-32601} ROOTS_UNSUPPORTED_ERROR_CODES = {-32601}
ROOTS_STATUS_READY = "ready"
ROOTS_STATUS_NOT_CONFIGURED = "not_configured"
ROOTS_STATUS_UNSUPPORTED_FALLBACK = "unsupported_fallback"
ROOTS_STATUS_CLIENT_DENY_ALL = "client_deny_all"
ROOTS_STATUS_ERROR = "error"
# Error code prefix mapping for better error tracing # Error code prefix mapping for better error tracing
@ -463,29 +468,47 @@ def _ensure_size_within_limit(tool_name: str, candidate: Path) -> Optional[str]:
async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]: async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]:
fallback_roots = list(SERVER_ALLOWED_ROOTS) roots, _status = await _get_effective_allowed_roots_with_status(ctx)
if ctx is None: return roots
return fallback_roots
try:
list_roots_result = await ctx.session.list_roots() def _is_roots_unsupported_error(error: Exception) -> bool:
except McpError as error: if isinstance(error, McpError):
error_code = getattr(getattr(error, "error", None), "code", None) error_code = getattr(getattr(error, "error", None), "code", None)
error_message = ( error_message = (
getattr(getattr(error, "error", None), "message", None) or str(error) getattr(getattr(error, "error", None), "message", None) or str(error)
).lower() ).lower()
if error_code in ROOTS_UNSUPPORTED_ERROR_CODES or "method not found" in error_message: if error_code in ROOTS_UNSUPPORTED_ERROR_CODES:
# Fallback is allowed only when roots are unsupported in this MCP client session. return True
return fallback_roots return "method not found" in error_message or "not implemented" in error_message
if isinstance(error, NotImplementedError):
return True
if isinstance(error, AttributeError):
return "list_roots" in str(error)
return False
async def _get_effective_allowed_roots_with_status(
ctx: Optional[Context],
) -> tuple[List[Path], str]:
fallback_roots = list(SERVER_ALLOWED_ROOTS)
if ctx is None:
if fallback_roots:
return fallback_roots, ROOTS_STATUS_READY
return [], ROOTS_STATUS_NOT_CONFIGURED
try:
list_roots_result = await ctx.session.list_roots()
except Exception as error:
if _is_roots_unsupported_error(error):
if fallback_roots:
return fallback_roots, ROOTS_STATUS_UNSUPPORTED_FALLBACK
return [], ROOTS_STATUS_NOT_CONFIGURED
logger.error( logger.error(
"MCP roots request failed; disabling file-path tools for safety.", exc_info=True "MCP roots request failed; disabling file-path tools for safety.", exc_info=True
) )
return [] return [], ROOTS_STATUS_ERROR
except Exception:
logger.error(
"Unexpected MCP roots failure; disabling file-path tools for safety.", exc_info=True
)
return []
client_roots: List[Path] = [] client_roots: List[Path] = []
for root in list_roots_result.roots: for root in list_roots_result.roots:
@ -496,17 +519,33 @@ async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]:
continue continue
if client_roots: if client_roots:
return _dedupe_paths(client_roots) return _dedupe_paths(client_roots), ROOTS_STATUS_READY
# Roots API succeeded; an empty roots list is treated as explicit deny-all. # Roots API succeeded; an empty roots list is treated as explicit deny-all.
return [] return [], ROOTS_STATUS_CLIENT_DENY_ALL
async def _ensure_allowed_roots( async def _ensure_allowed_roots(
ctx: Optional[Context], tool_name: str ctx: Optional[Context], tool_name: str
) -> tuple[List[Path], Optional[str]]: ) -> tuple[List[Path], Optional[str]]:
roots = await _get_effective_allowed_roots(ctx) roots, status = await _get_effective_allowed_roots_with_status(ctx)
if not roots: if not roots:
if status == ROOTS_STATUS_CLIENT_DENY_ALL:
return (
[],
(
f"{tool_name} is disabled because the client provided an empty "
"MCP Roots list (deny-all)."
),
)
if status == ROOTS_STATUS_ERROR:
return (
[],
(
f"{tool_name} is disabled because MCP Roots could not be verified safely. "
"Check MCP client/server logs."
),
)
return ( return (
[], [],
( (

View file

@ -38,6 +38,15 @@ class _FailingContext:
self.session = _FailingSession(error) self.session = _FailingSession(error)
class _MissingRootsSession:
pass
class _MissingRootsContext:
def __init__(self):
self.session = _MissingRootsSession()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_readable_relative_path_resolves_inside_first_server_root(tmp_path, monkeypatch): async def test_readable_relative_path_resolves_inside_first_server_root(tmp_path, monkeypatch):
root = (tmp_path / "root").resolve() root = (tmp_path / "root").resolve()
@ -139,7 +148,8 @@ async def test_empty_client_roots_disable_file_tools(tmp_path, monkeypatch):
) )
assert resolved is None assert resolved is None
assert error is not None assert error is not None
assert "disabled" in error assert "empty MCP Roots list" in error
assert "deny-all" in error
@pytest.mark.asyncio @pytest.mark.asyncio
@ -154,6 +164,18 @@ async def test_mcp_method_not_found_falls_back_to_server_allowlist(tmp_path, mon
assert roots == [server_root] assert roots == [server_root]
@pytest.mark.asyncio
async def test_missing_list_roots_method_falls_back_to_server_allowlist(tmp_path, monkeypatch):
server_root = (tmp_path / "server_root").resolve()
server_root.mkdir(parents=True)
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
ctx = _MissingRootsContext()
roots = await main._get_effective_allowed_roots(ctx)
assert roots == [server_root]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unexpected_roots_error_disables_file_path_tools(tmp_path, monkeypatch): async def test_unexpected_roots_error_disables_file_path_tools(tmp_path, monkeypatch):
server_root = (tmp_path / "server_root").resolve() server_root = (tmp_path / "server_root").resolve()