實(shí)操教程|我的PyTorch模型比內(nèi)存還大,怎么訓(xùn)練呀?
極市導(dǎo)讀
本文介紹了一種技術(shù):梯度檢查點(diǎn)。通過從計(jì)算圖中省略一些激活值,減少了計(jì)算圖使用的內(nèi)存,降低了總體內(nèi)存壓力。 >>公眾號(hào)后臺(tái)回復(fù)“79”或者“陳鑫”獲得CVPR 2021:TransT 直播鏈接
隨著深度學(xué)習(xí)的飛速發(fā)展,模型越來越臃腫,哦不,先進(jìn),運(yùn)行SOTA模型的主要困難之一就是怎么把它塞到 GPU 上,畢竟,你無法訓(xùn)練一個(gè)設(shè)備裝不下的模型。改善這個(gè)問題的技術(shù)有很多種,例如,分布式訓(xùn)練和混合精度訓(xùn)練。
本文將介紹另一種技術(shù): 梯度檢查點(diǎn)(gradient checkpointing)。簡(jiǎn)單的說,梯度檢查點(diǎn)的工作原理是在反向時(shí)重新計(jì)算深層神經(jīng)網(wǎng)絡(luò)的中間值(而通常情況是在前向時(shí)存儲(chǔ)的)。這個(gè)策略是用時(shí)間(重新計(jì)算這些值兩次的時(shí)間成本)來換空間(提前存儲(chǔ)這些值的內(nèi)存成本)。
文末有一個(gè)示例基準(zhǔn)測(cè)試,它顯示了梯度檢查點(diǎn)減少了模型 60% 的內(nèi)存開銷(以增加 25% 的訓(xùn)練時(shí)間為代價(jià))。
詳細(xì)代碼請(qǐng)查看我的 GitHub 庫: https://github.com/spellml/tweet-sentiment-extraction/blob/master/notebooks/5-checkpointing.ipynb
>>> 神經(jīng)網(wǎng)絡(luò)如何使用內(nèi)存
為了理解梯度檢查點(diǎn)是如何起作用的,我們首先需要了解一下模型內(nèi)存分配是如何工作的。
神經(jīng)網(wǎng)絡(luò)使用的總內(nèi)存基本上是兩個(gè)部分的和。
第一部分是模型使用的靜態(tài)內(nèi)存。盡管 PyTorch 模型中內(nèi)置了一些固定開銷,但總的來說幾乎完全由模型權(quán)重決定。當(dāng)今生產(chǎn)中使用的現(xiàn)代深度學(xué)習(xí)模型的總參數(shù)在100萬到10億之間。作為參考,一個(gè)帶 16GB GPU 內(nèi)存的 NVIDIA T4 的實(shí)際限制大約在1-1.5億個(gè)參數(shù)之間。
第二部分是模型的計(jì)算圖所占用的動(dòng)態(tài)內(nèi)存。在訓(xùn)練模式下,每次通過神經(jīng)網(wǎng)絡(luò)的前向傳播都為網(wǎng)絡(luò)中的每個(gè)神經(jīng)元計(jì)算一個(gè)激活值,這個(gè)值隨后被存儲(chǔ)在所謂的計(jì)算圖中。必須為批中的每個(gè)單個(gè)訓(xùn)練樣本存儲(chǔ)一個(gè)值,因此數(shù)量會(huì)迅速的累積起來??傞_銷由模型大小和批次大小決定,一般設(shè)置最大批次大小限制來適配你的 GPU 內(nèi)存。
要了解更多關(guān)于 PyTorch autograd 的信息,請(qǐng)查看我的 Kaggle 筆記本《PyTorch autograd 解釋》: https://www.kaggle.com/residentmario/pytorch-autograd-explained
>>> 梯度檢查點(diǎn)是如何起作用的
大型模型在靜態(tài)和動(dòng)態(tài)方面都很耗資源。首先,它們很難適配 GPU,而且哪怕你把它們放到了設(shè)備上,也很難訓(xùn)練,因?yàn)榕未笮”黄认拗频奶《鵁o法收斂。
現(xiàn)有的各種技術(shù)可以改善這些問題中的一個(gè)或兩個(gè)。梯度檢查點(diǎn)就是這樣一種技術(shù); 分布式訓(xùn)練,是另一種技術(shù)。
梯度檢查點(diǎn)(gradient checkpointing) 的工作原理是從計(jì)算圖中省略一些激活值。這減少了計(jì)算圖使用的內(nèi)存,降低了總體內(nèi)存壓力(并允許在處理過程中使用更大的批次大?。?。
但是,一開始存儲(chǔ)激活的原因是,在反向傳播期間計(jì)算梯度時(shí)需要用到激活。在計(jì)算圖中忽略它們將迫使 PyTorch 在任何出現(xiàn)這些值的地方重新計(jì)算,從而降低了整體計(jì)算速度。
因此,梯度檢查點(diǎn)是計(jì)算機(jī)科學(xué)中折衷的一個(gè)經(jīng)典例子,即在內(nèi)存和計(jì)算之間的權(quán)衡。
PyTorch 通過 torch.utils.checkpoint.checkpoint 和 torch.utils.checkpoint.checkpoint_sequential 提供梯度檢查點(diǎn),根據(jù)官方文檔的 notes,它實(shí)現(xiàn)了如下功能,在前向傳播時(shí),PyTorch 將保存模型中的每個(gè)函數(shù)的輸入元組。在反向傳播過程中,對(duì)于每個(gè)函數(shù),輸入元組和函數(shù)的組合以實(shí)時(shí)的方式重新計(jì)算,插入到每個(gè)需要它的函數(shù)的梯度公式中,然后丟棄。網(wǎng)絡(luò)計(jì)算開銷大致相當(dāng)于每個(gè)樣本通過模型前向傳播開銷的兩倍。
梯度檢查點(diǎn)首次發(fā)表在2016年的論文 《Training Deep Nets With Sublinear Memory Cost》 中。論文聲稱提出的梯度檢查點(diǎn)算法將模型的動(dòng)態(tài)內(nèi)存開銷從 O(n)(n 為模型中的層數(shù))降低到 O(sqrt(n)),并通過實(shí)驗(yàn)展示了將 ImageNet 的一個(gè)變種從 48GB 壓縮到了 7GB 內(nèi)存占用。
>>> 測(cè)試 API
PyTorch API 中有兩個(gè)不同的梯度檢查點(diǎn)方法,都在 torch.utils.checkpoint 命名空間中。兩者中比較簡(jiǎn)單的一個(gè)是 checkpoint_sequential,它被限制用于順序模型(例如使用 torch.nn.Sequential wrapper 的模型)。另一個(gè)是更靈活的 checkpoint,可以用于任何模塊。
下面是一個(gè)完整的代碼示例,顯示了 checkpoint_sequential 的實(shí)際用法:
import torchimport torch.nn as nn
from torch.utils.checkpoint import checkpoint_sequential
# a trivial modelmodel = nn.Sequential( nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 5), nn.ReLU())
# model inputinput_var = torch.randn(1, 100, requires_grad=True)
# the number of segments to divide the model intosegments = 2
# finally, apply checkpointing to the model# note the code that this replaces:# out = model(input_var)out = checkpoint_sequential(modules, segments, input_var)
# backpropagateout.sum().backwards()
如你所見,checkpoint_sequential 替換了 module 對(duì)象上的 forward 或 __call__ 方法。out 幾乎和我們調(diào)用 model(input_var) 時(shí)得到的張量一樣; 關(guān)鍵的區(qū)別在于它缺少了累積值,并且附加了一些額外的元數(shù)據(jù),指示 PyTorch 在 out.backward() 期間需要這些值時(shí)重新計(jì)算。
值得注意的是,checkpoint_sequential 接受整數(shù)值的片段數(shù)作為輸入。checkpoint_sequential 將模型分割成 n 個(gè)縱向片段,并對(duì)除了最后一個(gè)的每個(gè)片段應(yīng)用檢查點(diǎn)。
這工作很容易,但有一些主要的限制。你無法控制片段的邊界在哪里,也無法對(duì)整個(gè)模塊應(yīng)用檢查點(diǎn)(而是其中的一部分)。
替代方法是使用更靈活的 checkpoint API. 下面展示了一個(gè)簡(jiǎn)單的卷積模型:
class CIFAR10Model(nn.Module):def __init__(self):super().__init__()self.cnn_block_1 = nn.Sequential(*[nn.Conv2d(3, 32, 3, padding=1),nn.ReLU(),nn.Conv2d(32, 64, 3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Dropout(0.25)])self.cnn_block_2 = nn.Sequential(*[nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(),nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Dropout(0.25)])self.flatten = lambda inp: torch.flatten(inp, 1)self.head = nn.Sequential(*[nn.Linear(64 * 8 * 8, 512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, 10)])def forward(self, X):X = self.cnn_block_1(X)X = self.cnn_block_2(X)X = self.flatten(X)X = self.head(X)return X
這種模型有兩個(gè)卷積塊,一些 dropout,和一個(gè)線性頭(10個(gè)輸出對(duì)應(yīng) CIFAR10 的10類)。
下面是這個(gè)模型使用梯度檢查點(diǎn)的更新版本:
class CIFAR10Model(nn.Module): def __init__(self): super().__init__() self.cnn_block_1 = nn.Sequential(*[ nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2) ]) self.dropout_1 = nn.Dropout(0.25) self.cnn_block_2 = nn.Sequential(*[ nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(), nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2) ]) self.dropout_2 = nn.Dropout(0.25) self.flatten = lambda inp: torch.flatten(inp, 1) self.linearize = nn.Sequential(*[ nn.Linear(64 * 8 * 8, 512), nn.ReLU() ]) self.dropout_3 = nn.Dropout(0.5) self.out = nn.Linear(512, 10)
def forward(self, X): X = self.cnn_block_1(X) X = self.dropout_1(X) X = checkpoint(self.cnn_block_2, X) X = self.dropout_2(X) X = self.flatten(X) X = self.linearize(X) X = self.dropout_3(X) X = self.out(X) return X
在 forward 中顯示的 checkpoint 接受一個(gè)模塊(或任何可調(diào)用的模塊,如函數(shù))及其參數(shù)作為輸入。參數(shù)將在前向時(shí)被保存,然后用于在反向時(shí)重新計(jì)算其輸出值。
為了使其能夠工作,我們必須對(duì)模型定義進(jìn)行一些額外的更改。
首先,你會(huì)注意到我們從卷積塊里刪除了 nn.Dropout 層; 這是因?yàn)闄z查點(diǎn)與 dropout 不兼容(回想一下,樣本有效地通過模型兩次 —— dropout 會(huì)在每次通過時(shí)任意丟失不同的值,從而產(chǎn)生不同的輸出)?;旧?,任何在重新運(yùn)行時(shí)表現(xiàn)出非冪等(non-idempotent )行為的層都不應(yīng)該應(yīng)用檢查點(diǎn)(nn.BatchNorm 是另一個(gè)例子)。解決方案是重構(gòu)模塊,這樣問題層就不會(huì)被排除在檢查點(diǎn)片段之外,這正是我們?cè)谶@里所做的。
其次,你會(huì)注意到我們?cè)谀P椭械牡诙矸e塊上使用了檢查點(diǎn),但是第一個(gè)卷積塊上沒有使用檢查點(diǎn)。這是因?yàn)闄z查點(diǎn)簡(jiǎn)單地通過檢查輸入張量的 requires_grad 行為來決定它的輸入函數(shù)是否需要梯度下降(例如,它是否處于 requires_grad=True 或 requires_grad=False模式)。模型的輸入張量幾乎總是處于 requires_grad=False 模式,因?yàn)槲覀兏信d趣的是計(jì)算相對(duì)于網(wǎng)絡(luò)權(quán)重而不是輸入樣本本身的梯度。因此,模型中的第一個(gè)子模塊應(yīng)用檢查點(diǎn)沒多少意義: 它反而會(huì)凍結(jié)現(xiàn)有的權(quán)重,阻止它們進(jìn)行任何訓(xùn)練。更多細(xì)節(jié)請(qǐng)參考這個(gè) PyTorch 論壇帖子:https://discuss.pytorch.org/t/use-of-torch-utils-checkpoint-checkpoint-causes-simple-model-to-diverge/116271
在 PyTorch 文檔(https://pytorch.org/docs/stable/checkpoint.html#)中還討論了 RNG 狀態(tài)以及與分離張量不兼容的一些其他細(xì)節(jié)。
完整的訓(xùn)練代碼示例可以看這里:https://gist.github.com/ResidentMario/e3254172b4706191089bb63ecd610e21
和這里: https://gist.github.com/ResidentMario/9c3a90504d1a027aab926fd65ae08139
>>> 基準(zhǔn)測(cè)試
作為一個(gè)快速的基準(zhǔn)測(cè)試,我在 tweet-sentiment-extraction 上啟用了模型檢查點(diǎn),這是一個(gè)基于 Twitter 數(shù)據(jù)的帶有 BERT 主干的情感分類器模型。你可以在這里看到代碼:https://github.com/spellml/tweet-sentiment-extraction。transformers 已經(jīng)將模型檢查點(diǎn)作為 API 的一個(gè)可選部分來實(shí)現(xiàn); 為我們的模型啟用它就像翻轉(zhuǎn)一個(gè)布爾值標(biāo)記一樣簡(jiǎn)單:
# code from model_5.py
cfg = transformers.PretrainedConfig.get_config_dict("bert-base-uncased")[0]cfg["output_hidden_states"] = Truecfg["gradient_checkpointing"] = True # NEW!cfg = transformers.BertConfig.from_dict(cfg)self.bert = transformers.BertModel.from_pretrained( "bert-base-uncased", config=cfg)
我對(duì)這個(gè)模型進(jìn)行了四次訓(xùn)練: 分別在 NVIDIA T4和 NVIDIA V100 GPU 上,包括檢查點(diǎn)和無檢查點(diǎn)模式。所有運(yùn)行的批次大小為 64。以下是結(jié)果:

第一行是在模型檢查點(diǎn)關(guān)閉的情況下進(jìn)行的訓(xùn)練,第二行是在模型檢查點(diǎn)開啟的情況下進(jìn)行的訓(xùn)練。
模型檢查點(diǎn)降低了峰值模型內(nèi)存使用量 60% ,同時(shí)增加了模型訓(xùn)練時(shí)間 25% 。
當(dāng)然,你想要使用檢查點(diǎn)的主要原因可能是,這樣你就可以在 GPU 上使用更大的批次大小。在另一篇博文:https://qywu.github.io/2019/05/22/explore-gradient-checkpointing.html 中演示了這個(gè)很好的例子: 在他們的例子中,每批次樣本從 24 個(gè)提高到驚人的 132 個(gè)!
要處理大型神經(jīng)網(wǎng)絡(luò),模型檢查點(diǎn)顯然是一個(gè)非常強(qiáng)大和有用的工具。
原文:https://spell.ml/blog/gradient-checkpointing-pytorch-YGypLBAAACEAefHs
本文亮點(diǎn)總結(jié)
如果覺得有用,就請(qǐng)分享到朋友圈吧!
公眾號(hào)后臺(tái)回復(fù)“李鐸”獲取【極市線下沙龍】CVPR2021:通過反轉(zhuǎn)卷積的內(nèi)在性質(zhì)進(jìn)行視覺識(shí)別資源
YOLO教程:YOLO系列(從V1到V5)模型解讀|YOLO算法最全綜述:從YOLOv1到Y(jié)OLOv5
實(shí)操教程:使用Transformer來做物體檢測(cè)?DETR模型完整指南|PyTorch編譯并調(diào)用自定義CUDA算子的三種方式
算法技巧(trick):半監(jiān)督深度學(xué)習(xí)訓(xùn)練和實(shí)現(xiàn)|8點(diǎn)PyTorch提速技巧匯總
最新CV競(jìng)賽:2021 高通人工智能應(yīng)用創(chuàng)新大賽|CVPR 2021 | Short-video Face Parsing Challenge
# CV技術(shù)社群邀請(qǐng)函 #
備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測(cè)-深圳)
即可申請(qǐng)加入極市目標(biāo)檢測(cè)/圖像分割/工業(yè)檢測(cè)/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群
每月大咖直播分享、真實(shí)項(xiàng)目需求對(duì)接、求職內(nèi)推、算法競(jìng)賽、干貨資訊匯總、與 10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動(dòng)交流~
