1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        PyTorch 源碼解讀之 BN & SyncBN

        共 8239字,需瀏覽 17分鐘

         ·

        2020-12-23 01:55

        點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師

        設(shè)為星標(biāo),干貨直達(dá)!

        AI編輯:我是小將

        本文作者:OpenMMLab @205120

        https://zhuanlan.zhihu.com/p/337732517

        本文已由原作者授權(quán)



        1. BatchNorm 原理

        2. BatchNorm 的 PyTorch 實(shí)現(xiàn)

        2.1 _NormBase 類

        2.1.1 初始化

        2.1.2 模擬 BN forward

        2.1.3 running_mean、running_var 的更新

        2.1.4 \gamma, \beta 的更新

        2.1.5 eval 模式

        2.2 BatchNormNd 類

        3. SyncBatchNorm 的 PyTorch 實(shí)現(xiàn)

        3.1 forward

        3.2 backward

        1. BatchNorm 原理


        BatchNorm 最早在全連接網(wǎng)絡(luò)中被提出,對(duì)每個(gè)神經(jīng)元的輸入做歸一化。擴(kuò)展到 CNN 中,就是對(duì)每個(gè)卷積核的輸入做歸一化,或者說(shuō)在 channel 之外的所有維度做歸一化。BN 帶來(lái)的好處有很多,這里簡(jiǎn)單列舉幾個(gè):

        • 防止過(guò)擬合:?jiǎn)蝹€(gè)樣本的輸出依賴于整個(gè) mini-batch,防止對(duì)某個(gè)樣本過(guò)擬合;

        • 加快收斂:梯度下降過(guò)程中,每一層的??和??都會(huì)不斷變化,導(dǎo)致輸出結(jié)果的分布在不斷變化,后層網(wǎng)絡(luò)就要不停地去適應(yīng)這種分布變化。用 BN 后,可以使每一層輸入的分布近似不變。

        • 防止梯度彌散:forward 過(guò)程中,逐漸往非線性函數(shù)的取值區(qū)間的上下限兩端靠近,(以 Sigmoid 為例),此時(shí)后面層的梯度變得非常小,不利于訓(xùn)練。

        BN 的數(shù)學(xué)表達(dá)為:?

        這里引入了縮放因子??和平移因子??,作者在文章里解釋了它們的作用:

        • Normalize 到??,??會(huì)導(dǎo)致新的分布喪失從前層傳遞過(guò)來(lái)的特征與知識(shí)

        • 以 Sigmoid 為例,加入??,??可以防止大部分值落在近似線性的中間部分,導(dǎo)致無(wú)法利用非線性的部分

        2. BatchNorm 的 PyTorch 實(shí)現(xiàn)

        PyTorch 中與 BN 相關(guān)的幾個(gè)類放在 torch.nn.modules.batchnorm 中,包含以下幾個(gè)類:

        • _NormBasenn.Module?的子類,定義了 BN 中的一系列屬性與初始化、讀數(shù)據(jù)的方法;

        • _BatchNorm_NormBase?的子類,定義了?forward?方法;

        • BatchNorm1d?&?BatchNorm2d?&?BatchNorm3d_BatchNorm的子類,定義了不同的_check_input_dim方法。

        2.1 _NormBase 類

        2.1.1 初始化

        _NormBase類定義了 BN 相關(guān)的一些屬性,如下表所示:

        attributemeaning
        num_features輸入的 channel 數(shù)
        track_running_stats默認(rèn)為 True,是否統(tǒng)計(jì) running_mean,running_var
        running_mean訓(xùn)練時(shí)統(tǒng)計(jì)輸入的 mean,之后用于 inference
        running_var訓(xùn)練時(shí)統(tǒng)計(jì)輸入的 var,之后用于 inference
        momentum默認(rèn) 0.1,更新 running_mean,running_var 時(shí)的動(dòng)量
        num_batches_trackedPyTorch 0.4 后新加入,當(dāng) momentum 設(shè)置為 None 時(shí),使用 num_batches_tracked 計(jì)算每一輪更新的動(dòng)量
        affine默認(rèn)為 True,訓(xùn)練 weight 和 bias;否則不更新它們的值
        weight公式中的 \gamma,初始化為全 1 tensor
        bias公式中的 \beta,初始化為全 0 tensor

        這里貼一下 PyTorch 的源碼:

        class _NormBase(Module):
        """Common base of _InstanceNorm and _BatchNorm"""
        # 讀checkpoint時(shí)會(huì)用version來(lái)區(qū)分是 PyTorch 0.4.1 之前還是之后的版本
        _version = 2
        __constants__ = ['track_running_stats', 'momentum', 'eps',
        'num_features', 'affine']

        def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
        track_running_stats=True):
        super(_NormBase, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
        # 如果打開 affine,就使用縮放因子和平移因子
        self.weight = Parameter(torch.Tensor(num_features))
        self.bias = Parameter(torch.Tensor(num_features))
        else:
        self.register_parameter('weight', None)
        self.register_parameter('bias', None)
        # 訓(xùn)練時(shí)是否需要統(tǒng)計(jì) mean 和 variance
        if self.track_running_stats:
        # buffer 不會(huì)在self.parameters()中出現(xiàn)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
        self.register_parameter('running_mean', None)
        self.register_parameter('running_var', None)
        self.register_parameter('num_batches_tracked', None)
        self.reset_parameters()

        def reset_running_stats(self):
        if self.track_running_stats:
        self.running_mean.zero_()
        self.running_var.fill_(1)
        self.num_batches_tracked.zero_()

        def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
        init.ones_(self.weight)
        init.zeros_(self.bias)

        def _check_input_dim(self, input):
        # 具體在 BN1d, BN2d, BN3d 中實(shí)現(xiàn),驗(yàn)證輸入合法性
        raise NotImplementedError

        def extra_repr(self):
        return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
        'track_running_stats={track_running_stats}'.format(**self.__dict__)

        def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
        missing_keys, unexpected_keys, error_msgs):
        version = local_metadata.get('version', None)

        if (version is None or version < 2) and self.track_running_stats:
        # at version 2: added num_batches_tracked buffer
        # this should have a default value of 0
        num_batches_tracked_key = prefix + 'num_batches_tracked'
        if num_batches_tracked_key not in state_dict:
        # 舊版本的checkpoint沒(méi)有這個(gè)key,設(shè)置為0
        state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)

        super(_NormBase, self)._load_from_state_dict(
        state_dict, prefix, local_metadata, strict,
        missing_keys, unexpected_keys, error_msgs)


        class _BatchNorm(_NormBase):

        def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
        track_running_stats=True):
        super(_BatchNorm, self).__init__(
        num_features, eps, momentum, affine, track_running_stats)

        def forward(self, input):
        self._check_input_dim(input)

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that it gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
        exponential_average_factor = 0.0
        else:
        exponential_average_factor = self.momentum

        # 如果在train狀態(tài)且self.track_running_stats被設(shè)置為True,就需要更新統(tǒng)計(jì)量
        if self.training and self.track_running_stats:
        if self.num_batches_tracked is not None:
        self.num_batches_tracked = self.num_batches_tracked + 1
        # 如果momentum被設(shè)置為None,就用num_batches_tracked來(lái)加權(quán)
        if self.momentum is None:
        exponential_average_factor = 1.0 / float(self.num_batches_tracked)
        else: # use exponential moving average
        exponential_average_factor = self.momentum

        return F.batch_norm(
        input, self.running_mean, self.running_var, self.weight, self.bias,
        self.training or not self.track_running_stats,
        exponential_average_factor, self.eps)

        2.1.2 模擬 BN forward

        PyTorch 中 BN 的 Python 部分代碼主要實(shí)現(xiàn)初始化、傳參和底層方法調(diào)用。這里用 Python 模擬 BN 的底層計(jì)算。

        import torch
        import torch.nn as nn
        import torch.nn.modules.batchnorm

        # 創(chuàng)建隨機(jī)輸入
        def create_inputs():
        return torch.randn(8, 3, 20, 20)

        # 以 BatchNorm2d 為例
        # mean_val, var_val 不為None時(shí),不對(duì)輸入進(jìn)行統(tǒng)計(jì),而直接用傳進(jìn)來(lái)的均值、方差
        def dummy_bn_forward(x, bn_weight, bn_bias, eps, mean_val=None, var_val=None):
        if mean_val is None:
        mean_val = x.mean([0, 2, 3])
        if var_val is None:
        # 這里需要注意,torch.var 默認(rèn)算無(wú)偏估計(jì),因此需要手動(dòng)設(shè)置unbiased=False
        var_val = x.var([0, 2, 3], unbiased=False)

        x = x - mean_val[None, ..., None, None]
        x = x / torch.sqrt(var_val[None, ..., None, None] + eps)
        x = x * bn_weight[..., None, None] + bn_bias[..., None, None]
        return mean_val, var_val, x

        驗(yàn)證 dummy BN 輸出的正確性:

        bn_layer = nn.BatchNorm2d(num_features=3)
        inputs = create_inputs()
        # 用 pytorch 的實(shí)現(xiàn) forward
        bn_outputs = bn_layer(inputs)
        # 用 dummy bn 來(lái) forward
        _, _, expected_outputs = dummy_bn_forward(
        inputs, bn_layer.weight, bn_layer.bias, bn_layer.eps)
        assert torch.allclose(expected_outputs, bn_outputs)

        沒(méi)有報(bào)異常,因此計(jì)算的值是正確的。

        2.1.3 running_mean、running_var 的更新

        BatchNorm 默認(rèn)打開?track_running_stats,因此每次 forward 時(shí)都會(huì)依據(jù)當(dāng)前 minibatch 的統(tǒng)計(jì)量來(lái)更新?running_mean?和?running_var。

        momentum?默認(rèn)值為 0.1,控制歷史統(tǒng)計(jì)量與當(dāng)前 minibatch 在更新?running_mean、running_var?時(shí)的相對(duì)影響。

        其中??、?分別表示??的均值、方差;需要注意這里統(tǒng)計(jì)方差時(shí)用了無(wú)偏估計(jì),與論文保持一致。手動(dòng)對(duì)這一過(guò)程進(jìn)行模擬,如下所示:

        running_mean = torch.zeros(3)
        running_var = torch.ones_like(running_mean)
        momentum = 0.1 # 這也是BN初始化時(shí)momentum默認(rèn)值
        bn_layer = nn.BatchNorm2d(num_features=3, momentum=momentum)

        # 模擬 forward 10 次
        for t in range(10):
        inputs = create_inputs()
        bn_outputs = bn_layer(inputs)
        inputs_mean, inputs_var, _ = dummy_bn_forward(
        inputs, bn_layer.weight, bn_layer.bias, bn_layer.eps
        )
        n = inputs.numel() / inputs.size(1)
        # 更新 running_var 和 running_mean
        running_var = running_var * (1 - momentum) + momentum * inputs_var * n / (n - 1)
        running_mean = running_mean * (1 - momentum) + momentum * inputs_mean

        assert torch.allclose(running_var, bn_layer.running_var)
        assert torch.allclose(running_mean, bn_layer.running_mean)
        print(f'bn_layer running_mean is {bn_layer.running_mean}')
        print(f'dummy bn running_mean is {running_mean}')
        print(f'bn_layer running_var is {bn_layer.running_var}')
        print(f'dummy bn running_var is {running_var}')

        輸出結(jié)果:

        bn_layer running_mean is tensor([ 0.0101, -0.0013, 0.0101])
        dummy bn running_mean is tensor([ 0.0101, -0.0013, 0.0101])
        bn_layer running_var is tensor([0.9857, 0.9883, 1.0205])
        dummy bn running_var is tensor([0.9857, 0.9883, 1.0205])

        running_mean?的初始值為 0,forward 后發(fā)生變化。同時(shí)模擬 BN 的running_mean,running_var?也與 PyTorch 實(shí)現(xiàn)的結(jié)果一致。

        以上討論的是使用momentum的情況。在 PyTorch 0.4.1 后,加入了num_batches_tracked屬性,統(tǒng)計(jì) BN 一共 forward 了多少個(gè) minibatch。當(dāng)momentum被設(shè)置為None時(shí),就由num_batches_tracked來(lái)控制歷史統(tǒng)計(jì)量與當(dāng)前 minibatch 的影響占比:

        接下來(lái)手動(dòng)模擬這一過(guò)程:

        running_mean = torch.zeros(3)
        running_var = torch.ones_like(running_mean)
        num_batches_tracked = 0
        # momentum 設(shè)置成 None,用 num_batches_tracked 來(lái)更新統(tǒng)計(jì)量
        bn_layer = nn.BatchNorm2d(num_features=3, momentum=None)

        # 同樣是模擬 forward 10次
        for t in range(10):
        inputs = create_inputs()
        bn_outputs = bn_layer(inputs)
        inputs_mean, inputs_var, _ = dummy_bn_forward(
        inputs, bn_layer.weight, bn_layer.bias, bn_layer.eps
        )
        num_batches_tracked += 1
        # exponential_average_factor
        eaf = 1.0 / num_batches_tracked
        n = inputs.numel() / inputs.size(1)
        # 更新 running_var 和 running_mean
        running_var = running_var * (1 - eaf) + eaf * inputs_var * n / (n - 1)
        running_mean = running_mean * (1 - eaf) + eaf * inputs_mean

        assert torch.allclose(running_var, bn_layer.running_var)
        assert torch.allclose(running_mean, bn_layer.running_mean)

        bn_layer.train(mode=False)
        inference_inputs = create_inputs()
        bn_outputs = bn_layer(inference_inputs)
        _, _, dummy_outputs = dummy_bn_forward(
        inference_inputs, bn_layer.weight,
        bn_layer.bias, bn_layer.eps,
        running_mean, running_var)
        assert torch.allclose(dummy_outputs, bn_outputs)
        print(f'bn_layer running_mean is {bn_layer.running_mean}')
        print(f'dummy bn running_mean is {running_mean}')
        print(f'bn_layer running_var is {bn_layer.running_var}')
        print(f'dummy bn running_var is {running_var}')

        輸出:

        bn_layer running_mean is tensor([-0.0040, 0.0074, -0.0162])
        dummy bn running_mean is tensor([-0.0040, 0.0074, -0.0162])
        bn_layer running_var is tensor([1.0097, 1.0086, 0.9815])
        dummy bn running_var is tensor([1.0097, 1.0086, 0.9815])

        手動(dòng)模擬的結(jié)果與 PyTorch 相同。

        2.1.4??,??的更新

        BatchNorm 的?weight,bias?分別對(duì)應(yīng)公式里的??,??, 更新方式是梯度下降法。

        import torchvision
        from torchvision.transforms import Normalize, ToTensor, Compose
        import torch.nn.functional as F
        from torch.utils.data.dataloader import DataLoader

        # 用 mnist 作為 toy dataset
        mnist = torchvision.datasets.MNIST(root='mnist', download=True, transform=ToTensor())
        dataloader = DataLoader(dataset=mnist, batch_size=8)

        # 初始化一個(gè)帶 BN 的簡(jiǎn)單模型
        toy_model = nn.Sequential(nn.Linear(28 ** 2, 128), nn.BatchNorm1d(128),
        nn.ReLU(), nn.Linear(128, 10), nn.Sigmoid())
        optimizer = torch.optim.SGD(toy_model.parameters(), lr=0.1)

        bn_1d_layer = toy_model[1]
        print(f'Initial weight is {bn_layer.weight[:4].tolist()}...')
        print(f'Initial bias is {bn_layer.bias[:4].tolist()}...\n')
        # 模擬更新2次參數(shù)
        for (i, data) in enumerate(dataloader):
        output = toy_model(data[0].view(data[0].shape[0], -1))
        (F.cross_entropy(output, data[1])).backward()
        # 輸出部分參數(shù)的梯度,驗(yàn)證weight和bias確實(shí)是通過(guò)gradient descent更新的
        print(f'Gradient of weight is {bn_1d_layer.weight.grad[:4].tolist()}...')
        print(f'Gradient of bias is {bn_1d_layer.bias.grad[:4].tolist()}...')
        optimizer.step()
        optimizer.zero_grad()
        if i == 1:
        break
        print(f'\nNow weight is {bn_1d_layer.weight[:4].tolist()}...')
        print(f'Now bias is {bn_1d_layer.bias[:4].tolist()}...')

        inputs = torch.randn(4, 128)
        bn_outputs = bn_1d_layer(inputs)
        new_bn = nn.BatchNorm1d(128)
        bn_outputs_no_weight_bias = new_bn(inputs)

        assert not torch.allclose(bn_outputs, bn_outputs_no_weight_bias)

        輸出:

        Initial weight is [0.9999354481697083, 1.0033478736877441, 1.0019147396087646, 0.9986106157302856]...
        Initial bias is [-0.0012734815245494246, 0.001349383033812046, 0.0013358002761378884, -0.0007148777367547154]...

        Gradient of weight is [-0.0004475426103454083, -0.0021388232707977295, -0.0032624618615955114, -0.0009599098702892661]...
        Gradient of bias is [0.00011698803427862003, -0.001291472464799881, -0.0023048489820212126, -0.0009493136312812567]...
        Gradient of weight is [-0.00035325769567862153, -0.0014295700239017606, -0.002102235099300742, 0.000851186050567776]...
        Gradient of bias is [-0.00026844028616324067, -0.00025666248984634876, -0.0017800561618059874, 0.00024933076929301023]...

        Now weight is [1.0000154972076416, 1.0037046670913696, 1.0024511814117432, 0.9986214637756348]...
        Now bias is [-0.0012583363568410277, 0.0015041964361444116, 0.0017442908138036728, -0.0006448794738389552]...

        2.1.5 eval 模式

        上面驗(yàn)證的都是 train 模式下 BN 的表現(xiàn),eval 模式有幾個(gè)重要的參數(shù)。

        • track_running_stats默認(rèn)為True,train 模式下統(tǒng)計(jì)running_meanrunning_var,eval 模式下用統(tǒng)計(jì)數(shù)據(jù)作為??和??。設(shè)置為False時(shí),eval模式直接計(jì)算輸入的均值和方差。

        • running_mean、running_var:train 模式下的統(tǒng)計(jì)量。

        也就是說(shuō),BN.training?并不是決定 BN 行為的唯一參數(shù)。滿足BN.training or not BN.track_running_stats就會(huì)直接計(jì)算輸入數(shù)據(jù)的均值方差,否則用統(tǒng)計(jì)量代替。

        # 切換到eval模式
        bn_layer.train(mode=False)
        inference_inputs = create_inputs()
        # 輸出前后的 running_mean 和 running_var,驗(yàn)證eval模式下不再更新統(tǒng)計(jì)量
        print(f'bn_layer running_mean is {bn_layer.running_mean}')
        print(f'bn_layer running_var is {bn_layer.running_var}')
        bn_outputs = bn_layer(inference_inputs)
        print(f'Now bn_layer running_mean is {bn_layer.running_mean}')
        print(f'Now bn_layer running_var is {bn_layer.running_var}')
        # 用之前統(tǒng)計(jì)的running_mean和running_var替代輸入的running_mean和running_var
        _, _, dummy_outputs = dummy_bn_forward(
        inference_inputs, bn_layer.weight,
        bn_layer.bias, bn_layer.eps,
        running_mean, running_var)
        assert torch.allclose(dummy_outputs, bn_outputs)

        # 關(guān)閉track_running_stats后,即使在eval模式下,也會(huì)去計(jì)算輸入的mean和var
        bn_layer.track_running_stats = False
        bn_outputs_notrack = bn_layer(inference_inputs)
        _, _, dummy_outputs_notrack = dummy_bn_forward(
        inference_inputs, bn_layer.weight,
        bn_layer.bias, bn_layer.eps)

        assert torch.allclose(dummy_outputs_notrack, bn_outputs_notrack)
        assert not torch.allclose(bn_outputs, bn_outputs_notrack)

        輸出結(jié)果如下:

        bn_layer running_mean is tensor([-0.0143,  0.0089, -0.0062])
        bn_layer running_var is tensor([0.9611, 1.0380, 1.0181])
        Now bn_layer running_mean is tensor([-0.0143, 0.0089, -0.0062])
        Now bn_layer running_var is tensor([0.9611, 1.0380, 1.0181])

        2.2 BatchNormNd 類

        包括BatchNorm1d,BatchNorm2d,BatchNorm3d。區(qū)別只是檢查了輸入的合法性,這里簡(jiǎn)單貼一下BatchNorm2d的實(shí)現(xiàn):

        class BatchNorm2d(_BatchNorm):
        def _check_input_dim(self, input):
        if input.dim() != 4:
        raise ValueError('expected 4D input (got {}D input)'
        .format(input.dim()))

        BatchNorm1d接受 2D 或 3D 的輸入,BatchNorm2d接受 4D 的輸入,BatchNorm3d接受 5D 的輸入。

        3. SyncBatchNorm 的 PyTorch 實(shí)現(xiàn)

        BN 的性能和 batch size 有很大的關(guān)系。batch size 越大,BN 的統(tǒng)計(jì)量也會(huì)越準(zhǔn)。然而像檢測(cè)這樣的任務(wù),占用顯存較高,一張顯卡往往只能拿較少的圖片(比如 2 張)來(lái)訓(xùn)練,這就導(dǎo)致 BN 的表現(xiàn)變差。一個(gè)解決方式是 SyncBN:所有卡共享同一個(gè) BN,得到全局的統(tǒng)計(jì)量。

        PyTorch 的 SyncBN 分別在?torch/nn/modules/batchnorm.py?和?torch/nn/modules/_functions.py?做了實(shí)現(xiàn)。前者主要負(fù)責(zé)檢查輸入合法性,以及根據(jù)momentum等設(shè)置進(jìn)行傳參,調(diào)用后者。后者負(fù)責(zé)計(jì)算單卡統(tǒng)計(jì)量以及進(jìn)程間通信。

        class SyncBatchNorm(_BatchNorm):
        def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
        track_running_stats=True, process_group=None):
        super(SyncBatchNorm, self).__init__(num_features, eps, momentum, affine, track_running_stats)
        self.process_group = process_group
        # gpu_size is set through DistributedDataParallel initialization. This is to ensure that SyncBatchNorm is used
        # under supported condition (single GPU per process)
        self.ddp_gpu_size = None

        def _check_input_dim(self, input):
        if input.dim() < 2:
        raise ValueError('expected at least 2D input (got {}D input)'
        .format(input.dim()))

        def _specify_ddp_gpu_num(self, gpu_size):
        if gpu_size > 1:
        raise ValueError('SyncBatchNorm is only supported for DDP with single GPU per process')
        self.ddp_gpu_size = gpu_size

        def forward(self, input):
        if not input.is_cuda:
        raise ValueError('SyncBatchNorm expected input tensor to be on GPU')

        self._check_input_dim(input)

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that it gets updated
        # in ONNX graph when this node is exported to ONNX.
        # 接下來(lái)這部分與普通BN差別不大
        if self.momentum is None:
        exponential_average_factor = 0.0
        else:
        exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
        self.num_batches_tracked = self.num_batches_tracked + 1
        if self.momentum is None: # use cumulative moving average
        exponential_average_factor = 1.0 / self.num_batches_tracked.item()
        else: # use exponential moving average
        exponential_average_factor = self.momentum

        # 如果在train模式下,或者關(guān)閉track_running_stats,就需要同步全局的均值和方差
        need_sync = self.training or not self.track_running_stats
        if need_sync:
        process_group = torch.distributed.group.WORLD
        if self.process_group:
        process_group = self.process_group
        world_size = torch.distributed.get_world_size(process_group)
        need_sync = world_size > 1

        # 如果不需要同步,SyncBN的行為就與普通BN一致
        if not need_sync:
        return F.batch_norm(
        input, self.running_mean, self.running_var, self.weight, self.bias,
        self.training or not self.track_running_stats,
        exponential_average_factor, self.eps)
        else:
        if not self.ddp_gpu_size:
        raise AttributeError('SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel')

        return sync_batch_norm.apply(
        input, self.weight, self.bias, self.running_mean, self.running_var,
        self.eps, exponential_average_factor, process_group, world_size)

        # 把普通BN轉(zhuǎn)為SyncBN, 主要做一些參數(shù)拷貝
        @classmethod
        def convert_sync_batchnorm(cls, module, process_group=None):
        module_output = module
        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
        module_output = torch.nn.SyncBatchNorm(module.num_features,
        module.eps, module.momentum,
        module.affine,
        module.track_running_stats,
        process_group)
        if module.affine:
        with torch.no_grad():
        module_output.weight.copy_(module.weight)
        module_output.bias.copy_(module.bias)
        # keep requires_grad unchanged
        module_output.weight.requires_grad = module.weight.requires_grad
        module_output.bias.requires_grad = module.bias.requires_grad
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        for name, child in module.named_children():
        module_output.add_module(name, cls.convert_sync_batchnorm(child, process_group))
        del module
        return module_output

        3.1 forward

        復(fù)習(xí)一下方差的計(jì)算方式:?

        單卡上的 BN 會(huì)計(jì)算該卡對(duì)應(yīng)輸入的均值、方差,然后做 Normalize;SyncBN 則需要得到全局的統(tǒng)計(jì)量,也就是“所有卡上的輸入”對(duì)應(yīng)的均值、方差。一個(gè)簡(jiǎn)單的想法是分兩個(gè)步驟:

        1. 每張卡單獨(dú)計(jì)算其均值,然后做一次同步,得到全局均值

        2. 用全局均值去算每張卡對(duì)應(yīng)的方差,然后做一次同步,得到全局方差

        但兩次同步會(huì)消耗更多時(shí)間,事實(shí)上一次同步就可以實(shí)現(xiàn)??和??的計(jì)算:

        只需要在同步時(shí)算好??和??即可。這里用一張圖來(lái)描述這一過(guò)程。



        實(shí)現(xiàn)時(shí),batchnorm.SyncBatchNorm?根據(jù)自身的超參設(shè)置、train/eval 等設(shè)置參數(shù),并調(diào)用_functions.SyncBatchNorm,接口是def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):?首先算一下單卡上的均值和方差:

        # 這里直接算invstd,也就是 1/(sqrt(var+eps))
        mean, invstd = torch.batch_norm_stats(input, eps)

        然后同步各卡的數(shù)據(jù),得到mean_allinvstd_all,再算出全局的統(tǒng)計(jì)量,更新running_mean,running_var:

        # 計(jì)算全局的mean和invstd
        mean, invstd = torch.batch_norm_gather_stats_with_counts(
        input,
        mean_all,
        invstd_all,
        running_mean,
        running_var,
        momentum,
        eps,
        count_all.view(-1).long().tolist()
        )

        3.2 backward

        由于不同的進(jìn)程共享同一組 BN 參數(shù),因此在 backward 到 BN 前、后都需要做進(jìn)程的通信,在_functions.SyncBatchNorm中實(shí)現(xiàn):

        # calculate local stats as well as grad_weight / grad_bias
        sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
        grad_output,
        saved_input,
        mean,
        invstd,
        weight,
        self.needs_input_grad[0],
        self.needs_input_grad[1],
        self.needs_input_grad[2]
        )

        算出 weight、bias 的梯度以及??,??用于計(jì)算??的梯度:

        # all_reduce 計(jì)算梯度之和
        sum_dy_all_reduce = torch.distributed.all_reduce(
        sum_dy, torch.distributed.ReduceOp.SUM, process_group, async_op=True)
        sum_dy_xmu_all_reduce = torch.distributed.all_reduce(
        sum_dy_xmu, torch.distributed.ReduceOp.SUM, process_group, async_op=True)
        # ...
        # 根據(jù)總的size,對(duì)梯度做平均
        divisor = count_tensor.sum()
        mean_dy = sum_dy / divisor
        mean_dy_xmu = sum_dy_xmu / divisor
        # backward pass for gradient calculation
        grad_input = torch.batch_norm_backward_elemt(
        grad_output,
        saved_input,
        mean,
        invstd,
        weight,
        mean_dy,
        mean_dy_xmu
        )


        推薦閱讀

        PyTorch 源碼解讀之 torch.autograd

        CondInst:性能和速度均超越Mask RCNN的實(shí)例分割模型

        centerX: 用新的視角的方式打開CenterNet

        mmdetection最小復(fù)刻版(十一):概率Anchor分配機(jī)制PAA深入分析

        MMDetection新版本V2.7發(fā)布,支持DETR,還有YOLOV4在路上!

        CNN:我不是你想的那樣

        TF Object Detection 終于支持TF2了!

        無(wú)需tricks,知識(shí)蒸餾提升ResNet50在ImageNet上準(zhǔn)確度至80%+

        不妨試試MoCo,來(lái)替換ImageNet上pretrain模型!

        重磅!一文深入深度學(xué)習(xí)模型壓縮和加速

        從源碼學(xué)習(xí)Transformer!

        mmdetection最小復(fù)刻版(七):anchor-base和anchor-free差異分析

        mmdetection最小復(fù)刻版(四):獨(dú)家yolo轉(zhuǎn)化內(nèi)幕


        機(jī)器學(xué)習(xí)算法工程師


        ? ??? ? ? ? ? ? ? ? ? ? ? ??????? ??一個(gè)用心的公眾號(hào)


        ?

        瀏覽 69
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評(píng)論
        圖片
        表情
        推薦
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            国产精品久久久久久久四虎电影 | 挺进她的花苞啊太深了视频 | 久久免费少妇高潮99精品 | 丝袜足交在线 | 91精品国产日韩91久久久久久360 | 婷婷久久丁香 | 成年美女黄网站色大免费看 | 警察直男被gay猛男狂cao | 外国黄色小视频 | 熟妇性爱|