diff --git a/tests/tools/test_process_data.py b/tests/tools/test_process_data.py index 8e026f8ad..f318f6150 100644 --- a/tests/tools/test_process_data.py +++ b/tests/tools/test_process_data.py @@ -182,12 +182,16 @@ def test_ray_image(self): self.assertTrue(osp.exists(tmp_out_path)) - import ray - res_ds = ray.data.read_json(tmp_out_path) - res_ds = res_ds.to_pandas().to_dict(orient='records') - - self.assertEqual(len(res_ds), 3) - for item in res_ds: + from datasets import load_dataset + jsonl_files = [os.path.join(tmp_out_path, f) \ + for f in os.listdir(tmp_out_path) \ + if f.endswith('.json')] + dataset = load_dataset( + 'json', + data_files={'jsonl': jsonl_files}) + + self.assertEqual(len(dataset['jsonl']), 3) + for item in dataset['jsonl']: self.assertIn('aspect_ratios', item['__dj__stats__'])