Skip to content

Commit

Permalink
[shardformer/sequence parallel] polish code
Browse files Browse the repository at this point in the history
  • Loading branch information
FoolPlayer committed Aug 16, 2023
1 parent e477215 commit 5dd0a86
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 11 deletions.
4 changes: 2 additions & 2 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def backward(ctx, grad_output):
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()

torch.cuda.synchronize()
torch.cuda.current_stream().wait_stream(calculate_stream)

reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
with torch.cuda.stream(calculate_stream):
Expand All @@ -248,7 +248,7 @@ def backward(ctx, grad_output):
print(grad_output.shape, input_parallel.shape)
grad_weight = grad_output.t().matmul(input_parallel)

torch.cuda.synchronize()
torch.cuda.current_stream().wait_stream(calculate_stream)

return output, grad_weight, grad_bias, None, None, None, None

Expand Down
1 change: 1 addition & 0 deletions colossalai/shardformer/modeling/gpt2_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
logger = logging.get_logger(__name__)


# TODO: put all contents in `gpt2.py` and make it compatible with pipeline
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):

def forward(
Expand Down
9 changes: 1 addition & 8 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,7 @@ def module_policy(self):
),
])
if self.shard_config.enable_sequence_parallelism:
policy[GPT2Model] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
],
method_replacement={"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)})
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}

policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
Expand Down
2 changes: 1 addition & 1 deletion tests/kit/model_zoo/transformers/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def data_gen_for_sequence_classification():

config = transformers.GPT2Config(n_layer=2,
n_head=4,
vocab_size=50260,
vocab_size=50258,
attn_pdrop=0,
embd_pdrop=0,
resid_pdrop=0,
Expand Down

0 comments on commit 5dd0a86

Please sign in to comment.