diff --git a/llmfoundry/utils/data_prep_utils.py b/llmfoundry/utils/data_prep_utils.py index b5b606a57f..9601e8618e 100644 --- a/llmfoundry/utils/data_prep_utils.py +++ b/llmfoundry/utils/data_prep_utils.py @@ -7,6 +7,8 @@ from typing import List, Optional from composer.utils import ObjectStore +from composer.utils.object_store import ObjectStoreTransientError +from composer.utils.retrying import retry __all__ = [ 'merge_shard_groups', @@ -78,6 +80,26 @@ def merge_shard_groups(root: str) -> None: out.write(text) +@retry(ObjectStoreTransientError, num_attempts=5) +def download_file( + object_store: ObjectStore, + object_name: str, + output_filename: str, +) -> None: + """Downloads a file from an object store. + + Args: + object_store (ObjectStore): Object store to download from + object_name (str): Name of object to download + output_filename (str): Local filename to write to + """ + object_store.download_object( + object_name=object_name, + filename=output_filename, + overwrite=True, + ) + + class DownloadingIterable: def __init__( @@ -110,10 +132,11 @@ def __iter__(self): self.output_folder, object_name.strip('/'), ) - self.object_store.download_object( + + download_file( + object_store=self.object_store, object_name=object_name, - filename=output_filename, - overwrite=True, + output_filename=output_filename, ) with open(output_filename) as _txt_file: diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index 1e36a681f9..98e9843570 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -24,6 +24,7 @@ from llmfoundry.utils import maybe_create_mosaicml_logger from llmfoundry.utils.data_prep_utils import ( DownloadingIterable, + download_file, merge_shard_groups, ) from llmfoundry.utils.exceptions import ( @@ -329,9 +330,13 @@ def is_already_processed( try: with tempfile.TemporaryDirectory() as tmp_dir: done_file = os.path.join(tmp_dir, DONE_FILENAME) - output_object_store.download_object( - os.path.join(output_folder_prefix, DONE_FILENAME), - done_file, + download_file( + object_store=output_object_store, + object_name=os.path.join( + output_folder_prefix, + DONE_FILENAME, + ), + output_filename=done_file, ) with open(done_file) as df: done_file_contents = df.read().splitlines()