Skip to content

Commit

Permalink
Tweaked Version
Browse files Browse the repository at this point in the history
  • Loading branch information
Helw150 committed Jan 30, 2025
1 parent bbec00b commit efdc13a
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 18 deletions.
55 changes: 55 additions & 0 deletions src/levanter/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def map(self, fn: MapFunction[U], *extra_args, **extra_kwargs) -> "MappedAsyncDa
def map_batches(self, fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs) -> "BatchMappedAsyncDataset[U]":
return BatchMappedAsyncDataset(self, fn, *extra_args, **extra_kwargs)

def slice_dataset(
self, start_index: Optional[int] = None, end_index: Optional[int] = None
) -> "SlicedAsyncDataset[U]":
return SlicedAsyncDataset(self, start_index, end_index)

def shuffle(self, key: PRNGKey):
import levanter.data.permutation as permutation

Expand Down Expand Up @@ -375,6 +380,53 @@ def _call_fn(self, index, item):
return self.fn(item, *self._extra_args, **kwargs)


class SlicedAsyncDataset(AsyncDataset[U]):
def __init__(
self,
dataset: AsyncDataset[T],
start_index: Optional[int] = None,
end_index: Optional[int] = None,
):
if start_index is None:
start_index = 0
if end_index is not None and start_index > end_index:
raise ValueError("End index must come after start index.")
self.start_index = start_index
self.end_index = end_index
self.dataset = dataset
self._min_known_len = dataset._min_known_len if start_index is not None else (end_index - start_index)

async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
shifted_indices = [(index + self.start_index) for index in indices]
max_index = max(shifted_indices)

if self.end_index is not None and max_index > self.end_index:
raise ValueError("Requested indices beyond the end of the dataset")

return await self.dataset.get_batch(indices)

async def async_len(self) -> int:
underlying_length = await self.dataset.async_len()
if self.end_index is None:
return underlying_length - self.start_index
else:
return self.end_index - self.start_index

async def final_length_is_known(self) -> bool:
underlying_is_known = await self.dataset.final_length_is_known()
return underlying_is_known and self.end_index is not None

def is_finite(self) -> bool:
return self.dataset.is_finite() and self.end_index is not None

async def current_len(self) -> Optional[int]:
underlying_length = await self.dataset.current_len()
if self.end_index is None:
return underlying_length - self.start_index
else:
return self.end_index - self.start_index


class BatchMappedAsyncDataset(AsyncDataset[U]):
"""
A dataset that applies a function to each batch of items in the dataset.
Expand Down Expand Up @@ -408,6 +460,9 @@ def is_finite(self) -> bool:
async def current_len(self) -> Optional[int]:
return await self.dataset.current_len()

async def wait_until_len_at_least(self, length: int) -> int:
return await self.dataset.wait_until_len_at_least(length)

def _maybe_fold_in_key(self, key, indices: Sequence[int]):
if key is not None:
key = _fold_in_key_vmap(key, np.array(indices))
Expand Down
12 changes: 1 addition & 11 deletions src/levanter/data/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def __init__(
randomize_blocks: bool = True,
key: PRNGKeyArray | int,
stop_strategy: str = StopStrategy.RESTART_STRATEGY,
simulated_data_ratio: float = 1,
):
super().__init__()
if isinstance(weights, dict):
Expand Down Expand Up @@ -100,9 +99,6 @@ def __init__(
raise NotImplementedError("Only restart strategy is supported for now.")

self.stop_strategy = stop_strategy
if simulated_data_ratio > 1:
raise ValueError(f"Simulated data ratio must be at most 1, got {simulated_data_ratio}")
self.simulated_data_ratio = simulated_data_ratio

# Initialize stage-related counts and IDs
(
Expand Down Expand Up @@ -279,13 +275,7 @@ async def _remap_indices(self, ds, indices_into_ds):
if self.stop_strategy == StopStrategy.RESTART_STRATEGY:
if ds.is_finite():
max_elem = max(indices_into_ds)
# Remap Indices Earlier when simulating epoching for a larger budget
if self.simulated_data_ratio < 1:
# Note(Will): This blocks on datasets being fully processed even for small simulated runs making simulating data size slightly latency inducing but I think that's ok
true_length_of_dataset = await ds.async_len()
length_of_dataset = int(true_length_of_dataset * self.simulated_data_ratio)
else:
length_of_dataset = await ds.wait_until_len_at_least(max_elem + 1)
length_of_dataset = await ds.wait_until_len_at_least(max_elem + 1)
indices_into_ds = [idx % length_of_dataset for idx in indices_into_ds]

return indices_into_ds
Expand Down
13 changes: 10 additions & 3 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,18 +1231,25 @@ def shuffle_ds(ds, key):
out_token_datasets[name] = shuffle_ds(ds, next(key_iter))
token_datasets = out_token_datasets

if self.experiment_budget > self.target_budget:
raise ValueError(
f"Experiment budget should be smaller than target budget, got {self.experiment_budget} >"
f" {self.target_budget}"
)
if self.experiment_budget is not None and self.target_budget is not None:
simulated_data_ratio = self.experiment_budget / self.target_budget
else:
simulated_data_ratio = 1
for name, ds in self.datasets:
# Note(Will): This blocks on datasets being fully processed even for small simulated runs making simulating data size slightly latency inducing but I think that's ok
true_length_of_dataset = len(ds.as_sync_dataset())
simulated_length_of_dataset = int(true_length_of_dataset * simulated_data_ratio)
token_datasets[name] = ds.slice_dataset(end_index=simulated_length_of_dataset)

mixture = MixtureDataset(
datasets=token_datasets,
weights=self.train_weights,
stop_strategy=self.stop_strategy,
key=mix_key,
block_size=self.mixture_block_size,
simulated_data_ratio=simulated_data_ratio,
)

return mixture
Expand Down
6 changes: 2 additions & 4 deletions tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,27 +77,25 @@ async def test_mixture_dataset_stop_strategy_restart():
async def test_mixture_dataset_simulated_data_size():
weights = {"ds1": 1 / 3, "ds2": 1 / 3, "ds3": 1 / 3}
mixture_ds = MixtureDataset(
datasets(),
{name: dataset.slice_dataset(end_index=1) for name, dataset in datasets().items()},
weights,
block_size=10,
key=key(),
randomize_blocks=False,
stop_strategy=StopStrategy.RESTART_STRATEGY,
simulated_data_ratio=0.2,
)
for _ in range(10):
batch = await mixture_ds.get_batch([0, 1, 2])
assert len(batch) == 3
assert all(item in [1, 10, 100] for item in batch)

mixture_ds = MixtureDataset(
datasets(),
{name: dataset.slice_dataset(end_index=2) for name, dataset in datasets().items()},
weights,
block_size=10,
key=key(),
randomize_blocks=False,
stop_strategy=StopStrategy.RESTART_STRATEGY,
simulated_data_ratio=0.4,
)
for _ in range(10):
batch = await mixture_ds.get_batch([0, 1, 2])
Expand Down

0 comments on commit efdc13a

Please sign in to comment.