Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Nov 11, 2024
1 parent c8782c2 commit e39a5c3
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 57 deletions.
2 changes: 1 addition & 1 deletion applications/Colossal-LLaMA/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def train(args) -> None:
enable_fused_normalization=get_accelerator().is_available(),
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
Expand Down Expand Up @@ -171,6 +170,7 @@ def train(args) -> None:
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# TODO chatglm doesn't support lora now
init_ctx = (
LazyInitContext(default_device=get_current_device())
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
SAVE_DIR=""
SAVE_DIR="/home/nvme-share/home/jiangmingyan/workspace/ColossalAI/applications/ColossalChat/examples/data_preparation_scripts/tokenized_data"

rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow

python prepare_dataset.py --type sft \
--data_input_dirs /PATH/TO/SFT/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_input_dirs /home/nvme-share/home/jiangmingyan/workspace/ColossalAI/applications/ColossalChat/examples/data_preparation_scripts/data/ \
--conversation_template_config /home/nvme-share/home/jiangmingyan/workspace/ColossalAI/applications/ColossalChat/conversation_template/THUDM_chatglm2-6b.json \
--tokenizer_dir "/home/nvme-share/share/models/ZhipuAI/chatglm2-6b" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
Expand Down
2 changes: 1 addition & 1 deletion applications/ColossalChat/examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pandas>=1.4.1
sentencepiece
colossalai==0.4.0
# colossalai==0.4.0
prompt_toolkit
8 changes: 4 additions & 4 deletions applications/ColossalChat/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
transformers==4.39.3
# transformers==4.39.3
tqdm
datasets==2.14.7
loralib
colossalai>=0.4.0
torch>=2.1.0
# colossalai>=0.4.0
# torch>=2.1.0
langchain
tokenizers
fastapi
Expand All @@ -19,5 +19,5 @@ six==1.16.0
datasets
ninja==1.11.1
sentencepiece==0.1.99
flash-attn
# flash-attn
tiktoken
2 changes: 1 addition & 1 deletion colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _check_for_nccl_hccl_backend(group):


def _check_device(group):
is_nccl_backend = _check_for_nccl_backend(group)
is_nccl_backend = _check_for_nccl_hccl_backend(group)
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device(get_accelerator().current_device())
Expand Down
152 changes: 107 additions & 45 deletions colossalai/shardformer/modeling/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_sp_output,
Expand All @@ -25,42 +25,7 @@ def get_flash_core_attention_forward():

def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask):
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
attention_mask_type = AttnMaskType.CAUSAL
attn_bias = torch.zeros(
query_layer.shape[0],
1,
query_layer.shape[2],
key_layer.shape[2],
dtype=query_layer.dtype,
device=query_layer.device,
)
temp_mask = (
torch.ones(
query_layer.shape[2],
key_layer.shape[2],
dtype=torch.bool,
device=query_layer.device,
)
.tril(diagonal=0)
.expand(query_layer.shape[0], 1, -1, -1)
)
attn_bias.masked_fill_(temp_mask.logical_not(), torch.finfo(query_layer.dtype).min)
else:
attention_mask_type = AttnMaskType.CUSTOM
if attention_mask is not None:
attn_bias = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
attn_bias.masked_fill_(attention_mask, torch.finfo(query_layer.dtype).min)
dropout_p = self.attention_dropout.p if self.training else 0.0
context_layer = ColoAttention.attention(
query_layer,
key_layer,
value_layer,
attention_mask=attn_bias,
attention_mask_type=attention_mask_type,
dropout_p=dropout_p,
scale=1.0 / self.norm_factor,
)
context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, **attention_mask)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
Expand Down Expand Up @@ -180,9 +145,21 @@ def chatglm_model_forward(
],
dim=-1,
)
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)

if shard_config.enable_flash_attention:
mask_shape = (batch_size, 1, seq_length, seq_length)
full_attention_mask: dict = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
print("full_attention_mask", full_attention_mask)
else:
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)

# Support SP + PP
sp_size = shard_config.sequence_parallel_size
Expand Down Expand Up @@ -237,7 +214,7 @@ def chatglm_model_forward(
layer_ret = torch.utils.checkpoint.checkpoint(
layer,
hidden_states,
attention_mask,
full_attention_mask,
rotary_pos_emb,
past_key_values[idx],
use_cache,
Expand Down Expand Up @@ -402,10 +379,19 @@ def forward(
],
dim=-1,
)

if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
if shard_config.enable_flash_attention:
mask_shape = (batch_size, 1, seq_length, seq_length)
full_attention_mask: dict = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
else:
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)

# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
Expand Down Expand Up @@ -652,3 +638,79 @@ def forward(
return output, kv_cache

return forward


def get_flash_attention_forward_for_chat_glm_model():
from .chatglm2_6b.modeling_chatglm import ChatGLMModel

def forward(
self: ChatGLMModel,
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

batch_size, seq_length = input_ids.shape

if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)

if self.pre_seq_len is not None:
if past_key_values is None:
past_key_values = self.get_prompt(
batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype
)
if attention_mask is not None:
attention_mask = torch.cat(
[attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1
)

mask_shape = (batch_size, 1, seq_length, seq_length)
full_attention_mask: dict = ColoAttention.prepare_attn_kwargs(
mask_shape,
inputs_embeds.dtype,
inputs_embeds.device,
q_padding_mask=attention_mask,
is_causal=True,
)

# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
if position_ids is not None:
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()

# Run encoder.
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds,
full_attention_mask,
rotary_pos_emb=rotary_pos_emb,
kv_caches=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
)

if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)

return forward
8 changes: 8 additions & 0 deletions colossalai/shardformer/policies/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..modeling.chatglm2 import (
get_chatglm_sequence_parallel_attention_forward,
get_chatglm_sequence_parallel_forward_fn,
get_flash_attention_forward_for_chat_glm_model,
get_flash_core_attention_forward,
get_jit_fused_glm_block_forward,
)
Expand Down Expand Up @@ -203,6 +204,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy=policy,
target_key="CoreAttention",
)
self.append_or_create_method_replacement(
description={
"forward": get_flash_attention_forward_for_chat_glm_model(),
},
policy=policy,
target_key="ChatGLMModel",
)

# use sequence parallel
if self.shard_config.enable_sequence_parallelism:
Expand Down
4 changes: 4 additions & 0 deletions examples/language/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size:
)
self.attention_mask = torch.ones_like(self.input_ids)

half_length = max_length // 2

self.attention_mask[:, half_length:] = 0

def __len__(self):
return self.num_samples

Expand Down
6 changes: 5 additions & 1 deletion examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,13 @@ def empty_init():

model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
if config.model_type == "chatglm":
num_layers = model.config.num_layers
else:
num_layers = model.config.num_hidden_layers
performance_evaluator = PerformanceEvaluator(
model_numel,
model.config.num_hidden_layers,
num_layers,
model.config.hidden_size,
model.config.vocab_size,
args.grad_checkpoint,
Expand Down

0 comments on commit e39a5c3

Please sign in to comment.