From 6afa233cf019b6468ed26cccf4b4582a45782df8 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Fri, 22 Sep 2023 15:09:05 -0700 Subject: [PATCH] Training on PQ: draft generate.py, index.py, leaviung train.py --- scripts/parquet/generate.py | 141 ++++++++++++++++++++++++++++++++++++ scripts/parquet/index.py | 110 ++++++++++++++++++++++++++++ 2 files changed, 251 insertions(+) create mode 100644 scripts/parquet/generate.py create mode 100644 scripts/parquet/index.py diff --git a/scripts/parquet/generate.py b/scripts/parquet/generate.py new file mode 100644 index 000000000..de14e6a00 --- /dev/null +++ b/scripts/parquet/generate.py @@ -0,0 +1,141 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Generate a parquet dataset for testing.""" + +import os +from argparse import ArgumentParser, Namespace +from typing import List, Tuple + +import numpy as np +import pyarrow as pa +from pyarrow import parquet as pq + + +def parse_args() -> Namespace: + """Parse command-line arguments. + + Returns: + Namespace: Command-line arguments. + """ + args = ArgumentParser() + args.add_argument('--num_train', type=int, default=10_000_000) + args.add_argument('--num_val', type=int, default=1_000_000) + args.add_argument('--out', type=str, default='data/pq/') + args.add_argument('--samples_per_shard', type=int, default=10_000) + return args.parse_args() + + +_ones = ('zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen ' + 'fifteen sixteen seventeen eighteen nineteen').split() + +_tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split() + + +def say(i: int) -> List[str]: + """Get the word form of a number. + + Args: + i (int): The number. + + Returns: + List[str]: The number in word form. + """ + if i < 0: + return ['negative'] + say(-i) + elif i <= 19: + return [_ones[i]] + elif i < 100: + return [_tens[i // 10 - 2]] + ([_ones[i % 10]] if i % 10 else []) + elif i < 1_000: + return [_ones[i // 100], 'hundred'] + (say(i % 100) if i % 100 else []) + elif i < 1_000_000: + return say(i // 1_000) + ['thousand'] + (say(i % 1_000) if i % 1_000 else []) + elif i < 1_000_000_000: + return say(i // 1_000_000) + ['million'] + (say(i % 1_000_000) if i % 1_000_000 else []) + else: + raise ValueError('Integer must be less than a billion, but got: {i}') + + +def generate_number() -> int: + """Generate a random integer to say. + + Returns: + int: The integer. + """ + sign = (np.random.uniform() < 0.8) * 2 - 1 + expt = np.random.uniform(0, 9) + mag = int(10**expt) + return sign * mag + + +def generate_numbers(num_train: int, num_val: int) -> Tuple[List[int], List[int]]: + """Get two non-overlapping splits of integers to say. + + Args: + num_train (int): Number of training samples. + num_val (int): Number of validation samples. + + Returns: + Tuple[List[int], List[int]]: The two generated splits. + """ + total = num_train + num_val + nums = set() + while len(nums) < total: + num = generate_number() + if num in nums: + continue + nums.add(num) + nums = sorted(nums) + np.random.shuffle(nums) + train_nums = nums[:num_train] + val_nums = nums[num_train:] + return train_nums, val_nums + + +def save_parquets(nums: List[int], txts: List[str], dirname: str, samples_per_shard: int) -> None: + """Save a parquet dataaset given the samples. + + Args: + nums (List[int]): List of sample integers. + txts (List[str]): List of sample texts. + dirname (str): Output dirname. + samples_per_shard (int): Output shard size in samples. + """ + if not os.path.exists(dirname): + os.makedirs(dirname) + num_shards = (len(nums) + samples_per_shard - 1) // samples_per_shard + for shard_id in range(num_shards): + begin = shard_id * samples_per_shard + end = min(begin + samples_per_shard, len(nums)) + shard_nums = nums[begin:end] + shard_txts = txts[begin:end] + filename = os.path.join(dirname, f'{shard_id:05}.parquet') + obj = { + 'num': shard_nums, + 'txt': shard_txts, + } + table = pa.Table.from_pydict(obj) + pq.write_table(table, filename) + + +def main(args: Namespace) -> None: + """Generate a parquet dataset for testing. + + Args: + args (Namespace): Command-line arguments. + """ + train_nums, val_nums = generate_numbers(args.num_train, args.num_val) + + train_txts = [' '.join(say(num)) for num in train_nums] + val_txts = [' '.join(say(num)) for num in val_nums] + + dirname = os.path.join(args.out, 'train') + save_parquets(train_nums, train_txts, dirname, args.samples_per_shard) + + dirname = os.path.join(args.out, 'val') + save_parquets(val_nums, val_txts, dirname, args.samples_per_shard) + + +if __name__ == '__main__': + main(parse_args()) diff --git a/scripts/parquet/index.py b/scripts/parquet/index.py new file mode 100644 index 000000000..2168fe92e --- /dev/null +++ b/scripts/parquet/index.py @@ -0,0 +1,110 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Index a parquet dataset for use by Streaming.""" + +import json +import os +from argparse import ArgumentParser, Namespace +from typing import Any, Dict, Iterator, Tuple + +from pyarrow import parquet as pq + + +def parse_args() -> Namespace: + """Parse command-line arguments. + + Returns: + Namespace: Command-line arguments. + """ + args = ArgumentParser() + args.add_argument('--dataset', type=str, required=True) + args.add_argument('--shard_suffix', type=str, default='.parquet') + return args.parse_args() + + +def get_dataset_relative_path(dataset_root: str, path: str) -> str: + """Get the dataset-relative path of a shard file. + + Args: + dataset_root (str): Dataset root directory containing this shard. + path (str): Path to shard under dataset root dir. + + Returns: + Dataset-relative shard path. + """ + if not path.startswith(dataset_root): + raise ValueError('Path {path} was not found under {dataset_root}.') + rel_path = path[len(dataset_root):] + + while rel_path.startswith(os.path.sep): + rel_path = rel_path[1:] + + return rel_path + + +def each_shard_path(dataset_root: str, shard_suffix: str) -> Iterator[Tuple[str, str]]: + """Collect each Parquet shard, in order. + + Args: + dataset_root (str): Dataset root directory. + shard_suffix (str): Suffix of each Parquet shard file. + + Returns: + Iterator[Tuple[str, str]]: Iterator over absolute and dataset-relative paths. + """ + for root, _, files in os.walk(dataset_root): + files = filter(lambda file: file.endswith(shard_suffix), files) + files = (os.path.join(root, file) for file in files) + files = sorted(files) + for path in files: + dataset_rel_path = get_dataset_relative_path(dataset_root, path) + yield path, dataset_rel_path + + +def get_shard_info(path: str, dataset_rel_path: str) -> Dict[str, Any]: + """Get info the index needs about a Parquet shard. + + Args: + path (str): Absolute or relative-to-cwd file path. + dataset_rel_path (str): Relative-to-dataset file path. + + Returns: + Dict[str, Any]: Shard info. + """ + num_bytes = os.stat(path).st_size + table = pq.read_table(path) + num_samples = len(table) + return { + 'format': 'parquet', + 'raw_parquet': { + 'basename': dataset_rel_path, + 'bytes': num_bytes, + }, + 'samples': num_samples, + } + + +def main(args: Namespace) -> None: + """Index a parquet dataset for use by Streaming. + + Args: + args (Namespace): Command-line arguments. + """ + infos = [] + for path, dataset_rel_path in each_shard_path(args.dataset, args.shard_suffix): + info = get_shard_info(path, dataset_rel_path) + infos.append(info) + obj = { + 'version': 2, + 'shards': infos, + } + filename = os.path.join(args.dataset, 'index.json') + if os.path.exists(filename): + raise ValueError(f'Index file {filename} already exists.') + with open(filename, 'w') as out: + json.dump(obj, out) + + +if __name__ == '__main__': + main(parse_args())