Skip to content

Commit

Permalink
authorization bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jul 19, 2024
1 parent 6298207 commit ba4b833
Show file tree
Hide file tree
Showing 18 changed files with 643 additions and 165 deletions.
19 changes: 19 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# How can I contribute to Olah?

Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable.

It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you.

However you choose to contribute, please be mindful and respect our code of conduct.

## Ways to contribute

There are lots of ways you can contribute to Olah:
* Submitting issues on Github to report bugs or make feature requests
* Fixing outstanding issues with the existing code
* Implementing new features
* Contributing to the examples or to the documentation

*All are equally valuable to the community.*

#### This guide was heavily inspired by the awesome [transformers guide to contributing](https://github.com/huggingface/transformers/blob/master/CONTRIBUTING.md)
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ allow = false

## Future Work

* Authentication
* Administrator and user system
* OOS backend support
* Mirror Update Schedule Task
Expand Down
2 changes: 1 addition & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<p align="center">
<b>自托管的轻量级HuggingFace镜像服务</b>

Olah是一种自托管的轻量级HuggingFace镜像服务。`Olah`来源于丘丘人语,在丘丘人语中意味着`你好`。
Olah是开源的自托管轻量级HuggingFace镜像服务。`Olah`来源于丘丘人语,在丘丘人语中意味着`你好`。
Olah真正地实现了huggingface资源的`镜像`功能,而不仅仅是一个简单的`反向代理`。
Olah并不会立刻对huggingface全站进行镜像,而是在用户下载的同时在文件块级别对资源进行镜像(或者我们可以说是缓存)。

Expand Down
8 changes: 8 additions & 0 deletions docs/en/main.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<h1 align="center">Olah Document</h1>

<p align="center">
<b>Self-hosted Lightweight Huggingface Mirror Service</b>

Olah is a self-hosted lightweight huggingface mirror service. `Olah` means `hello` in Hilichurlian.
Olah implemented the `mirroring` feature for huggingface resources, rather than just a simple `reverse proxy`.
Olah does not immediately mirror the entire huggingface website but mirrors the resources at the file block level when users download them (or we can say cache them).
Empty file added docs/en/quickstart.md
Empty file.
9 changes: 9 additions & 0 deletions docs/zh/main.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<h1 align="center">Olah 文档</h1>


<p align="center">
<b>自托管的轻量级HuggingFace镜像服务</b>

Olah是开源的自托管轻量级HuggingFace镜像服务。`Olah`来源于丘丘人语,在丘丘人语中意味着`你好`。
Olah真正地实现了huggingface资源的`镜像`功能,而不仅仅是一个简单的`反向代理`。
Olah并不会立刻对huggingface全站进行镜像,而是在用户下载的同时在文件块级别对资源进行镜像(或者我们可以说是缓存)。
Empty file added docs/zh/quickstart.md
Empty file.
Empty file added olah/auth/__init__.py
Empty file.
2 changes: 2 additions & 0 deletions olah/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

DEFAULT_LOGGER_DIR = "./logs"

ORIGINAL_LOC = "oriloc"

from huggingface_hub.constants import (
REPO_TYPES_MAPPING,
HUGGINGFACE_CO_URL_TEMPLATE,
Expand Down
80 changes: 69 additions & 11 deletions olah/proxy/files.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding=utf-8
# Copyright 2024 XiaHan
#
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
Expand All @@ -22,13 +22,39 @@
HUGGINGFACE_HEADER_X_REPO_COMMIT,
HUGGINGFACE_HEADER_X_LINKED_ETAG,
HUGGINGFACE_HEADER_X_LINKED_SIZE,
ORIGINAL_LOC,
)
from olah.utils.olah_cache import OlahCache
from olah.utils.url_utils import RemoteInfo, check_cache_rules_hf, get_org_repo, get_url_tail, parse_range_params
from olah.utils.url_utils import (
RemoteInfo,
add_query_param,
check_url_has_param_name,
get_url_param_name,
get_url_tail,
parse_range_params,
remove_query_param,
)
from olah.utils.repo_utils import get_org_repo
from olah.utils.rule_utils import check_cache_rules_hf
from olah.utils.file_utils import make_dirs
from olah.constants import CHUNK_SIZE, LFS_FILE_BLOCK, WORKER_API_TIMEOUT

async def _write_cache_request(head_path: str, status_code: int, headers: Dict[str, str], content: bytes):

async def _write_cache_request(
head_path: str, status_code: int, headers: Dict[str, str], content: bytes
)-> None:
"""
Write the request's status code, headers, and content to a cache file.

Args:
head_path (str): The path to the cache file.
status_code (int): The status code of the request.
headers (Dict[str, str]): The dictionary of response headers.
content (bytes): The content of the request.

Returns:
None
"""
rq = {
"status_code": status_code,
"headers": headers,
Expand All @@ -37,13 +63,24 @@ async def _write_cache_request(head_path: str, status_code: int, headers: Dict[s
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(rq, ensure_ascii=False))

async def _read_cache_request(head_path: str):

async def _read_cache_request(head_path: str) -> Dict[str, str]:
"""
Read the request's status code, headers, and content from a cache file.

Args:
head_path (str): The path to the cache file.

Returns:
Dict[str, str]: A dictionary containing the status code, headers, and content of the request.
"""
with open(head_path, "r", encoding="utf-8") as f:
rq = json.loads(f.read())

rq["content"] = bytes.fromhex(rq["content"])
return rq


async def _file_full_header(
app,
save_path: str,
Expand Down Expand Up @@ -85,6 +122,13 @@ async def _file_full_header(
if len(parsed_url.netloc) != 0:
new_loc = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response.headers["location"]))
response_headers_dict["location"] = new_loc
# Redirect, add original location info
if check_url_has_param_name(response_headers_dict["location"], ORIGINAL_LOC):
raise Exception(f"Invalid field {ORIGINAL_LOC} in the url.")
else:
response_headers_dict["location"] = add_query_param(response_headers_dict["location"], ORIGINAL_LOC, response.headers["location"])
elif response.status_code == 403:
pass
else:
raise Exception(f"Unexpected HTTP status code {response.status_code}")
return response.status_code, response_headers_dict, response.content
Expand Down Expand Up @@ -244,13 +288,26 @@ async def _file_realtime_stream(
allow_cache=True,
commit: Optional[str] = None,
):
request_headers = {k: v for k, v in request.headers.items()}
request_headers.pop("host")

if urlparse(url).netloc == app.app_settings.config.mirror_netloc:
hf_url = urljoin(app.app_settings.config.hf_lfs_url_base(), get_url_tail(url))
if check_url_has_param_name(url, ORIGINAL_LOC):
clean_url = remove_query_param(url, ORIGINAL_LOC)
original_loc = get_url_param_name(url, ORIGINAL_LOC)
if urlparse(url).netloc in [app.app_settings.config.mirror_netloc, app.app_settings.config.mirror_lfs_netloc]:
hf_loc = urlparse(original_loc)
if len(hf_loc.netloc) != 0:
hf_url = urljoin(f"{hf_loc.scheme}://{hf_loc.netloc}", get_url_tail(clean_url))
else:
hf_url = url
else:
hf_url = url
else:
hf_url = url
if urlparse(url).netloc in [app.app_settings.config.mirror_netloc, app.app_settings.config.mirror_lfs_netloc]:
hf_url = urljoin(app.app_settings.config.hf_lfs_url_base(), get_url_tail(url))
else:
hf_url = url

request_headers = {k: v for k, v in request.headers.items()}
if "host" in request_headers:
request_headers["host"] = urlparse(hf_url).netloc

async with httpx.AsyncClient() as client:
# redirect_loc = await _get_redirected_url(client, method, url, request_headers)
Expand All @@ -270,6 +327,7 @@ async def _file_realtime_stream(
yield content
return

async with httpx.AsyncClient() as client:
file_size = int(head_info["content-length"])
response_headers = {k: v for k,v in head_info.items()}
if "range" in request_headers:
Expand Down
1 change: 0 additions & 1 deletion olah/proxy/lfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from olah.proxy.files import _file_realtime_stream
from olah.utils.file_utils import make_dirs
from olah.utils.url_utils import check_cache_rules_hf, get_org_repo


async def lfs_head_generator(
Expand Down
7 changes: 6 additions & 1 deletion olah/proxy/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import httpx
from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT

from olah.utils.url_utils import check_cache_rules_hf, get_org_repo
from olah.utils.rule_utils import check_cache_rules_hf
from olah.utils.repo_utils import get_org_repo
from olah.utils.file_utils import make_dirs


Expand Down Expand Up @@ -50,10 +51,14 @@ async def meta_proxy_cache(
app.app_settings.config.hf_url_base(),
f"/api/{repo_type}/{org_repo}/revision/{commit}",
)
headers = {}
if "authorization" in request.headers:
headers["authorization"] = request.headers["authorization"]
async with httpx.AsyncClient() as client:
response = await client.request(
method="GET",
url=meta_url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
follow_redirects=True,
)
Expand Down
Loading

0 comments on commit ba4b833

Please sign in to comment.