From 98ac6e4961984afee1806464b845a2e994366008 Mon Sep 17 00:00:00 2001 From: "Yuxuan.Zhang" <2448370773@qq.com> Date: Tue, 26 Nov 2024 22:05:42 +0800 Subject: [PATCH] change apply_rotary_pos_emb of Glmmodel for GLM-Edge Series model (#34629) * change apply_rotary_pos_emb * upload for glm-edge * remove useless part * follow the suggestion * fix * format * format * test * format again * format again * remove modular change * remove modular change * this apply_rotary_pos_emb need modify? * fix with this * format * format * ruff check * modify modular_glm failed * remove partial_rotary_factor of function partial_rotary_factor * fix wrong change of examples/research_projects * revert * remove line 118 * use q_rot --- .../models/glm/configuration_glm.py | 3 + .../models/glm/convert_glm_weights_to_hf.py | 71 ++++++++++++------- src/transformers/models/glm/modeling_glm.py | 17 +++-- src/transformers/models/glm/modular_glm.py | 17 +++-- 4 files changed, 69 insertions(+), 39 deletions(-) diff --git a/src/transformers/models/glm/configuration_glm.py b/src/transformers/models/glm/configuration_glm.py index 85d32a7c691a18..de0e80e8c65ba9 100644 --- a/src/transformers/models/glm/configuration_glm.py +++ b/src/transformers/models/glm/configuration_glm.py @@ -45,6 +45,7 @@ class GlmConfig(PretrainedConfig): by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): The factor of the partial rotary position. head_dim (`int`, *optional*, defaults to 128): The attention head dimension. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): @@ -93,6 +94,7 @@ def __init__( num_hidden_layers=40, num_attention_heads=32, num_key_value_heads=2, + partial_rotary_factor=0.5, head_dim=128, hidden_act="silu", attention_dropout=0.0, @@ -114,6 +116,7 @@ def __init__( self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads + self.partial_rotary_factor = partial_rotary_factor self.head_dim = head_dim self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act diff --git a/src/transformers/models/glm/convert_glm_weights_to_hf.py b/src/transformers/models/glm/convert_glm_weights_to_hf.py index 3878ce0d25814a..1053f984d7f053 100644 --- a/src/transformers/models/glm/convert_glm_weights_to_hf.py +++ b/src/transformers/models/glm/convert_glm_weights_to_hf.py @@ -37,16 +37,28 @@ # fmt: on -def merge_safetensors(input_dir: str): - all_files = [os.path.join(input_dir, x) for x in os.listdir(input_dir) if x.endswith(".safetensors")] - all_files = sorted(all_files, key=lambda x: int(x.rsplit("-", 3)[1])) +def load_weights(input_dir: str): + safetensor_files = [os.path.join(input_dir, x) for x in os.listdir(input_dir) if x.endswith(".safetensors")] + bin_files = [os.path.join(input_dir, x) for x in os.listdir(input_dir) if x.endswith(".bin")] all_weights = {} - for file in all_files: - tensors = load_file(file) - all_weights.update(tensors) - return all_weights + if safetensor_files: + safetensor_files = sorted(safetensor_files, key=lambda x: int(x.rsplit("-", 3)[1])) + for file in safetensor_files: + tensors = load_file(file) + all_weights.update(tensors) + return all_weights + + elif bin_files: + bin_files = sorted(bin_files, key=lambda x: int(x.rsplit("-", 3)[1])) + for file in bin_files: + tensors = torch.load(file, map_location="cpu") + all_weights.update(tensors) + return all_weights + + else: + raise ValueError("No .safetensors or .bin files found in the specified directory.") def map_old_key_to_new(old_key): @@ -100,7 +112,8 @@ def convert_config(original_config: dict): "attention_bias": "add_qkv_bias", } similar_keys_to_keep = [ - "num_attention_heads" "hidden_size", + "num_attention_heads", + "hidden_size", "attention_dropout", "use_cache", "eos_token_id", @@ -120,24 +133,27 @@ def convert_config(original_config: dict): return new_config -def convert_glm_tokenizer(input_dir): +def convert_glm_tokenizer(input_dir, use_post_processor=False): fast_tok = PreTrainedTokenizerFast.from_pretrained(input_dir, model_input_names=["input_ids", "attention_mask"]) - # Add the two tokens automatically with post processor - fast_tok._tokenizer.post_processor = processors.Sequence( - [ - processors.ByteLevel(trim_offsets=False), - processors.TemplateProcessing( - single="[gMASK]:0 :0 $A:0", - pair="[gMASK]:0 :0 $A:0 $B:1", - special_tokens=[("[gMASK]", 151331), ("", 151333)], - ), - ], - ) - + if use_post_processor: + fast_tok._tokenizer.post_processor = processors.Sequence( + [ + processors.ByteLevel(trim_offsets=False), + processors.TemplateProcessing( + single="[gMASK]:0 :0 $A:0", + pair="[gMASK]:0 :0 $A:0 $B:1", + special_tokens=[("[gMASK]", 151331), ("", 151333)], + ), + ], + ) + else: + fast_tok._tokenizer.post_processor = processors.Sequence( + [processors.ByteLevel(trim_offsets=False)], + ) return fast_tok -def convert_glm_model(input_dir, output_dir): +def convert_glm_model(input_dir, output_dir, use_post_processor=False): # Load and convert config with open(os.path.join(input_dir, "config.json")) as f: original_config = json.load(f) @@ -145,7 +161,7 @@ def convert_glm_model(input_dir, output_dir): config.save_pretrained(output_dir) # Load and convert weights - original_state_dict = merge_safetensors(input_dir) + original_state_dict = load_weights(input_dir) new_dict = convert_state_dict(original_state_dict, config) with torch.device("meta"): model = GlmForCausalLM(config) @@ -153,7 +169,7 @@ def convert_glm_model(input_dir, output_dir): model.save_pretrained(output_dir) # Load and convert tokenizer - tokenizer = convert_glm_tokenizer(input_dir) + tokenizer = convert_glm_tokenizer(input_dir, use_post_processor) tokenizer.save_pretrained(output_dir) @@ -169,6 +185,11 @@ def convert_glm_model(input_dir, output_dir): type=str, help="Location to write HF model and tokenizer", ) + parser.add_argument( + "--use_post_processor", + action="store_true", + help="Whether to apply post processor with special tokens", + ) args = parser.parse_args() - convert_glm_model(args.input_dir, args.output_dir) + convert_glm_model(args.input_dir, args.output_dir, args.use_post_processor) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 9080b5b9cc7c39..16a724f69464a9 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -169,13 +169,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) - # Keep half for later concatenation - q, q_pass = q[..., : q.shape[-1] // 2], q[..., q.shape[-1] // 2 :] - k, k_pass = k[..., : k.shape[-1] // 2], k[..., k.shape[-1] // 2 :] + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] - # Apply rotary embeddings on the first half - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) # Concatenate back to full shape q_embed = torch.cat([q_embed, q_pass], dim=-1) @@ -705,7 +706,9 @@ def __init__(self, config: GlmConfig): ) self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = GlmRotaryEmbedding( - dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta + dim=int(config.head_dim * config.partial_rotary_factor), + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, ) self.gradient_checkpointing = False if getattr(config, "pretraining_tp", 1) != 1: diff --git a/src/transformers/models/glm/modular_glm.py b/src/transformers/models/glm/modular_glm.py index 39ee4a2ad5803e..48605c15d30be3 100644 --- a/src/transformers/models/glm/modular_glm.py +++ b/src/transformers/models/glm/modular_glm.py @@ -95,13 +95,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) - # Keep half for later concatenation - q, q_pass = q[..., : q.shape[-1] // 2], q[..., q.shape[-1] // 2 :] - k, k_pass = k[..., : k.shape[-1] // 2], k[..., k.shape[-1] // 2 :] + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] - # Apply rotary embeddings on the first half - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) # Concatenate back to full shape q_embed = torch.cat([q_embed, q_pass], dim=-1) @@ -152,7 +153,9 @@ def __init__(self, config: GlmConfig): ) self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = GlmRotaryEmbedding( - dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta + dim=int(config.head_dim * config.partial_rotary_factor), + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, ) self.gradient_checkpointing = False