詳解機器學習損失函數之交叉熵

今天這篇文章和大家聊聊機器學習領域的


我在看paper的時候發現對於交叉熵的理解又有些遺忘,複習了一下之後,又有了一些新的認識。故寫下本文和大家分享。


熵這個概念應用非常廣泛,我個人認為比較經典的一個應用是在熱力學當中,反應一個系統的混亂程度。根據熱力學第二定律,一個孤立系統的熵不會減少。比如一盒乒乓球,如果把盒子掀翻了,乒乓球散出來,它的熵增加了。如果要將熵減小,那麼必須要對這個系統做功,也就是說需要有外力來將散落的乒乓球放回盒子裡,否則乒乓球的分佈只會越來越散亂。


開創了信息論的香農大佬靈光一閃,既然自然界微觀宏觀的問題都有熵,那麼信息應該也有。於是他開創性地將熵這個概念引入信息論領域,和熱力學的概念類似,信息論當中的熵指的是信息量的混亂程度,也可以理解成信息量的大小。


信息量


舉個簡單的例子,以下兩個句子,哪一個句子的信息量更大呢?


  1. 我今天沒中彩票
  2. 我今天中彩票了


從文本上來看,這兩句話的字數一致,描述的事件也基本一致,但是顯然第二句話的信息量要比第一句大得多,原因也很簡單,因為中彩票的概率要比不中彩票低得多。


相信大家也都明白了,一個信息傳遞的事件發生的概率越低,它的信息量越大。我們用對數函數來量化一個事件的信息量:


詳解機器學習損失函數之交叉熵


因為一個事件發生的概率取值範圍在0到1之間,所以log(p(X))的範圍是負無窮到0,我們加一個負號將它變成正值。畫成函數大概是下面這個樣子:

詳解機器學習損失函數之交叉熵


信息熵


我們上面的公式定義的是信息量,但是這裡有一個問題,我們定義的只是事件X的一種結果的信息量。對於一個事件來說,它可能的結果可能不止一種。我們希望定義整個事件的信息量,其實也很好辦,我們算下整個事件信息量的期望即可,這個期望就是信息熵。


期望的公式我們應該都還記得:


詳解機器學習損失函數之交叉熵


我們套入信息量的公式可以得到信息熵H(x):


詳解機器學習損失函數之交叉熵



相對熵(KL散度)


在我們介紹相對熵之前,我們先試著思考一個問題,我們為什麼需要相對熵這麼一個概念呢?


原因很簡單,因為我們希望測量我們訓練出來的模型和實際數據的差別,相對熵的目的就是為了評估和驗證模型學習的效果。也就是說相對熵是用來衡量兩個概率分佈的差值的,我們用這個差值來衡量模型預測結果與真實數據的差距。明白了這點之後,我們來看相對熵的定義:


詳解機器學習損失函數之交叉熵


如果把 {x} 看成是一個事件的所有結果,那 x_i可以理解成一個事件的一個結果。那麼所有的

P(x_i)和Q(x_i)就可以看成是兩個關於事件X的概率分佈。


P(x_i)樣本真實的分佈,我們可以求到。而 Q(x_i)是我們模型產出的分佈。KL散度越小,表示這兩個分佈越接近,說明模型的效果越好。


光看上面的KL散度公式可能會雲裡霧裡,不明白為什麼能反應P和Q兩個分佈的相似度。因為這個公式少了兩個限制條件:


詳解機器學習損失函數之交叉熵


對於單個 P(x_i)來說,當然Q(x_i)越大 P(x_i)log(P(x_i) / Q(x_i))越小。但由於所有的Q(x_i)的和是1,當前的i取的值大了,其他的i取的值就要小。


我們先來直觀地感受一下,再來證明。


我們假設 x_i只有0和1兩個取值,也就是一個事件只有發生或者不發生。我們再假設

P(x=0)=P(x=1)=0.5,我們畫一下

P(x_i)log(P(x_i) / Q(x_i))的圖像:


詳解機器學習損失函數之交叉熵


和我們預料的一樣,函數隨著 Q(x_i)的遞增而遞減。但是這只是一個x的取值,別忘了,我們相對熵計算的是整個分佈,那我們加上另一個x的取值會發生什麼呢?

詳解機器學習損失函數之交叉熵


從函數圖像上,我們很容易看出,當Q(x)=0.5的時候,KL散度取最小值,最小值為0。我們對上面的公式做一個變形:


詳解機器學習損失函數之交叉熵


這個式子左邊


詳解機器學習損失函數之交叉熵


其實就是-H(X),對於一個確定的事件X來說,它的信息熵是確定的,也就是說

H(X)是一個常數,P(x_i)也是常數。log函數是一個凹函數,-log是凸函數。我們把

P(x_i)當成常數的話,可以看出


詳解機器學習損失函數之交叉熵


是一個凸函數。


凸函數有一個jensen不等式:


詳解機器學習損失函數之交叉熵


也即:變量期望的函數值大於變量函數值的期望,有點繞口令,看不明白沒關係,可以通過下圖直觀感受:

詳解機器學習損失函數之交叉熵

我們利用這個不等式試著證明:


詳解機器學習損失函數之交叉熵


首先,我們對原式進行變形:


詳解機器學習損失函數之交叉熵


然後我們利用不等式:


詳解機器學習損失函數之交叉熵


所以KL散度是一個非負值,但是為什麼當P和Q相等時,能取到最小值呢?我們單獨拿出右邊


詳解機器學習損失函數之交叉熵


我們令


詳解機器學習損失函數之交叉熵


我們探索C(P, P) - C(P, Q)的正負性來判斷P和Q的關係。


詳解機器學習損失函數之交叉熵


因為log(x)是凸函數,所以我們利用jensen不等式,可以得到:


詳解機器學習損失函數之交叉熵


我們帶入


詳解機器學習損失函數之交叉熵


詳解機器學習損失函數之交叉熵


所以 C(P, P) - C(P, Q) <= 0, 即 C(P, P) <= C(P, Q),當且僅當 P=Q 時等號成立。


交叉熵


通過上面一系列證明,我們可以確定,KL散度可以反映兩個概率分佈的距離情況。由於P是樣本已知的分佈,所以我們可以用KL散度反映Q這個模型產出的結果與P的距離。距離越近,往往說明模型的擬合效果越好,能力越強。


我們把上面定義的C(P, Q)帶入KL散度的定義,會發現:


詳解機器學習損失函數之交叉熵


對於一個確定的樣本集來說,P(x_i)是定值,所以我們可以拋開左邊


詳解機器學習損失函數之交叉熵


不用管它,單純來看右邊。右邊我們剛剛定義的C(P, Q)其實就是交叉熵。


說白了,交叉熵就是KL散度去除掉了一個固定的定值的結果。KL散度能夠反映P和Q分佈的相似程度,同樣交叉熵也可以,而且交叉熵的計算相比KL散度來說要更為精簡一些。


如果我們只考慮二分類的場景,那麼C(P, Q) = -P(x=0)log(Q(x=0)) - P(x=1)log(Q(x=1))


由於P(x=0)結果已知,並且:

P(x=0) + P(x=1)=1, Q(x=0) + Q(x=1)=1。我們令 P(x=0) = y, Q(x=0)=y_hat


所以上式可以變形為:


詳解機器學習損失函數之交叉熵


這個式子就是我們在機器學習書上最常見到的二分類問題的交叉熵的公式在信息論上的解釋,我們經常使用,但是很少會有資料會將整個來龍去脈完整的推導一遍。對於我們算法學習者而言,我個人覺得只有將其中的細節深挖一遍,才能真正獲得提升,才能做到知其然,並且知其所以然。理解了這些,我們在面試的時候才能真正做到遊刃有餘。


當然,到這裡其實還沒有結束。仍然存在一個問題,可以反映模型與真實分佈距離的公式很多,為什麼我們訓練模型的時候單單選擇了交叉熵呢,其他的公式不行嗎?為什麼呢?


分析


我們來實際看個例子就明白了,假設我們對模型:


詳解機器學習損失函數之交叉熵


選擇MSE(均方差)作為損失函數。假設對於某個樣本x=2,y=0,

θ_0=2, θ_1 = 1 那麼θX=4,此時 σ(θX) = 0.98


此時


詳解機器學習損失函數之交叉熵


我們對它求關於θ的偏導:


詳解機器學習損失函數之交叉熵


所以如果我們通過梯度下降來學習的話,


詳解機器學習損失函數之交叉熵


這個式子看起來很正常,但是隱藏了一個問題,就是我們這樣計算出來的梯度實在是太小了。通過梯度下降去訓練模型需要消耗大量的時間才能收斂。


如果我們將損失函數換成交叉熵呢?


我們回顧一下交叉熵求梯度之後的公式:


詳解機器學習損失函數之交叉熵


我們帶入上面具體的值,可以算出來如果使用交叉上來訓練,我們算出來的梯度為1.96,要比上面算出來的0.04大了太多了。顯然這樣訓練模型的收斂速度會快很多,這也是為什麼我們訓練分類模型採用交叉熵作為損失函數的原因。


究其原因是因為如果我們使用MSE來訓練模型的話,在求梯度的過程當中免不了對sigmoid函數求導。而正是因為sigmoid的導數值非常小,才導致了我們梯度下降的速度如此緩慢。那麼為什麼sigmoid函數的導數這麼小呢?我們來看下它的圖像就知道了:

詳解機器學習損失函數之交叉熵


觀察一下上面這個圖像,可以發現當x的絕對值大於4的時候,也就是圖像當中距離原點距離大於4的兩側,函數圖像就變得非常平緩。而導數反應函數圖像的切線的斜率,顯然這些區域的斜率都非常小,而一開始參數稍微設置不合理,很容易落到這些區間裡。那麼通過梯度下降來迭代自然就會變得非常緩慢。


所以無論是機器學習還是深度學習,我們一般都會少會對sigmoid函數進行梯度下降。在之前邏輯迴歸的文章當中,我們通過極大似然推導出了交叉熵的公式,今天我們則是利用了信息論的知識推導了交叉熵的來龍去脈。兩種思路的出發點和思路不同,但是得到的結果卻是同樣的。關於這點數學之美當中給出瞭解釋,因為信息論是更高維度的理論,它反映的其實是信息領域最本質的法則。就好像物理學當中公式千千萬,都必須要遵守能量守恆、物質守恆一樣。機器學習也好,深度學習也罷,無非是信息領域的種種應用,自然也逃不脫信息論的框架。不知道看到這裡,你有沒有一點豁然開朗和一點震撼的感覺呢?


今天的文章就到這裡,公式比較多,但推導過程並不難,希望大家不要被嚇住,能冷靜看懂。如果覺得有所收穫,請順手點個關注或轉發吧,你們的支持是我最大的動力。


分享到:


相關文章: