-
Notifications
You must be signed in to change notification settings - Fork 153
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Training on PQ: draft generate.py, index.py, leaviung train.py
- Loading branch information
Showing
2 changed files
with
251 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Copyright 2023 MosaicML Streaming authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Generate a parquet dataset for testing.""" | ||
|
||
import os | ||
from argparse import ArgumentParser, Namespace | ||
from typing import List, Tuple | ||
|
||
import numpy as np | ||
import pyarrow as pa | ||
from pyarrow import parquet as pq | ||
|
||
|
||
def parse_args() -> Namespace: | ||
"""Parse command-line arguments. | ||
Returns: | ||
Namespace: Command-line arguments. | ||
""" | ||
args = ArgumentParser() | ||
args.add_argument('--num_train', type=int, default=10_000_000) | ||
args.add_argument('--num_val', type=int, default=1_000_000) | ||
args.add_argument('--out', type=str, default='data/pq/') | ||
args.add_argument('--samples_per_shard', type=int, default=10_000) | ||
return args.parse_args() | ||
|
||
|
||
_ones = ('zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen ' | ||
'fifteen sixteen seventeen eighteen nineteen').split() | ||
|
||
_tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split() | ||
|
||
|
||
def say(i: int) -> List[str]: | ||
"""Get the word form of a number. | ||
Args: | ||
i (int): The number. | ||
Returns: | ||
List[str]: The number in word form. | ||
""" | ||
if i < 0: | ||
return ['negative'] + say(-i) | ||
elif i <= 19: | ||
return [_ones[i]] | ||
elif i < 100: | ||
return [_tens[i // 10 - 2]] + ([_ones[i % 10]] if i % 10 else []) | ||
elif i < 1_000: | ||
return [_ones[i // 100], 'hundred'] + (say(i % 100) if i % 100 else []) | ||
elif i < 1_000_000: | ||
return say(i // 1_000) + ['thousand'] + (say(i % 1_000) if i % 1_000 else []) | ||
elif i < 1_000_000_000: | ||
return say(i // 1_000_000) + ['million'] + (say(i % 1_000_000) if i % 1_000_000 else []) | ||
else: | ||
raise ValueError('Integer must be less than a billion, but got: {i}') | ||
|
||
|
||
def generate_number() -> int: | ||
"""Generate a random integer to say. | ||
Returns: | ||
int: The integer. | ||
""" | ||
sign = (np.random.uniform() < 0.8) * 2 - 1 | ||
expt = np.random.uniform(0, 9) | ||
mag = int(10**expt) | ||
return sign * mag | ||
|
||
|
||
def generate_numbers(num_train: int, num_val: int) -> Tuple[List[int], List[int]]: | ||
"""Get two non-overlapping splits of integers to say. | ||
Args: | ||
num_train (int): Number of training samples. | ||
num_val (int): Number of validation samples. | ||
Returns: | ||
Tuple[List[int], List[int]]: The two generated splits. | ||
""" | ||
total = num_train + num_val | ||
nums = set() | ||
while len(nums) < total: | ||
num = generate_number() | ||
if num in nums: | ||
continue | ||
nums.add(num) | ||
nums = sorted(nums) | ||
np.random.shuffle(nums) | ||
train_nums = nums[:num_train] | ||
val_nums = nums[num_train:] | ||
return train_nums, val_nums | ||
|
||
|
||
def save_parquets(nums: List[int], txts: List[str], dirname: str, samples_per_shard: int) -> None: | ||
"""Save a parquet dataaset given the samples. | ||
Args: | ||
nums (List[int]): List of sample integers. | ||
txts (List[str]): List of sample texts. | ||
dirname (str): Output dirname. | ||
samples_per_shard (int): Output shard size in samples. | ||
""" | ||
if not os.path.exists(dirname): | ||
os.makedirs(dirname) | ||
num_shards = (len(nums) + samples_per_shard - 1) // samples_per_shard | ||
for shard_id in range(num_shards): | ||
begin = shard_id * samples_per_shard | ||
end = min(begin + samples_per_shard, len(nums)) | ||
shard_nums = nums[begin:end] | ||
shard_txts = txts[begin:end] | ||
filename = os.path.join(dirname, f'{shard_id:05}.parquet') | ||
obj = { | ||
'num': shard_nums, | ||
'txt': shard_txts, | ||
} | ||
table = pa.Table.from_pydict(obj) | ||
pq.write_table(table, filename) | ||
|
||
|
||
def main(args: Namespace) -> None: | ||
"""Generate a parquet dataset for testing. | ||
Args: | ||
args (Namespace): Command-line arguments. | ||
""" | ||
train_nums, val_nums = generate_numbers(args.num_train, args.num_val) | ||
|
||
train_txts = [' '.join(say(num)) for num in train_nums] | ||
val_txts = [' '.join(say(num)) for num in val_nums] | ||
|
||
dirname = os.path.join(args.out, 'train') | ||
save_parquets(train_nums, train_txts, dirname, args.samples_per_shard) | ||
|
||
dirname = os.path.join(args.out, 'val') | ||
save_parquets(val_nums, val_txts, dirname, args.samples_per_shard) | ||
|
||
|
||
if __name__ == '__main__': | ||
main(parse_args()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright 2023 MosaicML Streaming authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Index a parquet dataset for use by Streaming.""" | ||
|
||
import json | ||
import os | ||
from argparse import ArgumentParser, Namespace | ||
from typing import Any, Dict, Iterator, Tuple | ||
|
||
from pyarrow import parquet as pq | ||
|
||
|
||
def parse_args() -> Namespace: | ||
"""Parse command-line arguments. | ||
Returns: | ||
Namespace: Command-line arguments. | ||
""" | ||
args = ArgumentParser() | ||
args.add_argument('--dataset', type=str, required=True) | ||
args.add_argument('--shard_suffix', type=str, default='.parquet') | ||
return args.parse_args() | ||
|
||
|
||
def get_dataset_relative_path(dataset_root: str, path: str) -> str: | ||
"""Get the dataset-relative path of a shard file. | ||
Args: | ||
dataset_root (str): Dataset root directory containing this shard. | ||
path (str): Path to shard under dataset root dir. | ||
Returns: | ||
Dataset-relative shard path. | ||
""" | ||
if not path.startswith(dataset_root): | ||
raise ValueError('Path {path} was not found under {dataset_root}.') | ||
rel_path = path[len(dataset_root):] | ||
|
||
while rel_path.startswith(os.path.sep): | ||
rel_path = rel_path[1:] | ||
|
||
return rel_path | ||
|
||
|
||
def each_shard_path(dataset_root: str, shard_suffix: str) -> Iterator[Tuple[str, str]]: | ||
"""Collect each Parquet shard, in order. | ||
Args: | ||
dataset_root (str): Dataset root directory. | ||
shard_suffix (str): Suffix of each Parquet shard file. | ||
Returns: | ||
Iterator[Tuple[str, str]]: Iterator over absolute and dataset-relative paths. | ||
""" | ||
for root, _, files in os.walk(dataset_root): | ||
files = filter(lambda file: file.endswith(shard_suffix), files) | ||
files = (os.path.join(root, file) for file in files) | ||
files = sorted(files) | ||
for path in files: | ||
dataset_rel_path = get_dataset_relative_path(dataset_root, path) | ||
yield path, dataset_rel_path | ||
|
||
|
||
def get_shard_info(path: str, dataset_rel_path: str) -> Dict[str, Any]: | ||
"""Get info the index needs about a Parquet shard. | ||
Args: | ||
path (str): Absolute or relative-to-cwd file path. | ||
dataset_rel_path (str): Relative-to-dataset file path. | ||
Returns: | ||
Dict[str, Any]: Shard info. | ||
""" | ||
num_bytes = os.stat(path).st_size | ||
table = pq.read_table(path) | ||
num_samples = len(table) | ||
return { | ||
'format': 'parquet', | ||
'raw_parquet': { | ||
'basename': dataset_rel_path, | ||
'bytes': num_bytes, | ||
}, | ||
'samples': num_samples, | ||
} | ||
|
||
|
||
def main(args: Namespace) -> None: | ||
"""Index a parquet dataset for use by Streaming. | ||
Args: | ||
args (Namespace): Command-line arguments. | ||
""" | ||
infos = [] | ||
for path, dataset_rel_path in each_shard_path(args.dataset, args.shard_suffix): | ||
info = get_shard_info(path, dataset_rel_path) | ||
infos.append(info) | ||
obj = { | ||
'version': 2, | ||
'shards': infos, | ||
} | ||
filename = os.path.join(args.dataset, 'index.json') | ||
if os.path.exists(filename): | ||
raise ValueError(f'Index file {filename} already exists.') | ||
with open(filename, 'w') as out: | ||
json.dump(obj, out) | ||
|
||
|
||
if __name__ == '__main__': | ||
main(parse_args()) |