指針網絡 Pointer Network

傳統的 Seq2Seq 模型中 Decoder 輸出的目標數量是固定的,例如翻譯時 Decoder 預測的目標數量等於字典的大小。這導致 Seq2Seq 不能用於一些組合優化的問題,例如凸包問題,三角剖分,旅行商問題 (TSP) 等。Pointer Network 可以解決輸出字典大小可變的問題,Pointer Network 的輸出字典大小等於 Encoder 輸入序列的長度並修改了 Attention 的方法,根據 Attention 的值從 Encoder 的輸入中選擇一個作為 Decoder 的輸出。

1.Pointer Network

Seq2Seq 模型是一種包含 Encoder 和 Decoder 的模型,可以將一個序列轉成另外一個序列。但是 Seq2Seq 模型的預測輸出目標大小是固定的,對於一些輸出目標大小會變的情況,例如很多組合優化問題。

組合優化問題的輸出目標的數量依賴於輸入序列的長度,例如旅行商問題中包含5個城市 (1, 2, 3, 4, 5),輸出預測的時候目標數量為 5。Pointer Network 改變了傳統 Attention 的方式,從而可以用於這些組合優化的問題,Pointer Network 在預測輸出時會根據 Attention 得到輸入序列中每一個城市的概率 (即輸出從輸入中選擇)。

傳統 Attention

傳統的 Attention 會根據 Attention 值融合 Encoder 的每一個時刻的輸出,然後和 Decoder 當前時刻的輸出混在一起再預測輸出。如下面的公式所示,e 表示 Encoder 的輸出,d 表示 Decoder 的輸出,Wv 都是可以學習的參數。

指針網絡 Pointer Network

Seq2Seq 的 attention

Pointer Network 的 Attention

Pointer Network 計算 Attention 值之後不會把 Encoder 的輸出融合,而是將 Attention 作為輸入序列 P 中每一個位置輸出的概率。

指針網絡 Pointer Network

Pointer Network 的 Attention

Pointer Network 和 Seq2Seq 的區別如下圖所示,圖中展示了凸包問題。Seq2Seq 的 Decoder 會預測每一個位置的輸出 (但是輸出目標的數量是固定的),而 Pointer Network 的 Decoder 直接根據 Attention 得到輸入序列中每一個位置的概率,取概率最大的輸入位置作為當前輸出。

指針網絡 Pointer Network

Seq2Seq 和 Pointer Network

2.實驗結果

指針網絡 Pointer Network

圖中底部的都是 Pointer Network 的實驗結果,m 是訓練數據中點的個數,n 是測試數據中點的個數。圖 (a) 中是使用 LSTM 的 Seq2Seq,Seq2Seq 訓練和測試必須使用相同點的個數。

3.參考文獻

Pointer Networks


分享到:


相關文章: