Generative Adversarial Network系列: Vanilla GAN和Divergence

Shu-Yu Huang
Feb 20, 2021

--

最初的GAN以及最核心的loss

GAN框架

GAN的框架是這樣:輸入一些feature進Generator,輸出generated data (in domain A),在送進Discriminator,判別這些generated data是否跟real data(in domain B)是屬於同個domain的資料,例如都是手寫文字。這些feature跟data可以是各種不同形式包含聲音、影像文字、或者電訊號、還有雜訊等等。
元祖的Vanilla GAN的目標是做目標domain(domain B)的資料生成,以達成data augmentation的效果,為了做randomized augmentation,他的輸入就是使用雜訊而已,希望透過訓練Generator(後面簡稱G網路)及Discriminator(後面簡稱D網路)來增進資料擴增的效果。

而具體訓練方法是這樣:
首先先使G網路生出第一個batch的資料丟進D網路,用D網路去區辨它和真實資料。兩個網路有不同的目標,G網路的目標是騙過D網路讓它無法分辨生成跟真實資料;而D網路的目標就是希望正確區分真實和生成資料。這兩個網路會輪流訓練,train G網路時D網路會freeze住;接下來train D網路時G網路會freeze住。當D網路訓練到能分辨後,G網路會訓練到能欺騙,接著D網路再更精進的練到能分辨,如此往復,會讓Generator生出來的資料越來越像真實資料。

Generator的訓練

在數學上,Generator就是一個random vector/matrix z進入G網路後要在高維空間中產生一個matrix x=G(z)。這個x在圖像生成上可看成某個特定的圖。由G網路的生成母空間中產生這個圖的機率是PG(x),那由原本dataset母空間中挑出這個圖的機率是Pdata(x),Generator會希望PG(x)的分布跟Pdata(x)的分布越接近越好。那就有了以下的empirical risk minimization公式:

意思是希望這個G可使PG,Pdata間的divergence最小。

-李弘毅老師投影片-

這個divergence粗略來講可以說是兩個分布之間的差異,當兩組機率分布差越多,divergence越大,反之則越小。例如下圖這兩個機率密度分布(藍色和綠色),就差很多,那我們會期待有個公式可以算出他們兩個的差異很大。

Discriminator的訓練

另一方面D網路的訓練與G網路相反,是希望PG(x)的分布跟Pdata(x)的分布越不像越好:

那實際上我們minimize binary cross entropy(BCE)就有這樣效果。
BCE是這樣:

中間x是任意可能的圖,y表達是否為真實data,1是true、0是fake。那我們把BCE取負號:

而minimize BCE等於maximize -BCE。在要maximize的情況下,左邊的期望值代表希望data進去會predict 1;右邊的期望值代表希望GAN出來的資料進去會predict 0。
這個目標函式的最佳解為

(推導見Saxena, D., & Cao, J. (2020). Generative Adversarial Networks (GANs): Challenges, Solutions, and Future Directions. arXiv preprint arXiv:2005.00065.)

帶回原式,會變這樣:

其中KL指KL divergence,JS指JS divergence。
上述式子證明最佳化D網路時可以使得D網路等效於計算JS divergence,既然是某個divergence就可以拿來計算兩個分布差異程度。

也就是說,優化過的D網路可以拿來訓練G網路。

而實際上要訓練G網路會把上述式子簡化成為:

意思是希望生成資料進去最好的D可以騙過最好的D網路。

熵是什麼,KL/JS Divergence又是什麼

熵(Entropy)是事件帶來資訊量期望值。我們用氣象預報的例子來說明:
假設今天事件有兩種,一個晴天一個雨天,機率各半。那樣若是今天晴天,則可帶來1 bit的消息,或者說,可以使用一個bit來傳遞這個消息,雨天亦然。那這個傳遞消息使用bit數的期望值就是1,這就是這個機率分布的entropy。

那站在氣象局的角度,今天在看到晴天資訊前,所知道的機率是晴雨各50%,而知道今天晴天以後,就修正為晴天100%、雨天0%
那剛剛計算的bit數在這邊也可以看做表達機率修正比率所需的bit數。

在可能的天氣變多的時候,表達天氣所需要的bit數(entropy),也隨之上漲,

而每種天氣出現機率分布不均時,entropy會小於平均分布。可以想見,如果天氣只會有晴天就不用傳了,所需bit數是0。

在更細的來講可以用frequency encoding的做法來理解,機率高的事件採用比較高優先權的bit來表達,會降低機率分布不均時的bit數期望值。
如下圖:
先看bit0,如果是0則表示晴天,這個出現機率最高,那有很高的機率不需要看到後面的bit。若bit0不是0再來看bit 1,若bit1為0則表示晴時多雲,這第二高機率。若bit1不是0再來看bit 2,若bit2為0是雨天,否則為雷雨。
以上情況代表50%機率只看1個bit、25%機率看2個bits、25%機率看3個bits,平均看1.75個bits,跟上面算的一樣。

剛解釋完Entropy,那cross entropy就發生在實際發生機率和編碼方式不同的時候。
例如在冬天我觀察到雨天比較多那我的模型中雨天佔75%、晴天25%,我依此用較少bit來編碼我傳輸的方式。
可是來到夏天的時候晴雨各半,那我依然用冬天方式來傳資料,這樣會使平均所需的bit數改變。
這個誤用時的平均bit數,也就是以第二種資訊以第一種機率傳輸的平均資訊量。

那KL divergence公式中將log中的被除數拿出來化為減去log可以得到一個相減的式子。表明KL divergence就是Cross entropy減去entropy。
第一,若以上述機率修正比來看,就是傳輸Pdata到PG的修正比所需平均資訊量。
第二,若看作是cross entropy減去entropy的話,就是在兩個分布的資訊量上的差值以PGPG為機率分布的積分,基本上Pdata比PG大的地方會是負的影響,反之則有正向影響。那JS divergence就是從兩分布的均分分布(x點的機率是(Pdata(x)+PG(x))/2)對兩種分布的KL divergence平均值。

用積分來講可以解釋為何後期要用到Wessestain Distance而不用JS divergence。算JS divergence就是要算上述那個面積積分的量,如果兩個分布完全沒交集,在兩個分布形狀不改變的情況下不論離得多近多遠,JS divergence都是一樣值,這樣會導致無法做gradient descent。

--

--

Shu-Yu Huang

AI engineer in Taiwan AI Academy| Former process engineer in TSMC| Former Research Assistance in NYMU| Studying few-shot-learning and GAN