Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor chunk offsets calculations #3998

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions api/python/quilt3/data_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import shutil
import stat
import threading
import typing as T
import types
import warnings
from codecs import iterdecode
Expand Down Expand Up @@ -268,6 +269,16 @@ def get_checksum_chunksize(file_size: int) -> int:
return chunksize


def get_chunk_offsets(file_size: int) -> T.Optional[list[tuple[int, int]]]:
if file_size < CHECKSUM_MULTIPART_THRESHOLD:
return None
chunksize = get_checksum_chunksize(file_size)
return [
(start, min(start + chunksize, file_size))
for start in range(0, file_size, chunksize)
]


_EMPTY_STRING_SHA256 = hashlib.sha256(b'').digest()


Expand Down Expand Up @@ -303,7 +314,7 @@ def _copy_local_file(ctx: WorkerContext, size: int, src_path: str, dest_path: st
def _upload_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_key: str):
s3_client = ctx.s3_client_provider.standard_client

if size < CHECKSUM_MULTIPART_THRESHOLD:
if (chunk_offsets := get_chunk_offsets(size)) is None:
with ReadFileChunk.from_filename(src_path, 0, size, [ctx.progress]) as fd:
resp = s3_client.put_object(
Body=fd,
Expand All @@ -323,10 +334,6 @@ def _upload_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str,
)
upload_id = resp['UploadId']

chunksize = get_checksum_chunksize(size)

chunk_offsets = list(range(0, size, chunksize))

lock = Lock()
remaining = len(chunk_offsets)
parts = [None] * remaining
Expand Down Expand Up @@ -363,8 +370,7 @@ def upload_part(i, start, end):
checksum, _ = resp['ChecksumSHA256'].split('-', 1)
ctx.done(PhysicalKey(dest_bucket, dest_key, version_id), checksum)

for i, start in enumerate(chunk_offsets):
end = min(start + chunksize, size)
for i, (start, end) in enumerate(chunk_offsets):
ctx.run(upload_part, i, start, end)


Expand Down Expand Up @@ -460,7 +466,7 @@ def _copy_remote_file(ctx: WorkerContext, size: int, src_bucket: str, src_key: s

s3_client = ctx.s3_client_provider.standard_client

if size < CHECKSUM_MULTIPART_THRESHOLD:
if (chunk_offsets := get_chunk_offsets(size)) is None:
params: Dict[str, Any] = dict(
CopySource=src_params,
Bucket=dest_bucket,
Expand All @@ -484,10 +490,6 @@ def _copy_remote_file(ctx: WorkerContext, size: int, src_bucket: str, src_key: s
)
upload_id = resp['UploadId']

chunksize = get_checksum_chunksize(size)

chunk_offsets = list(range(0, size, chunksize))

lock = Lock()
remaining = len(chunk_offsets)
parts = [None] * remaining
Expand Down Expand Up @@ -525,8 +527,7 @@ def upload_part(i, start, end):
checksum, _ = resp['ChecksumSHA256'].split('-', 1)
ctx.done(PhysicalKey(dest_bucket, dest_key, version_id), checksum)

for i, start in enumerate(chunk_offsets):
end = min(start + chunksize, size)
for i, (start, end) in enumerate(chunk_offsets):
ctx.run(upload_part, i, start, end)


Expand Down Expand Up @@ -1229,11 +1230,9 @@ def _process_url(src, size):

def calculate_checksum_bytes(data: bytes) -> str:
size = len(data)
chunksize = get_checksum_chunksize(size)

hashes = []
for start in range(0, size, chunksize):
end = min(start + chunksize, size)
for start, end in get_chunk_offsets(size) or [(0, size)]:
hashes.append(hashlib.sha256(data[start:end]).digest())

hashes_hash = hashlib.sha256(b''.join(hashes)).digest()
Expand Down
Loading