Skip to content

Commit

Permalink
Merge pull request #298 from Modalities/combined_dataset_fast
Browse files Browse the repository at this point in the history
Improved efficiency of  CombinedDataset with binary search
  • Loading branch information
mali-git authored Jan 31, 2025
2 parents d43f02f + 698e8b6 commit 3ebaa91
Showing 1 changed file with 8 additions and 25 deletions.
33 changes: 8 additions & 25 deletions src/modalities/dataloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
from tqdm import tqdm
from transformers import BatchEncoding

from modalities.dataloader.create_packed_data import EmbeddedStreamData
from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader
from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper

from ..dataloader.large_file_lines_reader import LargeFileLinesReader
from .create_packed_data import EmbeddedStreamData


class Dataset(TorchdataSet):
"""Dataset class."""
Expand Down Expand Up @@ -411,9 +410,9 @@ def _generate_packing_index(self) -> list[tuple[int, int]]:
class CombinedDataset(Dataset):
"""Combines multiple datasets into one large dataset at runtime.
Note: When using this class to combine multiple `PackedMemMapDataset`s, then each packed sample
Note: When using this class to combine multiple `PackedMemMapDataset`s, each packed sample
is packed from a single dataset (i.e., the samples are not mixed between datasets).
In the Dataloader a batch will still contain packed samples from different datasets.
In the Dataloader, a batch will still contain packed samples from different datasets.
"""

def __init__(self, datasets: list[Dataset]):
Expand All @@ -423,29 +422,13 @@ def __init__(self, datasets: list[Dataset]):
datasets (list[Dataset]): A list of datasets to combine.
"""
self.datasets = datasets
self.cumulative_sizes = CombinedDataset._get_cumulated_sizes(datasets=datasets)

@staticmethod
def _get_cumulated_sizes(datasets: list[Dataset]) -> list[int]:
total = 0
cumulated_sizes = [0]
for dataset in datasets:
total += len(dataset)
cumulated_sizes.append(total)
return cumulated_sizes

def _find_dataset_idx(self, idx: int) -> int:
for i, cumulative_size in enumerate(self.cumulative_sizes):
if idx < cumulative_size:
return i - 1
raise IndexError(f"Index {idx} is out of bounds.")
self.cumulative_sizes = np.cumsum([len(ds) for ds in datasets], dtype=np.int64)

def __len__(self) -> int:
return self.cumulative_sizes[-1]

def __getitem__(self, idx: int) -> dict:
dataset_idx = self._find_dataset_idx(idx)
local_idx = idx - self.cumulative_sizes[dataset_idx]
dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right")
local_idx = idx - (self.cumulative_sizes[dataset_idx - 1] if dataset_idx > 0 else 0)

sample = self.datasets[dataset_idx][local_idx]
return sample
return self.datasets[dataset_idx][local_idx]

0 comments on commit 3ebaa91

Please sign in to comment.