diff --git a/llm/README.md b/llm/README.md index bb2315958789..6286c61f566a 100644 --- a/llm/README.md +++ b/llm/README.md @@ -11,7 +11,7 @@ | [GPT-3](./gpt-3) | ✅ | ✅ | ✅ | 🚧 | ✅ | 🚧 | | [OPT](./opt) | 🚧 | ✅ | ✅ | 🚧 | ✅ | 🚧 | | [GLM](./glm) | ❌ | ✅ | ✅ | 🚧 | ✅ | 🚧 | -| [Qwen](./qwen) | ❌ | ✅ | ✅ | ✅ | ✅ | 🚧 | +| [Qwen](./qwen) | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | * ✅: Supported @@ -39,6 +39,11 @@ ## 2. 预训练 [LLaMA v1/v2](./llama)、[GPT-3](./gpt-3) 目录中提供了模型预训练的数据准备和训练细节,后续我们将支持更多的模型预训练。 +``` +# 千问模型预训练 +python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_pretrain.py ./qwen/pretrain_argument_stage2.json + +``` ## 3. 精调 目前精调统一脚本只支持[LLaMA v1/v2](./llama)、[ChatGLM-6B](./chatglm)、[ChatGLM2-6B](./chatglm2)、[Bloom](./bloom)、[OPT](./opt)、[Qwen](./qwen),其他模型精调使用详见对应模型目录。接下来我们将以**Llama 2**为例介绍如何使用统一脚本进行SFT、LoRA、Prefix Tuning。更多LoRA、Prefix Tuning请参见[PEFT文档](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/peft.md)。 diff --git a/llm/ernie-3.5-se/run_pretrain.py b/llm/ernie-3.5-se/run_pretrain.py index 2fd4abcce880..6faefcedac5a 100644 --- a/llm/ernie-3.5-se/run_pretrain.py +++ b/llm/ernie-3.5-se/run_pretrain.py @@ -397,7 +397,7 @@ def main(): use_progressive_seq_len=True, ) else: - model = model_class._from_config(config, dtype=dtype) + model = model_class.from_config(config, dtype=dtype) # Create the learning_rate sheduler and optimizer if training_args.decay_steps is None: diff --git a/llm/gpt-3/run_pretrain.py b/llm/gpt-3/run_pretrain.py index ba603057675a..603dc4e748c2 100644 --- a/llm/gpt-3/run_pretrain.py +++ b/llm/gpt-3/run_pretrain.py @@ -411,7 +411,7 @@ def main(): dtype=dtype, ) else: - model = model_class._from_config(config, dtype=dtype) + model = model_class.from_config(config, dtype=dtype) # Create the learning_rate sheduler and optimizer if training_args.decay_steps is None: diff --git a/llm/qwen/pretrain_argument_stage2.json b/llm/qwen/pretrain_argument_stage2.json new file mode 100644 index 000000000000..1345021f3d19 --- /dev/null +++ b/llm/qwen/pretrain_argument_stage2.json @@ -0,0 +1,39 @@ +{ + "model_name_or_path": "qwen/qwen-7b", + "tokenizer_name_or_path": "qwen/qwen-7b", + "input_dir": "./data", + "output_dir": "./checkpoints/qwen_pretrain_ckpts", + "per_device_train_batch_size": 2, + "gradient_accumulation_steps": 1, + "per_device_eval_batch_size": 2, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, + "sharding": "stage2", + "virtual_pp_degree": 1, + "sequence_parallel": 0, + "use_flash_attention": true, + "use_fused_rms_norm": true, + "max_seq_length": 4096, + "learning_rate": 3e-05, + "min_learning_rate": 3e-06, + "warmup_steps": 30, + "logging_steps": 1, + "max_steps": 10000, + "save_steps": 5000, + "eval_steps": 1000, + "weight_decay": 0.01, + "bf16": true, + "fp16_opt_level": "O2", + "warmup_ratio": 0.01, + "max_grad_norm": 1.0, + "dataloader_num_workers": 1, + "continue_training": 1, + "do_train": true, + "do_eval": true, + "do_predict": true, + "disable_tqdm": true, + "recompute": true, + "distributed_dataloader": 1, + "recompute_granularity": "full", + "save_total_limit": 2 + } diff --git a/llm/qwen/pretrain_argument_tp2pp4.json b/llm/qwen/pretrain_argument_tp2pp4.json new file mode 100644 index 000000000000..4b060a490d60 --- /dev/null +++ b/llm/qwen/pretrain_argument_tp2pp4.json @@ -0,0 +1,39 @@ +{ + "model_name_or_path": "qwen/qwen-7b", + "tokenizer_name_or_path": "qwen/qwen-7b", + "input_dir": "./data", + "output_dir": "./checkpoints/qwen_pretrain_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 16, + "per_device_eval_batch_size": 16, + "tensor_parallel_degree": 2, + "pipeline_parallel_degree": 4, + "sharding": "stage1", + "virtual_pp_degree": 1, + "sequence_parallel": 0, + "use_flash_attention": true, + "use_fused_rms_norm": true, + "max_seq_length": 4096, + "learning_rate": 3e-05, + "min_learning_rate": 3e-06, + "warmup_steps": 30, + "logging_steps": 1, + "max_steps": 10000, + "save_steps": 5000, + "eval_steps": 1000, + "weight_decay": 0.01, + "bf16": true, + "fp16_opt_level": "O2", + "warmup_ratio": 0.01, + "max_grad_norm": 1.0, + "dataloader_num_workers": 1, + "continue_training": 1, + "do_train": true, + "do_eval": true, + "do_predict": true, + "disable_tqdm": true, + "recompute": true, + "distributed_dataloader": 1, + "recompute_granularity": "full", + "save_total_limit": 2 + } diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py new file mode 100644 index 000000000000..d684e0245cbd --- /dev/null +++ b/llm/run_pretrain.py @@ -0,0 +1,553 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import os +import random +import sys +import time +from dataclasses import dataclass, field +from typing import List, Optional + +import numpy as np +import paddle + +from paddlenlp.data.causal_dataset import ( + build_train_valid_test_datasets, + check_data_split, + print_rank_0, +) +from paddlenlp.trainer import ( + PdArgumentParser, + Trainer, + TrainingArguments, + get_last_checkpoint, + speed_metrics, +) +from paddlenlp.transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForCausalLMPipe, + AutoTokenizer, + CosineAnnealingWithWarmupDecay, + LinearAnnealingWithWarmupDecay, + register_sequence_parallel_allreduce_hooks, +) +from paddlenlp.utils.batch_sampler import DistributedBatchSampler +from paddlenlp.utils.log import logger + + +def add_start_docstrings(*docstr): + def docstring_decorator(fn): + fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") + return fn + + return docstring_decorator + + +@dataclass +@add_start_docstrings(TrainingArguments.__doc__) +class PreTrainingArguments(TrainingArguments): + min_learning_rate: float = field( + default=1e-5, + metadata={"help": "Minimum learning rate deacyed to."}, + ) + decay_steps: float = field( + default=None, + metadata={ + "help": "The steps use to control the learing rate. If the step > decay_steps, will use the min_learning_rate." + }, + ) + enable_linear_fused_grad_add: bool = field( + default=False, + metadata={ + "help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ." + }, + ) + + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and evaluating. + Using `PdArgumentParser` we can turn this class into argparse arguments to be able to + specify them on the command line. + """ + + input_dir: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split."}) + + max_seq_length: int = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + share_folder: bool = field( + default=False, + metadata={"help": "Use share folder for data dir and output dir on multi machine."}, + ) + + data_impl: str = field(default="mmap", metadata={"help": "The format of the preprocessed data."}) + skip_warmup: bool = field( + default=True, + metadata={"help": "Whether to skip the warmup process of mmap files."}, + ) + data_cache: str = field(default=None, metadata={"help": "The path of the cached dataset."}) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to pre-train from. + """ + + model_name_or_path: str = field( + default="__internal_testing__/tiny-random-llama", + metadata={ + "help": "Path to pretrained model or model identifier from https://paddlenlp.readthedocs.io/zh/latest/model_zoo/transformers.html" + }, + ) + tokenizer_name_or_path: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + use_flash_attention: bool = field( + default=False, + metadata={"help": "use_flash_attention"}, + ) + use_fused_rms_norm: bool = field( + default=False, + metadata={"help": "llama or other model, use_fused_rms_norm"}, + ) + fuse_attention_qkv: bool = field( + default=False, + metadata={"help": "whether to fuse attention qkv"}, + ) + fuse_attention_ffn: bool = field( + default=False, + metadata={"help": "whether to fuse first up and gate proj in mlp block"}, + ) + recompute_granularity: str = field( + default="full", + metadata={"help": "Choose among ['full', 'core_attn', 'full_attn']"}, + ) + virtual_pp_degree: int = field( + default=1, + metadata={"help": "virtual_pp_degree"}, + ) + continue_training: bool = field( + default=False, + metadata={ + "help": "Pre-training from existing paddlenlp model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddlenlp models." + }, + ) + sequence_parallel: bool = field( + default=False, + metadata={"help": "whether to use sequence parallel"}, + ) + fuse_sequence_parallel_allreduce: bool = field( + default=False, + metadata={"help": "whether to use fuse sequence parallel allreduce"}, + ) + use_fused_rope: Optional[bool] = field( + default=False, + metadata={"help": "Enable rope fusion or not."}, + ) + no_recompute_layers: Optional[List[int]] = field( + default=None, + metadata={"help": "Specify the full transformer layers that should not be recomputed."}, + ) + pp_recompute_interval: int = field( + default=1, + metadata={ + "help": "The interval for the number of layers at which recomputation occurs. A value of 0 indicates no recomputation. Default is 0." + }, + ) + recompute_use_reentrant: bool = field( + default=False, + metadata={"help": "recompute_use_reentrant"}, + ) + + +def create_pretrained_dataset( + data_args, + training_args, + data_file, + tokenizer, + need_data=True, +): + + check_data_split(data_args.split, training_args.do_train, training_args.do_eval, training_args.do_predict) + + train_val_test_num_samples = [ + training_args.per_device_train_batch_size + * training_args.dataset_world_size + * training_args.max_steps + * training_args.gradient_accumulation_steps, + training_args.per_device_eval_batch_size + * training_args.dataset_world_size + * training_args.eval_iters + * (training_args.max_steps // training_args.eval_steps + 1), + training_args.per_device_eval_batch_size * training_args.dataset_world_size * training_args.test_iters, + ] + + print_rank_0(" > datasets target sizes (minimum size):") + if training_args.do_train: + print_rank_0(" train: {}".format(train_val_test_num_samples[0])) + if training_args.do_eval: + print_rank_0(" validation: {}".format(train_val_test_num_samples[1])) + if training_args.do_predict: + print_rank_0(" test: {}".format(train_val_test_num_samples[2])) + + # Build the datasets. + train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets( + data_prefix=data_file, + data_impl=data_args.data_impl, + splits_string=data_args.split, + train_val_test_num_samples=train_val_test_num_samples, + seq_length=data_args.max_seq_length, + seed=training_args.seed, + skip_warmup=data_args.skip_warmup, + share_folder=data_args.share_folder, + data_cache_path=data_args.data_cache, + need_data=need_data, + ) + + def print_dataset(data, mode="train"): + logger.info(f"Sample data for {mode} mode.") + # input_ids, loss_mask, attention_mask, position_ids, labels = data + input_ids = data["text"] + logger.info(tokenizer._decode(list(input_ids))) + + from paddlenlp.data import Stack + + def _collate_data(data, stack_fn=Stack()): + tokens_ = stack_fn([x["text"] for x in data]) + + labels = tokens_[:, 1:] + tokens = tokens_[:, :-1] + + return { + "input_ids": tokens, + "labels": labels, + } + + if need_data: + if training_args.do_train: + print_dataset(train_dataset[0], "train") + if training_args.do_eval: + print_dataset(valid_dataset[0], "valid") + if training_args.do_predict: + print_dataset(test_dataset[0], "test") + + return train_dataset, valid_dataset, test_dataset, _collate_data + + +def get_train_data_file(args): + if len(args.input_dir.split()) > 1: + # weight-1 data-prefix-1 weight-2 data-prefix-2 ... + return args.input_dir.split() + else: + files = [ + os.path.join(args.input_dir, f) + for f in os.listdir(args.input_dir) + if (os.path.isfile(os.path.join(args.input_dir, f)) and ("_idx.npz" in str(f) or ".idx" in str(f))) + ] + files = [x.replace("_idx.npz", "") for x in files] + files = [x.replace(".idx", "") for x in files] # add + + if len(files) > 1: + ret = [] + logger.info("You are using multi-dataset:") + for x in files: + ret.append(1.0) + ret.append(x) + logger.info(" > set weight of %s dataset to 1.0" % x) + return ret + + return files + + +def set_seed(args): + if args.device == "cpu": + idx = 0 + else: + idx = paddle.distributed.get_rank() + random.seed(args.seed + idx) + np.random.seed(args.seed + idx) + paddle.seed(args.seed + idx) + + +class PretrainingTrainer(Trainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"): + # keep eval_dataloader + eval_dataloader = getattr(self, "eval_dataloader", None) + if eval_dataloader is None: + eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset + eval_dataloader = self.get_eval_dataloader(eval_dataset) + # must call data loader, otherwise, it will init many times, cause OOM error. + self.eval_dataloader = eval_dataloader() + + start_time = time.time() + # Temporarily disable metric computation, we will do it in the loop here. + compute_metrics = self.compute_metrics + eval_loop = self.evaluation_loop + + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + # Only evaluate max_eval_iters + max_eval_iters=self.args.eval_iters, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.log(output.metrics) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) + return output.metrics + + def _get_eval_sampler(self, eval_dataset) -> Optional[paddle.io.Sampler]: + return DistributedBatchSampler( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + shuffle=False, + num_replicas=self.args.dataset_world_size, + rank=self.args.dataset_rank, + drop_last=self.args.dataloader_drop_last, + ) + + def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: + return DistributedBatchSampler( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + shuffle=False, + num_replicas=self.args.dataset_world_size, + rank=self.args.dataset_rank, + drop_last=self.args.dataloader_drop_last, + ) + + +def main(): + parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if training_args.enable_linear_fused_grad_add: + from fused_layers import mock_layers + + mock_layers() + + if model_args.tokenizer_name_or_path is None: + model_args.tokenizer_name_or_path = model_args.model_name_or_path + + if data_args.data_cache is not None: + os.makedirs(data_args.data_cache, exist_ok=True) + + set_seed(training_args) + paddle.set_device(training_args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + training_args.eval_iters = 10 + training_args.test_iters = training_args.eval_iters * 10 + + # Log model and data config + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + # if last_checkpoint is None and len( + # os.listdir(training_args.output_dir)) > 1: + # raise ValueError( + # f"Output directory ({training_args.output_dir}) already exists and is not empty. " + # "Use --overwrite_output_dir to overcome.") + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path) + config = AutoConfig.from_pretrained(model_args.model_name_or_path) + + config.seq_length = data_args.max_seq_length + # There are some technique extend RotaryEmbedding context. so don't change max_position_embeddings + if not model_args.continue_training: + config.max_position_embeddings = max(config.max_position_embeddings, data_args.max_seq_length) + + if not model_args.continue_training: + config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128) + logger.info(f"Reset vocab size to {config.vocab_size} for batter amp peformance.") + + if model_args.no_recompute_layers is not None: + model_args.no_recompute_layers.sort() + + config.use_flash_attention = model_args.use_flash_attention + config.use_fused_rms_norm = model_args.use_fused_rms_norm + config.fuse_attention_qkv = model_args.fuse_attention_qkv + config.fuse_attention_ffn = model_args.fuse_attention_ffn + config.recompute_granularity = model_args.recompute_granularity + config.virtual_pp_degree = model_args.virtual_pp_degree + config.sequence_parallel = model_args.sequence_parallel + config.fuse_sequence_parallel_allreduce = model_args.fuse_sequence_parallel_allreduce + config.use_fused_rope = model_args.use_fused_rope + config.no_recompute_layers = model_args.no_recompute_layers + config.pp_recompute_interval = model_args.pp_recompute_interval + config.recompute_use_reentrant = model_args.recompute_use_reentrant + + config.use_recompute = training_args.recompute + config.tensor_parallel_degree = training_args.tensor_parallel_degree + config.tensor_parallel_rank = training_args.tensor_parallel_rank + + print("Final pre-training config:", config) + + # Set the dtype for loading model + dtype = "float32" + if training_args.fp16_opt_level == "O2": + if training_args.fp16: + dtype = "float16" + if training_args.bf16: + dtype = "bfloat16" + + model_class = AutoModelForCausalLM + if training_args.pipeline_parallel_degree > 1: + model_class = AutoModelForCausalLMPipe + + if model_args.continue_training: + model = model_class.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + ) + else: + model = model_class.from_config(config, dtype=dtype) + + if model_args.sequence_parallel: + register_sequence_parallel_allreduce_hooks( + model, training_args.gradient_accumulation_steps, model_args.fuse_sequence_parallel_allreduce + ) + + if training_args.recompute: + model.recompute_enable() + + # Create the learning_rate sheduler and optimizer + if training_args.decay_steps is None: + training_args.decay_steps = training_args.max_steps + warmup_steps = training_args.warmup_ratio * training_args.max_steps + + lr_scheduler = None + if training_args.lr_scheduler_type.value == "cosine": + lr_scheduler = CosineAnnealingWithWarmupDecay( + max_lr=training_args.learning_rate, + min_lr=training_args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=training_args.decay_steps, + last_epoch=0, + ) + elif training_args.lr_scheduler_type.value == "linear": + lr_scheduler = LinearAnnealingWithWarmupDecay( + max_lr=training_args.learning_rate, + min_lr=training_args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=training_args.decay_steps, + last_epoch=0, + ) + + data_file = get_train_data_file(data_args) + train_dataset, eval_dataset, test_dataset, data_collator = create_pretrained_dataset( + data_args, + training_args, + data_file, + tokenizer, + need_data=training_args.should_load_dataset, + ) + + total_effective_tokens = ( + training_args.per_device_train_batch_size + * training_args.dataset_world_size + * training_args.max_steps + * training_args.gradient_accumulation_steps + * data_args.max_seq_length + ) + + trainer = PretrainingTrainer( + model=model, + args=training_args, + data_collator=data_collator, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + optimizers=(None, lr_scheduler), + tokenizer=tokenizer, + ) + + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + trainer.save_model() + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + if training_args.do_predict: + test_ret = trainer.predict(test_dataset) + trainer.log_metrics("test", test_ret.metrics) + + if training_args.should_load_dataset: + effective_tokens_per_second = total_effective_tokens / train_result.metrics["train_runtime"] + print(f"Effective Tokens per second: {effective_tokens_per_second:.2f}") + print(f"ips: {effective_tokens_per_second:.2f} tokens/s") + + +if __name__ == "__main__": + main() diff --git a/model_zoo/gpt/run_pretrain_trainer.py b/model_zoo/gpt/run_pretrain_trainer.py index 055622400772..941b82b669bb 100644 --- a/model_zoo/gpt/run_pretrain_trainer.py +++ b/model_zoo/gpt/run_pretrain_trainer.py @@ -408,7 +408,7 @@ def main(): dtype=dtype, ) else: - model = model_class._from_config(config, dtype=dtype) + model = model_class.from_config(config, dtype=dtype) # Create the learning_rate sheduler and optimizer if training_args.decay_steps is None: diff --git a/paddlenlp/data/dist_dataloader.py b/paddlenlp/data/dist_dataloader.py index 838ff05edb77..73f199bce163 100644 --- a/paddlenlp/data/dist_dataloader.py +++ b/paddlenlp/data/dist_dataloader.py @@ -65,13 +65,12 @@ def __init__( self._hcg = fleet.get_hybrid_communicate_group() - # init pp data comm group + # Init pp data comm group. if self._hcg.get_pipe_parallel_world_size() > 1: self._pp_data_group = self._init_dataloader_comm_group() else: self._pp_data_group = None - # tensor parallel message self.mp_group = self._hcg.get_model_parallel_group() self.mp_rank = self._hcg.get_model_parallel_rank() self.mp_src_rank = self._hcg.get_model_parallel_group_src_rank() @@ -80,8 +79,10 @@ def __init__( self.dp_rank = self._hcg.get_data_parallel_rank() sharding_rank = self._hcg.get_sharding_parallel_rank() self._need_data = (self.mp_rank == 0) and (self.pp_rank == 0) - self._data_int64_keys, self._data_int64_keys_size = None, None - self._data_fp32_keys, self._data_fp32_keys_size = None, None + + # When needed other data types, we can modify dtype_list. + self.dtype_list = [paddle.int64, paddle.float32, paddle.int32] + self._data_keys_list, self._data_keys_size = None, None if self._need_data: self._dataloader = paddle.io.DataLoader( @@ -139,99 +140,75 @@ def __iter__(self): return self def __next__(self): - data_int64_keys_size, data_fp32_keys_size = 0, 0 + data_keys_size = [0 for i in range(len(self.dtype_list))] if self._need_data: - # {'input_ids': int64, 'labels': int64} data = next(self._dataloader_iter) data_keys = list(data.keys()) - # TODO(daisiming): Better methods are needed to support new data types. - type_check = [paddle.int64, paddle.float32] for key in data_keys: - if data[key].dtype not in type_check: + if data[key].dtype not in self.dtype_list: raise ValueError( - f"Dist dataloader requires dtype == `int64` or dtype == 'float32', but got: {data[key].dtype}" + f"Dist dataloader requires dtype as `int64`, `float32` or `int32` currently, but got: {data[key].dtype}" ) - data_int64_list = [data[key] for key in data_keys if data[key].dtype == paddle.int64] - data_int64_keys = [key for key in data_keys if data[key].dtype == paddle.int64] - data_fp32_list = [data[key] for key in data_keys if data[key].dtype == paddle.float32] - data_fp32_keys = [key for key in data_keys if data[key].dtype == paddle.float32] - data_int64_keys_size, data_fp32_keys_size = len(data_int64_keys), len(data_fp32_keys) + data_list, data_keys_list = [], [] + for i, dtype in enumerate(self.dtype_list): + data_list.append([data[key] for key in data_keys if data[key].dtype == dtype]) + data_keys_list.append([key for key in data_keys if data[key].dtype == dtype]) + data_keys_size = [len(keys) for keys in data_keys_list] - # broadcast data keys size - data_int64_keys_size = paddle.to_tensor(data_int64_keys_size) - data_fp32_keys_size = paddle.to_tensor(data_fp32_keys_size) - if self._data_int64_keys_size is None: + # Broadcast data keys size. + if self._data_keys_size is None: if self.mp_group is not None and self.pp_rank == 0: - paddle.distributed.broadcast(data_int64_keys_size, src=self.mp_src_rank, group=self.mp_group) - paddle.distributed.broadcast(data_fp32_keys_size, src=self.mp_src_rank, group=self.mp_group) + paddle.distributed.broadcast_object_list(data_keys_size, src=self.mp_src_rank, group=self.mp_group) if self._pp_data_group is not None: - paddle.distributed.broadcast( - data_int64_keys_size, src=self._pp_data_group.ranks[0], group=self._pp_data_group - ) - paddle.distributed.broadcast( - data_fp32_keys_size, src=self._pp_data_group.ranks[0], group=self._pp_data_group + paddle.distributed.broadcast_object_list( + data_keys_size, src=self._pp_data_group.ranks[0], group=self._pp_data_group ) - self._data_int64_keys_size = int(data_int64_keys_size.item()) - self._data_fp32_keys_size = int(data_fp32_keys_size.item()) + self._data_keys_size = data_keys_size if not self._need_data: - data_int64_keys = [None for i in range(self._data_int64_keys_size)] - data_fp32_keys = [None for i in range(self._data_fp32_keys_size)] + data_keys_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size] - # broadcast data keys name - if self._data_int64_keys is None: + # Broadcast data keys name. + if self._data_keys_list is None: if self.mp_group is not None and self.pp_rank == 0: - paddle.distributed.broadcast_object_list(data_int64_keys, src=self.mp_src_rank, group=self.mp_group) - if self._data_fp32_keys_size > 0: - paddle.distributed.broadcast_object_list(data_fp32_keys, src=self.mp_src_rank, group=self.mp_group) + paddle.distributed.broadcast_object_list(data_keys_list, src=self.mp_src_rank, group=self.mp_group) if self._pp_data_group is not None: paddle.distributed.broadcast_object_list( - data_int64_keys, src=self._pp_data_group.ranks[0], group=self._pp_data_group + data_keys_list, src=self._pp_data_group.ranks[0], group=self._pp_data_group ) - if self._data_fp32_keys_size > 0: - paddle.distributed.broadcast_object_list( - data_fp32_keys, src=self._pp_data_group.ranks[0], group=self._pp_data_group - ) - self._data_int64_keys = data_int64_keys - self._data_fp32_keys = data_fp32_keys + self._data_keys_list = data_keys_list - # broadcast data + # Broadcast data. if not self._need_data: - data_int64_list = [None for i in range(self._data_int64_keys_size)] - data_fp32_list = [None for i in range(self._data_fp32_keys_size)] + data_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size] + if self.mp_group is not None and self.pp_rank == 0: - data_int64_list = broadcast_data_list( - data_int64_list, paddle.int64, self.mp_rank, self.mp_group, self.mp_src_rank - ) - if self._data_fp32_keys_size > 0: - data_fp32_list = broadcast_data_list( - data_fp32_list, paddle.float32, self.mp_rank, self.mp_group, self.mp_src_rank - ) + for i, dtype in enumerate(self.dtype_list): + if self._data_keys_size[i] > 0: + data_list[i] = broadcast_data_list( + data_list[i], dtype, self.mp_rank, self.mp_group, self.mp_src_rank + ) if self._pp_data_group is not None: # Note(daisimng): In last stage of pp, we don't need input_ids. # It will be removed in future. - data_int64_list = broadcast_data_list( - data_int64_list, - paddle.int64, - self.pp_rank, - self._pp_data_group, - self._pp_data_group.ranks[0], - ) - if self._data_fp32_keys_size > 0: - data_fp32_list = broadcast_data_list( - data_fp32_list, - paddle.float32, - self.pp_rank, - self._pp_data_group, - self._pp_data_group.ranks[0], - ) + for i, dtype in enumerate(self.dtype_list): + if self._data_keys_size[i] > 0: + data_list[i] = broadcast_data_list( + data_list[i], + dtype, + self.pp_rank, + self._pp_data_group, + self._pp_data_group.ranks[0], + ) + + out_data = {} + for keys, datas in zip(self._data_keys_list, data_list): + out_data.update([(k, d) for k, d in zip(keys, datas)]) - out = dict([(key, data) for key, data in zip(self._data_int64_keys, data_int64_list)]) - out.update([(key, data) for key, data in zip(self._data_fp32_keys, data_fp32_list)]) - return out + return out_data def broadcast_data_list(data_list, datatype, comm_rank=0, comm_group=None, src_rank=0): diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index 5697dc8d6052..1546cb16d3dd 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -279,9 +279,7 @@ from .rw.modeling import * from .rw.configuration import * from .rw.tokenizer import * -from .qwen.modeling import * -from .qwen.configuration import * -from .qwen.tokenizer import * +from .qwen import * # For faster tokenizer from ..utils.import_utils import is_fast_tokenizer_available diff --git a/paddlenlp/transformers/auto/modeling.py b/paddlenlp/transformers/auto/modeling.py index 6e07d112f02d..dc23a9aee29d 100644 --- a/paddlenlp/transformers/auto/modeling.py +++ b/paddlenlp/transformers/auto/modeling.py @@ -211,9 +211,10 @@ def __init__(self, *args, **kwargs): # TODO: Refactor into AutoConfig when available @classmethod - def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file_path): - with io.open(config_file_path, encoding="utf-8") as f: - config = json.load(f) + def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file_path, config=None): + if config is None: + with io.open(config_file_path, encoding="utf-8") as f: + config = json.load(f) # Get class name corresponds to this configuration if is_standard_config(config): @@ -262,6 +263,11 @@ def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file + f" to load '{pretrained_model_name_or_path}'\n" ) + @classmethod + def from_config(cls, config, **kwargs): + model_class = cls._get_model_class_from_config(None, None, config) + return model_class._from_config(config, **kwargs) + @classmethod def _from_pretrained(cls, pretrained_model_name_or_path, task=None, *model_args, **kwargs): if task: diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index b10a7073254a..60cb74311a72 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -962,6 +962,17 @@ def _from_config(cls, config, **kwargs): return model + @classmethod + def from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + + Args: + dtype (`paddle.dtype`, *optional*): + Override the default `paddle.dtype` and load the model under this dtype. + """ + return cls._from_config(config, **kwargs) + @property def base_model(self): """ diff --git a/paddlenlp/transformers/qwen/__init__.py b/paddlenlp/transformers/qwen/__init__.py index 3b5b28f31501..640b75d167d2 100644 --- a/paddlenlp/transformers/qwen/__init__.py +++ b/paddlenlp/transformers/qwen/__init__.py @@ -10,3 +10,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .configuration import * +from .modeling import * +from .modeling_pp import * +from .tokenizer import * diff --git a/paddlenlp/transformers/qwen/configuration.py b/paddlenlp/transformers/qwen/configuration.py index 72a0e703df44..facc3fcad170 100644 --- a/paddlenlp/transformers/qwen/configuration.py +++ b/paddlenlp/transformers/qwen/configuration.py @@ -14,6 +14,8 @@ from paddlenlp.transformers import PretrainedConfig +__all__ = ["QWenConfig"] + class QWenConfig(PretrainedConfig): model_type = "qwen" @@ -41,6 +43,7 @@ def __init__( use_fused_rms_norm=False, use_fused_rope=False, intermediate_size=22016, + tensor_parallel_output=True, no_bias=True, tie_word_embeddings=False, **kwargs, diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index 17c846e04526..fb8557fecdd8 100644 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -13,6 +13,7 @@ # limitations under the License. import math +import warnings from functools import partial from typing import List @@ -35,6 +36,16 @@ from ..model_outputs import ModelOutput from .configuration import QWenConfig +__all__ = [ + "QWenBlock", + "QWenForCausalLM", + "QWenPretrainedModel", + "QWenModel", + "QWenLMHead", + "QWenPretrainingCriterion", +] + + MAX_NTK_SEQ_LENGTH = 32768 try: @@ -48,19 +59,6 @@ fused_rotary_position_embedding = None -def get_triangle_upper_mask(x, mask=None): - if mask is not None: - return mask - # [bsz, n_head, q_len, kv_seq_len] - shape = x.shape - # [bsz, 1, q_len, kv_seq_len] - shape[1] = 1 - mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype) - mask = paddle.triu(mask, diagonal=1) - mask.stop_gradient = True - return mask - - def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): is_fleet_init = True tensor_parallel_degree = 1 @@ -86,6 +84,19 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): return logits +def get_triangle_upper_mask(x, mask=None): + if mask is not None: + return mask + # [bsz, n_head, q_len, kv_seq_len] + shape = x.shape + # [bsz, 1, q_len, kv_seq_len] + shape[1] = 1 + mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype) + mask = paddle.triu(mask, diagonal=1) + mask.stop_gradient = True + return mask + + class QWenAttention(nn.Layer): def __init__(self, config): super().__init__() @@ -324,10 +335,9 @@ def forward(self, hidden_states): class QWenBlock(nn.Layer): def __init__(self, config): super().__init__() - self.ln_1 = RMSNorm(config) + self.ln_1 = QWenRMSNorm(config) self.attn = QWenAttention(config) - self.ln_2 = RMSNorm(config) - + self.ln_2 = QWenRMSNorm(config) self.mlp = QWenMLP(config) def forward( @@ -367,10 +377,14 @@ def forward( else: outputs = (hidden_states,) + outputs[1:] + # remove empty tuple for pipeline parallel + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + return outputs -class QWenPreTrainedModel(PretrainedModel): +class QWenPretrainedModel(PretrainedModel): config_class = QWenConfig base_model_prefix = "qwen" @@ -515,7 +529,7 @@ def _init_weights(self, module): ) -class QWenModel(QWenPreTrainedModel): +class QWenModel(QWenPretrainedModel): def __init__(self, config): super().__init__(config) self.config = config @@ -541,7 +555,7 @@ def __init__(self, config): for i in range(config.num_hidden_layers) ] ) - self.ln_f = RMSNorm(config) + self.ln_f = QWenRMSNorm(config) def get_input_embeddings(self): return self.wte @@ -696,7 +710,11 @@ def forward( output_attentions=output_attentions, ) - hidden_states = outputs[0] + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + if use_cache is True: presents = presents + (outputs[2 if output_attentions else 1],) @@ -746,13 +764,49 @@ def forward(self, hidden_states, tensor_parallel_output=None): return logits -class QWenForCausalLM(QWenPreTrainedModel): +class QWenPretrainingCriterion(paddle.nn.Layer): + """ + Criterion for Llama. + It calculates the final loss. + """ + + def __init__(self, config): + + super(QWenPretrainingCriterion, self).__init__() + self.ignore_index = getattr(config, "ignore_index", -100) + self.config = config + self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + + if self.enable_parallel_cross_entropy: # and False: # and lm_head is distributed + self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index) + else: + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + def forward(self, prediction_scores, masked_lm_labels): + if self.enable_parallel_cross_entropy: + if prediction_scores.shape[-1] == self.config.vocab_size: + warnings.warn( + f"enable_parallel_cross_entropy, the vocab_size should be splited: {prediction_scores.shape[-1]}, {self.config.vocab_size}" + ) + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + with paddle.amp.auto_cast(False): + masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) + # skip ignore_index which loss == 0 + masked_lm_loss = masked_lm_loss[masked_lm_loss > 0].astype("float32") + loss = paddle.mean(masked_lm_loss) + + return loss + + +class QWenForCausalLM(QWenPretrainedModel): _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"] def __init__(self, config): super().__init__(config) self.qwen = QWenModel(config) self.lm_head = QWenLMHead(config) + self.criterion = QWenPretrainingCriterion(config) def get_output_embeddings(self): return self.lm_head @@ -846,12 +900,24 @@ def forward( ) hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states) + # if labels is None,means we need full output, instead of tensor_parallel_output + # tensor_parallel_output is togather with ParallelCrossEntropy + tensor_parallel_output = ( + self.config.tensor_parallel_output and labels is not None and self.config.tensor_parallel_degree > 1 + ) + + lm_logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) loss = None if labels is not None: - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(lm_logits, labels) + loss = self.criterion(lm_logits, labels) + + # lm_logits = self.lm_head(hidden_states) + + # loss = None + # if labels is not None: + # loss_fct = nn.CrossEntropyLoss() + # loss = loss_fct(lm_logits, labels) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -918,7 +984,7 @@ def rms_norm_fused(x_in, w, eps): return fused_ln.fused_rms_norm(x_in, w, eps)[0] -class RMSNorm(nn.Layer): +class QWenRMSNorm(nn.Layer): def __init__(self, config): super().__init__() self.config = config diff --git a/paddlenlp/transformers/qwen/modeling_pp.py b/paddlenlp/transformers/qwen/modeling_pp.py new file mode 100644 index 000000000000..623968c19bc9 --- /dev/null +++ b/paddlenlp/transformers/qwen/modeling_pp.py @@ -0,0 +1,198 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.distributed.fleet as fleet +import paddle.nn as nn +from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer + +from paddlenlp.transformers.model_utils import PipelinePretrainedModel + +from .modeling import ( + QWenBlock, + QWenConfig, + QWenLMHead, + QWenModel, + QWenPretrainedModel, + QWenPretrainingCriterion, + QWenRMSNorm, +) + +__all__ = [ + "QWenForCausalLMPipe", +] + + +def parse_args(args): + if isinstance(args, tuple): + if len(args) == 3: + hidden_states, attention_mask, position_ids = args + elif len(args) == 2: + hidden_states, attention_mask = args + position_ids = None + elif len(args) == 1: + hidden_states = args + attention_mask, position_ids = None, None + else: + hidden_states = args + attention_mask, position_ids = None, None + + if position_ids is not None: + position_ids.stop_gradient = True + + if attention_mask is not None: + attention_mask.stop_gradient = True + + return hidden_states, attention_mask, position_ids + + +def return_args(hidden_states, attention_mask=None, position_ids=None): + ret = (hidden_states,) + + if attention_mask is not None: + ret += (attention_mask.clone(),) + if position_ids is not None: + ret += (position_ids.clone(),) + if len(ret) == 1: + ret = ret[0] + + return ret + + +class QWenEmbeddingPipe(nn.Layer): + """Extends QWenEmbeddings to forward attention_mask through the pipeline.""" + + def __init__(self, config): + super(QWenEmbeddingPipe, self).__init__() + self.hidden_size = config.hidden_size + if config.tensor_parallel_degree > 1: + self.wte = fleet.meta_parallel.VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()), + ) + else: + self.wte = nn.Embedding(config.vocab_size, config.hidden_size) + + def forward(self, args): + """_summary_ + + Args: + input (_type_): _description_ + + Returns: + _type_: _description_ + """ + input_ids, attention_mask, position_ids = parse_args(args) + input_embeds = self.wte(input_ids) + + batch_size, seq_length = input_ids.shape + if attention_mask is not None: + attention_mask = QWenModel._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), 0, input_embeds.dtype + ) + attention_mask.stop_gradient = True + + return return_args(input_embeds, attention_mask, position_ids) + + +class QWenBlockPipe(QWenBlock): + def forward(self, args): + hidden_states, attention_mask, position_ids = parse_args(args) + hidden_states = super().forward(hidden_states, attention_mask=attention_mask) + return return_args(hidden_states, attention_mask, position_ids) + + +class QWenRMSNormPipe(QWenRMSNorm): + def forward(self, args): + hidden_states, attention_mask, position_ids = parse_args(args) + return super().forward(hidden_states) + + +class QWenForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): + """QWenForPretraining adapted for pipeline parallelism. + + The largest change is flattening the QWenModel class so we can express it as a + sequence of layers including embedding, transformer layers, and output. + """ + + config_class = QWenConfig + + _get_tensor_parallel_mappings = QWenPretrainedModel._get_tensor_parallel_mappings + _init_weights = QWenPretrainedModel._init_weights + _keys_to_ignore_on_load_unexpected = QWenPretrainedModel._keys_to_ignore_on_load_unexpected + + # DONOT Add base_model_prefix !!!! + + def __init__(self, config): + self.config = config + + self.use_recompute = self.config.use_recompute + self.recompute_granularity = self.config.recompute_granularity + self.pp_recompute_interval = self.config.pp_recompute_interval + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + if self.recompute_granularity == "full": + assert len(self.no_recompute_layers) == 0, "for pp with full recompute, no_recompute_layers is not support" + + virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1) + + def get_hcg(): + return fleet.get_hybrid_communicate_group() + + hcg = get_hcg() + tensor_parallel_degree = max(hcg.get_model_parallel_world_size(), 1) + tensor_parallel_rank = max(hcg.get_model_parallel_rank(), 0) + + # TODO: fix tensor_parallel_degree rewrite in here + config.tensor_parallel_degree = tensor_parallel_degree + config.tensor_parallel_rank = tensor_parallel_rank + + self.add_sequential_layer(LayerDesc(QWenEmbeddingPipe, config=config), "qwen") + for i in range(config.num_hidden_layers): + self.add_sequential_layer( + LayerDesc(QWenBlockPipe, config=config), + f"qwen.h.{i}", + ) + self.add_sequential_layer(LayerDesc(QWenRMSNormPipe, config=config), "qwen.ln_f") + self.add_sequential_layer(LayerDesc(QWenLMHead, config=config), "lm_head") + + recompute_interval = 0 + if self.use_recompute and self.recompute_granularity == "full": + assert self.config.pp_recompute_interval <= config.num_hidden_layers // ( + virtual_pp_degree * get_hcg().topology().get_dim_size("pipe") + ), "pp recompute interval should smaller than num layers of each pp chunk" + recompute_interval = self.config.pp_recompute_interval + + seg_method = "layer:QWenBlock" + if config.num_hidden_layers % get_hcg().topology().get_dim_size("pipe") != 0: + seg_method = "uniform" + + PipelineLayer.__init__( + self, + layers=self.get_sequential_layers(), + loss_fn=QWenPretrainingCriterion(config), + topology=get_hcg().topology(), + seg_method=seg_method, + recompute_interval=recompute_interval, + recompute_ctx={ + "mp_group": get_hcg().get_model_parallel_group(), + "offload": False, + "partition": False, + }, + num_virtual_pipeline_stages=virtual_pp_degree, + ) + # You should call init here, since there is a diamond inheritance problem + self.apply(self._init_weights) + # DON'T init PipelinePretrainedModel + # PipelinePretrainedModel.__init__(self.super(), config=config) diff --git a/paddlenlp/transformers/qwen/tokenizer.py b/paddlenlp/transformers/qwen/tokenizer.py index d2b2a4d27b66..c4c8eb0a762b 100644 --- a/paddlenlp/transformers/qwen/tokenizer.py +++ b/paddlenlp/transformers/qwen/tokenizer.py @@ -30,6 +30,9 @@ PaddingStrategy, ) +__all__ = ["QWenTokenizer"] + + VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"} PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" diff --git a/tests/transformers/qwen/test_modeling.py b/tests/transformers/qwen/test_modeling.py index be4278c8ca74..4096d10478b7 100644 --- a/tests/transformers/qwen/test_modeling.py +++ b/tests/transformers/qwen/test_modeling.py @@ -55,7 +55,7 @@ def __init__( bias=False, parallel_attn=True, output_attentions=False, - use_flash_attn=False, + use_flash_attention=False, ): self.parent = parent self.batch_size = batch_size @@ -85,7 +85,7 @@ def __init__( self.bias = bias self.parallel_attn = parallel_attn self.output_attentions = output_attentions - self.use_flash_attn = use_flash_attn + self.use_flash_attention = use_flash_attention def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size, dtype="int64") @@ -148,7 +148,7 @@ def get_config(self): output_attentions=self.output_attentions, seq_length=self.seq_length, kv_channels=self.hidden_size // self.num_attention_heads, - use_flash_attn=self.use_flash_attn, + use_flash_attention=self.use_flash_attention, ) def create_and_check_model( diff --git a/tests/transformers/test_modeling_common.py b/tests/transformers/test_modeling_common.py index 620d6944417b..2a14fea66e17 100644 --- a/tests/transformers/test_modeling_common.py +++ b/tests/transformers/test_modeling_common.py @@ -800,7 +800,7 @@ def test_model_from_config_paddle_hub(self): if self.paddlehub_remote_test_model_path is None or self.base_model_class is None: return config = self.base_model_class.config_class.from_pretrained(self.paddlehub_remote_test_model_path) - model = self.base_model_class._from_config(config) + model = self.base_model_class.from_config(config) self.assertIsNotNone(model) @slow diff --git a/tests/transformers/test_shard_checkpoint.py b/tests/transformers/test_shard_checkpoint.py index 69c50d27d414..8b317a7ae52a 100644 --- a/tests/transformers/test_shard_checkpoint.py +++ b/tests/transformers/test_shard_checkpoint.py @@ -74,7 +74,7 @@ def test_from_pretrained_low_cpu_mem_usage_functional(self): def test_keep_in_fp32_modules(self): with tempfile.TemporaryDirectory() as tempdir: config = PretrainedConfig() - model = FakeModel._from_config(config, dtype="float16") + model = FakeModel.from_config(config, dtype="float16") model.config = config model.save_pretrained(tempdir) @@ -94,7 +94,7 @@ def test_load_sharded_checkpoint(self): with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir, max_shard_size="200kiB") - model_load = BertModel._from_config(config) + model_load = BertModel.from_config(config) missing_keys, unexpected_keys = load_sharded_checkpoint(model_load, tmp_dir) self.assertEqual(missing_keys, []) @@ -115,7 +115,7 @@ def inner_convert_test(src_dtype, dst_dtype): str_dst_dtype = str(dst_dtype)[dtype_prefix_len:] config = AutoConfig.from_pretrained("__internal_testing__/tiny-random-bert") - model = BertModel._from_config(config, dtype=str_src_dtype) + model = BertModel.from_config(config, dtype=str_src_dtype) with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir) diff --git a/tests/transformers/test_tensor_parallel.py b/tests/transformers/test_tensor_parallel.py index 209ce43d1521..e3f02d20e242 100644 --- a/tests/transformers/test_tensor_parallel.py +++ b/tests/transformers/test_tensor_parallel.py @@ -96,7 +96,7 @@ def _test_llama(): config = LlamaConfig() config = prepare_config(config) - model = LlamaForCausalLM._from_config(config) + model = LlamaForCausalLM.from_config(config) common_test_merge(model, LlamaForCausalLM) @@ -105,7 +105,7 @@ def _test_chatglm(): config = ChatGLMConfig() config = prepare_config(config) - model = ChatGLMForCausalLM._from_config(config) + model = ChatGLMForCausalLM.from_config(config) common_test_merge(model, ChatGLMForCausalLM) @@ -114,7 +114,7 @@ def _test_bloom(): config = BloomConfig() config = prepare_config(config) - model = BloomForCausalLM._from_config(config) + model = BloomForCausalLM.from_config(config) common_test_merge(model, BloomForCausalLM)