本文提出 Reformer—— 通過三種實用的方法使 Transformer 節約內存、加速訓練、處理長序列輸入。本文比較偏工程,算法複雜,建議瞭解即可。
概述
Transformer現在已經成了諸多NLP模型的標配,但是它的問題是,模型太大,不能處理較長的序列。
比如,就目前最大的Transformer結構來看,光是參數量就要佔2GB內存,然後儲存batch_size為8、embedding為1024的64K序列又需要佔2GB內存,內存開銷太大了!
如何減少內存開銷、加速模型訓練、處理更長序列,是一個十分現實的問題。
本文提出一種Reformer模型用於緩解上述問題:
- 使用可逆層(reversible layers),只存儲單層激活值的一份拷貝
- 把FF層裡的激活值進行切分
- 使用局部敏感哈希(LSH)注意力代替傳統多頭注意力
使用上述三種方法,本文在長達64K的文本任務和長達12K的圖像任務上進行試驗,結果表明Reformer跑得比Transformer不知道快到哪裡去了,既節約了內存,又拓展了Transformer處理長序列的能力。
局部敏感哈希注意力
我們知道,Transformer中的注意力計算需要讓矩陣Q和K的轉置相乘。我們假定它們的形狀都是[ batch_size, length, dimension ],那麼如果序列長度有64K,就有得到一個64K*64K的矩陣,顯然是不現實的。
那麼一個簡單的想法是,我們不一次性算所有的Q,而對每個q(i)單獨計算即可,然而在需要BP的時候再算一次即可。雖然這樣做不那麼高效,但至少它可以處理比較長的輸入序列。
共享QK參數
Q,K,V是通過三個不同的線性變換(矩陣)得到的,我們可以讓Q,K的變換矩陣相同,讓V單獨有一個。這樣做實際上並不會有損效果。
哈希注意力
上面我們提到,限制長序列的主要原因是當length很大時[ length, length ]的矩陣是不可行的。
但實際上,我們真正關心的是softmax(QK^T),又因為softmax是被那些比較大的數支配的,從而,對每個q(i),我們只需要去找K中離q(i)最近的的那些就好了,也就是選一個子集。
局部敏感哈希
那麼怎麼找到最近鄰呢?這可以使用局部敏感哈希(LSH)。把每個向量x映射為一個哈希值h(x)叫做局部敏感哈希,如果比較近的向量能以高概率映射到同一個哈希值,而比較遠的向量能以高概率被映射到不同的哈希值。
為了得到b個不同的哈希值,我們隨機一個矩陣R,大小為[ dimension, b/2 ],然後定義h(x)=argmax([xR; -xR])。這樣,對所有的x,我們就可以把它們分配到b個哈希桶裡面。
局部敏感哈希注意力
方便起見,下面用另一種方式重寫一下q(i)關注K的方程:
這裡z是歸一化項(不用管它)。現在來考慮LSH注意力。我們只需要讓q(i)去關注在同一個哈希桶裡面的k(j)即可:
下圖(右a-b)是和傳統注意力的比較。(a) 表明傳統的注意力是很稀疏的,也就是說大多數的字符其實都不用關注;(b) k和q根據它們的哈希桶(注意隨機的那個矩陣R是共享的)排序好,然後再使用。
然而另一個問題是,這樣得到的哈希桶的大小很可能不均勻。我們從小到大給Q的哈希桶排序,在每個桶內部,按照位置先後排序。
在排序後的注意力矩陣中,來自同一個哈希桶的(q,k)對會聚集在矩陣的對角(下圖右c)。最後,把它們分組,每組m個,在各組內相互關注即可。
多輪LSH哈希
為了進一步減小桶分佈不均的情況,可以用不同的哈希函數進行多輪哈希,具體參見原文。
共享QK中的自掩碼(Causal Masking for Shared-QK Attention)
在普通注意力中一個位置可以關注自己,但在共享QK中我們一般不考慮,從而我們禁止它去關注自己,除非沒有別的可以關注的了。
下表是幾種注意力方式的時空複雜度:
可逆Transformer
如上所述,attention的複雜度可以被減少為序列長度的線性級,但是,參數量佔的複雜度依舊很高,我們想要進一步減少。
可逆殘差網絡
對輸入x,殘差網絡的輸出是y=x+F(x),一個可逆層定義在一個輸入輸出對上:
然後,輸入輸出就可以是:
可逆Transformer
同上,我們用F表示注意力層,用G表示FF層:
FF層分組
可以進一步把Y(2)分組:
下表是所有變體的複雜度:
實驗
接下來,我們在imagenet64和enwik8-64K上實驗,其他設置詳見原文。
下圖是不同的方法在這兩個數據集上的表現,可以看到,無論是共享QK還是可逆Transformer,都不會影響效果。
下圖是不同哈希桶數的LSH注意力的表現。顯然,數量越多,效果越好,這是因為關注就越精確,同時模型代價就越高。
最後來看看Reformer的層數。下圖(左)是Big Reformer隨層變化的不同效果,20層依然無壓力。
而下圖(右)是普通注意力和LSH注意力在不同序列長度的速度比較,當序列很長的時候,LSH具有顯著的優勢。
小結
本文提出使用LSH(局部敏感哈希)和可逆網絡去加速Transformer訓練,節省內存以處理更長序列,在幾個實驗上取得了較好的效果。然而本文比較偏工程,而且算法也比較複雜,各位瞭解一下即可。
閱讀更多 sandag 的文章