Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Block-prune qwen2 model not works on torch-pruning #436

Open
lifelongeeek opened this issue Nov 19, 2024 · 2 comments
Open

Block-prune qwen2 model not works on torch-pruning #436

lifelongeeek opened this issue Nov 19, 2024 · 2 comments

Comments

@lifelongeeek
Copy link

I found that Qwen-2 pruning example is recently added in torch-pruning
: https://github.com/VainF/Torch-Pruning/tree/master/examples/LLMs#rocket-qwenqwen2-7b
Thanks for updating!

I try this script to block-pruned (4 blocks) qwen-2 architecture.

  • Original architecture
Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 3080)
    (layers): ModuleList(
      (0-3): 4 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=3080, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3080, out_features=512, bias=True)
          (v_proj): Linear(in_features=3080, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3080, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=3080, out_features=16288, bias=False)
          (up_proj): Linear(in_features=3080, out_features=16288, bias=False)
          (down_proj): Linear(in_features=16288, out_features=3080, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((3080,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((3080,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((3080,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbedding()
  )
  (lm_head): Linear(in_features=3080, out_features=152064, bias=False)
)

However,
when I try to prune this model,

python prune_llm.py --model $BLOCK_PRUNED_MODEL_PATH --pruning_ratio $PRUNE_RATIO --max_seq_len 4096 --save_model $SAVE_MODEL_PATH

Reloading (via AutoModelForCausalLM.from_pretrained()) this model fails due to either

  • case 1) hidden dimension / head dimension mismatch
    • e.g) PRUNE_RATIO=0.12
    • ValueError: hidden_size must be divisible by num_heads (got hidden_size: 3152 and num_heads: 28).
  • case 2) state_dict & model weight shape mismatch
    • e.g) PRUNE_RATIO=0.14
RuntimeError: Error(s) in loading state_dict for Qwen2ForCausalLM:
	size mismatch for model.layers.0.self_attn.q_proj.weight: copying a param with shape torch.Size([3584, 3080]) from checkpoint, the shape in current model is torch.Size([3080, 3080]).
	size mismatch for model.layers.0.self_attn.q_proj.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3080]).
	size mismatch for model.layers.0.self_attn.k_proj.weight: copying a param with shape torch.Size([512, 3080]) from checkpoint, the shape in current model is torch.Size([440, 3080]).
	size mismatch for model.layers.0.self_attn.k_proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([440]).
	size mismatch for model.layers.0.self_attn.v_proj.weight: copying a param with shape torch.Size([512, 3080]) from checkpoint, the shape in current model is torch.Size([440, 3080]).
	size mismatch for model.layers.0.self_attn.v_proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([440]).
	size mismatch for model.layers.0.self_attn.o_proj.weight: copying a param with shape torch.Size([3080, 3584]) from checkpoint, the shape in current model is torch.Size([3080, 3080]).
	size mismatch for model.layers.0.mlp.gate_proj.weight: copying a param with shape torch.Size([16288, 3080]) from checkpoint, the shape in current model is torch.Size([18944, 3080]).
	size mismatch for model.layers.0.mlp.up_proj.weight: copying a param with shape torch.Size([16288, 3080]) from checkpoint, the shape in current model is torch.Size([18944, 3080]).
	size mismatch for model.layers.0.mlp.down_proj.weight: copying a param with shape torch.Size([3080, 16288]) from checkpoint, the shape in current model is torch.Size([3080, 18944]).
	size mismatch for model.layers.1.self_attn.q_proj.weight: copying a param with shape torch.Size([3584, 3080]) from checkpoint, the shape in current model is torch.Size([3080, 3080]).
	size mismatch for model.layers.1.self_attn.q_proj.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3080]).
	size mismatch for model.layers.1.self_attn.k_proj.weight: copying a param with shape torch.Size([512, 3080]) from checkpoint, the shape in current model is torch.Size([440, 3080]).
	size mismatch for model.layers.1.self_attn.k_proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([440]).
	size mismatch for model.layers.1.self_attn.v_proj.weight: copying a param with shape torch.Size([512, 3080]) from checkpoint, the shape in current model is torch.Size([440, 3080]).
	size mismatch for model.layers.1.self_attn.v_proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([440]).
	size mismatch for model.layers.1.self_attn.o_proj.weight: copying a param with shape torch.Size([3080, 3584]) from checkpoint, the shape in current model is torch.Size([3080, 3080]).
	size mismatch for model.layers.1.mlp.gate_proj.weight: copying a param with shape torch.Size([16288, 3080]) from checkpoint, the shape in current model is torch.Size([18944, 3080]).
	size mismatch for model.layers.1.mlp.up_proj.weight: copying a param with shape torch.Size([16288, 3080]) from checkpoint, the shape in current model is torch.Size([18944, 3080]).
	size mismatch for model.layers.1.mlp.down_proj.weight: copying a param with shape torch.Size([3080, 16288]) from checkpoint, the shape in current model is torch.Size([3080, 18944]).
	size mismatch for model.layers.2.self_attn.q_proj.weight: copying a param with shape torch.Size([3584, 3080]) from checkpoint, the shape in current model is torch.Size([3080, 3080]).
	size mismatch for model.layers.2.self_attn.q_proj.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3080]).
	size mismatch for model.layers.2.self_attn.k_proj.weight: copying a param with shape torch.Size([512, 3080]) from checkpoint, the shape in current model is torch.Size([440, 3080]).
	size mismatch for model.layers.2.self_attn.k_proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([440]).
	size mismatch for model.layers.2.self_attn.v_proj.weight: copying a param with shape torch.Size([512, 3080]) from checkpoint, the shape in current model is torch.Size([440, 3080]).
	size mismatch for model.layers.2.self_attn.v_proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([440]).
	size mismatch for model.layers.2.self_attn.o_proj.weight: copying a param with shape torch.Size([3080, 3584]) from checkpoint, the shape in current model is torch.Size([3080, 3080]).
	size mismatch for model.layers.2.mlp.gate_proj.weight: copying a param with shape torch.Size([16288, 3080]) from checkpoint, the shape in current model is torch.Size([18944, 3080]).
	size mismatch for model.layers.2.mlp.up_proj.weight: copying a param with shape torch.Size([16288, 3080]) from checkpoint, the shape in current model is torch.Size([18944, 3080]).
	size mismatch for model.layers.2.mlp.down_proj.weight: copying a param with shape torch.Size([3080, 16288]) from checkpoint, the shape in current model is torch.Size([3080, 18944]).
	size mismatch for model.layers.3.self_attn.q_proj.weight: copying a param with shape torch.Size([3584, 3080]) from checkpoint, the shape in current model is torch.Size([3080, 3080]).
	size mismatch for model.layers.3.self_attn.q_proj.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3080]).
	size mismatch for model.layers.3.self_attn.k_proj.weight: copying a param with shape torch.Size([512, 3080]) from checkpoint, the shape in current model is torch.Size([440, 3080]).
	size mismatch for model.layers.3.self_attn.k_proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([440]).
	size mismatch for model.layers.3.self_attn.v_proj.weight: copying a param with shape torch.Size([512, 3080]) from checkpoint, the shape in current model is torch.Size([440, 3080]).
	size mismatch for model.layers.3.self_attn.v_proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([440]).
	size mismatch for model.layers.3.self_attn.o_proj.weight: copying a param with shape torch.Size([3080, 3584]) from checkpoint, the shape in current model is torch.Size([3080, 3080]).
	size mismatch for model.layers.3.mlp.gate_proj.weight: copying a param with shape torch.Size([16288, 3080]) from checkpoint, the shape in current model is torch.Size([18944, 3080]).
	size mismatch for model.layers.3.mlp.up_proj.weight: copying a param with shape torch.Size([16288, 3080]) from checkpoint, the shape in current model is torch.Size([18944, 3080]).
	size mismatch for model.layers.3.mlp.down_proj.weight: copying a param with shape torch.Size([3080, 16288]) from checkpoint, the shape in current model is torch.Size([3080, 18944]).
	You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.  

Could you suggest how to properly prune block-pruned qwen-2 model?
FYI, the following is my environment

torch 2.5.1
transformers 4.46.3
accelerate 1.1.1
@VainF
Copy link
Owner

VainF commented Nov 19, 2024

Hi @lifelongeeek, thanks for the information. There is indeed a bug in the Config. Wrong values were assigned to hidden_size and intermediate_size. The issue has now been resolved in this commit.

But I saw some new issues with offloading. Not sure if this is triggered by limited GPU Mem . I will check this later.

@lifelongeeek
Copy link
Author

With this commit you mentioned, I can still see the error messages for --pruning_ratio=0.1
ValueError: hidden_size must be divisible by num_heads (got hidden_size: 3224 and num_heads: 28).

Here is the before/after pruning

----------------- Before Pruning -----------------
Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 3584)
    (layers): ModuleList(
      (0-3): 4 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3584, out_features=512, bias=True)
          (v_proj): Linear(in_features=3584, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((3584,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbedding()
  )
  (lm_head): Linear(in_features=3584, out_features=152064, bias=False)
)

----------------- After Pruning -----------------
Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 3224)
    (layers): ModuleList(
      (0-3): 4 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=3224, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3224, out_features=512, bias=True)
          (v_proj): Linear(in_features=3224, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3224, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=3224, out_features=17048, bias=False)
          (up_proj): Linear(in_features=3224, out_features=17048, bias=False)
          (down_proj): Linear(in_features=17048, out_features=3224, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((3224,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((3224,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((3224,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbedding()
  )
  (lm_head): Linear(in_features=3224, out_features=152064, bias=False)
)
Qwen2Config {
  "_attn_implementation_autoset": true,
  "_name_or_path": "ed-nt/qwen-2",
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 3224,
  "initializer_range": 0.02,
  "intermediate_size": 17048,
  "max_position_embeddings": 32768,
  "max_window_layers": 28,
  "model_type": "qwen2",
  "num_attention_heads": 28,
  "num_hidden_layers": 4,
  "num_key_value_heads": 4,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 1000000.0,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.46.3",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 152064
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants