引言
对神经网络进行量化可以减小模型参数量和提升速度,并且一些硬件设备只支持特定的量化网络,因此在一些场景中常被采用。最常见的是对一个已经训练好的模型进行处理来做到int8量化,但这种方式一般会对网络的效果有些影响,特别是对一些图像生成任务。所以希望通过量化感知训练的方式能做一些特殊处理,以减少量化带来的效果影响。但是像pytorch这样的框架吗目前对量化感知训练支持的还不是特别好, 另外一些特殊硬件设备的量化策略跟常用的一些量化策略不太一样,因此最近在学习这方面的内容,在这其中踩了一些数值稳定性方面的坑,分享给大家。
网络的计算量很大部分来自于卷积,而卷积的计算过程可以由下面的公式表示
y = W x + b (1) y=Wx+b \tag{1} y=Wx+b(1)
其中W是卷积的权重,b是偏置bias。以int8为例,最简单的就是非对称量化,也就是希望将W的范围其量化到 [ 0 , 255 ] [0,255] [0,255],
s c a l e = 255 m a x ( W ) − m i n ( W ) (2) scale = \frac{255}{max(W)-min(W)} \tag{2} scale=max(W)−min(W)255(2)
因此量化后的权重计算就可以用
w q = ( w − m i n ( W ) ) ∗ s c a l e = w ∗ s c a l e − m i n ( W ) ∗ s c a l e = w ∗ s c a l e + z e r o _ p t (3) w_q=(w-min(W))*scale=w*scale - min(W)*scale = w*scale+zero\_pt \tag{3} wq=(w−min(W))∗scale=w∗scale−min(W)∗scale=w∗scale+zero_pt(3)
其中scale一般是个浮点值,zero_pt会被量化成一个int8的值
bias经过研究发现最好不要量化[2018-arxiv] Quantizing deep convolutional networks for efficient inference: A whitepaper,不然对效果有较大影响
BatchNorm带来的影响
目前的主流网络一般都会带batchnorm,batchnorm在推理的时候会被融入到前面的卷积层。如果我们先对 W W W量化到了255,但是又经过了batchnorm中的 γ \gamma γ去缩放,那么量化的范围有可能会被集中或者上下溢出,导致量化效果不好。
BatchNorm的功能在推理的时候会固定住均值方差,可以用下面的公式表示
y = x − μ g σ g ∗ γ + β (4) y=\frac{x-\mu_g}{\sigma_g}*\gamma+\beta \tag{4} y=σgx−μg∗γ+β(4)
而训练的时候会使用当前batch的均值方差,表示成以下形式
y = x − μ b σ b ∗ γ + β (5) y=\frac{x-\mu_b}{\sigma_b}*\gamma+\beta \tag{5} y=σbx−μb∗γ+β(5)
由于卷积层融合BN后,卷积层的bias没有意义,会被均值减掉,因此卷积成不会带bias,所以融合卷积后整个公式变成了
y = W x − μ g σ g ∗ γ + β = γ W x σ g + ( β − γ μ g σ g ) (6) y=\frac{Wx-\mu_g}{\sigma_g}*\gamma+\beta=\frac{\gamma Wx}{\sigma_g}+(\beta-\frac{\gamma \mu_g}{\sigma_g}) \tag{6} y=σgWx−μg∗γ+β=σgγWx+(β−σgγμg)(6)
算法迭代
因此量化感知训练的时候需要模拟这里的缩放行为,因此写出了第一版的算法。
A.1 原始方法
W = γ W σ g W=\frac{\gamma W}{\sigma_g} W=σgγW
W q = Q u a n t i z e ( W ) W_q = Quantize(W) Wq=Quantize(W)
W d q = D e Q u a n t i z e ( W q ) W_{dq}=DeQuantize(W_q) Wdq=DeQuantize(Wq)
y = W d q x y=W_{dq}x y=Wdqx
y = y ∗ σ g γ y=y*\frac{\sigma_g}{\gamma} y=y∗γσg
y = B N ( y ) y=BN(y) y=BN(y)
A.1遇到的第一个问题就是NAN问题,NAN问题主要两点
第一点来自于A.1算法的第二行的Quantize,参考公式2,我们要算scale,但是如果权重的最高值和最低值一样的时候就会出现这样的问题
第二个带你来自于A.1算法的倒数第二行,由于 γ \gamma γ是一个实值,所以有可能是0,比如在残差网络中,有时候训练初始会把残差分支的第三个BN的 γ \gamma γ初始化成0
所以我想也没想,就用了eps方式,
A.2 一步 ϵ \epsilon ϵ法
W = γ W σ g W=\frac{\gamma W}{\sigma_g} W=σgγW
W q = Q u a n t i z e ( W ) W_q = Quantize(W) Wq=Quantize(W)
W d q = D e Q u a n t i z e ( W q ) W_{dq}=DeQuantize(W_q) Wdq=DeQuantize(Wq)
y = W d q x y=W_{dq}x y=Wdqx
y = y ∗ σ g γ + ϵ y=y*\frac{\sigma_g}{\gamma+\epsilon} y=y∗γ+ϵσg
y = B N ( y ) y=BN(y) y=BN(y)
然后发现在resnet50的训练是这样子的
Epoch | train_loss | eval_loss | eval_top1 | eval_top5 |
---|---|---|---|---|
0 | 6.898 | 6.843 | 0.468 | 2.188 |
1 | 6.124 | 5.463 | 6.386 | 17.77 |
2 | 5.604 | 5.145 | 9.84 | 24.02 |
3 | 5.330 | 5.663 | 6.9 | 17.964 |
4 | 5.144 | 8.263 | 2.264 | 6.764 |
5 | 5.593 | 15.387 | 0.132 | 0.666 |
6 | 6.680 | 6.966 | 0.224 | 1.144 |
可以看到loss下降后出现发散的情况我们后面会谈到这样修改的问题。
不过当时也没多想,写了接下来的变体
A.3 两步法
with no_grad():
y=Wx
Update_BN(y)
W = γ W σ g W=\frac{\gamma W}{\sigma_g} W=σgγW
W q = Q u a n t i z e ( W ) W_q = Quantize(W) Wq=Quantize(W)
W d q = D e Q u a n t i z e ( W q ) W_{dq}=DeQuantize(W_q) Wdq=DeQuantize(Wq)
y = W d q x y=W_{dq}x y=Wdqx
y = y ∗ σ b σ g + β − γ μ b σ b y=y*\frac{\sigma_b}{\sigma_g}+\beta-\frac{\gamma \mu_b}{\sigma_b} y=y∗σgσb+β−σbγμb
这样就规避了 γ \gamma γ的数值稳定性问题,但是发现训练的状况是这样的
Epoch | train_loss | eval_loss | eval_top1 | eval_top5 |
---|---|---|---|---|
0 | 6.914 | 6.882 | 0.324 | 1.258 |
1 | 6.128 | 5.470 | 6.385 | 17.452 |
2 | 5.603 | 5.195 | 8.869 | 21.891 |
3 | 5.350 | 5.716 | 6.310 | 16.723 |
4 | 5.185 | 7.754 | 0.546 | 1.94 |
5 | 5.088 | 9.702 | 0.19 | 0.836 |
6 | 4.978 | 10.906 | 0.258 | 1.26 |
可以看到虽然训练loss再缓慢下降,但是验证精度却没提升。我们猜测原因主要是在更新BN时候并没有用量化的输出,导致真实的BN参数是有bias的
一度个人以为是不是量化感知的训练就是这样的,不过经过多次实验写出了下面的变体,解决了上面的一些问题。A.2的问题其实之前也已经提到,核心还是 γ \gamma γ是实值,没有非负性保证,加个eps也不能保证除数不为0.
最终的方案其实也很简单
A.4 最终方案
W = γ W σ g W=\frac{\gamma W}{\sigma_g} W=σgγW
W q = Q u a n t i z e ( W ) W_q = Quantize(W) Wq=Quantize(W)
W d q = D e Q u a n t i z e ( W q ) W_{dq}=DeQuantize(W_q) Wdq=DeQuantize(Wq)
y = W d q x y=W_{dq}x y=Wdqx
y = y ∗ σ g ∗ f a k e _ c o f f f a k e _ γ y=y*\frac{\sigma_g*fake\_coff}{fake\_\gamma} y=y∗fake_γσg∗fake_coff
y = B N ( y ) y=BN(y) y=BN(y)
Epoch | train_loss | eval_loss | eval_top1 | eval_top5 |
---|---|---|---|---|
0 | 6.915 | 6.882 | 0.343 | 1.281 |
1 | 6.081 | 5.279 | 7.958 | 20.57 |
2 | 5.180 | 4.356 | 16.68 | 36.588 |
3 | 4.423 | 3.533 | 27.680 | 52.539 |
4 | 3.965 | 2.856 | 37.928 | 64.976 |
5 | 3.688 | 2.760 | 40.058 | 66.861 |
6 | 3.479 | 2.479 | 45.689 | 71.930 |
我们发下最终的训练结果跟非量化感知类似,top1下降0.2个点
来源:oschina
链接:https://my.oschina.net/u/4261335/blog/4463306