From a11aa60062b3241632eca3fb402a6c38e80229ac Mon Sep 17 00:00:00 2001 From: Divanshu Chauhan <23524935+Divkix@users.noreply.github.com> Date: Sun, 20 Nov 2022 23:44:29 -0700 Subject: [PATCH] Maintenance update --- WebStreamer/__main__.py | 3 ++ WebStreamer/bot/plugins/admin.py | 8 +++- WebStreamer/bot/plugins/ban.py | 3 ++ WebStreamer/bot/plugins/start.py | 58 ++++++++++++++--------- WebStreamer/bot/plugins/stream.py | 6 ++- WebStreamer/db/__init__.py | 3 ++ WebStreamer/db/downloads.py | 26 +++++++++-- WebStreamer/db/mongo.py | 22 +++++---- WebStreamer/db/users.py | 36 ++++++++++++--- WebStreamer/server/__init__.py | 5 ++ WebStreamer/server/stream_routes.py | 21 +++++++-- WebStreamer/utils/broadcast_helper.py | 11 +++-- WebStreamer/utils/custom_dl.py | 66 +++++++++++++++++---------- WebStreamer/utils/human_readable.py | 5 +- WebStreamer/utils/ikb.py | 16 +++++-- WebStreamer/utils/joinCheck.py | 2 +- WebStreamer/utils/time_format.py | 3 ++ WebStreamer/vars.py | 4 ++ cf-worker/src/index.ts | 8 ++++ 19 files changed, 229 insertions(+), 77 deletions(-) diff --git a/WebStreamer/__main__.py b/WebStreamer/__main__.py index 1264b2c..720b6cc 100644 --- a/WebStreamer/__main__.py +++ b/WebStreamer/__main__.py @@ -19,6 +19,9 @@ async def start_services(): + """ + Start the bot and the web server + """ LOGGER.info("------------------- Initializing Telegram Bot -------------------") await StreamBot.start() LOGGER.info("----------------------------- DONE -----------------------------") diff --git a/WebStreamer/bot/plugins/admin.py b/WebStreamer/bot/plugins/admin.py index c0a7359..a38c4d1 100644 --- a/WebStreamer/bot/plugins/admin.py +++ b/WebStreamer/bot/plugins/admin.py @@ -21,6 +21,9 @@ filters.command("status") & filters.private & filters.user(Vars.OWNER_ID), ) async def status(_, m: Message): + """ + Get status of the bot, number of users, number of files, etc. + """ dl = Downloads() filename = "downloadList.txt" total_users = await Users().total_users_count() @@ -59,6 +62,9 @@ async def status(_, m: Message): & filters.reply, ) async def broadcast_(_, m: Message): + """ + Broadcast a message to all users + """ all_users = await Users().get_all_users() broadcast_msg = m.reply_to_message while 1: @@ -79,7 +85,7 @@ async def broadcast_(_, m: Message): ) async with open_aiofiles("broadcast.txt", "w") as broadcast_log_file: for user in all_users: - sts, msg = await send_msg(user_id=int(user), message=broadcast_msg) + sts, msg = await send_msg(user_id=int(user), m=broadcast_msg) if msg is not None: await broadcast_log_file.write(msg) if sts == 200: diff --git a/WebStreamer/bot/plugins/ban.py b/WebStreamer/bot/plugins/ban.py index 75f5eb0..4aa7e0b 100644 --- a/WebStreamer/bot/plugins/ban.py +++ b/WebStreamer/bot/plugins/ban.py @@ -7,6 +7,9 @@ @StreamBot.on_callback_query(filters.regex("^ban_")) async def ban_user(c: StreamBot, q: CallbackQuery): + """ + Ban a user from using the bot + """ user_id = int(q.data.split("_", 1)[1]) await c.ban_chat_member(Vars.AUTH_CHANNEL, user_id) await q.answer("User Banned from Updates Channel!", show_alert=True) diff --git a/WebStreamer/bot/plugins/start.py b/WebStreamer/bot/plugins/start.py index a2b11cd..2b142c1 100644 --- a/WebStreamer/bot/plugins/start.py +++ b/WebStreamer/bot/plugins/start.py @@ -46,7 +46,7 @@ class Btns: channel_and_group = [ - ("Support Group", "https://t.me/DivideProjectsDiscussion", "url"), + ("Support Group", "https://t.me/DivideSupport", "url"), ("Channel", "https://t.me/DivideProjects", "url"), ] about_me = ("About Me", "aboutbot") @@ -57,6 +57,9 @@ class Btns: @StreamBot.on_message(filters.command("start") & filters.private) @joinCheck() async def start(_, m: Message): + """ + Start the bot + """ return await m.reply_text( text=PMTEXT.format(m.from_user.mention), parse_mode=ParseMode.HTML, @@ -68,6 +71,9 @@ async def start(_, m: Message): @StreamBot.on_message(filters.command("help") & filters.private) @joinCheck() async def help_handler(_, m: Message): + """ + Help message handler + """ return await m.reply_text( HELPTEXT, parse_mode=ParseMode.HTML, @@ -77,28 +83,36 @@ async def help_handler(_, m: Message): @StreamBot.on_callback_query() async def button(_, m: CallbackQuery): + """ + handle button presses + """ cb_data = m.data msg = m.message - if cb_data == "aboutbot": - await msg.edit( - text=ABOUT, - parse_mode=ParseMode.HTML, - disable_web_page_preview=True, - reply_markup=ikb([[Btns.back]]), - ) - elif cb_data == "helptext": - await msg.edit( - text=HELPTEXT, - parse_mode=ParseMode.HTML, - disable_web_page_preview=True, - reply_markup=ikb([[Btns.back]]), - ) - elif cb_data == "gotohome": - await msg.edit( - text=PMTEXT.format(msg.from_user.mention), - parse_mode=ParseMode.HTML, - disable_web_page_preview=True, - reply_markup=ikb([Btns.channel_and_group, [Btns.about_me, Btns.help_me]]), - ) + match cb_data: + case "aboutbot": + await msg.edit( + text=ABOUT, + parse_mode=ParseMode.HTML, + disable_web_page_preview=True, + reply_markup=ikb([[Btns.back]]), + ) + case "helptext": + await msg.edit( + text=HELPTEXT, + parse_mode=ParseMode.HTML, + disable_web_page_preview=True, + reply_markup=ikb([[Btns.back]]), + ) + case "gotohome": + await msg.edit( + text=PMTEXT.format(msg.from_user.mention), + parse_mode=ParseMode.HTML, + disable_web_page_preview=True, + reply_markup=ikb( + [Btns.channel_and_group, [Btns.about_me, Btns.help_me]], + ), + ) + case _: + await msg.edit("Invalid Button Pressed!") await m.answer() diff --git a/WebStreamer/bot/plugins/stream.py b/WebStreamer/bot/plugins/stream.py index 4746a03..117a253 100644 --- a/WebStreamer/bot/plugins/stream.py +++ b/WebStreamer/bot/plugins/stream.py @@ -27,6 +27,7 @@ @DivideProjects """ +# Cache for storing how many times a user has used the bot, takes number of mimuted from Vars ttl_dict = TTLCache(maxsize=512, ttl=(Vars.FLOODCONTROL_TIME_MINUTES * 60)) @@ -35,7 +36,7 @@ & (filters.document | filters.video | filters.audio | filters.photo), group=4, ) -@joinCheck() +@joinCheck() # Check if user has joined the channel async def private_receive_handler(c: Client, m: Message): user = m.from_user user_id = user.id @@ -122,6 +123,9 @@ async def private_receive_handler(c: Client, m: Message): @StreamBot.on_callback_query(filters.regex("^delete_url.")) async def delete_download(_, q: CallbackQuery): + """ + Delete the download link from the database using a callback query + """ user_id = q.from_user.id msg = q.message url = str(q.data.split(".")[-1]) diff --git a/WebStreamer/db/__init__.py b/WebStreamer/db/__init__.py index ed73de9..9d1783f 100644 --- a/WebStreamer/db/__init__.py +++ b/WebStreamer/db/__init__.py @@ -5,6 +5,9 @@ def __connect_first(): + """ + Connect to the database before importing the models + """ _ = MongoDB("test") LOGGER.info("Initialized Database!") diff --git a/WebStreamer/db/downloads.py b/WebStreamer/db/downloads.py index 5b92443..b3ee32e 100644 --- a/WebStreamer/db/downloads.py +++ b/WebStreamer/db/downloads.py @@ -1,17 +1,28 @@ from datetime import datetime, timedelta from secrets import token_urlsafe +from typing import Tuple, Union from WebStreamer.db.mongo import MongoDB from WebStreamer.logger import LOGGER class Downloads(MongoDB): + """ + Define downloads collection here + """ + db_name = "filestreamerbot_downloads" def __init__(self): + """ + Initialize the collection + """ super().__init__(self.db_name) async def add_download(self, message_id: int, random_url: str, user_id: int) -> str: + """ + Add a download to the database + """ LOGGER.info(f"Added {random_url}: {message_id}") real_link = token_urlsafe(16) await self.insert_one( @@ -25,13 +36,19 @@ async def add_download(self, message_id: int, random_url: str, user_id: int) -> ) return real_link - async def get_actual_link(self, link: str): + async def get_actual_link(self, link: str) -> Union[str, None]: + """ + Get the actual link from the database + """ document = await self.find_one({"random_link": link}) if not document: return None return document["link"] - async def get_msg_id(self, link: str): + async def get_msg_id(self, link: str) -> Tuple[int, bool, datetime]: + """ + Get the message id from the database + """ document = await self.find_one({"link": link}) if not document: return 0, False, datetime.now() @@ -39,7 +56,10 @@ async def get_msg_id(self, link: str): valid = valid_upto > datetime.now() return document["message_id"], valid, valid_upto - async def total_downloads(self): + async def total_downloads(self) -> int: + """ + Get the total number of downloads + """ return await self.count() async def valid_downloads_list(self): diff --git a/WebStreamer/db/mongo.py b/WebStreamer/db/mongo.py index e749ccc..58c6ea2 100644 --- a/WebStreamer/db/mongo.py +++ b/WebStreamer/db/mongo.py @@ -1,3 +1,5 @@ +from typing import Any, Tuple, Union + from motor.motor_asyncio import AsyncIOMotorClient from WebStreamer.vars import Vars @@ -7,43 +9,45 @@ class MongoDB: - """Class for interacting with Bot database.""" + """ + Class for interacting with Bot database. + """ def __init__(self, collection) -> None: self.collection = main_db[collection] # Insert one entry into collection - async def insert_one(self, document): + async def insert_one(self, document) -> str: result = await self.collection.insert_one(document) return repr(result.inserted_id) # Find one entry from collection - async def find_one(self, query): + async def find_one(self, query) -> Union[bool, None, Any]: result = await self.collection.find_one(query) if result: return result return False # Find entries from collection - async def find_all(self, query=None): + async def find_all(self, query=None) -> Union[bool, None, Any]: if query is None: query = {} return [document async for document in self.collection.find(query)] # Count entries from collection - async def count(self, query=None): + async def count(self, query=None) -> int: if query is None: query = {} return await self.collection.count_documents(query) # Delete entry/entries from collection - async def delete_one(self, query): + async def delete_one(self, query) -> int: await self.collection.delete_many(query) after_delete = await self.collection.count_documents({}) return after_delete # Replace one entry in collection - async def replace(self, query, new_data): + async def replace(self, query, new_data) -> Tuple[int, int]: old = await self.collection.find_one(query) _id = old["_id"] await self.collection.replace_one({"_id": _id}, new_data) @@ -51,11 +55,11 @@ async def replace(self, query, new_data): return old, new # Update one entry from collection - async def update(self, query, update): + async def update(self, query, update) -> Tuple[int, int]: result = await self.collection.update_one(query, {"$set": update}) new_document = await self.collection.find_one(query) return result.modified_count, new_document @staticmethod - async def db_command(command: str): + async def db_command(command: str) -> Any: return await main_db.command(command) diff --git a/WebStreamer/db/users.py b/WebStreamer/db/users.py index 7556353..9d58be7 100644 --- a/WebStreamer/db/users.py +++ b/WebStreamer/db/users.py @@ -1,27 +1,48 @@ from datetime import date +from typing import List, Union from WebStreamer.db.mongo import MongoDB from WebStreamer.logger import LOGGER -def new_user(uid): - return {"id": uid, "join_date": date.today().isoformat(), "downloads": []} +def new_user(uid: int): + """ + Creates a new user in the database + """ + return { + "id": uid, + "join_date": date.today().isoformat(), + "downloads": [], + } class Users(MongoDB): + """ + Users collections to be made in the database + """ + db_name = "filestreamerbot_users" def __init__(self): super().__init__(self.db_name) - async def total_users_count(self): + async def total_users_count(self) -> int: + """ + Returns the total number of users in the database + """ return await self.count({}) - async def get_all_users(self): + async def get_all_users(self) -> List[int]: + """ + Returns a list of all users in the database + """ users = await self.find_all({}) return [user["id"] for user in users] - async def user_exists(self, user_id: int): + async def user_exists(self, user_id: int) -> bool: + """ + Checks if a user exists in the database + """ user = await self.find_one({"id": user_id}) if not user: user_data = { @@ -33,5 +54,8 @@ async def user_exists(self, user_id: int): return False return True - async def delete_user(self, user_id: int): + async def delete_user(self, user_id: int) -> Union[bool, int]: + """ + Deletes a user from the database + """ return await self.delete_one({"id": user_id}) diff --git a/WebStreamer/server/__init__.py b/WebStreamer/server/__init__.py index 41e2251..523dc64 100644 --- a/WebStreamer/server/__init__.py +++ b/WebStreamer/server/__init__.py @@ -6,11 +6,16 @@ async def web_server(): + """ + Create the web server and return it + """ web_app = web.Application(client_max_size=30000000) + # setup jinja2 with the web templates from templates folder setup_jinja2( web_app, enable_async=True, loader=FileSystemLoader("/app/WebStreamer/html/templates"), ) + # add the routes to the web app web_app.add_routes(routes) return web_app diff --git a/WebStreamer/server/stream_routes.py b/WebStreamer/server/stream_routes.py index 66e1e6f..b0057ab 100644 --- a/WebStreamer/server/stream_routes.py +++ b/WebStreamer/server/stream_routes.py @@ -2,6 +2,7 @@ from mimetypes import guess_type from secrets import token_hex from time import time +from typing import Dict, Union from aiohttp import web from aiohttp_jinja2 import template @@ -18,7 +19,10 @@ @routes.get("/", allow_head=True) -async def index_handler(_): +async def index_handler(_) -> web.StreamResponse: + """ + Index Handler for WebStreamer, the '/' route. + """ return web.json_response( { "status": "Active", @@ -32,7 +36,10 @@ async def index_handler(_): # custom download page @routes.get("/download-file-{random_link}") @template("download_page.html") -async def stream_handler(request): +async def stream_handler(request) -> Union[web.StreamResponse | Dict[str]]: + """ + Stream Handler for WebStreamer, the '/download-file-*' route. + """ try: random_link = request.match_info["random_link"] real_link = await Downloads().get_actual_link(random_link) @@ -44,7 +51,10 @@ async def stream_handler(request): # actual download link @routes.get("/{real_link}") -async def stream_handler(request): +async def stream_handler(request) -> web.StreamResponse: + """ + Stream Handler for WebStreamer, the '/{real_link}' route. + """ try: real_link = request.match_info["real_link"] message_id, valid, valid_upto = await Downloads().get_msg_id(real_link) @@ -73,7 +83,10 @@ async def stream_handler(request): raise web.HTTPNotFound -async def media_streamer(request, message_id: int): +async def media_streamer(request, message_id: int) -> web.StreamResponse: + """ + Media Streamer for WebStreamer, the '/{real_link}' route. + """ range_header = request.headers.get("Range", 0) media_msg = await StreamBot.get_messages(Vars.LOG_CHANNEL, message_id) file_properties = await TGCustomYield().generate_file_properties(media_msg) diff --git a/WebStreamer/utils/broadcast_helper.py b/WebStreamer/utils/broadcast_helper.py index 36d3de4..52dd961 100644 --- a/WebStreamer/utils/broadcast_helper.py +++ b/WebStreamer/utils/broadcast_helper.py @@ -1,5 +1,6 @@ from asyncio import sleep from traceback import format_exc +from typing import Tuple, Union from pyrogram.errors import ( FloodWait, @@ -7,15 +8,19 @@ PeerIdInvalid, UserIsBlocked, ) +from pyrogram.types import Message -async def send_msg(user_id, message): +async def send_msg(user_id: int, m: Message) -> Tuple[int, Union[Message, None, str]]: + """ + Send message to user using their user_id + """ try: - await message.forward(chat_id=user_id) + await m.forward(chat_id=user_id) return 200, None except FloodWait as e: await sleep(e.value) - return send_msg(user_id, message) + return send_msg(user_id, m) except InputUserDeactivated: return 400, f"{user_id} : deactivated\n" except UserIsBlocked: diff --git a/WebStreamer/utils/custom_dl.py b/WebStreamer/utils/custom_dl.py index 1cbe89e..6ead691 100644 --- a/WebStreamer/utils/custom_dl.py +++ b/WebStreamer/utils/custom_dl.py @@ -10,18 +10,23 @@ from WebStreamer.bot import StreamBot -async def chunk_size(length): +async def chunk_size(length) -> int: return 2 ** max(min(ceil(log2(length / 1024)), 10), 2) * 1024 -async def offset_fix(offset, chunksize): +async def offset_fix(offset, chunksize) -> int: offset -= offset % chunksize return offset class TGCustomYield: + """ + class to get the file from telegram servers + """ + def __init__(self): - """A custom method to stream files from telegram. functions: generate_file_properties: returns the properties + """ + A custom method to stream files from telegram. functions: generate_file_properties: returns the properties for a media on a specific message contained in FileId class. generate_media_session: returns the media session for the DC that contains the media file on the message. yield_file: yield a file from telegram servers for streaming. @@ -29,7 +34,10 @@ def __init__(self): self.main_bot = StreamBot @staticmethod - async def generate_file_properties(msg: Message): + async def generate_file_properties(m: Message): + """ + generate file properties from a message + """ available_media = ( "audio", "document", @@ -41,17 +49,17 @@ async def generate_file_properties(msg: Message): "video_note", ) - if isinstance(msg, Message): + if isinstance(m, Message): error_message = "This message doesn't contain any downloadable media" for kind in available_media: - media = getattr(msg, kind, None) + media = getattr(m, kind, None) if media is not None: break else: raise ValueError(error_message) else: - media = msg + media = m file_id_str = media if isinstance(media, str) else media.file_id file_id_obj = FileId.decode(file_id_str) @@ -63,28 +71,31 @@ async def generate_file_properties(msg: Message): return file_id_obj - async def generate_media_session(self, client: Client, msg: Message): - data = await self.generate_file_properties(msg) + async def generate_media_session(self, c: Client, m: Message): + """ + generate media session from a message + """ + data = await self.generate_file_properties(m) - media_session = client.media_sessions.get(data.dc_id, None) + media_session = c.media_sessions.get(data.dc_id, None) if media_session is None: - if data.dc_id != await client.storage.dc_id(): + if data.dc_id != await c.storage.dc_id(): media_session = Session( - client, + c, data.dc_id, await Auth( - client, + c, data.dc_id, - await client.storage.test_mode(), + await c.storage.test_mode(), ).create(), - await client.storage.test_mode(), + await c.storage.test_mode(), is_media=True, ) await media_session.start() for _ in range(3): - exported_auth = await client.send( + exported_auth = await c.send( raw.functions.auth.ExportAuthorization(dc_id=data.dc_id), ) @@ -104,20 +115,23 @@ async def generate_media_session(self, client: Client, msg: Message): raise AuthBytesInvalid else: media_session = Session( - client, + c, data.dc_id, - await client.storage.auth_key(), - await client.storage.test_mode(), + await c.storage.auth_key(), + await c.storage.test_mode(), is_media=True, ) await media_session.start() - client.media_sessions[data.dc_id] = media_session + c.media_sessions[data.dc_id] = media_session return media_session @staticmethod async def get_location(file_id: FileId): + """ + get location from file id + """ file_type = file_id.file_type if file_type == FileType.CHAT_PHOTO: @@ -164,6 +178,9 @@ async def yield_file( part_count: int, chunk_size_int: int, ) -> Union[str, None]: # pylint: disable=unsubscriptable-object + """ + yield a file from telegram servers for streaming + """ client = self.main_bot data = await self.generate_file_properties(media_msg) media_session = await self.generate_media_session(client, media_msg) @@ -204,10 +221,13 @@ async def yield_file( current_part += 1 - async def download_as_bytesio(self, media_msg: Message): + async def download_as_bytesio(self, m: Message): + """ + download a file as bytesio + """ client = self.main_bot - data = await self.generate_file_properties(media_msg) - media_session = await self.generate_media_session(client, media_msg) + data = await self.generate_file_properties(m) + media_session = await self.generate_media_session(client, m) location = await self.get_location(data) diff --git a/WebStreamer/utils/human_readable.py b/WebStreamer/utils/human_readable.py index e6cee81..71421b4 100644 --- a/WebStreamer/utils/human_readable.py +++ b/WebStreamer/utils/human_readable.py @@ -1,4 +1,7 @@ -def humanbytes(size): +def humanbytes(size) -> str: + """ + Returns a human readable string representation of bytes. + """ if not size: return "" power = 2**10 diff --git a/WebStreamer/utils/ikb.py b/WebStreamer/utils/ikb.py index a8a73e1..5982e96 100644 --- a/WebStreamer/utils/ikb.py +++ b/WebStreamer/utils/ikb.py @@ -1,7 +1,11 @@ from pyrogram.types import InlineKeyboardButton, InlineKeyboardMarkup -def ikb(rows=None): +def ikb(rows=None) -> InlineKeyboardMarkup: + """ + :param rows: list of list of buttons + :return: InlineKeyboardMarkup + """ if rows is None: rows = [] lines = [] @@ -15,6 +19,12 @@ def ikb(rows=None): # return {'inline_keyboard': lines} -def btn(text, value, type="callback_data"): - return InlineKeyboardButton(text, **{type: value}) +def btn(text, value, t="callback_data"): + """ + :param text: button text + :param value: button value + :param t: button type + :return: InlineKeyboardButton + """ + return InlineKeyboardButton(text, **{t: value}) # return {'text': text, type: value} diff --git a/WebStreamer/utils/joinCheck.py b/WebStreamer/utils/joinCheck.py index 757a83d..b8da7b3 100644 --- a/WebStreamer/utils/joinCheck.py +++ b/WebStreamer/utils/joinCheck.py @@ -5,7 +5,7 @@ from WebStreamer.utils.ikb import ikb from WebStreamer.vars import Vars -support_group = "https://t.me/DivideProjectsDiscussion" +support_group = "https://t.me/DivideSupport" def ban_kb(user_id: int): diff --git a/WebStreamer/utils/time_format.py b/WebStreamer/utils/time_format.py index 2c86664..863c8c9 100644 --- a/WebStreamer/utils/time_format.py +++ b/WebStreamer/utils/time_format.py @@ -1,4 +1,7 @@ def get_readable_time(seconds: int) -> str: + """ + Get readable time from seconds + """ count = 0 readable_time = "" time_list = [] diff --git a/WebStreamer/vars.py b/WebStreamer/vars.py index 26f1102..903d78f 100644 --- a/WebStreamer/vars.py +++ b/WebStreamer/vars.py @@ -8,6 +8,10 @@ class Vars: + """ + Class to store all the variables + """ + API_ID = int(config("API_ID", default=None)) API_HASH = str(config("API_HASH", default=None)) BOT_TOKEN = str(config("BOT_TOKEN", default=None)) diff --git a/cf-worker/src/index.ts b/cf-worker/src/index.ts index 1f3e174..3ce8295 100644 --- a/cf-worker/src/index.ts +++ b/cf-worker/src/index.ts @@ -5,9 +5,16 @@ const app = new Hono(); // listen for get requests on / app.get("/*", async (c: Context) => { + // get the formed url from the context const formedUrl = new URL(c.req.url); + + // get the fqdn from the env let fqdn = c.env.FQDN; + + // if fqdn ends without a slash, add one to it if (!fqdn.endsWith("/")) fqdn = fqdn + "/"; + + // slice the formed url to get the path, which is the file path const dlUrl = fqdn + formedUrl.pathname.slice(1); // Fetch from origin server. @@ -24,4 +31,5 @@ app.get("/*", async (c: Context) => { return new Response(readable, response); }); +// export default app export default app;