fix: harden roots fallback semantics and docs
This commit is contained in:
parent
e7b92cd89b
commit
a7df9da9f5
3 changed files with 103 additions and 16 deletions
|
|
@ -169,7 +169,7 @@ Example server launch with allowlisted roots:
|
||||||
uv --directory /full/path/to/telegram-mcp run main.py /data/telegram /tmp/telegram-mcp
|
uv --directory /full/path/to/telegram-mcp run main.py /data/telegram /tmp/telegram-mcp
|
||||||
```
|
```
|
||||||
|
|
||||||
Additionally, GIF-related tools (`get_gif_search`, `get_saved_gifs`, `send_gif`) have been removed due to ongoing issues with reliability in the Telethon library or Telegram API interactions.
|
GIF tools are currently limited: `get_gif_search` and `send_gif` are available, while `get_saved_gifs` is not implemented due to reliability limits in Telethon/Telegram API interactions.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
|
||||||
48
main.py
48
main.py
|
|
@ -161,6 +161,7 @@ MAX_FILE_BYTES: dict[str, int] = {
|
||||||
"set_profile_photo": 50 * 1024 * 1024,
|
"set_profile_photo": 50 * 1024 * 1024,
|
||||||
"edit_chat_photo": 50 * 1024 * 1024,
|
"edit_chat_photo": 50 * 1024 * 1024,
|
||||||
}
|
}
|
||||||
|
ROOTS_UNSUPPORTED_ERROR_CODES = {-32601}
|
||||||
|
|
||||||
|
|
||||||
# Error code prefix mapping for better error tracing
|
# Error code prefix mapping for better error tracing
|
||||||
|
|
@ -462,22 +463,39 @@ def _ensure_size_within_limit(tool_name: str, candidate: Path) -> Optional[str]:
|
||||||
|
|
||||||
|
|
||||||
async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]:
|
async def _get_effective_allowed_roots(ctx: Optional[Context]) -> List[Path]:
|
||||||
if ctx is not None:
|
fallback_roots = list(SERVER_ALLOWED_ROOTS)
|
||||||
|
if ctx is None:
|
||||||
|
return fallback_roots
|
||||||
|
|
||||||
|
try:
|
||||||
|
list_roots_result = await ctx.session.list_roots()
|
||||||
|
except McpError as error:
|
||||||
|
error_code = getattr(getattr(error, "error", None), "code", None)
|
||||||
|
error_message = (
|
||||||
|
getattr(getattr(error, "error", None), "message", None) or str(error)
|
||||||
|
).lower()
|
||||||
|
if error_code in ROOTS_UNSUPPORTED_ERROR_CODES or "method not found" in error_message:
|
||||||
|
# Fallback is allowed only when roots are unsupported in this MCP client session.
|
||||||
|
return fallback_roots
|
||||||
|
logger.error("MCP roots request failed; disabling file-path tools for safety.", exc_info=True)
|
||||||
|
return []
|
||||||
|
except Exception:
|
||||||
|
logger.error("Unexpected MCP roots failure; disabling file-path tools for safety.", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
client_roots: List[Path] = []
|
||||||
|
for root in list_roots_result.roots:
|
||||||
try:
|
try:
|
||||||
list_roots_result = await ctx.session.list_roots()
|
client_roots.append(_coerce_root_uri_to_path(str(root.uri)))
|
||||||
client_roots: List[Path] = []
|
except Exception:
|
||||||
for root in list_roots_result.roots:
|
# Ignore invalid root entries supplied by a client.
|
||||||
try:
|
continue
|
||||||
client_roots.append(_coerce_root_uri_to_path(str(root.uri)))
|
|
||||||
except Exception:
|
if client_roots:
|
||||||
# Ignore invalid root entries supplied by a client.
|
return _dedupe_paths(client_roots)
|
||||||
continue
|
|
||||||
return _dedupe_paths(client_roots)
|
# If client returned an empty roots list, keep server-side fallback roots.
|
||||||
except (McpError, Exception):
|
return fallback_roots
|
||||||
# Fall back to server-side allowlist when roots are unsupported
|
|
||||||
# or not available in this MCP client session.
|
|
||||||
pass
|
|
||||||
return list(SERVER_ALLOWED_ROOTS)
|
|
||||||
|
|
||||||
|
|
||||||
async def _ensure_allowed_roots(
|
async def _ensure_allowed_roots(
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@ from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from mcp import types
|
from mcp import types
|
||||||
|
from mcp.shared.exceptions import McpError
|
||||||
|
from mcp.types import ErrorData
|
||||||
|
|
||||||
os.environ.setdefault("TELEGRAM_API_ID", "12345")
|
os.environ.setdefault("TELEGRAM_API_ID", "12345")
|
||||||
os.environ.setdefault("TELEGRAM_API_HASH", "dummy_hash")
|
os.environ.setdefault("TELEGRAM_API_HASH", "dummy_hash")
|
||||||
|
|
@ -23,6 +25,19 @@ class _DummyContext:
|
||||||
self.session = _DummySession(roots)
|
self.session = _DummySession(roots)
|
||||||
|
|
||||||
|
|
||||||
|
class _FailingSession:
|
||||||
|
def __init__(self, error):
|
||||||
|
self._error = error
|
||||||
|
|
||||||
|
async def list_roots(self):
|
||||||
|
raise self._error
|
||||||
|
|
||||||
|
|
||||||
|
class _FailingContext:
|
||||||
|
def __init__(self, error):
|
||||||
|
self.session = _FailingSession(error)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_readable_relative_path_resolves_inside_first_server_root(tmp_path, monkeypatch):
|
async def test_readable_relative_path_resolves_inside_first_server_root(tmp_path, monkeypatch):
|
||||||
root = (tmp_path / "root").resolve()
|
root = (tmp_path / "root").resolve()
|
||||||
|
|
@ -106,6 +121,60 @@ async def test_client_roots_replace_server_allowlist(tmp_path, monkeypatch):
|
||||||
assert resolved == client_file.resolve()
|
assert resolved == client_file.resolve()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_client_roots_fall_back_to_server_allowlist(tmp_path, monkeypatch):
|
||||||
|
server_root = (tmp_path / "server_root").resolve()
|
||||||
|
server_root.mkdir(parents=True)
|
||||||
|
server_file = server_root / "server.txt"
|
||||||
|
server_file.write_text("server", encoding="utf-8")
|
||||||
|
|
||||||
|
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
|
||||||
|
ctx = _DummyContext([])
|
||||||
|
|
||||||
|
roots = await main._get_effective_allowed_roots(ctx)
|
||||||
|
assert roots == [server_root]
|
||||||
|
|
||||||
|
resolved, error = await main._resolve_readable_file_path(
|
||||||
|
raw_path="server.txt",
|
||||||
|
ctx=ctx,
|
||||||
|
tool_name="send_file",
|
||||||
|
)
|
||||||
|
assert error is None
|
||||||
|
assert resolved == server_file.resolve()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_method_not_found_falls_back_to_server_allowlist(tmp_path, monkeypatch):
|
||||||
|
server_root = (tmp_path / "server_root").resolve()
|
||||||
|
server_root.mkdir(parents=True)
|
||||||
|
|
||||||
|
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
|
||||||
|
ctx = _FailingContext(McpError(ErrorData(code=-32601, message="Method not found")))
|
||||||
|
|
||||||
|
roots = await main._get_effective_allowed_roots(ctx)
|
||||||
|
assert roots == [server_root]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unexpected_roots_error_disables_file_path_tools(tmp_path, monkeypatch):
|
||||||
|
server_root = (tmp_path / "server_root").resolve()
|
||||||
|
server_root.mkdir(parents=True)
|
||||||
|
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
|
||||||
|
|
||||||
|
ctx = _FailingContext(RuntimeError("transport failure"))
|
||||||
|
roots = await main._get_effective_allowed_roots(ctx)
|
||||||
|
assert roots == []
|
||||||
|
|
||||||
|
resolved, error = await main._resolve_readable_file_path(
|
||||||
|
raw_path="anything.txt",
|
||||||
|
ctx=ctx,
|
||||||
|
tool_name="send_file",
|
||||||
|
)
|
||||||
|
assert resolved is None
|
||||||
|
assert error is not None
|
||||||
|
assert "disabled" in error
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_writable_default_path_uses_downloads_subdir(tmp_path, monkeypatch):
|
async def test_writable_default_path_uses_downloads_subdir(tmp_path, monkeypatch):
|
||||||
root = (tmp_path / "root").resolve()
|
root = (tmp_path / "root").resolve()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue