模型量化策略:PTQ、QAT、LSQ
对于一个float x,将其写成:$x=S(q_x-Z)$,其中S是scale,Z是zero-point。
对于一个浮点数运算,比如$y=ax+b$,其中a和b也需要各自量化为定点数q,那么可以全部带入计算,就会发现量化后的$q_y$可以用$q_a,q_x,q_b$的计算加上某个缩放得到,因此,可以在计算过程中只使用定点数,然后将结果反量化,即可得到输出结果。
需要注意的一点是:定点数范围有限,需要防止溢出,提前确定好合适的bit位数。
最简单的是PTQ,它的基本思路就是找到每个层的S和Z,进行缩放。这个过程通过给定一些输入数据用来做校准,主要涉及feature map和激活函数输出的min max的计算。
为了更好利用数据信息,得到量化损失更小的量化策略,可以用QAT。
QAT的基本原理是:在训练过程中,对待量化的位置插入伪量化节点(参数由训练过程中当前的真实的分布计算得到),然后进行浮点训练,更新待量化参数。
这个过程主要是希望fake_quant节点的操作,模拟真实量化的情况,使得网络在面对量化时尽可能保持精度。
QAT训练完成后,将伪量化节点撤掉,变成真正的INT8量化模型。
注意:训练过程中采用了Straight-through estimator(STE)策略,因为quant操作不连续,因此直接将输出的导数直接给到输入。
将scale也加入梯度反传,从而直接优化更新scale(不像一般的QAT中利用feature算quant节点参数,然后量化前传,然后将loss直接传给feature,避免quant带来的不连续)。