fix: harden roots fallback semantics and docs
This commit is contained in:
parent
e7b92cd89b
commit
a7df9da9f5
3 changed files with 103 additions and 16 deletions
|
|
@ -169,7 +169,7 @@ Example server launch with allowlisted roots:
|
|||
uv --directory /full/path/to/telegram-mcp run main.py /data/telegram /tmp/telegram-mcp
|
||||
```
|
||||
|
||||
Additionally, GIF-related tools (`get_gif_search`, `get_saved_gifs`, `send_gif`) have been removed due to ongoing issues with reliability in the Telethon library or Telegram API interactions.
|
||||
GIF tools are currently limited: `get_gif_search` and `send_gif` are available, while `get_saved_gifs` is not implemented due to reliability limits in Telethon/Telegram API interactions.
|
||||
|
||||
---
|
||||
|
||||
|
|
|
|||
30
main.py
30
main.py
|
|
@ -161,6 +161,7 @@ MAX_FILE_BYTES: dict[str, int] = {
|
|||
"set_profile_photo": 50 * 1024 * 1024,
|
||||
"edit_chat_photo": 50 * 1024 * 1024,
|
||||
}
|
||||
ROOTS_UNSUPPORTED_ERROR_CODES = {-32601}
|
||||
|
||||
|
||||
# Error code prefix mapping for better error tracing
|
||||
|
|
@ -462,9 +463,26 @@ def _ensure_size_within_limit(tool_name: str, candidate: Path) -> Optional[str]:
|
|||
|
||||
|
||||
async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]:
|
||||
if ctx is not None:
|
||||
fallback_roots = list(SERVER_ALLOWED_ROOTS)
|
||||
if ctx is None:
|
||||
return fallback_roots
|
||||
|
||||
try:
|
||||
list_roots_result = await ctx.session.list_roots()
|
||||
except McpError as error:
|
||||
error_code = getattr(getattr(error, "error", None), "code", None)
|
||||
error_message = (
|
||||
getattr(getattr(error, "error", None), "message", None) or str(error)
|
||||
).lower()
|
||||
if error_code in ROOTS_UNSUPPORTED_ERROR_CODES or "method not found" in error_message:
|
||||
# Fallback is allowed only when roots are unsupported in this MCP client session.
|
||||
return fallback_roots
|
||||
logger.error("MCP roots request failed; disabling file-path tools for safety.", exc_info=True)
|
||||
return []
|
||||
except Exception:
|
||||
logger.error("Unexpected MCP roots failure; disabling file-path tools for safety.", exc_info=True)
|
||||
return []
|
||||
|
||||
client_roots: List[Path] = []
|
||||
for root in list_roots_result.roots:
|
||||
try:
|
||||
|
|
@ -472,12 +490,12 @@ async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]:
|
|||
except Exception:
|
||||
# Ignore invalid root entries supplied by a client.
|
||||
continue
|
||||
|
||||
if client_roots:
|
||||
return _dedupe_paths(client_roots)
|
||||
except (McpError, Exception):
|
||||
# Fall back to server-side allowlist when roots are unsupported
|
||||
# or not available in this MCP client session.
|
||||
pass
|
||||
return list(SERVER_ALLOWED_ROOTS)
|
||||
|
||||
# If client returned an empty roots list, keep server-side fallback roots.
|
||||
return fallback_roots
|
||||
|
||||
|
||||
async def _ensure_allowed_roots(
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
from mcp import types
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.types import ErrorData
|
||||
|
||||
os.environ.setdefault("TELEGRAM_API_ID", "12345")
|
||||
os.environ.setdefault("TELEGRAM_API_HASH", "dummy_hash")
|
||||
|
|
@ -23,6 +25,19 @@ class _DummyContext:
|
|||
self.session = _DummySession(roots)
|
||||
|
||||
|
||||
class _FailingSession:
|
||||
def __init__(self, error):
|
||||
self._error = error
|
||||
|
||||
async def list_roots(self):
|
||||
raise self._error
|
||||
|
||||
|
||||
class _FailingContext:
|
||||
def __init__(self, error):
|
||||
self.session = _FailingSession(error)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readable_relative_path_resolves_inside_first_server_root(tmp_path, monkeypatch):
|
||||
root = (tmp_path / "root").resolve()
|
||||
|
|
@ -106,6 +121,60 @@ async def test_client_roots_replace_server_allowlist(tmp_path, monkeypatch):
|
|||
assert resolved == client_file.resolve()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_client_roots_fall_back_to_server_allowlist(tmp_path, monkeypatch):
|
||||
server_root = (tmp_path / "server_root").resolve()
|
||||
server_root.mkdir(parents=True)
|
||||
server_file = server_root / "server.txt"
|
||||
server_file.write_text("server", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
|
||||
ctx = _DummyContext([])
|
||||
|
||||
roots = await main._get_effective_allowed_roots(ctx)
|
||||
assert roots == [server_root]
|
||||
|
||||
resolved, error = await main._resolve_readable_file_path(
|
||||
raw_path="server.txt",
|
||||
ctx=ctx,
|
||||
tool_name="send_file",
|
||||
)
|
||||
assert error is None
|
||||
assert resolved == server_file.resolve()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_method_not_found_falls_back_to_server_allowlist(tmp_path, monkeypatch):
|
||||
server_root = (tmp_path / "server_root").resolve()
|
||||
server_root.mkdir(parents=True)
|
||||
|
||||
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
|
||||
ctx = _FailingContext(McpError(ErrorData(code=-32601, message="Method not found")))
|
||||
|
||||
roots = await main._get_effective_allowed_roots(ctx)
|
||||
assert roots == [server_root]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unexpected_roots_error_disables_file_path_tools(tmp_path, monkeypatch):
|
||||
server_root = (tmp_path / "server_root").resolve()
|
||||
server_root.mkdir(parents=True)
|
||||
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
|
||||
|
||||
ctx = _FailingContext(RuntimeError("transport failure"))
|
||||
roots = await main._get_effective_allowed_roots(ctx)
|
||||
assert roots == []
|
||||
|
||||
resolved, error = await main._resolve_readable_file_path(
|
||||
raw_path="anything.txt",
|
||||
ctx=ctx,
|
||||
tool_name="send_file",
|
||||
)
|
||||
assert resolved is None
|
||||
assert error is not None
|
||||
assert "disabled" in error
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_writable_default_path_uses_downloads_subdir(tmp_path, monkeypatch):
|
||||
root = (tmp_path / "root").resolve()
|
||||
|
|
|
|||
Loading…
Reference in a new issue