Skip to content

Commit

Permalink
local mirros supports
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jul 19, 2024
1 parent 7933abe commit e962f2c
Show file tree
Hide file tree
Showing 15 changed files with 362 additions and 60 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

/mirrors_dir/
/model_dir/
/dataset_dir/
/repos/
Expand Down
1 change: 1 addition & 0 deletions assets/full_configs.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ hf-lfs-netloc = "cdn-lfs.huggingface.co"
mirror-scheme = "http"
mirror-netloc = "localhost:8090"
mirror-lfs-netloc = "localhost:8090"
mirrors-path = ["./mirrors_dir"]

[accessibility]
offline = false
Expand Down
4 changes: 4 additions & 0 deletions olah/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def __init__(self, path: Optional[str] = None) -> None:
self.mirror_netloc: str = "localhost:8090"
self.mirror_lfs_netloc: str = "localhost:8090"

self.mirrors_path: List[str] = []

# accessibility
self.offline = False
self.proxy = OlahRuleList.from_list(DEFAULT_PROXY_RULES)
Expand Down Expand Up @@ -139,6 +141,8 @@ def read_toml(self, path: str) -> None:
"mirror-lfs-netloc", self.mirror_lfs_netloc
)

self.mirrors_path = basic.get("mirrors-path", self.mirrors_path)

if "accessibility" in config:
accessibility = config["accessibility"]
self.offline = accessibility.get("offline", self.offline)
Expand Down
1 change: 1 addition & 0 deletions olah/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DEFAULT_LOGGER_DIR = "./logs"

from huggingface_hub.constants import (
REPO_TYPES_MAPPING,
HUGGINGFACE_CO_URL_TEMPLATE,
HUGGINGFACE_HEADER_X_REPO_COMMIT,
HUGGINGFACE_HEADER_X_LINKED_ETAG,
Expand Down
15 changes: 15 additions & 0 deletions olah/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@



from fastapi.responses import JSONResponse


def error_repo_not_found() -> JSONResponse:
return JSONResponse(
content={"error": "Repository not found"},
headers={
"x-error-code": "RepoNotFound",
"x-error-message": "Repository not found",
},
status_code=401,
)
Empty file added olah/mirror/__init__.py
Empty file.
40 changes: 40 additions & 0 deletions olah/mirror/meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@


class RepoMeta(object):
def __init__(self) -> None:
self._id = None
self.id = None
self.author = None
self.sha = None
self.lastModified = None
self.private = False
self.gated = False
self.disabled = False
self.tags = []
self.description = ""
self.paperswithcode_id = None
self.downloads = 0
self.likes = 0
self.cardData = None
self.siblings = None
self.createdAt = None

def to_dict(self):
return {
"_id": self._id,
"id": self.id,
"author": self.author,
"sha": self.sha,
"lastModified": self.lastModified,
"private": self.private,
"gated": self.gated,
"disabled": self.disabled,
"tags": self.tags,
"description": self.description,
"paperswithcode_id": self.paperswithcode_id,
"downloads": self.downloads,
"likes": self.likes,
"cardData": self.cardData,
"siblings": self.siblings,
"createdAt": self.createdAt,
}
167 changes: 167 additions & 0 deletions olah/mirror/repos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# coding=utf-8
# Copyright 2024 XiaHan
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
import hashlib
import io
import os
import re
from typing import Any, Dict, List, Union
import gitdb
from git import Commit, Optional, Repo, Tree
from gitdb.base import OStream
import yaml

from olah.mirror.meta import RepoMeta
class LocalMirrorRepo(object):
def __init__(self, path: str, repo_type: str, org: str, repo: str) -> None:
self._path = path
self._repo_type = repo_type
self._org = org
self._repo = repo

self._git_repo = Repo(self._path)

def _sha256(self, text: Union[str, bytes]) -> str:
if isinstance(text, bytes) or isinstance(text, bytearray):
bin = text
elif isinstance(text, str):
bin = text.encode('utf-8')
else:
raise Exception("Invalid sha256 param type.")
sha256_hash = hashlib.sha256()
sha256_hash.update(bin)
hashed_string = sha256_hash.hexdigest()
return hashed_string

def _match_card(self, readme: str) -> str:
pattern = r'\s*---(.*?)---'

match = re.match(pattern, readme, flags=re.S)

if match:
card_string = match.group(1)
return card_string
else:
return ""
def _remove_card(self, readme: str) -> str:
pattern = r'\s*---(.*?)---'
out = re.sub(pattern, "", readme, flags=re.S)
return out

def _get_readme(self, commit: Commit) -> str:
if "README.md" not in commit.tree:
return ""
else:
out: bytes = commit.tree["README.md"].data_stream.read()
return out.decode()

def _get_description(self, commit: Commit) -> str:
readme = self._get_readme(commit)
return self._remove_card(readme)

def _get_entry_files(self, tree, include_dir=False) -> List[str]:
out_paths = []
for entry in tree:
if entry.type == "tree":
out_paths.extend(self._get_entry_files(entry))
if include_dir:
out_paths.append(entry.path)
else:
out_paths.append(entry.path)
return out_paths

def _get_tree_files(self, commit: Commit) -> List[str]:
return self._get_entry_files(commit.tree)


def _get_earliest_commit(self) -> Commit:
earliest_commit = None
earliest_commit_date = None

for commit in self._git_repo.iter_commits():
commit_date = commit.committed_datetime

if earliest_commit_date is None or commit_date < earliest_commit_date:
earliest_commit = commit
earliest_commit_date = commit_date

return earliest_commit

def get_meta(self, commit_hash: str) -> Dict[str, Any]:
try:
commit = self._git_repo.commit(commit_hash)
except gitdb.exc.BadName:
return None
meta = RepoMeta()

meta._id = self._sha256(f"{self._org}/{self._repo}/{commit.hexsha}")
meta.id = f"{self._org}/{self._repo}"
meta.author = self._org
meta.sha = commit.hexsha
meta.lastModified = self._git_repo.head.commit.committed_datetime.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
meta.private = False
meta.gated = False
meta.disabled = False
meta.tags = []
meta.description = self._get_description(commit)
meta.paperswithcode_id = None
meta.downloads = 0
meta.likes = 0
meta.cardData = yaml.load(self._match_card(self._get_readme(commit)), Loader=yaml.CLoader)
meta.siblings = [{"rfilename": p} for p in self._get_tree_files(commit)]
meta.createdAt = self._get_earliest_commit().committed_datetime.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
return meta.to_dict()

def _contain_path(self, path: str, tree: Tree) -> bool:
norm_p = os.path.normpath(path).replace("\\", "/")
parts = norm_p.split("/")
for part in parts:
if all([t.name != part for t in tree]):
return False
else:
entry = tree[part]
if entry.type == "tree":
tree = entry
else:
tree = {}
return True

def get_file_head(self, commit_hash: str, path: str) -> Optional[Dict[str, Any]]:
try:
commit = self._git_repo.commit(commit_hash)
except gitdb.exc.BadName:
return None

if not self._contain_path(path, commit.tree):
return None
else:
header = {}
header["content-length"] = str(commit.tree[path].data_stream.size)
header["x-repo-commit"] = commit.hexsha
header["etag"] = self._sha256(commit.tree[path].data_stream.read())
return header

def get_file(self, commit_hash: str, path: str) -> Optional[OStream]:
try:
commit = self._git_repo.commit(commit_hash)
except gitdb.exc.BadName:
return None

def stream_wrapper(file_bytes: bytes):
file_stream = io.BytesIO(file_bytes)
while True:
chunk = file_stream.read(4096)
if len(chunk) == 0:
break
else:
yield chunk

if not self._contain_path(path, commit.tree):
return None
else:
return stream_wrapper(commit.tree[path].data_stream.read())


Empty file added olah/proxy/__init__.py
Empty file.
82 changes: 56 additions & 26 deletions olah/files.py → olah/proxy/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@
from olah.utils.file_utils import make_dirs
from olah.constants import CHUNK_SIZE, LFS_FILE_BLOCK, WORKER_API_TIMEOUT

async def _write_cache_request(head_path: str, status_code: int, headers: Dict[str, str], content: bytes):
rq = {
"status_code": status_code,
"headers": headers,
"content": content.hex(),
}
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(rq, ensure_ascii=False))

async def _read_cache_request(head_path: str):
with open(head_path, "r", encoding="utf-8") as f:
rq = json.loads(f.read())

rq["content"] = bytes.fromhex(rq["content"])
return rq

async def _file_full_header(
app,
save_path: str,
Expand All @@ -38,12 +54,18 @@ async def _file_full_header(
headers: Dict[str, str],
allow_cache: bool,
) -> Tuple[int, Dict[str, str], bytes]:
if os.path.exists(head_path):
with open(head_path, "r", encoding="utf-8") as f:
response_headers = json.loads(f.read())
response_headers_dict = {k.lower():v for k, v in response_headers.items()}
else:
if not app.app_settings.config.offline:
assert method.lower() == "head"
if not app.app_settings.config.offline:
if os.path.exists(head_path):
cache_rq = await _read_cache_request(head_path)
response_headers_dict = {k.lower():v for k, v in cache_rq["headers"].items()}
if "location" in response_headers_dict:
parsed_url = urlparse(response_headers_dict["location"])
if len(parsed_url.netloc) != 0:
new_loc = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response_headers_dict["location"]))
response_headers_dict["location"] = new_loc
return cache_rq["status_code"], response_headers_dict, cache_rq["content"]
else:
if "range" in headers:
headers.pop("range")
response = await client.request(
Expand All @@ -55,11 +77,9 @@ async def _file_full_header(
response_headers_dict = {k.lower(): v for k, v in response.headers.items()}
if allow_cache and method.lower() == "head":
if response.status_code == 200:
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(response_headers_dict, ensure_ascii=False))
await _write_cache_request(head_path, response.status_code, response_headers_dict, response.content)
elif response.status_code >= 300 and response.status_code <= 399:
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(response_headers_dict, ensure_ascii=False))
await _write_cache_request(head_path, response.status_code, response_headers_dict, response.content)
from_url = urlparse(url)
parsed_url = urlparse(response.headers["location"])
if len(parsed_url.netloc) != 0:
Expand All @@ -68,25 +88,34 @@ async def _file_full_header(
else:
raise Exception(f"Unexpected HTTP status code {response.status_code}")
return response.status_code, response_headers_dict, response.content
else:
if os.path.exists(head_path):
cache_rq = await _read_cache_request(head_path)
response_headers_dict = {k.lower():v for k, v in cache_rq["headers"].items()}
else:
response_headers_dict = {}
cache_rq = {
"status_code": 200,
"headers": response_headers_dict,
"content": b"",
}

new_headers = {}
if "content-type" in response_headers_dict:
new_headers["content-type"] = response_headers_dict["content-type"]
if "content-length" in response_headers_dict:
new_headers["content-length"] = response_headers_dict["content-length"]
if HUGGINGFACE_HEADER_X_REPO_COMMIT.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_REPO_COMMIT.lower(), "")
if HUGGINGFACE_HEADER_X_LINKED_ETAG.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_LINKED_ETAG.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_ETAG.lower(), "")
if HUGGINGFACE_HEADER_X_LINKED_SIZE.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_SIZE.lower(), "")
if "etag" in response_headers_dict:
new_headers["etag"] = response_headers_dict["etag"]
if "location" in response_headers_dict:
new_headers["location"] = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response_headers_dict["location"]))
return 200, new_headers, b""
new_headers = {}
if "content-type" in response_headers_dict:
new_headers["content-type"] = response_headers_dict["content-type"]
if "content-length" in response_headers_dict:
new_headers["content-length"] = response_headers_dict["content-length"]
if HUGGINGFACE_HEADER_X_REPO_COMMIT.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_REPO_COMMIT.lower(), "")
if HUGGINGFACE_HEADER_X_LINKED_ETAG.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_LINKED_ETAG.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_ETAG.lower(), "")
if HUGGINGFACE_HEADER_X_LINKED_SIZE.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_SIZE.lower(), "")
if "etag" in response_headers_dict:
new_headers["etag"] = response_headers_dict["etag"]
if "location" in response_headers_dict:
new_headers["location"] = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response_headers_dict["location"]))
return cache_rq["status_code"], new_headers, cache_rq["content"]

async def _get_file_block_from_cache(cache_file: OlahCache, block_index: int):
raw_block = cache_file.read_block(block_index)
Expand Down Expand Up @@ -240,6 +269,7 @@ async def _file_realtime_stream(
yield head_info
yield content
return

file_size = int(head_info["content-length"])
response_headers = {k: v for k,v in head_info.items()}
if "range" in request_headers:
Expand Down
2 changes: 1 addition & 1 deletion olah/lfs.py → olah/proxy/lfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Literal
from fastapi import FastAPI, Header, Request

from olah.files import _file_realtime_stream
from olah.proxy.files import _file_realtime_stream
from olah.utils.file_utils import make_dirs
from olah.utils.url_utils import check_cache_rules_hf, get_org_repo

Expand Down
File renamed without changes.
Loading

0 comments on commit e962f2c

Please sign in to comment.