Skip to content

Commit

Permalink
stream redirect bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jul 9, 2024
1 parent 5e8de7b commit 9ca8a1f
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 60 deletions.
10 changes: 10 additions & 0 deletions olah/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,13 @@
WORKER_API_TIMEOUT = 15
CHUNK_SIZE = 4096
LFS_FILE_BLOCK = 64 * 1024 * 1024


from huggingface_hub.constants import (
_HF_DEFAULT_ENDPOINT,
_HF_DEFAULT_STAGING_ENDPOINT,
HUGGINGFACE_CO_URL_TEMPLATE,
HUGGINGFACE_HEADER_X_REPO_COMMIT,
HUGGINGFACE_HEADER_X_LINKED_ETAG,
HUGGINGFACE_HEADER_X_LINKED_SIZE,
)
41 changes: 31 additions & 10 deletions olah/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
import httpx
from starlette.datastructures import URL

from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT
from olah.constants import (
CHUNK_SIZE,
WORKER_API_TIMEOUT,
HUGGINGFACE_HEADER_X_REPO_COMMIT,
HUGGINGFACE_HEADER_X_LINKED_ETAG,
HUGGINGFACE_HEADER_X_LINKED_SIZE,
)
from olah.utils import check_cache_rules_hf, get_org_repo, make_dirs
FILE_HEADER_TEMPLATE = {
"accept-ranges": "bytes",
Expand All @@ -27,7 +33,9 @@ async def _file_cache_stream(save_path: str, head_path: str, request: Request):
new_headers = {k:v for k, v in FILE_HEADER_TEMPLATE.items()}
new_headers["content-type"] = response_headers["content-type"]
new_headers["content-length"] = response_headers["content-length"]
new_headers["x-repo-commit"] = response_headers["x-repo-commit"]
new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT] = response_headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT, None)
new_headers[HUGGINGFACE_HEADER_X_LINKED_ETAG] = response_headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG, None)
new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE] = response_headers.get(HUGGINGFACE_HEADER_X_LINKED_SIZE, None)
new_headers["etag"] = response_headers["etag"]
yield new_headers
elif request.method.lower() == "get":
Expand Down Expand Up @@ -56,24 +64,35 @@ async def _file_realtime_stream(
write_temp_file = False
else:
write_temp_file = True

async with client.stream(
method=method,
url=url,
headers=request_headers,
timeout=WORKER_API_TIMEOUT,
) as response:
if response.status_code >= 300 and response.status_code <= 399:
redirect_loc = app.app_settings.hf_url + response.headers["location"]
else:
redirect_loc = url

async with client.stream(
method=method,
url=redirect_loc,
headers=request_headers,
timeout=WORKER_API_TIMEOUT,
) as response:
response_headers = response.headers
response_headers_dict = {k: v for k, v in response_headers.items()}
if allow_cache:
if request.method.lower() == "head":
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(response_headers_dict, ensure_ascii=False))
if "location" in response_headers:
response_headers["location"] = response_headers["location"].replace(
if "location" in response_headers_dict:
response_headers_dict["location"] = response_headers_dict["location"].replace(
app.app_settings.hf_lfs_url, app.app_settings.mirror_lfs_url
)
yield response_headers
yield response_headers_dict

async for raw_chunk in response.aiter_raw():
if not raw_chunk:
Expand All @@ -98,6 +117,7 @@ async def file_head_generator(
file_path: str,
request: Request,
):
org_repo = get_org_repo(org, repo)
# save
repos_path = app.app_settings.repos_path
head_path = os.path.join(
Expand All @@ -117,9 +137,9 @@ async def file_head_generator(
return _file_cache_stream(save_path=save_path, head_path=head_path, request=request)
else:
if repo_type == "models":
url = f"{app.app_settings.hf_url}/{org}/{repo}/resolve/{commit}/{file_path}"
url = f"{app.app_settings.hf_url}/{org_repo}/resolve/{commit}/{file_path}"
else:
url = f"{app.app_settings.hf_url}/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}"
url = f"{app.app_settings.hf_url}/{repo_type}/{org_repo}/resolve/{commit}/{file_path}"
return _file_realtime_stream(
app=app,
save_path=save_path,
Expand All @@ -140,6 +160,7 @@ async def file_get_generator(
file_path: str,
request: Request,
):
org_repo = get_org_repo(org, repo)
# save
repos_path = app.app_settings.repos_path
head_path = os.path.join(
Expand All @@ -159,9 +180,9 @@ async def file_get_generator(
return _file_cache_stream(save_path=save_path, head_path=head_path, request=request)
else:
if repo_type == "models":
url = f"{app.app_settings.hf_url}/{org}/{repo}/resolve/{commit}/{file_path}"
url = f"{app.app_settings.hf_url}/{org_repo}/resolve/{commit}/{file_path}"
else:
url = f"{app.app_settings.hf_url}/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}"
url = f"{app.app_settings.hf_url}/{repo_type}/{org_repo}/resolve/{commit}/{file_path}"
return _file_realtime_stream(
app=app,
save_path=save_path,
Expand Down
13 changes: 7 additions & 6 deletions olah/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from olah.configs import OlahConfig
from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT

from olah.utils import check_cache_rules_hf
from olah.utils import check_cache_rules_hf, get_org_repo, make_dirs

async def meta_cache_generator(app: FastAPI, save_path: str):
yield {}
Expand All @@ -25,7 +25,7 @@ async def meta_cache_generator(app: FastAPI, save_path: str):
async def meta_proxy_generator(app: FastAPI, headers: Dict[str, str], meta_url: str, allow_cache: bool, save_path: str):
try:
temp_file_path = None
async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(follow_redirects=True) as client:
with tempfile.NamedTemporaryFile(mode="wb", delete=True) as temp_file:
temp_file_path = temp_file.name
if not allow_cache:
Expand Down Expand Up @@ -61,12 +61,13 @@ async def meta_generator(app: FastAPI, repo_type: Literal["models", "datasets"],
repos_path = app.app_settings.repos_path
save_dir = os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}")
save_path = os.path.join(save_dir, "meta.json")
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)

make_dirs(save_path)

use_cache = os.path.exists(save_path)
allow_cache = await check_cache_rules_hf(app, repo_type, org, repo)
meta_url = f"{app.app_settings.hf_url}/api/{repo_type}/{org}/{repo}/revision/{commit}"

org_repo = get_org_repo(org, repo)
meta_url = f"{app.app_settings.hf_url}/api/{repo_type}/{org_repo}/revision/{commit}"
# proxy
if use_cache:
async for item in meta_cache_generator(app, save_path):
Expand Down
102 changes: 62 additions & 40 deletions olah/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from olah.files import cdn_file_get_generator, file_get_generator, file_head_generator
from olah.lfs import lfs_get_generator
from olah.meta import meta_generator
from olah.utils import check_proxy_rules_hf, check_commit_hf, get_commit_hf, get_newest_commit_hf
from olah.utils import check_proxy_rules_hf, check_commit_hf, get_commit_hf, get_newest_commit_hf, parse_org_repo

app = FastAPI(debug=False)

Expand All @@ -28,13 +28,10 @@ class AppSettings(BaseSettings):

@app.get("/api/{repo_type}/{org_repo}")
async def meta_proxy(repo_type: str, org_repo: str, request: Request):
if "/" in org_repo and org_repo.count("/") != 1:
org, repo = parse_org_repo(org_repo)
if org is None and repo is None:
return Response(content="This repository is not accessible.", status_code=404)
if "/" in org_repo:
org, repo = org_repo.split("/")
else:
org = None
repo = org_repo

if not await check_proxy_rules_hf(app, repo_type, org, repo):
return Response(content="This repository is forbidden by the mirror.", status_code=403)
if not await check_commit_hf(app, repo_type, org, repo, None):
Expand All @@ -56,13 +53,10 @@ async def meta_proxy_commit2(repo_type: str, org: str, repo: str, commit: str, r

@app.get("/api/{repo_type}/{org_repo}/revision/{commit}")
async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request: Request):
if "/" in org_repo and org_repo.count("/") != 1:
org, repo = parse_org_repo(org_repo)
if org is None and repo is None:
return Response(content="This repository is not accessible.", status_code=404)
if "/" in org_repo:
org, repo = org_repo.split("/")
else:
org = None
repo = org_repo

if not await check_proxy_rules_hf(app, repo_type, org, repo):
return Response(content="This repository is forbidden by the mirror. ", status_code=403)
if not await check_commit_hf(app, repo_type, org, repo, commit):
Expand All @@ -72,8 +66,7 @@ async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request:
return StreamingResponse(generator, headers=headers)

@app.head("/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path:path}")
@app.head("/{org}/{repo}/resolve/{commit}/{file_path:path}")
async def file_head_proxy2(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str = "models"):
async def file_head_proxy2(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str):
if not await check_proxy_rules_hf(app, repo_type, org, repo):
return Response(content="This repository is forbidden by the mirror. ", status_code=403)
if not await check_commit_hf(app, repo_type, org, repo, commit):
Expand All @@ -83,29 +76,45 @@ async def file_head_proxy2(org: str, repo: str, commit: str, file_path: str, req
headers = await generator.__anext__()
return StreamingResponse(generator, headers=headers)

@app.head("/{repo_type}/{org_repo}/resolve/{commit}/{file_path:path}")
@app.head("/{org_or_repo_type}/{repo}/resolve/{commit}/{file_path:path}")
async def file_head_proxy(org_or_repo_type: str, repo: str, commit: str, file_path: str, request: Request):
if org_or_repo_type in ["models", "datasets", "spaces"]:
repo_type: str = org_or_repo_type
org, repo = parse_org_repo(repo)
if org is None and repo is None:
return Response(content="This repository is not accessible.", status_code=404)
else:
repo_type: str = "models"
org, repo = org_or_repo_type, repo

if not await check_proxy_rules_hf(app, repo_type, org, repo):
return Response(content="This repository is forbidden by the mirror. ", status_code=403)
if org is not None and not await check_commit_hf(app, repo_type, org, repo, commit):
return Response(content="This repository is not accessible. ", status_code=404)
commit_sha = await get_commit_hf(app, repo_type, org, repo, commit)
generator = await file_head_generator(app, repo_type, org, repo, commit_sha, file_path, request)
headers = await generator.__anext__()
return StreamingResponse(generator, headers=headers)

@app.head("/{org_repo}/resolve/{commit}/{file_path:path}")
async def file_head_proxy(org_repo: str, commit: str, file_path: str, request: Request, repo_type: str = "models"):
if "/" in org_repo and org_repo.count("/") != 1:
async def file_head_proxy_default_type(org_repo: str, commit: str, file_path: str, request: Request):
repo_type: str = "models"
org, repo = parse_org_repo(org_repo)
if org is None and repo is None:
return Response(content="This repository is not accessible.", status_code=404)
if "/" in org_repo:
org, repo = org_repo.split("/")
else:
org = None
repo = org_repo

if not await check_proxy_rules_hf(app, repo_type, org, repo):
return Response(content="This repository is forbidden by the mirror. ", status_code=403)
if not await check_commit_hf(app, repo_type, org, repo, commit):
if org is not None and not await check_commit_hf(app, repo_type, org, repo, commit):
return Response(content="This repository is not accessible. ", status_code=404)

commit_sha = await get_commit_hf(app, repo_type, org, repo, commit)
generator = await file_head_generator(app, repo_type, org, repo, commit_sha, file_path, request)
headers = await generator.__anext__()
return StreamingResponse(generator, headers=headers)

@app.get("/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path:path}")
@app.get("/{org}/{repo}/resolve/{commit}/{file_path:path}")
async def file_proxy2(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str = "models"):
async def file_proxy2(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str):
if not await check_proxy_rules_hf(app, repo_type, org, repo):
return Response(content="This repository is forbidden by the mirror. ", status_code=403)
if not await check_commit_hf(app, repo_type, org, repo, commit):
Expand All @@ -115,16 +124,32 @@ async def file_proxy2(org: str, repo: str, commit: str, file_path: str, request:
headers = await generator.__anext__()
return StreamingResponse(generator, headers=headers)

@app.get("/{repo_type}/{org_repo}/resolve/{commit}/{file_path:path}")
@app.get("/{org_or_repo_type}/{repo}/resolve/{commit}/{file_path:path}")
async def file_proxy2_default_type(org_or_repo_type: str, repo: str, commit: str, file_path: str, request: Request):
if org_or_repo_type in ["models", "datasets", "spaces"]:
repo_type: str = org_or_repo_type
org, repo = parse_org_repo(repo)
if org is None and repo is None:
return Response(content="This repository is not accessible.", status_code=404)
else:
repo_type: str = "models"
org, repo = org_or_repo_type, repo

if not await check_proxy_rules_hf(app, repo_type, org, repo):
return Response(content="This repository is forbidden by the mirror. ", status_code=403)
if org is not None and not await check_commit_hf(app, repo_type, org, repo, commit):
return Response(content="This repository is not accessible. ", status_code=404)
commit_sha = await get_commit_hf(app, repo_type, org, repo, commit)
generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path, request)
headers = await generator.__anext__()
return StreamingResponse(generator, headers=headers)

@app.get("/{org_repo}/resolve/{commit}/{file_path:path}")
async def file_proxy(org_repo: str, commit: str, file_path: str, request: Request, repo_type: str = "models"):
if "/" in org_repo and org_repo.count("/") != 1:
async def file_proxy_default_type(org_repo: str, commit: str, file_path: str, request: Request):
repo_type: str = "models"
org, repo = parse_org_repo(org_repo)
if org is None and repo is None:
return Response(content="This repository is not accessible.", status_code=404)
if "/" in org_repo:
org, repo = org_repo.split("/")
else:
org = None
repo = org_repo

if not await check_proxy_rules_hf(app, repo_type, org, repo):
return Response(content="This repository is forbidden by the mirror. ", status_code=403)
Expand All @@ -135,15 +160,12 @@ async def file_proxy(org_repo: str, commit: str, file_path: str, request: Reques
headers = await generator.__anext__()
return StreamingResponse(generator, headers=headers)


@app.get("/{repo_type}/{org_repo}/{hash_file}")
async def cdn_file_proxy(org_repo: str, hash_file: str, request: Request, repo_type: str = "models"):
if "/" in org_repo and org_repo.count("/") != 1:
org, repo = parse_org_repo(org_repo)
if org is None and repo is None:
return Response(content="This repository is not accessible.", status_code=404)
if "/" in org_repo:
org, repo = org_repo.split("/")
else:
org = None
repo = org_repo

if not await check_proxy_rules_hf(app, repo_type, org, repo):
return Response(content="This repository is forbidden by the mirror. ", status_code=403)
Expand Down
17 changes: 13 additions & 4 deletions olah/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import datetime
import os
import glob
from typing import Literal, Optional
from typing import Literal, Optional, Tuple
import json
import httpx
from olah.configs import OlahConfig
Expand All @@ -15,6 +15,16 @@ def get_org_repo(org: Optional[str], repo: str) -> str:
org_repo = f"{org}/{repo}"
return org_repo

def parse_org_repo(org_repo: str) -> Tuple[str, str]:
if "/" in org_repo and org_repo.count("/") != 1:
return None, None
if "/" in org_repo:
org, repo = org_repo.split("/")
else:
org = None
repo = org_repo
return org, repo

def get_meta_save_path(repos_path: str, repo_type: str, org: Optional[str], repo: str, commit: str) -> str:
return os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}")

Expand Down Expand Up @@ -69,9 +79,8 @@ async def get_commit_hf(app, repo_type: Optional[Literal["models", "datasets", "
return await get_commit_hf_offline(app, repo_type, org, repo, commit)
try:
async with httpx.AsyncClient() as client:
response = await client.get(url,
timeout=WORKER_API_TIMEOUT)
if response.status_code != 200:
response = await client.get(url, timeout=WORKER_API_TIMEOUT, follow_redirects=True)
if response.status_code not in [200, 307]:
return await get_commit_hf_offline(app, repo_type, org, repo, commit)
obj = json.loads(response.text)
return obj.get("sha", None)
Expand Down

0 comments on commit 9ca8a1f

Please sign in to comment.