Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jul 14, 2024
1 parent e9d9ba8 commit 8ad38c1
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 70 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# olah
Olah is self-hosted lightweight huggingface mirror service. `Olah` means `hello` in Hilichurlian.
Olah implemented the `mirroring` feature for huggingface resources, rather than just a simple `reverse proxy`.
Olah does not immediately mirror the entire huggingface website but mirrors the resources at the file block level when users download them (or we can say cache them).

Other languages: [中文](README_zh.md)

Expand Down
5 changes: 4 additions & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
Olah是一种自托管的轻量级HuggingFace镜像服务。`Olah`在丘丘人语中意味着`你好`
# olah
Olah是一种自托管的轻量级HuggingFace镜像服务。`Olah`来源于丘丘人语,在丘丘人语中意味着`你好`
Olah真正地实现了huggingface资源的`镜像`功能,而不仅仅是一个简单的`反向代理`
Olah并不会立刻对huggingface全站进行镜像,而是在用户下载的同时在文件块级别对资源进行镜像(或者我们可以说是缓存)。

## Olah的优势
Olah能够在用户下载的同时分块缓存文件。当第二次下载时,直接从缓存中读取,极大地提升了下载速度并节约了流量。
Expand Down
2 changes: 2 additions & 0 deletions assets/full_configs.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ mirror-netloc = "localhost:8090"
mirror-lfs-netloc = "localhost:8090"

[accessibility]
offline = false

[[accessibility.proxy]]
repo = "cais/mmlu"
allow = true
Expand Down
110 changes: 48 additions & 62 deletions olah/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import hashlib
import json
import os
from typing import Dict, Literal, Optional
Expand All @@ -27,35 +28,6 @@
from olah.utils.file_utils import make_dirs
from olah.constants import CHUNK_SIZE, LFS_FILE_BLOCK, WORKER_API_TIMEOUT

FILE_HEADER_TEMPLATE = {
"accept-ranges": "bytes",
"access-control-allow-origin": "*",
"cache-control": "public, max-age=604800, immutable, s-maxage=604800",
# "content-length": None,
# "content-type": "binary/octet-stream",
# "etag": None,
# "last-modified": None,
}

async def _get_redirected_url(client: httpx.AsyncClient, method: str, url: str, headers: Dict[str, str]):
async with client.stream(
method=method,
url=url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
) as response:
if response.status_code >= 300 and response.status_code <= 399:
from_url = urlparse(url)
parsed_url = urlparse(response.headers["location"])
if len(parsed_url.netloc) == 0:
redirect_loc = urljoin(f"{from_url.scheme}://{from_url.netloc}", response.headers["location"])
else:
redirect_loc = response.headers["location"]
else:
redirect_loc = url

return redirect_loc

async def _file_full_header(
app,
save_path: str,
Expand All @@ -71,22 +43,27 @@ async def _file_full_header(
response_headers = json.loads(f.read())
response_headers_dict = {k.lower():v for k, v in response_headers.items()}
else:
if "range" in headers:
headers.pop("range")
response = await client.request(
method=method,
url=url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
)
response_headers_dict = {k.lower(): v for k, v in response.headers.items()}
if allow_cache and method.lower() == "head":
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(response_headers_dict, ensure_ascii=False))
if not app.app_settings.config.offline:
if "range" in headers:
headers.pop("range")
response = await client.request(
method=method,
url=url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
)
response_headers_dict = {k.lower(): v for k, v in response.headers.items()}
if allow_cache and method.lower() == "head" and response.status_code == 200:
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(response_headers_dict, ensure_ascii=False))
else:
response_headers_dict = {}

new_headers = {}
new_headers["content-type"] = response_headers_dict["content-type"]
new_headers["content-length"] = response_headers_dict["content-length"]
if "content-type" in response_headers_dict:
new_headers["content-type"] = response_headers_dict["content-type"]
if "content-length" in response_headers_dict:
new_headers["content-length"] = response_headers_dict["content-length"]
if HUGGINGFACE_HEADER_X_REPO_COMMIT.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_REPO_COMMIT.lower(), "")
if HUGGINGFACE_HEADER_X_LINKED_ETAG.lower() in response_headers_dict:
Expand Down Expand Up @@ -228,26 +205,28 @@ async def _file_realtime_stream(
hf_url = urljoin(app.app_settings.config.hf_lfs_url_base(), get_url_tail(url))
else:
hf_url = url
async with httpx.AsyncClient() as client:
response = await client.request(
method="HEAD",
url=hf_url,
headers=request_headers,
timeout=WORKER_API_TIMEOUT,
)

if response.status_code >= 300 and response.status_code <= 399:
from_url = urlparse(url)
parsed_url = urlparse(response.headers["location"])
new_headers = {k.lower():v for k, v in response.headers.items()}
if len(parsed_url.netloc) != 0:
new_loc = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response.headers["location"]))
new_headers["location"] = new_loc

if not app.app_settings.config.offline:
async with httpx.AsyncClient() as client:
response = await client.request(
method="HEAD",
url=hf_url,
headers=request_headers,
timeout=WORKER_API_TIMEOUT,
)

if response.status_code >= 300 and response.status_code <= 399:
from_url = urlparse(url)
parsed_url = urlparse(response.headers["location"])
new_headers = {k.lower():v for k, v in response.headers.items()}
if len(parsed_url.netloc) != 0:
new_loc = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response.headers["location"]))
new_headers["location"] = new_loc

yield response.status_code
yield new_headers
yield response.content
return
yield response.status_code
yield new_headers
yield response.content
return

async with httpx.AsyncClient() as client:
# redirect_loc = await _get_redirected_url(client, method, url, request_headers)
Expand All @@ -268,6 +247,13 @@ async def _file_realtime_stream(
response_headers["content-length"] = str(end_pos - start_pos)
if commit is not None:
response_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = commit

if app.app_settings.config.offline and "etag" not in response_headers:
# Create fake headers when offline mode
sha256_hash = hashlib.sha256()
sha256_hash.update(hf_url.encode('utf-8'))
content_hash = sha256_hash.hexdigest()
response_headers["etag"] = f"\"{content_hash[:32]}-10\""
yield 200
yield response_headers
if method.lower() == "get":
Expand Down
34 changes: 34 additions & 0 deletions olah/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,40 @@ async def meta_cache_generator(app: FastAPI, save_path: str):
break
yield chunk

async def meta_proxy_cache(
app: FastAPI,
repo_type: Literal["models", "datasets", "spaces"],
org: str,
repo: str,
commit: str,
request: Request,
):
# save
repos_path = app.app_settings.repos_path
save_dir = os.path.join(
repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}"
)
save_path = os.path.join(save_dir, "meta.json")
make_dirs(save_path)

# url
org_repo = get_org_repo(org, repo)
meta_url = urljoin(
app.app_settings.config.hf_url_base(),
f"/api/{repo_type}/{org_repo}/revision/{commit}",
)
async with httpx.AsyncClient() as client:
response = await client.request(
method="GET",
url=meta_url,
timeout=WORKER_API_TIMEOUT,
follow_redirects=True,
)
if response.status_code == 200:
with open(save_path, "wb") as meta_file:
meta_file.write(response.content)
else:
raise Exception(f"Cannot get the branch info from the url {meta_url}, status: {response.status_code}")

async def meta_proxy_generator(
app: FastAPI,
Expand Down
18 changes: 15 additions & 3 deletions olah/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from olah.configs import OlahConfig
from olah.files import cdn_file_get_generator, file_get_generator
from olah.lfs import lfs_get_generator, lfs_head_generator
from olah.meta import meta_generator
from olah.meta import meta_generator, meta_proxy_cache
from olah.utils.url_utils import check_proxy_rules_hf, check_commit_hf, get_commit_hf, get_newest_commit_hf, parse_org_repo

from olah.utils.logging import build_logger
Expand Down Expand Up @@ -52,7 +52,13 @@ async def meta_proxy_commit2(repo_type: str, org: str, repo: str, commit: str, r
return Response(content="This repository is forbidden by the mirror. ", status_code=403)
if not await check_commit_hf(app, repo_type, org, repo, commit):
return Response(content="This repository is not accessible. ", status_code=404)
generator = meta_generator(app, repo_type, org, repo, commit, request)
commit_sha = await get_commit_hf(app, repo_type, org, repo, commit)

# if branch name and online mode, refresh branch info
if commit_sha != commit and not app.app_settings.config.offline:
await meta_proxy_cache(app, repo_type, org, repo, commit, request)

generator = meta_generator(app, repo_type, org, repo, commit_sha, request)
headers = await generator.__anext__()
return StreamingResponse(generator, headers=headers)

Expand All @@ -66,7 +72,13 @@ async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request:
return Response(content="This repository is forbidden by the mirror. ", status_code=403)
if not await check_commit_hf(app, repo_type, org, repo, commit):
return Response(content="This repository is not accessible. ", status_code=404)
generator = meta_generator(app, repo_type, org, repo, commit, request)
commit_sha = await get_commit_hf(app, repo_type, org, repo, commit)

# if branch name and online mode, refresh branch info
if commit_sha != commit and not app.app_settings.config.offline:
await meta_proxy_cache(app, repo_type, org, repo, commit, request)

generator = meta_generator(app, repo_type, org, repo, commit_sha, request)
headers = await generator.__anext__()
return StreamingResponse(generator, headers=headers)

Expand Down
2 changes: 1 addition & 1 deletion olah/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def build_logger(logger_name, logger_filename, logger_dir=DEFAULT_LOGGER_DIR) ->
os.makedirs(logger_dir, exist_ok=True)
filename = os.path.join(logger_dir, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when="M", utc=True, encoding="utf-8"
filename, when="H", utc=True, encoding="utf-8"
)
handler.setFormatter(formatter)
handler.namer = lambda name: name.replace(".log", "") + ".log"
Expand Down
6 changes: 3 additions & 3 deletions olah/utils/url_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def parse_org_repo(org_repo: str) -> Tuple[Optional[str], Optional[str]]:
return org, repo

def get_meta_save_path(repos_path: str, repo_type: str, org: Optional[str], repo: str, commit: str) -> str:
return os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}")
return os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}/meta.json")

def get_meta_save_dir(repos_path: str, repo_type: str, org: Optional[str], repo: str) -> str:
return os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/revision")
Expand Down Expand Up @@ -108,12 +108,12 @@ async def check_commit_hf(app, repo_type: Optional[Literal["models", "datasets",
async def check_proxy_rules_hf(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str) -> bool:
config: OlahConfig = app.app_settings.config
org_repo = get_org_repo(org, repo)
return config.proxy.allow(f"{org_repo}")
return config.proxy.allow(org_repo)

async def check_cache_rules_hf(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str) -> bool:
config: OlahConfig = app.app_settings.config
org_repo = get_org_repo(org, repo)
return config.cache.allow(f"{org_repo}")
return config.cache.allow(org_repo)

def get_url_tail(parsed_url: Union[str, ParseResult]) -> str:
if isinstance(parsed_url, str):
Expand Down

0 comments on commit 8ad38c1

Please sign in to comment.