diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 26fcb6e77d35..617720ee7950 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -40,7 +40,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: suffix="dropout", target_module=DropoutForReplicatedInput, ) - ]) + ]) policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={ "attention.attention.num_attention_heads": @@ -48,45 +48,45 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "attention.attention.all_head_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, }, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.attention.query", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.key", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=col_nn.DropoutForReplicatedInput, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=col_nn.DropoutForReplicatedInput, - ), - ]) + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ]) # use flash attention if self.shard_config.enable_flash_attention: diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index 67a2e90532e2..be4591d58f74 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -21,7 +21,7 @@ def check_stage_manager(): 1: [0, 1], 2: [2, 3], 3: [2, 3], - } + } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank()