Skip to content

Commit

Permalink
[shardformer] support interleaved pipeline (hpcaitech#4448)
Browse files Browse the repository at this point in the history
* support interleaved pipeline

* fix unit test

* remove virtual stage test in stage mgr

* add droped type hint and updated bwd
  • Loading branch information
Gy-Lu authored Aug 16, 2023
1 parent 26e29d5 commit a78daf6
Show file tree
Hide file tree
Showing 7 changed files with 642 additions and 109 deletions.
10 changes: 8 additions & 2 deletions colossalai/cluster/process_group_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,23 @@ def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]:
return np.unravel_index(rank, shape)

@staticmethod
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...]) -> int:
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int:
"""Convert a coordinate to a rank.
mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
with wrap, index out of range would be wrapped around.
For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2)
Args:
coords (Tuple[int, ...]): Coordinate to be converted.
shape (Tuple[int, ...]): Shape of the process group mesh.
mode (Optional[str]): The mode for numpy.ravel_multi_index.
Returns:
int: Rank of the coordinate.
"""
return np.ravel_multi_index(coord, shape)

assert mode in ["raise", "wrap", "clip"]
return np.ravel_multi_index(coord, shape, mode)

def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
Expand Down
45 changes: 17 additions & 28 deletions colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,10 @@ def recv_forward(self, prev_rank: int = None) -> Any:
Returns:
Any: The input tensor or input tensor list.
"""
if self.stage_manager.is_first_stage():
input_tensor = None
else:
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(prev_rank, cur_rank,
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))

return input_tensor

Expand All @@ -193,14 +189,11 @@ def recv_backward(self, next_rank: int = None) -> Any:
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.stage_manager.is_last_stage():
output_tensor_grad = None
else:
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object(next_rank, cur_rank,
self.stage_manager.get_p2p_process_group(next_rank, cur_rank))
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object(next_rank, cur_rank,
self.stage_manager.get_p2p_process_group(next_rank, cur_rank))

return output_tensor_grad

Expand All @@ -211,12 +204,10 @@ def send_forward(self, output_object: Any, next_rank: int = None) -> None:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(output_object, cur_rank, next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank))

def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
Expand All @@ -225,9 +216,7 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank,
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
Loading

0 comments on commit a78daf6

Please sign in to comment.