元學習方法介紹

元學習方法介紹

人工智能的一個基本問題是它無法像人類一樣高效地學習。許多深度學習分類器顯示了超人的表現,但需要數百萬個訓練樣本。知識不共享,並且每個任務都獨立於其他任務進行訓練。在本文中,我們將該研究問題,然後檢查一些建議的解決方案。

問題

與人類相比,大多數最先進的深度學習方法都有兩個關鍵的弱點:

  1. 樣本效率:深度學習的樣本效率很差。例如,為了識別數字,我們通常每個數字需要6000個樣本。
  2. 可移植性差。我們不會從以前的經驗或學到的知識中學習。

元學習

那麼什麼是元學習?我們試圖將其定義為“學習如何學習”。但是實際上,我們還不知道確切的定義或解決方案。因此,它仍然是一個寬鬆的術語,指的是不同的方法。在本文中,我們將重點關注以下領域:

  • 循環模型
  • 元優化
  • 度量學習

少量學習(Few-Shot)

但讓我們先定義一些基本概念。在CIFAR-10,我們有10個不同類別的60000張圖片。換句話說,我們有10個分類任務,每個分類任務有6000個訓練樣本。在少樣本學習中,我們訓練的模型包含大量的任務,但每個任務只有一個或幾個樣本。我們的最終目標是將知識一般化,並將其應用到我們從未訓練過的新任務中。

例如,在任務1中,我們被要求學習3個表情符號。然後,用一個新的表情符號,我們通過訓練模型把它和之前的一個樣本聯繫起來。

元學習方法介紹

在我們的第二個任務中,我們用字母訓練它。

元學習方法介紹

我們用不同的任務重複這個過程很多次。一旦訓練完成,我們通過測試一個我們以前沒有執行過的任務來測量模型的通用性,識別漢字。該模型可以正確地將測試樣本與輸入關聯起來。

元學習方法介紹

我們可能想知道Few-Shot訓練和使用大數據集的傳統DL之間的區別。在DL中,我們使用正則化來確保我們沒有用一個小數據集過擬合我們的模型。但是通過使用這麼多的樣本和迭代來訓練模型,我們在我們的任務過擬合了。我們所學到的東西不能推廣到其他任務上。

讓我來演示一下DL中的一些問題。當我們測試數據集中不常見的樣本時,我們經常會陷入困境。例如,在玩具分類中,如下圖所示,黃色的玩具鴨分類很差。在Few-Shot訓練中,關鍵的目標是處理我們以前沒有訓練過的數據。

元學習方法介紹

在One-Shot訓練中,我們只會為每個類別提供一個訓練樣本。在下面的示例中,訓練包含多個數據集。每個數據集包含一個1-shot-5類的分類任務,即來自5個不同類的5個樣本。

元學習方法介紹

在這種One-Shot訓練中,我們經常訓練一個RNN來學習訓練數據和標籤。當我們用一個測試輸入表示時,我們應該正確地預測它的標籤。

元學習方法介紹

在元測試中,我們再次使用以前從未訓練過的類來提供數據集。在這個例子中,元學習的重點是學習對象分類的秘密。一旦我們學習了上百個任務,我們就不應該只關注單個的類。相反,我們應該發現對象分類的一般模式。因此,即使我們面對的是從未見過的類,我們也應該設法解決這個問題。

如果我們更聰明地收集任務,我們會學得更好。

Omniglot

在我們討論細節之前。讓我們介紹Omniglot。它是一個流行的Few-Shot學習數據集。以下是來自Omniglot的20幅代表不同的20個類的畫。

元學習方法介紹

循環模型

第一種元學習方法是循環模型。我們將數據輸入到一個類似於rnn的模型中,以記住我們目前看到的情況。當我們面對一個測試輸入時,我們從記憶中回憶它是什麼。然而,我們沒有足夠的內存來容納我們所看到的一切。循環模型存儲特徵,我們使用類似於word-embedding來關聯信息。

元學習方法介紹

讓我們先回顧一下內存網絡(MN)。MN使用一個控制器從輸入中提取特徵。然後我們使用這些特徵來訪問內存。

元學習方法介紹

例如,你接了一個電話,但你不能立即識別聲音。這個聲音聽起來很像你的堂兄(概率0.7),但也很像你哥哥(概率0.3)的聲音。在上圖中,每一行代表一個對象。我們計算每一行的權值w來衡量它與輸入的相關性。然後,我們計算所有行的加權和,以回憶與該輸入相關的內存。在我們的例子中,這個權重和指的是一個同學。

元學習方法介紹

記憶增強神經網絡是利用RNN作為外部記憶網絡的元學習方法之一。在監督學習中,我們在同一時間步t中提供輸入和標籤。但是,在這個模型中,標籤直到下一次時間戳t+1才被提供(如下圖所示)。這是一種阻止RNN單元將輸入直接映射到類標籤的技術。我們希望我們的模型記住經驗。

元學習方法介紹

訓練記憶增強神經網絡

元學習方法介紹

在記憶增強神經網絡中,我們使用外部存儲器來存儲樣本表示和類標籤信息。通常作為LSTM實現的控制器從輸入中生成一個鍵,該鍵要麼存儲在外部內存中,要麼用於檢索特定的內存。然後用反向傳播對整個系統進行訓練。具體建議讀者閱讀原始論文。

元學習方法介紹

如果我們能從經驗中學習,我們會學得更好。

學習優化器

在第二種元學習方法中,我們試圖更有效地優化模型。在每個任務的訓練之後,我們可以使用這些信息來更新模型。

元學習方法介紹

然而,我們正在學習一個特定的任務,而不是找到所有學習任務背後的基礎知識。因此,我們不是立即更新模型,而是等待一批任務完成。稍後,我們將從這些任務中學到的所有知識合併到一個更新中。這種方法實現了“學我們所學(learn what we learn)”的概念。

元學習方法介紹

模型無關元學習(MAML)利用上面的概念來更新模型。它是簡單的,它幾乎是相同的,我們的傳統DL梯度下降與增加一行代碼如下。在這裡,我們不會在每個任務之後立即更新模型參數。相反,我們一直等到一批任務完成

元學習方法介紹

元學習方法介紹

對於每個任務,我們使用反向傳播來計算建議的模型。

元學習方法介紹

然後合併訓練任務的損耗,並將損耗進行反向傳播,進行下一次模型更新:

元學習方法介紹

從概念上講,我們正在尋找一個最小化任務損失的模型。

元學習方法介紹

從圖形上看,每個任務可以將模型參數驅動到不同的方向。通過引入元學習步驟和少樣本數據集,我們學習了一個只處理任務而不處理樣本的模型。

還有其他一些優化器的目標是更有效地學習。例如,OpenAI提出了另一個名為Reptile的優化器。在隨機梯度下降法中,我們計算一個梯度下降並更新模型。然後我們為下一次迭代獲取下一批數據。在Reptile中,它對每個任務執行多步梯度下降,並使用最後一步的結果更新模型,使用與運行平均值類似的概念。

元學習方法介紹

在OpenAI的論文中,它從數學上論證了為什麼MAML和Reptile的行為相似。

我們建議你們閱讀原始的論文。

如果我們優化得更好,我們就學得更好。

度量學習

我們將討論的第三種元學習方法是度量學習。你還記得逐像素的圖片嗎?不。為了學習,我們需要用最少的內存獲取最多的信息。因此,第三種元學習方法關注的是我們如何提取特徵,但不要過度提取。在Siamese神經網絡中(如下圖所示),我們使用兩個具有相同模型參數值的相同網絡來提取兩個樣本的特徵。然後將提取出的特徵輸入鑑別器,判斷兩個樣本是否屬於同一類對象。例如,我們可以計算其特徵向量的餘弦相似度(p)。如果它們相似,p應該接近1。否則,它們應該接近0。根據樣本的標籤和p,我們對網絡進行相應的訓練。簡而言之,我們希望找到使樣例屬於同一類或將它們區分開來的特性。

元學習方法介紹

還有一種方法叫做Matching網絡,它與Siamese神經網絡非常相似。

元學習方法介紹

g和f是特徵提取器,使用深度來提取特徵,用於我們的輸入和測試樣本。通常,g和f是相同的,共享相同的深度網絡。然後我們比較它們的相似度,並使用一個softmax函數來計算它們是否相似。同樣,我們從預測中計算一個成本函數來訓練我們的特徵提取器。以下是數學公式:

元學習方法介紹

元學習方法介紹

元學習方法介紹

如果我們知道如何更好地表示數據,我們就學得更好。

其他方法

還有其他元學習方法關注於如何使用更好的超參數優化模型。例如,學習如何巧妙地調整超參數。或者,我們可以結合DL層來動態地形成一個新的模型。

元學習方法介紹

這些方法使模型更準確,但不一定更有效的學習較少的樣本。所以我們不會在元學習的討論中進一步討論。

想法

學習如何更好地學習不僅是對機器的挑戰,也是對人類的挑戰。元學習已經被研究了幾十年,但是我們還沒有完全理解它是如何實現的。為了結束我們的思想,這裡是目前提高學習效率相關的研究領域。

  • 收集更好的信息來學習。
  • 更好地從過去的經驗中學習。
  • 更好地知道如何表示信息。
  • 如何更好地優化模型。
  • 探索更好的方法。
  • 聯繫變得更好。
  • 泛化變得更好。


分享到:


相關文章: