Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Close google/azure/s3 client when closing GoogleTransfer object [DDB-1160] #187

Merged
merged 4 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions rohmu/object_storage/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ def __init__(
self.account_key = account_key
self.container_name = bucket_name
self.sas_token = sas_token
conn_str = self.conn_string(
self._conn_str = self.conn_string(
account_name=account_name,
account_key=account_key,
azure_cloud=azure_cloud,
host=host,
port=port,
is_secure=is_secure,
)
config: dict[str, Any] = {"max_block_size": MAX_BLOCK_SIZE}
self._config: dict[str, Any] = {"max_block_size": MAX_BLOCK_SIZE}
if proxy_info:
username = proxy_info.get("user")
password = proxy_info.get("pass")
Expand All @@ -93,16 +93,25 @@ def __init__(
schema = "socks5"
else:
schema = "http"
config["proxies"] = {"https": f"{schema}://{auth}{proxy_host}:{proxy_port}"}

self.conn: BlobServiceClient = BlobServiceClient.from_connection_string(
conn_str=conn_str,
credential=self.sas_token,
**config,
)
self._config["proxies"] = {"https": f"{schema}://{auth}{proxy_host}:{proxy_port}"}
self._blob_service_client: Optional[BlobServiceClient] = None
self.container = self.get_or_create_container(self.container_name)
self.log.debug("AzureTransfer initialized, %r", self.container_name)

def get_blob_service_client(self) -> BlobServiceClient:
if self._blob_service_client is None:
self._blob_service_client = BlobServiceClient.from_connection_string(
conn_str=self._conn_str,
credential=self.sas_token,
**self._config,
)
return self._blob_service_client

def close(self) -> None:
if self._blob_service_client is not None:
self._blob_service_client.close()
self._blob_service_client = None

@staticmethod
def conn_string(
account_name: str,
Expand Down Expand Up @@ -142,11 +151,11 @@ def _copy_file_from_bucket(
timeout: float = 15.0,
) -> None:
source_path = source_bucket.format_key_for_backend(source_key, remove_slash_prefix=True, trailing_slash=False)
source_client = source_bucket.conn.get_blob_client(source_bucket.container_name, source_path)
source_client = source_bucket.get_blob_service_client().get_blob_client(source_bucket.container_name, source_path)
source_url = source_client.url

destination_path = self.format_key_for_backend(destination_key, remove_slash_prefix=True, trailing_slash=False)
destination_client = self.conn.get_blob_client(self.container_name, destination_path)
destination_client = self.get_blob_service_client().get_blob_client(self.container_name, destination_path)
start = time.monotonic()
destination_client.start_copy_from_url(source_url, metadata=metadata, timeout=timeout)
while True:
Expand Down Expand Up @@ -219,7 +228,7 @@ def iter_key(

def _iter_key(self, *, path: str, with_metadata: bool, deep: bool) -> Iterator[IterKeyItem]:
include = "metadata" if with_metadata else None
container_client = self.conn.get_container_client(self.container_name)
container_client = self.get_blob_service_client().get_container_client(self.container_name)
name_starts_with = None
delimiter = ""
if path:
Expand Down Expand Up @@ -254,7 +263,7 @@ def delete_key(self, key: str) -> None:
path = self.format_key_for_backend(key, remove_slash_prefix=True)
self.log.debug("Deleting key: %r", path)
try:
blob_client = self.conn.get_blob_client(container=self.container_name, blob=path)
blob_client = self.get_blob_service_client().get_blob_client(container=self.container_name, blob=path)
result = blob_client.delete_blob()
except azure.core.exceptions.ResourceNotFoundError as ex:
raise FileNotFoundFromStorageError(path) from ex
Expand Down Expand Up @@ -283,9 +292,9 @@ def _stream_blob(
allows reading entire blob into memory at once or returning data from random offsets"""
file_size = None
start_range = byte_range[0] if byte_range else 0
chunk_size = self.conn._config.max_chunk_get_size # type: ignore[attr-defined]
chunk_size = self.get_blob_service_client()._config.max_chunk_get_size # type: ignore[attr-defined]
end_range = chunk_size - 1
blob = self.conn.get_blob_client(self.container_name, key)
blob = self.get_blob_service_client().get_blob_client(self.container_name, key)
while True:
try:
if byte_range:
Expand Down Expand Up @@ -337,7 +346,7 @@ def get_contents_to_fileobj(
def get_file_size(self, key: str) -> int:
path = self.format_key_for_backend(key, remove_slash_prefix=True)
try:
blob_client = self.conn.get_blob_client(self.container_name, path)
blob_client = self.get_blob_service_client().get_blob_client(self.container_name, path)
return blob_client.get_blob_properties().size
except azure.core.exceptions.ResourceNotFoundError as ex:
raise FileNotFoundFromStorageError(path) from ex
Expand Down Expand Up @@ -376,7 +385,7 @@ def progress_callback(pipeline_response: Any) -> None:
fd.tell = lambda: None # type: ignore[assignment,method-assign,return-value]
sanitized_metadata = self.sanitize_metadata(metadata, replace_hyphen_with="_")
try:
blob_client = self.conn.get_blob_client(self.container_name, path)
blob_client = self.get_blob_service_client().get_blob_client(self.container_name, path)
blob_client.upload_blob(
fd,
blob_type=BlobType.BlockBlob, # type: ignore[arg-type]
Expand All @@ -400,7 +409,7 @@ def get_or_create_container(self, container_name: str) -> str:
container_name = container_name.value
start_time = time.monotonic()
try:
self.conn.create_container(container_name)
self.get_blob_service_client().create_container(container_name)
except ResourceExistsError:
pass
except HttpResponseError as e:
Expand Down
4 changes: 4 additions & 0 deletions rohmu/object_storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def __init__(
self.notifier = notifier or NullNotifier()
self.stats = StatsClient(statsd_info)

def close(self) -> None:
Copy link
Contributor

@Khatskevich Khatskevich Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand who is calling this close call?
I could not find api in this repo which will enforce the close call.
Should the Transfer object "user" remember to close the object after every usage?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise a client pool may be used. Or any other simple way which makes it easy to use and reliable

Copy link
Contributor

@joelynch joelynch Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Close is part of the API for this class. A docstring might help here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Files, sockets, etc... need to be closed to reliably not have leaks, the same is true for APIs that use them.

I'll add a docstring to make it clear that this is part of the API.

"""Release all resources associated with the Transfer object."""
pass

@staticmethod
def _incremental_to_proportional_progress(
*, size: int, cb: ProgressProportionCallbackType
Expand Down
10 changes: 9 additions & 1 deletion rohmu/object_storage/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,18 @@ def __init__(
self.proxy_info = proxy_info
self.google_creds = get_credentials(credential_file=credential_file, credentials=credentials)
self.gs: Optional[Resource] = self._init_google_client()
self.gs_object_client = None
self.gs_object_client: Any = None
self.bucket_name = self.get_or_create_bucket(bucket_name)
self.log.debug("GoogleTransfer initialized")

def close(self) -> None:
if self.gs_object_client is not None:
self.gs_object_client.close()
self.gs_object_client = None
if self.gs is not None:
self.gs.close()
self.gs = None

def _init_google_client(self) -> Resource:
start_time = time.monotonic()
delay = 2
Expand Down
Loading
Loading