Skip to content

Commit

Permalink
Add retries to downloads in convert_text_to_mds.py (#1238)
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored May 24, 2024
1 parent ff92f3c commit fdaa58b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
29 changes: 26 additions & 3 deletions llmfoundry/utils/data_prep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import List, Optional

from composer.utils import ObjectStore
from composer.utils.object_store import ObjectStoreTransientError
from composer.utils.retrying import retry

__all__ = [
'merge_shard_groups',
Expand Down Expand Up @@ -78,6 +80,26 @@ def merge_shard_groups(root: str) -> None:
out.write(text)


@retry(ObjectStoreTransientError, num_attempts=5)
def download_file(
object_store: ObjectStore,
object_name: str,
output_filename: str,
) -> None:
"""Downloads a file from an object store.
Args:
object_store (ObjectStore): Object store to download from
object_name (str): Name of object to download
output_filename (str): Local filename to write to
"""
object_store.download_object(
object_name=object_name,
filename=output_filename,
overwrite=True,
)


class DownloadingIterable:

def __init__(
Expand Down Expand Up @@ -110,10 +132,11 @@ def __iter__(self):
self.output_folder,
object_name.strip('/'),
)
self.object_store.download_object(

download_file(
object_store=self.object_store,
object_name=object_name,
filename=output_filename,
overwrite=True,
output_filename=output_filename,
)

with open(output_filename) as _txt_file:
Expand Down
11 changes: 8 additions & 3 deletions scripts/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from llmfoundry.utils import maybe_create_mosaicml_logger
from llmfoundry.utils.data_prep_utils import (
DownloadingIterable,
download_file,
merge_shard_groups,
)
from llmfoundry.utils.exceptions import (
Expand Down Expand Up @@ -329,9 +330,13 @@ def is_already_processed(
try:
with tempfile.TemporaryDirectory() as tmp_dir:
done_file = os.path.join(tmp_dir, DONE_FILENAME)
output_object_store.download_object(
os.path.join(output_folder_prefix, DONE_FILENAME),
done_file,
download_file(
object_store=output_object_store,
object_name=os.path.join(
output_folder_prefix,
DONE_FILENAME,
),
output_filename=done_file,
)
with open(done_file) as df:
done_file_contents = df.read().splitlines()
Expand Down

0 comments on commit fdaa58b

Please sign in to comment.