Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hotfix] fix unsafe calculation in zero #4404

Merged
merged 3 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 36 additions & 19 deletions colossalai/zero/low_level/bookkeeping/bucket_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@ class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)

# init and reset
# init
self.current_group_id = 0
self._num_elements_in_bucket = 0
# mapping gardient slices and parameter
self.grad_to_param_mapping = dict()

self._grad_in_bucket = dict()
self._param_list = []
self._padding_size = []
for rank in range(self._world_size):
self._grad_in_bucket[rank] = []

self.reset()
# offset_list records number of tensors in the bucket before each reduction
self.offset_list = [0]

def num_elements_in_bucket(self) -> int:
"""Return the total number of elements in bucket
Expand All @@ -32,6 +37,12 @@ def num_elements_in_bucket(self) -> int:

return self._num_elements_in_bucket

def reset_num_elements_in_bucket(self):
"""Set the number of elements in bucket to zero.
"""

self._num_elements_in_bucket = 0

def add_param_grad(self, group_id: int, param: Tensor, padding_size: int):
"""Add a param to bucket and record the padding size of a param for gradient padding

Expand All @@ -46,28 +57,32 @@ def add_param_grad(self, group_id: int, param: Tensor, padding_size: int):
self._num_elements_in_bucket += (param.numel() + padding_size)
self.current_group_id = group_id

# number of tensors in current bucket
self.offset_list[-1] += 1

def build_grad_in_bucket(self):
"""Orgnize parameters' gradient(padding and split), follows the paramters' splitting method

Data structure of self._grad_in_bucket:
{
rank0: [grad0_rank0, grad1_rank0, ...]
rank1: [grad1_rank1, grad1_rank1, ...]
rank1: [grad0_rank1, grad1_rank1, ...]
}
"""

for param, padding_size in zip(self._param_list, self._padding_size):
with torch.no_grad():
grad = param.grad.detach().flatten()
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
grad_list = grad.split(grad.numel() // self._world_size)
for rank in range(self._world_size):
grad_current_rank = grad_list[rank].detach()
self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
self._grad_in_bucket[rank].append(grad_current_rank)
grad = param.grad.clone().detach().flatten()
if padding_size > 0:
with torch.no_grad():
grad = torch.nn.functional.pad(grad.view(-1), [0, padding_size])
grad_list = grad.split(grad.numel() // self._world_size)
for rank in range(self._world_size):
grad_current_rank = grad_list[rank].clone().detach()
self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
self._grad_in_bucket[rank].append(grad_current_rank)
param.grad = None

self.offset_list.append(0)

def get_grad(self) -> Dict:
"""Return the dictionary of gradients slices, of which the keys are ranks

Expand Down Expand Up @@ -104,10 +119,12 @@ def get_param_id_of_grad(self, grad: Tensor) -> int:
return self.grad_to_param_mapping[id(grad)]

def reset(self):
self.grad_to_param_mapping = dict()
self._num_elements_in_bucket = 0
self._param_list = []
self._padding_size = []
self._grad_in_bucket = dict()
"""Reset the bucket storage after reduction, only release the tensors have been reduced
"""
cur_offset = self.offset_list.pop(0)
self._param_list = self._param_list[cur_offset:]
self._padding_size = self._padding_size[cur_offset:]
for _ in range(cur_offset):
del self.grad_to_param_mapping[next(iter(self.grad_to_param_mapping))]
for rank in range(self._world_size):
self._grad_in_bucket[rank] = []
self._grad_in_bucket[rank] = self._grad_in_bucket[rank][cur_offset:]
9 changes: 9 additions & 0 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,19 @@ def _attach_reduction_hook(self):
def _run_reduction(self):
if self._bucket_store.num_elements_in_bucket() > 0:
self._bucket_store.build_grad_in_bucket()

flat_grads = self._bucket_store.get_flatten_grad()
flat_grads /= self._world_size

# ready to add other tensors to bucket
self._bucket_store.reset_num_elements_in_bucket()

if self._overlap_communication:
stream = self._comm_stream
# in case of the memory being reused in the default stream
flat_grads.record_stream(stream)
# waiting for ops in the default stream finishing
stream.wait_stream(torch.cuda.current_stream())
else:
stream = torch.cuda.current_stream()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_zero/test_low_level/test_zero1_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype):
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=True,
initial_scale=1,
reduce_bucket_size=262144)
reduce_bucket_size=1024 * 1024)

torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)

Expand Down
Loading