1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        神經(jīng)網(wǎng)絡(luò)量化入門--量化感知訓(xùn)練

        共 6451字,需瀏覽 13分鐘

         ·

        2021-07-29 06:15

        上一篇文章介紹了后訓(xùn)練量化的基本流程,并用 pytorch 演示了最簡(jiǎn)單的后訓(xùn)練量化算法。

        后訓(xùn)練量化雖然操作簡(jiǎn)單,并且大部分推理框架都提供了這類離線量化算法 (如 tensorrt、ncnn,SNPE 等),但有時(shí)候這種方法并不能保證足夠的精度,因此本文介紹另一種比后訓(xùn)練量化更有效的量化方法——量化感知訓(xùn)練。

        量化感知訓(xùn)練,顧名思義,就是在量化的過(guò)程中,對(duì)網(wǎng)絡(luò)進(jìn)行訓(xùn)練,從而讓網(wǎng)絡(luò)參數(shù)能更好地適應(yīng)量化帶來(lái)的信息損失。這種方式更加靈活,因此準(zhǔn)確性普遍比后訓(xùn)練量化要高。當(dāng)然,它的一大缺點(diǎn)是操作起來(lái)不方便,這一點(diǎn)后面會(huì)詳談。

        同樣地,這篇文章會(huì)講解最簡(jiǎn)單的量化訓(xùn)練算法流程,并沿用之前文章的代碼框架,用 pytorch 從零構(gòu)建量化訓(xùn)練算法的流程。

        量化訓(xùn)練的困難

        要理解量化訓(xùn)練的困難之處,需要了解量化訓(xùn)練相比普通的全精度訓(xùn)練有什么區(qū)別。為了看清這一點(diǎn),我們回顧一下上一篇文章中卷積量化的代碼:

        class?QConv2d(QModule):

        ????def?forward(self,?x):
        ????????if?hasattr(self,?'qi'):
        ????????????self.qi.update(x)

        ????????self.qw.update(self.conv_module.weight.data)

        ????????self.conv_module.weight.data?=?self.qw.quantize_tensor(self.conv_module.weight.data)
        ????????self.conv_module.weight.data?=?self.qw.dequantize_tensor(self.conv_module.weight.data)

        ????????x?=?self.conv_module(x)

        ????????if?hasattr(self,?'qo'):
        ????????????self.qo.update(x)

        ????????return?x

        這里面區(qū)別于全精度模型的地方在于,我們?cè)诰矸e運(yùn)算前先對(duì) weight 做了一遍量化,然后又再反量化成 float。這一步在后訓(xùn)練量化中其實(shí)可有可無(wú),但量化感知訓(xùn)練中卻是需要的。「之前為了代碼上的一致,我提前把這一步加上去了」

        那這一步有什么特別嗎?可以回顧一下量化的具體操作:

        def?quantize_tensor(x,?scale,?zero_point,?num_bits=8,?signed=False):
        ????if?signed:
        ????????qmin?=?-?2.?**?(num_bits?-?1)
        ????????qmax?=?2.?**?(num_bits?-?1)?-?1
        ????else:
        ????????qmin?=?0.
        ????????qmax?=?2.**num_bits?-?1.
        ?
        ????q_x?=?zero_point?+?x?/?scale
        ????q_x.clamp_(qmin,?qmax).round_()
        ????
        ????return?q_x.float()

        這里面有個(gè) round 函數(shù),而這個(gè)函數(shù)是沒(méi)法訓(xùn)練的。它的函數(shù)圖像如下:

        bf7f7925e0e2f2b0897824ef0eb18ef3.webp

        這個(gè)函數(shù)幾乎每一處的梯度都是 0,如果網(wǎng)絡(luò)中存在該函數(shù),會(huì)導(dǎo)致反向傳播的梯度也變成 0。

        可以看個(gè)例子:

        conv?=?nn.Conv2d(3,?1,?3,?1)

        def?quantize(weight):
        ????w?=?weight.round()
        ????return?w

        class?QuantConv(nn.Module):

        ????def?__init__(self,?conv_module):
        ????????super(QuantConv,?self).__init__()
        ????????self.conv_module?=?conv_module

        ????def?forward(self,?x):
        ????????return?F.conv2d(x,?quantize(self.conv_module.weight),?self.conv_module.bias,?3,?1)


        x?=?torch.randn((1,?3,?4,?4))

        quantconv?=?QuantConv(conv)

        a?=?quantconv(x).sum().backward()

        print(quantconv.conv_module.weight.grad)

        這個(gè)例子里面,我將權(quán)重 weight 做了一遍 round 操作后,再進(jìn)行卷積運(yùn)算,但返回的梯度全是 0:

        tensor([[[[0.,?0.,?0.],
        ??????????[0.,?0.,?0.],
        ??????????[0.,?0.,?0.]],

        ?????????[[0.,?0.,?0.],
        ??????????[0.,?0.,?0.],
        ??????????[0.,?0.,?0.]],

        ?????????[[0.,?0.,?0.],
        ??????????[0.,?0.,?0.],
        ??????????[0.,?0.,?0.]]]])

        換言之,這個(gè)函數(shù)是沒(méi)法學(xué)習(xí)的,從而導(dǎo)致量化訓(xùn)練進(jìn)行不下去。

        Straight Through Estimator

        那要怎么解決這個(gè)問(wèn)題呢?

        一個(gè)很容易想到的方法是,直接跳過(guò)偽量化的過(guò)程,避開(kāi) round。直接把卷積層的梯度回傳到偽量化之前的 weight 上。這樣一來(lái),由于卷積中用的 weight 是經(jīng)過(guò)偽量化操作的,因此可以模擬量化誤差,把這些誤差的梯度回傳到原來(lái)的 weight,又可以更新權(quán)重,使其適應(yīng)量化產(chǎn)生的誤差,量化訓(xùn)練就可以正常進(jìn)行下去了。

        這個(gè)方法就叫做 Straight Through Estimator(STE)。

        pytorch實(shí)現(xiàn)

        本文的相關(guān)代碼都可以在 https://github.com/Jermmy/pytorch-quantization-demo 上找到。

        偽量化節(jié)點(diǎn)實(shí)現(xiàn)

        上面講完量化訓(xùn)練最基本的思路,下面我們繼續(xù)沿用前文的代碼框架,加入量化訓(xùn)練的部分。

        首先,我們需要修改偽量化的寫(xiě)法,之前的代碼是直接對(duì) weight 的數(shù)值做了偽量化:

        self.conv_module.weight.data?=?self.qw.quantize_tensor(self.conv_module.weight.data)
        self.conv_module.weight.data?=?self.qw.dequantize_tensor(self.conv_module.weight.data)

        這在后訓(xùn)練量化里面沒(méi)有問(wèn)題,但在 pytorch 中,這種寫(xiě)法是沒(méi)法回傳梯度的,因此量化訓(xùn)練里面,需要重新修改偽量化節(jié)點(diǎn)的寫(xiě)法。

        另外,STE 需要我們重新定義反向傳播的梯度。因此,需要借助 pytorch 中的 Function 接口來(lái)重新定義偽量化的過(guò)程:

        from?torch.autograd?import?Function

        class?FakeQuantize(Function):

        ????@staticmethod
        ????def?forward(ctx,?x,?qparam):
        ????????x?=?qparam.quantize_tensor(x)
        ????????x?=?qparam.dequantize_tensor(x)
        ????????return?x

        ????@staticmethod
        ????def?backward(ctx,?grad_output):
        ????????return?grad_output,?None

        這里面的 forward 函數(shù),和之前的寫(xiě)法是類似的,就是把數(shù)值量化之后再反量化回去。但在 backward 中,我們直接返回了后一層傳過(guò)來(lái)的梯度 grad_output,相當(dāng)于直接跳過(guò)了偽量化這一層的梯度計(jì)算,讓梯度直接流到前一層 (Straight Through)。

        pytorch 定義 backward 函數(shù)的返回變量需要與 forward 的輸入?yún)?shù)對(duì)應(yīng),分別表示對(duì)應(yīng)輸入的梯度。由于 qparam 只是統(tǒng)計(jì) min、max,不需要梯度,因此返回給它的梯度是 None。

        量化卷積代碼

        量化卷積層的代碼除了 forward 中需要修改偽量化節(jié)點(diǎn)外,其余的和之前的文章基本一致:

        class?QConv2d(QModule):

        ????def?forward(self,?x):
        ????????if?hasattr(self,?'qi'):
        ????????????self.qi.update(x)
        ????????????x?=?FakeQuantize.apply(x,?self.qi)

        ????????self.qw.update(self.conv_module.weight.data)

        ????????x?=?F.conv2d(x,?FakeQuantize.apply(self.conv_module.weight,?self.qw),
        ?????????????????????self.conv_module.bias,?
        ?????????????????????stride=self.conv_module.stride,
        ?????????????????????padding=self.conv_module.padding,?dilation=self.conv_module.dilation,?
        ?????????????????????groups=self.conv_module.groups)

        ????????if?hasattr(self,?'qo'):
        ????????????self.qo.update(x)
        ????????????x?=?FakeQuantize.apply(x,?self.qo)

        ????????return?x

        由于我們需要先對(duì) weight 做一些偽量化的操作,根據(jù) pytorch 中的規(guī)則,在做卷積運(yùn)算的時(shí)候,不能像之前一樣用 x = self.conv_module(x) 的寫(xiě)法,而要用 F.conv2d 來(lái)調(diào)用。另外,之前的代碼中輸入輸出沒(méi)有加偽量化節(jié)點(diǎn),這在后訓(xùn)練量化中沒(méi)有問(wèn)題,但在量化訓(xùn)練中最好加上,方便網(wǎng)絡(luò)更好地感知量化帶來(lái)的損失。

        由于上一篇文章中做量化推理的時(shí)候,我發(fā)現(xiàn)精度損失不算太重,3 個(gè) bit 的情況下,準(zhǔn)確率依然能達(dá)到 96%。為了更好地體會(huì)量化訓(xùn)練帶來(lái)的收益,我們把量化推理的代碼再細(xì)致一點(diǎn),加大量化損失:

        class?QConv2d(QModule):

        ????def?quantize_inference(self,?x):
        ????????x?=?x?-?self.qi.zero_point
        ????????x?=?self.conv_module(x)
        ????????x?=?self.M?*?x
        ????????x.round_()??????#?多加一個(gè)round操作
        ????????x?=?x?+?self.qo.zero_point????????
        ????????x.clamp_(0.,?2.**self.num_bits-1.).round_()
        ????????return?x

        相比之前的代碼,其實(shí)就是多加了個(gè) round,讓量化推理更接近真實(shí)的推理過(guò)程。

        量化訓(xùn)練的收益

        這里仍然沿用之前文章里的小網(wǎng)絡(luò),在 mnist 上測(cè)試分類準(zhǔn)確率。由于量化推理有修改,為了方便對(duì)比,我重新跑了一遍后訓(xùn)練量化的準(zhǔn)確率:

        bit12345678
        accuracy10%47%83%96%98%98%98%98%

        接下來(lái),測(cè)試一下量化訓(xùn)練的效果,下面是 bit=3 時(shí)輸出的 log:

        Test?set:?Full?Model?Accuracy:?98%

        Quantization?bit:?3
        Quantize?Aware?Training?Epoch:?1?[3200/60000]???Loss:?0.087867
        Quantize?Aware?Training?Epoch:?1?[6400/60000]???Loss:?0.219696
        Quantize?Aware?Training?Epoch:?1?[9600/60000]???Loss:?0.283124
        Quantize?Aware?Training?Epoch:?1?[12800/60000]??Loss:?0.172751
        Quantize?Aware?Training?Epoch:?1?[16000/60000]??Loss:?0.315173
        Quantize?Aware?Training?Epoch:?1?[19200/60000]??Loss:?0.302261
        Quantize?Aware?Training?Epoch:?1?[22400/60000]??Loss:?0.218039
        Quantize?Aware?Training?Epoch:?1?[25600/60000]??Loss:?0.301568
        Quantize?Aware?Training?Epoch:?1?[28800/60000]??Loss:?0.252994
        Quantize?Aware?Training?Epoch:?1?[32000/60000]??Loss:?0.138346
        Quantize?Aware?Training?Epoch:?1?[35200/60000]??Loss:?0.203350

        ...

        Test?set:?Quant?Model?Accuracy:?90%

        總的實(shí)驗(yàn)結(jié)果如下:

        bit12345678
        accuracy10%63%90%97%98%98%98%98%

        用曲線把它們 plot 在一起:

        202351fffe47cb27b1366730212c4b66.webp

        灰色線是量化訓(xùn)練,橙色線是后訓(xùn)練量化,可以看到,在 bit = 2、3 的時(shí)候,量化訓(xùn)練能帶來(lái)很明顯的提升。

        在 bit = 1 的時(shí)候,我發(fā)現(xiàn)量化訓(xùn)練回傳的梯度為 0,訓(xùn)練基本失敗了。這是因?yàn)?bit = 1 的時(shí)候,整個(gè)網(wǎng)絡(luò)已經(jīng)退化成一個(gè)二值網(wǎng)絡(luò)了,而低比特量化訓(xùn)練本身不是一件容易的事情,雖然我們前面用 STE 解決了梯度的問(wèn)題,但由于低比特會(huì)使得網(wǎng)絡(luò)的信息損失巨大,因此通常的訓(xùn)練方式很難起到作用。

        另外,量化訓(xùn)練本身存在很多 trick,在這個(gè)實(shí)驗(yàn)中我發(fā)現(xiàn),學(xué)習(xí)率對(duì)結(jié)果的影響非常顯著,尤其是低比特量化的時(shí)候,學(xué)習(xí)率太高容易導(dǎo)致梯度變?yōu)?0,導(dǎo)致量化訓(xùn)練完全不起作用「一度以為代碼出錯(cuò)」。

        量化訓(xùn)練部署

        前面說(shuō)過(guò),量化訓(xùn)練雖然收益明顯,但實(shí)際應(yīng)用起來(lái)卻比后訓(xùn)練量化麻煩得多。

        目前大部分主流推理框架在處理后訓(xùn)練量化時(shí),只需要用戶把模型和數(shù)據(jù)扔進(jìn)去,就可以得到量化模型,然后直接部署。但很少有框架支持量化訓(xùn)練。

        目前量化訓(xùn)練缺少統(tǒng)一的規(guī)范,各家推理引擎的量化算法雖然本質(zhì)一樣,但很多細(xì)節(jié)處很難做到一致。而目前大家做模型訓(xùn)練的前端框架是不統(tǒng)一的「當(dāng)然主流還是 tf 和 pytorch」,如果各家的推理引擎需要支持不同前端的量化訓(xùn)練,就需要針對(duì)不同的前端框架,按照后端部署的實(shí)現(xiàn)規(guī)則「比如哪些層的量化需要合并、weight 是否采用對(duì)稱量化等」,從頭再搭一套量化訓(xùn)練框架,這個(gè)工作量想想就嚇人。

        總結(jié)

        這篇文章主要介紹了量化訓(xùn)練的基本方法,并用 pytorch 構(gòu)建了一個(gè)簡(jiǎn)單的量化訓(xùn)練實(shí)例。下一篇文章會(huì)介紹這系列教程的最后一篇文章——關(guān)于 fold BatchNorm 相關(guān)的知識(shí)。

        參考

        • Torch.round() gradient
        • pytorch實(shí)現(xiàn)簡(jiǎn)單的straight-through estimator(STE)


        瀏覽 86
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評(píng)論
        圖片
        表情
        推薦
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            青青草激动视频 | 在线观看黄色av 国产精品国产三级国产专业不 | av研究院 | 高清不卡a v | 男生插女生的逼 | 北岛玲heyzo一区二区 | 青青操视频在线观看 | 男人和女人在羞羞视频 | 欧美大香蕉性爱 | 在线视频欧美色图 |