DropBlock的原理和實(shí)現(xiàn)
點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師”
設(shè)為星標(biāo),干貨直達(dá)!
DropBlock是谷歌在2018年提出的一種用于CNN的正則化方法。普通的DropOut只是隨機(jī)屏蔽掉一部分特征,而DropBlock是隨機(jī)屏蔽掉一部分連續(xù)區(qū)域,如下圖所示。圖像是一個(gè)2D結(jié)構(gòu),像素或者特征點(diǎn)之間在空間上存在依賴關(guān)系,這樣普通的DropOut在屏蔽語義就不夠有效,但是DropBlock這樣屏蔽連續(xù)區(qū)域塊就能有效移除某些語義信息比如狗的頭,從而起到有效的正則化作用。DropBlock和CutOut有點(diǎn)類似,只不過CutOut是用于圖像的一種數(shù)據(jù)增強(qiáng)方法,而DropBlock是用在CNN的特征上的一種正則化手段。
DropBlock的原理很簡(jiǎn)單,它和DropOut的最大區(qū)別是就是屏蔽的地方是一個(gè)連續(xù)的方塊區(qū)域,其偽代碼如下所示:
DropBlock有兩個(gè)主要參數(shù):block_size和,其中block_size為方塊區(qū)域的邊長(zhǎng),而控制被屏蔽的特征數(shù)量大小。對(duì)于DropBlock,首先要用參數(shù)為的伯努利分布生成一個(gè)center mask,這個(gè)center mask產(chǎn)生的是要屏蔽的block的中心點(diǎn),然后將mask中的每個(gè)點(diǎn)擴(kuò)展到block_size大小的方塊區(qū)域,從而生成最終的block mask。假定輸入的特征大小為,那么center mask的大小應(yīng)該為,而block mask的大小為,在實(shí)現(xiàn)上我們可以先對(duì)center mask進(jìn)行padding,然后用一個(gè)kernel_size為block_size的max pooling來得到block mask。最后我們將特征乘以block mask即可,不過和DropOut類似,為了保證訓(xùn)練和測(cè)試的一致性,還需要對(duì)特征進(jìn)行歸一化:乘以count(block mask)/count_ones(block mask)。
對(duì)于DropBlock,我們往往像DropOut那樣直接設(shè)置一個(gè)keep_prob(或者drop_prob),這個(gè)概率值控制特征被屏蔽的量。此時(shí)我們需要將keep_prob轉(zhuǎn)換為,兩個(gè)參數(shù)帶來的效果應(yīng)該是等價(jià)的,所以有:
這里為特征圖大小,那么有了keep_prob就可以計(jì)算出:
不過這里并沒有考慮到兩個(gè)block可能會(huì)發(fā)生重疊,所以上述公式只是估算。DropBlock往往采用較大的keep_prob,如下圖所示采用0.9的效果是最好的。另外,論文中發(fā)現(xiàn)對(duì)keep_prob采用一個(gè)線性遞減的scheduler可以進(jìn)一步增加效果:keep_prob從1.0線性遞減到設(shè)定值如0.9。
對(duì)于block_size,實(shí)驗(yàn)發(fā)現(xiàn)采用block_size=7效果是最好的,如下所示:
以ResNet50為例,使用DropBlock后top-1 acc可以從76.5%提升至78.3%,超過其它dropout方法:
對(duì)于DropBlock的使用位置,論文發(fā)現(xiàn)對(duì)ResNet50來說,用在group3和group4中的卷積層中(即最后兩個(gè)stage),效果最好。
下面為DropBlock在PyTorch的具體實(shí)現(xiàn):
class?DropBlock2d(nn.Module):
????"""
????Implements?DropBlock2d?from?`"DropBlock:?A?regularization?method?for?convolutional?networks"
????`.
????Args:
????????p?(float):?Probability?of?an?element?to?be?dropped.
????????block_size?(int):?Size?of?the?block?to?drop.
????????inplace?(bool):?If?set?to?``True``,?will?do?this?operation?in-place.?Default:?``False``
????"""
????def?__init__(self,?p:?float,?block_size:?int,?inplace:?bool?=?False)?->?None:
????????super().__init__()
????????if?p?0.0?or?p?>?1.0:
????????????raise?ValueError(f"drop?probability?has?to?be?between?0?and?1,?but?got?{p}")
????????self.p?=?p
????????self.block_size?=?block_size
????????self.inplace?=?inplace
????def?forward(self,?input:?Tensor)?->?Tensor:
????????"""
????????Args:
????????????input?(Tensor):?Input?feature?map?on?which?some?areas?will?be?randomly
????????????????dropped.
????????Returns:
????????????Tensor:?The?tensor?after?DropBlock?layer.
????????"""
????????if?not?self.training:
????????????return?input
????????N,?C,?H,?W?=?input.size()
????????#?compute?the?gamma?of?Bernoulli?distribution
????????gamma?=?(self.p?*?H?*?W)?/?((self.block_size?**?2)?*?((H?-?self.block_size?+?1)?*?(W?-?self.block_size?+?1)))
????????mask_shape?=?(N,?C,?H?-?self.block_size?+?1,?W?-?self.block_size?+?1)
????????mask?=?torch.bernoulli(torch.full(mask_shape,?gamma,?device=input.device))
????????mask?=?F.pad(mask,?[self.block_size?//?2]?*?4,?value=0)
????????mask?=?F.max_pool2d(mask,?stride=(1,?1),?kernel_size=(self.block_size,?self.block_size),?padding=self.block_size?//?2)
????????mask?=?1?-?mask
????????normalize_scale?=?mask.numel()?/?(1e-6?+?mask.sum())
????????if?self.inplace:
????????????input.mul_(mask?*?normalize_scale)
????????else:
????????????input?=?input?*?mask?*?normalize_scale
????????return?input
????def?__repr__(self)?->?str:
????????s?=?f"{self.__class__.__name__}(p={self.p},?block_size={self.block_size},?inplace={self.inplace})"
????????return?s
參考
DropBlock: A regularization method for convolutional networks https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/plugins/dropblock.py
推薦閱讀
輔助模塊加速收斂,精度大幅提升!移動(dòng)端實(shí)時(shí)的NanoDet-Plus來了!
SSD的torchvision版本實(shí)現(xiàn)詳解
機(jī)器學(xué)習(xí)算法工程師
? ??? ? ? ? ? ? ? ? ? ? ????????? ??一個(gè)用心的公眾號(hào)

