From ab5683e2469889789b5aeec01f07959f6c4cb30f Mon Sep 17 00:00:00 2001 From: "guorong.zheng" <360996299@qq.com> Date: Mon, 23 Dec 2024 11:38:10 +0800 Subject: [PATCH] refactor:get_speed_m3u8(#719) --- utils/speed.py | 138 +++++++++++++++++++++++++++++++------------------ 1 file changed, 89 insertions(+), 49 deletions(-) diff --git a/utils/speed.py b/utils/speed.py index 5fd9b79e0df..308700433ba 100644 --- a/utils/speed.py +++ b/utils/speed.py @@ -6,12 +6,14 @@ import m3u8 from aiohttp import ClientSession, TCPConnector +from multidict import CIMultiDictProxy from utils.config import config from utils.tools import is_ipv6, remove_cache_info -async def get_speed_with_download(url: str, timeout: int = config.sort_timeout) -> dict[str, float | None]: +async def get_speed_with_download(url: str, session: ClientSession = None, timeout: int = config.sort_timeout) -> dict[ + str, float | None]: """ Get the speed of the url with a total timeout """ @@ -19,24 +21,61 @@ async def get_speed_with_download(url: str, timeout: int = config.sort_timeout) total_size = 0 total_time = 0 info = {'speed': None, 'delay': None} + if session is None: + session = ClientSession(connector=TCPConnector(ssl=False), trust_env=True) + created_session = True + else: + created_session = False try: - async with ClientSession( - connector=TCPConnector(ssl=False), trust_env=True - ) as session: - async with session.get(url, timeout=timeout) as response: - if response.status == 404: - return info - info['delay'] = int(round((time() - start_time) * 1000)) - async for chunk in response.content.iter_any(): - if chunk: - total_size += len(chunk) + async with session.get(url, timeout=timeout) as response: + if response.status == 404: + return info + info['delay'] = int(round((time() - start_time) * 1000)) + async for chunk in response.content.iter_any(): + if chunk: + total_size += len(chunk) except Exception as e: pass finally: - end_time = time() - total_time += end_time - start_time - info['speed'] = (total_size / total_time if total_time > 0 else 0) / 1024 / 1024 - return info + if created_session: + await session.close() + end_time = time() + total_time += end_time - start_time + info['speed'] = (total_size / total_time if total_time > 0 else 0) / 1024 / 1024 + return info + + +async def get_m3u8_headers(url: str, session: ClientSession = None, timeout: int = 5) -> CIMultiDictProxy[str] | dict[ + any, any]: + """ + Get the headers of the m3u8 url + """ + if session is None: + session = ClientSession(connector=TCPConnector(ssl=False), trust_env=True) + created_session = True + else: + created_session = False + try: + async with session.head(url, timeout=timeout) as response: + return response.headers + except: + pass + finally: + if created_session: + await session.close() + return {} + + +def check_m3u8_valid(headers: CIMultiDictProxy[str] | dict[any, any]) -> bool: + """ + Check the m3u8 url is valid + """ + content_type = headers.get('Content-Type') + if content_type: + content_type = content_type.lower() + if 'application/vnd.apple.mpegurl' in content_type: + return True + return False async def get_speed_m3u8(url: str, timeout: int = config.sort_timeout) -> dict[str, float | None]: @@ -47,44 +86,45 @@ async def get_speed_m3u8(url: str, timeout: int = config.sort_timeout) -> dict[s try: url = quote(url, safe=':/?$&=@[]').partition('$')[0] async with ClientSession(connector=TCPConnector(ssl=False), trust_env=True) as session: - async with session.head(url, timeout=5) as response: - content_type = response.headers.get('Content-Type') - if content_type: - content_type = content_type.lower() - location = response.headers.get('Location') - if 'application/vnd.apple.mpegurl' in content_type: - url = location or url + headers = await get_m3u8_headers(url, session) + if check_m3u8_valid(headers): + location = headers.get('Location') + if location: + info.update(await get_speed_m3u8(location, timeout)) + else: + m3u8_obj = m3u8.load(url, timeout=2) + playlists = m3u8_obj.data.get('playlists') + segments = m3u8_obj.segments + if not segments and playlists: + parsed_url = urlparse(url) + url = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.rsplit('/', 1)[0]}/{playlists[0].get('uri', '')}" + uri_headers = await get_m3u8_headers(url, session) + if not check_m3u8_valid(uri_headers): + if uri_headers.get('Content-Length'): + info.update(await get_speed_with_download(url, session, timeout)) + return info m3u8_obj = m3u8.load(url, timeout=2) - playlists = m3u8_obj.data.get('playlists') segments = m3u8_obj.segments - if not segments and playlists: - parsed_url = urlparse(url) - url = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.rsplit('/', 1)[0]}/{playlists[0].get('uri', '')}" - m3u8_obj = m3u8.load(url, timeout=2) - segments = m3u8_obj.segments - if not segments: - return info - ts_urls = [segment.absolute_uri for segment in segments] - speed_list = [] - start_time = time() - for ts_url in ts_urls: - if time() - start_time > timeout: - break - download_info = await get_speed_with_download(ts_url, timeout) - speed_list.append(download_info['speed']) - if info['delay'] is None and download_info['delay'] is not None: - info['delay'] = download_info['delay'] - info['speed'] = sum(speed_list) / len(speed_list) if speed_list else 0 - elif location: - info.update(await get_speed_m3u8(location, timeout)) - elif response.headers.get('Content-Length'): - info.update(await get_speed_with_download(url, timeout)) - else: - return info + if not segments: + return info + ts_urls = [segment.absolute_uri for segment in segments] + speed_list = [] + start_time = time() + for ts_url in ts_urls: + if time() - start_time > timeout: + break + download_info = await get_speed_with_download(ts_url, session, timeout) + speed_list.append(download_info['speed']) + if info['delay'] is None and download_info['delay'] is not None: + info['delay'] = download_info['delay'] + info['speed'] = sum(speed_list) / len(speed_list) if speed_list else 0 + elif headers.get('Content-Length'): + info.update(await get_speed_with_download(url, session, timeout)) + else: + return info except: pass - finally: - return info + return info async def get_delay_requests(url, timeout=config.sort_timeout, proxy=None):