告別規範化!MIT 谷歌等提出全新殘差學習方法,效果驚豔

告別規範化!MIT 谷歌等提出全新殘差學習方法,效果驚豔

現階段最好的神經網絡結構中,規範化(normalization)是一項重要技術。儘管規範化有利於網絡訓練的原因仍不清楚,但它已經被廣泛的應用於提高模型的泛化能力、穩定訓練過程、以及加速模型收斂,同時規範化也使得使用更高的學習率訓練模型成為了可能。

近日,來自 MIT、谷歌大腦、斯坦福的三名研究人員提出了一項具有挑戰性的工作(這項工作在 Facebook 工作期間完成),他們認為規範化的好處並不是唯一的。論文《Fixup Initialization: Residual Learning Without Normalization》中提出了一種固定更新初始化(fixed-update initialization,Fixup),該論文已被 ICLR2019 接收。我們對該方法做了簡要介紹,本文是 AI 前線第 70 篇論文導讀。

背 景

在人工智能應用發展火熱的今天,創新的網絡模型及其訓練技巧成為了人工智能發展的核心技術。例如在 2016 年何凱明及其團隊提出 ResNet 的殘差學習結構後,大部分 SOTA 的識別系統都是基於這種在卷積網絡與加性殘差連接堆疊的結構上添加一些規範化機制而設計的。除了圖像分類,多種類的規範化技術在很多其他領域,如機器翻譯、生成模型等方面都扮演著重要的角色。

對於規範化技術能產生上述好處的原因,學術界目前還沒有達成一種共識。這篇論文對以下這些方面進行了研究:

  1. 在不使用規範化的情況下,深度殘差網絡能否可靠地訓練?
  2. 在不使用規範化的情況下,深度殘差網絡能否使用相同的學習率進行訓練、能否以相同的速度收斂以及能夠獲得一樣好的泛化能力,或者更好?

令人驚訝的是,研究人員發現對於上述的兩個問題,回答是肯定的。同時,這篇論文也對以下幾個問題進行了討論(後文均使用第一人稱編譯):

  • 為什麼規範化有助於訓練過程。在殘差網絡的初始化過程中,我們對一種低邊界的梯度範數進行了研究,這可以很好的解釋為什麼使用標準初始化、規範化技術對於深度殘差網絡在大學習率的情況下是必要的。
  • 在不使用規範化的情況下進行訓練。我們提出了 Fixup 技術——一種通過修改網絡的結構,對殘差分支的標準初始化進行放縮的方法。Fixup 在不使用規範化對極深的殘差網絡訓練時,可以使用最大的學習率。
  • 圖像分類。我們使用 Fixup 代替了批規範化過程,在圖像分類的基準數據庫 CIFAR-10(Wide-ResNet)與 ImageNet(ResNet)上進行了驗證。實驗結果表明,在使用了適當的正則化後,使用 Fixup 的方法訓練出的網絡可以與使用規範化的方法訓練的網絡達到相同的效果。
  • 機器翻譯。我們使用 Fixup 替代了規範化層,在機器翻譯的基準數據庫 IWSLT 和 WMT 使用 Transformer(一種發表於 2017 年 NeurIPS 的目前最好的機器翻譯算法 )模型進行了訓練。發現該方法可以超過基準線並在相同的框架下取得的目前最好的結果。


告別規範化!MIT 谷歌等提出全新殘差學習方法,效果驚豔


圖 1:左:ResNet 標準殘差塊,紅色表示批歸一化層。中:一種簡單的網絡塊,在堆疊時可以保持穩定。右:Fixup 結構,通過添加偏置更好地提升網絡穩定性。

使用標準初始化的 ResNet 會導致梯度爆炸

標準初始化方法是為了初始化網絡參數,防止發生梯度消失或梯度爆炸而提出的。然而,在不使用如 BatchNorm 等一些規範化的方法時,標準初始化並不能很好的解決殘差連接中梯度爆炸的問題。針對該問題,Balduzzi 提出了一種 ReLU 網絡,我們使用正定同質激活函數將 ReLU 網絡的思想泛化到殘差網絡。不使用規範化層時,原始的 ResNet 中每個殘差塊的激活函數可以表達為:

告別規範化!MIT 谷歌等提出全新殘差學習方法,效果驚豔


為了便於分析,我們分析使用了正定同質函數,首先介紹兩個定義:

告別規範化!MIT 谷歌等提出全新殘差學習方法,效果驚豔


p.h. 函數的例子在神經網絡中是大量存在的。包括很多無偏的線性操作,例如全連接層、卷積層、池化層、串聯層和 dropout 等,當然也包括 ReLU 這種非線性函數。此外,我們還引入了以下兩個命題:

告別規範化!MIT 谷歌等提出全新殘差學習方法,效果驚豔


  1. 猜想 1 一個使用交叉熵損失的分類網絡 f 可以看成是一個若干個網絡塊組成的集合,其中每個網絡塊都是一個 p.h. 函數
  2. 猜想 2 FC 層的權重是從零均值對稱分佈中獨立同分布採樣得到的。

在一個使用了 ReLU 函數的殘差網絡中,如果我們移除了所有的規範化層,根據這些猜想可得到所有的偏置都被初始化為 0。

我們的結論可以總結為如下兩條,證明過程在論文的附件中:

告別規範化!MIT 谷歌等提出全新殘差學習方法,效果驚豔


在圖 2 中,我們提供了不使用規範化時 ResNet 中的三種 p.h. 集合的例子。基於理論 2,如果未歸一化概率值 z 在初始化的時候被放大了,這三個例子都有可能會遭遇梯度爆炸問題。而在 ResNet 中,如果不使用規範化,而使用傳統方法初始化參數,上述情況是很有可能出現的。這就是我們提出一種新的初始化方法的動機。

告別規範化!MIT 谷歌等提出全新殘差學習方法,效果驚豔


圖 2:ResNet 中不使用規劃範時 p.h. 集的例子。(1) 最大池化前的第一個卷積層;(2)softmax 層之前的全連接層;(3) 主幹網絡中的空間降採樣層和與其相對應的殘差分支中的卷積層。

Fixup:在 SGD 的每一步更新殘差網絡

在這一部分,我們提出了一種由上至下的新的初始化方法。通過對標準初始化簡單地放縮,可以保證在合適的範圍更新網絡函數。首先,我們假設學習率η並設定我們的目標:

告別規範化!MIT 谷歌等提出全新殘差學習方法,效果驚豔


從另一個角度講,我們的目標是設計一種初始化方法,從而使 SGD 可以不依賴深度且以正確的比例對網絡函數進行更新。

我們將捷徑(Shortcut)定義為殘差網絡中從輸入到輸出的最短路徑。捷徑是一種典型的可訓練的淺層網絡。假設使用標準方法對捷徑進行初始化,我們來分析一下殘差分支的初始化情況。

殘差分支同步更新網絡。

最初,我們發現 SGD 在更新每個殘差分支的時,會在高度相關的方向上改變網絡的輸出。這意味著,如果一個殘差網絡含有L個殘差分支,那麼每個殘差分支在一次迭代中,應當對網絡的輸出平均改變Θ(η/L) 從而達到對整個網路輸出改變為Θ(η)。

對於標量分支的研究。

告別規範化!MIT 谷歌等提出全新殘差學習方法,效果驚豔


最終我們可以得到,當且僅當滿足以下條件時:

告別規範化!MIT 谷歌等提出全新殘差學習方法,效果驚豔


告別規範化!MIT 谷歌等提出全新殘差學習方法,效果驚豔


偏置和乘法器的影響。

通過對殘差分支中權重的放縮,一個殘差網絡在 SGD 的每一步更新值為Θ(η),到這裡我們的目標已經達成了。然而,為了使訓練性能與使用規範化時相匹配,我們還需要考慮偏置和乘法器的情況。

在全連接層和卷積層使用偏置是一種很通用的方法。在規範化方法中,偏置和縮放參數被用來重建特徵經過規範化操作後的表徵能力。因為權重層與激活層的輸入 / 輸出均值可能不同,因此在不使用規範化的殘差網絡中插入偏置項同樣可以有助於訓練網絡。我們發現在每個權重層和非線性激活層前僅插入一個標量偏置可以顯著地提高訓練性能。

乘法器會對殘差分支的輸出進行放縮,這與批規範化中的放縮參數功能類似(批規範化中兩個重要參數為:scale 和 shift,即放縮與滑動)。我們發現,在每個殘差分支中插入一個標量乘法器可以模仿使用規範化的網絡中權重模值動態變化的過程。最終,我們得到了訓練不使用規範化方法的殘差網絡的完整方法:

Fixup 初始化

告別規範化!MIT 谷歌等提出全新殘差學習方法,效果驚豔


實 驗

我們在 CIFAR-10 數據庫上測試了第一輪迭代(即數據庫中所有圖像通過模型一次)結束後模型的測試準確率,發現對於多種深度的卷積神經網絡,在學習率相同的情況下 Fixup 可以達到與 BatchNorm 相同的效果。實驗結果如下圖所示:

告別規範化!MIT 谷歌等提出全新殘差學習方法,效果驚豔


圖 1 CIFAR-10 數據庫上各種方法的訓練結果比較,值越大表示結果越好。

此外,我們還對比了使用 Fixup 訓練不同深度的 ResNet 和其他方法在 ImageNet 數據庫上的結果,實驗結果如下表所示:

告別規範化!MIT 谷歌等提出全新殘差學習方法,效果驚豔


可以看出 Fixup 與組規範化方法的性能不相上下,該實驗中通過交叉驗證得到了三種方法的最優偏置標量,對於批規範化、組規範化和 Fixup 分別為 0.2,0.1 和 0.7。

此外,在機器翻譯的 SOTA 方法中我們同樣使用 Fixup 代替規範化層進行了實驗。我們驚奇地發現,當使用 Fixup 代替規範化層可以更好地防止模型過擬合,我們認為這要歸功於 dropout 操作的正則化。在兩個數據庫上,使用 Fixup 都取得了目前最好的結果,實驗結果如下表:

告別規範化!MIT 谷歌等提出全新殘差學習方法,效果驚豔


結 論

Fixup 通過對標準初始化進行適當的放縮來解決訓練過程中梯度爆炸和梯度消失的問題。在不使用規範化的情況下,使用 Fixup 訓練的殘差網絡可以達到與使用規範化訓練時相同的穩定性,甚至在網絡層數達到了 10000 層時也可以不相上下。此外,在使用了合適正則化方法的情況下,通過 Fixup 訓練的不使用規範化的殘差網絡在圖像分類和機器翻譯上達到了目前最好的水平。在理論和應用兩方面,這篇工作都給出了一種新的嘗試。在理論層面,去除規範化有利於更簡便地分析殘差網絡。在應用層面,Fixup 對於發展正則化方法提供了可能,比如結合 ZeroInit 等。

論文鏈接:https://arxiv.org/pdf/1901.09321.pdf


分享到:


相關文章: