文本生成中的beam search解碼器原理與實現

文本生成中的beam search解碼器原理與實現

自然語言處理任務,比如caption generation(圖片描述文本生成)、機器翻譯中,都需要進行詞或者字符序列的生成。常見於seq2seq模型或者RNNLM模型中。

這篇博文主要介紹文本生成解碼過程中用的greedy search 和beam search算法實現。其中,greedy search 比較簡單,著重介紹beam search算法的實現。

我們在文本生成解碼時,實際上是想找對最有的文本序列,或者說是概率,可能性最大的文本序列。而要在全局搜索這個最有解空間,往往是不可能的(因為詞典太大),建設生成序列長度為N,詞典大小為V, 則複雜度為 V^N次方。這實際上是一個NP難題。退而求其次,我們使用啟發式算法,來找到可能的最優解,或者說足夠好的解。

假設序列數據(假設每個位置詞的概率都已經給出):

data = [[0.1, 0.2, 0.3, 0.4, 0.5],

[0.5, 0.4, 0.3, 0.2, 0.1],

[0.1, 0.2, 0.3, 0.4, 0.5],

[0.5, 0.4, 0.3, 0.2, 0.1],

[0.1, 0.2, 0.3, 0.4, 0.5],

[0.5, 0.4, 0.3, 0.2, 0.1],

[0.1, 0.2, 0.3, 0.4, 0.5],

[0.5, 0.4, 0.3, 0.2, 0.1],

[0.1, 0.2, 0.3, 0.4, 0.5],

[0.5, 0.4, 0.3, 0.2, 0.1]]

data = array(data)

1、greedy search decoder

非常簡單,我們用argmax就可以實現

# greedy decoder

def greedy_decoder(data):

# 每一行最大概率詞的索引

return [argmax(s) for s in data]

完整代碼

from numpy import array

from numpy import argmax

# greedy decoder

def greedy_decoder(data):

# 每一行最大概率詞的索引

return [argmax(s) for s in data]

# 定義一個句子,長度為10,詞典大小為5

data = [[0.1, 0.2, 0.3, 0.4, 0.5],

[0.5, 0.4, 0.3, 0.2, 0.1],

[0.1, 0.2, 0.3, 0.4, 0.5],

[0.5, 0.4, 0.3, 0.2, 0.1],

[0.1, 0.2, 0.3, 0.4, 0.5],

[0.5, 0.4, 0.3, 0.2, 0.1],

[0.1, 0.2, 0.3, 0.4, 0.5],

[0.5, 0.4, 0.3, 0.2, 0.1],

[0.1, 0.2, 0.3, 0.4, 0.5],

[0.5, 0.4, 0.3, 0.2, 0.1]]

data = array(data)

# 使用greedy search解碼

result = greedy_decoder(data)

print(result)

2. beam search

與greedy search不同,beam search返回多個最有可能的解碼結果(具體多少個,由參數k執行)。

greedy search每一步都都採用最大概率的詞,而beam search每一步都保留k個最有可能的結果,在每一步,基於之前的k個可能最優結果,繼續搜索下一步。(參考下面示意圖理解)

示例圖(設置返回解碼結果為2個):

文本生成中的beam search解碼器原理與實現

from math import log

from numpy import array

from numpy import argmax

# beam search

def beam_search_decoder(data, k):

sequences = [[list(), 1.0]]

for row in data:

all_candidates = list()

for i in range(len(sequences)):

seq, score = sequences[i]

for j in range(len(row)):

candidate = [seq + [j], score * -log(row[j])]

all_candidates.append(candidate)

# 所有候選根據分值排序

ordered = sorted(all_candidates, key=lambda tup:tup[1])

# 選擇前k個

sequences = ordered[:k]

return sequences

# 定義一個句子,長度為10,詞典大小為5

data = [[0.1, 0.2, 0.3, 0.4, 0.5],

[0.5, 0.4, 0.3, 0.2, 0.1],

[0.1, 0.2, 0.3, 0.4, 0.5],

[0.5, 0.4, 0.3, 0.2, 0.1],

[0.1, 0.2, 0.3, 0.4, 0.5],

[0.5, 0.4, 0.3, 0.2, 0.1],

[0.1, 0.2, 0.3, 0.4, 0.5],

[0.5, 0.4, 0.3, 0.2, 0.1],

[0.1, 0.2, 0.3, 0.4, 0.5],

[0.5, 0.4, 0.3, 0.2, 0.1]]

data = array(data)

# 解碼

result = beam_search_decoder(data, 3)

# print result

for seq in result:

print(seq)

相關資料:

  • Argmax on Wikipedia
  • Numpy argmax API
  • Beam search on Wikipedia
  • Beam Search Strategies for Neural Machine Translation, 2017.
  • Artificial Intelligence: A Modern Approach (3rd Edition), 2009.
  • Neural Network Methods in Natural Language Processing, 2017.
  • Handbook of Natural Language Processing and Machine Translation, 2011.
  • Pharaoh: a beam search decoder for phrase-based statistical machine translation models, 2004.


分享到:


相關文章: