Skip to content

Commit

Permalink
Persist IterableDataset epoch in workers (#6710)
Browse files Browse the repository at this point in the history
* persist IterableDataset epoch in workers

* more tests

* comment

* re-share memory after pickling

* Update src/datasets/iterable_dataset.py
  • Loading branch information
lhoestq authored Jul 1, 2024
1 parent 100361d commit 4ba47a3
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 6 deletions.
33 changes: 27 additions & 6 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import dataclass
from functools import partial
from itertools import cycle, islice
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union

import fsspec.asyn
import numpy as np
Expand All @@ -26,6 +26,9 @@
from .utils.sharding import _merge_gen_kwargs, _number_of_shards_in_gen_kwargs, _shuffle_gen_kwargs, _split_gen_kwargs


if TYPE_CHECKING:
import torch

logger = get_logger(__name__)

Key = Union[int, str]
Expand Down Expand Up @@ -1690,6 +1693,18 @@ def _maybe_add_torch_iterable_dataset_parent_class(cls):
cls.__bases__ += (torch.utils.data.IterableDataset,)


def _maybe_share_with_torch_persistent_workers(value: Union[int, "torch.Tensor"]) -> Union[int, "torch.Tensor"]:
if config.TORCH_AVAILABLE:
import torch

if isinstance(value, torch.Tensor):
return value.share_memory_()
else:
return torch.tensor(value).share_memory_()
else:
return value


class IterableDataset(DatasetInfoMixin):
"""A Dataset backed by an iterable."""

Expand Down Expand Up @@ -1722,8 +1737,8 @@ def __init__(
self._formatting = formatting
self._shuffling = shuffling
self._distributed = distributed
self._epoch = 0
self._token_per_repo_id: Dict[str, Union[str, bool, None]] = token_per_repo_id or {}
self._epoch: Union[int, "torch.Tensor"] = _maybe_share_with_torch_persistent_workers(0)
self._starting_state_dict: Optional[dict] = None
self._prepared_ex_iterable = self._prepare_ex_iterable_for_iteration()
self._state_dict = self._prepared_ex_iterable._init_state_dict()
Expand Down Expand Up @@ -1841,18 +1856,24 @@ def __getstate__(self):

def __setstate__(self, d):
self.__dict__ = d
# Re-add torch shared memory, since shared memory is not always kept when pickling
self._epoch = _maybe_share_with_torch_persistent_workers(self._epoch)
# Re-add torch iterable dataset as a parent class, since dynamically added parent classes are not kept when pickling
_maybe_add_torch_iterable_dataset_parent_class(self.__class__)

def _head(self, n=5):
return _examples_to_batch(list(self.take(n)))

@property
def epoch(self) -> int:
return int(self._epoch)

def _effective_generator(self):
if self._shuffling and self._epoch == 0:
if self._shuffling and self.epoch == 0:
return self._shuffling.generator
elif self._shuffling:
# Create effective seed using self._epoch (we subtract in order to avoir overflow in long_scalars)
effective_seed = deepcopy(self._shuffling.generator).integers(0, 1 << 63) - self._epoch
# Create effective seed using self.epoch (we subtract in order to avoir overflow in long_scalars)
effective_seed = deepcopy(self._shuffling.generator).integers(0, 1 << 63) - self.epoch
effective_seed = (1 << 63) + effective_seed if effective_seed < 0 else effective_seed
return np.random.default_rng(effective_seed)
else:
Expand Down Expand Up @@ -2465,7 +2486,7 @@ def shuffle(
)

def set_epoch(self, epoch: int):
self._epoch = epoch
self._epoch += epoch - self._epoch # update torch value in shared memory in-place

def skip(self, n: int) -> "IterableDataset":
"""
Expand Down
22 changes: 22 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,6 +1641,28 @@ def test_iterable_dataset_is_torch_iterable_dataset(dataset: IterableDataset):
assert len(out) == DEFAULT_N_EXAMPLES


@require_torch
def test_iterable_dataset_persists_epoch_in_torch_workers(dataset: IterableDataset):
from torch.utils.data import DataLoader

dataset = dataset.shuffle(seed=42)
dataloader = DataLoader(dataset, num_workers=1, persistent_workers=True)
epoch0 = list(dataloader)
assert list(dataloader) == epoch0
dataset.set_epoch(1)
assert list(dataloader) != epoch0

# Make sure pickle works even with torch objects in shared memory
dataset_copy: IterableDataset = pickle.loads(pickle.dumps(dataset))
dataloader = DataLoader(dataset_copy, num_workers=1, persistent_workers=True)
epoch1 = list(dataloader)
assert list(dataloader) == epoch1
dataset.set_epoch(2) # this should not affect the copy
assert list(dataloader) == epoch1
dataset_copy.set_epoch(2)
assert list(dataloader) != epoch1


@pytest.mark.parametrize("n", [0, 2, int(1e10)])
def test_iterable_dataset_skip(dataset: IterableDataset, n):
skip_dataset = dataset.skip(n)
Expand Down

0 comments on commit 4ba47a3

Please sign in to comment.