feat: secure file-path tools with allowlisted roots

- implement server-side allowlist via CLI positional roots (fallback)\n- implement client MCP Roots override semantics (replace server roots when available)\n- add realpath + in-root validation, traversal/glob rejection, extension and size checks\n- make write path default to <first_root>/downloads when file_path is omitted\n- reintroduce upload_file tool with the same path security model\n- update README with security model and usage\n- add tests for root resolution, replacement semantics, traversal checks, and default write path\n- add pytest and pytest-asyncio to dev dependencies
This commit is contained in:
vp 2026-02-24 22:12:56 +03:00
parent 594e27e53a
commit fd9c90e53c
5 changed files with 626 additions and 55 deletions

View file

@ -101,12 +101,19 @@ This MCP server exposes a huge suite of Telegram tools. **Every major Telegram/T
### User & Profile
- **get_me()**: Get your user info
- **update_profile(first_name, last_name, about)**: Update your profile
- **set_profile_photo(file_path)**: Set a profile photo from an allowed root path
- **delete_profile_photo()**: Remove your profile photo
- **get_user_photos(user_id, limit)**: Get a user's profile photos
- **get_user_status(user_id)**: Get a user's online status
### Media
- **get_media_info(chat_id, message_id)**: Get info about media in a message
- **send_file(chat_id, file_path, caption)**: Send a local file from allowed roots
- **download_media(chat_id, message_id, file_path)**: Save message media under allowed roots
- **upload_file(file_path)**: Upload a local file and return upload metadata
- **send_voice(chat_id, file_path)**: Send `.ogg/.opus` voice note from allowed roots
- **send_sticker(chat_id, file_path)**: Send `.webp` sticker from allowed roots
- **edit_chat_photo(chat_id, file_path)**: Update chat photo from allowed roots
### Search & Discovery
- **search_public_chats(query)**: Search public chats/channels/bots
@ -142,11 +149,25 @@ To improve robustness, all functions accepting `chat_id` or `user_id` parameters
The server will automatically validate the input and convert it to the correct format before making a request to Telegram. If the input is invalid, a clear error message will be returned.
## Removed Functionality
## File-path Tools Security Model
Please note that tools requiring direct file path access on the server (`send_file`, `download_media`, `set_profile_photo`, `edit_chat_photo`, `send_voice`, `send_sticker`, `upload_file`) have been removed from `main.py`. This is due to limitations in the current MCP environment regarding handling file attachments and local file system paths.
File-path tools are available, but **disabled by default** until allowed roots are configured.
Additionally, GIF-related tools (`get_gif_search`, `get_saved_gifs`, `send_gif`) have been removed due to ongoing issues with reliability in the Telethon library or Telegram API interactions.
Supported file-path tools:
- `send_file`, `download_media`, `set_profile_photo`, `edit_chat_photo`, `send_voice`, `send_sticker`, `upload_file`
Security semantics (aligned with MCP filesystem server):
- Server-side allowlist via CLI positional arguments (fallback).
- Client-provided MCP Roots replace the server allowlist when available.
- All paths are resolved via realpath and must stay inside an allowed root.
- Traversal/glob-like patterns are rejected (`..`, `*`, `?`, `~`, etc.).
- Relative paths resolve against the first allowed root.
- Write tools default to `<first_root>/downloads/` when `file_path` is omitted.
Example server launch with allowlisted roots:
```bash
uv --directory /full/path/to/telegram-mcp run main.py /data/telegram /tmp/telegram-mcp
```
---

430
main.py
View file

@ -1,3 +1,4 @@
import argparse
import os
import sys
import json
@ -9,12 +10,15 @@ import mimetypes
from datetime import datetime, timedelta
from enum import Enum
from typing import List, Dict, Optional, Union, Any
from pathlib import Path
from urllib.parse import unquote, urlparse
# Third-party libraries
import nest_asyncio
from dotenv import load_dotenv
from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp import FastMCP, Context
from mcp.types import ToolAnnotations
from mcp.shared.exceptions import McpError
from pythonjsonlogger import jsonlogger
from telethon import TelegramClient, functions, types, utils
from telethon.sessions import StringSession
@ -139,6 +143,26 @@ except Exception as log_error:
logger.error(f"Failed to set up log file handler: {log_error}")
# File-path tool security configuration
SERVER_ALLOWED_ROOTS: list[Path] = []
DEFAULT_DOWNLOAD_SUBDIR = "downloads"
DISALLOWED_PATH_PATTERNS = ("*", "?", "[", "]", "{", "}", "~", "\x00")
EXTENSION_ALLOWLISTS: dict[str, set[str]] = {
"send_voice": {".ogg", ".opus"},
"send_sticker": {".webp"},
"set_profile_photo": {".jpg", ".jpeg", ".png", ".webp"},
"edit_chat_photo": {".jpg", ".jpeg", ".png", ".webp"},
}
MAX_FILE_BYTES: dict[str, int] = {
"send_file": 200 * 1024 * 1024, # 200 MB
"upload_file": 200 * 1024 * 1024,
"send_voice": 100 * 1024 * 1024,
"send_sticker": 10 * 1024 * 1024,
"set_profile_photo": 50 * 1024 * 1024,
"edit_chat_photo": 50 * 1024 * 1024,
}
# Error code prefix mapping for better error tracing
class ErrorCategory(str, Enum):
CHAT = "CHAT"
@ -363,6 +387,219 @@ def get_engagement_info(message) -> str:
return f" | {', '.join(engagement_parts)}" if engagement_parts else ""
def _dedupe_paths(paths: List[Path]) -> List[Path]:
seen: set[str] = set()
result: List[Path] = []
for path in paths:
key = str(path)
if key in seen:
continue
seen.add(key)
result.append(path)
return result
def _contains_forbidden_path_patterns(raw_path: str) -> Optional[str]:
value = raw_path.strip()
if not value:
return "Path must not be empty."
if any(token in value for token in DISALLOWED_PATH_PATTERNS):
return "Path contains disallowed wildcard/shell patterns."
if ".." in Path(value).parts:
return "Path traversal is not allowed."
return None
def _coerce_root_uri_to_path(uri: str) -> Path:
parsed = urlparse(uri)
if parsed.scheme != "file":
raise ValueError(f"Unsupported root URI scheme: {parsed.scheme}")
decoded_path = unquote(parsed.path or "")
if parsed.netloc and parsed.netloc not in ("", "localhost"):
decoded_path = f"//{parsed.netloc}{decoded_path}"
if os.name == "nt" and decoded_path.startswith("/") and len(decoded_path) > 2:
# file:///C:/tmp -> C:/tmp on Windows
if decoded_path[2] == ":":
decoded_path = decoded_path[1:]
return Path(decoded_path).resolve(strict=True)
def _path_is_within_root(candidate: Path, root: Path) -> bool:
root = root.resolve()
if root.is_file():
return candidate == root
return candidate == root or root in candidate.parents
def _path_is_within_any_root(candidate: Path, roots: List[Path]) -> bool:
return any(_path_is_within_root(candidate, root) for root in roots)
def _first_resolution_root(roots: List[Path]) -> Path:
first = roots[0]
return first if first.is_dir() else first.parent
def _ensure_extension_allowed(tool_name: str, candidate: Path) -> Optional[str]:
allowlist = EXTENSION_ALLOWLISTS.get(tool_name)
if not allowlist:
return None
if candidate.suffix.lower() not in allowlist:
allowed = ", ".join(sorted(allowlist))
return f"File extension is not allowed for {tool_name}. Allowed: {allowed}."
return None
def _ensure_size_within_limit(tool_name: str, candidate: Path) -> Optional[str]:
max_bytes = MAX_FILE_BYTES.get(tool_name)
if not max_bytes:
return None
size = candidate.stat().st_size
if size > max_bytes:
return f"File is too large for {tool_name}: {size} bytes " f"(limit: {max_bytes} bytes)."
return None
async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]:
if ctx is not None:
try:
list_roots_result = await ctx.session.list_roots()
client_roots: List[Path] = []
for root in list_roots_result.roots:
try:
client_roots.append(_coerce_root_uri_to_path(str(root.uri)))
except Exception:
# Ignore invalid root entries supplied by a client.
continue
return _dedupe_paths(client_roots)
except (McpError, Exception):
# Fall back to server-side allowlist when roots are unsupported
# or not available in this MCP client session.
pass
return list(SERVER_ALLOWED_ROOTS)
async def _ensure_allowed_roots(
ctx: Optional[Context], tool_name: str
) -> tuple[List[Path], Optional[str]]:
roots = await _get_effective_allowed_roots(ctx)
if not roots:
return (
[],
(
f"{tool_name} is disabled until allowed roots are configured. "
"Provide server CLI roots and/or client MCP Roots."
),
)
return roots, None
async def _resolve_readable_file_path(
*,
raw_path: str,
ctx: Optional[Context],
tool_name: str,
) -> tuple[Optional[Path], Optional[str]]:
roots, error = await _ensure_allowed_roots(ctx, tool_name)
if error:
return None, error
pattern_error = _contains_forbidden_path_patterns(raw_path)
if pattern_error:
return None, pattern_error
candidate = Path(raw_path.strip())
if not candidate.is_absolute():
candidate = _first_resolution_root(roots) / candidate
try:
candidate = candidate.resolve(strict=True)
except FileNotFoundError:
return None, f"File not found: {raw_path}"
if not _path_is_within_any_root(candidate, roots):
return None, "Path is outside allowed roots."
if not candidate.is_file():
return None, f"Path is not a file: {candidate}"
if not os.access(candidate, os.R_OK):
return None, f"File is not readable: {candidate}"
extension_error = _ensure_extension_allowed(tool_name, candidate)
if extension_error:
return None, extension_error
size_error = _ensure_size_within_limit(tool_name, candidate)
if size_error:
return None, size_error
return candidate, None
async def _resolve_writable_file_path(
*,
raw_path: Optional[str],
default_filename: str,
ctx: Optional[Context],
tool_name: str,
) -> tuple[Optional[Path], Optional[str]]:
roots, error = await _ensure_allowed_roots(ctx, tool_name)
if error:
return None, error
if raw_path and raw_path.strip():
pattern_error = _contains_forbidden_path_patterns(raw_path)
if pattern_error:
return None, pattern_error
candidate = Path(raw_path.strip())
if not candidate.is_absolute():
candidate = _first_resolution_root(roots) / candidate
else:
safe_name = Path(default_filename).name
candidate = _first_resolution_root(roots) / DEFAULT_DOWNLOAD_SUBDIR / safe_name
candidate = candidate.resolve(strict=False)
parent = candidate.parent.resolve(strict=False)
if not _path_is_within_any_root(candidate, roots) or not _path_is_within_any_root(
parent, roots
):
return None, "Path is outside allowed roots."
extension_error = _ensure_extension_allowed(tool_name, candidate)
if extension_error:
return None, extension_error
parent.mkdir(parents=True, exist_ok=True)
if not os.access(parent, os.W_OK):
return None, f"Directory not writable: {parent}"
return candidate, None
def _configure_allowed_roots_from_cli(argv: Optional[List[str]] = None) -> None:
parser = argparse.ArgumentParser(
prog="telegram-mcp",
add_help=False,
description=(
"Optional positional arguments define server-side allowed roots "
"for file-path tools."
),
)
parser.add_argument("allowed_roots", nargs="*")
parsed, _unknown = parser.parse_known_args(argv or [])
resolved_roots: List[Path] = []
for raw_root in parsed.allowed_roots:
root = Path(raw_root).expanduser()
if not root.exists():
raise SystemExit(f"Allowed root does not exist: {root}")
resolved = root.resolve(strict=True)
resolved_roots.append(resolved)
global SERVER_ALLOWED_ROOTS
SERVER_ALLOWED_ROOTS = _dedupe_paths(resolved_roots)
@mcp.tool(annotations=ToolAnnotations(title="Get Chats", openWorldHint=True, readOnlyHint=True))
async def get_chats(page: int = 1, page_size: int = 20) -> str:
"""
@ -1730,22 +1967,30 @@ async def get_participants(chat_id: Union[int, str]) -> str:
@mcp.tool(annotations=ToolAnnotations(title="Send File", openWorldHint=True, destructiveHint=True))
@validate_id("chat_id")
async def send_file(chat_id: Union[int, str], file_path: str, caption: str = None) -> str:
async def send_file(
chat_id: Union[int, str],
file_path: str,
caption: str = None,
ctx: Optional[Context] = None,
) -> str:
"""
Send a file to a chat.
Args:
chat_id: The chat ID or username.
file_path: Absolute path to the file to send (must exist and be readable).
file_path: Absolute or relative path to the file under allowed roots.
caption: Optional caption for the file.
"""
try:
if not os.path.isfile(file_path):
return f"File not found: {file_path}"
if not os.access(file_path, os.R_OK):
return f"File is not readable: {file_path}"
safe_path, path_error = await _resolve_readable_file_path(
raw_path=file_path,
ctx=ctx,
tool_name="send_file",
)
if path_error:
return path_error
entity = await client.get_entity(chat_id)
await client.send_file(entity, file_path, caption=caption)
return f"File sent to chat {chat_id}."
await client.send_file(entity, str(safe_path), caption=caption)
return f"File sent to chat {chat_id} from {safe_path}."
except Exception as e:
return log_and_format_error(
"send_file", e, chat_id=chat_id, file_path=file_path, caption=caption
@ -1753,30 +1998,51 @@ async def send_file(chat_id: Union[int, str], file_path: str, caption: str = Non
@mcp.tool(
annotations=ToolAnnotations(title="Download Media", openWorldHint=True, readOnlyHint=True)
annotations=ToolAnnotations(title="Download Media", openWorldHint=True, destructiveHint=True)
)
@validate_id("chat_id")
async def download_media(chat_id: Union[int, str], message_id: int, file_path: str) -> str:
async def download_media(
chat_id: Union[int, str],
message_id: int,
file_path: Optional[str] = None,
ctx: Optional[Context] = None,
) -> str:
"""
Download media from a message in a chat.
Args:
chat_id: The chat ID or username.
message_id: The message ID containing the media.
file_path: Absolute path to save the downloaded file (must be writable).
file_path: Optional absolute or relative path under allowed roots.
If omitted, saves into `<first_root>/downloads/`.
"""
try:
entity = await client.get_entity(chat_id)
msg = await client.get_messages(entity, ids=message_id)
if not msg or not msg.media:
return "No media found in the specified message."
# Check if directory is writable
dir_path = os.path.dirname(file_path) or "."
if not os.access(dir_path, os.W_OK):
return f"Directory not writable: {dir_path}"
await client.download_media(msg, file=file_path)
if not os.path.isfile(file_path):
return f"Download failed: file not created at {file_path}"
return f"Media downloaded to {file_path}."
default_name = f"telegram_{chat_id}_{message_id}_{int(time.time())}"
out_path, path_error = await _resolve_writable_file_path(
raw_path=file_path,
default_filename=default_name,
ctx=ctx,
tool_name="download_media",
)
if path_error:
return path_error
downloaded = await client.download_media(msg, file=str(out_path))
if not downloaded:
return f"Download failed for message {message_id}."
final_path = Path(downloaded).resolve(strict=True)
roots, roots_error = await _ensure_allowed_roots(ctx, "download_media")
if roots_error:
return roots_error
if not _path_is_within_any_root(final_path, roots):
return "Download failed: resulting path is outside allowed roots."
return f"Media downloaded to {final_path}."
except Exception as e:
return log_and_format_error(
"download_media",
@ -1814,15 +2080,24 @@ async def update_profile(first_name: str = None, last_name: str = None, about: s
title="Set Profile Photo", openWorldHint=True, destructiveHint=True, idempotentHint=True
)
)
async def set_profile_photo(file_path: str) -> str:
async def set_profile_photo(file_path: str, ctx: Optional[Context] = None) -> str:
"""
Set a new profile photo.
"""
try:
await client(
functions.photos.UploadProfilePhotoRequest(file=await client.upload_file(file_path))
safe_path, path_error = await _resolve_readable_file_path(
raw_path=file_path,
ctx=ctx,
tool_name="set_profile_photo",
)
return "Profile photo updated."
if path_error:
return path_error
await client(
functions.photos.UploadProfilePhotoRequest(
file=await client.upload_file(str(safe_path))
)
)
return f"Profile photo updated from {safe_path}."
except Exception as e:
return log_and_format_error("set_profile_photo", e, file_path=file_path)
@ -2077,18 +2352,25 @@ async def edit_chat_title(chat_id: Union[int, str], title: str) -> str:
)
)
@validate_id("chat_id")
async def edit_chat_photo(chat_id: Union[int, str], file_path: str) -> str:
async def edit_chat_photo(
chat_id: Union[int, str],
file_path: str,
ctx: Optional[Context] = None,
) -> str:
"""
Edit the photo of a chat, group, or channel. Requires a file path to an image.
"""
try:
if not os.path.isfile(file_path):
return f"Photo file not found: {file_path}"
if not os.access(file_path, os.R_OK):
return f"Photo file not readable: {file_path}"
safe_path, path_error = await _resolve_readable_file_path(
raw_path=file_path,
ctx=ctx,
tool_name="edit_chat_photo",
)
if path_error:
return path_error
entity = await client.get_entity(chat_id)
uploaded_file = await client.upload_file(file_path)
uploaded_file = await client.upload_file(str(safe_path))
if isinstance(entity, Channel):
# For channels/supergroups, use EditPhotoRequest with InputChatUploadedPhoto
@ -2103,7 +2385,7 @@ async def edit_chat_photo(chat_id: Union[int, str], file_path: str) -> str:
else:
return f"Cannot edit photo for this entity type ({type(entity)})."
return f"Chat {chat_id} photo updated."
return f"Chat {chat_id} photo updated from {safe_path}."
except Exception as e:
logger.exception(f"edit_chat_photo failed (chat_id={chat_id}, file_path='{file_path}')")
return log_and_format_error("edit_chat_photo", e, chat_id=chat_id, file_path=file_path)
@ -2606,38 +2888,76 @@ async def import_chat_invite(hash: str) -> str:
annotations=ToolAnnotations(title="Send Voice", openWorldHint=True, destructiveHint=True)
)
@validate_id("chat_id")
async def send_voice(chat_id: Union[int, str], file_path: str) -> str:
async def send_voice(
chat_id: Union[int, str],
file_path: str,
ctx: Optional[Context] = None,
) -> str:
"""
Send a voice message to a chat. File must be an OGG/OPUS voice note.
Args:
chat_id: The chat ID or username.
file_path: Absolute path to the OGG/OPUS file.
file_path: Absolute or relative path under allowed roots to the OGG/OPUS file.
"""
try:
if not os.path.isfile(file_path):
return f"File not found: {file_path}"
if not os.access(file_path, os.R_OK):
return f"File is not readable: {file_path}"
safe_path, path_error = await _resolve_readable_file_path(
raw_path=file_path,
ctx=ctx,
tool_name="send_voice",
)
if path_error:
return path_error
mime, _ = mimetypes.guess_type(file_path)
mime, _ = mimetypes.guess_type(str(safe_path))
if not (
mime
and (
mime == "audio/ogg"
or file_path.lower().endswith(".ogg")
or file_path.lower().endswith(".opus")
or str(safe_path).lower().endswith(".ogg")
or str(safe_path).lower().endswith(".opus")
)
):
return "Voice file must be .ogg or .opus format."
entity = await client.get_entity(chat_id)
await client.send_file(entity, file_path, voice_note=True)
return f"Voice message sent to chat {chat_id}."
await client.send_file(entity, str(safe_path), voice_note=True)
return f"Voice message sent to chat {chat_id} from {safe_path}."
except Exception as e:
return log_and_format_error("send_voice", e, chat_id=chat_id, file_path=file_path)
@mcp.tool(
annotations=ToolAnnotations(title="Upload File", openWorldHint=True, destructiveHint=True)
)
async def upload_file(file_path: str, ctx: Optional[Context] = None) -> str:
"""
Upload a local file to Telegram and return upload metadata.
Args:
file_path: Absolute or relative path under allowed roots.
"""
try:
safe_path, path_error = await _resolve_readable_file_path(
raw_path=file_path,
ctx=ctx,
tool_name="upload_file",
)
if path_error:
return path_error
uploaded = await client.upload_file(str(safe_path))
payload = {
"path": str(safe_path),
"name": getattr(uploaded, "name", safe_path.name),
"size": getattr(uploaded, "size", safe_path.stat().st_size),
"md5_checksum": getattr(uploaded, "md5_checksum", None),
}
return json.dumps(payload, indent=2, default=json_serializer)
except Exception as e:
return log_and_format_error("upload_file", e, file_path=file_path)
@mcp.tool(
annotations=ToolAnnotations(title="Forward Message", openWorldHint=True, destructiveHint=True)
)
@ -3003,25 +3323,30 @@ async def get_sticker_sets() -> str:
annotations=ToolAnnotations(title="Send Sticker", openWorldHint=True, destructiveHint=True)
)
@validate_id("chat_id")
async def send_sticker(chat_id: Union[int, str], file_path: str) -> str:
async def send_sticker(
chat_id: Union[int, str],
file_path: str,
ctx: Optional[Context] = None,
) -> str:
"""
Send a sticker to a chat. File must be a valid .webp sticker file.
Args:
chat_id: The chat ID or username.
file_path: Absolute path to the .webp sticker file.
file_path: Absolute or relative path under allowed roots to the .webp sticker file.
"""
try:
if not os.path.isfile(file_path):
return f"Sticker file not found: {file_path}"
if not os.access(file_path, os.R_OK):
return f"Sticker file is not readable: {file_path}"
if not file_path.lower().endswith(".webp"):
return "Sticker file must be a .webp file."
safe_path, path_error = await _resolve_readable_file_path(
raw_path=file_path,
ctx=ctx,
tool_name="send_sticker",
)
if path_error:
return path_error
entity = await client.get_entity(chat_id)
await client.send_file(entity, file_path, force_document=False)
return f"Sticker sent to chat {chat_id}."
await client.send_file(entity, str(safe_path), force_document=False)
return f"Sticker sent to chat {chat_id} from {safe_path}."
except Exception as e:
return log_and_format_error("send_sticker", e, chat_id=chat_id, file_path=file_path)
@ -4228,6 +4553,7 @@ async def _main() -> None:
def main() -> None:
_configure_allowed_roots_from_cli(sys.argv[1:])
nest_asyncio.apply()
asyncio.run(_main())

View file

@ -64,4 +64,6 @@ exclude = [
dev = [
"black>=25.9.0",
"flake8>=7.3.0",
"pytest>=9.0.2",
"pytest-asyncio>=1.3.0",
]

159
test_file_path_security.py Normal file
View file

@ -0,0 +1,159 @@
import os
from pathlib import Path
import pytest
from mcp import types
os.environ.setdefault("TELEGRAM_API_ID", "12345")
os.environ.setdefault("TELEGRAM_API_HASH", "dummy_hash")
import main
class _DummySession:
def __init__(self, roots):
self._roots = roots
async def list_roots(self):
return types.ListRootsResult(roots=self._roots)
class _DummyContext:
def __init__(self, roots):
self.session = _DummySession(roots)
@pytest.mark.asyncio
async def test_readable_relative_path_resolves_inside_first_server_root(tmp_path, monkeypatch):
root = (tmp_path / "root").resolve()
root.mkdir(parents=True)
target = root / "document.txt"
target.write_text("ok", encoding="utf-8")
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [root])
resolved, error = await main._resolve_readable_file_path(
raw_path="document.txt",
ctx=None,
tool_name="send_file",
)
assert error is None
assert resolved == target.resolve()
@pytest.mark.asyncio
async def test_readable_path_rejects_traversal(tmp_path, monkeypatch):
root = (tmp_path / "root").resolve()
root.mkdir(parents=True)
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [root])
resolved, error = await main._resolve_readable_file_path(
raw_path="../etc/passwd",
ctx=None,
tool_name="send_file",
)
assert resolved is None
assert error == "Path traversal is not allowed."
@pytest.mark.asyncio
async def test_readable_path_rejects_outside_root(tmp_path, monkeypatch):
root = (tmp_path / "root").resolve()
outside_root = (tmp_path / "outside").resolve()
root.mkdir(parents=True)
outside_root.mkdir(parents=True)
outside_file = outside_root / "outside.txt"
outside_file.write_text("no", encoding="utf-8")
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [root])
resolved, error = await main._resolve_readable_file_path(
raw_path=str(outside_file),
ctx=None,
tool_name="send_file",
)
assert resolved is None
assert error == "Path is outside allowed roots."
@pytest.mark.asyncio
async def test_client_roots_replace_server_allowlist(tmp_path, monkeypatch):
server_root = (tmp_path / "server_root").resolve()
client_root = (tmp_path / "client_root").resolve()
server_root.mkdir(parents=True)
client_root.mkdir(parents=True)
(server_root / "server.txt").write_text("server", encoding="utf-8")
client_file = client_root / "client.txt"
client_file.write_text("client", encoding="utf-8")
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
ctx = _DummyContext([types.Root(uri=client_root.as_uri())])
roots = await main._get_effective_allowed_roots(ctx)
assert roots == [client_root]
resolved, error = await main._resolve_readable_file_path(
raw_path="client.txt",
ctx=ctx,
tool_name="send_file",
)
assert error is None
assert resolved == client_file.resolve()
@pytest.mark.asyncio
async def test_writable_default_path_uses_downloads_subdir(tmp_path, monkeypatch):
root = (tmp_path / "root").resolve()
root.mkdir(parents=True)
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [root])
resolved, error = await main._resolve_writable_file_path(
raw_path=None,
default_filename="example.bin",
ctx=None,
tool_name="download_media",
)
assert error is None
assert resolved == (root / "downloads" / "example.bin").resolve()
assert resolved.parent.exists()
@pytest.mark.asyncio
async def test_extension_allowlist_is_enforced_for_sticker(tmp_path, monkeypatch):
root = (tmp_path / "root").resolve()
root.mkdir(parents=True)
file_path = root / "sticker.txt"
file_path.write_text("bad", encoding="utf-8")
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [root])
resolved, error = await main._resolve_readable_file_path(
raw_path=str(file_path),
ctx=None,
tool_name="send_sticker",
)
assert resolved is None
assert error is not None
assert "extension is not allowed" in error
@pytest.mark.asyncio
async def test_file_tools_disabled_without_any_roots(monkeypatch):
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [])
resolved, error = await main._resolve_readable_file_path(
raw_path="anything.txt",
ctx=None,
tool_name="send_file",
)
assert resolved is None
assert error is not None
assert "disabled" in error

63
uv.lock
View file

@ -40,6 +40,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373", size = 67615, upload-time = "2025-10-06T13:54:43.17Z" },
]
[[package]]
name = "backports-asyncio-runner"
version = "1.2.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" },
]
[[package]]
name = "black"
version = "25.9.0"
@ -394,6 +403,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" },
]
[[package]]
name = "iniconfig"
version = "2.3.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
]
[[package]]
name = "jsonschema"
version = "4.25.1"
@ -527,6 +545,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/73/cb/ac7874b3e5d58441674fb70742e6c374b28b0c7cb988d37d991cde47166c/platformdirs-4.5.0-py3-none-any.whl", hash = "sha256:e578a81bb873cbb89a41fcc904c7ef523cc18284b7e3b3ccf06aca1403b7ebd3", size = 18651, upload-time = "2025-10-08T17:44:47.223Z" },
]
[[package]]
name = "pluggy"
version = "1.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
]
[[package]]
name = "pyaes"
version = "1.6.1"
@ -708,6 +735,38 @@ crypto = [
{ name = "cryptography", version = "46.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14' and platform_python_implementation != 'PyPy'" },
]
[[package]]
name = "pytest"
version = "9.0.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
{ name = "iniconfig" },
{ name = "packaging" },
{ name = "pluggy" },
{ name = "pygments" },
{ name = "tomli", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
]
[[package]]
name = "pytest-asyncio"
version = "1.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" },
{ name = "pytest" },
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" },
]
[[package]]
name = "python-dotenv"
version = "1.1.0"
@ -989,6 +1048,8 @@ dependencies = [
dev = [
{ name = "black" },
{ name = "flake8" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
]
[package.metadata]
@ -1006,6 +1067,8 @@ requires-dist = [
dev = [
{ name = "black", specifier = ">=25.9.0" },
{ name = "flake8", specifier = ">=7.3.0" },
{ name = "pytest", specifier = ">=9.0.2" },
{ name = "pytest-asyncio", specifier = ">=1.3.0" },
]
[[package]]