Skip to content

Commit

Permalink
Revised PoSE (PaddlePaddle#8822)
Browse files Browse the repository at this point in the history
* Yarn

* Revise

* Delete paddlenlp/transformers/llama/modeling_new.py

* Delete paddlenlp/transformers/llama/modeling_sparse.py

* Delete paddlenlp/transformers/long_sequence_strategies/embedding_strategies_old.py

* Delete paddlenlp/transformers/long_sequence_strategies/embedding_strategies_yarn.py

* Delete llm/run_finetune_old.py

* Update data.py

* Update test_long_sequence_strategies_pose.py

* Update test_long_sequence_strategies_pose.py

* Create test_long_sequence_strategies_pose.py

* Update and rename test_long_sequence_strategies_pose.py to test_yarn.py

* Delete tests/peft/test_yarn.py

* Update run_finetune.py

* Update pyproject.toml

* Update pyproject.toml

* Add files via upload

* Update pyproject.toml

* Revise

* Revised

* Revised

* Revised

* update

* update

* update
  • Loading branch information
whf313 authored Nov 21, 2024
1 parent 5a343cf commit 6813e40
Show file tree
Hide file tree
Showing 10 changed files with 5,339 additions and 23 deletions.
6 changes: 6 additions & 0 deletions llm/docs/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ python merge_lora_params.py \
- `neftune_noise_alpha`: NEFT alpha 参数,默认为5.0。
- `vera`: 是否开启 VeRA 微调策略,默认为 False。
- `vera_rank`: VeRA 算法中 rank(秩)的值,默认为8。
- `use_long_sequence_strategies`: 是否使用长序列扩展策略,默认为 False。
- `strategy_type`: 长序列扩展策略的类型,默认为 None。
- `strategy_name`: 长序列扩展策略的具体名称,默认为 None。
- `rope_scaling_factor`: 应用 RoPE 扩展策略时的缩放因子。
</div>

<summary>&emsp; 数据参数(DataArgument)</summary><div>
Expand All @@ -140,6 +144,8 @@ python merge_lora_params.py \
- `src_length`: 模型输入上下文最大 token 长度,默认为1024。
- `max_length`:模型输入(上下文+生成内容)的最大 token 长度, 默认为2048。当`zero_padding`设为 True 的时候,同时也为 Zero Padding 数据流模型训练输入最大长度,通常建议设为模型允许输入最大长度,同时`per_device_train_batch_size`设为1,使用`gradient_accumulation_steps`控制 batch size。
- `lazy`:设置为 False 则使用`MapDataset`,设置为 True 则使用`IterDataset`,默认为 False。对于数据量较大的时候建议设为 True,`IterDataset`可以避免一次性将所有数据读入内存,注意需要设置`max_steps`并且`evaluation_strategy``save_strategy`设为`steps`
- `autoregressive`: 是否使用自回归生成,即训练数据为无监督数据,默认为 False。
- `use_pose_convert`: 是否使用 PoSE 算法的数据处理,默认为 False。

</div>

Expand Down
18 changes: 18 additions & 0 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# import inspect
import json
import os
import sys
Expand Down Expand Up @@ -161,6 +162,22 @@ def main():

model_config.seq_length = data_args.max_length

# Config for model useing long sequence strategy
if model_args.use_long_sequence_strategies:
data_args.scaled_max_length = int(data_args.max_length * model_args.rope_scaling_factor)
model_config.use_long_sequence_strategies = True
model_config.long_sequence_strategy_type = model_args.strategy_type
model_config.long_sequence_strategy_name = model_args.strategy_name
model_config.rope_scaling_factor = model_args.rope_scaling_factor
model_config.long_sequence_init_args = {
"dim": int(model_config.hidden_size / model_config.num_attention_heads),
"max_position_embeddings": data_args.scaled_max_length, # extended context window
"base": model_config.rope_theta,
"scaling_factor": model_args.rope_scaling_factor,
}
if model_args.strategy_name == "YaRNScalingRotaryEmbedding":
model_config.long_sequence_init_args["original_max_position_embeddings"] = data_args.max_length

logger.info(f"Final model config: {model_config}")

model_class = AutoModelForCausalLM
Expand Down Expand Up @@ -365,6 +382,7 @@ def neft_post_hook(module, input, output):
if ptq_ds is not None
else None
)

eval_zero_padding = data_args.zero_padding
if data_args.zero_padding and data_args.eval_with_do_generation:
logger.warning(
Expand Down
14 changes: 14 additions & 0 deletions llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ class DataArgument:
default=False,
metadata={"help": "Pad the input sequence to `max_length`."},
)
autoregressive: bool = field(
default=False,
metadata={"help": "Whether to use autoregressive mode."},
)
# Pose ralated parameters
use_pose_convert: bool = field(default=False, metadata={"help": "Whether to use PoSE data conversion function"})

def __post_init__(self):
if self.task_name_or_path is not None:
Expand Down Expand Up @@ -229,6 +235,14 @@ class ModelArgument:
neftune_noise_alpha: float = field(default=5.0, metadata={"help": "NEFT noise alpha"})
flash_mask: bool = field(default=False, metadata={"help": "Whether to use flash_mask in flash attention."})

# long sequence strategy
use_long_sequence_strategies: bool = field(
default=False, metadata={"help": "Whether to use long sequence strategy"}
)
rope_scaling_factor: float = field(default=1.0, metadata={"help": "Rope extension scaling factor"})
strategy_type: str = field(default=None, metadata={"help": "Long sequence strategy type"})
strategy_name: str = field(default=None, metadata={"help": "Long sequence strategy name"})


@dataclass
class QuantArgument:
Expand Down
110 changes: 87 additions & 23 deletions llm/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import random

import numpy as np

from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
Expand Down Expand Up @@ -69,6 +71,27 @@ class DataFormatError(ValueError):
pass


def tokenize_unsupervised_example(tokenizer, example, data_args, is_test=True, zero_padding=False, flash_mask=False):
if "src" in example:
source = example["src"][0] if isinstance(example["src"], list) else example["src"]
else:
raise DataFormatError(
f"Example format is wrong, please check: {example} or rewrite tokenize_example in data.py "
)
tokenized_source = tokenizer(
source,
truncation=False,
padding=True,
max_length=data_args.scaled_max_length,
add_special_tokens=True,
)

if data_args.use_pose_convert:
tokenized_source = get_example_pose(tokenized_source, tokenizer, data_args)

return tokenized_source


def tokenize_example(tokenizer, example, data_args):
if "src" in example and "tgt" in example:
source = example["src"][0] if isinstance(example["src"], list) else example["src"]
Expand Down Expand Up @@ -177,33 +200,49 @@ def tokenize_rounds_example(tokenizer, example, data_args, **kwargs):


def convert_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False, flash_mask=False):
if tokenizer.chat_template is not None:
return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding, flash_mask)

tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args)

if is_test:
return {
**tokenized_source,
"labels": tokenized_target_input_ids,
}
else:
input_ids = tokenized_source["input_ids"] + tokenized_target_input_ids
source_length = len(tokenized_source["input_ids"])
labels = [-100] * source_length + input_ids[source_length:]
# shift input_ids and labels
input_ids, labels = input_ids[:-1], labels[1:]
seq_length = len(input_ids)
if data_args.autoregressive:
tokenized_source = tokenize_unsupervised_example(
tokenizer, example, data_args, is_test=True, zero_padding=False, flash_mask=False
)
input_ids = tokenized_source["input_ids"]
if "labels" in tokenized_source:
labels = tokenized_source["labels"]
else:
labels = input_ids
input_ids = input_ids[:-1] + [tokenizer.eos_token_id]
labels = labels[1:] + [-100]
features = {"input_ids": input_ids, "labels": labels}
if "position_ids" in tokenized_source:
features["position_ids"] = list(range(seq_length))
if zero_padding:
if flash_mask:
features["attn_mask_startend_row_indices"] = [seq_length] * seq_length
else:
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)
features["position_ids"] = tokenized_source["position_ids"]
else:
if tokenizer.chat_template is not None:
return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding, flash_mask)
else:
tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args)

return features
if is_test:
return {
**tokenized_source,
"labels": tokenized_target_input_ids,
}
else:
input_ids = tokenized_source["input_ids"] + tokenized_target_input_ids
source_length = len(tokenized_source["input_ids"])
labels = [-100] * source_length + input_ids[source_length:]
# shift input_ids and labels
input_ids, labels = input_ids[:-1], labels[1:]
seq_length = len(input_ids)
features = {"input_ids": input_ids, "labels": labels}
if "position_ids" in tokenized_source:
features["position_ids"] = list(range(seq_length))
# maybe change here to suit flash_mask with longlora
if zero_padding:
if flash_mask:
features["attn_mask_startend_row_indices"] = [seq_length] * seq_length
else:
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)
return features


def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False, flash_mask=False):
Expand Down Expand Up @@ -289,3 +328,28 @@ def convert_example_chatglm(example, tokenizer, data_args, is_test=True, zero_pa
features["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)

return features


def get_example_pose(tokenized_source, tokenizer, data_args):

ids = tokenized_source["input_ids"]
len_chunk = min(len(ids), data_args.max_length)
if len(tokenized_source["input_ids"]) <= data_args.max_length:
tokenized_source["input_ids"] += [tokenizer.eos_token_id]

len_input = len(ids)

lt1 = 0 # chunk1 start pos
rt1 = random.randint(1, (len_chunk) // 2) # chunk1 end pos

rt2 = random.randint(lt1 + len_chunk, len_input - 1) # chunk2 end pos
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # chunk2 start pos
chunked_ids = ids[lt1:rt1] + ids[lt2:rt2]
labels = ids[lt1 + 1 : rt1 + 1] + ids[lt2 + 1 : rt2 + 1]

pos_ids = range(len(chunked_ids))
pos_ids = [x + lt1 if i < rt1 - lt1 else x + (lt2 - (rt1 - lt1)) for i, x in enumerate(pos_ids)]

features = {"input_ids": chunked_ids, "labels": labels, "position_ids": pos_ids}

return features
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import paddle
from paddle import nn

Expand All @@ -20,6 +22,7 @@
"LinearScalingRotaryEmbedding",
"NTKScalingRotaryEmbedding",
"DynamicNTKScalingRotaryEmbedding",
"YaRNScalingRotaryEmbedding",
]


Expand Down Expand Up @@ -120,3 +123,101 @@ def forward(self, seq_len=None, ntk_alpha=None):
self._scale_cos_sin(seq_len=seq_len, ntk_alpha=ntk_alpha)

return self.cos_cached[:, :], self.sin_cached[:, :]


class YaRNScalingRotaryEmbedding(nn.Layer):
"""RotaryEmbedding extended with YaRN scaling."""

def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1,
original_max_position_embeddings=2048,
extrapolation_factor=1,
attn_factor=1,
beta_fast=32,
beta_slow=1,
):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.scaling_factor = scaling_factor # scaling_factor
self.original_max_position_embeddings = original_max_position_embeddings
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow

self.yarn()

self._set_cos_sin_cache(seq_len=self.max_position_embeddings)

def _set_cos_sin_cache(self, seq_len):
self.max_seq_len_cached = seq_len
# [seq_len]
t = paddle.arange(seq_len, dtype=paddle.float32)
# [seq_len, dim/2]
with paddle.amp.auto_cast(enable=False):
freqs = paddle.outer(t.astype(self.inv_freq.dtype), self.inv_freq)
# [seq_len, dim]
emb = paddle.concat([freqs, freqs], axis=-1)
self.cos_cached = emb.cos()[:, :] * self.mscale
self.sin_cached = emb.sin()[:, :] * self.mscale

def _scale_cos_sin(self, seq_len):
self.max_seq_len_cached = seq_len

t = paddle.arange(self.max_seq_len_cached, dtype=self.inv_freq.dtype)
freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
emb = paddle.concat((freqs, freqs), axis=-1)

self.cos_cached = emb.cos()[:, :] * self.mscale
self.sin_cached = emb.sin()[:, :] * self.mscale

def forward(self, seq_len=None, ntk_alpha=None):
if seq_len > self.max_seq_len_cached:
self._scale_cos_sin(seq_len=seq_len)

return self.cos_cached[:, :], self.sin_cached[:, :]

def yarn(self):
inv_freq = 1.0 / (self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype=paddle.float32) / self.dim))

low, high = self._yarn_find_correction_range(
self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings
)
inv_freq_mask = (
1 - paddle.cast(self._yarn_linear_ramp_mask(low, high, self.dim // 2), dtype=paddle.float32)
) * self.extrapolation_factor

inv_freq = inv_freq / ((1 - inv_freq_mask) * self.scaling_factor + inv_freq_mask)
self.register_buffer("inv_freq", inv_freq)
self.mscale = self._yarn_get_mscale(self.scaling_factor) * self.attn_factor

@classmethod
def _yarn_find_correction_dim(cls, num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))

@classmethod
def _yarn_find_correction_range(cls, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(cls._yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(cls._yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1) # Clamp values just in case

@classmethod
def _yarn_linear_ramp_mask(cls, low, high, dim):
if low == high:
high += 0.001 # Prevent singularity

linear_func = (paddle.arange(dim, dtype=paddle.float32) - low) / (high - low)
ramp_func = paddle.clip(linear_func, 0, 1)
return ramp_func

@classmethod
def _yarn_get_mscale(cls, scaling_factor=1):
if scaling_factor <= 1:
return 1.0
return 0.1 * math.log(scaling_factor) + 1.0
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ testpaths = [
"tests/generation",
"tests/layers",
"tests/metrics",
"tests/pose",
"tests/ops",
"tests/trainer",
"tests/transformers",
Expand Down
Loading

0 comments on commit 6813e40

Please sign in to comment.