diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..d401c23 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,19 @@ +# How can I contribute to Olah? + +Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable. + +It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you. + +However you choose to contribute, please be mindful and respect our code of conduct. + +## Ways to contribute + +There are lots of ways you can contribute to Olah: +* Submitting issues on Github to report bugs or make feature requests +* Fixing outstanding issues with the existing code +* Implementing new features +* Contributing to the examples or to the documentation + +*All are equally valuable to the community.* + +#### This guide was heavily inspired by the awesome [transformers guide to contributing](https://github.com/huggingface/transformers/blob/master/CONTRIBUTING.md) diff --git a/README.md b/README.md index 2084af9..569af08 100644 --- a/README.md +++ b/README.md @@ -185,7 +185,6 @@ allow = false ## Future Work -* Authentication * Administrator and user system * OOS backend support * Mirror Update Schedule Task diff --git a/README_zh.md b/README_zh.md index 99e4a4a..46cdb54 100644 --- a/README_zh.md +++ b/README_zh.md @@ -4,7 +4,7 @@

自托管的轻量级HuggingFace镜像服务 -Olah是一种自托管的轻量级HuggingFace镜像服务。`Olah`来源于丘丘人语,在丘丘人语中意味着`你好`。 +Olah是开源的自托管轻量级HuggingFace镜像服务。`Olah`来源于丘丘人语,在丘丘人语中意味着`你好`。 Olah真正地实现了huggingface资源的`镜像`功能,而不仅仅是一个简单的`反向代理`。 Olah并不会立刻对huggingface全站进行镜像,而是在用户下载的同时在文件块级别对资源进行镜像(或者我们可以说是缓存)。 diff --git a/docs/en/main.md b/docs/en/main.md new file mode 100644 index 0000000..88c8d4f --- /dev/null +++ b/docs/en/main.md @@ -0,0 +1,8 @@ +

Olah Document

+ +

+Self-hosted Lightweight Huggingface Mirror Service + +Olah is a self-hosted lightweight huggingface mirror service. `Olah` means `hello` in Hilichurlian. +Olah implemented the `mirroring` feature for huggingface resources, rather than just a simple `reverse proxy`. +Olah does not immediately mirror the entire huggingface website but mirrors the resources at the file block level when users download them (or we can say cache them). diff --git a/docs/en/quickstart.md b/docs/en/quickstart.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/zh/main.md b/docs/zh/main.md new file mode 100644 index 0000000..c7ed505 --- /dev/null +++ b/docs/zh/main.md @@ -0,0 +1,9 @@ +

Olah 文档

+ + +

+自托管的轻量级HuggingFace镜像服务 + +Olah是开源的自托管轻量级HuggingFace镜像服务。`Olah`来源于丘丘人语,在丘丘人语中意味着`你好`。 +Olah真正地实现了huggingface资源的`镜像`功能,而不仅仅是一个简单的`反向代理`。 +Olah并不会立刻对huggingface全站进行镜像,而是在用户下载的同时在文件块级别对资源进行镜像(或者我们可以说是缓存)。 diff --git a/docs/zh/quickstart.md b/docs/zh/quickstart.md new file mode 100644 index 0000000..e69de29 diff --git a/olah/auth/__init__.py b/olah/auth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/olah/constants.py b/olah/constants.py index e10b81b..04bfc07 100644 --- a/olah/constants.py +++ b/olah/constants.py @@ -11,6 +11,8 @@ DEFAULT_LOGGER_DIR = "./logs" +ORIGINAL_LOC = "oriloc" + from huggingface_hub.constants import ( REPO_TYPES_MAPPING, HUGGINGFACE_CO_URL_TEMPLATE, diff --git a/olah/proxy/files.py b/olah/proxy/files.py index fd3e67d..8ae43f1 100644 --- a/olah/proxy/files.py +++ b/olah/proxy/files.py @@ -1,6 +1,6 @@ # coding=utf-8 # Copyright 2024 XiaHan -# +# # Use of this source code is governed by an MIT-style # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. @@ -22,13 +22,39 @@ HUGGINGFACE_HEADER_X_REPO_COMMIT, HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_LINKED_SIZE, + ORIGINAL_LOC, ) from olah.utils.olah_cache import OlahCache -from olah.utils.url_utils import RemoteInfo, check_cache_rules_hf, get_org_repo, get_url_tail, parse_range_params +from olah.utils.url_utils import ( + RemoteInfo, + add_query_param, + check_url_has_param_name, + get_url_param_name, + get_url_tail, + parse_range_params, + remove_query_param, +) +from olah.utils.repo_utils import get_org_repo +from olah.utils.rule_utils import check_cache_rules_hf from olah.utils.file_utils import make_dirs from olah.constants import CHUNK_SIZE, LFS_FILE_BLOCK, WORKER_API_TIMEOUT -async def _write_cache_request(head_path: str, status_code: int, headers: Dict[str, str], content: bytes): + +async def _write_cache_request( + head_path: str, status_code: int, headers: Dict[str, str], content: bytes +)-> None: + """ + Write the request's status code, headers, and content to a cache file. + + Args: + head_path (str): The path to the cache file. + status_code (int): The status code of the request. + headers (Dict[str, str]): The dictionary of response headers. + content (bytes): The content of the request. + + Returns: + None + """ rq = { "status_code": status_code, "headers": headers, @@ -37,13 +63,24 @@ async def _write_cache_request(head_path: str, status_code: int, headers: Dict[s with open(head_path, "w", encoding="utf-8") as f: f.write(json.dumps(rq, ensure_ascii=False)) -async def _read_cache_request(head_path: str): + +async def _read_cache_request(head_path: str) -> Dict[str, str]: + """ + Read the request's status code, headers, and content from a cache file. + + Args: + head_path (str): The path to the cache file. + + Returns: + Dict[str, str]: A dictionary containing the status code, headers, and content of the request. + """ with open(head_path, "r", encoding="utf-8") as f: rq = json.loads(f.read()) - + rq["content"] = bytes.fromhex(rq["content"]) return rq + async def _file_full_header( app, save_path: str, @@ -85,6 +122,13 @@ async def _file_full_header( if len(parsed_url.netloc) != 0: new_loc = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response.headers["location"])) response_headers_dict["location"] = new_loc + # Redirect, add original location info + if check_url_has_param_name(response_headers_dict["location"], ORIGINAL_LOC): + raise Exception(f"Invalid field {ORIGINAL_LOC} in the url.") + else: + response_headers_dict["location"] = add_query_param(response_headers_dict["location"], ORIGINAL_LOC, response.headers["location"]) + elif response.status_code == 403: + pass else: raise Exception(f"Unexpected HTTP status code {response.status_code}") return response.status_code, response_headers_dict, response.content @@ -244,13 +288,26 @@ async def _file_realtime_stream( allow_cache=True, commit: Optional[str] = None, ): - request_headers = {k: v for k, v in request.headers.items()} - request_headers.pop("host") - - if urlparse(url).netloc == app.app_settings.config.mirror_netloc: - hf_url = urljoin(app.app_settings.config.hf_lfs_url_base(), get_url_tail(url)) + if check_url_has_param_name(url, ORIGINAL_LOC): + clean_url = remove_query_param(url, ORIGINAL_LOC) + original_loc = get_url_param_name(url, ORIGINAL_LOC) + if urlparse(url).netloc in [app.app_settings.config.mirror_netloc, app.app_settings.config.mirror_lfs_netloc]: + hf_loc = urlparse(original_loc) + if len(hf_loc.netloc) != 0: + hf_url = urljoin(f"{hf_loc.scheme}://{hf_loc.netloc}", get_url_tail(clean_url)) + else: + hf_url = url + else: + hf_url = url else: - hf_url = url + if urlparse(url).netloc in [app.app_settings.config.mirror_netloc, app.app_settings.config.mirror_lfs_netloc]: + hf_url = urljoin(app.app_settings.config.hf_lfs_url_base(), get_url_tail(url)) + else: + hf_url = url + + request_headers = {k: v for k, v in request.headers.items()} + if "host" in request_headers: + request_headers["host"] = urlparse(hf_url).netloc async with httpx.AsyncClient() as client: # redirect_loc = await _get_redirected_url(client, method, url, request_headers) @@ -270,6 +327,7 @@ async def _file_realtime_stream( yield content return + async with httpx.AsyncClient() as client: file_size = int(head_info["content-length"]) response_headers = {k: v for k,v in head_info.items()} if "range" in request_headers: diff --git a/olah/proxy/lfs.py b/olah/proxy/lfs.py index 7588ff4..a85dd26 100644 --- a/olah/proxy/lfs.py +++ b/olah/proxy/lfs.py @@ -11,7 +11,6 @@ from olah.proxy.files import _file_realtime_stream from olah.utils.file_utils import make_dirs -from olah.utils.url_utils import check_cache_rules_hf, get_org_repo async def lfs_head_generator( diff --git a/olah/proxy/meta.py b/olah/proxy/meta.py index ea7eebf..9116521 100644 --- a/olah/proxy/meta.py +++ b/olah/proxy/meta.py @@ -15,7 +15,8 @@ import httpx from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT -from olah.utils.url_utils import check_cache_rules_hf, get_org_repo +from olah.utils.rule_utils import check_cache_rules_hf +from olah.utils.repo_utils import get_org_repo from olah.utils.file_utils import make_dirs @@ -50,10 +51,14 @@ async def meta_proxy_cache( app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}/revision/{commit}", ) + headers = {} + if "authorization" in request.headers: + headers["authorization"] = request.headers["authorization"] async with httpx.AsyncClient() as client: response = await client.request( method="GET", url=meta_url, + headers=headers, timeout=WORKER_API_TIMEOUT, follow_redirects=True, ) diff --git a/olah/server.py b/olah/server.py index b99882a..377808b 100644 --- a/olah/server.py +++ b/olah/server.py @@ -26,7 +26,8 @@ from olah.proxy.files import cdn_file_get_generator, file_get_generator from olah.proxy.lfs import lfs_get_generator, lfs_head_generator from olah.proxy.meta import meta_generator, meta_proxy_cache -from olah.utils.url_utils import check_proxy_rules_hf, check_commit_hf, get_commit_hf, get_newest_commit_hf, get_org_repo, parse_org_repo +from olah.utils.rule_utils import check_proxy_rules_hf, get_org_repo +from olah.utils.repo_utils import check_commit_hf, get_commit_hf, get_newest_commit_hf, parse_org_repo from olah.constants import REPO_TYPES_MAPPING from olah.utils.logging import build_logger @@ -85,7 +86,8 @@ class AppSettings(BaseSettings): repos_path: str = "./repos" # ====================== -# API Hooks +# File Meta Info API Hooks +# See also: https://huggingface.co/docs/hub/api#repo-listing-api # ====================== async def meta_proxy_common(repo_type: str, org: str, repo: str, commit: str, request: Request) -> Response: if repo_type not in REPO_TYPES_MAPPING.keys(): @@ -109,16 +111,27 @@ async def meta_proxy_common(repo_type: str, org: str, repo: str, commit: str, re # Proxy the HF File Meta try: if not app.app_settings.config.offline and not await check_commit_hf( - app, repo_type, org, repo, commit + app, + repo_type, + org, + repo, + commit=commit, + authorization=request.headers.get("authorization", None), ): return error_repo_not_found() - commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) + commit_sha = await get_commit_hf( + app, + repo_type, + org, + repo, + commit=commit, + authorization=request.headers.get("authorization", None), + ) if commit_sha is None: return error_repo_not_found() # if branch name and online mode, refresh branch info if not app.app_settings.config.offline and commit_sha != commit: await meta_proxy_cache(app, repo_type, org, repo, commit, request) - generator = meta_generator(app, repo_type, org, repo, commit_sha, request) headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) @@ -160,6 +173,29 @@ async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request: repo_type=repo_type, org=org, repo=repo, commit=commit, request=request ) +# ====================== +# Authentication API Hooks +# ====================== +@app.get("/api/whoami-v2") +async def whoami_v2(request: Request): + """ + Sensitive Information!!! + """ + new_headers = {k.lower(): v for k, v in request.headers.items()} + new_headers["host"] = app.app_settings.config.hf_netloc + async with httpx.AsyncClient() as client: + response = await client.request( + method="GET", + url=urljoin(app.app_settings.config.hf_url_base(), "/api/whoami-v2"), + headers=new_headers, + timeout=10, + ) + return Response( + content=response.content, + status_code=response.status_code, + headers=response.headers, + ) + # ====================== # File Head Hooks @@ -185,14 +221,26 @@ async def file_head_common( except git.exc.InvalidGitRepositoryError: logger.warning(f"Local repository {git_path} is not a valid git reposity.") continue - + # Proxy the HF File Head try: if not app.app_settings.config.offline and not await check_commit_hf( - app, repo_type, org, repo, commit + app, + repo_type, + org, + repo, + commit=commit, + authorization=request.headers.get("authorization", None), ): return error_repo_not_found() - commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) + commit_sha = await get_commit_hf( + app, + repo_type, + org, + repo, + commit=commit, + authorization=request.headers.get("authorization", None), + ) if commit_sha is None: return error_repo_not_found() generator = await file_get_generator( @@ -309,9 +357,23 @@ async def file_get_common( logger.warning(f"Local repository {git_path} is not a valid git reposity.") continue try: - if not app.app_settings.config.offline and not await check_commit_hf(app, repo_type, org, repo, commit): + if not app.app_settings.config.offline and not await check_commit_hf( + app, + repo_type, + org, + repo, + commit=commit, + authorization=request.headers.get("authorization", None), + ): return error_repo_not_found() - commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) + commit_sha = await get_commit_hf( + app, + repo_type, + org, + repo, + commit=commit, + authorization=request.headers.get("authorization", None), + ) if commit_sha is None: return error_repo_not_found() generator = await file_get_generator( @@ -333,7 +395,9 @@ async def file_get_common( @app.get("/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path:path}") -async def file_get3(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str): +async def file_get3( + org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str +): return await file_get_common( repo_type=repo_type, org=org, @@ -343,8 +407,11 @@ async def file_get3(org: str, repo: str, commit: str, file_path: str, request: R request=request, ) + @app.get("/{org_or_repo_type}/{repo_name}/resolve/{commit}/{file_path:path}") -async def file_get2(org_or_repo_type: str, repo_name: str, commit: str, file_path: str, request: Request): +async def file_get2( + org_or_repo_type: str, repo_name: str, commit: str, file_path: str, request: Request +): if org_or_repo_type in REPO_TYPES_MAPPING.keys(): repo_type: str = org_or_repo_type org, repo = parse_org_repo(repo_name) @@ -363,6 +430,7 @@ async def file_get2(org_or_repo_type: str, repo_name: str, commit: str, file_pat request=request, ) + @app.get("/{org_repo}/resolve/{commit}/{file_path:path}") async def file_get(org_repo: str, commit: str, file_path: str, request: Request): repo_type: str = "models" @@ -379,9 +447,12 @@ async def file_get(org_repo: str, commit: str, file_path: str, request: Request) request=request, ) + @app.get("/{org_repo}/{hash_file}") @app.get("/{repo_type}/{org_repo}/{hash_file}") -async def cdn_file_get(org_repo: str, hash_file: str, request: Request, repo_type: str = "models"): +async def cdn_file_get( + org_repo: str, hash_file: str, request: Request, repo_type: str = "models" +): org, repo = parse_org_repo(org_repo) if org is None and repo is None: return error_repo_not_found() @@ -389,13 +460,16 @@ async def cdn_file_get(org_repo: str, hash_file: str, request: Request, repo_typ if not await check_proxy_rules_hf(app, repo_type, org, repo): return error_repo_not_found() try: - generator = await cdn_file_get_generator(app, repo_type, org, repo, hash_file, method="GET", request=request) + generator = await cdn_file_get_generator( + app, repo_type, org, repo, hash_file, method="GET", request=request + ) status_code = await generator.__anext__() headers = await generator.__anext__() return StreamingResponse(generator, headers=headers, status_code=status_code) except httpx.ConnectTimeout: return Response(status_code=504) + # ====================== # LFS Hooks # ====================== diff --git a/olah/utils/bitset.py b/olah/utils/bitset.py index 9fc4cea..b290d3a 100644 --- a/olah/utils/bitset.py +++ b/olah/utils/bitset.py @@ -1,23 +1,48 @@ # coding=utf-8 # Copyright 2024 XiaHan -# +# # Use of this source code is governed by an MIT-style # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. + class Bitset: - def __init__(self, size): + def __init__(self, size) -> None: + """ + Initializes a Bitset object with a given size. + + Args: + size (int): The number of bits in the Bitset. + """ self.size = size - self.bits = bytearray((0, ) * ((size + 7) // 8)) + self.bits = bytearray((0,) * ((size + 7) // 8)) + + def set(self, index: int) -> None: + """ + Sets the bit at the specified index to 1. - def set(self, index): + Args: + index (int): The index of the bit to be set. + + Raises: + IndexError: If the index is out of range. + """ if index < 0 or index >= self.size: raise IndexError("Index out of range") byte_index = index // 8 bit_index = index % 8 - self.bits[byte_index] |= (1 << bit_index) + self.bits[byte_index] |= 1 << bit_index + + def clear(self, index: int) -> None: + """ + Sets the bit at the specified index to 0. - def clear(self, index): + Args: + index (int): The index of the bit to be cleared. + + Raises: + IndexError: If the index is out of range. + """ if index < 0 or index >= self.size: raise IndexError("Index out of range") self._resize_if_needed(index) @@ -25,7 +50,19 @@ def clear(self, index): bit_index = index % 8 self.bits[byte_index] &= ~(1 << bit_index) - def test(self, index): + def test(self, index: int) -> None: + """ + Checks the value of the bit at the specified index. + + Args: + index (int): The index of the bit to be checked. + + Returns: + bool: True if the bit is set (1), False if the bit is cleared (0). + + Raises: + IndexError: If the index is out of range. + """ if index < 0 or index >= self.size: raise IndexError("Index out of range") byte_index = index // 8 @@ -33,4 +70,10 @@ def test(self, index): return bool(self.bits[byte_index] & (1 << bit_index)) def __str__(self): - return ''.join(bin(byte)[2:].zfill(8) for byte in self.bits) + """ + Returns a string representation of the Bitset. + + Returns: + str: A string representation of the Bitset object, showing the binary representation of each byte. + """ + return "".join(bin(byte)[2:].zfill(8) for byte in self.bits) diff --git a/olah/utils/olah_cache.py b/olah/utils/olah_cache.py index 39e0d3b..a6e6701 100644 --- a/olah/utils/olah_cache.py +++ b/olah/utils/olah_cache.py @@ -1,6 +1,6 @@ # coding=utf-8 # Copyright 2024 XiaHan -# +# # Use of this source code is governed by an MIT-style # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. @@ -34,11 +34,11 @@ def __init__( self._block_mask_size = DEFAULT_BLOCK_MASK_MAX self._block_mask = Bitset(DEFAULT_BLOCK_MASK_MAX) - + @property def version(self) -> int: return self._version - + @property def block_size(self) -> int: return self._block_size @@ -46,7 +46,7 @@ def block_size(self) -> int: @property def file_size(self) -> int: return self._file_size - + @property def block_number(self) -> int: return self._block_number @@ -67,7 +67,7 @@ def _valid_header(self): raise Exception( f"This Olah Cache file is created by older version Olah. Please remove cache files and retry." ) - + if self._version > CURRENT_OLAH_CACHE_VERSION: raise Exception( f"This Olah Cache file is created by newer version Olah. Please remove cache files and retry." @@ -104,12 +104,13 @@ def write(self, stream): btyes_out = btyes_header + self._block_mask.bits stream.write(btyes_out) + class OlahCache(object): def __init__(self, path: str, block_size: int = DEFAULT_BLOCK_SIZE) -> None: self.path: Optional[str] = path self.header: Optional[OlahCacheHeader] = None self.is_open: bool = False - + # Lock self._header_lock = threading.Lock() @@ -118,7 +119,6 @@ def __init__(self, path: str, block_size: int = DEFAULT_BLOCK_SIZE) -> None: self._prefech_blocks: int = 16 self.open(path, block_size=block_size) - @staticmethod def create(path: str, block_size: int = DEFAULT_BLOCK_SIZE): @@ -155,7 +155,7 @@ def close(self): self.header = None self._blocks_read_cache.clear() - + self.is_open = False def _flush_header(self): @@ -173,23 +173,23 @@ def _get_block_number(self) -> int: with self._header_lock: block_number = self.header.block_number return block_number - + def _get_block_size(self) -> int: with self._header_lock: block_size = self.header.block_size return block_size - + def _get_header_size(self) -> int: with self._header_lock: header_size = self.header.get_header_size() return header_size - + def _resize_header(self, block_num: int, file_size: int): with self._header_lock: self.header._block_number = block_num self.header._file_size = file_size self.header._valid_header() - + def _set_header_block(self, block_index: int): with self._header_lock: self.header.block_mask.set(block_index) @@ -198,7 +198,7 @@ def _test_header_block(self, block_index: int): with self._header_lock: result = self.header.block_mask.test(block_index) return result - + def _pad_block(self, raw_block: bytes): if len(raw_block) < self._get_block_size(): block = raw_block + b"\x00" * (self._get_block_size() - len(raw_block)) @@ -210,7 +210,7 @@ def flush(self): if not self.is_open: raise Exception("This file has been close.") self._flush_header() - + def has_block(self, block_index: int) -> bool: return self._test_header_block(block_index) @@ -220,11 +220,11 @@ def read_block(self, block_index: int) -> Optional[bytes]: if block_index >= self._get_block_number(): raise Exception("Invalid block index.") - + # Check Cache if block_index in self._blocks_read_cache: return self._blocks_read_cache[block_index] - + if not self.has_block(block_index=block_index): return None @@ -240,11 +240,13 @@ def read_block(self, block_index: int) -> Optional[bytes]: self._blocks_read_cache[block_index + block_offset] = None else: prefetch_raw_block = f.read(self._get_block_size()) - self._blocks_read_cache[block_index + block_offset] = self._pad_block(prefetch_raw_block) + self._blocks_read_cache[block_index + block_offset] = ( + self._pad_block(prefetch_raw_block) + ) block = self._pad_block(raw_block) return block - + def write_block(self, block_index: int, block_bytes: bytes) -> None: if not self.is_open: raise Exception("This file has been closed.") @@ -259,25 +261,29 @@ def write_block(self, block_index: int, block_bytes: bytes) -> None: with open(self.path, "rb+") as f: f.seek(offset) if (block_index + 1) * self._get_block_size() > self._get_file_size(): - real_block_bytes = block_bytes[:self._get_file_size() - block_index * self._get_block_size()] + real_block_bytes = block_bytes[ + : self._get_file_size() - block_index * self._get_block_size() + ] else: real_block_bytes = block_bytes f.write(real_block_bytes) - + self._set_header_block(block_index) self._flush_header() # Clear Cache if block_index in self._blocks_read_cache: del self._blocks_read_cache[block_index] - + def _resize_file_size(self, file_size: int): if not self.is_open: raise Exception("This file has been closed.") if file_size == self._get_file_size(): return if file_size < self._get_file_size(): - raise Exception("Invalid resize file size. New file size must be greater than the current file size.") + raise Exception( + "Invalid resize file size. New file size must be greater than the current file size." + ) with open(self.path, "rb") as f: f.seek(0, os.SEEK_END) @@ -287,7 +293,7 @@ def _resize_file_size(self, file_size: int): with open(self.path, "rb+") as f: # Extend file size f.seek(0, os.SEEK_END) - f.write(b'\x00' * (new_bin_size - bin_size)) + f.write(b"\x00" * (new_bin_size - bin_size)) def resize(self, file_size: int): if not self.is_open: @@ -297,4 +303,3 @@ def resize(self, file_size: int): self._resize_file_size(file_size) self._resize_header(new_block_num, file_size) self._flush_header() - \ No newline at end of file diff --git a/olah/utils/repo_utils.py b/olah/utils/repo_utils.py new file mode 100644 index 0000000..2afb00a --- /dev/null +++ b/olah/utils/repo_utils.py @@ -0,0 +1,211 @@ +# coding=utf-8 +# Copyright 2024 XiaHan +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + +import datetime +import os +import glob +from typing import Dict, Literal, Optional, Tuple, Union +import json +from urllib.parse import urljoin +import httpx +from olah.constants import WORKER_API_TIMEOUT + + +def get_org_repo(org: Optional[str], repo: str) -> str: + if org is None: + org_repo = repo + else: + org_repo = f"{org}/{repo}" + return org_repo + + +def parse_org_repo(org_repo: str) -> Tuple[Optional[str], Optional[str]]: + if "/" in org_repo and org_repo.count("/") != 1: + return None, None + if "/" in org_repo: + org, repo = org_repo.split("/") + else: + org = None + repo = org_repo + return org, repo + + +def get_meta_save_path( + repos_path: str, repo_type: str, org: Optional[str], repo: str, commit: str +) -> str: + return os.path.join( + repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}/meta.json" + ) + + +def get_meta_save_dir( + repos_path: str, repo_type: str, org: Optional[str], repo: str +) -> str: + return os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/revision") + + +def get_file_save_path( + repos_path: str, + repo_type: str, + org: Optional[str], + repo: str, + commit: str, + file_path: str, +) -> str: + return os.path.join( + repos_path, f"heads/{repo_type}/{org}/{repo}/resolve_head/{commit}/{file_path}" + ) + + +async def get_newest_commit_hf_offline( + app, + repo_type: Optional[Literal["models", "datasets", "spaces"]], + org: str, + repo: str, +) -> str: + repos_path = app.app_settings.repos_path + save_dir = get_meta_save_dir(repos_path, repo_type, org, repo) + files = glob.glob(os.path.join(save_dir, "*", "meta.json")) + + time_revisions = [] + for file in files: + with open(file, "r", encoding="utf-8") as f: + obj = json.loads(f.read()) + datetime_object = datetime.datetime.fromisoformat(obj["lastModified"]) + time_revisions.append((datetime_object, obj["sha"])) + + time_revisions = sorted(time_revisions) + return time_revisions[-1][1] + + +async def get_newest_commit_hf( + app, + repo_type: Optional[Literal["models", "datasets", "spaces"]], + org: Optional[str], + repo: str, +) -> Optional[str]: + url = urljoin( + app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org}/{repo}" + ) + if app.app_settings.config.offline: + return get_newest_commit_hf_offline(app, repo_type, org, repo) + try: + async with httpx.AsyncClient() as client: + response = await client.get(url, timeout=WORKER_API_TIMEOUT) + if response.status_code != 200: + return get_newest_commit_hf_offline(app, repo_type, org, repo) + obj = json.loads(response.text) + return obj.get("sha", None) + except: + return get_newest_commit_hf_offline(app, repo_type, org, repo) + + +async def get_commit_hf_offline( + app, + repo_type: Optional[Literal["models", "datasets", "spaces"]], + org: Optional[str], + repo: str, + commit: str, +) -> Optional[str]: + """ + Retrieves the commit SHA for a given repository and commit from the offline cache. + + This function is used when the application is in offline mode and the commit information is not available from the API. + + Args: + app: The application instance. + repo_type: Optional. The type of repository ("models", "datasets", or "spaces"). + org: Optional. The organization name for the repository. + repo: The name of the repository. + commit: The commit identifier. + + Returns: + The commit SHA as a string if available in the offline cache, or None if the information is not cached. + """ + repos_path = app.app_settings.repos_path + save_path = get_meta_save_path(repos_path, repo_type, org, repo, commit) + if os.path.exists(save_path): + with open(save_path, "r", encoding="utf-8") as f: + obj = json.loads(f.read()) + return obj["sha"] + else: + return None + + +async def get_commit_hf( + app, + repo_type: Optional[Literal["models", "datasets", "spaces"]], + org: Optional[str], + repo: str, + commit: str, + authorization: Optional[str] = None, +) -> Optional[str]: + """ + Retrieves the commit SHA for a given repository and commit from the Hugging Face API. + + Args: + app: The application instance. + repo_type: Optional. The type of repository ("models", "datasets", or "spaces"). + org: Optional. The organization name for the repository. + repo: The name of the repository. + commit: The commit identifier. + authorization: Optional. The authorization token for accessing the API. + + Returns: + The commit SHA as a string, or None if the commit cannot be retrieved. + + Raises: + This function does not raise any explicit exceptions but may propagate exceptions from underlying functions. + """ + org_repo = get_org_repo(org, repo) + url = urljoin( + app.app_settings.config.hf_url_base(), + f"/api/{repo_type}/{org_repo}/revision/{commit}", + ) + if app.app_settings.config.offline: + return await get_commit_hf_offline(app, repo_type, org, repo, commit) + try: + headers = {} + if authorization is not None: + headers["authorization"] = authorization + async with httpx.AsyncClient() as client: + response = await client.get( + url, headers=headers, timeout=WORKER_API_TIMEOUT, follow_redirects=True + ) + if response.status_code not in [200, 307]: + return await get_commit_hf_offline(app, repo_type, org, repo, commit) + obj = json.loads(response.text) + return obj.get("sha", None) + except: + return await get_commit_hf_offline(app, repo_type, org, repo, commit) + + +async def check_commit_hf( + app, + repo_type: Optional[Literal["models", "datasets", "spaces"]], + org: Optional[str], + repo: str, + commit: Optional[str] = None, + authorization: Optional[str] = None, +) -> bool: + org_repo = get_org_repo(org, repo) + if commit is None: + url = urljoin( + app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}" + ) + else: + url = urljoin( + app.app_settings.config.hf_url_base(), + f"/api/{repo_type}/{org_repo}/revision/{commit}", + ) + + headers = {} + if authorization is not None: + headers["authorization"] = authorization + async with httpx.AsyncClient() as client: + response = await client.get(url, headers=headers, timeout=WORKER_API_TIMEOUT) + return response.status_code in [200, 307] diff --git a/olah/utils/rule_utils.py b/olah/utils/rule_utils.py new file mode 100644 index 0000000..0ad95e7 --- /dev/null +++ b/olah/utils/rule_utils.py @@ -0,0 +1,33 @@ +# coding=utf-8 +# Copyright 2024 XiaHan +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + + +from typing import Dict, Literal, Optional, Tuple, Union +from olah.configs import OlahConfig +from .repo_utils import get_org_repo + + +async def check_proxy_rules_hf( + app, + repo_type: Optional[Literal["models", "datasets", "spaces"]], + org: Optional[str], + repo: str, +) -> bool: + config: OlahConfig = app.app_settings.config + org_repo = get_org_repo(org, repo) + return config.proxy.allow(org_repo) + + +async def check_cache_rules_hf( + app, + repo_type: Optional[Literal["models", "datasets", "spaces"]], + org: Optional[str], + repo: str, +) -> bool: + config: OlahConfig = app.app_settings.config + org_repo = get_org_repo(org, repo) + return config.cache.allow(org_repo) diff --git a/olah/utils/url_utils.py b/olah/utils/url_utils.py index 85d8452..3f4d236 100644 --- a/olah/utils/url_utils.py +++ b/olah/utils/url_utils.py @@ -1,6 +1,6 @@ # coding=utf-8 # Copyright 2024 XiaHan -# +# # Use of this source code is governed by an MIT-style # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. @@ -10,113 +10,22 @@ import glob from typing import Dict, Literal, Optional, Tuple, Union import json -from urllib.parse import ParseResult, urljoin, urlparse +from urllib.parse import ParseResult, urlencode, urljoin, urlparse, parse_qs, urlunparse import httpx from olah.configs import OlahConfig from olah.constants import WORKER_API_TIMEOUT -def get_org_repo(org: Optional[str], repo: str) -> str: - if org is None: - org_repo = repo - else: - org_repo = f"{org}/{repo}" - return org_repo - -def parse_org_repo(org_repo: str) -> Tuple[Optional[str], Optional[str]]: - if "/" in org_repo and org_repo.count("/") != 1: - return None, None - if "/" in org_repo: - org, repo = org_repo.split("/") - else: - org = None - repo = org_repo - return org, repo - -def get_meta_save_path(repos_path: str, repo_type: str, org: Optional[str], repo: str, commit: str) -> str: - return os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}/meta.json") - -def get_meta_save_dir(repos_path: str, repo_type: str, org: Optional[str], repo: str) -> str: - return os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/revision") - -def get_file_save_path(repos_path: str, repo_type: str, org: Optional[str], repo: str, commit: str, file_path: str) -> str: - return os.path.join(repos_path, f"heads/{repo_type}/{org}/{repo}/resolve_head/{commit}/{file_path}") - -async def get_newest_commit_hf_offline(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: str, repo: str) -> str: - repos_path = app.app_settings.repos_path - save_dir = get_meta_save_dir(repos_path, repo_type, org, repo) - files = glob.glob(os.path.join(save_dir, "*", "meta.json")) - - time_revisions = [] - for file in files: - with open(file, "r", encoding="utf-8") as f: - obj = json.loads(f.read()) - datetime_object = datetime.datetime.fromisoformat(obj["lastModified"]) - time_revisions.append((datetime_object, obj["sha"])) - - time_revisions = sorted(time_revisions) - return time_revisions[-1][1] - -async def get_newest_commit_hf(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str) -> str: - url = urljoin(app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org}/{repo}") - if app.app_settings.config.offline: - return get_newest_commit_hf_offline(app, repo_type, org, repo) - try: - async with httpx.AsyncClient() as client: - response = await client.get(url, timeout=WORKER_API_TIMEOUT) - if response.status_code != 200: - return get_newest_commit_hf_offline(app, repo_type, org, repo) - obj = json.loads(response.text) - return obj.get("sha", None) - except: - return get_newest_commit_hf_offline(app, repo_type, org, repo) - -async def get_commit_hf_offline(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str, commit: str) -> str: - repos_path = app.app_settings.repos_path - save_path = get_meta_save_path(repos_path, repo_type, org, repo, commit) - if os.path.exists(save_path): - with open(save_path, "r", encoding="utf-8") as f: - obj = json.loads(f.read()) - return obj["sha"] - else: - return None - -async def get_commit_hf(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str, commit: str) -> str: - org_repo = get_org_repo(org, repo) - url = urljoin(app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}/revision/{commit}") - if app.app_settings.config.offline: - return await get_commit_hf_offline(app, repo_type, org, repo, commit) - try: - async with httpx.AsyncClient() as client: - response = await client.get(url, timeout=WORKER_API_TIMEOUT, follow_redirects=True) - if response.status_code not in [200, 307]: - return await get_commit_hf_offline(app, repo_type, org, repo, commit) - obj = json.loads(response.text) - return obj.get("sha", None) - except: - return await get_commit_hf_offline(app, repo_type, org, repo, commit) - -async def check_commit_hf(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str, commit: Optional[str]=None) -> bool: - org_repo = get_org_repo(org, repo) - if commit is None: - url = urljoin(app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}") - else: - url = urljoin(app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}/revision/{commit}") - - async with httpx.AsyncClient() as client: - response = await client.get(url, timeout=WORKER_API_TIMEOUT) - return response.status_code in [200, 307] - -async def check_proxy_rules_hf(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str) -> bool: - config: OlahConfig = app.app_settings.config - org_repo = get_org_repo(org, repo) - return config.proxy.allow(org_repo) - -async def check_cache_rules_hf(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str) -> bool: - config: OlahConfig = app.app_settings.config - org_repo = get_org_repo(org, repo) - return config.cache.allow(org_repo) def get_url_tail(parsed_url: Union[str, ParseResult]) -> str: + """ + Extracts the tail of a URL, including path, parameters, query, and fragment. + + Args: + parsed_url (Union[str, ParseResult]): The parsed URL or a string URL. + + Returns: + str: The tail of the URL, including path, parameters, query, and fragment. + """ if isinstance(parsed_url, str): parsed_url = urlparse(parsed_url) url_tail = parsed_url.path @@ -128,8 +37,22 @@ def get_url_tail(parsed_url: Union[str, ParseResult]) -> str: url_tail += f"#{parsed_url.fragment}" return url_tail + def parse_range_params(file_range: str, file_size: int) -> Tuple[int, int]: - # 'bytes=1887436800-' + """ + Parses the range parameters for a file request. + + Args: + file_range (str): The range parameter string, e.g., 'bytes=1887436800-'. + file_size (int): The size of the file. + + Returns: + Tuple[int, int]: A tuple of start and end positions for the file range. + """ + if "/" in file_range: + file_range, _file_size = file_range.split("/", maxsplit=1) + else: + file_range = file_range if file_range.startswith("bytes="): file_range = file_range[6:] start_pos, end_pos = file_range.split("-") @@ -146,6 +69,96 @@ def parse_range_params(file_range: str, file_size: int) -> Tuple[int, int]: class RemoteInfo(object): def __init__(self, method: str, url: str, headers: Dict[str, str]) -> None: + """ + Represents information about a remote request. + + Args: + method (str): The HTTP method of the request. + url (str): The URL of the request. + headers (Dict[str, str]): The headers of the request. + """ self.method = method self.url = url - self.headers = headers \ No newline at end of file + self.headers = headers + + +def check_url_has_param_name(url: str, param_name: str) -> bool: + """ + Checks if a URL contains a specific query parameter. + + Args: + url (str): The URL to check. + param_name (str): The name of the query parameter. + + Returns: + bool: True if the URL contains the parameter, False otherwise. + """ + parsed_url = urlparse(url) + query_params = parse_qs(parsed_url.query) + return param_name in query_params + + +def get_url_param_name(url: str, param_name: str) -> Optional[str]: + """ + Retrieves the value of a specific query parameter from a URL. + + Args: + url (str): The URL to retrieve the parameter from. + param_name (str): The name of the query parameter. + + Returns: + Optional[str]: The value of the query parameter if found, None otherwise. + """ + parsed_url = urlparse(url) + query_params = parse_qs(parsed_url.query) + original_location = query_params.get(param_name) + if original_location: + return original_location[0] + else: + return None + + +def add_query_param(url: str, param_name: str, param_value: str) -> str: + """ + Adds a query parameter to a URL. + + Args: + url (str): The URL to add the parameter to. + param_name (str): The name of the query parameter. + param_value (str): The value of the query parameter. + + Returns: + str: The modified URL with the added query parameter. + """ + parsed_url = urlparse(url) + query_params = parse_qs(parsed_url.query) + + query_params[param_name] = [param_value] + + new_query = urlencode(query_params, doseq=True) + new_url = urlunparse(parsed_url._replace(query=new_query)) + + return new_url + + +def remove_query_param(url: str, param_name: str) -> str: + """ + Removes a query parameter from a URL. + + Args: + url (str): The URL to remove the parameter from. + param_name (str): The name of the query parameter. + + Returns: + str: The modified URL with the parameter removed. + """ + parsed_url = urlparse(url) + query_params = parse_qs(parsed_url.query) + + if param_name in query_params: + del query_params[param_name] + + new_query = urlencode(query_params, doseq=True) + new_url = urlunparse(parsed_url._replace(query=new_query)) + + return new_url