高效 PyTorch:6個(gè)Tips,為訓(xùn)練管道加渦輪增壓
極市導(dǎo)讀
本文為pytorch使用者給出了六條建議,讓訓(xùn)練更快、更穩(wěn)、更強(qiáng)。>>>極市CV俠侶正式出道!請大家前往文末為他們投票打call~
Efficient PyTorch — Supercharging Training Pipeline
作者:Eugene Khvedchenya
https://medium.com/@eugenekhvedchenya/efficient-pytorch-supercharging-training-pipeline-19a26265adae
每個(gè)深度學(xué)習(xí)項(xiàng)目的最終目標(biāo)都是為產(chǎn)品帶來價(jià)值。當(dāng)然,我們希望有最好的模型。什么是“最好的”取決于具體的業(yè)務(wù)場景,不在本文討論范圍內(nèi)。我想談?wù)?/span>如何從 train.py 腳本中獲得最大價(jià)值。
大綱
高級框架代替了自制的訓(xùn)練循環(huán)
使用額外的度量(metrics)監(jiān)控訓(xùn)練的進(jìn)度
使用 TensorBoard
可視化模型的預(yù)測
使用 Dict 作為數(shù)據(jù)集和模型的返回值
檢測異常并解決數(shù)值不穩(wěn)定問題
免責(zé)聲明: 在下一節(jié)中,我將包括一些源代碼清單。其中大多數(shù)都是為 Catalyst 框架(版本20.08)定制的,并且可以在 pytorch-toolbelt 中使用。
不要重新發(fā)明輪子
建議1 — 利用 PyTorch 生態(tài)中的高級訓(xùn)練框架
從頭開始寫訓(xùn)練循環(huán)的話, PyTorch 提供了極好的靈活性和自由度。理論上,這為編寫任何訓(xùn)練邏輯提供了無限的可能性。實(shí)際上,你很少會(huì)為訓(xùn)練 CycleGAN、蒸餾 BERT 或者實(shí)現(xiàn)3D 目標(biāo)檢測從頭開始編寫新奇的訓(xùn)練循環(huán)。
從頭開始編寫一個(gè)完整的訓(xùn)練循環(huán)是學(xué)習(xí) PyTorch 基礎(chǔ)知識(shí)的一種極好的方法。然而,我強(qiáng)烈建議一旦掌握了一些知識(shí),就切換到高級框架。有很多選擇: Catalyst,PyTorch-Lightning,F(xiàn)ast.AI,Ignite 等等。高級庫通過以下方式節(jié)省你的時(shí)間:
-
提供經(jīng)過良好測試的訓(xùn)練循環(huán) -
支持配置文件 -
支持多 GPU 和分布式訓(xùn)練 -
檢查點(diǎn)/實(shí)驗(yàn)的管理 -
自動(dòng)記錄訓(xùn)練進(jìn)度
從這些高級庫中獲得最大價(jià)值需要一些時(shí)間。然而,從長遠(yuǎn)來看,這種一次性投資是值得的。
優(yōu)點(diǎn)
-
訓(xùn)練管道更小——代碼更少——出現(xiàn)錯(cuò)誤的可能性更小 -
實(shí)驗(yàn)管理更容易 -
簡化分布式及混合精度訓(xùn)練
缺點(diǎn)
-
多一個(gè)抽象層——像往常一樣,當(dāng)使用高級框架時(shí),我們必須在特定框架的設(shè)計(jì)原則和范式中編寫代碼 -
時(shí)間投資——學(xué)習(xí)額外的框架需要時(shí)間
給我展示度量
建議2ー在訓(xùn)練過程中查看額外的度量
幾乎每一個(gè)快速上手的圖像分類示例項(xiàng)目都有一個(gè)共同點(diǎn),那就是它們在訓(xùn)練期間和訓(xùn)練后都報(bào)告了一組最小的度量。大多數(shù)情況下,它是Top-1和Top-5的準(zhǔn)確率,錯(cuò)誤率,訓(xùn)練/驗(yàn)證損失,就這么多。雖然這些度量是必不可少的,但只是冰山一角!
現(xiàn)代圖像分類模型有數(shù)千萬個(gè)參數(shù)。你想僅使用一個(gè)標(biāo)量值來評估嗎?
具有最佳 Top-1精度的 CNN 分類模型在泛化方面可能不是最佳分類模型。根據(jù)你的領(lǐng)域和需求,你可能希望保存假陽性/假陰性率最低的模型或平均精度最高的模型。
讓我給你列舉一些想法,在訓(xùn)練期間你可以記錄哪些數(shù)據(jù):
-
Grad-CAM 熱圖——查看圖像的哪些部分對某一特定類別的貢獻(xiàn)最大
可視化 Grad-CAM 熱圖有助于確定模型做出預(yù)測是基于真實(shí)病理學(xué)還是基于圖像artifacts
-
混淆矩陣——向你展示哪一對類別對你的模型來說最具挑戰(zhàn)性
混淆矩陣揭示了模型對特定類型進(jìn)行錯(cuò)誤分類的頻率(Eugene Khvedchenya,ALASKA2 Image Steganalysis,Kaggle)
-
預(yù)測的分布——給你關(guān)于最佳決策邊界的洞察
模型的負(fù)和正預(yù)測的分布情況表明,大部分?jǐn)?shù)據(jù)模型不能有把握地進(jìn)行分類(Eugene Khvedchenya,ALASKA2 Image Steganalysis,Kaggle)
-
所有層的梯度的最小/平均/最大值——可以確定模型中是否存在消失/爆炸梯度或初始化不佳的層
使用dashboard工具監(jiān)控訓(xùn)練
建議3ー使用TensorBoard或任何其他解決方案來監(jiān)測訓(xùn)練的進(jìn)展
在訓(xùn)練模型時(shí),你最不想做的事情可能就是查看控制臺(tái)輸出。一個(gè)強(qiáng)大的dashboard,你可以一次看到所有的度量,這是檢查訓(xùn)練結(jié)果的一種更有效的方式。
Tensorboard 可以本地快速檢查和比較你的運(yùn)行
對于少數(shù)實(shí)驗(yàn)和非分布式環(huán)境,TensorBoard 是一個(gè)黃金標(biāo)準(zhǔn)。從版本1.3開始,PyTorch 就完全支持它,并且提供了一系列豐富的特性來管理實(shí)驗(yàn)。還有更先進(jìn)的基于云計(jì)算的解決方案,比如 Weights&Biases, Alchemy, 和 TensorBoard.dev,這使得在多臺(tái)機(jī)器上監(jiān)視和比較訓(xùn)練會(huì)話變得更加容易。
當(dāng)使用 Tensorboard 時(shí),我通常會(huì)記錄一組度量:
-
學(xué)習(xí)率和其他可能會(huì)改變的優(yōu)化器參數(shù)(動(dòng)量,權(quán)重衰減等) -
花費(fèi)在數(shù)據(jù)預(yù)處理和模型內(nèi)部的時(shí)間 -
訓(xùn)練和驗(yàn)證的損失(每個(gè)批次和每個(gè)epoch平均) -
跨訓(xùn)練和驗(yàn)證的度量標(biāo)準(zhǔn) -
最終度量值訓(xùn)練會(huì)話的超參數(shù) -
混淆矩陣,精度-召回曲線,AUC (如果適用) -
模型預(yù)測的可視化(如果適用)
一圖勝千言
看到模型的預(yù)測是非常重要的。有時(shí)候訓(xùn)練數(shù)據(jù)是有噪聲的; 有時(shí)候,模型過擬合圖像的artifacts。通過可視化最好和最差的批次(基于損失或你感興趣的度量) ,你可以獲得有價(jià)值的洞察,了解你的模型在哪些情況下表現(xiàn)得好,哪些情況下表現(xiàn)得差。
建議4ー把每個(gè)epoch最好和最差的批次可視化,它可以給你無價(jià)的洞察力
給 Catalyst 用戶的Tip: 使用可視化回調(diào)的例子在這里: https://github.com/bloodaxe/Catalyst-inria-segmentation-Example/blob/master/fit_predict.py#l258
例如,在全球小麥檢測挑戰(zhàn)中,我們需要檢測圖像上的小麥穗。通過可視化最佳批次的圖片(基于 mAP 度量) ,我們看到該模型在尋找小目標(biāo)方面近乎完美。
最佳模型預(yù)測的可視化顯示模型在小目標(biāo)上表現(xiàn)良好(Eugene Khvedchenya,Global Wheat Detection,Kaggle)
相比之下,當(dāng)我們看到最糟糕的一批的第一個(gè)樣本時(shí),我們看到這個(gè)模型很難對大型目標(biāo)做出準(zhǔn)確的預(yù)測。視覺分析為任何數(shù)據(jù)科學(xué)家提供了無價(jià)的洞察力。
可視化最差的模型預(yù)測揭示了模型在大目標(biāo)上表現(xiàn)不佳(Eugene Khvedchenya,Global Wheat Detection,Kaggle)
查看最差的批次也有助于發(fā)現(xiàn)數(shù)據(jù)標(biāo)簽中的錯(cuò)誤。通常情況下,有錯(cuò)誤標(biāo)簽的樣本有較大的損失,因此會(huì)出現(xiàn)在最壞的批次。通過在每個(gè)epoch對最差的批次進(jìn)行視覺檢查,你可以消除這些錯(cuò)誤:
標(biāo)記錯(cuò)誤的例子。綠色像素表示真陽性,紅色像素表示假陰性。在這個(gè)示例中,地面ground-truth掩碼在該位置具有一個(gè)建筑足跡,而實(shí)際上在該位置沒有建筑足跡。(Eugene Khvedchenya,Inria 航空圖像標(biāo)記數(shù)據(jù)集)
使用 Dict 作為數(shù)據(jù)集和模型的返回值
建議5ー如果你的模型返回一個(gè)以上的值ー使用 Dict 返回結(jié)果。不要使用 tuple。
在復(fù)雜模型中,返回多個(gè)輸出并不罕見。例如,目標(biāo)檢測模型通常返回邊界框和它們的標(biāo)簽,在圖像分割 CNN 中,我們經(jīng)常返回中間的mask用于深度監(jiān)督,多任務(wù)學(xué)習(xí)現(xiàn)在也很流行。
在很多開源實(shí)現(xiàn)中,我經(jīng)常看到這樣的東西:
# Bad practice, don't return tupleclass RetinaNet(nn.Module):...def forward(self, image):x = self.encoder(image)x = self.decoder(x)bboxes, scores = self.head(x)return bboxes, scores...
出于對作者的尊重,我認(rèn)為這是一個(gè)糟糕的、非常糟糕的從模型返回結(jié)果的方法。以下是我推薦的替代方法:
class RetinaNet(nn.Module):RETINA_NET_OUTPUT_BBOXES = "bboxes"RETINA_NET_OUTPUT_SCORES = "scores"...def forward(self, image):x = self.encoder(image)x = self.decoder(x)bboxes, scores = self.head(x)return { RETINA_NET_OUTPUT_BBOXES: bboxes,RETINA_NET_OUTPUT_SCORES: scores }...
這個(gè)建議在某種程度上與《 Python 之禪》(The Zen of Python)中的假設(shè)產(chǎn)生了共鳴——“明確的比隱含的好”。遵循這一規(guī)則將使你的代碼更加清晰和易于維護(hù)。
那么,為什么我認(rèn)為第二種選擇更好呢? 原因如下:
-
返回值有一個(gè)與之關(guān)聯(lián)的顯式名稱。你不需要記住元組中元素的確切順序 -
如果需要訪問返回字典的特定元素,可以通過它的名稱來訪問 -
從模型中添加新的輸出不會(huì)破壞代碼
使用 Dict,您甚至可以改變模型的行為,以根據(jù)需要返回額外的輸出。例如,這里有一個(gè)簡短的代碼片段,演示了如何返回多個(gè)“ main”輸出和兩個(gè)用于度量學(xué)習(xí)的“輔助”輸出:
# https://github.com/BloodAxe/Kaggle-2020-Alaska2/blob/master/alaska2/models/timm.py#L104def forward(self, **kwargs):x = kwargs[self.input_key]x = self.rgb_bn(x)x = self.encoder.forward_features(x)embedding = self.pool(x)result = {OUTPUT_PRED_MODIFICATION_FLAG: self.flag_classifier(self.drop(embedding)),OUTPUT_PRED_MODIFICATION_TYPE: self.type_classifier(self.drop(embedding)),}if self.need_embedding:result[OUTPUT_PRED_EMBEDDING] = embeddingif self.arc_margin is not None:result[OUTPUT_PRED_EMBEDDING_ARC_MARGIN] = self.arc_margin(embedding)return result
同樣的建議也適用于 Dataset 類。對于 Cifar-10玩具示例,可以將圖像及其對應(yīng)的標(biāo)簽返回為 tuple。但是在處理多任務(wù)或多輸入模型時(shí),你希望以 Dict 類型返回?cái)?shù)據(jù)集中的樣本:
# https://github.com/BloodAxe/Kaggle-2020-Alaska2/blob/master/alaska2/dataset.py#L373class TrainingValidationDataset(Dataset):def __init__(self,images: Union[List, np.ndarray],targets: Optional[Union[List, np.ndarray]],quality: Union[List, np.ndarray],bits: Optional[Union[List, np.ndarray]],transform: Union[A.Compose, A.BasicTransform],features: List[str],):""":param obliterate - Augmentation that destroys embedding."""if targets is not None:if len(images) != len(targets):raise ValueError(f"Size of images and targets does not match: {len(images)} {len(targets)}")self.images = imagesself.targets = targetsself.transform = transformself.features = featuresself.quality = qualityself.bits = bitsdef __len__(self):return len(self.images)def __repr__(self):return f"TrainingValidationDataset(len={len(self)}, targets_hist={np.bincount(self.targets)}, qf={np.bincount(self.quality)}, features={self.features})"def __getitem__(self, index):image_fname = self.images[index]try:image = cv2.imread(image_fname)if image is None:raise FileNotFoundError(image_fname)except Exception as e:print("Cannot read image ", image_fname, "at index", index)print(e)qf = self.quality[index]data = {}data["image"] = imagedata.update(compute_features(image, image_fname, self.features))data = self.transform(**data)sample = {INPUT_IMAGE_ID_KEY: os.path.basename(self.images[index]), INPUT_IMAGE_QF_KEY: int(qf)}if self.bits is not None:# OKsample[INPUT_TRUE_PAYLOAD_BITS] = torch.tensor(self.bits[index], dtype=torch.float32)if self.targets is not None:target = int(self.targets[index])sample[INPUT_TRUE_MODIFICATION_TYPE] = targetsample[INPUT_TRUE_MODIFICATION_FLAG] = torch.tensor([target > 0]).float()for key, value in data.items():if key in self.features:sample[key] = tensor_from_rgb_image(value)return sample
當(dāng)你的代碼中有字典時(shí),你可以到處使用名字常量引用輸入/輸出。遵循這條規(guī)則將使你的訓(xùn)練流程非常清晰和易讀:
# https://github.com/BloodAxe/Kaggle-2020-Alaska2callbacks += [CriterionCallback(input_key=INPUT_TRUE_MODIFICATION_FLAG,output_key=OUTPUT_PRED_MODIFICATION_FLAG,criterion_key="bce"),CriterionCallback(input_key=INPUT_TRUE_MODIFICATION_TYPE,output_key=OUTPUT_PRED_MODIFICATION_TYPE,criterion_key="ce"),CompetitionMetricCallback(input_key=INPUT_TRUE_MODIFICATION_FLAG,output_key=OUTPUT_PRED_MODIFICATION_FLAG,prefix="auc",output_activation=binary_logits_to_probas,class_names=class_names,),OutputDistributionCallback(input_key=INPUT_TRUE_MODIFICATION_FLAG,output_key=OUTPUT_PRED_MODIFICATION_FLAG,output_activation=binary_logits_to_probas,prefix="distribution/binary",),BestMetricCheckpointCallback(target_metric="auc",target_metric_minimize=False,save_n_best=3),]
檢測訓(xùn)練中的異常
建議6ー在訓(xùn)練過程中使用torch.autograd.detect_anomaly()來發(fā)現(xiàn)算術(shù)異常。
如果你在訓(xùn)練期間看到任何的 NaNs 或 Inf 的損失/度量,一個(gè)警報(bào)應(yīng)該在你的頭腦中響起。這是一個(gè)指示器,說明你的管道出了問題。通常,它可能是由以下原因引起的:
torch.sqrt() 應(yīng)用在負(fù)數(shù)上, torch.log() 非正等等)
torch.mean() 和 torch.sum() reduction 的錯(cuò)誤使用(零大小張量上的均值會(huì)導(dǎo)致nan,大張量上的和容易導(dǎo)致溢出)
x.sigmoid() 不謹(jǐn)慎 (如果你損失函數(shù)需要計(jì)算概率,一個(gè)更好的方法是x.sigmoid().clamp(eps,1-eps 或 torch.logsigmoid(x).exp() ,可避免梯度消失)
為了查找代碼中 Nan/Inf 第一次出現(xiàn)的確切位置,PyTorch 提供了一個(gè)易于使用的方法 torch.autograd.detect _ anomaly () :
僅用于調(diào)試目的,平時(shí)要禁用它,因?yàn)楫惓z測會(huì)帶來額外的計(jì)算開銷,訓(xùn)練循環(huán)會(huì)變慢10-15% 左右。
結(jié)語
謝謝閱讀!我希望你喜歡它,并從中發(fā)現(xiàn)了一些可以用得上的東西。你想分享什么tips和tricks嗎?請?jiān)谠u論中寫下你的知識(shí),或者讓我知道大家對哪些 PyTorch 相關(guān)的話題感興趣~
推薦閱讀
極市征集了大家關(guān)于陪伴的故事
投票通道現(xiàn)已開啟
快來為你喜愛的TA加油吧!
極市平臺(tái)公眾號(hào)回復(fù)“七夕”即可獲取投票鏈接
每人每天有3次投票機(jī)會(huì)哦~
目前,活動(dòng)還在進(jìn)行中
大家可添加極小東微信(ID:cvmart3)投稿~
△ 掃碼添加極小東微信
