diff --git a/Thunder/__main__.py b/Thunder/__main__.py index ff24011..f5fa296 100644 --- a/Thunder/__main__.py +++ b/Thunder/__main__.py @@ -21,6 +21,7 @@ from Thunder.utils.commands import set_commands from Thunder.utils.database import db from Thunder.utils.keepalive import ping_server +from Thunder.utils.canonical_files import drain_background_touch_tasks from Thunder.utils.logger import logger from Thunder.utils.messages import MSG_ADMIN_RESTART_DONE from Thunder.utils.rate_limiter import rate_limiter, request_executor @@ -49,6 +50,25 @@ def print_banner(): print(banner) +def schedule_index_ensure() -> None: + task = asyncio.create_task( + db.ensure_indexes(raise_on_error=False), + name="ensure_database_indexes" + ) + + def _log_index_failure(done_task: asyncio.Task) -> None: + try: + created_indexes = done_task.result() + if created_indexes: + print(" ✓ Database indexes ensured.") + else: + print(" ▶ Database indexes could not be ensured during startup.") + except Exception as e: + logger.error(f"Background database index ensure failed: {e}", exc_info=True) + + task.add_done_callback(_log_index_failure) + + async def import_plugins(): print("╠════════════════════ IMPORTING PLUGINS ════════════════════╣") plugins = glob.glob(PLUGIN_PATH) @@ -119,6 +139,7 @@ async def start_services(): await set_commands() print(" ✓ Bot commands set successfully.") + schedule_index_ensure() restart_message_data = await db.get_restart_message() if restart_message_data: @@ -212,6 +233,10 @@ async def start_services(): await rate_limiter.shutdown() except Exception: pass + try: + await drain_background_touch_tasks() + except Exception as e: + logger.error(f"Error during canonical touch task cleanup: {e}", exc_info=True) return elapsed_time = (datetime.now() - start_time).total_seconds() @@ -252,6 +277,11 @@ async def start_services(): except Exception as e: logger.error(f"Error during client cleanup: {e}") + try: + await drain_background_touch_tasks() + except Exception as e: + logger.error(f"Error during canonical touch task cleanup: {e}", exc_info=True) + if 'app_runner' in locals() and app_runner is not None: try: await app_runner.cleanup() diff --git a/Thunder/bot/plugins/stream.py b/Thunder/bot/plugins/stream.py index 6be9e94..2b6fe26 100644 --- a/Thunder/bot/plugins/stream.py +++ b/Thunder/bot/plugins/stream.py @@ -6,15 +6,16 @@ from pyrogram import Client, enums, filters from pyrogram.errors import FloodWait, MessageNotModified, MessageDeleteForbidden, MessageIdInvalid -from pyrogram.types import (InlineKeyboardButton, InlineKeyboardMarkup, - Message) - -from Thunder.bot import StreamBot -from Thunder.utils.bot_utils import (gen_links, is_admin, log_newusr, notify_own, - reply_user_err) -from Thunder.utils.database import db -from Thunder.utils.decorators import (check_banned, get_shortener_status, - require_token) +from pyrogram.types import (InlineKeyboardButton, InlineKeyboardMarkup, + Message) + +from Thunder.bot import StreamBot +from Thunder.utils.bot_utils import (gen_canonical_links, gen_links, is_admin, + log_newusr, notify_own, reply_user_err) +from Thunder.utils.canonical_files import get_or_create_canonical_file +from Thunder.utils.database import db +from Thunder.utils.decorators import (check_banned, get_shortener_status, + require_token) from Thunder.utils.force_channel import force_channel_check from Thunder.utils.logger import logger from Thunder.utils.messages import ( @@ -71,30 +72,55 @@ async def validate_request_common(client: Client, message: Message) -> Optional[ return await get_shortener_status(client, message) -async def send_channel_links(target_msg: Message, links: Dict[str, Any], source_info: str, source_id: int): - try: - await target_msg.reply_text( - MSG_NEW_FILE_REQUEST.format( - source_info=source_info, - id_=source_id, - online_link=links['online_link'], - stream_link=links['stream_link'] - ), - disable_web_page_preview=True, - quote=True - ) - except FloodWait as e: - await asyncio.sleep(e.value) - await target_msg.reply_text( - MSG_NEW_FILE_REQUEST.format( - source_info=source_info, - id_=source_id, - online_link=links['online_link'], - stream_link=links['stream_link'] - ), - disable_web_page_preview=True, - quote=True - ) +async def send_channel_links( + links: Dict[str, Any], + source_info: str, + source_id: int, + *, + target_msg: Optional[Message] = None, + reply_to_message_id: Optional[int] = None +): + try: + text = MSG_NEW_FILE_REQUEST.format( + source_info=source_info, + id_=source_id, + online_link=links['online_link'], + stream_link=links['stream_link'] + ) + if target_msg: + await target_msg.reply_text( + text, + disable_web_page_preview=True, + quote=True + ) + else: + await StreamBot.send_message( + chat_id=Var.BIN_CHANNEL, + text=text, + disable_web_page_preview=True, + reply_to_message_id=reply_to_message_id + ) + except FloodWait as e: + await asyncio.sleep(e.value) + text = MSG_NEW_FILE_REQUEST.format( + source_info=source_info, + id_=source_id, + online_link=links['online_link'], + stream_link=links['stream_link'] + ) + if target_msg: + await target_msg.reply_text( + text, + disable_web_page_preview=True, + quote=True + ) + else: + await StreamBot.send_message( + chat_id=Var.BIN_CHANNEL, + text=text, + disable_web_page_preview=True, + reply_to_message_id=reply_to_message_id + ) async def safe_edit_message(message: Message, text: str, **kwargs): @@ -313,18 +339,36 @@ async def _actual_channel_receive_handler(client: Client, message: Message, **ha f"({message.chat.title or 'Unknown'}). Ignoring message.") return - try: - stored_msg = await fwd_media(message) - if not stored_msg: - logger.error( - f"Failed to forward media from channel {message.chat.id}. Ignoring.") - return - shortener_val = await get_shortener_status(client, message) - links = await gen_links(stored_msg, shortener=shortener_val) - source_info = message.chat.title or "Unknown Channel" - - if notification_msg: - try: + try: + shortener_val = await get_shortener_status(client, message) + canonical_record, stored_msg, reused_existing = await get_or_create_canonical_file(message, fwd_media) + if reused_existing and stored_msg: + await safe_delete_message(stored_msg) + stored_msg = None + if canonical_record: + links = await gen_canonical_links( + file_name=canonical_record["file_name"], + file_size=int(canonical_record.get("file_size", 0) or 0), + public_hash=canonical_record["public_hash"], + shortener=shortener_val + ) + reply_to_message_id = int(canonical_record["canonical_message_id"]) + else: + if not stored_msg: + stored_msg = await fwd_media(message) + if not stored_msg: + logger.error( + f"Failed to forward media from channel {message.chat.id}. Ignoring.") + return + links = await gen_links(stored_msg, shortener=shortener_val) + reply_to_message_id = stored_msg.id + source_info = message.chat.title or "Unknown Channel" + # When we reused an existing canonical BIN copy, stored_msg is intentionally + # None so send_channel_links falls back to StreamBot.send_message(..., + # reply_to_message_id=...) and keeps the log threaded to the canonical message. + + if notification_msg: + try: try: await notification_msg.edit_text( MSG_NEW_FILE_REQUEST.format( @@ -346,11 +390,23 @@ async def _actual_channel_receive_handler(client: Client, message: Message, **ha ), disable_web_page_preview=True ) - except Exception as e: - logger.error(f"Error editing notification message with links: {e}", exc_info=True) - await send_channel_links(stored_msg, links, source_info, message.chat.id) - else: - await send_channel_links(stored_msg, links, source_info, message.chat.id) + except Exception as e: + logger.error(f"Error editing notification message with links: {e}", exc_info=True) + await send_channel_links( + links, + source_info, + message.chat.id, + target_msg=stored_msg, + reply_to_message_id=reply_to_message_id + ) + else: + await send_channel_links( + links, + source_info, + message.chat.id, + target_msg=stored_msg, + reply_to_message_id=reply_to_message_id + ) try: try: @@ -388,17 +444,32 @@ async def process_single( status_msg: Message, shortener_val: bool, original_request_msg: Optional[Message] = None, - notification_msg: Optional[Message] = None -): - try: - stored_msg = await fwd_media(file_msg) - if not stored_msg: - logger.error(f"Failed to forward media for message {file_msg.id}. Skipping.") - return None - links = await gen_links(stored_msg, shortener=shortener_val) - if notification_msg: - await safe_edit_message( - notification_msg, + notification_msg: Optional[Message] = None +): + try: + canonical_record, stored_msg, reused_existing = await get_or_create_canonical_file(file_msg, fwd_media) + if reused_existing and stored_msg: + await safe_delete_message(stored_msg) + stored_msg = None + if canonical_record: + links = await gen_canonical_links( + file_name=canonical_record["file_name"], + file_size=int(canonical_record.get("file_size", 0) or 0), + public_hash=canonical_record["public_hash"], + shortener=shortener_val + ) + canonical_reply_id = int(canonical_record["canonical_message_id"]) + else: + if not stored_msg: + stored_msg = await fwd_media(file_msg) + if not stored_msg: + logger.error(f"Failed to forward media for message {file_msg.id}. Skipping.") + return None + links = await gen_links(stored_msg, shortener=shortener_val) + canonical_reply_id = stored_msg.id + if notification_msg: + await safe_edit_message( + notification_msg, MSG_LINKS.format( file_name=links['media_name'], file_size=links['media_size'], @@ -421,35 +492,29 @@ async def process_single( if not source_info: source_info = f"@{source_msg.from_user.username}" if source_msg.from_user.username else "Unknown User" source_id = source_msg.from_user.id - elif source_msg.chat.type == enums.ChatType.CHANNEL: - source_info = source_msg.chat.title or "Unknown Channel" - source_id = source_msg.chat.id - if source_info and source_id: - try: - await stored_msg.reply_text( - MSG_NEW_FILE_REQUEST.format( - source_info=source_info, - id_=source_id, - online_link=links['online_link'], - stream_link=links['stream_link'] - ), - disable_web_page_preview=True, - quote=True - ) - except FloodWait as e: - await asyncio.sleep(e.value) - await stored_msg.reply_text( - MSG_NEW_FILE_REQUEST.format( - source_info=source_info, - id_=source_id, - online_link=links['online_link'], - stream_link=links['stream_link'] - ), - disable_web_page_preview=True, - quote=True - ) - if status_msg: - await safe_delete_message(status_msg) + elif source_msg.chat.type == enums.ChatType.CHANNEL: + source_info = source_msg.chat.title or "Unknown Channel" + source_id = source_msg.chat.id + if source_info and source_id: + try: + await send_channel_links( + links, + source_info, + source_id, + target_msg=stored_msg, + reply_to_message_id=canonical_reply_id + ) + except FloodWait as e: + await asyncio.sleep(e.value) + await send_channel_links( + links, + source_info, + source_id, + target_msg=stored_msg, + reply_to_message_id=canonical_reply_id + ) + if status_msg: + await safe_delete_message(status_msg) return links except Exception as e: logger.error(f"Error processing single file for message {file_msg.id}: {e}", exc_info=True) diff --git a/Thunder/server/stream_routes.py b/Thunder/server/stream_routes.py index 4cc5a92..3e5fb7f 100644 --- a/Thunder/server/stream_routes.py +++ b/Thunder/server/stream_routes.py @@ -1,19 +1,27 @@ # Thunder/server/stream_routes.py -import re -import secrets -import time -from urllib.parse import quote, unquote - -from aiohttp import web - -from Thunder import __version__, StartTime -from Thunder.bot import StreamBot, multi_clients, work_loads -from Thunder.server.exceptions import FileNotFound, InvalidHash -from Thunder.utils.custom_dl import ByteStreamer -from Thunder.utils.logger import logger -from Thunder.utils.render_template import render_page -from Thunder.utils.time_format import get_readable_time +import re +import secrets +import time +from urllib.parse import quote, unquote + +from aiohttp import web + +from Thunder import __version__, StartTime +from Thunder.bot import StreamBot, multi_clients, work_loads +from Thunder.server.exceptions import FileNotFound, InvalidHash +from Thunder.utils.bot_utils import quote_media_name +from Thunder.utils.canonical_files import ( + PUBLIC_HASH_LENGTH, + get_file_by_hash, + update_cached_file_id, +) +from Thunder.utils.custom_dl import ByteStreamer +from Thunder.utils.file_properties import get_media +from Thunder.utils.logger import logger +from Thunder.utils.render_template import render_media_page, render_page +from Thunder.utils.time_format import get_readable_time +from Thunder.vars import Var routes = web.RouteTableDef() @@ -24,8 +32,9 @@ PATTERN_HASH_FIRST = re.compile( rf"^([a-zA-Z0-9_-]{{{SECURE_HASH_LENGTH}}})(\d+)(?:/.*)?$") PATTERN_ID_FIRST = re.compile(r"^(\d+)(?:/.*)?$") -VALID_HASH_REGEX = re.compile(r'^[a-zA-Z0-9_-]+$') -VALID_DISPOSITIONS = {"inline", "attachment"} +VALID_HASH_REGEX = re.compile(r'^[a-zA-Z0-9_-]+$') +VALID_PUBLIC_HASH_REGEX = re.compile(rf'^[0-9a-f]{{{PUBLIC_HASH_LENGTH}}}$') +VALID_DISPOSITIONS = {"inline", "attachment"} CORS_HEADERS = { "Access-Control-Allow-Origin": "*", @@ -43,7 +52,7 @@ def get_streamer(client_id: int) -> ByteStreamer: return streamers[client_id] -def parse_media_request(path: str, query: dict) -> tuple[int, str]: +def parse_media_request(path: str, query: dict) -> tuple[int, str]: clean_path = unquote(path).strip('/') match = PATTERN_HASH_FIRST.match(clean_path) @@ -70,7 +79,14 @@ def parse_media_request(path: str, query: dict) -> tuple[int, str]: except ValueError as e: raise InvalidHash(f"Invalid message ID format in path: {e}") from e - raise InvalidHash("Invalid URL structure or missing hash") + raise InvalidHash("Invalid URL structure or missing hash") + + +def validate_public_hash(public_hash: str) -> str: + secure_hash = public_hash.strip().lower() + if len(secure_hash) != PUBLIC_HASH_LENGTH or not VALID_PUBLIC_HASH_REGEX.match(secure_hash): + raise InvalidHash("Invalid canonical file hash") + return secure_hash def select_optimal_client() -> tuple[int, ByteStreamer]: @@ -96,7 +112,7 @@ def get_content_disposition(request: web.Request) -> str: return disposition if disposition in VALID_DISPOSITIONS else "attachment" -def parse_range_header(range_header: str, file_size: int) -> tuple[int, int]: +def parse_range_header(range_header: str, file_size: int) -> tuple[int, int]: if not range_header: return 0, file_size - 1 @@ -124,7 +140,114 @@ def parse_range_header(range_header: str, file_size: int) -> tuple[int, int]: headers={"Content-Range": f"bytes */{file_size}"} ) - return start, end + return start, end + + +def _resolve_unique_id(file_info: dict) -> str: + unique_id = file_info.get("unique_id") or file_info.get("file_unique_id") + if not unique_id: + raise FileNotFound("File unique ID not found in info.") + return unique_id + + +def _resolve_filename(file_info: dict, mime_type: str) -> str: + filename = file_info.get("file_name") + if filename: + return filename + + ext = mime_type.split('/')[-1] if '/' in mime_type else 'bin' + ext_map = {'jpeg': 'jpg', 'mpeg': 'mp3', 'octet-stream': 'bin'} + ext = ext_map.get(ext, ext) + return f"file_{secrets.token_hex(4)}.{ext}" + + +async def _serve_media_response( + request: web.Request, + *, + file_info: dict, + streamer: ByteStreamer, + client_id: int, + media_ref: int | str, + fallback_message_id: int | None = None, + on_fallback_message=None +): + file_size = int(file_info.get('file_size', 0) or 0) + if file_size == 0: + raise FileNotFound("File size is reported as zero or unavailable.") + + range_header = request.headers.get("Range", "") + start, end = parse_range_header(range_header, file_size) + content_length = end - start + 1 + + if start == 0 and end == file_size - 1: + range_header = "" + + mime_type = file_info.get('mime_type') or 'application/octet-stream' + filename = _resolve_filename(file_info, mime_type) + disposition = get_content_disposition(request) + + headers = { + "Content-Type": mime_type, + "Content-Length": str(content_length), + "Content-Disposition": ( + f"{disposition}; filename*=UTF-8''{quote(filename, safe='')}"), + "Accept-Ranges": "bytes", + "Cache-Control": "public, max-age=31536000", + "Connection": "keep-alive", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "Range, Content-Type, *", + "Access-Control-Expose-Headers": ( + "Content-Length, Content-Range, Content-Disposition"), + "X-Content-Type-Options": "nosniff" + } + + if range_header: + headers["Content-Range"] = f"bytes {start}-{end}/{file_size}" + + if request.method == 'HEAD': + work_loads[client_id] -= 1 + return web.Response( + status=206 if range_header else 200, + headers=headers + ) + + async def stream_generator(): + try: + bytes_sent = 0 + bytes_to_skip = start % CHUNK_SIZE + + async for chunk in streamer.stream_file( + media_ref, + offset=start, + limit=content_length, + fallback_message_id=fallback_message_id, + on_fallback_message=on_fallback_message + ): + if bytes_to_skip > 0: + if len(chunk) <= bytes_to_skip: + bytes_to_skip -= len(chunk) + continue + chunk = chunk[bytes_to_skip:] + bytes_to_skip = 0 + + remaining = content_length - bytes_sent + if len(chunk) > remaining: + chunk = chunk[:remaining] + + if chunk: + yield chunk + bytes_sent += len(chunk) + + if bytes_sent >= content_length: + break + finally: + work_loads[client_id] -= 1 + + return web.Response( + status=206 if range_header else 200, + body=stream_generator(), + headers=headers + ) @routes.get("/", allow_head=True) @@ -167,16 +290,49 @@ async def status_options(request: web.Request): }) -@routes.options(r"/{path:.+}") -async def media_options(request: web.Request): +@routes.options(r"/{path:.+}") +async def media_options(request: web.Request): return web.Response(headers={ **CORS_HEADERS, "Access-Control-Max-Age": "86400" - }) - - -@routes.get(r"/watch/{path:.+}", allow_head=True) -async def media_preview(request: web.Request): + }) + + +@routes.get(r"/watch/f/{secure_hash}/{name:.+}", allow_head=True) +async def canonical_media_preview(request: web.Request): + try: + secure_hash = validate_public_hash(request.match_info["secure_hash"]) + file_record = await get_file_by_hash(secure_hash, raise_on_error=True) + if not file_record: + raise FileNotFound("Canonical file not found") + + file_name = file_record.get("file_name") or f"file_{secure_hash}" + src = f"{Var.URL.rstrip('/')}/f/{secure_hash}/{quote_media_name(file_name)}" + rendered_page = await render_media_page(file_name, src, requested_action='stream') + + response = web.Response( + text=rendered_page, + content_type='text/html', + headers={ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "Range, Content-Type, *", + "X-Content-Type-Options": "nosniff", + } + ) + response.enable_compression() + return response + except (InvalidHash, FileNotFound) as e: + logger.debug(f"Canonical preview error: {type(e).__name__} - {e}", exc_info=True) + raise web.HTTPNotFound(text="Resource not found") from e + except Exception as e: + error_id = secrets.token_hex(6) + logger.error(f"Canonical preview error {error_id}: {e}", exc_info=True) + raise web.HTTPInternalServerError( + text=f"Server error occurred: {error_id}") from e + + +@routes.get(r"/watch/{path:.+}", allow_head=True) +async def media_preview(request: web.Request): try: path = request.match_info["path"] message_id, secure_hash = parse_media_request(path, request.query) @@ -204,113 +360,96 @@ async def media_preview(request: web.Request): except Exception as e: error_id = secrets.token_hex(6) logger.error(f"Preview error {error_id}: {e}", exc_info=True) - raise web.HTTPInternalServerError( - text=f"Server error occurred: {error_id}") from e - - -@routes.get(r"/{path:.+}", allow_head=True) -async def media_delivery(request: web.Request): + raise web.HTTPInternalServerError( + text=f"Server error occurred: {error_id}") from e + + +@routes.get(r"/f/{secure_hash}/{name:.+}", allow_head=True) +async def canonical_media_delivery(request: web.Request): + try: + secure_hash = validate_public_hash(request.match_info["secure_hash"]) + file_record = await get_file_by_hash(secure_hash, raise_on_error=True) + if not file_record: + raise FileNotFound("Canonical file not found") + + client_id, streamer = select_optimal_client() + work_loads[client_id] += 1 + + try: + _resolve_unique_id(file_record) + media_ref = int(file_record["canonical_message_id"]) + if client_id == 0 and file_record.get("file_id"): + media_ref = file_record["file_id"] + fallback_message_id = int(file_record["canonical_message_id"]) + + async def persist_refreshed_file_id(message): + if client_id != 0: + return + media = get_media(message) + new_file_id = getattr(media, "file_id", None) if media else None + if new_file_id and new_file_id != file_record.get("file_id"): + try: + await update_cached_file_id(file_record, new_file_id) + except Exception as e: + logger.warning( + f"Failed to refresh cached file_id for canonical file {secure_hash}: {e}", + exc_info=True + ) + + return await _serve_media_response( + request, + file_info=file_record, + streamer=streamer, + client_id=client_id, + media_ref=media_ref, + fallback_message_id=fallback_message_id, + on_fallback_message=persist_refreshed_file_id + ) + except (FileNotFound, InvalidHash): + work_loads[client_id] -= 1 + raise + except Exception as e: + work_loads[client_id] -= 1 + error_id = secrets.token_hex(6) + logger.error(f"Canonical stream error {error_id}: {e}", exc_info=True) + raise web.HTTPInternalServerError( + text=f"Server error during streaming: {error_id}") from e + except (InvalidHash, FileNotFound) as e: + logger.debug(f"Canonical client error: {type(e).__name__} - {e}", exc_info=True) + raise web.HTTPNotFound(text="Resource not found") from e + except Exception as e: + error_id = secrets.token_hex(6) + logger.error(f"Canonical server error {error_id}: {e}", exc_info=True) + raise web.HTTPInternalServerError( + text=f"An unexpected server error occurred: {error_id}") from e + + +@routes.get(r"/{path:.+}", allow_head=True) +async def media_delivery(request: web.Request): try: path = request.match_info["path"] message_id, secure_hash = parse_media_request(path, request.query) client_id, streamer = select_optimal_client() - work_loads[client_id] += 1 - - try: - file_info = await streamer.get_file_info(message_id) - if not file_info.get('unique_id'): - raise FileNotFound("File unique ID not found in info.") - - if (file_info['unique_id'][:SECURE_HASH_LENGTH] != - secure_hash): - raise InvalidHash( - "Provided hash does not match file's unique ID.") - - file_size = file_info.get('file_size', 0) - if file_size == 0: - raise FileNotFound( - "File size is reported as zero or unavailable.") - - range_header = request.headers.get("Range", "") - start, end = parse_range_header(range_header, file_size) - content_length = end - start + 1 - - if start == 0 and end == file_size - 1: - range_header = "" - - mime_type = ( - file_info.get('mime_type') or 'application/octet-stream') - - filename = file_info.get('file_name') - if not filename: - ext = mime_type.split('/')[-1] if '/' in mime_type else 'bin' - ext_map = {'jpeg': 'jpg', 'mpeg': 'mp3', 'octet-stream': 'bin'} - ext = ext_map.get(ext, ext) - filename = f"file_{secrets.token_hex(4)}.{ext}" - - disposition = get_content_disposition(request) - - headers = { - "Content-Type": mime_type, - "Content-Length": str(content_length), - "Content-Disposition": ( - f"{disposition}; filename*=UTF-8''{quote(filename)}"), - "Accept-Ranges": "bytes", - "Cache-Control": "public, max-age=31536000", - "Connection": "keep-alive", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Headers": "Range, Content-Type, *", - "Access-Control-Expose-Headers": ( - "Content-Length, Content-Range, Content-Disposition"), - "X-Content-Type-Options": "nosniff" - } - - if range_header: - headers["Content-Range"] = f"bytes {start}-{end}/{file_size}" - - if request.method == 'HEAD': - work_loads[client_id] -= 1 - return web.Response( - status=206 if range_header else 200, - headers=headers - ) - - async def stream_generator(): - try: - bytes_sent = 0 - bytes_to_skip = start % CHUNK_SIZE - - async for chunk in streamer.stream_file( - message_id, offset=start, limit=content_length): - if bytes_to_skip > 0: - if len(chunk) <= bytes_to_skip: - bytes_to_skip -= len(chunk) - continue - chunk = chunk[bytes_to_skip:] - bytes_to_skip = 0 - - remaining = content_length - bytes_sent - if len(chunk) > remaining: - chunk = chunk[:remaining] - - if chunk: - yield chunk - bytes_sent += len(chunk) - - if bytes_sent >= content_length: - break - finally: - work_loads[client_id] -= 1 - - return web.Response( - status=206 if range_header else 200, - body=stream_generator(), - headers=headers - ) - - except (FileNotFound, InvalidHash): + work_loads[client_id] += 1 + + try: + file_info = await streamer.get_file_info(message_id) + unique_id = _resolve_unique_id(file_info) + + if unique_id[:SECURE_HASH_LENGTH] != secure_hash: + raise InvalidHash( + "Provided hash does not match file's unique ID.") + return await _serve_media_response( + request, + file_info=file_info, + streamer=streamer, + client_id=client_id, + media_ref=message_id + ) + + except (FileNotFound, InvalidHash): work_loads[client_id] -= 1 raise except Exception as e: diff --git a/Thunder/utils/bot_utils.py b/Thunder/utils/bot_utils.py index ca96b6b..f14b3c6 100644 --- a/Thunder/utils/bot_utils.py +++ b/Thunder/utils/bot_utils.py @@ -1,8 +1,8 @@ # Thunder/utils/bot_utils.py -import asyncio -from typing import Any, Dict, Optional -from urllib.parse import quote +import asyncio +from typing import Any, Dict, Optional +from urllib.parse import quote from pyrogram import Client from pyrogram.enums import ChatMemberStatus @@ -16,12 +16,64 @@ from Thunder.utils.logger import logger from Thunder.utils.messages import (MSG_BUTTON_GET_HELP, MSG_DC_UNKNOWN, MSG_DC_USER_INFO, MSG_NEW_USER) -from Thunder.utils.shortener import shorten -from Thunder.vars import Var - - - -async def notify_ch(cli: Client, txt: str): +from Thunder.utils.shortener import shorten +from Thunder.vars import Var + + +def quote_media_name(file_name: str) -> str: + return quote(str(file_name).replace("/", "_"), safe="") + + +async def _build_links( + *, + download_path: str, + stream_path: str, + media_name: str, + media_size: str, + shortener: bool = True +) -> Dict[str, str]: + base_url = Var.URL.rstrip("/") + slink = f"{base_url}{stream_path}" + olink = f"{base_url}{download_path}" + + if shortener and getattr(Var, "SHORTEN_MEDIA_LINKS", False): + try: + s_results = await asyncio.gather(shorten(slink), shorten(olink), return_exceptions=True) + if not isinstance(s_results[0], Exception): + slink = s_results[0] + else: + logger.warning(f"Failed to shorten stream_link: {s_results[0]}") + if not isinstance(s_results[1], Exception): + olink = s_results[1] + else: + logger.warning(f"Failed to shorten online_link: {s_results[1]}") + except Exception as e: + logger.error(f"Error during link shortening: {e}") + + return {"stream_link": slink, "online_link": olink, "media_name": media_name, "media_size": media_size} + + +async def gen_canonical_links( + *, + file_name: str, + file_size: int, + public_hash: str, + shortener: bool = True +) -> Dict[str, str]: + media_name = str(file_name) + media_size = humanbytes(file_size) + encoded_name = quote_media_name(media_name) + return await _build_links( + download_path=f"/f/{public_hash}/{encoded_name}", + stream_path=f"/watch/f/{public_hash}/{encoded_name}", + media_name=media_name, + media_size=media_size, + shortener=shortener + ) + + + +async def notify_ch(cli: Client, txt: str): if not (hasattr(Var, 'BIN_CHANNEL') and isinstance(Var.BIN_CHANNEL, int) and Var.BIN_CHANNEL != 0): return try: @@ -78,32 +130,20 @@ async def log_newusr(cli: Client, uid: int, fname: str): logger.error(f"Database error in log_newusr for user {uid}: {e}") -async def gen_links(fwd_msg: Message, shortener: bool = True) -> Dict[str, str]: - base_url = Var.URL.rstrip("/") - fid = fwd_msg.id - m_name_raw = get_fname(fwd_msg) - m_name = m_name_raw.decode('utf-8', errors='replace') if isinstance(m_name_raw, bytes) else str(m_name_raw) - m_size_hr = humanbytes(get_fsize(fwd_msg)) - enc_fname = quote(m_name) - f_hash = get_hash(fwd_msg) - slink = f"{base_url}/watch/{f_hash}{fid}/{enc_fname}" - olink = f"{base_url}/{f_hash}{fid}/{enc_fname}" - - if shortener and getattr(Var, "SHORTEN_MEDIA_LINKS", False): - try: - s_results = await asyncio.gather(shorten(slink), shorten(olink), return_exceptions=True) - if not isinstance(s_results[0], Exception): - slink = s_results[0] - else: - logger.warning(f"Failed to shorten stream_link: {s_results[0]}") - if not isinstance(s_results[1], Exception): - olink = s_results[1] - else: - logger.warning(f"Failed to shorten online_link: {s_results[1]}") - except Exception as e: - logger.error(f"Error during link shortening: {e}") - - return {"stream_link": slink, "online_link": olink, "media_name": m_name, "media_size": m_size_hr} +async def gen_links(fwd_msg: Message, shortener: bool = True) -> Dict[str, str]: + fid = fwd_msg.id + m_name_raw = get_fname(fwd_msg) + m_name = m_name_raw.decode('utf-8', errors='replace') if isinstance(m_name_raw, bytes) else str(m_name_raw) + m_size_hr = humanbytes(get_fsize(fwd_msg)) + enc_fname = quote_media_name(m_name) + f_hash = get_hash(fwd_msg) + return await _build_links( + download_path=f"/{f_hash}{fid}/{enc_fname}", + stream_path=f"/watch/{f_hash}{fid}/{enc_fname}", + media_name=m_name, + media_size=m_size_hr, + shortener=shortener + ) async def gen_dc_txt(usr: User) -> str: diff --git a/Thunder/utils/canonical_files.py b/Thunder/utils/canonical_files.py new file mode 100644 index 0000000..84773bb --- /dev/null +++ b/Thunder/utils/canonical_files.py @@ -0,0 +1,412 @@ +import asyncio +import datetime +import hashlib +from collections import OrderedDict +from contextlib import asynccontextmanager +from typing import Any, Awaitable, Callable, Dict, Optional, Tuple + +from pyrogram.errors import FloodWait +from pyrogram.types import Message +from pymongo.errors import DuplicateKeyError + +from Thunder.bot import StreamBot +from Thunder.utils.database import db +from Thunder.utils.file_properties import get_fname, get_media, get_uniqid +from Thunder.utils.logger import logger +from Thunder.vars import Var + +PUBLIC_HASH_LENGTH = 20 +_CACHE_TTL_SECONDS = 600 +_CACHE_MAX_ITEMS = 4096 +_INGEST_CLAIM_TTL_SECONDS = 60 +_INGEST_CLAIM_WAIT_SECONDS = 15 +_INGEST_CLAIM_POLL_SECONDS = 0.5 + +_cache_by_unique_id: "OrderedDict[str, Tuple[float, Dict[str, Any]]]" = OrderedDict() +_cache_by_hash: "OrderedDict[str, Tuple[float, Dict[str, Any]]]" = OrderedDict() +_cache_by_message_id: "OrderedDict[int, Tuple[float, Dict[str, Any]]]" = OrderedDict() + +_upload_locks: dict[str, asyncio.Lock] = {} +_upload_lock_counts: dict[str, int] = {} +_upload_locks_guard = asyncio.Lock() +_background_touch_tasks: set[asyncio.Task] = set() + + +def build_public_hash(file_unique_id: str) -> str: + return hashlib.sha256(file_unique_id.encode("utf-8")).hexdigest()[:PUBLIC_HASH_LENGTH] + + +def _infer_mime_type(media: Any) -> str: + mime_type = getattr(media, "mime_type", None) + if mime_type: + return mime_type + + mime_map = { + "photo": "image/jpeg", + "voice": "audio/ogg", + "videonote": "video/mp4", + } + return mime_map.get(type(media).__name__.lower(), "application/octet-stream") + + +def build_file_record( + stored_message: Message, + *, + source_chat_id: Optional[int] = None, + source_message_id: Optional[int] = None +) -> Optional[Dict[str, Any]]: + media = get_media(stored_message) + file_unique_id = get_uniqid(stored_message) + if not media or not file_unique_id: + return None + + now = datetime.datetime.utcnow() + return { + "file_unique_id": file_unique_id, + "public_hash": build_public_hash(file_unique_id), + "canonical_message_id": stored_message.id, + "file_id": getattr(media, "file_id", None), + "file_name": get_fname(stored_message), + "mime_type": _infer_mime_type(media), + "file_size": getattr(media, "file_size", 0) or 0, + "media_type": type(media).__name__.lower(), + "first_source_chat_id": source_chat_id, + "first_source_message_id": source_message_id, + "created_at": now, + "last_seen_at": now, + "seen_count": 1, + "reuse_count": 0 + } + + +def _prune_cache(cache: "OrderedDict[Any, Tuple[float, Dict[str, Any]]]") -> None: + now = asyncio.get_running_loop().time() + expired_keys = [key for key, (ts, _) in cache.items() if now - ts > _CACHE_TTL_SECONDS] + for key in expired_keys: + cache.pop(key, None) + while len(cache) > _CACHE_MAX_ITEMS: + cache.popitem(last=False) + + +def _cache_get( + cache: "OrderedDict[Any, Tuple[float, Dict[str, Any]]]", + key: Any +) -> Optional[Dict[str, Any]]: + if key not in cache: + return None + ts, value = cache[key] + now = asyncio.get_running_loop().time() + if now - ts > _CACHE_TTL_SECONDS: + cache.pop(key, None) + return None + cache.move_to_end(key) + return value + + +def _remember(record: Dict[str, Any]) -> Dict[str, Any]: + now = asyncio.get_running_loop().time() + file_unique_id = record.get("file_unique_id") + public_hash = record.get("public_hash") + canonical_message_id = record.get("canonical_message_id") + + if file_unique_id: + _cache_by_unique_id[file_unique_id] = (now, record) + _cache_by_unique_id.move_to_end(file_unique_id) + _prune_cache(_cache_by_unique_id) + if public_hash: + _cache_by_hash[public_hash] = (now, record) + _cache_by_hash.move_to_end(public_hash) + _prune_cache(_cache_by_hash) + if canonical_message_id is not None: + _cache_by_message_id[canonical_message_id] = (now, record) + _cache_by_message_id.move_to_end(canonical_message_id) + _prune_cache(_cache_by_message_id) + return record + + +def _forget(record: Dict[str, Any]) -> None: + file_unique_id = record.get("file_unique_id") + public_hash = record.get("public_hash") + canonical_message_id = record.get("canonical_message_id") + + if file_unique_id: + _cache_by_unique_id.pop(file_unique_id, None) + if public_hash: + _cache_by_hash.pop(public_hash, None) + if canonical_message_id is not None: + _cache_by_message_id.pop(canonical_message_id, None) + + +async def get_file_by_unique_id(file_unique_id: str) -> Optional[Dict[str, Any]]: + cached = _cache_get(_cache_by_unique_id, file_unique_id) + if cached: + return cached + record = await db.get_file_by_unique_id(file_unique_id) + return _remember(record) if record else None + + +async def get_file_by_hash( + public_hash: str, + *, + raise_on_error: bool = False +) -> Optional[Dict[str, Any]]: + cached = _cache_get(_cache_by_hash, public_hash) + if cached: + return cached + record = await db.get_file_by_hash(public_hash, raise_on_error=raise_on_error) + return _remember(record) if record else None + + +async def get_file_by_message_id(canonical_message_id: int) -> Optional[Dict[str, Any]]: + cached = _cache_get(_cache_by_message_id, canonical_message_id) + if cached: + return cached + record = await db.get_file_by_message_id(canonical_message_id) + return _remember(record) if record else None + + +async def touch_file_record(record: Dict[str, Any], *, reused: bool = False) -> None: + if not record.get("public_hash"): + return + record["last_seen_at"] = datetime.datetime.utcnow() + record["seen_count"] = int(record.get("seen_count", 0)) + 1 + if reused: + record["reuse_count"] = int(record.get("reuse_count", 0)) + 1 + _remember(record) + await db.touch_file_record(record["public_hash"], reused=reused, raise_on_error=True) + + +def schedule_touch_file_record(record: Dict[str, Any], *, reused: bool = False) -> None: + if not record.get("public_hash"): + return + + record["last_seen_at"] = datetime.datetime.utcnow() + record["seen_count"] = int(record.get("seen_count", 0)) + 1 + if reused: + record["reuse_count"] = int(record.get("reuse_count", 0)) + 1 + _remember(record) + + task = asyncio.create_task( + db.touch_file_record(record["public_hash"], reused=reused), + name=f"touch_file_record:{record['public_hash']}" + ) + _background_touch_tasks.add(task) + + def _log_touch_failure(done_task: asyncio.Task) -> None: + _background_touch_tasks.discard(done_task) + try: + touched = done_task.result() + if not touched: + logger.error( + f"Background touch did not update canonical file {record['public_hash']}" + ) + except Exception as e: + logger.error( + f"Background touch failed for canonical file {record['public_hash']}: {e}", + exc_info=True + ) + + task.add_done_callback(_log_touch_failure) + + +async def drain_background_touch_tasks() -> None: + if not _background_touch_tasks: + return + + pending = tuple(_background_touch_tasks) + if pending: + await asyncio.gather(*pending, return_exceptions=True) + + +async def update_cached_file_id(record: Dict[str, Any], file_id: str) -> None: + if not record.get("public_hash") or not file_id: + return + record["file_id"] = file_id + _remember(record) + await db.update_file_id(record["public_hash"], file_id, raise_on_error=True) + + +async def _fetch_canonical_message(record: Dict[str, Any]) -> Optional[Message]: + canonical_message_id = record.get("canonical_message_id") + if canonical_message_id is None: + return None + + try: + try: + message = await StreamBot.get_messages( + chat_id=int(Var.BIN_CHANNEL), + message_ids=int(canonical_message_id) + ) + except FloodWait as e: + await asyncio.sleep(e.value) + message = await StreamBot.get_messages( + chat_id=int(Var.BIN_CHANNEL), + message_ids=int(canonical_message_id) + ) + except Exception as e: + logger.warning( + f"Error fetching canonical message {canonical_message_id}: {e}", + exc_info=True + ) + raise + + if not message or not message.media: + return None + return message + + +async def _is_canonical_record_valid(record: Dict[str, Any], file_unique_id: str) -> bool: + message = await _fetch_canonical_message(record) + return bool(message and get_uniqid(message) == file_unique_id) + + +async def _get_reusable_canonical_record( + file_unique_id: str +) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]]]: + existing = await get_file_by_unique_id(file_unique_id) + if not existing: + return None, None + + try: + is_valid = await _is_canonical_record_valid(existing, file_unique_id) + except Exception as e: + logger.warning( + f"Falling back to BIN re-copy for {file_unique_id} after canonical validation failed: {e}", + exc_info=True + ) + is_valid = False + + if is_valid: + return existing, None + + _forget(existing) + return None, existing + + +async def _wait_for_other_worker_canonical_record(file_unique_id: str) -> Optional[Dict[str, Any]]: + loop = asyncio.get_running_loop() + deadline = loop.time() + _INGEST_CLAIM_WAIT_SECONDS + + while loop.time() < deadline: + reusable_record, _ = await _get_reusable_canonical_record(file_unique_id) + if reusable_record: + return reusable_record + + if not await db.is_file_ingest_claim_active(file_unique_id): + break + + await asyncio.sleep(_INGEST_CLAIM_POLL_SECONDS) + + return None + + +def _merge_replacement_record( + existing: Dict[str, Any], + refreshed: Dict[str, Any] +) -> Dict[str, Any]: + refreshed["created_at"] = existing.get("created_at", refreshed["created_at"]) + refreshed["seen_count"] = int(existing.get("seen_count", 0)) + 1 + refreshed["reuse_count"] = int(existing.get("reuse_count", 0)) + refreshed["first_source_chat_id"] = existing.get( + "first_source_chat_id", + refreshed.get("first_source_chat_id") + ) + refreshed["first_source_message_id"] = existing.get( + "first_source_message_id", + refreshed.get("first_source_message_id") + ) + return refreshed + + +@asynccontextmanager +async def file_ingest_lock(file_unique_id: str): + async with _upload_locks_guard: + lock = _upload_locks.get(file_unique_id) + if lock is None: + lock = asyncio.Lock() + _upload_locks[file_unique_id] = lock + _upload_lock_counts[file_unique_id] = 0 + _upload_lock_counts[file_unique_id] += 1 + + acquired = False + try: + await lock.acquire() + acquired = True + yield + finally: + if acquired: + lock.release() + async with _upload_locks_guard: + remaining = _upload_lock_counts.get(file_unique_id, 1) - 1 + if remaining <= 0: + _upload_lock_counts.pop(file_unique_id, None) + _upload_locks.pop(file_unique_id, None) + else: + _upload_lock_counts[file_unique_id] = remaining + + +async def get_or_create_canonical_file( + source_message: Message, + copy_media: Callable[[Message], Awaitable[Optional[Message]]] +) -> Tuple[Optional[Dict[str, Any]], Optional[Message], bool]: + file_unique_id = get_uniqid(source_message) + if not file_unique_id: + return None, None, False + + async with file_ingest_lock(file_unique_id): + while True: + reusable_record, stale_record = await _get_reusable_canonical_record(file_unique_id) + if reusable_record: + schedule_touch_file_record(reusable_record, reused=True) + return reusable_record, None, True + + claim_acquired = await db.acquire_file_ingest_claim( + file_unique_id, + ttl_seconds=_INGEST_CLAIM_TTL_SECONDS + ) + if not claim_acquired: + reusable_record = await _wait_for_other_worker_canonical_record(file_unique_id) + if reusable_record: + schedule_touch_file_record(reusable_record, reused=True) + return reusable_record, None, True + continue + + try: + reusable_record, stale_record = await _get_reusable_canonical_record(file_unique_id) + if reusable_record: + schedule_touch_file_record(reusable_record, reused=True) + return reusable_record, None, True + + stored_message = await copy_media(source_message) + if not stored_message: + return None, None, False + + record = build_file_record( + stored_message, + source_chat_id=source_message.chat.id if source_message.chat else None, + source_message_id=source_message.id + ) + if not record: + return None, stored_message, False + + try: + if stale_record: + record = _merge_replacement_record(stale_record, record) + await db.replace_file_record(record) + else: + await db.create_file_record(record) + _remember(record) + return record, stored_message, False + except DuplicateKeyError: + reusable_record = await _wait_for_other_worker_canonical_record(file_unique_id) + if reusable_record: + schedule_touch_file_record(reusable_record, reused=True) + return reusable_record, stored_message, True + raise + except FloodWait: + raise + except Exception as e: + logger.error(f"Error creating canonical file for {file_unique_id}: {e}", exc_info=True) + return None, stored_message, False + finally: + await db.release_file_ingest_claim(file_unique_id) diff --git a/Thunder/utils/custom_dl.py b/Thunder/utils/custom_dl.py index b0e0e23..50f953a 100644 --- a/Thunder/utils/custom_dl.py +++ b/Thunder/utils/custom_dl.py @@ -1,106 +1,135 @@ -# Thunder/utils/custom_dl.py - -import asyncio -from typing import Any, AsyncGenerator, Dict - -from pyrogram import Client -from pyrogram.errors import FloodWait -from pyrogram.types import Message - -from Thunder.server.exceptions import FileNotFound -from Thunder.utils.file_properties import get_media -from Thunder.utils.logger import logger -from Thunder.vars import Var - - -class ByteStreamer: - __slots__ = ('client', 'chat_id') - - def __init__(self, client: Client) -> None: - self.client = client - self.chat_id = int(Var.BIN_CHANNEL) - - async def get_message(self, message_id: int) -> Message: - while True: - try: - message = await self.client.get_messages(self.chat_id, message_id) - break - except FloodWait as e: - logger.debug(f"FloodWait: get_message, sleep {e.value}s") - await asyncio.sleep(e.value) - except Exception as e: - logger.debug(f"Error fetching message {message_id}: {e}", exc_info=True) - raise FileNotFound(f"Message {message_id} not found") from e - - if not message or not message.media: - raise FileNotFound(f"Message {message_id} not found") - return message - - async def stream_file( - self, message_id: int, offset: int = 0, limit: int = 0 - ) -> AsyncGenerator[bytes, None]: - message = await self.get_message(message_id) - - chunk_offset = offset // (1024 * 1024) - - chunk_limit = 0 - if limit > 0: - chunk_limit = ((limit + (1024 * 1024) - 1) // (1024 * 1024)) + 1 - - while True: - try: - async for chunk in self.client.stream_media( - message, offset=chunk_offset, limit=chunk_limit - ): - yield chunk - break - except FloodWait as e: - logger.debug(f"FloodWait: stream_file, sleep {e.value}s") - await asyncio.sleep(e.value) - - def get_file_info_sync(self, message: Message) -> Dict[str, Any]: - media = get_media(message) - if not media: - return {"message_id": message.id, "error": "No media"} - - media_type = type(media).__name__.lower() - file_name = getattr(media, 'file_name', None) - mime_type = getattr(media, 'mime_type', None) - - if not file_name: - ext_map = { - "photo": "jpg", - "audio": "mp3", - "voice": "ogg", - "video": "mp4", - "animation": "mp4", - "videonote": "mp4", - "sticker": "webp", - } - ext = ext_map.get(media_type, "bin") - file_name = f"Thunder_{message.id}.{ext}" - - if not mime_type: - mime_map = { - "photo": "image/jpeg", - "voice": "audio/ogg", - "videonote": "video/mp4", - } - mime_type = mime_map.get(media_type) - - return { - "message_id": message.id, - "file_size": getattr(media, 'file_size', 0) or 0, - "file_name": file_name, - "mime_type": mime_type, - "unique_id": getattr(media, 'file_unique_id', None), - "media_type": media_type - } - - async def get_file_info(self, message_id: int) -> Dict[str, Any]: - try: - message = await self.get_message(message_id) - return self.get_file_info_sync(message) - except Exception as e: - logger.debug(f"Error getting file info for {message_id}: {e}", exc_info=True) - return {"message_id": message_id, "error": str(e)} +# Thunder/utils/custom_dl.py + +import asyncio +from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, Optional + +from pyrogram import Client +from pyrogram.errors import FloodWait +from pyrogram.types import Message + +from Thunder.server.exceptions import FileNotFound +from Thunder.utils.file_properties import get_media +from Thunder.utils.logger import logger +from Thunder.vars import Var + + +class ByteStreamer: + __slots__ = ('client', 'chat_id') + + def __init__(self, client: Client) -> None: + self.client = client + self.chat_id = int(Var.BIN_CHANNEL) + + async def get_message(self, message_id: int) -> Message: + while True: + try: + message = await self.client.get_messages(self.chat_id, message_id) + break + except FloodWait as e: + logger.debug(f"FloodWait: get_message, sleep {e.value}s") + await asyncio.sleep(e.value) + except Exception as e: + logger.debug(f"Error fetching message {message_id}: {e}", exc_info=True) + raise FileNotFound(f"Message {message_id} not found") from e + + if not message or not message.media: + raise FileNotFound(f"Message {message_id} not found") + return message + + async def stream_file( + self, + media_ref: int | str | Message, + offset: int = 0, + limit: int = 0, + fallback_message_id: int | None = None, + on_fallback_message: Optional[Callable[[Message], Awaitable[None]]] = None + ) -> AsyncGenerator[bytes, None]: + chunk_offset = offset // (1024 * 1024) + chunk_limit = 0 + if limit > 0: + chunk_limit = ((limit + (1024 * 1024) - 1) // (1024 * 1024)) + 1 + + refs: list[int | str | Message] = [media_ref] + media_id = media_ref if isinstance(media_ref, int) else None + if isinstance(media_ref, Message): + media_id = getattr(media_ref, "id", getattr(media_ref, "message_id", None)) + if fallback_message_id is not None and (media_id is None or fallback_message_id != media_id): + refs.append(fallback_message_id) + + last_error: Exception | None = None + for ref in refs: + started_stream = False + while True: + try: + target = await self.get_message(ref) if isinstance(ref, int) else ref + if ( + on_fallback_message is not None and + fallback_message_id is not None and + ref == fallback_message_id and + isinstance(target, Message) + ): + await on_fallback_message(target) + async for chunk in self.client.stream_media( + target, offset=chunk_offset, limit=chunk_limit + ): + started_stream = True + yield chunk + return + except FloodWait as e: + logger.debug(f"FloodWait: stream_file, sleep {e.value}s") + await asyncio.sleep(e.value) + except Exception as e: + last_error = e + logger.debug(f"Error streaming media ref {ref}: {e}", exc_info=True) + if started_stream: + raise + break + + raise FileNotFound(f"Unable to stream file: {last_error}") + + def get_file_info_sync(self, message: Message) -> Dict[str, Any]: + media = get_media(message) + if not media: + return {"message_id": message.id, "error": "No media"} + + media_type = type(media).__name__.lower() + file_name = getattr(media, 'file_name', None) + mime_type = getattr(media, 'mime_type', None) + + if not file_name: + ext_map = { + "photo": "jpg", + "audio": "mp3", + "voice": "ogg", + "video": "mp4", + "animation": "mp4", + "videonote": "mp4", + "sticker": "webp", + } + ext = ext_map.get(media_type, "bin") + file_name = f"Thunder_{message.id}.{ext}" + + if not mime_type: + mime_map = { + "photo": "image/jpeg", + "voice": "audio/ogg", + "videonote": "video/mp4", + } + mime_type = mime_map.get(media_type) + + return { + "message_id": message.id, + "file_size": getattr(media, 'file_size', 0) or 0, + "file_name": file_name, + "mime_type": mime_type, + "unique_id": getattr(media, 'file_unique_id', None), + "media_type": media_type + } + + async def get_file_info(self, message_id: int) -> Dict[str, Any]: + try: + message = await self.get_message(message_id) + return self.get_file_info_sync(message) + except Exception as e: + logger.debug(f"Error getting file info for {message_id}: {e}", exc_info=True) + return {"message_id": message_id, "error": str(e)} diff --git a/Thunder/utils/database.py b/Thunder/utils/database.py index def3d20..4851543 100644 --- a/Thunder/utils/database.py +++ b/Thunder/utils/database.py @@ -1,9 +1,10 @@ # Thunder/utils/database.py import datetime -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional from pymongo import AsyncMongoClient from pymongo.asynchronous.collection import AsyncCollection +from pymongo.errors import DuplicateKeyError from Thunder.vars import Var from Thunder.utils.logger import logger @@ -17,8 +18,10 @@ def __init__(self, uri: str, database_name: str, *args, **kwargs): self.token_col: AsyncCollection = self.db.tokens self.authorized_users_col: AsyncCollection = self.db.authorized_users self.restart_message_col: AsyncCollection = self.db.restart_message + self.files_col: AsyncCollection = self.db.files + self.file_ingest_locks_col: AsyncCollection = self.db.file_ingest_locks - async def ensure_indexes(self): + async def ensure_indexes(self, *, raise_on_error: bool = True) -> bool: try: await self.banned_users_col.create_index("user_id", unique=True) await self.banned_channels_col.create_index("channel_id", unique=True) @@ -29,11 +32,20 @@ async def ensure_indexes(self): await self.token_col.create_index("activated") await self.restart_message_col.create_index("message_id", unique=True) await self.restart_message_col.create_index("timestamp", expireAfterSeconds=3600) + await self.files_col.create_index("file_unique_id", unique=True) + await self.files_col.create_index("public_hash", unique=True) + await self.files_col.create_index("canonical_message_id", unique=True) + await self.files_col.create_index("created_at") + await self.files_col.create_index("last_seen_at") + await self.file_ingest_locks_col.create_index("expires_at", expireAfterSeconds=0) logger.debug("Database indexes ensured.") + return True except Exception as e: logger.error(f"Error in ensure_indexes: {e}", exc_info=True) - raise + if raise_on_error: + raise + return False def new_user(self, user_id: int) -> dict: try: @@ -245,6 +257,165 @@ async def is_user_authorized(self, user_id: int) -> bool: logger.error(f"Error in is_user_authorized for user {user_id}: {e}", exc_info=True) return False + async def get_file_by_unique_id(self, file_unique_id: str) -> Optional[Dict[str, Any]]: + try: + return await self.files_col.find_one({"file_unique_id": file_unique_id}) + except Exception as e: + logger.error(f"Error getting file by unique_id {file_unique_id}: {e}", exc_info=True) + return None + + async def get_file_by_hash( + self, + public_hash: str, + *, + raise_on_error: bool = False + ) -> Optional[Dict[str, Any]]: + try: + return await self.files_col.find_one({"public_hash": public_hash}) + except Exception as e: + logger.error(f"Error getting file by hash {public_hash}: {e}", exc_info=True) + if raise_on_error: + raise + return None + + async def get_file_by_message_id(self, canonical_message_id: int) -> Optional[Dict[str, Any]]: + try: + return await self.files_col.find_one({"canonical_message_id": canonical_message_id}) + except Exception as e: + logger.error( + f"Error getting file by message_id {canonical_message_id}: {e}", + exc_info=True + ) + return None + + async def create_file_record(self, file_record: Dict[str, Any]) -> None: + try: + await self.files_col.insert_one(file_record) + except Exception as e: + logger.error( + f"Error creating canonical file record for {file_record.get('file_unique_id')}: {e}", + exc_info=True + ) + raise + + async def replace_file_record(self, file_record: Dict[str, Any]) -> None: + try: + await self.files_col.replace_one( + {"file_unique_id": file_record["file_unique_id"]}, + file_record, + upsert=True + ) + except Exception as e: + logger.error( + f"Error replacing canonical file record for {file_record.get('file_unique_id')}: {e}", + exc_info=True + ) + raise + + async def touch_file_record( + self, + public_hash: str, + *, + reused: bool = False, + raise_on_error: bool = False + ) -> bool: + try: + update_doc: Dict[str, Any] = { + "$set": {"last_seen_at": datetime.datetime.utcnow()}, + "$inc": {"seen_count": 1} + } + if reused: + update_doc["$inc"]["reuse_count"] = 1 + await self.files_col.update_one({"public_hash": public_hash}, update_doc) + return True + except Exception as e: + logger.error(f"Error touching canonical file {public_hash}: {e}", exc_info=True) + if raise_on_error: + raise + return False + + async def update_file_id( + self, + public_hash: str, + file_id: str, + *, + raise_on_error: bool = False + ) -> bool: + try: + await self.files_col.update_one( + {"public_hash": public_hash}, + { + "$set": { + "file_id": file_id, + "last_seen_at": datetime.datetime.utcnow() + } + } + ) + return True + except Exception as e: + logger.error(f"Error updating file_id for {public_hash}: {e}", exc_info=True) + if raise_on_error: + raise + return False + + async def acquire_file_ingest_claim( + self, + file_unique_id: str, + *, + ttl_seconds: int = 60 + ) -> bool: + now = datetime.datetime.utcnow() + claim_fields = { + "created_at": now, + "expires_at": now + datetime.timedelta(seconds=ttl_seconds) + } + try: + result = await self.file_ingest_locks_col.find_one_and_update( + { + "_id": file_unique_id, + "$or": [ + {"expires_at": {"$lte": now}}, + {"expires_at": {"$exists": False}} + ] + }, + { + "$set": claim_fields + }, + upsert=True, + return_document=False + ) + # Upsert-created documents return None; replacements of stale claims return old doc. + # Both mean the caller now owns the claim. + return result is None or bool(result) + except DuplicateKeyError: + # Another worker inserted an active claim concurrently. + return False + except Exception as e: + logger.error(f"Error acquiring ingest claim for {file_unique_id}: {e}", exc_info=True) + raise + + async def release_file_ingest_claim(self, file_unique_id: str) -> bool: + try: + await self.file_ingest_locks_col.delete_one({"_id": file_unique_id}) + return True + except Exception as e: + logger.error(f"Error releasing ingest claim for {file_unique_id}: {e}", exc_info=True) + return False + + async def is_file_ingest_claim_active(self, file_unique_id: str) -> bool: + try: + claim = await self.file_ingest_locks_col.find_one( + { + "_id": file_unique_id, + "expires_at": {"$gt": datetime.datetime.utcnow()} + }, + {"_id": 1} + ) + return bool(claim) + except Exception as e: + logger.error(f"Error checking ingest claim for {file_unique_id}: {e}", exc_info=True) + raise + async def close(self): if self._client: await self._client.close() diff --git a/Thunder/utils/render_template.py b/Thunder/utils/render_template.py index 454f4a0..6fde40a 100644 --- a/Thunder/utils/render_template.py +++ b/Thunder/utils/render_template.py @@ -1,60 +1,66 @@ -# Thunder/utils/render_template.py - -import asyncio -import html as html_module -import urllib.parse - -from jinja2 import Environment, FileSystemLoader -from pyrogram.errors import FloodWait - -from Thunder.bot import StreamBot -from Thunder.server.exceptions import InvalidHash -from Thunder.utils.file_properties import get_fname, get_uniqid -from Thunder.utils.logger import logger -from Thunder.vars import Var - -template_env = Environment( - loader=FileSystemLoader('Thunder/template'), - enable_async=True, - cache_size=200, - auto_reload=False, - optimized=True +# Thunder/utils/render_template.py + +import asyncio +import urllib.parse + +from jinja2 import Environment, FileSystemLoader, select_autoescape +from pyrogram.errors import FloodWait + +from Thunder.bot import StreamBot +from Thunder.server.exceptions import InvalidHash +from Thunder.utils.file_properties import get_fname, get_uniqid +from Thunder.utils.logger import logger +from Thunder.vars import Var + +template_env = Environment( + loader=FileSystemLoader('Thunder/template'), + autoescape=select_autoescape(enabled_extensions=("html",), default_for_string=True), + enable_async=True, + cache_size=200, + auto_reload=False, + optimized=True ) -async def render_page(id: int, secure_hash: str, requested_action: str | None = None) -> str: - try: - try: - message = await StreamBot.get_messages(chat_id=int(Var.BIN_CHANNEL), message_ids=id) - except FloodWait as e: - await asyncio.sleep(e.value) - message = await StreamBot.get_messages(chat_id=int(Var.BIN_CHANNEL), message_ids=id) - - if not message: - raise InvalidHash("Message not found") - - file_unique_id = get_uniqid(message) - file_name = get_fname(message) - +async def render_media_page(file_name: str, src: str, requested_action: str | None = None) -> str: + if requested_action == 'stream': + template = template_env.get_template('req.html') + context = { + 'heading': f"View {file_name}", + 'file_name': file_name, + 'src': f"{src}?disposition=inline" + } + else: + template = template_env.get_template('dl.html') + context = { + 'file_name': file_name, + 'src': src + } + return await template.render_async(**context) + + +async def render_page(message_id: int, secure_hash: str, requested_action: str | None = None) -> str: + try: + try: + message = await StreamBot.get_messages(chat_id=int(Var.BIN_CHANNEL), message_ids=message_id) + except FloodWait as e: + await asyncio.sleep(e.value) + message = await StreamBot.get_messages(chat_id=int(Var.BIN_CHANNEL), message_ids=message_id) + + if not message: + raise InvalidHash("Message not found") + + file_unique_id = get_uniqid(message) + file_name = get_fname(message) + if not file_unique_id or file_unique_id[:6] != secure_hash: raise InvalidHash("File unique ID or secure hash mismatch during rendering.") - - quoted_filename = urllib.parse.quote(file_name.replace('/', '_')) - src = urllib.parse.urljoin(Var.URL, f'{secure_hash}{id}/{quoted_filename}') - safe_filename = html_module.escape(file_name) - if requested_action == 'stream': - template = template_env.get_template('req.html') - context = { - 'heading': f"View {safe_filename}", - 'file_name': safe_filename, - 'src': f"{src}?disposition=inline" - } - else: - template = template_env.get_template('dl.html') - context = { - 'file_name': safe_filename, - 'src': src - } - return await template.render_async(**context) - except Exception as e: - logger.error(f"Error in render_page for ID {id} and hash {secure_hash}: {e}", exc_info=True) - raise + + quoted_filename = urllib.parse.quote(file_name.replace('/', '_'), safe="") + src = urllib.parse.urljoin(Var.URL, f'{secure_hash}{message_id}/{quoted_filename}') + return await render_media_page(file_name, src, requested_action) + except Exception as e: + logger.error( + f"Error in render_page for message_id {message_id} and hash {secure_hash}: {e}", + exc_info=True + ) + raise