淺談 Softmax 函數

點擊上方“視學算法”,馬上關注


來自 | 知乎 作者 | LinT鏈接丨https://zhuanlan.zhihu.com/p/79585726編輯 | 深度學習這件小事公眾號僅作學術交流,如有侵權,請聯繫刪除

乾貨 | 淺談 Softmax 函數


0. 引言

Softmax函數幾乎是深度學習中的標配了,在人工神經網絡中,幾乎無處不可見softmax函數的身影。可以認為softmax是arg max操作的一種平滑近似。

我將softmax的用途總結為兩種:

  • 分類:給定一系列類別,softmax可以給出某輸入被劃分到各個類別的概率分佈。由於人工智能領域的許多問題都可以抽象成分類問題,所以softmax最廣泛的應用當屬分類;
  • 尋址:由於softmax的輸出是一種概率分佈的形式,可以用它來實現一種軟性的尋址。近幾年逐漸推廣的(軟性)注意力機制就可以認為是一種軟性尋址的方式,根據各個鍵向量與查詢向量的相似度從各個值向量中獲取一定的“信息”。因此使用了softmax的注意力機制也可以用於外部記憶的訪問。

不難發現,在分類問題中,我們也可以使用arg max來找到對應的類別;在尋址問題中,一個直觀的方法也是使用arg max尋找最相似的向量/記憶。但是arg max操作並不具有良好的數學性質,其不可導的性質使其無法直接應用基於梯度的優化方法。因此在分類和尋址兩種用途中,常常都使用softmax函數替代arg max。

基於這兩種用途,softmax可以在人工神經網絡中充當什麼樣的角色,就靠諸君的想象了。這篇文章中,我想簡單、粗淺地探討一下softmax的一些性質與變種。


1. 基本形式

給定一個 維向量,softmax函數將其映射為一個概率分佈。標準的softmax函數 由下面的公式定義[1]:


其中,分母是配分函數(Partition Function),一般簡記為 ,表示所有狀態/值的總和,作為歸一化因子;分子是勢能函數(Potential Function)。

直觀上看,標準softmax函數用一個自然底數 先拉大了輸入值之間的差異,然後使用一個配分將其歸一化為一個概率分佈。在分類問題中,我們希望模型分配給正確的類別的概率接近1,其他的概率接近0,如果使用線性的歸一化方法,很難達到這種效果,而softmax有一個先拉開差異再歸一化的“兩步走”戰略,因此在分類問題中優勢顯著。

事實上,在勢能函數和配分函數中,可以採用的底數不僅僅是自然底數 ,也可以採用一些其他的底數。原則上,任意 都可以作為這裡的底數,越大的底數越能起到“拉開差異”的作用。使用 作為底數時,將產生以下的非標準softmax函數[1]:


其中 是一個實數,正的 常常在機器學習中使用,在信息檢索的工作DSSM中, 就充當了一個平滑因子[2];負的 常常在熱力學系統中使用,由於一些概率圖模型也參考了熱力學的原理,所以在概率圖模型中也常常能見到這種形式,如玻爾茲曼機。


2. 導數與優化

標準softmax具有非常漂亮的導數形式:


這裡導數的推導可以參考 @邱錫鵬 老師的《神經網絡與深度學習》[3]附錄B.2.4的推導。


在分類問題中,softmax函數常常和交叉熵損失函數一起使用,此時交叉熵損失函數 對 的導數,由下面的形式給出:


其中 是真實標籤對應的one-hot編碼向量。這樣的導數在優化時非常方便,實現起來也非常簡單。

由於softmax函數先拉大了輸入向量元素之間的差異,然後才歸一化為一個概率分佈,在應用到分類問題時,它使得各個類別的概率差異比較顯著,最大值產生的概率更接近1,這樣輸出分佈的形式更接近真實分佈。

但是當softmax函數被應用到尋址時,例如注意力機制中,softmax這個拉大元素間差異的過程可能會導致一定的問題。假設輸入向量有唯一的最大值,如果將arg max操作定義為指示最大值的一個one-hot編碼函數 ,在非標準softmax中有[1]:


* 這裡的證明參見下方的補充(A)。

如果將非標準softmax的 融入到輸入中,則容易看出:當輸入的方差/數量級較大時,softmax的輸出會比較接近一個one-hot向量。根據式 ,其導數的兩個項會比較接近,導致導數變成0矩陣。這也就導致了梯度彌散的問題,不利於優化,具體討論可以參考我先前的一篇回答[4]。這也是為什麼注意力機制中常常使用縮放點積形式的注意力。

插一句題外話,開篇提到softmax是arg max操作的一種平滑近似,而針對max操作的近似,其實有一個LogSumExp[5]操作(也叫作softmax),其導數形式就是softmax函數,是不是很有趣呢?


3. Softmax的解釋

Softmax可以由三個不同的角度來解釋。從不同角度來看softmax函數,可以對其應用場景有更深刻的理解。

3.1 是arg max的一種平滑近似[1]

前面提到過,softmax可以當作arg max的一種平滑近似,與arg max操作中暴力地選出一個最大值(產生一個one-hot向量)不同,softmax將這種輸出作了一定的平滑,即將one-hot輸出中最大值對應的1按輸入元素值的大小分配給其他位置。如式 所示,當底數增大時,softmax逐漸收斂為arg max操作。

在機器學習應用中,我們往往不(直接)需要一個arg max的操作,這時候顯然數學性質更好、更容易優化的softmax就是我們的第一選擇。

3.2 歸一化產生一個概率分佈

Softmax函數的輸出符合指數分佈族的基本形式


其中 。

不難理解,softmax將輸入向量歸一化映射到一個類別概率分佈,即 個類別上的概率分佈(前文也有提到)。這也是為什麼在深度學習中常常將softmax作為MLP的最後一層,並配合以交叉熵損失函數(對分佈間差異的一種度量)。

3.3 產生概率無向圖的聯合概率

從概率圖模型的角度來看,softmax的這種形式可以理解為一個概率無向圖上的聯合概率。因此你會發現,條件最大熵模型與softmax迴歸模型實際上是一致的,諸如這樣的例子還有很多。由於概率圖模型很大程度上借用了一些熱力學系統的理論,因此也可以從物理系統的角度賦予softmax一定的內涵。


4. Softmax的改進與變種

Softmax函數是一種簡單優美的歸一化方法,但是它也有其固有的缺陷。直觀上看,當應用到實際問題時,其最大的問題就在於配分函數的計算:當類別的數量很多時,配分函數的計算就成為了推斷和訓練時的一個瓶頸。在自然語言處理中,類別常常對應詞彙表中的所有詞彙,這個數量之大可見一斑,如果直接採用softmax計算方法,計算效率會非常低。因此一般採用一些方法改進softmax函數,加速模型訓練。這裡列舉幾個自然語言處理中的經典改進/變種[3]:

  • 層次化softmax:將扁平的 分類問題轉化為層次化的分類問題。將詞彙表中的詞分組組織成(二叉)樹形結構,這樣一個 分類問題,可以轉化為多層的二分類問題,從而將求和的次數由 降低到了樹的深度級別。這裡可以使用的一個方法是,按照詞彙的頻率求和編碼Huffman樹,從而進一步減少求和操作的計算次數。
  • 採樣方法:使用梯度上升優化時,softmax的導數涉及到一次配分函數的計算和一次所有詞彙上的softmax對詞彙的梯度的期望,這兩個計算都可以用採樣方法來近似,比如重要性採樣,這樣計算次數由 減少為採樣樣本數 的級別。這種方法的性能很受採樣策略的影響,以重要性採樣方法為例,其效果就依賴於提議分佈的選取;採樣樣本數 的選取也需要考慮精度和訓練效率的折衷。
  • 噪聲對比估計(NCE):將密度估計問題轉換為兩類分類問題(區分噪聲與真實數據),從而降低計算複雜度。其中配分函數被替換為了一個可學習的參數,這樣NCE方法能促使輸入的未歸一化向量自己適應為一個歸一化的、接近真實分佈的分佈向量。由於不再需要計算配分函數,訓練效率大大提升。這種對比學習思想在深度學習中也十分常見。



5. 總結

前面簡單討論了softmax的性質、解釋與變種,從現在來看,似乎softmax已經是神經網絡中的一根老油條了。Softmax還有哪些可以挖掘的地方呢?作為一個菜鳥,只好先把這個問題拋給諸位了。


乾貨 | 淺談 Softmax 函數


A. 補充

突然覺得式 直接放進來有一點太唐突了,覺得還是要簡單證明一下,算是完善一下之前的回答。

假設固定輸入 不變,變化參數 ,假設輸入 中有唯一的最大值 ,則有:


不妨設 ,可以分類討論一下:

1. 當 ,則 ,此時


2. 當 ,則 ,此時結合 ,可以得到即,當 取無窮大時,非標準softmax的輸出收斂到一個one-hot向量,其中最大輸入對應的輸出值是1,其他輸出是0。當輸入向量中有多個最大值,可以更寬泛地定義arg max操作,使其輸出為一個 維向量,其中 個最大值下標對應的輸出為 ,其它輸出為0。這時類似上面的證明,很容易驗證,當 時,softmax依然仍然收斂為arg max操作。


參考鏈接:


[1] Softmax function


https://en.wikipedia.org/wiki/Softmax_function


[2] Huang et al., Learning Deep Structured Semantic Models for Web Search using Clickthrough Data


https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf


[3] 邱錫鵬,《神經網絡與深度學習》


https://nndl.github.io


[4] transformer中的attention為什麼scaled? - lintongmao的回答 - 知乎


https://www.zhihu.com/question/339723385/answer/782509914


[5] LogSumExp https://en.wikipedia.org/wiki/LogSumExp


分享到:


相關文章: