1. 权重保存和加载
# 保存为TensorFlow checkpoint格式
model.save_weights('./my_model')
# 保存为TensorFlow HDF5格式
model.save_weights('model.h5', save_format='h5')
# 加载
model.load_weights('my_model')
2. 保存和加载网络结构
保存一个模型的配置,序列化过程中不包含权重。保存的配置可以用来重新创建、初始化出相同的模型,即使没有模型原始的定义代码。 Keras支持JSON、YAML序列化格式。
保存
import tensorflow as tf
import numpy as np
from tensorflow import keras
import json
model = tf.keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss='mean_squared_error')
xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
ys = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)
model.fit(xs, ys, epochs=100)
print(model.predict([10.0]))
json_string = model.to_json()
print(json_string)
'''
{"class_name": "Sequential", "config": {"name": "sequential_8", "layers": [{"class_name": "Dense", "config": {"name": "dense_8", "trainable": true, "batch_input_shape": [null, 1], "dtype": "float32", "units": 1, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null, "dtype": "float32"}}, "bias_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "keras_version": "2.2.4-tf", "backend": "tensorflow"}
'''
with open("model.json","w") as json_file:
json_file.write(json_string)
加载
# load json and create model
json_file = open('model.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json)
如果使用yaml格式,将model.to_json()换成model.to_yaml(),model_from_json()换成model_from_yaml()
3. 保存整个模型
整个模型可以保存到一个文件里,包含:权重、模型配置、优化器配置等。
可以保存状态后从完全相同的状态恢复训练。
# Create a trivial model
model = keras.Sequential([
keras.layers.Dense(10, activation='softmax', input_shape=(32,)),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, targets, batch_size=32, epochs=5)
# Save entire model to a HDF5 file
model.save('my_model.h5')
# Recreate the exact same model, including weights and optimizer.
model = keras.models.load_model('my_model.h5')
閱讀更多 編程圈 的文章