1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        淺談CMT以及從0-1復(fù)現(xiàn)

        共 9070字,需瀏覽 19分鐘

         ·

        2021-08-11 13:07

        【GiantPandaCV導(dǎo)語】本篇博客講解CMT模型并給出從0-1復(fù)現(xiàn)的過程以及實驗結(jié)果,由于論文的細(xì)節(jié)并沒有給出來,所以最后的復(fù)現(xiàn)和paper的精度有一點差異,等作者release代碼后,我會詳細(xì)的校對我自己的code,找找原因。

        論文鏈接: https://arxiv.org/abs/2107.06263

        論文代碼(個人實現(xiàn)版本): https://github.com/FlyEgle/CMT-pytorch

        知乎專欄:https://www.zhihu.com/people/flyegle

        1. 出發(fā)點

        • Transformers與現(xiàn)有的卷積神經(jīng)網(wǎng)絡(luò)(CNN)在性能和計算成本方面仍有差距。
        • 希望提出的模型不僅可以超越典型的Transformers,而且可以超越高性能卷積模型。

        2. 怎么做

        1. 提出混合模型(串行),通過利用Transformers來捕捉長距離的依賴關(guān)系,并利用CNN來獲取局部特征。
        2. 引入depth-wise卷積,獲取局部特征的同時,減少計算量
        3. 使用類似R50模型結(jié)構(gòu)一樣的stageblock,使得模型具有下采樣增強感受野和遷移dense的能力。
        4. 使用conv-stem來使得圖像的分辨率縮放從VIT的1/16變?yōu)?/4,保留更多的patch信息。

        3. 模型結(jié)構(gòu)

        模型結(jié)構(gòu)
        • (a)表示的是標(biāo)準(zhǔn)的R50模型,具有4個stage,每個都會進(jìn)行一次下采樣。最后得到特征表達(dá)后,經(jīng)過AvgPool進(jìn)行分類
        • (b)表示的是標(biāo)準(zhǔn)的VIT模型,先進(jìn)行patch的劃分,然后embeeding后進(jìn)入Transformer的block,這里,由于Transformer是long range的,所以進(jìn)入什么,輸出就是什么,引入了一個非image的class token來做分類。
        • (c)表示的是本文所提出的模型框架CMT,由CMT-stem, downsampling, cmt block所組成,整體結(jié)構(gòu)則是類似于R50,所以可以很好的遷移到dense任務(wù)上去。

        3.1. CMT Stem

        使用convolution來作為transformer結(jié)構(gòu)的stem,這個觀點FB也有提出一篇paper,Early Convolutions Help Transformers See Better。

        CMT&Conv stem共性

        • 使用4層conv3x3+stride2 + conv1x1 stride 1 等價于VIT的patch embeeding,conv16x16 stride 16.
        • 使用conv stem,可以使模型得到更好的收斂,同時,可以使用SGD優(yōu)化器來訓(xùn)練模型,對于超參數(shù)的依賴沒有原始的那么敏感。好處那是大大的多啊,僅僅是改了一個conv stem。

        CMT&Conv stem異性

        • 本文僅僅做了一次conv3x3 stride2,實際上只有一次下采樣,相比conv stem,可以保留更多的patch的信息到下層。

        從時間上來說,一個20210628(conv stem), 一個是20210713(CMT stem),存在借鑒的可能性還是比較小的,也說明了conv stem的確是work。

        3.2. CMT Block

        每一個stage都是由CMT block所堆疊而成的,CMT block由于是transformer結(jié)構(gòu),所以沒有在stage里面去設(shè)計下采樣。每個CMT block都是由Local Perception Unit, Ligntweight MHSA, Inverted Residual FFN這三個模塊所組成的,下面分別介紹:

        • Local Perception Unit(LPU)

        本文的一個核心點是希望模型具有l(wèi)ong-range的能力,同時還要具有l(wèi)ocal特征的能力,所以提出了LPU這個模塊,很簡單,一個3X3的DWconv,來做局部特征,同時減少點計算量,為了讓Transformers的模塊獲取的longrange的信息不缺失,這里做了一個shortcut,公式描述為:

        • Lightweight MHSA(LMHSA)

        MHSA這個不用多說了,多頭注意力,Lightweight這個作用,PVT(鏈接:https://arxiv.org/abs/2102.12122)曾經(jīng)有提出過,目的是為了降低復(fù)雜度,減少計算量。那本文是怎么做的呢,很簡單,假設(shè)我們的輸入為, 對其分別做一個scale,使用卷積核為,stride為的Depth Wise卷積來做了一次下采樣,得到的shape為,那么對應(yīng)的Q,K,V的shape分別為:

        我們知道,在計算MHSA的時候要遵守兩個計算原則:

        1. Q, K的序列dim要一致。
        2. K, V的token數(shù)量要一致。

        所以,本文中的MHSA計算公式如下:

        • Inverted Resdiual FFN(IRFFN)
        ffn

        FFN的這個模塊,其實和mbv2的block基本上就是一樣的了,不一樣的地方在于,使用的是GELU,采用的也是DW+PW來減少標(biāo)準(zhǔn)卷積的計算量。很簡單,就不多說了,公式如下:

        那么我們一個block里面的整體計算公式如下:

        3.3 patch aggregation

        每個stage都是由上述的多個CMTblock所堆疊而成, 上面也提到了,這里由于是transformer的操作,不會設(shè)計到scale尺度的問題,但是模型需要構(gòu)造下采樣,來實現(xiàn)層次結(jié)構(gòu),所以downsampling的操作單獨拎了出來,每個stage之前會做一次卷積核為2x2的,stride為2的卷積操作,以達(dá)到下采樣的效果。

        所以,整體的模型結(jié)構(gòu)就一目了然了,假設(shè)輸入為224x224x3,經(jīng)過CMT-STEM和第一次下采樣后,得到了一個56x56的featuremap,然后進(jìn)入stage1,輸出不變,經(jīng)過下采樣后,輸入為28x28,進(jìn)入stage2,輸出后經(jīng)過下采樣,輸入為14x14,進(jìn)入stage3,輸出后經(jīng)過最后的下采樣,輸入為7x7,進(jìn)入stage4,最后輸出7x7的特征圖,后面接avgpool和分類,達(dá)到分類的效果。

        我們接下來看一下怎么復(fù)現(xiàn)這篇paper。

        4. 論文復(fù)現(xiàn)

        ps: 這里的復(fù)現(xiàn)指的是沒有源碼的情況下,實現(xiàn)網(wǎng)絡(luò),訓(xùn)練等,如果是結(jié)果復(fù)現(xiàn),會標(biāo)明為復(fù)現(xiàn)精度。

        這里存在幾個問題

        • 文章的問題:我看到paper的時候,是第一個版本的arxiv,大概過了一周左右V2版本放出來了,這兩個版本有個很大的diff。Version1Version2網(wǎng)絡(luò)結(jié)構(gòu)可以說完全不同的情況下,F(xiàn)LOPs竟然一樣的,當(dāng)然可能是寫錯了,這里就不吐槽了。不過我一開始代碼復(fù)現(xiàn)就是按下面來的,所以對于我也沒影響多少,只是體驗有點差罷了。
        • 細(xì)節(jié)的問題:paper和很多的transformer一樣,都是采用了Deit的訓(xùn)練策略,但是差別在于別的paper或多或少會給出來額外的tirck,比如最后FC的dp的ratio等,或者會改變一些,再不濟(jì)會把代碼直接release了,所以只好悶頭嘗試Trick。

        4.1 復(fù)現(xiàn)難點

        paper里面采用的Position Embeeding和Swin是類似的,都是Relation Position Bias,但是和Swin不相同的是,我們的Q,K,V尺度是不一樣的。這里我考慮了兩種實現(xiàn)方法,一種是直接bicubic插值,另一種則是切片,切片更加直觀且embeeding我設(shè)置的可BP,所以,實現(xiàn)里面采用的是這種方法,代碼如下:

        def generate_relative_distance(number_size):
            """return relative distance, (number_size**2, number_size**2, 2)
            """

            indices = torch.tensor(np.array([[x, y] for x in range(number_size) for y in range(number_size)]))
            distances = indices[None, :, :] - indices[:, None, :]
            distances = distances + number_size - 1   # shift the zeros postion
            return distances
        ...
        elf.position_embeeding = nn.Parameter(torch.randn(2 * self.features_size - 12 * self.features_size - 1))

        ...
        q_n, k_n = q.shape[1], k.shape[2]
        attn = attn + self.position_embeeding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]][:, :k_n]

        4.2 復(fù)現(xiàn)trick歷程(血與淚TT)

        一方面想要看一下model是否是work的,一方面想要順便驗證一下DeiT的策略是否真的有效,所以從頭開始做了很多的實驗,簡單整理如下:

        • 數(shù)據(jù):

          1. 訓(xùn)練數(shù)據(jù): 20%的imagenet訓(xùn)練數(shù)據(jù)(快速實驗)。
          2. 驗證數(shù)據(jù): 全量的imagenet驗證數(shù)據(jù)。
        • 環(huán)境:

          1. 8xV100 32G
          2. CUDA 10.2 + pytorch 1.7.1
        • sgd優(yōu)化器實驗記錄

        modelaugmentsresolutionbatchsizeepochoptimizerLRstrategyweightdecaytop-1@acc
        CMT-TINYcrop+flip184->160512X8120SGD1.6cosine1.00E-040.55076
        CMT-TINYcrop+flip+colorjitter+randaug184->160512X8120SGD1.6cosine1.00E-040.59714
        CMT-TINYcrop+flip+colorjitter+randaug+mixup184->160512X8120SGD1.6cosine1.00E-040.57034
        CMT-TINYcrop+flip+colorjitter+randaug+cutmix184->160512X8120SGD1.6cosine1.00E-040.57264
        CMT-TINYcrop+flip+colorjitter+randaug184->160512X8120SGD1.6cosine5.00E-050.59452
        CMT-TINYcrop+flip+colorjitter+randaug+mixup184->160512X8200SGD1.6cosine1.00E-040.60532
        CMT-TINYcrop+flip+colorjitter+randaug+cutmix184->160512X8300SGD1.6cosine1.00E-040.61192
        CMT-TINYcrop+flip+colorjitter+randaug184->160512X8200SGD1.6cosine5.00E-050.60172
        CMT-TINYcrop+flip+colorjitter+randaug184->160512X8120SGD+ape(wrong->resolution)1.6cosine1.00E-040.60276
        CMT-TINYcrop+flip+colorjitter+randaug184->160512X8120SGD+rpe1.6cosine1.00E-040.6016
        CMT-TINYcrop+flip+colorjitter+randaug184->160512X8120SGD+ape(real->resolution)1.6cosine1.00E-040.60368
        CMT-TINYcrop+flip+colorjitter+randaug184->160512X8120SGD+pe_nd1.6cosine1.00E-040.59494
        CMT-TINYcrop+flip+colorjitter+randaug184->160512X8120SGD+qkv_bias1.6cosine1.00E-040.59902
        CMT-TINYcrop+flip+colorjitter+randaug184->160512X8120SGD+qkv_bias+rpe1.6cosine1.00E-040.6023
        CMT-TINYcrop+flip+colorjitter+randaug184->160512X8120SGD+qkv_bias+ape1.6cosine1.00E-040.5986
        CMT-TINYcrop+flip+colorjitter+randaug+no mixup+no_cutmix+labelsmoothing184->160512X8300SGD+qkv_bias+rpe1.6cosine1.00E-040.62108
        CMT-TINYcrop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing184->160512X8300SGD+qkv_bias+rpe1.6cosine1.00E-040.6612

        結(jié)論: 可以看到在SGD優(yōu)化器的情況下,使用1.6的LR,訓(xùn)練300個epoch,warmup5個epoch,是用cosine衰減學(xué)習(xí)率的策略,用randaug+colorjitter+mixup+cutmix+labelsmooth,設(shè)置weightdecay為0.1的配置下,使用QKV的bias以及相對位置偏差,可以達(dá)到比baseline高11%個點的結(jié)果,所有的實驗都是用FP16跑的。

        • adamw優(yōu)化器實驗記錄
        modelaugmentsresolutionbatchsizeepochoptimizerLRstrategyweightdecaytop-1@acc
        CMT-TINYcrop+flip184->160512X8120AdamW4.00E-03cosine5.00E-020.50994
        CMT-TINYcrop+flip+colorjitter+randaug184->160512X8300AdamW4.00E-03cosine5.00E-020.57646
        CMT-TINYcrop+flip+colorjitter+randaug184->160512X8120AdamW4.00E-03cosine1.00E-040.56504
        CMT-TINYcrop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing184->160512X8300adamw+qkv_bias+rpe4.00E-03cosine1.00E-040.63606
        CMT-TINYcrop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing + repsampler184->160512X8300adamw+qkv_bias+rpe4.00E-03cosine1.00E-040.61826
        CMT-TINYcrop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing184->160512X8300adamw+qkv_bias+rpe4.00E-03cosine5.00E-020.64228
        CMT-TINYcrop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing184->160512X8300adamw+qkv_bias+rpe1.00E-04cosine5.00E-020.4049
        CMT-TINYcrop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing + repsampler184->160512X8300adamw+qkv_bias+rpe4.00E-03cosine5.00E-020.63816
        CMT-TINYcrop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing184->160512X8300adamw+qkv_bias+rpe8.00E-03cosine5.00E-02不收斂
        CMT-TINYcrop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing184->160512X8300adamw+qkv_bias+rpe5.00E-03cosine5.00E-020.65118
        CMT-TINYcrop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing184->160512X8300adamw+qkv_bias+rpe6.00E-03cosine5.00E-020.65194
        CMT-TINYcrop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing184->160512X8300adamw+qkv_bias+rpe6.00E-03cosine5.00E-030.63726
        CMT-TINYcrop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing184->160512X8300adamw+qkv_bias+rpe6.00E-03cosine1.00E-010.65502
        CMT-TINYcrop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing+warmup20184->160512X8300adamw+qkv_bias+rpe6.00E-03cosine1.00E-010.65082
        CMT-TINYcrop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing+droppath184->160512X8300adamw+qkv_bias+rpe6.00E-03cosine1.00E-010.66908

        結(jié)論:使用AdamW的情況下,對學(xué)習(xí)率的縮放則是以512的bs為基礎(chǔ),所以對于4k的bs情況下,使用的是4e-3的LR,但是實驗發(fā)現(xiàn)增大到6e-3的時候,還會帶來一些提升,同時放大一點weightsdecay,也略微有所提升,最終使用AdamW的配置為,6e-3的LR,1e-1的weightdecay,和sgd一樣的增強方法,然后加上了隨機(jī)深度失活設(shè)置,最后比baseline高了16%個點,比SGD最好的結(jié)果要高0.8%個點。

        4.3. imagenet上的結(jié)果

        result

        最后用全量跑,使用SGD會報nan的問題,我定位了一下發(fā)現(xiàn),running_mean和running_std有nan出現(xiàn),本以為是數(shù)據(jù)增強導(dǎo)致的0或者nan值出現(xiàn),結(jié)果空跑幾次數(shù)據(jù)發(fā)現(xiàn)沒問題,只好把優(yōu)化器改成了AdamW,結(jié)果上述所示,CMT-Tiny在160x160的情況下達(dá)到了75.124%的精度,相比MbV2,MbV3的確是一個不錯的精度了,但是相比paper本身的精度還是差了將近4個點,很是離譜。

        速度上,CMT雖然FLOPs低,但是實際的推理速度并不快,128的bs條件下,速度慢了R50將近10倍。

        5. 實驗結(jié)果

        總體來說,CMT達(dá)到了更小的FLOPs同時有著不錯的精度, imagenet上的結(jié)果如下:

        coco2017上也有這不錯的精度

        6. 結(jié)論

        本文提出了一種名為CMT的新型混合架構(gòu),用于視覺識別和其他下游視覺任務(wù),以解決在計算機(jī)視覺領(lǐng)域以粗暴的方式利用Transformers的限制。所提出的CMT同時利用CNN和Transformers的優(yōu)勢來捕捉局部和全局信息,促進(jìn)網(wǎng)絡(luò)的表示能力。在ImageNet和其他下游視覺任務(wù)上進(jìn)行的大量實驗證明了所提出的CMT架構(gòu)的有效性和優(yōu)越性。

        代碼復(fù)現(xiàn)repo: https://github.com/FlyEgle/CMT-pytorch, 實現(xiàn)不易,求個star!


        - END - 

        歡迎添加微信,加入GiantPandaCV交流群

        瀏覽 49
        點贊
        評論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報
        評論
        圖片
        表情
        推薦
        點贊
        評論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            国产精品久久久久久久白晢女i | 亚洲人成人无码一区二区三区 | 色噜噜综合在线 | 免费 成人 结九幺视频 | 国产成人精品无码一区二区蜜柚 | 尤物视频网站免费观看 | 少妇下面好紧好舒服 | 国产福利三区 | 黄色片A片| 狠狠的挺进女同学的小泬漫画 |