fix: harden roots unsupported fallback and deny-all messaging

This commit is contained in:
vp 2026-02-24 23:02:38 +03:00
parent 172fedaf7a
commit 6ec5c95c91
2 changed files with 80 additions and 19 deletions

75
main.py
View file

@ -162,6 +162,11 @@ MAX_FILE_BYTES: dict[str, int] = {
"edit_chat_photo": 50 * 1024 * 1024,
}
ROOTS_UNSUPPORTED_ERROR_CODES = {-32601}
ROOTS_STATUS_READY = "ready"
ROOTS_STATUS_NOT_CONFIGURED = "not_configured"
ROOTS_STATUS_UNSUPPORTED_FALLBACK = "unsupported_fallback"
ROOTS_STATUS_CLIENT_DENY_ALL = "client_deny_all"
ROOTS_STATUS_ERROR = "error"
# Error code prefix mapping for better error tracing
@ -463,29 +468,47 @@ def _ensure_size_within_limit(tool_name: str, candidate: Path) -> Optional[str]:
async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]:
fallback_roots = list(SERVER_ALLOWED_ROOTS)
if ctx is None:
return fallback_roots
roots, _status = await _get_effective_allowed_roots_with_status(ctx)
return roots
try:
list_roots_result = await ctx.session.list_roots()
except McpError as error:
def _is_roots_unsupported_error(error: Exception) -> bool:
if isinstance(error, McpError):
error_code = getattr(getattr(error, "error", None), "code", None)
error_message = (
getattr(getattr(error, "error", None), "message", None) or str(error)
).lower()
if error_code in ROOTS_UNSUPPORTED_ERROR_CODES or "method not found" in error_message:
# Fallback is allowed only when roots are unsupported in this MCP client session.
return fallback_roots
if error_code in ROOTS_UNSUPPORTED_ERROR_CODES:
return True
return "method not found" in error_message or "not implemented" in error_message
if isinstance(error, NotImplementedError):
return True
if isinstance(error, AttributeError):
return "list_roots" in str(error)
return False
async def _get_effective_allowed_roots_with_status(
ctx: Optional[Context],
) -> tuple[List[Path], str]:
fallback_roots = list(SERVER_ALLOWED_ROOTS)
if ctx is None:
if fallback_roots:
return fallback_roots, ROOTS_STATUS_READY
return [], ROOTS_STATUS_NOT_CONFIGURED
try:
list_roots_result = await ctx.session.list_roots()
except Exception as error:
if _is_roots_unsupported_error(error):
if fallback_roots:
return fallback_roots, ROOTS_STATUS_UNSUPPORTED_FALLBACK
return [], ROOTS_STATUS_NOT_CONFIGURED
logger.error(
"MCP roots request failed; disabling file-path tools for safety.", exc_info=True
)
return []
except Exception:
logger.error(
"Unexpected MCP roots failure; disabling file-path tools for safety.", exc_info=True
)
return []
return [], ROOTS_STATUS_ERROR
client_roots: List[Path] = []
for root in list_roots_result.roots:
@ -496,17 +519,33 @@ async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]:
continue
if client_roots:
return _dedupe_paths(client_roots)
return _dedupe_paths(client_roots), ROOTS_STATUS_READY
# Roots API succeeded; an empty roots list is treated as explicit deny-all.
return []
return [], ROOTS_STATUS_CLIENT_DENY_ALL
async def _ensure_allowed_roots(
ctx: Optional[Context], tool_name: str
) -> tuple[List[Path], Optional[str]]:
roots = await _get_effective_allowed_roots(ctx)
roots, status = await _get_effective_allowed_roots_with_status(ctx)
if not roots:
if status == ROOTS_STATUS_CLIENT_DENY_ALL:
return (
[],
(
f"{tool_name} is disabled because the client provided an empty "
"MCP Roots list (deny-all)."
),
)
if status == ROOTS_STATUS_ERROR:
return (
[],
(
f"{tool_name} is disabled because MCP Roots could not be verified safely. "
"Check MCP client/server logs."
),
)
return (
[],
(

View file

@ -38,6 +38,15 @@ class _FailingContext:
self.session = _FailingSession(error)
class _MissingRootsSession:
pass
class _MissingRootsContext:
def __init__(self):
self.session = _MissingRootsSession()
@pytest.mark.asyncio
async def test_readable_relative_path_resolves_inside_first_server_root(tmp_path, monkeypatch):
root = (tmp_path / "root").resolve()
@ -139,7 +148,8 @@ async def test_empty_client_roots_disable_file_tools(tmp_path, monkeypatch):
)
assert resolved is None
assert error is not None
assert "disabled" in error
assert "empty MCP Roots list" in error
assert "deny-all" in error
@pytest.mark.asyncio
@ -154,6 +164,18 @@ async def test_mcp_method_not_found_falls_back_to_server_allowlist(tmp_path, mon
assert roots == [server_root]
@pytest.mark.asyncio
async def test_missing_list_roots_method_falls_back_to_server_allowlist(tmp_path, monkeypatch):
server_root = (tmp_path / "server_root").resolve()
server_root.mkdir(parents=True)
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
ctx = _MissingRootsContext()
roots = await main._get_effective_allowed_roots(ctx)
assert roots == [server_root]
@pytest.mark.asyncio
async def test_unexpected_roots_error_disables_file_path_tools(tmp_path, monkeypatch):
server_root = (tmp_path / "server_root").resolve()