在瀏覽器中進行深度學習:TensorFlow.js (八)生成對抗網絡

Generative A

dversarial Network 是深度學習中非常有趣的一種方法。GAN最早源自Ian Goodfellow的這篇論文。LeCun對GAN給出了極高的評價:

“There are many interesting recent development in deep learning…The most important one, in my opinion, is adversarial training (also called GAN for Generative Adversarial Networks). This, and the variations that are now being proposed is the most interesting idea in the last 10 years in ML, in my opinion.” – Yann LeCun

那麼我們就看看GAN究竟是怎麼回事吧:

在瀏覽器中進行深度學習:TensorFlow.js (八)生成對抗網絡

如上圖所示,GAN包含兩個互相對抗的網絡:G(Generator)和D(Discriminator)。正如它的名字所暗示的那樣,它們的功能分別是:

  • Generator是一個生成器的網絡,它接收一個隨機的噪聲,通過這個噪聲生成圖片,記做G(z)。
  • Discriminator是一個鑑別器網絡,判別一張圖片或者一個輸入是不是“真實的”。它的輸入x是數據或者圖片,輸出D(x)代表x為真實圖片的概率,如果為1,就代表100%是真實的圖片,而輸出為0,就代表不可能是真實的圖片。

在訓練過程中,生成網絡G的目標就是儘量生成真實的圖片去欺騙判別網絡D。而D的目標就是儘量把G生成的圖片和真實的圖片分別開來。這樣,G和D構成了一個動態的“博弈過程”。在最理想的狀態下,G可以生成足以“以假亂真”的圖片G(z)。對於D來說,它難以判定G生成的圖片究竟是不是真實的,因此D(G(z)) = 0.5。

最後,我們就可以使用生成器和隨機輸入來生成不同的數據或者圖片了。

上面的描述大家可能都能理解,但是把它變成數學語言,可能你就蒙B了。

在瀏覽器中進行深度學習:TensorFlow.js (八)生成對抗網絡

如上圖所示,x是輸入,z是隨機噪聲。D(x)是鑑別器的判定數據為真的概率,D(G(z))是判定生成數據為真的概率。生成器希望這個D(G(z))越大越好,這個時候整個表達式的值應該變小。而鑑別器的目的是能夠有效區分真實數據和假數據,所以D(x)應該趨向於變大,D(G(z))趨向於變小,整個表達式就變大。也就是說訓練過程,生成器和辨別器互相對抗,一個使上述表達式變小,另一個使其變大,最後訓練趨向於平衡,而生成器這時候應該生成真假難辨的數據,這就是我們的最終目的。

在瀏覽器中進行深度學習:TensorFlow.js (八)生成對抗網絡

上圖是GAN算法訓練的具體過程,這裡我們不做過多的解釋,直接運行一個例子。

在瀏覽器中進行深度學習:TensorFlow.js (八)生成對抗網絡

我們用MINST數據集來看看如何使用TensorflowJS來訓練一個GAN,模擬生成手寫數字。

代碼見我的codepen

function gen(xs) {
const l1 = tf.leakyRelu(xs.matMul(G1w).add(G1b));
const l2 = tf.leakyRelu(l1.matMul(G2w).add(G2b));
const l3 = tf.tanh(l2.matMul(G3w).add(G3b));
return l3;

}

function disReal(xs) {
const l1 = tf.leakyRelu(xs.matMul(D1w).add(D1b));
const l2 = tf.leakyRelu(l1.matMul(D2w).add(D2b));
const logits = l2.matMul(D3w).add(D3b);
const output = tf.sigmoid(logits);
return [logits, output];
}

function disFake(xs) {
return disReal(gen(xs));
}

GAN的兩個網絡分別用gen和disReal創建。gen是生成器網絡,disReal是辨別器的網絡。disFake是把生成數據用辨別器來辨別。這裡的網絡使用leakyrelu。使得輸出在-inf到+inf,利用sigmoid映射到【0,1】,這是辨別器模型輸出一個0-1之間的概率。

在瀏覽器中進行深度學習:TensorFlow.js (八)生成對抗網絡

通常我們會創建一個比生成器更復雜的鑑別器網絡使得鑑別器有足夠的分辨能力。但在這個例子裡,兩個網絡的複雜程度類似。

計算損失的函數使用 tf.sigmoidCrossEntropyWithLogits,值得注意的是,在最新的0.13版本中,這個交叉熵被移除了,你需要自己實現該方法。

訓練過程如下:

async function trainBatch(realBatch, fakeBatch) {
const dcost = dOptimizer.minimize(
() => {
const [logitsReal, outputReal] = disReal(realBatch);
const [logitsFake, outputFake] = disFake(fakeBatch);

const lossReal = sigmoidCrossEntropyWithLogits(ONES_PRIME, logitsReal);
const lossFake = sigmoidCrossEntropyWithLogits(ZEROS, logitsFake);
return lossReal.add(lossFake).mean();
},
true,
[D1w, D1b, D2w, D2b, D3w, D3b]
);
await tf.nextFrame();
const gcost = gOptimizer.minimize(
() => {
const [logitsFake, outputFake] = disFake(fakeBatch);

const lossFake = sigmoidCrossEntropyWithLogits(ONES, logitsFake);
return lossFake.mean();
},
true,
[G1w, G1b, G2w, G2b, G3w, G3b]
);
await tf.nextFrame();

return [dcost, gcost];
}

訓練使用了兩個optimizer,

  1. 第一步,計算實際數據的辨別結果和1的交叉熵,以及生成器生成數據的辨別結果和0的交叉熵。也就是說,我們希望辨別器儘可能的判斷出生成數據都是假的而實際數據都是真的。使得這兩個交叉熵的均值最小。
  2. 第二步開始對抗,要讓生成數據儘可能被判別為真。

下圖是某個訓練過程的損失:

在瀏覽器中進行深度學習:TensorFlow.js (八)生成對抗網絡

這個是經過1000個迭代後的生成圖:

在瀏覽器中進行深度學習:TensorFlow.js (八)生成對抗網絡

大家可以嘗試調整學習率,增加網絡複雜度,加大迭代次數來獲得更好的生成模型。

GAN的學習其實還是比較複雜的,參數和損失選擇都不容易,好在有一些現成的工具可以使用,另外推薦大家去https://poloclub.github.io/ganlab/,提供了很直觀的GAN學習的過程。這個也是用TensorflowJS來實現的。


參考









分享到:


相關文章: