Skip to content

Commit

Permalink
add block cache data structures
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jul 12, 2024
1 parent 473cdfe commit 7b6291f
Show file tree
Hide file tree
Showing 12 changed files with 369 additions and 85 deletions.
2 changes: 1 addition & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pip install -e .
python -m olah.server
```

然后将环境变量`HF_ENDPOINT`设置为镜像站点(这里是http://localhost:8090)。
然后将环境变量`HF_ENDPOINT`设置为镜像站点(这里是http://localhost:8090/)。

Linux:
```bash
Expand Down
10 changes: 6 additions & 4 deletions assets/full_configs.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ port = 8090
ssl-key = ""
ssl-cert = ""
repos-path = "./repos"
hf-url = "https://huggingface.co"
hf-lfs-url = "https://cdn-lfs.huggingface.co"
mirror-url = "http://localhost:8090"
mirror-lfs-url = "http://localhost:8090"
hf-scheme = "https"
hf-netloc = "huggingface.co"
hf-lfs-netloc = "cdn-lfs.huggingface.co"
mirror-scheme = "http"
mirror-netloc = "localhost:8090"
mirror-lfs-netloc = "localhost:8090"

[accessibility]
[[accessibility.proxy]]
Expand Down
41 changes: 28 additions & 13 deletions olah/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,14 @@ def __init__(self, path: Optional[str] = None) -> None:
self.ssl_key = None
self.ssl_cert = None
self.repos_path = "./repos"
self.hf_url = "https://huggingface.co"
self.hf_lfs_url = "https://cdn-lfs.huggingface.co"
self.mirror_url = f"http://{self.host}:{self.port}"
self.mirror_lfs_url = f"http://{self.host}:{self.port}"

self.hf_scheme: str = "https"
self.hf_netloc: str = "huggingface.co"
self.hf_lfs_netloc: str = "cdn-lfs.huggingface.co"

self.mirror_scheme: str = "http"
self.mirror_netloc: str = "localhost:8090"
self.mirror_lfs_netloc: str = "localhost:8090"

# accessibility
self.offline = False
Expand All @@ -103,12 +107,19 @@ def __init__(self, path: Optional[str] = None) -> None:

if path is not None:
self.read_toml(path)

# refresh urls
self.mirror_url = f"http://{self.host}:{self.port}"
self.mirror_lfs_url = f"http://{self.host}:{self.port}"


def hf_url_base(self) -> str:
return f"{self.hf_scheme}://{self.hf_netloc}"

def hf_lfs_url_base(self) -> str:
return f"{self.hf_scheme}://{self.hf_lfs_netloc}"

def mirror_url_base(self) -> str:
return f"{self.mirror_scheme}://{self.mirror_netloc}"

def mirror_lfs_url_base(self) -> str:
return f"{self.mirror_scheme}://{self.mirror_lfs_netloc}"

def empty_str(self, s: str) -> Optional[str]:
if s == "":
return None
Expand All @@ -125,10 +136,14 @@ def read_toml(self, path: str) -> None:
self.ssl_key = self.empty_str(basic.get("ssl-key", self.ssl_key))
self.ssl_cert = self.empty_str(basic.get("ssl-cert", self.ssl_cert))
self.repos_path = basic.get("repos-path", self.repos_path)
self.hf_url = basic.get("hf-url", self.hf_url)
self.hf_lfs_url = basic.get("hf-lfs-url", self.hf_lfs_url)
self.mirror_url = basic.get("mirror-url", self.mirror_url)
self.mirror_lfs_url = basic.get("mirror-lfs-url", self.mirror_lfs_url)

self.hf_scheme = basic.get("hf-scheme", self.hf_scheme)
self.hf_netloc = basic.get("hf-netloc", self.hf_netloc)
self.hf_lfs_netloc = basic.get("hf-lfs-netloc", self.hf_lfs_netloc)

self.mirror_scheme = basic.get("mirror-scheme", self.mirror_scheme)
self.mirror_netloc = basic.get("mirror-netloc", self.mirror_netloc)
self.mirror_lfs_netloc = basic.get("mirror-lfs-netloc", self.mirror_lfs_netloc)

if "accessibility" in config:
accessibility = config["accessibility"]
Expand Down
80 changes: 51 additions & 29 deletions olah/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import os
import shutil
import tempfile
from typing import Literal
from typing import Literal, Optional
from fastapi import Request

from requests.structures import CaseInsensitiveDict
import httpx
from starlette.datastructures import URL
from urllib.parse import urlparse, urljoin

from olah.constants import (
CHUNK_SIZE,
Expand All @@ -15,7 +17,7 @@
HUGGINGFACE_HEADER_X_LINKED_ETAG,
HUGGINGFACE_HEADER_X_LINKED_SIZE,
)
from olah.utils import check_cache_rules_hf, get_org_repo, make_dirs
from olah.utils import check_cache_rules_hf, get_org_repo, get_url_tail, make_dirs
FILE_HEADER_TEMPLATE = {
"accept-ranges": "bytes",
"access-control-allow-origin": "*",
Expand All @@ -30,12 +32,16 @@ async def _file_cache_stream(save_path: str, head_path: str, request: Request):
if request.method.lower() == "head":
with open(head_path, "r", encoding="utf-8") as f:
response_headers = json.loads(f.read())
new_headers = {k:v for k, v in FILE_HEADER_TEMPLATE.items()}
response_headers = {k.lower():v for k, v in response_headers.items()}
new_headers = {k.lower():v for k, v in FILE_HEADER_TEMPLATE.items()}
new_headers["content-type"] = response_headers["content-type"]
new_headers["content-length"] = response_headers["content-length"]
new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT] = response_headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT, None)
new_headers[HUGGINGFACE_HEADER_X_LINKED_ETAG] = response_headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG, None)
new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE] = response_headers.get(HUGGINGFACE_HEADER_X_LINKED_SIZE, None)
if HUGGINGFACE_HEADER_X_REPO_COMMIT.lower() in response_headers:
new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = response_headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT.lower(), "")
if HUGGINGFACE_HEADER_X_LINKED_ETAG.lower() in response_headers:
new_headers[HUGGINGFACE_HEADER_X_LINKED_ETAG.lower()] = response_headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG.lower(), "")
if HUGGINGFACE_HEADER_X_LINKED_SIZE.lower() in response_headers:
new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE.lower()] = response_headers.get(HUGGINGFACE_HEADER_X_LINKED_SIZE.lower(), "")
new_headers["etag"] = response_headers["etag"]
yield new_headers
elif request.method.lower() == "get":
Expand All @@ -49,8 +55,26 @@ async def _file_cache_stream(save_path: str, head_path: str, request: Request):
break
yield chunk

async def _get_redirected_url(client, method: str, url, headers):
async with client.stream(
method=method,
url=url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
) as response:
if response.status_code >= 300 and response.status_code <= 399:
from_url = urlparse(url)
parsed_url = urlparse(response.headers["location"])
if len(parsed_url.netloc) == 0:
redirect_loc = urljoin(f"{from_url.scheme}://{from_url.netloc}", response.headers["location"])
else:
redirect_loc = response.headers["location"]
else:
redirect_loc = url
return redirect_loc

async def _file_realtime_stream(
app, save_path: str, head_path: str, url: str, request: Request, method="GET", allow_cache=True
app, save_path: str, head_path: str, url: str, request: Request, method="GET", allow_cache=True, commit: Optional[str]=None
):
request_headers = {k: v for k, v in request.headers.items()}
request_headers.pop("host")
Expand All @@ -65,33 +89,26 @@ async def _file_realtime_stream(
else:
write_temp_file = True

async with client.stream(
method=method,
url=url,
headers=request_headers,
timeout=WORKER_API_TIMEOUT,
) as response:
if response.status_code >= 300 and response.status_code <= 399:
redirect_loc = app.app_settings.hf_url + response.headers["location"]
else:
redirect_loc = url

redirect_loc = await _get_redirected_url(client, method, url, request_headers)
async with client.stream(
method=method,
url=redirect_loc,
headers=request_headers,
timeout=WORKER_API_TIMEOUT,
) as response:
response_headers = response.headers
response_headers_dict = {k: v for k, v in response_headers.items()}
response_headers_dict = {k.lower(): v for k, v in response.headers.items()}
if allow_cache:
if request.method.lower() == "head":
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(response_headers_dict, ensure_ascii=False))
if "location" in response_headers_dict:
response_headers_dict["location"] = response_headers_dict["location"].replace(
app.app_settings.hf_lfs_url, app.app_settings.mirror_lfs_url
)
location_url = urlparse(response_headers_dict["location"])
if location_url.netloc == app.app_settings.config.hf_lfs_netloc:
response_headers_dict["location"] = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(location_url))
else:
response_headers_dict["location"] = urljoin(app.app_settings.config.mirror_url_base(), get_url_tail(location_url))
if commit is not None:
response_headers_dict[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = commit
yield response_headers_dict

async for raw_chunk in response.aiter_raw():
Expand Down Expand Up @@ -137,9 +154,9 @@ async def file_head_generator(
return _file_cache_stream(save_path=save_path, head_path=head_path, request=request)
else:
if repo_type == "models":
url = f"{app.app_settings.hf_url}/{org_repo}/resolve/{commit}/{file_path}"
url = urljoin(app.app_settings.config.hf_url_base(), f"/{org_repo}/resolve/{commit}/{file_path}")
else:
url = f"{app.app_settings.hf_url}/{repo_type}/{org_repo}/resolve/{commit}/{file_path}"
url = urljoin(app.app_settings.config.hf_url_base(), f"/{repo_type}/{org_repo}/resolve/{commit}/{file_path}")
return _file_realtime_stream(
app=app,
save_path=save_path,
Expand All @@ -148,6 +165,7 @@ async def file_head_generator(
request=request,
method="HEAD",
allow_cache=allow_cache,
commit=commit,
)


Expand Down Expand Up @@ -180,9 +198,9 @@ async def file_get_generator(
return _file_cache_stream(save_path=save_path, head_path=head_path, request=request)
else:
if repo_type == "models":
url = f"{app.app_settings.hf_url}/{org_repo}/resolve/{commit}/{file_path}"
url = urljoin(app.app_settings.config.hf_url_base(), f"/{org_repo}/resolve/{commit}/{file_path}")
else:
url = f"{app.app_settings.hf_url}/{repo_type}/{org_repo}/resolve/{commit}/{file_path}"
url = urljoin(app.app_settings.config.hf_url_base(), f"/{repo_type}/{org_repo}/resolve/{commit}/{file_path}")
return _file_realtime_stream(
app=app,
save_path=save_path,
Expand All @@ -191,6 +209,7 @@ async def file_get_generator(
request=request,
method="GET",
allow_cache=allow_cache,
commit=commit,
)


Expand Down Expand Up @@ -224,8 +243,11 @@ async def cdn_file_get_generator(
if use_cache:
return _file_cache_stream(save_path=save_path, request=request)
else:
redirected_url = str(request.url)
redirected_url = redirected_url.replace(app.app_settings.mirror_lfs_url, app.app_settings.hf_lfs_url)
request_url = urlparse(request.url)
if request_url.netloc == app.app_settings.config.hf_lfs_netloc:
redirected_url = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(request_url))
else:
redirected_url = urljoin(app.app_settings.config.mirror_url_base(), get_url_tail(request_url))

return _file_realtime_stream(
app=app,
Expand Down
6 changes: 5 additions & 1 deletion olah/lfs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
废弃方法
"""

import datetime
import json
import os
Expand All @@ -8,7 +12,7 @@
import pytz

from olah.constants import CHUNK_SIZE, LFS_FILE_BLOCK, WORKER_API_TIMEOUT
from olah.utils import make_dirs
from olah.utils.file_utils import make_dirs


async def lfs_get_generator(app, repo_type: str, lfs_url: str, save_path: str, request: Request):
Expand Down
6 changes: 4 additions & 2 deletions olah/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import shutil
import tempfile
from typing import Dict, Literal
from urllib.parse import urljoin
from fastapi import FastAPI, Request

import httpx
from olah.configs import OlahConfig
from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT

from olah.utils import check_cache_rules_hf, get_org_repo, make_dirs
from olah.utils.url_utils import check_cache_rules_hf, get_org_repo
from olah.utils.file_utils import make_dirs

async def meta_cache_generator(app: FastAPI, save_path: str):
yield {}
Expand Down Expand Up @@ -67,7 +69,7 @@ async def meta_generator(app: FastAPI, repo_type: Literal["models", "datasets"],
allow_cache = await check_cache_rules_hf(app, repo_type, org, repo)

org_repo = get_org_repo(org, repo)
meta_url = f"{app.app_settings.hf_url}/api/{repo_type}/{org_repo}/revision/{commit}"
meta_url = urljoin(app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}/revision/{commit}")
# proxy
if use_cache:
async for item in meta_cache_generator(app, save_path):
Expand Down
Loading

0 comments on commit 7b6291f

Please sign in to comment.