Skip to content

Commit

Permalink
Better refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Jan 20, 2024
1 parent 137341c commit 8ab7f6d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/a_scripts/data_prep/test_convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_single_and_multi_process(merge_shard_groups: Mock,
maybe_create_object_store_from_uri.return_value = mock_object_store
parse_uri.return_value = ('s3', 'fake-test-bucket', str(remote_folder))

def call_convert_text_to_mds(processes: int) -> None:
def call_convert_text_to_mds() -> None:
convert_text_to_mds(
tokenizer_name=tokenizer_name,
output_folder=f's3://fake-test-output-path',
Expand All @@ -106,7 +106,7 @@ def call_convert_text_to_mds(processes: int) -> None:
reprocess=False,
)

call_convert_text_to_mds(processes=processes)
call_convert_text_to_mds()

# Check call counts
assert download_and_convert.call_count == processes # called once per process
Expand All @@ -128,7 +128,7 @@ def call_convert_text_to_mds(processes: int) -> None:
_assert_files_exist(prefix=remote_folder,
files=['index.json', DONE_FILENAME] + shards)

call_convert_text_to_mds(processes=processes)
call_convert_text_to_mds()

# Check call counts
assert download_and_convert.call_count == processes # No changes because we shoudn't reprocess
Expand All @@ -141,7 +141,7 @@ def call_convert_text_to_mds(processes: int) -> None:
mock_object_store = Mock(wraps=object_store)
maybe_create_object_store_from_uri.return_value = mock_object_store

call_convert_text_to_mds(processes=processes)
call_convert_text_to_mds()

# Check call counts
assert download_and_convert.call_count == processes * 2 # called once per process
Expand Down

0 comments on commit 8ab7f6d

Please sign in to comment.