From ea97ba0e60462070b94228213f5fbf3374c39270 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Wed, 7 Aug 2024 04:16:50 +0800 Subject: [PATCH] meta_proxy_generator update --- olah/proxy/meta.py | 49 ++++++++++++++++------------------------ olah/server.py | 22 ++++++++++++++++-- olah/utils/repo_utils.py | 5 ++-- requirements.txt | 1 + 4 files changed, 44 insertions(+), 33 deletions(-) diff --git a/olah/proxy/meta.py b/olah/proxy/meta.py index 9116521..1fac6ab 100644 --- a/olah/proxy/meta.py +++ b/olah/proxy/meta.py @@ -75,37 +75,28 @@ async def meta_proxy_generator( allow_cache: bool, save_path: str, ): - try: - temp_file_path = None - async with httpx.AsyncClient(follow_redirects=True) as client: - with tempfile.NamedTemporaryFile(mode="wb", delete=True) as temp_file: - temp_file_path = temp_file.name - if not allow_cache: - write_temp_file = False - else: - write_temp_file = True - async with client.stream( - method="GET", - url=meta_url, - headers=headers, - timeout=WORKER_API_TIMEOUT, - ) as response: - response_headers = response.headers - yield response_headers + async with httpx.AsyncClient(follow_redirects=True) as client: + content_chunks = [] + async with client.stream( + method="GET", + url=meta_url, + headers=headers, + timeout=WORKER_API_TIMEOUT, + ) as response: + response_headers = response.headers + yield response_headers - async for raw_chunk in response.aiter_raw(): - if not raw_chunk: - continue - if write_temp_file: - temp_file.write(raw_chunk) - yield raw_chunk - if temp_file_path is not None: - temp_file.flush() - shutil.copyfile(temp_file_path, save_path) - finally: - if temp_file_path is not None and os.path.exists(temp_file_path): - os.remove(temp_file_path) + async for raw_chunk in response.aiter_raw(): + if not raw_chunk: + continue + content_chunks.append(raw_chunk) + yield raw_chunk + content = bytearray() + for chunk in content_chunks: + content += chunk + with open(save_path, "wb") as f: + f.write(bytes(content)) async def meta_generator( app: FastAPI, diff --git a/olah/server.py b/olah/server.py index 397039b..f5db66c 100644 --- a/olah/server.py +++ b/olah/server.py @@ -25,7 +25,25 @@ import git import httpx -from pydantic import BaseSettings + +BASE_SETTINGS = False +if not BASE_SETTINGS: + try: + from pydantic import BaseSettings + BASE_SETTINGS = True + except ImportError: + BASE_SETTINGS = False + +if not BASE_SETTINGS: + try: + from pydantic_settings import BaseSettings + BASE_SETTINGS = True + except ImportError: + BASE_SETTINGS = False + +if not BASE_SETTINGS: + raise Exception("Cannot import BaseSettings from pydantic or pydantic-settings") + from olah.configs import OlahConfig from olah.errors import error_repo_not_found, error_page_not_found from olah.mirror.repos import LocalMirrorRepo @@ -61,7 +79,7 @@ async def check_connection(url: str) -> bool: return False -@repeat_every(seconds=60) +@repeat_every(seconds=60*5) async def check_hf_connection() -> None: if app.app_settings.config.offline: return diff --git a/olah/utils/repo_utils.py b/olah/utils/repo_utils.py index eaec020..11711af 100644 --- a/olah/utils/repo_utils.py +++ b/olah/utils/repo_utils.py @@ -311,5 +311,6 @@ async def check_commit_hf( if authorization is not None: headers["authorization"] = authorization async with httpx.AsyncClient() as client: - response = await client.get(url, headers=headers, timeout=WORKER_API_TIMEOUT) - return response.status_code in [200, 307] + response = await client.request(method="HEAD", url=url, headers=headers, timeout=WORKER_API_TIMEOUT) + status_code = response.status_code + return status_code in [200, 307] diff --git a/requirements.txt b/requirements.txt index 6e87169..d1c261c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ fastapi-utils==0.7.0 GitPython==3.1.43 httpx==0.27.0 pydantic==2.8.2 +pydantic-setting==2.2.1 toml==0.10.2 huggingface_hub==0.23.4 pytest==8.2.2