Skip to content

Commit

Permalink
Add better error for non-empty local output folder in convert_text_to…
Browse files Browse the repository at this point in the history
…_mds.py (#891)
  • Loading branch information
irenedea authored Jan 23, 2024
1 parent f2614a4 commit 36fcb5e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 39 deletions.
4 changes: 4 additions & 0 deletions scripts/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,10 @@ def convert_text_to_mds(
local_output_folder = tempfile.TemporaryDirectory(
).name if is_remote_output else output_folder

if os.path.isdir(output_folder) and len(os.listdir(output_folder)) > 0:
raise FileExistsError(
f'{output_folder=} is not empty. Please remove or empty it.')

if processes > 1:
# Download and convert the text files in parallel
args = get_task_args(object_names, local_output_folder, input_folder,
Expand Down
82 changes: 43 additions & 39 deletions tests/a_scripts/data_prep/test_convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os
import pathlib
import shutil
from concurrent.futures import ProcessPoolExecutor
from glob import glob
from typing import Callable, Iterable, List
Expand Down Expand Up @@ -55,23 +56,6 @@ def upload_object(self, object_name: str, filename: str):
remote_file.write(local_file.read())


def _call_convert_text_to_mds(processes: int, tokenizer_name: str,
concat_tokens: int) -> None:
convert_text_to_mds(
tokenizer_name=tokenizer_name,
output_folder=f's3://fake-test-output-path',
input_folder=f's3://fake-test-input-path',
concat_tokens=concat_tokens,
eos_text='',
bos_text='',
no_wrap=False,
compression='zstd',
processes=processes,
args_str='Namespace()',
reprocess=False,
)


# Mock starmap with no multiprocessing
def _mock_map(func: Callable, args: Iterable) -> Iterable:
for arg in args:
Expand Down Expand Up @@ -107,9 +91,22 @@ def test_single_and_multi_process(merge_shard_groups: Mock,
maybe_create_object_store_from_uri.return_value = mock_object_store
parse_uri.return_value = ('s3', 'fake-test-bucket', str(remote_folder))

_call_convert_text_to_mds(processes=processes,
tokenizer_name=tokenizer_name,
concat_tokens=concat_tokens)
def call_convert_text_to_mds() -> None:
convert_text_to_mds(
tokenizer_name=tokenizer_name,
output_folder=f's3://fake-test-output-path',
input_folder=f's3://fake-test-input-path',
concat_tokens=concat_tokens,
eos_text='',
bos_text='',
no_wrap=False,
compression='zstd',
processes=processes,
args_str='Namespace()',
reprocess=False,
)

call_convert_text_to_mds()

# Check call counts
assert download_and_convert.call_count == processes # called once per process
Expand All @@ -131,9 +128,7 @@ def test_single_and_multi_process(merge_shard_groups: Mock,
_assert_files_exist(prefix=remote_folder,
files=['index.json', DONE_FILENAME] + shards)

_call_convert_text_to_mds(processes=processes,
tokenizer_name=tokenizer_name,
concat_tokens=concat_tokens)
call_convert_text_to_mds()

# Check call counts
assert download_and_convert.call_count == processes # No changes because we shoudn't reprocess
Expand All @@ -146,9 +141,7 @@ def test_single_and_multi_process(merge_shard_groups: Mock,
mock_object_store = Mock(wraps=object_store)
maybe_create_object_store_from_uri.return_value = mock_object_store

_call_convert_text_to_mds(processes=processes,
tokenizer_name=tokenizer_name,
concat_tokens=concat_tokens)
call_convert_text_to_mds()

# Check call counts
assert download_and_convert.call_count == processes * 2 # called once per process
Expand Down Expand Up @@ -187,31 +180,42 @@ def test_local_path(tmp_path: pathlib.Path):
input_folder = tmp_path / 'input'
output_folder = tmp_path / 'output'

def call_convert_text_to_mds(reprocess: bool):
convert_text_to_mds(
tokenizer_name='mosaicml/mpt-7b',
output_folder=str(output_folder),
input_folder=str(input_folder),
concat_tokens=1,
eos_text='',
bos_text='',
no_wrap=False,
compression='zstd',
processes=1,
args_str='Namespace()',
reprocess=reprocess,
)

# Create input text data
os.makedirs(input_folder, exist_ok=True)
with open(input_folder / 'test.txt', 'w') as f:
f.write('test')

# Convert text data to mds
convert_text_to_mds(
tokenizer_name='mosaicml/mpt-7b',
output_folder=str(output_folder),
input_folder=str(input_folder),
concat_tokens=1,
eos_text='',
bos_text='',
no_wrap=False,
compression='zstd',
processes=1,
args_str='Namespace()',
reprocess=False,
)
call_convert_text_to_mds(reprocess=False)

# Make sure all the files exist as expected.
assert os.path.exists(output_folder / '.text_to_mds_conversion_done')
assert os.path.exists(output_folder / 'index.json')
assert os.path.exists(output_folder / 'shard.00000.mds.zstd')

# Test reprocessing.
with pytest.raises(FileExistsError):
call_convert_text_to_mds(reprocess=True)

shutil.rmtree(output_folder)

call_convert_text_to_mds(reprocess=True)


def test_is_already_processed(tmp_path: pathlib.Path):
tmp_path_str = str(tmp_path)
Expand Down

0 comments on commit 36fcb5e

Please sign in to comment.