feat: Implement input validation decorator for chat/user IDs
Adds a new `@validate_id` decorator to validate `chat_id` and `user_id` parameters in `main.py`. This decorator ensures that all IDs passed to functions are in a valid format before making RPC calls to the Telegram API. It handles: - Integer IDs (positive and negative) - String representations of integer IDs - Usernames (e.g., "@username") - Lists of IDs Key changes: - Created a `validate_id` decorator in `main.py`. - Applied the decorator to all functions that accept `chat_id`, `user_id`, or similar parameters. - Updated the central `log_and_format_error` function to handle custom validation error messages and a specific `VALIDATION-001` error code for logging. - Added a new test suite (`test_validation.py`) with comprehensive tests for the decorator. - Updated `README.md` and `session_string_generator.py` with documentation about the new validation.
This commit is contained in:
parent
9464b4de75
commit
34bdd58905
5 changed files with 316 additions and 85 deletions
|
|
@ -1 +1 @@
|
|||
3.13
|
||||
3.12.11
|
||||
|
|
|
|||
10
README.md
10
README.md
|
|
@ -116,6 +116,16 @@ This MCP server exposes a huge suite of Telegram tools. **Every major Telegram/T
|
|||
- **unarchive_chat(chat_id)**: Unarchive a chat
|
||||
- **get_recent_actions(chat_id)**: Get recent admin actions
|
||||
|
||||
### Input Validation
|
||||
|
||||
To improve robustness, all functions accepting `chat_id` or `user_id` parameters now include input validation. You can use any of the following formats for these IDs:
|
||||
|
||||
- **Integer ID**: The direct integer ID for a user, chat, or channel (e.g., `123456789` or `-1001234567890`).
|
||||
- **String ID**: The integer ID provided as a string (e.g., `"123456789"`).
|
||||
- **Username**: The public username for a user or channel (e.g., `"@username"` or `"username"`).
|
||||
|
||||
The server will automatically validate the input and convert it to the correct format before making a request to Telegram. If the input is invalid, a clear error message will be returned.
|
||||
|
||||
## Removed Functionality
|
||||
|
||||
Please note that tools requiring direct file path access on the server (`send_file`, `download_media`, `set_profile_photo`, `edit_chat_photo`, `send_voice`, `send_sticker`, `upload_file`) have been removed from `main.py`. This is due to limitations in the current MCP environment regarding handling file attachments and local file system paths.
|
||||
|
|
|
|||
301
main.py
301
main.py
|
|
@ -30,9 +30,16 @@ from telethon.tl.types import (
|
|||
InputPeerChat,
|
||||
InputPeerChannel,
|
||||
)
|
||||
import re
|
||||
from functools import wraps
|
||||
import telethon.errors.rpcerrorlist
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
"""Custom exception for validation errors."""
|
||||
pass
|
||||
|
||||
|
||||
def json_serializer(obj):
|
||||
"""Helper function to convert non-serializable objects for JSON serialization."""
|
||||
if isinstance(obj, datetime):
|
||||
|
|
@ -109,7 +116,7 @@ ERROR_PREFIXES = {
|
|||
|
||||
|
||||
def log_and_format_error(
|
||||
function_name: str, error: Exception, prefix: str = None, **kwargs
|
||||
function_name: str, error: Exception, prefix: str = None, user_message: str = None, **kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Centralized error handling function that logs the error and returns a formatted user-friendly message.
|
||||
|
|
@ -118,13 +125,16 @@ def log_and_format_error(
|
|||
function_name: Name of the function where error occurred
|
||||
error: The exception that was raised
|
||||
prefix: Error code prefix (e.g., "CHAT", "MSG") - if None, will be derived from function_name
|
||||
user_message: A custom user-facing message to return. If None, a generic one is created.
|
||||
**kwargs: Additional context parameters to include in log
|
||||
|
||||
Returns:
|
||||
A user-friendly error message with error code
|
||||
"""
|
||||
# Generate a consistent error code
|
||||
if prefix is None:
|
||||
if prefix == "VALIDATION-001":
|
||||
error_code = prefix
|
||||
elif prefix is None:
|
||||
# Try to derive prefix from function name
|
||||
for key, value in ERROR_PREFIXES.items():
|
||||
if key in function_name.lower():
|
||||
|
|
@ -132,19 +142,93 @@ def log_and_format_error(
|
|||
break
|
||||
if prefix is None:
|
||||
prefix = "GEN" # Generic prefix if none matches
|
||||
error_code = f"{prefix}-ERR-{abs(hash(function_name)) % 1000:03d}"
|
||||
else:
|
||||
error_code = f"{prefix}-ERR-{abs(hash(function_name)) % 1000:03d}"
|
||||
|
||||
error_code = f"{prefix}-ERR-{abs(hash(function_name)) % 1000:03d}"
|
||||
|
||||
# Format the additional context parameters
|
||||
context = ", ".join(f"{k}={v}" for k, v in kwargs.items())
|
||||
|
||||
# Log the full technical error
|
||||
logger.exception(f"{function_name} failed ({context}): {error}")
|
||||
logger.error(f"Error in {function_name} ({context}) - Code: {error_code}", exc_info=True)
|
||||
|
||||
# Return a user-friendly message
|
||||
if user_message:
|
||||
return user_message
|
||||
|
||||
return f"An error occurred (code: {error_code}). Check mcp_errors.log for details."
|
||||
|
||||
|
||||
def validate_id(*param_names_to_validate):
|
||||
"""
|
||||
Decorator to validate chat_id and user_id parameters, including lists of IDs.
|
||||
It checks for valid integer ranges, string representations of integers,
|
||||
and username formats.
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
for param_name in param_names_to_validate:
|
||||
if param_name not in kwargs or kwargs[param_name] is None:
|
||||
continue
|
||||
|
||||
param_value = kwargs[param_name]
|
||||
|
||||
def validate_single_id(value, p_name):
|
||||
# Handle integer IDs
|
||||
if isinstance(value, int):
|
||||
if not (-2**63 <= value <= 2**63 - 1):
|
||||
return None, f"Invalid {p_name}: {value}. ID is out of the valid integer range."
|
||||
return value, None
|
||||
|
||||
# Handle string IDs
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
int_value = int(value)
|
||||
if not (-2**63 <= int_value <= 2**63 - 1):
|
||||
return None, f"Invalid {p_name}: {value}. ID is out of the valid integer range."
|
||||
return int_value, None
|
||||
except ValueError:
|
||||
if re.match(r'^@?[a-zA-Z0-9_]{5,}$', value):
|
||||
return value, None
|
||||
else:
|
||||
return None, f"Invalid {p_name}: '{value}'. Must be a valid integer ID, or a username string."
|
||||
|
||||
# Handle other invalid types
|
||||
return None, f"Invalid {p_name}: {value}. Type must be an integer or a string."
|
||||
|
||||
if isinstance(param_value, list):
|
||||
validated_list = []
|
||||
for item in param_value:
|
||||
validated_item, error_msg = validate_single_id(item, param_name)
|
||||
if error_msg:
|
||||
return log_and_format_error(
|
||||
func.__name__,
|
||||
ValidationError(error_msg),
|
||||
prefix="VALIDATION-001",
|
||||
user_message=error_msg,
|
||||
**{param_name: param_value}
|
||||
)
|
||||
validated_list.append(validated_item)
|
||||
kwargs[param_name] = validated_list
|
||||
else:
|
||||
validated_value, error_msg = validate_single_id(param_value, param_name)
|
||||
if error_msg:
|
||||
return log_and_format_error(
|
||||
func.__name__,
|
||||
ValidationError(error_msg),
|
||||
prefix="VALIDATION-001",
|
||||
user_message=error_msg,
|
||||
**{param_name: param_value}
|
||||
)
|
||||
kwargs[param_name] = validated_value
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def format_entity(entity) -> Dict[str, Any]:
|
||||
"""Helper function to format entity information consistently."""
|
||||
result = {"id": entity.id}
|
||||
|
|
@ -231,11 +315,12 @@ async def get_chats(page: int = 1, page_size: int = 20) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_messages(chat_id: int, page: int = 1, page_size: int = 20) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def get_messages(chat_id: Union[int, str], page: int = 1, page_size: int = 20) -> str:
|
||||
"""
|
||||
Get paginated messages from a specific chat.
|
||||
Args:
|
||||
chat_id: The ID of the chat.
|
||||
chat_id: The ID or username of the chat.
|
||||
page: Page number (1-indexed).
|
||||
page_size: Number of messages per page.
|
||||
"""
|
||||
|
|
@ -262,11 +347,12 @@ async def get_messages(chat_id: int, page: int = 1, page_size: int = 20) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def send_message(chat_id: int, message: str) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def send_message(chat_id: Union[int, str], message: str) -> str:
|
||||
"""
|
||||
Send a message to a specific chat.
|
||||
Args:
|
||||
chat_id: The ID of the chat.
|
||||
chat_id: The ID or username of the chat.
|
||||
message: The message content to send.
|
||||
"""
|
||||
try:
|
||||
|
|
@ -346,8 +432,9 @@ async def get_contact_ids() -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
@validate_id('chat_id')
|
||||
async def list_messages(
|
||||
chat_id: int,
|
||||
chat_id: Union[int, str],
|
||||
limit: int = 20,
|
||||
search_query: str = None,
|
||||
from_date: str = None,
|
||||
|
|
@ -357,7 +444,7 @@ async def list_messages(
|
|||
Retrieve messages with optional filters.
|
||||
|
||||
Args:
|
||||
chat_id: The ID of the chat to get messages from.
|
||||
chat_id: The ID or username of the chat to get messages from.
|
||||
limit: Maximum number of messages to retrieve.
|
||||
search_query: Filter messages containing this text.
|
||||
from_date: Filter messages starting from this date (format: YYYY-MM-DD).
|
||||
|
|
@ -501,12 +588,13 @@ async def list_chats(chat_type: str = None, limit: int = 20) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_chat(chat_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def get_chat(chat_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Get detailed information about a specific chat.
|
||||
|
||||
Args:
|
||||
chat_id: The ID of the chat.
|
||||
chat_id: The ID or username of the chat.
|
||||
"""
|
||||
try:
|
||||
entity = await client.get_entity(chat_id)
|
||||
|
|
@ -636,12 +724,13 @@ async def get_direct_chat_by_contact(contact_query: str) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_contact_chats(contact_id: int) -> str:
|
||||
@validate_id('contact_id')
|
||||
async def get_contact_chats(contact_id: Union[int, str]) -> str:
|
||||
"""
|
||||
List all chats involving a specific contact.
|
||||
|
||||
Args:
|
||||
contact_id: The ID of the contact.
|
||||
contact_id: The ID or username of the contact.
|
||||
"""
|
||||
try:
|
||||
# Get contact info
|
||||
|
|
@ -688,12 +777,13 @@ async def get_contact_chats(contact_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_last_interaction(contact_id: int) -> str:
|
||||
@validate_id('contact_id')
|
||||
async def get_last_interaction(contact_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Get the most recent message with a contact.
|
||||
|
||||
Args:
|
||||
contact_id: The ID of the contact.
|
||||
contact_id: The ID or username of the contact.
|
||||
"""
|
||||
try:
|
||||
# Get contact info
|
||||
|
|
@ -724,12 +814,13 @@ async def get_last_interaction(contact_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_message_context(chat_id: int, message_id: int, context_size: int = 3) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def get_message_context(chat_id: Union[int, str], message_id: int, context_size: int = 3) -> str:
|
||||
"""
|
||||
Retrieve context around a specific message.
|
||||
|
||||
Args:
|
||||
chat_id: The ID of the chat.
|
||||
chat_id: The ID or username of the chat.
|
||||
message_id: The ID of the central message.
|
||||
context_size: Number of messages before and after to include.
|
||||
"""
|
||||
|
|
@ -841,11 +932,12 @@ async def add_contact(phone: str, first_name: str, last_name: str = "") -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def delete_contact(user_id: int) -> str:
|
||||
@validate_id('user_id')
|
||||
async def delete_contact(user_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Delete a contact by user ID.
|
||||
Args:
|
||||
user_id: The Telegram user ID of the contact to delete.
|
||||
user_id: The Telegram user ID or username of the contact to delete.
|
||||
"""
|
||||
try:
|
||||
user = await client.get_entity(user_id)
|
||||
|
|
@ -856,11 +948,12 @@ async def delete_contact(user_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def block_user(user_id: int) -> str:
|
||||
@validate_id('user_id')
|
||||
async def block_user(user_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Block a user by user ID.
|
||||
Args:
|
||||
user_id: The Telegram user ID to block.
|
||||
user_id: The Telegram user ID or username to block.
|
||||
"""
|
||||
try:
|
||||
user = await client.get_entity(user_id)
|
||||
|
|
@ -871,11 +964,12 @@ async def block_user(user_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def unblock_user(user_id: int) -> str:
|
||||
@validate_id('user_id')
|
||||
async def unblock_user(user_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Unblock a user by user ID.
|
||||
Args:
|
||||
user_id: The Telegram user ID to unblock.
|
||||
user_id: The Telegram user ID or username to unblock.
|
||||
"""
|
||||
try:
|
||||
user = await client.get_entity(user_id)
|
||||
|
|
@ -898,13 +992,14 @@ async def get_me() -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def create_group(title: str, user_ids: list) -> str:
|
||||
@validate_id('user_ids')
|
||||
async def create_group(title: str, user_ids: List[Union[int, str]]) -> str:
|
||||
"""
|
||||
Create a new group or supergroup and add users.
|
||||
|
||||
Args:
|
||||
title: Title for the new group
|
||||
user_ids: List of user IDs to add to the group
|
||||
user_ids: List of user IDs or usernames to add to the group
|
||||
"""
|
||||
try:
|
||||
# Convert user IDs to entities
|
||||
|
|
@ -956,13 +1051,14 @@ async def create_group(title: str, user_ids: list) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def invite_to_group(group_id: int, user_ids: list) -> str:
|
||||
@validate_id('group_id', 'user_ids')
|
||||
async def invite_to_group(group_id: Union[int, str], user_ids: List[Union[int, str]]) -> str:
|
||||
"""
|
||||
Invite users to a group or channel.
|
||||
|
||||
Args:
|
||||
group_id: The ID of the group/channel.
|
||||
user_ids: List of user IDs to invite.
|
||||
group_id: The ID or username of the group/channel.
|
||||
user_ids: List of user IDs or usernames to invite.
|
||||
"""
|
||||
try:
|
||||
entity = await client.get_entity(group_id)
|
||||
|
|
@ -1005,12 +1101,13 @@ async def invite_to_group(group_id: int, user_ids: list) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def leave_chat(chat_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def leave_chat(chat_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Leave a group or channel by chat ID.
|
||||
|
||||
Args:
|
||||
chat_id: The chat ID to leave.
|
||||
chat_id: The chat ID or username to leave.
|
||||
"""
|
||||
try:
|
||||
entity = await client.get_entity(chat_id)
|
||||
|
|
@ -1084,11 +1181,12 @@ async def leave_chat(chat_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_participants(chat_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def get_participants(chat_id: Union[int, str]) -> str:
|
||||
"""
|
||||
List all participants in a group or channel.
|
||||
Args:
|
||||
chat_id: The group or channel ID.
|
||||
chat_id: The group or channel ID or username.
|
||||
"""
|
||||
try:
|
||||
participants = await client.get_participants(chat_id)
|
||||
|
|
@ -1102,11 +1200,12 @@ async def get_participants(chat_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def send_file(chat_id: int, file_path: str, caption: str = None) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def send_file(chat_id: Union[int, str], file_path: str, caption: str = None) -> str:
|
||||
"""
|
||||
Send a file to a chat.
|
||||
Args:
|
||||
chat_id: The chat ID.
|
||||
chat_id: The chat ID or username.
|
||||
file_path: Absolute path to the file to send (must exist and be readable).
|
||||
caption: Optional caption for the file.
|
||||
"""
|
||||
|
|
@ -1125,11 +1224,12 @@ async def send_file(chat_id: int, file_path: str, caption: str = None) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def download_media(chat_id: int, message_id: int, file_path: str) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def download_media(chat_id: Union[int, str], message_id: int, file_path: str) -> str:
|
||||
"""
|
||||
Download media from a message in a chat.
|
||||
Args:
|
||||
chat_id: The chat ID.
|
||||
chat_id: The chat ID or username.
|
||||
message_id: The message ID containing the media.
|
||||
file_path: Absolute path to save the downloaded file (must be writable).
|
||||
"""
|
||||
|
|
@ -1226,16 +1326,17 @@ async def get_privacy_settings() -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
@validate_id('allow_users', 'disallow_users')
|
||||
async def set_privacy_settings(
|
||||
key: str, allow_users: list = None, disallow_users: list = None
|
||||
key: str, allow_users: Optional[List[Union[int, str]]] = None, disallow_users: Optional[List[Union[int, str]]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Set privacy settings (e.g., last seen, phone, etc.).
|
||||
|
||||
Args:
|
||||
key: The privacy setting to modify ('status' for last seen, 'phone', 'profile_photo', etc.)
|
||||
allow_users: List of user IDs to allow
|
||||
disallow_users: List of user IDs to disallow
|
||||
allow_users: List of user IDs or usernames to allow
|
||||
disallow_users: List of user IDs or usernames to disallow
|
||||
"""
|
||||
try:
|
||||
# Import needed types
|
||||
|
|
@ -1382,7 +1483,8 @@ async def create_channel(title: str, about: str = "", megagroup: bool = False) -
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def edit_chat_title(chat_id: int, title: str) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def edit_chat_title(chat_id: Union[int, str], title: str) -> str:
|
||||
"""
|
||||
Edit the title of a chat, group, or channel.
|
||||
"""
|
||||
|
|
@ -1401,7 +1503,8 @@ async def edit_chat_title(chat_id: int, title: str) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def edit_chat_photo(chat_id: int, file_path: str) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def edit_chat_photo(chat_id: Union[int, str], file_path: str) -> str:
|
||||
"""
|
||||
Edit the photo of a chat, group, or channel. Requires a file path to an image.
|
||||
"""
|
||||
|
|
@ -1434,7 +1537,8 @@ async def edit_chat_photo(chat_id: int, file_path: str) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def delete_chat_photo(chat_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def delete_chat_photo(chat_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Delete the photo of a chat, group, or channel.
|
||||
"""
|
||||
|
|
@ -1462,13 +1566,14 @@ async def delete_chat_photo(chat_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def promote_admin(group_id: int, user_id: int, rights: dict = None) -> str:
|
||||
@validate_id('group_id', 'user_id')
|
||||
async def promote_admin(group_id: Union[int, str], user_id: Union[int, str], rights: dict = None) -> str:
|
||||
"""
|
||||
Promote a user to admin in a group/channel.
|
||||
|
||||
Args:
|
||||
group_id: ID of the group/channel
|
||||
user_id: User ID to promote
|
||||
group_id: ID or username of the group/channel
|
||||
user_id: User ID or username to promote
|
||||
rights: Admin rights to give (optional)
|
||||
"""
|
||||
try:
|
||||
|
|
@ -1526,13 +1631,14 @@ async def promote_admin(group_id: int, user_id: int, rights: dict = None) -> str
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def demote_admin(group_id: int, user_id: int) -> str:
|
||||
@validate_id('group_id', 'user_id')
|
||||
async def demote_admin(group_id: Union[int, str], user_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Demote a user from admin in a group/channel.
|
||||
|
||||
Args:
|
||||
group_id: ID of the group/channel
|
||||
user_id: User ID to demote
|
||||
group_id: ID or username of the group/channel
|
||||
user_id: User ID or username to demote
|
||||
"""
|
||||
try:
|
||||
chat = await client.get_entity(group_id)
|
||||
|
|
@ -1574,13 +1680,14 @@ async def demote_admin(group_id: int, user_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def ban_user(chat_id: int, user_id: int) -> str:
|
||||
@validate_id('chat_id', 'user_id')
|
||||
async def ban_user(chat_id: Union[int, str], user_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Ban a user from a group or channel.
|
||||
|
||||
Args:
|
||||
chat_id: ID of the group/channel
|
||||
user_id: User ID to ban
|
||||
chat_id: ID or username of the group/channel
|
||||
user_id: User ID or username to ban
|
||||
"""
|
||||
try:
|
||||
chat = await client.get_entity(chat_id)
|
||||
|
|
@ -1620,13 +1727,14 @@ async def ban_user(chat_id: int, user_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def unban_user(chat_id: int, user_id: int) -> str:
|
||||
@validate_id('chat_id', 'user_id')
|
||||
async def unban_user(chat_id: Union[int, str], user_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Unban a user from a group or channel.
|
||||
|
||||
Args:
|
||||
chat_id: ID of the group/channel
|
||||
user_id: User ID to unban
|
||||
chat_id: ID or username of the group/channel
|
||||
user_id: User ID or username to unban
|
||||
"""
|
||||
try:
|
||||
chat = await client.get_entity(chat_id)
|
||||
|
|
@ -1666,7 +1774,8 @@ async def unban_user(chat_id: int, user_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_admins(chat_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def get_admins(chat_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Get all admins in a group or channel.
|
||||
"""
|
||||
|
|
@ -1684,7 +1793,8 @@ async def get_admins(chat_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_banned_users(chat_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def get_banned_users(chat_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Get all banned users in a group or channel.
|
||||
"""
|
||||
|
|
@ -1704,7 +1814,8 @@ async def get_banned_users(chat_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_invite_link(chat_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def get_invite_link(chat_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Get the invite link for a group or channel.
|
||||
"""
|
||||
|
|
@ -1807,7 +1918,8 @@ async def join_chat_by_link(link: str) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def export_chat_invite(chat_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def export_chat_invite(chat_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Export a chat invite link.
|
||||
"""
|
||||
|
|
@ -1896,11 +2008,12 @@ async def import_chat_invite(hash: str) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def send_voice(chat_id: int, file_path: str) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def send_voice(chat_id: Union[int, str], file_path: str) -> str:
|
||||
"""
|
||||
Send a voice message to a chat. File must be an OGG/OPUS voice note.
|
||||
Args:
|
||||
chat_id: The chat ID.
|
||||
chat_id: The chat ID or username.
|
||||
file_path: Absolute path to the OGG/OPUS file.
|
||||
"""
|
||||
try:
|
||||
|
|
@ -1926,7 +2039,8 @@ async def send_voice(chat_id: int, file_path: str) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def forward_message(from_chat_id: int, message_id: int, to_chat_id: int) -> str:
|
||||
@validate_id('from_chat_id', 'to_chat_id')
|
||||
async def forward_message(from_chat_id: Union[int, str], message_id: int, to_chat_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Forward a message from one chat to another.
|
||||
"""
|
||||
|
|
@ -1946,7 +2060,8 @@ async def forward_message(from_chat_id: int, message_id: int, to_chat_id: int) -
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def edit_message(chat_id: int, message_id: int, new_text: str) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def edit_message(chat_id: Union[int, str], message_id: int, new_text: str) -> str:
|
||||
"""
|
||||
Edit a message you sent.
|
||||
"""
|
||||
|
|
@ -1961,7 +2076,8 @@ async def edit_message(chat_id: int, message_id: int, new_text: str) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def delete_message(chat_id: int, message_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def delete_message(chat_id: Union[int, str], message_id: int) -> str:
|
||||
"""
|
||||
Delete a message by ID.
|
||||
"""
|
||||
|
|
@ -1974,7 +2090,8 @@ async def delete_message(chat_id: int, message_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def pin_message(chat_id: int, message_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def pin_message(chat_id: Union[int, str], message_id: int) -> str:
|
||||
"""
|
||||
Pin a message in a chat.
|
||||
"""
|
||||
|
|
@ -1987,7 +2104,8 @@ async def pin_message(chat_id: int, message_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def unpin_message(chat_id: int, message_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def unpin_message(chat_id: Union[int, str], message_id: int) -> str:
|
||||
"""
|
||||
Unpin a message in a chat.
|
||||
"""
|
||||
|
|
@ -2000,7 +2118,8 @@ async def unpin_message(chat_id: int, message_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def mark_as_read(chat_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def mark_as_read(chat_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Mark all messages as read in a chat.
|
||||
"""
|
||||
|
|
@ -2013,7 +2132,8 @@ async def mark_as_read(chat_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def reply_to_message(chat_id: int, message_id: int, text: str) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def reply_to_message(chat_id: Union[int, str], message_id: int, text: str) -> str:
|
||||
"""
|
||||
Reply to a specific message in a chat.
|
||||
"""
|
||||
|
|
@ -2028,11 +2148,12 @@ async def reply_to_message(chat_id: int, message_id: int, text: str) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_media_info(chat_id: int, message_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def get_media_info(chat_id: Union[int, str], message_id: int) -> str:
|
||||
"""
|
||||
Get info about media in a message.
|
||||
Args:
|
||||
chat_id: The chat ID.
|
||||
chat_id: The chat ID or username.
|
||||
message_id: The message ID.
|
||||
"""
|
||||
try:
|
||||
|
|
@ -2058,7 +2179,8 @@ async def search_public_chats(query: str) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_messages(chat_id: int, query: str, limit: int = 20) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def search_messages(chat_id: Union[int, str], query: str, limit: int = 20) -> str:
|
||||
"""
|
||||
Search for messages in a chat by text.
|
||||
"""
|
||||
|
|
@ -2094,7 +2216,8 @@ async def resolve_username(username: str) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def mute_chat(chat_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def mute_chat(chat_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Mute notifications for a chat.
|
||||
"""
|
||||
|
|
@ -2132,7 +2255,8 @@ async def mute_chat(chat_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def unmute_chat(chat_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def unmute_chat(chat_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Unmute notifications for a chat.
|
||||
"""
|
||||
|
|
@ -2170,7 +2294,8 @@ async def unmute_chat(chat_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def archive_chat(chat_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def archive_chat(chat_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Archive a chat.
|
||||
"""
|
||||
|
|
@ -2186,7 +2311,8 @@ async def archive_chat(chat_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def unarchive_chat(chat_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def unarchive_chat(chat_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Unarchive a chat.
|
||||
"""
|
||||
|
|
@ -2214,11 +2340,12 @@ async def get_sticker_sets() -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def send_sticker(chat_id: int, file_path: str) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def send_sticker(chat_id: Union[int, str], file_path: str) -> str:
|
||||
"""
|
||||
Send a sticker to a chat. File must be a valid .webp sticker file.
|
||||
Args:
|
||||
chat_id: The chat ID.
|
||||
chat_id: The chat ID or username.
|
||||
file_path: Absolute path to the .webp sticker file.
|
||||
"""
|
||||
try:
|
||||
|
|
@ -2291,11 +2418,12 @@ async def get_gif_search(query: str, limit: int = 10) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def send_gif(chat_id: int, gif_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def send_gif(chat_id: Union[int, str], gif_id: int) -> str:
|
||||
"""
|
||||
Send a GIF to a chat by Telegram GIF document ID (not a file path).
|
||||
Args:
|
||||
chat_id: The chat ID.
|
||||
chat_id: The chat ID or username.
|
||||
gif_id: Telegram document ID for the GIF (from get_gif_search).
|
||||
"""
|
||||
try:
|
||||
|
|
@ -2393,7 +2521,8 @@ async def set_bot_commands(bot_username: str, commands: list) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_history(chat_id: int, limit: int = 100) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def get_history(chat_id: Union[int, str], limit: int = 100) -> str:
|
||||
"""
|
||||
Get full chat history (up to limit).
|
||||
"""
|
||||
|
|
@ -2415,7 +2544,8 @@ async def get_history(chat_id: int, limit: int = 100) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_user_photos(user_id: int, limit: int = 10) -> str:
|
||||
@validate_id('user_id')
|
||||
async def get_user_photos(user_id: Union[int, str], limit: int = 10) -> str:
|
||||
"""
|
||||
Get profile photos of a user.
|
||||
"""
|
||||
|
|
@ -2430,7 +2560,8 @@ async def get_user_photos(user_id: int, limit: int = 10) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_user_status(user_id: int) -> str:
|
||||
@validate_id('user_id')
|
||||
async def get_user_status(user_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Get the online status of a user.
|
||||
"""
|
||||
|
|
@ -2442,7 +2573,8 @@ async def get_user_status(user_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_recent_actions(chat_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def get_recent_actions(chat_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Get recent admin actions (admin log) in a group or channel.
|
||||
"""
|
||||
|
|
@ -2464,7 +2596,8 @@ async def get_recent_actions(chat_id: int) -> str:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_pinned_messages(chat_id: int) -> str:
|
||||
@validate_id('chat_id')
|
||||
async def get_pinned_messages(chat_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Get all pinned messages in a chat.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -12,6 +12,11 @@ Usage:
|
|||
Requirements:
|
||||
- telethon
|
||||
- python-dotenv
|
||||
|
||||
Note on ID Formats:
|
||||
When using the MCP server, please be aware that all `chat_id` and `user_id`
|
||||
parameters support integer IDs, string representations of IDs (e.g., "123456"),
|
||||
and usernames (e.g., "@mychannel").
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
|
|||
83
test_validation.py
Normal file
83
test_validation.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
import pytest
|
||||
import os
|
||||
os.environ["TELEGRAM_API_ID"] = "12345"
|
||||
os.environ["TELEGRAM_API_HASH"] = "dummy_hash"
|
||||
from main import validate_id, ValidationError, log_and_format_error
|
||||
from functools import wraps
|
||||
import asyncio
|
||||
from typing import Union, List, Optional
|
||||
|
||||
# A simple async function to be decorated for testing
|
||||
@validate_id('user_id', 'chat_id', 'user_ids')
|
||||
async def dummy_function(**kwargs):
|
||||
return "success", kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_integer_id():
|
||||
result, kwargs = await dummy_function(user_id=12345)
|
||||
assert result == "success"
|
||||
assert kwargs['user_id'] == 12345
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_negative_integer_id():
|
||||
result, kwargs = await dummy_function(chat_id=-100123456)
|
||||
assert result == "success"
|
||||
assert kwargs['chat_id'] == -100123456
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_string_integer_id():
|
||||
result, kwargs = await dummy_function(user_id="12345")
|
||||
assert result == "success"
|
||||
assert kwargs['user_id'] == 12345
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_username():
|
||||
result, kwargs = await dummy_function(user_id="@test_user")
|
||||
assert result == "success"
|
||||
assert kwargs['user_id'] == "@test_user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_username_without_at():
|
||||
result, kwargs = await dummy_function(user_id="test_user_long_enough")
|
||||
assert result == "success"
|
||||
assert kwargs['user_id'] == "test_user_long_enough"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_list_of_ids():
|
||||
result, kwargs = await dummy_function(user_ids=[123, "456", "@test_user"])
|
||||
assert result == "success"
|
||||
assert kwargs['user_ids'] == [123, 456, "@test_user"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_float_id():
|
||||
result = await dummy_function(user_id=123.45)
|
||||
assert "Invalid user_id" in result
|
||||
assert "Type must be an integer or a string" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_string_id():
|
||||
result = await dummy_function(user_id="inv") # too short
|
||||
assert "Invalid user_id" in result
|
||||
assert "Must be a valid integer ID, or a username string" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_integer_out_of_range():
|
||||
result = await dummy_function(user_id=2**64)
|
||||
assert "Invalid user_id" in result
|
||||
assert "out of the valid integer range" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_item_in_list():
|
||||
result = await dummy_function(user_ids=[123, "456", 123.45])
|
||||
assert "Invalid user_ids" in result
|
||||
assert "Type must be an integer or a string" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_id_provided():
|
||||
result, kwargs = await dummy_function()
|
||||
assert result == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_id_provided():
|
||||
result, kwargs = await dummy_function(user_id=None)
|
||||
assert result == "success"
|
||||
Loading…
Reference in a new issue