diff --git a/README.md b/README.md index db4dd7f..bdc2f54 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,11 @@ Olah is self-hosted lightweight huggingface mirror service. `Olah` means `hello` in Hilichurlian. Other languages: [中文](README_zh.md) + +## Advantages of Olah +Olah has the capability to cache files in chunks while users download them. Upon subsequent downloads, the files can be directly retrieved from the cache, greatly enhancing download speeds and saving bandwidth. +Additionally, Olah offers a range of cache control policies. Administrators can configure which repositories are accessible and which ones can be cached through a configuration file. + ## Features * Huggingface Data Cache * Models mirror diff --git a/README_zh.md b/README_zh.md index 288078a..4ae8e27 100644 --- a/README_zh.md +++ b/README_zh.md @@ -1,5 +1,9 @@ Olah是一种自托管的轻量级HuggingFace镜像服务。`Olah`在丘丘人语中意味着`你好`。 +## Olah的优势 +Olah能够在用户下载的同时分块缓存文件。当第二次下载时,直接从缓存中读取,极大地提升了下载速度并节约了流量。 +同时Olah提供了丰富的缓存控制策略,管理员可以通过配置文件设置哪些仓库可以访问,哪些仓库可以缓存。 + ## 特性 * 数据缓存,减少下载流量 * 模型镜像 diff --git a/olah/files.py b/olah/files.py index 3201d7d..451f6fe 100644 --- a/olah/files.py +++ b/olah/files.py @@ -48,6 +48,7 @@ async def _get_redirected_url(client: httpx.AsyncClient, method: str, url: str, redirect_loc = response.headers["location"] else: redirect_loc = url + return redirect_loc async def _file_full_header( @@ -67,16 +68,16 @@ async def _file_full_header( else: if "range" in headers: headers.pop("range") - async with client.stream( + response = await client.request( method=method, url=url, headers=headers, timeout=WORKER_API_TIMEOUT, - ) as response: - response_headers_dict = {k.lower(): v for k, v in response.headers.items()} - if allow_cache and method.lower() == "head": - with open(head_path, "w", encoding="utf-8") as f: - f.write(json.dumps(response_headers_dict, ensure_ascii=False)) + ) + response_headers_dict = {k.lower(): v for k, v in response.headers.items()} + if allow_cache and method.lower() == "head": + with open(head_path, "w", encoding="utf-8") as f: + f.write(json.dumps(response_headers_dict, ensure_ascii=False)) new_headers = {} new_headers["content-type"] = response_headers_dict["content-type"] @@ -108,12 +109,15 @@ async def _get_file_block_from_remote(client: httpx.AsyncClient, remote_info: Re headers=remote_info.headers, timeout=WORKER_API_TIMEOUT, ) as response: + response_content_length = int(response.headers['content-length']) async for raw_chunk in response.aiter_raw(): if not raw_chunk: continue raw_block += raw_chunk # print(remote_info.url, remote_info.method, remote_info.headers) # print(block_start_pos, block_end_pos) + if len(raw_block) != response_content_length: + raise Exception(f"The content of the response is incomplete. Expected-{response_content_length}. Accepted-{len(raw_block)}") if len(raw_block) != (block_end_pos - block_start_pos): raise Exception(f"The block is incomplete. Expected-{block_end_pos - block_start_pos}. Accepted-{len(raw_block)}") if len(raw_block) < cache_file._get_block_size(): @@ -215,15 +219,40 @@ async def _file_realtime_stream( request_headers = {k: v for k, v in request.headers.items()} request_headers.pop("host") + if urlparse(url).netloc == app.app_settings.config.mirror_netloc: + hf_url = urljoin(app.app_settings.config.hf_lfs_url_base(), get_url_tail(url)) + else: + hf_url = url + async with httpx.AsyncClient() as client: + response = await client.request( + method="HEAD", + url=hf_url, + headers=request_headers, + timeout=WORKER_API_TIMEOUT, + ) + + if response.status_code >= 300 and response.status_code <= 399: + from_url = urlparse(url) + parsed_url = urlparse(response.headers["location"]) + new_headers = {k.lower():v for k, v in response.headers.items()} + if len(parsed_url.netloc) != 0: + new_loc = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response.headers["location"])) + new_headers["location"] = new_loc + + yield response.status_code + yield new_headers + yield response.content + return + async with httpx.AsyncClient() as client: - redirect_loc = await _get_redirected_url(client, method, url, request_headers) + # redirect_loc = await _get_redirected_url(client, method, url, request_headers) head_info = await _file_full_header( app=app, save_path=save_path, head_path=head_path, client=client, - method=method, - url=redirect_loc, + method="HEAD", + url=hf_url, headers=request_headers, allow_cache=allow_cache, ) @@ -234,6 +263,7 @@ async def _file_realtime_stream( response_headers["content-length"] = str(end_pos - start_pos) if commit is not None: response_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = commit + yield 200 yield response_headers if method.lower() == "get": async for each_chunk in _file_chunk_get( @@ -242,7 +272,7 @@ async def _file_realtime_stream( head_path=head_path, client=client, method=method, - url=redirect_loc, + url=hf_url, headers=request_headers, allow_cache=allow_cache, file_size=file_size, @@ -255,7 +285,7 @@ async def _file_realtime_stream( head_path=head_path, client=client, method=method, - url=redirect_loc, + url=hf_url, headers=request_headers, allow_cache=allow_cache, file_size=0, @@ -264,48 +294,6 @@ async def _file_realtime_stream( else: raise Exception(f"Unsupported method: {method}") - -async def file_head_generator( - app, - repo_type: Literal["models", "datasets"], - org: str, - repo: str, - commit: str, - file_path: str, - request: Request, -): - org_repo = get_org_repo(org, repo) - # save - repos_path = app.app_settings.repos_path - head_path = os.path.join( - repos_path, f"heads/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}" - ) - save_path = os.path.join( - repos_path, f"files/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}" - ) - make_dirs(head_path) - make_dirs(save_path) - - # use_cache = os.path.exists(head_path) and os.path.exists(save_path) - allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) - - # proxy - if repo_type == "models": - url = urljoin(app.app_settings.config.hf_url_base(), f"/{org_repo}/resolve/{commit}/{file_path}") - else: - url = urljoin(app.app_settings.config.hf_url_base(), f"/{repo_type}/{org_repo}/resolve/{commit}/{file_path}") - return _file_realtime_stream( - app=app, - save_path=save_path, - head_path=head_path, - url=url, - request=request, - method="HEAD", - allow_cache=allow_cache, - commit=commit, - ) - - async def file_get_generator( app, repo_type: Literal["models", "datasets"], @@ -313,6 +301,7 @@ async def file_get_generator( repo: str, commit: str, file_path: str, + method: Literal["HEAD", "GET"], request: Request, ): org_repo = get_org_repo(org, repo) @@ -341,18 +330,18 @@ async def file_get_generator( head_path=head_path, url=url, request=request, - method="GET", + method=method, allow_cache=allow_cache, commit=commit, ) - async def cdn_file_get_generator( app, repo_type: Literal["models", "datasets"], org: str, repo: str, file_hash: str, + method: Literal["HEAD", "GET"], request: Request, ): headers = {k: v for k, v in request.headers.items()} @@ -374,18 +363,18 @@ async def cdn_file_get_generator( allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) # proxy - request_url = urlparse(request.url) - if request_url.netloc == app.app_settings.config.hf_lfs_netloc: - redirected_url = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(request_url)) - else: - redirected_url = urljoin(app.app_settings.config.mirror_url_base(), get_url_tail(request_url)) + # request_url = urlparse(str(request.url)) + # if request_url.netloc == app.app_settings.config.hf_lfs_netloc: + # redirected_url = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(request_url)) + # else: + # redirected_url = urljoin(app.app_settings.config.mirror_url_base(), get_url_tail(request_url)) return _file_realtime_stream( app=app, save_path=save_path, head_path=head_path, - url=redirected_url, + url=str(request.url), request=request, - method="GET", + method=method, allow_cache=allow_cache, ) diff --git a/olah/lfs.py b/olah/lfs.py index 0137bca..46d4131 100644 --- a/olah/lfs.py +++ b/olah/lfs.py @@ -1,157 +1,70 @@ -""" -废弃方法 -""" - import datetime import json import os -import tempfile -import shutil +from typing import Literal from fastapi import FastAPI, Header, Request -import httpx -import pytz -from olah.constants import CHUNK_SIZE, LFS_FILE_BLOCK, WORKER_API_TIMEOUT +from olah.files import _file_realtime_stream from olah.utils.file_utils import make_dirs +from olah.utils.url_utils import check_cache_rules_hf, get_org_repo -async def lfs_get_generator(app, repo_type: str, lfs_url: str, save_path: str, request: Request): - headers = {k: v for k, v in request.headers.items()} - headers.pop("host") - +async def lfs_head_generator( + app, + dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request +): # save repos_path = app.app_settings.repos_path - save_dir = os.path.join(repos_path, f"lfs/{repo_type}/{save_path}") - make_dirs(save_dir) - - # lfs meta - lfs_meta_path = os.path.join(save_dir, "meta.json") - if os.path.exists(lfs_meta_path): - with open(lfs_meta_path, "r", encoding="utf-8") as f: - lfs_meta = json.loads(f.read()) - else: - async with httpx.AsyncClient() as client: - async with client.stream( - method="GET", url=lfs_url, - headers={"range": "-"}, - params=request.query_params, - timeout=WORKER_API_TIMEOUT, - ) as response: - file_size = response.headers["content-length"] - req_headers = {k: v for k, v in response.headers.items()} - lfs_meta = { - "lfs_file_block": LFS_FILE_BLOCK, - "file_size": int(file_size), - "req_headers": req_headers, - } - with open(lfs_meta_path, "w", encoding="utf-8") as f: - f.write(json.dumps(lfs_meta)) - # range - file_size = lfs_meta["file_size"] - if "range" in headers: - file_range = headers['range'] # 'bytes=1887436800-' - if file_range.startswith("bytes="): - file_range = file_range[6:] - start_pos, end_pos = file_range.split("-") - if len(start_pos) != 0: - start_pos = int(start_pos) - else: - start_pos = 0 - if len(end_pos) != 0: - end_pos = int(end_pos) - else: - end_pos = file_size - else: - start_pos = 0 - end_pos = file_size - - # block - lfs_file_block = lfs_meta["lfs_file_block"] - start_block = start_pos // lfs_file_block - end_block = end_pos // lfs_file_block - - new_headers = lfs_meta["req_headers"] - new_headers["date"] = datetime.datetime.now(pytz.timezone('GMT')).strftime('%a, %d %b %Y %H:%M:%S %Z') - new_headers["content-length"] = str(end_pos - start_pos) - - yield new_headers - cur_pos = start_pos - cur_block = start_block + head_path = os.path.join( + repos_path, f"lfs/heads/{dir1}/{dir2}/{hash_repo}/{hash_file}" + ) + save_path = os.path.join( + repos_path, f"lfs/files/{dir1}/{dir2}/{hash_repo}/{hash_file}" + ) + make_dirs(head_path) + make_dirs(save_path) - while cur_block <= end_block: - save_path = os.path.join(save_dir, f"block-{cur_block}.bin") - use_cache = os.path.exists(save_path) - block_start_pos = cur_block * lfs_file_block - block_end_pos = min((cur_block + 1) * lfs_file_block, file_size) + # use_cache = os.path.exists(head_path) and os.path.exists(save_path) + allow_cache = True - # proxy - if use_cache: - with open(save_path, "rb") as f: - sub_chunk_start_pos = block_start_pos - while True: - raw_chunk = f.read(CHUNK_SIZE) - if not raw_chunk: - break + # proxy + return _file_realtime_stream( + app=app, + save_path=save_path, + head_path=head_path, + url=str(request.url), + request=request, + method="HEAD", + allow_cache=allow_cache, + commit=None, + ) - chunk = raw_chunk - if cur_pos >= sub_chunk_start_pos and cur_pos < sub_chunk_start_pos + len(raw_chunk): - chunk = chunk[cur_pos - sub_chunk_start_pos:] - elif cur_pos >= sub_chunk_start_pos + len(raw_chunk): - chunk = bytes([]) - elif cur_pos < sub_chunk_start_pos: - pass - - if cur_pos + len(chunk) > block_end_pos: - chunk = chunk[:-(cur_pos + len(chunk) - block_end_pos)] - print("Warning: This maybe a bug, sending chunk is larger than content length.") - - if len(chunk) != 0: - yield chunk - cur_pos += len(chunk) - sub_chunk_start_pos += len(raw_chunk) - else: - try: - temp_file_path = None - async with httpx.AsyncClient() as client: - with tempfile.NamedTemporaryFile(mode="wb", delete=True) as temp_file: - headers["range"] = f"bytes={block_start_pos}-{block_end_pos - 1}" - async with client.stream( - method="GET", url=lfs_url, - headers=headers, - params=request.query_params, - timeout=WORKER_API_TIMEOUT, - ) as response: - raw_bytes = 0 - sub_chunk_start_pos = block_start_pos - async for raw_chunk in response.aiter_raw(): - if not raw_chunk: - continue - temp_file.write(raw_chunk) - - stream_chunk = raw_chunk - - if cur_pos > sub_chunk_start_pos and cur_pos < sub_chunk_start_pos + len(raw_chunk): - stream_chunk = stream_chunk[cur_pos - sub_chunk_start_pos:] - elif cur_pos >= sub_chunk_start_pos + len(raw_chunk): - stream_chunk = bytes([]) - elif cur_pos < sub_chunk_start_pos: - pass +async def lfs_get_generator( + app, + dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request +): + # save + repos_path = app.app_settings.repos_path + head_path = os.path.join( + repos_path, f"lfs/heads/{dir1}/{dir2}/{hash_repo}/{hash_file}" + ) + save_path = os.path.join( + repos_path, f"lfs/files/{dir1}/{dir2}/{hash_repo}/{hash_file}" + ) + make_dirs(head_path) + make_dirs(save_path) - if cur_pos + len(stream_chunk) > block_end_pos: - stream_chunk = stream_chunk[:-(cur_pos + len(stream_chunk) - block_end_pos)] - print("Warning: This maybe a bug, sending chunk is larger than content length.") + # use_cache = os.path.exists(head_path) and os.path.exists(save_path) + allow_cache = True - if len(stream_chunk) != 0: - yield stream_chunk - cur_pos += len(stream_chunk) - raw_bytes += len(raw_chunk) - sub_chunk_start_pos += len(raw_chunk) - if raw_bytes >= block_end_pos - block_start_pos: - break - temp_file_path = temp_file.name - temp_file.flush() - shutil.copyfile(temp_file_path, save_path) - finally: - if temp_file_path is not None and os.path.exists(temp_file_path): - os.remove(temp_file_path) - cur_block += 1 + # proxy + return _file_realtime_stream( + app=app, + save_path=save_path, + head_path=head_path, + url=str(request.url), + request=request, + method="GET", + allow_cache=allow_cache, + commit=None, + ) \ No newline at end of file diff --git a/olah/server.py b/olah/server.py index 7fa249a..665e8fa 100644 --- a/olah/server.py +++ b/olah/server.py @@ -11,8 +11,8 @@ import httpx from pydantic import BaseSettings from olah.configs import OlahConfig -from olah.files import cdn_file_get_generator, file_get_generator, file_head_generator -from olah.lfs import lfs_get_generator +from olah.files import cdn_file_get_generator, file_get_generator +from olah.lfs import lfs_get_generator, lfs_head_generator from olah.meta import meta_generator from olah.utils.url_utils import check_proxy_rules_hf, check_commit_hf, get_commit_hf, get_newest_commit_hf, parse_org_repo @@ -70,38 +70,40 @@ async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request: # File Head Hooks # ====================== @app.head("/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path:path}") -async def file_head_proxy2(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str): +async def file_head3(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str): if not await check_proxy_rules_hf(app, repo_type, org, repo): return Response(content="This repository is forbidden by the mirror. ", status_code=403) if not await check_commit_hf(app, repo_type, org, repo, commit): return Response(content="This repository is not accessible. ", status_code=404) commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) - generator = await file_head_generator(app, repo_type, org, repo, commit_sha, file_path, request) + generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path=file_path, method="HEAD", request=request) + status_code = await generator.__anext__() headers = await generator.__anext__() - return StreamingResponse(generator, headers=headers) + return StreamingResponse(generator, headers=headers, status_code=status_code) -@app.head("/{org_or_repo_type}/{repo}/resolve/{commit}/{file_path:path}") -async def file_head_proxy(org_or_repo_type: str, repo: str, commit: str, file_path: str, request: Request): +@app.head("/{org_or_repo_type}/{repo_name}/resolve/{commit}/{file_path:path}") +async def file_head2(org_or_repo_type: str, repo_name: str, commit: str, file_path: str, request: Request): if org_or_repo_type in ["models", "datasets", "spaces"]: repo_type: str = org_or_repo_type - org, repo = parse_org_repo(repo) + org, repo = parse_org_repo(repo_name) if org is None and repo is None: return Response(content="This repository is not accessible.", status_code=404) else: repo_type: str = "models" - org, repo = org_or_repo_type, repo + org, repo = org_or_repo_type, repo_name if not await check_proxy_rules_hf(app, repo_type, org, repo): return Response(content="This repository is forbidden by the mirror. ", status_code=403) if org is not None and not await check_commit_hf(app, repo_type, org, repo, commit): return Response(content="This repository is not accessible. ", status_code=404) commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) - generator = await file_head_generator(app, repo_type, org, repo, commit_sha, file_path, request) + generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path=file_path, method="HEAD", request=request) + status_code = await generator.__anext__() headers = await generator.__anext__() - return StreamingResponse(generator, headers=headers) + return StreamingResponse(generator, headers=headers, status_code=status_code) @app.head("/{org_repo}/resolve/{commit}/{file_path:path}") -async def file_head_proxy_default_type(org_repo: str, commit: str, file_path: str, request: Request): +async def file_head(org_repo: str, commit: str, file_path: str, request: Request): repo_type: str = "models" org, repo = parse_org_repo(org_repo) if org is None and repo is None: @@ -113,46 +115,64 @@ async def file_head_proxy_default_type(org_repo: str, commit: str, file_path: st return Response(content="This repository is not accessible. ", status_code=404) commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) - generator = await file_head_generator(app, repo_type, org, repo, commit_sha, file_path, request) + generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path=file_path, method="HEAD", request=request) + status_code = await generator.__anext__() headers = await generator.__anext__() - return StreamingResponse(generator, headers=headers) + return StreamingResponse(generator, headers=headers, status_code=status_code) + +@app.head("/{org_repo}/{hash_file}") +@app.head("/{repo_type}/{org_repo}/{hash_file}") +async def cdn_file_head(org_repo: str, hash_file: str, request: Request, repo_type: str = "models"): + org, repo = parse_org_repo(org_repo) + if org is None and repo is None: + return Response(content="This repository is not accessible.", status_code=404) + + if not await check_proxy_rules_hf(app, repo_type, org, repo): + return Response(content="This repository is forbidden by the mirror. ", status_code=403) + + generator = await cdn_file_get_generator(app, repo_type, org, repo, hash_file, method="HEAD", request=request) + status_code = await generator.__anext__() + headers = await generator.__anext__() + return StreamingResponse(generator, headers=headers, status_code=status_code) # ====================== # File Hooks # ====================== @app.get("/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path:path}") -async def file_proxy2(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str): +async def file_get3(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str): if not await check_proxy_rules_hf(app, repo_type, org, repo): return Response(content="This repository is forbidden by the mirror. ", status_code=403) if not await check_commit_hf(app, repo_type, org, repo, commit): return Response(content="This repository is not accessible. ", status_code=404) commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) - generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path, request) + generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path=file_path, method="GET", request=request) + status_code = await generator.__anext__() headers = await generator.__anext__() - return StreamingResponse(generator, headers=headers) + return StreamingResponse(generator, headers=headers, status_code=status_code) -@app.get("/{org_or_repo_type}/{repo}/resolve/{commit}/{file_path:path}") -async def file_proxy2_default_type(org_or_repo_type: str, repo: str, commit: str, file_path: str, request: Request): +@app.get("/{org_or_repo_type}/{repo_name}/resolve/{commit}/{file_path:path}") +async def file_get2(org_or_repo_type: str, repo_name: str, commit: str, file_path: str, request: Request): if org_or_repo_type in ["models", "datasets", "spaces"]: repo_type: str = org_or_repo_type - org, repo = parse_org_repo(repo) + org, repo = parse_org_repo(repo_name) if org is None and repo is None: return Response(content="This repository is not accessible.", status_code=404) else: repo_type: str = "models" - org, repo = org_or_repo_type, repo + org, repo = org_or_repo_type, repo_name if not await check_proxy_rules_hf(app, repo_type, org, repo): return Response(content="This repository is forbidden by the mirror. ", status_code=403) if org is not None and not await check_commit_hf(app, repo_type, org, repo, commit): return Response(content="This repository is not accessible. ", status_code=404) commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) - generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path, request) + generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path=file_path, method="GET", request=request) + status_code = await generator.__anext__() headers = await generator.__anext__() - return StreamingResponse(generator, headers=headers) + return StreamingResponse(generator, headers=headers, status_code=status_code) @app.get("/{org_repo}/resolve/{commit}/{file_path:path}") -async def file_proxy_default_type(org_repo: str, commit: str, file_path: str, request: Request): +async def file_get(org_repo: str, commit: str, file_path: str, request: Request): repo_type: str = "models" org, repo = parse_org_repo(org_repo) if org is None and repo is None: @@ -163,12 +183,14 @@ async def file_proxy_default_type(org_repo: str, commit: str, file_path: str, re if not await check_commit_hf(app, repo_type, org, repo, commit): return Response(content="This repository is not accessible. ", status_code=404) commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) - generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path, request) + generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path=file_path, method="GET", request=request) + status_code = await generator.__anext__() headers = await generator.__anext__() - return StreamingResponse(generator, headers=headers) + return StreamingResponse(generator, headers=headers, status_code=status_code) +@app.get("/{org_repo}/{hash_file}") @app.get("/{repo_type}/{org_repo}/{hash_file}") -async def cdn_file_proxy(org_repo: str, hash_file: str, request: Request, repo_type: str = "models"): +async def cdn_file_get(org_repo: str, hash_file: str, request: Request, repo_type: str = "models"): org, repo = parse_org_repo(org_repo) if org is None and repo is None: return Response(content="This repository is not accessible.", status_code=404) @@ -176,30 +198,27 @@ async def cdn_file_proxy(org_repo: str, hash_file: str, request: Request, repo_t if not await check_proxy_rules_hf(app, repo_type, org, repo): return Response(content="This repository is forbidden by the mirror. ", status_code=403) - generator = await cdn_file_get_generator(app, repo_type, org, repo, hash_file, request) + generator = await cdn_file_get_generator(app, repo_type, org, repo, hash_file, method="GET", request=request) + status_code = await generator.__anext__() headers = await generator.__anext__() - return StreamingResponse(generator, headers=headers) + return StreamingResponse(generator, headers=headers, status_code=status_code) # ====================== # LFS Hooks # ====================== -@app.get("/repos/{dir1}/{dir2}/{hash_repo}/{hash_file}") -async def lfs_proxy(dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request): - repo_type = "models" - lfs_url = urljoin(app.app_settings.config.hf_lfs_url_base(), f"/repos/{dir1}/{dir2}/{hash_repo}/{hash_file}") - save_path = f"{dir1}/{dir2}/{hash_repo}/{hash_file}" - generator = lfs_get_generator(app, repo_type, lfs_url, save_path, request) +@app.head("/repos/{dir1}/{dir2}/{hash_repo}/{hash_file}") +async def lfs_head(dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request): + generator = await lfs_head_generator(app, dir1, dir2, hash_repo, hash_file, request) + status_code = await generator.__anext__() headers = await generator.__anext__() - return StreamingResponse(generator, headers=headers) + return StreamingResponse(generator, headers=headers, status_code=status_code) -@app.get("/datasets/hendrycks_test/{hash_file}") -async def lfs_proxy(hash_file: str, request: Request): - repo_type = "datasets" - lfs_url = urljoin(app.app_settings.config.hf_lfs_url_base(), f"/datasets/hendrycks_test/{hash_file}") - save_path = f"hendrycks_test/{hash_file}" - generator = lfs_get_generator(app, repo_type, lfs_url, save_path, request) +@app.get("/repos/{dir1}/{dir2}/{hash_repo}/{hash_file}") +async def lfs_get(dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request): + generator = await lfs_get_generator(app, dir1, dir2, hash_repo, hash_file, request) + status_code = await generator.__anext__() headers = await generator.__anext__() - return StreamingResponse(generator, headers=headers) + return StreamingResponse(generator, headers=headers, status_code=status_code) # ====================== # Web Page Hooks diff --git a/olah/utils/url_utils.py b/olah/utils/url_utils.py index e55c643..a072693 100644 --- a/olah/utils/url_utils.py +++ b/olah/utils/url_utils.py @@ -2,9 +2,9 @@ import datetime import os import glob -from typing import Dict, Literal, Optional, Tuple +from typing import Dict, Literal, Optional, Tuple, Union import json -from urllib.parse import ParseResult, urljoin +from urllib.parse import ParseResult, urljoin, urlparse import httpx from olah.configs import OlahConfig from olah.constants import WORKER_API_TIMEOUT @@ -16,7 +16,7 @@ def get_org_repo(org: Optional[str], repo: str) -> str: org_repo = f"{org}/{repo}" return org_repo -def parse_org_repo(org_repo: str) -> Tuple[str, str]: +def parse_org_repo(org_repo: str) -> Tuple[Optional[str], Optional[str]]: if "/" in org_repo and org_repo.count("/") != 1: return None, None if "/" in org_repo: @@ -109,7 +109,9 @@ async def check_cache_rules_hf(app, repo_type: Optional[Literal["models", "datas org_repo = get_org_repo(org, repo) return config.cache.allow(f"{org_repo}") -def get_url_tail(parsed_url: ParseResult) -> str: +def get_url_tail(parsed_url: Union[str, ParseResult]) -> str: + if isinstance(parsed_url, str): + parsed_url = urlparse(parsed_url) url_tail = parsed_url.path if len(parsed_url.params) != 0: url_tail += f";{parsed_url.params}" diff --git a/requirements.txt b/requirements.txt index 8d4816b..177b67a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -fastapi -httpx +fastapi==0.111.0 +httpx==0.27.0 pydantic==2.8.2 toml==0.10.2 huggingface_hub==0.23.4