Skip to content

Commit

Permalink
lfs bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jul 14, 2024
1 parent 721ec02 commit b763c23
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 253 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
Olah is self-hosted lightweight huggingface mirror service. `Olah` means `hello` in Hilichurlian.

Other languages: [中文](README_zh.md)

## Advantages of Olah
Olah has the capability to cache files in chunks while users download them. Upon subsequent downloads, the files can be directly retrieved from the cache, greatly enhancing download speeds and saving bandwidth.
Additionally, Olah offers a range of cache control policies. Administrators can configure which repositories are accessible and which ones can be cached through a configuration file.

## Features
* Huggingface Data Cache
* Models mirror
Expand Down
4 changes: 4 additions & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
Olah是一种自托管的轻量级HuggingFace镜像服务。`Olah`在丘丘人语中意味着`你好`

## Olah的优势
Olah能够在用户下载的同时分块缓存文件。当第二次下载时,直接从缓存中读取,极大地提升了下载速度并节约了流量。
同时Olah提供了丰富的缓存控制策略,管理员可以通过配置文件设置哪些仓库可以访问,哪些仓库可以缓存。

## 特性
* 数据缓存,减少下载流量
* 模型镜像
Expand Down
113 changes: 51 additions & 62 deletions olah/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ async def _get_redirected_url(client: httpx.AsyncClient, method: str, url: str,
redirect_loc = response.headers["location"]
else:
redirect_loc = url

return redirect_loc

async def _file_full_header(
Expand All @@ -67,16 +68,16 @@ async def _file_full_header(
else:
if "range" in headers:
headers.pop("range")
async with client.stream(
response = await client.request(
method=method,
url=url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
) as response:
response_headers_dict = {k.lower(): v for k, v in response.headers.items()}
if allow_cache and method.lower() == "head":
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(response_headers_dict, ensure_ascii=False))
)
response_headers_dict = {k.lower(): v for k, v in response.headers.items()}
if allow_cache and method.lower() == "head":
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(response_headers_dict, ensure_ascii=False))

new_headers = {}
new_headers["content-type"] = response_headers_dict["content-type"]
Expand Down Expand Up @@ -108,12 +109,15 @@ async def _get_file_block_from_remote(client: httpx.AsyncClient, remote_info: Re
headers=remote_info.headers,
timeout=WORKER_API_TIMEOUT,
) as response:
response_content_length = int(response.headers['content-length'])
async for raw_chunk in response.aiter_raw():
if not raw_chunk:
continue
raw_block += raw_chunk
# print(remote_info.url, remote_info.method, remote_info.headers)
# print(block_start_pos, block_end_pos)
if len(raw_block) != response_content_length:
raise Exception(f"The content of the response is incomplete. Expected-{response_content_length}. Accepted-{len(raw_block)}")
if len(raw_block) != (block_end_pos - block_start_pos):
raise Exception(f"The block is incomplete. Expected-{block_end_pos - block_start_pos}. Accepted-{len(raw_block)}")
if len(raw_block) < cache_file._get_block_size():
Expand Down Expand Up @@ -215,15 +219,40 @@ async def _file_realtime_stream(
request_headers = {k: v for k, v in request.headers.items()}
request_headers.pop("host")

if urlparse(url).netloc == app.app_settings.config.mirror_netloc:
hf_url = urljoin(app.app_settings.config.hf_lfs_url_base(), get_url_tail(url))
else:
hf_url = url
async with httpx.AsyncClient() as client:
response = await client.request(
method="HEAD",
url=hf_url,
headers=request_headers,
timeout=WORKER_API_TIMEOUT,
)

if response.status_code >= 300 and response.status_code <= 399:
from_url = urlparse(url)
parsed_url = urlparse(response.headers["location"])
new_headers = {k.lower():v for k, v in response.headers.items()}
if len(parsed_url.netloc) != 0:
new_loc = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response.headers["location"]))
new_headers["location"] = new_loc

yield response.status_code
yield new_headers
yield response.content
return

async with httpx.AsyncClient() as client:
redirect_loc = await _get_redirected_url(client, method, url, request_headers)
# redirect_loc = await _get_redirected_url(client, method, url, request_headers)
head_info = await _file_full_header(
app=app,
save_path=save_path,
head_path=head_path,
client=client,
method=method,
url=redirect_loc,
method="HEAD",
url=hf_url,
headers=request_headers,
allow_cache=allow_cache,
)
Expand All @@ -234,6 +263,7 @@ async def _file_realtime_stream(
response_headers["content-length"] = str(end_pos - start_pos)
if commit is not None:
response_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = commit
yield 200
yield response_headers
if method.lower() == "get":
async for each_chunk in _file_chunk_get(
Expand All @@ -242,7 +272,7 @@ async def _file_realtime_stream(
head_path=head_path,
client=client,
method=method,
url=redirect_loc,
url=hf_url,
headers=request_headers,
allow_cache=allow_cache,
file_size=file_size,
Expand All @@ -255,7 +285,7 @@ async def _file_realtime_stream(
head_path=head_path,
client=client,
method=method,
url=redirect_loc,
url=hf_url,
headers=request_headers,
allow_cache=allow_cache,
file_size=0,
Expand All @@ -264,55 +294,14 @@ async def _file_realtime_stream(
else:
raise Exception(f"Unsupported method: {method}")


async def file_head_generator(
app,
repo_type: Literal["models", "datasets"],
org: str,
repo: str,
commit: str,
file_path: str,
request: Request,
):
org_repo = get_org_repo(org, repo)
# save
repos_path = app.app_settings.repos_path
head_path = os.path.join(
repos_path, f"heads/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}"
)
save_path = os.path.join(
repos_path, f"files/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}"
)
make_dirs(head_path)
make_dirs(save_path)

# use_cache = os.path.exists(head_path) and os.path.exists(save_path)
allow_cache = await check_cache_rules_hf(app, repo_type, org, repo)

# proxy
if repo_type == "models":
url = urljoin(app.app_settings.config.hf_url_base(), f"/{org_repo}/resolve/{commit}/{file_path}")
else:
url = urljoin(app.app_settings.config.hf_url_base(), f"/{repo_type}/{org_repo}/resolve/{commit}/{file_path}")
return _file_realtime_stream(
app=app,
save_path=save_path,
head_path=head_path,
url=url,
request=request,
method="HEAD",
allow_cache=allow_cache,
commit=commit,
)


async def file_get_generator(
app,
repo_type: Literal["models", "datasets"],
org: str,
repo: str,
commit: str,
file_path: str,
method: Literal["HEAD", "GET"],
request: Request,
):
org_repo = get_org_repo(org, repo)
Expand Down Expand Up @@ -341,18 +330,18 @@ async def file_get_generator(
head_path=head_path,
url=url,
request=request,
method="GET",
method=method,
allow_cache=allow_cache,
commit=commit,
)


async def cdn_file_get_generator(
app,
repo_type: Literal["models", "datasets"],
org: str,
repo: str,
file_hash: str,
method: Literal["HEAD", "GET"],
request: Request,
):
headers = {k: v for k, v in request.headers.items()}
Expand All @@ -374,18 +363,18 @@ async def cdn_file_get_generator(
allow_cache = await check_cache_rules_hf(app, repo_type, org, repo)

# proxy
request_url = urlparse(request.url)
if request_url.netloc == app.app_settings.config.hf_lfs_netloc:
redirected_url = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(request_url))
else:
redirected_url = urljoin(app.app_settings.config.mirror_url_base(), get_url_tail(request_url))
# request_url = urlparse(str(request.url))
# if request_url.netloc == app.app_settings.config.hf_lfs_netloc:
# redirected_url = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(request_url))
# else:
# redirected_url = urljoin(app.app_settings.config.mirror_url_base(), get_url_tail(request_url))

return _file_realtime_stream(
app=app,
save_path=save_path,
head_path=head_path,
url=redirected_url,
url=str(request.url),
request=request,
method="GET",
method=method,
allow_cache=allow_cache,
)
Loading

0 comments on commit b763c23

Please sign in to comment.