1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        SimMIM:一種更簡(jiǎn)單的MIM方法

        共 10895字,需瀏覽 22分鐘

         ·

        2021-12-14 13:12

        點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師

        設(shè)為星標(biāo),干貨直達(dá)!


        自從何愷明的MEA:視覺(jué)無(wú)監(jiān)督訓(xùn)練新范式出來(lái)之后,基于MIM(Masked Image Modeling)的無(wú)監(jiān)督學(xué)習(xí)方法越來(lái)越受到關(guān)注。這里介紹一篇和MAE同期的工作:SimMIM: A Simple Framework for Masked Image Modeling,研究團(tuán)隊(duì)是微軟亞研院。SimMIM和MAE有很多相似的設(shè)計(jì)和結(jié)論,而且效果也比較接近,比如基于ViT-B的模型無(wú)監(jiān)督訓(xùn)練后再finetune可以ImageNet數(shù)據(jù)集達(dá)到83.8%的top1 accuray(MAE為83.6%)。不過(guò)相比MAE,SimMIM更加簡(jiǎn)單,而且也可以用來(lái)無(wú)監(jiān)督訓(xùn)練金字塔結(jié)構(gòu)的vision transformer模型如swin transformer等。目前SimMIM實(shí)現(xiàn)代碼已經(jīng)開(kāi)源,本文將基于論文和源碼對(duì)SimMIM方法進(jìn)行解讀。

        算法原理

        SimMIM采用最簡(jiǎn)單的MIM方法:隨機(jī)mask掉輸入圖像的一部分patch,然后通過(guò)encoder-decoder來(lái)預(yù)測(cè)masked patchs的原始像素值。算法原理圖如上圖所示,從設(shè)計(jì)方面和MAE基本一致。SimMIM的主要結(jié)論如下:

        • 直接對(duì)圖像采用簡(jiǎn)單的random mask是非常簡(jiǎn)單有效的方法;
        • 直接回歸原始的像素的RGB值不比BEiT采用的分類(lèi)效果差;
        • decoder采用輕量級(jí)的設(shè)計(jì)(直接采用一個(gè)線(xiàn)性層)也能得到很好的效果;

        這些結(jié)論也是在MAE論文中得到了驗(yàn)證。那么SimMIM和MAE的區(qū)別在哪里呢?主要有以下兩點(diǎn):

        • SimMIM的encoder同時(shí)處理visible tokens和masked tokens,而MAE的encoder只處理visible tokens;
        • SimMIM的decoder只采用一個(gè)線(xiàn)性層來(lái)回歸像素值,而MAE的decoder采用transformer結(jié)構(gòu);

        第2個(gè)差異帶來(lái)的影響相對(duì)很小,因?yàn)閮蓚€(gè)論文都證明了decoder設(shè)計(jì)對(duì)性能影響較小。主要的差異點(diǎn)是第一個(gè),MAE訓(xùn)練時(shí)只處理visible tokens一方面可以加速訓(xùn)練(減少了計(jì)算量),同時(shí)也可以減少pre-training和deploy之間的gap(deploy時(shí)輸入是非masked的圖像,無(wú)masked token),MAE實(shí)驗(yàn)也證明只處理visible tokens可以提升linear probing性能:73.5% vs 59.6%。而SimMIM是處理所有的tokens,從實(shí)驗(yàn)結(jié)果上看也符合MAE的結(jié)論,SimMIM方法得到的ViT-B模型的linear probing只有56.7%,不過(guò)這不并不會(huì)影響finetune后的性能,關(guān)于這點(diǎn)MAE論文也論證了。不過(guò)SimMIM這樣做帶來(lái)的一個(gè)好處是可以用來(lái)訓(xùn)練其它非“同質(zhì)結(jié)構(gòu)”模型,比如swin transformer,由于它各個(gè)stage間要對(duì)patch進(jìn)行merge操作,所以token并不是像ViT那樣一成不變的。下面我們具體介紹SimMIM的各個(gè)部分,這里默認(rèn)實(shí)驗(yàn)都是以Swin-B為encoder,為了減少實(shí)驗(yàn)成本,輸入圖像大小為192x192(原來(lái)是224),window size設(shè)置為6(原來(lái)是7),預(yù)訓(xùn)練epoch為100。

        Masking Strategy

        SimMIM的masking策略按照一定mask ratio隨機(jī)mask掉一部分patch。在MAE中,masked patch size和ViT的patch size是一致的,比如ViT-B/16模型,masked patch size就要設(shè)計(jì)為16x16,然后用一個(gè)可學(xué)習(xí)的masked token來(lái)代替。但是對(duì)于SimMIM,其設(shè)計(jì)masked patch size不一定等于模型的patch size,比如ViT模型masked patch size可以是32x32,理論上mask patch size只要是ViT模型patch size的整數(shù)倍就可以,因此此時(shí)每個(gè)mask掉的patch可以整分成和模型patch一樣大小的若干個(gè)patch。對(duì)于金字塔結(jié)構(gòu)的swin transformer,每個(gè)stage的patch size是不同的,比如第一個(gè)stage的patch size是4x4,而最后一個(gè)stage的patch size是32x32,此時(shí)設(shè)計(jì)的mask patch size只需要是第一個(gè)stage的patch size整數(shù)就好。無(wú)論是ViT還是swin transformer,masked token對(duì)應(yīng)的patch size都是其patch embedding層對(duì)應(yīng)的patch size,對(duì)于ViT就是16x16,而對(duì)于swin transformer就是4x4,而mask patch size只需要是masked token的patch size的整數(shù)倍即可。所以SimMIM采用更靈活的mask patch size,不同mask patch size的可視化效果如下圖所示。對(duì)于ViT和swin transformer,SimMIM都默認(rèn)采用:mask ratio=0.6,mask patch size=32x32。不同的mask type,mask patch size和mask ratio對(duì)模型效果(finetune)的影響如下表所示,可以看到不同的設(shè)置均可以取得類(lèi)似的效果,其中random+masked patch size=32x32+mask ratio=0.5可取得最優(yōu)的效果83.0%。從表中可以看出,采用較小的masked patch size(4x4,8x8,16x16),模型效果隨著mask ratio的增加而提升,而對(duì)于更大的masked patch size(64x64),需要采用較小的mask ratio才能得到較好的結(jié)果。masked patch size和mask ratio影響的是MIM任務(wù)的難度,兩者越大,MIM任務(wù)越難,要想取得較好的模型訓(xùn)練效果,MIM任務(wù)的難度要適當(dāng)大一些。論文也提出了AvgDist指標(biāo)來(lái)進(jìn)一步分析masked patch size和mask ratio對(duì)模型finetune效果的影響,這里AvgDist指標(biāo)計(jì)算的是所有masked pixels到最近的visible pixels的平均歐式距離,它綜合了masked patch size和mask ratio對(duì)MIM任務(wù)的影響。從下圖可以看出,AvgDist隨著mask ratio的增加而增加,對(duì)于較小的masked patch size,其AvgDist在較大的mask ratio下依然較小,而較大的masked patch size,其AvgDist在較小的mask ratio下就比較大。從右圖可以看出,AvgDist在[10, 20]區(qū)間內(nèi)都可以取得較好的finetune效果,這個(gè)可以用來(lái)指導(dǎo)選擇不同masked patch size和mask ratio組合。

        采用不同的masked patch size,其預(yù)測(cè)的圖像效果如下所示,可以看到masked patch size越小,圖像還原度越高,這也比較合理。但是MIM本身并不是為了更好地恢復(fù)圖像,而是希望encoder學(xué)習(xí)到好的特征以遷移到下游任務(wù)。

        隨機(jī)mask策略的實(shí)現(xiàn)比較簡(jiǎn)單,在對(duì)每個(gè)圖像進(jìn)行數(shù)據(jù)增強(qiáng)后,同時(shí)隨機(jī)生成一個(gè)mask;在模型forward時(shí),將masked patch用mask token來(lái)替換,注意由于masked patch size和model_patch_size不一定相等,所以要將隨機(jī)生成mask轉(zhuǎn)換成和model_patch_size一致的mask。具體實(shí)現(xiàn)代碼如下所示:

        class?MaskGenerator:
        ????def?__init__(self,?input_size=192,?mask_patch_size=32,?model_patch_size=4,?mask_ratio=0.6):
        ????????self.input_size?=?input_size?#?輸入圖像大小
        ????????self.mask_patch_size?=?mask_patch_size?#?masked?patch大小
        ????????self.model_patch_size?=?model_patch_size?#?模型patch?embed層的patch大小
        ????????self.mask_ratio?=?mask_ratio
        ????????
        ????????assert?self.input_size?%?self.mask_patch_size?==?0
        ????????assert?self.mask_patch_size?%?self.model_patch_size?==?0
        ????????
        ????????self.rand_size?=?self.input_size?//?self.mask_patch_size
        ????????self.scale?=?self.mask_patch_size?//?self.model_patch_size
        ????????
        ????????self.token_count?=?self.rand_size?**?2
        ????????self.mask_count?=?int(np.ceil(self.token_count?*?self.mask_ratio))
        ????????
        ????def?__call__(self):
        ????????mask_idx?=?np.random.permutation(self.token_count)[:self.mask_count]
        ????????mask?=?np.zeros(self.token_count,?dtype=int)
        ????????mask[mask_idx]?=?1
        ????????
        ????????#?要轉(zhuǎn)換成和model_patch?size一致的mask
        ????????mask?=?mask.reshape((self.rand_size,?self.rand_size))
        ????????mask?=?mask.repeat(self.scale,?axis=0).repeat(self.scale,?axis=1)
        ????????
        ????????return?mask
        ????

        class?SimMIMTransform:
        ????def?__init__(self,?config):
        ????????self.transform_img?=?T.Compose([
        ????????????T.Lambda(lambda?img:?img.convert('RGB')?if?img.mode?!=?'RGB'?else?img),
        ????????????T.RandomResizedCrop(config.DATA.IMG_SIZE,?scale=(0.67,?1.),?ratio=(3.?/?4.,?4.?/?3.)),
        ????????????T.RandomHorizontalFlip(),
        ????????????T.ToTensor(),
        ????????????T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)),
        ????????])
        ?
        ????????if?config.MODEL.TYPE?==?'swin':
        ????????????model_patch_size=config.MODEL.SWIN.PATCH_SIZE
        ????????elif?config.MODEL.TYPE?==?'vit':
        ????????????model_patch_size=config.MODEL.VIT.PATCH_SIZE
        ????????else:
        ????????????raise?NotImplementedError
        ????????
        ????????self.mask_generator?=?MaskGenerator(
        ????????????input_size=config.DATA.IMG_SIZE,
        ????????????mask_patch_size=config.DATA.MASK_PATCH_SIZE,
        ????????????model_patch_size=model_patch_size,
        ????????????mask_ratio=config.DATA.MASK_RATIO,
        ????????)
        ????
        ????def?__call__(self,?img):
        ????????img?=?self.transform_img(img)?#?圖像數(shù)據(jù)增強(qiáng)
        ????????mask?=?self.mask_generator()?#?生成mask
        ????????
        ????????return?img,?mask
        ????
        ?class?SwinTransformerForSimMIM(SwinTransformer):
        ????def?__init__(self,?**kwargs):
        ????????super().__init__(**kwargs)

        ????????assert?self.num_classes?==?0

        ????????self.mask_token?=?nn.Parameter(torch.zeros(1,?1,?self.embed_dim))
        ????????trunc_normal_(self.mask_token,?mean=0.,?std=.02)

        ????def?forward(self,?x,?mask):
        ????????x?=?self.patch_embed(x)

        ????????assert?mask?is?not?None
        ????????B,?L,?_?=?x.shape

        ????????mask_tokens?=?self.mask_token.expand(B,?L,?-1)
        ????????w?=?mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
        ????????x?=?x?*?(1.?-?w)?+?mask_tokens?*?w

        ????????if?self.ape:
        ????????????x?=?x?+?self.absolute_pos_embed
        ????????x?=?self.pos_drop(x)

        ????????for?layer?in?self.layers:
        ????????????x?=?layer(x)
        ????????x?=?self.norm(x)

        ????????x?=?x.transpose(1,?2)
        ????????B,?C,?L?=?x.shape
        ????????H?=?W?=?int(L?**?0.5)
        ????????x?=?x.reshape(B,?C,?H,?W)
        ????????return?x
        ????
        #?基于swinT的SimMIM
        class?SwinTransformerForSimMIM(SwinTransformer):
        ????def?__init__(self,?**kwargs):
        ????????super().__init__(**kwargs)

        ????????assert?self.num_classes?==?0
        ??
        ????????#?定義可學(xué)習(xí)的masked?token
        ????????self.mask_token?=?nn.Parameter(torch.zeros(1,?1,?self.embed_dim))
        ????????trunc_normal_(self.mask_token,?mean=0.,?std=.02)

        ????def?forward(self,?x,?mask):
        ????????x?=?self.patch_embed(x)

        ????????assert?mask?is?not?None
        ????????B,?L,?_?=?x.shape

        ????????mask_tokens?=?self.mask_token.expand(B,?L,?-1)
        ????????w?=?mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
        ????????x?=?x?*?(1.?-?w)?+?mask_tokens?*?w?#?用masked?token替換masked?patch對(duì)應(yīng)的patch?embedding

        ????????if?self.ape:
        ????????????x?=?x?+?self.absolute_pos_embed
        ????????x?=?self.pos_drop(x)

        ????????for?layer?in?self.layers:
        ????????????x?=?layer(x)
        ????????x?=?self.norm(x)

        ????????x?=?x.transpose(1,?2)
        ????????B,?C,?L?=?x.shape
        ????????H?=?W?=?int(L?**?0.5)
        ????????x?=?x.reshape(B,?C,?H,?W)
        ????????return?x


        Prediction Head

        這里的prediction head指的就是decoder,用來(lái)預(yù)測(cè)masked patch的原始像素值。論文發(fā)現(xiàn)采用一個(gè)非常輕量級(jí)的decoder(只用1個(gè)linear層)就非常有效。采用更復(fù)雜的head,效果沒(méi)有提升,反而會(huì)增加訓(xùn)練成本。MAE也指出decoder的設(shè)計(jì)對(duì)finetune性能影響較小,但是卻會(huì)影響linear probing效果,如果采用較輕的decoder,那么encoder的后面部分層就要承擔(dān)一部分像素預(yù)測(cè)任務(wù)(無(wú)監(jiān)督訓(xùn)練代理任務(wù)),但這個(gè)卻不是圖像分類(lèi)任務(wù)所需要的,所以會(huì)帶來(lái)linear probing的下降,所以如果要想得到比較好的linear probing效果,就需要設(shè)計(jì)一個(gè)適當(dāng)?shù)膁ecoder以將預(yù)測(cè)任務(wù)集中在decoder上。SimMIM默認(rèn)采用單個(gè)linear層來(lái)預(yù)測(cè)像素值,在實(shí)現(xiàn)上采用一個(gè)1x1卷積層。對(duì)于swin transformer,其得到的特征圖(恢復(fù)成hxw)是原來(lái)圖像的1/32大小,那么卷積層輸出channels等于3072=32x32x3,每個(gè)特征點(diǎn)預(yù)測(cè)32x32個(gè)pixels的RGB值。

        Prediction Tragets

        SimMIM是直接回歸masked patch的原始像素值,所以target就是原始圖像的RGB值,而回歸損失采用L1 loss,注意這里和MAE一樣,只計(jì)算masked pixels的損失,論文也發(fā)現(xiàn)如果對(duì)所有pixels計(jì)算loss,效果會(huì)下降(82.8% -> 81.7%),prediction而不是reconstruction能更好地讓encoder學(xué)習(xí)到更強(qiáng)的特征。另外一個(gè)參數(shù)是prediction resolution,SimMIM默認(rèn)的prediction resolution是原始圖像大小,但也可以對(duì)原始圖像進(jìn)行下采樣,從而降低prediction resolution,從實(shí)驗(yàn)結(jié)果來(lái)看,采用不同的prediction resolution均能得到較好的結(jié)果,除了1/32表現(xiàn)相對(duì)差一些(圖像損失比較嚴(yán)重):

        論文也對(duì)比了其它類(lèi)型的targets,比如像BEiT那樣用dVAE將回歸變成分類(lèi)任務(wù),或者像IGPT那樣采用color clustering。從下表的對(duì)比結(jié)果可以看到直接回歸像素值并不比這些更復(fù)雜的設(shè)計(jì)差。loss計(jì)算部分的實(shí)現(xiàn)也比較簡(jiǎn)單,具體的代碼如下所示(注意這里回歸的像素值是歸一化后的像素值):

        class?SimMIM(nn.Module):
        ????def?__init__(self,?encoder,?encoder_stride):
        ????????super().__init__()
        ????????self.encoder?=?encoder
        ????????self.encoder_stride?=?encoder_stride
        ??
        ????????#?定義encoder
        ????????self.decoder?=?nn.Sequential(
        ????????????nn.Conv2d(
        ????????????????in_channels=self.encoder.num_features,
        ????????????????out_channels=self.encoder_stride?**?2?*?3,?kernel_size=1),?#?1x1?conv等價(jià)于linear
        ????????????nn.PixelShuffle(self.encoder_stride),?#?[B,?3*r*r,?h,?w]?->?[B,?3,?h*r,?w*r]
        ????????)

        ????????self.in_chans?=?self.encoder.in_chans
        ????????self.patch_size?=?self.encoder.patch_size

        ????def?forward(self,?x,?mask):
        ????????z?=?self.encoder(x,?mask)?#?encoder提取特征
        ????????x_rec?=?self.decoder(z)?#?decoder預(yù)測(cè)圖像
        ??
        ????????#?mask轉(zhuǎn)變成和原始圖像一樣大小
        ????????mask?=?mask.repeat_interleave(self.patch_size,?1).repeat_interleave(self.patch_size,?2).unsqueeze(1).contiguous()
        ????????loss_recon?=?F.l1_loss(x,?x_rec,?reduction='none')?#?L1?loss?
        ????????loss?=?(loss_recon?*?mask).sum()?/?(mask.sum()?+?1e-5)?/?self.in_chans?#?只計(jì)算masked?pixels并取mean
        ????????return?loss

        實(shí)驗(yàn)設(shè)置及對(duì)比結(jié)果

        前面的實(shí)驗(yàn)都是以Swin-B為backbone,預(yù)訓(xùn)練的epoch為100,而最后的實(shí)驗(yàn)訓(xùn)練800個(gè)epoch,batch size為2048。在數(shù)據(jù)增強(qiáng)方面,只采用random resize croping:RandomResizedCrop(192, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.))以及水平翻轉(zhuǎn),和MAE一樣屬于輕量級(jí)的數(shù)據(jù)增強(qiáng),這說(shuō)明MIM方法確實(shí)不像對(duì)比學(xué)習(xí)那樣需要較heavy的數(shù)據(jù)增強(qiáng)。對(duì)于ViT,預(yù)訓(xùn)練的圖像大小是224,而SwinT采用的圖像大小為192,對(duì)比結(jié)果如下表所示??梢钥吹剑?/p>

        • 基于SimMIM訓(xùn)練的ViT-B優(yōu)于BEiT方法(83.8 vs 83.2),訓(xùn)練成本也比較低,但是linear probing效果均比較差(56.7);
        • 基于SimMIM預(yù)訓(xùn)練的SwinT也優(yōu)于有監(jiān)督訓(xùn)練的模型,對(duì)于Swin-B,預(yù)訓(xùn)練800epoch相比100epoch有一定提升(82.8 vs 84.0),這里也包含SwinV2的實(shí)驗(yàn),其中30億參數(shù)的SwinV2-G的效果可達(dá)到90.2%。
        image.png

        下圖是一些masked圖像重建后的可視化,可以看出經(jīng)過(guò)SimMIM訓(xùn)練后,模型能學(xué)習(xí)到一定的推理能力,比如mask掉一個(gè)物體或者人后,模型能學(xué)會(huì)補(bǔ)全背景。

        小結(jié)

        總結(jié)來(lái)看,SimMIM和MAE方法大致相同,兩者的差異大概源自SimMIM是為Swin設(shè)計(jì)的,而MAE是為單純的ViT結(jié)構(gòu)設(shè)計(jì)的。一個(gè)缺憾是SimMIM方法雖然在SwinV2上做了驗(yàn)證,但是沒(méi)有直接在下游檢測(cè)和分割任務(wù)上的對(duì)比實(shí)驗(yàn),而MAE方法在隨后的工作Benchmarking Detection Transfer Learning with Vision Transformers中論證了其遷移到實(shí)例分割任務(wù)上的有效性。

        參考

        • Masked Autoencoders Are Scalable Vision Learners
        • SimMIM: A Simple Framework for Masked Image Modeling
        • https://github.com/microsoft/SimMIM



        推薦閱讀

        CPVT:一個(gè)卷積就可以隱式編碼位置信息

        SOTA模型Swin Transformer是如何煉成的!

        快來(lái)解鎖PyTorch新技能:torch.fix

        RegNet:設(shè)計(jì)網(wǎng)絡(luò)設(shè)計(jì)空間

        PyTorch1.10發(fā)布:ZeroRedundancyOptimizer和Join

        谷歌AI用30億數(shù)據(jù)訓(xùn)練了一個(gè)20億參數(shù)Vision Transformer模型,在ImageNet上達(dá)到新的SOTA!

        BatchNorm的避坑指南(上)

        BatchNorm的避坑指南(下)

        目標(biāo)跟蹤入門(mén)篇-相關(guān)濾波

        SOTA模型Swin Transformer是如何煉成的!

        MoCo V3:我并不是你想的那樣!

        Transformer在語(yǔ)義分割上的應(yīng)用

        "未來(lái)"的經(jīng)典之作ViT:transformer is all you need!

        PVT:可用于密集任務(wù)backbone的金字塔視覺(jué)transformer!

        漲點(diǎn)神器FixRes:兩次超越ImageNet數(shù)據(jù)集上的SOTA

        Transformer為何能闖入CV界秒殺CNN?

        不妨試試MoCo,來(lái)替換ImageNet上pretrain模型!


        機(jī)器學(xué)習(xí)算法工程師


        ? ??? ? ? ? ? ? ? ? ? ? ????????? ??一個(gè)用心的公眾號(hào)

        瀏覽 156
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評(píng)論
        圖片
        表情
        推薦
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            在线播放东京热一n1154 | 在线视频精品播放 | 国产精品久久久久久久久午夜福利 | 色香蕉在线视频 | 伦人伦影院A片在线播放一区 | 办公室女职员交换性bd | 国产麻豆一区 | 把裸睡少妇邻居摸到高潮 | 粉嫩小泬无遮挡BBBB | 欧洲美一区二区三区亚洲 |