Skip to content

Commit

Permalink
[feat] support no tensor parallel Linear in shardformer; Add test for…
Browse files Browse the repository at this point in the history
… use weightGradStore and not use WeightGradStore
  • Loading branch information
duanjunwen committed Oct 30, 2024
1 parent 982e4ee commit d2e05a9
Show file tree
Hide file tree
Showing 4 changed files with 352 additions and 4 deletions.
3 changes: 2 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .linear import Linear1D, Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
Expand All @@ -11,6 +11,7 @@
__all__ = [
"Embedding1D",
"VocabParallelEmbedding1D",
"Linear1D",
"Linear1D_Col",
"Linear1D_Row",
"GPT2FusedLinearConv1D_Col",
Expand Down
106 changes: 105 additions & 1 deletion colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)

def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
# _grad_output_.t().matmul(_input_)
return wgrad_gemm_func(_grad_output_.t(), _input_)

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
Expand Down Expand Up @@ -236,6 +235,107 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
return grad_input, grad_weight, grad_bias, None, None, None, None


class LinearBase(torch.autograd.Function):
"""
Linear layer baseline (no tensor parallel version).
"""

@staticmethod
def forward(ctx, input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
ctx.use_zbv = use_zbv
if bias is not None:
output = F.linear(input_, weight, bias)
else:
output = F.linear(input_, weight)

return output

@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
ctx.fp8_communication
use_zbv = ctx.use_zbv

def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)

def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_grad_output_.t(), _input_)

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
bias.view(bias.shape)

total_input = input.contiguous()
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])

if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if use_zbv:
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
if grad.dtype == torch.float32:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
),
)
grad_weight = None
elif grad.dtype in (torch.float16, torch.bfloat16):
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
),
)
grad_weight = None
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)

grad_bias = grad_output.sum(dim=0) if use_bias else None

return grad_input, grad_weight, grad_bias, None, None, None, None


def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
# currently only support one single tensor as output
group_size = dist.get_world_size(process_group)
Expand Down Expand Up @@ -1101,6 +1201,10 @@ def linear_with_async_comm(
)


def linear_base(input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False):
return LinearBase.apply(input_, weight, bias, async_grad_allreduce, fp8_communication, use_zbv)


def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
):
Expand Down
155 changes: 154 additions & 1 deletion colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ._operation import (
gather_forward_reducescatter_backward,
gather_forward_split_backward,
linear_base,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
Expand All @@ -35,7 +36,159 @@
from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset

__all__ = ["Linear1D_Col", "Linear1D_Row"]
__all__ = ["Linear1D", "Linear1D_Col", "Linear1D_Row"]


class Linear1D(ParallelModule):
r"""Linear layer with no parallelism.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
"""

def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
gather_output: bool = False,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
use_zbv: bool = False,
**kwargs,
):
super().__init__(weight=weight, bias_=bias_, **kwargs)

# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.fp8_communication = fp8_communication
self.use_zbv = use_zbv

if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")

# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()

self.randomizer = create_randomizer_with_offset(seed, process_group=None)

# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"

# Parameters.
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight

if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None

if weight is None:
# init weights
self.reset_parameters(weight_initializer, bias_initializer)

@staticmethod
def from_native_module(module: nn.Linear, **kwargs) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device

linear_1d = Linear1D(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
weight=module.weight,
bias_=module.bias,
**kwargs,
)

return linear_1d

def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with self.randomizer.fork_rng(enable_cpu=True):
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)

def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert (
input_.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1]
)

# Set up backprop all-reduce.
input_parallel = input_

# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_base(
input_parallel,
self.weight,
bias,
False,
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv,
)

output = output_parallel

if self.skip_bias_add:
return output, self.bias
else:
return output


class Linear1D_Col(ParallelModule):
Expand Down
Loading

0 comments on commit d2e05a9

Please sign in to comment.