From fd9c90e53c2457a3f82d99067cd42780fc58aebe Mon Sep 17 00:00:00 2001 From: vp Date: Tue, 24 Feb 2026 22:12:56 +0300 Subject: [PATCH 1/6] feat: secure file-path tools with allowlisted roots - implement server-side allowlist via CLI positional roots (fallback)\n- implement client MCP Roots override semantics (replace server roots when available)\n- add realpath + in-root validation, traversal/glob rejection, extension and size checks\n- make write path default to /downloads when file_path is omitted\n- reintroduce upload_file tool with the same path security model\n- update README with security model and usage\n- add tests for root resolution, replacement semantics, traversal checks, and default write path\n- add pytest and pytest-asyncio to dev dependencies --- README.md | 27 ++- main.py | 430 ++++++++++++++++++++++++++++++++----- pyproject.toml | 2 + test_file_path_security.py | 159 ++++++++++++++ uv.lock | 63 ++++++ 5 files changed, 626 insertions(+), 55 deletions(-) create mode 100644 test_file_path_security.py diff --git a/README.md b/README.md index b0eb441..97a0b33 100644 --- a/README.md +++ b/README.md @@ -101,12 +101,19 @@ This MCP server exposes a huge suite of Telegram tools. **Every major Telegram/T ### User & Profile - **get_me()**: Get your user info - **update_profile(first_name, last_name, about)**: Update your profile +- **set_profile_photo(file_path)**: Set a profile photo from an allowed root path - **delete_profile_photo()**: Remove your profile photo - **get_user_photos(user_id, limit)**: Get a user's profile photos - **get_user_status(user_id)**: Get a user's online status ### Media - **get_media_info(chat_id, message_id)**: Get info about media in a message +- **send_file(chat_id, file_path, caption)**: Send a local file from allowed roots +- **download_media(chat_id, message_id, file_path)**: Save message media under allowed roots +- **upload_file(file_path)**: Upload a local file and return upload metadata +- **send_voice(chat_id, file_path)**: Send `.ogg/.opus` voice note from allowed roots +- **send_sticker(chat_id, file_path)**: Send `.webp` sticker from allowed roots +- **edit_chat_photo(chat_id, file_path)**: Update chat photo from allowed roots ### Search & Discovery - **search_public_chats(query)**: Search public chats/channels/bots @@ -142,11 +149,25 @@ To improve robustness, all functions accepting `chat_id` or `user_id` parameters The server will automatically validate the input and convert it to the correct format before making a request to Telegram. If the input is invalid, a clear error message will be returned. -## Removed Functionality +## File-path Tools Security Model -Please note that tools requiring direct file path access on the server (`send_file`, `download_media`, `set_profile_photo`, `edit_chat_photo`, `send_voice`, `send_sticker`, `upload_file`) have been removed from `main.py`. This is due to limitations in the current MCP environment regarding handling file attachments and local file system paths. +File-path tools are available, but **disabled by default** until allowed roots are configured. -Additionally, GIF-related tools (`get_gif_search`, `get_saved_gifs`, `send_gif`) have been removed due to ongoing issues with reliability in the Telethon library or Telegram API interactions. +Supported file-path tools: +- `send_file`, `download_media`, `set_profile_photo`, `edit_chat_photo`, `send_voice`, `send_sticker`, `upload_file` + +Security semantics (aligned with MCP filesystem server): +- Server-side allowlist via CLI positional arguments (fallback). +- Client-provided MCP Roots replace the server allowlist when available. +- All paths are resolved via realpath and must stay inside an allowed root. +- Traversal/glob-like patterns are rejected (`..`, `*`, `?`, `~`, etc.). +- Relative paths resolve against the first allowed root. +- Write tools default to `/downloads/` when `file_path` is omitted. + +Example server launch with allowlisted roots: +```bash +uv --directory /full/path/to/telegram-mcp run main.py /data/telegram /tmp/telegram-mcp +``` --- diff --git a/main.py b/main.py index 9721118..3c123cd 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import argparse import os import sys import json @@ -9,12 +10,15 @@ import mimetypes from datetime import datetime, timedelta from enum import Enum from typing import List, Dict, Optional, Union, Any +from pathlib import Path +from urllib.parse import unquote, urlparse # Third-party libraries import nest_asyncio from dotenv import load_dotenv -from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp import FastMCP, Context from mcp.types import ToolAnnotations +from mcp.shared.exceptions import McpError from pythonjsonlogger import jsonlogger from telethon import TelegramClient, functions, types, utils from telethon.sessions import StringSession @@ -139,6 +143,26 @@ except Exception as log_error: logger.error(f"Failed to set up log file handler: {log_error}") +# File-path tool security configuration +SERVER_ALLOWED_ROOTS: list[Path] = [] +DEFAULT_DOWNLOAD_SUBDIR = "downloads" +DISALLOWED_PATH_PATTERNS = ("*", "?", "[", "]", "{", "}", "~", "\x00") +EXTENSION_ALLOWLISTS: dict[str, set[str]] = { + "send_voice": {".ogg", ".opus"}, + "send_sticker": {".webp"}, + "set_profile_photo": {".jpg", ".jpeg", ".png", ".webp"}, + "edit_chat_photo": {".jpg", ".jpeg", ".png", ".webp"}, +} +MAX_FILE_BYTES: dict[str, int] = { + "send_file": 200 * 1024 * 1024, # 200 MB + "upload_file": 200 * 1024 * 1024, + "send_voice": 100 * 1024 * 1024, + "send_sticker": 10 * 1024 * 1024, + "set_profile_photo": 50 * 1024 * 1024, + "edit_chat_photo": 50 * 1024 * 1024, +} + + # Error code prefix mapping for better error tracing class ErrorCategory(str, Enum): CHAT = "CHAT" @@ -363,6 +387,219 @@ def get_engagement_info(message) -> str: return f" | {', '.join(engagement_parts)}" if engagement_parts else "" +def _dedupe_paths(paths: List[Path]) -> List[Path]: + seen: set[str] = set() + result: List[Path] = [] + for path in paths: + key = str(path) + if key in seen: + continue + seen.add(key) + result.append(path) + return result + + +def _contains_forbidden_path_patterns(raw_path: str) -> Optional[str]: + value = raw_path.strip() + if not value: + return "Path must not be empty." + if any(token in value for token in DISALLOWED_PATH_PATTERNS): + return "Path contains disallowed wildcard/shell patterns." + if ".." in Path(value).parts: + return "Path traversal is not allowed." + return None + + +def _coerce_root_uri_to_path(uri: str) -> Path: + parsed = urlparse(uri) + if parsed.scheme != "file": + raise ValueError(f"Unsupported root URI scheme: {parsed.scheme}") + + decoded_path = unquote(parsed.path or "") + if parsed.netloc and parsed.netloc not in ("", "localhost"): + decoded_path = f"//{parsed.netloc}{decoded_path}" + if os.name == "nt" and decoded_path.startswith("/") and len(decoded_path) > 2: + # file:///C:/tmp -> C:/tmp on Windows + if decoded_path[2] == ":": + decoded_path = decoded_path[1:] + return Path(decoded_path).resolve(strict=True) + + +def _path_is_within_root(candidate: Path, root: Path) -> bool: + root = root.resolve() + if root.is_file(): + return candidate == root + return candidate == root or root in candidate.parents + + +def _path_is_within_any_root(candidate: Path, roots: List[Path]) -> bool: + return any(_path_is_within_root(candidate, root) for root in roots) + + +def _first_resolution_root(roots: List[Path]) -> Path: + first = roots[0] + return first if first.is_dir() else first.parent + + +def _ensure_extension_allowed(tool_name: str, candidate: Path) -> Optional[str]: + allowlist = EXTENSION_ALLOWLISTS.get(tool_name) + if not allowlist: + return None + if candidate.suffix.lower() not in allowlist: + allowed = ", ".join(sorted(allowlist)) + return f"File extension is not allowed for {tool_name}. Allowed: {allowed}." + return None + + +def _ensure_size_within_limit(tool_name: str, candidate: Path) -> Optional[str]: + max_bytes = MAX_FILE_BYTES.get(tool_name) + if not max_bytes: + return None + size = candidate.stat().st_size + if size > max_bytes: + return f"File is too large for {tool_name}: {size} bytes " f"(limit: {max_bytes} bytes)." + return None + + +async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]: + if ctx is not None: + try: + list_roots_result = await ctx.session.list_roots() + client_roots: List[Path] = [] + for root in list_roots_result.roots: + try: + client_roots.append(_coerce_root_uri_to_path(str(root.uri))) + except Exception: + # Ignore invalid root entries supplied by a client. + continue + return _dedupe_paths(client_roots) + except (McpError, Exception): + # Fall back to server-side allowlist when roots are unsupported + # or not available in this MCP client session. + pass + return list(SERVER_ALLOWED_ROOTS) + + +async def _ensure_allowed_roots( + ctx: Optional[Context], tool_name: str +) -> tuple[List[Path], Optional[str]]: + roots = await _get_effective_allowed_roots(ctx) + if not roots: + return ( + [], + ( + f"{tool_name} is disabled until allowed roots are configured. " + "Provide server CLI roots and/or client MCP Roots." + ), + ) + return roots, None + + +async def _resolve_readable_file_path( + *, + raw_path: str, + ctx: Optional[Context], + tool_name: str, +) -> tuple[Optional[Path], Optional[str]]: + roots, error = await _ensure_allowed_roots(ctx, tool_name) + if error: + return None, error + + pattern_error = _contains_forbidden_path_patterns(raw_path) + if pattern_error: + return None, pattern_error + + candidate = Path(raw_path.strip()) + if not candidate.is_absolute(): + candidate = _first_resolution_root(roots) / candidate + + try: + candidate = candidate.resolve(strict=True) + except FileNotFoundError: + return None, f"File not found: {raw_path}" + + if not _path_is_within_any_root(candidate, roots): + return None, "Path is outside allowed roots." + if not candidate.is_file(): + return None, f"Path is not a file: {candidate}" + if not os.access(candidate, os.R_OK): + return None, f"File is not readable: {candidate}" + + extension_error = _ensure_extension_allowed(tool_name, candidate) + if extension_error: + return None, extension_error + + size_error = _ensure_size_within_limit(tool_name, candidate) + if size_error: + return None, size_error + + return candidate, None + + +async def _resolve_writable_file_path( + *, + raw_path: Optional[str], + default_filename: str, + ctx: Optional[Context], + tool_name: str, +) -> tuple[Optional[Path], Optional[str]]: + roots, error = await _ensure_allowed_roots(ctx, tool_name) + if error: + return None, error + + if raw_path and raw_path.strip(): + pattern_error = _contains_forbidden_path_patterns(raw_path) + if pattern_error: + return None, pattern_error + candidate = Path(raw_path.strip()) + if not candidate.is_absolute(): + candidate = _first_resolution_root(roots) / candidate + else: + safe_name = Path(default_filename).name + candidate = _first_resolution_root(roots) / DEFAULT_DOWNLOAD_SUBDIR / safe_name + + candidate = candidate.resolve(strict=False) + parent = candidate.parent.resolve(strict=False) + if not _path_is_within_any_root(candidate, roots) or not _path_is_within_any_root( + parent, roots + ): + return None, "Path is outside allowed roots." + + extension_error = _ensure_extension_allowed(tool_name, candidate) + if extension_error: + return None, extension_error + + parent.mkdir(parents=True, exist_ok=True) + if not os.access(parent, os.W_OK): + return None, f"Directory not writable: {parent}" + + return candidate, None + + +def _configure_allowed_roots_from_cli(argv: Optional[List[str]] = None) -> None: + parser = argparse.ArgumentParser( + prog="telegram-mcp", + add_help=False, + description=( + "Optional positional arguments define server-side allowed roots " + "for file-path tools." + ), + ) + parser.add_argument("allowed_roots", nargs="*") + parsed, _unknown = parser.parse_known_args(argv or []) + + resolved_roots: List[Path] = [] + for raw_root in parsed.allowed_roots: + root = Path(raw_root).expanduser() + if not root.exists(): + raise SystemExit(f"Allowed root does not exist: {root}") + resolved = root.resolve(strict=True) + resolved_roots.append(resolved) + + global SERVER_ALLOWED_ROOTS + SERVER_ALLOWED_ROOTS = _dedupe_paths(resolved_roots) + + @mcp.tool(annotations=ToolAnnotations(title="Get Chats", openWorldHint=True, readOnlyHint=True)) async def get_chats(page: int = 1, page_size: int = 20) -> str: """ @@ -1730,22 +1967,30 @@ async def get_participants(chat_id: Union[int, str]) -> str: @mcp.tool(annotations=ToolAnnotations(title="Send File", openWorldHint=True, destructiveHint=True)) @validate_id("chat_id") -async def send_file(chat_id: Union[int, str], file_path: str, caption: str = None) -> str: +async def send_file( + chat_id: Union[int, str], + file_path: str, + caption: str = None, + ctx: Optional[Context] = None, +) -> str: """ Send a file to a chat. Args: chat_id: The chat ID or username. - file_path: Absolute path to the file to send (must exist and be readable). + file_path: Absolute or relative path to the file under allowed roots. caption: Optional caption for the file. """ try: - if not os.path.isfile(file_path): - return f"File not found: {file_path}" - if not os.access(file_path, os.R_OK): - return f"File is not readable: {file_path}" + safe_path, path_error = await _resolve_readable_file_path( + raw_path=file_path, + ctx=ctx, + tool_name="send_file", + ) + if path_error: + return path_error entity = await client.get_entity(chat_id) - await client.send_file(entity, file_path, caption=caption) - return f"File sent to chat {chat_id}." + await client.send_file(entity, str(safe_path), caption=caption) + return f"File sent to chat {chat_id} from {safe_path}." except Exception as e: return log_and_format_error( "send_file", e, chat_id=chat_id, file_path=file_path, caption=caption @@ -1753,30 +1998,51 @@ async def send_file(chat_id: Union[int, str], file_path: str, caption: str = Non @mcp.tool( - annotations=ToolAnnotations(title="Download Media", openWorldHint=True, readOnlyHint=True) + annotations=ToolAnnotations(title="Download Media", openWorldHint=True, destructiveHint=True) ) @validate_id("chat_id") -async def download_media(chat_id: Union[int, str], message_id: int, file_path: str) -> str: +async def download_media( + chat_id: Union[int, str], + message_id: int, + file_path: Optional[str] = None, + ctx: Optional[Context] = None, +) -> str: """ Download media from a message in a chat. Args: chat_id: The chat ID or username. message_id: The message ID containing the media. - file_path: Absolute path to save the downloaded file (must be writable). + file_path: Optional absolute or relative path under allowed roots. + If omitted, saves into `/downloads/`. """ try: entity = await client.get_entity(chat_id) msg = await client.get_messages(entity, ids=message_id) if not msg or not msg.media: return "No media found in the specified message." - # Check if directory is writable - dir_path = os.path.dirname(file_path) or "." - if not os.access(dir_path, os.W_OK): - return f"Directory not writable: {dir_path}" - await client.download_media(msg, file=file_path) - if not os.path.isfile(file_path): - return f"Download failed: file not created at {file_path}" - return f"Media downloaded to {file_path}." + + default_name = f"telegram_{chat_id}_{message_id}_{int(time.time())}" + out_path, path_error = await _resolve_writable_file_path( + raw_path=file_path, + default_filename=default_name, + ctx=ctx, + tool_name="download_media", + ) + if path_error: + return path_error + + downloaded = await client.download_media(msg, file=str(out_path)) + if not downloaded: + return f"Download failed for message {message_id}." + + final_path = Path(downloaded).resolve(strict=True) + roots, roots_error = await _ensure_allowed_roots(ctx, "download_media") + if roots_error: + return roots_error + if not _path_is_within_any_root(final_path, roots): + return "Download failed: resulting path is outside allowed roots." + + return f"Media downloaded to {final_path}." except Exception as e: return log_and_format_error( "download_media", @@ -1814,15 +2080,24 @@ async def update_profile(first_name: str = None, last_name: str = None, about: s title="Set Profile Photo", openWorldHint=True, destructiveHint=True, idempotentHint=True ) ) -async def set_profile_photo(file_path: str) -> str: +async def set_profile_photo(file_path: str, ctx: Optional[Context] = None) -> str: """ Set a new profile photo. """ try: - await client( - functions.photos.UploadProfilePhotoRequest(file=await client.upload_file(file_path)) + safe_path, path_error = await _resolve_readable_file_path( + raw_path=file_path, + ctx=ctx, + tool_name="set_profile_photo", ) - return "Profile photo updated." + if path_error: + return path_error + await client( + functions.photos.UploadProfilePhotoRequest( + file=await client.upload_file(str(safe_path)) + ) + ) + return f"Profile photo updated from {safe_path}." except Exception as e: return log_and_format_error("set_profile_photo", e, file_path=file_path) @@ -2077,18 +2352,25 @@ async def edit_chat_title(chat_id: Union[int, str], title: str) -> str: ) ) @validate_id("chat_id") -async def edit_chat_photo(chat_id: Union[int, str], file_path: str) -> str: +async def edit_chat_photo( + chat_id: Union[int, str], + file_path: str, + ctx: Optional[Context] = None, +) -> str: """ Edit the photo of a chat, group, or channel. Requires a file path to an image. """ try: - if not os.path.isfile(file_path): - return f"Photo file not found: {file_path}" - if not os.access(file_path, os.R_OK): - return f"Photo file not readable: {file_path}" + safe_path, path_error = await _resolve_readable_file_path( + raw_path=file_path, + ctx=ctx, + tool_name="edit_chat_photo", + ) + if path_error: + return path_error entity = await client.get_entity(chat_id) - uploaded_file = await client.upload_file(file_path) + uploaded_file = await client.upload_file(str(safe_path)) if isinstance(entity, Channel): # For channels/supergroups, use EditPhotoRequest with InputChatUploadedPhoto @@ -2103,7 +2385,7 @@ async def edit_chat_photo(chat_id: Union[int, str], file_path: str) -> str: else: return f"Cannot edit photo for this entity type ({type(entity)})." - return f"Chat {chat_id} photo updated." + return f"Chat {chat_id} photo updated from {safe_path}." except Exception as e: logger.exception(f"edit_chat_photo failed (chat_id={chat_id}, file_path='{file_path}')") return log_and_format_error("edit_chat_photo", e, chat_id=chat_id, file_path=file_path) @@ -2606,38 +2888,76 @@ async def import_chat_invite(hash: str) -> str: annotations=ToolAnnotations(title="Send Voice", openWorldHint=True, destructiveHint=True) ) @validate_id("chat_id") -async def send_voice(chat_id: Union[int, str], file_path: str) -> str: +async def send_voice( + chat_id: Union[int, str], + file_path: str, + ctx: Optional[Context] = None, +) -> str: """ Send a voice message to a chat. File must be an OGG/OPUS voice note. Args: chat_id: The chat ID or username. - file_path: Absolute path to the OGG/OPUS file. + file_path: Absolute or relative path under allowed roots to the OGG/OPUS file. """ try: - if not os.path.isfile(file_path): - return f"File not found: {file_path}" - if not os.access(file_path, os.R_OK): - return f"File is not readable: {file_path}" + safe_path, path_error = await _resolve_readable_file_path( + raw_path=file_path, + ctx=ctx, + tool_name="send_voice", + ) + if path_error: + return path_error - mime, _ = mimetypes.guess_type(file_path) + mime, _ = mimetypes.guess_type(str(safe_path)) if not ( mime and ( mime == "audio/ogg" - or file_path.lower().endswith(".ogg") - or file_path.lower().endswith(".opus") + or str(safe_path).lower().endswith(".ogg") + or str(safe_path).lower().endswith(".opus") ) ): return "Voice file must be .ogg or .opus format." entity = await client.get_entity(chat_id) - await client.send_file(entity, file_path, voice_note=True) - return f"Voice message sent to chat {chat_id}." + await client.send_file(entity, str(safe_path), voice_note=True) + return f"Voice message sent to chat {chat_id} from {safe_path}." except Exception as e: return log_and_format_error("send_voice", e, chat_id=chat_id, file_path=file_path) +@mcp.tool( + annotations=ToolAnnotations(title="Upload File", openWorldHint=True, destructiveHint=True) +) +async def upload_file(file_path: str, ctx: Optional[Context] = None) -> str: + """ + Upload a local file to Telegram and return upload metadata. + + Args: + file_path: Absolute or relative path under allowed roots. + """ + try: + safe_path, path_error = await _resolve_readable_file_path( + raw_path=file_path, + ctx=ctx, + tool_name="upload_file", + ) + if path_error: + return path_error + + uploaded = await client.upload_file(str(safe_path)) + payload = { + "path": str(safe_path), + "name": getattr(uploaded, "name", safe_path.name), + "size": getattr(uploaded, "size", safe_path.stat().st_size), + "md5_checksum": getattr(uploaded, "md5_checksum", None), + } + return json.dumps(payload, indent=2, default=json_serializer) + except Exception as e: + return log_and_format_error("upload_file", e, file_path=file_path) + + @mcp.tool( annotations=ToolAnnotations(title="Forward Message", openWorldHint=True, destructiveHint=True) ) @@ -3003,25 +3323,30 @@ async def get_sticker_sets() -> str: annotations=ToolAnnotations(title="Send Sticker", openWorldHint=True, destructiveHint=True) ) @validate_id("chat_id") -async def send_sticker(chat_id: Union[int, str], file_path: str) -> str: +async def send_sticker( + chat_id: Union[int, str], + file_path: str, + ctx: Optional[Context] = None, +) -> str: """ Send a sticker to a chat. File must be a valid .webp sticker file. Args: chat_id: The chat ID or username. - file_path: Absolute path to the .webp sticker file. + file_path: Absolute or relative path under allowed roots to the .webp sticker file. """ try: - if not os.path.isfile(file_path): - return f"Sticker file not found: {file_path}" - if not os.access(file_path, os.R_OK): - return f"Sticker file is not readable: {file_path}" - if not file_path.lower().endswith(".webp"): - return "Sticker file must be a .webp file." + safe_path, path_error = await _resolve_readable_file_path( + raw_path=file_path, + ctx=ctx, + tool_name="send_sticker", + ) + if path_error: + return path_error entity = await client.get_entity(chat_id) - await client.send_file(entity, file_path, force_document=False) - return f"Sticker sent to chat {chat_id}." + await client.send_file(entity, str(safe_path), force_document=False) + return f"Sticker sent to chat {chat_id} from {safe_path}." except Exception as e: return log_and_format_error("send_sticker", e, chat_id=chat_id, file_path=file_path) @@ -4228,6 +4553,7 @@ async def _main() -> None: def main() -> None: + _configure_allowed_roots_from_cli(sys.argv[1:]) nest_asyncio.apply() asyncio.run(_main()) diff --git a/pyproject.toml b/pyproject.toml index 4fcf097..b52dc17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,4 +64,6 @@ exclude = [ dev = [ "black>=25.9.0", "flake8>=7.3.0", + "pytest>=9.0.2", + "pytest-asyncio>=1.3.0", ] diff --git a/test_file_path_security.py b/test_file_path_security.py new file mode 100644 index 0000000..bc54da6 --- /dev/null +++ b/test_file_path_security.py @@ -0,0 +1,159 @@ +import os +from pathlib import Path + +import pytest +from mcp import types + +os.environ.setdefault("TELEGRAM_API_ID", "12345") +os.environ.setdefault("TELEGRAM_API_HASH", "dummy_hash") + +import main + + +class _DummySession: + def __init__(self, roots): + self._roots = roots + + async def list_roots(self): + return types.ListRootsResult(roots=self._roots) + + +class _DummyContext: + def __init__(self, roots): + self.session = _DummySession(roots) + + +@pytest.mark.asyncio +async def test_readable_relative_path_resolves_inside_first_server_root(tmp_path, monkeypatch): + root = (tmp_path / "root").resolve() + root.mkdir(parents=True) + target = root / "document.txt" + target.write_text("ok", encoding="utf-8") + + monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [root]) + + resolved, error = await main._resolve_readable_file_path( + raw_path="document.txt", + ctx=None, + tool_name="send_file", + ) + + assert error is None + assert resolved == target.resolve() + + +@pytest.mark.asyncio +async def test_readable_path_rejects_traversal(tmp_path, monkeypatch): + root = (tmp_path / "root").resolve() + root.mkdir(parents=True) + monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [root]) + + resolved, error = await main._resolve_readable_file_path( + raw_path="../etc/passwd", + ctx=None, + tool_name="send_file", + ) + + assert resolved is None + assert error == "Path traversal is not allowed." + + +@pytest.mark.asyncio +async def test_readable_path_rejects_outside_root(tmp_path, monkeypatch): + root = (tmp_path / "root").resolve() + outside_root = (tmp_path / "outside").resolve() + root.mkdir(parents=True) + outside_root.mkdir(parents=True) + + outside_file = outside_root / "outside.txt" + outside_file.write_text("no", encoding="utf-8") + + monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [root]) + + resolved, error = await main._resolve_readable_file_path( + raw_path=str(outside_file), + ctx=None, + tool_name="send_file", + ) + + assert resolved is None + assert error == "Path is outside allowed roots." + + +@pytest.mark.asyncio +async def test_client_roots_replace_server_allowlist(tmp_path, monkeypatch): + server_root = (tmp_path / "server_root").resolve() + client_root = (tmp_path / "client_root").resolve() + server_root.mkdir(parents=True) + client_root.mkdir(parents=True) + + (server_root / "server.txt").write_text("server", encoding="utf-8") + client_file = client_root / "client.txt" + client_file.write_text("client", encoding="utf-8") + + monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root]) + ctx = _DummyContext([types.Root(uri=client_root.as_uri())]) + + roots = await main._get_effective_allowed_roots(ctx) + assert roots == [client_root] + + resolved, error = await main._resolve_readable_file_path( + raw_path="client.txt", + ctx=ctx, + tool_name="send_file", + ) + assert error is None + assert resolved == client_file.resolve() + + +@pytest.mark.asyncio +async def test_writable_default_path_uses_downloads_subdir(tmp_path, monkeypatch): + root = (tmp_path / "root").resolve() + root.mkdir(parents=True) + monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [root]) + + resolved, error = await main._resolve_writable_file_path( + raw_path=None, + default_filename="example.bin", + ctx=None, + tool_name="download_media", + ) + + assert error is None + assert resolved == (root / "downloads" / "example.bin").resolve() + assert resolved.parent.exists() + + +@pytest.mark.asyncio +async def test_extension_allowlist_is_enforced_for_sticker(tmp_path, monkeypatch): + root = (tmp_path / "root").resolve() + root.mkdir(parents=True) + file_path = root / "sticker.txt" + file_path.write_text("bad", encoding="utf-8") + + monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [root]) + + resolved, error = await main._resolve_readable_file_path( + raw_path=str(file_path), + ctx=None, + tool_name="send_sticker", + ) + + assert resolved is None + assert error is not None + assert "extension is not allowed" in error + + +@pytest.mark.asyncio +async def test_file_tools_disabled_without_any_roots(monkeypatch): + monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", []) + + resolved, error = await main._resolve_readable_file_path( + raw_path="anything.txt", + ctx=None, + tool_name="send_file", + ) + + assert resolved is None + assert error is not None + assert "disabled" in error diff --git a/uv.lock b/uv.lock index c7729fb..d65eaea 100644 --- a/uv.lock +++ b/uv.lock @@ -40,6 +40,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373", size = 67615, upload-time = "2025-10-06T13:54:43.17Z" }, ] +[[package]] +name = "backports-asyncio-runner" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, +] + [[package]] name = "black" version = "25.9.0" @@ -394,6 +403,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, ] +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + [[package]] name = "jsonschema" version = "4.25.1" @@ -527,6 +545,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/cb/ac7874b3e5d58441674fb70742e6c374b28b0c7cb988d37d991cde47166c/platformdirs-4.5.0-py3-none-any.whl", hash = "sha256:e578a81bb873cbb89a41fcc904c7ef523cc18284b7e3b3ccf06aca1403b7ebd3", size = 18651, upload-time = "2025-10-08T17:44:47.223Z" }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + [[package]] name = "pyaes" version = "1.6.1" @@ -708,6 +735,38 @@ crypto = [ { name = "cryptography", version = "46.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14' and platform_python_implementation != 'PyPy'" }, ] +[[package]] +name = "pytest" +version = "9.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, +] + +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" }, + { name = "pytest" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, +] + [[package]] name = "python-dotenv" version = "1.1.0" @@ -989,6 +1048,8 @@ dependencies = [ dev = [ { name = "black" }, { name = "flake8" }, + { name = "pytest" }, + { name = "pytest-asyncio" }, ] [package.metadata] @@ -1006,6 +1067,8 @@ requires-dist = [ dev = [ { name = "black", specifier = ">=25.9.0" }, { name = "flake8", specifier = ">=7.3.0" }, + { name = "pytest", specifier = ">=9.0.2" }, + { name = "pytest-asyncio", specifier = ">=1.3.0" }, ] [[package]] From e7b92cd89b87bd71cac0760748acc4c96993c317 Mon Sep 17 00:00:00 2001 From: vp Date: Tue, 24 Feb 2026 22:21:08 +0300 Subject: [PATCH 2/6] docs: restore GIF tools removal note in README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 97a0b33..6e35157 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,8 @@ Example server launch with allowlisted roots: uv --directory /full/path/to/telegram-mcp run main.py /data/telegram /tmp/telegram-mcp ``` +Additionally, GIF-related tools (`get_gif_search`, `get_saved_gifs`, `send_gif`) have been removed due to ongoing issues with reliability in the Telethon library or Telegram API interactions. + --- ## 📋 Requirements From a7df9da9f522473e4a5acbc9d568bff810d9b2ab Mon Sep 17 00:00:00 2001 From: vp Date: Tue, 24 Feb 2026 22:35:30 +0300 Subject: [PATCH 3/6] fix: harden roots fallback semantics and docs --- README.md | 2 +- main.py | 48 +++++++++++++++++--------- test_file_path_security.py | 69 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 103 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 6e35157..bd473a7 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,7 @@ Example server launch with allowlisted roots: uv --directory /full/path/to/telegram-mcp run main.py /data/telegram /tmp/telegram-mcp ``` -Additionally, GIF-related tools (`get_gif_search`, `get_saved_gifs`, `send_gif`) have been removed due to ongoing issues with reliability in the Telethon library or Telegram API interactions. +GIF tools are currently limited: `get_gif_search` and `send_gif` are available, while `get_saved_gifs` is not implemented due to reliability limits in Telethon/Telegram API interactions. --- diff --git a/main.py b/main.py index 3c123cd..53e8945 100644 --- a/main.py +++ b/main.py @@ -161,6 +161,7 @@ MAX_FILE_BYTES: dict[str, int] = { "set_profile_photo": 50 * 1024 * 1024, "edit_chat_photo": 50 * 1024 * 1024, } +ROOTS_UNSUPPORTED_ERROR_CODES = {-32601} # Error code prefix mapping for better error tracing @@ -462,22 +463,39 @@ def _ensure_size_within_limit(tool_name: str, candidate: Path) -> Optional[str]: async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]: - if ctx is not None: + fallback_roots = list(SERVER_ALLOWED_ROOTS) + if ctx is None: + return fallback_roots + + try: + list_roots_result = await ctx.session.list_roots() + except McpError as error: + error_code = getattr(getattr(error, "error", None), "code", None) + error_message = ( + getattr(getattr(error, "error", None), "message", None) or str(error) + ).lower() + if error_code in ROOTS_UNSUPPORTED_ERROR_CODES or "method not found" in error_message: + # Fallback is allowed only when roots are unsupported in this MCP client session. + return fallback_roots + logger.error("MCP roots request failed; disabling file-path tools for safety.", exc_info=True) + return [] + except Exception: + logger.error("Unexpected MCP roots failure; disabling file-path tools for safety.", exc_info=True) + return [] + + client_roots: List[Path] = [] + for root in list_roots_result.roots: try: - list_roots_result = await ctx.session.list_roots() - client_roots: List[Path] = [] - for root in list_roots_result.roots: - try: - client_roots.append(_coerce_root_uri_to_path(str(root.uri))) - except Exception: - # Ignore invalid root entries supplied by a client. - continue - return _dedupe_paths(client_roots) - except (McpError, Exception): - # Fall back to server-side allowlist when roots are unsupported - # or not available in this MCP client session. - pass - return list(SERVER_ALLOWED_ROOTS) + client_roots.append(_coerce_root_uri_to_path(str(root.uri))) + except Exception: + # Ignore invalid root entries supplied by a client. + continue + + if client_roots: + return _dedupe_paths(client_roots) + + # If client returned an empty roots list, keep server-side fallback roots. + return fallback_roots async def _ensure_allowed_roots( diff --git a/test_file_path_security.py b/test_file_path_security.py index bc54da6..90f4db1 100644 --- a/test_file_path_security.py +++ b/test_file_path_security.py @@ -3,6 +3,8 @@ from pathlib import Path import pytest from mcp import types +from mcp.shared.exceptions import McpError +from mcp.types import ErrorData os.environ.setdefault("TELEGRAM_API_ID", "12345") os.environ.setdefault("TELEGRAM_API_HASH", "dummy_hash") @@ -23,6 +25,19 @@ class _DummyContext: self.session = _DummySession(roots) +class _FailingSession: + def __init__(self, error): + self._error = error + + async def list_roots(self): + raise self._error + + +class _FailingContext: + def __init__(self, error): + self.session = _FailingSession(error) + + @pytest.mark.asyncio async def test_readable_relative_path_resolves_inside_first_server_root(tmp_path, monkeypatch): root = (tmp_path / "root").resolve() @@ -106,6 +121,60 @@ async def test_client_roots_replace_server_allowlist(tmp_path, monkeypatch): assert resolved == client_file.resolve() +@pytest.mark.asyncio +async def test_empty_client_roots_fall_back_to_server_allowlist(tmp_path, monkeypatch): + server_root = (tmp_path / "server_root").resolve() + server_root.mkdir(parents=True) + server_file = server_root / "server.txt" + server_file.write_text("server", encoding="utf-8") + + monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root]) + ctx = _DummyContext([]) + + roots = await main._get_effective_allowed_roots(ctx) + assert roots == [server_root] + + resolved, error = await main._resolve_readable_file_path( + raw_path="server.txt", + ctx=ctx, + tool_name="send_file", + ) + assert error is None + assert resolved == server_file.resolve() + + +@pytest.mark.asyncio +async def test_mcp_method_not_found_falls_back_to_server_allowlist(tmp_path, monkeypatch): + server_root = (tmp_path / "server_root").resolve() + server_root.mkdir(parents=True) + + monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root]) + ctx = _FailingContext(McpError(ErrorData(code=-32601, message="Method not found"))) + + roots = await main._get_effective_allowed_roots(ctx) + assert roots == [server_root] + + +@pytest.mark.asyncio +async def test_unexpected_roots_error_disables_file_path_tools(tmp_path, monkeypatch): + server_root = (tmp_path / "server_root").resolve() + server_root.mkdir(parents=True) + monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root]) + + ctx = _FailingContext(RuntimeError("transport failure")) + roots = await main._get_effective_allowed_roots(ctx) + assert roots == [] + + resolved, error = await main._resolve_readable_file_path( + raw_path="anything.txt", + ctx=ctx, + tool_name="send_file", + ) + assert resolved is None + assert error is not None + assert "disabled" in error + + @pytest.mark.asyncio async def test_writable_default_path_uses_downloads_subdir(tmp_path, monkeypatch): root = (tmp_path / "root").resolve() From 3b820620701cd0c81cd3066c18aae6e4113d49f9 Mon Sep 17 00:00:00 2001 From: vp Date: Tue, 24 Feb 2026 22:40:30 +0300 Subject: [PATCH 4/6] style: format main.py with black --- main.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 53e8945..20349f0 100644 --- a/main.py +++ b/main.py @@ -477,10 +477,14 @@ async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]: if error_code in ROOTS_UNSUPPORTED_ERROR_CODES or "method not found" in error_message: # Fallback is allowed only when roots are unsupported in this MCP client session. return fallback_roots - logger.error("MCP roots request failed; disabling file-path tools for safety.", exc_info=True) + logger.error( + "MCP roots request failed; disabling file-path tools for safety.", exc_info=True + ) return [] except Exception: - logger.error("Unexpected MCP roots failure; disabling file-path tools for safety.", exc_info=True) + logger.error( + "Unexpected MCP roots failure; disabling file-path tools for safety.", exc_info=True + ) return [] client_roots: List[Path] = [] From 172fedaf7a7e4145a2e0792715bfc556ce924f18 Mon Sep 17 00:00:00 2001 From: vp Date: Tue, 24 Feb 2026 22:52:23 +0300 Subject: [PATCH 5/6] fix: treat empty MCP roots as explicit deny-all --- README.md | 3 ++- main.py | 4 ++-- test_file_path_security.py | 11 +++++------ 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index bd473a7..752d524 100644 --- a/README.md +++ b/README.md @@ -157,8 +157,9 @@ Supported file-path tools: - `send_file`, `download_media`, `set_profile_photo`, `edit_chat_photo`, `send_voice`, `send_sticker`, `upload_file` Security semantics (aligned with MCP filesystem server): -- Server-side allowlist via CLI positional arguments (fallback). +- Server-side allowlist via CLI positional arguments (fallback when Roots API is unsupported). - Client-provided MCP Roots replace the server allowlist when available. +- If the client returns an empty Roots list, file-path tools are disabled (deny-all). - All paths are resolved via realpath and must stay inside an allowed root. - Traversal/glob-like patterns are rejected (`..`, `*`, `?`, `~`, etc.). - Relative paths resolve against the first allowed root. diff --git a/main.py b/main.py index 20349f0..b531556 100644 --- a/main.py +++ b/main.py @@ -498,8 +498,8 @@ async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]: if client_roots: return _dedupe_paths(client_roots) - # If client returned an empty roots list, keep server-side fallback roots. - return fallback_roots + # Roots API succeeded; an empty roots list is treated as explicit deny-all. + return [] async def _ensure_allowed_roots( diff --git a/test_file_path_security.py b/test_file_path_security.py index 90f4db1..e8e03a5 100644 --- a/test_file_path_security.py +++ b/test_file_path_security.py @@ -122,25 +122,24 @@ async def test_client_roots_replace_server_allowlist(tmp_path, monkeypatch): @pytest.mark.asyncio -async def test_empty_client_roots_fall_back_to_server_allowlist(tmp_path, monkeypatch): +async def test_empty_client_roots_disable_file_tools(tmp_path, monkeypatch): server_root = (tmp_path / "server_root").resolve() server_root.mkdir(parents=True) - server_file = server_root / "server.txt" - server_file.write_text("server", encoding="utf-8") monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root]) ctx = _DummyContext([]) roots = await main._get_effective_allowed_roots(ctx) - assert roots == [server_root] + assert roots == [] resolved, error = await main._resolve_readable_file_path( raw_path="server.txt", ctx=ctx, tool_name="send_file", ) - assert error is None - assert resolved == server_file.resolve() + assert resolved is None + assert error is not None + assert "disabled" in error @pytest.mark.asyncio From 6ec5c95c916c9db53cc16e309d52d3bd657b0d56 Mon Sep 17 00:00:00 2001 From: vp Date: Tue, 24 Feb 2026 23:02:38 +0300 Subject: [PATCH 6/6] fix: harden roots unsupported fallback and deny-all messaging --- main.py | 75 +++++++++++++++++++++++++++++--------- test_file_path_security.py | 24 +++++++++++- 2 files changed, 80 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index b531556..bcd89ce 100644 --- a/main.py +++ b/main.py @@ -162,6 +162,11 @@ MAX_FILE_BYTES: dict[str, int] = { "edit_chat_photo": 50 * 1024 * 1024, } ROOTS_UNSUPPORTED_ERROR_CODES = {-32601} +ROOTS_STATUS_READY = "ready" +ROOTS_STATUS_NOT_CONFIGURED = "not_configured" +ROOTS_STATUS_UNSUPPORTED_FALLBACK = "unsupported_fallback" +ROOTS_STATUS_CLIENT_DENY_ALL = "client_deny_all" +ROOTS_STATUS_ERROR = "error" # Error code prefix mapping for better error tracing @@ -463,29 +468,47 @@ def _ensure_size_within_limit(tool_name: str, candidate: Path) -> Optional[str]: async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]: - fallback_roots = list(SERVER_ALLOWED_ROOTS) - if ctx is None: - return fallback_roots + roots, _status = await _get_effective_allowed_roots_with_status(ctx) + return roots - try: - list_roots_result = await ctx.session.list_roots() - except McpError as error: + +def _is_roots_unsupported_error(error: Exception) -> bool: + if isinstance(error, McpError): error_code = getattr(getattr(error, "error", None), "code", None) error_message = ( getattr(getattr(error, "error", None), "message", None) or str(error) ).lower() - if error_code in ROOTS_UNSUPPORTED_ERROR_CODES or "method not found" in error_message: - # Fallback is allowed only when roots are unsupported in this MCP client session. - return fallback_roots + if error_code in ROOTS_UNSUPPORTED_ERROR_CODES: + return True + return "method not found" in error_message or "not implemented" in error_message + + if isinstance(error, NotImplementedError): + return True + if isinstance(error, AttributeError): + return "list_roots" in str(error) + return False + + +async def _get_effective_allowed_roots_with_status( + ctx: Optional[Context], +) -> tuple[List[Path], str]: + fallback_roots = list(SERVER_ALLOWED_ROOTS) + if ctx is None: + if fallback_roots: + return fallback_roots, ROOTS_STATUS_READY + return [], ROOTS_STATUS_NOT_CONFIGURED + + try: + list_roots_result = await ctx.session.list_roots() + except Exception as error: + if _is_roots_unsupported_error(error): + if fallback_roots: + return fallback_roots, ROOTS_STATUS_UNSUPPORTED_FALLBACK + return [], ROOTS_STATUS_NOT_CONFIGURED logger.error( "MCP roots request failed; disabling file-path tools for safety.", exc_info=True ) - return [] - except Exception: - logger.error( - "Unexpected MCP roots failure; disabling file-path tools for safety.", exc_info=True - ) - return [] + return [], ROOTS_STATUS_ERROR client_roots: List[Path] = [] for root in list_roots_result.roots: @@ -496,17 +519,33 @@ async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]: continue if client_roots: - return _dedupe_paths(client_roots) + return _dedupe_paths(client_roots), ROOTS_STATUS_READY # Roots API succeeded; an empty roots list is treated as explicit deny-all. - return [] + return [], ROOTS_STATUS_CLIENT_DENY_ALL async def _ensure_allowed_roots( ctx: Optional[Context], tool_name: str ) -> tuple[List[Path], Optional[str]]: - roots = await _get_effective_allowed_roots(ctx) + roots, status = await _get_effective_allowed_roots_with_status(ctx) if not roots: + if status == ROOTS_STATUS_CLIENT_DENY_ALL: + return ( + [], + ( + f"{tool_name} is disabled because the client provided an empty " + "MCP Roots list (deny-all)." + ), + ) + if status == ROOTS_STATUS_ERROR: + return ( + [], + ( + f"{tool_name} is disabled because MCP Roots could not be verified safely. " + "Check MCP client/server logs." + ), + ) return ( [], ( diff --git a/test_file_path_security.py b/test_file_path_security.py index e8e03a5..665c413 100644 --- a/test_file_path_security.py +++ b/test_file_path_security.py @@ -38,6 +38,15 @@ class _FailingContext: self.session = _FailingSession(error) +class _MissingRootsSession: + pass + + +class _MissingRootsContext: + def __init__(self): + self.session = _MissingRootsSession() + + @pytest.mark.asyncio async def test_readable_relative_path_resolves_inside_first_server_root(tmp_path, monkeypatch): root = (tmp_path / "root").resolve() @@ -139,7 +148,8 @@ async def test_empty_client_roots_disable_file_tools(tmp_path, monkeypatch): ) assert resolved is None assert error is not None - assert "disabled" in error + assert "empty MCP Roots list" in error + assert "deny-all" in error @pytest.mark.asyncio @@ -154,6 +164,18 @@ async def test_mcp_method_not_found_falls_back_to_server_allowlist(tmp_path, mon assert roots == [server_root] +@pytest.mark.asyncio +async def test_missing_list_roots_method_falls_back_to_server_allowlist(tmp_path, monkeypatch): + server_root = (tmp_path / "server_root").resolve() + server_root.mkdir(parents=True) + + monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root]) + ctx = _MissingRootsContext() + + roots = await main._get_effective_allowed_roots(ctx) + assert roots == [server_root] + + @pytest.mark.asyncio async def test_unexpected_roots_error_disables_file_path_tools(tmp_path, monkeypatch): server_root = (tmp_path / "server_root").resolve()