Tensorflow2入門講解(一)

一.系列前言

Tensorflow2剛出了v2版本,在1.x版本的基礎上,v2版本集成度更高,並且封裝了keras作為高級API,使用Tensorflow的時候還可以用keras的Sequential模型快速搭建神經網絡,對初學者十分友好,但缺點在於Tensorflow2推廣的時間比較短,網上的開源代碼基本都基於Tensorflow1,不過除了新版本,我們肯定學習新版本呀,Tensorflow2的官方文檔翻譯版本可以參看https://www.tensorflow.org/guide?hl=zh-CN,筆者最近也在學Tensorflow2,希望能把學到的東西和大家分享。

二.為什麼選擇tensorflow

Tensorflow2入門講解(一)

深度學習的各大框架

目前熱門的深度學習框架主要有Tensorflow、Pytorch、Keras、Mxnet、Theano。後兩者都是比較老的框架,目前已經冷門了。Pytorch是facebook推出的一款深度學習框架,搭建神經網絡比Tensorflow1.x簡單不少,近幾年Pytorch框架的用戶也越來越多,特別是在學術界有不可撼動的地位,但工業界用Pytorch很少。Tensorflow是google公司的一款深度學習開源框架,之前在工業界和keras基本兩分天下,用的人非常對,現在Keras又被google公司收購,成為了Tensorflow的高級API,可想而知Tensorflow的前景,並且因為google公司在AI行業內無人能望其項背,google公司肯定會讓他的深度學習框架變得越來越好,所以我們選擇Tensorflow作為我們的深度學習框架。

Tensorflow2入門講解(一)

2018年arxiv統計的各框架使用人數

三.Tensorflow2的簡便性

Tensorflow2推薦使用其高級API--tf.keras搭建神經網絡,

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))
model.add(tf.keras.layers.Dense(256,activation=='relu'))
model.add(tf.keras.layers.Dense(128,activation='relu'))
model.add(tf.keras.layers.Dense(10,activation='softmax'))

就這幾行代碼,我們就已經搭建好了一個結構如下圖所示的神經網絡

Tensorflow2入門講解(一)

上文代碼搭建的神經網絡模型

model.compile(optimizer='adam',loss='categorical_crossentropy',metics=['accuracy'])

對網絡進行配置過後就可以輸入參數進行訓練了,如果熟練的話,不超過2分鐘就可以編寫出一個能識別手寫數字的神經網絡。也許有人會問Tensorflow2集成化這麼高會不會不好改變網絡的結構,這個不要擔心,我們接下來會講解一下如何利用tf.keras的函數式API搭建自己的網絡模型,以及如何自定義自己的層結構來使網絡更加多樣化。下文我將用一個實例--深度學習的'hello world'搭建全連接層識別mnist手寫數據集一起探索一下Tensorflow的魅力!


分享到:


相關文章: