1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        圖像分類訓練技巧之數(shù)據(jù)增強總結

        共 38479字,需瀏覽 77分鐘

         ·

        2023-08-18 20:19

        點擊上方小白學視覺”,選擇加"星標"或“置頂

        重磅干貨,第一時間送達

        僅作學術分享,不代表本公眾號立場,侵權聯(lián)系刪除
        轉(zhuǎn)載于:作者丨小小將@知乎(已授權)
        來源丨h(huán)ttps://zhuanlan.zhihu.com/p/430563265
        編輯丨極市平臺
        一個模型的性能除了和網(wǎng)絡結構本身有關,還非常依賴具體的訓練策略,比如優(yōu)化器,數(shù)據(jù)增強以及正則化策略等(當然也很訓練數(shù)據(jù)強相關,訓練數(shù)據(jù)量往往決定模型性能的上線)。近年來,圖像分類模型在ImageNet數(shù)據(jù)集的top1 acc已經(jīng)由原來的56.5(AlexNet,2012)提升至90.88(CoAtNet,2021,用了額外的數(shù)據(jù)集JFT-3B),這進步除了主要歸功于模型,算力和數(shù)據(jù)的提升,也與訓練策略的提升緊密相關。最近剛興起的vision transformer相比CNN模型往往也需要更heavy的數(shù)據(jù)增強和正則化策略。這里簡單介紹圖像分類訓練技巧中的常用數(shù)據(jù)增強策略。

        baseline

        ImageNet數(shù)據(jù)集訓練常用的數(shù)據(jù)增強策略如下,訓練過程的數(shù)據(jù)增強包括隨機縮放裁剪(RandomResizedCrop,這種處理方式源自谷歌的Inception,所以稱為 Inception-style pre-processing)和水平翻轉(zhuǎn)(RandomHorizontalFlip),而測試階段是執(zhí)行縮放和中心裁剪。這其實是一種輕量級的策略,這里稱之為baseline。torchvision的實現(xiàn)的ResNet50訓練采用的策略就是這個,在ImageNet上的top1 acc可以達到76.1。

        from torchvision import transforms

        normalize = transforms.Normalize(mean=[0.4850.4560.406],
                                         std=[0.2290.2240.225])
        # 訓練
        train_transform = transforms.Compose([
            # 這里的scale指的是面積,ratio是寬高比
            # 具體實現(xiàn)每次先隨機確定scale和ratio,可以生成w和h,然后隨機確定裁剪位置進行crop
            # 最后是resize到target size
            transforms.RandomResizedCrop(224, scale=(0.081.0), ratio=(3. / 4.4. / 3.)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
         ])
        # 測試
        test_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
         ])

        AutoAugment

        谷歌在2018年提出通過AutoML來自動搜索數(shù)據(jù)增強策略,稱之為AutoAugment(算是自動數(shù)據(jù)增強開山之作)。搜索方法采用強化學習,和NAS類似,只不過搜索空間是數(shù)據(jù)增強策略,而不是網(wǎng)絡架構。在搜索空間里,一個policy包含5個sub-policies,每個sub-policy包含兩個串行的圖像增強操作,每個增強操作有兩個超參數(shù):進行該操作的概率圖像增強的幅度(magnitude,這個表示數(shù)據(jù)增強的強度,比如對于旋轉(zhuǎn),旋轉(zhuǎn)的角度就是增強幅度,旋轉(zhuǎn)角度越大,增強越大)。每個policy在執(zhí)行時,首先隨機從5個策略中隨機選擇一個sub-policy,然后序列執(zhí)行兩個圖像操作。

        搜索空間一共有16種圖像增強類型,具體如下所示,大部分操作都定義了圖像增強的幅度范圍,在搜索時需要將幅度值離散化,具體地是將幅度值在定義范圍內(nèi)均勻地取10個值。

        論文在不同的數(shù)據(jù)集上( CIFAR-10 , SVHN, ImageNet)做了實驗,這里給出在ImageNet數(shù)據(jù)集上搜索得到的最優(yōu)policy(最后實際上是將搜索得到的前5個最好的policies合成了一個policy,所以這里包含25個sub-policies):

        # operation, probability, magnitude
        (("Posterize"0.48), ("Rotate"0.69)),
        (("Solarize"0.65), ("AutoContrast"0.6None)),                                                          
        (("Equalize"0.8None), ("Equalize"0.6None)),
        (("Posterize"0.67), ("Posterize"0.66)),
        (("Equalize"0.4None), ("Solarize"0.24)),
        (("Equalize"0.4None), ("Rotate"0.88)),
        (("Solarize"0.63), ("Equalize"0.6None)),
        (("Posterize"0.85), ("Equalize"1.0None)),
        (("Rotate"0.23), ("Solarize"0.68)),
        (("Equalize"0.6None), ("Posterize"0.46)),
        (("Rotate"0.88), ("Color"0.40)),
        (("Rotate"0.49), ("Equalize"0.6None)),
        (("Equalize"0.0None), ("Equalize"0.8None)),
        (("Invert"0.6None), ("Equalize"1.0None)),
        (("Color"0.64), ("Contrast"1.08)),
        (("Rotate"0.88), ("Color"1.02)),
        (("Color"0.88), ("Solarize"0.87)),
        (("Sharpness"0.47), ("Invert"0.6None)),
        (("ShearX"0.65), ("Equalize"1.0None)),
        (("Color"0.40), ("Equalize"0.6None)),
        (("Equalize"0.4None), ("Solarize"0.24)),
        (("Solarize"0.65), ("AutoContrast"0.6None)),
        (("Invert"0.6None), ("Equalize"1.0None)),
        (("Color"0.64), ("Contrast"1.08)),
        (("Equalize"0.8None), ("Equalize"0.6None))

        基于搜索得到的AutoAugment訓練可以將ResNet50在ImageNet數(shù)據(jù)集上的top1 acc從76.3提升至77.6。一個比較重要的問題,這些從某一個數(shù)據(jù)集搜索得到的策略是否只對固定的數(shù)據(jù)集有效,論文也通過具體實驗證明了AutoAugment的遷移能力,比如將ImageNet數(shù)據(jù)集上得到的策略用在5個 FGVC數(shù)據(jù)集(與ImageNet圖像輸入大小相似)也均有提升。

        目前torchvision庫已經(jīng)實現(xiàn)了AutoAugment,具體使用如下所示(注意AutoAug前也需要包括一個RandomResizedCrop):

        from torchvision.transforms import autoaugment, transforms

        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(crop_size, interpolation=interpolation),
            transforms.RandomHorizontalFlip(hflip_prob),
            # 這里policy屬于torchvision.transforms.autoaugment.AutoAugmentPolicy,
            # 對于ImageNet就是 AutoAugmentPolicy.IMAGENET
            # 此時aa_policy = autoaugment.AutoAugmentPolicy('imagenet')
            autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation),
         transforms.PILToTensor(),
            transforms.ConvertImageDtype(torch.float),
            transforms.Normalize(mean=mean, std=std)
         ])

        RandAugment

        AutoAugment存在的一個問題是搜索空間巨大,這使得搜索只能在代理任務中進行:使用小的模型在ImageNet的一個小的子集( 120類和6000圖片)搜索。谷歌在2019年又提出了一個更簡單的數(shù)據(jù)增強策略:RandAugment。這篇論文首先發(fā)現(xiàn)AutoAugment這樣在小數(shù)據(jù)集上搜索出來的策略在大的數(shù)據(jù)集上應用會存在問題,這主要是因為數(shù)據(jù)增強策略和模型大小和數(shù)據(jù)量大小存在強相關,如下圖所示可以看到模型或者訓練數(shù)據(jù)量越大,其最優(yōu)的數(shù)據(jù)增強的幅度越大,這說明AutoAugment得到的結果應該是次優(yōu)的。另外,Population Based Augmentation這篇論文發(fā)現(xiàn)最優(yōu)的數(shù)據(jù)增強幅度是隨訓練過程增加,而且不同的增強操作遵循類似的規(guī)律,這啟發(fā)作者采用固定的增強幅度而不是去搜索。RandAugment相比AutoAugment的策略空間很?。?span style="outline: 0px;max-width: 100%;cursor: pointer;box-sizing: border-box !important;overflow-wrap: break-word !important;">  vs  ),所以它不需要采用代理任務,甚至直接采用簡單的網(wǎng)格搜索。

        具體地,RandAugment共包含兩個超參數(shù):圖像增強操作的數(shù)量N和一個全局的增強幅度M,其實現(xiàn)代碼如下所示,每次從候選操作集合(共14種策略)隨機選擇N個操作(等概率),然后串行執(zhí)行(這里沒有判斷概率,是一定執(zhí)行)。這里的M取值范圍為{0, . . . , 30}(每個圖像增強操作歸一化到同樣的幅度范圍),而N取值范圍一般為 {1, 2, 3}。

        # Identity是恒等變換,不做任何增強
        transforms = ['Identity''AutoContrast''Equalize''Rotate''Solarize''Color''Posterize'
                      'Contrast''Brightness''Sharpness''ShearX''ShearY''TranslateX''TranslateY']

        def randaugment(N, M):
         """Generate a set of distortions.
         Args:
         N: Number of augmentation transformations to
         apply sequentially.
         M: Magnitude for all the transformations.
         """

         sampled_ops = np.random.choice(transforms, N)
         return [(op, M) for op in sampled_ops]

        對于ResNet50,其搜索得到的N=2,M=9,RandAugment相比AutoAugment可以在ImageNet得到相似的效果(77.6),不過DeiT中發(fā)現(xiàn)使用RandAugment效果更好一些( DeiT-B:81.8 vs 81.2)。目前torchvision庫也已經(jīng)實現(xiàn)了RandAugment,具體使用如下所示:

        from torchvision.transforms import autoaugment, transforms

        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(crop_size, interpolation=interpolation),
            transforms.RandomHorizontalFlip(hflip_prob),
            autoaugment.RandAugment(interpolation=interpolation),
         transforms.PILToTensor(),
            transforms.ConvertImageDtype(torch.float),
            transforms.Normalize(mean=mean, std=std)
         ])

        TrivialAugment

        雖然RandAugment的搜索空間極小,但是對于不同的數(shù)據(jù)集還是需要確定最優(yōu)的N和M,這依然有較大的實驗成本。RandAugment后,華為提出了UniformAugment,這種策略不需要搜索也能取得較好的結果。不過這里我們介紹一項更新的工作:TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation。TrivialAugment也不需要任何搜索,整個方法非常簡單:每次隨機選擇一個圖像增強操作,然后隨機確定它的增強幅度,并對圖像進行增強。由于沒有任何超參數(shù),所以不需要任何搜索。從實驗結果上看,TA可以在多個數(shù)據(jù)集上取得更好的結果,如在ImageNet數(shù)據(jù)集上,ResNet50的top1 acc可以達到78.1,超過RandAugment。

        TrivialAugment的圖像增強集合和RandAugment基本一樣,不過TA也定義了一套更寬的增強幅度,目前torchvision中已經(jīng)實現(xiàn)了TrivialAugmentWide,具體使用代碼如下所示:

        from torchvision.transforms import autoaugment, transforms

        augmentation_space = {
            # op_name: (magnitudes, signed)
            "Identity": (torch.tensor(0.0), False),
            "ShearX": (torch.linspace(0.00.99, num_bins), True),
            "ShearY": (torch.linspace(0.00.99, num_bins), True),
            "TranslateX": (torch.linspace(0.032.0, num_bins), True),
            "TranslateY": (torch.linspace(0.032.0, num_bins), True),
            "Rotate": (torch.linspace(0.0135.0, num_bins), True),
            "Brightness": (torch.linspace(0.00.99, num_bins), True),
            "Color": (torch.linspace(0.00.99, num_bins), True),
            "Contrast": (torch.linspace(0.00.99, num_bins), True),
            "Sharpness": (torch.linspace(0.00.99, num_bins), True),
            "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
            "Solarize": (torch.linspace(255.00.0, num_bins), False),
            "AutoContrast": (torch.tensor(0.0), False),
            "Equalize": (torch.tensor(0.0), False),
        }

        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(crop_size, interpolation=interpolation),
            transforms.RandomHorizontalFlip(hflip_prob),
            autoaugment.TrivialAugmentWide(interpolation=interpolation),
         transforms.PILToTensor(),
            transforms.ConvertImageDtype(torch.float),
            transforms.Normalize(mean=mean, std=std)
         ])

        RandomErasing

        RandomErasing是廈門大學在2017年提出的一種簡單的數(shù)據(jù)增強(這個策略和同期的CutOut基本一樣),基本原理是:隨機從圖像中擦除一個矩形區(qū)域而不改變圖像的原始標簽。DeiT的訓練策略中也包括了RandomErasing。

        目前torchvision也實現(xiàn)了RandomErasing,其具體使用代碼如下(注意這個op不支持PIL圖像,需要在轉(zhuǎn)換為torch.tensor后使用):

        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.081.0), ratio=(3. / 4.4. / 3.)),
            transforms.RandomHorizontalFlip(),
            transforms.PILToTensor()
            transforms.ConvertImageDtype(torch.float),
            normalize,
            # scale是指相對于原圖的擦除面積范圍
            # ratio是指擦除區(qū)域的寬高比
            # value是指擦除區(qū)域的值,如果是int,也可以是tuple(RGB3個通道值),或者是str,需為'random',表示隨機生成
            transforms.RandomErasing(p=0.5, scale=(0.020.33), ratio=(0.33.3), value=0, inplace=False),
         ])

        MixUp

        MixUp在FAIR在2017年提出的一種數(shù)據(jù)增強方法:兩張不同的圖像隨機線性組合,而同時生成線性組合的標簽。



        這里的 是兩張不同的圖像, 是它們對應的one-hot標簽,而 是線性組合系數(shù),每次執(zhí)行時隨機生成。假定圖像分類任務是2分類(區(qū)分狗和貓),兩張輸入圖像分別是狗和貓(如下圖所示),它們對應的one-hot標簽分別是[1,0]和[0, 1]。在進行mixup之前,首先對它們進行必要的數(shù)據(jù)增強得到aug_img1和aug_img2,然后隨機生成線性組合系數(shù),對于 得到的圖像是mix_img1,標簽變?yōu)閇0.7, 0.3],而 得到的圖像是mix_img2,標簽變?yōu)閇0.3, 0.7]。

        目前timm和torchvision中已經(jīng)實現(xiàn)了mixup,這里以torchvision為例來講述具體的代碼實現(xiàn)。由于mixup需要兩個輸入,而不單單是對當前圖像進行操作,所以一般是在得到batch數(shù)據(jù)后再進行mixup,這也意味著圖像也已經(jīng)完成了其它的數(shù)據(jù)增強如RandAugment,對于batch中的每個樣本可以隨機選擇另外一個樣本進行mixup。具體的實現(xiàn)代碼如下所示:

        # from https://github.com/pytorch/vision/blob/main/references/classification/transforms.py
        class RandomMixup(torch.nn.Module):
            """Randomly apply Mixup to the provided batch and targets.
            The class implements the data augmentations as described in the paper
            `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
            Args:
                num_classes (int): number of classes used for one-hot encoding.
                p (float): probability of the batch being transformed. Default value is 0.5.
                alpha (float): hyperparameter of the Beta distribution used for mixup.
                    Default value is 1.0. # beta分布超參數(shù)
                inplace (bool): boolean to make this transform inplace. Default set to False.
            """


            def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
                super().__init__()
                assert num_classes > 0"Please provide a valid positive value for the num_classes."
                assert alpha > 0"Alpha param can't be zero."

                self.num_classes = num_classes
                self.p = p
                self.alpha = alpha
                self.inplace = inplace

            def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
                """
                Args:
                    batch (Tensor): Float tensor of size (B, C, H, W)
                    target (Tensor): Integer tensor of size (B, )
                Returns:
                    Tensor: Randomly transformed batch.
                """

                if batch.ndim != 4:
                    raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
                if target.ndim != 1:
                    raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
                if not batch.is_floating_point():
                    raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
                if target.dtype != torch.int64:
                    raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")

                if not self.inplace:
                    batch = batch.clone()
                    target = target.clone()
          
                # 建立one-hot標簽
                if target.ndim == 1:
                    target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
          
                # 判斷是否進行mixup
                if torch.rand(1).item() >= self.p:
                    return batch, target
          
                # 這里將batch數(shù)據(jù)平移一個單位,產(chǎn)生mixup的圖像對,這意味著每個圖像與相鄰的下一個圖像進行mixup
                # timm實現(xiàn)是通過flip來做的,這意味著第一個圖像和最后一個圖像進行mixup
                # It's faster to roll the batch by one instead of shuffling it to create image pairs
                batch_rolled = batch.roll(10)
                target_rolled = target.roll(10)
          
                # 隨機生成組合系數(shù)
                # Implemented as on mixup paper, page 3.
                lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
                batch_rolled.mul_(1.0 - lambda_param)
                batch.mul_(lambda_param).add_(batch_rolled) # 得到mixup后的圖像

                target_rolled.mul_(1.0 - lambda_param)
                target.mul_(lambda_param).add_(target_rolled) # 得到mixup后的標簽

                return batch, target

        然后可以將MixUp操作放在DataLoader的collate_fn中,這個函數(shù)要實現(xiàn)的是將多個樣本合并成一個mini-batch,所以可以將MixUp插在得到mini-batch后,具體實現(xiàn)如下所示:

        from torch.utils.data.dataloader import default_collate

        mixup_transform = RandomMixup(num_classes, p=1.0, alpha=mixup_alpha)
        collate_fn = lambda batch: mixup_transform(*default_collate(batch))
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
            sampler=train_sampler, collate_fn=collate_fn)

        對于MixUp,還要注意兩個兩點。第一個是如果同時采用了label smoothing,那么在創(chuàng)建one-hot標簽時要直接得到smooth后的標簽,具體實現(xiàn)如下(參考timm):

        def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
            x = x.long().view(-11)
            return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)

        off_value = smoothing / num_classes
        on_value = 1. - smoothing + off_value
        smooth_one_hot = one_hot(target, num_classes, on_value=on_value, off_value=off_value)

        第二個要注意的是MixUp后得到標簽時soft label,不能直接采用torch.nn.CrossEntropyLoss來計算loss,而是直接計算交叉熵(參考timm):

        class SoftTargetCrossEntropy(nn.Module):

            def __init__(self):
                super(SoftTargetCrossEntropy, self).__init__()

            def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
                loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
                return loss.mean()

        注意在PyTorch1.10版本之后,torch.nn.CrossEntropyLoss已經(jīng)支持直接送入的target是probabilities for each class,原來只支持target是class indices;而且也支持label_smoothing參數(shù),所以上述兩個注意點就不再需要了。

        說到計算loss,timm作者近期在ResNet strikes back: An improved training procedure in timm指出采用MixUp后可以將多分類改成多標簽分類(multi-label classification),即從N分類變成N個2分類(直接采用BinaryCrossEntropy),這應該更符合MixUp后圖像的語義,從對比實驗來看效果有微弱的提升。MixUp除了可以用于圖像分類任務,還可以用于物體檢測任務中,比如YOLOX就采用了MixUp,這里面的做法是對圖像mixup后,其box為兩個圖像的box的合并集合,而沒有對標簽軟化,這塊也可以見論文Bag of Freebies for Training Object Detection Neural Networks。

        CutMix

        CutMix是2019年提出的一項和MixUp和類似的數(shù)據(jù)增強策略,它也是同時對兩個圖像和標簽進行混合,與MixUp不同的是它的圖像混合方式。CutMix不是對兩個圖像線性組合,而是從另外一張圖像隨機剪切一個patch并粘貼到第一張圖像上,patch的起始坐標隨機生成,而寬高是由 來控制:



        這里 是原始圖像的寬和高,所以 其實決定的是patch和原圖的面積比: 。下圖展示了 分別取0.7和0.3的混合效果, 越小,粘貼的patch越大。對于標簽,其處理方式和MixUp一樣,通過 來得到兩張圖像的線性組合。

        CutMix做了ImageNet上的對比實驗,相比MixUp,ResNet50的top1 acc大約能提升一個點(77.4 vs 78.6):

        目前timm和torchvision中也已經(jīng)實現(xiàn)了CutMix,這里還是以torchvision為例來講述具體的代碼實現(xiàn),如下所示(和MixUp基本類似,只不過內(nèi)部處理存在差異):

        class RandomCutmix(torch.nn.Module):
            """Randomly apply Cutmix to the provided batch and targets.
            The class implements the data augmentations as described in the paper
            `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
            <https://arxiv.org/abs/1905.04899>`_.
            Args:
                num_classes (int): number of classes used for one-hot encoding.
                p (float): probability of the batch being transformed. Default value is 0.5.
                alpha (float): hyperparameter of the Beta distribution used for cutmix.
                    Default value is 1.0.
                inplace (bool): boolean to make this transform inplace. Default set to False.
            """


            def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
                super().__init__()
                assert num_classes > 0"Please provide a valid positive value for the num_classes."
                assert alpha > 0"Alpha param can't be zero."

                self.num_classes = num_classes
                self.p = p
                self.alpha = alpha
                self.inplace = inplace

            def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
                """
                Args:
                    batch (Tensor): Float tensor of size (B, C, H, W)
                    target (Tensor): Integer tensor of size (B, )
                Returns:
                    Tensor: Randomly transformed batch.
                """

                if batch.ndim != 4:
                    raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
                if target.ndim != 1:
                    raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
                if not batch.is_floating_point():
                    raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
                if target.dtype != torch.int64:
                    raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")

                if not self.inplace:
                    batch = batch.clone()
                    target = target.clone()

                if target.ndim == 1:
                    target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)

                if torch.rand(1).item() >= self.p:
                    return batch, target

                # It's faster to roll the batch by one instead of shuffling it to create image pairs
                batch_rolled = batch.roll(10)
                target_rolled = target.roll(10)

                # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
                lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
                W, H = F.get_image_size(batch)
          
                # 確定patch的起點
                r_x = torch.randint(W, (1,))
                r_y = torch.randint(H, (1,))
          
                # 確定patch的w和h(其實是一半大?。?/span>
                r = 0.5 * math.sqrt(1.0 - lambda_param)
                r_w_half = int(r * W)
                r_h_half = int(r * H)
          
                # 越界處理
                x1 = int(torch.clamp(r_x - r_w_half, min=0))
                y1 = int(torch.clamp(r_y - r_h_half, min=0))
                x2 = int(torch.clamp(r_x + r_w_half, max=W))
                y2 = int(torch.clamp(r_y + r_h_half, max=H))

                batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
                # 由于越界處理, λ可能發(fā)生改變,所以要重新計算
                lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

                target_rolled.mul_(1.0 - lambda_param)
                target.mul_(lambda_param).add_(target_rolled)

                return batch, target

        其它使用和MixUp一樣。

        Repeated Augmentation

        Repeated Augmentation (RA)是FAIR在MultiGrain提出的一種抽樣策略,一般情況下,訓練的mini-batch包含的增強過的sample都是來自不同的圖像,但是RA這種抽樣策略允許一個mini-batch中包含來自同一個圖像的不同增強版本,此時mini-batch的各個樣本并非是完全獨立的,這相當于對同一個樣本進行重復抽樣,所以稱為Repeated Augmentation。這篇論文認為在一個mini-batch學習來自同一個圖像的不同增強版本能讓模型更容易學習到增強不變的特征。關于RA,其實另外一篇較早的論文Augment your batch: better training with larger batches也提出了類似的策略,另外DeepMind在最近的論文Drawing Multiple Augmentation Samples Per Image During Training Efficiently Decreases Test Error也進一步通過實驗來證明這種策略的效果。

        DeiT的訓練也采用了RA,嚴格來說RA不屬于數(shù)據(jù)增強策略,而是一種mini-batch抽樣方法,這里也簡單給出DeiT實現(xiàn)的RA(可以替換torch.utils.data.DistributedSampler):

        class RASampler(torch.utils.data.Sampler):
            """Sampler that restricts data loading to a subset of the dataset for distributed,
            with repeated augmentation.
            It ensures that different each augmented version of a sample will be visible to a
            different process (GPU)
            Heavily based on torch.utils.data.DistributedSampler
            """


            def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
                if num_replicas is None:
                    if not dist.is_available():
                        raise RuntimeError("Requires distributed package to be available")
                    num_replicas = dist.get_world_size()
                if rank is None:
                    if not dist.is_available():
                        raise RuntimeError("Requires distributed package to be available")
                    rank = dist.get_rank()
                self.dataset = dataset
                self.num_replicas = num_replicas
                self.rank = rank
                self.epoch = 0
                # 重復采樣后每個replica的樣本量
                self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
                # 重復采樣后的總樣本量
                self.total_size = self.num_samples * self.num_replicas
                # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
                # 每個replica實際樣本量,即不重復采樣時的每個replica的樣本量
                self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
                self.shuffle = shuffle

            def __iter__(self):
                # deterministically shuffle based on epoch
                g = torch.Generator()
                g.manual_seed(self.epoch)
                if self.shuffle:
                    indices = torch.randperm(len(self.dataset), generator=g).tolist()
                else:
                    indices = list(range(len(self.dataset)))

                # add extra samples to make it evenly divisible
                indices = [ele for ele in indices for i in range(3)] # 重復3次
                indices += indices[:(self.total_size - len(indices))]
                assert len(indices) == self.total_size

                # subsample: 使得同一個樣本的重復版本進入不同的進程(GPU)
                indices = indices[self.rank:self.total_size:self.num_replicas]
                assert len(indices) == self.num_samples

                return iter(indices[:self.num_selected_samples]) # 截取實際樣本量

            def __len__(self):
                return self.num_selected_samples

            def set_epoch(self, epoch):
                self.epoch = epoch

        小結

        這里簡單介紹了幾種常用且有效的數(shù)據(jù)增強策略,這些策略在vision transformer模型被使用,而且timm訓練的ResNet新baseline也使用了這些策略。

        參考

        1. Training data-efficient image transformers & distillation through attention  (https://arxiv.org/abs/2012.12877)
        2. AutoAugment: Learning Augmentation Policies from Data  (https://arxiv.org/abs/1805.09501)
        3. RandAugment: Practical automated data augmentation with a reduced search space  (https://arxiv.org/abs/1909.13719)
        4. TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation  (https://arxiv.org/abs/2103.10158)
        5. Random Erasing Data Augmentation(https://arxiv.org/abs/1708.04896)
        6. Augment your batch: better training with larger batches  (https://arxiv.org/abs/1901.09335)
        7. MultiGrain: a unified image embedding for classes and instances(https://arxiv.org/abs/1902.05509)
        8. mixup: Beyond Empirical Risk Minimization  (https://arxiv.org/abs/1710.09412)
        9. CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features  (https://arxiv.org/abs/1905.04899)
              
        下載1:OpenCV-Contrib擴展模塊中文版教程
        在「小白學視覺」公眾號后臺回復:擴展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴展模塊教程中文版,涵蓋擴展模塊安裝、SFM算法、立體視覺、目標跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。

        下載2:Python視覺實戰(zhàn)項目52講
        小白學視覺公眾號后臺回復:Python視覺實戰(zhàn)項目即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個視覺實戰(zhàn)項目,助力快速學校計算機視覺。

        下載3:OpenCV實戰(zhàn)項目20講
        小白學視覺公眾號后臺回復:OpenCV實戰(zhàn)項目20講即可下載含有20個基于OpenCV實現(xiàn)20個實戰(zhàn)項目,實現(xiàn)OpenCV學習進階。

        交流群


        歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學影像、GAN、算法競賽等微信群(以后會逐漸細分),請掃描下面微信號加群,備注:”昵稱+學校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~


        瀏覽 224
        點贊
        評論
        收藏
        分享

        手機掃一掃分享

        分享
        舉報
        評論
        圖片
        表情
        推薦
        點贊
        評論
        收藏
        分享

        手機掃一掃分享

        分享
        舉報
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            日韩一区二区婬片国产欧美在线 | 色精品| 男女无遮挡毛片免费视频网站 | 99久久人妻无码中文字幕系列 | 透明内裤被巴捣出白浆 | 久久夜色网 | 91看片视频 | 同学扒了我内裤还玩我全身男 | 国产人妻 精品无码免费 | 国产精品毛片一区视频播不卡 |