1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        淺談Transformer+CNN混合架構(gòu):CMT以及從0-1復(fù)現(xiàn)

        共 5952字,需瀏覽 12分鐘

         ·

        2021-08-15 17:24

        ↑ 點(diǎn)擊藍(lán)字 關(guān)注極市平臺

        作者丨FlyEgle
        來源丨GiantPandaCV
        編輯丨極市平臺

        極市導(dǎo)讀

         

        本文詳細(xì)講解了華為諾亞與悉尼大學(xué)在Transformer+CNN架構(gòu)混合方面的嘗試,一種同時(shí)具有Transformer長距離建模與CNN局部特征提取能力的CMT。并給出了自己從0-1的復(fù)現(xiàn)過程以及是實(shí)驗(yàn)結(jié)果。 >>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺的最前沿

        論文鏈接: https://arxiv.org/abs/2107.06263
        論文代碼(個(gè)人實(shí)現(xiàn)版本): https://github.com/FlyEgle/CMT-pytorch
        知乎專欄:https://www.zhihu.com/people/flyegle

        寫在前面

        本篇博客講解CMT模型并給出從0-1復(fù)現(xiàn)的過程以及實(shí)驗(yàn)結(jié)果,由于論文的細(xì)節(jié)并沒有給出來,所以最后的復(fù)現(xiàn)和paper的精度有一點(diǎn)差異,等作者release代碼后,我會詳細(xì)的校對我自己的code,找找原因。

        1. 出發(fā)點(diǎn)

        • Transformers與現(xiàn)有的卷積神經(jīng)網(wǎng)絡(luò)(CNN)在性能和計(jì)算成本方面仍有差距。
        • 希望提出的模型不僅可以超越典型的Transformers,而且可以超越高性能卷積模型。

        2. 怎么做

        1. 提出混合模型(串行),通過利用Transformers來捕捉長距離的依賴關(guān)系,并利用CNN來獲取局部特征。
        2. 引入depth-wise卷積,獲取局部特征的同時(shí),減少計(jì)算量
        3. 使用類似R50模型結(jié)構(gòu)一樣的stageblock,使得模型具有下采樣增強(qiáng)感受野和遷移dense的能力。
        4. 使用conv-stem來使得圖像的分辨率縮放從VIT的1/16變?yōu)?/4,保留更多的patch信息。

        3. 模型結(jié)構(gòu)

        模型結(jié)構(gòu)
        • (a)表示的是標(biāo)準(zhǔn)的R50模型,具有4個(gè)stage,每個(gè)都會進(jìn)行一次下采樣。最后得到特征表達(dá)后,經(jīng)過AvgPool進(jìn)行分類
        • (b)表示的是標(biāo)準(zhǔn)的VIT模型,先進(jìn)行patch的劃分,然后embeeding后進(jìn)入Transformer的block,這里,由于Transformer是long range的,所以進(jìn)入什么,輸出就是什么,引入了一個(gè)非image的class token來做分類。
        • (c)表示的是本文所提出的模型框架CMT,由CMT-stem, downsampling, cmt block所組成,整體結(jié)構(gòu)則是類似于R50,所以可以很好的遷移到dense任務(wù)上去。

        3.1. CMT Stem

        使用convolution來作為transformer結(jié)構(gòu)的stem,這個(gè)觀點(diǎn)FB也有提出一篇paper,Early Convolutions Help Transformers See Better。

        https://arxiv.org/abs/2106.14881

        CMT&Conv stem共性

        • 使用4層conv3x3+stride2 + conv1x1 stride 1 等價(jià)于VIT的patch embeeding,conv16x16 stride 16.
        • 使用conv stem,可以使模型得到更好的收斂,同時(shí),可以使用SGD優(yōu)化器來訓(xùn)練模型,對于超參數(shù)的依賴沒有原始的那么敏感。好處那是大大的多啊,僅僅是改了一個(gè)conv stem。

        CMT&Conv stem異性

        • 本文僅僅做了一次conv3x3 stride2,實(shí)際上只有一次下采樣,相比conv stem,可以保留更多的patch的信息到下層。

        從時(shí)間上來說,一個(gè)20210628(conv stem), 一個(gè)是20210713(CMT stem),存在借鑒的可能性還是比較小的,也說明了conv stem的確是work。

        3.2. CMT Block

        每一個(gè)stage都是由CMT block所堆疊而成的,CMT block由于是transformer結(jié)構(gòu),所以沒有在stage里面去設(shè)計(jì)下采樣。每個(gè)CMT block都是由Local Perception Unit, Ligntweight MHSA, Inverted Residual FFN這三個(gè)模塊所組成的,下面分別介紹:

        • Local Perception Unit(LPU)

        LPU

        本文的一個(gè)核心點(diǎn)是希望模型具有l(wèi)ong-range的能力,同時(shí)還要具有l(wèi)ocal特征的能力,所以提出了LPU這個(gè)模塊,很簡單,一個(gè)3X3的DWconv,來做局部特征,同時(shí)減少點(diǎn)計(jì)算量,為了讓Transformers的模塊獲取的longrange的信息不缺失,這里做了一個(gè)shortcut,公式描述為:

        • Lightweight MHSA(LMHSA)

        LMHSA

        MHSA這個(gè)不用多說了,多頭注意力,Lightweight這個(gè)作用,PVT 曾經(jīng)有提出過,目的是為了降低復(fù)雜度,減少計(jì)算量。那本文是怎么做的呢,很簡單,假設(shè)我們的輸入為, 對其分別做一個(gè)scale,使用卷積核為,stride為的Depth Wise卷積來做了一次下采樣,得到的shape為,那么對應(yīng)的Q,K,V的shape分別為:

        我們知道,在計(jì)算MHSA的時(shí)候要遵守兩個(gè)計(jì)算原則:

        1. Q, K的序列dim要一致。
        2. K, V的token數(shù)量要一致。

        所以,本文中的MHSA計(jì)算公式如下:

        • Inverted Resdiual FFN(IRFFN)

        IRFFN

        FFN的這個(gè)模塊,其實(shí)和mbv2的block基本上就是一樣的了,不一樣的地方在于,使用的是GELU,采用的也是DW+PW來減少標(biāo)準(zhǔn)卷積的計(jì)算量。很簡單,就不多說了,公式如下:

        那么我們一個(gè)block里面的整體計(jì)算公式如下:

        3.3 patch aggregation

        每個(gè)stage都是由上述的多個(gè)CMTblock所堆疊而成, 上面也提到了,這里由于是transformer的操作,不會設(shè)計(jì)到scale尺度的問題,但是模型需要構(gòu)造下采樣,來實(shí)現(xiàn)層次結(jié)構(gòu),所以downsampling的操作單獨(dú)拎了出來,每個(gè)stage之前會做一次卷積核為2x2的,stride為2的卷積操作,以達(dá)到下采樣的效果。

        所以,整體的模型結(jié)構(gòu)就一目了然了,假設(shè)輸入為224x224x3,經(jīng)過CMT-STEM和第一次下采樣后,得到了一個(gè)56x56的featuremap,然后進(jìn)入stage1,輸出不變,經(jīng)過下采樣后,輸入為28x28,進(jìn)入stage2,輸出后經(jīng)過下采樣,輸入為14x14,進(jìn)入stage3,輸出后經(jīng)過最后的下采樣,輸入為7x7,進(jìn)入stage4,最后輸出7x7的特征圖,后面接avgpool和分類,達(dá)到分類的效果。

        我們接下來看一下怎么復(fù)現(xiàn)這篇paper。

        4. 論文復(fù)現(xiàn)

        ps: 這里的復(fù)現(xiàn)指的是沒有源碼的情況下,實(shí)現(xiàn)網(wǎng)絡(luò),訓(xùn)練等,如果是結(jié)果復(fù)現(xiàn),會標(biāo)明為復(fù)現(xiàn)精度。

        這里存在幾個(gè)問題

        • 文章的問題:我看到paper的時(shí)候,是第一個(gè)版本的arxiv,大概過了一周左右V2版本放出來了,這兩個(gè)版本有個(gè)很大的diff。
        V1
        V2

        網(wǎng)絡(luò)結(jié)構(gòu)可以說完全不同的情況下,F(xiàn)LOPs竟然一樣的,當(dāng)然可能是寫錯(cuò)了,這里就不吐槽了。不過我一開始代碼復(fù)現(xiàn)就是按下面來的,所以對于我也沒影響多少,只是體驗(yàn)有點(diǎn)差罷了。

        • 細(xì)節(jié)的問題:paper和很多的transformer一樣,都是采用了Deit的訓(xùn)練策略,但是差別在于別的paper或多或少會給出來額外的tirck,比如最后FC的dp的ratio等,或者會改變一些,再不濟(jì)會把代碼直接release了,所以只好悶頭嘗試Trick。

        4.1 復(fù)現(xiàn)難點(diǎn)

        paper里面采用的Position Embeeding和Swin是類似的,都是Relation Position Bias,但是和Swin不相同的是,我們的Q,K,V尺度是不一樣的。這里我考慮了兩種實(shí)現(xiàn)方法,一種是直接bicubic插值,另一種則是切片,切片更加直觀且embeeding我設(shè)置的可BP,所以,實(shí)現(xiàn)里面采用的是這種方法,代碼如下:

        def generate_relative_distance(number_size):
        """return relative distance, (number_size**2, number_size**2, 2)
        """
        indices = torch.tensor(np.array([[x, y] for x in range(number_size) for y in range(number_size)]))
        distances = indices[None, :, :] - indices[:, None, :]
        distances = distances + number_size - 1 # shift the zeros postion
        return distances
        ...
        elf.position_embeeding = nn.Parameter(torch.randn(2 * self.features_size - 1, 2 * self.features_size - 1))

        ...
        q_n, k_n = q.shape[1], k.shape[2]
        attn = attn + self.position_embeeding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]][:, :k_n]

        4.2 復(fù)現(xiàn)trick歷程(血與淚TT)

        一方面想要看一下model是否是work的,一方面想要順便驗(yàn)證一下DeiT的策略是否真的有效,所以從頭開始做了很多的實(shí)驗(yàn),簡單整理如下:

        • 數(shù)據(jù):
        1. 訓(xùn)練數(shù)據(jù): 20%的imagenet訓(xùn)練數(shù)據(jù)(快速實(shí)驗(yàn))。
        2. 驗(yàn)證數(shù)據(jù): 全量的imagenet驗(yàn)證數(shù)據(jù)。
        • 環(huán)境:
        1. 8xV100 32G
        2. CUDA 10.2 + pytorch 1.7.1
        • sgd優(yōu)化器實(shí)驗(yàn)記錄
        SGD實(shí)驗(yàn)

        結(jié)論: 可以看到在SGD優(yōu)化器的情況下,使用1.6的LR,訓(xùn)練300個(gè)epoch,warmup5個(gè)epoch,是用cosine衰減學(xué)習(xí)率的策略,用randaug+colorjitter+mixup+cutmix+labelsmooth,設(shè)置weightdecay為0.1的配置下,使用QKV的bias以及相對位置偏差,可以達(dá)到比baseline高11%個(gè)點(diǎn)的結(jié)果,所有的實(shí)驗(yàn)都是用FP16跑的。

        • adamw優(yōu)化器實(shí)驗(yàn)記錄
        adamw實(shí)驗(yàn)

        結(jié)論:使用AdamW的情況下,對學(xué)習(xí)率的縮放則是以512的bs為基礎(chǔ),所以對于4k的bs情況下,使用的是4e-3的LR,但是實(shí)驗(yàn)發(fā)現(xiàn)增大到6e-3的時(shí)候,還會帶來一些提升,同時(shí)放大一點(diǎn)weightsdecay,也略微有所提升,最終使用AdamW的配置為,6e-3的LR,1e-1的weightdecay,和sgd一樣的增強(qiáng)方法,然后加上了隨機(jī)深度失活設(shè)置,最后比baseline高了16%個(gè)點(diǎn),比SGD最好的結(jié)果要高0.8%個(gè)點(diǎn)。

        4.3. imagenet上的結(jié)果

        最后用全量跑,使用SGD會報(bào)nan的問題,我定位了一下發(fā)現(xiàn),running_mean和running_std有nan出現(xiàn),本以為是數(shù)據(jù)增強(qiáng)導(dǎo)致的0或者nan值出現(xiàn),結(jié)果空跑幾次數(shù)據(jù)發(fā)現(xiàn)沒問題,只好把優(yōu)化器改成了AdamW,結(jié)果上述所示,CMT-Tiny在160x160的情況下達(dá)到了75.124%的精度,相比MbV2,MbV3的確是一個(gè)不錯(cuò)的精度了,但是相比paper本身的精度還是差了將近4個(gè)點(diǎn),很是離譜。

        速度上,CMT雖然FLOPs低,但是實(shí)際的推理速度并不快,128的bs條件下,速度慢了R50將近10倍。

        5. 實(shí)驗(yàn)結(jié)果

        總體來說,CMT達(dá)到了更小的FLOPs同時(shí)有著不錯(cuò)的精度, imagenet上的結(jié)果如下:

        imagenet

        coco2017上也有這不錯(cuò)的精度

        coco2017

        6. 結(jié)論

        本文提出了一種名為CMT的新型混合架構(gòu),用于視覺識別和其他下游視覺任務(wù),以解決在計(jì)算機(jī)視覺領(lǐng)域以粗暴的方式利用Transformers的限制。所提出的CMT同時(shí)利用CNN和Transformers的優(yōu)勢來捕捉局部和全局信息,促進(jìn)網(wǎng)絡(luò)的表示能力。在ImageNet和其他下游視覺任務(wù)上進(jìn)行的大量實(shí)驗(yàn)證明了所提出的CMT架構(gòu)的有效性和優(yōu)越性。

        代碼復(fù)現(xiàn)repo:

        https://github.com/FlyEgle/CMT-pytorch

        實(shí)現(xiàn)不易,求個(gè)star!


        本文亮點(diǎn)總結(jié)


        1.CMT&Conv stem共性:
        • 使用4層conv3x3+stride2 + conv1x1 stride 1 等價(jià)于VIT的patch embeeding,conv16x16 stride 16.
        • 使用conv stem,可以使模型得到更好的收斂,同時(shí),可以使用SGD優(yōu)化器來訓(xùn)練模型,對于超參數(shù)的依賴沒有原始的那么敏感。好處那是大大的多啊,僅僅是改了一個(gè)conv stem。

        2.本文的一個(gè)核心點(diǎn)是希望模型具有l(wèi)ong-range的能力,同時(shí)還要具有l(wèi)ocal特征的能力,所以提出了LPU這個(gè)模塊,很簡單,一個(gè)3X3的DWconv,來做局部特征,同時(shí)減少點(diǎn)計(jì)算量。

        如果覺得有用,就請分享到朋友圈吧!

        △點(diǎn)擊卡片關(guān)注極市平臺,獲取最新CV干貨

        公眾號后臺回復(fù)“CVPR21檢測”獲取CVPR2021目標(biāo)檢測論文下載~


        極市干貨
        深度學(xué)習(xí)環(huán)境搭建:如何配置一臺深度學(xué)習(xí)工作站?
        實(shí)操教程:OpenVINO2021.4+YOLOX目標(biāo)檢測模型測試部署為什么你的顯卡利用率總是0%?
        算法技巧(trick):圖像分類算法優(yōu)化技巧21個(gè)深度學(xué)習(xí)調(diào)參的實(shí)用技巧


        CV技術(shù)社群邀請函 #

        △長按添加極市小助手
        添加極市小助手微信(ID : cvmart4)

        備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測-深圳)


        即可申請加入極市目標(biāo)檢測/圖像分割/工業(yè)檢測/人臉/醫(yī)學(xué)影像/3D/SLAM/自動駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群


        每月大咖直播分享、真實(shí)項(xiàng)目需求對接、求職內(nèi)推、算法競賽、干貨資訊匯總、與 10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動交流~



        覺得有用麻煩給個(gè)在看啦~  
        瀏覽 496
        點(diǎn)贊
        評論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評論
        圖片
        表情
        推薦
        點(diǎn)贊
        評論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            久热福利| 污污网站免费 | 一边摸一边添高潮av | 日韩久久久久 | 男人爽到不行疯狂叫床 | 韩日不卡视频 | 国产拔擦拔擦8ⅹ8x更快乐 | 亚洲AV无码片 | 亚洲一区无码在线 | 黄色艳情片 |