PyTorch 進階之路:在 GPU 上訓練深度神經(jīng)網(wǎng)絡(luò)
點擊上方“小白學視覺”,選擇加"星標"或“置頂”
重磅干貨,第一時間送達
選自 | Medium
作者 | Aakash N S
參與| Panda
本文是該系列的第四篇,將介紹如何在 GPU 上使用 PyTorch 訓練深度神經(jīng)網(wǎng)絡(luò)。
在之前的教程中,我們基于 MNIST 數(shù)據(jù)集訓練了一個識別手寫數(shù)字的 logistic 回歸模型,并且達到了約 86% 的準確度。

但是,我們也注意到,由于模型能力有限,很難再進一步將準確度提升到 87% 以上。在本文中,我們將嘗試使用前向神經(jīng)網(wǎng)絡(luò)來提升準確度。本教程的大部分內(nèi)容受到了 Jeremy Howard 的 FastAI 開發(fā)筆記的啟發(fā):https://github.com/fastai/fastai_old/tree/master/dev_nb
如果你想一邊閱讀一邊運行代碼,你可以通過下面的鏈接找到本教程的 Jupyter Notebook:
https://jvn.io/aakashns/fdaae0bf32cf4917a931ac415a5c31b0
你可以克隆這個筆記,使用 conda 安裝必要的依賴包,然后通過在終端運行以下命令來啟動 Jupyter:
pip install jovian --upgrade # Install the jovian library
jovian clone fdaae0bf32cf4917a931ac415a5c31b0 # Download notebook
cd 04-feedforward-nn # Enter the created directory
jovian install # Install the dependencies
conda activate 04-feedfoward-nn # Activate virtual env
jupyter notebook # Start Jupyter
如果你的 conda 版本更舊一些,你也許需要運行 source activate 04-feedforward-nn 來激活虛擬環(huán)境。對以上步驟的更詳細解釋可參閱本教程的本系列教程第一篇文章。
這里的數(shù)據(jù)準備流程和前一篇教程完全一樣。我們首先導入所需的模塊和類。

我們使用 torchvision.datasets 的 MNIST 類下載數(shù)據(jù)并創(chuàng)建一個 PyTorch 數(shù)據(jù)集。

接下來,我們定義并使用一個函數(shù) split_indices 來隨機選取 20% 圖像作為驗證集。

現(xiàn)在,我們可以使用 SubsetRandomSampler 為每個子集創(chuàng)建 PyTorch 數(shù)據(jù)加載器,它可從一個給定的索引列表中隨機地采樣元素,同時創(chuàng)建分批數(shù)據(jù)。

要在 logistic 回歸的基礎(chǔ)上實現(xiàn)進一步提升,我們將創(chuàng)建一個帶有一個隱藏層(hidden layer)的神經(jīng)網(wǎng)絡(luò)。這是我們的做法:
我們不再使用單個 nn.Linear 對象將輸入批(像素強度)轉(zhuǎn)換成輸出批(類別概率),而是將使用兩個 nn.Linear 對象。其中每一個對象都被稱為一層,而該模型本身則被稱為一個網(wǎng)絡(luò)。
第一層(也被稱為隱藏層)可將大小為 batch_size x 784 的輸入矩陣轉(zhuǎn)換成大小為 batch_size x hidden_size 的中間輸出矩陣,其中 hidden_size 是一個預配置的參數(shù)(比如 32 或 64)。
然后,這個中間輸出會被傳遞給一個非線性激活函數(shù),它操作的是這個輸出矩陣的各個元素。
這個激活函數(shù)的結(jié)果的大小也為 batch_size x hidden_size,其會被傳遞給第二層(也被稱為輸出層)。該層可將隱藏層的結(jié)果轉(zhuǎn)換成一個大小為 batch_size x 10 的矩陣,這與 logistic 回歸模型的輸出一樣。
引入隱藏層和激活函數(shù)讓模型學習輸入與目標之間更復雜的、多層的和非線性的關(guān)系。看起來像是這樣(藍框表示單張輸入圖像的層輸出):

我們這里將使用的激活函數(shù)是整流線性單元(ReLU),它的公式很簡單:relu(x) = max(0,x),即如果一個元素為負,則將其替換成 0,否則保持不變。
為了定義模型,我們對 nn.Module 類進行擴展,就像我們使用 logistic 回歸時那樣。

我們將創(chuàng)建一個帶有 32 個激活的隱藏層的模型。

我們看看模型的參數(shù)??梢灶A見每一層都有一個權(quán)重和偏置矩陣。

我們試試用我們的模型生成一些輸出。我們從我們的數(shù)據(jù)集取第一批 100 張圖像,并將其傳入我們的模型。

隨著我們的模型和數(shù)據(jù)集規(guī)模增大,為了在合理的時間內(nèi)完成模型訓練,我們需要使用 GPU(圖形處理器,也被稱為顯卡)來訓練我們的模型。GPU 包含數(shù)百個核,這些核針對成本高昂的浮點數(shù)矩陣運算進行了優(yōu)化,讓我們可以在較短時間內(nèi)完成這些計算;這也因此使得 GPU 非常適合用于訓練具有很多層的深度神經(jīng)網(wǎng)絡(luò)。你可以在 Kaggle kernels 或 Google Colab 上免費使用 GPU,也可以租用 Google Cloud Platform、Amazon Web Services 或 Paperspace 等 GPU 使用服務(wù)。你可以使用 torch.cuda.is_available 檢查 GPU 是否可用以及是否已經(jīng)安裝了所需的英偉達驅(qū)動和 CUDA 庫。

我們定義一個輔助函數(shù),以便在有 GPU 時選擇 GPU 為目標設(shè)備,否則就默認選擇 CPU。

接下來,我們定義一個可將數(shù)據(jù)移動到所選設(shè)備的函數(shù)。

最后,我們定義一個 DeviceDataLoader 類(受 FastAI 的啟發(fā))來封裝我們已有的數(shù)據(jù)加載器并在讀取數(shù)據(jù)批時將數(shù)據(jù)移動到所選設(shè)備。有意思的是,我們不需要擴展已有的類來創(chuàng)建 PyTorch 數(shù)據(jù)加載器。我們只需要用 __iter__ 方法來檢索數(shù)據(jù)批并使用 __len__ 方法來獲取批數(shù)量即可。

我們現(xiàn)在可以使用 DeviceDataLoader 來封裝我們的數(shù)據(jù)加載器了。

已被移動到 GPU 的 RAM 的張量有一個 device 屬性,其中包含 cuda 這個詞。我們通過查看 valid_dl 的一批數(shù)據(jù)來驗證這一點。

和 logistic 回歸一樣,我們可以使用交叉熵作為損失函數(shù),使用準確度作為模型的評估指標。訓練循環(huán)也是一樣的,所以我們可以復用前一個教程的 loss_batch、evaluate 和 fit 函數(shù)。
loss_batch 函數(shù)計算的是一批數(shù)據(jù)的損失和指標值,并且如果提供了優(yōu)化器就可選擇執(zhí)行梯度下降。

evaluate 函數(shù)是為驗證集計算整體損失(如果有,還計算一個指標)。

和之前教程中定義的一樣,fit 函數(shù)包含實際的訓練循環(huán)。我們將對 fit 函數(shù)進行一些改進:
我們沒有人工地定義優(yōu)化器,而是將傳入學習率并在該函數(shù)中創(chuàng)建一個優(yōu)化器。這讓我們在有需要時能以不同的學習率訓練模型。
我們將記錄每 epoch 結(jié)束時的驗證損失和準確度,并返回這個歷史作為 fit 函數(shù)的輸出。

我們還要定義一個 accuracy 函數(shù),其計算的是模型在整批輸出上的整體準確度,所以我們可將其用作 fit 中的指標。

在我們訓練模型之前,我們需要確保數(shù)據(jù)和模型參數(shù)(權(quán)重和偏置)都在同一設(shè)備上(CPU 或 GPU)。我們可以復用 to_device 函數(shù)來將模型參數(shù)移至正確的設(shè)備。

我們看看使用初始權(quán)重和偏置時,模型在驗證集上的表現(xiàn)。

初始準確度大約是 10%,這符合我們對隨機初始化模型的預期(其有十分之一的可能性得到正確標簽)。
現(xiàn)在可以開始訓練模型了。我們先訓練 5 epoch 看看結(jié)果。我們可以使用相對較高的學習率 0.5。

95% 非常好了!我們再以更低的學習率 0.1 訓練 5 epoch,以進一步提升準確度。

現(xiàn)在我們可以繪制準確度圖表,看看模型隨時間的提升情況。

我們當前的模型極大地優(yōu)于 logistic 模型(僅能達到約 86% 的準確度)!它很快就達到了 96% 的準確度,但沒能實現(xiàn)進一步提升。如果要進一步提升準確度,我們需要讓模型更加強大。你可能也已經(jīng)猜到了,通過增大隱藏層的規(guī)?;蛱砑痈嚯[藏層可以實現(xiàn)這一目標。
最后,我們可以使用 jovian 庫保存和提交我們的成果。

jovian 會將筆記上傳到 https://jvn.io,并會獲取其 Python 環(huán)境并為該筆記創(chuàng)建一個可分享的鏈接。你可以使用該鏈接共享你的成果,讓任何人都能使用 jovian 克隆命令輕松復現(xiàn)它。jovian 還有一個強大的評論接口,讓你和其他人都能討論和點評你的筆記的各個部分。
本教程涵蓋的主題總結(jié)如下:
我們創(chuàng)建了一個帶有一個隱藏層的神經(jīng)網(wǎng)絡(luò),以在前一個教程的 logistic 回歸模型基礎(chǔ)上實現(xiàn)進一步提升。
我們使用了 ReLU 激活函數(shù)來引入非線性,讓模型可以學習輸入和輸出之間的更復雜的關(guān)系。
我們定義了 get_default_device、to_device 和 DeviceDataLoader 等一些實用程序,以便在可使用 GPU 時利用它,并將輸入數(shù)據(jù)和模型參數(shù)移動到合適的設(shè)備。
我們可以使用我們之前定義的同樣的訓練循環(huán):fit 函數(shù),來訓練我們的模型以及在驗證數(shù)據(jù)集上評估它。
其中有很多可以實驗的地方,我建議你使用 Jupyter 的交互性質(zhì)試試各種不同的參數(shù)。這里有一些想法:
試試修改隱藏層的大小或添加更多隱藏層,看你能否實現(xiàn)更高的準確度。
試試修改批大小和學習率,看你能否用更少的 epoch 實現(xiàn)同樣的準確度。
比較在 CPU 和 GPU 上的訓練時間。你看到存在顯著差異嗎?數(shù)據(jù)集的大小和模型的大?。?quán)重和參數(shù)的數(shù)量)對其有何影響?
試試為不同的數(shù)據(jù)集構(gòu)建模型,比如 CIFAR10 或 CIFAR100 數(shù)據(jù)集。
最后,分享一些適合進一步學習的好資源:
神經(jīng)網(wǎng)絡(luò)可以計算任何函數(shù)的視覺式證明,也被稱為通用近似定理:http://neuralnetworksanddeeplearning.com/chap4.html
神經(jīng)網(wǎng)絡(luò)究竟是什么?——通過視覺和直觀的介紹解釋了神經(jīng)網(wǎng)絡(luò)以及中間層所表示的內(nèi)容:https://www.youtube.com/watch?v=aircAruvnKk
斯坦福大學 CS229 關(guān)于反向傳播的講義——更數(shù)學地解釋了多層神經(jīng)網(wǎng)絡(luò)計算梯度和更新權(quán)重的方式:http://cs229.stanford.edu/notes/cs229-notes-backprop.pdf
吳恩達的 Coursera 課程:關(guān)于激活函數(shù)的視頻課程:https://www.coursera.org/lecture/neural-networks-deep-learning/activation-functions-4dDC1
好消息!
小白學視覺知識星球
開始面向外開放啦??????
下載1:OpenCV-Contrib擴展模塊中文版教程 在「小白學視覺」公眾號后臺回復:擴展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴展模塊教程中文版,涵蓋擴展模塊安裝、SFM算法、立體視覺、目標跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。 下載2:Python視覺實戰(zhàn)項目52講 在「小白學視覺」公眾號后臺回復:Python視覺實戰(zhàn)項目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個視覺實戰(zhàn)項目,助力快速學校計算機視覺。 下載3:OpenCV實戰(zhàn)項目20講 在「小白學視覺」公眾號后臺回復:OpenCV實戰(zhàn)項目20講,即可下載含有20個基于OpenCV實現(xiàn)20個實戰(zhàn)項目,實現(xiàn)OpenCV學習進階。 交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學影像、GAN、算法競賽等微信群(以后會逐漸細分),請掃描下面微信號加群,備注:”昵稱+學校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~

