Skip to content

Commit

Permalink
[fix] fix linear (no tp) ops func name;
Browse files Browse the repository at this point in the history
  • Loading branch information
duanjunwen committed Oct 31, 2024
1 parent d2e05a9 commit 5f09243
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 41 deletions.
4 changes: 2 additions & 2 deletions colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D, Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
Expand All @@ -11,7 +11,7 @@
__all__ = [
"Embedding1D",
"VocabParallelEmbedding1D",
"Linear1D",
"LinearWithGradAccum",
"Linear1D_Col",
"Linear1D_Row",
"GPT2FusedLinearConv1D_Col",
Expand Down
10 changes: 4 additions & 6 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,17 +235,16 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
return grad_input, grad_weight, grad_bias, None, None, None, None


class LinearBase(torch.autograd.Function):
class LinearWithGradAccum(torch.autograd.Function):
"""
Linear layer baseline (no tensor parallel version).
"""

@staticmethod
def forward(ctx, input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False):
def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
ctx.use_zbv = use_zbv
if bias is not None:
output = F.linear(input_, weight, bias)
Expand All @@ -258,7 +257,6 @@ def forward(ctx, input_, weight, bias, async_grad_allreduce, fp8_communication=F
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
ctx.fp8_communication
use_zbv = ctx.use_zbv

def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
Expand Down Expand Up @@ -1201,8 +1199,8 @@ def linear_with_async_comm(
)


def linear_base(input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False):
return LinearBase.apply(input_, weight, bias, async_grad_allreduce, fp8_communication, use_zbv)
def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=False):
return LinearWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv)


def linear_gather_forward_reducescatter_backward(
Expand Down
21 changes: 5 additions & 16 deletions colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@
from ._operation import (
gather_forward_reducescatter_backward,
gather_forward_split_backward,
linear_base,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
linear_with_grad_accum,
reduce_forward,
reducescatter_forward_gather_backward,
split_forward_gather_backward,
)
from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset

__all__ = ["Linear1D", "Linear1D_Col", "Linear1D_Row"]
__all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"]


class Linear1D(ParallelModule):
class LinearWithGradAccum(ParallelModule):
r"""Linear layer with no parallelism.
Args:
Expand Down Expand Up @@ -69,16 +69,11 @@ def __init__(
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
gather_output: bool = False,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
use_zbv: bool = False,
**kwargs,
):
Expand All @@ -87,13 +82,8 @@ def __init__(
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.fp8_communication = fp8_communication
self.use_zbv = use_zbv

if skip_bias_add and not bias:
Expand Down Expand Up @@ -143,7 +133,7 @@ def from_native_module(module: nn.Linear, **kwargs) -> ParallelModule:
bias = module.bias is not None
device = module.weight.device

linear_1d = Linear1D(
linear_1d = LinearWithGradAccum(
in_features=in_features,
out_features=out_features,
bias=bias,
Expand Down Expand Up @@ -174,12 +164,11 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:

# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_base(
output_parallel = linear_with_grad_accum(
input_parallel,
self.weight,
bias,
False,
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv,
)

Expand Down
15 changes: 3 additions & 12 deletions colossalai/shardformer/policies/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
tp_size = self.shard_config.tensor_parallel_size
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv

# modified for both SP and TP
num_q_heads = self.model.config.num_attention_heads
Expand Down Expand Up @@ -334,10 +331,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
class MixtralForCausalLMPolicy(MixtralPolicy):
def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism:
# add a new item for causal lm
Expand Down Expand Up @@ -400,10 +394,7 @@ def module_policy(self):
from transformers import MixtralForSequenceClassification

policy = super().module_policy()
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv

if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
Expand Down
4 changes: 2 additions & 2 deletions examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,10 @@ def empty_init():
)
loss = outputs["loss"]
if args.pp_style == "zbv":
if dist.get_rank() == 0:
if coordinator.is_master():
print(f"Step {step} loss: {loss}")
else:
if dist.get_rank() == dist.get_world_size() - 1:
if coordinator.is_last_process():
print(f"Step {step} loss: {loss}")
optimizer.step()
optimizer.zero_grad()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_shardformer/test_layer/test_linear_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.weight_grad_store import WeightGradStore
from colossalai.shardformer.layer import Linear1D, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, LinearWithGradAccum
from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn

Expand Down Expand Up @@ -124,7 +124,7 @@ def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: b
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_base = Linear1D.from_native_module(
linear_base = LinearWithGradAccum.from_native_module(
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False
)
assert linear_base.weight.shape == torch.Size([128, 32])
Expand Down Expand Up @@ -164,7 +164,7 @@ def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_base = Linear1D.from_native_module(
linear_base = LinearWithGradAccum.from_native_module(
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True
)
assert linear_base.weight.shape == torch.Size([128, 32])
Expand Down

0 comments on commit 5f09243

Please sign in to comment.