diff --git a/composer/loggers/remote_uploader_downloader.py b/composer/loggers/remote_uploader_downloader.py index 39cb3ded9f..a88ac214b9 100644 --- a/composer/loggers/remote_uploader_downloader.py +++ b/composer/loggers/remote_uploader_downloader.py @@ -673,8 +673,8 @@ def _upload_worker( # defining as a function-in-function to use decorator notation with num_attempts as an argument @retry(ObjectStoreTransientError, num_attempts=num_attempts) - def upload_file(): - if not overwrite: + def upload_file(retry_index: int = 0): + if retry_index == 0 and not overwrite: try: remote_backend.get_object_size(remote_file_name) except FileNotFoundError: diff --git a/composer/utils/retrying.py b/composer/utils/retrying.py index 5684018061..d2cec5b9a2 100644 --- a/composer/utils/retrying.py +++ b/composer/utils/retrying.py @@ -7,6 +7,7 @@ import collections.abc import functools +import inspect import logging import random import time @@ -46,18 +47,16 @@ def retry( # type: ignore Attempts are spaced out with ``initial_backoff + 2**num_attempts + random.random() * max_jitter`` seconds. + Optionally, the decorated function can specify `retry_index` as an argument to receive the current attempt number. + Example: .. testcode:: from composer.utils import retry - num_tries = 0 - @retry(RuntimeError, num_attempts=3, initial_backoff=0.1) - def flaky_function(): - global num_tries - if num_tries < 2: - num_tries += 1 + def flaky_function(retry_index: int): + if retry_index < 2: raise RuntimeError("Called too soon!") return "Third time's a charm." @@ -84,9 +83,12 @@ def wrapped_func(func: TCallable) -> TCallable: @functools.wraps(func) def new_func(*args: Any, **kwargs: Any): + retry_index_param = 'retry_index' i = 0 while True: try: + if retry_index_param in inspect.signature(func).parameters: + kwargs[retry_index_param] = i return func(*args, **kwargs) except exc_class as e: log.debug(f'Attempt {i} failed. Exception type: {type(e)}, message: {str(e)}.') diff --git a/tests/loggers/test_remote_uploader_downloader.py b/tests/loggers/test_remote_uploader_downloader.py index fe0159db6d..b3f389f991 100644 --- a/tests/loggers/test_remote_uploader_downloader.py +++ b/tests/loggers/test_remote_uploader_downloader.py @@ -15,7 +15,7 @@ from composer.core import Event, State from composer.loggers import Logger, RemoteUploaderDownloader -from composer.utils.object_store.object_store import ObjectStore +from composer.utils.object_store.object_store import ObjectStore, ObjectStoreTransientError class DummyObjectStore(ObjectStore): @@ -190,6 +190,64 @@ def test_remote_uploader_downloader_no_overwrite( ) +def test_allow_overwrite_on_retry(tmp_path: pathlib.Path, dummy_state: State): + file_path = tmp_path / 'samples' / 'sample' + os.makedirs(tmp_path / 'samples') + with open(file_path, 'w') as f: + f.write('sample') + + # Dummy object store that fails the first two uploads + # This tests that the remote uploader downloader allows overwriting a partially uploaded file on a retry. + class RetryDummyObjectStore(DummyObjectStore): + + def __init__( + self, + dir: Optional[pathlib.Path] = None, + always_fail: bool = False, + **kwargs: Dict[str, Any], + ) -> None: + self._retry = 0 + super().__init__(dir, always_fail, **kwargs) + + def upload_object( + self, + object_name: str, + filename: Union[str, pathlib.Path], + callback: Optional[Callable[[int, int], None]] = None, + ) -> None: + if self._retry < 2: + self._retry += 1 # Takes two retries to upload the file + raise ObjectStoreTransientError('Retry this') + self._retry += 1 + return super().upload_object(object_name, filename, callback) + + def get_object_size(self, object_name: str) -> int: + if self._retry > 0: + return 1 # The 0th upload resulted in a partial upload + return super().get_object_size(object_name) + + fork_context = multiprocessing.get_context('fork') + with patch('composer.loggers.remote_uploader_downloader.S3ObjectStore', RetryDummyObjectStore): + with patch('composer.loggers.remote_uploader_downloader.multiprocessing.get_context', lambda _: fork_context): + remote_uploader_downloader = RemoteUploaderDownloader( + bucket_uri=f"s3://{tmp_path}/'object_store_backend", + backend_kwargs={ + 'dir': tmp_path / 'object_store_backend', + }, + num_concurrent_uploads=4, + upload_staging_folder=str(tmp_path / 'staging_folder'), + use_procs=True, + num_attempts=3, + ) + logger = Logger(dummy_state, destinations=[remote_uploader_downloader]) + + remote_uploader_downloader.run_event(Event.INIT, dummy_state, logger) + remote_file_name = 'remote_file_name' + remote_uploader_downloader.upload_file(dummy_state, remote_file_name, file_path, overwrite=False) + remote_uploader_downloader.close(dummy_state, logger=logger) + remote_uploader_downloader.post_close() + + @pytest.mark.parametrize('use_procs', [True, False]) def test_race_with_overwrite(tmp_path: pathlib.Path, use_procs: bool, dummy_state: State): # Test a race condition with the object store logger where multiple files with the same name are logged in rapid succession