diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py index ab730ab52c90..17c5abd53236 100644 --- a/applications/Colossal-LLaMA/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -105,7 +105,6 @@ def train(args) -> None: enable_fused_normalization=get_accelerator().is_available(), enable_sequence_parallelism=args.enable_sequence_parallelism, cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, - parallel_output=False, max_norm=args.grad_clip, precision=args.mixed_precision, microbatch_size=args.microbatch_size, @@ -171,6 +170,7 @@ def train(args) -> None: # ====================================================== # Initialize Model, Objective, Optimizer and LR Scheduler # ====================================================== + # TODO chatglm doesn't support lora now init_ctx = ( LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0 diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh index 84bae0027c83..e77bcb833ff3 100755 --- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh +++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh @@ -1,13 +1,13 @@ -SAVE_DIR="" +SAVE_DIR="/home/nvme-share/home/jiangmingyan/workspace/ColossalAI/applications/ColossalChat/examples/data_preparation_scripts/tokenized_data" rm -rf $SAVE_DIR/cache rm -rf $SAVE_DIR/jsonl rm -rf $SAVE_DIR/arrow python prepare_dataset.py --type sft \ - --data_input_dirs /PATH/TO/SFT/DATASET \ - --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \ - --tokenizer_dir "" \ + --data_input_dirs /home/nvme-share/home/jiangmingyan/workspace/ColossalAI/applications/ColossalChat/examples/data_preparation_scripts/data/ \ + --conversation_template_config /home/nvme-share/home/jiangmingyan/workspace/ColossalAI/applications/ColossalChat/conversation_template/THUDM_chatglm2-6b.json \ + --tokenizer_dir "/home/nvme-share/share/models/ZhipuAI/chatglm2-6b" \ --data_cache_dir $SAVE_DIR/cache \ --data_jsonl_output_dir $SAVE_DIR/jsonl \ --data_arrow_output_dir $SAVE_DIR/arrow \ diff --git a/applications/ColossalChat/examples/requirements.txt b/applications/ColossalChat/examples/requirements.txt index 91f25a5cf843..91ef7afde0ee 100644 --- a/applications/ColossalChat/examples/requirements.txt +++ b/applications/ColossalChat/examples/requirements.txt @@ -1,4 +1,4 @@ pandas>=1.4.1 sentencepiece -colossalai==0.4.0 +# colossalai==0.4.0 prompt_toolkit diff --git a/applications/ColossalChat/requirements.txt b/applications/ColossalChat/requirements.txt index ac40ae821d0a..a7e92a39bb94 100755 --- a/applications/ColossalChat/requirements.txt +++ b/applications/ColossalChat/requirements.txt @@ -1,9 +1,9 @@ -transformers==4.39.3 +# transformers==4.39.3 tqdm datasets==2.14.7 loralib -colossalai>=0.4.0 -torch>=2.1.0 +# colossalai>=0.4.0 +# torch>=2.1.0 langchain tokenizers fastapi @@ -19,5 +19,5 @@ six==1.16.0 datasets ninja==1.11.1 sentencepiece==0.1.99 -flash-attn +# flash-attn tiktoken diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index c3c6f36011fb..b9a324f709d2 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -160,7 +160,7 @@ def _check_for_nccl_hccl_backend(group): def _check_device(group): - is_nccl_backend = _check_for_nccl_backend(group) + is_nccl_backend = _check_for_nccl_hccl_backend(group) current_device = torch.device("cpu") if is_nccl_backend: current_device = torch.device(get_accelerator().current_device()) diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index a9be5c74dba8..bc4a4cc9af51 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -9,7 +9,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig -from colossalai.shardformer.layer import AttnMaskType, ColoAttention +from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.layer._operation import ( all_to_all_comm, gather_sp_output, @@ -25,42 +25,7 @@ def get_flash_core_attention_forward(): def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask): query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - attention_mask_type = AttnMaskType.CAUSAL - attn_bias = torch.zeros( - query_layer.shape[0], - 1, - query_layer.shape[2], - key_layer.shape[2], - dtype=query_layer.dtype, - device=query_layer.device, - ) - temp_mask = ( - torch.ones( - query_layer.shape[2], - key_layer.shape[2], - dtype=torch.bool, - device=query_layer.device, - ) - .tril(diagonal=0) - .expand(query_layer.shape[0], 1, -1, -1) - ) - attn_bias.masked_fill_(temp_mask.logical_not(), torch.finfo(query_layer.dtype).min) - else: - attention_mask_type = AttnMaskType.CUSTOM - if attention_mask is not None: - attn_bias = torch.zeros_like(attention_mask, dtype=query_layer.dtype) - attn_bias.masked_fill_(attention_mask, torch.finfo(query_layer.dtype).min) - dropout_p = self.attention_dropout.p if self.training else 0.0 - context_layer = ColoAttention.attention( - query_layer, - key_layer, - value_layer, - attention_mask=attn_bias, - attention_mask_type=attention_mask_type, - dropout_p=dropout_p, - scale=1.0 / self.norm_factor, - ) + context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, **attention_mask) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(*new_context_layer_shape) @@ -180,9 +145,21 @@ def chatglm_model_forward( ], dim=-1, ) - if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + if shard_config.enable_flash_attention: + mask_shape = (batch_size, 1, seq_length, seq_length) + full_attention_mask: dict = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + print("full_attention_mask", full_attention_mask) + else: + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) # Support SP + PP sp_size = shard_config.sequence_parallel_size @@ -237,7 +214,7 @@ def chatglm_model_forward( layer_ret = torch.utils.checkpoint.checkpoint( layer, hidden_states, - attention_mask, + full_attention_mask, rotary_pos_emb, past_key_values[idx], use_cache, @@ -402,10 +379,19 @@ def forward( ], dim=-1, ) - - if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + if shard_config.enable_flash_attention: + mask_shape = (batch_size, 1, seq_length, seq_length) + full_attention_mask: dict = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) # Rotary positional embeddings rotary_pos_emb = self.rotary_pos_emb(self.seq_length) @@ -652,3 +638,79 @@ def forward( return output, kv_cache return forward + + +def get_flash_attention_forward_for_chat_glm_model(): + from .chatglm2_6b.modeling_chatglm import ChatGLMModel + + def forward( + self: ChatGLMModel, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype + ) + if attention_mask is not None: + attention_mask = torch.cat( + [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1 + ) + + mask_shape = (batch_size, 1, seq_length, seq_length) + full_attention_mask: dict = ColoAttention.prepare_attn_kwargs( + mask_shape, + inputs_embeds.dtype, + inputs_embeds.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index c003570a0582..4ddcf8bfce6b 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -11,6 +11,7 @@ from ..modeling.chatglm2 import ( get_chatglm_sequence_parallel_attention_forward, get_chatglm_sequence_parallel_forward_fn, + get_flash_attention_forward_for_chat_glm_model, get_flash_core_attention_forward, get_jit_fused_glm_block_forward, ) @@ -203,6 +204,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key="CoreAttention", ) + self.append_or_create_method_replacement( + description={ + "forward": get_flash_attention_forward_for_chat_glm_model(), + }, + policy=policy, + target_key="ChatGLMModel", + ) # use sequence parallel if self.shard_config.enable_sequence_parallelism: diff --git a/examples/language/data_utils.py b/examples/language/data_utils.py index 6b9e8ef28eb7..2b31e62192c3 100644 --- a/examples/language/data_utils.py +++ b/examples/language/data_utils.py @@ -113,6 +113,10 @@ def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: ) self.attention_mask = torch.ones_like(self.input_ids) + half_length = max_length // 2 + + self.attention_mask[:, half_length:] = 0 + def __len__(self): return self.num_samples diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index d6b009724bf4..ad3b9a34bc0d 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -292,9 +292,13 @@ def empty_init(): model_numel = get_model_numel(model) coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + if config.model_type == "chatglm": + num_layers = model.config.num_layers + else: + num_layers = model.config.num_hidden_layers performance_evaluator = PerformanceEvaluator( model_numel, - model.config.num_hidden_layers, + num_layers, model.config.hidden_size, model.config.vocab_size, args.grad_checkpoint,