diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 1cc05625..7c953e71 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -11,11 +11,10 @@ from tqdm import tqdm from transformers import BatchEncoding +from modalities.dataloader.create_packed_data import EmbeddedStreamData +from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper -from ..dataloader.large_file_lines_reader import LargeFileLinesReader -from .create_packed_data import EmbeddedStreamData - class Dataset(TorchdataSet): """Dataset class.""" @@ -411,9 +410,9 @@ def _generate_packing_index(self) -> list[tuple[int, int]]: class CombinedDataset(Dataset): """Combines multiple datasets into one large dataset at runtime. - Note: When using this class to combine multiple `PackedMemMapDataset`s, then each packed sample + Note: When using this class to combine multiple `PackedMemMapDataset`s, each packed sample is packed from a single dataset (i.e., the samples are not mixed between datasets). - In the Dataloader a batch will still contain packed samples from different datasets. + In the Dataloader, a batch will still contain packed samples from different datasets. """ def __init__(self, datasets: list[Dataset]): @@ -423,29 +422,13 @@ def __init__(self, datasets: list[Dataset]): datasets (list[Dataset]): A list of datasets to combine. """ self.datasets = datasets - self.cumulative_sizes = CombinedDataset._get_cumulated_sizes(datasets=datasets) - - @staticmethod - def _get_cumulated_sizes(datasets: list[Dataset]) -> list[int]: - total = 0 - cumulated_sizes = [0] - for dataset in datasets: - total += len(dataset) - cumulated_sizes.append(total) - return cumulated_sizes - - def _find_dataset_idx(self, idx: int) -> int: - for i, cumulative_size in enumerate(self.cumulative_sizes): - if idx < cumulative_size: - return i - 1 - raise IndexError(f"Index {idx} is out of bounds.") + self.cumulative_sizes = np.cumsum([len(ds) for ds in datasets], dtype=np.int64) def __len__(self) -> int: return self.cumulative_sizes[-1] def __getitem__(self, idx: int) -> dict: - dataset_idx = self._find_dataset_idx(idx) - local_idx = idx - self.cumulative_sizes[dataset_idx] + dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right") + local_idx = idx - (self.cumulative_sizes[dataset_idx - 1] if dataset_idx > 0 else 0) - sample = self.datasets[dataset_idx][local_idx] - return sample + return self.datasets[dataset_idx][local_idx]