From 51665a25cce58b4ccc683de3d967888108cfcca7 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 2 Apr 2024 17:59:25 +0800 Subject: [PATCH] fix --- colossalai/shardformer/layer/linear.py | 23 +++++++++++-------- .../shardformer/layer/parallel_module.py | 3 +++ colossalai/shardformer/policies/bert.py | 9 -------- 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 5c2cc9445313..f8c978861da3 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -67,8 +67,8 @@ class Linear1D_Col(ParallelModule): def __init__( self, - in_features: int, - out_features: int, + in_features: int = None, + out_features: int = None, bias: bool = True, dtype: torch.dtype = None, device: torch.device = None, @@ -82,8 +82,10 @@ def __init__( bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + *args, + **kwargs, ): - super().__init__() + super().__init__(*args, **kwargs) # Keep input parameters self.in_features = in_features @@ -509,7 +511,7 @@ def forward(self, input: Tensor) -> Tensor: return output -class VocabParallelLMHead1D(PaddingParallelModule, Linear1D_Col): +class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule): r"""Linear layer with column parallelism. The linear layer is defined as :math:`Y = XA + b`. A is parallelized along @@ -540,8 +542,8 @@ class VocabParallelLMHead1D(PaddingParallelModule, Linear1D_Col): def __init__( self, - in_features: int, - out_features: int, + in_features: int = None, + out_features: int = None, bias: bool = True, dtype: torch.dtype = None, device: torch.device = None, @@ -570,10 +572,6 @@ def __init__( new_out_features = out_features + multiple - (out_features % multiple) super().__init__( - new_num_embeddings=new_out_features, - old_num_embeddings=out_features, - weight_A=weight, - bias_A=bias_, in_features=in_features, out_features=new_out_features, bias=bias, @@ -583,7 +581,12 @@ def __init__( bias_=bias_, *args, **kwargs, + new_num_embeddings=new_out_features, + old_num_embeddings=out_features, + weight_A=weight, + bias_A=bias_, ) + # get the length of valid embeddings tp_rank = dist.get_rank(process_group) partition_size = self.new_num_embeddings // dist.get_world_size(process_group) diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index facf3a90260d..f38e467480a4 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -25,6 +25,9 @@ class ParallelModule(nn.Module, ABC): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + @abstractmethod def from_native_module( module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 9f9d577b0a1a..5ad5179ab2df 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -39,15 +39,6 @@ def preprocess(self): self.tie_weight = self.tie_weight_check() return self.model - def tie_weight_check(self): - input_embedding = self.model.get_input_embeddings() - output_embedding = self.model.get_output_embeddings() - return ( - input_embedding is not None - and output_embedding is not None - and id(input_embedding.weight) == id(output_embedding.weight) - ) - def module_policy(self): from transformers.models.bert.modeling_bert import ( BertEmbeddings,