fix: harden roots fallback semantics and docs

This commit is contained in:
vp 2026-02-24 22:35:30 +03:00
parent e7b92cd89b
commit a7df9da9f5
3 changed files with 103 additions and 16 deletions

View file

@ -169,7 +169,7 @@ Example server launch with allowlisted roots:
uv --directory /full/path/to/telegram-mcp run main.py /data/telegram /tmp/telegram-mcp uv --directory /full/path/to/telegram-mcp run main.py /data/telegram /tmp/telegram-mcp
``` ```
Additionally, GIF-related tools (`get_gif_search`, `get_saved_gifs`, `send_gif`) have been removed due to ongoing issues with reliability in the Telethon library or Telegram API interactions. GIF tools are currently limited: `get_gif_search` and `send_gif` are available, while `get_saved_gifs` is not implemented due to reliability limits in Telethon/Telegram API interactions.
--- ---

48
main.py
View file

@ -161,6 +161,7 @@ MAX_FILE_BYTES: dict[str, int] = {
"set_profile_photo": 50 * 1024 * 1024, "set_profile_photo": 50 * 1024 * 1024,
"edit_chat_photo": 50 * 1024 * 1024, "edit_chat_photo": 50 * 1024 * 1024,
} }
ROOTS_UNSUPPORTED_ERROR_CODES = {-32601}
# Error code prefix mapping for better error tracing # Error code prefix mapping for better error tracing
@ -462,22 +463,39 @@ def _ensure_size_within_limit(tool_name: str, candidate: Path) -> Optional[str]:
async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]: async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]:
if ctx is not None: fallback_roots = list(SERVER_ALLOWED_ROOTS)
if ctx is None:
return fallback_roots
try:
list_roots_result = await ctx.session.list_roots()
except McpError as error:
error_code = getattr(getattr(error, "error", None), "code", None)
error_message = (
getattr(getattr(error, "error", None), "message", None) or str(error)
).lower()
if error_code in ROOTS_UNSUPPORTED_ERROR_CODES or "method not found" in error_message:
# Fallback is allowed only when roots are unsupported in this MCP client session.
return fallback_roots
logger.error("MCP roots request failed; disabling file-path tools for safety.", exc_info=True)
return []
except Exception:
logger.error("Unexpected MCP roots failure; disabling file-path tools for safety.", exc_info=True)
return []
client_roots: List[Path] = []
for root in list_roots_result.roots:
try: try:
list_roots_result = await ctx.session.list_roots() client_roots.append(_coerce_root_uri_to_path(str(root.uri)))
client_roots: List[Path] = [] except Exception:
for root in list_roots_result.roots: # Ignore invalid root entries supplied by a client.
try: continue
client_roots.append(_coerce_root_uri_to_path(str(root.uri)))
except Exception: if client_roots:
# Ignore invalid root entries supplied by a client. return _dedupe_paths(client_roots)
continue
return _dedupe_paths(client_roots) # If client returned an empty roots list, keep server-side fallback roots.
except (McpError, Exception): return fallback_roots
# Fall back to server-side allowlist when roots are unsupported
# or not available in this MCP client session.
pass
return list(SERVER_ALLOWED_ROOTS)
async def _ensure_allowed_roots( async def _ensure_allowed_roots(

View file

@ -3,6 +3,8 @@ from pathlib import Path
import pytest import pytest
from mcp import types from mcp import types
from mcp.shared.exceptions import McpError
from mcp.types import ErrorData
os.environ.setdefault("TELEGRAM_API_ID", "12345") os.environ.setdefault("TELEGRAM_API_ID", "12345")
os.environ.setdefault("TELEGRAM_API_HASH", "dummy_hash") os.environ.setdefault("TELEGRAM_API_HASH", "dummy_hash")
@ -23,6 +25,19 @@ class _DummyContext:
self.session = _DummySession(roots) self.session = _DummySession(roots)
class _FailingSession:
def __init__(self, error):
self._error = error
async def list_roots(self):
raise self._error
class _FailingContext:
def __init__(self, error):
self.session = _FailingSession(error)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_readable_relative_path_resolves_inside_first_server_root(tmp_path, monkeypatch): async def test_readable_relative_path_resolves_inside_first_server_root(tmp_path, monkeypatch):
root = (tmp_path / "root").resolve() root = (tmp_path / "root").resolve()
@ -106,6 +121,60 @@ async def test_client_roots_replace_server_allowlist(tmp_path, monkeypatch):
assert resolved == client_file.resolve() assert resolved == client_file.resolve()
@pytest.mark.asyncio
async def test_empty_client_roots_fall_back_to_server_allowlist(tmp_path, monkeypatch):
server_root = (tmp_path / "server_root").resolve()
server_root.mkdir(parents=True)
server_file = server_root / "server.txt"
server_file.write_text("server", encoding="utf-8")
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
ctx = _DummyContext([])
roots = await main._get_effective_allowed_roots(ctx)
assert roots == [server_root]
resolved, error = await main._resolve_readable_file_path(
raw_path="server.txt",
ctx=ctx,
tool_name="send_file",
)
assert error is None
assert resolved == server_file.resolve()
@pytest.mark.asyncio
async def test_mcp_method_not_found_falls_back_to_server_allowlist(tmp_path, monkeypatch):
server_root = (tmp_path / "server_root").resolve()
server_root.mkdir(parents=True)
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
ctx = _FailingContext(McpError(ErrorData(code=-32601, message="Method not found")))
roots = await main._get_effective_allowed_roots(ctx)
assert roots == [server_root]
@pytest.mark.asyncio
async def test_unexpected_roots_error_disables_file_path_tools(tmp_path, monkeypatch):
server_root = (tmp_path / "server_root").resolve()
server_root.mkdir(parents=True)
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
ctx = _FailingContext(RuntimeError("transport failure"))
roots = await main._get_effective_allowed_roots(ctx)
assert roots == []
resolved, error = await main._resolve_readable_file_path(
raw_path="anything.txt",
ctx=ctx,
tool_name="send_file",
)
assert resolved is None
assert error is not None
assert "disabled" in error
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_writable_default_path_uses_downloads_subdir(tmp_path, monkeypatch): async def test_writable_default_path_uses_downloads_subdir(tmp_path, monkeypatch):
root = (tmp_path / "root").resolve() root = (tmp_path / "root").resolve()