Skip to content

Commit

Permalink
change apply_rotary_pos_emb of Glmmodel for GLM-Edge Series model (hu…
Browse files Browse the repository at this point in the history
…ggingface#34629)

* change apply_rotary_pos_emb

* upload for glm-edge

* remove useless part

* follow the suggestion

* fix

* format

* format

* test

* format again

* format again

* remove modular change

* remove modular change

* this apply_rotary_pos_emb need modify?

* fix with this

* format

* format

* ruff check

* modify modular_glm failed

* remove partial_rotary_factor of function  partial_rotary_factor

* fix wrong change of examples/research_projects

* revert

* remove line 118

* use q_rot
  • Loading branch information
zRzRzRzRzRzRzR authored and BernardZach committed Dec 6, 2024
1 parent bf74b85 commit 98ac6e4
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 39 deletions.
3 changes: 3 additions & 0 deletions src/transformers/models/glm/configuration_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class GlmConfig(PretrainedConfig):
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
partial_rotary_factor (`float`, *optional*, defaults to 0.5): The factor of the partial rotary position.
head_dim (`int`, *optional*, defaults to 128):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
Expand Down Expand Up @@ -93,6 +94,7 @@ def __init__(
num_hidden_layers=40,
num_attention_heads=32,
num_key_value_heads=2,
partial_rotary_factor=0.5,
head_dim=128,
hidden_act="silu",
attention_dropout=0.0,
Expand All @@ -114,6 +116,7 @@ def __init__(
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.partial_rotary_factor = partial_rotary_factor
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
Expand Down
71 changes: 46 additions & 25 deletions src/transformers/models/glm/convert_glm_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,28 @@
# fmt: on


def merge_safetensors(input_dir: str):
all_files = [os.path.join(input_dir, x) for x in os.listdir(input_dir) if x.endswith(".safetensors")]
all_files = sorted(all_files, key=lambda x: int(x.rsplit("-", 3)[1]))
def load_weights(input_dir: str):
safetensor_files = [os.path.join(input_dir, x) for x in os.listdir(input_dir) if x.endswith(".safetensors")]
bin_files = [os.path.join(input_dir, x) for x in os.listdir(input_dir) if x.endswith(".bin")]

all_weights = {}
for file in all_files:
tensors = load_file(file)
all_weights.update(tensors)

return all_weights
if safetensor_files:
safetensor_files = sorted(safetensor_files, key=lambda x: int(x.rsplit("-", 3)[1]))
for file in safetensor_files:
tensors = load_file(file)
all_weights.update(tensors)
return all_weights

elif bin_files:
bin_files = sorted(bin_files, key=lambda x: int(x.rsplit("-", 3)[1]))
for file in bin_files:
tensors = torch.load(file, map_location="cpu")
all_weights.update(tensors)
return all_weights

else:
raise ValueError("No .safetensors or .bin files found in the specified directory.")


def map_old_key_to_new(old_key):
Expand Down Expand Up @@ -100,7 +112,8 @@ def convert_config(original_config: dict):
"attention_bias": "add_qkv_bias",
}
similar_keys_to_keep = [
"num_attention_heads" "hidden_size",
"num_attention_heads",
"hidden_size",
"attention_dropout",
"use_cache",
"eos_token_id",
Expand All @@ -120,40 +133,43 @@ def convert_config(original_config: dict):
return new_config


def convert_glm_tokenizer(input_dir):
def convert_glm_tokenizer(input_dir, use_post_processor=False):
fast_tok = PreTrainedTokenizerFast.from_pretrained(input_dir, model_input_names=["input_ids", "attention_mask"])
# Add the two tokens automatically with post processor
fast_tok._tokenizer.post_processor = processors.Sequence(
[
processors.ByteLevel(trim_offsets=False),
processors.TemplateProcessing(
single="[gMASK]:0 <sop>:0 $A:0",
pair="[gMASK]:0 <sop>:0 $A:0 $B:1",
special_tokens=[("[gMASK]", 151331), ("<sop>", 151333)],
),
],
)

if use_post_processor:
fast_tok._tokenizer.post_processor = processors.Sequence(
[
processors.ByteLevel(trim_offsets=False),
processors.TemplateProcessing(
single="[gMASK]:0 <sop>:0 $A:0",
pair="[gMASK]:0 <sop>:0 $A:0 $B:1",
special_tokens=[("[gMASK]", 151331), ("<sop>", 151333)],
),
],
)
else:
fast_tok._tokenizer.post_processor = processors.Sequence(
[processors.ByteLevel(trim_offsets=False)],
)
return fast_tok


def convert_glm_model(input_dir, output_dir):
def convert_glm_model(input_dir, output_dir, use_post_processor=False):
# Load and convert config
with open(os.path.join(input_dir, "config.json")) as f:
original_config = json.load(f)
config = convert_config(original_config)
config.save_pretrained(output_dir)

# Load and convert weights
original_state_dict = merge_safetensors(input_dir)
original_state_dict = load_weights(input_dir)
new_dict = convert_state_dict(original_state_dict, config)
with torch.device("meta"):
model = GlmForCausalLM(config)
model.load_state_dict(new_dict, strict=True, assign=True)
model.save_pretrained(output_dir)

# Load and convert tokenizer
tokenizer = convert_glm_tokenizer(input_dir)
tokenizer = convert_glm_tokenizer(input_dir, use_post_processor)
tokenizer.save_pretrained(output_dir)


Expand All @@ -169,6 +185,11 @@ def convert_glm_model(input_dir, output_dir):
type=str,
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--use_post_processor",
action="store_true",
help="Whether to apply post processor with special tokens",
)

args = parser.parse_args()
convert_glm_model(args.input_dir, args.output_dir)
convert_glm_model(args.input_dir, args.output_dir, args.use_post_processor)
17 changes: 10 additions & 7 deletions src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)

# Keep half for later concatenation
q, q_pass = q[..., : q.shape[-1] // 2], q[..., q.shape[-1] // 2 :]
k, k_pass = k[..., : k.shape[-1] // 2], k[..., k.shape[-1] // 2 :]
# Keep half or full tensor for later concatenation
rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]

# Apply rotary embeddings on the first half
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
# Apply rotary embeddings on the first half or full tensor
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)

# Concatenate back to full shape
q_embed = torch.cat([q_embed, q_pass], dim=-1)
Expand Down Expand Up @@ -705,7 +706,9 @@ def __init__(self, config: GlmConfig):
)
self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = GlmRotaryEmbedding(
dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta
dim=int(config.head_dim * config.partial_rotary_factor),
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
)
self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
Expand Down
17 changes: 10 additions & 7 deletions src/transformers/models/glm/modular_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)

# Keep half for later concatenation
q, q_pass = q[..., : q.shape[-1] // 2], q[..., q.shape[-1] // 2 :]
k, k_pass = k[..., : k.shape[-1] // 2], k[..., k.shape[-1] // 2 :]
# Keep half or full tensor for later concatenation
rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]

# Apply rotary embeddings on the first half
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
# Apply rotary embeddings on the first half or full tensor
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)

# Concatenate back to full shape
q_embed = torch.cat([q_embed, q_pass], dim=-1)
Expand Down Expand Up @@ -152,7 +153,9 @@ def __init__(self, config: GlmConfig):
)
self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = GlmRotaryEmbedding(
dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta
dim=int(config.head_dim * config.partial_rotary_factor),
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
)
self.gradient_checkpointing = False

Expand Down

0 comments on commit 98ac6e4

Please sign in to comment.