在瀏覽器中進行深度學習:TensorFlow.js (二)線性迴歸

筆者在上一篇文章中介紹了TensorFlow.js中的基本概念,以及機器學習的數學基礎,線性代數的基本知識。在這一遍文章裡,我們來看一看如何利用TensorFlow.js來構建數學模型,以及進行學習的基本過程。

學習的過程基本如下:

  1. 準備訓練數據
  2. 構建一個模型
  3. 利用訓練數據和模型,進行迭代的學習
  4. 模型訓練完畢,用這個模型對新的數據進行預測(這裡我們先略過對模型的驗證部分)

好了,我們以最簡單的線性迴歸為例子,看看這個過程。

準備數據

在瀏覽器中進行深度學習:TensorFlow.js (二)線性迴歸

如上圖所示,我在二維座標系中生成了7個點,讓它們在我假想的某條直線附近。我以這幾個點作為我的訓練數據。

訓練數據的初始化代碼如下,這裡tx是所有點數據的x座標,ty是所有點數據的座標。

const train_x = tf.tensor1d(tx);
const train_y = tf.tensor1d(ty);

模型選擇

所有的模型都是錯的,有的模型更好。

所謂的模型,也就是一個函數f,對應於某個輸入數據,計算出某些輸出數據。模型可以複雜,可以簡單。簡單的模型不一定不好,複雜的模型也不一定好。

我們用線性模型舉例,數學上就是假定 Y = wX + b

在這個模型中,有兩個參數需要確定,w和b。

模型既然是個函數,那麼它的代碼也就很容易理解了:

const f = x => w.mul(x).add(b); 

當然你也可以這樣寫:

const f = function(x){
return w.mul(x).add(b);
}
}

迭代學習

學習的過程我們稱作訓練,訓練通常是一個迭代的過程,這個過程中,通常需要這幾樣東西:

  • 一個損失函數(loss function),損失函數定義了模型是不是足夠好,通常loss越小越好。
  • 一個優化器 (optimizer),優化器通過某種算法來決定如何改變參數的值,使得損失函數最小化。
  • 迭代循環, 通過循環 -> 調用優化器,得到新的參數,計算損失, 最終當損失足夠小時,可以認為訓練結束了。

訓練代碼如下:

初始化參數,這裡使用隨機數來作為參數的初始值。(注意,初始參數並不總是隨機選擇的。)

const w = tf.variable(tf.scalar(Math.random()));
const b = tf.variable(tf.scalar(Math.random()));

初始化學習參數,

  • numIterations是迭代的次數,一般次數越多,模型的擬合就越好,但是就需要花費越多的計算
  • learningRate是學習率,這個值越大,學的速度就越快,但是也會更加容易錯過極值點。
const numIterations = 200;
const learningRate = 1;

選擇一個優化器,這裡我選擇了adam。TensorFlow.js提供了多種優化器,例如sgd,momentum等等,大家可以根據自己的需要來選擇。

const optimizer = tf.train.adam(learningRate);

對於損失函數,我們採用的是均方差

在瀏覽器中進行深度學習:TensorFlow.js (二)線性迴歸

const loss = (pred, label) => pred.sub(label).square().mean();

或者可以寫作:

function loss(predictions, labels) {
const meanSquareError = predictions.sub(labels).square().mean();
return meanSquareError;
}

然後就是訓練的過程啦:

for (let iter = 0; iter < numIterations; iter++) {
optimizer.minimize(() => {
const loss_var = loss(f(train_x), train_y);
loss_var.print();
return loss_var;
})
}

在訓練過程中,我們調用tensor的print()方法打印出損失的值,看看訓練過程是不是收斂。當選擇的模型,參數,優化器不合適的時候,有可能訓練過程並不收斂。

訓練的結果我們就等到了w和b的值。也就是確定了直線的斜率和截距。

我們可以看到學習過程中是如何慢慢收斂到最後的結果的直線。

總結

本文描述了一個使用tensoflow.js來進行最簡單的線性迴歸模型的學習的過程。希望大家可以通過這個簡單的例子瞭解機器學習的基本思路。


分享到:


相關文章: