GAN的入門與實踐
點擊上方“小白學視覺”,選擇加"星標"或“置頂”
重磅干貨,第一時間送達
引言
生成對抗網(wǎng)絡(Generative Adversarial Nets,GAN)是由open ai研究員Good fellow在2014年提出的一種生成式模型,自從提出后在深度學習領域收到了廣泛的關注和研究。目前,深度學習領域的圖像生成,風格遷移,圖像變換,圖像描述,無監(jiān)督學習,甚至強化學習領域都能看到GAN 的身影。GAN主要針對的是一種生成類問題。目前深度學習領域可以分為兩大類,其中一個是檢測識別,比如圖像分類,目標識別等等,此類模型主要是VGG, GoogLenet,residual net等等,目前幾乎所有的網(wǎng)絡都是基于識別的;另一種是圖像生成,即解決如何從一些數(shù)據(jù)里生成出圖像的問題,生成類模型主要有深度信念網(wǎng)(DBN),變分自編碼器(VAE)。而某種程度上,在生成能力上,GAN遠遠超過DBN、VAE。經(jīng)過改進后的GAN足以生成以假亂真的圖像。本文將首先介紹一些GAN 的原理和公式推導,另外會詳細給出GAN生成圖像的Tensorflow的實現(xiàn),基于python語言。
GAN主要解決的是生成類問題,即如何從一段任意的隨機數(shù)中生成圖像。假設給定一段100維的向量X{x1, x2,…, x100 }作為網(wǎng)絡的輸入,其中x是產(chǎn)生的隨機數(shù),一般按照高斯分布或者均勻分布產(chǎn)生,GAN通過對抗訓練的方式,可以生成清晰的圖像,這個過程是通過GAN不斷模擬訓練集中圖像的像素分布來實現(xiàn)的??赐晗挛腉AN的原理后或許你會對這個過程有一個清晰的認識。

圖1?
首先,附上一張GAN的網(wǎng)絡流程圖,如圖1所示。不同于以往的判別網(wǎng)絡模型,GAN包括兩個網(wǎng)絡模型,一個生成模型G(generator)和一個判別模型D(discriminator),其中D就是識別檢測類模型中經(jīng)常使用的網(wǎng)絡。GAN的大概流程是,G以隨機噪聲作為輸入,生成出一張圖像G(z),暫且不管生成質(zhì)量多好,然后D以G(z)和真實圖像x作為輸入,對G(z)和x做一個二分類,檢測誰是真實圖像誰是生成的假圖像。D的輸出是一個概率值,比如G(z)作為輸入時D輸出0.15,那么代表D認為G(z)有15%的概率是真圖像。然后G和D會根據(jù)D輸出的情況不斷改進自己,G提高G(z)和x的相似度,盡可能的欺騙D,而D則會通過學習盡可能的不被G欺騙。二者相當于是做一個極大極小的博弈過程,稱為零和博弈。可以用一個簡單的例子描述他們之間的過程,我們把G想象成制造假幣的團伙,視D為警察,G不斷產(chǎn)生假幣,而D任務就是從真錢幣中分辨出G的假幣,剛開始時,G沒有經(jīng)驗,制造的假幣太假,D很容易就能分辨出來,所以G不斷改進自己的技術,產(chǎn)生的假幣越來越真實,D可能就沒有那么容易判別出真假了,所以D也根據(jù)自己的情況不斷改進自己,經(jīng)過很多次這樣的循環(huán)之后,G產(chǎn)生的假幣足以以假亂真了,D很難分出真假。對應到圖像生成上,此時G足以生成出一般的分類神經(jīng)網(wǎng)絡分辨不出真假的圖像了,G從而獲得了生成圖像的能力。
與傳統(tǒng)神經(jīng)網(wǎng)絡訓練不一樣的且有趣的地方,就是訓練生成器的方法不同,生成器參數(shù)的更新來自于D的反傳梯度。生成器一心想要“騙過”判別器。使用博弈理論分析技術,可以證明這里面存在一種納什均衡。

這里就是他們的損失函數(shù)定義,實際上是一個交叉熵,判別器的目的是盡可能的令D(x)接近1,令D(G(z))接近0,所以D主要是最大化上面的損失函數(shù),G恰恰相反,他主要是最小化上述損失函數(shù)。
訓練過程:

(圖2)
圖2展示了GAN訓練的偽代碼,首先在迭代次數(shù)范圍內(nèi),首先對z和x采樣一個批次,獲得他們的數(shù)據(jù)分布,然后通過隨機梯度下降的方法先對D做k次更新,之后對G做一次更新,這樣做的主要目的是保證D一直有足夠的能力去分辨真假。實際在代碼中我們可能會多更新幾次G只更新一次D,不然D學習的太好,會導致訓練前期發(fā)生梯度消失的問題。
在求平衡點之前,我們先做一個數(shù)學假設,即G固定情況下D的最優(yōu)形式,然后根據(jù)D的最優(yōu)形式再去觀察G最小化損失函數(shù)的問題。
假設在G固定的條件下,并將損失函數(shù)化為如下簡單形式:

D的目標就是最大化L,我們可以通過對L求導,并令導數(shù)為0,計算出L取最大值時y的取值如下:

所以,換為原來的式子D的最優(yōu)解形式為:

到這里我們得出了結論,當G固定時,D的最優(yōu)形式是上面形式。
接下來我們求一下D最優(yōu)時,G最小化損失函數(shù)到什么形式才能達到二者相互博弈的平衡點。
帶入到損失函數(shù)里面后,損失函數(shù)可以寫為如下形式:

這時觀察到,上面式子仍然是一個交叉熵也稱KL散度的形式,KL散度通常用來衡量分布之間的距離,它是非對稱的。同樣還有另一個衡量數(shù)據(jù)分布距離的散度--JS散度,他們之間有如下關系。

不過JS散度有一個很重要的性質(zhì)就是總是大于等于0的,當且僅當 P1=P2上面的式子取得最小值0,
所以我們可以將C(G)寫成JS散度的形式:

也即是當且僅當Pg=Pdata時,C(G)取得最小值-log(4),也即是D最優(yōu)時,G能將損失函數(shù)最小化到-log(4),最小點處Pg=Pdata。即真實數(shù)據(jù)的分布和生成數(shù)據(jù)的分布相等。
分析到這里,直觀上也很好理解了,Pg=Pdata意味著此時D恰好等于0.5,就是D有一半的概率認為D(G(z))是真的數(shù)據(jù),有一半概率認為是假的數(shù)據(jù),這不就和猜硬幣正反面一樣嘛。也說明了此時G生成的數(shù)據(jù)足以以假亂真。
到這里,GAN的原理和數(shù)學推導就介紹完了,理論上說明了GAN只要循規(guī)蹈矩的訓練,G就可以完美的模擬數(shù)據(jù)分布并生成真實的圖像,但是我們做數(shù)學推導的時候為了證明方便做了一些假設,實際上并不是這樣,GAN存在訓練困難、梯度消失、模式崩潰的問題,這些問題在這里不做重點介紹。
首先,建立一個train.py文件,在文件里建立一個名為Train的類,在類的初始化函數(shù)里進行一些初始化:

Self.build_model()函數(shù)用來存放構建流圖部分的代碼,下面會介紹,其他初始化的都是一些簡單的參數(shù)。
下面先介紹生成器和判別器的網(wǎng)絡:

生成器傳進去三個參數(shù),分別是名字,輸入數(shù)據(jù),和一個bool型狀態(tài)變量reuse,用來表示生成器是否復用,reuse=True代表網(wǎng)絡復用,F(xiàn)alse代表不復用。
生成器一共包括1個全連接層和4個轉(zhuǎn)置卷積層,每一層后面都跟一個batchnorm層,激活函數(shù)都選擇relu。其中fc(),deconv2d()函數(shù)和bn()函數(shù)都是我們封裝好的函數(shù),代表全鏈接層,轉(zhuǎn)制卷積層,和歸一化層,其形式如下:

全連接層fc的輸入?yún)?shù)value指輸入向量,output_shape指經(jīng)過全連接層后輸出的向量維度,比如我們生成器這里噪聲向量維度是128,我們輸出的是4*4*8*64維。

其中Ksize指卷積核的大小,outshape指輸出的張量的shape,sted是一個bool類型的參數(shù),表示用不同的方式初始化參數(shù)
bn()函數(shù)我是直接放在了train的類里面,其形式如下:

我們都希望權重都能初始化到一個比較好的數(shù),所以這里我沒有直接用固定方差的高斯分布去初始化權重,而是根據(jù)每一層的輸入輸出通道數(shù)量的不同計算出一個合適的方差去做初始化。同理,我們還封裝了卷積操作,其形式如下:

好了,目前已經(jīng)介紹了生成器的結構和一些基本函數(shù),下面來介紹一下判別網(wǎng)絡,其代碼如下所示:

與生成器不同的是,我們使用leakrelu作為激活函數(shù),

這些函數(shù)的定義都是放在了layer.py文件里,


這里有兩個GAN可供選擇,DCGAN 和WGAN-GP,他們唯一不同的地方是損失函數(shù)的計算不同,網(wǎng)絡結構都是一樣的,二者都是GAN的改進版,WGAN-GP效果好更好一些,這里我們使用WGAN-GP。DCGAN訓練的時候容易遇到訓練不穩(wěn)定的問題。
?
到這里我們已經(jīng)介紹完了所有的初始化過程,接下來就是訓練數(shù)據(jù)的提取和網(wǎng)絡的訓練部分了,訓練數(shù)據(jù)我們使用cele名人數(shù)據(jù)集,一共20萬張圖像左右,數(shù)據(jù)集里的圖像size并不是很一致,我們可以使用一小段代碼把圖像的人臉截取下來,并resize到64*64大小。
代碼如下:

把數(shù)據(jù)集下載下來后解壓到img_align_celeba文件夾里面,然后運行face_detec.py就可以了,截取下來的圖像會放到64_crop文件夾里,本來有20萬張圖像的,截取過后就剩15萬張了。
?
下面就是訓練部分了,首先是讀取數(shù)據(jù),load_data()函數(shù)每次會讀取一個batch_size的數(shù)據(jù)作為網(wǎng)絡的輸入,在訓練過程中,我們選擇訓練一次D訓練兩次G,而不是訓練多次D之后訓練一次G,不然容易發(fā)生訓練不穩(wěn)定的問題,因為D總是學的太好,很容易就判別出真假,所以導致G不論怎么改進都沒有用,有些太打擊G的造假積極性了。

Plot()函數(shù)會每訓練100步后繪出網(wǎng)絡loss的變化圖像,是另外封裝的函數(shù)
同時我們選擇每訓練400步生成一張圖像,看一下生成器的效果。
load_data()函數(shù)我們并沒有使用隊列或者轉(zhuǎn)化為record文件讀取,這樣的方式肯定會快一些,讀取圖像我們使用scipy.misc 來讀取,
具體是import scipy.misc as scm

可以看到,我們首先對所有的圖像做一個排序,返回一個列表,列表里存放的是每個圖像的位置索引,這樣做就是每次將一個batch_size的數(shù)據(jù)讀到了內(nèi)存里,讀取的數(shù)據(jù)做了一個歸一化操作,我們選擇歸一化到[-0.5,+0.5]。
?
接下來就是展示結果的時候了,其中訓練過程loss的變化如下所示:


由圖可見,經(jīng)過一次比較大的震蕩之后,網(wǎng)絡就收斂的比較好了。
接下來是展示生成結果了:
我測試的時候設置了bach_size是16:
訓練1epoch的時候是這樣子的:

訓練一段時間后:

再往后訓練效果看上去反而差了一些,而且明顯沒有學習到眼鏡的特征(最后一行第二個)估計是數(shù)據(jù)集里眼鏡比較少,GAN學習不到足夠的特征,眼睛鼻子嘴巴學習的還是很好的。

訓練失敗的結果:

下面談一談我訓練GAN的感受,GAN是在是太難訓練了,即使是使用WGAN,WGAN-GP,還是遇到了訓練困難的問題,以上這些結果都是我做了好幾次實驗得出來的結果,有些實驗中間得到的生成結果其實是慘不忍睹的,就像是下面這樣,我總結了一部分原因,一個原因是網(wǎng)絡結構太簡單,我本次使用的網(wǎng)絡是幾年前流行的DCGAN的網(wǎng)絡結構,有很大的改進空間,現(xiàn)在基本上用的不多了,我也試了BEGAN,不得不說BEGAN是真好訓練,只要寫好代碼就讓他自己跑去吧,基本上不會出問題,而且效果還很好;另一個原因是優(yōu)化器的選擇和學習率等超參數(shù)的設置。設置好的超參數(shù)對GAN的訓練是很有幫助的,至于優(yōu)化器,盡量不要選擇SGD,因為GAN的平衡點是一個鞍點,鞍點附近梯度幾乎為0,使用梯度的優(yōu)化方法很難收斂到最優(yōu)點,另外就是SGD訓練震蕩,很容易引起訓練不穩(wěn)定。理論上是這樣,實際的問題比這復雜的多。
交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學影像、GAN、算法競賽等微信群(以后會逐漸細分),請掃描下面微信號加群,備注:”昵稱+學校/公司+研究方向“,例如:”張三?+?上海交大?+?視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~

