Skip to content

Commit

Permalink
Only strip object names when creating new output path (#766)
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Nov 29, 2023
1 parent 3a96b69 commit 1191267
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
11 changes: 6 additions & 5 deletions llmfoundry/utils/data_prep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,16 @@ def __init__(

def __iter__(self):
for object_name in self.object_names:
object_name = object_name.strip('/')
output_filename = os.path.join(self.output_folder, object_name)
# Default output_filename, used for local paths.
output_filename = object_name

# Download objects if remote path.
if self.object_store is not None:
output_filename = os.path.join(self.output_folder,
object_name.strip('/'))
self.object_store.download_object(object_name=object_name,
filename=output_filename,
overwrite=True)
else:
# Inputs are local so we do not need to download them.
output_filename = object_name

with open(output_filename) as _txt_file:
txt = _txt_file.read()
Expand Down
31 changes: 31 additions & 0 deletions tests/test_convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,37 @@ def test_single_and_multi_process(merge_shard_groups: Mock,
assert n_tokens == expected_n_tokens


def test_local_path(tmp_path: pathlib.Path):
# Input/output folders
input_folder = tmp_path / 'input'
output_folder = tmp_path / 'output'

# Create input text data
os.makedirs(input_folder, exist_ok=True)
with open(input_folder / 'test.txt', 'w') as f:
f.write('test')

# Convert text data to mds
convert_text_to_mds(
tokenizer_name='mosaicml/mpt-7b',
output_folder=str(output_folder),
input_folder=str(input_folder),
concat_tokens=1,
eos_text='',
bos_text='',
no_wrap=False,
compression='zstd',
processes=1,
args_str='Namespace()',
reprocess=False,
)

# Make sure all the files exist as expected.
assert os.path.exists(output_folder / '.text_to_mds_conversion_done')
assert os.path.exists(output_folder / 'index.json')
assert os.path.exists(output_folder / 'shard.00000.mds.zstd')


def test_is_already_processed(tmp_path: pathlib.Path):
tmp_path_str = str(tmp_path)
args_str = 'Namespace(x = 5)'
Expand Down

0 comments on commit 1191267

Please sign in to comment.