論文分享:深度相互學習(Deep Mutual Learning)

寫在前面

這是⼀篇CVPR2017年的⽂章,筆者今天才看到,表示慚愧,看完以後覺得寫得很好,重點是簡單幹淨且實用,我認為這篇⽂章算是⼀篇佳作,因此在這⾥給⼤家分享。這篇論⽂主要講的通過多個模型的相互學習,從而提升各個模型各⾃的性能,效果優於模型蒸餾。與傳統的優化⽅法相比, Deep Mutual Learning(DML)不是幫助我們找到⼀個更好的或者更深層次的訓練損失最小值,而是幫助我們找到⼀個更⼴泛或者更可靠的最⼩值,它能更好地概括測試數據,更加健壯。某種程度上來說,很像⼀種深度的正則化⽅式。

論文分享:深度相互學習(Deep Mutual Learning)

DML 網絡結構

Deep Mutual Learning

⽂章的核心思想⾮常簡單,就是模型的相互學習,相互學習的⽅式就是增加⼀個監督不同模型輸出的loss,使得不同模型的預測分佈⼀致。在統計學⾥,有⼀個概念叫KL散度,在GAN⽹絡⾥⾯很常⽤,KL 散度是⼀種衡量兩個概率分佈的匹配程度的指標,兩個分佈差異越⼤,KL散度越⼤。所以這⾥的監督loss就是KL散度。KL散度寫成公式,長下面這樣,p1和p2分別表示兩個模型的概率預測:

論文分享:深度相互學習(Deep Mutual Learning)

KL散度

再加上各個模型的交叉熵損失,K個模型相互學習的完整loss可以寫為:

論文分享:深度相互學習(Deep Mutual Learning)

K個模型的訓練loss

說⽩了,就是各個模型正常訓練,最後最各個模型的輸出加上KL散度的監督 loss,優化策略也很簡單,在每次訓練迭代中,都計算兩個模型的預測,並根據另⼀個模型的預測更新兩個⽹絡的參數,這個是通過KL-loss實現的。

論文分享:深度相互學習(Deep Mutual Learning)

優化策略

實驗結果

實驗結果有很多,這⾥就放⼀個吧,可以看出,經過Deep Mutual Learning後,參加相互學習的模型各⾃的性能都得到了提升,在CIFAR-100上達到了兩個點的提升,還是很可觀的。而且我們還能發現一個現象,兩個模型使用同樣的結構也能明顯漲點,說明確實是這個優化策略起了作用。

論文分享:深度相互學習(Deep Mutual Learning)

DML對⽐實驗

寫在後面

⽂章在最後還討論了為了深度相互學習策略可以奏效。作者⽐較了DML訓練後的模型和單個模型在添加⾼斯噪聲前後訓練的損失變化。從左圖可以看出兩個模型的訓練極⼩值是幾乎相同的,但是在加⼊⾼斯噪聲後,單模型的訓練損失變化較大,而DML模型的訓練損失變化較小。這表明DML模型找到了⼀個更⼴泛,健壯的最⼩值,進⽽能夠提供更好的泛化性能。

論文分享:深度相互學習(Deep Mutual Learning)

加⼊⾼斯噪聲前後loss變化

那麼問題來了,DML是怎樣找到這個更加魯棒的最⼩值的呢?DML會通過KL散度判斷相似性要求每個網絡與其輔助網絡的概率估計分佈一致,通俗的說,如果給定網絡預測為零,⽽其對等網絡預測為⾮零,則該網絡將受到嚴重懲罰。總體上,DML是指,當每個網絡獨⽴地將⼀個關注點放在⼀個小的次概率集合上時,DML中的所有網絡都傾向於聚合它們對次級概率的預測。也就是說所有的網絡把重心放在次概率上,並且把更多重心放在更明顯的次概率上。因此,DML是通過對“合理的” 次概率預測的相互概率匹配來尋找更寬泛的最小值。很合理啊。


分享到:


相關文章: