「NLP」飛躍芝麻街:XLNet 詳解

「NLP」飛躍芝麻街:XLNet 詳解

來源 | 安迪的寫作間

作者 | 準備上天的

“BERT 被碾壓了!”

看到這個消息的瞬間,不由自主地在座位上直接喊了出來,大家紛紛轉過頭來詢問,之後分享出論文鏈接,突然整個組就沒了聲音,我知道大家都在和我一樣,認真讀著 XLNet 的論文,都想參與這或許會見證歷史的時刻,知道到底發生了什麼。

現在,距 XLNet 發佈已過去一週有餘,因為其中用到多個技巧,涉及很多細節,所以反覆閱讀加上和朋友討論後,才大概將各部分給搞明白,還自己思考了一些東西。

為了應題,可以勉強將 XLNet 想象成是一架組裝飛機,為了飛躍芝麻街,用各種部件組裝起來,加上足夠多的燃料(更多數據),飛了過去。

當然這架飛機的造價也不菲,據 Reddit 熱心網友計算,大概也就 24 萬美金 (512x2.5x24x8=245760),相比起來 BERT 的訓練感覺只是灑灑水。

對於 XLNet,如果跳過其各種實現細節,我認為它顯示出的最重要兩點如下:

  1. BERT 雖然用了深層雙向信息,但沒有對被遮掩(Mask)的 token 之間的關係進行直接學習,因此 XLNet 通過提出 Permutation Language Model (PLM)對其進行了學習。
  2. 更多的數據,還有用 Transformer-XL 中的技巧帶來的更大範圍上下文,對模型有正向加強。

關於 XLNet 各個部件的關係,可以分為如下,為了更好實現 PLM,需要 Two-Stream Self-AttentionPartial Prediction,為了更大的上下文信息,需要 Transformer-XL 中的兩個技巧 Segment Recurrence Mechanism 還有 Relative Positional Encoding,最後為滿足 XLNet 像 BERT 一樣處理多段句子,加入了 Relative Segment Encoding

接下來就讓我來一一介紹這幾個部件吧,最後再給出自己的一些看法。

動力系統核心:名為 PLM 的發動機

「NLP」飛躍芝麻街:XLNet 詳解

首先是整篇論文的核心思想,Permutation Language Model。

進入正題前,先來談談文中 AR (AutoRegression, 自迴歸)AE (AutoEncoder, 自編碼器)的提法。這也是很多人覺得很有意思的一個提法。

作者們認為,當前預訓練最主要的兩個目標可分成兩類,一類便是類似 GPT 的 AR 方式,根據前面所有信息預測後一個 token,不斷重複(自迴歸),本質上是在進行某種 Density Estimation (密度估計);而另一類,則是類 BERT 的 AE 方式,做法是類似 DAE (Denoising AutoEncoder, 去噪自編碼器)中把輸入破壞掉一部分,然後還原,BERT 具體做法就是隨機將一些 token 替換成 “[MASK]” 特殊符

這種提法是很讓人耳目一新,但在我看來可能並沒有那麼重要,兩者界限並不是很明顯,特別是後者 DAE 在我看來也能看作一種 Density Estimation,而且微軟 UniLM 中的單向語言模型任務,其實已經將兩者都包含進去了,它在替換掉最後一個 token 為 "[MASK]" 的同時進行自迴歸,但也沒有顯示兩個目標結合會更好。

因此我更願意將 BERT 中的 AE 方式更為特指化,叫做內部損壞的自編碼器(Inner Corrupted AutoEncoder,ICAE),因為這裡我們想要的其實是通過破壞掉中間部分然後復原,捕捉深層次的雙向信息

所以,我想作者更想指出的是,AR 方式所帶來的自迴歸性學習了預測 token 之間的依賴,這是 BERT 所沒有的;而 BERT 的 ICAE 帶來的對深層次雙向信息的學習,卻又是像 GPT 還有 UniLM 單向語言模型所沒有的,不管是有沒有替換 “[MASK]” .

於是,自然就會想,如何將兩者的優點統一起來?於是乎,就到了主角登場的時間。

Permutation Language Model.

作者們發現,只要在 AR 以及 AE 方式之間再加入一個步驟,就能夠完美地將兩者統一起來,那就是 Permutation.

「NLP」飛躍芝麻街:XLNet 詳解

具體實現方式是,通過隨機取一句話排列的一種,然後將末尾一定量的詞給“遮掩”(和 BERT 裡的直接替換 “[MASK]” 有些不同)掉,最後用 AR 的方式來按照這種排列方式依此預測被“遮掩”掉的詞

「NLP」飛躍芝麻街:XLNet 詳解

相信聰明的同學已經發現通過隨機取排列(Permutation)中的一種,就能非常巧妙地通過 AR 的單向方式來習得雙向信息

了。

論文中 Permutation 具體的實現方式是通過直接對 Transformer 的 Attention Mask 進行操作(對 Transformer 和 Attention Mask 不瞭解的可以查看[1],[2])。

「NLP」飛躍芝麻街:XLNet 詳解

比如說序號依次為 1234 的句子,先隨機取一種排列,3241。於是根據這個排列我們就做出類似上圖的 Attention Mask,先看第1行,因為在新的排列方式中 1 在最後一個,根據從左到右 AR 方式,1 就能看到 234 全部,於是第一行的 234 位置是紅色的(沒有遮蓋掉,會用到),以此類推,第2行,因為 2 在新排列是第二個,只能看到 3 於是 3 位置是紅色,第 3 行,因為 3 在第一個,看不到其他位置,所以全部遮蓋掉...

這就是這篇論文的核心思想,看到這裡其實已經能去和小夥伴們吹了,接下來會介紹,XLNet 對 PLM 理念的實現,文末會列出一種我認為可能的另一種實現方式。



輔助動力系統1:Two-Stream Self-Attention

「NLP」飛躍芝麻街:XLNet 詳解

為了實現 Permutation 加上 AR 預測過程,首先我們會發現,打亂順序後位置信息非常重要,同時對每個位置來說,需要預測的是內容信息(對應位置的詞),於是輸入就不能包含內容信息,不然模型學不到東西,只需要直接從輸入 copy 到輸出就好了。

於是這裡就造成了位置信息與內容信息的割裂,因此在 BERT 這樣的位置信息+內容信息輸入 Self-Attention (自注意力) 的流(Stream)之外,作者們還增加了另一個只有位置信息作為 Self-Attention 中 query 輸入的流。文中將前者稱為 Content Stream,而後者稱為 Query Stream

這樣子就能利用 Query Stream 在對需要預測位置進行預測的同時,又不會洩露當前位置的內容信息。具體操作就是用兩組隱狀態(hidden states),g 和 h,其中 g 只有位置信息,作為 Self-Attention 裡的 Q,h 包含內容信息,則作為 K 和 V.

「NLP」飛躍芝麻街:XLNet 詳解

假如說,模型只有一層的話,其實這樣只有 Query Stream 就已經夠了。但如果將層數加上去的話,為了取得更高層的 h,於是就需要 Content Stream 了。h 同時作為 Q K V。

「NLP」飛躍芝麻街:XLNet 詳解

於是組合起來就是這樣。

「NLP」飛躍芝麻街:XLNet 詳解

這篇論文一個缺點就是內容真的很多,很多地方講解不是很詳細,得看幾遍,然後讀引用文獻,才容易搞明白。比如說這幅圖中的兩點:

第一點,最下面一層藍色的 Content Stream 的輸入是 e(x) ,這個很好懂就是 x 對應的詞向量 (Embedding),不同詞對應不同向量,但看旁邊綠色的 Query Stream,就會覺得很奇怪,為什麼都是一樣的 w?這個和後面的Relative Positional Encoding 有關,之後細說。

第二點,當然這是實現細節了,(b)圖中為了便於說明,只將當前位置之外的 h 作為 K 和 V,但實際上實現中應該是所有時序上的 h 都作為 K 和 V,最後再交給 (c)圖中的 Query stream 的 Attention Mask 來完成位置的遮蓋。



輔助動力系統2:Partial Prediction

「NLP」飛躍芝麻街:XLNet 詳解

接著是動力系統最後一部分,XLNet 對 PLM 實現的一個細節,Partial Prediction (部分預測),非常好理解。

因為當我們按上面提到的實現,在 Permutation 後對每個位置進行預測的話,會導致優化過難,訓練難以收斂,於是作者們就做了和 BERT 中類似的操作。訓練時,只對每句話部分位置進行預測

這些預測位置如何選取呢,選當前排列的最後幾個位置。舉個例子,假如有 1234567,先隨機挑一個排列,5427163,那麼假設對最後兩個位置預測,於是就需要依此對6和3進行預測。通過挑結尾的位置,在 AR 中,就能在預測時用到儘可能多的可知信息。

這裡再談一個有意思的點,挑選最後幾個,那麼到底該挑選幾個呢,總得給個標準吧。於是作者這裡設了一個超參數 K,K 等於總長度除以需要預測的個數。拿上面的例子,總長為 7 而需要預測為 2,於是 K = 7/2.

而論文中實驗得出的最佳 K 值介於 6 和 7 (更好)之間,其實如果我們取 K 的倒數,然後轉為百分比,就會發現最佳的比值介於 14.3% 到 16.7% 之間,還記得 BERT 論文的同學肯定就會開始覺得眼熟了。因為 BERT 裡將 Token 遮掩成 “[MASK]” 的百分比就是 15%,正好介於它們之間,我想這並不只是偶然,肯定有更深層的聯繫。

還有一點需要格外指出,被預測之前的其實取不取 Permutation 都沒關係,因為本身位置信息也都在裡面,permutation 反而有些更難理解。

上面就是論文的主要部分,下面是一些更細節實現,比如如何從 Transformer-XL 借來各種部件。



機身:Segment Recurrence Mechanism

「NLP」飛躍芝麻街:XLNet 詳解

Transformer-XL 的重要組件之一,Segment Recurrence Mechanism(段循環機制)。

其實思想很簡單,因為一般訓練 Transformer 時,會按照一定長度,將文本處理成一段(segment)一段的。比如說 BERT 預處理時,就會先處理成一個個 512 長度的樣本,即使可能處理前的文本更長。這樣子的話,有些更長的上下文信息,模型就是學習不到的。

「NLP」飛躍芝麻街:XLNet 詳解

於是 Segment Recurrence Mechanism 想做的就是,能不能在前一段計算完後,將它計算出的隱狀態(hidden states)都保存下來,放入一個 Memory 中去,之後在當前分段計算時,

將之前存下來的隱狀態和當前段的隱狀態拼起來作為 Attention 機制的 K 和 V,從而獲得更長的上下文信息

「NLP」飛躍芝麻街:XLNet 詳解

於是乎文中 Fig1 圖中

「NLP」飛躍芝麻街:XLNet 詳解

最左邊這個一開始看很是神秘的 mem 的身份也就很明顯了,就是 Segment Recurrence Mechanism 中用到的 memory,存放著之前 segment 的隱狀態。



機翼1:Relative Positional Encoding

「NLP」飛躍芝麻街:XLNet 詳解

Transformer-XL 的另一重要組件,

Relative Positional Encoding(相對位置編碼),其實很大程度上是為了解決上一個機制中位置信息表示的問題。

這個問題是,假設在 segment1 中已經用了從 1 開始編碼的絕對位置向量,那麼在 segment2 中,我們該用什麼樣的位置編碼呢。

從 1 開始的絕對位置編碼嗎?這樣的話,在複用 segment1 時,整個過程中就會有兩個 1 位置,這樣是會出問題的,因為模型會搞不清想讓它學習的位置信息。

當然也有個做法就是從 segment1 長度 +1 開始給 segment2 加上位置編碼,但這樣會讓位置編碼表過長,而且不一定能充分學習,還有就是這樣不太符合人類寫作的常識,我們其實都是一段段寫,不會有人會認真數我現在寫到了第 1000 個字,然後第 1000 個字會和第 10 個字有什麼關係,更多會關心在某一段中一個字詞和其他字詞的相對關係。

因此就可以用上這裡提到的,相對位置編碼,不再關心句中詞的絕對信息,而是相對的,比如說兩個詞之間隔了多少個詞這樣的相對信息。Transformer-XL 中提出的相對位置編碼,雖然是為了解決上面的問題,但也非常有趣,將位置信息編碼分析得很透徹。

可以簡單介紹一下,如何從絕對位置信息編碼到相對位置信息編碼的。首先,簡單定義一下,E 是詞向量,也可以把它當作主要內容承載者,U 是絕對位置向量,可看作絕對位置信息承載者,W 主要是用於 attention 機制 QK 的轉換,於是絕對位置信息編碼的注意力由下式得出:

「NLP」飛躍芝麻街:XLNet 詳解

乍一看,WTF,這什麼鬼,其實只是簡單的矩陣運算,用上些定律,簡單點可當成類似下面的乘法運算:

「NLP」飛躍芝麻街:XLNet 詳解

具體點的話就是這樣:

「NLP」飛躍芝麻街:XLNet 詳解

這四個項也都各有各的意義,(a)表示純基於內容之間的尋址,(b)和(c)則分別是 i 位置的內容和位置信息分別相對於 j 位置的位置和內容信息進行的尋址,(d)則是純基於位置之間的尋址。於是要改進的話,就需要對後三個和位置信息相關的項進行改進。

Transformer-XL 給出的改進方案是這樣:

「NLP」飛躍芝麻街:XLNet 詳解

主要有三條改進:

  • 先把有位置信息 U_j 的地方都替換成相對位置信息 R_ij;
  • 之後將(c)和(d)裡的 U_i W_q 分別替換成,u 和 v 可學習向量;
  • 最後將 K 轉換中的矩陣 W_k,分成兩個 W_kE 和 W_kR,分別給內容向量和相對位置向量用。

這樣就獲得了文中的相對位置編碼方法。

那麼相對位置編碼是不是隻有這一種,並不是,這只是一種實現方式,比如現在我們就能想出另一種實現方式:

「NLP」飛躍芝麻街:XLNet 詳解

如果再借鑑一下這裡的第三條改進,就可以變成,

「NLP」飛躍芝麻街:XLNet 詳解

很自由的。



機翼2:Relative Segment Encodings

「NLP」飛躍芝麻街:XLNet 詳解

為了通過輸入形式 [A, SEP, B, SEP, CLS] 來處理句子對任務,於是需要加入標識 A 句和 B 句的段信息。BERT 裡面很簡單,直接準備兩個向量,一個加到 A 句上,一個加到 B 句上。

但當這個遇上 Segment Recurrence Mechanism 時,和位置向量一樣,也出問題了。萬一出現了明明不是一句,但是相同了怎麼辦,於是我們就需要最後一塊補丁,同樣準備兩個向量,s+s- 分別表示在一句話內和不在一句話內。

具體實現是在計算 attention 的時候加入一項:

「NLP」飛躍芝麻街:XLNet 詳解

當 i 和 j 位置在同一段裡就用 s+,反之用 s-,在 attention 計算權重的時候加入額外項。



燃料:更多的數據

「NLP」飛躍芝麻街:XLNet 詳解

XLNet 這架飛機硬件部分都造好了,於是就只差最後一件東西,也就是,Data(數據).

XLNet 到底用了多少數據呢?

  • BooksCorpus + English Wikipedia (13GB)
  • Giga5 (16GB)
  • ClueWeb 2012B (19GB)
  • Common Crawl (78GB)

加起來有 13+16+19+78=126GB 純文本數據,一個恐怖的數據量。

而 BERT 訓練時只用到了第一項,也就是 13GB 的數據。因此就數據而言,XLNet 就用了將近 BERT 十倍的數據。

所以也難怪會有人說 XLNet 不過是更多數據+更廣的上下文信息

最後關於訓練,值得一說的是,和 BERT 一樣也是同時構建正例(正確的連續句子)和負例(隨機下一句的例子),之後分別對每段進行 Permutation 處理,然後預測,對於正例,後一段會用前一段的信息,而對於負例就不用。

關於訓練 loss,XLNet 只用了 PLM 的 loss,卻沒有像 BERT 一樣用 Next Sentence Prediction (下句預測)loss,但是它在句子級別任務表現卻不差,對於這個現象感覺非常神奇,按理說應該是會有幫助的。



飛躍芝麻街!!!

噠噠噠噠噠噠噠噠噠噠噠噠噠噠噠噠噠噠噠噠噠噠噠......

「NLP」飛躍芝麻街:XLNet 詳解

XLNet 號起飛了!它能否飛躍芝麻街呢!讓我們拭目以待!

首先是第一關,RACE 數據集,這是一個難度非常高的閱讀理解數據集,以其長度著稱,諸多模型都沒能在它上面佔到絲毫便宜,只見 XLNet 一個桶滾動作,輕鬆地越了過去,成績還比前輩們好上很多。這充分說明了 XLNet 對

長片段信息的捕捉能力

「NLP」飛躍芝麻街:XLNet 詳解

於是,我們來到了 SQuAD 閱讀理解數據集,這是由 SQuAD1.1 和更具難度的 SQuAD2.0 組成的雙重關卡。XLNet 在 RACE 上就已經取得很好的成績了,相信它在更簡單的 SQuAD 上也不成問題,果然, XLNet 一個平螺旋,刷刷,連續繞過了 SQuAD1.1 和 SQuAD2.0,在後者上甚至直接甩了 BERT 近7個點!!!

「NLP」飛躍芝麻街:XLNet 詳解

接下來,開始進入一系列文本分類障礙,它們有 IMDB, Yelp-2, Yelp-5, DBpedia, AG, Amazon-2, and Amazon-5。很多名字一看起來就像炮灰, XLNet 一系列動作下來,迅速越過,最後還輕鬆地做了一個殷麥曼(別問我這是什麼,我也不懂)動作。可以看出 XLNet 在句子表徵方面性能也很棒

「NLP」飛躍芝麻街:XLNet 詳解

之後進入一個特殊關卡,ClueWeb09-B 文檔排序數據集,主要用於測試模型生成的詞向量效果。只見 XLNet 上拉,下落,一個完美的鐘形機動,從上方繞了過去。這證明了光從模型獲得基於語境的詞向量,XLNet 也是要優於 BERT 的。

「NLP」飛躍芝麻街:XLNet 詳解

最後,終於到了最關鍵環節,GLUE 數據集,為了在自然語言理解上全面測試模型性能,這是一個由九個自然語言理解任務組成的數據集,號稱“死亡九連環”。XLNet 沒有絲毫猶豫,一頭衝了進去,旋轉,跳躍,翻滾,我閉著眼,帶著幾處輕微刮蹭成功衝了出來!正當大家心潮躍起準備歡呼,只見它機頭猛然抬起,突然失速,大家以為它要就此翻車,只見它卻又逐漸恢復水平,原來是眼鏡蛇機動!懸著的心,高高躍起,芝麻街傳來片片歡呼聲。

看看成績,發現除了一兩個任務,XLNet 都是 SOTA(最好成績),而相對於 BERT 則是全面碾壓了。

「NLP」飛躍芝麻街:XLNet 詳解



模型比較與其他

「NLP」飛躍芝麻街:XLNet 詳解

與 BERT 的比較

論文最後有消融實驗部分,將 XLNet 和 BERT 進行更公平的對比,可以發現 XLNet 中 Transformer-XL 更大的上下文信息貢獻了大概一半提高,而 PLM 貢獻剩下一半。更多詳細分析可以看張俊林博士的文章。

「NLP」飛躍芝麻街:XLNet 詳解

其實我更想看到的是 XLNet 這種 Two-Stream 實現方式,和下面提到的 BERT 替換 "[MASK]" 實現方式的對比。因為作者一直有說, BERT 的 pretrain 模式中 “[MASK]” 符號的加入,因為 finetune 時沒用到,所以會有不一致產生,從而影響性能,但並沒有用實驗來證明這一點。

PLM 的另一種實現思路

於是進入 PLM 的另一種實現思路。

當你理解了 PLM,將文中 Permutation,Two-stream 給剝離開,你就會發現,實際上整個 PLM,最主要就是為了建立預測 token 與 token 之間關係,使得 XLNet 對於相同目標能夠學到更多信息,這是單純 BERT 直接遮掩然後同時預測所沒有的,就如文中提到的 “New York” 的例子。

「NLP」飛躍芝麻街:XLNet 詳解

既然如此,那是不是我們只需要對 BERT 進行一些小小改動就能實現相同的目的,也就是,讓它被遮掩的 token 不再是單獨一起輸出結果,而是依此用類似 AR 的方式輸出。

比如說1到7位,隨機兩個位置,2和6,然後再隨機依此預測兩個,這個也同樣可以用 PLM 中對 attention mask 的操作。而各個位置內容信息洩露的問題,也可以直接用 “[MASK]” 符號避免掉。

其實,XLNet 用到的 Two-Stream 感覺只是為了可以避免用 “[MASK]” 符號,但說到預訓練與下游 finetune 任務的不一致,Two-Stream 裡也是有的,比如說預訓練的時候用的是 g 來預測,而下游任務時卻用的是 h。

因此如果不做實驗好好比較一下,是很難說 Two-Stream 就要比 BERT 的 “[MASK]” 符號替換更好。

與 MASS 的比較

其實說到 PLM,還有一篇比較類似的預訓練論文值得關注,那就是微軟的 MASS,它在今年的 WMT2019 取得非常好的成績。

「NLP」飛躍芝麻街:XLNet 詳解

它的思路是在預訓練階段,對於一句話,我們先遮蓋掉,比如說圖中的一句話 12345678,我們遮蓋掉 3456,然後將被遮蓋的 1278 輸入 encoder,獲得上下文雙向信息,之後以此為基礎,用 AR 的方式依此預測 3456.

這個方法與 XLNet 的 PLM 做法的不同大體兩點:

  1. MASS 用的是 BERT 的"[MASK]"遮蓋方法;
  2. MASS 是整段整段的遮掩,而 XLNet 是類似與 BERT 的隨機 1/K 的遮掩。

鑑於 MASS 在生成上比較好的效果,我覺得可以試試將 MASS 的方法拿來和 XLNet 一起訓練,或者將 MASS 的方法借鑑來用於 XLNet 的生成式。

還有個很好玩的地方是,MASS 的遮蓋比取 50% 的時候是最好的,這個又與 BERT 和 XLNet 裡的接近 15% 有些不同了,而造成這個不同的原因,我認為可能是因為連續遮蓋需要更長句子來習得生成所需要的依賴信息。

生成式可能性探討

最後關於 XLNet 生成式方法的探討,可能大家覺得 XLNet 因為是用 AR 做的所以生成式應該很簡單,水到渠成。

但仔細想想就會發現其實會很 tricky,如果這麼好做相信早做了,實驗結果放論文裡了。

將該方法直接用於生成式,首先需要 Permutation,然後之前 finetune 中沒用到的 g 就派上用場了,我們按照 permutation 的結果來獲得 g,預測結果,然後根據預測結果生成 h, 再生成 g,以此類推。

但這樣會帶來兩個問題

  1. 我們如何預先知道生成句子的長度;
  2. 預訓練時只用了 1/K 長度,這和之後生成時的長度不一致,有很大的 gap。

因此實際上看上去用 AR 做的 XLNet 好像很好做生成式,但其實並沒有那麼簡單。



尾聲

XLNet 帶領我們飛躍了芝麻街,可以預見之後各大榜上的模型估計也不大會繼續取芝麻街裡的名字了,我們開始離開被芝麻街統治的時代踏上新的旅途。

XLNet 結構還有很多能改進的地方,當然最重要的是它讓大家突破了 BERT 的思想,可能會在 XLNet 的基礎上再做出一波突破,非常期待之後的研究。


Reference

  1. Illustrated Transformer: https://jalammar.github.io/illustrated-transformer/
  2. The Annotated Transformer: http://nlp.seas.harvard.edu/2018/04/03/attention.html
  3. Dissecting Transformer-XL:https://mc.ai/dissecting-transformer-xl/
  4. XLNet:運行機制及和Bert的異同比較:https://zhuanlan.zhihu.com/p/70257427
  5. XLNet 論文:https://arxiv.org/abs/1906.08237
  6. Transformer-XL 論文:https://arxiv.org/abs/1901.02860
  7. UniLM 論文:https://arxiv.org/abs/1905.03197v1
  8. MASS 論文:http://arxiv.org/abs/1905.02450
  9. BERT 論文:http://arxiv.org/abs/1810.04805
  10. 特技飛行:https://www.glafly.com/Flyzixun/inforshow.aspx?fxzx_id=00031000320003400038

The End


分享到:


相關文章: