Skip to content

Commit

Permalink
improve dataset saving
Browse files Browse the repository at this point in the history
unify val to validation
make `save_dataset` support saving multiple files
allow passing filename in `save_dataset`

Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Oct 22, 2024
1 parent e12f96f commit 38e8a13
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 22 deletions.
12 changes: 5 additions & 7 deletions multimolecule/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def infer_task(
) -> Task:
if max_seq_length is not None and seq_length_offset is not None:
max_seq_length -= seq_length_offset
if isinstance(sequence, ChunkedArray) and sequence.num_chunks == 1:
sequence = sequence.chunks[0]
if isinstance(column, ChunkedArray) and column.num_chunks == 1:
column = column.chunks[0]
if isinstance(sequence, ChunkedArray):
sequence = sequence.combine_chunks()
if isinstance(column, ChunkedArray):
column = column.combine_chunks()
flattened, levels = flatten_column(column, truncation, max_seq_length)
dtype = flattened.type
unique = flattened.unique()
Expand Down Expand Up @@ -145,14 +145,12 @@ def flatten_column(


def get_num_tokens(sequence: Array | ListArray, seq_length_offset: int | None = None) -> Tuple[int, int]:
if isinstance(sequence, StringArray):
if isinstance(sequence, StringArray) or isinstance(sequence[0], pa.lib.StringScalar):
return sum(len(i.as_py()) for i in sequence), sum(len(i.as_py()) ** 2 for i in sequence)
# remove <bos> and <eos> tokens in length calculation
if seq_length_offset is None:
warn("seq_length_offset not specified, automatically detecting <bos> and <eos> tokens")
seq_length_offset = 0
if isinstance(sequence[0], pa.lib.StringScalar):
raise ValueError("seq_length_offset must be specified for StringScalar sequences")
if len({i[0] for i in sequence}) == 1:
seq_length_offset += 1
if len({i[-1] for i in sequence}) == 1:
Expand Down
2 changes: 1 addition & 1 deletion multimolecule/datasets/bprna_new/bprna_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def convert_bpseq(bpseq):

def convert_dataset(convert_config):
data = [convert_bpseq(file) for file in tqdm(get_files(convert_config.dataset_path))]
save_dataset(convert_config, data)
save_dataset(convert_config, data, filename="test.parquet")


class ConvertConfig(ConvertConfig_):
Expand Down
14 changes: 2 additions & 12 deletions multimolecule/datasets/bprna_spot/bprna_spot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,17 @@
from __future__ import annotations

import os
from collections.abc import Mapping

import torch
from tqdm import tqdm

from multimolecule.datasets.bprna.bprna import convert_sta
from multimolecule.datasets.conversion_utils import ConvertConfig as ConvertConfig_
from multimolecule.datasets.conversion_utils import copy_readme, get_files, push_to_hub, write_data
from multimolecule.datasets.conversion_utils import get_files, save_dataset

torch.manual_seed(1016)


def save_dataset(convert_config: ConvertConfig, data: Mapping, compression: str = "brotli", level: int = 4):
root, output_path = convert_config.root, convert_config.output_path
os.makedirs(output_path, exist_ok=True)
for name, d in data.items():
write_data(d, output_path, name + ".parquet", compression, level)
copy_readme(root, output_path)
push_to_hub(convert_config, output_path)


def _convert_dataset(dataset):
files = get_files(dataset)
return [convert_sta(file) for file in tqdm(files, total=len(files))]
Expand All @@ -46,7 +36,7 @@ def _convert_dataset(dataset):
def convert_dataset(convert_config):
data = {
"train": _convert_dataset(os.path.join(convert_config.dataset_path, "TR0")),
"val": _convert_dataset(os.path.join(convert_config.dataset_path, "VL0")),
"validation": _convert_dataset(os.path.join(convert_config.dataset_path, "VL0")),
"test": _convert_dataset(os.path.join(convert_config.dataset_path, "TS0")),
}
save_dataset(convert_config, data)
Expand Down
16 changes: 14 additions & 2 deletions multimolecule/datasets/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import os
import shutil
from collections.abc import Mapping
from warnings import warn

import pyarrow as pa
from chanfig import Config
Expand Down Expand Up @@ -74,11 +76,21 @@ def push_to_hub(convert_config: ConvertConfig, output_path: str, repo_type: str


def save_dataset(
convert_config: ConvertConfig, data: Table | list | dict | DataFrame, compression: str = "brotli", level: int = 4
convert_config: ConvertConfig,
data: Table | list | dict | DataFrame,
filename: str = "data.parquet",
compression: str = "brotli",
level: int = 4,
):
root, output_path = convert_config.root, convert_config.output_path
os.makedirs(output_path, exist_ok=True)
write_data(data, output_path, compression=compression, level=level)
if isinstance(data, Mapping):
if filename != "data.parquet":
warn("Filename is ignored when saving multiple datasets.")
for name, d in data.items():
write_data(d, output_path, filename=name + ".parquet", compression=compression, level=level)
else:
write_data(data, output_path, filename=filename, compression=compression, level=level)
copy_readme(root, output_path)
push_to_hub(convert_config, output_path)

Expand Down

0 comments on commit 38e8a13

Please sign in to comment.