1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        圖像分類訓練技巧之數(shù)據(jù)增強方法總結

        共 20921字,需瀏覽 42分鐘

         ·

        2022-02-11 06:19

        點擊上方“程序員大白”,選擇“星標”公眾號

        重磅干貨,第一時間送達

        作者丨小小將@知乎(已授權)
        來源丨h(huán)ttps://zhuanlan.zhihu.com/p/430563265
        編輯丨極市平臺

        導讀

        ?

        一個模型的性能除了和網(wǎng)絡結構本身有關,還非常依賴具體的訓練策略,比如優(yōu)化器,數(shù)據(jù)增強以及正則化策略等。本文簡單介紹了圖像分類訓練技巧中的常用數(shù)據(jù)增強策略。?

        一個模型的性能除了和網(wǎng)絡結構本身有關,還非常依賴具體的訓練策略,比如優(yōu)化器,數(shù)據(jù)增強以及正則化策略等(當然也很訓練數(shù)據(jù)強相關,訓練數(shù)據(jù)量往往決定模型性能的上線)。近年來,圖像分類模型在ImageNet數(shù)據(jù)集的top1 acc已經由原來的56.5(AlexNet,2012)提升至90.88(CoAtNet,2021,用了額外的數(shù)據(jù)集JFT-3B),這進步除了主要歸功于模型,算力和數(shù)據(jù)的提升,也與訓練策略的提升緊密相關。最近剛興起的vision transformer相比CNN模型往往也需要更heavy的數(shù)據(jù)增強和正則化策略。這里簡單介紹圖像分類訓練技巧中的常用數(shù)據(jù)增強策略。

        baseline

        ImageNet數(shù)據(jù)集訓練常用的數(shù)據(jù)增強策略如下,訓練過程的數(shù)據(jù)增強包括隨機縮放裁剪(RandomResizedCrop,這種處理方式源自谷歌的Inception,所以稱為 Inception-style pre-processing)和水平翻轉(RandomHorizontalFlip),而測試階段是執(zhí)行縮放和中心裁剪。這其實是一種輕量級的策略,這里稱之為baseline。torchvision的實現(xiàn)的ResNet50訓練采用的策略就是這個,在ImageNet上的top1 acc可以達到76.1。

        from?torchvision?import?transforms

        normalize?=?transforms.Normalize(mean=[0.485,?0.456,?0.406],
        ?????????????????????????????????std=[0.229,?0.224,?0.225])
        #?訓練
        train_transform?=?transforms.Compose([
        ????#?這里的scale指的是面積,ratio是寬高比
        ????#?具體實現(xiàn)每次先隨機確定scale和ratio,可以生成w和h,然后隨機確定裁剪位置進行crop
        ????#?最后是resize到target?size
        ????transforms.RandomResizedCrop(224,?scale=(0.08,?1.0),?ratio=(3.?/?4.,?4.?/?3.)),
        ????transforms.RandomHorizontalFlip(),
        ????transforms.ToTensor(),
        ????normalize
        ?])
        #?測試
        test_transform?=?transforms.Compose([
        ????transforms.Resize(256),
        ????transforms.CenterCrop(224),
        ????transforms.ToTensor(),
        ????normalize,
        ?])

        AutoAugment

        谷歌在2018年提出通過AutoML來自動搜索數(shù)據(jù)增強策略,稱之為AutoAugment(算是自動數(shù)據(jù)增強開山之作)。搜索方法采用強化學習,和NAS類似,只不過搜索空間是數(shù)據(jù)增強策略,而不是網(wǎng)絡架構。在搜索空間里,一個policy包含5個sub-policies,每個sub-policy包含兩個串行的圖像增強操作,每個增強操作有兩個超參數(shù):進行該操作的概率圖像增強的幅度(magnitude,這個表示數(shù)據(jù)增強的強度,比如對于旋轉,旋轉的角度就是增強幅度,旋轉角度越大,增強越大)。每個policy在執(zhí)行時,首先隨機從5個策略中隨機選擇一個sub-policy,然后序列執(zhí)行兩個圖像操作。

        搜索空間一共有16種圖像增強類型,具體如下所示,大部分操作都定義了圖像增強的幅度范圍,在搜索時需要將幅度值離散化,具體地是將幅度值在定義范圍內均勻地取10個值。

        論文在不同的數(shù)據(jù)集上( CIFAR-10 , SVHN, ImageNet)做了實驗,這里給出在ImageNet數(shù)據(jù)集上搜索得到的最優(yōu)policy(最后實際上是將搜索得到的前5個最好的policies合成了一個policy,所以這里包含25個sub-policies):

        #?operation,?probability,?magnitude
        (("Posterize",?0.4,?8),?("Rotate",?0.6,?9)),
        (("Solarize",?0.6,?5),?("AutoContrast",?0.6,?None)),??????????????????????????????????????????????????????????
        (("Equalize",?0.8,?None),?("Equalize",?0.6,?None)),
        (("Posterize",?0.6,?7),?("Posterize",?0.6,?6)),
        (("Equalize",?0.4,?None),?("Solarize",?0.2,?4)),
        (("Equalize",?0.4,?None),?("Rotate",?0.8,?8)),
        (("Solarize",?0.6,?3),?("Equalize",?0.6,?None)),
        (("Posterize",?0.8,?5),?("Equalize",?1.0,?None)),
        (("Rotate",?0.2,?3),?("Solarize",?0.6,?8)),
        (("Equalize",?0.6,?None),?("Posterize",?0.4,?6)),
        (("Rotate",?0.8,?8),?("Color",?0.4,?0)),
        (("Rotate",?0.4,?9),?("Equalize",?0.6,?None)),
        (("Equalize",?0.0,?None),?("Equalize",?0.8,?None)),
        (("Invert",?0.6,?None),?("Equalize",?1.0,?None)),
        (("Color",?0.6,?4),?("Contrast",?1.0,?8)),
        (("Rotate",?0.8,?8),?("Color",?1.0,?2)),
        (("Color",?0.8,?8),?("Solarize",?0.8,?7)),
        (("Sharpness",?0.4,?7),?("Invert",?0.6,?None)),
        (("ShearX",?0.6,?5),?("Equalize",?1.0,?None)),
        (("Color",?0.4,?0),?("Equalize",?0.6,?None)),
        (("Equalize",?0.4,?None),?("Solarize",?0.2,?4)),
        (("Solarize",?0.6,?5),?("AutoContrast",?0.6,?None)),
        (("Invert",?0.6,?None),?("Equalize",?1.0,?None)),
        (("Color",?0.6,?4),?("Contrast",?1.0,?8)),
        (("Equalize",?0.8,?None),?("Equalize",?0.6,?None))

        基于搜索得到的AutoAugment訓練可以將ResNet50在ImageNet數(shù)據(jù)集上的top1 acc從76.3提升至77.6。一個比較重要的問題,這些從某一個數(shù)據(jù)集搜索得到的策略是否只對固定的數(shù)據(jù)集有效,論文也通過具體實驗證明了AutoAugment的遷移能力,比如將ImageNet數(shù)據(jù)集上得到的策略用在5個 FGVC數(shù)據(jù)集(與ImageNet圖像輸入大小相似)也均有提升。

        目前torchvision庫已經實現(xiàn)了AutoAugment,具體使用如下所示(注意AutoAug前也需要包括一個RandomResizedCrop):

        from?torchvision.transforms?import?autoaugment,?transforms

        train_transform?=?transforms.Compose([
        ????transforms.RandomResizedCrop(crop_size,?interpolation=interpolation),
        ????transforms.RandomHorizontalFlip(hflip_prob),
        ????#?這里policy屬于torchvision.transforms.autoaugment.AutoAugmentPolicy,
        ????#?對于ImageNet就是?AutoAugmentPolicy.IMAGENET
        ????#?此時aa_policy?=?autoaugment.AutoAugmentPolicy('imagenet')
        ????autoaugment.AutoAugment(policy=aa_policy,?interpolation=interpolation),
        ?transforms.PILToTensor(),
        ????transforms.ConvertImageDtype(torch.float),
        ????transforms.Normalize(mean=mean,?std=std)
        ?])

        RandAugment

        AutoAugment存在的一個問題是搜索空間巨大,這使得搜索只能在代理任務中進行:使用小的模型在ImageNet的一個小的子集( 120類和6000圖片)搜索。谷歌在2019年又提出了一個更簡單的數(shù)據(jù)增強策略:RandAugment。這篇論文首先發(fā)現(xiàn)AutoAugment這樣在小數(shù)據(jù)集上搜索出來的策略在大的數(shù)據(jù)集上應用會存在問題,這主要是因為數(shù)據(jù)增強策略和模型大小和數(shù)據(jù)量大小存在強相關,如下圖所示可以看到模型或者訓練數(shù)據(jù)量越大,其最優(yōu)的數(shù)據(jù)增強的幅度越大,這說明AutoAugment得到的結果應該是次優(yōu)的。另外,Population Based Augmentation這篇論文發(fā)現(xiàn)最優(yōu)的數(shù)據(jù)增強幅度是隨訓練過程增加,而且不同的增強操作遵循類似的規(guī)律,這啟發(fā)作者采用固定的增強幅度而不是去搜索。RandAugment相比AutoAugment的策略空間很?。?span style="cursor: pointer;"> vs ),所以它不需要采用代理任務,甚至直接采用簡單的網(wǎng)格搜索。

        具體地,RandAugment共包含兩個超參數(shù):圖像增強操作的數(shù)量N和一個全局的增強幅度M,其實現(xiàn)代碼如下所示,每次從候選操作集合(共14種策略)隨機選擇N個操作(等概率),然后串行執(zhí)行(這里沒有判斷概率,是一定執(zhí)行)。這里的M取值范圍為{0, . . . , 30}(每個圖像增強操作歸一化到同樣的幅度范圍),而N取值范圍一般為 {1, 2, 3}。

        #?Identity是恒等變換,不做任何增強
        transforms?=?['Identity',?'AutoContrast',?'Equalize',?'Rotate',?'Solarize',?'Color',?'Posterize',?
        ??????????????'Contrast',?'Brightness',?'Sharpness',?'ShearX',?'ShearY',?'TranslateX',?'TranslateY']

        def?randaugment(N,?M):
        ?"""Generate?a?set?of?distortions.
        ?Args:
        ?N:?Number?of?augmentation?transformations?to
        ?apply?sequentially.
        ?M:?Magnitude?for?all?the?transformations.
        ?"""

        ?sampled_ops?=?np.random.choice(transforms,?N)
        ?return?[(op,?M)?for?op?in?sampled_ops]

        對于ResNet50,其搜索得到的N=2,M=9,RandAugment相比AutoAugment可以在ImageNet得到相似的效果(77.6),不過DeiT中發(fā)現(xiàn)使用RandAugment效果更好一些( DeiT-B:81.8 vs 81.2)。目前torchvision庫也已經實現(xiàn)了RandAugment,具體使用如下所示:

        from?torchvision.transforms?import?autoaugment,?transforms

        train_transform?=?transforms.Compose([
        ????transforms.RandomResizedCrop(crop_size,?interpolation=interpolation),
        ????transforms.RandomHorizontalFlip(hflip_prob),
        ????autoaugment.RandAugment(interpolation=interpolation),
        ?transforms.PILToTensor(),
        ????transforms.ConvertImageDtype(torch.float),
        ????transforms.Normalize(mean=mean,?std=std)
        ?])

        TrivialAugment

        雖然RandAugment的搜索空間極小,但是對于不同的數(shù)據(jù)集還是需要確定最優(yōu)的N和M,這依然有較大的實驗成本。RandAugment后,華為提出了UniformAugment,這種策略不需要搜索也能取得較好的結果。不過這里我們介紹一項更新的工作:TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation。TrivialAugment也不需要任何搜索,整個方法非常簡單:每次隨機選擇一個圖像增強操作,然后隨機確定它的增強幅度,并對圖像進行增強。由于沒有任何超參數(shù),所以不需要任何搜索。從實驗結果上看,TA可以在多個數(shù)據(jù)集上取得更好的結果,如在ImageNet數(shù)據(jù)集上,ResNet50的top1 acc可以達到78.1,超過RandAugment。

        TrivialAugment的圖像增強集合和RandAugment基本一樣,不過TA也定義了一套更寬的增強幅度,目前torchvision中已經實現(xiàn)了TrivialAugmentWide,具體使用代碼如下所示:

        from?torchvision.transforms?import?autoaugment,?transforms

        augmentation_space?=?{
        ????#?op_name:?(magnitudes,?signed)
        ????"Identity":?(torch.tensor(0.0),?False),
        ????"ShearX":?(torch.linspace(0.0,?0.99,?num_bins),?True),
        ????"ShearY":?(torch.linspace(0.0,?0.99,?num_bins),?True),
        ????"TranslateX":?(torch.linspace(0.0,?32.0,?num_bins),?True),
        ????"TranslateY":?(torch.linspace(0.0,?32.0,?num_bins),?True),
        ????"Rotate":?(torch.linspace(0.0,?135.0,?num_bins),?True),
        ????"Brightness":?(torch.linspace(0.0,?0.99,?num_bins),?True),
        ????"Color":?(torch.linspace(0.0,?0.99,?num_bins),?True),
        ????"Contrast":?(torch.linspace(0.0,?0.99,?num_bins),?True),
        ????"Sharpness":?(torch.linspace(0.0,?0.99,?num_bins),?True),
        ????"Posterize":?(8?-?(torch.arange(num_bins)?/?((num_bins?-?1)?/?6)).round().int(),?False),
        ????"Solarize":?(torch.linspace(255.0,?0.0,?num_bins),?False),
        ????"AutoContrast":?(torch.tensor(0.0),?False),
        ????"Equalize":?(torch.tensor(0.0),?False),
        }

        train_transform?=?transforms.Compose([
        ????transforms.RandomResizedCrop(crop_size,?interpolation=interpolation),
        ????transforms.RandomHorizontalFlip(hflip_prob),
        ????autoaugment.TrivialAugmentWide(interpolation=interpolation),
        ?transforms.PILToTensor(),
        ????transforms.ConvertImageDtype(torch.float),
        ????transforms.Normalize(mean=mean,?std=std)
        ?])

        RandomErasing

        RandomErasing是廈門大學在2017年提出的一種簡單的數(shù)據(jù)增強(這個策略和同期的CutOut基本一樣),基本原理是:隨機從圖像中擦除一個矩形區(qū)域而不改變圖像的原始標簽。DeiT的訓練策略中也包括了RandomErasing。

        目前torchvision也實現(xiàn)了RandomErasing,其具體使用代碼如下(注意這個op不支持PIL圖像,需要在轉換為torch.tensor后使用):

        train_transform?=?transforms.Compose([
        ????transforms.RandomResizedCrop(224,?scale=(0.08,?1.0),?ratio=(3.?/?4.,?4.?/?3.)),
        ????transforms.RandomHorizontalFlip(),
        ????transforms.PILToTensor()
        ????transforms.ConvertImageDtype(torch.float),
        ????normalize,
        ????#?scale是指相對于原圖的擦除面積范圍
        ????#?ratio是指擦除區(qū)域的寬高比
        ????#?value是指擦除區(qū)域的值,如果是int,也可以是tuple(RGB3個通道值),或者是str,需為'random',表示隨機生成
        ????transforms.RandomErasing(p=0.5,?scale=(0.02,?0.33),?ratio=(0.3,?3.3),?value=0,?inplace=False),
        ?])

        MixUp

        MixUp在FAIR在2017年提出的一種數(shù)據(jù)增強方法:兩張不同的圖像隨機線性組合,而同時生成線性組合的標簽。

        這里的是兩張不同的圖像,是它們對應的one-hot標簽,而是線性組合系數(shù),每次執(zhí)行時隨機生成。假定圖像分類任務是2分類(區(qū)分狗和貓),兩張輸入圖像分別是狗和貓(如下圖所示),它們對應的one-hot標簽分別是[1,0]和[0, 1]。在進行mixup之前,首先對它們進行必要的數(shù)據(jù)增強得到aug_img1和aug_img2,然后隨機生成線性組合系數(shù),對于得到的圖像是mix_img1,標簽變?yōu)閇0.7, 0.3],而得到的圖像是mix_img2,標簽變?yōu)閇0.3, 0.7]。

        目前timm和torchvision中已經實現(xiàn)了mixup,這里以torchvision為例來講述具體的代碼實現(xiàn)。由于mixup需要兩個輸入,而不單單是對當前圖像進行操作,所以一般是在得到batch數(shù)據(jù)后再進行mixup,這也意味著圖像也已經完成了其它的數(shù)據(jù)增強如RandAugment,對于batch中的每個樣本可以隨機選擇另外一個樣本進行mixup。具體的實現(xiàn)代碼如下所示:

        #?from?https://github.com/pytorch/vision/blob/main/references/classification/transforms.py
        class?RandomMixup(torch.nn.Module):
        ????"""Randomly?apply?Mixup?to?the?provided?batch?and?targets.
        ????The?class?implements?the?data?augmentations?as?described?in?the?paper
        ????`"mixup:?Beyond?Empirical?Risk?Minimization"?`_.
        ????Args:
        ????????num_classes?(int):?number?of?classes?used?for?one-hot?encoding.
        ????????p?(float):?probability?of?the?batch?being?transformed.?Default?value?is?0.5.
        ????????alpha?(float):?hyperparameter?of?the?Beta?distribution?used?for?mixup.
        ????????????Default?value?is?1.0.?#?beta分布超參數(shù)
        ????????inplace?(bool):?boolean?to?make?this?transform?inplace.?Default?set?to?False.
        ????"""


        ????def?__init__(self,?num_classes:?int,?p:?float?=?0.5,?alpha:?float?=?1.0,?inplace:?bool?=?False)?->?None:
        ????????super().__init__()
        ????????assert?num_classes?>?0,?"Please?provide?a?valid?positive?value?for?the?num_classes."
        ????????assert?alpha?>?0,?"Alpha?param?can't?be?zero."

        ????????self.num_classes?=?num_classes
        ????????self.p?=?p
        ????????self.alpha?=?alpha
        ????????self.inplace?=?inplace

        ????def?forward(self,?batch:?Tensor,?target:?Tensor)?->?Tuple[Tensor,?Tensor]:
        ????????"""
        ????????Args:
        ????????????batch?(Tensor):?Float?tensor?of?size?(B,?C,?H,?W)
        ????????????target?(Tensor):?Integer?tensor?of?size?(B,?)
        ????????Returns:
        ????????????Tensor:?Randomly?transformed?batch.
        ????????"""

        ????????if?batch.ndim?!=?4:
        ????????????raise?ValueError(f"Batch?ndim?should?be?4.?Got?{batch.ndim}")
        ????????if?target.ndim?!=?1:
        ????????????raise?ValueError(f"Target?ndim?should?be?1.?Got?{target.ndim}")
        ????????if?not?batch.is_floating_point():
        ????????????raise?TypeError(f"Batch?dtype?should?be?a?float?tensor.?Got?{batch.dtype}.")
        ????????if?target.dtype?!=?torch.int64:
        ????????????raise?TypeError(f"Target?dtype?should?be?torch.int64.?Got?{target.dtype}")

        ????????if?not?self.inplace:
        ????????????batch?=?batch.clone()
        ????????????target?=?target.clone()
        ??
        ????????#?建立one-hot標簽
        ????????if?target.ndim?==?1:
        ????????????target?=?torch.nn.functional.one_hot(target,?num_classes=self.num_classes).to(dtype=batch.dtype)
        ??
        ????????#?判斷是否進行mixup
        ????????if?torch.rand(1).item()?>=?self.p:
        ????????????return?batch,?target
        ??
        ????????#?這里將batch數(shù)據(jù)平移一個單位,產生mixup的圖像對,這意味著每個圖像與相鄰的下一個圖像進行mixup
        ????????#?timm實現(xiàn)是通過flip來做的,這意味著第一個圖像和最后一個圖像進行mixup
        ????????#?It's?faster?to?roll?the?batch?by?one?instead?of?shuffling?it?to?create?image?pairs
        ????????batch_rolled?=?batch.roll(1,?0)
        ????????target_rolled?=?target.roll(1,?0)
        ??
        ????????#?隨機生成組合系數(shù)
        ????????#?Implemented?as?on?mixup?paper,?page?3.
        ????????lambda_param?=?float(torch._sample_dirichlet(torch.tensor([self.alpha,?self.alpha]))[0])
        ????????batch_rolled.mul_(1.0?-?lambda_param)
        ????????batch.mul_(lambda_param).add_(batch_rolled)?#?得到mixup后的圖像

        ????????target_rolled.mul_(1.0?-?lambda_param)
        ????????target.mul_(lambda_param).add_(target_rolled)?#?得到mixup后的標簽

        ????????return?batch,?target

        然后可以將MixUp操作放在DataLoader的collate_fn中,這個函數(shù)要實現(xiàn)的是將多個樣本合并成一個mini-batch,所以可以將MixUp插在得到mini-batch后,具體實現(xiàn)如下所示:

        from?torch.utils.data.dataloader?import?default_collate

        mixup_transform?=?RandomMixup(num_classes,?p=1.0,?alpha=mixup_alpha)
        collate_fn?=?lambda?batch:?mixup_transform(*default_collate(batch))
        data_loader?=?torch.utils.data.DataLoader(dataset,?batch_size=batch_size,
        ????sampler=train_sampler,?collate_fn=collate_fn)

        對于MixUp,還要注意兩個兩點。第一個是如果同時采用了label smoothing,那么在創(chuàng)建one-hot標簽時要直接得到smooth后的標簽,具體實現(xiàn)如下(參考timm):

        def?one_hot(x,?num_classes,?on_value=1.,?off_value=0.,?device='cuda'):
        ????x?=?x.long().view(-1,?1)
        ????return?torch.full((x.size()[0],?num_classes),?off_value,?device=device).scatter_(1,?x,?on_value)

        off_value?=?smoothing?/?num_classes
        on_value?=?1.?-?smoothing?+?off_value
        smooth_one_hot?=?one_hot(target,?num_classes,?on_value=on_value,?off_value=off_value)

        第二個要注意的是MixUp后得到標簽時soft label,不能直接采用torch.nn.CrossEntropyLoss來計算loss,而是直接計算交叉熵(參考timm):

        class?SoftTargetCrossEntropy(nn.Module):

        ????def?__init__(self):
        ????????super(SoftTargetCrossEntropy,?self).__init__()

        ????def?forward(self,?x:?torch.Tensor,?target:?torch.Tensor)?->?torch.Tensor:
        ????????loss?=?torch.sum(-target?*?F.log_softmax(x,?dim=-1),?dim=-1)
        ????????return?loss.mean()

        注意在PyTorch1.10版本之后,torch.nn.CrossEntropyLoss已經支持直接送入的target是probabilities for each class,原來只支持target是class indices;而且也支持label_smoothing參數(shù),所以上述兩個注意點就不再需要了。

        說到計算loss,timm作者近期在ResNet strikes back: An improved training procedure in timm指出采用MixUp后可以將多分類改成多標簽分類(multi-label classification),即從N分類變成N個2分類(直接采用BinaryCrossEntropy),這應該更符合MixUp后圖像的語義,從對比實驗來看效果有微弱的提升。MixUp除了可以用于圖像分類任務,還可以用于物體檢測任務中,比如YOLOX就采用了MixUp,這里面的做法是對圖像mixup后,其box為兩個圖像的box的合并集合,而沒有對標簽軟化,這塊也可以見論文Bag of Freebies for Training Object Detection Neural Networks。

        CutMix

        CutMix是2019年提出的一項和MixUp和類似的數(shù)據(jù)增強策略,它也是同時對兩個圖像和標簽進行混合,與MixUp不同的是它的圖像混合方式。CutMix不是對兩個圖像線性組合,而是從另外一張圖像隨機剪切一個patch并粘貼到第一張圖像上,patch的起始坐標隨機生成,而寬高是由來控制:

        這里是原始圖像的寬和高,所以其實決定的是patch和原圖的面積比:。下圖展示了分別取0.7和0.3的混合效果,越小,粘貼的patch越大。對于標簽,其處理方式和MixUp一樣,通過來得到兩張圖像的線性組合。

        CutMix做了ImageNet上的對比實驗,相比MixUp,ResNet50的top1 acc大約能提升一個點(77.4 vs 78.6):

        目前timm和torchvision中也已經實現(xiàn)了CutMix,這里還是以torchvision為例來講述具體的代碼實現(xiàn),如下所示(和MixUp基本類似,只不過內部處理存在差異):

        class?RandomCutmix(torch.nn.Module):
        ????"""Randomly?apply?Cutmix?to?the?provided?batch?and?targets.
        ????The?class?implements?the?data?augmentations?as?described?in?the?paper
        ????`"CutMix:?Regularization?Strategy?to?Train?Strong?Classifiers?with?Localizable?Features"
        ????`_.
        ????Args:
        ????????num_classes?(int):?number?of?classes?used?for?one-hot?encoding.
        ????????p?(float):?probability?of?the?batch?being?transformed.?Default?value?is?0.5.
        ????????alpha?(float):?hyperparameter?of?the?Beta?distribution?used?for?cutmix.
        ????????????Default?value?is?1.0.
        ????????inplace?(bool):?boolean?to?make?this?transform?inplace.?Default?set?to?False.
        ????"""


        ????def?__init__(self,?num_classes:?int,?p:?float?=?0.5,?alpha:?float?=?1.0,?inplace:?bool?=?False)?->?None:
        ????????super().__init__()
        ????????assert?num_classes?>?0,?"Please?provide?a?valid?positive?value?for?the?num_classes."
        ????????assert?alpha?>?0,?"Alpha?param?can't?be?zero."

        ????????self.num_classes?=?num_classes
        ????????self.p?=?p
        ????????self.alpha?=?alpha
        ????????self.inplace?=?inplace

        ????def?forward(self,?batch:?Tensor,?target:?Tensor)?->?Tuple[Tensor,?Tensor]:
        ????????"""
        ????????Args:
        ????????????batch?(Tensor):?Float?tensor?of?size?(B,?C,?H,?W)
        ????????????target?(Tensor):?Integer?tensor?of?size?(B,?)
        ????????Returns:
        ????????????Tensor:?Randomly?transformed?batch.
        ????????"""

        ????????if?batch.ndim?!=?4:
        ????????????raise?ValueError(f"Batch?ndim?should?be?4.?Got?{batch.ndim}")
        ????????if?target.ndim?!=?1:
        ????????????raise?ValueError(f"Target?ndim?should?be?1.?Got?{target.ndim}")
        ????????if?not?batch.is_floating_point():
        ????????????raise?TypeError(f"Batch?dtype?should?be?a?float?tensor.?Got?{batch.dtype}.")
        ????????if?target.dtype?!=?torch.int64:
        ????????????raise?TypeError(f"Target?dtype?should?be?torch.int64.?Got?{target.dtype}")

        ????????if?not?self.inplace:
        ????????????batch?=?batch.clone()
        ????????????target?=?target.clone()

        ????????if?target.ndim?==?1:
        ????????????target?=?torch.nn.functional.one_hot(target,?num_classes=self.num_classes).to(dtype=batch.dtype)

        ????????if?torch.rand(1).item()?>=?self.p:
        ????????????return?batch,?target

        ????????#?It's?faster?to?roll?the?batch?by?one?instead?of?shuffling?it?to?create?image?pairs
        ????????batch_rolled?=?batch.roll(1,?0)
        ????????target_rolled?=?target.roll(1,?0)

        ????????#?Implemented?as?on?cutmix?paper,?page?12?(with?minor?corrections?on?typos).
        ????????lambda_param?=?float(torch._sample_dirichlet(torch.tensor([self.alpha,?self.alpha]))[0])
        ????????W,?H?=?F.get_image_size(batch)
        ??
        ????????#?確定patch的起點
        ????????r_x?=?torch.randint(W,?(1,))
        ????????r_y?=?torch.randint(H,?(1,))
        ??
        ????????#?確定patch的w和h(其實是一半大小)
        ????????r?=?0.5?*?math.sqrt(1.0?-?lambda_param)
        ????????r_w_half?=?int(r?*?W)
        ????????r_h_half?=?int(r?*?H)
        ??
        ????????#?越界處理
        ????????x1?=?int(torch.clamp(r_x?-?r_w_half,?min=0))
        ????????y1?=?int(torch.clamp(r_y?-?r_h_half,?min=0))
        ????????x2?=?int(torch.clamp(r_x?+?r_w_half,?max=W))
        ????????y2?=?int(torch.clamp(r_y?+?r_h_half,?max=H))

        ????????batch[:,?:,?y1:y2,?x1:x2]?=?batch_rolled[:,?:,?y1:y2,?x1:x2]
        ????????#?由于越界處理,?λ可能發(fā)生改變,所以要重新計算
        ????????lambda_param?=?float(1.0?-?(x2?-?x1)?*?(y2?-?y1)?/?(W?*?H))

        ????????target_rolled.mul_(1.0?-?lambda_param)
        ????????target.mul_(lambda_param).add_(target_rolled)

        ????????return?batch,?target

        其它使用和MixUp一樣。

        Repeated Augmentation

        Repeated Augmentation (RA)是FAIR在MultiGrain提出的一種抽樣策略,一般情況下,訓練的mini-batch包含的增強過的sample都是來自不同的圖像,但是RA這種抽樣策略允許一個mini-batch中包含來自同一個圖像的不同增強版本,此時mini-batch的各個樣本并非是完全獨立的,這相當于對同一個樣本進行重復抽樣,所以稱為Repeated Augmentation。這篇論文認為在一個mini-batch學習來自同一個圖像的不同增強版本能讓模型更容易學習到增強不變的特征。關于RA,其實另外一篇較早的論文Augment your batch: better training with larger batches也提出了類似的策略,另外DeepMind在最近的論文Drawing Multiple Augmentation Samples Per Image During Training Efficiently Decreases Test Error也進一步通過實驗來證明這種策略的效果。

        DeiT的訓練也采用了RA,嚴格來說RA不屬于數(shù)據(jù)增強策略,而是一種mini-batch抽樣方法,這里也簡單給出DeiT實現(xiàn)的RA(可以替換torch.utils.data.DistributedSampler):

        class?RASampler(torch.utils.data.Sampler):
        ????"""Sampler?that?restricts?data?loading?to?a?subset?of?the?dataset?for?distributed,
        ????with?repeated?augmentation.
        ????It?ensures?that?different?each?augmented?version?of?a?sample?will?be?visible?to?a
        ????different?process?(GPU)
        ????Heavily?based?on?torch.utils.data.DistributedSampler
        ????"""


        ????def?__init__(self,?dataset,?num_replicas=None,?rank=None,?shuffle=True):
        ????????if?num_replicas?is?None:
        ????????????if?not?dist.is_available():
        ????????????????raise?RuntimeError("Requires?distributed?package?to?be?available")
        ????????????num_replicas?=?dist.get_world_size()
        ????????if?rank?is?None:
        ????????????if?not?dist.is_available():
        ????????????????raise?RuntimeError("Requires?distributed?package?to?be?available")
        ????????????rank?=?dist.get_rank()
        ????????self.dataset?=?dataset
        ????????self.num_replicas?=?num_replicas
        ????????self.rank?=?rank
        ????????self.epoch?=?0
        ????????#?重復采樣后每個replica的樣本量
        ????????self.num_samples?=?int(math.ceil(len(self.dataset)?*?3.0?/?self.num_replicas))
        ????????#?重復采樣后的總樣本量
        ????????self.total_size?=?self.num_samples?*?self.num_replicas
        ????????#?self.num_selected_samples?=?int(math.ceil(len(self.dataset)?/?self.num_replicas))
        ????????#?每個replica實際樣本量,即不重復采樣時的每個replica的樣本量
        ????????self.num_selected_samples?=?int(math.floor(len(self.dataset)?//?256?*?256?/?self.num_replicas))
        ????????self.shuffle?=?shuffle

        ????def?__iter__(self):
        ????????#?deterministically?shuffle?based?on?epoch
        ????????g?=?torch.Generator()
        ????????g.manual_seed(self.epoch)
        ????????if?self.shuffle:
        ????????????indices?=?torch.randperm(len(self.dataset),?generator=g).tolist()
        ????????else:
        ????????????indices?=?list(range(len(self.dataset)))

        ????????#?add?extra?samples?to?make?it?evenly?divisible
        ????????indices?=?[ele?for?ele?in?indices?for?i?in?range(3)]?#?重復3次
        ????????indices?+=?indices[:(self.total_size?-?len(indices))]
        ????????assert?len(indices)?==?self.total_size

        ????????#?subsample:?使得同一個樣本的重復版本進入不同的進程(GPU)
        ????????indices?=?indices[self.rank:self.total_size:self.num_replicas]
        ????????assert?len(indices)?==?self.num_samples

        ????????return?iter(indices[:self.num_selected_samples])?#?截取實際樣本量

        ????def?__len__(self):
        ????????return?self.num_selected_samples

        ????def?set_epoch(self,?epoch):
        ????????self.epoch?=?epoch

        小結

        這里簡單介紹了幾種常用且有效的數(shù)據(jù)增強策略,這些策略在vision transformer模型被使用,而且timm訓練的ResNet新baseline也使用了這些策略。

        參考

        1. Training data-efficient image transformers & distillation through attention? (https://arxiv.org/abs/2012.12877)
        2. AutoAugment: Learning Augmentation Policies from Data? (https://arxiv.org/abs/1805.09501)
        3. RandAugment: Practical automated data augmentation with a reduced search space? (https://arxiv.org/abs/1909.13719)
        4. TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation? (https://arxiv.org/abs/2103.10158)
        5. Random Erasing Data Augmentation(https://arxiv.org/abs/1708.04896)
        6. Augment your batch: better training with larger batches? (https://arxiv.org/abs/1901.09335)
        7. MultiGrain: a unified image embedding for classes and instances(https://arxiv.org/abs/1902.05509)
        8. mixup: Beyond Empirical Risk Minimization? (https://arxiv.org/abs/1710.09412)
        9. CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features? (https://arxiv.org/abs/1905.04899)

        13個你一定要知道的PyTorch特性

        解讀:為什么要做特征歸一化/標準化?

        一文搞懂 PyTorch 內部機制

        張一鳴:每個逆襲的年輕人,都具備的底層能力




        ,,西,[]!


        瀏覽 100
        點贊
        評論
        收藏
        分享

        手機掃一掃分享

        分享
        舉報
        評論
        圖片
        表情
        推薦
        點贊
        評論
        收藏
        分享

        手機掃一掃分享

        分享
        舉報
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            日本无遮挡 | 国产大屌日 | 男人和女人曰逼 | 一级A片女处破 | 中文字幕乱码在线蜜乳欧美字幕 | 手机福利视频一区二区 | 日本www色 | 亚洲欧美天堂 | 免费无码又爽又高潮视频蜜柚视频 | 日日日日日 |