Skip to content

Commit

Permalink
Test
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Oct 11, 2023
1 parent f8933c8 commit 123510a
Showing 1 changed file with 45 additions and 28 deletions.
73 changes: 45 additions & 28 deletions src/accelerate/test_utils/scripts/external_deps/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from accelerate import Accelerator
from accelerate.data_loader import DataLoaderDispatcher
from accelerate.test_utils import RegressionDataset, RegressionModel
from accelerate.utils import is_tpu_available, set_seed
from accelerate.utils import set_seed


os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
Expand Down Expand Up @@ -219,6 +219,20 @@ def __iter__(self):
logger.removeHandler(list_handler)


def test_gather_for_metrics_drop_last():
accelerator = Accelerator()
dataloader = DataLoader(range((10 * accelerator.num_processes) + 1), batch_size=5, drop_last=True)
dataloader = accelerator.prepare(dataloader)

iterator = iter(dataloader)
next(iterator) # Skip first batch tensor([0, 1, 2, 3, 4], device='cuda:0')
batch = next(iterator)
gathered_items = accelerator.gather_for_metrics(batch)
assert gathered_items.size(0) == (
5 * accelerator.num_processes
), f"Expected number of items: {(5*accelerator.num_processes)}, Actual: {gathered_items.size(0)}"


def main():
accelerator = Accelerator(split_batches=False, dispatch_batches=False)
if accelerator.is_local_main_process:
Expand All @@ -227,34 +241,37 @@ def main():
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# These are a bit slower so they should only be ran on the GPU or TPU
if torch.cuda.is_available() or is_tpu_available():
if accelerator.is_local_main_process:
print("**Testing gather_for_metrics**")
for split_batches in [True, False]:
for dispatch_batches in [True, False]:
if accelerator.is_local_main_process:
print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`")
test_mrpc(dispatch_batches, split_batches)
accelerator.state._reset_state()
print("test_gather_for_metrics_with_iterable_dataset")
test_gather_for_metrics_with_iterable_dataset()
print("test gather_for_metrics_with_non_tensor_objects_iterable_dataset")
test_gather_for_metrics_with_non_tensor_objects_iterable_dataset()
if accelerator.is_local_main_process:
print("**Test torch metrics**")
for split_batches in [True, False]:
for dispatch_batches in [True, False]:
accelerator = Accelerator(split_batches=split_batches, dispatch_batches=dispatch_batches)
if accelerator.is_local_main_process:
print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`, length=99")
test_torch_metrics(accelerator, 99)
accelerator.state._reset_state()
if accelerator.is_local_main_process:
print("**Test last batch is not dropped when perfectly divisible**")
# # These are a bit slower so they should only be ran on the GPU or TPU
# if torch.cuda.is_available() or is_tpu_available():
# if accelerator.is_local_main_process:
# print("**Testing gather_for_metrics**")
# for split_batches in [True, False]:
# for dispatch_batches in [True, False]:
# if accelerator.is_local_main_process:
# print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`")
# test_mrpc(dispatch_batches, split_batches)
# accelerator.state._reset_state()
# print("test_gather_for_metrics_with_iterable_dataset")
# test_gather_for_metrics_with_iterable_dataset()
# print("test gather_for_metrics_with_non_tensor_objects_iterable_dataset")
# test_gather_for_metrics_with_non_tensor_objects_iterable_dataset()
# if accelerator.is_local_main_process:
# print("**Test torch metrics**")
# for split_batches in [True, False]:
# for dispatch_batches in [True, False]:
# accelerator = Accelerator(split_batches=split_batches, dispatch_batches=dispatch_batches)
# if accelerator.is_local_main_process:
# print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`, length=99")
# test_torch_metrics(accelerator, 99)
# accelerator.state._reset_state()
# if accelerator.is_local_main_process:
# print("**Test last batch is not dropped when perfectly divisible**")
accelerator = Accelerator()
test_torch_metrics(accelerator, 512)
accelerator.state._reset_state()
# test_torch_metrics(accelerator, 512)
# accelerator.state._reset_state()
if accelerator.is_local_main_process:
print("**Test that `drop_last` is taken into account**")
test_gather_for_metrics_drop_last()


def _mp_fn(index):
Expand Down

0 comments on commit 123510a

Please sign in to comment.