From 80827f93d136befcb8c7bf19114341f5a5c24cc9 Mon Sep 17 00:00:00 2001 From: Kevin Moore Date: Mon, 16 Oct 2023 08:59:35 -0700 Subject: [PATCH] Allow setting bucket-specific s3 clients To support Min.io and passing in S3 credentials, allow users to pass their own S3 clients for accessing specific buckets. TODO: how will this support bucket-to-bucket copy across different clients? --- api/python/quilt3/data_transfer.py | 38 +++++++++++++++++++----------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/api/python/quilt3/data_transfer.py b/api/python/quilt3/data_transfer.py index e8e9a29d5ff..d1a73f3a7b6 100644 --- a/api/python/quilt3/data_transfer.py +++ b/api/python/quilt3/data_transfer.py @@ -90,16 +90,26 @@ class S3ClientProvider: We assume that public buckets are read-only: write operations should always use S3ClientProvider.standard_client """ + _client_map = {} + + @classmethod + def set_s3_client(cls, bucket: str, client): + assert bucket is not None + cls._client_map[bucket] = client + def __init__(self): self._use_unsigned_client = {} # f'{action}/{bucket}' -> use_unsigned_client_bool self._standard_client = None self._unsigned_client = None - @property - def standard_client(self): - if self._standard_client is None: - self._build_standard_client() - return self._standard_client + def get_standard_client(self, bucket): + mapped_client = self.__class__._client_map.get(bucket) + if mapped_client is not None: + return mapped_client + else: + if self._standard_client is None: + self._build_standard_client() + return self._standard_client @property def unsigned_client(self): @@ -115,7 +125,7 @@ def get_correct_client(self, action: S3Api, bucket: str): if self.should_use_unsigned_client(action, bucket): return self.unsigned_client else: - return self.standard_client + return self.get_standard_client(bucket) def key(self, action: S3Api, bucket: str): return f"{action}/{bucket}" @@ -144,9 +154,9 @@ def find_correct_client(self, api_type, bucket, param_dict): f"API '{api_type}' is not current supported. You may want to use S3ClientProvider.standard_client " \ f"instead " check_fn = check_fn_mapper[api_type] - if check_fn(self.standard_client, param_dict): + if check_fn(self.get_standard_client(bucket), param_dict): self.set_cache(api_type, bucket, use_unsigned=False) - return self.standard_client + return self.get_standard_client(bucket) else: if check_fn(self.unsigned_client, param_dict): self.set_cache(api_type, bucket, use_unsigned=True) @@ -261,7 +271,7 @@ def _copy_local_file(ctx, size, src_path, dest_path): def _upload_file(ctx, size, src_path, dest_bucket, dest_key): - s3_client = ctx.s3_client_provider.standard_client + s3_client = ctx.s3_client_provider.get_standard_client(dest_bucket) if size < s3_transfer_config.multipart_threshold: with OSUtils().open_file_chunk_reader(src_path, 0, size, [ctx.progress]) as fd: @@ -401,7 +411,7 @@ def _copy_remote_file(ctx, size, src_bucket, src_key, src_version, VersionId=src_version ) - s3_client = ctx.s3_client_provider.standard_client + s3_client = ctx.s3_client_provider.get_standard_client(dest_bucket) if size < s3_transfer_config.multipart_threshold: params = dict( @@ -652,7 +662,7 @@ def _calculate_etag(file_path): def delete_object(bucket, key): - s3_client = S3ClientProvider().standard_client + s3_client = S3ClientProvider().get_standard_client(bucke) s3_client.head_object(Bucket=bucket, Key=key) # Make sure it exists s3_client.delete_object(Bucket=bucket, Key=key) # Actually delete it @@ -768,7 +778,7 @@ def delete_url(src: PhysicalKey): except FileNotFoundError: pass else: - s3_client = S3ClientProvider().standard_client + s3_client = S3ClientProvider().get_standard_client(src.bucket) s3_client.delete_object(Bucket=src.bucket, Key=src.path) @@ -831,7 +841,7 @@ def put_bytes(data: bytes, dest: PhysicalKey): else: if dest.version_id is not None: raise ValueError("Cannot set VersionId on destination") - s3_client = S3ClientProvider().standard_client + s3_client = S3ClientProvider().get_standard_client(dest.bucket) s3_client.put_object( Bucket=dest.bucket, Key=dest.path, @@ -1203,7 +1213,7 @@ def select(src, query, meta=None, raw=False, **kwargs): # S3 Select does not support anonymous access (as of Jan 2019) # https://docs.aws.amazon.com/AmazonS3/latest/API/API_SelectObjectContent.html - s3_client = S3ClientProvider().standard_client + s3_client = S3ClientProvider().get_standard_client(src.bucket) response = s3_client.select_object_content(**select_kwargs) # we don't want multiple copies of large chunks of data hanging around.