Skip to content

Commit

Permalink
activate checks
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Aug 6, 2023
1 parent b8eb7c0 commit 0c2d924
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 21 deletions.
1 change: 1 addition & 0 deletions colossalai/shardformer/modeling/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def get_opt_flash_attention_forward():
from transformers.models.opt.modeling_opt import OPTAttention

from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
from colossalai.lazy import LazyTensor

def forward(
self: OPTAttention,
Expand Down
5 changes: 1 addition & 4 deletions colossalai/shardformer/policies/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,11 @@ def preprocess(self):
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)


return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:

from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock

policy = {}

Expand Down Expand Up @@ -207,7 +206,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
return []



class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):

def module_policy(self):
Expand All @@ -228,4 +226,3 @@ def get_held_layers(self) -> List[nn.Module]:
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in ChatGLMForConditionalGenerationModel."""
return []

5 changes: 2 additions & 3 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
from contextlib import nullcontext
from typing import Optional
from typing import Any, Callable, Dict, List, Optional

import torch
Expand All @@ -16,8 +15,8 @@
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.shardformer._utils import getattr_
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor


Expand All @@ -33,9 +32,9 @@ def build_model(model_fn,
# create new model
org_model = model_fn()
model_copy = copy.deepcopy(org_model)
# shard model
if use_lazy_init:
ctx.materialize(org_model)
# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
Expand Down
16 changes: 10 additions & 6 deletions tests/test_shardformer/test_model/test_shard_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,25 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check grad
col_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens']
row_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-7, rtol=1e-3, dim=0, verbose=False)
check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False)
check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False)


@parameterize('use_lazy_init', [True, False])
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
@parameterize('use_lazy_init', [False, True])
def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
use_lazy_init):
def run_opt_test(use_lazy_init, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention,
enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, enable_jit_fused, use_lazy_init)
org_model, sharded_model = build_model(model_fn,
enable_fused_normalization,
enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused,
use_lazy_init=use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
Expand Down
18 changes: 11 additions & 7 deletions tests/test_shardformer/test_model/test_shard_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check grad
col_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.q', 'shared']
row_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.relative_attention_bias']
check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-7, rtol=1e-5, dim=0, verbose=False)
check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-7, rtol=1e-5, dim=1, verbose=False)
check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)

# check weights are tied
if hasattr(org_model, 'lm_head'):
Expand All @@ -44,15 +44,19 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo

@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('use_lazy_init', [False, True])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
@parameterize('use_lazy_init', [False, True])
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
use_lazy_init):
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init, enable_flash_attention,
enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, enable_jit_fused, use_lazy_init)
org_model, sharded_model = build_model(model_fn,
enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused,
use_lazy_init=use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/test_shard_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)

assert_hf_output_close(org_output, shard_output, atol=1e-4, rtol=1e-4)
assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3)

# do backward
org_loss.backward()
Expand Down

0 comments on commit 0c2d924

Please sign in to comment.