Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jul 8, 2024
1 parent fe933fd commit 0336380
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 125 deletions.
2 changes: 1 addition & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pip install -e .
python -m olah.server
```

然后将环境变量`HF_ENDPOINT`设置为镜像站点这里是http://localhost:8090
然后将环境变量`HF_ENDPOINT`设置为镜像站点(这里是http://localhost:8090)
Linux:
```bash
export HF_ENDPOINT=http://localhost:8090
Expand Down
9 changes: 7 additions & 2 deletions olah/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def __init__(self, path: Optional[str] = None) -> None:
self.repos_path = "./repos"
self.hf_url = "https://huggingface.co"
self.hf_lfs_url = "https://cdn-lfs.huggingface.co"
self.mirror_url = "http://localhost:8090"
self.mirror_lfs_url = "http://localhost:8090"
self.mirror_url = f"http://{self.host}:{self.port}"
self.mirror_lfs_url = f"http://{self.host}:{self.port}"

# accessibility
self.offline = False
Expand All @@ -103,6 +103,11 @@ def __init__(self, path: Optional[str] = None) -> None:

if path is not None:
self.read_toml(path)

# refresh urls
self.mirror_url = f"http://{self.host}:{self.port}"
self.mirror_lfs_url = f"http://{self.host}:{self.port}"


def empty_str(self, s: str) -> Optional[str]:
if s == "":
Expand Down
211 changes: 99 additions & 112 deletions olah/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,51 +9,79 @@
from starlette.datastructures import URL

from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT
from olah.utls import check_cache_rules_hf, get_org_repo


async def _file_head_cache_stream(app, save_path: str, request: Request):
with open(save_path, "r", encoding="utf-8") as f:
response_headers = json.loads(f.read())
if "location" in response_headers:
response_headers["location"] = response_headers["location"].replace(
app.app_settings.hf_url, app.app_settings.mirror_url
)
yield response_headers

from olah.utils import check_cache_rules_hf, get_org_repo, make_dirs
FILE_HEADER_TEMPLATE = {
"accept-ranges": "bytes",
"access-control-allow-origin": "*",
"cache-control": "public, max-age=604800, immutable, s-maxage=604800",
# "content-length": None,
# "content-type": "binary/octet-stream",
# "etag": None,
# "last-modified": None,
}

async def _file_cache_stream(save_path: str, head_path: str, request: Request):
if request.method.lower() == "head":
with open(head_path, "r", encoding="utf-8") as f:
response_headers = json.loads(f.read())
yield FILE_HEADER_TEMPLATE
elif request.method.lower() == "get":
yield FILE_HEADER_TEMPLATE
else:
raise Exception(f"Invalid Method type {request.method}")
with open(save_path, "rb") as f:
while True:
chunk = f.read(CHUNK_SIZE)
if not chunk:
break
yield chunk

async def _file_head_realtime_stream(
app,
save_path: str,
url: str,
headers,
request: Request,
method="HEAD",
allow_cache=True,
async def _file_realtime_stream(
app, save_path: str, head_path: str, url: str, request: Request, method="GET", allow_cache=True
):
async with httpx.AsyncClient() as client:
async with client.stream(
method=method,
url=url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
) as response:
response_headers = response.headers
response_headers = {k: v for k, v in response_headers.items()}
if allow_cache:
with open(save_path, "w", encoding="utf-8") as f:
f.write(json.dumps(response_headers, ensure_ascii=False))
if "location" in response_headers:
response_headers["location"] = response_headers["location"].replace(
app.app_settings.hf_url, app.app_settings.mirror_url
)
yield response_headers

async for raw_chunk in response.aiter_raw():
if not raw_chunk:
continue
yield raw_chunk
request_headers = {k: v for k, v in request.headers.items()}
request_headers.pop("host")
temp_file_path = None
try:
async with httpx.AsyncClient() as client:
with tempfile.NamedTemporaryFile(mode="wb", delete=True) as temp_file:
if not allow_cache or request.method.lower() == "head":
write_temp_file = False
else:
write_temp_file = True
async with client.stream(
method=method,
url=url,
headers=request_headers,
timeout=WORKER_API_TIMEOUT,
) as response:
response_headers = response.headers
response_headers_dict = {k: v for k, v in response_headers.items()}
if allow_cache:
if request.method.lower() == "head":
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(response_headers_dict, ensure_ascii=False))
if "location" in response_headers:
response_headers["location"] = response_headers["location"].replace(
app.app_settings.hf_lfs_url, app.app_settings.mirror_lfs_url
)
yield response_headers

async for raw_chunk in response.aiter_raw():
if not raw_chunk:
continue
if write_temp_file:
temp_file.write(raw_chunk)
yield raw_chunk
if not allow_cache:
temp_file_path = None
else:
temp_file_path = temp_file.name
if temp_file_path is not None:
shutil.copyfile(temp_file_path, save_path)
finally:
if temp_file_path is not None and os.path.exists(temp_file_path):
os.remove(temp_file_path)

async def file_head_generator(
app,
Expand All @@ -64,84 +92,39 @@ async def file_head_generator(
file_path: str,
request: Request,
):
headers = {k: v for k, v in request.headers.items()}
headers.pop("host")

# save
repos_path = app.app_settings.repos_path
head_path = os.path.join(
repos_path, f"heads/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}"
)
save_path = os.path.join(
repos_path, f"heads/{repo_type}/{org}/{repo}/resolve_head/{commit}/{file_path}"
repos_path, f"files/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}"
)
save_dir = os.path.dirname(save_path)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
make_dirs(head_path)
make_dirs(save_path)

use_cache = os.path.exists(save_path)
use_cache = os.path.exists(head_path) and os.path.exists(save_path)
allow_cache = await check_cache_rules_hf(app, repo_type, org, repo)

# proxy
if use_cache:
return _file_head_cache_stream(app=app, save_path=save_path, request=request)
return _file_cache_stream(save_path=save_path, head_path=head_path, request=request)
else:
if repo_type == "models":
url = f"{app.app_settings.hf_url}/{org}/{repo}/resolve/{commit}/{file_path}"
else:
url = f"{app.app_settings.hf_url}/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}"
return _file_head_realtime_stream(
return _file_realtime_stream(
app=app,
save_path=save_path,
head_path=head_path,
url=url,
headers=headers,
request=request,
method="HEAD",
allow_cache=allow_cache,
)


async def _file_cache_stream(save_path: str, request: Request):
yield request.headers
with open(save_path, "rb") as f:
while True:
chunk = f.read(CHUNK_SIZE)
if not chunk:
break
yield chunk


async def _file_realtime_stream(
save_path: str, url: str, headers, request: Request, method="GET", allow_cache=True
):
temp_file_path = None
try:
async with httpx.AsyncClient() as client:
with tempfile.NamedTemporaryFile(mode="wb", delete=False) as temp_file:
if not allow_cache:
temp_file = open(os.devnull, "wb")
async with client.stream(
method=method,
url=url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
) as response:
response_headers = response.headers
yield response_headers

async for raw_chunk in response.aiter_raw():
if not raw_chunk:
continue
temp_file.write(raw_chunk)
yield raw_chunk
if not allow_cache:
temp_file_path = None
else:
temp_file_path = temp_file.name
if temp_file_path is not None:
shutil.copyfile(temp_file_path, save_path)
finally:
if temp_file_path is not None:
os.remove(temp_file_path)


async def file_get_generator(
app,
repo_type: Literal["models", "datasets"],
Expand All @@ -151,32 +134,33 @@ async def file_get_generator(
file_path: str,
request: Request,
):
headers = {k: v for k, v in request.headers.items()}
headers.pop("host")
# save
repos_path = app.app_settings.repos_path
head_path = os.path.join(
repos_path, f"heads/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}"
)
save_path = os.path.join(
repos_path, f"files/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}"
)
save_dir = os.path.dirname(save_path)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
make_dirs(head_path)
make_dirs(save_path)

use_cache = os.path.exists(save_path)
use_cache = os.path.exists(head_path) and os.path.exists(save_path)
allow_cache = await check_cache_rules_hf(app, repo_type, org, repo)

# proxy
if use_cache:
return _file_cache_stream(save_path=save_path, request=request)
return _file_cache_stream(save_path=save_path, head_path=head_path, request=request)
else:
if repo_type == "models":
url = f"{app.app_settings.hf_url}/{org}/{repo}/resolve/{commit}/{file_path}"
else:
url = f"{app.app_settings.hf_url}/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}"
return _file_realtime_stream(
app=app,
save_path=save_path,
head_path=head_path,
url=url,
headers=headers,
request=request,
method="GET",
allow_cache=allow_cache,
Expand All @@ -197,27 +181,30 @@ async def cdn_file_get_generator(
org_repo = get_org_repo(org, repo)
# save
repos_path = app.app_settings.repos_path
head_path = os.path.join(
repos_path, f"heads/{repo_type}/{org}/{repo}/cdn/{file_hash}"
)
save_path = os.path.join(
repos_path, f"files/{repo_type}/cdn/{org}/{repo}/{file_hash}"
repos_path, f"files/{repo_type}/{org}/{repo}/cdn/{file_hash}"
)
save_dir = os.path.dirname(save_path)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
make_dirs(head_path)
make_dirs(save_path)

use_cache = os.path.exists(save_path)
use_cache = os.path.exists(head_path) and os.path.exists(save_path)
allow_cache = await check_cache_rules_hf(app, repo_type, org, repo)

# proxy
if use_cache:
return _file_cache_stream(save_path=save_path, request=request)
else:
redirected_url = str(request.url)
redirected_url = redirected_url.replace(app.app_settings.hf_lfs_url, app.app_settings.mirror_lfs_url)
redirected_url = redirected_url.replace(app.app_settings.mirror_lfs_url, app.app_settings.hf_lfs_url)

return _file_realtime_stream(
app=app,
save_path=save_path,
head_path=head_path,
url=str(redirected_url),
headers=headers,
request=request,
method="GET",
allow_cache=allow_cache,
Expand Down
10 changes: 5 additions & 5 deletions olah/lfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytz

from olah.constants import CHUNK_SIZE, LFS_FILE_BLOCK, WORKER_API_TIMEOUT
from olah.utils import make_dirs


async def lfs_get_generator(app, repo_type: str, lfs_url: str, save_path: str, request: Request):
Expand All @@ -17,8 +18,7 @@ async def lfs_get_generator(app, repo_type: str, lfs_url: str, save_path: str, r
# save
repos_path = app.app_settings.repos_path
save_dir = os.path.join(repos_path, f"lfs/{repo_type}/{save_path}")
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
make_dirs(save_dir)

# lfs meta
lfs_meta_path = os.path.join(save_dir, "meta.json")
Expand Down Expand Up @@ -109,7 +109,7 @@ async def lfs_get_generator(app, repo_type: str, lfs_url: str, save_path: str, r
try:
temp_file_path = None
async with httpx.AsyncClient() as client:
with tempfile.NamedTemporaryFile(mode="wb", delete=False) as temp_file:
with tempfile.NamedTemporaryFile(mode="wb", delete=True) as temp_file:
headers["range"] = f"bytes={block_start_pos}-{block_end_pos - 1}"
async with client.stream(
method="GET", url=lfs_url,
Expand Down Expand Up @@ -145,8 +145,8 @@ async def lfs_get_generator(app, repo_type: str, lfs_url: str, save_path: str, r
if raw_bytes >= block_end_pos - block_start_pos:
break
temp_file_path = temp_file.name
shutil.copyfile(temp_file_path, save_path)
shutil.copyfile(temp_file_path, save_path)
finally:
if temp_file_path is not None:
if temp_file_path is not None and os.path.exists(temp_file_path):
os.remove(temp_file_path)
cur_block += 1
2 changes: 1 addition & 1 deletion olah/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from olah.configs import OlahConfig
from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT

from olah.utls import check_cache_rules_hf
from olah.utils import check_cache_rules_hf

async def meta_cache_generator(app: FastAPI, save_path: str):
yield {}
Expand Down
2 changes: 1 addition & 1 deletion olah/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from olah.files import cdn_file_get_generator, file_get_generator, file_head_generator
from olah.lfs import lfs_get_generator
from olah.meta import meta_generator
from olah.utls import check_proxy_rules_hf, check_commit_hf, get_commit_hf, get_newest_commit_hf
from olah.utils import check_proxy_rules_hf, check_commit_hf, get_commit_hf, get_newest_commit_hf

app = FastAPI(debug=False)

Expand Down
8 changes: 8 additions & 0 deletions olah/utls.py → olah/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,11 @@ async def check_cache_rules_hf(app, repo_type: Optional[Literal["models", "datas
config: OlahConfig = app.app_settings.config
org_repo = get_org_repo(org, repo)
return config.cache.allow(f"{org_repo}")

def make_dirs(path: str):
if os.path.isdir(path):
save_dir = path
else:
save_dir = os.path.dirname(path)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
Loading

0 comments on commit 0336380

Please sign in to comment.