深入淺出 | 圖像分類之?dāng)?shù)據(jù)增強全梳理
知乎作者:小小將??文僅分享,侵刪
原文鏈接:https://zhuanlan.zhihu.com/p/430563265

一個模型的性能除了和網(wǎng)絡(luò)結(jié)構(gòu)本身有關(guān),還非常依賴具體的訓(xùn)練策略,比如優(yōu)化器,數(shù)據(jù)增強以及正則化策略等(當(dāng)然也很訓(xùn)練數(shù)據(jù)強相關(guān),訓(xùn)練數(shù)據(jù)量往往決定模型性能的上線)。近年來,圖像分類模型在ImageNet數(shù)據(jù)集的top1 acc已經(jīng)由原來的56.5(AlexNet,2012)提升至90.88(CoAtNet,2021,用了額外的數(shù)據(jù)集JFT-3B),這進步除了主要歸功于模型,算力和數(shù)據(jù)的提升,也與訓(xùn)練策略的提升緊密相關(guān)。最近剛興起的vision transformer相比CNN模型往往也需要更heavy的數(shù)據(jù)增強和正則化策略。這里簡單介紹圖像分類訓(xùn)練技巧中的常用數(shù)據(jù)增強策略。
baseline
ImageNet數(shù)據(jù)集訓(xùn)練常用的數(shù)據(jù)增強策略如下,訓(xùn)練過程的數(shù)據(jù)增強包括隨機縮放裁剪(RandomResizedCrop,這種處理方式源自谷歌的Inception,所以稱為 Inception-style pre-processing)和水平翻轉(zhuǎn)(RandomHorizontalFlip),而測試階段是執(zhí)行縮放和中心裁剪。這其實是一種輕量級的策略,這里稱之為baseline。torchvision的實現(xiàn)的ResNet50訓(xùn)練采用的策略就是這個,在ImageNet上的top1 acc可以達到76.1。
from?torchvision?import?transforms
normalize?=?transforms.Normalize(mean=[0.485,?0.456,?0.406],
?????????????????????????????????std=[0.229,?0.224,?0.225])
#?訓(xùn)練
train_transform?=?transforms.Compose([
????#?這里的scale指的是面積,ratio是寬高比
????#?具體實現(xiàn)每次先隨機確定scale和ratio,可以生成w和h,然后隨機確定裁剪位置進行crop
????#?最后是resize到target?size
????transforms.RandomResizedCrop(224,?scale=(0.08,?1.0),?ratio=(3.?/?4.,?4.?/?3.)),
????transforms.RandomHorizontalFlip(),
????transforms.ToTensor(),
????normalize
?])
#?測試
test_transform?=?transforms.Compose([
????transforms.Resize(256),
????transforms.CenterCrop(224),
????transforms.ToTensor(),
????normalize,
?])

AutoAugment
谷歌在2018年提出通過AutoML來自動搜索數(shù)據(jù)增強策略,稱之為AutoAugment(算是自動數(shù)據(jù)增強開山之作)。搜索方法采用強化學(xué)習(xí),和NAS類似,只不過搜索空間是數(shù)據(jù)增強策略,而不是網(wǎng)絡(luò)架構(gòu)。在搜索空間里,一個policy包含5個sub-policies,每個sub-policy包含兩個串行的圖像增強操作,每個增強操作有兩個超參數(shù):進行該操作的概率和圖像增強的幅度(magnitude,這個表示數(shù)據(jù)增強的強度,比如對于旋轉(zhuǎn),旋轉(zhuǎn)的角度就是增強幅度,旋轉(zhuǎn)角度越大,增強越大)。每個policy在執(zhí)行時,首先隨機從5個策略中隨機選擇一個sub-policy,然后序列執(zhí)行兩個圖像操作。

搜索空間一共有16種圖像增強類型,具體如下所示,大部分操作都定義了圖像增強的幅度范圍,在搜索時需要將幅度值離散化,具體地是將幅度值在定義范圍內(nèi)均勻地取10個值。

論文在不同的數(shù)據(jù)集上( CIFAR-10 , SVHN, ImageNet)做了實驗,這里給出在ImageNet數(shù)據(jù)集上搜索得到的最優(yōu)policy(最后實際上是將搜索得到的前5個最好的policies合成了一個policy,所以這里包含25個sub-policies):
#?operation,?probability,?magnitude
(("Posterize",?0.4,?8),?("Rotate",?0.6,?9)),
(("Solarize",?0.6,?5),?("AutoContrast",?0.6,?None)),??????????????????????????????????????????????????????????
(("Equalize",?0.8,?None),?("Equalize",?0.6,?None)),
(("Posterize",?0.6,?7),?("Posterize",?0.6,?6)),
(("Equalize",?0.4,?None),?("Solarize",?0.2,?4)),
(("Equalize",?0.4,?None),?("Rotate",?0.8,?8)),
(("Solarize",?0.6,?3),?("Equalize",?0.6,?None)),
(("Posterize",?0.8,?5),?("Equalize",?1.0,?None)),
(("Rotate",?0.2,?3),?("Solarize",?0.6,?8)),
(("Equalize",?0.6,?None),?("Posterize",?0.4,?6)),
(("Rotate",?0.8,?8),?("Color",?0.4,?0)),
(("Rotate",?0.4,?9),?("Equalize",?0.6,?None)),
(("Equalize",?0.0,?None),?("Equalize",?0.8,?None)),
(("Invert",?0.6,?None),?("Equalize",?1.0,?None)),
(("Color",?0.6,?4),?("Contrast",?1.0,?8)),
(("Rotate",?0.8,?8),?("Color",?1.0,?2)),
(("Color",?0.8,?8),?("Solarize",?0.8,?7)),
(("Sharpness",?0.4,?7),?("Invert",?0.6,?None)),
(("ShearX",?0.6,?5),?("Equalize",?1.0,?None)),
(("Color",?0.4,?0),?("Equalize",?0.6,?None)),
(("Equalize",?0.4,?None),?("Solarize",?0.2,?4)),
(("Solarize",?0.6,?5),?("AutoContrast",?0.6,?None)),
(("Invert",?0.6,?None),?("Equalize",?1.0,?None)),
(("Color",?0.6,?4),?("Contrast",?1.0,?8)),
(("Equalize",?0.8,?None),?("Equalize",?0.6,?None))
基于搜索得到的AutoAugment訓(xùn)練可以將ResNet50在ImageNet數(shù)據(jù)集上的top1 acc從76.3提升至77.6。一個比較重要的問題,這些從某一個數(shù)據(jù)集搜索得到的策略是否只對固定的數(shù)據(jù)集有效,論文也通過具體實驗證明了AutoAugment的遷移能力,比如將ImageNet數(shù)據(jù)集上得到的策略用在5個 FGVC數(shù)據(jù)集(與ImageNet圖像輸入大小相似)也均有提升。
目前torchvision庫已經(jīng)實現(xiàn)了AutoAugment,具體使用如下所示(注意AutoAug前也需要包括一個RandomResizedCrop):
from?torchvision.transforms?import?autoaugment,?transforms
train_transform?=?transforms.Compose([
????transforms.RandomResizedCrop(crop_size,?interpolation=interpolation),
????transforms.RandomHorizontalFlip(hflip_prob),
????#?這里policy屬于torchvision.transforms.autoaugment.AutoAugmentPolicy,
????#?對于ImageNet就是?AutoAugmentPolicy.IMAGENET
????#?此時aa_policy?=?autoaugment.AutoAugmentPolicy('imagenet')
????autoaugment.AutoAugment(policy=aa_policy,?interpolation=interpolation),
?transforms.PILToTensor(),
????transforms.ConvertImageDtype(torch.float),
????transforms.Normalize(mean=mean,?std=std)
?])

RandAugment
AutoAugment存在的一個問題是搜索空間巨大,這使得搜索只能在代理任務(wù)中進行:使用小的模型在ImageNet的一個小的子集( 120類和6000圖片)搜索。谷歌在2019年又提出了一個更簡單的數(shù)據(jù)增強策略:RandAugment。這篇論文首先發(fā)現(xiàn)AutoAugment這樣在小數(shù)據(jù)集上搜索出來的策略在大的數(shù)據(jù)集上應(yīng)用會存在問題,這主要是因為數(shù)據(jù)增強策略和模型大小和數(shù)據(jù)量大小存在強相關(guān),如下圖所示可以看到模型或者訓(xùn)練數(shù)據(jù)量越大,其最優(yōu)的數(shù)據(jù)增強的幅度越大,這說明AutoAugment得到的結(jié)果應(yīng)該是次優(yōu)的。另外,Population Based Augmentation這篇論文發(fā)現(xiàn)最優(yōu)的數(shù)據(jù)增強幅度是隨訓(xùn)練過程增加,而且不同的增強操作遵循類似的規(guī)律,這啟發(fā)作者采用固定的增強幅度而不是去搜索。RandAugment相比AutoAugment的策略空間很?。?span style="outline: 0px;max-width: 100%;cursor: pointer;box-sizing: border-box !important;overflow-wrap: break-word !important;">?vs?),所以它不需要采用代理任務(wù),甚至直接采用簡單的網(wǎng)格搜索。

具體地,RandAugment共包含兩個超參數(shù):圖像增強操作的數(shù)量N和一個全局的增強幅度M,其實現(xiàn)代碼如下所示,每次從候選操作集合(共14種策略)隨機選擇N個操作(等概率),然后串行執(zhí)行(這里沒有判斷概率,是一定執(zhí)行)。這里的M取值范圍為{0, . . . , 30}(每個圖像增強操作歸一化到同樣的幅度范圍),而N取值范圍一般為 {1, 2, 3}。
#?Identity是恒等變換,不做任何增強
transforms?=?['Identity',?'AutoContrast',?'Equalize',?'Rotate',?'Solarize',?'Color',?'Posterize',?
??????????????'Contrast',?'Brightness',?'Sharpness',?'ShearX',?'ShearY',?'TranslateX',?'TranslateY']
def?randaugment(N,?M):
?"""Generate?a?set?of?distortions.
?Args:
?N:?Number?of?augmentation?transformations?to
?apply?sequentially.
?M:?Magnitude?for?all?the?transformations.
?"""
?sampled_ops?=?np.random.choice(transforms,?N)
?return?[(op,?M)?for?op?in?sampled_ops]
對于ResNet50,其搜索得到的N=2,M=9,RandAugment相比AutoAugment可以在ImageNet得到相似的效果(77.6),不過DeiT中發(fā)現(xiàn)使用RandAugment效果更好一些( DeiT-B:81.8 vs 81.2)。目前torchvision庫也已經(jīng)實現(xiàn)了RandAugment,具體使用如下所示:
from?torchvision.transforms?import?autoaugment,?transforms
train_transform?=?transforms.Compose([
????transforms.RandomResizedCrop(crop_size,?interpolation=interpolation),
????transforms.RandomHorizontalFlip(hflip_prob),
????autoaugment.RandAugment(interpolation=interpolation),
?transforms.PILToTensor(),
????transforms.ConvertImageDtype(torch.float),
????transforms.Normalize(mean=mean,?std=std)
?])

TrivialAugment
雖然RandAugment的搜索空間極小,但是對于不同的數(shù)據(jù)集還是需要確定最優(yōu)的N和M,這依然有較大的實驗成本。RandAugment后,華為提出了UniformAugment,這種策略不需要搜索也能取得較好的結(jié)果。不過這里我們介紹一項更新的工作:TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation。TrivialAugment也不需要任何搜索,整個方法非常簡單:每次隨機選擇一個圖像增強操作,然后隨機確定它的增強幅度,并對圖像進行增強。由于沒有任何超參數(shù),所以不需要任何搜索。從實驗結(jié)果上看,TA可以在多個數(shù)據(jù)集上取得更好的結(jié)果,如在ImageNet數(shù)據(jù)集上,ResNet50的top1 acc可以達到78.1,超過RandAugment。

TrivialAugment的圖像增強集合和RandAugment基本一樣,不過TA也定義了一套更寬的增強幅度,目前torchvision中已經(jīng)實現(xiàn)了TrivialAugmentWide,具體使用代碼如下所示:
from?torchvision.transforms?import?autoaugment,?transforms
augmentation_space?=?{
????#?op_name:?(magnitudes,?signed)
????"Identity":?(torch.tensor(0.0),?False),
????"ShearX":?(torch.linspace(0.0,?0.99,?num_bins),?True),
????"ShearY":?(torch.linspace(0.0,?0.99,?num_bins),?True),
????"TranslateX":?(torch.linspace(0.0,?32.0,?num_bins),?True),
????"TranslateY":?(torch.linspace(0.0,?32.0,?num_bins),?True),
????"Rotate":?(torch.linspace(0.0,?135.0,?num_bins),?True),
????"Brightness":?(torch.linspace(0.0,?0.99,?num_bins),?True),
????"Color":?(torch.linspace(0.0,?0.99,?num_bins),?True),
????"Contrast":?(torch.linspace(0.0,?0.99,?num_bins),?True),
????"Sharpness":?(torch.linspace(0.0,?0.99,?num_bins),?True),
????"Posterize":?(8?-?(torch.arange(num_bins)?/?((num_bins?-?1)?/?6)).round().int(),?False),
????"Solarize":?(torch.linspace(255.0,?0.0,?num_bins),?False),
????"AutoContrast":?(torch.tensor(0.0),?False),
????"Equalize":?(torch.tensor(0.0),?False),
}
train_transform?=?transforms.Compose([
????transforms.RandomResizedCrop(crop_size,?interpolation=interpolation),
????transforms.RandomHorizontalFlip(hflip_prob),
????autoaugment.TrivialAugmentWide(interpolation=interpolation),
?transforms.PILToTensor(),
????transforms.ConvertImageDtype(torch.float),
????transforms.Normalize(mean=mean,?std=std)
?])

RandomErasing
RandomErasing是廈門大學(xué)在2017年提出的一種簡單的數(shù)據(jù)增強(這個策略和同期的CutOut基本一樣),基本原理是:隨機從圖像中擦除一個矩形區(qū)域而不改變圖像的原始標(biāo)簽。DeiT的訓(xùn)練策略中也包括了RandomErasing。

目前torchvision也實現(xiàn)了RandomErasing,其具體使用代碼如下(注意這個op不支持PIL圖像,需要在轉(zhuǎn)換為torch.tensor后使用):
train_transform?=?transforms.Compose([
????transforms.RandomResizedCrop(224,?scale=(0.08,?1.0),?ratio=(3.?/?4.,?4.?/?3.)),
????transforms.RandomHorizontalFlip(),
????transforms.PILToTensor()
????transforms.ConvertImageDtype(torch.float),
????normalize,
????#?scale是指相對于原圖的擦除面積范圍
????#?ratio是指擦除區(qū)域的寬高比
????#?value是指擦除區(qū)域的值,如果是int,也可以是tuple(RGB3個通道值),或者是str,需為'random',表示隨機生成
????transforms.RandomErasing(p=0.5,?scale=(0.02,?0.33),?ratio=(0.3,?3.3),?value=0,?inplace=False),
?])

MixUp
MixUp在FAIR在2017年提出的一種數(shù)據(jù)增強方法:兩張不同的圖像隨機線性組合,而同時生成線性組合的標(biāo)簽。

這里的和是兩張不同的圖像,和是它們對應(yīng)的one-hot標(biāo)簽,而是線性組合系數(shù),每次執(zhí)行時隨機生成。假定圖像分類任務(wù)是2分類(區(qū)分狗和貓),兩張輸入圖像分別是狗和貓(如下圖所示),它們對應(yīng)的one-hot標(biāo)簽分別是[1,0]和[0, 1]。在進行mixup之前,首先對它們進行必要的數(shù)據(jù)增強得到aug_img1和aug_img2,然后隨機生成線性組合系數(shù),對于得到的圖像是mix_img1,標(biāo)簽變?yōu)閇0.7, 0.3],而得到的圖像是mix_img2,標(biāo)簽變?yōu)閇0.3, 0.7]。

目前timm和torchvision中已經(jīng)實現(xiàn)了mixup,這里以torchvision為例來講述具體的代碼實現(xiàn)。由于mixup需要兩個輸入,而不單單是對當(dāng)前圖像進行操作,所以一般是在得到batch數(shù)據(jù)后再進行mixup,這也意味著圖像也已經(jīng)完成了其它的數(shù)據(jù)增強如RandAugment,對于batch中的每個樣本可以隨機選擇另外一個樣本進行mixup。具體的實現(xiàn)代碼如下所示:
#?from?https://github.com/pytorch/vision/blob/main/references/classification/transforms.py
class?RandomMixup(torch.nn.Module):
????"""Randomly?apply?Mixup?to?the?provided?batch?and?targets.
????The?class?implements?the?data?augmentations?as?described?in?the?paper
????`"mixup:?Beyond?Empirical?Risk?Minimization"?`_.
????Args:
????????num_classes?(int):?number?of?classes?used?for?one-hot?encoding.
????????p?(float):?probability?of?the?batch?being?transformed.?Default?value?is?0.5.
????????alpha?(float):?hyperparameter?of?the?Beta?distribution?used?for?mixup.
????????????Default?value?is?1.0.?#?beta分布超參數(shù)
????????inplace?(bool):?boolean?to?make?this?transform?inplace.?Default?set?to?False.
????"""
????def?__init__(self,?num_classes:?int,?p:?float?=?0.5,?alpha:?float?=?1.0,?inplace:?bool?=?False)?->?None:
????????super().__init__()
????????assert?num_classes?>?0,?"Please?provide?a?valid?positive?value?for?the?num_classes."
????????assert?alpha?>?0,?"Alpha?param?can't?be?zero."
????????self.num_classes?=?num_classes
????????self.p?=?p
????????self.alpha?=?alpha
????????self.inplace?=?inplace
????def?forward(self,?batch:?Tensor,?target:?Tensor)?->?Tuple[Tensor,?Tensor]:
????????"""
????????Args:
????????????batch?(Tensor):?Float?tensor?of?size?(B,?C,?H,?W)
????????????target?(Tensor):?Integer?tensor?of?size?(B,?)
????????Returns:
????????????Tensor:?Randomly?transformed?batch.
????????"""
????????if?batch.ndim?!=?4:
????????????raise?ValueError(f"Batch?ndim?should?be?4.?Got?{batch.ndim}")
????????if?target.ndim?!=?1:
????????????raise?ValueError(f"Target?ndim?should?be?1.?Got?{target.ndim}")
????????if?not?batch.is_floating_point():
????????????raise?TypeError(f"Batch?dtype?should?be?a?float?tensor.?Got?{batch.dtype}.")
????????if?target.dtype?!=?torch.int64:
????????????raise?TypeError(f"Target?dtype?should?be?torch.int64.?Got?{target.dtype}")
????????if?not?self.inplace:
????????????batch?=?batch.clone()
????????????target?=?target.clone()
??
????????#?建立one-hot標(biāo)簽
????????if?target.ndim?==?1:
????????????target?=?torch.nn.functional.one_hot(target,?num_classes=self.num_classes).to(dtype=batch.dtype)
??
????????#?判斷是否進行mixup
????????if?torch.rand(1).item()?>=?self.p:
????????????return?batch,?target
??
????????#?這里將batch數(shù)據(jù)平移一個單位,產(chǎn)生mixup的圖像對,這意味著每個圖像與相鄰的下一個圖像進行mixup
????????#?timm實現(xiàn)是通過flip來做的,這意味著第一個圖像和最后一個圖像進行mixup
????????#?It's?faster?to?roll?the?batch?by?one?instead?of?shuffling?it?to?create?image?pairs
????????batch_rolled?=?batch.roll(1,?0)
????????target_rolled?=?target.roll(1,?0)
??
????????#?隨機生成組合系數(shù)
????????#?Implemented?as?on?mixup?paper,?page?3.
????????lambda_param?=?float(torch._sample_dirichlet(torch.tensor([self.alpha,?self.alpha]))[0])
????????batch_rolled.mul_(1.0?-?lambda_param)
????????batch.mul_(lambda_param).add_(batch_rolled)?#?得到mixup后的圖像
????????target_rolled.mul_(1.0?-?lambda_param)
????????target.mul_(lambda_param).add_(target_rolled)?#?得到mixup后的標(biāo)簽
????????return?batch,?target
然后可以將MixUp操作放在DataLoader的collate_fn中,這個函數(shù)要實現(xiàn)的是將多個樣本合并成一個mini-batch,所以可以將MixUp插在得到mini-batch后,具體實現(xiàn)如下所示:
from?torch.utils.data.dataloader?import?default_collate
mixup_transform?=?RandomMixup(num_classes,?p=1.0,?alpha=mixup_alpha)
collate_fn?=?lambda?batch:?mixup_transform(*default_collate(batch))
data_loader?=?torch.utils.data.DataLoader(dataset,?batch_size=batch_size,
????sampler=train_sampler,?collate_fn=collate_fn)
對于MixUp,還要注意兩個兩點。第一個是如果同時采用了label smoothing,那么在創(chuàng)建one-hot標(biāo)簽時要直接得到smooth后的標(biāo)簽,具體實現(xiàn)如下(參考timm):
def?one_hot(x,?num_classes,?on_value=1.,?off_value=0.,?device='cuda'):
????x?=?x.long().view(-1,?1)
????return?torch.full((x.size()[0],?num_classes),?off_value,?device=device).scatter_(1,?x,?on_value)
off_value?=?smoothing?/?num_classes
on_value?=?1.?-?smoothing?+?off_value
smooth_one_hot?=?one_hot(target,?num_classes,?on_value=on_value,?off_value=off_value)
第二個要注意的是MixUp后得到標(biāo)簽時soft label,不能直接采用torch.nn.CrossEntropyLoss來計算loss,而是直接計算交叉熵(參考timm):
class?SoftTargetCrossEntropy(nn.Module):
????def?__init__(self):
????????super(SoftTargetCrossEntropy,?self).__init__()
????def?forward(self,?x:?torch.Tensor,?target:?torch.Tensor)?->?torch.Tensor:
????????loss?=?torch.sum(-target?*?F.log_softmax(x,?dim=-1),?dim=-1)
????????return?loss.mean()
注意在PyTorch1.10版本之后,torch.nn.CrossEntropyLoss已經(jīng)支持直接送入的target是probabilities for each class,原來只支持target是class indices;而且也支持label_smoothing參數(shù),所以上述兩個注意點就不再需要了。
說到計算loss,timm作者近期在ResNet strikes back: An improved training procedure in timm指出采用MixUp后可以將多分類改成多標(biāo)簽分類(multi-label classification),即從N分類變成N個2分類(直接采用BinaryCrossEntropy),這應(yīng)該更符合MixUp后圖像的語義,從對比實驗來看效果有微弱的提升。MixUp除了可以用于圖像分類任務(wù),還可以用于物體檢測任務(wù)中,比如YOLOX就采用了MixUp,這里面的做法是對圖像mixup后,其box為兩個圖像的box的合并集合,而沒有對標(biāo)簽軟化,這塊也可以見論文Bag of Freebies for Training Object Detection Neural Networks。

CutMix
CutMix是2019年提出的一項和MixUp和類似的數(shù)據(jù)增強策略,它也是同時對兩個圖像和標(biāo)簽進行混合,與MixUp不同的是它的圖像混合方式。CutMix不是對兩個圖像線性組合,而是從另外一張圖像隨機剪切一個patch并粘貼到第一張圖像上,patch的起始坐標(biāo)隨機生成,而寬高是由來控制:

這里和是原始圖像的寬和高,所以其實決定的是patch和原圖的面積比:。下圖展示了分別取0.7和0.3的混合效果,越小,粘貼的patch越大。對于標(biāo)簽,其處理方式和MixUp一樣,通過來得到兩張圖像的線性組合。

CutMix做了ImageNet上的對比實驗,相比MixUp,ResNet50的top1 acc大約能提升一個點(77.4 vs 78.6):

目前timm和torchvision中也已經(jīng)實現(xiàn)了CutMix,這里還是以torchvision為例來講述具體的代碼實現(xiàn),如下所示(和MixUp基本類似,只不過內(nèi)部處理存在差異):
class?RandomCutmix(torch.nn.Module):
????"""Randomly?apply?Cutmix?to?the?provided?batch?and?targets.
????The?class?implements?the?data?augmentations?as?described?in?the?paper
????`"CutMix:?Regularization?Strategy?to?Train?Strong?Classifiers?with?Localizable?Features"
????`_.
????Args:
????????num_classes?(int):?number?of?classes?used?for?one-hot?encoding.
????????p?(float):?probability?of?the?batch?being?transformed.?Default?value?is?0.5.
????????alpha?(float):?hyperparameter?of?the?Beta?distribution?used?for?cutmix.
????????????Default?value?is?1.0.
????????inplace?(bool):?boolean?to?make?this?transform?inplace.?Default?set?to?False.
????"""
????def?__init__(self,?num_classes:?int,?p:?float?=?0.5,?alpha:?float?=?1.0,?inplace:?bool?=?False)?->?None:
????????super().__init__()
????????assert?num_classes?>?0,?"Please?provide?a?valid?positive?value?for?the?num_classes."
????????assert?alpha?>?0,?"Alpha?param?can't?be?zero."
????????self.num_classes?=?num_classes
????????self.p?=?p
????????self.alpha?=?alpha
????????self.inplace?=?inplace
????def?forward(self,?batch:?Tensor,?target:?Tensor)?->?Tuple[Tensor,?Tensor]:
????????"""
????????Args:
????????????batch?(Tensor):?Float?tensor?of?size?(B,?C,?H,?W)
????????????target?(Tensor):?Integer?tensor?of?size?(B,?)
????????Returns:
????????????Tensor:?Randomly?transformed?batch.
????????"""
????????if?batch.ndim?!=?4:
????????????raise?ValueError(f"Batch?ndim?should?be?4.?Got?{batch.ndim}")
????????if?target.ndim?!=?1:
????????????raise?ValueError(f"Target?ndim?should?be?1.?Got?{target.ndim}")
????????if?not?batch.is_floating_point():
????????????raise?TypeError(f"Batch?dtype?should?be?a?float?tensor.?Got?{batch.dtype}.")
????????if?target.dtype?!=?torch.int64:
????????????raise?TypeError(f"Target?dtype?should?be?torch.int64.?Got?{target.dtype}")
????????if?not?self.inplace:
????????????batch?=?batch.clone()
????????????target?=?target.clone()
????????if?target.ndim?==?1:
????????????target?=?torch.nn.functional.one_hot(target,?num_classes=self.num_classes).to(dtype=batch.dtype)
????????if?torch.rand(1).item()?>=?self.p:
????????????return?batch,?target
????????#?It's?faster?to?roll?the?batch?by?one?instead?of?shuffling?it?to?create?image?pairs
????????batch_rolled?=?batch.roll(1,?0)
????????target_rolled?=?target.roll(1,?0)
????????#?Implemented?as?on?cutmix?paper,?page?12?(with?minor?corrections?on?typos).
????????lambda_param?=?float(torch._sample_dirichlet(torch.tensor([self.alpha,?self.alpha]))[0])
????????W,?H?=?F.get_image_size(batch)
??
????????#?確定patch的起點
????????r_x?=?torch.randint(W,?(1,))
????????r_y?=?torch.randint(H,?(1,))
??
????????#?確定patch的w和h(其實是一半大?。?/span>
????????r?=?0.5?*?math.sqrt(1.0?-?lambda_param)
????????r_w_half?=?int(r?*?W)
????????r_h_half?=?int(r?*?H)
??
????????#?越界處理
????????x1?=?int(torch.clamp(r_x?-?r_w_half,?min=0))
????????y1?=?int(torch.clamp(r_y?-?r_h_half,?min=0))
????????x2?=?int(torch.clamp(r_x?+?r_w_half,?max=W))
????????y2?=?int(torch.clamp(r_y?+?r_h_half,?max=H))
????????batch[:,?:,?y1:y2,?x1:x2]?=?batch_rolled[:,?:,?y1:y2,?x1:x2]
????????#?由于越界處理,?λ可能發(fā)生改變,所以要重新計算
????????lambda_param?=?float(1.0?-?(x2?-?x1)?*?(y2?-?y1)?/?(W?*?H))
????????target_rolled.mul_(1.0?-?lambda_param)
????????target.mul_(lambda_param).add_(target_rolled)
????????return?batch,?target
其它使用和MixUp一樣。
Repeated Augmentation
Repeated Augmentation (RA)是FAIR在MultiGrain提出的一種抽樣策略,一般情況下,訓(xùn)練的mini-batch包含的增強過的sample都是來自不同的圖像,但是RA這種抽樣策略允許一個mini-batch中包含來自同一個圖像的不同增強版本,此時mini-batch的各個樣本并非是完全獨立的,這相當(dāng)于對同一個樣本進行重復(fù)抽樣,所以稱為Repeated Augmentation。這篇論文認為在一個mini-batch學(xué)習(xí)來自同一個圖像的不同增強版本能讓模型更容易學(xué)習(xí)到增強不變的特征。關(guān)于RA,其實另外一篇較早的論文Augment your batch: better training with larger batches也提出了類似的策略,另外DeepMind在最近的論文Drawing Multiple Augmentation Samples Per Image During Training Efficiently Decreases Test Error也進一步通過實驗來證明這種策略的效果。
DeiT的訓(xùn)練也采用了RA,嚴(yán)格來說RA不屬于數(shù)據(jù)增強策略,而是一種mini-batch抽樣方法,這里也簡單給出DeiT實現(xiàn)的RA(可以替換torch.utils.data.DistributedSampler):
class?RASampler(torch.utils.data.Sampler):
????"""Sampler?that?restricts?data?loading?to?a?subset?of?the?dataset?for?distributed,
????with?repeated?augmentation.
????It?ensures?that?different?each?augmented?version?of?a?sample?will?be?visible?to?a
????different?process?(GPU)
????Heavily?based?on?torch.utils.data.DistributedSampler
????"""
????def?__init__(self,?dataset,?num_replicas=None,?rank=None,?shuffle=True):
????????if?num_replicas?is?None:
????????????if?not?dist.is_available():
????????????????raise?RuntimeError("Requires?distributed?package?to?be?available")
????????????num_replicas?=?dist.get_world_size()
????????if?rank?is?None:
????????????if?not?dist.is_available():
????????????????raise?RuntimeError("Requires?distributed?package?to?be?available")
????????????rank?=?dist.get_rank()
????????self.dataset?=?dataset
????????self.num_replicas?=?num_replicas
????????self.rank?=?rank
????????self.epoch?=?0
????????#?重復(fù)采樣后每個replica的樣本量
????????self.num_samples?=?int(math.ceil(len(self.dataset)?*?3.0?/?self.num_replicas))
????????#?重復(fù)采樣后的總樣本量
????????self.total_size?=?self.num_samples?*?self.num_replicas
????????#?self.num_selected_samples?=?int(math.ceil(len(self.dataset)?/?self.num_replicas))
????????#?每個replica實際樣本量,即不重復(fù)采樣時的每個replica的樣本量
????????self.num_selected_samples?=?int(math.floor(len(self.dataset)?//?256?*?256?/?self.num_replicas))
????????self.shuffle?=?shuffle
????def?__iter__(self):
????????#?deterministically?shuffle?based?on?epoch
????????g?=?torch.Generator()
????????g.manual_seed(self.epoch)
????????if?self.shuffle:
????????????indices?=?torch.randperm(len(self.dataset),?generator=g).tolist()
????????else:
????????????indices?=?list(range(len(self.dataset)))
????????#?add?extra?samples?to?make?it?evenly?divisible
????????indices?=?[ele?for?ele?in?indices?for?i?in?range(3)]?#?重復(fù)3次
????????indices?+=?indices[:(self.total_size?-?len(indices))]
????????assert?len(indices)?==?self.total_size
????????#?subsample:?使得同一個樣本的重復(fù)版本進入不同的進程(GPU)
????????indices?=?indices[self.rank:self.total_size:self.num_replicas]
????????assert?len(indices)?==?self.num_samples
????????return?iter(indices[:self.num_selected_samples])?#?截取實際樣本量
????def?__len__(self):
????????return?self.num_selected_samples
????def?set_epoch(self,?epoch):
????????self.epoch?=?epoch
小結(jié)
這里簡單介紹了幾種常用且有效的數(shù)據(jù)增強策略,這些策略在vision transformer模型被使用,而且timm訓(xùn)練的ResNet新baseline也使用了這些策略。
參考
Training data-efficient image transformers & distillation through attention? (https://arxiv.org/abs/2012.12877) AutoAugment: Learning Augmentation Policies from Data? (https://arxiv.org/abs/1805.09501) RandAugment: Practical automated data augmentation with a reduced search space? (https://arxiv.org/abs/1909.13719) TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation? (https://arxiv.org/abs/2103.10158) Random Erasing Data Augmentation(https://arxiv.org/abs/1708.04896) Augment your batch: better training with larger batches? (https://arxiv.org/abs/1901.09335) MultiGrain: a unified image embedding for classes and instances(https://arxiv.org/abs/1902.05509) mixup: Beyond Empirical Risk Minimization? (https://arxiv.org/abs/1710.09412) CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features? (https://arxiv.org/abs/1905.04899)
猜您喜歡:
?戳我,查看GAN的系列專輯~!附下載 |《TensorFlow 2.0 深度學(xué)習(xí)算法實戰(zhàn)》
附下載 |《計算機視覺中的數(shù)學(xué)方法》分享
《基于深度神經(jīng)網(wǎng)絡(luò)的少樣本學(xué)習(xí)綜述》
