如何讓Transformer更高效處理長序列

本文提出 Reformer—— 通過三種實用的方法使 Transformer 節約內存、加速訓練、處理長序列輸入。本文比較偏工程,算法複雜,建議瞭解即可。

概述

Transformer現在已經成了諸多NLP模型的標配,但是它的問題是,模型太大,不能處理較長的序列。

比如,就目前最大的Transformer結構來看,光是參數量就要佔2GB內存,然後儲存batch_size為8、embedding為1024的64K序列又需要佔2GB內存,內存開銷太大了!

如何減少內存開銷、加速模型訓練、處理更長序列,是一個十分現實的問題。

本文提出一種Reformer模型用於緩解上述問題:

  • 使用可逆層(reversible layers),只存儲單層激活值的一份拷貝
  • 把FF層裡的激活值進行切分
  • 使用局部敏感哈希(LSH)注意力代替傳統多頭注意力

使用上述三種方法,本文在長達64K的文本任務和長達12K的圖像任務上進行試驗,結果表明Reformer跑得比Transformer不知道快到哪裡去了,既節約了內存,又拓展了Transformer處理長序列的能力。

局部敏感哈希注意力

我們知道,Transformer中的注意力計算需要讓矩陣Q和K的轉置相乘。我們假定它們的形狀都是[ batch_size, length, dimension ],那麼如果序列長度有64K,就有得到一個64K*64K的矩陣,顯然是不現實的。

那麼一個簡單的想法是,我們不一次性算所有的Q,而對每個q(i)單獨計算即可,然而在需要BP的時候再算一次即可。雖然這樣做不那麼高效,但至少它可以處理比較長的輸入序列。

共享QK參數

Q,K,V是通過三個不同的線性變換(矩陣)得到的,我們可以讓Q,K的變換矩陣相同,讓V單獨有一個。這樣做實際上並不會有損效果。

哈希注意力

上面我們提到,限制長序列的主要原因是當length很大時[ length, length ]的矩陣是不可行的。

但實際上,我們真正關心的是softmax(QK^T),又因為softmax是被那些比較大的數支配的,從而,對每個q(i),我們只需要去找K中離q(i)最近的的那些就好了,也就是選一個子集。

局部敏感哈希

那麼怎麼找到最近鄰呢?這可以使用局部敏感哈希(LSH)。把每個向量x映射為一個哈希值h(x)叫做局部敏感哈希,如果比較近的向量能以高概率映射到同一個哈希值,而比較遠的向量能以高概率被映射到不同的哈希值。

為了得到b個不同的哈希值,我們隨機一個矩陣R,大小為[ dimension, b/2 ],然後定義h(x)=argmax([xR; -xR])。這樣,對所有的x,我們就可以把它們分配到b個哈希桶裡面。

局部敏感哈希注意力

方便起見,下面用另一種方式重寫一下q(i)關注K的方程:

如何讓Transformer更高效處理長序列

這裡z是歸一化項(不用管它)。現在來考慮LSH注意力。我們只需要讓q(i)去關注在同一個哈希桶裡面的k(j)即可:

如何讓Transformer更高效處理長序列

下圖(右a-b)是和傳統注意力的比較。(a) 表明傳統的注意力是很稀疏的,也就是說大多數的字符其實都不用關注;(b) k和q根據它們的哈希桶(注意隨機的那個矩陣R是共享的)排序好,然後再使用。

然而另一個問題是,這樣得到的哈希桶的大小很可能不均勻。我們從小到大給Q的哈希桶排序,在每個桶內部,按照位置先後排序。

在排序後的注意力矩陣中,來自同一個哈希桶的(q,k)對會聚集在矩陣的對角(下圖右c)。最後,把它們分組,每組m個,在各組內相互關注即可。

如何讓Transformer更高效處理長序列

多輪LSH哈希

為了進一步減小桶分佈不均的情況,可以用不同的哈希函數進行多輪哈希,具體參見原文。

共享QK中的自掩碼(Causal Masking for Shared-QK Attention)

在普通注意力中一個位置可以關注自己,但在共享QK中我們一般不考慮,從而我們禁止它去關注自己,除非沒有別的可以關注的了。

下表是幾種注意力方式的時空複雜度:

如何讓Transformer更高效處理長序列

可逆Transformer

如上所述,attention的複雜度可以被減少為序列長度的線性級,但是,參數量佔的複雜度依舊很高,我們想要進一步減少。

可逆殘差網絡

對輸入x,殘差網絡的輸出是y=x+F(x),一個可逆層定義在一個輸入輸出對上:

如何讓Transformer更高效處理長序列

然後,輸入輸出就可以是:

如何讓Transformer更高效處理長序列

可逆Transformer

同上,我們用F表示注意力層,用G表示FF層:

如何讓Transformer更高效處理長序列

FF層分組

可以進一步把Y(2)分組:

下表是所有變體的複雜度:

如何讓Transformer更高效處理長序列

實驗

接下來,我們在imagenet64和enwik8-64K上實驗,其他設置詳見原文。

下圖是不同的方法在這兩個數據集上的表現,可以看到,無論是共享QK還是可逆Transformer,都不會影響效果。

如何讓Transformer更高效處理長序列

下圖是不同哈希桶數的LSH注意力的表現。顯然,數量越多,效果越好,這是因為關注就越精確,同時模型代價就越高。

如何讓Transformer更高效處理長序列

最後來看看Reformer的層數。下圖(左)是Big Reformer隨層變化的不同效果,20層依然無壓力。

而下圖(右)是普通注意力和LSH注意力在不同序列長度的速度比較,當序列很長的時候,LSH具有顯著的優勢。

如何讓Transformer更高效處理長序列

小結

本文提出使用LSH(局部敏感哈希)和可逆網絡去加速Transformer訓練,節省內存以處理更長序列,在幾個實驗上取得了較好的效果。然而本文比較偏工程,而且算法也比較複雜,各位瞭解一下即可。


分享到:


相關文章: