Tensorflow2入门讲解(一)

一.系列前言

Tensorflow2刚出了v2版本,在1.x版本的基础上,v2版本集成度更高,并且封装了keras作为高级API,使用Tensorflow的时候还可以用keras的Sequential模型快速搭建神经网络,对初学者十分友好,但缺点在于Tensorflow2推广的时间比较短,网上的开源代码基本都基于Tensorflow1,不过除了新版本,我们肯定学习新版本呀,Tensorflow2的官方文档翻译版本可以参看https://www.tensorflow.org/guide?hl=zh-CN,笔者最近也在学Tensorflow2,希望能把学到的东西和大家分享。

二.为什么选择tensorflow

Tensorflow2入门讲解(一)

深度学习的各大框架

目前热门的深度学习框架主要有Tensorflow、Pytorch、Keras、Mxnet、Theano。后两者都是比较老的框架,目前已经冷门了。Pytorch是facebook推出的一款深度学习框架,搭建神经网络比Tensorflow1.x简单不少,近几年Pytorch框架的用户也越来越多,特别是在学术界有不可撼动的地位,但工业界用Pytorch很少。Tensorflow是google公司的一款深度学习开源框架,之前在工业界和keras基本两分天下,用的人非常对,现在Keras又被google公司收购,成为了Tensorflow的高级API,可想而知Tensorflow的前景,并且因为google公司在AI行业内无人能望其项背,google公司肯定会让他的深度学习框架变得越来越好,所以我们选择Tensorflow作为我们的深度学习框架。

Tensorflow2入门讲解(一)

2018年arxiv统计的各框架使用人数

三.Tensorflow2的简便性

Tensorflow2推荐使用其高级API--tf.keras搭建神经网络,

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))
model.add(tf.keras.layers.Dense(256,activation=='relu'))
model.add(tf.keras.layers.Dense(128,activation='relu'))
model.add(tf.keras.layers.Dense(10,activation='softmax'))

就这几行代码,我们就已经搭建好了一个结构如下图所示的神经网络

Tensorflow2入门讲解(一)

上文代码搭建的神经网络模型

model.compile(optimizer='adam',loss='categorical_crossentropy',metics=['accuracy'])

对网络进行配置过后就可以输入参数进行训练了,如果熟练的话,不超过2分钟就可以编写出一个能识别手写数字的神经网络。也许有人会问Tensorflow2集成化这么高会不会不好改变网络的结构,这个不要担心,我们接下来会讲解一下如何利用tf.keras的函数式API搭建自己的网络模型,以及如何自定义自己的层结构来使网络更加多样化。下文我将用一个实例--深度学习的'hello world'搭建全连接层识别mnist手写数据集一起探索一下Tensorflow的魅力!


分享到:


相關文章: