詳細介紹 Beam Search 及其優化方法

Beam Search 是一種受限的寬度優先搜索方法,經常用在各種 NLP 生成類任務中,例如機器翻譯、對話系統、文本摘要。本文首先介紹 Beam Search 的相關概念和得分函數優化方法,然後介紹一種新的 Best-First Beam Search 方法,Best-First Beam Search 結合了優先隊列和 A* 啟發式搜索方法,可以提升 Beam Search 的速度。

1.Beam Search

在生成文本的時候,通常需要進行解碼操作,貪心搜索 (Greedy Search) 是比較簡單的解碼。假設要把句子 "I love you" 翻譯成 "我愛你"。則貪心搜索的解碼過程如下:

詳細介紹 Beam Search 及其優化方法

貪心搜索解碼過程

貪心搜索每一時刻只選擇當前最有可能的單詞,例如在預測第一個單詞時,"我" 的概率最大,則第一個單詞預測為 "我";預測第二個單詞時,"愛" 的概率最大,則預測為 "愛"。貪心搜索具有比較高的運行效率,但是每一步考慮的均是局部最優,有時候不能得到全局最優解。

Beam Search 對貪心搜索進行了改進,擴大了搜索空間,更容易得到全局最優解。Beam Search 包含一個參數 beam size k,表示每一時刻均保留得分最高的 k 個序列,然後下一時刻用這 k 個序列繼續生成。下圖展示了 Beam Search 的過程,對應的 k=2:

詳細介紹 Beam Search 及其優化方法

Beam Search 解碼過程

在第一個時刻,"我" 和 "你" 的預測分數最高,因此 Beam Search 會保留 "我" 和 "你";在第二個時刻的解碼過程中,會分別利用 "我" 和 "你" 生成序列,其中 "我愛" 和 "你愛" 的得分最高,因此 Beam Search 會保留 "我愛" 和 "你愛";以此類推。

Beam Search 的偽代碼如下:

詳細介紹 Beam Search 及其優化方法

Beam Search 偽代碼

每一步 Beam Search 都會維護一個 k-最大堆 (即偽代碼中的 B),然後用上一步的 k 個最高得分的序列生成新序列,放入最大堆 B 裡面,選出當前得分最高的 k 個序列。偽代碼中的 score 是得分函數,通常是對數似然:

詳細介紹 Beam Search 及其優化方法

對數似然得分函數

2.Beam Search 得分函數優化

2.1 length normalization 和 coverage penalty

這個方法是論文《Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation》中提出的。包含兩個部分 length normalization 和 coverage penalty。

在採用對數似然作為得分函數時,Beam Search 通常會傾向於更短的序列。因為對數似然是負數,越長的序列在計算 score 時得分越低 (加的負數越多)。在得分函數中引入 length normalization 對長度進行歸一化可以解決這一問題。

coverage penalty 主要用於使用 Attention 的場合,通過 coverage penalty 可以讓 Decoder 均勻地關注於輸入序列 x 的每一個 token,防止一些 token 獲得過多的 Attention。

把對數似然、length normalization 和 coverage penalty 結合在一起,可以得到新的得分函數,如下面的公式所示,其中 lp 是 length normalization,cp 是 coverage penalty:

詳細介紹 Beam Search 及其優化方法

length normalization 和 coverage penalty

2.2 互信息得分函數

在對話模型中解碼器通常會生成一些過於通用的回覆,回覆多樣性不足。例如不管用戶問什麼,都回復 "我不知道"。為了解決對話模型多樣性的問題,論文《A Diversity-Promoting Objective Function for Neural Conversation Models》中提出了採用最大化互信息 MMI (Maximum Mutual Information) 作為目標函數。其也可以作為 Beam Search 的得分函數,如下面的公式所示。

詳細介紹 Beam Search 及其優化方法

互信息得分函數

最大化上面的得分函數可以提高模型回覆的多樣性,即需要時 p(y|x) 遠遠大於 p(y)。這樣子可以為每一個輸入 x 找到一個專屬的回覆,而不是通用的回覆。

3.更高效的 Beam Search

論文《Best-First Beam Search》關注於提升 Beam Search 的搜索效率,提出了 Best-First Beam Search 算法,Best-First Beam Search 可以在更短時間內得到和 Beam Search 相同的搜索結果。論文中提到 Beam Search 的時間和得分函數調用次數成正比,如下圖所示,因此作者希望通過減少得分函數的調用次數,從而提升 Beam Search 效率。

詳細介紹 Beam Search 及其優化方法

Beam Search 運行時間和得分函數調用次數的關係

Best-First Beam Search 使用了優先隊列並定義新的比較運算符,從而可以減少調用得分函數的次數,更快停止搜索。另外 Best-First Beam Search 也可以結合 A* 搜索算法,在計算得分時加上一些啟發函數,對於 A* 不瞭解的讀者可以參考下之前的文章《A* 路徑搜索算法》

3.1 減少調用得分函數的次數

Beam Search 使用的得分函數是對數似然 log p,log p 是一個負數,則 Beam Search 的得函數是一個關於序列長度 t 單調遞減的函數,即 t 越得分越低。Best-First Beam Search 就是利用這一特性,不去搜索那些必定不是最大得分的路徑。

傳統的 Beam Search 每一個時刻 t 均會保留 k 個最大得分的序列,然後對於這 k 個序列分別生成 t+1 時刻的序列。但是其中有一些搜索是沒有必要的,只需要一直搜索當前得分最大的序列 (如果有兩個得分最大的序列,則搜索更短的那個序列) ,直到得分最大的序列已經結束 (即生成結束符)。

3.2 通用的 Beam Search 偽代碼

作者給出了一種通用的 Beam Search 偽代碼,偽代碼包括 4 種可替換的關鍵成分。傳統的 Beam Search、Best-First Beam Search 和 A* Beam Search 都可以通過修改偽代碼的可替換成分得到。偽代碼如下:

詳細介紹 Beam Search 及其優化方法

通用的 Beam Search 偽代碼

偽代碼包括 4 個可替換部分:

  • 粉紅色部分為優先隊列 Q 的比較函數 comparator,通過 comparator 對比兩個預測序列的優先級。預測序列用 表示,y 是序列,s 是序列對應的得分。
  • 紫色部分是停止搜索的條件。
  • 綠色部分是 beam size k,POPS 用於統計長度為 |y| 的序列個數,如果長度為 |y| 的序列超過 k 個,就不進行處理 (和傳統 Beam Search 保留 k 個是一樣的意思)。
  • 黃色部分是啟發函數,A* Beam Search 才會使用。

通過修改這 4 個部分,就可以分別得到 Beam Search、Best-First Beam Search 和 A* Beam Search,具體定義如下圖所示。圖中第一行的 3 種均是 Beam Search 方法,第二行的 3 種是傳統的搜索方法 (即 k=∞)。我們首先看一下 Beam Search,Beam Search 的 comparator 如下:

詳細介紹 Beam Search 及其優化方法

不同 Beam Search 生成的方式

3.3 實驗結果

詳細介紹 Beam Search 及其優化方法

Best-First Beam Search 實驗結果

可以看到 Best-First Beam Search 可以減少得分函數的調用次數,k 值越大能夠減少的次數越多。

4.參考文獻

Best-First Beam Search

Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation

A Diversity-Promoting Objective Function for Neural Conversation Models


分享到:


相關文章: