Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why my QAT's convert doesn't work,still float32 #1290

Open
Xxxgrey opened this issue Nov 15, 2024 · 5 comments
Open

Why my QAT's convert doesn't work,still float32 #1290

Xxxgrey opened this issue Nov 15, 2024 · 5 comments
Assignees

Comments

@Xxxgrey
Copy link

Xxxgrey commented Nov 15, 2024

I try the original QAT code.

model = llama3(
        vocab_size=4096,
        num_layers=16,
        num_heads=16,
        num_kv_heads=4,
        embed_dim=2048,
        max_seq_len=2048,
    ).cuda()
    qat_quantizer = Int8DynActInt4WeightQATQuantizer()

    # Insert "fake quantize" operations into linear layers.
    # These operations simulate quantization numerics during
    # training without performing any dtype casting
    model = qat_quantizer.prepare(model)

    for name, param in model.named_parameters():
        print(f"Layer: {name}, Data type: {param.dtype}")
    # Standard training loop
    # optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
    # loss_fn = torch.nn.CrossEntropyLoss()
    # for epoch in range(num_epochs):
    #     train_model(model, train_loader, optimizer, loss_fn, num_epochs)
    print("================================================================")
    model = qat_quantizer.convert(model)
    for name, param in model.named_parameters():
        print(f"Layer: {name}, Data type: {param.dtype}")

but the result shows

Layer: tok_embeddings.weight, Data type: torch.float32
Layer: layers.0.sa_norm.scale, Data type: torch.float32
Layer: layers.0.attn.q_proj.weight, Data type: torch.float32
Layer: layers.0.attn.k_proj.weight, Data type: torch.float32
Layer: layers.0.attn.v_proj.weight, Data type: torch.float32
Layer: layers.0.attn.output_proj.weight, Data type: torch.float32
Layer: layers.0.mlp_norm.scale, Data type: torch.float32
Layer: layers.0.mlp.w1.weight, Data type: torch.float32
Layer: layers.0.mlp.w2.weight, Data type: torch.float32
Layer: layers.0.mlp.w3.weight, Data type: torch.float32
Layer: layers.1.sa_norm.scale, Data type: torch.float32
Layer: layers.1.attn.q_proj.weight, Data type: torch.float32
Layer: layers.1.attn.k_proj.weight, Data type: torch.float32
Layer: layers.1.attn.v_proj.weight, Data type: torch.float32
Layer: layers.1.attn.output_proj.weight, Data type: torch.float32
Layer: layers.1.mlp_norm.scale, Data type: torch.float32
Layer: layers.1.mlp.w1.weight, Data type: torch.float32
Layer: layers.1.mlp.w2.weight, Data type: torch.float32
Layer: layers.1.mlp.w3.weight, Data type: torch.float32
Layer: layers.2.sa_norm.scale, Data type: torch.float32
Layer: layers.2.attn.q_proj.weight, Data type: torch.float32
Layer: layers.2.attn.k_proj.weight, Data type: torch.float32
Layer: layers.2.attn.v_proj.weight, Data type: torch.float32
Layer: layers.2.attn.output_proj.weight, Data type: torch.float32
Layer: layers.2.mlp_norm.scale, Data type: torch.float32
Layer: layers.2.mlp.w1.weight, Data type: torch.float32
Layer: layers.2.mlp.w2.weight, Data type: torch.float32
Layer: layers.2.mlp.w3.weight, Data type: torch.float32
Layer: layers.3.sa_norm.scale, Data type: torch.float32
Layer: layers.3.attn.q_proj.weight, Data type: torch.float32
Layer: layers.3.attn.k_proj.weight, Data type: torch.float32
Layer: layers.3.attn.v_proj.weight, Data type: torch.float32
Layer: layers.3.attn.output_proj.weight, Data type: torch.float32
Layer: layers.3.mlp_norm.scale, Data type: torch.float32
Layer: layers.3.mlp.w1.weight, Data type: torch.float32
Layer: layers.3.mlp.w2.weight, Data type: torch.float32
Layer: layers.3.mlp.w3.weight, Data type: torch.float32
Layer: layers.4.sa_norm.scale, Data type: torch.float32
Layer: layers.4.attn.q_proj.weight, Data type: torch.float32
Layer: layers.4.attn.k_proj.weight, Data type: torch.float32
Layer: layers.4.attn.v_proj.weight, Data type: torch.float32
Layer: layers.4.attn.output_proj.weight, Data type: torch.float32
Layer: layers.4.mlp_norm.scale, Data type: torch.float32
Layer: layers.4.mlp.w1.weight, Data type: torch.float32
Layer: layers.4.mlp.w2.weight, Data type: torch.float32
Layer: layers.4.mlp.w3.weight, Data type: torch.float32
Layer: layers.5.sa_norm.scale, Data type: torch.float32
Layer: layers.5.attn.q_proj.weight, Data type: torch.float32
Layer: layers.5.attn.k_proj.weight, Data type: torch.float32
Layer: layers.5.attn.v_proj.weight, Data type: torch.float32
Layer: layers.5.attn.output_proj.weight, Data type: torch.float32
Layer: layers.5.mlp_norm.scale, Data type: torch.float32
Layer: layers.5.mlp.w1.weight, Data type: torch.float32
Layer: layers.5.mlp.w2.weight, Data type: torch.float32
Layer: layers.5.mlp.w3.weight, Data type: torch.float32
Layer: layers.6.sa_norm.scale, Data type: torch.float32
Layer: layers.6.attn.q_proj.weight, Data type: torch.float32
Layer: layers.6.attn.k_proj.weight, Data type: torch.float32
Layer: layers.6.attn.v_proj.weight, Data type: torch.float32
Layer: layers.6.attn.output_proj.weight, Data type: torch.float32
Layer: layers.6.mlp_norm.scale, Data type: torch.float32
Layer: layers.6.mlp.w1.weight, Data type: torch.float32
Layer: layers.6.mlp.w2.weight, Data type: torch.float32
Layer: layers.6.mlp.w3.weight, Data type: torch.float32
Layer: layers.7.sa_norm.scale, Data type: torch.float32
Layer: layers.7.attn.q_proj.weight, Data type: torch.float32
Layer: layers.7.attn.k_proj.weight, Data type: torch.float32
Layer: layers.7.attn.v_proj.weight, Data type: torch.float32
Layer: layers.7.attn.output_proj.weight, Data type: torch.float32
Layer: layers.7.mlp_norm.scale, Data type: torch.float32
Layer: layers.7.mlp.w1.weight, Data type: torch.float32
Layer: layers.7.mlp.w2.weight, Data type: torch.float32
Layer: layers.7.mlp.w3.weight, Data type: torch.float32
Layer: layers.8.sa_norm.scale, Data type: torch.float32
Layer: layers.8.attn.q_proj.weight, Data type: torch.float32
Layer: layers.8.attn.k_proj.weight, Data type: torch.float32
Layer: layers.8.attn.v_proj.weight, Data type: torch.float32
Layer: layers.8.attn.output_proj.weight, Data type: torch.float32
Layer: layers.8.mlp_norm.scale, Data type: torch.float32
Layer: layers.8.mlp.w1.weight, Data type: torch.float32
Layer: layers.8.mlp.w2.weight, Data type: torch.float32
Layer: layers.8.mlp.w3.weight, Data type: torch.float32
Layer: layers.9.sa_norm.scale, Data type: torch.float32
Layer: layers.9.attn.q_proj.weight, Data type: torch.float32
Layer: layers.9.attn.k_proj.weight, Data type: torch.float32
Layer: layers.9.attn.v_proj.weight, Data type: torch.float32
Layer: layers.9.attn.output_proj.weight, Data type: torch.float32
Layer: layers.9.mlp_norm.scale, Data type: torch.float32
Layer: layers.9.mlp.w1.weight, Data type: torch.float32
Layer: layers.9.mlp.w2.weight, Data type: torch.float32
Layer: layers.9.mlp.w3.weight, Data type: torch.float32
Layer: layers.10.sa_norm.scale, Data type: torch.float32
Layer: layers.10.attn.q_proj.weight, Data type: torch.float32
Layer: layers.10.attn.k_proj.weight, Data type: torch.float32
Layer: layers.10.attn.v_proj.weight, Data type: torch.float32
Layer: layers.10.attn.output_proj.weight, Data type: torch.float32
Layer: layers.10.mlp_norm.scale, Data type: torch.float32
Layer: layers.10.mlp.w1.weight, Data type: torch.float32
Layer: layers.10.mlp.w2.weight, Data type: torch.float32
Layer: layers.10.mlp.w3.weight, Data type: torch.float32
Layer: layers.11.sa_norm.scale, Data type: torch.float32
Layer: layers.11.attn.q_proj.weight, Data type: torch.float32
Layer: layers.11.attn.k_proj.weight, Data type: torch.float32
Layer: layers.11.attn.v_proj.weight, Data type: torch.float32
Layer: layers.11.attn.output_proj.weight, Data type: torch.float32
Layer: layers.11.mlp_norm.scale, Data type: torch.float32
Layer: layers.11.mlp.w1.weight, Data type: torch.float32
Layer: layers.11.mlp.w2.weight, Data type: torch.float32
Layer: layers.11.mlp.w3.weight, Data type: torch.float32
Layer: layers.12.sa_norm.scale, Data type: torch.float32
Layer: layers.12.attn.q_proj.weight, Data type: torch.float32
Layer: layers.12.attn.k_proj.weight, Data type: torch.float32
Layer: layers.12.attn.v_proj.weight, Data type: torch.float32
Layer: layers.12.attn.output_proj.weight, Data type: torch.float32
Layer: layers.12.mlp_norm.scale, Data type: torch.float32
Layer: layers.12.mlp.w1.weight, Data type: torch.float32
Layer: layers.12.mlp.w2.weight, Data type: torch.float32
Layer: layers.12.mlp.w3.weight, Data type: torch.float32
Layer: layers.13.sa_norm.scale, Data type: torch.float32
Layer: layers.13.attn.q_proj.weight, Data type: torch.float32
Layer: layers.13.attn.k_proj.weight, Data type: torch.float32
Layer: layers.13.attn.v_proj.weight, Data type: torch.float32
Layer: layers.13.attn.output_proj.weight, Data type: torch.float32
Layer: layers.13.mlp_norm.scale, Data type: torch.float32
Layer: layers.13.mlp.w1.weight, Data type: torch.float32
Layer: layers.13.mlp.w2.weight, Data type: torch.float32
Layer: layers.13.mlp.w3.weight, Data type: torch.float32
Layer: layers.14.sa_norm.scale, Data type: torch.float32
Layer: layers.14.attn.q_proj.weight, Data type: torch.float32
Layer: layers.14.attn.k_proj.weight, Data type: torch.float32
Layer: layers.14.attn.v_proj.weight, Data type: torch.float32
Layer: layers.14.attn.output_proj.weight, Data type: torch.float32
Layer: layers.14.mlp_norm.scale, Data type: torch.float32
Layer: layers.14.mlp.w1.weight, Data type: torch.float32
Layer: layers.14.mlp.w2.weight, Data type: torch.float32
Layer: layers.14.mlp.w3.weight, Data type: torch.float32
Layer: layers.15.sa_norm.scale, Data type: torch.float32
Layer: layers.15.attn.q_proj.weight, Data type: torch.float32
Layer: layers.15.attn.k_proj.weight, Data type: torch.float32
Layer: layers.15.attn.v_proj.weight, Data type: torch.float32
Layer: layers.15.attn.output_proj.weight, Data type: torch.float32
Layer: layers.15.mlp_norm.scale, Data type: torch.float32
Layer: layers.15.mlp.w1.weight, Data type: torch.float32
Layer: layers.15.mlp.w2.weight, Data type: torch.float32
Layer: layers.15.mlp.w3.weight, Data type: torch.float32
Layer: norm.scale, Data type: torch.float32
Layer: output.weight, Data type: torch.float32
================================================================
Layer: tok_embeddings.weight, Data type: torch.float32
Layer: layers.0.sa_norm.scale, Data type: torch.float32
Layer: layers.0.mlp_norm.scale, Data type: torch.float32
Layer: layers.1.sa_norm.scale, Data type: torch.float32
Layer: layers.1.mlp_norm.scale, Data type: torch.float32
Layer: layers.2.sa_norm.scale, Data type: torch.float32
Layer: layers.2.mlp_norm.scale, Data type: torch.float32
Layer: layers.3.sa_norm.scale, Data type: torch.float32
Layer: layers.3.mlp_norm.scale, Data type: torch.float32
Layer: layers.4.sa_norm.scale, Data type: torch.float32
Layer: layers.4.mlp_norm.scale, Data type: torch.float32
Layer: layers.5.sa_norm.scale, Data type: torch.float32
Layer: layers.5.mlp_norm.scale, Data type: torch.float32
Layer: layers.6.sa_norm.scale, Data type: torch.float32
Layer: layers.6.mlp_norm.scale, Data type: torch.float32
Layer: layers.7.sa_norm.scale, Data type: torch.float32
Layer: layers.7.mlp_norm.scale, Data type: torch.float32
Layer: layers.8.sa_norm.scale, Data type: torch.float32
Layer: layers.8.mlp_norm.scale, Data type: torch.float32
Layer: layers.9.sa_norm.scale, Data type: torch.float32
Layer: layers.9.mlp_norm.scale, Data type: torch.float32
Layer: layers.10.sa_norm.scale, Data type: torch.float32
Layer: layers.10.mlp_norm.scale, Data type: torch.float32
Layer: layers.11.sa_norm.scale, Data type: torch.float32
Layer: layers.11.mlp_norm.scale, Data type: torch.float32
Layer: layers.12.sa_norm.scale, Data type: torch.float32
Layer: layers.12.mlp_norm.scale, Data type: torch.float32
Layer: layers.13.sa_norm.scale, Data type: torch.float32
Layer: layers.13.mlp_norm.scale, Data type: torch.float32
Layer: layers.14.sa_norm.scale, Data type: torch.float32
Layer: layers.14.mlp_norm.scale, Data type: torch.float32
Layer: layers.15.sa_norm.scale, Data type: torch.float32
Layer: layers.15.mlp_norm.scale, Data type: torch.float32
Layer: norm.scale, Data type: torch.float32

It doesn't convert to int8 or int4.

@jerryzh168
Copy link
Contributor

This is because we are producing a model that's going to be lowered to executorch for speedup I think. Here is the doc for QAT: https://pytorch.org/blog/quantization-aware-training/
cc @andrewor14 is there a doc for lowering QAT model to executorch?

@Xxxgrey
Copy link
Author

Xxxgrey commented Nov 15, 2024

This is because we are producing a model that's going to be lowered to executorch for speedup I think. Here is the doc for QAT: https://pytorch.org/blog/quantization-aware-training/ cc @andrewor14 is there a doc for lowering QAT model to executorch?

Do I need to use torch.quantization.convert() to quantize my model to int8?

@andrewor14 andrewor14 self-assigned this Nov 15, 2024
@andrewor14
Copy link
Contributor

Hi @Xxxgrey, I see you printed the parameter dtypes twice:

  • Right after prepare: At this step everything should be in the original (float) precision, so what you're seeing is expected. This is because the prepare step just inserts "fake" quantization into the linear layers of the model, and this only simulates the numerics of quantization without casting the activations/weights to lower bit-widths.
  • Right after convert: If you look closely at what's printed, the linear weights are not part of the parameters anymore. The Int8DynActInt4WeightQATQuantizer you're using only quantizes the linear layers, so the other parameters are left untouched (and stay at fp32). After calling convert, the linear weights are now quantized and no longer trainable. You can confirm this by printing out the linear weight dtypes specifically:
>>> type(model.layers[0].attn.q_proj.weight)
<class 'torch.Tensor'>
>>> model.layers[0].attn.q_proj.weight.dtype
torch.int8

(Note that this is in int8 not int4 because torch.int4 wasn't natively supported yet when this flow was built. We will update this in the future.)

@Xxxgrey
Copy link
Author

Xxxgrey commented Nov 16, 2024

Hi @Xxxgrey, I see you printed the parameter dtypes twice:

  • **准备后:**在此步骤中,所有内容都应为原始 (float) 精度,因此您看到的是预期的。这是因为 prepare 步骤只是将 “假” 量化插入模型的线性层中,并且这仅模拟量化的数字,而不会将激活/权重转换为较低的位宽。
  • **转换后立即:**如果您仔细观察打印的内容,线性权重不再是参数的一部分。你正在使用的 只量化线性层,所以其他参数保持不变(并保持在 fp32 上)。调用 convert 后,线性权重现在被量化,不再可训练。您可以通过具体打印出线性权重 dtypes 来确认这一点:Int8DynActInt4WeightQATQuantizer
>>> type(model.layers[0].attn.q_proj.weight)
<class 'torch.Tensor'>
>>> model.layers[0].attn.q_proj.weight.dtype
torch.int8

(请注意,这是在 int8 而不是 int4 中,因为在构建此流程时,尚未原生支持 torch.int4 。我们将来会更新它。

I got it ,so is there any qat method that is not only quantize the

Hi @Xxxgrey, I see you printed the parameter dtypes twice:

  • Right after prepare: At this step everything should be in the original (float) precision, so what you're seeing is expected. This is because the prepare step just inserts "fake" quantization into the linear layers of the model, and this only simulates the numerics of quantization without casting the activations/weights to lower bit-widths.
  • Right after convert: If you look closely at what's printed, the linear weights are not part of the parameters anymore. The Int8DynActInt4WeightQATQuantizer you're using only quantizes the linear layers, so the other parameters are left untouched (and stay at fp32). After calling convert, the linear weights are now quantized and no longer trainable. You can confirm this by printing out the linear weight dtypes specifically:
>>> type(model.layers[0].attn.q_proj.weight)
<class 'torch.Tensor'>
>>> model.layers[0].attn.q_proj.weight.dtype
torch.int8

(Note that this is in int8 not int4 because torch.int4 wasn't natively supported yet when this flow was built. We will update this in the future.)

I got it ,so is there any qat method that is not only quantize the liner layers? I only see these. Seems like they are all only quantize the linear layers.

"ComposableQATQuantizer",
    "Int4WeightOnlyQATQuantizer",
    "Int4WeightOnlyEmbeddingQATQuantizer",
    "Int8DynActInt4WeightQATQuantizer",

@andrewor14
Copy link
Contributor

I got it ,so is there any qat method that is not only quantize the liner layers? I only see these. Seems like they are all only quantize the linear layers.

Yeah today we only have support for linear and embedding layers:

# linear:
Int4WeightOnlyQATQuantizer
Int8DynActInt4WeightQATQuantizer

# embedding:
Int4WeightOnlyEmbeddingQATQuantizer

If you need other bit-width / quantization schemes, you can also use the generic FakeQuantizedLinear and FakeQuantizedEmbedding classes directly. Here's an example usage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants