Skip to content

Commit

Permalink
[Hotfix] Fix llama fwd replacement bug (#6031)
Browse files Browse the repository at this point in the history
Co-authored-by: Edenzzzz <[email protected]>
  • Loading branch information
Edenzzzz and Edenzzzz authored Aug 23, 2024
1 parent 39e2597 commit 7cf9df0
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy=policy,
target_key=attn_cls,
)
if self.pipeline_stage_manager is None:
self.append_or_create_method_replacement(
description={
"forward": get_llama_flash_attention_model_forward(
self.shard_config,
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
),
},
policy=policy,
target_key=LlamaModel,
)

if self.pipeline_stage_manager is None:
self.append_or_create_method_replacement(
description={
"forward": get_llama_flash_attention_model_forward(
self.shard_config,
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
),
},
policy=policy,
target_key=LlamaModel,
)

if self.shard_config.enable_tensor_parallelism:
assert (
Expand Down

0 comments on commit 7cf9df0

Please sign in to comment.