Skip to content

Commit

Permalink
[shardformer/sequence parallel] Overlap input gather and grad computa…
Browse files Browse the repository at this point in the history
…tion during col backward (#4401)

* overlap gather input / grad computing during col backward

* modify test for overlap

* simplify code

* fix code and modify cuda stream synchronize
  • Loading branch information
FoolPlayer authored Aug 11, 2023
1 parent 19b48bc commit ab0e294
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 56 deletions.
133 changes: 85 additions & 48 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,18 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
"""

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap

input_parallel = _gather(input_, dim, process_group)

Expand All @@ -175,42 +177,80 @@ def backward(ctx, grad_output):
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap

if not overlap:
# TODO: overlap SP input with gradient computation
input_parallel = _gather(input_, dim, process_group)

total_input = input_parallel
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])

# TODO: overlap SP input with gradient computation
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_parallel.dtype,
device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1

grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.async_grad_reduce_scatter:
handle.wait()

# TODO: overlap SP input with gradient computation
input_parallel = _gather(input_, dim, process_group)

total_input = input_parallel
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])

# TODO: overlap SP input with gradient computation
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
new_shape = list(input_parallel.shape)
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(new_shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1

grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.async_grad_reduce_scatter:
handle.wait()

# print(output, output.shape)
return output, grad_weight, grad_bias, None, None, None
else:
# create new stream for calculate the gradient
calculate_stream = torch.cuda.Stream()

# do all gather in default stream
input_ = input_.contiguous()
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)

# calculate gradient in calculate_stream
with torch.cuda.stream(calculate_stream):
# calculate
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None

# prepare data
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()

torch.cuda.synchronize()

reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
with torch.cuda.stream(calculate_stream):
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
print(grad_output.shape, input_parallel.shape)
grad_weight = grad_output.t().matmul(input_parallel)

torch.cuda.synchronize()

return output, grad_weight, grad_bias, None, None, None, None


class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
Expand Down Expand Up @@ -294,14 +334,10 @@ def backward(ctx, grad_output):
# TODO: overlap SP input with gradient computation
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
new_shape = list(input_parallel.shape)
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(new_shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
Expand Down Expand Up @@ -442,7 +478,7 @@ def _gather(input_, dim=-1, process_group=None):
return output


def _reduce_scatter(intput_, dim=1, process_group=None):
def _reduce_scatter(input_, dim=1, process_group=None):
""" Do reduce-scatter operation.
Args:
Expand All @@ -452,15 +488,15 @@ def _reduce_scatter(intput_, dim=1, process_group=None):
"""
world_size = dist.get_world_size(process_group)
if world_size == 1:
return intput_
return input_

# reduce-scatter
new_shape = list(intput_.shape)
new_shape = list(input_.shape)
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
new_shape[dim] = new_shape[dim] // world_size
output = torch.empty(new_shape, dtype=intput_.dtype, device=intput_.device)
dist.reduce_scatter(output, intput_, group=process_group)
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
dist.reduce_scatter(output, input_, group=process_group)

return output

Expand All @@ -473,9 +509,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)


def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
overlap):
return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
async_grad_reduce_scatter, dim)
async_grad_reduce_scatter, dim, overlap)


def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
Expand Down
5 changes: 4 additions & 1 deletion colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Linear1D_Col(ParallelModule):
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
Expand All @@ -73,6 +74,7 @@ def __init__(self,
process_group: ProcessGroup = None,
gather_output: bool = False,
seq_parallel: bool = False,
overlap: bool = False,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
Expand All @@ -85,6 +87,7 @@ def __init__(self,
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel = seq_parallel
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
Expand Down Expand Up @@ -179,7 +182,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel:
output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
self.process_group, True, 1)
self.process_group, True, 1, self.overlap)
else:
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)

Expand Down
17 changes: 10 additions & 7 deletions tests/test_shardformer/test_layer/test_linear_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


def check_linear_1d_col(lazy_init: bool, seq_parallel: bool):
def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(linear_copy,
process_group=None,
gather_output=True,
seq_parallel=seq_parallel)
seq_parallel=seq_parallel,
overlap=overlap)

# ensure that the parameters are distributed
assert is_distributed_tensor(linear_col.weight)
Expand Down Expand Up @@ -111,7 +112,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
assert_close(x_for_unshard.grad, x_for_shard.grad)


def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool):
def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()

linear_1 = nn.Linear(32, 128).cuda()
Expand All @@ -123,7 +124,8 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool):
linear_col = Linear1D_Col.from_native_module(linear_1_copy,
process_group=None,
gather_output=False,
seq_parallel=seq_parallel)
seq_parallel=seq_parallel,
overlap=overlap)
linear_row = Linear1D_Row.from_native_module(linear_2_copy,
process_group=None,
parallel_input=True,
Expand Down Expand Up @@ -166,10 +168,11 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool):

@parameterize('lazy_init', [False, True])
@parameterize('seq_parallel', [False, True])
def run_dist_linear_test(lazy_init, seq_parallel):
check_linear_1d_col(lazy_init, seq_parallel)
@parameterize('overlap', [False, True])
def run_dist_linear_test(lazy_init, seq_parallel, overlap):
check_linear_1d_col(lazy_init, seq_parallel, overlap)
check_linear_1d_row(lazy_init, seq_parallel)
check_linear_col_plus_row(lazy_init, seq_parallel)
check_linear_col_plus_row(lazy_init, seq_parallel, overlap)


def check_dist_linear(rank, world_size, port):
Expand Down

0 comments on commit ab0e294

Please sign in to comment.