1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        PyTorch中模型的可復現(xiàn)性

        共 4627字,需瀏覽 10分鐘

         ·

        2021-05-16 07:07

        點擊上方小白學視覺”,選擇加"星標"或“置頂

        重磅干貨,第一時間送達

        本文轉(zhuǎn)自:AI算法與圖像處理

        在深度學習模型的訓練過程中,難免引入隨機因素,這就會對模型的可復現(xiàn)性產(chǎn)生不好的影響。但是對于研究人員來講,模型的可復現(xiàn)性是很重要的。這篇文章收集并總結了可能導致模型難以復現(xiàn)的原因,雖然不可能完全避免隨機因素,但是可以通過一些設置盡可能降低模型的隨機性。

        常規(guī)操作


        PyTorch官方提供了一些關于可復現(xiàn)性的解釋和說明。

        在PyTorch發(fā)行版中,不同的版本或不同的平臺上,不能保證完全可重復的結果。此外,即使在使用相同種子的情況下,結果也不能保證在CPU和GPU上再現(xiàn)。

        但是,為了使計算能夠在一個特定平臺和PyTorch版本上確定特定問題,需要采取幾個步驟。

        PyTorch中涉及兩個偽隨機數(shù)生成器,需要手動對其進行播種以使運行可重復。此外,還應確保代碼所依賴的所有其他庫以及使用隨機數(shù)的庫也使用固定種子。

        常用的固定seed的方法有:

        import torch
        import numpy as np
        import random

        seed=0

        random.seed(seed)
        np.random.seed(seed)
        if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        # Remove randomness (may be slower on Tesla GPUs)
        # https://pytorch.org/docs/stable/notes/randomness.html
        if seed == 0:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        API中也揭示了原因,PyTorch使用的CUDA實現(xiàn)中,有一部分是原子操作,尤其是atomicAdd,使用這個操作就代表數(shù)據(jù)不能夠并行處理,需要串行處理,使用到atomicAdd之后就會按照不確定的并行加法順序執(zhí)行,從而引入了不確定因素。PyTorch中使用到的atomicAdd的方法:

        前向傳播時:

        • torch.Tensor.index_add_()
        • torch.Tensor.scatter_add_()
        • torch.bincount()

        反向傳播時:

        • torch.nn.functional.embedding_bag()
        • torch.nn.functional.ctc_loss()
        • 其他pooling,padding, sampling操作

        可以說由于需要并行計算,從而引入atomicAdd之后,必然會引入不確定性,目前沒有一種簡單的方法可以完全避免不確定性。

        upsample層


        upsample導致模型可復現(xiàn)性變差,這一點在PyTorch的官方庫issue#12207中有提到。也有很多熱心的人提供了這個的解決方案:

        import torch.nn as nn
        class UpsampleDeterministic(nn.Module):
        def __init__(self,upscale=2):
        super(UpsampleDeterministic, self).__init__()
        self.upscale = upscale

        def forward(self, x):
        '''
        x: 4-dim tensor. shape is (batch,channel,h,w)
        output: 4-dim tensor. shape is (batch,channel,self.upscale*h,self.upscale*w)
        '''

        return x[:, :, :, None, :, None]\
        .expand(-1, -1, -1, self.upscale, -1, self.upscale)\
        .reshape(x.size(0), x.size(1), x.size(2)\
        *self.upscale, x.size(3)*self.upscale)

        # or
        def upsample_deterministic(x,upscale):
        return x[:, :, :, None, :, None]\
        .expand(-1, -1, -1, upscale, -1, upscale)\
        .reshape(x.size(0), x.size(1), x.size(2)\
        *upscale, x.size(3)*upscale)

        可以將以上模塊替換掉官方的nn.Upsample函數(shù)來避免不確定性。

        Batch Size


        Batch Size這個超參數(shù)很容易被人忽視,很多時候都是看目前剩余的顯存,然后再進行設置合適的Batch Size參數(shù)。模型復現(xiàn)時Batch Size大小是必須相同的。

        Batch Size對模型的影響很大,Batch Size決定了要經(jīng)過多少對數(shù)據(jù)的學習以后,進行一次反向傳播。

        Batch Size過大:

        • 占用顯存過大,在很多情況下很難滿足要求。對內(nèi)存的容量也有更高的要求。
        • 容易陷入局部最小值或者鞍點,模型會在發(fā)生過擬合,在訓練集上表現(xiàn)非常好,但是測試集上表現(xiàn)差。

        Batch Size過?。?/p>

        • 假設bs=1,這就屬于在線學習,每次的修正方向以各自樣本的梯度方向修正,很可能將難以收斂。
        • 訓練時間過長,難以提高資源利用率

        另外,由于CUDA的原因,Batch Size設置為2的冪次的時候速度更快一些。所以嘗試修改Batch Size的時候就按照4,8,16,32,...這樣進行設置。

        數(shù)據(jù)在線增強


        在這里參考的庫是ultralytics的yolov3實現(xiàn),數(shù)據(jù)增強分為在線增強離線增強

        • 在線增強:在獲得 batch 數(shù)據(jù)之后,然后對這個 batch 的數(shù)據(jù)進行增強,如旋轉(zhuǎn)、平移、翻折等相應的變化,由于有些數(shù)據(jù)集不能接受線性級別的增長,這種方法常常用于大的數(shù)據(jù)集。
        • 離線增強:直接對數(shù)據(jù)集進行處理,數(shù)據(jù)的數(shù)目會變成增強因子 x 原數(shù)據(jù)集的數(shù)目 ,這種方法常常用于數(shù)據(jù)集很小的時候。

        在yolov3中使用的就是在線增強,比如其中一部分增強方法:

        if self.augment:
        # 隨機左右翻轉(zhuǎn)
        lr_flip = True
        if lr_flip and random.random() < 0.5:
        img = np.fliplr(img)
        if nL:
        labels[:, 1] = 1 - labels[:, 1]

        # 隨機上下翻轉(zhuǎn)
        ud_flip = False
        if ud_flip and random.random() < 0.5:
        img = np.flipud(img)
        if nL:
        labels[:, 2] = 1 - labels[:, 2]

        可以看到,如果設置了在線增強,那么模型會以一定的概率進行增強,這樣會導致每次運行得到的訓練樣本可能是不一致的,這也就造成了模型的不可復現(xiàn)。為了復現(xiàn),這里暫時將在線增強的功能關掉。

        多線程操作


        FP32(或者FP16 apex)中的隨機性是由多線程引入的,在PyTorch中設置DataLoader中的num_worker參數(shù)為0,或者直接不使用GPU,通過--device cpu指定使用CPU都可以避免程序使用多線程。但是這明顯不是一個很好的解決方案,因為兩種操作都會顯著地影響訓練速度。

        任何多線程操作都可能會引入問題,甚至是對單個向量求和,因為線程求和將導致FP16 / 32的精度損失,從而執(zhí)行的順序和線程數(shù)將對結果產(chǎn)生輕微影響。


        其他
        • 所有模型涉及到的文件中使用到random或者np.random的部分都需要設置seed

        • dropout可能也會帶來隨機性。

        • 多GPU并行訓練會帶來一定程度的隨機性。

        • 可能還有一些其他問題,感興趣的話可以看一下知乎上問題: PyTorch 有哪些坑/bug?


        總結


        上面大概梳理了一下可能導致PyTorch的模型可復現(xiàn)性出現(xiàn)問題的原因??梢钥闯鰜?,有很多問題是難以避免的,比如使用到官方提及的幾個方法、涉及到atomicAdd的操作、多線程操作等等。

        筆者也在yolov3基礎上修改了以上提到的內(nèi)容,固定了seed,batch size,關閉了數(shù)據(jù)增強。在模型運行了10個epoch左右的時候,前后兩次訓練的結果是一模一樣的,但是隨著epoch越來越多,也會產(chǎn)生一定的波動

        總之,應該盡量滿足可復現(xiàn)性的要求,我們可以通過設置固定seed等操作,盡可能保證前后兩次相同實驗得到的結果波動不能太大,不然就很難判斷模型的提升是由于隨機性導致的還是對模型的改進導致的。

        目前筆者進行了多次試驗來研究模型的可復現(xiàn)性,偶爾會出現(xiàn)兩次一模一樣的訓練結果,但是更多實驗中,兩次的訓練結果都是略有不同的,不過通過以上設置,可以讓訓練結果差距在1%以內(nèi)。

        在目前的實驗中還無法達到每次前后兩次完全一樣,如果有讀者有類似的經(jīng)驗,歡迎來交流。


        下載1:OpenCV-Contrib擴展模塊中文版教程
        在「小白學視覺」公眾號后臺回復:擴展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴展模塊教程中文版,涵蓋擴展模塊安裝、SFM算法、立體視覺、目標跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。

        下載2:Python視覺實戰(zhàn)項目52講
        小白學視覺公眾號后臺回復:Python視覺實戰(zhàn)項目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個視覺實戰(zhàn)項目,助力快速學校計算機視覺。

        下載3:OpenCV實戰(zhàn)項目20講
        小白學視覺公眾號后臺回復:OpenCV實戰(zhàn)項目20講,即可下載含有20個基于OpenCV實現(xiàn)20個實戰(zhàn)項目,實現(xiàn)OpenCV學習進階。

        交流群


        歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學影像、GAN、算法競賽等微信群(以后會逐漸細分),請掃描下面微信號加群,備注:”昵稱+學校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~


        瀏覽 40
        點贊
        評論
        收藏
        分享

        手機掃一掃分享

        分享
        舉報
        評論
        圖片
        表情
        推薦
        點贊
        評論
        收藏
        分享

        手機掃一掃分享

        分享
        舉報
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            亚洲无码中文字幕在线观看 | 日本国产在线视频 | 97资源在线 | 欧美成人在线视频 | 精品一区二区久久久久久久网站 | 91女人18毛片水多的意思 | 深夜办公室老板揉我胸摸下边 | 亚洲AV成人精品毛片 | 观看成人永久免费视频 | 美女扒开胸罩免费视频网 |