SimMIM:一種更簡(jiǎn)單的MIM方法
點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師”
設(shè)為星標(biāo),干貨直達(dá)!
自從何愷明的MEA:視覺(jué)無(wú)監(jiān)督訓(xùn)練新范式出來(lái)之后,基于MIM(Masked Image Modeling)的無(wú)監(jiān)督學(xué)習(xí)方法越來(lái)越受到關(guān)注。這里介紹一篇和MAE同期的工作:SimMIM: A Simple Framework for Masked Image Modeling,研究團(tuán)隊(duì)是微軟亞研院。SimMIM和MAE有很多相似的設(shè)計(jì)和結(jié)論,而且效果也比較接近,比如基于ViT-B的模型無(wú)監(jiān)督訓(xùn)練后再finetune可以ImageNet數(shù)據(jù)集達(dá)到83.8%的top1 accuray(MAE為83.6%)。不過(guò)相比MAE,SimMIM更加簡(jiǎn)單,而且也可以用來(lái)無(wú)監(jiān)督訓(xùn)練金字塔結(jié)構(gòu)的vision transformer模型如swin transformer等。目前SimMIM實(shí)現(xiàn)代碼已經(jīng)開(kāi)源,本文將基于論文和源碼對(duì)SimMIM方法進(jìn)行解讀。
算法原理
SimMIM采用最簡(jiǎn)單的MIM方法:隨機(jī)mask掉輸入圖像的一部分patch,然后通過(guò)encoder-decoder來(lái)預(yù)測(cè)masked patchs的原始像素值。算法原理圖如上圖所示,從設(shè)計(jì)方面和MAE基本一致。SimMIM的主要結(jié)論如下:
直接對(duì)圖像采用簡(jiǎn)單的random mask是非常簡(jiǎn)單有效的方法; 直接回歸原始的像素的RGB值不比BEiT采用的分類(lèi)效果差; decoder采用輕量級(jí)的設(shè)計(jì)(直接采用一個(gè)線(xiàn)性層)也能得到很好的效果;
這些結(jié)論也是在MAE論文中得到了驗(yàn)證。那么SimMIM和MAE的區(qū)別在哪里呢?主要有以下兩點(diǎn):
SimMIM的encoder同時(shí)處理visible tokens和masked tokens,而MAE的encoder只處理visible tokens; SimMIM的decoder只采用一個(gè)線(xiàn)性層來(lái)回歸像素值,而MAE的decoder采用transformer結(jié)構(gòu);
第2個(gè)差異帶來(lái)的影響相對(duì)很小,因?yàn)閮蓚€(gè)論文都證明了decoder設(shè)計(jì)對(duì)性能影響較小。主要的差異點(diǎn)是第一個(gè),MAE訓(xùn)練時(shí)只處理visible tokens一方面可以加速訓(xùn)練(減少了計(jì)算量),同時(shí)也可以減少pre-training和deploy之間的gap(deploy時(shí)輸入是非masked的圖像,無(wú)masked token),MAE實(shí)驗(yàn)也證明只處理visible tokens可以提升linear probing性能:73.5% vs 59.6%。而SimMIM是處理所有的tokens,從實(shí)驗(yàn)結(jié)果上看也符合MAE的結(jié)論,SimMIM方法得到的ViT-B模型的linear probing只有56.7%,不過(guò)這不并不會(huì)影響finetune后的性能,關(guān)于這點(diǎn)MAE論文也論證了。不過(guò)SimMIM這樣做帶來(lái)的一個(gè)好處是可以用來(lái)訓(xùn)練其它非“同質(zhì)結(jié)構(gòu)”模型,比如swin transformer,由于它各個(gè)stage間要對(duì)patch進(jìn)行merge操作,所以token并不是像ViT那樣一成不變的。下面我們具體介紹SimMIM的各個(gè)部分,這里默認(rèn)實(shí)驗(yàn)都是以Swin-B為encoder,為了減少實(shí)驗(yàn)成本,輸入圖像大小為192x192(原來(lái)是224),window size設(shè)置為6(原來(lái)是7),預(yù)訓(xùn)練epoch為100。
Masking Strategy
SimMIM的masking策略按照一定mask ratio隨機(jī)mask掉一部分patch。在MAE中,masked patch size和ViT的patch size是一致的,比如ViT-B/16模型,masked patch size就要設(shè)計(jì)為16x16,然后用一個(gè)可學(xué)習(xí)的masked token來(lái)代替。但是對(duì)于SimMIM,其設(shè)計(jì)masked patch size不一定等于模型的patch size,比如ViT模型masked patch size可以是32x32,理論上mask patch size只要是ViT模型patch size的整數(shù)倍就可以,因此此時(shí)每個(gè)mask掉的patch可以整分成和模型patch一樣大小的若干個(gè)patch。對(duì)于金字塔結(jié)構(gòu)的swin transformer,每個(gè)stage的patch size是不同的,比如第一個(gè)stage的patch size是4x4,而最后一個(gè)stage的patch size是32x32,此時(shí)設(shè)計(jì)的mask patch size只需要是第一個(gè)stage的patch size整數(shù)就好。無(wú)論是ViT還是swin transformer,masked token對(duì)應(yīng)的patch size都是其patch embedding層對(duì)應(yīng)的patch size,對(duì)于ViT就是16x16,而對(duì)于swin transformer就是4x4,而mask patch size只需要是masked token的patch size的整數(shù)倍即可。所以SimMIM采用更靈活的mask patch size,不同mask patch size的可視化效果如下圖所示。對(duì)于ViT和swin transformer,SimMIM都默認(rèn)采用:mask ratio=0.6,mask patch size=32x32。
不同的mask type,mask patch size和mask ratio對(duì)模型效果(finetune)的影響如下表所示,可以看到不同的設(shè)置均可以取得類(lèi)似的效果,其中random+masked patch size=32x32+mask ratio=0.5可取得最優(yōu)的效果83.0%。
從表中可以看出,采用較小的masked patch size(4x4,8x8,16x16),模型效果隨著mask ratio的增加而提升,而對(duì)于更大的masked patch size(64x64),需要采用較小的mask ratio才能得到較好的結(jié)果。masked patch size和mask ratio影響的是MIM任務(wù)的難度,兩者越大,MIM任務(wù)越難,要想取得較好的模型訓(xùn)練效果,MIM任務(wù)的難度要適當(dāng)大一些。論文也提出了AvgDist指標(biāo)來(lái)進(jìn)一步分析masked patch size和mask ratio對(duì)模型finetune效果的影響,這里AvgDist指標(biāo)計(jì)算的是所有masked pixels到最近的visible pixels的平均歐式距離,它綜合了masked patch size和mask ratio對(duì)MIM任務(wù)的影響。從下圖可以看出,AvgDist隨著mask ratio的增加而增加,對(duì)于較小的masked patch size,其AvgDist在較大的mask ratio下依然較小,而較大的masked patch size,其AvgDist在較小的mask ratio下就比較大。從右圖可以看出,AvgDist在[10, 20]區(qū)間內(nèi)都可以取得較好的finetune效果,這個(gè)可以用來(lái)指導(dǎo)選擇不同masked patch size和mask ratio組合。
采用不同的masked patch size,其預(yù)測(cè)的圖像效果如下所示,可以看到masked patch size越小,圖像還原度越高,這也比較合理。但是MIM本身并不是為了更好地恢復(fù)圖像,而是希望encoder學(xué)習(xí)到好的特征以遷移到下游任務(wù)。
隨機(jī)mask策略的實(shí)現(xiàn)比較簡(jiǎn)單,在對(duì)每個(gè)圖像進(jìn)行數(shù)據(jù)增強(qiáng)后,同時(shí)隨機(jī)生成一個(gè)mask;在模型forward時(shí),將masked patch用mask token來(lái)替換,注意由于masked patch size和model_patch_size不一定相等,所以要將隨機(jī)生成mask轉(zhuǎn)換成和model_patch_size一致的mask。具體實(shí)現(xiàn)代碼如下所示:
class?MaskGenerator:
????def?__init__(self,?input_size=192,?mask_patch_size=32,?model_patch_size=4,?mask_ratio=0.6):
????????self.input_size?=?input_size?#?輸入圖像大小
????????self.mask_patch_size?=?mask_patch_size?#?masked?patch大小
????????self.model_patch_size?=?model_patch_size?#?模型patch?embed層的patch大小
????????self.mask_ratio?=?mask_ratio
????????
????????assert?self.input_size?%?self.mask_patch_size?==?0
????????assert?self.mask_patch_size?%?self.model_patch_size?==?0
????????
????????self.rand_size?=?self.input_size?//?self.mask_patch_size
????????self.scale?=?self.mask_patch_size?//?self.model_patch_size
????????
????????self.token_count?=?self.rand_size?**?2
????????self.mask_count?=?int(np.ceil(self.token_count?*?self.mask_ratio))
????????
????def?__call__(self):
????????mask_idx?=?np.random.permutation(self.token_count)[:self.mask_count]
????????mask?=?np.zeros(self.token_count,?dtype=int)
????????mask[mask_idx]?=?1
????????
????????#?要轉(zhuǎn)換成和model_patch?size一致的mask
????????mask?=?mask.reshape((self.rand_size,?self.rand_size))
????????mask?=?mask.repeat(self.scale,?axis=0).repeat(self.scale,?axis=1)
????????
????????return?mask
????
class?SimMIMTransform:
????def?__init__(self,?config):
????????self.transform_img?=?T.Compose([
????????????T.Lambda(lambda?img:?img.convert('RGB')?if?img.mode?!=?'RGB'?else?img),
????????????T.RandomResizedCrop(config.DATA.IMG_SIZE,?scale=(0.67,?1.),?ratio=(3.?/?4.,?4.?/?3.)),
????????????T.RandomHorizontalFlip(),
????????????T.ToTensor(),
????????????T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)),
????????])
?
????????if?config.MODEL.TYPE?==?'swin':
????????????model_patch_size=config.MODEL.SWIN.PATCH_SIZE
????????elif?config.MODEL.TYPE?==?'vit':
????????????model_patch_size=config.MODEL.VIT.PATCH_SIZE
????????else:
????????????raise?NotImplementedError
????????
????????self.mask_generator?=?MaskGenerator(
????????????input_size=config.DATA.IMG_SIZE,
????????????mask_patch_size=config.DATA.MASK_PATCH_SIZE,
????????????model_patch_size=model_patch_size,
????????????mask_ratio=config.DATA.MASK_RATIO,
????????)
????
????def?__call__(self,?img):
????????img?=?self.transform_img(img)?#?圖像數(shù)據(jù)增強(qiáng)
????????mask?=?self.mask_generator()?#?生成mask
????????
????????return?img,?mask
????
?class?SwinTransformerForSimMIM(SwinTransformer):
????def?__init__(self,?**kwargs):
????????super().__init__(**kwargs)
????????assert?self.num_classes?==?0
????????self.mask_token?=?nn.Parameter(torch.zeros(1,?1,?self.embed_dim))
????????trunc_normal_(self.mask_token,?mean=0.,?std=.02)
????def?forward(self,?x,?mask):
????????x?=?self.patch_embed(x)
????????assert?mask?is?not?None
????????B,?L,?_?=?x.shape
????????mask_tokens?=?self.mask_token.expand(B,?L,?-1)
????????w?=?mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
????????x?=?x?*?(1.?-?w)?+?mask_tokens?*?w
????????if?self.ape:
????????????x?=?x?+?self.absolute_pos_embed
????????x?=?self.pos_drop(x)
????????for?layer?in?self.layers:
????????????x?=?layer(x)
????????x?=?self.norm(x)
????????x?=?x.transpose(1,?2)
????????B,?C,?L?=?x.shape
????????H?=?W?=?int(L?**?0.5)
????????x?=?x.reshape(B,?C,?H,?W)
????????return?x
????
#?基于swinT的SimMIM
class?SwinTransformerForSimMIM(SwinTransformer):
????def?__init__(self,?**kwargs):
????????super().__init__(**kwargs)
????????assert?self.num_classes?==?0
??
????????#?定義可學(xué)習(xí)的masked?token
????????self.mask_token?=?nn.Parameter(torch.zeros(1,?1,?self.embed_dim))
????????trunc_normal_(self.mask_token,?mean=0.,?std=.02)
????def?forward(self,?x,?mask):
????????x?=?self.patch_embed(x)
????????assert?mask?is?not?None
????????B,?L,?_?=?x.shape
????????mask_tokens?=?self.mask_token.expand(B,?L,?-1)
????????w?=?mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
????????x?=?x?*?(1.?-?w)?+?mask_tokens?*?w?#?用masked?token替換masked?patch對(duì)應(yīng)的patch?embedding
????????if?self.ape:
????????????x?=?x?+?self.absolute_pos_embed
????????x?=?self.pos_drop(x)
????????for?layer?in?self.layers:
????????????x?=?layer(x)
????????x?=?self.norm(x)
????????x?=?x.transpose(1,?2)
????????B,?C,?L?=?x.shape
????????H?=?W?=?int(L?**?0.5)
????????x?=?x.reshape(B,?C,?H,?W)
????????return?x
Prediction Head
這里的prediction head指的就是decoder,用來(lái)預(yù)測(cè)masked patch的原始像素值。論文發(fā)現(xiàn)采用一個(gè)非常輕量級(jí)的decoder(只用1個(gè)linear層)就非常有效。采用更復(fù)雜的head,效果沒(méi)有提升,反而會(huì)增加訓(xùn)練成本。MAE也指出decoder的設(shè)計(jì)對(duì)finetune性能影響較小,但是卻會(huì)影響linear probing效果,如果采用較輕的decoder,那么encoder的后面部分層就要承擔(dān)一部分像素預(yù)測(cè)任務(wù)(無(wú)監(jiān)督訓(xùn)練代理任務(wù)),但這個(gè)卻不是圖像分類(lèi)任務(wù)所需要的,所以會(huì)帶來(lái)linear probing的下降,所以如果要想得到比較好的linear probing效果,就需要設(shè)計(jì)一個(gè)適當(dāng)?shù)膁ecoder以將預(yù)測(cè)任務(wù)集中在decoder上。
SimMIM默認(rèn)采用單個(gè)linear層來(lái)預(yù)測(cè)像素值,在實(shí)現(xiàn)上采用一個(gè)1x1卷積層。對(duì)于swin transformer,其得到的特征圖(恢復(fù)成hxw)是原來(lái)圖像的1/32大小,那么卷積層輸出channels等于3072=32x32x3,每個(gè)特征點(diǎn)預(yù)測(cè)32x32個(gè)pixels的RGB值。
Prediction Tragets
SimMIM是直接回歸masked patch的原始像素值,所以target就是原始圖像的RGB值,而回歸損失采用L1 loss,注意這里和MAE一樣,只計(jì)算masked pixels的損失,論文也發(fā)現(xiàn)如果對(duì)所有pixels計(jì)算loss,效果會(huì)下降(82.8% -> 81.7%),prediction而不是reconstruction能更好地讓encoder學(xué)習(xí)到更強(qiáng)的特征。另外一個(gè)參數(shù)是prediction resolution,SimMIM默認(rèn)的prediction resolution是原始圖像大小,但也可以對(duì)原始圖像進(jìn)行下采樣,從而降低prediction resolution,從實(shí)驗(yàn)結(jié)果來(lái)看,采用不同的prediction resolution均能得到較好的結(jié)果,除了1/32表現(xiàn)相對(duì)差一些(圖像損失比較嚴(yán)重):
論文也對(duì)比了其它類(lèi)型的targets,比如像BEiT那樣用dVAE將回歸變成分類(lèi)任務(wù),或者像IGPT那樣采用color clustering。從下表的對(duì)比結(jié)果可以看到直接回歸像素值并不比這些更復(fù)雜的設(shè)計(jì)差。
loss計(jì)算部分的實(shí)現(xiàn)也比較簡(jiǎn)單,具體的代碼如下所示(注意這里回歸的像素值是歸一化后的像素值):
class?SimMIM(nn.Module):
????def?__init__(self,?encoder,?encoder_stride):
????????super().__init__()
????????self.encoder?=?encoder
????????self.encoder_stride?=?encoder_stride
??
????????#?定義encoder
????????self.decoder?=?nn.Sequential(
????????????nn.Conv2d(
????????????????in_channels=self.encoder.num_features,
????????????????out_channels=self.encoder_stride?**?2?*?3,?kernel_size=1),?#?1x1?conv等價(jià)于linear
????????????nn.PixelShuffle(self.encoder_stride),?#?[B,?3*r*r,?h,?w]?->?[B,?3,?h*r,?w*r]
????????)
????????self.in_chans?=?self.encoder.in_chans
????????self.patch_size?=?self.encoder.patch_size
????def?forward(self,?x,?mask):
????????z?=?self.encoder(x,?mask)?#?encoder提取特征
????????x_rec?=?self.decoder(z)?#?decoder預(yù)測(cè)圖像
??
????????#?mask轉(zhuǎn)變成和原始圖像一樣大小
????????mask?=?mask.repeat_interleave(self.patch_size,?1).repeat_interleave(self.patch_size,?2).unsqueeze(1).contiguous()
????????loss_recon?=?F.l1_loss(x,?x_rec,?reduction='none')?#?L1?loss?
????????loss?=?(loss_recon?*?mask).sum()?/?(mask.sum()?+?1e-5)?/?self.in_chans?#?只計(jì)算masked?pixels并取mean
????????return?loss
實(shí)驗(yàn)設(shè)置及對(duì)比結(jié)果
前面的實(shí)驗(yàn)都是以Swin-B為backbone,預(yù)訓(xùn)練的epoch為100,而最后的實(shí)驗(yàn)訓(xùn)練800個(gè)epoch,batch size為2048。在數(shù)據(jù)增強(qiáng)方面,只采用random resize croping:RandomResizedCrop(192, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.))以及水平翻轉(zhuǎn),和MAE一樣屬于輕量級(jí)的數(shù)據(jù)增強(qiáng),這說(shuō)明MIM方法確實(shí)不像對(duì)比學(xué)習(xí)那樣需要較heavy的數(shù)據(jù)增強(qiáng)。對(duì)于ViT,預(yù)訓(xùn)練的圖像大小是224,而SwinT采用的圖像大小為192,對(duì)比結(jié)果如下表所示??梢钥吹剑?/p>
基于SimMIM訓(xùn)練的ViT-B優(yōu)于BEiT方法(83.8 vs 83.2),訓(xùn)練成本也比較低,但是linear probing效果均比較差(56.7); 基于SimMIM預(yù)訓(xùn)練的SwinT也優(yōu)于有監(jiān)督訓(xùn)練的模型,對(duì)于Swin-B,預(yù)訓(xùn)練800epoch相比100epoch有一定提升(82.8 vs 84.0),這里也包含SwinV2的實(shí)驗(yàn),其中30億參數(shù)的SwinV2-G的效果可達(dá)到90.2%。

下圖是一些masked圖像重建后的可視化,可以看出經(jīng)過(guò)SimMIM訓(xùn)練后,模型能學(xué)習(xí)到一定的推理能力,比如mask掉一個(gè)物體或者人后,模型能學(xué)會(huì)補(bǔ)全背景。
小結(jié)
總結(jié)來(lái)看,SimMIM和MAE方法大致相同,兩者的差異大概源自SimMIM是為Swin設(shè)計(jì)的,而MAE是為單純的ViT結(jié)構(gòu)設(shè)計(jì)的。一個(gè)缺憾是SimMIM方法雖然在SwinV2上做了驗(yàn)證,但是沒(méi)有直接在下游檢測(cè)和分割任務(wù)上的對(duì)比實(shí)驗(yàn),而MAE方法在隨后的工作Benchmarking Detection Transfer Learning with Vision Transformers中論證了其遷移到實(shí)例分割任務(wù)上的有效性。
參考
Masked Autoencoders Are Scalable Vision Learners SimMIM: A Simple Framework for Masked Image Modeling https://github.com/microsoft/SimMIM
推薦閱讀
RegNet:設(shè)計(jì)網(wǎng)絡(luò)設(shè)計(jì)空間
PyTorch1.10發(fā)布:ZeroRedundancyOptimizer和Join
谷歌AI用30億數(shù)據(jù)訓(xùn)練了一個(gè)20億參數(shù)Vision Transformer模型,在ImageNet上達(dá)到新的SOTA!
Transformer在語(yǔ)義分割上的應(yīng)用
"未來(lái)"的經(jīng)典之作ViT:transformer is all you need!
PVT:可用于密集任務(wù)backbone的金字塔視覺(jué)transformer!
漲點(diǎn)神器FixRes:兩次超越ImageNet數(shù)據(jù)集上的SOTA
不妨試試MoCo,來(lái)替換ImageNet上pretrain模型!
機(jī)器學(xué)習(xí)算法工程師
? ??? ? ? ? ? ? ? ? ? ? ????????? ??一個(gè)用心的公眾號(hào)

