TensorFlow2学习四、Keras 保存和加载模型

TensorFlow2学习四、Keras 保存和加载模型

1. 权重保存和加载

# 保存为TensorFlow checkpoint格式
model.save_weights('./my_model')
# 保存为TensorFlow HDF5格式
model.save_weights('model.h5', save_format='h5')

# 加载
model.load_weights('my_model')

2. 保存和加载网络结构

保存一个模型的配置,序列化过程中不包含权重。保存的配置可以用来重新创建、初始化出相同的模型,即使没有模型原始的定义代码。 Keras支持JSON、YAML序列化格式。

保存

import tensorflow as tf
import numpy as np
from tensorflow import keras
import json

model = tf.keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])

model.compile(optimizer='sgd', loss='mean_squared_error')

xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
ys = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)

model.fit(xs, ys, epochs=100)
print(model.predict([10.0]))

json_string = model.to_json()
print(json_string)
'''
{"class_name": "Sequential", "config": {"name": "sequential_8", "layers": [{"class_name": "Dense", "config": {"name": "dense_8", "trainable": true, "batch_input_shape": [null, 1], "dtype": "float32", "units": 1, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null, "dtype": "float32"}}, "bias_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "keras_version": "2.2.4-tf", "backend": "tensorflow"}
'''
with open("model.json","w") as json_file:
json_file.write(json_string)

加载

# load json and create model
json_file = open('model.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json)

如果使用yaml格式,将model.to_json()换成model.to_yaml(),model_from_json()换成model_from_yaml()

3. 保存整个模型

整个模型可以保存到一个文件里,包含:权重、模型配置、优化器配置等。
可以保存状态后从完全相同的状态恢复训练。

# Create a trivial model
model = keras.Sequential([
keras.layers.Dense(10, activation='softmax', input_shape=(32,)),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, targets, batch_size=32, epochs=5)

# Save entire model to a HDF5 file
model.save('my_model.h5')

# Recreate the exact same model, including weights and optimizer.
model = keras.models.load_model('my_model.h5')


分享到:


相關文章: