>>極市CV俠侶正式出道!請大家前往文末為他們投票打c..." />
    1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        高效 PyTorch:6個(gè)Tips,為訓(xùn)練管道加渦輪增壓

        共 11924字,需瀏覽 24分鐘

         ·

        2020-08-24 22:22

        ↑ 點(diǎn)擊藍(lán)字 關(guān)注極市平臺(tái)

        作者丨McGL@知乎
        來源丨h(huán)ttps://zhuanlan.zhihu.com/p/194303854

        極市導(dǎo)讀

         

        本文為pytorch使用者給出了六條建議,讓訓(xùn)練更快、更穩(wěn)、更強(qiáng)。>>>極市CV俠侶正式出道!請大家前往文末為他們投票打call~

        高效 PyTorch系列第二彈來了,6個(gè)建議,讓你的訓(xùn)練更快、更穩(wěn)、更強(qiáng)。

        Efficient PyTorch — Supercharging Training Pipeline

        作者:Eugene Khvedchenya

        https://medium.com/@eugenekhvedchenya/efficient-pytorch-supercharging-training-pipeline-19a26265adae

        每個(gè)深度學(xué)習(xí)項(xiàng)目的最終目標(biāo)都是為產(chǎn)品帶來價(jià)值。當(dāng)然,我們希望有最好的模型。什么是“最好的”取決于具體的業(yè)務(wù)場景,不在本文討論范圍內(nèi)。我想談?wù)?/span>如何從 train.py 腳本中獲得最大價(jià)值。

        大綱

        • 高級框架代替了自制的訓(xùn)練循環(huán)

        • 使用額外的度量(metrics)監(jiān)控訓(xùn)練的進(jìn)度

        • 使用 TensorBoard

        • 可視化模型的預(yù)測

        • 使用 Dict 作為數(shù)據(jù)集和模型的返回值

        • 檢測異常并解決數(shù)值不穩(wěn)定問題

        免責(zé)聲明: 在下一節(jié)中,我將包括一些源代碼清單。其中大多數(shù)都是為 Catalyst 框架(版本20.08)定制的,并且可以在 pytorch-toolbelt 中使用。

        不要重新發(fā)明輪子

        建議1 — 利用 PyTorch 生態(tài)中的高級訓(xùn)練框架

        從頭開始寫訓(xùn)練循環(huán)的話, PyTorch 提供了極好的靈活性和自由度。理論上,這為編寫任何訓(xùn)練邏輯提供了無限的可能性。實(shí)際上,你很少會(huì)為訓(xùn)練 CycleGAN、蒸餾 BERT 或者實(shí)現(xiàn)3D 目標(biāo)檢測從頭開始編寫新奇的訓(xùn)練循環(huán)。

        從頭開始編寫一個(gè)完整的訓(xùn)練循環(huán)是學(xué)習(xí) PyTorch 基礎(chǔ)知識(shí)的一種極好的方法。然而,我強(qiáng)烈建議一旦掌握了一些知識(shí),就切換到高級框架。有很多選擇: Catalyst,PyTorch-Lightning,F(xiàn)ast.AI,Ignite 等等。高級庫通過以下方式節(jié)省你的時(shí)間:

        • 提供經(jīng)過良好測試的訓(xùn)練循環(huán)
        • 支持配置文件
        • 支持多 GPU 和分布式訓(xùn)練
        • 檢查點(diǎn)/實(shí)驗(yàn)的管理
        • 自動(dòng)記錄訓(xùn)練進(jìn)度

        從這些高級庫中獲得最大價(jià)值需要一些時(shí)間。然而,從長遠(yuǎn)來看,這種一次性投資是值得的。

        優(yōu)點(diǎn)

        • 訓(xùn)練管道更小——代碼更少——出現(xiàn)錯(cuò)誤的可能性更小
        • 實(shí)驗(yàn)管理更容易
        • 簡化分布式及混合精度訓(xùn)練

        缺點(diǎn)

        • 多一個(gè)抽象層——像往常一樣,當(dāng)使用高級框架時(shí),我們必須在特定框架的設(shè)計(jì)原則和范式中編寫代碼
        • 時(shí)間投資——學(xué)習(xí)額外的框架需要時(shí)間

        給我展示度量

        建議2ー在訓(xùn)練過程中查看額外的度量

        幾乎每一個(gè)快速上手的圖像分類示例項(xiàng)目都有一個(gè)共同點(diǎn),那就是它們在訓(xùn)練期間和訓(xùn)練后都報(bào)告了一組最小的度量。大多數(shù)情況下,它是Top-1和Top-5的準(zhǔn)確率,錯(cuò)誤率,訓(xùn)練/驗(yàn)證損失,就這么多。雖然這些度量是必不可少的,但只是冰山一角!

        現(xiàn)代圖像分類模型有數(shù)千萬個(gè)參數(shù)。你想僅使用一個(gè)標(biāo)量值來評估嗎?

        具有最佳 Top-1精度的 CNN 分類模型在泛化方面可能不是最佳分類模型。根據(jù)你的領(lǐng)域和需求,你可能希望保存假陽性/假陰性率最低的模型或平均精度最高的模型。

        讓我給你列舉一些想法,在訓(xùn)練期間你可以記錄哪些數(shù)據(jù):

        • Grad-CAM 熱圖——查看圖像的哪些部分對某一特定類別的貢獻(xiàn)最大

        可視化 Grad-CAM 熱圖有助于確定模型做出預(yù)測是基于真實(shí)病理學(xué)還是基于圖像artifacts

        • 混淆矩陣——向你展示哪一對類別對你的模型來說最具挑戰(zhàn)性

        混淆矩陣揭示了模型對特定類型進(jìn)行錯(cuò)誤分類的頻率(Eugene Khvedchenya,ALASKA2 Image Steganalysis,Kaggle)

        • 預(yù)測的分布——給你關(guān)于最佳決策邊界的洞察

        模型的負(fù)和正預(yù)測的分布情況表明,大部分?jǐn)?shù)據(jù)模型不能有把握地進(jìn)行分類(Eugene Khvedchenya,ALASKA2 Image Steganalysis,Kaggle)

        • 所有層的梯度的最小/平均/最大值——可以確定模型中是否存在消失/爆炸梯度或初始化不佳的層

        使用dashboard工具監(jiān)控訓(xùn)練

        建議3ー使用TensorBoard或任何其他解決方案來監(jiān)測訓(xùn)練的進(jìn)展

        在訓(xùn)練模型時(shí),你最不想做的事情可能就是查看控制臺(tái)輸出。一個(gè)強(qiáng)大的dashboard,你可以一次看到所有的度量,這是檢查訓(xùn)練結(jié)果的一種更有效的方式。

        Tensorboard 可以本地快速檢查和比較你的運(yùn)行

        對于少數(shù)實(shí)驗(yàn)和非分布式環(huán)境,TensorBoard 是一個(gè)黃金標(biāo)準(zhǔn)。從版本1.3開始,PyTorch 就完全支持它,并且提供了一系列豐富的特性來管理實(shí)驗(yàn)。還有更先進(jìn)的基于云計(jì)算的解決方案,比如 Weights&Biases, Alchemy, 和 TensorBoard.dev,這使得在多臺(tái)機(jī)器上監(jiān)視和比較訓(xùn)練會(huì)話變得更加容易。

        當(dāng)使用 Tensorboard 時(shí),我通常會(huì)記錄一組度量:

        • 學(xué)習(xí)率和其他可能會(huì)改變的優(yōu)化器參數(shù)(動(dòng)量,權(quán)重衰減等)
        • 花費(fèi)在數(shù)據(jù)預(yù)處理和模型內(nèi)部的時(shí)間
        • 訓(xùn)練和驗(yàn)證的損失(每個(gè)批次和每個(gè)epoch平均)
        • 跨訓(xùn)練和驗(yàn)證的度量標(biāo)準(zhǔn)
        • 最終度量值訓(xùn)練會(huì)話的超參數(shù)
        • 混淆矩陣,精度-召回曲線,AUC (如果適用)
        • 模型預(yù)測的可視化(如果適用)

        一圖勝千言

        看到模型的預(yù)測是非常重要的。有時(shí)候訓(xùn)練數(shù)據(jù)是有噪聲的; 有時(shí)候,模型過擬合圖像的artifacts。通過可視化最好和最差的批次(基于損失或你感興趣的度量) ,你可以獲得有價(jià)值的洞察,了解你的模型在哪些情況下表現(xiàn)得好,哪些情況下表現(xiàn)得差。

        建議4ー把每個(gè)epoch最好和最差的批次可視化,它可以給你無價(jià)的洞察力

        給 Catalyst 用戶的Tip: 使用可視化回調(diào)的例子在這里: https://github.com/bloodaxe/Catalyst-inria-segmentation-Example/blob/master/fit_predict.py#l258

        例如,在全球小麥檢測挑戰(zhàn)中,我們需要檢測圖像上的小麥穗。通過可視化最佳批次的圖片(基于 mAP 度量) ,我們看到該模型在尋找小目標(biāo)方面近乎完美。

        最佳模型預(yù)測的可視化顯示模型在小目標(biāo)上表現(xiàn)良好(Eugene Khvedchenya,Global Wheat Detection,Kaggle)

        相比之下,當(dāng)我們看到最糟糕的一批的第一個(gè)樣本時(shí),我們看到這個(gè)模型很難對大型目標(biāo)做出準(zhǔn)確的預(yù)測。視覺分析為任何數(shù)據(jù)科學(xué)家提供了無價(jià)的洞察力。

        可視化最差的模型預(yù)測揭示了模型在大目標(biāo)上表現(xiàn)不佳(Eugene Khvedchenya,Global Wheat Detection,Kaggle)

        查看最差的批次也有助于發(fā)現(xiàn)數(shù)據(jù)標(biāo)簽中的錯(cuò)誤。通常情況下,有錯(cuò)誤標(biāo)簽的樣本有較大的損失,因此會(huì)出現(xiàn)在最壞的批次。通過在每個(gè)epoch對最差的批次進(jìn)行視覺檢查,你可以消除這些錯(cuò)誤:

        標(biāo)記錯(cuò)誤的例子。綠色像素表示真陽性,紅色像素表示假陰性。在這個(gè)示例中,地面ground-truth掩碼在該位置具有一個(gè)建筑足跡,而實(shí)際上在該位置沒有建筑足跡。(Eugene Khvedchenya,Inria 航空圖像標(biāo)記數(shù)據(jù)集)

        使用 Dict 作為數(shù)據(jù)集和模型的返回值

        建議5ー如果你的模型返回一個(gè)以上的值ー使用 Dict 返回結(jié)果。不要使用 tuple。

        在復(fù)雜模型中,返回多個(gè)輸出并不罕見。例如,目標(biāo)檢測模型通常返回邊界框和它們的標(biāo)簽,在圖像分割 CNN 中,我們經(jīng)常返回中間的mask用于深度監(jiān)督,多任務(wù)學(xué)習(xí)現(xiàn)在也很流行。

        在很多開源實(shí)現(xiàn)中,我經(jīng)常看到這樣的東西:

        # Bad practice, don't return tupleclass RetinaNet(nn.Module):  ...
        def forward(self, image): x = self.encoder(image) x = self.decoder(x) bboxes, scores = self.head(x) return bboxes, scores
        ...

        出于對作者的尊重,我認(rèn)為這是一個(gè)糟糕的、非常糟糕的從模型返回結(jié)果的方法。以下是我推薦的替代方法:

        class RetinaNet(nn.Module):  RETINA_NET_OUTPUT_BBOXES = "bboxes"  RETINA_NET_OUTPUT_SCORES = "scores"
        ...
        def forward(self, image): x = self.encoder(image) x = self.decoder(x) bboxes, scores = self.head(x) return { RETINA_NET_OUTPUT_BBOXES: bboxes, RETINA_NET_OUTPUT_SCORES: scores }
        ...

        這個(gè)建議在某種程度上與《 Python 之禪》(The Zen of Python)中的假設(shè)產(chǎn)生了共鳴——“明確的比隱含的好”。遵循這一規(guī)則將使你的代碼更加清晰和易于維護(hù)。

        那么,為什么我認(rèn)為第二種選擇更好呢? 原因如下:

        • 返回值有一個(gè)與之關(guān)聯(lián)的顯式名稱。你不需要記住元組中元素的確切順序
        • 如果需要訪問返回字典的特定元素,可以通過它的名稱來訪問
        • 從模型中添加新的輸出不會(huì)破壞代碼

        使用 Dict,您甚至可以改變模型的行為,以根據(jù)需要返回額外的輸出。例如,這里有一個(gè)簡短的代碼片段,演示了如何返回多個(gè)“ main”輸出和兩個(gè)用于度量學(xué)習(xí)的“輔助”輸出:

        # https://github.com/BloodAxe/Kaggle-2020-Alaska2/blob/master/alaska2/models/timm.py#L104
        def forward(self, **kwargs): x = kwargs[self.input_key] x = self.rgb_bn(x) x = self.encoder.forward_features(x) embedding = self.pool(x) result = { OUTPUT_PRED_MODIFICATION_FLAG: self.flag_classifier(self.drop(embedding)), OUTPUT_PRED_MODIFICATION_TYPE: self.type_classifier(self.drop(embedding)), } if self.need_embedding: result[OUTPUT_PRED_EMBEDDING] = embedding if self.arc_margin is not None: result[OUTPUT_PRED_EMBEDDING_ARC_MARGIN] = self.arc_margin(embedding)
        return result

        同樣的建議也適用于 Dataset 類。對于 Cifar-10玩具示例,可以將圖像及其對應(yīng)的標(biāo)簽返回為 tuple。但是在處理多任務(wù)或多輸入模型時(shí),你希望以 Dict 類型返回?cái)?shù)據(jù)集中的樣本:

        # https://github.com/BloodAxe/Kaggle-2020-Alaska2/blob/master/alaska2/dataset.py#L373class TrainingValidationDataset(Dataset):    def __init__(        self,        images: Union[List, np.ndarray],        targets: Optional[Union[List, np.ndarray]],        quality: Union[List, np.ndarray],        bits: Optional[Union[List, np.ndarray]],        transform: Union[A.Compose, A.BasicTransform],        features: List[str],    ):        """        :param obliterate - Augmentation that destroys embedding.        """        if targets is not None:            if len(images) != len(targets):                raise ValueError(f"Size of images and targets does not match: {len(images)} {len(targets)}")
        self.images = images self.targets = targets self.transform = transform self.features = features self.quality = quality self.bits = bits
        def __len__(self): return len(self.images)
        def __repr__(self): return f"TrainingValidationDataset(len={len(self)}, targets_hist={np.bincount(self.targets)}, qf={np.bincount(self.quality)}, features={self.features})"
        def __getitem__(self, index): image_fname = self.images[index] try: image = cv2.imread(image_fname) if image is None: raise FileNotFoundError(image_fname) except Exception as e: print("Cannot read image ", image_fname, "at index", index) print(e)
        qf = self.quality[index] data = {} data["image"] = image data.update(compute_features(image, image_fname, self.features))
        data = self.transform(**data)
        sample = {INPUT_IMAGE_ID_KEY: os.path.basename(self.images[index]), INPUT_IMAGE_QF_KEY: int(qf)}
        if self.bits is not None: # OK sample[INPUT_TRUE_PAYLOAD_BITS] = torch.tensor(self.bits[index], dtype=torch.float32)
        if self.targets is not None: target = int(self.targets[index]) sample[INPUT_TRUE_MODIFICATION_TYPE] = target sample[INPUT_TRUE_MODIFICATION_FLAG] = torch.tensor([target > 0]).float()
        for key, value in data.items(): if key in self.features: sample[key] = tensor_from_rgb_image(value)
        return sample

        當(dāng)你的代碼中有字典時(shí),你可以到處使用名字常量引用輸入/輸出。遵循這條規(guī)則將使你的訓(xùn)練流程非常清晰和易讀:

        # https://github.com/BloodAxe/Kaggle-2020-Alaska2
        callbacks += [ CriterionCallback( input_key=INPUT_TRUE_MODIFICATION_FLAG, output_key=OUTPUT_PRED_MODIFICATION_FLAG, criterion_key="bce" ), CriterionCallback( input_key=INPUT_TRUE_MODIFICATION_TYPE, output_key=OUTPUT_PRED_MODIFICATION_TYPE, criterion_key="ce" ), CompetitionMetricCallback( input_key=INPUT_TRUE_MODIFICATION_FLAG, output_key=OUTPUT_PRED_MODIFICATION_FLAG, prefix="auc", output_activation=binary_logits_to_probas, class_names=class_names, ), OutputDistributionCallback( input_key=INPUT_TRUE_MODIFICATION_FLAG, output_key=OUTPUT_PRED_MODIFICATION_FLAG, output_activation=binary_logits_to_probas, prefix="distribution/binary", ), BestMetricCheckpointCallback( target_metric="auc", target_metric_minimize=False, save_n_best=3),]

        檢測訓(xùn)練中的異常

        建議6ー在訓(xùn)練過程中使用torch.autograd.detect_anomaly()來發(fā)現(xiàn)算術(shù)異常。

        如果你在訓(xùn)練期間看到任何的 NaNs 或 Inf 的損失/度量,一個(gè)警報(bào)應(yīng)該在你的頭腦中響起。這是一個(gè)指示器,說明你的管道出了問題。通常,它可能是由以下原因引起的:

        模型或特定層的初始化不好(你可以通過查看梯度大小來檢查是哪些層)
        錯(cuò)誤的數(shù)學(xué)運(yùn)算 (torch.sqrt() 應(yīng)用在負(fù)數(shù)上, torch.log() 非正等等)
        Improper use of torch.mean()torch.sum() reduction 的錯(cuò)誤使用(零大小張量上的均值會(huì)導(dǎo)致nan,大張量上的和容易導(dǎo)致溢出)
        損失使用 x.sigmoid() 不謹(jǐn)慎 (如果你損失函數(shù)需要計(jì)算概率,一個(gè)更好的方法是x.sigmoid().clamp(eps,1-epstorch.logsigmoid(x).exp() ,可避免梯度消失)
        類Adam 優(yōu)化器中的低 epsilon 值
        fp16 使用 fp16進(jìn)行訓(xùn)練時(shí)不使用動(dòng)態(tài)損失縮放

        為了查找代碼中 Nan/Inf 第一次出現(xiàn)的確切位置,PyTorch 提供了一個(gè)易于使用的方法 torch.autograd.detect _ anomaly () :
        僅用于調(diào)試目的,平時(shí)要禁用它,因?yàn)楫惓z測會(huì)帶來額外的計(jì)算開銷,訓(xùn)練循環(huán)會(huì)變慢10-15% 左右。

        結(jié)語

        謝謝閱讀!我希望你喜歡它,并從中發(fā)現(xiàn)了一些可以用得上的東西。你想分享什么tips和tricks嗎?請?jiān)谠u論中寫下你的知識(shí),或者讓我知道大家對哪些 PyTorch 相關(guān)的話題感興趣~


        推薦閱讀



        極市七夕"CV俠侶"征稿活動(dòng)

        極市征集了大家關(guān)于陪伴的故事

        投票通道現(xiàn)已開啟

        快來為你喜愛的TA加油吧!

        極市平臺(tái)公眾號(hào)回復(fù)七夕”即可獲取投票鏈接

        每人每天有3次投票機(jī)會(huì)哦~


        目前,活動(dòng)還在進(jìn)行中

        大家可添加極小東微信(ID:cvmart3)投稿~

        △ 掃碼添加極小東微信

        添加極市小助手微信(ID : cvmart2),備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測-深圳),即可申請加入極市目標(biāo)檢測/圖像分割/工業(yè)檢測/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群:月大咖直播分享、真實(shí)項(xiàng)目需求對接、求職內(nèi)推、算法競賽、干貨資訊匯總、與 10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動(dòng)交流~

        △長按添加極市小助手

        △長按關(guān)注極市平臺(tái),獲取最新CV干貨

        覺得有用麻煩給個(gè)在看啦~  
        瀏覽 50
        點(diǎn)贊
        評論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評論
        圖片
        表情
        推薦
        點(diǎn)贊
        評論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            午夜成人免费福利视频 | 夜夜春精品AAAAXXXX3D | 秋霞在线观看视频 | 日本三级日产三级国产三级 | 国产精品久久久久久久久久久易记 | 国产一精品一aⅴ一免费 | 麻豆最猛性XXxXXx交 | 中文字幕综合网 | 天堂在线观看视频 | 欧美精品无码久久久精品酒店 |