diff --git a/.python-version b/.python-version index 24ee5b1..7eebfaf 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.13 +3.12.11 diff --git a/README.md b/README.md index 125f847..efca7da 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,16 @@ This MCP server exposes a huge suite of Telegram tools. **Every major Telegram/T - **unarchive_chat(chat_id)**: Unarchive a chat - **get_recent_actions(chat_id)**: Get recent admin actions +### Input Validation + +To improve robustness, all functions accepting `chat_id` or `user_id` parameters now include input validation. You can use any of the following formats for these IDs: + +- **Integer ID**: The direct integer ID for a user, chat, or channel (e.g., `123456789` or `-1001234567890`). +- **String ID**: The integer ID provided as a string (e.g., `"123456789"`). +- **Username**: The public username for a user or channel (e.g., `"@username"` or `"username"`). + +The server will automatically validate the input and convert it to the correct format before making a request to Telegram. If the input is invalid, a clear error message will be returned. + ## Removed Functionality Please note that tools requiring direct file path access on the server (`send_file`, `download_media`, `set_profile_photo`, `edit_chat_photo`, `send_voice`, `send_sticker`, `upload_file`) have been removed from `main.py`. This is due to limitations in the current MCP environment regarding handling file attachments and local file system paths. diff --git a/main.py b/main.py index 6fb590e..dc6c49c 100644 --- a/main.py +++ b/main.py @@ -30,9 +30,16 @@ from telethon.tl.types import ( InputPeerChat, InputPeerChannel, ) +import re +from functools import wraps import telethon.errors.rpcerrorlist +class ValidationError(Exception): + """Custom exception for validation errors.""" + pass + + def json_serializer(obj): """Helper function to convert non-serializable objects for JSON serialization.""" if isinstance(obj, datetime): @@ -109,7 +116,7 @@ ERROR_PREFIXES = { def log_and_format_error( - function_name: str, error: Exception, prefix: str = None, **kwargs + function_name: str, error: Exception, prefix: str = None, user_message: str = None, **kwargs ) -> str: """ Centralized error handling function that logs the error and returns a formatted user-friendly message. @@ -118,13 +125,16 @@ def log_and_format_error( function_name: Name of the function where error occurred error: The exception that was raised prefix: Error code prefix (e.g., "CHAT", "MSG") - if None, will be derived from function_name + user_message: A custom user-facing message to return. If None, a generic one is created. **kwargs: Additional context parameters to include in log Returns: A user-friendly error message with error code """ # Generate a consistent error code - if prefix is None: + if prefix == "VALIDATION-001": + error_code = prefix + elif prefix is None: # Try to derive prefix from function name for key, value in ERROR_PREFIXES.items(): if key in function_name.lower(): @@ -132,19 +142,93 @@ def log_and_format_error( break if prefix is None: prefix = "GEN" # Generic prefix if none matches + error_code = f"{prefix}-ERR-{abs(hash(function_name)) % 1000:03d}" + else: + error_code = f"{prefix}-ERR-{abs(hash(function_name)) % 1000:03d}" - error_code = f"{prefix}-ERR-{abs(hash(function_name)) % 1000:03d}" # Format the additional context parameters context = ", ".join(f"{k}={v}" for k, v in kwargs.items()) # Log the full technical error - logger.exception(f"{function_name} failed ({context}): {error}") + logger.error(f"Error in {function_name} ({context}) - Code: {error_code}", exc_info=True) # Return a user-friendly message + if user_message: + return user_message + return f"An error occurred (code: {error_code}). Check mcp_errors.log for details." +def validate_id(*param_names_to_validate): + """ + Decorator to validate chat_id and user_id parameters, including lists of IDs. + It checks for valid integer ranges, string representations of integers, + and username formats. + """ + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + for param_name in param_names_to_validate: + if param_name not in kwargs or kwargs[param_name] is None: + continue + + param_value = kwargs[param_name] + + def validate_single_id(value, p_name): + # Handle integer IDs + if isinstance(value, int): + if not (-2**63 <= value <= 2**63 - 1): + return None, f"Invalid {p_name}: {value}. ID is out of the valid integer range." + return value, None + + # Handle string IDs + if isinstance(value, str): + try: + int_value = int(value) + if not (-2**63 <= int_value <= 2**63 - 1): + return None, f"Invalid {p_name}: {value}. ID is out of the valid integer range." + return int_value, None + except ValueError: + if re.match(r'^@?[a-zA-Z0-9_]{5,}$', value): + return value, None + else: + return None, f"Invalid {p_name}: '{value}'. Must be a valid integer ID, or a username string." + + # Handle other invalid types + return None, f"Invalid {p_name}: {value}. Type must be an integer or a string." + + if isinstance(param_value, list): + validated_list = [] + for item in param_value: + validated_item, error_msg = validate_single_id(item, param_name) + if error_msg: + return log_and_format_error( + func.__name__, + ValidationError(error_msg), + prefix="VALIDATION-001", + user_message=error_msg, + **{param_name: param_value} + ) + validated_list.append(validated_item) + kwargs[param_name] = validated_list + else: + validated_value, error_msg = validate_single_id(param_value, param_name) + if error_msg: + return log_and_format_error( + func.__name__, + ValidationError(error_msg), + prefix="VALIDATION-001", + user_message=error_msg, + **{param_name: param_value} + ) + kwargs[param_name] = validated_value + + return await func(*args, **kwargs) + return wrapper + return decorator + + def format_entity(entity) -> Dict[str, Any]: """Helper function to format entity information consistently.""" result = {"id": entity.id} @@ -231,11 +315,12 @@ async def get_chats(page: int = 1, page_size: int = 20) -> str: @mcp.tool() -async def get_messages(chat_id: int, page: int = 1, page_size: int = 20) -> str: +@validate_id('chat_id') +async def get_messages(chat_id: Union[int, str], page: int = 1, page_size: int = 20) -> str: """ Get paginated messages from a specific chat. Args: - chat_id: The ID of the chat. + chat_id: The ID or username of the chat. page: Page number (1-indexed). page_size: Number of messages per page. """ @@ -262,11 +347,12 @@ async def get_messages(chat_id: int, page: int = 1, page_size: int = 20) -> str: @mcp.tool() -async def send_message(chat_id: int, message: str) -> str: +@validate_id('chat_id') +async def send_message(chat_id: Union[int, str], message: str) -> str: """ Send a message to a specific chat. Args: - chat_id: The ID of the chat. + chat_id: The ID or username of the chat. message: The message content to send. """ try: @@ -346,8 +432,9 @@ async def get_contact_ids() -> str: @mcp.tool() +@validate_id('chat_id') async def list_messages( - chat_id: int, + chat_id: Union[int, str], limit: int = 20, search_query: str = None, from_date: str = None, @@ -357,7 +444,7 @@ async def list_messages( Retrieve messages with optional filters. Args: - chat_id: The ID of the chat to get messages from. + chat_id: The ID or username of the chat to get messages from. limit: Maximum number of messages to retrieve. search_query: Filter messages containing this text. from_date: Filter messages starting from this date (format: YYYY-MM-DD). @@ -501,12 +588,13 @@ async def list_chats(chat_type: str = None, limit: int = 20) -> str: @mcp.tool() -async def get_chat(chat_id: int) -> str: +@validate_id('chat_id') +async def get_chat(chat_id: Union[int, str]) -> str: """ Get detailed information about a specific chat. Args: - chat_id: The ID of the chat. + chat_id: The ID or username of the chat. """ try: entity = await client.get_entity(chat_id) @@ -636,12 +724,13 @@ async def get_direct_chat_by_contact(contact_query: str) -> str: @mcp.tool() -async def get_contact_chats(contact_id: int) -> str: +@validate_id('contact_id') +async def get_contact_chats(contact_id: Union[int, str]) -> str: """ List all chats involving a specific contact. Args: - contact_id: The ID of the contact. + contact_id: The ID or username of the contact. """ try: # Get contact info @@ -688,12 +777,13 @@ async def get_contact_chats(contact_id: int) -> str: @mcp.tool() -async def get_last_interaction(contact_id: int) -> str: +@validate_id('contact_id') +async def get_last_interaction(contact_id: Union[int, str]) -> str: """ Get the most recent message with a contact. Args: - contact_id: The ID of the contact. + contact_id: The ID or username of the contact. """ try: # Get contact info @@ -724,12 +814,13 @@ async def get_last_interaction(contact_id: int) -> str: @mcp.tool() -async def get_message_context(chat_id: int, message_id: int, context_size: int = 3) -> str: +@validate_id('chat_id') +async def get_message_context(chat_id: Union[int, str], message_id: int, context_size: int = 3) -> str: """ Retrieve context around a specific message. Args: - chat_id: The ID of the chat. + chat_id: The ID or username of the chat. message_id: The ID of the central message. context_size: Number of messages before and after to include. """ @@ -841,11 +932,12 @@ async def add_contact(phone: str, first_name: str, last_name: str = "") -> str: @mcp.tool() -async def delete_contact(user_id: int) -> str: +@validate_id('user_id') +async def delete_contact(user_id: Union[int, str]) -> str: """ Delete a contact by user ID. Args: - user_id: The Telegram user ID of the contact to delete. + user_id: The Telegram user ID or username of the contact to delete. """ try: user = await client.get_entity(user_id) @@ -856,11 +948,12 @@ async def delete_contact(user_id: int) -> str: @mcp.tool() -async def block_user(user_id: int) -> str: +@validate_id('user_id') +async def block_user(user_id: Union[int, str]) -> str: """ Block a user by user ID. Args: - user_id: The Telegram user ID to block. + user_id: The Telegram user ID or username to block. """ try: user = await client.get_entity(user_id) @@ -871,11 +964,12 @@ async def block_user(user_id: int) -> str: @mcp.tool() -async def unblock_user(user_id: int) -> str: +@validate_id('user_id') +async def unblock_user(user_id: Union[int, str]) -> str: """ Unblock a user by user ID. Args: - user_id: The Telegram user ID to unblock. + user_id: The Telegram user ID or username to unblock. """ try: user = await client.get_entity(user_id) @@ -898,13 +992,14 @@ async def get_me() -> str: @mcp.tool() -async def create_group(title: str, user_ids: list) -> str: +@validate_id('user_ids') +async def create_group(title: str, user_ids: List[Union[int, str]]) -> str: """ Create a new group or supergroup and add users. Args: title: Title for the new group - user_ids: List of user IDs to add to the group + user_ids: List of user IDs or usernames to add to the group """ try: # Convert user IDs to entities @@ -956,13 +1051,14 @@ async def create_group(title: str, user_ids: list) -> str: @mcp.tool() -async def invite_to_group(group_id: int, user_ids: list) -> str: +@validate_id('group_id', 'user_ids') +async def invite_to_group(group_id: Union[int, str], user_ids: List[Union[int, str]]) -> str: """ Invite users to a group or channel. Args: - group_id: The ID of the group/channel. - user_ids: List of user IDs to invite. + group_id: The ID or username of the group/channel. + user_ids: List of user IDs or usernames to invite. """ try: entity = await client.get_entity(group_id) @@ -1005,12 +1101,13 @@ async def invite_to_group(group_id: int, user_ids: list) -> str: @mcp.tool() -async def leave_chat(chat_id: int) -> str: +@validate_id('chat_id') +async def leave_chat(chat_id: Union[int, str]) -> str: """ Leave a group or channel by chat ID. Args: - chat_id: The chat ID to leave. + chat_id: The chat ID or username to leave. """ try: entity = await client.get_entity(chat_id) @@ -1084,11 +1181,12 @@ async def leave_chat(chat_id: int) -> str: @mcp.tool() -async def get_participants(chat_id: int) -> str: +@validate_id('chat_id') +async def get_participants(chat_id: Union[int, str]) -> str: """ List all participants in a group or channel. Args: - chat_id: The group or channel ID. + chat_id: The group or channel ID or username. """ try: participants = await client.get_participants(chat_id) @@ -1102,11 +1200,12 @@ async def get_participants(chat_id: int) -> str: @mcp.tool() -async def send_file(chat_id: int, file_path: str, caption: str = None) -> str: +@validate_id('chat_id') +async def send_file(chat_id: Union[int, str], file_path: str, caption: str = None) -> str: """ Send a file to a chat. Args: - chat_id: The chat ID. + chat_id: The chat ID or username. file_path: Absolute path to the file to send (must exist and be readable). caption: Optional caption for the file. """ @@ -1125,11 +1224,12 @@ async def send_file(chat_id: int, file_path: str, caption: str = None) -> str: @mcp.tool() -async def download_media(chat_id: int, message_id: int, file_path: str) -> str: +@validate_id('chat_id') +async def download_media(chat_id: Union[int, str], message_id: int, file_path: str) -> str: """ Download media from a message in a chat. Args: - chat_id: The chat ID. + chat_id: The chat ID or username. message_id: The message ID containing the media. file_path: Absolute path to save the downloaded file (must be writable). """ @@ -1226,16 +1326,17 @@ async def get_privacy_settings() -> str: @mcp.tool() +@validate_id('allow_users', 'disallow_users') async def set_privacy_settings( - key: str, allow_users: list = None, disallow_users: list = None + key: str, allow_users: Optional[List[Union[int, str]]] = None, disallow_users: Optional[List[Union[int, str]]] = None ) -> str: """ Set privacy settings (e.g., last seen, phone, etc.). Args: key: The privacy setting to modify ('status' for last seen, 'phone', 'profile_photo', etc.) - allow_users: List of user IDs to allow - disallow_users: List of user IDs to disallow + allow_users: List of user IDs or usernames to allow + disallow_users: List of user IDs or usernames to disallow """ try: # Import needed types @@ -1382,7 +1483,8 @@ async def create_channel(title: str, about: str = "", megagroup: bool = False) - @mcp.tool() -async def edit_chat_title(chat_id: int, title: str) -> str: +@validate_id('chat_id') +async def edit_chat_title(chat_id: Union[int, str], title: str) -> str: """ Edit the title of a chat, group, or channel. """ @@ -1401,7 +1503,8 @@ async def edit_chat_title(chat_id: int, title: str) -> str: @mcp.tool() -async def edit_chat_photo(chat_id: int, file_path: str) -> str: +@validate_id('chat_id') +async def edit_chat_photo(chat_id: Union[int, str], file_path: str) -> str: """ Edit the photo of a chat, group, or channel. Requires a file path to an image. """ @@ -1434,7 +1537,8 @@ async def edit_chat_photo(chat_id: int, file_path: str) -> str: @mcp.tool() -async def delete_chat_photo(chat_id: int) -> str: +@validate_id('chat_id') +async def delete_chat_photo(chat_id: Union[int, str]) -> str: """ Delete the photo of a chat, group, or channel. """ @@ -1462,13 +1566,14 @@ async def delete_chat_photo(chat_id: int) -> str: @mcp.tool() -async def promote_admin(group_id: int, user_id: int, rights: dict = None) -> str: +@validate_id('group_id', 'user_id') +async def promote_admin(group_id: Union[int, str], user_id: Union[int, str], rights: dict = None) -> str: """ Promote a user to admin in a group/channel. Args: - group_id: ID of the group/channel - user_id: User ID to promote + group_id: ID or username of the group/channel + user_id: User ID or username to promote rights: Admin rights to give (optional) """ try: @@ -1526,13 +1631,14 @@ async def promote_admin(group_id: int, user_id: int, rights: dict = None) -> str @mcp.tool() -async def demote_admin(group_id: int, user_id: int) -> str: +@validate_id('group_id', 'user_id') +async def demote_admin(group_id: Union[int, str], user_id: Union[int, str]) -> str: """ Demote a user from admin in a group/channel. Args: - group_id: ID of the group/channel - user_id: User ID to demote + group_id: ID or username of the group/channel + user_id: User ID or username to demote """ try: chat = await client.get_entity(group_id) @@ -1574,13 +1680,14 @@ async def demote_admin(group_id: int, user_id: int) -> str: @mcp.tool() -async def ban_user(chat_id: int, user_id: int) -> str: +@validate_id('chat_id', 'user_id') +async def ban_user(chat_id: Union[int, str], user_id: Union[int, str]) -> str: """ Ban a user from a group or channel. Args: - chat_id: ID of the group/channel - user_id: User ID to ban + chat_id: ID or username of the group/channel + user_id: User ID or username to ban """ try: chat = await client.get_entity(chat_id) @@ -1620,13 +1727,14 @@ async def ban_user(chat_id: int, user_id: int) -> str: @mcp.tool() -async def unban_user(chat_id: int, user_id: int) -> str: +@validate_id('chat_id', 'user_id') +async def unban_user(chat_id: Union[int, str], user_id: Union[int, str]) -> str: """ Unban a user from a group or channel. Args: - chat_id: ID of the group/channel - user_id: User ID to unban + chat_id: ID or username of the group/channel + user_id: User ID or username to unban """ try: chat = await client.get_entity(chat_id) @@ -1666,7 +1774,8 @@ async def unban_user(chat_id: int, user_id: int) -> str: @mcp.tool() -async def get_admins(chat_id: int) -> str: +@validate_id('chat_id') +async def get_admins(chat_id: Union[int, str]) -> str: """ Get all admins in a group or channel. """ @@ -1684,7 +1793,8 @@ async def get_admins(chat_id: int) -> str: @mcp.tool() -async def get_banned_users(chat_id: int) -> str: +@validate_id('chat_id') +async def get_banned_users(chat_id: Union[int, str]) -> str: """ Get all banned users in a group or channel. """ @@ -1704,7 +1814,8 @@ async def get_banned_users(chat_id: int) -> str: @mcp.tool() -async def get_invite_link(chat_id: int) -> str: +@validate_id('chat_id') +async def get_invite_link(chat_id: Union[int, str]) -> str: """ Get the invite link for a group or channel. """ @@ -1807,7 +1918,8 @@ async def join_chat_by_link(link: str) -> str: @mcp.tool() -async def export_chat_invite(chat_id: int) -> str: +@validate_id('chat_id') +async def export_chat_invite(chat_id: Union[int, str]) -> str: """ Export a chat invite link. """ @@ -1896,11 +2008,12 @@ async def import_chat_invite(hash: str) -> str: @mcp.tool() -async def send_voice(chat_id: int, file_path: str) -> str: +@validate_id('chat_id') +async def send_voice(chat_id: Union[int, str], file_path: str) -> str: """ Send a voice message to a chat. File must be an OGG/OPUS voice note. Args: - chat_id: The chat ID. + chat_id: The chat ID or username. file_path: Absolute path to the OGG/OPUS file. """ try: @@ -1926,7 +2039,8 @@ async def send_voice(chat_id: int, file_path: str) -> str: @mcp.tool() -async def forward_message(from_chat_id: int, message_id: int, to_chat_id: int) -> str: +@validate_id('from_chat_id', 'to_chat_id') +async def forward_message(from_chat_id: Union[int, str], message_id: int, to_chat_id: Union[int, str]) -> str: """ Forward a message from one chat to another. """ @@ -1946,7 +2060,8 @@ async def forward_message(from_chat_id: int, message_id: int, to_chat_id: int) - @mcp.tool() -async def edit_message(chat_id: int, message_id: int, new_text: str) -> str: +@validate_id('chat_id') +async def edit_message(chat_id: Union[int, str], message_id: int, new_text: str) -> str: """ Edit a message you sent. """ @@ -1961,7 +2076,8 @@ async def edit_message(chat_id: int, message_id: int, new_text: str) -> str: @mcp.tool() -async def delete_message(chat_id: int, message_id: int) -> str: +@validate_id('chat_id') +async def delete_message(chat_id: Union[int, str], message_id: int) -> str: """ Delete a message by ID. """ @@ -1974,7 +2090,8 @@ async def delete_message(chat_id: int, message_id: int) -> str: @mcp.tool() -async def pin_message(chat_id: int, message_id: int) -> str: +@validate_id('chat_id') +async def pin_message(chat_id: Union[int, str], message_id: int) -> str: """ Pin a message in a chat. """ @@ -1987,7 +2104,8 @@ async def pin_message(chat_id: int, message_id: int) -> str: @mcp.tool() -async def unpin_message(chat_id: int, message_id: int) -> str: +@validate_id('chat_id') +async def unpin_message(chat_id: Union[int, str], message_id: int) -> str: """ Unpin a message in a chat. """ @@ -2000,7 +2118,8 @@ async def unpin_message(chat_id: int, message_id: int) -> str: @mcp.tool() -async def mark_as_read(chat_id: int) -> str: +@validate_id('chat_id') +async def mark_as_read(chat_id: Union[int, str]) -> str: """ Mark all messages as read in a chat. """ @@ -2013,7 +2132,8 @@ async def mark_as_read(chat_id: int) -> str: @mcp.tool() -async def reply_to_message(chat_id: int, message_id: int, text: str) -> str: +@validate_id('chat_id') +async def reply_to_message(chat_id: Union[int, str], message_id: int, text: str) -> str: """ Reply to a specific message in a chat. """ @@ -2028,11 +2148,12 @@ async def reply_to_message(chat_id: int, message_id: int, text: str) -> str: @mcp.tool() -async def get_media_info(chat_id: int, message_id: int) -> str: +@validate_id('chat_id') +async def get_media_info(chat_id: Union[int, str], message_id: int) -> str: """ Get info about media in a message. Args: - chat_id: The chat ID. + chat_id: The chat ID or username. message_id: The message ID. """ try: @@ -2058,7 +2179,8 @@ async def search_public_chats(query: str) -> str: @mcp.tool() -async def search_messages(chat_id: int, query: str, limit: int = 20) -> str: +@validate_id('chat_id') +async def search_messages(chat_id: Union[int, str], query: str, limit: int = 20) -> str: """ Search for messages in a chat by text. """ @@ -2094,7 +2216,8 @@ async def resolve_username(username: str) -> str: @mcp.tool() -async def mute_chat(chat_id: int) -> str: +@validate_id('chat_id') +async def mute_chat(chat_id: Union[int, str]) -> str: """ Mute notifications for a chat. """ @@ -2132,7 +2255,8 @@ async def mute_chat(chat_id: int) -> str: @mcp.tool() -async def unmute_chat(chat_id: int) -> str: +@validate_id('chat_id') +async def unmute_chat(chat_id: Union[int, str]) -> str: """ Unmute notifications for a chat. """ @@ -2170,7 +2294,8 @@ async def unmute_chat(chat_id: int) -> str: @mcp.tool() -async def archive_chat(chat_id: int) -> str: +@validate_id('chat_id') +async def archive_chat(chat_id: Union[int, str]) -> str: """ Archive a chat. """ @@ -2186,7 +2311,8 @@ async def archive_chat(chat_id: int) -> str: @mcp.tool() -async def unarchive_chat(chat_id: int) -> str: +@validate_id('chat_id') +async def unarchive_chat(chat_id: Union[int, str]) -> str: """ Unarchive a chat. """ @@ -2214,11 +2340,12 @@ async def get_sticker_sets() -> str: @mcp.tool() -async def send_sticker(chat_id: int, file_path: str) -> str: +@validate_id('chat_id') +async def send_sticker(chat_id: Union[int, str], file_path: str) -> str: """ Send a sticker to a chat. File must be a valid .webp sticker file. Args: - chat_id: The chat ID. + chat_id: The chat ID or username. file_path: Absolute path to the .webp sticker file. """ try: @@ -2291,11 +2418,12 @@ async def get_gif_search(query: str, limit: int = 10) -> str: @mcp.tool() -async def send_gif(chat_id: int, gif_id: int) -> str: +@validate_id('chat_id') +async def send_gif(chat_id: Union[int, str], gif_id: int) -> str: """ Send a GIF to a chat by Telegram GIF document ID (not a file path). Args: - chat_id: The chat ID. + chat_id: The chat ID or username. gif_id: Telegram document ID for the GIF (from get_gif_search). """ try: @@ -2393,7 +2521,8 @@ async def set_bot_commands(bot_username: str, commands: list) -> str: @mcp.tool() -async def get_history(chat_id: int, limit: int = 100) -> str: +@validate_id('chat_id') +async def get_history(chat_id: Union[int, str], limit: int = 100) -> str: """ Get full chat history (up to limit). """ @@ -2415,7 +2544,8 @@ async def get_history(chat_id: int, limit: int = 100) -> str: @mcp.tool() -async def get_user_photos(user_id: int, limit: int = 10) -> str: +@validate_id('user_id') +async def get_user_photos(user_id: Union[int, str], limit: int = 10) -> str: """ Get profile photos of a user. """ @@ -2430,7 +2560,8 @@ async def get_user_photos(user_id: int, limit: int = 10) -> str: @mcp.tool() -async def get_user_status(user_id: int) -> str: +@validate_id('user_id') +async def get_user_status(user_id: Union[int, str]) -> str: """ Get the online status of a user. """ @@ -2442,7 +2573,8 @@ async def get_user_status(user_id: int) -> str: @mcp.tool() -async def get_recent_actions(chat_id: int) -> str: +@validate_id('chat_id') +async def get_recent_actions(chat_id: Union[int, str]) -> str: """ Get recent admin actions (admin log) in a group or channel. """ @@ -2464,7 +2596,8 @@ async def get_recent_actions(chat_id: int) -> str: @mcp.tool() -async def get_pinned_messages(chat_id: int) -> str: +@validate_id('chat_id') +async def get_pinned_messages(chat_id: Union[int, str]) -> str: """ Get all pinned messages in a chat. """ diff --git a/session_string_generator.py b/session_string_generator.py index 5647ab5..327e570 100755 --- a/session_string_generator.py +++ b/session_string_generator.py @@ -12,6 +12,11 @@ Usage: Requirements: - telethon - python-dotenv + +Note on ID Formats: +When using the MCP server, please be aware that all `chat_id` and `user_id` +parameters support integer IDs, string representations of IDs (e.g., "123456"), +and usernames (e.g., "@mychannel"). """ import os diff --git a/test_validation.py b/test_validation.py new file mode 100644 index 0000000..6e7354b --- /dev/null +++ b/test_validation.py @@ -0,0 +1,83 @@ +import pytest +import os +os.environ["TELEGRAM_API_ID"] = "12345" +os.environ["TELEGRAM_API_HASH"] = "dummy_hash" +from main import validate_id, ValidationError, log_and_format_error +from functools import wraps +import asyncio +from typing import Union, List, Optional + +# A simple async function to be decorated for testing +@validate_id('user_id', 'chat_id', 'user_ids') +async def dummy_function(**kwargs): + return "success", kwargs + +@pytest.mark.asyncio +async def test_valid_integer_id(): + result, kwargs = await dummy_function(user_id=12345) + assert result == "success" + assert kwargs['user_id'] == 12345 + +@pytest.mark.asyncio +async def test_valid_negative_integer_id(): + result, kwargs = await dummy_function(chat_id=-100123456) + assert result == "success" + assert kwargs['chat_id'] == -100123456 + +@pytest.mark.asyncio +async def test_valid_string_integer_id(): + result, kwargs = await dummy_function(user_id="12345") + assert result == "success" + assert kwargs['user_id'] == 12345 + +@pytest.mark.asyncio +async def test_valid_username(): + result, kwargs = await dummy_function(user_id="@test_user") + assert result == "success" + assert kwargs['user_id'] == "@test_user" + +@pytest.mark.asyncio +async def test_valid_username_without_at(): + result, kwargs = await dummy_function(user_id="test_user_long_enough") + assert result == "success" + assert kwargs['user_id'] == "test_user_long_enough" + +@pytest.mark.asyncio +async def test_valid_list_of_ids(): + result, kwargs = await dummy_function(user_ids=[123, "456", "@test_user"]) + assert result == "success" + assert kwargs['user_ids'] == [123, 456, "@test_user"] + +@pytest.mark.asyncio +async def test_invalid_float_id(): + result = await dummy_function(user_id=123.45) + assert "Invalid user_id" in result + assert "Type must be an integer or a string" in result + +@pytest.mark.asyncio +async def test_invalid_string_id(): + result = await dummy_function(user_id="inv") # too short + assert "Invalid user_id" in result + assert "Must be a valid integer ID, or a username string" in result + +@pytest.mark.asyncio +async def test_integer_out_of_range(): + result = await dummy_function(user_id=2**64) + assert "Invalid user_id" in result + assert "out of the valid integer range" in result + +@pytest.mark.asyncio +async def test_invalid_item_in_list(): + result = await dummy_function(user_ids=[123, "456", 123.45]) + assert "Invalid user_ids" in result + assert "Type must be an integer or a string" in result + +@pytest.mark.asyncio +async def test_no_id_provided(): + result, kwargs = await dummy_function() + assert result == "success" + +@pytest.mark.asyncio +async def test_none_id_provided(): + result, kwargs = await dummy_function(user_id=None) + assert result == "success"