From 472dcffbc014738d37b766db7308e2ced76b6112 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Thu, 5 Sep 2024 13:47:54 +0800 Subject: [PATCH] _get_file_range_from_remote authorization bug fix --- olah/proxy/files.py | 11 ++++++++++- olah/server.py | 17 +++++++++++++++-- olah/utils/repo_utils.py | 3 ++- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/olah/proxy/files.py b/olah/proxy/files.py index 0f17c59..1fd9b8f 100644 --- a/olah/proxy/files.py +++ b/olah/proxy/files.py @@ -238,7 +238,9 @@ async def _get_file_range_from_remote( end_pos: int, ): headers = {} + headers["authorization"] = remote_info.headers.get("authorization", None) headers["range"] = f"bytes={start_pos}-{end_pos - 1}" + chunk_bytes = 0 raw_data = b"" async with client.stream( @@ -537,7 +539,14 @@ async def _file_realtime_stream( else: if method.lower() == "head": async with httpx.AsyncClient() as client: - response = await client.request(method="head", url=hf_url,headers={},timeout=WORKER_API_TIMEOUT) + response = await client.request( + method="head", + url=hf_url, + headers={ + "authorization": request.headers.get("authorization", None) + }, + timeout=WORKER_API_TIMEOUT, + ) if "etag" in response.headers: response_headers["etag"] = response.headers["etag"] else: diff --git a/olah/server.py b/olah/server.py index aea230a..f9a6a7a 100644 --- a/olah/server.py +++ b/olah/server.py @@ -255,7 +255,13 @@ async def meta_proxy(repo_type: str, org_repo: str, request: Request): if org is None and repo is None: return error_repo_not_found() if not app.app_settings.config.offline: - new_commit = await get_newest_commit_hf(app, repo_type, org, repo) + new_commit = await get_newest_commit_hf( + app, + repo_type, + org, + repo, + authorization=request.headers.get("authorization", None), + ) if new_commit is None: return error_repo_not_found() else: @@ -269,11 +275,18 @@ async def meta_proxy(repo_type: str, org_repo: str, request: Request): authorization=request.headers.get("authorization", None), ) + @app.head("/api/{repo_type}/{org}/{repo}") @app.get("/api/{repo_type}/{org}/{repo}") async def meta_proxy(repo_type: str, org: str, repo: str, request: Request): if not app.app_settings.config.offline: - new_commit = await get_newest_commit_hf(app, repo_type, org, repo) + new_commit = await get_newest_commit_hf( + app, + repo_type, + org, + repo, + authorization=request.headers.get("authorization", None), + ) if new_commit is None: return error_repo_not_found() else: diff --git a/olah/utils/repo_utils.py b/olah/utils/repo_utils.py index a006859..0f7f784 100644 --- a/olah/utils/repo_utils.py +++ b/olah/utils/repo_utils.py @@ -167,6 +167,7 @@ async def get_newest_commit_hf( repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str, + authorization: Optional[str] = None, ) -> Optional[str]: """ Retrieves the newest commit hash for a repository. @@ -188,7 +189,7 @@ async def get_newest_commit_hf( return await get_newest_commit_hf_offline(app, repo_type, org, repo) try: async with httpx.AsyncClient() as client: - response = await client.get(url, timeout=WORKER_API_TIMEOUT) + response = await client.get(url, headers={"authorization": authorization}, timeout=WORKER_API_TIMEOUT) if response.status_code != 200: return await get_newest_commit_hf_offline(app, repo_type, org, repo) obj = json.loads(response.text)