diff --git a/README_zh.md b/README_zh.md index be6b8b9..288078a 100644 --- a/README_zh.md +++ b/README_zh.md @@ -40,7 +40,7 @@ pip install -e . python -m olah.server ``` -然后将环境变量`HF_ENDPOINT`设置为镜像站点(这里是http://localhost:8090)。 +然后将环境变量`HF_ENDPOINT`设置为镜像站点(这里是http://localhost:8090/)。 Linux: ```bash diff --git a/assets/full_configs.toml b/assets/full_configs.toml index 16faa55..d47ba69 100644 --- a/assets/full_configs.toml +++ b/assets/full_configs.toml @@ -4,10 +4,12 @@ port = 8090 ssl-key = "" ssl-cert = "" repos-path = "./repos" -hf-url = "https://huggingface.co" -hf-lfs-url = "https://cdn-lfs.huggingface.co" -mirror-url = "http://localhost:8090" -mirror-lfs-url = "http://localhost:8090" +hf-scheme = "https" +hf-netloc = "huggingface.co" +hf-lfs-netloc = "cdn-lfs.huggingface.co" +mirror-scheme = "http" +mirror-netloc = "localhost:8090" +mirror-lfs-netloc = "localhost:8090" [accessibility] [[accessibility.proxy]] diff --git a/olah/configs.py b/olah/configs.py index 23bbd62..ae96b1d 100644 --- a/olah/configs.py +++ b/olah/configs.py @@ -91,10 +91,14 @@ def __init__(self, path: Optional[str] = None) -> None: self.ssl_key = None self.ssl_cert = None self.repos_path = "./repos" - self.hf_url = "https://huggingface.co" - self.hf_lfs_url = "https://cdn-lfs.huggingface.co" - self.mirror_url = f"http://{self.host}:{self.port}" - self.mirror_lfs_url = f"http://{self.host}:{self.port}" + + self.hf_scheme: str = "https" + self.hf_netloc: str = "huggingface.co" + self.hf_lfs_netloc: str = "cdn-lfs.huggingface.co" + + self.mirror_scheme: str = "http" + self.mirror_netloc: str = "localhost:8090" + self.mirror_lfs_netloc: str = "localhost:8090" # accessibility self.offline = False @@ -103,12 +107,19 @@ def __init__(self, path: Optional[str] = None) -> None: if path is not None: self.read_toml(path) - - # refresh urls - self.mirror_url = f"http://{self.host}:{self.port}" - self.mirror_lfs_url = f"http://{self.host}:{self.port}" - + def hf_url_base(self) -> str: + return f"{self.hf_scheme}://{self.hf_netloc}" + + def hf_lfs_url_base(self) -> str: + return f"{self.hf_scheme}://{self.hf_lfs_netloc}" + + def mirror_url_base(self) -> str: + return f"{self.mirror_scheme}://{self.mirror_netloc}" + + def mirror_lfs_url_base(self) -> str: + return f"{self.mirror_scheme}://{self.mirror_lfs_netloc}" + def empty_str(self, s: str) -> Optional[str]: if s == "": return None @@ -125,10 +136,14 @@ def read_toml(self, path: str) -> None: self.ssl_key = self.empty_str(basic.get("ssl-key", self.ssl_key)) self.ssl_cert = self.empty_str(basic.get("ssl-cert", self.ssl_cert)) self.repos_path = basic.get("repos-path", self.repos_path) - self.hf_url = basic.get("hf-url", self.hf_url) - self.hf_lfs_url = basic.get("hf-lfs-url", self.hf_lfs_url) - self.mirror_url = basic.get("mirror-url", self.mirror_url) - self.mirror_lfs_url = basic.get("mirror-lfs-url", self.mirror_lfs_url) + + self.hf_scheme = basic.get("hf-scheme", self.hf_scheme) + self.hf_netloc = basic.get("hf-netloc", self.hf_netloc) + self.hf_lfs_netloc = basic.get("hf-lfs-netloc", self.hf_lfs_netloc) + + self.mirror_scheme = basic.get("mirror-scheme", self.mirror_scheme) + self.mirror_netloc = basic.get("mirror-netloc", self.mirror_netloc) + self.mirror_lfs_netloc = basic.get("mirror-lfs-netloc", self.mirror_lfs_netloc) if "accessibility" in config: accessibility = config["accessibility"] diff --git a/olah/files.py b/olah/files.py index 1b32a56..44b56a6 100644 --- a/olah/files.py +++ b/olah/files.py @@ -2,11 +2,13 @@ import os import shutil import tempfile -from typing import Literal +from typing import Literal, Optional from fastapi import Request +from requests.structures import CaseInsensitiveDict import httpx from starlette.datastructures import URL +from urllib.parse import urlparse, urljoin from olah.constants import ( CHUNK_SIZE, @@ -15,7 +17,7 @@ HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_LINKED_SIZE, ) -from olah.utils import check_cache_rules_hf, get_org_repo, make_dirs +from olah.utils import check_cache_rules_hf, get_org_repo, get_url_tail, make_dirs FILE_HEADER_TEMPLATE = { "accept-ranges": "bytes", "access-control-allow-origin": "*", @@ -30,12 +32,16 @@ async def _file_cache_stream(save_path: str, head_path: str, request: Request): if request.method.lower() == "head": with open(head_path, "r", encoding="utf-8") as f: response_headers = json.loads(f.read()) - new_headers = {k:v for k, v in FILE_HEADER_TEMPLATE.items()} + response_headers = {k.lower():v for k, v in response_headers.items()} + new_headers = {k.lower():v for k, v in FILE_HEADER_TEMPLATE.items()} new_headers["content-type"] = response_headers["content-type"] new_headers["content-length"] = response_headers["content-length"] - new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT] = response_headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT, None) - new_headers[HUGGINGFACE_HEADER_X_LINKED_ETAG] = response_headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG, None) - new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE] = response_headers.get(HUGGINGFACE_HEADER_X_LINKED_SIZE, None) + if HUGGINGFACE_HEADER_X_REPO_COMMIT.lower() in response_headers: + new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = response_headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT.lower(), "") + if HUGGINGFACE_HEADER_X_LINKED_ETAG.lower() in response_headers: + new_headers[HUGGINGFACE_HEADER_X_LINKED_ETAG.lower()] = response_headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG.lower(), "") + if HUGGINGFACE_HEADER_X_LINKED_SIZE.lower() in response_headers: + new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE.lower()] = response_headers.get(HUGGINGFACE_HEADER_X_LINKED_SIZE.lower(), "") new_headers["etag"] = response_headers["etag"] yield new_headers elif request.method.lower() == "get": @@ -49,8 +55,26 @@ async def _file_cache_stream(save_path: str, head_path: str, request: Request): break yield chunk +async def _get_redirected_url(client, method: str, url, headers): + async with client.stream( + method=method, + url=url, + headers=headers, + timeout=WORKER_API_TIMEOUT, + ) as response: + if response.status_code >= 300 and response.status_code <= 399: + from_url = urlparse(url) + parsed_url = urlparse(response.headers["location"]) + if len(parsed_url.netloc) == 0: + redirect_loc = urljoin(f"{from_url.scheme}://{from_url.netloc}", response.headers["location"]) + else: + redirect_loc = response.headers["location"] + else: + redirect_loc = url + return redirect_loc + async def _file_realtime_stream( - app, save_path: str, head_path: str, url: str, request: Request, method="GET", allow_cache=True + app, save_path: str, head_path: str, url: str, request: Request, method="GET", allow_cache=True, commit: Optional[str]=None ): request_headers = {k: v for k, v in request.headers.items()} request_headers.pop("host") @@ -65,33 +89,26 @@ async def _file_realtime_stream( else: write_temp_file = True - async with client.stream( - method=method, - url=url, - headers=request_headers, - timeout=WORKER_API_TIMEOUT, - ) as response: - if response.status_code >= 300 and response.status_code <= 399: - redirect_loc = app.app_settings.hf_url + response.headers["location"] - else: - redirect_loc = url - + redirect_loc = await _get_redirected_url(client, method, url, request_headers) async with client.stream( method=method, url=redirect_loc, headers=request_headers, timeout=WORKER_API_TIMEOUT, ) as response: - response_headers = response.headers - response_headers_dict = {k: v for k, v in response_headers.items()} + response_headers_dict = {k.lower(): v for k, v in response.headers.items()} if allow_cache: if request.method.lower() == "head": with open(head_path, "w", encoding="utf-8") as f: f.write(json.dumps(response_headers_dict, ensure_ascii=False)) if "location" in response_headers_dict: - response_headers_dict["location"] = response_headers_dict["location"].replace( - app.app_settings.hf_lfs_url, app.app_settings.mirror_lfs_url - ) + location_url = urlparse(response_headers_dict["location"]) + if location_url.netloc == app.app_settings.config.hf_lfs_netloc: + response_headers_dict["location"] = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(location_url)) + else: + response_headers_dict["location"] = urljoin(app.app_settings.config.mirror_url_base(), get_url_tail(location_url)) + if commit is not None: + response_headers_dict[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = commit yield response_headers_dict async for raw_chunk in response.aiter_raw(): @@ -137,9 +154,9 @@ async def file_head_generator( return _file_cache_stream(save_path=save_path, head_path=head_path, request=request) else: if repo_type == "models": - url = f"{app.app_settings.hf_url}/{org_repo}/resolve/{commit}/{file_path}" + url = urljoin(app.app_settings.config.hf_url_base(), f"/{org_repo}/resolve/{commit}/{file_path}") else: - url = f"{app.app_settings.hf_url}/{repo_type}/{org_repo}/resolve/{commit}/{file_path}" + url = urljoin(app.app_settings.config.hf_url_base(), f"/{repo_type}/{org_repo}/resolve/{commit}/{file_path}") return _file_realtime_stream( app=app, save_path=save_path, @@ -148,6 +165,7 @@ async def file_head_generator( request=request, method="HEAD", allow_cache=allow_cache, + commit=commit, ) @@ -180,9 +198,9 @@ async def file_get_generator( return _file_cache_stream(save_path=save_path, head_path=head_path, request=request) else: if repo_type == "models": - url = f"{app.app_settings.hf_url}/{org_repo}/resolve/{commit}/{file_path}" + url = urljoin(app.app_settings.config.hf_url_base(), f"/{org_repo}/resolve/{commit}/{file_path}") else: - url = f"{app.app_settings.hf_url}/{repo_type}/{org_repo}/resolve/{commit}/{file_path}" + url = urljoin(app.app_settings.config.hf_url_base(), f"/{repo_type}/{org_repo}/resolve/{commit}/{file_path}") return _file_realtime_stream( app=app, save_path=save_path, @@ -191,6 +209,7 @@ async def file_get_generator( request=request, method="GET", allow_cache=allow_cache, + commit=commit, ) @@ -224,8 +243,11 @@ async def cdn_file_get_generator( if use_cache: return _file_cache_stream(save_path=save_path, request=request) else: - redirected_url = str(request.url) - redirected_url = redirected_url.replace(app.app_settings.mirror_lfs_url, app.app_settings.hf_lfs_url) + request_url = urlparse(request.url) + if request_url.netloc == app.app_settings.config.hf_lfs_netloc: + redirected_url = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(request_url)) + else: + redirected_url = urljoin(app.app_settings.config.mirror_url_base(), get_url_tail(request_url)) return _file_realtime_stream( app=app, diff --git a/olah/lfs.py b/olah/lfs.py index 113e03f..0137bca 100644 --- a/olah/lfs.py +++ b/olah/lfs.py @@ -1,3 +1,7 @@ +""" +废弃方法 +""" + import datetime import json import os @@ -8,7 +12,7 @@ import pytz from olah.constants import CHUNK_SIZE, LFS_FILE_BLOCK, WORKER_API_TIMEOUT -from olah.utils import make_dirs +from olah.utils.file_utils import make_dirs async def lfs_get_generator(app, repo_type: str, lfs_url: str, save_path: str, request: Request): diff --git a/olah/meta.py b/olah/meta.py index afffd2a..3539067 100644 --- a/olah/meta.py +++ b/olah/meta.py @@ -5,13 +5,15 @@ import shutil import tempfile from typing import Dict, Literal +from urllib.parse import urljoin from fastapi import FastAPI, Request import httpx from olah.configs import OlahConfig from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT -from olah.utils import check_cache_rules_hf, get_org_repo, make_dirs +from olah.utils.url_utils import check_cache_rules_hf, get_org_repo +from olah.utils.file_utils import make_dirs async def meta_cache_generator(app: FastAPI, save_path: str): yield {} @@ -67,7 +69,7 @@ async def meta_generator(app: FastAPI, repo_type: Literal["models", "datasets"], allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) org_repo = get_org_repo(org, repo) - meta_url = f"{app.app_settings.hf_url}/api/{repo_type}/{org_repo}/revision/{commit}" + meta_url = urljoin(app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}/revision/{commit}") # proxy if use_cache: async for item in meta_cache_generator(app, save_path): diff --git a/olah/server.py b/olah/server.py index 61ae5b4..7fa249a 100644 --- a/olah/server.py +++ b/olah/server.py @@ -5,6 +5,7 @@ import tempfile import shutil from typing import Annotated, Optional, Union +from urllib.parse import urljoin from fastapi import FastAPI, Header, Request from fastapi.responses import HTMLResponse, StreamingResponse, Response import httpx @@ -13,7 +14,7 @@ from olah.files import cdn_file_get_generator, file_get_generator, file_head_generator from olah.lfs import lfs_get_generator from olah.meta import meta_generator -from olah.utils import check_proxy_rules_hf, check_commit_hf, get_commit_hf, get_newest_commit_hf, parse_org_repo +from olah.utils.url_utils import check_proxy_rules_hf, check_commit_hf, get_commit_hf, get_newest_commit_hf, parse_org_repo app = FastAPI(debug=False) @@ -21,11 +22,10 @@ class AppSettings(BaseSettings): # The address of the model controller. config: OlahConfig = OlahConfig() repos_path: str = "./repos" - hf_url: str = "https://huggingface.co" - hf_lfs_url: str = "https://cdn-lfs.huggingface.co" - mirror_url: str = "http://localhost:8090" - mirror_lfs_url: str = "http://localhost:8090" +# ====================== +# API Hooks +# ====================== @app.get("/api/{repo_type}/{org_repo}") async def meta_proxy(repo_type: str, org_repo: str, request: Request): org, repo = parse_org_repo(org_repo) @@ -65,6 +65,10 @@ async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request: headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) + +# ====================== +# File Head Hooks +# ====================== @app.head("/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path:path}") async def file_head_proxy2(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str): if not await check_proxy_rules_hf(app, repo_type, org, repo): @@ -113,6 +117,9 @@ async def file_head_proxy_default_type(org_repo: str, commit: str, file_path: st headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) +# ====================== +# File Hooks +# ====================== @app.get("/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path:path}") async def file_proxy2(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str): if not await check_proxy_rules_hf(app, repo_type, org, repo): @@ -160,7 +167,6 @@ async def file_proxy_default_type(org_repo: str, commit: str, file_path: str, re headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) - @app.get("/{repo_type}/{org_repo}/{hash_file}") async def cdn_file_proxy(org_repo: str, hash_file: str, request: Request, repo_type: str = "models"): org, repo = parse_org_repo(org_repo) @@ -174,11 +180,13 @@ async def cdn_file_proxy(org_repo: str, hash_file: str, request: Request, repo_t headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) - +# ====================== +# LFS Hooks +# ====================== @app.get("/repos/{dir1}/{dir2}/{hash_repo}/{hash_file}") async def lfs_proxy(dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request): repo_type = "models" - lfs_url = f"{app.app_settings.hf_lfs_url}/repos/{dir1}/{dir2}/{hash_repo}/{hash_file}" + lfs_url = urljoin(app.app_settings.config.hf_lfs_url_base(), f"/repos/{dir1}/{dir2}/{hash_repo}/{hash_file}") save_path = f"{dir1}/{dir2}/{hash_repo}/{hash_file}" generator = lfs_get_generator(app, repo_type, lfs_url, save_path, request) headers = await generator.__anext__() @@ -187,12 +195,15 @@ async def lfs_proxy(dir1: str, dir2: str, hash_repo: str, hash_file: str, reques @app.get("/datasets/hendrycks_test/{hash_file}") async def lfs_proxy(hash_file: str, request: Request): repo_type = "datasets" - lfs_url = f"{app.app_settings.hf_lfs_url}/datasets/hendrycks_test/{hash_file}" + lfs_url = urljoin(app.app_settings.config.hf_lfs_url_base(), f"/datasets/hendrycks_test/{hash_file}") save_path = f"hendrycks_test/{hash_file}" generator = lfs_get_generator(app, repo_type, lfs_url, save_path, request) headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) +# ====================== +# Web Page Hooks +# ====================== @app.get("/", response_class=HTMLResponse) async def index(): with open(os.path.join(os.path.dirname(__file__), "../static/index.html"), "r", encoding="utf-8") as f: @@ -209,10 +220,6 @@ async def index(): parser.add_argument("--ssl-key", type=str, default=None) parser.add_argument("--ssl-cert", type=str, default=None) parser.add_argument("--repos-path", type=str, default="./repos") - parser.add_argument("--hf-url", type=str, default="https://huggingface.co") - parser.add_argument("--hf-lfs-url", type=str, default="https://cdn-lfs.huggingface.co") - parser.add_argument("--mirror-url", type=str, default="http://localhost:8090") - parser.add_argument("--mirror-lfs-url", type=str, default="http://localhost:8090") args = parser.parse_args() def is_default_value(args, arg_name): @@ -237,22 +244,10 @@ def is_default_value(args, arg_name): args.ssl_cert = config.ssl_cert if is_default_value(args, "repos_path"): args.repos_path = config.repos_path - if is_default_value(args, "hf_url"): - args.hf_url = config.hf_url - if is_default_value(args, "hf_lfs_url"): - args.hf_lfs_url = config.hf_lfs_url - if is_default_value(args, "mirror_url"): - args.mirror_url = config.mirror_url - if is_default_value(args, "mirror_lfs_url"): - args.mirror_lfs_url = config.mirror_lfs_url app.app_settings = AppSettings( config=config, repos_path=args.repos_path, - hf_url=args.hf_url, - hf_lfs_url=args.hf_lfs_url, - mirror_url=args.mirror_url, - mirror_lfs_url=args.mirror_lfs_url, ) import uvicorn diff --git a/olah/utils/__init__.py b/olah/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/olah/utils/bitset.py b/olah/utils/bitset.py new file mode 100644 index 0000000..741ee7e --- /dev/null +++ b/olah/utils/bitset.py @@ -0,0 +1,29 @@ +class Bitset: + def __init__(self, size): + self.size = size + self.bits = bytearray((0, ) * ((size + 7) // 8)) + + def set(self, index): + if index < 0 or index >= self.size: + raise IndexError("Index out of range") + byte_index = index // 8 + bit_index = index % 8 + self.bits[byte_index] |= (1 << bit_index) + + def clear(self, index): + if index < 0 or index >= self.size: + raise IndexError("Index out of range") + self._resize_if_needed(index) + byte_index = index // 8 + bit_index = index % 8 + self.bits[byte_index] &= ~(1 << bit_index) + + def test(self, index): + if index < 0 or index >= self.size: + raise IndexError("Index out of range") + byte_index = index // 8 + bit_index = index % 8 + return bool(self.bits[byte_index] & (1 << bit_index)) + + def __str__(self): + return ''.join(bin(byte)[2:].zfill(8) for byte in self.bits) diff --git a/olah/utils/file_utils.py b/olah/utils/file_utils.py new file mode 100644 index 0000000..447f36b --- /dev/null +++ b/olah/utils/file_utils.py @@ -0,0 +1,12 @@ + + +import os + + +def make_dirs(path: str): + if os.path.isdir(path): + save_dir = path + else: + save_dir = os.path.dirname(path) + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) \ No newline at end of file diff --git a/olah/utils/olah_cache.py b/olah/utils/olah_cache.py new file mode 100644 index 0000000..6c699ef --- /dev/null +++ b/olah/utils/olah_cache.py @@ -0,0 +1,185 @@ +from io import BufferedReader +import os +import struct +from typing import BinaryIO, Optional +from .bitset import Bitset + +CURRENT_OLAH_CACHE_VERSION = 8 +DEFAULT_BLOCK_MASK_MAX = 1024 * 1024 +DEFAULT_BLOCK_SIZE = 64 * 1024 * 1024 + + +class OlahCacheHeader(object): + MAGIC_NUMBER = "OLAH".encode("ascii") + HEADER_FIX_SIZE = 36 + + def __init__( + self, + version: int = CURRENT_OLAH_CACHE_VERSION, + block_size: int = DEFAULT_BLOCK_SIZE, + file_size: int = 0, + ) -> None: + self.version = version + self.block_size = block_size + + self.file_size = file_size + self.block_number = (file_size + block_size - 1) // block_size + + self.block_mask_size = DEFAULT_BLOCK_MASK_MAX + self.block_mask = Bitset(DEFAULT_BLOCK_MASK_MAX) + + def get_header_size(self): + return self.HEADER_FIX_SIZE + len(self.block_mask.bits) + + def _valid_header(self): + if self.file_size > self.block_mask_size * self.block_size: + raise Exception( + f"The size of file {self.file_size} is out of the max capability of container ({self.block_mask_size} * {self.block_size})." + ) + + @staticmethod + def read(stream) -> "OlahCacheHeader": + obj = OlahCacheHeader() + magic, version, block_size, file_size, block_mask_size = struct.unpack( + "<4sQQQQ", stream.read(OlahCacheHeader.HEADER_FIX_SIZE) + ) + if magic != OlahCacheHeader.MAGIC_NUMBER: + raise Exception("The file is not a valid olah cache file.") + obj.version = version + obj.block_size = block_size + obj.file_size = file_size + obj.block_number = (file_size + block_size - 1) // block_size + obj.block_mask_size = block_mask_size + obj.block_mask = Bitset(block_mask_size) + obj.block_mask.bits = bytearray(stream.read((block_mask_size + 7) // 8)) + + obj._valid_header() + return obj + + def write(self, stream): + btyes_header = struct.pack( + "<4sQQQQ", + self.MAGIC_NUMBER, + self.version, + self.block_size, + self.file_size, + self.block_mask_size, + ) + btyes_out = btyes_header + self.block_mask.bits + stream.write(btyes_out) + +class OlahCache(object): + def __init__(self, path: str, block_size: int = DEFAULT_BLOCK_SIZE) -> None: + self.path: Optional[str] = path + self.header: Optional[OlahCacheHeader] = None + self.is_open: bool = False + self.open(path, block_size=block_size) + + @staticmethod + def create(path: str, block_size: int = DEFAULT_BLOCK_SIZE): + return OlahCache(path, block_size=block_size) + + def open(self, path: str, block_size: int = DEFAULT_BLOCK_SIZE): + if self.is_open: + raise Exception("This file has been open.") + if os.path.exists(path): + with open(path, "rb") as f: + f.seek(0) + self.header = OlahCacheHeader.read(f) + else: + # Create new file + with open(path, "wb") as f: + f.seek(0) + self.header = OlahCacheHeader( + version=CURRENT_OLAH_CACHE_VERSION, + block_size=block_size, + file_size=0, + ) + self.header.write(f) + + self.is_open = True + + def close(self): + if not self.is_open: + raise Exception("This file has been close.") + + self._flush_header() + self.path = None + self.header = None + + self.is_open = False + + def _flush_header(self): + with open(self.path, "rb+") as f: + f.seek(0) + self.header.write(f) + + def flush(self): + if not self.is_open: + raise Exception("This file has been close.") + self._flush_header() + + def _has_block(self, block_index: int) -> bool: + return self.header.block_mask.test(block_index) + + def _read_block(self, block_index: int) -> Optional[bytes]: + if not self.is_open: + raise Exception("This file has been closed.") + + if block_index >= self.header.block_number: + raise Exception("Invalid block index.") + + if not self._has_block(block_index=block_index): + return None + + offset = self.header.get_header_size() + (block_index * self.header.block_size) + with open(self.path, "rb") as f: + f.seek(offset) + return f.read(self.header.block_size) + + def _write_block(self, block_index: int, block_bytes: bytes) -> None: + if not self.is_open: + raise Exception("This file has been closed.") + + if block_index >= self.header.block_number: + raise Exception("Invalid block index.") + + if len(block_bytes) != self.header.block_size: + raise Exception("Block size does not match the cache's block size.") + + offset = self.header.get_header_size() + (block_index * self.header.block_size) + with open(self.path, "rb+") as f: + f.seek(offset) + f.write(block_bytes) + + self.header.block_mask.set(block_index) + + def _resize_blocks(self, block_num: int): + if not self.is_open: + raise Exception("This file has been closed.") + if block_num == self.header.block_number: + return + if block_num <= self.header.block_number: + raise Exception("Invalid block number. New block number must be greater than the current block number.") + + with open(self.path, "rb") as f: + f.seek(0, os.SEEK_END) + bin_size = f.tell() + + new_bin_size = self.header.get_header_size() + block_num * self.header.block_size + with open(self.path, "rb+") as f: + # Extend file size + f.seek(0, os.SEEK_END) + f.write(b'\x00' * (new_bin_size - bin_size)) + + def resize(self, file_size: int): + if not self.is_open: + raise Exception("This file has been closed.") + new_block_num = (file_size + self.header.block_size - 1) // self.header.block_size + self._resize_blocks(new_block_num) + + self.header.block_number = new_block_num + self.header.file_size = file_size + self.header._valid_header() + self._flush_header() + \ No newline at end of file diff --git a/olah/utils.py b/olah/utils/url_utils.py similarity index 79% rename from olah/utils.py rename to olah/utils/url_utils.py index 5b29ce6..575af8a 100644 --- a/olah/utils.py +++ b/olah/utils/url_utils.py @@ -4,6 +4,7 @@ import glob from typing import Literal, Optional, Tuple import json +from urllib.parse import ParseResult, urljoin import httpx from olah.configs import OlahConfig from olah.constants import WORKER_API_TIMEOUT @@ -50,7 +51,7 @@ async def get_newest_commit_hf_offline(app, repo_type: Optional[Literal["models" return time_revisions[-1][1] async def get_newest_commit_hf(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str) -> str: - url = f"{app.app_settings.hf_url}/api/{repo_type}/{org}/{repo}" + url = urljoin(app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org}/{repo}") if app.app_settings.config.offline: return get_newest_commit_hf_offline(app, repo_type, org, repo) try: @@ -74,7 +75,7 @@ async def get_commit_hf_offline(app, repo_type: Optional[Literal["models", "data async def get_commit_hf(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str, commit: str) -> str: org_repo = get_org_repo(org, repo) - url = f"{app.app_settings.hf_url}/api/{repo_type}/{org_repo}/revision/{commit}" + url = urljoin(app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}/revision/{commit}") if app.app_settings.config.offline: return await get_commit_hf_offline(app, repo_type, org, repo, commit) try: @@ -90,9 +91,9 @@ async def get_commit_hf(app, repo_type: Optional[Literal["models", "datasets", " async def check_commit_hf(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str, commit: Optional[str]=None) -> bool: org_repo = get_org_repo(org, repo) if commit is None: - url = f"{app.app_settings.hf_url}/api/{repo_type}/{org_repo}" + url = urljoin(app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}") else: - url = f"{app.app_settings.hf_url}/api/{repo_type}/{org_repo}/revision/{commit}" + url = urljoin(app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}/revision/{commit}") async with httpx.AsyncClient() as client: response = await client.get(url, timeout=WORKER_API_TIMEOUT) @@ -108,10 +109,27 @@ async def check_cache_rules_hf(app, repo_type: Optional[Literal["models", "datas org_repo = get_org_repo(org, repo) return config.cache.allow(f"{org_repo}") -def make_dirs(path: str): - if os.path.isdir(path): - save_dir = path +def get_url_tail(parsed_url: ParseResult) -> str: + url_tail = parsed_url.path + if len(parsed_url.params) != 0: + url_tail += f";{parsed_url.params}" + if len(parsed_url.query) != 0: + url_tail += f"?{parsed_url.query}" + if len(parsed_url.fragment) != 0: + url_tail += f"#{parsed_url.fragment}" + return url_tail + +def parse_range_params(file_range: str, file_size: int) -> Tuple[int, int]: + # 'bytes=1887436800-' + if file_range.startswith("bytes="): + file_range = file_range[6:] + start_pos, end_pos = file_range.split("-") + if len(start_pos) != 0: + start_pos = int(start_pos) else: - save_dir = os.path.dirname(path) - if not os.path.exists(save_dir): - os.makedirs(save_dir, exist_ok=True) \ No newline at end of file + start_pos = 0 + if len(end_pos) != 0: + end_pos = int(end_pos) + else: + end_pos = file_size + return start_pos, end_pos