Skip to content

Commit

Permalink
Allow setting bucket-specific s3 clients
Browse files Browse the repository at this point in the history
To support Min.io and passing in S3 credentials, allow users to pass their own S3 clients for accessing specific buckets.

TODO: how will this support bucket-to-bucket copy across different clients?
  • Loading branch information
Kevin Moore committed Oct 16, 2023
1 parent 6ba9401 commit 80827f9
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions api/python/quilt3/data_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,26 @@ class S3ClientProvider:
We assume that public buckets are read-only: write operations should always use S3ClientProvider.standard_client
"""

_client_map = {}

@classmethod
def set_s3_client(cls, bucket: str, client):
assert bucket is not None
cls._client_map[bucket] = client

def __init__(self):
self._use_unsigned_client = {} # f'{action}/{bucket}' -> use_unsigned_client_bool
self._standard_client = None
self._unsigned_client = None

@property
def standard_client(self):
if self._standard_client is None:
self._build_standard_client()
return self._standard_client
def get_standard_client(self, bucket):
mapped_client = self.__class__._client_map.get(bucket)
if mapped_client is not None:
return mapped_client
else:
if self._standard_client is None:
self._build_standard_client()
return self._standard_client

@property
def unsigned_client(self):
Expand All @@ -115,7 +125,7 @@ def get_correct_client(self, action: S3Api, bucket: str):
if self.should_use_unsigned_client(action, bucket):
return self.unsigned_client
else:
return self.standard_client
return self.get_standard_client(bucket)

def key(self, action: S3Api, bucket: str):
return f"{action}/{bucket}"
Expand Down Expand Up @@ -144,9 +154,9 @@ def find_correct_client(self, api_type, bucket, param_dict):
f"API '{api_type}' is not current supported. You may want to use S3ClientProvider.standard_client " \
f"instead "
check_fn = check_fn_mapper[api_type]
if check_fn(self.standard_client, param_dict):
if check_fn(self.get_standard_client(bucket), param_dict):
self.set_cache(api_type, bucket, use_unsigned=False)
return self.standard_client
return self.get_standard_client(bucket)
else:
if check_fn(self.unsigned_client, param_dict):
self.set_cache(api_type, bucket, use_unsigned=True)
Expand Down Expand Up @@ -261,7 +271,7 @@ def _copy_local_file(ctx, size, src_path, dest_path):


def _upload_file(ctx, size, src_path, dest_bucket, dest_key):
s3_client = ctx.s3_client_provider.standard_client
s3_client = ctx.s3_client_provider.get_standard_client(dest_bucket)

if size < s3_transfer_config.multipart_threshold:
with OSUtils().open_file_chunk_reader(src_path, 0, size, [ctx.progress]) as fd:
Expand Down Expand Up @@ -401,7 +411,7 @@ def _copy_remote_file(ctx, size, src_bucket, src_key, src_version,
VersionId=src_version
)

s3_client = ctx.s3_client_provider.standard_client
s3_client = ctx.s3_client_provider.get_standard_client(dest_bucket)

if size < s3_transfer_config.multipart_threshold:
params = dict(
Expand Down Expand Up @@ -652,7 +662,7 @@ def _calculate_etag(file_path):


def delete_object(bucket, key):
s3_client = S3ClientProvider().standard_client
s3_client = S3ClientProvider().get_standard_client(bucke)

s3_client.head_object(Bucket=bucket, Key=key) # Make sure it exists
s3_client.delete_object(Bucket=bucket, Key=key) # Actually delete it
Expand Down Expand Up @@ -768,7 +778,7 @@ def delete_url(src: PhysicalKey):
except FileNotFoundError:
pass
else:
s3_client = S3ClientProvider().standard_client
s3_client = S3ClientProvider().get_standard_client(src.bucket)
s3_client.delete_object(Bucket=src.bucket, Key=src.path)


Expand Down Expand Up @@ -831,7 +841,7 @@ def put_bytes(data: bytes, dest: PhysicalKey):
else:
if dest.version_id is not None:
raise ValueError("Cannot set VersionId on destination")
s3_client = S3ClientProvider().standard_client
s3_client = S3ClientProvider().get_standard_client(dest.bucket)
s3_client.put_object(
Bucket=dest.bucket,
Key=dest.path,
Expand Down Expand Up @@ -1203,7 +1213,7 @@ def select(src, query, meta=None, raw=False, **kwargs):

# S3 Select does not support anonymous access (as of Jan 2019)
# https://docs.aws.amazon.com/AmazonS3/latest/API/API_SelectObjectContent.html
s3_client = S3ClientProvider().standard_client
s3_client = S3ClientProvider().get_standard_client(src.bucket)
response = s3_client.select_object_content(**select_kwargs)

# we don't want multiple copies of large chunks of data hanging around.
Expand Down

0 comments on commit 80827f9

Please sign in to comment.