淺談CMT以及從0-1復(fù)現(xiàn)
【GiantPandaCV導(dǎo)語】本篇博客講解CMT模型并給出從0-1復(fù)現(xiàn)的過程以及實驗結(jié)果,由于論文的細(xì)節(jié)并沒有給出來,所以最后的復(fù)現(xiàn)和paper的精度有一點差異,等作者release代碼后,我會詳細(xì)的校對我自己的code,找找原因。
論文鏈接: https://arxiv.org/abs/2107.06263
論文代碼(個人實現(xiàn)版本): https://github.com/FlyEgle/CMT-pytorch
知乎專欄:https://www.zhihu.com/people/flyegle
1. 出發(fā)點
Transformers與現(xiàn)有的卷積神經(jīng)網(wǎng)絡(luò)(CNN)在性能和計算成本方面仍有差距。 希望提出的模型不僅可以超越典型的Transformers,而且可以超越高性能卷積模型。
2. 怎么做
提出混合模型(串行),通過利用Transformers來捕捉長距離的依賴關(guān)系,并利用CNN來獲取局部特征。 引入depth-wise卷積,獲取局部特征的同時,減少計算量 使用類似R50模型結(jié)構(gòu)一樣的stageblock,使得模型具有下采樣增強感受野和遷移dense的能力。 使用conv-stem來使得圖像的分辨率縮放從VIT的1/16變?yōu)?/4,保留更多的patch信息。
3. 模型結(jié)構(gòu)

(a)表示的是標(biāo)準(zhǔn)的R50模型,具有4個stage,每個都會進(jìn)行一次下采樣。最后得到特征表達(dá)后,經(jīng)過AvgPool進(jìn)行分類 (b)表示的是標(biāo)準(zhǔn)的VIT模型,先進(jìn)行patch的劃分,然后embeeding后進(jìn)入Transformer的block,這里,由于Transformer是long range的,所以進(jìn)入什么,輸出就是什么,引入了一個非image的class token來做分類。 (c)表示的是本文所提出的模型框架CMT,由CMT-stem, downsampling, cmt block所組成,整體結(jié)構(gòu)則是類似于R50,所以可以很好的遷移到dense任務(wù)上去。
3.1. CMT Stem
使用convolution來作為transformer結(jié)構(gòu)的stem,這個觀點FB也有提出一篇paper,Early Convolutions Help Transformers See Better。
CMT&Conv stem共性
使用4層conv3x3+stride2 + conv1x1 stride 1 等價于VIT的patch embeeding,conv16x16 stride 16. 使用conv stem,可以使模型得到更好的收斂,同時,可以使用SGD優(yōu)化器來訓(xùn)練模型,對于超參數(shù)的依賴沒有原始的那么敏感。好處那是大大的多啊,僅僅是改了一個conv stem。
CMT&Conv stem異性
本文僅僅做了一次conv3x3 stride2,實際上只有一次下采樣,相比conv stem,可以保留更多的patch的信息到下層。
從時間上來說,一個20210628(conv stem), 一個是20210713(CMT stem),存在借鑒的可能性還是比較小的,也說明了conv stem的確是work。
3.2. CMT Block
每一個stage都是由CMT block所堆疊而成的,CMT block由于是transformer結(jié)構(gòu),所以沒有在stage里面去設(shè)計下采樣。每個CMT block都是由Local Perception Unit, Ligntweight MHSA, Inverted Residual FFN這三個模塊所組成的,下面分別介紹:
Local Perception Unit(LPU)
本文的一個核心點是希望模型具有l(wèi)ong-range的能力,同時還要具有l(wèi)ocal特征的能力,所以提出了LPU這個模塊,很簡單,一個3X3的DWconv,來做局部特征,同時減少點計算量,為了讓Transformers的模塊獲取的longrange的信息不缺失,這里做了一個shortcut,公式描述為:
Lightweight MHSA(LMHSA)
MHSA這個不用多說了,多頭注意力,Lightweight這個作用,PVT(鏈接:https://arxiv.org/abs/2102.12122)曾經(jīng)有提出過,目的是為了降低復(fù)雜度,減少計算量。那本文是怎么做的呢,很簡單,假設(shè)我們的輸入為, 對其分別做一個scale,使用卷積核為,stride為的Depth Wise卷積來做了一次下采樣,得到的shape為,那么對應(yīng)的Q,K,V的shape分別為:
我們知道,在計算MHSA的時候要遵守兩個計算原則:
Q, K的序列dim要一致。 K, V的token數(shù)量要一致。
所以,本文中的MHSA計算公式如下:
Inverted Resdiual FFN(IRFFN)

FFN的這個模塊,其實和mbv2的block基本上就是一樣的了,不一樣的地方在于,使用的是GELU,采用的也是DW+PW來減少標(biāo)準(zhǔn)卷積的計算量。很簡單,就不多說了,公式如下:
那么我們一個block里面的整體計算公式如下:
3.3 patch aggregation
每個stage都是由上述的多個CMTblock所堆疊而成, 上面也提到了,這里由于是transformer的操作,不會設(shè)計到scale尺度的問題,但是模型需要構(gòu)造下采樣,來實現(xiàn)層次結(jié)構(gòu),所以downsampling的操作單獨拎了出來,每個stage之前會做一次卷積核為2x2的,stride為2的卷積操作,以達(dá)到下采樣的效果。
所以,整體的模型結(jié)構(gòu)就一目了然了,假設(shè)輸入為224x224x3,經(jīng)過CMT-STEM和第一次下采樣后,得到了一個56x56的featuremap,然后進(jìn)入stage1,輸出不變,經(jīng)過下采樣后,輸入為28x28,進(jìn)入stage2,輸出后經(jīng)過下采樣,輸入為14x14,進(jìn)入stage3,輸出后經(jīng)過最后的下采樣,輸入為7x7,進(jìn)入stage4,最后輸出7x7的特征圖,后面接avgpool和分類,達(dá)到分類的效果。
我們接下來看一下怎么復(fù)現(xiàn)這篇paper。
4. 論文復(fù)現(xiàn)
ps: 這里的復(fù)現(xiàn)指的是沒有源碼的情況下,實現(xiàn)網(wǎng)絡(luò),訓(xùn)練等,如果是結(jié)果復(fù)現(xiàn),會標(biāo)明為復(fù)現(xiàn)精度。
這里存在幾個問題
文章的問題:我看到paper的時候,是第一個版本的arxiv,大概過了一周左右V2版本放出來了,這兩個版本有個很大的diff。Version1
Version2
網(wǎng)絡(luò)結(jié)構(gòu)可以說完全不同的情況下,F(xiàn)LOPs竟然一樣的,當(dāng)然可能是寫錯了,這里就不吐槽了。不過我一開始代碼復(fù)現(xiàn)就是按下面來的,所以對于我也沒影響多少,只是體驗有點差罷了。細(xì)節(jié)的問題:paper和很多的transformer一樣,都是采用了Deit的訓(xùn)練策略,但是差別在于別的paper或多或少會給出來額外的tirck,比如最后FC的dp的ratio等,或者會改變一些,再不濟(jì)會把代碼直接release了,所以只好悶頭嘗試Trick。
4.1 復(fù)現(xiàn)難點
paper里面采用的Position Embeeding和Swin是類似的,都是Relation Position Bias,但是和Swin不相同的是,我們的Q,K,V尺度是不一樣的。這里我考慮了兩種實現(xiàn)方法,一種是直接bicubic插值,另一種則是切片,切片更加直觀且embeeding我設(shè)置的可BP,所以,實現(xiàn)里面采用的是這種方法,代碼如下:
def generate_relative_distance(number_size):
"""return relative distance, (number_size**2, number_size**2, 2)
"""
indices = torch.tensor(np.array([[x, y] for x in range(number_size) for y in range(number_size)]))
distances = indices[None, :, :] - indices[:, None, :]
distances = distances + number_size - 1 # shift the zeros postion
return distances
...
elf.position_embeeding = nn.Parameter(torch.randn(2 * self.features_size - 1, 2 * self.features_size - 1))
...
q_n, k_n = q.shape[1], k.shape[2]
attn = attn + self.position_embeeding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]][:, :k_n]
4.2 復(fù)現(xiàn)trick歷程(血與淚TT)
一方面想要看一下model是否是work的,一方面想要順便驗證一下DeiT的策略是否真的有效,所以從頭開始做了很多的實驗,簡單整理如下:
數(shù)據(jù):
訓(xùn)練數(shù)據(jù): 20%的imagenet訓(xùn)練數(shù)據(jù)(快速實驗)。 驗證數(shù)據(jù): 全量的imagenet驗證數(shù)據(jù)。 環(huán)境:
8xV100 32G CUDA 10.2 + pytorch 1.7.1 sgd優(yōu)化器實驗記錄
| model | augments | resolution | batchsize | epoch | optimizer | LR | strategy | weightdecay | top-1@acc |
|---|---|---|---|---|---|---|---|---|---|
| CMT-TINY | crop+flip | 184->160 | 512X8 | 120 | SGD | 1.6 | cosine | 1.00E-04 | 0.55076 |
| CMT-TINY | crop+flip+colorjitter+randaug | 184->160 | 512X8 | 120 | SGD | 1.6 | cosine | 1.00E-04 | 0.59714 |
| CMT-TINY | crop+flip+colorjitter+randaug+mixup | 184->160 | 512X8 | 120 | SGD | 1.6 | cosine | 1.00E-04 | 0.57034 |
| CMT-TINY | crop+flip+colorjitter+randaug+cutmix | 184->160 | 512X8 | 120 | SGD | 1.6 | cosine | 1.00E-04 | 0.57264 |
| CMT-TINY | crop+flip+colorjitter+randaug | 184->160 | 512X8 | 120 | SGD | 1.6 | cosine | 5.00E-05 | 0.59452 |
| CMT-TINY | crop+flip+colorjitter+randaug+mixup | 184->160 | 512X8 | 200 | SGD | 1.6 | cosine | 1.00E-04 | 0.60532 |
| CMT-TINY | crop+flip+colorjitter+randaug+cutmix | 184->160 | 512X8 | 300 | SGD | 1.6 | cosine | 1.00E-04 | 0.61192 |
| CMT-TINY | crop+flip+colorjitter+randaug | 184->160 | 512X8 | 200 | SGD | 1.6 | cosine | 5.00E-05 | 0.60172 |
| CMT-TINY | crop+flip+colorjitter+randaug | 184->160 | 512X8 | 120 | SGD+ape(wrong->resolution) | 1.6 | cosine | 1.00E-04 | 0.60276 |
| CMT-TINY | crop+flip+colorjitter+randaug | 184->160 | 512X8 | 120 | SGD+rpe | 1.6 | cosine | 1.00E-04 | 0.6016 |
| CMT-TINY | crop+flip+colorjitter+randaug | 184->160 | 512X8 | 120 | SGD+ape(real->resolution) | 1.6 | cosine | 1.00E-04 | 0.60368 |
| CMT-TINY | crop+flip+colorjitter+randaug | 184->160 | 512X8 | 120 | SGD+pe_nd | 1.6 | cosine | 1.00E-04 | 0.59494 |
| CMT-TINY | crop+flip+colorjitter+randaug | 184->160 | 512X8 | 120 | SGD+qkv_bias | 1.6 | cosine | 1.00E-04 | 0.59902 |
| CMT-TINY | crop+flip+colorjitter+randaug | 184->160 | 512X8 | 120 | SGD+qkv_bias+rpe | 1.6 | cosine | 1.00E-04 | 0.6023 |
| CMT-TINY | crop+flip+colorjitter+randaug | 184->160 | 512X8 | 120 | SGD+qkv_bias+ape | 1.6 | cosine | 1.00E-04 | 0.5986 |
| CMT-TINY | crop+flip+colorjitter+randaug+no mixup+no_cutmix+labelsmoothing | 184->160 | 512X8 | 300 | SGD+qkv_bias+rpe | 1.6 | cosine | 1.00E-04 | 0.62108 |
| CMT-TINY | crop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing | 184->160 | 512X8 | 300 | SGD+qkv_bias+rpe | 1.6 | cosine | 1.00E-04 | 0.6612 |
結(jié)論: 可以看到在SGD優(yōu)化器的情況下,使用1.6的LR,訓(xùn)練300個epoch,warmup5個epoch,是用cosine衰減學(xué)習(xí)率的策略,用randaug+colorjitter+mixup+cutmix+labelsmooth,設(shè)置weightdecay為0.1的配置下,使用QKV的bias以及相對位置偏差,可以達(dá)到比baseline高11%個點的結(jié)果,所有的實驗都是用FP16跑的。
adamw優(yōu)化器實驗記錄
| model | augments | resolution | batchsize | epoch | optimizer | LR | strategy | weightdecay | top-1@acc |
|---|---|---|---|---|---|---|---|---|---|
| CMT-TINY | crop+flip | 184->160 | 512X8 | 120 | AdamW | 4.00E-03 | cosine | 5.00E-02 | 0.50994 |
| CMT-TINY | crop+flip+colorjitter+randaug | 184->160 | 512X8 | 300 | AdamW | 4.00E-03 | cosine | 5.00E-02 | 0.57646 |
| CMT-TINY | crop+flip+colorjitter+randaug | 184->160 | 512X8 | 120 | AdamW | 4.00E-03 | cosine | 1.00E-04 | 0.56504 |
| CMT-TINY | crop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing | 184->160 | 512X8 | 300 | adamw+qkv_bias+rpe | 4.00E-03 | cosine | 1.00E-04 | 0.63606 |
| CMT-TINY | crop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing + repsampler | 184->160 | 512X8 | 300 | adamw+qkv_bias+rpe | 4.00E-03 | cosine | 1.00E-04 | 0.61826 |
| CMT-TINY | crop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing | 184->160 | 512X8 | 300 | adamw+qkv_bias+rpe | 4.00E-03 | cosine | 5.00E-02 | 0.64228 |
| CMT-TINY | crop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing | 184->160 | 512X8 | 300 | adamw+qkv_bias+rpe | 1.00E-04 | cosine | 5.00E-02 | 0.4049 |
| CMT-TINY | crop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing + repsampler | 184->160 | 512X8 | 300 | adamw+qkv_bias+rpe | 4.00E-03 | cosine | 5.00E-02 | 0.63816 |
| CMT-TINY | crop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing | 184->160 | 512X8 | 300 | adamw+qkv_bias+rpe | 8.00E-03 | cosine | 5.00E-02 | 不收斂 |
| CMT-TINY | crop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing | 184->160 | 512X8 | 300 | adamw+qkv_bias+rpe | 5.00E-03 | cosine | 5.00E-02 | 0.65118 |
| CMT-TINY | crop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing | 184->160 | 512X8 | 300 | adamw+qkv_bias+rpe | 6.00E-03 | cosine | 5.00E-02 | 0.65194 |
| CMT-TINY | crop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing | 184->160 | 512X8 | 300 | adamw+qkv_bias+rpe | 6.00E-03 | cosine | 5.00E-03 | 0.63726 |
| CMT-TINY | crop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing | 184->160 | 512X8 | 300 | adamw+qkv_bias+rpe | 6.00E-03 | cosine | 1.00E-01 | 0.65502 |
| CMT-TINY | crop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing+warmup20 | 184->160 | 512X8 | 300 | adamw+qkv_bias+rpe | 6.00E-03 | cosine | 1.00E-01 | 0.65082 |
| CMT-TINY | crop+flip+colorjitter+randaug+mixup+cutmix+labelsmoothing+droppath | 184->160 | 512X8 | 300 | adamw+qkv_bias+rpe | 6.00E-03 | cosine | 1.00E-01 | 0.66908 |
結(jié)論:使用AdamW的情況下,對學(xué)習(xí)率的縮放則是以512的bs為基礎(chǔ),所以對于4k的bs情況下,使用的是4e-3的LR,但是實驗發(fā)現(xiàn)增大到6e-3的時候,還會帶來一些提升,同時放大一點weightsdecay,也略微有所提升,最終使用AdamW的配置為,6e-3的LR,1e-1的weightdecay,和sgd一樣的增強方法,然后加上了隨機(jī)深度失活設(shè)置,最后比baseline高了16%個點,比SGD最好的結(jié)果要高0.8%個點。
4.3. imagenet上的結(jié)果

最后用全量跑,使用SGD會報nan的問題,我定位了一下發(fā)現(xiàn),running_mean和running_std有nan出現(xiàn),本以為是數(shù)據(jù)增強導(dǎo)致的0或者nan值出現(xiàn),結(jié)果空跑幾次數(shù)據(jù)發(fā)現(xiàn)沒問題,只好把優(yōu)化器改成了AdamW,結(jié)果上述所示,CMT-Tiny在160x160的情況下達(dá)到了75.124%的精度,相比MbV2,MbV3的確是一個不錯的精度了,但是相比paper本身的精度還是差了將近4個點,很是離譜。
速度上,CMT雖然FLOPs低,但是實際的推理速度并不快,128的bs條件下,速度慢了R50將近10倍。
5. 實驗結(jié)果
總體來說,CMT達(dá)到了更小的FLOPs同時有著不錯的精度, imagenet上的結(jié)果如下:
coco2017上也有這不錯的精度
6. 結(jié)論
本文提出了一種名為CMT的新型混合架構(gòu),用于視覺識別和其他下游視覺任務(wù),以解決在計算機(jī)視覺領(lǐng)域以粗暴的方式利用Transformers的限制。所提出的CMT同時利用CNN和Transformers的優(yōu)勢來捕捉局部和全局信息,促進(jìn)網(wǎng)絡(luò)的表示能力。在ImageNet和其他下游視覺任務(wù)上進(jìn)行的大量實驗證明了所提出的CMT架構(gòu)的有效性和優(yōu)越性。
代碼復(fù)現(xiàn)repo: https://github.com/FlyEgle/CMT-pytorch, 實現(xiàn)不易,求個star!
- END -
歡迎添加微信,加入GiantPandaCV交流群
