Skip to content

Commit

Permalink
[WIP] Implementing gather_for_metrics with dedup for non tensor objec…
Browse files Browse the repository at this point in the history
…ts (#1937)

* [feat] implementing gather_for_metrics for objects

* [lint] make style result

* [docs] improve fn docs gather for metrics

Co-authored-by: Zach Mueller <[email protected]>

* [docs] update args description gather for metrics

Co-authored-by: Zach Mueller <[email protected]>

* [refactor] gather for metrics for non tensor obj

Co-authored-by: Zach Mueller <[email protected]>

* [fix] renaming tensor to data (was not defined and it is not just a tensor)

* [fix] else state

* [test] gather for metrics with non tensor objects

* [lint] make style result

* Update src/accelerate/accelerator.py

Co-authored-by: Zach Mueller <[email protected]>

* Update src/accelerate/accelerator.py

Co-authored-by: Zach Mueller <[email protected]>

* [test] removing useless assertion

Co-authored-by: Zach Mueller <[email protected]>

* [test] add running on main

* [lint] style autoformat

---------

Co-authored-by: Lorenzobattistela <[email protected]>
Co-authored-by: Zach Mueller <[email protected]>
  • Loading branch information
3 people authored Sep 12, 2023
1 parent d9b5ce6 commit 5d558f2
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 11 deletions.
33 changes: 22 additions & 11 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
convert_outputs_to_fp32,
extract_model_from_parallel,
gather,
gather_object,
get_mixed_precision_context_manager,
get_pretty_name,
has_transformer_engine_layers,
Expand Down Expand Up @@ -2099,14 +2100,14 @@ def gather(self, tensor):
"""
return gather(tensor)

def gather_for_metrics(self, tensor):
def gather_for_metrics(self, input_data):
"""
Gathers `tensor` and potentially drops duplicates in the last batch if on a distributed system. Should be used
for gathering the inputs and targets for metric calculation.
Gathers `input_data` and potentially drops duplicates in the last batch if on a distributed system. Should be
used for gathering the inputs and targets for metric calculation.
Args:
tensor (`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`):
The tensors for calculating metrics across all processes.
input (`torch.Tensor`, `object`, a nested tuple/list/dictionary of `torch.Tensor`, or a nested tuple/list/dictionary of `object`):
The tensors or objects for calculating metrics across all processes
Example:
Expand All @@ -2124,7 +2125,17 @@ def gather_for_metrics(self, tensor):
9
```
"""
tensor = self.gather(tensor)

try:
recursively_apply(lambda x: x, input_data, error_on_other_type=True)
all_tensors = True
except TypeError:
all_tensors = False

if not all_tensors:
data = gather_object(input_data)
else:
data = self.gather(input_data)

try:
if self.gradient_state.end_of_dataloader:
Expand All @@ -2134,22 +2145,22 @@ def gather_for_metrics(self, tensor):
logger.info(
"The used dataset had no length, returning gathered tensors. You should drop the remainder yourself."
)
return tensor
return data
elif self.gradient_state.remainder > 0:
# Last batch needs to be truncated on distributed systems as it contains additional samples
def _adjust_samples(tensor):
return tensor[: self.gradient_state.remainder]

return recursively_apply(_adjust_samples, tensor)
return recursively_apply(_adjust_samples, data)
else: # remainder is 0
# no remainder even though at end of dataloader, so nothing to do.
return tensor
return data
else:
# Not at the end of the dataloader, no need to adjust the tensors
return tensor
return data
except Exception:
# Dataset had no length or raised an error
return tensor
return data

def reduce(self, tensor, reduction="sum", scale=1.0):
"""
Expand Down
35 changes: 35 additions & 0 deletions src/accelerate/test_utils/scripts/external_deps/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,39 @@ def test_mrpc(dispatch_batches: bool = False, split_batches: bool = False):
), f"Baseline and Distributed are not the same for key {key}:\n\tBaseline: {baseline[key]}\n\tDistributed: {distributed[key]}\n"


def test_gather_for_metrics_with_non_tensor_objects_iterable_dataset():
class DummyIterableDataset(IterableDataset):
def __init__(self, data):
self.data = data

def __len__(self):
return len(self.data)

def __iter__(self):
for element in self.data:
yield element

iterable_dataset = DummyIterableDataset([n for n in range(30)])
dataloader = DataLoader(iterable_dataset, batch_size=4)
accelerator = Accelerator()
prepared_dataloader = accelerator.prepare(dataloader)

if accelerator.is_main_process:
logger = logging.root.manager.loggerDict["accelerate.accelerator"]
list_handler = ListHandler()
logger.addHandler(list_handler)

batches_for_metrics = []
for batch in prepared_dataloader:
batches_for_metrics.append(accelerator.gather_for_metrics(batch))

assert torch.cat(batches_for_metrics).size(0) == 30

if accelerator.is_main_process:
assert len(list_handler.logs) == 0
logger.removeHandler(list_handler)


def test_gather_for_metrics_with_iterable_dataset():
class DummyIterableDataset(IterableDataset):
def __init__(self, data):
Expand Down Expand Up @@ -206,6 +239,8 @@ def main():
accelerator.state._reset_state()
print("test_gather_for_metrics_with_iterable_dataset")
test_gather_for_metrics_with_iterable_dataset()
print("test gather_for_metrics_with_non_tensor_objects_iterable_dataset")
test_gather_for_metrics_with_non_tensor_objects_iterable_dataset()
if accelerator.is_local_main_process:
print("**Test torch metrics**")
for split_batches in [True, False]:
Expand Down

0 comments on commit 5d558f2

Please sign in to comment.