diff --git a/tests/a_scripts/data_prep/test_convert_text_to_mds.py b/tests/a_scripts/data_prep/test_convert_text_to_mds.py index 142b7b748b..3a00a8889f 100644 --- a/tests/a_scripts/data_prep/test_convert_text_to_mds.py +++ b/tests/a_scripts/data_prep/test_convert_text_to_mds.py @@ -91,7 +91,7 @@ def test_single_and_multi_process(merge_shard_groups: Mock, maybe_create_object_store_from_uri.return_value = mock_object_store parse_uri.return_value = ('s3', 'fake-test-bucket', str(remote_folder)) - def call_convert_text_to_mds(processes: int) -> None: + def call_convert_text_to_mds() -> None: convert_text_to_mds( tokenizer_name=tokenizer_name, output_folder=f's3://fake-test-output-path', @@ -106,7 +106,7 @@ def call_convert_text_to_mds(processes: int) -> None: reprocess=False, ) - call_convert_text_to_mds(processes=processes) + call_convert_text_to_mds() # Check call counts assert download_and_convert.call_count == processes # called once per process @@ -128,7 +128,7 @@ def call_convert_text_to_mds(processes: int) -> None: _assert_files_exist(prefix=remote_folder, files=['index.json', DONE_FILENAME] + shards) - call_convert_text_to_mds(processes=processes) + call_convert_text_to_mds() # Check call counts assert download_and_convert.call_count == processes # No changes because we shoudn't reprocess @@ -141,7 +141,7 @@ def call_convert_text_to_mds(processes: int) -> None: mock_object_store = Mock(wraps=object_store) maybe_create_object_store_from_uri.return_value = mock_object_store - call_convert_text_to_mds(processes=processes) + call_convert_text_to_mds() # Check call counts assert download_and_convert.call_count == processes * 2 # called once per process