Skip to content

Commit

Permalink
Refactor chunk offsets calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
sir-sigurd committed Jun 12, 2024
1 parent b58a355 commit 8c05751
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions api/python/quilt3/data_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import shutil
import stat
import threading
import typing as T
import types
import warnings
from codecs import iterdecode
Expand Down Expand Up @@ -268,6 +269,16 @@ def get_checksum_chunksize(file_size: int) -> int:
return chunksize


def get_chunk_offsets(file_size: int) -> T.Optional[list[tuple[int, int]]]:
if file_size < CHECKSUM_MULTIPART_THRESHOLD:
return None
chunksize = get_checksum_chunksize(file_size)
return [
(start, min(start + chunksize, file_size))
for start in range(0, file_size, chunksize)
]


_EMPTY_STRING_SHA256 = hashlib.sha256(b'').digest()


Expand Down Expand Up @@ -303,7 +314,7 @@ def _copy_local_file(ctx: WorkerContext, size: int, src_path: str, dest_path: st
def _upload_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_key: str):
s3_client = ctx.s3_client_provider.standard_client

if size < CHECKSUM_MULTIPART_THRESHOLD:
if (chunk_offsets := get_chunk_offsets(size)) is None:
with ReadFileChunk.from_filename(src_path, 0, size, [ctx.progress]) as fd:
resp = s3_client.put_object(
Body=fd,
Expand All @@ -323,10 +334,6 @@ def _upload_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str,
)
upload_id = resp['UploadId']

chunksize = get_checksum_chunksize(size)

chunk_offsets = list(range(0, size, chunksize))

lock = Lock()
remaining = len(chunk_offsets)
parts = [None] * remaining
Expand Down Expand Up @@ -363,8 +370,7 @@ def upload_part(i, start, end):
checksum, _ = resp['ChecksumSHA256'].split('-', 1)
ctx.done(PhysicalKey(dest_bucket, dest_key, version_id), checksum)

for i, start in enumerate(chunk_offsets):
end = min(start + chunksize, size)
for i, (start, end) in enumerate(chunk_offsets):
ctx.run(upload_part, i, start, end)


Expand Down Expand Up @@ -460,7 +466,7 @@ def _copy_remote_file(ctx: WorkerContext, size: int, src_bucket: str, src_key: s

s3_client = ctx.s3_client_provider.standard_client

if size < CHECKSUM_MULTIPART_THRESHOLD:
if (chunk_offsets := get_chunk_offsets(size)) is None:
params: Dict[str, Any] = dict(
CopySource=src_params,
Bucket=dest_bucket,
Expand All @@ -484,10 +490,6 @@ def _copy_remote_file(ctx: WorkerContext, size: int, src_bucket: str, src_key: s
)
upload_id = resp['UploadId']

chunksize = get_checksum_chunksize(size)

chunk_offsets = list(range(0, size, chunksize))

lock = Lock()
remaining = len(chunk_offsets)
parts = [None] * remaining
Expand Down Expand Up @@ -525,8 +527,7 @@ def upload_part(i, start, end):
checksum, _ = resp['ChecksumSHA256'].split('-', 1)
ctx.done(PhysicalKey(dest_bucket, dest_key, version_id), checksum)

for i, start in enumerate(chunk_offsets):
end = min(start + chunksize, size)
for i, (start, end) in enumerate(chunk_offsets):
ctx.run(upload_part, i, start, end)


Expand Down Expand Up @@ -1229,11 +1230,9 @@ def _process_url(src, size):

def calculate_checksum_bytes(data: bytes) -> str:
size = len(data)
chunksize = get_checksum_chunksize(size)

hashes = []
for start in range(0, size, chunksize):
end = min(start + chunksize, size)
for start, end in get_chunk_offsets(size) or [(0, size)]:
hashes.append(hashlib.sha256(data[start:end]).digest())

hashes_hash = hashlib.sha256(b''.join(hashes)).digest()
Expand Down

0 comments on commit 8c05751

Please sign in to comment.