量化感知训练的数值稳定性

别等时光非礼了梦想. 提交于 2020-08-14 05:35:37

引言

对神经网络进行量化可以减小模型参数量和提升速度,并且一些硬件设备只支持特定的量化网络,因此在一些场景中常被采用。最常见的是对一个已经训练好的模型进行处理来做到int8量化,但这种方式一般会对网络的效果有些影响,特别是对一些图像生成任务。所以希望通过量化感知训练的方式能做一些特殊处理,以减少量化带来的效果影响。但是像pytorch这样的框架吗目前对量化感知训练支持的还不是特别好, 另外一些特殊硬件设备的量化策略跟常用的一些量化策略不太一样,因此最近在学习这方面的内容,在这其中踩了一些数值稳定性方面的坑,分享给大家。

网络的计算量很大部分来自于卷积,而卷积的计算过程可以由下面的公式表示
y = W x + b (1) y=Wx+b \tag{1} y=Wx+b(1)
其中W是卷积的权重,b是偏置bias。以int8为例,最简单的就是非对称量化,也就是希望将W的范围其量化到 [ 0 , 255 ] [0,255] [0,255]
s c a l e = 255 m a x ( W ) − m i n ( W ) (2) scale = \frac{255}{max(W)-min(W)} \tag{2} scale=max(W)min(W)255(2)
因此量化后的权重计算就可以用
w q = ( w − m i n ( W ) ) ∗ s c a l e = w ∗ s c a l e − m i n ( W ) ∗ s c a l e = w ∗ s c a l e + z e r o _ p t (3) w_q=(w-min(W))*scale=w*scale - min(W)*scale = w*scale+zero\_pt \tag{3} wq=(wmin(W))scale=wscalemin(W)scale=wscale+zero_pt(3)
其中scale一般是个浮点值,zero_pt会被量化成一个int8的值





bias经过研究发现最好不要量化[2018-arxiv] Quantizing deep convolutional networks for efficient inference: A whitepaper,不然对效果有较大影响

BatchNorm带来的影响

目前的主流网络一般都会带batchnorm,batchnorm在推理的时候会被融入到前面的卷积层。如果我们先对 W W W量化到了255,但是又经过了batchnorm中的 γ \gamma γ去缩放,那么量化的范围有可能会被集中或者上下溢出,导致量化效果不好。

BatchNorm的功能在推理的时候会固定住均值方差,可以用下面的公式表示
y = x − μ g σ g ∗ γ + β (4) y=\frac{x-\mu_g}{\sigma_g}*\gamma+\beta \tag{4} y=σgxμgγ+β(4)
而训练的时候会使用当前batch的均值方差,表示成以下形式
y = x − μ b σ b ∗ γ + β (5) y=\frac{x-\mu_b}{\sigma_b}*\gamma+\beta \tag{5} y=σbxμbγ+β(5)


由于卷积层融合BN后,卷积层的bias没有意义,会被均值减掉,因此卷积成不会带bias,所以融合卷积后整个公式变成了
y = W x − μ g σ g ∗ γ + β = γ W x σ g + ( β − γ μ g σ g ) (6) y=\frac{Wx-\mu_g}{\sigma_g}*\gamma+\beta=\frac{\gamma Wx}{\sigma_g}+(\beta-\frac{\gamma \mu_g}{\sigma_g}) \tag{6} y=σgWxμgγ+β=σgγWx+(βσgγμg)(6)

算法迭代

因此量化感知训练的时候需要模拟这里的缩放行为,因此写出了第一版的算法。


A.1 原始方法


W = γ W σ g W=\frac{\gamma W}{\sigma_g} W=σgγW

W q = Q u a n t i z e ( W ) W_q = Quantize(W) Wq=Quantize(W)

W d q = D e Q u a n t i z e ( W q ) W_{dq}=DeQuantize(W_q) Wdq=DeQuantize(Wq)

y = W d q x y=W_{dq}x y=Wdqx

y = y ∗ σ g γ y=y*\frac{\sigma_g}{\gamma} y=yγσg

y = B N ( y ) y=BN(y) y=BN(y)


A.1遇到的第一个问题就是NAN问题,NAN问题主要两点

第一点来自于A.1算法的第二行的Quantize,参考公式2,我们要算scale,但是如果权重的最高值和最低值一样的时候就会出现这样的问题

第二个带你来自于A.1算法的倒数第二行,由于 γ \gamma γ是一个实值,所以有可能是0,比如在残差网络中,有时候训练初始会把残差分支的第三个BN的 γ \gamma γ初始化成0

所以我想也没想,就用了eps方式,


A.2 一步 ϵ \epsilon ϵ


W = γ W σ g W=\frac{\gamma W}{\sigma_g} W=σgγW

W q = Q u a n t i z e ( W ) W_q = Quantize(W) Wq=Quantize(W)

W d q = D e Q u a n t i z e ( W q ) W_{dq}=DeQuantize(W_q) Wdq=DeQuantize(Wq)

y = W d q x y=W_{dq}x y=Wdqx

y = y ∗ σ g γ + ϵ y=y*\frac{\sigma_g}{\gamma+\epsilon} y=yγ+ϵσg

y = B N ( y ) y=BN(y) y=BN(y)


然后发现在resnet50的训练是这样子的

Epoch train_loss eval_loss eval_top1 eval_top5
0 6.898 6.843 0.468 2.188
1 6.124 5.463 6.386 17.77
2 5.604 5.145 9.84 24.02
3 5.330 5.663 6.9 17.964
4 5.144 8.263 2.264 6.764
5 5.593 15.387 0.132 0.666
6 6.680 6.966 0.224 1.144

可以看到loss下降后出现发散的情况我们后面会谈到这样修改的问题。

不过当时也没多想,写了接下来的变体


A.3 两步法


with no_grad():

​ y=Wx

Update_BN(y)

W = γ W σ g W=\frac{\gamma W}{\sigma_g} W=σgγW

W q = Q u a n t i z e ( W ) W_q = Quantize(W) Wq=Quantize(W)

W d q = D e Q u a n t i z e ( W q ) W_{dq}=DeQuantize(W_q) Wdq=DeQuantize(Wq)

y = W d q x y=W_{dq}x y=Wdqx

y = y ∗ σ b σ g + β − γ μ b σ b y=y*\frac{\sigma_b}{\sigma_g}+\beta-\frac{\gamma \mu_b}{\sigma_b} y=yσgσb+βσbγμb


这样就规避了 γ \gamma γ的数值稳定性问题,但是发现训练的状况是这样的

Epoch train_loss eval_loss eval_top1 eval_top5
0 6.914 6.882 0.324 1.258
1 6.128 5.470 6.385 17.452
2 5.603 5.195 8.869 21.891
3 5.350 5.716 6.310 16.723
4 5.185 7.754 0.546 1.94
5 5.088 9.702 0.19 0.836
6 4.978 10.906 0.258 1.26

可以看到虽然训练loss再缓慢下降,但是验证精度却没提升。我们猜测原因主要是在更新BN时候并没有用量化的输出,导致真实的BN参数是有bias的

一度个人以为是不是量化感知的训练就是这样的,不过经过多次实验写出了下面的变体,解决了上面的一些问题。A.2的问题其实之前也已经提到,核心还是 γ \gamma γ是实值,没有非负性保证,加个eps也不能保证除数不为0.

最终的方案其实也很简单


A.4 最终方案


W = γ W σ g W=\frac{\gamma W}{\sigma_g} W=σgγW

W q = Q u a n t i z e ( W ) W_q = Quantize(W) Wq=Quantize(W)

W d q = D e Q u a n t i z e ( W q ) W_{dq}=DeQuantize(W_q) Wdq=DeQuantize(Wq)

y = W d q x y=W_{dq}x y=Wdqx

y = y ∗ σ g ∗ f a k e _ c o f f f a k e _ γ y=y*\frac{\sigma_g*fake\_coff}{fake\_\gamma} y=yfake_γσgfake_coff

y = B N ( y ) y=BN(y) y=BN(y)


Epoch train_loss eval_loss eval_top1 eval_top5
0 6.915 6.882 0.343 1.281
1 6.081 5.279 7.958 20.57
2 5.180 4.356 16.68 36.588
3 4.423 3.533 27.680 52.539
4 3.965 2.856 37.928 64.976
5 3.688 2.760 40.058 66.861
6 3.479 2.479 45.689 71.930

我们发下最终的训练结果跟非量化感知类似,top1下降0.2个点

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!