From efdc13a13d2d8aff605509059c91266008476035 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Thu, 30 Jan 2025 18:39:23 -0500 Subject: [PATCH] Tweaked Version --- src/levanter/data/dataset.py | 55 ++++++++++++++++++++++++++++++++++++ src/levanter/data/mixture.py | 12 +------- src/levanter/data/text.py | 13 +++++++-- tests/test_mixture.py | 6 ++-- 4 files changed, 68 insertions(+), 18 deletions(-) diff --git a/src/levanter/data/dataset.py b/src/levanter/data/dataset.py index f448ed83b..fe590da96 100644 --- a/src/levanter/data/dataset.py +++ b/src/levanter/data/dataset.py @@ -122,6 +122,11 @@ def map(self, fn: MapFunction[U], *extra_args, **extra_kwargs) -> "MappedAsyncDa def map_batches(self, fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs) -> "BatchMappedAsyncDataset[U]": return BatchMappedAsyncDataset(self, fn, *extra_args, **extra_kwargs) + def slice_dataset( + self, start_index: Optional[int] = None, end_index: Optional[int] = None + ) -> "SlicedAsyncDataset[U]": + return SlicedAsyncDataset(self, start_index, end_index) + def shuffle(self, key: PRNGKey): import levanter.data.permutation as permutation @@ -375,6 +380,53 @@ def _call_fn(self, index, item): return self.fn(item, *self._extra_args, **kwargs) +class SlicedAsyncDataset(AsyncDataset[U]): + def __init__( + self, + dataset: AsyncDataset[T], + start_index: Optional[int] = None, + end_index: Optional[int] = None, + ): + if start_index is None: + start_index = 0 + if end_index is not None and start_index > end_index: + raise ValueError("End index must come after start index.") + self.start_index = start_index + self.end_index = end_index + self.dataset = dataset + self._min_known_len = dataset._min_known_len if start_index is not None else (end_index - start_index) + + async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: + shifted_indices = [(index + self.start_index) for index in indices] + max_index = max(shifted_indices) + + if self.end_index is not None and max_index > self.end_index: + raise ValueError("Requested indices beyond the end of the dataset") + + return await self.dataset.get_batch(indices) + + async def async_len(self) -> int: + underlying_length = await self.dataset.async_len() + if self.end_index is None: + return underlying_length - self.start_index + else: + return self.end_index - self.start_index + + async def final_length_is_known(self) -> bool: + underlying_is_known = await self.dataset.final_length_is_known() + return underlying_is_known and self.end_index is not None + + def is_finite(self) -> bool: + return self.dataset.is_finite() and self.end_index is not None + + async def current_len(self) -> Optional[int]: + underlying_length = await self.dataset.current_len() + if self.end_index is None: + return underlying_length - self.start_index + else: + return self.end_index - self.start_index + + class BatchMappedAsyncDataset(AsyncDataset[U]): """ A dataset that applies a function to each batch of items in the dataset. @@ -408,6 +460,9 @@ def is_finite(self) -> bool: async def current_len(self) -> Optional[int]: return await self.dataset.current_len() + async def wait_until_len_at_least(self, length: int) -> int: + return await self.dataset.wait_until_len_at_least(length) + def _maybe_fold_in_key(self, key, indices: Sequence[int]): if key is not None: key = _fold_in_key_vmap(key, np.array(indices)) diff --git a/src/levanter/data/mixture.py b/src/levanter/data/mixture.py index 57fd792bb..188e5e426 100644 --- a/src/levanter/data/mixture.py +++ b/src/levanter/data/mixture.py @@ -52,7 +52,6 @@ def __init__( randomize_blocks: bool = True, key: PRNGKeyArray | int, stop_strategy: str = StopStrategy.RESTART_STRATEGY, - simulated_data_ratio: float = 1, ): super().__init__() if isinstance(weights, dict): @@ -100,9 +99,6 @@ def __init__( raise NotImplementedError("Only restart strategy is supported for now.") self.stop_strategy = stop_strategy - if simulated_data_ratio > 1: - raise ValueError(f"Simulated data ratio must be at most 1, got {simulated_data_ratio}") - self.simulated_data_ratio = simulated_data_ratio # Initialize stage-related counts and IDs ( @@ -279,13 +275,7 @@ async def _remap_indices(self, ds, indices_into_ds): if self.stop_strategy == StopStrategy.RESTART_STRATEGY: if ds.is_finite(): max_elem = max(indices_into_ds) - # Remap Indices Earlier when simulating epoching for a larger budget - if self.simulated_data_ratio < 1: - # Note(Will): This blocks on datasets being fully processed even for small simulated runs making simulating data size slightly latency inducing but I think that's ok - true_length_of_dataset = await ds.async_len() - length_of_dataset = int(true_length_of_dataset * self.simulated_data_ratio) - else: - length_of_dataset = await ds.wait_until_len_at_least(max_elem + 1) + length_of_dataset = await ds.wait_until_len_at_least(max_elem + 1) indices_into_ds = [idx % length_of_dataset for idx in indices_into_ds] return indices_into_ds diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 355de9631..98139e8e0 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -1231,10 +1231,18 @@ def shuffle_ds(ds, key): out_token_datasets[name] = shuffle_ds(ds, next(key_iter)) token_datasets = out_token_datasets + if self.experiment_budget > self.target_budget: + raise ValueError( + f"Experiment budget should be smaller than target budget, got {self.experiment_budget} >" + f" {self.target_budget}" + ) if self.experiment_budget is not None and self.target_budget is not None: simulated_data_ratio = self.experiment_budget / self.target_budget - else: - simulated_data_ratio = 1 + for name, ds in self.datasets: + # Note(Will): This blocks on datasets being fully processed even for small simulated runs making simulating data size slightly latency inducing but I think that's ok + true_length_of_dataset = len(ds.as_sync_dataset()) + simulated_length_of_dataset = int(true_length_of_dataset * simulated_data_ratio) + token_datasets[name] = ds.slice_dataset(end_index=simulated_length_of_dataset) mixture = MixtureDataset( datasets=token_datasets, @@ -1242,7 +1250,6 @@ def shuffle_ds(ds, key): stop_strategy=self.stop_strategy, key=mix_key, block_size=self.mixture_block_size, - simulated_data_ratio=simulated_data_ratio, ) return mixture diff --git a/tests/test_mixture.py b/tests/test_mixture.py index 458297745..52652f380 100644 --- a/tests/test_mixture.py +++ b/tests/test_mixture.py @@ -77,13 +77,12 @@ async def test_mixture_dataset_stop_strategy_restart(): async def test_mixture_dataset_simulated_data_size(): weights = {"ds1": 1 / 3, "ds2": 1 / 3, "ds3": 1 / 3} mixture_ds = MixtureDataset( - datasets(), + {name: dataset.slice_dataset(end_index=1) for name, dataset in datasets().items()}, weights, block_size=10, key=key(), randomize_blocks=False, stop_strategy=StopStrategy.RESTART_STRATEGY, - simulated_data_ratio=0.2, ) for _ in range(10): batch = await mixture_ds.get_batch([0, 1, 2]) @@ -91,13 +90,12 @@ async def test_mixture_dataset_simulated_data_size(): assert all(item in [1, 10, 100] for item in batch) mixture_ds = MixtureDataset( - datasets(), + {name: dataset.slice_dataset(end_index=2) for name, dataset in datasets().items()}, weights, block_size=10, key=key(), randomize_blocks=False, stop_strategy=StopStrategy.RESTART_STRATEGY, - simulated_data_ratio=0.4, ) for _ in range(10): batch = await mixture_ds.get_batch([0, 1, 2])