Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c committed Dec 19, 2024
1 parent 46062f8 commit ca18958
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 5 deletions.
95 changes: 91 additions & 4 deletions data_juicer/core/ray_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import os
from functools import partial
from typing import Any, Dict, List, Literal, Optional, Union

import pyarrow as pa
import pyarrow
from loguru import logger

from data_juicer import cuda_device_count
Expand All @@ -12,6 +15,7 @@
from data_juicer.utils.process_utils import calculate_np

rd = LazyLoader('rd', 'ray.data')
ds = LazyLoader('ds', 'ray.data.datasource')


def get_abs_path(path, dataset_dir):
Expand All @@ -33,7 +37,7 @@ def convert_to_absolute_paths(samples, dataset_dir, path_keys):
samples[key][idx] = [
get_abs_path(item, dataset_dir) for item in paths
]
return pa.Table.from_pydict(samples)
return pyarrow.Table.from_pydict(samples)


# TODO: check path for nestdataset
Expand Down Expand Up @@ -63,7 +67,7 @@ def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset:
dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg)
if Fields.stats not in columns:

def process_batch_arrow(table: pa.Table) -> pa.Table:
def process_batch_arrow(table: pyarrow.Table) -> pyarrow.Table:
new_column_data = [{} for _ in range(len(table))]
new_talbe = table.append_column(Fields.stats, [new_column_data])
return new_talbe
Expand All @@ -81,7 +85,7 @@ def get_num_gpus(op, op_proc):


def filter_batch(batch, filter_func):
mask = pa.array(filter_func(batch.to_pydict()))
mask = pyarrow.array(filter_func(batch.to_pydict()))
return batch.filter(mask)


Expand Down Expand Up @@ -148,3 +152,86 @@ def _run_single_op(self, op):
import traceback
traceback.print_exc()
exit(1)

@classmethod
def read_json(cls, paths: Union[str, List[str]]) -> RayDataset:
# Note: a temp solution for reading json stream
# TODO: replace with ray.data.read_json_stream once it is available
import pyarrow.json as js
try:
js.open_json
return read_json_stream(paths)
except AttributeError:
return rd.read_json(paths)


class JSONStreamDatasource(ds.JSONDatasource):

def _read_stream(self, f: 'pyarrow.NativeFile', path: str):
from pyarrow.json import open_json

try:
reader = open_json(
f,
read_options=self.read_options,
**self.arrow_json_args,
)
schema = None
while True:
try:
batch = reader.read_next_batch()
table = pyarrow.Table.from_batches([batch], schema=schema)
if schema is None:
schema = table.schema
yield table
except StopIteration:
return
except pyarrow.lib.ArrowInvalid as e:
raise ValueError(
f'Failed to read JSON file: {path}. '
'Please check the JSON file has correct format, or filter out '
"non-JSON file with 'partition_filter' field. See read_csv() "
'documentation for more details.') from e


def read_json_stream(
paths: Union[str, List[str]],
*,
filesystem: Optional['pyarrow.fs.FileSystem'] = None,
parallelism: int = -1,
ray_remote_args: Dict[str, Any] = None,
arrow_open_stream_args: Optional[Dict[str, Any]] = None,
meta_provider=None,
partition_filter=None,
partitioning=ds.partitioning.Partitioning('hive'),
include_paths: bool = False,
ignore_missing_paths: bool = False,
shuffle: Union[Literal['files'], None] = None,
file_extensions: Optional[List[str]] = ['json', 'jsonl'],
concurrency: Optional[int] = None,
override_num_blocks: Optional[int] = None,
**arrow_json_args,
) -> rd.Dataset:
if meta_provider is None:
meta_provider = ds.file_meta_provider.DefaultFileMetadataProvider()

datasource = JSONStreamDatasource(
paths,
arrow_json_args=arrow_json_args,
filesystem=filesystem,
open_stream_args=arrow_open_stream_args,
meta_provider=meta_provider,
partition_filter=partition_filter,
partitioning=partitioning,
ignore_missing_paths=ignore_missing_paths,
shuffle=shuffle,
include_paths=include_paths,
file_extensions=file_extensions,
)
return rd.read_datasource(
datasource,
parallelism=parallelism,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
override_num_blocks=override_num_blocks,
)
2 changes: 1 addition & 1 deletion data_juicer/core/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def run(self, load_data_np=None):
from data_juicer.format.formatter import FORMATTERS
dataset = FORMATTERS.modules[obj_name](**args).load_dataset()
else:
dataset = rd.read_json(self.cfg.dataset_path)
dataset = RayDataset.read_json(self.cfg.dataset_path)

# convert all the path in dataset to absolute path
dataset = RayDataset(dataset, self.cfg.dataset_path, self.cfg)
Expand Down

0 comments on commit ca18958

Please sign in to comment.