1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        輕松學(xué)Pytorch – 構(gòu)建生成對(duì)抗網(wǎng)絡(luò)

        共 5229字,需瀏覽 11分鐘

         ·

        2022-05-24 10:10

        點(diǎn)擊上方小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂

        重磅干貨,第一時(shí)間送達(dá)

        又好久沒有繼續(xù)寫了,這個(gè)是我寫的第21篇文章,我還在繼續(xù)堅(jiān)持寫下去,雖然經(jīng)常各種拖延癥,但是我還記得,一直沒有敢忘記!今天給大家分享一下Pytorch生成對(duì)抗網(wǎng)絡(luò)代碼實(shí)現(xiàn)。

        ?

        01.什么是生成對(duì)抗網(wǎng)絡(luò)


        Ian J. Goodfellow在2014年提出生成對(duì)抗網(wǎng)絡(luò),從此打開了深度學(xué)習(xí)中另外一個(gè)重要分支,讓生成對(duì)抗網(wǎng)絡(luò)(GAN)成為與卷積神經(jīng)網(wǎng)絡(luò)(CNN)、循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN/LSTM)可以并駕齊驅(qū)的分支領(lǐng)域。今天GAN仍然是計(jì)算機(jī)視覺領(lǐng)域研究熱點(diǎn)之一,每年還有大量相關(guān)的論文產(chǎn)生,GAN已經(jīng)被用在視覺任務(wù)的很多方面,主要包括:

        • 圖像合成與數(shù)據(jù)增廣

        • 圖像翻譯與變換

        • 缺陷檢測(cè)

        • 圖像去噪與重建

        • 圖像分割

        但是GAN最基本的核心思想還是2014年Ian J. Goodfellow在論文中提到的兩個(gè)基本的模型分別是:生成器與判別器

        生成器(G):

        根據(jù)輸入噪聲Z生成輸出樣本G(z)目標(biāo):通過生成樣本與目標(biāo)樣本分布一致,成功欺騙鑒別器

        判別器(D):

        根據(jù)輸入樣本數(shù)據(jù)來分辨真實(shí)樣本概率從數(shù)據(jù)中學(xué)習(xí)樣本數(shù)據(jù)的差異性

        從a到d,可以看到輸入噪聲的生成分布越來越接近真實(shí)分布X,最終達(dá)到一種平衡狀態(tài),這種穩(wěn)定的平衡狀態(tài)叫納什均衡,還有一部電影跟這個(gè)有關(guān)系叫《美麗心靈》。

        ?

        02.GAN代碼實(shí)現(xiàn)


        下面的代碼實(shí)現(xiàn)了基于Mnist數(shù)據(jù)集實(shí)現(xiàn)判別器與生成器,最終通過生成器可以自動(dòng)生成手寫數(shù)字識(shí)別的圖像,輸入的z=100是隨機(jī)噪聲,輸出的是784個(gè)數(shù)據(jù)表示28x28大小的手寫數(shù)字樣本,損失主要來自兩個(gè)部分,生成器生成損失,判別器分別判別真實(shí)與虛構(gòu)樣本概率,基于反向傳播訓(xùn)練兩個(gè)網(wǎng)絡(luò),設(shè)置epoch=100,得到最終的生成器生成結(jié)果如下:


        生成器與判別器代碼實(shí)現(xiàn)如下


        判別器與生成器代碼:(后面文字忽略)2004論文中提出,其主要思想可以通過下面一張圖像解釋:

         1transform?=?tv.transforms.Compose([tv.transforms.ToTensor(),
        2???????????????????????????????????tv.transforms.Normalize((0.5,),?(0.5,))])
        3train_ts?=?tv.datasets.MNIST(root='./data',?train=True,?download=True,?transform=transform)
        4test_ts?=?tv.datasets.MNIST(root='./data',?train=False,?download=True,?transform=transform)
        5train_dl?=?DataLoader(train_ts,?batch_size=128,?shuffle=True,?drop_last=False)
        6test_dl?=?DataLoader(test_ts,?batch_size=128,?shuffle=True,?drop_last=False)
        7
        8
        9class?Generator(t.nn.Module):
        10????def?__init__(self,?g_input_dim,?g_output_dim):
        11????????super(Generator,?self).__init__()
        12????????self.fc1?=?t.nn.Linear(g_input_dim,?256)
        13????????self.fc2?=?t.nn.Linear(self.fc1.out_features,?self.fc1.out_features?*?2)
        14????????self.fc3?=?t.nn.Linear(self.fc2.out_features,?self.fc2.out_features?*?2)
        15????????self.fc4?=?t.nn.Linear(self.fc3.out_features,?g_output_dim)
        16
        17????#?forward?method
        18????def?forward(self,?x):
        19????????x?=?F.leaky_relu(self.fc1(x),?0.2)
        20????????x?=?F.leaky_relu(self.fc2(x),?0.2)
        21????????x?=?F.leaky_relu(self.fc3(x),?0.2)
        22????????return?t.tanh(self.fc4(x))
        23
        24
        25class?Discriminator(t.nn.Module):
        26????def?__init__(self,?d_input_dim):
        27????????super(Discriminator,?self).__init__()
        28????????self.fc1?=?t.nn.Linear(d_input_dim,?1024)
        29????????self.fc2?=?t.nn.Linear(self.fc1.out_features,?self.fc1.out_features?//?2)
        30????????self.fc3?=?t.nn.Linear(self.fc2.out_features,?self.fc2.out_features?//?2)
        31????????self.fc4?=?t.nn.Linear(self.fc3.out_features,?1)
        32
        33????#?forward?method
        34????def?forward(self,?x):
        35????????x?=?F.leaky_relu(self.fc1(x),?0.2)
        36????????x?=?F.dropout(x,?0.3)
        37????????x?=?F.leaky_relu(self.fc2(x),?0.2)
        38????????x?=?F.dropout(x,?0.3)
        39????????x?=?F.leaky_relu(self.fc3(x),?0.2)
        40????????x?=?F.dropout(x,?0.3)
        41????????return?t.sigmoid(self.fc4(x))


        損失與訓(xùn)練代碼如下


        分別定義生成網(wǎng)絡(luò)訓(xùn)練與鑒別網(wǎng)絡(luò)的訓(xùn)練方法,然后開始訓(xùn)練即可,代碼實(shí)現(xiàn)如下:

         1#?生成者與判別者
        2bs?=?128
        3z_dim?=?100
        4mnist_dim?=?784
        5#?loss
        6criterion?=?t.nn.BCELoss()
        7
        8#?optimizer
        9device?=?"cuda"
        10gnet?=?Generator(g_input_dim?=?z_dim,?g_output_dim?=?mnist_dim).to(device)
        11dnet?=?Discriminator(mnist_dim).to(device)
        12lr?=?0.0002
        13G_optimizer?=?t.optim.Adam(gnet.parameters(),?lr=lr)
        14D_optimizer?=?t.optim.Adam(dnet.parameters(),?lr=lr)
        15
        16
        17def?D_train(x):
        18????#?=======================Train?the?discriminator=======================#
        19????dnet.zero_grad()
        20
        21????#?train?discriminator?on?real
        22????x_real,?y_real?=?x.view(-1,?mnist_dim),?t.ones(bs,?1)
        23????x_real,?y_real?=?Variable(x_real.to(device)),?Variable(y_real.to(device))
        24
        25????D_output?=?dnet(x_real)
        26????D_real_loss?=?criterion(D_output,?y_real)
        27
        28????#?train?discriminator?on?facke
        29????z?=?Variable(t.randn(bs,?z_dim).to(device))
        30????x_fake,?y_fake?=?gnet(z),?Variable(t.zeros(bs,?1).to(device))
        31
        32????D_output?=?dnet(x_fake)
        33????D_fake_loss?=?criterion(D_output,?y_fake)
        34
        35????#?gradient?backprop?&?optimize?ONLY?D's?parameters
        36????D_loss?=?D_real_loss?+?D_fake_loss
        37????D_loss.backward()
        38????D_optimizer.step()
        39
        40????return?D_loss.data.item()
        41
        42
        43def?G_train(x):
        44????#?=======================Train?the?generator=======================#
        45????gnet.zero_grad()
        46
        47????z?=?Variable(t.randn(bs,?z_dim).to(device))
        48????y?=?Variable(t.ones(bs,?1).to(device))
        49
        50????G_output?=?gnet(z)
        51????D_output?=?dnet(G_output)
        52????G_loss?=?criterion(D_output,?y)
        53
        54????#?gradient?backprop?&?optimize?ONLY?G's?parameters
        55????G_loss.backward()
        56????G_optimizer.step()
        57
        58????return?G_loss.data.item()
        59
        60
        61n_epoch?=?100
        62for?epoch?in?range(1,?n_epoch+1):
        63????D_losses,?G_losses?=?[],?[]
        64????for?batch_idx,?(x,?_)?in?enumerate(train_dl):
        65????????bs_,?_,_,_?=?x.size()
        66????????bs?=?bs_
        67????????D_losses.append(D_train(x))
        68????????G_losses.append(G_train(x))
        69
        70????print('[%d/%d]:?loss_d:?%.3f,?loss_g:?%.3f'?%?(
        71????????????(epoch),?n_epoch,?t.mean(t.FloatTensor(D_losses)),?t.mean(t.FloatTensor(G_losses))))



        下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程
        在「小白學(xué)視覺」公眾號(hào)后臺(tái)回復(fù):擴(kuò)展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺、目標(biāo)跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。

        下載2:Python視覺實(shí)戰(zhàn)項(xiàng)目52講
        小白學(xué)視覺公眾號(hào)后臺(tái)回復(fù):Python視覺實(shí)戰(zhàn)項(xiàng)目,即可下載包括圖像分割、口罩檢測(cè)、車道線檢測(cè)、車輛計(jì)數(shù)、添加眼線、車牌識(shí)別、字符識(shí)別、情緒檢測(cè)、文本內(nèi)容提取、面部識(shí)別等31個(gè)視覺實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺。

        下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講
        小白學(xué)視覺公眾號(hào)后臺(tái)回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。

        交流群


        歡迎加入公眾號(hào)讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動(dòng)駕駛、計(jì)算攝影、檢測(cè)、分割、識(shí)別、醫(yī)學(xué)影像、GAN算法競(jìng)賽等微信群(以后會(huì)逐漸細(xì)分),請(qǐng)掃描下面微信號(hào)加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三?+?上海交大?+?視覺SLAM“。請(qǐng)按照格式備注,否則不予通過。添加成功后會(huì)根據(jù)研究方向邀請(qǐng)進(jìn)入相關(guān)微信群。請(qǐng)勿在群內(nèi)發(fā)送廣告,否則會(huì)請(qǐng)出群,謝謝理解~


        瀏覽 51
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評(píng)論
        圖片
        表情
        推薦
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            欧美在线A片 | 特级西西4444wwww人体视频 | 97国产精品A片 | 男人天堂网在线视频观看 | 国产精品久久久久久久久久直播 | 亚洲 自拍 另类小说 | 久久一 | 女人脱下内裤让男人捅 | 考逼网站 | 欧美成人影片一区 |