批歸一化到底做了什麼?DeepMind研究者進行了拆解

選自arXiv

作者:Soham De、Samuel L. Smith

機器之心編譯

參與:魔王

批歸一化有很多作用,其最重要的一項功能是大幅提升殘差網絡的最大可訓練深度。DeepMind 這項研究探尋了其中的原因,並進行了大量驗證。

批歸一化到底做了什麼?DeepMind研究者進行了拆解


論文鏈接:https://arxiv.org/abs/2002.10444

批歸一化用處很多。它可以改善損失分佈(loss landscape),同時還是效果驚人的正則化項。但是,它最重要的一項功能出現在殘差網絡中——大幅提升網絡的最大可訓練深度

DeepMind 近期一項研究找到了這項功能的原因:在初始化階段,批歸一化使用與網絡深度的平方根成比例的歸一化因子來縮小與跳躍連接相關的殘差分支的大小。這可以確保在訓練初期,深度歸一化殘差網絡計算的函數由具備表現良好的梯度的淺路徑(shallow path)主導。

該研究基於此想法開發了一種簡單的初始化機制,可以在不使用歸一化的情況下訓練非常深的殘差網絡。研究者還發現,儘管批歸一化可以維持模型以較大的學習率進行穩定訓練,但這隻在批大小較大的並行化訓練中才有用。這一結果有助於釐清批歸一化在不同架構中的不同功能。

批歸一化到底幹了什麼

跳躍連接和批歸一化結合起來可以大幅提升神經網絡的最大可訓練深度。

DeepMind 研究者將殘差網絡看作多個路徑的集成,這些路徑共享權重,但是深度各有不同(與 Veit 等人 2016 年的研究類似),進而發現批歸一化如何確保非常深的殘差網絡(數萬層)在訓練初期被僅包含幾十個層的淺路徑主導。原因在於,批歸一化使用與網絡深度的平方根成比例的因子縮小與跳躍連接相關的殘差分支的大小。這就為深度歸一化殘差網絡在訓練初期可得到高效優化提供了直觀解釋,它們只是把具備表現良好的梯度的淺層網絡集成起來罷了。

上述觀察表明,要想在不使用歸一化或不進行認真初始化的前提下訓練深度殘差網絡,只需要縮小殘差分支即可。

為了確認這一點,研究者改動了一行代碼,實現不使用歸一化的深度殘差網絡訓練(SkipInit)。結合額外的正則化後,SkipInit 網絡的性能可與經過批歸一化的對應網絡不相上下(該網絡使用常規的批大小設置)。

為什麼深度歸一化殘差網絡是可訓練的?

殘差分支經過歸一化後,假設 ˆf_i 的輸出方差為 1。每個殘差塊的方差增加 1,則第 i 個殘差塊前的激活的預期方差為 i。因此,對於任意遍歷第 i 個殘差分支的路徑,其方差縮小到 1/i,這說明隱藏層激活縮小到 1/√ i。

如圖 3 所示,該縮小因子很強大,可確保具備 10000 個殘差塊的網絡 97% 的方差來自遍歷 15 個或者更少殘差分支的淺路徑。典型殘差塊的深度與殘差塊總數 d 成比例,這表明批歸一化將殘差分支縮小到 1/√ d。

批歸一化到底做了什麼?DeepMind研究者進行了拆解

圖 3:此圖模擬了初始化階段不同深度的路徑對 logits 方差的貢獻。

為了驗證這一觀點,研究者評估兩個歸一化殘差網絡的不同通道的方差,以及批統計量(batch statistics),如下圖 4 所示。

圖 4(a) 中,深度線性 ResNet 的跳躍路徑方差幾乎等於當前深度 i,而每個殘差分支末端的方差約為 1。這是因為批歸一化移動方差約等於深度,從而證實歸一化將殘差分支縮小到原本的 1/√ i。

圖 4(b) 中,研究者在 CIFAR-10 數據集上評估使用 ReLU 激活函數的卷積 ResNet。跳躍路徑的方差仍與深度成正比,但係數略低於 1。這些關聯也導致批歸一化移動平均數的平方隨著深度的增加而增大。

批歸一化到底做了什麼?DeepMind研究者進行了拆解

圖4。

這就為「深度歸一化殘差網絡是可訓練的」提供了簡潔的解釋。這一觀點可以擴展至其他歸一化方法和模型架構。

SkipInit:歸一化的替代方案

研究者發現,歸一化之所以能夠確保深度殘差網絡的訓練,是因為它在初始化階段按與網絡深度平方根成正比的歸一化因子縮小殘差分支。

為了驗證該觀點,研究者提出了一個簡單的替代方法——SkipInit:在每個殘差分支末端放置一個標量乘數,並將每個乘數初始化為 α。

批歸一化到底做了什麼?DeepMind研究者進行了拆解

圖 1:A) 使用批歸一化的殘差塊。B) SkipInit 用一個可學習標量 α 替代了批歸一化。

移除歸一化之後,只需改動一行代碼即可實現 SkipInit。研究者證明,按 (1/ √ d) 或更小的值初始化 α 就可以訓練深度殘差網絡(d 表示殘差塊數量)。

研究者引入了 Fixup,它也可以確保殘差塊在初始化時表示 identity。但是,Fixup 包含多個額外組件。在實踐中,研究者發現 Fixup 的組件 1 或組件 2 就足以在不使用歸一化的前提下訓練深度 ResNet-V2 了。

實證研究

下表 1 展示了 n-2 Wide-ResNet 在 CIFAR-10 數據集上訓練 200 epoch 後的平均性能,模型深度 n 在 16 到 1000 層之間。

批歸一化到底做了什麼?DeepMind研究者進行了拆解

表 1:批歸一化使得我們可以訓練深度殘差網絡。然而在殘差分支末端添加標量乘數 α 後,不使用歸一化也能實現同樣的效果。

下表 2 驗證了,當 α = 1 時使用 SkipInit 無法訓練深度殘差網絡,因此必須縮小殘差分支。研究者還確認了,對於未經歸一化的殘差網絡,只確保激活函數不在前向傳播上爆炸還不夠(只需在每次殘差分支和跳過路徑合併時將激活乘以 (1/ √ 2) 即可實現)。

批歸一化到底做了什麼?DeepMind研究者進行了拆解

表 2:如果 α = 1,我們無法訓練深度殘差網絡。

批歸一化的主要功能是改善損失分佈,增加最大穩定學習率。下圖 5 提供了 16-4 Wide-ResNet 在 CIFAR-10 數據集上訓練 200 epoch 後的平均性能,批大小的範圍很大。

批歸一化到底做了什麼?DeepMind研究者進行了拆解

圖 5:使用批歸一化要比不使用獲得的測試準確率更高,研究者還能夠以非常大的批大小執行高效訓練。

為了更好地理解批歸一化網絡能夠以更大批大小進行高效訓練的原因,研究者在下圖 6 中展示了最優學習率,它可以最大化測試準確率、最小化訓練損失。

批歸一化到底做了什麼?DeepMind研究者進行了拆解

圖 6:使用和不使用批歸一化情況下的最優學習率。

研究者在 ImageNet 數據集上對 SkipInit、Fixup 初始化和批歸一化進行了實驗對比,證明 SkipInit 可擴展至大型高難度數據分佈。

下表 3 展示了最優驗證準確率。研究者發現卷積層包含偏置可使 SkipInit 的驗證準確率出現小幅提升,因此研究者在所有 SkipInit 運行中添加了偏置。SkipInit 的驗證性能與批歸一化相當,與使用標準批大小 256 的 Fixup 相當。但是,當批大小非常大時,SkipInit 和 Fixup 的性能不如批歸一化。

批歸一化到底做了什麼?DeepMind研究者進行了拆解

表 3:研究者訓練了 90 個 epoch,並執行網格搜索,以找出最優學習率,從而最大化模型在 ImageNet 數據集上的 top-1 驗證準確率。


分享到:


相關文章: