損失函數技術總結及Pytorch使用示例
點擊下方“AI算法與圖像處理”,一起進步!
重磅干貨,第一時間送達
導讀
?本文對損失函數的類別和應用場景,常見的損失函數,常見損失函數的表達式,特性,應用場景和使用示例作了詳細的總結。?
前言
一直想寫損失函數的技術總結,但網上已經有諸多關于損失函數綜述的文章或博客,考慮到這點就一直拖著沒寫,直到有一天,我將一個二分類項目修改為多分類,簡簡單單地修改了損失函數,結果一直有問題,后來才發(fā)現是不同函數的標簽的設置方式并不相同。
為了避免讀者也出現這樣的問題,本文中會給出每個損失函數的pytorch使用示例,這也是本文與其它相關綜述文章或博客的區(qū)別所在。希望讀者在閱讀本文時,重點關注一下每個損失函數的使用示例中的target的設置問題。
本文對損失函數的類別和應用場景,常見的損失函數,常見損失函數的表達式,特性,應用場景和使用示例作了詳細的總結。
主要涉及到L1 loss、L2 loss、Negative Log-Likelihood loss、Cross-Entropy loss、Hinge Embedding loss、Margin Ranking Loss、Triplet Margin loss、KL Divergence.
損失函數分類與應用場景
損失函數可以分為三類:回歸損失函數(Regression loss)、分類損失函數(Classification loss)和排序損失函數(Ranking loss)。
應用場景:回歸損失:用于預測連續(xù)的值。如預測房價、年齡等。分類損失:用于預測離散的值。如圖像分類,語義分割等。排序損失:用于預測輸入數據之間的相對距離。如行人重識別。
L1 loss
也稱Mean Absolute Error,簡稱MAE,計算實際值和預測值之間的絕對差之和的平均值。
表達式如下:
y表示標簽,pred表示預測值。
應用場合:回歸問題。
根據損失函數的表達式很容易了解它的特性:當目標變量的分布具有異常值時,即與平均值相差很大的值,它被認為對異常值具有很好的魯棒行。
使用示例:
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
mae_loss = torch.nn.L1Loss()
output = mae_loss(input, target)
L2 loss
也稱為Mean Squared Error,簡稱MSE,計算實際值和預測值之間的平方差的平均值。
表達式如下:
應用場合:對大部分回歸問題,pytorch默認使用L2,即MSE。
使用平方意味著當預測值離目標值更遠時在平方后具有更大的懲罰,預測值離目標值更近時在平方后懲罰更小,因此,當異常值與樣本平均值相差格外大時,模型會因為懲罰更大而開始偏離,相比之下,L1對異常值的魯棒性更好。
使用示例:
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
mse_loss = torch.nn.MSELoss()
output = mse_loss(input, target)
Negative Log-Likelihood
簡稱NLL。表達式如下:
應用場景:多分類問題。
注:NLL要求網絡最后一層使用softmax作為激活函數。通過softmax將輸出值映射為每個類別的概率值。
根據表達式,它的特性是懲罰預測準確而預測概率不高的情況。
NLL 使用負號,因為概率(或似然)在 0 和 1 之間變化,并且此范圍內的值的對數為負。最后,損失值變?yōu)檎怠?/p>
在 NLL 中,最小化損失函數有助于獲得更好的輸出。從近似最大似然估計 (MLE) 中檢索負對數似然。這意味著嘗試最大化模型的對數似然,從而最小化 NLL。
使用示例
# size of input (N x C) is = 3 x 5
input = torch.randn(3, 5, requires_grad=True)
# every element in target should have 0 <= value < C
target = torch.tensor([1, 0, 4])
m = nn.LogSoftmax(dim=1)
nll_loss = torch.nn.NLLLoss()
output = nll_loss(m(input), target)
Cross-Entropy
此損失函數計算提供的一組出現次數或隨機變量的兩個概率分布之間的差異。它用于計算預測值與實際值之間的平均差異的分數。
表達式:
應用場景:二分類及多分類。
特性:負對數似然損失不對預測置信度懲罰,與之不同的是,交叉熵懲罰不正確但可信的預測,以及正確但不太可信的預測。
交叉熵函數有很多種變體,其中最常見的類型是Binary Cross-Entropy (BCE)。BCE Loss 主要用于二分類模型;也就是說,模型只有 2 個類。
使用示例
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
cross_entropy_loss = torch.nn.CrossEntropyLoss()
output = cross_entropy_loss(input, target)
Hinge Embedding
表達式:
其中y為1或-1。
應用場景:
分類問題,特別是在確定兩個輸入是否不同或相似時。
學習非線性嵌入或半監(jiān)督學習任務。
使用示例
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
hinge_loss = torch.nn.HingeEmbeddingLoss()
output = hinge_loss(input, target)
Margin Ranking Loss
Margin Ranking Loss 計算一個標準來預測輸入之間的相對距離。這與其他損失函數(如 MSE 或交叉熵)不同,后者學習直接從給定的輸入集進行預測。
表達式:
標簽張量 y(包含 1 或 -1)。當 y == 1 時,第一個輸入將被假定為更大的值。它將排名高于第二個輸入。如果 y == -1,則第二個輸入將排名更高。
應用場景:排名問題
使用示例
input_one = torch.randn(3, requires_grad=True)
input_two = torch.randn(3, requires_grad=True)
target = torch.randn(3).sign()
ranking_loss = torch.nn.MarginRankingLoss()
output = ranking_loss(input_one, input_two, target)
Triplet Margin Loss
計算三元組的損失。
表達式:
三元組由a (anchor),p (正樣本) 和 n (負樣本)組成.
應用場景:
確定樣本之間的相對相似性
用于基于內容的檢索問題
使用示例
anchor = torch.randn(100, 128, requires_grad=True)
positive = torch.randn(100, 128, requires_grad=True)
negative = torch.randn(100, 128, requires_grad=True)
triplet_margin_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2)
output = triplet_margin_loss(anchor, positive, negative)
KL Divergence Loss
計算兩個概率分布之間的差異。
表達式:
輸出表示兩個概率分布的接近程度。如果預測的概率分布與真實的概率分布相差很遠,就會導致很大的損失。如果 KL Divergence 的值為零,則表示概率分布相同。
KL Divergence 與交叉熵損失的關鍵區(qū)別在于它們如何處理預測概率和實際概率。交叉熵根據預測的置信度懲罰模型,而 KL Divergence 則沒有。KL Divergence 僅評估概率分布預測與ground truth分布的不同之處。
應用場景:逼近復雜函數多類分類任務確保預測的分布與訓練數據的分布相似
使用示例
input = torch.randn(2, 3, requires_grad=True)
target = torch.randn(2, 3)
kl_loss = torch.nn.KLDivLoss(reduction = 'batchmean')
output = kl_loss(input, target)
原文鏈接:https://neptune.ai/blog/pytorch-loss-functions
本文在此鏈接的基礎上進行一部分而來修改。
??
交流群
歡迎加入公眾號讀者群一起和同行交流,目前有美顏、三維視覺、計算攝影、檢測、分割、識別、醫(yī)學影像、GAN、算法競賽等微信群
個人微信(如果沒有備注不拉群!) 請注明:地區(qū)+學校/企業(yè)+研究方向+昵稱
下載1:何愷明頂會分享
在「AI算法與圖像處理」公眾號后臺回復:何愷明,即可下載。總共有6份PDF,涉及 ResNet、Mask RCNN等經典工作的總結分析
下載2:終身受益的編程指南:Google編程風格指南
在「AI算法與圖像處理」公眾號后臺回復:c++,即可下載。歷經十年考驗,最權威的編程規(guī)范!
下載3 CVPR2021 在「AI算法與圖像處理」公眾號后臺回復:CVPR,即可下載1467篇CVPR?2020論文 和 CVPR 2021 最新論文

