Skip to content

Commit

Permalink
_test_header_block bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jul 13, 2024
1 parent e34cbce commit 7df7586
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 91 deletions.
108 changes: 20 additions & 88 deletions olah/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,33 +32,6 @@
# "last-modified": None,
}

async def _file_cache_stream(save_path: str, head_path: str, request: Request):
if request.method.lower() == "head":
with open(head_path, "r", encoding="utf-8") as f:
response_headers = json.loads(f.read())
response_headers = {k.lower():v for k, v in response_headers.items()}
new_headers = {k.lower():v for k, v in FILE_HEADER_TEMPLATE.items()}
new_headers["content-type"] = response_headers["content-type"]
new_headers["content-length"] = response_headers["content-length"]
if HUGGINGFACE_HEADER_X_REPO_COMMIT.lower() in response_headers:
new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = response_headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT.lower(), "")
if HUGGINGFACE_HEADER_X_LINKED_ETAG.lower() in response_headers:
new_headers[HUGGINGFACE_HEADER_X_LINKED_ETAG.lower()] = response_headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG.lower(), "")
if HUGGINGFACE_HEADER_X_LINKED_SIZE.lower() in response_headers:
new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE.lower()] = response_headers.get(HUGGINGFACE_HEADER_X_LINKED_SIZE.lower(), "")
new_headers["etag"] = response_headers["etag"]
yield new_headers
elif request.method.lower() == "get":
yield FILE_HEADER_TEMPLATE
else:
raise Exception(f"Invalid Method type {request.method}")
with open(save_path, "rb") as f:
while True:
chunk = f.read(CHUNK_SIZE)
if not chunk:
break
yield chunk

async def _get_redirected_url(client: httpx.AsyncClient, method: str, url: str, headers: Dict[str, str]):
async with client.stream(
method=method,
Expand Down Expand Up @@ -86,7 +59,6 @@ async def _file_full_header(
url: str,
headers: Dict[str, str],
allow_cache: bool,
commit: Optional[str] = None,
):
if os.path.exists(head_path):
with open(head_path, "r", encoding="utf-8") as f:
Expand Down Expand Up @@ -118,49 +90,11 @@ async def _file_full_header(
new_headers["etag"] = response_headers_dict["etag"]
return new_headers

async def _file_header(
app,
save_path: str,
head_path: str,
client: httpx.AsyncClient,
method: str,
url: str,
headers: Dict[str, str],
allow_cache: bool,
commit: Optional[str] = None,
):
if os.path.exists(head_path):
with open(head_path, "r", encoding="utf-8") as f:
response_headers = json.loads(f.read())
response_headers_dict = {k.lower():v for k, v in response_headers.items()}
else:
async with client.stream(
method=method,
url=url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
) as response:
response_headers_dict = {k.lower(): v for k, v in response.headers.items()}
if allow_cache and method.lower() == "head":
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(response_headers_dict, ensure_ascii=False))

new_headers = {}
new_headers["content-type"] = response_headers_dict["content-type"]
new_headers["content-length"] = response_headers_dict["content-length"]
if HUGGINGFACE_HEADER_X_REPO_COMMIT.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_REPO_COMMIT.lower(), "")
if HUGGINGFACE_HEADER_X_LINKED_ETAG.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_LINKED_ETAG.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_ETAG.lower(), "")
if HUGGINGFACE_HEADER_X_LINKED_SIZE.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_SIZE.lower(), "")
new_headers["etag"] = response_headers_dict["etag"]
return new_headers

async def _get_file_block_from_cache(cache_file: OlahCache, block_index: int):
return cache_file.read_block(block_index)
raw_block = cache_file.read_block(block_index)
return raw_block

async def _get_file_block_from_remote(client, remote_info: RemoteInfo, cache_file: OlahCache, block_index: int):
async def _get_file_block_from_remote(client: httpx.AsyncClient, remote_info: RemoteInfo, cache_file: OlahCache, block_index: int):
block_start_pos = block_index * cache_file._get_block_size()
block_end_pos = min(
(block_index + 1) * cache_file._get_block_size(), cache_file._get_file_size()
Expand All @@ -183,8 +117,7 @@ async def _get_file_block_from_remote(client, remote_info: RemoteInfo, cache_fil
raise Exception(f"The block is incomplete. Expected-{block_end_pos - block_start_pos}. Accepted-{len(raw_block)}")
if len(raw_block) < cache_file.header.block_size:
raw_block += b"\x00" * (cache_file.header.block_size - len(raw_block))
# print(len(raw_block))
return bytes(raw_block)
return raw_block

async def _file_chunk_get(
app,
Expand All @@ -196,7 +129,6 @@ async def _file_chunk_get(
headers: Dict[str, str],
allow_cache: bool,
file_size: int,
commit: Optional[str] = None,
):
# Redirect Chunks
if os.path.exists(save_path):
Expand All @@ -220,7 +152,7 @@ async def _file_chunk_get(
)
if cache_file.has_block(cur_block):
raw_block = await _get_file_block_from_cache(
cache_file, cur_block
cache_file, cur_block,
)
else:
raw_block = await _get_file_block_from_remote(
Expand All @@ -229,17 +161,19 @@ async def _file_chunk_get(
cache_file,
cur_block,
)
cache_file.write_block(cur_block, raw_block)

if len(raw_block) != cache_file._get_block_size():
raise Exception(f"The size of raw block {len(raw_block)} is different from blocksize {cache_file._get_block_size()}.")

s = cur_pos - block_start_pos
e = block_end_pos - block_start_pos
chunk = raw_block[s:e]

if len(chunk) != 0:
yield chunk
yield bytes(chunk)
cur_pos += len(chunk)

if len(raw_block) != cache_file._get_block_size():
raise Exception(f"The size of raw block {len(raw_block)} is different from blocksize {cache_file._get_block_size()}.")
if not cache_file.has_block(cur_block) and allow_cache:
cache_file.write_block(cur_block, raw_block)

cur_block += 1
finally:
cache_file.close()
Expand All @@ -254,7 +188,6 @@ async def _file_chunk_head(
headers: Dict[str, str],
allow_cache: bool,
file_size: int,
commit: Optional[str] = None,
):
async with client.stream(
method=method,
Expand Down Expand Up @@ -292,13 +225,12 @@ async def _file_realtime_stream(
url=redirect_loc,
headers=request_headers,
allow_cache=allow_cache,
commit=commit,
)
file_size = int(head_info["content-length"])
response_headers = {k: v for k,v in head_info.items()}
if "range" in request_headers:
start_pos, end_pos = parse_range_params(request_headers.get("range", f"bytes={0}-{file_size}"), file_size)
response_headers["content-length"] = end_pos - start_pos
response_headers["content-length"] = str(end_pos - start_pos)
if commit is not None:
response_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = commit
yield response_headers
Expand All @@ -313,10 +245,9 @@ async def _file_realtime_stream(
headers=request_headers,
allow_cache=allow_cache,
file_size=file_size,
commit=commit,
):
yield each_chunk
else:
elif method.lower() == "head":
async for each_chunk in _file_chunk_head(
app=app,
save_path=save_path,
Expand All @@ -327,9 +258,10 @@ async def _file_realtime_stream(
headers=request_headers,
allow_cache=allow_cache,
file_size=0,
commit=commit,
):
yield each_chunk
else:
raise Exception(f"Unsupported method: {method}")


async def file_head_generator(
Expand All @@ -353,7 +285,7 @@ async def file_head_generator(
make_dirs(head_path)
make_dirs(save_path)

use_cache = os.path.exists(head_path) and os.path.exists(save_path)
# use_cache = os.path.exists(head_path) and os.path.exists(save_path)
allow_cache = await check_cache_rules_hf(app, repo_type, org, repo)

# proxy
Expand Down Expand Up @@ -394,7 +326,7 @@ async def file_get_generator(
make_dirs(head_path)
make_dirs(save_path)

use_cache = os.path.exists(head_path) and os.path.exists(save_path)
# use_cache = os.path.exists(head_path) and os.path.exists(save_path)
allow_cache = await check_cache_rules_hf(app, repo_type, org, repo)

# proxy
Expand Down Expand Up @@ -437,7 +369,7 @@ async def cdn_file_get_generator(
make_dirs(head_path)
make_dirs(save_path)

use_cache = os.path.exists(head_path) and os.path.exists(save_path)
# use_cache = os.path.exists(head_path) and os.path.exists(save_path)
allow_cache = await check_cache_rules_hf(app, repo_type, org, repo)

# proxy
Expand Down
7 changes: 4 additions & 3 deletions olah/utils/olah_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

CURRENT_OLAH_CACHE_VERSION = 8
DEFAULT_BLOCK_MASK_MAX = 1024 * 1024
DEFAULT_BLOCK_SIZE = 8 * 1024 * 1024
DEFAULT_BLOCK_SIZE = 16 * 1024 * 1024


class OlahCacheHeader(object):
Expand Down Expand Up @@ -183,15 +183,16 @@ def _set_header_block(self, block_index: int):

def _test_header_block(self, block_index: int):
with self.header_lock:
self.header.block_mask.test(block_index)
result = self.header.block_mask.test(block_index)
return result

def flush(self):
if not self.is_open:
raise Exception("This file has been close.")
self._flush_header()

def has_block(self, block_index: int) -> bool:
self._test_header_block(block_index)
return self._test_header_block(block_index)

def read_block(self, block_index: int) -> Optional[bytes]:
if not self.is_open:
Expand Down

0 comments on commit 7df7586

Please sign in to comment.