Skip to content

Commit

Permalink
Training on PQ: draft generate.py, index.py, leaviung train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
knighton committed Sep 22, 2023
1 parent d5ff35f commit 6afa233
Show file tree
Hide file tree
Showing 2 changed files with 251 additions and 0 deletions.
141 changes: 141 additions & 0 deletions scripts/parquet/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Generate a parquet dataset for testing."""

import os
from argparse import ArgumentParser, Namespace
from typing import List, Tuple

import numpy as np
import pyarrow as pa
from pyarrow import parquet as pq


def parse_args() -> Namespace:
"""Parse command-line arguments.
Returns:
Namespace: Command-line arguments.
"""
args = ArgumentParser()
args.add_argument('--num_train', type=int, default=10_000_000)
args.add_argument('--num_val', type=int, default=1_000_000)
args.add_argument('--out', type=str, default='data/pq/')
args.add_argument('--samples_per_shard', type=int, default=10_000)
return args.parse_args()


_ones = ('zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen '
'fifteen sixteen seventeen eighteen nineteen').split()

_tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split()


def say(i: int) -> List[str]:
"""Get the word form of a number.
Args:
i (int): The number.
Returns:
List[str]: The number in word form.
"""
if i < 0:
return ['negative'] + say(-i)
elif i <= 19:
return [_ones[i]]
elif i < 100:
return [_tens[i // 10 - 2]] + ([_ones[i % 10]] if i % 10 else [])
elif i < 1_000:
return [_ones[i // 100], 'hundred'] + (say(i % 100) if i % 100 else [])
elif i < 1_000_000:
return say(i // 1_000) + ['thousand'] + (say(i % 1_000) if i % 1_000 else [])
elif i < 1_000_000_000:
return say(i // 1_000_000) + ['million'] + (say(i % 1_000_000) if i % 1_000_000 else [])
else:
raise ValueError('Integer must be less than a billion, but got: {i}')


def generate_number() -> int:
"""Generate a random integer to say.
Returns:
int: The integer.
"""
sign = (np.random.uniform() < 0.8) * 2 - 1
expt = np.random.uniform(0, 9)
mag = int(10**expt)
return sign * mag


def generate_numbers(num_train: int, num_val: int) -> Tuple[List[int], List[int]]:
"""Get two non-overlapping splits of integers to say.
Args:
num_train (int): Number of training samples.
num_val (int): Number of validation samples.
Returns:
Tuple[List[int], List[int]]: The two generated splits.
"""
total = num_train + num_val
nums = set()
while len(nums) < total:
num = generate_number()
if num in nums:
continue
nums.add(num)
nums = sorted(nums)
np.random.shuffle(nums)
train_nums = nums[:num_train]
val_nums = nums[num_train:]
return train_nums, val_nums


def save_parquets(nums: List[int], txts: List[str], dirname: str, samples_per_shard: int) -> None:
"""Save a parquet dataaset given the samples.
Args:
nums (List[int]): List of sample integers.
txts (List[str]): List of sample texts.
dirname (str): Output dirname.
samples_per_shard (int): Output shard size in samples.
"""
if not os.path.exists(dirname):
os.makedirs(dirname)
num_shards = (len(nums) + samples_per_shard - 1) // samples_per_shard
for shard_id in range(num_shards):
begin = shard_id * samples_per_shard
end = min(begin + samples_per_shard, len(nums))
shard_nums = nums[begin:end]
shard_txts = txts[begin:end]
filename = os.path.join(dirname, f'{shard_id:05}.parquet')
obj = {
'num': shard_nums,
'txt': shard_txts,
}
table = pa.Table.from_pydict(obj)
pq.write_table(table, filename)


def main(args: Namespace) -> None:
"""Generate a parquet dataset for testing.
Args:
args (Namespace): Command-line arguments.
"""
train_nums, val_nums = generate_numbers(args.num_train, args.num_val)

train_txts = [' '.join(say(num)) for num in train_nums]
val_txts = [' '.join(say(num)) for num in val_nums]

dirname = os.path.join(args.out, 'train')
save_parquets(train_nums, train_txts, dirname, args.samples_per_shard)

dirname = os.path.join(args.out, 'val')
save_parquets(val_nums, val_txts, dirname, args.samples_per_shard)


if __name__ == '__main__':
main(parse_args())
110 changes: 110 additions & 0 deletions scripts/parquet/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Index a parquet dataset for use by Streaming."""

import json
import os
from argparse import ArgumentParser, Namespace
from typing import Any, Dict, Iterator, Tuple

from pyarrow import parquet as pq


def parse_args() -> Namespace:
"""Parse command-line arguments.
Returns:
Namespace: Command-line arguments.
"""
args = ArgumentParser()
args.add_argument('--dataset', type=str, required=True)
args.add_argument('--shard_suffix', type=str, default='.parquet')
return args.parse_args()


def get_dataset_relative_path(dataset_root: str, path: str) -> str:
"""Get the dataset-relative path of a shard file.
Args:
dataset_root (str): Dataset root directory containing this shard.
path (str): Path to shard under dataset root dir.
Returns:
Dataset-relative shard path.
"""
if not path.startswith(dataset_root):
raise ValueError('Path {path} was not found under {dataset_root}.')
rel_path = path[len(dataset_root):]

while rel_path.startswith(os.path.sep):
rel_path = rel_path[1:]

return rel_path


def each_shard_path(dataset_root: str, shard_suffix: str) -> Iterator[Tuple[str, str]]:
"""Collect each Parquet shard, in order.
Args:
dataset_root (str): Dataset root directory.
shard_suffix (str): Suffix of each Parquet shard file.
Returns:
Iterator[Tuple[str, str]]: Iterator over absolute and dataset-relative paths.
"""
for root, _, files in os.walk(dataset_root):
files = filter(lambda file: file.endswith(shard_suffix), files)
files = (os.path.join(root, file) for file in files)
files = sorted(files)
for path in files:
dataset_rel_path = get_dataset_relative_path(dataset_root, path)
yield path, dataset_rel_path


def get_shard_info(path: str, dataset_rel_path: str) -> Dict[str, Any]:
"""Get info the index needs about a Parquet shard.
Args:
path (str): Absolute or relative-to-cwd file path.
dataset_rel_path (str): Relative-to-dataset file path.
Returns:
Dict[str, Any]: Shard info.
"""
num_bytes = os.stat(path).st_size
table = pq.read_table(path)
num_samples = len(table)
return {
'format': 'parquet',
'raw_parquet': {
'basename': dataset_rel_path,
'bytes': num_bytes,
},
'samples': num_samples,
}


def main(args: Namespace) -> None:
"""Index a parquet dataset for use by Streaming.
Args:
args (Namespace): Command-line arguments.
"""
infos = []
for path, dataset_rel_path in each_shard_path(args.dataset, args.shard_suffix):
info = get_shard_info(path, dataset_rel_path)
infos.append(info)
obj = {
'version': 2,
'shards': infos,
}
filename = os.path.join(args.dataset, 'index.json')
if os.path.exists(filename):
raise ValueError(f'Index file {filename} already exists.')
with open(filename, 'w') as out:
json.dump(obj, out)


if __name__ == '__main__':
main(parse_args())

0 comments on commit 6afa233

Please sign in to comment.