Skip to content

Commit

Permalink
Decouple prepare_data_loader() from Accelerator (#3047)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddk authored Aug 26, 2024
1 parent 726140c commit 2789933
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 17 deletions.
13 changes: 6 additions & 7 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler

from .logging import get_logger
from .state import AcceleratorState, DistributedType, GradientState, PartialState, is_torch_xla_available
from .state import DistributedType, GradientState, PartialState, is_torch_xla_available
from .utils import (
RNGType,
broadcast,
Expand Down Expand Up @@ -720,7 +720,7 @@ def __init__(
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)

self.gradient_state = GradientState()
self.state = AcceleratorState()
self.state = PartialState()
self._drop_last = _drop_last
self._non_blocking = _non_blocking
self.skip_batches = skip_batches
Expand Down Expand Up @@ -937,10 +937,9 @@ def prepare_data_loader(
device (`torch.device`):
The target device for the returned `DataLoader`.
num_processes (`int`, *optional*):
The number of processes running concurrently. Will default to the value given by
[`~state.AcceleratorState`].
The number of processes running concurrently. Will default to the value given by [`~state.PartialState`].
process_index (`int`, *optional*):
The index of the current process. Will default to the value given by [`~state.AcceleratorState`].
The index of the current process. Will default to the value given by [`~state.PartialState`].
split_batches (`bool`, *optional*, defaults to `False`):
Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
Expand Down Expand Up @@ -1009,8 +1008,8 @@ def prepare_data_loader(

if dispatch_batches and not put_on_device:
raise ValueError("Using `dispatch_batches=True` requires `put_on_device=True`.")
# Grab defaults from AcceleratorState
state = AcceleratorState()
# Grab defaults from PartialState
state = PartialState()
if num_processes is None:
num_processes = state.num_processes
if process_index is None:
Expand Down
130 changes: 120 additions & 10 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from parameterized import parameterized
from torch.utils.data import BatchSampler, DataLoader, IterableDataset

from accelerate import Accelerator
from accelerate import Accelerator, PartialState
from accelerate.data_loader import (
BatchSamplerShard,
DataLoaderDispatcher,
Expand All @@ -29,11 +29,12 @@
IterableDatasetShard,
SkipBatchSampler,
SkipDataLoader,
prepare_data_loader,
skip_first_batches,
)
from accelerate.state import GradientState
from accelerate.test_utils.testing import require_torchdata_stateful_dataloader
from accelerate.utils import is_torchdata_stateful_dataloader_available
from accelerate.utils.dataclasses import DataLoaderConfiguration


if is_torchdata_stateful_dataloader_available():
Expand Down Expand Up @@ -401,9 +402,8 @@ def test_iterable_dataset_shard(self):

def test_iterable_dataset_using_none_batch_size(self):
dataset = SimpleIterableDataset(100)
accelerator = Accelerator()
dataloader = DataLoader(dataset, batch_size=None)
dataloader = accelerator.prepare(dataloader)
dataloader = prepare_data_loader(dataloader)
for d in dataloader:
assert isinstance(d, torch.Tensor)

Expand All @@ -417,7 +417,6 @@ def test_dataloader_inheritance(self):
`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that subclasses of DataLoaderAdapter
are instances of DataLoader and DataLoaderStateMixin.
"""
Accelerator()
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2)
dl_shard = DataLoaderShard(range(16), batch_size=4)
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4)
Expand Down Expand Up @@ -454,7 +453,6 @@ def test_end_of_dataloader(self):
assert dataloader.end_of_dataloader == (idx == 3)

def test_end_of_dataloader_dispatcher(self):
Accelerator()
dataloader = DataLoaderDispatcher(range(16), batch_size=4)
for idx, _ in enumerate(dataloader):
assert dataloader.end_of_dataloader == (idx == 3)
Expand Down Expand Up @@ -492,7 +490,6 @@ def test_end_of_dataloader(self):

@require_torchdata_stateful_dataloader
def test_end_of_dataloader_dispatcher(self):
Accelerator()
dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)
assert isinstance(dataloader, StatefulDataLoader)
for idx, _ in enumerate(dataloader):
Expand Down Expand Up @@ -535,8 +532,6 @@ def test_dataloader_dispatcher_state_dict(self, num_workers):
"""
Test that saving a stateful dataloader's state, then loading it back, gives the same results.
"""
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
Accelerator(dataloader_config=dataloader_config)
dataset = list(range(16))
dataloader = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers)

Expand Down Expand Up @@ -565,7 +560,6 @@ def test_dataloader_inheritance(self):
`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that if use_stateful_dataloader=True,
subclasses of DataLoaderAdapter are instances of StatefulDataLoader and DataLoaderStateMixin.
"""
Accelerator()
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True)
dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True)
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)
Expand Down Expand Up @@ -689,3 +683,119 @@ def get_all_batches(dl, device):
assert expected_batch_results[1] == dl_results[1]

assert accelerator.gradient_state.active_dataloader is None

@parameterized.expand([0, 2], name_func=parameterized_custom_name_func)
@require_torchdata_stateful_dataloader
def test_decoupled_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers):
"""
Assert that `state_dict()` and `load_state_dict()` for derived subclasses of `DataLoaderAdapter` produce
the same behavior as `state_dict()` and `load_state_dict()` for `StatefulDataLoader` when *not* using
Accelerator (and instead using the decoupled `PartialState` workflow).
"""
dataset = list(range(64))

# Set the seed for reproducibility
def g():
return torch.Generator().manual_seed(42)

state = PartialState()
stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())
skip_dl = SkipDataLoader(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
dl_shard = DataLoaderShard(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
dl_dispatcher = DataLoaderDispatcher(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)

dataloaders_under_test = [skip_dl, dl_shard, dl_dispatcher]

num_batches_to_skip = 8

def get_first_n_batches(dl, n, device):
"""
Iterate over the first `n` batches of a dataloader then break, returning the batches in a list.
"""
batches = []
for idx, batch in enumerate(dl):
if idx == n - 1:
if hasattr(dl, "end"):
dl.end()
break
batches.append(batch.to(device))
return batches

# Iterate over all of the dataloaders identically, expect the same values
expected_batches = get_first_n_batches(stateful_dl, num_batches_to_skip, state.device)
batches_from_dataloaders = [
get_first_n_batches(dl, num_batches_to_skip, state.device) for dl in dataloaders_under_test
]

for dl_batches in batches_from_dataloaders:
for expected, actual in zip(expected_batches, dl_batches):
assert torch.allclose(expected, actual)

# The adapters should all produce the same state_dict as the reference stateful dataloader
expected_state_dict = stateful_dl.state_dict()
skip_dl_state_dict = skip_dl.state_dict()
dl_shard_state_dict = dl_shard.state_dict()
dl_dispatcher_state_dict = dl_dispatcher.state_dict()

assert expected_state_dict == skip_dl_state_dict
assert expected_state_dict == dl_shard_state_dict
assert expected_state_dict == dl_dispatcher_state_dict

# Load the state dict into new dataloaders
manual_skip_dl = SkipDataLoader(
dataset,
batch_size=4,
num_workers=num_workers,
generator=g(),
skip_batches=num_batches_to_skip,
use_stateful_dataloader=True,
)
loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())
loaded_stateful_dl.load_state_dict(expected_state_dict)
loaded_skip_dl = SkipDataLoader(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
loaded_skip_dl.load_state_dict(expected_state_dict)
loaded_dl_shard = DataLoaderShard(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
loaded_dl_shard.load_state_dict(expected_state_dict)
loaded_dl_dispatcher = DataLoaderDispatcher(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
loaded_dl_dispatcher.load_state_dict(expected_state_dict)

# Continue the iteration, expecting identical behavior across the board
def get_all_batches(dl, device):
"""
Iterate over all batches of a dataloader, returning (batches, num_batches_yielded)
"""
batches = []
num_batches_yielded = 0
for batch in dl:
batches.append(batch.to(device))
num_batches_yielded += 1
return (batches, num_batches_yielded)

expected_batch_results = get_all_batches(loaded_stateful_dl, state.device)
dataloader_batch_results = [
get_all_batches(dl, state.device)
for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher]
]
for dl_results in dataloader_batch_results:
for expected, actual in zip(expected_batches, dl_batches):
assert torch.allclose(expected[0], actual[0])
assert expected_batch_results[1] == dl_results[1]

# Using the decoupled (`PartialState`) workflow, GradientState should be automatically initialized (with
# default parameters) by `DataLoaderDispatcher`
assert GradientState._shared_state != {}, "GradientState should already be initialized!"

gradient_state = GradientState()
assert gradient_state.active_dataloader is None

0 comments on commit 2789933

Please sign in to comment.