1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        損失函數(shù)技術(shù)總結(jié)及Pytorch使用示例

        共 1743字,需瀏覽 4分鐘

         ·

        2022-04-12 02:47

        ↑ 點(diǎn)擊藍(lán)字?關(guān)注極市平臺(tái)

        作者丨仿佛若有光
        來源丨CV技術(shù)指南
        編輯丨極市平臺(tái)

        極市導(dǎo)讀

        ?

        本文對(duì)損失函數(shù)的類別和應(yīng)用場(chǎng)景,常見的損失函數(shù),常見損失函數(shù)的表達(dá)式,特性,應(yīng)用場(chǎng)景和使用示例作了詳細(xì)的總結(jié)。?>>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺的最前沿

        前言

        一直想寫損失函數(shù)的技術(shù)總結(jié),但網(wǎng)上已經(jīng)有諸多關(guān)于損失函數(shù)綜述的文章或博客,考慮到這點(diǎn)就一直拖著沒寫,直到有一天,我將一個(gè)二分類項(xiàng)目修改為多分類,簡(jiǎn)簡(jiǎn)單單地修改了損失函數(shù),結(jié)果一直有問題,后來才發(fā)現(xiàn)是不同函數(shù)的標(biāo)簽的設(shè)置方式并不相同。

        為了避免讀者也出現(xiàn)這樣的問題,本文中會(huì)給出每個(gè)損失函數(shù)的pytorch使用示例,這也是本文與其它相關(guān)綜述文章或博客的區(qū)別所在。希望讀者在閱讀本文時(shí),重點(diǎn)關(guān)注一下每個(gè)損失函數(shù)的使用示例中的target的設(shè)置問題。

        本文對(duì)損失函數(shù)的類別和應(yīng)用場(chǎng)景,常見的損失函數(shù),常見損失函數(shù)的表達(dá)式,特性,應(yīng)用場(chǎng)景和使用示例作了詳細(xì)的總結(jié)。

        主要涉及到L1 loss、L2 loss、Negative Log-Likelihood loss、Cross-Entropy loss、Hinge Embedding loss、Margin Ranking Loss、Triplet Margin loss、KL Divergence.

        損失函數(shù)分類與應(yīng)用場(chǎng)景

        損失函數(shù)可以分為三類:回歸損失函數(shù)(Regression loss)、分類損失函數(shù)(Classification loss)和排序損失函數(shù)(Ranking loss)。

        應(yīng)用場(chǎng)景:回歸損失:用于預(yù)測(cè)連續(xù)的值。如預(yù)測(cè)房?jī)r(jià)、年齡等。分類損失:用于預(yù)測(cè)離散的值。如圖像分類,語(yǔ)義分割等。排序損失:用于預(yù)測(cè)輸入數(shù)據(jù)之間的相對(duì)距離。如行人重識(shí)別。

        L1 loss

        也稱Mean Absolute Error,簡(jiǎn)稱MAE,計(jì)算實(shí)際值和預(yù)測(cè)值之間的絕對(duì)差之和的平均值。

        表達(dá)式如下:

        y表示標(biāo)簽,pred表示預(yù)測(cè)值。

        應(yīng)用場(chǎng)合:回歸問題。

        根據(jù)損失函數(shù)的表達(dá)式很容易了解它的特性:當(dāng)目標(biāo)變量的分布具有異常值時(shí),即與平均值相差很大的值,它被認(rèn)為對(duì)異常值具有很好的魯棒行。

        使用示例:

        input = torch.randn(3, 5, requires_grad=True)
        target = torch.randn(3, 5)

        mae_loss = torch.nn.L1Loss()
        output = mae_loss(input, target)

        L2 loss

        也稱為Mean Squared Error,簡(jiǎn)稱MSE,計(jì)算實(shí)際值和預(yù)測(cè)值之間的平方差的平均值。

        表達(dá)式如下:

        應(yīng)用場(chǎng)合:對(duì)大部分回歸問題,pytorch默認(rèn)使用L2,即MSE。

        使用平方意味著當(dāng)預(yù)測(cè)值離目標(biāo)值更遠(yuǎn)時(shí)在平方后具有更大的懲罰,預(yù)測(cè)值離目標(biāo)值更近時(shí)在平方后懲罰更小,因此,當(dāng)異常值與樣本平均值相差格外大時(shí),模型會(huì)因?yàn)閼土P更大而開始偏離,相比之下,L1對(duì)異常值的魯棒性更好。

        使用示例:

        input = torch.randn(3, 5, requires_grad=True)
        target = torch.randn(3, 5)
        mse_loss = torch.nn.MSELoss()
        output = mse_loss(input, target)

        Negative Log-Likelihood

        簡(jiǎn)稱NLL。表達(dá)式如下:

        應(yīng)用場(chǎng)景:多分類問題。

        注:NLL要求網(wǎng)絡(luò)最后一層使用softmax作為激活函數(shù)。通過softmax將輸出值映射為每個(gè)類別的概率值。

        根據(jù)表達(dá)式,它的特性是懲罰預(yù)測(cè)準(zhǔn)確而預(yù)測(cè)概率不高的情況。

        NLL 使用負(fù)號(hào),因?yàn)楦怕剩ɑ蛩迫唬┰?0 和 1 之間變化,并且此范圍內(nèi)的值的對(duì)數(shù)為負(fù)。最后,損失值變?yōu)檎怠?/p>

        在 NLL 中,最小化損失函數(shù)有助于獲得更好的輸出。從近似最大似然估計(jì) (MLE) 中檢索負(fù)對(duì)數(shù)似然。這意味著嘗試最大化模型的對(duì)數(shù)似然,從而最小化 NLL。

        使用示例

        # size of input (N x C) is = 3 x 5
        input = torch.randn(3, 5, requires_grad=True)
        # every element in target should have 0 <= value < C
        target = torch.tensor([1, 0, 4])

        m = nn.LogSoftmax(dim=1)
        nll_loss = torch.nn.NLLLoss()
        output = nll_loss(m(input), target)

        Cross-Entropy

        此損失函數(shù)計(jì)算提供的一組出現(xiàn)次數(shù)或隨機(jī)變量的兩個(gè)概率分布之間的差異。它用于計(jì)算預(yù)測(cè)值與實(shí)際值之間的平均差異的分?jǐn)?shù)。

        表達(dá)式:

        應(yīng)用場(chǎng)景:二分類及多分類。

        特性:負(fù)對(duì)數(shù)似然損失不對(duì)預(yù)測(cè)置信度懲罰,與之不同的是,交叉熵懲罰不正確但可信的預(yù)測(cè),以及正確但不太可信的預(yù)測(cè)。

        交叉熵函數(shù)有很多種變體,其中最常見的類型是Binary Cross-Entropy (BCE)。BCE Loss 主要用于二分類模型;也就是說,模型只有 2 個(gè)類。

        使用示例

        input = torch.randn(3, 5, requires_grad=True)
        target = torch.empty(3, dtype=torch.long).random_(5)

        cross_entropy_loss = torch.nn.CrossEntropyLoss()
        output = cross_entropy_loss(input, target)

        Hinge Embedding

        表達(dá)式:

        其中y為1或-1。

        應(yīng)用場(chǎng)景:

        分類問題,特別是在確定兩個(gè)輸入是否不同或相似時(shí)。

        學(xué)習(xí)非線性嵌入或半監(jiān)督學(xué)習(xí)任務(wù)。

        使用示例

        input = torch.randn(3, 5, requires_grad=True)
        target = torch.randn(3, 5)

        hinge_loss = torch.nn.HingeEmbeddingLoss()
        output = hinge_loss(input, target)

        Margin Ranking Loss

        Margin Ranking Loss 計(jì)算一個(gè)標(biāo)準(zhǔn)來預(yù)測(cè)輸入之間的相對(duì)距離。這與其他損失函數(shù)(如 MSE 或交叉熵)不同,后者學(xué)習(xí)直接從給定的輸入集進(jìn)行預(yù)測(cè)。

        表達(dá)式:

        標(biāo)簽張量 y(包含 1 或 -1)。當(dāng) y == 1 時(shí),第一個(gè)輸入將被假定為更大的值。它將排名高于第二個(gè)輸入。如果 y == -1,則第二個(gè)輸入將排名更高。

        應(yīng)用場(chǎng)景:排名問題

        使用示例

        input_one = torch.randn(3, requires_grad=True)
        input_two = torch.randn(3, requires_grad=True)
        target = torch.randn(3).sign()

        ranking_loss = torch.nn.MarginRankingLoss()
        output = ranking_loss(input_one, input_two, target)

        Triplet Margin Loss

        計(jì)算三元組的損失。

        表達(dá)式:

        三元組由a (anchor),p (正樣本) 和 n (負(fù)樣本)組成.

        應(yīng)用場(chǎng)景:

        確定樣本之間的相對(duì)相似性

        用于基于內(nèi)容的檢索問題

        使用示例

        anchor = torch.randn(100, 128, requires_grad=True)
        positive = torch.randn(100, 128, requires_grad=True)
        negative = torch.randn(100, 128, requires_grad=True)

        triplet_margin_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2)
        output = triplet_margin_loss(anchor, positive, negative)

        KL Divergence Loss

        計(jì)算兩個(gè)概率分布之間的差異。

        表達(dá)式:

        輸出表示兩個(gè)概率分布的接近程度。如果預(yù)測(cè)的概率分布與真實(shí)的概率分布相差很遠(yuǎn),就會(huì)導(dǎo)致很大的損失。如果 KL Divergence 的值為零,則表示概率分布相同。

        KL Divergence 與交叉熵?fù)p失的關(guān)鍵區(qū)別在于它們?nèi)绾翁幚眍A(yù)測(cè)概率和實(shí)際概率。交叉熵根據(jù)預(yù)測(cè)的置信度懲罰模型,而 KL Divergence 則沒有。KL Divergence 僅評(píng)估概率分布預(yù)測(cè)與ground truth分布的不同之處。

        應(yīng)用場(chǎng)景:逼近復(fù)雜函數(shù)多類分類任務(wù)確保預(yù)測(cè)的分布與訓(xùn)練數(shù)據(jù)的分布相似

        使用示例

        input = torch.randn(2, 3, requires_grad=True)
        target = torch.randn(2, 3)

        kl_loss = torch.nn.KLDivLoss(reduction = 'batchmean')
        output = kl_loss(input, target)

        原文鏈接:https://neptune.ai/blog/pytorch-loss-functions
        本文在此鏈接的基礎(chǔ)上進(jìn)行一部分而來修改。


        △點(diǎn)擊卡片關(guān)注極市平臺(tái),獲取最新CV干貨


        極市干貨
        YOLO教程:一文讀懂YOLO V5 與 YOLO V4大盤點(diǎn)|YOLO 系目標(biāo)檢測(cè)算法總覽全面解析YOLO V4網(wǎng)絡(luò)結(jié)構(gòu)
        實(shí)操教程:PyTorch vs LibTorch:網(wǎng)絡(luò)推理速度誰(shuí)更快?只用兩行代碼,我讓Transformer推理加速了50倍
        算法技巧(trick):深度學(xué)習(xí)訓(xùn)練tricks總結(jié)(有實(shí)驗(yàn)支撐)深度強(qiáng)化學(xué)習(xí)調(diào)參Tricks合集


        #?CV技術(shù)社群邀請(qǐng)函?#

        △長(zhǎng)按添加極市小助手
        添加極市小助手微信(ID : cvmart2)

        備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測(cè)-深圳)


        即可申請(qǐng)加入極市目標(biāo)檢測(cè)/圖像分割/工業(yè)檢測(cè)/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群


        每月大咖直播分享、真實(shí)項(xiàng)目需求對(duì)接、求職內(nèi)推、算法競(jìng)賽、干貨資訊匯總、與?10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動(dòng)交流~



        覺得有用麻煩給個(gè)在看啦~??
        瀏覽 43
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評(píng)論
        圖片
        表情
        推薦
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            先锋影音一区二区三区 | 水多多成人免费A片 | 日日夜夜免费精品 | 97se亚洲综合自在线尤物 | 国产午夜91 | 特黄AV | 久操最新地址 | 第四色狠狠 | xxfree性人妖hd丝袜 | 丰满奶水二区三区在线 |