MEA:視覺無(wú)監(jiān)督訓(xùn)練新范式
點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師”
設(shè)為星標(biāo),干貨直達(dá)!
近日,F(xiàn)AIR的最新論文Masked Autoencoders Are Scalable Vision Learners(何愷明一作)提出了一種更簡(jiǎn)單有效的用于ViT無(wú)監(jiān)督訓(xùn)練的方法MAE,并在ImageNet-1K數(shù)據(jù)集上的top-1 acc達(dá)到新的SOTA:87.8%(無(wú)額外訓(xùn)練數(shù)據(jù))。自從ViT火了之后,一些研究者就開始嘗試研究ViT的無(wú)監(jiān)督學(xué)習(xí),比如Mocov3用對(duì)比學(xué)習(xí)的方法無(wú)監(jiān)督訓(xùn)練ViT,此外也有一些研究開始借鑒BERT中的MLM(masked language modeling)方法,比如BEiT提出了用于圖像的無(wú)監(jiān)督學(xué)習(xí)方法:MIM(masked image modeling)。無(wú)疑,MAE方法也落在MIM的范疇,但整個(gè)論文會(huì)給人更震撼之感,因?yàn)镸EA方法更簡(jiǎn)單有效。
NLP領(lǐng)域的BERT提出的預(yù)訓(xùn)練方法本質(zhì)上也是一種masked autoencoding:去除數(shù)據(jù)的一部分然后學(xué)習(xí)恢復(fù)。這種masked autoencoding方法也很早就在圖像領(lǐng)域應(yīng)用,比如Stacked Denoising Autoencoders。但是NLP領(lǐng)域已經(jīng)在BERT之后采用這種方法在無(wú)監(jiān)督學(xué)習(xí)上取得非常大的進(jìn)展,比如目前已經(jīng)可以訓(xùn)練超過(guò)1000億參數(shù)的大模型,但是圖像領(lǐng)域卻遠(yuǎn)遠(yuǎn)落后,而且目前主流的無(wú)監(jiān)督訓(xùn)練還是對(duì)比學(xué)習(xí)。那么究竟是什么造成了masked autoencoding方法在NLP和CV上的差異呢?MEA論文從三個(gè)方面做了分析,這也是MEA方法的立意:
圖像的主流模型是CNN,而NLP的主流模型是transformer,CNN和transformer的架構(gòu)不同導(dǎo)致NLP的BERT很難直接遷移到CV。但是vision transformer的出現(xiàn)已經(jīng)解決這個(gè)問(wèn)題; 圖像和文本的信息密度不同,文本是高語(yǔ)義的人工創(chuàng)造的符號(hào),而圖像是一種自然信號(hào),兩者采用masked autoencoding建模任務(wù)難度就不一樣,從句子中預(yù)測(cè)丟失的詞本身就是一種復(fù)雜的語(yǔ)言理解任務(wù),但是圖像存在很大的信息冗余,一個(gè)丟失的圖像塊很容易利用周邊的圖像區(qū)域進(jìn)行恢復(fù); 用于重建的decoder在圖像和文本任務(wù)發(fā)揮的角色有區(qū)別,從句子中預(yù)測(cè)單詞屬于高語(yǔ)義任務(wù),encoder和decoder的gap小,所以BERT的decoder部分微不足道(只需要一個(gè)MLP),而對(duì)圖像重建像素屬于低語(yǔ)義任務(wù)(相比圖像分類),encoder需要發(fā)揮更大作用:將高語(yǔ)義的中間表征恢復(fù)成低語(yǔ)義的像素值。
基于這三個(gè)的分析,論文提出了一種用于圖像領(lǐng)域(ViT模型)的更簡(jiǎn)單有效的無(wú)監(jiān)督訓(xùn)練方法:MAE(masked autoencoder),隨機(jī)mask掉部分patchs然后進(jìn)行重建,其整體架構(gòu)如下所示。MAE采用encoder-decoder結(jié)構(gòu)(分析3,需要單獨(dú)的decoder),但屬于非對(duì)稱結(jié)構(gòu),一方面decoder采用比encoder更輕量級(jí)設(shè)計(jì),另外一方面encoder只處理一部分patchs(visible patchs,除了masked patchs之外的patchs),而encoder處理所有的patchs。一個(gè)很重要的點(diǎn),MEA采用很高的masking ratio(比如75%甚至更高),這契合分析2,這樣構(gòu)建的學(xué)習(xí)任務(wù)大大降低了信息冗余,也使得encoder能學(xué)習(xí)到更高級(jí)的特征。由于encoder只處理visible patchs,所以很高的masking ratio可以大大降低計(jì)算量。

MEA采用的masking策略是簡(jiǎn)單的隨機(jī)mask:基于均勻分布從圖像的patchs隨機(jī)抽樣一部分patchs進(jìn)行mask。每個(gè)被mask的patch采用mask token來(lái)替代,mask token是一個(gè)共享且可學(xué)習(xí)的向量。MEA的encoder采用ViT模型,只處理visible patchs,visible patchs通過(guò)linear projection得到patch embedding輸入到ViT的transformer blocks進(jìn)行處理;而decoder是一個(gè)輕量級(jí)模塊,主體包含幾個(gè)transformer blocks,而最后一層是一個(gè)linear層(輸出是和一個(gè)patch像素?cái)?shù)一致),用來(lái)直接預(yù)測(cè)masked patch的像素值。decoder的輸入是所有的tokens:encoded visible patchs和mask tokens,它們要加上對(duì)應(yīng)的positional embeddings。訓(xùn)練的loss采用簡(jiǎn)單的MSE:計(jì)算預(yù)測(cè)像素值和原始像素值的均方誤差,不過(guò)loss只計(jì)算masked patchs。MEA的實(shí)現(xiàn)非常簡(jiǎn)單:首先對(duì)輸入的patch進(jìn)行l(wèi)inear projection得到patch embeddings,并加上positional embeddings(采用sine-cosine版本);然后對(duì)tokens列表進(jìn)行random shuffle,根據(jù)masking ratio去掉列表中后面的一部分tokens,然后送入encoder中,這里注意ViT中需要一個(gè)class token來(lái)做圖像分類,所以這里的輸入也要增加一個(gè)dummy token(如果最后分類采用global avg pooling就不需要這個(gè));encoder處理后,在tokens列表后面補(bǔ)足mask tokens,然后通過(guò)unshuffle來(lái)恢復(fù)tokens列表中tokens的原始位置,然后再加上positional embeddings(mask tokens本身并無(wú)位置信息,所以還要此操作)送入decoder中進(jìn)行處理。
論文選擇ViT-Large(ViT-L/16)作為encoder在ImageNet-1K上實(shí)驗(yàn),首先進(jìn)行無(wú)監(jiān)督預(yù)訓(xùn)練,然后進(jìn)行監(jiān)督訓(xùn)練以評(píng)估encoder的表征能力,包括常用linear probing和finetune兩個(gè)實(shí)驗(yàn)結(jié)果。下表是baseline MEA方法的實(shí)驗(yàn)結(jié)果,可以看到經(jīng)過(guò)MEA預(yù)訓(xùn)練后finetune的效果要超過(guò)直接從頭訓(xùn)練(84.9 vs 82.5):
更重要的是,論文做了MEA各個(gè)部分的不同設(shè)置對(duì)比實(shí)驗(yàn),這些實(shí)驗(yàn)?zāi)軌蚪沂綧EA更多的特性。首先是masking ratio,從下圖可以看到,最優(yōu)的設(shè)置是75%的masking ratio,此時(shí)linear probing和finetune效果最好,這比之前的研究要高很多,比如BEiT的masking ratio是40%。另外也可以看到linear probing和finetune的表現(xiàn)不一樣,linear probing效果隨著masking ratio的增加逐漸提高直至一個(gè)峰值后出現(xiàn)下降,而finetune效果在不同making ratio下差異小,masking ratio在40%~80%范圍內(nèi)均能表現(xiàn)較好。
這么高的masking ratio,模型到底能學(xué)習(xí)到什么?這里采用預(yù)訓(xùn)練好的模型在驗(yàn)證集進(jìn)行重建,效果如下所示,可以看到decoder重建出來(lái)的圖像還是比較讓人驚艷的(95%的masking ratio竟然也能work?。@或許說(shuō)明模型已經(jīng)學(xué)習(xí)到比較好的特征。
第二個(gè)是encoder的設(shè)計(jì),這里主要探討decoder的深度(transformer blocks數(shù)量)和寬度(channels數(shù)量)對(duì)效果的影響,實(shí)驗(yàn)結(jié)果如下表所示。首先,要想得到比較好的linear probing效果,就需要一個(gè)比較深的decoder,這不難理解,前面說(shuō)過(guò)重建圖像和圖像識(shí)別兩個(gè)任務(wù)的gap較大,如果decoder比較深,那么decoder就有足夠的容量學(xué)習(xí)到重建能力,這樣encoder可以更專注于提取特征。但是不同的深度對(duì)finetune效果影響較小,只用一個(gè)transformer block就可以work。相比之下,網(wǎng)絡(luò)寬度對(duì)linear probing影響比網(wǎng)絡(luò)深度要小一點(diǎn)。論文選擇的默認(rèn)設(shè)置是:8個(gè)blocks,width為512,一個(gè)token的FLOPs只有encoder的9%。
第三個(gè)是mask token,這里探討的是encoder是否處理mask tokens帶來(lái)的影響,從對(duì)比實(shí)驗(yàn)來(lái)看,encoder不處理mask tokens不僅效果更好而且訓(xùn)練更高效,首先linear probing的效果差異非常大,如果encoder也處理mask tokens,此時(shí)linear probing的效果較差,這主要是訓(xùn)練和測(cè)試的不一致帶來(lái)的,因?yàn)闇y(cè)試時(shí)都是正常的圖像,但經(jīng)過(guò)finetune后也能得到較好的效果。最重要的是,不處理mask tokens模型的FLOPs大大降低(3.3x),而且訓(xùn)練也能加速2.8倍,這里也可以看到采用較小的decoder可以進(jìn)一步加速訓(xùn)練。
第四個(gè)是探討不同的重建目標(biāo)對(duì)效果的影響,從對(duì)比實(shí)驗(yàn)看,如果對(duì)像素值做歸一化處理(用patch所有像素點(diǎn)的mean和std),效果有一定提升,采用PCA處理效果無(wú)提升。這里也實(shí)驗(yàn)了BEiT采用的dVAE tokenizer,此時(shí)訓(xùn)練loss是交叉熵,從效果上看比baseline有一定提升(finetune有提升,但是linear probing下降),但不如歸一化處理的結(jié)果。注意的是dVAE tokenizer需要非常大的數(shù)據(jù)來(lái)單獨(dú)訓(xùn)練,這是非常不方便的。
第五個(gè)是數(shù)據(jù)增強(qiáng)的影響,這里讓人驚奇的是MEA在無(wú)數(shù)據(jù)增強(qiáng)下(center crop)依然可以表現(xiàn)出好的效果,如果采用random crop(固定size或隨機(jī)size)+random horizontal flipping(其實(shí)也屬于輕量級(jí))效果有微弱的提升,但加上color jit效果反而有所下降。相比之下,對(duì)比學(xué)習(xí)往往需要非常heavy的數(shù)據(jù)增強(qiáng)。這差異的背后主要是因?yàn)镸EA采用的random mask patch已經(jīng)起到了數(shù)據(jù)增強(qiáng)的效果。
第六個(gè)是mask sampling策略的影響,相比BEiT采用的block-wise或grid-wise方式,random sampling效果最好。
另外,論文也發(fā)現(xiàn)MEA和對(duì)比學(xué)習(xí)方法在training schedule上也存在差異,之前的實(shí)驗(yàn)都是基于800 epoch的訓(xùn)練時(shí)長(zhǎng),而實(shí)驗(yàn)發(fā)現(xiàn)訓(xùn)練到更長(zhǎng)的epoch(1600 epoch+),模型的linear probing性能依然還在上升,而MoCoV3在300 epoch后就飽和了。不過(guò),MEA在75%的masking ratio下每個(gè)epoch其實(shí)只相當(dāng)于見了25%的數(shù)據(jù),而對(duì)比學(xué)習(xí)往往學(xué)習(xí)two-crop和multi-crop,每個(gè)epoch見到的數(shù)據(jù)在200%以上,這也意味著MEA可以訓(xùn)練更多的epoch。雖然MEA訓(xùn)練更長(zhǎng),但是由于其特殊的設(shè)置,基于ViT-L的MEA訓(xùn)練1600 epoch的時(shí)長(zhǎng)比MoCoV3訓(xùn)練300 epoch還要短(31h vs 36h)。

MEA與其它無(wú)監(jiān)督方法的對(duì)比如下所示,可以看到在同樣條件下MEA要比BEiT更好,而且也超過(guò)有監(jiān)督訓(xùn)練,其中ViT-H在448大小finetune后在ImageNet上達(dá)到了87.8%的top1 acc。不過(guò)MEA的效果還是比谷歌采用JFT300M訓(xùn)練的ViT要差一些,這說(shuō)明訓(xùn)練數(shù)據(jù)量可能是一個(gè)瓶頸。
同時(shí),論文也對(duì)比了MEA訓(xùn)練的encoder在下游任務(wù)(檢測(cè)和分割)的遷移能力,同等條件下,MEA均能超過(guò)有監(jiān)督訓(xùn)練或者其它無(wú)監(jiān)督訓(xùn)練方法:
論文最后還有一個(gè)額外的部分,那就是對(duì)linear probing評(píng)估方式的討論。從前面的實(shí)驗(yàn)我們看到,雖然MEA訓(xùn)練的encoder在finetune下能取得比較SOTA的結(jié)果,但是其linear probing和finetune效果存在不小的差異,單從linear probing效果來(lái)看,MEA并不比MoCoV3要好(ViT-L:73.5 vs 77.6)。雖然linear probing一直是無(wú)監(jiān)督訓(xùn)練的最常用的評(píng)估方法,但是它追求的是encoder提取特征的線性可分能力,這不并能成為唯一的一個(gè)評(píng)價(jià)指標(biāo),而且linear probing也不能很好地和下游任務(wù)遷移能力關(guān)聯(lián)起來(lái)。所以論文額外做了partial fine-tuning的實(shí)驗(yàn),這里可以看到如果僅對(duì)encoder的最后一個(gè)block進(jìn)行finetune的話,MAE就能達(dá)到和MoCoV3一樣的效果,如果finetune更多的blocks,MAE就會(huì)超過(guò)MoCoV3。這說(shuō)明雖然MAE得到的特征線性可分能力差了點(diǎn),但是它其實(shí)是更強(qiáng)的非線性特征。
最后談一點(diǎn)自己對(duì)MEA的認(rèn)識(shí):首先MEA并不是第一個(gè)基于MIM方法做無(wú)監(jiān)督訓(xùn)練,之前微軟的BEiT基于MIM也取得了很好的效果,還有MST和iBOT等工作。但是MEA讓人看起來(lái)更簡(jiǎn)單有效,比如BEiT需要單獨(dú)訓(xùn)練的tokenizer,而其它的一些工作往往引入了對(duì)比學(xué)習(xí)的類似設(shè)計(jì)。對(duì)于MEA的成功,我覺得是一些突破常規(guī)的設(shè)計(jì),比如很高的masking ratio,這是很難想象會(huì)work的,但MEA卻證明了這是成功的關(guān)鍵。
參考
Mocov3: An Empirical Study of Training Self-Supervised Vision Transformers DINO: Emerging Properties in Self-Supervised Vision Transformers MST: Masked Self-Supervised Transformer for Visual Representation BEiT: BERT Pre-Training of Image Transformers EsViT: Efficient Self-supervised Vision Transformers for Representation Learning Image BERT Pre-training with Online Tokenizer Masked Autoencoders Are Scalable Vision Learners
推薦閱讀
RegNet:設(shè)計(jì)網(wǎng)絡(luò)設(shè)計(jì)空間
PyTorch1.10發(fā)布:ZeroRedundancyOptimizer和Join
谷歌AI用30億數(shù)據(jù)訓(xùn)練了一個(gè)20億參數(shù)Vision Transformer模型,在ImageNet上達(dá)到新的SOTA!
Transformer在語(yǔ)義分割上的應(yīng)用
"未來(lái)"的經(jīng)典之作ViT:transformer is all you need!
PVT:可用于密集任務(wù)backbone的金字塔視覺transformer!
漲點(diǎn)神器FixRes:兩次超越ImageNet數(shù)據(jù)集上的SOTA
不妨試試MoCo,來(lái)替換ImageNet上pretrain模型!
機(jī)器學(xué)習(xí)算法工程師
? ??? ? ? ? ? ? ? ? ? ? ????????? ??一個(gè)用心的公眾號(hào)

