diff --git a/.gitignore b/.gitignore index fbafc9a..87e4e38 100644 --- a/.gitignore +++ b/.gitignore @@ -159,6 +159,7 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +/mirrors_dir/ /model_dir/ /dataset_dir/ /repos/ diff --git a/assets/full_configs.toml b/assets/full_configs.toml index a66fc61..e5acd6a 100644 --- a/assets/full_configs.toml +++ b/assets/full_configs.toml @@ -10,6 +10,7 @@ hf-lfs-netloc = "cdn-lfs.huggingface.co" mirror-scheme = "http" mirror-netloc = "localhost:8090" mirror-lfs-netloc = "localhost:8090" +mirrors-path = ["./mirrors_dir"] [accessibility] offline = false diff --git a/olah/configs.py b/olah/configs.py index 7dee165..e276ccc 100644 --- a/olah/configs.py +++ b/olah/configs.py @@ -92,6 +92,8 @@ def __init__(self, path: Optional[str] = None) -> None: self.mirror_netloc: str = "localhost:8090" self.mirror_lfs_netloc: str = "localhost:8090" + self.mirrors_path: List[str] = [] + # accessibility self.offline = False self.proxy = OlahRuleList.from_list(DEFAULT_PROXY_RULES) @@ -139,6 +141,8 @@ def read_toml(self, path: str) -> None: "mirror-lfs-netloc", self.mirror_lfs_netloc ) + self.mirrors_path = basic.get("mirrors-path", self.mirrors_path) + if "accessibility" in config: accessibility = config["accessibility"] self.offline = accessibility.get("offline", self.offline) diff --git a/olah/constants.py b/olah/constants.py index c4018f2..e10b81b 100644 --- a/olah/constants.py +++ b/olah/constants.py @@ -12,6 +12,7 @@ DEFAULT_LOGGER_DIR = "./logs" from huggingface_hub.constants import ( + REPO_TYPES_MAPPING, HUGGINGFACE_CO_URL_TEMPLATE, HUGGINGFACE_HEADER_X_REPO_COMMIT, HUGGINGFACE_HEADER_X_LINKED_ETAG, diff --git a/olah/errors.py b/olah/errors.py new file mode 100644 index 0000000..165300b --- /dev/null +++ b/olah/errors.py @@ -0,0 +1,15 @@ + + + +from fastapi.responses import JSONResponse + + +def error_repo_not_found() -> JSONResponse: + return JSONResponse( + content={"error": "Repository not found"}, + headers={ + "x-error-code": "RepoNotFound", + "x-error-message": "Repository not found", + }, + status_code=401, + ) \ No newline at end of file diff --git a/olah/mirror/__init__.py b/olah/mirror/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/olah/mirror/meta.py b/olah/mirror/meta.py new file mode 100644 index 0000000..fe302cb --- /dev/null +++ b/olah/mirror/meta.py @@ -0,0 +1,40 @@ + + +class RepoMeta(object): + def __init__(self) -> None: + self._id = None + self.id = None + self.author = None + self.sha = None + self.lastModified = None + self.private = False + self.gated = False + self.disabled = False + self.tags = [] + self.description = "" + self.paperswithcode_id = None + self.downloads = 0 + self.likes = 0 + self.cardData = None + self.siblings = None + self.createdAt = None + + def to_dict(self): + return { + "_id": self._id, + "id": self.id, + "author": self.author, + "sha": self.sha, + "lastModified": self.lastModified, + "private": self.private, + "gated": self.gated, + "disabled": self.disabled, + "tags": self.tags, + "description": self.description, + "paperswithcode_id": self.paperswithcode_id, + "downloads": self.downloads, + "likes": self.likes, + "cardData": self.cardData, + "siblings": self.siblings, + "createdAt": self.createdAt, + } diff --git a/olah/mirror/repos.py b/olah/mirror/repos.py new file mode 100644 index 0000000..2296039 --- /dev/null +++ b/olah/mirror/repos.py @@ -0,0 +1,167 @@ +# coding=utf-8 +# Copyright 2024 XiaHan +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. +import hashlib +import io +import os +import re +from typing import Any, Dict, List, Union +import gitdb +from git import Commit, Optional, Repo, Tree +from gitdb.base import OStream +import yaml + +from olah.mirror.meta import RepoMeta +class LocalMirrorRepo(object): + def __init__(self, path: str, repo_type: str, org: str, repo: str) -> None: + self._path = path + self._repo_type = repo_type + self._org = org + self._repo = repo + + self._git_repo = Repo(self._path) + + def _sha256(self, text: Union[str, bytes]) -> str: + if isinstance(text, bytes) or isinstance(text, bytearray): + bin = text + elif isinstance(text, str): + bin = text.encode('utf-8') + else: + raise Exception("Invalid sha256 param type.") + sha256_hash = hashlib.sha256() + sha256_hash.update(bin) + hashed_string = sha256_hash.hexdigest() + return hashed_string + + def _match_card(self, readme: str) -> str: + pattern = r'\s*---(.*?)---' + + match = re.match(pattern, readme, flags=re.S) + + if match: + card_string = match.group(1) + return card_string + else: + return "" + def _remove_card(self, readme: str) -> str: + pattern = r'\s*---(.*?)---' + out = re.sub(pattern, "", readme, flags=re.S) + return out + + def _get_readme(self, commit: Commit) -> str: + if "README.md" not in commit.tree: + return "" + else: + out: bytes = commit.tree["README.md"].data_stream.read() + return out.decode() + + def _get_description(self, commit: Commit) -> str: + readme = self._get_readme(commit) + return self._remove_card(readme) + + def _get_entry_files(self, tree, include_dir=False) -> List[str]: + out_paths = [] + for entry in tree: + if entry.type == "tree": + out_paths.extend(self._get_entry_files(entry)) + if include_dir: + out_paths.append(entry.path) + else: + out_paths.append(entry.path) + return out_paths + + def _get_tree_files(self, commit: Commit) -> List[str]: + return self._get_entry_files(commit.tree) + + + def _get_earliest_commit(self) -> Commit: + earliest_commit = None + earliest_commit_date = None + + for commit in self._git_repo.iter_commits(): + commit_date = commit.committed_datetime + + if earliest_commit_date is None or commit_date < earliest_commit_date: + earliest_commit = commit + earliest_commit_date = commit_date + + return earliest_commit + + def get_meta(self, commit_hash: str) -> Dict[str, Any]: + try: + commit = self._git_repo.commit(commit_hash) + except gitdb.exc.BadName: + return None + meta = RepoMeta() + + meta._id = self._sha256(f"{self._org}/{self._repo}/{commit.hexsha}") + meta.id = f"{self._org}/{self._repo}" + meta.author = self._org + meta.sha = commit.hexsha + meta.lastModified = self._git_repo.head.commit.committed_datetime.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + meta.private = False + meta.gated = False + meta.disabled = False + meta.tags = [] + meta.description = self._get_description(commit) + meta.paperswithcode_id = None + meta.downloads = 0 + meta.likes = 0 + meta.cardData = yaml.load(self._match_card(self._get_readme(commit)), Loader=yaml.CLoader) + meta.siblings = [{"rfilename": p} for p in self._get_tree_files(commit)] + meta.createdAt = self._get_earliest_commit().committed_datetime.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + return meta.to_dict() + + def _contain_path(self, path: str, tree: Tree) -> bool: + norm_p = os.path.normpath(path).replace("\\", "/") + parts = norm_p.split("/") + for part in parts: + if all([t.name != part for t in tree]): + return False + else: + entry = tree[part] + if entry.type == "tree": + tree = entry + else: + tree = {} + return True + + def get_file_head(self, commit_hash: str, path: str) -> Optional[Dict[str, Any]]: + try: + commit = self._git_repo.commit(commit_hash) + except gitdb.exc.BadName: + return None + + if not self._contain_path(path, commit.tree): + return None + else: + header = {} + header["content-length"] = str(commit.tree[path].data_stream.size) + header["x-repo-commit"] = commit.hexsha + header["etag"] = self._sha256(commit.tree[path].data_stream.read()) + return header + + def get_file(self, commit_hash: str, path: str) -> Optional[OStream]: + try: + commit = self._git_repo.commit(commit_hash) + except gitdb.exc.BadName: + return None + + def stream_wrapper(file_bytes: bytes): + file_stream = io.BytesIO(file_bytes) + while True: + chunk = file_stream.read(4096) + if len(chunk) == 0: + break + else: + yield chunk + + if not self._contain_path(path, commit.tree): + return None + else: + return stream_wrapper(commit.tree[path].data_stream.read()) + + \ No newline at end of file diff --git a/olah/proxy/__init__.py b/olah/proxy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/olah/files.py b/olah/proxy/files.py similarity index 79% rename from olah/files.py rename to olah/proxy/files.py index e41cddf..fd3e67d 100644 --- a/olah/files.py +++ b/olah/proxy/files.py @@ -28,6 +28,22 @@ from olah.utils.file_utils import make_dirs from olah.constants import CHUNK_SIZE, LFS_FILE_BLOCK, WORKER_API_TIMEOUT +async def _write_cache_request(head_path: str, status_code: int, headers: Dict[str, str], content: bytes): + rq = { + "status_code": status_code, + "headers": headers, + "content": content.hex(), + } + with open(head_path, "w", encoding="utf-8") as f: + f.write(json.dumps(rq, ensure_ascii=False)) + +async def _read_cache_request(head_path: str): + with open(head_path, "r", encoding="utf-8") as f: + rq = json.loads(f.read()) + + rq["content"] = bytes.fromhex(rq["content"]) + return rq + async def _file_full_header( app, save_path: str, @@ -38,12 +54,18 @@ async def _file_full_header( headers: Dict[str, str], allow_cache: bool, ) -> Tuple[int, Dict[str, str], bytes]: - if os.path.exists(head_path): - with open(head_path, "r", encoding="utf-8") as f: - response_headers = json.loads(f.read()) - response_headers_dict = {k.lower():v for k, v in response_headers.items()} - else: - if not app.app_settings.config.offline: + assert method.lower() == "head" + if not app.app_settings.config.offline: + if os.path.exists(head_path): + cache_rq = await _read_cache_request(head_path) + response_headers_dict = {k.lower():v for k, v in cache_rq["headers"].items()} + if "location" in response_headers_dict: + parsed_url = urlparse(response_headers_dict["location"]) + if len(parsed_url.netloc) != 0: + new_loc = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response_headers_dict["location"])) + response_headers_dict["location"] = new_loc + return cache_rq["status_code"], response_headers_dict, cache_rq["content"] + else: if "range" in headers: headers.pop("range") response = await client.request( @@ -55,11 +77,9 @@ async def _file_full_header( response_headers_dict = {k.lower(): v for k, v in response.headers.items()} if allow_cache and method.lower() == "head": if response.status_code == 200: - with open(head_path, "w", encoding="utf-8") as f: - f.write(json.dumps(response_headers_dict, ensure_ascii=False)) + await _write_cache_request(head_path, response.status_code, response_headers_dict, response.content) elif response.status_code >= 300 and response.status_code <= 399: - with open(head_path, "w", encoding="utf-8") as f: - f.write(json.dumps(response_headers_dict, ensure_ascii=False)) + await _write_cache_request(head_path, response.status_code, response_headers_dict, response.content) from_url = urlparse(url) parsed_url = urlparse(response.headers["location"]) if len(parsed_url.netloc) != 0: @@ -68,25 +88,34 @@ async def _file_full_header( else: raise Exception(f"Unexpected HTTP status code {response.status_code}") return response.status_code, response_headers_dict, response.content + else: + if os.path.exists(head_path): + cache_rq = await _read_cache_request(head_path) + response_headers_dict = {k.lower():v for k, v in cache_rq["headers"].items()} else: response_headers_dict = {} + cache_rq = { + "status_code": 200, + "headers": response_headers_dict, + "content": b"", + } - new_headers = {} - if "content-type" in response_headers_dict: - new_headers["content-type"] = response_headers_dict["content-type"] - if "content-length" in response_headers_dict: - new_headers["content-length"] = response_headers_dict["content-length"] - if HUGGINGFACE_HEADER_X_REPO_COMMIT.lower() in response_headers_dict: - new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_REPO_COMMIT.lower(), "") - if HUGGINGFACE_HEADER_X_LINKED_ETAG.lower() in response_headers_dict: - new_headers[HUGGINGFACE_HEADER_X_LINKED_ETAG.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_ETAG.lower(), "") - if HUGGINGFACE_HEADER_X_LINKED_SIZE.lower() in response_headers_dict: - new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_SIZE.lower(), "") - if "etag" in response_headers_dict: - new_headers["etag"] = response_headers_dict["etag"] - if "location" in response_headers_dict: - new_headers["location"] = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response_headers_dict["location"])) - return 200, new_headers, b"" + new_headers = {} + if "content-type" in response_headers_dict: + new_headers["content-type"] = response_headers_dict["content-type"] + if "content-length" in response_headers_dict: + new_headers["content-length"] = response_headers_dict["content-length"] + if HUGGINGFACE_HEADER_X_REPO_COMMIT.lower() in response_headers_dict: + new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_REPO_COMMIT.lower(), "") + if HUGGINGFACE_HEADER_X_LINKED_ETAG.lower() in response_headers_dict: + new_headers[HUGGINGFACE_HEADER_X_LINKED_ETAG.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_ETAG.lower(), "") + if HUGGINGFACE_HEADER_X_LINKED_SIZE.lower() in response_headers_dict: + new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_SIZE.lower(), "") + if "etag" in response_headers_dict: + new_headers["etag"] = response_headers_dict["etag"] + if "location" in response_headers_dict: + new_headers["location"] = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response_headers_dict["location"])) + return cache_rq["status_code"], new_headers, cache_rq["content"] async def _get_file_block_from_cache(cache_file: OlahCache, block_index: int): raw_block = cache_file.read_block(block_index) @@ -240,6 +269,7 @@ async def _file_realtime_stream( yield head_info yield content return + file_size = int(head_info["content-length"]) response_headers = {k: v for k,v in head_info.items()} if "range" in request_headers: diff --git a/olah/lfs.py b/olah/proxy/lfs.py similarity index 97% rename from olah/lfs.py rename to olah/proxy/lfs.py index d4474dc..7588ff4 100644 --- a/olah/lfs.py +++ b/olah/proxy/lfs.py @@ -9,7 +9,7 @@ from typing import Literal from fastapi import FastAPI, Header, Request -from olah.files import _file_realtime_stream +from olah.proxy.files import _file_realtime_stream from olah.utils.file_utils import make_dirs from olah.utils.url_utils import check_cache_rules_hf, get_org_repo diff --git a/olah/meta.py b/olah/proxy/meta.py similarity index 100% rename from olah/meta.py rename to olah/proxy/meta.py diff --git a/olah/server.py b/olah/server.py index 2706377..57869e5 100644 --- a/olah/server.py +++ b/olah/server.py @@ -12,16 +12,19 @@ from typing import Annotated, Optional, Union from urllib.parse import urljoin from fastapi import FastAPI, Header, Request -from fastapi.responses import HTMLResponse, StreamingResponse, Response +from fastapi.responses import FileResponse, HTMLResponse, StreamingResponse, Response, JSONResponse from fastapi_utils.tasks import repeat_every +import git import httpx from pydantic import BaseSettings from olah.configs import OlahConfig -from olah.files import cdn_file_get_generator, file_get_generator -from olah.lfs import lfs_get_generator, lfs_head_generator -from olah.meta import meta_generator, meta_proxy_cache +from olah.errors import error_repo_not_found +from olah.mirror.repos import LocalMirrorRepo +from olah.proxy.files import cdn_file_get_generator, file_get_generator +from olah.proxy.lfs import lfs_get_generator, lfs_head_generator +from olah.proxy.meta import meta_generator, meta_proxy_cache from olah.utils.url_utils import check_proxy_rules_hf, check_commit_hf, get_commit_hf, get_newest_commit_hf, parse_org_repo - +from olah.constants import REPO_TYPES_MAPPING from olah.utils.logging import build_logger # ====================== @@ -81,25 +84,37 @@ class AppSettings(BaseSettings): # API Hooks # ====================== async def meta_proxy_common(repo_type: str, org: str, repo: str, commit: str, request: Request) -> Response: - if not await check_proxy_rules_hf(app, repo_type, org, repo): + if repo_type not in REPO_TYPES_MAPPING.keys(): return Response( - content="This repository is forbidden by the mirror. ", status_code=403 + content="Invalid repository type. ", status_code=403 ) + if not await check_proxy_rules_hf(app, repo_type, org, repo): + return error_repo_not_found() + # Check Mirror Path + for mirror_path in app.app_settings.config.mirrors_path: + try: + git_path = os.path.join(mirror_path, repo_type, org, repo) + if os.path.exists(git_path): + local_repo = LocalMirrorRepo(git_path, repo_type, org, repo) + meta_data = local_repo.get_meta(commit) + if meta_data is None: + continue + return JSONResponse(content=meta_data) + except git.exc.InvalidGitRepositoryError: + logger.warning(f"Local repository {git_path} is not a valid git reposity.") + continue + + # Proxy the HF File Meta try: if not app.app_settings.config.offline and not await check_commit_hf( app, repo_type, org, repo, commit ): - return Response( - content="This repository is not accessible. ", status_code=404 - ) + return error_repo_not_found() commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) if commit_sha is None: - return Response( - content="This repository is not accessible. ", status_code=404 - ) - + return error_repo_not_found() # if branch name and online mode, refresh branch info - if commit_sha != commit and not app.app_settings.config.offline: + if not app.app_settings.config.offline and commit_sha != commit: await meta_proxy_cache(app, repo_type, org, repo, commit, request) generator = meta_generator(app, repo_type, org, repo, commit_sha, request) @@ -150,22 +165,36 @@ async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request: async def file_head_common( repo_type: str, org: str, repo: str, commit: str, file_path: str, request: Request ) -> Response: - if not await check_proxy_rules_hf(app, repo_type, org, repo): + if repo_type not in REPO_TYPES_MAPPING.keys(): return Response( - content="This repository is forbidden by the mirror. ", status_code=403 + content="Invalid repository type. ", status_code=403 ) + if not await check_proxy_rules_hf(app, repo_type, org, repo): + return error_repo_not_found() + + # Check Mirror Path + for mirror_path in app.app_settings.config.mirrors_path: + try: + git_path = os.path.join(mirror_path, repo_type, org, repo) + if os.path.exists(git_path): + local_repo = LocalMirrorRepo(git_path, repo_type, org, repo) + head = local_repo.get_file_head(commit_hash=commit, path=file_path) + if head is None: + continue + return Response(headers=head) + except git.exc.InvalidGitRepositoryError: + logger.warning(f"Local repository {git_path} is not a valid git reposity.") + continue + + # Proxy the HF File Head try: if not app.app_settings.config.offline and not await check_commit_hf( app, repo_type, org, repo, commit ): - return Response( - content="This repository is not accessible. ", status_code=404 - ) + return error_repo_not_found() commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) if commit_sha is None: - return Response( - content="This repository is not accessible. ", status_code=404 - ) + return error_repo_not_found() generator = await file_get_generator( app, repo_type, @@ -264,20 +293,31 @@ async def cdn_file_head(org_repo: str, hash_file: str, request: Request, repo_ty async def file_get_common( repo_type: str, org: str, repo: str, commit: str, file_path: str, request: Request ) -> Response: - if not await check_proxy_rules_hf(app, repo_type, org, repo): + if repo_type not in REPO_TYPES_MAPPING.keys(): return Response( - content="This repository is forbidden by the mirror. ", status_code=403 + content="Invalid repository type. ", status_code=403 ) + if not await check_proxy_rules_hf(app, repo_type, org, repo): + return error_repo_not_found() + # Check Mirror Path + for mirror_path in app.app_settings.config.mirrors_path: + try: + git_path = os.path.join(mirror_path, repo_type, org, repo) + if os.path.exists(git_path): + local_repo = LocalMirrorRepo(git_path, repo_type, org, repo) + content_stream = local_repo.get_file(commit_hash=commit, path=file_path) + if content_stream is None: + continue + return StreamingResponse(content_stream) + except git.exc.InvalidGitRepositoryError: + logger.warning(f"Local repository {git_path} is not a valid git reposity.") + continue try: if not app.app_settings.config.offline and not await check_commit_hf(app, repo_type, org, repo, commit): - return Response( - content="This repository is not accessible. ", status_code=404 - ) + return error_repo_not_found() commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) if commit_sha is None: - return Response( - content="This repository is not accessible. ", status_code=404 - ) + return error_repo_not_found() generator = await file_get_generator( app, repo_type, diff --git a/pyproject.toml b/pyproject.toml index 0f8342b..38bdd39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,8 @@ classifiers = [ ] dependencies = [ "fastapi", "fastapi-utils", "httpx", "numpy", "pydantic<=2.8.2", "requests", "toml", - "rich>=10.0.0", "shortuuid", "uvicorn", "tenacity>=8.2.2", "pytz", "cachetools" + "rich>=10.0.0", "shortuuid", "uvicorn", "tenacity>=8.2.2", "pytz", "cachetools", "GitPython", + "PyYAML" ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index dc90e65..6e87169 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,10 @@ fastapi==0.111.0 fastapi-utils==0.7.0 +GitPython==3.1.43 httpx==0.27.0 pydantic==2.8.2 toml==0.10.2 huggingface_hub==0.23.4 pytest==8.2.2 -cachetools==5.4.0 \ No newline at end of file +cachetools==5.4.0 +PyYAML==6.0.1 \ No newline at end of file