From 7933abe40b278e46de46348eb0020fdf1cc915c1 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Fri, 19 Jul 2024 04:44:59 +0800 Subject: [PATCH] check hf timeout --- olah/files.py | 70 +++++++++++++++++++------------------------------- olah/server.py | 51 ++++++++++++++++++++++++++++++++++-- 2 files changed, 75 insertions(+), 46 deletions(-) diff --git a/olah/files.py b/olah/files.py index 63ddb28..e41cddf 100644 --- a/olah/files.py +++ b/olah/files.py @@ -8,7 +8,7 @@ import hashlib import json import os -from typing import Dict, Literal, Optional +from typing import Dict, Literal, Optional, Tuple from fastapi import Request from requests.structures import CaseInsensitiveDict @@ -37,7 +37,7 @@ async def _file_full_header( url: str, headers: Dict[str, str], allow_cache: bool, - ): + ) -> Tuple[int, Dict[str, str], bytes]: if os.path.exists(head_path): with open(head_path, "r", encoding="utf-8") as f: response_headers = json.loads(f.read()) @@ -53,9 +53,21 @@ async def _file_full_header( timeout=WORKER_API_TIMEOUT, ) response_headers_dict = {k.lower(): v for k, v in response.headers.items()} - if allow_cache and method.lower() == "head" and response.status_code == 200: - with open(head_path, "w", encoding="utf-8") as f: - f.write(json.dumps(response_headers_dict, ensure_ascii=False)) + if allow_cache and method.lower() == "head": + if response.status_code == 200: + with open(head_path, "w", encoding="utf-8") as f: + f.write(json.dumps(response_headers_dict, ensure_ascii=False)) + elif response.status_code >= 300 and response.status_code <= 399: + with open(head_path, "w", encoding="utf-8") as f: + f.write(json.dumps(response_headers_dict, ensure_ascii=False)) + from_url = urlparse(url) + parsed_url = urlparse(response.headers["location"]) + if len(parsed_url.netloc) != 0: + new_loc = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response.headers["location"])) + response_headers_dict["location"] = new_loc + else: + raise Exception(f"Unexpected HTTP status code {response.status_code}") + return response.status_code, response_headers_dict, response.content else: response_headers_dict = {} @@ -72,7 +84,9 @@ async def _file_full_header( new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_SIZE.lower(), "") if "etag" in response_headers_dict: new_headers["etag"] = response_headers_dict["etag"] - return new_headers + if "location" in response_headers_dict: + new_headers["location"] = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response_headers_dict["location"])) + return 200, new_headers, b"" async def _get_file_block_from_cache(cache_file: OlahCache, block_index: int): raw_block = cache_file.read_block(block_index) @@ -208,47 +222,10 @@ async def _file_realtime_stream( hf_url = urljoin(app.app_settings.config.hf_lfs_url_base(), get_url_tail(url)) else: hf_url = url - - # Handle Redirection - if not app.app_settings.config.offline: - async with httpx.AsyncClient() as client: - response = await client.request( - method="HEAD", - url=hf_url, - headers=request_headers, - timeout=WORKER_API_TIMEOUT, - ) - - if response.status_code >= 300 and response.status_code <= 399: - from_url = urlparse(url) - parsed_url = urlparse(response.headers["location"]) - new_headers = {k.lower():v for k, v in response.headers.items()} - if len(parsed_url.netloc) != 0: - new_loc = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response.headers["location"])) - new_headers["location"] = new_loc - - if allow_cache: - with open(head_path, "w", encoding="utf-8") as f: - f.write(json.dumps(new_headers, ensure_ascii=False)) - yield response.status_code - yield new_headers - yield response.content - return - else: - if os.path.exists(head_path): - with open(head_path, "r", encoding="utf-8") as f: - head_content = json.loads(f.read()) - - if "location" in head_content: - yield 302 - yield head_content - yield b"" - return - async with httpx.AsyncClient() as client: # redirect_loc = await _get_redirected_url(client, method, url, request_headers) - head_info = await _file_full_header( + status_code, head_info, content = await _file_full_header( app=app, save_path=save_path, head_path=head_path, @@ -258,6 +235,11 @@ async def _file_realtime_stream( headers=request_headers, allow_cache=allow_cache, ) + if status_code != 200: + yield status_code + yield head_info + yield content + return file_size = int(head_info["content-length"]) response_headers = {k: v for k,v in head_info.items()} if "range" in request_headers: diff --git a/olah/server.py b/olah/server.py index 929b0ae..2706377 100644 --- a/olah/server.py +++ b/olah/server.py @@ -5,6 +5,7 @@ # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. +from contextlib import asynccontextmanager import os import argparse import traceback @@ -12,6 +13,7 @@ from urllib.parse import urljoin from fastapi import FastAPI, Header, Request from fastapi.responses import HTMLResponse, StreamingResponse, Response +from fastapi_utils.tasks import repeat_every import httpx from pydantic import BaseSettings from olah.configs import OlahConfig @@ -22,14 +24,59 @@ from olah.utils.logging import build_logger -app = FastAPI(debug=False) +# ====================== +# Utilities +# ====================== +async def check_connection(url: str) -> bool: + try: + async with httpx.AsyncClient() as client: + response = await client.request( + method="HEAD", + url=url, + timeout=10, + ) + if response.status_code != 200: + return False + else: + return True + except httpx.TimeoutException: + return False + + +@repeat_every(seconds=60) +async def check_hf_connection() -> None: + if app.app_settings.config.offline: + return + hf_online_status = await check_connection( + "https://huggingface.co/datasets/Salesforce/wikitext/resolve/main/.gitattributes" + ) + if not hf_online_status: + logger.info( + "Cannot reach Huggingface Official Site. Trying to connect hf-mirror." + ) + hf_mirror_online_status = await check_connection( + "https://hf-mirror.com/datasets/Salesforce/wikitext/resolve/main/.gitattributes" + ) + if not hf_online_status and not hf_mirror_online_status: + logger.error("Failed to reach Huggingface Official Site.") + logger.error("Failed to reach hf-mirror Site.") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + await check_hf_connection() + yield + +# ====================== +# Application +# ====================== +app = FastAPI(lifespan=lifespan, debug=False) class AppSettings(BaseSettings): # The address of the model controller. config: OlahConfig = OlahConfig() repos_path: str = "./repos" - # ====================== # API Hooks # ======================