Skip to content

Commit

Permalink
Close S3 client when closing S3Transfer
Browse files Browse the repository at this point in the history
[DDB-1160]
  • Loading branch information
kmichel-aiven committed Jul 22, 2024
1 parent 4514a9e commit 7ef096d
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 76 deletions.
175 changes: 100 additions & 75 deletions rohmu/object_storage/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,72 +127,97 @@ def __init__(
) -> None:
super().__init__(prefix=prefix, notifier=notifier, statsd_info=statsd_info)
self.bucket_name = bucket_name
self.location = ""
self.region = region
timeouts: dict[str, Any] = {}
if connect_timeout:
timeouts["connect_timeout"] = connect_timeout
if read_timeout:
timeouts["read_timeout"] = read_timeout
if not host or not port:
custom_config: dict[str, Any] = {**timeouts}
if proxy_info:
proxy_url = get_proxy_url(proxy_info)
custom_config["proxies"] = {"https": proxy_url}
if use_dualstack_endpoint is True:
custom_config["use_dualstack_endpoint"] = True
with self._get_session() as session:
self.s3_client = create_s3_client(
session=session,
config=botocore.config.Config(**custom_config),
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=region,
)
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.host = host
self.port = port
self.addressing_style = addressing_style
self.is_secure = is_secure
self.is_verify_tls = is_verify_tls
self.cert_path = cert_path
self.proxy_info = proxy_info
self.connect_timeout = connect_timeout
self.read_timeout = read_timeout
self.aws_session_token = aws_session_token
self.use_dualstack_endpoint = use_dualstack_endpoint
self.multipart_chunk_size = segment_size
self.encrypted = encrypted
self.s3_client: Optional[S3Client] = None
self.location = ""
if not self.host or not self.port:
if self.region and self.region != "us-east-1":
self.location = self.region
else:
scheme = "https" if is_secure else "http"
custom_url = f"{scheme}://{host}:{port}"
if self.region:
signature_version = "s3v4"
self.location = self.region
else:
signature_version = "s3"
proxies: Optional[dict[str, str]] = None
if proxy_info:
proxies = {"https": get_proxy_url(proxy_info)}
boto_config = botocore.client.Config(
s3={"addressing_style": S3AddressingStyle(addressing_style).value},
signature_version=signature_version,
proxies=proxies,
retries={
"max_attempts": 10,
"mode": "standard",
},
**timeouts,
)
if not is_verify_tls and cert_path is not None:
if not self.is_verify_tls and self.cert_path is not None:
raise ValueError("cert_path is set but is_verify_tls is False")
self.check_or_create_bucket()
self.log.debug("S3Transfer initialized")

with self._get_session() as session:
self.s3_client = create_s3_client(
session=session,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
config=boto_config,
endpoint_url=custom_url,
region_name=region,
verify=str(cert_path) if cert_path is not None and is_verify_tls else is_verify_tls,
def get_client(self) -> S3Client:
if self.s3_client is None:
timeouts: dict[str, Any] = {}
if self.connect_timeout:
timeouts["connect_timeout"] = self.connect_timeout
if self.read_timeout:
timeouts["read_timeout"] = self.read_timeout
if not self.host or not self.port:
custom_config: dict[str, Any] = {**timeouts}
if self.proxy_info:
proxy_url = get_proxy_url(self.proxy_info)
custom_config["proxies"] = {"https": proxy_url}
if self.use_dualstack_endpoint is True:
custom_config["use_dualstack_endpoint"] = True
with self._get_session() as session:
self.s3_client = create_s3_client(
session=session,
config=botocore.config.Config(**custom_config),
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
region_name=self.region,
)
else:
scheme = "https" if self.is_secure else "http"
custom_url = f"{scheme}://{self.host}:{self.port}"
if self.region:
signature_version = "s3v4"
else:
signature_version = "s3"
proxies: Optional[dict[str, str]] = None
if self.proxy_info:
proxies = {"https": get_proxy_url(self.proxy_info)}
boto_config = botocore.client.Config(
s3={"addressing_style": S3AddressingStyle(self.addressing_style).value},
signature_version=signature_version,
proxies=proxies,
retries={
"max_attempts": 10,
"mode": "standard",
},
**timeouts,
)
with self._get_session() as session:
self.s3_client = create_s3_client(
session=session,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
config=boto_config,
endpoint_url=custom_url,
region_name=self.region,
verify=str(self.cert_path)
if self.cert_path is not None and self.is_verify_tls
else self.is_verify_tls,
)
return self.s3_client

self.check_or_create_bucket()

self.multipart_chunk_size = segment_size
self.encrypted = encrypted
self.log.debug("S3Transfer initialized")
def close(self) -> None:
if self.s3_client is not None:
self.s3_client.close()
self.s3_client = None

# It is advantageous to share the Session as much as possible since the very
# large service model files (eg botocore/data/ec2/2016-11-15/service-2.json)
Expand Down Expand Up @@ -228,7 +253,7 @@ def _copy_file_from_bucket(
destination_path = self.format_key_for_backend(destination_key, remove_slash_prefix=True)
self.stats.operation(StorageOperation.copy_file)
try:
self.s3_client.copy_object(
self.get_client().copy_object(
Bucket=self.bucket_name,
CopySource=source_path,
Key=destination_path,
Expand Down Expand Up @@ -257,7 +282,7 @@ def get_metadata_for_key(self, key: str) -> Metadata:
def _metadata_for_key(self, key: str) -> Metadata:
self.stats.operation(StorageOperation.metadata_for_key)
try:
response = self.s3_client.head_object(Bucket=self.bucket_name, Key=key)
response = self.get_client().head_object(Bucket=self.bucket_name, Key=key)
except botocore.exceptions.ClientError as ex:
status_code = ex.response.get("ResponseMetadata", {}).get("HTTPStatusCode")
if status_code == 404:
Expand All @@ -272,13 +297,13 @@ def delete_key(self, key: str) -> None:
self.log.debug("Deleting key: %r", path)
self._metadata_for_key(path) # check that key exists
self.stats.operation(StorageOperation.delete_key)
self.s3_client.delete_object(Bucket=self.bucket_name, Key=path)
self.get_client().delete_object(Bucket=self.bucket_name, Key=path)
self.notifier.object_deleted(key=key)

def delete_keys(self, keys: Collection[str]) -> None:
self.stats.operation(StorageOperation.delete_key, count=len(keys))
for batch in batched(keys, 1000): # Cannot delete more than 1000 objects at a time
self.s3_client.delete_objects(
self.get_client().delete_objects(
Bucket=self.bucket_name,
Delete={"Objects": [{"Key": self.format_key_for_backend(key, remove_slash_prefix=True)} for key in batch]},
)
Expand All @@ -303,7 +328,7 @@ def iter_key(
if continuation_token:
args["ContinuationToken"] = continuation_token
self.stats.operation(StorageOperation.iter_key)
response = self.s3_client.list_objects_v2(**args)
response = self.get_client().list_objects_v2(**args)

for item in response.get("Contents", []):
if with_metadata:
Expand Down Expand Up @@ -345,7 +370,7 @@ def _get_object_stream(self, key: str, byte_range: Optional[tuple[int, int]]) ->
# Actual usage is accounted for in
# _read_object_to_fileobj, although that omits the initial
# get_object call if it fails.
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=path, **kwargs)
response = self.get_client().get_object(Bucket=self.bucket_name, Key=path, **kwargs)
except botocore.exceptions.ClientError as ex:
status_code = ex.response.get("ResponseMetadata", {}).get("HTTPStatusCode")
if status_code == 404:
Expand Down Expand Up @@ -387,7 +412,7 @@ def get_file_size(self, key: str) -> int:
path = self.format_key_for_backend(key, remove_slash_prefix=True)
self.stats.operation(StorageOperation.get_file_size)
try:
response = self.s3_client.head_object(Bucket=self.bucket_name, Key=path)
response = self.get_client().head_object(Bucket=self.bucket_name, Key=path)
return int(response["ContentLength"])
except botocore.exceptions.ClientError as ex:
if ex.response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 404:
Expand Down Expand Up @@ -420,7 +445,7 @@ def multipart_upload_file_object(

self.stats.operation(StorageOperation.create_multipart_upload)
try:
cmu_response = self.s3_client.create_multipart_upload(**args)
cmu_response = self.get_client().create_multipart_upload(**args)
except botocore.exceptions.ClientError as ex:
raise StorageError(f"Failed to initiate multipart upload for {path}") from ex

Expand All @@ -434,7 +459,7 @@ def multipart_upload_file_object(
start_of_part_upload = time.monotonic()
self.stats.operation(StorageOperation.store_file, size=len(data))
try:
cup_response = self.s3_client.upload_part(
cup_response = self.get_client().upload_part(
Body=data,
Bucket=self.bucket_name,
Key=path,
Expand All @@ -445,7 +470,7 @@ def multipart_upload_file_object(
self.log.exception("Uploading part %d for %s failed", part_number, path)
self.stats.operation(StorageOperation.multipart_aborted)
try:
self.s3_client.abort_multipart_upload(
self.get_client().abort_multipart_upload(
Bucket=self.bucket_name,
Key=path,
UploadId=mp_id,
Expand Down Expand Up @@ -475,7 +500,7 @@ def multipart_upload_file_object(

self.stats.operation(StorageOperation.multipart_complete)
try:
self.s3_client.complete_multipart_upload(
self.get_client().complete_multipart_upload(
Bucket=self.bucket_name,
Key=path,
MultipartUpload={"Parts": parts},
Expand All @@ -484,7 +509,7 @@ def multipart_upload_file_object(
except botocore.exceptions.ClientError as ex:
try:
self.stats.operation(StorageOperation.multipart_aborted)
self.s3_client.abort_multipart_upload(
self.get_client().abort_multipart_upload(
Bucket=self.bucket_name,
Key=path,
UploadId=mp_id,
Expand Down Expand Up @@ -529,7 +554,7 @@ def store_file_from_memory(
if mimetype is not None:
args["ContentType"] = mimetype
self.stats.operation(StorageOperation.store_file, size=len(data))
self.s3_client.put_object(**args)
self.get_client().put_object(**args)
self.notifier.object_created(key=key, size=len(data), metadata=sanitized_metadata)

def store_file_object(
Expand Down Expand Up @@ -565,7 +590,7 @@ def check_or_create_bucket(self) -> None:
create_bucket = False
self.stats.operation(StorageOperation.head_request)
try:
self.s3_client.head_bucket(Bucket=self.bucket_name)
self.get_client().head_bucket(Bucket=self.bucket_name)
except botocore.exceptions.ClientError as ex:
# https://docs.aws.amazon.com/AmazonS3/latest/API/API_HeadBucket.html
status_code = ex.response.get("ResponseMetadata", {}).get("HTTPStatusCode")
Expand All @@ -590,7 +615,7 @@ def check_or_create_bucket(self) -> None:
}

self.stats.operation(StorageOperation.create_bucket)
self.s3_client.create_bucket(**args)
self.get_client().create_bucket(**args)

def create_concurrent_upload(
self,
Expand All @@ -603,7 +628,7 @@ def create_concurrent_upload(

self.stats.operation(StorageOperation.create_multipart_upload)
try:
cmu_response = self.s3_client.create_multipart_upload(**args)
cmu_response = self.get_client().create_multipart_upload(**args)
except botocore.exceptions.ClientError as ex:
raise ConcurrentUploadError(f"Failed to initiate multipart upload for {path}") from ex

Expand All @@ -617,7 +642,7 @@ def complete_concurrent_upload(self, upload: ConcurrentUpload) -> None:
)
try:
self.stats.operation(StorageOperation.multipart_complete)
self.s3_client.complete_multipart_upload(
self.get_client().complete_multipart_upload(
Bucket=self.bucket_name,
Key=backend_key,
MultipartUpload={"Parts": sorted_chunks},
Expand All @@ -631,7 +656,7 @@ def abort_concurrent_upload(self, upload: ConcurrentUpload) -> None:
backend_key = self.format_key_for_backend(upload.key, remove_slash_prefix=True)
try:
self.stats.operation(StorageOperation.multipart_aborted)
self.s3_client.abort_multipart_upload(
self.get_client().abort_multipart_upload(
Bucket=self.bucket_name,
Key=backend_key,
UploadId=upload.backend_id,
Expand All @@ -654,7 +679,7 @@ def upload_concurrent_chunk(
backend_key = self.format_key_for_backend(upload.key, remove_slash_prefix=True)
try:
upload_func = partial(
self.s3_client.upload_part,
self.get_client().upload_part,
Bucket=self.bucket_name,
Key=backend_key,
UploadId=upload.backend_id,
Expand Down
8 changes: 8 additions & 0 deletions test/object_storage/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def _get_session(cls: S3Transfer) -> Iterator[MagicMock]:
yield S3Infra(notifier, operation, s3_client, transfer)


def test_close(infra: S3Infra) -> None:
infra.transfer.get_client()
assert infra.transfer.s3_client is not None
infra.transfer.close()
assert infra.transfer.s3_client is None
infra.s3_client.close.assert_called_once()


def test_store_file_from_disk(infra: S3Infra) -> None:
test_data = b"test-data"
metadata = {"Content-Length": len(test_data), "some-date": datetime(2022, 11, 15, 18, 30, 58, 486644)}
Expand Down
3 changes: 2 additions & 1 deletion test/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ def test_get_transfer_s3(
mock_config_model.return_value = S3ObjectStorageConfig(**expected_config_arg)

transfer_object = get_transfer(config)
assert isinstance(transfer_object, S3Transfer)
transfer_object.get_client()

mock_config_model.assert_called_once_with(**expected_config_arg)
mock_from_model.assert_called_once_with(mock_config_model(), mock_notifier.return_value)
mock_notifier.assert_called_once_with(url=config["notifier"]["url"])
assert isinstance(transfer_object, S3Transfer)
assert transfer_object.bucket_name == "dummy-bucket"
mock_botocore_config.assert_called_once_with(**expected_botocore_config)
mock_s3_client.assert_called_once_with(
Expand Down

0 comments on commit 7ef096d

Please sign in to comment.