diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index b252b5989..4e76063df 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import os from functools import partial +from typing import Any, Dict, List, Literal, Optional, Union -import pyarrow as pa +import pyarrow from loguru import logger from data_juicer import cuda_device_count @@ -12,6 +15,7 @@ from data_juicer.utils.process_utils import calculate_np rd = LazyLoader('rd', 'ray.data') +ds = LazyLoader('ds', 'ray.data.datasource') def get_abs_path(path, dataset_dir): @@ -33,7 +37,7 @@ def convert_to_absolute_paths(samples, dataset_dir, path_keys): samples[key][idx] = [ get_abs_path(item, dataset_dir) for item in paths ] - return pa.Table.from_pydict(samples) + return pyarrow.Table.from_pydict(samples) # TODO: check path for nestdataset @@ -63,7 +67,7 @@ def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset: dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg) if Fields.stats not in columns: - def process_batch_arrow(table: pa.Table) -> pa.Table: + def process_batch_arrow(table: pyarrow.Table) -> pyarrow.Table: new_column_data = [{} for _ in range(len(table))] new_talbe = table.append_column(Fields.stats, [new_column_data]) return new_talbe @@ -81,7 +85,7 @@ def get_num_gpus(op, op_proc): def filter_batch(batch, filter_func): - mask = pa.array(filter_func(batch.to_pydict())) + mask = pyarrow.array(filter_func(batch.to_pydict())) return batch.filter(mask) @@ -148,3 +152,86 @@ def _run_single_op(self, op): import traceback traceback.print_exc() exit(1) + + @classmethod + def read_json(cls, paths: Union[str, List[str]]) -> RayDataset: + # Note: a temp solution for reading json stream + # TODO: replace with ray.data.read_json_stream once it is available + import pyarrow.json as js + try: + js.open_json + return read_json_stream(paths) + except AttributeError: + return rd.read_json(paths) + + +class JSONStreamDatasource(ds.JSONDatasource): + + def _read_stream(self, f: 'pyarrow.NativeFile', path: str): + from pyarrow.json import open_json + + try: + reader = open_json( + f, + read_options=self.read_options, + **self.arrow_json_args, + ) + schema = None + while True: + try: + batch = reader.read_next_batch() + table = pyarrow.Table.from_batches([batch], schema=schema) + if schema is None: + schema = table.schema + yield table + except StopIteration: + return + except pyarrow.lib.ArrowInvalid as e: + raise ValueError( + f'Failed to read JSON file: {path}. ' + 'Please check the JSON file has correct format, or filter out ' + "non-JSON file with 'partition_filter' field. See read_csv() " + 'documentation for more details.') from e + + +def read_json_stream( + paths: Union[str, List[str]], + *, + filesystem: Optional['pyarrow.fs.FileSystem'] = None, + parallelism: int = -1, + ray_remote_args: Dict[str, Any] = None, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + meta_provider=None, + partition_filter=None, + partitioning=ds.partitioning.Partitioning('hive'), + include_paths: bool = False, + ignore_missing_paths: bool = False, + shuffle: Union[Literal['files'], None] = None, + file_extensions: Optional[List[str]] = ['json', 'jsonl'], + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, + **arrow_json_args, +) -> rd.Dataset: + if meta_provider is None: + meta_provider = ds.file_meta_provider.DefaultFileMetadataProvider() + + datasource = JSONStreamDatasource( + paths, + arrow_json_args=arrow_json_args, + filesystem=filesystem, + open_stream_args=arrow_open_stream_args, + meta_provider=meta_provider, + partition_filter=partition_filter, + partitioning=partitioning, + ignore_missing_paths=ignore_missing_paths, + shuffle=shuffle, + include_paths=include_paths, + file_extensions=file_extensions, + ) + return rd.read_datasource( + datasource, + parallelism=parallelism, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index 1d90e31b3..41990b36a 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -61,7 +61,7 @@ def run(self, load_data_np=None): from data_juicer.format.formatter import FORMATTERS dataset = FORMATTERS.modules[obj_name](**args).load_dataset() else: - dataset = rd.read_json(self.cfg.dataset_path) + dataset = RayDataset.read_json(self.cfg.dataset_path) # convert all the path in dataset to absolute path dataset = RayDataset(dataset, self.cfg.dataset_path, self.cfg)