Skip to content

Commit

Permalink
Allow overwrite on upload retry in remote uploader downloader (mosaic…
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored May 20, 2024
1 parent 8e20ea8 commit 552958c
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 9 deletions.
4 changes: 2 additions & 2 deletions composer/loggers/remote_uploader_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,8 +673,8 @@ def _upload_worker(

# defining as a function-in-function to use decorator notation with num_attempts as an argument
@retry(ObjectStoreTransientError, num_attempts=num_attempts)
def upload_file():
if not overwrite:
def upload_file(retry_index: int = 0):
if retry_index == 0 and not overwrite:
try:
remote_backend.get_object_size(remote_file_name)
except FileNotFoundError:
Expand Down
14 changes: 8 additions & 6 deletions composer/utils/retrying.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import collections.abc
import functools
import inspect
import logging
import random
import time
Expand Down Expand Up @@ -46,18 +47,16 @@ def retry( # type: ignore
Attempts are spaced out with ``initial_backoff + 2**num_attempts + random.random() * max_jitter`` seconds.
Optionally, the decorated function can specify `retry_index` as an argument to receive the current attempt number.
Example:
.. testcode::
from composer.utils import retry
num_tries = 0
@retry(RuntimeError, num_attempts=3, initial_backoff=0.1)
def flaky_function():
global num_tries
if num_tries < 2:
num_tries += 1
def flaky_function(retry_index: int):
if retry_index < 2:
raise RuntimeError("Called too soon!")
return "Third time's a charm."
Expand All @@ -84,9 +83,12 @@ def wrapped_func(func: TCallable) -> TCallable:

@functools.wraps(func)
def new_func(*args: Any, **kwargs: Any):
retry_index_param = 'retry_index'
i = 0
while True:
try:
if retry_index_param in inspect.signature(func).parameters:
kwargs[retry_index_param] = i
return func(*args, **kwargs)
except exc_class as e:
log.debug(f'Attempt {i} failed. Exception type: {type(e)}, message: {str(e)}.')
Expand Down
60 changes: 59 additions & 1 deletion tests/loggers/test_remote_uploader_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from composer.core import Event, State
from composer.loggers import Logger, RemoteUploaderDownloader
from composer.utils.object_store.object_store import ObjectStore
from composer.utils.object_store.object_store import ObjectStore, ObjectStoreTransientError


class DummyObjectStore(ObjectStore):
Expand Down Expand Up @@ -190,6 +190,64 @@ def test_remote_uploader_downloader_no_overwrite(
)


def test_allow_overwrite_on_retry(tmp_path: pathlib.Path, dummy_state: State):
file_path = tmp_path / 'samples' / 'sample'
os.makedirs(tmp_path / 'samples')
with open(file_path, 'w') as f:
f.write('sample')

# Dummy object store that fails the first two uploads
# This tests that the remote uploader downloader allows overwriting a partially uploaded file on a retry.
class RetryDummyObjectStore(DummyObjectStore):

def __init__(
self,
dir: Optional[pathlib.Path] = None,
always_fail: bool = False,
**kwargs: Dict[str, Any],
) -> None:
self._retry = 0
super().__init__(dir, always_fail, **kwargs)

def upload_object(
self,
object_name: str,
filename: Union[str, pathlib.Path],
callback: Optional[Callable[[int, int], None]] = None,
) -> None:
if self._retry < 2:
self._retry += 1 # Takes two retries to upload the file
raise ObjectStoreTransientError('Retry this')
self._retry += 1
return super().upload_object(object_name, filename, callback)

def get_object_size(self, object_name: str) -> int:
if self._retry > 0:
return 1 # The 0th upload resulted in a partial upload
return super().get_object_size(object_name)

fork_context = multiprocessing.get_context('fork')
with patch('composer.loggers.remote_uploader_downloader.S3ObjectStore', RetryDummyObjectStore):
with patch('composer.loggers.remote_uploader_downloader.multiprocessing.get_context', lambda _: fork_context):
remote_uploader_downloader = RemoteUploaderDownloader(
bucket_uri=f"s3://{tmp_path}/'object_store_backend",
backend_kwargs={
'dir': tmp_path / 'object_store_backend',
},
num_concurrent_uploads=4,
upload_staging_folder=str(tmp_path / 'staging_folder'),
use_procs=True,
num_attempts=3,
)
logger = Logger(dummy_state, destinations=[remote_uploader_downloader])

remote_uploader_downloader.run_event(Event.INIT, dummy_state, logger)
remote_file_name = 'remote_file_name'
remote_uploader_downloader.upload_file(dummy_state, remote_file_name, file_path, overwrite=False)
remote_uploader_downloader.close(dummy_state, logger=logger)
remote_uploader_downloader.post_close()


@pytest.mark.parametrize('use_procs', [True, False])
def test_race_with_overwrite(tmp_path: pathlib.Path, use_procs: bool, dummy_state: State):
# Test a race condition with the object store logger where multiple files with the same name are logged in rapid succession
Expand Down

0 comments on commit 552958c

Please sign in to comment.