From 0336380353087567df8afb12defa9f6716def328 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Tue, 9 Jul 2024 04:08:13 +0800 Subject: [PATCH] bug fix --- README_zh.md | 2 +- olah/configs.py | 9 +- olah/files.py | 211 +++++++++++++++++-------------------- olah/lfs.py | 10 +- olah/meta.py | 2 +- olah/server.py | 2 +- olah/{utls.py => utils.py} | 8 ++ pyproject.toml | 2 +- tests/simple_test.py | 2 - 9 files changed, 123 insertions(+), 125 deletions(-) rename olah/{utls.py => utils.py} (95%) diff --git a/README_zh.md b/README_zh.md index 1267b82..ae6abc1 100644 --- a/README_zh.md +++ b/README_zh.md @@ -40,7 +40,7 @@ pip install -e . python -m olah.server ``` -然后将环境变量`HF_ENDPOINT`设置为镜像站点(这里是http://localhost:8090)。 +然后将环境变量`HF_ENDPOINT`设置为镜像站点(这里是http://localhost:8090)。 Linux: ```bash export HF_ENDPOINT=http://localhost:8090 diff --git a/olah/configs.py b/olah/configs.py index 3427288..23bbd62 100644 --- a/olah/configs.py +++ b/olah/configs.py @@ -93,8 +93,8 @@ def __init__(self, path: Optional[str] = None) -> None: self.repos_path = "./repos" self.hf_url = "https://huggingface.co" self.hf_lfs_url = "https://cdn-lfs.huggingface.co" - self.mirror_url = "http://localhost:8090" - self.mirror_lfs_url = "http://localhost:8090" + self.mirror_url = f"http://{self.host}:{self.port}" + self.mirror_lfs_url = f"http://{self.host}:{self.port}" # accessibility self.offline = False @@ -103,6 +103,11 @@ def __init__(self, path: Optional[str] = None) -> None: if path is not None: self.read_toml(path) + + # refresh urls + self.mirror_url = f"http://{self.host}:{self.port}" + self.mirror_lfs_url = f"http://{self.host}:{self.port}" + def empty_str(self, s: str) -> Optional[str]: if s == "": diff --git a/olah/files.py b/olah/files.py index 385b746..4629693 100644 --- a/olah/files.py +++ b/olah/files.py @@ -9,51 +9,79 @@ from starlette.datastructures import URL from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT -from olah.utls import check_cache_rules_hf, get_org_repo - - -async def _file_head_cache_stream(app, save_path: str, request: Request): - with open(save_path, "r", encoding="utf-8") as f: - response_headers = json.loads(f.read()) - if "location" in response_headers: - response_headers["location"] = response_headers["location"].replace( - app.app_settings.hf_url, app.app_settings.mirror_url - ) - yield response_headers - +from olah.utils import check_cache_rules_hf, get_org_repo, make_dirs +FILE_HEADER_TEMPLATE = { + "accept-ranges": "bytes", + "access-control-allow-origin": "*", + "cache-control": "public, max-age=604800, immutable, s-maxage=604800", + # "content-length": None, + # "content-type": "binary/octet-stream", + # "etag": None, + # "last-modified": None, +} + +async def _file_cache_stream(save_path: str, head_path: str, request: Request): + if request.method.lower() == "head": + with open(head_path, "r", encoding="utf-8") as f: + response_headers = json.loads(f.read()) + yield FILE_HEADER_TEMPLATE + elif request.method.lower() == "get": + yield FILE_HEADER_TEMPLATE + else: + raise Exception(f"Invalid Method type {request.method}") + with open(save_path, "rb") as f: + while True: + chunk = f.read(CHUNK_SIZE) + if not chunk: + break + yield chunk -async def _file_head_realtime_stream( - app, - save_path: str, - url: str, - headers, - request: Request, - method="HEAD", - allow_cache=True, +async def _file_realtime_stream( + app, save_path: str, head_path: str, url: str, request: Request, method="GET", allow_cache=True ): - async with httpx.AsyncClient() as client: - async with client.stream( - method=method, - url=url, - headers=headers, - timeout=WORKER_API_TIMEOUT, - ) as response: - response_headers = response.headers - response_headers = {k: v for k, v in response_headers.items()} - if allow_cache: - with open(save_path, "w", encoding="utf-8") as f: - f.write(json.dumps(response_headers, ensure_ascii=False)) - if "location" in response_headers: - response_headers["location"] = response_headers["location"].replace( - app.app_settings.hf_url, app.app_settings.mirror_url - ) - yield response_headers - - async for raw_chunk in response.aiter_raw(): - if not raw_chunk: - continue - yield raw_chunk + request_headers = {k: v for k, v in request.headers.items()} + request_headers.pop("host") + temp_file_path = None + try: + async with httpx.AsyncClient() as client: + with tempfile.NamedTemporaryFile(mode="wb", delete=True) as temp_file: + if not allow_cache or request.method.lower() == "head": + write_temp_file = False + else: + write_temp_file = True + async with client.stream( + method=method, + url=url, + headers=request_headers, + timeout=WORKER_API_TIMEOUT, + ) as response: + response_headers = response.headers + response_headers_dict = {k: v for k, v in response_headers.items()} + if allow_cache: + if request.method.lower() == "head": + with open(head_path, "w", encoding="utf-8") as f: + f.write(json.dumps(response_headers_dict, ensure_ascii=False)) + if "location" in response_headers: + response_headers["location"] = response_headers["location"].replace( + app.app_settings.hf_lfs_url, app.app_settings.mirror_lfs_url + ) + yield response_headers + async for raw_chunk in response.aiter_raw(): + if not raw_chunk: + continue + if write_temp_file: + temp_file.write(raw_chunk) + yield raw_chunk + if not allow_cache: + temp_file_path = None + else: + temp_file_path = temp_file.name + if temp_file_path is not None: + shutil.copyfile(temp_file_path, save_path) + finally: + if temp_file_path is not None and os.path.exists(temp_file_path): + os.remove(temp_file_path) async def file_head_generator( app, @@ -64,84 +92,39 @@ async def file_head_generator( file_path: str, request: Request, ): - headers = {k: v for k, v in request.headers.items()} - headers.pop("host") - # save repos_path = app.app_settings.repos_path + head_path = os.path.join( + repos_path, f"heads/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}" + ) save_path = os.path.join( - repos_path, f"heads/{repo_type}/{org}/{repo}/resolve_head/{commit}/{file_path}" + repos_path, f"files/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}" ) - save_dir = os.path.dirname(save_path) - if not os.path.exists(save_dir): - os.makedirs(save_dir, exist_ok=True) + make_dirs(head_path) + make_dirs(save_path) - use_cache = os.path.exists(save_path) + use_cache = os.path.exists(head_path) and os.path.exists(save_path) allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) # proxy if use_cache: - return _file_head_cache_stream(app=app, save_path=save_path, request=request) + return _file_cache_stream(save_path=save_path, head_path=head_path, request=request) else: if repo_type == "models": url = f"{app.app_settings.hf_url}/{org}/{repo}/resolve/{commit}/{file_path}" else: url = f"{app.app_settings.hf_url}/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}" - return _file_head_realtime_stream( + return _file_realtime_stream( app=app, save_path=save_path, + head_path=head_path, url=url, - headers=headers, request=request, method="HEAD", allow_cache=allow_cache, ) -async def _file_cache_stream(save_path: str, request: Request): - yield request.headers - with open(save_path, "rb") as f: - while True: - chunk = f.read(CHUNK_SIZE) - if not chunk: - break - yield chunk - - -async def _file_realtime_stream( - save_path: str, url: str, headers, request: Request, method="GET", allow_cache=True -): - temp_file_path = None - try: - async with httpx.AsyncClient() as client: - with tempfile.NamedTemporaryFile(mode="wb", delete=False) as temp_file: - if not allow_cache: - temp_file = open(os.devnull, "wb") - async with client.stream( - method=method, - url=url, - headers=headers, - timeout=WORKER_API_TIMEOUT, - ) as response: - response_headers = response.headers - yield response_headers - - async for raw_chunk in response.aiter_raw(): - if not raw_chunk: - continue - temp_file.write(raw_chunk) - yield raw_chunk - if not allow_cache: - temp_file_path = None - else: - temp_file_path = temp_file.name - if temp_file_path is not None: - shutil.copyfile(temp_file_path, save_path) - finally: - if temp_file_path is not None: - os.remove(temp_file_path) - - async def file_get_generator( app, repo_type: Literal["models", "datasets"], @@ -151,32 +134,33 @@ async def file_get_generator( file_path: str, request: Request, ): - headers = {k: v for k, v in request.headers.items()} - headers.pop("host") # save repos_path = app.app_settings.repos_path + head_path = os.path.join( + repos_path, f"heads/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}" + ) save_path = os.path.join( repos_path, f"files/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}" ) - save_dir = os.path.dirname(save_path) - if not os.path.exists(save_dir): - os.makedirs(save_dir, exist_ok=True) + make_dirs(head_path) + make_dirs(save_path) - use_cache = os.path.exists(save_path) + use_cache = os.path.exists(head_path) and os.path.exists(save_path) allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) # proxy if use_cache: - return _file_cache_stream(save_path=save_path, request=request) + return _file_cache_stream(save_path=save_path, head_path=head_path, request=request) else: if repo_type == "models": url = f"{app.app_settings.hf_url}/{org}/{repo}/resolve/{commit}/{file_path}" else: url = f"{app.app_settings.hf_url}/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}" return _file_realtime_stream( + app=app, save_path=save_path, + head_path=head_path, url=url, - headers=headers, request=request, method="GET", allow_cache=allow_cache, @@ -197,14 +181,16 @@ async def cdn_file_get_generator( org_repo = get_org_repo(org, repo) # save repos_path = app.app_settings.repos_path + head_path = os.path.join( + repos_path, f"heads/{repo_type}/{org}/{repo}/cdn/{file_hash}" + ) save_path = os.path.join( - repos_path, f"files/{repo_type}/cdn/{org}/{repo}/{file_hash}" + repos_path, f"files/{repo_type}/{org}/{repo}/cdn/{file_hash}" ) - save_dir = os.path.dirname(save_path) - if not os.path.exists(save_dir): - os.makedirs(save_dir, exist_ok=True) + make_dirs(head_path) + make_dirs(save_path) - use_cache = os.path.exists(save_path) + use_cache = os.path.exists(head_path) and os.path.exists(save_path) allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) # proxy @@ -212,12 +198,13 @@ async def cdn_file_get_generator( return _file_cache_stream(save_path=save_path, request=request) else: redirected_url = str(request.url) - redirected_url = redirected_url.replace(app.app_settings.hf_lfs_url, app.app_settings.mirror_lfs_url) + redirected_url = redirected_url.replace(app.app_settings.mirror_lfs_url, app.app_settings.hf_lfs_url) return _file_realtime_stream( + app=app, save_path=save_path, + head_path=head_path, url=str(redirected_url), - headers=headers, request=request, method="GET", allow_cache=allow_cache, diff --git a/olah/lfs.py b/olah/lfs.py index a0c3930..4c521fd 100644 --- a/olah/lfs.py +++ b/olah/lfs.py @@ -8,6 +8,7 @@ import pytz from olah.constants import CHUNK_SIZE, LFS_FILE_BLOCK, WORKER_API_TIMEOUT +from olah.utils import make_dirs async def lfs_get_generator(app, repo_type: str, lfs_url: str, save_path: str, request: Request): @@ -17,8 +18,7 @@ async def lfs_get_generator(app, repo_type: str, lfs_url: str, save_path: str, r # save repos_path = app.app_settings.repos_path save_dir = os.path.join(repos_path, f"lfs/{repo_type}/{save_path}") - if not os.path.exists(save_dir): - os.makedirs(save_dir, exist_ok=True) + make_dirs(save_dir) # lfs meta lfs_meta_path = os.path.join(save_dir, "meta.json") @@ -109,7 +109,7 @@ async def lfs_get_generator(app, repo_type: str, lfs_url: str, save_path: str, r try: temp_file_path = None async with httpx.AsyncClient() as client: - with tempfile.NamedTemporaryFile(mode="wb", delete=False) as temp_file: + with tempfile.NamedTemporaryFile(mode="wb", delete=True) as temp_file: headers["range"] = f"bytes={block_start_pos}-{block_end_pos - 1}" async with client.stream( method="GET", url=lfs_url, @@ -145,8 +145,8 @@ async def lfs_get_generator(app, repo_type: str, lfs_url: str, save_path: str, r if raw_bytes >= block_end_pos - block_start_pos: break temp_file_path = temp_file.name - shutil.copyfile(temp_file_path, save_path) + shutil.copyfile(temp_file_path, save_path) finally: - if temp_file_path is not None: + if temp_file_path is not None and os.path.exists(temp_file_path): os.remove(temp_file_path) cur_block += 1 diff --git a/olah/meta.py b/olah/meta.py index efe9e00..ea98538 100644 --- a/olah/meta.py +++ b/olah/meta.py @@ -11,7 +11,7 @@ from olah.configs import OlahConfig from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT -from olah.utls import check_cache_rules_hf +from olah.utils import check_cache_rules_hf async def meta_cache_generator(app: FastAPI, save_path: str): yield {} diff --git a/olah/server.py b/olah/server.py index 32bdd5d..3525521 100644 --- a/olah/server.py +++ b/olah/server.py @@ -13,7 +13,7 @@ from olah.files import cdn_file_get_generator, file_get_generator, file_head_generator from olah.lfs import lfs_get_generator from olah.meta import meta_generator -from olah.utls import check_proxy_rules_hf, check_commit_hf, get_commit_hf, get_newest_commit_hf +from olah.utils import check_proxy_rules_hf, check_commit_hf, get_commit_hf, get_newest_commit_hf app = FastAPI(debug=False) diff --git a/olah/utls.py b/olah/utils.py similarity index 95% rename from olah/utls.py rename to olah/utils.py index 9613618..0ded8ee 100644 --- a/olah/utls.py +++ b/olah/utils.py @@ -98,3 +98,11 @@ async def check_cache_rules_hf(app, repo_type: Optional[Literal["models", "datas config: OlahConfig = app.app_settings.config org_repo = get_org_repo(org, repo) return config.cache.allow(f"{org_repo}") + +def make_dirs(path: str): + if os.path.isdir(path): + save_dir = path + else: + save_dir = os.path.dirname(path) + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 227a823..a095a33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "olah" -version = "0.0.4" +version = "0.0.5" description = "Self-hosted lightweight huggingface mirror." readme = "README.md" requires-python = ">=3.8" diff --git a/tests/simple_test.py b/tests/simple_test.py index 9d3cab1..3195405 100644 --- a/tests/simple_test.py +++ b/tests/simple_test.py @@ -11,7 +11,6 @@ def test_dataset(): snapshot_download(repo_id='Nerfgun3/bad_prompt', repo_type='dataset', local_dir='./dataset_dir', max_workers=8) - # 终止子进程 process.terminate() def test_model(): @@ -21,5 +20,4 @@ def test_model(): snapshot_download(repo_id='prajjwal1/bert-tiny', repo_type='model', local_dir='./model_dir', max_workers=8) - # 终止子进程 process.terminate() \ No newline at end of file