Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
Browse files Browse the repository at this point in the history
…nto update_dist_dataloader
  • Loading branch information
DesmonDay committed May 7, 2024
2 parents d5dae85 + 9f3cf82 commit acdf480
Show file tree
Hide file tree
Showing 171 changed files with 6,018 additions and 632 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ jobs:
- name: run the command
run: make test
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v4
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@

## News 📢

* **2024.01.04 [PaddleNLP v2.7](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v2.7.1)**: 大模型体验全面升级,统一工具链大模型入口。统一预训练、精调、压缩、推理以及部署等环节的实现代码,到 `PaddleNLP/llm`目录。全新大[模型工具链文档](https://paddlenlp.readthedocs.io/zh/latest/llm/finetune.html),一站式指引用户从大模型入门到业务部署上线。全断点存储机制 Unified Checkpoint,大大提高大模型存储的通用性。高效微调升级,支持了高效微调+LoRA同时使用,支持了QLoRA等算法。
* **2024.04.24 [PaddleNLP v2.8](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v2.8.0)**:自研极致收敛的RsLoRA+算法,大幅提升PEFT训练收敛速度以及训练效果;引入高性能生成加速到RLHF PPO算法,打破 PPO 训练中生成速度瓶颈,PPO训练性能大幅领先。通用化支持 FastFNN、FusedQKV等多个大模型训练性能优化方式,大模型训练更快、更稳定。

* **2024.01.04 [PaddleNLP v2.7](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v2.7.1)**: 大模型体验全面升级,统一工具链大模型入口。统一预训练、精调、压缩、推理以及部署等环节的实现代码,到 `PaddleNLP/llm`目录。全新[大模型工具链文档](https://paddlenlp.readthedocs.io/zh/latest/llm/finetune.html),一站式指引用户从大模型入门到业务部署上线。全断点存储机制 Unified Checkpoint,大大提高大模型存储的通用性。高效微调升级,支持了高效微调+LoRA同时使用,支持了QLoRA等算法。

* **2023.08.15 [PaddleNLP v2.6](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v2.6.0)**: 发布[全流程大模型工具链](./llm),涵盖预训练,精调,压缩,推理以及部署等各个环节,为用户提供端到端的大模型方案和一站式的开发体验;内置[4D并行分布式Trainer](./docs/trainer.md)[高效微调算法LoRA/Prefix Tuning](./llm#33-lora), [自研INT8/INT4量化算法](./llm#6-量化)等等;全面支持[LLaMA 1/2](./llm/llama), [BLOOM](.llm/bloom), [ChatGLM 1/2](./llm/chatglm), [GLM](./llm/glm), [OPT](./llm/opt)等主流大模型

Expand Down
16 changes: 15 additions & 1 deletion docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,20 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
default -1 for not use tensor parallel, Suggest tensor_parallel_degree<=8 for better proformance.
Note, this need model support in source code, currently GPT/BLOOM/LLAMA/BLOOM/CLM/CHATGLM is supported.
--tensor_parallel_config
对于张量并行,一些选项会影响训练性能,这里将一些选项配置集中管理,以str形式传入配置.
支持如下选项:
enable_delay_scale_loss : 在优化器阶段做梯度累加,将所有梯度除以累加次数,而不是直接对loss除以累加次数。
sync_param : 在优化器阶段使用broadcast同步所有is_distributed=False的参数
sync_grad : 在优化器阶段使用broadcast同步所有is_distributed=False的梯度
sync_moment : 在优化器阶段使用broadcast同步所有is_distributed=False的momentum
Some additional config it highly affect the usage of tensor parallel, we provide some option to config it.
following config is support:
enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.
sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False.
sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False.
sync_moment, in optimizer step, use broadcast to sync momentums those attr 'is_distributed' is False.
--pipeline_parallel_degree
流水线并行是Megatron论文针对多层Transformer结构提出的按层划分方法.
Expand Down Expand Up @@ -549,7 +563,7 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
following config is support:
disable_p2p_cache_shape, if you max sequence length is varying, please set disable_p2p_cache_shape.
disable_partial_send_recv, optmize send speed for tensor parallel.
enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.
enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.
enable_dp_comm_overlap, fuse data parallel gradient communication.
--data_parallel_config
Expand Down
4 changes: 2 additions & 2 deletions examples/language_model/moe/dygraph/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,8 +748,8 @@ def forward(self, input_ids, position_ids=None, attention_mask=None, use_cache=F
if position_ids is None:
past_length = 0
if cache is not None:
past_length = paddle.shape(cache[0].k)[-2]
position_ids = paddle.arange(past_length, paddle.shape(input_ids)[-1] + past_length, dtype="int64")
past_length = cache[0].k.shape[-2]
position_ids = paddle.arange(past_length, input_ids.shape[-1] + past_length, dtype="int64")
position_ids = position_ids.unsqueeze(0)
# .expand_as(input_ids)
position_ids = paddle.expand_as(position_ids, input_ids)
Expand Down
12 changes: 5 additions & 7 deletions examples/model_interpretation/task/senti/rnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def forward(self, input, mask=None):
# Shape: (batch_size, max_seq_len, hidden_size)
h = paddle.add_n([forward_input, backward_input])
# Shape: (batch_size, hidden_size, 1)
att_weight = self.att_weight.tile(repeat_times=(paddle.shape(h)[0], 1, 1))
att_weight = self.att_weight.tile(repeat_times=(h.shape[0], 1, 1))
# Shape: (batch_size, max_seq_len, 1)
att_score = paddle.bmm(paddle.tanh(h), att_weight)
if mask is not None:
Expand Down Expand Up @@ -246,20 +246,18 @@ def forward(self, input, mask=None):
Tensor is a bool tensor, whose each element identifies whether the input word id is pad token or not.
Defaults to `None
"""
weight = self.input_weight.tile(
repeat_times=(paddle.shape(input)[0], 1, 1)
) # tensor[batch, hidden_size, hidden_size]
bias = self.bias.tile(repeat_times=(paddle.shape(input)[0], 1, 1)) # tensor[batch, 1, hidden_size]
weight = self.input_weight.tile(repeat_times=(input.shape[0], 1, 1)) # tensor[batch, hidden_size, hidden_size]
bias = self.bias.tile(repeat_times=(input.shape[0], 1, 1)) # tensor[batch, 1, hidden_size]
word_squish = paddle.bmm(input, weight) + bias # Shape: (batch_size, seq_len, hidden_size)
att_context_vector = self.att_context_vector.tile(
repeat_times=(paddle.shape(input)[0], 1, 1)
repeat_times=(input.shape[0], 1, 1)
) # Shape: (batch_size, hidden_size, 1)
att_score = paddle.bmm(word_squish, att_context_vector) # tensor[batch_size, seq_len, 1]
if mask is not None:
# mask, remove the effect of 'PAD'
mask = paddle.cast(mask, dtype="float32")
mask = mask.unsqueeze(axis=-1)
inf_tensor = paddle.full(shape=paddle.shape(mask), dtype="float32", fill_value=-INF)
inf_tensor = paddle.full(shape=mask.shape, dtype="float32", fill_value=-INF)
att_score = paddle.multiply(att_score, mask) + paddle.multiply(inf_tensor, (1 - mask))
att_weight = F.softmax(att_score, axis=1) # tensor[batch_size, seq_len, 1]

Expand Down
2 changes: 1 addition & 1 deletion examples/simultaneous_translation/stacl/demo/model_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def greedy_search(self, src_word, max_len=256, waitk=-1, caches=None, bos_id=Non
So, it needsprevious state(caches) and last one of generated
tokens id last time.
"""
src_max_len = paddle.shape(src_word)[-1]
src_max_len = src_word.shape[-1]
base_attn_bias = (
paddle.cast(src_word == self.bos_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e9
)
Expand Down
10 changes: 5 additions & 5 deletions examples/simultaneous_translation/stacl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from __future__ import print_function

import numpy as np

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlenlp.transformers import WordEmbedding, PositionalEmbedding

from paddlenlp.transformers import PositionalEmbedding, WordEmbedding


class CrossEntropyCriterion(nn.Layer):
Expand Down Expand Up @@ -190,8 +190,8 @@ def __init__(
self.linear = nn.Linear(in_features=d_model, out_features=trg_vocab_size, bias_attr=False)

def forward(self, src_word, trg_word):
src_max_len = paddle.shape(src_word)[-1]
trg_max_len = paddle.shape(trg_word)[-1]
src_max_len = src_word.shape[-1]
trg_max_len = trg_word.shape[-1]
base_attn_bias = (
paddle.cast(src_word == self.bos_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e9
)
Expand Down Expand Up @@ -236,7 +236,7 @@ def beam_search(self, src_word, beam_size=4, max_len=256, waitk=-1):
raise NotImplementedError

def greedy_search(self, src_word, max_len=256, waitk=-1):
src_max_len = paddle.shape(src_word)[-1]
src_max_len = src_word.shape[-1]
base_attn_bias = (
paddle.cast(src_word == self.bos_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e9
)
Expand Down
10 changes: 5 additions & 5 deletions examples/text_classification/rnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def forward(self, input, mask=None):
# Shape: (batch_size, max_seq_len, hidden_size)
h = paddle.add_n([forward_input, backward_input])
# Shape: (batch_size, hidden_size, 1)
att_weight = self.att_weight.tile(repeat_times=(paddle.shape(h)[0], 1, 1))
att_weight = self.att_weight.tile(repeat_times=(h.shape[0], 1, 1))
# Shape: (batch_size, max_seq_len, 1)
att_score = paddle.bmm(paddle.tanh(h), att_weight)
if mask is not None:
Expand Down Expand Up @@ -292,19 +292,19 @@ def forward(self, input, mask=None):
Tensor is a bool tensor, whose each element identifies whether the input word id is pad token or not.
Defaults to `None
"""
weight = self.input_weight.tile(repeat_times=(paddle.shape(input)[0], 1, 1))
bias = self.bias.tile(repeat_times=(paddle.shape(input)[0], 1, 1))
weight = self.input_weight.tile(repeat_times=(input.shape[0], 1, 1))
bias = self.bias.tile(repeat_times=(input.shape[0], 1, 1))
# Shape: (batch_size, max_seq_len, hidden_size)
word_squish = paddle.bmm(input, weight) + bias

att_context_vector = self.att_context_vector.tile(repeat_times=(paddle.shape(input)[0], 1, 1))
att_context_vector = self.att_context_vector.tile(repeat_times=(input.shape[0], 1, 1))
# Shape: (batch_size, max_seq_len, 1)
att_score = paddle.bmm(word_squish, att_context_vector)
if mask is not None:
# mask, remove the effect of 'PAD'
mask = paddle.cast(mask, dtype="float32")
mask = mask.unsqueeze(axis=-1)
inf_tensor = paddle.full(shape=paddle.shape(mask), dtype="float32", fill_value=-INF)
inf_tensor = paddle.full(shape=mask.shape, dtype="float32", fill_value=-INF)
att_score = paddle.multiply(att_score, mask) + paddle.multiply(inf_tensor, (1 - mask))
att_weight = F.softmax(att_score, axis=1)

Expand Down
2 changes: 1 addition & 1 deletion examples/text_to_sql/RAT-SQL/text2sql/utils/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def batch_gather_2d(var, indices):
"shape of indices error. it should be a 2-D layers. " "but got shape = %s" % (str(indices.shape),)
)

batch_size = paddle.shape(indices)[0]
batch_size = indices.shape[0]

zero = paddle.to_tensor([0], dtype="int64")
one = paddle.to_tensor([1], dtype="int64")
Expand Down
2 changes: 1 addition & 1 deletion llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ PaddleNLP将飞桨4D并行策略加入到Trainer API中, 用户只需修改Tra

此项目支持了LLaMA、GPT-3、BaiChuan、Qwen 等大模型的预训练。用户切换配置config文件,即可一键运行。

数据详细制作流程可参考[此处](https://paddlenlp.readthedocs.io/zh/latest/pretraining/dataset.html) : https://paddlenlp.readthedocs.io/zh/latest/pretraining/dataset.html
数据详细制作流程可参考[此处](https://paddlenlp.readthedocs.io/zh/latest/llm/pretraining/dataset.html) : https://paddlenlp.readthedocs.io/zh/latest/llm/pretraining/dataset.html

为了方便用户运行测试本模型,本项目提供了处理好的100k条doc的训练样本:
```shell
Expand Down
8 changes: 4 additions & 4 deletions llm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ def get_convert_example(model):

if base_model_prefix == "chatglm":
return convert_example_chatglm
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral"]:
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral", "gemma"]:
return convert_example_common
else:
raise ValueError(
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral"
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma"
)


Expand Down Expand Up @@ -90,7 +90,7 @@ def tokenize_example(tokenizer, example, data_args):
return tokenized_source, tokenized_target_input_ids


def tokenize_rounds_example(tokenizer, example, data_args):
def tokenize_rounds_example(tokenizer, example, data_args, **kwargs):
"""tokenize multi-rounds examples with chat_template.json
Args:
Expand All @@ -117,7 +117,7 @@ def tokenize_rounds_example(tokenizer, example, data_args):

# 1. only tokenize input_ids
conversation_result: list[tuple[list[int], list[int]]] = tokenizer.encode_chat_inputs(
conversations, context_data=context_data
conversations, context_data=context_data, **kwargs
)
system_ids = conversation_result.pop("system", []) or []

Expand Down
3 changes: 2 additions & 1 deletion llm/docs/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./

```
python merge_tp_and_pp_params.py \
--model_name_or_path ./checkpoints/llama_sft_ckpts/checkpoint-100
--model_name_or_path ./checkpoints/llama_sft_ckpts/checkpoint-100 \
--pp 2 --tp 4
```

<summary>&emsp; 脚本参数介绍</summary><div>
Expand Down
4 changes: 2 additions & 2 deletions llm/ernie-3.5-se/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, output_attentions, config, is_causal=True
):

bsz, q_len, num_heads, _ = paddle.shape(query_states)
bsz, q_len, num_heads, _ = query_states.shape
head_dim = config.hidden_size // config.num_attention_heads
_, kv_seq_len, _, _ = value_states.shape

Expand Down Expand Up @@ -1054,7 +1054,7 @@ def forward(
seq_length_with_past = seq_length
cache_length = 0
if past_key_values[0] is not None:
cache_length = paddle.shape(past_key_values[0][0])[1]
cache_length = past_key_values[0][0].shape[1]
seq_length_with_past += cache_length
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids).astype(self.embed_tokens.weight.dtype)
Expand Down
18 changes: 18 additions & 0 deletions llm/gemma/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Gemma

## 1.模型介绍

[Gemma](https://blog.google/technology/developers/gemma-open-models/) 由谷歌DeepMind和谷歌其他团队开发,是一个轻量级、最先进的开放式模型家族,采用与Gemini模型相同的研究和技术构建。

**支持模型权重:**

| Model |
| ------------------ |
| google/gemma-7b |
| google/gemma-7b-it |
| google/gemma-2b |
| google/gemma-2b-it |

## 2. 模型精调

请参考[LLM全流程工具介绍](../README.md)
30 changes: 30 additions & 0 deletions llm/gemma/sft_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"model_name_or_path": "google/gemma-2b/",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/gemma_sft_ckpts",
"per_device_train_batch_size": 2,
"gradient_accumulation_steps": 1,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-05,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 512,
"max_length": 1024,
"fp16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 2,
"zero_padding": false,
"use_flash_attention": false
}
32 changes: 32 additions & 0 deletions llm/gemma/sft_argument_7b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"model_name_or_path": "google/gemma-7b",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/gemma_sft_ckpts",
"per_device_train_batch_size": 8,
"gradient_accumulation_steps": 1,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":1,
"num_train_epochs": 3,
"learning_rate": 3e-06,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 512,
"max_length": 1024,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"do_predict": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 8,
"pipeline_parallel_degree": 1,
"zero_padding": false,
"use_flash_attention": false
}
33 changes: 33 additions & 0 deletions llm/gemma/sft_argument_7b_sharding.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"model_name_or_path": "google/gemma-7b",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/llama_sft_ckpts",
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 1,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":1,
"num_train_epochs": 3,
"learning_rate": 3e-06,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"fp16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"do_predict": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"sharding_parallel_degree": 8,
"sharding": "stage3",
"pipeline_parallel_degree": 1,
"zero_padding": false,
"use_flash_attention": false
}
Loading

0 comments on commit acdf480

Please sign in to comment.